Skip to content

Commit 1c0ca22

Browse files
committed
docs: add example
1 parent cd9014a commit 1c0ca22

File tree

3 files changed

+62
-31
lines changed

3 files changed

+62
-31
lines changed

nx/lib/nx/defn/graph_splitter.ex renamed to nx/lib/nx/defn/graph.ex

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,60 @@
1-
defmodule Nx.Defn.GraphSplitter do
1+
defmodule Nx.Defn.Graph do
22
alias Nx.Defn.Composite
33

44
alias Nx.Tensor, as: T
55
alias Nx.Defn.Expr
6-
alias Nx.Defn.GraphSplitter.Stage
6+
7+
defmodule Stage do
8+
@typedoc """
9+
A stage in the graph splitter.
10+
11+
* `:arguments`: a list of maps that point to the source from which to fetch the corresponding
12+
value for the given argument.
13+
14+
* `:expr`: the expression that represents the computation for the Stage.
15+
16+
* `:id`: the unique id for the Stage.
17+
"""
18+
@type t :: %__MODULE__{
19+
id: reference(),
20+
expr: %{__struct__: Nx.Defn.Expr},
21+
arguments: [%{source: {reference() | nil, non_neg_integer()}}]
22+
}
23+
24+
defstruct [:id, :expr, :arguments]
25+
end
726

827
@doc """
9-
Traverses the expression and splits it into stages.
28+
Splits the received Nx.Defn.Expr into stages given the rules
29+
defined by `expr_split_fn`.
30+
31+
## Examples
32+
33+
iex> expr = Nx.Defn.debug_expr(fn x, y -> x |> Nx.negate() |> Nx.sin() |> Nx.cos() |> Nx.add(y) end).(1, 2)
34+
iex> [stage0, stage1] = Nx.Defn.Graph.split(expr, fn %Nx.Tensor{data: %Nx.Defn.Expr{op: op}} -> op == :cos end)
35+
iex> {out0} = stage0.expr
36+
iex> out0
37+
#Nx.Tensor<
38+
f32
39+
\n\
40+
Nx.Defn.Expr
41+
parameter a:0 s32
42+
b = negate a s32
43+
c = sin b f32
44+
>
45+
iex> stage1.expr
46+
#Nx.Tensor<
47+
f32
48+
\n\
49+
Nx.Defn.Expr
50+
parameter a:1 f32
51+
parameter c:0 s32
52+
b = cos a f32
53+
d = add b, c f32
54+
>
1055
"""
11-
def traverse(expr, expr_split_fn \\ fn _ -> false end) do
12-
{chain, _, _} = __traverse__(expr, expr_split_fn)
56+
def split(expr, expr_split_fn \\ fn _ -> false end) do
57+
{chain, _, _} = __split__(expr, expr_split_fn)
1358
chain
1459
end
1560

@@ -56,7 +101,7 @@ defmodule Nx.Defn.GraphSplitter do
56101
end
57102

58103
@doc false
59-
def __traverse__(expr, expr_split_fn) do
104+
def __split__(expr, expr_split_fn) do
60105
# expression_chain is going to be a reverse-accumulation of {category, subexpr}
61106
# that we can then compile and chain-execute elsewhere. category is either :gather, :reduce or :none
62107
state = %{

nx/lib/nx/defn/graph_splitter/stage.ex

Lines changed: 0 additions & 16 deletions
This file was deleted.

nx/test/nx/defn/graph_splitter_test.exs renamed to nx/test/nx/defn/graph_test.exs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
defmodule Nx.Defn.GraphSplitterTest do
1+
defmodule Nx.Defn.GraphTest do
22
use ExUnit.Case, async: true
33

4-
alias Nx.Defn.GraphSplitter
5-
alias Nx.Defn.GraphSplitter.Stage
4+
alias Nx.Defn.Graph
5+
alias Nx.Defn.Graph.Stage
66

77
alias Nx.Tensor, as: T
88
alias Nx.Defn.Expr
99

10+
doctest Nx.Defn.Graph
11+
1012
describe "traverse/1" do
1113
test "simple expression with 1 split and no common nodes" do
1214
expr =
@@ -23,7 +25,7 @@ defmodule Nx.Defn.GraphSplitterTest do
2325
_ -> false
2426
end
2527

26-
{chain, cache, state} = GraphSplitter.traverse_and_return_cache(expr, split_fn)
28+
{chain, cache, state} = Graph.__split__(expr, split_fn)
2729

2830
assert [
2931
%Stage{
@@ -137,7 +139,7 @@ defmodule Nx.Defn.GraphSplitterTest do
137139
_ -> false
138140
end
139141

140-
{chain, cache, state} = GraphSplitter.traverse_and_return_cache(expr, split_fn)
142+
{chain, cache, state} = Graph.__split__(expr, split_fn)
141143

142144
assert [
143145
%Stage{
@@ -280,7 +282,7 @@ defmodule Nx.Defn.GraphSplitterTest do
280282
_ -> false
281283
end
282284

283-
assert [%Stage{} = stage_0, %Stage{} = stage_1] = GraphSplitter.traverse(expr, split_fn)
285+
assert [%Stage{} = stage_0, %Stage{} = stage_1] = Graph.split(expr, split_fn)
284286

285287
assert stage_0.arguments == [%{source: {nil, 1}}]
286288
assert stage_1.arguments == [%{source: {nil, 0}}, %{source: {stage_0.id, 0}}]
@@ -330,7 +332,7 @@ defmodule Nx.Defn.GraphSplitterTest do
330332
_ -> false
331333
end
332334

333-
assert [%Stage{} = stage_0, %Stage{} = stage_1] = GraphSplitter.traverse(expr, split_fn)
335+
assert [%Stage{} = stage_0, %Stage{} = stage_1] = Graph.split(expr, split_fn)
334336

335337
assert [%{source: {nil, 1}}] == stage_0.arguments
336338

@@ -382,7 +384,7 @@ defmodule Nx.Defn.GraphSplitterTest do
382384
_ -> false
383385
end
384386

385-
chain = GraphSplitter.traverse(expr, split_fn)
387+
chain = Graph.split(expr, split_fn)
386388

387389
assert [root, right, left, merge] = chain
388390

@@ -453,7 +455,7 @@ defmodule Nx.Defn.GraphSplitterTest do
453455
assert Enum.fetch!(merge.arguments, 0).source == {right.id, 0}
454456
assert Enum.fetch!(merge.arguments, 1).source == {left.id, 0}
455457

456-
assert GraphSplitter.run(chain, args) == expected_result
458+
assert Graph.run(chain, args) == expected_result
457459
end
458460
end
459461
end

0 commit comments

Comments
 (0)