Package jazzparser :: Package misc :: Package chordlabel :: Module baumwelch
[hide private]
[frames] | no frames]

Source Code for Module jazzparser.misc.chordlabel.baumwelch

  1  """Baum-Welch EM trainer for the chord labeling model. 
  2   
  3  """ 
  4  """ 
  5  ============================== License ======================================== 
  6   Copyright (C) 2008, 2010-12 University of Edinburgh, Mark Granroth-Wilding 
  7    
  8   This file is part of The Jazz Parser. 
  9    
 10   The Jazz Parser is free software: you can redistribute it and/or modify 
 11   it under the terms of the GNU General Public License as published by 
 12   the Free Software Foundation, either version 3 of the License, or 
 13   (at your option) any later version. 
 14    
 15   The Jazz Parser is distributed in the hope that it will be useful, 
 16   but WITHOUT ANY WARRANTY; without even the implied warranty of 
 17   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
 18   GNU General Public License for more details. 
 19    
 20   You should have received a copy of the GNU General Public License 
 21   along with The Jazz Parser.  If not, see <http://www.gnu.org/licenses/>. 
 22   
 23  ============================ End license ====================================== 
 24   
 25  """ 
 26  __author__ = "Mark Granroth-Wilding <mark.granroth-wilding@ed.ac.uk>"  
 27   
 28  import numpy 
 29  from numpy import float64, sum as array_sum, zeros, log2, add as array_add 
 30   
 31  from jazzparser.utils.nltk.probability import logprob 
 32  from jazzparser.utils.options import ModuleOption 
 33  from jazzparser.utils.strings import str_to_bool 
 34  from jazzparser.utils.base import ExecutionTimer 
 35   
 36  from jazzparser.utils.nltk.ngram.baumwelch import BaumWelchTrainer 
 37   
 38  # Small quantity added to every probability to ensure we never get zeros 
 39  ADD_SMALL = 1e-6 
 40   
41 -def sequence_updates(sequence, last_model, empty_arrays, array_ids, 42 update_initial=True):
43 """ 44 Evaluates the forward/backward probability matrices for a 45 single sequence under the model that came from the previous 46 iteration and returns matrices that contain the updates 47 to be made to the distributions during this iteration. 48 49 This is wrapped up in a function so it can be run in 50 parallel for each sequence. Once all sequences have been 51 evaluated, the results are combined and model updated. 52 53 @type update_initial: bool 54 @param update_initial: if update_initial=False, 55 the initial state distribution updates won't be made for this sequence 56 57 """ 58 try: 59 (initial_keys, initial_chords, key_trans, chord_trans, ems, 60 initial_keys_denom, initial_chords_denom, key_trans_denom, 61 chord_trans_denom, ems_denom) = empty_arrays 62 chord_ids, chord_type_ids = array_ids 63 64 # Compute the forwards with seq_prob=True 65 fwds,seq_logprob = last_model.normal_forward_probabilities(sequence, seq_prob=True) 66 # gamma contains the state occupation probability for each state at each 67 # timestep 68 gamma = last_model.gamma_probabilities(sequence, forward=fwds) 69 # xi contains the probability of every state transition at every timestep 70 xi = last_model.compute_xi(sequence) 71 72 label_dom = last_model.label_dom 73 # Enumerate the label dom 74 state_ids = dict([(state,id) for (id,state) in enumerate(label_dom)]) 75 T = len(sequence) 76 77 for time in range(T): 78 for state in label_dom: 79 keyi, rooti, labeli = state 80 state_i = state_ids[state] 81 chord_i = chord_ids[((rooti-keyi)%12, labeli)] 82 83 if time == 0: 84 # Update initial distributions 85 initial_keys[keyi] += gamma[time][state_i] 86 initial_chords[chord_i] += gamma[time][state_i] 87 88 if time == T-1: 89 # Last timestep 90 # Update the transition dists for transitions to final state 91 chord_trans[chord_i][-1] += gamma[time][state_i] 92 else: 93 # Go through all possible pairs of states to update the 94 # transition distributions 95 for next_state in label_dom: 96 keyj, rootj, labelj = next_state 97 state_j = state_ids[next_state] 98 chord_j = chord_ids[((rootj-keyj)%12, labelj)] 99 key_change = (keyj - keyi) % 12 100 101 ## Transition dist updates ## 102 key_trans[key_change] += xi[time][state_i][state_j] 103 chord_trans[chord_i][chord_j] += xi[time][state_i][state_j] 104 105 ## Emission dist update ## 106 for note in sequence[time]: 107 pc = (note-rooti) % 12 108 ems[chord_type_ids[labeli]][pc] += gamma[time][state_i] 109 110 # Calculate the denominators by summing 111 initial_keys_denom[0] = array_sum(initial_keys) 112 initial_chords_denom[0] = array_sum(initial_chords) 113 key_trans_denom[0] = array_sum(key_trans) 114 chord_trans_denom = array_sum(chord_trans, axis=1) 115 ems_denom = array_sum(ems, axis=1) 116 117 # Wrap this all up in a tuple to return to the master 118 return (initial_keys, initial_chords, key_trans, chord_trans, ems, 119 initial_keys_denom, initial_chords_denom, key_trans_denom, 120 chord_trans_denom, ems_denom, seq_logprob) 121 except KeyboardInterrupt: 122 return
123
124 -class HPBaumWelchTrainer(BaumWelchTrainer):
125 """ 126 Baum-Welch training for L{jazzparser.misc.chordlabel.HPChordLabeler} 127 models. 128 129 """ 130 OPTIONS = BaumWelchTrainer.OPTIONS + [ 131 ModuleOption('initkey', filter=str_to_bool, 132 help_text="Train the initial key distribution. The default "\ 133 "behaviour will leave the distribution alone (probably inited "\ 134 "to uniform): suitable if the training data is transposed into "\ 135 "a common key. If your data has keys, set to true", 136 usage="initkey=B, where B is 'true' or 'false' "\ 137 "(default true)", 138 default=False), 139 ] 140
141 - def record_history(self, line):
142 """ 143 Stores a line in the history of the model to keep a record of training 144 steps. 145 146 """ 147 self.model.add_history(line)
148 149 sequence_updates = staticmethod(sequence_updates) 150
151 - def create_mutable_model(self, model):
152 return model.copy(mutable=True)
153
154 - def get_empty_arrays(self):
155 num_chords = len(self.model.chord_dom) 156 num_chord_types = len(self.model.chord_vocab) 157 158 # Accumulators 159 initial_keys = zeros((12,), float64) 160 initial_chords = zeros((num_chords,), float64) 161 key_trans = zeros((12,), float64) 162 chord_trans = zeros((num_chords, num_chords+1), float64) 163 ems = zeros((num_chord_types, 12), float64) 164 165 # Denominator accumulators 166 initial_keys_denom = zeros((1,), float64) 167 initial_chords_denom = zeros((1,), float64) 168 key_trans_denom = zeros((1,), float64) 169 chord_trans_denom = zeros((num_chords,), float64) 170 ems_denom = zeros((num_chord_types,), float64) 171 172 return (initial_keys, initial_chords, key_trans, chord_trans, ems, 173 initial_keys_denom, initial_chords_denom, key_trans_denom, 174 chord_trans_denom, ems_denom)
175
176 - def get_array_indices(self):
177 chord_ids = dict([(chord,id) for (id,chord) in \ 178 enumerate(self.model.chord_dom+[None])]) 179 chord_type_ids = dict([(ctype,id) for (id,ctype) in \ 180 enumerate(self.model.chord_vocab.keys())]) 181 return (chord_ids, chord_type_ids)
182
183 - def sequence_updates_callback(self, result):
184 if result is None: 185 # Process cancelled: do no updates 186 return 187 188 # The members of the result tuple (apart from the logprob at the end) 189 # should match up with the array they're to be added to in global_arrays 190 for local_array,global_array in zip(result[:10], self.global_arrays): 191 # Add the arrays together and store the result in the global array 192 array_add(global_array, local_array, global_array)
193
194 - def update_model(self, arrays, array_ids):
195 """ 196 Replaces the distributions of the saved model with the probabilities 197 taken from the arrays of updates. self.model is expected to be 198 made up of mutable distributions when this is called. 199 200 """ 201 (initial_keys, initial_chords, key_trans, chord_trans, ems, 202 initial_keys_denom, initial_chords_denom, key_trans_denom, 203 chord_trans_denom, ems_denom) = arrays 204 chord_ids, chord_type_ids = array_ids 205 206 num_chords = len(self.model.chord_dom) 207 num_emissions = len(self.model.emission_dom) 208 num_chord_types = len(self.model.chord_vocab) 209 210 # Initial keys distribution 211 # Only update this distribution if asked to: often we should leave it 212 if self.options['initkey']: 213 for key in range(12): 214 prob = logprob(initial_keys[key] + ADD_SMALL) - \ 215 logprob(initial_keys_denom[0] + ADD_SMALL*12) 216 self.model.initial_key_dist.update(key, prob) 217 218 # Initial chords distribution 219 for chord in self.model.chord_dom: 220 chordi = chord_ids[chord] 221 222 prob = logprob(initial_chords[chordi] + ADD_SMALL) - \ 223 logprob(initial_chords_denom[0] + ADD_SMALL*num_chords) 224 self.model.initial_chord_dist.update(chord, prob) 225 226 # Key transition distribution 227 for key in range(12): 228 prob = logprob(key_trans[key] + ADD_SMALL) - \ 229 logprob(key_trans_denom[0] + ADD_SMALL*12) 230 self.model.key_transition_dist.update(key, prob) 231 232 # Chord transition distribution 233 for chord0 in self.model.chord_dom: 234 chordi = chord_ids[chord0] 235 236 for chord1 in self.model.chord_dom+[None]: 237 chordj = chord_ids[chord1] 238 239 prob = logprob(chord_trans[chordi][chordj] + ADD_SMALL) - \ 240 logprob(chord_trans_denom[chordi] + ADD_SMALL*num_chords) 241 self.model.chord_transition_dist[chord0].update(chord1, prob) 242 243 # Emission distribution 244 for label in self.model.chord_vocab: 245 labeli = chord_type_ids[label] 246 247 for pitch in range(12): 248 prob = logprob(ems[labeli][pitch] + ADD_SMALL) - \ 249 logprob(ems_denom[labeli] + ADD_SMALL*num_chord_types) 250 self.model.emission_dist[label].update(pitch, prob)
251
252 - def save(self):
253 # If the writing fails, wait till I've had a chance to sort it 254 # out and then try again. This happens when my AFS token runs 255 # out 256 while True: 257 try: 258 self.model.save() 259 except (IOError, OSError), err: 260 print "Error writing model to disk: %s. " % err 261 raw_input("Press <enter> to try again... ") 262 else: 263 # Saved normally 264 break
265