Skip to content

Commit 9182a91

Browse files
committed
Add test set prediction using WEKA
1 parent 7d54c6b commit 9182a91

File tree

1 file changed

+39
-2
lines changed

1 file changed

+39
-2
lines changed

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/features/WekaFeatureSet.java

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,46 @@ public void createFeatures() {
113113
}
114114

115115
//Set attributes for the test instances
116-
if (options.isPredictPhase()) {
117-
//TODO implement predict phase for WEKA
116+
if (options.getPhase().toUpperCase().contentEquals(ModelEvaluator.Phase.PREDICT.name())) {
117+
118+
for (Category category : options.getAllClasses().stream().map(Category::fromText).collect(Collectors.toList())) {
119+
120+
createAttributes(category, dataset.getTestMethods());
121+
evaluateFeatureData(dataset.getTestMethods());
122+
123+
Instances testInstances = new Instances(structures.get(category.getId().toLowerCase()));
124+
testInstances.setRelationName(testInstances.relationName() + "-test");
125+
126+
//Replace existing method IDs with test set IDs
127+
Attribute idAttr = new Attribute("id", dataset.getTestMethods().stream().map(Method::getArffSafeSignature).collect(Collectors.toList()));
128+
testInstances.replaceAttributeAt(idAttr, testInstances.attribute("id").index());
129+
130+
ArrayList<Attribute> aList = Collections.list(testInstances.enumerateAttributes());
131+
132+
Instances tInstances = createInstances(testInstances, aList, dataset.getTestMethods(), Collections.singleton(category));
133+
tInstances.setClassIndex(tInstances.numAttributes() - 1);
134+
135+
logger.info("Creating {} TEST dataset with {} features",
136+
category.getId(), tInstances.numAttributes());
137+
138+
this.testInstances.put(category.getId().toLowerCase(), tInstances);
139+
}
118140
}
141+
}
142+
143+
public Instances filterInstances(Instances instances, Set<Method> methods) {
144+
145+
Set<String> train = methods.stream().map(Method::getArffSafeSignature).collect(Collectors.toSet());
146+
147+
for (int i = instances.numInstances() - 1; i >= 0; i--) {
148+
String instanceId = instances.get(i).stringValue(instances.attribute("id").index());
149+
150+
if (!train.contains(instanceId)) {
151+
instances.remove(i);
152+
}
153+
}
154+
return instances;
155+
}
119156

120157
public Instances performAttributeSelection(Instances instances) {
121158

0 commit comments

Comments
 (0)