88import de .fraunhofer .iem .swan .io .dataset .SrmListUtils ;
99import de .fraunhofer .iem .swan .model .ModelEvaluator ;
1010import de .fraunhofer .iem .swan .util .Util ;
11- import org .apache .commons .lang3 .StringUtils ;
1211import org .slf4j .Logger ;
1312import org .slf4j .LoggerFactory ;
14-
1513import java .io .File ;
1614import java .io .IOException ;
17- import java .util .*;
1815
1916/**
2017 * Runner for SWAN
2421
2522public class SwanPipeline {
2623
27- private Learner learner ;
28- private Loader loader ;
29- private Parser parser ;
30- private FeatureHandler featureHandler ;
31- private Writer writer ;
32- // Configuration tags for debugging
33-
3424 private static final Logger logger = LoggerFactory .getLogger (SwanPipeline .class );
35- private DocFeatureHandler docFeatureHandler ;
36- public static HashMap <String , HashSet <String >> predictions ;
37-
3825 public static SwanOptions options ;
3926
4027 public SwanPipeline (SwanOptions options ) {
4128 SwanPipeline .options = options ;
4229 }
4330
4431 /**
45- * Executes the analysis and can also be called from outside by lients .
32+ * Executes the analysis and can also be called from outside by clients .
4633 *
47- * @throws IOException In case an error occurs during the preparation
48- * or execution of the analysis.
34+ * @throws IOException In case an error occurs during the preparation
35+ * or execution of the analysis.
4936 */
5037 public void run () throws IOException , InterruptedException {
5138
@@ -64,85 +51,17 @@ public void run() throws IOException, InterruptedException {
6451 FeaturesHandler featuresHandler = new FeaturesHandler (dataset , testDataset , options );
6552 featuresHandler .createFeatures ();
6653
67- // Cache the methods from the second test set.
68- loader .pruneNone ();
69-
70- /*
71- SECOND PHASE - binary classification for each of the CWE categories.
72- (1) Classify: cwe78, cwe079, cwe089, cwe306, cwe601, cwe862, cwe863
73- */
74- runClassEvaluation (options .getCweClasses (), feature , learnerMode );
54+ //Train and evaluate model for SRM and CWE categories
55+ ModelEvaluator modelEvaluator = new ModelEvaluator (featuresHandler , options .getLearningMode (), options .getIterations (), options .getTrainTestSplit ());
56+ modelEvaluator .trainModel ();
7557
58+ //TODO export final list to JSON file
7659 String outputFile = options .getOutputDir () + File .separator + "swan-srm-cwe-list.json" ;
7760 ObjectMapper objectMapper = new ObjectMapper ();
7861 objectMapper .writeValue (new File (outputFile ), dataset );
7962 logger .info ("SRM/CWE list exported to {}" , outputFile );
8063
8164 long analysisTime = System .currentTimeMillis () - startAnalysisTime ;
82- logger .info ("Total runtime {} mins" , analysisTime / 60000 );
83- }
84-
85- public void runClassEvaluation (List <String > classes , InstancesHandler .FeatureSet featureSet , Learner .Mode learnerMode ) {
86-
87- for (String cat : classes ) {
88-
89- HashSet <Category > categories ;
90-
91- if (cat .contentEquals ("authentication" ))
92- categories = new HashSet <>(Arrays .asList (Category .AUTHENTICATION_TO_HIGH ,
93- Category .AUTHENTICATION_TO_LOW , Category .AUTHENTICATION_NEUTRAL , Category .NONE ));
94- else
95- categories = new HashSet <>(Arrays .asList (Category .fromText (cat ), Category .NONE ));
96-
97- runClassifier (categories , Learner .EVAL_MODE .CLASS , featureSet , learnerMode );
98- }
99- }
100-
101-
102- private double runClassifier (HashSet <Category > categories , Learner .EVAL_MODE eval_mode , InstancesHandler .FeatureSet featureSet , Learner .Mode learnerMode ) {
103- parser .resetMethods ();
104- loader .resetMethods ();
105-
106- logger .info ("Starting classification for {}" , categories .toString ());
107- long startAnalysisTime ;
108-
109- if (categories .stream ().anyMatch (Category ::isCwe )) {
110-
111- ArrayList <InstancesHandler > instancesHandlers = new ArrayList <>();
112- for (String iteration : predictions .keySet ()) {
113-
114- if (predictions .get (iteration ).size () == 0 )
115- continue ;
116-
117- HashSet <Method > methods = new HashSet <>();
118-
119- for (Method method : parser .getMethods ()) {
120- if (predictions .get (iteration ).contains (method .getArffSafeSignature ()))
121- methods .add (method );
122- }
123-
124- InstancesHandler instancesHandler = new InstancesHandler ();
125- instancesHandler .createInstances (methods , featureHandler .features (), docFeatureHandler , categories , featureSet );
126-
127- instancesHandlers .add (instancesHandler );
128- }
129-
130- startAnalysisTime = System .currentTimeMillis ();
131- learner .trainModel (instancesHandlers , learnerMode );
132-
133- } else {
134- InstancesHandler instancesHandler = new InstancesHandler ();
135- instancesHandler .createInstances (parser .getMethods (), featureHandler .features (), docFeatureHandler , categories , featureSet );
136- startAnalysisTime = System .currentTimeMillis ();
137-
138- ArrayList <InstancesHandler > instancesHandlers = new ArrayList <>();
139- instancesHandlers .add (instancesHandler );
140- learner .trainModel (instancesHandlers , learnerMode );
141- }
142-
143- long analysisTime = System .currentTimeMillis () - startAnalysisTime ;
144- logger .info ("Total time for classification {}ms" , analysisTime );
145-
146- return 0.0 ;
65+ logger .info ("Total runtime {} minutes" , analysisTime / 60000 );
14766 }
14867}
0 commit comments