Package jazzparser :: Package utils :: Package nltk :: Package ngram :: Module baumwelch
[hide private]
[frames] | no frames]

Source Code for Module jazzparser.utils.nltk.ngram.baumwelch

  1  """Unsupervised EM training for HMMs that use  
  2  L{jazzparser.utils.nltk.ngram.NgramModel} as their base implementation. 
  3  This is a generic implementation of the Baum-Welch algorithm for EM training  
  4  of HMMs. C{BaumWelchTrainer} should be subclassed to override anything that  
  5  needs to be customized for the model type. 
  6   
  7  """ 
  8  """ 
  9  ============================== License ======================================== 
 10   Copyright (C) 2008, 2010-12 University of Edinburgh, Mark Granroth-Wilding 
 11    
 12   This file is part of The Jazz Parser. 
 13    
 14   The Jazz Parser is free software: you can redistribute it and/or modify 
 15   it under the terms of the GNU General Public License as published by 
 16   the Free Software Foundation, either version 3 of the License, or 
 17   (at your option) any later version. 
 18    
 19   The Jazz Parser is distributed in the hope that it will be useful, 
 20   but WITHOUT ANY WARRANTY; without even the implied warranty of 
 21   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
 22   GNU General Public License for more details. 
 23    
 24   You should have received a copy of the GNU General Public License 
 25   along with The Jazz Parser.  If not, see <http://www.gnu.org/licenses/>. 
 26   
 27  ============================ End license ====================================== 
 28   
 29  """ 
 30  __author__ = "Mark Granroth-Wilding <mark.granroth-wilding@ed.ac.uk>"  
 31   
 32  import numpy 
 33  from numpy import float64, sum as array_sum, zeros, log2, add as array_add 
 34  from multiprocessing import Pool 
 35   
 36  from jazzparser.utils.nltk.probability import logprob 
 37  from jazzparser.utils.nltk.ngram import NgramModel, DictionaryHmmModel 
 38  from jazzparser.utils.options import ModuleOption 
 39  from jazzparser.utils.system import get_host_info_string 
 40  from jazzparser.utils.strings import str_to_bool 
 41  from jazzparser import settings 
 42   
 43  # Small quantity added to every probability to ensure we never get zeros 
 44  ADD_SMALL = 1e-6 
45 46 -def sequence_updates(sequence, last_model, empty_arrays, array_ids, update_initial=True):
47 """ 48 Evaluates the forward/backward probability matrices for a 49 single sequence under the model that came from the previous 50 iteration and returns matrices that contain the updates 51 to be made to the distributions during this iteration. 52 53 This is wrapped up in a function so it can be run in 54 parallel for each sequence. Once all sequences have been 55 evaluated, the results are combined and model updated. 56 57 @type update_initial: bool 58 @param update_initial: usually you want to update all distributions, 59 including the initial state distribution. If update_initial=False, 60 the initial state distribution updates won't be made for this sequence. 61 We want this when the sequence is actually a non-initial fragment of 62 a longer sequence 63 64 """ 65 try: 66 trans, ems, trans_denom, ems_denom = empty_arrays 67 state_ids, em_ids = array_ids 68 69 # Compute the forwards with seq_prob=True 70 fwds,seq_logprob = last_model.normal_forward_probabilities(sequence, seq_prob=True) 71 # gamma contains the state occupation probability for each state at each 72 # timestep 73 gamma = last_model.gamma_probabilities(sequence, forward=fwds) 74 # xi contains the probability of every state transition at every timestep 75 xi = last_model.compute_xi(sequence) 76 77 label_dom = last_model.label_dom 78 T = len(sequence) 79 80 for time in range(T): 81 for state in label_dom: 82 state_i = state_ids[state] 83 84 if time < T-1: 85 # Go through all possible pairs of states to update the 86 # transition distributions 87 for next_state in label_dom: 88 state_j = state_ids[next_state] 89 90 ## Transition dist update ## 91 trans[state_i][state_j] += xi[time][state_i][state_j] 92 93 ## Emission dist update ## 94 ems[state_ids[state]][em_ids[sequence[time]]] += \ 95 gamma[time][state_i] 96 97 # Calculate the denominators by summing 98 trans_denom = array_sum(trans, axis=1) 99 ems_denom = array_sum(ems, axis=1) 100 101 # Wrap this all up in a tuple to return to the master 102 return (trans, ems, trans_denom, ems_denom, seq_logprob) 103 except KeyboardInterrupt: 104 return
105
106 107 -class BaumWelchTrainer(object):
108 """ 109 Class with methods to retrain an HMM using the Baum-Welch EM algorithm. 110 111 Note that although the default implementation is for a plain 112 L{jazzparser.utils.nltk.ngram.NgramModel}, Baum-Welch training only makes 113 sense if the model is an HMM. It will therefore complain if the order is 114 not 2 and if there's a backoff model. 115 116 Module options must be processed externally. This allows them to be 117 combined with other options as appropriate. The options defined here 118 are a standard set of options for generic training and should be processed 119 before the trainer is instantiated. 120 121 This is designed as a generic implementation of the algorithm. To use it 122 for a special kind of model (e.g. one with a non-standard transition 123 distribution), you need to override certain methods to make them 124 appropriate to the model: 125 - C{create_mutable_model} 126 - C{update_model} 127 - C{sequence_updates} 128 - C{get_empty_arrays} 129 - C{sequence_updates_callback} 130 - C{get_array_indices} 131 132 The generic version of the trainer can be used to train a 133 DictionaryHmmModel. Subclasses are used to train other model types. 134 135 """ 136 OPTIONS = [ 137 ModuleOption('max_iterations', filter=int, 138 help_text="Number of training iterations to give up after "\ 139 "if we don't reach convergence before.", 140 usage="max_iterations=N, where N is an integer", default=100), 141 ModuleOption('convergence_logprob', filter=float, 142 help_text="Difference in overall log probability of the "\ 143 "training data made by one iteration after which we "\ 144 "consider the training to have converged.", 145 usage="convergence_logprob=X, where X is a small floating "\ 146 "point number (e.g. 1e-3)", default=1e-3), 147 ModuleOption('split', filter=int, 148 help_text="Limits the length of inputs by splitting them into "\ 149 "fragments of at most this length. The initial state "\ 150 "distribution will only be updated for the initial fragments.", 151 usage="split=X, where X is an int"), 152 ModuleOption('truncate', filter=int, 153 help_text="Limits the length of inputs by truncating them to this "\ 154 "number of timesteps. Truncation is applied before splitting.", 155 usage="truncate=X, where X is an int"), 156 ModuleOption('save_intermediate', filter=str_to_bool, 157 help_text="Save the model between iterations", 158 usage="save_intermediate=B, where B is 'true' or 'false' "\ 159 "(default true)", 160 default=True), 161 ModuleOption('trainprocs', filter=int, 162 help_text="Number of processes to spawn during training. Use -1 "\ 163 "to spawn a process for every sequence.", 164 usage="trainprocs=P, where P is an integer", 165 default=1), 166 ] 167
168 - def __init__(self, model, options={}):
169 self.model = model 170 self.options = options 171 172 # Do some checks on the model to make sure it's suitable for training 173 if not isinstance(model, NgramModel): 174 raise BaumWelchTrainingError, "BaumWelchTrainer can only be used "\ 175 "to train an instance of a subclass of NgramModel, not %s" % \ 176 type(model).__name__ 177 if model.order != 2: 178 raise BaumWelchTrainingError, "can only train a bigram model with "\ 179 "Baum-Welch. Got model of order %d" % model.order 180 if model.backoff_model is not None: 181 raise BaumWelchTrainingError, "model to be retrained has a backoff "\ 182 "model, but we can't train that using Baum-Welch"
183 184 @classmethod
185 - def process_option_dict(cls, options):
186 """ 187 Verifies and processes the training option values. Returns the 188 processed dict. 189 190 """ 191 return ModuleOption.process_option_dict(options, cls.OPTIONS)
192
193 - def record_history(self, line):
194 """ 195 Stores a line in the history of the model or wherever else it is 196 appropriate to keep a record of training steps. 197 198 Default implementation does nothing, but subclasses may want to 199 store this information. 200 201 """ 202 return
203 204 sequence_updates = staticmethod(sequence_updates) 205 """ 206 This should be overridden by subclasses, but not by defining a static 207 method on the class, since the function must be picklable. For this, it 208 needs to be a top-level function. Then you can set the sequence_updates 209 attribute to point to it (using staticmethod), as we have done in the 210 default implementation. 211 212 """ 213
214 - def create_mutable_model(self, model):
215 """ 216 Creates a mutable version of the given model. This mutable version 217 will be the model that receives updates during training, as defined 218 by L{update_model}. 219 220 """ 221 return DictionaryHmmModel.from_ngram_model(model, mutable=True)
222
223 - def get_empty_arrays(self):
224 """ 225 Creates empty arrays to hold the accumulated probabilities during 226 training. The sizes will depend on self.model. 227 228 """ 229 num_states = len(self.model.label_dom) 230 trans = zeros((num_states, num_states), float64) 231 trans_denom = zeros((num_states, ), float64) 232 233 num_ems = len(self.model.emission_dom) 234 ems = zeros((num_states, num_ems), float64) 235 ems_denom = zeros((num_states, ), float64) 236 return (trans, ems, trans_denom, ems_denom)
237
238 - def get_array_indices(self):
239 """ 240 Returns a tuple of the dicts that map labels, emissions, etc to the 241 indices of arrays to which they correspond. These will need to be 242 different for non-standard models. 243 244 """ 245 state_ids = dict([(state,id) for (id,state) in \ 246 enumerate(self.model.label_dom)]) 247 em_ids = dict([(em,id) for (id,em) in \ 248 enumerate(self.model.emission_dom)]) 249 return (state_ids, em_ids)
250
251 - def sequence_updates_callback(self, result):
252 """ 253 Callback for the sequence_updates processes that takes 254 the updates from a single sequence and adds them onto 255 the global update accumulators. 256 257 The accumulators are stored as self.global_arrays. 258 259 """ 260 if result is None: 261 # Process cancelled: do no updates 262 logger.warning("Child process was cancelled") 263 return 264 265 # sequence_updates() returns all of this as a tuple 266 (trans_local, ems_local, \ 267 trans_denom_local, ems_denom_local, \ 268 seq_logprob) = result 269 # Get the global arrays that we're updating 270 (trans, ems, 271 trans_denom, ems_denom) = self.global_arrays 272 273 # Add these probabilities from this sequence to the 274 # global matrices 275 # Emission numerator 276 array_add(ems, ems_local, ems) 277 # Transition numerator 278 array_add(trans, trans_local, trans) 279 # Denominators 280 array_add(ems_denom, ems_denom_local, ems_denom) 281 array_add(trans_denom, trans_denom_local, trans_denom)
282
283 - def update_model(self, arrays, array_ids):
284 """ 285 Replaces the distributions of the saved model with the probabilities 286 taken from the arrays of updates. self.model is expected to be 287 made up of mutable distributions when this is called. 288 289 """ 290 trans, ems, trans_denom, ems_denom = arrays 291 state_ids, em_ids = array_ids 292 num_states = len(self.model.label_dom) 293 num_emissions = len(self.model.emission_dom) 294 295 for state in self.model.label_dom: 296 # Get the transition denominator for going from this state 297 state_i = state_ids[state] 298 denom = trans_denom[state_i] 299 300 for next_state in self.model.label_dom: 301 state_j = state_ids[next_state] 302 # Update the probability of this transition 303 prob = logprob(trans[state_i][state_j] + ADD_SMALL) - \ 304 logprob(trans_denom[state_i] + num_states*ADD_SMALL) 305 self.model.label_dist[(state,)].update(next_state, prob) 306 307 for emission in self.model.emission_dom: 308 # Update the probability of this emission 309 prob = logprob(ems[state_i][em_ids[emission]] + ADD_SMALL) - \ 310 logprob(ems_denom[state_i] + num_emissions*ADD_SMALL) 311 self.model.emission_dist[state].update(emission, prob)
312
313 - def save(self):
314 """ 315 Saves the model in self.model to disk. This may be called at the end 316 of each iteration and will be called at the end of the whole training 317 process. 318 319 By default, does nothing. You don't have to put something in here, 320 but you'll need to override this if you want the model to be saved 321 during training before it gets return at the end. 322 323 """ 324 return
325
326 - def train(self, emissions, logger=None):
327 """ 328 Performs unsupervised training using Baum-Welch EM. 329 330 This is performed as a retraining step on a model that has already 331 been initialized. 332 333 This is based on the training procedure in NLTK for HMMs: 334 C{nltk.tag.hmm.HiddenMarkovModelTrainer.train_unsupervised}. 335 336 @type emissions: list of lists of emissions 337 @param emissions: training data. Each element is a list of 338 emissions representing a sequence in the training data. 339 Each emission is an emission like those used for 340 C{emission_log_probability} on the model 341 @type logger: logging.Logger 342 @param logger: a logger to send progress logging to 343 344 """ 345 if logger is None: 346 from jazzparser.utils.loggers import create_dummy_logger 347 logger = create_dummy_logger() 348 349 self.record_history("Beginning Baum-Welch training on %s" % get_host_info_string()) 350 self.record_history("Training on %d inputs (with %s segments)" % \ 351 (len(emissions), ", ".join("%d" % len(seq) for seq in emissions))) 352 logger.info("Beginning Baum-Welch training on %s" % get_host_info_string()) 353 354 # Get some options out of the module options 355 max_iterations = self.options['max_iterations'] 356 convergence_logprob = self.options['convergence_logprob'] 357 split_length = self.options['split'] 358 truncate_length = self.options['truncate'] 359 save_intermediate = self.options['save_intermediate'] 360 processes = self.options['trainprocs'] 361 362 # Make a mutable version of the model that we can update each iteration 363 self.model = self.create_mutable_model(self.model) 364 # Getting the array id mappings 365 array_ids = self.get_array_indices() 366 367 ########## Data preprocessing 368 logger.info("%d input sequences" % len(emissions)) 369 # Truncate long streams 370 if truncate_length is not None: 371 logger.info("Truncating sequences to max %d timesteps" % \ 372 truncate_length) 373 emissions = [stream[:truncate_length] for stream in emissions] 374 # Split up long streams if requested 375 # After this, each stream is a tuple (first,stream), where first 376 # indicates whether the stream segment begins a song 377 if split_length is not None: 378 logger.info("Splitting sequences into max %d-sized chunks" % \ 379 split_length) 380 split_emissions = [] 381 # Split each stream 382 for emstream in emissions: 383 input_ems = list(emstream) 384 splits = [] 385 first = True 386 # Take bits of length split_length until we're under the max 387 while len(input_ems) >= split_length: 388 # Overlap the splits by one so we get all transitions 389 splits.append((first, input_ems[:split_length])) 390 input_ems = input_ems[split_length-1:] 391 first = False 392 # Get the last short one 393 if len(input_ems): 394 # Try to avoid having a small bit that's split off at the end 395 if len(splits) and len(input_ems) <= split_length / 5: 396 # Add these to the end of the last split 397 # This will make it slightly longer than requested 398 splits[-1][1].extend(input_ems) 399 else: 400 splits.append((first, input_ems)) 401 split_emissions.extend(splits) 402 else: 403 # All streams begin a song 404 split_emissions = [(True,stream) for stream in emissions] 405 logger.info("Sequence lengths after preprocessing: %s" % 406 " ".join([str(len(em[1])) for em in split_emissions])) 407 ########## 408 409 # Special case of -1 for number of sequences 410 # No point in creating more processes than there are sequences 411 if processes == -1 or processes > len(split_emissions): 412 processes = len(split_emissions) 413 414 iteration = 0 415 last_logprob = None 416 while iteration < max_iterations: 417 logger.info("Beginning iteration %d" % iteration) 418 current_logprob = 0.0 419 420 # Build a tuple of the arrays that will be updated by each sequence 421 self.global_arrays = self.get_empty_arrays() 422 423 # Only use a process pool if there's more than one sequence 424 if processes > 1: 425 # Create a process pool to use for training 426 logger.info("Creating a pool of %d processes" % processes) 427 # catch them at this level 428 pool = Pool(processes=processes) 429 430 async_results = [] 431 try: 432 for seq_i,(first,sequence) in enumerate(split_emissions): 433 logger.info("Iteration %d, sequence %d" % (iteration, seq_i)) 434 T = len(sequence) 435 if T == 0: 436 continue 437 438 def _notifier_closure(seq_index): 439 def _notifier(res): 440 logger.info("Sequence %d finished" % seq_index)
441 return _notifier
442 # Create some empty arrays for the updates to go into 443 empty_arrays = self.get_empty_arrays() 444 # Fire off a new call to the process pool for every sequence 445 async_results.append( 446 pool.apply_async(self.sequence_updates, 447 (sequence, self.model, empty_arrays, array_ids), 448 { 'update_initial' : first }, 449 _notifier_closure(seq_i)) ) 450 pool.close() 451 # Wait for all the workers to complete 452 pool.join() 453 except KeyboardInterrupt: 454 # If Ctl+C is fired during the processing, we exit here 455 logger.info("Keyboard interrupt was received during EM "\ 456 "updates") 457 raise 458 459 # Call get() on every AsyncResult so that any exceptions in 460 # workers get raised 461 for res in async_results: 462 # If there was an exception in sequence_updates, it 463 # will get raised here 464 res_tuple = res.get() 465 # Run the callback on the results from this process 466 # It might seem sensible to do this using the callback 467 # arg to apply_async, but then the callback must be 468 # picklable and it doesn't buy us anything really 469 self.sequence_updates_callback(res_tuple) 470 # Add this sequence's logprob into the total for all sequences 471 current_logprob += res_tuple[-1] 472 else: 473 if len(split_emissions) == 1: 474 logger.info("One sequence: not using a process pool") 475 else: 476 logger.info("Not using a process pool: training %d "\ 477 "emission sequences sequentially" % \ 478 len(split_emissions)) 479 480 for seq_i,(first,sequence) in enumerate(split_emissions): 481 if len(sequence) > 0: 482 logger.info("Iteration %d, sequence %d" % (iteration, seq_i)) 483 # Create some empty arrays for the updates to go into 484 empty_arrays = self.get_empty_arrays() 485 updates = self.sequence_updates( 486 sequence, self.model, 487 empty_arrays, array_ids, 488 update_initial=first) 489 self.sequence_updates_callback(updates) 490 # Update the overall logprob 491 current_logprob += updates[-1] 492 493 ######## Model updates 494 # Update the main model 495 self.update_model(self.global_arrays, array_ids) 496 497 # Clear the model's cache so we get the new probabilities 498 self.model.clear_cache() 499 500 logger.info("Training data log prob: %s" % current_logprob) 501 if last_logprob is not None and current_logprob < last_logprob: 502 logger.error("Log probability dropped by %s" % \ 503 (last_logprob - current_logprob)) 504 if last_logprob is not None: 505 logger.info("Log prob change: %s" % \ 506 (current_logprob - last_logprob)) 507 # Check whether the log probability has converged 508 if iteration > 0 and \ 509 abs(current_logprob - last_logprob) < convergence_logprob: 510 # Don't iterate any more 511 logger.info("Distribution has converged: ceasing training") 512 break 513 514 iteration += 1 515 last_logprob = current_logprob 516 517 # Only save if we've been asked to save between iterations 518 if save_intermediate: 519 self.save() 520 521 self.record_history("Completed Baum-Welch training") 522 # Always save the model now that we're done 523 self.save() 524 return self.model 525
526 527 -class BaumWelchTrainingError(Exception):
528 pass
529