Package featuregenerator :: Package lm :: Package srilm :: Module srilm_ngram
[hide private]
[frames] | no frames]

Source Code for Module featuregenerator.lm.srilm.srilm_ngram

  1  import xmlrpclib  
  2  #import base64 
  3  from featuregenerator.languagefeaturegenerator import LanguageFeatureGenerator 
  4  from nltk.tokenize.punkt import PunktWordTokenizer 
  5  import sys 
  6  from util.freqcaser import FreqCaser 
  7  from numpy import average, std 
  8   
  9   
 10   
11 -class SRILMngramGenerator(LanguageFeatureGenerator):
12 ''' 13 Gets all the words of a sentence through a SRILM language model and counts how many of them are unknown (unigram prob -99) 14 ''' 15
16 - def __init__(self, url, lang="en", lowercase=True, tokenize=True, freqcase_file=False):
17 ''' 18 Define connection with the server 19 ''' 20 self.server = xmlrpclib.Server(url) 21 self.lang = lang 22 self.lowercase = lowercase 23 self.tokenize = tokenize 24 self.freqcaser = None 25 if freqcase_file: 26 self.freqcaser = FreqCaser(freqcase_file)
27
28 - def get_features_src(self, simplesentence, parallelsentence):
29 atts = {} 30 src_lang = parallelsentence.get_attribute("langsrc") 31 if src_lang == self.lang: 32 atts = self.get_features_simplesentence(simplesentence) 33 34 return atts
35
36 - def get_features_tgt(self, simplesentence, parallelsentence):
37 atts = {} 38 tgt_lang = parallelsentence.get_attribute("langtgt") 39 if tgt_lang == self.lang: 40 atts = self.get_features_simplesentence(simplesentence) 41 return atts
42 43
44 - def _prepare_sentence(self, simplesentence):
45 sent_string = simplesentence.get_string().replace('-',' ').strip() 46 if self.freqcaser: 47 tokenized_string = self.freqcaser.freqcase(sent_string) 48 else: 49 if self.lowercase: 50 sent_string = sent_string.lower() 51 if self.tokenize: 52 sent_string = sent_string.replace('%',' %') #TODO: this is an issue 53 tokenized_string = PunktWordTokenizer().tokenize(sent_string) 54 sent_string = ' '.join(tokenized_string) 55 else: 56 #split and remove empty tokens (due to multiple spaces) 57 tokenized_string = [tok.strip() for tok in sent_string.split(' ') if tok.strip()] 58 59 60 return (tokenized_string, sent_string)
61 62
63 - def prepare_sentence(self, simplesentence):
64 sent_string = simplesentence.get_string().replace('-',' ').strip() 65 if self.freqcaser: 66 tokenized_string = self.freqcaser.freqcase(sent_string) 67 else: 68 if self.lowercase: 69 sent_string = sent_string.lower() 70 if self.tokenize: 71 sent_string = sent_string.replace('%',' %') #TODO: this is an issue 72 tokenized_string = PunktWordTokenizer().tokenize(sent_string) 73 sent_string = ' '.join(tokenized_string) 74 else: 75 tokenized_string = sent_string.split(' ') 76 77 #for i in range(len(tokenized_string)): 78 # tokenized_string[i] = base64.standard_b64encode(tokenized_string[i]) 79 80 return unicode(tokenized_string)
81
82 - def _standouts(self, vector, sign):
83 std_value = std(vector) 84 avg_value = average(vector) 85 standout = 0 86 87 for value in vector: 88 if value*sign > (avg_value + sign*std_value): 89 standout += 1 90 91 return standout
92
93 - def _standout_pos(self, vector, sign):
94 std_value = std(vector) 95 avg_value = average(vector) 96 standout = [] 97 98 99 for pos, value in enumerate(vector, start=1): 100 if value*sign > (avg_value + sign*std_value): 101 standout.append(pos) 102 103 return standout
104 105
106 - def get_features_simplesentence(self, simplesentence):
107 (tokens,sent_string) = self._prepare_sentence(simplesentence) 108 unk_count = 0 109 uni_probs = 1 110 bi_probs = 1 111 tri_probs = 1 112 unk_tokens = [] 113 114 prob = self._get_sentence_probability(sent_string) 115 116 #check for unknown words and collecting unigram probabilities: 117 pos = 0 118 unk_pos = [] #keep the positions of unknown words 119 uni_probs_vector = [] 120 bi_probs_vector = [] 121 tri_probs_vector = [] 122 123 for token in tokens: 124 pos+=1 125 # try: 126 uni_prob = self.server.getUnigramProb(token) 127 #uni_prob = self.server.getUnigramProb(base64.standard_b64encode(token)) 128 if uni_prob == -99: 129 unk_count += 1 130 unk_pos.append(pos) 131 unk_tokens.append(token) 132 sys.stderr.write("Unknown word: %s of len %d\n" % (token, len(token))) 133 else: 134 uni_probs_vector.append(uni_prob) 135 uni_probs += uni_prob 136 # except: 137 #sys.stderr.write("Failed to retrieve unigram probability for token: '%s'\n" % token) 138 # pass 139 140 141 #get bigram probabilities 142 for pos in range ( len(tokens) -1 ): 143 token = tokens[pos:pos+2] 144 if (token[0] not in unk_tokens) and (token[1] not in unk_tokens): 145 # try: 146 bi_prob = self.server.getBigramProb(' '.join(token)) 147 #bi_prob = self.server.getBigramProb(base64.standard_b64encode(' '.join(token))) 148 bi_probs += bi_prob 149 bi_probs_vector.append(bi_prob) 150 # except: 151 #sys.stderr.write("Failed to retrieve bigram probability for tokens: '%s'\n" % ' '.join(token)) 152 153 154 #get trigram probabilities 155 for pos in range ( len(tokens) -2 ): 156 token = tokens[pos:pos+3] 157 if (token[0] not in unk_tokens) and (token[1] not in unk_tokens) and (token[2] not in unk_tokens): 158 # try: 159 tri_prob = self.server.getTrigramProb(' '.join(token)) 160 tri_probs += tri_prob 161 tri_probs_vector.append(tri_prob) 162 163 # except: 164 #sys.stderr.write("Failed to retrieve trigram probability for tokens: '%s'\n" % ' '.join(token)) 165 # pass 166 unk_rel_pos = [(unk_pos_item * 1.00) / len(tokens) for unk_pos_item in unk_pos] 167 unk_len = sum([len(token) for token in unk_tokens]) 168 169 if len(unk_pos) == 0: 170 unk_pos = [0] 171 unk_rel_pos = [0] 172 173 attributes = { 'lm_unk_pos_abs_avg' : str(average(unk_pos)), 174 'lm_unk_pos_abs_std' : str(std(unk_pos)), 175 'lm_unk_pos_abs_min' : str(min(unk_pos)), 176 'lm_unk_pos_abs_max' : str(max(unk_pos)), 177 'lm_unk_pos_rel_avg' : str(average(unk_rel_pos)), 178 'lm_unk_pos_rel_std' : str(std(unk_rel_pos)), 179 'lm_unk_pos_rel_min' : str(min(unk_rel_pos)), 180 'lm_unk_pos_rel_max' : str(max(unk_rel_pos)), 181 'lm_unk' : str(unk_count), 182 'lm_unk_len' : unk_len, 183 184 'lm_uni-prob' : str(uni_probs), 185 'lm_uni-prob_avg' : str(average(uni_probs_vector)), 186 'lm_uni-prob_std' : str(std(uni_probs_vector)), 187 'lm_uni-prob_low' : self._standouts(uni_probs_vector, -1), 188 'lm_uni-prob_high' : self._standouts(uni_probs_vector, +1), 189 'lm_uni-prob_low_pos_avg': average(self._standout_pos(uni_probs_vector, -1)), 190 'lm_uni-prob_low_pos_std': std(self._standout_pos(uni_probs_vector, -1)), 191 192 'lm_bi-prob' : str(bi_probs), 193 'lm_bi-prob_avg' : str(average(bi_probs_vector)), 194 'lm_bi-prob_std' : str(std(bi_probs_vector)), 195 'lm_bi-prob_low' : self._standouts(bi_probs_vector, -1), 196 'lm_bi-prob_high' : self._standouts(bi_probs_vector, +1), 197 'lm_bi-prob_low_pos_avg': average(self._standout_pos(bi_probs_vector, -1)), 198 'lm_bi-prob_low_pos_std': std(self._standout_pos(bi_probs_vector, -1)), 199 200 'lm_tri-prob' : str(tri_probs), 201 'lm_tri-prob_avg' : str(average(tri_probs_vector)), 202 'lm_tri-prob_std' : str(std(tri_probs_vector)), 203 'lm_tri-prob_low' : self._standouts(tri_probs_vector, -1), 204 'lm_tri-prob_high' : self._standouts(tri_probs_vector, +1), 205 'lm_tri-prob_low_pos_avg': average(self._standout_pos(tri_probs_vector, -1)), 206 'lm_tri-prob_low_pos_std': std(self._standout_pos(tri_probs_vector, -1)), 207 'lm_prob' : str(prob) } 208 209 return attributes
210 211 212 213
214 - def _get_sentence_probability(self, sent_string ):
215 216 l = len(sent_string.split(" ")) 217 218 #print l, sent_string 219 return str (self.server.getSentenceProb(sent_string, l))
220 221
222 - def xmlrpc_call(self, batch):
223 return self.server.getNgramFeatures_batch(batch)
224 225 226 227 # def add_features_batch(self, parallelsentences): 228 # 229 # #return self.add_features_batch_xmlrpc(parallelsentences) 230 # batch = [] 231 # preprocessed_batch = [] 232 # for parallelsentence in parallelsentences: 233 # batch.append((parallelsentence.serialize(), parallelsentence.get_attribute("langsrc"), parallelsentence.get_attribute("langtgt"))) 234 # 235 # for (row, langsrc, langtgt) in batch: 236 # preprocessed_row = [] 237 # col_id = 0 238 # for simplesentence in row: 239 # if (col_id == 0 and langsrc == self.lang) or (col_id > 0 and langtgt == self.lang): 240 # simplesentence = self.prepare_sentence(simplesentence) 241 # preprocessed_row.append(simplesentence) 242 # else: 243 # simplesentence = ["DUMMY"] 244 # preprocessed_row.append(simplesentence) 245 # col_id += 1 246 # preprocessed_batch.append(preprocessed_row) 247 # 248 # print "sending request" 249 # features_batch = self.server.getNgramFeatures_batch(preprocessed_batch) 250 # 251 # row_id = 0 252 # 253 # 254 # new_parallelsentences = [] 255 # for row in features_batch: 256 # parallelsentence = parallelsentences[row_id] 257 # src = parallelsentence.get_source() 258 # targets = parallelsentence.get_translations() 259 # 260 # column_id = 0 261 # #dig in the batch to retrieve features 262 # for feature_set in row: 263 # for key in feature_set: 264 # if column_id == 0: 265 # src.add_attribute(key, feature_set[key]) 266 # else: 267 # targets[column_id - 1].add_attribute(key, feature_set[key]) 268 # 269 # 270 # column_id += 1 271 # 272 # parallelsentence.set_source(src) 273 # parallelsentence.set_translations(targets) 274 # new_parallelsentences.append(parallelsentence) 275 # row_id += 1 276 # 277 # return new_parallelsentences 278