@@ -86,7 +86,10 @@ defmodule Nx.Defn.GraphSplitter do
8686 end
8787 )
8888
89- { expr_chain , cache , Map . delete ( state , :expression_chain ) }
89+ # Apply topological sort to the expr_chain
90+ sorted_expr_chain = topological_sort ( expr_chain )
91+
92+ { sorted_expr_chain , cache , Map . delete ( state , :expression_chain ) }
9093 end
9194
9295 defp composite_eval ( expr , state , cache ) do
@@ -234,4 +237,39 @@ defmodule Nx.Defn.GraphSplitter do
234237 end
235238
236239 defp rewrite_subtree ( other , _ , acc ) , do: { other , acc }
240+
241+ defp topological_sort ( expr_chain ) do
242+ # Create a new directed graph
243+ graph = :digraph . new ( )
244+
245+ # Add vertices for each stage output
246+ Enum . each ( expr_chain , fn % Stage { id: id , arguments: arguments } ->
247+ Enum . with_index ( arguments , fn _ , idx ->
248+ :digraph . add_vertex ( graph , { id , idx } )
249+ end )
250+ end )
251+
252+ # Add edges based on argument sources
253+ Enum . each ( expr_chain , fn % Stage { id: id , argument_sources: sources } ->
254+ Enum . each ( sources , fn { _arg_id , { source_stage_id , source_index } } ->
255+ if source_stage_id != nil do
256+ Enum . each ( 0 .. ( map_size ( sources ) - 1 ) // 1 , fn idx ->
257+ :digraph . add_edge ( graph , { source_stage_id , source_index } , { id , idx } )
258+ end )
259+ end
260+ end )
261+ end )
262+
263+ # Perform topological sort
264+ sorted_ids =
265+ :digraph_utils . topsort ( graph ) |> Enum . map ( fn { stage_id , _ } -> stage_id end ) |> Enum . uniq ( )
266+
267+ # Clean up the graph
268+ :digraph . delete ( graph )
269+
270+ # Return the sorted stages based on sorted_ids
271+ Enum . map ( sorted_ids , fn stage_id ->
272+ Enum . find ( expr_chain , fn % Stage { id: id } -> id == stage_id end )
273+ end )
274+ end
237275end
0 commit comments