Package jazzparser :: Package taggers :: Package baseline3 :: Module tagger
[hide private]
[frames] | no frames]

Source Code for Module jazzparser.taggers.baseline3.tagger

  1  """Third, very simple baseline tagger model. 
  2   
  3  Tagging model 'baseline3' is another very simple tagging model that tags  
  4  using just the unigram probabilities on the basis of observed chord  
  5  intervals and chord types. 
  6   
  7  It is the model presented as 'model 5' in the Stupid Baselines talk. 
  8   
  9  """ 
 10  """ 
 11  ============================== License ======================================== 
 12   Copyright (C) 2008, 2010-12 University of Edinburgh, Mark Granroth-Wilding 
 13    
 14   This file is part of The Jazz Parser. 
 15    
 16   The Jazz Parser is free software: you can redistribute it and/or modify 
 17   it under the terms of the GNU General Public License as published by 
 18   the Free Software Foundation, either version 3 of the License, or 
 19   (at your option) any later version. 
 20    
 21   The Jazz Parser is distributed in the hope that it will be useful, 
 22   but WITHOUT ANY WARRANTY; without even the implied warranty of 
 23   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
 24   GNU General Public License for more details. 
 25    
 26   You should have received a copy of the GNU General Public License 
 27   along with The Jazz Parser.  If not, see <http://www.gnu.org/licenses/>. 
 28   
 29  ============================ End license ====================================== 
 30   
 31  """ 
 32  __author__ = "Mark Granroth-Wilding <mark@granroth-wilding.co.uk>"  
 33   
 34  import pickle 
 35  from jazzparser.taggers.models import ModelTagger, ModelLoadError, TaggerModel 
 36  from jazzparser.taggers import process_chord_input 
 37  from jazzparser.utils.probabilities import batch_sizes 
 38  from jazzparser.data import Chord 
 39  from jazzparser.utils.base import group_pairs 
 40   
41 -def observation_from_chord_pair(crd1, crd2):
42 if crd2 is None: 43 interval = 0 44 else: 45 interval = Chord.interval(Chord.from_name(str(crd1)), Chord.from_name(str(crd2))) 46 if not isinstance(crd1, Chord): 47 crd1 = Chord.from_name(str(crd1)) 48 return "%d-%s" % (interval, crd1.type)
49
50 -class Baseline3Model(TaggerModel):
51 """ 52 A class to encapsulate the model data for the tagger. 53 """ 54 MODEL_TYPE = "baseline3" 55
56 - def __init__(self, model_name, *args, **kwargs):
57 super(Baseline3Model, self).__init__(model_name, *args, **kwargs) 58 self.category_chord_count = {} 59 self.category_count = {} 60 self.chord_count = {}
61
62 - def _add_category_chord_count(self, category, chord):
63 """ 64 Adds a count of the joint observation of the category and the 65 chord and of the category and the chord themselves. 66 """ 67 # Count the cat-chord combo 68 cat_chords = self.category_chord_count.setdefault(category, {}) 69 if chord in cat_chords: 70 cat_chords[chord] += 1 71 else: 72 cat_chords[chord] = 1 73 # Count the cat occurrence 74 if category in self.category_count: 75 self.category_count[category] += 1 76 else: 77 self.category_count[category] = 1 78 # Count the chord occurrence 79 if chord in self.chord_count: 80 self.chord_count[chord] += 1 81 else: 82 self.chord_count[chord] = 1
83
84 - def train(self, sequences, grammar=None, logger=None):
85 seqs = 0 86 chords = 0 87 # Each sequence in the given corpus 88 for seq in sequences: 89 seqs += 1 90 # Each chord in the sequence 91 for c1,c2 in group_pairs(seq.iterator(), none_final=True): 92 chords += 1 93 self._add_category_chord_count(c1.category, observation_from_chord_pair(c1, c2)) 94 # Add a bit of training info to the descriptive text 95 self.model_description = """\ 96 Unigram probability model of combined observations of interval and chord type 97 98 Training sequences: %(seqs)d 99 Training samples: %(samples)d""" % { 100 'seqs' : seqs, 101 'samples' : chords 102 }
103
104 - def get_prob_cat_given_chord_pair(self, cat, chord1, chord2):
105 obs = observation_from_chord_pair(chord1, chord2) 106 chord_count = self.chord_count.get(obs, 0) 107 if chord_count == 0: 108 # Unseen data: give all seen cats equal probability 109 if cat in self.category_count: 110 return 1.0 / len(self.category_count) 111 else: 112 # Haven't seen the category before: don't smooth 113 return 0.0 114 count = self.category_chord_count.get(cat, {}).get(obs, 0) 115 return float(count) / chord_count
116
117 -class Baseline3Tagger(ModelTagger):
118 """ 119 The second of the simple baseline tagger models. This models unigram 120 probabilities of tags, given only the intervals between chords. 121 122 """ 123 MODEL_CLASS = Baseline3Model 124 INPUT_TYPES = ['db', 'chords'] 125
126 - def __init__(self, grammar, input, options={}, *args, **kwargs):
127 super(Baseline3Tagger, self).__init__(grammar, input, options, *args, **kwargs) 128 process_chord_input(self) 129 130 #### Tag the input sequence #### 131 self._tagged_data = [] 132 self._batch_ranges = [] 133 # Group the input into pairs 134 inpairs = group_pairs(self.input, none_final=True) 135 # Get all the possible signs from the grammar 136 for index,pair in enumerate(inpairs): 137 features = { 138 'duration' : self.durations[index], 139 'time' : self.times[index], 140 } 141 word_signs = [] 142 # Now assign a probability to each tag, given the observation 143 for tag in self.model.category_count.keys(): 144 sign = self.grammar.get_sign_for_word_by_tag(self.input[index], tag, extra_features=features) 145 if sign is not None: 146 probability = self.model.get_prob_cat_given_chord_pair(tag, *pair) 147 word_signs.append((sign, tag, probability)) 148 word_signs = list(reversed(sorted([(sign, tag, prob) for sign,tag,prob in word_signs], key=lambda x:x[2]))) 149 self._tagged_data.append(word_signs) 150 151 # Work out the sizes of the batches to return these in 152 batches = batch_sizes([p for __,__,p in word_signs], self.batch_ratio) 153 # Transform these into a form that's easier to use for getting the signs 154 so_far = 0 155 batch_ranges = [] 156 for batch in batches: 157 batch_ranges.append((so_far,so_far+batch)) 158 so_far += batch 159 self._batch_ranges.append(batch_ranges)
160
161 - def get_signs_for_word(self, index, offset=0):
162 if self.best_only: 163 # Only ever return one sign 164 if offset == 0 and len(self._tagged_data[index]) > 0: 165 return [self._tagged_data[index][0]] 166 else: 167 return None 168 ranges = self._batch_ranges[index] 169 if offset >= len(ranges): 170 # No more batches left 171 return None 172 start,end = ranges[offset] 173 return self._tagged_data[index][start:end]
174
175 - def get_word(self, index):
176 return self.input[index]
177