Package jazzparser :: Package misc :: Package raphsto :: Module train
[hide private]
[frames] | no frames]

Source Code for Module jazzparser.misc.raphsto.train

  1  """Unsupervised EM training for Raphael and Stoddard's chord labelling model. 
  2   
  3  """ 
  4  """ 
  5  ============================== License ======================================== 
  6   Copyright (C) 2008, 2010-12 University of Edinburgh, Mark Granroth-Wilding 
  7    
  8   This file is part of The Jazz Parser. 
  9    
 10   The Jazz Parser is free software: you can redistribute it and/or modify 
 11   it under the terms of the GNU General Public License as published by 
 12   the Free Software Foundation, either version 3 of the License, or 
 13   (at your option) any later version. 
 14    
 15   The Jazz Parser is distributed in the hope that it will be useful, 
 16   but WITHOUT ANY WARRANTY; without even the implied warranty of 
 17   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
 18   GNU General Public License for more details. 
 19    
 20   You should have received a copy of the GNU General Public License 
 21   along with The Jazz Parser.  If not, see <http://www.gnu.org/licenses/>. 
 22   
 23  ============================ End license ====================================== 
 24   
 25  """ 
 26  __author__ = "Mark Granroth-Wilding <mark.granroth-wilding@ed.ac.uk>"  
 27   
 28   
 29  import numpy, os 
 30  from numpy import ones, float64, sum as array_sum, zeros, log2, add as array_add 
 31  import cPickle as pickle 
 32  from multiprocessing import Pool 
 33   
 34  from jazzparser.utils.nltk.probability import mle_estimator, logprob, add_logs, \ 
 35                          sum_logs, prob_dist_to_dictionary_prob_dist, \ 
 36                          cond_prob_dist_to_dictionary_cond_prob_dist 
 37  from jazzparser.utils.options import ModuleOption 
 38  from jazzparser.utils.system import get_host_info_string 
 39  from jazzparser import settings 
 40  from . import constants, RaphstoHmm, RaphstoHmmThreeChord, RaphstoHmmFourChord, \ 
 41                          RaphstoHmmUnigram, RaphstoHmmParameterError 
 42   
 43  from nltk.probability import ConditionalProbDist, FreqDist, \ 
 44              ConditionalFreqDist, DictionaryProbDist, \ 
 45              DictionaryConditionalProbDist, MutableProbDist 
 46   
 47  # Small quantity added to every probability to ensure we never get zeros 
 48  ADD_SMALL = 1e-6 
 49   
50 -def _sequence_updates(sequence, last_model, label_dom, state_ids, mode_ids, \ 51 chord_ids, beat_ids, d_ids, d_func):
52 """ 53 Evaluates the forward/backward probability matrices for a 54 single sequence under the model that came from the previous 55 iteration and returns matrices that contain the updates 56 to be made to the distributions during this iteration. 57 58 This is wrapped up in a function so it can be run in 59 parallel for each sequence. Once all sequences have been 60 evaluated, the results are combined and model updated. 61 62 """ 63 num_chords = len(chord_ids) 64 num_beats = len(beat_ids) 65 num_modes = len(mode_ids) 66 num_ds = len(d_ids) 67 num_ktrans = 12 68 69 # Local versions of the matrices store the accumulated values 70 # for just this sequence (so we can normalize before adding 71 # to the global matrices) 72 # The numerators 73 ctrans_local = zeros((num_chords,num_chords), float64) 74 ems_local = zeros((num_beats,num_ds), float64) 75 ktrans_local = zeros((num_modes,num_ktrans,num_modes), float64) 76 uni_chords_local = zeros(num_chords, float64) 77 78 # Compute the forward and backward probabilities 79 alpha,scale,seq_logprob = last_model.normal_forward_probabilities(sequence) 80 beta,scale = last_model.normal_backward_probabilities(sequence) 81 # gamma contains the state occupation probability for each state at each 82 # timestep 83 gamma = last_model.compute_gamma(sequence, alpha, beta) 84 # xi contains the probability of every state transition at every timestep 85 xi = last_model.compute_xi(sequence, alpha, beta) 86 87 T = len(sequence) 88 89 for time in range(T): 90 for state in label_dom: 91 tonic,mode,chord = state 92 state_i = state_ids[state] 93 mode_i = mode_ids[mode] 94 95 if time < T-1: 96 # Go through all possible pairs of states to update the 97 # transition distributions 98 for next_state in label_dom: 99 ntonic,nmode,nchord = next_state 100 state_j = state_ids[next_state] 101 mode_j = mode_ids[nmode] 102 103 ## Key transition dist update ## 104 tonic_change = (ntonic - tonic) % 12 105 ktrans_local[mode_i][tonic_change][mode_j] += \ 106 xi[time][state_i][state_j] 107 108 ## Chord transition dist update ## 109 chord_i, chord_j = chord_ids[chord], chord_ids[nchord] 110 if tonic == ntonic and mode == nmode: 111 # Add to chord transition dist for this chord pair 112 ctrans_local[chord_i][chord_j] += xi[time][state_i][state_j] 113 else: 114 uni_chords_local[chord_j] += xi[time][state_i][state_j] 115 116 ## Emission dist update ## 117 # Add the state occupation probability to the emission numerator 118 # for every note 119 for pc,beat in sequence[time]: 120 beat_i = beat_ids[beat] 121 d = d_func(pc, state) 122 d_i = d_ids[d] 123 124 ems_local[beat_i][d_i] += gamma[time][state_i] 125 126 # Calculate the denominators 127 ctrans_denom_local = array_sum(ctrans_local, axis=1) 128 ems_denom_local = array_sum(ems_local, axis=1) 129 ktrans_denom_local = array_sum(array_sum(ktrans_local, axis=2), axis=1) 130 uni_chords_denom_local = array_sum(uni_chords_local) 131 132 # Wrap this all up in a tuple to return to the master 133 return (ktrans_local, ctrans_local, ems_local, \ 134 uni_chords_local, \ 135 ktrans_denom_local, ctrans_denom_local, \ 136 ems_denom_local, uni_chords_denom_local, \ 137 seq_logprob)
138 ## End of pool operation _sequence_updates 139 140
141 -class RaphstoBaumWelchTrainer(object):
142 """ 143 Class with methods to retrain a Raphsto model using the Baum-Welch 144 EM algorithm. 145 146 """ 147 OPTIONS = [ 148 ModuleOption('max_iterations', filter=int, 149 help_text="Number of training iterations to give up after "\ 150 "if we don't reach convergence before.", 151 usage="max_iterations=N, where N is an integer", default=100), 152 ModuleOption('convergence_logprob', filter=float, 153 help_text="Difference in overall log probability of the "\ 154 "training data made by one iteration after which we "\ 155 "consider the training to have converged.", 156 usage="convergence_logprob=X, where X is a small floating "\ 157 "point number (e.g. 1e-3)", default=1e-3), 158 ] 159 MODEL_TYPES = [ 160 RaphstoHmm, 161 RaphstoHmmThreeChord, 162 RaphstoHmmFourChord 163 ] 164 # Only models of these types may be trained with this trainer 165
166 - def __init__(self, model, options={}):
167 self.model = model 168 # Check this model is of one of the types we can train 169 if type(model) not in self.MODEL_TYPES: 170 raise RaphstoHmmParameterError, "trainer %s cannot train a model "\ 171 "of type %s" % (type(self).__name__, type(model).__name__) 172 173 self.options = ModuleOption.process_option_dict(options, self.OPTIONS) 174 self.model_cls = type(model)
175
176 - def train(self, emissions, max_iterations=None, \ 177 convergence_logprob=None, logger=None, processes=1, 178 save=True, save_intermediate=False):
179 """ 180 Performs unsupervised training using Baum-Welch EM. 181 182 This is an instance method, because it is performed on a model 183 that has already been initialized. You might, for example, 184 create such a model using C{initialize_chord_types}. 185 186 This is based on the training procedure in NLTK for HMMs: 187 C{nltk.tag.hmm.HiddenMarkovModelTrainer.train_unsupervised}. 188 189 @type emissions: list of lists of emissions 190 @param emissions: training data. Each element is a list of 191 emissions representing a sequence in the training data. 192 Each emission is an emission like those used for 193 L{jazzparser.misc.raphsto.RaphstoHmm.emission_log_probability}, 194 i.e. a list of note 195 observations 196 @type max_iterations: int 197 @param max_iterations: maximum number of iterations to allow 198 for EM (default 100). Overrides the corresponding 199 module option 200 @type convergence_logprob: float 201 @param convergence_logprob: maximum change in log probability 202 to consider convergence to have been reached (default 1e-3). 203 Overrides the corresponding module option 204 @type logger: logging.Logger 205 @param logger: a logger to send progress logging to 206 @type processes: int 207 @param processes: number processes to spawn. A pool of this 208 many processes will be used to compute distribution updates 209 for sequences in parallel during each iteration. 210 @type save: bool 211 @param save: save the model at the end of training 212 @type save_intermediate: bool 213 @param save_intermediate: save the model after each iteration. Implies 214 C{save} 215 216 """ 217 from . import raphsto_d 218 if logger is None: 219 from jazzparser.utils.loggers import create_dummy_logger 220 logger = create_dummy_logger() 221 222 if save_intermediate: 223 save = True 224 225 # No point in creating more processes than there are sequences 226 if processes > len(emissions): 227 processes = len(emissions) 228 229 self.model.add_history("Beginning Baum-Welch training on %s" % get_host_info_string()) 230 self.model.add_history("Training on %d sequences (with %s chords)" % \ 231 (len(emissions), ", ".join("%d" % len(seq) for seq in emissions))) 232 233 # Use kwargs if given, otherwise module options 234 if max_iterations is None: 235 max_iterations = self.options['max_iterations'] 236 if convergence_logprob is None: 237 convergence_logprob = self.options['convergence_logprob'] 238 239 # Enumerate the chords 240 chord_ids = dict((crd,num) for (num,crd) in \ 241 enumerate(self.model.chord_transition_dom)) 242 num_chords = len(chord_ids) 243 # Enumerate the states 244 state_ids = dict((state,num) for (num,state) in \ 245 enumerate(self.model.label_dom)) 246 247 # Enumerate the beat values (they're probably consecutive ints, but 248 # let's not rely on it) 249 beat_ids = dict((beat,num) for (num,beat) in \ 250 enumerate(self.model.beat_dom)) 251 num_beats = len(beat_ids) 252 # Enumerate the d-values (d-function's domain) 253 d_ids = dict((d,num) for (num,d) in \ 254 enumerate(self.model.emission_dist_dom)) 255 num_ds = len(d_ids) 256 257 # Enumerate the modes 258 mode_ids = dict((m,num) for (num,m) in enumerate(constants.MODES)) 259 num_modes = len(mode_ids) 260 # The number of key transitions is always 12 261 num_ktrans = 12 262 263 # Make a mutable distribution for each of the distributions 264 # we'll be updating 265 emission_mdist = DictionaryConditionalProbDist( 266 dict((s, MutableProbDist(self.model.emission_dist[s], 267 self.model.emission_dist_dom)) 268 for s in self.model.emission_dist.conditions())) 269 key_mdist = DictionaryConditionalProbDist( 270 dict((s, MutableProbDist(self.model.key_transition_dist[s], 271 self.model.key_transition_dom)) 272 for s in self.model.key_transition_dist.conditions())) 273 chord_mdist = DictionaryConditionalProbDist( 274 dict((s, MutableProbDist(self.model.chord_transition_dist[s], 275 self.model.chord_transition_dom)) 276 for s in self.model.chord_transition_dist.conditions())) 277 chord_uni_mdist = MutableProbDist(self.model.chord_dist, 278 self.model.chord_transition_dom) 279 280 # Construct a model using these mutable distributions so we can 281 # evaluate using them 282 model = self.model_cls(key_mdist, 283 chord_mdist, 284 emission_mdist, 285 chord_uni_mdist, 286 chord_set=self.model.chord_set) 287 288 iteration = 0 289 last_logprob = None 290 while iteration < max_iterations: 291 logger.info("Beginning iteration %d" % iteration) 292 current_logprob = 0.0 293 294 ### Matrices in which to accumulate new probability estimates 295 # ctrans contains new chord transition numerator probabilities 296 # ctrans[c][c'] = Sum_{t_n=t_(n+1), m_n=m_(n+1),c_n=c,c_(n+1)=c'} 297 # alpha(x_n).beta(x_(n+1)). 298 # p(x_(n+1)|x_n).p(y_(n+1)|x_(n+1)) 299 ctrans = zeros((num_chords,num_chords), float64) 300 # ems contains the new emission numerator probabilities 301 # ems[r][d] = Sum_{d(y_n^k, x_n)=d, r_n^k=r} 302 # alpha(x_n).beta(x_n) / 303 # Sum_{x'_n} (alpha(x'_n).beta(x'_n)) 304 ems = zeros((num_beats,num_ds), float64) 305 # ktrans contains new key transition numerator probabilities 306 # ktrans[m][dt][m'] = Sum_{t_(n+1)-t_n=dt,m_(n+1)=m',m_n=m} 307 # alpha(x_n).beta(x_(n+1)). 308 # p(x_(n+1)|x_n).p(y_(n+1)|x_(n+1)) 309 ktrans = zeros((num_modes,num_ktrans,num_modes), float64) 310 # uni_chords contains the new chord numerator probabilities (q_c^1, 311 # the one not conditioned on the previous chord) 312 uni_chords = zeros(num_chords, float64) 313 # And these are the denominators 314 ctrans_denom = zeros(num_chords, float64) 315 ems_denom = zeros(num_beats, float64) 316 ktrans_denom = zeros(num_modes, float64) 317 # It may seem silly to use a matrix for this, but it allows 318 # us to update it in the callback 319 uni_chords_denom = zeros(1, float64) 320 321 def _training_callback(result): 322 """ 323 Callback for the _sequence_updates processes that takes 324 the updates from a single sequence and adds them onto 325 the global update accumulators. 326 327 """ 328 # _sequence_updates() returns all of this as a tuple 329 (ktrans_local, ctrans_local, ems_local, uni_chords_local, \ 330 ktrans_denom_local, ctrans_denom_local, ems_denom_local, \ 331 uni_chords_denom_local, \ 332 seq_logprob) = result 333 334 # Add these probabilities from this sequence to the 335 # global matrices 336 # Emission numerator 337 array_add(ems, ems_local, ems) 338 # Key transition numerator 339 array_add(ktrans, ktrans_local, ktrans) 340 # Chord transition numerator 341 array_add(ctrans, ctrans_local, ctrans) 342 # Unconditioned chord numerator 343 array_add(uni_chords, uni_chords_local, uni_chords) 344 # Denominators 345 array_add(ems_denom, ems_denom_local, ems_denom) 346 array_add(ktrans_denom, ktrans_denom_local, ktrans_denom) 347 array_add(ctrans_denom, ctrans_denom_local, ctrans_denom) 348 array_add(uni_chords_denom, uni_chords_denom_local, uni_chords_denom)
349 ## End of _training_callback 350 351 352 # Only use a process pool if there's more than one sequence 353 if processes > 1: 354 # Create a process pool to use for training 355 logger.info("Creating a pool of %d processes" % processes) 356 pool = Pool(processes=processes) 357 358 async_results = [] 359 for seq_i,sequence in enumerate(emissions): 360 logger.info("Iteration %d, sequence %d" % (iteration, seq_i)) 361 T = len(sequence) 362 if T == 0: 363 continue 364 365 # Fire off a new call to the process pool for every sequence 366 async_results.append( 367 pool.apply_async(_sequence_updates, 368 (sequence, model, 369 self.model.label_dom, 370 state_ids, mode_ids, chord_ids, 371 beat_ids, d_ids, raphsto_d), 372 callback=_training_callback) ) 373 pool.close() 374 # Wait for all the workers to complete 375 pool.join() 376 377 # Call get() on every AsyncResult so that any exceptions in 378 # workers get raised 379 for res in async_results: 380 # If there was an exception in _sequence_update, it 381 # will get raised here 382 res_tuple = res.get() 383 # Add this sequence's logprob into the total for all sequences 384 current_logprob += res_tuple[8] 385 else: 386 logger.info("One sequence: not using a process pool") 387 sequence = emissions[0] 388 389 if len(sequence) > 0: 390 updates = _sequence_updates( 391 sequence, model, 392 self.model.label_dom, 393 state_ids, mode_ids, chord_ids, 394 beat_ids, d_ids, raphsto_d) 395 _training_callback(updates) 396 # Update the overall logprob 397 current_logprob = updates[8] 398 399 # Update the model's probabilities from the accumulated values 400 for beat in self.model.beat_dom: 401 denom = ems_denom[beat_ids[beat]] 402 for d in self.model.emission_dist_dom: 403 if denom == 0.0: 404 # Zero denominator 405 prob = - logprob(len(d_ids)) 406 else: 407 prob = logprob(ems[beat_ids[beat]][d_ids[d]] + ADD_SMALL) - logprob(denom + len(d_ids)*ADD_SMALL) 408 model.emission_dist[beat].update(d, prob) 409 410 for mode0 in mode_ids.keys(): 411 mode_i = mode_ids[mode0] 412 denom = ktrans_denom[mode_ids[mode0]] 413 for key in range(num_ktrans): 414 for mode1 in mode_ids.keys(): 415 mode_j = mode_ids[mode1] 416 if denom == 0.0: 417 # Zero denominator: use a uniform distribution 418 prob = - logprob(num_ktrans*num_modes) 419 else: 420 prob = logprob(ktrans[mode_i][key][mode_j] + ADD_SMALL) - logprob(denom + num_ktrans*num_modes*ADD_SMALL) 421 model.key_transition_dist[mode0].update( 422 (key,mode1), prob) 423 424 for chord0 in chord_ids.keys(): 425 chord_i = chord_ids[chord0] 426 denom = ctrans_denom[chord_i] 427 for chord1 in chord_ids.keys(): 428 chord_j = chord_ids[chord1] 429 if denom == 0.0: 430 # Zero denominator: use a uniform distribution 431 prob = - logprob(num_chords) 432 else: 433 prob = logprob(ctrans[chord_i][chord_j] + ADD_SMALL) - logprob(denom + num_chords*ADD_SMALL) 434 model.chord_transition_dist[chord0].update(chord1, prob) 435 for chord in chord_ids.keys(): 436 prob = logprob(uni_chords[chord_ids[chord]] + ADD_SMALL) - logprob(uni_chords_denom[0] + len(chord_ids)*ADD_SMALL) 437 model.chord_dist.update(chord, prob) 438 439 # Clear the model's cache so we get the new probabilities 440 model.clear_cache() 441 442 logger.info("Training data log prob: %s" % current_logprob) 443 if last_logprob is not None and current_logprob < last_logprob: 444 logger.error("Log probability dropped by %s" % \ 445 (last_logprob - current_logprob)) 446 if last_logprob is not None: 447 logger.info("Log prob change: %s" % \ 448 (current_logprob - last_logprob)) 449 # Check whether the log probability has converged 450 if iteration > 0 and \ 451 abs(current_logprob - last_logprob) < convergence_logprob: 452 # Don't iterate any more 453 logger.info("Distribution has converged: ceasing training") 454 break 455 456 iteration += 1 457 last_logprob = current_logprob 458 459 # Update the main model 460 # Only save if we've been asked to save between iterations 461 self.update_model(model, save=save_intermediate) 462 463 self.model.add_history("Completed Baum-Welch training") 464 # Update the distribution's parameters with those we've trained 465 self.update_model(model, save=save) 466 return
467
468 - def update_model(self, model, save=True):
469 """ 470 Replaces the distributions of the saved model with those of the given 471 model and saves it. 472 473 @type save: bool 474 @param save: save the model. Otherwise just updates the distributions. 475 476 """ 477 self.model.key_transition_dist = \ 478 cond_prob_dist_to_dictionary_cond_prob_dist( 479 model.key_transition_dist) 480 self.model.chord_transition_dist = \ 481 cond_prob_dist_to_dictionary_cond_prob_dist( 482 model.chord_transition_dist) 483 self.model.emission_dist = \ 484 cond_prob_dist_to_dictionary_cond_prob_dist(model.emission_dist) 485 self.model.chord_dist = prob_dist_to_dictionary_prob_dist( 486 model.chord_dist) 487 if save: 488 self.model.save()
489 490 491 ########################## Unigram model ############################ 492
493 -def _sequence_updates_uni(sequence, last_model, label_dom, state_ids, \ 494 beat_ids, d_ids, d_func):
495 """Same as L{_sequence_updates}, modified for unigram models. """ 496 num_beats = len(beat_ids) 497 num_ds = len(d_ids) 498 num_ktrans = 12 499 500 # Local versions of the matrices store the accumulated values 501 # for just this sequence (so we can normalize before adding 502 # to the global matrices) 503 # The numerators 504 ems_local = zeros((num_beats,num_ds), float64) 505 506 # Compute the forward and backward probabilities 507 alpha,scale,seq_logprob = last_model.normal_forward_probabilities(sequence) 508 beta,scale = last_model.normal_backward_probabilities(sequence) 509 # gamma contains the state occupation probability for each state at each 510 # timestep 511 gamma = last_model.compute_gamma(sequence, alpha, beta) 512 # xi contains the probability of every state transition at every timestep 513 xi = last_model.compute_xi(sequence, alpha, beta) 514 515 T = len(sequence) 516 517 for time in range(T): 518 for state in label_dom: 519 tonic,mode,chord = state 520 state_i = state_ids[state] 521 # We don't update the transition distribution here, because it's fixed 522 523 ## Emission dist update ## 524 # Add the state occupation probability to the emission numerator 525 # for every note 526 for pc,beat in sequence[time]: 527 beat_i = beat_ids[beat] 528 d = d_func(pc, state) 529 d_i = d_ids[d] 530 531 ems_local[beat_i][d_i] += gamma[time][state_i] 532 533 # Calculate the denominators 534 ems_denom_local = array_sum(ems_local, axis=1) 535 536 # Wrap this all up in a tuple to return to the master 537 return (ems_local, ems_denom_local, seq_logprob)
538 ## End of pool operation _sequence_updates_uni 539
540 -class RaphstoBaumWelchUnigramTrainer(RaphstoBaumWelchTrainer):
541 """ 542 Class with methods to retrain a Raphsto model using the Baum-Welch 543 EM algorithm. 544 Special trainer to train unigram models. That is, it doesn't update 545 the transition distribution. 546 547 """ 548 MODEL_TYPES = [ 549 RaphstoHmmUnigram, 550 ] 551 # Model types which may be trained by this trainer: override the superclass' 552
553 - def train(self, emissions, max_iterations=None, \ 554 convergence_logprob=None, logger=None, processes=1, 555 save=True, save_intermediate=False):
556 """ 557 Performs unsupervised training using Baum-Welch EM. 558 559 This is an instance method, because it is performed on a model 560 that has already been initialized. You might, for example, 561 create such a model using C{initialize_chord_types}. 562 563 This is based on the training procedure in NLTK for HMMs: 564 C{nltk.tag.hmm.HiddenMarkovModelTrainer.train_unsupervised}. 565 566 @type emissions: list of lists of emissions 567 @param emissions: training data. Each element is a list of 568 emissions representing a sequence in the training data. 569 Each emission is an emission like those used for 570 L{jazzparser.misc.raphsto.RaphstoHmm.emission_log_probability}, 571 i.e. a list of note 572 observations 573 @type max_iterations: int 574 @param max_iterations: maximum number of iterations to allow 575 for EM (default 100). Overrides the corresponding 576 module option 577 @type convergence_logprob: float 578 @param convergence_logprob: maximum change in log probability 579 to consider convergence to have been reached (default 1e-3). 580 Overrides the corresponding module option 581 @type logger: logging.Logger 582 @param logger: a logger to send progress logging to 583 @type processes: int 584 @param processes: number processes to spawn. A pool of this 585 many processes will be used to compute distribution updates 586 for sequences in parallel during each iteration. 587 @type save: bool 588 @param save: save the model at the end of training 589 @type save_intermediate: bool 590 @param save_intermediate: save the model after each iteration. Implies 591 C{save} 592 593 """ 594 from . import raphsto_d 595 if logger is None: 596 from jazzparser.utils.loggers import create_dummy_logger 597 logger = create_dummy_logger() 598 599 if save_intermediate: 600 save = True 601 602 # No point in creating more processes than there are sequences 603 if processes > len(emissions): 604 processes = len(emissions) 605 606 self.model.add_history("Beginning Baum-Welch unigram training on %s" % get_host_info_string()) 607 self.model.add_history("Training on %d sequences (with %s chords)" % \ 608 (len(emissions), ", ".join("%d" % len(seq) for seq in emissions))) 609 610 # Use kwargs if given, otherwise module options 611 if max_iterations is None: 612 max_iterations = self.options['max_iterations'] 613 if convergence_logprob is None: 614 convergence_logprob = self.options['convergence_logprob'] 615 616 # Enumerate the states 617 state_ids = dict((state,num) for (num,state) in \ 618 enumerate(self.model.label_dom)) 619 620 # Enumerate the beat values (they're probably consecutive ints, but 621 # let's not rely on it) 622 beat_ids = dict((beat,num) for (num,beat) in \ 623 enumerate(self.model.beat_dom)) 624 num_beats = len(beat_ids) 625 # Enumerate the d-values (d-function's domain) 626 d_ids = dict((d,num) for (num,d) in \ 627 enumerate(self.model.emission_dist_dom)) 628 num_ds = len(d_ids) 629 630 # Make a mutable distribution for the emission distribution we'll 631 # be updating 632 emission_mdist = DictionaryConditionalProbDist( 633 dict((s, MutableProbDist(self.model.emission_dist[s], 634 self.model.emission_dist_dom)) 635 for s in self.model.emission_dist.conditions())) 636 # Create dummy distributions to fill the places of the transition 637 # distribution components 638 key_mdist = DictionaryConditionalProbDist({}) 639 chord_mdist = DictionaryConditionalProbDist({}) 640 chord_uni_mdist = MutableProbDist({}, []) 641 642 # Construct a model using these mutable distributions so we can 643 # evaluate using them 644 model = self.model_cls(key_mdist, 645 chord_mdist, 646 emission_mdist, 647 chord_uni_mdist, 648 chord_set=self.model.chord_set) 649 650 iteration = 0 651 last_logprob = None 652 while iteration < max_iterations: 653 logger.info("Beginning iteration %d" % iteration) 654 current_logprob = 0.0 655 656 # ems contains the new emission numerator probabilities 657 # ems[r][d] = Sum_{d(y_n^k, x_n)=d, r_n^k=r} 658 # alpha(x_n).beta(x_n) / 659 # Sum_{x'_n} (alpha(x'_n).beta(x'_n)) 660 ems = zeros((num_beats,num_ds), float64) 661 # And these are the denominators 662 ems_denom = zeros(num_beats, float64) 663 664 def _training_callback(result): 665 """ 666 Callback for the _sequence_updates processes that takes 667 the updates from a single sequence and adds them onto 668 the global update accumulators. 669 670 """ 671 # _sequence_updates() returns all of this as a tuple 672 (ems_local, ems_denom_local, seq_logprob) = result 673 674 # Add these probabilities from this sequence to the 675 # global matrices 676 # Emission numerator 677 array_add(ems, ems_local, ems) 678 # Denominators 679 array_add(ems_denom, ems_denom_local, ems_denom)
680 ## End of _training_callback 681 682 683 # Only use a process pool if there's more than one sequence 684 if processes > 1: 685 # Create a process pool to use for training 686 logger.info("Creating a pool of %d processes" % processes) 687 pool = Pool(processes=processes) 688 689 async_results = [] 690 for seq_i,sequence in enumerate(emissions): 691 logger.info("Iteration %d, sequence %d" % (iteration, seq_i)) 692 T = len(sequence) 693 if T == 0: 694 continue 695 696 # Fire off a new call to the process pool for every sequence 697 async_results.append( 698 pool.apply_async(_sequence_updates_uni, 699 (sequence, model, 700 self.model.label_dom, 701 state_ids, 702 beat_ids, d_ids, raphsto_d), 703 callback=_training_callback) ) 704 pool.close() 705 # Wait for all the workers to complete 706 pool.join() 707 708 # Call get() on every AsyncResult so that any exceptions in 709 # workers get raised 710 for res in async_results: 711 # If there was an exception in _sequence_update, it 712 # will get raised here 713 res_tuple = res.get() 714 # Add this sequence's logprob into the total for all sequences 715 current_logprob += res_tuple[2] 716 else: 717 logger.info("One sequence: not using a process pool") 718 sequence = emissions[0] 719 720 if len(sequence) > 0: 721 updates = _sequence_updates_uni( 722 sequence, model, 723 self.model.label_dom, 724 state_ids, 725 beat_ids, d_ids, raphsto_d) 726 _training_callback(updates) 727 # Update the overall logprob 728 current_logprob = updates[2] 729 730 # Update the model's probabilities from the accumulated values 731 for beat in self.model.beat_dom: 732 denom = ems_denom[beat_ids[beat]] 733 for d in self.model.emission_dist_dom: 734 if denom == 0.0: 735 # Zero denominator 736 prob = - logprob(len(d_ids)) 737 else: 738 prob = logprob(ems[beat_ids[beat]][d_ids[d]] + ADD_SMALL) - logprob(denom + len(d_ids)*ADD_SMALL) 739 model.emission_dist[beat].update(d, prob) 740 741 # Clear the model's cache so we get the new probabilities 742 model.clear_cache() 743 744 logger.info("Training data log prob: %s" % current_logprob) 745 if last_logprob is not None and current_logprob < last_logprob: 746 logger.error("Log probability dropped by %s" % \ 747 (last_logprob - current_logprob)) 748 if last_logprob is not None: 749 logger.info("Log prob change: %s" % \ 750 (current_logprob - last_logprob)) 751 # Check whether the log probability has converged 752 if iteration > 0 and \ 753 abs(current_logprob - last_logprob) < convergence_logprob: 754 # Don't iterate any more 755 logger.info("Distribution has converged: ceasing training") 756 break 757 758 iteration += 1 759 last_logprob = current_logprob 760 761 # Update the main model 762 # Only save if we've been asked to save between iterations 763 self.update_model(model, save=save_intermediate) 764 765 self.model.add_history("Completed Baum-Welch unigram training") 766 # Update the distribution's parameters with those we've trained 767 self.update_model(model, save=save) 768 return
769
770 - def update_model(self, model, save=True):
771 """ 772 Replaces the distributions of the saved model with those of the given 773 model and saves it. 774 775 @type save: bool 776 @param save: save the model. Otherwise just updates the distributions. 777 778 """ 779 self.model.emission_dist = \ 780 cond_prob_dist_to_dictionary_cond_prob_dist(model.emission_dist) 781 if save: 782 self.model.save()
783