Package teamwork :: Package test :: Package policy :: Module testPWLTable
[hide private]
[frames] | no frames]

Source Code for Module teamwork.test.policy.testPWLTable

  1  from teamwork.policy.pwlTable import PWLTable 
  2  from teamwork.math.Keys import WorldKey 
  3  from teamwork.math.KeyedVector import KeyedVector 
  4  from teamwork.math.KeyedMatrix import KeyedMatrix 
  5   
  6  import copy 
  7  import random 
  8  import unittest 
  9   
10 -class TestPWLTable(unittest.TestCase):
11
12 - def generateVector(self,dimension=2,probability=False):
13 """ 14 @param dimension: the size of the vector (default is 2) 15 @type dimension: int 16 @param probability: if C{True}, then make sure vector values are nonneagtive and sum to 1. (default is C{False}) 17 @type probability: bool 18 @return: a random vector 19 @rtype: L{KeyedVector} 20 """ 21 self.assert_(dimension > 0) 22 total = 1. 23 vector = {} 24 for world in range(dimension-1): 25 key = WorldKey({'world': world}) 26 if probability: 27 value = random.random()*total 28 total -= value 29 else: 30 value = random.random()*2. - 1. 31 vector[key] = value 32 key = WorldKey({'world': dimension-1}) 33 if probability: 34 vector[key] = total 35 else: 36 vector[key] = random.random()*2. - 1. 37 return KeyedVector(vector)
38
39 - def generateLHS(self,attributes=2,dimension=2):
40 """ 41 Generates an initialized table with random LHS conditions, but no RHS 42 @param attributes: the number of LHS attributes (default is 3) 43 @type attributes: int 44 @param dimension: the size of the vector (default is 2) 45 @type dimension: int 46 @return: a table with random LHS attributes 47 @rtype: L{PWLTable} 48 """ 49 table = PWLTable() 50 while len(table.attributes) < attributes: 51 table.addAttribute(self.generateVector(),0.) 52 table.initialize() 53 return table
54
55 - def generateRHS(self,table,actions=2,dimension=2,matrix=False, 56 probability=False,rules=None):
57 """ 58 Fills in random RHS value function for the given table 59 @param table: initialized table 60 @type table: L{PWLTable} 61 @param actions: the number of actions to define value function over (default is 3) 62 @type actions: int 63 @param dimension: dimension of arrays in RHS values (default is 2) 64 @type dimension: int 65 @param matrix: if C{True}, use matrices in RHS; otherwise, use vectors (default is C{False}, i.e., vectors) 66 @type matrix: bool 67 @param probability: if C{True}, then make sure vector values are nonneagtive and sum to 1. (default is C{False}) 68 @type probability: bool 69 @param rules: the maximum number of rules to generate (default is as many as supported by LHS combinations) 70 @type rules: int 71 @warning: the table is modified in place 72 """ 73 size = 1 74 for index in range(len(table.attributes)): 75 size *= len(table.attributes[index][1]) + 1 76 if rules is None: 77 rules = size 78 else: 79 rules = min(rules,size) 80 remaining = range(size) 81 while len(table.values) < rules: 82 rule = random.choice(remaining) 83 remaining.remove(rule) 84 table.values[rule] = {} 85 for action in range(actions): 86 if matrix: 87 rhs = self.generateMatrix(dimension,probability) 88 else: 89 rhs = self.generateVector(dimension,probability) 90 table.values[rule]['Action %d' % (action)] = rhs
91
92 - def testAddAttribute(self):
93 for iteration in range(10): 94 table = PWLTable() 95 cache = {} 96 count = 10 97 for attr in range(count): 98 vector = self.generateVector() 99 table.addAttribute(vector,0.) 100 if cache.has_key(vector): 101 count -= 1 102 else: 103 cache[vector] = True 104 table.addAttribute(copy.copy(vector),0.5) 105 # Test that uniqueness is preserved 106 self.assertEqual(len(table.attributes),count) 107 for attr in range(len(table.attributes)-1): 108 # Test that attributes are sorted 109 self.assert_(list(table.attributes[attr][0].getArray()) < \ 110 list(table.attributes[attr+1][0].getArray()), 111 '%s >= %s' % (table.attributes[attr][0], 112 table.attributes[attr+1][0])) 113 # Test that values are correct 114 self.assertEqual(len(table.attributes[attr][1]),2) 115 self.assertAlmostEqual(table.attributes[attr][1][0],0.) 116 self.assertAlmostEqual(table.attributes[attr][1][1],0.5) 117 self.assertEqual(len(table.attributes[-1][1]),2) 118 self.assertAlmostEqual(table.attributes[-1][1][0],0.) 119 self.assertAlmostEqual(table.attributes[-1][1][1],0.5)
120
121 - def testMax(self):
122 """ 123 @note: tests C{star} method as well 124 """ 125 for iteration in range(100): 126 table = self.generateLHS() 127 self.generateRHS(table) 128 originalRules = copy.deepcopy(table.rules) 129 originalValue = copy.deepcopy(table.values) 130 maxTable = table.max() 131 self.assertEqual(table.rules,originalRules) 132 self.assertEqual(table.values,originalValue) 133 starTable = maxTable.star() 134 for subiteration in range(100): 135 state = self.generateVector(probability=True) 136 index = table.index(state) 137 best = None 138 for action,V in table.values[index].items(): 139 ER = V*state 140 if best is None or ER > best['value']: 141 best = {'action': action, 'value': ER} 142 action = maxTable[state] 143 self.assertEqual(action,best['action']) 144 index = maxTable.index(state) 145 ER = maxTable.values[index][action]*state 146 self.assertEqual(ER,best['value']) 147 self.assertEqual(ER,starTable[state]*state)
148 149 if __name__ == '__main__': 150 unittest.main() 151