Package gmisclib :: Module HTK_HMM_io
[frames] | no frames]

Source Code for Module gmisclib.HTK_HMM_io

  1  # -*- coding: utf-8 -*- 
  2  """This reads in HMM files produced by HTK. 
  3  """ 
  4   
  5  import re 
  6  import sys 
  7  import numpy 
  8  import string 
  9  from gmisclib import g_encode 
 10   
 11   
 12  I = '  ' 
 13   
14 -class BadFormatError(ValueError):
15 - def __init__(self, t, s):
16 if t is not None: 17 prefix = 'Line %d in %s: ' % (t.lnum, t.current.name) 18 else: 19 prefix = '' 20 ValueError.__init__(self, '%s%s' % (prefix, s))
21 22 23 24
25 -class vec(object):
26 cname = {'u': 'Mean', 'v': 'Variance', 'i': 'InvCov', 't': 'TransP'} 27 dim = {'u': 1, 'v': 1, 'i': -2, 't': 2} 28
29 - def __init__(self, vtype, v, name=None):
30 """@type vtype: string, one of 'u', 'v', 'i', 't' 31 """ 32 self.name = name 33 self.v = v 34 assert vtype in self.cname 35 self.code = vtype 36 self.id = hash((self.name, self.code, v.tostring()))
37
38 - def __eq__(self, other):
39 return self.name==other.name and self.code==other.code and self.id==other.id
40
41 - def __hash__(self):
42 return self.id
43
44 - def __repr__(self):
45 return "<%s:%s %s>" % (self.code, self.name, str(self.v))
46
47 - def write(self):
48 o = ['<%s> %d' % (self.cname[self.code], self.v.shape[0])] 49 n = self.v.shape[0] 50 title = self.cname[self.code] 51 dim = self.dim[self.code] 52 vec = self.v 53 if dim == 1: 54 assert len(vec.shape)==1 55 tmp = [ '%.4f' % q for q in vec] 56 o = [ I*4 + '<%s> %d' % (title, n), 57 I*5 + ' '.join(tmp) 58 ] 59 elif dim == -2: 60 o = [ I*4 + '<%s> %d' % (title, n) ] 61 for i in range(n): 62 tmp = [ I*5 ] 63 for j in range(i, n): 64 tmp.append('%.3f' % vec[i,j]) 65 o.append(' '.join(tmp)) 66 elif dim == 2: 67 o = [ I*4 + '<%s> %d' % (title, n) ] 68 for i in range(n): 69 tmp = [ I*5 ] 70 for j in range(n): 71 tmp.append('%.3f' % vec[i,j]) 72 o.append(' '.join(tmp)) 73 return o
74 75
76 -class mixture(object):
77 code = 'm' 78
79 - def __init__(self, wt=1.0, mean=None, var=None, name=None):
80 self.name = name 81 self.wt = wt 82 self.m = mean 83 self.v = var
84
85 - def __hash__(self):
86 return hash((self.wt, self.m, self.v))
87
88 - def __eq__(self, other):
89 return self.wt==other.wt and self.m==other.m and self.v==other.v
90
91 - def check(self):
92 assert self.v is not None 93 assert self.m is not None 94 assert self.wt >= 0.0 95 assert isinstance(self.v, vec) or isinstance(self.v, ref) 96 assert isinstance(self.m, vec) or isinstance(self.m, ref)
97
98 - def __repr__(self):
99 v0 = str(self.v)[:50] 100 m0 = str(self.m)[:50] 101 return "[M wt=%g mu=%s... v=%s...]" % (self.wt, m0, v0)
102
103 - def write(self, imix=None):
104 o = [] 105 if imix is not None: 106 o.append(I*3 + '<Mixture> %d %f' % (imix, self.wt)) 107 o.extend(self.m.write()) 108 o.extend(self.v.write()) 109 return o
110 111
112 -class ref(object):
113 - def __init__(self, thing):
114 assert thing.code is not None and thing.name is not None, "thing=%s" % repr(thing) 115 self.thing = thing
116
117 - def write(self):
118 return [I*4 + '~%s "%s"' % (self.thing.code, self.thing.name)]
119 120
121 -def deref(x):
122 if isinstance(x, ref): 123 return x.thing 124 return x
125 126
127 -class state(object):
128 - def __init__(self, name=None):
129 self.mx = [None, mixture()] 130 self.name = name
131
132 - def create_mx(self, i):
133 if i < len(self.mx): 134 if self.mx[i] is None: 135 self.mx[i] = mixture() 136 else: 137 self.mx.extend( [None] * (i-len(self.mx)) ) 138 self.mx.append( mixture() ) 139 return self.mx[i]
140
141 - def check(self, istate=-1, phone='???'):
142 if self.mx[0] is not None: 143 raise BadFormatError(None, "mixture[0] should be None in %s" % phone) 144 for i in range(1, len(self.mx)): 145 if self.mx[i] is None: 146 raise BadFormatError(None, "mixture %d/%d not set in state %d of %s" % (i, len(self.mx), istate, phone)) 147 self.mx[i].check()
148
149 - def __repr__(self):
150 return "<STATE %s>" % (', '.join([str(q) for q in self.mx]))
151
152 - def write(self, istate):
153 o = [] 154 nmix = len(self.mx) - 1 155 if istate is None: 156 o.append(I*2 + '<NumMixes> %d' % nmix) 157 else: 158 o.append(I*2 + '<State> %d <NumMixes> %d' % (istate, nmix)) 159 for (i, mx) in enumerate(self.mx[1:]): 160 o.extend(mx.write(i+1)) 161 return o
162 163 164 165
166 -class hmm(object):
167 - def __init__(self, numstates, name=None):
168 self.states = [ None ] * numstates 169 self.t = None 170 self.name = name
171
172 - def check(self):
173 if self.states[0] is not None or self.states[1] is not None: 174 raise BadFormatError(None, "State[0] and [1] should be None in %s" % self.name) 175 for i in range(2, len(self.states)): 176 if self.states[i] is None: 177 raise BadFormatError(None, "State %d/%d not set in %s" % (i, len(self.states), self.name)) 178 self.states[i].check(istate=i, phone=self.name)
179
180 - def write(self):
181 o = [] 182 for (i,s) in enumerate(self.states[2:]): 183 o.extend(s.write(i+2)) 184 tr = self.t.write() 185 nstates = len(self.states) 186 return ['~h "%s"' % self.name, '<BeginHMM>', 187 I*1 + '<NumStates> %d' % nstates 188 ] + o + tr + ['<EndHMM>']
189 190 191 192 193
194 -class hmmset(object):
195 - def __init__(self):
196 self.phones = [] 197 self.v_macros = {} 198 self.u_macros = {} 199 self.i_macros = {} 200 self.s_macros = {} 201 self.o_macro = {}
202
203 - def twiddle_o(self):
204 om = self.o_macro 205 si0, si1 = om['streaminfo'] 206 sourcekind = '_'.join([om['basekind']] + om['parmkind']) 207 o = ['~o', '<STREAMINFO> %d %d' % (si0, si1), 208 '<VECSIZE> %d' % om['vecsize'], 209 '<%s>' % om['durkind'], 210 '<%s>' % sourcekind 211 ] 212 if 'covkind' in om: 213 o.append('<%s>' % covkind) 214 return ' '.join(o)
215
216 - def write(self):
217 o = [self.twiddle_o()] 218 for (k, v) in self.v_macros.items(): 219 o.append('~v "%s"' % k) 220 o.extend(v.write()) 221 for (k, v) in self.u_macros.items(): 222 o.append('~u "%s"' % k) 223 o.extend(v.write()) 224 for (k, v) in self.i_macros.items(): 225 o.append('~i "%s"' % k) 226 o.extend(v.write()) 227 for (k, v) in self.s_macros.items(): 228 o.append('~s "%s"' % k) 229 o.extend(v.write()) 230 for p in self.phones: 231 o.extend( p.write()) 232 return o
233 234
235 -class tokstream(object):
236 """This breaks the HTK HMM file down into tokens.""" 237 238 int = re.compile('([+-]?\d+)') 239 float = re.compile('([+-]?[0-9]*[.][0-9]*[e0-9+-]*)') 240 bracket = re.compile('<\s*(\w+)\s*>') 241 macro = re.compile('~([a-zA-Z])') 242 name = re.compile('"([a-zA-Z_]\w*)"') 243
244 - def __init__(self, fdlist):
245 self.fds = list(fdlist) 246 if self.fds: 247 self.current = self.fds.pop(0) 248 else: 249 self.current = None 250 self.lnum = 0 251 self.nextline() 252 self.fulltok = None
253
254 - def _getline(self):
255 if self.current is None: 256 return None 257 while True: 258 line = self.current.readline() 259 self.lnum += 1 260 if line == '': 261 return None 262 else: 263 line = line.strip() 264 if line != '': 265 return line
266
267 - def nextline(self):
268 """Get the next line from the input files(s). Sets self.line. 269 @return: the line. 270 @rtype: str or L{None}. 271 """ 272 self.line = self._getline() 273 if self.line is None: 274 if self.fds: 275 self.current = self.fds.pop(0) 276 self.lnum = 0 277 self.line = self._getline() 278 else: 279 self.line = None 280 return self.line
281 282
283 - def __nonzero__(self):
284 """@return: True when there is more work to be done.""" 285 return self.line is not None
286
287 - def get(self):
288 """Get a token. 289 This also sets C{self.fulltok} with the full text of the token. 290 @raise BadFormatError: When the remainder of the line doesn't match any legal pattern. 291 @return: The regular expression that matches the next token, and the relevant text. 292 @rtype: (compiled regular expression, str) 293 """ 294 assert self.line is not None 295 for pat in [self.macro, self.bracket, self.float, self.int, self.name]: 296 m = pat.match(self.line) 297 if m: 298 self.line = self.line[m.end():].strip() 299 if self.line == '': 300 self.nextline() 301 tok = m.group(1) 302 if pat is self.bracket: 303 tok = tok.lower() 304 self.fulltok = m.group(0) 305 return (pat, tok) 306 raise BadFormatError(self, "Cannot match remainder of line: '%s'" % self.line)
307
308 - def get_name(self):
309 p, tok = self.get() 310 if p is not self.name: 311 raise BadFormatError(self, "Expected a name, got %s" % self.fulltok) 312 return tok
313
314 - def get_bracket(self):
315 p, tok = self.get() 316 if p is not self.bracket: 317 raise BadFormatError(self, "Expected a bracket, got %s" % self.fulltok) 318 return tok
319
320 - def get_int(self):
321 p, tok = self.get() 322 if p is not self.int: 323 raise BadFormatError(self, "Expected an int, got %s" % self.fulltok) 324 return int(tok)
325
326 - def get_float(self):
327 p, tok = self.get() 328 if p is not self.float and p is not self.int: 329 raise BadFormatError(self, "Expected a float, got %s" % self.fulltok) 330 return float(tok)
331 332 333 334 335
336 -def read_floats_t(vtype, dim, t, tok, phoneme='???'):
337 n = t.get_int() 338 return read_vec(dim, n, t, vtype=vtype)
339 340
341 -def read_floats_b(vtype, dim, t, tok, phoneme='???', name=None):
342 if t.get_bracket() != vec.cname[vtype].lower(): 343 raise BadFormatError(t, 'macro "%s" expects a <variance> tag (phoneme=%s)' % (tok, phoneme)) 344 n = t.get_int() 345 return read_vec(dim, n, t, vtype=vtype, name=name)
346 347
348 -def read_vec(dim, n, t, name=None, vtype=None):
349 if dim == 1: 350 tmp = numpy.zeros((n,)) 351 for i in range(n): 352 tmp[i] = t.get_float() 353 elif dim == -2: 354 tmp = numpy.zeros((n,n)) 355 for i in range(n): 356 for j in range(i, n): 357 tmp[i,j] = t.get_float() 358 for i in range(n): 359 for j in range(i+1, n): 360 tmp[j,i] = tmp[i,j] 361 elif dim == 2: 362 tmp = numpy.zeros((n,n)) 363 for i in range(n): 364 for j in range(n): 365 tmp[i,j] = t.get_float() 366 else: 367 raise AssertionError, "Never happens" 368 return vec(vtype, tmp, name)
369 370 371 372 373 374
375 -def read(fdlist):
376 rv = hmmset() 377 phoneme = None 378 statenum = None 379 mixnum = None 380 thishmm = None 381 thisstate = None 382 thismix = None 383 t = tokstream(fdlist) 384 while t: 385 pat, tok = t.get() 386 if pat is t.macro: 387 if tok == 'h': 388 phoneme = t.get_name() 389 elif tok == 'o': 390 rv.o_macro = {} 391 elif tok == 'u' and thismix is not None: 392 try: 393 thismix.m = ref(rv.u_macros[t.get_name()]) 394 except KeyError, x: 395 raise BadFormatError(t, "undefined macro ~%s '%s' in phoneme %s" % (tok, x, phoneme)) 396 elif tok == 'u': 397 name = t.get_name() 398 rv.u_macros[name] = read_floats_b('u', 1, t, tok, '---', name=name) 399 elif tok == 'v' and thismix is not None: 400 try: 401 thismix.v = ref(rv.v_macros[t.get_name()]) 402 except KeyError, x: 403 raise BadFormatError(t, "undefined macro ~%s '%s' in phoneme %s" % (tok, x, phoneme)) 404 elif tok == 'v': 405 name = t.get_name() 406 rv.v_macros[name] = read_floats_b('v', 1, t, tok, '---', name=name) 407 elif tok == 'i' and thismix is not None: 408 try: 409 thismix.v = ref(rv.i_macros[t.get_name()]) 410 except KeyError, x: 411 raise BadFormatError(t, "undefined macro ~%s '%s' in phoneme %s" % (tok, x, phoneme)) 412 elif tok == 'i': 413 name = t.get_name() 414 rv.i_macros[name] = read_floats_b('i', -2, t, tok, '---', name=name) 415 elif tok == 's' and statenum is not None: 416 name = t.get_name() 417 assert thishmm is not None 418 thishmm.states[statenum] = rv.s_macros[t.get_name()] 419 elif tok == 's': 420 assert 0 421 elif pat is t.bracket and tok=='gconst': 422 t.get_float() 423 elif pat is t.bracket and tok=='streaminfo': 424 rv.o_macro['streaminfo'] = (t.get_int(), t.get_int()) 425 elif pat is t.bracket and tok=='vecsize': 426 rv.o_macro['vecsize'] = t.get_int() 427 elif pat is t.bracket and tok=='beginhmm': 428 if phoneme is None: 429 raise BadFormatError(t, "phoneme not named at BeginState") 430 if t.get_bracket() != 'numstates': 431 raise BadFormatError(t, "Expected numstates, got <%s>" % tok) 432 numstates = t.get_int() 433 if not (numstates > 0): 434 raise BadFormatError(t, "numstates out of range: 0<%d for phoneme %s" % (numstates, phoneme)) 435 thishmm = hmm(numstates, phoneme) 436 elif pat is t.bracket and tok=='state': 437 if numstates is None: 438 raise BadFormatError(t, "numstates not set before state") 439 thisstate = state() 440 statenum = t.get_int() 441 thishmm.states[statenum] = thisstate 442 if not (statenum>1 and statenum<= numstates): 443 raise BadFormatError(t, "State number out of range: 1<%d<%d" % (statenum, numstates)) 444 mixnum = 1 445 nummixes = 1 446 thismix = thisstate.create_mx(mixnum) 447 elif pat is t.bracket and tok=='nummixes': 448 if thisstate is None: 449 raise BadFormatError(t, "state not set before nummixes") 450 nummixes = t.get_int() 451 if not (nummixes > 0): 452 raise BadFormatError(t, "nummixes out of range: 0<%d for state %d" % (nummixes, statenum)) 453 elif pat is t.bracket and tok=='mixture': 454 if statenum is None: 455 raise BadFormatError(t, "state not set before mixture") 456 nummixes = None 457 mixnum = t.get_int() 458 thismix = thisstate.create_mx(mixnum) 459 thismix.wt = t.get_float() 460 elif pat is t.bracket and tok=='transp': 461 n = t.get_int() 462 if n != numstates: 463 raise BadFormatError(t, "transp: size mismatch: %d != %d" % (n, numstates)) 464 thishmm.t = read_vec(2, numstates, t, vtype='t') 465 elif pat is t.bracket and tok=='endhmm': 466 thisstate = None 467 mixnum = None 468 thishmm.check() 469 rv.phones.append(thishmm) 470 thishmm = None 471 phoneme = None 472 elif thismix is not None and pat is t.bracket and tok=='variance': 473 thismix.v = read_floats_t('v', 1, t, tok, phoneme) 474 elif thismix is not None and pat is t.bracket and tok=='mean': 475 thismix.m = read_floats_t('u', 1, t, tok, phoneme) 476 elif thismix is not None and pat is t.bracket and tok=='invcov': 477 thismix.v = read_floats_t('i', -2, t, tok, phoneme) 478 elif pat is t.bracket and tok in ['diagc', 'invdiagc', 'fullc', 'lltc', 'xformc']: 479 rv.o_macro['covkind'] = tok 480 elif pat is t.bracket and tok in [ 'nulld', 'poissond', 'gammad', 'gend']: 481 rv.o_macro['durkind'] = tok 482 elif pat is t.bracket and tok=='user': 483 rv.o_macro['basekind'] = tok 484 rv.o_macro['parmkind'] = '' 485 elif pat is t.bracket and '_' in tok: 486 klist = tok.split('_') 487 rv.o_macro['basekind'] = klist[0] 488 rv.o_macro['parmkind'] = klist[1:] 489 else: 490 raise BadFormatError(t, "Unexpectedly saw %s." % t.fulltok) 491 492 return rv
493 494
495 -class NoPhonemeToFormat(ValueError):
496 - def __init__(self, *s):
497 ValueError.__init__(self, *s)
498 499 _e = g_encode.encoder(allowed=string.letters+string.digits, eschar='_') 500 _needsq = re.compile('(^[0-9])|[A-Z]') 501 _q = re.compile('q') 502
503 -def HTKformat(s):
504 """Format for HTK state name.""" 505 assert isinstance(s, str), "s=%s" % repr(s) 506 if not s: 507 raise NoPhonemeToFormat("lbl='%s'" % s) 508 s = _q.sub('qq', s) 509 s = _needsq.sub(lambda x: 'q%s' % x.group(0).lower(), s) 510 return _e.fwd(s)
511 512 513 _brkt = re.compile(r'\[|\]|%')
514 -def dict_format(s):
515 """Format for center [bracket] in HTK dictionary.""" 516 assert isinstance(s, str), "s=%s" % repr(s) 517 if not s: 518 raise NoPhonemeToFormat("lbl='%s'" % s) 519 return _brkt.sub(lambda x: '%%%d' % ord(x.group(0)), s)
520 521 522 523 _TEST = """ 524 ~o 525 <STREAMINFO> 1 26 526 <VECSIZE> 26<NULLD><MFCC_E_D><DIAGC> 527 ~v "b" 528 <VARIANCE> 26 529 6.262400e+01 5.858900e+01 1.003000e+02 7.347500e+01 6.431700e+01 5.047400e+01 6.242300e+01 4.893200e+01 5.247800e+01 3.391300e+01 3.731700e+01 3.165200e+01 4.400000e-02 5.378000e+00 5.956000e+00 6.876000e+00 6.459000e+00 6.584000e+00 6.842000e+00 6.603000e+00 5.119000e+00 5.212000e+00 4.712000e+00 4.471000e+00 3.459000e+00 5.000000e-03 530 ~o <STREAMINFO> 1 26 <VECSIZE> 26 <NULLD><MFCC_E_D> 531 ~h "b" 532 <BeginHMM> 533 <NumStates> 6 534 <State> 2 <NumMixes> 1 535 <Mixture> 1 1 536 <Mean> 26 537 0.939 6.000 9.331 -4.474 -9.654 -4.482 -9.613 -2.793 -1.901 -0.622 -2.544 -3.954 0.786 0.413 -0.327 0.125 -1.229 -0.422 0.067 -0.052 -0.266 -0.783 -0.376 -0.270 -0.252 0.020 538 ~v "b" 539 <State> 3 <NumMixes> 1 540 <Mixture> 1 1 541 <Mean> 26 542 0.909 6.071 9.146 -4.544 -9.690 -4.453 -9.601 -2.748 -1.946 -0.706 -2.698 -3.882 0.790 0.401 -0.369 0.164 -1.257 -0.419 0.133 -0.073 -0.286 -0.796 -0.369 -0.296 -0.248 0.020 543 ~v "b" 544 <State> 4 <NumMixes> 1 545 <Mixture> 1 1 546 <Mean> 26 547 0.823 5.966 9.268 -4.783 -9.793 -4.583 -9.584 -2.669 -2.134 -0.603 -2.607 -3.922 0.784 0.424 -0.337 0.192 -1.219 -0.428 0.100 -0.100 -0.311 -0.810 -0.350 -0.268 -0.270 0.020 548 ~v "b" 549 <State> 5 <NumMixes> 1 550 <Mixture> 1 1 551 <Mean> 26 552 0.891 6.011 9.190 -4.533 -9.694 -4.507 -9.363 -2.863 -2.066 -0.611 -2.551 -3.939 0.785 0.430 -0.344 0.155 -1.231 -0.445 0.104 -0.034 -0.269 -0.785 -0.371 -0.272 -0.261 0.020 553 ~v "b" 554 <TransP> 6 555 0.0000 1.0000 0.0000 0.0000 0.0000 0.0000 556 0.0000 0.8000 0.2000 0.0000 0.0000 0.0000 557 0.0000 0.0000 0.8000 0.2000 0.0000 0.0000 558 0.0000 0.0000 0.0000 0.8000 0.2000 0.0000 559 0.0000 0.0000 0.0000 0.0000 0.8000 0.2000 560 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 561 <EndHMM> 562 """ 563 564
565 -def test():
566 from gmisclib import fake_file 567 fd = fake_file.file(name='-') 568 fd.write(_TEST) 569 fd.seek(0) 570 t2 = '\n'.join( read([fd]).write() ) 571 fd2 = fake_file.file(name='-') 572 fd2.write(t2) 573 fd2.seek(0) 574 t3 = '\n'.join( read([fd2]).write() ) 575 assert t2==t3 576 return t2
577 578 579 if __name__ == '__main__': 580 print test() 581 # read([open(q, 'r') for q in sys.argv[1:]]) 582