Skip to content

Commit 69c0e2a

Browse files
gdallewsmoses
authored andcommitted
fix: correct return tuple for Enzyme's reverse autodiff
1 parent 09caa1f commit 69c0e2a

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

src/Enzyme.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -496,12 +496,11 @@ function overload_autodiff(
496496
func2.operation = MLIR.API.MlirOperation(C_NULL)
497497

498498
if reverse
499-
resv = if EnzymeCore.needs_primal(CMode)
500-
result
499+
if EnzymeCore.needs_primal(CMode)
500+
return ((restup...,), result)
501501
else
502-
nothing
502+
return ((restup...,),)
503503
end
504-
return ((restup...,), resv)
505504
else
506505
if EnzymeCore.needs_primal(CMode)
507506
if CMode <: ForwardMode && !(A <: Const)

test/autodiff.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,3 +407,12 @@ end
407407
@test results_fd[2].y results_enz[2].y
408408
@test results_fd[3] results_enz[3]
409409
end
410+
411+
@testset "Correct return tuple" begin
412+
# issue 1875
413+
x = ones(2)
414+
xr = Reactant.to_rarray(x)
415+
res = autodiff(Reverse, sum, Duplicated(x, zero(x)))
416+
res_reactant = @jit autodiff(Reverse, sum, Duplicated(xr, zero(xr)))
417+
@test length(res) == length(res_reactant)
418+
end

0 commit comments

Comments
 (0)