Skip to content

Commit d77a1a4

Browse files
committed
Refactor ML-Plan configuration and integration
1 parent 2c883e6 commit d77a1a4

File tree

1 file changed

+57
-174
lines changed
  • swan-pipeline/src/main/java/de/fraunhofer/iem/swan/model/toolkit

1 file changed

+57
-174
lines changed
Lines changed: 57 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,29 @@
11
package de.fraunhofer.iem.swan.model.toolkit;
22

33
import ai.libs.jaicore.ml.classification.loss.dataset.EClassificationPerformanceMeasure;
4-
import ai.libs.jaicore.ml.core.dataset.schema.attribute.IntBasedCategoricalAttribute;
54
import ai.libs.jaicore.ml.core.dataset.serialization.ArffDatasetAdapter;
5+
import ai.libs.jaicore.ml.core.evaluation.evaluator.SupervisedLearnerExecutor;
66
import ai.libs.jaicore.ml.core.filter.SplitterUtil;
77
import ai.libs.jaicore.ml.weka.classification.learner.IWekaClassifier;
88
import ai.libs.mlplan.weka.MLPlanWekaBuilder;
9-
import de.fraunhofer.iem.swan.model.MonteCarloValidator;
9+
import de.fraunhofer.iem.swan.cli.SwanOptions;
10+
import de.fraunhofer.iem.swan.features.WekaFeatureSet;
11+
import de.fraunhofer.iem.swan.model.ModelEvaluator;
1012
import de.fraunhofer.iem.swan.util.Util;
11-
import org.api4.java.ai.ml.core.dataset.schema.attribute.IAttribute;
13+
import org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassification;
1214
import org.api4.java.ai.ml.core.dataset.serialization.DatasetDeserializationFailedException;
1315
import org.api4.java.ai.ml.core.dataset.splitter.SplitFailedException;
1416
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
17+
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
18+
import org.api4.java.ai.ml.core.evaluation.execution.ILearnerRunReport;
19+
import org.api4.java.ai.ml.core.evaluation.execution.LearnerExecutionFailedException;
1520
import org.api4.java.algorithm.Timeout;
1621
import org.api4.java.algorithm.exceptions.AlgorithmException;
1722
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
1823
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
1924
import org.slf4j.Logger;
2025
import org.slf4j.LoggerFactory;
2126
import weka.classifiers.Classifier;
22-
import weka.core.Instances;
23-
import weka.core.converters.ArffLoader;
2427

2528
import java.io.File;
2629
import java.io.IOException;
@@ -37,215 +40,95 @@ public class MLPlan {
3740

3841
private static final Logger LOGGER = LoggerFactory.getLogger(MLPlan.class);
3942
private final int ITERATIONS = 1;
43+
private WekaFeatureSet featureSet;
44+
private SwanOptions swanOptions;
45+
long start;
4046

41-
public MLPlan() {
42-
47+
public MLPlan(WekaFeatureSet features, SwanOptions options) {
48+
this.featureSet = features;
49+
swanOptions = options;
4350
}
4451

4552
/**
46-
* Run ML-Plan using the provided path to the ARFF file.
53+
* Trains and evaluates the model with the given training data and specified classification mode.
4754
*
48-
* @param instances1 file path for ARFF file
55+
* @return Hashmap containing the name of the classifier and it's F-Measure
4956
*/
50-
public HashMap<String, ArrayList<Double>> evaluateDataset(Instances instances1) {
51-
52-
String arffFilePath = Util.exportInstancesToArff(instances1, "mlplan");
53-
ArffDatasetAdapter arffDatasetAdapter = new ArffDatasetAdapter();
57+
public HashMap<String, HashMap<String, String>> trainModel() {
5458

55-
String mClass = Util.getClassName(instances1);
59+
switch (ModelEvaluator.Phase.valueOf(swanOptions.getPhase().toUpperCase())) {
60+
case VALIDATE:
5661

57-
long start = System.currentTimeMillis();
58-
59-
//Initialize dataset using ARFF file path
60-
ILabeledDataset<?> dataset = null;
61-
try {
62-
dataset = arffDatasetAdapter.readDataset(new File(arffFilePath));
63-
} catch (DatasetDeserializationFailedException e) {
64-
e.printStackTrace();
62+
evaluateData(Util.exportInstancesToArff(featureSet.getTrainInstances().get("sanitizer"), "mlplan"));
63+
return null;
64+
case PREDICT:
6565
}
66-
67-
//dataset.removeColumn("id");
68-
69-
ArrayList<Double> fScores = new ArrayList<>();
70-
ArrayList<String> algorithms = new ArrayList<>();
71-
72-
MonteCarloValidator monteCarloValidator = new MonteCarloValidator();
73-
74-
ArffLoader loader = new ArffLoader();
75-
try {
76-
loader.setFile(new File(arffFilePath));
77-
Instances instances = loader.getDataSet();
78-
instances.setClassIndex(instances.numAttributes() - 1);
79-
// monteCarloValidator.initializeResultSet(instances);
80-
} catch (IOException e) {
81-
e.printStackTrace();
82-
}
83-
84-
85-
//For each iteration, create a new train-test-split and run ML-Plan
86-
for (int iteration = 0; iteration < ITERATIONS; iteration++) {
87-
88-
System.out.println("Iteration #"+iteration);
89-
try {
90-
List<ILabeledDataset<?>> split = SplitterUtil.getLabelStratifiedTrainTestSplit(dataset, new Random(1337 + (iteration * 11)), 0.7);
91-
LOGGER.info("Data read. Time to create dataset object was {}ms", System.currentTimeMillis() - start);
92-
93-
Classifier optimizedClassifier = getClassifier(split.get(0));
94-
//System.out.println("Classify: " + optimizedClassifier.getClassifier().getClass().getSimpleName());
95-
96-
//optimizedClassifier.fit(split.get(0));
97-
98-
String trainPath = "swan/swan_core/swan-out/mlplan/train-methods-dataset.arff";
99-
arffDatasetAdapter.serializeDataset(new File(trainPath), split.get(0));
100-
ArffLoader trainLoader = new ArffLoader();
101-
trainLoader.setFile(new File(trainPath));
102-
Instances trainInstances = trainLoader.getDataSet();
103-
trainInstances.setClassIndex(trainInstances.numAttributes() - 1);
104-
105-
String testPath = "swan/swan_core/swan-out/mlplan/test-methods-dataset.arff";
106-
arffDatasetAdapter.serializeDataset(new File(testPath), split.get(1));
107-
ArffLoader testLoader = new ArffLoader();
108-
testLoader.setFile(new File(testPath));
109-
Instances testInstances = testLoader.getDataSet();
110-
testInstances.setClassIndex(testInstances.numAttributes() - 1);
111-
112-
monteCarloValidator.evaluate(optimizedClassifier, trainInstances, testInstances);
113-
114-
115-
/* evaluate solution produced by mlplan */
116-
/* SupervisedLearnerExecutor executor = new SupervisedLearnerExecutor();
117-
ILearnerRunReport report = executor.execute(optimizedClassifier, split.get(0), split.get(1));
118-
119-
for (Object pred : report.getPredictionDiffList().getPredictionsAsList()) {
120-
121-
SingleLabelClassification cl = (SingleLabelClassification) pred;
122-
123-
}
124-
125-
for (Object prediction : report.getPredictionDiffList().getPredictionsAsList()) {
126-
127-
SingleLabelClassification label = (SingleLabelClassification) prediction;
128-
}
129-
130-
131-
132-
133-
LOGGER.info("Model selected: {},{},{},{}", mClass, iteration,
134-
optimizedClassifier.getClassifier().getClass().getSimpleName(),
135-
EClassificationPerformanceMeasure.F1_WITH_1_POSITIVE.F1_WITH_1_POSITIVE.loss(report.getPredictionDiffList().getCastedView(Integer.class, ISingleLabelClassification.class)));
136-
*/
137-
//fScores.add(EClassificationPerformanceMeasure.F1_WITH_1_POSITIVE.loss(report.getPredictionDiffList().getCastedView(Integer.class, ISingleLabelClassification.class)));
138-
//algorithms.add(optimizedClassifier.getClassifier().getClass().getSimpleName());
139-
//LOGGER.info("Error Rate of the solution produced by ML-Plan: {}. ", );
140-
141-
} catch (SplitFailedException | InterruptedException | IOException e) {
142-
e.printStackTrace();
143-
}
144-
}
145-
return monteCarloValidator.getFMeasure();
66+
return null;
14667
}
14768

148-
public void evaluateDataset(Instances instances, int k) {
69+
public HashMap<String, ArrayList<Double>> evaluateData(String arffFilePath) {
14970

150-
//arffFilePath = "swan/swan_core/src/main/resources/waveform.arff";
151-
String arffFilePath = Util.exportInstancesToArff(instances, "mlplan");
152-
ArffDatasetAdapter arffDatasetAdapter = new ArffDatasetAdapter();
71+
start = System.currentTimeMillis();
15372

154-
String mClass = Util.getClassName(instances);
155-
156-
long start = System.currentTimeMillis();
157-
158-
//Initialize dataset using ARFF file path
159-
ILabeledDataset<?> dataset = null;
16073
try {
161-
dataset = arffDatasetAdapter.readDataset(new File(arffFilePath));
162-
} catch (DatasetDeserializationFailedException e) {
163-
e.printStackTrace();
164-
}
165-
166-
//dataset.removeColumn("id");
167-
168-
MonteCarloValidator monteCarloValidator = new MonteCarloValidator();
16974

170-
//For each iteration, create a new train-test-split and run ML-Plan
171-
for (int iteration = 0; iteration < ITERATIONS; iteration++) {
75+
ArffDatasetAdapter arffDatasetAdapter = new ArffDatasetAdapter();
76+
ILabeledDataset<ILabeledInstance> dataset = arffDatasetAdapter.readDataset(new File(arffFilePath));
17277

173-
try {
174-
List<ILabeledDataset<?>> split = SplitterUtil.getLabelStratifiedTrainTestSplit(dataset, new Random(1337 + (iteration * 11)), 0.7);
175-
LOGGER.info("Data read. Time to create dataset object was {}ms", System.currentTimeMillis() - start);
78+
List<ILabeledDataset<?>> split = SplitterUtil.getLabelStratifiedTrainTestSplit(dataset, new Random(42), .7);
79+
LOGGER.info("Data read. Time to create dataset object was {}ms", System.currentTimeMillis() - start);
17680

177-
System.out.println(split.get(1).getLabelAttribute().getName());
178-
for (IAttribute attribute : split.get(1).getListOfAttributes()) {
81+
getClassifier(split);
17982

180-
// System.out.println(attribute.getName());
181-
}
182-
arffDatasetAdapter.serializeDataset(new File("swan/swan_core/swan-out/mlplan/methods-dataset.arff"), split.get(1));
183-
184-
185-
for (int x = 0; x < split.get(1).size(); x++) {
186-
187-
int attributeIndex = split.get(1).getNumAttributes() - 1;
188-
//System.out.println(Arrays.toString(split.get(1).get(x).getAttributes()));
189-
190-
IAttribute attribute = split.get(1).getAttribute(attributeIndex);
191-
192-
//System.out.println(dataset.getLabelVector().);
193-
System.out.println(((IntBasedCategoricalAttribute) split.get(1).getAttribute(attributeIndex)).getLabelOfCategory((int) split.get(1).get(x).getAttributeValue(attributeIndex)));
194-
System.out.println(((IntBasedCategoricalAttribute) split.get(1).getLabelAttribute()).getLabelOfCategory((int) split.get(1).get(x).getLabel()));
195-
// System.out.println(split.get(1).getAttribute());
196-
System.out.println(split.get(1).get(x).getAttributeValue(split.get(1).getNumAttributes() - 2) + " " + split.get(1).get(x).getAttributeValue(split.get(1).getNumAttributes() - 1));
197-
}
198-
} catch (SplitFailedException | InterruptedException | IOException e) {
199-
e.printStackTrace();
200-
}
83+
} catch (InterruptedException | DatasetDeserializationFailedException | SplitFailedException e) {
84+
throw new RuntimeException(e);
20185
}
86+
return null;
20287
}
20388

20489
/**
205-
* SReturns trained clssifier
90+
* Returns trained clssifier
20691
*
20792
* @param trainingSet training set
20893
* @return trained classifier
20994
*/
210-
public Classifier getClassifier(ILabeledDataset<?> trainingSet) {
95+
public Classifier getClassifier(List<ILabeledDataset<?>> trainingSet) {
21196

212-
Classifier optimizedClassifier = null;
213-
/* initialize mlplan with a tiny search space, and let it run for 30 seconds */
97+
IWekaClassifier optimizedClassifier = null;
21498

21599
try {
100+
/* initialize mlplan with a tiny search space, and let it run for 30 seconds */
216101
ai.libs.mlplan.core.MLPlan<IWekaClassifier> mlPlan = new MLPlanWekaBuilder()
217-
.withNumCpus(12)//Set to about 12 on the server
102+
.withNumCpus(4)//Set to about 12 on the server
218103
.withSeed(35467463)
219104
//set default timeout
220-
.withTimeOut(new Timeout(30, TimeUnit.SECONDS))
221-
.withDataset(trainingSet)
222-
.withCandidateEvaluationTimeOut(new Timeout(30, TimeUnit.SECONDS))
105+
.withTimeOut(new Timeout(60, TimeUnit.SECONDS))
106+
.withDataset(trainingSet.get(0))
107+
/*.withCandidateEvaluationTimeOut(new Timeout(5, TimeUnit.SECONDS))
223108
.withPortionOfDataReservedForSelection(0.0)//ignore selection phase
224109
.withPerformanceMeasureForSearchPhase(EClassificationPerformanceMeasure.F1_WITH_1_POSITIVE)//use F1
225-
.withMCCVBasedCandidateEvaluationInSearchPhase(5, .7)
110+
.withMCCVBasedCandidateEvaluationInSearchPhase(1, .7)*/
226111
.build();
112+
mlPlan.setLoggerName("mlplan-swan");
227113

228-
mlPlan.setLoggerName("testedalgorithm");
229-
230-
long start = System.currentTimeMillis();
231-
232-
optimizedClassifier = mlPlan.call().getClassifier();
114+
optimizedClassifier = mlPlan.call();
233115

234116
long trainTime = (int) (System.currentTimeMillis() - start) / 1000;
235117
LOGGER.info("Finished build of the classifier. Training time was {}s.", trainTime);
236-
LOGGER.info("Internally believed error was {}", mlPlan.getInternalValidationErrorOfSelectedClassifier());
118+
LOGGER.info("Chosen model is: {}", mlPlan.getSelectedClassifier());
119+
120+
/* evaluate solution produced by mlplan */
121+
SupervisedLearnerExecutor executor = new SupervisedLearnerExecutor();
122+
ILearnerRunReport report = executor.execute(optimizedClassifier, trainingSet.get(1));
123+
LOGGER.info("F-measure for ML-Plan Solution: {}",
124+
EClassificationPerformanceMeasure.F1_WITH_1_POSITIVE.loss(report.getPredictionDiffList().getCastedView(Integer.class, ISingleLabelClassification.class)));
237125

238-
} catch (IOException | AlgorithmTimeoutedException | InterruptedException | AlgorithmException | AlgorithmExecutionCanceledException e) {
126+
} catch (IOException | AlgorithmTimeoutedException | InterruptedException | AlgorithmException |
127+
AlgorithmExecutionCanceledException e) {
239128
e.printStackTrace();
129+
} catch (LearnerExecutionFailedException e) {
130+
throw new RuntimeException(e);
240131
}
241-
return optimizedClassifier;
242-
}
243-
244-
public static void maihn(String[] args) {
245-
246-
String file = "swan/swan_core/src/main/resources/waveform.arff";
247-
248-
MLPlan mlPlan = new MLPlan();
249-
// mlPlan.evaluateDataset(file, "sdfs");
132+
return optimizedClassifier.getClassifier();
250133
}
251-
}
134+
}

0 commit comments

Comments
 (0)