Skip to content

Commit 0521267

Browse files
committed
Make it work with keras.
1 parent 7a3d89e commit 0521267

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

lib-static/interaction-3.5.jar

66 Bytes
Binary file not shown.

src/gate/plugin/learningframework/engines/EnginePythonNetworksBase.java

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ protected File findWrapperCommand(File dataDirectory, boolean apply) {
147147
@Override
148148
protected void loadModel(File directory, String parms) {
149149
ArrayList<String> finalCommand = new ArrayList<String>();
150+
// we need the corpus representation here! Normally this is done from loadEngine and after
151+
// load model, but we do it here. The load crm method only loads anything if it is still
152+
// null, so we will do this only once anyway.
153+
loadMalletCorpusRepresentation(directory);
150154
CorpusRepresentationMalletTarget data = (CorpusRepresentationMalletTarget)corpusRepresentationMallet;
151155
SimpleEntry<String,Integer> modeAndNrC = findOutMode(data);
152156
String mode = modeAndNrC.getKey();
@@ -360,14 +364,12 @@ public void initializeAlgorithm(Algorithm algorithm, String parms) {
360364

361365
@Override
362366
protected void loadMalletCorpusRepresentation(File directory) {
363-
corpusRepresentationMallet = CorpusRepresentationMalletTarget.load(directory);
367+
if(corpusRepresentationMallet==null)
368+
corpusRepresentationMallet = CorpusRepresentationMalletTarget.load(directory);
364369
}
365370

366371
protected AbstractMap.SimpleEntry<String,Integer> findOutMode(CorpusRepresentationMalletTarget crm) {
367372
InstanceList instances = crm.getRepresentationMallet();
368-
if(instances.size() == 0) {
369-
throw new GateRuntimeException("No instances in the training set, cannot train");
370-
}
371373
// we pass on a "mode" for the learning problem, which is one of the following:
372374
// - classind: predict the index of a class
373375
// - classcosts: targets are vectors of class costs
@@ -381,15 +383,24 @@ protected AbstractMap.SimpleEntry<String,Integer> findOutMode(CorpusRepresentati
381383
Alphabet ta = crm.getPipe().getTargetAlphabet();
382384

383385
if(ta != null) {
384-
Instance firstInstance = instances.get(0);
385-
Object targetObj = firstInstance.getTarget();
386-
if(targetObj instanceof NominalTargetWithCosts) {
387-
NominalTargetWithCosts target = (NominalTargetWithCosts)targetObj;
388-
nrClasses = target.getCosts().length;
389-
mode = "classcosts";
386+
// if this is invoked for training, we should have a first instance, but for
387+
// application, we do not have any instances yet. If we do not have any instances, we
388+
// just use dummy values for now since at the moment we do not need this information
389+
// at application time. Should we ever need it we need to store this in the pipe!
390+
if(instances==null || instances.isEmpty()) {
391+
mode="classind";
392+
nrClasses=-1;
390393
} else {
391-
mode = "classind";
392-
nrClasses = ta.size();
394+
Instance firstInstance = instances.get(0);
395+
Object targetObj = firstInstance.getTarget();
396+
if(targetObj instanceof NominalTargetWithCosts) {
397+
NominalTargetWithCosts target = (NominalTargetWithCosts)targetObj;
398+
nrClasses = target.getCosts().length;
399+
mode = "classcosts";
400+
} else {
401+
mode = "classind";
402+
nrClasses = ta.size();
403+
}
393404
}
394405
}
395406
AbstractMap.SimpleEntry<String,Integer> ret = new AbstractMap.SimpleEntry<String, Integer>(mode,nrClasses);

0 commit comments

Comments
 (0)