|
1 | | -struct PrimalAdditionFailedException{P} <: Exception |
2 | | - primal::P |
3 | | - differential::Composite{P} |
4 | | - original::Exception |
| 1 | +""" |
| 2 | + Composite{P, T} <: AbstractDifferential |
| 3 | +
|
| 4 | +This type represents the differential for a `struct`/`NamedTuple`, or `Tuple`. |
| 5 | +`P` is the the corresponding primal type that this is a differential for. |
| 6 | +
|
| 7 | +`Composite{P}` should have fields (technically properties), that match to a subset of the |
| 8 | +fields of the primal type; and each should be a differential type matching to the primal |
| 9 | +type of that field. |
| 10 | +Fields of the P that are not present in the Composite are treated as `Zero`. |
| 11 | +
|
| 12 | +`T` is an implementation detail representing the backing data structure. |
| 13 | +For Tuple it will be a Tuple, and for everything else it will be a `NamedTuple`. |
| 14 | +It should not be passed in by user. |
| 15 | +""" |
| 16 | +struct Composite{P, T} <: AbstractDifferential |
| 17 | + # Note: If T is a Tuple, then P is also a Tuple |
| 18 | + # (but potentially a different one, as it doesn't contain differentials) |
| 19 | + backing::T |
5 | 20 | end |
6 | 21 |
|
7 | | -function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} |
8 | | - println(io, "Could not construct $P after addition.") |
9 | | - println(io, "This probably means no default constructor is defined.") |
10 | | - println(io, "Either define a default constructor") |
11 | | - printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color=:blue) |
12 | | - println(io, "\nor overload") |
13 | | - printstyled(io, |
14 | | - "ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.differential)))"; |
15 | | - color=:blue |
16 | | - ) |
17 | | - println(io, "\nor overload") |
18 | | - printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color=:blue) |
19 | | - println(io, "\nOriginal Exception:") |
20 | | - printstyled(io, err.original; color=:yellow) |
21 | | - println(io) |
| 22 | +function Composite{P}(; kwargs...) where P |
| 23 | + backing = (; kwargs...) # construct as NamedTuple |
| 24 | + return Composite{P, typeof(backing)}(backing) |
| 25 | +end |
| 26 | + |
| 27 | +function Composite{P}(args...) where P |
| 28 | + return Composite{P, typeof(args)}(args) |
| 29 | +end |
| 30 | + |
| 31 | +function Base.show(io::IO, comp::Composite{P}) where P |
| 32 | + print(io, "Composite{") |
| 33 | + show(io, P) |
| 34 | + print(io, "}") |
| 35 | + # allow Tuple or NamedTuple `show` to do the rendering of brackets etc |
| 36 | + show(io, backing(comp)) |
22 | 37 | end |
23 | 38 |
|
| 39 | +Base.convert(::Type{<:NamedTuple}, comp::Composite{<:Any, <:NamedTuple}) = backing(comp) |
| 40 | +Base.convert(::Type{<:Tuple}, comp::Composite{<:Any, <:Tuple}) = backing(comp) |
| 41 | + |
| 42 | +Base.getindex(comp::Composite, idx) = getindex(backing(comp), idx) |
| 43 | +Base.getproperty(comp::Composite, idx::Int) = getproperty(backing(comp), idx) # for Tuple |
| 44 | +Base.getproperty(comp::Composite, idx::Symbol) = getproperty(backing(comp), idx) |
| 45 | +Base.propertynames(comp::Composite) = propertynames(backing(comp)) |
| 46 | + |
| 47 | +Base.iterate(comp::Composite, args...) = iterate(backing(comp), args...) |
| 48 | +Base.length(comp::Composite) = length(backing(comp)) |
| 49 | +Base.eltype(::Type{<:Composite{<:Any, T}}) where T = eltype(T) |
| 50 | + |
| 51 | +function Base.map(f, comp::Composite{P, <:Tuple}) where P |
| 52 | + vals::Tuple = map(f, backing(comp)) |
| 53 | + return Composite{P, typeof(vals)}(vals) |
| 54 | +end |
| 55 | +function Base.map(f, comp::Composite{P, <:NamedTuple{L}}) where{P, L} |
| 56 | + vals = map(f, Tuple(backing(comp))) |
| 57 | + named_vals = NamedTuple{L, typeof(vals)}(vals) |
| 58 | + return Composite{P, typeof(named_vals)}(named_vals) |
| 59 | +end |
| 60 | + |
| 61 | +Base.conj(comp::Composite) = map(conj, comp) |
| 62 | + |
| 63 | +extern(comp::Composite) = backing(map(extern, comp)) # gives a NamedTuple or Tuple |
| 64 | + |
| 65 | + |
24 | 66 | """ |
25 | 67 | backing(x) |
26 | 68 |
|
@@ -131,3 +173,36 @@ function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn} |
131 | 173 | return NamedTuple{names,types}(vals) |
132 | 174 | end |
133 | 175 | end |
| 176 | + |
| 177 | + |
| 178 | +struct PrimalAdditionFailedException{P} <: Exception |
| 179 | + primal::P |
| 180 | + differential::Composite{P} |
| 181 | + original::Exception |
| 182 | +end |
| 183 | + |
| 184 | +function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} |
| 185 | + println(io, "Could not construct $P after addition.") |
| 186 | + println(io, "This probably means no default constructor is defined.") |
| 187 | + println(io, "Either define a default constructor") |
| 188 | + printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color=:blue) |
| 189 | + println(io, "\nor overload") |
| 190 | + printstyled(io, |
| 191 | + "ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.differential)))"; |
| 192 | + color=:blue |
| 193 | + ) |
| 194 | + println(io, "\nor overload") |
| 195 | + printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color=:blue) |
| 196 | + println(io, "\nOriginal Exception:") |
| 197 | + printstyled(io, err.original; color=:yellow) |
| 198 | + println(io) |
| 199 | +end |
| 200 | + |
| 201 | +""" |
| 202 | + NO_FIELDS |
| 203 | +
|
| 204 | +Constant for the reverse-mode derivative with respect to a structure that has no fields. |
| 205 | +The most notable use for this is for the reverse-mode derivative with respect to the |
| 206 | +function itself, when that function is not a closure. |
| 207 | +""" |
| 208 | +const NO_FIELDS = DoesNotExist() |
0 commit comments