Package teamwork :: Package test :: Package agent :: Module testRecursiveAgent
[hide private]
[frames] | no frames]

Source Code for Module teamwork.test.agent.testRecursiveAgent

  1  from teamwork.agent.Entities import * 
  2  from teamwork.multiagent.sequential import * 
  3  from teamwork.multiagent.GenericSociety import * 
  4  from teamwork.agent.DefaultBased import createEntity 
  5  from teamwork.messages.PsychMessage import * 
  6  from teamwork.math.Interval import * 
  7  from teamwork.math.rules import applyRules,internalCheck 
  8  import copy 
  9  import random 
 10  import time 
 11  import unittest 
 12   
13 -class TestRecursiveAgentPort(unittest.TestCase):
14 debug = None 15
16 - def setUp(self):
17 """Creates the instantiated scenario used for testing""" 18 from teamwork.examples.InfoShare import PortClasses 19 self.society = GenericSociety() 20 self.society.importDict(PortClasses.classHierarchy) 21 entities = [] 22 self.instances = {'World':1,'FederalAuthority':1,'FirstResponder':1,'Shipper':1} 23 for cls,num in self.instances.items(): 24 for index in range(num): 25 if num > 1: 26 name = '%s%d' % (cls,index) 27 else: 28 name = cls 29 entities.append(createEntity(cls,name,self.society,PsychEntity)) 30 self.entities = SequentialAgents(entities) 31 self.entities.applyDefaults() 32 self.entities.compileDynamics() 33 # Set up the spec of the desired test action 34 entity = self.entities['FirstResponder'] 35 options = entity.actions.getOptions() 36 for act in options: 37 if act[0]['type'] == 'inspect': 38 break 39 else: 40 # No inspection act found! 41 self.fail() 42 self.actions = {entity.name:act} 43 # Set up the spec of the desired test message 44 self.danger = Distribution({0.:0.9, 45 0.7:0.1}) 46 msg = {'factors': [{'topic':'state', 47 'relation':'=','value':self.danger, 48 'lhs':['entities','Shipper','state', 49 'containerDanger']}]} 50 self.msg = Message(msg)
51
52 - def testIncorporateMessage(self):
53 """Tests the hypothetical belief update produced by a message""" 54 entity = self.entities['FirstResponder'] 55 delta,exp = entity.incorporateMessage(self.msg) 56 self.assert_(isinstance(delta['state'],Distribution)) 57 keyList = entity.entities.getStateKeys().keys() 58 keyList.sort() 59 info = Distribution() 60 for matrix,prob in delta['state'].items(): 61 self.assert_(isinstance(matrix,KeyedMatrix)) 62 self.assertEqual(matrix.rowKeys(),keyList) 63 for rowKey in matrix.rowKeys(): 64 row = matrix[rowKey] 65 self.assertEqual(row.keys(),keyList) 66 for colKey in row.keys(): 67 if isinstance(rowKey,StateKey) and \ 68 rowKey['feature'] == 'containerDanger': 69 if colKey == keyConstant: 70 info[matrix[rowKey][colKey]] = prob 71 else: 72 self.assertAlmostEqual(matrix[rowKey][colKey],0.,7) 73 elif rowKey == colKey: 74 self.assertAlmostEqual(matrix[rowKey][colKey],1.,7) 75 else: 76 self.assertAlmostEqual(matrix[rowKey][colKey],0.,7) 77 self.assertEqual(info,self.danger)
78
79 - def testMessage(self):
80 """Tests the real belief update produced by test message""" 81 entity = self.entities['FirstResponder'] 82 self.msg['_unobserved'] = ['Shipper','World'] 83 self.msg.forceAccept() 84 result,delta = entity.stateEstimator(entity, 85 {'FederalAuthority':[self.msg]}) 86 while entity.hasBelief('Shipper'): 87 self.assertEqual(entity.getBelief('Shipper','containerDanger'), 88 self.danger) 89 entity = entity.getEntity(entity.name)
90
91 -class TestRecursiveAgentIraq(unittest.TestCase):
92 debug = False 93 profile = False 94
95 - def setUp(self):
96 """Creates the instantiated scenario used for testing""" 97 from teamwork.examples.PSYOP import Society 98 self.society = GenericSociety() 99 self.society.importDict(Society.classHierarchy) 100 entities = [] 101 self.instances = {'GeographicArea':1, 102 'US':1, 103 'Turkomen':1, 104 'Kurds':1, 105 } 106 for cls,num in self.instances.items(): 107 for index in range(num): 108 if num > 1: 109 name = '%s%d' % (cls,index) 110 else: 111 name = cls 112 entity = createEntity(cls,name,self.society,PsychEntity) 113 entities.append(entity) 114 if entity.name in ['Turkomen','Kurds']: 115 entity.relationships = {'location':['GeographicArea']} 116 self.entities = SequentialAgents(entities) 117 self.entities.applyDefaults() 118 self.entities.compileDynamics(profile=self.profile) 119 # Set up the spec of the desired test action 120 entity = self.entities['Turkomen'] 121 options = entity.actions.getOptions() 122 self.wait = None 123 self.attack = None 124 for act in options: 125 if act[0]['type'] == 'attack': 126 self.attack = act 127 elif act[0]['type'] == 'wait': 128 self.wait = act 129 self.assert_(self.wait) 130 self.assert_(self.attack) 131 self.assertEqual(len(options),2)
132
133 - def verifyState(self,entity):
134 """Checks whether the state vector is well formed""" 135 if entity.name in ['Turkomen','Kurds']: 136 allowed = ['population','politicalPower','militaryPower', 137 'economicPower'] 138 elif entity.name == 'US': 139 allowed = ['population','militaryPower','economicPower'] 140 elif entity.name == 'GeographicArea': 141 allowed = ['oilInfrastructure'] 142 else: 143 self.fail() 144 ## if entity.parent and entity.parent.name != entity.name \ 145 ## and entity.name != 'GeographicArea': 146 ## allowed.append(entity._supportFeature) 147 ## allowed.append(entity._trustFeature) 148 # Make sure that the state contains only state for this agent 149 for row in entity.state.domain(): 150 count = 0 151 for key in row.keys(): 152 if isinstance(key,StateKey) and key['entity'] == entity.name: 153 count += 1 154 self.assertEqual(count,len(allowed), 155 '%s has state vector of length %d, instead of %d'\ 156 %(entity.ancestry(),len(row.getArray()), 157 len(allowed))) 158 for key,value in row.items(): 159 if isinstance(key,StateKey): 160 if key['entity'] == entity.name: 161 self.assert_(key['feature'] in allowed) 162 else: 163 self.assert_(isinstance(key,ConstantKey)) 164 # Make sure all features are covered 165 keyList = entity.getStateFeatures() 166 for row in entity.state.domain(): 167 for feature in allowed: 168 key = StateKey({'entity':entity.name,'feature':feature}) 169 self.assert_(row.has_key(key)) 170 self.assert_(key['feature'] in keyList) 171 keyList.remove(key['feature']) 172 self.assertEqual(keyList,[]) 173 # Descend recursively 174 for other in entity.getEntityBeliefs(): 175 self.verifyState(other)
176
177 - def verifyGlobalState(self,entities):
178 """Checks whether the scenario state is well formed 179 @param entities: the scenario to verify 180 @type entities: L{PsychAgents} 181 """ 182 state = entities.getState() 183 for vector in state.domain(): 184 keyList = vector.keys() 185 for entity in entities.members(): 186 self.assertEqual(id(entity.state),id(state)) 187 for feature in entity.getStateFeatures(): 188 key = StateKey({'entity':entity.name, 189 'feature':feature}) 190 self.assert_(key in keyList) 191 keyList.remove(key) 192 self.assertEqual(keyList,[keyConstant])
193
194 - def testLocalState(self):
195 for entity in self.entities.members(): 196 self.verifyState(entity) 197 if len(entity.entities) > 0: 198 state = entity.entities.getState() 199 for vector in state.domain(): 200 self.assertEqual(len(vector),13) 201 self.verifyGlobalState(entity.entities) 202 if entity.name == 'Turkomen': 203 state = entity.state.expectation() 204 count = 0 205 for key,value in state.items(): 206 if isinstance(key,StateKey) and \ 207 key['entity'] == entity.name: 208 self.assertAlmostEqual(value,.1,10) 209 count += 1 210 self.assertEqual(count,4) 211 elif entity.name == 'Kurds': 212 state = entity.state.expectation() 213 count = 0 214 for key,value in state.items(): 215 if isinstance(key,StateKey) and \ 216 key['entity'] == entity.name: 217 if key['feature'] == 'politicalPower': 218 self.assertAlmostEqual(value,.4,10) 219 else: 220 self.assertAlmostEqual(value,.2,10) 221 count += 1 222 self.assertEqual(count,4) 223 elif entity.name == 'GeographicArea': 224 state = entity.state.expectation() 225 count = 0 226 for key,value in state.items(): 227 if isinstance(key,StateKey) and \ 228 key['entity'] == entity.name: 229 count += 1 230 if key['feature'] == 'oilInfrastructure': 231 self.assertAlmostEqual(value,.8,10) 232 else: 233 self.fail() 234 self.assertEqual(count,1) 235 elif entity.name == 'US': 236 state = entity.state.expectation() 237 count = 0 238 for key,value in state.items(): 239 if isinstance(key,StateKey) and \ 240 key['entity'] == entity.name: 241 self.assertAlmostEqual(value,.1,10) 242 count += 1 243 self.assertEqual(count,3) 244 else: 245 self.fail()
246
247 - def testState(self):
248 state = self.entities.getState() 249 for vector in state.domain(): 250 self.assertEqual(len(vector),13) 251 self.verifyGlobalState(self.entities)
252
253 - def testDynamics(self):
254 entity = self.entities['Turkomen'] 255 action = self.wait[0] 256 tree = self.entities.getDynamics({entity.name:self.wait})['state'].getTree() 257 for matrix in tree.leaves(): 258 for feature in entity.getStateFeatures(): 259 rowKey = StateKey({'feature':feature,'entity':entity.name}) 260 self.assert_(matrix.has_key(rowKey)) 261 row = matrix[rowKey] 262 if feature in ['politicalPower','population']: 263 for colKey,value in row.items(): 264 if colKey == rowKey: 265 self.assertAlmostEqual(value,1.,8) 266 else: 267 self.assertAlmostEqual(value,0.,8) 268 dynamics = self.entities['GeographicArea'].getDynamics(self.attack[0],'oilInfrastructure') 269 tree = dynamics.getTree() 270 flag = False 271 for matrix in tree.leaves(): 272 for rowKey in matrix.rowKeys(): 273 if isinstance(rowKey,StateKey) and \ 274 rowKey['entity'] == 'GeographicArea' and \ 275 rowKey['feature'] == 'oilInfrastructure': 276 for colKey in matrix.colKeys(): 277 value = matrix[rowKey][colKey] 278 if isinstance(colKey,StateKey): 279 if colKey == rowKey: 280 self.assertAlmostEqual(value,1.,8) 281 elif colKey['entity'] == 'Turkomen' and \ 282 colKey['feature'] == 'militaryPower': 283 if value < -0.05: 284 flag = True 285 else: 286 self.assertAlmostEqual(value,0.,8) 287 break 288 else: 289 self.fail() 290 self.assert_(flag) 291 tree = self.entities.getDynamics({entity.name:self.attack})['state'].getTree() 292 flag = False 293 for matrix in tree.leaves(): 294 for rowKey in self.entities.getStateKeys().keys(): 295 if isinstance(rowKey,StateKey) and \ 296 rowKey['entity'] == 'GeographicArea' and \ 297 rowKey['feature'] == 'oilInfrastructure': 298 for colKey in self.entities.getStateKeys().keys(): 299 value = matrix[rowKey][colKey] 300 if colKey == rowKey: 301 self.assertAlmostEqual(value,1.,8) 302 elif isinstance(colKey,StateKey) and \ 303 colKey['entity'] == entity.name and \ 304 colKey['feature'] == 'militaryPower': 305 if value < -.05: 306 flag = True 307 else: 308 self.assertAlmostEqual(value,0.,8) 309 else: 310 for colKey in self.entities.getStateKeys().keys(): 311 value = matrix[rowKey][colKey] 312 if colKey == rowKey: 313 self.assertAlmostEqual(value,1.,8) 314 else: 315 self.assertAlmostEqual(value,0.,8) 316 self.assert_(flag)
317
318 - def verifyEffect(self,entities,action):
319 """Check the result of the given action 320 @param entities: the scenario to which the action has been applied 321 @type entities: L{PsychAgents} 322 @param action: the action performed 323 @type action: L{Action} 324 """ 325 for entity in entities.members(): 326 for feature in entity.getStateFeatures(): 327 for cls in entity.classes: 328 generic = entity.hierarchy[cls] 329 if feature in generic.getStateFeatures(): 330 if feature == 'oilInfrastructure': 331 new = entity.getState(feature).domain()[0] 332 old = generic.getState(feature).domain()[0] 333 attacker = entities[action['actor']] 334 power = attacker.getState('militaryPower') 335 power = power.domain()[0] 336 diff = -.1*power 337 self.assertAlmostEqual(old+diff,new,10) 338 else: 339 new = entity.getState(feature).domain()[0] 340 old = generic.getState(feature).domain()[0] 341 self.assertAlmostEqual(new,old,10) 342 break 343 else: 344 # Should be able to check belief features, too, but not right now 345 self.assertEqual(feature[0],'_')
346
347 - def testAction(self):
348 name = 'Turkomen' 349 entity = self.entities[name] 350 key = StateKey({'entity':name, 351 'feature':self.entities.turnFeature}) 352 self.assertAlmostEqual(self.entities.order[key],0.25,3) 353 # Check initial observation flags 354 observations = self.entities.getActions() 355 for obsKey,value in observations.items(): 356 if isinstance(obsKey,ActionKey): 357 self.assertAlmostEqual(value,0.,8) 358 else: 359 self.assert_(isinstance(obsKey,ConstantKey)) 360 self.assertAlmostEqual(value,1.,8) 361 # Perform action 362 dynamics = self.entities.getDynamics({name:self.attack})['state'].getTree()[self.entities.getState()] 363 self.assert_(self.entities['GeographicArea'].state is self.entities.state) 364 result = self.entities.microstep([{'name':name, 365 'choices':[self.attack]}]) 366 # Check final observation flags 367 observations = self.entities.getActions() 368 for obsKey,value in observations.items(): 369 if isinstance(obsKey,ActionKey): 370 if obsKey['entity'] == name and \ 371 obsKey['type'] == self.attack[0]['type'] and \ 372 obsKey['object'] == self.attack[0]['object']: 373 self.assertAlmostEqual(value,1.,8) 374 else: 375 self.assertAlmostEqual(value,0.,8) 376 else: 377 self.assert_(isinstance(obsKey,ConstantKey)) 378 self.assertAlmostEqual(value,1.,8) 379 # Check effects on states 380 for other in self.entities.members(): 381 self.verifyState(other) 382 self.assert_(self.entities['GeographicArea'].state is self.entities.state) 383 self.verifyEffect(self.entities,self.attack[0]) 384 for entity in self.entities.members(): 385 self.verifyEffect(entity.entities,self.attack[0]) 386 if len(entity.entities) > 0: 387 self.assertAlmostEqual(entity.entities.order[key],0.,8) 388 for other in entity.entities.members(): 389 self.verifyEffect(other.entities,self.attack[0]) 390 self.assertAlmostEqual(self.entities.order[key],0.,8) 391 key = StateKey({'entity':'US','feature':self.entities.turnFeature}) 392 self.assertAlmostEqual(self.entities.order[key],0.25,3)
393
394 - def testValueAttack(self):
395 self.entities.compileDynamics() 396 entity = self.entities['Turkomen'] 397 action = {entity.name:self.attack} 398 horizon = 3 399 # Compute the projected value of the action over different horizons 400 expected = [] 401 for t in range(1,horizon+1): 402 value,explanation = entity.actionValue(self.attack,horizon=t) 403 expected.append(value) 404 # Compute the actual value of the action over different horizons 405 actual = [] 406 self.entities.performAct(action) 407 value = entity.applyGoals() 408 actual.append(value) 409 for t in range(horizon-1): 410 next = entity.entities.next()[0]['name'] # Assume only one entity 411 other = entity.getEntity(next) 412 action,explanation = other.applyPolicy() 413 self.entities.performAct({next:action}) 414 value = entity.applyGoals() 415 actual.append(actual[t]+value) 416 for t in range(horizon): 417 self.assertEqual(expected[t],actual[t])
418
419 - def testValueWait(self):
420 self.entities.compileDynamics() 421 entity = self.entities['Turkomen'] 422 action = {entity.name:self.wait} 423 horizon = 3 424 # Compute the projected value of the action over different horizons 425 expected = [] 426 for t in range(1,horizon+1): 427 value,explanation = entity.actionValue(self.wait,t) 428 expected.append(value) 429 # Compute the actual value of the action over different horizons 430 actual = [] 431 self.entities.performAct(action) 432 value = entity.applyGoals() 433 actual.append(value) 434 for t in range(horizon-1): 435 next = entity.entities.next()[0]['name'] # Assume only one entity 436 other = entity.getEntity(next) 437 action,explanation = other.applyPolicy() 438 self.entities.performAct({next:action}) 439 value = entity.applyGoals() 440 actual.append(actual[t]+value) 441 for t in range(horizon): 442 self.assertEqual(expected[t],actual[t])
443
444 - def testPolicy(self):
445 # Verify the entity attribute on everyone's policy 446 entityList = self.entities.activeMembers() 447 while len(entityList) > 0: 448 entity = entityList.pop() 449 self.assert_(entity.policy.entity is entity) 450 entityList += entity.entities.activeMembers() 451 # Verify the result of applying one agent's policy 452 entity = self.entities['Turkomen'] 453 self.entities.compileDynamics() 454 values = {} 455 for action in entity.actions.getOptions(): 456 values[action[0]],explanation = entity.actionValue(action,3) 457 action,explanation = entity.applyPolicy() 458 for option,value in values.items(): 459 if action[0] != option: 460 self.assert_(float(values[action[0]]) > float(value), 461 '%s is preferred over %s, although value %s does not exceed %s' % (action[0],option,values[action[0]],value)) 462 # Test turn order 463 order = [] 464 for key,value in self.entities.order.items(): 465 if isinstance(key,StateKey): 466 order.append((value,key['entity'])) 467 order.sort() 468 order.reverse() 469 order = map(lambda t:t[1],order) 470 breakdown = explanation['options'][str(action)]['breakdown'] 471 topOrder = order[:] 472 topOrder.remove(entity.name) 473 topOrder.insert(0,entity.name) 474 for t0 in range(len(breakdown)): 475 step = breakdown[t0] 476 self.assertEqual(len(step['action']),1) 477 actor = step['action'].keys()[0] 478 self.assertEqual(actor,topOrder[0]) 479 topOrder.remove(actor) 480 if len(topOrder) == 0: 481 topOrder = order[:] 482 decision = step['action'][actor] 483 subBreakdown = step['breakdown'][actor] 484 if t0 == 0: 485 self.assert_(subBreakdown.has_key('forced')) 486 self.assert_(subBreakdown['forced']) 487 else: 488 botOrder = topOrder[:]+order 489 botOrder.insert(0,actor) 490 subBreakdown = subBreakdown['options'][str(decision)] 491 self.assert_(subBreakdown.has_key('breakdown')) 492 subBreakdown = subBreakdown['breakdown'] 493 for t1 in range(len(subBreakdown)): 494 subStep = subBreakdown[t1] 495 self.assertEqual(len(subStep['action']),1) 496 subActor = subStep['action'].keys()[0] 497 self.assertEqual(subActor,botOrder[t1])
498
500 parent = self.entities['Turkomen'].getEntity('Kurds') 501 entity = parent.getEntity('Turkomen') 502 tree = entity.policy.compileTree() 503 state = copy.copy(parent.entities.getState()) 504 for index in xrange(pow(self.granularity,len(self.keyKeys))): 505 for key in self.keyKeys: 506 state.domain()[0][key] = float(index % self.granularity)/float(self.granularity) + self.base 507 index /= self.granularity 508 decision = tree[state].domain()[0] 509 best = entity.policy.getValueTree(decision)*state 510 for action in entity.actions.getOptions(): 511 if str(action) != decision: 512 # Relies on having passed testValueFunction 513 value = entity.policy.getValueTree(action)*state 514 self.assert_(float(best) >= float(value))
515
516 - def DONTtestRuleReplacement(self):
517 entity = self.entities['Turkomen'] 518 entity.horizon = 1 519 entity.policy.depth = entity.horizon 520 rules,attributes,values = entity.policy.compileRules() 521 # Test dynamics extension of policy 522 dynamicsRules = {} 523 dynamicsAttrs = {} 524 dynamicsValues = {} 525 rawValues = {} 526 rawAttributes = copy.copy(attributes) 527 for key,action in values.items(): 528 actionDict = {entity.name:action} 529 tree = entity.entities.getDynamics(actionDict)['state'].getTree() 530 dynamicsAttrs[key] = {} 531 dynamicsValues[key] = {} 532 dynamicsRules[key] = tree.makeRules(dynamicsAttrs[key], 533 dynamicsValues[key]) 534 rawAttributes.update(dynamicsAttrs[key]) 535 rawValues.update(dynamicsValues[key]) 536 self.assertEqual(dynamicsRules.keys(),values.keys()) 537 for rule in rules: 538 for attr,value in rule.items(): 539 if attr != '_value': 540 self.assert_(rawAttributes.has_key(attr)) 541 for ruleSet in dynamicsRules.values(): 542 for rule in ruleSet: 543 for attr,value in rule.items(): 544 if attr != '_value': 545 self.assert_(rawAttributes.has_key(attr)) 546 rawRules = replaceValues(rules,rawAttributes,values,dynamicsRules,rawValues) 547 state = copy.copy(self.entities.getState()) 548 for index in xrange(pow(self.granularity,len(self.keyKeys))): 549 for key in self.keyKeys: 550 state.domain()[0][key] = float(index % self.granularity)/float(self.granularity) + self.base 551 index /= self.granularity 552 decision = applyRules(state,rules,attributes,values,'_value',True) 553 rulesMatrix = applyRules(state,dynamicsRules[str(decision)], 554 dynamicsAttrs[str(decision)], 555 dynamicsValues[str(decision)],'_value',True) 556 raw = applyRules(state,rawRules,rawAttributes,rawValues,'_value',True) 557 self.assertEqual(raw.getArray(),rulesMatrix.getArray()) 558 self.assertEqual(len(raw),len(rulesMatrix)) 559 self.assertEqual(len(raw),len(state.domainKeys())) 560 self.assertEqual(len(raw.colKeys()),len(rulesMatrix.colKeys())) 561 self.assertEqual(len(raw.colKeys()),len(state.domainKeys())) 562 for row in rulesMatrix.keys(): 563 self.assert_(raw.has_key(row)) 564 rawRow = raw[row] 565 dynRow = rulesMatrix[row] 566 for col in dynRow.keys(): 567 self.assert_(dynRow.has_key(col)) 568 self.assert_(rawRow.has_key(col)) 569 self.assertAlmostEqual(dynRow[col],rawRow[col],8)
570
572 entity = self.entities['Turkomen'] 573 entity.horizon = 1 574 entity.policy.depth = entity.horizon 575 orig = entity.policy.compileRules() 576 for index in range(1000): 577 new = entity.policy.compileRules() 578 self.assertEqual(orig,new)
579
580 - def DONTtestPolicyPruning(self):
581 target = '_value' 582 entity = self.entities['Turkomen'].getEntity('Kurds') 583 entity.horizon = 1 584 entity.policy.depth = entity.horizon 585 oldRules,oldAttrs,oldVals = entity.policy.getDynamics() 586 self.verifyRuleConsistency(oldRules,oldAttrs,oldVals) 587 print len(oldRules) 588 newRules,newAttrs = internalCheck(oldRules[:],oldAttrs,oldVals,target) 589 print len(newRules) 590 self.verifyRuleEquality(oldRules,oldAttrs,oldVals,newRules,newAttrs,oldVals)
591
593 entity = self.entities['Turkomen'] 594 goals = entity.getGoalTree().getValue() 595 other = entity.getEntity('US') 596 other.horizon = 1 597 other.policy.depth = other.horizon 598 start = time.time() 599 USRules,USAttrs,USVals = other.policy.getDynamics() 600 print other.ancestry(),time.time()-start 601 self.verifyRuleConsistency(USRules,USAttrs,USVals) 602 other = entity.getEntity('Kurds') 603 other.horizon = 1 604 other.policy.depth = other.horizon 605 start = time.time() 606 kRules,kAttrs,kVals = other.policy.getDynamics() 607 print other.ancestry(),time.time()-start 608 self.verifyRuleConsistency(kRules,kAttrs,kVals) 609 state0 = copy.copy(self.entities.getState()) 610 for action in entity.actions.getOptions(): 611 print action 612 # Cumulative storage of attributes and values 613 policyAttributes = {'_value':True} 614 policyValues = {} 615 dynamicsAttributes = {'_value':True} 616 dynamicsValues = {} 617 618 # Step 1: Turkomen perform given action 619 actionDict = {entity.name:action} 620 step1Tree = entity.entities.getDynamics(actionDict)['state'].getTree() 621 step1Attrs = {} 622 step1Vals = {} 623 start = time.time() 624 step1Rules = step1Tree.makeRules(step1Attrs,step1Vals) 625 print 'Making rules:',time.time()-start 626 627 # Update cumulative dynamics 628 dynamicsAttributes.update(step1Attrs) 629 dynamicsValues.update(step1Vals) 630 dynamicsRules = step1Rules 631 632 # Compute current total reward 633 totalVals = copy.copy(dynamicsValues) 634 start = time.time() 635 totalRules = mapValues(dynamicsRules,totalVals,lambda v:goals*v) 636 print 'Evaluating:',time.time()-start 637 policyValues.update(totalVals) 638 policyAttributes.update(dynamicsAttributes) 639 policyRules = totalRules 640 641 # Step 2: US follows policy 642 dynamicsAttributes.update(USAttrs) 643 dynamicsValues.update(USVals) 644 start = time.time() 645 ## dynamicsRules,dynamicsAttributes = self.detailedMultiply(USRules,dynamicsRules, 646 ## dynamicsAttributes,dynamicsValues) 647 dynamicsRules,dynamicsAttributes = multiplyRules(USRules,dynamicsRules, 648 dynamicsAttributes, 649 dynamicsValues) 650 print 'Multiplying:',time.time()-start 651 print len(dynamicsRules) 652 653 # Update running total of reward 654 totalVals = copy.copy(dynamicsValues) 655 start = time.time() 656 totalRules = mapValues(dynamicsRules,totalVals,lambda v:goals*v) 657 print 'Evaluating:',time.time()-start 658 policyValues.update(totalVals) 659 policyAttributes.update(dynamicsAttributes) 660 policyRules = addRules(policyRules,totalRules,policyAttributes,policyValues) 661 662 # Step 3: Kurds follow Policy 663 dynamicsAttributes.update(kAttrs) 664 dynamicsValues.update(kVals) 665 start = time.time() 666 ## dynamicsRules,dynamicsAttributes = self.detailedMultiply(kRules,dynamicsRules, 667 ## dynamicsAttributes,dynamicsValues) 668 dynamicsRules,dynamicsAttributes = multiplyRules(kRules,dynamicsRules, 669 dynamicsAttributes, 670 dynamicsValues) 671 print 'Multiplying:',time.time()-start 672 print len(dynamicsRules) 673 674 # Update running total of reward 675 totalVals = copy.copy(dynamicsValues) 676 start = time.time() 677 totalRules = mapValues(dynamicsRules,totalVals,lambda v:goals*v) 678 print 'Evaluating:',time.time()-start 679 policyValues.update(totalVals) 680 policyAttributes.update(dynamicsAttributes) 681 policyRules = addRules(policyRules,totalRules,policyAttributes,policyValues) 682 683 print 'Running %d tests' % (pow(self.granularity,len(self.keyKeys))) 684 for index in xrange(pow(self.granularity,len(self.keyKeys))): 685 for key in self.keyKeys: 686 state0.domain()[0][key] = float(index % self.granularity)/float(self.granularity) + self.base 687 index /= self.granularity 688 ## print 0,state0.domain()[0].simpleText() 689 real0 = state0 690 rhs1 = applyRules(state0,step1Rules,step1Attrs,step1Vals,'_value',True) 691 product = rhs1 692 total = goals * rhs1 693 state1 = Distribution({rhs1*state0.domain()[0]:1.}) 694 real1 = step1Tree*real0 695 self.assertEqual(state1,real1) 696 reward = goals*real1.domain()[0] 697 self.assertAlmostEqual(reward,total*real0.domain()[0],8) 698 ## print 1,state1.domain()[0].simpleText() 699 rhsUS = applyRules(state1,USRules,USAttrs,USVals,'_value',True) 700 product = rhsUS*product 701 total += goals*product 702 decision,exp = entity.getEntity('US').policy.execute({'state':real1}) 703 step2Tree = self.entities.getDynamics({'US':decision})['state'].getTree() 704 real2 = step2Tree*real1 705 state2 = Distribution({rhsUS*state1.domain()[0]:1.}) 706 self.assertEqual(state2,real2) 707 reward += goals*real2.domain()[0] 708 self.assertAlmostEqual(reward,total*real0.domain()[0],8) 709 ## print 2,state2.domain()[0].simpleText() 710 rhsKurds = applyRules(state2,kRules,kAttrs,kVals,'_value',True) 711 product = rhsKurds*product 712 total += goals*product 713 ## print result[0] 714 state3 = Distribution({rhsKurds*state2.domain()[0]:1.}) 715 decision,exp = entity.getEntity('Kurds').policy.execute({'state':real1}) 716 step3Tree = self.entities.getDynamics({'Kurds':decision})['state'].getTree() 717 real3 = step3Tree*real2 718 self.assertEqual(state3,real3) 719 reward += goals*real3.domain()[0] 720 self.assertAlmostEqual(reward,total*real0.domain()[0],8) 721 ## print 3,state3.domain()[0].simpleText() 722 projection = applyRules(state0,dynamicsRules,dynamicsAttributes, 723 dynamicsValues,'_value',True) 724 ## print 'Alleged:',(projection*state0.domain()[0]).simpleText() 725 self.assertEqual(product,projection, 726 'Difference found between %s and %s' % \ 727 (product.simpleText(),projection.simpleText())) 728 729 rhsTotal = applyRules(state0,policyRules,policyAttributes, 730 policyValues,'_value',True) 731 self.assertEqual(total,rhsTotal,'Difference found between %s and %s' % \ 732 (total.simpleText(),rhsTotal.simpleText()))
733
734 - def detailedMultiply(self,set1,set2,attributes,values):
735 newRules = [] 736 target = '_value' 737 lhsKeys = filter(lambda k:k!=target,attributes.keys()) 738 self.assertEqual(len(lhsKeys),len(attributes)-1) 739 newAttributes = {} 740 for new in set1: 741 for old in set2: 742 matrix = values[old[target]] 743 inconsistent = False 744 ## print 'projecting by:',matrix.simpleText() 745 projectedNew = {} 746 for newAttr,newValue in new.items(): 747 self.assert_(attributes.has_key(newAttr)) 748 if newAttr != target and newValue is not None: 749 ## print '\t',newAttr,newValue 750 newPlane = attributes[newAttr] 751 weights = newPlane.weights * matrix 752 newPlane = newPlane.__class__(weights, 753 newPlane.threshold) 754 label = newPlane.simpleText() 755 for oldAttr,oldValue in old.items(): 756 if oldAttr != target and oldValue is not None: 757 oldPlane = attributes[oldAttr] 758 result = oldPlane.compare(newPlane) 759 if result == 'equal': 760 label = oldAttr 761 newAttributes[oldAttr] = attributes[oldAttr] 762 if projectedNew.has_key(oldAttr): 763 if projectedNew[oldAttr] is None: 764 projectedNew[oldAttr] = newValue 765 elif newValue is not None and \ 766 projectedNew[oldAttr] != newValue: 767 inconsistent = True 768 else: 769 projectedNew[oldAttr] = newValue 770 if projectedNew[oldAttr] is None: 771 # Old rule takes precedence 772 projectedNew[oldAttr] = oldValue 773 elif projectedNew[oldAttr] != oldValue: 774 # Mismatch 775 inconsistent = True 776 break 777 elif result == 'greater' and oldValue == True: 778 # newAttr is guaranteed to be True 779 if newValue == False: 780 inconsistent = True 781 break 782 elif result == 'less' and oldValue == False: 783 # newAttr is guaranteed to be False 784 if newValue == True: 785 inconsistent = True 786 break 787 else: 788 if newAttributes.has_key(label): 789 self.assert_(not projectedNew.has_key(label)) 790 projectedNew[label] = newValue 791 ## print 'new ->',label,newValue 792 elif attributes.has_key(label): 793 self.assert_(isinstance(newValue,bool)) 794 if old.has_key(label) and old[label] is not None and old[label] != newValue: 795 # Mismatch 796 inconsistent = True 797 break 798 self.assert_(not projectedNew.has_key(label)) 799 projectedNew[label] = newValue 800 newAttributes[label] = attributes[label] 801 ## print 'old ->',label,newValue 802 else: 803 # No matching plane found 804 ## print 'New attribute:',label 805 newAttributes[label] = newPlane 806 lhsKeys.append(label) 807 projectedNew[label] = newValue 808 if inconsistent: 809 break 810 if inconsistent: 811 # These two rules are incompatible 812 continue 813 for oldAttr,oldValue in old.items(): 814 if oldAttr == target or oldValue is None: 815 pass 816 elif projectedNew.has_key(oldAttr): 817 self.assertEqual(oldValue,projectedNew[oldAttr]) 818 else: 819 newAttributes[oldAttr] = attributes[oldAttr] 820 projectedNew[oldAttr] = oldValue 821 for key in projectedNew.keys(): 822 self.assert_(newAttributes.has_key(key) or attributes.has_key(key),'Key %s is missing' % (str(key))) 823 # Verify projection 824 for attr,value in new.items(): 825 if attr == target or value is None: 826 continue 827 oldPlane = attributes[attr] 828 plane = oldPlane.__class__(oldPlane.weights*matrix, 829 oldPlane.threshold) 830 for key in projectedNew.keys(): 831 if key != target: 832 result = newAttributes[key].compare(plane) 833 if result == 'equal': 834 break 835 elif result == 'less' and projectedNew[key] == False and new[attr] == False: 836 break 837 elif result == 'greater' and projectedNew[key] == True and new[attr] == True: 838 break 839 else: 840 self.fail() 841 self.assert_(newAttributes.has_key(key)) 842 self.assert_(key in lhsKeys) 843 self.assertEqual(projectedNew[key],value, 844 'Deviation (%s vs. %s) on %s (to %s)' % \ 845 (projectedNew[key],value,attr,key)) 846 # Compute new RHS 847 oldValue = values[old[target]] 848 newValue = values[new[target]] 849 product = newValue * oldValue 850 label = str(product) 851 projectedNew[target] = label 852 if not values.has_key(label): 853 values[label] = product 854 newRules.append(projectedNew) 855 for rule in newRules: 856 for attr in newAttributes.keys(): 857 if not rule.has_key(attr): 858 rule[attr] = None 859 newAttributes[target] = True 860 return newRules,newAttributes
861
862 - def DONTtestValueFunction(self):
863 for entity in self.entities.activeMembers(): 864 self.verifyValueFunction(entity)
865
866 - def verifyValueFunction(self,entity):
867 entity.horizon = 1 868 entity.policy.depth = entity.horizon 869 rules = {} 870 trees = {} 871 target = '_value' 872 for action in entity.actions.getOptions(): 873 tree = entity.policy.getValueTree(action) 874 attributes = {target:True} 875 values = {} 876 rules[str(action)] = tree.makeRules(attributes,values) 877 tree = tree.__class__() 878 tree = tree.fromRules(rules[str(action)],attributes,values) 879 trees[str(action)] = tree 880 # Verify that all values are mutually exclusive 881 key = StateKey({'entity':'GeographicArea','feature':'oilInfrastructure'}) 882 for index1 in range(len(values)-1): 883 key1,vector1 = values.items()[index1] 884 for index2 in range(index1+1,len(values)): 885 key2,vector2 = values.items()[index2] 886 self.assertNotEqual(key1,key2) 887 self.assertNotEqual(vector1[key],vector2[key]) 888 state = copy.copy(self.entities.getState()) 889 goals = entity.getGoalVector()['state'] 890 goals.fill(state.domain()[0].keys()) 891 for index in xrange(pow(self.granularity,len(self.keyKeys))): 892 for key in self.keyKeys: 893 state.domain()[0][key] = float(index % self.granularity)/float(self.granularity) 894 index /= self.granularity 895 for action in entity.actions.getOptions(): 896 dynamics = self.entities.getDynamics({entity.name:action}) 897 matrix = dynamics['state'].getTree()[state].domain()[0] 898 rawValue = (goals*matrix*state).domain()[0] 899 tree = entity.policy.getValueTree(action) 900 value = (tree*state).domain()[0] 901 self.assertAlmostEqual(rawValue,value,8) 902 # Test rule-based formulation of value function 903 rhs = None 904 for rule in rules[str(action)]: 905 for attr,truth in rule.items(): 906 if attr != target and truth is not None: 907 plane = attributes[attr] 908 if plane.test(state.domain()[0]) != truth: 909 break 910 else: 911 rhs = values[rule[target]] 912 self.assert_(rhs is not None) 913 ruleValue = rhs*(state.domain()[0]) 914 self.assertAlmostEqual(ruleValue,value,8) 915 ruleValue = (trees[str(action)]*state).domain()[0] 916 self.assertAlmostEqual(ruleValue,value,8)
917
918 - def DONTtestCompile(self):
919 entity = self.entities['Turkomen'] 920 goals = entity.getGoalTree() 921 goalKeys = {} 922 first = True 923 for leaf in goals.leaves(): 924 for key in leaf.keys(): 925 if first: 926 goalKeys[key] = True 927 else: 928 self.assert_(goalKeys.has_key(key)) 929 self.assert_(entity.entities.getStateKeys().has_key(key), 930 'Extraneous goal key: %s' % (str(key))) 931 for key in entity.state.domainKeys(): 932 self.assert_(leaf.has_key(key), 933 'Missing goal key: %s' % (str(key))) 934 if first: 935 first = False 936 else: 937 self.assertEqual(len(leaf),len(goalKeys)) 938 goalKeys = goalKeys.keys() 939 goalKeys.sort() 940 uncompiled,exp = entity.applyPolicy() 941 if self.profile: 942 import hotshot,hotshot.stats 943 filename = '/tmp/stats' 944 prof = hotshot.Profile(filename) 945 prof.start() 946 tree = entity.policy.getActionTree(self.attack) 947 for leaf in tree.leaves(): 948 for key in leaf.rowKeys(): 949 if not key in goalKeys: 950 print 'extra key:',key 951 self.assertEqual(len(leaf.rowKeys()),len(goalKeys)) 952 self.assertEqual(len(leaf.colKeys()),len(goalKeys)) 953 for key in leaf.rowKeys(): 954 self.assert_(key in goalKeys) 955 for key in leaf.colKeys(): 956 self.assert_(key in goalKeys) 957 entity.policy.compile() 958 if self.profile: 959 prof.stop() 960 prof.close() 961 print 'loading stats...' 962 stats = hotshot.stats.load(filename) 963 stats.strip_dirs() 964 stats.sort_stats('time', 'calls') 965 stats.print_stats() 966 compiled,exp = entity.applyPolicy() 967 self.assertEqual(uncompiled,compiled)
968 ## tree2 = entity.policy.getActionTree(entity,self.wait) 969 ## tree = tree1-tree2 970 ## print tree.simpleText() 971 ## self.entities.compile() 972 973 if __name__ == '__main__': 974 unittest.main() 975