@@ -4,16 +4,23 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
44 alias Nx.Tensor , as: T
55 alias Nx.Defn.Expr
66 alias Nx.Defn.ShardingCompiler.Shard
7+ alias Nx.Defn.ShardingCompiler.Passes.GraphSplitter.Stage
78
89 @ gather_ops [ :dot ]
910 @ reduction_ops [ :sum ]
1011
11- def traverse ( expr , expr_shards \\ % { } ) do
12+ @ ops_to_split Map . merge (
13+ Map . new ( @ gather_ops , & { & 1 , :gather } ) ,
14+ Map . new ( @ reduction_ops , & { & 1 , :reduce } )
15+ )
16+
17+ def traverse ( expr , expr_shards \\ % { } , ops_to_split \\ @ ops_to_split ) do
1218 # expression_chain is going to be a reverse-accumulation of {category, subexpr}
1319 # that we can then compile and chain-execute elsewhere. category is either :gather, :reduce or :none
1420 state = % {
1521 expression_chain: [ ] ,
1622 nodes_to_replace: % { } ,
23+ ops_to_split: ops_to_split ,
1724 # contains the sharding configuration for each node by id
1825 shards: expr_shards ,
1926 # args is a map of id -> {stage_id, output_container_position}
@@ -54,62 +61,64 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
5461 { id , { expr , nil } } , idx ->
5562 { id , put_in ( expr . data . args , [ idx ] ) }
5663
57- { id , { expr , shard_propagation } } , idx ->
64+ { id , { expr , _shard_propagation } } , idx ->
5865 expr = put_in ( expr . data . args , [ idx ] )
59- expr = Expr . metadata ( expr , % { shards: shard_propagation . shards } )
6066 { id , expr }
6167 end )
6268 |> Map . new ( )
6369
6470 { expr , _ } =
6571 composite_rewrite_subtree ( expr , % { state | nodes_to_replace: arg_remapping } )
6672
67- expr =
68- Composite . traverse ( expr , fn
69- % T { data: % Expr { id: id } } = t ->
70- if shard_propagation = state . shards [ id ] do
71- Expr . metadata ( t , % { shards: shard_propagation . shards } )
72- else
73- t
74- end
75-
76- other ->
77- other
73+ # Traverse the expression to remap all shapes according to the sharding given
74+ expr = set_shard_metadata ( expr , state . shards )
75+
76+ arguments =
77+ Map . new ( arg_remapping , fn { _id , arg_expr } ->
78+ { arg_expr . data . id , set_shard_metadata ( arg_expr , state . shards ) }
7879 end )
7980
80- argument_sources = Map . take ( state . args , Map . keys ( arg_remapping ) )
81+ argument_sources =
82+ state . args
83+ |> Map . take ( Map . keys ( arg_remapping ) )
84+ |> Map . new ( fn { remap_id , v } ->
85+ { arg_remapping [ remap_id ] . data . id , v }
86+ end )
8187
82- [ { id , category , expr , argument_sources } | acc ]
88+ [
89+ % Stage {
90+ id: id ,
91+ category: category ,
92+ expr: expr ,
93+ arguments: arguments ,
94+ argument_sources: argument_sources
95+ }
96+ | acc
97+ ]
8398 end
8499 )
85100
86- { expr_chain , Map . delete ( state , :expression_chain ) , cache }
101+ { expr_chain , cache , Map . delete ( state , :expression_chain ) }
87102 end
88103
89104 defp composite_eval ( expr , state , cache ) do
90105 Composite . traverse ( expr , { cache , state } , & eval / 2 )
91106 end
92107
93108 defp eval ( % T { data: % Expr { id: id , op: op } } = ans , { cache , state } ) do
94- case { cache , state . nodes_to_replace } do
95- { _ , % { ^ id => res } } ->
109+ case { cache , state . nodes_to_replace , state . ops_to_split } do
110+ { _ , % { ^ id => res } , _ } ->
96111 # Replace the node with the corresponding parameter
97112 { res , { Map . put ( cache , id , res ) , state } }
98113
99- { % { ^ id => res } , _ } ->
114+ { % { ^ id => res } , _ , _ } ->
100115 { res , { cache , state } }
101116
102- { _ , _ } ->
103- cond do
104- op in @ gather_ops ->
105- rewrite_args ( ans , :gather , { cache , state } )
106-
107- op in @ reduction_ops ->
108- rewrite_args ( ans , :reduce , { cache , state } )
117+ { _ , _ , % { ^ op => category } } ->
118+ rewrite_args ( ans , category , { cache , state } )
109119
110- true ->
111- eval_apply ( op , ans , { cache , state } )
112- end
120+ _ ->
121+ eval_apply ( op , ans , { cache , state } )
113122 end
114123 end
115124
@@ -203,8 +212,8 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
203212 { new_expr , { cache , state } }
204213 end
205214
206- defp eval_apply ( :parameter , % T { data: % Expr { id: id } } = expr , { cache , state } ) do
207- state = put_in ( state . args [ id ] , nil )
215+ defp eval_apply ( :parameter , % T { data: % Expr { id: id , args: [ idx ] } } = expr , { cache , state } ) do
216+ state = put_in ( state . args [ id ] , { nil , idx } )
208217 { expr , { Map . put ( cache , id , expr ) , state } }
209218 end
210219
@@ -220,19 +229,26 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
220229 { ans , { Map . put ( cache , id , ans ) , state } }
221230 end
222231
223- defp composite_rewrite_subtree ( args , state , acc \\ % { used_args: % { } } )
232+ defp composite_rewrite_subtree ( container , state , acc \\ % { used_args: % { } } )
224233
225- defp composite_rewrite_subtree ( args , state , acc ) when is_list ( args ) do
226- Enum . map_reduce ( args , acc , fn
234+ defp composite_rewrite_subtree ( container , state , acc ) when is_list ( container ) do
235+ Enum . map_reduce ( container , acc , fn
227236 % T { } = arg , acc ->
228237 composite_rewrite_subtree ( arg , state , acc )
229238
239+ arg , acc when is_list ( arg ) ->
240+ composite_rewrite_subtree ( arg , state , acc )
241+
230242 arg , acc ->
231243 { arg , acc }
232244 end )
233245 end
234246
235- defp composite_rewrite_subtree ( % T { data: % Expr { id: id , op: :parameter } } = expr , state , acc ) do
247+ defp composite_rewrite_subtree ( container , state , acc ) do
248+ Composite . traverse ( container , acc , & rewrite_subtree ( & 1 , state , & 2 ) )
249+ end
250+
251+ defp rewrite_subtree ( % T { data: % Expr { id: id , op: :parameter } } = expr , state , acc ) do
236252 case state . nodes_to_replace do
237253 % { ^ id => res } ->
238254 { res , put_in ( acc . used_args [ id ] , { res , state . shards [ id ] } ) }
@@ -242,22 +258,75 @@ defmodule Nx.Defn.ShardingCompiler.Passes.GraphSplitter do
242258 end
243259 end
244260
245- defp composite_rewrite_subtree ( arg , state , acc ) do
246- Composite . traverse ( arg , acc , & rewrite_subtree ( & 1 , state , & 2 ) )
261+ defp rewrite_subtree (
262+ % T { data: % Expr { op: :optional , id: id , args: [ call , subexpr , fun ] } } = expr ,
263+ state ,
264+ acc
265+ ) do
266+ case state . nodes_to_replace do
267+ % { ^ id => res } ->
268+ { res , put_in ( acc . used_args [ id ] , { res , state . shards [ id ] } ) }
269+
270+ _ ->
271+ { call , acc } = rewrite_subtree ( call , state , acc )
272+ # `subexpr` is hermetic, in the sense that it is a self-contained scope
273+ # from which the arguments always come from `call`, so we can
274+ # keep it as is.
275+
276+ { put_in ( expr . data . args , [ call , subexpr , fun ] ) , acc }
277+ end
247278 end
248279
249280 defp rewrite_subtree ( % T { data: % Expr { id: id , args: args } } = expr , state , acc ) do
250281 case state . nodes_to_replace do
251282 % { ^ id => res } ->
252283 # nodes_to_replace always contains a param
253- { res , put_in ( acc . used_args [ id ] , res ) }
284+ { res , put_in ( acc . used_args [ id ] , { res , state . shards [ id ] } ) }
254285
255286 _ ->
256287 { args , acc } = composite_rewrite_subtree ( args , state , acc )
257-
258288 { put_in ( expr . data . args , args ) , acc }
259289 end
260290 end
261291
262292 defp rewrite_subtree ( other , _ , acc ) , do: { other , acc }
293+
294+ defp set_shard_metadata ( expr , shards ) do
295+ Composite . traverse ( expr , fn
296+ % T { data: % Expr { id: id } } = t ->
297+ if shard_propagation = shards [ id ] do
298+ shape =
299+ shard_propagation . shards
300+ |> Enum . sort ( )
301+ |> Enum . map ( fn { _axis , [ % Shard { length: length } | _ ] } -> length end )
302+ |> List . to_tuple ( )
303+
304+ t = do_set_shard_metadata ( % { t | shape: shape } , shards )
305+ Expr . metadata ( t , % { shards: shard_propagation . shards } )
306+ else
307+ do_set_shard_metadata ( t , shards )
308+ end
309+
310+ other ->
311+ other
312+ end )
313+ end
314+
315+ defp do_set_shard_metadata ( % T { data: % Expr { args: args } } = expr , shards ) do
316+ args =
317+ Enum . map ( args , fn
318+ % T { } = arg ->
319+ set_shard_metadata ( arg , shards )
320+
321+ arg when is_list ( arg ) ->
322+ Enum . map ( arg , & do_set_shard_metadata ( & 1 , shards ) )
323+
324+ arg ->
325+ arg
326+ end )
327+
328+ put_in ( expr . data . args , args )
329+ end
330+
331+ defp do_set_shard_metadata ( other , _ ) , do: other
263332end
0 commit comments