Skip to content

Commit 94c7c1a

Browse files
committed
Refactor model evaluation implementation
1 parent 857758a commit 94c7c1a

File tree

4 files changed

+140
-119
lines changed

4 files changed

+140
-119
lines changed

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/SwanPipeline.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public void run() throws IOException, InterruptedException {
5252
featuresHandler.createFeatures();
5353

5454
//Train and evaluate model for SRM and CWE categories
55-
ModelEvaluator modelEvaluator = new ModelEvaluator(featuresHandler, options);
55+
ModelEvaluator modelEvaluator = new ModelEvaluator(featuresHandler, options, testDataset.getMethods());
5656
modelEvaluator.trainModel();
5757

5858
//TODO export final list to JSON file
Lines changed: 22 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,14 @@
11
package de.fraunhofer.iem.swan.model;
22

33
import de.fraunhofer.iem.swan.cli.SwanOptions;
4+
import de.fraunhofer.iem.swan.data.Method;
45
import de.fraunhofer.iem.swan.features.FeaturesHandler;
6+
import de.fraunhofer.iem.swan.model.engine.MLPlan;
7+
import de.fraunhofer.iem.swan.model.engine.Meka;
8+
import de.fraunhofer.iem.swan.model.engine.Weka;
59
import org.slf4j.Logger;
610
import org.slf4j.LoggerFactory;
7-
import weka.classifiers.Classifier;
8-
import weka.classifiers.bayes.BayesNet;
9-
import weka.classifiers.bayes.NaiveBayes;
10-
import weka.classifiers.evaluation.output.prediction.AbstractOutput;
11-
import weka.classifiers.functions.Logistic;
12-
import weka.classifiers.functions.SMO;
13-
import weka.classifiers.rules.JRip;
14-
import weka.classifiers.trees.DecisionStump;
15-
import weka.classifiers.trees.J48;
16-
import weka.core.Instances;
17-
import weka.filters.Filter;
18-
import weka.filters.MultiFilter;
19-
20-
import java.util.ArrayList;
21-
import java.util.HashMap;
22-
import java.util.LinkedHashMap;
23-
import java.util.List;
11+
import java.util.Set;
2412

2513
/**
2614
* Finds possible sources and sinks in a given set of system methods using a
@@ -43,11 +31,13 @@ public enum Phase {
4331

4432
private FeaturesHandler features;
4533
private SwanOptions options;
34+
private Set<Method> methods;
4635
private static final Logger logger = LoggerFactory.getLogger(ModelEvaluator.class);
4736

4837
public ModelEvaluator(FeaturesHandler features, SwanOptions options, Set<Method> methods) {
4938
this.features = features;
5039
this.options = options;
40+
this.methods = methods;
5141
}
5242

5343
/**
@@ -59,86 +49,21 @@ public void trainModel() {
5949

6050
switch (Mode.valueOf(options.getLearningMode().toUpperCase())) {
6151

62-
case MANUAL:
63-
64-
//Phase 1: classify SRM classes
65-
for (String srm : options.getSrmClasses())
66-
runManualEvaluation(features.getInstances().get(srm));
67-
68-
//Filter methods from CWE instances that were not classified
69-
//into one of the SRM classes
70-
71-
72-
//Phase 2: classify CWE classes
73-
for (String cwe : options.getCweClasses())
74-
runManualEvaluation(features.getInstances().get(cwe));
75-
76-
case AUTOMATIC:
77-
//return runAutomaticEvaluation(instances);
78-
}
79-
return null;
80-
}
81-
82-
/**
83-
* Run AutoML training and evaluation on instances.
84-
*
85-
* @param instances list of instances
86-
* @return
87-
*/
88-
public HashMap<String, HashMap<String, String>> runAutomaticEvaluation(Instances instances) {
89-
90-
LinkedHashMap<String, HashMap<String, String>> fMeasure = new LinkedHashMap<>();
91-
92-
MLPlanExecutor mlPlanExecutor = new MLPlanExecutor();
93-
// fMeasure.put("ML-Plan", mlPlanExecutor.evaluateDataset(instances));
94-
95-
//outputFMeasure(fMeasure);
96-
return fMeasure;
97-
}
98-
99-
/**
100-
* @return
101-
*/
102-
public HashMap<String, HashMap<String, String>> runManualEvaluation(Instances instances) {
103-
104-
LinkedHashMap<String, HashMap<String, String>> fMeasure = new LinkedHashMap<>();
105-
106-
List<Classifier> classifiers = new ArrayList<>();
107-
classifiers.add(new BayesNet());
108-
classifiers.add(new NaiveBayes());
109-
classifiers.add(new J48());
110-
classifiers.add(new SMO());
111-
classifiers.add(new JRip());
112-
classifiers.add(new DecisionStump());
113-
classifiers.add(new Logistic());
114-
115-
//For each classifier, evaluate its performance on the instances
116-
for (Classifier classifier : classifiers) {
117-
118-
MonteCarloValidator evaluator = new MonteCarloValidator();
119-
evaluator.monteCarloValidate(instances, classifier, options.getTrainTestSplit(), options.getIterations());
120-
121-
for (String key : evaluator.getFMeasure().keySet())
122-
logger.info("F-measure for {} using {}: {}", key, classifier.getClass().getSimpleName(), evaluator.getFMeasure().get(key));
123-
}
124-
return fMeasure;
125-
}
126-
127-
/**
128-
* Applies the Weka filters to the instances.
129-
*
130-
* @param instances instane set
131-
* @param filters array of filters
132-
* @return instances with filter applied
133-
*/
134-
public Instances applyFilter(Instances instances, MultiFilter filters) {
135-
136-
try {
137-
filters.setInputFormat(instances);
138-
return Filter.useFilter(instances, filters);
139-
} catch (Exception e) {
140-
e.printStackTrace();
52+
case MEKA:
53+
logger.info("Evaluating model with MEKA");
54+
Meka meka = new Meka(features, options, methods);
55+
meka.trainModel();
56+
break;
57+
case WEKA:
58+
logger.info("Evaluating model with WEKA");
59+
Weka weka = new Weka(features, options);
60+
weka.trainModel();
61+
break;
62+
case MLPLAN:
63+
logger.info("Evaluating model with ML-PLAN");
64+
MLPlan mlPlan = new MLPlan();
65+
mlPlan.evaluateDataset(features.getInstances().get("train"));
66+
break;
14167
}
142-
return null;
14368
}
14469
}

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/model/MLPlanExecutor.java renamed to swan-pipeline/src/main/java/de/fraunhofer/iem/swan/model/engine/MLPlan.java

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
package de.fraunhofer.iem.swan.model;
1+
package de.fraunhofer.iem.swan.model.engine;
22

33
import ai.libs.jaicore.ml.classification.loss.dataset.EClassificationPerformanceMeasure;
44
import ai.libs.jaicore.ml.core.dataset.schema.attribute.IntBasedCategoricalAttribute;
55
import ai.libs.jaicore.ml.core.dataset.serialization.ArffDatasetAdapter;
66
import ai.libs.jaicore.ml.core.filter.SplitterUtil;
77
import ai.libs.jaicore.ml.weka.classification.learner.IWekaClassifier;
8-
import ai.libs.mlplan.core.MLPlan;
98
import ai.libs.mlplan.multiclass.wekamlplan.MLPlanWekaBuilder;
9+
import de.fraunhofer.iem.swan.model.MonteCarloValidator;
1010
import de.fraunhofer.iem.swan.util.Util;
1111
import org.api4.java.ai.ml.core.dataset.schema.attribute.IAttribute;
1212
import org.api4.java.ai.ml.core.dataset.serialization.DatasetDeserializationFailedException;
@@ -33,12 +33,12 @@
3333
/**
3434
* @author Oshando Johnson on 27.09.20
3535
*/
36-
public class MLPlanExecutor {
36+
public class MLPlan {
3737

38-
private static final Logger LOGGER = LoggerFactory.getLogger(MLPlanExecutor.class);
38+
private static final Logger LOGGER = LoggerFactory.getLogger(MLPlan.class);
3939
private final int ITERATIONS = 1;
4040

41-
public MLPlanExecutor() {
41+
public MLPlan() {
4242

4343
}
4444

@@ -94,15 +94,13 @@ public HashMap<String, ArrayList<Double>> evaluateDataset(Instances instances1)
9494

9595
//optimizedClassifier.fit(split.get(0));
9696

97-
9897
String trainPath = "swan/swan_core/swan-out/mlplan/train-methods-dataset.arff";
9998
ArffDatasetAdapter.serializeDataset(new File(trainPath), split.get(0));
10099
ArffLoader trainLoader = new ArffLoader();
101100
trainLoader.setFile(new File(trainPath));
102101
Instances trainInstances = trainLoader.getDataSet();
103102
trainInstances.setClassIndex(trainInstances.numAttributes() - 1);
104103

105-
106104
String testPath = "swan/swan_core/swan-out/mlplan/test-methods-dataset.arff";
107105
ArffDatasetAdapter.serializeDataset(new File(testPath), split.get(1));
108106
ArffLoader testLoader = new ArffLoader();
@@ -153,7 +151,6 @@ public void evaluateDataset(Instances instances, int k) {
153151

154152
String mClass = Util.getClassName(instances);
155153

156-
157154
long start = System.currentTimeMillis();
158155

159156
//Initialize dataset using ARFF file path
@@ -171,12 +168,10 @@ public void evaluateDataset(Instances instances, int k) {
171168
//For each iteration, create a new train-test-split and run ML-Plan
172169
for (int iteration = 0; iteration < ITERATIONS; iteration++) {
173170

174-
175171
try {
176172
List<ILabeledDataset<?>> split = SplitterUtil.getLabelStratifiedTrainTestSplit(dataset, new Random(1337 + (iteration * 11)), 0.7);
177173
LOGGER.info("Data read. Time to create dataset object was {}ms", System.currentTimeMillis() - start);
178174

179-
180175
System.out.println(split.get(1).getLabelAttribute().getName());
181176
for (IAttribute attribute : split.get(1).getListOfAttributes()) {
182177

@@ -194,22 +189,14 @@ public void evaluateDataset(Instances instances, int k) {
194189

195190
//System.out.println(dataset.getLabelVector().);
196191
System.out.println(((IntBasedCategoricalAttribute) split.get(1).getAttribute(attributeIndex)).getLabelOfCategory((int) split.get(1).get(x).getAttributeValue(attributeIndex)));
197-
198-
199192
System.out.println(((IntBasedCategoricalAttribute) split.get(1).getLabelAttribute()).getLabelOfCategory((int) split.get(1).get(x).getLabel()));
200-
201-
202193
// System.out.println(split.get(1).getAttribute());
203-
204194
System.out.println(split.get(1).get(x).getAttributeValue(split.get(1).getNumAttributes() - 2) + " " + split.get(1).get(x).getAttributeValue(split.get(1).getNumAttributes() - 1));
205195
}
206-
207-
208196
} catch (SplitFailedException | InterruptedException | IOException e) {
209197
e.printStackTrace();
210198
}
211199
}
212-
213200
}
214201

215202
/**
@@ -224,7 +211,7 @@ public Classifier getClassifier(ILabeledDataset<?> trainingSet) {
224211
/* initialize mlplan with a tiny search space, and let it run for 30 seconds */
225212

226213
try {
227-
MLPlan<IWekaClassifier> mlPlan = new MLPlanWekaBuilder()
214+
ai.libs.mlplan.core.MLPlan<IWekaClassifier> mlPlan = new MLPlanWekaBuilder()
228215
.withNumCpus(12)//Set to about 12 on the server
229216
.withSeed(35467463)
230217
//set default timeout
@@ -252,12 +239,11 @@ public Classifier getClassifier(ILabeledDataset<?> trainingSet) {
252239
return optimizedClassifier;
253240
}
254241

255-
256242
public static void maihn(String[] args) {
257243

258244
String file = "swan/swan_core/src/main/resources/waveform.arff";
259245

260-
MLPlanExecutor mlPlan = new MLPlanExecutor();
246+
MLPlan mlPlan = new MLPlan();
261247
// mlPlan.evaluateDataset(file, "sdfs");
262248
}
263249
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package de.fraunhofer.iem.swan.model.engine;
2+
3+
import de.fraunhofer.iem.swan.cli.SwanOptions;
4+
import de.fraunhofer.iem.swan.features.FeaturesHandler;
5+
import de.fraunhofer.iem.swan.model.ModelEvaluator;
6+
import de.fraunhofer.iem.swan.model.MonteCarloValidator;
7+
import org.slf4j.Logger;
8+
import org.slf4j.LoggerFactory;
9+
import weka.classifiers.Classifier;
10+
import weka.classifiers.bayes.BayesNet;
11+
import weka.classifiers.bayes.NaiveBayes;
12+
import weka.classifiers.functions.Logistic;
13+
import weka.classifiers.functions.SMO;
14+
import weka.classifiers.rules.JRip;
15+
import weka.classifiers.trees.DecisionStump;
16+
import weka.classifiers.trees.J48;
17+
import weka.core.Instances;
18+
import weka.filters.Filter;
19+
import weka.filters.MultiFilter;
20+
21+
import java.util.ArrayList;
22+
import java.util.HashMap;
23+
import java.util.LinkedHashMap;
24+
import java.util.List;
25+
26+
/**
27+
* Finds possible sources and sinks in a given set of system methods using a
28+
* probabilistic algorithm trained on a previously annotated sample set.
29+
*
30+
* @author Steven Arzt, Lisa Nguyen Quang Do, Goran Piskachev
31+
*/
32+
public class Weka {
33+
34+
private FeaturesHandler features;
35+
private SwanOptions options;
36+
private static final Logger logger = LoggerFactory.getLogger(ModelEvaluator.class);
37+
38+
public Weka(FeaturesHandler features, SwanOptions options) {
39+
this.features = features;
40+
this.options = options;
41+
}
42+
43+
/**
44+
* Trains and evaluates the model with the given training data and specified classification mode.
45+
*
46+
* @return Hashmap containing the name of the classifier and it's F-Measure
47+
*/
48+
public HashMap<String, HashMap<String, String>> trainModel() {
49+
50+
51+
//Phase 1: classify SRM classes
52+
for (String srm : options.getSrmClasses())
53+
runManualEvaluation(features.getInstances().get(srm));
54+
55+
//Filter methods from CWE instances that were not classified
56+
//into one of the SRM classes
57+
58+
//Phase 2: classify CWE classes
59+
for (String cwe : options.getCweClasses())
60+
runManualEvaluation(features.getInstances().get(cwe));
61+
62+
return null;
63+
}
64+
65+
/**
66+
* @return
67+
*/
68+
public HashMap<String, HashMap<String, String>> runManualEvaluation(Instances instances) {
69+
70+
LinkedHashMap<String, HashMap<String, String>> fMeasure = new LinkedHashMap<>();
71+
72+
List<Classifier> classifiers = new ArrayList<>();
73+
classifiers.add(new BayesNet());
74+
classifiers.add(new NaiveBayes());
75+
classifiers.add(new J48());
76+
classifiers.add(new SMO());
77+
classifiers.add(new JRip());
78+
classifiers.add(new DecisionStump());
79+
classifiers.add(new Logistic());
80+
81+
//For each classifier, evaluate its performance on the instances
82+
for (Classifier classifier : classifiers) {
83+
84+
MonteCarloValidator evaluator = new MonteCarloValidator();
85+
evaluator.monteCarloValidate(instances, classifier, options.getTrainTestSplit(), options.getIterations());
86+
87+
for (String key : evaluator.getFMeasure().keySet())
88+
logger.info("F-measure for {} using {}: {}", key, classifier.getClass().getSimpleName(), evaluator.getFMeasure().get(key));
89+
}
90+
return fMeasure;
91+
}
92+
93+
/**
94+
* Applies the Weka filters to the instances.
95+
*
96+
* @param instances instane set
97+
* @param filters array of filters
98+
* @return instances with filter applied
99+
*/
100+
public Instances applyFilter(Instances instances, MultiFilter filters) {
101+
102+
try {
103+
filters.setInputFormat(instances);
104+
return Filter.useFilter(instances, filters);
105+
} catch (Exception e) {
106+
e.printStackTrace();
107+
}
108+
return null;
109+
}
110+
}

0 commit comments

Comments
 (0)