1 """Utility functions and base testing class for verifying PWL code
2 @author: David V. Pynadath <pynadath@isi.edu>
3 """
4 from teamwork.math.Keys import *
5 from teamwork.math.KeyedMatrix import *
6 from teamwork.math.KeyedTree import *
7
8 from random import random,uniform
9 import unittest
10
12 """Generate a random vector, with random constant field
13 @param keys: the keys to consider using in the vector
14 @type keys: L{Key}[]
15 @param fillProb: the probability that each key will be present in the vector (default is 0.5)
16 @type fillProb: float
17 @rtype: L{KeyedVector}"""
18 row = KeyedVector()
19 for key in keys:
20 if random() <= fillProb:
21
22 row[key] = uniform(-1.,1.)
23 if random() <= fillProb:
24 row[keyConstant] = uniform(-1.,1.)
25 return row
26
42
44 """Generate a random vector, with constant field of 1
45 @param keys: the keys to consider using in the vector
46 @type keys: L{Key}[]
47 @param fillProb: the probability that each key will be present in the vector (default is 0.5)
48 @type fillProb: float
49 @rtype: L{KeyedVector}"""
50 vector = makeVector(keys,fillProb)
51 vector[keyConstant] = 1.
52 vector.fill(keys)
53 return vector
54
56 """Generate a random hyperplane
57 @param keys: the keys to consider using in the vector
58 @type keys: L{Key}[]
59 @param fillProb: the probability that a given key will be present in the weights vector (default is 0.5)
60 @type fillProb: float
61 @rtype: L{KeyedPlane}"""
62 weights = makeVector(keys,fillProb)
63 weights[keyConstant] = 0.
64 weights.fill(keys)
65 threshold = random()
66 return KeyedPlane(weights,threshold)
67
69 """Constructs a random dynamics tree
70 @param keys: the keys to consider using in the vector
71 @type keys: L{Key}[]
72 @param depth: the depth of the tree to return
73 @type depth: int
74 @rtype: L{KeyedTree} instance"""
75 tree = _makeTree(keys,depth,fillProb)
76 tree.freeze()
77 return tree
78
80 """Constructs a random dynamics tree
81 @param keys: the keys to consider using in the vector
82 @type keys: L{Key}[]
83 @param depth: the depth of the tree to return
84 @type depth: int
85 @rtype: L{KeyedTree} instance"""
86 tree = KeyedTree()
87 if depth > 0:
88 plane = makePlane(keys,fillProb)
89 falseTree = makeTree(keys,depth-1,fillProb)
90 trueTree = makeTree(keys,depth-1,fillProb)
91 tree.branch(plane,falseTree,trueTree,pruneT=False,pruneF=False)
92 else:
93 dynamics = makeDynamics(keys,fillProb)
94 tree.makeLeaf(dynamics)
95 return tree
96
98 """Base class for testing PWL code. Intended to be abstract class
99 @cvar precision: number of significant digits to check for equality
100 @type precision: int
101 @cvar agents: a list of entity names to use as test cases
102 @type agents: str[]
103 @cvar features: a list of state features to use as test case
104 @type features: str[]
105 """
106 features = []
107 agents = []
108 precision = 8
109
114
116 """Tests equality of given vectors
117 @type old,new: L{KeyedVector}
118 """
119 self.assert_(isinstance(old,KeyedVector))
120 self.assert_(isinstance(new,KeyedVector))
121 for key in old.keys():
122 self.assert_(new.has_key(key))
123 self.assertAlmostEqual(old[key],new[key],self.precision)
124 for key in new.keys():
125 self.assert_(old.has_key(key))
126
128 """Tests equality of given matrices
129 @type old,new: L{KeyedMatrix}
130 """
131 self.assert_(isinstance(old,KeyedMatrix))
132 self.assert_(isinstance(new,KeyedMatrix))
133 for key in old.keys():
134 self.assert_(new.has_key(key))
135 self.verifyVector(old[key],new[key])
136 for key in new.keys():
137 self.assert_(old.has_key(key))
138
147
149 """Tests equality of given trees
150 @type old,new: L{KeyedTree}
151 """
152 self.assert_(isinstance(old,KeyedTree))
153 self.assert_(isinstance(new,KeyedTree))
154 if old.isLeaf():
155 self.assert_(new.isLeaf())
156 if isinstance(old.getValue(),KeyedMatrix):
157 matrix1 = old.getValue()
158 matrix2 = new.getValue()
159 self.verifyMatrix(matrix1,matrix2)
160 else:
161 self.assertEqual(old.getValue(),new.getValue())
162 else:
163 self.assert_(not new.isLeaf())
164 self.assertEqual(len(old.split),len(new.split))
165 for index in range(len(old.split)):
166 self.verifyPlane(old.split[index],new.split[index])
167 self.verifyTree(old.falseTree,new.falseTree)
168 self.verifyTree(old.trueTree,new.trueTree)
169