Skip to content

Commit ab922ad

Browse files
committed
Rename and refactor model evaluation classes
1 parent 62af5ca commit ab922ad

File tree

3 files changed

+116
-417
lines changed

3 files changed

+116
-417
lines changed

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/SwanPipeline.java

Lines changed: 8 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,10 @@
88
import de.fraunhofer.iem.swan.io.dataset.SrmListUtils;
99
import de.fraunhofer.iem.swan.model.ModelEvaluator;
1010
import de.fraunhofer.iem.swan.util.Util;
11-
import org.apache.commons.lang3.StringUtils;
1211
import org.slf4j.Logger;
1312
import org.slf4j.LoggerFactory;
14-
1513
import java.io.File;
1614
import java.io.IOException;
17-
import java.util.*;
1815

1916
/**
2017
* Runner for SWAN
@@ -24,28 +21,18 @@
2421

2522
public class SwanPipeline {
2623

27-
private Learner learner;
28-
private Loader loader;
29-
private Parser parser;
30-
private FeatureHandler featureHandler;
31-
private Writer writer;
32-
// Configuration tags for debugging
33-
3424
private static final Logger logger = LoggerFactory.getLogger(SwanPipeline.class);
35-
private DocFeatureHandler docFeatureHandler;
36-
public static HashMap<String, HashSet<String>> predictions;
37-
3825
public static SwanOptions options;
3926

4027
public SwanPipeline(SwanOptions options) {
4128
SwanPipeline.options = options;
4229
}
4330

4431
/**
45-
* Executes the analysis and can also be called from outside by lients.
32+
* Executes the analysis and can also be called from outside by clients.
4633
*
47-
* @throws IOException In case an error occurs during the preparation
48-
* or execution of the analysis.
34+
* @throws IOException In case an error occurs during the preparation
35+
* or execution of the analysis.
4936
*/
5037
public void run() throws IOException, InterruptedException {
5138

@@ -64,85 +51,17 @@ public void run() throws IOException, InterruptedException {
6451
FeaturesHandler featuresHandler = new FeaturesHandler(dataset, testDataset, options);
6552
featuresHandler.createFeatures();
6653

67-
// Cache the methods from the second test set.
68-
loader.pruneNone();
69-
70-
/*
71-
SECOND PHASE - binary classification for each of the CWE categories.
72-
(1) Classify: cwe78, cwe079, cwe089, cwe306, cwe601, cwe862, cwe863
73-
*/
74-
runClassEvaluation(options.getCweClasses(), feature, learnerMode);
54+
//Train and evaluate model for SRM and CWE categories
55+
ModelEvaluator modelEvaluator = new ModelEvaluator(featuresHandler, options.getLearningMode(), options.getIterations(), options.getTrainTestSplit());
56+
modelEvaluator.trainModel();
7557

58+
//TODO export final list to JSON file
7659
String outputFile = options.getOutputDir() + File.separator + "swan-srm-cwe-list.json";
7760
ObjectMapper objectMapper = new ObjectMapper();
7861
objectMapper.writeValue(new File(outputFile), dataset);
7962
logger.info("SRM/CWE list exported to {}", outputFile);
8063

8164
long analysisTime = System.currentTimeMillis() - startAnalysisTime;
82-
logger.info("Total runtime {} mins", analysisTime / 60000);
83-
}
84-
85-
public void runClassEvaluation(List<String> classes, InstancesHandler.FeatureSet featureSet, Learner.Mode learnerMode) {
86-
87-
for (String cat : classes) {
88-
89-
HashSet<Category> categories;
90-
91-
if (cat.contentEquals("authentication"))
92-
categories = new HashSet<>(Arrays.asList(Category.AUTHENTICATION_TO_HIGH,
93-
Category.AUTHENTICATION_TO_LOW, Category.AUTHENTICATION_NEUTRAL, Category.NONE));
94-
else
95-
categories = new HashSet<>(Arrays.asList(Category.fromText(cat), Category.NONE));
96-
97-
runClassifier(categories, Learner.EVAL_MODE.CLASS, featureSet, learnerMode);
98-
}
99-
}
100-
101-
102-
private double runClassifier(HashSet<Category> categories, Learner.EVAL_MODE eval_mode, InstancesHandler.FeatureSet featureSet, Learner.Mode learnerMode) {
103-
parser.resetMethods();
104-
loader.resetMethods();
105-
106-
logger.info("Starting classification for {}", categories.toString());
107-
long startAnalysisTime;
108-
109-
if (categories.stream().anyMatch(Category::isCwe)) {
110-
111-
ArrayList<InstancesHandler> instancesHandlers = new ArrayList<>();
112-
for (String iteration : predictions.keySet()) {
113-
114-
if (predictions.get(iteration).size() == 0)
115-
continue;
116-
117-
HashSet<Method> methods = new HashSet<>();
118-
119-
for (Method method : parser.getMethods()) {
120-
if (predictions.get(iteration).contains(method.getArffSafeSignature()))
121-
methods.add(method);
122-
}
123-
124-
InstancesHandler instancesHandler = new InstancesHandler();
125-
instancesHandler.createInstances(methods, featureHandler.features(), docFeatureHandler, categories, featureSet);
126-
127-
instancesHandlers.add(instancesHandler);
128-
}
129-
130-
startAnalysisTime = System.currentTimeMillis();
131-
learner.trainModel(instancesHandlers, learnerMode);
132-
133-
} else {
134-
InstancesHandler instancesHandler = new InstancesHandler();
135-
instancesHandler.createInstances(parser.getMethods(), featureHandler.features(), docFeatureHandler, categories, featureSet);
136-
startAnalysisTime = System.currentTimeMillis();
137-
138-
ArrayList<InstancesHandler> instancesHandlers = new ArrayList<>();
139-
instancesHandlers.add(instancesHandler);
140-
learner.trainModel(instancesHandlers, learnerMode);
141-
}
142-
143-
long analysisTime = System.currentTimeMillis() - startAnalysisTime;
144-
logger.info("Total time for classification {}ms", analysisTime);
145-
146-
return 0.0;
65+
logger.info("Total runtime {} minutes", analysisTime / 60000);
14766
}
14867
}

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/model/Learner.java

Lines changed: 0 additions & 168 deletions
This file was deleted.

0 commit comments

Comments
 (0)