Skip to content

Commit 63de708

Browse files
authored
Merge pull request #319 from JuliaDiff/ox/scope
Add qualification to at-nondifferentiable
2 parents 7404c03 + 262f99b commit 63de708

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.9.30"
3+
version = "0.9.31"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/rule_definition_tools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ function _nondiff_rrule_expr(primal_sig_parts, primal_invoke)
361361
)
362362
return esc(quote
363363
# Manually defined kw version to save compiler work. See explanation in rules.jl
364-
function (::Core.kwftype(typeof(rrule)))(kwargs::Any, rrule::typeof(ChainRulesCore.rrule), $(primal_sig_parts...))
364+
function (::Core.kwftype(typeof(ChainRulesCore.rrule)))(kwargs::Any, rrule::typeof(ChainRulesCore.rrule), $(primal_sig_parts...))
365365
return ($(_with_kwargs_expr(primal_invoke)), $pullback_expr)
366366
end
367367
function ChainRulesCore.rrule($(primal_sig_parts...))

test/rule_definition_tools.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,24 @@ end
251251

252252

253253
end
254+
255+
256+
module IsolatedModuleForTestingScoping
257+
using Test
258+
# need to make sure macros work in something that hasn't imported all exports
259+
# all that matters is that the following don't error, since they will resolve at
260+
# parse time
261+
using ChainRulesCore: ChainRulesCore
262+
263+
@testset "@non_differentiable" begin
264+
# this is
265+
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/317
266+
fixed(x) = :abc
267+
ChainRulesCore.@non_differentiable fixed(x)
268+
end
269+
270+
@testset "@scalar_rule" begin
271+
my_id(x) = x
272+
ChainRulesCore.@scalar_rule(my_id(x), 1.0)
273+
end
274+
end

0 commit comments

Comments
 (0)