@@ -149,6 +149,7 @@ Base.iterate(::One, ::Any) = nothing
149149# ####
150150# #### `AbstractThunk
151151# ####
152+
152153abstract type AbstractThunk <: AbstractDifferential end
153154
154155Base. Broadcast. broadcastable (x:: AbstractThunk ) = broadcastable (extern (x))
@@ -237,8 +238,17 @@ macro thunk(body)
237238 return :(Thunk ($ (esc (func))))
238239end
239240
241+ """
242+ unthunk(x)
243+
244+ `unthunk` removes 1 layer of thunking from an `AbstractThunk`,
245+ and on all other types is the `identity` function.
246+ """
247+ unthunk (x) = x
248+ unthunk (x:: Thunk ) = x ()
249+
240250# have to define this here after `@thunk` and `Thunk` is defined
241- Base. conj (x:: AbstractThunk ) = @thunk (conj (extern (x)))
251+ Base. conj (x:: AbstractThunk ) = @thunk (conj (unthunk (x)))
242252
243253(x:: Thunk )() = x. f ()
244254@inline unthunk (x:: Thunk ) = x ()
@@ -284,6 +294,73 @@ function itself, when that function is not a closure.
284294"""
285295const NO_FIELDS = DoesNotExist ()
286296
297+
298+ """
299+ Composite{P, T} <: AbstractDifferential
300+
301+ This type represents the differential for a `struct`/`NamedTuple`, or `Tuple`.
302+ `P` is the the corresponding primal type that this is a differential for.
303+
304+ `Composite{P}` should have fields (technically properties), that match to a subset of the
305+ fields of the primal type; and each should be a differential type matching to the primal
306+ type of that field.
307+ Fields of the P that are not present in the Composite are treated as `Zero`.
308+
309+ `T` is an implementation detail representing the backing data structure.
310+ For Tuple it will be a Tuple, and for everything else it will be a `NamedTuple`.
311+ It should not be passed in by user.
312+ """
313+ struct Composite{P, T} <: AbstractDifferential
314+ # Note: If T is a Tuple, then P is also a Tuple
315+ # (but potentially a different one, as it doesn't contain differentials)
316+ backing:: T
317+ end
318+
319+ function Composite {P} (; kwargs... ) where P
320+ backing = (; kwargs... ) # construct as NamedTuple
321+ return Composite {P, typeof(backing)} (backing)
322+ end
323+
324+ function Composite {P} (args... ) where P
325+ return Composite {P, typeof(args)} (args)
326+ end
327+
328+ function Base. show (io:: IO , comp:: Composite{P} ) where P
329+ print (io, " Composite{" )
330+ show (io, P)
331+ print (io, " }" )
332+ # allow Tuple or NamedTuple `show` to do the rendering of brackets etc
333+ show (io, backing (comp))
334+ end
335+
336+ Base. convert (:: Type{<:NamedTuple} , comp:: Composite{<:Any, <:NamedTuple} ) = backing (comp)
337+ Base. convert (:: Type{<:Tuple} , comp:: Composite{<:Any, <:Tuple} ) = backing (comp)
338+
339+ Base. getindex (comp:: Composite , idx) = getindex (backing (comp), idx)
340+ Base. getproperty (comp:: Composite , idx:: Int ) = getproperty (backing (comp), idx) # for Tuple
341+ Base. getproperty (comp:: Composite , idx:: Symbol ) = getproperty (backing (comp), idx)
342+ Base. propertynames (comp:: Composite ) = propertynames (backing (comp))
343+
344+ Base. iterate (comp:: Composite , args... ) = iterate (backing (comp), args... )
345+ Base. length (comp:: Composite ) = length (backing (comp))
346+ Base. eltype (:: Type{<:Composite{<:Any, T}} ) where T = eltype (T)
347+
348+ function Base. map (f, comp:: Composite{P, <:Tuple} ) where P
349+ vals:: Tuple = map (f, backing (comp))
350+ return Composite {P, typeof(vals)} (vals)
351+ end
352+ function Base. map (f, comp:: Composite{P, <:NamedTuple{L}} ) where {P, L}
353+ vals = map (f, Tuple (backing (comp)))
354+ named_vals = NamedTuple {L, typeof(vals)} (vals)
355+ return Composite {P, typeof(named_vals)} (named_vals)
356+ end
357+
358+ Base. conj (comp:: Composite ) = map (conj, comp)
359+
360+ extern (comp:: Composite ) = backing (map (extern, comp)) # gives a NamedTuple or Tuple
361+
362+ #= =============================================================================#
363+
287364"""
288365 refine_differential(𝒟::Type, der)
289366
0 commit comments