Skip to content

Commit b8dd06a

Browse files
committed
refactor: represent arguments as a list
1 parent 591b6f4 commit b8dd06a

File tree

3 files changed

+33
-85
lines changed

3 files changed

+33
-85
lines changed

nx/lib/nx/defn/graph_splitter.ex

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,9 @@ defmodule Nx.Defn.GraphSplitter do
2626
%{id: id, expr: expr, arguments: arguments} = stage
2727

2828
args =
29-
arguments
30-
|> Enum.map(fn {_id, %{source: source, index: idx}} ->
31-
argument = Map.fetch!(scope, source)
32-
{idx, argument}
29+
Enum.map(arguments, fn %{source: source} ->
30+
Map.fetch!(scope, source)
3331
end)
34-
|> Enum.sort_by(fn {idx, _} -> idx end)
35-
|> Enum.map(fn {_, argument} -> argument end)
3632

3733
case Nx.Defn.jit_apply(fn _ -> expr end, [List.to_tuple(args)]) do
3834
%T{} = tensor ->
@@ -114,12 +110,15 @@ defmodule Nx.Defn.GraphSplitter do
114110
composite_rewrite_subtree(expr, %{state | nodes_to_replace: arg_remapping})
115111

116112
arguments =
117-
Map.new(arg_remapping, fn {_id, arg_expr} ->
113+
arg_remapping
114+
|> Enum.map(fn {_id, arg_expr} ->
118115
id = arg_expr.data.id
119116
[idx] = arg_expr.data.args
120117
source = Map.fetch!(state.args, id)
121-
{id, %{source: source, index: idx}}
118+
{idx, %{source: source}}
122119
end)
120+
|> Enum.sort_by(fn {idx, _} -> idx end)
121+
|> Enum.map(fn {_, arg} -> arg end)
123122

124123
[
125124
%Stage{

nx/lib/nx/defn/graph_splitter/stage.ex

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,7 @@ defmodule Nx.Defn.GraphSplitter.Stage do
99
@type t :: %__MODULE__{
1010
id: reference(),
1111
expr: %{__struct__: Nx.Defn.Expr},
12-
arguments: %{
13-
reference() => %{
14-
source: {reference() | nil, non_neg_integer()},
15-
index: non_neg_integer()
16-
}
17-
}
12+
arguments: [%{source: {reference() | nil, non_neg_integer()}}]
1813
}
1914

2015
defstruct [:id, :expr, :arguments]

nx/test/nx/defn/graph_splitter_test.exs

Lines changed: 25 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ defmodule Nx.Defn.GraphSplitterTest do
3838
}
3939
] = chain
4040

41-
assert Enum.all?(stage_0_arguments, fn {_id, %{source: {source_id, _idx}}} ->
42-
source_id == nil
43-
end)
41+
assert [%{source: {nil, 0}}, %{source: {nil, 1}}] == stage_0_arguments
4442

4543
assert [{2, arg_2_original_node_id, arg_2_id}, {3, arg_3_original_node_id, arg_3_id}] =
4644
state.nodes_to_replace
@@ -51,11 +49,7 @@ defmodule Nx.Defn.GraphSplitterTest do
5149
|> Enum.sort()
5250

5351
# ensure that arg2 and arg3 map to the correct stage and output container position
54-
assert %{
55-
arg_2_id => %{source: {stage_0_id, 0}, index: 0},
56-
arg_3_id => %{source: {stage_0_id, 1}, index: 1}
57-
} ==
58-
stage_1_arguments
52+
assert [%{source: {stage_0_id, 0}}, %{source: {stage_0_id, 1}}] == stage_1_arguments
5953

6054
# ensure that arg2 and arg3 are replacing the correct nodes
6155
{_dot_node_id, %T{data: %Expr{args: [dot_arg_0, _, _, dot_arg_1, _, _]}}} =
@@ -163,9 +157,7 @@ defmodule Nx.Defn.GraphSplitterTest do
163157
}
164158
] = chain
165159

166-
assert Enum.all?(stage_0_arguments, fn {_id, %{source: {source_id, _idx}}} ->
167-
source_id == nil
168-
end)
160+
assert [%{source: {nil, 0}}, %{source: {nil, 1}}] == stage_0_arguments
169161

170162
assert map_size(state.args) == 6
171163

@@ -199,11 +191,7 @@ defmodule Nx.Defn.GraphSplitterTest do
199191
assert arg_5_id not in original_args
200192

201193
# ensure that arg3 and arg4 map to the correct stage and output container position
202-
assert %{
203-
arg_3_id => %{source: {stage_0_id, 0}, index: 0},
204-
arg_4_id => %{source: {stage_0_id, 1}, index: 1}
205-
} ==
206-
stage_1_arguments
194+
assert [%{source: {stage_0_id, 0}}, %{source: {stage_0_id, 1}}] == stage_1_arguments
207195

208196
# ensure that arg3 and arg4 are replacing the correct nodes
209197
{_dot_node_id, %T{data: %Expr{args: [dot_arg_0, _, _, dot_arg_1, _, _]}}} =
@@ -269,10 +257,7 @@ defmodule Nx.Defn.GraphSplitterTest do
269257
assert %T{data: %Expr{op: :sum, args: [^a, [axes: nil, keep_axes: false]]}} = b
270258
assert %T{data: %Expr{id: ^arg_5_id, op: :parameter, args: [1]}} = a
271259

272-
assert %{
273-
arg_2_id => %{source: {nil, 2}, index: 0},
274-
arg_5_id => %{source: {stage_1_id, 0}, index: 1}
275-
} == stage_2_arguments
260+
assert [%{source: {nil, 2}}, %{source: {stage_1_id, 0}}] == stage_2_arguments
276261
end
277262

278263
test "supports optional callbacks" do
@@ -297,35 +282,17 @@ defmodule Nx.Defn.GraphSplitterTest do
297282

298283
assert [%Stage{} = stage_0, %Stage{} = stage_1] = GraphSplitter.traverse(expr, split_fn)
299284

300-
[{arg1_id, %{source: {nil, 1}, index: 0}}] =
301-
Enum.to_list(stage_0.arguments)
302-
303-
assert stage_0.arguments == %{arg1_id => %{source: {nil, 1}, index: 0}}
304-
305-
stage_1_args =
306-
Enum.sort_by(stage_1.arguments, fn {_id, idx} -> idx end)
307-
308-
stage_0_id = stage_0.id
309-
310-
assert [
311-
{arg_0_id, %{source: {nil, 0}, index: 0}},
312-
{arg_1_id, %{source: {^stage_0_id, 0}, index: 1}}
313-
] =
314-
stage_1_args
315-
316-
assert stage_1.arguments == %{
317-
arg_0_id => %{source: {nil, 0}, index: 0},
318-
arg_1_id => %{source: {stage_0.id, 0}, index: 1}
319-
}
285+
assert stage_0.arguments == [%{source: {nil, 1}}]
286+
assert stage_1.arguments == [%{source: {nil, 0}}, %{source: {stage_0.id, 0}}]
320287

321288
assert %T{data: %Expr{op: :subtract, args: [c, d]}} = stage_1.expr
322289
assert %T{data: %Expr{op: :optional, args: [call, subexpr, _fun]}} = c
323290

324-
assert %T{data: %Expr{id: ^arg_0_id, op: :parameter, args: [0]}} = d
291+
assert %T{data: %Expr{id: arg_0_id, op: :parameter, args: [0]}} = d
325292

326293
assert %T{data: %Expr{op: :logical_not, args: [b]}} = call
327294
assert %T{data: %Expr{op: :sum, args: [a, [axes: [1], keep_axes: false]]}} = b
328-
assert %T{data: %Expr{id: ^arg_1_id, op: :parameter, args: [1]}} = a
295+
assert %T{data: %Expr{id: arg_1_id, op: :parameter, args: [1]}} = a
329296

330297
assert %T{
331298
data: %Expr{
@@ -365,22 +332,9 @@ defmodule Nx.Defn.GraphSplitterTest do
365332

366333
assert [%Stage{} = stage_0, %Stage{} = stage_1] = GraphSplitter.traverse(expr, split_fn)
367334

368-
assert [{stage_0_arg_0_id, %{source: {nil, 1}, index: 0}}] = Enum.to_list(stage_0.arguments)
369-
370-
stage_1_args =
371-
Enum.sort_by(stage_1.arguments, fn {_id, idx} -> idx end)
372-
373-
stage_0_id = stage_0.id
374-
375-
assert [
376-
{arg_0_id, %{source: {nil, 0}, index: 0}},
377-
{arg_1_id, %{source: {^stage_0_id, 0}, index: 1}}
378-
] = stage_1_args
379-
380-
assert arg_0_id != arg_1_id
381-
assert arg_0_id != stage_0_arg_0_id
382-
assert arg_1_id != stage_0_arg_0_id
335+
assert [%{source: {nil, 1}}] == stage_0.arguments
383336

337+
assert [%{source: {nil, 0}}, %{source: {stage_0.id, 0}}] == stage_1.arguments
384338
assert %T{data: %Expr{op: :subtract, args: [c, d]}} = stage_1.expr
385339

386340
assert %T{
@@ -393,10 +347,10 @@ defmodule Nx.Defn.GraphSplitterTest do
393347
}
394348
} = c
395349

396-
assert %T{data: %Expr{id: ^arg_0_id, op: :parameter, args: [0]}} = d
350+
assert %T{data: %Expr{op: :parameter, args: [0]}} = d
397351

398352
assert %T{data: %Expr{op: :sum, args: [a, [axes: [1], keep_axes: false]]}} = left
399-
assert %T{data: %Expr{id: ^arg_1_id, op: :parameter, args: [1]}} = a
353+
assert %T{data: %Expr{op: :parameter, args: [1]}} = a
400354
end
401355
end
402356

@@ -444,16 +398,16 @@ defmodule Nx.Defn.GraphSplitterTest do
444398
data: %Expr{
445399
op: :metadata,
446400
args: [
447-
%T{data: %Expr{id: x_left_id, op: :parameter, args: [1]}},
401+
%T{data: %Expr{op: :parameter, args: [1]}},
448402
%{split: true}
449403
]
450404
}
451405
} = x
452406

453-
assert %T{data: %Expr{id: arg1_left_id, op: :parameter, args: [0]}} = arg1_left
407+
assert %T{data: %Expr{op: :parameter, args: [0]}} = arg1_left
454408

455-
assert left.arguments[arg1_left_id].source == {nil, 1}
456-
assert left.arguments[x_left_id].source == {root.id, 0}
409+
assert Enum.fetch!(left.arguments, 0).source == {nil, 1}
410+
assert Enum.fetch!(left.arguments, 1).source == {root.id, 0}
457411

458412
# right should depend on the result of the root and on arg1, but arg1 will be reindexed
459413
# we should assert that the argument source for arg1_right is correct
@@ -463,24 +417,24 @@ defmodule Nx.Defn.GraphSplitterTest do
463417
data: %Expr{
464418
op: :metadata,
465419
args: [
466-
%T{data: %Expr{id: x_right_id, op: :parameter, args: [1]}},
420+
%T{data: %Expr{op: :parameter, args: [1]}},
467421
%{split: true}
468422
]
469423
}
470424
} = x
471425

472-
assert %T{data: %Expr{id: arg1_right_id, op: :parameter, args: [0]}} = arg1_right
426+
assert %T{data: %Expr{op: :parameter, args: [0]}} = arg1_right
473427

474-
assert right.arguments[arg1_right_id].source == {nil, 1}
475-
assert right.arguments[x_right_id].source == {root.id, 0}
428+
assert Enum.fetch!(right.arguments, 0).source == {nil, 1}
429+
assert Enum.fetch!(right.arguments, 1).source == {root.id, 0}
476430

477431
assert %T{data: %Expr{op: :add, args: [w_right, w_left]}} = merge.expr
478432

479433
assert %T{
480434
data: %Expr{
481435
op: :metadata,
482436
args: [
483-
%T{data: %Expr{id: w_right_id, op: :parameter, args: [0]}},
437+
%T{data: %Expr{op: :parameter, args: [0]}},
484438
%{split: true}
485439
]
486440
}
@@ -490,14 +444,14 @@ defmodule Nx.Defn.GraphSplitterTest do
490444
data: %Expr{
491445
op: :metadata,
492446
args: [
493-
%T{data: %Expr{id: w_left_id, op: :parameter, args: [1]}},
447+
%T{data: %Expr{op: :parameter, args: [1]}},
494448
%{split: true}
495449
]
496450
}
497451
} = w_left
498452

499-
assert merge.arguments[w_right_id].source == {right.id, 0}
500-
assert merge.arguments[w_left_id].source == {left.id, 0}
453+
assert Enum.fetch!(merge.arguments, 0).source == {right.id, 0}
454+
assert Enum.fetch!(merge.arguments, 1).source == {left.id, 0}
501455

502456
assert GraphSplitter.run(chain, args) == expected_result
503457
end

0 commit comments

Comments
 (0)