Package teamwork :: Package test :: Package agent :: Module testLightweight
[hide private]
[frames] | no frames]

Source Code for Module teamwork.test.agent.testLightweight

  1  import bz2 
  2  import hotshot,hotshot.stats 
  3  import random 
  4  from teamwork.agent.Entities import PsychEntity 
  5  from teamwork.multiagent.PsychAgents import PsychAgents 
  6  from teamwork.multiagent.GenericSociety import GenericSociety 
  7  from teamwork.multiagent.pwlSimulation import PWLSimulation 
  8  from teamwork.agent.lightweight import PWLAgent 
  9  from teamwork.examples.TigerScenario import * 
 10  from teamwork.policy.pwlTable import PWLTable 
 11  from teamwork.policy.pwlPolicy import PWLPolicy 
 12   
 13  import unittest 
 14  import hotshot,hotshot.stats 
 15   
 16  # The following is needed to raise an exception for division by zero 
 17  # for some versions of Numeric Python 
 18  try: 
 19      from numpy import seterr 
 20      seterr(divide='raise') 
 21  except ImportError: 
 22      pass 
 23   
24 -class probabilityIterator:
25 - def __init__(self,vector,factor=0.001,update='additive',epsilon=1e-8):
26 self.value = copy.copy(vector) 27 if update == 'multiplicative': 28 self.value.getArray()[0] = 0.5 29 else: 30 self.value.getArray()[0] = 0. 31 self.value.getArray()[1] = 1. - self.value.getArray()[0] 32 self.factor = factor 33 self.update = update 34 self.epsilon = epsilon 35 self.first = True
36
37 - def __iter__(self):
38 return self
39
40 - def next(self):
41 if self.first: 42 self.first = False 43 return self.value 44 else: 45 original = self.value.getArray()[0] 46 if self.update == 'multiplicative': 47 if original + self.epsilon > 0.5: 48 next = self.factor*(1.-original) 49 if next < self.epsilon: 50 raise StopIteration 51 else: 52 next = 1. - original 53 else: 54 next = original + self.factor 55 if next > 1.: 56 raise StopIteration 57 self.value.getArray()[0] = next 58 self.value.getArray()[1] = 1. - self.value.getArray()[0] 59 return copy.copy(self.value)
60
61 -class TestPWLAgent(unittest.TestCase):
62 """Uses the tiger scenario to test PWL policy generation 63 """ 64 profile = False 65 66
67 - def setUp(self):
68 """Creates the instantiated scenario used for testing""" 69 society = GenericSociety() 70 society.importDict(classHierarchy) 71 # Instantiate scenario 72 agents = [] 73 agents.append(society.instantiate('Tiger','Tiger',PsychEntity)) 74 agents.append(society.instantiate('Dude','Player 1',PsychEntity)) 75 agents.append(society.instantiate('Dude','Player 2',PsychEntity)) 76 self.full = PsychAgents(agents) 77 self.full.applyDefaults() 78 # Make into PWL agents 79 self.scenario = PWLSimulation(self.full) 80 state = self.scenario.getState() 81 # Initialize dynamics 82 keyList = self.scenario.state.domain()[0].keys() 83 keyList.sort() 84 actions = {} 85 for act1 in self.scenario['Player 1'].actions.getOptions(): 86 actions['Player 1'] = act1 87 for act2 in self.scenario['Player 2'].actions.getOptions(): 88 actions['Player 2'] = act2 89 actionKey = ' '.join(map(str,actions.values())) 90 if act1[0]['type'] == 'Listen' and act2[0]['type'] == 'Listen': 91 dynamics = self.full.getDynamics({'Player 1':act1}) 92 elif act1[0]['type'] != 'Listen': 93 dynamics = self.full.getDynamics({'Player 1':act1}) 94 else: 95 dynamics = self.full.getDynamics({'Player 2':act2}) 96 tree = dynamics['state'].getTree() 97 tree.unfreeze() 98 tree.fill(keyList) 99 tree.freeze() 100 dynamics['state'].args['tree'] = tree 101 self.full.dynamics[actionKey] = dynamics 102 self.frozen = True 103 # Find reachable worlds at 0-level 104 self.worlds,self.lookup = self.full.generateWorlds() 105 total = 0 106 for world,state in self.worlds.items(): 107 self.assert_(isinstance(world,WorldKey)) 108 self.assert_(world['world'] == 0 or world['world'] == 1) 109 key = StateKey({'entity':'Tiger','feature':'position'}) 110 self.assertAlmostEqual(state[key],float(world['world']),3) 111 total += world['world'] 112 self.assertEqual(total,1) 113 for name in ['Player 1','Player 2']: 114 self.scenario[name].beliefs = KeyedVector() 115 for key,world in self.worlds.items(): 116 self.scenario[name].beliefs[key] = self.scenario.state[world] 117 state = KeyedVector() 118 for key,world in self.worlds.items(): 119 state[key] = self.scenario.state[world] 120 self.scenario.state = state 121 self.transition = self.full.getDynamicsMatrix(self.worlds,self.lookup) 122 self.reward = {} 123 self.observations = {} 124 for actions in self.full.generateActions(): 125 actionKey = ' '.join(map(str,actions.values())) 126 # Transform reward function into matrix representation 127 tree = rewardDict[actions.values()[0][0]['type']+\ 128 actions.values()[1][0]['type']] 129 vector = KeyedVector() 130 for key,world in self.worlds.items(): 131 vector[key] = tree[world] 132 vector.freeze() 133 self.reward[actionKey] = vector 134 # Transform observation probability into matrix representation 135 tree = observationDict[actions.values()[0][0]['type']+\ 136 actions.values()[1][0]['type']] 137 matrix = KeyedMatrix() 138 for colKey,world in self.worlds.items(): 139 new = tree[world]*world 140 for vector,prob in new.items(): 141 for omega in Omega.values(): 142 if vector[omega] > 0.5: 143 matrix.set(omega,colKey,prob) 144 matrix.freeze() 145 self.observations[actionKey] = matrix 146 # Seed policies with a very naive one --- best joint action 147 self.best = {'key':None,'value':None} 148 for key,vector in self.reward.items(): 149 value = vector*state 150 if self.best['key'] is None or value > self.best['value']: 151 self.best['key'] = key 152 self.best['value'] = value 153 for actions in self.full.generateActions(): 154 if self.best['key'] == ' '.join(map(str,actions.values())): 155 self.best['action'] = actions 156 break 157 else: 158 raise UserWarning,'Unknown joint action: %s' % (self.best['key']) 159 # Set up null policy 160 for name,option in self.best['action'].items(): 161 self.scenario[name].policy = PWLPolicy(self.scenario[name]) 162 table = PWLTable() 163 table.rules = {0:option} 164 table.values = {0:{}} 165 actions = copy.copy(self.best['action']) 166 for alternative in self.scenario[name].actions.getOptions(): 167 actions[name] = alternative 168 actionKey = ' '.join(map(str,actions.values())) 169 value = self.reward[actionKey] 170 table.values[0][str(alternative)] = value 171 self.scenario[name].policy.tables.append([table])
172
173 - def generateState(self):
174 """ 175 @return: a random state vector for the scenario of interest 176 @rtype: L{KeyedVector} 177 """ 178 state = copy.copy(self.scenario.state) 179 key = state.keys() 180 state.getArray()[0] = random.random() 181 state.getArray()[1] = 1. - state.getArray()[0] 182 return state
183
184 - def testEstimator(self):
185 agent = self.scenario['Player 1'] 186 other = self.scenario['Player 2'] 187 # Check that initial beliefs are correct 188 for world in self.worlds.keys(): 189 self.assertAlmostEqual(agent.beliefs[world],0.5,3) 190 agent.setEstimator(self.transition,self.observations) 191 yrOption,exp = other.policy.execute() 192 self.assertEqual(len(yrOption),1) 193 self.assertEqual(yrOption[0]['actor'],other.name) 194 self.assertEqual(yrOption[0]['type'],'Listen') 195 self.assertEqual(yrOption[0]['object'],None) 196 actions = {other.name: yrOption} 197 self.verifyEstimator(agent,actions)
198
199 - def verifyEstimator(self,agent,actions):
200 for myOption in agent.actions.getOptions(): 201 actions[agent.name] = myOption 202 actionKey = ' '.join(map(str,actions.values())) 203 for omega in self.observations[actionKey].keys(): 204 beliefs = agent.stateEstimator(agent.beliefs,myOption,omega) 205 normalization = 0. 206 for world,prob in agent.beliefs.items(): 207 normalization += prob*self.observations[actionKey][omega][world] 208 for world,prob in beliefs.items(): 209 new = agent.beliefs[world]*self.observations[actionKey][omega][world]/normalization 210 self.assertAlmostEqual(prob,new,3)
211
212 - def testProject(self):
213 agent = self.scenario['Player 1'] 214 other = self.scenario['Player 2'] 215 # R 216 R = other.policy.getTable() 217 for index in range(len(R)): 218 yrAction = other.policy.getTable().rules[index] 219 R.values[index].clear() 220 for myAction in agent.actions.getOptions(): 221 actions = {agent.name:myAction, 222 other.name:yrAction} 223 actionKey = ' '.join(map(str,actions.values())) 224 R.values[index][str(myAction)] = copy.copy(self.reward[actionKey]) 225 # SE 226 agent.setEstimator(self.transition,self.observations) 227 other.setEstimator(self.transition,self.observations) 228 # Horizon = 0 229 V = R.getTable() 230 V.rules.clear() 231 self.verifyPolicy(V) 232 policy = V.max() 233 self.verifyMax(V,policy) 234 self.verifyPolicy(policy) 235 agent.policy.project(R,depth=1) 236 self.verifyPolicy() 237 print 'Horizon: 0' 238 print agent.policy.getTable() 239 table = agent.policy.getTable() 240 self.verifyPrune(table) 241 for horizon in range(1,8): 242 print 'Horizon:',horizon 243 self.deepVerify(R) 244 if horizon == 7: 245 prof = hotshot.Profile("tiger%d.prof" % (horizon)) 246 prof.start() 247 start = time.time() 248 agent.policy.project(R,depth=1) 249 print time.time()-start 250 if horizon == 7: 251 prof.stop() 252 prof.close() 253 stats = hotshot.stats.load("tiger%d.prof" % (horizon)) 254 stats.strip_dirs() 255 stats.sort_stats('time', 'calls') 256 stats.print_stats(20) 257 policy = agent.policy.getTable() 258 print 'Rules:',len(policy.rules) 259 self.verifyPolicy() 260 policy.prune(rulesOnly=True) 261 print policy
262
263 - def deepVerify(self,R):
264 agent = self.scenario['Player 1'] 265 other = self.scenario['Player 2'] 266 previous = agent.policy.getTable() 267 # Compute new value function: V_a(b) = R(a,b) + ... 268 V = R.getTable() 269 Vstar = previous.star() 270 self.verifyStar(Vstar,previous) 271 old = Vstar.getTable() 272 Vstar.prune() 273 self.verifyStar(Vstar,previous) 274 self.verifyPrune(Vstar,old) 275 for omega,SE in agent.estimators.items(): 276 # ... + \sum_\omega V^*(SE_a(b,\omega)) 277 product = Vstar.__mul__(SE,debug=True) 278 ## product.prune(debug=False) 279 self.verifyProduct(Vstar,omega,product) 280 Vnew = V.__add__(product,debug=False) 281 self.verifySum(V,product,Vnew) 282 Vnew.prune() 283 self.verifySum(V,product,Vnew) 284 V = Vnew 285 self.verifyPrune(V,inspect='values') 286 self.verifyPolicy(V,Vstar) 287 policy = V.max() 288 self.verifyMax(V,policy) 289 self.verifyPolicy(policy,Vstar) 290 old = policy.getTable() 291 policy.pruneAttributes() 292 self.verifyPrune(policy,old) 293 self.verifyPolicy(policy,Vstar)
294
295 - def verifyMax(self,raw,maxed):
296 for beliefs in probabilityIterator(self.scenario.state): 297 # Find rules that trigger 298 rawIndex = raw.index(beliefs) 299 self.assert_(raw.values.has_key(rawIndex)) 300 maxIndex = maxed.index(beliefs) 301 self.assert_(maxed.rules.has_key(maxIndex),'Missing: (%d) %s' % (maxIndex,maxed.factorString(maxIndex))) 302 self.assert_(maxed.values.has_key(maxIndex)) 303 # Iterate through actions 304 best = None 305 for option,V in raw.values[rawIndex].items(): 306 value = V*beliefs 307 if best is None or value > best: 308 best = value 309 self.assert_(maxed.values[maxIndex].has_key(option)) 310 maxValue = maxed.values[maxIndex][option] 311 self.failUnlessEqual(V,maxValue,'Failed on rule %d\n%s' % \ 312 (rawIndex,maxed.index2factored(maxIndex))) 313 option = maxed.rules[maxIndex] 314 value = maxed.values[maxIndex][option]*beliefs 315 self.assertAlmostEqual(best,value,5, 316 'Failed on rule %d (%f vs. %f)' % (rawIndex,best,value))
317
318 - def verifyPolicy(self,policy=None,Vstar=None,debug=False):
319 agent = self.scenario['Player 1'] 320 other = self.scenario['Player 2'] 321 yrOption,exp = other.policy.execute() 322 if policy is None: 323 policy = agent.policy.tables[-1][-1] 324 if Vstar is None and len(agent.policy.tables[-1]) > 1: 325 Vstar = agent.policy.tables[-1][-2].star() 326 Vstar.prune() 327 for beliefs in probabilityIterator(self.scenario.state): 328 index = policy.index(beliefs) 329 actions = {other.name:yrOption} 330 best = None 331 results = {} 332 for option in agent.actions.getOptions(): 333 actions[agent.name] = option 334 actionKey = ' '.join(map(str,actions.values())) 335 results[str(option)] = self.reward[actionKey]*beliefs 336 if Vstar: 337 # Project next time step 338 partial = {} 339 newState = self.transition[actionKey]*beliefs 340 for omega,O in self.observations[actionKey].items(): 341 prob = O*newState 342 newBeliefs = agent.stateEstimator(beliefs,option,omega) 343 projection = Vstar[newBeliefs]*newBeliefs 344 results[str(option)] += prob*projection 345 partial[str(omega)] = {'probability':prob, 346 'state':newState.getArray(), 347 'beliefs':newBeliefs.getArray(), 348 'value': projection, 349 'V': Vstar[newBeliefs].getArray(), 350 } 351 if policy.values.has_key(index): 352 # Checking value function 353 value = policy.values[index][str(option)]*beliefs 354 self.assertAlmostEqual(value,results[str(option)],3) 355 if policy.rules.has_key(index): 356 best = str(policy[beliefs]) 357 for option,value in results.items(): 358 self.failIf(results[best]+1e-10 < value)
359
360 - def verifySum(self,A,B,total):
361 agent = self.scenario['Player 1'] 362 for index,table in total.values.items(): 363 for option in agent.actions.getOptions(): 364 self.assert_(table.has_key(str(option)), 365 'Only %d entries' % (len(table))) 366 for beliefs in probabilityIterator(self.scenario.state): 367 for option in agent.actions.getOptions(): 368 index = A.index(beliefs) 369 realValue = A.values[index][str(option)]*beliefs 370 index = B.index(beliefs) 371 realValue += B.values[index][str(option)]*beliefs 372 index = total.index(beliefs) 373 testValue = total.values[index][str(option)]*beliefs 374 self.assertAlmostEqual(realValue,testValue,5)
375
376 - def verifyProduct(self,V,omega,product,debug=False):
377 agent = self.scenario['Player 1'] 378 other = self.scenario['Player 2'] 379 yrOption,exp = other.policy.execute() 380 actions = {other.name:yrOption} 381 for index,table in product.values.items(): 382 for option in agent.actions.getOptions(): 383 if not table.has_key(str(option)): 384 print 'Rule %d missing %s' % (index,str(option)) 385 factors = product.index2factored(index) 386 for i in range(len(product.attributes)): 387 print product.attributes[i][0].getArray(),bool(factors[i]) 388 self.assert_(table.has_key(str(option))) 389 for beliefs in probabilityIterator(self.scenario.state): 390 self.assertAlmostEqual(sum(beliefs.getArray()),1.,5) 391 if debug: 392 print 'Beliefs:',beliefs.getArray(),omega 393 beliefs.getArray()[0] = 0.5 394 beliefs.getArray()[1] = 0.5 395 index = product.index(beliefs) 396 if debug: 397 for attr in product.attributes: 398 print '\t',attr[0].getArray() 399 print product.index2factored(index) 400 self.assert_(product.values.has_key(index)) 401 for myOption in agent.actions.getOptions(): 402 self.assert_(product.values[index].has_key(str(myOption))) 403 rhs = product.values[index][str(myOption)] 404 value = rhs*beliefs 405 # Compute value from first principles 406 actions[agent.name] = myOption 407 actionKey = ' '.join(map(str,actions.values())) 408 new = copy.copy(beliefs) 409 for oldWorld in new.keys(): 410 new[oldWorld] = 0. 411 for oldWorld,oldProb in beliefs.items(): 412 for newWorld in new.keys(): 413 newProb = oldProb*self.transition[actionKey][newWorld][oldWorld]*self.observations[actionKey][omega][newWorld] 414 new[newWorld] += newProb 415 normalization = sum(new.getArray()) 416 for world,prob in new.items(): 417 new[world] = prob/normalization 418 # Compute value at next time step 419 real = V[new]*new 420 # Compute probability of receiving observation 421 obs = self.observations[actionKey][omega]*beliefs 422 real *= obs 423 if debug: 424 print 'New:',new.getArray() 425 print 'V0:',V[new].getArray() 426 print 'V0(b):',V[new]*new 427 print 'O:',obs 428 print 'Real:',real 429 print 'Computed RHS:',rhs.getArray() 430 print 'Computed Product:',value 431 self.assertAlmostEqual(value,real,5)
432
433 - def verifyStar(self,star,old):
434 for beliefs in probabilityIterator(self.scenario.state): 435 myReward = star[beliefs]*beliefs 436 index = old.index(beliefs) 437 for option,value in old.values[index].items(): 438 self.failIf(value*beliefs > myReward+0.000001)
439
440 - def verifyPrune(self,table,old=None,inspect='rules'):
441 missing = {} 442 if inspect == 'values': 443 keyList = table.values.keys() 444 else: 445 keyList = table.rules.keys() 446 for index in keyList: 447 missing[index] = True 448 for beliefs in probabilityIterator(self.scenario.state): 449 if self._verifyPrune(table,beliefs,missing,old,inspect): 450 break 451 else: 452 for beliefs in probabilityIterator(self.scenario.state,factor=0.5,update='multiplicative'): 453 if self._verifyPrune(table,beliefs,missing,old,inspect): 454 break 455 else: 456 for rule in missing.keys(): 457 print table.factorString(rule) 458 self.fail('Able to hit only %d/%d rules' % \ 459 (len(keyList)-len(missing),len(keyList)))
460
461 - def _verifyPrune(self,table,beliefs,missing,old=None,inspect='rules'):
462 index = table.index(beliefs) 463 if inspect == 'values': 464 self.assert_(table.values.has_key(index)) 465 else: 466 self.assert_(table.rules.has_key(index)) 467 if missing.has_key(index): 468 del missing[index] 469 if len(missing) == 0: 470 return True 471 if old: 472 if inspect == 'values': 473 self.assertEqual(table.values[index], 474 old.values[old.index(beliefs)]) 475 else: 476 self.assertEqual(table.rules[index], 477 old.rules[old.index(beliefs)]) 478 return False
479
480 - def DONTtestXML(self):
481 for agent in self.scenario.members(): 482 doc = agent.__xml__() 483 agent = PWLAgent() 484 agent.parse(doc.documentElement) 485 self.assertEqual(agent.name,agent.name) 486 self.assertEqual(agent.beliefs,agent.beliefs) 487 self.assertEqual(len(agent.dynamics.keys()), 488 len(agent.dynamics.keys())) 489 for action in agent.dynamics.keys(): 490 self.assert_(agent.dynamics.has_key(action)) 491 self.assertEqual(agent.dynamics[action], 492 agent.dynamics[action])
493 494 if __name__ == '__main__': 495 unittest.main() 496