diff --git a/tslearn/metrics/soft_dtw_keras b/tslearn/metrics/soft_dtw_keras new file mode 100644 index 00000000..0bf05c27 --- /dev/null +++ b/tslearn/metrics/soft_dtw_keras @@ -0,0 +1,186 @@ +try: + import tensorflow.keras.backend as K + import tensorflow as tf + import numpy as np + import keras + from keras import layers + + HAS_DEPS = True +except ImportError: + HAS_DEPS = False + +if not HAS_DEPS: + class SoftDTWLossTF: + def __init__(self, *args, **kwargs): + raise ValueError( + "Could not use SoftDTWLossTF since keras and/or tensorflow are not installed" + ) +else: + DBL_MAX = np.finfo("double").max + + def euclidean_squared_dist(x, y): + """Calculates the Euclidean squared distance between each element in x and y per timestep. + + Parameters + ---------- + x : Tensor, shape=[b, m, d] + Batch of time series. + y : Tensor, shape=[b, n, d] + Batch of time series. + + Returns + ------- + dist : Tensor, shape=[b, m, n] + The pairwise squared Euclidean distances. + """ + # Broadcast to each other + x = x[:, :, None] + y = y[:, None] + return K.sum(K.pow(x - y, 2), axis=3) + + def softmin3(a, b, c, *, gamma): + r"""Compute softmin of 3 input variables with parameter gamma. + The inputs need to be of the same shape. + + In the limit case :math:`\gamma = 0`, the softmin operator reduces to + a hard-min operator. + + Parameters + ---------- + a : float64[...] + First input variable. + b : float64[...] + Second input variable. + c : float64[...] + Third input variable. + gamma : float64 + Regularization parameter. + + Returns + ------- + softmin_value : float64 + Softmin value. + """ + a /= -gamma + b /= -gamma + c /= -gamma + + max_val = layers.Maximum()(inputs=[a, b, c]) + + tmp = ( + K.exp(a - max_val) + + K.exp(b - max_val) + + K.exp(c - max_val) + ) + return -gamma * (K.log(tmp) + max_val) + + + @tf.custom_gradient + def soft_dtw(D, gamma): + """Compute soft dynamic time warping. + + Parameters + ---------- + D : array-like, shape=(m, n), dtype=float64 + R : array-like, shape=(m+2, n+2), dtype=float64 + gamma : float64 + Regularization parameter. + """ + b, m, n = D.shape + + # Initialization. + R = np.empty((m + 2, n + 2), dtype=object) + # Doesn't work because of numpy auto broadcast + #R[: m + 1, 0] = K.constant(DBL_MAX) + #R[0, : n + 1] = K.constant(DBL_MAX) + # Instead we check for None later + dbl_max = K.constant(DBL_MAX)[None, None, None] + + R[0, 0] = K.constant(0)[None, None, None] + + # DP recursion. + for i in range(1, m + 1): + for j in range(1, n + 1): + # D is indexed starting from 0. + R[i, j] = D[:, i - 1, j - 1] + softmin3( + R[i - 1, j] if R[i - 1, j] is not None else dbl_max, + R[i - 1, j - 1] if R[i - 1, j - 1] is not None else dbl_max, + R[i, j - 1] if R[i, j - 1] is not None else dbl_max, + gamma=gamma + ) + + def grad(upstream): + D_ = np.zeros((m, n), dtype=object) + R_ = np.zeros((m + 2, n + 2), dtype=object) + E_ = np.zeros((m + 2, n + 2), dtype=object) + + for i in range(m): + for j in range(n): + D_[i, j] = D[:, i, j] + + + for i in range(m+2): + for j in range(n+2): + R_[i, j] = R[i, j] + + m_ = m - 1 + n_ = n - 1 + + # Initialization. + D_[:m_, n_] = 0 + D_[m_, :n_] = 0 + for i in range(1, m_ + 1): + R_[i, n_ + 1] = -dbl_max + for i in range(1, n_ + 1): + R_[m_ + 1, i] = -dbl_max + + E_[m_ + 1, n_ + 1] = 1 + R_[m_ + 1, n_ + 1] = R[m_, n_] + D_[m_, n_] = 0 + + for j in range(n_, 0, -1): # ranges from n to 1 + for i in range(m_, 0, -1): # ranges from m to 1 + a = K.exp((R_[i + 1, j] - R_[i, j] - D_[i, j - 1]) / gamma) + b = K.exp((R_[i, j + 1] - R_[i, j] - D_[i - 1, j]) / gamma) + c = K.exp((R_[i + 1, j + 1] - R_[i, j] - D_[i, j]) / gamma) + E_[i, j] = E_[i + 1, j] * a + E_[i, j + 1] * b + E_[i + 1, j + 1] * c + + ll = list(E_[1 : m + 1, 1 : n + 1].flatten()) + for i in range(len(ll)): + if isinstance(ll[i], int): + ll[i] = K.constant(ll[i])[None, None, None] + + E_T = layers.Reshape((m, n), name="dtw_grad_reshape")(inputs=layers.Concatenate()(inputs=ll)) + + return upstream[:, None, None] * E_T, None + + return R[-2, -2], grad + + class SoftDTWLossTF(tf.keras.losses.Loss): + def __init__(self, gamma=1.0, normalize=False, dist_func=None): + super().__init__() + self.gamma = gamma + self.normalize = normalize + self.dist_func = dist_func if dist_func is not None else euclidean_squared_dist + + def call(self, y_true, y_pred): + x = y_true + y = y_pred + + bx, lx, dx = x.shape + by, ly, dy = y.shape + + assert bx == by + assert dx == dy + + d_xy = self.dist_func(x, y) + loss_xy = soft_dtw(d_xy, self.gamma) + + if self.normalize: + d_xx = self.dist_func(x, x) + d_yy = self.dist_func(y, y) + loss_xx = soft_dtw(d_xx, self.gamma) + loss_yy = soft_dtw(d_yy, self.gamma) + return loss_xy - 1 / 2 * (loss_xx + loss_yy) + else: + return loss_xy