Package jazzparser :: Package parsers :: Package tagrank :: Module chart
[hide private]
[frames] | no frames]

Source Code for Module jazzparser.parsers.tagrank.chart

  1  """Simple probabilistic extension to the CKY chart for the tagrank parser. 
  2   
  3  This is rather like the PCFG chart, but doesn't do as much - it just  
  4  combines probabilities very naively from arguments of rule applications  
  5  so that products of tag probabilities work their way up the tree. 
  6   
  7  """ 
  8  """ 
  9  ============================== License ======================================== 
 10   Copyright (C) 2008, 2010-12 University of Edinburgh, Mark Granroth-Wilding 
 11    
 12   This file is part of The Jazz Parser. 
 13    
 14   The Jazz Parser is free software: you can redistribute it and/or modify 
 15   it under the terms of the GNU General Public License as published by 
 16   the Free Software Foundation, either version 3 of the License, or 
 17   (at your option) any later version. 
 18    
 19   The Jazz Parser is distributed in the hope that it will be useful, 
 20   but WITHOUT ANY WARRANTY; without even the implied warranty of 
 21   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
 22   GNU General Public License for more details. 
 23    
 24   You should have received a copy of the GNU General Public License 
 25   along with The Jazz Parser.  If not, see <http://www.gnu.org/licenses/>. 
 26   
 27  ============================ End license ====================================== 
 28   
 29  """ 
 30  __author__ = "Mark Granroth-Wilding <mark.granroth-wilding@ed.ac.uk>"  
 31   
 32  from jazzparser.parsers.cky.chart import Chart, SignHashSet 
 33  from jazzparser.data import DerivationTrace 
 34  from jazzparser.utils.strings import fmt_prob 
 35  import logging 
 36   
 37  # Get the logger from the logging system 
 38  logger = logging.getLogger("main_logger") 
 39   
40 -class TagProbabilitySignHashSet(SignHashSet):
41 """ 42 For this chart's internal data structure, we use a modified 43 implementation of the HashSet which adds some basic handling of 44 the tag probabilities that came from the tagger's model. 45 46 """ 47 DEFAULT_THRESHOLD = 0.001 48 DEFAULT_MAX_SIZE = 50 49 # The beam won't operate at all until there are this many signs on the arc 50 DEFAULT_BEAMING_THRESHOLD = 20 51
52 - def __init__(self, *args, **kwargs):
53 self.threshold = kwargs.pop('threshold', self.DEFAULT_THRESHOLD) 54 self.maxsize = kwargs.pop('maxsize', self.DEFAULT_MAX_SIZE) 55 self.beaming_threshold = kwargs.pop('beaming_threshold', self.DEFAULT_BEAMING_THRESHOLD) 56 super(TagProbabilitySignHashSet, self).__init__(*args, **kwargs)
57
58 - def _add_existing_value(self, existing_value, new_value):
59 # Take the max of the two probability products. 60 # These aren't probabilities - don't sum them. 61 # In effect, this means we're getting the probability of the 62 # most likely tag sequence that led to this sign 63 existing_value.probability = max(existing_value.probability, new_value.probability) 64 # Continue to do whatever the formalism wants with the signs 65 super(TagProbabilitySignHashSet, self)._add_existing_value(existing_value, new_value)
66
67 - def _max_probability(self):
68 return max([val.probability for val in self.values()]+[0.0])
69
70 - def _apply_beam(self):
71 """ 72 Applies a beam, using the already given threshold, to the set, 73 pruning out any signs with a probability lower than the 74 given ratio of the most probable sign. 75 76 """ 77 if len(self) >= self.beaming_threshold: 78 max_prob = self._max_probability() 79 cutoff = max_prob * self.threshold 80 to_remove = [sign for sign in self.values() if sign.probability < cutoff] 81 for sign in to_remove: 82 self.remove(sign) 83 logger.debug("Beam removed %d signs (max %s, min %s)" % \ 84 (len(to_remove), max_prob, cutoff)) 85 # Beam is now applied: check the remaining size 86 if self.maxsize != 0 and len(self) > self.maxsize: 87 logger.debug("Hard beam removed %d signs" % (len(self)-self.maxsize)) 88 # Too many signs: apply a hard cutoff 89 ordered = list(sorted(self.values(), key=lambda s:s.probability)) 90 for sign in ordered[self.maxsize:]: 91 self.remove(sign)
92
93 - def ranked(self):
94 """ 95 Returns the signs in the set ranked by tag probability product 96 (highest first). 97 98 """ 99 return list(reversed(sorted(self.values(), key=lambda s:s.probability)))
100
101 -class TagRankChart(Chart):
102 """ 103 Overrides the CKY chart to add probabilistic stuff. 104 105 Signs in the input should have an attribute 'probability'. 106 The results of rule application will also have such an attribute. 107 108 """ 109 HASH_SET_IMPL = TagProbabilitySignHashSet 110
111 - def __init__(self, *args, **kwargs):
112 super(TagRankChart, self).__init__(*args, **kwargs) 113 # For convenience 114 self.catrep = self.grammar.formalism.PcfgParser.category_representation
115
116 - def _get_ranked_parses(self):
117 """ 118 Full parses ranked by probability. 119 Returns a list. 120 121 """ 122 return list(reversed(sorted(self.parses, key=lambda s:s.probability)))
123 ranked_parses = property(_get_ranked_parses) 124
125 - def apply_unary_rule(self, rule, start, end, beam=True):
126 # Apply the rule using the super method 127 def _res_mod(result, sign): 128 # Function to add the probability to each result from the input 129 result.probability = sign.probability
130 signs_added = super(TagRankChart, self).apply_unary_rule(rule, start, end, result_modifier=_res_mod) 131 132 # Apply a beam if necessary 133 if beam and signs_added: 134 self.apply_beam((start,end)) 135 return signs_added
136
137 - def _apply_binary_rule(self, rule, sign_pair):
138 """ 139 Override to provide probability propagation. 140 141 See L{jazzparser.parsers.cky.chart.Chart._apply_binary_rule} 142 for full doc. 143 144 """ 145 # Calculate the output probability product 146 prob = sign_pair[0].probability * sign_pair[1].probability 147 # Call the superclass method to do the application 148 results = super(TagRankChart, self)._apply_binary_rule(rule, sign_pair) 149 150 # Add probabilities to the results 151 for result in results: 152 result.probability = prob 153 return results
154
155 - def _apply_binary_rule_semantics(self, rule, sign_pair, category):
156 """ 157 Like _apply_binary_rule, but uses the C{apply_rule_semantics()} 158 of the rule instead of C{apply_rule()}. 159 160 Extends the overridden method to add probabilities. 161 162 @see: jazzparser.parsers.cky.chart.Chart._apply_binary_rule_semantics 163 164 """ 165 # Calculate the output probability product 166 prob = sign_pair[0].probability * sign_pair[1].probability 167 # Call the superclass method to do the application 168 results = super(TagRankChart, self)._apply_binary_rule_semantics(rule, sign_pair, category) 169 170 # Add probabilities to the results 171 for result in results: 172 result.probability = prob 173 return results
174
175 - def apply_binary_rules(self, start, middle, end, beam=True):
176 # Call the super method to apply the rules 177 signs_added = super(TagRankChart, self).apply_binary_rules(start, middle, end) 178 179 if beam and signs_added: 180 # Apply a beam to the results 181 self.apply_beam((start, end)) 182 return signs_added
183
184 - def apply_binary_rule(self, rule, start, middle, end, beam=True):
185 # Call the super method to apply the rule 186 signs_added = super(TagRankChart, self).apply_binary_rule(rule, start, middle, end) 187 188 if beam and signs_added: 189 # Apply the beam to the arc that might have got results 190 self.apply_beam((start, end)) 191 return signs_added
192
193 - def apply_beam(self, arc=None):
194 """ 195 Applies a beam to every arc in the chart. If arc is given, it 196 should be a tuple of (start,end): applies a beam only to the 197 arc starting at start and ending at end. 198 Note that the beam will not be applied to any leaves (lexical 199 arcs) in the chart. 200 201 """ 202 if arc is not None: 203 start,end = arc 204 if end != start + 1: 205 # Apply to a specific arc 206 logger.debug("Beaming (%s,%s)" % arc) 207 self._table[start][end-start-1]._apply_beam() 208 else: 209 # Apply to whole chart 210 for ends in self._table: 211 for arcs in ends[1:]: 212 arcs._apply_beam()
213
214 - def _sign_string(self, sign):
215 return "%s (%s)" % (sign, fmt_prob(sign.probability))
216
217 - def launch_inspector(self, input=None):
218 # Inherit docs from Chart 219 from .inspector import TagRankChartInspectorThread 220 inspector = TagRankChartInspectorThread(self, input_strs=input) 221 inspector.start()
222