diff --git a/aeon/classification/distance_based/_elastic_ensemble.py b/aeon/classification/distance_based/_elastic_ensemble.py index c0b0bce805..73068a8ce3 100644 --- a/aeon/classification/distance_based/_elastic_ensemble.py +++ b/aeon/classification/distance_based/_elastic_ensemble.py @@ -42,7 +42,9 @@ class ElasticEnsemble(BaseClassifier): A ``list`` of strings identifying which distance measures to include. Valid values are one or more of: ``euclidean``, ``dtw``, ``wdtw``, ``ddtw``, ``wddtw``, ``lcss``, ``erp``, ``msm``, ``twe``. The default value ``all`` means that all - the previously listed distances are used. + the previously listed distances are used. The special value ``ts-quad`` can be + used to select the distance measures for the TS-QUAD ensemble: WDTW, DDTW, + LCSS, and MSM. proportion_of_param_options : float, default=1 The proportion of the parameter grid space to search optional. proportion_train_in_param_finding : float, default=1 @@ -153,6 +155,17 @@ def _fit(self, X, y): "euclidean", "twe", ] + elif self.distance_measures == "ts-quad": + self._distance_measures = [ + "wdtw", + "ddtw", + "lcss", + "msm", + ] + if self.verbose > 0: + print( # noqa: T201 + "Configuring ElasticEnsemble as TS-QUAD with WDTW, DDTW, LCSS, MSM." + ) else: self._distance_measures = self.distance_measures @@ -515,6 +528,13 @@ def _get_test_params(cls, parameter_set: str = "default") -> dict | list[dict]: "majority_vote": True, "distance_measures": ["dtw", "ddtw", "wdtw"], } + elif parameter_set == "ts-quad": + return { + "proportion_of_param_options": 0.01, + "proportion_train_for_test": 0.1, + "majority_vote": True, + "distance_measures": "ts-quad", + } else: return { "proportion_of_param_options": 0.01, diff --git a/aeon/classification/distance_based/tests/test_elastic_ensemble.py b/aeon/classification/distance_based/tests/test_elastic_ensemble.py index 53215dad2f..a246d6ebc6 100644 --- a/aeon/classification/distance_based/tests/test_elastic_ensemble.py +++ b/aeon/classification/distance_based/tests/test_elastic_ensemble.py @@ -90,3 +90,17 @@ def test_all_distance_measures(): ee.fit(X, y) distances = list(ee.get_metric_params()) assert len(distances) == 9 + + +def test_ts_quad_distance_measures(): + """Test the 'ts-quad' option of the distance_measures parameter.""" + X = np.random.random(size=(10, 1, 10)) + y = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + ee = ElasticEnsemble( + distance_measures="ts-quad", + proportion_train_in_param_finding=0.2, + proportion_of_param_options=0.1, + ) + ee.fit(X, y) + actual_distances = list(ee.get_metric_params()) + assert len(actual_distances) == 4 diff --git a/docs/changelogs/v1.3.md b/docs/changelogs/v1.3.md index eae8566a67..6e250984ab 100644 --- a/docs/changelogs/v1.3.md +++ b/docs/changelogs/v1.3.md @@ -35,6 +35,7 @@ September 2025 ### Enhancements +- [ENH] Implement TS-QUAD as a distance parameter for `ElasticEnsemble` ({pr}`3126`) {user}`Nithurshen` - [ENH] Improvements to ST transformer and classifier ({pr}`2968`) {user}`MatthewMiddlehurst` - [ENH] KNN n_jobs and updated kneighbours method ({pr}`2578`) {user}`chrisholder` - [ENH] Refactor signature code ({pr}`2943`) {user}`TonyBagnall`