Skip to content

Commit 76ef352

Browse files
committed
refactor: remove topsort because expr chain is already sorted as it's built
1 parent 384c770 commit 76ef352

File tree

2 files changed

+2
-49
lines changed

2 files changed

+2
-49
lines changed

nx/lib/nx/defn/graph_splitter.ex

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,7 @@ defmodule Nx.Defn.GraphSplitter do
138138
end
139139
)
140140

141-
# Apply topological sort to the expr_chain
142-
sorted_expr_chain = topological_sort(expr_chain)
143-
144-
{sorted_expr_chain, cache, Map.delete(state, :expression_chain)}
141+
{expr_chain, cache, Map.delete(state, :expression_chain)}
145142
end
146143

147144
defp composite_eval(expr, state, cache) do
@@ -289,42 +286,4 @@ defmodule Nx.Defn.GraphSplitter do
289286
end
290287

291288
defp rewrite_subtree(other, _, acc), do: {other, acc}
292-
293-
defp topological_sort(expr_chain) do
294-
# Create a new directed graph
295-
graph = :digraph.new()
296-
297-
# Add vertices for each stage output
298-
Enum.each(expr_chain, fn %Stage{id: id, arguments: arguments} ->
299-
Enum.with_index(arguments, fn _, idx ->
300-
:digraph.add_vertex(graph, {id, idx})
301-
end)
302-
end)
303-
304-
# Add edges based on argument sources
305-
Enum.each(expr_chain, fn %Stage{id: id, argument_sources: sources} ->
306-
Enum.each(sources, fn {_arg_id, {source_stage_id, source_index}} ->
307-
if source_stage_id != nil do
308-
Enum.each(0..(map_size(sources) - 1)//1, fn idx ->
309-
:digraph.add_edge(graph, {source_stage_id, source_index}, {id, idx})
310-
end)
311-
end
312-
end)
313-
end)
314-
315-
# Perform topological sort
316-
sorted_ids =
317-
graph
318-
|> :digraph_utils.topsort()
319-
|> Enum.map(fn {stage_id, _} -> stage_id end)
320-
|> Enum.uniq()
321-
322-
# Clean up the graph
323-
:digraph.delete(graph)
324-
325-
# Return the sorted stages based on sorted_ids
326-
Enum.map(sorted_ids, fn stage_id ->
327-
Enum.find(expr_chain, fn %Stage{id: id} -> id == stage_id end)
328-
end)
329-
end
330289
end

nx/test/nx/defn/graph_splitter_test.exs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -426,18 +426,12 @@ defmodule Nx.Defn.GraphSplitterTest do
426426

427427
chain = GraphSplitter.traverse(expr, split_fn)
428428

429-
assert [root, side1, side2, merge] = chain
429+
assert [root, right, left, merge] = chain
430430

431431
assert {%T{data: %Expr{op: :multiply, args: [arg0, arg1]}}} = root.expr
432432
assert %T{data: %Expr{op: :parameter, args: [0]}} = arg0
433433
assert %T{data: %Expr{op: :parameter, args: [1]}} = arg1
434434

435-
# because things are balanced, we don't know which of side1 and side2 are left and right
436-
# in our expr, so we should disambiguate:
437-
438-
{[%Stage{} = left], [%Stage{} = right]} =
439-
Enum.split_with([side1, side2], fn %Stage{expr: {expr}} -> expr.data.op == :multiply end)
440-
441435
# left should depend on exactly the same parameters as the root, as it's pulling from
442436
# the global scope
443437
assert {%T{data: %Expr{op: :multiply, args: [x, arg1_left]}}} = left.expr

0 commit comments

Comments
 (0)