|
1 | 1 | package de.fraunhofer.iem.swan.model; |
2 | 2 |
|
3 | | -import ai.libs.jaicore.logging.LoggerUtil; |
4 | | -import ai.libs.jaicore.ml.weka.dataset.WekaInstances; |
| 3 | +import ai.libs.jaicore.ml.classification.loss.dataset.EClassificationPerformanceMeasure; |
| 4 | +import ai.libs.jaicore.ml.core.dataset.serialization.ArffDatasetAdapter; |
| 5 | +import ai.libs.jaicore.ml.core.evaluation.evaluator.SupervisedLearnerExecutor; |
| 6 | +import ai.libs.jaicore.ml.core.filter.SplitterUtil; |
| 7 | +import ai.libs.jaicore.ml.weka.classification.learner.IWekaClassifier; |
| 8 | +import ai.libs.mlplan.core.MLPlan; |
5 | 9 | import ai.libs.mlplan.multiclass.wekamlplan.MLPlanWekaBuilder; |
6 | | -import de.fraunhofer.iem.swan.util.Util; |
| 10 | +import org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassification; |
7 | 11 | import org.api4.java.ai.ml.core.dataset.serialization.DatasetDeserializationFailedException; |
8 | 12 | import org.api4.java.ai.ml.core.dataset.splitter.SplitFailedException; |
| 13 | +import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset; |
| 14 | +import org.api4.java.ai.ml.core.evaluation.execution.ILearnerRunReport; |
| 15 | +import org.api4.java.ai.ml.core.evaluation.execution.LearnerExecutionFailedException; |
| 16 | +import org.api4.java.algorithm.Timeout; |
9 | 17 | import org.api4.java.algorithm.exceptions.AlgorithmException; |
10 | 18 | import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException; |
11 | 19 | import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException; |
12 | 20 | import org.slf4j.Logger; |
13 | 21 | import org.slf4j.LoggerFactory; |
14 | | -import weka.classifiers.Classifier; |
15 | | -import weka.classifiers.meta.FilteredClassifier; |
16 | 22 | import weka.core.Instances; |
17 | 23 |
|
| 24 | +import java.io.File; |
18 | 25 | import java.io.IOException; |
19 | | -import java.util.HashMap; |
20 | | -import java.util.NoSuchElementException; |
| 26 | +import java.util.List; |
| 27 | +import java.util.Random; |
| 28 | +import java.util.concurrent.TimeUnit; |
21 | 29 |
|
22 | 30 | /** |
23 | 31 | * @author Oshando Johnson on 27.09.20 |
24 | 32 | */ |
25 | 33 | public class MLPlanExecutor { |
26 | 34 |
|
27 | 35 | private static final Logger LOGGER = LoggerFactory.getLogger(MLPlanExecutor.class); |
28 | | - FilteredClassifier filteredClassifier; |
| 36 | + private final int ITERATIONS = 10; |
29 | 37 |
|
30 | 38 | public MLPlanExecutor() { |
31 | | - this.filteredClassifier = new FilteredClassifier(); |
| 39 | + |
32 | 40 | } |
33 | 41 |
|
34 | 42 | /** |
35 | | - * Runs ML-Plan with the given instance set. |
36 | | - * @param instances training instances |
37 | | - * @return hashmap containing F1 measures |
| 43 | + * Run ML-Plan using the provided path to the ARFF file. |
| 44 | + * |
| 45 | + * @param arffFilePath file path for ARFF file |
38 | 46 | */ |
39 | | - public HashMap<String, String> run(Instances instances) { |
| 47 | + public void evaluateDataset(String arffFilePath) { |
40 | 48 |
|
41 | 49 | long start = System.currentTimeMillis(); |
42 | 50 |
|
43 | | - instances.setClassIndex(instances.numAttributes() - 1); |
44 | | - |
45 | | - HashMap<String, String> results = new HashMap<>(); |
46 | | - |
47 | | - //Find classifier using ML-Plan |
| 51 | + //Initialize dataset using ARFF file path |
| 52 | + ILabeledDataset<?> dataset = null; |
48 | 53 | try { |
| 54 | + dataset = ArffDatasetAdapter.readDataset(new File(arffFilePath)); |
| 55 | + } catch (DatasetDeserializationFailedException e) { |
| 56 | + e.printStackTrace(); |
| 57 | + } |
49 | 58 |
|
50 | | - Classifier classifier = new MLPlanWekaBuilder() |
51 | | - .withDataset(new WekaInstances(instances)) |
52 | | - .withPortionOfDataReservedForSelection(.8) |
53 | | - //.withNumCpus(4) |
54 | | - //.withTimeOut(new Timeout(300, TimeUnit.SECONDS)) |
55 | | - .withMCCVBasedCandidateEvaluationInSearchPhase(10, .8) |
56 | | - .build() |
57 | | - .call() |
58 | | - .getClassifier(); |
| 59 | + //For each iteration, create a new train-test-split and run ML-Plan |
| 60 | + for (int iteration = 0; iteration < ITERATIONS; iteration++) { |
59 | 61 |
|
60 | | - long trainTime = (int) (System.currentTimeMillis() - start) / 1000; |
61 | | - filteredClassifier.setClassifier(classifier); |
62 | | - LOGGER.info("Finished build of the classifier. Training time was {}s.", trainTime); |
| 62 | + try { |
| 63 | + List<ILabeledDataset<?>> split = SplitterUtil.getLabelStratifiedTrainTestSplit(dataset, new Random(1337 + (iteration * 11)), 0.7); |
| 64 | + LOGGER.info("Data read. Time to create dataset object was {}ms", System.currentTimeMillis() - start); |
63 | 65 |
|
64 | | - LOGGER.info("Chosen model is: {}", (filteredClassifier.getClassifier().getClass().getSimpleName())); |
65 | | - } catch (IOException | AlgorithmTimeoutedException | InterruptedException | AlgorithmException | AlgorithmExecutionCanceledException e) { |
66 | | - e.printStackTrace(); |
67 | | - } |
| 66 | + IWekaClassifier optimizedClassifier = getClassifier(split.get(0)); |
68 | 67 |
|
69 | | - //Evaluate classifier using MCCV |
70 | | - try { |
| 68 | + /* evaluate solution produced by mlplan */ |
| 69 | + SupervisedLearnerExecutor executor = new SupervisedLearnerExecutor(); |
| 70 | + ILearnerRunReport report = executor.execute(optimizedClassifier, split.get(0), split.get(1)); |
71 | 71 |
|
72 | | - /* evaluate solution produced by mlplan */ |
73 | | - ModelEvaluator modelEvaluator = new ModelEvaluator(); |
74 | | - results = modelEvaluator.monteCarloValidate(instances, filteredClassifier, 0.8, 10); |
| 72 | + System.out.println(report.getPredictionDiffList().getPredictionsAsList().toString()); |
75 | 73 |
|
76 | | - //Output evaluation results |
77 | | - for(String category: results.keySet()){ |
78 | | - System.out.println("---" + category + "---"); |
79 | | - System.out.println(filteredClassifier.getClassifier().getClass().getSimpleName()+ ";"+ |
80 | | - results.get(category).replace(".", ",").substring(0, results.get(category).lastIndexOf(";"))); |
81 | | - } |
| 74 | + LOGGER.info("Error Rate of the solution produced by ML-Plan: {}. ", |
| 75 | + EClassificationPerformanceMeasure.F1_WITH_1_POSITIVE.loss(report.getPredictionDiffList().getCastedView(Integer.class, ISingleLabelClassification.class))); |
82 | 76 |
|
83 | | - long totalTime = (int) (System.currentTimeMillis() - start) / 1000; |
84 | | - LOGGER.info("ML-Plan execution completed. Total time {}s.", totalTime); |
85 | | - return modelEvaluator.getfMeasure(); |
86 | | - } catch (NoSuchElementException e) { |
87 | | - LOGGER.error("Building the classifier failed: {}", LoggerUtil.getExceptionInfo(e)); |
| 77 | + } catch (SplitFailedException | InterruptedException | LearnerExecutionFailedException e) { |
| 78 | + e.printStackTrace(); |
| 79 | + } |
88 | 80 | } |
89 | | - return results; |
90 | 81 | } |
91 | 82 |
|
92 | 83 | /** |
93 | | - * Returns classifier selected by ML-Plan. |
94 | | - * @return WEKA classifier |
| 84 | + * SReturns trained clssifier |
| 85 | + * @param trainingSet training set |
| 86 | + * @return trained classifier |
95 | 87 | */ |
96 | | - public FilteredClassifier getClassifier() { |
97 | | - return filteredClassifier; |
| 88 | + public IWekaClassifier getClassifier(ILabeledDataset<?> trainingSet) { |
| 89 | + |
| 90 | + IWekaClassifier optimizedClassifier = null; |
| 91 | + /* initialize mlplan with a tiny search space, and let it run for 30 seconds */ |
| 92 | + |
| 93 | + try { |
| 94 | + MLPlan<IWekaClassifier> mlPlan = new MLPlanWekaBuilder() |
| 95 | + .withNumCpus(8)//Set to about 12 on the server |
| 96 | + .withSeed(35467463) |
| 97 | + //set default timeout |
| 98 | + .withTimeOut(new Timeout(30, TimeUnit.SECONDS)) |
| 99 | + .withDataset(trainingSet) |
| 100 | + .withPortionOfDataReservedForSelection(0.0)//ignore selection phase |
| 101 | + .withPerformanceMeasureForSearchPhase(EClassificationPerformanceMeasure.F1_WITH_1_POSITIVE)//use F1 |
| 102 | + .withMCCVBasedCandidateEvaluationInSearchPhase(10, .7) |
| 103 | + .build(); |
| 104 | + |
| 105 | + mlPlan.setLoggerName("testedalgorithm"); |
| 106 | + |
| 107 | + long start = System.currentTimeMillis(); |
| 108 | + |
| 109 | + optimizedClassifier = mlPlan.call(); |
| 110 | + |
| 111 | + long trainTime = (int) (System.currentTimeMillis() - start) / 1000; |
| 112 | + LOGGER.info("Finished build of the classifier. Training time was {}s.", trainTime); |
| 113 | + LOGGER.info("Chosen model is: {}", (mlPlan.getSelectedClassifier())); |
| 114 | + LOGGER.info("Internally believed error was {}", mlPlan.getInternalValidationErrorOfSelectedClassifier()); |
| 115 | + |
| 116 | + } catch (IOException | AlgorithmTimeoutedException | InterruptedException | AlgorithmException | AlgorithmExecutionCanceledException e) { |
| 117 | + e.printStackTrace(); |
| 118 | + } |
| 119 | + return optimizedClassifier; |
98 | 120 | } |
99 | 121 |
|
100 | | - public static void main(String[] args) throws DatasetDeserializationFailedException, IOException, InterruptedException, SplitFailedException, AlgorithmExecutionCanceledException, AlgorithmTimeoutedException, AlgorithmException { |
| 122 | + public static void main(String[] args) { |
101 | 123 |
|
102 | | - Instances dataset = Util.loadArffFile("/Users/oshando/Projects/thesis/03-code/swan/swan_core/swan-out/weka/Train_sanitizer_none.arff"); |
103 | | - //Instances dataset = Util.loadArffFile("/Users/oshando/Projects/thesis/03-code/swan/swan_core/src/main/resources/waveform.arff"); |
| 124 | + String file = "/Users/oshando/Projects/thesis/03-code/swan/swan_core/src/main/resources/waveform.arff"; |
104 | 125 |
|
105 | 126 | MLPlanExecutor mlPlan = new MLPlanExecutor(); |
106 | | - mlPlan.run(dataset); |
| 127 | + mlPlan.evaluateDataset(file); |
107 | 128 | } |
108 | 129 | } |
0 commit comments