Package jazzparser :: Package utils :: Package nltk :: Module probability
[hide private]
[frames] | no frames]

Source Code for Module jazzparser.utils.nltk.probability

  1  """ Extensions to NLTK's probability module. 
  2   
  3  """ 
  4  """ 
  5  ============================== License ======================================== 
  6   Copyright (C) 2008, 2010-12 University of Edinburgh, Mark Granroth-Wilding 
  7    
  8   This file is part of The Jazz Parser. 
  9    
 10   The Jazz Parser is free software: you can redistribute it and/or modify 
 11   it under the terms of the GNU General Public License as published by 
 12   the Free Software Foundation, either version 3 of the License, or 
 13   (at your option) any later version. 
 14    
 15   The Jazz Parser is distributed in the hope that it will be useful, 
 16   but WITHOUT ANY WARRANTY; without even the implied warranty of 
 17   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
 18   GNU General Public License for more details. 
 19    
 20   You should have received a copy of the GNU General Public License 
 21   along with The Jazz Parser.  If not, see <http://www.gnu.org/licenses/>. 
 22   
 23  ============================ End license ====================================== 
 24   
 25  """ 
 26  __author__ = "Mark Granroth-Wilding <mark.granroth-wilding@ed.ac.uk>"  
 27   
 28  import math 
 29  from nltk.probability import FreqDist, ConditionalFreqDist, \ 
 30                      MLEProbDist, FreqDist, ConditionalFreqDist, \ 
 31                      ConditionalProbDist, LaplaceProbDist, WittenBellProbDist, \ 
 32                      GoodTuringProbDist, add_logs as nltk_add_logs, \ 
 33                      DictionaryProbDist, DictionaryConditionalProbDist, \ 
 34                      MutableProbDist, SimpleGoodTuringProbDist 
 35  from .storage import ObjectStorer 
 36  from ..probabilities import random_selection 
37 38 -def logprob(prob):
39 """ 40 Returns the base 2 log of the given probability (or other float). If 41 prob == 0.0, returns -inf. 42 43 """ 44 if prob == 0.0: 45 return float('-inf') 46 else: 47 return math.log(prob, 2)
48
49 -def add_logs(logx, logy):
50 """ 51 Identical to NLTK's L{nltk.probability.add_logs}, but handles the special 52 case where one or both of the numbers is -inf. NLTK's version gives a nan 53 in this case. 54 55 """ 56 if logx == float('-inf'): 57 return logy 58 elif logy == float('-inf'): 59 return logx 60 else: 61 return nltk_add_logs(logx, logy)
62
63 -def sum_logs(logs):
64 """ 65 Identical to NLTK's L{nltk.probability.sum_logs}, but uses our version of 66 L{add_logs} and -inf for zero probs 67 68 """ 69 if len(logs) == 0: 70 return float('-inf') 71 else: 72 return reduce(add_logs, logs[1:], logs[0])
73
74 -def generate_from_prob_dist(dist):
75 """ 76 Generates a sample chosen randomly from the observed samples of an NLTK 77 prob dist, weighted according to their probability. 78 NLTK provides this, but doesn't allow for the summed probability of the 79 observed samples not being 1.0. But, of course, this is the case when 80 we're smoothing. 81 82 """ 83 samp_probs = [(samp, dist.prob(samp)) for samp in dist.samples()] 84 return random_selection(samp_probs, normalize=True)
85
86 -def prob_dist_to_dictionary_prob_dist(dist, mutable=False, samples=None):
87 """ 88 Takes a probability distribution estimated in any way (e.g. from 89 a freq dist) and produces a corresponding dictionary prob dist 90 that just stores the probability of every sample. 91 92 Can be used to turn any kind of prob dist into a dictionary-based 93 one, including a MutableProbDist. 94 95 @type mutable: bool 96 @param mutable: if True, the returned dist is a mutable prob dist 97 98 """ 99 # We may want to give a different set of samples, for example, if there 100 # are samples not represented in the original dist 101 if samples is None: 102 samples = dist.samples() 103 104 probs = {} 105 for sample in samples: 106 probs[sample] = dist.prob(sample) 107 # We'd expect these to sum to one, but normalize just in case 108 dictpd = DictionaryProbDist(probs, normalize=True) 109 110 if mutable: 111 # Convert to a mutable distribution 112 dictpd = MutableProbDist(dictpd, samples) 113 return dictpd
114
115 -def cond_prob_dist_to_dictionary_cond_prob_dist(dist, mutable=False, \ 116 samples=None, conditions=None):
117 """ 118 Takes a conditional probability distribution which may estimate 119 its probabilities in any way (most likely from a set of frequency 120 distributions) and produces an equivalent dictionary conditional 121 distribution, whose distributions are dictionary prob dists. 122 123 @type mutable: bool 124 @param mutable: if True, the returned dist contains mutable prob dists 125 126 """ 127 dists = {} 128 if conditions is None: 129 conditions = dist.conditions() 130 for condition in conditions: 131 dists[condition] = prob_dist_to_dictionary_prob_dist(dist[condition], \ 132 mutable=mutable, samples=samples) 133 return DictionaryConditionalProbDist(dists)
134
135 136 -class WittenBellProbDistFix(WittenBellProbDist):
137 """ 138 There's a nasty bug in WittenBellProbDist, but the fix is very simple. 139 Use this instead of WittenBellProbDist. 140 141 """
142 - def __init__(self, freqdist, bins=None):
143 assert bins == None or bins >= freqdist.B(),\ 144 'Bins parameter must not be less than freqdist.B()' 145 if bins == None: 146 bins = freqdist.B() 147 self._freqdist = freqdist 148 self._T = self._freqdist.B() 149 self._Z = bins - self._freqdist.B() 150 self._N = self._freqdist.N() 151 # self._P0 is P(0), precalculated for efficiency: 152 if self._Z == 0: 153 # No unseen events: probability of anything we have no 154 # counts for is 0 155 self._P0 = 0.0 156 elif self._N==0: 157 # if freqdist is empty, we approximate P(0) by a UniformProbDist: 158 self._P0 = 1.0 / self._Z 159 else: 160 self._P0 = self._T / float(self._Z * (self._N + self._T))
161
162 163 164 -def estimator_name(name):
165 """ Decorator to add a name attribute to the estimator functions """ 166 def _estimator_name(estimator): 167 estimator.estimator_name = name 168 return estimator
169 return _estimator_name 170
171 @estimator_name('mle') 172 -def mle_estimator(fdist, bins):
173 return MLEProbDist(fdist)
174
175 @estimator_name('laplace') 176 -def laplace_estimator(fdist, bins):
177 return LaplaceProbDist(fdist, bins=bins)
178
179 @estimator_name('witten-bell') 180 -def witten_bell_estimator(fdist, bins):
181 return WittenBellProbDistFix(fdist, bins=bins)
182
183 @estimator_name('good_turing') 184 -def good_turing_estimator(fdist, bins):
185 return GoodTuringProbDist(fdist, bins=bins)
186
187 @estimator_name('good_turing') 188 -def simple_good_turing_estimator(fdist, bins):
189 return SimpleGoodTuringProbDist(fdist, bins=bins)
190
191 -def get_estimator_name(estimator):
192 if hasattr(estimator, 'estimator_name'): 193 # Use the readable name if one is available 194 return estimator.estimator_name 195 else: 196 return estimator.__name__
197 198 ESTIMATORS = { 199 'mle' : mle_estimator, 200 'laplace' : laplace_estimator, 201 'witten-bell' : witten_bell_estimator, 202 'good-turing' : good_turing_estimator, 203 'simple-good-turing' : simple_good_turing_estimator, 204 }
205 206 -class CutoffFreqDist(FreqDist):
207 """ 208 Like FreqDist, but returns zero counts for everything with a count 209 less than a given cutoff. Also adjusts the total count to account 210 for the lost counts. 211 212 """
213 - def __init__(self, cutoff, *args, **kwargs):
214 self._cutoff = cutoff 215 super(CutoffFreqDist, self).__init__(*args, **kwargs)
216
217 - def __getitem__(self, key):
218 val = self.raw_count(key) 219 if val <= self._cutoff: 220 return 0 221 else: 222 return val
223
224 - def raw_counts(self):
225 """ 226 Returns the raw counts (i.e. without the cutoff applied) as a 227 dictionary. This could, for example, be used as init data to 228 another FreqDist. 229 230 """ 231 return dict(dict.items(self))
232
233 - def raw_count(self, sample):
234 """ 235 Returns the raw count of this sample (doesn't apply a cutoff). 236 237 """ 238 return super(CutoffFreqDist, self).__getitem__(sample)
239
240 - def N(self):
241 return self._N - self.lost_N()
242
243 - def B(self):
244 """ 245 This is slightly more complicated than the superclass, because 246 we want to count only samples that have non-zero counts after the 247 cutoff has been applied. 248 249 """ 250 return len([count for count in self.values() if count > self._cutoff])
251
252 - def __len__(self):
253 return self.B()
254
255 - def freq(self, sample):
256 """ 257 Have to override this because the superclass doesn't use N(), 258 but the internal _N to calculate the frequency. 259 260 """ 261 if self.N() == 0: 262 return 0 263 return float(self[sample]) / self.N()
264
265 - def copy(self):
266 # Don't use our samples, but the ones without the cutoff applied 267 return CutoffFreqDist(self._cutoff, self.raw_counts())
268
269 - def __add__(self, other):
270 """ 271 Returns a CutoffFreqDist like this one, but with counts from the 272 other added. The other may only be another CutoffFreqDist. 273 274 """ 275 if not isinstance(other, CutoffFreqDist): 276 raise TypeError, "can only sum a CutoffFreqDist with "\ 277 "another CutoffFreqDist, not %s" % type(other).__name__ 278 clone = self.copy() 279 clone.update(other.raw_counts()) 280 return clone
281
282 - def _reset_caches(self):
283 """ Add our own caches to the superclass' """ 284 self._lost_N = None 285 super(CutoffFreqDist, self)._reset_caches()
286
287 - def lost_N(self):
288 """ The number of counts lost by applying the cutoff """ 289 if self._lost_N is None: 290 # Recompute the cached value 291 self._lost_N = 0 292 raw_counts = self.raw_counts() 293 for key in raw_counts: 294 if raw_counts[key] <= self._cutoff: 295 # Would have counted this much but for the cutoff 296 self._lost_N += raw_counts[key] 297 return self._lost_N
298
299 - def _get_cutoff(self):
300 """ Make cutoff a read-only attribute """ 301 return self._cutoff
302 cutoff = property(_get_cutoff) 303
304 - def _sort_keys_by_value(self):
305 """ 306 Need to override this because dict.items(self) accesses the 307 non-cutoff values. 308 """ 309 if not self._item_cache: 310 items = [(key,self[key]) for key in dict.keys(self)] 311 # Eliminate 0 counts 312 items = [(key,val) for (key,val) in items if val != 0] 313 self._item_cache = sorted(items, key=lambda x:(-x[1], x[0]))
314
315 -class CutoffConditionalFreqDist(ConditionalFreqDist):
316 """ 317 A version of ConditionalFreqDist that uses a CutoffFreqDist for 318 each distribution instead of FreqDist. 319 320 """
321 - def __init__(self, cutoff, *args, **kwargs):
322 self._cutoff = cutoff 323 super(CutoffConditionalFreqDist, self).__init__(*args, **kwargs)
324
325 - def _get_cutoff(self):
326 """ Make cutoff a read-only attribute """ 327 return self._cutoff
328 cutoff = property(_get_cutoff) 329
330 - def __getitem__(self, condition):
331 """ 332 Override this to use CutoffFreqDists instead of FreqDists. 333 334 """ 335 # Create the conditioned freq dist, if it doesn't exist 336 if condition not in self._fdists: 337 self._fdists[condition] = CutoffFreqDist(self._cutoff) 338 return self._fdists[condition]
339
340 ########################## Storers (see .storage) ###################### 341 -class CutoffFreqDistStorer(ObjectStorer):
342 STORED_CLASS = CutoffFreqDist 343 344 @staticmethod
345 - def _object_to_dict(obj):
346 from .storage import FreqDistStorer 347 # This overrides FreqDistStorer, so can use most of its method 348 data = FreqDistStorer._object_to_dict(obj) 349 # Add our own value 350 data['cutoff'] = obj.cutoff 351 return data
352 353 @staticmethod
354 - def _dict_to_object(dic):
355 from .storage import FreqDistStorer 356 dist = CutoffFreqDist(dic.pop('cutoff')) 357 return FreqDistStorer._dict_to_object(dic, start_dist=dist)
358
359 -class CutoffConditionalFreqDistStorer(ObjectStorer):
360 STORED_CLASS = CutoffConditionalFreqDist 361 362 @staticmethod
363 - def _object_to_dict(obj):
364 from .storage import ConditionalFreqDistStorer 365 data = ConditionalFreqDistStorer._object_to_dict(obj) 366 # Add the cutoff value 367 data['cutoff'] = obj._cutoff 368 return data
369 370 @staticmethod
371 - def _dict_to_object(dic):
372 from .storage import dict_to_object 373 # Would be nice to use the superclass' storage better, but 374 # simpler to copy this for now 375 obj = CutoffConditionalFreqDist(dic['cutoff']) 376 obj._fdists = dict([ 377 (condition, dict_to_object(dist)) \ 378 for condition,dist in dic['fdists'].items()]) 379 return obj
380