Skip to content

Commit f5b51fe

Browse files
authored
Merge pull request #208 from lucasimi/develop
Develop
2 parents 3ac3a9b + 854b64f commit f5b51fe

File tree

3 files changed

+82
-87
lines changed

3 files changed

+82
-87
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ from tdamapper.plot import MapperPlot
9292

9393
# Generate toy dataset
9494
X, labels = make_circles(n_samples=5000, noise=0.05, factor=0.3, random_state=42)
95-
plt.scatter(X[:,0], X[:,1], c=labels, cmap='jet', s=0.25)
95+
plt.figure(figsize=(5, 5))
96+
plt.scatter(X[:,0], X[:,1], c=labels, s=0.25, cmap="jet")
97+
plt.axis("off")
9698
plt.show()
9799

98100
# Apply PCA as lens
@@ -105,7 +107,7 @@ graph = MapperAlgorithm(cover, clust).fit_transform(X, y)
105107

106108
# Visualize the Mapper graph
107109
fig = MapperPlot(graph, dim=2, seed=42, iterations=60).plot_plotly(colors=labels)
108-
fig.show(config={'scrollZoom': True})
110+
fig.show(config={"scrollZoom": True})
109111
```
110112

111113
| Original Dataset | Mapper Graph |

docs/source/notebooks/circles_online.py

Lines changed: 35 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,74 +17,63 @@
1717

1818
# %%
1919
import numpy as np
20-
2120
from matplotlib import pyplot as plt
22-
21+
from sklearn.cluster import DBSCAN
2322
from sklearn.datasets import make_circles
2423
from sklearn.decomposition import PCA
25-
from sklearn.cluster import DBSCAN
2624

27-
from tdamapper.learn import MapperAlgorithm
2825
from tdamapper.cover import CubicalCover
26+
from tdamapper.learn import MapperAlgorithm
2927
from tdamapper.plot import MapperPlot
3028

31-
X, y = make_circles( # load a labelled dataset
32-
n_samples=5000,
33-
noise=0.05,
34-
factor=0.3,
35-
random_state=42
36-
)
37-
lens = PCA(2, random_state=42).fit_transform(X)
29+
width, height, dpi = 500, 500, 100
30+
31+
# Generate toy dataset
32+
X, labels = make_circles(n_samples=5000, noise=0.05, factor=0.3, random_state=42)
33+
34+
fig = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi)
35+
plt.scatter(X[:, 0], X[:, 1], c=labels, s=0.25, cmap="jet")
36+
plt.axis("off")
37+
plt.show()
38+
# fig.savefig("circles_dataset.png", dpi=dpi)
39+
40+
# Apply PCA as lens
41+
y = PCA(2, random_state=42).fit_transform(X)
3842

39-
plt.scatter(lens[:, 0], lens[:, 1], c=y, cmap='jet')
4043

4144
# %% [markdown]
4245
# ### Build Mapper graph
4346

4447
# %%
45-
mapper_algo = MapperAlgorithm(
46-
cover=CubicalCover(
47-
n_intervals=10,
48-
overlap_frac=0.3
49-
),
50-
clustering=DBSCAN()
51-
)
52-
53-
mapper_graph = mapper_algo.fit_transform(X, lens)
48+
cover = CubicalCover(n_intervals=10, overlap_frac=0.3)
49+
clust = DBSCAN()
50+
mapper = MapperAlgorithm(cover=cover, clustering=clust)
51+
graph = mapper.fit_transform(X, y)
5452

5553
# %% [markdown]
5654
# ### Plot Mapper graph with mean
5755

5856
# %%
59-
mapper_plot = MapperPlot(
60-
mapper_graph,
61-
dim=2,
62-
iterations=60,
63-
seed=42
64-
)
57+
plot = MapperPlot(graph, dim=2, iterations=60, seed=42)
6558

66-
fig = mapper_plot.plot_plotly(
67-
colors=y, # color according to categorical values
68-
cmap='jet', # Jet colormap, for classes
69-
agg=np.nanmean, # aggregate on nodes according to mean
59+
fig = plot.plot_plotly(
60+
colors=labels, # color according to categorical values
61+
cmap="jet", # Jet colormap, for classes
62+
agg=np.nanmean, # aggregate on nodes according to mean
7063
width=600,
71-
height=600
64+
height=600,
7265
)
7366

74-
fig.show(
75-
renderer='notebook_connected',
76-
config={'scrollZoom': True}
77-
)
67+
fig.show(renderer="notebook_connected", config={"scrollZoom": True})
68+
# fig.write_image("circles_mean.png", width=width, height=height)
7869

7970
# %%
80-
mapper_plot.plot_plotly_update(
81-
fig, # update the old figure
82-
colors=y,
83-
cmap='viridis', # viridis colormap, for ranges
84-
agg=np.nanstd # aggregate on nodes according to std
85-
)
86-
87-
fig.show(
88-
renderer='notebook_connected',
89-
config={'scrollZoom': True}
71+
plot.plot_plotly_update(
72+
fig, # update the old figure
73+
colors=labels,
74+
cmap="viridis", # viridis colormap, for ranges
75+
agg=np.nanstd, # aggregate on nodes according to std
9076
)
77+
78+
fig.show(renderer="notebook_connected", config={"scrollZoom": True})
79+
# fig.write_image("circles_std.png", width=width, height=height)

docs/source/notebooks/digits_online.py

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,74 +17,78 @@
1717

1818
# %%
1919
import numpy as np
20-
21-
from sklearn.datasets import load_digits
2220
from sklearn.cluster import AgglomerativeClustering
21+
from sklearn.datasets import load_digits
2322
from sklearn.decomposition import PCA
2423

25-
from tdamapper.learn import MapperAlgorithm
26-
from tdamapper.cover import CubicalCover
2724
from tdamapper.clustering import FailSafeClustering
25+
from tdamapper.cover import CubicalCover
26+
from tdamapper.learn import MapperAlgorithm
2827
from tdamapper.plot import MapperPlot
2928

29+
# We load a labelled dataset
30+
X, labels = load_digits(return_X_y=True)
3031

31-
X, y = load_digits(return_X_y=True) # We load a labelled dataset
32-
lens = PCA(2, random_state=42).fit_transform(X) # We compute the lens values
32+
# Apply PCA as lens
33+
y = PCA(2, random_state=42).fit_transform(X)
3334

3435
# %% [markdown]
3536
# ### Build Mapper graph
3637

3738
# %%
38-
mapper_algo = MapperAlgorithm(
39-
cover=CubicalCover(
40-
n_intervals=10,
41-
overlap_frac=0.65
42-
),
39+
algo = MapperAlgorithm(
40+
cover=CubicalCover(n_intervals=10, overlap_frac=0.5),
4341
clustering=AgglomerativeClustering(10),
44-
verbose=False
42+
verbose=False,
4543
)
4644

47-
mapper_graph = mapper_algo.fit_transform(X, lens)
45+
graph = algo.fit_transform(X, y)
4846

4947
# %% [markdown]
5048
# ### Plot Mapper graph with mean
5149

5250
# %%
53-
mapper_plot = MapperPlot(
54-
mapper_graph,
55-
dim=2,
56-
iterations=400,
57-
seed=42
58-
)
51+
plot = MapperPlot(graph, dim=3, iterations=400, seed=42)
5952

60-
fig = mapper_plot.plot_plotly(
61-
colors=y, # We color according to digit values
62-
cmap='jet', # Jet colormap, used for classes
63-
agg=np.nanmean, # We aggregate on graph nodes according to mean
64-
title='digit (mean)',
53+
fig = plot.plot_plotly(
54+
colors=labels, # We color according to digit values
55+
cmap="jet", # Jet colormap, used for classes
56+
agg=np.nanmean, # We aggregate on graph nodes according to mean
57+
title="digit (mean)",
6558
width=600,
66-
height=600
59+
height=600,
6760
)
6861

69-
fig.show(
70-
renderer='notebook_connected',
71-
config={'scrollZoom': True}
72-
)
62+
fig.show(renderer="notebook_connected", config={"scrollZoom": True})
7363

7464
# %% [markdown]
7565
# ### Plot Mapper graph with standard deviation
7666

7767
# %%
78-
fig = mapper_plot.plot_plotly(
79-
colors=y,
80-
cmap='viridis', # Viridis colormap, used for ranges
81-
agg=np.nanstd, # We aggregate on graph nodes according to std
82-
title='digit (std)',
68+
fig = plot.plot_plotly(
69+
colors=labels,
70+
cmap="viridis", # Viridis colormap, used for ranges
71+
agg=np.nanstd, # We aggregate on graph nodes according to std
72+
title="digit (std)",
8373
width=600,
84-
height=600
74+
height=600,
8575
)
8676

87-
fig.show(
88-
renderer='notebook_connected',
89-
config={'scrollZoom': True}
90-
)
77+
fig.show(renderer="notebook_connected", config={"scrollZoom": True})
78+
79+
# %% [markdown]
80+
# ### Inspect interesting nodes
81+
82+
# %%
83+
from matplotlib import pyplot as plt
84+
85+
# By interacting with the plot we see that node 140 is joining the cluster of
86+
# digit 0 with the cluster of digit 4. Let's see how digits inside look like!
87+
88+
node_140 = [X[i, :] for i in graph.nodes()[140]["ids"]]
89+
fig, axes = plt.subplots(1, len(node_140))
90+
for dgt, ax in zip(node_140, axes):
91+
ax.imshow(dgt.reshape(8, 8), cmap="gray")
92+
ax.axis("off")
93+
plt.tight_layout()
94+
plt.show()

0 commit comments

Comments
 (0)