Skip to content

Commit b7b4076

Browse files
committed
Refactor ML-Plan integration
1 parent d236e3c commit b7b4076

File tree

1 file changed

+80
-59
lines changed

1 file changed

+80
-59
lines changed
Lines changed: 80 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,108 +1,129 @@
11
package de.fraunhofer.iem.swan.model;
22

3-
import ai.libs.jaicore.logging.LoggerUtil;
4-
import ai.libs.jaicore.ml.weka.dataset.WekaInstances;
3+
import ai.libs.jaicore.ml.classification.loss.dataset.EClassificationPerformanceMeasure;
4+
import ai.libs.jaicore.ml.core.dataset.serialization.ArffDatasetAdapter;
5+
import ai.libs.jaicore.ml.core.evaluation.evaluator.SupervisedLearnerExecutor;
6+
import ai.libs.jaicore.ml.core.filter.SplitterUtil;
7+
import ai.libs.jaicore.ml.weka.classification.learner.IWekaClassifier;
8+
import ai.libs.mlplan.core.MLPlan;
59
import ai.libs.mlplan.multiclass.wekamlplan.MLPlanWekaBuilder;
6-
import de.fraunhofer.iem.swan.util.Util;
10+
import org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassification;
711
import org.api4.java.ai.ml.core.dataset.serialization.DatasetDeserializationFailedException;
812
import org.api4.java.ai.ml.core.dataset.splitter.SplitFailedException;
13+
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
14+
import org.api4.java.ai.ml.core.evaluation.execution.ILearnerRunReport;
15+
import org.api4.java.ai.ml.core.evaluation.execution.LearnerExecutionFailedException;
16+
import org.api4.java.algorithm.Timeout;
917
import org.api4.java.algorithm.exceptions.AlgorithmException;
1018
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
1119
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
1220
import org.slf4j.Logger;
1321
import org.slf4j.LoggerFactory;
14-
import weka.classifiers.Classifier;
15-
import weka.classifiers.meta.FilteredClassifier;
1622
import weka.core.Instances;
1723

24+
import java.io.File;
1825
import java.io.IOException;
19-
import java.util.HashMap;
20-
import java.util.NoSuchElementException;
26+
import java.util.List;
27+
import java.util.Random;
28+
import java.util.concurrent.TimeUnit;
2129

2230
/**
2331
* @author Oshando Johnson on 27.09.20
2432
*/
2533
public class MLPlanExecutor {
2634

2735
private static final Logger LOGGER = LoggerFactory.getLogger(MLPlanExecutor.class);
28-
FilteredClassifier filteredClassifier;
36+
private final int ITERATIONS = 10;
2937

3038
public MLPlanExecutor() {
31-
this.filteredClassifier = new FilteredClassifier();
39+
3240
}
3341

3442
/**
35-
* Runs ML-Plan with the given instance set.
36-
* @param instances training instances
37-
* @return hashmap containing F1 measures
43+
* Run ML-Plan using the provided path to the ARFF file.
44+
*
45+
* @param arffFilePath file path for ARFF file
3846
*/
39-
public HashMap<String, String> run(Instances instances) {
47+
public void evaluateDataset(String arffFilePath) {
4048

4149
long start = System.currentTimeMillis();
4250

43-
instances.setClassIndex(instances.numAttributes() - 1);
44-
45-
HashMap<String, String> results = new HashMap<>();
46-
47-
//Find classifier using ML-Plan
51+
//Initialize dataset using ARFF file path
52+
ILabeledDataset<?> dataset = null;
4853
try {
54+
dataset = ArffDatasetAdapter.readDataset(new File(arffFilePath));
55+
} catch (DatasetDeserializationFailedException e) {
56+
e.printStackTrace();
57+
}
4958

50-
Classifier classifier = new MLPlanWekaBuilder()
51-
.withDataset(new WekaInstances(instances))
52-
.withPortionOfDataReservedForSelection(.8)
53-
//.withNumCpus(4)
54-
//.withTimeOut(new Timeout(300, TimeUnit.SECONDS))
55-
.withMCCVBasedCandidateEvaluationInSearchPhase(10, .8)
56-
.build()
57-
.call()
58-
.getClassifier();
59+
//For each iteration, create a new train-test-split and run ML-Plan
60+
for (int iteration = 0; iteration < ITERATIONS; iteration++) {
5961

60-
long trainTime = (int) (System.currentTimeMillis() - start) / 1000;
61-
filteredClassifier.setClassifier(classifier);
62-
LOGGER.info("Finished build of the classifier. Training time was {}s.", trainTime);
62+
try {
63+
List<ILabeledDataset<?>> split = SplitterUtil.getLabelStratifiedTrainTestSplit(dataset, new Random(1337 + (iteration * 11)), 0.7);
64+
LOGGER.info("Data read. Time to create dataset object was {}ms", System.currentTimeMillis() - start);
6365

64-
LOGGER.info("Chosen model is: {}", (filteredClassifier.getClassifier().getClass().getSimpleName()));
65-
} catch (IOException | AlgorithmTimeoutedException | InterruptedException | AlgorithmException | AlgorithmExecutionCanceledException e) {
66-
e.printStackTrace();
67-
}
66+
IWekaClassifier optimizedClassifier = getClassifier(split.get(0));
6867

69-
//Evaluate classifier using MCCV
70-
try {
68+
/* evaluate solution produced by mlplan */
69+
SupervisedLearnerExecutor executor = new SupervisedLearnerExecutor();
70+
ILearnerRunReport report = executor.execute(optimizedClassifier, split.get(0), split.get(1));
7171

72-
/* evaluate solution produced by mlplan */
73-
ModelEvaluator modelEvaluator = new ModelEvaluator();
74-
results = modelEvaluator.monteCarloValidate(instances, filteredClassifier, 0.8, 10);
72+
System.out.println(report.getPredictionDiffList().getPredictionsAsList().toString());
7573

76-
//Output evaluation results
77-
for(String category: results.keySet()){
78-
System.out.println("---" + category + "---");
79-
System.out.println(filteredClassifier.getClassifier().getClass().getSimpleName()+ ";"+
80-
results.get(category).replace(".", ",").substring(0, results.get(category).lastIndexOf(";")));
81-
}
74+
LOGGER.info("Error Rate of the solution produced by ML-Plan: {}. ",
75+
EClassificationPerformanceMeasure.F1_WITH_1_POSITIVE.loss(report.getPredictionDiffList().getCastedView(Integer.class, ISingleLabelClassification.class)));
8276

83-
long totalTime = (int) (System.currentTimeMillis() - start) / 1000;
84-
LOGGER.info("ML-Plan execution completed. Total time {}s.", totalTime);
85-
return modelEvaluator.getfMeasure();
86-
} catch (NoSuchElementException e) {
87-
LOGGER.error("Building the classifier failed: {}", LoggerUtil.getExceptionInfo(e));
77+
} catch (SplitFailedException | InterruptedException | LearnerExecutionFailedException e) {
78+
e.printStackTrace();
79+
}
8880
}
89-
return results;
9081
}
9182

9283
/**
93-
* Returns classifier selected by ML-Plan.
94-
* @return WEKA classifier
84+
* SReturns trained clssifier
85+
* @param trainingSet training set
86+
* @return trained classifier
9587
*/
96-
public FilteredClassifier getClassifier() {
97-
return filteredClassifier;
88+
public IWekaClassifier getClassifier(ILabeledDataset<?> trainingSet) {
89+
90+
IWekaClassifier optimizedClassifier = null;
91+
/* initialize mlplan with a tiny search space, and let it run for 30 seconds */
92+
93+
try {
94+
MLPlan<IWekaClassifier> mlPlan = new MLPlanWekaBuilder()
95+
.withNumCpus(8)//Set to about 12 on the server
96+
.withSeed(35467463)
97+
//set default timeout
98+
.withTimeOut(new Timeout(30, TimeUnit.SECONDS))
99+
.withDataset(trainingSet)
100+
.withPortionOfDataReservedForSelection(0.0)//ignore selection phase
101+
.withPerformanceMeasureForSearchPhase(EClassificationPerformanceMeasure.F1_WITH_1_POSITIVE)//use F1
102+
.withMCCVBasedCandidateEvaluationInSearchPhase(10, .7)
103+
.build();
104+
105+
mlPlan.setLoggerName("testedalgorithm");
106+
107+
long start = System.currentTimeMillis();
108+
109+
optimizedClassifier = mlPlan.call();
110+
111+
long trainTime = (int) (System.currentTimeMillis() - start) / 1000;
112+
LOGGER.info("Finished build of the classifier. Training time was {}s.", trainTime);
113+
LOGGER.info("Chosen model is: {}", (mlPlan.getSelectedClassifier()));
114+
LOGGER.info("Internally believed error was {}", mlPlan.getInternalValidationErrorOfSelectedClassifier());
115+
116+
} catch (IOException | AlgorithmTimeoutedException | InterruptedException | AlgorithmException | AlgorithmExecutionCanceledException e) {
117+
e.printStackTrace();
118+
}
119+
return optimizedClassifier;
98120
}
99121

100-
public static void main(String[] args) throws DatasetDeserializationFailedException, IOException, InterruptedException, SplitFailedException, AlgorithmExecutionCanceledException, AlgorithmTimeoutedException, AlgorithmException {
122+
public static void main(String[] args) {
101123

102-
Instances dataset = Util.loadArffFile("/Users/oshando/Projects/thesis/03-code/swan/swan_core/swan-out/weka/Train_sanitizer_none.arff");
103-
//Instances dataset = Util.loadArffFile("/Users/oshando/Projects/thesis/03-code/swan/swan_core/src/main/resources/waveform.arff");
124+
String file = "/Users/oshando/Projects/thesis/03-code/swan/swan_core/src/main/resources/waveform.arff";
104125

105126
MLPlanExecutor mlPlan = new MLPlanExecutor();
106-
mlPlan.run(dataset);
127+
mlPlan.evaluateDataset(file);
107128
}
108129
}

0 commit comments

Comments
 (0)