Package teamwork :: Package math :: Module probability
[hide private]
[frames] | no frames]

Source Code for Module teamwork.math.probability

  1  import copy 
  2  import random 
  3  from xml.dom.minidom import * 
  4   
  5  from Keys import ConstantKey 
  6  from KeyedVector import KeyedVector 
  7   
8 -def setitemAndReturn(table,key,value):
9 """Helper method used by L{Distribution.join} method 10 @param table: the dictionary to modify 11 @type table: dictionary 12 @param key: the entry to set in that dictionary 13 @type key: hashable instance 14 @param value: the value to stick into the dictionary 15 @return: a I{copy} of the original dictionary with the specified key-value association inserted 16 """ 17 new = copy.copy(table) 18 new[key] = value 19 return new
20
21 -class Distribution(dict):
22 """ 23 A probability distribution 24 25 - C{dist.L{domain}()}: Returns the domain of possible values 26 - C{dist.L{items}()}: Returns the list of all (value,prob) pairs 27 - C{dist[value]}: Returns the probability of the given value 28 - C{dist[value] = x}: Sets the probability of the given value to x 29 30 The possible domain values are any objects 31 @warning: If you make the domain values mutable types, try not to change the values while they are inside the distribution. If you must change a domain value, it is better to first delete the old value, change it, and then re-insert it. 32 @cvar epsilon: the granularity for float comparisons 33 @type epsilon: float 34 """ 35 epsilon = 0.0001 36
37 - def __init__(self,args=None):
38 self._domain = {} 39 dict.__init__(self) 40 if not args is None: 41 for key,value in args.items(): 42 self[key] = value
43
44 - def __getitem__(self,element):
45 key = str(element) 46 return dict.__getitem__(self,key)
47
48 - def __setitem__(self,element,value):
49 """ 50 @param element: the domain element 51 @param value: the probability to associate with the given key 52 @type value: float 53 @warning: raises an C{AssertionError} if setting to an invalid prob value""" 54 assert(value > -self.epsilon, 55 "Negative probability value: %f" % (value)) 56 assert(value < 1.+self.epsilon, 57 "Probability value exceeds 1: %f" % (value)) 58 key = str(element) 59 self._domain[key] = element 60 dict.__setitem__(self,key,value)
61
62 - def __delitem__(self,element):
63 key = str(element) 64 dict.__delitem__(self,key) 65 del self._domain[key]
66
67 - def clear(self):
68 dict.clear(self) 69 self._domain.clear()
70
71 - def replace(self,old,new):
72 """Replaces on element in the sample space with another. Raises an exception if the original element does not exist, and an exception if the new element already exists (i.e., does not do a merge) 73 """ 74 prob = self[old] 75 del self[old] 76 self[new] = prob
77
78 - def domain(self):
79 """ 80 @return: the sample space of this probability distribution 81 @rtype: C{list} 82 """ 83 return self._domain.values()
84
85 - def items(self):
86 """ 87 @return: a list of tuples of value,probability pairs 88 @rtype: (value,float)[] 89 """ 90 return map(lambda k:(self._domain[k],dict.__getitem__(self,k)), 91 self.keys())
92
93 - def domainKeys(self):
94 """ 95 @return: all keys contained in the domain values 96 @rtype: C{dict:L{teamwork.math.Keys.Key}S{->}boolean} 97 """ 98 keys = {} 99 for value in self.domain(): 100 for key in value.keys(): 101 keys[key] = True 102 return keys
103
104 - def normalize(self):
105 """Normalizes the distribution so that the sum of values = 1 106 @note: Not sure if this is really necessary""" 107 total = sum(self.values()) 108 if abs(total-1.) > self.epsilon: 109 for key,value in self.items(): 110 try: 111 self[key] /= total 112 except ZeroDivisionError: 113 self[key] = 1./float(len(self))
114
115 - def marginalize(self,key):
116 """Marginalizes the distribution to remove the given key (not in place! returns the new distribution) 117 @param key: the key to marginalize over 118 @return: a new L{Distribution} object representing the marginal distribution 119 @note: no exception is raised if the key is not present""" 120 result = self.__class__() 121 for row,prob in self.items(): 122 new = copy.copy(row) 123 new.unfreeze() 124 try: 125 del new[key] 126 except KeyError: 127 pass 128 try: 129 result[new] += prob 130 except KeyError: 131 result[new] = prob 132 return result
133
134 - def getMarginal(self,key):
135 """Marginalizes the distribution over all but the given key 136 @param key: the key to compute the marginal distribution over 137 @return: a new L{Distribution} object representing the marginal""" 138 result = self.__class__() 139 for row,prob in self.items(): 140 try: 141 value = row[key] 142 except KeyError: 143 # If no entry, then assume 0 value 144 # (maybe there are domains where this is incorrect?) 145 value = 0. 146 try: 147 result[value] += prob 148 except KeyError: 149 result[value] = prob 150 return result
151
152 - def join(self,key,value,debug=False):
153 """Returns the joint distribution that includes the given key 154 @param key: any hashable instance 155 @param value: if a L{Distribution}, the marginal distribution for the given key; otherwise, the marginal distribution is assumed to be I{P(key=value)=1} 156 @return: the joint distribution combining the current distribution with the specified marginal over the given key 157 @warning: this method assumes that this L{Distribution} has domain values that are C{dict} instances (i.e., for each domain element C{e}, it can set C{e[key]=value})...in other words, there should probably be a subclass.""" 158 if isinstance(value,Distribution): 159 # # First, remove any previous values 160 # for row,prob in self.items(): 161 # del self[row] 162 # try: 163 # del row[key] 164 # except KeyError: 165 # pass 166 # try: 167 # self[row] += prob 168 # except KeyError: 169 # self[row] = prob 170 # Then, create new distribution 171 self.compose(value,lambda x,y,k=key:setitemAndReturn(x,k,y), 172 replace=True,debug=debug) 173 else: 174 for row,prob in self.items(): 175 del self[row] 176 row[key] = value 177 self[row] = prob 178 return self
179
180 - def expectation(self):
181 """Returns the expected value of this distribution 182 183 @warning: As a side effect, the distribution will be normalized""" 184 if len(self) == 1: 185 # Shortcut if no uncertainty 186 return self.domain()[0] 187 else: 188 # I suppose we could just assume that the distribution is already 189 # normalized 190 self.normalize() 191 total = None 192 for key,value in self.items(): 193 if total is None: 194 total = key*value 195 else: 196 total += key*value 197 return total
198
199 - def prune(self):
200 """Removes any zero-probability entries from this distribution 201 @return: the pruned distribution (not a copy)""" 202 for key,value in self.items(): 203 if abs(value) < self.epsilon: 204 del self[key] 205 return self
206
207 - def fill(self,keys,value=0.):
208 """Fills in any missing rows/columns in the domain matrices with a default value 209 @param keys: the new slots that should be filled 210 @type keys: C{L{teamwork.math.Keys.Key}[]} 211 @param value: the default value (default is 0.) 212 @note: essentially calls appropriate C{fill} method for any domain objects 213 """ 214 for element,prob in self.items(): 215 del self[element] 216 element.fill(keys,value) 217 self[element] = prob
218 219
220 - def freeze(self):
221 """Locks in the dimensions and keys of all domain values""" 222 for element in self.domain(): 223 element.freeze()
224
225 - def unfreeze(self):
226 """Unlocks in the dimensions and keys of all domain values""" 227 for element in self.domain(): 228 element.unfreeze()
229
230 - def instantiate(self,table):
231 """Substitutes values for any abstract references, using the 232 given substitution table 233 @param table: dictionary of key-value pairs, where the value will be substituted for any appearance of the given key in a field of this L{teamwork.math.Keys.Key} object 234 @type table: dictionary""" 235 result = self.__class__() 236 for key,element in self._domain.items(): 237 prob = dict.__getitem__(self,key) 238 new = element.instantiate(table) 239 if key == str(element): 240 # Make sure new key matches new element 241 key = str(new) 242 result._domain[key] = new 243 try: 244 prob += dict.__getitem__(result,key) 245 except KeyError: 246 pass 247 dict.__setitem__(result,key,prob) 248 return result
249
250 - def instantiateKeys(self,table):
251 """Substitutes values for any abstract references, using the 252 given substitution table 253 @param table: dictionary of key-value pairs, where the value will be substituted for any appearance of the given key in a field of this L{teamwork.math.Keys.Key} object 254 @type table: dictionary""" 255 for key,element in self._domain.items(): 256 prob = dict.__getitem__(self,key) 257 # Check whether we need to synch key with new value 258 update = key == str(element) 259 element.instantiateKeys(table) 260 if update: 261 dict.__delitem__(self,key) 262 del self._domain[key] 263 key = str(element) 264 self._domain[key] = element 265 dict.__setitem__(self,key,prob)
266
267 - def compose(self,other,operator,replace=False,debug=False):
268 """Composes this distribution with the other given, using the given op 269 @param other: a L{Distribution} object, or an object of the same class as the keys in this Distribution object 270 @param operator: a binary operator applicable to the class of keys in this L{Distribution} object 271 @param replace: if this flag is true, the result modifies this distribution itself 272 @return: the composed distribution""" 273 if replace: 274 result = self 275 else: 276 result = self.__class__() 277 original = self.items() 278 if replace: 279 self.clear() 280 for key1,value1 in original: 281 if isinstance(other,Distribution): 282 for key2,value2 in other.items(): 283 key = apply(operator,(key1,key2)) 284 prob = value1*value2 285 try: 286 result[key] += prob 287 except KeyError: 288 result[key] = prob 289 if debug: 290 print key1,key2 291 print '\t->',key 292 print '\t=',result[key] 293 else: 294 key = apply(operator,(key1,other) ) 295 try: 296 result[key] += value1 297 except KeyError: 298 result[key] = value1 299 return result
300
301 - def __add__(self,other):
302 """ 303 @note: Also supports + operator between Distribution object and objects of the same class as its keys""" 304 return self.compose(other,lambda x,y:x+y)
305
306 - def __neg__(self):
307 result = self.__class__() 308 for key,value in self.items(): 309 if not key: 310 raise UserWarning 311 result[-key] = value 312 return result
313
314 - def __sub__(self,other):
315 """@note: Also supports - operator between L{Distribution} object and objects of the same class as its keys""" 316 return self + (-other)
317
318 - def __mul__(self,other):
319 """@note: Also supports * operator between L{Distribution} object and objects of the same class as its keys""" 320 return self.compose(other,lambda x,y:x*y)
321 322
323 - def __div__(self,other):
324 if isinstance(other,Distribution): 325 return self.conditional(other,{}) 326 else: 327 return self * (1./other)
328
329 - def conditional(self,other,value={}):
330 """Computes a conditional probability, given this joint probability I{P(AB)}, the marginal probability I{P(B)}, and the value of I{B} being conditioned on 331 @param other: the marginal probability, I{P(B)} 332 @type other: L{Distribution} 333 @param value: the value of I{B} 334 @type value: L{KeyedVector} (if omitted, it's assumed that both I{P(AB)} and I{P(B)} have already been conditioned on the desired value) 335 @return: I{P(A|B=C{value})} where C{self} is I{P(AB)} 336 @rtype: L{Distribution} 337 """ 338 result = {} 339 for myValue,myProb in self.items(): 340 for yrValue,yrProb in other.items(): 341 for key in value.keys(): 342 if not yrValue.has_key(key) \ 343 or yrValue[key] != value[key]: 344 break 345 else: 346 for key in yrValue.keys(): 347 if not myValue.has_key(key) \ 348 or myValue[key] != yrValue[key]: 349 break 350 else: 351 new = copy.copy(myValue) 352 frozen = new.unfreeze() 353 for key in yrValue.keys(): 354 if not isinstance(key,ConstantKey): 355 del new[key] 356 if frozen: 357 new.freeze() 358 try: 359 result[new] += myProb/yrProb 360 except KeyError: 361 result[new] = myProb/yrProb 362 return Distribution(result)
363 364
365 - def reachable(self,estimators,observations,horizon):
366 """Computes any reachable distributions from this one 367 @param estimators: any possible conditional probability distributions, expressed as dictionaries, each containing C{numerator} and C{denominator} fields 368 @type estimators: dict[] 369 @param observations: any possible observations 370 @type observations: L{KeyedVector}[] 371 @param horizon: the maximum length of observation sequences to consider (if less than 1, then only the current distribution is reachable) 372 @return: all the reachable distributions 373 @rtype: L{Distribution}[] 374 """ 375 if horizon <= 0: 376 return [self] 377 reachable = {str(self):self} 378 for estimator in estimators: 379 numerator = estimator['numerator']*self 380 denominator = estimator['denominator']*self 381 for obs in observations: 382 posterior = numerator.conditional(denominator,obs) 383 for beliefState in posterior.reachable(estimators,observations, 384 horizon-1): 385 key= str(beliefState) 386 if not reachable.has_key(key): 387 reachable[key] = beliefState 388 return reachable.values()
389
390 - def sample(self):
391 """ 392 @return: a single element from the sample space, chosen randomly according to this distribution. 393 """ 394 elements = self.domain() 395 elements.sort() 396 total = random.random() 397 index = 0 398 while total > self[elements[index]]: 399 total -= self[elements[index]] 400 index += 1 401 return elements[index]
402
403 - def __float__(self):
404 """Supports float conversion of distributions by returning EV. 405 Invoked by calling C{float(self)}""" 406 return float(self.expectation())
407
408 - def __str__(self):
409 """Returns a pretty string representation of this distribution""" 410 return self.simpleText()
411 ## content = '' 412 ## for value,prob in self.items(): 413 ## content += '%s with probability %5.3f, ' % (str(value),prob) 414 ## return content[:-2] 415
416 - def simpleText(self,numbers=True,all=False):
417 """ 418 @param numbers: if C{True}, returns a number-free representation of this distribution 419 """ 420 content = '' 421 for value,prob in self.items(): 422 try: 423 label = value.simpleText(numbers=numbers,all=all) 424 except TypeError: 425 label = value.simpleText() 426 except AttributeError: 427 label = str(value) 428 if numbers: 429 level = 'probability %5.3f' % (prob) 430 elif prob < .3: 431 level = 'low probability' 432 elif prob < .6: 433 level = 'medium likelihood' 434 elif prob < 1.: 435 level = 'high probability' 436 else: 437 level = 'certainty' 438 content += '%s with %s, ' % (label,level) 439 return content[:-2]
440
441 - def __xml__(self):
442 """@return: An XML Document object representing this distribution""" 443 doc = Document() 444 root = doc.createElement('distribution') 445 doc.appendChild(root) 446 for key,value in self._domain.items(): 447 prob = dict.__getitem__(self,key) 448 node = doc.createElement('entry') 449 root.appendChild(node) 450 node.setAttribute('probability',str(prob)) 451 node.setAttribute('key',key) 452 if not isinstance(value,str): 453 node.appendChild(value.__xml__().documentElement) 454 return doc
455
456 - def __copy__(self):
457 result = self.__class__() 458 result._domain.update(self._domain) 459 for element in result._domain.keys(): 460 dict.__setitem__(result,element,dict.__getitem__(self,element)) 461 return result
462
463 - def __deepcopy__(self,memo):
464 result = self.__class__() 465 memo[id(self)] = result 466 result._domain = copy.deepcopy(self._domain,memo) 467 for element in result._domain.keys(): 468 dict.__setitem__(result,element,dict.__getitem__(self,element)) 469 return result
470
471 - def parse(self,element,valueClass=None):
472 """Extracts the distribution from the given XML element 473 @param element: The XML Element object specifying the distribution 474 @type element: Element 475 @param valueClass: The class used to generate the domain values for this distribution 476 @return: This L{Distribution} object""" 477 assert(element.tagName == 'distribution') 478 self.clear() 479 node = element.firstChild 480 while node: 481 if node.nodeType == Node.ELEMENT_NODE: 482 prob = float(node.getAttribute('probability')) 483 key = str(node.getAttribute('key')) 484 subNode = node.firstChild 485 while subNode and subNode.nodeType != Node.ELEMENT_NODE: 486 subNode = subNode.nextSibling 487 if subNode: 488 value = valueClass() 489 value = value.parse(subNode) 490 else: 491 value = key 492 if value is None: 493 raise UserWarning,'XML parsing method for %s has null return value' % (valueClass.__name__) 494 self[key] = prob 495 self._domain[key] = value 496 node = node.nextSibling 497 return self
498