1 """ Extensions to NLTK's probability module.
2
3 """
4 """
5 ============================== License ========================================
6 Copyright (C) 2008, 2010-12 University of Edinburgh, Mark Granroth-Wilding
7
8 This file is part of The Jazz Parser.
9
10 The Jazz Parser is free software: you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation, either version 3 of the License, or
13 (at your option) any later version.
14
15 The Jazz Parser is distributed in the hope that it will be useful,
16 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 GNU General Public License for more details.
19
20 You should have received a copy of the GNU General Public License
21 along with The Jazz Parser. If not, see <http://www.gnu.org/licenses/>.
22
23 ============================ End license ======================================
24
25 """
26 __author__ = "Mark Granroth-Wilding <mark.granroth-wilding@ed.ac.uk>"
27
28 import math
29 from nltk.probability import FreqDist, ConditionalFreqDist, \
30 MLEProbDist, FreqDist, ConditionalFreqDist, \
31 ConditionalProbDist, LaplaceProbDist, WittenBellProbDist, \
32 GoodTuringProbDist, add_logs as nltk_add_logs, \
33 DictionaryProbDist, DictionaryConditionalProbDist, \
34 MutableProbDist, SimpleGoodTuringProbDist
35 from .storage import ObjectStorer
36 from ..probabilities import random_selection
39 """
40 Returns the base 2 log of the given probability (or other float). If
41 prob == 0.0, returns -inf.
42
43 """
44 if prob == 0.0:
45 return float('-inf')
46 else:
47 return math.log(prob, 2)
48
50 """
51 Identical to NLTK's L{nltk.probability.add_logs}, but handles the special
52 case where one or both of the numbers is -inf. NLTK's version gives a nan
53 in this case.
54
55 """
56 if logx == float('-inf'):
57 return logy
58 elif logy == float('-inf'):
59 return logx
60 else:
61 return nltk_add_logs(logx, logy)
62
64 """
65 Identical to NLTK's L{nltk.probability.sum_logs}, but uses our version of
66 L{add_logs} and -inf for zero probs
67
68 """
69 if len(logs) == 0:
70 return float('-inf')
71 else:
72 return reduce(add_logs, logs[1:], logs[0])
73
75 """
76 Generates a sample chosen randomly from the observed samples of an NLTK
77 prob dist, weighted according to their probability.
78 NLTK provides this, but doesn't allow for the summed probability of the
79 observed samples not being 1.0. But, of course, this is the case when
80 we're smoothing.
81
82 """
83 samp_probs = [(samp, dist.prob(samp)) for samp in dist.samples()]
84 return random_selection(samp_probs, normalize=True)
85
87 """
88 Takes a probability distribution estimated in any way (e.g. from
89 a freq dist) and produces a corresponding dictionary prob dist
90 that just stores the probability of every sample.
91
92 Can be used to turn any kind of prob dist into a dictionary-based
93 one, including a MutableProbDist.
94
95 @type mutable: bool
96 @param mutable: if True, the returned dist is a mutable prob dist
97
98 """
99
100
101 if samples is None:
102 samples = dist.samples()
103
104 probs = {}
105 for sample in samples:
106 probs[sample] = dist.prob(sample)
107
108 dictpd = DictionaryProbDist(probs, normalize=True)
109
110 if mutable:
111
112 dictpd = MutableProbDist(dictpd, samples)
113 return dictpd
114
117 """
118 Takes a conditional probability distribution which may estimate
119 its probabilities in any way (most likely from a set of frequency
120 distributions) and produces an equivalent dictionary conditional
121 distribution, whose distributions are dictionary prob dists.
122
123 @type mutable: bool
124 @param mutable: if True, the returned dist contains mutable prob dists
125
126 """
127 dists = {}
128 if conditions is None:
129 conditions = dist.conditions()
130 for condition in conditions:
131 dists[condition] = prob_dist_to_dictionary_prob_dist(dist[condition], \
132 mutable=mutable, samples=samples)
133 return DictionaryConditionalProbDist(dists)
134
137 """
138 There's a nasty bug in WittenBellProbDist, but the fix is very simple.
139 Use this instead of WittenBellProbDist.
140
141 """
142 - def __init__(self, freqdist, bins=None):
143 assert bins == None or bins >= freqdist.B(),\
144 'Bins parameter must not be less than freqdist.B()'
145 if bins == None:
146 bins = freqdist.B()
147 self._freqdist = freqdist
148 self._T = self._freqdist.B()
149 self._Z = bins - self._freqdist.B()
150 self._N = self._freqdist.N()
151
152 if self._Z == 0:
153
154
155 self._P0 = 0.0
156 elif self._N==0:
157
158 self._P0 = 1.0 / self._Z
159 else:
160 self._P0 = self._T / float(self._Z * (self._N + self._T))
161
165 """ Decorator to add a name attribute to the estimator functions """
166 def _estimator_name(estimator):
167 estimator.estimator_name = name
168 return estimator
169 return _estimator_name
170
174
178
182
186
189 return SimpleGoodTuringProbDist(fdist, bins=bins)
190
192 if hasattr(estimator, 'estimator_name'):
193
194 return estimator.estimator_name
195 else:
196 return estimator.__name__
197
198 ESTIMATORS = {
199 'mle' : mle_estimator,
200 'laplace' : laplace_estimator,
201 'witten-bell' : witten_bell_estimator,
202 'good-turing' : good_turing_estimator,
203 'simple-good-turing' : simple_good_turing_estimator,
204 }
207 """
208 Like FreqDist, but returns zero counts for everything with a count
209 less than a given cutoff. Also adjusts the total count to account
210 for the lost counts.
211
212 """
213 - def __init__(self, cutoff, *args, **kwargs):
216
218 val = self.raw_count(key)
219 if val <= self._cutoff:
220 return 0
221 else:
222 return val
223
225 """
226 Returns the raw counts (i.e. without the cutoff applied) as a
227 dictionary. This could, for example, be used as init data to
228 another FreqDist.
229
230 """
231 return dict(dict.items(self))
232
234 """
235 Returns the raw count of this sample (doesn't apply a cutoff).
236
237 """
238 return super(CutoffFreqDist, self).__getitem__(sample)
239
241 return self._N - self.lost_N()
242
244 """
245 This is slightly more complicated than the superclass, because
246 we want to count only samples that have non-zero counts after the
247 cutoff has been applied.
248
249 """
250 return len([count for count in self.values() if count > self._cutoff])
251
254
255 - def freq(self, sample):
256 """
257 Have to override this because the superclass doesn't use N(),
258 but the internal _N to calculate the frequency.
259
260 """
261 if self.N() == 0:
262 return 0
263 return float(self[sample]) / self.N()
264
268
270 """
271 Returns a CutoffFreqDist like this one, but with counts from the
272 other added. The other may only be another CutoffFreqDist.
273
274 """
275 if not isinstance(other, CutoffFreqDist):
276 raise TypeError, "can only sum a CutoffFreqDist with "\
277 "another CutoffFreqDist, not %s" % type(other).__name__
278 clone = self.copy()
279 clone.update(other.raw_counts())
280 return clone
281
286
288 """ The number of counts lost by applying the cutoff """
289 if self._lost_N is None:
290
291 self._lost_N = 0
292 raw_counts = self.raw_counts()
293 for key in raw_counts:
294 if raw_counts[key] <= self._cutoff:
295
296 self._lost_N += raw_counts[key]
297 return self._lost_N
298
300 """ Make cutoff a read-only attribute """
301 return self._cutoff
302 cutoff = property(_get_cutoff)
303
305 """
306 Need to override this because dict.items(self) accesses the
307 non-cutoff values.
308 """
309 if not self._item_cache:
310 items = [(key,self[key]) for key in dict.keys(self)]
311
312 items = [(key,val) for (key,val) in items if val != 0]
313 self._item_cache = sorted(items, key=lambda x:(-x[1], x[0]))
314
316 """
317 A version of ConditionalFreqDist that uses a CutoffFreqDist for
318 each distribution instead of FreqDist.
319
320 """
321 - def __init__(self, cutoff, *args, **kwargs):
324
326 """ Make cutoff a read-only attribute """
327 return self._cutoff
328 cutoff = property(_get_cutoff)
329
331 """
332 Override this to use CutoffFreqDists instead of FreqDists.
333
334 """
335
336 if condition not in self._fdists:
337 self._fdists[condition] = CutoffFreqDist(self._cutoff)
338 return self._fdists[condition]
339
358
380