Skip to content

Commit e2e64ce

Browse files
authored
kasba bug fixes (#3120)
1 parent b999d55 commit e2e64ce

File tree

3 files changed

+48
-45
lines changed

3 files changed

+48
-45
lines changed

aeon/clustering/averaging/_kasba_average.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def _kasba_refine_one_iter(
248248
for i in shuffled_indices:
249249
curr_ts = X[i]
250250
curr_alignment, _ = _get_alignment_path(
251-
center=barycenter,
251+
center=barycenter_copy,
252252
ts=curr_ts,
253253
distance=distance,
254254
window=window,

aeon/clustering/averaging/tests/test_kasba.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ def test_kasba_ba_expected():
6363
],
6464
)
6565
def test_kasba_ba_uni(distance, init_barycenter):
66-
"""Test kasba dba functionality."""
66+
"""Test kasba ba functionality."""
6767
distance = distance[0]
68-
X_train_uni = make_example_3d_numpy(10, 1, 10, random_state=1, return_y=False)
68+
X_train_uni = make_example_3d_numpy(20, 1, 10, random_state=1, return_y=False)
6969

7070
params = {
7171
"window": 0.2,
@@ -92,7 +92,6 @@ def test_kasba_ba_uni(distance, init_barycenter):
9292
assert average_ts_uni.shape == X_train_uni[0].shape
9393
assert np.allclose(average_ts_uni, call_directly_average_ts_uni)
9494

95-
# EDR and shape_dtw with random values don't update the barycenter so skipping
9695
if distance not in ["shape_dtw", "edr"]:
9796
# Test not just returning the init barycenter
9897
assert not np.array_equal(average_ts_uni, init_barycenter)
@@ -144,7 +143,11 @@ def test_kasba_distance_params(distance):
144143
"""Test kasba with various distance parameters."""
145144
distance_params = distance[1]
146145
distance = distance[0]
147-
X_train_uni = make_example_3d_numpy(10, 1, 10, random_state=1, return_y=False)
146+
if distance == "soft_dtw":
147+
# Skip for now and add back when soft-dtw refactored
148+
return
149+
150+
X_train_uni = make_example_3d_numpy(20, 1, 10, random_state=1, return_y=False)
148151

149152
for key in distance_params:
150153
curr_param = {key: distance_params[key]}

aeon/testing/expected_results/expected_average_results.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -329,54 +329,54 @@
329329

330330
expected_kasba_ba_univariate = [
331331
[
332-
1.66927775,
333-
2.00725484,
334-
0.95286261,
335-
1.98356308,
336-
0.81880663,
337-
1.37102377,
338-
1.77943005,
339-
1.37294541,
340-
1.33994775,
341-
1.67563731,
332+
1.6974843043271997,
333+
1.9018800362618855,
334+
1.1382045429284227,
335+
1.8251262846578693,
336+
1.0275022774799252,
337+
1.4765112774785576,
338+
1.7094282969055936,
339+
1.3646825812938472,
340+
1.1952637933803256,
341+
1.6600703806952517,
342342
]
343343
]
344344

345345
expected_kasba_ba_multivariate = [
346346
[
347-
0.71767481,
348-
1.39930349,
349-
1.29434224,
350-
0.72628471,
351-
1.12290607,
352-
1.23966907,
353-
0.78614546,
354-
1.56670693,
355-
1.09827156,
356-
1.57560871,
347+
0.7176748106272357,
348+
1.3993034910724549,
349+
1.2943422445754889,
350+
0.7262847133788719,
351+
1.1005654993888412,
352+
1.2172160392802192,
353+
0.7861454601925223,
354+
1.5667069293313187,
355+
1.098271560623938,
356+
1.5756087067076097,
357357
],
358358
[
359-
1.13660818,
360-
1.36478053,
361-
1.09463845,
362-
1.65103481,
363-
0.97217855,
364-
1.41841601,
365-
1.2157067,
366-
1.11209033,
367-
1.30966526,
368-
1.47004663,
359+
1.1366081768662664,
360+
1.3647805338946,
361+
1.094638448269596,
362+
1.6510348057425757,
363+
1.0445155948956755,
364+
1.3809953389562177,
365+
1.215706700151818,
366+
1.1120903277144476,
367+
1.3096652637918798,
368+
1.4700466261528247,
369369
],
370370
[
371-
1.19600771,
372-
1.52977945,
373-
1.38014143,
374-
1.08513837,
375-
1.18066111,
376-
0.99505074,
377-
1.12834455,
378-
1.0172058,
379-
1.32168863,
380-
1.30702151,
371+
1.1960077112744116,
372+
1.5297794504034785,
373+
1.3801414272436885,
374+
1.0851383745767733,
375+
1.0978851749939638,
376+
1.085522061831642,
377+
1.1283445515251094,
378+
1.0172058031512394,
379+
1.3216886255042977,
380+
1.3070215133512242,
381381
],
382382
]

0 commit comments

Comments
 (0)