Package teamwork :: Package test :: Package math :: Module testFitting
[hide private]
[frames] | no frames]

Source Code for Module teamwork.test.math.testFitting

  1  from teamwork.agent.Entities import * 
  2  from teamwork.action.PsychActions import * 
  3  from teamwork.multiagent.GenericSociety import * 
  4  from teamwork.agent.DefaultBased import createEntity 
  5  from teamwork.multiagent.sequential import * 
  6  from teamwork.reward.MinMaxGoal import * 
  7  from teamwork.examples.InfoShare.PortClasses import * 
  8   
  9  from teamwork.math.Keys import * 
 10  from teamwork.math.fitting import * 
 11   
 12  import copy 
 13  import unittest 
 14   
15 -class TestFitting(unittest.TestCase):
16 debug = None 17 increment = 0.25 18
19 - def setUp(self):
20 """Creates the instantiated scenario used for testing""" 21 society = GenericSociety() 22 society.importDict(classHierarchy) 23 entities = [] 24 self.instances = {'FirstResponder':['FirstResponder'], 25 'World':['World'], 26 'FederalAuthority':['FederalAuthority'], 27 'Shipper':['Shipper'], 28 } 29 for cls,names in self.instances.items(): 30 for name in names: 31 entity = createEntity(cls,name,society,PsychEntity) 32 entities.append(entity) 33 self.entities = SequentialAgents(entities) 34 self.entities.applyDefaults() 35 self.entities.compileDynamics() 36 37 self.entity = self.entities['FirstResponder'] 38 for action in self.entity.actions.getOptions(): 39 if action[0]['type'] == 'inspect': 40 break 41 else: 42 self.fail('No inspect action found') 43 self.action = action 44 self.order = self.entity.policy.getSequence(self.entity) 45 ## self.order = self.entity.entities.keyOrder[:] 46 ## self.order.remove([self.entity.name]) 47 ## self.order.insert(0,[self.entity.name]) 48 49 self.features = {self.entity.name:{'waitTime':1, 50 'reputation':1, 51 }, 52 'Shipper':{'containerDanger':1, 53 }, 54 'FederalAuthority':{}, 55 'World':{'socialWelfare':1, 56 }, 57 } 58 self.keys = [] 59 for entity,features in self.features.items(): 60 for feature in features.keys(): 61 self.keys.append(StateKey({'entity':entity, 62 'feature':feature}))
63
64 - def generateFirstState(self):
65 state = KeyedVector({keyConstant:1.}) 66 for key in self.keys: 67 count = self.features[key['entity']][key['feature']] 68 if count == 1: 69 state[key] = 0.5 70 else: 71 state[key] = 0. 72 return state
73
74 - def generateNextState(self,state):
75 for key in self.keys: 76 count = self.features[key['entity']][key['feature']] 77 if count > 1: 78 if state.has_key(key): 79 if state[key] < float(count)*self.increment: 80 state[key] += self.increment 81 break 82 else: 83 state[key] = 0. 84 else: 85 state[key] = self.increment 86 break 87 else: 88 return None 89 return state
90
91 - def generateLoop(self,cmd):
92 # Loop through possible states to test values 93 done = {} 94 state = self.generateFirstState() 95 while state: 96 for key in self.keys: 97 try: 98 value = state[key] 99 except KeyError: 100 value = 0. 101 self.entity.setRecursiveBelief(key['entity'],key['feature'], 102 value) 103 # Execute state 104 cmd(state) 105 state = self.generateNextState(state)
106
107 - def verifyDynamics(self,state,tree,action):
108 """Verifies that the given dynamics tree 109 110 Checks that the result of applying the tree in the given state 111 produces the same result as directly applying the action to 112 the entity (using the performAct method)""" 113 # Apply the dynamics tree 114 delta = tree[state]*state 115 state += delta 116 # Apply the action 117 self.entity.entities.performAct({self.entity.name:action}) 118 self.checkState(state)
119
120 - def checkState(self,state):
121 """Verifies that the state vector is the same as the real state""" 122 for key in self.keys: 123 try: 124 value = state[key] 125 except KeyError: 126 value = 0. 127 if isinstance(key,StateKey): 128 real = self.entity.getBelief(key['entity'], 129 key['feature']) 130 msg = '%s:\n%f != %f' % \ 131 (key.simpleText(),real,value) 132 self.assertAlmostEqual(float(real),value,5,msg)
133
134 - def DONTtestGetKeys(self):
135 keyList = getActionKeys(self.entity,self.order) 136 self.assertEqual(keyList,[])
137
138 - def testExpandPolicy(self):
139 """Tests the expansion of an entity's lookup policy tree""" 140 keyList = [] # getActionKeys(self.entity,self.order) 141 for turn in self.order: 142 for name in turn: 143 if name != self.entity.name: 144 policy = expandPolicy(self.entity,name,keyList) 145 entity = self.entity.getEntity(name) 146 self.generateLoop(lambda state,s=self,e=entity,p=policy:\ 147 s.verifyPolicyEffect(e,state,p))
148
149 - def verifyPolicyEffect(self,entity,state,policy):
150 """Verifies the PWL policy in the given state""" 151 act,exp = entity.applyPolicy() 152 state += policy[state]*state 153 self.entity.entities.performAct({entity.name:act}) 154 self.checkState(state)
155
156 - def testGetLookahead(self):
157 """Tests the decision-tree compilation of the lookahead delta""" 158 entity = self.entity 159 for action in entity.actions.getOptions(): 160 # Loop through each action 161 length = len(self.order) 162 result = getLookaheadTree(entity,action, 163 self.order[:length]) 164 self.generateLoop(lambda state,s=self,t=result['transition'], 165 a=action,l=length:\ 166 s.verifyLookahead(state,t,a,l))
167
168 - def verifyLookahead(self,state,tree,action,length):
169 "Verifies the given lookahead (delta only) tree""" 170 # Apply the delta tree 171 state = tree[state]*state 172 # Explicit simulation of lookahead 173 for t in range(length): 174 if t == 0: 175 self.entity.entities.performAct({self.entity.name:action}) 176 else: 177 result = self.entity.entities.microstep() 178 self.assertEqual(self.order[t],result['decision'].keys()) 179 self.checkState(state)
180
181 - def testValueTree(self):
182 """Tests the expected value decision tree""" 183 # Check the basic goal vector 184 # (should probably be in different test case) 185 goals = self.entity.getGoalTree() 186 realValue = self.entity.applyGoals() 187 state = self.entities.getState() 188 self.assertAlmostEqual(float(goals[state]*state),float(realValue),5) 189 # Check expected value tree 190 for action in self.entity.actions.getOptions(): 191 tree = self.entity.policy.getValueTree(action) 192 self.generateLoop(lambda state,s=self,t=tree,a=action:\ 193 s.verifyValue(state,t,a))
194
195 - def verifyValue(self,state,tree,action):
196 """Verifies the PWL expected reward calculation in the given state""" 197 # Apply the delta tree 198 value = (tree[state]*state) 199 # Explicit simulation of lookahead 200 total,exp = self.entity.actionValue(action,len(self.order)) 201 self.assertAlmostEqual(value,float(total),5)
202
203 - def verifyPolicy(self,entity,state,policy):
204 """Verifies the PWL policy in the given state""" 205 act,exp = entity.applyPolicy() 206 self.assertEqual(act,policy[state])
207
208 - def testBuildPolicy(self):
209 """Test compilation of lookahead into decision tree""" 210 policy = self.entity.policy.buildPolicy() 211 self.generateLoop(lambda state,s=self,p=policy:\ 212 s.verifyPolicy(s.entity,state,p))
213
214 - def DONTtestFindAll(self):
215 """Tests the fitting procedure""" 216 for desired in self.entity.actions.getOptions(): 217 act,exp = self.entity.applyPolicy() 218 constraints = findAllConstraints(self.entity,desired,self.order) 219 if act == desired: 220 # Desired action already preferred 221 self.assertEqual(len(constraints),1) 222 constraint = constraints[0] 223 self.assertEqual(constraint['slope'],{}) 224 self.assertEqual(constraint['solution'],{}) 225 else: 226 # Must fit to desired action 227 for constraint in constraints: 228 for goal in self.entity.getGoals(): 229 try: 230 slope = constraint['slope'][goal.toKey()] 231 except KeyError: 232 continue 233 weight = self.entity.getGoalWeight(goal) 234 try: 235 delta = -constraint['delta']/slope - epsilon 236 except ZeroDivisionError: 237 continue 238 change = None 239 if (slope > 0. and goal.isMax()) or \ 240 (slope < 0. and not goal.isMax()): 241 if weight + delta < Interval.CEILING: 242 new = weight+delta 243 change = 1 244 else: 245 continue 246 elif (slope < 0. and goal.isMax()) or \ 247 (slope > 0. and not goal.isMax()): 248 if weight > delta: 249 new = weight-delta 250 change = 1 251 else: 252 continue 253 254 self.entity.setGoalWeight(goal,new) 255 act,exp = self.entity.applyPolicy() 256 self.assertEqual(act,desired) 257 self.entity.setGoalWeight(goal,weight)
258 259 if __name__ == '__main__': 260 unittest.main() 261