File tree Expand file tree Collapse file tree 1 file changed +16
-3
lines changed
Expand file tree Collapse file tree 1 file changed +16
-3
lines changed Original file line number Diff line number Diff line change 11defmodule Nx.Defn.Graph do
2+ @ moduledoc """
3+ A module for splitting Nx.Defn.Expr into stages.
4+
5+ This module is used to split an Nx.Defn.Expr into stages, which are then
6+ executed in a chain.
7+
8+ `split/2` and `t:Stage.t()` describe how to split
9+ the graph and what's the expected result.
10+
11+ `run/2` executes the given graph against the provided arguments in a sequential manner.
12+ """
213 alias Nx.Defn.Composite
314
415 alias Nx.Tensor , as: T
@@ -23,8 +34,10 @@ defmodule Nx.Defn.Graph do
2334 end
2435
2536 @ doc """
26- Splits the received Nx.Defn.Expr into stages given the rules
27- defined by `expr_split_fn`.
37+ Splits the received Nx.Defn.Expr into stages given the rules.
38+
39+ `expr_split_fn` is a function that receives an `Nx.Tensor` containing an `Nx.Defn.Expr`
40+ and returns `true` when a split must happen, and `false` otherwise.
2841
2942 ## Examples
3043
@@ -51,7 +64,7 @@ defmodule Nx.Defn.Graph do
5164 d = add b, c f32
5265 >
5366 """
54- def split ( expr , expr_split_fn \\ fn _ -> false end ) do
67+ def split ( expr , expr_split_fn ) when is_function ( expr_split_fn , 1 ) do
5568 { chain , _ , _ } = __split__ ( expr , expr_split_fn )
5669 chain
5770 end
You can’t perform that action at this time.
0 commit comments