Skip to content

Commit 058940d

Browse files
Merge pull request #1670 from AayushSabharwal/as/fix-getindex-der
fix: fix derivatives of indexed array expressions
2 parents 9b0df3d + 8c2ddc8 commit 058940d

File tree

2 files changed

+45
-5
lines changed

2 files changed

+45
-5
lines changed

src/diff.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,16 +263,25 @@ function executediff(D, arg, simplify=false; throw_no_derivative=false)
263263
end
264264
end
265265
elseif op === getindex
266+
arr = arguments(arg)[1]
266267
inner_args = arguments(arguments(arg)[1])
268+
idxs = @views arguments(arg)[2:end]
267269
c = 0
268-
for a in inner_args
270+
# We know `D.x` is in `arg`, so the derivative is not identically zero.
271+
# `arg` cannot be `D.x` since, that would have also early exited.
272+
for (i, a) in enumerate(inner_args)
273+
der = derivative_idx(arr, i)
269274
if isequal(a, D.x)
270-
return D(arg)
275+
der isa NoDeriv && return D(arg)
276+
c += der[idxs...]
277+
continue
278+
elseif der isa NoDeriv
279+
c += Differential(a)(arg) * executediff(D, a)
271280
else
272-
c += Differential(a)(arg) * D(a)
281+
c += der[idxs...] * executediff(D, a)
273282
end
274283
end
275-
return expand_derivatives(c)
284+
return c
276285
elseif op === ifelse
277286
args = arguments(arg)
278287
O = op(args[1],

test/diff.jl

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,4 +634,35 @@ end
634634
@test iszero(Symbolics.unwrap.(Symbolics.gradient(f, vp) .- Symbolics.gradient(f, p)))
635635
@test iszero(Symbolics.unwrap.(Symbolics.hessian(f, vp) .- Symbolics.hessian(f, p)))
636636
@test iszero(Symbolics.unwrap.(Symbolics.jacobian([f], vp) .- Symbolics.jacobian([f], p)))
637-
end
637+
end
638+
639+
@testset "Derivatives of indexed array expressions" begin
640+
@register_array_symbolic foobar(x, y) begin
641+
size = (2,)
642+
eltype = Real
643+
ndims = 1
644+
end
645+
@register_array_symbolic dfoobar1(x, y) begin
646+
size = (2,)
647+
eltype = Real
648+
ndims = 1
649+
end
650+
@register_array_symbolic dfoobar2(x, y) begin
651+
size = (2,)
652+
eltype = Real
653+
ndims = 1
654+
end
655+
Symbolics.derivative(::typeof(foobar), args::NTuple{2, Any}, ::Val{1}) = dfoobar1(args...)
656+
Symbolics.derivative(::typeof(foobar), args::NTuple{2, Any}, ::Val{2}) = dfoobar2(args...)
657+
658+
@variables x y
659+
ex = foobar(x + 2y, y)
660+
@test ex isa Symbolics.Arr
661+
@test size(ex) == (2,)
662+
663+
der = Symbolics.derivative(ex[1], x)
664+
@test isequal(der, dfoobar1(x + 2y, y)[1])
665+
666+
der = Symbolics.derivative(ex[1], y)
667+
@test isequal(der, dfoobar1(x + 2y, y)[1] * 2 + dfoobar2(x + 2y, y)[1])
668+
end

0 commit comments

Comments
 (0)