Package teamwork :: Package policy :: Module pwlTable
[hide private]
[frames] | no frames]

Source Code for Module teamwork.policy.pwlTable

   1  import copy 
   2  from teamwork.math.matrices import epsilon 
   3  from teamwork.math.KeyedVector import KeyedVector 
   4  from teamwork.math.KeyedTree import KeyedPlane 
   5   
6 -class PWLTable:
7 """Tabular representation of a PWL function, as an alternative to L{KeyedTree<teamwork.math.KeyedTree.KeyedTree>} 8 @ivar rules: table of RHS, in dictionary form, indexed by row number 9 @ivar values: table of value function, in dictionary form, indexed by row number 10 @type values: intS{->}dict 11 @ivar attributes: the list of LHS conditions 12 @type attributes: L{KeyedVector}[] 13 @ivar _attributes: mapping from LHS condition to position in C{attributes} list 14 @type _attributes: L{KeyedVector}S{->}int 15 @ivar _consistency: table of cached consistency checks among attribute values 16 @type _consistency: intS{->}intS{->}bool 17 @ivar zeroPlanes: C{True} iff all of the attributes are hyperplanes through the origin 18 @type zeroPlanes: bool 19 """ 20
21 - def __init__(self):
22 self.rules = {} 23 self.values = {} 24 self._attributes = {} 25 self._consistency = {} 26 self.reset() 27 self.zeroPlanes = True
28
29 - def reset(self):
30 """Clears all existing contents (including attributes) of the table""" 31 self.attributes = [] 32 self.initialize()
33
34 - def initialize(self):
35 """Clears all existing contents (excluding attributes) of the table""" 36 self.rules.clear() 37 self.values.clear() 38 self._attributes.clear() 39 for index in range(len(self.attributes)): 40 self._attributes[str(self.attributes[index][0].getArray())] = index 41 self._consistency.clear()
42
43 - def addAttribute(self,obj,value):
44 """Inserts the new attribute/value into the LHS conditions for this policy 45 @param obj: the condition 46 @param value: the test value 47 @return: the index of the attribute 48 @rtype: int 49 """ 50 if abs(value) > epsilon: 51 self.zeroPlanes = False 52 # Look for an existing test 53 for index in range(len(self.attributes)): 54 other,values = self.attributes[index] 55 if obj == other: 56 if not value in values: 57 values.append(value) 58 values.sort() 59 return index 60 elif obj == -other: 61 if not value in values: 62 values.append(value) 63 values.sort() 64 return index - len(self.attributes) 65 else: 66 # New attribute, insert in order 67 for index in range(len(self.attributes)): 68 if len(obj) == 2: 69 a,b = obj.getArray() 70 assert a > b-epsilon,obj.getArray() 71 less = solveTuple(obj) < solveTuple(self.attributes[index][0]) 72 else: 73 less = list(obj.getArray()) < \ 74 list(self.attributes[index][0].getArray()) 75 if less: 76 self.attributes.insert(index,(obj,[value])) 77 return index 78 else: 79 self.attributes.append((obj,[value])) 80 return len(self.attributes)-1
81
82 - def delAttribute(self,index):
83 """Deletes the attribute in the given position and reorganizes the rules accordingly 84 @warning: it does not do any clever aggregation over multiple rules that may be collapsed because of the deletion of this attribute 85 @param index: the position of the attribute to be deleted within the list of attributes 86 @type index: int 87 """ 88 oldRules = copy.copy(self.rules) 89 oldValues = copy.copy(self.values) 90 oldFactors = {} 91 for rule in oldRules.keys(): 92 oldFactors[rule] = self.index2factored(rule) 93 del oldFactors[rule][index] 94 del self.attributes[index] 95 for obj,values in self.attributes: 96 if len(values) != 1 or abs(values[0]) > epsilon: 97 break 98 else: 99 self.zeroPlanes = True 100 self.initialize() 101 for rule in oldRules.keys(): 102 newRule = self.factored2index(oldFactors[rule])[0] 103 self.rules[newRule] = oldRules[rule] 104 self.values[newRule] = oldValues[rule]
105
106 - def index(self,state,observations={}):
107 """ 108 @param state: the beliefs to use in identifying the appropriate rule 109 @return: the rule index corresponding to the given beliefs 110 @rtype: int 111 """ 112 if observations: 113 raise NotImplementedError,'Direct testing of observations not currently supported' 114 factors = [] 115 size = 1 116 for index in range(len(self.attributes)): 117 obj,values = self.attributes[index] 118 if isinstance(obj,KeyedVector): 119 # Determine which plane interval this state is in 120 value = obj*state 121 else: 122 # Need to pick out observation 123 try: 124 value = observations[obj.name] 125 except KeyError: 126 value = None 127 factors.append(self.subIndex(index,value)) 128 return self.factored2index(factors)[0]
129
130 - def __getitem__(self,index):
131 """Shortcut method, index is either an int (for directly indexing into the table) or a belief vector. Can't incorporate observations. 132 """ 133 if isinstance(index,int): 134 rule = index 135 else: 136 rule = self.index(index,{}) 137 try: 138 return self.rules[rule] 139 except KeyError: 140 print self 141 print index,rule,self.attributes[-1][0] 142 return None
143
144 - def subIndex(self,attr,value):
145 """ 146 Computes the index corresponding to the given value for the given attribute 147 @param attr: the index of the attribute 148 @type attr: int 149 @param value: the actual value to determine the index of 150 @rtype: int 151 """ 152 obj,values = self.attributes[attr] 153 if isinstance(obj,KeyedVector): 154 for subIndex in range(len(values)): 155 if value < values[subIndex]+epsilon: 156 return subIndex 157 else: 158 return len(values) 159 else: 160 try: 161 return values.index(value) 162 except KeyError: 163 return 0
164
165 - def consistentp(self,assignment,subIndex):
166 """Tests whether extending a partial LHS assignment with a given subIndex is self-consistent 167 @type assignment: int[] 168 @type subIndex: int 169 @return: C{True} iff the sub-index is consistent with the current partial assignment 170 @rtype: bool 171 """ 172 newAttr,newVals = self.attributes[len(assignment)] 173 for pos in range(len(assignment)): 174 oldAttr,oldVals = self.attributes[pos] 175 if detectConflict(oldAttr,assignment[pos],newAttr,subIndex): 176 return False 177 else: 178 return True
179
180 - def factored2index(self,factors,check=False):
181 """ 182 Transforms a list of subindices into a list of matching rule indices 183 @param check: if C{True}, then check consistency before returning indices (default is C{False}) 184 @type check: bool 185 @note: subindex can be a list of subindices 186 @type factors: int[] 187 @rtype: int[] 188 """ 189 # Start with a single empty attribute assignment 190 old = [[]] 191 # Iterate through each attribute 192 for position in range(len(self.attributes)): 193 obj,values = self.attributes[position] 194 if isinstance(obj,KeyedVector): 195 size = len(values) + 1 196 else: 197 size = len(values) 198 # Iterate through indices generated so far 199 new = [] 200 for assignment in old: 201 if factors[position] is None: 202 # No constraint on this attribute value 203 possible = range(size) 204 elif isinstance(factors[position],tuple): 205 # Attribute value is a interval 206 possible = range(factors[position][0], 207 factors[position][1]+1) 208 elif isinstance(factors[position],list): 209 # Attribute value is a set 210 possible = factors[position] 211 else: 212 # Assume attribute value is a singleton 213 possible = [factors[position]] 214 for subIndex in possible: 215 if not check or self.consistentp(assignment,subIndex): 216 new.append(assignment + [subIndex]) 217 old = new 218 # Convert each assignment into an integer 219 indices = [] 220 for assignment in old: 221 index = 0 222 for position in range(len(self.attributes)): 223 obj,values = self.attributes[position] 224 if isinstance(obj,KeyedVector): 225 size = len(values) + 1 226 else: 227 size = len(values) 228 index *= size 229 index += assignment[position] 230 indices.append(index) 231 return indices
232
233 - def OLDfactored2index(self,factors):
234 """ 235 Transforms a list of subindices into a list of matching rule indices 236 @note: subindex can be a list of subindices 237 @type factors: int[] 238 @rtype: int[] 239 """ 240 # Start with a single index: all attributes at the minimum value 241 indices = [0] 242 # Iterate through each attribute 243 for position in range(len(self.attributes)): 244 obj,values = self.attributes[position] 245 if isinstance(obj,KeyedVector): 246 size = len(values) + 1 247 else: 248 size = len(values) 249 # Iterate through indices generated so far 250 for index in indices[:]: 251 indices.remove(index) 252 index *= size 253 if factors[position] is None: 254 # No constraint on this attribute value 255 for subIndex in range(len(values)+1): 256 indices.append(index + subIndex) 257 elif isinstance(factors[position],tuple): 258 # Attribute value is a interval 259 for subIndex in range(factors[position][0], 260 factors[position][1]+1): 261 indices.append(index + subIndex) 262 elif isinstance(factors[position],list): 263 # Attribute value is a set 264 for subIndex in factors[position]: 265 indices.append(index + subIndex) 266 else: 267 # Assume attribute value is a singleton 268 indices.append(index + factors[position]) 269 return indices
270
271 - def index2factored(self,index):
272 """ 273 Transforms a rule index into a list of subindices 274 @type index: int 275 @rtype: int[] 276 """ 277 factors = [] 278 for pos in range(len(self.attributes)): 279 obj,values = self.attributes[-pos-1] 280 if isinstance(obj,KeyedVector): 281 size = len(values) + 1 282 else: 283 size = len(values) 284 factors.insert(0,index % size) 285 index /= size 286 return factors
287
288 - def fromTree(self,tree):
289 """Extract a tabular representation of the given PWL tree. Updates this tree to represent the same PWL function as the given tree. 290 @param tree: the tree to import 291 @type tree: L{KeyedTree} 292 """ 293 self.reset() 294 remaining = [tree] 295 while remaining: 296 node = remaining.pop() 297 if not node.isLeaf(): 298 if not node.isProbabilistic(): 299 for plane in node.split: 300 print plane 301 self.addAttribute(plane.weights,plane.threshold) 302 remaining += node.children() 303 for obj,values in self.attributes: 304 print obj 305 print values
306
307 - def getTable(self):
308 """ 309 @return: the base table (stripped of any subclass extras) 310 @rtype: L{PWLTable} 311 """ 312 result = PWLTable() 313 return self.copy(result)
314
315 - def _consistent(self,attr1,great1,attr2,great2=None,debug=False):
316 """Compares an attribute-value pair against another (or others) to determine whether they're potentially consistent 317 @type attr1: L{KeyedVector} 318 @type attr2: L{KeyedVector} or (L{KeyedVector},bool)[] 319 @type great1: bool 320 @type great2: bool or None 321 @return: C{False} if never consistent, C{True} if always consistent, C{None} otherwise 322 """ 323 if isinstance(attr1,int) and isinstance(attr2,int): 324 # Look for cached result 325 cache1 = '%d,%d' % (attr1,int(great1)) 326 cache2 = '%d,%d' % (attr2,int(great2)) 327 try: 328 table = self._consistency[cache1] 329 try: 330 result = self._consistency[cache1][cache2] 331 if debug: print '\tCache hit:',result 332 return result 333 except KeyError: 334 pass 335 except KeyError: 336 self._consistency[cache1] = {} 337 else: 338 cache1,cache2 = None,None 339 if isinstance(attr1,int): 340 attr1 = self.attributes[attr1][0] 341 if isinstance(attr2,list): 342 # Multiple attributes to test against each other 343 value = None 344 for pos in range(len(attr2)): 345 # NOTE: switch order 346 test = self._consistent(attr2[pos][0],attr2[pos][1], 347 attr1,great1,debug) 348 if debug: print '\t\t',test 349 if test is False: 350 # If one is inconsistent, then whole thing is 351 return False 352 elif test is True: 353 # Subsumed by another factor 354 value = True 355 return value 356 elif isinstance(attr2,int): 357 attr2 = self.attributes[attr2][0] 358 # Solve for first of two variables 359 key1,key2 = attr1.keys() 360 if len(attr1) != 2: 361 # Don't have a general n-tuple version of this 362 raise NotImplementedError,'I handle only binary attributes' 363 try: 364 weight1 = - attr1[key2] / attr1[key1] 365 except: 366 raise NotImplementedError,'Unable to handle unary tests: %s' \ 367 % (str(attr1.getArray())) 368 if attr1[key1] < 0.: 369 great1 = not great1 370 if len(attr2) != 2: 371 # Don't have a general n-tuple version of this 372 raise NotImplementedError,'I handle only binary attributes' 373 # Solve for first of two variables 374 try: 375 weight2 = - attr2[key2] / attr2[key1] 376 except: 377 raise NotImplementedError,'Unable to handle unary tests' 378 if attr2[key1] < 0.: 379 great2 = not great2 380 if debug: 381 print '\tComparing:',getProbRep(attr1,great1) 382 print '\tvs.:',getProbRep(attr2,great2) 383 result = None 384 if great1 != great2: 385 # Thresholds in different direction 386 if great1: 387 # weight1*y < x < weight2*y 388 if weight1 > weight2: 389 if debug: print '\t\tInconsistent' 390 result = False 391 else: 392 # weight1*y > x > weight2*y 393 if weight2 > weight1: 394 result = False 395 elif attr1 == attr2: 396 # Exact match 397 if debug: print '\t\tEqual' 398 result = True 399 elif len(attr1) == 2: 400 # Probabilistic comparison 401 thresh1 = solveTuple(attr1) 402 thresh2 = solveTuple(attr2) 403 # Both are thresholds on the same variable 404 if great1 and great2: 405 # x > theta 406 if thresh1 > thresh2: 407 # Subsumed 408 if debug: print '\t\tSubsumed' 409 result = True 410 elif not great1 and not great2: 411 # x < theta 412 if thresh1 < thresh2: 413 # Subsumed 414 if debug: print '\t\tSubsumed' 415 result = True 416 if cache1: 417 self._consistency[cache1][cache2] = result 418 return result
419
420 - def prune(self,rulesOnly=False,debug=False):
421 """Removes rows and attributes that are irrelevant 422 @param rulesOnly: if C{True}, only the RHS of the rules need to be distinct, not the value function as well (default is C{False}) 423 @type rulesOnly: bool 424 """ 425 self.pruneRules(debug) 426 self.pruneAttributes(rulesOnly,debug)
427
428 - def pruneRules(self,debug=False):
429 keyList = self.rules.keys() 430 if not keyList: 431 keyList = self.values.keys() 432 if debug: 433 print 'Starting with %d rules' % (len(keyList)) 434 # Prune contradictory LHS combinations 435 for rule in keyList[:]: 436 factors = self.index2factored(rule) 437 consistent = True 438 for i in range(len(self.attributes)-1): 439 attrI,values = self.attributes[i] 440 assert values == [0.],'Unable to prune tables with nonzero intercepts in their LHS conditions' 441 for j in range(i+1,len(self.attributes)): 442 attrJ,values = self.attributes[j] 443 pairwise = self._consistent(i,bool(factors[i]),j,bool(factors[j])) 444 if pairwise is None: 445 pass 446 elif pairwise: 447 pass 448 else: 449 assert pairwise is False 450 consistent = False 451 if debug: 452 print 453 print attrI.getArray(),factors[i] 454 print 'inconsistent with' 455 print attrJ.getArray(),factors[j] 456 break 457 if not consistent: 458 # Already found inconsistency 459 if self.rules.has_key(rule): 460 del self.rules[rule] 461 if self.values.has_key(rule): 462 del self.values[rule] 463 keyList.remove(rule) 464 break
465
466 - def pruneAttributes(self,rulesOnly=False,debug=False):
467 """Prune irrelevant attributes 468 """ 469 if debug: 470 print 'Starting with %d attributes' % (self.attributes) 471 keyList = self.rules.keys() 472 if not keyList: 473 keyList = self.values.keys() 474 delete = [] 475 for attrIndex in range(len(self.attributes)): 476 old = {} # track which rules we've already tested 477 attr,values = self.attributes[attrIndex] 478 distinct = False # track whether we've found different RHS 479 if debug: 480 print 'Testing distinctness of:',attr.getArray() 481 for ruleIndex in keyList: 482 if debug: 483 print '\tStarting rule:',ruleIndex 484 if not old.has_key(ruleIndex): 485 factors = self.index2factored(ruleIndex) 486 rhs = None # track RHS that everything should match 487 for valueIndex in range(len(values)+1): 488 factors[attrIndex] = valueIndex 489 ruleIndex = self.factored2index(factors)[0] 490 if debug: 491 print '\tTesting rule:',ruleIndex 492 if ruleIndex in keyList: 493 old[ruleIndex] = True 494 if rhs is None: 495 rhs = {} 496 if self.rules.has_key(ruleIndex): 497 rhs['rules'] = self.rules[ruleIndex] 498 if not rulesOnly and self.values.has_key(ruleIndex): 499 # Compare value function as well 500 rhs['values'] = self.values[ruleIndex] 501 elif not distinct: 502 # Check this RHS for a distinct value 503 me = {} 504 if self.rules.has_key(ruleIndex): 505 me['rules'] = self.rules[ruleIndex] 506 if not rulesOnly and self.values.has_key(ruleIndex): 507 me['values'] = self.values[ruleIndex] 508 if rhs != me: 509 distinct = True 510 if not distinct and not attrIndex in delete: 511 if debug: 512 print 'Delete attribute:',attrIndex 513 delete.append(attrIndex) 514 delete.sort() 515 delete.reverse() 516 for attrIndex in delete: 517 self.delAttribute(attrIndex)
518
519 - def max(self,debug=False):
520 """ 521 Computes the rules based on maximizing the values in this table 522 @return: the table with the newly generated rules 523 @rtype: L{PWLTable} 524 @warning: assumes that the same option keys exist in every rule in the value function 525 """ 526 # Generate new LHS conditions 527 options = self.values.values()[0].keys() 528 options.sort() 529 rhs = {} # Store possible conditions that would trigger each RHS 530 for desired in options: 531 rhs[desired] = [] 532 others = {} # Cache defeating conditions of other RHS 533 rules = self.values.keys() 534 rules.sort() 535 for rule in rules: 536 # Identify the LHS conditions that trigger this rule 537 factors = self.index2factored(rule) 538 if debug: 539 print 'Rule:',rule 540 lhs = [] 541 for index in range(len(self.attributes)): 542 lhs.append((self.attributes[index][0],factors[index])) 543 if debug: 544 print '\t',getProbRep(self.attributes[index][0],factors[index]) 545 # Initialize preconditions 546 for desired in options: 547 others[desired] = [] 548 # Do pairwise comparisons between RHS values 549 for i in range(len(options)): 550 desired = options[i] 551 if not others.has_key(desired): 552 # Alternative has been previously eliminated 553 continue 554 if debug: 555 print desired,self.values[rule][desired].getArray() 556 # Add on preconditions found so far 557 path = lhs + others[desired] 558 # Compare against other possible RHS values 559 for j in range(i+1,len(options)): 560 alternative = options[j] 561 if not others.has_key(alternative): 562 # Alternative has been previous eliminated 563 continue 564 if debug: 565 print 566 print '\tvs.',alternative,self.values[rule][alternative].getArray() 567 ## if self.values[rule][desired] == self.values[rule][alternative]: 568 ## # No preference 569 ## if debug: print '\tIndifferent' 570 ## del others[alternative] 571 ## continue 572 weights = self.values[rule][desired] - self.values[rule][alternative] 573 try: 574 weights.normalize() 575 except ZeroDivisionError: 576 if debug: print '\tZero vector' 577 del others[alternative] 578 continue 579 side = 1 580 if len(weights) == 2: 581 # Normalize direction 582 a,b = weights.getArray() 583 if a < b: 584 weights = -weights 585 side = 0 586 if debug: print '\tDifference:',getProbRep(weights,side) 587 # Check whether this condition can ever be met 588 test = KeyedPlane(weights,0.).always(probability=True) 589 if test is None: 590 pass 591 elif not bool(side) is test: 592 if debug: print '\tNever True' 593 break 594 elif bool(side) is test: 595 if debug: print '\tAlways True' 596 del others[alternative] 597 continue 598 # Check whether condition consistent with original LHS 599 test = self._consistent(weights,side,lhs) 600 if debug: print '\tConsistent?',test 601 if test is None: 602 # Compare against my pre-conditions 603 test = self._consistent(weights,side,others[desired]) 604 if test: 605 if debug: print '\t\tSubsumed by pre-condition' 606 continue 607 elif test is False: 608 if debug: print '\t\tInconsistent with pre-condition' 609 continue 610 # Add to alternative's pre-conditions 611 if others[alternative]: 612 test = self._consistent(weights,1-side,lhs+others[alternative]) 613 else: 614 test = None 615 if test is None: 616 if debug: print '\tPrecondition:',getProbRep(weights,1-side) 617 others[alternative].append((weights,1-side)) 618 elif test is False: 619 if debug: print '\tImpossible' 620 del others[alternative] 621 else: 622 assert test is True 623 if debug: print '\tSubsumed' 624 # Check whether any existing conditions are subsumed 625 if debug: 626 print 627 index = 0 628 while index < len(path): 629 test = self._consistent(path[index][0],path[index][1],weights,side) 630 if test is True: 631 if debug: print '\tSubsumes:',getProbRep(path[index][0],path[index][1]) 632 del path[index] 633 else: 634 index += 1 635 path.append((weights,side)) 636 elif test: 637 # This condition is always met 638 if debug: print '\tDominated' 639 del others[alternative] 640 else: 641 # This condition is never met 642 if debug: print '\tInconsistent' 643 break 644 else: 645 # Conditions are all potentially meetable 646 if debug: 647 print 'Final for rule',rule,desired 648 for weights,side in path: 649 print '\t',getProbRep(weights,side) 650 rhs[desired].append({'lhs':path,'value':rule}) 651 # Generate new table attributes 652 policy = PWLTable() 653 for desired,conditions in rhs.items(): 654 for condition in conditions: 655 path = condition['lhs'] 656 for index in range(len(path)): 657 pos = policy.addAttribute(path[index][0],0.) 658 policy.initialize() 659 if debug: 660 print 'New attributes:' 661 for attr in policy.attributes: 662 print '\t',getProbRep(attr[0]) 663 # Translate plane into attribute index 664 cache = {} 665 for desired,conditions in rhs.items(): 666 if debug: print 'Processing:',desired 667 for condition in conditions: 668 path = condition['lhs'] 669 if debug: print 'From rule',condition['value'] 670 for index in range(len(path)): 671 try: 672 attr = policy._attributes[str(path[index][0].getArray())] 673 value = path[index][1] 674 except KeyError: 675 attr = policy._attributes[str(-path[index][0].getArray())] 676 assert policy.attributes[attr][1] == [0.] 677 value = 1 - path[index][1] 678 if debug: print '\t%d,%d' % (attr,value) 679 path[index] = (attr,value) 680 # Generate new rules for this table 681 for desired,conditions in rhs.items(): 682 if debug: print 'Inserting:',desired 683 for condition in conditions: 684 # Initialize attribute values 685 factors = map(lambda i: None,range(len(policy.attributes))) 686 path = [] 687 # Override defaults with LHS values 688 for attr,value in condition['lhs']: 689 assert isinstance(attr,int) 690 assert factors[attr] is None 691 factors[attr] = value 692 path.append((attr,value)) 693 if debug: print 'Original:',condition['value'],policy.factorString(factors) 694 for attr in range(len(policy.attributes)): 695 if factors[attr] is None: 696 test = policy._consistent(attr,1,path) 697 if debug: print '\t',getProbRep(policy.attributes[attr][0],1),test 698 if test is None: 699 # Possibly satisfied 700 factors[attr] = [0,1] 701 elif test: 702 # Always satisfied 703 factors[attr] = 1 704 path.append((policy.attributes[attr][0],1)) 705 else: 706 # Never satisfied 707 factors[attr] = 0 708 path.append((policy.attributes[attr][0],0)) 709 if debug: print '\t%s' % (policy.factorString(factors)) 710 # Insert RHS and value into specified rule 711 for index in policy.factored2index(factors): 712 if debug: print index, 713 assert not policy.rules.has_key(index) 714 policy.rules[index] = desired 715 policy.values[index] = self.values[condition['value']] 716 if debug: print 717 return policy
718
719 - def star(self):
720 """Computes the optimal value function, independent of action 721 @return: a table with the optimal value as the rules' RHS, and no values 722 @rtype: L{PWLTable} 723 """ 724 result = self.getTable() 725 for rule in result.rules.keys(): 726 rhs = result.rules[rule] 727 result.rules[rule] = result.values[rule][str(rhs)] 728 result.values[rule].clear() 729 return result
730
731 - def __copy__(self):
732 result = self.copy(self.__class__()) 733 return result
734
735 - def copy(self,result):
736 result.attributes = self.attributes[:] 737 result.rules.update(self.rules) 738 for key,table in self.values.items(): 739 result.values[key] = {} 740 result.values[key].update(table) 741 return result
742
743 - def __len__(self):
744 count = 1 745 for obj,values in self.attributes: 746 if isinstance(obj,KeyedVector): 747 count *= len(values)+1 748 else: 749 count *= len(values) 750 return count
751
752 - def __add__(self,other,debug=False):
753 if self.zeroPlanes and other.zeroPlanes: 754 return self.mergeZero(other,lambda x,y: x+y,None,debug) 755 result = PWLTable() 756 if debug: 757 print 'I:',self 758 print 'U:',other 759 # Start with addend's attributes 760 for obj,values in other.attributes: 761 result.attributes.append((obj,values[:])) 762 # Insert mine in as well 763 for obj,values in self.attributes: 764 index = result.addAttribute(obj,values[0]) 765 for value in values[1:]: 766 if not value in result.attributes[index][1]: 767 result.attributes[index][1].append(value) 768 result.attributes[index][1].sort() 769 result.initialize() 770 if debug: 771 print 'New attributes:' 772 for attr in result.attributes: 773 print '\t',attr[0].getArray() 774 # Transfer RHS 775 for myRule in self.values.keys(): 776 myFactors = self.index2factored(myRule) 777 for yrRule in other.values.keys(): 778 yrFactors = other.index2factored(yrRule) 779 # Compute new rule index 780 newFactors = result.mapIndex(other,yrFactors) 781 newFactors = result.mapIndex(self,myFactors,newFactors) 782 if isinstance(newFactors,list): 783 # Consistent mapping found, so insert RHS 784 indexList = result.factored2index(newFactors) 785 for option,yrRHS in other.values[yrRule].items(): 786 # Compute new RHS 787 myRHS = self.values[myRule][option] 788 newRHS = myRHS + yrRHS 789 for newIndex in indexList: 790 if not result.values.has_key(newIndex): 791 result.values[newIndex] = {} 792 assert not result.values[newIndex].has_key(option) 793 result.values[newIndex][option] = newRHS 794 return result
795
796 - def __mul__(self,other,combiner=None,debug=False):
797 """ 798 @param combiner: optional binary function for using in combining RHS matrices (default is multiplication, duh) 799 @type combiner: lambda 800 @warning: like matrix multiplication, not commutative 801 """ 802 if self.zeroPlanes and other.zeroPlanes: 803 return self.mergeZero(other,combiner,lambda x,y: x*y,debug) 804 result = self.__class__() 805 # Start with right multiplicand's LHS 806 for obj,values in other.attributes: 807 result.attributes.append((obj,values[:])) 808 # Project my LHS 809 for rule,V in other.values.items(): 810 if V: 811 # Access all RHS in value function 812 new = V.values() 813 else: 814 # No value function, take rule RHS 815 new = [other.rules[rule]] 816 for rhs in new: 817 for obj,values in self.attributes: 818 new = obj*rhs 819 new.normalize() 820 index = None 821 for value in values: 822 plane = KeyedPlane(new,value) 823 if plane.always(probability=True) is None: 824 if index is None: 825 index = result.addAttribute(new,value) 826 else: 827 result.attributes[index][1].append(value) 828 if not index is None: 829 result.attributes[index][1].sort() 830 result.initialize() 831 if debug: 832 print 'New attributes:' 833 for attr in result.attributes: 834 print '\t',attr[0].getArray() 835 # Transfer RHS 836 for myRule in self.values.keys(): 837 myFactors = self.index2factored(myRule) 838 for yrRule in other.values.keys(): 839 yrFactors = other.index2factored(yrRule) 840 for option,yrRHS in other.values[yrRule].items(): 841 # Compute new RHS 842 try: 843 myRHS = self.values[myRule][option] 844 except KeyError: 845 # No value function... use rules 846 myRHS = self.rules[myRule] 847 if combiner: 848 newRHS = combiner(myRHS,yrRHS) 849 else: 850 newRHS = myRHS*yrRHS 851 if debug: 852 print 853 print 'A:' 854 for index in range(len(self.attributes)): 855 print bool(myFactors[index]),self.attributes[index][0].getArray() 856 ## print '->',myRHS.getArray() 857 print 'B:' 858 for index in range(len(other.attributes)): 859 print bool(yrFactors[index]),other.attributes[index][0].getArray() 860 ## print '->',yrRHS.getArray() 861 print option 862 print 'Product:',newRHS.getArray() 863 # Compute new rule index 864 newFactors = result.mapIndex(other,yrFactors,debug=debug) 865 newFactors = result.mapIndex(self,myFactors, 866 newFactors,yrRHS,debug=debug) 867 if isinstance(newFactors,list): 868 # Consistent mapping found, so insert RHS 869 if debug: 870 print newFactors 871 for newIndex in result.factored2index(newFactors): 872 if debug: 873 print newIndex, 874 if not result.values.has_key(newIndex): 875 result.values[newIndex] = {} 876 assert not result.values[newIndex].has_key(option) 877 result.values[newIndex][option] = newRHS 878 if debug: 879 print 880 elif debug: 881 print 'Rejected' 882 result.pruneAttributes() 883 return result
884
885 - def mergeZero(self,other,combiner=None,projector=None,debug=False):
886 """ 887 Merging when both tables have all of their hyperplanes going through the origin 888 @param combiner: optional binary function for using in combining RHS matrices (default is multiplication) 889 @type combiner: lambda 890 @param projector: optional binary function for using in projecting my LHS attributes based on the RHS of the other 891 @type projector: lambda 892 @warning: like matrix multiplication, not commutative 893 """ 894 # Build pairwise combinations of all rules 895 entries = [] 896 for yrRule,yrValue in other.values.items(): 897 yrFactors = other.index2factored(yrRule) 898 for myRule,myValue in self.values.items(): 899 myFactors = self.index2factored(myRule) 900 for option,yrRHS in yrValue.items(): 901 # Extract my RHS 902 try: 903 myRHS = self.values[myRule][option] 904 except KeyError: 905 # No value function... use rules 906 myRHS = self.rules[myRule] 907 if debug: 908 print 'Combining:',self.factorString(myFactors),myRHS.getArray() 909 print 'with:',yrFactors,yrRHS.getArray() 910 print 'under:',option 911 path = [] 912 for myIndex in range(len(self.attributes)): 913 # Project hyperplane against RHS of multiplicand 914 myAttr = self.attributes[myIndex][0] 915 if projector: 916 newAttr = projector(myAttr,yrRHS) 917 newAttr.normalize() 918 if debug: 919 print '\t\t\tProjecting:',myAttr.getArray(),myFactors[myIndex] 920 print '\t\t\tInto:',newAttr.getArray() 921 # Check whether this is a degenerate condition 922 threshold = solveTuple(newAttr) 923 if not isinstance(threshold,float): 924 if threshold == myFactors[myIndex]: 925 # Always satisfied 926 if debug: 927 print '\tRedundant:',newAttr.getArray() 928 newAttr = None 929 else: 930 # Never satisfiable 931 if debug: 932 print '\tInconsistent:',newAttr.getArray() 933 consistent = False 934 break 935 else: 936 newAttr = myAttr 937 if newAttr: 938 if debug: 939 print '\t\tChecking:',newAttr.getArray() 940 consistent = True 941 # Check consistency against existing settings 942 for yrIndex in range(len(other.attributes)): 943 yrAttr = other.attributes[yrIndex][0] 944 if detectConflict(newAttr,myFactors[myIndex], 945 yrAttr,yrFactors[yrIndex]): 946 # Conflict with other rule 947 consistent = False 948 break 949 for yrAttr,yrSide in path: 950 if detectConflict(newAttr,myFactors[myIndex], 951 yrAttr,yrSide): 952 # Conflict with other rule 953 consistent = False 954 break 955 else: 956 path.append((newAttr,myFactors[myIndex])) 957 if not consistent: 958 break 959 else: 960 # Consistent path found 961 entry = {'LHS': yrFactors + path,'option': option} 962 if combiner: 963 entry['RHS'] = combiner(myRHS,yrRHS) 964 else: 965 entry['RHS'] = myRHS*yrRHS 966 if debug: 967 print 'Path found:' 968 for vector,side in entry['LHS']: 969 print '\t',getProbRep(vector,side) 970 print '\t',entry['RHS'].getArray() 971 entries.append(entry) 972 # Extract entries found 973 result = self.__class__() 974 # Copy over attributes from other 975 attributes = {} 976 for obj,values in other.attributes: 977 result.addAttribute(obj,values[0]) 978 attributes[str(obj.getArray())] = True 979 # Add new attributes from each new rule 980 for entry in entries: 981 for attr,side in entry['LHS'][len(other.attributes):]: 982 key = str(attr.getArray()) 983 if not attributes.has_key(key): 984 result.addAttribute(attr,0.) 985 attributes[key] = True 986 # Find indices for new set of attributes 987 for index in range(len(result.attributes)): 988 attributes[str(result.attributes[index][0].getArray())] = index 989 if debug: 990 print 'Attribute:',result.attributes[index][0].getArray() 991 # Convert entries into value function entries 992 for entry in entries: 993 if debug: 994 print 'New Entry:',entry['option'] 995 for vector,side in entry['LHS']: 996 print '\t',getProbRep(vector,side) 997 # Start with all wildcards 998 factors = [] 999 for attr in result.attributes: 1000 factors.append(None) 1001 # Set factors from other's rule index 1002 for index in range(len(other.attributes)): 1003 plane = other.attributes[index][0] 1004 side = entry['LHS'][index] 1005 factors[attributes[str(plane.getArray())]] = side 1006 # Set new factors 1007 for plane,side in entry['LHS'][len(other.attributes):]: 1008 factors[attributes[str(plane.getArray())]] = side 1009 # Insert new value entry 1010 indexList = result.factored2index(factors,check=True) 1011 if debug: 1012 print '\t',factors 1013 print '\t',indexList 1014 for index in indexList: 1015 if not result.values.has_key(index): 1016 result.values[index] = {} 1017 assert not result.values[index].has_key(entry['option']),\ 1018 '%d, %s' % (index,entry['option']) 1019 result.values[index][entry['option']] = entry['RHS'] 1020 return result
1021
1022 - def mapIndex(self,other,factors,result=None,multiplicand=None,debug=False):
1023 """Translates an index in another table into one for this table 1024 @param other: the other table 1025 @type other: L{PWLTable} 1026 @param factors: the index or factors of the rule to map 1027 @type factors: int or int[] 1028 @param result: previously determined factors that should be merged (default is C{None}) 1029 @type result: int[] 1030 @param multiplicand: matrix used to scale any attributes (default is identity) 1031 @type multiplicand: L{KeyedMatrix<teamwork.math.KeyedMatrix.KeyedMatrix>} 1032 @return: a list of attributes subindices, C{None} if no consistent index exists 1033 @rtype: int[] 1034 """ 1035 if result is None: 1036 result = map(lambda attr: None,self.attributes) 1037 if isinstance(factors,int): 1038 factors = other.index2factored(factors) 1039 for pos in range(len(factors)): 1040 obj,values = other.attributes[pos] 1041 assert values == [0.],'Unable to handle non-zero thresholds' 1042 if multiplicand: 1043 # Apply projection to LHS condition 1044 obj = obj*multiplicand 1045 obj.normalize() 1046 # Figure out interval of acceptable values for this attribute 1047 greater = factors[pos] == 1 1048 if debug: 1049 print '\tMapping:',obj.getArray() 1050 # Map interval into new range of possible values 1051 try: 1052 index = self._attributes[str(obj.getArray())] 1053 except KeyError: 1054 try: 1055 index = self._attributes[str(-obj.getArray())] 1056 greater = not greater 1057 except KeyError: 1058 # Check whether we've generated a degenerate condition 1059 assert values == [0.] 1060 plane = KeyedPlane(obj,0.) 1061 always = plane.always(probability=True) 1062 if always is None: 1063 raise UserWarning,str(plane) 1064 elif always: 1065 # This condition is always true 1066 if greater: 1067 continue 1068 else: 1069 if debug: 1070 print '\t',always 1071 return None 1072 else: 1073 # This condition will never be met 1074 if greater: 1075 if debug: 1076 print '\t',always 1077 return None 1078 else: 1079 continue 1080 obj,values = self.attributes[index] 1081 assert values == [0.],'Unable to handle non-zero thresholds' 1082 if result[index] is None: 1083 # Check consistency 1084 if index < 0: 1085 always = self._consistent(index+len(self.attributes),greater, 1086 map(lambda i: (i,result[i]),range(len(self.attributes)))) 1087 else: 1088 always = self._consistent(index,greater, 1089 map(lambda i: (i,result[i]),range(len(self.attributes)))) 1090 if always is False: 1091 return None 1092 # Setting attribute value fresh 1093 if greater: 1094 result[index] = 1 1095 else: 1096 result[index] = 0 1097 else: 1098 # Merge with existing attribute 1099 if greater: 1100 if result[index] == 0: 1101 # Mismatch 1102 return None 1103 elif result[index] == 1: 1104 # Mismatch 1105 return None 1106 return result
1107
1108 - def __str__(self):
1109 if not self.rules: 1110 # Print out value function 1111 return self.valueString() 1112 else: 1113 # Print out rules 1114 return self._ruleString()
1115
1116 - def _ruleString(self):
1117 """Helper method that returns a string representation of the rules 1118 @rtype: str 1119 """ 1120 content = self._attributeHeader() 1121 rules = self.rules.keys() 1122 rules.sort() 1123 for rule in rules: 1124 row = '%5d%s' % (rule,self.factorString(rule)) 1125 if isinstance(self.rules[rule],KeyedVector): 1126 row += '\t%s' % (self.rules[rule].getArray()) 1127 else: 1128 row += '\t%s' % (self.rules[rule]) 1129 content += '\n%s' % (row) 1130 return content
1131
1132 - def valueString(self):
1133 """Helper method that returns string representation of value function 1134 @rtype: str 1135 """ 1136 content = self._attributeHeader() + '\tValue' 1137 rules = self.values.keys() 1138 rules.sort() 1139 for rule in rules: 1140 lhs = self.factorString(rule) 1141 keys = self.values[rule].keys() 1142 keys.sort() 1143 for key in keys: 1144 value = self.values[rule][key] 1145 row = '%5d%s\t%s\t%s' % (rule,lhs,key,value.getArray()) 1146 content += '\n%s' % (row) 1147 return content
1148
1149 - def _attributeHeader(self,rhsLabel='Action'):
1150 """Helper method that returns column headings for the attributes 1151 @param rhsLabel: column heading to use for RHS (default is 'Action') 1152 @type rhsLabel: str 1153 @rtype: str 1154 """ 1155 row = 'Index' 1156 for obj,values in self.attributes: 1157 row += '%s' % (attrString(obj)) 1158 row += '\tAction' 1159 return row
1160
1161 - def factorString(self,factors):
1162 """Helper method that returns string representation of factor tuple 1163 @param factors: factors (or rule index) 1164 @type factors: int or int[] 1165 """ 1166 lhs = '' 1167 if not isinstance(factors,list): 1168 factors = self.index2factored(factors) 1169 for attr in range(len(self.attributes)): 1170 values = self.attributes[attr][1] 1171 if len(self.attributes[attr][0]) == 2 and values == [0.]: 1172 if factors[attr] == 0: 1173 return (getProbRep(self.attributes[attr][0], 1174 factors[attr])) 1175 elif attr == len(self.attributes)-1: 1176 return (getProbRep(self.attributes[attr][0], 1177 factors[attr])) 1178 else: 1179 if factors[attr] == 0: 1180 lhs += '\t<=%8.3f' % (values[factors[attr]]) 1181 elif factors[attr] == len(values): 1182 lhs += '\t >%8.3f' % (values[-1]) 1183 else: 1184 lhs += '\t<=%5.3f,>%5.3f' % (values[factors[attr-1]], 1185 values[factors[attr]]) 1186 return lhs
1187
1188 -def attrString(attr):
1189 """ 1190 @return: a happy string representation of the given attribute 1191 @rtype: str 1192 """ 1193 if isinstance(attr,KeyedVector): 1194 if len(attr) == 2: 1195 return getProbRep(attr,True) 1196 else: 1197 keys = filter(lambda k: abs(attr[k]) > epsilon,attr.keys()) 1198 if len(keys) == 2: 1199 return getArrayRep(attr[keys[0]],attr[keys[1]],True) 1200 elif len(keys) == 1: 1201 return '%s>0.' % (keys[0]) 1202 else: 1203 return '\t'+','.join(map(lambda x: '%6.4f' % (x),attr.getArray())) 1204 else: 1205 return ' %s action' % (attr.name)
1206
1207 -def getProbRep(vector,side=True):
1208 """ 1209 @return: for probabilistic tuples, returns a unary constraint represenation of this vector 1210 @rtype: str 1211 """ 1212 # Probabilistic tuple: ax + by > 0. 1213 label = getArrayRep(vector.getArray()[0],vector.getArray()[1],side) 1214 return label
1215
1216 -def getArrayRep(a,b,side=True):
1217 """ 1218 @return: for binary array, returns a unary constraint representation of this vector 1219 @rtype: str 1220 """ 1221 # Probabilistic tuple: ax + by > 0. 1222 threshold,var = solveTuple(a,b),'L' 1223 if a-b < 0.: 1224 side = not side 1225 if side is None: 1226 sign = '??' 1227 elif side: 1228 sign = '> ' 1229 else: 1230 sign = '<=' 1231 return ' %s%s%5.3f' % (var,sign,threshold)
1232
1233 -def solveTuple(a,b=None):
1234 """Solves a 2-dimensional vector for one of the variables 1235 @param vector: ax + by 1236 @type vector: L{KeyedVector} 1237 @return: -b/(a-b) if a!=b; otherwise, C{True} iff b>0 1238 @rtype: float or bool 1239 """ 1240 # Solve one for the other 1241 if b is None: 1242 a,b = a.getArray() 1243 try: 1244 return -b/(a-b) 1245 except: 1246 return b > 0.
1247
1248 -def detectConflict(vector1,side1,vector2,side2):
1249 """Detects whether there is a conflict between two attribute-value pairs, where each attribute is a binary, 2-dimensional vector 1250 @type side1,side2: bool 1251 @type vector1,vector2: L{KeyedVector} 1252 @return: C{True} if there is a conflict 1253 """ 1254 if len(vector1) == 2 and len(vector2) == 2: 1255 weight1 = solveTuple(vector1) 1256 weight2 = solveTuple(vector2) 1257 # compare a > w*b attributes 1258 if side1 != side2: 1259 if side1: 1260 # > w1, < w2 1261 if weight1 > weight2-epsilon: 1262 return True 1263 else: 1264 # < w1, > w2 1265 if weight2 > weight1-epsilon: 1266 return True 1267 elif len(vector1) == 1 and len(vector2) == 1: 1268 raise NotImplementedError,'I should be able to do this, but my creator is lazy' 1269 else: 1270 raise NotImplementedError,'Your %d-dimensional vectors frighten and confuse me' % (max(len(vector1),len(vector2))) 1271 return False
1272