Skip to content

Commit 9a5d84a

Browse files
committed
feat: apply topsort to expr chain
1 parent f7bb062 commit 9a5d84a

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

nx/lib/nx/defn/graph_splitter.ex

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ defmodule Nx.Defn.GraphSplitter do
8686
end
8787
)
8888

89-
{expr_chain, cache, Map.delete(state, :expression_chain)}
89+
# Apply topological sort to the expr_chain
90+
sorted_expr_chain = topological_sort(expr_chain)
91+
92+
{sorted_expr_chain, cache, Map.delete(state, :expression_chain)}
9093
end
9194

9295
defp composite_eval(expr, state, cache) do
@@ -234,4 +237,39 @@ defmodule Nx.Defn.GraphSplitter do
234237
end
235238

236239
defp rewrite_subtree(other, _, acc), do: {other, acc}
240+
241+
defp topological_sort(expr_chain) do
242+
# Create a new directed graph
243+
graph = :digraph.new()
244+
245+
# Add vertices for each stage output
246+
Enum.each(expr_chain, fn %Stage{id: id, arguments: arguments} ->
247+
Enum.with_index(arguments, fn _, idx ->
248+
:digraph.add_vertex(graph, {id, idx})
249+
end)
250+
end)
251+
252+
# Add edges based on argument sources
253+
Enum.each(expr_chain, fn %Stage{id: id, argument_sources: sources} ->
254+
Enum.each(sources, fn {_arg_id, {source_stage_id, source_index}} ->
255+
if source_stage_id != nil do
256+
Enum.each(0..(map_size(sources) - 1)//1, fn idx ->
257+
:digraph.add_edge(graph, {source_stage_id, source_index}, {id, idx})
258+
end)
259+
end
260+
end)
261+
end)
262+
263+
# Perform topological sort
264+
sorted_ids =
265+
:digraph_utils.topsort(graph) |> Enum.map(fn {stage_id, _} -> stage_id end) |> Enum.uniq()
266+
267+
# Clean up the graph
268+
:digraph.delete(graph)
269+
270+
# Return the sorted stages based on sorted_ids
271+
Enum.map(sorted_ids, fn stage_id ->
272+
Enum.find(expr_chain, fn %Stage{id: id} -> id == stage_id end)
273+
end)
274+
end
237275
end

0 commit comments

Comments
 (0)