Skip to content

Commit 7a3d89e

Browse files
committed
WIP for making keras wrapper work.
1 parent d6ba50b commit 7a3d89e

File tree

1 file changed

+63
-8
lines changed

1 file changed

+63
-8
lines changed

src/gate/plugin/learningframework/engines/EnginePythonNetworksBase.java

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package gate.plugin.learningframework.engines;
22

3+
import cc.mallet.types.Alphabet;
34
import cc.mallet.types.FeatureVector;
45
import cc.mallet.types.Instance;
6+
import cc.mallet.types.InstanceList;
57
import gate.Annotation;
68
import gate.AnnotationSet;
79
import gate.lib.interaction.process.Process4JsonStream;
@@ -12,10 +14,13 @@
1214
import gate.plugin.learningframework.GateClassification;
1315
import gate.plugin.learningframework.data.CorpusRepresentationMalletTarget;
1416
import gate.plugin.learningframework.mallet.LFPipe;
17+
import gate.plugin.learningframework.mallet.NominalTargetWithCosts;
1518
import gate.util.GateRuntimeException;
1619
import java.io.File;
1720
import java.io.FileInputStream;
1821
import java.io.InputStreamReader;
22+
import java.util.AbstractMap;
23+
import java.util.AbstractMap.SimpleEntry;
1924
import java.util.ArrayList;
2025
import java.util.Arrays;
2126
import java.util.Collections;
@@ -142,6 +147,10 @@ protected File findWrapperCommand(File dataDirectory, boolean apply) {
142147
@Override
143148
protected void loadModel(File directory, String parms) {
144149
ArrayList<String> finalCommand = new ArrayList<String>();
150+
CorpusRepresentationMalletTarget data = (CorpusRepresentationMalletTarget)corpusRepresentationMallet;
151+
SimpleEntry<String,Integer> modeAndNrC = findOutMode(data);
152+
String mode = modeAndNrC.getKey();
153+
Integer nrClasses = modeAndNrC.getValue();
145154
// Instead of loading a model, this establishes a connection with the
146155
// external wrapper process.
147156

@@ -150,6 +159,8 @@ protected void loadModel(File directory, String parms) {
150159
finalCommand.add(commandFile.getAbsolutePath());
151160
finalCommand.add(wrapperhome);
152161
finalCommand.add(modelFileName);
162+
finalCommand.add(mode);
163+
finalCommand.add(nrClasses.toString());
153164
// if we have a shell command prepend that, and if we have shell parms too, include them
154165
if(shellcmd != null) {
155166
finalCommand.add(0,shellcmd);
@@ -176,6 +187,11 @@ protected void saveModel(File directory) {
176187
@Override
177188
public void trainModel(File dataDirectory, String instanceType, String parms) {
178189
ArrayList<String> finalCommand = new ArrayList<String>();
190+
CorpusRepresentationMalletTarget data = (CorpusRepresentationMalletTarget)corpusRepresentationMallet;
191+
SimpleEntry<String,Integer> modeAndNrC = findOutMode(data);
192+
String mode = modeAndNrC.getKey();
193+
Integer nrClasses = modeAndNrC.getValue();
194+
179195
// invoke wrapper for training
180196
File commandFile = findWrapperCommand(dataDirectory, false);
181197
// Export the data
@@ -184,14 +200,20 @@ public void trainModel(File dataDirectory, String instanceType, String parms) {
184200
// TODO: NOTE: not sure if classification/regression matters here as long as
185201
// the actual exporter class does the right thing based on the corpus representation!
186202
// TODO: we have to choose the correct target type here!!!
203+
// NOTE: the last argument here are the parameters for the exporter method.
204+
// we use the CSV exporter with parameters:
205+
// -t: twofiles, export indep and dep into separate files
206+
// -n: noheaders, do not add a header row
187207
Exporter.export(getCorpusRepresentationMallet(),
188-
Exporter.EXPORTER_CSV_CLASS, dataDirectory, instanceType, parms);
208+
Exporter.EXPORTER_CSV_CLASS, dataDirectory, instanceType, "-t -n");
189209
String dataFileName = dataDirectory.getAbsolutePath()+File.separator;
190210
String modelFileName = new File(dataDirectory, MODEL_BASENAME).getAbsolutePath();
191211
finalCommand.add(commandFile.getAbsolutePath());
192212
finalCommand.add(wrapperhome);
193213
finalCommand.add(dataFileName);
194214
finalCommand.add(modelFileName);
215+
finalCommand.add(mode);
216+
finalCommand.add(nrClasses.toString());
195217
if(!parms.trim().isEmpty()) {
196218
String[] tmp = parms.split("\\s+",-1);
197219
finalCommand.addAll(Arrays.asList(tmp));
@@ -240,17 +262,19 @@ public List<GateClassification> classify(AnnotationSet instanceAS, AnnotationSet
240262
}
241263
// create the datastructure we need for the application script:
242264
// a map that contains the following fields:
243-
// - cmd: either STOP or CSR1
265+
// - cmd: either STOP or "AC" for apply classification or "AR" for apply regression
244266
// - values: the non-zero values, for increasing rows and increasing cols within rows
245267
// - rowinds: for the k-th value which row number it is in
246268
// - colinds: for the k-th value which column number (location index) it is in
247269
// - shaperows: number of rows in total
248270
// - shapecols: maximum number of cols in a vector
249271
Map map = new HashMap<String,Object>();
250-
map.put("cmd", "CSR1");
272+
if(classList==null)
273+
map.put("cmd", "AR");
274+
else
275+
map.put("cmd","AC");
251276
ArrayList<double[]> rows = new ArrayList<double[]>();
252277
int rowIndex = 0;
253-
pipe.getDataAlphabet().size();
254278
List<Annotation> instances = instanceAS.inDocumentOrder();
255279
for(Annotation instAnn : instances) {
256280
Instance inst = data.extractIndependentFeatures(instAnn, inputAS);
@@ -272,9 +296,8 @@ public List<GateClassification> classify(AnnotationSet instanceAS, AnnotationSet
272296
}
273297
// send the matrix data over to the weka process
274298
// TODO: add a key with the featureWeights to the map!
275-
map.put("rows", rows);
276-
map.put("nrRows", rowIndex);
277-
map.put("nrCols", nrCols);
299+
map.put("values", rows);
300+
map.put("n", nrCols);
278301
process.writeObject(map);
279302
// get the result back
280303
Object ret = process.readObject();
@@ -339,6 +362,38 @@ public void initializeAlgorithm(Algorithm algorithm, String parms) {
339362
protected void loadMalletCorpusRepresentation(File directory) {
340363
corpusRepresentationMallet = CorpusRepresentationMalletTarget.load(directory);
341364
}
342-
365+
366+
protected AbstractMap.SimpleEntry<String,Integer> findOutMode(CorpusRepresentationMalletTarget crm) {
367+
InstanceList instances = crm.getRepresentationMallet();
368+
if(instances.size() == 0) {
369+
throw new GateRuntimeException("No instances in the training set, cannot train");
370+
}
371+
// we pass on a "mode" for the learning problem, which is one of the following:
372+
// - classind: predict the index of a class
373+
// - classcosts: targets are vectors of class costs
374+
// - regr: regression
375+
// we also pass on another parameter which provides details of the learning problem:
376+
// - the number of class indices in case of classind and classcosts
377+
// - 0 as a dummy value in case of "regr"
378+
379+
int nrClasses = 0;
380+
String mode = "regr";
381+
Alphabet ta = crm.getPipe().getTargetAlphabet();
382+
383+
if(ta != null) {
384+
Instance firstInstance = instances.get(0);
385+
Object targetObj = firstInstance.getTarget();
386+
if(targetObj instanceof NominalTargetWithCosts) {
387+
NominalTargetWithCosts target = (NominalTargetWithCosts)targetObj;
388+
nrClasses = target.getCosts().length;
389+
mode = "classcosts";
390+
} else {
391+
mode = "classind";
392+
nrClasses = ta.size();
393+
}
394+
}
395+
AbstractMap.SimpleEntry<String,Integer> ret = new AbstractMap.SimpleEntry<String, Integer>(mode,nrClasses);
396+
return ret;
397+
}
343398

344399
}

0 commit comments

Comments
 (0)