11package de .fraunhofer .iem .swan .model .toolkit ;
22
33import de .fraunhofer .iem .swan .cli .SwanOptions ;
4+ import de .fraunhofer .iem .swan .data .Category ;
5+ import de .fraunhofer .iem .swan .data .Method ;
46import de .fraunhofer .iem .swan .features .WekaFeatureSet ;
7+ import de .fraunhofer .iem .swan .io .dataset .SrmList ;
58import de .fraunhofer .iem .swan .model .ModelEvaluator ;
69import de .fraunhofer .iem .swan .model .MonteCarloValidator ;
10+ import javafx .util .Pair ;
711import org .slf4j .Logger ;
812import org .slf4j .LoggerFactory ;
13+ import weka .classifiers .AbstractClassifier ;
914import weka .classifiers .Classifier ;
15+ import weka .classifiers .Evaluation ;
1016import weka .classifiers .bayes .BayesNet ;
1117import weka .classifiers .bayes .NaiveBayes ;
1218import weka .classifiers .functions .Logistic ;
1319import weka .classifiers .functions .SMO ;
1420import weka .classifiers .rules .JRip ;
1521import weka .classifiers .trees .DecisionStump ;
1622import weka .classifiers .trees .J48 ;
23+ import weka .core .Instance ;
1724import weka .core .Instances ;
1825import weka .filters .Filter ;
1926import weka .filters .MultiFilter ;
2027
21- import java .util .ArrayList ;
22- import java .util .HashMap ;
23- import java .util .LinkedHashMap ;
24- import java .util .List ;
28+ import java .text .DecimalFormat ;
29+ import java .util .*;
2530
2631/**
2732 * Finds possible sources and sinks in a given set of system methods using a
@@ -33,45 +38,104 @@ public class Weka {
3338
3439 private WekaFeatureSet features ;
3540 private SwanOptions options ;
36- private static final Logger logger = LoggerFactory .getLogger (ModelEvaluator .class );
41+ private Set <Method > methods ;
42+ private static final Logger logger = LoggerFactory .getLogger (Weka .class );
43+ private HashMap <String , ArrayList <Category >> predictions ;
44+ private HashMap <String , HashMap <String , ArrayList <Double >>> results ;
45+ private DecimalFormat df = new DecimalFormat ("####0.00" );
3746
38- public Weka (WekaFeatureSet features , SwanOptions options ) {
47+ public Weka (WekaFeatureSet features , SwanOptions options , Set < Method > methods ) {
3948 this .features = features ;
4049 this .options = options ;
50+ this .methods = methods ;
51+ predictions = new HashMap <>();
52+
53+ results = new HashMap <>();
54+
55+ if (options .isPredictPhase ())
56+ for (Method method : features .getDataset ().getTestMethods ())
57+ predictions .put (method .getArffSafeSignature (), new ArrayList <>());
4158 }
4259
4360 /**
4461 * Trains and evaluates the model with the given training data and specified classification mode.
4562 *
4663 * @return Hashmap containing the name of the classifier and it's F-Measure
4764 */
48- public HashMap < String , HashMap < String , String >> trainModel () {
65+ public SrmList trainModel () {
4966
5067 switch (ModelEvaluator .Phase .valueOf (options .getPhase ().toUpperCase ())) {
5168 case VALIDATE :
5269
5370 //Phase 1: classify SRM classes
5471 logger .info ("Performing {}-fold cross-validation for {} using WEKA" , options .getIterations (), options .getSrmClasses ());
5572 for (String srm : options .getSrmClasses ())
56- runManualEvaluation (features .getTrainInstances ().get (srm ));
73+ crossValidateModel (features .getTrainInstances ().get (srm ));
5774
5875 //Filter methods from CWE instances that were not classified into one of the SRM classes
5976 //Phase 2: classify CWE classes
6077 logger .info ("Performing {}-fold cross-validation for {} using WEKA" , options .getIterations (), options .getCweClasses ());
6178 for (String cwe : options .getCweClasses ())
62- runManualEvaluation (features .getTrainInstances ().get (cwe ));
79+ crossValidateModel (features .getTrainInstances ().get (cwe ));
80+
81+ // TreeMap to store values of HashMap
82+ TreeMap <String , HashMap <String , ArrayList <Double >>> sorted
83+ = new TreeMap <>(results );
6384
64- return null ;
85+ // Display the TreeMap which is naturally sorted
86+ for (Map .Entry <String , HashMap <String , ArrayList <Double >>> entry :
87+ sorted .entrySet ())
88+
89+ return null ;
6590 case PREDICT :
6691
92+ logger .info ("Predicting {} for TEST dataset using WEKA" , options .getSrmClasses ());
93+ for (String srm : options .getSrmClasses ()) {
94+ predictModel (srm );
95+ }
96+
97+ for (Method method : methods ) {
98+ for (Category category : predictions .get (method .getArffSafeSignature ())) {
99+ method .addCategory (category );
100+ }
101+ }
102+ return new SrmList (methods );
67103 }
68104 return null ;
69105 }
70106
107+ public void predictModel (String srm ) {
108+ Pair <String , Double > bestClassifier = crossValidateModel (features .getTrainInstances ().get (srm ));
109+
110+ try {
111+
112+ Classifier classifier = AbstractClassifier .forName (bestClassifier .getKey (), null );
113+ classifier .buildClassifier (features .getTrainInstances ().get (srm ));
114+
115+ //NaiveBayes classifier = new NaiveBayes();
116+ //classifier.buildClassifier(features.getTrainInstances().get(srm));
117+
118+ Evaluation eval = new Evaluation (features .getTestInstances ().get (srm ));
119+ eval .evaluateModel (classifier , features .getTestInstances ().get (srm ));
120+
121+ for (Instance instance : features .getTestInstances ().get (srm )) {
122+
123+ double prediction = eval .evaluateModelOnce (classifier , instance );
124+ if (prediction > 0 ) {
125+ predictions .get (instance .stringValue (features .getTestInstances ().get (srm ).attribute ("id" ).index ()))
126+ .add (Category .valueOf (srm .toUpperCase ()));
127+ }
128+
129+ }
130+ } catch (Exception e ) {
131+ e .printStackTrace ();
132+ }
133+ }
134+
71135 /**
72136 * @return
73137 */
74- public HashMap <String , HashMap < String , String >> runManualEvaluation (Instances instances ) {
138+ public Pair <String , Double > crossValidateModel (Instances instances ) {
75139
76140 String category = instances .attribute (instances .numAttributes () - 1 ).name ();
77141 instances .setClass (instances .attribute (instances .numAttributes () - 1 ));
@@ -87,19 +151,47 @@ public HashMap<String, HashMap<String, String>> runManualEvaluation(Instances in
87151 classifiers .add (new DecisionStump ());
88152 classifiers .add (new Logistic ());
89153
154+ Pair <String , Double > bestClassifier = new Pair <>("" , 0.0 );
155+ List <Pair > classifierSummary = new ArrayList <>();
90156 //For each classifier, evaluate its performance on the instances
157+
158+ HashMap <String , ArrayList <Double >> measure = new HashMap <>();
91159 for (Classifier classifier : classifiers ) {
92160
93161 MonteCarloValidator evaluator = new MonteCarloValidator ();
94162 evaluator .monteCarloValidate (instances , classifier , options .getTrainTestSplit (), options .getIterations ());
95163
96164 for (String key : evaluator .getFMeasure ().keySet ()) {
97165
98- logger .info ("Average F-measure for {}({}) using {}: {}, {}" , category , key , classifier .getClass ().getSimpleName (),
99- evaluator .getFMeasure ().get (key ).stream ().mapToDouble (a -> a ).average ().getAsDouble (), evaluator .getFMeasure ().get (key ));
166+ double averageFMeasure = evaluator .getFMeasure ().get (key ).stream ().mapToDouble (a -> a ).average ().getAsDouble ();
167+ Pair summary = new Pair <>(classifier .getClass ().getSimpleName (), Double .parseDouble (df .format (averageFMeasure )));
168+ classifierSummary .add (summary );
169+
170+ if (averageFMeasure > bestClassifier .getValue ()) {
171+ bestClassifier = summary ;
172+ }
173+
174+ if (category .contains ("authentication" )) {
175+
176+ if (!results .containsKey (category + "-" + key )) {
177+ HashMap <String , ArrayList <Double >> m = new HashMap <>();
178+ m .put (classifier .getClass ().getSimpleName (), evaluator .getFMeasure ().get (key ));
179+ results .put (category + "-" + key , m );
180+ } else {
181+ results .get (category + "-" + key ).put (classifier .getClass ().getSimpleName (), evaluator .getFMeasure ().get (key ));
182+ }
183+ } else
184+ measure .put (classifier .getClass ().getSimpleName (), evaluator .getFMeasure ().get (key ));
185+
186+ logger .debug ("Average F-measure for {}({}) using {}: {}, {}" , category , key , classifier .getClass ().getSimpleName (),
187+ averageFMeasure , evaluator .getFMeasure ().get (key ));
100188 }
189+ if (!category .contains ("authentication" ))
190+ results .put (category , measure );
101191 }
102- return fMeasure ;
192+ logger .info ("Selecting {} model for {}, evaluated classifiers={}" , bestClassifier .getKey (), category , classifierSummary );
193+
194+ return bestClassifier ;
103195 }
104196
105197 /**
0 commit comments