@@ -42,6 +42,13 @@ def _node_pos(graph, dim, seed, iterations):
4242 iterations = iterations )
4343
4444
45+ def _node_col (graph , colors , agg , default = 0.5 ):
46+ if colors is not None :
47+ return aggregate_graph (colors , graph , agg )
48+ else :
49+ return [default for _ in graph .nodes ()]
50+
51+
4552def _node_pos_array (graph , dim , node_pos ):
4653 return tuple ([node_pos [n ][i ] for n in graph .nodes ()] for i in range (dim ))
4754
@@ -116,7 +123,7 @@ def __init__(self, graph, dim,
116123 self .__width = width
117124 self .__height = height
118125 self .__cmap = cmap
119- node_col = aggregate_graph (self .__colors , self .__graph , self .__agg )
126+ node_col = _node_col (self .__graph , self .__colors , self .__agg )
120127 self .__fig = self ._figure (node_col )
121128 #self._update_traces_col()
122129 #self._update_layout()
@@ -311,10 +318,10 @@ def _update_traces_pos(self):
311318
312319 def _update_traces_col (self ):
313320 if (self .__colors is not None ) and (self .__agg is not None ):
314- colors_agg = aggregate_graph (self .__colors , self .__graph , self .__agg )
315- colors_list = list (colors_agg .values ())
316- self ._update_node_trace_col (colors_agg , colors_list )
317- self ._update_edge_trace_col (colors_agg , colors_list )
321+ nodes_col = _node_col (self .__graph , self .__colors , self .__agg )
322+ colors_list = list (nodes_col .values ())
323+ self ._update_node_trace_col (nodes_col , colors_list )
324+ self ._update_edge_trace_col (nodes_col , colors_list )
318325
319326 def _update_edge_trace_col (self , colors_agg , colors_list ):
320327 colors_avg = []
@@ -538,7 +545,7 @@ def _plot_nodes(self, ax, nodes_pos):
538545 nodes_arr = _node_pos_array (self .__graph , self .__dim , nodes_pos )
539546 attr_size = nx .get_node_attributes (self .__graph , ATTR_SIZE )
540547 max_size = max (attr_size .values ()) if attr_size else 1.0
541- colors_agg = aggregate_graph (self .__colors , self .__graph , self .__agg )
548+ colors_agg = _node_col (self .__graph , self .__colors , self .__agg )
542549 marker_color = [colors_agg [n ] for n in self .__graph .nodes ()]
543550 marker_size = [200.0 * math .sqrt (attr_size [n ] / max_size ) for n in self .__graph .nodes ()]
544551 verts = ax .scatter (
0 commit comments