From 1f5fa07229d702ced0f046036f3ee60c60c8a153 Mon Sep 17 00:00:00 2001 From: Satwik Sai Prakash Sahoo Date: Fri, 14 Nov 2025 19:19:14 +0530 Subject: [PATCH 1/3] [ENH] Fix dataset root path (#3088) --- aeon/datasets/_data_loaders.py | 3 ++- aeon/datasets/_single_problem_loaders.py | 3 ++- aeon/datasets/dataset_collections.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/aeon/datasets/_data_loaders.py b/aeon/datasets/_data_loaders.py index d4c20b5cf8..5047b6e20d 100644 --- a/aeon/datasets/_data_loaders.py +++ b/aeon/datasets/_data_loaders.py @@ -23,6 +23,7 @@ import zipfile from datetime import datetime from http.client import IncompleteRead, RemoteDisconnected +from pathlib import Path from urllib.error import HTTPError, URLError from urllib.parse import urlparse from urllib.request import Request, urlopen, urlretrieve @@ -40,7 +41,7 @@ from aeon.utils.conversion import convert_collection DIRNAME = "data" -MODULE = os.path.join(os.path.dirname(aeon.__file__), "datasets") +MODULE = Path(aeon.__file__).parent / "datasets" CONNECTION_ERRORS = ( HTTPError, diff --git a/aeon/datasets/_single_problem_loaders.py b/aeon/datasets/_single_problem_loaders.py index 30e6a053c6..38e9999e33 100644 --- a/aeon/datasets/_single_problem_loaders.py +++ b/aeon/datasets/_single_problem_loaders.py @@ -24,6 +24,7 @@ ] import os +from pathlib import Path import numpy as np import pandas as pd @@ -32,7 +33,7 @@ from aeon.datasets._data_loaders import _load_saved_dataset, _load_tsc_dataset DIRNAME = "data" -MODULE = os.path.dirname(__file__) +MODULE = Path(__file__).parent def load_gunpoint(split=None, return_type="numpy3d"): diff --git a/aeon/datasets/dataset_collections.py b/aeon/datasets/dataset_collections.py index f47dac5cc4..24cbe28a8a 100644 --- a/aeon/datasets/dataset_collections.py +++ b/aeon/datasets/dataset_collections.py @@ -34,13 +34,14 @@ "get_available_tsf_datasets", ] import os +from pathlib import Path import aeon from aeon.datasets.tsc_datasets import multivariate, univariate from aeon.datasets.tser_datasets import tser_monash, tser_soton from aeon.datasets.tsf_datasets import tsf_all -MODULE = os.path.join(os.path.dirname(aeon.__file__), "datasets") +MODULE = Path(aeon.__file__).parent / "datasets" def get_available_tser_datasets(name="tser_soton", return_list=True): From f86f86dd6b678bb5cd6a1e4e9bb0489f5ae24bf7 Mon Sep 17 00:00:00 2001 From: Satwik Sai Prakash Sahoo Date: Sat, 29 Nov 2025 01:54:52 +0530 Subject: [PATCH 2/3] Change default dataset download path to user home directory --- aeon/datasets/_data_loaders.py | 69 ++++++++++++++++++------ aeon/datasets/_single_problem_loaders.py | 5 +- aeon/datasets/dataset_collections.py | 5 +- aeon/datasets/tests/test_data_loaders.py | 4 +- 4 files changed, 60 insertions(+), 23 deletions(-) diff --git a/aeon/datasets/_data_loaders.py b/aeon/datasets/_data_loaders.py index 5047b6e20d..53cda54ffc 100644 --- a/aeon/datasets/_data_loaders.py +++ b/aeon/datasets/_data_loaders.py @@ -41,7 +41,7 @@ from aeon.utils.conversion import convert_collection DIRNAME = "data" -MODULE = Path(aeon.__file__).parent / "datasets" +MODULE = os.path.join(os.path.dirname(aeon.__file__), "datasets") CONNECTION_ERRORS = ( HTTPError, @@ -466,7 +466,9 @@ def _download_and_extract(url, extract_path=None): with open(zip_file_name, "wb") as out_file: out_file.write(response.read()) if extract_path is None: - extract_path = os.path.join(MODULE, "local_data/%s/" % file_name.split(".")[0]) + extract_path = os.path.join( + str(Path.home() / ".aeon"), "local_data/%s/" % file_name.split(".")[0] + ) else: extract_path = os.path.join(extract_path, "%s/" % file_name.split(".")[0]) @@ -525,8 +527,14 @@ def _load_tsc_dataset( local_module = extract_path local_dirname = "" else: - local_module = MODULE - local_dirname = "data" + bundled_path = os.path.join(MODULE, "data", name) + if os.path.exists(bundled_path): + local_module = MODULE + local_dirname = "data" + else: + aeon_home = Path.home() / ".aeon" + local_module = str(aeon_home) + local_dirname = "data" if not os.path.exists(os.path.join(local_module, local_dirname)): os.makedirs(os.path.join(local_module, local_dirname)) @@ -546,7 +554,11 @@ def _load_tsc_dataset( try: _download_and_extract( url, - extract_path=extract_path, + extract_path=( + extract_path + if extract_path is not None + else os.path.join(local_module, local_dirname) + ), ) except zipfile.BadZipFile as e: raise ValueError( @@ -988,8 +1000,13 @@ def load_forecasting(name, extract_path=None, return_metadata=False): local_module = extract_path local_dirname = "" else: - local_module = MODULE - local_dirname = "data" + bundled_path = os.path.join(MODULE, "data", name) + if os.path.exists(bundled_path): + local_module = MODULE + local_dirname = "data" + else: + local_module = str(Path.home() / ".aeon") + local_dirname = "data" if not os.path.exists(os.path.join(local_module, local_dirname)): os.makedirs(os.path.join(local_module, local_dirname)) @@ -1029,7 +1046,11 @@ def load_forecasting(name, extract_path=None, return_metadata=False): try: _download_and_extract( url, - extract_path=extract_path, + extract_path=( + extract_path + if extract_path is not None + else os.path.join(local_module, local_dirname) + ), ) except zipfile.BadZipFile: raise ValueError( @@ -1142,8 +1163,13 @@ def load_regression( local_module = extract_path local_dirname = "" else: - local_module = MODULE - local_dirname = "data" + bundled_path = os.path.join(MODULE, "data", name) + if os.path.exists(bundled_path): + local_module = MODULE + local_dirname = "data" + else: + local_module = str(Path.home() / ".aeon") + local_dirname = "data" error_str = ( f"File name {name} is not in the list of valid files to download," f"see aeon.datasets.tser_datasetss.tser_soton for the list. " @@ -1183,7 +1209,11 @@ def load_regression( try: _download_and_extract( url, - extract_path=extract_path, + extract_path=( + extract_path + if extract_path is not None + else os.path.join(local_module, local_dirname) + ), ) except zipfile.BadZipFile: try_monash = True @@ -1323,8 +1353,13 @@ def load_classification( local_module = extract_path local_dirname = None else: - local_module = MODULE - local_dirname = "data" + bundled_path = os.path.join(MODULE, "data", name) + if os.path.exists(bundled_path): + local_module = MODULE + local_dirname = "data" + else: + local_module = str(Path.home() / ".aeon") + local_dirname = "data" if local_dirname is None: path = local_module else: @@ -1363,7 +1398,11 @@ def load_classification( try: _download_and_extract( url, - extract_path=extract_path, + extract_path=( + extract_path + if extract_path is not None + else os.path.join(local_module, local_dirname) + ), ) except zipfile.BadZipFile: try_zenodo = True @@ -1444,7 +1483,7 @@ def download_all_regression(extract_path=None): local_module = extract_path local_dirname = "" else: - local_module = MODULE + local_module = str(Path.home() / ".aeon") local_dirname = "data" if not os.path.exists(os.path.join(local_module, local_dirname)): diff --git a/aeon/datasets/_single_problem_loaders.py b/aeon/datasets/_single_problem_loaders.py index 38e9999e33..75e994ba68 100644 --- a/aeon/datasets/_single_problem_loaders.py +++ b/aeon/datasets/_single_problem_loaders.py @@ -24,7 +24,6 @@ ] import os -from pathlib import Path import numpy as np import pandas as pd @@ -33,7 +32,7 @@ from aeon.datasets._data_loaders import _load_saved_dataset, _load_tsc_dataset DIRNAME = "data" -MODULE = Path(__file__).parent +MODULE = os.path.dirname(__file__) def load_gunpoint(split=None, return_type="numpy3d"): @@ -990,4 +989,4 @@ def load_longley(return_array=True): data = data.astype(float) if return_array: return data.to_numpy().T - return data.T + return data.T \ No newline at end of file diff --git a/aeon/datasets/dataset_collections.py b/aeon/datasets/dataset_collections.py index 24cbe28a8a..870ffc1d81 100644 --- a/aeon/datasets/dataset_collections.py +++ b/aeon/datasets/dataset_collections.py @@ -34,14 +34,13 @@ "get_available_tsf_datasets", ] import os -from pathlib import Path import aeon from aeon.datasets.tsc_datasets import multivariate, univariate from aeon.datasets.tser_datasets import tser_monash, tser_soton from aeon.datasets.tsf_datasets import tsf_all -MODULE = Path(aeon.__file__).parent / "datasets" +MODULE = os.path.join(os.path.dirname(aeon.__file__), "datasets") def get_available_tser_datasets(name="tser_soton", return_list=True): @@ -160,4 +159,4 @@ def get_downloaded_tsf_datasets(extract_path=None): all_files = os.listdir(sub_dir) if name + ".tsf" in all_files: datasets.append(name) - return datasets + return datasets \ No newline at end of file diff --git a/aeon/datasets/tests/test_data_loaders.py b/aeon/datasets/tests/test_data_loaders.py index 29d7049b9e..59da1980d6 100644 --- a/aeon/datasets/tests/test_data_loaders.py +++ b/aeon/datasets/tests/test_data_loaders.py @@ -57,7 +57,7 @@ def test_load_forecasting_from_repo(): assert not meta["contain_missing_values"] assert not meta["contain_equal_length"] - shutil.rmtree(os.path.dirname(__file__) + "/../local_data") + shutil.rmtree(os.path.dirname(__file__) + "/../local_data", ignore_errors=True) @pytest.mark.skipif( @@ -84,7 +84,7 @@ def test_load_classification_from_repo(): assert meta["classlabel"] assert not meta["targetlabel"] assert meta["class_values"] == ["1", "2"] - shutil.rmtree(os.path.dirname(__file__) + "/../local_data") + shutil.rmtree(os.path.dirname(__file__) + "/../local_data", ignore_errors=True) @pytest.mark.skipif( From bf3b0cbfbabb41114fafb9e816c88de2e7eec1ef Mon Sep 17 00:00:00 2001 From: satwiksps <215063428+satwiksps@users.noreply.github.com> Date: Fri, 28 Nov 2025 20:25:38 +0000 Subject: [PATCH 3/3] Automatic `pre-commit` fixes --- aeon/datasets/_single_problem_loaders.py | 2 +- aeon/datasets/dataset_collections.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aeon/datasets/_single_problem_loaders.py b/aeon/datasets/_single_problem_loaders.py index 75e994ba68..30e6a053c6 100644 --- a/aeon/datasets/_single_problem_loaders.py +++ b/aeon/datasets/_single_problem_loaders.py @@ -989,4 +989,4 @@ def load_longley(return_array=True): data = data.astype(float) if return_array: return data.to_numpy().T - return data.T \ No newline at end of file + return data.T diff --git a/aeon/datasets/dataset_collections.py b/aeon/datasets/dataset_collections.py index 870ffc1d81..f47dac5cc4 100644 --- a/aeon/datasets/dataset_collections.py +++ b/aeon/datasets/dataset_collections.py @@ -159,4 +159,4 @@ def get_downloaded_tsf_datasets(extract_path=None): all_files = os.listdir(sub_dir) if name + ".tsf" in all_files: datasets.append(name) - return datasets \ No newline at end of file + return datasets