[frames] | no frames]

# Source Code for Module ortho_fit

```  1  import math
2
3  import LinearAlgebra
4
5  import ortho_poly
6
7  from gmisclib import Num
8  from gmisclib import die
9  from gmisclib import gpk_lsq
10
11
12 -class NotEnoughData(RuntimeError):
13 -        def __init__(self, x=''):
14                  RuntimeError.__init__(self, x)
15
16
17  SING_VAL_RATIO_LIM = 1e-6
18
19
21 -        def __init__(self, s=""):
22                  RuntimeError.__init__(self, "Support for fit is poor: %s" % s)
23
24
25 -class solution:
26 -        def __init__(self, polys, coefs, err=None, rank=None, fit=None):
27                  self.p = polys
28                  self.c = coefs
29                  self.e = err
30                  self.rank = rank
31                  if fit is not None:
32                          self._fit = Num.array(fit)
33                  else:
34                          self._fit = None
35
36 -        def fit(self):
37                  if self._fit is None:
38                          self._fit = self.p.expand(self.c)
39                  return self._fit
40
41
42 -        def expand(self, time):
43                  sum = Num.zeros(time.shape, Num.Float)
44                  for i in range(len(self.c)):
45                          # print "OPPS=", op.P(i).shape, "SS=", soln.shape
46                          # print "I=", i, "op.P=", op.P(i), "SOLN=", soln[i]
47                          sum += self.p.P(i) * self.c[i]
48                  return sum
49
50
51 -def perplexity(x):
52          p = x / Num.sum(x)
53          return math.exp(-Num.sum(p * Num.log(p)))
54
55
56
57
58
59 -def poly(x=None, n=None, name='Chebyshev2'):
60          return ortho_poly.F(name, x=x, n=n)
61
62
63 -def ortho_fit_svd(x, y, wt, N, svr_lim=None, svrls=None, name='Chebyshev2'):
64          """This uses a SVD fit to handle ill-constrained fits."""
65          assert N >= 0
66          assert N < 100
67          EPS = 1e-8
68          die.note('y.shape', y.shape)
69          die.note('len(y)', len(y))
70          die.note('N', N)
71          assert Num.alltrue(Num.absolute(x)<=1+EPS), "Time must be normalized into [-1,1]"
72          # die.note('x', x)
73          # die.note('y', y)
74          if len(y) < N:
75                  raise NotEnoughData, "y"
76          op = ortho_poly.F(name, x=x)
77          wwt = op.wt() * wt
78          rwwt = Num.sqrt(wwt)
79          if N == 0:
80                  return ( solution(op, Num.zeros((0,), Num.Float)),
81                                  Num.sum(wwt), None
82                          )
83
84          assert len(wwt.shape) == 1
85          yw = y * rwwt
86          assert len(yw.shape)==1
87          die.note('yw.shape', yw.shape)
88          die.note('len(yw)', len(yw))
89          # print "CSC=", yw
90          if Num.sum( wwt > 1e-10 ) <= N:
91                  raise NotEnoughData, "wt"
92          basisvec = []
93          for i in range(N):
94                  # sys.stderr.writelines('v')
95                  basisvec.append( op.P(i) * rwwt )
96          basisw = Num.transpose(basisvec)
97          assert len(basisw.shape) == 2
98          # print "b.shape=", b.shape, "y.shape=", y.shape
99          # print "bc=", bc
100          # print "yw=", yw
101          if svrls is None:
102                  # If you are fitting a linear equation with support at
103                  # (0,0)*N, (1,0)*N, (1,1), (i.e. there are N point to
104                  # determine the first coefficient, and only one point that
105                  # determines the second coefficient, then the singular values
106                  # are sqrt(N) and 1, so the limit to the ratio of singular
107                  # values is set to be a little smaller than that.
108                  svrls = 0.5/math.sqrt(len(y))
109                  # So, if the rank is limited because if this, the equations
110                  # must be both somewhat singular and poorly supported.
111          # print 'yw=', yw
112          # print 'basisw=', basisw
113          raise Exception, "Next line is obsolete!"
114          soln, fitw, rank, q = gpk_lsq.linear_least_squares(basisw, yw, svrls)
115          # Gotta be careful here: fitw is the fit to the data multiplied
116          # by the weight.  It is *not* the fit to the input data.
117
118          assert len(soln) == N
119          # print 'soln=', soln
120          # print 'rank=', rank
121          # print 'q=', q
122          # print 'svrls=', svrls
123          try:
124                  svr = min(q) / max(q)
125          except ZeroDivisionError:
126                  die.warn('Largest singular value is apparently zero!')
127                  raise
128
129          if svr_lim is None:
130                  svr_lim = SING_VAL_RATIO_LIM
131          if svr < svr_lim:
132                  # print 'q=', q
133                  # print 'wt=', wt
134                  # print 'wwt=', wwt
135                  # print 'basisw=', basisw
136                  raise BadDataSupport, 'Ortho_fit is nearly degenerate (svr=%g::%g)' % (svr, svr_lim)
137
138          assert len(soln) == N
139          swt = Num.sum(wwt)
140          soln = solution(op, soln, err=Num.sum((fitw-yw)**2)/swt, rank=rank)
141          return (soln, swt, svr)
142
143
144  ortho_fit = ortho_fit_svd
145
146
147 -def ortho_fit_reg(x, y, wt, N, regstr, regtgt,
148                          rscale=None, name='Chebyshev2'):
149          """This uses linear regularization to handle ill-constrained fits."""
150          assert N >= 0
151          assert N < 100
152          EPS = 1e-8
153          die.note('y.shape', y.shape)
154          die.note('len(y)', len(y))
155          die.note('N', N)
156          assert Num.alltrue(Num.absolute(x)<=1+EPS), "Time must be normalized into [-1,1]"
157          # die.note('x', x)
158          # die.note('y', y)
159          if len(y) < N:
160                  raise NotEnoughData, "y"
161          op = ortho_poly.F(name, x=x)
162          wwt = op.wt() * wt
163          rwwt = Num.sqrt(wwt)
164          if N == 0:
165                  raise Exception, "something wrong here -- inconsistent return values"
166                  return ( solution(op, Num.zeros((0,), Num.Float)),
167                                  Num.sum(wwt), None
168                          )
169
170          assert len(wwt.shape) == 1
171          yw = y * rwwt
172          assert len(yw.shape)==1
173          die.note('yw.shape', yw.shape)
174          die.note('len(yw)', len(yw))
175          # print "CSC=", yw
176          if Num.sum( wwt > 1e-10 ) <= N:
177                  raise NotEnoughData, "wt"
178          basisvec = []
179          for i in range(N):
180                  # sys.stderr.writelines('v')
181                  basisvec.append( op.P(i) * rwwt )
182          basisw = Num.transpose(basisvec)
183          assert len(basisw.shape) == 2
184          # print "b.shape=", b.shape, "y.shape=", y.shape
185          # print "bc=", bc
186          # print "yw=", yw
187          # print 'basisw=', basisw
188          rlss = gpk_lsq.reg_linear_least_squares(basisw, yw, regstr,
189                                                  regtgt, rscale=rscale)
190          # Gotta be careful here: this is the fit to the data multiplied
191          # by the weight.  It is *not* the fit to the input data.
192
193          assert len(rlss.x) == N
194          swt = Num.sum(wwt)
195          rank = rlss.eff_rank()
196          soln = solution(op, rlss.x, err=Num.sum((rlss.fit()-yw)**2)/swt, rank=rank)
197          raise Exception, "something wrong here -- inconsistent return values"
198          return (soln, swt)
199
200
201
202 -def ortho_err_rms(soln, x, data, wt):
203          wwt = soln.p.wt() * wt
204          r = data - soln.expand(x)
205          return Num.sum(wwt * r**2)/Num.sum(wwt)
206
207
208 -def ortho_err_abs(soln, x, data, wt):
209          wwt = soln.p.wt() * wt
210          r = data - soln.expand(x)
211          return Num.sum(wwt*Num.absolute(r))/Num.sum(wwt)
212
213
214
215
216
217
218
219 -def eff_perplexity(wt, correlation_length, threshold = 0.2, name='Chebyshev2'):
220          assert len(wt.shape) == 1
221          n = wt.shape[0]
222          assert wt.shape == (n,)
223          op = ortho_poly.F(name, n)
224          x = []
225          sqwt = Num.sqrt(wt * op.wt())
226          # print 'EP:', sqwt
227          for i in range(n-1):
228                  x.append( op.P(i) * sqwt )
229          x = Num.array(x, Num.Float)
230          xx = Num.matrixmultiply(x, Num.transpose(x))
231          assert xx.shape == (n-1,n-1), "xx.shape=%s, n=%d" % (str(xx.shape), n)
232          ev = LinearAlgebra.Heigenvalues(xx)
233          # print 'Eff_perp ev:', ev
234          mxev = ev[Num.argmax(ev)]
235          # print 'Eff_perp mxev:', mxev, correlation_length, threshold**2 * mxev, Num.sum(ev > threshold**2 * mxev)
236          return Num.sum(ev > threshold**2 * mxev) / math.hypot(correlation_length, 1.0)
237
238
239
240 -def test_r():
241          import RandomArray
242          N = 2
243          M = 10
244          wt1 = RandomArray.exponential(1.0, (M,))
245          wt = Num.concatenate((wt1, wt1[::-1]))
246          # wt = Num.ones((2*M,), Num.Float)*2
247          y1 = RandomArray.exponential(1.0, (M,))
248          y = Num.concatenate((y1, y1[::-1]))
249          ya = Num.sum(y*wt)/Num.sum(wt)
250          y /= ya
251          # Y is constructed to have a mean of 1 and a slope of zero.
252          x = (Num.arrayrange(2*M)-M+0.5)/float(2*M)
253          soln, swt, svr = ortho_fit(x, y, wt, N, svr_lim=None, svrls=None, name='Legendre')
254          assert svr>0.5 and svr<=1.0
255          assert abs(swt - 2*M) < 5*math.sqrt(2*M)
256          assert abs(soln.c[0]-1) < 1e-6
257          assert abs(soln.c[1]) < 1e-6
258
259
260 -def test_expand():
261          import RandomArray
262          N = 5
263          M = 60
264          x = (Num.arrayrange(2*M)-M+0.5)/float(2*M)
265          op = ortho_poly.Legendre(x=x)
266          y = op.P(0) + op.P(1) - op.P(2) - op.P(3)
267          wt = RandomArray.exponential(1.0, (2*M,))
268          soln, swt, svr = ortho_fit(x, y, wt, N, svr_lim=None, svrls=None, name='Legendre')
269          assert Num.alltrue( Num.absolute(soln.c - [1, 1, -1, -1, 0]) < 1e-6)
270          assert svr > 0.5 and svr<=1.0
271          assert abs(swt - Num.sum(wt)) < 1e-6
272          fitdiff = soln.fit() - y
273          assert Num.sum(fitdiff**2) < 0.001
274          diff = soln.expand(x) - y
275          assert Num.sum(diff**2) < 0.001
276
277
278  if __name__ == '__main__':
279          test_r()
280          test_expand()
281
<!--
expandto(location.href);
// -->

```

 Generated by Epydoc 3.0.1 on Tue Aug 25 01:53:13 2009 http://epydoc.sourceforge.net