1 """Base classes for piecewise linearity
2 @var __CONSTANT__: flag indicating whether a constant factor should be included in each vector
3 @type __CONSTANT__: boolean
4 @var epsilon: margin of error used in comparison
5 @type epsilon: float
6 """
7 import copy
8
9
10
11
12 from types import *
13 from rules import pruneRules
14 from xml.dom.minidom import *
15 from id3 import gain
16 from dtree import create_decision_tree
17 import time
18
19 __CONSTANT__ = 1
20
21 epsilon = 0.00001
22
24 """A structure to represent linear separations on an I{n}-dimensional space
25 @ivar weights: the slope of this plane
26 @type weights: L{KeyedVector}
27 @ivar threshold: the offset of this plane
28 @type threshold: float
29 @ivar relation: the relation against this plane. Default is >, alternatives are: =.
30 @type relation: str
31 """
32
33 - def __init__(self,weights,threshold,relation=None):
34 """Constructs a hyperplane weights*x == threshold
35 @param weights: the slope of the hyperplane
36 @type weights: list or array
37 @param threshold: the intercept of this hyperplane
38 @type threshold: float"""
39 self._string = None
40 if type(weights) is ListType:
41 try:
42 self.weights = array(weights)
43 except TypeError,e:
44 print 'Weights:',weights
45 raise TypeError,e
46 else:
47 self.weights = weights
48 self.threshold = threshold
49 self.relation = relation
50
52 """Return the slope of this hyperplane"""
53 if __CONSTANT__:
54 return self.weights[:len(self.weights)-1]
55 else:
56 return self.weights
57
59 if __CONSTANT__:
60 return self.weights[len(self.weights)-1]
61 else:
62 return 0.
63
64 - def test(self,value):
65 """Returns true iff the passed in value (in array form) lies
66 above this hyperplane (self.weights*value > self.threshold)
67 @rtype: boolean"""
68 total = dot(self.weights,value)
69 if self.relation is None or self.relation == '>':
70 return total > self.threshold
71 elif self.relation == '=':
72 return abs(total - self.threshold) < epsilon
73 else:
74 raise UserWarning,'Unknown hyperplane test: %s' % (self.relation)
75
77 """
78 @return:
79 - True: iff this plane eliminates none of the state space (i.e., for all q, w*q > theta).
80 - False: iff this plane eliminates all of the state space (i.e., for all q, w*q <= theta).
81 - None: otherwise
82 @rtype: boolean
83 @warning: This has not yet been implemented for this class"""
84 raise NotImplementedError
85
87 """Modified version of __cmp__ method
88 @return:
89 - 'less': self < other
90 - 'greater': self > other
91 - 'equal': self == other
92 - 'indeterminate': none of the above
93 @rtype: str
94 """
95 if self == other:
96 return 'equal'
97 elif self < other:
98 return 'less'
99 elif self > other:
100 return 'greater'
101 else:
102 return 'indeterminate'
103
105 return '%s ? %5.3f' % (str(self.weights),self.threshold)
106
107
108
109
110
111
113 return self.__class__(self.weights*-1.,self.threshold)
114
116 """
117 Creates a plane exactly opposite to this one. In other words, for all C{x}, C{self.test[x]} implies C{not self.inverse().test[x]}
118 @rtype: L{Hyperplane}
119 """
120 return self.__class__(-self.weights,-self.threshold)
121
122
123
124
125
126
127
141
155
157 return self.compare(other) == 'equal'
158
159
160
162 return self.weights[index]
163
166
169
171 weights = copy.deepcopy(self.weights,memo)
172 memo[id(self.weights)] = weights
173 return self.__class__(weights,self.threshold,self.relation)
174
176 doc = Document()
177 root = doc.createElement('plane')
178 doc.appendChild(root)
179 root.setAttribute('threshold',str(self.threshold))
180 if self.relation:
181 root.setAttribute('relation',self.relation)
182 root.appendChild(self.weights.__xml__().documentElement)
183 return doc
184
185 - def parse(self,element):
186 """Extracts the plane from the given XML element
187 @param element: The XML Element object specifying the plane
188 @type element: Element
189 @return: the L{Hyperplane} instance"""
190 self.threshold = float(element.getAttribute('threshold'))
191 self.relation = str(element.getAttribute('relation'))
192 if not self.relation:
193 self.relation = None
194 nodes = element.getElementsByTagName('vector')
195 self.weights = self.weights.parse(nodes[0])
196
197 if self.weights.__class__.__name__ == 'KeyedVector':
198 nodes[0].setAttribute('type','Equal')
199 self.weights = self.weights.parse(nodes[0])
200 return self
201
203 """Represents a decision tree with hyperplane branches that divide
204 an n-dimensional space, and unrestricted values stored at the leaf
205 nodes (e.g., matrices for dynamics, actions for policies, etc.)
206 @cvar planeClass: the class used to instantiate the branches
207 @cvar checkTautology: flag that, if C{True}, activates check for hyperplanes that are either always C{True} or always C{False} in L{branch}. This can lead to smaller trees, but decreases efficiency
208 @type checkTautology: C{boolean}
209 @cvar checkPrune: flag that, if C{True}, activates the L{prune} method. This will lead to much smaller trees, but increases the overhead required to check for pruneability
210 @type checkPrune: C{boolean}
211 """
212 planeClass = Hyperplane
213 checkTautology = False
214 checkPrune = True
215
217 """Creates a DecisionTree
218 @param value: the optional leaf node value"""
219 self.parent = None
220 self.stats = {}
221 self.makeLeaf(value)
222
232
234 """
235 @return: the value of this tree
236 - If a leaf node, as a single object
237 - If a branch, as a tuple (falseTree,trueTree)"""
238 if self.isLeaf():
239 return self.falseTree
240 else:
241 return (self.falseTree,self.trueTree)
242
244 """
245 @return: C{True} iff this tree is a leaf node
246 @rtype: boolean"""
247 if len(self.split) > 0:
248 return False
249 else:
250 return True
251
253 """
254 @return: all subtrees rooted at this node
255 @rtype: list of L{DecisionTree} instances"""
256 if self.isLeaf():
257 return []
258 else:
259 falseTree,trueTree = self.getValue()
260 return [trueTree,falseTree]
261
263 """
264 @return: list of all leaf values (not necessarily unique) from L to R
265 @note: the leaf value is the result of calling L{getValue}, not an actual L{DecisionTree} instance
266 """
267 if self.isLeaf():
268 return [self.getValue()]
269 else:
270 leaves = []
271 for child in self.children():
272 leaves += child.leaves()
273 return leaves
274
276 """
277 @return: list of all leaf nodes (not necessarily unique) from L to R
278 @rtype: L{DecisionTree}[]
279 """
280 if self.isLeaf():
281 return [self]
282 else:
283 leaves = []
284 for child in self.children():
285 leaves += child.leafNodes()
286 return leaves
287
289 """
290 @return: the maximum distance between this node and the leaf nodes of the tree rooted at this node (a leaf node has a depth of 0, a branch node with two leaf nodes as children has a depth of 1, etc.)
291 @rtype: int
292 """
293 if self.isLeaf():
294 return 0
295 else:
296 return 1+max(map(lambda c:c.depth(),self.children()))
297
299 """
300 @return: all branches (not necessarily unique)
301 @rtype: intS{->}L{Hyperplane}
302 """
303 if result is None:
304 result = {}
305 if not self.isLeaf():
306 if isinstance(self.split,list):
307 for plane in self.split:
308 result[id(plane)] = plane
309 else:
310 assert(not isinstance(self.split,Hyperplane))
311 result[id(self.split)] = self.split
312 for child in self.children():
313 result.update(child.branches(result))
314 return result
315
316 - def branch(self,plane,falseTree,trueTree,pruneF=True,pruneT=True,debug=False):
317 """Marks this tree as a deterministic branching node
318 @param plane: the branchpoint(s) separating the C{False} and C{True} subtrees
319 @type plane: L{Hyperplane} or L{Hyperplane}[]
320 @param falseTree: the C{False} subtree
321 @type falseTree: L{DecisionTree} instance
322 @param trueTree: the C{True} subtree
323 @type trueTree: L{DecisionTree} instance
324 @param pruneF: if true, will L{prune} the C{False} subtree
325 @type pruneF: bool
326 @param pruneT: if true, will L{prune} the C{True} subtree
327 @type pruneT: bool
328 @param debug: if C{True}, some debugging statements will be written to stdout (default is C{False})
329 @type debug: bool
330 @note: setting either prune flag to false will save time (though may lead to more inefficient trees)"""
331 self.branchType = 'deterministic'
332 if isinstance(plane,list):
333 self.split = plane
334 else:
335 self.split = [plane]
336 if self.checkTautology:
337
338 always = None
339 for plane in self.split[:]:
340 value = plane.always()
341 if value == False:
342
343 always = value
344 break
345 elif value == True:
346
347 self.split.remove(plane)
348 if len(self.split) == 0:
349
350 if isinstance(trueTree,DecisionTree):
351 if trueTree.isLeaf():
352 self.makeLeaf(trueTree.getValue())
353 else:
354 newFalse,newTrue = trueTree.getValue()
355 self.branch(trueTree.split,newFalse,newTrue,
356 pruneF=False,pruneT=False,debug=debug)
357 else:
358 self.makeLeaf(trueTree)
359 return
360 elif always == False:
361
362 if isinstance(falseTree,DecisionTree):
363 if falseTree.isLeaf():
364 self.makeLeaf(falseTree.getValue())
365 else:
366 newFalse,newTrue = falseTree.getValue()
367 self.branch(falseTree.split,newFalse,newTrue,
368 pruneF=False,pruneT=False,debug=debug)
369 else:
370 self.makeLeaf(falseTree)
371 return
372
373 if isinstance(falseTree,DecisionTree):
374 self.falseTree = falseTree
375 else:
376 self.falseTree = self.__class__()
377 self.falseTree.makeLeaf(falseTree)
378
379 if isinstance(trueTree,DecisionTree):
380 self.trueTree = trueTree
381 else:
382 self.trueTree = self.__class__()
383 self.trueTree.makeLeaf(trueTree)
384 self.falseTree.parent = (self,False)
385 self.trueTree.parent = (self,True)
386 if pruneF:
387 self.falseTree.prune(debug=debug)
388 if pruneT:
389 self.trueTree.prune(debug=debug)
390
392 """
393 @return: the conditions under which this node will be reached, as a list of C{(plane,True/False)} tuples
394 @rtype: (L{Hyperplane},boolean)[]
395 """
396 if self.parent:
397 parent,side = self.parent
398 return [(parent.split,side)] + parent.getPath()
399 else:
400 return []
401
411
413 if not self.isLeaf():
414 truth = True
415 for plane in self.split:
416 truth = plane.always(negative)
417 if isinstance(truth,bool):
418 break
419 if truth is True:
420
421 fTree,tTree = self.getValue()
422 tTree.removeTautologies()
423 if tTree.isLeaf():
424 self.makeLeaf(tTree.getValue())
425 else:
426 fNew,tNew = tTree.getValue()
427 self.branch(tTree.split,fNew,tNew,
428 pruneF=False,pruneT=False)
429 elif truth is False:
430
431 fTree,tTree = self.getValue()
432 fTree.removeTautologies()
433 if fTree.isLeaf():
434 self.makeLeaf(fTree.getValue())
435 else:
436 fNew,tNew = fTree.getValue()
437 self.branch(fTree.split,fNew,tNew,
438 pruneF=False,pruneT=False)
439
440 - def prune(self,comparisons=None,debug=False,negative=True):
441 if not self.checkPrune:
442 return
443 if comparisons is None:
444 comparisons = {}
445 if not self.isLeaf():
446 ancestor = self.parent
447 split = self.split
448 if debug:
449 print
450 print 'Current:'
451 print ' and '.join(map(lambda p:p.simpleText(),split))
452 print len(self.leaves())
453 while ancestor:
454 parent,side = self.parent
455 tree,direction = ancestor
456 if debug:
457 print 'Ancestor:',len(tree.split)
458
459 print 'Side:',direction
460 split = comparePlaneSets(split,tree.split,direction,comparisons,debug,negative)
461 if debug:
462 print 'Result:',
463 if isinstance(split,bool):
464 print split
465 else:
466 print ' and '.join(map(lambda p:p.simpleText(),split))
467 if isinstance(split,bool):
468 oldFalse,oldTrue = parent.getValue()
469 newFalse,newTrue = self.getValue()
470 if split:
471
472 oldFalse,oldTrue = parent.getValue()
473 newFalse,newTrue = self.getValue()
474 if side:
475 parent.branch(parent.split,oldFalse,newTrue,
476 pruneF=False,pruneT=True,debug=debug)
477 else:
478 parent.branch(parent.split,newTrue,oldTrue,
479 pruneF=True,pruneT=False,debug=debug)
480 else:
481
482 if side:
483 parent.branch(parent.split,oldFalse,newFalse,
484 pruneF=False,pruneT=True,debug=debug)
485 else:
486 parent.branch(parent.split,newFalse,oldTrue,
487 pruneF=True,pruneT=False,debug=debug)
488 break
489 ancestor = tree.parent
490 else:
491 self.split = split
492 self.falseTree.prune(comparisons,debug)
493 self.trueTree.prune(comparisons,debug)
494
495 falseTree,trueTree = self.getValue()
496 if falseTree == trueTree:
497 if debug:
498 print 'Equal subtrees:',falseTree,trueTree
499 if falseTree.isLeaf():
500 self.makeLeaf(falseTree)
501 else:
502 newFalse,newTrue = falseTree.getValue()
503 self.branch(falseTree.split,newFalse,newTrue,
504 pruneF=False,pruneT=False,debug=debug)
505
507 """
508 @return: a dictionary of statistics about the decision tree rooted at this node:
509 - I{leaf}: # of leaves
510 - I{branch}: # of branch nodes
511 - I{depth}: depth of tree"""
512 if self.isLeaf():
513 self.stats['leaf'] = 1
514 self.stats['branch'] = 0
515 self.stats['depth'] = 0
516 return self.stats
517 else:
518 self.stats['leaf'] = 0
519 self.stats['branch'] = 1
520 self.stats['depth'] = 1
521 depth = 0
522 for tree in self.children():
523 subCount = tree.count()
524 for key in ['leaf','branch']:
525 self.stats[key] += subCount[key]
526 if subCount['depth'] > depth:
527 depth = subCount['depth']
528 self.stats['depth'] += depth
529 return self.stats
530
532 """
533 Uses ID3 heuristic to reorder branches
534 @return: C{True}, iff a rebalancing was applied at this level
535 """
536 target = '_value'
537 attributes = {target:True}
538 values = {}
539 data = self.makeRules(attributes,values)
540 new = create_decision_tree(data,attributes.keys(),target,gain)
541 self._extractTree(new,attributes,values)
542 return self
543
545 """
546 Extracts the rules from the given L{dtree} structure into this tree
547 """
548 if type(tree) == dict:
549 plane = attributes[tree.keys()[0]]
550 trueTree = None
551 falseTree = None
552 for item in tree.values()[0].keys():
553 if item == True:
554 trueTree = self.__class__()
555 trueTree._extractTree(tree.values()[0][item],
556 attributes,values)
557 elif item == False:
558 falseTree = self.__class__()
559 falseTree._extractTree(tree.values()[0][item],
560 attributes,values)
561 else:
562 raise UserWarning,'Unknown attribute value: %s' % \
563 (str(item))
564 if trueTree is None:
565 if falseTree is None:
566 raise UserWarning,'Null decision tree returned'
567 else:
568 self.makeLeaf(falseTree)
569 elif falseTree is None:
570 self.makeLeaf(trueTree)
571 else:
572 if falseTree == trueTree:
573 raise UserWarning
574 self.branch(plane,falseTree,trueTree,
575 pruneF=False,pruneT=False)
576 else:
577 self.makeLeaf(values[tree])
578 return self
579
580 - def makeRules(self,attributes=None,values=None,conditions=None,
581 debug=False,comparisons=None):
582 """Represents this tree as a list of rules
583 @return: dict[]
584 """
585 if comparisons is None:
586 comparisons = {}
587 rules = []
588 if attributes is None:
589 attributes = {'_value':True}
590 if values is None:
591 values = {}
592 if conditions is None:
593 conditions = []
594 if self.isLeaf():
595 label = str(self.getValue())
596 rule = {'_value':label}
597 values[label] = self.getValue()
598 for plane,side in conditions:
599 rule[plane] = side
600 rules.append(rule)
601 else:
602 falseTree,trueTree = self.getValue()
603 newConditions = {}
604 for plane in self.split:
605 label = plane.simpleText()
606 attributes[label] = plane
607 newConditions[label] = plane
608
609 for plane in newConditions.keys():
610 split = [newConditions[plane]]
611 for oldPlane,side in conditions:
612 split = comparePlaneSets(split,[attributes[oldPlane]],
613 side,comparisons)
614 if isinstance(split,bool):
615 if split:
616
617 break
618 else:
619
620 rules += falseTree.makeRules(attributes,values,
621 conditions,
622 debug,comparisons)
623 break
624 else:
625
626 rules += falseTree.makeRules(attributes,values,
627 conditions+[(plane,False)],
628 debug,comparisons)
629
630 split = newConditions.values()
631 for oldPlane,side in conditions:
632 split = comparePlaneSets(split,[attributes[oldPlane]],
633 side,comparisons)
634 if isinstance(split,bool):
635 if split:
636
637 rules += trueTree.makeRules(attributes,values,
638 conditions,
639 debug,comparisons)
640 break
641 else:
642
643 break
644 else:
645
646 rules += trueTree.makeRules(attributes,values,
647 conditions+map(lambda p:(p,True),
648 newConditions.keys()),
649 debug,comparisons)
650
651
652 if not self.parent:
653 for rule in rules:
654 for attr in attributes.keys():
655 if not rule.has_key(attr):
656
657 rule[attr] = None
658 if debug:
659 print '\t\tPruning %s rules' % (len(rules))
660 rules,attributes = pruneRules(rules,attributes,values,debug)
661 return rules
662
663 - def fromRules(self,rules,attributes,values,comparisons=None):
704
706 if not test:
707 test = lambda x,y: x != y
708 if self.isLeaf():
709 myValue = self.getValue()
710 if test(myValue,value):
711 return [{'plane':None,'truth':1,'value':myValue}]
712 else:
713
714 return []
715 else:
716 falseTree,trueTree = self.getValue()
717 if reduce(lambda x,y:x and y,
718 map(lambda p:p.test(index),self.split)):
719
720 alternatives = trueTree.generateAlternatives(index,value)
721 myValue = falseTree[index]
722 for action in myValue:
723 if test(action,value):
724
725 for plane in self.split:
726 if plane.test(index):
727 alternatives.append({'plane':plane,
728 'truth':False,
729 'value':myValue})
730 break
731 else:
732
733 alternatives = falseTree.generateAlternatives(index,value)
734 myValue = trueTree[index]
735 for action in myValue:
736 if test(action,value):
737
738 for plane in self.split:
739 if not plane.test(index):
740 alternatives.append({'plane':plane,
741 'truth':True,
742 'value':myValue})
743 break
744 return alternatives
745
747 if type(index) is IntType:
748
749 if not self.stats.has_key('index'):
750 self.createIndex()
751 if self.isLeaf():
752 if self.stats['index'] == index:
753 return self
754 else:
755 raise IndexError,index
756 else:
757 falseTree,trueTree = self.getValue()
758 if index < falseTree.stats['index'] + falseTree.stats['leaf']:
759 return falseTree[index]
760 else:
761 return trueTree[index]
762 else:
763
764 if self.isLeaf():
765 return self.getValue()
766 else:
767
768 if reduce(lambda x,y:x and y,
769 map(lambda p:p.test(index),self.split)):
770 return self.trueTree[index]
771 else:
772 return self.falseTree[index]
773
774 - def replace(self,orig,new,comparisons=None,conditions=[]):
775 """Replaces any leaf nodes that match the given original value
776 with the provided new value, followed by a pruning phase
777 @param orig: leaf value to be replaced
778 @param new: leaf value with which to replace
779 @warning: the replacement modifies this tree in place"""
780 if not isinstance(new,DecisionTree):
781 raise NotImplementedError,'Currently unable to replace leaf nodes with non-tree objects'
782 if comparisons is None:
783 comparisons = {}
784 if self.isLeaf():
785 value = self.getValue()
786 if isinstance(value,orig.__class__) and value == orig:
787 if new.isLeaf():
788 self.makeLeaf(new.getValue())
789 else:
790 falseTree,trueTree = new.getValue()
791
792 split = new.split
793 for plane,truth in conditions:
794 split = comparePlaneSets(split,plane,truth,comparisons)
795 if isinstance(split,bool):
796 if split:
797
798 return self.replace(orig,trueTree,comparisons,conditions)
799 else:
800
801 return self.replace(orig,falseTree,comparisons,conditions)
802
803 newFalse = self.__class__()
804 newFalse.makeLeaf(orig)
805 newFalse.replace(orig,falseTree,comparisons,conditions)
806 newTrue = self.__class__()
807 newTrue.makeLeaf(orig)
808 newTrue.replace(orig,trueTree,comparisons,conditions)
809 self.branch(split,newFalse,newTrue,pruneF=False,pruneT=False)
810 else:
811
812 falseTree,trueTree = self.getValue()
813 falseTree.replace(orig,new,comparisons,conditions+[(self.split,False)])
814 trueTree.replace(orig,new,comparisons,conditions+[(self.split,True)])
815 self.branch(self.split,falseTree,trueTree,pruneF=False,pruneT=False)
816
817 - def merge(self,other,op):
818 """Merges the two trees together using the given operator to combine leaf values
819 @param other: the other tree to merge with
820 @type other: L{DecisionTree} instance
821 @param op: the operator used to generate the new leaf values, C{lambda x,y:f(x,y)} where C{x} and C{y} are leaf values
822 @rtype: a new L{DecisionTree} instance"""
823 result = self._merge(other,op)
824 result.prune()
825 return result
826
827 - def _merge(self,other,op,comparisons=None,conditions=[]):
828 """Helper method that merges the two trees together using the given operator to combine leaf values, without pruning
829 @param other: the other tree to merge with
830 @type other: L{DecisionTree} instance
831 @param op: the operator used to generate the new leaf values, C{lambda x,y:f(x,y)} where C{x} and C{y} are leaf values
832 @rtype: a new L{DecisionTree} instance"""
833 if comparisons is None:
834 comparisons = {}
835 result = self.__class__()
836 if not self.isLeaf():
837 falseTree,trueTree = self.getValue()
838 falseTree = falseTree._merge(other,op,comparisons,conditions+[(self.split,False)])
839 trueTree = trueTree._merge(other,op,comparisons,conditions+[(self.split,True)])
840 result.branch(self.split,falseTree,trueTree,pruneF=False,pruneT=False)
841 elif isinstance(other,DecisionTree):
842 if other.isLeaf():
843 result.makeLeaf(op(self.getValue(),other.getValue()))
844 else:
845 falseTree,trueTree = other.getValue()
846
847 split = other.split
848 for plane,truth in conditions:
849 split = comparePlaneSets(split,plane,truth,comparisons)
850 if isinstance(split,bool):
851 if split:
852
853 return self._merge(trueTree,op,comparisons,conditions)
854 else:
855
856 return self._merge(falseTree,op,comparisons,conditions)
857
858 newFalse = self._merge(falseTree,op,comparisons,conditions)
859 newTrue = self._merge(trueTree,op,comparisons,conditions)
860 result.branch(split,newFalse,newTrue,
861 pruneF=False,pruneT=False)
862 else:
863 result.makeLeaf(op(self.getValue(),other))
864 return result
865
867 return self.merge(other,lambda x,y:x+y)
868
870 result = self._multiply(other)
871 result.prune()
872 return result
873
874 - def _multiply(self,other,comparisons=None,conditions=[]):
875 if comparisons is None:
876 comparisons = {}
877 result = self.__class__()
878 if other.isLeaf():
879 if self.isLeaf():
880 result.makeLeaf(matrixmultiply(self.getValue(),
881 other.getValue()))
882 else:
883 falseTree,trueTree = self.getValue()
884 new = []
885 for original in self.split:
886 weights = matrixmultiply(original.weights,other.getValue())
887 plane = original.__class__(weights,original.threshold)
888 new.append(plane)
889 result.branch(new,falseTree._multiply(other,comparisons,conditions+[(new,False)]),
890 trueTree._multiply(other,comparisons,conditions+[(new,True)]),
891 pruneF=False,pruneT=False)
892 else:
893 falseTree,trueTree = other.getValue()
894 split = other.split
895
896 for plane,truth in conditions:
897 split = comparePlaneSets(split,plane,truth,comparisons)
898 if isinstance(split,bool):
899 if split:
900
901 return self._multiply(trueTree,comparisons,conditions)
902 else:
903
904 return self._multiply(falseTree,comparisons,conditions)
905
906 newFalse = self._multiply(falseTree,comparisons,conditions)
907 newTrue = self._multiply(trueTree,comparisons,conditions)
908 result.branch(split,newFalse,newTrue,pruneF=False,pruneT=False)
909 return result
910
912 return self + (-other)
913
915 result = self.__class__()
916 if self.isLeaf():
917 result.makeLeaf(-self.getValue())
918 else:
919 result.branch(self.split,-self.falseTree,-self.trueTree,
920 pruneF=False,pruneT=False)
921 return result
922
934
936 return hash(str(self))
937
940
950
952 doc = Document()
953 root = doc.createElement('tree')
954 doc.appendChild(root)
955 if self.isLeaf():
956 root.setAttribute('type','leaf')
957 value = self.getValue()
958 try:
959 root.appendChild(value.__xml__().documentElement)
960 except AttributeError:
961
962 root.appendChild(doc.createTextNode(str(value)))
963 else:
964 root.setAttribute('type','branch')
965 element = doc.createElement('split')
966 root.appendChild(element)
967 for plane in self.split:
968 element.appendChild(plane.__xml__().documentElement)
969 falseTree,trueTree = self.getValue()
970 element = doc.createElement('false')
971 root.appendChild(element)
972 element.appendChild(falseTree.__xml__().documentElement)
973 element = doc.createElement('true')
974 root.appendChild(element)
975 element.appendChild(trueTree.__xml__().documentElement)
976 return doc
977
978 - def parse(self,element,valueClass=None,debug=False):
979 """Extracts the tree from the given XML element
980 @param element: The XML Element object specifying the plane
981 @type element: Element
982 @param valueClass: The class used to generate the leaf values
983 @return: the L{KeyedTree} instance"""
984 if element.getAttribute('type') == 'leaf':
985
986 if not valueClass:
987 valueClass = float
988 node = element.firstChild
989 while node:
990 if node.nodeType == node.ELEMENT_NODE:
991 value = valueClass()
992 value = value.parse(node)
993 break
994 elif node.nodeType == node.TEXT_NODE:
995 value = str(node.data).strip()
996 if len(value) > 0:
997 if value == 'None':
998
999 value = None
1000 break
1001 node = node.nextSibling
1002 else:
1003
1004 value = None
1005
1006 self.makeLeaf(value)
1007 else:
1008
1009 planes = []
1010 falseTree = self.__class__()
1011 trueTree = self.__class__()
1012 node = element.firstChild
1013 while node:
1014 if node.nodeType == node.ELEMENT_NODE:
1015 if node.tagName == 'split':
1016 subNode = node.firstChild
1017 while subNode:
1018 if subNode.nodeType == subNode.ELEMENT_NODE:
1019 plane = self.planeClass({},0.)
1020 planes.append(plane.parse(subNode))
1021 subNode = subNode.nextSibling
1022 elif node.tagName in ['false','left']:
1023 subNode = node.firstChild
1024 while subNode and subNode.nodeType != node.ELEMENT_NODE:
1025 subNode = subNode.nextSibling
1026 falseTree = falseTree.parse(subNode,valueClass,debug)
1027 elif node.tagName in ['true','right']:
1028 subNode = node.firstChild
1029 while subNode and subNode.nodeType != node.ELEMENT_NODE:
1030 subNode = subNode.nextSibling
1031 trueTree = trueTree.parse(subNode,valueClass,debug)
1032 node = node.nextSibling
1033 self.branch(planes,falseTree,trueTree)
1034 return self
1035
1037 for datum in data:
1038 print '\n\t',
1039 for attr,val in datum.items():
1040 if attr != '_value':
1041 print '%5s' % (val),
1042 if values:
1043 print values.index(datum['_value']),pow(2,datum.values().count(None)),
1044 print
1045
1046
1047 -def comparePlaneSets(set1,set2,side,comparisons=None,
1048 debug=False,negative=True):
1049 """
1050 Compares a conjunction of planes against a second conjunction of planes that has already been tested against. It prunes the current conjunction based on any redundancy or inconsistency with the test
1051 @param set1: the plane set to be pruned
1052 @param set2: the plane set already tested
1053 @type set1,set2: L{Hyperplane}[]
1054 @param side: the side of the second set that we're already guaranteed to be on
1055 @type side: boolean
1056 @return: The minimal set of planes in the first set that are not redundant given these a priori conditions (if guaranteed to be C{True} or C{False}, then the boolean value is returned)
1057 @rtype: L{Hyperplane}[]
1058 @param negative: if C{True}, then assume that weights may be negative (default is C{True}
1059 """
1060 hasher = id
1061
1062 planes = []
1063 mustBe = map(lambda p:None,set2)
1064 trueCount = 0
1065
1066 for myPlane in set1:
1067 for yrIndex in range(len(set2)):
1068 yrPlane = set2[yrIndex]
1069 if isinstance(comparisons,dict):
1070 try:
1071 result = comparisons[hasher(yrPlane)][hasher(myPlane)]
1072 except KeyError:
1073 result = yrPlane.compare(myPlane,negative)
1074 try:
1075 comparisons[hasher(yrPlane)][hasher(myPlane)] = result
1076 except KeyError:
1077 comparisons[hasher(yrPlane)] = {hasher(myPlane):result}
1078 else:
1079 result = yrPlane.compare(myPlane,negative)
1080 if result == 'equal':
1081
1082 if side:
1083
1084 break
1085 else:
1086
1087 if mustBe[yrIndex] is False:
1088
1089 return False
1090 elif not mustBe[yrIndex]:
1091 mustBe[yrIndex] = True
1092 trueCount += 1
1093 if trueCount == len(set2):
1094
1095 return False
1096 elif result == 'inverse':
1097
1098 if side:
1099
1100 return False
1101 else:
1102
1103 if mustBe[yrIndex]:
1104
1105 return False
1106 else:
1107 mustBe[yrIndex] = False
1108 elif result == 'greater':
1109 if side:
1110
1111 break
1112 else:
1113
1114 pass
1115 elif result == 'less':
1116 if side:
1117
1118 pass
1119 else:
1120
1121 if mustBe[yrIndex] is False:
1122
1123 return False
1124 elif not mustBe[yrIndex]:
1125 mustBe[yrIndex] = True
1126 trueCount += 1
1127 if trueCount == len(set2):
1128
1129 return False
1130 else:
1131
1132 pass
1133 else:
1134
1135 planes.append(myPlane)
1136 if len(planes) == 0:
1137 return True
1138 else:
1139 return planes
1140
1142 """Pre-computes a comparison matrix between two sets of planes
1143 @param set1,set2: the two sets of planes
1144 @type set1,set2: L{Hyperplane}[]
1145 @return: a pairwise matrix of comparisons, indexed by the C{id} of each plane, so that C{result[id(p1)][id(p2)] = p1.compare(p2)}
1146 @rtype: str{}{}
1147 """
1148 comparisons = {}
1149 for plane1 in set1:
1150 comparisons[id(plane1)] = {}
1151 for plane2 in set2:
1152 comparisons[id(plane1)][id(plane2)] = plane1.compare(plane2)
1153 return comparisons
1154
1155 if __name__ == '__main__':
1156 from ProbabilityTree import *
1157 import pickle
1158 f = open('/tmp/tree.pickle','r')
1159 tree = pickle.load(f)
1160 print len(tree.leaves()),'leaves'
1161 f.close()
1162
1163 planes = {}
1164 nodes = [tree]
1165 while len(nodes) > 0:
1166 node = nodes.pop()
1167 if not node.isLeaf():
1168 for plane in node.split:
1169 planes[id(plane)] = plane
1170 nodes += node.children()
1171 print len(planes)
1172
1173 comparisons = generateComparisons(planes.values(),planes.values())
1174 for plane1 in planes.values():
1175 assert(comparisons.has_key(id(plane1)))
1176 for plane2 in planes.values():
1177 assert(comparisons[id(plane1)].has_key(id(plane2)))
1178 from teamwork.utils.Debugger import quickProfile
1179 quickProfile(tree.prune,(comparisons,False))
1180 print len(tree.leaves())
1181