@@ -240,62 +240,62 @@ def __init__(self, graph, dim,
240240 cmap = 'jet' ):
241241 self .__graph = graph
242242 self .__dim = dim
243- self .seed = seed
244- self .iterations = iterations
245- self .colors = colors
246- self .agg = agg
247- self .title = title
248- self .width = width
249- self .height = height
250- self .cmap = cmap
251- self .fig = _plotly_mapper_fig (
243+ self .__seed = seed
244+ self .__iterations = iterations
245+ self .__colors = colors
246+ self .__agg = agg
247+ self .__title = title
248+ self .__width = width
249+ self .__height = height
250+ self .__cmap = cmap
251+ self .__fig = _plotly_mapper_fig (
252252 self .__graph ,
253253 self .__dim ,
254- self .seed ,
255- self .iterations )
254+ self .__seed ,
255+ self .__iterations )
256256 self ._update_traces_col ()
257257 self ._update_layout ()
258258 self ._update_traces_cmap ()
259259 self ._update_traces_title ()
260260
261261 def _update_traces_pos (self ):
262- pos = _nodes_pos (self .__graph , self .__dim , self .seed , self .iterations )
262+ pos = _nodes_pos (self .__graph , self .__dim , self .__seed , self .__iterations )
263263 node_arr = _nodes_array (self .__graph , self .__dim , pos )
264264 edge_arr = _edges_array (self .__graph , self .__dim , pos )
265265 if self .__dim == 3 :
266- self .fig .update_traces (
266+ self .__fig .update_traces (
267267 patch = dict (
268268 x = node_arr [0 ],
269269 y = node_arr [1 ],
270270 z = node_arr [2 ]),
271271 selector = dict (
272272 name = 'nodes_trace' ))
273- self .fig .update_traces (
273+ self .__fig .update_traces (
274274 patch = dict (
275275 x = edge_arr [0 ],
276276 y = edge_arr [1 ],
277277 z = edge_arr [2 ]),
278278 selector = dict (
279279 name = 'edges_trace' ))
280280 elif self .__dim == 2 :
281- self .fig .update_traces (
281+ self .__fig .update_traces (
282282 patch = dict (
283283 x = node_arr [0 ],
284284 y = node_arr [1 ]),
285285 selector = dict (
286286 name = 'nodes_trace' ))
287- self .fig .update_traces (
287+ self .__fig .update_traces (
288288 patch = dict (
289289 x = edge_arr [0 ],
290290 y = edge_arr [1 ]),
291291 selector = dict (
292292 name = 'edges_trace' ))
293293
294294 def _update_traces_col (self ):
295- if (self .colors is not None ) and (self .agg is not None ):
296- colors_agg = aggregate_graph (self .colors , self .__graph , self .agg )
295+ if (self .__colors is not None ) and (self .__agg is not None ):
296+ colors_agg = aggregate_graph (self .__colors , self .__graph , self .__agg )
297297 colors_list = [colors_agg [n ] for n in self .__graph .nodes ()]
298- self .fig .update_traces (
298+ self .__fig .update_traces (
299299 patch = dict (
300300 marker_color = colors_list ,
301301 marker_cmax = max (colors_list ),
@@ -305,26 +305,28 @@ def _update_traces_col(self):
305305 name = 'nodes_trace' ))
306306
307307 def _update_traces_cmap (self ):
308- self .fig .update_traces (
308+ self .__fig .update_traces (
309309 patch = dict (
310- marker_colorscale = self .cmap ,
311- marker_line_colorscale = self .cmap ),
310+ marker_colorscale = self .__cmap ,
311+ marker_line_colorscale = self .__cmap ),
312312 selector = dict (
313313 name = 'nodes_trace' ))
314314
315315 def _update_traces_title (self ):
316- self .fig .update_traces (
316+ self .__fig .update_traces (
317317 patch = dict (
318- marker_colorbar = _plotly_colorbar (self .__dim , self .title )),
318+ marker_colorbar = _plotly_colorbar (self .__dim , self .__title )),
319319 selector = dict (
320320 name = 'nodes_trace' ))
321321
322322 def _update_layout (self ):
323- self .fig .update_layout (
324- width = self .width ,
325- height = self .height )
323+ self .__fig .update_layout (
324+ width = self .__width ,
325+ height = self .__height )
326326
327327 def update (self ,
328+ graph = None ,
329+ dim = None ,
328330 seed = None ,
329331 iterations = None ,
330332 colors = None ,
@@ -340,6 +342,12 @@ def update(self,
340342 calling this method, the figure will be updated according to the supplied
341343 parameters.
342344
345+ :param graph: The precomputed Mapper graph to be embedded. This can be
346+ obtained by calling :func:`tdamapper.core.mapper_graph` or
347+ :func:`tdamapper.core.MapperAlgorithm.fit_transform`.
348+ :type graph: :class:`networkx.Graph`, optional
349+ :param dim: The dimension of the graph embedding (2 or 3).
350+ :type dim: int, optional
343351 :param seed: The random seed used to construct the graph embedding.
344352 :type seed: int, optional
345353 :param iterations: The number of iterations used to construct the graph embedding.
@@ -362,45 +370,54 @@ def update(self,
362370 :type cmap: str, optional
363371 """
364372 _update_pos = False
373+ _update_col = False
374+ _update_layout = False
375+ if graph is not None :
376+ self .__graph = graph
377+ _update_pos = True
378+ _update_col = True
379+ _update_layout = True
380+ if dim is not None :
381+ self .__dim = dim
382+ _update_pos = True
383+ _update_layout = True
365384 if seed is not None :
366- self .seed = seed
385+ self .__seed = seed
367386 _update_pos = True
368387 if iterations is not None :
369- self .iterations = iterations
388+ self .__iterations = iterations
370389 _update_pos = True
371390 if _update_pos :
372391 self ._update_traces_pos ()
373- _update_col = False
374392 if agg is not None :
375- self .agg = agg
393+ self .__agg = agg
376394 if (colors is not None ) and (agg is not None ):
377- self .colors = colors
378- self .agg = agg
395+ self .__colors = colors
396+ self .__agg = agg
379397 _update_col = True
380- if (colors is not None ) and (self .agg is not None ):
381- self .colors = colors
398+ if (colors is not None ) and (self .__agg is not None ):
399+ self .__colors = colors
382400 _update_col = True
383- if (self .colors is not None ) and (agg is not None ):
384- self .agg = agg
401+ if (self .__colors is not None ) and (agg is not None ):
402+ self .__agg = agg
385403 _update_col = True
386404 if _update_col :
387405 self ._update_traces_col ()
388406 if cmap is not None :
389- self .cmap = cmap
407+ self .__cmap = cmap
390408 self ._update_traces_cmap ()
391409 if title is not None :
392- self .title = title
410+ self .__title = title
393411 self ._update_traces_title ()
394- _update_layout = False
395412 if (width is not None ) and (height is not None ):
396- self .width = width
397- self .height = height
413+ self .__width = width
414+ self .__height = height
398415 _update_layout = True
399- if (width is not None ) and (self .height is not None ):
400- self .width = width
416+ if (width is not None ) and (self .__height is not None ):
417+ self .__width = width
401418 _update_layout = True
402419 if height is not None :
403- self .height = height
420+ self .__height = height
404421 _update_layout = True
405422 if _update_layout :
406423 self ._update_layout ()
@@ -413,7 +430,7 @@ def plot(self):
413430 For 3D embeddings, the figure requires a WebGL context to be shown.
414431 :rtype: :class:`plotly.graph_objects.Figure`
415432 """
416- return self .fig
433+ return self .__fig
417434
418435
419436class MapperLayoutStatic :
@@ -462,14 +479,14 @@ def __init__(self, graph, dim,
462479 cmap = 'jet' ):
463480 self .__graph = graph
464481 self .__dim = dim
465- self .seed = seed
466- self .iterations = iterations
467- self .colors = colors
468- self .agg = agg
469- self .title = title
470- self .width = width
471- self .height = height
472- self .cmap = cmap
482+ self .__seed = seed
483+ self .__iterations = iterations
484+ self .__colors = colors
485+ self .__agg = agg
486+ self .__title = title
487+ self .__width = width
488+ self .__height = height
489+ self .__cmap = cmap
473490
474491 def plot (self ):
475492 """
@@ -479,10 +496,10 @@ def plot(self):
479496 :rtype: :class:`matplotlib.figure.Figure`, :class:`matplotlib.axes.Axes`
480497 """
481498 px = 1 / plt .rcParams ['figure.dpi' ] # pixel in inches
482- fig , ax = plt .subplots (figsize = (self .width * px , self .height * px ))
499+ fig , ax = plt .subplots (figsize = (self .__width * px , self .__height * px ))
483500 ax .get_xaxis ().set_visible (False )
484501 ax .get_yaxis ().set_visible (False )
485- pos = _nodes_pos (self .__graph , self .__dim , self .seed , self .iterations )
502+ pos = _nodes_pos (self .__graph , self .__dim , self .__seed , self .__iterations )
486503 self ._plot_edges (ax , pos )
487504 self ._plot_nodes (ax , pos )
488505 return fig , ax
@@ -491,15 +508,15 @@ def _plot_nodes(self, ax, nodes_pos):
491508 nodes_arr = _nodes_array (self .__graph , self .__dim , nodes_pos )
492509 attr_size = nx .get_node_attributes (self .__graph , ATTR_SIZE )
493510 max_size = max (attr_size .values ()) if attr_size else 1.0
494- colors_agg = aggregate_graph (self .colors , self .__graph , self .agg )
511+ colors_agg = aggregate_graph (self .__colors , self .__graph , self .__agg )
495512 marker_color = [colors_agg [n ] for n in self .__graph .nodes ()]
496513 marker_size = [200.0 * math .sqrt (attr_size [n ] / max_size ) for n in self .__graph .nodes ()]
497514 verts = ax .scatter (
498515 x = nodes_arr [0 ],
499516 y = nodes_arr [1 ],
500517 c = marker_color ,
501518 s = marker_size ,
502- cmap = self .cmap ,
519+ cmap = self .__cmap ,
503520 alpha = 1.0 ,
504521 vmin = min (marker_color ),
505522 vmax = max (marker_color ),
@@ -514,7 +531,7 @@ def _plot_nodes(self, ax, nodes_pos):
514531 ax = ax ,
515532 format = "%.2g" )
516533 colorbar .set_label (
517- self .title ,
534+ self .__title ,
518535 color = _NODE_OUTER_COLOR )
519536 colorbar .set_alpha (1.0 )
520537 colorbar .outline .set_color (_NODE_OUTER_COLOR )
0 commit comments