Skip to content

Commit 857758a

Browse files
committed
Implement multi-label classification with MEKA
1 parent 8d2ace0 commit 857758a

File tree

1 file changed

+115
-0
lines changed
  • swan-pipeline/src/main/java/de/fraunhofer/iem/swan/model/engine

1 file changed

+115
-0
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package de.fraunhofer.iem.swan.model.engine;
2+
3+
import de.fraunhofer.iem.swan.cli.SwanOptions;
4+
import de.fraunhofer.iem.swan.data.Category;
5+
import de.fraunhofer.iem.swan.data.Method;
6+
import de.fraunhofer.iem.swan.features.FeaturesHandler;
7+
import de.fraunhofer.iem.swan.io.dataset.SrmList;
8+
import de.fraunhofer.iem.swan.io.dataset.SrmListUtils;
9+
import de.fraunhofer.iem.swan.model.ModelEvaluator;
10+
import meka.classifiers.multilabel.BR;
11+
import meka.classifiers.multilabel.Evaluation;
12+
import meka.core.Result;
13+
import org.slf4j.Logger;
14+
import org.slf4j.LoggerFactory;
15+
import weka.core.Instances;
16+
17+
import java.io.File;
18+
import java.io.IOException;
19+
import java.util.ArrayList;
20+
import java.util.HashMap;
21+
import java.util.Set;
22+
23+
public class Meka {
24+
25+
private FeaturesHandler features;
26+
private SwanOptions options;
27+
private Set<Method> methods;
28+
private static final Logger logger = LoggerFactory.getLogger(ModelEvaluator.class);
29+
30+
public Meka(FeaturesHandler features, SwanOptions options, Set<Method> methods) {
31+
this.features = features;
32+
this.options = options;
33+
this.methods = methods;
34+
}
35+
36+
/**
37+
* Trains and evaluates the model with the given training data and specified classification mode.
38+
*
39+
*/
40+
public void trainModel() {
41+
42+
switch (ModelEvaluator.Phase.valueOf(options.getPhase().toUpperCase())) {
43+
case VALIDATE:
44+
crossValidateModel(features.getTrainInstances());
45+
break;
46+
case PREDICT:
47+
HashMap<String, ArrayList<Category>> predictions = predictModel(features.getTrainInstances(), features.getTestInstances(), options.getPredictionThreshold());
48+
49+
for (Method method : methods) {
50+
for (Category category : predictions.get(method.getArffSafeSignature())) {
51+
method.addCategory(category);
52+
}
53+
}
54+
55+
try {
56+
SrmListUtils.exportFile(new SrmList(methods), options.getOutputDir() + File.separator + "swan-srm-cwe-list.json");
57+
} catch (IOException e) {
58+
e.printStackTrace();
59+
}
60+
break;
61+
}
62+
}
63+
64+
/**
65+
* Cross-validates a ML model using the provided instances and outputs metrics.
66+
*
67+
* @param instances training instances
68+
*/
69+
public void crossValidateModel(Instances instances) {
70+
71+
try {
72+
73+
BR classifier = new BR();
74+
String top = "PCut1";
75+
String verbosity = "7";
76+
Result result = Evaluation.cvModel(classifier, instances, options.getIterations(), top, verbosity);
77+
78+
logger.info("Model cross-validation results {}", result);
79+
} catch (Exception e) {
80+
e.printStackTrace();
81+
}
82+
}
83+
84+
/**
85+
* Trains a model and uses the trained model to predict the categories for the methods in the test set.
86+
*
87+
* @param train training instances
88+
* @param test test instances
89+
* @param threshold threshold used to determine if a method should be classified into a category
90+
* @return hash map of method signatures and the categories they're classified into
91+
*/
92+
public HashMap<String, ArrayList<Category>> predictModel(Instances train, Instances test, double threshold) {
93+
94+
HashMap<String, ArrayList<Category>> predictions = new HashMap<>();
95+
try {
96+
BR classifier = new BR();
97+
classifier.buildClassifier(train);
98+
99+
for (int i = 0; i < test.numInstances(); i++) {
100+
double[] dist = classifier.distributionForInstance(test.instance(i));
101+
102+
ArrayList<Category> categories = new ArrayList<>();
103+
for (int p = 0; p < dist.length; p++) {
104+
if (dist[p] >= threshold)
105+
categories.add(Category.valueOf(test.attribute(p).name().toUpperCase()));
106+
}
107+
predictions.put(test.get(i).stringValue(test.attribute("id").index()), categories);
108+
}
109+
110+
} catch (Exception e) {
111+
e.printStackTrace();
112+
}
113+
return predictions;
114+
}
115+
}

0 commit comments

Comments
 (0)