Skip to content

Commit 3a66244

Browse files
committed
refactor: simplify argument mapping representation
1 parent 76ef352 commit 3a66244

File tree

2 files changed

+10
-18
lines changed

2 files changed

+10
-18
lines changed

nx/lib/nx/defn/graph_splitter.ex

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ defmodule Nx.Defn.GraphSplitter do
2727

2828
args =
2929
arguments
30-
|> Enum.map(fn {id, %T{data: %Expr{args: [idx]}}} ->
30+
|> Enum.map(fn {id, idx} ->
3131
source = Map.fetch!(argument_sources, id)
3232
argument = Map.fetch!(scope, source)
3333
{idx, argument}
@@ -116,7 +116,8 @@ defmodule Nx.Defn.GraphSplitter do
116116

117117
arguments =
118118
Map.new(arg_remapping, fn {_id, arg_expr} ->
119-
{arg_expr.data.id, arg_expr}
119+
[idx] = arg_expr.data.args
120+
{arg_expr.data.id, idx}
120121
end)
121122

122123
argument_sources =

nx/test/nx/defn/graph_splitter_test.exs

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -294,19 +294,17 @@ defmodule Nx.Defn.GraphSplitterTest do
294294

295295
assert [%Stage{} = stage_0, %Stage{} = stage_1] = GraphSplitter.traverse(expr, split_fn)
296296

297-
[{arg1_id, %T{shape: {2, 3}, type: {:u, 8}, data: %Expr{args: [0]}}}] =
297+
[{arg1_id, 0}] =
298298
Enum.to_list(stage_0.arguments)
299299

300300
assert stage_0.argument_sources == %{arg1_id => {nil, 1}}
301301

302302
stage_1_args =
303-
Enum.sort_by(stage_1.arguments, fn {_id, %T{data: %Expr{op: :parameter, args: [idx]}}} ->
304-
idx
305-
end)
303+
Enum.sort_by(stage_1.arguments, fn {_id, idx} -> idx end)
306304

307305
assert [
308-
{arg_0_id, %T{shape: {}, type: {:s, 32}}},
309-
{arg_1_id, %T{shape: {2, 3}, type: {:u, 8}}}
306+
{arg_0_id, 0},
307+
{arg_1_id, 1}
310308
] =
311309
stage_1_args
312310

@@ -359,21 +357,14 @@ defmodule Nx.Defn.GraphSplitterTest do
359357

360358
assert [%Stage{} = stage_0, %Stage{} = stage_1] = GraphSplitter.traverse(expr, split_fn)
361359

362-
[{arg1_id, %T{shape: {2, 3}, type: {:u, 8}, data: %Expr{args: [0]}}}] =
363-
Enum.to_list(stage_0.arguments)
360+
[{arg1_id, 0}] = Enum.to_list(stage_0.arguments)
364361

365362
assert stage_0.argument_sources == %{arg1_id => {nil, 1}}
366363

367364
stage_1_args =
368-
Enum.sort_by(stage_1.arguments, fn {_id, %T{data: %Expr{op: :parameter, args: [idx]}}} ->
369-
idx
370-
end)
365+
Enum.sort_by(stage_1.arguments, fn {_id, idx} -> idx end)
371366

372-
assert [
373-
{arg_0_id, %T{shape: {}, type: {:s, 32}}},
374-
{arg_1_id, %T{shape: {2, 3}, type: {:u, 8}}}
375-
] =
376-
stage_1_args
367+
assert [{arg_0_id, 0}, {arg_1_id, 1}] = stage_1_args
377368

378369
assert stage_1.argument_sources == %{arg_0_id => {nil, 0}, arg_1_id => {stage_0.id, 0}}
379370

0 commit comments

Comments
 (0)