Skip to content
4 changes: 4 additions & 0 deletions nx/config/config.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ import Config
# true inside Nx.
config :nx, :verify_grad, true
config :nx, :verify_binary_size, true

# If set to true, shards and sharding stages will be
# inspected with their debug ids alongside their unique ref ids
config :nx, :debug_shards, true
1 change: 1 addition & 0 deletions nx/lib/nx/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ defmodule Nx.Application do

def start(_type, _args) do
children = [
Nx.Defn.ShardingCompiler.ShardRegistry,
%{id: Nx.Serving.PG, start: {:pg, :start_link, [Nx.Serving.PG]}},
{Nx.HiddenServing, Nx.Serving.PG}
]
Expand Down
19 changes: 9 additions & 10 deletions nx/lib/nx/defn/sharding_compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@ defmodule Nx.Defn.ShardingCompiler do

[args] = args

%T{
shape: shape,
type: type,
data: %ShardPropagation{
shards: output_shards,
parameter_ids_to_index: parameter_ids_to_index
}
} =
{%T{
type: type,
data: %ShardPropagation{
shards: output_shards
}
}, parameter_ids_to_index,
shape} =
propagate_shards(vars, fun, opts[:sharding_config] || [])

data_sections =
Expand Down Expand Up @@ -152,9 +151,9 @@ defmodule Nx.Defn.ShardingCompiler do
|> Enum.with_index(fn x, idx -> {idx, x} end)
|> Map.new()

{container, _cache, _state} = ShardPropagation.traverse(expr, tensor_shardings)
{container, _cache, state} = ShardPropagation.traverse(expr, tensor_shardings)

container
{container, state.parameter_ids_to_index, expr.shape}
end

@impl true
Expand Down
149 changes: 109 additions & 40 deletions nx/lib/nx/defn/sharding_compiler/passes/graph_splitter.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,23 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
alias Nx.Tensor, as: T
alias Nx.Defn.Expr
alias Nx.Defn.ShardingCompiler.Shard
alias Nx.Defn.ShardingCompiler.Passes.GraphSplitter.Stage

@gather_ops [:dot]
@reduction_ops [:sum]

def traverse(expr, expr_shards \\ %{}) do
@ops_to_split Map.merge(
Map.new(@gather_ops, &{&1, :gather}),
Map.new(@reduction_ops, &{&1, :reduce})
)

def traverse(expr, expr_shards \\ %{}, ops_to_split \\ @ops_to_split) do
# expression_chain is going to be a reverse-accumulation of {category, subexpr}
# that we can then compile and chain-execute elsewhere. category is either :gather, :reduce or :none
state = %{
expression_chain: [],
nodes_to_replace: %{},
ops_to_split: ops_to_split,
# contains the sharding configuration for each node by id
shards: expr_shards,
# args is a map of id -> {stage_id, output_container_position}
Expand Down Expand Up @@ -54,62 +61,64 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
{id, {expr, nil}}, idx ->
{id, put_in(expr.data.args, [idx])}

{id, {expr, shard_propagation}}, idx ->
{id, {expr, _shard_propagation}}, idx ->
expr = put_in(expr.data.args, [idx])
expr = Expr.metadata(expr, %{shards: shard_propagation.shards})
{id, expr}
end)
|> Map.new()

{expr, _} =
composite_rewrite_subtree(expr, %{state | nodes_to_replace: arg_remapping})

expr =
Composite.traverse(expr, fn
%T{data: %Expr{id: id}} = t ->
if shard_propagation = state.shards[id] do
Expr.metadata(t, %{shards: shard_propagation.shards})
else
t
end

other ->
other
# Traverse the expression to remap all shapes according to the sharding given
expr = set_shard_metadata(expr, state.shards)

arguments =
Map.new(arg_remapping, fn {_id, arg_expr} ->
{arg_expr.data.id, set_shard_metadata(arg_expr, state.shards)}
end)

argument_sources = Map.take(state.args, Map.keys(arg_remapping))
argument_sources =
state.args
|> Map.take(Map.keys(arg_remapping))
|> Map.new(fn {remap_id, v} ->
{arg_remapping[remap_id].data.id, v}
end)

[{id, category, expr, argument_sources} | acc]
[
%Stage{
id: id,
category: category,
expr: expr,
arguments: arguments,
argument_sources: argument_sources
}
| acc
]
end
)

{expr_chain, Map.delete(state, :expression_chain), cache}
{expr_chain, cache, Map.delete(state, :expression_chain)}
end

defp composite_eval(expr, state, cache) do
Composite.traverse(expr, {cache, state}, &eval/2)
end

defp eval(%T{data: %Expr{id: id, op: op}} = ans, {cache, state}) do
case {cache, state.nodes_to_replace} do
{_, %{^id => res}} ->
case {cache, state.nodes_to_replace, state.ops_to_split} do
{_, %{^id => res}, _} ->
# Replace the node with the corresponding parameter
{res, {Map.put(cache, id, res), state}}

{%{^id => res}, _} ->
{%{^id => res}, _, _} ->
{res, {cache, state}}

{_, _} ->
cond do
op in @gather_ops ->
rewrite_args(ans, :gather, {cache, state})

op in @reduction_ops ->
rewrite_args(ans, :reduce, {cache, state})
{_, _, %{^op => category}} ->
rewrite_args(ans, category, {cache, state})

true ->
eval_apply(op, ans, {cache, state})
end
_ ->
eval_apply(op, ans, {cache, state})
end
end

Expand Down Expand Up @@ -203,8 +212,8 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
{new_expr, {cache, state}}
end

defp eval_apply(:parameter, %T{data: %Expr{id: id}} = expr, {cache, state}) do
state = put_in(state.args[id], nil)
defp eval_apply(:parameter, %T{data: %Expr{id: id, args: [idx]}} = expr, {cache, state}) do
state = put_in(state.args[id], {nil, idx})
{expr, {Map.put(cache, id, expr), state}}
end

Expand All @@ -220,19 +229,26 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
{ans, {Map.put(cache, id, ans), state}}
end

defp composite_rewrite_subtree(args, state, acc \\ %{used_args: %{}})
defp composite_rewrite_subtree(container, state, acc \\ %{used_args: %{}})

defp composite_rewrite_subtree(args, state, acc) when is_list(args) do
Enum.map_reduce(args, acc, fn
defp composite_rewrite_subtree(container, state, acc) when is_list(container) do
Enum.map_reduce(container, acc, fn
%T{} = arg, acc ->
composite_rewrite_subtree(arg, state, acc)

arg, acc when is_list(arg) ->
composite_rewrite_subtree(arg, state, acc)

arg, acc ->
{arg, acc}
end)
end

defp composite_rewrite_subtree(%T{data: %Expr{id: id, op: :parameter}} = expr, state, acc) do
defp composite_rewrite_subtree(container, state, acc) do
Composite.traverse(container, acc, &rewrite_subtree(&1, state, &2))
end

defp rewrite_subtree(%T{data: %Expr{id: id, op: :parameter}} = expr, state, acc) do
case state.nodes_to_replace do
%{^id => res} ->
{res, put_in(acc.used_args[id], {res, state.shards[id]})}
Expand All @@ -242,22 +258,75 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
end
end

defp composite_rewrite_subtree(arg, state, acc) do
Composite.traverse(arg, acc, &rewrite_subtree(&1, state, &2))
defp rewrite_subtree(
%T{data: %Expr{op: :optional, id: id, args: [call, subexpr, fun]}} = expr,
state,
acc
) do
case state.nodes_to_replace do
%{^id => res} ->
{res, put_in(acc.used_args[id], {res, state.shards[id]})}

_ ->
{call, acc} = rewrite_subtree(call, state, acc)
# `subexpr` is hermetic, in the sense that it is a self-contained scope
# from which the arguments always come from `call`, so we can
# keep it as is.

{put_in(expr.data.args, [call, subexpr, fun]), acc}
end
end

defp rewrite_subtree(%T{data: %Expr{id: id, args: args}} = expr, state, acc) do
case state.nodes_to_replace do
%{^id => res} ->
# nodes_to_replace always contains a param
{res, put_in(acc.used_args[id], res)}
{res, put_in(acc.used_args[id], {res, state.shards[id]})}

_ ->
{args, acc} = composite_rewrite_subtree(args, state, acc)

{put_in(expr.data.args, args), acc}
end
end

defp rewrite_subtree(other, _, acc), do: {other, acc}

defp set_shard_metadata(expr, shards) do
Composite.traverse(expr, fn
%T{data: %Expr{id: id}} = t ->
if shard_propagation = shards[id] do
shape =
shard_propagation.shards
|> Enum.sort()
|> Enum.map(fn {_axis, [%Shard{length: length} | _]} -> length end)
|> List.to_tuple()

t = do_set_shard_metadata(%{t | shape: shape}, shards)
Expr.metadata(t, %{shards: shard_propagation.shards})
else
do_set_shard_metadata(t, shards)
end

other ->
other
end)
end

defp do_set_shard_metadata(%T{data: %Expr{args: args}} = expr, shards) do
args =
Enum.map(args, fn
%T{} = arg ->
set_shard_metadata(arg, shards)

arg when is_list(arg) ->
Enum.map(arg, &do_set_shard_metadata(&1, shards))

arg ->
arg
end)

put_in(expr.data.args, args)
end

defp do_set_shard_metadata(other, _), do: other
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter.Stage do
defstruct [:id, :category, :expr, :arguments, :argument_sources]
end
10 changes: 4 additions & 6 deletions nx/lib/nx/defn/sharding_compiler/passes/shard_propagation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do

alias Nx.Defn.ShardingCompiler.Shard

defstruct [:id, :shards, :input_tensor_shardings, :parameter_ids_to_index, :expr]
defstruct [:id, :shards, :expr]

def traverse(expr, tensor_shardings) do
{container, {cache, state}} =
Expand All @@ -19,9 +19,6 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do
%{}
)

container = put_in(container.data.input_tensor_shardings, tensor_shardings)
container = put_in(container.data.parameter_ids_to_index, state.parameter_ids_to_index)

{container, cache, state}
end

Expand Down Expand Up @@ -53,7 +50,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do
t
|> Nx.axes()
|> Map.new(fn axis ->
{axis, [0..(elem(t.shape, axis) - 1)]}
{axis, elem(t.shape, axis)}
end)

expr = shard_from_config(t, config)
Expand All @@ -62,7 +59,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do
end

defp eval(%T{data: %Expr{op: :constant, args: [_constant]}} = ans, {cache, state}) do
expr = shard_from_config(ans, %{0 => [0..0]})
expr = shard_from_config(ans, %{})
state = put_in(state.expr_shards[expr.data.id], expr.data)
{expr, {cache, state}}
end
Expand Down Expand Up @@ -361,6 +358,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do
defp resolve_sharding_broadcast(axis, left_shards, false, right_shards, false) do
# We have a shard on both sides. We need to determine the intersection of the two.
# This is fine only if all shards are equal

{reverse_out_shards, all_shards_match} =
Enum.zip_reduce(left_shards, right_shards, {[], true}, fn left,
right,
Expand Down
Loading
Loading