@@ -128,24 +128,6 @@ function Cassette.overdub(ctx::JacobianSparsityContext,
128128 end
129129end
130130
131- function Cassette. overdub (ctx:: JacobianSparsityContext ,
132- f:: typeof (Base. unsafe_copyto!),
133- X:: Tagged ,
134- xstart,
135- Y:: Tagged ,
136- ystart,
137- len)
138- S = ctx. metadata
139- if metatype (Y, ctx) <: JacInput
140- val = Cassette. fallback (ctx, f, X, xstart, Y, ystart, len)
141- nometa = Cassette. NoMetaMeta ()
142- X. meta. meta[xstart: xstart+ len- 1 ] .= (i-> Cassette. Meta (ProvinanceSet (i), nometa)). (ystart: ystart+ len- 1 )
143- val
144- else
145- Cassette. recurse (ctx, f, X, xstart, Y, ystart, len)
146- end
147- end
148-
149131function jacobian_sparsity (f!, Y, X, args... ;
150132 sparsity= Sparsity (length (Y), length (X)),
151133 raw = false )
@@ -170,29 +152,29 @@ function jacobian_sparsity(f!, Y, X, args...;
170152 end
171153end
172154
173- function Cassette. overdub (ctx:: SparsityContext ,
155+ function Cassette. overdub (ctx:: JacobianSparsityContext ,
174156 f:: typeof (Base. unsafe_copyto!),
175157 X:: Tagged ,
176158 xstart,
177159 Y:: Tagged ,
178160 ystart,
179161 len)
180162 S = ctx. metadata
181- if ismetatype (Y, ctx, JacInput) && ismetatype (X, ctx, JacOutput)
163+ if metatype (Y, ctx) <: JacInput && metatype (X, ctx) <: JacOutput
182164 # Write directly to the output sparsity
183165 val = Cassette. fallback (ctx, f, X, xstart, Y, ystart, len)
184166 for (i, j) in zip (xstart: xstart+ len- 1 , ystart: ystart+ len- 1 )
185167 push! (S, i, j)
186168 end
187169 val
188- elseif ismetatype (Y, ctx, JacInput)
170+ elseif metatype (Y, ctx) <: JacInput
189171 # Keep around a ProvinanceSet
190172 val = Cassette. fallback (ctx, f, X, xstart, Y, ystart, len)
191173 nometa = Cassette. NoMetaMeta ()
192- rhs = (i-> Cassette. Meta (pset (i), nometa)). (ystart: ystart+ len- 1 )
174+ rhs = (i-> Cassette. Meta (ProvinanceSet (i), nometa)). (ystart: ystart+ len- 1 )
193175 X. meta. meta[xstart: xstart+ len- 1 ] .= rhs
194176 val
195- elseif ismetatype (X, ctx, JacOutput)
177+ elseif metatype (X, ctx) <: JacOutput
196178 val = Cassette. fallback (ctx, f, X, xstart, Y, ystart, len)
197179 for (i, j) in zip (xstart: xstart+ len- 1 , ystart: ystart+ len- 1 )
198180 y = Cassette. @overdub ctx Y[j]
0 commit comments