1 from LookaheadPolicy import LookaheadPolicy
2 from pwlTable import PWLTable
3
5 """
6 Policy that uses L{PWLTable}s to store action rules
7 """
8
9 - def __init__(self,entity,actions=None,horizon=None):
21
23 """Removes any cached policy tables
24 """
25 self.tables = []
26
29
30 - def execute(self,state=None,choices=[],debug=None,horizon=-1,explain=False):
31 """Execute the policy in the given state
32 @param horizon: the horizon to consider (by default, use the entity's given horizon)
33 @type horizon: int
34 @param choices: the legal actions to consider (default is all available actions)
35 @type choices: L{Action<teamwork.action.PsychActions.Action>}[][]
36 """
37 if state is None:
38 state = self.entity.beliefs
39 if horizon < 0:
40 horizon = self.entity.horizon
41 if self.tables and self.tables[-1]:
42 try:
43 table = self.tables[-1][horizon]
44 except IndexError:
45 table = self.tables[-1][-1]
46 rhs = table[state]
47 else:
48 rhs = None
49 if rhs is None:
50 return LookaheadPolicy.execute(self,state,choices,debug,horizon,
51 explain)
52 else:
53 return rhs,None
54
55 - def project(self,R,depth=-1,debug=False):
56 """
57 Project the value function and policy one step further at the given depth
58 @param R: the reward function in tabular form
59 @type R: L{PWLTable}
60 @param depth: the recursive belief depth to compute the value function for (default is deepest level already computed)
61 @type depth: int
62 """
63
64 try:
65 previous = self.tables[depth]
66 except IndexError:
67 if depth < 0:
68 if len(self.tables) == abs(depth) - 1:
69
70 self.tables.append([])
71 else:
72
73 self.project(R,depth+1,debug)
74 else:
75 if len(self.tables) == depth:
76
77 self.tables.append([])
78 else:
79
80 self.project(R,depth-1,debug)
81 previous = self.tables[depth]
82 if debug:
83 print 'Horizon = %d' % (len(self.tables[depth]))
84
85 V = R.getTable()
86 if len(previous) > 0:
87
88 Vstar = previous[-1].star()
89 Vstar.prune()
90 if debug:
91 print 'V*'
92 print Vstar
93 for omega,SE in self.entity.estimators.items():
94
95 product = Vstar.__mul__(SE,debug=debug)
96
97 if debug:
98 print 'SE(b,%s)' % (omega)
99 print SE
100 print 'V*(SE(b,%s))' % (omega)
101 print product
102 V = V.__add__(product)
103
104
105 if debug:
106 print '\tMax (%d rules)' % (len(V.values))
107 print V
108 policy = V.max()
109
110
111
112 previous.append(policy)
113
114 for rule in policy.rules.keys():
115 for option in self.entity.actions.getOptions():
116 if str(option) == policy.rules[rule]:
117 policy.rules[rule] = option
118 break
119 else:
120 raise NameError,'Unable to find RHS action "%s"' % \
121 (policy.rules[rule])
122
123 - def getTable(self,depth=-1,horizon=-1):
124 """
125 @param depth: the recursive depth for the desired policy (default is maximum depth solved)
126 @param horizon: the horizon for the desired policy (default is maximum horizon solved)
127 @type depth, horizon: int
128 @return: a given policy table
129 @rtype: L{PWLTable}
130 @warning: Will raise C{IndexError} if no policy has been solved for the given horizon and depth settings
131 """
132 table = self.tables[depth][horizon]
133 return table.getTable()
134