1 from __future__ import absolute_import
2 """First suggested segmidi tagger, based on chord classes.
3
4 This is the first model structure presented in my 2nd-year review. It is
5 based on the Raphael & Stoddard model and conditions emission distributions
6 on chord classes.
7
8 @deprecated: this tagger is deprecated. I started doing some experiments
9 with it, but never concluded them, so it's not useable.
10
11 """
12 """
13 ============================== License ========================================
14 Copyright (C) 2008, 2010-12 University of Edinburgh, Mark Granroth-Wilding
15
16 This file is part of The Jazz Parser.
17
18 The Jazz Parser is free software: you can redistribute it and/or modify
19 it under the terms of the GNU General Public License as published by
20 the Free Software Foundation, either version 3 of the License, or
21 (at your option) any later version.
22
23 The Jazz Parser is distributed in the hope that it will be useful,
24 but WITHOUT ANY WARRANTY; without even the implied warranty of
25 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26 GNU General Public License for more details.
27
28 You should have received a copy of the GNU General Public License
29 along with The Jazz Parser. If not, see <http://www.gnu.org/licenses/>.
30
31 ============================ End license ======================================
32
33 """
34 __author__ = "Mark Granroth-Wilding <mark.granroth-wilding@ed.ac.uk>"
35
36 from cStringIO import StringIO
37 from jazzparser.taggers.segmidi.base import SegmidiTagger
38 from jazzparser.taggers.segmidi.chordclass.hmm import ChordClassHmm
39 from jazzparser.taggers.segmidi.chordclass.train import ChordClassBaumWelchTrainer
40 from jazzparser.taggers.segmidi.chordclass.tagutils import prepare_categories
41 from jazzparser.taggers.segmidi.midi import midi_to_emission_stream
42 from jazzparser.taggers.models import TaggerModel
43 from jazzparser.data.input import detect_input_type, MidiTaggerTrainingBulkInput
44 from jazzparser.utils.options import ModuleOption, choose_from_list
45 from jazzparser.utils.strings import str_to_bool
46 from jazzparser.utils.chords import int_to_chord_numeral
47 from jazzparser.utils.midi import note_ons
48 from . import tools
51
52 transitions = []
53 for trans_string in val.split(","):
54 trans_string = trans_string.strip()
55
56 if "-" not in trans_string:
57 raise ValueError, "illegal transition specification must be of "\
58 "the form SCHEMA-SCHEMA. Could not understand %s" % trans_string
59 schema0,__,schema1 = trans_string.partition("-")
60
61 schema0s = schema0.split("|")
62 schema1s = schema1.split("|")
63 for schema0 in schema0s:
64 for schema1 in schema1s:
65 transitions.append((schema0,schema1))
66 return transitions
67
69
70 transitions = {}
71 for trans_string in val.split(","):
72 trans_string = trans_string.strip()
73
74 if trans_string.count("-") != 2:
75 raise ValueError, "fixed root transition string must be of "\
76 "the form SCHEMA-SCHEMA-ROOT_INTERVAL. Could not understand "\
77 "%s" % trans_string
78 schema0,__,rest = trans_string.partition("-")
79 schema1,__,root = rest.partition("-")
80 root = int(root)
81 if root < 0 or root > 11:
82 raise ValueError, "root interval must be in the range [0,11], got "\
83 "%d" % root
84
85 schema0s = schema0.split("|")
86 schema1s = schema1.split("|")
87 for schema0 in schema0s:
88 for schema1 in schema1s:
89 transitions[(schema0,schema1)] = root
90 return transitions
91
94 """
95 Model class to go with L{ChordClassMidiTagger}. This is where the real
96 meat of the model is implemented.
97
98 """
99 MODEL_TYPE = 'chordclass'
100 TRAINING_OPTIONS = TaggerModel.TRAINING_OPTIONS + [
101
102 ModuleOption('ccprob', filter=float,
103 help_text="Initialization of the emission distribution.",
104 usage="ccprob=P, P is a probability. Prob P is distributed "\
105 "over the pitch classes that are in the chord class.",
106 required=True),
107 ModuleOption('metric', filter=str_to_bool,
108 help_text="Create the model with a metrical component, as in the "\
109 "original Raphael & Stoddard model",
110 usage="metric=True, or metric=False (default False)",
111 default=False),
112 ModuleOption('contprob', filter=(float, choose_from_list(["learn"])),
113 help_text="Continuation probability for transition initialization: "\
114 "probability of staying in the same state between emissions. "\
115 "Use value 'learn' to learn the self-transition probabilities "\
116 "from the durations in the transition training data",
117 usage="contprob=P, P is a probability or 'learn'",
118 default=0.3),
119 ModuleOption('maxnotes', filter=int,
120 help_text="Maximum number of notes that can be generated from a "\
121 "a state. Limit is required to make the distribution finite.",
122 usage="maxnotes=N, N is an integer",
123 default=100),
124 ModuleOption('illegal_transitions', filter=_filter_illegal_transition_string,
125 help_text="List of grammatical schema transitions (pairs) that "\
126 "will be forced to have a 0 probability. You may specify "\
127 "groups of schemata, separating each with a |",
128 usage="illegal_transitions=X0-Y0,X1-Y1,... where Xs and Ys and "\
129 "schema tags",
130 default=[]),
131 ModuleOption('fixed_roots', filter=_filter_fixed_root_transition_string,
132 help_text="List of schema transitions that may have only one "\
133 "non-zero probability root interval. You may specify "\
134 "groups of schemata, separating each with a |",
135 usage="fixed_roots=X0-Y0-R0,X1-Y1-R1,... where Xs and Ys and "\
136 "schema tags and Rs are integers in [0,11]",
137 default={}),
138
139 ] + ChordClassBaumWelchTrainer.OPTIONS
140
141 - def __init__(self, model_name, *args, **kwargs):
144
145 - def train(self, inputs, grammar=None, logger=None):
146 """
147 @type inputs: L{jazzparser.data.input.MidiTaggerTrainingBulkInput} or
148 list of L{jazzparser.data.input.Input}s
149 @param inputs: training MIDI data. Annotated chord sequences should
150 also be given (though this is optional) by loading a
151 bulk db input file in the MidiTaggerTrainingBulkInput.
152
153 """
154 if grammar is None:
155 from jazzparser.grammar import get_grammar
156
157 grammar = get_grammar()
158
159 if len(inputs) == 0:
160
161 return
162
163
164
165 input_type = detect_input_type(inputs[0], allowed=['segmidi'])
166
167 if isinstance(inputs, MidiTaggerTrainingBulkInput) and \
168 inputs.chords is not None:
169 chord_inputs = inputs.chords
170 else:
171 chord_inputs = None
172
173
174 self.hmm = ChordClassHmm.initialize_chord_classes(
175 self.options['ccprob'],
176 self.options['maxnotes'],
177 grammar,
178 metric=self.options['metric'],
179 illegal_transitions=self.options['illegal_transitions'],
180 fixed_root_transitions=self.options['fixed_roots'])
181
182 if chord_inputs:
183
184
185 self.hmm.add_history("Training initial transition distribution "\
186 "from annotated chord data")
187 self.hmm.train_transition_distribution(chord_inputs, grammar, \
188 contprob=self.options['contprob'])
189 else:
190
191 self.hmm.add_history("No annotated chord training data given. "\
192 "Transition distribution initialized to uniform.")
193
194
195
196 bw_opt_names = [opt.name for opt in ChordClassBaumWelchTrainer.OPTIONS]
197 bw_opts = dict([(name,val) for (name,val) in self.options.items() \
198 if name in bw_opt_names])
199 retrainer = ChordClassBaumWelchTrainer(self.hmm, options=bw_opts)
200
201 def _get_save_callback():
202 def _save_callback():
203 self.save()
204 return _save_callback
205 save_callback = _get_save_callback()
206
207 retrainer.train(inputs, logger=logger, save_callback=save_callback)
208
209 self.model_description = """\
210 Initial chord class emission prob: %(ccprob)f
211 Initial self-transition prob: %(contprob)s
212 Metrical model: %(metric)s
213 """ % \
214 {
215 'ccprob' : self.options['ccprob'],
216 'metric' : self.options['metric'],
217 'contprob' : self.options['contprob'],
218 }
219
220 @staticmethod
225
232
234 """ Produce a human-readable repr of the params of the model """
235 buff = StringIO()
236
237 print >>buff, "\nChord classes:\n%s" % ", ".join(\
238 [str(cc) for cc in self.hmm.chord_classes])
239 print >>buff, "\nSchemata:\n%s" % ", ".join(sorted(self.hmm.schemata))
240 print >>buff, "\nIllegal transitions (probabilities below will be "\
241 "redistributed):\n%s" % ", ".join(["%s-%s" % labels for labels in \
242 self.hmm.illegal_transitions])
243 print >>buff, "\nFixed-root transitions:\n%s" % \
244 "\n".join(["%s -> %s, %d" % (label0, label1, root) for \
245 ((label0,label1), root) in sorted(self.hmm.fixed_root_transitions.items())])
246
247 print >>buff, "\n*** Emission distributions ***"
248 def _fmt_cond(cond):
249 if self.hmm.metric:
250 return "%s, %s, %s" % (cond[0], cond[1])
251 else:
252
253 return "%s" % (cond[0])
254
255 em_dist = self.hmm.emission_dist
256 for cond in sorted(em_dist.conditions()):
257 print >>buff, "%s:" % _fmt_cond(cond)
258 for (prob,samp) in reversed(sorted(\
259 [(em_dist[cond].prob(samp),samp) for \
260 samp in em_dist[cond].samples()])):
261 print >>buff, " %s: %s" % (samp, prob)
262
263 print >>buff, "\n*** Transition distributions ***"
264 print >>buff, "Schema transitions"
265
266 schema_trans_dist = self.hmm.schema_transition_dist
267 for label0 in sorted(schema_trans_dist.conditions()):
268 print >>buff, "%s ->" % label0
269 for (prob,samp) in reversed(sorted(\
270 [(schema_trans_dist[label0].prob(samp),samp) \
271 for samp in schema_trans_dist[label0].samples()])):
272 print >>buff, " %s: %s" % (samp, prob)
273
274 print >>buff, "\nRoot transitions"
275 root_trans_dist = self.hmm.root_transition_dist
276 for label0,label1 in sorted(root_trans_dist.conditions()):
277
278
279 if (label0,label1) in self.hmm.illegal_transitions:
280 print >>buff, "%s -> %s illegal" % (label0, label1)
281 elif (label0,label1) in self.hmm.fixed_root_transitions:
282
283 print >>buff, "%s -> %s, only %d" % (label0,label1,\
284 self.hmm.fixed_root_transitions[(label0,label1)])
285 else:
286 print >>buff, "%s -> %s," % (label0, label1)
287 for (prob,samp) in reversed(sorted(\
288 [(root_trans_dist[(label0,label1)].prob(samp),samp) \
289 for samp in root_trans_dist[(label0,label1)].samples()])):
290 print >>buff, " %s: %s" % (samp, prob)
291
292 print >>buff, "\n*** Initial state distribution ***"
293 init_dist = self.hmm.initial_state_dist
294 for (prob,label) in reversed(sorted(\
295 [(init_dist.prob(label),label) for \
296 label in init_dist.samples()])):
297 print >>buff, "%s: %s" % (label, prob)
298
299 return buff.getvalue()
300 readable_parameters = property(_get_readable_parameters)
301
303 """ Overridden to add history onto description. """
304 if self.model_description is not None:
305 model_desc = "\n\n%s" % self.model_description
306 else:
307 model_desc = ""
308 return "%s%s\nModel history:\n%s" % \
309 (self._description,model_desc,self.hmm.history)
310 description = property(__get_description)
311
313 MODEL_CLASS = ChordClassTaggerModel
314 TAGGER_OPTIONS = SegmidiTagger.TAGGER_OPTIONS + [
315 ModuleOption('decoden', filter=int,
316 help_text="Number of best categories to consider for each timestep",
317 usage="decoden=N, where N is an integer",
318 default=5),
319 ]
320 shell_tools = SegmidiTagger.shell_tools + [
321 tools.StateGridTool()
322 ]
323
325 SegmidiTagger.__init__(self, *args, **kwargs)
326 grammar = self.grammar
327
328
329 emissions = midi_to_emission_stream(self.input,
330 metric=self.model.hmm.metric,
331 remove_empty=False)
332
333
334
335 N = self.options['decoden']
336
337 gamma = self.model.hmm.compute_gamma(emissions[0])
338
339 T = gamma.shape[0]
340 probabilities = []
341 for t in range(T):
342 timeprobs = {}
343 for i,label in enumerate(self.model.hmm.label_dom):
344 timeprobs[label] = gamma[t,i]
345 probabilities.append(timeprobs)
346
347 top_tags = []
348 for time,probs in enumerate(probabilities):
349 ranked = list(reversed(sorted(\
350 [(prob,(schema,root)) for ((schema,root,chord_class),prob) in probs.items()])))
351 top_tags.append(ranked[:N])
352 self.top_tags = top_tags
353
354
355 spans = prepare_categories(top_tags)
356
357 category_sets = []
358 added_spans = []
359 for spanset in spans:
360 categories = []
361
362 for start,end,(log_prob,(schema,root)) in spanset:
363
364
365 new_cats = [
366 (start,
367 end,
368 (category,schema,2**log_prob)) \
369 for category in \
370 grammar.get_signs_for_tag(schema,
371 {'root' : root, 'time' : start })]
372
373
374
375 for new_cat in new_cats:
376 if (start,end,new_cat[2][0]) not in added_spans:
377 categories.append(new_cat)
378 added_spans.append((start,end,new_cat[2][0]))
379 category_sets.append(categories)
380 self.category_sets = category_sets
381
383 if offset >= len(self.category_sets):
384 return []
385 else:
386 return self.category_sets[offset]
387