1 from teamwork.policy.pwlTable import PWLTable
2 from teamwork.math.Keys import WorldKey
3 from teamwork.math.KeyedVector import KeyedVector
4 from teamwork.math.KeyedMatrix import KeyedMatrix
5
6 import copy
7 import random
8 import unittest
9
11
13 """
14 @param dimension: the size of the vector (default is 2)
15 @type dimension: int
16 @param probability: if C{True}, then make sure vector values are nonneagtive and sum to 1. (default is C{False})
17 @type probability: bool
18 @return: a random vector
19 @rtype: L{KeyedVector}
20 """
21 self.assert_(dimension > 0)
22 total = 1.
23 vector = {}
24 for world in range(dimension-1):
25 key = WorldKey({'world': world})
26 if probability:
27 value = random.random()*total
28 total -= value
29 else:
30 value = random.random()*2. - 1.
31 vector[key] = value
32 key = WorldKey({'world': dimension-1})
33 if probability:
34 vector[key] = total
35 else:
36 vector[key] = random.random()*2. - 1.
37 return KeyedVector(vector)
38
40 """
41 Generates an initialized table with random LHS conditions, but no RHS
42 @param attributes: the number of LHS attributes (default is 3)
43 @type attributes: int
44 @param dimension: the size of the vector (default is 2)
45 @type dimension: int
46 @return: a table with random LHS attributes
47 @rtype: L{PWLTable}
48 """
49 table = PWLTable()
50 while len(table.attributes) < attributes:
51 table.addAttribute(self.generateVector(),0.)
52 table.initialize()
53 return table
54
55 - def generateRHS(self,table,actions=2,dimension=2,matrix=False,
56 probability=False,rules=None):
57 """
58 Fills in random RHS value function for the given table
59 @param table: initialized table
60 @type table: L{PWLTable}
61 @param actions: the number of actions to define value function over (default is 3)
62 @type actions: int
63 @param dimension: dimension of arrays in RHS values (default is 2)
64 @type dimension: int
65 @param matrix: if C{True}, use matrices in RHS; otherwise, use vectors (default is C{False}, i.e., vectors)
66 @type matrix: bool
67 @param probability: if C{True}, then make sure vector values are nonneagtive and sum to 1. (default is C{False})
68 @type probability: bool
69 @param rules: the maximum number of rules to generate (default is as many as supported by LHS combinations)
70 @type rules: int
71 @warning: the table is modified in place
72 """
73 size = 1
74 for index in range(len(table.attributes)):
75 size *= len(table.attributes[index][1]) + 1
76 if rules is None:
77 rules = size
78 else:
79 rules = min(rules,size)
80 remaining = range(size)
81 while len(table.values) < rules:
82 rule = random.choice(remaining)
83 remaining.remove(rule)
84 table.values[rule] = {}
85 for action in range(actions):
86 if matrix:
87 rhs = self.generateMatrix(dimension,probability)
88 else:
89 rhs = self.generateVector(dimension,probability)
90 table.values[rule]['Action %d' % (action)] = rhs
91
93 for iteration in range(10):
94 table = PWLTable()
95 cache = {}
96 count = 10
97 for attr in range(count):
98 vector = self.generateVector()
99 table.addAttribute(vector,0.)
100 if cache.has_key(vector):
101 count -= 1
102 else:
103 cache[vector] = True
104 table.addAttribute(copy.copy(vector),0.5)
105
106 self.assertEqual(len(table.attributes),count)
107 for attr in range(len(table.attributes)-1):
108
109 self.assert_(list(table.attributes[attr][0].getArray()) < \
110 list(table.attributes[attr+1][0].getArray()),
111 '%s >= %s' % (table.attributes[attr][0],
112 table.attributes[attr+1][0]))
113
114 self.assertEqual(len(table.attributes[attr][1]),2)
115 self.assertAlmostEqual(table.attributes[attr][1][0],0.)
116 self.assertAlmostEqual(table.attributes[attr][1][1],0.5)
117 self.assertEqual(len(table.attributes[-1][1]),2)
118 self.assertAlmostEqual(table.attributes[-1][1][0],0.)
119 self.assertAlmostEqual(table.attributes[-1][1][1],0.5)
120
122 """
123 @note: tests C{star} method as well
124 """
125 for iteration in range(100):
126 table = self.generateLHS()
127 self.generateRHS(table)
128 originalRules = copy.deepcopy(table.rules)
129 originalValue = copy.deepcopy(table.values)
130 maxTable = table.max()
131 self.assertEqual(table.rules,originalRules)
132 self.assertEqual(table.values,originalValue)
133 starTable = maxTable.star()
134 for subiteration in range(100):
135 state = self.generateVector(probability=True)
136 index = table.index(state)
137 best = None
138 for action,V in table.values[index].items():
139 ER = V*state
140 if best is None or ER > best['value']:
141 best = {'action': action, 'value': ER}
142 action = maxTable[state]
143 self.assertEqual(action,best['action'])
144 index = maxTable.index(state)
145 ER = maxTable.values[index][action]*state
146 self.assertEqual(ER,best['value'])
147 self.assertEqual(ER,starTable[state]*state)
148
149 if __name__ == '__main__':
150 unittest.main()
151