@@ -87,3 +87,75 @@ for T in (:Any,)
8787 @eval Base.:* (a:: AbstractThunk , b:: $T ) = unthunk (a) * b
8888 @eval Base.:* (a:: $T , b:: AbstractThunk ) = a * unthunk (b)
8989end
90+
91+ # ################# Composite ##############################################################
92+
93+ # We intentionally do not define, `Base.*(::Composite, ::Composite)` as that is not meaningful
94+ # In general one doesn't have to represent multiplications of 2 differentials
95+ # Only of a differential and a scaling factor (generally `Real`)
96+ Base.* (s:: Any , comp:: Composite ) = map (x-> s* x, comp)
97+ Base.* (comp:: Composite , s:: Any ) = s* comp
98+
99+ function Base.:+ (a:: Composite{Primal, NamedTuple{an}} , b:: Composite{Primal, NamedTuple{bn}} ) where Primal
100+ # Base on the `merge(:;NamedTuple, ::NamedTuple)` code from Base.
101+ # https://github.com/JuliaLang/julia/blob/592748adb25301a45bd6edef3ac0a93eed069852/base/namedtuple.jl#L220-L231
102+ if @generated
103+ names = Base. merge_names (an, bn)
104+ types = Base. merge_types (names, a, b)
105+
106+ vals = map (names) do field
107+ a_field = :(getproperty (:a , $ (QuoteNode (field))))
108+ b_field = :(getproperty (:b , $ (QuoteNode (field))))
109+ val_expr = if Base. sym_in (field, an)
110+ if Base. sym_in (field, bn)
111+ # in both
112+ :($ a_field + $ b_field)
113+ else
114+ # only in `an`
115+ a_field
116+ end
117+ else # must be in `b` only
118+ b_field
119+ end
120+ end
121+ return :(NamedTuple {$names, $types} (($ (vals... ),)))
122+ else
123+ names = Base. merge_names (an, bn)
124+ types = Base. merge_types (names, typeof (a), typeof (b))
125+ vals = map (names) do field
126+ val_expr = if Base. sym_in (field, an)
127+ a_field = getproperty (a, field)
128+ if Base. sym_in (field, bn)
129+ # in both
130+ b_field = getproperty (a, field)
131+ :($ a_field + $ b_field)
132+ else
133+ # only in `an`
134+ a_field
135+ end
136+ else # must be in `b` only
137+ b_field = getproperty (a, field)
138+ b_field
139+ end
140+ end
141+ NamedTuple {names,types} (map (n-> getfield (sym_in (n, bn) ? b : a, n), names))
142+ end
143+ end
144+ end
145+
146+ # this should not need to be generated, # TODO test that
147+ function Base.:+ (a:: Composite{Primal, <:Tuple} , b:: Composite{Primal, <:Tuple} ) where Primal
148+ # TODO : should we even allow it on different lengths?
149+ short, long = length (a) < length (b) ? (a. backing, b. backing) : (b. backing, a. backing)
150+ backing = ntuple (length (long)) do ii
151+ long_val = getfield (long, ii)
152+ if ii <= length (short)
153+ short_val = getfield (short, ii)
154+ return short_val + long_val
155+ else
156+ return long_val
157+ end
158+ end
159+
160+ return Composite {Primal, typeof(backing)} (backing)
161+ end
0 commit comments