Package jptests :: Package utils :: Package nltk :: Package ngram :: Module baumwelch
[hide private]
[frames] | no frames]

Source Code for Module jptests.utils.nltk.ngram.baumwelch

  1  """Unit tests for jazzparser.utils.nltk.ngram.baumwelch 
  2   
  3  """ 
  4  """ 
  5  ============================== License ======================================== 
  6   Copyright (C) 2008, 2010-12 University of Edinburgh, Mark Granroth-Wilding 
  7    
  8   This file is part of The Jazz Parser. 
  9    
 10   The Jazz Parser is free software: you can redistribute it and/or modify 
 11   it under the terms of the GNU General Public License as published by 
 12   the Free Software Foundation, either version 3 of the License, or 
 13   (at your option) any later version. 
 14    
 15   The Jazz Parser is distributed in the hope that it will be useful, 
 16   but WITHOUT ANY WARRANTY; without even the implied warranty of 
 17   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
 18   GNU General Public License for more details. 
 19    
 20   You should have received a copy of the GNU General Public License 
 21   along with The Jazz Parser.  If not, see <http://www.gnu.org/licenses/>. 
 22   
 23  ============================ End license ====================================== 
 24   
 25  """ 
 26  __author__ = "Mark Granroth-Wilding <mark.granroth-wilding@ed.ac.uk>"  
 27   
 28  import unittest 
 29  from jazzparser.utils.nltk.ngram import DictionaryHmmModel 
 30  from jazzparser.utils.nltk.ngram.baumwelch import BaumWelchTrainer 
 31  from nltk.probability import DictionaryConditionalProbDist, DictionaryProbDist 
 32   
33 -class TestTrain(unittest.TestCase):
34 - def setUp(self):
35 """ 36 Prepare some training data. 37 38 """ 39 self.TRAINING_DATA = [ 40 [0, 5, 5, 7, 6, 7, 8, 5, 2, 0, 3, 1, 2, 2, 2, 9, 9, 8, 0, 8, 9, 9, 1, 3, 2, 2, 1], 41 [3, 3, 1, 2, 1, 1, 0, 1, 9, 7, 8, 7, 7, 9, 0], 42 [7, 8, 6, 9, 8, 9, 9, 1, 3, 0, 1, 3, 0, 1, 1, 0, 5, 7, 5, 4, 5, 7, 7] 43 ] 44 self.TEST_DATA = [0, 1, 2, 3, 4, 3, 5, 6, 7, 8, 8, 9, 7, 7, 0, 0, 1] 45 46 ems = list(range(10)) 47 states = ['H', 'M', 'L'] 48 # Construct some initial distributions 49 # Emission 50 hprobs = { 51 0:0.0, 1:0.0, 2:0.0, 3:0.0, 4:0.0, 5:0.0, 6:0.1, 7:0.3, 52 8:0.3, 9:0.3 } 53 mprobs = { 54 0:0.0, 1:0.0, 2:0.0, 3:0.1, 4:0.3, 5:0.3, 6:0.3, 7:0.0, 55 8:0.0, 9:0.0 } 56 lprobs = { 57 0:0.2, 1:0.2, 2:0.2, 3:0.2, 4:0.2, 5:0.0, 6:0.0, 7:0.0, 58 8:0.0, 9:0.0 } 59 conddist = { 60 'H' : DictionaryProbDist(hprobs), 61 'M' : DictionaryProbDist(mprobs), 62 'L' : DictionaryProbDist(lprobs), 63 } 64 emdist = DictionaryConditionalProbDist(conddist) 65 # And transition 66 conddist = {} 67 for first in states+[None]: 68 probs = dict([(second, 1.0/3) for second in states+[None]]) 69 dist = DictionaryProbDist(probs) 70 conddist[(first,)] = dist 71 transdist = DictionaryConditionalProbDist(conddist) 72 73 # Initialize an ngram model with these distributions 74 self.model = DictionaryHmmModel(transdist, emdist, states, ems)
75
76 - def test_init_model(self):
77 """ 78 Check that the initialized model is doing something sensible. 79 80 """ 81 # We don't check these, just that they're not generating errors: 82 # that would really confuse the rest of it 83 self.model.emission_probability(2, 'H') 84 self.model.emission_probability(6, 'H') 85 self.model.transition_probability('H','L') 86 self.model.transition_probability('H','H')
87
88 - def test_init_decode(self):
89 """ 90 Try running the viterbi decoder using the initial model. 91 92 """ 93 self.model.viterbi_decode(self.TEST_DATA)
94
95 - def test_baum_welch(self):
96 """ 97 Runs the Baum Welch trainer using the training data. 98 99 """ 100 options = BaumWelchTrainer.process_option_dict({}) 101 trainer = BaumWelchTrainer(self.model, options) 102 # Train the model with Baum Welch 103 trainer.train(self.TRAINING_DATA) 104 model = trainer.model 105 # Try decoding using the trained model to check it still works 106 model.viterbi_decode(self.TEST_DATA)
107
108 - def test_baum_welch_mp(self):
109 """ 110 Does the same as L{test_baum_welch}, but uses multiprocessing. 111 112 """ 113 options = BaumWelchTrainer.process_option_dict({'trainprocs':-1}) 114 trainer = BaumWelchTrainer(self.model, options) 115 # Train the model with Baum Welch 116 trainer.train(self.TRAINING_DATA)
117