Skip to content

Commit 863014c

Browse files
authored
Merge pull request #57 from lucasimi/feature/alignement
Improved docstrings
2 parents 7fc17ff + 4a274ba commit 863014c

File tree

3 files changed

+47
-41
lines changed

3 files changed

+47
-41
lines changed

src/tdamapper/core.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ def mapper_labels(X, y, cover, clustering):
2121
:type cover: A class from `tdamapper.cover`.
2222
:param clustering: A clustering algorithm.
2323
:type clustering: A class from `tdamapper.clustering` or a class from `sklearn.cluster`.
24-
:return: A list where each item is a sorted list of ints with no duplicate.
25-
The list at position `i` contains the cluster labels to which the point at position `i` in `X`
26-
belongs to. If `i < j`, the labels at position `i` are strictly less then those at position `j`.
27-
:rtype: `list[list[int]]`.
24+
:return: A list where each item is a sorted list of ints with no duplicate. The list at
25+
position `i` contains the cluster labels to which the point at position `i` in `X`
26+
belongs to. If `i < j`, the labels at position `i` are strictly less then those at
27+
position `j`.
28+
:rtype: `list[list[int]]`
2829
'''
2930
itm_lbls = [[] for _ in X]
3031
max_lbl = 0
@@ -43,9 +44,10 @@ def mapper_labels(X, y, cover, clustering):
4344
def mapper_connected_components(X, y, cover, clustering):
4445
'''
4546
Computes the connected components of the Mapper graph.
46-
The algorithm computes the connected components using a union-find data structure.
47-
This approach should be faster than computing the Mapper graph by first calling
48-
`tdamapper.core.mapper_graph` and then calling `networkx.connected_components` on it.
47+
48+
The algorithm computes the connected components using a union-find data structure. This
49+
approach should be faster than computing the Mapper graph by first calling `mapper_graph`
50+
and then calling `networkx.connected_components` on it.
4951
5052
:param X: A dataset.
5153
:type X: `numpy.ndarray` or list-like.
@@ -55,9 +57,9 @@ def mapper_connected_components(X, y, cover, clustering):
5557
:type cover: A class from `tdamapper.cover`.
5658
:param clustering: A clustering algorithm.
5759
:type clustering: A class from `tdamapper.clustering` or a class from `sklearn.cluster`.
58-
:return: A list of labels, where the value at position `i` identifies
59-
the connected component of the point `X[i]`.
60-
:rtype: `list[int]`.
60+
:return: A list of labels, where the value at position `i` identifies the connected
61+
component of the point `X[i]`.
62+
:rtype: `list[int]`
6163
'''
6264
itm_lbls = mapper_labels(X, y, cover, clustering)
6365
label_values = set()
@@ -91,7 +93,7 @@ def mapper_graph(X, y, cover, clustering):
9193
:param clustering: A clustering algorithm.
9294
:type clustering: A class from `tdamapper.clustering` or a class from `sklearn.cluster`.
9395
:return: The Mapper graph.
94-
:rtype: `networkx.Graph`.
96+
:rtype: `networkx.Graph`
9597
'''
9698
itm_lbls = mapper_labels(X, y, cover, clustering)
9799
graph = nx.Graph()
@@ -124,7 +126,7 @@ def aggregate_graph(y, graph, agg):
124126
:param agg: An aggregation function.
125127
:type agg: Callable.
126128
:return: A dict of values, where each node is mapped to its aggregation.
127-
:rtype: `dict`.
129+
:rtype: `dict`
128130
'''
129131
agg_values = {}
130132
nodes = graph.nodes()
@@ -142,7 +144,8 @@ class MapperAlgorithm:
142144
:param cover: A cover algorithm.
143145
:type cover: A class from `tdamapper.cover`.
144146
:param clustering: A clustering algorithm.
145-
:type clustering: A class from `tdamapper.clustering` or a class from `sklearn.cluster`.
147+
:type clustering: A class from `tdamapper.clustering`
148+
or a class from `sklearn.cluster`
146149
'''
147150

148151
def __init__(self, cover, clustering):
@@ -152,7 +155,7 @@ def __init__(self, cover, clustering):
152155

153156
def fit(self, X, y=None):
154157
'''
155-
Computes the Mapper Graph
158+
Computes the Mapper Graph.
156159
157160
:param X: A dataset.
158161
:type X: `numpy.ndarray` or list-like.

src/tdamapper/cover.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
def proximity_net(X, proximity):
99
'''
1010
Compute proximity-net for a given proximity function.
11-
Returns a generator where each item is a subset of ids
12-
of points from `X`.
11+
12+
Returns a generator where each item is a subset of ids of points from `X`.
1313
1414
:param X: A dataset.
1515
:type X: `numpy.ndarray` or list-like.
1616
:param proximity: A proximity function.
17-
:type proximity: `tdamapper.cover.Proximity`.
17+
:type proximity: `tdamapper.cover.Proximity`
1818
'''
1919
covered_ids = set()
2020
proximity.fit(X)
@@ -28,9 +28,9 @@ def proximity_net(X, proximity):
2828

2929
class Proximity:
3030
'''
31-
This class serves as a blueprint for proximity functions used inside
32-
`tdamapper.cover.proximity_net`. Subclasses are expected to override
33-
the methods `fit` and `search`.
31+
This class serves as a blueprint for proximity functions used inside `proximity_net`.
32+
33+
Subclasses are expected to override the methods `fit` and `search`.
3434
'''
3535

3636
def fit(self, X):
@@ -56,15 +56,16 @@ def apply(self, X):
5656
class BallCover(ProximityNetCover):
5757
'''
5858
Creates an open cover made of overlapping open balls of fixed radius.
59-
This class implements the Ball Proximity function: after calling fit on X,
60-
the `BallCover.search` method returns all the points within a ball centered in the target point.
59+
60+
This class implements the Ball Proximity function: after calling `fit`, the `search` method
61+
returns all the points within a ball centered in the target point.
6162
6263
:param radius: The radius of open balls
6364
:type radius: float.
6465
:param metric: The metric used to define open balls.
6566
:type metric: Callable.
6667
:param flat: Set to True to use flat vptrees.
67-
:type flat: bool.
68+
:type flat: `bool`
6869
'''
6970

7071
def __init__(self, radius, metric, flat=True):
@@ -95,15 +96,16 @@ def search(self, x):
9596
class KNNCover(ProximityNetCover):
9697
'''
9798
Creates an open cover where each open set containes a fixed number of neighbors, using KNN.
98-
This class implements the KNN Proximity function: after calling fit on X,
99-
the `KNNCover.search` method returns the k nearest points to the target point.
99+
100+
This class implements the KNN Proximity function: after calling `fit`, the `search` method
101+
returns the k nearest points to the target point.
100102
101103
:param neighbors: The number of neighbors.
102104
:type neighbors: int.
103105
:param metric: The metric used to search neighbors.
104106
:type metric: function.
105107
:param flat: Set to True to use flat vptrees.
106-
:type flat: bool.
108+
:type flat: `bool`
107109
'''
108110

109111
def __init__(self, neighbors, metric, flat=True):
@@ -134,17 +136,17 @@ def search(self, x):
134136
class CubicalCover(ProximityNetCover):
135137
'''
136138
Creates an open cover of hypercubes of evenly-sized sides and overlap.
137-
This class implements the Cubical Proximity function: after calling fit on X,
138-
the `CubicalCover.search` method returns the hypercube whose center is nearest to
139-
the target point. Each hypercube is the product of 1-dimensional intervals
140-
with the same lenght and overlap.
139+
140+
This class implements the Cubical Proximity function: after calling `fit`, the `search` method
141+
returns the hypercube whose center is nearest to the target point. Each hypercube is the
142+
product of 1-dimensional intervals with the same lenght and overlap.
141143
142144
:param n_intervals: The number of intervals on each dimension.
143145
:type n_intervals: int.
144146
:param overlap_frac: The overlap fraction.
145-
:type overlap_frac: float in (0.0, 1.0).
147+
:type overlap_frac: `float` in (0.0, 1.0).
146148
:param flat: Set to True to use flat vptrees.
147-
:type flat: bool.
149+
:type flat: `bool`
148150
'''
149151

150152
def __init__(self, n_intervals, overlap_frac, flat=True):

src/tdamapper/plot.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
'''A module for plotting the Mapper graph.'''
12
import math
23

34
import numpy as np
@@ -28,15 +29,15 @@ class MapperPlot:
2829
:param X: A dataset.
2930
:type X: `numpy.ndarray` or list-like.
3031
:param graph: The Mapper graph.
31-
:type graph: `networkx.Graph`.
32+
:type graph: `networkx.Graph`
3233
:param colors: A dataset of values to plot as nodes color.
3334
:type colors: `numpy.ndarray` or list-like.
3435
:param agg: Aggregation function that computes nodes color.
3536
:type agg: Callable on the values of colors.
3637
:param cmap: A colormap, to convert values into colors.
37-
:type cmap: `str`.
38-
:param **kwargs: Additional arguments to networkx.spring_layout.
39-
:type: `dict`.
38+
:type cmap: `str`
39+
:param kwargs: Additional arguments to networkx.spring_layout.
40+
:type: `dict`
4041
'''
4142

4243
def __init__(
@@ -77,12 +78,12 @@ def plot(self, *args, style='interactive', **kwargs):
7778
'''
7879
Turns the plot object into a displayable figure.
7980
80-
:param *args: Arguments to supply.
81-
:type *args: `list`.
81+
:param args: Arguments to supply.
82+
:type args: `list`
8283
:param style: The type of plot, can either be 'interactive' or 'static'.
83-
:type style: `str`.
84-
:param **kwargs: Additional arguments to supply.
85-
:type: `dict`.
84+
:type style: `str`
85+
:param kwargs: Additional arguments to supply.
86+
:type kwargs: `dict`
8687
'''
8788
if not self.__pos:
8889
return

0 commit comments

Comments
 (0)