Package jazzparser :: Package taggers :: Package baseline2 :: Module tagger
[hide private]
[frames] | no frames]

Source Code for Module jazzparser.taggers.baseline2.tagger

  1  """Second, very simple baseline tagger model. 
  2   
  3  Tagging model 'baseline2' is another very simple tagging model that tags  
  4  using just the unigram probabilities on the basis of observed chord  
  5  intervals (no types). 
  6   
  7  It is the model presented as 'model 4' in the Stupid Baselines talk. 
  8   
  9  """ 
 10  """ 
 11  ============================== License ======================================== 
 12   Copyright (C) 2008, 2010-12 University of Edinburgh, Mark Granroth-Wilding 
 13    
 14   This file is part of The Jazz Parser. 
 15    
 16   The Jazz Parser is free software: you can redistribute it and/or modify 
 17   it under the terms of the GNU General Public License as published by 
 18   the Free Software Foundation, either version 3 of the License, or 
 19   (at your option) any later version. 
 20    
 21   The Jazz Parser is distributed in the hope that it will be useful, 
 22   but WITHOUT ANY WARRANTY; without even the implied warranty of 
 23   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
 24   GNU General Public License for more details. 
 25    
 26   You should have received a copy of the GNU General Public License 
 27   along with The Jazz Parser.  If not, see <http://www.gnu.org/licenses/>. 
 28   
 29  ============================ End license ====================================== 
 30   
 31  """ 
 32  __author__ = "Mark Granroth-Wilding <mark@granroth-wilding.co.uk>"  
 33   
 34  import pickle 
 35  from jazzparser.taggers.models import ModelTagger, ModelLoadError, TaggerModel 
 36  from jazzparser.taggers import process_chord_input 
 37  from jazzparser.utils.probabilities import batch_sizes 
 38  from jazzparser.data import Chord 
 39  from jazzparser.utils.base import group_pairs 
 40   
41 -def observation_from_chord_pair(crd1, crd2):
42 if crd1 is None or crd2 is None: 43 return "0" 44 return "%d" % Chord.interval(Chord.from_name(str(crd1)), Chord.from_name(str(crd2)))
45
46 -class Baseline2Model(TaggerModel):
47 """ 48 A class to encapsulate the model data for the tagger. 49 """ 50 MODEL_TYPE = "baseline2" 51
52 - def __init__(self, model_name, *args, **kwargs):
53 super(Baseline2Model, self).__init__(model_name, *args, **kwargs) 54 self.category_chord_count = {} 55 self.category_count = {} 56 self.chord_count = {}
57
58 - def _add_category_chord_count(self, category, chord):
59 """ 60 Adds a count of the joint observation of the category and the 61 chord and of the category and the chord themselves. 62 """ 63 # Count the cat-chord combo 64 cat_chords = self.category_chord_count.setdefault(category, {}) 65 if chord in cat_chords: 66 cat_chords[chord] += 1 67 else: 68 cat_chords[chord] = 1 69 # Count the cat occurrence 70 if category in self.category_count: 71 self.category_count[category] += 1 72 else: 73 self.category_count[category] = 1 74 # Count the chord occurrence 75 if chord in self.chord_count: 76 self.chord_count[chord] += 1 77 else: 78 self.chord_count[chord] = 1
79
80 - def train(self, sequences, grammar=None, logger=None):
81 seqs = 0 82 chords = 0 83 # Each sequence in the given corpus 84 for seq in sequences: 85 seqs += 1 86 # Each chord in the sequence 87 for c1,c2 in group_pairs(seq.iterator(), none_final=True): 88 chords += 1 89 self._add_category_chord_count(c1.category, observation_from_chord_pair(c1, c2)) 90 # Add a bit of training info to the descriptive text 91 self.model_description = """\ 92 Unigram probability model, observing only root intervals 93 94 Training sequences: %(seqs)d 95 Training samples: %(samples)d""" % { 96 'seqs' : seqs, 97 'samples' : chords 98 }
99
100 - def get_prob_cat_given_chord_pair(self, cat, chord1, chord2):
101 obs = observation_from_chord_pair(chord1, chord2) 102 chord_count = self.chord_count.get(obs, 0) 103 if chord_count == 0: 104 # Unseen data: give all seen cats equal probability 105 if cat in self.category_count: 106 return 1.0 / len(self.category_count) 107 else: 108 # Haven't seen the category before: don't smooth 109 return 0.0 110 count = self.category_chord_count.get(cat, {}).get(obs, 0) 111 return float(count) / chord_count
112
113 -class Baseline2Tagger(ModelTagger):
114 """ 115 The second of the simple baseline tagger models. This models unigram 116 probabilities of tags, given only the intervals between chords. 117 118 """ 119 MODEL_CLASS = Baseline2Model 120 INPUT_TYPES = ['db', 'chords'] 121
122 - def __init__(self, grammar, input, options={}, *args, **kwargs):
123 super(Baseline2Tagger, self).__init__(grammar, input, options, *args, **kwargs) 124 process_chord_input(self) 125 126 #### Tag the input sequence #### 127 self._tagged_data = [] 128 self._batch_ranges = [] 129 # Group the input into pairs 130 inpairs = group_pairs(self.input, none_final=True) 131 # Get all the possible signs from the grammar 132 for index,pair in enumerate(inpairs): 133 features = { 134 'duration' : self.durations[index], 135 'time' : self.times[index], 136 } 137 word_signs = [] 138 # Now assign a probability to each tag, given the observation 139 for tag in self.model.category_count.keys(): 140 sign = self.grammar.get_sign_for_word_by_tag(self.input[index], tag, extra_features=features) 141 if sign is not None: 142 probability = self.model.get_prob_cat_given_chord_pair(tag, *pair) 143 word_signs.append((sign, tag, probability)) 144 word_signs = list(reversed(sorted([(sign, tag, prob) for sign,tag,prob in word_signs], key=lambda x:x[2]))) 145 self._tagged_data.append(word_signs) 146 147 # Work out the sizes of the batches to return these in 148 batches = batch_sizes([p for __,__,p in word_signs], self.batch_ratio) 149 # Transform these into a form that's easier to use for getting the signs 150 so_far = 0 151 batch_ranges = [] 152 for batch in batches: 153 batch_ranges.append((so_far,so_far+batch)) 154 so_far += batch 155 self._batch_ranges.append(batch_ranges)
156
157 - def get_signs_for_word(self, index, offset=0):
158 if self.best_only: 159 # Only ever return one sign 160 if offset == 0 and len(self._tagged_data[index]) > 0: 161 return [self._tagged_data[index][0]] 162 else: 163 return None 164 ranges = self._batch_ranges[index] 165 if offset >= len(ranges): 166 # No more batches left 167 return None 168 start,end = ranges[offset] 169 return self._tagged_data[index][start:end]
170
171 - def get_word(self, index):
172 return self.input[index]
173