Skip to content
This repository was archived by the owner on Jun 24, 2022. It is now read-only.

Commit 0237121

Browse files
committed
test setup
1 parent 78f4923 commit 0237121

File tree

8 files changed

+41
-29
lines changed

8 files changed

+41
-29
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ authors = ["Shashi Gowda", "Chris Rackauckas <contact@chrisrackauckas.com>"]
44
version = "0.1.1"
55

66
[deps]
7-
Amb = "c42f9944-9a8f-11e9-2851-1930a3e1c813"
87
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
98
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
109
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -14,7 +13,8 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1413
julia = "1"
1514

1615
[extras]
16+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1717
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1818

1919
[targets]
20-
test = ["Test"]
20+
test = ["Test", "FiniteDiff"]

src/SparsityDetection.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ include("jacobian.jl")
1616
include("hessian.jl")
1717
include("blas.jl")
1818

19-
Base.@deprecate sparsity!(args...) jacobian_sparsity(args...)
20-
Base.@deprecate hsparsity(args...) hessian_sparsity(args...)
19+
sparsity!(args...; kwargs...) = jacobian_sparsity(args...; kwargs...)
20+
hsparsity(args...; kwargs...) = hessian_sparsity(args...; kwargs...)
2121

2222
end

src/hessian.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -222,17 +222,19 @@ end
222222

223223
function hessian_sparsity(f, X, args...; raw=false)
224224

225-
terms = zero(TermCombination)
225+
terms = Ref(zero(TermCombination))
226226
ctx = HessianSparsityContext()
227227
ctx = Cassette.enabletagging(ctx, f)
228228
ctx = Cassette.disablehooks(ctx)
229229
val = nothing
230230
function process(result)
231-
try
232-
terms += metadata(result, ctx)
233-
catch err
234-
@warn("Could not extract hessian sparsity")
235-
println(err)
231+
if Cassette.hasmetadata(result, ctx)
232+
try
233+
terms[] += metadata(result, ctx)
234+
catch err
235+
@warn("Could not extract hessian sparsity")
236+
println(err)
237+
end
236238
end
237239
val=result
238240
end
@@ -242,9 +244,9 @@ function hessian_sparsity(f, X, args...; raw=false)
242244
arg.value : tag(arg, ctx, one(TermCombination)), args)...)
243245

244246
if raw
245-
return ctx, val
247+
return ctx, terms[], val
246248
end
247-
_sparse(terms, length(X))
249+
_sparse(terms[], length(X))
248250
end
249251

250252

src/jacobian.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ end
124124

125125
function jacobian_sparsity(f!, Y, X, args...;
126126
sparsity=Sparsity(length(Y), length(X)),
127+
verbose = true,
127128
raw = false)
128129

129130
ctx = JacobianSparsityContext(metadata=sparsity)
@@ -137,7 +138,8 @@ function jacobian_sparsity(f!, Y, X, args...;
137138
tag(Y, ctx, JacOutput()),
138139
tag(X, ctx, JacInput()),
139140
map(arg -> arg isa Fixed ?
140-
arg.value : tag(arg, ctx, ProvinanceSet(())), args)...)
141+
arg.value : tag(arg, ctx, ProvinanceSet(())), args)...;
142+
verbose=verbose)
141143

142144
if raw
143145
return (ctx, res)

src/propagate_tags.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,17 @@ macro proptagcontext(name)
1212
## of x and y. So here we recurse into promote, and then tag each
1313
## element of the result with the original tag
1414
function Cassette.overdub(ctx::$name, f::typeof(promote), args...)
15-
promoted = Cassette.recurse(ctx, f, args...)
15+
promoted = Cassette.fallback(ctx, f, args...)
1616

1717
# put the tags back on:
18-
tagged_promoted = map((x,v)->tag(v, ctx, metadata(x, ctx)), args, promoted)
19-
Cassette.overdub(ctx, tuple, tagged_promoted...)
18+
tagged_promoted = map(args, promoted) do orig, prom
19+
if Cassette.hasmetadata(orig, ctx)
20+
tag(prom, ctx, metadata(orig, ctx))
21+
else
22+
prom
23+
end
24+
end
25+
Cassette.recurse(ctx, tuple, tagged_promoted...)
2026
end
2127

2228
function Cassette.overdub(ctx::$name, f, args...)

src/sparsity.jl

Whitespace-only changes.

test/common.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,28 @@ using Cassette
44
import Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged
55
using SparseArrays
66
using SparsityDetection
7-
import SparsityDetection: abstract_run, HessianSparsityContext, JacobianSparsityContext
7+
import SparsityDetection: abstract_run, HessianSparsityContext,
8+
JacobianSparsityContext, TermCombination, HessInput
9+
using Test
10+
11+
Term(i...) = TermCombination(Set([Dict(j=>1 for j in i)]))
812

913
function jactester(f, Y, X, args...)
1014
ctx, val = jacobian_sparsity(f, Y, X, args...; raw=true)
1115
end
1216

13-
jactestmeta(args...) = jactester(args...)[1].metadata[1]
17+
jactestmeta(args...) = jactester(args...)[1].metadata
1418
jactestval(args...) = jactester(args...) |> ((ctx,val),) -> untag(val, ctx)
1519
jactesttag(args...) = jactester(args...) |> ((ctx,val),) -> metadata(val, ctx)
1620

1721
function hesstester(f, X, args...)
18-
ctx, val = hessian_sparsity(f, X, args...; raw=true)
22+
ctx, terms, val = hessian_sparsity(f, X, args...; raw=true)
1923
end
2024

21-
hesstestmeta(args...) = hesstester(args...)[1].metadata[1]
22-
hesstestval(args...) = hesstester(args...) |> ((ctx,val),) -> untag(val, ctx)
23-
hesstesttag(args...) = hesstester(args...) |> ((ctx,val),) -> metadata(val, ctx)
25+
hesstestmeta(args...) = hesstester(args...)[1].metadata
26+
hesstestval(args...) = hesstester(args...) |> ((ctx,terms,val),) -> untag(val, ctx)
27+
hesstesttag(args...) = hesstester(args...) |> ((ctx,terms,val),) -> metadata(val, ctx)
28+
hesstestterms(args...) = hesstester(args...) |> ((ctx,terms,val),) -> terms
2429

2530

2631
Base.show(io::IO, ::Type{<:Cassette.Context}) = print(io, "ctx")

test/hessian.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
1-
import SparsityDetection: TermCombination, HessInput
2-
using Test
31

4-
Term(i...) = TermCombination(Set([Dict(j=>1 for j in i)]))
5-
6-
@test hesstesttag(x->x, [1,2]) == HessInput()
7-
@test hesstesttag(x->x[1], [1,2]) == Term(1)
2+
@test hesstestterms(x->x[1], [1,2]) == Term(1)
83

94
# Tuple / struct
105
@test hesstesttag(x->(x[1],x[2])[2], [1,2]) == Term(2)
116

7+
@test hesstesttag(x->promote(x[1],convert(Float64, x[2]))[2], [1,2]) == Term(2)
8+
129
# 1-arg linear
1310
@test hesstesttag(x->deg2rad(x[1]), [1,2]) == Term(1)
1411

1512
# 1-arg nonlinear
16-
@test hesstesttag(x->sin(x[1]), [1,2]) == (Term(1) * Term(1))
13+
@test hesstestterms(x->sin(x[1]), [1,2]) == (Term(1) + Term(1) * Term(1))
1714

1815
# 2-arg (true,true,true)
1916
@test hesstesttag(x->x[1]+x[2], [1,2]) == Term(1)+Term(2)

0 commit comments

Comments
 (0)