1 """Unsupervised EM training for HMMs that use
2 L{jazzparser.utils.nltk.ngram.NgramModel} as their base implementation.
3 This is a generic implementation of the Baum-Welch algorithm for EM training
4 of HMMs. C{BaumWelchTrainer} should be subclassed to override anything that
5 needs to be customized for the model type.
6
7 """
8 """
9 ============================== License ========================================
10 Copyright (C) 2008, 2010-12 University of Edinburgh, Mark Granroth-Wilding
11
12 This file is part of The Jazz Parser.
13
14 The Jazz Parser is free software: you can redistribute it and/or modify
15 it under the terms of the GNU General Public License as published by
16 the Free Software Foundation, either version 3 of the License, or
17 (at your option) any later version.
18
19 The Jazz Parser is distributed in the hope that it will be useful,
20 but WITHOUT ANY WARRANTY; without even the implied warranty of
21 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
22 GNU General Public License for more details.
23
24 You should have received a copy of the GNU General Public License
25 along with The Jazz Parser. If not, see <http://www.gnu.org/licenses/>.
26
27 ============================ End license ======================================
28
29 """
30 __author__ = "Mark Granroth-Wilding <mark.granroth-wilding@ed.ac.uk>"
31
32 import numpy
33 from numpy import float64, sum as array_sum, zeros, log2, add as array_add
34 from multiprocessing import Pool
35
36 from jazzparser.utils.nltk.probability import logprob
37 from jazzparser.utils.nltk.ngram import NgramModel, DictionaryHmmModel
38 from jazzparser.utils.options import ModuleOption
39 from jazzparser.utils.system import get_host_info_string
40 from jazzparser.utils.strings import str_to_bool
41 from jazzparser import settings
42
43
44 ADD_SMALL = 1e-6
45
46 -def sequence_updates(sequence, last_model, empty_arrays, array_ids, update_initial=True):
47 """
48 Evaluates the forward/backward probability matrices for a
49 single sequence under the model that came from the previous
50 iteration and returns matrices that contain the updates
51 to be made to the distributions during this iteration.
52
53 This is wrapped up in a function so it can be run in
54 parallel for each sequence. Once all sequences have been
55 evaluated, the results are combined and model updated.
56
57 @type update_initial: bool
58 @param update_initial: usually you want to update all distributions,
59 including the initial state distribution. If update_initial=False,
60 the initial state distribution updates won't be made for this sequence.
61 We want this when the sequence is actually a non-initial fragment of
62 a longer sequence
63
64 """
65 try:
66 trans, ems, trans_denom, ems_denom = empty_arrays
67 state_ids, em_ids = array_ids
68
69
70 fwds,seq_logprob = last_model.normal_forward_probabilities(sequence, seq_prob=True)
71
72
73 gamma = last_model.gamma_probabilities(sequence, forward=fwds)
74
75 xi = last_model.compute_xi(sequence)
76
77 label_dom = last_model.label_dom
78 T = len(sequence)
79
80 for time in range(T):
81 for state in label_dom:
82 state_i = state_ids[state]
83
84 if time < T-1:
85
86
87 for next_state in label_dom:
88 state_j = state_ids[next_state]
89
90
91 trans[state_i][state_j] += xi[time][state_i][state_j]
92
93
94 ems[state_ids[state]][em_ids[sequence[time]]] += \
95 gamma[time][state_i]
96
97
98 trans_denom = array_sum(trans, axis=1)
99 ems_denom = array_sum(ems, axis=1)
100
101
102 return (trans, ems, trans_denom, ems_denom, seq_logprob)
103 except KeyboardInterrupt:
104 return
105
108 """
109 Class with methods to retrain an HMM using the Baum-Welch EM algorithm.
110
111 Note that although the default implementation is for a plain
112 L{jazzparser.utils.nltk.ngram.NgramModel}, Baum-Welch training only makes
113 sense if the model is an HMM. It will therefore complain if the order is
114 not 2 and if there's a backoff model.
115
116 Module options must be processed externally. This allows them to be
117 combined with other options as appropriate. The options defined here
118 are a standard set of options for generic training and should be processed
119 before the trainer is instantiated.
120
121 This is designed as a generic implementation of the algorithm. To use it
122 for a special kind of model (e.g. one with a non-standard transition
123 distribution), you need to override certain methods to make them
124 appropriate to the model:
125 - C{create_mutable_model}
126 - C{update_model}
127 - C{sequence_updates}
128 - C{get_empty_arrays}
129 - C{sequence_updates_callback}
130 - C{get_array_indices}
131
132 The generic version of the trainer can be used to train a
133 DictionaryHmmModel. Subclasses are used to train other model types.
134
135 """
136 OPTIONS = [
137 ModuleOption('max_iterations', filter=int,
138 help_text="Number of training iterations to give up after "\
139 "if we don't reach convergence before.",
140 usage="max_iterations=N, where N is an integer", default=100),
141 ModuleOption('convergence_logprob', filter=float,
142 help_text="Difference in overall log probability of the "\
143 "training data made by one iteration after which we "\
144 "consider the training to have converged.",
145 usage="convergence_logprob=X, where X is a small floating "\
146 "point number (e.g. 1e-3)", default=1e-3),
147 ModuleOption('split', filter=int,
148 help_text="Limits the length of inputs by splitting them into "\
149 "fragments of at most this length. The initial state "\
150 "distribution will only be updated for the initial fragments.",
151 usage="split=X, where X is an int"),
152 ModuleOption('truncate', filter=int,
153 help_text="Limits the length of inputs by truncating them to this "\
154 "number of timesteps. Truncation is applied before splitting.",
155 usage="truncate=X, where X is an int"),
156 ModuleOption('save_intermediate', filter=str_to_bool,
157 help_text="Save the model between iterations",
158 usage="save_intermediate=B, where B is 'true' or 'false' "\
159 "(default true)",
160 default=True),
161 ModuleOption('trainprocs', filter=int,
162 help_text="Number of processes to spawn during training. Use -1 "\
163 "to spawn a process for every sequence.",
164 usage="trainprocs=P, where P is an integer",
165 default=1),
166 ]
167
183
184 @classmethod
192
193 - def record_history(self, line):
194 """
195 Stores a line in the history of the model or wherever else it is
196 appropriate to keep a record of training steps.
197
198 Default implementation does nothing, but subclasses may want to
199 store this information.
200
201 """
202 return
203
204 sequence_updates = staticmethod(sequence_updates)
205 """
206 This should be overridden by subclasses, but not by defining a static
207 method on the class, since the function must be picklable. For this, it
208 needs to be a top-level function. Then you can set the sequence_updates
209 attribute to point to it (using staticmethod), as we have done in the
210 default implementation.
211
212 """
213
215 """
216 Creates a mutable version of the given model. This mutable version
217 will be the model that receives updates during training, as defined
218 by L{update_model}.
219
220 """
221 return DictionaryHmmModel.from_ngram_model(model, mutable=True)
222
224 """
225 Creates empty arrays to hold the accumulated probabilities during
226 training. The sizes will depend on self.model.
227
228 """
229 num_states = len(self.model.label_dom)
230 trans = zeros((num_states, num_states), float64)
231 trans_denom = zeros((num_states, ), float64)
232
233 num_ems = len(self.model.emission_dom)
234 ems = zeros((num_states, num_ems), float64)
235 ems_denom = zeros((num_states, ), float64)
236 return (trans, ems, trans_denom, ems_denom)
237
239 """
240 Returns a tuple of the dicts that map labels, emissions, etc to the
241 indices of arrays to which they correspond. These will need to be
242 different for non-standard models.
243
244 """
245 state_ids = dict([(state,id) for (id,state) in \
246 enumerate(self.model.label_dom)])
247 em_ids = dict([(em,id) for (id,em) in \
248 enumerate(self.model.emission_dom)])
249 return (state_ids, em_ids)
250
252 """
253 Callback for the sequence_updates processes that takes
254 the updates from a single sequence and adds them onto
255 the global update accumulators.
256
257 The accumulators are stored as self.global_arrays.
258
259 """
260 if result is None:
261
262 logger.warning("Child process was cancelled")
263 return
264
265
266 (trans_local, ems_local, \
267 trans_denom_local, ems_denom_local, \
268 seq_logprob) = result
269
270 (trans, ems,
271 trans_denom, ems_denom) = self.global_arrays
272
273
274
275
276 array_add(ems, ems_local, ems)
277
278 array_add(trans, trans_local, trans)
279
280 array_add(ems_denom, ems_denom_local, ems_denom)
281 array_add(trans_denom, trans_denom_local, trans_denom)
282
284 """
285 Replaces the distributions of the saved model with the probabilities
286 taken from the arrays of updates. self.model is expected to be
287 made up of mutable distributions when this is called.
288
289 """
290 trans, ems, trans_denom, ems_denom = arrays
291 state_ids, em_ids = array_ids
292 num_states = len(self.model.label_dom)
293 num_emissions = len(self.model.emission_dom)
294
295 for state in self.model.label_dom:
296
297 state_i = state_ids[state]
298 denom = trans_denom[state_i]
299
300 for next_state in self.model.label_dom:
301 state_j = state_ids[next_state]
302
303 prob = logprob(trans[state_i][state_j] + ADD_SMALL) - \
304 logprob(trans_denom[state_i] + num_states*ADD_SMALL)
305 self.model.label_dist[(state,)].update(next_state, prob)
306
307 for emission in self.model.emission_dom:
308
309 prob = logprob(ems[state_i][em_ids[emission]] + ADD_SMALL) - \
310 logprob(ems_denom[state_i] + num_emissions*ADD_SMALL)
311 self.model.emission_dist[state].update(emission, prob)
312
314 """
315 Saves the model in self.model to disk. This may be called at the end
316 of each iteration and will be called at the end of the whole training
317 process.
318
319 By default, does nothing. You don't have to put something in here,
320 but you'll need to override this if you want the model to be saved
321 during training before it gets return at the end.
322
323 """
324 return
325
326 - def train(self, emissions, logger=None):
327 """
328 Performs unsupervised training using Baum-Welch EM.
329
330 This is performed as a retraining step on a model that has already
331 been initialized.
332
333 This is based on the training procedure in NLTK for HMMs:
334 C{nltk.tag.hmm.HiddenMarkovModelTrainer.train_unsupervised}.
335
336 @type emissions: list of lists of emissions
337 @param emissions: training data. Each element is a list of
338 emissions representing a sequence in the training data.
339 Each emission is an emission like those used for
340 C{emission_log_probability} on the model
341 @type logger: logging.Logger
342 @param logger: a logger to send progress logging to
343
344 """
345 if logger is None:
346 from jazzparser.utils.loggers import create_dummy_logger
347 logger = create_dummy_logger()
348
349 self.record_history("Beginning Baum-Welch training on %s" % get_host_info_string())
350 self.record_history("Training on %d inputs (with %s segments)" % \
351 (len(emissions), ", ".join("%d" % len(seq) for seq in emissions)))
352 logger.info("Beginning Baum-Welch training on %s" % get_host_info_string())
353
354
355 max_iterations = self.options['max_iterations']
356 convergence_logprob = self.options['convergence_logprob']
357 split_length = self.options['split']
358 truncate_length = self.options['truncate']
359 save_intermediate = self.options['save_intermediate']
360 processes = self.options['trainprocs']
361
362
363 self.model = self.create_mutable_model(self.model)
364
365 array_ids = self.get_array_indices()
366
367
368 logger.info("%d input sequences" % len(emissions))
369
370 if truncate_length is not None:
371 logger.info("Truncating sequences to max %d timesteps" % \
372 truncate_length)
373 emissions = [stream[:truncate_length] for stream in emissions]
374
375
376
377 if split_length is not None:
378 logger.info("Splitting sequences into max %d-sized chunks" % \
379 split_length)
380 split_emissions = []
381
382 for emstream in emissions:
383 input_ems = list(emstream)
384 splits = []
385 first = True
386
387 while len(input_ems) >= split_length:
388
389 splits.append((first, input_ems[:split_length]))
390 input_ems = input_ems[split_length-1:]
391 first = False
392
393 if len(input_ems):
394
395 if len(splits) and len(input_ems) <= split_length / 5:
396
397
398 splits[-1][1].extend(input_ems)
399 else:
400 splits.append((first, input_ems))
401 split_emissions.extend(splits)
402 else:
403
404 split_emissions = [(True,stream) for stream in emissions]
405 logger.info("Sequence lengths after preprocessing: %s" %
406 " ".join([str(len(em[1])) for em in split_emissions]))
407
408
409
410
411 if processes == -1 or processes > len(split_emissions):
412 processes = len(split_emissions)
413
414 iteration = 0
415 last_logprob = None
416 while iteration < max_iterations:
417 logger.info("Beginning iteration %d" % iteration)
418 current_logprob = 0.0
419
420
421 self.global_arrays = self.get_empty_arrays()
422
423
424 if processes > 1:
425
426 logger.info("Creating a pool of %d processes" % processes)
427
428 pool = Pool(processes=processes)
429
430 async_results = []
431 try:
432 for seq_i,(first,sequence) in enumerate(split_emissions):
433 logger.info("Iteration %d, sequence %d" % (iteration, seq_i))
434 T = len(sequence)
435 if T == 0:
436 continue
437
438 def _notifier_closure(seq_index):
439 def _notifier(res):
440 logger.info("Sequence %d finished" % seq_index)
441 return _notifier
442
443 empty_arrays = self.get_empty_arrays()
444
445 async_results.append(
446 pool.apply_async(self.sequence_updates,
447 (sequence, self.model, empty_arrays, array_ids),
448 { 'update_initial' : first },
449 _notifier_closure(seq_i)) )
450 pool.close()
451
452 pool.join()
453 except KeyboardInterrupt:
454
455 logger.info("Keyboard interrupt was received during EM "\
456 "updates")
457 raise
458
459
460
461 for res in async_results:
462
463
464 res_tuple = res.get()
465
466
467
468
469 self.sequence_updates_callback(res_tuple)
470
471 current_logprob += res_tuple[-1]
472 else:
473 if len(split_emissions) == 1:
474 logger.info("One sequence: not using a process pool")
475 else:
476 logger.info("Not using a process pool: training %d "\
477 "emission sequences sequentially" % \
478 len(split_emissions))
479
480 for seq_i,(first,sequence) in enumerate(split_emissions):
481 if len(sequence) > 0:
482 logger.info("Iteration %d, sequence %d" % (iteration, seq_i))
483
484 empty_arrays = self.get_empty_arrays()
485 updates = self.sequence_updates(
486 sequence, self.model,
487 empty_arrays, array_ids,
488 update_initial=first)
489 self.sequence_updates_callback(updates)
490
491 current_logprob += updates[-1]
492
493
494
495 self.update_model(self.global_arrays, array_ids)
496
497
498 self.model.clear_cache()
499
500 logger.info("Training data log prob: %s" % current_logprob)
501 if last_logprob is not None and current_logprob < last_logprob:
502 logger.error("Log probability dropped by %s" % \
503 (last_logprob - current_logprob))
504 if last_logprob is not None:
505 logger.info("Log prob change: %s" % \
506 (current_logprob - last_logprob))
507
508 if iteration > 0 and \
509 abs(current_logprob - last_logprob) < convergence_logprob:
510
511 logger.info("Distribution has converged: ceasing training")
512 break
513
514 iteration += 1
515 last_logprob = current_logprob
516
517
518 if save_intermediate:
519 self.save()
520
521 self.record_history("Completed Baum-Welch training")
522
523 self.save()
524 return self.model
525
529