Package gmisclib :: Module mcmc_logger
[frames] | no frames]

Source Code for Module gmisclib.mcmc_logger

  1  import re 
  2   
  3  import numpy 
  4   
  5  from gmisclib import die 
  6  from gmisclib import fiatio 
  7  from gmisclib import dictops 
  8  from gmisclib import gpkmisc 
  9  from gmisclib import nice_hash 
 10  from gmisclib import mcmc_helper as mcmcP 
 11   
 12   
13 -class WildNumber(ValueError):
14 - def __init__(self, *s):
15 ValueError.__init__(self, *s)
16 17
18 -class logger_c(mcmcP.logger_template):
19 """This class is called from the Markov-Chain stepper 20 and writes the log file. 21 """
22 - def __init__(self, logwriter, iterstep=100, ergstep=0.1):
23 assert iterstep >= 0 24 assert ergstep >= 0.0 25 # self.lg = LG.logger(fd) 26 self.lg = logwriter 27 self.iterstep = iterstep 28 self.ergstep = ergstep 29 # self.__g_implements__ = ['logger_template'] 30 self.nlogged = 0 31 self.erg = 0.0 32 self.last_add_iter = 0 33 self.last_logged_erg = None 34 self.last_logged_iter = None 35 self.best = None 36 self.reason = 'sampled'
37 38
39 - def set_logstep(self, iterstep, ergstep, reason='sampled'):
40 self.reason = reason 41 assert iterstep > 0 42 self.iterstep = iterstep 43 if ergstep is None: 44 self.ergstep = 0.0 45 else: 46 assert ergstep >= 0.0 47 self.ergstep = ergstep 48 self.lg.flush()
49 50
51 - def add(self, stepperInstance, iter):
52 """Called every step: this decides whether 53 data should be logged. 54 """ 55 self.erg += stepperInstance.ergodic() * (iter-self.last_add_iter) 56 self.last_add_iter = iter 57 reason = [] 58 if(self.last_logged_iter is None or self.last_logged_erg is None): 59 reason.append('initial') 60 elif(iter>=self.last_logged_iter+self.iterstep 61 and self.erg>self.last_logged_erg+self.ergstep 62 ): 63 reason.append(self.reason) 64 else: 65 clp = stepperInstance.current().logp_nocompute() 66 if self.best is None or clp>self.best: 67 self.best = clp 68 reason.append('best') 69 if reason: 70 self.Add(stepperInstance, iter, reason=','.join(reason))
71 72
73 - def Add(self, stepperInstance, iter, reason=None):
74 raise RuntimeError, "Virtual Function"
75 76 77
78 - def close(self):
79 self.lg.close()
80
81 - def comment(self, comment):
82 self.lg.comment(comment) 83 self.lg.flush()
84
85 - def header(self, k, v):
86 self.lg.header(k, v)
87
88 - def headers(self, hdict):
89 self.lg.headers(hdict)
90
91 - def reset(self):
92 self.nlogged = 0 93 self.erg = 0.0 94 self.last_logged_erg = None 95 self.comment("Reset")
96 97
98 - def set_trigger_point(self):
99 self.nlogged = 0 100 self.erg = 0.0 101 self.last_logged_erg = None 102 self.comment("run_to_bottom finished")
103 104 105 106
107 -class logger_A(logger_c):
108 - def __init__(self, logwriter, iterstep=100, ergstep=0.1):
109 """logwriter is a fiatio.writer or avio.writer or similar instance. 110 """ 111 logger_c.__init__(self, logwriter, iterstep=iterstep, ergstep=ergstep) 112 self.HUGEp = 1e10 113 self.HUGEl = 1e10
114 115
116 - def ok_logp(self, logpval):
117 return -self.HUGEl < logpval < self.HUGEl
118 119
120 - def ok_prms(self, prms):
121 return numpy.greater_equal(self.HUGEp, prms) * numpy.greater_equal(prms, -self.HUGEp)
122 123
124 - def fmt(self, prmname):
125 return ','.join(['%s' % q for q in prmname])
126 127
128 - def get_map(self, stepper):
129 return stepper.current().pd.idxr.map
130 131
132 - def Add(self, stepperInstance, iter, reason=None):
133 self.nlogged += 1 134 self.last_logged_erg = self.erg 135 self.last_logged_iter = iter 136 prms = stepperInstance.current().prms() 137 logp = stepperInstance.current().logp_nocompute(), 138 tmp = {'T': stepperInstance.acceptable.T(), 139 'resetid': stepperInstance.reset_id(), 140 'iter': iter 141 } 142 if reason is not None: 143 tmp['why_logged'] = reason 144 wild = [] 145 if logp is None or not self.ok_logp(logp): 146 tmp['logp'] = '%.2f' % logp 147 else: 148 wild.append("logp=%s" % str(logp)) 149 okprm = self.ok_prms(prms) 150 map = self.get_map(stepperInstance) 151 if not okprm.all(): 152 rmap = dictops.rev1to1(map) 153 for (i,isok) in enumerate(okprm): 154 if not isok: 155 wild.append( 'p%d(%s)=%s' % (i, self.fmt(rmap[i]), prms[i]) ) 156 for (nm, i) in map.items(): 157 tmp['P.%s' % self.fmt(nm)] = prms[i] 158 self.lg.datum(tmp) 159 self.lg.flush() 160 if wild: 161 raise WildNumber, '; '.join(wild)
162 163 164 165
166 -class logger_N(logger_c):
167 - def __init__(self, logwriter, iterstep=100, ergstep=0.1):
168 """logwriter is really a newstem2.newlogger instance. 169 """ 170 logger_c.__init__(self, logwriter, iterstep=iterstep, ergstep=ergstep)
171 172 173
174 - def Add(self, stepperInstance, iter, reason=None):
175 """Actually write a set of parameters into the log file.""" 176 print 'Logging', stepperInstance.current().logp_nocompute(), 'iter=', iter 177 self.nlogged += 1 178 self.last_logged_erg = self.erg 179 self.last_logged_iter = iter 180 current = stepperInstance.current() 181 idxr = current.pd.idxr 182 self.lg.add('UID', current.prms(), 183 idxr.map, 184 current.logp_nocompute(), 185 iter, 186 extra={'T': stepperInstance.acceptable.T(), 187 'resetid': stepperInstance.reset_id(), 188 }, 189 reason=reason 190 ) 191 192 if hasattr(idxr, 'get_fixed'): 193 fixed = stepperInstance.current().pd.idxr.get_fixed() 194 if fixed: 195 mp = {} 196 pc = numpy.zeros((len(fixed),)) 197 for (i, (k, v)) in enumerate(sorted(fixed.items())): 198 pc.itemset(i, v) 199 mp[k] = i 200 self.lg.add('fixed', pc, mp, 201 stepperInstance.current().logp_nocompute(), 202 iter, 203 extra={'T': stepperInstance.acceptable.T(), 204 'resetid': stepperInstance.reset_id(), 205 }, 206 reason=reason 207 ) 208 self.lg.flush()
209 210 211 212 213 214 BadFormatError = fiatio.BadFormatError 215
216 -def read_currentlist(fname, trigger=None):
217 """Starting from a file name, read in a log file. 218 @param fname: name of a (possibly compressed) log file. 219 @type fname: str 220 @param trigger: where to start in the log file. See L{readiter}. 221 @rtype: C{tuple(list(dict(str:str)), dict(str:str))} 222 @return: a list of data lines and header information. 223 A data line is a dictionary, containing parameters and other information from one iteration. 224 @note: this is very similar to L{newstem2.newlogger.read_currentlist} except that returns a more 225 complicated data structure that has one more level of indirection. 226 """ 227 if trigger is not None: 228 print '# Trigger="%s"' % trigger 229 trigger = re.compile(trigger) 230 currentlist = [] 231 lasthdr = {} 232 lineno = 1 233 for (hdr, data, comments) in fiatio.readiter(gpkmisc.open_compressed(fname)): 234 if comments and trigger is not None: 235 for c in comments: 236 if trigger.match(c): 237 print '# Trigger match' 238 trigger = None 239 if trigger is None and data is not None: 240 data['__line'] = lineno 241 currentlist.append( data ) 242 lasthdr.update(hdr) 243 lineno += 1 244 if trigger is not None: 245 die.warn('Trigger set but never fired: %s' % fname) 246 return (currentlist, lasthdr)
247 248
249 -class lti_c(object):
250 - def __init__(self):
251 self.h = nice_hash.simple() 252 from newstem2 import indexclass 253 self.IC = indexclass
254
255 - def parse(self, d):
256 """This constructs an L{newstem2.indexclass.index} object from a line from a log file. 257 @param d: a line from a log file specifying information for an iteration. 258 @type d: dict(str:str) 259 @return: the parameters for that iteration. 260 @rtype: L{newstem2.indexclass.index} 261 """ 262 unfmt = self.IC.index_base._unfmt 263 PREFIX = 'P.' 264 lp = len(PREFIX) 265 if not self.h.n: 266 # Build a map of parameter names to integers 267 for (k, v) in d.items(): 268 if k.startswith(PREFIX): 269 self.h.add(unfmt(k[lp:])) 270 prms = numpy.zeros((self.h.n,)) 271 n = 0 272 for (k, v) in d.items(): 273 if k.startswith(PREFIX): 274 prms.itemset(self.h.get_image(unfmt(k[lp:])), float(v)) 275 n += 1 276 if n < self.h.n: 277 raise BadFormatError, "Missing parameter" 278 return self.IC.index(self.h.map(), p=prms, name=d.get('uid'))
279 280 281
282 -class NoDataError(ValueError):
283 - def __init__(self, *s):
284 ValueError.__init__(self, *s)
285 286 287
288 -def read_multisample(fname, Nsamp=10, tail=0.0, trigger=None, reader=read_currentlist):
289 """Read in a log file: it gives you multiple samples 290 from the log. In other words, it provides a time-series of the changes 291 to the parameters. 292 293 @param tail: Where to start in the file? 294 C{tail=0} means that you start at the trigger or beginning. 295 C{tail=1-epsilon} means you take just the last tiny bit of the file. 296 @param reader: A function that sucks in the data from the log file, yielding a list of 297 iterations and an overall header. 298 @type reader: function(str, trigger) -> (list(dict(str:str)), dict(str:str)) 299 @return: 300 @rtype: (dict(str:str), list(indexer), list(float)) 301 @raise NoDatError: when the log file is empty. 302 @raise BadFormatError: when a parameter appears mid-way through the optimization or an iteration cannot be read. 303 """ 304 currentlist, lasthdr = reader(fname, trigger=trigger) 305 if len(currentlist) == 0: 306 raise NoDataError, fname 307 308 lc = len(currentlist) 309 tail = min(tail, 1.0 - 0.5/lc) 310 nsamp = min(Nsamp, lc-int(lc*tail)) # How many samples to extract? 311 # print '# len(currentlist)=', len(currentlist), 'tail=', tail, 'nsamp=', nsamp 312 assert nsamp > 0 313 lti = lti_c() 314 indexers = [] 315 logps = [] 316 try: 317 for i in range(nsamp): 318 f = float(i)/float(nsamp) 319 aCurrent = currentlist[int((tail + (1.0-tail)*f)*lc)] 320 assert len(currentlist) > 0 321 indexers.append( lti.parse(aCurrent) ) 322 logps.append( float(aCurrent['logp']) ) 323 except nice_hash.NotInHash, x: 324 raise BadFormatError, "Unexpected parameter %s appearing on %s:%d" % (x, aCurrent['__line'], fname) 325 except BadFormatError, x: 326 raise BadFormatError, "%s line %s:%d" % (x, aCurrent['__line'], fname) 327 return (lasthdr, indexers, logps)
328