Skip to content

Commit aa096ca

Browse files
authored
Merge pull request #180 from lucasimi/bugfix/cubical-cover-extra-interval
Bugfix/cubical cover extra interval
2 parents b07de8a + 78b66da commit aa096ca

File tree

4 files changed

+50
-24
lines changed

4 files changed

+50
-24
lines changed

src/tdamapper/cover.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,14 @@ def fit(self, X):
303303
else:
304304
self.__overlap_frac = self.overlap_frac
305305
self.__n_intervals = self.n_intervals
306-
if (self.__overlap_frac <= 0.0) or (self.__overlap_frac > 0.5):
306+
if self.__overlap_frac <= 0.0:
307+
raise ValueError(
308+
'The parameter overlap_frac is expected to be '
309+
'> 0.0'
310+
)
311+
if self.__overlap_frac > 0.5:
307312
warn_user(
308-
'The parameter overlap_frac is expected within range '
309-
'(0.0, 0.5]'
313+
'The parameter overlap_frac is expected to be <= 0.5'
310314
)
311315
self.__min, self.__max, self.__delta = self._get_bounds(X)
312316
radius = 1.0 / (2.0 - 2.0 * self.__overlap_frac)
@@ -337,16 +341,20 @@ def search(self, x):
337341
return self.__cover.search(center)
338342

339343
def _get_center(self, x):
340-
cell = self.__n_intervals * (x - self.__min) // self.__delta
344+
offset = self._offset(x)
341345
center = self._phi(x)
342-
return tuple(cell), center
346+
return tuple(offset), center
343347

344348
def _get_overlap_frac(self, dim, overlap_vol_frac):
345349
beta = math.pow(1.0 - overlap_vol_frac, 1.0 / dim)
346350
return 1.0 - 1.0 / (2.0 - beta)
347351

352+
def _offset(self, x):
353+
return np.minimum(self.__n_intervals - 1, np.floor(self._gamma_n(x)))
354+
348355
def _phi(self, x):
349-
return self._gamma_n_inv(0.5 + np.floor(self._gamma_n(x)))
356+
offset = self._offset(x)
357+
return self._gamma_n_inv(0.5 + offset)
350358

351359
def _gamma_n(self, x):
352360
return self.__n_intervals * (x - self.__min) / self.__delta

tests/test_unit_core.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,14 @@ def test_proximity_cubical_line(self):
144144
g = mp.fit_transform(data, data)
145145
self.assertEqual(4, len(g.nodes))
146146

147+
def test_standard_cubical_line(self):
148+
data = np.array([[float(i)] for i in range(1000)])
149+
cover = StandardCubicalCover(n_intervals=4, overlap_frac=0.5)
150+
clustering = TrivialClustering()
151+
mp = MapperAlgorithm(cover, clustering)
152+
g = mp.fit_transform(data, data)
153+
self.assertEqual(4, len(g.nodes))
154+
147155
def test_cubical_line(self):
148156
data = np.array([[float(i)] for i in range(1000)])
149157
cover = CubicalCover(n_intervals=4, overlap_frac=0.5)
@@ -152,6 +160,14 @@ def test_cubical_line(self):
152160
g = mp.fit_transform(data, data)
153161
self.assertEqual(4, len(g.nodes))
154162

163+
def test_cubical_no_overlap(self):
164+
data = np.array([[0.0], [1.0], [2.0]])
165+
cover = StandardCubicalCover(n_intervals=2, overlap_frac=0)
166+
clustering = TrivialClustering()
167+
mp = MapperAlgorithm(cover, clustering)
168+
with self.assertRaises(ValueError):
169+
mp.fit_transform(data, data)
170+
155171
def test_mock_connected_components(self):
156172
data = [0, 1, 2, 3]
157173

tests/test_unit_params.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ def __test_clone(self, obj):
1818
obj_repr = repr(obj)
1919
obj_cln = clone(obj)
2020
cln_repr = repr(obj_cln)
21-
self.assertEquals(obj_repr, cln_repr)
21+
self.assertEqual(obj_repr, cln_repr)
2222

2323
def __test_repr(self, obj):
2424
obj_repr = repr(obj)
2525
_obj = eval(obj_repr)
2626
_obj_repr = repr(_obj)
27-
self.assertEquals(obj_repr, _obj_repr)
27+
self.assertEqual(obj_repr, _obj_repr)
2828

2929
def __test_clone_and_repr(self, obj):
3030
self.__test_clone(obj)
@@ -38,16 +38,16 @@ def test_params_mapper_algorithm(self):
3838
),
3939
)
4040
params = est.get_params(deep=False)
41-
self.assertEquals(5, len(params))
41+
self.assertEqual(5, len(params))
4242
params = est.get_params()
43-
self.assertEquals(12, len(params))
44-
self.assertEquals(3, params['cover__n_intervals'])
45-
self.assertEquals(0.3, params['cover__overlap_frac'])
43+
self.assertEqual(12, len(params))
44+
self.assertEqual(3, params['cover__n_intervals'])
45+
self.assertEqual(0.3, params['cover__overlap_frac'])
4646
est.set_params(cover__n_intervals=2, cover__overlap_frac=0.2)
4747
params = est.get_params()
48-
self.assertEquals(12, len(params))
49-
self.assertEquals(2, params['cover__n_intervals'])
50-
self.assertEquals(0.2, params['cover__overlap_frac'])
48+
self.assertEqual(12, len(params))
49+
self.assertEqual(2, params['cover__n_intervals'])
50+
self.assertEqual(0.2, params['cover__overlap_frac'])
5151

5252
def test_params_mapper_clustering(self):
5353
est = MapperClustering(
@@ -57,16 +57,16 @@ def test_params_mapper_clustering(self):
5757
),
5858
)
5959
params = est.get_params(deep=False)
60-
self.assertEquals(3, len(params))
60+
self.assertEqual(3, len(params))
6161
params = est.get_params()
62-
self.assertEquals(10, len(params))
63-
self.assertEquals(3, params['cover__n_intervals'])
64-
self.assertEquals(0.3, params['cover__overlap_frac'])
62+
self.assertEqual(10, len(params))
63+
self.assertEqual(3, params['cover__n_intervals'])
64+
self.assertEqual(0.3, params['cover__overlap_frac'])
6565
est.set_params(cover__n_intervals=2, cover__overlap_frac=0.2)
6666
params = est.get_params()
67-
self.assertEquals(10, len(params))
68-
self.assertEquals(2, params['cover__n_intervals'])
69-
self.assertEquals(0.2, params['cover__overlap_frac'])
67+
self.assertEqual(10, len(params))
68+
self.assertEqual(2, params['cover__n_intervals'])
69+
self.assertEqual(0.2, params['cover__overlap_frac'])
7070

7171
def test_clone_and_repr_ball_cover(self):
7272
self.__test_clone_and_repr(BallCover())

tests/test_unit_proximity.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_cubical_proximity(self):
4343
data = list(range(m, M + 1))
4444
cover = CubicalCover(n_intervals=n, overlap_frac=p)
4545
cover.fit(data)
46-
for x in data:
46+
for x in data[:-1]:
4747
result = cover.search(x)
4848
i = math.floor((x - m) / (w - delta))
4949
a_i = m + i * (w - delta) - delta / 2.0
@@ -53,7 +53,9 @@ def test_cubical_proximity(self):
5353
self.assertTrue(c in expected)
5454
for c in expected:
5555
self.assertTrue(c in result)
56-
#self.assertEqual(set(expected), set(result))
56+
x = data[-1]
57+
last_result = cover.search(x)
58+
self.assertEqual(result, last_result)
5759

5860
def test_cubical_params(self):
5961
cover = CubicalCover(n_intervals=10, overlap_frac=0.5)

0 commit comments

Comments
 (0)