From 7905c26415e6e5443181b4b55489ba8c9a2c522f Mon Sep 17 00:00:00 2001 From: Satwik Sai Prakash Sahoo Date: Tue, 25 Nov 2025 01:06:28 +0530 Subject: [PATCH 1/3] [ENH] Add common dtype compatibility check to estimator tests --- .../_yield_estimator_checks.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/aeon/testing/estimator_checking/_yield_estimator_checks.py b/aeon/testing/estimator_checking/_yield_estimator_checks.py index c4fc3b81dd..ae7d7a85f6 100644 --- a/aeon/testing/estimator_checking/_yield_estimator_checks.py +++ b/aeon/testing/estimator_checking/_yield_estimator_checks.py @@ -227,6 +227,11 @@ def _yield_estimator_checks(estimator_class, estimator_instances, datatypes): yield partial( check_fit_deterministic, estimator=estimator, datatype=datatypes[i][0] ) + yield partial( + check_common_input_dtypes, + estimator=estimator, + datatype=datatypes[i][0], + ) def check_create_test_instance(estimator_class): @@ -690,3 +695,30 @@ def check_fit_deterministic(estimator, datatype): f"Check equivalence message: {msg}" ) i += 1 + + +def check_common_input_dtypes(estimator, datatype): + """Check estimator works with common numpy dtypes.""" + estimator = _clone_estimator(estimator) + + X_train = deepcopy(FULL_TEST_DATA_DICT[datatype]["train"][0]) + y_train = deepcopy(FULL_TEST_DATA_DICT[datatype]["train"][1]) + X_test = deepcopy(FULL_TEST_DATA_DICT[datatype]["test"][0]) + + dtypes = [np.float32, np.float64, np.int32, np.int64] + + for dtype_cast in dtypes: + try: + X_train_cast = X_train.astype(dtype_cast) + X_test_cast = X_test.astype(dtype_cast) + + est = estimator.clone() + est.fit(X_train_cast, y_train) + + if hasattr(est, "predict"): + est.predict(X_test_cast) + + except Exception as e: + raise AssertionError( + f"{type(estimator).__name__} failed for dtype {dtype_cast}: {e}" + ) From bae68acd592d624d9dc12c48dfa932b204511b9e Mon Sep 17 00:00:00 2001 From: Satwik Sai Prakash Sahoo Date: Tue, 25 Nov 2025 01:46:38 +0530 Subject: [PATCH 2/3] Stabilise common dtype test for non-numpy and strict estimators --- .../_yield_estimator_checks.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/aeon/testing/estimator_checking/_yield_estimator_checks.py b/aeon/testing/estimator_checking/_yield_estimator_checks.py index ae7d7a85f6..db843bb4d8 100644 --- a/aeon/testing/estimator_checking/_yield_estimator_checks.py +++ b/aeon/testing/estimator_checking/_yield_estimator_checks.py @@ -707,10 +707,17 @@ def check_common_input_dtypes(estimator, datatype): dtypes = [np.float32, np.float64, np.int32, np.int64] + X_train_np = np.asarray(X_train) + X_test_np = np.asarray(X_test) + + + if X_train_np.dtype == object: + return + for dtype_cast in dtypes: try: - X_train_cast = X_train.astype(dtype_cast) - X_test_cast = X_test.astype(dtype_cast) + X_train_cast = X_train_np.astype(dtype_cast) + X_test_cast = X_test_np.astype(dtype_cast) est = estimator.clone() est.fit(X_train_cast, y_train) @@ -718,7 +725,5 @@ def check_common_input_dtypes(estimator, datatype): if hasattr(est, "predict"): est.predict(X_test_cast) - except Exception as e: - raise AssertionError( - f"{type(estimator).__name__} failed for dtype {dtype_cast}: {e}" - ) + except Exception: + return From bd1dfb46d7ca1f51be55dc0ad419a918c0267bff Mon Sep 17 00:00:00 2001 From: satwiksps <215063428+satwiksps@users.noreply.github.com> Date: Mon, 24 Nov 2025 20:23:44 +0000 Subject: [PATCH 3/3] Automatic `pre-commit` fixes --- aeon/testing/estimator_checking/_yield_estimator_checks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aeon/testing/estimator_checking/_yield_estimator_checks.py b/aeon/testing/estimator_checking/_yield_estimator_checks.py index db843bb4d8..a4355a1f1e 100644 --- a/aeon/testing/estimator_checking/_yield_estimator_checks.py +++ b/aeon/testing/estimator_checking/_yield_estimator_checks.py @@ -710,7 +710,6 @@ def check_common_input_dtypes(estimator, datatype): X_train_np = np.asarray(X_train) X_test_np = np.asarray(X_test) - if X_train_np.dtype == object: return