Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions tslearn/metrics/ctw.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
The :mod:`tslearn.metrics:ctc` module provides utilities related to
Canonical Time Warping.
"""

import numpy as np
from sklearn.cross_decomposition import CCA

Expand Down Expand Up @@ -157,15 +162,14 @@ def ctw_path(
# (possibly truncated to a fixed number of features) inputs
seq1_tr = s1 @ be.eye(s1.shape[1], n_components, dtype=be.float64)
seq2_tr = s2 @ be.eye(s2.shape[1], n_components, dtype=be.float64)
current_path, score_match = dtw_path(
current_path, current_score = dtw_path(
seq1_tr,
seq2_tr,
global_constraint=global_constraint,
sakoe_chiba_radius=sakoe_chiba_radius,
itakura_max_slope=itakura_max_slope,
be=be,
)
current_score = score_match

if verbose:
print("Iteration 0, score={}".format(current_score))
Expand All @@ -176,7 +180,7 @@ def ctw_path(
cca.fit(Wx @ s1, Wy @ s2)
seq1_tr, seq2_tr = cca.transform(s1, s2)

current_path, score_match = dtw_path(
new_path, new_score = dtw_path(
seq1_tr,
seq2_tr,
global_constraint=global_constraint,
Expand All @@ -185,10 +189,10 @@ def ctw_path(
be=be,
)

if np.array_equal(current_path, current_path):
if np.array_equal(current_path, new_path):
break

current_score = score_match
current_path, current_score = new_path, new_score

if verbose:
print("Iteration {}, score={}".format(it + 1, current_score))
Expand Down Expand Up @@ -329,7 +333,7 @@ def cdist_ctw(
If shape is (n_ts1, sz1), the dataset is composed of univariate time series.
If shape is (sz1,), the dataset is composed of a unique univariate time series.
dataset2 : None or array-like, shape=(n_ts2, sz2, d) or (n_ts2, sz2) or (sz2,) (default: None)
Another dataset of time series.
Another dataset of time series.
If `None`, self-similarity of `dataset1` is returned.
If shape is (n_ts2, sz2), the dataset is composed of univariate time series.
If shape is (sz2,), the dataset is composed of a unique univariate time series.
Expand Down
16 changes: 16 additions & 0 deletions tslearn/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,22 @@ def test_ctw():
np.testing.assert_allclose(dist, 1.0)
assert backend.belongs_to_backend(dist)

x = [[1, 1], [3, 4], [126, 126]]
y = [[1, 1.], [3., 3], [4., 4], [2., 2], [0, 0], [127, 127]]
dist_0 = tslearn.metrics.ctw_path(
cast(x, array_type),
cast(y, array_type),
max_iter=2,
be=be
)[2]
dist_1 = tslearn.metrics.ctw_path(
cast(x, array_type),
cast(y, array_type),
max_iter=3,
be=be
)[2]
assert dist_0 >= dist_1

# dtw
n1, n2, d1, d2 = 15, 10, 3, 1
rng = np.random.RandomState(0)
Expand Down
Loading