Skip to content

Commit 9b5ca52

Browse files
committed
Fixed issue with None colors argument
1 parent e00e391 commit 9b5ca52

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

src/tdamapper/plot.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def _node_pos(graph, dim, seed, iterations):
4242
iterations=iterations)
4343

4444

45+
def _node_col(graph, colors, agg, default=0.5):
46+
if colors is not None:
47+
return aggregate_graph(colors, graph, agg)
48+
else:
49+
return [default for _ in graph.nodes()]
50+
51+
4552
def _node_pos_array(graph, dim, node_pos):
4653
return tuple([node_pos[n][i] for n in graph.nodes()] for i in range(dim))
4754

@@ -116,7 +123,7 @@ def __init__(self, graph, dim,
116123
self.__width = width
117124
self.__height = height
118125
self.__cmap = cmap
119-
node_col = aggregate_graph(self.__colors, self.__graph, self.__agg)
126+
node_col = _node_col(self.__graph, self.__colors, self.__agg)
120127
self.__fig = self._figure(node_col)
121128
#self._update_traces_col()
122129
#self._update_layout()
@@ -311,10 +318,10 @@ def _update_traces_pos(self):
311318

312319
def _update_traces_col(self):
313320
if (self.__colors is not None) and (self.__agg is not None):
314-
colors_agg = aggregate_graph(self.__colors, self.__graph, self.__agg)
315-
colors_list = list(colors_agg.values())
316-
self._update_node_trace_col(colors_agg, colors_list)
317-
self._update_edge_trace_col(colors_agg, colors_list)
321+
nodes_col = _node_col(self.__graph, self.__colors, self.__agg)
322+
colors_list = list(nodes_col.values())
323+
self._update_node_trace_col(nodes_col, colors_list)
324+
self._update_edge_trace_col(nodes_col, colors_list)
318325

319326
def _update_edge_trace_col(self, colors_agg, colors_list):
320327
colors_avg = []
@@ -538,7 +545,7 @@ def _plot_nodes(self, ax, nodes_pos):
538545
nodes_arr = _node_pos_array(self.__graph, self.__dim, nodes_pos)
539546
attr_size = nx.get_node_attributes(self.__graph, ATTR_SIZE)
540547
max_size = max(attr_size.values()) if attr_size else 1.0
541-
colors_agg = aggregate_graph(self.__colors, self.__graph, self.__agg)
548+
colors_agg = _node_col(self.__graph, self.__colors, self.__agg)
542549
marker_color = [colors_agg[n] for n in self.__graph.nodes()]
543550
marker_size = [200.0 * math.sqrt(attr_size[n] / max_size) for n in self.__graph.nodes()]
544551
verts = ax.scatter(

0 commit comments

Comments
 (0)