Package teamwork :: Package policy :: Module pwlPolicy
[hide private]
[frames] | no frames]

Source Code for Module teamwork.policy.pwlPolicy

  1  from LookaheadPolicy import LookaheadPolicy 
  2  from pwlTable import PWLTable 
  3       
4 -class PWLPolicy(LookaheadPolicy):
5 """ 6 Policy that uses L{PWLTable}s to store action rules 7 """ 8
9 - def __init__(self,entity,actions=None,horizon=None):
10 """Same arguments used by constructor for L{LookupAheadPolicy} 11 superclass 12 """ 13 if actions is None: 14 actions = entity.actions.getOptions() 15 if horizon is None: 16 horizon = entity.horizon 17 LookaheadPolicy.__init__(self,entity=entity,actions=actions, 18 horizon=horizon) 19 self.type = 'PWL' 20 self.reset()
21
22 - def reset(self):
23 """Removes any cached policy tables 24 """ 25 self.tables = []
26
27 - def __getitem__(self,index):
28 return self.execute(index)
29
30 - def execute(self,state=None,choices=[],debug=None,horizon=-1,explain=False):
31 """Execute the policy in the given state 32 @param horizon: the horizon to consider (by default, use the entity's given horizon) 33 @type horizon: int 34 @param choices: the legal actions to consider (default is all available actions) 35 @type choices: L{Action<teamwork.action.PsychActions.Action>}[][] 36 """ 37 if state is None: 38 state = self.entity.beliefs 39 if horizon < 0: 40 horizon = self.entity.horizon 41 if self.tables and self.tables[-1]: 42 try: 43 table = self.tables[-1][horizon] 44 except IndexError: 45 table = self.tables[-1][-1] 46 rhs = table[state] 47 else: 48 rhs = None 49 if rhs is None: 50 return LookaheadPolicy.execute(self,state,choices,debug,horizon, 51 explain) 52 else: 53 return rhs,None
54
55 - def project(self,R,depth=-1,debug=False):
56 """ 57 Project the value function and policy one step further at the given depth 58 @param R: the reward function in tabular form 59 @type R: L{PWLTable} 60 @param depth: the recursive belief depth to compute the value function for (default is deepest level already computed) 61 @type depth: int 62 """ 63 # Find policy at the specified depth (create one if none present) 64 try: 65 previous = self.tables[depth] 66 except IndexError: 67 if depth < 0: 68 if len(self.tables) == abs(depth) - 1: 69 # Table at previous depth exists 70 self.tables.append([]) 71 else: 72 # Need to do intervening level first 73 self.project(R,depth+1,debug) 74 else: 75 if len(self.tables) == depth: 76 # Table at previous depth exists 77 self.tables.append([]) 78 else: 79 # Need to do intervening level first 80 self.project(R,depth-1,debug) 81 previous = self.tables[depth] 82 if debug: 83 print 'Horizon = %d' % (len(self.tables[depth])) 84 # Compute new value function: V_a(b) = R(a,b) + ... 85 V = R.getTable() 86 if len(previous) > 0: 87 # Transition to previous time step's value function 88 Vstar = previous[-1].star() 89 Vstar.prune() 90 if debug: 91 print 'V*' 92 print Vstar 93 for omega,SE in self.entity.estimators.items(): 94 # ... + \sum_\omega V^*(SE_a(b,\omega)) 95 product = Vstar.__mul__(SE,debug=debug) 96 ## product.prune() 97 if debug: 98 print 'SE(b,%s)' % (omega) 99 print SE 100 print 'V*(SE(b,%s))' % (omega) 101 print product 102 V = V.__add__(product) 103 ## V.prune() 104 # Compute policy 105 if debug: 106 print '\tMax (%d rules)' % (len(V.values)) 107 print V 108 policy = V.max() 109 # if debug: 110 # print '\tPruning...' 111 # policy.pruneAttributes() 112 previous.append(policy) 113 # Replace string RHS with actual actions 114 for rule in policy.rules.keys(): 115 for option in self.entity.actions.getOptions(): 116 if str(option) == policy.rules[rule]: 117 policy.rules[rule] = option 118 break 119 else: 120 raise NameError,'Unable to find RHS action "%s"' % \ 121 (policy.rules[rule])
122
123 - def getTable(self,depth=-1,horizon=-1):
124 """ 125 @param depth: the recursive depth for the desired policy (default is maximum depth solved) 126 @param horizon: the horizon for the desired policy (default is maximum horizon solved) 127 @type depth, horizon: int 128 @return: a given policy table 129 @rtype: L{PWLTable} 130 @warning: Will raise C{IndexError} if no policy has been solved for the given horizon and depth settings 131 """ 132 table = self.tables[depth][horizon] 133 return table.getTable()
134