Skip to content

Commit 3ba98d4

Browse files
committed
Print precision and recall to logs
1 parent bfdb19a commit 3ba98d4

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/model/MonteCarloValidator.java

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,15 @@ public class MonteCarloValidator {
1919

2020
private ArrayList<String> predictions;
2121
private HashMap<String, ArrayList<Double>> fMeasure;
22+
private HashMap<String, ArrayList<Double>> precision;
23+
private HashMap<String, ArrayList<Double>> recall;
2224
private DecimalFormat df = new DecimalFormat("####0.00");
2325

2426
public MonteCarloValidator() {
2527
predictions = new ArrayList<>();
2628
fMeasure = new HashMap<>();
29+
precision = new HashMap<>();
30+
recall = new HashMap<>();
2731
}
2832

2933
public ArrayList<String> getPredictions() {
@@ -34,6 +38,14 @@ public HashMap<String, ArrayList<Double>> getFMeasure() {
3438
return fMeasure;
3539
}
3640

41+
public HashMap<String, ArrayList<Double>> getRecall() {
42+
return recall;
43+
}
44+
45+
public HashMap<String, ArrayList<Double>> getPrecision() {
46+
return precision;
47+
}
48+
3749
/**
3850
* Evaluates instances using Monte Carlo Cross Evaluation.
3951
*
@@ -55,11 +67,11 @@ public void monteCarloValidate(Instances instances, Classifier classifier, doubl
5567
Instances trainInstances = new Instances(instances, 0, trainSize);
5668
Instances testInstances = new Instances(instances, trainSize, testSize);
5769

58-
evaluate(classifier, trainInstances, testInstances);
70+
evaluate(classifier, trainInstances, testInstances, i);
5971
}
6072
}
6173

62-
public ArrayList<String> evaluate(Classifier classifier, Instances trainInstances, Instances testInstances) {
74+
public ArrayList<String> evaluate(Classifier classifier, Instances trainInstances, Instances testInstances, int iteration) {
6375

6476
Evaluation eval = null;
6577
try {
@@ -75,6 +87,9 @@ public ArrayList<String> evaluate(Classifier classifier, Instances trainInstance
7587

7688
eval.evaluateModel(classifier, testInstances, abstractOutput);
7789

90+
String header = "=== " + classifier.getClass().getSimpleName() + " Iteration #" + iteration + ": " + trainInstances.attribute(trainInstances.numAttributes() - 1).name() + " ===";
91+
System.out.println(eval.toClassDetailsString(header));
92+
7893
//Obtain all predictions and extract method signatures
7994
String[] output = abstractOutput.getBuffer().toString().split("\n");
8095

@@ -99,10 +114,14 @@ public void updateResultSet(Instances instances, Evaluation eval) {
99114
String currentClass = instances.classAttribute().value(c);
100115

101116
if (!currentClass.contentEquals("0")) {
102-
if (!fMeasure.containsKey(currentClass))
117+
if (!fMeasure.containsKey(currentClass)) {
103118
fMeasure.put(currentClass, new ArrayList<>(Collections.singletonList(Double.isNaN(eval.fMeasure(c)) ? 0 : Double.parseDouble(df.format(eval.fMeasure(c))))));
104-
else {
119+
recall.put(currentClass, new ArrayList<>(Collections.singletonList(Double.isNaN(eval.recall(c)) ? 0 : Double.parseDouble(df.format(eval.recall(c))))));
120+
precision.put(currentClass, new ArrayList<>(Collections.singletonList(Double.isNaN(eval.precision(c)) ? 0 : Double.parseDouble(df.format(eval.precision(c))))));
121+
} else {
105122
fMeasure.get(currentClass).add(Double.isNaN(eval.fMeasure(c)) ? 0 : Double.parseDouble(df.format(eval.fMeasure(c))));
123+
recall.get(currentClass).add(Double.isNaN(eval.recall(c)) ? 0 : Double.parseDouble(df.format(eval.recall(c))));
124+
precision.get(currentClass).add(Double.isNaN(eval.precision(c)) ? 0 : Double.parseDouble(df.format(eval.precision(c))));
106125
}
107126
}
108127
}

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/model/toolkit/Weka.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public class Weka {
4242
private static final Logger logger = LoggerFactory.getLogger(Weka.class);
4343
private HashMap<String, ArrayList<Category>> predictions;
4444
private HashMap<String, HashMap<String, ArrayList<Double>>> results;
45-
private DecimalFormat df = new DecimalFormat("####0.00");
45+
private DecimalFormat df = new DecimalFormat("####0.000");
4646

4747
public Weka(WekaFeatureSet features, SwanOptions options, Set<Method> methods) {
4848
this.features = features;
@@ -164,6 +164,9 @@ public Pair<String, Double> crossValidateModel(Instances instances) {
164164
for (String key : evaluator.getFMeasure().keySet()) {
165165

166166
double averageFMeasure = evaluator.getFMeasure().get(key).stream().mapToDouble(a -> a).average().getAsDouble();
167+
double averagePrecision = evaluator.getPrecision().get(key).stream().mapToDouble(a -> a).average().getAsDouble();
168+
double averageRecall = evaluator.getRecall().get(key).stream().mapToDouble(a -> a).average().getAsDouble();
169+
167170
Pair summary = new Pair<>(classifier.getClass().getSimpleName(), Double.parseDouble(df.format(averageFMeasure)));
168171
classifierSummary.add(summary);
169172

@@ -183,8 +186,7 @@ public Pair<String, Double> crossValidateModel(Instances instances) {
183186
} else
184187
measure.put(classifier.getClass().getSimpleName(), evaluator.getFMeasure().get(key));
185188

186-
logger.debug("Average F-measure for {}({}) using {}: {}, {}", category, key, classifier.getClass().getSimpleName(),
187-
averageFMeasure, evaluator.getFMeasure().get(key));
189+
logger.debug("{} Average F-measure ({}), Precision ({}) and Recall ({}) for {}({}) ", classifier.getClass().getSimpleName(), averageFMeasure,averagePrecision, averageRecall, category, key);
188190
}
189191
if (!category.contains("authentication"))
190192
results.put(category, measure);

0 commit comments

Comments
 (0)