1 """Unsupervised EM training for Raphael and Stoddard's chord labelling model.
2
3 """
4 """
5 ============================== License ========================================
6 Copyright (C) 2008, 2010-12 University of Edinburgh, Mark Granroth-Wilding
7
8 This file is part of The Jazz Parser.
9
10 The Jazz Parser is free software: you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation, either version 3 of the License, or
13 (at your option) any later version.
14
15 The Jazz Parser is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
19
20 You should have received a copy of the GNU General Public License
21 along with The Jazz Parser. If not, see <http://www.gnu.org/licenses/>.
22
23 ============================ End license ======================================
24
25 """
26 __author__ = "Mark Granroth-Wilding <mark.granroth-wilding@ed.ac.uk>"
27
28
29 import numpy, os
30 from numpy import ones, float64, sum as array_sum, zeros, log2, add as array_add
31 import cPickle as pickle
32 from multiprocessing import Pool
33
34 from jazzparser.utils.nltk.probability import mle_estimator, logprob, add_logs, \
35 sum_logs, prob_dist_to_dictionary_prob_dist, \
36 cond_prob_dist_to_dictionary_cond_prob_dist
37 from jazzparser.utils.options import ModuleOption
38 from jazzparser.utils.system import get_host_info_string
39 from jazzparser import settings
40 from . import constants, RaphstoHmm, RaphstoHmmThreeChord, RaphstoHmmFourChord, \
41 RaphstoHmmUnigram, RaphstoHmmParameterError
42
43 from nltk.probability import ConditionalProbDist, FreqDist, \
44 ConditionalFreqDist, DictionaryProbDist, \
45 DictionaryConditionalProbDist, MutableProbDist
46
47
48 ADD_SMALL = 1e-6
49
50 -def _sequence_updates(sequence, last_model, label_dom, state_ids, mode_ids, \
51 chord_ids, beat_ids, d_ids, d_func):
52 """
53 Evaluates the forward/backward probability matrices for a
54 single sequence under the model that came from the previous
55 iteration and returns matrices that contain the updates
56 to be made to the distributions during this iteration.
57
58 This is wrapped up in a function so it can be run in
59 parallel for each sequence. Once all sequences have been
60 evaluated, the results are combined and model updated.
61
62 """
63 num_chords = len(chord_ids)
64 num_beats = len(beat_ids)
65 num_modes = len(mode_ids)
66 num_ds = len(d_ids)
67 num_ktrans = 12
68
69
70
71
72
73 ctrans_local = zeros((num_chords,num_chords), float64)
74 ems_local = zeros((num_beats,num_ds), float64)
75 ktrans_local = zeros((num_modes,num_ktrans,num_modes), float64)
76 uni_chords_local = zeros(num_chords, float64)
77
78
79 alpha,scale,seq_logprob = last_model.normal_forward_probabilities(sequence)
80 beta,scale = last_model.normal_backward_probabilities(sequence)
81
82
83 gamma = last_model.compute_gamma(sequence, alpha, beta)
84
85 xi = last_model.compute_xi(sequence, alpha, beta)
86
87 T = len(sequence)
88
89 for time in range(T):
90 for state in label_dom:
91 tonic,mode,chord = state
92 state_i = state_ids[state]
93 mode_i = mode_ids[mode]
94
95 if time < T-1:
96
97
98 for next_state in label_dom:
99 ntonic,nmode,nchord = next_state
100 state_j = state_ids[next_state]
101 mode_j = mode_ids[nmode]
102
103
104 tonic_change = (ntonic - tonic) % 12
105 ktrans_local[mode_i][tonic_change][mode_j] += \
106 xi[time][state_i][state_j]
107
108
109 chord_i, chord_j = chord_ids[chord], chord_ids[nchord]
110 if tonic == ntonic and mode == nmode:
111
112 ctrans_local[chord_i][chord_j] += xi[time][state_i][state_j]
113 else:
114 uni_chords_local[chord_j] += xi[time][state_i][state_j]
115
116
117
118
119 for pc,beat in sequence[time]:
120 beat_i = beat_ids[beat]
121 d = d_func(pc, state)
122 d_i = d_ids[d]
123
124 ems_local[beat_i][d_i] += gamma[time][state_i]
125
126
127 ctrans_denom_local = array_sum(ctrans_local, axis=1)
128 ems_denom_local = array_sum(ems_local, axis=1)
129 ktrans_denom_local = array_sum(array_sum(ktrans_local, axis=2), axis=1)
130 uni_chords_denom_local = array_sum(uni_chords_local)
131
132
133 return (ktrans_local, ctrans_local, ems_local, \
134 uni_chords_local, \
135 ktrans_denom_local, ctrans_denom_local, \
136 ems_denom_local, uni_chords_denom_local, \
137 seq_logprob)
138
139
140
142 """
143 Class with methods to retrain a Raphsto model using the Baum-Welch
144 EM algorithm.
145
146 """
147 OPTIONS = [
148 ModuleOption('max_iterations', filter=int,
149 help_text="Number of training iterations to give up after "\
150 "if we don't reach convergence before.",
151 usage="max_iterations=N, where N is an integer", default=100),
152 ModuleOption('convergence_logprob', filter=float,
153 help_text="Difference in overall log probability of the "\
154 "training data made by one iteration after which we "\
155 "consider the training to have converged.",
156 usage="convergence_logprob=X, where X is a small floating "\
157 "point number (e.g. 1e-3)", default=1e-3),
158 ]
159 MODEL_TYPES = [
160 RaphstoHmm,
161 RaphstoHmmThreeChord,
162 RaphstoHmmFourChord
163 ]
164
165
175
176 - def train(self, emissions, max_iterations=None, \
177 convergence_logprob=None, logger=None, processes=1,
178 save=True, save_intermediate=False):
179 """
180 Performs unsupervised training using Baum-Welch EM.
181
182 This is an instance method, because it is performed on a model
183 that has already been initialized. You might, for example,
184 create such a model using C{initialize_chord_types}.
185
186 This is based on the training procedure in NLTK for HMMs:
187 C{nltk.tag.hmm.HiddenMarkovModelTrainer.train_unsupervised}.
188
189 @type emissions: list of lists of emissions
190 @param emissions: training data. Each element is a list of
191 emissions representing a sequence in the training data.
192 Each emission is an emission like those used for
193 L{jazzparser.misc.raphsto.RaphstoHmm.emission_log_probability},
194 i.e. a list of note
195 observations
196 @type max_iterations: int
197 @param max_iterations: maximum number of iterations to allow
198 for EM (default 100). Overrides the corresponding
199 module option
200 @type convergence_logprob: float
201 @param convergence_logprob: maximum change in log probability
202 to consider convergence to have been reached (default 1e-3).
203 Overrides the corresponding module option
204 @type logger: logging.Logger
205 @param logger: a logger to send progress logging to
206 @type processes: int
207 @param processes: number processes to spawn. A pool of this
208 many processes will be used to compute distribution updates
209 for sequences in parallel during each iteration.
210 @type save: bool
211 @param save: save the model at the end of training
212 @type save_intermediate: bool
213 @param save_intermediate: save the model after each iteration. Implies
214 C{save}
215
216 """
217 from . import raphsto_d
218 if logger is None:
219 from jazzparser.utils.loggers import create_dummy_logger
220 logger = create_dummy_logger()
221
222 if save_intermediate:
223 save = True
224
225
226 if processes > len(emissions):
227 processes = len(emissions)
228
229 self.model.add_history("Beginning Baum-Welch training on %s" % get_host_info_string())
230 self.model.add_history("Training on %d sequences (with %s chords)" % \
231 (len(emissions), ", ".join("%d" % len(seq) for seq in emissions)))
232
233
234 if max_iterations is None:
235 max_iterations = self.options['max_iterations']
236 if convergence_logprob is None:
237 convergence_logprob = self.options['convergence_logprob']
238
239
240 chord_ids = dict((crd,num) for (num,crd) in \
241 enumerate(self.model.chord_transition_dom))
242 num_chords = len(chord_ids)
243
244 state_ids = dict((state,num) for (num,state) in \
245 enumerate(self.model.label_dom))
246
247
248
249 beat_ids = dict((beat,num) for (num,beat) in \
250 enumerate(self.model.beat_dom))
251 num_beats = len(beat_ids)
252
253 d_ids = dict((d,num) for (num,d) in \
254 enumerate(self.model.emission_dist_dom))
255 num_ds = len(d_ids)
256
257
258 mode_ids = dict((m,num) for (num,m) in enumerate(constants.MODES))
259 num_modes = len(mode_ids)
260
261 num_ktrans = 12
262
263
264
265 emission_mdist = DictionaryConditionalProbDist(
266 dict((s, MutableProbDist(self.model.emission_dist[s],
267 self.model.emission_dist_dom))
268 for s in self.model.emission_dist.conditions()))
269 key_mdist = DictionaryConditionalProbDist(
270 dict((s, MutableProbDist(self.model.key_transition_dist[s],
271 self.model.key_transition_dom))
272 for s in self.model.key_transition_dist.conditions()))
273 chord_mdist = DictionaryConditionalProbDist(
274 dict((s, MutableProbDist(self.model.chord_transition_dist[s],
275 self.model.chord_transition_dom))
276 for s in self.model.chord_transition_dist.conditions()))
277 chord_uni_mdist = MutableProbDist(self.model.chord_dist,
278 self.model.chord_transition_dom)
279
280
281
282 model = self.model_cls(key_mdist,
283 chord_mdist,
284 emission_mdist,
285 chord_uni_mdist,
286 chord_set=self.model.chord_set)
287
288 iteration = 0
289 last_logprob = None
290 while iteration < max_iterations:
291 logger.info("Beginning iteration %d" % iteration)
292 current_logprob = 0.0
293
294
295
296
297
298
299 ctrans = zeros((num_chords,num_chords), float64)
300
301
302
303
304 ems = zeros((num_beats,num_ds), float64)
305
306
307
308
309 ktrans = zeros((num_modes,num_ktrans,num_modes), float64)
310
311
312 uni_chords = zeros(num_chords, float64)
313
314 ctrans_denom = zeros(num_chords, float64)
315 ems_denom = zeros(num_beats, float64)
316 ktrans_denom = zeros(num_modes, float64)
317
318
319 uni_chords_denom = zeros(1, float64)
320
321 def _training_callback(result):
322 """
323 Callback for the _sequence_updates processes that takes
324 the updates from a single sequence and adds them onto
325 the global update accumulators.
326
327 """
328
329 (ktrans_local, ctrans_local, ems_local, uni_chords_local, \
330 ktrans_denom_local, ctrans_denom_local, ems_denom_local, \
331 uni_chords_denom_local, \
332 seq_logprob) = result
333
334
335
336
337 array_add(ems, ems_local, ems)
338
339 array_add(ktrans, ktrans_local, ktrans)
340
341 array_add(ctrans, ctrans_local, ctrans)
342
343 array_add(uni_chords, uni_chords_local, uni_chords)
344
345 array_add(ems_denom, ems_denom_local, ems_denom)
346 array_add(ktrans_denom, ktrans_denom_local, ktrans_denom)
347 array_add(ctrans_denom, ctrans_denom_local, ctrans_denom)
348 array_add(uni_chords_denom, uni_chords_denom_local, uni_chords_denom)
349
350
351
352
353 if processes > 1:
354
355 logger.info("Creating a pool of %d processes" % processes)
356 pool = Pool(processes=processes)
357
358 async_results = []
359 for seq_i,sequence in enumerate(emissions):
360 logger.info("Iteration %d, sequence %d" % (iteration, seq_i))
361 T = len(sequence)
362 if T == 0:
363 continue
364
365
366 async_results.append(
367 pool.apply_async(_sequence_updates,
368 (sequence, model,
369 self.model.label_dom,
370 state_ids, mode_ids, chord_ids,
371 beat_ids, d_ids, raphsto_d),
372 callback=_training_callback) )
373 pool.close()
374
375 pool.join()
376
377
378
379 for res in async_results:
380
381
382 res_tuple = res.get()
383
384 current_logprob += res_tuple[8]
385 else:
386 logger.info("One sequence: not using a process pool")
387 sequence = emissions[0]
388
389 if len(sequence) > 0:
390 updates = _sequence_updates(
391 sequence, model,
392 self.model.label_dom,
393 state_ids, mode_ids, chord_ids,
394 beat_ids, d_ids, raphsto_d)
395 _training_callback(updates)
396
397 current_logprob = updates[8]
398
399
400 for beat in self.model.beat_dom:
401 denom = ems_denom[beat_ids[beat]]
402 for d in self.model.emission_dist_dom:
403 if denom == 0.0:
404
405 prob = - logprob(len(d_ids))
406 else:
407 prob = logprob(ems[beat_ids[beat]][d_ids[d]] + ADD_SMALL) - logprob(denom + len(d_ids)*ADD_SMALL)
408 model.emission_dist[beat].update(d, prob)
409
410 for mode0 in mode_ids.keys():
411 mode_i = mode_ids[mode0]
412 denom = ktrans_denom[mode_ids[mode0]]
413 for key in range(num_ktrans):
414 for mode1 in mode_ids.keys():
415 mode_j = mode_ids[mode1]
416 if denom == 0.0:
417
418 prob = - logprob(num_ktrans*num_modes)
419 else:
420 prob = logprob(ktrans[mode_i][key][mode_j] + ADD_SMALL) - logprob(denom + num_ktrans*num_modes*ADD_SMALL)
421 model.key_transition_dist[mode0].update(
422 (key,mode1), prob)
423
424 for chord0 in chord_ids.keys():
425 chord_i = chord_ids[chord0]
426 denom = ctrans_denom[chord_i]
427 for chord1 in chord_ids.keys():
428 chord_j = chord_ids[chord1]
429 if denom == 0.0:
430
431 prob = - logprob(num_chords)
432 else:
433 prob = logprob(ctrans[chord_i][chord_j] + ADD_SMALL) - logprob(denom + num_chords*ADD_SMALL)
434 model.chord_transition_dist[chord0].update(chord1, prob)
435 for chord in chord_ids.keys():
436 prob = logprob(uni_chords[chord_ids[chord]] + ADD_SMALL) - logprob(uni_chords_denom[0] + len(chord_ids)*ADD_SMALL)
437 model.chord_dist.update(chord, prob)
438
439
440 model.clear_cache()
441
442 logger.info("Training data log prob: %s" % current_logprob)
443 if last_logprob is not None and current_logprob < last_logprob:
444 logger.error("Log probability dropped by %s" % \
445 (last_logprob - current_logprob))
446 if last_logprob is not None:
447 logger.info("Log prob change: %s" % \
448 (current_logprob - last_logprob))
449
450 if iteration > 0 and \
451 abs(current_logprob - last_logprob) < convergence_logprob:
452
453 logger.info("Distribution has converged: ceasing training")
454 break
455
456 iteration += 1
457 last_logprob = current_logprob
458
459
460
461 self.update_model(model, save=save_intermediate)
462
463 self.model.add_history("Completed Baum-Welch training")
464
465 self.update_model(model, save=save)
466 return
467
489
490
491
492
495 """Same as L{_sequence_updates}, modified for unigram models. """
496 num_beats = len(beat_ids)
497 num_ds = len(d_ids)
498 num_ktrans = 12
499
500
501
502
503
504 ems_local = zeros((num_beats,num_ds), float64)
505
506
507 alpha,scale,seq_logprob = last_model.normal_forward_probabilities(sequence)
508 beta,scale = last_model.normal_backward_probabilities(sequence)
509
510
511 gamma = last_model.compute_gamma(sequence, alpha, beta)
512
513 xi = last_model.compute_xi(sequence, alpha, beta)
514
515 T = len(sequence)
516
517 for time in range(T):
518 for state in label_dom:
519 tonic,mode,chord = state
520 state_i = state_ids[state]
521
522
523
524
525
526 for pc,beat in sequence[time]:
527 beat_i = beat_ids[beat]
528 d = d_func(pc, state)
529 d_i = d_ids[d]
530
531 ems_local[beat_i][d_i] += gamma[time][state_i]
532
533
534 ems_denom_local = array_sum(ems_local, axis=1)
535
536
537 return (ems_local, ems_denom_local, seq_logprob)
538
539
541 """
542 Class with methods to retrain a Raphsto model using the Baum-Welch
543 EM algorithm.
544 Special trainer to train unigram models. That is, it doesn't update
545 the transition distribution.
546
547 """
548 MODEL_TYPES = [
549 RaphstoHmmUnigram,
550 ]
551
552
553 - def train(self, emissions, max_iterations=None, \
554 convergence_logprob=None, logger=None, processes=1,
555 save=True, save_intermediate=False):
556 """
557 Performs unsupervised training using Baum-Welch EM.
558
559 This is an instance method, because it is performed on a model
560 that has already been initialized. You might, for example,
561 create such a model using C{initialize_chord_types}.
562
563 This is based on the training procedure in NLTK for HMMs:
564 C{nltk.tag.hmm.HiddenMarkovModelTrainer.train_unsupervised}.
565
566 @type emissions: list of lists of emissions
567 @param emissions: training data. Each element is a list of
568 emissions representing a sequence in the training data.
569 Each emission is an emission like those used for
570 L{jazzparser.misc.raphsto.RaphstoHmm.emission_log_probability},
571 i.e. a list of note
572 observations
573 @type max_iterations: int
574 @param max_iterations: maximum number of iterations to allow
575 for EM (default 100). Overrides the corresponding
576 module option
577 @type convergence_logprob: float
578 @param convergence_logprob: maximum change in log probability
579 to consider convergence to have been reached (default 1e-3).
580 Overrides the corresponding module option
581 @type logger: logging.Logger
582 @param logger: a logger to send progress logging to
583 @type processes: int
584 @param processes: number processes to spawn. A pool of this
585 many processes will be used to compute distribution updates
586 for sequences in parallel during each iteration.
587 @type save: bool
588 @param save: save the model at the end of training
589 @type save_intermediate: bool
590 @param save_intermediate: save the model after each iteration. Implies
591 C{save}
592
593 """
594 from . import raphsto_d
595 if logger is None:
596 from jazzparser.utils.loggers import create_dummy_logger
597 logger = create_dummy_logger()
598
599 if save_intermediate:
600 save = True
601
602
603 if processes > len(emissions):
604 processes = len(emissions)
605
606 self.model.add_history("Beginning Baum-Welch unigram training on %s" % get_host_info_string())
607 self.model.add_history("Training on %d sequences (with %s chords)" % \
608 (len(emissions), ", ".join("%d" % len(seq) for seq in emissions)))
609
610
611 if max_iterations is None:
612 max_iterations = self.options['max_iterations']
613 if convergence_logprob is None:
614 convergence_logprob = self.options['convergence_logprob']
615
616
617 state_ids = dict((state,num) for (num,state) in \
618 enumerate(self.model.label_dom))
619
620
621
622 beat_ids = dict((beat,num) for (num,beat) in \
623 enumerate(self.model.beat_dom))
624 num_beats = len(beat_ids)
625
626 d_ids = dict((d,num) for (num,d) in \
627 enumerate(self.model.emission_dist_dom))
628 num_ds = len(d_ids)
629
630
631
632 emission_mdist = DictionaryConditionalProbDist(
633 dict((s, MutableProbDist(self.model.emission_dist[s],
634 self.model.emission_dist_dom))
635 for s in self.model.emission_dist.conditions()))
636
637
638 key_mdist = DictionaryConditionalProbDist({})
639 chord_mdist = DictionaryConditionalProbDist({})
640 chord_uni_mdist = MutableProbDist({}, [])
641
642
643
644 model = self.model_cls(key_mdist,
645 chord_mdist,
646 emission_mdist,
647 chord_uni_mdist,
648 chord_set=self.model.chord_set)
649
650 iteration = 0
651 last_logprob = None
652 while iteration < max_iterations:
653 logger.info("Beginning iteration %d" % iteration)
654 current_logprob = 0.0
655
656
657
658
659
660 ems = zeros((num_beats,num_ds), float64)
661
662 ems_denom = zeros(num_beats, float64)
663
664 def _training_callback(result):
665 """
666 Callback for the _sequence_updates processes that takes
667 the updates from a single sequence and adds them onto
668 the global update accumulators.
669
670 """
671
672 (ems_local, ems_denom_local, seq_logprob) = result
673
674
675
676
677 array_add(ems, ems_local, ems)
678
679 array_add(ems_denom, ems_denom_local, ems_denom)
680
681
682
683
684 if processes > 1:
685
686 logger.info("Creating a pool of %d processes" % processes)
687 pool = Pool(processes=processes)
688
689 async_results = []
690 for seq_i,sequence in enumerate(emissions):
691 logger.info("Iteration %d, sequence %d" % (iteration, seq_i))
692 T = len(sequence)
693 if T == 0:
694 continue
695
696
697 async_results.append(
698 pool.apply_async(_sequence_updates_uni,
699 (sequence, model,
700 self.model.label_dom,
701 state_ids,
702 beat_ids, d_ids, raphsto_d),
703 callback=_training_callback) )
704 pool.close()
705
706 pool.join()
707
708
709
710 for res in async_results:
711
712
713 res_tuple = res.get()
714
715 current_logprob += res_tuple[2]
716 else:
717 logger.info("One sequence: not using a process pool")
718 sequence = emissions[0]
719
720 if len(sequence) > 0:
721 updates = _sequence_updates_uni(
722 sequence, model,
723 self.model.label_dom,
724 state_ids,
725 beat_ids, d_ids, raphsto_d)
726 _training_callback(updates)
727
728 current_logprob = updates[2]
729
730
731 for beat in self.model.beat_dom:
732 denom = ems_denom[beat_ids[beat]]
733 for d in self.model.emission_dist_dom:
734 if denom == 0.0:
735
736 prob = - logprob(len(d_ids))
737 else:
738 prob = logprob(ems[beat_ids[beat]][d_ids[d]] + ADD_SMALL) - logprob(denom + len(d_ids)*ADD_SMALL)
739 model.emission_dist[beat].update(d, prob)
740
741
742 model.clear_cache()
743
744 logger.info("Training data log prob: %s" % current_logprob)
745 if last_logprob is not None and current_logprob < last_logprob:
746 logger.error("Log probability dropped by %s" % \
747 (last_logprob - current_logprob))
748 if last_logprob is not None:
749 logger.info("Log prob change: %s" % \
750 (current_logprob - last_logprob))
751
752 if iteration > 0 and \
753 abs(current_logprob - last_logprob) < convergence_logprob:
754
755 logger.info("Distribution has converged: ceasing training")
756 break
757
758 iteration += 1
759 last_logprob = current_logprob
760
761
762
763 self.update_model(model, save=save_intermediate)
764
765 self.model.add_history("Completed Baum-Welch unigram training")
766
767 self.update_model(model, save=save)
768 return
769
771 """
772 Replaces the distributions of the saved model with those of the given
773 model and saves it.
774
775 @type save: bool
776 @param save: save the model. Otherwise just updates the distributions.
777
778 """
779 self.model.emission_dist = \
780 cond_prob_dist_to_dictionary_cond_prob_dist(model.emission_dist)
781 if save:
782 self.model.save()
783