@@ -4,6 +4,7 @@ struct SMCSparseHessianPrep{
44 P <: AbstractMatrix ,
55 C <: AbstractColoringResult{:symmetric, :column} ,
66 M <: AbstractMatrix{<:Number} ,
7+ Sp <: NTuple ,
78 S <: AbstractVector{<:NTuple} ,
89 R <: AbstractVector{<:NTuple} ,
910 E2 <: DI.HVPPrep ,
@@ -14,6 +15,7 @@ struct SMCSparseHessianPrep{
1415 sparsity:: P
1516 coloring_result:: C
1617 compressed_matrix:: M
18+ batched_seed_prep:: Sp
1719 batched_seeds:: S
1820 batched_results:: R
1921 hvp_prep:: E2
@@ -54,14 +56,20 @@ function _prepare_sparse_hessian_aux(
5456 (; N, A) = batch_size_settings
5557 dense_backend = dense_ad (backend)
5658 groups = column_groups (coloring_result)
59+ seed_prep = DI. multibasis (x, eachindex (x))
5760 seeds = [DI. multibasis (x, eachindex (x)[group]) for group in groups]
58- compressed_matrix = stack (_ -> vec (similar (x)), groups; dims = 2 )
61+ compressed_matrix = if isempty (groups)
62+ similar (x, length (x), 0 )
63+ else
64+ stack (_ -> vec (similar (x)), groups; dims = 2 )
65+ end
66+ batched_seed_prep = ntuple (b -> copy (seed_prep), Val (B))
5967 batched_seeds = [
6068 ntuple (b -> seeds[1 + ((a - 1 ) * B + (b - 1 )) % N], Val (B)) for a in 1 : A
6169 ]
6270 batched_results = [ntuple (b -> similar (x), Val (B)) for _ in batched_seeds]
6371 hvp_prep = DI. prepare_hvp_nokwarg (
64- strict, f, dense_backend, x, batched_seeds[ 1 ] , contexts...
72+ strict, f, dense_backend, x, batched_seed_prep , contexts...
6573 )
6674 gradient_prep = DI. prepare_gradient_nokwarg (
6775 strict, f, DI. inner (dense_backend), x, contexts...
@@ -72,6 +80,7 @@ function _prepare_sparse_hessian_aux(
7280 sparsity,
7381 coloring_result,
7482 compressed_matrix,
83+ batched_seed_prep,
7584 batched_seeds,
7685 batched_results,
7786 hvp_prep,
@@ -92,6 +101,7 @@ function DI.hessian!(
92101 batch_size_settings,
93102 coloring_result,
94103 compressed_matrix,
104+ batched_seed_prep,
95105 batched_seeds,
96106 batched_results,
97107 hvp_prep,
@@ -100,7 +110,7 @@ function DI.hessian!(
100110 dense_backend = dense_ad (backend)
101111
102112 hvp_prep_same = DI. prepare_hvp_same_point (
103- f, hvp_prep, dense_backend, x, batched_seeds[ 1 ] , contexts...
113+ f, hvp_prep, dense_backend, x, batched_seed_prep , contexts...
104114 )
105115
106116 for a in eachindex (batched_seeds, batched_results)
0 commit comments