1
2 import string
3 import Num
4
5 inttype = type(1)
6
7
9 - def __init__(self, sz, bw, _data = None):
10 """Constructs a symmetric band-diagonal matrix of sz*sz, with
11 a nonzero bandwidth of bw.
12 bw==0 corresponds to a diagonal matrix."""
13 self.n = sz
14 if bw > sz:
15 raise ValueError, 'Bandwidth too large for size'
16 self.kd = bw
17 self.ldab = self.kd+1
18 if _data is None:
19 self.d = Num.zeros((self.n, self.ldab), Num.Float)
20 else:
21 self.d = _data
22 self.shape = (sz, sz)
23
24
26 """Copy the data, not just the data description."""
27 return sbd(self.n, self.kd, Num.array(self.d, copy=True))
28
31
32
34 """Converts from (x,y) external representation of the
35 matrix storage to the index where it is actually
36 stored internally."""
37 if len(key) != 2 or type(key[0])!=inttype or type(key[1])!=inttype:
38 raise TypeError, 'Need two integer indices'
39 diagdist = abs(key[1] - key[0])
40 minidx = min(key[1], key[0])
41 if minidx < 0:
42 raise IndexError, 'indices must be positive'
43 if diagdist > self.kd:
44 raise IndexError, 'out of band: %d/%d/(%d,%d)'%(diagdist, self.kd, self.d.shape[0], self.d.shape[1])
45 return (minidx, diagdist)
46
47
49 """Pulls an element of the array. Key is a 2-tuple that specifies the
50 'virtual' position in the array (i.e., as if the array were a full
51 square matrix, rather than a band diagonal symmetric matrix."""
52 try:
53 k0, k1 = self.__idx2(key)
54 except IndexError:
55 return 0.0
56 return self.d[k0, k1]
57
59 k0, k1 = self.__idx2(key)
60 self.d[k0, k1] = value
61
63 """Increment a single value
64 (and its symmetric partner, if off the diagonal)."""
65 k0, k1 = self.__idx2(key)
66 self.d[k0,k1] = delta + self.d[k0, k1]
67
69 """Block diagonal increment."""
70 assert len(delta.shape) == 2
71 assert delta.shape[0]==delta.shape[1]
72 n = delta.shape[0]
73 for i in range(n):
74 e = n - i
75 s = key + i
76
77
78
79 wrk = self.d[s,:e]
80 Num.add(wrk, delta[i,i:], wrk)
81
92
95
97 return sbd(self.n, self.kd, (self.d)**other)
98
100 """This is copy semantics, not shared reference.
101 """
102 if i<0 or j<0:
103 raise IndexError, 'indices must be positive'
104 high = min(j, self.n)
105 low = min(i, j)
106 if high == low:
107 return None
108 rv = Num.zeros((high-low, self.n), Num.Float)
109 for r in range(low, high):
110 e = min(r+self.kd+1, self.n)
111 rv[r-low,r:e] = self.d[r, 0:e-r]
112 for r in range(max(0, low-self.kd), high):
113
114 e = min(r+self.kd+1, self.n)
115 ee = min(e, high)-low
116 q = max(low-r, 0)
117 rv[q+r-low:ee,r] = self.d[r, q:ee-(r-low)]
118 return rv
119
120
123
124
128
129
131 """ solves a*x = b, where a is class sbd:
132 symmetric, positive definite, and band-diagonal.
133 Note that this destroys the contents of a."""
134 b = Num.asarray(b0, Num.Float)
135 if len(b.shape) == 1:
136
137 b = Num.reshape(b, (1, b.shape[0]))
138
139 assert a.n == b.shape[1]
140 import lapack_dpb
141
142
143
144
145 result = lapack_dpb.dpbsv('L', a.n, a.kd, b.shape[0], a.d,
146 a.ldab, b, max(1, a.n), 0)
147
148 if result['info'] != 0:
149 raise NoSolutionError, 'Linear system has no solution. Lapack_dpb.dpbsv info code=%d' % result['info']
150
151 return b
152
153
155 """Calculates a*x, where a is class sbd:
156 symmetric, positive definite, and band-diagonal."""
157 xx = Num.asarray(x, Num.Float)
158 assert len(xx.shape) == 1
159 assert a.n == xx.shape[0]
160 import dblas
161 y = Num.zeros(xx.shape, Num.Float)
162
163
164
165
166 result = dblas.dsbmv('L', a.n, a.kd,
167 Num.ones((1,), Num.Float), a.d,
168 a.ldab, xx, 1,
169 Num.zeros((1,), Num.Float),
170 y, 1)
171 assert result == 0
172
173
174 return y
175
176
178 x = sbd(100, 10)
179 x[0,0] = 1
180 assert x[0,0] == 1
181 x[99,99] = 2
182 assert x[99,99] == 2
183 x[10,0] = 10
184 assert x[10,0] == 10
185 assert x[0,10] == 10
186 try:
187 x[12,1] = 2
188 except IndexError:
189 pass
190 else:
191 raise RuntimeError, "whoops"
192
193
196
198 x = sbd(6,2)
199 b = Num.zeros((6,), Num.Float) + 10
200 for i in range(6):
201 x[i,i] = 2
202 x[2,1] = 0.5
203
204 y = solve(x, b)
205 y_t = Num.array(([5, 4, 4, 5, 5, 5],))
206
207
208 assert err(y_t, y) < 1e-9
209
210
212 a = sbd(6, 2)
213 a[4,3] = 1
214 a[4,4] = 2
215 a[5,3] = 3
216 a[5,5] = -1
217 x = Num.array((1, 0, 0, 0, 0, 0), Num.Float)
218 print multiply(a, x)
219 x = Num.array((0, 0, 0, 0, 1, 0), Num.Float)
220 print multiply(a, x)
221 x = Num.array((0, 0, 0, 0, 0, 1), Num.Float)
222 print multiply(a, x)
223
224
226 x = sbd(6, 2)
227 x[4,3] = 1
228 x[4,4] = 2
229 x[5,3] = 3
230 x[5,5] = -1
231
232 y = x[:]
233
234 assert y[4,3]==1
235 assert y[4,4]==2
236 assert y[5,3]==3
237 assert y[4,5]==0
238 assert y[3,4]==1
239 assert y[3,3]==0
240 assert y[4,5]==0
241 assert y[5,5] == -1
242 y = x[5:6]
243
244 assert y[0,5] == -1
245 assert y[0,3] == 3
246 assert y[0,4] == 0
247
249 x = sbd(6,2)
250 y = Num.array([[1, 2], [2, 3]])
251 x.bd_increment(1, y)
252 assert x[1,1]==1
253 assert x[2,2]==3
254 assert x[2,1]==2
255 assert x[1,2]==2
256 assert x[0,0]==0
257 assert x[3,3]==0
258 assert x[3,2]==0
259 z = Num.array([[-1, 0, 1], [0, 0, 0], [0, 0, 0]])
260 x.bd_increment(0, z)
261 assert x[0,0]==-1
262 assert x[0, 2]==1
263 assert x[1,1]==1
264 assert x[1, 2]==2
265
266 if __name__ == '__main__':
267 test1()
268 test2()
269 testa()
270 testbdi()
271 testm()
272