Package gmisclib :: Module mcmc_newlogger
[frames] | no frames]

Source Code for Module gmisclib.mcmc_newlogger

  1   
  2  import os 
  3  import re 
  4  import cPickle 
  5   
  6  import numpy 
  7   
  8  import die 
  9  import avio 
 10  import gpkmisc as GM 
 11  import dictops 
 12  import load_mod 
 13   
 14  import mcmc_indexclass as IC 
 15   
 16  HUGE = 1e30 
 17   
18 -def ok(x):
19 return -HUGE < x < HUGE
20 21 assert ok(1.0) 22 assert not ok(2*HUGE) 23
24 -class WildNumber(ValueError):
25 - def __init__(self, *s):
26 ValueError.__init__(self, *s)
27 28 29 BadFormatError = avio.BadFormatError 30
31 -class logger_base(object):
32 - def __init__(self, fd):
33 self.writer = avio.writer(fd)
34 35
36 - def flush(self):
37 self.writer.flush()
38 39
40 - def close(self):
41 self.writer.close()
42 43
44 - def header(self, k, v):
45 self.writer.header(k, v) 46 self.writer.flush()
47 48
49 - def headers(self, hdict):
50 kvl = list(hdict.items()) 51 kvl.sort() 52 for (k, v) in kvl: 53 self.header(k, v)
54 55
56 - def comment(self, comment):
57 self.writer.comment(comment) 58 self.writer.flush()
59 60
61 -def _fmt(x):
62 return IC.index_base._fmt(x)
63 64
65 -class DBGlogger(logger_base):
66 - def __init__(self, fd):
67 logger_base.__init__(self, fd)
68 69
70 - def add(self, uid, prms, map, logp, iter, extra=None, reason=None):
71 """Adds parameters to the log file. 72 uid -- which subset of parameters, 73 prms -- a vector of the actual numbers, 74 map -- a dictionary mapping from parameter names to an index into the prms array. 75 logp -- how good is the fit? log(probability) 76 iter -- an integer to keep track of the number of iterations 77 """ 78 79 if extra is not None: 80 tmp = extra.copy() 81 else: 82 tmp = {} 83 if reason is not None: 84 tmp['why_logged'] = reason 85 tmp.update({'uid': uid, 'iter': iter, 'logp': '%.1f' % logp}) 86 for (nm, i) in map.items(): 87 tmp[_fmt(nm)] = '%.2f' % prms[i] 88 self.writer.datum(tmp) 89 self.writer.flush()
90 91 92
93 -class logger(logger_base):
94 - def __init__(self, fd, huge=1e30):
95 """Create a data log. 96 @param fd: Where to log the results 97 @type fd: file 98 @param huge: sets a threshold for throwing L{WildNumber}. 99 If the absolute value of any parameter gets bigger than huge, 100 something is assumed to be wrong. 101 @type huge: normally float, but could be numpy.ndarray 102 """ 103 logger_base.__init__(self, fd) 104 self.nmmap = {} 105 self.HUGE = huge
106 107
108 - def add(self, uid, prms, map, logp, iter, extra=None, reason=None):
109 """Adds a set of parameters to the log file. 110 @param uid: which subset of parameters, 111 @type uid: string 112 @param prms: a vector of the actual numbers, 113 @type prms: numpy.ndarray 114 @param map: a dictionary mapping from parameter names to an index into the prms array. 115 @type map: dict 116 @param logp: how good is the fit? log(probability) 117 @type logp: float 118 @param iter: which iteration is it? 119 @type iter: int 120 @param extra: any extra information desired. 121 @type extra: None or a dictionary. 122 """ 123 deltanm = [] 124 for (nm, i) in map.items(): 125 k = (uid, i) 126 if k not in self.nmmap or self.nmmap[k] != nm: 127 self.nmmap[k] = nm 128 deltanm.append( (i, nm) ) 129 if extra: 130 tmp = extra.copy() 131 else: 132 tmp = {} 133 if reason is not None: 134 tmp['why_logged'] = reason 135 tmp['uid'] = uid 136 tmp['iter'] = iter 137 wild = [] 138 if logp is None or ok(logp): 139 tmp['logp'] = '%.2f' % logp 140 else: 141 wild.append("logp=%s" % str(logp)) 142 okprm = numpy.greater_equal(self.HUGE, prms) * numpy.greater_equal(prms, -self.HUGE) 143 if not okprm.all(): 144 rmap = dictops.rev1to1(map) 145 for (i,isok) in enumerate(okprm): 146 if not isok: 147 wild.append( 'p%d(%s)=%s' % (i, _fmt(rmap[i]), prms[i]) ) 148 tmp['prms'] = cPickle.dumps(prms, protocol=0) 149 tmp['deltanm'] = cPickle.dumps(deltanm, protocol=0) 150 self.writer.datum(tmp) 151 self.writer.flush() 152 if wild: 153 raise WildNumber, '; '.join(wild)
154 155 156 157
158 -class NoDataError(ValueError):
159 """There is not a complete set of data in the log file: 160 data from the outer optimizer and at least one full set of inner optimizer 161 results. 162 """
163 - def __init__(self, *s):
164 ValueError.__init__(self, *s)
165 166
167 -class logline(object):
168 """This holds the information from one line of a log file. 169 For efficiency reasons, some of the data may be stored pickled, 170 and should be accessed through C{prms()} or C{logp()}. 171 """ 172 173 __slots__ = ['uid', 'iter', '_logp', '_prms', 'names', 'extra'] 174
175 - def __init__(self, uid, iter, logp, prms, names, extra):
176 self.uid = uid 177 self.iter = int(iter) 178 self._logp = logp #: left as a string 179 self._prms = prms #: This is a pickled value 180 self.names = names 181 self.extra = extra
182
183 - def prms(self):
184 tmp = cPickle.loads(self._prms) 185 if len(tmp) != len(self.names): 186 raise ValueError, "Problem unpickling logline: len=%d, should be %d" % (len(tmp), len(self.names)) 187 return tmp
188 189
190 - def logp(self):
191 return float(self._logp)
192
193 - def idxr(self):
194 keymap = dict([(key, j) for (j, key) in enumerate(self.names)]) 195 return IC.index(keymap, p=self.prms(), name=self.uid)
196
197 - def __repr__(self):
198 rv = [ '# ' + avio.concoct({'uid': self.uid, 'logP': self.logp(), 'iter': self.iter}) ] 199 for (nm, v) in zip(self.names, self.prms()): 200 rv.append( "%s %s" % (v, IC.index._fmt(nm))) 201 return '\n'.join(rv)
202 203 _flg = re.compile('-?[0-9.]+([eE][+-][0-9]+)?') 204
205 - def T(self):
206 def grabfloat(s): 207 """This is a horrible kluge to deal with some mis-formatted log files.""" 208 m = self._flg.search(s) 209 if m: 210 return float(m.group(0)) 211 raise ValueError, "Cannot find float in {%s}" % s
212 213 try: 214 tee = float(self.extra['T']) 215 except ValueError: 216 tee = grabfloat(self.extra['T']) 217 return tee
218 219
220 -def readiter(logfd, trigger=None):
221 """Reads in one line at a time from a log file. 222 This is an iterator. At every non-comment line, it yields an output. 223 @param logfd: a L{file} descriptor for a log file produced by this module. 224 @param trigger: this is None or a regular expression that specifies when to start paying 225 attention to the data. If C{trigger} is not None, body 226 lines before C{trigger} matches are ignored. Note that C{trigger} 227 is matched to a line with trailing whitespace stripped away. 228 @return: (Yields) C{(hdr, x)} where C{hdr} is a dictionary of the header information current 229 at that point in the file and C{x} is a L{logline} class containing the current line's 230 data. 231 @raise BadFormatError: ??? 232 @raise KeyError: ??? 233 """ 234 if trigger is not None: 235 die.info('# Trigger="%s"' % trigger) 236 trigger = re.compile("^#\s*" + trigger) 237 hdr = {} 238 nmmap = {} 239 lineno = 0 240 try: 241 for line in logfd: 242 lineno += 1 243 if not line.endswith('\n'): 244 break # Incomplete last line of a file. 245 line = line.rstrip() 246 if line == '': 247 continue 248 if trigger is not None: 249 if trigger.match(line): 250 die.info("Trigger match") 251 trigger = None 252 if line.startswith('#'): 253 if '=' in line: 254 k, v = line[1:].split('=', 1) 255 hdr[k.strip()] = avio.back(v.strip()) 256 continue 257 a = avio.parse(line) 258 uid = a.pop('uid') 259 deltanm = cPickle.loads(a.pop('deltanm')) 260 261 # We do *NOT* unpickle a['prms'] here because the odds are very high 262 # that no one will care about this particular set of parameters. 263 # Unpickling is done, as needed, in the logline class. 264 # prms = cPickle.loads(a['prms']) 265 266 try: 267 names = nmmap[uid] 268 except KeyError: 269 names = [] 270 nmmap[uid] = names 271 if len(deltanm) > 0: 272 olen = len(names) 273 lmax = olen + len(deltanm) 274 for (i, nm) in deltanm: 275 if not ( 0 <= i < lmax ): 276 raise BadFormatError, "Bad index for deltanm: old=%d i=%d len(names)=%d" % (olen, i, len(names)) 277 elif i < len(names): 278 if names[i] is None: 279 names[i] = nm 280 elif names[i] != nm: 281 raise BadFormatError, "Inconsistent names in slot %d: %s and %s" % (i, names[i], nm) 282 else: 283 names.extend([None]*(1+i-len(names))) 284 names[i] = nm 285 if None in names[olen:]: 286 raise BadFormatError, "A name is none: %s" % names 287 nmmap[uid] = names 288 if trigger is None: 289 yield (hdr, logline(uid, a.pop('iter'), a.pop('logp'), a.pop('prms'), 290 names, a) 291 # Note that python evaluates left-to-right, so _a_ has already 292 # had the other stuff popped off. 293 ) 294 except KeyError, x: 295 raise BadFormatError, "Missing key: %s on line %d of %s" % (x, lineno, getattr(logfd, 'name', "?")) 296 except BadFormatError, x: 297 raise BadFormatError, "file %s line %d: %s" % (getattr(logfd, 'name', '?'), lineno, str(x)) 298 if trigger is not None: 299 die.warn('Trigger set but never fired: %s' % getattr(logfd, 'name', '???'))
300 301 302
303 -def read_raw(logfd):
304 """Read log information in raw line-by-line form.""" 305 data = [] 306 allhdrs = None 307 for (hdr, datum) in readiter(logfd): 308 data.append(datum) 309 allhdrs = hdr 310 return (allhdrs, data)
311 312 313
314 -def _find_longest_suboptimizer(clidx):
315 n = 0 316 for (k, v) in clidx.items(): 317 lp = len(v.prms()) 318 if lp > n: 319 n = lp 320 return n
321 322
323 -def _read_currentlist(fname, trigger):
324 """Starting from a file name, read in a log file. 325 @param fname: name of a (possibly compressed) log file. 326 @type fname: str 327 @param trigger: where to start in the log file. See L{readiter}. 328 @rtype: C{(list(dict(str:X)), dict(str:str))} where C{X} is a 329 L{logline} class containing the data from one line of the log file. 330 @return: a list of data groups and header information. 331 A data group is a dictionary of information, all from the same iteration. 332 For simple optimizations, there is one item in a data group, 333 conventionally under the key "UID". 334 """ 335 current = {} 336 currentlist = [] 337 lastiter = -1 338 lasthdr = None 339 n_in_current = 0 340 for (hdr, d) in readiter(GM.open_compressed(fname), trigger=trigger): 341 # print '_read_currentlist: d=', str(d)[:50], "..." 342 if d.iter!=lastiter and current: 343 if lastiter >= 0: 344 currentlist.append( current ) 345 current = {} 346 lastiter = d.iter 347 current[d.uid] = d 348 n_in_current = max(n_in_current, len(current)) 349 lasthdr = hdr 350 if n_in_current>0 and len(current) == n_in_current: 351 currentlist.append(current) 352 return (currentlist, lasthdr)
353 354 355
356 -def read_log(fname, which):
357 """Read in a log file, returning information that you can use to 358 restart the optimizer. The log file must be produced 359 by L{logger.add}(). 360 361 @param fname: the filename to read. 362 @type fname: str 363 @param which: a string that tells which log entry to grab. 364 - "last": the final logged entry; 365 - "index%d" (i.e. "index" followed directly by an integer): 366 the %dth log entry (negative numbers count backwards from the end), 367 so "index-1" is the final logged entry and "index0" is the first; 368 - "frac%g" (i.e. "frac" followed by a float): 369 picks a log entry as a fraction of the length of the log. 370 "frac0" is the first, and "frac1" is the final entry. 371 - "best": picks the iteration with the most positive logP value. 372 @type which: str 373 @return: (hdr, stepperstate, vlists, logps), where 374 - hdr is all the header information from the log file 375 (e.g. metadata). 376 - stepperstate is a dictionary of L{IC.index} instances, one 377 per UID and L{newstem2.position.OuterUID}. 378 It will evoke a True when passed to 379 L{newstem2.position.is_statedict}(). 380 Basically, this is a single moment of the stepper's state. 381 - vlists is a dictionary (again, one per UID plus...) whose 382 values are lists of position vectors that were recently 383 encountered. This is helpful for the optimizer to 384 figure out plausible step sizes. 385 Basically, this is many moments of state information near C{idx}. 386 - logps is a dictionary (again...) whose values are recently 387 obtained values for logP for that UID. This is mostly 388 useful for display and debugging purposes. 389 @raise NoDataError: ??? 390 """ 391 def sumlogp(x): 392 s = 0.0 393 for v in x.values(): 394 s += v.logp() 395 return s
396 397 currentlist, lasthdr = _read_currentlist(fname, None) 398 # currentlist is list(dict(str:X)) 399 if len(currentlist) == 0: 400 raise NoDataError, fname 401 402 if which == 'last': 403 idx = len(currentlist)-1 404 elif which == 'best': 405 # Is this the appropriate definition of best? 406 idx = 0 407 slp = sumlogp(currentlist[idx]) 408 for (i,current) in enumerate(currentlist): 409 tmp = sumlogp(current) 410 if tmp > slp: 411 idx = i 412 elif which.startswith('frac'): 413 frac = float(GM.dropfront('frac', which)) 414 if not ( 0 <= frac <= 1.0): 415 raise ValueError, "Requires 0 <= (frac=%g) <= 1.0" % frac 416 idx = int((len(currentlist))*frac) 417 elif which.startswith('index'): 418 idx = int(GM.dropfront('index', which)) 419 if not ( abs(idx) <= len(currentlist)): 420 raise ValueError, "Requires abs(idx=%d) <= len(log_items)=%d" % (idx, len(currentlist)) 421 if idx < 0: 422 idx = len(currentlist) + idx 423 else: 424 raise ValueError, "Bad choice of which=%s" % which 425 426 n = _find_longest_suboptimizer(currentlist[idx]) + 1 427 428 logps = {} 429 vlists = {} 430 stepperstate = {} 431 lc = len(currentlist) 432 assert lc > 0 433 for (k, v) in currentlist[idx].items(): 434 # k is the name of a sub-optimizer 435 # v is information about that sub-optimizer's idx^th iteration. 436 lo = int(round(idx*float(max(0,lc-n))/float(lc))) 437 assert 0 <= lo < lc 438 high = min(2*n + lo, lc) 439 assert lo < high <= lc, "lc=%d, lo=%d, idx=%d, n=%d" % (lc, lo, idx, n) 440 keymap = dict([(key, i) for (i, key) in enumerate(v.names)]) 441 stepperstate[k] = IC.index(keymap, p=v.prms(), name=str(k)) 442 vlists[k] = [c[k].prms() for c in currentlist[lo:high]] 443 logps[k] = v.logp() 444 return (lasthdr, stepperstate, vlists, logps) 445 446
447 -def read_multilog(fname, Nsamp=10, tail=0.5):
448 """Read in a log file, producing a sample of stepperstate information. 449 @param tail: Where to start in the file? 450 C{tail=0} means that you start at the trigger or beginning. 451 C{tail=1-epsilon} means you take just the last tiny bit of the file. 452 """ 453 currentlist, lasthdr = _read_currentlist(fname, None) 454 if len(currentlist) == 0: 455 raise NoDataError, fname 456 457 lc = len(currentlist) 458 stepperstates = [] 459 nsamp = min(Nsamp, int(lc*(1.0-tail))) # How many samples to extract? 460 assert nsamp > 0 461 for i in range(nsamp): 462 f = float(i)/float(nsamp) 463 aCurrent = currentlist[int((tail + (1.0-tail)*f)*lc)] 464 stepperstate = {} 465 lc = len(currentlist) 466 assert lc > 0 467 for (k, v) in aCurrent.items(): 468 # This is iterating over the various utterances. 469 keymap = dict([(key, i) for (i, key) in enumerate(v.names)]) 470 stepperstate[k] = IC.index(keymap, p=v.prms(), name=str(k)) 471 stepperstates.append( stepperstate ) 472 return (lasthdr, stepperstates)
473 474
475 -def read_multi_uid(fname, uid, Nsamp=10, tail=0.0, trigger=None):
476 """Read in a log file, selecting information only for a particular UID. 477 This *doesn't* read in multiple UIDs, no matter what the name says: 478 the "multi" refers to the fact that it gives you multiple samples 479 from the log. In other words, it provides a time-series of the changes 480 to the parameters. 481 482 OBSOLETE. 483 484 @param Nsamp: the maximum number of samples to extract (or None, or -1). 485 None means "as many as are available"; -1 means "as many samples as there are parameters." 486 @type Nsamp: L{int} or L{None}. 487 @param tail: Where to start in the file? 488 C{tail=0} means that you start at the trigger or beginning. 489 C{tail=1-epsilon} means you take just the last tiny bit of the file. 490 @raise NoDataError: When the log file is empty or a trigger is set but not triggered. 491 """ 492 currentlist, lasthdr = _read_currentlist(fname, trigger) 493 if len(currentlist) == 0: 494 raise NoDataError, fname 495 496 lc = len(currentlist) 497 tail = min(tail, 1.0 - 0.5/lc) 498 indexers = [] 499 nsamp = lc-int(lc*tail) # How many samples to extract? 500 if Nsamp is not None and Nsamp == -1: 501 Nsamp = currentlist[0][uid].prms().shape[0] 502 assert Nsamp > 0 503 if Nsamp < nsamp: 504 nsamp = Nsamp 505 elif Nsamp is not None and Nsamp < nsamp: 506 nsamp = Nsamp 507 print '# len(currentlist)=', len(currentlist), 'tail=', tail, 'nsamp=', nsamp 508 assert nsamp > 0 509 # logps = [] 510 logps = numpy.zeros((nsamp,)) 511 for i in range(nsamp): 512 f = float(i)/float(nsamp) 513 aCurrent = currentlist[int((tail + (1.0-tail)*f)*lc)] 514 assert len(currentlist) > 0 515 keymap = dict([(key, j) for (j, key) in enumerate(aCurrent[uid].names)]) 516 indexers.append( 517 IC.index(keymap, p=aCurrent[uid].prms(), name=uid) 518 ) 519 # logps.append( aCurrent[uid].logp() ) 520 # logps[i] = aCurrent[uid].logp() 521 logps.itemset(i, aCurrent[uid].logp()) 522 # print '#logps=', logps 523 return (lasthdr, indexers, logps)
524 525
526 -def read_tail_uid(fname, uid, Nsamp=10, tail=0.0, trigger=None):
527 """Read in a log file, selecting information only for a particular UID. 528 It gives you multiple samples 529 from the log. In other words, it provides a time-series of the changes 530 to the parameters. 531 532 @param Nsamp: the maximum number of samples to extract (or None, or -1). 533 None means "as many as are available"; -1 means "as many samples as there are parameters." 534 @type Nsamp: L{int} or L{None}. 535 @param tail: Where to start in the file? 536 C{tail=0} means that you start at the trigger or beginning. 537 C{tail=1-epsilon} means you take just the last tiny bit of the file. 538 @raise NoDataError: When the log file is empty or a trigger is set but not triggered. 539 @rtype: (dict{str:str}, list(logline)) 540 @return: The data file's header, and the selected list of L{logline} instances. 541 """ 542 currentlist, lasthdr = _read_currentlist(fname, trigger) 543 if len(currentlist) == 0: 544 raise NoDataError, fname 545 546 lc = len(currentlist) 547 tail = min(tail, 1.0 - 0.5/lc) 548 nsamp = lc-int(lc*tail) # How many samples to extract? 549 if Nsamp is not None and Nsamp == -1: 550 Nsamp = currentlist[0][uid].prms().shape[0] 551 assert Nsamp > 0 552 if Nsamp < nsamp: 553 nsamp = Nsamp 554 elif Nsamp is not None and Nsamp < nsamp: 555 nsamp = Nsamp 556 die.info('N(data lines)=%d tail=%.3f nsamp=%d' % (len(currentlist), tail, nsamp)) 557 assert nsamp > 0 558 rv = [] 559 for i in range(nsamp): 560 f = float(i)/float(nsamp) 561 aCurrent = currentlist[int((tail + (1.0-tail)*f)*lc)] 562 assert len(currentlist) > 0 563 rv.append(aCurrent[uid]) 564 return (lasthdr, rv)
565 566
567 -def read_human_fmt(fname):
568 """This reads a more human-readable format, as produced by L{print_log}. 569 It can easily be modified with a text editor to change parameters. 570 """ 571 statedict = {} 572 logps = {} 573 hdr = {} 574 lnum = 0 575 uid = None 576 map = {} 577 prms = [] 578 for line in open(fname, 'r'): 579 if not line.endswith('\n'): 580 die.warn("Incomplete line - no newline: %d" % (lnum+1)) 581 continue 582 lnum += 1 583 line = line.strip() 584 if line == '': 585 continue 586 elif line.startswith('#'): 587 tmp = avio.parse(line[1:]) 588 589 if 'uid' in tmp: 590 if uid is not None: 591 statedict[uid] = IC.index(map, p=numpy.array(prms, numpy.float), name=uid) 592 map = {} 593 prms = [] 594 uid = tmp['uid'] 595 if uid in statedict: 596 raise BadFormatError, "human format: duplicate uid=%s" % uid 597 if 'logP' in tmp: 598 logps[uid] = float(tmp['logP']) 599 else: 600 hdr.update(tmp) 601 else: 602 cols = line.strip().split(None, 1) 603 if len(cols) != 2: 604 raise BadFormatError, "human format: needs two columns line %d" % lnum 605 map[IC.index_base._unfmt(cols[1])] = len(prms) 606 try: 607 prms.append( float(cols[0]) ) 608 except ValueError: 609 raise BadFormatError, "human format: need number in first column line %d: %s" % (lnum, cols[0]) 610 statedict[uid] = IC.index(map, p=numpy.array(prms, numpy.float), name=uid) 611 612 return (hdr, statedict, logps)
613 614
615 -def load_module(phdr):
616 try: 617 m = load_mod.load_named(os.path.splitext(phdr['model_file'])[0]) 618 except ImportError: 619 m = None 620 if m is None: 621 m = load_mod.load_named(phdr['model_name']) 622 return m
623 624 638 639 640 if __name__ == '__main__': 641 import sys 642 which = 'last' 643 arglist = sys.argv[1:] 644 while arglist and arglist[0].startswith('-'): 645 arg = arglist.pop(0) 646 if arg == '--': 647 break 648 elif arg == '-last': 649 which = 'last' 650 elif arg == '-best': 651 which = 'best' 652 elif arg == '-frac': 653 which = 'frac%f' % float(arglist.pop(0)) 654 elif arg == '-index': 655 which = 'index%d' % int(arglist.pop(0)) 656 else: 657 die.die('Unrecognized arg: %s' % arg) 658 assert len(arglist) == 1 659 fname = arglist.pop(0) 660 print_log(fname, which) 661