1 """N-gram tagger that operates on chord inputs or chord lattices.
2
3 This differs from the "ngram" tagger in that it can accept lattice input.
4 The model is also very slightly different: it's almost equivalent, but
5 makes a slightly different independence assumption. In general, you're
6 probably better off using this version.
7
8 I'm using this version for all the experiments in the thesis, so that
9 I can use the same supertagger for the supertagging experiments, parsing
10 experiments and MIDI parsing experiments.
11
12 Note that this used to be called C{bigram-multi}, before I generalized it
13 to n-grams and renamed it C{ngram-multi}. There may yet be bugs that arise
14 as a result of this renaming, or old config files, etc, that haven't been
15 updated. The tagger and model classes have been correspondingly renamed.
16
17 """
18 """
19 ============================== License ========================================
20 Copyright (C) 2008, 2010-12 University of Edinburgh, Mark Granroth-Wilding
21
22 This file is part of The Jazz Parser.
23
24 The Jazz Parser is free software: you can redistribute it and/or modify
25 it under the terms of the GNU General Public License as published by
26 the Free Software Foundation, either version 3 of the License, or
27 (at your option) any later version.
28
29 The Jazz Parser is distributed in the hope that it will be useful,
30 but WITHOUT ANY WARRANTY; without even the implied warranty of
31 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
32 GNU General Public License for more details.
33
34 You should have received a copy of the GNU General Public License
35 along with The Jazz Parser. If not, see <http://www.gnu.org/licenses/>.
36
37 ============================ End license ======================================
38
39 """
40 __author__ = "Mark Granroth-Wilding <mark.granroth-wilding@ed.ac.uk>"
41
42 import cPickle as pickle
43 from cStringIO import StringIO
44 from operator import mul
45
46 from jazzparser.taggers.models import ModelTagger, ModelLoadError, \
47 TaggerModel, TaggingModelError, ModelSaveError
48 from jazzparser.taggers import process_chord_input
49 from jazzparser.parsers.base.utils import SpanCombiner
50 from jazzparser.data.input import DbBulkInput, AnnotatedDbBulkInput, \
51 ChordInput, WeightedChordLabelInput, DbInput
52
53 from jazzparser.utils.options import ModuleOption, choose_from_list, \
54 choose_from_dict
55 from jazzparser.utils.probabilities import batch_sizes, beamed_batch_sizes
56 from jazzparser.utils.nltk.probability import ESTIMATORS, laplace_estimator, \
57 get_estimator_name
58 from jazzparser.taggers.chordmap import get_chord_mapping, \
59 get_chord_mapping_module_option
60
61 from .model import MultiChordNgramModel, lattice_to_emissions
62 from .. import TaggerTrainingError
65 MODEL_TYPE = 'ngram-multi'
66
67 TRAINING_OPTIONS = [
68 ModuleOption('n', filter=int,
69 help_text="Length of the n-grams which this model will use.",
70 usage="n=N, where N is an integer. Defaults to bigrams", default=2),
71 ModuleOption('backoff', filter=int,
72 help_text="Number of orders of backoff to use. This must be "\
73 "less than n. E.g. if using a trigram model (n=3) you can "\
74 "set backoff=2 to back off to bigrams and from bigrams "\
75 "to unigrams. Set to 0 to use no backoff at all (default).",
76 usage="backoff=X, where X is an integer < n", default=0),
77 ModuleOption('cutoff', filter=int,
78 help_text="In estimating probabilities, treat any counts below "\
79 "cutoff as zero",
80 usage="cutoff=X, where X is an integer", default=0),
81 ModuleOption('backoff_cutoff', filter=int,
82 help_text="Apply a different cutoff setting to the backoff model. "\
83 "Default is to use the same as the main model",
84 usage="backoff_cutoff=X, where X is an integer"),
85 ModuleOption('estimator', filter=choose_from_dict(ESTIMATORS),
86 help_text="A way of constructing a probability model given "\
87 "the set of counts from the data. Default is to use "\
88 "laplace (add-one) smoothing.",
89 usage="estimator=X, where X is one of: %s" \
90 % ", ".join(ESTIMATORS.keys()),
91 default=laplace_estimator),
92
93 get_chord_mapping_module_option(),
94 ] + TaggerModel.TRAINING_OPTIONS
95
96 - def __init__(self, model_name, model=None, chordmap=None, *args, **kwargs):
107
108 - def train(self, sequences, grammar=None, logger=None):
109 if grammar is None:
110 from jazzparser.grammar import get_grammar
111
112 grammar = get_grammar()
113
114
115 if not isinstance(sequences, (DbBulkInput, AnnotatedDbBulkInput)):
116 raise TaggerTrainingError, "can only train ngram-multi model "\
117 "on bulk db chord input (bulk-db or bulk-db-annotated). Got "\
118 "input of type '%s'" % type(sequences).__name__
119
120 if self.options['backoff_cutoff'] is None:
121 backoff_kwargs = {}
122 else:
123 backoff_kwargs = {'cutoff' : self.options['backoff_cutoff']}
124
125
126 schemata = grammar.pos_tags
127
128
129
130 chord_types = list(set(self.options['chord_mapping'].values()))
131
132 self.model = MultiChordNgramModel.train(
133 sequences,
134 schemata,
135 chord_types,
136 self.options['estimator'],
137 cutoff=self.options['cutoff'],
138 chord_map=self.options['chord_mapping'],
139 order=self.options['n'],
140 backoff_orders=self.options['backoff'],
141 backoff_kwargs=backoff_kwargs)
142
143
144
145 est_name = get_estimator_name(self.options['estimator'])
146 self.model_description = """\
147 Order: %(order)d
148 Backoff orders: %(backoff)d
149 Probability estimator: %(est)s
150 Zero-count threshold: %(cutoff)d
151 Chord mapping: %(chordmap)s
152 Training sequences: %(seqs)d\
153 """ % \
154 {
155 'est' : est_name,
156 'seqs' : len(sequences),
157 'cutoff' : self.options['cutoff'],
158 'chordmap' : self.options['chord_mapping'].name,
159 'order' : self.options['n'],
160 'backoff' : self.options['backoff'],
161 }
162
163 @staticmethod
171
179
180
181
182
184 """
185 Returns a list of timesteps, each consisting of a dictionary mapping
186 states to their occupation probability in that timestep.
187
188 """
189 matrix = []
190
191 gamma = self.model.compute_gamma(observations)
192
193 T,N = gamma.shape
194
195 for t in range(T):
196 state_probs = {}
197 for s,state in enumerate(self.model.label_dom):
198 state_probs[state] = gamma[t, s]
199 matrix.append(state_probs)
200 return matrix
201
203 """
204 Like L{forward_backward_probabilities}, but only uses forward algorithm.
205
206 """
207
208 return self.model.normal_forward_probabilities(observations)
209
210
212 """ Produce a human-readable repr of the params of the model """
213 buff = StringIO()
214
215 try:
216
217
218 print >>buff, self.model_description
219
220 print >>buff, "\nNum emissions: %d" % self.model.num_emissions
221 print >>buff, "\nShowing only probs for non-zero counts. "\
222 "Others may have a non-zero prob by smoothing"
223
224 print >>buff, "\nChord mapping: %s:" % self.chordmap.name
225 for (crdin, crdout) in self.chordmap.items():
226 print >>buff, " %s -> %s" % (crdin, crdout)
227
228 print >>buff, "\nRoot transition dist"
229 for schema in sorted(self.model.root_transition_dist.conditions()):
230 print >>buff, " %s" % schema
231 for prob,interval in reversed(sorted(\
232 (self.model.root_transition_dist[schema].prob(interval),
233 interval) for \
234 interval in self.model.root_transition_dist[schema].samples())):
235 print >>buff, " %s: %s " % (interval, prob)
236 print >>buff
237
238 print >>buff, "Schema transition dist"
239 for context in sorted(self.model.schema_transition_dist.conditions()):
240 print >>buff, " %s" % ",".join([str(s) for s in context])
241 for prob,schema in reversed(sorted(\
242 (self.model.schema_transition_dist[context].prob(schema),
243 schema) for \
244 schema in self.model.schema_transition_dist[context].samples())):
245 print >>buff, " %s: %s " % (schema, prob)
246 print >>buff
247
248 print >>buff, "Emission dist"
249 for schema in sorted(self.model.emission_dist.conditions()):
250 print >>buff, " %s" % schema
251 for prob,chord in reversed(sorted(\
252 (self.model.emission_dist[schema].prob(chord),
253 chord) for \
254 chord in self.model.emission_dist[schema].samples())):
255 print >>buff, " %s: %s " % (chord, prob)
256 except AttributeError, err:
257
258
259 raise ValueError, "error generating model description "\
260 "(attribute error): %s" % err
261
262 return buff.getvalue()
263 readable_parameters = property(_get_readable_parameters)
264
267 MODEL_CLASS = MultiChordNgramTaggerModel
268 TAGGER_OPTIONS = ModelTagger.TAGGER_OPTIONS + [
269 ModuleOption('decode', filter=choose_from_list([ \
270 'forward-backward', 'forward']),
271 help_text="Decoding method for inference.",
272 usage="decode=X, where X is one of 'viterbi', 'forward-backward' "\
273 "or 'forward'",
274 default="forward-backward"),
275 ]
276 INPUT_TYPES = ['db', 'chords', 'labels']
277
278 - def __init__(self, grammar, input, options={}, *args, **kwargs):
279 super(MultiChordNgramTagger, self).__init__(grammar, input, options, *args, **kwargs)
280 process_chord_input(self)
281
282
283 self._tagged_times = []
284 self._tagged_spans = []
285 self._batch_ranges = []
286 word_tag_probs = []
287
288
289 chord_map = self.model.chordmap
290
291 if isinstance(self.wrapped_input, ChordInput):
292 chords = self.wrapped_input.to_db_input().chords
293 observations = [(chord.root, chord_map[chord.type]) for chord in chords]
294 self.input = chords
295 elif isinstance(self.wrapped_input, DbInput):
296 observations = [(chord.root, chord_map[chord.type]) for chord in self.wrapped_input.chords]
297 elif isinstance(self.wrapped_input, WeightedChordLabelInput):
298 observations = lattice_to_emissions(input, chord_map=chord_map)
299
300
301
302 if self.options['decode'] == "forward":
303 probabilities = self.model.forward_probabilities(observations)
304 else:
305 probabilities = self.model.forward_backward_probabilities(observations)
306
307
308 probabilities = [
309 reversed(sorted(\
310 [(state,prob) for (state,prob) in timestep.items() if prob > 0.0], \
311 key=lambda x:x[1])) \
312 for timestep in probabilities]
313
314 for index,probs in enumerate(probabilities):
315 features = {
316 'duration' : self.durations[index],
317 'time' : self.times[index],
318 }
319
320 word_signs = []
321 for (state,prob) in probs:
322 root,schema = state
323
324 features['root'] = root
325 signs = self.grammar.get_signs_for_tag(schema, features)
326
327 if not signs:
328 continue
329 else:
330 sign = signs[0]
331 word_signs.append((sign, (root, schema), prob))
332
333 self._tagged_times.append(word_signs)
334
335
336
337
338 word_tag_probs.append([p for __,__,p in word_signs])
339
340 if self.options['best']:
341
342 batch_ranges = [[(0,1)] for i in range(len(self.input))]
343 else:
344
345 batch_sizes = beamed_batch_sizes(word_tag_probs, self.batch_ratio, max_batch=self.options['max_batch'])
346
347 batch_ranges = [[(sum(batches[:i]),sum(batches[:i+1])) for i in range(len(batches))] \
348 for batches in batch_sizes]
349
350
351
352 def prob_combiner(probs):
353 return sum(probs, 0.0) / float(len(probs))
354 combiner = SpanCombiner()
355 added = True
356 offset = 0
357 while added:
358 added = False
359 batch_spans = []
360 for time in range(len(batch_ranges)):
361 if offset < len(batch_ranges[time]):
362 start, end = batch_ranges[time][offset]
363 for sign_offset in range(start, end):
364 sign, (root,schema), prob = self._tagged_times[time][sign_offset]
365 added = True
366
367 batch_spans.append((time, time+1, (sign,(root,schema),prob)))
368
369
370 combined = combiner.combine_edge(
371 (time, time+1, (root,schema)),
372 properties=prob,
373 prop_combiner=prob_combiner)
374
375 for (span_start, span_end) in combined:
376
377 new_prob = combiner.edge_properties[
378 (span_start, span_end, (root,schema))]
379
380 features = {
381 'duration' : sum(
382 self.durations[span_start:span_end]),
383 'time' : self.times[span_start],
384 'root' : root,
385 }
386
387
388 new_signs = \
389 self.grammar.get_signs_for_tag(schema, features)
390 for new_sign in new_signs:
391 batch_spans.append(
392 (span_start, span_end,
393 (new_sign, (root,schema), new_prob)))
394 self._tagged_spans.append(batch_spans)
395 offset += 1
396
398 if offset < len(self._tagged_spans):
399 return self._tagged_spans[offset]
400 else:
401 return []
402
404 return self.input[index]
405