Package teamwork :: Package test :: Package policy :: Module testPWLPolicy
[hide private]
[frames] | no frames]

Source Code for Module teamwork.test.policy.testPWLPolicy

  1  from teamwork.agent.Entities import * 
  2  from teamwork.multiagent.sequential import * 
  3  from teamwork.multiagent.GenericSociety import * 
  4  from teamwork.agent.DefaultBased import createEntity 
  5  from teamwork.math.rules import internalCheck,mergeAttributes 
  6   
  7  import time 
  8  import unittest 
  9   
10 -class TestPWLPolicy(unittest.TestCase):
11 """ 12 @cvar filename: the temporary filename for storing profiling stats 13 @cvar granularity: when exploring possible state vectors for testing, the number of points to try for each feature 14 @cvar base: the offset for generating the different points for each state feature (i.e., if L{granularity} is I{n}, then the points generated will be C{base}+1/I{n}, C{base}+2/I{n}, ..., C{base}+(I{n}-1)/I{n}) 15 """ 16 filename = '/tmp/compile.prof' 17 granularity = 2 18 base = 0.25 19
20 - def setUp(self):
21 """Creates the instantiated scenario used for testing""" 22 from teamwork.examples.PSYOP import Society 23 society = GenericSociety() 24 society.importDict(Society.classHierarchy) 25 entities = [] 26 self.instances = {'GeographicArea':1, 27 'US':1, 28 'Turkomen':1, 29 'Kurds':1} 30 for cls,num in self.instances.items(): 31 for index in range(num): 32 if num > 1: 33 name = '%s%d' % (cls,index) 34 else: 35 name = cls 36 entity = createEntity(cls,name,society,PsychEntity) 37 entities.append(entity) 38 if entity.name in ['Turkomen','Kurds']: 39 entity.relationships = {'location':['GeographicArea']} 40 self.entities = SequentialAgents(entities) 41 self.entities.applyDefaults() 42 self.entities.compileDynamics() 43 self.keyKeys = self.entities.getStateKeys().keys() 44 for key in self.keyKeys[:]: 45 if isinstance(key,StateKey): 46 if key['feature'] in ['economicPower','population']: 47 self.keyKeys.remove(key) 48 else: 49 self.keyKeys.remove(key)
50
51 - def verifyRuleConsistency(self,rules,attributes,values):
52 target = '_value' 53 lhs = filter(lambda k:k!=target,attributes) 54 comparisons = {} 55 # Check for duplicate attributes 56 for index1 in range(len(lhs)-1): 57 attr1 = lhs[index1] 58 for index2 in range(index1+1,len(lhs)): 59 attr2 = lhs[index2] 60 try: 61 result = comparisons[attr2][attr1] 62 except KeyError: 63 result = attributes[attr2].compare(attributes[attr1]) 64 try: 65 comparisons[attr2][attr1] = result 66 except KeyError: 67 comparisons[attr2] = {attr1:result} 68 self.assertNotEqual(result,'equal') 69 for rule1 in rules: 70 for attr in rule1.keys(): 71 self.assert_(attributes.has_key(attr)) 72 for rule1 in rules: 73 for rule2 in rules: 74 if rule1 is rule2: 75 continue 76 for attr1,value1 in rule1.items(): 77 if attr1 != target and value1 is not None: 78 # Check whether rule2 matches on this condition 79 if rule2.has_key(attr1): 80 # Check directly on this attribute 81 if rule2[attr1] != value1 and rule2[attr1] is not None: 82 # Mismatch 83 break 84 for attr2,value2 in rule2.items(): 85 if attr2 != target: 86 # Check indirectly related attributes 87 try: 88 result = comparisons[attr2][attr1] 89 except KeyError: 90 result = attributes[attr2].compare(attributes[attr1]) 91 try: 92 comparisons[attr2][attr1] = result 93 except KeyError: 94 comparisons[attr2] = {attr1:result} 95 if result == 'equal': 96 if value1 != value2 and value2 is not None: 97 # Mismatch 98 break 99 elif result == 'greater' and value2 == True: 100 if value1 == False: 101 break 102 elif result == 'less' and value2 == False: 103 if value1 == True: 104 break 105 else: 106 continue 107 break 108 else: 109 value1 = values[rule1[target]] 110 value2 = values[rule2[target]] 111 content = '\n' 112 for attr,val in rule1.items(): 113 if attr != target and val is not None: 114 content += '%s %s\n' % (val,attr) 115 if isinstance(value1,list): 116 content += '-> %s\n\n' % (str(value1)) 117 else: 118 try: 119 content += '-> %s\n\n' % (value1.simpleText()) 120 except: 121 content += '-> %s\n\n' % (str(value1)) 122 for attr,val in rule2.items(): 123 if attr != target and val is not None: 124 content += '%s %s\n' % (val,attr) 125 if isinstance(value2,list): 126 content += '-> %s\n\n' % (str(value2)) 127 else: 128 try: 129 content += '-> %s\n\n' % (value2.simpleText()) 130 except: 131 content += '-> %s\n\n' % (str(value2)) 132 self.assertEqual(len(value1),len(value2),content) 133 ## for rowKey,row1 in value1.items(): 134 ## self.assert_(value2.has_key(rowKey),content) 135 ## row2 = value2[rowKey] 136 ## self.assertEqual(len(row1),len(row2),content) 137 ## for colKey,col1 in row1.items(): 138 ## self.assertAlmostEqual(col1,row2[colKey],8,content) 139 self.assertEqual(value1,value2,content)
140 ## # Test attribute minimality 141 ## for index1 in range(len(lhs)-1): 142 ## attr1 = lhs[index1] 143 ## plane1 = attributes[attr1] 144 ## for index2 in range(index1+1,len(lhs)): 145 ## attr2 = lhs[index2] 146 ## plane2 = attributes[attr2] 147 ## self.assertNotEqual(plane1.compare(plane2),'equal') 148 ## # Test rule minimality 149 ## for myIndex in range(len(rules)-1): 150 ## myRule = rules[myIndex] 151 ## for yrIndex in range(myIndex+1,len(rules)): 152 ## yrRule = rules[yrIndex] 153 ## for attr in lhs: 154 ## pass 155
156 - def verifyRuleEquality(self,rules1,attrs1,vals1,rules2,attrs2,vals2,comparison=None):
157 if comparison is None: 158 comparison = lambda x,y: x == y 159 state = copy.copy(self.entities.getState()) 160 for index in xrange(pow(self.granularity,len(self.keyKeys))): 161 for key in self.keyKeys: 162 state.domain()[0][key] = float(index % self.granularity)/float(self.granularity) + self.base 163 index /= self.granularity 164 rhs1 = applyRules(state,rules1,attrs1,vals1,'_value',True) 165 rhs2 = applyRules(state,rules2,attrs2,vals2,'_value',True) 166 self.assertEqual(comparison(rhs1,rhs2),True, 167 'Difference found between %s and %s' % \ 168 (str(rhs1),str(rhs2)))
169
170 - def testPolicyRules(self):
171 for entity in self.entities.activeMembers(): 172 for other in entity.entities.activeMembers(): 173 self.verifyValueRules(other) 174 self.verifyMergeRules(other) 175 self.verifyPolicyRules(other) 176 ## self.verifyValueRules(entity) 177 ## self.verifyMergeRules(entity) 178 ## self.verifyPolicyRules(entity) 179 self.verifyProjectedRules(entity) 180 ## self.verifyMergeRules(entity) 181 self.verifyProjectedPolicy(entity)
182
183 - def verifyValueRules(self,entity):
184 entity.horizon = 1 185 entity.policy.depth = entity.horizon 186 rules = {} 187 trees = {} 188 attributes = {} 189 values = {} 190 target = '_value' 191 for action in entity.actions.getOptions(): 192 rules[str(action)],attributes[str(action)],values[str(action)] = entity.policy.getValueRules(action) 193 self.verifyRuleConsistency(rules[str(action)],attributes[str(action)], 194 values[str(action)]) 195 state = copy.copy(self.entities.getState()) 196 goals = entity.getGoalVector()['state'] 197 goals.fill(state.domain()[0].keys()) 198 for index in xrange(pow(self.granularity,len(self.keyKeys))): 199 for key in self.keyKeys: 200 state.domain()[0][key] = float(index % self.granularity)/float(self.granularity) + self.base 201 index /= self.granularity 202 for action in entity.actions.getOptions(): 203 dynamics = entity.entities.getDynamics({entity.name:action}) 204 matrix = dynamics['state'].getTree()[state].domain()[0] 205 goals.fill(matrix.rowKeys()) 206 rawValue = (goals*matrix*state).domain()[0] 207 # Test rule-based formulation of value function 208 rhs = applyRules(state,rules[str(action)], 209 attributes[str(action)], 210 values[str(action)],'_value',True) 211 self.assert_(rhs is not None) 212 ruleValue = rhs*(state.domain()[0]) 213 self.assertAlmostEqual(ruleValue,rawValue,8,'%s has wrong value of %s' %\ 214 (entity.ancestry(),str(action)))
215
216 - def verifyMergeRules(self,entity):
217 ## entity.horizon = 1 218 ## entity.policy.depth = entity.horizon 219 rules = [] 220 target = '_value' 221 attributes = {target:True} 222 values = {} 223 choices = entity.actions.getOptions() 224 for index in range(len(choices)): 225 action = choices[index] 226 subRules,subAttributes,subValues = entity.policy.getValueRules(action=action) 227 self.verifyRuleConsistency(subRules,subAttributes,subValues) 228 subValues = copy.copy(subValues) 229 subRules = mapValues(subRules,subValues, 230 lambda v:{'action':action,'weights':v}) 231 for rule in subRules: 232 self.assertEqual(subValues[rule[target]]['action'], action) 233 self.verifyRuleConsistency(subRules,subAttributes,subValues) 234 mergeAttributes(attributes,subRules,subAttributes) 235 values.update(subValues) 236 if len(rules) > 0: 237 rules,attributes = mergeRules(rules,action,subRules,attributes,values) 238 ## rules,attributes = pruneRules(rules,attributes,values) 239 self.verifyRuleConsistency(rules,attributes,values) 240 else: 241 rules = subRules 242 ## rules,attributes = pruneRules(rules,attributes,values) 243 self.verifyRuleConsistency(rules,attributes,values) 244 ## # Create the final policy 245 filteredAttrs = copy.copy(attributes) 246 filteredVals = copy.copy(values) 247 filteredRules = mapValues(rules,filteredVals,lambda v:v['action']) 248 # Verify the value 249 state = copy.copy(self.entities.getState()) 250 for index in xrange(pow(self.granularity,len(self.keyKeys))): 251 for key in self.keyKeys: 252 state.domain()[0][key] = float(index % self.granularity)/float(self.granularity) + self.base 253 index /= self.granularity 254 decision1 = applyRules(state,rules,attributes,values,'_value',True)['action'] 255 decision2 = applyRules(state,filteredRules,filteredAttrs, 256 filteredVals,'_value',True) 257 args = (state,)+entity.policy.getValueRules(decision1)+('_value',True) 258 best1 = apply(applyRules,args)*state.domain()[0] 259 args = (state,)+entity.policy.getValueRules(decision2)+('_value',True) 260 best2 = apply(applyRules,args)*state.domain()[0] 261 for choice in choices[:index+1]: 262 # Relies on having passed testValueRules 263 args = (state,)+entity.policy.getValueRules(choice)+('_value',True) 264 value = apply(applyRules,args)*state.domain()[0] 265 self.assert_(best1 >= value,'%s incorrectly prefers %s over %s' % \ 266 (entity.ancestry(),str(decision1),str(choice))) 267 self.assert_(best2 >= value,'%s incorrectly prefers %s over %s' % \ 268 (entity.ancestry(),str(decision2),str(choice)))
269 ## # Check that we're matching the actual policy methods 270 ## altRules,altAttrs,altVals = entity.policy.buildPolicy(debug=False) 271 ## self.assertEqual(len(rules),len(altRules)) 272 ## for ruleIndex in range(len(rules)): 273 ## rule1 = rules[ruleIndex] 274 ## rule2 = altRules[ruleIndex] 275 ## self.assertEqual(len(rule1),len(rule2)) 276 ## for attr,value1 in rule1.items(): 277 ## if attr != target: 278 ## self.assert_(rule2.has_key(attr)) 279 ## self.assertEqual(rule1[attr],rule2[attr]) 280 ## self.assertEqual(len(attributes),len(altAttrs)) 281 ## self.assertEqual(len(values),len(altVals)) 282 ## # Check that we're matching the actual policy methods on filtering 283 ## altRules,altAttrs,altVals = entity.policy.compileRules(debug=False) 284 ## self.assertEqual(len(filteredRules),len(altRules)) 285 ## for ruleIndex in range(len(filteredRules)): 286 ## rule1 = filteredRules[ruleIndex] 287 ## rule2 = altRules[ruleIndex] 288 ## self.assertEqual(len(rule1),len(rule2)) 289 ## for attr,value1 in rule1.items(): 290 ## if attr != target: 291 ## self.assert_(rule2.has_key(attr)) 292 ## self.assertEqual(rule1[attr],rule2[attr]) 293 ## self.assertEqual(len(filteredAttrs),len(altAttrs)) 294 ## self.assertEqual(len(filteredVals),len(altVals)) 295
296 - def verifyPolicyRules(self,entity):
297 entity.horizon = 1 298 entity.policy.depth = entity.horizon 299 rules,attributes,values = entity.policy.compileRules() 300 target = '_value' 301 self.verifyRuleConsistency(rules,attributes,values) 302 dynRules,dynAttributes,dynValues = entity.policy.getDynamics() 303 self.verifyRuleConsistency(dynRules,dynAttributes,dynValues) 304 # Test dynamics extension of policy 305 dynamicsRules = {} 306 rawValues = {} 307 rawAttributes = copy.copy(attributes) 308 for key,action in values.items(): 309 actionDict = {entity.name:action} 310 tree = entity.entities.getDynamics(actionDict)['state'].getTree() 311 subAttributes = {target:True} 312 subValues = {} 313 dynamicsRules[key] = tree.makeRules(subAttributes,subValues) 314 mergeAttributes(rawAttributes,dynamicsRules[key],subAttributes) 315 rawValues.update(subValues) 316 dynamicsAttributes = copy.copy(rawAttributes) 317 dynamicsValues = copy.copy(rawValues) 318 for rule in rules: 319 for attr,value in rule.items(): 320 if attr != target: 321 self.assert_(rawAttributes.has_key(attr)) 322 for ruleSet in dynamicsRules.values(): 323 for rule in ruleSet: 324 for attr,value in rule.items(): 325 if attr != target: 326 self.assert_(rawAttributes.has_key(attr)) 327 rawRules = replaceValues(rules,rawAttributes,values,dynamicsRules,rawValues) 328 329 prunedRules,prunedAttributes = pruneRules(rawRules,rawAttributes,rawValues) 330 self.verifyRuleEquality(prunedRules,prunedAttributes,rawValues, 331 rawRules,rawAttributes,rawValues) 332 self.verifyRuleEquality(prunedRules,prunedAttributes,rawValues, 333 dynRules,dynAttributes,dynValues) 334 state = copy.copy(self.entities.getState()) 335 for index in xrange(pow(self.granularity,len(self.keyKeys))): 336 for key in self.keyKeys: 337 state.domain()[0][key] = float(index % self.granularity)/float(self.granularity) + self.base 338 index /= self.granularity 339 decision = applyRules(state,rules,attributes,values,'_value',True) 340 args = (state,)+entity.policy.getValueRules(decision)+('_value',True) 341 best = apply(applyRules,args)*state.domain()[0] 342 for action in entity.actions.getOptions(): 343 if str(action) != decision: 344 # Relies on having passed testValueRules 345 args = (state,)+entity.policy.getValueRules(action)+('_value',True) 346 value = apply(applyRules,args)*state.domain()[0] 347 self.assert_(best >= value,'%s incorrectly prefers %s over %s' % \ 348 (entity.ancestry(),str(decision),str(action))) 349 # Test policy-dynamics combo rules 350 dynamics = entity.entities.getDynamics({entity.name:decision}) 351 matrix = dynamics['state'].getTree()[state] 352 realState = matrix*state 353 self.assertEqual(len(realState),1) 354 rulesMatrix = applyRules(state,dynamicsRules[str(decision)], 355 dynamicsAttributes, 356 dynamicsValues,'_value',True) 357 rhs = applyRules(state,dynRules,dynAttributes,dynValues,'_value',True) 358 raw = applyRules(state,rawRules,rawAttributes,rawValues,'_value',True) 359 pruned = applyRules(state,prunedRules,prunedAttributes,rawValues,'_value',True) 360 self.assertEqual(raw,rulesMatrix,'%s differs from actual %s' % \ 361 (raw.simpleText(),rulesMatrix.simpleText())) 362 self.assertEqual(raw,pruned,'pruned %s differs from original %s' % \ 363 (pruned.simpleText(),raw.simpleText())) 364 self.assertEqual(raw,rhs) 365 for row in matrix.domain()[0].rowKeys(): 366 realRow = matrix.domain()[0][row] 367 self.assert_(rhs.has_key(row)) 368 ruleRow = rhs[row] 369 self.assert_(raw.has_key(row)) 370 rawRow = raw[row] 371 self.assert_(rulesMatrix.has_key(row)) 372 dynRow = rulesMatrix[row] 373 for col in realRow.keys(): 374 self.assert_(dynRow.has_key(col)) 375 self.assertAlmostEqual(dynRow[col],realRow[col],8) 376 self.assert_(rawRow.has_key(col)) 377 self.assertAlmostEqual(dynRow[col],rawRow[col],8) 378 self.assertAlmostEqual(rawRow[col],realRow[col],8) 379 self.assert_(ruleRow.has_key(col)) 380 self.assertAlmostEqual(realRow[col],ruleRow[col],8) 381 ruleState = rhs*state.domain()[0] 382 for key in state.domainKeys(): 383 self.assertAlmostEqual(realState.domain()[0][key], 384 ruleState[key],8)
385
387 entity = self.entities['Turkomen'] 388 goals = entity.getGoalTree().getValue() 389 other = entity.getEntity('US') 390 other.horizon = 1 391 other.policy.depth = other.horizon 392 start = time.time() 393 USRules,USAttrs,USVals = other.policy.getDynamics() 394 print other.ancestry(),time.time()-start 395 self.verifyRuleConsistency(USRules,USAttrs,USVals) 396 other = entity.getEntity('Kurds') 397 other.horizon = 1 398 other.policy.depth = other.horizon 399 start = time.time() 400 kRules,kAttrs,kVals = other.policy.getDynamics() 401 print other.ancestry(),time.time()-start 402 self.verifyRuleConsistency(kRules,kAttrs,kVals) 403 state0 = copy.copy(self.entities.getState()) 404 for action in entity.actions.getOptions(): 405 print action 406 # Cumulative storage of attributes and values 407 policyAttributes = {'_value':True} 408 policyValues = {} 409 dynamicsAttributes = {'_value':True} 410 dynamicsValues = {} 411 412 # Step 1: Turkomen perform given action 413 actionDict = {entity.name:action} 414 step1Tree = entity.entities.getDynamics(actionDict)['state'].getTree() 415 step1Attrs = {} 416 step1Vals = {} 417 start = time.time() 418 step1Rules = step1Tree.makeRules(step1Attrs,step1Vals) 419 print 'Making rules:',time.time()-start 420 421 # Update cumulative dynamics 422 mergeAttributes(dynamicsAttributes,step1Rules,step1Attrs) 423 dynamicsValues.update(step1Vals) 424 dynamicsRules = step1Rules 425 self.verifyRuleConsistency(dynamicsRules,dynamicsAttributes, 426 dynamicsValues) 427 428 # Compute current total reward 429 totalVals = copy.copy(dynamicsValues) 430 start = time.time() 431 totalRules = mapValues(dynamicsRules,totalVals,lambda v:goals*v) 432 print 'Evaluating:',time.time()-start 433 policyValues.update(totalVals) 434 mergeAttributes(policyAttributes,totalRules,dynamicsAttributes) 435 policyRules = totalRules 436 437 # Step 2: US follows policy 438 mergeAttributes(dynamicsAttributes,USRules,USAttrs) 439 dynamicsValues.update(USVals) 440 self.verifyRuleConsistency(dynamicsRules,dynamicsAttributes, 441 dynamicsValues) 442 self.verifyRuleConsistency(USRules,dynamicsAttributes, 443 dynamicsValues) 444 start = time.time() 445 ## dynamicsRules,dynamicsAttributes = self.detailedMultiply(USRules,dynamicsRules, 446 ## dynamicsAttributes,dynamicsValues) 447 dynamicsRules,dynamicsAttributes = multiplyRules(USRules,dynamicsRules, 448 dynamicsAttributes, 449 dynamicsValues) 450 print 'Multiplying:',time.time()-start 451 print len(dynamicsRules) 452 self.verifyRuleConsistency(dynamicsRules,dynamicsAttributes, 453 dynamicsValues) 454 455 # Update running total of reward 456 totalVals = copy.copy(dynamicsValues) 457 start = time.time() 458 totalRules = mapValues(dynamicsRules,totalVals,lambda v:goals*v) 459 print 'Evaluating:',time.time()-start 460 policyValues.update(totalVals) 461 mergeAttributes(policyAttributes,totalRules,dynamicsAttributes) 462 policyRules = addRules(policyRules,totalRules,policyAttributes,policyValues) 463 464 # Step 3: Kurds follow Policy 465 self.verifyRuleConsistency(dynamicsRules,dynamicsAttributes, 466 dynamicsValues) 467 mergeAttributes(dynamicsAttributes,kRules,kAttrs) 468 dynamicsValues.update(kVals) 469 self.verifyRuleConsistency(dynamicsRules,dynamicsAttributes, 470 dynamicsValues) 471 self.verifyRuleConsistency(kRules,dynamicsAttributes, 472 dynamicsValues) 473 start = time.time() 474 ## dynamicsRules,dynamicsAttributes = self.detailedMultiply(kRules,dynamicsRules, 475 ## dynamicsAttributes,dynamicsValues) 476 dynamicsRules,dynamicsAttributes = multiplyRules(kRules,dynamicsRules, 477 dynamicsAttributes, 478 dynamicsValues,debug=True) 479 print 'Multiplying:',time.time()-start 480 print len(dynamicsRules) 481 self.verifyRuleConsistency(dynamicsRules,dynamicsAttributes, 482 dynamicsValues) 483 484 # Update running total of reward 485 totalVals = copy.copy(dynamicsValues) 486 start = time.time() 487 totalRules = mapValues(dynamicsRules,totalVals,lambda v:goals*v) 488 print 'Evaluating:',time.time()-start 489 policyValues.update(totalVals) 490 mergeAttributes(policyAttributes,totalRules,dynamicsAttributes) 491 policyRules = addRules(policyRules,totalRules,policyAttributes,policyValues) 492 493 print 'Running %d tests' % (pow(self.granularity,len(self.keyKeys))) 494 for index in xrange(pow(self.granularity,len(self.keyKeys))): 495 for key in self.keyKeys: 496 state0.domain()[0][key] = float(index % self.granularity)/float(self.granularity) + self.base 497 index /= self.granularity 498 ## print 0,state0.domain()[0].simpleText() 499 real0 = state0 500 rhs1 = applyRules(state0,step1Rules,step1Attrs,step1Vals,'_value',True) 501 product = rhs1 502 total = goals * rhs1 503 state1 = Distribution({rhs1*state0.domain()[0]:1.}) 504 real1 = step1Tree*real0 505 self.assertEqual(state1,real1) 506 reward = goals*real1.domain()[0] 507 self.assertAlmostEqual(reward,total*real0.domain()[0],8) 508 ## print 1,state1.domain()[0].simpleText() 509 rhsUS = applyRules(state1,USRules,USAttrs,USVals,'_value',True) 510 product = rhsUS*product 511 total += goals*product 512 decision,exp = entity.getEntity('US').policy.execute({'state':real1}) 513 step2Tree = self.entities.getDynamics({'US':decision})['state'].getTree() 514 real2 = step2Tree*real1 515 state2 = Distribution({rhsUS*state1.domain()[0]:1.}) 516 self.assertEqual(state2,real2) 517 reward += goals*real2.domain()[0] 518 self.assertAlmostEqual(reward,total*real0.domain()[0],8) 519 ## print 2,state2.domain()[0].simpleText() 520 rhsKurds = applyRules(state2,kRules,kAttrs,kVals,'_value',True) 521 product = rhsKurds*product 522 total += goals*product 523 ## print result[0] 524 state3 = Distribution({rhsKurds*state2.domain()[0]:1.}) 525 decision,exp = entity.getEntity('Kurds').policy.execute({'state':real1}) 526 step3Tree = self.entities.getDynamics({'Kurds':decision})['state'].getTree() 527 real3 = step3Tree*real2 528 self.assertEqual(state3,real3) 529 reward += goals*real3.domain()[0] 530 self.assertAlmostEqual(reward,total*real0.domain()[0],8) 531 ## print 3,state3.domain()[0].simpleText() 532 projection = applyRules(state0,dynamicsRules,dynamicsAttributes, 533 dynamicsValues,'_value',True) 534 ## print 'Alleged:',(projection*state0.domain()[0]).simpleText() 535 self.assertEqual(product,projection, 536 'Difference found between %s and %s' % \ 537 (product.simpleText(),projection.simpleText())) 538 539 rhsTotal = applyRules(state0,policyRules,policyAttributes, 540 policyValues,'_value',True) 541 self.assertEqual(total,rhsTotal,'Difference found between %s and %s' % \ 542 (total.simpleText(),rhsTotal.simpleText()))
543
544 - def verifyProjectedRules(self,entity):
545 sequence = entity.policy.getSequence(entity,len(self.entities.activeMembers())) 546 entity.horizon = len(sequence) 547 entity.policy.depth = entity.horizon 548 # Pre-compile mental models 549 for other in entity.getEntityBeliefs(): 550 if other.name != entity.name: 551 other.horizon = 1 552 other.policy.depth = other.horizon 553 # We assume this is correct based on testPolicyRules 554 other.policy.compileRules() 555 state = copy.copy(entity.entities.getState()) 556 goals = entity.getGoalVector()['state'] 557 goals.fill(state.domain()[0].keys()) 558 target = '_value' 559 for action in entity.actions.getOptions(): 560 actionDict = {entity.name:action} 561 actionKey = string.join(map(str,actionDict.values())) 562 self.assert_(self.entities.dynamics.has_key(actionKey)) 563 dynamics = self.entities.getDynamics(actionDict) 564 tree = dynamics['state'].getTree() 565 # Compile a "long-range" policy 566 rules,attributes,values = entity.policy.getValueRules(action,debug=False) 567 self.verifyRuleConsistency(rules,attributes,values) 568 ## original=copy.deepcopy((rules,attributes,values)) 569 ## internRules,internAttrs = internalCheck(rules,attributes,values,target) 570 ## prunedRules,prunedAttrs = pruneRules(internRules,internAttrs,values) 571 ## print 'Pruning saved:',len(rules)-len(prunedRules) 572 ## self.assertEqual(original,(rules,attributes,values)) 573 ## self.verifyRuleEquality(rules,attributes,values, 574 ## prunedRules,prunedAttrs,values) 575 for index in xrange(pow(self.granularity,len(self.keyKeys))): 576 for key in self.keyKeys: 577 state.domain()[0][key] = float(index % self.granularity)/float(self.granularity) + self.base 578 index /= self.granularity 579 # Project chosen action 580 current = copy.copy(state) 581 stateSequence = [current] 582 dynAttrs = {target:True} 583 dynVals = {} 584 dynRules = tree.makeRules(dynAttrs,dynVals) 585 self.verifyRuleConsistency(dynRules,dynAttrs,dynVals) 586 ruleMatrix = applyRules(current,dynRules,dynAttrs,dynVals,target,True) 587 totalRules = copy.deepcopy(dynRules) 588 totalAttrs = copy.copy(dynAttrs) 589 totalVals = copy.copy(dynVals) 590 matrix = tree[current] 591 self.assertEqual(ruleMatrix,matrix.domain()[0],'%s differs from %s' % \ 592 (ruleMatrix.simpleText(), 593 matrix.domain()[0].simpleText())) 594 rawValue = 0. 595 current = matrix*current 596 stateSequence.append(current) 597 rawValue += (goals*current).domain()[0] 598 for t in range(1,entity.horizon): 599 other = entity.getEntity(sequence[t][0]) 600 decision,exp = other.policy.execute({'state':current}) 601 dynamics = self.entities.getDynamics({other.name:decision}) 602 matrix = dynamics['state'].getTree()[current] 603 newRules,newAttrs,newVals = other.policy.compileRules() 604 self.verifyRuleConsistency(newRules,newAttrs,newVals) 605 newRules,newAttrs,newVals = other.policy.getDynamics() 606 newRules = copy.deepcopy(newRules) 607 self.verifyRuleConsistency(newRules,newAttrs,newVals) 608 ruleMatrix = applyRules(current,newRules,newAttrs,newVals,'_value',True) 609 mergeAttributes(dynAttrs,newRules,newAttrs) 610 dynVals.update(newVals) 611 dynRules,dynAttrs = multiplyRules(newRules,dynRules,dynAttrs,dynVals) 612 self.verifyRuleConsistency(dynRules,dynAttrs,dynVals) 613 ruleState = ruleMatrix*current.domain()[0] 614 current = matrix*current 615 stateSequence.append(current) 616 for key,value in ruleState.items(): 617 self.assertAlmostEqual(value,current.domain()[0][key],8) 618 rawValue += (goals*current).domain()[0] 619 # Test rule-based formulation of value function 620 rhs = applyRules(state,rules,attributes,values,'_value',True) 621 ruleValue = rhs*(state.domain()[0]) 622 if abs(ruleValue-rawValue) > .000001: 623 for index in range(len(stateSequence)): 624 print index,stateSequence[index].domain()[0] 625 self.assertAlmostEqual(ruleValue,rawValue,8, 626 'Incorrect value (%f!=%f) on %s' % \ 627 (ruleValue,rawValue,action))
628
629 - def verifyProjectedPolicy(self,entity):
630 target = '_value' 631 policyRules,policyAttrs,policyVals = entity.policy.compileRules() 632 self.verifyRuleConsistency(policyRules,policyAttrs,policyVals) 633 ## internRules,internAttrs = internalCheck(policyRules,policyAttrs,policyVals,target) 634 ## prunedRules,prunedAttrs = pruneRules(internRules,internAttrs,policyVals) 635 ## print 'Pruning saved:',len(policyRules)-len(prunedRules) 636 ## self.verifyRuleEquality(policyRules,policyAttrs,policyVals, 637 ## prunedRules,prunedAttrs,policyVals) 638 state = copy.copy(entity.entities.getState()) 639 for index in xrange(pow(self.granularity,len(self.keyKeys))): 640 for key in self.keyKeys: 641 state.domain()[0][key] = float(index % self.granularity)/float(self.granularity) + self.base 642 index /= self.granularity 643 decision = applyRules(state,policyRules,policyAttrs,policyVals,target,True) 644 valRules,valAttrs,valVals = entity.policy.getValueRules(decision) 645 best = applyRules(state,valRules,valAttrs,valVals,target,True)*state.domain()[0] 646 # The right thing? 647 for action in entity.actions.getOptions(): 648 if action != decision: 649 valRules,valAttrs,valVals = entity.policy.getValueRules(action) 650 value = applyRules(state,valRules,valAttrs,valVals,'_value',True)*state.domain()[0] 651 self.assert_(best >= value,'Nope, %s not better than %s' % (best,value))
652 653 if __name__ == '__main__': 654 unittest.main() 655