1 import copy
2 from teamwork.math.matrices import epsilon
3 from teamwork.math.KeyedVector import KeyedVector
4 from teamwork.math.KeyedTree import KeyedPlane
5
7 """Tabular representation of a PWL function, as an alternative to L{KeyedTree<teamwork.math.KeyedTree.KeyedTree>}
8 @ivar rules: table of RHS, in dictionary form, indexed by row number
9 @ivar values: table of value function, in dictionary form, indexed by row number
10 @type values: intS{->}dict
11 @ivar attributes: the list of LHS conditions
12 @type attributes: L{KeyedVector}[]
13 @ivar _attributes: mapping from LHS condition to position in C{attributes} list
14 @type _attributes: L{KeyedVector}S{->}int
15 @ivar _consistency: table of cached consistency checks among attribute values
16 @type _consistency: intS{->}intS{->}bool
17 @ivar zeroPlanes: C{True} iff all of the attributes are hyperplanes through the origin
18 @type zeroPlanes: bool
19 """
20
22 self.rules = {}
23 self.values = {}
24 self._attributes = {}
25 self._consistency = {}
26 self.reset()
27 self.zeroPlanes = True
28
30 """Clears all existing contents (including attributes) of the table"""
31 self.attributes = []
32 self.initialize()
33
35 """Clears all existing contents (excluding attributes) of the table"""
36 self.rules.clear()
37 self.values.clear()
38 self._attributes.clear()
39 for index in range(len(self.attributes)):
40 self._attributes[str(self.attributes[index][0].getArray())] = index
41 self._consistency.clear()
42
44 """Inserts the new attribute/value into the LHS conditions for this policy
45 @param obj: the condition
46 @param value: the test value
47 @return: the index of the attribute
48 @rtype: int
49 """
50 if abs(value) > epsilon:
51 self.zeroPlanes = False
52
53 for index in range(len(self.attributes)):
54 other,values = self.attributes[index]
55 if obj == other:
56 if not value in values:
57 values.append(value)
58 values.sort()
59 return index
60 elif obj == -other:
61 if not value in values:
62 values.append(value)
63 values.sort()
64 return index - len(self.attributes)
65 else:
66
67 for index in range(len(self.attributes)):
68 if len(obj) == 2:
69 a,b = obj.getArray()
70 assert a > b-epsilon,obj.getArray()
71 less = solveTuple(obj) < solveTuple(self.attributes[index][0])
72 else:
73 less = list(obj.getArray()) < \
74 list(self.attributes[index][0].getArray())
75 if less:
76 self.attributes.insert(index,(obj,[value]))
77 return index
78 else:
79 self.attributes.append((obj,[value]))
80 return len(self.attributes)-1
81
83 """Deletes the attribute in the given position and reorganizes the rules accordingly
84 @warning: it does not do any clever aggregation over multiple rules that may be collapsed because of the deletion of this attribute
85 @param index: the position of the attribute to be deleted within the list of attributes
86 @type index: int
87 """
88 oldRules = copy.copy(self.rules)
89 oldValues = copy.copy(self.values)
90 oldFactors = {}
91 for rule in oldRules.keys():
92 oldFactors[rule] = self.index2factored(rule)
93 del oldFactors[rule][index]
94 del self.attributes[index]
95 for obj,values in self.attributes:
96 if len(values) != 1 or abs(values[0]) > epsilon:
97 break
98 else:
99 self.zeroPlanes = True
100 self.initialize()
101 for rule in oldRules.keys():
102 newRule = self.factored2index(oldFactors[rule])[0]
103 self.rules[newRule] = oldRules[rule]
104 self.values[newRule] = oldValues[rule]
105
106 - def index(self,state,observations={}):
107 """
108 @param state: the beliefs to use in identifying the appropriate rule
109 @return: the rule index corresponding to the given beliefs
110 @rtype: int
111 """
112 if observations:
113 raise NotImplementedError,'Direct testing of observations not currently supported'
114 factors = []
115 size = 1
116 for index in range(len(self.attributes)):
117 obj,values = self.attributes[index]
118 if isinstance(obj,KeyedVector):
119
120 value = obj*state
121 else:
122
123 try:
124 value = observations[obj.name]
125 except KeyError:
126 value = None
127 factors.append(self.subIndex(index,value))
128 return self.factored2index(factors)[0]
129
131 """Shortcut method, index is either an int (for directly indexing into the table) or a belief vector. Can't incorporate observations.
132 """
133 if isinstance(index,int):
134 rule = index
135 else:
136 rule = self.index(index,{})
137 try:
138 return self.rules[rule]
139 except KeyError:
140 print self
141 print index,rule,self.attributes[-1][0]
142 return None
143
145 """
146 Computes the index corresponding to the given value for the given attribute
147 @param attr: the index of the attribute
148 @type attr: int
149 @param value: the actual value to determine the index of
150 @rtype: int
151 """
152 obj,values = self.attributes[attr]
153 if isinstance(obj,KeyedVector):
154 for subIndex in range(len(values)):
155 if value < values[subIndex]+epsilon:
156 return subIndex
157 else:
158 return len(values)
159 else:
160 try:
161 return values.index(value)
162 except KeyError:
163 return 0
164
166 """Tests whether extending a partial LHS assignment with a given subIndex is self-consistent
167 @type assignment: int[]
168 @type subIndex: int
169 @return: C{True} iff the sub-index is consistent with the current partial assignment
170 @rtype: bool
171 """
172 newAttr,newVals = self.attributes[len(assignment)]
173 for pos in range(len(assignment)):
174 oldAttr,oldVals = self.attributes[pos]
175 if detectConflict(oldAttr,assignment[pos],newAttr,subIndex):
176 return False
177 else:
178 return True
179
181 """
182 Transforms a list of subindices into a list of matching rule indices
183 @param check: if C{True}, then check consistency before returning indices (default is C{False})
184 @type check: bool
185 @note: subindex can be a list of subindices
186 @type factors: int[]
187 @rtype: int[]
188 """
189
190 old = [[]]
191
192 for position in range(len(self.attributes)):
193 obj,values = self.attributes[position]
194 if isinstance(obj,KeyedVector):
195 size = len(values) + 1
196 else:
197 size = len(values)
198
199 new = []
200 for assignment in old:
201 if factors[position] is None:
202
203 possible = range(size)
204 elif isinstance(factors[position],tuple):
205
206 possible = range(factors[position][0],
207 factors[position][1]+1)
208 elif isinstance(factors[position],list):
209
210 possible = factors[position]
211 else:
212
213 possible = [factors[position]]
214 for subIndex in possible:
215 if not check or self.consistentp(assignment,subIndex):
216 new.append(assignment + [subIndex])
217 old = new
218
219 indices = []
220 for assignment in old:
221 index = 0
222 for position in range(len(self.attributes)):
223 obj,values = self.attributes[position]
224 if isinstance(obj,KeyedVector):
225 size = len(values) + 1
226 else:
227 size = len(values)
228 index *= size
229 index += assignment[position]
230 indices.append(index)
231 return indices
232
270
272 """
273 Transforms a rule index into a list of subindices
274 @type index: int
275 @rtype: int[]
276 """
277 factors = []
278 for pos in range(len(self.attributes)):
279 obj,values = self.attributes[-pos-1]
280 if isinstance(obj,KeyedVector):
281 size = len(values) + 1
282 else:
283 size = len(values)
284 factors.insert(0,index % size)
285 index /= size
286 return factors
287
289 """Extract a tabular representation of the given PWL tree. Updates this tree to represent the same PWL function as the given tree.
290 @param tree: the tree to import
291 @type tree: L{KeyedTree}
292 """
293 self.reset()
294 remaining = [tree]
295 while remaining:
296 node = remaining.pop()
297 if not node.isLeaf():
298 if not node.isProbabilistic():
299 for plane in node.split:
300 print plane
301 self.addAttribute(plane.weights,plane.threshold)
302 remaining += node.children()
303 for obj,values in self.attributes:
304 print obj
305 print values
306
308 """
309 @return: the base table (stripped of any subclass extras)
310 @rtype: L{PWLTable}
311 """
312 result = PWLTable()
313 return self.copy(result)
314
315 - def _consistent(self,attr1,great1,attr2,great2=None,debug=False):
316 """Compares an attribute-value pair against another (or others) to determine whether they're potentially consistent
317 @type attr1: L{KeyedVector}
318 @type attr2: L{KeyedVector} or (L{KeyedVector},bool)[]
319 @type great1: bool
320 @type great2: bool or None
321 @return: C{False} if never consistent, C{True} if always consistent, C{None} otherwise
322 """
323 if isinstance(attr1,int) and isinstance(attr2,int):
324
325 cache1 = '%d,%d' % (attr1,int(great1))
326 cache2 = '%d,%d' % (attr2,int(great2))
327 try:
328 table = self._consistency[cache1]
329 try:
330 result = self._consistency[cache1][cache2]
331 if debug: print '\tCache hit:',result
332 return result
333 except KeyError:
334 pass
335 except KeyError:
336 self._consistency[cache1] = {}
337 else:
338 cache1,cache2 = None,None
339 if isinstance(attr1,int):
340 attr1 = self.attributes[attr1][0]
341 if isinstance(attr2,list):
342
343 value = None
344 for pos in range(len(attr2)):
345
346 test = self._consistent(attr2[pos][0],attr2[pos][1],
347 attr1,great1,debug)
348 if debug: print '\t\t',test
349 if test is False:
350
351 return False
352 elif test is True:
353
354 value = True
355 return value
356 elif isinstance(attr2,int):
357 attr2 = self.attributes[attr2][0]
358
359 key1,key2 = attr1.keys()
360 if len(attr1) != 2:
361
362 raise NotImplementedError,'I handle only binary attributes'
363 try:
364 weight1 = - attr1[key2] / attr1[key1]
365 except:
366 raise NotImplementedError,'Unable to handle unary tests: %s' \
367 % (str(attr1.getArray()))
368 if attr1[key1] < 0.:
369 great1 = not great1
370 if len(attr2) != 2:
371
372 raise NotImplementedError,'I handle only binary attributes'
373
374 try:
375 weight2 = - attr2[key2] / attr2[key1]
376 except:
377 raise NotImplementedError,'Unable to handle unary tests'
378 if attr2[key1] < 0.:
379 great2 = not great2
380 if debug:
381 print '\tComparing:',getProbRep(attr1,great1)
382 print '\tvs.:',getProbRep(attr2,great2)
383 result = None
384 if great1 != great2:
385
386 if great1:
387
388 if weight1 > weight2:
389 if debug: print '\t\tInconsistent'
390 result = False
391 else:
392
393 if weight2 > weight1:
394 result = False
395 elif attr1 == attr2:
396
397 if debug: print '\t\tEqual'
398 result = True
399 elif len(attr1) == 2:
400
401 thresh1 = solveTuple(attr1)
402 thresh2 = solveTuple(attr2)
403
404 if great1 and great2:
405
406 if thresh1 > thresh2:
407
408 if debug: print '\t\tSubsumed'
409 result = True
410 elif not great1 and not great2:
411
412 if thresh1 < thresh2:
413
414 if debug: print '\t\tSubsumed'
415 result = True
416 if cache1:
417 self._consistency[cache1][cache2] = result
418 return result
419
420 - def prune(self,rulesOnly=False,debug=False):
421 """Removes rows and attributes that are irrelevant
422 @param rulesOnly: if C{True}, only the RHS of the rules need to be distinct, not the value function as well (default is C{False})
423 @type rulesOnly: bool
424 """
425 self.pruneRules(debug)
426 self.pruneAttributes(rulesOnly,debug)
427
429 keyList = self.rules.keys()
430 if not keyList:
431 keyList = self.values.keys()
432 if debug:
433 print 'Starting with %d rules' % (len(keyList))
434
435 for rule in keyList[:]:
436 factors = self.index2factored(rule)
437 consistent = True
438 for i in range(len(self.attributes)-1):
439 attrI,values = self.attributes[i]
440 assert values == [0.],'Unable to prune tables with nonzero intercepts in their LHS conditions'
441 for j in range(i+1,len(self.attributes)):
442 attrJ,values = self.attributes[j]
443 pairwise = self._consistent(i,bool(factors[i]),j,bool(factors[j]))
444 if pairwise is None:
445 pass
446 elif pairwise:
447 pass
448 else:
449 assert pairwise is False
450 consistent = False
451 if debug:
452 print
453 print attrI.getArray(),factors[i]
454 print 'inconsistent with'
455 print attrJ.getArray(),factors[j]
456 break
457 if not consistent:
458
459 if self.rules.has_key(rule):
460 del self.rules[rule]
461 if self.values.has_key(rule):
462 del self.values[rule]
463 keyList.remove(rule)
464 break
465
467 """Prune irrelevant attributes
468 """
469 if debug:
470 print 'Starting with %d attributes' % (self.attributes)
471 keyList = self.rules.keys()
472 if not keyList:
473 keyList = self.values.keys()
474 delete = []
475 for attrIndex in range(len(self.attributes)):
476 old = {}
477 attr,values = self.attributes[attrIndex]
478 distinct = False
479 if debug:
480 print 'Testing distinctness of:',attr.getArray()
481 for ruleIndex in keyList:
482 if debug:
483 print '\tStarting rule:',ruleIndex
484 if not old.has_key(ruleIndex):
485 factors = self.index2factored(ruleIndex)
486 rhs = None
487 for valueIndex in range(len(values)+1):
488 factors[attrIndex] = valueIndex
489 ruleIndex = self.factored2index(factors)[0]
490 if debug:
491 print '\tTesting rule:',ruleIndex
492 if ruleIndex in keyList:
493 old[ruleIndex] = True
494 if rhs is None:
495 rhs = {}
496 if self.rules.has_key(ruleIndex):
497 rhs['rules'] = self.rules[ruleIndex]
498 if not rulesOnly and self.values.has_key(ruleIndex):
499
500 rhs['values'] = self.values[ruleIndex]
501 elif not distinct:
502
503 me = {}
504 if self.rules.has_key(ruleIndex):
505 me['rules'] = self.rules[ruleIndex]
506 if not rulesOnly and self.values.has_key(ruleIndex):
507 me['values'] = self.values[ruleIndex]
508 if rhs != me:
509 distinct = True
510 if not distinct and not attrIndex in delete:
511 if debug:
512 print 'Delete attribute:',attrIndex
513 delete.append(attrIndex)
514 delete.sort()
515 delete.reverse()
516 for attrIndex in delete:
517 self.delAttribute(attrIndex)
518
519 - def max(self,debug=False):
520 """
521 Computes the rules based on maximizing the values in this table
522 @return: the table with the newly generated rules
523 @rtype: L{PWLTable}
524 @warning: assumes that the same option keys exist in every rule in the value function
525 """
526
527 options = self.values.values()[0].keys()
528 options.sort()
529 rhs = {}
530 for desired in options:
531 rhs[desired] = []
532 others = {}
533 rules = self.values.keys()
534 rules.sort()
535 for rule in rules:
536
537 factors = self.index2factored(rule)
538 if debug:
539 print 'Rule:',rule
540 lhs = []
541 for index in range(len(self.attributes)):
542 lhs.append((self.attributes[index][0],factors[index]))
543 if debug:
544 print '\t',getProbRep(self.attributes[index][0],factors[index])
545
546 for desired in options:
547 others[desired] = []
548
549 for i in range(len(options)):
550 desired = options[i]
551 if not others.has_key(desired):
552
553 continue
554 if debug:
555 print desired,self.values[rule][desired].getArray()
556
557 path = lhs + others[desired]
558
559 for j in range(i+1,len(options)):
560 alternative = options[j]
561 if not others.has_key(alternative):
562
563 continue
564 if debug:
565 print
566 print '\tvs.',alternative,self.values[rule][alternative].getArray()
567
568
569
570
571
572 weights = self.values[rule][desired] - self.values[rule][alternative]
573 try:
574 weights.normalize()
575 except ZeroDivisionError:
576 if debug: print '\tZero vector'
577 del others[alternative]
578 continue
579 side = 1
580 if len(weights) == 2:
581
582 a,b = weights.getArray()
583 if a < b:
584 weights = -weights
585 side = 0
586 if debug: print '\tDifference:',getProbRep(weights,side)
587
588 test = KeyedPlane(weights,0.).always(probability=True)
589 if test is None:
590 pass
591 elif not bool(side) is test:
592 if debug: print '\tNever True'
593 break
594 elif bool(side) is test:
595 if debug: print '\tAlways True'
596 del others[alternative]
597 continue
598
599 test = self._consistent(weights,side,lhs)
600 if debug: print '\tConsistent?',test
601 if test is None:
602
603 test = self._consistent(weights,side,others[desired])
604 if test:
605 if debug: print '\t\tSubsumed by pre-condition'
606 continue
607 elif test is False:
608 if debug: print '\t\tInconsistent with pre-condition'
609 continue
610
611 if others[alternative]:
612 test = self._consistent(weights,1-side,lhs+others[alternative])
613 else:
614 test = None
615 if test is None:
616 if debug: print '\tPrecondition:',getProbRep(weights,1-side)
617 others[alternative].append((weights,1-side))
618 elif test is False:
619 if debug: print '\tImpossible'
620 del others[alternative]
621 else:
622 assert test is True
623 if debug: print '\tSubsumed'
624
625 if debug:
626 print
627 index = 0
628 while index < len(path):
629 test = self._consistent(path[index][0],path[index][1],weights,side)
630 if test is True:
631 if debug: print '\tSubsumes:',getProbRep(path[index][0],path[index][1])
632 del path[index]
633 else:
634 index += 1
635 path.append((weights,side))
636 elif test:
637
638 if debug: print '\tDominated'
639 del others[alternative]
640 else:
641
642 if debug: print '\tInconsistent'
643 break
644 else:
645
646 if debug:
647 print 'Final for rule',rule,desired
648 for weights,side in path:
649 print '\t',getProbRep(weights,side)
650 rhs[desired].append({'lhs':path,'value':rule})
651
652 policy = PWLTable()
653 for desired,conditions in rhs.items():
654 for condition in conditions:
655 path = condition['lhs']
656 for index in range(len(path)):
657 pos = policy.addAttribute(path[index][0],0.)
658 policy.initialize()
659 if debug:
660 print 'New attributes:'
661 for attr in policy.attributes:
662 print '\t',getProbRep(attr[0])
663
664 cache = {}
665 for desired,conditions in rhs.items():
666 if debug: print 'Processing:',desired
667 for condition in conditions:
668 path = condition['lhs']
669 if debug: print 'From rule',condition['value']
670 for index in range(len(path)):
671 try:
672 attr = policy._attributes[str(path[index][0].getArray())]
673 value = path[index][1]
674 except KeyError:
675 attr = policy._attributes[str(-path[index][0].getArray())]
676 assert policy.attributes[attr][1] == [0.]
677 value = 1 - path[index][1]
678 if debug: print '\t%d,%d' % (attr,value)
679 path[index] = (attr,value)
680
681 for desired,conditions in rhs.items():
682 if debug: print 'Inserting:',desired
683 for condition in conditions:
684
685 factors = map(lambda i: None,range(len(policy.attributes)))
686 path = []
687
688 for attr,value in condition['lhs']:
689 assert isinstance(attr,int)
690 assert factors[attr] is None
691 factors[attr] = value
692 path.append((attr,value))
693 if debug: print 'Original:',condition['value'],policy.factorString(factors)
694 for attr in range(len(policy.attributes)):
695 if factors[attr] is None:
696 test = policy._consistent(attr,1,path)
697 if debug: print '\t',getProbRep(policy.attributes[attr][0],1),test
698 if test is None:
699
700 factors[attr] = [0,1]
701 elif test:
702
703 factors[attr] = 1
704 path.append((policy.attributes[attr][0],1))
705 else:
706
707 factors[attr] = 0
708 path.append((policy.attributes[attr][0],0))
709 if debug: print '\t%s' % (policy.factorString(factors))
710
711 for index in policy.factored2index(factors):
712 if debug: print index,
713 assert not policy.rules.has_key(index)
714 policy.rules[index] = desired
715 policy.values[index] = self.values[condition['value']]
716 if debug: print
717 return policy
718
720 """Computes the optimal value function, independent of action
721 @return: a table with the optimal value as the rules' RHS, and no values
722 @rtype: L{PWLTable}
723 """
724 result = self.getTable()
725 for rule in result.rules.keys():
726 rhs = result.rules[rule]
727 result.rules[rule] = result.values[rule][str(rhs)]
728 result.values[rule].clear()
729 return result
730
732 result = self.copy(self.__class__())
733 return result
734
735 - def copy(self,result):
736 result.attributes = self.attributes[:]
737 result.rules.update(self.rules)
738 for key,table in self.values.items():
739 result.values[key] = {}
740 result.values[key].update(table)
741 return result
742
744 count = 1
745 for obj,values in self.attributes:
746 if isinstance(obj,KeyedVector):
747 count *= len(values)+1
748 else:
749 count *= len(values)
750 return count
751
752 - def __add__(self,other,debug=False):
753 if self.zeroPlanes and other.zeroPlanes:
754 return self.mergeZero(other,lambda x,y: x+y,None,debug)
755 result = PWLTable()
756 if debug:
757 print 'I:',self
758 print 'U:',other
759
760 for obj,values in other.attributes:
761 result.attributes.append((obj,values[:]))
762
763 for obj,values in self.attributes:
764 index = result.addAttribute(obj,values[0])
765 for value in values[1:]:
766 if not value in result.attributes[index][1]:
767 result.attributes[index][1].append(value)
768 result.attributes[index][1].sort()
769 result.initialize()
770 if debug:
771 print 'New attributes:'
772 for attr in result.attributes:
773 print '\t',attr[0].getArray()
774
775 for myRule in self.values.keys():
776 myFactors = self.index2factored(myRule)
777 for yrRule in other.values.keys():
778 yrFactors = other.index2factored(yrRule)
779
780 newFactors = result.mapIndex(other,yrFactors)
781 newFactors = result.mapIndex(self,myFactors,newFactors)
782 if isinstance(newFactors,list):
783
784 indexList = result.factored2index(newFactors)
785 for option,yrRHS in other.values[yrRule].items():
786
787 myRHS = self.values[myRule][option]
788 newRHS = myRHS + yrRHS
789 for newIndex in indexList:
790 if not result.values.has_key(newIndex):
791 result.values[newIndex] = {}
792 assert not result.values[newIndex].has_key(option)
793 result.values[newIndex][option] = newRHS
794 return result
795
796 - def __mul__(self,other,combiner=None,debug=False):
797 """
798 @param combiner: optional binary function for using in combining RHS matrices (default is multiplication, duh)
799 @type combiner: lambda
800 @warning: like matrix multiplication, not commutative
801 """
802 if self.zeroPlanes and other.zeroPlanes:
803 return self.mergeZero(other,combiner,lambda x,y: x*y,debug)
804 result = self.__class__()
805
806 for obj,values in other.attributes:
807 result.attributes.append((obj,values[:]))
808
809 for rule,V in other.values.items():
810 if V:
811
812 new = V.values()
813 else:
814
815 new = [other.rules[rule]]
816 for rhs in new:
817 for obj,values in self.attributes:
818 new = obj*rhs
819 new.normalize()
820 index = None
821 for value in values:
822 plane = KeyedPlane(new,value)
823 if plane.always(probability=True) is None:
824 if index is None:
825 index = result.addAttribute(new,value)
826 else:
827 result.attributes[index][1].append(value)
828 if not index is None:
829 result.attributes[index][1].sort()
830 result.initialize()
831 if debug:
832 print 'New attributes:'
833 for attr in result.attributes:
834 print '\t',attr[0].getArray()
835
836 for myRule in self.values.keys():
837 myFactors = self.index2factored(myRule)
838 for yrRule in other.values.keys():
839 yrFactors = other.index2factored(yrRule)
840 for option,yrRHS in other.values[yrRule].items():
841
842 try:
843 myRHS = self.values[myRule][option]
844 except KeyError:
845
846 myRHS = self.rules[myRule]
847 if combiner:
848 newRHS = combiner(myRHS,yrRHS)
849 else:
850 newRHS = myRHS*yrRHS
851 if debug:
852 print
853 print 'A:'
854 for index in range(len(self.attributes)):
855 print bool(myFactors[index]),self.attributes[index][0].getArray()
856
857 print 'B:'
858 for index in range(len(other.attributes)):
859 print bool(yrFactors[index]),other.attributes[index][0].getArray()
860
861 print option
862 print 'Product:',newRHS.getArray()
863
864 newFactors = result.mapIndex(other,yrFactors,debug=debug)
865 newFactors = result.mapIndex(self,myFactors,
866 newFactors,yrRHS,debug=debug)
867 if isinstance(newFactors,list):
868
869 if debug:
870 print newFactors
871 for newIndex in result.factored2index(newFactors):
872 if debug:
873 print newIndex,
874 if not result.values.has_key(newIndex):
875 result.values[newIndex] = {}
876 assert not result.values[newIndex].has_key(option)
877 result.values[newIndex][option] = newRHS
878 if debug:
879 print
880 elif debug:
881 print 'Rejected'
882 result.pruneAttributes()
883 return result
884
885 - def mergeZero(self,other,combiner=None,projector=None,debug=False):
886 """
887 Merging when both tables have all of their hyperplanes going through the origin
888 @param combiner: optional binary function for using in combining RHS matrices (default is multiplication)
889 @type combiner: lambda
890 @param projector: optional binary function for using in projecting my LHS attributes based on the RHS of the other
891 @type projector: lambda
892 @warning: like matrix multiplication, not commutative
893 """
894
895 entries = []
896 for yrRule,yrValue in other.values.items():
897 yrFactors = other.index2factored(yrRule)
898 for myRule,myValue in self.values.items():
899 myFactors = self.index2factored(myRule)
900 for option,yrRHS in yrValue.items():
901
902 try:
903 myRHS = self.values[myRule][option]
904 except KeyError:
905
906 myRHS = self.rules[myRule]
907 if debug:
908 print 'Combining:',self.factorString(myFactors),myRHS.getArray()
909 print 'with:',yrFactors,yrRHS.getArray()
910 print 'under:',option
911 path = []
912 for myIndex in range(len(self.attributes)):
913
914 myAttr = self.attributes[myIndex][0]
915 if projector:
916 newAttr = projector(myAttr,yrRHS)
917 newAttr.normalize()
918 if debug:
919 print '\t\t\tProjecting:',myAttr.getArray(),myFactors[myIndex]
920 print '\t\t\tInto:',newAttr.getArray()
921
922 threshold = solveTuple(newAttr)
923 if not isinstance(threshold,float):
924 if threshold == myFactors[myIndex]:
925
926 if debug:
927 print '\tRedundant:',newAttr.getArray()
928 newAttr = None
929 else:
930
931 if debug:
932 print '\tInconsistent:',newAttr.getArray()
933 consistent = False
934 break
935 else:
936 newAttr = myAttr
937 if newAttr:
938 if debug:
939 print '\t\tChecking:',newAttr.getArray()
940 consistent = True
941
942 for yrIndex in range(len(other.attributes)):
943 yrAttr = other.attributes[yrIndex][0]
944 if detectConflict(newAttr,myFactors[myIndex],
945 yrAttr,yrFactors[yrIndex]):
946
947 consistent = False
948 break
949 for yrAttr,yrSide in path:
950 if detectConflict(newAttr,myFactors[myIndex],
951 yrAttr,yrSide):
952
953 consistent = False
954 break
955 else:
956 path.append((newAttr,myFactors[myIndex]))
957 if not consistent:
958 break
959 else:
960
961 entry = {'LHS': yrFactors + path,'option': option}
962 if combiner:
963 entry['RHS'] = combiner(myRHS,yrRHS)
964 else:
965 entry['RHS'] = myRHS*yrRHS
966 if debug:
967 print 'Path found:'
968 for vector,side in entry['LHS']:
969 print '\t',getProbRep(vector,side)
970 print '\t',entry['RHS'].getArray()
971 entries.append(entry)
972
973 result = self.__class__()
974
975 attributes = {}
976 for obj,values in other.attributes:
977 result.addAttribute(obj,values[0])
978 attributes[str(obj.getArray())] = True
979
980 for entry in entries:
981 for attr,side in entry['LHS'][len(other.attributes):]:
982 key = str(attr.getArray())
983 if not attributes.has_key(key):
984 result.addAttribute(attr,0.)
985 attributes[key] = True
986
987 for index in range(len(result.attributes)):
988 attributes[str(result.attributes[index][0].getArray())] = index
989 if debug:
990 print 'Attribute:',result.attributes[index][0].getArray()
991
992 for entry in entries:
993 if debug:
994 print 'New Entry:',entry['option']
995 for vector,side in entry['LHS']:
996 print '\t',getProbRep(vector,side)
997
998 factors = []
999 for attr in result.attributes:
1000 factors.append(None)
1001
1002 for index in range(len(other.attributes)):
1003 plane = other.attributes[index][0]
1004 side = entry['LHS'][index]
1005 factors[attributes[str(plane.getArray())]] = side
1006
1007 for plane,side in entry['LHS'][len(other.attributes):]:
1008 factors[attributes[str(plane.getArray())]] = side
1009
1010 indexList = result.factored2index(factors,check=True)
1011 if debug:
1012 print '\t',factors
1013 print '\t',indexList
1014 for index in indexList:
1015 if not result.values.has_key(index):
1016 result.values[index] = {}
1017 assert not result.values[index].has_key(entry['option']),\
1018 '%d, %s' % (index,entry['option'])
1019 result.values[index][entry['option']] = entry['RHS']
1020 return result
1021
1022 - def mapIndex(self,other,factors,result=None,multiplicand=None,debug=False):
1023 """Translates an index in another table into one for this table
1024 @param other: the other table
1025 @type other: L{PWLTable}
1026 @param factors: the index or factors of the rule to map
1027 @type factors: int or int[]
1028 @param result: previously determined factors that should be merged (default is C{None})
1029 @type result: int[]
1030 @param multiplicand: matrix used to scale any attributes (default is identity)
1031 @type multiplicand: L{KeyedMatrix<teamwork.math.KeyedMatrix.KeyedMatrix>}
1032 @return: a list of attributes subindices, C{None} if no consistent index exists
1033 @rtype: int[]
1034 """
1035 if result is None:
1036 result = map(lambda attr: None,self.attributes)
1037 if isinstance(factors,int):
1038 factors = other.index2factored(factors)
1039 for pos in range(len(factors)):
1040 obj,values = other.attributes[pos]
1041 assert values == [0.],'Unable to handle non-zero thresholds'
1042 if multiplicand:
1043
1044 obj = obj*multiplicand
1045 obj.normalize()
1046
1047 greater = factors[pos] == 1
1048 if debug:
1049 print '\tMapping:',obj.getArray()
1050
1051 try:
1052 index = self._attributes[str(obj.getArray())]
1053 except KeyError:
1054 try:
1055 index = self._attributes[str(-obj.getArray())]
1056 greater = not greater
1057 except KeyError:
1058
1059 assert values == [0.]
1060 plane = KeyedPlane(obj,0.)
1061 always = plane.always(probability=True)
1062 if always is None:
1063 raise UserWarning,str(plane)
1064 elif always:
1065
1066 if greater:
1067 continue
1068 else:
1069 if debug:
1070 print '\t',always
1071 return None
1072 else:
1073
1074 if greater:
1075 if debug:
1076 print '\t',always
1077 return None
1078 else:
1079 continue
1080 obj,values = self.attributes[index]
1081 assert values == [0.],'Unable to handle non-zero thresholds'
1082 if result[index] is None:
1083
1084 if index < 0:
1085 always = self._consistent(index+len(self.attributes),greater,
1086 map(lambda i: (i,result[i]),range(len(self.attributes))))
1087 else:
1088 always = self._consistent(index,greater,
1089 map(lambda i: (i,result[i]),range(len(self.attributes))))
1090 if always is False:
1091 return None
1092
1093 if greater:
1094 result[index] = 1
1095 else:
1096 result[index] = 0
1097 else:
1098
1099 if greater:
1100 if result[index] == 0:
1101
1102 return None
1103 elif result[index] == 1:
1104
1105 return None
1106 return result
1107
1115
1131
1148
1150 """Helper method that returns column headings for the attributes
1151 @param rhsLabel: column heading to use for RHS (default is 'Action')
1152 @type rhsLabel: str
1153 @rtype: str
1154 """
1155 row = 'Index'
1156 for obj,values in self.attributes:
1157 row += '%s' % (attrString(obj))
1158 row += '\tAction'
1159 return row
1160
1162 """Helper method that returns string representation of factor tuple
1163 @param factors: factors (or rule index)
1164 @type factors: int or int[]
1165 """
1166 lhs = ''
1167 if not isinstance(factors,list):
1168 factors = self.index2factored(factors)
1169 for attr in range(len(self.attributes)):
1170 values = self.attributes[attr][1]
1171 if len(self.attributes[attr][0]) == 2 and values == [0.]:
1172 if factors[attr] == 0:
1173 return (getProbRep(self.attributes[attr][0],
1174 factors[attr]))
1175 elif attr == len(self.attributes)-1:
1176 return (getProbRep(self.attributes[attr][0],
1177 factors[attr]))
1178 else:
1179 if factors[attr] == 0:
1180 lhs += '\t<=%8.3f' % (values[factors[attr]])
1181 elif factors[attr] == len(values):
1182 lhs += '\t >%8.3f' % (values[-1])
1183 else:
1184 lhs += '\t<=%5.3f,>%5.3f' % (values[factors[attr-1]],
1185 values[factors[attr]])
1186 return lhs
1187
1189 """
1190 @return: a happy string representation of the given attribute
1191 @rtype: str
1192 """
1193 if isinstance(attr,KeyedVector):
1194 if len(attr) == 2:
1195 return getProbRep(attr,True)
1196 else:
1197 keys = filter(lambda k: abs(attr[k]) > epsilon,attr.keys())
1198 if len(keys) == 2:
1199 return getArrayRep(attr[keys[0]],attr[keys[1]],True)
1200 elif len(keys) == 1:
1201 return '%s>0.' % (keys[0])
1202 else:
1203 return '\t'+','.join(map(lambda x: '%6.4f' % (x),attr.getArray()))
1204 else:
1205 return ' %s action' % (attr.name)
1206
1208 """
1209 @return: for probabilistic tuples, returns a unary constraint represenation of this vector
1210 @rtype: str
1211 """
1212
1213 label = getArrayRep(vector.getArray()[0],vector.getArray()[1],side)
1214 return label
1215
1217 """
1218 @return: for binary array, returns a unary constraint representation of this vector
1219 @rtype: str
1220 """
1221
1222 threshold,var = solveTuple(a,b),'L'
1223 if a-b < 0.:
1224 side = not side
1225 if side is None:
1226 sign = '??'
1227 elif side:
1228 sign = '> '
1229 else:
1230 sign = '<='
1231 return ' %s%s%5.3f' % (var,sign,threshold)
1232
1234 """Solves a 2-dimensional vector for one of the variables
1235 @param vector: ax + by
1236 @type vector: L{KeyedVector}
1237 @return: -b/(a-b) if a!=b; otherwise, C{True} iff b>0
1238 @rtype: float or bool
1239 """
1240
1241 if b is None:
1242 a,b = a.getArray()
1243 try:
1244 return -b/(a-b)
1245 except:
1246 return b > 0.
1247
1249 """Detects whether there is a conflict between two attribute-value pairs, where each attribute is a binary, 2-dimensional vector
1250 @type side1,side2: bool
1251 @type vector1,vector2: L{KeyedVector}
1252 @return: C{True} if there is a conflict
1253 """
1254 if len(vector1) == 2 and len(vector2) == 2:
1255 weight1 = solveTuple(vector1)
1256 weight2 = solveTuple(vector2)
1257
1258 if side1 != side2:
1259 if side1:
1260
1261 if weight1 > weight2-epsilon:
1262 return True
1263 else:
1264
1265 if weight2 > weight1-epsilon:
1266 return True
1267 elif len(vector1) == 1 and len(vector2) == 1:
1268 raise NotImplementedError,'I should be able to do this, but my creator is lazy'
1269 else:
1270 raise NotImplementedError,'Your %d-dimensional vectors frighten and confuse me' % (max(len(vector1),len(vector2)))
1271 return False
1272