1
2
3
4 """A classifier that assumes that P is linear in position.
5 This is known as a (linear) logistic discriminant analysis::
6
7 @note: A useful general reference is:
8 @inbook{webb:spr:logistic,
9 author = {Andrew Webb},
10 title = {Statistical Pattern Recognition},
11 pages = {124--132},
12 year = {1999},
13 publisher = {Arnold},
14 address = {London, New York},
15 note = {ISBN 0 340 74164 3}
16 }
17
18 @note: This code was described in an appendix to
19 "Dimensions of durational variation in speech",
20 by Anastassia Loukina, Greg Kochanski, Burton Rosner, Chilin Shih, and Elinor Keane,
21 submitted 2010 to J. Acoustical Society of America.
22 """
23
24
25 import sys
26 import math
27 import random
28 import numpy
29 from gmisclib import die
30 from gmisclib import fiatio
31 from gmisclib import g_closure
32 from gmisclib import gpkmisc
33 from gmisclib import Numeric_gpk
34
35 import q_classifier_r as Q
36
37 COVERAGE = 3
38 FTEST = 0.25
39 N_PER_DIM = 10
40
41
42
43
45 tmp = [ datum.value for datum in data ]
46 return gpkmisc.median_across(tmp)
47
48
50 tmp = [m.reference for m in models.values()]
51 return gpkmisc.median_across(tmp)
52
53
55 """This class is metadata and helper functions for a linear discriminant classifier.
56 """
57
58 - def __init__(self, data=None, evaluator=None, models=None, ftrim=None):
59 if data is not None:
60 classes = Q.list_classes(data)
61 fvdim = Q.get_dim(data)
62 reference_pt = ref_from_data(data)
63 elif models is not None:
64 classes = models.keys()
65 fvdim = models.values()[0].direction.shape[0]
66 reference_pt = ref_from_models(models)
67 else:
68 raise ValueError, "Needs either a set of models or data to build a l_classifier_desc."
69 Q.classifier_desc.__init__(self, classes, fvdim, evaluator=evaluator, ftrim=ftrim)
70 self.reference = reference_pt
71
72
74 """The number of parameters required to define each class."""
75 assert self.ndim > 0
76 return self.ndim + 1
77
79 """The total number of parameters required to define all classes."""
80 assert self.nc > 1, "Only %d classes" % self.nc
81 assert self.np() > 0
82 return (self.nc-1) * self.np()
83
84 - def unpack(self, prmvec, trainingset_name=None, uid=None):
85 """Produce a classifier from a parameter vector."""
86 m = self.np()
87 assert len(prmvec) == self.npt()
88 assert self.c
89
90 o = l_classifier(self, trainingset_name=trainingset_name, uid=uid)
91 for (i, c) in enumerate(self.c):
92 if i >= self.nc-1:
93 break
94
95 pc = prmvec[i*m : (i+1)*m]
96 bias = pc[0]
97 prms = pc[1:]
98 o.add_model( c, Q.lmodel( prms, bias, self.reference ) )
99 o.add_model(c, Q.lzmodel( self.ndim ) )
100 return o
101
102
104 """Starting position for Markov Chain Monte Carlo."""
105 assert self.npt() > 0
106 p = numpy.zeros((self.npt(),))
107 v = numpy.identity(self.npt(), numpy.float)
108 var = Numeric_gpk.vec_variance( [datum.value for datum in data] )
109 if not numpy.greater(var, 0.0).all():
110 die.die("Variance of some data column is not positive: %s" % str(var))
111
112 print 'START: var=', var
113 assert var.shape == (self.ndim,)
114 assert self.npt() % self.np() == 0
115 for i in range(self.npt()):
116 snp = self.np()
117 isnp = i % snp
118 if isnp == 0:
119 sig = 1.0
120 else:
121 sig = math.sqrt(var[isnp-1])
122 p[i] = random.normalvariate(0.0, 1.0/sig)
123 v[i,i] = sig**-2
124
125
126 return (p, v)
127
129 return 'linear_discriminant_classifier'
130
131
132
133
134
135
137 """This is a linear discriminant classifier.
138 @type cdesc: L{l_classifier_desc} or None
139 """
140
141 - def __init__(self, cdesc=None, models=None, trainingset_name=None, uid=None):
150
151
152
153
154
155
156
161 d = Q.read_data(fd)
162 print '# classes: %s' % (' '.join(Q.list_classes(d)))
163 modelchoice = g_closure.Closure(l_classifier_desc, g_closure.NotYet,
164
165 ftrim=ftrim)
166 classout = fiatio.writer(open('classified.fiat', 'w'))
167
168 classout.header('ftrim', ftrim)
169 summary, out, wrong = Q.compute_self_class(d, coverage=coverage, ftest=ftest,
170 n_per_dim = n_per_dim,
171 modelchoice = modelchoice,
172 builder=Q.forest_build,
173 classout=classout, verbose=verbose,
174 modify_class=modify_class)
175 Q.default_writer(summary, out, classout, wrong)
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
216 print 'TEST_BUILD1:'
217 data = [ Q.datum_tr([1.0], 'A'), Q.datum_tr([0.0], 'B') ] * 2
218 assert issubclass(classifier_choice, Q.classifier_desc)
219 modelchoice = classifier_choice(data, evaluator=Q.evaluate_match)
220 assert isinstance(modelchoice, Q.classifier_desc)
221 print 'modelchoice=', modelchoice
222 print 'data=', data
223 ok = 0
224 for cdef in Q.forest_build(data, 1, modelchoice=modelchoice):
225 print 'cdef=', cdef
226 eigenvalue = Q.evaluate_match(cdef, data)
227 print 'eval=', eigenvalue
228 if eigenvalue < 1:
229 ok = 1
230 break
231 if not ok:
232 raise RuntimeError, "can't find a perfect classifier in this trivial situation!"
233
234
236 print 'TEST_BUILD1:'
237 NDAT = 300
238 data = []
239 for i in range(NDAT):
240 data.append(Q.datum_tr([random.normalvariate(1.0, 1./2.828)], 'A') )
241 data.append(Q.datum_tr([random.normalvariate(0.0, 1./2.828)], 'B') )
242 assert issubclass(classifier_choice, Q.classifier_desc)
243 modelchoice = classifier_choice(data, evaluator=Q.evaluate_match)
244 assert isinstance(modelchoice, Q.classifier_desc)
245 print 'modelchoice=', modelchoice
246 print 'data=', data
247 tmp = 0
248 ntmp = 0
249 for cdef in Q.forest_build(data, 10, modelchoice=modelchoice):
250 sc = Q.evaluate_match(cdef, data)
251 print 'cdef=', cdef, "score=", sc
252 tmp += sc
253 ntmp += 1
254 print 'tmp/ntmp=', tmp/ntmp
255 if tmp/float(ntmp) > 0.1 * 2*NDAT:
256 raise RuntimeError, "can't find a perfect classifier in this trivial situation!"
257
258
259
261 print 'TEST_BUILD1t:'
262
263 data = [ Q.datum_tr([100.0], 'A'), Q.datum_tr([-101.0], 'B') ]
264
265 for i in range(30):
266 data.append( Q.datum_tr([random.random()-1.5], 'A'))
267 data.append( Q.datum_tr([random.random()+0.5], 'B'))
268 print 'data=', data
269 assert issubclass(classifier_choice, Q.classifier_desc)
270 modelchoice = classifier_choice(data, evaluator=Q.evaluate_match, ftrim=(0.1,6))
271 assert isinstance(modelchoice, Q.classifier_desc)
272 ok = False
273 for cdef in Q.forest_build(data, 20, modelchoice=modelchoice):
274 e = Q.evaluate_match(cdef, data)
275 if e < 3:
276 ok = 1
277 break
278 else:
279 print 'cdef=', cdef, 'evaluation=', e
280 if not ok:
281 raise RuntimeError, "can't find a perfect classifier in this trivial situation!"
282
283
284
286 print 'TEST_BUILD32:'
287 data = [
288 Q.datum_tr([1.0, 0.0], 'A'),
289 Q.datum_tr([-0.01, 0.0], 'B'),
290 Q.datum_tr([-.01, 1.0], 'C'),
291 Q.datum_tr([1.02, 0.01], 'A'),
292 Q.datum_tr([-0.02, 0.01], 'B'),
293 Q.datum_tr([-.02, 1.01], 'C'),
294 Q.datum_tr([1.01, -0.01], 'A'),
295 Q.datum_tr([0.0, -0.01], 'B'),
296 Q.datum_tr([0, 0.99], 'C'),
297 Q.datum_tr([0.98, 0.02], 'A'),
298 Q.datum_tr([0.02, 0.0], 'B'),
299 Q.datum_tr([-0.01, 0.98], 'C'),
300 Q.datum_tr([0.99, 0.01], 'A'),
301 Q.datum_tr([0.01, 0.03], 'B'),
302 Q.datum_tr([-.02, 0.97], 'C'),
303 ]
304 assert issubclass(classifier_choice, Q.classifier_desc)
305 modelchoice = classifier_choice(data, evaluator=Q.evaluate_match)
306 assert isinstance(modelchoice, Q.classifier_desc)
307 ok = False
308 for cdef in Q.forest_build(data, 20, modelchoice=modelchoice):
309 e = Q.evaluate_match(cdef, data)
310 print cdef
311 if e < 1:
312 ok = True
313 break
314 if not ok:
315 raise RuntimeError, "can't find a perfect classifier in this trivial situation!"
316
317
319 print 'TEST_BUILD2:'
320 data = [
321 Q.datum_tr([1.0, 2.0, 0.0], 'A'),
322 Q.datum_tr([0.0, 1.0, 0.0], 'A'),
323 Q.datum_tr([1.0, 1.0, 1.0], 'A'),
324 Q.datum_tr([2.0, 0.0, 0.0], 'A'),
325 Q.datum_tr([2.0, 2.0, 2.0], 'A'),
326 Q.datum_tr([1.0, 0.0, 1.0], 'A'),
327 Q.datum_tr([0.0, -1.0, -2.0], 'B'),
328 Q.datum_tr([-1.0, -1.0, -2.0], 'B'),
329 Q.datum_tr([0.0, -1.0, -1.0], 'B'),
330 Q.datum_tr([-2.0, -1.0, -2.0], 'B'),
331 Q.datum_tr([0.0, -1.0, 0.0], 'B'),
332 Q.datum_tr([0.0, 0.0, -2.0], 'B'),
333 Q.datum_tr([-2.0, -2.0, -2.0], 'B'),
334 Q.datum_tr([-1.0, -1.0, 0.0], 'B')
335 ]
336 assert issubclass(classifier_choice, Q.classifier_desc)
337 modelchoice = classifier_choice(data, evaluator=Q.evaluate_match)
338 assert isinstance(modelchoice, Q.classifier_desc)
339 ok = False
340 for cdef in Q.forest_build(data, 20, modelchoice=modelchoice):
341 e = Q.evaluate_match(cdef, data)
342 if e < 1:
343 ok = True
344 break
345 if not ok:
346 raise RuntimeError, "can't find a perfect classifier in this trivial situation!"
347
348
349
351 print 'TEST_4_2:'
352 data = []
353 for i in range(400):
354 data.append( Q.datum_tr( [ random.random() ], 'a') )
355 data.append( Q.datum_tr( [ random.random() ], 'b') )
356 data.append( Q.datum_tr( [ 5+2*random.random() ], 'c') )
357 data.append( Q.datum_tr( [ 5+2*random.random() ], 'd') )
358 assert issubclass(classifier_choice, Q.classifier_desc)
359 modelchoice = g_closure.Closure(classifier_choice, g_closure.NotYet,
360 evaluator=Q.evaluate_match)
361 summary, out, wrong = Q.compute_self_class(data, ftest=FTEST, coverage=COVERAGE,
362 modelchoice = modelchoice,
363 n_per_dim = N_PER_DIM,
364 builder=Q.forest_build)
365 assert abs(summary['Chance'] - 0.25) < 0.1
366 assert abs(summary['Pcorrect'] - 0.5) < 0.1
367 assert summary['Pcorrect'] > summary['Chance']
368 print 'summary:', summary
369
370
371
373 print 'TEST_2_bias:'
374 data = []
375 for i in range(400):
376 data.append( Q.datum_tr( [ random.random() ], 'a') )
377 for i in range(40):
378 data.append( Q.datum_tr( [ random.random() ], 'b') )
379 assert issubclass(classifier_choice, Q.classifier_desc)
380 modelchoice = g_closure.Closure(classifier_choice, g_closure.NotYet,
381 evaluator=Q.evaluate_match)
382 classout = []
383 summary, out, wrong = Q.compute_self_class(data, ftest=FTEST, coverage=COVERAGE,
384 modelchoice = modelchoice,
385 n_per_dim = N_PER_DIM,
386 builder=Q.forest_build,
387 classout=classout)
388 print 'summary:', summary
389 assert abs(summary['Chance'] - 0.9) < 0.05
390 assert abs(summary['Pcorrect'] - 0.9) < 0.05
391 nbc = 0
392 nbt = 0
393 ntot = 0
394 for tmp in classout:
395 if tmp['trueclass']=='b':
396 nbt += 1
397 if tmp['compclass']=='b':
398 nbc += 1
399 ntot += 1
400 print 'true b=', nbt/float(ntot), 'computed b=', nbc/float(ntot)
401 assert abs(nbt/float(ntot) - 0.1) < 0.03
402 assert nbc < 1.3*nbt
403
404
406 o = []
407 for i in range(1,8):
408 o.append( random.gauss(off*i, f*i*i) )
409 return o
410
412 print 'TEST_2_scale:'
413 data = []
414 for i in range(300):
415 data.append( Q.datum_tr( _t2sguts(0.0, 0.001), 'a') )
416 for i in range(100):
417 data.append( Q.datum_tr( _t2sguts(0.0, 0.001), 'b') )
418 assert issubclass(classifier_choice, Q.classifier_desc)
419 modelchoice = g_closure.Closure(classifier_choice, g_closure.NotYet,
420 evaluator=Q.evaluate_match)
421 classout = []
422 summary, out, wrong = Q.compute_self_class(data, ftest=FTEST, coverage=COVERAGE,
423 modelchoice = modelchoice,
424 n_per_dim = N_PER_DIM,
425 builder=Q.forest_build,
426 classout=classout)
427 print 'summary:', summary
428 assert abs(summary['Chance'] - 0.75) < 0.05
429 assert abs(summary['Pcorrect'] - 0.75) < 0.10
430 nbc = 0
431 nbt = 0
432 ntot = 0
433 for tmp in classout:
434 if tmp['trueclass']=='b':
435 nbt += 1
436 if tmp['compclass']=='b':
437 nbc += 1
438 ntot += 1
439 print 'true b=', nbt/float(ntot), 'computed b=', nbc/float(ntot)
440 assert abs(nbt/float(ntot) - 0.25) < 0.10
441 assert nbc < 1.3*nbt
442
443
444
446 print 'TEST_var:'
447 data = []
448 for i in range(80):
449 data.append( Q.datum_tr( [ random.gauss(0.0, 1.0), random.gauss(0.0, 1.0) ], 'a') )
450 for i in range(160):
451 data.append( Q.datum_tr( [ random.gauss(0.0, 1.0), random.gauss(0.0, 1.0)+0.01 ], 'b') )
452 modelchoice = g_closure.Closure(l_classifier_desc, data,
453 evaluator=Q.evaluate_match)
454 summary, out, wrong = Q.compute_self_class(data,
455 modelchoice = modelchoice,
456 builder=Q.forest_build)
457 print 'summary:', summary
458
459
460
461
463 for i in range(15):
464 test_var()
465 sys.stdout.flush()
466
478
479
480 PSYCO = False
481
482 if __name__ == '__main__':
483 test()
484