Package teamwork :: Package math :: Module matrices
[hide private]
[frames] | no frames]

Source Code for Module teamwork.math.matrices

   1  """Base classes for piecewise linearity 
   2  @var __CONSTANT__: flag indicating whether a constant factor should be included in each vector 
   3  @type __CONSTANT__: boolean 
   4  @var epsilon: margin of error used in comparison 
   5  @type epsilon: float 
   6  """ 
   7  import copy 
   8  ##try: 
   9  ##    from numarray.numarrayall import * 
  10  ##except ImportError: 
  11  ##    pass 
  12  from types import * 
  13  from rules import pruneRules 
  14  from xml.dom.minidom import * 
  15  from id3 import gain 
  16  from dtree import create_decision_tree 
  17  import time 
  18   
  19  __CONSTANT__ = 1 
  20   
  21  epsilon = 0.00001 
  22       
23 -class Hyperplane:
24 """A structure to represent linear separations on an I{n}-dimensional space 25 @ivar weights: the slope of this plane 26 @type weights: L{KeyedVector} 27 @ivar threshold: the offset of this plane 28 @type threshold: float 29 @ivar relation: the relation against this plane. Default is >, alternatives are: =. 30 @type relation: str 31 """ 32
33 - def __init__(self,weights,threshold,relation=None):
34 """Constructs a hyperplane weights*x == threshold 35 @param weights: the slope of the hyperplane 36 @type weights: list or array 37 @param threshold: the intercept of this hyperplane 38 @type threshold: float""" 39 self._string = None 40 if type(weights) is ListType: 41 try: 42 self.weights = array(weights) 43 except TypeError,e: 44 print 'Weights:',weights 45 raise TypeError,e 46 else: 47 self.weights = weights 48 self.threshold = threshold 49 self.relation = relation
50
51 - def getWeights(self):
52 """Return the slope of this hyperplane""" 53 if __CONSTANT__: 54 return self.weights[:len(self.weights)-1] 55 else: 56 return self.weights
57
58 - def getConstant(self):
59 if __CONSTANT__: 60 return self.weights[len(self.weights)-1] 61 else: 62 return 0.
63
64 - def test(self,value):
65 """Returns true iff the passed in value (in array form) lies 66 above this hyperplane (self.weights*value > self.threshold) 67 @rtype: boolean""" 68 total = dot(self.weights,value) 69 if self.relation is None or self.relation == '>': 70 return total > self.threshold 71 elif self.relation == '=': 72 return abs(total - self.threshold) < epsilon 73 else: 74 raise UserWarning,'Unknown hyperplane test: %s' % (self.relation)
75
76 - def always(self):
77 """ 78 @return: 79 - True: iff this plane eliminates none of the state space (i.e., for all q, w*q > theta). 80 - False: iff this plane eliminates all of the state space (i.e., for all q, w*q <= theta). 81 - None: otherwise 82 @rtype: boolean 83 @warning: This has not yet been implemented for this class""" 84 raise NotImplementedError
85
86 - def compare(self,other):
87 """Modified version of __cmp__ method 88 @return: 89 - 'less': self < other 90 - 'greater': self > other 91 - 'equal': self == other 92 - 'indeterminate': none of the above 93 @rtype: str 94 """ 95 if self == other: 96 return 'equal' 97 elif self < other: 98 return 'less' 99 elif self > other: 100 return 'greater' 101 else: 102 return 'indeterminate'
103
104 - def __str__(self):
105 return '%s ? %5.3f' % (str(self.weights),self.threshold)
106 107 ## def __str__(self): 108 ## if self._string is None: 109 ## self._string = self.__str__() 110 ## return self._string 111
112 - def __neg__(self):
113 return self.__class__(self.weights*-1.,self.threshold)
114
115 - def inverse(self):
116 """ 117 Creates a plane exactly opposite to this one. In other words, for all C{x}, C{self.test[x]} implies C{not self.inverse().test[x]} 118 @rtype: L{Hyperplane} 119 """ 120 return self.__class__(-self.weights,-self.threshold)
121 122 ## def __cmp__(self,other): 123 ## if sum(self.weights != other.weights) > 0: 124 ## return 0 125 ## else: 126 ## return self.threshold.__cmp__(other.threshold) 127
128 - def __gt__(self,other):
129 if sum(self.getWeights() > other.getWeights()) > 0: 130 # One of our weights is greater than the other's 131 return 0 132 else: 133 diff = (self.threshold - self.getConstant()) \ 134 - (other.threshold - other.getConstant()) 135 if sum(self.getWeights() < other.getWeights()) > 0: 136 # One weight strictly greater 137 return diff > -epsilon 138 else: 139 # No weights strictly greater, so constant factor must be 140 return diff > 0.0
141
142 - def __lt__(self,other):
143 if sum(self.getWeights() < other.getWeights()) > 0: 144 # One of our weights is less than the other's 145 return 0 146 else: 147 diff = (self.threshold - self.getConstant()) \ 148 - (other.threshold - other.getConstant()) 149 if sum(self.getWeights() > other.getWeights()) > 0: 150 # One weight strictly greater 151 return diff < epsilon 152 else: 153 # No weights strictly greater, so constant factor must be 154 return diff < 0.0
155
156 - def __eq__(self,other):
157 return self.compare(other) == 'equal'
158 ## return sum(self.weights != other.weights) == 0 \ 159 ## and abs(self.threshold-other.threshold) < epsilon 160
161 - def __getitem__(self,index):
162 return self.weights[index]
163
164 - def __setitem__(self,index,value):
165 self.weights[index] = value
166
167 - def __copy__(self):
168 return self.__class__(copy.copy(self.weights),self.threshold,self.relation)
169
170 - def __deepcopy__(self,memo):
171 weights = copy.deepcopy(self.weights,memo) 172 memo[id(self.weights)] = weights 173 return self.__class__(weights,self.threshold,self.relation)
174
175 - def __xml__(self):
176 doc = Document() 177 root = doc.createElement('plane') 178 doc.appendChild(root) 179 root.setAttribute('threshold',str(self.threshold)) 180 if self.relation: 181 root.setAttribute('relation',self.relation) 182 root.appendChild(self.weights.__xml__().documentElement) 183 return doc
184
185 - def parse(self,element):
186 """Extracts the plane from the given XML element 187 @param element: The XML Element object specifying the plane 188 @type element: Element 189 @return: the L{Hyperplane} instance""" 190 self.threshold = float(element.getAttribute('threshold')) 191 self.relation = str(element.getAttribute('relation')) 192 if not self.relation: 193 self.relation = None 194 nodes = element.getElementsByTagName('vector') 195 self.weights = self.weights.parse(nodes[0]) 196 # Patch bug in writing EqualRow? 197 if self.weights.__class__.__name__ == 'KeyedVector': 198 nodes[0].setAttribute('type','Equal') 199 self.weights = self.weights.parse(nodes[0]) 200 return self
201
202 -class DecisionTree:
203 """Represents a decision tree with hyperplane branches that divide 204 an n-dimensional space, and unrestricted values stored at the leaf 205 nodes (e.g., matrices for dynamics, actions for policies, etc.) 206 @cvar planeClass: the class used to instantiate the branches 207 @cvar checkTautology: flag that, if C{True}, activates check for hyperplanes that are either always C{True} or always C{False} in L{branch}. This can lead to smaller trees, but decreases efficiency 208 @type checkTautology: C{boolean} 209 @cvar checkPrune: flag that, if C{True}, activates the L{prune} method. This will lead to much smaller trees, but increases the overhead required to check for pruneability 210 @type checkPrune: C{boolean} 211 """ 212 planeClass = Hyperplane 213 checkTautology = False 214 checkPrune = True 215
216 - def __init__(self,value=None):
217 """Creates a DecisionTree 218 @param value: the optional leaf node value""" 219 self.parent = None 220 self.stats = {} 221 self.makeLeaf(value)
222
223 - def makeLeaf(self,value):
224 """Marks this tree as a leaf node 225 @param value: the new value of this leaf node""" 226 self.branchType = None 227 self.split = [] 228 while isinstance(value,DecisionTree): 229 value = value.getValue() 230 self.falseTree = value 231 self.trueTree = None
232
233 - def getValue(self):
234 """ 235 @return: the value of this tree 236 - If a leaf node, as a single object 237 - If a branch, as a tuple (falseTree,trueTree)""" 238 if self.isLeaf(): 239 return self.falseTree 240 else: 241 return (self.falseTree,self.trueTree)
242
243 - def isLeaf(self):
244 """ 245 @return: C{True} iff this tree is a leaf node 246 @rtype: boolean""" 247 if len(self.split) > 0: 248 return False 249 else: 250 return True
251
252 - def children(self):
253 """ 254 @return: all subtrees rooted at this node 255 @rtype: list of L{DecisionTree} instances""" 256 if self.isLeaf(): 257 return [] 258 else: 259 falseTree,trueTree = self.getValue() 260 return [trueTree,falseTree]
261
262 - def leaves(self):
263 """ 264 @return: list of all leaf values (not necessarily unique) from L to R 265 @note: the leaf value is the result of calling L{getValue}, not an actual L{DecisionTree} instance 266 """ 267 if self.isLeaf(): 268 return [self.getValue()] 269 else: 270 leaves = [] 271 for child in self.children(): 272 leaves += child.leaves() 273 return leaves
274
275 - def leafNodes(self):
276 """ 277 @return: list of all leaf nodes (not necessarily unique) from L to R 278 @rtype: L{DecisionTree}[] 279 """ 280 if self.isLeaf(): 281 return [self] 282 else: 283 leaves = [] 284 for child in self.children(): 285 leaves += child.leafNodes() 286 return leaves
287
288 - def depth(self):
289 """ 290 @return: the maximum distance between this node and the leaf nodes of the tree rooted at this node (a leaf node has a depth of 0, a branch node with two leaf nodes as children has a depth of 1, etc.) 291 @rtype: int 292 """ 293 if self.isLeaf(): 294 return 0 295 else: 296 return 1+max(map(lambda c:c.depth(),self.children()))
297
298 - def branches(self,result=None):
299 """ 300 @return: all branches (not necessarily unique) 301 @rtype: intS{->}L{Hyperplane} 302 """ 303 if result is None: 304 result = {} 305 if not self.isLeaf(): 306 if isinstance(self.split,list): 307 for plane in self.split: 308 result[id(plane)] = plane 309 else: 310 assert(not isinstance(self.split,Hyperplane)) 311 result[id(self.split)] = self.split 312 for child in self.children(): 313 result.update(child.branches(result)) 314 return result
315
316 - def branch(self,plane,falseTree,trueTree,pruneF=True,pruneT=True,debug=False):
317 """Marks this tree as a deterministic branching node 318 @param plane: the branchpoint(s) separating the C{False} and C{True} subtrees 319 @type plane: L{Hyperplane} or L{Hyperplane}[] 320 @param falseTree: the C{False} subtree 321 @type falseTree: L{DecisionTree} instance 322 @param trueTree: the C{True} subtree 323 @type trueTree: L{DecisionTree} instance 324 @param pruneF: if true, will L{prune} the C{False} subtree 325 @type pruneF: bool 326 @param pruneT: if true, will L{prune} the C{True} subtree 327 @type pruneT: bool 328 @param debug: if C{True}, some debugging statements will be written to stdout (default is C{False}) 329 @type debug: bool 330 @note: setting either prune flag to false will save time (though may lead to more inefficient trees)""" 331 self.branchType = 'deterministic' 332 if isinstance(plane,list): 333 self.split = plane 334 else: 335 self.split = [plane] 336 if self.checkTautology: 337 # Check whether these conditions are always true 338 always = None 339 for plane in self.split[:]: 340 value = plane.always() 341 if value == False: 342 # A single False value makes the whole condition False 343 always = value 344 break 345 elif value == True: 346 # If always True, remove from conjunction 347 self.split.remove(plane) 348 if len(self.split) == 0: 349 # Always True, so the False subtree is irrelevant 350 if isinstance(trueTree,DecisionTree): 351 if trueTree.isLeaf(): 352 self.makeLeaf(trueTree.getValue()) 353 else: 354 newFalse,newTrue = trueTree.getValue() 355 self.branch(trueTree.split,newFalse,newTrue, 356 pruneF=False,pruneT=False,debug=debug) 357 else: 358 self.makeLeaf(trueTree) 359 return 360 elif always == False: 361 # the True subtree is irrelevant 362 if isinstance(falseTree,DecisionTree): 363 if falseTree.isLeaf(): 364 self.makeLeaf(falseTree.getValue()) 365 else: 366 newFalse,newTrue = falseTree.getValue() 367 self.branch(falseTree.split,newFalse,newTrue, 368 pruneF=False,pruneT=False,debug=debug) 369 else: 370 self.makeLeaf(falseTree) 371 return 372 # Create False subtree 373 if isinstance(falseTree,DecisionTree): 374 self.falseTree = falseTree 375 else: 376 self.falseTree = self.__class__() 377 self.falseTree.makeLeaf(falseTree) 378 # Create True subtree 379 if isinstance(trueTree,DecisionTree): 380 self.trueTree = trueTree 381 else: 382 self.trueTree = self.__class__() 383 self.trueTree.makeLeaf(trueTree) 384 self.falseTree.parent = (self,False) 385 self.trueTree.parent = (self,True) 386 if pruneF: 387 self.falseTree.prune(debug=debug) 388 if pruneT: 389 self.trueTree.prune(debug=debug)
390
391 - def getPath(self):
392 """ 393 @return: the conditions under which this node will be reached, as a list of C{(plane,True/False)} tuples 394 @rtype: (L{Hyperplane},boolean)[] 395 """ 396 if self.parent: 397 parent,side = self.parent 398 return [(parent.split,side)] + parent.getPath() 399 else: 400 return []
401
402 - def createIndex(self,start=0):
403 self.stats['index'] = start 404 if not self.isLeaf(): 405 falseTree,trueTree = self.getValue() 406 # Initialize statistics, if not already done 407 if not falseTree.stats.has_key('leaf'): 408 self.count() 409 falseTree.createIndex(start) 410 trueTree.createIndex(start+falseTree.stats['leaf'])
411
412 - def removeTautologies(self,negative=True):
413 if not self.isLeaf(): 414 truth = True # Assume True if no branches 415 for plane in self.split: 416 truth = plane.always(negative) 417 if isinstance(truth,bool): 418 break 419 if truth is True: 420 # Only need true tree 421 fTree,tTree = self.getValue() 422 tTree.removeTautologies() 423 if tTree.isLeaf(): 424 self.makeLeaf(tTree.getValue()) 425 else: 426 fNew,tNew = tTree.getValue() 427 self.branch(tTree.split,fNew,tNew, 428 pruneF=False,pruneT=False) 429 elif truth is False: 430 # Only need false tree 431 fTree,tTree = self.getValue() 432 fTree.removeTautologies() 433 if fTree.isLeaf(): 434 self.makeLeaf(fTree.getValue()) 435 else: 436 fNew,tNew = fTree.getValue() 437 self.branch(fTree.split,fNew,tNew, 438 pruneF=False,pruneT=False)
439
440 - def prune(self,comparisons=None,debug=False,negative=True):
441 if not self.checkPrune: 442 return 443 if comparisons is None: 444 comparisons = {} 445 if not self.isLeaf(): 446 ancestor = self.parent 447 split = self.split 448 if debug: 449 print 450 print 'Current:' 451 print ' and '.join(map(lambda p:p.simpleText(),split)) 452 print len(self.leaves()) 453 while ancestor: 454 parent,side = self.parent 455 tree,direction = ancestor 456 if debug: 457 print 'Ancestor:',len(tree.split) 458 ## print ' and '.join(map(lambda p:p.simpleText(),tree.split)) 459 print 'Side:',direction 460 split = comparePlaneSets(split,tree.split,direction,comparisons,debug,negative) 461 if debug: 462 print 'Result:', 463 if isinstance(split,bool): 464 print split 465 else: 466 print ' and '.join(map(lambda p:p.simpleText(),split)) 467 if isinstance(split,bool): 468 oldFalse,oldTrue = parent.getValue() 469 newFalse,newTrue = self.getValue() 470 if split: 471 # The conjunction has degenerated to always be True 472 oldFalse,oldTrue = parent.getValue() 473 newFalse,newTrue = self.getValue() 474 if side: 475 parent.branch(parent.split,oldFalse,newTrue, 476 pruneF=False,pruneT=True,debug=debug) 477 else: 478 parent.branch(parent.split,newTrue,oldTrue, 479 pruneF=True,pruneT=False,debug=debug) 480 else: 481 # We're already guaranteed to be False 482 if side: 483 parent.branch(parent.split,oldFalse,newFalse, 484 pruneF=False,pruneT=True,debug=debug) 485 else: 486 parent.branch(parent.split,newFalse,oldTrue, 487 pruneF=True,pruneT=False,debug=debug) 488 break 489 ancestor = tree.parent 490 else: 491 self.split = split 492 self.falseTree.prune(comparisons,debug) 493 self.trueTree.prune(comparisons,debug) 494 # Check whether pruning has reduced T/F to be identical 495 falseTree,trueTree = self.getValue() 496 if falseTree == trueTree: 497 if debug: 498 print 'Equal subtrees:',falseTree,trueTree 499 if falseTree.isLeaf(): 500 self.makeLeaf(falseTree) 501 else: 502 newFalse,newTrue = falseTree.getValue() 503 self.branch(falseTree.split,newFalse,newTrue, 504 pruneF=False,pruneT=False,debug=debug)
505
506 - def count(self):
507 """ 508 @return: a dictionary of statistics about the decision tree rooted at this node: 509 - I{leaf}: # of leaves 510 - I{branch}: # of branch nodes 511 - I{depth}: depth of tree""" 512 if self.isLeaf(): 513 self.stats['leaf'] = 1 514 self.stats['branch'] = 0 515 self.stats['depth'] = 0 516 return self.stats 517 else: 518 self.stats['leaf'] = 0 519 self.stats['branch'] = 1 520 self.stats['depth'] = 1 521 depth = 0 522 for tree in self.children(): 523 subCount = tree.count() 524 for key in ['leaf','branch']: 525 self.stats[key] += subCount[key] 526 if subCount['depth'] > depth: 527 depth = subCount['depth'] 528 self.stats['depth'] += depth 529 return self.stats
530
531 - def rebalance(self,debug=False):
532 """ 533 Uses ID3 heuristic to reorder branches 534 @return: C{True}, iff a rebalancing was applied at this level 535 """ 536 target = '_value' 537 attributes = {target:True} 538 values = {} 539 data = self.makeRules(attributes,values) 540 new = create_decision_tree(data,attributes.keys(),target,gain) 541 self._extractTree(new,attributes,values) 542 return self
543
544 - def _extractTree(self,tree,attributes,values):
545 """ 546 Extracts the rules from the given L{dtree} structure into this tree 547 """ 548 if type(tree) == dict: 549 plane = attributes[tree.keys()[0]] 550 trueTree = None 551 falseTree = None 552 for item in tree.values()[0].keys(): 553 if item == True: 554 trueTree = self.__class__() 555 trueTree._extractTree(tree.values()[0][item], 556 attributes,values) 557 elif item == False: 558 falseTree = self.__class__() 559 falseTree._extractTree(tree.values()[0][item], 560 attributes,values) 561 else: 562 raise UserWarning,'Unknown attribute value: %s' % \ 563 (str(item)) 564 if trueTree is None: 565 if falseTree is None: 566 raise UserWarning,'Null decision tree returned' 567 else: 568 self.makeLeaf(falseTree) 569 elif falseTree is None: 570 self.makeLeaf(trueTree) 571 else: 572 if falseTree == trueTree: 573 raise UserWarning 574 self.branch(plane,falseTree,trueTree, 575 pruneF=False,pruneT=False) 576 else: 577 self.makeLeaf(values[tree]) 578 return self
579
580 - def makeRules(self,attributes=None,values=None,conditions=None, 581 debug=False,comparisons=None):
582 """Represents this tree as a list of rules 583 @return: dict[] 584 """ 585 if comparisons is None: 586 comparisons = {} 587 rules = [] 588 if attributes is None: 589 attributes = {'_value':True} 590 if values is None: 591 values = {} 592 if conditions is None: 593 conditions = [] 594 if self.isLeaf(): 595 label = str(self.getValue()) 596 rule = {'_value':label} 597 values[label] = self.getValue() 598 for plane,side in conditions: 599 rule[plane] = side 600 rules.append(rule) 601 else: 602 falseTree,trueTree = self.getValue() 603 newConditions = {} 604 for plane in self.split: 605 label = plane.simpleText() 606 attributes[label] = plane 607 newConditions[label] = plane 608 # Determine rules when we branch False 609 for plane in newConditions.keys(): 610 split = [newConditions[plane]] 611 for oldPlane,side in conditions: 612 split = comparePlaneSets(split,[attributes[oldPlane]], 613 side,comparisons) 614 if isinstance(split,bool): 615 if split: 616 # Guaranteed to be True, so no need to continue 617 break 618 else: 619 # Guaranteed to be False, so no need to add plane 620 rules += falseTree.makeRules(attributes,values, 621 conditions, 622 debug,comparisons) 623 break 624 else: 625 # Must add this plane as extra condition 626 rules += falseTree.makeRules(attributes,values, 627 conditions+[(plane,False)], 628 debug,comparisons) 629 # Determine rules when we branch True 630 split = newConditions.values() 631 for oldPlane,side in conditions: 632 split = comparePlaneSets(split,[attributes[oldPlane]], 633 side,comparisons) 634 if isinstance(split,bool): 635 if split: 636 # Guaranteed to be True, so no need to add plane 637 rules += trueTree.makeRules(attributes,values, 638 conditions, 639 debug,comparisons) 640 break 641 else: 642 # Guaranteed to be False, so no need to add plane 643 break 644 else: 645 # Must add this plane as extra condition 646 rules += trueTree.makeRules(attributes,values, 647 conditions+map(lambda p:(p,True), 648 newConditions.keys()), 649 debug,comparisons) 650 # Once we've created all of the rules, fill in any missing 651 # conditions on the left-hand sides 652 if not self.parent: 653 for rule in rules: 654 for attr in attributes.keys(): 655 if not rule.has_key(attr): 656 # Add a "wildcard" 657 rule[attr] = None 658 if debug: 659 print '\t\tPruning %s rules' % (len(rules)) 660 rules,attributes = pruneRules(rules,attributes,values,debug) 661 return rules
662
663 - def fromRules(self,rules,attributes,values,comparisons=None):
664 tree = self 665 if comparisons is None: 666 comparisons = {} 667 for rule in rules[:-1]: 668 split = [] 669 for attr,value in rule.items(): 670 if attr == '_value': 671 tTree = self.__class__() 672 tTree.makeLeaf(values[value]) 673 elif value == True: 674 split.append(attributes[attr]) 675 elif value == False: 676 split.append(attributes[attr].inverse()) 677 # Minimize branches in conjunction 678 value = True 679 while value is not None: 680 value = None 681 for index in range(len(split)): 682 result = comparePlaneSets([split[index]], 683 split[:index]+split[index+1:], 684 True,comparisons) 685 if isinstance(result,bool): 686 if result: 687 # This plane is redundant 688 del split[index] 689 value = True 690 break 691 else: 692 # This plane is in conflict with the others 693 value = False 694 break 695 if value is False: 696 break 697 else: 698 fTree = self.__class__() 699 fTree.makeLeaf(None) 700 tree.branch(split,fTree,tTree) 701 tree = fTree 702 tree.makeLeaf(values[rules[-1]['_value']]) 703 return self
704
705 - def generateAlternatives(self,index,value,test=None):
706 if not test: 707 test = lambda x,y: x != y 708 if self.isLeaf(): 709 myValue = self.getValue() 710 if test(myValue,value): 711 return [{'plane':None,'truth':1,'value':myValue}] 712 else: 713 # No alternative 714 return [] 715 else: 716 falseTree,trueTree = self.getValue() 717 if reduce(lambda x,y:x and y, 718 map(lambda p:p.test(index),self.split)): 719 # We are on the True side 720 alternatives = trueTree.generateAlternatives(index,value) 721 myValue = falseTree[index] 722 for action in myValue: 723 if test(action,value): 724 # Here's a way to get a different value 725 for plane in self.split: 726 if plane.test(index): 727 alternatives.append({'plane':plane, 728 'truth':False, 729 'value':myValue}) 730 break 731 else: 732 # We are on the False side 733 alternatives = falseTree.generateAlternatives(index,value) 734 myValue = trueTree[index] 735 for action in myValue: 736 if test(action,value): 737 # Here's a way to get a different value 738 for plane in self.split: 739 if not plane.test(index): 740 alternatives.append({'plane':plane, 741 'truth':True, 742 'value':myValue}) 743 break 744 return alternatives
745
746 - def __getitem__(self,index):
747 if type(index) is IntType: 748 # Direct index into leaf node 749 if not self.stats.has_key('index'): 750 self.createIndex() 751 if self.isLeaf(): 752 if self.stats['index'] == index: 753 return self 754 else: 755 raise IndexError,index 756 else: 757 falseTree,trueTree = self.getValue() 758 if index < falseTree.stats['index'] + falseTree.stats['leaf']: 759 return falseTree[index] 760 else: 761 return trueTree[index] 762 else: 763 # Array type index into decision tree 764 if self.isLeaf(): 765 return self.getValue() 766 else: 767 # All planes in branch must be true 768 if reduce(lambda x,y:x and y, 769 map(lambda p:p.test(index),self.split)): 770 return self.trueTree[index] 771 else: 772 return self.falseTree[index]
773
774 - def replace(self,orig,new,comparisons=None,conditions=[]):
775 """Replaces any leaf nodes that match the given original value 776 with the provided new value, followed by a pruning phase 777 @param orig: leaf value to be replaced 778 @param new: leaf value with which to replace 779 @warning: the replacement modifies this tree in place""" 780 if not isinstance(new,DecisionTree): 781 raise NotImplementedError,'Currently unable to replace leaf nodes with non-tree objects' 782 if comparisons is None: 783 comparisons = {} 784 if self.isLeaf(): 785 value = self.getValue() 786 if isinstance(value,orig.__class__) and value == orig: 787 if new.isLeaf(): 788 self.makeLeaf(new.getValue()) 789 else: 790 falseTree,trueTree = new.getValue() 791 # Check whether this branch is relevant 792 split = new.split 793 for plane,truth in conditions: 794 split = comparePlaneSets(split,plane,truth,comparisons) 795 if isinstance(split,bool): 796 if split: 797 # Guaranteed True 798 return self.replace(orig,trueTree,comparisons,conditions) 799 else: 800 # Guaranteed False 801 return self.replace(orig,falseTree,comparisons,conditions) 802 # Merge the subtree branch 803 newFalse = self.__class__() 804 newFalse.makeLeaf(orig) 805 newFalse.replace(orig,falseTree,comparisons,conditions) 806 newTrue = self.__class__() 807 newTrue.makeLeaf(orig) 808 newTrue.replace(orig,trueTree,comparisons,conditions) 809 self.branch(split,newFalse,newTrue,pruneF=False,pruneT=False) 810 else: 811 # Copy the current tree 812 falseTree,trueTree = self.getValue() 813 falseTree.replace(orig,new,comparisons,conditions+[(self.split,False)]) 814 trueTree.replace(orig,new,comparisons,conditions+[(self.split,True)]) 815 self.branch(self.split,falseTree,trueTree,pruneF=False,pruneT=False)
816
817 - def merge(self,other,op):
818 """Merges the two trees together using the given operator to combine leaf values 819 @param other: the other tree to merge with 820 @type other: L{DecisionTree} instance 821 @param op: the operator used to generate the new leaf values, C{lambda x,y:f(x,y)} where C{x} and C{y} are leaf values 822 @rtype: a new L{DecisionTree} instance""" 823 result = self._merge(other,op) 824 result.prune() 825 return result
826
827 - def _merge(self,other,op,comparisons=None,conditions=[]):
828 """Helper method that merges the two trees together using the given operator to combine leaf values, without pruning 829 @param other: the other tree to merge with 830 @type other: L{DecisionTree} instance 831 @param op: the operator used to generate the new leaf values, C{lambda x,y:f(x,y)} where C{x} and C{y} are leaf values 832 @rtype: a new L{DecisionTree} instance""" 833 if comparisons is None: 834 comparisons = {} 835 result = self.__class__() 836 if not self.isLeaf(): 837 falseTree,trueTree = self.getValue() 838 falseTree = falseTree._merge(other,op,comparisons,conditions+[(self.split,False)]) 839 trueTree = trueTree._merge(other,op,comparisons,conditions+[(self.split,True)]) 840 result.branch(self.split,falseTree,trueTree,pruneF=False,pruneT=False) 841 elif isinstance(other,DecisionTree): 842 if other.isLeaf(): 843 result.makeLeaf(op(self.getValue(),other.getValue())) 844 else: 845 falseTree,trueTree = other.getValue() 846 # Check whether this branch is relevant 847 split = other.split 848 for plane,truth in conditions: 849 split = comparePlaneSets(split,plane,truth,comparisons) 850 if isinstance(split,bool): 851 if split: 852 # Guaranteed True 853 return self._merge(trueTree,op,comparisons,conditions) 854 else: 855 # Guaranteed False 856 return self._merge(falseTree,op,comparisons,conditions) 857 # Merge the subtree branch 858 newFalse = self._merge(falseTree,op,comparisons,conditions) 859 newTrue = self._merge(trueTree,op,comparisons,conditions) 860 result.branch(split,newFalse,newTrue, 861 pruneF=False,pruneT=False) 862 else: 863 result.makeLeaf(op(self.getValue(),other)) 864 return result
865
866 - def __add__(self,other):
867 return self.merge(other,lambda x,y:x+y)
868
869 - def __mul__(self,other):
870 result = self._multiply(other) 871 result.prune() 872 return result
873
874 - def _multiply(self,other,comparisons=None,conditions=[]):
875 if comparisons is None: 876 comparisons = {} 877 result = self.__class__() 878 if other.isLeaf(): 879 if self.isLeaf(): 880 result.makeLeaf(matrixmultiply(self.getValue(), 881 other.getValue())) 882 else: 883 falseTree,trueTree = self.getValue() 884 new = [] 885 for original in self.split: 886 weights = matrixmultiply(original.weights,other.getValue()) 887 plane = original.__class__(weights,original.threshold) 888 new.append(plane) 889 result.branch(new,falseTree._multiply(other,comparisons,conditions+[(new,False)]), 890 trueTree._multiply(other,comparisons,conditions+[(new,True)]), 891 pruneF=False,pruneT=False) 892 else: 893 falseTree,trueTree = other.getValue() 894 split = other.split 895 # Check whether this branch is relevant 896 for plane,truth in conditions: 897 split = comparePlaneSets(split,plane,truth,comparisons) 898 if isinstance(split,bool): 899 if split: 900 # Guaranteed True 901 return self._multiply(trueTree,comparisons,conditions) 902 else: 903 # Guaranteed False 904 return self._multiply(falseTree,comparisons,conditions) 905 # Merge the subtree branch 906 newFalse = self._multiply(falseTree,comparisons,conditions) 907 newTrue = self._multiply(trueTree,comparisons,conditions) 908 result.branch(split,newFalse,newTrue,pruneF=False,pruneT=False) 909 return result
910
911 - def __sub__(self,other):
912 return self + (-other)
913
914 - def __neg__(self):
915 result = self.__class__() 916 if self.isLeaf(): 917 result.makeLeaf(-self.getValue()) 918 else: 919 result.branch(self.split,-self.falseTree,-self.trueTree, 920 pruneF=False,pruneT=False) 921 return result
922
923 - def __eq__(self,other):
924 if self.__class__ == other.__class__: 925 if self.isLeaf() and other.isLeaf(): 926 return self.getValue() == other.getValue() 927 elif not self.isLeaf() and not other.isLeaf(): 928 return (self.split == other.split) and \ 929 (self.getValue() == other.getValue()) 930 else: 931 return False 932 else: 933 return False
934
935 - def __hash__(self):
936 return hash(str(self))
937
938 - def __str__(self):
939 return self.simpleText()
940
941 - def __copy__(self):
942 new = self.__class__() 943 if self.isLeaf(): 944 new.makeLeaf(copy.copy(self.getValue())) 945 else: 946 falseTree,trueTree = self.getValue() 947 new.branch(copy.copy(self.split),copy.copy(falseTree), 948 copy.copy(trueTree),0,0) 949 return new
950
951 - def __xml__(self):
952 doc = Document() 953 root = doc.createElement('tree') 954 doc.appendChild(root) 955 if self.isLeaf(): 956 root.setAttribute('type','leaf') 957 value = self.getValue() 958 try: 959 root.appendChild(value.__xml__().documentElement) 960 except AttributeError: 961 # Floats, lists, strings, etc. all get converted into strings 962 root.appendChild(doc.createTextNode(str(value))) 963 else: 964 root.setAttribute('type','branch') 965 element = doc.createElement('split') 966 root.appendChild(element) 967 for plane in self.split: 968 element.appendChild(plane.__xml__().documentElement) 969 falseTree,trueTree = self.getValue() 970 element = doc.createElement('false') 971 root.appendChild(element) 972 element.appendChild(falseTree.__xml__().documentElement) 973 element = doc.createElement('true') 974 root.appendChild(element) 975 element.appendChild(trueTree.__xml__().documentElement) 976 return doc
977
978 - def parse(self,element,valueClass=None,debug=False):
979 """Extracts the tree from the given XML element 980 @param element: The XML Element object specifying the plane 981 @type element: Element 982 @param valueClass: The class used to generate the leaf values 983 @return: the L{KeyedTree} instance""" 984 if element.getAttribute('type') == 'leaf': 985 # Extract leaf value 986 if not valueClass: 987 valueClass = float 988 node = element.firstChild 989 while node: 990 if node.nodeType == node.ELEMENT_NODE: 991 value = valueClass() 992 value = value.parse(node) 993 break 994 elif node.nodeType == node.TEXT_NODE: 995 value = str(node.data).strip() 996 if len(value) > 0: 997 if value == 'None': 998 # Better hope that this wasn't intended to be a string 999 value = None 1000 break 1001 node = node.nextSibling 1002 else: 1003 # Should this be an error? No, be proud of your emptiness. 1004 value = None 1005 ## raise UserWarning,'Empty leaf node: %s' % (element.toxml()) 1006 self.makeLeaf(value) 1007 else: 1008 # Extract plane, False, and True 1009 planes = [] 1010 falseTree = self.__class__() 1011 trueTree = self.__class__() 1012 node = element.firstChild 1013 while node: 1014 if node.nodeType == node.ELEMENT_NODE: 1015 if node.tagName == 'split': 1016 subNode = node.firstChild 1017 while subNode: 1018 if subNode.nodeType == subNode.ELEMENT_NODE: 1019 plane = self.planeClass({},0.) 1020 planes.append(plane.parse(subNode)) 1021 subNode = subNode.nextSibling 1022 elif node.tagName in ['false','left']: 1023 subNode = node.firstChild 1024 while subNode and subNode.nodeType != node.ELEMENT_NODE: 1025 subNode = subNode.nextSibling 1026 falseTree = falseTree.parse(subNode,valueClass,debug) 1027 elif node.tagName in ['true','right']: 1028 subNode = node.firstChild 1029 while subNode and subNode.nodeType != node.ELEMENT_NODE: 1030 subNode = subNode.nextSibling 1031 trueTree = trueTree.parse(subNode,valueClass,debug) 1032 node = node.nextSibling 1033 self.branch(planes,falseTree,trueTree) 1034 return self
1035
1036 -def printData(data,values=None):
1037 for datum in data: 1038 print '\n\t', 1039 for attr,val in datum.items(): 1040 if attr != '_value': 1041 print '%5s' % (val), 1042 if values: 1043 print values.index(datum['_value']),pow(2,datum.values().count(None)), 1044 print
1045 1046
1047 -def comparePlaneSets(set1,set2,side,comparisons=None, 1048 debug=False,negative=True):
1049 """ 1050 Compares a conjunction of planes against a second conjunction of planes that has already been tested against. It prunes the current conjunction based on any redundancy or inconsistency with the test 1051 @param set1: the plane set to be pruned 1052 @param set2: the plane set already tested 1053 @type set1,set2: L{Hyperplane}[] 1054 @param side: the side of the second set that we're already guaranteed to be on 1055 @type side: boolean 1056 @return: The minimal set of planes in the first set that are not redundant given these a priori conditions (if guaranteed to be C{True} or C{False}, then the boolean value is returned) 1057 @rtype: L{Hyperplane}[] 1058 @param negative: if C{True}, then assume that weights may be negative (default is C{True} 1059 """ 1060 hasher = id 1061 # Relevant planes so far 1062 planes = [] 1063 mustBe = map(lambda p:None,set2) 1064 trueCount = 0 1065 # Compare this branch against parent branch 1066 for myPlane in set1: 1067 for yrIndex in range(len(set2)): 1068 yrPlane = set2[yrIndex] 1069 if isinstance(comparisons,dict): 1070 try: 1071 result = comparisons[hasher(yrPlane)][hasher(myPlane)] 1072 except KeyError: 1073 result = yrPlane.compare(myPlane,negative) 1074 try: 1075 comparisons[hasher(yrPlane)][hasher(myPlane)] = result 1076 except KeyError: 1077 comparisons[hasher(yrPlane)] = {hasher(myPlane):result} 1078 else: 1079 result = yrPlane.compare(myPlane,negative) 1080 if result == 'equal': 1081 # We need yrPlane to be True 1082 if side: 1083 # All of set2 is True, so myPlane already guaranteed to be True 1084 break 1085 else: 1086 # At least one in set2 is False 1087 if mustBe[yrIndex] is False: 1088 # Oops, already asked yrPlane to be False 1089 return False 1090 elif not mustBe[yrIndex]: 1091 mustBe[yrIndex] = True 1092 trueCount += 1 1093 if trueCount == len(set2): 1094 # We require all of set2 to be True, but at least one's False 1095 return False 1096 elif result == 'inverse': 1097 # We need yrPlane to be False 1098 if side: 1099 # all of set2 is True 1100 return False 1101 else: 1102 # At least one in set2 is False 1103 if mustBe[yrIndex]: 1104 # Oops, already asked yrPlane to be True 1105 return False 1106 else: 1107 mustBe[yrIndex] = False 1108 elif result == 'greater': 1109 if side: 1110 # This plane is already guaranteed to be True 1111 break 1112 else: 1113 # We can't conclude anything about this plane 1114 pass 1115 elif result == 'less': 1116 if side: 1117 # We can't conclude anything about this plane 1118 pass 1119 else: 1120 # myPlane is False if yrPlane is False 1121 if mustBe[yrIndex] is False: 1122 # Oops, already asked yrPlane to be False 1123 return False 1124 elif not mustBe[yrIndex]: 1125 mustBe[yrIndex] = True 1126 trueCount += 1 1127 if trueCount == len(set2): 1128 # We require all of set2 to be True, but at least one's False 1129 return False 1130 else: 1131 # No conclusive comparison 1132 pass 1133 else: 1134 # We didn't draw any conclusions about this plane 1135 planes.append(myPlane) 1136 if len(planes) == 0: 1137 return True 1138 else: 1139 return planes
1140
1141 -def generateComparisons(set1,set2):
1142 """Pre-computes a comparison matrix between two sets of planes 1143 @param set1,set2: the two sets of planes 1144 @type set1,set2: L{Hyperplane}[] 1145 @return: a pairwise matrix of comparisons, indexed by the C{id} of each plane, so that C{result[id(p1)][id(p2)] = p1.compare(p2)} 1146 @rtype: str{}{} 1147 """ 1148 comparisons = {} 1149 for plane1 in set1: 1150 comparisons[id(plane1)] = {} 1151 for plane2 in set2: 1152 comparisons[id(plane1)][id(plane2)] = plane1.compare(plane2) 1153 return comparisons
1154 1155 if __name__ == '__main__': 1156 from ProbabilityTree import * 1157 import pickle 1158 f = open('/tmp/tree.pickle','r') 1159 tree = pickle.load(f) 1160 print len(tree.leaves()),'leaves' 1161 f.close() 1162 1163 planes = {} 1164 nodes = [tree] 1165 while len(nodes) > 0: 1166 node = nodes.pop() 1167 if not node.isLeaf(): 1168 for plane in node.split: 1169 planes[id(plane)] = plane 1170 nodes += node.children() 1171 print len(planes) 1172 1173 comparisons = generateComparisons(planes.values(),planes.values()) 1174 for plane1 in planes.values(): 1175 assert(comparisons.has_key(id(plane1))) 1176 for plane2 in planes.values(): 1177 assert(comparisons[id(plane1)].has_key(id(plane2))) 1178 from teamwork.utils.Debugger import quickProfile 1179 quickProfile(tree.prune,(comparisons,False)) 1180 print len(tree.leaves()) 1181