Skip to content
This repository was archived by the owner on Jun 24, 2022. It is now read-only.

Commit 92b8656

Browse files
committed
splice in some more fixes
1 parent e37b27f commit 92b8656

File tree

6 files changed

+48
-45
lines changed

6 files changed

+48
-45
lines changed

src/SparsityDetection.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ include("propagate_tags.jl")
1414
include("linearity.jl")
1515
include("jacobian.jl")
1616
include("hessian.jl")
17+
include("blas.jl")
1718

1819
Base.@deprecate sparsity!(args...) jacobian_sparsity(args...)
1920
Base.@deprecate hsparsity(args...) hessian_sparsity(args...)

src/blas.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Forward BLAS calls to generic implementation
2+
#
3+
using LinearAlgebra
4+
import LinearAlgebra.BLAS
5+
6+
# generic implementations
7+
8+
@reroute LinearAlgebra.BLAS.dot(x,y) LinearAlgebra.dot(Any, Any)
9+
@reroute LinearAlgebra.BLAS.axpy!(x, y) LinearAlgebra.axpy!(Any,
10+
AbstractArray,
11+
AbstractArray)
12+
13+
gengemv!(tA, α, A, x, β, y) = LinearAlgebra.generic_matvecmul!(y, tA, A, x, LinearAlgebra.MulAddMul(α, β))
14+
15+
@reroute LinearAlgebra.BLAS.gemv!(tA, α, A, x, β, y) gengemv!(Any, Any, Any, Any, Any, Any)

src/controlflow.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,7 @@ end
3434
"""
3535
function abstract_run(acc, ctx::Cassette.Context, overdub_fn, args...)
3636
pass_ctx = Cassette.similarcontext(ctx, pass=AbsintPass)
37-
@ambrun begin
3837
acc(Cassette.overdub(pass_ctx, overdub_fn, args...))
39-
@amb
40-
end
4138
end
4239

4340
"""

src/hessian.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -248,14 +248,3 @@ function hessian_sparsity(f, X, args...; raw=false)
248248
end
249249

250250

251-
# Forward BLAS calls to generic implementation
252-
#
253-
using LinearAlgebra
254-
import LinearAlgebra.BLAS
255-
256-
# generic implementations
257-
258-
@reroute HessianSparsityContext BLAS.dot dot(Any, Any)
259-
@reroute HessianSparsityContext BLAS.axpy! axpy!(Any,
260-
AbstractArray,
261-
AbstractArray)

src/jacobian.jl

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -128,24 +128,6 @@ function Cassette.overdub(ctx::JacobianSparsityContext,
128128
end
129129
end
130130

131-
function Cassette.overdub(ctx::JacobianSparsityContext,
132-
f::typeof(Base.unsafe_copyto!),
133-
X::Tagged,
134-
xstart,
135-
Y::Tagged,
136-
ystart,
137-
len)
138-
S = ctx.metadata
139-
if metatype(Y, ctx) <: JacInput
140-
val = Cassette.fallback(ctx, f, X, xstart, Y, ystart, len)
141-
nometa = Cassette.NoMetaMeta()
142-
X.meta.meta[xstart:xstart+len-1] .= (i->Cassette.Meta(ProvinanceSet(i), nometa)).(ystart:ystart+len-1)
143-
val
144-
else
145-
Cassette.recurse(ctx, f, X, xstart, Y, ystart, len)
146-
end
147-
end
148-
149131
function jacobian_sparsity(f!, Y, X, args...;
150132
sparsity=Sparsity(length(Y), length(X)),
151133
raw = false)
@@ -170,29 +152,29 @@ function jacobian_sparsity(f!, Y, X, args...;
170152
end
171153
end
172154

173-
function Cassette.overdub(ctx::SparsityContext,
155+
function Cassette.overdub(ctx::JacobianSparsityContext,
174156
f::typeof(Base.unsafe_copyto!),
175157
X::Tagged,
176158
xstart,
177159
Y::Tagged,
178160
ystart,
179161
len)
180162
S = ctx.metadata
181-
if ismetatype(Y, ctx, JacInput) && ismetatype(X, ctx, JacOutput)
163+
if metatype(Y, ctx) <: JacInput && metatype(X, ctx) <: JacOutput
182164
# Write directly to the output sparsity
183165
val = Cassette.fallback(ctx, f, X, xstart, Y, ystart, len)
184166
for (i, j) in zip(xstart:xstart+len-1, ystart:ystart+len-1)
185167
push!(S, i, j)
186168
end
187169
val
188-
elseif ismetatype(Y, ctx, JacInput)
170+
elseif metatype(Y, ctx) <: JacInput
189171
# Keep around a ProvinanceSet
190172
val = Cassette.fallback(ctx, f, X, xstart, Y, ystart, len)
191173
nometa = Cassette.NoMetaMeta()
192-
rhs = (i->Cassette.Meta(pset(i), nometa)).(ystart:ystart+len-1)
174+
rhs = (i->Cassette.Meta(ProvinanceSet(i), nometa)).(ystart:ystart+len-1)
193175
X.meta.meta[xstart:xstart+len-1] .= rhs
194176
val
195-
elseif ismetatype(X, ctx, JacOutput)
177+
elseif metatype(X, ctx) <: JacOutput
196178
val = Cassette.fallback(ctx, f, X, xstart, Y, ystart, len)
197179
for (i, j) in zip(xstart:xstart+len-1, ystart:ystart+len-1)
198180
y = Cassette.@overdub ctx Y[j]

src/util.jl

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,36 @@ function metatype(x, ctx)
1212
end
1313
end
1414

15-
macro reroute(ctx, f, g)
15+
# generic implementations
16+
17+
_name(x::Symbol) = x
18+
_name(x::Expr) = (@assert x.head == :(::); x.args[1])
19+
macro reroute(f, g)
20+
fname = f.args[1]
21+
fargs = f.args[2:end]
22+
gname = g.args[1]
23+
gargs = g.args[2:end]
1624
quote
17-
function Cassette.overdub(ctx::$ctx,
18-
f::typeof($(esc(f))),
19-
args...)
20-
Cassette.overdub(
25+
@inline function Cassette.overdub(ctx::JacobianSparsityContext,
26+
f::typeof($(esc(fname))),
27+
$(fargs...))
28+
Cassette.recurse(
29+
ctx,
30+
invoke,
31+
$(esc(gname)),
32+
$(esc(:(Tuple{$(gargs...)}))),
33+
$(map(_name, fargs)...))
34+
end
35+
36+
@inline function Cassette.overdub(ctx::HessianSparsityContext,
37+
f::typeof($(esc(fname))),
38+
$(fargs...))
39+
Cassette.recurse(
2140
ctx,
2241
invoke,
23-
$(esc(g.args[1])),
24-
$(esc(:(Tuple{$(g.args[2:end]...)}))),
25-
args...)
42+
$(esc(gname)),
43+
$(esc(:(Tuple{$(gargs...)}))),
44+
$(map(_name, fargs)...))
2645
end
2746
end
2847
end

0 commit comments

Comments
 (0)