Skip to content

Commit c782e40

Browse files
YingboMaoxinabox
andauthored
Make One, Zero, and DNE inferable when broadcasting with StaticArrays (#96)
* Make `One`, `Zero`, and `DNE` inferable when broadcasting with StaticArrays * New release * Remove superfluous [deps] * Update test/rules.jl Co-Authored-By: Lyndon White <oxinabox@ucc.asn.au> Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
1 parent 932b704 commit c782e40

File tree

5 files changed

+17
-3
lines changed

5 files changed

+17
-3
lines changed

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.5.2"
3+
version = "0.5.3"
44

55
[deps]
66
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
77

88
[compat]
9-
julia = "^1.0"
109
MuladdMacro = "0.2.1"
10+
StaticArrays = "0.11, 0.12"
11+
julia = "^1.0"
1112

1213
[extras]
1314
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
15+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1416
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1517

1618
[targets]
17-
test = ["Test", "LinearAlgebra"]
19+
test = ["Test", "LinearAlgebra", "StaticArrays"]

src/differentials/does_not_exist.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ function extern(x::DoesNotExist)
3333
end
3434

3535
Base.Broadcast.broadcastable(::DoesNotExist) = Ref(DoesNotExist())
36+
Base.Broadcast.broadcasted(::Type{DoesNotExist}) = DoesNotExist()
3637

3738
Base.iterate(x::DoesNotExist) = (x, nothing)
3839
Base.iterate(::DoesNotExist, ::Any) = nothing

src/differentials/one.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ struct One <: AbstractDifferential end
88
extern(x::One) = true # true is a strong 1.
99

1010
Base.Broadcast.broadcastable(::One) = Ref(One())
11+
Base.Broadcast.broadcasted(::Type{One}) = One()
1112

1213
Base.iterate(x::One) = (x, nothing)
1314
Base.iterate(::One, ::Any) = nothing

src/differentials/zero.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ struct Zero <: AbstractDifferential end
88
extern(x::Zero) = false # false is a strong 0. E.g. `false * NaN = 0.0`
99

1010
Base.Broadcast.broadcastable(::Zero) = Ref(Zero())
11+
Base.Broadcast.broadcasted(::Type{Zero}) = Zero()
1112

1213
Base.iterate(x::Zero) = (x, nothing)
1314
Base.iterate(::Zero, ::Any) = nothing

test/rules.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#######
22
# Demo setup
3+
using StaticArrays: @SVector
34

45
cool(x) = x + 1
56
cool(x, y) = x + y + 1
@@ -11,6 +12,9 @@ dummy_identity(x) = x
1112
nice(x) = 1
1213
@scalar_rule(nice(x), Zero())
1314

15+
very_nice(x, y) = x + y
16+
@scalar_rule(very_nice(x, y), (One(), One()))
17+
1418
#######
1519

1620
_second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
@@ -46,4 +50,9 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
4650
@test nice_pushforward === 0
4751
rrx, nice_pullback = rrule(nice, 1)
4852
@test (NO_FIELDS, 0) === nice_pullback(1)
53+
54+
sx = @SVector [1, 2]
55+
sy = @SVector [3, 4]
56+
# This actually is testing that @scalar_rule and `One()` play nice together, w.r.t broadcasting
57+
@inferred frule(very_nice, 1, 2, Zero(), sx, sy)
4958
end

0 commit comments

Comments
 (0)