diff --git a/nx/config/config.exs b/nx/config/config.exs index c1d328181e..a4f4ea7a48 100644 --- a/nx/config/config.exs +++ b/nx/config/config.exs @@ -5,3 +5,7 @@ import Config # true inside Nx. config :nx, :verify_grad, true config :nx, :verify_binary_size, true + +# If set to true, shards and sharding stages will be +# inspected with their debug ids alongside their unique ref ids +config :nx, :debug_shards, true diff --git a/nx/lib/nx/application.ex b/nx/lib/nx/application.ex index 3ba5b5ebca..062c6ed014 100644 --- a/nx/lib/nx/application.ex +++ b/nx/lib/nx/application.ex @@ -4,6 +4,7 @@ defmodule Nx.Application do def start(_type, _args) do children = [ + Nx.Defn.ShardingCompiler.ShardRegistry, %{id: Nx.Serving.PG, start: {:pg, :start_link, [Nx.Serving.PG]}}, {Nx.HiddenServing, Nx.Serving.PG} ] diff --git a/nx/lib/nx/defn/sharding_compiler.ex b/nx/lib/nx/defn/sharding_compiler.ex index 884e231e3c..c0cca1efb5 100644 --- a/nx/lib/nx/defn/sharding_compiler.ex +++ b/nx/lib/nx/defn/sharding_compiler.ex @@ -20,14 +20,13 @@ defmodule Nx.Defn.ShardingCompiler do [args] = args - %T{ - shape: shape, - type: type, - data: %ShardPropagation{ - shards: output_shards, - parameter_ids_to_index: parameter_ids_to_index - } - } = + {%T{ + type: type, + data: %ShardPropagation{ + shards: output_shards + } + }, parameter_ids_to_index, + shape} = propagate_shards(vars, fun, opts[:sharding_config] || []) data_sections = @@ -152,9 +151,9 @@ defmodule Nx.Defn.ShardingCompiler do |> Enum.with_index(fn x, idx -> {idx, x} end) |> Map.new() - {container, _cache, _state} = ShardPropagation.traverse(expr, tensor_shardings) + {container, _cache, state} = ShardPropagation.traverse(expr, tensor_shardings) - container + {container, state.parameter_ids_to_index, expr.shape} end @impl true diff --git a/nx/lib/nx/defn/sharding_compiler/passes/graph_splitter.ex b/nx/lib/nx/defn/sharding_compiler/passes/graph_splitter.ex index 41b70d9273..44e20c9d06 100644 --- a/nx/lib/nx/defn/sharding_compiler/passes/graph_splitter.ex +++ b/nx/lib/nx/defn/sharding_compiler/passes/graph_splitter.ex @@ -4,16 +4,23 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do alias Nx.Tensor, as: T alias Nx.Defn.Expr alias Nx.Defn.ShardingCompiler.Shard + alias Nx.Defn.ShardingCompiler.Passes.GraphSplitter.Stage @gather_ops [:dot] @reduction_ops [:sum] - def traverse(expr, expr_shards \\ %{}) do + @ops_to_split Map.merge( + Map.new(@gather_ops, &{&1, :gather}), + Map.new(@reduction_ops, &{&1, :reduce}) + ) + + def traverse(expr, expr_shards \\ %{}, ops_to_split \\ @ops_to_split) do # expression_chain is going to be a reverse-accumulation of {category, subexpr} # that we can then compile and chain-execute elsewhere. category is either :gather, :reduce or :none state = %{ expression_chain: [], nodes_to_replace: %{}, + ops_to_split: ops_to_split, # contains the sharding configuration for each node by id shards: expr_shards, # args is a map of id -> {stage_id, output_container_position} @@ -54,9 +61,8 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do {id, {expr, nil}}, idx -> {id, put_in(expr.data.args, [idx])} - {id, {expr, shard_propagation}}, idx -> + {id, {expr, _shard_propagation}}, idx -> expr = put_in(expr.data.args, [idx]) - expr = Expr.metadata(expr, %{shards: shard_propagation.shards}) {id, expr} end) |> Map.new() @@ -64,26 +70,35 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do {expr, _} = composite_rewrite_subtree(expr, %{state | nodes_to_replace: arg_remapping}) - expr = - Composite.traverse(expr, fn - %T{data: %Expr{id: id}} = t -> - if shard_propagation = state.shards[id] do - Expr.metadata(t, %{shards: shard_propagation.shards}) - else - t - end - - other -> - other + # Traverse the expression to remap all shapes according to the sharding given + expr = set_shard_metadata(expr, state.shards) + + arguments = + Map.new(arg_remapping, fn {_id, arg_expr} -> + {arg_expr.data.id, set_shard_metadata(arg_expr, state.shards)} end) - argument_sources = Map.take(state.args, Map.keys(arg_remapping)) + argument_sources = + state.args + |> Map.take(Map.keys(arg_remapping)) + |> Map.new(fn {remap_id, v} -> + {arg_remapping[remap_id].data.id, v} + end) - [{id, category, expr, argument_sources} | acc] + [ + %Stage{ + id: id, + category: category, + expr: expr, + arguments: arguments, + argument_sources: argument_sources + } + | acc + ] end ) - {expr_chain, Map.delete(state, :expression_chain), cache} + {expr_chain, cache, Map.delete(state, :expression_chain)} end defp composite_eval(expr, state, cache) do @@ -91,25 +106,19 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do end defp eval(%T{data: %Expr{id: id, op: op}} = ans, {cache, state}) do - case {cache, state.nodes_to_replace} do - {_, %{^id => res}} -> + case {cache, state.nodes_to_replace, state.ops_to_split} do + {_, %{^id => res}, _} -> # Replace the node with the corresponding parameter {res, {Map.put(cache, id, res), state}} - {%{^id => res}, _} -> + {%{^id => res}, _, _} -> {res, {cache, state}} - {_, _} -> - cond do - op in @gather_ops -> - rewrite_args(ans, :gather, {cache, state}) - - op in @reduction_ops -> - rewrite_args(ans, :reduce, {cache, state}) + {_, _, %{^op => category}} -> + rewrite_args(ans, category, {cache, state}) - true -> - eval_apply(op, ans, {cache, state}) - end + _ -> + eval_apply(op, ans, {cache, state}) end end @@ -203,8 +212,8 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do {new_expr, {cache, state}} end - defp eval_apply(:parameter, %T{data: %Expr{id: id}} = expr, {cache, state}) do - state = put_in(state.args[id], nil) + defp eval_apply(:parameter, %T{data: %Expr{id: id, args: [idx]}} = expr, {cache, state}) do + state = put_in(state.args[id], {nil, idx}) {expr, {Map.put(cache, id, expr), state}} end @@ -220,19 +229,26 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do {ans, {Map.put(cache, id, ans), state}} end - defp composite_rewrite_subtree(args, state, acc \\ %{used_args: %{}}) + defp composite_rewrite_subtree(container, state, acc \\ %{used_args: %{}}) - defp composite_rewrite_subtree(args, state, acc) when is_list(args) do - Enum.map_reduce(args, acc, fn + defp composite_rewrite_subtree(container, state, acc) when is_list(container) do + Enum.map_reduce(container, acc, fn %T{} = arg, acc -> composite_rewrite_subtree(arg, state, acc) + arg, acc when is_list(arg) -> + composite_rewrite_subtree(arg, state, acc) + arg, acc -> {arg, acc} end) end - defp composite_rewrite_subtree(%T{data: %Expr{id: id, op: :parameter}} = expr, state, acc) do + defp composite_rewrite_subtree(container, state, acc) do + Composite.traverse(container, acc, &rewrite_subtree(&1, state, &2)) + end + + defp rewrite_subtree(%T{data: %Expr{id: id, op: :parameter}} = expr, state, acc) do case state.nodes_to_replace do %{^id => res} -> {res, put_in(acc.used_args[id], {res, state.shards[id]})} @@ -242,22 +258,75 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do end end - defp composite_rewrite_subtree(arg, state, acc) do - Composite.traverse(arg, acc, &rewrite_subtree(&1, state, &2)) + defp rewrite_subtree( + %T{data: %Expr{op: :optional, id: id, args: [call, subexpr, fun]}} = expr, + state, + acc + ) do + case state.nodes_to_replace do + %{^id => res} -> + {res, put_in(acc.used_args[id], {res, state.shards[id]})} + + _ -> + {call, acc} = rewrite_subtree(call, state, acc) + # `subexpr` is hermetic, in the sense that it is a self-contained scope + # from which the arguments always come from `call`, so we can + # keep it as is. + + {put_in(expr.data.args, [call, subexpr, fun]), acc} + end end defp rewrite_subtree(%T{data: %Expr{id: id, args: args}} = expr, state, acc) do case state.nodes_to_replace do %{^id => res} -> # nodes_to_replace always contains a param - {res, put_in(acc.used_args[id], res)} + {res, put_in(acc.used_args[id], {res, state.shards[id]})} _ -> {args, acc} = composite_rewrite_subtree(args, state, acc) - {put_in(expr.data.args, args), acc} end end defp rewrite_subtree(other, _, acc), do: {other, acc} + + defp set_shard_metadata(expr, shards) do + Composite.traverse(expr, fn + %T{data: %Expr{id: id}} = t -> + if shard_propagation = shards[id] do + shape = + shard_propagation.shards + |> Enum.sort() + |> Enum.map(fn {_axis, [%Shard{length: length} | _]} -> length end) + |> List.to_tuple() + + t = do_set_shard_metadata(%{t | shape: shape}, shards) + Expr.metadata(t, %{shards: shard_propagation.shards}) + else + do_set_shard_metadata(t, shards) + end + + other -> + other + end) + end + + defp do_set_shard_metadata(%T{data: %Expr{args: args}} = expr, shards) do + args = + Enum.map(args, fn + %T{} = arg -> + set_shard_metadata(arg, shards) + + arg when is_list(arg) -> + Enum.map(arg, &do_set_shard_metadata(&1, shards)) + + arg -> + arg + end) + + put_in(expr.data.args, args) + end + + defp do_set_shard_metadata(other, _), do: other end diff --git a/nx/lib/nx/defn/sharding_compiler/passes/graph_splitter/stage.ex b/nx/lib/nx/defn/sharding_compiler/passes/graph_splitter/stage.ex new file mode 100644 index 0000000000..c6aba0d60f --- /dev/null +++ b/nx/lib/nx/defn/sharding_compiler/passes/graph_splitter/stage.ex @@ -0,0 +1,3 @@ +defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter.Stage do + defstruct [:id, :category, :expr, :arguments, :argument_sources] +end diff --git a/nx/lib/nx/defn/sharding_compiler/passes/shard_propagation.ex b/nx/lib/nx/defn/sharding_compiler/passes/shard_propagation.ex index 08206e1a0a..83b703b083 100644 --- a/nx/lib/nx/defn/sharding_compiler/passes/shard_propagation.ex +++ b/nx/lib/nx/defn/sharding_compiler/passes/shard_propagation.ex @@ -5,7 +5,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do alias Nx.Defn.ShardingCompiler.Shard - defstruct [:id, :shards, :input_tensor_shardings, :parameter_ids_to_index, :expr] + defstruct [:id, :shards, :expr] def traverse(expr, tensor_shardings) do {container, {cache, state}} = @@ -19,9 +19,6 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do %{} ) - container = put_in(container.data.input_tensor_shardings, tensor_shardings) - container = put_in(container.data.parameter_ids_to_index, state.parameter_ids_to_index) - {container, cache, state} end @@ -53,7 +50,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do t |> Nx.axes() |> Map.new(fn axis -> - {axis, [0..(elem(t.shape, axis) - 1)]} + {axis, elem(t.shape, axis)} end) expr = shard_from_config(t, config) @@ -62,7 +59,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do end defp eval(%T{data: %Expr{op: :constant, args: [_constant]}} = ans, {cache, state}) do - expr = shard_from_config(ans, %{0 => [0..0]}) + expr = shard_from_config(ans, %{}) state = put_in(state.expr_shards[expr.data.id], expr.data) {expr, {cache, state}} end @@ -361,6 +358,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.ShardPropagation do defp resolve_sharding_broadcast(axis, left_shards, false, right_shards, false) do # We have a shard on both sides. We need to determine the intersection of the two. # This is fine only if all shards are equal + {reverse_out_shards, all_shards_match} = Enum.zip_reduce(left_shards, right_shards, {[], true}, fn left, right, diff --git a/nx/lib/nx/defn/sharding_compiler/shard.ex b/nx/lib/nx/defn/sharding_compiler/shard.ex index e2feac672f..f5130f12a0 100644 --- a/nx/lib/nx/defn/sharding_compiler/shard.ex +++ b/nx/lib/nx/defn/sharding_compiler/shard.ex @@ -1,6 +1,6 @@ defmodule Nx.Defn.ShardingCompiler.Shard do import Inspect.Algebra - defstruct [:id, :axis, :input_id, :start, :length, :parents] + defstruct [:id, :axis, :input_id, :start, :length, :parents, :debug_id] def inspect(%__MODULE__{start: start, length: length}, inspect_opts) when is_nil(start) or is_nil(length) do @@ -8,12 +8,26 @@ defmodule Nx.Defn.ShardingCompiler.Shard do end def inspect( - %__MODULE__{id: id, axis: axis, start: start, length: length, input_id: input_id}, + %__MODULE__{ + debug_id: debug_id, + id: id, + axis: axis, + start: start, + length: length, + input_id: input_id + }, inspect_opts ) do single_line = inspect_opts.custom_options[:single_line] print_axis = inspect_opts.custom_options[:print_axis] + id_doc = + if Application.get_env(:nx, :debug_shards) do + "(#{inspect(debug_id)} | #{inspect(id)})" + else + "(#{inspect(id)})" + end + range_doc = "#{start}..#{start + length - 1}" input_id_doc = if(input_id, do: "(#{inspect(input_id)})", else: "") @@ -22,7 +36,8 @@ defmodule Nx.Defn.ShardingCompiler.Shard do color("Shard<", :map, inspect_opts), if(print_axis && axis, do: "#{axis}: ", else: ""), range_doc, - " (#{inspect(id)})", + " ", + id_doc, input_id_doc, color(">", :map, inspect_opts) ]) @@ -35,7 +50,7 @@ defmodule Nx.Defn.ShardingCompiler.Shard do if(print_axis && axis, do: "#{axis}: ", else: ""), range_doc, line(), - "(#{inspect(id)})", + id_doc, line(), input_id_doc ]), @@ -56,10 +71,11 @@ defmodule Nx.Defn.ShardingCompiler.Shard do """ def from_config(tensor, config, opts \\ []) do input_id = opts[:input_id] + debug_id = opts[:debug_id] shards = Map.new(config, fn - {axis_or_name, slices} -> + {axis_or_name, length} -> axis = if is_atom(axis_or_name) do Nx.axis_index(tensor, axis_or_name) @@ -67,15 +83,27 @@ defmodule Nx.Defn.ShardingCompiler.Shard do axis_or_name end + axis_size = Nx.axis_size(tensor, axis) + + {slices, checksum} = + Enum.map_reduce(0..(axis_size - 1)//length, 0, fn start, checksum -> + {{start, length}, checksum + length} + end) + + if checksum != axis_size do + raise "Shard length #{length} does not evenly divide axis #{inspect(axis_or_name)} of size #{axis_size}" + end + shards = - Enum.map(slices, fn start..finish//1 -> + Enum.map(slices, fn {start, length} -> id = make_ref() %__MODULE__{ id: id, + debug_id: debug_id, axis: axis, start: start, - length: finish - start + 1, + length: length, input_id: input_id, parents: [] } diff --git a/nx/lib/nx/defn/sharding_compiler/shard_execution.ex b/nx/lib/nx/defn/sharding_compiler/shard_execution.ex new file mode 100644 index 0000000000..22db845712 --- /dev/null +++ b/nx/lib/nx/defn/sharding_compiler/shard_execution.ex @@ -0,0 +1,173 @@ +defmodule Nx.Defn.ShardingCompiler.ShardExecution do + # processes a single shard of an output entry, given the corresponding input data sections (1 per input) + defstruct [ + :compiled_fun, + :stage, + :input_data_sections, + :output_entry_index, + :output_data_section_id, + :output_starts, + :output_lengths, + :fetched_inputs, + :output + ] + + use GenServer + + alias Nx.Defn.ShardingCompiler.Passes.GraphSplitter.Stage + alias Nx.Defn.ShardingCompiler.ShardRegistry + + alias Nx.Tensor, as: T + alias Nx.Defn.Expr + + def init([ + %Stage{} = stage, + input_data_sections, + output_entry_index, + output_data_section_id, + output_starts, + output_lengths + ]) do + Process.send_after(self(), :fetch_inputs, 0) + + fetched_inputs = Map.new(input_data_sections, fn {_idx, {arg_id, _}} -> {arg_id, nil} end) + + arg_templates = + Enum.map(input_data_sections, fn {idx, {arg_id, shard_ids}} -> + arg = stage.arguments[arg_id] + + {shape, type} = + case arg do + %T{data: %Expr{op: :parameter, args: [_idx]}, shape: shape, type: type} -> + {shape, type} + + %T{data: %Expr{op: :metadata, args: [_arg, %{shards: shards}]}, type: type} -> + shape = + Enum.with_index(shard_ids, fn shard_id, axis -> + %{length: length} = + Enum.find(shards[axis], fn shard -> shard.id == shard_id end) + + length + end) + |> List.to_tuple() + + {shape, type} + end + + arg = %T{ + data: nil, + shape: shape, + type: type, + names: List.duplicate(nil, tuple_size(shape)) + } + + Expr.parameter(arg, :root, idx) + end) + + compiled_fun = + Nx.Defn.Evaluator.__compile__( + make_ref(), + arg_templates, + fn _ -> + stage.expr + end, + [] + ) + + fun = fn [args] -> + [res] = + compiled_fun.([ + Enum.map(Tuple.to_list(args), fn arg -> + fn -> arg end + end) + ]) + + res + end + + {:ok, + %__MODULE__{ + stage: stage, + input_data_sections: input_data_sections, + output_entry_index: output_entry_index, + output_data_section_id: output_data_section_id, + output_starts: output_starts, + output_lengths: output_lengths, + fetched_inputs: fetched_inputs, + # TO-DO: pass compiled_fun as argument + compiled_fun: fun + }} + end + + def start_link(args) do + GenServer.start_link(__MODULE__, args, name: via_tuple(args)) + end + + defp via_tuple([ + stage, + _input_data_sections, + output_entry_index, + output_data_section_id, + _starts, + _lengths + ]) do + {:via, Registry, {ShardRegistry, {stage.id, output_entry_index, output_data_section_id}}} + end + + def handle_info(:fetch_inputs, state) do + state = + for {arg_idx, {arg_id, data_section_id}} <- state.input_data_sections, + is_nil(state.fetched_inputs[arg_id]), + reduce: state do + state -> + {stage_id, stage_idx} = state.stage.argument_sources[arg_id] + + case get(stage_id, stage_idx, data_section_id) do + {:ok, data} -> + put_in(state.fetched_inputs[arg_id], {arg_idx, data}) + + {:error, :pending} -> + state + end + end + + if Enum.any?(state.fetched_inputs, fn {_arg_id, data} -> is_nil(data) end) do + Process.send_after(self(), :fetch_inputs, 10) + {:noreply, state} + else + state = compute(state) + {:noreply, state} + end + end + + def get(stage_id, stage_idx, data_section_id) do + key = {stage_id, stage_idx, data_section_id} + + case ShardRegistry.lookup(key) do + {:ok, pid} -> GenServer.call(pid, :get) + {:error, :pending} -> {:error, :pending} + end + end + + def handle_call(:get, _from, state) do + result = + case state.output do + nil -> {:error, :pending} + data -> {:ok, {data, state.output_starts, state.output_lengths}} + end + + {:reply, result, state} + end + + defp compute(state) do + args = + state.fetched_inputs + |> Enum.map(fn {_id, {idx, data}} -> {idx, data} end) + |> Enum.sort() + |> Enum.map(fn {_idx, data} -> data end) + |> List.to_tuple() + + output = state.compiled_fun.([args]) + %{state | output: output} + end +end diff --git a/nx/lib/nx/defn/sharding_compiler/shard_execution/argument_provider.ex b/nx/lib/nx/defn/sharding_compiler/shard_execution/argument_provider.ex new file mode 100644 index 0000000000..425c9e8b75 --- /dev/null +++ b/nx/lib/nx/defn/sharding_compiler/shard_execution/argument_provider.ex @@ -0,0 +1,21 @@ +defmodule Nx.Defn.ShardingCompiler.ShardExecution.ArgumentProvider do + use GenServer + + alias Nx.Defn.ShardingCompiler.ShardRegistry + + def init(data) do + {:ok, data} + end + + def start_link([data, idx, section_id]) do + GenServer.start_link(__MODULE__, data, name: via_tuple(idx, section_id)) + end + + defp via_tuple(idx, section_id) do + {:via, Registry, {ShardRegistry, {nil, idx, section_id}}} + end + + def handle_call(:get, _from, data) do + {:reply, {:ok, data}, data} + end +end diff --git a/nx/lib/nx/defn/sharding_compiler/shard_execution/output_collector.ex b/nx/lib/nx/defn/sharding_compiler/shard_execution/output_collector.ex new file mode 100644 index 0000000000..b5776afaee --- /dev/null +++ b/nx/lib/nx/defn/sharding_compiler/shard_execution/output_collector.ex @@ -0,0 +1,153 @@ +defmodule Nx.Defn.ShardingCompiler.ShardExecution.OutputCollector do + use GenServer + + alias Nx.Defn.ShardingCompiler.ShardRegistry + alias Nx.Defn.ShardingCompiler.ShardExecution + alias Nx.Defn.ShardingCompiler.Passes.ShardPropagation + + alias Nx.Tensor, as: T + + def init([expr, previous_stage_id, listener_pid]) do + {expr_tuple, unwrap_result} = + if is_tuple(expr) do + {expr, false} + else + {{expr}, true} + end + + data_sections_by_index = + expr_tuple + |> Tuple.to_list() + |> Enum.with_index(fn expr, idx -> + sections = + for {starts, data_section_id} <- starts_and_data_section_ids(expr) do + {data_section_id, starts, nil} + end + + {idx, sections} + end) + + Process.send_after(self(), :collect_data, 0) + + {:ok, + %{ + listener_pid: listener_pid, + expr_tuple: expr_tuple, + previous_stage_id: previous_stage_id, + unwrap_result: unwrap_result, + data_sections_by_index: data_sections_by_index, + output: nil + }} + end + + def start_link(sharded_expr, previous_stage_id, listener_pid) do + GenServer.start_link( + __MODULE__, + [sharded_expr, previous_stage_id, listener_pid], + sharded_exprname: via_tuple(sharded_expr) + ) + end + + defp via_tuple(expr) do + expr_id = + if is_tuple(expr) do + expr + |> Tuple.to_list() + |> Enum.map(& &1.data.id) + else + expr.data.id + end + + {:via, Registry, {ShardRegistry, {:output, expr_id}}} + end + + def handle_call(:get, _from, state) do + case state.output do + nil -> {:reply, {:error, :pending}, state} + data -> {:reply, {:ok, data}, state} + end + end + + def handle_info(:collect_data, state) do + data_sections_by_index = + Map.new(state.data_sections_by_index, fn {idx, data_sections} -> + data_sections = + for {data_section_id, starts, nil} <- data_sections do + case ShardExecution.get(state.previous_stage_id, idx, data_section_id) do + {:ok, {data, _starts, _lenghts}} -> + {data_section_id, starts, data} + + {:error, :pending} -> + {data_section_id, starts, nil} + end + end + + {idx, data_sections} + end) + + finished = + Enum.all?(data_sections_by_index, fn {_idx, data_sections} -> + Enum.all?(data_sections, fn {_, _, data} -> not is_nil(data) end) + end) + + output = + if finished do + out_list = produce_output(state.expr_tuple, data_sections_by_index) + + if state.unwrap_result do + [out] = out_list + out + else + List.to_tuple(out_list) + end + end + + if output do + Process.send_after(self(), :notify_listener, 0) + else + Process.send_after(self(), :collect_data, 0) + end + + {:noreply, %{state | output: output, data_sections_by_index: data_sections_by_index}} + end + + def handle_info(:notify_listener, state) do + send(state.listener_pid, {__MODULE__, :done, self(), state.output}) + {:noreply, state} + end + + defp starts_and_data_section_ids(%T{data: %ShardPropagation{shards: shards}}) do + shards + |> Enum.sort_by(fn {axis, _} -> axis end) + |> Enum.map(fn {axis, shard} -> {shard, axis} end) + |> cartesian_product() + |> Enum.map(fn sections -> + starts = + Enum.map(sections, fn {shard, _axis} -> shard.start end) + + data_section_id = Enum.map(sections, fn {shard, _axis} -> shard.id end) + + {starts, data_section_id} + end) + end + + defp cartesian_product([{data, meta} | rest]) do + for x <- data, y <- cartesian_product(rest), do: [{x, meta} | y] + end + + defp cartesian_product([]), do: [[]] + + defp produce_output(expr_tuple, data_sections_by_index) do + Enum.map(data_sections_by_index, fn {idx, data_sections} -> + hole_template = elem(expr_tuple, idx) + hole = Nx.broadcast(Nx.tensor(0, type: hole_template.type), hole_template.shape) + + data = + Enum.reduce(data_sections, hole, fn {_, starts, data}, acc -> + Nx.put_slice(acc, starts, data) + end) + + data + end) + end +end diff --git a/nx/lib/nx/defn/sharding_compiler/shard_execution/supervisor.ex b/nx/lib/nx/defn/sharding_compiler/shard_execution/supervisor.ex new file mode 100644 index 0000000000..aed79ee1d2 --- /dev/null +++ b/nx/lib/nx/defn/sharding_compiler/shard_execution/supervisor.ex @@ -0,0 +1,117 @@ +defmodule Nx.Defn.ShardingCompiler.ShardExecution.Supervisor do + use Supervisor + + alias Nx.Defn.ShardingCompiler.Passes.GraphSplitter.Stage + alias Nx.Defn.ShardingCompiler.Shard + + alias Nx.Defn.Expr + alias Nx.Tensor, as: T + + def start_link(%Stage{} = stage) do + Supervisor.start_link(__MODULE__, stage) + end + + @impl true + def init(stage) do + children = + for {output_entry_index, output_data_sections} <- output_data_sections(stage), + {output_data_section_id, input_data_sections, output_starts, output_lengths} <- + input_data_sections(stage.arguments, output_data_sections) do + %{ + id: {stage.id, input_data_sections}, + start: + {Nx.Defn.ShardingCompiler.ShardExecution, :start_link, + [ + [ + stage, + input_data_sections, + output_entry_index, + output_data_section_id, + output_starts, + output_lengths + ] + ]}, + restart: :permanent, + type: :worker + } + end + + Supervisor.init(children, strategy: :one_for_one) + end + + defp output_data_sections(%Stage{expr: expr}) do + if is_tuple(expr) do + expr + |> Tuple.to_list() + |> Enum.with_index(fn expr, idx -> {idx, output_data_sections_for_expr(expr)} end) + else + [{0, output_data_sections_for_expr(expr)}] + end + end + + defp output_data_sections_for_expr(%T{data: %Expr{op: :metadata, args: [_, %{shards: shards}]}}) do + shards + |> Enum.sort_by(fn {axis, _} -> axis end) + |> Enum.map(fn {axis, shard} -> {shard, axis} end) + |> cartesian_product() + |> Enum.map(fn sections -> + {starts, lengths} = + sections + |> Enum.map(fn {shard, _axis} -> {shard.start, shard.length} end) + |> Enum.unzip() + + data_section_id = Enum.map(sections, fn {shard, _axis} -> shard.id end) + + roots = + Enum.map(sections, fn {shard, axis} -> {axis, get_root_parents(shard)} end) + + {data_section_id, {roots, starts, lengths}} + end) + end + + defp input_data_sections(arguments, output_data_sections) do + for {data_section_id, {output_roots_by_dim, starts, lengths}} <- output_data_sections do + arg_sections = + for {arg_id, arg} <- arguments do + %T{data: %Expr{op: :metadata, args: [param, %{shards: shards}]}} = arg + %T{data: %Expr{op: :parameter, args: [arg_idx]}} = param + + shards_by_root = + for {_axis, shards_for_axis} <- shards, + shard <- shards_for_axis, + root <- get_root_parents(shard), + into: %{} do + {root.id, shard} + end + + data_section_id_for_input = + output_roots_by_dim + |> Enum.map(fn {_axis, roots} -> + %Shard{} = shard = Enum.find(roots, &shards_by_root[&1.id]) + {shard.axis, shard.id} + end) + |> Enum.sort() + |> Enum.map(fn {_axis, id} -> id end) + + {arg_idx, {arg_id, data_section_id_for_input}} + end + + {data_section_id, arg_sections, starts, lengths} + end + end + + defp cartesian_product([{data, meta} | rest]) do + for x <- data, y <- cartesian_product(rest), do: [{x, meta} | y] + end + + defp cartesian_product([]), do: [[]] + + defp get_root_parents(shard, acc \\ []) + + defp get_root_parents(%Shard{parents: []} = shard, acc), do: List.flatten([shard | acc]) + + defp get_root_parents(%Shard{parents: parents}, acc) do + Enum.reduce(parents, acc, &get_root_parents/2) + |> List.flatten() + end +end diff --git a/nx/lib/nx/defn/sharding_compiler/shard_registry.ex b/nx/lib/nx/defn/sharding_compiler/shard_registry.ex new file mode 100644 index 0000000000..98626d1475 --- /dev/null +++ b/nx/lib/nx/defn/sharding_compiler/shard_registry.ex @@ -0,0 +1,26 @@ +defmodule Nx.Defn.ShardingCompiler.ShardRegistry do + def child_spec(opts) do + %{ + id: __MODULE__, + start: {__MODULE__, :start_link, [opts]} + } + end + + def start_link(_) do + Registry.start_link(name: __MODULE__, keys: :unique) + end + + def lookup(key) do + results = :erpc.multicall([Node.self() | Node.list()], Registry, :lookup, [__MODULE__, key]) + + results + |> Enum.find_value(fn + {:ok, [{pid, _}]} -> pid + _ -> nil + end) + |> case do + nil -> {:error, :pending} + pid -> {:ok, pid} + end + end +end diff --git a/nx/test/nx/defn/sharding_compiler/passes/graph_splitter_test.exs b/nx/test/nx/defn/sharding_compiler/passes/graph_splitter_test.exs index 24206263ec..76a50e40dd 100644 --- a/nx/test/nx/defn/sharding_compiler/passes/graph_splitter_test.exs +++ b/nx/test/nx/defn/sharding_compiler/passes/graph_splitter_test.exs @@ -2,6 +2,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitterTest do use ExUnit.Case, async: true alias Nx.Defn.ShardingCompiler.Passes.GraphSplitter + alias Nx.Defn.ShardingCompiler.Passes.GraphSplitter.Stage alias Nx.Defn.ShardingCompiler.Passes.ShardPropagation alias Nx.Defn.ShardingCompiler.Shard @@ -19,14 +20,26 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitterTest do Nx.divide(w, 4) end).(Nx.tensor([1, 2]), Nx.tensor([3, 4])) - {chain, state, cache} = GraphSplitter.traverse(expr) + {chain, cache, state} = GraphSplitter.traverse(expr) assert [ - {stage_0_id, :gather, stage_0_expr, stage_0_argument_sources}, - {_stage_1_id, :none, stage_1_expr, stage_1_argument_sources} + %Stage{ + id: stage_0_id, + category: :gather, + expr: stage_0_expr, + argument_sources: stage_0_argument_sources + }, + %Stage{ + id: _stage_1_id, + category: :none, + expr: stage_1_expr, + argument_sources: stage_1_argument_sources + } ] = chain - assert Enum.all?(stage_0_argument_sources, fn {_id, source} -> source == nil end) + assert Enum.all?(stage_0_argument_sources, fn {_id, {source_id, _idx}} -> + source_id == nil + end) assert [{2, arg_2_original_node_id, arg_2_id}, {3, arg_3_original_node_id, arg_3_id}] = state.nodes_to_replace @@ -123,15 +136,32 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitterTest do |> Nx.subtract(arg2) end).(Nx.tensor([[1, 2]]), Nx.tensor([[3], [4]]), Nx.tensor([5, 6])) - {chain, state, cache} = GraphSplitter.traverse(expr) + {chain, cache, state} = GraphSplitter.traverse(expr) assert [ - {stage_0_id, :gather, stage_0_expr, stage_0_argument_sources}, - {stage_1_id, :reduce, stage_1_expr, stage_1_argument_sources}, - {_stage_2_id, :none, stage_2_expr, stage_2_argument_sources} + %Stage{ + id: stage_0_id, + category: :gather, + expr: stage_0_expr, + argument_sources: stage_0_argument_sources + }, + %Stage{ + id: stage_1_id, + category: :reduce, + expr: stage_1_expr, + argument_sources: stage_1_argument_sources + }, + %Stage{ + id: _stage_2_id, + category: :none, + expr: stage_2_expr, + argument_sources: stage_2_argument_sources + } ] = chain - assert Enum.all?(stage_0_argument_sources, fn {_id, source} -> source == nil end) + assert Enum.all?(stage_0_argument_sources, fn {_id, {source_id, _idx}} -> + source_id == nil + end) assert map_size(state.args) == 6 @@ -235,7 +265,7 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitterTest do assert %T{data: %Expr{op: :sum, args: [^a, [axes: nil, keep_axes: false]]}} = b assert %T{data: %Expr{id: ^arg_5_id, op: :parameter, args: [1]}} = a - assert %{arg_2_id => nil, arg_5_id => {stage_1_id, 0}} == stage_2_argument_sources + assert %{arg_2_id => {nil, 2}, arg_5_id => {stage_1_id, 0}} == stage_2_argument_sources end test "does not split on dot if arguments are not sharded on the reduction axis" do @@ -263,12 +293,13 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitterTest do {sharded_expr, _cache, %{expr_shards: expr_shards}} = ShardPropagation.traverse(expr, %{ - 0 => Shard.from_config(arg0, %{0 => [0..0, 1..1], 1 => [0..2]}), - 1 => Shard.from_config(arg1, %{0 => [0..2], 1 => [0..0, 1..1]}) + 0 => Shard.from_config(arg0, %{0 => 1, 1 => 3}), + 1 => Shard.from_config(arg1, %{0 => 3, 1 => 1}) }) # This ensures the data hasn't been split - assert {[{_id, :none, out_expr, sources}], _state, _cache} = + assert {[%Stage{category: :none, expr: out_expr, argument_sources: sources}], _cache, + _state} = GraphSplitter.traverse(expr, expr_shards) # Following assertions ensure that: @@ -276,24 +307,15 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitterTest do # - The expression is unchanged aside from extra metadata nodes; # - And that the shards are set to the parameters too assert %T{ + shape: {1, 1}, data: %Expr{ op: :metadata, args: [ %T{ + shape: {1, 1}, data: %Expr{ op: :divide, - args: [ - %T{ - data: %Expr{ - op: :multiply, - args: [ - %T{data: %Expr{op: :constant, args: [3]}}, - %T{data: %Expr{op: :dot, args: [t0, _, _, t1, _, _]}} - ] - } - }, - %T{data: %Expr{op: :constant, args: [4]}} - ] + args: [div_arg_meta, %T{data: %Expr{op: :constant, args: [4]}}] } }, %{shards: output_shards} @@ -301,52 +323,81 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitterTest do } } = out_expr + assert %T{shape: {1, 1}, data: %Expr{op: :metadata, args: [div_arg, _]}} = div_arg_meta + + assert %T{ + shape: {1, 1}, + data: %Expr{ + op: :multiply, + args: [ + %T{data: %Expr{op: :constant, args: [3]}}, + mul_arg + ] + } + } = div_arg + + assert %T{ + shape: {1, 1}, + data: %Expr{ + op: :metadata, + args: [ + %T{shape: {1, 1}, data: %Expr{op: :dot, args: [t0_meta, _, _, t1_meta, _, _]}}, + _shards + ] + } + } = mul_arg + assert sharded_expr.data.shards == output_shards - %T{ - data: %Expr{ - op: :add, - args: [ - %T{data: %Expr{op: :constant, args: [1]}}, - %T{ - data: %Expr{ - op: :metadata, - args: [%T{data: %Expr{op: :parameter, args: [0]}}, %{shards: arg0_shards}] - } - } - ] - } - } = t0 + assert %T{shape: {1, 3}, data: %Expr{op: :metadata, args: [t0, _]}} = t0_meta + assert %T{shape: {3, 1}, data: %Expr{op: :metadata, args: [t1, _]}} = t1_meta + + assert %T{ + data: %Expr{ + op: :add, + args: [ + %T{data: %Expr{op: :constant, args: [1]}}, + %T{ + data: %Expr{ + op: :metadata, + args: [%T{data: %Expr{op: :parameter, args: [0]}}, %{shards: arg0_shards}] + } + } + ] + } + } = t0 assert %{ 0 => [%Shard{start: 0, length: 1}, %Shard{start: 1, length: 1}], 1 => [%Shard{start: 0, length: 3}] } = arg0_shards - %T{ - data: %Expr{ - op: :subtract, - args: [ - %T{ - data: %Expr{ - op: :metadata, - args: [%T{data: %Expr{op: :parameter, args: [1]}}, %{shards: arg1_shards}] - } - }, - %T{data: %Expr{op: :constant, args: [2]}} - ] - } - } = t1 + assert %T{ + data: %Expr{ + op: :subtract, + args: [ + %T{ + data: %Expr{ + op: :metadata, + args: [%T{data: %Expr{op: :parameter, args: [1]}}, %{shards: arg1_shards}] + } + }, + %T{data: %Expr{op: :constant, args: [2]}} + ] + } + } = t1 assert %{ 0 => [%Shard{start: 0, length: 3}], 1 => [%Shard{start: 0, length: 1}, %Shard{start: 1, length: 1}] } = arg1_shards - assert Enum.all?(sources, fn {_id, source} -> source == nil end) + assert Enum.all?(sources, fn {_id, {source_id, _idx}} -> + source_id == nil + end) end - test "splits on dot if arguments are not sharded on the reduction axis" do + test "splits on dot if arguments are sharded on the reduction axis" do arg0 = Nx.tensor([ [1, 2, 3], @@ -372,18 +423,18 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitterTest do {_sharded_expr, _cache, %{expr_shards: expr_shards}} = ShardPropagation.traverse(expr, %{ 0 => Shard.from_config(arg0, %{}), - 1 => Shard.from_config(arg1, %{0 => [0..2], 1 => [0..0, 1..1]}) + 1 => Shard.from_config(arg1, %{0 => 3, 1 => 1}) }) - assert {[_, _], _state, _cache} = GraphSplitter.traverse(expr, expr_shards) + assert {[_, _], _cache, _state} = GraphSplitter.traverse(expr, expr_shards) {sharded_expr, _cache, %{expr_shards: expr_shards}} = ShardPropagation.traverse(expr, %{ - 0 => Shard.from_config(arg0, %{0 => [0..0, 1..1], 1 => [0..2]}), + 0 => Shard.from_config(arg0, %{0 => 1, 1 => 3}), 1 => Shard.from_config(arg1, %{}) }) - assert {[{_, _, stage_0_expr, _}, {_, _, stage_1_expr, _}], _state, _cache} = + assert {[%Stage{expr: stage_0_expr}, %Stage{expr: stage_1_expr}], _cache, _state} = GraphSplitter.traverse(expr, expr_shards) assert {%T{data: %Expr{op: :metadata, args: [_left, %{shards: left_shards}]}}, @@ -409,5 +460,120 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitterTest do assert out_shards == sharded_expr.data.shards end + + test "supports optional callbacks" do + arg0 = + Nx.u8([ + [1, 0, 1], + [1, 1, 1] + ]) + + expr = + Nx.Defn.debug_expr(fn a, b -> + x = Nx.add(b, 1) + y = Nx.sum(x, axes: [1]) + z = Nx.logical_not(y) + Nx.subtract(z, a) + end).(1, arg0) + + assert {[%Stage{} = stage_0, %Stage{} = stage_1], _cache, _state} = + GraphSplitter.traverse(expr) + + [{arg1_id, %T{shape: {2, 3}, type: {:u, 8}, data: %Expr{args: [0]}}}] = + Enum.to_list(stage_0.arguments) + + assert stage_0.argument_sources == %{arg1_id => {nil, 1}} + + stage_1_args = + Enum.sort_by(stage_1.arguments, fn {_id, %T{data: %Expr{op: :parameter, args: [idx]}}} -> + idx + end) + + assert [ + {arg_0_id, %T{shape: {}, type: {:s, 32}}}, + {arg_1_id, %T{shape: {2, 3}, type: {:u, 8}}} + ] = + stage_1_args + + assert stage_1.argument_sources == %{arg_0_id => {nil, 0}, arg_1_id => {stage_0.id, 0}} + + assert %T{data: %Expr{op: :subtract, args: [c, d]}} = stage_1.expr + assert %T{data: %Expr{op: :optional, args: [call, subexpr, _fun]}} = c + + assert %T{data: %Expr{id: ^arg_0_id, op: :parameter, args: [0]}} = d + + assert %T{data: %Expr{op: :logical_not, args: [b]}} = call + assert %T{data: %Expr{op: :sum, args: [a, [axes: [1], keep_axes: false]]}} = b + assert %T{data: %Expr{id: ^arg_1_id, op: :parameter, args: [1]}} = a + + assert %T{ + data: %Expr{ + op: :equal, + args: [ + %T{data: %Expr{id: subexpr_arg_0_id, op: :parameter, args: [0]}}, + %T{data: %Expr{op: :constant, args: [0]}} + ] + } + } = subexpr + + # ensure subexpr is hermetic + assert subexpr_arg_0_id != arg_0_id + assert subexpr_arg_0_id != arg_1_id + end + + test "supports in-line anonymous functions" do + arg0 = + Nx.u8([ + [1, 0, 1], + [1, 1, 1] + ]) + + expr = + Nx.Defn.debug_expr(fn a, b -> + x = Nx.add(b, 1) + y = Nx.sum(x, axes: [1]) + f = fn a -> Nx.equal(a, 0) end + z = f.(y) + Nx.subtract(z, a) + end).(1, arg0) + + assert {[%Stage{} = stage_0, %Stage{} = stage_1], _cache, _state} = + GraphSplitter.traverse(expr) + + [{arg1_id, %T{shape: {2, 3}, type: {:u, 8}, data: %Expr{args: [0]}}}] = + Enum.to_list(stage_0.arguments) + + assert stage_0.argument_sources == %{arg1_id => {nil, 1}} + + stage_1_args = + Enum.sort_by(stage_1.arguments, fn {_id, %T{data: %Expr{op: :parameter, args: [idx]}}} -> + idx + end) + + assert [ + {arg_0_id, %T{shape: {}, type: {:s, 32}}}, + {arg_1_id, %T{shape: {2, 3}, type: {:u, 8}}} + ] = + stage_1_args + + assert stage_1.argument_sources == %{arg_0_id => {nil, 0}, arg_1_id => {stage_0.id, 0}} + + assert %T{data: %Expr{op: :subtract, args: [c, d]}} = stage_1.expr + + assert %T{ + data: %Expr{ + op: :equal, + args: [ + left, + %T{data: %Expr{op: :constant, args: [0]}} + ] + } + } = c + + assert %T{data: %Expr{id: ^arg_0_id, op: :parameter, args: [0]}} = d + + assert %T{data: %Expr{op: :sum, args: [a, [axes: [1], keep_axes: false]]}} = left + assert %T{data: %Expr{id: ^arg_1_id, op: :parameter, args: [1]}} = a + end end end diff --git a/nx/test/nx/defn/sharding_compiler/shard_execution_test.exs b/nx/test/nx/defn/sharding_compiler/shard_execution_test.exs new file mode 100644 index 0000000000..f43b87fdbd --- /dev/null +++ b/nx/test/nx/defn/sharding_compiler/shard_execution_test.exs @@ -0,0 +1,159 @@ +defmodule Nx.Defn.ShardingCompiler.ShardExecutionTest do + use ExUnit.Case, async: true + + alias Nx.Defn.ShardingCompiler.Passes.GraphSplitter + alias Nx.Defn.ShardingCompiler.Passes.GraphSplitter.Stage + alias Nx.Defn.ShardingCompiler.Passes.ShardPropagation + alias Nx.Defn.ShardingCompiler.ShardExecution + alias Nx.Defn.ShardingCompiler.Shard + + alias Nx.Tensor, as: T + alias Nx.Defn.Expr + + test "Creates all the necessary children for each stage" do + arg0 = + Nx.tensor([ + [1, 2, 3], + [4, 5, 6] + ]) + + arg1 = + Nx.tensor([ + [1, 2], + [3, 4], + [5, 6] + ]) + + fun = fn arg0, arg1 -> + x = Nx.add(arg0, 1) + y = Nx.subtract(arg1, 2) + + Nx.multiply(x, Nx.transpose(y)) + end + + expected_output = fun.(arg0, arg1) + + expr = + Nx.Defn.debug_expr(fun).(arg0, arg1) + + {%T{data: %ShardPropagation{expr: sharded_expr}} = ans, _cache, %{expr_shards: expr_shards}} = + ShardPropagation.traverse(expr, %{ + 0 => Shard.from_config(arg0, %{0 => 1, 1 => 3}, debug_id: "arg 0"), + 1 => Shard.from_config(arg1, %{0 => 3}, debug_id: "arg 1") + }) + + assert {[%Stage{} = stage0], _cache, _state} = + GraphSplitter.traverse(%T{ans | data: sharded_expr}, expr_shards) + + args_by_idx = %{0 => arg0, 1 => arg1} + + arg_providers = + Enum.flat_map(stage0.arguments, fn {_id, expr} -> + start_shard_providers(expr, args_by_idx) + end) + + assert Enum.count(arg_providers, &match?({:ok, _}, &1)) == 4 + + assert {:ok, pid} = ShardExecution.Supervisor.start_link(stage0) + + children = Supervisor.which_children(pid) + + states = + Enum.map(children, fn {key, pid, :worker, [ShardExecution]} -> + {key, :sys.get_state(pid)} + end) + |> Enum.sort_by(fn {_, state} -> {state.output_entry_index, state.output_starts} end) + + assert [executor0, executor1] = states + + assert {_key0, executor0_state} = executor0 + assert {_key1, executor1_state} = executor1 + + idx_to_id = + Map.new(stage0.arguments, fn + {id, %T{data: %Expr{op: :parameter, args: [idx]}}} -> + {idx, {id, nil}} + + {id, %T{data: %Expr{op: :metadata, args: [expr, %{shards: shards}]}}} -> + %T{data: %Expr{op: :parameter, args: [idx]}} = expr + {idx, {id, shards}} + end) + + assert %ShardExecution{ + input_data_sections: [{0, input_section0}, {1, input_section1}], + output_starts: [0, 0], + output_lengths: [1, 3] + } = executor0_state + + {id0, %{0 => [shard0, shard1], 1 => [shard2]}} = idx_to_id[0] + {id1, %{0 => [shard3], 1 => [shard4, shard5]}} = idx_to_id[1] + + assert {id0, [shard0.id, shard2.id]} == input_section0 + assert {id1, [shard3.id, shard4.id]} == input_section1 + + assert %ShardExecution{ + input_data_sections: [{0, input_section0}, {1, input_section1}], + output_starts: [1, 0], + output_lengths: [1, 3] + } = executor1_state + + assert {id0, [shard1.id, shard2.id]} == input_section0 + assert {id1, [shard3.id, shard5.id]} == input_section1 + + assert executor0_state.output == + Nx.tensor([ + [-2, 3, 12] + ]) + + assert executor1_state.output == + Nx.tensor([ + [0, 12, 28] + ]) + + {:ok, output_collector_pid} = + ShardExecution.OutputCollector.start_link(ans, stage0.id, self()) + + assert_receive {ShardExecution.OutputCollector, :done, ^output_collector_pid, result} + + assert expected_output == result + end + + defp start_shard_providers(sharded_expr, arg_data) do + case sharded_expr do + %T{data: %Expr{op: :parameter, args: [idx]}} -> + [ + ShardExecution.ArgumentProvider.start_link([ + sharded_expr, + idx, + arg_data[idx] + ]) + ] + + %T{data: %Expr{op: :metadata, args: [%T{data: %Expr{args: [idx]}}, %{shards: shards}]}} -> + shards + |> Enum.sort_by(fn {axis, _} -> axis end) + |> Enum.map(fn {axis, shard} -> {shard, axis} end) + |> cartesian_product() + |> Enum.map(fn sections -> + {starts, lengths} = + sections + |> Enum.map(fn {shard, _axis} -> {shard.start, shard.length} end) + |> Enum.unzip() + + data_section_id = Enum.map(sections, fn {shard, _axis} -> shard.id end) + + ShardExecution.ArgumentProvider.start_link([ + Nx.slice(arg_data[idx], starts, lengths), + idx, + data_section_id + ]) + end) + end + end + + defp cartesian_product([{data, meta} | rest]) do + for x <- data, y <- cartesian_product(rest), do: [{x, meta} | y] + end + + defp cartesian_product([]), do: [[]] +end diff --git a/nx/test/nx/defn/sharding_compiler_test.exs b/nx/test/nx/defn/sharding_compiler_test.exs index 3cdecf8d91..31ea72119b 100644 --- a/nx/test/nx/defn/sharding_compiler_test.exs +++ b/nx/test/nx/defn/sharding_compiler_test.exs @@ -15,8 +15,8 @@ defmodule Nx.Defn.ShardingCompilerTest do inputs = [in0, in1] - arg0_sharding = %{0 => [0..1], 1 => [0..1]} - arg1_sharding = %{4 => [0..0, 1..1]} + arg0_sharding = %{0 => 2, 1 => 2} + arg1_sharding = %{4 => 1} sharding = [arg0_sharding, arg1_sharding] @@ -39,8 +39,8 @@ defmodule Nx.Defn.ShardingCompilerTest do inputs = [t, t] - arg0_sharding = %{0 => [0..0, 1..1, 2..2]} - arg1_sharding = %{1 => [0..0, 1..1, 2..2]} + arg0_sharding = %{0 => 1} + arg1_sharding = %{1 => 1} sharding = [arg0_sharding, arg1_sharding] @@ -64,7 +64,7 @@ defmodule Nx.Defn.ShardingCompilerTest do inputs = [t] - arg0_sharding = %{0 => [0..0, 1..1, 2..2], 1 => [0..2]} + arg0_sharding = %{0 => 1, 1 => 3} sharding = [arg0_sharding] @@ -88,8 +88,8 @@ defmodule Nx.Defn.ShardingCompilerTest do inputs = [t0, t1] - arg0_sharding = %{0 => [0..0, 1..1, 2..2]} - arg1_sharding = %{1 => [0..0, 1..1, 2..2]} + arg0_sharding = %{0 => 1} + arg1_sharding = %{1 => 1} sharding = [arg0_sharding, arg1_sharding]