1 """V=WH, where W and H are non-negative.
2
3 From D. D. Lee and H. S. Seung, 'Learning the parts of objects by
4 nonnegative matrix factorization.'
5 """
6
7 import Num
8 import math
9 import die
10
11
12
13 EPS = 1e-6
14 CC = 2
15 IEPS = 0.03
16
18 try:
19 tmp = math.sqrt(Num.sum(x**2, axis=None))
20 except OverflowError, x:
21 print "x=", x
22 raise
23 return tmp
24
25
27 F = math.sqrt(wh.shape[0] * wh.shape[1])
28 d1 = _norm(wh-whold)/_norm(v)
29 d2 = _norm(v-wh)/_norm(v)
30 die.dbg("Converged= %f %f" % (d1, d2))
31 return min(d1*F,d2) < EPS*fudge
32
33
35
36 eps = EPS * Num.sum(wh, axis=None)/(v.shape[0]*v.shape[1])
37
38 f = Num.matrixmultiply(Num.transpose(w), v/(wh+eps))
39
40 return h * f
41
42
44 eps = EPS * Num.sum(wh, axis=None)/(v.shape[0]*v.shape[1])
45 print 'eps=', eps
46
47 f = Num.matrixmultiply(v/(wh+eps), Num.transpose(h))
48 print 'f=', f
49 print "W:v/wh=", v/(wh+eps)
50 print "W:f=", f
51 new_w = w * f
52 print "new_w=", new_w
53 first_index_sum = Num.sum(new_w, axis=0)
54 print "fis=", first_index_sum
55 print 'new_w=', new_w
56 o = new_w/first_index_sum[Num.NewAxis, :]
57 print "o=", o
58 return o
59
60
62 n, m = v.shape
63 w = Num.RA.standard_normal((n, rank))**2 + IEPS
64 c = Num.sum(v**2, axis=0)/(v.shape[0]*v.shape[1])
65 h = c * (Num.RA.standard_normal((rank, m))**2 + IEPS)
66 return (w, h)
67
68
70 assert rank > 0, "Zero rank approximations are usually pretty awful."
71 v = Num.asarray(v, Num.Float)
72 assert Num.alltrue(Num.greater_equal(Num.ravel(v), 0.0)), "Negative element!"
73 w, h = _initialize(v, rank)
74 cc = 0
75 ic = 0
76 wh = Num.zeros(v.shape, Num.Float)
77 while 1:
78
79 whold = wh
80 wh = Num.matrixmultiply(w, h)
81 if _converged(wh, whold, v, math.sqrt(ic/float(rank))):
82 cc += 1
83 else:
84 cc = 0
85
86 if cc > CC:
87 break
88 wnew = _updateW(w, h, wh, v)
89 hnew = _updateH(w, h, wh, v)
90 w = wnew
91 h = hnew
92 ic += 1
93
94 return (w, h, _norm(v - Num.matrixmultiply(w, h)) )
95
96
97
98
100 a = [[1, 0], [1, 0], [0, 0]]
101 w, h, err = nmf(a, 1)
102 wh = Num.matrixmultiply(w, h)
103 assert _norm(wh - a) < 30*EPS
104 assert err<0.001
105
107 a = [[1, 0], [0, 1], [0, 0]]
108 w, h, err = nmf(a, 2)
109 wh = Num.matrixmultiply(w, h)
110 assert _norm(wh - a) < 30*EPS
111 assert err<0.001
112
113
115 a = [[1.0, 0.0, 0.5], [0.0, 1.0, 0.5], [1.0, 1.0, 1.0]]
116 w, h, err = nmf(a, rank)
117 wh = Num.matrixmultiply(w, h)
118 assert _norm(wh - a) < 30*math.sqrt(EPS)
119 assert err<0.001
120
121
122 if __name__ == '__main__':
123 print "TEST1"
124 _test1()
125 print
126 print "TEST2"
127 _test2()
128 print
129 print "TEST3(1)"
130 try:
131 _test3(1)
132 except AssertionError:
133 pass
134 else:
135 raise AssertionError, "Test3(1) should fail!"
136 print
137 print "TEST3(2)"
138 _test3(2)
139 print
140 print "TEST3(3)"
141 _test3(3)
142 print
143 print "TEST3(4)"
144 _test3(4)
145
146 print "TEST3(5)"
147 _test3(5)
148 print
149 print "LAST"
150
151 a = [[1, 0], [0, 1]]
152 w, h, err = nmf(a, 4)
153 print "w=", w
154 print "h=", h
155 print "wh=", Num.matrixmultiply(w, h)
156 print "a=", a
157