Skip to content

Commit 2a87510

Browse files
committed
Move model evaluation code for MEKA, WEKA and ML-Plan
1 parent 120a966 commit 2a87510

File tree

4 files changed

+20
-18
lines changed

4 files changed

+20
-18
lines changed

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/model/ModelEvaluator.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import de.fraunhofer.iem.swan.cli.SwanOptions;
44
import de.fraunhofer.iem.swan.data.Method;
5-
import de.fraunhofer.iem.swan.features.FeaturesHandler;
5+
import de.fraunhofer.iem.swan.features.IFeatureSet;
6+
import de.fraunhofer.iem.swan.features.MekaFeatureSet;
7+
import de.fraunhofer.iem.swan.features.WekaFeatureSet;
68
import de.fraunhofer.iem.swan.model.engine.MLPlan;
79
import de.fraunhofer.iem.swan.model.engine.Meka;
810
import de.fraunhofer.iem.swan.model.engine.Weka;
@@ -29,12 +31,12 @@ public enum Phase {
2931
PREDICT
3032
}
3133

32-
private FeaturesHandler features;
34+
private IFeatureSet features;
3335
private SwanOptions options;
3436
private Set<Method> methods;
3537
private static final Logger logger = LoggerFactory.getLogger(ModelEvaluator.class);
3638

37-
public ModelEvaluator(FeaturesHandler features, SwanOptions options, Set<Method> methods) {
39+
public ModelEvaluator(IFeatureSet features, SwanOptions options, Set<Method> methods) {
3840
this.features = features;
3941
this.options = options;
4042
this.methods = methods;
@@ -51,18 +53,18 @@ public void trainModel() {
5153

5254
case MEKA:
5355
logger.info("Evaluating model with MEKA");
54-
Meka meka = new Meka(features, options, methods);
56+
Meka meka = new Meka((MekaFeatureSet)features, options, methods);
5557
meka.trainModel();
5658
break;
5759
case WEKA:
5860
logger.info("Evaluating model with WEKA");
59-
Weka weka = new Weka(features, options);
61+
Weka weka = new Weka((WekaFeatureSet) features, options);
6062
weka.trainModel();
6163
break;
6264
case MLPLAN:
6365
logger.info("Evaluating model with ML-PLAN");
6466
MLPlan mlPlan = new MLPlan();
65-
mlPlan.evaluateDataset(features.getInstances().get("train"));
67+
mlPlan.evaluateDataset(((WekaFeatureSet)features).getInstances().get("train"));
6668
break;
6769
}
6870
}

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/model/engine/MLPlan.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ public HashMap<String, ArrayList<Double>> evaluateDataset(Instances instances1)
108108
Instances testInstances = testLoader.getDataSet();
109109
testInstances.setClassIndex(testInstances.numAttributes() - 1);
110110

111-
monteCarloValidator.evaluate(optimizedClassifier, trainInstances, testInstances, iteration);
111+
monteCarloValidator.evaluate(optimizedClassifier, trainInstances, testInstances);
112112

113113

114114
/* evaluate solution produced by mlplan */

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/model/engine/Meka.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import de.fraunhofer.iem.swan.cli.SwanOptions;
44
import de.fraunhofer.iem.swan.data.Category;
55
import de.fraunhofer.iem.swan.data.Method;
6-
import de.fraunhofer.iem.swan.features.FeaturesHandler;
6+
import de.fraunhofer.iem.swan.features.MekaFeatureSet;
77
import de.fraunhofer.iem.swan.io.dataset.SrmList;
88
import de.fraunhofer.iem.swan.io.dataset.SrmListUtils;
99
import de.fraunhofer.iem.swan.model.ModelEvaluator;
@@ -22,12 +22,12 @@
2222

2323
public class Meka {
2424

25-
private FeaturesHandler features;
25+
private MekaFeatureSet features;
2626
private SwanOptions options;
2727
private Set<Method> methods;
2828
private static final Logger logger = LoggerFactory.getLogger(ModelEvaluator.class);
2929

30-
public Meka(FeaturesHandler features, SwanOptions options, Set<Method> methods) {
30+
public Meka(MekaFeatureSet features, SwanOptions options, Set<Method> methods) {
3131
this.features = features;
3232
this.options = options;
3333
this.methods = methods;

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/model/engine/Weka.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package de.fraunhofer.iem.swan.model.engine;
22

33
import de.fraunhofer.iem.swan.cli.SwanOptions;
4-
import de.fraunhofer.iem.swan.features.FeaturesHandler;
4+
import de.fraunhofer.iem.swan.features.WekaFeatureSet;
55
import de.fraunhofer.iem.swan.model.ModelEvaluator;
66
import de.fraunhofer.iem.swan.model.MonteCarloValidator;
77
import org.slf4j.Logger;
@@ -31,11 +31,11 @@
3131
*/
3232
public class Weka {
3333

34-
private FeaturesHandler features;
34+
private WekaFeatureSet features;
3535
private SwanOptions options;
3636
private static final Logger logger = LoggerFactory.getLogger(ModelEvaluator.class);
3737

38-
public Weka(FeaturesHandler features, SwanOptions options) {
38+
public Weka(WekaFeatureSet features, SwanOptions options) {
3939
this.features = features;
4040
this.options = options;
4141
}
@@ -47,14 +47,11 @@ public Weka(FeaturesHandler features, SwanOptions options) {
4747
*/
4848
public HashMap<String, HashMap<String, String>> trainModel() {
4949

50-
5150
//Phase 1: classify SRM classes
5251
for (String srm : options.getSrmClasses())
5352
runManualEvaluation(features.getInstances().get(srm));
5453

55-
//Filter methods from CWE instances that were not classified
56-
//into one of the SRM classes
57-
54+
//Filter methods from CWE instances that were not classified into one of the SRM classes
5855
//Phase 2: classify CWE classes
5956
for (String cwe : options.getCweClasses())
6057
runManualEvaluation(features.getInstances().get(cwe));
@@ -67,6 +64,9 @@ public HashMap<String, HashMap<String, String>> trainModel() {
6764
*/
6865
public HashMap<String, HashMap<String, String>> runManualEvaluation(Instances instances) {
6966

67+
String category = instances.attribute(instances.numAttributes()-1).name();
68+
instances.setClass(instances.attribute(instances.numAttributes()-1));
69+
7070
LinkedHashMap<String, HashMap<String, String>> fMeasure = new LinkedHashMap<>();
7171

7272
List<Classifier> classifiers = new ArrayList<>();
@@ -85,7 +85,7 @@ public HashMap<String, HashMap<String, String>> runManualEvaluation(Instances in
8585
evaluator.monteCarloValidate(instances, classifier, options.getTrainTestSplit(), options.getIterations());
8686

8787
for (String key : evaluator.getFMeasure().keySet())
88-
logger.info("F-measure for {} using {}: {}", key, classifier.getClass().getSimpleName(), evaluator.getFMeasure().get(key));
88+
logger.info("F-measure for {}({}) using {}: {}", category, key , classifier.getClass().getSimpleName(), evaluator.getFMeasure().get(key));
8989
}
9090
return fMeasure;
9191
}

0 commit comments

Comments
 (0)