Skip to content

Commit 52c98a2

Browse files
authored
[ENH] Add data loader for EEG classification datasets (#107)
* dataset lists * data loader * data loader * data loader * data loader
1 parent f3bb96a commit 52c98a2

File tree

5 files changed

+230
-1
lines changed

5 files changed

+230
-1
lines changed

aeon_neuro/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Utilities for loading datasets."""
22

33
__maintainer__ = ["TonyBagnall"]
4-
__all__ = ["load_kdd_example", "load_kdd_full_example"]
4+
__all__ = ["load_kdd_example", "load_kdd_full_example", "load_eeg_classification"]
55

6+
from aeon_neuro.datasets._data_loaders import load_eeg_classification
67
from aeon_neuro.datasets._single_problem_loaders import (
78
load_kdd_example,
89
load_kdd_full_example,
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""Function to load EEG Datasets from Zenodo."""
2+
3+
import os
4+
from urllib.request import urlretrieve
5+
6+
from aeon.datasets._single_problem_loaders import _load_saved_dataset
7+
from aeon.datasets.dataset_collections import get_downloaded_tsc_tsr_datasets
8+
9+
import aeon_neuro
10+
from aeon_neuro.datasets.classification_datasets import dataset_map
11+
12+
DIRNAME = "data"
13+
MODULE = os.path.join(os.path.dirname(aeon_neuro.__file__), "datasets")
14+
15+
16+
def load_eeg_classification(
17+
name,
18+
split=None,
19+
extract_path=None,
20+
return_metadata=False,
21+
):
22+
"""Load an EEG classification dataset.
23+
24+
This function loads EEG TSC problems into memory, attempting to load from the
25+
specified local path `extract_path`` or trying to download from
26+
https://zenodo.org// if the data is not in the local path. To download from
27+
zenodo, the dataset must be in the list ``dataset_map`` in data._data_loaders.py.
28+
This function assumes the data is stored in format
29+
``<extract_path>/<name>/<name>_TRAIN.ts`` and
30+
``<extract_path>/<name>/<name>_TEST.ts.`` If you want to load a file directly
31+
from a full path that is in ``aeon`` ts format, use the function
32+
`load_from_ts_file`` in ``aeon`` directly. If
33+
you do not specify ``extract_path``, it will set the path to
34+
``aeon_neuro/datasets/local_data``.
35+
36+
Data is assumed to be in the standard ``aeon`` .ts format: each row is a (possibly
37+
multivariate) time series. Each channel is separated by a colon, each value in
38+
a series is comma separated. For examples see aeon_neuro.datasets.data.
39+
40+
Parameters
41+
----------
42+
name : str
43+
Name of data set. If a dataset that is listed in tsc_datasets is given,
44+
this function will look in the extract_path first, and if it is not present,
45+
attempt to download the data from www.timeseriesclassification.com, saving it to
46+
the extract_path.
47+
split : None or str{"train", "test"}, default=None
48+
Whether to load the train or test partition of the problem. By default it
49+
loads both into a single dataset, otherwise it looks only for files of the
50+
format <name>_TRAIN.ts or <name>_TEST.ts.
51+
extract_path : str, default=None
52+
the path to look for the data. If no path is provided, the function
53+
looks in `aeon/datasets/local_data/`. If a path is given, it can be absolute,
54+
e.g. C:/Temp/ or relative, e.g. Temp/ or ./Temp/.
55+
return_metadata : boolean, default = True
56+
If True, returns a tuple (X, y, metadata)
57+
58+
Returns
59+
-------
60+
X: np.ndarray or list of np.ndarray
61+
y: np.ndarray
62+
The class labels for each case in X
63+
metadata: dict, optional
64+
returns the following metadata
65+
'problemname',timestamps, missing,univariate,equallength, class_values
66+
targetlabel should be false, and classlabel true
67+
68+
Raises
69+
------
70+
URLError or HTTPError
71+
If the website is not accessible.
72+
ValueError
73+
If a dataset name that does not exist on the repo is given or if a
74+
webpage is requested that does not exist.
75+
76+
Examples
77+
--------
78+
>>> from aeon.datasets import load_classification
79+
>>> X, y = load_classification(name="ArrowHead") # doctest: +SKIP
80+
"""
81+
if extract_path is not None:
82+
local_module = extract_path
83+
local_dirname = None
84+
else:
85+
local_module = MODULE
86+
local_dirname = "data"
87+
if local_dirname is None:
88+
path = local_module
89+
else:
90+
path = os.path.join(local_module, local_dirname)
91+
if not os.path.exists(path):
92+
os.makedirs(path)
93+
if name not in get_downloaded_tsc_tsr_datasets(path):
94+
if extract_path is None:
95+
local_dirname = "local_data"
96+
path = os.path.join(local_module, local_dirname)
97+
else:
98+
path = extract_path
99+
if not os.path.exists(path):
100+
os.makedirs(path)
101+
error_str = (
102+
f"File name {name} is not in the list of valid files to download,"
103+
f"see aeon_neuro.datasets.classification for the current list of "
104+
f"maintained datasets."
105+
)
106+
107+
if name not in get_downloaded_tsc_tsr_datasets(path):
108+
# Check if in the zenodo list
109+
if name in dataset_map.keys():
110+
id = dataset_map[name]
111+
if id == 49:
112+
raise ValueError(error_str)
113+
url_train = f"https://zenodo.org/record/{id}/files/{name}_TRAIN.ts"
114+
url_test = f"https://zenodo.org/record/{id}/files/{name}_TEST.ts"
115+
full_path = os.path.join(path, name)
116+
if not os.path.exists(full_path):
117+
os.makedirs(full_path)
118+
train_save = f"{full_path}/{name}_TRAIN.ts"
119+
test_save = f"{full_path}/{name}_TEST.ts"
120+
try:
121+
urlretrieve(url_train, train_save)
122+
urlretrieve(url_test, test_save)
123+
except Exception:
124+
raise ValueError(error_str)
125+
else:
126+
raise ValueError(error_str)
127+
X, y, meta = _load_saved_dataset(
128+
name=name,
129+
dir_name=name,
130+
split=split,
131+
local_module=local_module,
132+
local_dirname=local_dirname,
133+
return_meta=True,
134+
)
135+
# Check this is a classification problem
136+
if "classlabel" not in meta or not meta["classlabel"]:
137+
raise ValueError(
138+
f"You have tried to load a regression problem called {name} with "
139+
f"load_classifier. This will cause unintended consequences for any "
140+
f"classifier you build. If you want to load a regression problem, "
141+
f"use load_regression in ``aeon`` "
142+
)
143+
if return_metadata:
144+
return X, y, meta
145+
return X, y
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""Datasets in the EEG classification archive."""
2+
3+
# 31 EEG Classification problems
4+
5+
dataset_names = [
6+
"Alzhiemers",
7+
"Blink",
8+
"ButtonPress",
9+
"Epilepsy2",
10+
"EyesOpenShut",
11+
"FaceDetection",
12+
"FingerMovements",
13+
"HandMovementDirection",
14+
"ImaginedOpenCloseFist",
15+
"ImaginedOpenCloseFistFeet",
16+
"InnerSpeech",
17+
"Liverpool-Fibromyalgia",
18+
"LongIntervalTask",
19+
"LowCostEEG",
20+
"MatchingPennies",
21+
"MindReading",
22+
"MotorImagery",
23+
"N_Back",
24+
"OpenCloseFist",
25+
"OpenCloseFistFeet",
26+
"Photo-Stimulation",
27+
"PronouncedSpeech",
28+
"PsychologyButtonPress",
29+
"SelfRegulationSCP1",
30+
"SelfRegulationSCP2",
31+
"ShortIntervalTask",
32+
"SitStand",
33+
"Sleep",
34+
"SongFamiliarity",
35+
"VIPA",
36+
"VisualSpeech",
37+
]
38+
# Complete with zenodo number when available. 49 means not available yet
39+
dataset_map = {
40+
"Alzhiemers": 49,
41+
"Blink": 49,
42+
"ButtonPress": 49,
43+
"Epilepsy2": 49,
44+
"EyesOpenShut": 49,
45+
"FaceDetection": 49,
46+
"FingerMovements": 49,
47+
"HandMovementDirection": 49,
48+
"ImaginedOpenCloseFist": 49,
49+
"ImaginedOpenCloseFistFeet": 49,
50+
"InnerSpeech": 49,
51+
"Liverpool-Fibromyalgia": 49,
52+
"LongIntervalTask": 49,
53+
"LowCostEEG": 49,
54+
"MatchingPennies": 49,
55+
"MindReading": 49,
56+
"MotorImagery": 49,
57+
"N_Back": 49,
58+
"OpenCloseFist": 49,
59+
"OpenCloseFistFeet": 49,
60+
"Photo-Stimulation": 49,
61+
"PronouncedSpeech": 49,
62+
"PsychologyButtonPress": 49,
63+
"SelfRegulationSCP1": 49,
64+
"SelfRegulationSCP2": 49,
65+
"ShortIntervalTask": 49,
66+
"SitStand": 49,
67+
"Sleep": 49,
68+
"SongFamiliarity": 49,
69+
"VIPA": 49,
70+
"VisualSpeech": 49,
71+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Tests for loaders."""
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Test data loading with shipped data."""
2+
3+
from aeon_neuro.datasets._data_loaders import load_eeg_classification
4+
5+
6+
def test_load_eeg():
7+
"""Test data loading from provided datasets."""
8+
X, y = load_eeg_classification("SelfRegulationSCP1")
9+
assert X.shape == (561, 6, 896)
10+
X, y, meta = load_eeg_classification("SelfRegulationSCP1", return_metadata=True)
11+
assert meta["problemname"] == "selfregulationscp1"

0 commit comments

Comments
 (0)