Skip to content

Commit a265c90

Browse files
committed
Adding improvements for TPU compatibility.
1 parent 43d16e5 commit a265c90

File tree

13 files changed

+283
-183
lines changed

13 files changed

+283
-183
lines changed

docs/graph_nets.md

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4108,7 +4108,7 @@ Constructs an instance from an iterable of networkx graphs.
41084108
* `ValueError`: If `graph_nxs` is not an iterable of networkx instances.
41094109

41104110

4111-
### [`utils_tf.concat(input_graphs, axis, name='graph_concat')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=352)<!-- utils_tf.concat .code-reference -->
4111+
### [`utils_tf.concat(input_graphs, axis, name='graph_concat')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=359)<!-- utils_tf.concat .code-reference -->
41124112

41134113
Returns an op that concatenates graphs along a given axis.
41144114

@@ -4141,7 +4141,7 @@ corresponding fields is not `None`.
41414141
in `input_graphs` are not the same for all the graphs.
41424142

41434143

4144-
### [`utils_tf.data_dicts_to_graphs_tuple(data_dicts, name='data_dicts_to_graphs_tuple')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=915)<!-- utils_tf.data_dicts_to_graphs_tuple .code-reference -->
4144+
### [`utils_tf.data_dicts_to_graphs_tuple(data_dicts, name='data_dicts_to_graphs_tuple')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=922)<!-- utils_tf.data_dicts_to_graphs_tuple .code-reference -->
41454145

41464146
Creates a `graphs.GraphsTuple` containing tensors from data dicts.
41474147

@@ -4168,7 +4168,7 @@ Creates a `graphs.GraphsTuple` containing tensors from data dicts.
41684168
A `graphs.GraphTuple` representing the graphs in `data_dicts`.
41694169

41704170

4171-
### [`utils_tf.fully_connect_graph_dynamic(graph, exclude_self_edges=False, name='fully_connect_graph_dynamic')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=755)<!-- utils_tf.fully_connect_graph_dynamic .code-reference -->
4171+
### [`utils_tf.fully_connect_graph_dynamic(graph, exclude_self_edges=False, name='fully_connect_graph_dynamic')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=762)<!-- utils_tf.fully_connect_graph_dynamic .code-reference -->
41724172

41734173
Adds edges to a graph by fully-connecting the nodes.
41744174

@@ -4195,7 +4195,7 @@ or to be known at graph building time.
41954195
`None` in `graph`.
41964196

41974197

4198-
### [`utils_tf.fully_connect_graph_static(graph, exclude_self_edges=False, name='fully_connect_graph_static')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=687)<!-- utils_tf.fully_connect_graph_static .code-reference -->
4198+
### [`utils_tf.fully_connect_graph_static(graph, exclude_self_edges=False, name='fully_connect_graph_static')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=694)<!-- utils_tf.fully_connect_graph_static .code-reference -->
41994199

42004200
Adds edges to a graph by fully-connecting the nodes.
42014201

@@ -4231,7 +4231,7 @@ case, the method may silently yield an incorrect result.
42314231
the constantness of the number of nodes per graph).
42324232

42334233

4234-
### [`utils_tf.get_feed_dict(placeholders, graph)`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=219)<!-- utils_tf.get_feed_dict .code-reference -->
4234+
### [`utils_tf.get_feed_dict(placeholders, graph)`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=226)<!-- utils_tf.get_feed_dict .code-reference -->
42354235

42364236
Feeds a `graphs.GraphsTuple` of numpy arrays or `None` into `placeholders`.
42374237

@@ -4269,7 +4269,7 @@ restoring the correct behavior.
42694269
match.
42704270

42714271

4272-
### [`utils_tf.get_graph(input_graphs, index, name='get_graph')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=970)<!-- utils_tf.get_graph .code-reference -->
4272+
### [`utils_tf.get_graph(input_graphs, index, name='get_graph')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=977)<!-- utils_tf.get_graph .code-reference -->
42734273

42744274
Indexes into a graph.
42754275

@@ -4300,7 +4300,7 @@ graphs specified by the slice, and returns them into an another instance of a
43004300
* `ValueError`: if `index` is a slice and `index.step` if not None.
43014301

43024302

4303-
### [`utils_tf.get_num_graphs(input_graphs, name='get_num_graphs')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=1051)<!-- utils_tf.get_num_graphs .code-reference -->
4303+
### [`utils_tf.get_num_graphs(input_graphs, name='get_num_graphs')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=1058)<!-- utils_tf.get_num_graphs .code-reference -->
43044304

43054305
Returns the number of graphs (i.e. the batch size) in `input_graphs`.
43064306

@@ -4316,7 +4316,7 @@ Returns the number of graphs (i.e. the batch size) in `input_graphs`.
43164316
number of graphs is dynamic).
43174317

43184318

4319-
### [`utils_tf.identity(graph, name='graph_identity')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=465)<!-- utils_tf.identity .code-reference -->
4319+
### [`utils_tf.identity(graph, name='graph_identity')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=472)<!-- utils_tf.identity .code-reference -->
43204320

43214321
Pass each element of a graph through a `tf.identity`.
43224322

@@ -4341,7 +4341,7 @@ with tf.name_scope("encoder"):
43414341
`graph_output.x = tf.identity(graph.x)`
43424342

43434343

4344-
### [`utils_tf.make_runnable_in_session(graph, name='make_graph_runnable_in_session')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=490)<!-- utils_tf.make_runnable_in_session .code-reference -->
4344+
### [`utils_tf.make_runnable_in_session(graph, name='make_graph_runnable_in_session')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=497)<!-- utils_tf.make_runnable_in_session .code-reference -->
43454345

43464346
Allows a graph containing `None` fields to be run in a `tf.Session`.
43474347

@@ -4363,7 +4363,7 @@ meant to be called just before a call to `sess.run` on a Tensorflow session
43634363
otherwise
43644364

43654365

4366-
### [`utils_tf.nest_to_numpy(nest_of_tensors)`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=1066)<!-- utils_tf.nest_to_numpy .code-reference -->
4366+
### [`utils_tf.nest_to_numpy(nest_of_tensors)`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=1073)<!-- utils_tf.nest_to_numpy .code-reference -->
43674367

43684368
Converts a nest of eager tensors to a nest of numpy arrays.
43694369

@@ -4384,7 +4384,7 @@ tensors into a `graphs.GraphsTuple` of arrays, or nests containing
43844384
arrays and all other elements are kept the same.
43854385

43864386

4387-
### [`utils_tf.placeholders_from_data_dicts(data_dicts, force_dynamic_num_graphs=True, name='placeholders_from_data_dicts')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=263)<!-- utils_tf.placeholders_from_data_dicts .code-reference -->
4387+
### [`utils_tf.placeholders_from_data_dicts(data_dicts, force_dynamic_num_graphs=True, name='placeholders_from_data_dicts')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=270)<!-- utils_tf.placeholders_from_data_dicts .code-reference -->
43884388

43894389
Constructs placeholders compatible with a list of data dicts.
43904390

@@ -4402,7 +4402,7 @@ Constructs placeholders compatible with a list of data dicts.
44024402
dimensions of the dictionaries in `data_dicts`.
44034403

44044404

4405-
### [`utils_tf.placeholders_from_networkxs(graph_nxs, node_shape_hint=None, edge_shape_hint=None, data_type_hint=tf.float32, force_dynamic_num_graphs=True, name='placeholders_from_networkxs')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=284)<!-- utils_tf.placeholders_from_networkxs .code-reference -->
4405+
### [`utils_tf.placeholders_from_networkxs(graph_nxs, node_shape_hint=None, edge_shape_hint=None, data_type_hint=tf.float32, force_dynamic_num_graphs=True, name='placeholders_from_networkxs')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=291)<!-- utils_tf.placeholders_from_networkxs .code-reference -->
44064406

44074407
Constructs placeholders compatible with a list of networkx instances.
44084408

@@ -4443,7 +4443,7 @@ The networkx graph should be set up such that, for fixed shapes `node_shape`,
44434443
dimensions of the graph_nxs.
44444444

44454445

4446-
### [`utils_tf.repeat(tensor, repeats, axis=0, name='repeat')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=512)<!-- utils_tf.repeat .code-reference -->
4446+
### [`utils_tf.repeat(tensor, repeats, axis=0, name='repeat')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=519)<!-- utils_tf.repeat .code-reference -->
44474447

44484448
Repeats a `tf.Tensor`'s elements along an axis by custom amounts.
44494449

@@ -4463,7 +4463,7 @@ Equivalent to Numpy's `np.repeat`.
44634463
The `tf.Tensor` with repeated values.
44644464

44654465

4466-
### [`utils_tf.set_zero_edge_features(graph, edge_size, dtype=tf.float32, name='set_zero_edge_features')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=844)<!-- utils_tf.set_zero_edge_features .code-reference -->
4466+
### [`utils_tf.set_zero_edge_features(graph, edge_size, dtype=tf.float32, name='set_zero_edge_features')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=851)<!-- utils_tf.set_zero_edge_features .code-reference -->
44674467

44684468
Completes the edge state of a graph.
44694469

@@ -4489,7 +4489,7 @@ Completes the edge state of a graph.
44894489
* `ValueError`: If `edge_size` is None.
44904490

44914491

4492-
### [`utils_tf.set_zero_global_features(graph, global_size, dtype=tf.float32, name='set_zero_global_features')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=884)<!-- utils_tf.set_zero_global_features .code-reference -->
4492+
### [`utils_tf.set_zero_global_features(graph, global_size, dtype=tf.float32, name='set_zero_global_features')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=891)<!-- utils_tf.set_zero_global_features .code-reference -->
44934493

44944494
Completes the global state of a graph.
44954495

@@ -4513,7 +4513,7 @@ Completes the global state of a graph.
45134513
* `ValueError`: If `global_size` is not `None`.
45144514

45154515

4516-
### [`utils_tf.set_zero_node_features(graph, node_size, dtype=tf.float32, name='set_zero_node_features')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=812)<!-- utils_tf.set_zero_node_features .code-reference -->
4516+
### [`utils_tf.set_zero_node_features(graph, node_size, dtype=tf.float32, name='set_zero_node_features')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=819)<!-- utils_tf.set_zero_node_features .code-reference -->
45174517

45184518
Completes the node state of a graph.
45194519

@@ -4538,7 +4538,7 @@ Completes the node state of a graph.
45384538
* `ValueError`: If `node_size` is None.
45394539

45404540

4541-
### [`utils_tf.specs_from_graphs_tuple(graphs_tuple_sample, dynamic_num_graphs=False, dynamic_num_nodes=True, dynamic_num_edges=True, description_fn=<class 'tf.TensorSpec'>)`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=1087)<!-- utils_tf.specs_from_graphs_tuple .code-reference -->
4541+
### [`utils_tf.specs_from_graphs_tuple(graphs_tuple_sample, dynamic_num_graphs=False, dynamic_num_nodes=True, dynamic_num_edges=True, description_fn=<class 'tf.TensorSpec'>)`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=1094)<!-- utils_tf.specs_from_graphs_tuple .code-reference -->
45424542

45434543
Returns the `TensorSpec` specification for a given `GraphsTuple`.
45444544

@@ -4594,7 +4594,7 @@ for i in range(num_training_steps):
45944594
* `ValueError`: If a `GraphsTuple` has a field with `None`.
45954595

45964596

4597-
### [`utils_tf.stop_gradient(graph, stop_edges=True, stop_nodes=True, stop_globals=True, name='graph_stop_gradient')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=420)<!-- utils_tf.stop_gradient .code-reference -->
4597+
### [`utils_tf.stop_gradient(graph, stop_edges=True, stop_nodes=True, stop_globals=True, name='graph_stop_gradient')`](https://github.com/deepmind/graph_nets/blob/master/graph_nets/utils_tf.py?l=427)<!-- utils_tf.stop_gradient .code-reference -->
45984598

45994599
Stops the gradient flow through a graph.
46004600

graph_nets/blocks.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,26 @@ def _validate_broadcasted_graph(graph, from_field, to_field):
6464
_validate_graph(graph, [from_field, to_field], additional_message)
6565

6666

67-
def broadcast_globals_to_edges(graph, name="broadcast_globals_to_edges"):
67+
def _get_static_num_nodes(graph):
68+
"""Returns the static total number of nodes in a batch or None."""
69+
return None if graph.nodes is None else graph.nodes.shape.as_list()[0]
70+
71+
72+
def _get_static_num_edges(graph):
73+
"""Returns the static total number of edges in a batch or None."""
74+
return None if graph.senders is None else graph.senders.shape.as_list()[0]
75+
76+
77+
def broadcast_globals_to_edges(graph, name="broadcast_globals_to_edges",
78+
num_edges_hint=None):
6879
"""Broadcasts the global features to the edges of a graph.
6980
7081
Args:
7182
graph: A `graphs.GraphsTuple` containing `Tensor`s, with globals features of
7283
shape `[n_graphs] + global_shape`, and `N_EDGE` field of shape
7384
`[n_graphs]`.
7485
name: (string, optional) A name for the operation.
86+
num_edges_hint: Integer indicating the total number of edges, if known.
7587
7688
Returns:
7789
A tensor of shape `[n_edges] + global_shape`, where
@@ -85,17 +97,20 @@ def broadcast_globals_to_edges(graph, name="broadcast_globals_to_edges"):
8597
"""
8698
_validate_broadcasted_graph(graph, GLOBALS, N_EDGE)
8799
with tf.name_scope(name):
88-
return utils_tf.repeat(graph.globals, graph.n_edge, axis=0)
100+
return utils_tf.repeat(graph.globals, graph.n_edge, axis=0,
101+
sum_repeats_hint=num_edges_hint)
89102

90103

91-
def broadcast_globals_to_nodes(graph, name="broadcast_globals_to_nodes"):
104+
def broadcast_globals_to_nodes(graph, name="broadcast_globals_to_nodes",
105+
num_nodes_hint=None):
92106
"""Broadcasts the global features to the nodes of a graph.
93107
94108
Args:
95109
graph: A `graphs.GraphsTuple` containing `Tensor`s, with globals features of
96110
shape `[n_graphs] + global_shape`, and `N_NODE` field of shape
97111
`[n_graphs]`.
98112
name: (string, optional) A name for the operation.
113+
num_nodes_hint: Integer indicating the total number of nodes, if known.
99114
100115
Returns:
101116
A tensor of shape `[n_nodes] + global_shape`, where
@@ -109,7 +124,8 @@ def broadcast_globals_to_nodes(graph, name="broadcast_globals_to_nodes"):
109124
"""
110125
_validate_broadcasted_graph(graph, GLOBALS, N_NODE)
111126
with tf.name_scope(name):
112-
return utils_tf.repeat(graph.globals, graph.n_node, axis=0)
127+
return utils_tf.repeat(graph.globals, graph.n_node, axis=0,
128+
sum_repeats_hint=num_nodes_hint)
113129

114130

115131
def broadcast_sender_nodes_to_edges(
@@ -189,7 +205,8 @@ def _build(self, graph):
189205
additional_message="when aggregating from edges.")
190206
num_graphs = utils_tf.get_num_graphs(graph)
191207
graph_index = tf.range(num_graphs)
192-
indices = utils_tf.repeat(graph_index, graph.n_edge, axis=0)
208+
indices = utils_tf.repeat(graph_index, graph.n_edge, axis=0,
209+
sum_repeats_hint=_get_static_num_edges(graph))
193210
return self._reducer(graph.edges, indices, num_graphs)
194211

195212

@@ -225,7 +242,8 @@ def _build(self, graph):
225242
additional_message="when aggregating from nodes.")
226243
num_graphs = utils_tf.get_num_graphs(graph)
227244
graph_index = tf.range(num_graphs)
228-
indices = utils_tf.repeat(graph_index, graph.n_node, axis=0)
245+
indices = utils_tf.repeat(graph_index, graph.n_node, axis=0,
246+
sum_repeats_hint=_get_static_num_nodes(graph))
229247
return self._reducer(graph.nodes, indices, num_graphs)
230248

231249

@@ -453,7 +471,9 @@ def _build(self, graph):
453471
edges_to_collect.append(broadcast_sender_nodes_to_edges(graph))
454472

455473
if self._use_globals:
456-
edges_to_collect.append(broadcast_globals_to_edges(graph))
474+
num_edges_hint = _get_static_num_edges(graph)
475+
edges_to_collect.append(
476+
broadcast_globals_to_edges(graph, num_edges_hint=num_edges_hint))
457477

458478
collected_edges = tf.concat(edges_to_collect, axis=-1)
459479
updated_edges = self._edge_model(collected_edges)
@@ -563,7 +583,12 @@ def _build(self, graph):
563583
nodes_to_collect.append(graph.nodes)
564584

565585
if self._use_globals:
566-
nodes_to_collect.append(broadcast_globals_to_nodes(graph))
586+
# The hint will be an integer if the graph has node features and the total
587+
# number of nodes is known at tensorflow graph definition time, or None
588+
# otherwise.
589+
num_nodes_hint = _get_static_num_nodes(graph)
590+
nodes_to_collect.append(
591+
broadcast_globals_to_nodes(graph, num_nodes_hint=num_nodes_hint))
567592

568593
collected_nodes = tf.concat(nodes_to_collect, axis=-1)
569594
updated_nodes = self._node_model(collected_nodes)

graph_nets/demos/graph_nets_basics.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@
115115
"import networkx as nx\n",
116116
"import numpy as np\n",
117117
"import sonnet as snt\n",
118-
"import tensorflow as tf"
118+
"import tensorflow as tf\n",
119+
""
119120
]
120121
},
121122
{

graph_nets/demos/physics.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@
130130
"import sonnet as snt\n",
131131
"import tensorflow as tf\n",
132132
"\n",
133+
"\n",
133134
"try:\n",
134135
" import seaborn as sns\n",
135136
"except ImportError:\n",

graph_nets/demos/shortest_path.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@
124124
"from scipy import spatial\n",
125125
"import tensorflow as tf\n",
126126
"\n",
127+
"\n",
127128
"SEED = 1\n",
128129
"np.random.seed(SEED)\n",
129130
"tf.set_random_seed(SEED)"

graph_nets/demos/sort.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137
"import numpy as np\n",
138138
"import tensorflow as tf\n",
139139
"\n",
140+
"\n",
140141
"SEED = 1\n",
141142
"np.random.seed(SEED)\n",
142143
"tf.set_random_seed(SEED)"
@@ -697,8 +698,8 @@
697698
"colab": {
698699
"collapsed_sections": [],
699700
"last_runtime": {
700-
"build_target": "//learning/deepmind/dm_python:dm_notebook3",
701-
"kind": "private"
701+
"build_target": "",
702+
"kind": "local"
702703
},
703704
"name": "sort.ipynb",
704705
"provenance": [],

0 commit comments

Comments
 (0)