11package de .fraunhofer .iem .swan .model .toolkit ;
22
33import ai .libs .jaicore .ml .classification .loss .dataset .EClassificationPerformanceMeasure ;
4- import ai .libs .jaicore .ml .core .dataset .schema .attribute .IntBasedCategoricalAttribute ;
54import ai .libs .jaicore .ml .core .dataset .serialization .ArffDatasetAdapter ;
5+ import ai .libs .jaicore .ml .core .evaluation .evaluator .SupervisedLearnerExecutor ;
66import ai .libs .jaicore .ml .core .filter .SplitterUtil ;
77import ai .libs .jaicore .ml .weka .classification .learner .IWekaClassifier ;
88import ai .libs .mlplan .weka .MLPlanWekaBuilder ;
9- import de .fraunhofer .iem .swan .model .MonteCarloValidator ;
9+ import de .fraunhofer .iem .swan .cli .SwanOptions ;
10+ import de .fraunhofer .iem .swan .features .WekaFeatureSet ;
11+ import de .fraunhofer .iem .swan .model .ModelEvaluator ;
1012import de .fraunhofer .iem .swan .util .Util ;
11- import org .api4 .java .ai .ml .core . dataset . schema . attribute . IAttribute ;
13+ import org .api4 .java .ai .ml .classification . singlelabel . evaluation . ISingleLabelClassification ;
1214import org .api4 .java .ai .ml .core .dataset .serialization .DatasetDeserializationFailedException ;
1315import org .api4 .java .ai .ml .core .dataset .splitter .SplitFailedException ;
1416import org .api4 .java .ai .ml .core .dataset .supervised .ILabeledDataset ;
17+ import org .api4 .java .ai .ml .core .dataset .supervised .ILabeledInstance ;
18+ import org .api4 .java .ai .ml .core .evaluation .execution .ILearnerRunReport ;
19+ import org .api4 .java .ai .ml .core .evaluation .execution .LearnerExecutionFailedException ;
1520import org .api4 .java .algorithm .Timeout ;
1621import org .api4 .java .algorithm .exceptions .AlgorithmException ;
1722import org .api4 .java .algorithm .exceptions .AlgorithmExecutionCanceledException ;
1823import org .api4 .java .algorithm .exceptions .AlgorithmTimeoutedException ;
1924import org .slf4j .Logger ;
2025import org .slf4j .LoggerFactory ;
2126import weka .classifiers .Classifier ;
22- import weka .core .Instances ;
23- import weka .core .converters .ArffLoader ;
2427
2528import java .io .File ;
2629import java .io .IOException ;
@@ -37,215 +40,95 @@ public class MLPlan {
3740
3841 private static final Logger LOGGER = LoggerFactory .getLogger (MLPlan .class );
3942 private final int ITERATIONS = 1 ;
43+ private WekaFeatureSet featureSet ;
44+ private SwanOptions swanOptions ;
45+ long start ;
4046
41- public MLPlan () {
42-
47+ public MLPlan (WekaFeatureSet features , SwanOptions options ) {
48+ this .featureSet = features ;
49+ swanOptions = options ;
4350 }
4451
4552 /**
46- * Run ML-Plan using the provided path to the ARFF file .
53+ * Trains and evaluates the model with the given training data and specified classification mode .
4754 *
48- * @param instances1 file path for ARFF file
55+ * @return Hashmap containing the name of the classifier and it's F-Measure
4956 */
50- public HashMap <String , ArrayList <Double >> evaluateDataset (Instances instances1 ) {
51-
52- String arffFilePath = Util .exportInstancesToArff (instances1 , "mlplan" );
53- ArffDatasetAdapter arffDatasetAdapter = new ArffDatasetAdapter ();
57+ public HashMap <String , HashMap <String , String >> trainModel () {
5458
55- String mClass = Util .getClassName (instances1 );
59+ switch (ModelEvaluator .Phase .valueOf (swanOptions .getPhase ().toUpperCase ())) {
60+ case VALIDATE :
5661
57- long start = System .currentTimeMillis ();
58-
59- //Initialize dataset using ARFF file path
60- ILabeledDataset <?> dataset = null ;
61- try {
62- dataset = arffDatasetAdapter .readDataset (new File (arffFilePath ));
63- } catch (DatasetDeserializationFailedException e ) {
64- e .printStackTrace ();
62+ evaluateData (Util .exportInstancesToArff (featureSet .getTrainInstances ().get ("sanitizer" ), "mlplan" ));
63+ return null ;
64+ case PREDICT :
6565 }
66-
67- //dataset.removeColumn("id");
68-
69- ArrayList <Double > fScores = new ArrayList <>();
70- ArrayList <String > algorithms = new ArrayList <>();
71-
72- MonteCarloValidator monteCarloValidator = new MonteCarloValidator ();
73-
74- ArffLoader loader = new ArffLoader ();
75- try {
76- loader .setFile (new File (arffFilePath ));
77- Instances instances = loader .getDataSet ();
78- instances .setClassIndex (instances .numAttributes () - 1 );
79- // monteCarloValidator.initializeResultSet(instances);
80- } catch (IOException e ) {
81- e .printStackTrace ();
82- }
83-
84-
85- //For each iteration, create a new train-test-split and run ML-Plan
86- for (int iteration = 0 ; iteration < ITERATIONS ; iteration ++) {
87-
88- System .out .println ("Iteration #" +iteration );
89- try {
90- List <ILabeledDataset <?>> split = SplitterUtil .getLabelStratifiedTrainTestSplit (dataset , new Random (1337 + (iteration * 11 )), 0.7 );
91- LOGGER .info ("Data read. Time to create dataset object was {}ms" , System .currentTimeMillis () - start );
92-
93- Classifier optimizedClassifier = getClassifier (split .get (0 ));
94- //System.out.println("Classify: " + optimizedClassifier.getClassifier().getClass().getSimpleName());
95-
96- //optimizedClassifier.fit(split.get(0));
97-
98- String trainPath = "swan/swan_core/swan-out/mlplan/train-methods-dataset.arff" ;
99- arffDatasetAdapter .serializeDataset (new File (trainPath ), split .get (0 ));
100- ArffLoader trainLoader = new ArffLoader ();
101- trainLoader .setFile (new File (trainPath ));
102- Instances trainInstances = trainLoader .getDataSet ();
103- trainInstances .setClassIndex (trainInstances .numAttributes () - 1 );
104-
105- String testPath = "swan/swan_core/swan-out/mlplan/test-methods-dataset.arff" ;
106- arffDatasetAdapter .serializeDataset (new File (testPath ), split .get (1 ));
107- ArffLoader testLoader = new ArffLoader ();
108- testLoader .setFile (new File (testPath ));
109- Instances testInstances = testLoader .getDataSet ();
110- testInstances .setClassIndex (testInstances .numAttributes () - 1 );
111-
112- monteCarloValidator .evaluate (optimizedClassifier , trainInstances , testInstances );
113-
114-
115- /* evaluate solution produced by mlplan */
116- /* SupervisedLearnerExecutor executor = new SupervisedLearnerExecutor();
117- ILearnerRunReport report = executor.execute(optimizedClassifier, split.get(0), split.get(1));
118-
119- for (Object pred : report.getPredictionDiffList().getPredictionsAsList()) {
120-
121- SingleLabelClassification cl = (SingleLabelClassification) pred;
122-
123- }
124-
125- for (Object prediction : report.getPredictionDiffList().getPredictionsAsList()) {
126-
127- SingleLabelClassification label = (SingleLabelClassification) prediction;
128- }
129-
130-
131-
132-
133- LOGGER.info("Model selected: {},{},{},{}", mClass, iteration,
134- optimizedClassifier.getClassifier().getClass().getSimpleName(),
135- EClassificationPerformanceMeasure.F1_WITH_1_POSITIVE.F1_WITH_1_POSITIVE.loss(report.getPredictionDiffList().getCastedView(Integer.class, ISingleLabelClassification.class)));
136- */
137- //fScores.add(EClassificationPerformanceMeasure.F1_WITH_1_POSITIVE.loss(report.getPredictionDiffList().getCastedView(Integer.class, ISingleLabelClassification.class)));
138- //algorithms.add(optimizedClassifier.getClassifier().getClass().getSimpleName());
139- //LOGGER.info("Error Rate of the solution produced by ML-Plan: {}. ", );
140-
141- } catch (SplitFailedException | InterruptedException | IOException e ) {
142- e .printStackTrace ();
143- }
144- }
145- return monteCarloValidator .getFMeasure ();
66+ return null ;
14667 }
14768
148- public void evaluateDataset ( Instances instances , int k ) {
69+ public HashMap < String , ArrayList < Double >> evaluateData ( String arffFilePath ) {
14970
150- //arffFilePath = "swan/swan_core/src/main/resources/waveform.arff";
151- String arffFilePath = Util .exportInstancesToArff (instances , "mlplan" );
152- ArffDatasetAdapter arffDatasetAdapter = new ArffDatasetAdapter ();
71+ start = System .currentTimeMillis ();
15372
154- String mClass = Util .getClassName (instances );
155-
156- long start = System .currentTimeMillis ();
157-
158- //Initialize dataset using ARFF file path
159- ILabeledDataset <?> dataset = null ;
16073 try {
161- dataset = arffDatasetAdapter .readDataset (new File (arffFilePath ));
162- } catch (DatasetDeserializationFailedException e ) {
163- e .printStackTrace ();
164- }
165-
166- //dataset.removeColumn("id");
167-
168- MonteCarloValidator monteCarloValidator = new MonteCarloValidator ();
16974
170- //For each iteration, create a new train-test-split and run ML-Plan
171- for ( int iteration = 0 ; iteration < ITERATIONS ; iteration ++) {
75+ ArffDatasetAdapter arffDatasetAdapter = new ArffDatasetAdapter ();
76+ ILabeledDataset < ILabeledInstance > dataset = arffDatasetAdapter . readDataset ( new File ( arffFilePath ));
17277
173- try {
174- List <ILabeledDataset <?>> split = SplitterUtil .getLabelStratifiedTrainTestSplit (dataset , new Random (1337 + (iteration * 11 )), 0.7 );
175- LOGGER .info ("Data read. Time to create dataset object was {}ms" , System .currentTimeMillis () - start );
78+ List <ILabeledDataset <?>> split = SplitterUtil .getLabelStratifiedTrainTestSplit (dataset , new Random (42 ), .7 );
79+ LOGGER .info ("Data read. Time to create dataset object was {}ms" , System .currentTimeMillis () - start );
17680
177- System .out .println (split .get (1 ).getLabelAttribute ().getName ());
178- for (IAttribute attribute : split .get (1 ).getListOfAttributes ()) {
81+ getClassifier (split );
17982
180- // System.out.println(attribute.getName());
181- }
182- arffDatasetAdapter .serializeDataset (new File ("swan/swan_core/swan-out/mlplan/methods-dataset.arff" ), split .get (1 ));
183-
184-
185- for (int x = 0 ; x < split .get (1 ).size (); x ++) {
186-
187- int attributeIndex = split .get (1 ).getNumAttributes () - 1 ;
188- //System.out.println(Arrays.toString(split.get(1).get(x).getAttributes()));
189-
190- IAttribute attribute = split .get (1 ).getAttribute (attributeIndex );
191-
192- //System.out.println(dataset.getLabelVector().);
193- System .out .println (((IntBasedCategoricalAttribute ) split .get (1 ).getAttribute (attributeIndex )).getLabelOfCategory ((int ) split .get (1 ).get (x ).getAttributeValue (attributeIndex )));
194- System .out .println (((IntBasedCategoricalAttribute ) split .get (1 ).getLabelAttribute ()).getLabelOfCategory ((int ) split .get (1 ).get (x ).getLabel ()));
195- // System.out.println(split.get(1).getAttribute());
196- System .out .println (split .get (1 ).get (x ).getAttributeValue (split .get (1 ).getNumAttributes () - 2 ) + " " + split .get (1 ).get (x ).getAttributeValue (split .get (1 ).getNumAttributes () - 1 ));
197- }
198- } catch (SplitFailedException | InterruptedException | IOException e ) {
199- e .printStackTrace ();
200- }
83+ } catch (InterruptedException | DatasetDeserializationFailedException | SplitFailedException e ) {
84+ throw new RuntimeException (e );
20185 }
86+ return null ;
20287 }
20388
20489 /**
205- * SReturns trained clssifier
90+ * Returns trained clssifier
20691 *
20792 * @param trainingSet training set
20893 * @return trained classifier
20994 */
210- public Classifier getClassifier (ILabeledDataset <?> trainingSet ) {
95+ public Classifier getClassifier (List < ILabeledDataset <?> > trainingSet ) {
21196
212- Classifier optimizedClassifier = null ;
213- /* initialize mlplan with a tiny search space, and let it run for 30 seconds */
97+ IWekaClassifier optimizedClassifier = null ;
21498
21599 try {
100+ /* initialize mlplan with a tiny search space, and let it run for 30 seconds */
216101 ai .libs .mlplan .core .MLPlan <IWekaClassifier > mlPlan = new MLPlanWekaBuilder ()
217- .withNumCpus (12 )//Set to about 12 on the server
102+ .withNumCpus (4 )//Set to about 12 on the server
218103 .withSeed (35467463 )
219104 //set default timeout
220- .withTimeOut (new Timeout (30 , TimeUnit .SECONDS ))
221- .withDataset (trainingSet )
222- .withCandidateEvaluationTimeOut (new Timeout (30 , TimeUnit .SECONDS ))
105+ .withTimeOut (new Timeout (60 , TimeUnit .SECONDS ))
106+ .withDataset (trainingSet . get ( 0 ) )
107+ /* .withCandidateEvaluationTimeOut(new Timeout(5 , TimeUnit.SECONDS))
223108 .withPortionOfDataReservedForSelection(0.0)//ignore selection phase
224109 .withPerformanceMeasureForSearchPhase(EClassificationPerformanceMeasure.F1_WITH_1_POSITIVE)//use F1
225- .withMCCVBasedCandidateEvaluationInSearchPhase (5 , .7 )
110+ .withMCCVBasedCandidateEvaluationInSearchPhase(1 , .7)*/
226111 .build ();
112+ mlPlan .setLoggerName ("mlplan-swan" );
227113
228- mlPlan .setLoggerName ("testedalgorithm" );
229-
230- long start = System .currentTimeMillis ();
231-
232- optimizedClassifier = mlPlan .call ().getClassifier ();
114+ optimizedClassifier = mlPlan .call ();
233115
234116 long trainTime = (int ) (System .currentTimeMillis () - start ) / 1000 ;
235117 LOGGER .info ("Finished build of the classifier. Training time was {}s." , trainTime );
236- LOGGER .info ("Internally believed error was {}" , mlPlan .getInternalValidationErrorOfSelectedClassifier ());
118+ LOGGER .info ("Chosen model is: {}" , mlPlan .getSelectedClassifier ());
119+
120+ /* evaluate solution produced by mlplan */
121+ SupervisedLearnerExecutor executor = new SupervisedLearnerExecutor ();
122+ ILearnerRunReport report = executor .execute (optimizedClassifier , trainingSet .get (1 ));
123+ LOGGER .info ("F-measure for ML-Plan Solution: {}" ,
124+ EClassificationPerformanceMeasure .F1_WITH_1_POSITIVE .loss (report .getPredictionDiffList ().getCastedView (Integer .class , ISingleLabelClassification .class )));
237125
238- } catch (IOException | AlgorithmTimeoutedException | InterruptedException | AlgorithmException | AlgorithmExecutionCanceledException e ) {
126+ } catch (IOException | AlgorithmTimeoutedException | InterruptedException | AlgorithmException |
127+ AlgorithmExecutionCanceledException e ) {
239128 e .printStackTrace ();
129+ } catch (LearnerExecutionFailedException e ) {
130+ throw new RuntimeException (e );
240131 }
241- return optimizedClassifier ;
242- }
243-
244- public static void maihn (String [] args ) {
245-
246- String file = "swan/swan_core/src/main/resources/waveform.arff" ;
247-
248- MLPlan mlPlan = new MLPlan ();
249- // mlPlan.evaluateDataset(file, "sdfs");
132+ return optimizedClassifier .getClassifier ();
250133 }
251- }
134+ }
0 commit comments