Skip to content

Commit dc36cdb

Browse files
authored
Merge pull request #97 from lucasimi/develop
Develop
2 parents 89dd06f + d3a791c commit dc36cdb

File tree

3 files changed

+41
-26
lines changed

3 files changed

+41
-26
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "tda-mapper"
7-
version = "0.5.1"
7+
version = "0.5.2"
88
description = "A simple and efficient Python implementation of Mapper algorithm for Topological Data Analysis"
99
readme = "README.md"
1010
authors = [{ name = "Luca Simi", email = "lucasimi90@gmail.com" }]

src/tdamapper/plot.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def _node_pos(graph, dim, seed, iterations):
4242
iterations=iterations)
4343

4444

45+
def _node_col(graph, colors, agg, default=0.5):
46+
if colors is not None:
47+
return aggregate_graph(colors, graph, agg)
48+
else:
49+
return [default for _ in graph.nodes()]
50+
51+
4552
def _node_pos_array(graph, dim, node_pos):
4653
return tuple([node_pos[n][i] for n in graph.nodes()] for i in range(dim))
4754

@@ -116,12 +123,8 @@ def __init__(self, graph, dim,
116123
self.__width = width
117124
self.__height = height
118125
self.__cmap = cmap
119-
node_col = aggregate_graph(self.__colors, self.__graph, self.__agg)
126+
node_col = _node_col(self.__graph, self.__colors, self.__agg)
120127
self.__fig = self._figure(node_col)
121-
#self._update_traces_col()
122-
#self._update_layout()
123-
#self._update_traces_cmap()
124-
#self._update_traces_title()
125128

126129
def _nodes_trace(self, node_pos_arr, node_col):
127130
attr_size = nx.get_node_attributes(self.__graph, ATTR_SIZE)
@@ -192,7 +195,6 @@ def _layout(self):
192195
line_col = 'rgba(230, 230, 230, 1.0)'
193196
axis = dict(
194197
showline=True,
195-
#linecolor='rgba(230, 230, 230, 1.0)',
196198
linewidth=1,
197199
mirror=True,
198200
visible=True,
@@ -263,17 +265,13 @@ def _colorbar(self):
263265
elif self.__dim == 2:
264266
return go.scatter.marker.ColorBar(cbar)
265267

266-
def _text(self, colors=None):
268+
def _text(self, colors):
267269
attr_size = nx.get_node_attributes(self.__graph, ATTR_SIZE)
268-
if colors is None:
269-
def _lbl(n):
270-
size = _fmt(attr_size[n], 5)
271-
return f'node: {n}<br>size: {size}'
272-
else:
273-
def _lbl(n):
274-
col = _fmt(colors[n], 3)
275-
size = _fmt(attr_size[n], 5)
276-
return f'color: {col}<br>node: {n}<br>size: {size}'
270+
271+
def _lbl(n):
272+
col = _fmt(colors[n], 3)
273+
size = _fmt(attr_size[n], 5)
274+
return f'color: {col}<br>node: {n}<br>size: {size}'
277275
return [_lbl(n) for n in self.__graph.nodes()]
278276

279277
def _update_traces_pos(self):
@@ -311,10 +309,10 @@ def _update_traces_pos(self):
311309

312310
def _update_traces_col(self):
313311
if (self.__colors is not None) and (self.__agg is not None):
314-
colors_agg = aggregate_graph(self.__colors, self.__graph, self.__agg)
315-
colors_list = list(colors_agg.values())
316-
self._update_node_trace_col(colors_agg, colors_list)
317-
self._update_edge_trace_col(colors_agg, colors_list)
312+
nodes_col = _node_col(self.__graph, self.__colors, self.__agg)
313+
colors_list = list(nodes_col.values())
314+
self._update_node_trace_col(nodes_col, colors_list)
315+
self._update_edge_trace_col(nodes_col, colors_list)
318316

319317
def _update_edge_trace_col(self, colors_agg, colors_list):
320318
colors_avg = []
@@ -538,7 +536,7 @@ def _plot_nodes(self, ax, nodes_pos):
538536
nodes_arr = _node_pos_array(self.__graph, self.__dim, nodes_pos)
539537
attr_size = nx.get_node_attributes(self.__graph, ATTR_SIZE)
540538
max_size = max(attr_size.values()) if attr_size else 1.0
541-
colors_agg = aggregate_graph(self.__colors, self.__graph, self.__agg)
539+
colors_agg = _node_col(self.__graph, self.__colors, self.__agg)
542540
marker_color = [colors_agg[n] for n in self.__graph.nodes()]
543541
marker_size = [200.0 * math.sqrt(attr_size[n] / max_size) for n in self.__graph.nodes()]
544542
verts = ax.scatter(

tests/test_plot.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,34 @@ def testTwoConnectedClusters(self):
2121
mp = MapperAlgorithm(cover=BallCover(1.1, metric=dist),
2222
clustering=TrivialClustering())
2323
g = mp.fit_transform(data, data)
24-
mp_plot1 = MapperLayoutInteractive(g, colors=data, dim=2)
24+
mp_plot1 = MapperLayoutInteractive(g, dim=2,
25+
colors=data,
26+
seed=123,
27+
iterations=10,
28+
agg=np.nanmax,
29+
width=200,
30+
height=200,
31+
title='example',
32+
cmap='jet')
2533
mp_plot1.plot()
26-
mp_plot2 = MapperLayoutInteractive(g, colors=data, dim=3)
27-
mp_plot2.plot()
28-
mp_plot2.update(
34+
mp_plot2 = MapperLayoutInteractive(g, dim=3,
2935
colors=data,
3036
seed=123,
3137
iterations=10,
3238
agg=np.nanmax,
3339
width=200,
3440
height=200,
41+
title='example',
42+
cmap='jet')
43+
mp_plot2.plot()
44+
mp_plot2.update(
45+
colors=data,
46+
seed=124,
47+
iterations=15,
48+
agg=np.nanmin,
49+
width=300,
50+
height=300,
51+
title='example-updated',
3552
cmap='viridis')
3653
mp_plot2.plot()
3754
mp_plot3 = MapperLayoutStatic(g, colors=data, dim=2)

0 commit comments

Comments
 (0)