Skip to content

Commit f6090c3

Browse files
committed
docs: add more docs
1 parent ab1adaa commit f6090c3

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

nx/lib/nx/defn/graph.ex

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,15 @@
11
defmodule Nx.Defn.Graph do
2+
@moduledoc """
3+
A module for splitting Nx.Defn.Expr into stages.
4+
5+
This module is used to split an Nx.Defn.Expr into stages, which are then
6+
executed in a chain.
7+
8+
`split/2` and `t:Stage.t()` describe how to split
9+
the graph and what's the expected result.
10+
11+
`run/2` executes the given graph against the provided arguments in a sequential manner.
12+
"""
213
alias Nx.Defn.Composite
314

415
alias Nx.Tensor, as: T
@@ -23,8 +34,10 @@ defmodule Nx.Defn.Graph do
2334
end
2435

2536
@doc """
26-
Splits the received Nx.Defn.Expr into stages given the rules
27-
defined by `expr_split_fn`.
37+
Splits the received Nx.Defn.Expr into stages given the rules.
38+
39+
`expr_split_fn` is a function that receives an `Nx.Tensor` containing an `Nx.Defn.Expr`
40+
and returns `true` when a split must happen, and `false` otherwise.
2841
2942
## Examples
3043
@@ -51,7 +64,7 @@ defmodule Nx.Defn.Graph do
5164
d = add b, c f32
5265
>
5366
"""
54-
def split(expr, expr_split_fn \\ fn _ -> false end) do
67+
def split(expr, expr_split_fn) when is_function(expr_split_fn, 1) do
5568
{chain, _, _} = __split__(expr, expr_split_fn)
5669
chain
5770
end

0 commit comments

Comments
 (0)