Skip to content

Commit dce2c60

Browse files
authored
feat: add shard execution workflow (#1557)
1 parent 6eb6fba commit dce2c60

File tree

15 files changed

+1046
-129
lines changed

15 files changed

+1046
-129
lines changed

nx/config/config.exs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,7 @@ import Config
55
# true inside Nx.
66
config :nx, :verify_grad, true
77
config :nx, :verify_binary_size, true
8+
9+
# If set to true, shards and sharding stages will be
10+
# inspected with their debug ids alongside their unique ref ids
11+
config :nx, :debug_shards, true

nx/lib/nx/application.ex

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ defmodule Nx.Application do
44

55
def start(_type, _args) do
66
children = [
7+
Nx.Defn.ShardingCompiler.ShardRegistry,
78
%{id: Nx.Serving.PG, start: {:pg, :start_link, [Nx.Serving.PG]}},
89
{Nx.HiddenServing, Nx.Serving.PG}
910
]

nx/lib/nx/defn/sharding_compiler.ex

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,13 @@ defmodule Nx.Defn.ShardingCompiler do
2020

2121
[args] = args
2222

23-
%T{
24-
shape: shape,
25-
type: type,
26-
data: %ShardPropagation{
27-
shards: output_shards,
28-
parameter_ids_to_index: parameter_ids_to_index
29-
}
30-
} =
23+
{%T{
24+
type: type,
25+
data: %ShardPropagation{
26+
shards: output_shards
27+
}
28+
}, parameter_ids_to_index,
29+
shape} =
3130
propagate_shards(vars, fun, opts[:sharding_config] || [])
3231

3332
data_sections =
@@ -152,9 +151,9 @@ defmodule Nx.Defn.ShardingCompiler do
152151
|> Enum.with_index(fn x, idx -> {idx, x} end)
153152
|> Map.new()
154153

155-
{container, _cache, _state} = ShardPropagation.traverse(expr, tensor_shardings)
154+
{container, _cache, state} = ShardPropagation.traverse(expr, tensor_shardings)
156155

157-
container
156+
{container, state.parameter_ids_to_index, expr.shape}
158157
end
159158

160159
@impl true

nx/lib/nx/defn/sharding_compiler/passes/graph_splitter.ex

Lines changed: 109 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,23 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
44
alias Nx.Tensor, as: T
55
alias Nx.Defn.Expr
66
alias Nx.Defn.ShardingCompiler.Shard
7+
alias Nx.Defn.ShardingCompiler.Passes.GraphSplitter.Stage
78

89
@gather_ops [:dot]
910
@reduction_ops [:sum]
1011

11-
def traverse(expr, expr_shards \\ %{}) do
12+
@ops_to_split Map.merge(
13+
Map.new(@gather_ops, &{&1, :gather}),
14+
Map.new(@reduction_ops, &{&1, :reduce})
15+
)
16+
17+
def traverse(expr, expr_shards \\ %{}, ops_to_split \\ @ops_to_split) do
1218
# expression_chain is going to be a reverse-accumulation of {category, subexpr}
1319
# that we can then compile and chain-execute elsewhere. category is either :gather, :reduce or :none
1420
state = %{
1521
expression_chain: [],
1622
nodes_to_replace: %{},
23+
ops_to_split: ops_to_split,
1724
# contains the sharding configuration for each node by id
1825
shards: expr_shards,
1926
# args is a map of id -> {stage_id, output_container_position}
@@ -54,62 +61,64 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
5461
{id, {expr, nil}}, idx ->
5562
{id, put_in(expr.data.args, [idx])}
5663

57-
{id, {expr, shard_propagation}}, idx ->
64+
{id, {expr, _shard_propagation}}, idx ->
5865
expr = put_in(expr.data.args, [idx])
59-
expr = Expr.metadata(expr, %{shards: shard_propagation.shards})
6066
{id, expr}
6167
end)
6268
|> Map.new()
6369

6470
{expr, _} =
6571
composite_rewrite_subtree(expr, %{state | nodes_to_replace: arg_remapping})
6672

67-
expr =
68-
Composite.traverse(expr, fn
69-
%T{data: %Expr{id: id}} = t ->
70-
if shard_propagation = state.shards[id] do
71-
Expr.metadata(t, %{shards: shard_propagation.shards})
72-
else
73-
t
74-
end
75-
76-
other ->
77-
other
73+
# Traverse the expression to remap all shapes according to the sharding given
74+
expr = set_shard_metadata(expr, state.shards)
75+
76+
arguments =
77+
Map.new(arg_remapping, fn {_id, arg_expr} ->
78+
{arg_expr.data.id, set_shard_metadata(arg_expr, state.shards)}
7879
end)
7980

80-
argument_sources = Map.take(state.args, Map.keys(arg_remapping))
81+
argument_sources =
82+
state.args
83+
|> Map.take(Map.keys(arg_remapping))
84+
|> Map.new(fn {remap_id, v} ->
85+
{arg_remapping[remap_id].data.id, v}
86+
end)
8187

82-
[{id, category, expr, argument_sources} | acc]
88+
[
89+
%Stage{
90+
id: id,
91+
category: category,
92+
expr: expr,
93+
arguments: arguments,
94+
argument_sources: argument_sources
95+
}
96+
| acc
97+
]
8398
end
8499
)
85100

86-
{expr_chain, Map.delete(state, :expression_chain), cache}
101+
{expr_chain, cache, Map.delete(state, :expression_chain)}
87102
end
88103

89104
defp composite_eval(expr, state, cache) do
90105
Composite.traverse(expr, {cache, state}, &eval/2)
91106
end
92107

93108
defp eval(%T{data: %Expr{id: id, op: op}} = ans, {cache, state}) do
94-
case {cache, state.nodes_to_replace} do
95-
{_, %{^id => res}} ->
109+
case {cache, state.nodes_to_replace, state.ops_to_split} do
110+
{_, %{^id => res}, _} ->
96111
# Replace the node with the corresponding parameter
97112
{res, {Map.put(cache, id, res), state}}
98113

99-
{%{^id => res}, _} ->
114+
{%{^id => res}, _, _} ->
100115
{res, {cache, state}}
101116

102-
{_, _} ->
103-
cond do
104-
op in @gather_ops ->
105-
rewrite_args(ans, :gather, {cache, state})
106-
107-
op in @reduction_ops ->
108-
rewrite_args(ans, :reduce, {cache, state})
117+
{_, _, %{^op => category}} ->
118+
rewrite_args(ans, category, {cache, state})
109119

110-
true ->
111-
eval_apply(op, ans, {cache, state})
112-
end
120+
_ ->
121+
eval_apply(op, ans, {cache, state})
113122
end
114123
end
115124

@@ -203,8 +212,8 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
203212
{new_expr, {cache, state}}
204213
end
205214

206-
defp eval_apply(:parameter, %T{data: %Expr{id: id}} = expr, {cache, state}) do
207-
state = put_in(state.args[id], nil)
215+
defp eval_apply(:parameter, %T{data: %Expr{id: id, args: [idx]}} = expr, {cache, state}) do
216+
state = put_in(state.args[id], {nil, idx})
208217
{expr, {Map.put(cache, id, expr), state}}
209218
end
210219

@@ -220,19 +229,26 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
220229
{ans, {Map.put(cache, id, ans), state}}
221230
end
222231

223-
defp composite_rewrite_subtree(args, state, acc \\ %{used_args: %{}})
232+
defp composite_rewrite_subtree(container, state, acc \\ %{used_args: %{}})
224233

225-
defp composite_rewrite_subtree(args, state, acc) when is_list(args) do
226-
Enum.map_reduce(args, acc, fn
234+
defp composite_rewrite_subtree(container, state, acc) when is_list(container) do
235+
Enum.map_reduce(container, acc, fn
227236
%T{} = arg, acc ->
228237
composite_rewrite_subtree(arg, state, acc)
229238

239+
arg, acc when is_list(arg) ->
240+
composite_rewrite_subtree(arg, state, acc)
241+
230242
arg, acc ->
231243
{arg, acc}
232244
end)
233245
end
234246

235-
defp composite_rewrite_subtree(%T{data: %Expr{id: id, op: :parameter}} = expr, state, acc) do
247+
defp composite_rewrite_subtree(container, state, acc) do
248+
Composite.traverse(container, acc, &rewrite_subtree(&1, state, &2))
249+
end
250+
251+
defp rewrite_subtree(%T{data: %Expr{id: id, op: :parameter}} = expr, state, acc) do
236252
case state.nodes_to_replace do
237253
%{^id => res} ->
238254
{res, put_in(acc.used_args[id], {res, state.shards[id]})}
@@ -242,22 +258,75 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
242258
end
243259
end
244260

245-
defp composite_rewrite_subtree(arg, state, acc) do
246-
Composite.traverse(arg, acc, &rewrite_subtree(&1, state, &2))
261+
defp rewrite_subtree(
262+
%T{data: %Expr{op: :optional, id: id, args: [call, subexpr, fun]}} = expr,
263+
state,
264+
acc
265+
) do
266+
case state.nodes_to_replace do
267+
%{^id => res} ->
268+
{res, put_in(acc.used_args[id], {res, state.shards[id]})}
269+
270+
_ ->
271+
{call, acc} = rewrite_subtree(call, state, acc)
272+
# `subexpr` is hermetic, in the sense that it is a self-contained scope
273+
# from which the arguments always come from `call`, so we can
274+
# keep it as is.
275+
276+
{put_in(expr.data.args, [call, subexpr, fun]), acc}
277+
end
247278
end
248279

249280
defp rewrite_subtree(%T{data: %Expr{id: id, args: args}} = expr, state, acc) do
250281
case state.nodes_to_replace do
251282
%{^id => res} ->
252283
# nodes_to_replace always contains a param
253-
{res, put_in(acc.used_args[id], res)}
284+
{res, put_in(acc.used_args[id], {res, state.shards[id]})}
254285

255286
_ ->
256287
{args, acc} = composite_rewrite_subtree(args, state, acc)
257-
258288
{put_in(expr.data.args, args), acc}
259289
end
260290
end
261291

262292
defp rewrite_subtree(other, _, acc), do: {other, acc}
293+
294+
defp set_shard_metadata(expr, shards) do
295+
Composite.traverse(expr, fn
296+
%T{data: %Expr{id: id}} = t ->
297+
if shard_propagation = shards[id] do
298+
shape =
299+
shard_propagation.shards
300+
|> Enum.sort()
301+
|> Enum.map(fn {_axis, [%Shard{length: length} | _]} -> length end)
302+
|> List.to_tuple()
303+
304+
t = do_set_shard_metadata(%{t | shape: shape}, shards)
305+
Expr.metadata(t, %{shards: shard_propagation.shards})
306+
else
307+
do_set_shard_metadata(t, shards)
308+
end
309+
310+
other ->
311+
other
312+
end)
313+
end
314+
315+
defp do_set_shard_metadata(%T{data: %Expr{args: args}} = expr, shards) do
316+
args =
317+
Enum.map(args, fn
318+
%T{} = arg ->
319+
set_shard_metadata(arg, shards)
320+
321+
arg when is_list(arg) ->
322+
Enum.map(arg, &do_set_shard_metadata(&1, shards))
323+
324+
arg ->
325+
arg
326+
end)
327+
328+
put_in(expr.data.args, args)
329+
end
330+
331+
defp do_set_shard_metadata(other, _), do: other
263332
end
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter.Stage do
2+
defstruct [:id, :category, :expr, :arguments, :argument_sources]
3+
end

nx/lib/nx/defn/sharding_compiler/passes/shard_propagation.ex

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do
55

66
alias Nx.Defn.ShardingCompiler.Shard
77

8-
defstruct [:id, :shards, :input_tensor_shardings, :parameter_ids_to_index, :expr]
8+
defstruct [:id, :shards, :expr]
99

1010
def traverse(expr, tensor_shardings) do
1111
{container, {cache, state}} =
@@ -19,9 +19,6 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do
1919
%{}
2020
)
2121

22-
container = put_in(container.data.input_tensor_shardings, tensor_shardings)
23-
container = put_in(container.data.parameter_ids_to_index, state.parameter_ids_to_index)
24-
2522
{container, cache, state}
2623
end
2724

@@ -53,7 +50,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do
5350
t
5451
|> Nx.axes()
5552
|> Map.new(fn axis ->
56-
{axis, [0..(elem(t.shape, axis) - 1)]}
53+
{axis, elem(t.shape, axis)}
5754
end)
5855

5956
expr = shard_from_config(t, config)
@@ -62,7 +59,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do
6259
end
6360

6461
defp eval(%T{data: %Expr{op: :constant, args: [_constant]}} = ans, {cache, state}) do
65-
expr = shard_from_config(ans, %{0 => [0..0]})
62+
expr = shard_from_config(ans, %{})
6663
state = put_in(state.expr_shards[expr.data.id], expr.data)
6764
{expr, {cache, state}}
6865
end
@@ -361,6 +358,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do
361358
defp resolve_sharding_broadcast(axis, left_shards, false, right_shards, false) do
362359
# We have a shard on both sides. We need to determine the intersection of the two.
363360
# This is fine only if all shards are equal
361+
364362
{reverse_out_shards, all_shards_match} =
365363
Enum.zip_reduce(left_shards, right_shards, {[], true}, fn left,
366364
right,

0 commit comments

Comments
 (0)