1 """Defines the layer of probabilistic branches over L{KeyedTree}"""
2 from xml.dom.minidom import *
3
4 from probability import *
5 from KeyedTree import *
6
8 """A decision tree that supports probabilistic branches
9
10 If this node is I{not} a probabilistic branch, then identical to a L{KeyedTree} object."""
11
12 - def fill(self,keys,value=0.):
13 """Fills in any missing slots with a default value
14 @param keys: the slots that should be filled
15 @type keys: list of L{Key} instances
16 @param value: the default value (defaults to 0)
17 @note: does not overwrite existing values"""
18 if self.isProbabilistic():
19 for subtree in self.children():
20 try:
21 subtree.fill(keys,value)
22 except AttributeError:
23
24 pass
25 else:
26 KeyedTree.fill(self,keys,value)
27
35
43
45 """
46 @return: true iff there's a probabilistic branch at this node
47 @rtype: boolean"""
48 return (not self.isLeaf()) and (self.branchType == 'probabilistic')
49
59
60 - def branch(self,plane,falseTree=None,trueTree=None,
61 pruneF=True,pruneT=True,prune=True,debug=False):
62 """Same as C{L{KeyedTree}.branch}, except that plane can be a L{Distribution}
63 @param plane: if a L{Hyperplane}, then the arguments are interpreted as for {L{KeyedTree}.branch} with; if a L{Distribution}, then the tree arguments are ignored
64 @param prune: used (iff L{plane} is a L{Distribution}) to determine whether the given subtrees should be pruned
65 @type prune: C{boolean}
66 @type plane: L{Hyperplane}/L{Distribution}(L{ProbabilityTree})
67 """
68 if isinstance(plane,Distribution):
69 self.branchType = 'probabilistic'
70 self.split = plane
71 self.falseTree = None
72 self.trueTree = None
73 for key,subtree in plane._domain.items():
74 subtree.parent = (self,key)
75 if prune:
76 for subtree in self.children():
77 if isinstance(subtree,DecisionTree):
78 subtree.prune()
79 else:
80 KeyedTree.branch(self,plane,falseTree,trueTree,pruneF,pruneT,debug)
81
82 - def _merge(self,other,op,comparisons=None,conditions=[]):
83 """Helper method that merges the two trees together using the given operator to combine leaf values, without pruning
84 @param other: the other tree to merge with
85 @type other: L{DecisionTree} instance
86 @param op: the operator used to generate the new leaf values, C{lambda x,y:f(x,y)} where C{x} and C{y} are leaf values
87 @rtype: a new L{DecisionTree} instance"""
88 if comparisons is None:
89 comparisons = {}
90 if self.isProbabilistic():
91 result = self.__class__()
92 dist = {}
93 for child,prob in self.split.items():
94 newChild = child._merge(other,op,comparisons,conditions)
95 try:
96 dist[newChild] += prob
97 except KeyError:
98 dist[newChild] = prob
99 result.branch(Distribution(dist),prune=False)
100 return result
101 elif not self.isLeaf():
102 return KeyedTree._merge(self,other,op,comparisons,conditions)
103 elif other.isProbabilistic():
104 result = self.__class__()
105 dist = {}
106 for child,prob in other.split.items():
107 newChild = self._merge(child,op,comparisons,conditions)
108 try:
109 dist[newChild] += prob
110 except KeyError:
111 dist[newChild] = prob
112 result.branch(Distribution(dist),prune=False)
113 return result
114 else:
115 return KeyedTree._merge(self,other,op,comparisons,conditions)
116
117 - def prune(self,comparisons=None,debug=False):
125
127 """Marginalizes any distributions to remove the given key (not in place! returns the new tree)
128 @param key: the key to marginalize over
129 @return: a new L{ProbabilityTree} object representing the marginal function
130 @note: no exception is raised if the key is not present"""
131 result = self.__class__()
132 if self.isProbabilistic():
133 distribution = {}
134 for element,prob in self.split.items():
135 if isinstance(element,ProbabilityTree):
136 new = element.marginalize(key)
137 else:
138 new = copy.deepcopy(element)
139 new.unfreeze()
140 try:
141 del new[key]
142 except KeyError:
143 pass
144 try:
145 distribution[new] += prob
146 except KeyError:
147 distribution[new] = prob
148 result.branch(Distribution(distribution))
149 elif self.isLeaf():
150 new = copy.deepcopy(self.getValue())
151 new.unfreeze()
152 try:
153 del new[key]
154 except KeyError:
155 pass
156 result.makeLeaf(new)
157 else:
158 fTree,tTree = self.getValue()
159 result.branch(self.split,fTree.marginalize(key),
160 tTree.marginalize(key))
161 return result
162
164 result = self.__class__()
165 if self.isProbabilistic():
166 distribution = {}
167 for element,prob in self.split.items():
168 element = element.condition(observation)
169 if element is not None:
170 try:
171 distribution[element] += prob
172 except KeyError:
173 distribution[element] = prob
174 if len(distribution) == 0:
175 result = None
176 else:
177 result.branch(Distribution(distribution))
178 elif self.isLeaf():
179
180 matrix = self.getValue()
181 assert(isinstance(matrix,KeyedMatrix))
182 for rowKey in observation.keys():
183 try:
184 row = matrix[rowKey]
185 except KeyError:
186 row = {keyConstant:0.}
187
188 assert(isinstance(row,KeyedVector))
189 for colKey,value in row.items():
190 if isinstance(colKey,ConstantKey):
191 if value != observation[rowKey]:
192
193 return None
194 else:
195 if abs(value) > epsilon:
196
197 return None
198 result.makeLeaf(matrix)
199 else:
200 fTree,tTree = self.getValue()
201 fTree = fTree.condition(observation)
202 tTree = tTree.condition(observation)
203 if fTree is None:
204 if tTree is None:
205 result = None
206 else:
207 result = tTree
208 elif tTree is None:
209 result = fTree
210 else:
211 result.branch(self.split,fTree,tTree)
212 return result
213
221
228
230 if self.isProbabilistic():
231 alternatives = []
232 for subtree,prob in self.split.items():
233 for alt in subtree.generateAlternatives(index,value,test):
234 try:
235 alt['probability'] *= prob
236 except KeyError:
237 alt['probability'] = prob
238 alternatives.append(alt)
239 return alternatives
240 elif isinstance(index,Distribution):
241 alternatives = []
242 for subIndex,prob in index.items():
243 for alt in KeyedTree.generateAlternatives(self,subIndex,value,
244 test):
245 try:
246 alt['probability'] *= prob
247 except KeyError:
248 alt['probability'] = prob
249 alternatives.append(alt)
250 return alternatives
251 else:
252 return KeyedTree.generateAlternatives(self,index,value,test)
253
254 - def simpleText(self,printLeaves=True,numbers=True,all=False):
255 """Returns a more readable string version of this tree
256 @param printLeaves: optional flag indicating whether the leaves should also be converted into a user-friendly string
257 @type printLeaves: C{boolean}
258 @param numbers: if C{True}, floats are used to represent the threshold; otherwise, an automatically generated English representation (defaults to C{False})
259 @type numbers: boolean
260 @rtype: C{str}
261 """
262 if self.isProbabilistic():
263 content = ''
264 for subtree,prob in self.split.items():
265 substr = subtree.simpleText(printLeaves,numbers,all)
266 content += '%s with probability %5.3f\n' % (substr,prob)
267 return content
268 else:
269 return KeyedTree.simpleText(self,printLeaves,numbers,all)
270
278
280 """
281 @return: the distribution over leaf nodes for this value
282 """
283 if self.isProbabilistic():
284 result = Distribution()
285 for subtree,prob in self.split.items():
286 value = subtree[index]
287 if isinstance(value,Distribution):
288 for subValue,subProb in value.items():
289 try:
290 result[subValue] += subProb*prob
291 except KeyError:
292 result[subValue] = subProb*prob
293 else:
294 try:
295 result[value] += prob
296 except KeyError:
297 result[value] = prob
298 elif isinstance(index,Distribution):
299 result = Distribution()
300 for subIndex,prob in index.items():
301 value = KeyedTree.__getitem__(self,subIndex)
302 if isinstance(value,Distribution):
303 for subValue,subProb in value.items():
304 try:
305 result[subValue] += prob*subProb
306 except KeyError:
307 result[subValue] = prob*subProb
308 else:
309
310 try:
311 result[value] += prob
312 except KeyError:
313 result[value] = prob
314 else:
315 result = KeyedTree.__getitem__(self,index)
316
317
318 if isinstance(result,Distribution) and len(result) > 0 \
319 and isinstance(result.domain()[0],KeyedMatrix):
320 for value,prob in result.items():
321 for other in result.domain():
322 if value is not other:
323 if other.simpleText() == value.simpleText():
324 result[other] += prob
325 del result[value]
326 break
327
328 return result
329
330 - def _multiply(self,other,comparisons=None,conditions=[]):
331 if comparisons is None:
332 comparisons = {}
333 if self.isProbabilistic():
334 if other.isProbabilistic():
335 result = self.__class__()
336 distribution = {}
337 for myChild,myProb in self.split.items():
338 for yrChild,yrProb in other.split.items():
339 new = myChild._multiply(yrChild,comparisons,conditions)
340 try:
341 distribution[new] += myProb*yrProb
342 except KeyError:
343 distribution[new] = myProb*yrProb
344 result.branch(Distribution(distribution))
345 return result
346 else:
347 result = self.__class__()
348 distribution = {}
349 for myChild,myProb in self.split.items():
350 new = myChild._multiply(other,comparisons,conditions)
351 try:
352 distribution[new] += myProb
353 except KeyError:
354 distribution[new] = myProb
355 result.branch(Distribution(distribution))
356 return result
357 elif isinstance(other,Distribution):
358 distribution = {}
359 for yrChild,yrProb in other.items():
360 new = self._multiply(yrChild,comparisons,conditions)
361 if isinstance(new,Distribution):
362 for new,myProb in new.items():
363 try:
364 distribution[new] += myProb*yrProb
365 except KeyError:
366 distribution[new] = myProb*yrProb
367 else:
368 try:
369 distribution[new] += yrProb
370 except KeyError:
371 distribution[new] = yrProb
372 return Distribution(distribution)
373 elif isinstance(other,KeyedVector):
374 return self[other]*other
375 elif other.isProbabilistic():
376 result = self.__class__()
377 distribution = {}
378 for yrChild,yrProb in other.split.items():
379 new = self._multiply(yrChild,comparisons,conditions)
380 try:
381 distribution[new] += yrProb
382 except KeyError:
383 distribution[new] = yrProb
384 result.branch(Distribution(distribution))
385 return result
386 else:
387 return KeyedTree._multiply(self,other,comparisons,conditions)
388
391
393 if self.isProbabilistic():
394 doc = Document()
395 root = doc.createElement('tree')
396 doc.appendChild(root)
397 root.setAttribute('type','probabilistic')
398 root.appendChild(self.split.__xml__().documentElement)
399 return doc
400 else:
401 return KeyedTree.__xml__(self)
402
403 - def parse(self,element,valueClass=None,debug=False):
404 """Extracts the tree from the given XML element
405 @param element: The XML Element object specifying the plane
406 @type element: Element
407 @param valueClass: The class used to generate the leaf values
408 @return: the L{ProbabilityTree} instance"""
409 if not valueClass:
410 valueClass = KeyedMatrix
411 if element.getAttribute('type') == 'probabilistic':
412
413 split = Distribution()
414 split.parse(element.firstChild,ProbabilityTree)
415 self.branch(split)
416 else:
417
418 KeyedTree.parse(self,element,valueClass,debug)
419 return self
420
422 """Shorthand for constructing a decision tree with a single branch in it
423 @param plane: the plane to branch on
424 @type plane: L{Hyperplane}
425 @param falseTree: the tree that will be followed if the plane tests C{False}
426 @param trueTree: the tree that will be followed if the plane tests C{True}
427 @type falseTree,trueTree: L{ProbabilityTree}
428 @note: Will not prune tree
429 """
430 tree = ProbabilityTree()
431 tree.branch(plane,falseTree,trueTree,pruneF=False,pruneT=False)
432 return tree
433
439
441 """Shorthand for constructing a decision tree that branches on
442 whether the value lies on the plane or not, with the former/latter
443 cases leading down to the given equalTree/unequalTree"""
444 subPlane = copy.copy(plane)
445 subPlane.threshold -= 2.*epsilon
446 subTree = createBranchTree(subPlane,unequalTree,equalTree)
447 tree = createBranchTree(plane,subTree,unequalTree)
448 return tree
449
451 """Shorthand for constructing a leaf node with a dynamics matrix
452 for the given key with the specified weights (either KeyedVector, or
453 just plain old dictionary, for the lazy)"""
454 if isinstance(feature,Key):
455 key = feature
456 else:
457 key = makeStateKey('self',feature)
458 if isinstance(weights,KeyedVector):
459 matrix = KeyedMatrix({key:weights})
460 else:
461 matrix = KeyedMatrix({key:KeyedVector(weights)})
462 return createNodeTree(matrix)
463
465 """
466 To create a tree that follows the C{True} branch iff both the actor has accepted and the negotiation is not terminated:
467
468 >>> tree = createANDTree([(StateKey({'entity':'actor','feature':'accepted'}),True), (StateKey({'entity':'self','feature':'terminated'}),False)], falseTree, trueTree)
469
470 @note: the default truth value of the plane is C{True} (i.e., if no keys are provided, then C{trueTree} is returned
471 @param keyWeights: a list of tuples, C{(key,True/False)}, of the preconditions for the test to be true
472 @type keyWeights: (L{Key},boolean)[]
473 @param falseTree: the tree to invoke if the conjunction evaluates to C{False}
474 @param trueTree: the tree to invoke if the conjunction evaluates to C{True}
475 @type falseTree,trueTree: L{DecisionTree}
476 @return: the new tree with the conjunction test at the root
477 @rtype: L{ProbabilityTree}
478 """
479 if len(keyWeights) == 0:
480 return trueTree
481 weights = {}
482 length = float(len(keyWeights))
483 for key,truth in keyWeights:
484 if truth:
485 weights[key] = 1./length
486 else:
487 weights[key] = -1./length
488 try:
489 weights[keyConstant] += 1./length
490 except KeyError:
491 weights[keyConstant] = 1./length
492 weights = ANDRow(args=weights,keys=map(lambda t:t[0],keyWeights))
493 plane = KeyedPlane(weights,1.-1/(2.*length))
494 return createBranchTree(plane,falseTree,trueTree)
495
497 """
498 To create a tree that follows the C{True} branch iff either the actor has accepted or the negotiation is not terminated:
499
500 >>> tree = createORTree([(StateKey({'entity':'actor','feature':'accepted'}),True), (StateKey({'entity':'self','feature':'terminated'}),False)], falseTree, trueTree)
501
502 @note: the default truth value of the plane is C{False} (i.e., if no keys are provided, then C{falseTree} is returned
503 @param keyWeights: a list of tuples, C{(key,True/False)}, of the preconditions for the test to be true
504 @type keyWeights: (L{Key},boolean)[]
505 @param falseTree: the tree to invoke if the conjunction evaluates to C{False}
506 @param trueTree: the tree to invoke if the conjunction evaluates to C{True}
507 @type falseTree,trueTree: L{DecisionTree}
508 @return: the new tree with the conjunction test at the root
509 @rtype: L{ProbabilityTree}
510 """
511 if len(keyWeights) == 0:
512 return falseTree
513 weights = ORRow(keys=map(lambda t:t[0],keyWeights))
514 length = float(len(keyWeights))
515 for key,truth in keyWeights:
516 if truth:
517 weights[key] = 1./length
518 else:
519 weights[key] = -1./length
520 try:
521 weights[keyConstant] += 1./length
522 except KeyError:
523 weights[keyConstant] = 1./length
524 plane = KeyedPlane(weights,1/(2.*length))
525 return createBranchTree(plane,falseTree,trueTree)
526
528 """Creates a decision tree that will leave the given feature unchanged
529 @param feature: the state feature whose dynamics we are creating
530 @type feature: C{str}/L{Key}
531 @rtype: L{ProbabilityTree}
532 """
533 return ProbabilityTree(IdentityMatrix(feature))
534
535 if __name__ == '__main__':
536 f = open('/tmp/pynadath/tree.xml')
537 data = f.read()
538 f.close()
539 doc = parseString(data)
540 tree = ProbabilityTree()
541 tree.parse(doc.documentElement)
542 print tree.simpleText()
543
544
545
546
547
548
549
550
551
552
553
554
555
556