[frames] | no frames]

# Source Code for Module gmisclib.mcmc_indexclass

```  1  """This module provides several classes to manage the parameters of an algorithm,
2  particularly so that you can run the mcmc.py optimizer on it.  Essentially, the main
3  problem that it solves is how to keep track of human-readable names assigned to
4  components of a vector of parameters.
5
6  You have an L{index_base} object C{i}, and out can get a parameter from it
7  like this:  C{i.p("velocity", "car", "x")}.   That would get out a number that
8  might correspond to the x-velocity of a car.   That's the human-readable side,
9  but you can also get a vector of all the parameters via C{i.get_prms()} and
10  get a mapping between names and indices into that vector via C{i.map}.
11
12  You can also have parameters that do not appear in the vector
14  These are known as "fixed" parameters; their intent is to let you
15  fix some of the parameters of an optimization, and let the optimizer
16  play with the remainder.
17
18  Use L{guess} first to establish the mapping, write it to a file
19  and guess reasonable starting values for an optimization.
20  Then, use L{index} to get parameter values corresponding to names.
21  """
22
23  import re
24  import os
25
26  import numpy
27
28  import die
29  import dictops
30  import g_encode
31
32  Debug = False
33
34
35
36 -class PriorProbDist(object):
37          """This is a map from a regular expression to a callable object.   It's used to find
38          which element of the PriorProbDist applies to a particular parameter.
39          """
40
41 -        def __init__(self, guesses):
42                  """@type guesses: list( (str, L{mcmc_idxr.probdist_c}, C{R}) )   where C{R} is an optional float.
43                  @param guesses: a list of associations between a pattern and a probability distribution.
44                          If C{R} exists, it is a float used for probing for function roughness.
45                  """
46                  try:
47                          for guess in guesses:
48                                  if not (2 <= len(guess) <= 3):
49                                          raise ValueError, "Wrong length=%d" % len(guess)
50                                  pat, sampler = guess[:2]
51                                  assert isinstance(pat, str)
52                  except ValueError, x:
53                          raise ValueError(*(('Guesses needs to be a 2-tuple', str(x), 'guesses:',)+tuple(guesses)))
54                  self.guesses = []
55                  for guess in guesses:
56                          pat, sampler = guess[:2]
57                          if len(guess) == 3:
58                                  probedist = guess[2]
59                          else:
60                                  probedist = None
61                          self.guesses.append( (re.compile('^%s\$' % pat), sampler, probedist) )
62
63
64 -        def __getitem__(self, fkey):
65                  """Look up a string against the stored list of patterns.  Return the coresponding callable.
66
67                  @param fkey: this takes the key formatted as a string, not as a tuple.
68                  @type fkey: str
69                  @rtype whatever is the middle component of a C{guesses} tuple.  This is normally an instance of L{mcmc_idxr.probdist_c}.
70                  """
71                  for (pat, fcn, probedist) in self.guesses:
72                          if pat.match(fkey):
73                                  return fcn
74                  raise KeyError, "No match found for '%s'" % fkey
75
76
77 -        def getprobe(self, fkey):
78                  """
79                  @param fkey: this takes the key formatted as a string, not as a tuple.
80                  @type fkey: str
81                  @rtype: whatever you get from the final component of a C{guesses} tuple.
82                          This is normally a float.
83                  """
84                  for (pat, fcn, probedist) in self.guesses:
85                          if pat.match(fkey):
86                                  return probedist
87                  raise KeyError, "No match found for '%s'" % fkey
88
89
90 -class IndexKeyError(KeyError):
91 -        def __init__(self, *a):
92                  KeyError.__init__(self, *a)
93
94
95 -class index_base(object):
96          _e = g_encode.encoder(notallowed=' %,')
97
98 -        def __init__(self, keymap, name=None):
99                  assert isinstance(keymap, dict), "Keymap must be a dict, not %s: %s" % (type(keymap), keymap)
100                  self.map = keymap
101                  self._p = {}
102                  self.name = name
103                  self.modified = 0
104
105 -        def clear(self):
106                  self._p = {}
107                  self.modified = 0
108
109 -        def p(self, *key):
110                  raise RuntimeError, 'Virtual Function'
111
112 -        def p_lower(self, ll, *key):
113                  raise RuntimeError, 'Virtual Function'
114
115 -        def p_range(self, lb, ub, *key):
116                  raise RuntimeError, 'Virtual Function'
117
118
119 -        def pN(self, num, *key):
120                  return [self.p( *(key+(str(i),)) )
121                                  for i in range(num)
122                                  ]
123
124
125 -        def pNr(self, num, low, high, *key):
126                  return [self.p_range(low, high, *(key+(str(i),)) )
127                                  for i in range(num)
128                                  ]
129
130 -        def pNl(self, num, low, high, *key):
131                  return [self.p_lower(low, high, *(key+(str(i),)) )
132                                  for i in range(num)
133                                  ]
134
135
136 -        def get_prms(self):
137                  """@return: a vector of the adjustable parameters
138                  @rtype: L{numpy.ndarray}
139                  """
140                  raise RuntimeError, 'Virtual Function'
141
142
143 -        def columns(self):
144                  cols = [None] * len(self.map)
145                  for (k, i) in self.map.items():
146                          cols[i] = self._fmt(k)
147                  return cols
148
149
150 -        def n(self):
151                  """@return: the number of adjustable parameters.
152                  @rtype: int
153                  @note: to include the number of fixed parameters, you can take C{len(self._p)}.
154                  """
155                  return len(self.map)
156
157
158          @classmethod
159 -        def _unfmt(cls, s):
160                  """This converts the in-file format of a parameter name (a comma separated string)
161                  into the in-memory format (a tuple).
162                  @param s: parameter name made of parameter-separated components
163                  @type s: str
164                  @return tuple-formated parameter name
165                  @rtype tuple(str)
166                  """
167                  return tuple( [ cls._e.back(q) for q in s.split(',')] )
168
169
170          @classmethod
171 -        def _fmt(cls, key):
172                  for kq in key:
173                          assert isinstance(kq, str), "Key is not entirely composed of strings: %s" % repr(key)
174                  return ','.join( [ cls._e.fwd(kq) for kq in key]  )
175
176
177 -        def i_k_val(self):
178                  rv = []
179                  for (k, v) in self.map.items():
180                          rv.append( (v, k, self.p(*k)) )
181                  rv.sort()
182                  return rv
183
184
185 -        def comment(self, c):
186                  pass
187
188
189 -        def hash(self):
190                  return hash(tuple(sorted(self.map.items())))
191
192 -        def __call__(self):
193                  raise RuntimeError, "Virtual method - returns a float"
194
195 -        def logdens(self, x):
196                  raise RuntimeError, "Virtual method - returns a float"
197
198 -        def set_all_fixed(self, kv):
199                  """This sets a whole dictionary's worth of fixed parameters.
200                  @param kv: parameters to set
201                  @type kv: dict
202                  """
203                  self._p.update(kv)
204                  self.modified += 1
205
206 -        def set_fixed(self, v, *k):
207                  """You can have "fixed" parameters that are not updated by the MCMC optimizer.
208                  This sets one of them.
209                  @param v: the value
210                  @param k: the key.
211                  """
212                  self._p[k] = v
213                  self.modified += 1
214
215
216 -        def get_fixed(self):
217                  """Get only the fixed parameters.
218                  @rtype: a dictionary from parameter name to value.
219                  """
220                  rv = {}
221                  for (k, v) in self._p.items():
222                          if k not in self.map:
223                                  rv[k] = v
224                  return rv
225
226
227 -        def get_all(self):
228                  return self._p.copy()
229
230
231 -class MissingParameterError(KeyError):
232 -        def __init__(self, *s):
233                   KeyError.__init__(self, *s)
234
235
236 -def _get_doc(fcn):
237          try:
238                  doc = fcn.__doc__
239                  # doc = getattr(fcn, '__doc__')
240          except AttributeError:
241                  doc = str(fcn)
242          return doc
243
244
245 -class sampler(object):
246 -        def __init__(self):
247                  pass
248
249 -        def __call__(self):
250                  raise RuntimeError, "Virtual method - returns a float"
251
252 -        def logdens(self, x):
253                  raise RuntimeError, "Virtual method - returns a float"
254
255
256 -class guess(index_base):
257 -        def __init__(self, guesses, fp=None, name=''):
258                  """Create an object that produces the initial guesses at parameter values.
259                  Feed it a list of patterns that match parameter names, each with a function to
260                  produce a random guess for that parameter.   When you request the value of
261                  a parameter via C{self.p} or similar, it will look up the parameter name,
262                  find the first regular expression that matches it, and call the associated
263                  function to generate a random value for that parameter.
264
265                  @type guesses: list(tuple(str, L{sampler}))
266                  @param guesses: a list of C{(pattern, sampler)} tuples, where
267                          C{pattern} is an (uncompiled) regular expression,
268                          The C{pattern} selects for which parameters C{sampler} will be used.
269                          It must match the entire parameter name; it will be surrounded by "^\$" before
270                          use.
271                          C{sampler} is a function (or class) that produces samples from some
272                          probability distribution.  C{Sampler} is
273                          called without arguments, once for each parameter, for each guess that is needed.
274                  @param fp: an open file which can be used as a log, or None.
275                  @type fp: a writeable C{file} object.
276                  @param name: a named for the indexer.   It is just stored away, and may be accessed
277                          as the "name" attribute.
278                  @type name: anything.
279                  """
280                  index_base.__init__(self, {}, name=name)
281                  self.guesses = PriorProbDist(guesses)
282                  self.fp = fp
283                  if fp is not None:
284                          fp.writelines('# ------------------------------------\n')
285                  self.active = True
286                  self.pmaker = {}
288                  self.user = None
289                  self._usage = dictops.dict_of_sets()
290                  if fp is not None and hasattr(fp, 'name'):
292
293
294 -        def comment(self, c):
295                  if self.fp is not None:
296                          self.fp.writelines(['#', str(c), '\n'])
297                          self.fp.flush()
298
299 -        def clear(self):
300                  index_base.clear(self)
301                  self._usage = dictops.dict_of_sets()
302
303 -        def get_prms(self):
304                  """This L{freeze}s your set of parameters so no more will be added,
305                  and then returns the current guess for all the adjustable parameters.
306                  @return: a vector of the adjustable parameters
307                  @rtype: L{numpy.ndarray}
308                  """
309                  self.freeze()
310                  a = numpy.zeros((self.n(),))
311                  for (k,i) in self.map.items():
312                          try:
313                                  a.itemset(i, self._p[k])
314                          except KeyError:
315                                  a.itemset(i, self.pmaker[k]() )
316                  self._p = {}
317                  return a
318
319
320 -        def _wrerr(self, text, key):
321                  if self.fp is not None:
322                          fmtd = self._fmt(key)
323                          self.fp.writelines('ERR: %s %s\n' % (text, fmtd))
324                          self.fp.flush()
325
326
327          _fnsp = re.compile('/([^/]*)(?=/)')
328
329
330 -        def _write(self, key, extra_stackdrop=0):
331                  if self.fp is not None:
332                          import inspect
333                          import avio
334                          fmtd = self._fmt(key)
335                          stack = []
336                          for (fro, fname, lnum, funcname, context, icontext) in inspect.stack()[2+extra_stackdrop:]:
337                                  del fro
338                                  fname = self._fnsp.sub(lambda g: g.group(0)[:2], fname)
339                                  stack.append( '%s(%s:%d)' % (funcname, fname, lnum) )
340                          tmp = {'key': fmtd, 'index': self.map.get(key, "__fixed__"),
341                                  'stack': ','.join(stack),
342                                  'name': self.name
343                                  }
344                          self.fp.writelines(avio.concoct(tmp) + '\n')
345                          self.fp.flush()
346
347
348 -        def _compute(self, key, fixer, extra_stackdrop):
349                  try:
350                          return self._p[key]
351                  except KeyError:
352                          pass
353                  try:
354                          tmp = fixer(self.pmaker[key]())
355                  except KeyError:
356                          if not self.active:
357                                  raise MissingParameterError(key)
358
359                          fkey = self._fmt(key)
360                          tmpmaker = self.guesses[fkey]
361                          tmp = tmpmaker()
362                          if tmpmaker.logdens is not None:
363                                  self.map[key] = len(self.pmaker)
364                          self.pmaker[key] = tmpmaker
365                          if self.user is not None:
367                          self._write(key, extra_stackdrop)
368                  self._p[key] = tmp
369                  return tmp
370
371
372 -        def __repr__(self):
373                  s = []
374                  for k in self.map.keys():
375                          s.append('%s=%s=%s' % (','.join(k), self.pmaker.doc,
376                                                  str(self._p[k])))
377                  s.sort()
378                  s.insert(0, "active" if self.active else "frozen")
379                  return '<guess: %s>' % '\n\t'.join(s)
380
381
382 -        def set_user(self, user):
383                  self.user = user
384
385
386 -        def usage(self):
387                  """Tells you which users access each key.
388                  @returns: L{dictops.dict_of_sets} (approximately a dict)
389                          indexed by integer indices and mapping to a set of users.
390                          e.g. {('some', 'key'): set(['user1', 'user2']), ...}.
391                  @rtype: dict(tuple: set(str))
392                  """
393                  return self._usage
394
395
396 -        def p(self, *key):
397                  return self._compute(key, None, 0)
398
399          pc = p
400
401 -        def p_lower(self, ll, *key):
402                  """A range, with a lower limit."""
403                  def fix_ll(x):
404                          if x < ll:
405                                  return 2*ll - x
406                          return x
407                  fix_ll.__doc__ = "lower limit %g" % ll
408                  return self._compute(key, fix_ll, 0)
409
410          pc_lower = p_lower
411
412
413 -        def p_range(self, lb, ub, *key):
414                  assert ub > lb, "Bad range: %s<%s for %s" % (str(ub), str(lb), str(key))
415                  """A range, folding at the ends."""
416                  def fix_range(v):
417                          if not ( lb <= v <= ub):
418                                  r = ub - lb
419                                  tmp = (v-lb) % (2*r)
420                                  v = lb + (tmp if tmp<=r else 2*r-tmp)
421                          return v
422                  fix_range.__doc__ = "range from %g to %g" % (lb, ub)
423                  return self._compute(key, fix_range, 0)
424
425
426 -        def p_periodic(self, ub, *key):
427                  """A range with periodic boundary conditions 0 to ub."""
428                  assert ub > 0
429                  def fix_range(v):
430                          return v % ub
431                  fix_range.__doc__ = "periodic from 0 to %g" % ub
432                  return self._compute(key, fix_range, 0)
433
434
435 -        def freeze(self):
436                  """This call means that all parameter names have been seen.
437                  Any novel names in calls to C{self.p()} will then raise an error.
438                  """
439                  self.active = False
440
441
442 -        def set_prm(self, v, *key):
443                  """This function is used to adjust parameters in non-trivial ways.
444                  """
445                  self.modified += 1
446                  self._p[key] = v
447                  if Debug:
448                          die.info('set_prm on guess: %g %s' % (v, str(key)))
449
450
451
452
453
454
455 -class index(index_base):
456
457          reprkeys = re.compile('.')
458
459 -        def __init__(self, keymap, p=None, name='', fixed=None):
460                  assert isinstance(keymap, dict)
461                  index_base.__init__(self, keymap, name=name)
462                  assert p is None or len(p) == len(keymap)
463                  self._prms = None
464                  self._hashmod = -1
465                  if p is not None:
466                          self.set_prms(p)
467                  if fixed is not None:
468                          self.set_all_fixed(fixed)
469                  # self.hash()
470
471
472 -        def hash(self):
473                  if self._hashmod == self.modified:
474                          return self._hashcache
475                  phash = hash(tuple(self._p.values()))
476                  self._hashcache = hash( (index_base.hash(self), phash, self.modified) )
477                  self._hashmod = self.modified
478                  return self._hashcache
479
480
481 -        def clear(self):
482                  index_base.clear(self)
483                  self._prms = None
484                  self._hashmod = -1
485
486 -        def set_prms(self, prm):
487                  """This sets the vector of parameters to be used and modified.
488                  Note that p_lower() can affect the prm argument! No copy is
489                  made and this side effect is needed for the problem_definition.fixer()
490                  function.
491
492                  This function should only be called after the map is established.
493                  """
494                  assert prm.shape == (len(self.map),)
495                  self.modified = 0
496                  self._hashmod = -1
497                  self._prms = prm
498                  for (k, v) in self.map.items():
499                          self._p[k] = prm.item(v)
500
501
502 -        def set_prm(self, v, *key):
503                  """This sets one element of the parameters vector.
504                  This function should only be called after the map is established; it cannot be used to set fixed parameters.
505                  """
506                  try:
507                          self._prms.itemset(self.map[key], v)
508                  except KeyError:
509                          raise IndexKeyError(key, self)
510                  self._p[key] = v
511                  self.modified += 1
512
513
514 -        def p(self, *key):
515                  """
516                  @type key: tuple(str)
517                  @return: the parameter indexed by C{*key}.
518                  @rtype: float
519                  """
520                  try:
521                          return self._p[key]
522                  except KeyError:
523                          raise IndexKeyError(key, self)
524
525
526 -        def p_lower(self, ll, *key):
527                  """@note: this may have side effects on the array passed to prms().
528                  @raise KeyError: when you apply it to a fixed parameter below the specified C{ll}.
529                  """
530                  try:
531                          v = self._p[key]
532                  except KeyError:
533                          raise IndexKeyError(key, self)
534                  if v < ll:
535                          v = 2.0*ll - v
536                          self.modified += 1
537                          self._prms.itemset(self.map[key], v)
538                          self._p[key] = v
539                  return v
540
541
542 -        def p_range(self, lb, ub, *key):
543                  """@note: this may have side effects on the array passed to prms().
544                  @raise KeyError: when you apply it to a fixed parameter outside specified range C{(lb,ub)}.
545                  """
546                  try:
547                          v = self._p[key]
548                  except KeyError:
549                          raise IndexKeyError(key, self)
550                  if not (lb <= v <= ub):
551                          assert ub > lb
552                          r = ub - lb
553                          tmp = (v-ub) % (2*r)
554                          v = lb + (tmp if tmp<=r else 2*r-tmp)
555                          self.modified += 1
556                          self._p[key] = v
557                          self._prms.itemset(self.map[key], v)
558                  return v
559
560
561 -        def p_periodic(self, ub, *key):
562                  """A range with periodic boundary conditions 0 to ub.
563                  @note: this may have side effects on the array passed to prms().
564                  @raise KeyError: when you apply it to a fixed parameter outside specified range C{(lb,ub)}.
565                  """
566                  try:
567                          v = self._p[key]
568                  except KeyError:
569                          raise IndexKeyError(key, self)
570                  if not (0 <= v < ub):
571                          assert ub > 0
572                          v = v % ub
573                          self.modified += 1
574                          self._p[key] = v
575                          self._prms.itemset(self.map[key], v)
576                  return v
577
578          pc = p
579          pc_lower = p_lower
580
581
582 -        def get_prms(self):
583                  """@return: a vector of the adjustable parameters
584                  @rtype: L{numpy.ndarray}
585                  """
586                  return self._prms
587
588
589 -        def toDisplay(self, reprkeys=None):
590                  """@note: This prints the values without the folding or wrapping that is done by
591                          L{p_range}, L{p_periodic}, L{p_lower}, or L{p_upper}.   So, you should not
592                          be surprised or disturbed if you see a value outside the expected range.
593                          If you extract the values normally, using the p_* accessor funcitons,
594                          all will be well.
595                  """
596                  if reprkeys is None:
597                          reprkeys = self.reprkeys
598                  tmp = []
599                  for k in self.map.keys():
600                          ifk = index._fmt(k)
601                          if reprkeys.match(ifk):
602                                  tmp.append("%s=%.3f" % (index._fmt(k), self.p(*k)) )
603                  if len(tmp) < self.n():
604                          tmp.append('...')
605                  tmp.sort()
606                  return ', '.join(tmp)
607
608
609 -        def __repr__(self):
610                  return '<Index:%s %s>' % (self.name, self.toDisplay())
611
612
613 -        def items(self):
614                  """@return: a sequence of all the parameter, fixed or adjustable.
615                  @rtype: sequence( (key, value) ) where C{key} could be anything hashable,
616                          but is usually a tuple of strings, and C{value} is a L{float}.
617                  """
618                  return self._p.items()
619
620 -        def keys(self):
621                  return self._p.keys()
622
623
624 -def default_mapfp(nmprefix='m'):
625          import time
626          import gpkmisc
627          gpkmisc.makedirs('MAPS')
628          # os.system('mkdir -p MAPS')
629          fname = time.strftime('%Y%m%d-%H%M%S.map')
630          fp = open(os.path.join('MAPS', '%s%s' % (nmprefix, fname)),
631                                  'w')
632          return fp
633
634
635 -def guess_to_indexer(g):
636          assert isinstance(g, guess)
637          return index(g.map, g.get_prms(), name=g.name, fixed=g.get_fixed())
638
639
640 -def reindex(q, ref):
641          """This remaps the parameters of C{q} so that they are in
642          the same order as in C{ref}.
643          @rtype: indexer
644          @return: an indexer that contains the same data as C{q}, but shares its parameter mapping with C{ref}.
645          @note: The parameters in C{q} must be a superset of the adjustable
646                  parameters in C{ref}.  Extra parameters in C{q} will be
647                  carried over as fixed parameters.
648          """
649          if set(q.keys()) != set(ref.keys()):
650                  rk = set(ref.keys())
651                  qk = set(q.keys())
652                  die.info("Common parameters: %s" % str(rk.intersection(qk)))
653                  die.info("Parameters only in q: %s" % str(qk-rk))
654                  die.info("Parameters only in ref: %s" % str(rk-qk))
655          rvn = numpy.zeros((ref.n(),))
656          rvf = {}
657          for (k,i) in ref.map.items():
658                  rvn[i] = q.p(*k)
659          for (k, v) in q.items():
660                  if k not in ref.map:
661                          rvf[k] = v
662          return index(ref.map, p=rvn, name=q.name, fixed=rvf)
663
664
665
666
667 -class index_counted(index):
668          """This is a special-purpose wrapped version of L{index}.
669          Its only function is to make sure that all parameters were
670          used, in addition to the normal checking that you do not use
671          any parameters that are undefined.
672
673          Use it exactly as you would use L{index}, but it is slower.
674          """
675
676 -        def __init__(self, idxr):
677                  index.__init__(self, idxr.map, idxr.get_prms(), name=idxr.name)
678                  self.used = set()
679
680 -        def p(self, *key):
682                  return index.p(self, *key)
683
684 -        def p_lower(self, ll, *key):
686                  return index.p_lower(self, ll, *key)
687
688 -        def p_range(self, lb, ub, *key):
690                  return index.p_range(self, lb, ub, *key)
691
692 -        def p_periodic(self, ub, *key):
694                  return index.p_periodic(self, ub, *key)
695
696 -        def confirm_all_used(self):
697                  """This can be used to check that we aren't using an indexer with
698                  a problem that is too small (for instance, if the indexer was read in
699                  from a log file).
700                  @raise IndexKeyError: if there are some keys in the index that have not
701                          been used.
702                  """
703                  sm = set(self.map.keys())
704                  if self.used != sm:
705                          die.info("used keys: %s" % self.used)
706                          die.info("keys in map: %s" % self.map.keys())
707                          raise IndexKeyError("Unused keys:", sm-self.used)
708
<!--
expandto(location.href);
// -->

```

 Generated by Epydoc 3.0.1 on Thu Sep 22 04:25:11 2011 http://epydoc.sourceforge.net