[frames] | no frames]

# Source Code for Module gmisclib.nmf

```  1  """V=WH, where W and H are non-negative.
2
3  From D. D. Lee and H. S. Seung, 'Learning the parts of objects by
4  nonnegative matrix factorization.'
5  """
6
7  import Num
8  import math
9  import die
10
11
12
13  EPS = 1e-6
14  CC = 2
15  IEPS = 0.03
16
17 -def _norm(x):
18          try:
19                  tmp = math.sqrt(Num.sum(x**2, axis=None))
20          except OverflowError, x:
21                  print "x=", x
22                  raise
23          return tmp
24
25
26 -def _converged(wh, whold, v, fudge):
27          F = math.sqrt(wh.shape[0] * wh.shape[1])
28          d1 = _norm(wh-whold)/_norm(v)
29          d2 = _norm(v-wh)/_norm(v)
30          die.dbg("Converged= %f %f" % (d1, d2))
31          return min(d1*F,d2) < EPS*fudge
32
33
34 -def _updateH(w, h, wh, v):
35          # eps = EPS*(Num.sum(v)
36          eps = EPS * Num.sum(wh, axis=None)/(v.shape[0]*v.shape[1])
37          # f = Num.matrixmultiply(Num.transpose(w), (v+eps)/(wh+eps))
38          f = Num.matrixmultiply(Num.transpose(w), v/(wh+eps))
39          # print "H: f=", f
40          return h * f
41
42
43 -def _updateW(w, h, wh, v):
44          eps = EPS * Num.sum(wh, axis=None)/(v.shape[0]*v.shape[1])
45          print 'eps=', eps
46          # f = Num.matrixmultiply((v+eps)/(wh+eps), Num.transpose(h))
47          f = Num.matrixmultiply(v/(wh+eps), Num.transpose(h))
48          print 'f=', f
49          print "W:v/wh=", v/(wh+eps)
50          print "W:f=", f
51          new_w = w * f
52          print "new_w=", new_w
53          first_index_sum = Num.sum(new_w, axis=0)
54          print "fis=", first_index_sum
55          print 'new_w=', new_w
56          o = new_w/first_index_sum[Num.NewAxis, :]
57          print "o=", o
58          return o
59
60
61 -def _initialize(v, rank):
62          n, m = v.shape
63          w = Num.RA.standard_normal((n, rank))**2 + IEPS
64          c = Num.sum(v**2, axis=0)/(v.shape[0]*v.shape[1])
65          h = c * (Num.RA.standard_normal((rank, m))**2 + IEPS)
66          return (w, h)
67
68
69 -def nmf(v, rank):
70          assert rank > 0, "Zero rank approximations are usually pretty awful."
71          v = Num.asarray(v, Num.Float)
72          assert Num.alltrue(Num.greater_equal(Num.ravel(v), 0.0)), "Negative element!"
73          w, h = _initialize(v, rank)
74          cc = 0
75          ic = 0
76          wh = Num.zeros(v.shape, Num.Float)
77          while 1:
78                  # die.dbg("Loop, cc= %d" % cc)
79                  whold = wh
80                  wh = Num.matrixmultiply(w, h)
81                  if _converged(wh, whold, v, math.sqrt(ic/float(rank))):
82                          cc += 1
83                  else:
84                          cc = 0
85                  # print "cc=", cc
86                  if cc > CC:
87                          break
88                  wnew = _updateW(w, h, wh, v)
89                  hnew = _updateH(w, h, wh, v)
90                  w = wnew
91                  h = hnew
92                  ic += 1
93
94          return (w, h, _norm(v - Num.matrixmultiply(w, h)) )
95
96
97
98
99 -def _test1():
100          a = [[1, 0], [1, 0], [0, 0]]
101          w, h, err = nmf(a, 1)
102          wh = Num.matrixmultiply(w, h)
103          assert _norm(wh - a) < 30*EPS
104          assert err<0.001
105
106 -def _test2():
107          a = [[1, 0], [0, 1], [0, 0]]
108          w, h, err = nmf(a, 2)
109          wh = Num.matrixmultiply(w, h)
110          assert _norm(wh - a) < 30*EPS
111          assert err<0.001
112
113
114 -def _test3( rank ):
115          a = [[1.0, 0.0, 0.5], [0.0, 1.0, 0.5], [1.0, 1.0, 1.0]]
116          w, h, err = nmf(a, rank)
117          wh = Num.matrixmultiply(w, h)
118          assert _norm(wh - a) < 30*math.sqrt(EPS)
119          assert err<0.001
120
121
122  if __name__ == '__main__':
123          print "TEST1"
124          _test1()
125          print
126          print "TEST2"
127          _test2()
128          print
129          print "TEST3(1)"
130          try:
131                  _test3(1)
132          except AssertionError:
133                  pass
134          else:
135                  raise AssertionError, "Test3(1) should fail!"
136          print
137          print "TEST3(2)"
138          _test3(2)
139          print
140          print "TEST3(3)"
141          _test3(3)
142          print
143          print "TEST3(4)"
144          _test3(4)
145          # print
146          print "TEST3(5)"
147          _test3(5)
148          print
149          print "LAST"
150
151          a = [[1, 0], [0, 1]]
152          w, h, err = nmf(a, 4)
153          print "w=", w
154          print "h=", h
155          print "wh=", Num.matrixmultiply(w, h)
156          print "a=", a
157
<!--
expandto(location.href);
// -->

```

 Generated by Epydoc 3.0.1 on Thu Sep 22 04:25:12 2011 http://epydoc.sourceforge.net