Skip to content

Commit fc09bc4

Browse files
authored
Merge pull request #58 from lucasimi/develop
Develop
2 parents 905901c + 863014c commit fc09bc4

File tree

3 files changed

+78
-55
lines changed

3 files changed

+78
-55
lines changed

src/tdamapper/core.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ def mapper_labels(X, y, cover, clustering):
2121
:type cover: A class from `tdamapper.cover`.
2222
:param clustering: A clustering algorithm.
2323
:type clustering: A class from `tdamapper.clustering` or a class from `sklearn.cluster`.
24-
:return: A list where each item is a sorted list of ints with no duplicate.
25-
The list at position `i` contains the cluster labels to which the point at position `i` in `X`
26-
belongs to. If `i < j`, the labels at position `i` are strictly less then those at position `j`.
27-
:rtype: `list[list[int]]`.
24+
:return: A list where each item is a sorted list of ints with no duplicate. The list at
25+
position `i` contains the cluster labels to which the point at position `i` in `X`
26+
belongs to. If `i < j`, the labels at position `i` are strictly less then those at
27+
position `j`.
28+
:rtype: `list[list[int]]`
2829
'''
2930
itm_lbls = [[] for _ in X]
3031
max_lbl = 0
@@ -43,9 +44,10 @@ def mapper_labels(X, y, cover, clustering):
4344
def mapper_connected_components(X, y, cover, clustering):
4445
'''
4546
Computes the connected components of the Mapper graph.
46-
The algorithm computes the connected components using a union-find data structure.
47-
This approach should be faster than computing the Mapper graph by first calling
48-
`tdamapper.core.mapper_graph` and then calling `networkx.connected_components` on it.
47+
48+
The algorithm computes the connected components using a union-find data structure. This
49+
approach should be faster than computing the Mapper graph by first calling `mapper_graph`
50+
and then calling `networkx.connected_components` on it.
4951
5052
:param X: A dataset.
5153
:type X: `numpy.ndarray` or list-like.
@@ -55,9 +57,9 @@ def mapper_connected_components(X, y, cover, clustering):
5557
:type cover: A class from `tdamapper.cover`.
5658
:param clustering: A clustering algorithm.
5759
:type clustering: A class from `tdamapper.clustering` or a class from `sklearn.cluster`.
58-
:return: A list of labels, where the value at position `i` identifies
59-
the connected component of the point `X[i]`.
60-
:rtype: `list[int]`.
60+
:return: A list of labels, where the value at position `i` identifies the connected
61+
component of the point `X[i]`.
62+
:rtype: `list[int]`
6163
'''
6264
itm_lbls = mapper_labels(X, y, cover, clustering)
6365
label_values = set()
@@ -67,12 +69,11 @@ def mapper_connected_components(X, y, cover, clustering):
6769
labels = [-1 for _ in X]
6870
for lbls in itm_lbls:
6971
len_lbls = len(lbls)
72+
root = -1
7073
# noise points
71-
if len_lbls == 0:
72-
root = -1
73-
elif len_lbls == 1:
74+
if len_lbls == 1:
7475
root = uf.find(lbls[0])
75-
else:
76+
elif len_lbls > 1:
7677
for first, second in zip(lbls, lbls[1:]):
7778
root = uf.union(first, second)
7879
labels.append(root)
@@ -92,7 +93,7 @@ def mapper_graph(X, y, cover, clustering):
9293
:param clustering: A clustering algorithm.
9394
:type clustering: A class from `tdamapper.clustering` or a class from `sklearn.cluster`.
9495
:return: The Mapper graph.
95-
:rtype: `networkx.Graph`.
96+
:rtype: `networkx.Graph`
9697
'''
9798
itm_lbls = mapper_labels(X, y, cover, clustering)
9899
graph = nx.Graph()
@@ -125,7 +126,7 @@ def aggregate_graph(y, graph, agg):
125126
:param agg: An aggregation function.
126127
:type agg: Callable.
127128
:return: A dict of values, where each node is mapped to its aggregation.
128-
:rtype: `dict`.
129+
:rtype: `dict`
129130
'''
130131
agg_values = {}
131132
nodes = graph.nodes()
@@ -143,7 +144,8 @@ class MapperAlgorithm:
143144
:param cover: A cover algorithm.
144145
:type cover: A class from `tdamapper.cover`.
145146
:param clustering: A clustering algorithm.
146-
:type clustering: A class from `tdamapper.clustering` or a class from `sklearn.cluster`.
147+
:type clustering: A class from `tdamapper.clustering`
148+
or a class from `sklearn.cluster`
147149
'''
148150

149151
def __init__(self, cover, clustering):
@@ -153,7 +155,7 @@ def __init__(self, cover, clustering):
153155

154156
def fit(self, X, y=None):
155157
'''
156-
Computes the Mapper Graph
158+
Computes the Mapper Graph.
157159
158160
:param X: A dataset.
159161
:type X: `numpy.ndarray` or list-like.

src/tdamapper/cover.py

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
'''A module containing the logic for building open covers for the Mapper algorithm.'''
22
import numpy as np
33

4-
from tdamapper.utils.vptree_flat import VPTree
4+
from tdamapper.utils.vptree_flat import VPTree as FVPT
5+
from tdamapper.utils.vptree import VPTree as VPT
56

67

78
def proximity_net(X, proximity):
89
'''
910
Compute proximity-net for a given proximity function.
10-
Returns a generator where each item is a subset of ids
11-
of points from `X`.
11+
12+
Returns a generator where each item is a subset of ids of points from `X`.
1213
1314
:param X: A dataset.
1415
:type X: `numpy.ndarray` or list-like.
1516
:param proximity: A proximity function.
16-
:type proximity: `tdamapper.cover.Proximity`.
17+
:type proximity: `tdamapper.cover.Proximity`
1718
'''
1819
covered_ids = set()
1920
proximity.fit(X)
@@ -27,9 +28,9 @@ def proximity_net(X, proximity):
2728

2829
class Proximity:
2930
'''
30-
This class serves as a blueprint for proximity functions used inside
31-
`tdamapper.cover.proximity_net`. Subclasses are expected to override
32-
the methods `fit` and `search`.
31+
This class serves as a blueprint for proximity functions used inside `proximity_net`.
32+
33+
Subclasses are expected to override the methods `fit` and `search`.
3334
'''
3435

3536
def fit(self, X):
@@ -55,25 +56,34 @@ def apply(self, X):
5556
class BallCover(ProximityNetCover):
5657
'''
5758
Creates an open cover made of overlapping open balls of fixed radius.
58-
This class implements the Ball Proximity function: after calling fit on X,
59-
the `BallCover.search` method returns all the points within a ball centered in the target point.
59+
60+
This class implements the Ball Proximity function: after calling `fit`, the `search` method
61+
returns all the points within a ball centered in the target point.
6062
6163
:param radius: The radius of open balls
6264
:type radius: float.
6365
:param metric: The metric used to define open balls.
6466
:type metric: Callable.
67+
:param flat: Set to True to use flat vptrees.
68+
:type flat: `bool`
6569
'''
6670

67-
def __init__(self, radius, metric):
71+
def __init__(self, radius, metric, flat=True):
6872
self.__metric = lambda x, y: metric(x[1], y[1])
6973
self.__radius = radius
7074
self.__data = None
7175
self.__vptree = None
76+
self.__flat = flat
77+
78+
def __flat_vpt(self):
79+
return FVPT(self.__metric, self.__data, leaf_radius=self.__radius)
80+
81+
def __vpt(self):
82+
return VPT(self.__metric, self.__data, leaf_radius=self.__radius)
7283

7384
def fit(self, X):
7485
self.__data = list(enumerate(X))
75-
self.__vptree = VPTree(
76-
self.__metric, self.__data, leaf_radius=self.__radius)
86+
self.__vptree = self.__flat_vpt() if self.__flat else self.__vpt()
7787
return self
7888

7989
def search(self, x):
@@ -86,24 +96,34 @@ def search(self, x):
8696
class KNNCover(ProximityNetCover):
8797
'''
8898
Creates an open cover where each open set containes a fixed number of neighbors, using KNN.
89-
This class implements the KNN Proximity function: after calling fit on X,
90-
the `KNNCover.search` method returns the k nearest points to the target point.
99+
100+
This class implements the KNN Proximity function: after calling `fit`, the `search` method
101+
returns the k nearest points to the target point.
91102
92103
:param neighbors: The number of neighbors.
93104
:type neighbors: int.
94105
:param metric: The metric used to search neighbors.
95106
:type metric: function.
107+
:param flat: Set to True to use flat vptrees.
108+
:type flat: `bool`
96109
'''
97110

98-
def __init__(self, neighbors, metric):
111+
def __init__(self, neighbors, metric, flat=True):
99112
self.__neighbors = neighbors
100113
self.__metric = lambda x, y: metric(x[1], y[1])
101114
self.__data = None
102115
self.__vptree = None
116+
self.__flat = flat
117+
118+
def __flat_vpt(self):
119+
return FVPT(self.__metric, self.__data, leaf_capacity=self.__neighbors)
120+
121+
def __vpt(self):
122+
return VPT(self.__metric, self.__data, leaf_capacity=self.__neighbors)
103123

104124
def fit(self, X):
105125
self.__data = list(enumerate(X))
106-
self.__vptree = VPTree(self.__metric, self.__data, leaf_capacity=self.__neighbors)
126+
self.__vptree = self.__flat_vpt() if self.__flat else self.__vpt()
107127
return self
108128

109129
def search(self, x):
@@ -116,26 +136,27 @@ def search(self, x):
116136
class CubicalCover(ProximityNetCover):
117137
'''
118138
Creates an open cover of hypercubes of evenly-sized sides and overlap.
119-
This class implements the Cubical Proximity function: after calling fit on X,
120-
the `CubicalCover.search` method returns the hypercube whose center is nearest to
121-
the target point. Each hypercube is the product of 1-dimensional intervals
122-
with the same lenght and overlap.
139+
140+
This class implements the Cubical Proximity function: after calling `fit`, the `search` method
141+
returns the hypercube whose center is nearest to the target point. Each hypercube is the
142+
product of 1-dimensional intervals with the same lenght and overlap.
123143
124144
:param n_intervals: The number of intervals on each dimension.
125145
:type n_intervals: int.
126146
:param overlap_frac: The overlap fraction.
127-
:type overlap_frac: float in (0.0, 1.0).
147+
:type overlap_frac: `float` in (0.0, 1.0).
148+
:param flat: Set to True to use flat vptrees.
149+
:type flat: `bool`
128150
'''
129151

130-
def __init__(self, n_intervals, overlap_frac):
152+
def __init__(self, n_intervals, overlap_frac, flat=True):
131153
self.__n_intervals = n_intervals
132-
self.__overlap_frac = overlap_frac
133154
self.__radius = 1.0 / (2.0 - 2.0 * overlap_frac)
134155
self.__minimum = None
135156
self.__maximum = None
136157
self.__delta = None
137158
metric = self._pullback(self._gamma_n, self._l_infty)
138-
self.__ball_proximity = BallCover(self.__radius, metric)
159+
self.__ball_proximity = BallCover(self.__radius, metric, flat=flat)
139160

140161
def _l_infty(self, x, y):
141162
return np.max(np.abs(x - y)) # in alternative: np.linalg.norm(x - y, ord=np.inf)
@@ -163,16 +184,15 @@ def _set_bounds(self, data):
163184
for w in data:
164185
minimum = np.minimum(minimum, np.array(w))
165186
maximum = np.maximum(maximum, np.array(w))
166-
self.__minimum = np.nan_to_num(minimum, nan=-eps)
167-
self.__maximum = np.nan_to_num(maximum, nan=eps)
187+
self.__minimum = np.nan_to_num(minimum, nan=-float(eps))
188+
self.__maximum = np.nan_to_num(maximum, nan=float(eps))
168189
delta = self.__maximum - self.__minimum
169-
eps = np.finfo(np.float64).eps
170190
self.__delta = np.maximum(eps, delta)
171191

172192
def fit(self, X):
173193
self._set_bounds(X)
174194
self.__ball_proximity.fit(X)
175-
return
195+
return self
176196

177197
def search(self, x):
178198
return self.__ball_proximity.search(self._phi(x))

src/tdamapper/plot.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
'''A module for plotting the Mapper graph.'''
12
import math
23

34
import numpy as np
@@ -28,15 +29,15 @@ class MapperPlot:
2829
:param X: A dataset.
2930
:type X: `numpy.ndarray` or list-like.
3031
:param graph: The Mapper graph.
31-
:type graph: `networkx.Graph`.
32+
:type graph: `networkx.Graph`
3233
:param colors: A dataset of values to plot as nodes color.
3334
:type colors: `numpy.ndarray` or list-like.
3435
:param agg: Aggregation function that computes nodes color.
3536
:type agg: Callable on the values of colors.
3637
:param cmap: A colormap, to convert values into colors.
37-
:type cmap: `str`.
38-
:param **kwargs: Additional arguments to networkx.spring_layout.
39-
:type: `dict`.
38+
:type cmap: `str`
39+
:param kwargs: Additional arguments to networkx.spring_layout.
40+
:type: `dict`
4041
'''
4142

4243
def __init__(
@@ -77,12 +78,12 @@ def plot(self, *args, style='interactive', **kwargs):
7778
'''
7879
Turns the plot object into a displayable figure.
7980
80-
:param *args: Arguments to supply.
81-
:type *args: `list`.
81+
:param args: Arguments to supply.
82+
:type args: `list`
8283
:param style: The type of plot, can either be 'interactive' or 'static'.
83-
:type style: `str`.
84-
:param **kwargs: Additional arguments to supply.
85-
:type: `dict`.
84+
:type style: `str`
85+
:param kwargs: Additional arguments to supply.
86+
:type kwargs: `dict`
8687
'''
8788
if not self.__pos:
8889
return

0 commit comments

Comments
 (0)