From 89e2d0c473f84734e26d8157cbc7897a72db439f Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Tue, 14 Oct 2025 22:45:01 +0330 Subject: [PATCH 1/9] Rewrite the progressbar part of `OptimizationOptimisers` --- lib/OptimizationOptimisers/Project.toml | 5 +- .../src/OptimizationOptimisers.jl | 118 +++++++++--------- 2 files changed, 60 insertions(+), 63 deletions(-) diff --git a/lib/OptimizationOptimisers/Project.toml b/lib/OptimizationOptimisers/Project.toml index 2c65ea722..a9f079f09 100644 --- a/lib/OptimizationOptimisers/Project.toml +++ b/lib/OptimizationOptimisers/Project.toml @@ -5,10 +5,10 @@ version = "0.3.12" [deps] OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" -Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [extras] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -19,6 +19,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" [compat] julia = "1.10" @@ -29,4 +30,4 @@ Optimisers = "0.2, 0.3, 0.4" Reexport = "1.2" [targets] -test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote"] +test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote", "Printf"] diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index a86b23ac8..9e51e5992 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -1,6 +1,6 @@ module OptimizationOptimisers -using Reexport, Printf, ProgressLogging +using Reexport, ProgressLogging, UUIDs @reexport using Optimisers, OptimizationBase using SciMLBase @@ -95,77 +95,73 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{ gevals = 0 t0 = time() breakall = false - begin - for epoch in 1:epochs - if breakall - break + progress_id = uuid4() + for epoch in 1:epochs, d in data + if cache.f.fg !== nothing && dataiterate + x = cache.f.fg(G, θ, d) + iterations += 1 + fevals += 1 + gevals += 1 + elseif dataiterate + cache.f.grad(G, θ, d) + x = cache.f(θ, d) + iterations += 1 + fevals += 2 + gevals += 1 + elseif cache.f.fg !== nothing + x = cache.f.fg(G, θ) + iterations += 1 + fevals += 1 + gevals += 1 + else + cache.f.grad(G, θ) + x = cache.f(θ) + iterations += 1 + fevals += 2 + gevals += 1 + end + opt_state = OptimizationBase.OptimizationState( + iter = iterations, + u = θ, + p = d, + objective = x[1], + grad = G, + original = state) + breakall = cache.callback(opt_state, x...) + if !(breakall isa Bool) + error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.") + elseif breakall + break + end + cache.progress && + @info ProgressLogging.Progress(progress_id, iterations / maxiters; + name = "loss: $(round(first(first(x)); digits=3))") + + if cache.solver_args.save_best + if first(x)[1] < first(min_err)[1] #found a better solution + min_opt = opt + min_err = x + min_θ = copy(θ) end - for (i, d) in enumerate(data) - if cache.f.fg !== nothing && dataiterate - x = cache.f.fg(G, θ, d) - iterations += 1 - fevals += 1 - gevals += 1 - elseif dataiterate - cache.f.grad(G, θ, d) - x = cache.f(θ, d) - iterations += 1 - fevals += 2 - gevals += 1 - elseif cache.f.fg !== nothing - x = cache.f.fg(G, θ) - iterations += 1 - fevals += 1 - gevals += 1 - else - cache.f.grad(G, θ) - x = cache.f(θ) - iterations += 1 - fevals += 2 - gevals += 1 - end - opt_state = OptimizationBase.OptimizationState( - iter = i + (epoch - 1) * length(data), + if iterations == length(data) * epochs #Last iter, revert to best. + opt = min_opt + x = min_err + θ = min_θ + cache.f.grad(G, θ, d) + opt_state = OptimizationBase.OptimizationState(iter = iterations, u = θ, p = d, objective = x[1], grad = G, original = state) breakall = cache.callback(opt_state, x...) - if !(breakall isa Bool) - error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.") - elseif breakall - break - end - msg = @sprintf("loss: %.3g", first(x)[1]) - cache.progress && ProgressLogging.@logprogress msg iterations/maxiters - - if cache.solver_args.save_best - if first(x)[1] < first(min_err)[1] #found a better solution - min_opt = opt - min_err = x - min_θ = copy(θ) - end - if iterations == length(data) * epochs #Last iter, revert to best. - opt = min_opt - x = min_err - θ = min_θ - cache.f.grad(G, θ, d) - opt_state = OptimizationBase.OptimizationState(iter = iterations, - u = θ, - p = d, - objective = x[1], - grad = G, - original = state) - breakall = cache.callback(opt_state, x...) - break - end - end - state, θ = Optimisers.update(state, θ, G) + break end end + state, θ = Optimisers.update(state, θ, G) end + cache.progress && @info ProgressLogging.Progress(progress_id; done = true) t1 = time() stats = OptimizationBase.OptimizationStats(; iterations, time = t1 - t0, fevals, gevals) From ee9f9f9465b4fd2ade355892bcf9fa409988a099 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sun, 19 Oct 2025 14:11:54 +0330 Subject: [PATCH 2/9] comply with ode progress log --- Project.toml | 4 +--- lib/OptimizationOptimisers/Project.toml | 2 -- .../src/OptimizationOptimisers.jl | 12 ++++++------ src/Optimization.jl | 2 +- 4 files changed, 8 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index a522b5b68..00af684ea 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,6 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" -ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -58,7 +57,6 @@ OptimizationOptimisers = "0.3" OrdinaryDiffEqTsit5 = "1" Pkg = "1" Printf = "1.10" -ProgressLogging = "0.1" Random = "1.10" Reexport = "1.2" ReverseDiff = "1" @@ -109,6 +107,6 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [targets] test = ["Aqua", "BenchmarkTools", "Boltz", "ComponentArrays", "DiffEqFlux", "Enzyme", "FiniteDiff", "Flux", "ForwardDiff", - "Ipopt", "IterTools", "Lux", "MLUtils", "ModelingToolkit", "Optim", "OptimizationLBFGSB", "OptimizationMOI", "OptimizationOptimJL", "OptimizationOptimisers", + "Ipopt", "IterTools", "Lux", "MLUtils", "ModelingToolkit", "Optim", "OptimizationLBFGSB", "OptimizationMOI", "OptimizationOptimJL", "OptimizationOptimisers", "OrdinaryDiffEqTsit5", "Pkg", "Random", "ReverseDiff", "SafeTestsets", "SciMLSensitivity", "SparseArrays", "Symbolics", "Test", "Tracker", "Zygote", "Mooncake"] diff --git a/lib/OptimizationOptimisers/Project.toml b/lib/OptimizationOptimisers/Project.toml index a9f079f09..70e4ca961 100644 --- a/lib/OptimizationOptimisers/Project.toml +++ b/lib/OptimizationOptimisers/Project.toml @@ -4,7 +4,6 @@ authors = ["Vaibhav Dixit and contributors"] version = "0.3.12" [deps] OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" -ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -24,7 +23,6 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" [compat] julia = "1.10" OptimizationBase = "3" -ProgressLogging = "0.1" SciMLBase = "2.58" Optimisers = "0.2, 0.3, 0.4" Reexport = "1.2" diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 9e51e5992..4cd14a2d7 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -1,6 +1,6 @@ module OptimizationOptimisers -using Reexport, ProgressLogging, UUIDs +using Reexport, UUIDs @reexport using Optimisers, OptimizationBase using SciMLBase @@ -134,9 +134,9 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{ break end cache.progress && - @info ProgressLogging.Progress(progress_id, iterations / maxiters; - name = "loss: $(round(first(first(x)); digits=3))") - + @logmsg(LogLevel(-1), "Optimization"; + _id = progress_id, message = "Loss: $(round(first(first(x)); digits=3))", + progress = iterations / maxiters) if cache.solver_args.save_best if first(x)[1] < first(min_err)[1] #found a better solution min_opt = opt @@ -160,8 +160,8 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{ end state, θ = Optimisers.update(state, θ, G) end - - cache.progress && @info ProgressLogging.Progress(progress_id; done = true) + cache.progress && @logmsg(LogLevel(-1), "Optimization"; + _id = progress_id, message = "Done", progress = 1.0) t1 = time() stats = OptimizationBase.OptimizationStats(; iterations, time = t1 - t0, fevals, gevals) diff --git a/src/Optimization.jl b/src/Optimization.jl index e419377ca..681ce22c8 100644 --- a/src/Optimization.jl +++ b/src/Optimization.jl @@ -11,7 +11,7 @@ if !isdefined(Base, :get_extension) using Requires end -using Logging, ProgressLogging, ConsoleProgressMonitor, TerminalLoggers, LoggingExtras +using Logging, ConsoleProgressMonitor, TerminalLoggers, LoggingExtras using ArrayInterface, Base.Iterators, SparseArrays, LinearAlgebra import OptimizationBase: instantiate_function, OptimizationCache, ReInitCache From 662a894f4ae2df81aae4d3e1f88c4dfa834f8c22 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sun, 19 Oct 2025 14:53:51 +0330 Subject: [PATCH 3/9] add `Logging` to deps --- lib/OptimizationOptimisers/Project.toml | 3 +++ lib/OptimizationOptimisers/src/OptimizationOptimisers.jl | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/OptimizationOptimisers/Project.toml b/lib/OptimizationOptimisers/Project.toml index 70e4ca961..de65c12ef 100644 --- a/lib/OptimizationOptimisers/Project.toml +++ b/lib/OptimizationOptimisers/Project.toml @@ -2,12 +2,14 @@ name = "OptimizationOptimisers" uuid = "42dfb2eb-d2b4-4451-abcd-913932933ac1" authors = ["Vaibhav Dixit and contributors"] version = "0.3.12" + [deps] OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" [extras] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -26,6 +28,7 @@ OptimizationBase = "3" SciMLBase = "2.58" Optimisers = "0.2, 0.3, 0.4" Reexport = "1.2" +Logging = "1.10" [targets] test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote", "Printf"] diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 4cd14a2d7..9836ed994 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -1,6 +1,6 @@ module OptimizationOptimisers -using Reexport, UUIDs +using Reexport, UUIDs, Logging @reexport using Optimisers, OptimizationBase using SciMLBase From 7b2be4e48bf2582c872b25b11f668255cd60a41f Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sun, 19 Oct 2025 14:55:19 +0330 Subject: [PATCH 4/9] newline --- lib/OptimizationOptimisers/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/OptimizationOptimisers/Project.toml b/lib/OptimizationOptimisers/Project.toml index 90844cd42..26ea7efde 100644 --- a/lib/OptimizationOptimisers/Project.toml +++ b/lib/OptimizationOptimisers/Project.toml @@ -2,6 +2,7 @@ name = "OptimizationOptimisers" uuid = "42dfb2eb-d2b4-4451-abcd-913932933ac1" authors = ["Vaibhav Dixit and contributors"] version = "0.3.13" + [deps] OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" From 6b17597efa8606a0a41715e8ec4bffb8feaa00c8 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sun, 19 Oct 2025 15:20:22 +0330 Subject: [PATCH 5/9] fix: seperate message --- lib/OptimizationOptimisers/src/OptimizationOptimisers.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 9836ed994..3bec045cc 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -133,10 +133,11 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{ elseif breakall break end - cache.progress && - @logmsg(LogLevel(-1), "Optimization"; - _id = progress_id, message = "Loss: $(round(first(first(x)); digits=3))", - progress = iterations / maxiters) + if cache.progress + message = "Loss: $(round(first(first(x)); digits = 3))" + @logmsg(LogLevel(-1), "Optimization"; _id = progress_id, + message = message, progress = iterations / maxiters) + end if cache.solver_args.save_best if first(x)[1] < first(min_err)[1] #found a better solution min_opt = opt From 2aa5eb3c51557d320bbbdb05f3d69e9b757a56a3 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sun, 19 Oct 2025 16:11:12 +0330 Subject: [PATCH 6/9] fix: use comma --- lib/OptimizationOptimisers/src/OptimizationOptimisers.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 3bec045cc..bf4884e70 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -135,8 +135,8 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{ end if cache.progress message = "Loss: $(round(first(first(x)); digits = 3))" - @logmsg(LogLevel(-1), "Optimization"; _id = progress_id, - message = message, progress = iterations / maxiters) + @logmsg(LogLevel(-1), "Optimization", _id=progress_id, + message=message, progress=iterations / maxiters) end if cache.solver_args.save_best if first(x)[1] < first(min_err)[1] #found a better solution @@ -161,8 +161,8 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{ end state, θ = Optimisers.update(state, θ, G) end - cache.progress && @logmsg(LogLevel(-1), "Optimization"; - _id = progress_id, message = "Done", progress = 1.0) + cache.progress && @logmsg(LogLevel(-1), "Optimization", + _id=progress_id, message="Done", progress=1.0) t1 = time() stats = OptimizationBase.OptimizationStats(; iterations, time = t1 - t0, fevals, gevals) From e82d35439ddbff2f93710a08c599e874891f133a Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 22 Oct 2025 06:31:06 -0400 Subject: [PATCH 7/9] Update lib/OptimizationOptimisers/Project.toml --- lib/OptimizationOptimisers/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/OptimizationOptimisers/Project.toml b/lib/OptimizationOptimisers/Project.toml index 26ea7efde..28989ef78 100644 --- a/lib/OptimizationOptimisers/Project.toml +++ b/lib/OptimizationOptimisers/Project.toml @@ -8,7 +8,6 @@ OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" [extras] From 1eaacc3639afde6f335c25445657d920eded3142 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 22 Oct 2025 06:31:24 -0400 Subject: [PATCH 8/9] Update lib/OptimizationOptimisers/src/OptimizationOptimisers.jl --- lib/OptimizationOptimisers/src/OptimizationOptimisers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index bf4884e70..60d1666c0 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -1,6 +1,6 @@ module OptimizationOptimisers -using Reexport, UUIDs, Logging +using Reexport, Logging @reexport using Optimisers, OptimizationBase using SciMLBase From e690813d9fed3bb3ec63f32549caf874ba916fbb Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 22 Oct 2025 06:31:33 -0400 Subject: [PATCH 9/9] Update lib/OptimizationOptimisers/src/OptimizationOptimisers.jl --- lib/OptimizationOptimisers/src/OptimizationOptimisers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 60d1666c0..de36f25a8 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -95,7 +95,7 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{ gevals = 0 t0 = time() breakall = false - progress_id = uuid4() + progress_id = :OptimizationOptimizersJL for epoch in 1:epochs, d in data if cache.f.fg !== nothing && dataiterate x = cache.f.fg(G, θ, d)