1 """Generic HMM model implementation, using NLTK's probability handling.
2
3 This is similar to L{jazzparser.utils.nltk.ngram.NgramModel}, but is
4 specialized to HMMs (bigram models) and stores probability distributions
5 as dictionaries instead of estimating them from counts. It may be trained
6 from counts in a corpus, but these are thrown away once the model is
7 estimated.
8
9 This type of model may be used in Baum-Welch re-estimation, since the
10 probabilities can be updated, since they're not estimated from counts.
11 Baum-Welch training for this model type (and its subclasses) can be found
12 in L{jazzparser.utils.nltk.ngram.baumwelch}.
13
14 """
15 """
16 ============================== License ========================================
17 Copyright (C) 2008, 2010-12 University of Edinburgh, Mark Granroth-Wilding
18
19 This file is part of The Jazz Parser.
20
21 The Jazz Parser is free software: you can redistribute it and/or modify
22 it under the terms of the GNU General Public License as published by
23 the Free Software Foundation, either version 3 of the License, or
24 (at your option) any later version.
25
26 The Jazz Parser is distributed in the hope that it will be useful,
27 but WITHOUT ANY WARRANTY; without even the implied warranty of
28 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
29 GNU General Public License for more details.
30
31 You should have received a copy of the GNU General Public License
32 along with The Jazz Parser. If not, see <http://www.gnu.org/licenses/>.
33
34 ============================ End license ======================================
35
36 """
37 __author__ = "Mark Granroth-Wilding <mark.granroth-wilding@ed.ac.uk>"
38
39 import numpy
40 from numpy import sum as array_sum
41
42 from jazzparser.utils.nltk.probability import logprob, \
43 cond_prob_dist_to_dictionary_cond_prob_dist
44 from nltk.probability import sum_logs, DictionaryProbDist
45 from .model import NgramModel
48 """
49 Like an NgramModel, but (a) restricted to be an HMM (order 2 and no
50 backoff) and (b) uses dictionary distributions, rather than distributions
51 generated from counts.
52
53 Using dictionary distributions allows the parameters to be retrained, e.g.
54 with Baum-Welch EM, but means we can't use the kind of backoff model
55 NgramModel implements, since this is dependent on having counts for the
56 training data available.
57
58 The distributions given at initialization don't have to be dictionary
59 distributions, but will be converted to such.
60
61 """
62 - def __init__(self, label_dist, emission_dist, label_dom, emission_dom, \
63 mutable=False):
64 """
65 @type label_dist: nltk prob dist
66 @param label_dist: transition distribution
67 @type emission_dist: nltk prob dist
68 @param emission_dist: emission distribution
69 @type label_dom: list
70 @param label_dom: state domain
71 @type emission_dom: list
72 @param emission_dom: emission domain
73 @type mutable: bool
74 @param mutable: if true, the distributions stored will be mutable
75 dictionary distributions, so the model can be updated
76
77 """
78 self.order = 2
79
80 self.label_dom = label_dom
81 self.num_labels = len(label_dom)
82 self.emission_dom = emission_dom
83 self.num_emissions = len(emission_dom)
84
85 self.label_dist = cond_prob_dist_to_dictionary_cond_prob_dist(\
86 label_dist, mutable=mutable)
87 self.emission_dist = cond_prob_dist_to_dictionary_cond_prob_dist(\
88 emission_dist, mutable=mutable)
89
90 observations = {}
91 for label in emission_dist.conditions():
92 for samp in emission_dist[label].samples():
93 observations[samp] = observations.get(samp, 0.0) + \
94 emission_dist[label].prob(samp)
95 self.observation_dist = DictionaryProbDist(observations)
96 self.seen_labels = label_dom
97
98 self.backoff_model = None
99
100
101
102 self.clear_cache()
103
104 @staticmethod
106 """
107 Creates a DictionaryHmmModel from an L{NgramModel}. Note that the
108 ngram model must of the correct sort: order 2 with no backoff.
109
110 """
111 if model.order != 2:
112 raise TypeError, "can only create an HMM from an order 2 ngram"
113 if model.backoff_model is not None:
114 raise TypeError, "tried to create a dictionary HMM from a model "\
115 "with backoff. The backoff can't be replicated in such a model"
116
117 return DictionaryHmmModel(model.label_dist,
118 model.emission_dist,
119 model.label_dom,
120 model.emission_dom,
121 mutable=mutable)
122
124 """
125 Produces a picklable representation of model as a dict.
126
127 """
128 from jazzparser.utils.nltk.storage import object_to_dict
129
130 return {
131 'label_dom' : self.label_dom,
132 'emission_dom' : self.emission_dom,
133 'label_dist' : object_to_dict(self.label_dist),
134 'emission_dist' : object_to_dict(self.emission_dist),
135 }
136
137 @staticmethod
150
152 """
153 This is now implemented much better than it used to be by the
154 superclass. Use
155 L{jazzparser.utils.nltk.ngram.NgramModel.gamma_probabilities}.
156 This method is now just a wrapper to that.
157
158 """
159 return self.gamma_probabilities(*args, **kwargs)
160
161 - def compute_xi(self, sequence, forward=None, backward=None,
162 emission_matrix=None, transition_matrix=None,
163 use_logs=False):
164 """
165 Computes the xi matrix used by Baum-Welch. It is the matrix of joint
166 probabilities of occupation of pairs of consecutive states:
167 P(i_t, j_{t+1} | O).
168
169 As with L{compute_gamma} forward and backward matrices can optionally
170 be passed in to avoid recomputing.
171
172 @type use_logs: bool
173 @param use_logs: by default, this function does not use logs in its
174 calculations. This can lead to underflow if your forward/backward
175 matrices have sufficiently low values. If C{use_logs=True}, logs
176 will be used internally (though the returned values are
177 exponentiated again). This makes the function an order of magnitude
178 slower.
179
180 """
181 if forward is None:
182 forward = self.normal_forward_probabilities(sequence)
183 if backward is None:
184 backward = self.normal_backward_probabilities(sequence)
185
186
187 T,N = forward.shape
188
189
190 xi = numpy.zeros((T-1,N,N), numpy.float64)
191
192
193 if emission_matrix is None:
194 emission_matrix = self.get_emission_matrix(sequence)
195
196 if transition_matrix is None:
197 transition_matrix = self.get_transition_matrix()
198
199 if not use_logs:
200
201 for t in range(T-1):
202 total = 0.0
203
204
205 fwd_trans = forward[t,:, numpy.newaxis]
206
207 xi[t] = transition_matrix.T * fwd_trans * backward[t+1] * \
208 emission_matrix[t+1]
209
210
211 total = array_sum(xi[t])
212 xi[t] /= total
213 else:
214
215 emission_matrix = numpy.log2(emission_matrix)
216 transition_matrix = numpy.log2(transition_matrix)
217 forward = numpy.log2(forward)
218 backward = numpy.log2(backward)
219
220 for t in range(T-1):
221 total = 0.0
222 fwd_trans = forward[t,:, numpy.newaxis]
223 xi[t] = transition_matrix.T + fwd_trans + backward[t+1] + \
224 emission_matrix[t+1]
225
226 total = numpy.logaddexp2.reduce(xi[t])
227 xi[t] -= total
228
229
230 xi = numpy.exp2(xi)
231
232 return xi
233
235 """ More efficient Viterbi decoding than superclass. """
236 T = len(sequence)
237 N = len(self.label_dom)
238
239 viterbi_matrix = numpy.zeros((T,N), numpy.float64)
240 back_pointers = numpy.zeros((T-1,N), numpy.int)
241
242 ems = self.get_emission_matrix(sequence)
243 trans = self.get_transition_matrix()
244
245
246
247 for i,state in enumerate(self.label_dom):
248 viterbi_matrix[0,i] = self.transition_probability(state, None) * ems[0,i]
249
250
251 for t in range(1, T):
252
253
254
255 transitions = trans * viterbi_matrix[t-1]
256
257 max_transitions = numpy.max(transitions, axis=1)
258
259
260 viterbi_matrix[t, :] = max_transitions * ems[t]
261
262 viterbi_matrix[t, :] /= numpy.sum(viterbi_matrix[t, :])
263
264 back_pointers[t-1] = numpy.argmax(transitions, axis=1)
265
266
267 final_state = numpy.argmax(viterbi_matrix[T-1,:])
268
269
270 current_state = final_state
271 states = [final_state]
272 for t in range(T-2, -1, -1):
273 last_state = back_pointers[t, current_state]
274 states.append(last_state)
275 current_state = last_state
276 states = list(reversed(states))
277
278
279 states = [self.label_dom[s] for s in states]
280 return states
281