1 """HMM that underlies the chordclass MIDI tagger.
2
3 This is based on the model proposed by Raphael & Stoddard for harmonic analysis.
4 See L{jazzparser.misc.raphsto} for the pure implementation of their model,
5 which this builds on.
6
7 """
8 """
9 ============================== License ========================================
10 Copyright (C) 2008, 2010-12 University of Edinburgh, Mark Granroth-Wilding
11
12 This file is part of The Jazz Parser.
13
14 The Jazz Parser is free software: you can redistribute it and/or modify
15 it under the terms of the GNU General Public License as published by
16 the Free Software Foundation, either version 3 of the License, or
17 (at your option) any later version.
18
19 The Jazz Parser is distributed in the hope that it will be useful,
20 but WITHOUT ANY WARRANTY; without even the implied warranty of
21 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
22 GNU General Public License for more details.
23
24 You should have received a copy of the GNU General Public License
25 along with The Jazz Parser. If not, see <http://www.gnu.org/licenses/>.
26
27 ============================ End license ======================================
28
29 """
30 __author__ = "Mark Granroth-Wilding <mark.granroth-wilding@ed.ac.uk>"
31
32 import numpy, os, re, math, warnings
33 from numpy import ones, float64, sum as array_sum, zeros
34 import cPickle as pickle
35 from datetime import datetime
36
37 from jazzparser.utils.nltk.ngram import NgramModel
38 from jazzparser.utils.nltk.probability import mle_estimator, logprob, add_logs, \
39 sum_logs, prob_dist_to_dictionary_prob_dist, \
40 cond_prob_dist_to_dictionary_cond_prob_dist, \
41 prob_dist_to_dictionary_prob_dist
42 from jazzparser.utils.base import group_pairs
43 from jazzparser.taggers.models import TaggerModel
44 from jazzparser import settings
45 from jazzparser.utils.midi import note_ons
46
47 from nltk.probability import ConditionalProbDist, FreqDist, \
48 ConditionalFreqDist, DictionaryProbDist, \
49 DictionaryConditionalProbDist, MutableProbDist
52 """
53 Hidden Markov Model based on the model described in the paper. The
54 structure of this model was descibed in my 2nd year review.
55
56 States are in the form of a tuple C{(schema,root)} where C{schema} is the
57 name of a lexical schema from the jazz grammar and C{root} is a pitch
58 class root.
59
60 Emissions are in the form of a list of pairs (pc,r), where pc is a pitch
61 class (like C{root} above) and r is an onset time abstraction. The
62 metrical model (the r part) can be disabled.
63
64 An additional distribution is stored over the number of notes emitted.
65 This needs to be included in the computation of the emission probability
66 for a set of notes for it to form a valid probability distribution.
67 For simplicity, we don't condition this on the chord class.
68 We also need a maximum number of possible notes that can be emitted
69 C{max_notes} so that this is a finite distribution.
70
71 Unlike with NgramModel, the emission domain is the domain of values
72 from which each element of an emission is selected. In other words,
73 the actual domain of emissions is the infinite set of combinations
74 of the values in C{emission_dom} (in fact finite because of C{max_notes})
75
76 As for prior distributions (start state distribution), we ignore
77 the root of the first state - it doesn't make any sense to look
78 at it since the model is pitch-invariant throughout.
79
80 @note: B{mutable distributions}: if you use mutable distributions for
81 transition or emission distributions, make sure you invalidate the cache
82 by calling L{clear_cache} after updating the distributions. Various
83 caches are used to speed up retreival of probabilities. If you fail to
84 do this, you'll end up getting some values unpredictably from the old
85 distributions
86
87 @todo: Make this inherit from
88 L{jazzparser.utils.nltk.ngram.DictionaryHmmModel} so that we can use a
89 specialization of the
90 L{jazzparser.utils.nltk.ngram.baumwelch.BaumWelchTrainer} with it.
91
92 """
93 - def __init__(self, schema_transition_dist, root_transition_dist,
94 emission_dist, emission_number_dist,
95 initial_state_dist,
96 schemata, chord_class_mapping,
97 chord_classes, history="", description="",
98 metric=False,
99 illegal_transitions=[],
100 fixed_root_transitions={}):
101 warnings.warn("DEPRECATED: The chord class tagger was never really "\
102 "finished properly, so is deprecated")
103
104
105
106
107 self.order = 2
108 self.metric = metric
109 if metric:
110 self.r_values = range(4)
111 else:
112 self.r_values = [0]
113
114
115 self.root_transition_dom = list(range(12))
116 self.schemata = schemata
117 self.emission_dom = list(range(12))
118 self.max_notes = max(emission_number_dist.samples())
119
120 self.label_dom = sum(\
121 [[(schema,root,cclass) for cclass in chord_class_mapping[schema]] \
122 for schema in schemata for root in range(12)], [])
123
124 self.num_labels = len(self.label_dom)
125 self.num_emissions = len(self.emission_dom)
126
127 self.schema_transition_dist = schema_transition_dist
128 self.root_transition_dist = root_transition_dist
129 self.emission_dist = emission_dist
130 self.initial_state_dist = initial_state_dist
131 self.emission_number_dist = emission_number_dist
132
133 self.chord_class_mapping = chord_class_mapping
134 self.chord_classes = chord_classes
135 self.num_chord_classes = dict(\
136 [(schema, len(chord_class_mapping[schema])) for schema in schemata])
137
138
139
140 self.illegal_transitions = illegal_transitions
141 self.fixed_root_transitions = fixed_root_transitions
142
143
144 self.backoff_model = None
145
146 self.history = history
147
148 self.description = description
149
150
151 self.clear_cache()
152
154 """
155 Initializes or empties probability distribution caches.
156
157 Make sure to call this if you change or update the distributions.
158
159 """
160
161 self._emission_cache = {}
162
163 self._emission_class_cache = {}
164
165 self._transition_cache = {}
166
167
168 illegal_prob = dict([(label, 0.0) for label in self.schemata])
169 for label0,label1 in self.illegal_transitions:
170
171
172 illegal_prob[label0] += self.schema_transition_dist[label0].prob(label1)
173 self._schema_prob_scalers = {}
174 for label in self.schemata:
175
176 self._schema_prob_scalers[label] = - logprob(1.0 - illegal_prob[label])
177
178 - def add_history(self, string):
179 """ Adds a line to the end of this model's history string. """
180 self.history += "%s: %s\n" % (datetime.now().isoformat(' '), string)
181
183 """ Needed for the generic Ngram stuff. """
184 if len(seq) > 1:
185 raise ValueError, "sequence_to_ngram() got a sequence with "\
186 "more than one value. This shouldn't happen on an HMM"
187 elif len(seq) == 0:
188 return None
189 else:
190 return seq[0]
191
193 """ Needed for the generic Ngram stuff. """
194 if ngram is None:
195 return []
196 else:
197 return [ngram]
198
200 """ Needed for the generic Ngram stuff. """
201 return ngram
202
204 """ Needed for the generic Ngram stuff. """
205 raise NotImplementedError, "backoff_ngram() should never be called, "\
206 "since we don't use backoff models"
207
208 @staticmethod
209 - def train(*args, **kwargs):
210 """
211 We don't train these HMMs using the C{train} method, since our
212 training procedure is not the same as the superclass, so this would
213 be confusing, as this method would require completely different
214 input.
215
216 We train our models by initializing in some way (usually hand-setting
217 of parameters), then using Baum-Welch on unlabelled data.
218
219 This method will just raise an error.
220
221 """
222 raise NotImplementedError, "don't use train() to train a ChordClassHmm. "\
223 "Instead initialize and then train unsupervisedly"
224
225 @classmethod
226 - def initialize_chord_classes(cls, tetrad_prob, max_notes, grammar, \
227 illegal_transitions=[], fixed_root_transitions={}, metric=False):
228 """
229 Creates a new model with the distributions initialized naively to
230 favour simple chord-types, in a similar way to what R&S do in the paper.
231
232 The transition distribution is initialized so that everything is
233 equiprobable.
234
235 @type tetrad_prob: float
236 @param tetrad_prob: prob of a note in the tetrad. This prob is
237 distributed over the notes of the tetrad. The remaining prob
238 mass is distributed over the remaining notes. You'll want this
239 to be >0.33, so that tetrad notes are more probable than others.
240 @type max_notes: int
241 @param max_notes: maximum number of notes that can be generated in
242 each emission. Usually best to set to something high, like 100 -
243 it's just to make the distribution finite.
244 @type grammar: L{jazzparser.grammar.Grammar}
245 @param grammar: grammar from which to take the chord class definitions
246 @type metric: bool
247 @param metric: if True, creates a model with a metrical component
248 (dependence on metrical position). Default False
249
250 """
251
252 classes = [ccls for ccls in grammar.chord_classes.values() if ccls.used]
253
254
255 dists = {}
256
257
258
259 if metric:
260 r_vals = range(4)
261 else:
262 r_vals = [0]
263
264 for ccls in classes:
265 for r in r_vals:
266 probabilities = {}
267
268
269 in_tetrad_prob = tetrad_prob / len(ccls.notes)
270 out_tetrad_prob = (1.0 - tetrad_prob) / (12 - len(ccls.notes))
271
272 for d in range(12):
273 if d in ccls.notes:
274 probabilities[d] = in_tetrad_prob
275 else:
276 probabilities[d] = out_tetrad_prob
277 dists[(ccls.name,r)] = DictionaryProbDist(probabilities)
278 emission_dist = DictionaryConditionalProbDist(dists)
279
280
281
282
283 schemata = grammar.midi_families.keys()
284
285
286
287 for labels in illegal_transitions:
288 for label in labels:
289 if label not in schemata:
290 raise ValueError, "%s, given in illegal transition "\
291 "specification, is not a valid schema in the grammar" \
292 % label
293 for labels in fixed_root_transitions:
294 for label in labels:
295 if label not in schemata:
296 raise ValueError, "%s, given in fixed root transition "\
297 "specification, is not a valid schema in the grammar" \
298 % label
299
300
301
302 chord_class_mapping = {}
303 for morph in grammar.morphs:
304 if morph.pos in schemata:
305 chord_class_mapping.setdefault(morph.pos, []).append(str(morph.chord_class.name))
306
307 for label in schemata:
308 if label not in chord_class_mapping:
309 chord_class_mapping[label] = []
310
311
312 schema_transition_counts = ConditionalFreqDist()
313 root_transition_counts = ConditionalFreqDist()
314 for label0 in schemata:
315 for label1 in schemata:
316
317
318
319 for cclass in chord_class_mapping[label1]:
320 schema_transition_counts[label0].inc(label1)
321 for root_change in range(12):
322
323 root_transition_counts[(label0,label1)].inc(root_change)
324
325 schema_transition_counts[label0].inc(None)
326
327 schema_trans_dist = ConditionalProbDist(schema_transition_counts, mle_estimator, None)
328 root_trans_dist = ConditionalProbDist(root_transition_counts, mle_estimator, None)
329
330 schema_trans_dist = cond_prob_dist_to_dictionary_cond_prob_dist(schema_trans_dist)
331 root_trans_dist = cond_prob_dist_to_dictionary_cond_prob_dist(root_trans_dist)
332
333
334 initial_state_counts = FreqDist()
335 for label in schemata:
336 initial_state_counts.inc(label)
337 initial_state_dist = mle_estimator(initial_state_counts, None)
338 initial_state_dist = prob_dist_to_dictionary_prob_dist(initial_state_dist)
339
340
341 emission_number_counts = FreqDist()
342 for i in range(max_notes):
343 emission_number_counts.inc(i)
344 emission_number_dist = mle_estimator(emission_number_counts, None)
345 emission_number_dist = prob_dist_to_dictionary_prob_dist(emission_number_dist)
346
347
348 model = cls(schema_trans_dist,
349 root_trans_dist,
350 emission_dist,
351 emission_number_dist,
352 initial_state_dist,
353 schemata,
354 chord_class_mapping,
355 classes,
356 metric=metric,
357 illegal_transitions=illegal_transitions,
358 fixed_root_transitions=fixed_root_transitions)
359 model.add_history(\
360 "Initialized model to chord type probabilities, using "\
361 "tetrad probability %s. Metric: %s" % \
362 (tetrad_prob, metric))
363
364 return model
365
367 """
368 Train the transition distribution parameters in a supervised manner,
369 using chord corpus input.
370
371 This is used as an initialization step to set transition parameters
372 before running EM on unannotated data.
373
374 @type inputs: L{jazzparser.data.input.AnnotatedDbBulkInput}
375 @param inputs: annotated chord training data
376 @type contprob: float or string
377 @param contprob: probability mass to reserve for staying on the
378 same state (self transitions). Use special value 'learn' to
379 learn the probabilities from the durations
380
381 """
382 self.add_history(
383 "Training transition probabilities using %d annotated chord "\
384 "sequences" % len(inputs))
385 learn_cont = contprob == "learn"
386
387
388 if learn_cont:
389
390 sequences = []
391 for seq in inputs:
392 sequence = []
393 last_cat = None
394 for chord,cat in zip(seq, seq.categories):
395
396 for i in range(chord.duration):
397 sequence.append((chord,cat))
398 sequences.append(sequence)
399 else:
400 sequences = [list(zip(sequence, sequence.categories)) for \
401 sequence in inputs]
402
403
404 label_transform = {}
405
406 for schema in self.schemata:
407 label_transform[schema] = (schema, 0)
408
409 for pos,mapping in grammar.equiv_map.items():
410 label_transform[pos] = (mapping.target.pos, mapping.root)
411
412
413 training_samples = []
414 for chord_cats in sequences:
415 seq_samples = []
416 for chord,cat in chord_cats:
417
418 if cat in label_transform:
419 use_cat, alter_root = label_transform[cat]
420 else:
421 use_cat, alter_root = cat, 0
422 root = (chord.root + alter_root) % 12
423 seq_samples.append((str(use_cat), root))
424 training_samples.append(seq_samples)
425
426 training_data = sum([
427 [(cat0, cat1, (root1 - root0) % 12)
428 for ((cat0,root0),(cat1,root1)) in \
429 group_pairs(seq_samples)] \
430 for seq_samples in training_samples], [])
431
432
433 schema_transition_counts = ConditionalFreqDist()
434 root_transition_counts = ConditionalFreqDist()
435 for (label0, label1, root_change) in training_data:
436
437 if label0 in self.schemata and label1 in self.schemata:
438 schema_transition_counts[label0].inc(label1)
439 root_transition_counts[(label0,label1)].inc(root_change)
440
441
442 for sequence in training_samples:
443
444
445 schema_transition_counts[sequence[-1][0]].inc(None)
446
447
448
449
450
451 for label0 in self.schemata:
452 for label1 in self.schemata:
453 for root_change in range(12):
454
455 if learn_cont or not (label0 == label1 and root_change == 0):
456 schema_transition_counts[label0].inc(label1)
457 root_transition_counts[(label0,label1)].inc(root_change)
458
459
460
461
462 schema_trans_dist = cond_prob_dist_to_dictionary_cond_prob_dist(\
463 ConditionalProbDist(schema_transition_counts, mle_estimator, None), \
464 mutable=True, samples=self.schemata+[None])
465 root_trans_dist = cond_prob_dist_to_dictionary_cond_prob_dist(\
466 ConditionalProbDist(root_transition_counts, mle_estimator, None), \
467 mutable=True, samples=range(12))
468
469 if not learn_cont:
470
471 discount = logprob(1.0 - contprob)
472 self_prob = logprob(contprob)
473 for label0 in self.schemata:
474
475 trans_dist[label0].update((label0, 0), self_prob)
476
477
478 for label1 in self.schemata:
479 for root_change in range(12):
480 if not (label0 == label1 and root_change == 0):
481
482 trans_dist[label0].update((label1, root_change), \
483 trans_dist[label0].logprob((label1, root_change)) + \
484 discount)
485
486
487 schema_trans_dist = cond_prob_dist_to_dictionary_cond_prob_dist(schema_trans_dist)
488 root_trans_dist = cond_prob_dist_to_dictionary_cond_prob_dist(root_trans_dist)
489
490
491
492 initial_counts = FreqDist()
493 for sequence in training_samples:
494 initial_counts.inc(sequence[0][0])
495
496
497
498
499
500 initial_dist = prob_dist_to_dictionary_prob_dist(\
501 mle_estimator(initial_counts, None), samples=self.schemata)
502
503
504 self.schema_transition_dist = schema_trans_dist
505 self.root_transition_dist = root_trans_dist
506 self.initial_state_dist = initial_dist
507
508 self.clear_cache()
509
511 """
512 Trains the distribution over the number of notes emitted from a
513 chord class. It's not conditioned on the chord class, so the only
514 training data needed is a segmented MIDI corpus.
515
516 @type inputs: list of lists
517 @param inputs: training data. The same format as is produced by
518 L{jazzparser.taggers.segmidi.midi.midi_to_emission_stream}
519
520 """
521 self.add_history(
522 "Training emission number probabilities using %d MIDI segments"\
523 % len(inputs))
524
525 emission_number_counts = FreqDist()
526 for sequence in inputs:
527 for segment in sequence:
528 notes = len(segment)
529
530 if notes <= self.max_notes:
531 emission_number_counts.inc(notes)
532
533
534 for notes in range(self.max_notes):
535 emission_number_counts.inc(notes)
536
537
538 emission_number_dist = prob_dist_to_dictionary_prob_dist(\
539 mle_estimator(emission_number_counts, None))
540 self.emission_number_dist = emission_number_dist
541
543
544 if (previous_state, state) not in self._transition_cache:
545 if state is None:
546
547
548
549 label,root,cclass = previous_state
550 prob = self.schema_transition_dist[label].logprob(None) + \
551 self._schema_prob_scalers[label]
552 return prob
553
554 if previous_state is None:
555
556 label,root,cclass = state
557
558
559 return self.initial_state_dist.logprob(label) \
560 - math.log(12.0, 2) \
561 - math.log(self.num_chord_classes[label], 2)
562
563
564 (label0,root0,cclass),(label1,root1,cclass) = (previous_state, state)
565 root_change = (root1 - root0) % 12
566
567
568 if (label0,label1) in self.illegal_transitions:
569 return float('-inf')
570
571 if (label0,label1) in self.fixed_root_transitions:
572
573 if self.fixed_root_transitions[(label0,label1)] == root_change:
574 root_prob = 0.0
575 else:
576
577 return float('-inf')
578 else:
579 root_prob = self.root_transition_dist[(label0,label1)].logprob(\
580 root_change)
581
582
583
584 prob = \
585 self.schema_transition_dist[label0].logprob(label1) + \
586 self._schema_prob_scalers[label0] + \
587 root_prob - \
588 math.log(self.num_chord_classes[label1], 2)
589
590 self._transition_cache[(previous_state,state)] = prob
591 return prob
592 else:
593 return self._transition_cache[(previous_state,state)]
594
596 """
597 Gives the probability P(emission | label). Returned as a base 2
598 log.
599
600 The emission should be a list of emitted notes.
601
602 Each note should be
603 given as a tuple (pc,beat), where pc is the pitch class of the note
604 and beat is the beat specifier for the metrical model. If the model
605 is non-metric, you may set to beat always to 0, as it will be ignored
606 and assumed to be 0.
607
608 """
609
610
611
612 cache_key = (tuple(sorted(emission)), state)
613
614 if cache_key not in self._emission_cache:
615
616 label, root, chord_class = state
617
618 prob = self.chord_class_emission_log_probability(emission,
619 chord_class,
620 root)
621 self._emission_cache[cache_key] = prob
622 return self._emission_cache[cache_key]
623
625 """
626 The standard emission probability is P(emission | state). This instead
627 returns P(emission | chord class). The emission is given in the same
628 way as to L{emission_log_probability}.
629
630 The root number is also required. For L{emission_log_probability},
631 this is included in the state label.
632
633 """
634
635
636 if not self.metric:
637
638
639 notes = sorted([((pc-root) % 12, 0) for (pc,beat) in emission])
640 else:
641
642 notes = sorted([((pc-root) % 12, beat) for (pc,beat) in emission])
643
644
645
646
647 cache_key = (chord_class,tuple(notes))
648
649 if cache_key not in self._emission_class_cache:
650 prob = 0.0
651 for beat in self.r_values:
652
653 beat_notes = [rel_pc for (rel_pc,notebeat) in notes \
654 if notebeat == beat]
655
656 prob += self.emission_number_dist.logprob(len(beat_notes))
657
658
659
660 for rel_pc in beat_notes:
661 prob += self.emission_dist[(chord_class,beat)].logprob(rel_pc)
662 self._emission_class_cache[cache_key] = prob
663 else:
664 prob = self._emission_class_cache[cache_key]
665 return prob
666
668 """We override this to provide a faster implementation.
669
670 It might also be possible to speed up the superclass' implementation
671 using numpy, but it's easier here because we know we're using an
672 HMM, not a higher-order ngram.
673
674 This is based on the fwd prob calculation in NLTK's HMM implementation.
675
676 @type array: bool
677 @param array: if True, returns a numpy 2d array instead of a list of
678 dicts.
679
680 """
681 T = len(sequence)
682 N = len(self.label_dom)
683 alpha = numpy.zeros((T, N), numpy.float64)
684
685
686
687 for i,state in enumerate(self.label_dom):
688 alpha[0,i] = self.transition_log_probability(state, None) + \
689 self.emission_log_probability(sequence[0], state)
690
691
692 for t in range(1, T):
693 for j,sj in enumerate(self.label_dom):
694
695
696 log_probs = [
697 alpha[t-1, i] + self.transition_log_probability(sj, si) \
698 for i,si in enumerate(self.label_dom)]
699
700 alpha[t, j] = sum_logs(log_probs) + \
701 self.emission_log_probability(sequence[t], sj)
702
703 if normalize:
704 for t in range(T):
705 total = sum_logs(alpha[t,:])
706 for j in range(N):
707 alpha[t,j] -= total
708
709 if not array:
710
711 matrix = []
712 for t in range(T):
713 timestep = {}
714 for (i,label) in enumerate(self.label_dom):
715 timestep[label] = alpha[t,i]
716 matrix.append(timestep)
717 return matrix
718 else:
719 return alpha
720
722 """We override this to provide a faster implementation.
723
724 @see: forward_log_probability
725
726 @type array: bool
727 @param array: if True, returns a numpy 2d array instead of a list of
728 dicts.
729
730 """
731 T = len(sequence)
732 N = len(self.label_dom)
733 beta = numpy.zeros((T, N), numpy.float64)
734
735
736 for i,si in enumerate(self.label_dom):
737 beta[T-1, i] = self.transition_log_probability(None, si)
738
739
740 for t in range(T-2, -1, -1):
741 for i,si in enumerate(self.label_dom):
742
743
744
745 log_probs = [
746 beta[t+1, j] + self.transition_log_probability(sj, si) + \
747 self.emission_log_probability(sequence[t+1], sj) \
748 for j,sj in enumerate(self.label_dom)]
749 beta[t, i] = sum_logs(log_probs)
750
751 if normalize:
752 total = sum_logs(beta[t,:])
753 for j in range(N):
754 beta[t,j] -= total
755
756 if not array:
757
758 matrix = []
759 for t in range(T):
760 timestep = {}
761 for (i,label) in enumerate(self.label_dom):
762 timestep[label] = beta[t,i]
763 matrix.append(timestep)
764 return matrix
765 else:
766 return beta
767
768
770 """If you want the normalized matrix of forward probabilities, it's
771 ok to use normal (non-log) probabilities and these can be computed
772 more quickly, since you don't need to sum logs (which is time
773 consuming).
774
775 Returns the matrix, and also the vector of values that each timestep
776 was divided by to normalize (i.e. total probability of each timestep
777 over all states).
778 Also returns the total log probability of the sequence.
779
780 @type array: bool
781 @param array: if True, returns a numpy 2d array instead of a list of
782 dicts.
783 @return: (matrix,normalizing vector,log prob)
784
785 """
786 T = len(sequence)
787 N = len(self.label_dom)
788 alpha = numpy.zeros((T, N), numpy.float64)
789 scale = numpy.zeros(T, numpy.float64)
790
791
792
793 for i,state in enumerate(self.label_dom):
794 alpha[0,i] = self.transition_probability(state, None) * \
795 self.emission_probability(sequence[0], state)
796
797 total = array_sum(alpha[0,:])
798 alpha[0,:] /= total
799 scale[0] = total
800
801
802 for t in range(1, T):
803 for j,sj in enumerate(self.label_dom):
804
805
806 prob = sum(
807 (alpha[t-1, i] * self.transition_probability(sj, si) \
808 for i,si in enumerate(self.label_dom)), 0.0)
809
810 alpha[t, j] = prob * \
811 self.emission_probability(sequence[t], sj)
812
813 total = array_sum(alpha[t,:])
814 alpha[t,:] /= total
815 scale[t] = total
816
817
818
819
820
821
822 log_prob = sum((logprob(total) for total in scale), 0.0)
823
824 if not array:
825
826 matrix = []
827 for t in range(T):
828 timestep = {}
829 for (i,label) in enumerate(self.label_dom):
830 timestep[label] = alpha[t,i]
831 matrix.append(timestep)
832 return matrix,scale,log_prob
833 else:
834 return alpha,scale,log_prob
835
837 """
838 @see: normal_forward_probabilities
839
840 (except that this doesn't return the logprob)
841
842 @type array: bool
843 @param array: if True, returns a numpy 2d array instead of a list of
844 dicts.
845
846 """
847 T = len(sequence)
848 N = len(self.label_dom)
849 beta = numpy.zeros((T, N), numpy.float64)
850 scale = numpy.zeros(T, numpy.float64)
851
852
853 for i,si in enumerate(self.label_dom):
854 beta[T-1, i] = self.transition_probability(None, si)
855
856 total = array_sum(beta[T-1, :])
857 beta[T-1,:] /= total
858
859
860 scale[T-1] = total
861
862
863 for t in range(T-2, -1, -1):
864
865
866 em_probs = [
867 self.emission_probability(sequence[t+1], sj) \
868 for sj in self.label_dom]
869
870 for i,si in enumerate(self.label_dom):
871
872
873
874 beta[t, i] = sum(
875 (beta[t+1, j] * self.transition_probability(sj, si) * \
876 em_probs[j] \
877 for j,sj in enumerate(self.label_dom)), 0.0)
878
879 total = array_sum(beta[t,:])
880 beta[t,:] /= total
881 scale[t] = total
882
883 if not array:
884
885 matrix = []
886 for t in range(T):
887 timestep = {}
888 for (i,label) in enumerate(self.label_dom):
889 timestep[label] = beta[t,i]
890 matrix.append(timestep)
891 return matrix,scale
892 else:
893 return beta,scale
894
896 """
897 Computes the gamma matrix used in Baum-Welch. This is the matrix
898 of state occupation probabilities for each timestep. It is computed
899 from the forward and backward matrices.
900
901 These can be passed in as
902 arguments to avoid recomputing if you need to reuse them, but will
903 be computed from the model if not given. They are assumed to be
904 the matrices computed by L{normal_forward_probabilities} and
905 L{normal_backward_probabilities} (i.e. normalized, non-log
906 probabilities).
907
908 """
909 if forward is None:
910 forward = self.normal_forward_probabilities(sequence, array=True)[0]
911 if backward is None:
912 backward = self.normal_backward_probabilities(sequence, array=True)[0]
913
914
915 T,N = forward.shape
916
917
918 gamma = forward * backward
919
920 denominators = array_sum(gamma, axis=1)
921
922 gamma = (gamma.transpose() / denominators).transpose()
923
924 return gamma
925
926 - def compute_xi(self, sequence, forward=None, backward=None):
927 """
928 Computes the xi matrix used by Baum-Welch. It is the matrix of joint
929 probabilities of occupation of pairs of conecutive states:
930 P(i_t, j_{t+1} | O).
931
932 As with L{compute_gamma} forward and backward matrices can optionally
933 be passed in to avoid recomputing.
934
935 """
936 if forward is None:
937 forward = self.normal_forward_probabilities(sequence, array=True)
938 if backward is None:
939 backward = self.normal_backward_probabilities(sequence, array=True)
940
941
942 T,N = forward.shape
943
944 xi = zeros((T-1,N,N), float64)
945 for t in range(T-1):
946 total = 0.0
947
948
949 em_probs = [
950 self.emission_probability(sequence[t+1], statej) \
951 for statej in self.label_dom]
952
953
954 for i,statei in enumerate(self.label_dom):
955
956 for j,statej in enumerate(self.label_dom):
957
958 prob = forward[t][i] * backward[t+1][j] * \
959 self.transition_probability(statej, statei) * \
960 em_probs[j]
961 xi[t][i][j] = prob
962 total += prob
963
964 for i in range(N):
965 for j in range(N):
966 xi[t][i][j] /= total
967 return xi
968
970 """
971 Produces a picklable representation of model as a dict.
972 You can't just pickle the object directly because some of the
973 NLTK classes can't be pickled. You can pickle this dict and
974 reconstruct the model using NgramModel.from_picklable_dict(dict).
975
976 """
977 from jazzparser.utils.nltk.storage import object_to_dict
978 return {
979 'schema_transition_dist' : object_to_dict(self.schema_transition_dist),
980 'root_transition_dist' : object_to_dict(self.root_transition_dist),
981 'emission_dist' : object_to_dict(self.emission_dist),
982 'emission_number_dist' : object_to_dict(self.emission_number_dist),
983 'initial_state_dist' : object_to_dict(self.initial_state_dist),
984 'schemata' : self.schemata,
985 'history' : self.history,
986 'description' : self.description,
987 'chord_class_mapping' : self.chord_class_mapping,
988 'chord_classes' : self.chord_classes,
989 'illegal_transitions' : self.illegal_transitions,
990 'fixed_root_transitions' : self.fixed_root_transitions,
991 }
992
993 @classmethod
995 """
996 Reproduces an n-gram model that was converted to a picklable
997 form using to_picklable_dict.
998
999 """
1000 from jazzparser.utils.nltk.storage import dict_to_object
1001 return cls(dict_to_object(data['schema_transition_dist']),
1002 dict_to_object(data['root_transition_dist']),
1003 dict_to_object(data['emission_dist']),
1004 dict_to_object(data['emission_number_dist']),
1005 dict_to_object(data['initial_state_dist']),
1006 data['schemata'],
1007 data['chord_class_mapping'],
1008 data['chord_classes'],
1009 history=data.get('history', ''),
1010 description=data.get('description', ''),
1011 illegal_transitions=data.get('illegal_transitions', []),
1012 fixed_root_transitions=data.get('fixed_root_transitions', {}))
1013