1 import copy
2 import random
3 from xml.dom.minidom import *
4
5 from Keys import ConstantKey
6 from KeyedVector import KeyedVector
7
9 """Helper method used by L{Distribution.join} method
10 @param table: the dictionary to modify
11 @type table: dictionary
12 @param key: the entry to set in that dictionary
13 @type key: hashable instance
14 @param value: the value to stick into the dictionary
15 @return: a I{copy} of the original dictionary with the specified key-value association inserted
16 """
17 new = copy.copy(table)
18 new[key] = value
19 return new
20
22 """
23 A probability distribution
24
25 - C{dist.L{domain}()}: Returns the domain of possible values
26 - C{dist.L{items}()}: Returns the list of all (value,prob) pairs
27 - C{dist[value]}: Returns the probability of the given value
28 - C{dist[value] = x}: Sets the probability of the given value to x
29
30 The possible domain values are any objects
31 @warning: If you make the domain values mutable types, try not to change the values while they are inside the distribution. If you must change a domain value, it is better to first delete the old value, change it, and then re-insert it.
32 @cvar epsilon: the granularity for float comparisons
33 @type epsilon: float
34 """
35 epsilon = 0.0001
36
43
45 key = str(element)
46 return dict.__getitem__(self,key)
47
49 """
50 @param element: the domain element
51 @param value: the probability to associate with the given key
52 @type value: float
53 @warning: raises an C{AssertionError} if setting to an invalid prob value"""
54 assert(value > -self.epsilon,
55 "Negative probability value: %f" % (value))
56 assert(value < 1.+self.epsilon,
57 "Probability value exceeds 1: %f" % (value))
58 key = str(element)
59 self._domain[key] = element
60 dict.__setitem__(self,key,value)
61
63 key = str(element)
64 dict.__delitem__(self,key)
65 del self._domain[key]
66
70
72 """Replaces on element in the sample space with another. Raises an exception if the original element does not exist, and an exception if the new element already exists (i.e., does not do a merge)
73 """
74 prob = self[old]
75 del self[old]
76 self[new] = prob
77
79 """
80 @return: the sample space of this probability distribution
81 @rtype: C{list}
82 """
83 return self._domain.values()
84
86 """
87 @return: a list of tuples of value,probability pairs
88 @rtype: (value,float)[]
89 """
90 return map(lambda k:(self._domain[k],dict.__getitem__(self,k)),
91 self.keys())
92
93 - def domainKeys(self):
94 """
95 @return: all keys contained in the domain values
96 @rtype: C{dict:L{teamwork.math.Keys.Key}S{->}boolean}
97 """
98 keys = {}
99 for value in self.domain():
100 for key in value.keys():
101 keys[key] = True
102 return keys
103
105 """Normalizes the distribution so that the sum of values = 1
106 @note: Not sure if this is really necessary"""
107 total = sum(self.values())
108 if abs(total-1.) > self.epsilon:
109 for key,value in self.items():
110 try:
111 self[key] /= total
112 except ZeroDivisionError:
113 self[key] = 1./float(len(self))
114
116 """Marginalizes the distribution to remove the given key (not in place! returns the new distribution)
117 @param key: the key to marginalize over
118 @return: a new L{Distribution} object representing the marginal distribution
119 @note: no exception is raised if the key is not present"""
120 result = self.__class__()
121 for row,prob in self.items():
122 new = copy.copy(row)
123 new.unfreeze()
124 try:
125 del new[key]
126 except KeyError:
127 pass
128 try:
129 result[new] += prob
130 except KeyError:
131 result[new] = prob
132 return result
133
135 """Marginalizes the distribution over all but the given key
136 @param key: the key to compute the marginal distribution over
137 @return: a new L{Distribution} object representing the marginal"""
138 result = self.__class__()
139 for row,prob in self.items():
140 try:
141 value = row[key]
142 except KeyError:
143
144
145 value = 0.
146 try:
147 result[value] += prob
148 except KeyError:
149 result[value] = prob
150 return result
151
152 - def join(self,key,value,debug=False):
153 """Returns the joint distribution that includes the given key
154 @param key: any hashable instance
155 @param value: if a L{Distribution}, the marginal distribution for the given key; otherwise, the marginal distribution is assumed to be I{P(key=value)=1}
156 @return: the joint distribution combining the current distribution with the specified marginal over the given key
157 @warning: this method assumes that this L{Distribution} has domain values that are C{dict} instances (i.e., for each domain element C{e}, it can set C{e[key]=value})...in other words, there should probably be a subclass."""
158 if isinstance(value,Distribution):
159
160
161
162
163
164
165
166
167
168
169
170
171 self.compose(value,lambda x,y,k=key:setitemAndReturn(x,k,y),
172 replace=True,debug=debug)
173 else:
174 for row,prob in self.items():
175 del self[row]
176 row[key] = value
177 self[row] = prob
178 return self
179
181 """Returns the expected value of this distribution
182
183 @warning: As a side effect, the distribution will be normalized"""
184 if len(self) == 1:
185
186 return self.domain()[0]
187 else:
188
189
190 self.normalize()
191 total = None
192 for key,value in self.items():
193 if total is None:
194 total = key*value
195 else:
196 total += key*value
197 return total
198
200 """Removes any zero-probability entries from this distribution
201 @return: the pruned distribution (not a copy)"""
202 for key,value in self.items():
203 if abs(value) < self.epsilon:
204 del self[key]
205 return self
206
207 - def fill(self,keys,value=0.):
208 """Fills in any missing rows/columns in the domain matrices with a default value
209 @param keys: the new slots that should be filled
210 @type keys: C{L{teamwork.math.Keys.Key}[]}
211 @param value: the default value (default is 0.)
212 @note: essentially calls appropriate C{fill} method for any domain objects
213 """
214 for element,prob in self.items():
215 del self[element]
216 element.fill(keys,value)
217 self[element] = prob
218
219
221 """Locks in the dimensions and keys of all domain values"""
222 for element in self.domain():
223 element.freeze()
224
226 """Unlocks in the dimensions and keys of all domain values"""
227 for element in self.domain():
228 element.unfreeze()
229
231 """Substitutes values for any abstract references, using the
232 given substitution table
233 @param table: dictionary of key-value pairs, where the value will be substituted for any appearance of the given key in a field of this L{teamwork.math.Keys.Key} object
234 @type table: dictionary"""
235 result = self.__class__()
236 for key,element in self._domain.items():
237 prob = dict.__getitem__(self,key)
238 new = element.instantiate(table)
239 if key == str(element):
240
241 key = str(new)
242 result._domain[key] = new
243 try:
244 prob += dict.__getitem__(result,key)
245 except KeyError:
246 pass
247 dict.__setitem__(result,key,prob)
248 return result
249
251 """Substitutes values for any abstract references, using the
252 given substitution table
253 @param table: dictionary of key-value pairs, where the value will be substituted for any appearance of the given key in a field of this L{teamwork.math.Keys.Key} object
254 @type table: dictionary"""
255 for key,element in self._domain.items():
256 prob = dict.__getitem__(self,key)
257
258 update = key == str(element)
259 element.instantiateKeys(table)
260 if update:
261 dict.__delitem__(self,key)
262 del self._domain[key]
263 key = str(element)
264 self._domain[key] = element
265 dict.__setitem__(self,key,prob)
266
267 - def compose(self,other,operator,replace=False,debug=False):
268 """Composes this distribution with the other given, using the given op
269 @param other: a L{Distribution} object, or an object of the same class as the keys in this Distribution object
270 @param operator: a binary operator applicable to the class of keys in this L{Distribution} object
271 @param replace: if this flag is true, the result modifies this distribution itself
272 @return: the composed distribution"""
273 if replace:
274 result = self
275 else:
276 result = self.__class__()
277 original = self.items()
278 if replace:
279 self.clear()
280 for key1,value1 in original:
281 if isinstance(other,Distribution):
282 for key2,value2 in other.items():
283 key = apply(operator,(key1,key2))
284 prob = value1*value2
285 try:
286 result[key] += prob
287 except KeyError:
288 result[key] = prob
289 if debug:
290 print key1,key2
291 print '\t->',key
292 print '\t=',result[key]
293 else:
294 key = apply(operator,(key1,other) )
295 try:
296 result[key] += value1
297 except KeyError:
298 result[key] = value1
299 return result
300
302 """
303 @note: Also supports + operator between Distribution object and objects of the same class as its keys"""
304 return self.compose(other,lambda x,y:x+y)
305
307 result = self.__class__()
308 for key,value in self.items():
309 if not key:
310 raise UserWarning
311 result[-key] = value
312 return result
313
315 """@note: Also supports - operator between L{Distribution} object and objects of the same class as its keys"""
316 return self + (-other)
317
319 """@note: Also supports * operator between L{Distribution} object and objects of the same class as its keys"""
320 return self.compose(other,lambda x,y:x*y)
321
322
324 if isinstance(other,Distribution):
325 return self.conditional(other,{})
326 else:
327 return self * (1./other)
328
330 """Computes a conditional probability, given this joint probability I{P(AB)}, the marginal probability I{P(B)}, and the value of I{B} being conditioned on
331 @param other: the marginal probability, I{P(B)}
332 @type other: L{Distribution}
333 @param value: the value of I{B}
334 @type value: L{KeyedVector} (if omitted, it's assumed that both I{P(AB)} and I{P(B)} have already been conditioned on the desired value)
335 @return: I{P(A|B=C{value})} where C{self} is I{P(AB)}
336 @rtype: L{Distribution}
337 """
338 result = {}
339 for myValue,myProb in self.items():
340 for yrValue,yrProb in other.items():
341 for key in value.keys():
342 if not yrValue.has_key(key) \
343 or yrValue[key] != value[key]:
344 break
345 else:
346 for key in yrValue.keys():
347 if not myValue.has_key(key) \
348 or myValue[key] != yrValue[key]:
349 break
350 else:
351 new = copy.copy(myValue)
352 frozen = new.unfreeze()
353 for key in yrValue.keys():
354 if not isinstance(key,ConstantKey):
355 del new[key]
356 if frozen:
357 new.freeze()
358 try:
359 result[new] += myProb/yrProb
360 except KeyError:
361 result[new] = myProb/yrProb
362 return Distribution(result)
363
364
365 - def reachable(self,estimators,observations,horizon):
366 """Computes any reachable distributions from this one
367 @param estimators: any possible conditional probability distributions, expressed as dictionaries, each containing C{numerator} and C{denominator} fields
368 @type estimators: dict[]
369 @param observations: any possible observations
370 @type observations: L{KeyedVector}[]
371 @param horizon: the maximum length of observation sequences to consider (if less than 1, then only the current distribution is reachable)
372 @return: all the reachable distributions
373 @rtype: L{Distribution}[]
374 """
375 if horizon <= 0:
376 return [self]
377 reachable = {str(self):self}
378 for estimator in estimators:
379 numerator = estimator['numerator']*self
380 denominator = estimator['denominator']*self
381 for obs in observations:
382 posterior = numerator.conditional(denominator,obs)
383 for beliefState in posterior.reachable(estimators,observations,
384 horizon-1):
385 key= str(beliefState)
386 if not reachable.has_key(key):
387 reachable[key] = beliefState
388 return reachable.values()
389
391 """
392 @return: a single element from the sample space, chosen randomly according to this distribution.
393 """
394 elements = self.domain()
395 elements.sort()
396 total = random.random()
397 index = 0
398 while total > self[elements[index]]:
399 total -= self[elements[index]]
400 index += 1
401 return elements[index]
402
404 """Supports float conversion of distributions by returning EV.
405 Invoked by calling C{float(self)}"""
406 return float(self.expectation())
407
409 """Returns a pretty string representation of this distribution"""
410 return self.simpleText()
411
412
413
414
415
416 - def simpleText(self,numbers=True,all=False):
417 """
418 @param numbers: if C{True}, returns a number-free representation of this distribution
419 """
420 content = ''
421 for value,prob in self.items():
422 try:
423 label = value.simpleText(numbers=numbers,all=all)
424 except TypeError:
425 label = value.simpleText()
426 except AttributeError:
427 label = str(value)
428 if numbers:
429 level = 'probability %5.3f' % (prob)
430 elif prob < .3:
431 level = 'low probability'
432 elif prob < .6:
433 level = 'medium likelihood'
434 elif prob < 1.:
435 level = 'high probability'
436 else:
437 level = 'certainty'
438 content += '%s with %s, ' % (label,level)
439 return content[:-2]
440
442 """@return: An XML Document object representing this distribution"""
443 doc = Document()
444 root = doc.createElement('distribution')
445 doc.appendChild(root)
446 for key,value in self._domain.items():
447 prob = dict.__getitem__(self,key)
448 node = doc.createElement('entry')
449 root.appendChild(node)
450 node.setAttribute('probability',str(prob))
451 node.setAttribute('key',key)
452 if not isinstance(value,str):
453 node.appendChild(value.__xml__().documentElement)
454 return doc
455
457 result = self.__class__()
458 result._domain.update(self._domain)
459 for element in result._domain.keys():
460 dict.__setitem__(result,element,dict.__getitem__(self,element))
461 return result
462
464 result = self.__class__()
465 memo[id(self)] = result
466 result._domain = copy.deepcopy(self._domain,memo)
467 for element in result._domain.keys():
468 dict.__setitem__(result,element,dict.__getitem__(self,element))
469 return result
470
471 - def parse(self,element,valueClass=None):
472 """Extracts the distribution from the given XML element
473 @param element: The XML Element object specifying the distribution
474 @type element: Element
475 @param valueClass: The class used to generate the domain values for this distribution
476 @return: This L{Distribution} object"""
477 assert(element.tagName == 'distribution')
478 self.clear()
479 node = element.firstChild
480 while node:
481 if node.nodeType == Node.ELEMENT_NODE:
482 prob = float(node.getAttribute('probability'))
483 key = str(node.getAttribute('key'))
484 subNode = node.firstChild
485 while subNode and subNode.nodeType != Node.ELEMENT_NODE:
486 subNode = subNode.nextSibling
487 if subNode:
488 value = valueClass()
489 value = value.parse(subNode)
490 else:
491 value = key
492 if value is None:
493 raise UserWarning,'XML parsing method for %s has null return value' % (valueClass.__name__)
494 self[key] = prob
495 self._domain[key] = value
496 node = node.nextSibling
497 return self
498