@@ -42,6 +42,13 @@ def _node_pos(graph, dim, seed, iterations):
4242 iterations = iterations )
4343
4444
45+ def _node_col (graph , colors , agg , default = 0.5 ):
46+ if colors is not None :
47+ return aggregate_graph (colors , graph , agg )
48+ else :
49+ return [default for _ in graph .nodes ()]
50+
51+
4552def _node_pos_array (graph , dim , node_pos ):
4653 return tuple ([node_pos [n ][i ] for n in graph .nodes ()] for i in range (dim ))
4754
@@ -116,12 +123,8 @@ def __init__(self, graph, dim,
116123 self .__width = width
117124 self .__height = height
118125 self .__cmap = cmap
119- node_col = aggregate_graph (self .__colors , self .__graph , self .__agg )
126+ node_col = _node_col (self .__graph , self .__colors , self .__agg )
120127 self .__fig = self ._figure (node_col )
121- #self._update_traces_col()
122- #self._update_layout()
123- #self._update_traces_cmap()
124- #self._update_traces_title()
125128
126129 def _nodes_trace (self , node_pos_arr , node_col ):
127130 attr_size = nx .get_node_attributes (self .__graph , ATTR_SIZE )
@@ -192,7 +195,6 @@ def _layout(self):
192195 line_col = 'rgba(230, 230, 230, 1.0)'
193196 axis = dict (
194197 showline = True ,
195- #linecolor='rgba(230, 230, 230, 1.0)',
196198 linewidth = 1 ,
197199 mirror = True ,
198200 visible = True ,
@@ -263,17 +265,13 @@ def _colorbar(self):
263265 elif self .__dim == 2 :
264266 return go .scatter .marker .ColorBar (cbar )
265267
266- def _text (self , colors = None ):
268+ def _text (self , colors ):
267269 attr_size = nx .get_node_attributes (self .__graph , ATTR_SIZE )
268- if colors is None :
269- def _lbl (n ):
270- size = _fmt (attr_size [n ], 5 )
271- return f'node: { n } <br>size: { size } '
272- else :
273- def _lbl (n ):
274- col = _fmt (colors [n ], 3 )
275- size = _fmt (attr_size [n ], 5 )
276- return f'color: { col } <br>node: { n } <br>size: { size } '
270+
271+ def _lbl (n ):
272+ col = _fmt (colors [n ], 3 )
273+ size = _fmt (attr_size [n ], 5 )
274+ return f'color: { col } <br>node: { n } <br>size: { size } '
277275 return [_lbl (n ) for n in self .__graph .nodes ()]
278276
279277 def _update_traces_pos (self ):
@@ -311,10 +309,10 @@ def _update_traces_pos(self):
311309
312310 def _update_traces_col (self ):
313311 if (self .__colors is not None ) and (self .__agg is not None ):
314- colors_agg = aggregate_graph (self .__colors , self .__graph , self .__agg )
315- colors_list = list (colors_agg .values ())
316- self ._update_node_trace_col (colors_agg , colors_list )
317- self ._update_edge_trace_col (colors_agg , colors_list )
312+ nodes_col = _node_col (self .__graph , self .__colors , self .__agg )
313+ colors_list = list (nodes_col .values ())
314+ self ._update_node_trace_col (nodes_col , colors_list )
315+ self ._update_edge_trace_col (nodes_col , colors_list )
318316
319317 def _update_edge_trace_col (self , colors_agg , colors_list ):
320318 colors_avg = []
@@ -538,7 +536,7 @@ def _plot_nodes(self, ax, nodes_pos):
538536 nodes_arr = _node_pos_array (self .__graph , self .__dim , nodes_pos )
539537 attr_size = nx .get_node_attributes (self .__graph , ATTR_SIZE )
540538 max_size = max (attr_size .values ()) if attr_size else 1.0
541- colors_agg = aggregate_graph (self .__colors , self .__graph , self .__agg )
539+ colors_agg = _node_col (self .__graph , self .__colors , self .__agg )
542540 marker_color = [colors_agg [n ] for n in self .__graph .nodes ()]
543541 marker_size = [200.0 * math .sqrt (attr_size [n ] / max_size ) for n in self .__graph .nodes ()]
544542 verts = ax .scatter (
0 commit comments