Skip to content

Commit a57b46e

Browse files
sethaxenoxinabox
andauthored
Add dot overloads for Zero and One (#201)
* Add LinearAlgebra as dependancy * Add overrides for dot * Test dot rules * Increment version number * Increment version number * Add dot overloads for DoesNotExist * Add missing overload * Add comment Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> * Don't use 1 Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
1 parent 39f1caf commit a57b46e

File tree

7 files changed

+40
-10
lines changed

7 files changed

+40
-10
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.9.9"
3+
version = "0.9.10"
44

55
[deps]
6+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
67
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
78

89
[compat]
@@ -15,9 +16,8 @@ julia = "^1.0"
1516
[extras]
1617
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
1718
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
18-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1919
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2020
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2121

2222
[targets]
23-
test = ["Test", "BenchmarkTools", "FiniteDifferences", "LinearAlgebra", "StaticArrays"]
23+
test = ["Test", "BenchmarkTools", "FiniteDifferences", "StaticArrays"]

src/ChainRulesCore.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module ChainRulesCore
22
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
3+
using LinearAlgebra: LinearAlgebra
34
using MuladdMacro: @muladd
45

56
export on_new_rule, refresh_rules # generation tools

src/differential_arithmetic.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Base.:+(::DoesNotExist, ::DoesNotExist) = DoesNotExist()
1818
Base.:-(::DoesNotExist, ::DoesNotExist) = DoesNotExist()
1919
Base.:-(::DoesNotExist) = DoesNotExist()
2020
Base.:*(::DoesNotExist, ::DoesNotExist) = DoesNotExist()
21+
LinearAlgebra.dot(::DoesNotExist, ::DoesNotExist) = DoesNotExist()
2122
for T in (:One, :AbstractThunk, :Composite, :Any)
2223
@eval Base.:+(::DoesNotExist, b::$T) = b
2324
@eval Base.:+(a::$T, ::DoesNotExist) = a
@@ -26,6 +27,9 @@ for T in (:One, :AbstractThunk, :Composite, :Any)
2627

2728
@eval Base.:*(::DoesNotExist, ::$T) = DoesNotExist()
2829
@eval Base.:*(::$T, ::DoesNotExist) = DoesNotExist()
30+
31+
@eval LinearAlgebra.dot(::DoesNotExist, ::$T) = DoesNotExist()
32+
@eval LinearAlgebra.dot(::$T, ::DoesNotExist) = DoesNotExist()
2933
end
3034
# `DoesNotExist` and `Zero` have special relationship,
3135
# DoesNotExist wins add, Zero wins *. This is (in theory) to allow `*` to be used for
@@ -37,6 +41,9 @@ Base.:-(::Zero, ::DoesNotExist) = DoesNotExist()
3741
Base.:*(::DoesNotExist, ::Zero) = Zero()
3842
Base.:*(::Zero, ::DoesNotExist) = Zero()
3943

44+
LinearAlgebra.dot(::DoesNotExist, ::Zero) = Zero()
45+
LinearAlgebra.dot(::Zero, ::DoesNotExist) = Zero()
46+
4047
Base.muladd(::Zero, x, y) = y
4148
Base.muladd(x, ::Zero, y) = y
4249
Base.muladd(x, y, ::Zero) = x*y
@@ -51,6 +58,7 @@ Base.:+(::Zero, ::Zero) = Zero()
5158
Base.:-(::Zero, ::Zero) = Zero()
5259
Base.:-(::Zero) = Zero()
5360
Base.:*(::Zero, ::Zero) = Zero()
61+
LinearAlgebra.dot(::Zero, ::Zero) = Zero()
5462
for T in (:One, :AbstractThunk, :Composite, :Any)
5563
@eval Base.:+(::Zero, b::$T) = b
5664
@eval Base.:+(a::$T, ::Zero) = a
@@ -59,6 +67,9 @@ for T in (:One, :AbstractThunk, :Composite, :Any)
5967

6068
@eval Base.:*(::Zero, ::$T) = Zero()
6169
@eval Base.:*(::$T, ::Zero) = Zero()
70+
71+
@eval LinearAlgebra.dot(::Zero, ::$T) = Zero()
72+
@eval LinearAlgebra.dot(::$T, ::Zero) = Zero()
6273
end
6374

6475
Base.real(::Zero) = Zero()
@@ -88,6 +99,8 @@ for T in (:AbstractThunk, :Composite, :Any)
8899
@eval Base.:*(a::$T, ::One) = a
89100
end
90101

102+
LinearAlgebra.dot(::One, x::Number) = x
103+
LinearAlgebra.dot(x::Number, ::One) = conj(x) # see definition of Frobenius inner product
91104

92105
Base.:+(a::AbstractThunk, b::AbstractThunk) = unthunk(a) + unthunk(b)
93106
Base.:*(a::AbstractThunk, b::AbstractThunk) = unthunk(a) * unthunk(b)

test/differentials/abstract_zero.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@
1515
@test 1 - z === 1
1616
@test -z === z
1717
@test z * z === z
18-
@test z * 1 === Zero()
19-
@test 1 * z === Zero()
18+
@test z * 11.1 === Zero()
19+
@test 12.3 * z === Zero()
20+
@test dot(z, z) === z
21+
@test dot(z, 1.8) === z
22+
@test dot(2.1, z) === z
23+
@test dot([1, 2], z) === z
24+
@test dot(z, [1, 2]) === z
2025
for x in z
2126
@test x === z
2227
end
@@ -67,8 +72,11 @@
6772
@test 1 - dne == 1
6873
@test -dne == dne
6974
@test dne * dne == dne
70-
@test dne * 1 == dne
71-
@test 1 * dne == dne
75+
@test dne * 11.1 == dne
76+
@test 12.1 * dne == dne
77+
@test dot(dne, dne) == dne
78+
@test dot(dne, 17.2) == dne
79+
@test dot(11.9, dne) == dne
7280

7381
@test Zero() + dne == dne
7482
@test dne + Zero() == dne
@@ -77,6 +85,8 @@
7785

7886
@test Zero() * dne == Zero()
7987
@test dne * Zero() == Zero()
88+
@test dot(Zero(), dne) == Zero()
89+
@test dot(dne, Zero()) == Zero()
8090

8191
for x in dne
8292
@test x === dne

test/differentials/composite.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,13 @@ end
185185

186186
@test DoesNotExist() * c == DoesNotExist()
187187
@test c * DoesNotExist() == DoesNotExist()
188+
@test dot(DoesNotExist(), c) == DoesNotExist()
189+
@test dot(c, DoesNotExist()) == DoesNotExist()
188190

189191
@test Zero() * c == Zero()
190192
@test c * Zero() == Zero()
193+
@test dot(Zero(), c) == Zero()
194+
@test dot(c, Zero()) == Zero()
191195

192196
@test One() * c === c
193197
@test c * One() === c

test/differentials/one.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
@test o + 1 == 2
66
@test 1 + o == 2
77
@test o * o == o
8-
@test o * 1 == 1
9-
@test 1 * o == 1
8+
@test o * 17 == 17
9+
@test 6 * o == 6
10+
@test dot(2 + im, o) == 2 - im
11+
@test dot(o, 2 + im) == 2 + im
1012
for x in o
1113
@test x === o
1214
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Base.Broadcast: broadcastable
22
using BenchmarkTools
33
using ChainRulesCore
4-
using LinearAlgebra: Diagonal
4+
using LinearAlgebra: Diagonal, dot
55
using FiniteDifferences
66
using Test
77

0 commit comments

Comments
 (0)