Package jazzparser :: Package taggers :: Package segmidi :: Package chordclass
[hide private]
[frames] | no frames]

Source Code for Package jazzparser.taggers.segmidi.chordclass

  1  from __future__ import absolute_import 
  2  """First suggested segmidi tagger, based on chord classes. 
  3   
  4  This is the first model structure presented in my 2nd-year review. It is  
  5  based on the Raphael & Stoddard model and conditions emission distributions  
  6  on chord classes. 
  7   
  8  @deprecated: this tagger is deprecated. I started doing some experiments  
  9      with it, but never concluded them, so it's not useable. 
 10   
 11  """ 
 12  """ 
 13  ============================== License ======================================== 
 14   Copyright (C) 2008, 2010-12 University of Edinburgh, Mark Granroth-Wilding 
 15    
 16   This file is part of The Jazz Parser. 
 17    
 18   The Jazz Parser is free software: you can redistribute it and/or modify 
 19   it under the terms of the GNU General Public License as published by 
 20   the Free Software Foundation, either version 3 of the License, or 
 21   (at your option) any later version. 
 22    
 23   The Jazz Parser is distributed in the hope that it will be useful, 
 24   but WITHOUT ANY WARRANTY; without even the implied warranty of 
 25   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
 26   GNU General Public License for more details. 
 27    
 28   You should have received a copy of the GNU General Public License 
 29   along with The Jazz Parser.  If not, see <http://www.gnu.org/licenses/>. 
 30   
 31  ============================ End license ====================================== 
 32   
 33  """ 
 34  __author__ = "Mark Granroth-Wilding <mark.granroth-wilding@ed.ac.uk>"  
 35   
 36  from cStringIO import StringIO 
 37  from jazzparser.taggers.segmidi.base import SegmidiTagger 
 38  from jazzparser.taggers.segmidi.chordclass.hmm import ChordClassHmm 
 39  from jazzparser.taggers.segmidi.chordclass.train import ChordClassBaumWelchTrainer 
 40  from jazzparser.taggers.segmidi.chordclass.tagutils import prepare_categories 
 41  from jazzparser.taggers.segmidi.midi import midi_to_emission_stream 
 42  from jazzparser.taggers.models import TaggerModel 
 43  from jazzparser.data.input import detect_input_type, MidiTaggerTrainingBulkInput 
 44  from jazzparser.utils.options import ModuleOption, choose_from_list 
 45  from jazzparser.utils.strings import str_to_bool 
 46  from jazzparser.utils.chords import int_to_chord_numeral 
 47  from jazzparser.utils.midi import note_ons 
 48  from . import tools 
49 50 -def _filter_illegal_transition_string(val):
51 # Interpret a specification of illegal transitions in string form 52 transitions = [] 53 for trans_string in val.split(","): 54 trans_string = trans_string.strip() 55 # This should be of the form SCHEMA-SCHEMA 56 if "-" not in trans_string: 57 raise ValueError, "illegal transition specification must be of "\ 58 "the form SCHEMA-SCHEMA. Could not understand %s" % trans_string 59 schema0,__,schema1 = trans_string.partition("-") 60 # Allow multiple alternative schemata to be given at once 61 schema0s = schema0.split("|") 62 schema1s = schema1.split("|") 63 for schema0 in schema0s: 64 for schema1 in schema1s: 65 transitions.append((schema0,schema1)) 66 return transitions
67
68 -def _filter_fixed_root_transition_string(val):
69 # Interpret a specification of fixed root transition in string form 70 transitions = {} 71 for trans_string in val.split(","): 72 trans_string = trans_string.strip() 73 # This should be of the form SCHEMA-SCHEMA-ROOT_INTERVAL 74 if trans_string.count("-") != 2: 75 raise ValueError, "fixed root transition string must be of "\ 76 "the form SCHEMA-SCHEMA-ROOT_INTERVAL. Could not understand "\ 77 "%s" % trans_string 78 schema0,__,rest = trans_string.partition("-") 79 schema1,__,root = rest.partition("-") 80 root = int(root) 81 if root < 0 or root > 11: 82 raise ValueError, "root interval must be in the range [0,11], got "\ 83 "%d" % root 84 # Allow multiple alternative schemata to be given at once 85 schema0s = schema0.split("|") 86 schema1s = schema1.split("|") 87 for schema0 in schema0s: 88 for schema1 in schema1s: 89 transitions[(schema0,schema1)] = root 90 return transitions
91
92 93 -class ChordClassTaggerModel(TaggerModel):
94 """ 95 Model class to go with L{ChordClassMidiTagger}. This is where the real 96 meat of the model is implemented. 97 98 """ 99 MODEL_TYPE = 'chordclass' 100 TRAINING_OPTIONS = TaggerModel.TRAINING_OPTIONS + [ 101 # Initialization options 102 ModuleOption('ccprob', filter=float, 103 help_text="Initialization of the emission distribution.", 104 usage="ccprob=P, P is a probability. Prob P is distributed "\ 105 "over the pitch classes that are in the chord class.", 106 required=True), 107 ModuleOption('metric', filter=str_to_bool, 108 help_text="Create the model with a metrical component, as in the "\ 109 "original Raphael & Stoddard model", 110 usage="metric=True, or metric=False (default False)", 111 default=False), 112 ModuleOption('contprob', filter=(float, choose_from_list(["learn"])), 113 help_text="Continuation probability for transition initialization: "\ 114 "probability of staying in the same state between emissions. "\ 115 "Use value 'learn' to learn the self-transition probabilities "\ 116 "from the durations in the transition training data", 117 usage="contprob=P, P is a probability or 'learn'", 118 default=0.3), 119 ModuleOption('maxnotes', filter=int, 120 help_text="Maximum number of notes that can be generated from a "\ 121 "a state. Limit is required to make the distribution finite.", 122 usage="maxnotes=N, N is an integer", 123 default=100), 124 ModuleOption('illegal_transitions', filter=_filter_illegal_transition_string, 125 help_text="List of grammatical schema transitions (pairs) that "\ 126 "will be forced to have a 0 probability. You may specify "\ 127 "groups of schemata, separating each with a |", 128 usage="illegal_transitions=X0-Y0,X1-Y1,... where Xs and Ys and "\ 129 "schema tags", 130 default=[]), 131 ModuleOption('fixed_roots', filter=_filter_fixed_root_transition_string, 132 help_text="List of schema transitions that may have only one "\ 133 "non-zero probability root interval. You may specify "\ 134 "groups of schemata, separating each with a |", 135 usage="fixed_roots=X0-Y0-R0,X1-Y1-R1,... where Xs and Ys and "\ 136 "schema tags and Rs are integers in [0,11]", 137 default={}), 138 # Also include the options for the baum-welch training 139 ] + ChordClassBaumWelchTrainer.OPTIONS 140
141 - def __init__(self, model_name, *args, **kwargs):
142 self.hmm = kwargs.pop('model', None) 143 super(ChordClassTaggerModel, self).__init__(model_name, *args, **kwargs)
144
145 - def train(self, inputs, grammar=None, logger=None):
146 """ 147 @type inputs: L{jazzparser.data.input.MidiTaggerTrainingBulkInput} or 148 list of L{jazzparser.data.input.Input}s 149 @param inputs: training MIDI data. Annotated chord sequences should 150 also be given (though this is optional) by loading a 151 bulk db input file in the MidiTaggerTrainingBulkInput. 152 153 """ 154 if grammar is None: 155 from jazzparser.grammar import get_grammar 156 # Load the default grammar 157 grammar = get_grammar() 158 159 if len(inputs) == 0: 160 # No data - nothing to do 161 return 162 163 # Check the type of one of the inputs - no guarantee they're all the 164 # same, but there's something seriously weird going on if they're not 165 input_type = detect_input_type(inputs[0], allowed=['segmidi']) 166 # Get the chord training data too if it's been given 167 if isinstance(inputs, MidiTaggerTrainingBulkInput) and \ 168 inputs.chords is not None: 169 chord_inputs = inputs.chords 170 else: 171 chord_inputs = None 172 173 # Initialize the emission distribution for chord classes 174 self.hmm = ChordClassHmm.initialize_chord_classes( 175 self.options['ccprob'], 176 self.options['maxnotes'], 177 grammar, 178 metric=self.options['metric'], 179 illegal_transitions=self.options['illegal_transitions'], 180 fixed_root_transitions=self.options['fixed_roots']) 181 182 if chord_inputs: 183 # If chord training data was given, initially train transition 184 # distribution from this 185 self.hmm.add_history("Training initial transition distribution "\ 186 "from annotated chord data") 187 self.hmm.train_transition_distribution(chord_inputs, grammar, \ 188 contprob=self.options['contprob']) 189 else: 190 # Otherwise it gets left as a uniform distribution 191 self.hmm.add_history("No annotated chord training data given. "\ 192 "Transition distribution initialized to uniform.") 193 194 # Get a Baum-Welch trainer to do the EM retraining 195 # Pull out the options to pass to the trainer 196 bw_opt_names = [opt.name for opt in ChordClassBaumWelchTrainer.OPTIONS] 197 bw_opts = dict([(name,val) for (name,val) in self.options.items() \ 198 if name in bw_opt_names]) 199 retrainer = ChordClassBaumWelchTrainer(self.hmm, options=bw_opts) 200 # Prepare a callback to save 201 def _get_save_callback(): 202 def _save_callback(): 203 self.save()
204 return _save_callback
205 save_callback = _get_save_callback() 206 # Do the Baum-Welch training 207 retrainer.train(inputs, logger=logger, save_callback=save_callback) 208 209 self.model_description = """\ 210 Initial chord class emission prob: %(ccprob)f 211 Initial self-transition prob: %(contprob)s 212 Metrical model: %(metric)s 213 """ % \ 214 { 215 'ccprob' : self.options['ccprob'], 216 'metric' : self.options['metric'], 217 'contprob' : self.options['contprob'], 218 } 219 220 @staticmethod
221 - def _load_model(data):
222 model = ChordClassHmm.from_picklable_dict(data['model']) 223 name = data['name'] 224 return ChordClassTaggerModel(name, model=model)
225
226 - def _get_model_data(self):
227 data = { 228 'name' : self.model_name, 229 'model' : self.hmm.to_picklable_dict() 230 } 231 return data
232
233 - def _get_readable_parameters(self):
234 """ Produce a human-readable repr of the params of the model """ 235 buff = StringIO() 236 237 print >>buff, "\nChord classes:\n%s" % ", ".join(\ 238 [str(cc) for cc in self.hmm.chord_classes]) 239 print >>buff, "\nSchemata:\n%s" % ", ".join(sorted(self.hmm.schemata)) 240 print >>buff, "\nIllegal transitions (probabilities below will be "\ 241 "redistributed):\n%s" % ", ".join(["%s-%s" % labels for labels in \ 242 self.hmm.illegal_transitions]) 243 print >>buff, "\nFixed-root transitions:\n%s" % \ 244 "\n".join(["%s -> %s, %d" % (label0, label1, root) for \ 245 ((label0,label1), root) in sorted(self.hmm.fixed_root_transitions.items())]) 246 247 print >>buff, "\n*** Emission distributions ***" 248 def _fmt_cond(cond): 249 if self.hmm.metric: 250 return "%s, %s, %s" % (cond[0], cond[1]) 251 else: 252 # Don't bother showing the 2nd element: it's always 0 253 return "%s" % (cond[0])
254 # Output emission parameters 255 em_dist = self.hmm.emission_dist 256 for cond in sorted(em_dist.conditions()): 257 print >>buff, "%s:" % _fmt_cond(cond) 258 for (prob,samp) in reversed(sorted(\ 259 [(em_dist[cond].prob(samp),samp) for \ 260 samp in em_dist[cond].samples()])): 261 print >>buff, " %s: %s" % (samp, prob) 262 263 print >>buff, "\n*** Transition distributions ***" 264 print >>buff, "Schema transitions" 265 # Output transition parameters 266 schema_trans_dist = self.hmm.schema_transition_dist 267 for label0 in sorted(schema_trans_dist.conditions()): 268 print >>buff, "%s ->" % label0 269 for (prob,samp) in reversed(sorted(\ 270 [(schema_trans_dist[label0].prob(samp),samp) \ 271 for samp in schema_trans_dist[label0].samples()])): 272 print >>buff, " %s: %s" % (samp, prob) 273 274 print >>buff, "\nRoot transitions" 275 root_trans_dist = self.hmm.root_transition_dist 276 for label0,label1 in sorted(root_trans_dist.conditions()): 277 # Don't show the distribution for transitions where the schema 278 # transition is forced to have 0 probability 279 if (label0,label1) in self.hmm.illegal_transitions: 280 print >>buff, "%s -> %s illegal" % (label0, label1) 281 elif (label0,label1) in self.hmm.fixed_root_transitions: 282 # Show a special case for the constrained transitions 283 print >>buff, "%s -> %s, only %d" % (label0,label1,\ 284 self.hmm.fixed_root_transitions[(label0,label1)]) 285 else: 286 print >>buff, "%s -> %s," % (label0, label1) 287 for (prob,samp) in reversed(sorted(\ 288 [(root_trans_dist[(label0,label1)].prob(samp),samp) \ 289 for samp in root_trans_dist[(label0,label1)].samples()])): 290 print >>buff, " %s: %s" % (samp, prob) 291 292 print >>buff, "\n*** Initial state distribution ***" 293 init_dist = self.hmm.initial_state_dist 294 for (prob,label) in reversed(sorted(\ 295 [(init_dist.prob(label),label) for \ 296 label in init_dist.samples()])): 297 print >>buff, "%s: %s" % (label, prob) 298 299 return buff.getvalue() 300 readable_parameters = property(_get_readable_parameters) 301
302 - def __get_description(self):
303 """ Overridden to add history onto description. """ 304 if self.model_description is not None: 305 model_desc = "\n\n%s" % self.model_description 306 else: 307 model_desc = "" 308 return "%s%s\nModel history:\n%s" % \ 309 (self._description,model_desc,self.hmm.history)
310 description = property(__get_description) 311
312 -class ChordClassMidiTagger(SegmidiTagger):
313 MODEL_CLASS = ChordClassTaggerModel 314 TAGGER_OPTIONS = SegmidiTagger.TAGGER_OPTIONS + [ 315 ModuleOption('decoden', filter=int, 316 help_text="Number of best categories to consider for each timestep", 317 usage="decoden=N, where N is an integer", 318 default=5), 319 ] 320 shell_tools = SegmidiTagger.shell_tools + [ 321 tools.StateGridTool() 322 ] 323
324 - def __init__(self, *args, **kwargs):
325 SegmidiTagger.__init__(self, *args, **kwargs) 326 grammar = self.grammar 327 328 # Prepare the input data to get the observations in the required form 329 emissions = midi_to_emission_stream(self.input, 330 metric=self.model.hmm.metric, 331 remove_empty=False) 332 333 # Use the hmm model to get tag probabilities for each input by 334 # computing n-best viterbi 335 N = self.options['decoden'] 336 # Get the N-best tags for each timestep 337 gamma = self.model.hmm.compute_gamma(emissions[0]) 338 # Match up the elements in the array with their labels 339 T = gamma.shape[0] 340 probabilities = [] 341 for t in range(T): 342 timeprobs = {} 343 for i,label in enumerate(self.model.hmm.label_dom): 344 timeprobs[label] = gamma[t,i] 345 probabilities.append(timeprobs) 346 347 top_tags = [] 348 for time,probs in enumerate(probabilities): 349 ranked = list(reversed(sorted(\ 350 [(prob,(schema,root)) for ((schema,root,chord_class),prob) in probs.items()]))) 351 top_tags.append(ranked[:N]) 352 self.top_tags = top_tags 353 354 # Process the tags to add spans for repeated tags 355 spans = prepare_categories(top_tags) 356 # Each spanset is a priority group of spans 357 category_sets = [] 358 added_spans = [] 359 for spanset in spans: 360 categories = [] 361 # Get a category for each span by its tag 362 for start,end,(log_prob,(schema,root)) in spanset: 363 # For now just use the start cell as the time value 364 # TODO: maybe use the midi tick time?? 365 new_cats = [ 366 (start, 367 end, 368 (category,schema,2**log_prob)) \ 369 for category in \ 370 grammar.get_signs_for_tag(schema, 371 {'root' : root, 'time' : start })] 372 # Don't add the same category twice to the same span 373 # This can happen because some (schema,root) pairs map to the 374 # same category 375 for new_cat in new_cats: 376 if (start,end,new_cat[2][0]) not in added_spans: 377 categories.append(new_cat) 378 added_spans.append((start,end,new_cat[2][0])) 379 category_sets.append(categories) 380 self.category_sets = category_sets
381
382 - def get_signs(self, offset=0):
383 if offset >= len(self.category_sets): 384 return [] 385 else: 386 return self.category_sets[offset]
387