Skip to content

Commit 1efe316

Browse files
authored
Merge pull request #96 from torfjelde/tor/fix-95
Fix for #95
2 parents 37e3e4d + 4d1bb3e commit 1efe316

File tree

5 files changed

+108
-80
lines changed

5 files changed

+108
-80
lines changed

src/if_required/chainrulescore.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, ::Val{s}) where s = ChainRulesCore.rrule(getproperty, x, s)
2-
function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Symbol)
1+
function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union{Symbol, Val})
32
function getproperty_adjoint(Δ)
43
zero_x = zero(x)
54
setproperty!(zero_x, s, Δ)
6-
return (ChainRulesCore.NO_FIELDS, zero_x)
5+
return (ChainRulesCore.NoTangent(), zero_x, ChainRulesCore.NoTangent())
76
end
87

98
return getproperty(x, s), getproperty_adjoint
109
end
1110

12-
ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->(ChainRulesCore.NO_FIELDS, ComponentArray(Δ, getaxes(x)))
11+
ChainRulesCore.rrule(::typeof(getdata), x::ComponentArray) = getdata(x), Δ->(ChainRulesCore.NoTangent(), ComponentArray(Δ, getaxes(x)))
1312

14-
ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ->(ChainRulesCore.NO_FIELDS, getdata(Δ), getaxes(Δ))
13+
ChainRulesCore.rrule(::Type{ComponentArray}, data, axes) = ComponentArray(data, axes), Δ->(ChainRulesCore.NoTangent(), getdata(Δ), getaxes(Δ))

src/if_required/zygote.jl

Lines changed: 0 additions & 23 deletions
This file was deleted.

test/Manifest.toml

Lines changed: 81 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,49 @@
11
# This file is machine-generated - editing it directly is not advised
22

3+
[[AbstractFFTs]]
4+
deps = ["LinearAlgebra"]
5+
git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0"
6+
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
7+
version = "1.0.1"
8+
39
[[Adapt]]
410
deps = ["LinearAlgebra"]
5-
git-tree-sha1 = "f1b523983a58802c4695851926203b36e28f09db"
11+
git-tree-sha1 = "84918055d15b3114ede17ac6a7182f68870c16f7"
612
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
7-
version = "3.3.0"
13+
version = "3.3.1"
814

915
[[ArgTools]]
1016
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
1117

1218
[[ArrayInterface]]
1319
deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"]
14-
git-tree-sha1 = "2fbfa5f372352f92191b63976d070dc7195f47a4"
20+
git-tree-sha1 = "045ff5e1bc8c6fb1ecb28694abba0a0d55b5f4f5"
1521
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
16-
version = "3.1.7"
22+
version = "3.1.17"
1723

1824
[[Artifacts]]
1925
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
2026

2127
[[Base64]]
2228
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
2329

30+
[[ChainRules]]
31+
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Statistics"]
32+
git-tree-sha1 = "720fa9a9ce61ff18842a40f501d6a1f8ba771c64"
33+
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
34+
version = "0.8.6"
35+
2436
[[ChainRulesCore]]
2537
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
26-
git-tree-sha1 = "42e3c181483fbd2c416087a0a93838803e358358"
38+
git-tree-sha1 = "d659e42240c2162300b321f05173cab5cc40a5ba"
2739
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
28-
version = "0.9.38"
40+
version = "0.10.4"
2941

3042
[[ChainRulesTestUtils]]
3143
deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"]
32-
git-tree-sha1 = "8951ac04086b1114303bfba244c1ce7b954c25a2"
44+
git-tree-sha1 = "ef004b4fd7c8ce775d19fc0b4b5c1030c51973e5"
3345
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
34-
version = "0.6.7"
46+
version = "0.7.9"
3547

3648
[[CommonSubexpressions]]
3749
deps = ["MacroTools", "Test"]
@@ -41,19 +53,19 @@ version = "0.3.0"
4153

4254
[[Compat]]
4355
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
44-
git-tree-sha1 = "ac4132ad78082518ec2037ae5770b6e796f7f956"
56+
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
4557
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
46-
version = "3.27.0"
58+
version = "3.30.0"
4759

4860
[[CompilerSupportLibraries_jll]]
4961
deps = ["Artifacts", "Libdl"]
5062
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
5163

5264
[[ConstructionBase]]
5365
deps = ["LinearAlgebra"]
54-
git-tree-sha1 = "48920211c95a6da1914a06c44ec94be70e84ffff"
66+
git-tree-sha1 = "1dc43957fb9a1574fa1b7a449e101bd1fd3a9fb7"
5567
uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
56-
version = "1.1.0"
68+
version = "1.2.1"
5769

5870
[[Dates]]
5971
deps = ["Printf"]
@@ -79,10 +91,22 @@ version = "1.0.2"
7991
deps = ["Random", "Serialization", "Sockets"]
8092
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
8193

94+
[[DocStringExtensions]]
95+
deps = ["LibGit2"]
96+
git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f"
97+
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
98+
version = "0.8.5"
99+
82100
[[Downloads]]
83101
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
84102
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
85103

104+
[[FillArrays]]
105+
deps = ["LinearAlgebra", "Random", "SparseArrays"]
106+
git-tree-sha1 = "31939159aeb8ffad1d4d8ee44d07f8558273120a"
107+
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
108+
version = "0.11.7"
109+
86110
[[FiniteDiff]]
87111
deps = ["ArrayInterface", "LinearAlgebra", "Requires", "SparseArrays", "StaticArrays"]
88112
git-tree-sha1 = "f6f80c8f934efd49a286bb5315360be66956dfc4"
@@ -91,9 +115,9 @@ version = "2.8.0"
91115

92116
[[FiniteDifferences]]
93117
deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"]
94-
git-tree-sha1 = "a1c802a1407e4ff5a4733dedb8ed40bd6fb46021"
118+
git-tree-sha1 = "bdc9fb1d27a1ccecd2fe8f39c6211524cbe642cb"
95119
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
96-
version = "0.12.2"
120+
version = "0.12.13"
97121

98122
[[ForwardDiff]]
99123
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
@@ -106,6 +130,12 @@ git-tree-sha1 = "241552bc2209f0fa068b6415b1942cc0aa486bcc"
106130
uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
107131
version = "1.1.2"
108132

133+
[[IRTools]]
134+
deps = ["InteractiveUtils", "MacroTools", "Test"]
135+
git-tree-sha1 = "c67e7515a11f726f44083e74f218d134396d6510"
136+
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
137+
version = "0.4.2"
138+
109139
[[IfElse]]
110140
git-tree-sha1 = "28e837ff3e7a6c3cdb252ce49fb412c8eb3caeef"
111141
uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
@@ -123,9 +153,9 @@ version = "1.3.0"
123153

124154
[[LabelledArrays]]
125155
deps = ["ArrayInterface", "LinearAlgebra", "MacroTools", "StaticArrays"]
126-
git-tree-sha1 = "df09e970c816637191ef8794ef5c5c9f8950db88"
156+
git-tree-sha1 = "248a199fa42ec62922225334131c9330e1dd72f6"
127157
uuid = "2ee39098-c373-598a-b85f-a56591580800"
128-
version = "1.6.0"
158+
version = "1.6.1"
129159

130160
[[LibCURL]]
131161
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
@@ -150,6 +180,12 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
150180
deps = ["Libdl"]
151181
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
152182

183+
[[LogExpFunctions]]
184+
deps = ["DocStringExtensions", "LinearAlgebra"]
185+
git-tree-sha1 = "1ba664552f1ef15325e68dc4c05c3ef8c2d5d885"
186+
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
187+
version = "0.2.4"
188+
153189
[[Logging]]
154190
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
155191

@@ -183,25 +219,25 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
183219

184220
[[OffsetArrays]]
185221
deps = ["Adapt"]
186-
git-tree-sha1 = "b3dfef5f2be7d7eb0e782ba9146a5271ee426e90"
222+
git-tree-sha1 = "1381a7142eefd4cd12f052a4d2d790fe21bd1d55"
187223
uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
188-
version = "1.6.2"
224+
version = "1.9.2"
189225

190226
[[OpenSpecFun_jll]]
191227
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
192-
git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3"
228+
git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1"
193229
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
194-
version = "0.5.3+4"
230+
version = "0.5.5+0"
195231

196232
[[Pkg]]
197233
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
198234
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
199235

200236
[[Preferences]]
201237
deps = ["TOML"]
202-
git-tree-sha1 = "ea79e4c9077208cd3bc5d29631a26bc0cff78902"
238+
git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a"
203239
uuid = "21216c6a-2e73-6563-6e65-726566657250"
204-
version = "1.2.1"
240+
version = "1.2.2"
205241

206242
[[Printf]]
207243
deps = ["Unicode"]
@@ -223,9 +259,9 @@ version = "1.1.3"
223259

224260
[[ReverseDiff]]
225261
deps = ["DiffResults", "DiffRules", "ForwardDiff", "FunctionWrappers", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "SpecialFunctions", "StaticArrays", "Statistics"]
226-
git-tree-sha1 = "6f8b8ce36bcefe8f2c16182c8e86c00748613e98"
262+
git-tree-sha1 = "63ee24ea0689157a1113dbdab10c6cb011d519c4"
227263
uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
228-
version = "1.8.0"
264+
version = "1.9.0"
229265

230266
[[Richardson]]
231267
deps = ["LinearAlgebra"]
@@ -251,22 +287,22 @@ deps = ["LinearAlgebra", "Random"]
251287
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
252288

253289
[[SpecialFunctions]]
254-
deps = ["ChainRulesCore", "OpenSpecFun_jll"]
255-
git-tree-sha1 = "5919936c0e92cff40e57d0ddf0ceb667d42e5902"
290+
deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"]
291+
git-tree-sha1 = "a50550fa3164a8c46747e62063b4d774ac1bcf49"
256292
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
257-
version = "1.3.0"
293+
version = "1.5.1"
258294

259295
[[Static]]
260296
deps = ["IfElse"]
261-
git-tree-sha1 = "ddec5466a1d2d7e58adf9a427ba69763661aacf6"
297+
git-tree-sha1 = "2740ea27b66a41f9d213561a04573da5d3823d4b"
262298
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
263-
version = "0.2.4"
299+
version = "0.2.5"
264300

265301
[[StaticArrays]]
266302
deps = ["LinearAlgebra", "Random", "Statistics"]
267-
git-tree-sha1 = "e8cd1b100d37f5b4cfd2c83f45becf61c762eaf7"
303+
git-tree-sha1 = "42378d3bab8b4f57aa1ca443821b752850592668"
268304
uuid = "90137ffa-7385-5640-81b9-e52037218182"
269-
version = "1.1.1"
305+
version = "1.2.2"
270306

271307
[[Statistics]]
272308
deps = ["LinearAlgebra", "SparseArrays"]
@@ -293,14 +329,26 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
293329

294330
[[Unitful]]
295331
deps = ["ConstructionBase", "Dates", "LinearAlgebra", "Random"]
296-
git-tree-sha1 = "c6bbc170505c5ea36593a0072b61d3be8bf868ae"
332+
git-tree-sha1 = "b3682a0559219355f1e3c8024e9f97adce2d4623"
297333
uuid = "1986cc42-f94f-5a68-af5c-568840ba703d"
298-
version = "1.7.0"
334+
version = "1.8.0"
299335

300336
[[Zlib_jll]]
301337
deps = ["Libdl"]
302338
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
303339

340+
[[Zygote]]
341+
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
342+
git-tree-sha1 = "b1d95edd4e693066c38c13a10aab0a8f6a6e2f65"
343+
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
344+
version = "0.6.12"
345+
346+
[[ZygoteRules]]
347+
deps = ["MacroTools"]
348+
git-tree-sha1 = "9e7a1e8ca60b742e508a315c17eef5211e7fbfd7"
349+
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
350+
version = "0.2.1"
351+
304352
[[nghttp2_jll]]
305353
deps = ["Artifacts", "Libdl"]
306354
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
99
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1010
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1111
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
12+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

test/autodiff_tests.jl

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,32 @@
11
import FiniteDiff
22
import ForwardDiff
33
import ReverseDiff
4-
# import Zygote
4+
import Zygote
55

66
using Test
77

8-
grads = let
9-
F(x, θ, deg) = (θ[1] - x[1])^deg + θ[2] * (x[2] - x[1]^deg)^deg
8+
F(x, θ, deg) = (θ[1] - x[1])^deg + θ[2] * (x[2] - x[1]^deg)^deg
9+
F_idx_val(ca) = F(ca[Val(:x)], ca[Val()], ca[Val(:deg)])
10+
F_idx_sym(ca) = F(ca[:x], ca[], ca[:deg])
11+
F_prop(ca) = F(ca.x, ca.θ, ca.deg)
1012

11-
F_idx_val(ca) = F(ca[Val(:x)], ca[Val()], ca[Val(:deg)])
12-
F_idx_sym(ca) = F(ca[:x], ca[], ca[:deg])
13-
F_prop(ca) = F(ca.x, ca.θ, ca.deg)
13+
ca = ComponentArray(x = [1, 2], θ = [1.0, 100.0], deg = 2)
14+
truth = [-400, 200]
1415

15-
ca = ComponentArray(x = [1, 2], θ = [1.0, 100.0], deg = 2)
16+
@testset "$(nameof(F_))" for F_ in (F_idx_val, F_idx_sym, F_prop)
17+
finite = FiniteDiff.finite_difference_gradient(ca -> F_(ca), ca).x
18+
@test finite truth
1619

17-
(
18-
truth = [-400, 200],
19-
finite = FiniteDiff.finite_difference_gradient(ca -> F_prop(ca), ca).x,
20-
forward = ForwardDiff.gradient(ca -> F_prop(ca), ca).x,
21-
reverse = ReverseDiff.gradient(ca -> F_prop(ca), ca).x,
22-
# zygote = Zygote.gradient(ca -> F(ca), ca)[1].x,
23-
)
24-
end
20+
forward = ForwardDiff.gradient(ca -> F_(ca), ca).x
21+
@test forward truth
22+
23+
reverse = ReverseDiff.gradient(ca -> F_(ca), ca).x
24+
if F_ in (F_idx_val, F_idx_sym)
25+
@test_broken reverse truth
26+
else
27+
@test reverse truth
28+
end
2529

26-
@test grads.finite grads.truth
27-
@test grads.forward grads.truth
28-
@test grads.reverse grads.truth
29-
# @test grads.zygote ≈ grads.truth
30+
zygote = Zygote.gradient(ca -> F_(ca), ca)[1].x
31+
@test zygote truth
32+
end

0 commit comments

Comments
 (0)