Package gmisclib :: Module sbd_array
[frames] | no frames]

Source Code for Module gmisclib.sbd_array

  1   
  2  import string 
  3  import Num 
  4   
  5  inttype = type(1) 
  6   
  7   
8 -class sbd:
9 - def __init__(self, sz, bw, _data = None):
10 """Constructs a symmetric band-diagonal matrix of sz*sz, with 11 a nonzero bandwidth of bw. 12 bw==0 corresponds to a diagonal matrix.""" 13 self.n = sz 14 if bw > sz: 15 raise ValueError, 'Bandwidth too large for size' 16 self.kd = bw 17 self.ldab = self.kd+1 18 if _data is None: 19 self.d = Num.zeros((self.n, self.ldab), Num.Float) 20 else: 21 self.d = _data # Not a public interface! 22 self.shape = (sz, sz)
23 24
25 - def __copy__(self):
26 """Copy the data, not just the data description.""" 27 return sbd(self.n, self.kd, Num.array(self.d, copy=True))
28
29 - def __deepcopy(self, m):
30 return self.__copy__()
31 32
33 - def __idx2(self, key):
34 """Converts from (x,y) external representation of the 35 matrix storage to the index where it is actually 36 stored internally.""" 37 if len(key) != 2 or type(key[0])!=inttype or type(key[1])!=inttype: 38 raise TypeError, 'Need two integer indices' 39 diagdist = abs(key[1] - key[0]) 40 minidx = min(key[1], key[0]) 41 if minidx < 0: 42 raise IndexError, 'indices must be positive' 43 if diagdist > self.kd: 44 raise IndexError, 'out of band: %d/%d/(%d,%d)'%(diagdist, self.kd, self.d.shape[0], self.d.shape[1]) 45 return (minidx, diagdist)
46 47
48 - def __getitem__(self, key):
49 """Pulls an element of the array. Key is a 2-tuple that specifies the 50 'virtual' position in the array (i.e., as if the array were a full 51 square matrix, rather than a band diagonal symmetric matrix.""" 52 try: 53 k0, k1 = self.__idx2(key) 54 except IndexError: 55 return 0.0 56 return self.d[k0, k1]
57
58 - def __setitem__(self, key, value):
59 k0, k1 = self.__idx2(key) 60 self.d[k0, k1] = value
61
62 - def increment(self, key, delta):
63 """Increment a single value 64 (and its symmetric partner, if off the diagonal).""" 65 k0, k1 = self.__idx2(key) 66 self.d[k0,k1] = delta + self.d[k0, k1]
67
68 - def bd_increment(self, key, delta):
69 """Block diagonal increment.""" 70 assert len(delta.shape) == 2 71 assert delta.shape[0]==delta.shape[1] 72 n = delta.shape[0] 73 for i in range(n): 74 e = n - i 75 s = key + i 76 # print "i=", i, "n=", n, "s=", s, "e=", e 77 # print "sd.shape=", self.d[s,:e].shape 78 # print "d.shape=", delta[i,i:].shape 79 wrk = self.d[s,:e] 80 Num.add(wrk, delta[i,i:], wrk)
81
82 - def __str__(self):
83 o = [] 84 o.append( "<sbdarray\n") 85 for i in range(self.n): 86 o.append('[') 87 for j in range(self.n): 88 o.append( repr(self.__getitem__((i, j))) ) 89 o.append(']\n') 90 o.append('>') 91 return string.join(o, ' ');
92
93 - def __repr__(self):
94 return self.__str__()
95
96 - def __pow__(self, other):
97 return sbd(self.n, self.kd, (self.d)**other)
98
99 - def __getslice__(self, i, j):
100 """This is copy semantics, not shared reference. 101 """ 102 if i<0 or j<0: 103 raise IndexError, 'indices must be positive' 104 high = min(j, self.n) 105 low = min(i, j) 106 if high == low: 107 return None 108 rv = Num.zeros((high-low, self.n), Num.Float) 109 for r in range(low, high): 110 e = min(r+self.kd+1, self.n) 111 rv[r-low,r:e] = self.d[r, 0:e-r] 112 for r in range(max(0, low-self.kd), high): 113 # Clunky, but correct: 114 e = min(r+self.kd+1, self.n) 115 ee = min(e, high)-low 116 q = max(low-r, 0) 117 rv[q+r-low:ee,r] = self.d[r, q:ee-(r-low)] 118 return rv
119 120
121 -def symmetrize(matrix, direction):
122 return matrix
123 124
125 -class NoSolutionError(ValueError):
126 - def __init__(self, s):
127 ValueError.__init__(self, s)
128 129
130 -def solve(a, b0):
131 """ solves a*x = b, where a is class sbd: 132 symmetric, positive definite, and band-diagonal. 133 Note that this destroys the contents of a.""" 134 b = Num.asarray(b0, Num.Float) 135 if len(b.shape) == 1: 136 # b = Num.array((b,), Num.Float) 137 b = Num.reshape(b, (1, b.shape[0])) 138 # print "b.shape=", b.shape 139 assert a.n == b.shape[1] 140 import lapack_dpb 141 # print "type(a.d)==", type(a.d) 142 # print "a.d=", a.d 143 # print "ravel(a.d)=", Num.ravel(a.d) 144 # print "a.ldab=", a.ldab 145 result = lapack_dpb.dpbsv('L', a.n, a.kd, b.shape[0], a.d, 146 a.ldab, b, max(1, a.n), 0) 147 # print "result=", result 148 if result['info'] != 0: 149 raise NoSolutionError, 'Linear system has no solution. Lapack_dpb.dpbsv info code=%d' % result['info'] 150 # print "b=", b 151 return b
152 153
154 -def multiply(a, x):
155 """Calculates a*x, where a is class sbd: 156 symmetric, positive definite, and band-diagonal.""" 157 xx = Num.asarray(x, Num.Float) 158 assert len(xx.shape) == 1 159 assert a.n == xx.shape[0] 160 import dblas 161 y = Num.zeros(xx.shape, Num.Float) 162 # print "type(a.d)==", type(a.d) 163 # print "a.d=", a.d 164 # print "ravel(a.d)=", Num.ravel(a.d) 165 # print "a.ldab=", a.ldab 166 result = dblas.dsbmv('L', a.n, a.kd, 167 Num.ones((1,), Num.Float), a.d, 168 a.ldab, xx, 1, 169 Num.zeros((1,), Num.Float), 170 y, 1) 171 assert result == 0 172 # print "result=", result 173 # print "b=", b 174 return y
175 176
177 -def test1():
178 x = sbd(100, 10) 179 x[0,0] = 1 180 assert x[0,0] == 1 181 x[99,99] = 2 182 assert x[99,99] == 2 183 x[10,0] = 10 184 assert x[10,0] == 10 185 assert x[0,10] == 10 186 try: 187 x[12,1] = 2 188 except IndexError: 189 pass 190 else: 191 raise RuntimeError, "whoops"
192 193
194 -def err(a, b):
195 return Num.sum(Num.square(a-b))
196
197 -def test2():
198 x = sbd(6,2) 199 b = Num.zeros((6,), Num.Float) + 10 200 for i in range(6): 201 x[i,i] = 2 202 x[2,1] = 0.5 203 # print "x=", x 204 y = solve(x, b) 205 y_t = Num.array(([5, 4, 4, 5, 5, 5],)) 206 # print "y=", y, "y_t=", y_t 207 # print "x=", x 208 assert err(y_t, y) < 1e-9
209 210
211 -def testm():
212 a = sbd(6, 2) 213 a[4,3] = 1 214 a[4,4] = 2 215 a[5,3] = 3 216 a[5,5] = -1 217 x = Num.array((1, 0, 0, 0, 0, 0), Num.Float) 218 print multiply(a, x) 219 x = Num.array((0, 0, 0, 0, 1, 0), Num.Float) 220 print multiply(a, x) 221 x = Num.array((0, 0, 0, 0, 0, 1), Num.Float) 222 print multiply(a, x)
223 224
225 -def testa():
226 x = sbd(6, 2) 227 x[4,3] = 1 228 x[4,4] = 2 229 x[5,3] = 3 230 x[5,5] = -1 231 232 y = x[:] 233 # print "y=", y 234 assert y[4,3]==1 235 assert y[4,4]==2 236 assert y[5,3]==3 237 assert y[4,5]==0 238 assert y[3,4]==1 239 assert y[3,3]==0 240 assert y[4,5]==0 241 assert y[5,5] == -1 242 y = x[5:6] 243 # print "y=", y 244 assert y[0,5] == -1 245 assert y[0,3] == 3 246 assert y[0,4] == 0
247
248 -def testbdi():
249 x = sbd(6,2) 250 y = Num.array([[1, 2], [2, 3]]) 251 x.bd_increment(1, y) 252 assert x[1,1]==1 253 assert x[2,2]==3 254 assert x[2,1]==2 255 assert x[1,2]==2 256 assert x[0,0]==0 257 assert x[3,3]==0 258 assert x[3,2]==0 259 z = Num.array([[-1, 0, 1], [0, 0, 0], [0, 0, 0]]) 260 x.bd_increment(0, z) 261 assert x[0,0]==-1 262 assert x[0, 2]==1 263 assert x[1,1]==1 264 assert x[1, 2]==2
265 266 if __name__ == '__main__': 267 test1() 268 test2() 269 testa() 270 testbdi() 271 testm() 272