Skip to content

Commit d236e3c

Browse files
committed
Add ML-Plan executor
1 parent 4284a91 commit d236e3c

File tree

3 files changed

+129
-7
lines changed

3 files changed

+129
-7
lines changed

swan_core/src/main/java/de/fraunhofer/iem/swan/model/Learner.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,13 @@ public HashMap<String, HashMap<String, String>> runAutomaticEvaluation(Instances
7878
MLPlanExecutor mlPlanExecutor = new MLPlanExecutor();
7979
LinkedHashMap<String, HashMap<String, String>> fMeasure = new LinkedHashMap<>();
8080

81-
try {
82-
HashMap<String, String> f = mlPlanExecutor.run(instances);
83-
fMeasure.put(mlPlanExecutor.getFilteredClassifier().getClass().getSimpleName(),f);
84-
} catch (IOException e) {
85-
e.printStackTrace();
86-
}
81+
HashMap<String, String> f = mlPlanExecutor.run(instances);
82+
fMeasure.put(mlPlanExecutor.getClassifier().getClass().getSimpleName(), f);
8783

8884
return fMeasure;
8985
}
9086

9187
/**
92-
*
9388
* @param instances
9489
* @return
9590
*/
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package de.fraunhofer.iem.swan.model;
2+
3+
import ai.libs.jaicore.logging.LoggerUtil;
4+
import ai.libs.jaicore.ml.weka.dataset.WekaInstances;
5+
import ai.libs.mlplan.multiclass.wekamlplan.MLPlanWekaBuilder;
6+
import de.fraunhofer.iem.swan.util.Util;
7+
import org.api4.java.ai.ml.core.dataset.serialization.DatasetDeserializationFailedException;
8+
import org.api4.java.ai.ml.core.dataset.splitter.SplitFailedException;
9+
import org.api4.java.algorithm.exceptions.AlgorithmException;
10+
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
11+
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
12+
import org.slf4j.Logger;
13+
import org.slf4j.LoggerFactory;
14+
import weka.classifiers.Classifier;
15+
import weka.classifiers.meta.FilteredClassifier;
16+
import weka.core.Instances;
17+
18+
import java.io.IOException;
19+
import java.util.HashMap;
20+
import java.util.NoSuchElementException;
21+
22+
/**
23+
* @author Oshando Johnson on 27.09.20
24+
*/
25+
public class MLPlanExecutor {
26+
27+
private static final Logger LOGGER = LoggerFactory.getLogger(MLPlanExecutor.class);
28+
FilteredClassifier filteredClassifier;
29+
30+
public MLPlanExecutor() {
31+
this.filteredClassifier = new FilteredClassifier();
32+
}
33+
34+
/**
35+
* Runs ML-Plan with the given instance set.
36+
* @param instances training instances
37+
* @return hashmap containing F1 measures
38+
*/
39+
public HashMap<String, String> run(Instances instances) {
40+
41+
long start = System.currentTimeMillis();
42+
43+
instances.setClassIndex(instances.numAttributes() - 1);
44+
45+
HashMap<String, String> results = new HashMap<>();
46+
47+
//Find classifier using ML-Plan
48+
try {
49+
50+
Classifier classifier = new MLPlanWekaBuilder()
51+
.withDataset(new WekaInstances(instances))
52+
.withPortionOfDataReservedForSelection(.8)
53+
//.withNumCpus(4)
54+
//.withTimeOut(new Timeout(300, TimeUnit.SECONDS))
55+
.withMCCVBasedCandidateEvaluationInSearchPhase(10, .8)
56+
.build()
57+
.call()
58+
.getClassifier();
59+
60+
long trainTime = (int) (System.currentTimeMillis() - start) / 1000;
61+
filteredClassifier.setClassifier(classifier);
62+
LOGGER.info("Finished build of the classifier. Training time was {}s.", trainTime);
63+
64+
LOGGER.info("Chosen model is: {}", (filteredClassifier.getClassifier().getClass().getSimpleName()));
65+
} catch (IOException | AlgorithmTimeoutedException | InterruptedException | AlgorithmException | AlgorithmExecutionCanceledException e) {
66+
e.printStackTrace();
67+
}
68+
69+
//Evaluate classifier using MCCV
70+
try {
71+
72+
/* evaluate solution produced by mlplan */
73+
ModelEvaluator modelEvaluator = new ModelEvaluator();
74+
results = modelEvaluator.monteCarloValidate(instances, filteredClassifier, 0.8, 10);
75+
76+
//Output evaluation results
77+
for(String category: results.keySet()){
78+
System.out.println("---" + category + "---");
79+
System.out.println(filteredClassifier.getClassifier().getClass().getSimpleName()+ ";"+
80+
results.get(category).replace(".", ",").substring(0, results.get(category).lastIndexOf(";")));
81+
}
82+
83+
long totalTime = (int) (System.currentTimeMillis() - start) / 1000;
84+
LOGGER.info("ML-Plan execution completed. Total time {}s.", totalTime);
85+
return modelEvaluator.getfMeasure();
86+
} catch (NoSuchElementException e) {
87+
LOGGER.error("Building the classifier failed: {}", LoggerUtil.getExceptionInfo(e));
88+
}
89+
return results;
90+
}
91+
92+
/**
93+
* Returns classifier selected by ML-Plan.
94+
* @return WEKA classifier
95+
*/
96+
public FilteredClassifier getClassifier() {
97+
return filteredClassifier;
98+
}
99+
100+
public static void main(String[] args) throws DatasetDeserializationFailedException, IOException, InterruptedException, SplitFailedException, AlgorithmExecutionCanceledException, AlgorithmTimeoutedException, AlgorithmException {
101+
102+
Instances dataset = Util.loadArffFile("/Users/oshando/Projects/thesis/03-code/swan/swan_core/swan-out/weka/Train_sanitizer_none.arff");
103+
//Instances dataset = Util.loadArffFile("/Users/oshando/Projects/thesis/03-code/swan/swan_core/src/main/resources/waveform.arff");
104+
105+
MLPlanExecutor mlPlan = new MLPlanExecutor();
106+
mlPlan.run(dataset);
107+
}
108+
}

swan_core/src/main/java/de/fraunhofer/iem/swan/util/Util.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import java.io.File;
44
import java.io.FileInputStream;
5+
import java.io.FileReader;
56
import java.io.IOException;
67
import java.util.HashMap;
78
import java.util.HashSet;
@@ -18,6 +19,7 @@
1819
import de.fraunhofer.iem.swan.data.Method;
1920
import de.fraunhofer.iem.swan.features.type.AbstractSootFeature;
2021
import soot.SootMethod;
22+
import weka.core.Instances;
2123

2224
public class Util {
2325
private static final Logger logger = LoggerFactory.getLogger(Util.class);
@@ -226,4 +228,21 @@ public static Set<String> getFiles(String fileInDirectory) throws IOException {
226228
}
227229
return files;
228230
}
231+
232+
/**
233+
* Loads the ARFF file to an instances object.
234+
* @param filePath path to ARFF file
235+
* @return data as Instances object
236+
*/
237+
public static Instances loadArffFile(String filePath) {
238+
Instances dataset = null;
239+
240+
try {
241+
dataset = new Instances(new FileReader(filePath));
242+
} catch (IOException e) {
243+
e.printStackTrace();
244+
}
245+
246+
return dataset;
247+
}
229248
}

0 commit comments

Comments
 (0)