Skip to content

Commit fb34703

Browse files
committed
Handle nothing grads for Pairs.data
1 parent de078c8 commit fb34703

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

src/lib/base.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ function _pullback(cx::AContext, ::typeof(literal_getindex),
162162
ps::Iterators.Pairs{<:Any,<:Any,<:Any,<:NamedTuple}, ::Val{K}) where K
163163
val, gf_back = _pullback(cx, literal_getfield, NamedTuple(ps), Val(K))
164164
function kwargs_literal_getindex_pullback(Δ)
165-
dps = (data = gf_back(Δ)[2], itr = nothing)
165+
dps = (data = gradindex(gf_back(Δ), 2), itr = nothing)
166166
return (nothing, dps, nothing)
167167
end
168168
return val, kwargs_literal_getindex_pullback

test/features.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,10 @@ end
591591
h(somedata) = g(; somedata...)
592592
@test gradient(h, (; x=3.0, y=4.0, z=2.3)) == ((x = 2.3, y = nothing, z = 3.0),)
593593
@test gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == ((y = nothing, z = 3.0, x = 2.3),)
594+
595+
# for when no kwargs have grads backpropogated
596+
no_kwarg_grad(x; kwargs...) = x[kwargs[:i]]
597+
@test gradient(x -> no_kwarg_grad(x; i=1), [1]) == (1,)
594598
end
595599

596600
@testset "Iterators" begin

0 commit comments

Comments
 (0)