Skip to content

Commit 13ff71f

Browse files
committed
Minor improvements
1 parent 85408fb commit 13ff71f

File tree

1 file changed

+59
-59
lines changed

1 file changed

+59
-59
lines changed

src/tdamapper/plot.py

Lines changed: 59 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -240,62 +240,62 @@ def __init__(self, graph, dim,
240240
cmap='jet'):
241241
self.__graph = graph
242242
self.__dim = dim
243-
self.seed = seed
244-
self.iterations = iterations
245-
self.colors = colors
246-
self.agg = agg
247-
self.title = title
248-
self.width = width
249-
self.height = height
250-
self.cmap = cmap
251-
self.fig = _plotly_mapper_fig(
243+
self.__seed = seed
244+
self.__iterations = iterations
245+
self.__colors = colors
246+
self.__agg = agg
247+
self.__title = title
248+
self.__width = width
249+
self.__height = height
250+
self.__cmap = cmap
251+
self.__fig = _plotly_mapper_fig(
252252
self.__graph,
253253
self.__dim,
254-
self.seed,
255-
self.iterations)
254+
self.__seed,
255+
self.__iterations)
256256
self._update_traces_col()
257257
self._update_layout()
258258
self._update_traces_cmap()
259259
self._update_traces_title()
260260

261261
def _update_traces_pos(self):
262-
pos = _nodes_pos(self.__graph, self.__dim, self.seed, self.iterations)
262+
pos = _nodes_pos(self.__graph, self.__dim, self.__seed, self.__iterations)
263263
node_arr = _nodes_array(self.__graph, self.__dim, pos)
264264
edge_arr = _edges_array(self.__graph, self.__dim, pos)
265265
if self.__dim == 3:
266-
self.fig.update_traces(
266+
self.__fig.update_traces(
267267
patch=dict(
268268
x=node_arr[0],
269269
y=node_arr[1],
270270
z=node_arr[2]),
271271
selector=dict(
272272
name='nodes_trace'))
273-
self.fig.update_traces(
273+
self.__fig.update_traces(
274274
patch=dict(
275275
x=edge_arr[0],
276276
y=edge_arr[1],
277277
z=edge_arr[2]),
278278
selector=dict(
279279
name='edges_trace'))
280280
elif self.__dim == 2:
281-
self.fig.update_traces(
281+
self.__fig.update_traces(
282282
patch=dict(
283283
x=node_arr[0],
284284
y=node_arr[1]),
285285
selector=dict(
286286
name='nodes_trace'))
287-
self.fig.update_traces(
287+
self.__fig.update_traces(
288288
patch=dict(
289289
x=edge_arr[0],
290290
y=edge_arr[1]),
291291
selector=dict(
292292
name='edges_trace'))
293293

294294
def _update_traces_col(self):
295-
if (self.colors is not None) and (self.agg is not None):
296-
colors_agg = aggregate_graph(self.colors, self.__graph, self.agg)
295+
if (self.__colors is not None) and (self.__agg is not None):
296+
colors_agg = aggregate_graph(self.__colors, self.__graph, self.__agg)
297297
colors_list = [colors_agg[n] for n in self.__graph.nodes()]
298-
self.fig.update_traces(
298+
self.__fig.update_traces(
299299
patch=dict(
300300
marker_color=colors_list,
301301
marker_cmax=max(colors_list),
@@ -305,24 +305,24 @@ def _update_traces_col(self):
305305
name='nodes_trace'))
306306

307307
def _update_traces_cmap(self):
308-
self.fig.update_traces(
308+
self.__fig.update_traces(
309309
patch=dict(
310-
marker_colorscale=self.cmap,
311-
marker_line_colorscale=self.cmap),
310+
marker_colorscale=self.__cmap,
311+
marker_line_colorscale=self.__cmap),
312312
selector=dict(
313313
name='nodes_trace'))
314314

315315
def _update_traces_title(self):
316-
self.fig.update_traces(
316+
self.__fig.update_traces(
317317
patch=dict(
318-
marker_colorbar=_plotly_colorbar(self.__dim, self.title)),
318+
marker_colorbar=_plotly_colorbar(self.__dim, self.__title)),
319319
selector=dict(
320320
name='nodes_trace'))
321321

322322
def _update_layout(self):
323-
self.fig.update_layout(
324-
width=self.width,
325-
height=self.height)
323+
self.__fig.update_layout(
324+
width=self.__width,
325+
height=self.__height)
326326

327327
def update(self,
328328
graph=None,
@@ -345,9 +345,9 @@ def update(self,
345345
:param graph: The precomputed Mapper graph to be embedded. This can be
346346
obtained by calling :func:`tdamapper.core.mapper_graph` or
347347
:func:`tdamapper.core.MapperAlgorithm.fit_transform`.
348-
:type graph: :class:`networkx.Graph`, required
348+
:type graph: :class:`networkx.Graph`, optional
349349
:param dim: The dimension of the graph embedding (2 or 3).
350-
:type dim: int
350+
:type dim: int, optional
351351
:param seed: The random seed used to construct the graph embedding.
352352
:type seed: int, optional
353353
:param iterations: The number of iterations used to construct the graph embedding.
@@ -382,42 +382,42 @@ def update(self,
382382
_update_pos = True
383383
_update_layout = True
384384
if seed is not None:
385-
self.seed = seed
385+
self.__seed = seed
386386
_update_pos = True
387387
if iterations is not None:
388-
self.iterations = iterations
388+
self.__iterations = iterations
389389
_update_pos = True
390390
if _update_pos:
391391
self._update_traces_pos()
392392
if agg is not None:
393-
self.agg = agg
393+
self.__agg = agg
394394
if (colors is not None) and (agg is not None):
395-
self.colors = colors
396-
self.agg = agg
395+
self.__colors = colors
396+
self.__agg = agg
397397
_update_col = True
398-
if (colors is not None) and (self.agg is not None):
399-
self.colors = colors
398+
if (colors is not None) and (self.__agg is not None):
399+
self.__colors = colors
400400
_update_col = True
401-
if (self.colors is not None) and (agg is not None):
402-
self.agg = agg
401+
if (self.__colors is not None) and (agg is not None):
402+
self.__agg = agg
403403
_update_col = True
404404
if _update_col:
405405
self._update_traces_col()
406406
if cmap is not None:
407-
self.cmap = cmap
407+
self.__cmap = cmap
408408
self._update_traces_cmap()
409409
if title is not None:
410-
self.title = title
410+
self.__title = title
411411
self._update_traces_title()
412412
if (width is not None) and (height is not None):
413-
self.width = width
414-
self.height = height
413+
self.__width = width
414+
self.__height = height
415415
_update_layout = True
416-
if (width is not None) and (self.height is not None):
417-
self.width = width
416+
if (width is not None) and (self.__height is not None):
417+
self.__width = width
418418
_update_layout = True
419419
if height is not None:
420-
self.height = height
420+
self.__height = height
421421
_update_layout = True
422422
if _update_layout:
423423
self._update_layout()
@@ -430,7 +430,7 @@ def plot(self):
430430
For 3D embeddings, the figure requires a WebGL context to be shown.
431431
:rtype: :class:`plotly.graph_objects.Figure`
432432
"""
433-
return self.fig
433+
return self.__fig
434434

435435

436436
class MapperLayoutStatic:
@@ -479,14 +479,14 @@ def __init__(self, graph, dim,
479479
cmap='jet'):
480480
self.__graph = graph
481481
self.__dim = dim
482-
self.seed = seed
483-
self.iterations = iterations
484-
self.colors = colors
485-
self.agg = agg
486-
self.title = title
487-
self.width = width
488-
self.height = height
489-
self.cmap = cmap
482+
self.__seed = seed
483+
self.__iterations = iterations
484+
self.__colors = colors
485+
self.__agg = agg
486+
self.__title = title
487+
self.__width = width
488+
self.__height = height
489+
self.__cmap = cmap
490490

491491
def plot(self):
492492
"""
@@ -496,10 +496,10 @@ def plot(self):
496496
:rtype: :class:`matplotlib.figure.Figure`, :class:`matplotlib.axes.Axes`
497497
"""
498498
px = 1 / plt.rcParams['figure.dpi'] # pixel in inches
499-
fig, ax = plt.subplots(figsize=(self.width * px, self.height * px))
499+
fig, ax = plt.subplots(figsize=(self.__width * px, self.__height * px))
500500
ax.get_xaxis().set_visible(False)
501501
ax.get_yaxis().set_visible(False)
502-
pos = _nodes_pos(self.__graph, self.__dim, self.seed, self.iterations)
502+
pos = _nodes_pos(self.__graph, self.__dim, self.__seed, self.__iterations)
503503
self._plot_edges(ax, pos)
504504
self._plot_nodes(ax, pos)
505505
return fig, ax
@@ -508,15 +508,15 @@ def _plot_nodes(self, ax, nodes_pos):
508508
nodes_arr = _nodes_array(self.__graph, self.__dim, nodes_pos)
509509
attr_size = nx.get_node_attributes(self.__graph, ATTR_SIZE)
510510
max_size = max(attr_size.values()) if attr_size else 1.0
511-
colors_agg = aggregate_graph(self.colors, self.__graph, self.agg)
511+
colors_agg = aggregate_graph(self.__colors, self.__graph, self.__agg)
512512
marker_color = [colors_agg[n] for n in self.__graph.nodes()]
513513
marker_size = [200.0 * math.sqrt(attr_size[n] / max_size) for n in self.__graph.nodes()]
514514
verts = ax.scatter(
515515
x=nodes_arr[0],
516516
y=nodes_arr[1],
517517
c=marker_color,
518518
s=marker_size,
519-
cmap=self.cmap,
519+
cmap=self.__cmap,
520520
alpha=1.0,
521521
vmin=min(marker_color),
522522
vmax=max(marker_color),
@@ -531,7 +531,7 @@ def _plot_nodes(self, ax, nodes_pos):
531531
ax=ax,
532532
format="%.2g")
533533
colorbar.set_label(
534-
self.title,
534+
self.__title,
535535
color=_NODE_OUTER_COLOR)
536536
colorbar.set_alpha(1.0)
537537
colorbar.outline.set_color(_NODE_OUTER_COLOR)

0 commit comments

Comments
 (0)