Skip to content

Commit 99d23eb

Browse files
committed
Move code for Monte Carlo validation to new class
1 parent ab922ad commit 99d23eb

File tree

2 files changed

+154
-22
lines changed

2 files changed

+154
-22
lines changed

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/model/MLPlanExecutor.java

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,33 @@
11
package de.fraunhofer.iem.swan.model;
22

33
import ai.libs.jaicore.ml.classification.loss.dataset.EClassificationPerformanceMeasure;
4-
import ai.libs.jaicore.ml.classification.singlelabel.SingleLabelClassification;
54
import ai.libs.jaicore.ml.core.dataset.schema.attribute.IntBasedCategoricalAttribute;
65
import ai.libs.jaicore.ml.core.dataset.serialization.ArffDatasetAdapter;
7-
import ai.libs.jaicore.ml.core.evaluation.evaluator.SupervisedLearnerExecutor;
86
import ai.libs.jaicore.ml.core.filter.SplitterUtil;
97
import ai.libs.jaicore.ml.weka.classification.learner.IWekaClassifier;
108
import ai.libs.mlplan.core.MLPlan;
119
import ai.libs.mlplan.multiclass.wekamlplan.MLPlanWekaBuilder;
12-
import de.fraunhofer.iem.swan.features.InstancesHandler;
1310
import de.fraunhofer.iem.swan.util.Util;
14-
import org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassification;
1511
import org.api4.java.ai.ml.core.dataset.schema.attribute.IAttribute;
1612
import org.api4.java.ai.ml.core.dataset.serialization.DatasetDeserializationFailedException;
1713
import org.api4.java.ai.ml.core.dataset.splitter.SplitFailedException;
1814
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
19-
import org.api4.java.ai.ml.core.evaluation.execution.ILearnerRunReport;
20-
import org.api4.java.ai.ml.core.evaluation.execution.LearnerExecutionFailedException;
21-
import org.api4.java.ai.ml.core.exception.TrainingException;
2215
import org.api4.java.algorithm.Timeout;
2316
import org.api4.java.algorithm.exceptions.AlgorithmException;
2417
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
2518
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
26-
import org.nd4j.common.io.StringUtils;
2719
import org.slf4j.Logger;
2820
import org.slf4j.LoggerFactory;
2921
import weka.classifiers.Classifier;
30-
import weka.core.Attribute;
3122
import weka.core.Instances;
3223
import weka.core.converters.ArffLoader;
3324

3425
import java.io.File;
3526
import java.io.IOException;
36-
import java.util.*;
27+
import java.util.ArrayList;
28+
import java.util.HashMap;
29+
import java.util.List;
30+
import java.util.Random;
3731
import java.util.concurrent.TimeUnit;
3832

3933
/**
@@ -51,13 +45,13 @@ public MLPlanExecutor() {
5145
/**
5246
* Run ML-Plan using the provided path to the ARFF file.
5347
*
54-
* @param instancesHandler file path for ARFF file
48+
* @param instances1 file path for ARFF file
5549
*/
56-
public HashMap<String, String> evaluateDataset(InstancesHandler instancesHandler) {
50+
public HashMap<String, String> evaluateDataset(Instances instances1) {
5751

58-
String arffFilePath = Util.exportInstancesToArff(instancesHandler.getInstances());
52+
String arffFilePath = Util.exportInstancesToArff(instances1);
5953

60-
String mClass = Util.getClassName(instancesHandler.getInstances());
54+
String mClass = Util.getClassName(instances1);
6155

6256
long start = System.currentTimeMillis();
6357

@@ -74,14 +68,14 @@ public HashMap<String, String> evaluateDataset(InstancesHandler instancesHandler
7468
ArrayList<Double> fScores = new ArrayList<>();
7569
ArrayList<String> algorithms = new ArrayList<>();
7670

77-
ModelEvaluator modelEvaluator = new ModelEvaluator();
71+
MonteCarloValidator monteCarloValidator = new MonteCarloValidator();
7872

7973
ArffLoader loader = new ArffLoader();
8074
try {
8175
loader.setFile(new File(arffFilePath));
8276
Instances instances = loader.getDataSet();
8377
instances.setClassIndex(instances.numAttributes() - 1);
84-
modelEvaluator.initializeResultSet(instances);
78+
monteCarloValidator.initializeResultSet(instances);
8579
} catch (IOException e) {
8680
e.printStackTrace();
8781
}
@@ -116,7 +110,7 @@ public HashMap<String, String> evaluateDataset(InstancesHandler instancesHandler
116110
Instances testInstances = testLoader.getDataSet();
117111
testInstances.setClassIndex(testInstances.numAttributes() - 1);
118112

119-
modelEvaluator.evaluate(optimizedClassifier, trainInstances, testInstances, iteration);
113+
monteCarloValidator.evaluate(optimizedClassifier, trainInstances, testInstances, iteration);
120114

121115

122116
/* evaluate solution produced by mlplan */
@@ -149,15 +143,15 @@ public HashMap<String, String> evaluateDataset(InstancesHandler instancesHandler
149143
e.printStackTrace();
150144
}
151145
}
152-
return modelEvaluator.getFMeasure();
146+
return monteCarloValidator.getFMeasure();
153147
}
154148

155-
public void evaluateDataset(InstancesHandler instancesHandler, int k) {
149+
public void evaluateDataset(Instances instances, int k) {
156150

157151
//arffFilePath = "swan/swan_core/src/main/resources/waveform.arff";
158-
String arffFilePath = Util.exportInstancesToArff(instancesHandler.getInstances());
152+
String arffFilePath = Util.exportInstancesToArff(instances);
159153

160-
String mClass = Util.getClassName(instancesHandler.getInstances());
154+
String mClass = Util.getClassName(instances);
161155

162156

163157
long start = System.currentTimeMillis();
@@ -172,7 +166,7 @@ public void evaluateDataset(InstancesHandler instancesHandler, int k) {
172166

173167
//dataset.removeColumn("id");
174168

175-
ModelEvaluator modelEvaluator = new ModelEvaluator();
169+
MonteCarloValidator monteCarloValidator = new MonteCarloValidator();
176170

177171
//For each iteration, create a new train-test-split and run ML-Plan
178172
for (int iteration = 0; iteration < ITERATIONS; iteration++) {
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package de.fraunhofer.iem.swan.model;
2+
3+
import de.fraunhofer.iem.swan.util.Util;
4+
import weka.classifiers.Classifier;
5+
import weka.classifiers.Evaluation;
6+
import weka.classifiers.evaluation.output.prediction.AbstractOutput;
7+
import weka.classifiers.evaluation.output.prediction.CSV;
8+
import weka.core.Instances;
9+
10+
import java.util.ArrayList;
11+
import java.util.HashMap;
12+
import java.util.Random;
13+
14+
/**
15+
* @author Oshando Johnson on 02.09.20
16+
*/
17+
public class MonteCarloValidator {
18+
19+
private ArrayList<AbstractOutput> predictions;
20+
private HashMap<String, String> fMeasure;
21+
22+
public MonteCarloValidator() {
23+
predictions = new ArrayList<>();
24+
}
25+
26+
public ArrayList<AbstractOutput> getPredictions() {
27+
return predictions;
28+
}
29+
30+
public HashMap<String, String> getFMeasure() {
31+
return fMeasure;
32+
}
33+
34+
/**
35+
* Evaluates instances using Monte Carlo Cross Evaluation.
36+
*
37+
* @param instances instance set
38+
* @param classifier classifier to model creation
39+
* @param trainPercentage percentage of instances for train set
40+
* @param iterations number of evaluation iterations
41+
* @return average F-score for iterations
42+
*/
43+
public HashMap<String, String> monteCarloValidate(Instances instances, Classifier classifier, double trainPercentage, int iterations) {
44+
45+
initializeResultSet(instances);
46+
47+
for (int i = 0; i < iterations; i++) {
48+
Util.exportInstancesToArff(instances);
49+
evaluateIteration(instances, classifier, trainPercentage, i);
50+
}
51+
return fMeasure;
52+
}
53+
54+
public void evaluateIteration(Instances instances, Classifier classifier, double trainPercentage, int iteration) {
55+
56+
int trainSize = (int) Math.round(instances.numInstances() * trainPercentage);
57+
int testSize = instances.numInstances() - trainSize;
58+
59+
instances.randomize(new Random(1337 + iteration * 11));
60+
instances.stratify(10);
61+
62+
Instances trainInstances = new Instances(instances, 0, trainSize);
63+
Instances testInstances = new Instances(instances, trainSize, testSize);
64+
65+
evaluate(classifier, trainInstances, testInstances, iteration);
66+
}
67+
68+
public void evaluate(Classifier classifier, Instances trainInstances, Instances testInstances, int iteration) {
69+
70+
Evaluation eval = null;
71+
try {
72+
73+
classifier.buildClassifier(trainInstances);
74+
75+
eval = new Evaluation(testInstances);
76+
77+
AbstractOutput abstractOutput = new CSV();
78+
abstractOutput.setBuffer(new StringBuffer());
79+
abstractOutput.setHeader(testInstances);
80+
abstractOutput.setAttributes(Integer.toString(testInstances.numAttributes() - 1));
81+
82+
eval.evaluateModel(classifier, testInstances, abstractOutput);
83+
84+
String[] predictions = abstractOutput.getBuffer().toString().split("\n");
85+
86+
for (String result : predictions) {
87+
String[] entry = result.split(",");
88+
89+
if (entry[2].contains("source") || entry[2].contains("sink") || entry[2].contains("sanitizer")
90+
|| entry[2].contains("auth")) {
91+
92+
String method = entry[5].replace("'", "");
93+
94+
// System.out.println(method);
95+
// SwanPipeline.predictions.get(Integer.toString(iteration)).add(method);
96+
}
97+
}
98+
99+
//get class name
100+
String className = "";
101+
for (int x = 0; x < testInstances.attribute("class").numValues(); x++) {
102+
103+
if (!testInstances.attribute("class").value(x).contains("none")) {
104+
className = testInstances.attribute("class").value(x);
105+
break;
106+
}
107+
}
108+
} catch (Exception e) {
109+
e.printStackTrace();
110+
}
111+
updateResultSet(testInstances, eval);
112+
}
113+
114+
public void initializeResultSet(Instances instances) {
115+
fMeasure = new HashMap<>();
116+
117+
for (int x = 0; x < instances.numClasses(); x++) {
118+
119+
if (!instances.classAttribute().value(x).contentEquals("none")) {
120+
fMeasure.put(instances.classAttribute().value(x), "");
121+
}
122+
}
123+
}
124+
125+
public void updateResultSet(Instances instances, Evaluation eval) {
126+
127+
for (int x = 0; x < instances.numClasses(); x++) {
128+
129+
if (!instances.classAttribute().value(x).contentEquals("none")) {
130+
131+
String current = fMeasure.get(instances.classAttribute().value(x));
132+
current += eval.fMeasure(x) + ";";
133+
134+
fMeasure.replace(instances.classAttribute().value(x), current.replace("NaN", "0"));
135+
}
136+
}
137+
}
138+
}

0 commit comments

Comments
 (0)