11package gate .plugin .learningframework .engines ;
22
3+ import cc .mallet .types .Alphabet ;
34import cc .mallet .types .FeatureVector ;
45import cc .mallet .types .Instance ;
6+ import cc .mallet .types .InstanceList ;
57import gate .Annotation ;
68import gate .AnnotationSet ;
79import gate .lib .interaction .process .Process4JsonStream ;
1214import gate .plugin .learningframework .GateClassification ;
1315import gate .plugin .learningframework .data .CorpusRepresentationMalletTarget ;
1416import gate .plugin .learningframework .mallet .LFPipe ;
17+ import gate .plugin .learningframework .mallet .NominalTargetWithCosts ;
1518import gate .util .GateRuntimeException ;
1619import java .io .File ;
1720import java .io .FileInputStream ;
1821import java .io .InputStreamReader ;
22+ import java .util .AbstractMap ;
23+ import java .util .AbstractMap .SimpleEntry ;
1924import java .util .ArrayList ;
2025import java .util .Arrays ;
2126import java .util .Collections ;
@@ -142,6 +147,10 @@ protected File findWrapperCommand(File dataDirectory, boolean apply) {
142147 @ Override
143148 protected void loadModel (File directory , String parms ) {
144149 ArrayList <String > finalCommand = new ArrayList <String >();
150+ CorpusRepresentationMalletTarget data = (CorpusRepresentationMalletTarget )corpusRepresentationMallet ;
151+ SimpleEntry <String ,Integer > modeAndNrC = findOutMode (data );
152+ String mode = modeAndNrC .getKey ();
153+ Integer nrClasses = modeAndNrC .getValue ();
145154 // Instead of loading a model, this establishes a connection with the
146155 // external wrapper process.
147156
@@ -150,6 +159,8 @@ protected void loadModel(File directory, String parms) {
150159 finalCommand .add (commandFile .getAbsolutePath ());
151160 finalCommand .add (wrapperhome );
152161 finalCommand .add (modelFileName );
162+ finalCommand .add (mode );
163+ finalCommand .add (nrClasses .toString ());
153164 // if we have a shell command prepend that, and if we have shell parms too, include them
154165 if (shellcmd != null ) {
155166 finalCommand .add (0 ,shellcmd );
@@ -176,6 +187,11 @@ protected void saveModel(File directory) {
176187 @ Override
177188 public void trainModel (File dataDirectory , String instanceType , String parms ) {
178189 ArrayList <String > finalCommand = new ArrayList <String >();
190+ CorpusRepresentationMalletTarget data = (CorpusRepresentationMalletTarget )corpusRepresentationMallet ;
191+ SimpleEntry <String ,Integer > modeAndNrC = findOutMode (data );
192+ String mode = modeAndNrC .getKey ();
193+ Integer nrClasses = modeAndNrC .getValue ();
194+
179195 // invoke wrapper for training
180196 File commandFile = findWrapperCommand (dataDirectory , false );
181197 // Export the data
@@ -184,14 +200,20 @@ public void trainModel(File dataDirectory, String instanceType, String parms) {
184200 // TODO: NOTE: not sure if classification/regression matters here as long as
185201 // the actual exporter class does the right thing based on the corpus representation!
186202 // TODO: we have to choose the correct target type here!!!
203+ // NOTE: the last argument here are the parameters for the exporter method.
204+ // we use the CSV exporter with parameters:
205+ // -t: twofiles, export indep and dep into separate files
206+ // -n: noheaders, do not add a header row
187207 Exporter .export (getCorpusRepresentationMallet (),
188- Exporter .EXPORTER_CSV_CLASS , dataDirectory , instanceType , parms );
208+ Exporter .EXPORTER_CSV_CLASS , dataDirectory , instanceType , "-t -n" );
189209 String dataFileName = dataDirectory .getAbsolutePath ()+File .separator ;
190210 String modelFileName = new File (dataDirectory , MODEL_BASENAME ).getAbsolutePath ();
191211 finalCommand .add (commandFile .getAbsolutePath ());
192212 finalCommand .add (wrapperhome );
193213 finalCommand .add (dataFileName );
194214 finalCommand .add (modelFileName );
215+ finalCommand .add (mode );
216+ finalCommand .add (nrClasses .toString ());
195217 if (!parms .trim ().isEmpty ()) {
196218 String [] tmp = parms .split ("\\ s+" ,-1 );
197219 finalCommand .addAll (Arrays .asList (tmp ));
@@ -240,17 +262,19 @@ public List<GateClassification> classify(AnnotationSet instanceAS, AnnotationSet
240262 }
241263 // create the datastructure we need for the application script:
242264 // a map that contains the following fields:
243- // - cmd: either STOP or CSR1
265+ // - cmd: either STOP or "AC" for apply classification or "AR" for apply regression
244266 // - values: the non-zero values, for increasing rows and increasing cols within rows
245267 // - rowinds: for the k-th value which row number it is in
246268 // - colinds: for the k-th value which column number (location index) it is in
247269 // - shaperows: number of rows in total
248270 // - shapecols: maximum number of cols in a vector
249271 Map map = new HashMap <String ,Object >();
250- map .put ("cmd" , "CSR1" );
272+ if (classList ==null )
273+ map .put ("cmd" , "AR" );
274+ else
275+ map .put ("cmd" ,"AC" );
251276 ArrayList <double []> rows = new ArrayList <double []>();
252277 int rowIndex = 0 ;
253- pipe .getDataAlphabet ().size ();
254278 List <Annotation > instances = instanceAS .inDocumentOrder ();
255279 for (Annotation instAnn : instances ) {
256280 Instance inst = data .extractIndependentFeatures (instAnn , inputAS );
@@ -272,9 +296,8 @@ public List<GateClassification> classify(AnnotationSet instanceAS, AnnotationSet
272296 }
273297 // send the matrix data over to the weka process
274298 // TODO: add a key with the featureWeights to the map!
275- map .put ("rows" , rows );
276- map .put ("nrRows" , rowIndex );
277- map .put ("nrCols" , nrCols );
299+ map .put ("values" , rows );
300+ map .put ("n" , nrCols );
278301 process .writeObject (map );
279302 // get the result back
280303 Object ret = process .readObject ();
@@ -339,6 +362,38 @@ public void initializeAlgorithm(Algorithm algorithm, String parms) {
339362 protected void loadMalletCorpusRepresentation (File directory ) {
340363 corpusRepresentationMallet = CorpusRepresentationMalletTarget .load (directory );
341364 }
342-
365+
366+ protected AbstractMap .SimpleEntry <String ,Integer > findOutMode (CorpusRepresentationMalletTarget crm ) {
367+ InstanceList instances = crm .getRepresentationMallet ();
368+ if (instances .size () == 0 ) {
369+ throw new GateRuntimeException ("No instances in the training set, cannot train" );
370+ }
371+ // we pass on a "mode" for the learning problem, which is one of the following:
372+ // - classind: predict the index of a class
373+ // - classcosts: targets are vectors of class costs
374+ // - regr: regression
375+ // we also pass on another parameter which provides details of the learning problem:
376+ // - the number of class indices in case of classind and classcosts
377+ // - 0 as a dummy value in case of "regr"
378+
379+ int nrClasses = 0 ;
380+ String mode = "regr" ;
381+ Alphabet ta = crm .getPipe ().getTargetAlphabet ();
382+
383+ if (ta != null ) {
384+ Instance firstInstance = instances .get (0 );
385+ Object targetObj = firstInstance .getTarget ();
386+ if (targetObj instanceof NominalTargetWithCosts ) {
387+ NominalTargetWithCosts target = (NominalTargetWithCosts )targetObj ;
388+ nrClasses = target .getCosts ().length ;
389+ mode = "classcosts" ;
390+ } else {
391+ mode = "classind" ;
392+ nrClasses = ta .size ();
393+ }
394+ }
395+ AbstractMap .SimpleEntry <String ,Integer > ret = new AbstractMap .SimpleEntry <String , Integer >(mode ,nrClasses );
396+ return ret ;
397+ }
343398
344399}
0 commit comments