Skip to content

Commit 363a1f2

Browse files
authored
Merge pull request #181 from lucasimi/develop
Develop
2 parents 30f1eb0 + aa096ca commit 363a1f2

File tree

5 files changed

+241
-42
lines changed

5 files changed

+241
-42
lines changed

src/tdamapper/_common.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,14 @@ class ParamsMixin:
8181
scikit-learn `get_params` and `set_params`.
8282
"""
8383

84-
def __is_param_internal(self, k):
85-
return k.startswith('_') or k.endswith('_')
84+
def __is_param_public(self, k):
85+
return (not k.startswith('_')) and (not k.endswith('_'))
86+
87+
def __split_param(self, k):
88+
k_split = k.split('__')
89+
outer = k_split[0]
90+
inner = '__'.join(k_split[1:])
91+
return outer, inner
8692

8793
def get_params(self, deep=True):
8894
"""
@@ -91,20 +97,47 @@ def get_params(self, deep=True):
9197
:param deep: A flag for returning also nested parameters.
9298
:type deep: bool, optional.
9399
"""
94-
params = self.__dict__.items()
95-
return {k: v for k, v in params if not self.__is_param_internal(k)}
100+
params = {}
101+
for k, v in self.__dict__.items():
102+
if self.__is_param_public(k):
103+
params[k] = v
104+
if hasattr(v, 'get_params') and deep:
105+
for _k, _v in v.get_params().items():
106+
params[f'{k}__{_k}'] = _v
107+
return params
96108

97109
def set_params(self, **params):
98110
"""
99111
Set public parameters. Only updates attributes that already exist.
100112
"""
113+
nested_params = []
101114
for k, v in params.items():
102-
if hasattr(self, k) and not self.__is_param_internal(k):
103-
setattr(self, k, v)
115+
if self.__is_param_public(k):
116+
k_outer, k_inner = self.__split_param(k)
117+
if not k_inner:
118+
if hasattr(self, k_outer):
119+
setattr(self, k_outer, v)
120+
else:
121+
nested_params.append((k_outer, k_inner, v))
122+
for k_outer, k_inner, v in nested_params:
123+
if hasattr(self, k_outer):
124+
k_attr = getattr(self, k_outer)
125+
k_attr.set_params(**{k_inner: v})
104126
return self
105127

128+
def __repr__(self):
129+
obj_noargs = type(self)()
130+
args_repr = []
131+
for k, v in self.__dict__.items():
132+
v_default = getattr(obj_noargs, k)
133+
v_default_repr = repr(v_default)
134+
v_repr = repr(v)
135+
if self.__is_param_public(k) and not v_repr == v_default_repr:
136+
args_repr.append(f'{k}={v_repr}')
137+
return f"{self.__class__.__name__}({', '.join(args_repr)})"
138+
106139

107-
def clone(estimator):
140+
def clone(obj):
108141
"""
109142
Clone an estimator, returning a new one, unfitted, having the same public
110143
parameters.
@@ -114,5 +147,7 @@ def clone(estimator):
114147
:return: A new estimator with the same parameters.
115148
:rtype: A scikit-learn compatible estimator
116149
"""
117-
params = estimator.get_params(deep=True)
118-
return type(estimator)(**params)
150+
params = obj.get_params(deep=True)
151+
obj_noargs = type(obj)()
152+
obj_noargs.set_params(**params)
153+
return obj_noargs

src/tdamapper/cover.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,14 @@ def fit(self, X):
303303
else:
304304
self.__overlap_frac = self.overlap_frac
305305
self.__n_intervals = self.n_intervals
306-
if (self.__overlap_frac <= 0.0) or (self.__overlap_frac > 0.5):
306+
if self.__overlap_frac <= 0.0:
307+
raise ValueError(
308+
'The parameter overlap_frac is expected to be '
309+
'> 0.0'
310+
)
311+
if self.__overlap_frac > 0.5:
307312
warn_user(
308-
'The parameter overlap_frac is expected within range '
309-
'(0.0, 0.5]'
313+
'The parameter overlap_frac is expected to be <= 0.5'
310314
)
311315
self.__min, self.__max, self.__delta = self._get_bounds(X)
312316
radius = 1.0 / (2.0 - 2.0 * self.__overlap_frac)
@@ -337,16 +341,20 @@ def search(self, x):
337341
return self.__cover.search(center)
338342

339343
def _get_center(self, x):
340-
cell = self.__n_intervals * (x - self.__min) // self.__delta
344+
offset = self._offset(x)
341345
center = self._phi(x)
342-
return tuple(cell), center
346+
return tuple(offset), center
343347

344348
def _get_overlap_frac(self, dim, overlap_vol_frac):
345349
beta = math.pow(1.0 - overlap_vol_frac, 1.0 / dim)
346350
return 1.0 - 1.0 / (2.0 - beta)
347351

352+
def _offset(self, x):
353+
return np.minimum(self.__n_intervals - 1, np.floor(self._gamma_n(x)))
354+
348355
def _phi(self, x):
349-
return self._gamma_n_inv(0.5 + np.floor(self._gamma_n(x)))
356+
offset = self._offset(x)
357+
return self._gamma_n_inv(0.5 + offset)
350358

351359
def _gamma_n(self, x):
352360
return self.__n_intervals * (x - self.__min) / self.__delta
@@ -512,24 +520,20 @@ def __init__(
512520
self.leaf_capacity = leaf_capacity
513521
self.leaf_radius = leaf_radius
514522
self.pivoting = pivoting
515-
if algorithm == 'proximity':
516-
self.__cubical_cover = ProximityCubicalCover(
517-
n_intervals=n_intervals,
518-
overlap_frac=overlap_frac,
519-
kind=kind,
520-
leaf_capacity=leaf_capacity,
521-
leaf_radius=leaf_radius,
522-
pivoting=pivoting,
523-
)
524-
elif algorithm == 'standard':
525-
self.__cubical_cover = StandardCubicalCover(
526-
n_intervals=n_intervals,
527-
overlap_frac=overlap_frac,
528-
kind=kind,
529-
leaf_capacity=leaf_capacity,
530-
leaf_radius=leaf_radius,
531-
pivoting=pivoting,
532-
)
523+
524+
def __get_cubical_cover(self):
525+
params = dict(
526+
n_intervals=self.n_intervals,
527+
overlap_frac=self.overlap_frac,
528+
kind=self.kind,
529+
leaf_capacity=self.leaf_capacity,
530+
leaf_radius=self.leaf_radius,
531+
pivoting=self.pivoting,
532+
)
533+
if self.algorithm == 'proximity':
534+
return ProximityCubicalCover(**params)
535+
elif self.algorithm == 'standard':
536+
return StandardCubicalCover(**params)
533537
else:
534538
raise ValueError(
535539
"The only possible values for algorithm are 'standard' and "
@@ -548,7 +552,9 @@ def fit(self, X):
548552
:return: The object itself.
549553
:rtype: self
550554
"""
551-
return self.__cubical_cover.fit(X)
555+
self.__cubical_cover = self.__get_cubical_cover()
556+
self.__cubical_cover.fit(X)
557+
return self
552558

553559
def search(self, x):
554560
"""
@@ -576,4 +582,5 @@ def apply(self, X):
576582
:return: A generator of lists of ids.
577583
:rtype: generator of lists of ints
578584
"""
585+
self.__cubical_cover = self.__get_cubical_cover()
579586
return self.__cubical_cover.apply(X)

tests/test_unit_core.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
mapper_labels,
1111
TrivialCover,
1212
)
13-
from tdamapper.cover import BallCover
13+
from tdamapper.cover import BallCover, CubicalCover, ProximityCubicalCover, StandardCubicalCover
1414
from tdamapper.clustering import TrivialClustering
1515

1616

@@ -80,7 +80,7 @@ def test_ball_large_radius(self):
8080
ccs2 = mapper_connected_components(data, data, cover, clustering)
8181
self.assertEqual(len(data), len(ccs2))
8282

83-
def test_two_disconnected_clusters(self):
83+
def test_ball_two_disconnected_clusters(self):
8484
data = [np.array([float(i), 0.0]) for i in range(100)]
8585
data.extend([np.array([float(i), 500.0]) for i in range(100)])
8686
data = np.array(data)
@@ -98,7 +98,7 @@ def test_two_disconnected_clusters(self):
9898
ccs2 = mapper_connected_components(data, data, cover, clustering)
9999
self.assertEqual(len(data), len(ccs2))
100100

101-
def test_two_connected_clusters(self):
101+
def test_ball_two_connected_clusters(self):
102102
data = [
103103
np.array([0.0, 1.0]), np.array([1.0, 0.0]),
104104
np.array([0.0, 0.0]), np.array([1.0, 1.0])]
@@ -116,7 +116,7 @@ def test_two_connected_clusters(self):
116116
ccs2 = mapper_connected_components(data, data, cover, clustering)
117117
self.assertEqual(len(data), len(ccs2))
118118

119-
def test_two_connected_clusters_parallel(self):
119+
def test_ball_two_connected_clusters_parallel(self):
120120
data = [
121121
np.array([0.0, 1.0]), np.array([1.0, 0.0]),
122122
np.array([0.0, 0.0]), np.array([1.0, 1.0])]
@@ -136,7 +136,39 @@ def test_two_connected_clusters_parallel(self):
136136
ccs2 = mapper_connected_components(data, data, cover, clustering)
137137
self.assertEqual(len(data), len(ccs2))
138138

139-
def test_connected_components(self):
139+
def test_proximity_cubical_line(self):
140+
data = np.array([[float(i)] for i in range(1000)])
141+
cover = ProximityCubicalCover(n_intervals=4, overlap_frac=0.5)
142+
clustering = TrivialClustering()
143+
mp = MapperAlgorithm(cover, clustering)
144+
g = mp.fit_transform(data, data)
145+
self.assertEqual(4, len(g.nodes))
146+
147+
def test_standard_cubical_line(self):
148+
data = np.array([[float(i)] for i in range(1000)])
149+
cover = StandardCubicalCover(n_intervals=4, overlap_frac=0.5)
150+
clustering = TrivialClustering()
151+
mp = MapperAlgorithm(cover, clustering)
152+
g = mp.fit_transform(data, data)
153+
self.assertEqual(4, len(g.nodes))
154+
155+
def test_cubical_line(self):
156+
data = np.array([[float(i)] for i in range(1000)])
157+
cover = CubicalCover(n_intervals=4, overlap_frac=0.5)
158+
clustering = TrivialClustering()
159+
mp = MapperAlgorithm(cover, clustering)
160+
g = mp.fit_transform(data, data)
161+
self.assertEqual(4, len(g.nodes))
162+
163+
def test_cubical_no_overlap(self):
164+
data = np.array([[0.0], [1.0], [2.0]])
165+
cover = StandardCubicalCover(n_intervals=2, overlap_frac=0)
166+
clustering = TrivialClustering()
167+
mp = MapperAlgorithm(cover, clustering)
168+
with self.assertRaises(ValueError):
169+
mp.fit_transform(data, data)
170+
171+
def test_mock_connected_components(self):
140172
data = [0, 1, 2, 3]
141173

142174
class MockCover:
@@ -156,7 +188,7 @@ def apply(self, X):
156188
self.assertEqual(cc0, ccs[2])
157189
self.assertEqual(cc0, ccs[3])
158190

159-
def test_labels(self):
191+
def test_mock_labels(self):
160192
data = [0, 1, 2, 3]
161193

162194
class MockCover:

tests/test_unit_params.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import unittest
2+
3+
from sklearn.cluster import DBSCAN
4+
5+
from tdamapper._common import clone
6+
from tdamapper.core import MapperAlgorithm
7+
from tdamapper.cover import (
8+
BallCover,
9+
KNNCover,
10+
CubicalCover
11+
)
12+
from tdamapper.clustering import MapperClustering
13+
14+
15+
class TestParams(unittest.TestCase):
16+
17+
def __test_clone(self, obj):
18+
obj_repr = repr(obj)
19+
obj_cln = clone(obj)
20+
cln_repr = repr(obj_cln)
21+
self.assertEqual(obj_repr, cln_repr)
22+
23+
def __test_repr(self, obj):
24+
obj_repr = repr(obj)
25+
_obj = eval(obj_repr)
26+
_obj_repr = repr(_obj)
27+
self.assertEqual(obj_repr, _obj_repr)
28+
29+
def __test_clone_and_repr(self, obj):
30+
self.__test_clone(obj)
31+
self.__test_repr(obj)
32+
33+
def test_params_mapper_algorithm(self):
34+
est = MapperAlgorithm(
35+
cover=CubicalCover(
36+
n_intervals=3,
37+
overlap_frac=0.3,
38+
),
39+
)
40+
params = est.get_params(deep=False)
41+
self.assertEqual(5, len(params))
42+
params = est.get_params()
43+
self.assertEqual(12, len(params))
44+
self.assertEqual(3, params['cover__n_intervals'])
45+
self.assertEqual(0.3, params['cover__overlap_frac'])
46+
est.set_params(cover__n_intervals=2, cover__overlap_frac=0.2)
47+
params = est.get_params()
48+
self.assertEqual(12, len(params))
49+
self.assertEqual(2, params['cover__n_intervals'])
50+
self.assertEqual(0.2, params['cover__overlap_frac'])
51+
52+
def test_params_mapper_clustering(self):
53+
est = MapperClustering(
54+
cover=CubicalCover(
55+
n_intervals=3,
56+
overlap_frac=0.3,
57+
),
58+
)
59+
params = est.get_params(deep=False)
60+
self.assertEqual(3, len(params))
61+
params = est.get_params()
62+
self.assertEqual(10, len(params))
63+
self.assertEqual(3, params['cover__n_intervals'])
64+
self.assertEqual(0.3, params['cover__overlap_frac'])
65+
est.set_params(cover__n_intervals=2, cover__overlap_frac=0.2)
66+
params = est.get_params()
67+
self.assertEqual(10, len(params))
68+
self.assertEqual(2, params['cover__n_intervals'])
69+
self.assertEqual(0.2, params['cover__overlap_frac'])
70+
71+
def test_clone_and_repr_ball_cover(self):
72+
self.__test_clone_and_repr(BallCover())
73+
self.__test_clone_and_repr(BallCover(
74+
radius=2.0,
75+
metric='test',
76+
metric_params={'f': 4},
77+
kind='kind_test',
78+
leaf_capacity=3.0,
79+
leaf_radius=-2.0,
80+
pivoting=7,
81+
))
82+
83+
def test_clone_and_repr_cubical_cover(self):
84+
self.__test_clone_and_repr(CubicalCover())
85+
self.__test_clone_and_repr(CubicalCover(
86+
n_intervals=4,
87+
overlap_frac=5,
88+
algorithm='algo_test',
89+
kind='simple',
90+
leaf_radius=5,
91+
leaf_capacity=6,
92+
pivoting='no'
93+
))
94+
95+
def test_clone_repr_mapper_algorithm(self):
96+
self.__test_clone_and_repr(MapperAlgorithm())
97+
self.__test_clone_and_repr(MapperAlgorithm(
98+
cover=CubicalCover(
99+
n_intervals=3,
100+
overlap_frac=0.3,
101+
),
102+
clustering=DBSCAN(
103+
eps='none',
104+
min_samples=5.4,
105+
),
106+
failsafe=4,
107+
n_jobs='foo',
108+
verbose=4,
109+
))
110+
111+
def test_clone_repr_mapper_clustering(self):
112+
self.__test_clone_and_repr(MapperClustering())
113+
self.__test_clone_and_repr(MapperClustering(
114+
cover=CubicalCover(
115+
n_intervals=3,
116+
overlap_frac=0.3,
117+
),
118+
clustering=DBSCAN(
119+
eps='none',
120+
min_samples=5.4,
121+
),
122+
n_jobs='foo',
123+
))

0 commit comments

Comments
 (0)