Skip to content

Commit 8572053

Browse files
committed
Implement Monte Carlo cross-validation and F-measure reporting
1 parent fd9bb15 commit 8572053

File tree

1 file changed

+148
-0
lines changed

1 file changed

+148
-0
lines changed
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
package de.fraunhofer.iem.swan.model;
2+
3+
4+
import weka.classifiers.Evaluation;
5+
import weka.classifiers.evaluation.output.prediction.AbstractOutput;
6+
import weka.classifiers.evaluation.output.prediction.CSV;
7+
import weka.classifiers.meta.FilteredClassifier;
8+
import weka.core.Instances;
9+
import weka.core.Range;
10+
11+
import java.util.ArrayList;
12+
import java.util.HashMap;
13+
import java.util.Random;
14+
15+
/**
16+
* @author Oshando Johnson on 02.09.20
17+
*/
18+
public class ModelEvaluator {
19+
20+
private ArrayList<AbstractOutput> predictions;
21+
private HashMap<String, String> fMeasure;
22+
23+
public ModelEvaluator() {
24+
predictions = new ArrayList<>();
25+
}
26+
27+
public ArrayList<AbstractOutput> getPredictions() {
28+
return predictions;
29+
}
30+
31+
public HashMap<String, String> getfMeasure() {
32+
return fMeasure;
33+
}
34+
35+
/**
36+
* Evaluates instances using Monte Carlo Cross Evaluation.
37+
*
38+
* @param instances instance set
39+
* @param filteredClassifier classifier to model creation
40+
* @param trainPercentage percentage of instances for train set
41+
* @param iterations number of evaluation iterations
42+
* @return average F-score for iterations
43+
*/
44+
public HashMap<String, String> monteCarloValidate(Instances instances, FilteredClassifier filteredClassifier, double trainPercentage, int iterations) {
45+
46+
initializeResultSet(instances);
47+
48+
for (int i = 0; i < iterations; i++) {
49+
50+
//System.out.println("----" + filteredClassifier.getClassifier().getClass().getSimpleName() + " Iteration #" + i + "----");
51+
StringBuffer stringBuffer = new StringBuffer();
52+
Evaluation eval = null;
53+
54+
//Instances percentage split
55+
int trainSize = (int) Math.round(instances.numInstances() * trainPercentage);
56+
int testSize = instances.numInstances() - trainSize;
57+
58+
//System.out.println("Split: " + trainSize + "/" + testSize);
59+
instances.randomize(new Random(1337 + i * 11));
60+
61+
Instances trainInstances = new Instances(instances, 0, trainSize);
62+
Instances testInstances = new Instances(instances, trainSize, testSize);
63+
64+
try {
65+
66+
filteredClassifier.buildClassifier(trainInstances);
67+
68+
eval = new Evaluation(testInstances);
69+
70+
AbstractOutput abstractOutput = new CSV();
71+
abstractOutput.setBuffer(new StringBuffer());
72+
abstractOutput.setHeader(testInstances);
73+
abstractOutput.setAttributes("last");
74+
75+
eval.evaluateModel(filteredClassifier, testInstances, abstractOutput);
76+
// System.out.println(abstractOutput.getBuffer());
77+
predictions.add(abstractOutput);
78+
79+
80+
//System.out.println(eval.toClassDetailsString());
81+
82+
} catch (Exception e) {
83+
e.printStackTrace();
84+
}
85+
86+
updateResultSet(instances, eval);
87+
}
88+
return fMeasure;
89+
}
90+
91+
/**
92+
* Evaluates instances using Cross Evaluation.
93+
*
94+
* @param instances instance set
95+
* @param filteredClassifier classifier to model creation
96+
* @param iterations number of evaluation iterations
97+
* @return average F-score for iterations
98+
*/
99+
public HashMap<String, String> crossValidate(Instances instances, FilteredClassifier filteredClassifier, int iterations, int folds) {
100+
101+
initializeResultSet(instances);
102+
103+
for (int i = 0; i < iterations; i++) {
104+
105+
Evaluation eval = null;
106+
StringBuffer stringBuffer = new StringBuffer();
107+
108+
try {
109+
eval = new Evaluation(instances);
110+
eval.crossValidateModel(filteredClassifier, instances, folds
111+
, new Random(1337 + i * 11),
112+
stringBuffer, new Range(Integer.toString(instances.numAttributes() - 1)),
113+
true);
114+
//System.out.println(stringBuffer.toString());
115+
System.out.println(eval.toClassDetailsString());
116+
} catch (Exception e) {
117+
e.printStackTrace();
118+
}
119+
updateResultSet(instances, eval);
120+
}
121+
return fMeasure;
122+
}
123+
124+
public void initializeResultSet(Instances instances) {
125+
fMeasure = new HashMap<>();
126+
127+
for (int x = 0; x < instances.numClasses(); x++) {
128+
129+
if (!instances.classAttribute().value(x).contentEquals("none")) {
130+
fMeasure.put(instances.classAttribute().value(x), "");
131+
}
132+
}
133+
}
134+
135+
public void updateResultSet(Instances instances, Evaluation eval) {
136+
137+
for (int x = 0; x < instances.numClasses(); x++) {
138+
139+
if (!instances.classAttribute().value(x).contentEquals("none")) {
140+
141+
String current = fMeasure.get(instances.classAttribute().value(x));
142+
current += eval.fMeasure(x) + ";";
143+
144+
fMeasure.replace(instances.classAttribute().value(x), current.replace("NaN","0"));
145+
}
146+
}
147+
}
148+
}

0 commit comments

Comments
 (0)