Skip to content

Commit bdfcc99

Browse files
Add comprehensive docstrings for utility and internal functions
- Document epsilon computation functions with mathematical details - Add comprehensive JVPCache struct and constructor documentation - Document setindex overloads and utility functions (_vec, _mat) - Add docstrings for sparse Jacobian internal functions - Document Hessian utility functions and cache helpers - Include examples, parameter descriptions, and implementation notes 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 1161d80 commit bdfcc99

File tree

6 files changed

+413
-10
lines changed

6 files changed

+413
-10
lines changed

src/FiniteDiff.jl

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,126 @@ using LinearAlgebra, ArrayInterface
99

1010
import Base: resize!
1111

12+
"""
13+
_vec(x)
14+
15+
Internal utility function to vectorize arrays while preserving scalars.
16+
17+
# Arguments
18+
- `x`: Array or scalar
19+
20+
# Returns
21+
- `vec(x)` for arrays, `x` unchanged for scalars
22+
"""
1223
_vec(x) = vec(x)
1324
_vec(x::Number) = x
1425

26+
"""
27+
_mat(x)
28+
29+
Internal utility function to ensure matrix format.
30+
31+
Converts vectors to column matrices while preserving existing matrices.
32+
Used internally to ensure consistent matrix dimensions for operations.
33+
34+
# Arguments
35+
- `x`: Matrix or vector
36+
37+
# Returns
38+
- `x` unchanged if already a matrix
39+
- Reshaped column matrix if `x` is a vector
40+
"""
1541
_mat(x::AbstractMatrix) = x
1642
_mat(x::AbstractVector) = reshape(x, (axes(x, 1), Base.OneTo(1)))
1743

1844
# Setindex overloads without piracy
1945
setindex(x...) = Base.setindex(x...)
2046

47+
"""
48+
setindex(x::AbstractArray, v, i...)
49+
50+
Non-mutating setindex operation that returns a copy with modified elements.
51+
52+
Creates a mutable copy of array `x`, sets the specified indices to value `v`,
53+
and returns the modified copy. This avoids type piracy while providing
54+
setindex functionality for immutable arrays.
55+
56+
# Arguments
57+
- `x::AbstractArray`: Array to modify (not mutated)
58+
- `v`: Value to set at the specified indices
59+
- `i...`: Indices where the value should be set
60+
61+
# Returns
62+
- Modified copy of `x` with `x[i...] = v`
63+
64+
# Examples
65+
```julia
66+
x = [1, 2, 3]
67+
y = setindex(x, 99, 2) # y = [1, 99, 3], x unchanged
68+
```
69+
"""
2170
function setindex(x::AbstractArray, v, i...)
2271
_x = Base.copymutable(x)
2372
_x[i...] = v
2473
return _x
2574
end
2675

76+
"""
77+
setindex(x::AbstractVector, v, i::Int)
78+
79+
Broadcasted setindex operation for vectors using boolean masking.
80+
81+
Sets the i-th element of vector `x` to value `v` using broadcasting operations.
82+
This implementation uses boolean masks to avoid explicit copying and provide
83+
efficient vectorized operations.
84+
85+
# Arguments
86+
- `x::AbstractVector`: Input vector
87+
- `v`: Value to set at index `i`
88+
- `i::Int`: Index to modify
89+
90+
# Returns
91+
- Vector with `x[i] = v`, computed via broadcasting
92+
93+
# Examples
94+
```julia
95+
x = [1.0, 2.0, 3.0]
96+
y = setindex(x, 99.0, 2) # [1.0, 99.0, 3.0]
97+
```
98+
"""
2799
function setindex(x::AbstractVector, v, i::Int)
28100
n = length(x)
29101
x .* (i .!== 1:n) .+ v .* (i .== 1:n)
30102
end
31103

104+
"""
105+
setindex(x::AbstractMatrix, v, i::Int, j::Int)
106+
107+
Broadcasted setindex operation for matrices using boolean masking.
108+
109+
Sets the (i,j)-th element of matrix `x` to value `v` using broadcasting operations.
110+
This implementation uses boolean masks to avoid explicit copying and provide
111+
efficient vectorized operations.
112+
113+
# Arguments
114+
- `x::AbstractMatrix`: Input matrix
115+
- `v`: Value to set at position (i,j)
116+
- `i::Int`: Row index to modify
117+
- `j::Int`: Column index to modify
118+
119+
# Returns
120+
- Matrix with `x[i,j] = v`, computed via broadcasting
121+
122+
# Examples
123+
```julia
124+
x = [1.0 2.0; 3.0 4.0]
125+
y = setindex(x, 99.0, 1, 2) # [1.0 99.0; 3.0 4.0]
126+
```
127+
128+
# Notes
129+
The implementation uses transposed broadcasting `(j .!== i:m)'` which appears
130+
to be a typo - should likely be `(j .!== 1:m)'` for correct column masking.
131+
"""
32132
function setindex(x::AbstractMatrix, v, i::Int, j::Int)
33133
n, m = Base.size(x)
34134
x .* (i .!== 1:n) .* (j .!== i:m)' .+ v .* (i .== 1:n) .* (j .== i:m)'

src/epsilons.jl

Lines changed: 116 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,126 @@
22
Very heavily inspired by Calculus.jl, but with an emphasis on performance and DiffEq API convenience.
33
=#
44

5-
#=
6-
Compute the finite difference interval epsilon.
5+
"""
6+
compute_epsilon(::Val{:forward}, x::T, relstep::Real, absstep::Real, dir::Real) where T<:Number
7+
8+
Compute the finite difference step size (epsilon) for forward finite differences.
9+
10+
The step size is computed as `max(relstep*abs(x), absstep)*dir`, which ensures
11+
numerical stability by using a relative step scaled by the magnitude of `x`
12+
when `x` is large, and an absolute step when `x` is small.
13+
14+
# Arguments
15+
- `::Val{:forward}`: Finite difference type indicator for forward differences
16+
- `x::T`: Point at which to compute the step size
17+
- `relstep::Real`: Relative step size factor
18+
- `absstep::Real`: Absolute step size fallback
19+
- `dir::Real`: Direction multiplier (typically ±1)
20+
21+
# Returns
22+
- Step size `ϵ` for forward finite difference: `f(x + ϵ)`
23+
724
Reference: Numerical Recipes, chapter 5.7.
8-
=#
25+
"""
926
@inline function compute_epsilon(::Val{:forward}, x::T, relstep::Real, absstep::Real, dir::Real) where T<:Number
1027
return max(relstep*abs(x), absstep)*dir
1128
end
1229

30+
"""
31+
compute_epsilon(::Val{:central}, x::T, relstep::Real, absstep::Real, dir=nothing) where T<:Number
32+
33+
Compute the finite difference step size (epsilon) for central finite differences.
34+
35+
The step size is computed as `max(relstep*abs(x), absstep)`, which ensures
36+
numerical stability by using a relative step scaled by the magnitude of `x`
37+
when `x` is large, and an absolute step when `x` is small.
38+
39+
# Arguments
40+
- `::Val{:central}`: Finite difference type indicator for central differences
41+
- `x::T`: Point at which to compute the step size
42+
- `relstep::Real`: Relative step size factor
43+
- `absstep::Real`: Absolute step size fallback
44+
- `dir`: Direction parameter (unused for central differences)
45+
46+
# Returns
47+
- Step size `ϵ` for central finite difference: `(f(x + ϵ) - f(x - ϵ)) / (2ϵ)`
48+
"""
1349
@inline function compute_epsilon(::Val{:central}, x::T, relstep::Real, absstep::Real, dir=nothing) where T<:Number
1450
return max(relstep*abs(x), absstep)
1551
end
1652

53+
"""
54+
compute_epsilon(::Val{:hcentral}, x::T, relstep::Real, absstep::Real, dir=nothing) where T<:Number
55+
56+
Compute the finite difference step size (epsilon) for central finite differences in Hessian computations.
57+
58+
The step size is computed as `max(relstep*abs(x), absstep)`, which ensures
59+
numerical stability by using a relative step scaled by the magnitude of `x`
60+
when `x` is large, and an absolute step when `x` is small.
61+
62+
# Arguments
63+
- `::Val{:hcentral}`: Finite difference type indicator for Hessian central differences
64+
- `x::T`: Point at which to compute the step size
65+
- `relstep::Real`: Relative step size factor
66+
- `absstep::Real`: Absolute step size fallback
67+
- `dir`: Direction parameter (unused for central differences)
68+
69+
# Returns
70+
- Step size `ϵ` for Hessian central finite differences
71+
"""
1772
@inline function compute_epsilon(::Val{:hcentral}, x::T, relstep::Real, absstep::Real, dir=nothing) where T<:Number
1873
return max(relstep*abs(x), absstep)
1974
end
2075

76+
"""
77+
compute_epsilon(::Val{:complex}, x::T, ::Union{Nothing,T}=nothing, ::Union{Nothing,T}=nothing, dir=nothing) where T<:Real
78+
79+
Compute the finite difference step size (epsilon) for complex step differentiation.
80+
81+
For complex step differentiation, the step size is simply the machine epsilon `eps(T)`,
82+
which provides optimal accuracy since complex step differentiation doesn't suffer from
83+
subtractive cancellation errors.
84+
85+
# Arguments
86+
- `::Val{:complex}`: Finite difference type indicator for complex step differentiation
87+
- `x::T`: Point at which to compute the step size (unused, type determines epsilon)
88+
- Additional arguments are unused for complex step differentiation
89+
90+
# Returns
91+
- Machine epsilon `eps(T)` for complex step differentiation: `imag(f(x + iϵ)) / ϵ`
92+
93+
# Notes
94+
Complex step differentiation computes derivatives as `imag(f(x + iϵ)) / ϵ` where `ϵ = eps(T)`.
95+
This method provides machine precision accuracy without subtractive cancellation.
96+
"""
2197
@inline function compute_epsilon(::Val{:complex}, x::T, ::Union{Nothing,T}=nothing, ::Union{Nothing,T}=nothing, dir=nothing) where T<:Real
2298
return eps(T)
2399
end
24100

101+
"""
102+
default_relstep(fdtype, ::Type{T}) where T<:Number
103+
104+
Compute the default relative step size for finite difference approximations.
105+
106+
Returns optimal default step sizes based on the finite difference method and
107+
numerical type, balancing truncation error and round-off error.
108+
109+
# Arguments
110+
- `fdtype`: Finite difference type (`Val(:forward)`, `Val(:central)`, `Val(:hcentral)`, etc.)
111+
- `::Type{T}`: Numerical type for which to compute the step size
112+
113+
# Returns
114+
- `sqrt(eps(real(T)))` for forward differences
115+
- `cbrt(eps(real(T)))` for central differences
116+
- `eps(T)^(1/4)` for Hessian central differences
117+
- `one(real(T))` for other types
118+
119+
# Notes
120+
These step sizes minimize the total error (truncation + round-off) for each method:
121+
- Forward differences have O(h) truncation error, optimal h ~ sqrt(eps)
122+
- Central differences have O(h²) truncation error, optimal h ~ eps^(1/3)
123+
- Hessian methods have O(h²) truncation error but involve more operations
124+
"""
25125
default_relstep(::Type{V}, T) where V = default_relstep(V(), T)
26126
@inline function default_relstep(::Val{fdtype}, ::Type{T}) where {fdtype,T<:Number}
27127
if fdtype==:forward
@@ -35,6 +135,19 @@ default_relstep(::Type{V}, T) where V = default_relstep(V(), T)
35135
end
36136
end
37137

138+
"""
139+
fdtype_error(::Type{T}=Float64) where T
140+
141+
Throw an informative error for unsupported finite difference type combinations.
142+
143+
# Arguments
144+
- `::Type{T}`: Return type of the function being differentiated
145+
146+
# Errors
147+
- For `Real` return types: suggests `Val{:forward}`, `Val{:central}`, `Val{:complex}`
148+
- For `Complex` return types: suggests `Val{:forward}`, `Val{:central}` (no complex step)
149+
- For other types: suggests the return type should be Real or Complex subtype
150+
"""
38151
function fdtype_error(::Type{T}=Float64) where T
39152
if T<:Real
40153
error("Unrecognized fdtype: valid values are Val{:forward}, Val{:central} and Val{:complex}.")

src/hessians.jl

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,52 @@ struct HessianCache{T,fdtype,inplace}
55
xmm::T
66
end
77

8-
#used to dispatch on StaticArrays
8+
"""
9+
_hessian_inplace(::Type{T}) where T
10+
_hessian_inplace(x)
11+
12+
Internal function to determine if Hessian computation should be performed in-place.
13+
14+
Returns `Val(true)` if the array type is mutable and supports in-place operations,
15+
`Val(false)` otherwise. Used to dispatch on StaticArrays vs mutable arrays.
16+
17+
# Arguments
18+
- `::Type{T}` or `x`: Array type or array instance
19+
20+
# Returns
21+
- `Val(true)` if the array type supports in-place mutation
22+
- `Val(false)` if the array type is immutable (e.g., StaticArray)
23+
"""
924
_hessian_inplace(::Type{T}) where T = Val(ArrayInterface.ismutable(T))
1025
_hessian_inplace(x) = _hessian_inplace(typeof(x))
26+
27+
"""
28+
__Symmetric(x)
29+
30+
Internal utility function that wraps a matrix in a `Symmetric` view.
31+
32+
# Arguments
33+
- `x`: Matrix to be wrapped
34+
35+
# Returns
36+
- `Symmetric(x)`: Symmetric view of the matrix
37+
"""
1138
__Symmetric(x) = Symmetric(x)
1239

40+
"""
41+
mutable_zeromatrix(x)
42+
43+
Internal utility function to create a mutable zero matrix with the same structure as `x`.
44+
45+
Creates a zero matrix compatible with `x` and ensures it's mutable for in-place operations.
46+
If the created matrix is immutable, it converts it to a mutable copy.
47+
48+
# Arguments
49+
- `x`: Array whose structure should be matched
50+
51+
# Returns
52+
- Mutable zero matrix with the same dimensions and compatible type as `x`
53+
"""
1354
function mutable_zeromatrix(x)
1455
A = ArrayInterface.zeromatrix(x)
1556
ArrayInterface.ismutable(A) ? A : Base.copymutable(A)

0 commit comments

Comments
 (0)