Package teamwork :: Package policy :: Module ObservationPolicy
[hide private]
[frames] | no frames]

Source Code for Module teamwork.policy.ObservationPolicy

  1  from generic import * 
  2   
3 -class ObservationPolicy(Policy):
4 """Policy that uses a lookup table, indexed by observation history 5 @ivar Omega: list of possible observations 6 @ivar entries: lookup table 7 @ivar horizon: the maximum observation history that has been solved for (default is None) 8 """ 9
10 - def __init__(self,choices,observations,horizon=0):
11 Policy.__init__(self,choices) 12 self.Omega = observations[:] 13 self.initialize(horizon)
14
15 - def initialize(self,horizon):
16 """Sets the policy to the first policy in the ordered space 17 """ 18 self.horizon = horizon 19 self.entries = map(lambda n: 0,range(horizon+1))
20
21 - def next(self):
22 """Increments this policy to the next one in the ordered space 23 @return: C{True} if the next policy has been found; C{False} if reached the end of the policy space 24 @rtype: bool 25 """ 26 sizeA = len(self.choices) 27 sizeOmega = len(self.Omega) 28 pos = 0 29 while True: 30 try: 31 self.entries[pos] += 1 32 except IndexError: 33 # Reached end of policy space 34 return False 35 if self.entries[pos] < pow(sizeA,pow(sizeOmega,pos)): 36 break 37 else: 38 self.entries[pos] = 0 39 pos += 1 40 return True
41
42 - def execute(self,state=None,choices=[],debug=None,depth=-1,explain=False):
43 """Execute the policy in the given state 44 @param state: observation history 45 @warning: Ignores C{choices} and C{depth} argument 46 """ 47 try: 48 policyIndex = self.entries[len(state)] 49 except IndexError: 50 # Haven't planned for an observation history this long 51 length = len(self.entries) 52 try: 53 # Try most extensive policy found 54 policyIndex = self.entries[-1] 55 except IndexError: 56 raise UserWarning,'No policy available' 57 # Ignore oldest observations 58 state = state[-length:] 59 # Find entry for given observation history 60 history = 0 61 for omega in state: 62 history *= len(self.Omega) 63 if isinstance(omega,int): 64 # Already an index 65 history += omega 66 else: 67 # Convert symbol to index 68 history += self.Omega.index(omega) 69 # Index into policy using history 70 action = (policyIndex / pow(len(self.choices),history)) \ 71 % (len(self.choices)) 72 return self.choices[action]
73
74 - def __str__(self,buf=None):
75 if buf is None: 76 import cStringIO 77 buf = cStringIO.StringIO() 78 sizeOmega = len(self.Omega) 79 for horizon in range(len(self.entries)): 80 # For each possible history length 81 for history in range(pow(sizeOmega,horizon)): 82 # For each possible history 83 obs = [] 84 while len(obs) < horizon: 85 obs.insert(0,history % sizeOmega) 86 history /= sizeOmega 87 print >> buf,map(lambda o:str(self.Omega[o]),obs), 88 print >> buf,self.execute(obs) 89 content = buf.getvalue() 90 buf.close() 91 return content
92
93 -def solve(policies,horizon,evaluate,debug=False,identical=False):
94 """Exhaustive search to find optimal joint policy over the given horizon 95 @type policies: L{ObservationPolicy}[] 96 @type horizon: int 97 @param evaluate: function that takes this policy object and returns an expected value 98 @type evaluate: lambda L{ObservationPolicy}: float 99 @param identical: if C{True}, then assume that all agents use an identical policy (default is C{False}) 100 @type identical: bool 101 @return: the value of the best policy found 102 @rtype: float 103 @warning: side effect of setting all policies in list to the best one found. If you don't like it, too bad. 104 """ 105 best = None 106 for policy in policies: 107 policy.initialize(horizon) 108 done = False 109 while not done: 110 # Evaluate current candidate joint policy 111 if debug: 112 print 'Evaluating:',map(lambda p: str(p.entries),policies) 113 value = evaluate(policies) 114 if debug: 115 print 'EV =',value 116 if best is None or value > best['value']: 117 if debug: 118 print 'New best:' 119 for policy in policies: 120 print '\t'+str(policy).replace('\n','\n\t') 121 best = {'policy': map(lambda p: p.entries[:],policies), 122 'value': value} 123 # Go on to next policy 124 if identical: 125 if policies[0].next(): 126 # Copy new policy over others 127 for policy in policies[1:]: 128 policy.entries = policies[0].entries[:] 129 else: 130 # No more new policies to try 131 break 132 else: 133 for index in range(len(policies)): 134 policy = policies[index] 135 if policy.next(): 136 break 137 else: 138 # Reached end of space for this policy 139 policy.initialize(horizon) 140 else: 141 # Gone through all combinations 142 break 143 # Update policy with best found 144 for index in range(len(policies)): 145 policies[index].entries = best['policy'][index] 146 return best['value']
147
148 -def solveExhaustively(scenario,transition,Omega,observations,evaluate, 149 horizon,identical=False,debug=False):
150 """ 151 exhaustive search for optimal policy in given scenario 152 """ 153 old = {} 154 policies = [] 155 # Set up the initial observation-based policies 156 for agent in scenario.members(): 157 actions = agent.actions.getOptions() 158 if actions: 159 old[agent.name] = agent.policy 160 agent.policy = ObservationPolicy(actions,Omega.values()) 161 policies.append(agent.policy) 162 # Phase 2: ? 163 value = solve(policies,horizon,evaluate,identical=True) 164 # Phase 3: Profit 165 if debug: 166 print value/float(horizon+1) 167 for agent in scenario.members(): 168 if agent.actions.getOptions(): 169 print str(agent.policy) 170 agent.policy = old[agent.name] 171 return value
172 173 if __name__ == '__main__': 174 import sys 175 import time 176 from teamwork.examples.TigerScenario import setupTigers,Omega,EV 177 178 scenario,full,transition,reward,observations = setupTigers() 179 180 for horizon in range(10):
181 - def evaluateJoint(policies):
182 assert isinstance(scenario['Player 2'].policy,ObservationPolicy) 183 return EV(scenario,transition,observations,reward,horizon)
184 185 start = time.time() 186 value = solveExhaustively(scenario,transition,Omega,observations, 187 evaluateJoint,horizon,True) 188 delta = time.time()-start 189 print '%d,%f,%f' % (horizon,value/float(horizon+1),delta) 190 sys.stdout.flush() 191