1+ package de .fraunhofer .iem .swan .model .engine ;
2+
3+ import de .fraunhofer .iem .swan .cli .SwanOptions ;
4+ import de .fraunhofer .iem .swan .data .Category ;
5+ import de .fraunhofer .iem .swan .data .Method ;
6+ import de .fraunhofer .iem .swan .features .FeaturesHandler ;
7+ import de .fraunhofer .iem .swan .io .dataset .SrmList ;
8+ import de .fraunhofer .iem .swan .io .dataset .SrmListUtils ;
9+ import de .fraunhofer .iem .swan .model .ModelEvaluator ;
10+ import meka .classifiers .multilabel .BR ;
11+ import meka .classifiers .multilabel .Evaluation ;
12+ import meka .core .Result ;
13+ import org .slf4j .Logger ;
14+ import org .slf4j .LoggerFactory ;
15+ import weka .core .Instances ;
16+
17+ import java .io .File ;
18+ import java .io .IOException ;
19+ import java .util .ArrayList ;
20+ import java .util .HashMap ;
21+ import java .util .Set ;
22+
23+ public class Meka {
24+
25+ private FeaturesHandler features ;
26+ private SwanOptions options ;
27+ private Set <Method > methods ;
28+ private static final Logger logger = LoggerFactory .getLogger (ModelEvaluator .class );
29+
30+ public Meka (FeaturesHandler features , SwanOptions options , Set <Method > methods ) {
31+ this .features = features ;
32+ this .options = options ;
33+ this .methods = methods ;
34+ }
35+
36+ /**
37+ * Trains and evaluates the model with the given training data and specified classification mode.
38+ *
39+ */
40+ public void trainModel () {
41+
42+ switch (ModelEvaluator .Phase .valueOf (options .getPhase ().toUpperCase ())) {
43+ case VALIDATE :
44+ crossValidateModel (features .getTrainInstances ());
45+ break ;
46+ case PREDICT :
47+ HashMap <String , ArrayList <Category >> predictions = predictModel (features .getTrainInstances (), features .getTestInstances (), options .getPredictionThreshold ());
48+
49+ for (Method method : methods ) {
50+ for (Category category : predictions .get (method .getArffSafeSignature ())) {
51+ method .addCategory (category );
52+ }
53+ }
54+
55+ try {
56+ SrmListUtils .exportFile (new SrmList (methods ), options .getOutputDir () + File .separator + "swan-srm-cwe-list.json" );
57+ } catch (IOException e ) {
58+ e .printStackTrace ();
59+ }
60+ break ;
61+ }
62+ }
63+
64+ /**
65+ * Cross-validates a ML model using the provided instances and outputs metrics.
66+ *
67+ * @param instances training instances
68+ */
69+ public void crossValidateModel (Instances instances ) {
70+
71+ try {
72+
73+ BR classifier = new BR ();
74+ String top = "PCut1" ;
75+ String verbosity = "7" ;
76+ Result result = Evaluation .cvModel (classifier , instances , options .getIterations (), top , verbosity );
77+
78+ logger .info ("Model cross-validation results {}" , result );
79+ } catch (Exception e ) {
80+ e .printStackTrace ();
81+ }
82+ }
83+
84+ /**
85+ * Trains a model and uses the trained model to predict the categories for the methods in the test set.
86+ *
87+ * @param train training instances
88+ * @param test test instances
89+ * @param threshold threshold used to determine if a method should be classified into a category
90+ * @return hash map of method signatures and the categories they're classified into
91+ */
92+ public HashMap <String , ArrayList <Category >> predictModel (Instances train , Instances test , double threshold ) {
93+
94+ HashMap <String , ArrayList <Category >> predictions = new HashMap <>();
95+ try {
96+ BR classifier = new BR ();
97+ classifier .buildClassifier (train );
98+
99+ for (int i = 0 ; i < test .numInstances (); i ++) {
100+ double [] dist = classifier .distributionForInstance (test .instance (i ));
101+
102+ ArrayList <Category > categories = new ArrayList <>();
103+ for (int p = 0 ; p < dist .length ; p ++) {
104+ if (dist [p ] >= threshold )
105+ categories .add (Category .valueOf (test .attribute (p ).name ().toUpperCase ()));
106+ }
107+ predictions .put (test .get (i ).stringValue (test .attribute ("id" ).index ()), categories );
108+ }
109+
110+ } catch (Exception e ) {
111+ e .printStackTrace ();
112+ }
113+ return predictions ;
114+ }
115+ }
0 commit comments