Skip to content

Commit 905901c

Browse files
authored
Merge pull request #53 from lucasimi/develop
Develop
2 parents 6a07371 + c312e80 commit 905901c

File tree

1 file changed

+37
-25
lines changed

1 file changed

+37
-25
lines changed

src/tdamapper/cover.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,40 +4,52 @@
44
from tdamapper.utils.vptree_flat import VPTree
55

66

7-
class ProximityNetCover:
7+
def proximity_net(X, proximity):
88
'''
9-
This class serves as a blueprint for proximity-based cover algorithm
10-
and implements proximity-net in the `ProximityNetCover.apply` method.
11-
Subclasses are expected to override the methods `ProximityNetCover.fit`
12-
and `ProximityNetCover.search`.
9+
Compute proximity-net for a given proximity function.
10+
Returns a generator where each item is a subset of ids
11+
of points from `X`.
12+
13+
:param X: A dataset.
14+
:type X: `numpy.ndarray` or list-like.
15+
:param proximity: A proximity function.
16+
:type proximity: `tdamapper.cover.Proximity`.
1317
'''
18+
covered_ids = set()
19+
proximity.fit(X)
20+
for i, xi in enumerate(X):
21+
if i not in covered_ids:
22+
neigh_ids = proximity.search(xi)
23+
covered_ids.update(neigh_ids)
24+
if neigh_ids:
25+
yield neigh_ids
1426

15-
def __init__(self):
16-
pass
1727

18-
def apply(self, X):
19-
'''
20-
Compute proximity-net for a given open cover.
21-
Returns a generator where each item is a subset of ids
22-
of points from `X`.
23-
24-
:param X: A dataset.
25-
:type X: `numpy.ndarray` or list-like.
26-
'''
27-
covered_ids = set()
28-
self.fit(X)
29-
for i, xi in enumerate(X):
30-
if i not in covered_ids:
31-
neigh_ids = self.search(xi)
32-
covered_ids.update(neigh_ids)
33-
if neigh_ids:
34-
yield neigh_ids
28+
class Proximity:
29+
'''
30+
This class serves as a blueprint for proximity functions used inside
31+
`tdamapper.cover.proximity_net`. Subclasses are expected to override
32+
the methods `fit` and `search`.
33+
'''
3534

3635
def fit(self, X):
36+
self.__X = X
3737
return self
3838

3939
def search(self, x):
40-
return []
40+
return [i for i, _ in enumerate(self.__X)]
41+
42+
43+
class ProximityNetCover(Proximity):
44+
'''
45+
This class serves as a blueprint for cover algorithm based on proximity-net.
46+
'''
47+
48+
def __init__(self):
49+
pass
50+
51+
def apply(self, X):
52+
return proximity_net(X, self)
4153

4254

4355
class BallCover(ProximityNetCover):

0 commit comments

Comments
 (0)