1 import re
2
3 import numpy
4
5 from gmisclib import die
6 from gmisclib import fiatio
7 from gmisclib import dictops
8 from gmisclib import gpkmisc
9 from gmisclib import nice_hash
10 from gmisclib import mcmc_helper as mcmcP
11
12
16
17
19 """This class is called from the Markov-Chain stepper
20 and writes the log file.
21 """
22 - def __init__(self, logwriter, iterstep=100, ergstep=0.1):
23 assert iterstep >= 0
24 assert ergstep >= 0.0
25
26 self.lg = logwriter
27 self.iterstep = iterstep
28 self.ergstep = ergstep
29
30 self.nlogged = 0
31 self.erg = 0.0
32 self.last_add_iter = 0
33 self.last_logged_erg = None
34 self.last_logged_iter = None
35 self.best = None
36 self.reason = 'sampled'
37
38
39 - def set_logstep(self, iterstep, ergstep, reason='sampled'):
40 self.reason = reason
41 assert iterstep > 0
42 self.iterstep = iterstep
43 if ergstep is None:
44 self.ergstep = 0.0
45 else:
46 assert ergstep >= 0.0
47 self.ergstep = ergstep
48 self.lg.flush()
49
50
51 - def add(self, stepperInstance, iter):
52 """Called every step: this decides whether
53 data should be logged.
54 """
55 self.erg += stepperInstance.ergodic() * (iter-self.last_add_iter)
56 self.last_add_iter = iter
57 reason = []
58 if(self.last_logged_iter is None or self.last_logged_erg is None):
59 reason.append('initial')
60 elif(iter>=self.last_logged_iter+self.iterstep
61 and self.erg>self.last_logged_erg+self.ergstep
62 ):
63 reason.append(self.reason)
64 else:
65 clp = stepperInstance.current().logp_nocompute()
66 if self.best is None or clp>self.best:
67 self.best = clp
68 reason.append('best')
69 if reason:
70 self.Add(stepperInstance, iter, reason=','.join(reason))
71
72
73 - def Add(self, stepperInstance, iter, reason=None):
74 raise RuntimeError, "Virtual Function"
75
76
77
80
84
87
90
92 self.nlogged = 0
93 self.erg = 0.0
94 self.last_logged_erg = None
95 self.comment("Reset")
96
97
99 self.nlogged = 0
100 self.erg = 0.0
101 self.last_logged_erg = None
102 self.comment("run_to_bottom finished")
103
104
105
106
108 - def __init__(self, logwriter, iterstep=100, ergstep=0.1):
109 """logwriter is a fiatio.writer or avio.writer or similar instance.
110 """
111 logger_c.__init__(self, logwriter, iterstep=iterstep, ergstep=ergstep)
112 self.HUGEp = 1e10
113 self.HUGEl = 1e10
114
115
117 return -self.HUGEl < logpval < self.HUGEl
118
119
122
123
124 - def fmt(self, prmname):
125 return ','.join(['%s' % q for q in prmname])
126
127
130
131
132 - def Add(self, stepperInstance, iter, reason=None):
162
163
164
165
167 - def __init__(self, logwriter, iterstep=100, ergstep=0.1):
168 """logwriter is really a newstem2.newlogger instance.
169 """
170 logger_c.__init__(self, logwriter, iterstep=iterstep, ergstep=ergstep)
171
172
173
174 - def Add(self, stepperInstance, iter, reason=None):
175 """Actually write a set of parameters into the log file."""
176 print 'Logging', stepperInstance.current().logp_nocompute(), 'iter=', iter
177 self.nlogged += 1
178 self.last_logged_erg = self.erg
179 self.last_logged_iter = iter
180 current = stepperInstance.current()
181 idxr = current.pd.idxr
182 self.lg.add('UID', current.prms(),
183 idxr.map,
184 current.logp_nocompute(),
185 iter,
186 extra={'T': stepperInstance.acceptable.T(),
187 'resetid': stepperInstance.reset_id(),
188 },
189 reason=reason
190 )
191
192 if hasattr(idxr, 'get_fixed'):
193 fixed = stepperInstance.current().pd.idxr.get_fixed()
194 if fixed:
195 mp = {}
196 pc = numpy.zeros((len(fixed),))
197 for (i, (k, v)) in enumerate(sorted(fixed.items())):
198 pc.itemset(i, v)
199 mp[k] = i
200 self.lg.add('fixed', pc, mp,
201 stepperInstance.current().logp_nocompute(),
202 iter,
203 extra={'T': stepperInstance.acceptable.T(),
204 'resetid': stepperInstance.reset_id(),
205 },
206 reason=reason
207 )
208 self.lg.flush()
209
210
211
212
213
214 BadFormatError = fiatio.BadFormatError
215
217 """Starting from a file name, read in a log file.
218 @param fname: name of a (possibly compressed) log file.
219 @type fname: str
220 @param trigger: where to start in the log file. See L{readiter}.
221 @rtype: C{tuple(list(dict(str:str)), dict(str:str))}
222 @return: a list of data lines and header information.
223 A data line is a dictionary, containing parameters and other information from one iteration.
224 @note: this is very similar to L{newstem2.newlogger.read_currentlist} except that returns a more
225 complicated data structure that has one more level of indirection.
226 """
227 if trigger is not None:
228 print '# Trigger="%s"' % trigger
229 trigger = re.compile(trigger)
230 currentlist = []
231 lasthdr = {}
232 lineno = 1
233 for (hdr, data, comments) in fiatio.readiter(gpkmisc.open_compressed(fname)):
234 if comments and trigger is not None:
235 for c in comments:
236 if trigger.match(c):
237 print '# Trigger match'
238 trigger = None
239 if trigger is None and data is not None:
240 data['__line'] = lineno
241 currentlist.append( data )
242 lasthdr.update(hdr)
243 lineno += 1
244 if trigger is not None:
245 die.warn('Trigger set but never fired: %s' % fname)
246 return (currentlist, lasthdr)
247
248
251 self.h = nice_hash.simple()
252 from newstem2 import indexclass
253 self.IC = indexclass
254
256 """This constructs an L{newstem2.indexclass.index} object from a line from a log file.
257 @param d: a line from a log file specifying information for an iteration.
258 @type d: dict(str:str)
259 @return: the parameters for that iteration.
260 @rtype: L{newstem2.indexclass.index}
261 """
262 unfmt = self.IC.index_base._unfmt
263 PREFIX = 'P.'
264 lp = len(PREFIX)
265 if not self.h.n:
266
267 for (k, v) in d.items():
268 if k.startswith(PREFIX):
269 self.h.add(unfmt(k[lp:]))
270 prms = numpy.zeros((self.h.n,))
271 n = 0
272 for (k, v) in d.items():
273 if k.startswith(PREFIX):
274 prms.itemset(self.h.get_image(unfmt(k[lp:])), float(v))
275 n += 1
276 if n < self.h.n:
277 raise BadFormatError, "Missing parameter"
278 return self.IC.index(self.h.map(), p=prms, name=d.get('uid'))
279
280
281
285
286
287
289 """Read in a log file: it gives you multiple samples
290 from the log. In other words, it provides a time-series of the changes
291 to the parameters.
292
293 @param tail: Where to start in the file?
294 C{tail=0} means that you start at the trigger or beginning.
295 C{tail=1-epsilon} means you take just the last tiny bit of the file.
296 @param reader: A function that sucks in the data from the log file, yielding a list of
297 iterations and an overall header.
298 @type reader: function(str, trigger) -> (list(dict(str:str)), dict(str:str))
299 @return:
300 @rtype: (dict(str:str), list(indexer), list(float))
301 @raise NoDatError: when the log file is empty.
302 @raise BadFormatError: when a parameter appears mid-way through the optimization or an iteration cannot be read.
303 """
304 currentlist, lasthdr = reader(fname, trigger=trigger)
305 if len(currentlist) == 0:
306 raise NoDataError, fname
307
308 lc = len(currentlist)
309 tail = min(tail, 1.0 - 0.5/lc)
310 nsamp = min(Nsamp, lc-int(lc*tail))
311
312 assert nsamp > 0
313 lti = lti_c()
314 indexers = []
315 logps = []
316 try:
317 for i in range(nsamp):
318 f = float(i)/float(nsamp)
319 aCurrent = currentlist[int((tail + (1.0-tail)*f)*lc)]
320 assert len(currentlist) > 0
321 indexers.append( lti.parse(aCurrent) )
322 logps.append( float(aCurrent['logp']) )
323 except nice_hash.NotInHash, x:
324 raise BadFormatError, "Unexpected parameter %s appearing on %s:%d" % (x, aCurrent['__line'], fname)
325 except BadFormatError, x:
326 raise BadFormatError, "%s line %s:%d" % (x, aCurrent['__line'], fname)
327 return (lasthdr, indexers, logps)
328