Skip to content

Commit 1746da1

Browse files
committed
Add tests
1 parent f2eb1c2 commit 1746da1

File tree

3 files changed

+47
-24
lines changed

3 files changed

+47
-24
lines changed

src/Nonlinear/ReverseAD/Coloring/Coloring.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
module Coloring
88

9-
include("Forest.jl")
9+
include("IntDisjointSet.jl")
1010
include("topological_sort.jl")
1111

1212
"""
@@ -176,7 +176,7 @@ function _grow_star(v, w, e_idx, firstNeighbor, color, S)
176176
return
177177
end
178178

179-
function _merge_trees(S::_Forest, eg::Int, eg1::Int)
179+
function _merge_trees(S::_IntDisjointSet, eg::Int, eg1::Int)
180180
if _find_root!(S, eg) != _find_root!(S, eg1)
181181
_root_union!(S, eg, eg1)
182182
end
@@ -202,7 +202,7 @@ function acyclic_coloring(g::UndirectedGraph)
202202
firstNeighbor = _Edge[]
203203
firstVisitToTree = fill(_Edge(0, 0, 0), _num_edges(g))
204204
color = fill(0, _num_vertices(g))
205-
S = _Forest(_num_edges(g))
205+
S = _IntDisjointSet(_num_edges(g))
206206
@inbounds for v in 1:_num_vertices(g)
207207
n_neighbor = _num_neighbors(v, g)
208208
start_neighbor = _start_neighbors(v, g)

src/Nonlinear/ReverseAD/Coloring/Forest.jl renamed to src/Nonlinear/ReverseAD/Coloring/IntDisjointSet.jl

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,46 +6,42 @@
66
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.
77

88
# The code in this file was taken from
9-
# https://github.com/gdalle/SparseMatrixColorings.jl/blob/main/src/forest.jl
9+
# https://github.com/gdalle/SparseMatrixColorings.jl/blob/main/src/Forest.jl
1010
#
1111
# It was copied at the suggestion of Alexis in his JuMP-dev 2025 talk.
1212
#
1313
# @odow made minor changes to match MOI coding styles.
1414
#
1515
# x-ref https://github.com/gdalle/SparseMatrixColorings.jl/pull/190
1616

17-
mutable struct _Forest
18-
# current number of distinct trees in the forest
17+
mutable struct _IntDisjointSet
18+
# current number of distinct trees in the S
1919
number_of_trees::Int
2020
# vector storing the index of a parent in the tree for each edge, used in
2121
# union-find operations
2222
parents::Vector{Int}
2323
# vector approximating the depth of each tree to optimize path compression
2424
ranks::Vector{Int}
2525

26-
_Forest(n::Integer) = new(n, collect(Base.OneTo(n)), zeros(Int, n))
26+
_IntDisjointSet(n::Integer) = new(n, collect(1:n), zeros(Int, n))
2727
end
2828

29-
function _find_root!(parents::Vector{Int}, index_edge::Integer)
30-
p = parents[index_edge]
31-
if parents[p] != p
32-
parents[index_edge] = p = _find_root!(parents, p)
29+
function _find_root!(S::_IntDisjointSet, x::Integer)
30+
p = S.parents[x]
31+
if S.parents[p] != p
32+
S.parents[x] = p = _find_root!(S, p)
3333
end
3434
return p
3535
end
3636

37-
function _find_root!(forest::_Forest, index_edge::Integer)
38-
return _find_root!(forest.parents, index_edge)
39-
end
40-
41-
function _root_union!(forest::_Forest, index_edge1::Int, index_edge2::Int)
42-
rank1, rank2 = forest.ranks[index_edge1], forest.ranks[index_edge2]
37+
function _root_union!(S::_IntDisjointSet, x::Int, y::Int)
38+
rank1, rank2 = S.ranks[x], S.ranks[y]
4339
if rank1 < rank2
44-
index_edge1, index_edge2 = index_edge2, index_edge1
40+
x, y = y, x
4541
elseif rank1 == rank2
46-
forest.ranks[index_edge1] += 1
42+
S.ranks[x] += 1
4743
end
48-
forest.parents[index_edge2] = index_edge1
49-
forest.number_of_trees -= 1
44+
S.parents[y] = x
45+
S.number_of_trees -= 1
5046
return
5147
end

test/Nonlinear/ReverseAD.jl

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ import LinearAlgebra
1111
import MathOptInterface as MOI
1212
import SparseArrays
1313

14-
const Nonlinear = MOI.Nonlinear
15-
const ReverseAD = Nonlinear.ReverseAD
16-
const Coloring = ReverseAD.Coloring
14+
import MathOptInterface.Nonlinear
15+
import MathOptInterface.Nonlinear.ReverseAD
16+
import MathOptInterface.Nonlinear.ReverseAD.Coloring
1717

1818
function runtests()
1919
for name in names(@__MODULE__; all = true)
@@ -1421,6 +1421,33 @@ function test_hessian_reinterpret_unsafe()
14211421
return
14221422
end
14231423

1424+
function test_IntDisjointSet()
1425+
for case in [
1426+
[(1, 2) => [1, 1, 3], (1, 3) => [1, 1, 1]],
1427+
[(1, 2) => [1, 1, 3], (3, 1) => [1, 1, 1]],
1428+
[(2, 1) => [2, 2, 3], (1, 3) => [2, 2, 2]],
1429+
[(2, 1) => [2, 2, 3], (3, 1) => [3, 2, 3]],
1430+
[(1, 3) => [1, 2, 1], (2, 3) => [1, 2, 2]],
1431+
[(1, 3) => [1, 2, 1], (3, 2) => [1, 1, 1]],
1432+
[(3, 1) => [3, 2, 3], (2, 3) => [3, 3, 3]],
1433+
[(3, 1) => [3, 2, 3], (3, 2) => [3, 3, 3]],
1434+
[(2, 3) => [1, 2, 2], (1, 3) => [1, 2, 1]],
1435+
[(2, 3) => [1, 2, 2], (3, 1) => [2, 2, 2]],
1436+
[(3, 2) => [1, 3, 3], (1, 3) => [3, 3, 3]],
1437+
[(3, 2) => [1, 3, 3], (3, 1) => [3, 3, 3]],
1438+
]
1439+
S = Coloring._IntDisjointSet(3)
1440+
@test Coloring._find_root!.((S,), [1, 2, 3]) == [1, 2, 3]
1441+
@test S.number_of_trees == 3
1442+
for (i, (union, result)) in enumerate(case)
1443+
Coloring._root_union!(S, union[1], union[2])
1444+
@test Coloring._find_root!.((S,), [1, 2, 3]) == result
1445+
@test S.number_of_trees == 3 - i
1446+
end
1447+
end
1448+
return
1449+
end
1450+
14241451
end # module
14251452

14261453
TestReverseAD.runtests()

0 commit comments

Comments
 (0)