Package teamwork :: Package math :: Module KeyedVector
[hide private]
[frames] | no frames]

Source Code for Module teamwork.math.KeyedVector

   1  """Classes for vectors with symbolic indices (i.e., L{Key} objects) 
   2  @var slopeTypes: dictionary of available classes for hyperplanes, indexed by appropriate labels 
   3  """ 
   4  try: 
   5      from numpy.core.numeric import array,dot,all,seterr 
   6      seterr(divide='raise') 
   7      try: 
   8          from numpy.core.numeric import matrixmultiply 
   9      except ImportError: 
  10          matrixmultiply = dot 
  11  except ImportError: 
  12      try: 
  13          from scipy import array,matrixmultiply,dot,all 
  14      except ImportError: 
  15          from Numeric import array,matrixmultiply,dot 
  16          from Numeric import alltrue as all 
  17  from xml.dom.minidom import * 
  18  import copy 
  19  from matrices import epsilon 
  20  from Keys import * 
  21   
22 -class KeyedVector:
23 """A dictionary-based representation of a one-dimensional vector 24 @ivar _fresh: flag indicating whether the current array needs an update 25 @type _fresh: C{boolean} 26 @ivar _frozen: flag indicating whether the dimensions of this vector are subject to change 27 @type _frozen: C{boolean} 28 @ivar _string: the string representation of this vector 29 @type _string: C{str} 30 @ivar _array: The numeric representation of this vector 31 @type _array: C{array} 32 """
33 - def __init__(self,args={}):
34 self._frozen = False 35 ## self._fresh = False 36 self._order = {} 37 if len(args) > 0: 38 # Store initial values 39 self._orderedKeys = args.keys() 40 self._orderedKeys.sort() 41 values = [] 42 for index in range(len(args)): 43 key = self._orderedKeys[index] 44 self._order[key] = index 45 values.append(args[key]) 46 self._array = array(values) 47 else: 48 # Start with no values 49 self._orderedKeys = [] 50 self._array = None # This is somehow faster than array([])
51
52 - def keys(self):
53 """ 54 @return: a consistently ordered list of keys 55 @rtype: L{Key}[] 56 """ 57 return self._orderedKeys
58
59 - def setArray(self):
60 """Now deprecated because of irrelevance. Used to update the internal numeric representation based on the current dictionary contents""" 61 raise DeprecationWarning,'Calls to setArray should be unnecessary'
62
63 - def _updateString(self):
64 """Updates the string representation of this vector, as needed to enforce consistent ordering (e.g., for hashing)""" 65 self._string = '{' 66 for key in self.keys(): 67 value = self[key] 68 try: 69 self._string += '\n\t%s: %5.3f' % (key,value) 70 except TypeError: 71 raise TypeError,'Illegal vector value, %s (%s)' % \ 72 (value,type(value)) 73 self._string += '\n}'
74
75 - def addColumns(self,keys,values=None):
76 """Adds new slots to this vector 77 @param keys: the (sorted) keys for the new slots to insert 78 @type keys: L{Key}[] 79 @param values: the values to insert for each key (defaults to 0.) 80 @type values: float, or dict:L{Key}S{->}float 81 @warning: Assumes that keys are sorted! 82 """ 83 if self._array is None: 84 # Build up array from scratch; all keys are missing 85 self._orderedKeys = keys[:] 86 arrayValues = [] 87 for index in range(len(keys)): 88 key = keys[index] 89 self._order[key] = index 90 if isinstance(values,dict): 91 try: 92 arrayValues.append(values[key]) 93 except KeyError: 94 arrayValues.append(0.) 95 elif values is None: 96 arrayValues.append(0.) 97 else: 98 arrayValues.append(values) 99 self._array = array(arrayValues) 100 return 101 # Initialize values and find out which keys are missing 102 newKeys = [] 103 newValues = {} 104 for key in keys: 105 if not self.has_key(key): 106 newKeys.append(key) 107 if isinstance(values,dict): 108 try: 109 newValues[key] = values[key] 110 except KeyError: 111 newValues[key] = 0. 112 elif values is None: 113 newValues[key] = 0. 114 else: 115 newValues[key] = values 116 # Nothing to do if there are no missing columns 117 if len(newKeys) == 0: 118 return 119 elif self._frozen: 120 raise UserWarning,'You are modifying a frozen vector' 121 # OK, let's build ourselves a new vector 122 finalOrder = {} # The key order for the resulting vector 123 finalKeys = [] # The ordered key list for the resulting vector 124 finalArray = [] # The ordered value list for the resulting vector 125 oldIndex = 0 # The pointer to the next old key to copy 126 newIndex = 0 # The pointer to the next new key to insert 127 # Insert all the new keys 128 while newIndex < len(newKeys): 129 # Insert all the old keys that come before this new key 130 newKey = newKeys[newIndex] 131 while oldIndex < len(self._orderedKeys) and \ 132 self._orderedKeys[oldIndex] < newKey: 133 # Copy an old key 134 key = self._orderedKeys[oldIndex] 135 finalOrder[key] = len(finalKeys) 136 finalKeys.append(key) 137 finalArray.append(self._array[oldIndex]) 138 oldIndex += 1 139 # Copy the new key 140 finalOrder[newKey] = len(finalKeys) 141 finalKeys.append(newKey) 142 finalArray.append(newValues[newKey]) 143 newIndex += 1 144 while oldIndex < len(self._orderedKeys): 145 # Copy an old key 146 key = self._orderedKeys[oldIndex] 147 finalOrder[key] = len(finalKeys) 148 finalKeys.append(key) 149 finalArray.append(self._array[oldIndex]) 150 oldIndex += 1 151 # Set the final results 152 self._order = finalOrder 153 self._orderedKeys = finalKeys 154 self._array = array(finalArray)
155
156 - def getArray(self):
157 """ 158 @return: the numeric array representation of this vector 159 @rtype: C{array} 160 """ 161 if self._array is None: 162 return array([]) 163 else: 164 return self._array
165
166 - def __getitem__(self,key):
167 return self.getArray()[self._order[key]]
168
169 - def __setitem__(self,key,value):
170 """@type key: L{Key} instance 171 @type value: C{float}""" 172 try: 173 index = self._order[key] 174 self.getArray()[index] = value 175 except KeyError: 176 self.addColumns([key],values={key:value})
177
178 - def normalize(self):
179 """Scales this vector so that the highest absolute weight is 1 180 @warning: throws exception if the vector is all 0s 181 """ 182 factor = max(map(abs,self._array)) 183 if factor > 1e-10: 184 self._array *= 1./factor 185 # if max(self._array) < 1e-10: 186 # # All negative 187 # self._array = -self._array 188 else: 189 raise ZeroDivisionError
190
191 - def __len__(self):
192 return len(self._order)
193
194 - def items(self):
195 return map(lambda k:(k,self[k]),self.keys())
196
197 - def __eq__(self,other):
198 if self._frozen and other._frozen: 199 diff = sum(map(abs,self.getArray() - other.getArray())) 200 return diff < 1e-10 201 # return all(self.getArray() == other.getArray()) 202 else: 203 return all(self.getArray() == other.getArray()) == 1 and \ 204 (self._order == other._order)
205 206
207 - def __delitem__(self,key):
208 if self._frozen: 209 raise UserWarning,'You are modifying a frozen vector' 210 index = self._order[key] 211 values = self._array.tolist() 212 self._array = array(values[:index]+values[index+1:]) 213 del self._order[key] 214 for other in self._orderedKeys[index+1:]: 215 self._order[other] -= 1 216 self._orderedKeys.remove(key)
217
218 - def has_key(self,key):
219 return self._order.has_key(key)
220
221 - def fill(self,keys,value=None):
222 """Fills in any missing slots with a default value (it's really just a call to L{addColumns} now) 223 @param keys: the slots that should be filled 224 @type keys: list of L{Key} instances 225 @param value: the default value (defaults to 0) 226 @note: does not overwrite existing values 227 """ 228 self.addColumns(keys,value)
229
230 - def instantiate(self,table):
231 """Substitutes values for any abstract references, using the given substitution table 232 @param table: dictionary of key-value pairs, where the value will be substituted for any appearance of the given key in a field of this L{Key} object 233 @type table: dictionary 234 @rtype: L{KeyedVector} 235 """ 236 args = {} 237 for key,value in self.items(): 238 newKeys = key.instantiate(table) 239 if not isinstance(newKeys,list): 240 if newKeys == keyDelete: 241 newKeys = [] 242 else: 243 newKeys = [newKeys] 244 for newKey in newKeys: 245 try: 246 args[newKey] += value 247 except KeyError: 248 args[newKey] = value 249 return self.__class__(args)
250
251 - def instantiateKeys(self,table):
252 """Substitutes values for any abstract references, using the given substitution table 253 @param table: dictionary of key-value pairs, where the value will be substituted for any appearance of the given key in a field of this L{Key} object 254 @type table: dictionary""" 255 args = {} 256 for key,value in self.items(): 257 newKeys = key.instantiate(table) 258 if not isinstance(newKeys,list): 259 if newKeys == keyDelete: 260 newKeys = [] 261 else: 262 newKeys = [newKeys] 263 for newKey in newKeys: 264 try: 265 args[newKey] += value 266 except KeyError: 267 args[newKey] = value 268 # Reset the vector 269 self._order.clear() 270 self._orderedKeys = args.keys() 271 self._orderedKeys.sort() 272 values = [] 273 for index in range(len(args)): 274 key = self._orderedKeys[index] 275 self._order[key] = index 276 values.append(args[key]) 277 self._array = array(values) 278 ## self._fresh = False 279 self._updateString()
280
281 - def compose(self,other,op):
282 """Composes the two vectors together using the given operator 283 @param other: the other vector to compose with 284 @type other: L{KeyedVector} instance 285 @param op: the operator used to generate the new array values 286 @type op: C{lambda x,y:f(x,y)} where C{x} and C{y} are C{array} instances 287 @rtype: a new L{KeyedVector} instance""" 288 result = KeyedVector() 289 result._order = self._order 290 result._orderedKeys = self._orderedKeys 291 result._array = op(self.getArray(),other.getArray()) 292 if self._frozen: 293 result.freeze() 294 else: 295 result.unfreeze() 296 return result
297
298 - def freeze(self):
299 """Locks in the dimensions and keys of this vector. A frozen vector leads to faster math. 300 """ 301 self._frozen = True
302
303 - def unfreeze(self):
304 """Unlocks the dimensions and keys of this vector 305 @return: C{True} iff the vector was originally frozen 306 @rtype: bool 307 """ 308 if self._frozen: 309 self._order = copy.copy(self._order) 310 self._orderedKeys = copy.copy(self._orderedKeys) 311 self._frozen = False 312 return True 313 else: 314 return False
315
316 - def __add__(self,other):
317 """ 318 @warning: assumes that your vectors are aligned 319 """ 320 return self.compose(other,lambda x,y:x+y)
321
322 - def __sub__(self,other):
323 return self + (-other)
324
325 - def __neg__(self):
326 result = KeyedVector() 327 result._order = self._order 328 result._orderedKeys = self._orderedKeys 329 result._array = -self.getArray() 330 if self._frozen: 331 result.freeze() 332 else: 333 result.unfreeze() 334 return result
335
336 - def __mul__(self,other):
337 """ 338 - If other is a L{KeyedVector}, then the result is the dot product 339 - If other is a L{KeyedMatrix}, then the result is product of this vector, transposed, by the matrix 340 - Otherwise, each element in this vector is scaled by other 341 """ 342 if isinstance(other,KeyedVector): 343 # Dot product 344 if self._frozen and other._frozen: 345 # Assume that they are aligned 346 try: 347 result = dot(self.getArray(),other.getArray()) 348 except ValueError: 349 # Generate helpful error message 350 missing = [] 351 extra = [] 352 for key in self.keys(): 353 if not other.has_key(key): 354 missing.append('"%s"' % (str(key))) 355 for key in other.keys(): 356 if not self.has_key(key): 357 extra.append('"%s"' % (str(key))) 358 msg = 'Multiplicand' 359 if len(missing) > 0: 360 msg += ' is missing %s' % (string.join(missing,', ')) 361 if len(extra) > 0: 362 msg += ' and' 363 if len(extra) > 0: 364 msg += ' has extra %s' % (string.join(extra,', ')) 365 raise UserWarning,msg 366 else: 367 # Not aligned, so go key by key 368 result = 0. 369 for key in self.keys(): 370 try: 371 result += self[key]*other[key] 372 except KeyError: 373 # Assume other[key] is 0 374 pass 375 return result 376 elif isinstance(other,dict): 377 result = KeyedVector() 378 try: 379 result._array = matrixmultiply(self.getArray(), 380 other.getArray()) 381 except ValueError: 382 # More helpful error message than simply "objects are not aligned" 383 missing = [] 384 extra = [] 385 for key in self.keys(): 386 if not key in other.rowKeys(): 387 missing.append('row "%s"' % (str(key))) 388 if not key in other.colKeys(): 389 missing.append('column "%s"' % (str(key))) 390 for key in other.rowKeys(): 391 if not self._order.has_key(key): 392 extra.append('row "%s"' % (str(key))) 393 for key in other.colKeys(): 394 if not self._order.has_key(key): 395 extra.append('column "%s"' % (str(key))) 396 msg = 'Multiplicand has' 397 if len(missing) > 0: 398 msg += ' missing %s' % (string.join(missing,', ')) 399 if len(extra) > 0: 400 msg += ' and' 401 if len(extra) > 0: 402 msg += ' extra %s' % (string.join(extra,', ')) 403 raise UserWarning,msg 404 result._order = self._order 405 result._orderedKeys = self._orderedKeys 406 if self._frozen: 407 result.freeze() 408 else: 409 result.unfreeze() 410 return result 411 else: 412 result = copy.copy(self) 413 result._array = self.getArray()*other 414 result._order = self._order 415 result._orderedKeys = self._orderedKeys 416 if self._frozen: 417 result.freeze() 418 else: 419 result.unfreeze() 420 return result
421
422 - def __rmul__(self,other):
423 if isinstance(other,dict): 424 # Key by key multiplication 425 result = KeyedVector() 426 for key in self.keys(): 427 try: 428 result[key] = self*other[key] 429 except KeyError: 430 # Assume no-change if missing (by convention) 431 result[key] = self[key] 432 return result 433 else: 434 raise UserWarning,'Unable to multiply %s by KeyedVector' % \ 435 (other.__class__.__name__)
436
437 - def __str__(self):
438 self._updateString() 439 return self._string
440
441 - def simpleText(self,numbers=True,all=False):
442 content = '{' 443 for key in self.keys(): 444 value = self[key] 445 if all or abs(value) > epsilon: 446 if numbers: 447 value = '%5.3f' % (value) 448 else: 449 # Should do something else here 450 value = '%5.3f' % (value) 451 content += '\n\t%s:\t%s' % (str(key),value) 452 content += '\n}' 453 return content
454
455 - def __hash__(self):
456 return hash(str(self))
457
458 - def __xml__(self):
459 doc = Document() 460 root = doc.createElement('vector') 461 doc.appendChild(root) 462 for key,value in self.items(): 463 node = doc.createElement('entry') 464 root.appendChild(node) 465 node.appendChild(key.__xml__().documentElement) 466 node.setAttribute('weight',str(value)) 467 return doc
468
469 - def parse(self,element,changeInPlace=False):
470 """Extracts the distribution from the given XML element 471 @param element: The XML Element object specifying the vector 472 @type element: Element 473 @param changeInPlace: flag, if C{True}, then modify this vector itself; otherwise, return a new vector 474 @type changeInPlace: boolean 475 @return: the L{KeyedVector} instance""" 476 assert(element.tagName=='vector') 477 if changeInPlace: 478 node = element.firstChild 479 while node: 480 if node.nodeType == node.ELEMENT_NODE: 481 if node.tagName =='entry': 482 value = float(node.getAttribute('weight')) 483 child = node.firstChild 484 while child and child.nodeType != child.ELEMENT_NODE: 485 child = child.nextSibling 486 key = Key() 487 key = key.parse(child) 488 try: 489 self[key] = value 490 except: 491 print 'Ignoring:',child.toxml() 492 node = node.nextSibling 493 result = self 494 else: 495 # Determine what type of vector this is 496 vectorType = str(element.getAttribute('type')) 497 try: 498 cls = globals()['%sRow' % (vectorType)] 499 except KeyError: 500 cls = self.__class__ 501 vector = cls() 502 result = vector.parse(element,True) 503 return result
504
505 - def __copy__(self):
506 result = self.__class__() 507 result._array = copy.copy(self.getArray()) 508 result._order = self._order 509 result._orderedKeys = self._orderedKeys 510 if self._frozen: 511 result.freeze() 512 else: 513 result._frozen = True 514 result.unfreeze() 515 return result
516
517 - def __deepcopy__(self,memo):
518 result = KeyedVector() 519 memo[id(self)] = result 520 result._array = copy.deepcopy(self.getArray(),memo) 521 result._order = self._order 522 result._orderedKeys = self._orderedKeys 523 if self._frozen: 524 result.freeze() 525 else: 526 result._frozen = True 527 result.unfreeze() 528 return result
529 530
531 -class DeltaRow(KeyedVector):
532 """Subclass for rows used to compute deltas in dynamics""" 533 keyClass = Key 534 label = 'change' 535
536 - def __init__(self,args={},sourceKey=None,deltaKey=None,value=0.):
537 """ 538 @param sourceKey: the feature to be changed 539 @param deltaKey: the feature to use in computing the delta 540 @type sourceKey,deltaKey: L{Key} 541 @param value: the coefficient for that feature 542 @type value: C{float} 543 """ 544 self.sourceKey = sourceKey 545 self.deltaKey = deltaKey 546 if sourceKey is None and len(args) > 0: 547 raise UserWarning,'Use keyword arguments for typed row constructors' 548 if self.sourceKey is None: 549 KeyedVector.__init__(self,args) 550 else: 551 if deltaKey is None: 552 key = self.keyClass() 553 assert(isinstance(self.deltaKey,self.keyClass)) 554 row = {self.sourceKey:1.} 555 try: 556 row[self.deltaKey] += value 557 except KeyError: 558 row[self.deltaKey] = value 559 row.update(args) 560 KeyedVector.__init__(self,row)
561
562 - def instantiate(self,table):
563 source = self.sourceKey.instantiate(table) 564 if isinstance(source,list): 565 if len(source) > 1: 566 raise UserWarning,'Unable to instantiate ambiguous %s: %s' \ 567 % (self.__class__.__name__,self.simpleText()) 568 else: 569 source = source[0] 570 delta = self.deltaKey.instantiate(table) 571 if isinstance(delta,list): 572 if len(delta) > 1: 573 raise UserWarning,'Unable to instantiate ambiguous %s: %s' \ 574 % (self.__class__.__name__,self.simpleText()) 575 else: 576 delta = delta[0] 577 if source == delta: 578 return self.__class__(sourceKey=source,deltaKey=delta, 579 value=self[self.sourceKey]-1.) 580 else: 581 return self.__class__(sourceKey=source,deltaKey=delta, 582 value=self[self.deltaKey])
583
584 - def instantiateKeys(self,table):
585 """Substitutes values for any abstract references, using the given substitution table 586 @param table: dictionary of key-value pairs, where the value will be substituted for any appearance of the given key in a field of this L{Key} object 587 @type table: dictionary""" 588 KeyedVector.instantiateKeys(self,table) 589 keyList = self.sourceKey.instantiate(table) 590 if isinstance(keyList,list): 591 if len(keyList) > 1: 592 raise UserWarning,'Unable to instantiate ambiguous %s: %s' \ 593 % (self.__class__.__name__,self.simpleText()) 594 else: 595 self.sourceKey = keyList[0] 596 else: 597 self.sourceKey = keyList 598 keyList = self.deltaKey.instantiate(table) 599 if isinstance(keyList,list): 600 if len(keyList) > 1: 601 raise UserWarning,'Unable to instantiate ambiguous %s: %s' \ 602 % (self.__class__.__name__,self.simpleText()) 603 else: 604 self.deltaKey = keyList[0] 605 else: 606 self.deltaKey = keyList
607
608 - def __delitem__(self,key):
609 KeyedVector.__delitem__(self,key) 610 if key == self.sourceKey: 611 print 'Deleting source key from %s!' % (self.__class__.__name__) 612 elif key == self.deltaKey: 613 self.deltaKey = self.sourceKey
614
615 - def __copy__(self):
616 if self.sourceKey == self.deltaKey: 617 return self.__class__(sourceKey=self.sourceKey, 618 deltaKey=self.deltaKey, 619 value=self[self.deltaKey]-1.) 620 else: 621 return self.__class__(sourceKey=self.sourceKey, 622 deltaKey=self.deltaKey, 623 value=self[self.deltaKey])
624
625 - def __deepcopy__(self,memo):
626 if self.sourceKey == self.deltaKey: 627 result = self.__class__(sourceKey=self.sourceKey, 628 deltaKey=self.deltaKey, 629 value=self[self.deltaKey]-1.) 630 else: 631 result = self.__class__(sourceKey=self.sourceKey, 632 deltaKey=self.deltaKey, 633 value=self[self.deltaKey]) 634 memo[id(self)] = result 635 # Check whether there are any other keys to be inserted 636 for key in self.keys(): 637 if key != self.sourceKey and key != self.deltaKey: 638 break 639 else: 640 # Nope 641 return result 642 result._array = copy.deepcopy(self.getArray(),memo) 643 result._order = self._order 644 result._orderedKeys = self._orderedKeys 645 if self._frozen: 646 result.freeze() 647 else: 648 result._frozen = True 649 result.unfreeze() 650 return result
651
652 - def __xml__(self):
653 doc = KeyedVector.__xml__(self) 654 element = doc.documentElement 655 node = doc.createElement('source') 656 element.appendChild(node) 657 node.appendChild(self.sourceKey.__xml__().documentElement) 658 node = doc.createElement('delta') 659 element.appendChild(node) 660 node.appendChild(self.deltaKey.__xml__().documentElement) 661 element.setAttribute('type',self.__class__.__name__[:-3]) 662 return doc
663
664 - def parse(self,element,changeInPlace=True):
665 """Extracts the distribution from the given XML element 666 @param element: The XML Element object specifying the vector 667 @type element: Element 668 @param changeInPlace: flag, if C{True}, then modify this vector itself; otherwise, return a new vector 669 @type changeInPlace: boolean 670 @return: the L{KeyedVector} instance""" 671 if not changeInPlace: 672 return KeyedVector.parse(self,element,False) 673 KeyedVector.parse(self,element,True) 674 # Fill in the missing bits from XML 675 node = element.firstChild 676 while node: 677 if node.nodeType == node.ELEMENT_NODE: 678 if node.tagName == 'source': 679 child = node.firstChild 680 while child and child.nodeType != child.ELEMENT_NODE: 681 child = child.nextSibling 682 key = Key() 683 self.sourceKey = key.parse(child) 684 elif node.tagName == 'delta': 685 child = node.firstChild 686 while child and child.nodeType != child.ELEMENT_NODE: 687 child = child.nextSibling 688 key = Key() 689 self.deltaKey = key.parse(child) 690 self.value = self[self.deltaKey] 691 node = node.nextSibling 692 return self
693 694
695 -class SetToConstantRow(DeltaRow):
696 """A row that sets the value of a given feature to a specific value 697 @note: can possibly be abused to create a row that sets the value of one feature to a percentage of some other feature 698 """ 699 keyClass = ConstantKey 700 label = 'set to constant' 701
702 - def __init__(self,args={},sourceKey=None, 703 deltaKey=keyConstant,value=0.):
704 """ 705 @param sourceKey: the feature to be changed 706 @param deltaKey: the L{Key} for the value column; should be omitted 707 @type sourceKey: L{StateKey} 708 @type deltaKey: L{ConstantKey} 709 @param value: the new value 710 @type value: C{float} 711 """ 712 DeltaRow.__init__(self,args,sourceKey,deltaKey,value) 713 if self.sourceKey: 714 self[self.sourceKey] = 0.
715
716 - def simpleText(self,numbers=True,all=False):
717 return 'set to %4.2f' % (self[self.deltaKey])
718
719 -class SetToFeatureRow(DeltaRow):
720 """A row that sets the value of a given feature to a specific percentage of some other feature 721 """ 722 keyClass = StateKey 723 label = 'set to feature' 724
725 - def __init__(self,args={},sourceKey=None,deltaKey=keyConstant,value=0.):
726 """ 727 @param sourceKey: the feature to be changed 728 @param deltaKey: the L{Key} for the value column 729 @type sourceKey: L{StateKey} 730 @type deltaKey: L{StateKey} 731 @param value: the percentage to use as the new value 732 @type value: C{float} 733 """ 734 DeltaRow.__init__(self,args,sourceKey,deltaKey,value) 735 if self.sourceKey: 736 self[self.sourceKey] = 0.
737
738 - def simpleText(self,numbers=True,all=False):
739 return 'set to %d%% of %s' % (int(100.*self[self.deltaKey]), 740 self.deltaKey.simpleText())
741
742 -class IncrementRow(DeltaRow):
743 """A row that increments the given feature by a constant amount""" 744 keyClass = ConstantKey 745 label = 'add constant' 746
747 - def __init__(self,args={},sourceKey=None,deltaKey=keyConstant,value=0.):
748 """ 749 @param sourceKey: the feature to be changed 750 @param deltaKey: the L{Key} for the increment column; should be omitted 751 @type sourceKey: L{StateKey} 752 @type deltaKey: L{ConstantKey} 753 @param value: the amount of the increment 754 @type value: C{float} 755 """ 756 DeltaRow.__init__(self,args,sourceKey,deltaKey,value)
757
758 - def simpleText(self,numbers=True,all=False):
759 if self[keyConstant] < 0.: 760 return 'decrease by %5.3f' % (-self[keyConstant]) 761 else: 762 return 'increase by %5.3f' % (self[keyConstant])
763
764 -class ScaleRow(DeltaRow):
765 """A row that increases the given feature by a percentage of another feature""" 766 keyClass = StateKey 767 label = 'add feature' 768
769 - def simpleText(self,numbers=True,all=False):
770 try: 771 coefficient = self[self.deltaKey] 772 except KeyError: 773 return DeltaRow.simpleText(self,numbers=numbers,all=all) 774 if len(self) == 1: 775 coefficient -= 1. 776 coefficient *= 100. 777 if coefficient < 0.: 778 return 'decrease by %d%% of %s' % (-int(coefficient), 779 self.deltaKey.simpleText()) 780 else: 781 return 'increase by %d%% of %s' % (int(coefficient), 782 self.deltaKey.simpleText())
783
784 -class ActionCountRow(DeltaRow):
785 """A rows that sets the value to a count of a given type of action 786 """ 787 keyClass = ActionKey 788 label = 'action count' 789
790 - def simpleText(self,numbers=True,all=False):
791 return 'set to %d%% of # of %s' % (int(100.*self[self.deltaKey]), 792 self.deltaKey.simpleText())
793
794 -class UnchangedRow(IncrementRow):
795 """A row that doesn't change the given feature""" 796 label = 'no change' 797
798 - def __init__(self,args={},sourceKey=None,deltaKey=keyConstant,value=0.):
799 """ 800 @param sourceKey: the feature to be changed 801 @param deltaKey: the L{Key} for the increment column; should be omitted 802 @type sourceKey: L{StateKey} 803 @type deltaKey: L{ConstantKey} 804 @param value: the amount of the increment (ignored, always 0) 805 @type value: C{float} 806 """ 807 IncrementRow.__init__(self,args,sourceKey,deltaKey,0.)
808
809 - def simpleText(self,numbers=True,all=False):
810 return 'no change'
811
812 -def getDeltaTypes():
813 """Automatic extraction of possible L{DeltaRow} subclasses 814 @return: all available subclasses of L{DeltaRow} 815 @rtype: C{dict:str->class} 816 """ 817 import inspect 818 result = {} 819 for key,value in globals().items(): 820 if inspect.isclass(value) and issubclass(value,DeltaRow) and \ 821 not value is DeltaRow: 822 result[key[:-3]] = value 823 return result
824 deltaTypes = getDeltaTypes() 825
826 -class SlopeRow(KeyedVector):
827 """Subclass for rows used to represent the slope of planes 828 @cvar args: list of keys for this test 829 @cvar threshold: the default threshold for a plane of this type 830 @cvar relation: the default relation for a plane of this type 831 """ 832 args = [] 833 threshold = None 834 relation = None 835
836 - def __init__(self,args={},keys=None):
837 initial = {} 838 self.specialKeys = [] 839 if keys is not None: 840 for index in range(len(self.args)): 841 try: 842 key = keys[index] 843 if key.__class__.__name__ == 'dict': 844 key = self.args[index]['type'](key) 845 self.specialKeys.append(key) 846 except IndexError: 847 raise IndexError,'%s expects %d keys' % \ 848 (self.__class__.__name__,len(self.args)) 849 if key != keyDelete: 850 initial[key] = self.args[index]['weight'] 851 initial.update(args) 852 KeyedVector.__init__(self,initial)
853
854 - def instantiate(self,table):
855 """Substitutes values for any abstract references, using the given substitution table 856 @param table: dictionary of key-value pairs, where the value will be substituted for any appearance of the given key in a field of this L{Key} object 857 @type table: dictionary""" 858 keyList = [] 859 for index in range(len(self.specialKeys)): 860 newKey = self.specialKeys[index].instantiate(table) 861 if isinstance(newKey,list): 862 if len(newKey) > 1: 863 raise UserWarning,\ 864 'Unable to instantiate ambiguous %s: %s' \ 865 % (self.__class__.__name__,self.simpleText()) 866 else: 867 keyList.append(newKey[0]) 868 else: 869 keyList.append(newKey) 870 return self.__class__(keys=keyList)
871
872 - def instantiateKeys(self,table):
873 """Substitutes values for any abstract references, using the given substitution table 874 @param table: dictionary of key-value pairs, where the value will be substituted for any appearance of the given key in a field of this L{Key} object 875 @type table: dictionary""" 876 KeyedVector.instantiateKeys(self,table) 877 for index in range(len(self.specialKeys)): 878 keyList = self.specialKeys[index].instantiate(table) 879 if isinstance(keyList,list): 880 if len(keyList) > 1: 881 raise UserWarning,\ 882 'Unable to instantiate ambiguous %s: %s' \ 883 % (self.__class__.__name__,self.simpleText()) 884 else: 885 self.specialKeys[index] = keyList[0] 886 else: 887 self.specialKeys[index] = keyList
888
889 - def __copy__(self):
890 return self.__class__(keys=self.specialKeys)
891
892 - def __deepcopy__(self,memo):
893 result = self.__class__(keys=self.specialKeys) 894 memo[id(self)] = result 895 if len(self) > len(self.args): 896 result._array = copy.deepcopy(self.getArray(),memo) 897 result._order = self._order 898 result._orderedKeys = self._orderedKeys 899 if self._frozen: 900 result.freeze() 901 else: 902 result._frozen = True 903 result.unfreeze() 904 return result
905
906 - def __xml__(self):
907 doc = KeyedVector.__xml__(self) 908 element = doc.documentElement 909 element.setAttribute('type',self.__class__.__name__[:-3]) 910 for key in self.specialKeys: 911 node = doc.createElement('slopeKey') 912 element.appendChild(node) 913 node.appendChild(key.__xml__().documentElement) 914 return doc
915
916 - def parse(self,element,changeInPlace=True):
917 """Extracts the distribution from the given XML element 918 @param element: The XML Element object specifying the vector 919 @type element: Element 920 @param changeInPlace: flag, if C{True}, then modify this vector itself; otherwise, return a new vector 921 @type changeInPlace: boolean 922 @return: the L{KeyedVector} instance""" 923 if not changeInPlace: 924 return KeyedVector.parse(self,element,False) 925 KeyedVector.parse(self,element,True) 926 # Fill in the missing bits from XML 927 node = element.firstChild 928 while node: 929 if node.nodeType == node.ELEMENT_NODE: 930 if node.tagName == 'slopeKey': 931 child = node.firstChild 932 while child and child.nodeType != child.ELEMENT_NODE: 933 child = child.nextSibling 934 key = Key() 935 self.specialKeys.append(key.parse(child)) 936 node = node.nextSibling 937 # Patch bug in writing Equal Row 938 if len(self.specialKeys) < len(self.args): 939 node = element.firstChild 940 while node: 941 if node.nodeType == node.ELEMENT_NODE: 942 if node.tagName == 'entry': 943 child = node.firstChild 944 while child and child.nodeType != child.ELEMENT_NODE: 945 child = child.nextSibling 946 key = Key() 947 self.specialKeys.append(key.parse(child)) 948 node = node.nextSibling 949 assert len(self.specialKeys) == len(self.args) 950 return self
951
952 -class TrueRow(SlopeRow):
953 """A vector that produces a hyperplane that is always true 954 955 The following creates such a row: 956 957 >>> row = TrueRow() 958 """ 959 args = [] 960
961 - def simpleText(self,numbers=True,all=False):
962 return '...'
963
964 -class ThresholdRow(SlopeRow):
965 """A vector for testing that a given state feature exceeds a threshold. It should be created as C{row = ThresholdRow(keys=[{'entity':entity,'feature':feature}])} to create a test on the given C{feature} value of the given C{entity} 966 967 The following creates a row to test my power: 968 969 >>> row = ThresholdRow(keys=[{'entity':'self','feature':'power'}])""" 970 args = [{'type':StateKey,'weight':1.}] 971
972 - def simpleText(self,numbers=True,all=False):
973 return self.specialKeys[0].simpleText()
974
975 -class ClassRow(SlopeRow):
976 """A vector for testing that a given entity is a member of a given class. It should be created as C{row = ClassRow(keys=[{'entity':entity,'value':cls}])} to create a test that the given C{entity} is a member of the given C{cls} 977 978 The following creates a row to test that the actor is a teacher: 979 980 >>> row = ClassRow(keys=[{'entity':'actor','value':'Teacher'}])""" 981 982 args = [{'type':ClassKey,'weight':1.}] 983 threshold = 0. 984
985 - def simpleText(self,numbers=True,all=False):
986 return self.specialKeys[0].simpleText()
987
988 -class RelationshipRow(SlopeRow):
989 """A vector for testing that a given entity has the specified relationship to another entity. It should be created as C{row = RelationshipRow(keys=[{'feature':relation,'relatee':entity}])} to create a test that C{entity} is a C{relation} of me 990 991 The following creates a row to test that the actor is my student 992 993 >>> row = RelationshipRow(keys=[{'feature':'student','relatee':'actor'}]) 994 """ 995 args = [{'type':RelationshipKey,'weight':1.}] 996 threshold = 0. 997
998 - def simpleText(self,numbers=True,all=False):
999 return self.specialKeys[0].simpleText()
1000
1001 -class IdentityRow(SlopeRow):
1002 """A vector for testing that a given entity is identical to another. It should be created as C{row = IdentityRow(keys=[{'entity':entity}])} to create a test that the entity being tested is the given C{entity} 1003 1004 The following creates a row to test that I am the object of the current action: 1005 1006 >>> row = IdentityRow(keys=[{'entity':'object','relationship':'equals'}])""" 1007 args = [{'type':IdentityKey,'weight':1.}] 1008 threshold = 0. 1009
1010 - def simpleText(self,numbers=True,all=False):
1011 return self.specialKeys[0].simpleText()
1012
1013 -class SumRow(SlopeRow):
1014 """A vector for testing that the sum of two given state features exceeds a threshold. 1015 @warning: not tested, probably doesn't work""" 1016 args = [{'type':StateKey,'weight':1.}, 1017 {'type':StateKey,'weight':1.}] 1018
1019 - def simpleText(self,numbers=True,all=False):
1020 key1,key2 = self.specialKeys 1021 return '%s + %s' % (key1.simpleText(),key2.simpleText())
1022
1023 -class DifferenceRow(SlopeRow):
1024 """A vector for testing that the difference between two given state features exceeds a threshold. 1025 @warning: not tested, probably doesn't work""" 1026 args = [{'type':StateKey,'weight':1.}, 1027 {'type':StateKey,'weight':-1.}] 1028
1029 - def simpleText(self,numbers=True,all=False):
1030 key1,key2 = self.specialKeys 1031 if self[key1] > 0.: 1032 pos = key1 1033 neg = key2 1034 else: 1035 pos = key2 1036 neg = key1 1037 return '%s - %s' % (pos.simpleText(),neg.simpleText())
1038
1039 -class EqualRow(SlopeRow):
1040 """A vector for testing that two given state features have the same value. 1041 @warning: not tested, probably doesn't work""" 1042 args = [{'type':StateKey,'weight':1.}, 1043 {'type':StateKey,'weight':-1.}] 1044 threshold = 0. 1045 relation = '=' 1046
1047 - def simpleText(self,numbers=True,all=False):
1048 key1,key2 = self.specialKeys 1049 if self[key1] > 0.: 1050 pos = key1 1051 neg = key2 1052 else: 1053 pos = key2 1054 neg = key1 1055 return '%s = %s' % (pos.simpleText(),neg.simpleText())
1056
1057 -class ANDRow(SlopeRow):
1058 """Subclass representing a conjunction of tests. 1059 1060 The following creates a row to test that the current state is both 'terminated' and 'accepted': 1061 1062 >>> row = ANDRow(keys=[{'entity':'self','feature':'terminated'},{'entity':'self','feature':'accepted'}]) 1063 """
1064 - def __init__(self,args={},keys=None):
1065 SlopeRow.__init__(self,args,keys) 1066 if keys: 1067 self.specialKeys = keys[:] 1068 else: 1069 self.specialKeys = []
1070
1071 - def simpleText(self,numbers=True,all=False):
1072 content = [] 1073 for key in self.specialKeys: 1074 if self[key] > 0.: 1075 content.append(key.simpleText()) 1076 else: 1077 content.append('not %s' % (key.simpleText())) 1078 return string.join(content,' and ')
1079
1080 -class ORRow(SlopeRow):
1081 """Subclass representing a disjunction of tests. 1082 1083 The following creates a row to test that the current state is either 'terminated' or 'accepted': 1084 1085 >>> row = ORRow(keys=[{'entity':'self','feature':'terminated'},{'entity':'self','feature':'accepted'}]) 1086 """
1087 - def __init__(self,args={},keys=None):
1088 SlopeRow.__init__(self,args,keys) 1089 if keys: 1090 self.specialKeys = keys[:] 1091 else: 1092 self.specialKeys = []
1093
1094 - def simpleText(self,numbers=True,all=False):
1095 content = [] 1096 for key in self.specialKeys: 1097 if self[key] > 0.: 1098 content.append(key.simpleText()) 1099 else: 1100 content.append('not %s' % (key.simpleText())) 1101 return string.join(content,' or ')
1102
1103 -def getSlopeTypes():
1104 """Automatic extraction of possible L{SlopeRow} subclasses 1105 @return: all available subclasses of L{SlopeRow} 1106 @rtype: C{dict:str->class} 1107 """ 1108 import inspect 1109 result = {} 1110 for key,value in globals().items(): 1111 if inspect.isclass(value) and issubclass(value,SlopeRow) and \ 1112 not value is SlopeRow: 1113 result[key[:-3]] = value 1114 # The following types are not easily supported by GUI 1115 del result['OR'] 1116 del result['AND'] 1117 # The following is not intended to be used, but is merely for display 1118 del result['True'] 1119 return result
1120 slopeTypes = getSlopeTypes() 1121 vectorTypes = copy.copy(deltaTypes) 1122 vectorTypes.update(slopeTypes) 1123 1124 if __name__ == '__main__': 1125 old = TrueRow() 1126 doc = old.__xml__() 1127 new = KeyedVector() 1128 new = new.parse(doc.documentElement) 1129 print new.__class__.__name__ 1130 print new.simpleText() 1131