Package dataprocessor :: Package ce :: Module pysvmlight
[hide private]
[frames] | no frames]

Source Code for Module dataprocessor.ce.pysvmlight

  1  ''' 
  2  Created on 27 Aug 2012 
  3   
  4  @author: Eleftherios Avramidis 
  5  ''' 
  6  from xml.etree.cElementTree import iterparse 
  7   
  8  TAG_SENT = 'judgedsentence' 
  9  TAG_SRC = 'src' 
 10  TAG_TGT = 'tgt' 
 11  TAG_DOC = 'jcml' 
 12  import sys 
 13  import math 
 14   
15 -def get_svmlight_format(dataset):
16 attribute_names = set() 17 18 for parallelsentence in dataset.get_parallelsentences(): 19 current_attribute_names = parallelsentence.get_nested_attribute_names() 20 attribute_names.update(current_attribute_names) 21 22 attribute_names = sorted(list(attribute_names)) 23 24 instances = [get_instance_from_parallelsentence(parallelsentence) for parallelsentence in dataset.get_parallelsentences()] 25 return instances
26
27 -def get_instance_from_parallelsentence(parallelsentence, attribute_names):
28 current_attributes = parallelsentence.get_nested_attributes() 29 label = current_attributes["rank"] 30 del(current_attributes["rank"]) 31 new_attributes = [] 32 for att, value in current_attributes.iteritems(): 33 att_id = 1.0 * float(attribute_names.index(att)) 34 try: 35 value = float(value) 36 except: 37 continue 38 new_attributes.append((att_id, 1.0 * value)) 39 instance = (int(label), new_attributes) 40 return instance
41 42
43 -def get_attribute_names(input_xml_filename):
44 ''' 45 Parse once the given XML file and return a set with the attribute names 46 @param input_xml_filename: The XML file to be parsed 47 ''' 48 source_xml_file = open(input_xml_filename, "r") 49 # get an iterable 50 context = iterparse(source_xml_file, events=("start", "end")) 51 # turn it into an iterator 52 context = iter(context) 53 # get the root element 54 event, root = context.next() 55 56 number_of_targets = 0 57 attribute_names = [] 58 for event, elem in context: 59 #new sentence: get attributes 60 if event == "start" and elem.tag == TAG_SENT: 61 attribute_names.extend(elem.attrib.keys()) 62 target_id = 0 63 #new source sentence 64 elif event == "start" and elem.tag == TAG_SRC: 65 source_attributes = ["src_{}".format(key) for key in elem.attrib.keys()] 66 attribute_names.extend(source_attributes) 67 #new target sentence 68 elif event == "start" and elem.tag == TAG_TGT: 69 target_id += 1 70 target_attributes = ["tgt_{}".format(key) for key in elem.attrib.keys()] 71 attribute_names.extend(target_attributes) 72 elif event == "end" and elem.tag == TAG_SENT: 73 if target_id > number_of_targets: 74 number_of_targets = target_id 75 root.clear() 76 source_xml_file.close() 77 return set(attribute_names)
78 79
80 -def read_file_incremental(input_xml_filename, **kwargs):
81 82 desired_attributes = kwargs.setdefault("desired_attributes", []) 83 class_name = kwargs.setdefault("class_name", "tgt_rank") 84 group_test = kwargs.setdefault("group_test", False) 85 id_start = kwargs.setdefault("id_start", 0) 86 impute = kwargs.setdefault("impute", True) 87 remove_inf = kwargs.setdefault("remove_inf", True) 88 89 existing_attribute_names = get_attribute_names(input_xml_filename) 90 91 if desired_attributes: 92 attribute_names = set(desired_attributes) 93 missing_attribute_names = attribute_names - existing_attribute_names 94 usable_attribute_names = attribute_names.intersection(existing_attribute_names) 95 if list(missing_attribute_names): 96 sys.stderr.write("could not find attributes {}".format("\n\t".join(list(missing_attribute_names)))) 97 attribute_names = desired_attributes 98 99 if not desired_attributes or not usable_attribute_names: 100 meta_attributes = kwargs.setdefault("meta_attributes", []) 101 attribute_names = existing_attribute_names - set(meta_attributes) 102 attribute_names = sorted(list(attribute_names)) 103 104 105 source_xml_file = open(input_xml_filename, "r") 106 # get an iterable 107 context = iterparse(source_xml_file, events=("start", "end")) 108 # turn it into an iterator 109 context = iter(context) 110 # get the root element 111 event, root = context.next() 112 113 instances = [] 114 instancegroups = [] 115 116 attributes = [] 117 target_id = 0 118 i = id_start 119 for event, elem in context: 120 #new sentence: get attributes 121 if event == "start" and elem.tag == TAG_SENT: 122 general_attributes = elem.attrib 123 i +=1 124 attribute_list = [] 125 target_id = 0 126 #new source sentence 127 elif event == "start" and elem.tag == TAG_SRC: 128 source_attributes = [("src_{}".format(key), value) for key, value in elem.attrib.iteritems()] 129 130 #new target sentence 131 elif event == "start" and elem.tag == TAG_TGT: 132 target_id += 1 133 target_attributes = [("tgt_{}".format(key), value) for key, value in elem.attrib.iteritems()] 134 attribute_list = [] 135 attribute_list.extend(source_attributes) 136 attribute_list.extend(target_attributes) 137 attributes = dict(attribute_list) 138 attributes.update(general_attributes) 139 label = attributes[class_name] 140 del(attributes[class_name]) 141 142 new_attributes = [] 143 for att, value in attributes.iteritems(): 144 try: 145 att_id = int(attribute_names.index(att)+1) 146 except ValueError: #maybe it is a meta 147 continue 148 try: 149 value = float(value) 150 except: 151 if impute: 152 value = 0 153 else: 154 continue 155 if remove_inf: 156 if math.isnan(value): 157 value = math.copysign(0, value) 158 elif math.isinf(value): 159 value = math.copysign(99999, value) 160 new_attributes.append((att_id, 1.0 * value)) 161 instance = (int(label), new_attributes, i) 162 instances.append(instance) 163 164 elif event == "end" and elem.tag == TAG_SRC: 165 pass 166 elif event == "end" and elem.tag == TAG_TGT: 167 pass 168 elif event == "end" and elem.tag in TAG_SENT: 169 if group_test: 170 instancegroups.append(instances) 171 instances = [] 172 173 root.clear() 174 source_xml_file.close() 175 if group_test: 176 return instancegroups 177 return instances
178
179 -def convert_jcml_to_dat(jcml_filename, dat_filename, **kwargs):
180 instances = read_file_incremental(jcml_filename, **kwargs) 181 dat = open(dat_filename, 'w') 182 for label, features, qid in instances: 183 featurestring = " ".join("{}:{}".format(name, value) for name, value in sorted(features)) 184 line = "{} qid:{} {}".format(label, qid, featurestring) 185 dat.write("{}\n".format(line)) 186 dat.close()
187