Skip to content

Commit 4c0c8f9

Browse files
RomeoVChrisRackauckas
authored andcommitted
Start implementing trimming test
We introduce `SafeTestsets` as part of this, inspired by the way the downstream tests are set up in the `SciMLBase` repo.
1 parent a10d495 commit 4c0c8f9

File tree

5 files changed

+125
-7
lines changed

5 files changed

+125
-7
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
148148
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
149149
SciMLLogging = "a6db7da4-7206-11f0-1eab-35f2a5dbe1d1"
150150
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
151+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
151152
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
152153
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
153154
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
@@ -159,4 +160,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
159160
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
160161

161162
[targets]
162-
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote", "ReverseDiff", "Tracker", "SciMLLogging"]
163+
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "PolyesterForwardDiff", "Random", "ReTestItems", "SafeTestsets", "SIAMFANLEquations", "SparseConnectivityTracer", "SparseMatrixColorings", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote", "ReverseDiff", "Tracker", "SciMLLogging"]

test/runtests.jl

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
using ReTestItems, NonlinearSolve, Hwloc, InteractiveUtils, Pkg
1+
using NonlinearSolve, Hwloc, InteractiveUtils, Pkg
2+
using SafeTestsets
3+
using ReTestItems
24

35
@info sprint(InteractiveUtils.versioninfo)
46

@@ -28,6 +30,13 @@ const GROUP = lowercase(get_from_test_args_or_env("GROUP", "all"))
2830
# To re-enable: change condition to `true` or `VERSION < v"1.13"`
2931
const ENZYME_ENABLED = VERSION < v"1.12"
3032

33+
function activate_trim_env!()
34+
Pkg.activate(abspath(joinpath(dirname(@__FILE__), "trim")))
35+
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
36+
Pkg.instantiate()
37+
return nothing
38+
end
39+
3140
const EXTRA_PKGS = Pkg.PackageSpec[]
3241
if GROUP == "all" || GROUP == "downstream"
3342
push!(EXTRA_PKGS, Pkg.PackageSpec("ModelingToolkit"))
@@ -68,8 +77,22 @@ end
6877

6978
@info "Running tests for group: $(GROUP) with $(RETESTITEMS_NWORKERS) workers"
7079

71-
ReTestItems.runtests(
72-
NonlinearSolve; tags = (GROUP == "all" ? nothing : [Symbol(GROUP)]),
73-
nworkers = RETESTITEMS_NWORKERS, nworker_threads = RETESTITEMS_NWORKER_THREADS,
74-
testitem_timeout = 3600
75-
)
80+
if GROUP != "trim"
81+
ReTestItems.runtests(
82+
NonlinearSolve; tags = (GROUP == "all" ? nothing : [Symbol(GROUP)]),
83+
nworkers = RETESTITEMS_NWORKERS, nworker_threads = RETESTITEMS_NWORKER_THREADS,
84+
testitem_timeout = 3600
85+
)
86+
elseif GROUP == "trim" && VERSION >= v"1.12.0-rc1" # trimming has been introduced in julia 1.12
87+
activate_trim_env!()
88+
@safetestset "Clean implementation (non-trimmable)" begin
89+
using SciMLBase: successful_retcode
90+
include("trim/clean_optimization.jl")
91+
@test successful_retcode(minimize(1.0).retcode)
92+
end
93+
@safetestset "Trimmable implementation" begin
94+
using SciMLBase: successful_retcode
95+
include("trim/trimmable_optimization.jl")
96+
@test successful_retcode(minimize(1.0).retcode)
97+
end
98+
end

test/trim/Project.toml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
[deps]
2+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
3+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
4+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
5+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
6+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
7+
NonlinearSolveFirstOrder = "5959db7a-ea39-4486-b5fe-2dd0bf03d60d"
8+
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
9+
PolyesterWeave = "1d0040c9-8b98-4ee7-8388-3f51789ca0ad"
10+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
11+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
12+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
13+
14+
[sources]
15+
ForwardDiff = {rev = "rv/remove-quote-assert-string-interpolation", url = "https://github.com/RomeoV/ForwardDiff.jl"}
16+
LinearSolve = {rev = "rv/remove-linsolve-forwarddiff-special-path", url = "https://github.com/RomeoV/LinearSolve.jl"}
17+
NonlinearSolveFirstOrder = {path = "../../lib/NonlinearSolveFirstOrder"}
18+
Polyester = {rev = "master", url = "https://github.com/RomeoV/Polyester.jl"}
19+
PolyesterWeave = {rev = "main", url = "https://github.com/RomeoV/PolyesterWeave.jl"}
20+
SciMLBase = {rev = "as/fix-jet-opt", url = "https://github.com/AayushSabharwal/SciMLBase.jl"}
21+
22+
[compat]
23+
ADTypes = "1.15.0"
24+
DiffEqBase = "6.179.0"
25+
LinearAlgebra = "1.12.0"
26+
NonlinearSolveFirstOrder = "1.6.0"
27+
StaticArrays = "1.9.0"

test/trim/clean_optimization.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using NonlinearSolveFirstOrder
2+
using ADTypes: AutoForwardDiff
3+
using ForwardDiff
4+
using LinearAlgebra
5+
using StaticArrays
6+
using LinearSolve
7+
const LS = LinearSolve
8+
9+
function f(u, p)
10+
L, U = cholesky(p.Σ)
11+
rhs = (u .* u .- p.λ)
12+
# there are some issues currently with LinearSolve and triangular matrices,
13+
# so we just make `L` dense here.
14+
linprob = LinearProblem(Matrix(L), rhs)
15+
alg = LS.GenericLUFactorization()
16+
sol = LinearSolve.solve(linprob, alg)
17+
return sol.u
18+
end
19+
20+
struct MyParams{T, M}
21+
λ::T
22+
Σ::M
23+
end
24+
25+
function minimize(x)
26+
autodiff = AutoForwardDiff(; chunksize=1)
27+
alg = TrustRegion(; autodiff, linsolve=LS.CholeskyFactorization())
28+
ps = MyParams(rand(), hermitianpart(rand(2,2)+2I))
29+
prob = NonlinearLeastSquaresProblem{false}(f, rand(2), ps)
30+
sol = solve(prob, alg)
31+
return sol
32+
end
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using NonlinearSolveFirstOrder
2+
using ADTypes: AutoForwardDiff
3+
using ForwardDiff
4+
using LinearAlgebra
5+
using StaticArrays
6+
using LinearSolve
7+
const LS = LinearSolve
8+
9+
function f(u, p)
10+
L, U = cholesky(p.Σ)
11+
rhs = (u .* u .- p.λ)
12+
# there are some issues currently with LinearSolve and triangular matrices,
13+
# so we just make `L` dense here.
14+
linprob = LinearProblem(Matrix(L), rhs)
15+
alg = LS.GenericLUFactorization()
16+
sol = LinearSolve.solve(linprob, alg)
17+
return sol.u
18+
end
19+
20+
struct MyParams{T, M}
21+
λ::T
22+
Σ::M
23+
end
24+
25+
const autodiff = AutoForwardDiff(; chunksize = 1)
26+
const alg = TrustRegion(; autodiff, linsolve = LS.CholeskyFactorization())
27+
const prob = NonlinearLeastSquaresProblem{false}(f, rand(2), MyParams(rand(), hermitianpart(rand(2, 2) + 2I)))
28+
const cache = init(prob, alg)
29+
30+
function minimize(x)
31+
ps = MyParams(x, hermitianpart(rand(2, 2) + 2I))
32+
reinit!(cache, rand(2); p = ps)
33+
solve!(cache)
34+
return cache
35+
end

0 commit comments

Comments
 (0)