11'''A module containing the logic for building open covers for the Mapper algorithm.'''
22import numpy as np
33
4- from tdamapper .utils .vptree_flat import VPTree
4+ from tdamapper .utils .vptree_flat import VPTree as FVPT
5+ from tdamapper .utils .vptree import VPTree as VPT
56
67
78def proximity_net (X , proximity ):
89 '''
910 Compute proximity-net for a given proximity function.
10- Returns a generator where each item is a subset of ids
11- of points from `X`.
11+
12+ Returns a generator where each item is a subset of ids of points from `X`.
1213
1314 :param X: A dataset.
1415 :type X: `numpy.ndarray` or list-like.
1516 :param proximity: A proximity function.
16- :type proximity: `tdamapper.cover.Proximity`.
17+ :type proximity: `tdamapper.cover.Proximity`
1718 '''
1819 covered_ids = set ()
1920 proximity .fit (X )
@@ -27,9 +28,9 @@ def proximity_net(X, proximity):
2728
2829class Proximity :
2930 '''
30- This class serves as a blueprint for proximity functions used inside
31- `tdamapper.cover.proximity_net`. Subclasses are expected to override
32- the methods `fit` and `search`.
31+ This class serves as a blueprint for proximity functions used inside `proximity_net`.
32+
33+ Subclasses are expected to override the methods `fit` and `search`.
3334 '''
3435
3536 def fit (self , X ):
@@ -55,25 +56,34 @@ def apply(self, X):
5556class BallCover (ProximityNetCover ):
5657 '''
5758 Creates an open cover made of overlapping open balls of fixed radius.
58- This class implements the Ball Proximity function: after calling fit on X,
59- the `BallCover.search` method returns all the points within a ball centered in the target point.
59+
60+ This class implements the Ball Proximity function: after calling `fit`, the `search` method
61+ returns all the points within a ball centered in the target point.
6062
6163 :param radius: The radius of open balls
6264 :type radius: float.
6365 :param metric: The metric used to define open balls.
6466 :type metric: Callable.
67+ :param flat: Set to True to use flat vptrees.
68+ :type flat: `bool`
6569 '''
6670
67- def __init__ (self , radius , metric ):
71+ def __init__ (self , radius , metric , flat = True ):
6872 self .__metric = lambda x , y : metric (x [1 ], y [1 ])
6973 self .__radius = radius
7074 self .__data = None
7175 self .__vptree = None
76+ self .__flat = flat
77+
78+ def __flat_vpt (self ):
79+ return FVPT (self .__metric , self .__data , leaf_radius = self .__radius )
80+
81+ def __vpt (self ):
82+ return VPT (self .__metric , self .__data , leaf_radius = self .__radius )
7283
7384 def fit (self , X ):
7485 self .__data = list (enumerate (X ))
75- self .__vptree = VPTree (
76- self .__metric , self .__data , leaf_radius = self .__radius )
86+ self .__vptree = self .__flat_vpt () if self .__flat else self .__vpt ()
7787 return self
7888
7989 def search (self , x ):
@@ -86,24 +96,34 @@ def search(self, x):
8696class KNNCover (ProximityNetCover ):
8797 '''
8898 Creates an open cover where each open set containes a fixed number of neighbors, using KNN.
89- This class implements the KNN Proximity function: after calling fit on X,
90- the `KNNCover.search` method returns the k nearest points to the target point.
99+
100+ This class implements the KNN Proximity function: after calling `fit`, the `search` method
101+ returns the k nearest points to the target point.
91102
92103 :param neighbors: The number of neighbors.
93104 :type neighbors: int.
94105 :param metric: The metric used to search neighbors.
95106 :type metric: function.
107+ :param flat: Set to True to use flat vptrees.
108+ :type flat: `bool`
96109 '''
97110
98- def __init__ (self , neighbors , metric ):
111+ def __init__ (self , neighbors , metric , flat = True ):
99112 self .__neighbors = neighbors
100113 self .__metric = lambda x , y : metric (x [1 ], y [1 ])
101114 self .__data = None
102115 self .__vptree = None
116+ self .__flat = flat
117+
118+ def __flat_vpt (self ):
119+ return FVPT (self .__metric , self .__data , leaf_capacity = self .__neighbors )
120+
121+ def __vpt (self ):
122+ return VPT (self .__metric , self .__data , leaf_capacity = self .__neighbors )
103123
104124 def fit (self , X ):
105125 self .__data = list (enumerate (X ))
106- self .__vptree = VPTree ( self .__metric , self .__data , leaf_capacity = self .__neighbors )
126+ self .__vptree = self .__flat_vpt () if self .__flat else self .__vpt ( )
107127 return self
108128
109129 def search (self , x ):
@@ -116,26 +136,27 @@ def search(self, x):
116136class CubicalCover (ProximityNetCover ):
117137 '''
118138 Creates an open cover of hypercubes of evenly-sized sides and overlap.
119- This class implements the Cubical Proximity function: after calling fit on X,
120- the `CubicalCover.search` method returns the hypercube whose center is nearest to
121- the target point. Each hypercube is the product of 1-dimensional intervals
122- with the same lenght and overlap.
139+
140+ This class implements the Cubical Proximity function: after calling `fit`, the `search` method
141+ returns the hypercube whose center is nearest to the target point. Each hypercube is the
142+ product of 1-dimensional intervals with the same lenght and overlap.
123143
124144 :param n_intervals: The number of intervals on each dimension.
125145 :type n_intervals: int.
126146 :param overlap_frac: The overlap fraction.
127- :type overlap_frac: float in (0.0, 1.0).
147+ :type overlap_frac: `float` in (0.0, 1.0).
148+ :param flat: Set to True to use flat vptrees.
149+ :type flat: `bool`
128150 '''
129151
130- def __init__ (self , n_intervals , overlap_frac ):
152+ def __init__ (self , n_intervals , overlap_frac , flat = True ):
131153 self .__n_intervals = n_intervals
132- self .__overlap_frac = overlap_frac
133154 self .__radius = 1.0 / (2.0 - 2.0 * overlap_frac )
134155 self .__minimum = None
135156 self .__maximum = None
136157 self .__delta = None
137158 metric = self ._pullback (self ._gamma_n , self ._l_infty )
138- self .__ball_proximity = BallCover (self .__radius , metric )
159+ self .__ball_proximity = BallCover (self .__radius , metric , flat = flat )
139160
140161 def _l_infty (self , x , y ):
141162 return np .max (np .abs (x - y )) # in alternative: np.linalg.norm(x - y, ord=np.inf)
@@ -163,16 +184,15 @@ def _set_bounds(self, data):
163184 for w in data :
164185 minimum = np .minimum (minimum , np .array (w ))
165186 maximum = np .maximum (maximum , np .array (w ))
166- self .__minimum = np .nan_to_num (minimum , nan = - eps )
167- self .__maximum = np .nan_to_num (maximum , nan = eps )
187+ self .__minimum = np .nan_to_num (minimum , nan = - float ( eps ) )
188+ self .__maximum = np .nan_to_num (maximum , nan = float ( eps ) )
168189 delta = self .__maximum - self .__minimum
169- eps = np .finfo (np .float64 ).eps
170190 self .__delta = np .maximum (eps , delta )
171191
172192 def fit (self , X ):
173193 self ._set_bounds (X )
174194 self .__ball_proximity .fit (X )
175- return
195+ return self
176196
177197 def search (self , x ):
178198 return self .__ball_proximity .search (self ._phi (x ))
0 commit comments