1 from teamwork.math.Keys import *
2 from teamwork.math.KeyedMatrix import *
3 from teamwork.math.probability import *
4 from teamwork.math.ProbabilityTree import *
5 from testPWL import makeVector,makePlane
6 import random
7
8 import unittest
9
11
13 distribution = Distribution()
14 self.key = StateKey({'entity':'Bill',
15 'feature':'power'})
16 row1 = KeyedVector({self.key:0.2})
17 row2 = KeyedVector({self.key:0.4})
18 self.rows = [row1,row2]
19 distribution[row1] = 1.
20 distribution[row2] = distribution[row1]
21 self.distribution = distribution
22 self.distribution.normalize()
23
25 self.assertAlmostEqual(sum(self.distribution.values()),1.,5)
26 self.assertAlmostEqual(self.distribution[self.rows[0]],
27 self.distribution[self.rows[1]],5)
28
30 total = 10000
31 counts = {}
32 for element in self.distribution.domain():
33 counts[element] = 0
34 self.assertEqual(counts.keys(),self.distribution.domain())
35 for index in range(total):
36 counts[self.distribution.sample()] += 1
37 for element,prob in self.distribution.items():
38 self.assertAlmostEqual(float(counts[element])/float(total),prob,1)
39
41 eValue = self.distribution.expectation()
42 self.assertEqual(eValue.keys(),[self.key])
43 self.assertAlmostEqual(eValue[self.key],0.3,5)
44 eValue = float(self.distribution.getMarginal(self.key))
45 self.assertAlmostEqual(eValue,0.3,5)
46
50
52 result = self.distribution + self.distribution
53 self.assertEqual(len(result),3)
54 for row,prob in result.items():
55 if row[self.key] < 0.5:
56 self.assertAlmostEqual(row[self.key],0.4,5)
57 self.assertAlmostEqual(prob,0.25,5)
58 elif row[self.key] < 0.7:
59 self.assertAlmostEqual(row[self.key],0.6,5)
60 self.assertAlmostEqual(prob,0.5,5)
61 else:
62 self.assertAlmostEqual(row[self.key],0.8,5)
63 self.assertAlmostEqual(prob,0.25,5)
64
66 result = self.distribution - self.distribution
67 self.assertEqual(len(result),3)
68 for row,prob in result.items():
69 if row[self.key] < 0.:
70 self.assertAlmostEqual(row[self.key],-.2,5)
71 self.assertAlmostEqual(prob,0.25,5)
72 elif row[self.key] < 0.1:
73 self.assertAlmostEqual(row[self.key],0.,5)
74 self.assertAlmostEqual(prob,0.5,5)
75 else:
76 self.assertAlmostEqual(row[self.key],0.2,5)
77 self.assertAlmostEqual(prob,0.25,5)
78
80 result = self.distribution * self.distribution
81 self.assertEqual(len(result),3)
82 for row,prob in result.items():
83 if row < 0.05:
84 self.assertAlmostEqual(row,0.04,5)
85 self.assertAlmostEqual(prob,0.25,5)
86 elif row < 0.1:
87 self.assertAlmostEqual(row,0.08,5)
88 self.assertAlmostEqual(prob,0.5,5)
89 else:
90 self.assertAlmostEqual(row,.16,5)
91 self.assertAlmostEqual(prob,0.25,5)
92
94 distribution = Distribution()
95 key = StateKey({'entity':'Bill','feature':'_trustworthiness'})
96 row = KeyedVector({key:-.3})
97 distribution[row] = 0.6
98 row = KeyedVector({key:.3})
99 distribution[row] = 0.4
100 value = self.distribution + distribution
101 marginal = value.getMarginal(self.key)
102 self.assertEqual(len(marginal),len(value))
103 for row in value.domain():
104 for val,prob in marginal.items():
105 if abs(val-row[self.key]) < 0.001:
106 self.assertAlmostEqual(prob,value[row])
107 break
108 else:
109 self.fail()
110
112 initial = len(self.distribution)
113 key = StateKey({'entity':'Bill','feature':'_trustworthiness'})
114 self.assertEqual(self.distribution.keys(),
115 self.distribution._domain.keys())
116 self.distribution.join(key,0.3)
117 self.assertEqual(self.distribution.keys(),
118 self.distribution._domain.keys())
119 self.assertEqual(len(self.distribution),initial)
120 marginal = self.distribution.getMarginal(key)
121 self.assertEqual(marginal.keys(),marginal._domain.keys())
122 self.assertAlmostEqual(marginal.domain()[0],0.3,5)
123 self.assertAlmostEqual(float(marginal),0.3,5)
124 distribution = Distribution({0.3:0.7,-0.3:0.3})
125 self.assertEqual(distribution.keys(),distribution._domain.keys())
126 self.distribution.join(key,distribution)
127 self.assertEqual(len(self.distribution),2*initial)
128
134
136
138
139 self.keys = {'enemy': StateKey({'entity':'Radar',
140 'feature':'position'}),
141 }
142 self.state = Distribution()
143 for position in range(2,10,2):
144 row = KeyedVector({self.keys['enemy']:float(position)/10.})
145 row[keyConstant] = 1.
146 self.state[row] = 1.
147 self.state.normalize()
148
149 self.keys['escort'] = StateKey({'entity':'Escort',
150 'feature':'position'})
151 marginal = Distribution({0:1.})
152 self.state.join(self.keys['escort'],marginal)
153
154 self.keys['transport'] = StateKey({'entity':'Transport',
155 'feature':'position'})
156 marginal = Distribution({0.:1.})
157 self.state.join(self.keys['transport'],marginal)
158
159 for row in self.state.domain():
160 row.fill(self.keys.values())
161
162 self.dynamics = {}
163 tree = ProbabilityTree()
164 row = KeyedVector({self.keys['escort']:1.,
165 self.keys['enemy']:-1.})
166 plane = KeyedPlane(row,-.1)
167 left = ProbabilityTree(IdentityMatrix('position'))
168 key = StateKey({'entity':'self','feature':'position'})
169 right = ProbabilityTree(ScaleMatrix('position',key,-1.))
170 tree.branch(plane,left,right)
171 self.dynamics['enemy'] = tree
172 tree = ProbabilityTree()
173 tree.makeLeaf(IncrementMatrix('position',value=0.2))
174 self.dynamics['escort'] = tree
175 self.dynamics['fly-normal'] = tree
176 tree = ProbabilityTree()
177 tree.makeLeaf(IncrementMatrix('position',value=0.1))
178 self.dynamics['fly-noe'] = tree
179
180 tree = ProbabilityTree()
181 tree.makeLeaf(KeyedMatrix({keyConstant:KeyedVector({keyConstant:1.})}))
182 self.dynamics['constant'] = tree
183
188
190 keyList = self.keys.values()+[keyConstant]
191 keyList.sort()
192
193 dynamics = self.dynamics['escort']
194 table = {'self':'Escort'}
195 dynamics.instantiateKeys(table)
196 for row in dynamics.leaves()[0].values():
197 key = StateKey({'entity':'self','feature':'position'})
198 self.assert_(not row.has_key(key))
199 dynamics.fill(keyList)
200 self.assertEqual(len(dynamics.getValue()),len(self.keys)+1)
201 self.assert_(dynamics.isLeaf())
202 dynamics = dynamics[self.state]
203 self.assertEqual(len(dynamics),1)
204 for vector in dynamics.domain():
205 self.assertEqual(len(vector),len(self.keys)+1)
206 self.assertEqual(len(self.state),4)
207 for vector in self.state.domain():
208 self.assertEqual(len(vector),len(self.keys)+1)
209 matrix = dynamics.domain()[0]
210 self.assertEqual(len(matrix),len(self.keys)+1)
211 for key in self.keys.values():
212 row = matrix[key]
213 self.assertEqual(len(row),len(self.keys)+1)
214 if key['entity'] == 'Escort':
215 self.assertAlmostEqual(row[keyConstant],0.2,5)
216 for subKey in self.keys.values():
217 if subKey['entity'] == 'Escort':
218 self.assertAlmostEqual(row[subKey],1.,5)
219 else:
220 self.assertAlmostEqual(row[subKey],0.,5)
221 else:
222 for subKey in self.keys.values():
223 self.assertAlmostEqual(row[subKey],0.,5)
224 self.assertAlmostEqual(matrix[keyConstant][keyConstant],0.,5)
225 for key in self.keys.values():
226 self.assertAlmostEqual(matrix[keyConstant][key],0.,5)
227
229 self.assertEqual(len(self.state),len(self.state.domain()))
230 tables = {'escort':{'self':'Escort'},
231 'enemy':{'self':'Radar'},
232 'fly-noe':{'self':'Transport'},
233 'constant':{},
234 }
235 total = None
236 keyList = self.keys.values()+[keyConstant]
237 keyList.sort()
238 for key,table in tables.items():
239 dynamics = self.dynamics[key]
240 dynamics.instantiateKeys(table)
241 dynamics.fill(keyList)
242 if total is None:
243 total = dynamics
244 else:
245 total += dynamics
246 total.freeze()
247 state = total[self.state]*self.state
248 valid = map(lambda row:row[StateKey({'entity':'Radar',
249 'feature':'position'})],
250 self.state.domain())
251 self.assertEqual(len(state),len(valid))
252 for vector in state.domain():
253 self.assertEqual(len(vector),len(self.keys)+1)
254 self.assertAlmostEqual(vector[self.keys['escort']],.2,5)
255 self.assertAlmostEqual(vector[self.keys['transport']],.1,5)
256 self.assertAlmostEqual(vector[keyConstant],1.,5)
257 self.assert_(vector[self.keys['enemy']] in valid)
258
266
304
305 if __name__ == '__main__':
306 unittest.main()
307