Skip to content

Commit 2f90bf4

Browse files
authored
Merge pull request #42 from secure-software-engineering/feature/train-dataset-arff
Improve training dataset loading
2 parents 471f793 + 03b30c2 commit 2f90bf4

File tree

17 files changed

+806
-87
lines changed

17 files changed

+806
-87
lines changed

swan-pipeline/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@
106106
<groupId>edu.stanford.nlp</groupId>
107107
<artifactId>stanford-corenlp</artifactId>
108108
<version>4.3.0</version>
109-
<classifier>models</classifier>
109+
<classifier>models-english</classifier>
110110
</dependency>
111111
<dependency>
112112
<groupId>org.jsoup</groupId>

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/SwanPipeline.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ public class SwanPipeline {
2626

2727
private static final Logger logger = LoggerFactory.getLogger(SwanPipeline.class);
2828
public static SwanOptions options;
29+
private ModelEvaluator modelEvaluator;
2930

3031
public SwanPipeline(SwanOptions options) {
3132
SwanPipeline.options = options;
@@ -47,7 +48,10 @@ public void run() throws IOException, InterruptedException {
4748
// Load methods in training dataset
4849
Dataset dataset = new Dataset();
4950
dataset.setTrain(SrmListUtils.importFile(options.getDatasetJson()));
50-
soot.cleanupList(dataset.getTrain());
51+
52+
if (!options.getTrainDataDir().isEmpty())
53+
soot.cleanupList(dataset.getTrain());
54+
5155
logger.info("Loaded {} training methods, distribution={}", dataset.getTrainMethods().size(), Util.countCategories(dataset.getTrainMethods()));
5256

5357
//Load methods from the test set
@@ -61,10 +65,14 @@ public void run() throws IOException, InterruptedException {
6165
IFeatureSet featureSet = featureSetSelector.select(dataset, options);
6266

6367
//Train and evaluate model for SRM and CWE categories
64-
ModelEvaluator modelEvaluator = new ModelEvaluator(featureSet, options, dataset.getTestMethods());
68+
modelEvaluator = new ModelEvaluator(featureSet, options, dataset.getTestMethods());
6569
modelEvaluator.trainModel();
6670

6771
long analysisTime = System.currentTimeMillis() - startAnalysisTime;
6872
logger.info("Total runtime {} minutes", analysisTime / 60000);
6973
}
74+
75+
public ModelEvaluator getModelEvaluator() {
76+
return modelEvaluator;
77+
}
7078
}

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/cli/CliRunner.java

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import picocli.CommandLine;
44

5+
import java.util.ArrayList;
56
import java.util.Collections;
67
import java.util.List;
78
import java.util.concurrent.Callable;
@@ -11,16 +12,19 @@
1112
public class CliRunner implements Callable<Integer> {
1213

1314
@CommandLine.Option(names = {"-test", "--test-data"}, description = {"Path of test JARs or class files"})
14-
private String testDataDir = "/input/test-data";
15+
private String testDataDir = "";
1516

1617
@CommandLine.Option(names = {"-train", "--train-data"}, description = {"Path of training JARs or class files"})
17-
private String trainDataDir = "/input/train-data";
18+
private String trainDataDir = "";
1819

1920
@CommandLine.Option(names = {"-d", "--dataset"}, description = {"Path to JSON dataset file"})
20-
private String datasetJson = "/input/dataset/swan-dataset.json";
21+
private String datasetJson = "/dataset/swan-dataset.json";
22+
23+
@CommandLine.Option(names = {"-in", "--train-instances"}, description = {"Path to ARFF files that contain training instances"})
24+
private List<String> instancesArff = new ArrayList<>();
2125

2226
@CommandLine.Option(names = {"-o", "--output"}, description = {"Directory to save output files"})
23-
private String outputDir = "/swan-output";
27+
private String outputDir = "";
2428

2529
@CommandLine.Option(names = {"-f", "--feature"}, description = {"Select one or more feature sets: all, code, doc-auto or doc-manual"})
2630
private List<String> featureSet = Collections.singletonList("code");
@@ -52,9 +56,7 @@ public class CliRunner implements Callable<Integer> {
5256
@CommandLine.Option(names = {"-pt", "--prediction-threshold"}, description = {"Threshold for predicting categories"})
5357
private double predictionThreshold = 0.5;
5458

55-
56-
@Override
57-
public Integer call() throws Exception {
59+
public SwanOptions initializeOptions(){
5860

5961
SwanOptions options = new SwanOptions(testDataDir,
6062
trainDataDir,
@@ -70,7 +72,14 @@ public Integer call() throws Exception {
7072
split,
7173
phase);
7274
options.setPredictionThreshold(predictionThreshold);
75+
options.setInstances(instancesArff);
76+
77+
return options;
78+
}
79+
80+
@Override
81+
public Integer call() throws Exception {
7382

74-
return new SwanCli().run(options);
83+
return new SwanCli().run(initializeOptions());
7584
}
7685
}

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/cli/SwanCli.java

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import org.slf4j.Logger;
55
import org.slf4j.LoggerFactory;
66

7+
import java.util.ArrayList;
78
import java.util.Arrays;
9+
import java.util.List;
810
import java.util.concurrent.CancellationException;
911

1012
/**
@@ -13,39 +15,44 @@
1315
public class SwanCli {
1416

1517
private static final Logger logger = LoggerFactory.getLogger(SwanCli.class);
18+
private SwanPipeline swanPipeline;
1619

1720
public Integer run(SwanOptions options) throws Exception {
1821

1922
FileUtility fileUtility = new FileUtility();
2023

21-
if (options.getDatasetJson().contentEquals("/input/dataset/swan-dataset.json")) {
24+
if (options.getDatasetJson().contentEquals("/dataset/swan-dataset.json")) {
2225
options.setDatasetJson(fileUtility.getResourceFile(options.getDatasetJson()).getAbsolutePath());
2326
}
2427

25-
if (options.getTrainDataDir().contentEquals("/input/train-data")) {
26-
options.setTrainDataDir(fileUtility.getResourceDirectory("/input/train-data").getAbsolutePath());
27-
}
28-
29-
if (options.getTestDataDir().contentEquals("/input/test-data")) {
30-
options.setTestDataDir(fileUtility.getResourceDirectory("/input/test-data").getAbsolutePath());
31-
}
32-
33-
if(options.getSrmClasses().contains("all")){
28+
if (options.getSrmClasses().contains("all")) {
3429
options.setSrmClasses(Arrays.asList("source", "sink", "sanitizer", "authentication"));
3530
}
3631

37-
if(options.getCweClasses().contains("all")){
32+
if (options.getCweClasses().contains("all")) {
3833
options.setCweClasses(Arrays.asList("cwe078", "cwe079", "cwe089", "cwe306", "cwe601", "cwe862", "cwe863"));
3934
}
4035

41-
if(options.getFeatureSet().contains("all")){
36+
if (options.getFeatureSet().contains("all")) {
4237
options.setFeatureSet(Arrays.asList("code", "doc-manual", "doc-auto"));
4338
}
4439

40+
if (options.getInstances().isEmpty()) {
41+
42+
List<String> instances = new ArrayList<>();
43+
44+
for (String feature : options.getFeatureSet()){
45+
String filepath = "/dataset/" + options.getToolkit() + "-" + feature + "-instances.arff";
46+
instances.add(fileUtility.getResourceFile(filepath).getAbsolutePath());
47+
}
48+
49+
options.setInstances(instances);
50+
}
51+
4552
logger.info("SWAN options: {}", options);
4653

4754
try {
48-
SwanPipeline swanPipeline = new SwanPipeline(options);
55+
swanPipeline = new SwanPipeline(options);
4956
swanPipeline.run();
5057

5158
return 0;
@@ -60,4 +67,8 @@ public Integer run(SwanOptions options) throws Exception {
6067
fileUtility.dispose();
6168
}
6269
}
70+
71+
public SwanPipeline getSwanPipeline() {
72+
return swanPipeline;
73+
}
6374
}

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/cli/SwanOptions.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ public class SwanOptions {
2424
private double trainTestSplit;
2525
private String phase;
2626
private double predictionThreshold;
27+
private List<String> instancesArff;
2728

2829
public SwanOptions(String testDataDir, String trainDataDir, String datasetJson, String outputDir,
2930
List<String> featureSet, String toolkit, List<String> srmClasses,
@@ -177,12 +178,21 @@ public void setPredictionThreshold(double predictionThreshold) {
177178
this.predictionThreshold = predictionThreshold;
178179
}
179180

181+
public List<String> getInstances() {
182+
return instancesArff;
183+
}
184+
185+
public void setInstances(List<String> instancesArff) {
186+
this.instancesArff = instancesArff;
187+
}
188+
180189
@Override
181190
public String toString() {
182191
return "SwanOptions{" +
183192
"testData='" + testDataDir + '\'' +
184193
", trainData='" + trainDataDir + '\'' +
185194
", datasetJson='" + datasetJson + '\'' +
195+
", instances='" + instancesArff + '\'' +
186196
", outputDir='" + outputDir + '\'' +
187197
", featureSet='" + featureSet + '\'' +
188198
", learningMode='" + toolkit + '\'' +

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/features/FeatureSet.java

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,9 @@ public ArrayList<Attribute> addDocAttributes(FeatureSet.Type instanceSet) {
188188
return attributes;
189189
}
190190

191-
public Instances createInstances(List<Type> featureSets, ArrayList<Attribute> attributes,
192-
Set<Method> methods, Set<Category> categories, String name) {
193191

194-
Instances instances = new Instances(name, attributes, 0);
192+
public Instances createInstances(Instances instances, List<Type> featureSets, ArrayList<Attribute> attributes,
193+
Set<Method> methods, Set<Category> categories) {
195194

196195
for (FeatureSet.Type featureSet : featureSets)
197196
switch (featureSet) {
@@ -207,6 +206,15 @@ public Instances createInstances(List<Type> featureSets, ArrayList<Attribute> at
207206
}
208207

209208

209+
public Instances createInstances(List<Type> featureSets, ArrayList<Attribute> attributes,
210+
Set<Method> methods, Set<Category> categories) {
211+
212+
Instances instances = new Instances("swan-srm", attributes, 0);
213+
214+
return createInstances(instances, featureSets, attributes, methods, categories);
215+
}
216+
217+
210218
/**
211219
* Adds data for SWAN features to instance set.
212220
*
@@ -256,13 +264,13 @@ public ArrayList<Instance> getCodeInstances(Instances instances, Set<Method> met
256264

257265
switch (entry.getKey().applies(method)) {
258266
case TRUE:
259-
inst.setValue(entry.getValue(), "true");
267+
inst.setValue(instances.attribute(String.valueOf(entry.getKey())), "true");
260268
break;
261269
case FALSE:
262-
inst.setValue(entry.getValue(), "false");
270+
inst.setValue(instances.attribute(String.valueOf(entry.getKey())), "false");
263271
break;
264272
default:
265-
inst.setMissing(entry.getValue());
273+
inst.setMissing(instances.attribute(String.valueOf(entry.getKey())));
266274
}
267275
}
268276

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/features/MekaFeatureSet.java

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@
99
import meka.filters.unsupervised.attribute.MekaClassAttributes;
1010
import weka.core.Attribute;
1111
import weka.core.Instances;
12+
import weka.core.converters.ArffLoader;
1213
import weka.filters.Filter;
1314

15+
import java.io.File;
16+
import java.io.IOException;
1417
import java.util.*;
18+
import java.util.stream.Collectors;
1519

1620
public class MekaFeatureSet extends FeatureSet implements IFeatureSet {
1721

@@ -26,14 +30,34 @@ public void createFeatures() {
2630

2731
List<FeatureSet.Type> featureSets = initializeFeatures();
2832

33+
Instances trainInstances = null;
34+
Instances structure = null;
35+
2936
//Create and set attributes for the train instances
30-
ArrayList<Attribute> trainAttributes = createAttributes(getCategories(options.getAllClasses()), dataset.getTrainMethods(), featureSets);
31-
Instances trainInstances = createInstances(featureSets, trainAttributes, dataset.getTrainMethods(), getCategories(options.getAllClasses()), "train-instances");
32-
this.instances.put("train", convertToMekaInstances(trainInstances));
37+
if (options.getInstances().isEmpty()) {
38+
ArrayList<Attribute> trainAttributes = createAttributes(getCategories(options.getAllClasses()), dataset.getTrainMethods(), featureSets);
39+
trainInstances = createInstances(featureSets, trainAttributes, dataset.getTrainMethods(), getCategories(options.getAllClasses()));
40+
} else {
41+
ArffLoader loader = new ArffLoader();
42+
43+
try {
44+
loader.setSource(new File(options.getInstances().get(0)));
45+
trainInstances = loader.getDataSet();
46+
structure = loader.getStructure();
47+
} catch (IOException e) {
48+
e.printStackTrace();
49+
}
50+
}
3351

3452
//Create and set attributes for the test instances.
53+
Attribute idAttr = new Attribute("id", dataset.getTestMethods().stream().map(Method::getArffSafeSignature).collect(Collectors.toList()));
54+
structure.replaceAttributeAt(idAttr, structure.attribute("id").index());
55+
ArrayList<Attribute> aList = Collections.list(structure.enumerateAttributes());
56+
3557
ArrayList<Attribute> testAttributes = createAttributes(getCategories(options.getAllClasses()), dataset.getTestMethods(), featureSets);
36-
Instances testInstances = createInstances(featureSets, testAttributes, dataset.getTestMethods(), getCategories(options.getAllClasses()), "test-instances");
58+
Instances testInstances = (createInstances(structure, featureSets, aList, dataset.getTestMethods(), getCategories(options.getAllClasses())));
59+
60+
this.instances.put("train", convertToMekaInstances(trainInstances));
3761
this.instances.put("test", convertToMekaInstances(testInstances));
3862
}
3963

@@ -67,7 +91,7 @@ public Instances convertToMekaInstances(Instances instances) {
6791
filter.setAttributeIndices("1-11");
6892
filter.setInputFormat(instances);
6993
output = Filter.useFilter(instances, filter);
70-
output.setRelationName(instances.relationName() + ":" + output.relationName());
94+
output.setRelationName("swan-srm:" + output.relationName());
7195

7296
Util.exportInstancesToArff(output);
7397
} catch (Exception e) {
@@ -85,5 +109,4 @@ public HashSet<Category> getCategories(List<String> cat) {
85109

86110
return categories;
87111
}
88-
89112
}

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/features/WekaFeatureSet.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public void createFeatures() {
3434
ArrayList<Attribute> trainAttributes = createAttributes(category, dataset.getTrainMethods(), featureSets);
3535

3636
String instanceName = category.getId().toLowerCase() + "-train-instances";
37-
Instances trainInstances = createInstances(featureSets, trainAttributes, dataset.getTrainMethods(), Collections.singleton(category), instanceName);
37+
Instances trainInstances = createInstances(featureSets, trainAttributes, dataset.getTrainMethods(), Collections.singleton(category));
3838
this.instances.put(category.getId().toLowerCase(), trainInstances);
3939
Util.exportInstancesToArff(trainInstances);
4040

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/io/dataset/SrmList.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ public SrmList(String sourceFileDir) {
3131
methods = new HashSet<>();
3232
}
3333

34-
3534
public void load(final Set<Method> trainingSet) {
3635

3736
Util.createSubclassAnnotations(methods, "classpath");

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/io/dataset/SrmListUtils.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,9 @@ public static SrmList importFile(String file) throws IOException {
4343
*/
4444
public static void exportFile(SrmList srmList, String file) throws IOException {
4545

46-
srmList.removeUnclassifiedMethods();
4746
ObjectMapper objectMapper = new ObjectMapper();
4847
objectMapper.writeValue(new File(file), srmList);
49-
logger.info("{} methods exported to {}", srmList.getMethods().size(), file);
48+
logger.info("{} SRMs exported to {}", srmList.getMethods().size(), file);
5049
}
5150

5251
public static void removeUndocumentedMethods(SrmList list) {

0 commit comments

Comments
 (0)