Package gmisclib :: Module nmf
[frames] | no frames]

Source Code for Module gmisclib.nmf

  1  """V=WH, where W and H are non-negative. 
  2   
  3  From D. D. Lee and H. S. Seung, 'Learning the parts of objects by 
  4  nonnegative matrix factorization.' 
  5  """ 
  6   
  7  import Num 
  8  import math 
  9  import die 
 10   
 11   
 12   
 13  EPS = 1e-6 
 14  CC = 2 
 15  IEPS = 0.03 
 16   
17 -def _norm(x):
18 try: 19 tmp = math.sqrt(Num.sum(x**2, axis=None)) 20 except OverflowError, x: 21 print "x=", x 22 raise 23 return tmp
24 25
26 -def _converged(wh, whold, v, fudge):
27 F = math.sqrt(wh.shape[0] * wh.shape[1]) 28 d1 = _norm(wh-whold)/_norm(v) 29 d2 = _norm(v-wh)/_norm(v) 30 die.dbg("Converged= %f %f" % (d1, d2)) 31 return min(d1*F,d2) < EPS*fudge
32 33
34 -def _updateH(w, h, wh, v):
35 # eps = EPS*(Num.sum(v) 36 eps = EPS * Num.sum(wh, axis=None)/(v.shape[0]*v.shape[1]) 37 # f = Num.matrixmultiply(Num.transpose(w), (v+eps)/(wh+eps)) 38 f = Num.matrixmultiply(Num.transpose(w), v/(wh+eps)) 39 # print "H: f=", f 40 return h * f
41 42
43 -def _updateW(w, h, wh, v):
44 eps = EPS * Num.sum(wh, axis=None)/(v.shape[0]*v.shape[1]) 45 print 'eps=', eps 46 # f = Num.matrixmultiply((v+eps)/(wh+eps), Num.transpose(h)) 47 f = Num.matrixmultiply(v/(wh+eps), Num.transpose(h)) 48 print 'f=', f 49 print "W:v/wh=", v/(wh+eps) 50 print "W:f=", f 51 new_w = w * f 52 print "new_w=", new_w 53 first_index_sum = Num.sum(new_w, axis=0) 54 print "fis=", first_index_sum 55 print 'new_w=', new_w 56 o = new_w/first_index_sum[Num.NewAxis, :] 57 print "o=", o 58 return o
59 60
61 -def _initialize(v, rank):
62 n, m = v.shape 63 w = Num.RA.standard_normal((n, rank))**2 + IEPS 64 c = Num.sum(v**2, axis=0)/(v.shape[0]*v.shape[1]) 65 h = c * (Num.RA.standard_normal((rank, m))**2 + IEPS) 66 return (w, h)
67 68
69 -def nmf(v, rank):
70 assert rank > 0, "Zero rank approximations are usually pretty awful." 71 v = Num.asarray(v, Num.Float) 72 assert Num.alltrue(Num.greater_equal(Num.ravel(v), 0.0)), "Negative element!" 73 w, h = _initialize(v, rank) 74 cc = 0 75 ic = 0 76 wh = Num.zeros(v.shape, Num.Float) 77 while 1: 78 # die.dbg("Loop, cc= %d" % cc) 79 whold = wh 80 wh = Num.matrixmultiply(w, h) 81 if _converged(wh, whold, v, math.sqrt(ic/float(rank))): 82 cc += 1 83 else: 84 cc = 0 85 # print "cc=", cc 86 if cc > CC: 87 break 88 wnew = _updateW(w, h, wh, v) 89 hnew = _updateH(w, h, wh, v) 90 w = wnew 91 h = hnew 92 ic += 1 93 94 return (w, h, _norm(v - Num.matrixmultiply(w, h)) )
95 96 97 98
99 -def _test1():
100 a = [[1, 0], [1, 0], [0, 0]] 101 w, h, err = nmf(a, 1) 102 wh = Num.matrixmultiply(w, h) 103 assert _norm(wh - a) < 30*EPS 104 assert err<0.001
105
106 -def _test2():
107 a = [[1, 0], [0, 1], [0, 0]] 108 w, h, err = nmf(a, 2) 109 wh = Num.matrixmultiply(w, h) 110 assert _norm(wh - a) < 30*EPS 111 assert err<0.001
112 113
114 -def _test3( rank ):
115 a = [[1.0, 0.0, 0.5], [0.0, 1.0, 0.5], [1.0, 1.0, 1.0]] 116 w, h, err = nmf(a, rank) 117 wh = Num.matrixmultiply(w, h) 118 assert _norm(wh - a) < 30*math.sqrt(EPS) 119 assert err<0.001
120 121 122 if __name__ == '__main__': 123 print "TEST1" 124 _test1() 125 print 126 print "TEST2" 127 _test2() 128 print 129 print "TEST3(1)" 130 try: 131 _test3(1) 132 except AssertionError: 133 pass 134 else: 135 raise AssertionError, "Test3(1) should fail!" 136 print 137 print "TEST3(2)" 138 _test3(2) 139 print 140 print "TEST3(3)" 141 _test3(3) 142 print 143 print "TEST3(4)" 144 _test3(4) 145 # print 146 print "TEST3(5)" 147 _test3(5) 148 print 149 print "LAST" 150 151 a = [[1, 0], [0, 1]] 152 w, h, err = nmf(a, 4) 153 print "w=", w 154 print "h=", h 155 print "wh=", Num.matrixmultiply(w, h) 156 print "a=", a 157