1 import copy
2 from types import *
3
4 from ThespianKeys import makeBelongKey
5 from matrices import Hyperplane,DecisionTree
6 from KeyedMatrix import *
7 from probability import *
8 from teamwork.utils.FriendlyFloat import *
9
11 """A L{Hyperplane} alternative that stores the weights as a L{KeyedVector}, rather than as a list/array. The interface is identical to that of the L{Hyperplane} superclass, except that the array arguments should be in dictionary (not list/array) form
12 @ivar weights: the slope of this plane
13 @type weights: L{KeyedVector}
14 @ivar threshold: the offset of this plane
15 @type threshold: float
16 @ivar relation: the relation against this plane. Default is >, alternatives are: =.
17 @type relation: str
18 """
19
20 - def __init__(self,weights,threshold,relation=None):
21 """Constructs a hyperplane whose slope is determined by the given weights (in dictionary or L{KeyedVector} form) and whose intercept is defined by threshold (i.e., C{weights*x == threshold})
22 @type weights: dict
23 @warning: you should start passing in C{weights} in L{KeyedVector} form, because the handling of C{dict} arguments will likely be deprecated
24 """
25 if type(weights) is DictType:
26 Hyperplane.__init__(self,KeyedVector(weights),threshold,relation)
27 else:
28 Hyperplane.__init__(self,weights,threshold,relation)
29
30 if self.weights.has_key(keyConstant):
31 self.threshold -= self.weights[keyConstant]
32 self.weights[keyConstant] = 0.
33
40
41 - def test(self,value):
42 """
43 @return: C{True} iff the passed in value (in dictionary form) lies above this hyperplane (C{self.weights*value > self.threshold})
44 @rtype: bool"""
45 total = self.weights * value
46 if self.relation is None or self.relation == '>':
47 return total > self.threshold
48 elif self.relation == '=':
49 return abs(total - self.threshold) < epsilon
50 else:
51 raise UserWarning,'Unknown hyperplane test: %s' % (self.relation)
52
54 """
55 Computes the 'probability' that a point in this region will satisfy this plane. In reality, just a coarse measure of volume.
56 @param intervals: a list of dictionaries, each with:
57 - weights: the slope of the bounding planes (L{KeyedVector})
58 - lo: the low value for this interval (int)
59 - hi: the high value for this interval (int)
60 @return: a dictionary of probability values over C{True} and C{False} (although not a real L{Distribution} object)
61 @rtype: boolS{->}float
62 """
63 result = {True:0.,False:0.}
64 for interval in intervals:
65 diff = sum(abs(self.weights.getArray()-
66 interval['weights'].getArray()))
67 if diff < epsilon:
68
69 span = interval['hi'] - interval['lo']
70 if span < epsilon:
71
72 result[self.test(interval['lo'])] = 1.
73 else:
74
75 if interval['hi'] <= self.threshold:
76
77 result[False] = 1.
78 elif interval['lo'] > self.threshold:
79
80 result[True] = 1.
81 else:
82
83 result[True] = (interval['hi']-self.threshold)/span
84 result[False] = 1. - result[True]
85 break
86 else:
87
88
89 result[True] = (1.-self.threshold)/2.
90 result[False] = 1. - result[True]
91 assert result[True] > -epsilon
92 assert result[False] > -epsilon
93 assert abs(result[True]+result[False]-1.) < epsilon
94 return result
95
96 - def always(self,negative=True,probability=False):
97 """
98 @return: C{True} iff this plane eliminates none of the state space (i.e., for all q, w*q > theta). C{False} iff this plane eliminates all of the state space (i.e., for all q, w*q <= theta).
99 @rtype: boolean
100 @param probability: if C{True}, then assume that weights are nonnegative and sum to 1 (default is C{False})
101 @param negative: if C{True}, then assume that weights may be negative (default is C{True})
102 @warning: guaranteed to hold only within unit hypercube"""
103 if probability and len(self.weights) == 2 and \
104 abs(self.threshold) < epsilon:
105
106 key1,key2 = self.weights.keys()
107 numerator = -self.weights[key2]
108 denominator = self.weights[key1] - self.weights[key2]
109 if abs(denominator) < epsilon:
110
111 return 0. > numerator
112 else:
113 threshold = numerator/denominator
114 if denominator > 0.:
115
116 if threshold < 0.:
117
118 return True
119 elif threshold >= 1.:
120
121 return False
122 else:
123
124 if threshold <= 0.:
125
126 return False
127 elif threshold > 1.:
128
129 return True
130 return None
131
132 if negative:
133 hi = 0.
134 for key in self.weights.keys():
135 hi += abs(self.weights[key])
136 lo = -hi
137 else:
138
139 hi = max([0]+self.weights.getArray())
140 lo = min([0]+self.weights.getArray())
141 if lo > self.threshold:
142
143 return True
144 elif hi <= self.threshold:
145
146 return False
147 else:
148 return None
149
151 """
152 @return: C{True} iff this plane has zero weights and a zero threshold
153 @rtype: bool
154 """
155 if abs(self.threshold) > epsilon:
156 return False
157 else:
158 for weight in self.weights.getArray():
159 if abs(weight) > epsilon:
160 return False
161 else:
162 return True
163
165 """
166 @return: the keys used by the weight vector of this plane
167 @rtype: L{Key}[]"""
168 return self.weights.keys()
169
172
178
180 return self == other or self > other
181
183 return self == other or self < other
184
185 - def compare(self,other,negative=True):
186 """Modified version of __cmp__ method
187 @return:
188 - 'less': C{self < other}, i.e., for all C{x}, if C{not self.test(x)}, then C{not other.test(x)}
189 - 'greater': C{self > other}, i.e., for all C{x}, if C{self.test(x)}, then C{other.test(x)}
190 - 'equal': C{self == other}, i.e., for all C{x}, C{self.test(x) == other.test(x)}
191 - 'inverse': C{self == not other}, i.e., for all C{x}, C{self.test(x) != other.test(x)}
192 - 'indeterminate': none of the above
193 @rtype: str
194 @param negative: if C{True}, then assume that weights may be negative (default is C{True})
195 """
196 if self.weights._frozen and other.weights._frozen:
197
198 myArray = self.weights.getArray()
199 yrArray = other.weights.getArray()
200 myThresh = self.threshold
201 yrThresh = other.threshold
202 if negative:
203 try:
204 scaling = min(filter(lambda w:w>epsilon,map(abs,myArray)))
205 except ValueError:
206 scaling = 1.
207 try:
208 myArray = myArray/scaling
209 myThresh /= scaling
210 except ZeroDivisionError:
211 pass
212 try:
213 scaling = min(filter(lambda w:w>epsilon,map(abs,yrArray)))
214 except ValueError:
215 scaling = 1.
216 try:
217 yrArray = yrArray/scaling
218 yrThresh /= scaling
219 except ZeroDivisionError:
220 pass
221
222
223
224
225
226
227
228
229
230
231 diff = sum(abs(myArray+yrArray))
232 if diff < epsilon and abs(myThresh+yrThresh) < epsilon:
233 return 'inverse'
234
235 diff = sum(abs(myArray-yrArray))
236 if diff < epsilon:
237 if abs(myThresh-yrThresh) < epsilon:
238 return 'equal'
239 elif yrThresh > myThresh:
240
241 return 'less'
242 else:
243
244 return 'greater'
245 hi = lo = yrThresh - yrThresh
246 if negative:
247 hi += diff
248 else:
249 hi += max([0]+(myArray-yrArray))
250 if negative:
251 lo -= diff
252 else:
253 lo += min([0]+(myArray-yrArray))
254 if hi > 0.:
255 if lo < 0.:
256 return 'indeterminate'
257 else:
258 return 'less'
259 elif lo < 0.:
260 return 'greater'
261 else:
262 return 'equal'
263 else:
264 hi = lo = other.threshold - self.threshold
265 value = 0
266 for key,myValue in self.weights.items():
267 try:
268 yrValue = other.weights[key]
269 except KeyError:
270 yrValue = 0.
271 diff = abs(myValue-yrValue)
272 hi += diff
273 lo -= diff
274 if hi > 0.:
275 if lo < 0.:
276 return 'indeterminate'
277 for key,yrValue in other.weights.items():
278 try:
279 myValue = self.weights[key]
280 except KeyError:
281 myValue = 0.
282 diff = abs(myValue-yrValue)
283 hi += diff
284 lo -= diff
285 if hi > 0.:
286 if lo < 0.:
287 return 'indeterminate'
288 if hi > 0.:
289 return 'less'
290 elif lo < 0.:
291 return 'greater'
292 else:
293 return 'equal'
294
296 """Slightly unorthodox subtraction
297 @return: a tuple (lo,hi) representing the min/max that the difference between these two planes will be on the unit hypercube"""
298 diff = self.weights.getArray() - other.weights.getArray()
299 hi = lo = 0.
300 for value in diff:
301 hi += abs(value)
302 lo += -abs(value)
303 hi -= self.threshold - other.threshold
304 lo -= self.threshold - other.threshold
305 return lo,hi
306
327
329 self.weights.instantiateKeys(table)
330 if isinstance(self.threshold,str):
331 self.threshold = float(table[self.threshold])
332 if len(self.weights.keys()) > 1:
333 return 0
334 else:
335 try:
336 key = self.weights.keys()[0]
337 except IndexError:
338 return -1
339 if key == keyConstant:
340 if self.weights[key] > self.threshold:
341 return 1
342 else:
343 return -1
344 else:
345 return 0
346
347 - def simpleText(self,numbers=True,all=False):
348 """
349 @param numbers: if C{True}, floats are used to represent the threshold; otherwise, an automatically generated English representation (defaults to C{False})
350 @type numbers: boolean
351 @return: a user-friendly string representation of this hyperplane
352 @rtype: str
353 """
354 if isinstance(self.weights,TrueRow):
355
356 return self.weights.simpleText()
357 elif len(self.weights) == 1:
358
359 key,weight = self.weights.items()[0]
360 if isinstance(key,IdentityKey) or isinstance(key,ClassKey) \
361 or isinstance(key,RelationshipKey):
362 if weight > 0:
363 return key.simpleText()
364 else:
365 return 'not %s' % (key.simpleText())
366 else:
367 threshold = self.threshold/weight
368 row = self.weights.keys()[0]
369 else:
370 row = self.weights.simpleText(numbers,all)
371 threshold = self.threshold
372 weight = 1.
373 if self.relation is None:
374 if numbers:
375 condition = '> %5.3f' % (threshold)
376 else:
377 level = self.simpleThreshold(threshold)
378 if weight < 0:
379
380
381 if abs(threshold) < epsilon:
382 condition = 'is negative'
383 else:
384 last = level
385 for index in getLevels():
386 label = levelStrings[index]
387 if level == label:
388 break
389 else:
390 last = label
391 level = last
392 condition = 'is no more than %5.3f' % (threshold)
393
394 elif abs(threshold) < epsilon:
395 condition = 'is positive'
396 else:
397 condition = 'is at least %5.3f' % (threshold)
398
399 content = '%s %s' % (row,condition)
400 else:
401
402 content = row
403 return content
404
406 """A L{DecisionTree} that requires L{KeyedPlane} branches.
407 @cvar planeClass: the L{Hyperplane} subclass used for the L{split} attribute
408 @ivar keys: the cumulative list of keys used by all of the branches below"""
409 planeClass = KeyedPlane
410
411 - def fill(self,keys,value=0.):
412 """Fills in any missing slots with a default value
413 @param keys: the slots that should be filled
414 @type keys: list of L{Key} instances
415 @param value: the default value (defaults to 0)
416 @note: does not overwrite existing values"""
417 if self.isLeaf():
418 try:
419 self.getValue().fill(keys,value)
420 except AttributeError:
421
422 pass
423 else:
424 for plane in self.split:
425 plane.weights.fill(keys,value)
426 falseTree,trueTree = self.getValue()
427 falseTree.fill(keys,value)
428 trueTree.fill(keys,value)
429
431 """
432 @return: all keys used in this tree and all subtrees and leaves
433 @rtype: C{L{Key}[]}
434 """
435 return self._keys().keys()
436
461
463 """Locks in the dimensions and keys of all leaves"""
464 if self.isLeaf():
465 try:
466 self.getValue().freeze()
467 except AttributeError:
468
469 pass
470 else:
471 for plane in self.split:
472 plane.weights.freeze()
473 for child in self.children():
474 child.freeze()
475
485
486 - def simpleText(self,printLeaves=True,numbers=True,all=False):
487 """Returns a more readable string version of this tree
488 @param printLeaves: optional flag indicating whether the leaves should also be converted into a user-friendly string
489 @type printLeaves: C{boolean}
490 @param numbers: if C{True}, floats are used to represent the threshold; otherwise, an automatically generated English representation (defaults to C{False})
491 @type numbers: boolean
492 @rtype: C{str}
493 """
494 if self.isLeaf():
495 if printLeaves:
496 value = self.getValue()
497 if printLeaves is True:
498 try:
499 content = value.simpleText(all=all)
500 except AttributeError:
501 content = str(value)
502 else:
503 content = printLeaves(value)
504 else:
505 content = '...'
506 else:
507 falseTree,trueTree = self.getValue()
508 falseTree = falseTree.simpleText(printLeaves,numbers,all).replace('\n','\n\t')
509 trueTree = trueTree.simpleText(printLeaves,numbers,all).replace('\n','\n\t')
510 plane = string.join(map(lambda p:p.simpleText(numbers,all),
511 self.split),' and ')
512 content = 'if %s\n\tthen %s\n\telse %s' \
513 % (plane,trueTree,falseTree)
514 return content
515
548
550 if self.isLeaf():
551 return self.__class__(self.getValue().instantiate(table))
552 else:
553 if len(self.split) > 1:
554 raise NotImplementedError,'Currently unable to instantiate trees with conjunction branches'
555 result = self.split[0].instantiate(table)
556 falseTree,trueTree = self.getValue()
557 if not isinstance(result,int):
558 new = self.__class__()
559 new.branch(result,falseTree.instantiate(table),
560 trueTree.instantiate(table))
561 return new
562 elif result > 0:
563
564 trueTree = trueTree.instantiate(table)
565 if trueTree.isLeaf():
566 return self.__class__(trueTree.getValue())
567 else:
568 new = self.__class__()
569 plane = trueTree.split
570 falseTree,trueTree = trueTree.getValue()
571 new.branch(plane,falseTree,trueTree)
572 return new
573 elif result < 0:
574
575 falseTree = falseTree.instantiate(table)
576 if falseTree.isLeaf():
577 return self.__class__(falseTree.getValue())
578 else:
579 new = self.__class__()
580 plane = falseTree.split
581 falseTree,trueTree = falseTree.getValue()
582 new.branch(plane,falseTree,trueTree)
583 return new
584
586 """Replaces any key references by the values in the table"""
587 self._instantiate(table)
588 self.updateKeys()
589
619
620
621
622
623
624
625 - def _multiply(self,other,comparisons=None,conditions=[]):
626 if comparisons is None:
627 comparisons = {}
628 result = self.__class__()
629 if other.isLeaf():
630 if self.isLeaf():
631 try:
632 result.makeLeaf(self.getValue()*other.getValue())
633 except TypeError:
634
635 result.makeLeaf(other.getValue()*self.getValue())
636 else:
637 falseTree,trueTree = self.getValue()
638 new = []
639 for plane in self.split:
640 vector = plane.weights * other.getValue()
641 new.append(plane.__class__(vector,plane.threshold,plane.test))
642 newF = falseTree._multiply(other,comparisons,
643 conditions+[(new,False)])
644 newT = trueTree._multiply(other,comparisons,
645 conditions+[(new,True)])
646 result.branch(new,newF,newT,pruneF=False,pruneT=False)
647 else:
648 result = DecisionTree._multiply(self,other,comparisons,conditions)
649 return result
650
652 """Scales all of the leaf nodes by the given float factor"""
653 if self.isLeaf():
654 self.makeLeaf(self.getValue()*factor)
655 else:
656 for subtree in self.children():
657 subtree.scale(factor)
658 return self
659
669
671 """
672 @return: a string representation of the internal array representation
673 @rtype: str
674 """
675 if self.isLeaf():
676 return str(self.getValue().getArray())
677 else:
678 falseTree,trueTree = self.getValue()
679 prefix = 'if '
680 content = ''
681 for plane in self.split:
682 content += '%s %s*x > %f:\n' % \
683 (prefix,
684 str(plane.weights.getArray().transpose()),
685 plane.threshold)
686 prefix = 'and '
687 substring = trueTree.toNumeric()
688 substring = substring.replace('\n','\n\t\t')
689 content += '\tthen:\t%s\n' % (substring)
690 substring = falseTree.toNumeric()
691 substring = substring.replace('\n','\n\t\t')
692 content += '\telse:\t%s\n' % (substring)
693 return content
694
696 """Replaces any identity matrices at the leaves of this tree with the number 1.
697 @return: the number of identity matrices found
698 @rtype: int
699 """
700 if self.isLeaf():
701 if self.getValue().isIdentity():
702 self.makeLeaf(1.)
703 return 1
704 else:
705 return 0
706 else:
707 return sum(map(self.__class__.pruneIdentities,self.children()))
708
735
738
741
742 if __name__ == '__main__':
743 key0 = makeStateKey('Poor','economicpower')
744 key1 = keyConstant
745 weights1 = KeyedVector({key0:0.6})
746 weights2 = KeyedVector({key0:0.5})
747
748 plane1 = KeyedPlane(weights1,0.5)
749 plane2 = KeyedPlane(weights2,0.5)
750 if plane1 < plane2:
751 print 'Less!'
752 elif plane1 > plane2:
753 print 'Great!'
754 else:
755 print 'Who knows?'
756 tree = KeyedTree()
757 tree.branch(plane1,'wait','act')
758
759 print tree[KeyedVector({key0:0.7})]
760 tree = copy.copy(tree)
761 print tree.getKeys()
762