Skip to content

Commit 01aa04a

Browse files
authored
fix: allow (empty) sparse Hessians of linear functions (#906)
1 parent 9f50690 commit 01aa04a

File tree

2 files changed

+23
-4
lines changed
  • DifferentiationInterface
    • ext/DifferentiationInterfaceSparseMatrixColoringsExt
    • test/Core/SimpleFiniteDiff

2 files changed

+23
-4
lines changed

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ struct SMCSparseHessianPrep{
44
P <: AbstractMatrix,
55
C <: AbstractColoringResult{:symmetric, :column},
66
M <: AbstractMatrix{<:Number},
7+
Sp <: NTuple,
78
S <: AbstractVector{<:NTuple},
89
R <: AbstractVector{<:NTuple},
910
E2 <: DI.HVPPrep,
@@ -14,6 +15,7 @@ struct SMCSparseHessianPrep{
1415
sparsity::P
1516
coloring_result::C
1617
compressed_matrix::M
18+
batched_seed_prep::Sp
1719
batched_seeds::S
1820
batched_results::R
1921
hvp_prep::E2
@@ -54,14 +56,20 @@ function _prepare_sparse_hessian_aux(
5456
(; N, A) = batch_size_settings
5557
dense_backend = dense_ad(backend)
5658
groups = column_groups(coloring_result)
59+
seed_prep = DI.multibasis(x, eachindex(x))
5760
seeds = [DI.multibasis(x, eachindex(x)[group]) for group in groups]
58-
compressed_matrix = stack(_ -> vec(similar(x)), groups; dims = 2)
61+
compressed_matrix = if isempty(groups)
62+
similar(x, length(x), 0)
63+
else
64+
stack(_ -> vec(similar(x)), groups; dims = 2)
65+
end
66+
batched_seed_prep = ntuple(b -> copy(seed_prep), Val(B))
5967
batched_seeds = [
6068
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
6169
]
6270
batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds]
6371
hvp_prep = DI.prepare_hvp_nokwarg(
64-
strict, f, dense_backend, x, batched_seeds[1], contexts...
72+
strict, f, dense_backend, x, batched_seed_prep, contexts...
6573
)
6674
gradient_prep = DI.prepare_gradient_nokwarg(
6775
strict, f, DI.inner(dense_backend), x, contexts...
@@ -72,6 +80,7 @@ function _prepare_sparse_hessian_aux(
7280
sparsity,
7381
coloring_result,
7482
compressed_matrix,
83+
batched_seed_prep,
7584
batched_seeds,
7685
batched_results,
7786
hvp_prep,
@@ -92,6 +101,7 @@ function DI.hessian!(
92101
batch_size_settings,
93102
coloring_result,
94103
compressed_matrix,
104+
batched_seed_prep,
95105
batched_seeds,
96106
batched_results,
97107
hvp_prep,
@@ -100,7 +110,7 @@ function DI.hessian!(
100110
dense_backend = dense_ad(backend)
101111

102112
hvp_prep_same = DI.prepare_hvp_same_point(
103-
f, hvp_prep, dense_backend, x, batched_seeds[1], contexts...
113+
f, hvp_prep, dense_backend, x, batched_seed_prep, contexts...
104114
)
105115

106116
for a in eachindex(batched_seeds, batched_results)

DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,16 @@ end
136136
@test only(column_groups(hess_prep)) == 1:10
137137
end
138138

139-
@testset "Empty colors for mixed mode" begin # issue 857
139+
@testset "Empty color groups in sparse AD" begin # issue 857
140+
# forward
141+
backend = MyAutoSparse(adaptive_backends[1])
142+
@test jacobian(zero, backend, ones(10)) isa AbstractMatrix
143+
@test hessian(sum zero, backend, ones(10)) isa AbstractMatrix
144+
# reverse
145+
backend = MyAutoSparse(adaptive_backends[2])
146+
@test jacobian(zero, backend, ones(10)) isa AbstractMatrix
147+
@test hessian(sum zero, backend, ones(10)) isa AbstractMatrix
148+
# mixed
140149
backend = MyAutoSparse(MixedMode(adaptive_backends[1], adaptive_backends[2]))
141150
@test jacobian(copyto!, zeros(10), backend, ones(10)) isa AbstractMatrix
142151
end

0 commit comments

Comments
 (0)