Skip to content

Commit ecd2bb6

Browse files
committed
test clear hooks
1 parent 7d32509 commit ecd2bb6

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

test/ruleset_loading.jl

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,31 @@
1010
op = sig.parameters[1]
1111
push!(rrule_history, op)
1212
end
13+
14+
@testset "new rules hit the hooks" begin
15+
# Now define some rules
16+
@scalar_rule x + y (1, 1)
17+
@scalar_rule x - y (1, -1)
18+
refresh_rules()
19+
20+
@test Set(frule_history[end-1:end]) == Set((typeof(+), typeof(-)))
21+
@test Set(rrule_history[end-1:end]) == Set((typeof(+), typeof(-)))
22+
end
1323

14-
# Now define some rules
15-
@scalar_rule x + y (1, 1)
16-
@scalar_rule x - y (1, -1)
17-
refresh_rules()
24+
@testset "# Make sure nothing happens anymore once we clear the hooks" begin
25+
ChainRulesCore.clear_new_rule_hooks!(frule)
26+
ChainRulesCore.clear_new_rule_hooks!(rrule)
27+
28+
old_frule_history = copy(frule_history)
29+
old_rrule_history = copy(rrule_history)
30+
31+
@scalar_rule sin(x) cos(x)
32+
refresh_rules()
33+
34+
@test old_rrule_history == rrule_history
35+
@test old_frule_history == frule_history
36+
end
1837

19-
@test Set(frule_history[end-1:end]) == Set((typeof(+), typeof(-)))
20-
@test Set(rrule_history[end-1:end]) == Set((typeof(+), typeof(-)))
21-
22-
ChainRulesCore.clear_new_rule_hooks!(frule)
23-
ChainRulesCore.clear_new_rule_hooks!(rrule)
2438
end
2539

2640
@testset "_primal_sig" begin

0 commit comments

Comments
 (0)