Skip to content

Commit 1d2e11b

Browse files
committed
add tests for issues closed
1 parent 2217614 commit 1d2e11b

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

test/complex.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using Zygote, Test, LinearAlgebra
22

3+
@testset "basic" begin
4+
35
@test gradient(x -> real(abs(x)*exp(im*angle(x))), 10+20im)[1] 1
46
@test gradient(x -> imag(real(x)+0.3im), 0.3)[1] 0
57
@test gradient(x -> imag(conj(x)+0.3im), 0.3 + 0im)[1] -1im
@@ -23,6 +25,8 @@ using Zygote, Test, LinearAlgebra
2325
@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] real(im .* exp.(1:3))
2426
@test gradient(x -> imag(sum(exp, x)), [1+0im,2,3])[1] im .* exp.(1:3)
2527

28+
end # @testset
29+
2630
fs_C_to_R = (real,
2731
imag,
2832
abs,
@@ -83,3 +87,26 @@ fs_C_to_C_non_holomorphic = (conj,
8387
end
8488
end
8589
end
90+
91+
@testset "issue 342" begin
92+
@test Zygote.gradient(x->real(x + 2.0*im), 3.0) == (1.0,)
93+
@test Zygote.gradient(x->imag(x + 2.0*im), 3.0) == (0.0,)
94+
end
95+
96+
@testset "issue 402" begin
97+
A = [1,2,3.0]
98+
y, B_getindex = Zygote.pullback(x->getindex(x,2,1),Diagonal(A))
99+
bA = B_getindex(1)[1]
100+
@test bA isa Diagonal
101+
@test bA == [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0]
102+
end
103+
104+
@testset "issue #917" begin
105+
function fun(v)
106+
c = v[1:3] + v[4:6]*im
107+
r = v[7:9]
108+
sum(r .* abs2.(c)) # This would be calling my actual function depending on r and c
109+
end
110+
@test Zygote.hessian(fun, collect(1:9)) [14 0 0 0 0 0 2 0 0; 0 16 0 0 0 0 0 4 0; 0 0 18 0 0 0 0 0 6; 0 0 0 14 0 0 8 0 0; 0 0 0 0 16 0 0 10 0; 0 0 0 0 0 18 0 0 12; 2 0 0 8 0 0 0 0 0; 0 4 0 0 10 0 0 0 0; 0 0 6 0 0 12 0 0 0]
111+
end
112+

0 commit comments

Comments
 (0)