Skip to content

Commit 7d54c6b

Browse files
committed
Implement attribute selection for WEKA
1 parent 927107d commit 7d54c6b

File tree

1 file changed

+50
-12
lines changed

1 file changed

+50
-12
lines changed

swan-pipeline/src/main/java/de/fraunhofer/iem/swan/features/WekaFeatureSet.java

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ public void createFeatures() {
5555
structures.put(category.getId().toLowerCase(), new Instances("weka-", trainAttributes, 0));
5656

5757
Instances trainInstances = createInstances(trainAttributes, dataset.getTrainMethods(), Collections.singleton(category));
58+
59+
if (options.isReduceAttributes())
60+
trainInstances = performAttributeSelection(trainInstances);
61+
5862
trainInstances.setClassIndex(trainInstances.numAttributes() - 1);
5963
this.trainInstances.put(category.getId().toLowerCase(), trainInstances);
6064
Util.exportInstancesToArff(trainInstances, category.getId());
@@ -88,9 +92,20 @@ public void createFeatures() {
8892
}
8993
}
9094

91-
this.trainInstances.put(category.getId().toLowerCase(), trainInstances);
92-
structures.put(category.getId().toLowerCase(), structure);
95+
logger.info("Using default {} TRAIN dataset(s) file(s) in {} with {} features and {} instances",
96+
category.getId(), instancesFile, trainInstances.numAttributes(), trainInstances.numInstances());
97+
98+
if (options.isReduceAttributes()) {
99+
100+
Instances originalInstances = trainInstances;
101+
trainInstances = performAttributeSelection(trainInstances);
102+
103+
logger.debug("Performing feature selection on {} TRAIN dataset(s), {} reduced to {} features ",
104+
category.getId(), originalInstances.numAttributes(), trainInstances.numAttributes());
105+
}
93106

107+
this.trainInstances.put(category.getId().toLowerCase(), filterInstances(trainInstances, dataset.getTrainMethods()));
108+
structures.put(category.getId().toLowerCase(), structure);
94109
} catch (IOException e) {
95110
e.printStackTrace();
96111
}
@@ -102,21 +117,44 @@ public void createFeatures() {
102117
//TODO implement predict phase for WEKA
103118
}
104119

120+
public Instances performAttributeSelection(Instances instances) {
121+
122+
//CfsSubsetEval eval = new CfsSubsetEval();
123+
//CorrelationAttributeEval eval = new CorrelationAttributeEval();
124+
InfoGainAttributeEval eval = new InfoGainAttributeEval();
125+
//ReliefFAttributeEval eval = new ReliefFAttributeEval();
126+
127+
//Set search method
128+
//GreedyStepwise search = new GreedyStepwise();
129+
//search.setNumToSelect(980);
130+
//search.setSearchBackwards(true);
105131

106-
/* for (Category category : options.getAllClasses().stream().map(Category::fromText).collect(Collectors.toList())) {
132+
Ranker search = new Ranker();
133+
try {
134+
search.setOptions(new String[]{"-T", "0.0343"});
135+
} catch (Exception e) {
136+
throw new RuntimeException(e);
137+
}
138+
139+
//search.setNumToSelect(10);
140+
//Perform attribute selection
141+
AttributeSelection attributeSelection = new AttributeSelection();
142+
attributeSelection.setEvaluator(eval);
143+
attributeSelection.setSearch(search);
144+
attributeSelection.setRanking(true);
107145

108-
//Create and set attributes for the train instances
109-
ArrayList<Attribute> trainAttributes = createAttributes(category, dataset.getTrainMethods());
146+
Instances filteredInstances;
110147

111-
Instances trainInstances = createInstances(trainAttributes, dataset.getTrainMethods(), Collections.singleton(category));
112-
this.instances.put(category.getId().toLowerCase(), trainInstances);
113-
Util.exportInstancesToArff(trainInstances, "weka-"+category.getId());
148+
try {
149+
attributeSelection.SelectAttributes(instances);
150+
filteredInstances = attributeSelection.reduceDimensionality(instances);
114151

115-
//Create and set attributes for the test instances.
116-
/*ArrayList<Attribute> testAttributes = createAttributes(getCategories(category), testData.getMethods(), featureSets);
117-
Instances testInstances = createInstances(featureSets, testAttributes, testData.getMethods(), getCategories(category), category + "-test-instances");
152+
System.out.println(attributeSelection.toResultsString());
118153

119-
}*/
154+
} catch (Exception e) {
155+
throw new RuntimeException(e);
156+
}
157+
return filteredInstances;
120158
}
121159

122160

0 commit comments

Comments
 (0)