Skip to content

Commit 65226ae

Browse files
update to SciMLSensitivity
1 parent 4dcb907 commit 65226ae

File tree

7 files changed

+13
-13
lines changed

7 files changed

+13
-13
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
name = "DiffEqFlux"
22
uuid = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
33
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
4-
version = "1.50.0"
4+
version = "1.51.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
99
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
1010
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
1111
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
12-
DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2"
1312
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1413
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1514
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
@@ -32,6 +31,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
3231
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
3332
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
3433
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
34+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
3535
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3636
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
3737
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
@@ -43,7 +43,6 @@ Cassette = "0.3.7"
4343
ConsoleProgressMonitor = "0.1"
4444
DataInterpolations = "3.3"
4545
DiffEqBase = "6.41"
46-
DiffEqSensitivity = "6.65"
4746
DiffResults = "1.0"
4847
Distributions = "0.23, 0.24, 0.25"
4948
DistributionsAD = "0.6"
@@ -62,6 +61,7 @@ RecursiveArrayTools = "2"
6261
Reexport = "0.2, 1"
6362
Requires = "0.5, 1.0"
6463
SciMLBase = "1"
64+
SciMLSensitivity = "7"
6565
StaticArrays = "0.11, 0.12, 1"
6666
TerminalLoggers = "0.1"
6767
Zygote = "0.5, 0.6"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[deps]
22
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
3-
DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2"
3+
SciMLSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2"
44
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
55
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
66
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

docs/src/examples/GPUs.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ condition is a GPU array. Thus, for example, we can define a neural ODE by hand
55
that runs on the GPU (if no GPU is available, the calculation defaults back to the CPU):
66

77
```julia
8-
using DifferentialEquations, Flux, DiffEqFlux, DiffEqSensitivity
8+
using DifferentialEquations, Flux, DiffEqFlux, SciMLSensitivity
99

1010
using Random
1111
rng = Random.default_rng()
@@ -70,7 +70,7 @@ same code works on CPUs and GPUs, dependent on `using CUDA`.
7070

7171
```julia
7272
using Flux, DiffEqFlux, Optimization, OptimizationFlux, Zygote,
73-
OrdinaryDiffEq, Plots, CUDA, DiffEqSensitivity, Random, ComponentArrays
73+
OrdinaryDiffEq, Plots, CUDA, SciMLSensitivity, Random, ComponentArrays
7474
CUDA.allowscalar(false) # Makes sure no slow operations are occuring
7575

7676
#rng for Lux.setup

docs/src/examples/collocation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ pretraining the neural network against a smoothed collocation of the
55
data. First the example and then an explanation.
66

77
```@example collocation_cp
8-
using Lux, DiffEqFlux, OrdinaryDiffEq, DiffEqSensitivity, Optimization, OptimizationFlux, Plots
8+
using Lux, DiffEqFlux, OrdinaryDiffEq, SciMLSensitivity, Optimization, OptimizationFlux, Plots
99
1010
using Random
1111
rng = Random.default_rng()

src/DiffEqFlux.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
module DiffEqFlux
22

33
using Adapt, Base.Iterators, ConsoleProgressMonitor, DataInterpolations,
4-
DiffEqBase, DiffEqSensitivity, DiffResults, Distributions, DistributionsAD,
4+
DiffEqBase, SciMLSensitivity, DiffResults, Distributions, DistributionsAD,
55
ForwardDiff, Optimization, OptimizationPolyalgorithms, LinearAlgebra,
66
Logging, LoggingExtras, Printf, ProgressLogging, Random, RecursiveArrayTools,
77
Reexport, SciMLBase, StaticArrays, TerminalLoggers, Zygote, ZygoteRules
88

9-
@reexport using DiffEqSensitivity
9+
@reexport using SciMLSensitivity
1010
@reexport using Zygote
1111

1212
# deprecate

src/fast_layers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,12 +407,12 @@ paramlength(f::StaticDense{out,in,bias}) where {out,in,bias} = out*(in + bias)
407407
initial_params(f::StaticDense) = f.initial_params()
408408

409409
# Override FastDense to exclude the branch from the check
410-
function Cassette.overdub(ctx::DiffEqSensitivity.HasBranchingCtx, f::FastDense, x, p)
410+
function Cassette.overdub(ctx::SciMLSensitivity.HasBranchingCtx, f::FastDense, x, p)
411411
y = reshape(p[1:(f.out*f.in)],f.out,f.in)*x
412412
Cassette.@overdub ctx f.σ.(y)
413413
end
414414

415-
function Cassette.overdub(ctx::DiffEqSensitivity.HasBranchingCtx, f::StaticDense{out,in,bias}, x, p) where {out,in,bias}
415+
function Cassette.overdub(ctx::SciMLSensitivity.HasBranchingCtx, f::StaticDense{out,in,bias}, x, p) where {out,in,bias}
416416
y = reshape(p[1:(out*in)],out,in)*x
417417
Cassette.@overdub ctx f.σ.(y)
418418
end

src/neural_de.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ derivatives of the loss backwards in time.
1212
```julia
1313
NeuralODE(model,tspan,alg=nothing,args...;kwargs...)
1414
NeuralODE(model::FastChain,tspan,alg=nothing,args...;
15-
sensealg=InterpolatingAdjoint(autojacvec=DiffEqSensitivity.ReverseDiffVJP(true)),
15+
sensealg=InterpolatingAdjoint(autojacvec=SciMLSensitivity.ReverseDiffVJP(true)),
1616
kwargs...)
1717
```
1818
@@ -490,7 +490,7 @@ the constraint equations.
490490
```julia
491491
NeuralODEMM(model,constraints_model,tspan,mass_matrix,alg=nothing,args...;kwargs...)
492492
NeuralODEMM(model::FastChain,tspan,mass_matrix,alg=nothing,args...;
493-
sensealg=InterpolatingAdjoint(autojacvec=DiffEqSensitivity.ReverseDiffVJP(true)),
493+
sensealg=InterpolatingAdjoint(autojacvec=SciMLSensitivity.ReverseDiffVJP(true)),
494494
kwargs...)
495495
```
496496

0 commit comments

Comments
 (0)