1 import bz2
2 import hotshot,hotshot.stats
3 import random
4 from teamwork.agent.Entities import PsychEntity
5 from teamwork.multiagent.PsychAgents import PsychAgents
6 from teamwork.multiagent.GenericSociety import GenericSociety
7 from teamwork.multiagent.pwlSimulation import PWLSimulation
8 from teamwork.agent.lightweight import PWLAgent
9 from teamwork.examples.TigerScenario import *
10 from teamwork.policy.pwlTable import PWLTable
11 from teamwork.policy.pwlPolicy import PWLPolicy
12
13 import unittest
14 import hotshot,hotshot.stats
15
16
17
18 try:
19 from numpy import seterr
20 seterr(divide='raise')
21 except ImportError:
22 pass
23
25 - def __init__(self,vector,factor=0.001,update='additive',epsilon=1e-8):
36
39
60
62 """Uses the tiger scenario to test PWL policy generation
63 """
64 profile = False
65
66
68 """Creates the instantiated scenario used for testing"""
69 society = GenericSociety()
70 society.importDict(classHierarchy)
71
72 agents = []
73 agents.append(society.instantiate('Tiger','Tiger',PsychEntity))
74 agents.append(society.instantiate('Dude','Player 1',PsychEntity))
75 agents.append(society.instantiate('Dude','Player 2',PsychEntity))
76 self.full = PsychAgents(agents)
77 self.full.applyDefaults()
78
79 self.scenario = PWLSimulation(self.full)
80 state = self.scenario.getState()
81
82 keyList = self.scenario.state.domain()[0].keys()
83 keyList.sort()
84 actions = {}
85 for act1 in self.scenario['Player 1'].actions.getOptions():
86 actions['Player 1'] = act1
87 for act2 in self.scenario['Player 2'].actions.getOptions():
88 actions['Player 2'] = act2
89 actionKey = ' '.join(map(str,actions.values()))
90 if act1[0]['type'] == 'Listen' and act2[0]['type'] == 'Listen':
91 dynamics = self.full.getDynamics({'Player 1':act1})
92 elif act1[0]['type'] != 'Listen':
93 dynamics = self.full.getDynamics({'Player 1':act1})
94 else:
95 dynamics = self.full.getDynamics({'Player 2':act2})
96 tree = dynamics['state'].getTree()
97 tree.unfreeze()
98 tree.fill(keyList)
99 tree.freeze()
100 dynamics['state'].args['tree'] = tree
101 self.full.dynamics[actionKey] = dynamics
102 self.frozen = True
103
104 self.worlds,self.lookup = self.full.generateWorlds()
105 total = 0
106 for world,state in self.worlds.items():
107 self.assert_(isinstance(world,WorldKey))
108 self.assert_(world['world'] == 0 or world['world'] == 1)
109 key = StateKey({'entity':'Tiger','feature':'position'})
110 self.assertAlmostEqual(state[key],float(world['world']),3)
111 total += world['world']
112 self.assertEqual(total,1)
113 for name in ['Player 1','Player 2']:
114 self.scenario[name].beliefs = KeyedVector()
115 for key,world in self.worlds.items():
116 self.scenario[name].beliefs[key] = self.scenario.state[world]
117 state = KeyedVector()
118 for key,world in self.worlds.items():
119 state[key] = self.scenario.state[world]
120 self.scenario.state = state
121 self.transition = self.full.getDynamicsMatrix(self.worlds,self.lookup)
122 self.reward = {}
123 self.observations = {}
124 for actions in self.full.generateActions():
125 actionKey = ' '.join(map(str,actions.values()))
126
127 tree = rewardDict[actions.values()[0][0]['type']+\
128 actions.values()[1][0]['type']]
129 vector = KeyedVector()
130 for key,world in self.worlds.items():
131 vector[key] = tree[world]
132 vector.freeze()
133 self.reward[actionKey] = vector
134
135 tree = observationDict[actions.values()[0][0]['type']+\
136 actions.values()[1][0]['type']]
137 matrix = KeyedMatrix()
138 for colKey,world in self.worlds.items():
139 new = tree[world]*world
140 for vector,prob in new.items():
141 for omega in Omega.values():
142 if vector[omega] > 0.5:
143 matrix.set(omega,colKey,prob)
144 matrix.freeze()
145 self.observations[actionKey] = matrix
146
147 self.best = {'key':None,'value':None}
148 for key,vector in self.reward.items():
149 value = vector*state
150 if self.best['key'] is None or value > self.best['value']:
151 self.best['key'] = key
152 self.best['value'] = value
153 for actions in self.full.generateActions():
154 if self.best['key'] == ' '.join(map(str,actions.values())):
155 self.best['action'] = actions
156 break
157 else:
158 raise UserWarning,'Unknown joint action: %s' % (self.best['key'])
159
160 for name,option in self.best['action'].items():
161 self.scenario[name].policy = PWLPolicy(self.scenario[name])
162 table = PWLTable()
163 table.rules = {0:option}
164 table.values = {0:{}}
165 actions = copy.copy(self.best['action'])
166 for alternative in self.scenario[name].actions.getOptions():
167 actions[name] = alternative
168 actionKey = ' '.join(map(str,actions.values()))
169 value = self.reward[actionKey]
170 table.values[0][str(alternative)] = value
171 self.scenario[name].policy.tables.append([table])
172
183
185 agent = self.scenario['Player 1']
186 other = self.scenario['Player 2']
187
188 for world in self.worlds.keys():
189 self.assertAlmostEqual(agent.beliefs[world],0.5,3)
190 agent.setEstimator(self.transition,self.observations)
191 yrOption,exp = other.policy.execute()
192 self.assertEqual(len(yrOption),1)
193 self.assertEqual(yrOption[0]['actor'],other.name)
194 self.assertEqual(yrOption[0]['type'],'Listen')
195 self.assertEqual(yrOption[0]['object'],None)
196 actions = {other.name: yrOption}
197 self.verifyEstimator(agent,actions)
198
200 for myOption in agent.actions.getOptions():
201 actions[agent.name] = myOption
202 actionKey = ' '.join(map(str,actions.values()))
203 for omega in self.observations[actionKey].keys():
204 beliefs = agent.stateEstimator(agent.beliefs,myOption,omega)
205 normalization = 0.
206 for world,prob in agent.beliefs.items():
207 normalization += prob*self.observations[actionKey][omega][world]
208 for world,prob in beliefs.items():
209 new = agent.beliefs[world]*self.observations[actionKey][omega][world]/normalization
210 self.assertAlmostEqual(prob,new,3)
211
213 agent = self.scenario['Player 1']
214 other = self.scenario['Player 2']
215
216 R = other.policy.getTable()
217 for index in range(len(R)):
218 yrAction = other.policy.getTable().rules[index]
219 R.values[index].clear()
220 for myAction in agent.actions.getOptions():
221 actions = {agent.name:myAction,
222 other.name:yrAction}
223 actionKey = ' '.join(map(str,actions.values()))
224 R.values[index][str(myAction)] = copy.copy(self.reward[actionKey])
225
226 agent.setEstimator(self.transition,self.observations)
227 other.setEstimator(self.transition,self.observations)
228
229 V = R.getTable()
230 V.rules.clear()
231 self.verifyPolicy(V)
232 policy = V.max()
233 self.verifyMax(V,policy)
234 self.verifyPolicy(policy)
235 agent.policy.project(R,depth=1)
236 self.verifyPolicy()
237 print 'Horizon: 0'
238 print agent.policy.getTable()
239 table = agent.policy.getTable()
240 self.verifyPrune(table)
241 for horizon in range(1,8):
242 print 'Horizon:',horizon
243 self.deepVerify(R)
244 if horizon == 7:
245 prof = hotshot.Profile("tiger%d.prof" % (horizon))
246 prof.start()
247 start = time.time()
248 agent.policy.project(R,depth=1)
249 print time.time()-start
250 if horizon == 7:
251 prof.stop()
252 prof.close()
253 stats = hotshot.stats.load("tiger%d.prof" % (horizon))
254 stats.strip_dirs()
255 stats.sort_stats('time', 'calls')
256 stats.print_stats(20)
257 policy = agent.policy.getTable()
258 print 'Rules:',len(policy.rules)
259 self.verifyPolicy()
260 policy.prune(rulesOnly=True)
261 print policy
262
264 agent = self.scenario['Player 1']
265 other = self.scenario['Player 2']
266 previous = agent.policy.getTable()
267
268 V = R.getTable()
269 Vstar = previous.star()
270 self.verifyStar(Vstar,previous)
271 old = Vstar.getTable()
272 Vstar.prune()
273 self.verifyStar(Vstar,previous)
274 self.verifyPrune(Vstar,old)
275 for omega,SE in agent.estimators.items():
276
277 product = Vstar.__mul__(SE,debug=True)
278
279 self.verifyProduct(Vstar,omega,product)
280 Vnew = V.__add__(product,debug=False)
281 self.verifySum(V,product,Vnew)
282 Vnew.prune()
283 self.verifySum(V,product,Vnew)
284 V = Vnew
285 self.verifyPrune(V,inspect='values')
286 self.verifyPolicy(V,Vstar)
287 policy = V.max()
288 self.verifyMax(V,policy)
289 self.verifyPolicy(policy,Vstar)
290 old = policy.getTable()
291 policy.pruneAttributes()
292 self.verifyPrune(policy,old)
293 self.verifyPolicy(policy,Vstar)
294
296 for beliefs in probabilityIterator(self.scenario.state):
297
298 rawIndex = raw.index(beliefs)
299 self.assert_(raw.values.has_key(rawIndex))
300 maxIndex = maxed.index(beliefs)
301 self.assert_(maxed.rules.has_key(maxIndex),'Missing: (%d) %s' % (maxIndex,maxed.factorString(maxIndex)))
302 self.assert_(maxed.values.has_key(maxIndex))
303
304 best = None
305 for option,V in raw.values[rawIndex].items():
306 value = V*beliefs
307 if best is None or value > best:
308 best = value
309 self.assert_(maxed.values[maxIndex].has_key(option))
310 maxValue = maxed.values[maxIndex][option]
311 self.failUnlessEqual(V,maxValue,'Failed on rule %d\n%s' % \
312 (rawIndex,maxed.index2factored(maxIndex)))
313 option = maxed.rules[maxIndex]
314 value = maxed.values[maxIndex][option]*beliefs
315 self.assertAlmostEqual(best,value,5,
316 'Failed on rule %d (%f vs. %f)' % (rawIndex,best,value))
317
319 agent = self.scenario['Player 1']
320 other = self.scenario['Player 2']
321 yrOption,exp = other.policy.execute()
322 if policy is None:
323 policy = agent.policy.tables[-1][-1]
324 if Vstar is None and len(agent.policy.tables[-1]) > 1:
325 Vstar = agent.policy.tables[-1][-2].star()
326 Vstar.prune()
327 for beliefs in probabilityIterator(self.scenario.state):
328 index = policy.index(beliefs)
329 actions = {other.name:yrOption}
330 best = None
331 results = {}
332 for option in agent.actions.getOptions():
333 actions[agent.name] = option
334 actionKey = ' '.join(map(str,actions.values()))
335 results[str(option)] = self.reward[actionKey]*beliefs
336 if Vstar:
337
338 partial = {}
339 newState = self.transition[actionKey]*beliefs
340 for omega,O in self.observations[actionKey].items():
341 prob = O*newState
342 newBeliefs = agent.stateEstimator(beliefs,option,omega)
343 projection = Vstar[newBeliefs]*newBeliefs
344 results[str(option)] += prob*projection
345 partial[str(omega)] = {'probability':prob,
346 'state':newState.getArray(),
347 'beliefs':newBeliefs.getArray(),
348 'value': projection,
349 'V': Vstar[newBeliefs].getArray(),
350 }
351 if policy.values.has_key(index):
352
353 value = policy.values[index][str(option)]*beliefs
354 self.assertAlmostEqual(value,results[str(option)],3)
355 if policy.rules.has_key(index):
356 best = str(policy[beliefs])
357 for option,value in results.items():
358 self.failIf(results[best]+1e-10 < value)
359
375
377 agent = self.scenario['Player 1']
378 other = self.scenario['Player 2']
379 yrOption,exp = other.policy.execute()
380 actions = {other.name:yrOption}
381 for index,table in product.values.items():
382 for option in agent.actions.getOptions():
383 if not table.has_key(str(option)):
384 print 'Rule %d missing %s' % (index,str(option))
385 factors = product.index2factored(index)
386 for i in range(len(product.attributes)):
387 print product.attributes[i][0].getArray(),bool(factors[i])
388 self.assert_(table.has_key(str(option)))
389 for beliefs in probabilityIterator(self.scenario.state):
390 self.assertAlmostEqual(sum(beliefs.getArray()),1.,5)
391 if debug:
392 print 'Beliefs:',beliefs.getArray(),omega
393 beliefs.getArray()[0] = 0.5
394 beliefs.getArray()[1] = 0.5
395 index = product.index(beliefs)
396 if debug:
397 for attr in product.attributes:
398 print '\t',attr[0].getArray()
399 print product.index2factored(index)
400 self.assert_(product.values.has_key(index))
401 for myOption in agent.actions.getOptions():
402 self.assert_(product.values[index].has_key(str(myOption)))
403 rhs = product.values[index][str(myOption)]
404 value = rhs*beliefs
405
406 actions[agent.name] = myOption
407 actionKey = ' '.join(map(str,actions.values()))
408 new = copy.copy(beliefs)
409 for oldWorld in new.keys():
410 new[oldWorld] = 0.
411 for oldWorld,oldProb in beliefs.items():
412 for newWorld in new.keys():
413 newProb = oldProb*self.transition[actionKey][newWorld][oldWorld]*self.observations[actionKey][omega][newWorld]
414 new[newWorld] += newProb
415 normalization = sum(new.getArray())
416 for world,prob in new.items():
417 new[world] = prob/normalization
418
419 real = V[new]*new
420
421 obs = self.observations[actionKey][omega]*beliefs
422 real *= obs
423 if debug:
424 print 'New:',new.getArray()
425 print 'V0:',V[new].getArray()
426 print 'V0(b):',V[new]*new
427 print 'O:',obs
428 print 'Real:',real
429 print 'Computed RHS:',rhs.getArray()
430 print 'Computed Product:',value
431 self.assertAlmostEqual(value,real,5)
432
439
441 missing = {}
442 if inspect == 'values':
443 keyList = table.values.keys()
444 else:
445 keyList = table.rules.keys()
446 for index in keyList:
447 missing[index] = True
448 for beliefs in probabilityIterator(self.scenario.state):
449 if self._verifyPrune(table,beliefs,missing,old,inspect):
450 break
451 else:
452 for beliefs in probabilityIterator(self.scenario.state,factor=0.5,update='multiplicative'):
453 if self._verifyPrune(table,beliefs,missing,old,inspect):
454 break
455 else:
456 for rule in missing.keys():
457 print table.factorString(rule)
458 self.fail('Able to hit only %d/%d rules' % \
459 (len(keyList)-len(missing),len(keyList)))
460
461 - def _verifyPrune(self,table,beliefs,missing,old=None,inspect='rules'):
462 index = table.index(beliefs)
463 if inspect == 'values':
464 self.assert_(table.values.has_key(index))
465 else:
466 self.assert_(table.rules.has_key(index))
467 if missing.has_key(index):
468 del missing[index]
469 if len(missing) == 0:
470 return True
471 if old:
472 if inspect == 'values':
473 self.assertEqual(table.values[index],
474 old.values[old.index(beliefs)])
475 else:
476 self.assertEqual(table.rules[index],
477 old.rules[old.index(beliefs)])
478 return False
479
493
494 if __name__ == '__main__':
495 unittest.main()
496