From 7957bd646cd20f742f42fab420e025e23cc389bf Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Fri, 24 Oct 2025 08:46:21 -0400 Subject: [PATCH 01/10] Skip parameter updates when gradients contain NaN or Inf - Add check before Optimisers.update to detect NaN/Inf in gradients - Skip update but still increment iteration counter when detected - Add warning message (maxlog=10) to inform users - Fixes issue where NaN gradients corrupt all subsequent updates Addresses: https://discourse.julialang.org/t/how-to-ignore-minibatches-with-nan-gradients-optimizing-a-hybrid-lux-model-using-optimization-jl/132615 --- lib/OptimizationOptimisers/src/OptimizationOptimisers.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index b1713244d..21aa17ac0 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -129,7 +129,12 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractRule break end end - state, θ = Optimisers.update(state, θ, G) + # Skip update if gradient contains NaN or Inf values + if !any(x -> isnan(x) || isinf(x), G) + state, θ = Optimisers.update(state, θ, G) + else + @warn "Skipping parameter update due to NaN or Inf in gradients at iteration $iterations" maxlog=10 + end end cache.progress && @logmsg(LogLevel(-1), "Optimization", _id=progress_id, message="Done", progress=1.0) From 0b17bc958cd6b5550bbd461004419a1e52992d3f Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Fri, 24 Oct 2025 08:47:03 -0400 Subject: [PATCH 02/10] Add tests for NaN/Inf gradient handling - Test with custom gradient function that injects NaN periodically - Test with custom gradient function that injects Inf periodically - Verify iterations complete and parameters remain finite - Verify optimizer doesn't crash when encountering bad gradients --- lib/OptimizationOptimisers/test/runtests.jl | 61 +++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/lib/OptimizationOptimisers/test/runtests.jl b/lib/OptimizationOptimisers/test/runtests.jl index ad754cf74..d347d3b12 100644 --- a/lib/OptimizationOptimisers/test/runtests.jl +++ b/lib/OptimizationOptimisers/test/runtests.jl @@ -134,3 +134,64 @@ end @test res.objective < 1e-4 end + +@testset "NaN/Inf gradient handling" begin + # Test that optimizer skips updates when gradients contain NaN or Inf + rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2 + x0 = zeros(2) + _p = [1.0, 100.0] + + # Counter to track gradient evaluations + grad_counter = Ref(0) + + # Custom gradient function that returns NaN on every 5th call + function custom_grad!(G, x, p) + grad_counter[] += 1 + if grad_counter[] % 5 == 0 + # Inject NaN into gradient + G .= NaN + else + # Normal gradient computation + G[1] = -2.0 * (p[1] - x[1]) - 4.0 * p[2] * x[1] * (x[2] - x[1]^2) + G[2] = 2.0 * p[2] * (x[2] - x[1]^2) + end + return nothing + end + + optprob = OptimizationFunction(rosenbrock; grad = custom_grad!) + prob = OptimizationProblem(optprob, x0, _p) + + # Should not throw error and should complete all iterations + sol = solve(prob, Optimisers.Adam(0.01), maxiters = 20, progress = false) + + # Verify solution completed all iterations + @test sol.stats.iterations == 20 + + # Verify parameters are not NaN (would be NaN if updates were applied with NaN gradients) + @test all(!isnan, sol.u) + @test all(isfinite, sol.u) + + # Test with Inf gradients + grad_counter_inf = Ref(0) + function custom_grad_inf!(G, x, p) + grad_counter_inf[] += 1 + if grad_counter_inf[] % 7 == 0 + # Inject Inf into gradient + G .= Inf + else + # Normal gradient computation + G[1] = -2.0 * (p[1] - x[1]) - 4.0 * p[2] * x[1] * (x[2] - x[1]^2) + G[2] = 2.0 * p[2] * (x[2] - x[1]^2) + end + return nothing + end + + optprob_inf = OptimizationFunction(rosenbrock; grad = custom_grad_inf!) + prob_inf = OptimizationProblem(optprob_inf, x0, _p) + + sol_inf = solve(prob_inf, Optimisers.Adam(0.01), maxiters = 20, progress = false) + + @test sol_inf.stats.iterations == 20 + @test all(!isnan, sol_inf.u) + @test all(isfinite, sol_inf.u) +end From 8fd8cef0f0646e037cc9e80846d71774c1137e77 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Fri, 24 Oct 2025 08:50:49 -0400 Subject: [PATCH 03/10] Fix NaN/Inf check to handle arrays properly Use any(isnan, G) || any(isinf, G) instead of lambda function to correctly handle array elements and hierarchical structures --- lib/OptimizationOptimisers/src/OptimizationOptimisers.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 21aa17ac0..e54cf9670 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -130,7 +130,8 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractRule end end # Skip update if gradient contains NaN or Inf values - if !any(x -> isnan(x) || isinf(x), G) + has_nan_or_inf = any(isnan, G) || any(isinf, G) + if !has_nan_or_inf state, θ = Optimisers.update(state, θ, G) else @warn "Skipping parameter update due to NaN or Inf in gradients at iteration $iterations" maxlog=10 From 1aa8780e8553c58a1ad7697b7169d370bc739d41 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Fri, 24 Oct 2025 08:53:28 -0400 Subject: [PATCH 04/10] Use broadcasting for NaN/Inf check Use any(.!(isfinite.(G))) to properly handle arrays with broadcasting --- lib/OptimizationOptimisers/src/OptimizationOptimisers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index e54cf9670..dede16bba 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -130,7 +130,7 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractRule end end # Skip update if gradient contains NaN or Inf values - has_nan_or_inf = any(isnan, G) || any(isinf, G) + has_nan_or_inf = any(.!(isfinite.(G))) if !has_nan_or_inf state, θ = Optimisers.update(state, θ, G) else From 5a2d69289a1551693f745a5fe9d4419de0c9a6b6 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Fri, 24 Oct 2025 08:57:03 -0400 Subject: [PATCH 05/10] Add robust NaN/Inf checker using Functors.fmap - Add has_nan_or_inf() helper function - Uses Functors.fmap to recursively check all elements - Handles arbitrary nested structures (arrays, ComponentArrays, etc.) - Checks if any element is not finite (catches both NaN and Inf) --- .../src/OptimizationOptimisers.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index dede16bba..2ce28fbe4 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -3,11 +3,24 @@ module OptimizationOptimisers using Reexport, Logging @reexport using Optimisers, OptimizationBase using SciMLBase +using Functors SciMLBase.has_init(opt::AbstractRule) = true SciMLBase.requiresgradient(opt::AbstractRule) = true SciMLBase.allowsfg(opt::AbstractRule) = true +# Helper function to check if gradients contain NaN or Inf +function has_nan_or_inf(x) + result = Ref(false) + Functors.fmap(x) do val + if val isa Number && (!isfinite(val)) + result[] = true + end + return val + end + return result[] +end + function SciMLBase.__init( prob::SciMLBase.OptimizationProblem, opt::AbstractRule; callback = (args...) -> (false), @@ -130,8 +143,7 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractRule end end # Skip update if gradient contains NaN or Inf values - has_nan_or_inf = any(.!(isfinite.(G))) - if !has_nan_or_inf + if !has_nan_or_inf(G) state, θ = Optimisers.update(state, θ, G) else @warn "Skipping parameter update due to NaN or Inf in gradients at iteration $iterations" maxlog=10 From 0cf9eecb3cf123da06256ee09311e55efdb81416 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Fri, 24 Oct 2025 08:57:44 -0400 Subject: [PATCH 06/10] Add Functors as dependency for NaN/Inf checking --- lib/OptimizationOptimisers/Project.toml | 41 +++++++++++++------------ 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/lib/OptimizationOptimisers/Project.toml b/lib/OptimizationOptimisers/Project.toml index d5a921f13..b3d164e26 100644 --- a/lib/OptimizationOptimisers/Project.toml +++ b/lib/OptimizationOptimisers/Project.toml @@ -4,33 +4,36 @@ authors = ["Vaibhav Dixit and contributors"] version = "0.3.13" [deps] -OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" -SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" +OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -[extras] -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" -MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +[sources] +OptimizationBase = {path = "../OptimizationBase"} [compat] -julia = "1.10" -OptimizationBase = "4" -SciMLBase = "2.122.1" +Functors = "0.4, 0.5" +Logging = "1.10" Optimisers = "0.2, 0.3, 0.4" +OptimizationBase = "4" Reexport = "1.2" -Logging = "1.10" +SciMLBase = "2.122.1" +julia = "1.10" -[sources] -OptimizationBase = {path = "../OptimizationBase"} +[extras] +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote", "Printf"] From cc69c5dd97f8f3a1188ee884d53ba25497b814b1 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Fri, 24 Oct 2025 09:01:43 -0400 Subject: [PATCH 07/10] Fix test to use callback approach for NaN/Inf injection - Use Zygote for gradient computation - Inject NaN/Inf via callback that modifies state.grad - This better simulates real-world scenarios where autodiff produces NaN - Avoids issues with custom gradient function signatures --- lib/OptimizationOptimisers/test/runtests.jl | 45 +++++++++------------ 1 file changed, 19 insertions(+), 26 deletions(-) diff --git a/lib/OptimizationOptimisers/test/runtests.jl b/lib/OptimizationOptimisers/test/runtests.jl index d347d3b12..3f86246f2 100644 --- a/lib/OptimizationOptimisers/test/runtests.jl +++ b/lib/OptimizationOptimisers/test/runtests.jl @@ -141,28 +141,26 @@ end x0 = zeros(2) _p = [1.0, 100.0] - # Counter to track gradient evaluations + # Test with NaN gradients using Zygote + # We'll use a callback to inject NaN into some iterations grad_counter = Ref(0) - # Custom gradient function that returns NaN on every 5th call - function custom_grad!(G, x, p) + # Create optimization problem with automatic differentiation + optprob = OptimizationFunction(rosenbrock, OptimizationBase.AutoZygote()) + prob = OptimizationProblem(optprob, x0, _p) + + # Use a callback that modifies the gradient to inject NaN periodically + function nan_callback(state, l) grad_counter[] += 1 if grad_counter[] % 5 == 0 - # Inject NaN into gradient - G .= NaN - else - # Normal gradient computation - G[1] = -2.0 * (p[1] - x[1]) - 4.0 * p[2] * x[1] * (x[2] - x[1]^2) - G[2] = 2.0 * p[2] * (x[2] - x[1]^2) + # Inject NaN into gradient on every 5th iteration + state.grad .= NaN end - return nothing + return false end - optprob = OptimizationFunction(rosenbrock; grad = custom_grad!) - prob = OptimizationProblem(optprob, x0, _p) - # Should not throw error and should complete all iterations - sol = solve(prob, Optimisers.Adam(0.01), maxiters = 20, progress = false) + sol = solve(prob, Optimisers.Adam(0.01), maxiters = 20, progress = false, callback = nan_callback) # Verify solution completed all iterations @test sol.stats.iterations == 20 @@ -173,23 +171,18 @@ end # Test with Inf gradients grad_counter_inf = Ref(0) - function custom_grad_inf!(G, x, p) + prob_inf = OptimizationProblem(optprob, x0, _p) + + function inf_callback(state, l) grad_counter_inf[] += 1 if grad_counter_inf[] % 7 == 0 - # Inject Inf into gradient - G .= Inf - else - # Normal gradient computation - G[1] = -2.0 * (p[1] - x[1]) - 4.0 * p[2] * x[1] * (x[2] - x[1]^2) - G[2] = 2.0 * p[2] * (x[2] - x[1]^2) + # Inject Inf into gradient on every 7th iteration + state.grad .= Inf end - return nothing + return false end - optprob_inf = OptimizationFunction(rosenbrock; grad = custom_grad_inf!) - prob_inf = OptimizationProblem(optprob_inf, x0, _p) - - sol_inf = solve(prob_inf, Optimisers.Adam(0.01), maxiters = 20, progress = false) + sol_inf = solve(prob_inf, Optimisers.Adam(0.01), maxiters = 20, progress = false, callback = inf_callback) @test sol_inf.stats.iterations == 20 @test all(!isnan, sol_inf.u) From 0295dcfe0d9a363c57ff48e7b41d83a9ac14fd04 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Fri, 24 Oct 2025 09:05:41 -0400 Subject: [PATCH 08/10] Apply JuliaFormatter with SciMLStyle and fix Project.toml - Apply SciMLStyle formatting to OptimizationOptimisers.jl - Remove JuliaFormatter from runtime dependencies --- lib/OptimizationOptimisers/src/OptimizationOptimisers.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 2ce28fbe4..1e94a99f0 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -80,6 +80,7 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractRule breakall = false progress_id = :OptimizationOptimizersJL for epoch in 1:epochs, d in data + if cache.f.fg !== nothing && dataiterate x = cache.f.fg(G, θ, d) iterations += 1 @@ -119,7 +120,7 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractRule if cache.progress message = "Loss: $(round(first(first(x)); digits = 3))" @logmsg(LogLevel(-1), "Optimization", _id=progress_id, - message=message, progress=iterations / maxiters) + message=message, progress=iterations/maxiters) end if cache.solver_args.save_best if first(x)[1] < first(min_err)[1] #found a better solution From 9ae783984a664eaa2381e3a59136af4bbd70fe8c Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Fri, 24 Oct 2025 09:11:12 -0400 Subject: [PATCH 09/10] Address PR feedback - Remove Functors dependency and use simple all(isfinite, G) check - Make warning conditional on cache.progress flag - Rewrite tests to use functions that return NaN/Inf in certain regions instead of callback-based approach --- lib/OptimizationOptimisers/Project.toml | 2 - .../src/OptimizationOptimisers.jl | 17 +----- lib/OptimizationOptimisers/test/runtests.jl | 54 ++++++++----------- 3 files changed, 25 insertions(+), 48 deletions(-) diff --git a/lib/OptimizationOptimisers/Project.toml b/lib/OptimizationOptimisers/Project.toml index b3d164e26..05e07b123 100644 --- a/lib/OptimizationOptimisers/Project.toml +++ b/lib/OptimizationOptimisers/Project.toml @@ -4,7 +4,6 @@ authors = ["Vaibhav Dixit and contributors"] version = "0.3.13" [deps] -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" @@ -16,7 +15,6 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" OptimizationBase = {path = "../OptimizationBase"} [compat] -Functors = "0.4, 0.5" Logging = "1.10" Optimisers = "0.2, 0.3, 0.4" OptimizationBase = "4" diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 1e94a99f0..2476e5743 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -3,24 +3,11 @@ module OptimizationOptimisers using Reexport, Logging @reexport using Optimisers, OptimizationBase using SciMLBase -using Functors SciMLBase.has_init(opt::AbstractRule) = true SciMLBase.requiresgradient(opt::AbstractRule) = true SciMLBase.allowsfg(opt::AbstractRule) = true -# Helper function to check if gradients contain NaN or Inf -function has_nan_or_inf(x) - result = Ref(false) - Functors.fmap(x) do val - if val isa Number && (!isfinite(val)) - result[] = true - end - return val - end - return result[] -end - function SciMLBase.__init( prob::SciMLBase.OptimizationProblem, opt::AbstractRule; callback = (args...) -> (false), @@ -144,9 +131,9 @@ function SciMLBase.__solve(cache::OptimizationCache{O}) where {O <: AbstractRule end end # Skip update if gradient contains NaN or Inf values - if !has_nan_or_inf(G) + if all(isfinite, G) state, θ = Optimisers.update(state, θ, G) - else + elseif cache.progress @warn "Skipping parameter update due to NaN or Inf in gradients at iteration $iterations" maxlog=10 end end diff --git a/lib/OptimizationOptimisers/test/runtests.jl b/lib/OptimizationOptimisers/test/runtests.jl index 3f86246f2..cc40c103b 100644 --- a/lib/OptimizationOptimisers/test/runtests.jl +++ b/lib/OptimizationOptimisers/test/runtests.jl @@ -137,54 +137,46 @@ end @testset "NaN/Inf gradient handling" begin # Test that optimizer skips updates when gradients contain NaN or Inf - rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2 + # Function that returns NaN when parameters are in certain regions + function weird_nan_function(x, p) + # Return NaN when x[1] is close to certain values to simulate numerical issues + if abs(x[1] - 0.3) < 0.05 || abs(x[1] + 0.3) < 0.05 + return NaN + end + return (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2 + end + x0 = zeros(2) _p = [1.0, 100.0] - # Test with NaN gradients using Zygote - # We'll use a callback to inject NaN into some iterations - grad_counter = Ref(0) - - # Create optimization problem with automatic differentiation - optprob = OptimizationFunction(rosenbrock, OptimizationBase.AutoZygote()) + optprob = OptimizationFunction(weird_nan_function, OptimizationBase.AutoZygote()) prob = OptimizationProblem(optprob, x0, _p) - # Use a callback that modifies the gradient to inject NaN periodically - function nan_callback(state, l) - grad_counter[] += 1 - if grad_counter[] % 5 == 0 - # Inject NaN into gradient on every 5th iteration - state.grad .= NaN - end - return false - end - # Should not throw error and should complete all iterations - sol = solve(prob, Optimisers.Adam(0.01), maxiters = 20, progress = false, callback = nan_callback) + sol = solve(prob, Optimisers.Adam(0.01), maxiters = 50, progress = false) # Verify solution completed all iterations - @test sol.stats.iterations == 20 + @test sol.stats.iterations == 50 # Verify parameters are not NaN (would be NaN if updates were applied with NaN gradients) @test all(!isnan, sol.u) @test all(isfinite, sol.u) - # Test with Inf gradients - grad_counter_inf = Ref(0) - prob_inf = OptimizationProblem(optprob, x0, _p) - - function inf_callback(state, l) - grad_counter_inf[] += 1 - if grad_counter_inf[] % 7 == 0 - # Inject Inf into gradient on every 7th iteration - state.grad .= Inf + # Function that returns Inf when parameters are in certain regions + function weird_inf_function(x, p) + # Return Inf when x[1] is close to certain values + if abs(x[1] - 0.2) < 0.05 || abs(x[1] + 0.2) < 0.05 + return Inf end - return false + return (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2 end - sol_inf = solve(prob_inf, Optimisers.Adam(0.01), maxiters = 20, progress = false, callback = inf_callback) + optprob_inf = OptimizationFunction(weird_inf_function, OptimizationBase.AutoZygote()) + prob_inf = OptimizationProblem(optprob_inf, x0, _p) + + sol_inf = solve(prob_inf, Optimisers.Adam(0.01), maxiters = 50, progress = false) - @test sol_inf.stats.iterations == 20 + @test sol_inf.stats.iterations == 50 @test all(!isnan, sol_inf.u) @test all(isfinite, sol_inf.u) end From 16140e65c2bdca8cbe206731ab73f33634e1714e Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Fri, 24 Oct 2025 09:13:45 -0400 Subject: [PATCH 10/10] Update test functions to produce NaN/Inf gradients naturally - Use sqrt and max to produce NaN when x goes negative - Use 1/x pattern to produce Inf gradients - Functions naturally produce problematic gradients during optimization --- lib/OptimizationOptimisers/test/runtests.jl | 24 ++++++++++----------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/lib/OptimizationOptimisers/test/runtests.jl b/lib/OptimizationOptimisers/test/runtests.jl index cc40c103b..269d01932 100644 --- a/lib/OptimizationOptimisers/test/runtests.jl +++ b/lib/OptimizationOptimisers/test/runtests.jl @@ -137,16 +137,15 @@ end @testset "NaN/Inf gradient handling" begin # Test that optimizer skips updates when gradients contain NaN or Inf - # Function that returns NaN when parameters are in certain regions + # Function that can produce NaN due to sqrt of negative number function weird_nan_function(x, p) - # Return NaN when x[1] is close to certain values to simulate numerical issues - if abs(x[1] - 0.3) < 0.05 || abs(x[1] + 0.3) < 0.05 - return NaN - end - return (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2 + val = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2 + # sqrt of a value that can become negative produces NaN + val += sqrt(max(x[1], 0.0)) * 0.01 + return val end - x0 = zeros(2) + x0 = [-0.5, 0.1] # Start with negative x[1] to trigger sqrt of negative _p = [1.0, 100.0] optprob = OptimizationFunction(weird_nan_function, OptimizationBase.AutoZygote()) @@ -162,13 +161,12 @@ end @test all(!isnan, sol.u) @test all(isfinite, sol.u) - # Function that returns Inf when parameters are in certain regions + # Function with 1/x that can produce Inf gradient when x is very small function weird_inf_function(x, p) - # Return Inf when x[1] is close to certain values - if abs(x[1] - 0.2) < 0.05 || abs(x[1] + 0.2) < 0.05 - return Inf - end - return (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2 + val = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2 + # 1/(x[1] + 0.01) can have very large gradient near x[1] = -0.01 + val += 0.01 / (abs(x[1] - 0.1) + 1e-8) + return val end optprob_inf = OptimizationFunction(weird_inf_function, OptimizationBase.AutoZygote())