diff --git a/tslearn/metrics/ctw.py b/tslearn/metrics/ctw.py index f3d48cef..97546b73 100644 --- a/tslearn/metrics/ctw.py +++ b/tslearn/metrics/ctw.py @@ -1,3 +1,8 @@ +""" +The :mod:`tslearn.metrics:ctc` module provides utilities related to +Canonical Time Warping. +""" + import numpy as np from sklearn.cross_decomposition import CCA @@ -157,7 +162,7 @@ def ctw_path( # (possibly truncated to a fixed number of features) inputs seq1_tr = s1 @ be.eye(s1.shape[1], n_components, dtype=be.float64) seq2_tr = s2 @ be.eye(s2.shape[1], n_components, dtype=be.float64) - current_path, score_match = dtw_path( + current_path, current_score = dtw_path( seq1_tr, seq2_tr, global_constraint=global_constraint, @@ -165,7 +170,6 @@ def ctw_path( itakura_max_slope=itakura_max_slope, be=be, ) - current_score = score_match if verbose: print("Iteration 0, score={}".format(current_score)) @@ -176,7 +180,7 @@ def ctw_path( cca.fit(Wx @ s1, Wy @ s2) seq1_tr, seq2_tr = cca.transform(s1, s2) - current_path, score_match = dtw_path( + new_path, new_score = dtw_path( seq1_tr, seq2_tr, global_constraint=global_constraint, @@ -185,10 +189,10 @@ def ctw_path( be=be, ) - if np.array_equal(current_path, current_path): + if np.array_equal(current_path, new_path): break - current_score = score_match + current_path, current_score = new_path, new_score if verbose: print("Iteration {}, score={}".format(it + 1, current_score)) @@ -329,7 +333,7 @@ def cdist_ctw( If shape is (n_ts1, sz1), the dataset is composed of univariate time series. If shape is (sz1,), the dataset is composed of a unique univariate time series. dataset2 : None or array-like, shape=(n_ts2, sz2, d) or (n_ts2, sz2) or (sz2,) (default: None) - Another dataset of time series. + Another dataset of time series. If `None`, self-similarity of `dataset1` is returned. If shape is (n_ts2, sz2), the dataset is composed of univariate time series. If shape is (sz2,), the dataset is composed of a unique univariate time series. diff --git a/tslearn/tests/test_metrics.py b/tslearn/tests/test_metrics.py index 20d93c41..49fe2704 100644 --- a/tslearn/tests/test_metrics.py +++ b/tslearn/tests/test_metrics.py @@ -77,6 +77,22 @@ def test_ctw(): np.testing.assert_allclose(dist, 1.0) assert backend.belongs_to_backend(dist) + x = [[1, 1], [3, 4], [126, 126]] + y = [[1, 1.], [3., 3], [4., 4], [2., 2], [0, 0], [127, 127]] + dist_0 = tslearn.metrics.ctw_path( + cast(x, array_type), + cast(y, array_type), + max_iter=2, + be=be + )[2] + dist_1 = tslearn.metrics.ctw_path( + cast(x, array_type), + cast(y, array_type), + max_iter=3, + be=be + )[2] + assert dist_0 >= dist_1 + # dtw n1, n2, d1, d2 = 15, 10, 3, 1 rng = np.random.RandomState(0)