Skip to content

Commit f094f15

Browse files
authored
Merge pull request #95 from lucasimi/develop
Develop
2 parents 88e6ec1 + 13ff71f commit f094f15

File tree

2 files changed

+78
-59
lines changed

2 files changed

+78
-59
lines changed

src/tdamapper/plot.py

Lines changed: 76 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -240,62 +240,62 @@ def __init__(self, graph, dim,
240240
cmap='jet'):
241241
self.__graph = graph
242242
self.__dim = dim
243-
self.seed = seed
244-
self.iterations = iterations
245-
self.colors = colors
246-
self.agg = agg
247-
self.title = title
248-
self.width = width
249-
self.height = height
250-
self.cmap = cmap
251-
self.fig = _plotly_mapper_fig(
243+
self.__seed = seed
244+
self.__iterations = iterations
245+
self.__colors = colors
246+
self.__agg = agg
247+
self.__title = title
248+
self.__width = width
249+
self.__height = height
250+
self.__cmap = cmap
251+
self.__fig = _plotly_mapper_fig(
252252
self.__graph,
253253
self.__dim,
254-
self.seed,
255-
self.iterations)
254+
self.__seed,
255+
self.__iterations)
256256
self._update_traces_col()
257257
self._update_layout()
258258
self._update_traces_cmap()
259259
self._update_traces_title()
260260

261261
def _update_traces_pos(self):
262-
pos = _nodes_pos(self.__graph, self.__dim, self.seed, self.iterations)
262+
pos = _nodes_pos(self.__graph, self.__dim, self.__seed, self.__iterations)
263263
node_arr = _nodes_array(self.__graph, self.__dim, pos)
264264
edge_arr = _edges_array(self.__graph, self.__dim, pos)
265265
if self.__dim == 3:
266-
self.fig.update_traces(
266+
self.__fig.update_traces(
267267
patch=dict(
268268
x=node_arr[0],
269269
y=node_arr[1],
270270
z=node_arr[2]),
271271
selector=dict(
272272
name='nodes_trace'))
273-
self.fig.update_traces(
273+
self.__fig.update_traces(
274274
patch=dict(
275275
x=edge_arr[0],
276276
y=edge_arr[1],
277277
z=edge_arr[2]),
278278
selector=dict(
279279
name='edges_trace'))
280280
elif self.__dim == 2:
281-
self.fig.update_traces(
281+
self.__fig.update_traces(
282282
patch=dict(
283283
x=node_arr[0],
284284
y=node_arr[1]),
285285
selector=dict(
286286
name='nodes_trace'))
287-
self.fig.update_traces(
287+
self.__fig.update_traces(
288288
patch=dict(
289289
x=edge_arr[0],
290290
y=edge_arr[1]),
291291
selector=dict(
292292
name='edges_trace'))
293293

294294
def _update_traces_col(self):
295-
if (self.colors is not None) and (self.agg is not None):
296-
colors_agg = aggregate_graph(self.colors, self.__graph, self.agg)
295+
if (self.__colors is not None) and (self.__agg is not None):
296+
colors_agg = aggregate_graph(self.__colors, self.__graph, self.__agg)
297297
colors_list = [colors_agg[n] for n in self.__graph.nodes()]
298-
self.fig.update_traces(
298+
self.__fig.update_traces(
299299
patch=dict(
300300
marker_color=colors_list,
301301
marker_cmax=max(colors_list),
@@ -305,26 +305,28 @@ def _update_traces_col(self):
305305
name='nodes_trace'))
306306

307307
def _update_traces_cmap(self):
308-
self.fig.update_traces(
308+
self.__fig.update_traces(
309309
patch=dict(
310-
marker_colorscale=self.cmap,
311-
marker_line_colorscale=self.cmap),
310+
marker_colorscale=self.__cmap,
311+
marker_line_colorscale=self.__cmap),
312312
selector=dict(
313313
name='nodes_trace'))
314314

315315
def _update_traces_title(self):
316-
self.fig.update_traces(
316+
self.__fig.update_traces(
317317
patch=dict(
318-
marker_colorbar=_plotly_colorbar(self.__dim, self.title)),
318+
marker_colorbar=_plotly_colorbar(self.__dim, self.__title)),
319319
selector=dict(
320320
name='nodes_trace'))
321321

322322
def _update_layout(self):
323-
self.fig.update_layout(
324-
width=self.width,
325-
height=self.height)
323+
self.__fig.update_layout(
324+
width=self.__width,
325+
height=self.__height)
326326

327327
def update(self,
328+
graph=None,
329+
dim=None,
328330
seed=None,
329331
iterations=None,
330332
colors=None,
@@ -340,6 +342,12 @@ def update(self,
340342
calling this method, the figure will be updated according to the supplied
341343
parameters.
342344
345+
:param graph: The precomputed Mapper graph to be embedded. This can be
346+
obtained by calling :func:`tdamapper.core.mapper_graph` or
347+
:func:`tdamapper.core.MapperAlgorithm.fit_transform`.
348+
:type graph: :class:`networkx.Graph`, optional
349+
:param dim: The dimension of the graph embedding (2 or 3).
350+
:type dim: int, optional
343351
:param seed: The random seed used to construct the graph embedding.
344352
:type seed: int, optional
345353
:param iterations: The number of iterations used to construct the graph embedding.
@@ -362,45 +370,54 @@ def update(self,
362370
:type cmap: str, optional
363371
"""
364372
_update_pos = False
373+
_update_col = False
374+
_update_layout = False
375+
if graph is not None:
376+
self.__graph = graph
377+
_update_pos = True
378+
_update_col = True
379+
_update_layout = True
380+
if dim is not None:
381+
self.__dim = dim
382+
_update_pos = True
383+
_update_layout = True
365384
if seed is not None:
366-
self.seed = seed
385+
self.__seed = seed
367386
_update_pos = True
368387
if iterations is not None:
369-
self.iterations = iterations
388+
self.__iterations = iterations
370389
_update_pos = True
371390
if _update_pos:
372391
self._update_traces_pos()
373-
_update_col = False
374392
if agg is not None:
375-
self.agg = agg
393+
self.__agg = agg
376394
if (colors is not None) and (agg is not None):
377-
self.colors = colors
378-
self.agg = agg
395+
self.__colors = colors
396+
self.__agg = agg
379397
_update_col = True
380-
if (colors is not None) and (self.agg is not None):
381-
self.colors = colors
398+
if (colors is not None) and (self.__agg is not None):
399+
self.__colors = colors
382400
_update_col = True
383-
if (self.colors is not None) and (agg is not None):
384-
self.agg = agg
401+
if (self.__colors is not None) and (agg is not None):
402+
self.__agg = agg
385403
_update_col = True
386404
if _update_col:
387405
self._update_traces_col()
388406
if cmap is not None:
389-
self.cmap = cmap
407+
self.__cmap = cmap
390408
self._update_traces_cmap()
391409
if title is not None:
392-
self.title = title
410+
self.__title = title
393411
self._update_traces_title()
394-
_update_layout = False
395412
if (width is not None) and (height is not None):
396-
self.width = width
397-
self.height = height
413+
self.__width = width
414+
self.__height = height
398415
_update_layout = True
399-
if (width is not None) and (self.height is not None):
400-
self.width = width
416+
if (width is not None) and (self.__height is not None):
417+
self.__width = width
401418
_update_layout = True
402419
if height is not None:
403-
self.height = height
420+
self.__height = height
404421
_update_layout = True
405422
if _update_layout:
406423
self._update_layout()
@@ -413,7 +430,7 @@ def plot(self):
413430
For 3D embeddings, the figure requires a WebGL context to be shown.
414431
:rtype: :class:`plotly.graph_objects.Figure`
415432
"""
416-
return self.fig
433+
return self.__fig
417434

418435

419436
class MapperLayoutStatic:
@@ -462,14 +479,14 @@ def __init__(self, graph, dim,
462479
cmap='jet'):
463480
self.__graph = graph
464481
self.__dim = dim
465-
self.seed = seed
466-
self.iterations = iterations
467-
self.colors = colors
468-
self.agg = agg
469-
self.title = title
470-
self.width = width
471-
self.height = height
472-
self.cmap = cmap
482+
self.__seed = seed
483+
self.__iterations = iterations
484+
self.__colors = colors
485+
self.__agg = agg
486+
self.__title = title
487+
self.__width = width
488+
self.__height = height
489+
self.__cmap = cmap
473490

474491
def plot(self):
475492
"""
@@ -479,10 +496,10 @@ def plot(self):
479496
:rtype: :class:`matplotlib.figure.Figure`, :class:`matplotlib.axes.Axes`
480497
"""
481498
px = 1 / plt.rcParams['figure.dpi'] # pixel in inches
482-
fig, ax = plt.subplots(figsize=(self.width * px, self.height * px))
499+
fig, ax = plt.subplots(figsize=(self.__width * px, self.__height * px))
483500
ax.get_xaxis().set_visible(False)
484501
ax.get_yaxis().set_visible(False)
485-
pos = _nodes_pos(self.__graph, self.__dim, self.seed, self.iterations)
502+
pos = _nodes_pos(self.__graph, self.__dim, self.__seed, self.__iterations)
486503
self._plot_edges(ax, pos)
487504
self._plot_nodes(ax, pos)
488505
return fig, ax
@@ -491,15 +508,15 @@ def _plot_nodes(self, ax, nodes_pos):
491508
nodes_arr = _nodes_array(self.__graph, self.__dim, nodes_pos)
492509
attr_size = nx.get_node_attributes(self.__graph, ATTR_SIZE)
493510
max_size = max(attr_size.values()) if attr_size else 1.0
494-
colors_agg = aggregate_graph(self.colors, self.__graph, self.agg)
511+
colors_agg = aggregate_graph(self.__colors, self.__graph, self.__agg)
495512
marker_color = [colors_agg[n] for n in self.__graph.nodes()]
496513
marker_size = [200.0 * math.sqrt(attr_size[n] / max_size) for n in self.__graph.nodes()]
497514
verts = ax.scatter(
498515
x=nodes_arr[0],
499516
y=nodes_arr[1],
500517
c=marker_color,
501518
s=marker_size,
502-
cmap=self.cmap,
519+
cmap=self.__cmap,
503520
alpha=1.0,
504521
vmin=min(marker_color),
505522
vmax=max(marker_color),
@@ -514,7 +531,7 @@ def _plot_nodes(self, ax, nodes_pos):
514531
ax=ax,
515532
format="%.2g")
516533
colorbar.set_label(
517-
self.title,
534+
self.__title,
518535
color=_NODE_OUTER_COLOR)
519536
colorbar.set_alpha(1.0)
520537
colorbar.outline.set_color(_NODE_OUTER_COLOR)

tests/test_plot.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def testTwoConnectedClusters(self):
2626
mp_plot2 = MapperLayoutInteractive(g, colors=data, dim=3)
2727
mp_plot2.plot()
2828
mp_plot2.update(
29+
graph=g,
30+
dim=2,
2931
colors=data,
3032
seed=123,
3133
iterations=10,

0 commit comments

Comments
 (0)