@@ -68,14 +68,13 @@ julia> gradient([7, 11], 0, 1) do x, y, d
6868 p = size(x, d)
6969 sum(x.^p .+ y)
7070 end
71- ([14.0, 22.0], 2, nothing)
71+ ([14.0, 22.0], 2.0 , nothing)
7272```
7373"""
7474function gradient (f, args... )
7575 y, back = pullback (f, args... )
7676 grad = back (sensitivity (y))
77- isnothing (grad) && return nothing
78- map (_project, args, grad)
77+ isnothing (grad) ? nothing : map (_project, args, grad)
7978end
8079
8180Base. adjoint (f:: Function ) = x -> gradient (f, x)[1 ]
9897function withgradient (f, args... )
9998 y, back = pullback (f, args... )
10099 grad = back (sensitivity (y))
101- isnothing (grad) && return (val = y, grad= nothing )
102- (val = y, grad = map (_project, args, grad) )
100+ results = isnothing (grad) ? map (_ -> nothing , args) : map (_project, args, grad)
101+ (val= y, grad= results )
103102end
104103
105104# Param-style wrappers
@@ -138,23 +137,17 @@ julia> g = gradient(Params([x, y])) do
138137Grads(...)
139138
140139julia> g[x]
141- 2×3 Matrix{Int64 }:
142- 7 70 700
143- 8 80 800
140+ 2×3 Matrix{Float64 }:
141+ 7.0 70.0 700.0
142+ 8.0 80.0 800.0
144143
145144julia> haskey(g, z) # only x and y are parameters
146145false
147146```
148147"""
149- function gradient (f, ps:: Params )
150- y, back = pullback (f, ps)
151- back (sensitivity (y))
152- end
148+ gradient
153149
154- function withgradient (f, ps:: Params )
155- y, back = pullback (f, ps)
156- (val = y, grad = back (sensitivity (y)))
157- end
150+ Base. map (:: typeof (_project), args:: Tuple{Params} , grad) = grad # skip _project in gradient(f, ::Params)
158151
159152function Base. union! (ps:: Params , itrs... )
160153 foreach (itr -> foreach (x -> push! (ps, x), itr), itrs)
0 commit comments