|
1 | 1 | ### Rules |
2 | 2 | using Test |
3 | | -using Cassette |
4 | | -import Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged |
5 | | -using SparseArrays |
6 | | -using SparsityDetection |
7 | | -import SparsityDetection: abstract_run, HessianSparsityContext, |
8 | | - JacobianSparsityContext, TermCombination, HessInput |
9 | | -using Test |
| 3 | +using Cassette, SparsityDetection |
| 4 | +using SparseArrays, Test |
| 5 | + |
| 6 | +using Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged |
| 7 | +using SparsityDetection: Path, BranchesPass, SparsityContext, Fixed, |
| 8 | + Input, Output, pset, Tainted, istainted, |
| 9 | + alldone, reset!, HessianSparsityContext |
| 10 | +using SparsityDetection: TermCombination |
10 | 11 |
|
11 | 12 | Term(i...) = TermCombination(Set([Dict(j=>1 for j in i)])) |
12 | 13 |
|
13 | 14 | function jactester(f, Y, X, args...) |
14 | 15 | ctx, val = jacobian_sparsity(f, Y, X, args...; raw=true) |
| 16 | + val = nothing |
| 17 | + while true |
| 18 | + val = Cassette.overdub(ctx, |
| 19 | + f, |
| 20 | + tag(Y, ctx, Output()), |
| 21 | + tag(X, ctx, Input()), |
| 22 | + map(arg -> arg isa Fixed ? |
| 23 | + arg.value : |
| 24 | + tag(arg, ctx, pset()), args)...) |
| 25 | + println("Explored path: ", path) |
| 26 | + alldone(path) && break |
| 27 | + reset!(path) |
| 28 | + end |
| 29 | + return ctx, val |
15 | 30 | end |
16 | 31 |
|
17 | 32 | jactestmeta(args...) = jactester(args...)[1].metadata |
|
0 commit comments