Skip to content

Commit 591b6f4

Browse files
committed
refactor: unify arguments and argument sources
1 parent 3a66244 commit 591b6f4

File tree

3 files changed

+68
-44
lines changed

3 files changed

+68
-44
lines changed

nx/lib/nx/defn/graph_splitter.ex

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,11 @@ defmodule Nx.Defn.GraphSplitter do
2323

2424
scope =
2525
Enum.reduce(chain, scope, fn stage, scope ->
26-
%{id: id, expr: expr, argument_sources: argument_sources, arguments: arguments} = stage
26+
%{id: id, expr: expr, arguments: arguments} = stage
2727

2828
args =
2929
arguments
30-
|> Enum.map(fn {id, idx} ->
31-
source = Map.fetch!(argument_sources, id)
30+
|> Enum.map(fn {_id, %{source: source, index: idx}} ->
3231
argument = Map.fetch!(scope, source)
3332
{idx, argument}
3433
end)
@@ -116,23 +115,17 @@ defmodule Nx.Defn.GraphSplitter do
116115

117116
arguments =
118117
Map.new(arg_remapping, fn {_id, arg_expr} ->
118+
id = arg_expr.data.id
119119
[idx] = arg_expr.data.args
120-
{arg_expr.data.id, idx}
121-
end)
122-
123-
argument_sources =
124-
state.args
125-
|> Map.take(Map.keys(arg_remapping))
126-
|> Map.new(fn {remap_id, v} ->
127-
{arg_remapping[remap_id].data.id, v}
120+
source = Map.fetch!(state.args, id)
121+
{id, %{source: source, index: idx}}
128122
end)
129123

130124
[
131125
%Stage{
132126
id: id,
133127
expr: expr,
134-
arguments: arguments,
135-
argument_sources: argument_sources
128+
arguments: arguments
136129
}
137130
| acc
138131
]
Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
11
defmodule Nx.Defn.GraphSplitter.Stage do
2-
defstruct [:id, :expr, :arguments, :argument_sources]
2+
@typedoc """
3+
A stage in the graph splitter.
4+
5+
`:arguments` is a map of the id of the corresponding Nx.Defn.Expr :parameter
6+
node to the source {stage_id, output_container_position} of the argument
7+
and the index of the argument in the current stage.
8+
"""
9+
@type t :: %__MODULE__{
10+
id: reference(),
11+
expr: %{__struct__: Nx.Defn.Expr},
12+
arguments: %{
13+
reference() => %{
14+
source: {reference() | nil, non_neg_integer()},
15+
index: non_neg_integer()
16+
}
17+
}
18+
}
19+
20+
defstruct [:id, :expr, :arguments]
321
end

nx/test/nx/defn/graph_splitter_test.exs

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@ defmodule Nx.Defn.GraphSplitterTest do
2929
%Stage{
3030
id: stage_0_id,
3131
expr: stage_0_expr,
32-
argument_sources: stage_0_argument_sources
32+
arguments: stage_0_arguments
3333
},
3434
%Stage{
3535
id: _stage_1_id,
3636
expr: stage_1_expr,
37-
argument_sources: stage_1_argument_sources
37+
arguments: stage_1_arguments
3838
}
3939
] = chain
4040

41-
assert Enum.all?(stage_0_argument_sources, fn {_id, {source_id, _idx}} ->
41+
assert Enum.all?(stage_0_arguments, fn {_id, %{source: {source_id, _idx}}} ->
4242
source_id == nil
4343
end)
4444

@@ -52,10 +52,10 @@ defmodule Nx.Defn.GraphSplitterTest do
5252

5353
# ensure that arg2 and arg3 map to the correct stage and output container position
5454
assert %{
55-
arg_2_id => {stage_0_id, 0},
56-
arg_3_id => {stage_0_id, 1}
55+
arg_2_id => %{source: {stage_0_id, 0}, index: 0},
56+
arg_3_id => %{source: {stage_0_id, 1}, index: 1}
5757
} ==
58-
stage_1_argument_sources
58+
stage_1_arguments
5959

6060
# ensure that arg2 and arg3 are replacing the correct nodes
6161
{_dot_node_id, %T{data: %Expr{args: [dot_arg_0, _, _, dot_arg_1, _, _]}}} =
@@ -149,21 +149,21 @@ defmodule Nx.Defn.GraphSplitterTest do
149149
%Stage{
150150
id: stage_0_id,
151151
expr: stage_0_expr,
152-
argument_sources: stage_0_argument_sources
152+
arguments: stage_0_arguments
153153
},
154154
%Stage{
155155
id: stage_1_id,
156156
expr: stage_1_expr,
157-
argument_sources: stage_1_argument_sources
157+
arguments: stage_1_arguments
158158
},
159159
%Stage{
160160
id: _stage_2_id,
161161
expr: stage_2_expr,
162-
argument_sources: stage_2_argument_sources
162+
arguments: stage_2_arguments
163163
}
164164
] = chain
165165

166-
assert Enum.all?(stage_0_argument_sources, fn {_id, {source_id, _idx}} ->
166+
assert Enum.all?(stage_0_arguments, fn {_id, %{source: {source_id, _idx}}} ->
167167
source_id == nil
168168
end)
169169

@@ -200,10 +200,10 @@ defmodule Nx.Defn.GraphSplitterTest do
200200

201201
# ensure that arg3 and arg4 map to the correct stage and output container position
202202
assert %{
203-
arg_3_id => {stage_0_id, 0},
204-
arg_4_id => {stage_0_id, 1}
203+
arg_3_id => %{source: {stage_0_id, 0}, index: 0},
204+
arg_4_id => %{source: {stage_0_id, 1}, index: 1}
205205
} ==
206-
stage_1_argument_sources
206+
stage_1_arguments
207207

208208
# ensure that arg3 and arg4 are replacing the correct nodes
209209
{_dot_node_id, %T{data: %Expr{args: [dot_arg_0, _, _, dot_arg_1, _, _]}}} =
@@ -269,7 +269,10 @@ defmodule Nx.Defn.GraphSplitterTest do
269269
assert %T{data: %Expr{op: :sum, args: [^a, [axes: nil, keep_axes: false]]}} = b
270270
assert %T{data: %Expr{id: ^arg_5_id, op: :parameter, args: [1]}} = a
271271

272-
assert %{arg_2_id => {nil, 2}, arg_5_id => {stage_1_id, 0}} == stage_2_argument_sources
272+
assert %{
273+
arg_2_id => %{source: {nil, 2}, index: 0},
274+
arg_5_id => %{source: {stage_1_id, 0}, index: 1}
275+
} == stage_2_arguments
273276
end
274277

275278
test "supports optional callbacks" do
@@ -294,21 +297,26 @@ defmodule Nx.Defn.GraphSplitterTest do
294297

295298
assert [%Stage{} = stage_0, %Stage{} = stage_1] = GraphSplitter.traverse(expr, split_fn)
296299

297-
[{arg1_id, 0}] =
300+
[{arg1_id, %{source: {nil, 1}, index: 0}}] =
298301
Enum.to_list(stage_0.arguments)
299302

300-
assert stage_0.argument_sources == %{arg1_id => {nil, 1}}
303+
assert stage_0.arguments == %{arg1_id => %{source: {nil, 1}, index: 0}}
301304

302305
stage_1_args =
303306
Enum.sort_by(stage_1.arguments, fn {_id, idx} -> idx end)
304307

308+
stage_0_id = stage_0.id
309+
305310
assert [
306-
{arg_0_id, 0},
307-
{arg_1_id, 1}
311+
{arg_0_id, %{source: {nil, 0}, index: 0}},
312+
{arg_1_id, %{source: {^stage_0_id, 0}, index: 1}}
308313
] =
309314
stage_1_args
310315

311-
assert stage_1.argument_sources == %{arg_0_id => {nil, 0}, arg_1_id => {stage_0.id, 0}}
316+
assert stage_1.arguments == %{
317+
arg_0_id => %{source: {nil, 0}, index: 0},
318+
arg_1_id => %{source: {stage_0.id, 0}, index: 1}
319+
}
312320

313321
assert %T{data: %Expr{op: :subtract, args: [c, d]}} = stage_1.expr
314322
assert %T{data: %Expr{op: :optional, args: [call, subexpr, _fun]}} = c
@@ -357,16 +365,21 @@ defmodule Nx.Defn.GraphSplitterTest do
357365

358366
assert [%Stage{} = stage_0, %Stage{} = stage_1] = GraphSplitter.traverse(expr, split_fn)
359367

360-
[{arg1_id, 0}] = Enum.to_list(stage_0.arguments)
361-
362-
assert stage_0.argument_sources == %{arg1_id => {nil, 1}}
368+
assert [{stage_0_arg_0_id, %{source: {nil, 1}, index: 0}}] = Enum.to_list(stage_0.arguments)
363369

364370
stage_1_args =
365371
Enum.sort_by(stage_1.arguments, fn {_id, idx} -> idx end)
366372

367-
assert [{arg_0_id, 0}, {arg_1_id, 1}] = stage_1_args
373+
stage_0_id = stage_0.id
374+
375+
assert [
376+
{arg_0_id, %{source: {nil, 0}, index: 0}},
377+
{arg_1_id, %{source: {^stage_0_id, 0}, index: 1}}
378+
] = stage_1_args
368379

369-
assert stage_1.argument_sources == %{arg_0_id => {nil, 0}, arg_1_id => {stage_0.id, 0}}
380+
assert arg_0_id != arg_1_id
381+
assert arg_0_id != stage_0_arg_0_id
382+
assert arg_1_id != stage_0_arg_0_id
370383

371384
assert %T{data: %Expr{op: :subtract, args: [c, d]}} = stage_1.expr
372385

@@ -439,8 +452,8 @@ defmodule Nx.Defn.GraphSplitterTest do
439452

440453
assert %T{data: %Expr{id: arg1_left_id, op: :parameter, args: [0]}} = arg1_left
441454

442-
assert left.argument_sources[arg1_left_id] == {nil, 1}
443-
assert left.argument_sources[x_left_id] == {root.id, 0}
455+
assert left.arguments[arg1_left_id].source == {nil, 1}
456+
assert left.arguments[x_left_id].source == {root.id, 0}
444457

445458
# right should depend on the result of the root and on arg1, but arg1 will be reindexed
446459
# we should assert that the argument source for arg1_right is correct
@@ -458,8 +471,8 @@ defmodule Nx.Defn.GraphSplitterTest do
458471

459472
assert %T{data: %Expr{id: arg1_right_id, op: :parameter, args: [0]}} = arg1_right
460473

461-
assert right.argument_sources[arg1_right_id] == {nil, 1}
462-
assert right.argument_sources[x_right_id] == {root.id, 0}
474+
assert right.arguments[arg1_right_id].source == {nil, 1}
475+
assert right.arguments[x_right_id].source == {root.id, 0}
463476

464477
assert %T{data: %Expr{op: :add, args: [w_right, w_left]}} = merge.expr
465478

@@ -483,8 +496,8 @@ defmodule Nx.Defn.GraphSplitterTest do
483496
}
484497
} = w_left
485498

486-
assert merge.argument_sources[w_right_id] == {right.id, 0}
487-
assert merge.argument_sources[w_left_id] == {left.id, 0}
499+
assert merge.arguments[w_right_id].source == {right.id, 0}
500+
assert merge.arguments[w_left_id].source == {left.id, 0}
488501

489502
assert GraphSplitter.run(chain, args) == expected_result
490503
end

0 commit comments

Comments
 (0)