Skip to content

Commit e716c88

Browse files
authored
Add feature shape and class count methods to HiggsDataLoader (#1641)
This update introduces two new methods, `get_feature_shape` and `get_num_classes`, to the HiggsDataLoader class. These methods provide easy access to the feature shape and the number of classes for the Higgs dataset, enhancing the usability of the data loader. Signed-off-by: Rahul Garg <rahul.garg@intel.com>
1 parent 333049f commit e716c88

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

openfl-workspace/xgb_higgs/src/dataloader.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ class HiggsDataLoader(XGBoostDataLoader):
2525
def __init__(self, data_path=None, **kwargs):
2626
super().__init__(**kwargs)
2727

28+
# Define default feature shape and number of classes for Higgs dataset
29+
self.feature_shape = (28,)
30+
self.num_classes = 2
2831

2932
# If data_path is None, this is being used for model initialization only
3033
if data_path is None:
@@ -38,6 +41,22 @@ def __init__(self, data_path=None, **kwargs):
3841
self.X_valid = X_valid
3942
self.y_valid = y_valid
4043

44+
def get_feature_shape(self):
45+
"""Returns the shape of an example feature array.
46+
47+
Returns:
48+
list: The shape of an example feature array [3, 150, 150] for Histology images.
49+
"""
50+
return self.feature_shape
51+
52+
def get_num_classes(self):
53+
"""Returns the number of classes for classification tasks.
54+
55+
Returns:
56+
int: The number of classes (8 for Histology dataset).
57+
"""
58+
return self.num_classes
59+
4160

4261
def load_Higgs(data_path, **kwargs):
4362
"""

0 commit comments

Comments
 (0)