Package gmisclib :: Module mcmc_indexclass
[frames] | no frames]

Source Code for Module gmisclib.mcmc_indexclass

  1  """This module provides several classes to manage the parameters of an algorithm, 
  2  particularly so that you can run the mcmc.py optimizer on it.  Essentially, the main 
  3  problem that it solves is how to keep track of human-readable names assigned to 
  4  components of a vector of parameters. 
  5   
  6  You have an L{index_base} object C{i}, and out can get a parameter from it 
  7  like this:  C{i.p("velocity", "car", "x")}.   That would get out a number that 
  8  might correspond to the x-velocity of a car.   That's the human-readable side, 
  9  but you can also get a vector of all the parameters via C{i.get_prms()} and 
 10  get a mapping between names and indices into that vector via C{i.map}. 
 11   
 12  You can also have parameters that do not appear in the vector 
 13  of adjustable parameters. 
 14  These are known as "fixed" parameters; their intent is to let you 
 15  fix some of the parameters of an optimization, and let the optimizer 
 16  play with the remainder. 
 17   
 18  Use L{guess} first to establish the mapping, write it to a file 
 19  and guess reasonable starting values for an optimization. 
 20  Then, use L{index} to get parameter values corresponding to names. 
 21  """ 
 22   
 23  import re 
 24  import os 
 25   
 26  import numpy 
 27   
 28  import die 
 29  import dictops 
 30  import g_encode 
 31   
 32  Debug = False 
33 34 35 36 -class PriorProbDist(object):
37 """This is a map from a regular expression to a callable object. It's used to find 38 which element of the PriorProbDist applies to a particular parameter. 39 """ 40
41 - def __init__(self, guesses):
42 """@type guesses: list( (str, L{mcmc_idxr.probdist_c}, C{R}) ) where C{R} is an optional float. 43 @param guesses: a list of associations between a pattern and a probability distribution. 44 If C{R} exists, it is a float used for probing for function roughness. 45 """ 46 try: 47 for guess in guesses: 48 if not (2 <= len(guess) <= 3): 49 raise ValueError, "Wrong length=%d" % len(guess) 50 pat, sampler = guess[:2] 51 assert isinstance(pat, str) 52 except ValueError, x: 53 raise ValueError(*(('Guesses needs to be a 2-tuple', str(x), 'guesses:',)+tuple(guesses))) 54 self.guesses = [] 55 for guess in guesses: 56 pat, sampler = guess[:2] 57 if len(guess) == 3: 58 probedist = guess[2] 59 else: 60 probedist = None 61 self.guesses.append( (re.compile('^%s$' % pat), sampler, probedist) )
62 63
64 - def __getitem__(self, fkey):
65 """Look up a string against the stored list of patterns. Return the coresponding callable. 66 67 @param fkey: this takes the key formatted as a string, not as a tuple. 68 @type fkey: str 69 @rtype whatever is the middle component of a C{guesses} tuple. This is normally an instance of L{mcmc_idxr.probdist_c}. 70 """ 71 for (pat, fcn, probedist) in self.guesses: 72 if pat.match(fkey): 73 return fcn 74 raise KeyError, "No match found for '%s'" % fkey
75 76
77 - def getprobe(self, fkey):
78 """ 79 @param fkey: this takes the key formatted as a string, not as a tuple. 80 @type fkey: str 81 @rtype: whatever you get from the final component of a C{guesses} tuple. 82 This is normally a float. 83 """ 84 for (pat, fcn, probedist) in self.guesses: 85 if pat.match(fkey): 86 return probedist 87 raise KeyError, "No match found for '%s'" % fkey
88
89 90 -class IndexKeyError(KeyError):
91 - def __init__(self, *a):
92 KeyError.__init__(self, *a)
93
94 95 -class index_base(object):
96 _e = g_encode.encoder(notallowed=' %,') 97
98 - def __init__(self, keymap, name=None):
99 assert isinstance(keymap, dict), "Keymap must be a dict, not %s: %s" % (type(keymap), keymap) 100 self.map = keymap 101 self._p = {} 102 self.name = name 103 self.modified = 0
104
105 - def clear(self):
106 self._p = {} 107 self.modified = 0
108
109 - def p(self, *key):
110 raise RuntimeError, 'Virtual Function'
111
112 - def p_lower(self, ll, *key):
113 raise RuntimeError, 'Virtual Function'
114
115 - def p_range(self, lb, ub, *key):
116 raise RuntimeError, 'Virtual Function'
117 118
119 - def pN(self, num, *key):
120 return [self.p( *(key+(str(i),)) ) 121 for i in range(num) 122 ]
123 124
125 - def pNr(self, num, low, high, *key):
126 return [self.p_range(low, high, *(key+(str(i),)) ) 127 for i in range(num) 128 ]
129
130 - def pNl(self, num, low, high, *key):
131 return [self.p_lower(low, high, *(key+(str(i),)) ) 132 for i in range(num) 133 ]
134 135
136 - def get_prms(self):
137 """@return: a vector of the adjustable parameters 138 @rtype: L{numpy.ndarray} 139 """ 140 raise RuntimeError, 'Virtual Function'
141 142
143 - def columns(self):
144 cols = [None] * len(self.map) 145 for (k, i) in self.map.items(): 146 cols[i] = self._fmt(k) 147 return cols
148 149
150 - def n(self):
151 """@return: the number of adjustable parameters. 152 @rtype: int 153 @note: to include the number of fixed parameters, you can take C{len(self._p)}. 154 """ 155 return len(self.map)
156 157 158 @classmethod
159 - def _unfmt(cls, s):
160 """This converts the in-file format of a parameter name (a comma separated string) 161 into the in-memory format (a tuple). 162 @param s: parameter name made of parameter-separated components 163 @type s: str 164 @return tuple-formated parameter name 165 @rtype tuple(str) 166 """ 167 return tuple( [ cls._e.back(q) for q in s.split(',')] )
168 169 170 @classmethod
171 - def _fmt(cls, key):
172 for kq in key: 173 assert isinstance(kq, str), "Key is not entirely composed of strings: %s" % repr(key) 174 return ','.join( [ cls._e.fwd(kq) for kq in key] )
175 176
177 - def i_k_val(self):
178 rv = [] 179 for (k, v) in self.map.items(): 180 rv.append( (v, k, self.p(*k)) ) 181 rv.sort() 182 return rv
183 184
185 - def comment(self, c):
186 pass
187 188
189 - def hash(self):
190 return hash(tuple(sorted(self.map.items())))
191
192 - def __call__(self):
193 raise RuntimeError, "Virtual method - returns a float"
194
195 - def logdens(self, x):
196 raise RuntimeError, "Virtual method - returns a float"
197
198 - def set_all_fixed(self, kv):
199 """This sets a whole dictionary's worth of fixed parameters. 200 @param kv: parameters to set 201 @type kv: dict 202 """ 203 self._p.update(kv) 204 self.modified += 1
205
206 - def set_fixed(self, v, *k):
207 """You can have "fixed" parameters that are not updated by the MCMC optimizer. 208 This sets one of them. 209 @param v: the value 210 @param k: the key. 211 """ 212 self._p[k] = v 213 self.modified += 1
214 215
216 - def get_fixed(self):
217 """Get only the fixed parameters. 218 @rtype: a dictionary from parameter name to value. 219 """ 220 rv = {} 221 for (k, v) in self._p.items(): 222 if k not in self.map: 223 rv[k] = v 224 return rv
225 226
227 - def get_all(self):
228 return self._p.copy()
229
230 231 -class MissingParameterError(KeyError):
232 - def __init__(self, *s):
233 KeyError.__init__(self, *s)
234
235 236 -def _get_doc(fcn):
237 try: 238 doc = fcn.__doc__ 239 # doc = getattr(fcn, '__doc__') 240 except AttributeError: 241 doc = str(fcn) 242 return doc
243
244 245 -class sampler(object):
246 - def __init__(self):
247 pass
248
249 - def __call__(self):
250 raise RuntimeError, "Virtual method - returns a float"
251
252 - def logdens(self, x):
253 raise RuntimeError, "Virtual method - returns a float"
254
255 256 -class guess(index_base):
257 - def __init__(self, guesses, fp=None, name=''):
258 """Create an object that produces the initial guesses at parameter values. 259 Feed it a list of patterns that match parameter names, each with a function to 260 produce a random guess for that parameter. When you request the value of 261 a parameter via C{self.p} or similar, it will look up the parameter name, 262 find the first regular expression that matches it, and call the associated 263 function to generate a random value for that parameter. 264 265 @type guesses: list(tuple(str, L{sampler})) 266 @param guesses: a list of C{(pattern, sampler)} tuples, where 267 C{pattern} is an (uncompiled) regular expression, 268 The C{pattern} selects for which parameters C{sampler} will be used. 269 It must match the entire parameter name; it will be surrounded by "^$" before 270 use. 271 C{sampler} is a function (or class) that produces samples from some 272 probability distribution. C{Sampler} is 273 called without arguments, once for each parameter, for each guess that is needed. 274 @param fp: an open file which can be used as a log, or None. 275 @type fp: a writeable C{file} object. 276 @param name: a named for the indexer. It is just stored away, and may be accessed 277 as the "name" attribute. 278 @type name: anything. 279 """ 280 index_base.__init__(self, {}, name=name) 281 self.guesses = PriorProbDist(guesses) 282 self.fp = fp 283 if fp is not None: 284 fp.writelines('# ------------------------------------\n') 285 self.active = True 286 self.pmaker = {} 287 self.header = {} 288 self.user = None 289 self._usage = dictops.dict_of_sets() 290 if fp is not None and hasattr(fp, 'name'): 291 self.header['mapfname'] = fp.name
292 293
294 - def comment(self, c):
295 if self.fp is not None: 296 self.fp.writelines(['#', str(c), '\n']) 297 self.fp.flush()
298
299 - def clear(self):
300 index_base.clear(self) 301 self._usage = dictops.dict_of_sets()
302
303 - def get_prms(self):
304 """This L{freeze}s your set of parameters so no more will be added, 305 and then returns the current guess for all the adjustable parameters. 306 @return: a vector of the adjustable parameters 307 @rtype: L{numpy.ndarray} 308 """ 309 self.freeze() 310 a = numpy.zeros((self.n(),)) 311 for (k,i) in self.map.items(): 312 try: 313 a.itemset(i, self._p[k]) 314 except KeyError: 315 a.itemset(i, self.pmaker[k]() ) 316 self._p = {} 317 return a
318 319
320 - def _wrerr(self, text, key):
321 if self.fp is not None: 322 fmtd = self._fmt(key) 323 self.fp.writelines('ERR: %s %s\n' % (text, fmtd)) 324 self.fp.flush()
325 326 327 _fnsp = re.compile('/([^/]*)(?=/)') 328 329
330 - def _write(self, key, extra_stackdrop=0):
331 if self.fp is not None: 332 import inspect 333 import avio 334 fmtd = self._fmt(key) 335 stack = [] 336 for (fro, fname, lnum, funcname, context, icontext) in inspect.stack()[2+extra_stackdrop:]: 337 del fro 338 fname = self._fnsp.sub(lambda g: g.group(0)[:2], fname) 339 stack.append( '%s(%s:%d)' % (funcname, fname, lnum) ) 340 tmp = {'key': fmtd, 'index': self.map.get(key, "__fixed__"), 341 'stack': ','.join(stack), 342 'name': self.name 343 } 344 self.fp.writelines(avio.concoct(tmp) + '\n') 345 self.fp.flush()
346 347
348 - def _compute(self, key, fixer, extra_stackdrop):
349 try: 350 return self._p[key] 351 except KeyError: 352 pass 353 try: 354 tmp = fixer(self.pmaker[key]()) 355 except KeyError: 356 if not self.active: 357 raise MissingParameterError(key) 358 359 fkey = self._fmt(key) 360 tmpmaker = self.guesses[fkey] 361 tmp = tmpmaker() 362 if tmpmaker.logdens is not None: 363 self.map[key] = len(self.pmaker) 364 self.pmaker[key] = tmpmaker 365 if self.user is not None: 366 self._usage.add(key, self.user) 367 self._write(key, extra_stackdrop) 368 self._p[key] = tmp 369 return tmp
370 371
372 - def __repr__(self):
373 s = [] 374 for k in self.map.keys(): 375 s.append('%s=%s=%s' % (','.join(k), self.pmaker.doc, 376 str(self._p[k]))) 377 s.sort() 378 s.insert(0, "active" if self.active else "frozen") 379 return '<guess: %s>' % '\n\t'.join(s)
380 381
382 - def set_user(self, user):
383 self.user = user
384 385
386 - def usage(self):
387 """Tells you which users access each key. 388 @returns: L{dictops.dict_of_sets} (approximately a dict) 389 indexed by integer indices and mapping to a set of users. 390 e.g. {('some', 'key'): set(['user1', 'user2']), ...}. 391 @rtype: dict(tuple: set(str)) 392 """ 393 return self._usage
394 395
396 - def p(self, *key):
397 return self._compute(key, None, 0)
398 399 pc = p 400
401 - def p_lower(self, ll, *key):
402 """A range, with a lower limit.""" 403 def fix_ll(x): 404 if x < ll: 405 return 2*ll - x 406 return x
407 fix_ll.__doc__ = "lower limit %g" % ll 408 return self._compute(key, fix_ll, 0)
409 410 pc_lower = p_lower 411 412
413 - def p_range(self, lb, ub, *key):
414 assert ub > lb, "Bad range: %s<%s for %s" % (str(ub), str(lb), str(key)) 415 """A range, folding at the ends.""" 416 def fix_range(v): 417 if not ( lb <= v <= ub): 418 r = ub - lb 419 tmp = (v-lb) % (2*r) 420 v = lb + (tmp if tmp<=r else 2*r-tmp) 421 return v
422 fix_range.__doc__ = "range from %g to %g" % (lb, ub) 423 return self._compute(key, fix_range, 0) 424 425
426 - def p_periodic(self, ub, *key):
427 """A range with periodic boundary conditions 0 to ub.""" 428 assert ub > 0 429 def fix_range(v): 430 return v % ub
431 fix_range.__doc__ = "periodic from 0 to %g" % ub 432 return self._compute(key, fix_range, 0) 433 434
435 - def freeze(self):
436 """This call means that all parameter names have been seen. 437 Any novel names in calls to C{self.p()} will then raise an error. 438 """ 439 self.active = False
440 441
442 - def set_prm(self, v, *key):
443 """This function is used to adjust parameters in non-trivial ways. 444 """ 445 self.modified += 1 446 self._p[key] = v 447 if Debug: 448 die.info('set_prm on guess: %g %s' % (v, str(key)))
449
450 451 452 453 454 455 -class index(index_base):
456 457 reprkeys = re.compile('.') 458
459 - def __init__(self, keymap, p=None, name='', fixed=None):
460 assert isinstance(keymap, dict) 461 index_base.__init__(self, keymap, name=name) 462 assert p is None or len(p) == len(keymap) 463 self._prms = None 464 self._hashmod = -1 465 if p is not None: 466 self.set_prms(p) 467 if fixed is not None: 468 self.set_all_fixed(fixed)
469 # self.hash() 470 471
472 - def hash(self):
473 if self._hashmod == self.modified: 474 return self._hashcache 475 phash = hash(tuple(self._p.values())) 476 self._hashcache = hash( (index_base.hash(self), phash, self.modified) ) 477 self._hashmod = self.modified 478 return self._hashcache
479 480
481 - def clear(self):
482 index_base.clear(self) 483 self._prms = None 484 self._hashmod = -1
485
486 - def set_prms(self, prm):
487 """This sets the vector of parameters to be used and modified. 488 Note that p_lower() can affect the prm argument! No copy is 489 made and this side effect is needed for the problem_definition.fixer() 490 function. 491 492 This function should only be called after the map is established. 493 """ 494 assert prm.shape == (len(self.map),) 495 self.modified = 0 496 self._hashmod = -1 497 self._prms = prm 498 for (k, v) in self.map.items(): 499 self._p[k] = prm.item(v)
500 501
502 - def set_prm(self, v, *key):
503 """This sets one element of the parameters vector. 504 This function should only be called after the map is established; it cannot be used to set fixed parameters. 505 """ 506 try: 507 self._prms.itemset(self.map[key], v) 508 except KeyError: 509 raise IndexKeyError(key, self) 510 self._p[key] = v 511 self.modified += 1
512 513
514 - def p(self, *key):
515 """ 516 @type key: tuple(str) 517 @return: the parameter indexed by C{*key}. 518 @rtype: float 519 """ 520 try: 521 return self._p[key] 522 except KeyError: 523 raise IndexKeyError(key, self)
524 525
526 - def p_lower(self, ll, *key):
527 """@note: this may have side effects on the array passed to prms(). 528 @raise KeyError: when you apply it to a fixed parameter below the specified C{ll}. 529 """ 530 try: 531 v = self._p[key] 532 except KeyError: 533 raise IndexKeyError(key, self) 534 if v < ll: 535 v = 2.0*ll - v 536 self.modified += 1 537 self._prms.itemset(self.map[key], v) 538 self._p[key] = v 539 return v
540 541
542 - def p_range(self, lb, ub, *key):
543 """@note: this may have side effects on the array passed to prms(). 544 @raise KeyError: when you apply it to a fixed parameter outside specified range C{(lb,ub)}. 545 """ 546 try: 547 v = self._p[key] 548 except KeyError: 549 raise IndexKeyError(key, self) 550 if not (lb <= v <= ub): 551 assert ub > lb 552 r = ub - lb 553 tmp = (v-ub) % (2*r) 554 v = lb + (tmp if tmp<=r else 2*r-tmp) 555 self.modified += 1 556 self._p[key] = v 557 self._prms.itemset(self.map[key], v) 558 return v
559 560
561 - def p_periodic(self, ub, *key):
562 """A range with periodic boundary conditions 0 to ub. 563 @note: this may have side effects on the array passed to prms(). 564 @raise KeyError: when you apply it to a fixed parameter outside specified range C{(lb,ub)}. 565 """ 566 try: 567 v = self._p[key] 568 except KeyError: 569 raise IndexKeyError(key, self) 570 if not (0 <= v < ub): 571 assert ub > 0 572 v = v % ub 573 self.modified += 1 574 self._p[key] = v 575 self._prms.itemset(self.map[key], v) 576 return v
577 578 pc = p 579 pc_lower = p_lower 580 581
582 - def get_prms(self):
583 """@return: a vector of the adjustable parameters 584 @rtype: L{numpy.ndarray} 585 """ 586 return self._prms
587 588
589 - def toDisplay(self, reprkeys=None):
590 """@note: This prints the values without the folding or wrapping that is done by 591 L{p_range}, L{p_periodic}, L{p_lower}, or L{p_upper}. So, you should not 592 be surprised or disturbed if you see a value outside the expected range. 593 If you extract the values normally, using the p_* accessor funcitons, 594 all will be well. 595 """ 596 if reprkeys is None: 597 reprkeys = self.reprkeys 598 tmp = [] 599 for k in self.map.keys(): 600 ifk = index._fmt(k) 601 if reprkeys.match(ifk): 602 tmp.append("%s=%.3f" % (index._fmt(k), self.p(*k)) ) 603 if len(tmp) < self.n(): 604 tmp.append('...') 605 tmp.sort() 606 return ', '.join(tmp)
607 608
609 - def __repr__(self):
610 return '<Index:%s %s>' % (self.name, self.toDisplay())
611 612
613 - def items(self):
614 """@return: a sequence of all the parameter, fixed or adjustable. 615 @rtype: sequence( (key, value) ) where C{key} could be anything hashable, 616 but is usually a tuple of strings, and C{value} is a L{float}. 617 """ 618 return self._p.items()
619
620 - def keys(self):
621 return self._p.keys()
622
623 624 -def default_mapfp(nmprefix='m'):
625 import time 626 import gpkmisc 627 gpkmisc.makedirs('MAPS') 628 # os.system('mkdir -p MAPS') 629 fname = time.strftime('%Y%m%d-%H%M%S.map') 630 fp = open(os.path.join('MAPS', '%s%s' % (nmprefix, fname)), 631 'w') 632 return fp
633
634 635 -def guess_to_indexer(g):
636 assert isinstance(g, guess) 637 return index(g.map, g.get_prms(), name=g.name, fixed=g.get_fixed())
638
639 640 -def reindex(q, ref):
641 """This remaps the parameters of C{q} so that they are in 642 the same order as in C{ref}. 643 @rtype: indexer 644 @return: an indexer that contains the same data as C{q}, but shares its parameter mapping with C{ref}. 645 @note: The parameters in C{q} must be a superset of the adjustable 646 parameters in C{ref}. Extra parameters in C{q} will be 647 carried over as fixed parameters. 648 """ 649 if set(q.keys()) != set(ref.keys()): 650 rk = set(ref.keys()) 651 qk = set(q.keys()) 652 die.info("Common parameters: %s" % str(rk.intersection(qk))) 653 die.info("Parameters only in q: %s" % str(qk-rk)) 654 die.info("Parameters only in ref: %s" % str(rk-qk)) 655 rvn = numpy.zeros((ref.n(),)) 656 rvf = {} 657 for (k,i) in ref.map.items(): 658 rvn[i] = q.p(*k) 659 for (k, v) in q.items(): 660 if k not in ref.map: 661 rvf[k] = v 662 return index(ref.map, p=rvn, name=q.name, fixed=rvf)
663
664 665 666 667 -class index_counted(index):
668 """This is a special-purpose wrapped version of L{index}. 669 Its only function is to make sure that all parameters were 670 used, in addition to the normal checking that you do not use 671 any parameters that are undefined. 672 673 Use it exactly as you would use L{index}, but it is slower. 674 """ 675
676 - def __init__(self, idxr):
677 index.__init__(self, idxr.map, idxr.get_prms(), name=idxr.name) 678 self.used = set()
679
680 - def p(self, *key):
681 self.used.add(key) 682 return index.p(self, *key)
683
684 - def p_lower(self, ll, *key):
685 self.used.add(key) 686 return index.p_lower(self, ll, *key)
687
688 - def p_range(self, lb, ub, *key):
689 self.used.add(key) 690 return index.p_range(self, lb, ub, *key)
691
692 - def p_periodic(self, ub, *key):
693 self.used.add(key) 694 return index.p_periodic(self, ub, *key)
695
696 - def confirm_all_used(self):
697 """This can be used to check that we aren't using an indexer with 698 a problem that is too small (for instance, if the indexer was read in 699 from a log file). 700 @raise IndexKeyError: if there are some keys in the index that have not 701 been used. 702 """ 703 sm = set(self.map.keys()) 704 if self.used != sm: 705 die.info("used keys: %s" % self.used) 706 die.info("keys in map: %s" % self.map.keys()) 707 raise IndexKeyError("Unused keys:", sm-self.used)
708