Skip to content

Commit 7ebb544

Browse files
committed
Add model selection using WEKA
1 parent 9182a91 commit 7ebb544

File tree

1 file changed

+106
-14
lines changed
  • swan-pipeline/src/main/java/de/fraunhofer/iem/swan/model/toolkit

1 file changed

+106
-14
lines changed

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/model/toolkit/Weka.java

Lines changed: 106 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,32 @@
11
package de.fraunhofer.iem.swan.model.toolkit;
22

33
import de.fraunhofer.iem.swan.cli.SwanOptions;
4+
import de.fraunhofer.iem.swan.data.Category;
5+
import de.fraunhofer.iem.swan.data.Method;
46
import de.fraunhofer.iem.swan.features.WekaFeatureSet;
7+
import de.fraunhofer.iem.swan.io.dataset.SrmList;
58
import de.fraunhofer.iem.swan.model.ModelEvaluator;
69
import de.fraunhofer.iem.swan.model.MonteCarloValidator;
10+
import javafx.util.Pair;
711
import org.slf4j.Logger;
812
import org.slf4j.LoggerFactory;
13+
import weka.classifiers.AbstractClassifier;
914
import weka.classifiers.Classifier;
15+
import weka.classifiers.Evaluation;
1016
import weka.classifiers.bayes.BayesNet;
1117
import weka.classifiers.bayes.NaiveBayes;
1218
import weka.classifiers.functions.Logistic;
1319
import weka.classifiers.functions.SMO;
1420
import weka.classifiers.rules.JRip;
1521
import weka.classifiers.trees.DecisionStump;
1622
import weka.classifiers.trees.J48;
23+
import weka.core.Instance;
1724
import weka.core.Instances;
1825
import weka.filters.Filter;
1926
import weka.filters.MultiFilter;
2027

21-
import java.util.ArrayList;
22-
import java.util.HashMap;
23-
import java.util.LinkedHashMap;
24-
import java.util.List;
28+
import java.text.DecimalFormat;
29+
import java.util.*;
2530

2631
/**
2732
* Finds possible sources and sinks in a given set of system methods using a
@@ -33,45 +38,104 @@ public class Weka {
3338

3439
private WekaFeatureSet features;
3540
private SwanOptions options;
36-
private static final Logger logger = LoggerFactory.getLogger(ModelEvaluator.class);
41+
private Set<Method> methods;
42+
private static final Logger logger = LoggerFactory.getLogger(Weka.class);
43+
private HashMap<String, ArrayList<Category>> predictions;
44+
private HashMap<String, HashMap<String, ArrayList<Double>>> results;
45+
private DecimalFormat df = new DecimalFormat("####0.00");
3746

38-
public Weka(WekaFeatureSet features, SwanOptions options) {
47+
public Weka(WekaFeatureSet features, SwanOptions options, Set<Method> methods) {
3948
this.features = features;
4049
this.options = options;
50+
this.methods = methods;
51+
predictions = new HashMap<>();
52+
53+
results = new HashMap<>();
54+
55+
if (options.isPredictPhase())
56+
for (Method method : features.getDataset().getTestMethods())
57+
predictions.put(method.getArffSafeSignature(), new ArrayList<>());
4158
}
4259

4360
/**
4461
* Trains and evaluates the model with the given training data and specified classification mode.
4562
*
4663
* @return Hashmap containing the name of the classifier and it's F-Measure
4764
*/
48-
public HashMap<String, HashMap<String, String>> trainModel() {
65+
public SrmList trainModel() {
4966

5067
switch (ModelEvaluator.Phase.valueOf(options.getPhase().toUpperCase())) {
5168
case VALIDATE:
5269

5370
//Phase 1: classify SRM classes
5471
logger.info("Performing {}-fold cross-validation for {} using WEKA", options.getIterations(), options.getSrmClasses());
5572
for (String srm : options.getSrmClasses())
56-
runManualEvaluation(features.getTrainInstances().get(srm));
73+
crossValidateModel(features.getTrainInstances().get(srm));
5774

5875
//Filter methods from CWE instances that were not classified into one of the SRM classes
5976
//Phase 2: classify CWE classes
6077
logger.info("Performing {}-fold cross-validation for {} using WEKA", options.getIterations(), options.getCweClasses());
6178
for (String cwe : options.getCweClasses())
62-
runManualEvaluation(features.getTrainInstances().get(cwe));
79+
crossValidateModel(features.getTrainInstances().get(cwe));
80+
81+
// TreeMap to store values of HashMap
82+
TreeMap<String, HashMap<String, ArrayList<Double>>> sorted
83+
= new TreeMap<>(results);
6384

64-
return null;
85+
// Display the TreeMap which is naturally sorted
86+
for (Map.Entry<String, HashMap<String, ArrayList<Double>>> entry :
87+
sorted.entrySet())
88+
89+
return null;
6590
case PREDICT:
6691

92+
logger.info("Predicting {} for TEST dataset using WEKA", options.getSrmClasses());
93+
for (String srm : options.getSrmClasses()) {
94+
predictModel(srm);
95+
}
96+
97+
for (Method method : methods) {
98+
for (Category category : predictions.get(method.getArffSafeSignature())) {
99+
method.addCategory(category);
100+
}
101+
}
102+
return new SrmList(methods);
67103
}
68104
return null;
69105
}
70106

107+
public void predictModel(String srm) {
108+
Pair<String, Double> bestClassifier = crossValidateModel(features.getTrainInstances().get(srm));
109+
110+
try {
111+
112+
Classifier classifier = AbstractClassifier.forName(bestClassifier.getKey(), null);
113+
classifier.buildClassifier(features.getTrainInstances().get(srm));
114+
115+
//NaiveBayes classifier = new NaiveBayes();
116+
//classifier.buildClassifier(features.getTrainInstances().get(srm));
117+
118+
Evaluation eval = new Evaluation(features.getTestInstances().get(srm));
119+
eval.evaluateModel(classifier, features.getTestInstances().get(srm));
120+
121+
for (Instance instance : features.getTestInstances().get(srm)) {
122+
123+
double prediction = eval.evaluateModelOnce(classifier, instance);
124+
if (prediction > 0) {
125+
predictions.get(instance.stringValue(features.getTestInstances().get(srm).attribute("id").index()))
126+
.add(Category.valueOf(srm.toUpperCase()));
127+
}
128+
129+
}
130+
} catch (Exception e) {
131+
e.printStackTrace();
132+
}
133+
}
134+
71135
/**
72136
* @return
73137
*/
74-
public HashMap<String, HashMap<String, String>> runManualEvaluation(Instances instances) {
138+
public Pair<String, Double> crossValidateModel(Instances instances) {
75139

76140
String category = instances.attribute(instances.numAttributes() - 1).name();
77141
instances.setClass(instances.attribute(instances.numAttributes() - 1));
@@ -87,19 +151,47 @@ public HashMap<String, HashMap<String, String>> runManualEvaluation(Instances in
87151
classifiers.add(new DecisionStump());
88152
classifiers.add(new Logistic());
89153

154+
Pair<String, Double> bestClassifier = new Pair<>("", 0.0);
155+
List<Pair> classifierSummary = new ArrayList<>();
90156
//For each classifier, evaluate its performance on the instances
157+
158+
HashMap<String, ArrayList<Double>> measure = new HashMap<>();
91159
for (Classifier classifier : classifiers) {
92160

93161
MonteCarloValidator evaluator = new MonteCarloValidator();
94162
evaluator.monteCarloValidate(instances, classifier, options.getTrainTestSplit(), options.getIterations());
95163

96164
for (String key : evaluator.getFMeasure().keySet()) {
97165

98-
logger.info("Average F-measure for {}({}) using {}: {}, {}", category, key, classifier.getClass().getSimpleName(),
99-
evaluator.getFMeasure().get(key).stream().mapToDouble(a -> a).average().getAsDouble(), evaluator.getFMeasure().get(key));
166+
double averageFMeasure = evaluator.getFMeasure().get(key).stream().mapToDouble(a -> a).average().getAsDouble();
167+
Pair summary = new Pair<>(classifier.getClass().getSimpleName(), Double.parseDouble(df.format(averageFMeasure)));
168+
classifierSummary.add(summary);
169+
170+
if (averageFMeasure > bestClassifier.getValue()) {
171+
bestClassifier = summary;
172+
}
173+
174+
if (category.contains("authentication")) {
175+
176+
if (!results.containsKey(category + "-" + key)) {
177+
HashMap<String, ArrayList<Double>> m = new HashMap<>();
178+
m.put(classifier.getClass().getSimpleName(), evaluator.getFMeasure().get(key));
179+
results.put(category + "-" + key, m);
180+
} else {
181+
results.get(category + "-" + key).put(classifier.getClass().getSimpleName(), evaluator.getFMeasure().get(key));
182+
}
183+
} else
184+
measure.put(classifier.getClass().getSimpleName(), evaluator.getFMeasure().get(key));
185+
186+
logger.debug("Average F-measure for {}({}) using {}: {}, {}", category, key, classifier.getClass().getSimpleName(),
187+
averageFMeasure, evaluator.getFMeasure().get(key));
100188
}
189+
if (!category.contains("authentication"))
190+
results.put(category, measure);
101191
}
102-
return fMeasure;
192+
logger.info("Selecting {} model for {}, evaluated classifiers={}", bestClassifier.getKey(), category, classifierSummary);
193+
194+
return bestClassifier;
103195
}
104196

105197
/**

0 commit comments

Comments
 (0)