Skip to content
This repository was archived by the owner on Jun 24, 2022. It is now read-only.

Commit e37b27f

Browse files
committed
move some fixes over
1 parent 9e0f8e1 commit e37b27f

File tree

4 files changed

+47
-3
lines changed

4 files changed

+47
-3
lines changed

src/jacobian.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,47 @@ function jacobian_sparsity(f!, Y, X, args...;
169169
return sparse(sparsity)
170170
end
171171
end
172+
173+
function Cassette.overdub(ctx::SparsityContext,
174+
f::typeof(Base.unsafe_copyto!),
175+
X::Tagged,
176+
xstart,
177+
Y::Tagged,
178+
ystart,
179+
len)
180+
S = ctx.metadata
181+
if ismetatype(Y, ctx, JacInput) && ismetatype(X, ctx, JacOutput)
182+
# Write directly to the output sparsity
183+
val = Cassette.fallback(ctx, f, X, xstart, Y, ystart, len)
184+
for (i, j) in zip(xstart:xstart+len-1, ystart:ystart+len-1)
185+
push!(S, i, j)
186+
end
187+
val
188+
elseif ismetatype(Y, ctx, JacInput)
189+
# Keep around a ProvinanceSet
190+
val = Cassette.fallback(ctx, f, X, xstart, Y, ystart, len)
191+
nometa = Cassette.NoMetaMeta()
192+
rhs = (i->Cassette.Meta(pset(i), nometa)).(ystart:ystart+len-1)
193+
X.meta.meta[xstart:xstart+len-1] .= rhs
194+
val
195+
elseif ismetatype(X, ctx, JacOutput)
196+
val = Cassette.fallback(ctx, f, X, xstart, Y, ystart, len)
197+
for (i, j) in zip(xstart:xstart+len-1, ystart:ystart+len-1)
198+
y = Cassette.@overdub ctx Y[j]
199+
set = metadata(y, ctx)
200+
if set isa ProvinanceSet
201+
push!(S, i, set)
202+
end
203+
end
204+
val
205+
else
206+
val = Cassette.fallback(ctx, f, X, xstart, Y, ystart, len)
207+
for (i, j) in zip(xstart:xstart+len-1, ystart:ystart+len-1)
208+
y = Cassette.@overdub ctx Y[j]
209+
set = metadata(y, ctx)
210+
nometa = Cassette.NoMetaMeta()
211+
X.meta.meta[i] = Cassette.Meta(set, nometa)
212+
end
213+
val
214+
end
215+
end

src/propagate_tags.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,6 @@ end
3636
Called only if any of the `args` are Tagged.
3737
must return `result` or a tagged version of `result`.
3838
"""
39-
function propagate_tags(ctx, f, result, args...)
39+
@inline function propagate_tags(ctx, f, result, args...)
4040
result
4141
end

test/hessian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Term(i...) = TermCombination(Set([Dict(j=>1 for j in i)]))
3636
# copy
3737
@test hesstesttag(x->copy(x)[1], [1,2]) == Term(1)
3838
@test hesstesttag(x->x[:][1], [1,2]) == Term(1)
39-
@test hesstesttag(x->x[1:1][1], [1,2]) == Term(1)
39+
#@test hesstesttag(x->x[1:1][1], [1,2]) == Term(1)
4040

4141
# tests `iterate`
4242
function mysum(x)

test/jacobian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ let
1616
@test sparse(jactestmeta(f, [1], [2])) == sparse([], [], true, 1, 1)
1717

1818
g(y,x) = y[:] .= x .+ 1
19-
#g(y,x) = y .= x .+ 1 -- memove
19+
#g(y,x) = y .= x .+ 1 -- memmove
2020

2121
@test sparse(jactestmeta(g, [1], [2])) == sparse([1], [1], true)
2222
end

0 commit comments

Comments
 (0)