|
| 1 | +package de.fraunhofer.iem.swan.model; |
| 2 | + |
| 3 | +import ai.libs.jaicore.logging.LoggerUtil; |
| 4 | +import ai.libs.jaicore.ml.weka.dataset.WekaInstances; |
| 5 | +import ai.libs.mlplan.multiclass.wekamlplan.MLPlanWekaBuilder; |
| 6 | +import de.fraunhofer.iem.swan.util.Util; |
| 7 | +import org.api4.java.ai.ml.core.dataset.serialization.DatasetDeserializationFailedException; |
| 8 | +import org.api4.java.ai.ml.core.dataset.splitter.SplitFailedException; |
| 9 | +import org.api4.java.algorithm.exceptions.AlgorithmException; |
| 10 | +import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException; |
| 11 | +import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException; |
| 12 | +import org.slf4j.Logger; |
| 13 | +import org.slf4j.LoggerFactory; |
| 14 | +import weka.classifiers.Classifier; |
| 15 | +import weka.classifiers.meta.FilteredClassifier; |
| 16 | +import weka.core.Instances; |
| 17 | + |
| 18 | +import java.io.IOException; |
| 19 | +import java.util.HashMap; |
| 20 | +import java.util.NoSuchElementException; |
| 21 | + |
| 22 | +/** |
| 23 | + * @author Oshando Johnson on 27.09.20 |
| 24 | + */ |
| 25 | +public class MLPlanExecutor { |
| 26 | + |
| 27 | + private static final Logger LOGGER = LoggerFactory.getLogger(MLPlanExecutor.class); |
| 28 | + FilteredClassifier filteredClassifier; |
| 29 | + |
| 30 | + public MLPlanExecutor() { |
| 31 | + this.filteredClassifier = new FilteredClassifier(); |
| 32 | + } |
| 33 | + |
| 34 | + /** |
| 35 | + * Runs ML-Plan with the given instance set. |
| 36 | + * @param instances training instances |
| 37 | + * @return hashmap containing F1 measures |
| 38 | + */ |
| 39 | + public HashMap<String, String> run(Instances instances) { |
| 40 | + |
| 41 | + long start = System.currentTimeMillis(); |
| 42 | + |
| 43 | + instances.setClassIndex(instances.numAttributes() - 1); |
| 44 | + |
| 45 | + HashMap<String, String> results = new HashMap<>(); |
| 46 | + |
| 47 | + //Find classifier using ML-Plan |
| 48 | + try { |
| 49 | + |
| 50 | + Classifier classifier = new MLPlanWekaBuilder() |
| 51 | + .withDataset(new WekaInstances(instances)) |
| 52 | + .withPortionOfDataReservedForSelection(.8) |
| 53 | + //.withNumCpus(4) |
| 54 | + //.withTimeOut(new Timeout(300, TimeUnit.SECONDS)) |
| 55 | + .withMCCVBasedCandidateEvaluationInSearchPhase(10, .8) |
| 56 | + .build() |
| 57 | + .call() |
| 58 | + .getClassifier(); |
| 59 | + |
| 60 | + long trainTime = (int) (System.currentTimeMillis() - start) / 1000; |
| 61 | + filteredClassifier.setClassifier(classifier); |
| 62 | + LOGGER.info("Finished build of the classifier. Training time was {}s.", trainTime); |
| 63 | + |
| 64 | + LOGGER.info("Chosen model is: {}", (filteredClassifier.getClassifier().getClass().getSimpleName())); |
| 65 | + } catch (IOException | AlgorithmTimeoutedException | InterruptedException | AlgorithmException | AlgorithmExecutionCanceledException e) { |
| 66 | + e.printStackTrace(); |
| 67 | + } |
| 68 | + |
| 69 | + //Evaluate classifier using MCCV |
| 70 | + try { |
| 71 | + |
| 72 | + /* evaluate solution produced by mlplan */ |
| 73 | + ModelEvaluator modelEvaluator = new ModelEvaluator(); |
| 74 | + results = modelEvaluator.monteCarloValidate(instances, filteredClassifier, 0.8, 10); |
| 75 | + |
| 76 | + //Output evaluation results |
| 77 | + for(String category: results.keySet()){ |
| 78 | + System.out.println("---" + category + "---"); |
| 79 | + System.out.println(filteredClassifier.getClassifier().getClass().getSimpleName()+ ";"+ |
| 80 | + results.get(category).replace(".", ",").substring(0, results.get(category).lastIndexOf(";"))); |
| 81 | + } |
| 82 | + |
| 83 | + long totalTime = (int) (System.currentTimeMillis() - start) / 1000; |
| 84 | + LOGGER.info("ML-Plan execution completed. Total time {}s.", totalTime); |
| 85 | + return modelEvaluator.getfMeasure(); |
| 86 | + } catch (NoSuchElementException e) { |
| 87 | + LOGGER.error("Building the classifier failed: {}", LoggerUtil.getExceptionInfo(e)); |
| 88 | + } |
| 89 | + return results; |
| 90 | + } |
| 91 | + |
| 92 | + /** |
| 93 | + * Returns classifier selected by ML-Plan. |
| 94 | + * @return WEKA classifier |
| 95 | + */ |
| 96 | + public FilteredClassifier getClassifier() { |
| 97 | + return filteredClassifier; |
| 98 | + } |
| 99 | + |
| 100 | + public static void main(String[] args) throws DatasetDeserializationFailedException, IOException, InterruptedException, SplitFailedException, AlgorithmExecutionCanceledException, AlgorithmTimeoutedException, AlgorithmException { |
| 101 | + |
| 102 | + Instances dataset = Util.loadArffFile("/Users/oshando/Projects/thesis/03-code/swan/swan_core/swan-out/weka/Train_sanitizer_none.arff"); |
| 103 | + //Instances dataset = Util.loadArffFile("/Users/oshando/Projects/thesis/03-code/swan/swan_core/src/main/resources/waveform.arff"); |
| 104 | + |
| 105 | + MLPlanExecutor mlPlan = new MLPlanExecutor(); |
| 106 | + mlPlan.run(dataset); |
| 107 | + } |
| 108 | +} |
0 commit comments