Skip to content

Commit 84db119

Browse files
authored
[Nonlinear.ReverseAd.Coloring] fix acyclic coloring algorithm (#2898)
1 parent 7895482 commit 84db119

File tree

3 files changed

+23
-9
lines changed

3 files changed

+23
-9
lines changed

src/Nonlinear/ReverseAD/Coloring/Coloring.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,9 @@ function _grow_star(v, w, e_idx, firstNeighbor, color, S)
171171
@inbounds if p != v
172172
firstNeighbor[color[w]] = _Edge(e_idx, v, w)
173173
else
174-
_root_union!(S, e_idx, e.index)
175-
end
176-
return
177-
end
178-
179-
function _merge_trees(S::_IntDisjointSet, eg::Int, eg1::Int)
180-
if _find_root!(S, eg) != _find_root!(S, eg1)
181-
_root_union!(S, eg, eg1)
174+
root1 = _find_root!(S, e_idx)
175+
root2 = _find_root!(S, e.index)
176+
_root_union!(S, root1, root2)
182177
end
183178
return
184179
end
@@ -280,7 +275,7 @@ function acyclic_coloring(g::UndirectedGraph)
280275
continue
281276
end
282277
if color[x] == color[v]
283-
_merge_trees(S, e_idx, e2_idx)
278+
_union!(S, e_idx, e2_idx)
284279
end
285280
end
286281
end

src/Nonlinear/ReverseAD/Coloring/IntDisjointSet.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,12 @@ function _root_union!(S::_IntDisjointSet, x::Int, y::Int)
4545
S.number_of_trees -= 1
4646
return
4747
end
48+
49+
function _union!(S, x::Int, y::Int)
50+
root_x = _find_root!(S, x)
51+
root_y = _find_root!(S, y)
52+
if root_x != root_y
53+
_root_union!(S, root_x, root_y)
54+
end
55+
return
56+
end

test/Nonlinear/ReverseAD.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1448,6 +1448,16 @@ function test_IntDisjointSet()
14481448
return
14491449
end
14501450

1451+
function test_issue_2897()
1452+
I = [4, 5, 4, 6, 5, 6]
1453+
J = [2, 1, 1, 2, 3, 3]
1454+
g = Coloring.UndirectedGraph(I, J, length(I))
1455+
color, num_colors = Coloring.acyclic_coloring(g)
1456+
@test color == [1, 1, 1, 2, 2, 3]
1457+
@test num_colors == 3
1458+
return
1459+
end
1460+
14511461
end # module
14521462

14531463
TestReverseAD.runtests()

0 commit comments

Comments
 (0)