Package teamwork :: Package math :: Module ProbabilityTree
[hide private]
[frames] | no frames]

Source Code for Module teamwork.math.ProbabilityTree

  1  """Defines the layer of probabilistic branches over L{KeyedTree}""" 
  2  from xml.dom.minidom import * 
  3   
  4  from probability import *     
  5  from KeyedTree import * 
  6   
7 -class ProbabilityTree(KeyedTree):
8 """A decision tree that supports probabilistic branches 9 10 If this node is I{not} a probabilistic branch, then identical to a L{KeyedTree} object.""" 11
12 - def fill(self,keys,value=0.):
13 """Fills in any missing slots with a default value 14 @param keys: the slots that should be filled 15 @type keys: list of L{Key} instances 16 @param value: the default value (defaults to 0) 17 @note: does not overwrite existing values""" 18 if self.isProbabilistic(): 19 for subtree in self.children(): 20 try: 21 subtree.fill(keys,value) 22 except AttributeError: 23 # Leaf is keyless 24 pass 25 else: 26 KeyedTree.fill(self,keys,value)
27
28 - def freeze(self):
29 """Locks in the dimensions and keys of all leaves""" 30 if self.isProbabilistic(): 31 for child in self.children(): 32 child.freeze() 33 else: 34 KeyedTree.freeze(self)
35
36 - def unfreeze(self):
37 """Unocks in the dimensions and keys of all leaves""" 38 if self.isProbabilistic(): 39 for child in self.children(): 40 child.unfreeze() 41 else: 42 KeyedTree.unfreeze(self)
43
44 - def isProbabilistic(self):
45 """ 46 @return: true iff there's a probabilistic branch at this node 47 @rtype: boolean""" 48 return (not self.isLeaf()) and (self.branchType == 'probabilistic')
49
50 - def children(self):
51 """ 52 @return: all child nodes of this node 53 @rtype: L{ProbabilityTree}[] 54 """ 55 if self.isProbabilistic(): 56 return self.split.domain() 57 else: 58 return KeyedTree.children(self)
59
60 - def branch(self,plane,falseTree=None,trueTree=None, 61 pruneF=True,pruneT=True,prune=True,debug=False):
62 """Same as C{L{KeyedTree}.branch}, except that plane can be a L{Distribution} 63 @param plane: if a L{Hyperplane}, then the arguments are interpreted as for {L{KeyedTree}.branch} with; if a L{Distribution}, then the tree arguments are ignored 64 @param prune: used (iff L{plane} is a L{Distribution}) to determine whether the given subtrees should be pruned 65 @type prune: C{boolean} 66 @type plane: L{Hyperplane}/L{Distribution}(L{ProbabilityTree}) 67 """ 68 if isinstance(plane,Distribution): 69 self.branchType = 'probabilistic' 70 self.split = plane 71 self.falseTree = None 72 self.trueTree = None 73 for key,subtree in plane._domain.items(): 74 subtree.parent = (self,key) 75 if prune: 76 for subtree in self.children(): 77 if isinstance(subtree,DecisionTree): 78 subtree.prune() 79 else: 80 KeyedTree.branch(self,plane,falseTree,trueTree,pruneF,pruneT,debug)
81
82 - def _merge(self,other,op,comparisons=None,conditions=[]):
83 """Helper method that merges the two trees together using the given operator to combine leaf values, without pruning 84 @param other: the other tree to merge with 85 @type other: L{DecisionTree} instance 86 @param op: the operator used to generate the new leaf values, C{lambda x,y:f(x,y)} where C{x} and C{y} are leaf values 87 @rtype: a new L{DecisionTree} instance""" 88 if comparisons is None: 89 comparisons = {} 90 if self.isProbabilistic(): 91 result = self.__class__() 92 dist = {} 93 for child,prob in self.split.items(): 94 newChild = child._merge(other,op,comparisons,conditions) 95 try: 96 dist[newChild] += prob 97 except KeyError: 98 dist[newChild] = prob 99 result.branch(Distribution(dist),prune=False) 100 return result 101 elif not self.isLeaf(): 102 return KeyedTree._merge(self,other,op,comparisons,conditions) 103 elif other.isProbabilistic(): 104 result = self.__class__() 105 dist = {} 106 for child,prob in other.split.items(): 107 newChild = self._merge(child,op,comparisons,conditions) 108 try: 109 dist[newChild] += prob 110 except KeyError: 111 dist[newChild] = prob 112 result.branch(Distribution(dist),prune=False) 113 return result 114 else: 115 return KeyedTree._merge(self,other,op,comparisons,conditions)
116
117 - def prune(self,comparisons=None,debug=False):
118 if comparisons is None: 119 comparisons = {} 120 if self.isProbabilistic(): 121 for subtree in self.children(): 122 subtree.prune(comparisons,debug) 123 else: 124 KeyedTree.prune(self,comparisons,debug)
125
126 - def marginalize(self,key):
127 """Marginalizes any distributions to remove the given key (not in place! returns the new tree) 128 @param key: the key to marginalize over 129 @return: a new L{ProbabilityTree} object representing the marginal function 130 @note: no exception is raised if the key is not present""" 131 result = self.__class__() 132 if self.isProbabilistic(): 133 distribution = {} 134 for element,prob in self.split.items(): 135 if isinstance(element,ProbabilityTree): 136 new = element.marginalize(key) 137 else: 138 new = copy.deepcopy(element) 139 new.unfreeze() 140 try: 141 del new[key] 142 except KeyError: 143 pass 144 try: 145 distribution[new] += prob 146 except KeyError: 147 distribution[new] = prob 148 result.branch(Distribution(distribution)) 149 elif self.isLeaf(): 150 new = copy.deepcopy(self.getValue()) 151 new.unfreeze() 152 try: 153 del new[key] 154 except KeyError: 155 pass 156 result.makeLeaf(new) 157 else: 158 fTree,tTree = self.getValue() 159 result.branch(self.split,fTree.marginalize(key), 160 tTree.marginalize(key)) 161 return result
162
163 - def condition(self,observation):
164 result = self.__class__() 165 if self.isProbabilistic(): 166 distribution = {} 167 for element,prob in self.split.items(): 168 element = element.condition(observation) 169 if element is not None: 170 try: 171 distribution[element] += prob 172 except KeyError: 173 distribution[element] = prob 174 if len(distribution) == 0: 175 result = None 176 else: 177 result.branch(Distribution(distribution)) 178 elif self.isLeaf(): 179 # Assuming that the leaf is a matrix 180 matrix = self.getValue() 181 assert(isinstance(matrix,KeyedMatrix)) 182 for rowKey in observation.keys(): 183 try: 184 row = matrix[rowKey] 185 except KeyError: 186 row = {keyConstant:0.} 187 # WARNING: this handles only SetToConstant rows! 188 assert(isinstance(row,KeyedVector)) 189 for colKey,value in row.items(): 190 if isinstance(colKey,ConstantKey): 191 if value != observation[rowKey]: 192 # Setting it to a different value 193 return None 194 else: 195 if abs(value) > epsilon: 196 # Adding it to something, kind of different 197 return None 198 result.makeLeaf(matrix) 199 else: 200 fTree,tTree = self.getValue() 201 fTree = fTree.condition(observation) 202 tTree = tTree.condition(observation) 203 if fTree is None: 204 if tTree is None: 205 result = None 206 else: 207 result = tTree 208 elif tTree is None: 209 result = fTree 210 else: 211 result.branch(self.split,fTree,tTree) 212 return result
213
214 - def instantiate(self,values):
215 if self.isProbabilistic(): 216 new = self.__class__() 217 new.branch(self.split.instantiate(values)) 218 return new 219 else: 220 return KeyedTree.instantiate(self,values)
221
222 - def instantiateKeys(self,values):
223 if self.isProbabilistic(): 224 for subtree in self.children(): 225 subtree.instantiateKeys(values) 226 else: 227 return KeyedTree.instantiateKeys(self,values)
228
229 - def generateAlternatives(self,index,value,test=None):
230 if self.isProbabilistic(): 231 alternatives = [] 232 for subtree,prob in self.split.items(): 233 for alt in subtree.generateAlternatives(index,value,test): 234 try: 235 alt['probability'] *= prob 236 except KeyError: 237 alt['probability'] = prob 238 alternatives.append(alt) 239 return alternatives 240 elif isinstance(index,Distribution): 241 alternatives = [] 242 for subIndex,prob in index.items(): 243 for alt in KeyedTree.generateAlternatives(self,subIndex,value, 244 test): 245 try: 246 alt['probability'] *= prob 247 except KeyError: 248 alt['probability'] = prob 249 alternatives.append(alt) 250 return alternatives 251 else: 252 return KeyedTree.generateAlternatives(self,index,value,test)
253
254 - def simpleText(self,printLeaves=True,numbers=True,all=False):
255 """Returns a more readable string version of this tree 256 @param printLeaves: optional flag indicating whether the leaves should also be converted into a user-friendly string 257 @type printLeaves: C{boolean} 258 @param numbers: if C{True}, floats are used to represent the threshold; otherwise, an automatically generated English representation (defaults to C{False}) 259 @type numbers: boolean 260 @rtype: C{str} 261 """ 262 if self.isProbabilistic(): 263 content = '' 264 for subtree,prob in self.split.items(): 265 substr = subtree.simpleText(printLeaves,numbers,all) 266 content += '%s with probability %5.3f\n' % (substr,prob) 267 return content 268 else: 269 return KeyedTree.simpleText(self,printLeaves,numbers,all)
270
271 - def updateKeys(self):
272 if self.isProbabilistic(): 273 for subtree in self.children(): 274 subtree.updateKeys() 275 else: 276 KeyedTree.updateKeys(self) 277 return self.keys
278
279 - def __getitem__(self,index):
280 """ 281 @return: the distribution over leaf nodes for this value 282 """ 283 if self.isProbabilistic(): 284 result = Distribution() 285 for subtree,prob in self.split.items(): 286 value = subtree[index] 287 if isinstance(value,Distribution): 288 for subValue,subProb in value.items(): 289 try: 290 result[subValue] += subProb*prob 291 except KeyError: 292 result[subValue] = subProb*prob 293 else: 294 try: 295 result[value] += prob 296 except KeyError: 297 result[value] = prob 298 elif isinstance(index,Distribution): 299 result = Distribution() 300 for subIndex,prob in index.items(): 301 value = KeyedTree.__getitem__(self,subIndex) 302 if isinstance(value,Distribution): 303 for subValue,subProb in value.items(): 304 try: 305 result[subValue] += prob*subProb 306 except KeyError: 307 result[subValue] = prob*subProb 308 else: 309 # Update return distribution 310 try: 311 result[value] += prob 312 except KeyError: 313 result[value] = prob 314 else: 315 result = KeyedTree.__getitem__(self,index) 316 # <HACK> 317 # By default, matrices use a cruder (but faster) string rep, so there will likely be duplicate matrix entries that need to be consolidated 318 if isinstance(result,Distribution) and len(result) > 0 \ 319 and isinstance(result.domain()[0],KeyedMatrix): 320 for value,prob in result.items(): 321 for other in result.domain(): 322 if value is not other: 323 if other.simpleText() == value.simpleText(): 324 result[other] += prob 325 del result[value] 326 break 327 # </HACK> 328 return result
329
330 - def _multiply(self,other,comparisons=None,conditions=[]):
331 if comparisons is None: 332 comparisons = {} 333 if self.isProbabilistic(): 334 if other.isProbabilistic(): 335 result = self.__class__() 336 distribution = {} 337 for myChild,myProb in self.split.items(): 338 for yrChild,yrProb in other.split.items(): 339 new = myChild._multiply(yrChild,comparisons,conditions) 340 try: 341 distribution[new] += myProb*yrProb 342 except KeyError: 343 distribution[new] = myProb*yrProb 344 result.branch(Distribution(distribution)) 345 return result 346 else: 347 result = self.__class__() 348 distribution = {} 349 for myChild,myProb in self.split.items(): 350 new = myChild._multiply(other,comparisons,conditions) 351 try: 352 distribution[new] += myProb 353 except KeyError: 354 distribution[new] = myProb 355 result.branch(Distribution(distribution)) 356 return result 357 elif isinstance(other,Distribution): 358 distribution = {} 359 for yrChild,yrProb in other.items(): 360 new = self._multiply(yrChild,comparisons,conditions) 361 if isinstance(new,Distribution): 362 for new,myProb in new.items(): 363 try: 364 distribution[new] += myProb*yrProb 365 except KeyError: 366 distribution[new] = myProb*yrProb 367 else: 368 try: 369 distribution[new] += yrProb 370 except KeyError: 371 distribution[new] = yrProb 372 return Distribution(distribution) 373 elif isinstance(other,KeyedVector): 374 return self[other]*other 375 elif other.isProbabilistic(): 376 result = self.__class__() 377 distribution = {} 378 for yrChild,yrProb in other.split.items(): 379 new = self._multiply(yrChild,comparisons,conditions) 380 try: 381 distribution[new] += yrProb 382 except KeyError: 383 distribution[new] = yrProb 384 result.branch(Distribution(distribution)) 385 return result 386 else: 387 return KeyedTree._multiply(self,other,comparisons,conditions)
388
389 - def __str__(self):
390 return self.simpleText(printLeaves=True)
391
392 - def __xml__(self):
393 if self.isProbabilistic(): 394 doc = Document() 395 root = doc.createElement('tree') 396 doc.appendChild(root) 397 root.setAttribute('type','probabilistic') 398 root.appendChild(self.split.__xml__().documentElement) 399 return doc 400 else: 401 return KeyedTree.__xml__(self)
402
403 - def parse(self,element,valueClass=None,debug=False):
404 """Extracts the tree from the given XML element 405 @param element: The XML Element object specifying the plane 406 @type element: Element 407 @param valueClass: The class used to generate the leaf values 408 @return: the L{ProbabilityTree} instance""" 409 if not valueClass: 410 valueClass = KeyedMatrix 411 if element.getAttribute('type') == 'probabilistic': 412 # This branch is a distribution over subtrees 413 split = Distribution() 414 split.parse(element.firstChild,ProbabilityTree) 415 self.branch(split) 416 else: 417 # This is a leaf or deterministic branch 418 KeyedTree.parse(self,element,valueClass,debug) 419 return self
420
421 -def createBranchTree(plane,falseTree,trueTree):
422 """Shorthand for constructing a decision tree with a single branch in it 423 @param plane: the plane to branch on 424 @type plane: L{Hyperplane} 425 @param falseTree: the tree that will be followed if the plane tests C{False} 426 @param trueTree: the tree that will be followed if the plane tests C{True} 427 @type falseTree,trueTree: L{ProbabilityTree} 428 @note: Will not prune tree 429 """ 430 tree = ProbabilityTree() 431 tree.branch(plane,falseTree,trueTree,pruneF=False,pruneT=False) 432 return tree
433
434 -def createNodeTree(node=None):
435 """Shorthand for constructing a leaf node with the given value""" 436 tree = ProbabilityTree() 437 tree.makeLeaf(node) 438 return tree
439
440 -def createEqualTree(plane,equalTree,unequalTree):
441 """Shorthand for constructing a decision tree that branches on 442 whether the value lies on the plane or not, with the former/latter 443 cases leading down to the given equalTree/unequalTree""" 444 subPlane = copy.copy(plane) 445 subPlane.threshold -= 2.*epsilon 446 subTree = createBranchTree(subPlane,unequalTree,equalTree) 447 tree = createBranchTree(plane,subTree,unequalTree) 448 return tree
449
450 -def createDynamicNode(feature,weights):
451 """Shorthand for constructing a leaf node with a dynamics matrix 452 for the given key with the specified weights (either KeyedVector, or 453 just plain old dictionary, for the lazy)""" 454 if isinstance(feature,Key): 455 key = feature 456 else: 457 key = makeStateKey('self',feature) 458 if isinstance(weights,KeyedVector): 459 matrix = KeyedMatrix({key:weights}) 460 else: 461 matrix = KeyedMatrix({key:KeyedVector(weights)}) 462 return createNodeTree(matrix)
463
464 -def createANDTree(keyWeights,falseTree,trueTree):
465 """ 466 To create a tree that follows the C{True} branch iff both the actor has accepted and the negotiation is not terminated: 467 468 >>> tree = createANDTree([(StateKey({'entity':'actor','feature':'accepted'}),True), (StateKey({'entity':'self','feature':'terminated'}),False)], falseTree, trueTree) 469 470 @note: the default truth value of the plane is C{True} (i.e., if no keys are provided, then C{trueTree} is returned 471 @param keyWeights: a list of tuples, C{(key,True/False)}, of the preconditions for the test to be true 472 @type keyWeights: (L{Key},boolean)[] 473 @param falseTree: the tree to invoke if the conjunction evaluates to C{False} 474 @param trueTree: the tree to invoke if the conjunction evaluates to C{True} 475 @type falseTree,trueTree: L{DecisionTree} 476 @return: the new tree with the conjunction test at the root 477 @rtype: L{ProbabilityTree} 478 """ 479 if len(keyWeights) == 0: 480 return trueTree 481 weights = {} 482 length = float(len(keyWeights)) 483 for key,truth in keyWeights: 484 if truth: 485 weights[key] = 1./length 486 else: 487 weights[key] = -1./length 488 try: 489 weights[keyConstant] += 1./length 490 except KeyError: 491 weights[keyConstant] = 1./length 492 weights = ANDRow(args=weights,keys=map(lambda t:t[0],keyWeights)) 493 plane = KeyedPlane(weights,1.-1/(2.*length)) 494 return createBranchTree(plane,falseTree,trueTree)
495
496 -def createORTree(keyWeights,falseTree,trueTree):
497 """ 498 To create a tree that follows the C{True} branch iff either the actor has accepted or the negotiation is not terminated: 499 500 >>> tree = createORTree([(StateKey({'entity':'actor','feature':'accepted'}),True), (StateKey({'entity':'self','feature':'terminated'}),False)], falseTree, trueTree) 501 502 @note: the default truth value of the plane is C{False} (i.e., if no keys are provided, then C{falseTree} is returned 503 @param keyWeights: a list of tuples, C{(key,True/False)}, of the preconditions for the test to be true 504 @type keyWeights: (L{Key},boolean)[] 505 @param falseTree: the tree to invoke if the conjunction evaluates to C{False} 506 @param trueTree: the tree to invoke if the conjunction evaluates to C{True} 507 @type falseTree,trueTree: L{DecisionTree} 508 @return: the new tree with the conjunction test at the root 509 @rtype: L{ProbabilityTree} 510 """ 511 if len(keyWeights) == 0: 512 return falseTree 513 weights = ORRow(keys=map(lambda t:t[0],keyWeights)) 514 length = float(len(keyWeights)) 515 for key,truth in keyWeights: 516 if truth: 517 weights[key] = 1./length 518 else: 519 weights[key] = -1./length 520 try: 521 weights[keyConstant] += 1./length 522 except KeyError: 523 weights[keyConstant] = 1./length 524 plane = KeyedPlane(weights,1/(2.*length)) 525 return createBranchTree(plane,falseTree,trueTree)
526
527 -def identityTree(feature):
528 """Creates a decision tree that will leave the given feature unchanged 529 @param feature: the state feature whose dynamics we are creating 530 @type feature: C{str}/L{Key} 531 @rtype: L{ProbabilityTree} 532 """ 533 return ProbabilityTree(IdentityMatrix(feature))
534 535 if __name__ == '__main__': 536 f = open('/tmp/pynadath/tree.xml') 537 data = f.read() 538 f.close() 539 doc = parseString(data) 540 tree = ProbabilityTree() 541 tree.parse(doc.documentElement) 542 print tree.simpleText() 543 544 ## from unittest import TestResult 545 ## import sys 546 ## from teamwork.test.math.testKeyedPlane import TestKeyedPlane 547 ## if len(sys.argv) > 1: 548 ## method = sys.argv[1] 549 ## else: 550 ## method = 'testANDPlane' 551 ## case = TestKeyedPlane(method) 552 ## result = TestResult() 553 ## case(result) 554 ## for failure in result.errors+result.failures: 555 ## print failure[1] 556