Package gmisclib :: Module mcmc_mpi
[frames] | no frames]

Source Code for Module gmisclib.mcmc_mpi

  1  # -*- coding: utf-8 -*- 
  2  """This is a helper module to make use of mcmc.py and mcmc_big.py. 
  3  It allows you to conveniently run a Monte-Carlo simulation of any 
  4  kind until it converges (L{stepper.run_to_bottom}) or until it 
  5  has explored a large chunk of parameter space (L{stepper.run_to_ergodic}). 
  6   
  7  It also helps you with logging the process. 
  8   
  9  When run in parallel, each processor does its thing more-or-less 
 10  independently.   However, every few steps, they exchange notes on 
 11  their current progress.   If one finds an improved vertex, it will be 
 12  passed on to other processors via MPI. 
 13  """ 
 14   
 15  import sys 
 16  import mpi              # This uses pyMPI 
 17  import random 
 18  import numpy 
 19  from gmisclib import die 
 20  from gmisclib import mcmc 
 21  from gmisclib import mcmc_helper as MCH 
 22  Debug = 0 
 23  from gmisclib.mcmc_helper import TooManyLoops, warnevery, logger_template, test 
 24  from gmisclib.mcmc_helper import step_acceptor, make_stepper_from_lov 
 25   
 26   
 27   
28 -class stepper(MCH.stepper):
29 - def __init__(self, x, maxloops=-1, logger=None, share=None):
30 die.info('# mpi stepper rank=%d size=%d' % (rank(), size())) 31 assert maxloops == -1 32 MCH.stepper.__init__(self, x, maxloops, logger)
33 34
35 - def reset_loops(self, maxloops=-1):
36 assert maxloops == -1 37 MCH.stepper.reset_loops(self, maxloops)
38 39
40 - def communicate_hook(self, id):
41 self.note('chook iter=%d' % self.iter, 4) 42 if size() > 1: 43 self.note('chook active iter=%d' % self.iter, 3) 44 handle = mpi.irecv(tag=self.MPIID) 45 c = self.x.current() 46 v = c.vec() 47 lp = c.logp() 48 mpi.send((v, lp, id), (mpi.rank+1)%mpi.size, tag=self.MPIID) 49 handle.wait() 50 nv, nlp, nid = handle.message 51 assert nid==id, "ID mismatch: %d/%d or %s/%s" % (id, nid) 52 53 r = rank() 54 s = size() 55 self.note('sendrecv from %d to %d' % (r, (r+1)%s), 5) 56 nv, nlp, nid = mpi.sendrecv(sendobj=(v, lp, id), 57 dest=(r+1)%s, sendtag=self.MPIID, 58 source=(r+s-1)%s, 59 recvtag=self.MPIID 60 ) 61 62 self.note('communicate succeeded from %s' % nid, 1) 63 delta = nlp - lp 64 if self.x.acceptable(delta): 65 q = self.x._set_current(c.new(nv-v, logp=nlp)) 66 self.note('communicate accepted: %s' % q, 1) 67 else: 68 self.note('communicate not accepted %g vs %g' % (nlp, lp), 1) 69 self.x._set_current(self.x.current()) 70 sys.stdout.flush()
71 72 73 MPIID = 1241 74 75
76 - def _nc_get_hook(self, nc):
77 self.note('_nye pre', 5) 78 ncsum = mpi.allreduce(float(nc), mpi.SUM) 79 self.note('_nye post', 5) 80 return ncsum/float(size())
81 82
83 - def _not_at_bottom(self, xchanged, nchg, es, dotchanged, ndot):
84 mytmp = (numpy.sometrue(numpy.less(xchanged,nchg)) 85 or es<1.0 or dotchanged<ndot 86 or self.x.acceptable.T()>1.5 87 ) 88 self.note('_nab pre', 5) 89 ntrue = mpi.allreduce(int(mytmp), mpi.SUM) 90 self.note('_nab post', 5) 91 return ntrue*4 >= mpi.size
92 93
94 - def synchronize_start(self, id):
95 self.synchronize('start ' + id)
96
97 - def synchronize_end(self, id):
98 self.synchronize('end ' + id)
99
100 - def synchronize_abort(self, id):
101 # MPI has no good way to handle this case! 102 raise RuntimeError, "MPI cannot handle an abort."
103
104 - def synchronize(self, id):
105 self.note('pre join %s' % id, 5) 106 rootid = mpi.bcast(id) 107 assert rootid == id 108 self.note('post join %s' % id, 5)
109 110
111 - def note(self, s, lvl):
112 if Debug >= lvl: 113 sys.stderr.writelines('# %s, rank=%d\n' % (s, rank())) 114 sys.stderr.flush()
115
116 - def size(self):
117 return mpi.Get_size()
118
119 - def rank(self):
120 return mpi.Get_rank()
121 122
123 -def precompute_logp(lop):
124 """Does a parallel evaluation of logp for all items in lop. 125 """ 126 nper = len(lop)//mpi.size 127 r = rank() 128 mychunk = lop[r*nper:(r+1)*nper] 129 for p in mychunk: 130 q = p.logp() 131 print 'logp=', q, 'for rank', r 132 for r in range(size()): 133 nc = mpi.bcast(mychunk, r) 134 lop[r*nper:(r+1)*nper] = nc 135 mpi.barrier()
136 137 138 139
140 -def test_():
141 def test_logp(x, c): 142 # print '#', x[0], x[1] 143 return -(x[0]-x[1]**2)**2 + 0.001*x[1]**2
144 x = mcmc.bootstepper(test_logp, numpy.array([0.0,2.0]), 145 numpy.array([[1.0,0],[0,2.0]])) 146 print 'TEST: rank=', rank() 147 thr = stepper(x) 148 # nsteps = thr.run_to_bottom(x) 149 # print '#nsteps', nsteps 150 # assert nsteps < 100 151 for i in range(2): 152 print 'RTC' 153 thr.run_to_change(2) 154 print 'RTE' 155 thr.run_to_ergodic(1.0) 156 print 'DONE' 157 thr.close() 158 159 160 161 if __name__ == '__main__': 162 test_() 163