Skip to content

Commit 3433cdd

Browse files
ToucheSirmcabbott
andcommitted
Add accum_param tests
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
1 parent e9a6075 commit 3433cdd

File tree

1 file changed

+57
-4
lines changed

1 file changed

+57
-4
lines changed

test/features.jl

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ end
476476
@test_broken gradient(x -> abs2(x[1].x) + 7 * x[1].x.re, [Ref(1+im)]) == ([(x = 9.0 + 2.0im,)],)
477477
@test_broken gradient(x -> abs2(x[1].x) + 7 * real(x[1].x), [Ref(1+im)]) == ([(x = 9.0 + 2.0im,)],) # worked on 0.6.0, 0.6.20
478478

479-
@test_broken gradient(x -> abs2(x[].x) + 7 * real(x[].x), Ref(Ref(1+im))) == ((x = 9.0 + 2.0im,),) # gives nothing, same in 0.6.0
479+
@test gradient(x -> abs2(x[].x) + 7 * real(x[].x), Ref(Ref(1+im))) == ((x = (x = 9.0 + 2.0im,),),) # gave `nothing` from 0.6.0 to 0.6.41
480480

481481
# Array of mutables:
482482
@test gradient(x -> sum(getindex.(x).^2), Ref.(1:3))[1] == [(;x=2i) for i in 1:3]
@@ -490,6 +490,59 @@ end
490490
@test gradient(x -> sum(sum, Ref(x) .* [1,2,3]), [4,5]) == ([6.0, 6.0],)
491491
end
492492

493+
@testset "mutable accum_param bugs" begin
494+
mutable struct Mut{T}; x::T; end
495+
struct Imm{T}; x::T; end
496+
497+
# Indexing a tuple containing a mutable struct gave `nothing`
498+
x1 = (Mut(3.0),)
499+
x2 = (Imm(3.0),)
500+
x3 = (Ref(3.0),)
501+
@test gradient(x -> x[1].x^2, x1)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41
502+
@test gradient(x -> x[1].x^2, x2)[1] == ((x = 6.0,),)
503+
@test gradient(x -> x[1].x^2, x3)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41
504+
i1 = 1
505+
@test gradient(x -> x[i1].x^2, x1)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41
506+
@test gradient(x -> x[i1].x^2, x2)[1] == ((x = 6.0,),)
507+
@test gradient(x -> x[i1].x^2, x3)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41
508+
509+
@test gradient(x -> x[1][1].x^2, [x1])[1] == [((x = 6.0,),)] # fails on v0.6.0 v0.6.41
510+
@test gradient(x -> x[1][1].x^2, [x2])[1] == [((x = 6.0,),)]
511+
@test gradient(x -> x[1][1].x^2, [x3])[1] == [((x = 6.0,),)] # fails on v0.6.0 v0.6.41
512+
513+
# When `getfield` returns a mutable struct, it gave `nothing`:
514+
x4 = Imm(Mut(4.0))
515+
x5 = Mut(Mut(4.0))
516+
x6 = Imm(Imm(4.0))
517+
@test gradient(x -> x.x.x^3, x4)[1] == (x = (x = 48.0,),) # fails on v0.6.0 v0.6.41
518+
@test gradient(x -> x.x.x^3, x5)[1] == (x = (x = 48.0,),) # fails on v0.6.0
519+
@test gradient(x -> x.x.x^3, x6)[1] == (x = (x = 48.0,),) # fails on v0.6.41
520+
521+
@test gradient(x -> x[2].x.x^3, [x4, x4])[1] == [nothing, (x = (x = 48.0,),)] # fails on v0.6.0 v0.6.41
522+
@test gradient(x -> x[2].x.x^3, [x4, x5])[1] == [nothing, (x = (x = 48.0,),)] # fails on v0.6.0
523+
@test gradient(x -> x[2].x.x^3, [x4, x6])[1] == [nothing, (x = (x = 48.0,),)] # fails on v0.6.41
524+
525+
# Check when using implicit parameters, Params cases used to pass:
526+
y1 = [3.0]
527+
y2 = (Mut(y1),)
528+
y3 = (Imm(y1),)
529+
@test gradient(x -> sum(x[1].x)^2, y2)[1] == ((x = [6.0],),) # fails on v0.6.0 v0.6.41
530+
@test gradient(() -> sum(y2[1].x)^2, Params([y1]))[y1] == [6.0]
531+
@test gradient(x -> sum(x[1].x)^2, y3)[1] == ((x = [6.0],),)
532+
@test gradient(() -> sum(y3[1].x)^2, Params([y1]))[y1] == [6.0]
533+
534+
@test gradient(x -> sum(x[1].x .+ x[1].x)^3, y2)[1] == ((x = [216.0],),) # fails on v0.6.0 v0.6.41
535+
@test gradient(() -> sum(y2[1].x .+ y2[1].x)^3, Params([y1]))[y1] == [216.0]
536+
@test gradient(x -> sum(x[1].x .+ x[1].x)^3, y3)[1] == ((x = [216.0],),)
537+
@test gradient(() -> sum(y3[1].x .+ y3[1].x)^3, Params([y1]))[y1] == [216.0]
538+
539+
i1 = 1
540+
@test gradient(x -> sum(x[i1].x .+ x[1].x)^3, y2)[1] == ((x = [216.0],),) # fails on v0.6.0 v0.6.41
541+
@test gradient(() -> sum(y2[i1].x .+ y2[1].x)^3, Params([y1]))[y1] == [216.0]
542+
@test gradient(x -> sum(x[i1].x .+ x[1].x)^3, y3)[1] == ((x = [216.0],),)
543+
@test gradient(() -> sum(y3[i1].x .+ y3[1].x)^3, Params([y1]))[y1] == [216.0]
544+
end
545+
493546
@testset "NamedTuples" begin
494547
@test gradient(x -> x.a, (a=1, b=2)) == ((a = 1, b = nothing),)
495548
@test gradient(x -> x[1].a, [(a=1, b=2)]) == ([(a = 1, b = nothing)],)
@@ -517,7 +570,7 @@ end
517570
@test (x->10*(x => 2)[2])'(100) === nothing
518571

519572
@test gradient(x-> (:x => x)[2], 17) == (1,)
520-
573+
521574
d = Dict(:x=>1.0, :y=>3.0);
522575
@test gradient(d -> Dict(:x => d[:x])[:x], d) == (Dict(:x => 1),)
523576
end
@@ -546,7 +599,7 @@ end
546599
# zip
547600
if VERSION >= v"1.5"
548601
# On Julia 1.4 and earlier, [x/y for (x,y) in zip(10:14, 1:10)] is a DimensionMismatch,
549-
# while on 1.5 - 1.7 it stops early.
602+
# while on 1.5 - 1.7 it stops early.
550603

551604
@test gradient(10:14, 1:10) do xs, ys
552605
sum([x/y for (x,y) in zip(xs, ys)])
@@ -608,7 +661,7 @@ end
608661

609662
# Iterators.Product with enumerate
610663
@test gradient([2 3; 4 5]) do xs
611-
sum([x^i+y for (i,x) in enumerate(xs), y in xs])
664+
sum([x^i+y for (i,x) in enumerate(xs), y in xs])
612665
end == ([8 112; 36 2004],)
613666
end
614667

0 commit comments

Comments
 (0)