Skip to content

Commit 73f799f

Browse files
feat: Nx.Defn.Graph (#1544)
Co-authored-by: José Valim <jose.valim@dashbit.co>
1 parent 46c233c commit 73f799f

File tree

2 files changed

+788
-0
lines changed

2 files changed

+788
-0
lines changed

nx/lib/nx/defn/graph.ex

Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
defmodule Nx.Defn.Graph do
2+
@moduledoc """
3+
A module for splitting `Nx.Defn.Expr` into stages.
4+
5+
This module is used to split an `Nx.Defn.Expr` into stages, which are then
6+
executed in a chain.
7+
8+
`split/2` and `t:Stage.t()` describe how to split
9+
the graph and what's the expected result.
10+
11+
`run/2` executes the given graph against the provided arguments in a sequential manner.
12+
"""
13+
alias Nx.Defn.Composite
14+
15+
alias Nx.Tensor, as: T
16+
alias Nx.Defn.Expr
17+
18+
defmodule Stage do
19+
@typedoc """
20+
A stage in the graph splitter.
21+
22+
* `:arguments`: a list of maps that point to the source from which to fetch the corresponding
23+
value for the given argument.
24+
* `:expr`: the expression that represents the computation for the Stage.
25+
* `:id`: the unique id for the Stage.
26+
"""
27+
@type t :: %__MODULE__{
28+
id: reference(),
29+
expr: %{__struct__: Nx.Defn.Expr},
30+
arguments: [%{source: {reference() | nil, non_neg_integer()}}]
31+
}
32+
33+
defstruct [:id, :expr, :arguments]
34+
end
35+
36+
@doc """
37+
Splits the received Nx.Defn.Expr into stages given the rules.
38+
39+
`expr_split_fn` is a function that receives an `Nx.Tensor` containing an `Nx.Defn.Expr`
40+
and returns `true` when a split must happen, and `false` otherwise.
41+
42+
## Examples
43+
44+
iex> expr = Nx.Defn.debug_expr(fn x, y -> x |> Nx.negate() |> Nx.sin() |> Nx.cos() |> Nx.add(y) end).(1, 2)
45+
iex> [stage0, stage1] = Nx.Defn.Graph.split(expr, fn %Nx.Tensor{data: %Nx.Defn.Expr{op: op}} -> op == :cos end)
46+
iex> {out0} = stage0.expr
47+
iex> out0
48+
#Nx.Tensor<
49+
f32
50+
\n\
51+
Nx.Defn.Expr
52+
parameter a:0 s32
53+
b = negate a s32
54+
c = sin b f32
55+
>
56+
iex> stage1.expr
57+
#Nx.Tensor<
58+
f32
59+
\n\
60+
Nx.Defn.Expr
61+
parameter a:1 f32
62+
parameter c:0 s32
63+
b = cos a f32
64+
d = add b, c f32
65+
>
66+
"""
67+
def split(expr, expr_split_fn) when is_function(expr_split_fn, 1) do
68+
{chain, _, _} = __split__(expr, expr_split_fn)
69+
chain
70+
end
71+
72+
@doc """
73+
Executes the stage chain with the given arguments.
74+
"""
75+
def run(chain, args) do
76+
scope =
77+
Enum.with_index(args, fn arg, idx -> {{nil, idx}, arg} end)
78+
|> Map.new()
79+
80+
{result, _scope} =
81+
Enum.reduce(chain, {nil, scope}, fn stage, {_result, scope} ->
82+
%{id: id, expr: expr, arguments: arguments} = stage
83+
84+
args =
85+
Enum.map(arguments, fn %{source: source} ->
86+
Map.fetch!(scope, source)
87+
end)
88+
89+
case Nx.Defn.jit_apply(fn _ -> expr end, [List.to_tuple(args)]) do
90+
%T{} = tensor ->
91+
{tensor, Map.put(scope, {id, 0}, tensor)}
92+
93+
tuple ->
94+
{_idx, scope} =
95+
tuple
96+
|> Tuple.to_list()
97+
|> Enum.reduce({0, scope}, fn tensor, {idx, scope} ->
98+
{idx + 1, Map.put(scope, {id, idx}, tensor)}
99+
end)
100+
101+
{tuple, scope}
102+
end
103+
end)
104+
105+
result
106+
end
107+
108+
@doc false
109+
def __split__(expr, expr_split_fn) do
110+
# state.expression_chain is a reverse accumulation of the stages and
111+
# snapshots of the state at each one so that we can properly remap parameters for each stage.
112+
state = %{
113+
expression_chain: [],
114+
nodes_to_replace: %{},
115+
expr_split_fn: expr_split_fn,
116+
# args is a map of id -> {stage_id, output_container_position}
117+
args: %{}
118+
}
119+
120+
cache = %{}
121+
{expr, {cache, state}} = composite_eval(expr, state, cache)
122+
123+
expr_chain =
124+
Enum.reduce(
125+
[{make_ref(), expr, state.nodes_to_replace} | state.expression_chain],
126+
[],
127+
fn {id, expr, nodes_to_replace}, acc ->
128+
# TO-DO: we need to also do a pass to avoid recalculating results that have been previously calculated.
129+
# For example:
130+
# x = arg0 + arg1
131+
# y = arg0 - arg1
132+
# z = x + y
133+
# -----
134+
# w = dot(z, arg1)
135+
# y + w <- here, we currently have to recalculate y given that only z, arg0 and arg1 will be passed as arguments.
136+
# ideally, we should also pass y as a value to avoid recalculating it.
137+
# We might be able to calculate this in the first traversal somehow.
138+
139+
{expr, %{used_args: used_args}} =
140+
composite_rewrite_subtree(
141+
expr,
142+
%{state | nodes_to_replace: nodes_to_replace}
143+
)
144+
145+
arg_remapping =
146+
used_args
147+
|> Enum.sort_by(fn {_id, %T{data: %Expr{op: :parameter, args: [idx]}}} -> idx end)
148+
|> Enum.with_index(fn
149+
{id, expr}, idx ->
150+
{id, put_in(expr.data.args, [idx])}
151+
end)
152+
|> Map.new()
153+
154+
{expr, _} =
155+
composite_rewrite_subtree(expr, %{state | nodes_to_replace: arg_remapping})
156+
157+
arguments =
158+
arg_remapping
159+
|> Enum.map(fn {_id, arg_expr} ->
160+
id = arg_expr.data.id
161+
[idx] = arg_expr.data.args
162+
source = Map.fetch!(state.args, id)
163+
{idx, %{source: source}}
164+
end)
165+
|> Enum.sort_by(fn {idx, _} -> idx end)
166+
|> Enum.map(fn {_, arg} -> arg end)
167+
168+
[
169+
%Stage{
170+
id: id,
171+
expr: expr,
172+
arguments: arguments
173+
}
174+
| acc
175+
]
176+
end
177+
)
178+
179+
{expr_chain, cache, Map.delete(state, :expression_chain)}
180+
end
181+
182+
defp composite_eval(expr, state, cache) do
183+
Composite.traverse(expr, {cache, state}, &eval/2)
184+
end
185+
186+
defp eval(%T{data: %Expr{id: id, op: op}} = ans, {cache, state}) do
187+
case {cache, state.nodes_to_replace} do
188+
{_, %{^id => res}} ->
189+
# Replace the node with the corresponding parameter
190+
{res, {Map.put(cache, id, res), state}}
191+
192+
{%{^id => res}, _} ->
193+
{res, {cache, state}}
194+
195+
_ ->
196+
if state.expr_split_fn.(ans) do
197+
split_expr(ans, {cache, state})
198+
else
199+
eval_apply(op, ans, {cache, state})
200+
end
201+
end
202+
end
203+
204+
defp eval(other, {cache, state}) do
205+
{other, {cache, state}}
206+
end
207+
208+
defp split_expr(expr, {cache, state}) do
209+
{args, {cache, state}} = Nx.Defn.Tree.apply_args(expr, {cache, state}, &eval/2)
210+
# We need to save this so that each previous stage
211+
# isn't affected by following ones
212+
nodes_to_replace = state.nodes_to_replace
213+
214+
stage_id = make_ref()
215+
216+
{args, {tensor_args, _out_position, state}} =
217+
Enum.map_reduce(args, {[], 0, state}, fn
218+
%T{} = expr, {tensor_args, out_position, state} ->
219+
arg = Expr.parameter(expr, map_size(state.args))
220+
221+
state = %{
222+
state
223+
| args: Map.put(state.args, arg.data.id, {stage_id, out_position}),
224+
nodes_to_replace: Map.put(state.nodes_to_replace, expr.data.id, arg)
225+
}
226+
227+
{arg, {[expr | tensor_args], out_position + 1, state}}
228+
229+
non_tensor_arg, acc ->
230+
{non_tensor_arg, acc}
231+
end)
232+
233+
new_expr = put_in(expr.data.args, args)
234+
235+
state =
236+
update_in(
237+
state.expression_chain,
238+
&[
239+
{stage_id, List.to_tuple(Enum.reverse(tensor_args)), nodes_to_replace}
240+
| &1
241+
]
242+
)
243+
244+
cache = Map.put(cache, new_expr.data.id, new_expr)
245+
246+
{new_expr, {cache, state}}
247+
end
248+
249+
defp eval_apply(:parameter, %T{data: %Expr{id: id, args: [idx]}} = expr, {cache, state}) do
250+
state = put_in(state.args[id], {nil, idx})
251+
{expr, {Map.put(cache, id, expr), state}}
252+
end
253+
254+
defp eval_apply(:elem, %T{data: %Expr{id: id, args: [tuple, i]}}, {cache, state}) do
255+
{tuple, cache} = composite_eval(tuple, state, cache)
256+
res = elem(tuple, i)
257+
{res, {Map.put(cache, id, res), state}}
258+
end
259+
260+
defp eval_apply(_op, %T{data: %Expr{id: id}} = ans, {cache, state}) do
261+
{args, {cache, state}} = Nx.Defn.Tree.apply_args(ans, {cache, state}, &eval/2)
262+
ans = put_in(ans.data.args, args)
263+
{ans, {Map.put(cache, id, ans), state}}
264+
end
265+
266+
defp composite_rewrite_subtree(container, state, acc \\ %{used_args: %{}})
267+
268+
defp composite_rewrite_subtree(container, state, acc) when is_list(container) do
269+
Enum.map_reduce(container, acc, fn
270+
%T{} = arg, acc ->
271+
composite_rewrite_subtree(arg, state, acc)
272+
273+
arg, acc when is_list(arg) ->
274+
composite_rewrite_subtree(arg, state, acc)
275+
276+
arg, acc ->
277+
{arg, acc}
278+
end)
279+
end
280+
281+
defp composite_rewrite_subtree(container, state, acc) do
282+
Composite.traverse(container, acc, &rewrite_subtree(&1, state, &2))
283+
end
284+
285+
defp rewrite_subtree(%T{data: %Expr{id: id, op: :parameter}} = expr, state, acc) do
286+
case state.nodes_to_replace do
287+
%{^id => res} ->
288+
{res, put_in(acc.used_args[id], res)}
289+
290+
_ ->
291+
{expr, put_in(acc.used_args[id], expr)}
292+
end
293+
end
294+
295+
defp rewrite_subtree(
296+
%T{data: %Expr{op: :optional, id: id, args: [call, subexpr, fun]}} = expr,
297+
state,
298+
acc
299+
) do
300+
case state.nodes_to_replace do
301+
%{^id => res} ->
302+
{res, put_in(acc.used_args[id], res)}
303+
304+
_ ->
305+
{call, acc} = rewrite_subtree(call, state, acc)
306+
# `subexpr` is hermetic, in the sense that it is a self-contained scope
307+
# from which the arguments always come from `call`, so we can
308+
# keep it as is.
309+
310+
{put_in(expr.data.args, [call, subexpr, fun]), acc}
311+
end
312+
end
313+
314+
defp rewrite_subtree(%T{data: %Expr{id: id, args: args}} = expr, state, acc) do
315+
case state.nodes_to_replace do
316+
%{^id => res} ->
317+
# nodes_to_replace always contains a param
318+
{res, put_in(acc.used_args[id], res)}
319+
320+
_ ->
321+
{args, acc} = composite_rewrite_subtree(args, state, acc)
322+
{put_in(expr.data.args, args), acc}
323+
end
324+
end
325+
326+
defp rewrite_subtree(other, _, acc), do: {other, acc}
327+
end

0 commit comments

Comments
 (0)