1
2
3 """This rotates a matrix by multiplying with a unitary matrix
4 so that the resulting elements are either nearly zero or relatively large.
5 It minimizes sum of |x_ij|*log(|x_ij|), i.e. the entropy (sortof).
6 """
7
8 import math
9 from gmisclib import Num
10 from gmisclib import mcmc
11 from gmisclib import mcmc_helper
12 from gmisclib import gpkmisc
13
14
16 """extra_entropy is a vector that helps choose what to optimize."""
17 x = Num.asarray(x, Num.Float)
18 print 'x.shape=', x.shape
19 ndim, nvec = x.shape
20
21 def unitary(p):
22 assert p.shape[0] == nvec**2
23 pm = Num.reshape(p, (nvec, nvec))
24 psym = Num.dot(pm, Num.transpose(pm))
25 evalues, umat = Num.LA.eigh(psym)
26 return umat
27
28 def rotated(xt, p):
29 frot = Num.dot(unitary(p), xt)
30 assert frot.shape == xt.shape
31 return frot
32
33 def fom(p, c):
34 xt, nvec, extra_entropy, G = c
35 frot = rotated(xt, p)
36 assert frot.shape[0] == nvec
37
38 frotfom = Num.absolute(frot)
39 frotn = frotfom/Num.sum(frotfom, axis=1)[:,Num.NewAxis]
40 assert abs(Num.sum(frotn[0])-1.0) < 0.01
41 nege = Num.sum(frotn*(Num.log(frotn)-extra_entropy), axis=1)
42 assert nege.shape == (nvec,)
43 print 'fom=', Num.sum(nege)
44 return G*Num.sum(nege)
45
46 def fixer(p, c):
47 Num.divide(p, math.sqrt(Num.average(Num.square(p))), p)
48 return p
49
50 prmvec = Num.RA.normal(0.0, 1.0, (nvec**2,))
51 stpr = mcmc.bootstepper(fom, [prmvec], v=gpkmisc.make_diag(Num.square(prmvec)),
52 c = (Num.transpose(x), nvec, extra_entropy, 20.0),
53 fixer=fixer
54 )
55 mcmch = mcmc_helper.stepper(stpr)
56 mcmch.run_to_bottom()
57 stpr.T = 0.3
58 mcmch.run_to_bottom()
59 stpr.T = 0.1
60 mcmch.run_to_bottom()
61 prms = stpr.current().prms()
62 return (Num.transpose(rotated(Num.transpose(x), prms)),
63 unitary(prms)
64 )
65
66
76
78 x = [[0, 0, 1], [1.0, -1.0, 0], [1.0, 1.0, 0]]
79 y,u = make_min_entropy(x)
80 assert abs(Num.sum(u**2) - u.shape[0]) < 0.001
81 print 'y=', y
82 ay = Num.absolute(y)
83 assert Num.sum(Num.less(ay, 0.01)) == 6
84 r2 = math.sqrt(2.0)
85 assert Num.sum(Num.less(Num.absolute(ay-r2), 0.01)) == 2
86 assert Num.sum(Num.less(Num.absolute(ay-1), 0.01)) == 1
87
88
89 if __name__ == '__main__':
90 test2()
91 test3()
92