Package classifiers :: Module g_kmeans
[frames] | no frames]

Source Code for Module classifiers.g_kmeans

  1  #!/usr/bin/env python 
  2   
  3   
  4  import math 
  5  import random 
  6  from gmisclib import g_implements 
  7  from gmisclib import Num 
  8  from gmisclib import gpkmisc 
  9  from gmisclib import Numeric_gpk as NG 
 10  from gmisclib import fiatio 
 11  from gmisclib import chunkio 
 12  from gmisclib import die 
 13  import q_classifier_r 
 14   
 15  HUGE = 1e30 
 16  Clip = None 
 17   
18 -def find_which_cluster(datum, cpos, distfcn):
19 bestdist = HUGE 20 best = -1 21 for (i, cp) in enumerate(cpos): 22 tmp = distfcn(datum.value, cp) 23 if tmp < bestdist: 24 bestdist = tmp 25 best = i 26 return (best, bestdist)
27 28
29 -def find_center(i, membership, data, clip):
30 tmp = [] 31 for (mem, datum) in enumerate(membership, data): 32 if mem == i: 33 tmp.append( datum.value ) 34 ctr = NG.trimmed_mean_across(tmp, None, clip) 35 return ctr
36 37
38 -class cluster_descriptor:
39 - def __init__(self, center, membership, size):
40 self.center = center 41 self.membership = membership 42 self.size = size
43 44 45
46 -def kmeans(data, Ncl, distfcn=None, Nit=None, clip=None):
47 for datum in data: 48 g_implements.check(datum, q_classifier_r.datum_c) 49 50 if Nit is None: 51 Nit = 10 + int(round(math.sqrt(len(data)*Ncl))) 52 if clip is None: 53 clip = 0.10 54 assert len(data) >= Ncl 55 bestcp = None 56 bestm = None 57 besterr = HUGE 58 starts = 0 59 while starts < Nit: 60 cpos = [data[c].value for c in random.sample(range(len(data)), Ncl) ] 61 membership = Num.zeros((len(data),), Num.Int) 62 assert len(cpos) == Ncl 63 passes = 0 64 while passes < Nit: 65 cldtmp = [ [] for x in range(Ncl) ] 66 scatter = 0.0 67 omem = Num.array(membership, copy=True) 68 for (i, datum) in enumerate(data): 69 membership[i], dist = find_which_cluster(datum, cpos, distfcn) 70 if membership[i] >= 0: 71 cldtmp[membership[i]].append(dist) 72 scatter += dist 73 cldist = [ gpkmisc.median(cldtmp[i]) for i in range(Ncl) ] 74 75 for i in range(Ncl): 76 cpos[i] = find_center(i, membership, data, clip) 77 if omem == membership: 78 break 79 passes += 1 80 81 if scatter < besterr: 82 besterr = scatter 83 bestcp = cpos 84 bestm = membership 85 86 starts += 1 87 88 o = [] 89 for i in range(Ncl): 90 m = [ d.uid for (mem, d) in zip(bestm, data) if mem == i ] 91 o.append( cluster_descriptor(bestcp[i], m, cldist[i]) ) 92 map_to_cluster = {} 93 for (datum, mem) in zip(data, bestm): 94 map_to_cluster[datum] = mem 95 return (o, map_to_cluster, besterr)
96 97 98 99
100 -def read_data(fd):
101 """Reads in feature vectors, each with a uid as a comment.""" 102 d = [] 103 dim = None 104 # die.info('Reading') 105 ln = 0 106 # for l in fd.readlines(): 107 for l in fd: 108 ln += 1 109 if l.startswith('#'): 110 continue 111 aa = l.split('#', 1) 112 if len(aa) > 1: 113 uid = aa[1].strip() 114 else: 115 uid = 'Line:%d' % ln 116 a = aa[0].split() 117 if dim is None: 118 dim = len(a) 119 elif len(a) != dim: 120 die.die('Not all vectors have length=%d. Problem on line %d' 121 % (dim, ln) 122 ) 123 d.append( q_classifier_r.datum_c(Num.array([float(x) for x in a]), 124 uid) 125 ) 126 return d
127 128
129 -def euclid(a, b):
130 return math.sqrt(Num.sum(Num.ravel(a-b)**2))
131 132 133 if __name__ == '__main__': 134 import sys 135 arglist = sys.argv[1:] 136 hdrs = {} 137 uidname = 'uid' 138 while arglist and arglist[0].startswith('-'): 139 arg = arglist.pop(0) 140 if arg == '--': 141 break 142 elif arg == '-header': 143 k = arglist.pop(0) 144 v = arglist.pop(0) 145 hdrs[k] = v 146 elif arg == '-uidname': 147 uidname = arglist.pop(0) 148 elif arg == '-clip': 149 Clip = float(arglist.pop(0)) 150 else: 151 die.die('Unrecognized flag: %s' % arg) 152 ncl = int(arglist[0]) 153 d = read_data(sys.stdin) 154 o, map_to_cluster, err = kmeans(d, ncl, distfcn=euclid, clip=Clip) 155 w = fiatio.writer(sys.stdout) 156 w.headers(hdrs) 157 w.header('Ncl', ncl) 158 if Clip is not None: 159 w.header('Clip', Clip) 160 w.header('Error', err) 161 for (clnum, descr) in enumerate(o): 162 w.header('cluster%d' % clnum, 163 chunkio.chunkstring_w().write_NumArray(descr.center, str).close() 164 ) 165 w.header('size%d' % clnum, descr.size) 166 for (clnum, descr) in enumerate(o): 167 for datum_uid in descr.membership: 168 tmp = {'clnum': clnum, uidname: datum_uid} 169 w.datum(tmp) 170 w.close() 171