Package teamwork :: Package test :: Package math :: Module testProbability
[hide private]
[frames] | no frames]

Source Code for Module teamwork.test.math.testProbability

  1  from teamwork.math.Keys import * 
  2  from teamwork.math.KeyedMatrix import * 
  3  from teamwork.math.probability import * 
  4  from teamwork.math.ProbabilityTree import * 
  5  from testPWL import makeVector,makePlane 
  6  import random 
  7   
  8  import unittest 
  9   
10 -class TestProbability(unittest.TestCase):
11
12 - def setUp(self):
13 distribution = Distribution() 14 self.key = StateKey({'entity':'Bill', 15 'feature':'power'}) 16 row1 = KeyedVector({self.key:0.2}) 17 row2 = KeyedVector({self.key:0.4}) 18 self.rows = [row1,row2] 19 distribution[row1] = 1. 20 distribution[row2] = distribution[row1] 21 self.distribution = distribution 22 self.distribution.normalize()
23
24 - def testNormalize(self):
25 self.assertAlmostEqual(sum(self.distribution.values()),1.,5) 26 self.assertAlmostEqual(self.distribution[self.rows[0]], 27 self.distribution[self.rows[1]],5)
28
29 - def testSample(self):
30 total = 10000 31 counts = {} 32 for element in self.distribution.domain(): 33 counts[element] = 0 34 self.assertEqual(counts.keys(),self.distribution.domain()) 35 for index in range(total): 36 counts[self.distribution.sample()] += 1 37 for element,prob in self.distribution.items(): 38 self.assertAlmostEqual(float(counts[element])/float(total),prob,1)
39
40 - def testExpectation(self):
41 eValue = self.distribution.expectation() 42 self.assertEqual(eValue.keys(),[self.key]) 43 self.assertAlmostEqual(eValue[self.key],0.3,5) 44 eValue = float(self.distribution.getMarginal(self.key)) 45 self.assertAlmostEqual(eValue,0.3,5)
46
47 - def testNegation(self):
48 value = -self.distribution 49 self.assertAlmostEqual(value.expectation()[self.key],-0.3,5)
50
51 - def testAddition(self):
52 result = self.distribution + self.distribution 53 self.assertEqual(len(result),3) 54 for row,prob in result.items(): 55 if row[self.key] < 0.5: 56 self.assertAlmostEqual(row[self.key],0.4,5) 57 self.assertAlmostEqual(prob,0.25,5) 58 elif row[self.key] < 0.7: 59 self.assertAlmostEqual(row[self.key],0.6,5) 60 self.assertAlmostEqual(prob,0.5,5) 61 else: 62 self.assertAlmostEqual(row[self.key],0.8,5) 63 self.assertAlmostEqual(prob,0.25,5)
64
65 - def testSubtraction(self):
66 result = self.distribution - self.distribution 67 self.assertEqual(len(result),3) 68 for row,prob in result.items(): 69 if row[self.key] < 0.: 70 self.assertAlmostEqual(row[self.key],-.2,5) 71 self.assertAlmostEqual(prob,0.25,5) 72 elif row[self.key] < 0.1: 73 self.assertAlmostEqual(row[self.key],0.,5) 74 self.assertAlmostEqual(prob,0.5,5) 75 else: 76 self.assertAlmostEqual(row[self.key],0.2,5) 77 self.assertAlmostEqual(prob,0.25,5)
78
79 - def testMultiplication(self):
80 result = self.distribution * self.distribution 81 self.assertEqual(len(result),3) 82 for row,prob in result.items(): 83 if row < 0.05: 84 self.assertAlmostEqual(row,0.04,5) 85 self.assertAlmostEqual(prob,0.25,5) 86 elif row < 0.1: 87 self.assertAlmostEqual(row,0.08,5) 88 self.assertAlmostEqual(prob,0.5,5) 89 else: 90 self.assertAlmostEqual(row,.16,5) 91 self.assertAlmostEqual(prob,0.25,5)
92
93 - def testGetMarginal(self):
94 distribution = Distribution() 95 key = StateKey({'entity':'Bill','feature':'_trustworthiness'}) 96 row = KeyedVector({key:-.3}) 97 distribution[row] = 0.6 98 row = KeyedVector({key:.3}) 99 distribution[row] = 0.4 100 value = self.distribution + distribution 101 marginal = value.getMarginal(self.key) 102 self.assertEqual(len(marginal),len(value)) 103 for row in value.domain(): 104 for val,prob in marginal.items(): 105 if abs(val-row[self.key]) < 0.001: 106 self.assertAlmostEqual(prob,value[row]) 107 break 108 else: 109 self.fail()
110
111 - def testJoint(self):
112 initial = len(self.distribution) 113 key = StateKey({'entity':'Bill','feature':'_trustworthiness'}) 114 self.assertEqual(self.distribution.keys(), 115 self.distribution._domain.keys()) 116 self.distribution.join(key,0.3) 117 self.assertEqual(self.distribution.keys(), 118 self.distribution._domain.keys()) 119 self.assertEqual(len(self.distribution),initial) 120 marginal = self.distribution.getMarginal(key) 121 self.assertEqual(marginal.keys(),marginal._domain.keys()) 122 self.assertAlmostEqual(marginal.domain()[0],0.3,5) 123 self.assertAlmostEqual(float(marginal),0.3,5) 124 distribution = Distribution({0.3:0.7,-0.3:0.3}) 125 self.assertEqual(distribution.keys(),distribution._domain.keys()) 126 self.distribution.join(key,distribution) 127 self.assertEqual(len(self.distribution),2*initial)
128
129 - def testXML(self):
130 doc = self.distribution.__xml__() 131 new = Distribution() 132 new.parse(doc.documentElement,KeyedVector) 133 self.assertEqual(self.distribution,new)
134
135 -class TestProbabilityTree(unittest.TestCase):
136
137 - def setUp(self):
138 # Set up distribution over radar position 139 self.keys = {'enemy': StateKey({'entity':'Radar', 140 'feature':'position'}), 141 } 142 self.state = Distribution() 143 for position in range(2,10,2): 144 row = KeyedVector({self.keys['enemy']:float(position)/10.}) 145 row[keyConstant] = 1. 146 self.state[row] = 1. 147 self.state.normalize() 148 # Set up escort state 149 self.keys['escort'] = StateKey({'entity':'Escort', 150 'feature':'position'}) 151 marginal = Distribution({0:1.}) 152 self.state.join(self.keys['escort'],marginal) 153 # Set up transport state 154 self.keys['transport'] = StateKey({'entity':'Transport', 155 'feature':'position'}) 156 marginal = Distribution({0.:1.}) 157 self.state.join(self.keys['transport'],marginal) 158 159 for row in self.state.domain(): 160 row.fill(self.keys.values()) 161 # Set up action dynamics 162 self.dynamics = {} 163 tree = ProbabilityTree() 164 row = KeyedVector({self.keys['escort']:1., 165 self.keys['enemy']:-1.}) 166 plane = KeyedPlane(row,-.1) 167 left = ProbabilityTree(IdentityMatrix('position')) 168 key = StateKey({'entity':'self','feature':'position'}) 169 right = ProbabilityTree(ScaleMatrix('position',key,-1.)) 170 tree.branch(plane,left,right) 171 self.dynamics['enemy'] = tree 172 tree = ProbabilityTree() 173 tree.makeLeaf(IncrementMatrix('position',value=0.2)) 174 self.dynamics['escort'] = tree 175 self.dynamics['fly-normal'] = tree 176 tree = ProbabilityTree() 177 tree.makeLeaf(IncrementMatrix('position',value=0.1)) 178 self.dynamics['fly-noe'] = tree 179 # Add constant factor 180 tree = ProbabilityTree() 181 tree.makeLeaf(KeyedMatrix({keyConstant:KeyedVector({keyConstant:1.})})) 182 self.dynamics['constant'] = tree
183
184 - def testClear(self):
185 self.state.clear() 186 self.assertEqual(len(self.state),0) 187 self.assertEqual(len(self.state.domain()),0)
188
189 - def testGetItem(self):
190 keyList = self.keys.values()+[keyConstant] 191 keyList.sort() 192 # Test escort dynamics 193 dynamics = self.dynamics['escort'] 194 table = {'self':'Escort'} 195 dynamics.instantiateKeys(table) 196 for row in dynamics.leaves()[0].values(): 197 key = StateKey({'entity':'self','feature':'position'}) 198 self.assert_(not row.has_key(key)) 199 dynamics.fill(keyList) 200 self.assertEqual(len(dynamics.getValue()),len(self.keys)+1) 201 self.assert_(dynamics.isLeaf()) 202 dynamics = dynamics[self.state] 203 self.assertEqual(len(dynamics),1) 204 for vector in dynamics.domain(): 205 self.assertEqual(len(vector),len(self.keys)+1) 206 self.assertEqual(len(self.state),4) 207 for vector in self.state.domain(): 208 self.assertEqual(len(vector),len(self.keys)+1) 209 matrix = dynamics.domain()[0] 210 self.assertEqual(len(matrix),len(self.keys)+1) 211 for key in self.keys.values(): 212 row = matrix[key] 213 self.assertEqual(len(row),len(self.keys)+1) 214 if key['entity'] == 'Escort': 215 self.assertAlmostEqual(row[keyConstant],0.2,5) 216 for subKey in self.keys.values(): 217 if subKey['entity'] == 'Escort': 218 self.assertAlmostEqual(row[subKey],1.,5) 219 else: 220 self.assertAlmostEqual(row[subKey],0.,5) 221 else: 222 for subKey in self.keys.values(): 223 self.assertAlmostEqual(row[subKey],0.,5) 224 self.assertAlmostEqual(matrix[keyConstant][keyConstant],0.,5) 225 for key in self.keys.values(): 226 self.assertAlmostEqual(matrix[keyConstant][key],0.,5)
227
228 - def testDynamics(self):
229 self.assertEqual(len(self.state),len(self.state.domain())) 230 tables = {'escort':{'self':'Escort'}, 231 'enemy':{'self':'Radar'}, 232 'fly-noe':{'self':'Transport'}, 233 'constant':{}, 234 } 235 total = None 236 keyList = self.keys.values()+[keyConstant] 237 keyList.sort() 238 for key,table in tables.items(): 239 dynamics = self.dynamics[key] 240 dynamics.instantiateKeys(table) 241 dynamics.fill(keyList) 242 if total is None: 243 total = dynamics 244 else: 245 total += dynamics 246 total.freeze() 247 state = total[self.state]*self.state 248 valid = map(lambda row:row[StateKey({'entity':'Radar', 249 'feature':'position'})], 250 self.state.domain()) 251 self.assertEqual(len(state),len(valid)) 252 for vector in state.domain(): 253 self.assertEqual(len(vector),len(self.keys)+1) 254 self.assertAlmostEqual(vector[self.keys['escort']],.2,5) 255 self.assertAlmostEqual(vector[self.keys['transport']],.1,5) 256 self.assertAlmostEqual(vector[keyConstant],1.,5) 257 self.assert_(vector[self.keys['enemy']] in valid)
258
259 - def testXML(self):
260 doc = self.dynamics['escort'].__xml__() 261 new = ProbabilityTree() 262 new.parse(doc.documentElement) 263 self.assertEqual(self.dynamics['escort'].split,new.split) 264 self.assertEqual(self.dynamics['escort'].getValue(),new.getValue()) 265 self.assertEqual(self.dynamics['escort'],new)
266
267 - def testMerge(self):
268 # Make a 1-row matrix 269 vector1 = makeVector(self.keys.values()) 270 key1 = random.choice(self.keys.values()) 271 matrix1 = KeyedMatrix({key1:vector1}) 272 # Make a 1-row matrix with a different row key 273 key2 = key1 274 while key2 == key1: 275 key2 = random.choice(self.keys.values()) 276 vector2 = makeVector(self.keys.values()) 277 matrix2 = KeyedMatrix({key2:vector2}) 278 # Make another 1-row matrix with a different row key 279 key3 = key1 280 while key3 == key1 or key3 == key2: 281 key3 = random.choice(self.keys.values()) 282 vector3 = makeVector(self.keys.values()) 283 matrix3 = KeyedMatrix({key3:vector3}) 284 # Make trees out of these matrices 285 tree1 = ProbabilityTree(matrix1) 286 tree2 = ProbabilityTree() 287 plane2 = makePlane(self.keys.values()) 288 tree2.branch(plane2, 289 ProbabilityTree(matrix2), 290 ProbabilityTree(matrix3)) 291 # Merge them 292 tree = tree1.merge(tree2,KeyedMatrix.merge) 293 self.assert_(not tree.isLeaf()) 294 self.assertEqual(tree.split[0],plane2) 295 for matrix in tree.leaves(): 296 self.assert_(matrix.has_key(key1)) 297 self.assertEqual(len(matrix),2) 298 self.assertEqual(matrix[key1],vector1) 299 if matrix.has_key(key2): 300 self.assertEqual(matrix[key2],vector2) 301 else: 302 self.assert_(matrix.has_key(key3)) 303 self.assertEqual(matrix[key3],vector3)
304 305 if __name__ == '__main__': 306 unittest.main() 307