@@ -68,14 +68,13 @@ julia> gradient([7, 11], 0, 1) do x, y, d
6868 p = size(x, d)
6969 sum(x.^p .+ y)
7070 end
71- ([14.0, 22.0], 2, nothing)
71+ ([14.0, 22.0], 2.0 , nothing)
7272```
7373"""
7474function gradient (f, args... )
7575 y, back = pullback (f, args... )
7676 grad = back (sensitivity (y))
77- isnothing (grad) && return nothing
78- map (_project, args, grad)
77+ isnothing (grad) ? nothing : map (_project, args, grad)
7978end
8079
8180Base. adjoint (f:: Function ) = x -> gradient (f, x)[1 ]
9897function withgradient (f, args... )
9998 y, back = pullback (f, args... )
10099 grad = back (sensitivity (y))
101- isnothing (grad) && return (val = y, grad= nothing )
102- (val = y, grad = map (_project, args, grad) )
100+ results = isnothing (grad) ? map (_ -> nothing , args) : map (_project, args, grad)
101+ (val= y, grad= results )
103102end
104103
105104# Param-style wrappers
106105
107- """
108- Params([A, B])
109-
110- Container for implicit parameters, used when differentiating
111- a zero-argument funtion `() -> loss(A, B)` with respect to `A, B`.
112- """
113- struct Params
114- order:: Buffer # {Any, Vector{Any}}
115- params:: IdSet{Any} # TODO store ids only
116- end
117-
118- Params () = Params (Buffer ([], false ), IdSet ())
119- Params (xs) = Params (Buffer (xs, false ), IdSet (xs))
120- Params (ps:: Params ) = ps
121- Params (xs:: Tuple ) = Params (collect (xs))
122-
123- @forward Params. order Base. iterate, Base. length, Base. getindex
124- @forward Params. params Base. in
125-
126106"""
127107 gradient(() -> loss(), ps::Params) -> Grads
128108
@@ -138,24 +118,37 @@ julia> g = gradient(Params([x, y])) do
138118Grads(...)
139119
140120julia> g[x]
141- 2×3 Matrix{Int64 }:
142- 7 70 700
143- 8 80 800
121+ 2×3 Matrix{Float64 }:
122+ 7.0 70.0 700.0
123+ 8.0 80.0 800.0
144124
145125julia> haskey(g, z) # only x and y are parameters
146126false
147127```
148128"""
149- function gradient (f, ps:: Params )
150- y, back = pullback (f, ps)
151- back (sensitivity (y))
152- end
129+ gradient
153130
154- function withgradient (f, ps:: Params )
155- y, back = pullback (f, ps)
156- (val = y, grad = back (sensitivity (y)))
131+ """
132+ Params([A, B])
133+
134+ Container for implicit parameters, used when differentiating
135+ a zero-argument funtion `() -> loss(A, B)` with respect to `A, B`.
136+ """
137+ struct Params
138+ order:: Buffer # {Any, Vector{Any}}
139+ params:: IdSet{Any} # TODO store ids only
157140end
158141
142+ Params () = Params (Buffer ([], false ), IdSet ())
143+ Params (xs) = Params (Buffer (xs, false ), IdSet (xs))
144+ Params (ps:: Params ) = ps
145+ Params (xs:: Tuple ) = Params (collect (xs))
146+
147+ @forward Params. order Base. iterate, Base. length, Base. getindex
148+ @forward Params. params Base. in
149+
150+ Base. map (:: typeof (_project), args:: Tuple{Params} , grad) = grad # skip _project in gradient(f, ::Params)
151+
159152function Base. union! (ps:: Params , itrs... )
160153 foreach (itr -> foreach (x -> push! (ps, x), itr), itrs)
161154 return ps
0 commit comments