1 """Baum-Welch EM trainer for the chord labeling model.
2
3 """
4 """
5 ============================== License ========================================
6 Copyright (C) 2008, 2010-12 University of Edinburgh, Mark Granroth-Wilding
7
8 This file is part of The Jazz Parser.
9
10 The Jazz Parser is free software: you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation, either version 3 of the License, or
13 (at your option) any later version.
14
15 The Jazz Parser is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
19
20 You should have received a copy of the GNU General Public License
21 along with The Jazz Parser. If not, see <http://www.gnu.org/licenses/>.
22
23 ============================ End license ======================================
24
25 """
26 __author__ = "Mark Granroth-Wilding <mark.granroth-wilding@ed.ac.uk>"
27
28 import numpy
29 from numpy import float64, sum as array_sum, zeros, log2, add as array_add
30
31 from jazzparser.utils.nltk.probability import logprob
32 from jazzparser.utils.options import ModuleOption
33 from jazzparser.utils.strings import str_to_bool
34 from jazzparser.utils.base import ExecutionTimer
35
36 from jazzparser.utils.nltk.ngram.baumwelch import BaumWelchTrainer
37
38
39 ADD_SMALL = 1e-6
40
41 -def sequence_updates(sequence, last_model, empty_arrays, array_ids,
42 update_initial=True):
43 """
44 Evaluates the forward/backward probability matrices for a
45 single sequence under the model that came from the previous
46 iteration and returns matrices that contain the updates
47 to be made to the distributions during this iteration.
48
49 This is wrapped up in a function so it can be run in
50 parallel for each sequence. Once all sequences have been
51 evaluated, the results are combined and model updated.
52
53 @type update_initial: bool
54 @param update_initial: if update_initial=False,
55 the initial state distribution updates won't be made for this sequence
56
57 """
58 try:
59 (initial_keys, initial_chords, key_trans, chord_trans, ems,
60 initial_keys_denom, initial_chords_denom, key_trans_denom,
61 chord_trans_denom, ems_denom) = empty_arrays
62 chord_ids, chord_type_ids = array_ids
63
64
65 fwds,seq_logprob = last_model.normal_forward_probabilities(sequence, seq_prob=True)
66
67
68 gamma = last_model.gamma_probabilities(sequence, forward=fwds)
69
70 xi = last_model.compute_xi(sequence)
71
72 label_dom = last_model.label_dom
73
74 state_ids = dict([(state,id) for (id,state) in enumerate(label_dom)])
75 T = len(sequence)
76
77 for time in range(T):
78 for state in label_dom:
79 keyi, rooti, labeli = state
80 state_i = state_ids[state]
81 chord_i = chord_ids[((rooti-keyi)%12, labeli)]
82
83 if time == 0:
84
85 initial_keys[keyi] += gamma[time][state_i]
86 initial_chords[chord_i] += gamma[time][state_i]
87
88 if time == T-1:
89
90
91 chord_trans[chord_i][-1] += gamma[time][state_i]
92 else:
93
94
95 for next_state in label_dom:
96 keyj, rootj, labelj = next_state
97 state_j = state_ids[next_state]
98 chord_j = chord_ids[((rootj-keyj)%12, labelj)]
99 key_change = (keyj - keyi) % 12
100
101
102 key_trans[key_change] += xi[time][state_i][state_j]
103 chord_trans[chord_i][chord_j] += xi[time][state_i][state_j]
104
105
106 for note in sequence[time]:
107 pc = (note-rooti) % 12
108 ems[chord_type_ids[labeli]][pc] += gamma[time][state_i]
109
110
111 initial_keys_denom[0] = array_sum(initial_keys)
112 initial_chords_denom[0] = array_sum(initial_chords)
113 key_trans_denom[0] = array_sum(key_trans)
114 chord_trans_denom = array_sum(chord_trans, axis=1)
115 ems_denom = array_sum(ems, axis=1)
116
117
118 return (initial_keys, initial_chords, key_trans, chord_trans, ems,
119 initial_keys_denom, initial_chords_denom, key_trans_denom,
120 chord_trans_denom, ems_denom, seq_logprob)
121 except KeyboardInterrupt:
122 return
123
125 """
126 Baum-Welch training for L{jazzparser.misc.chordlabel.HPChordLabeler}
127 models.
128
129 """
130 OPTIONS = BaumWelchTrainer.OPTIONS + [
131 ModuleOption('initkey', filter=str_to_bool,
132 help_text="Train the initial key distribution. The default "\
133 "behaviour will leave the distribution alone (probably inited "\
134 "to uniform): suitable if the training data is transposed into "\
135 "a common key. If your data has keys, set to true",
136 usage="initkey=B, where B is 'true' or 'false' "\
137 "(default true)",
138 default=False),
139 ]
140
141 - def record_history(self, line):
142 """
143 Stores a line in the history of the model to keep a record of training
144 steps.
145
146 """
147 self.model.add_history(line)
148
149 sequence_updates = staticmethod(sequence_updates)
150
153
155 num_chords = len(self.model.chord_dom)
156 num_chord_types = len(self.model.chord_vocab)
157
158
159 initial_keys = zeros((12,), float64)
160 initial_chords = zeros((num_chords,), float64)
161 key_trans = zeros((12,), float64)
162 chord_trans = zeros((num_chords, num_chords+1), float64)
163 ems = zeros((num_chord_types, 12), float64)
164
165
166 initial_keys_denom = zeros((1,), float64)
167 initial_chords_denom = zeros((1,), float64)
168 key_trans_denom = zeros((1,), float64)
169 chord_trans_denom = zeros((num_chords,), float64)
170 ems_denom = zeros((num_chord_types,), float64)
171
172 return (initial_keys, initial_chords, key_trans, chord_trans, ems,
173 initial_keys_denom, initial_chords_denom, key_trans_denom,
174 chord_trans_denom, ems_denom)
175
177 chord_ids = dict([(chord,id) for (id,chord) in \
178 enumerate(self.model.chord_dom+[None])])
179 chord_type_ids = dict([(ctype,id) for (id,ctype) in \
180 enumerate(self.model.chord_vocab.keys())])
181 return (chord_ids, chord_type_ids)
182
184 if result is None:
185
186 return
187
188
189
190 for local_array,global_array in zip(result[:10], self.global_arrays):
191
192 array_add(global_array, local_array, global_array)
193
195 """
196 Replaces the distributions of the saved model with the probabilities
197 taken from the arrays of updates. self.model is expected to be
198 made up of mutable distributions when this is called.
199
200 """
201 (initial_keys, initial_chords, key_trans, chord_trans, ems,
202 initial_keys_denom, initial_chords_denom, key_trans_denom,
203 chord_trans_denom, ems_denom) = arrays
204 chord_ids, chord_type_ids = array_ids
205
206 num_chords = len(self.model.chord_dom)
207 num_emissions = len(self.model.emission_dom)
208 num_chord_types = len(self.model.chord_vocab)
209
210
211
212 if self.options['initkey']:
213 for key in range(12):
214 prob = logprob(initial_keys[key] + ADD_SMALL) - \
215 logprob(initial_keys_denom[0] + ADD_SMALL*12)
216 self.model.initial_key_dist.update(key, prob)
217
218
219 for chord in self.model.chord_dom:
220 chordi = chord_ids[chord]
221
222 prob = logprob(initial_chords[chordi] + ADD_SMALL) - \
223 logprob(initial_chords_denom[0] + ADD_SMALL*num_chords)
224 self.model.initial_chord_dist.update(chord, prob)
225
226
227 for key in range(12):
228 prob = logprob(key_trans[key] + ADD_SMALL) - \
229 logprob(key_trans_denom[0] + ADD_SMALL*12)
230 self.model.key_transition_dist.update(key, prob)
231
232
233 for chord0 in self.model.chord_dom:
234 chordi = chord_ids[chord0]
235
236 for chord1 in self.model.chord_dom+[None]:
237 chordj = chord_ids[chord1]
238
239 prob = logprob(chord_trans[chordi][chordj] + ADD_SMALL) - \
240 logprob(chord_trans_denom[chordi] + ADD_SMALL*num_chords)
241 self.model.chord_transition_dist[chord0].update(chord1, prob)
242
243
244 for label in self.model.chord_vocab:
245 labeli = chord_type_ids[label]
246
247 for pitch in range(12):
248 prob = logprob(ems[labeli][pitch] + ADD_SMALL) - \
249 logprob(ems_denom[labeli] + ADD_SMALL*num_chord_types)
250 self.model.emission_dist[label].update(pitch, prob)
251
253
254
255
256 while True:
257 try:
258 self.model.save()
259 except (IOError, OSError), err:
260 print "Error writing model to disk: %s. " % err
261 raw_input("Press <enter> to try again... ")
262 else:
263
264 break
265