Module ortho_fit
[frames] | no frames]

Source Code for Module ortho_fit

  1  import math 
  2   
  3  import LinearAlgebra 
  4   
  5  import ortho_poly 
  6   
  7  from gmisclib import Num 
  8  from gmisclib import die 
  9  from gmisclib import gpk_lsq 
 10   
 11   
12 -class NotEnoughData(RuntimeError):
13 - def __init__(self, x=''):
14 RuntimeError.__init__(self, x)
15 16 17 SING_VAL_RATIO_LIM = 1e-6 18 19
20 -class BadDataSupport(RuntimeError):
21 - def __init__(self, s=""):
22 RuntimeError.__init__(self, "Support for fit is poor: %s" % s)
23 24
25 -class solution:
26 - def __init__(self, polys, coefs, err=None, rank=None, fit=None):
27 self.p = polys 28 self.c = coefs 29 self.e = err 30 self.rank = rank 31 if fit is not None: 32 self._fit = Num.array(fit) 33 else: 34 self._fit = None
35
36 - def fit(self):
37 if self._fit is None: 38 self._fit = self.p.expand(self.c) 39 return self._fit
40 41
42 - def expand(self, time):
43 sum = Num.zeros(time.shape, Num.Float) 44 for i in range(len(self.c)): 45 # print "OPPS=", op.P(i).shape, "SS=", soln.shape 46 # print "I=", i, "op.P=", op.P(i), "SOLN=", soln[i] 47 sum += self.p.P(i) * self.c[i] 48 return sum
49 50
51 -def perplexity(x):
52 p = x / Num.sum(x) 53 return math.exp(-Num.sum(p * Num.log(p)))
54 55 56 57 58
59 -def poly(x=None, n=None, name='Chebyshev2'):
60 return ortho_poly.F(name, x=x, n=n)
61 62
63 -def ortho_fit_svd(x, y, wt, N, svr_lim=None, svrls=None, name='Chebyshev2'):
64 """This uses a SVD fit to handle ill-constrained fits.""" 65 assert N >= 0 66 assert N < 100 67 EPS = 1e-8 68 die.note('y.shape', y.shape) 69 die.note('len(y)', len(y)) 70 die.note('N', N) 71 assert Num.alltrue(Num.absolute(x)<=1+EPS), "Time must be normalized into [-1,1]" 72 # die.note('x', x) 73 # die.note('y', y) 74 if len(y) < N: 75 raise NotEnoughData, "y" 76 op = ortho_poly.F(name, x=x) 77 wwt = op.wt() * wt 78 rwwt = Num.sqrt(wwt) 79 if N == 0: 80 return ( solution(op, Num.zeros((0,), Num.Float)), 81 Num.sum(wwt), None 82 ) 83 84 assert len(wwt.shape) == 1 85 yw = y * rwwt 86 assert len(yw.shape)==1 87 die.note('yw.shape', yw.shape) 88 die.note('len(yw)', len(yw)) 89 # print "CSC=", yw 90 if Num.sum( wwt > 1e-10 ) <= N: 91 raise NotEnoughData, "wt" 92 basisvec = [] 93 for i in range(N): 94 # sys.stderr.writelines('v') 95 basisvec.append( op.P(i) * rwwt ) 96 basisw = Num.transpose(basisvec) 97 assert len(basisw.shape) == 2 98 # print "b.shape=", b.shape, "y.shape=", y.shape 99 # print "bc=", bc 100 # print "yw=", yw 101 if svrls is None: 102 # If you are fitting a linear equation with support at 103 # (0,0)*N, (1,0)*N, (1,1), (i.e. there are N point to 104 # determine the first coefficient, and only one point that 105 # determines the second coefficient, then the singular values 106 # are sqrt(N) and 1, so the limit to the ratio of singular 107 # values is set to be a little smaller than that. 108 svrls = 0.5/math.sqrt(len(y)) 109 # So, if the rank is limited because if this, the equations 110 # must be both somewhat singular and poorly supported. 111 # print 'yw=', yw 112 # print 'basisw=', basisw 113 raise Exception, "Next line is obsolete!" 114 soln, fitw, rank, q = gpk_lsq.linear_least_squares(basisw, yw, svrls) 115 # Gotta be careful here: fitw is the fit to the data multiplied 116 # by the weight. It is *not* the fit to the input data. 117 118 assert len(soln) == N 119 # print 'soln=', soln 120 # print 'rank=', rank 121 # print 'q=', q 122 # print 'svrls=', svrls 123 try: 124 svr = min(q) / max(q) 125 except ZeroDivisionError: 126 die.warn('Largest singular value is apparently zero!') 127 raise 128 129 if svr_lim is None: 130 svr_lim = SING_VAL_RATIO_LIM 131 if svr < svr_lim: 132 # print 'q=', q 133 # print 'wt=', wt 134 # print 'wwt=', wwt 135 # print 'basisw=', basisw 136 raise BadDataSupport, 'Ortho_fit is nearly degenerate (svr=%g::%g)' % (svr, svr_lim) 137 138 assert len(soln) == N 139 swt = Num.sum(wwt) 140 soln = solution(op, soln, err=Num.sum((fitw-yw)**2)/swt, rank=rank) 141 return (soln, swt, svr)
142 143 144 ortho_fit = ortho_fit_svd 145 146
147 -def ortho_fit_reg(x, y, wt, N, regstr, regtgt, 148 rscale=None, name='Chebyshev2'):
149 """This uses linear regularization to handle ill-constrained fits.""" 150 assert N >= 0 151 assert N < 100 152 EPS = 1e-8 153 die.note('y.shape', y.shape) 154 die.note('len(y)', len(y)) 155 die.note('N', N) 156 assert Num.alltrue(Num.absolute(x)<=1+EPS), "Time must be normalized into [-1,1]" 157 # die.note('x', x) 158 # die.note('y', y) 159 if len(y) < N: 160 raise NotEnoughData, "y" 161 op = ortho_poly.F(name, x=x) 162 wwt = op.wt() * wt 163 rwwt = Num.sqrt(wwt) 164 if N == 0: 165 raise Exception, "something wrong here -- inconsistent return values" 166 return ( solution(op, Num.zeros((0,), Num.Float)), 167 Num.sum(wwt), None 168 ) 169 170 assert len(wwt.shape) == 1 171 yw = y * rwwt 172 assert len(yw.shape)==1 173 die.note('yw.shape', yw.shape) 174 die.note('len(yw)', len(yw)) 175 # print "CSC=", yw 176 if Num.sum( wwt > 1e-10 ) <= N: 177 raise NotEnoughData, "wt" 178 basisvec = [] 179 for i in range(N): 180 # sys.stderr.writelines('v') 181 basisvec.append( op.P(i) * rwwt ) 182 basisw = Num.transpose(basisvec) 183 assert len(basisw.shape) == 2 184 # print "b.shape=", b.shape, "y.shape=", y.shape 185 # print "bc=", bc 186 # print "yw=", yw 187 # print 'basisw=', basisw 188 rlss = gpk_lsq.reg_linear_least_squares(basisw, yw, regstr, 189 regtgt, rscale=rscale) 190 # Gotta be careful here: this is the fit to the data multiplied 191 # by the weight. It is *not* the fit to the input data. 192 193 assert len(rlss.x) == N 194 swt = Num.sum(wwt) 195 rank = rlss.eff_rank() 196 soln = solution(op, rlss.x, err=Num.sum((rlss.fit()-yw)**2)/swt, rank=rank) 197 raise Exception, "something wrong here -- inconsistent return values" 198 return (soln, swt)
199 200 201
202 -def ortho_err_rms(soln, x, data, wt):
203 wwt = soln.p.wt() * wt 204 r = data - soln.expand(x) 205 return Num.sum(wwt * r**2)/Num.sum(wwt)
206 207
208 -def ortho_err_abs(soln, x, data, wt):
209 wwt = soln.p.wt() * wt 210 r = data - soln.expand(x) 211 return Num.sum(wwt*Num.absolute(r))/Num.sum(wwt)
212 213 214 215 216 217 218
219 -def eff_perplexity(wt, correlation_length, threshold = 0.2, name='Chebyshev2'):
220 assert len(wt.shape) == 1 221 n = wt.shape[0] 222 assert wt.shape == (n,) 223 op = ortho_poly.F(name, n) 224 x = [] 225 sqwt = Num.sqrt(wt * op.wt()) 226 # print 'EP:', sqwt 227 for i in range(n-1): 228 x.append( op.P(i) * sqwt ) 229 x = Num.array(x, Num.Float) 230 xx = Num.matrixmultiply(x, Num.transpose(x)) 231 assert xx.shape == (n-1,n-1), "xx.shape=%s, n=%d" % (str(xx.shape), n) 232 ev = LinearAlgebra.Heigenvalues(xx) 233 # print 'Eff_perp ev:', ev 234 mxev = ev[Num.argmax(ev)] 235 # print 'Eff_perp mxev:', mxev, correlation_length, threshold**2 * mxev, Num.sum(ev > threshold**2 * mxev) 236 return Num.sum(ev > threshold**2 * mxev) / math.hypot(correlation_length, 1.0)
237 238 239
240 -def test_r():
241 import RandomArray 242 N = 2 243 M = 10 244 wt1 = RandomArray.exponential(1.0, (M,)) 245 wt = Num.concatenate((wt1, wt1[::-1])) 246 # wt = Num.ones((2*M,), Num.Float)*2 247 y1 = RandomArray.exponential(1.0, (M,)) 248 y = Num.concatenate((y1, y1[::-1])) 249 ya = Num.sum(y*wt)/Num.sum(wt) 250 y /= ya 251 # Y is constructed to have a mean of 1 and a slope of zero. 252 x = (Num.arrayrange(2*M)-M+0.5)/float(2*M) 253 soln, swt, svr = ortho_fit(x, y, wt, N, svr_lim=None, svrls=None, name='Legendre') 254 assert svr>0.5 and svr<=1.0 255 assert abs(swt - 2*M) < 5*math.sqrt(2*M) 256 assert abs(soln.c[0]-1) < 1e-6 257 assert abs(soln.c[1]) < 1e-6
258 259
260 -def test_expand():
261 import RandomArray 262 N = 5 263 M = 60 264 x = (Num.arrayrange(2*M)-M+0.5)/float(2*M) 265 op = ortho_poly.Legendre(x=x) 266 y = op.P(0) + op.P(1) - op.P(2) - op.P(3) 267 wt = RandomArray.exponential(1.0, (2*M,)) 268 soln, swt, svr = ortho_fit(x, y, wt, N, svr_lim=None, svrls=None, name='Legendre') 269 assert Num.alltrue( Num.absolute(soln.c - [1, 1, -1, -1, 0]) < 1e-6) 270 assert svr > 0.5 and svr<=1.0 271 assert abs(swt - Num.sum(wt)) < 1e-6 272 fitdiff = soln.fit() - y 273 assert Num.sum(fitdiff**2) < 0.001 274 diff = soln.expand(x) - y 275 assert Num.sum(diff**2) < 0.001
276 277 278 if __name__ == '__main__': 279 test_r() 280 test_expand() 281