1 import math
2
3 import LinearAlgebra
4
5 import ortho_poly
6
7 from gmisclib import Num
8 from gmisclib import die
9 from gmisclib import gpk_lsq
10
11
15
16
17 SING_VAL_RATIO_LIM = 1e-6
18
19
22 RuntimeError.__init__(self, "Support for fit is poor: %s" % s)
23
24
26 - def __init__(self, polys, coefs, err=None, rank=None, fit=None):
27 self.p = polys
28 self.c = coefs
29 self.e = err
30 self.rank = rank
31 if fit is not None:
32 self._fit = Num.array(fit)
33 else:
34 self._fit = None
35
37 if self._fit is None:
38 self._fit = self.p.expand(self.c)
39 return self._fit
40
41
43 sum = Num.zeros(time.shape, Num.Float)
44 for i in range(len(self.c)):
45
46
47 sum += self.p.P(i) * self.c[i]
48 return sum
49
50
52 p = x / Num.sum(x)
53 return math.exp(-Num.sum(p * Num.log(p)))
54
55
56
57
58
59 -def poly(x=None, n=None, name='Chebyshev2'):
60 return ortho_poly.F(name, x=x, n=n)
61
62
63 -def ortho_fit_svd(x, y, wt, N, svr_lim=None, svrls=None, name='Chebyshev2'):
64 """This uses a SVD fit to handle ill-constrained fits."""
65 assert N >= 0
66 assert N < 100
67 EPS = 1e-8
68 die.note('y.shape', y.shape)
69 die.note('len(y)', len(y))
70 die.note('N', N)
71 assert Num.alltrue(Num.absolute(x)<=1+EPS), "Time must be normalized into [-1,1]"
72
73
74 if len(y) < N:
75 raise NotEnoughData, "y"
76 op = ortho_poly.F(name, x=x)
77 wwt = op.wt() * wt
78 rwwt = Num.sqrt(wwt)
79 if N == 0:
80 return ( solution(op, Num.zeros((0,), Num.Float)),
81 Num.sum(wwt), None
82 )
83
84 assert len(wwt.shape) == 1
85 yw = y * rwwt
86 assert len(yw.shape)==1
87 die.note('yw.shape', yw.shape)
88 die.note('len(yw)', len(yw))
89
90 if Num.sum( wwt > 1e-10 ) <= N:
91 raise NotEnoughData, "wt"
92 basisvec = []
93 for i in range(N):
94
95 basisvec.append( op.P(i) * rwwt )
96 basisw = Num.transpose(basisvec)
97 assert len(basisw.shape) == 2
98
99
100
101 if svrls is None:
102
103
104
105
106
107
108 svrls = 0.5/math.sqrt(len(y))
109
110
111
112
113 raise Exception, "Next line is obsolete!"
114 soln, fitw, rank, q = gpk_lsq.linear_least_squares(basisw, yw, svrls)
115
116
117
118 assert len(soln) == N
119
120
121
122
123 try:
124 svr = min(q) / max(q)
125 except ZeroDivisionError:
126 die.warn('Largest singular value is apparently zero!')
127 raise
128
129 if svr_lim is None:
130 svr_lim = SING_VAL_RATIO_LIM
131 if svr < svr_lim:
132
133
134
135
136 raise BadDataSupport, 'Ortho_fit is nearly degenerate (svr=%g::%g)' % (svr, svr_lim)
137
138 assert len(soln) == N
139 swt = Num.sum(wwt)
140 soln = solution(op, soln, err=Num.sum((fitw-yw)**2)/swt, rank=rank)
141 return (soln, swt, svr)
142
143
144 ortho_fit = ortho_fit_svd
145
146
147 -def ortho_fit_reg(x, y, wt, N, regstr, regtgt,
148 rscale=None, name='Chebyshev2'):
149 """This uses linear regularization to handle ill-constrained fits."""
150 assert N >= 0
151 assert N < 100
152 EPS = 1e-8
153 die.note('y.shape', y.shape)
154 die.note('len(y)', len(y))
155 die.note('N', N)
156 assert Num.alltrue(Num.absolute(x)<=1+EPS), "Time must be normalized into [-1,1]"
157
158
159 if len(y) < N:
160 raise NotEnoughData, "y"
161 op = ortho_poly.F(name, x=x)
162 wwt = op.wt() * wt
163 rwwt = Num.sqrt(wwt)
164 if N == 0:
165 raise Exception, "something wrong here -- inconsistent return values"
166 return ( solution(op, Num.zeros((0,), Num.Float)),
167 Num.sum(wwt), None
168 )
169
170 assert len(wwt.shape) == 1
171 yw = y * rwwt
172 assert len(yw.shape)==1
173 die.note('yw.shape', yw.shape)
174 die.note('len(yw)', len(yw))
175
176 if Num.sum( wwt > 1e-10 ) <= N:
177 raise NotEnoughData, "wt"
178 basisvec = []
179 for i in range(N):
180
181 basisvec.append( op.P(i) * rwwt )
182 basisw = Num.transpose(basisvec)
183 assert len(basisw.shape) == 2
184
185
186
187
188 rlss = gpk_lsq.reg_linear_least_squares(basisw, yw, regstr,
189 regtgt, rscale=rscale)
190
191
192
193 assert len(rlss.x) == N
194 swt = Num.sum(wwt)
195 rank = rlss.eff_rank()
196 soln = solution(op, rlss.x, err=Num.sum((rlss.fit()-yw)**2)/swt, rank=rank)
197 raise Exception, "something wrong here -- inconsistent return values"
198 return (soln, swt)
199
200
201
203 wwt = soln.p.wt() * wt
204 r = data - soln.expand(x)
205 return Num.sum(wwt * r**2)/Num.sum(wwt)
206
207
209 wwt = soln.p.wt() * wt
210 r = data - soln.expand(x)
211 return Num.sum(wwt*Num.absolute(r))/Num.sum(wwt)
212
213
214
215
216
217
218
219 -def eff_perplexity(wt, correlation_length, threshold = 0.2, name='Chebyshev2'):
220 assert len(wt.shape) == 1
221 n = wt.shape[0]
222 assert wt.shape == (n,)
223 op = ortho_poly.F(name, n)
224 x = []
225 sqwt = Num.sqrt(wt * op.wt())
226
227 for i in range(n-1):
228 x.append( op.P(i) * sqwt )
229 x = Num.array(x, Num.Float)
230 xx = Num.matrixmultiply(x, Num.transpose(x))
231 assert xx.shape == (n-1,n-1), "xx.shape=%s, n=%d" % (str(xx.shape), n)
232 ev = LinearAlgebra.Heigenvalues(xx)
233
234 mxev = ev[Num.argmax(ev)]
235
236 return Num.sum(ev > threshold**2 * mxev) / math.hypot(correlation_length, 1.0)
237
238
239
241 import RandomArray
242 N = 2
243 M = 10
244 wt1 = RandomArray.exponential(1.0, (M,))
245 wt = Num.concatenate((wt1, wt1[::-1]))
246
247 y1 = RandomArray.exponential(1.0, (M,))
248 y = Num.concatenate((y1, y1[::-1]))
249 ya = Num.sum(y*wt)/Num.sum(wt)
250 y /= ya
251
252 x = (Num.arrayrange(2*M)-M+0.5)/float(2*M)
253 soln, swt, svr = ortho_fit(x, y, wt, N, svr_lim=None, svrls=None, name='Legendre')
254 assert svr>0.5 and svr<=1.0
255 assert abs(swt - 2*M) < 5*math.sqrt(2*M)
256 assert abs(soln.c[0]-1) < 1e-6
257 assert abs(soln.c[1]) < 1e-6
258
259
261 import RandomArray
262 N = 5
263 M = 60
264 x = (Num.arrayrange(2*M)-M+0.5)/float(2*M)
265 op = ortho_poly.Legendre(x=x)
266 y = op.P(0) + op.P(1) - op.P(2) - op.P(3)
267 wt = RandomArray.exponential(1.0, (2*M,))
268 soln, swt, svr = ortho_fit(x, y, wt, N, svr_lim=None, svrls=None, name='Legendre')
269 assert Num.alltrue( Num.absolute(soln.c - [1, 1, -1, -1, 0]) < 1e-6)
270 assert svr > 0.5 and svr<=1.0
271 assert abs(swt - Num.sum(wt)) < 1e-6
272 fitdiff = soln.fit() - y
273 assert Num.sum(fitdiff**2) < 0.001
274 diff = soln.expand(x) - y
275 assert Num.sum(diff**2) < 0.001
276
277
278 if __name__ == '__main__':
279 test_r()
280 test_expand()
281