Skip to content
This repository was archived by the owner on Jun 24, 2022. It is now read-only.

Commit 34aaceb

Browse files
committed
fix promote
1 parent 69e736e commit 34aaceb

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed

src/controlflow.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ function rewrite_ir(ctx, ref)
221221

222222
ir.ssavaluetypes = length(ir.code)
223223
# Core.Compiler.validate_code(ir)
224+
#@show ref.method
225+
#@show ir
224226
return ir
225227
end
226228

src/propagate_tags.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,17 @@ macro proptagcontext(name)
88
quote
99
Cassette.@context($name)
1010

11+
## promote(x,y) should not tag the output tuple with the union of tags
12+
## of x and y. So here we recurse into promote, and then tag each
13+
## element of the result with the original tag
14+
function Cassette.overdub(ctx::$name, f::typeof(promote), args...)
15+
promoted = Cassette.recurse(ctx, f, args...)
16+
17+
# put the tags back on:
18+
tagged_promoted = map((x,v)->tag(v, ctx, metadata(x, ctx)), args, promoted)
19+
Cassette.overdub(ctx, tuple, tagged_promoted...)
20+
end
21+
1122
function Cassette.overdub(ctx::$name, f, args...)
1223
# this check can be inferred (in theory)
1324
if anytagged(args...)

test/examples.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ using SparsityDetection, SparseArrays
1313
input = rand(10)
1414
output = similar(input)
1515
sparsity_pattern = sparsity!(f,output,input)
16-
jac = Float64.(sparse(sparsity_pattern))
16+
17+
@test nnz(sparse(sparsity_pattern)) == 28

0 commit comments

Comments
 (0)