|
102 | 102 |
|
103 | 103 | # Param-style wrappers |
104 | 104 |
|
| 105 | +""" |
| 106 | + Params([A, B]) |
| 107 | +
|
| 108 | +Container for implicit parameters, used when differentiating |
| 109 | +a zero-argument funtion `() -> loss(A, B)` with respect to `A, B`. |
| 110 | +""" |
| 111 | +struct Params |
| 112 | + order::Buffer # {Any, Vector{Any}} |
| 113 | + params::IdSet{Any} # TODO store ids only |
| 114 | +end |
| 115 | + |
| 116 | +Params() = Params(Buffer([], false), IdSet()) |
| 117 | +Params(xs) = Params(Buffer(xs, false), IdSet(xs)) |
| 118 | +Params(ps::Params) = ps |
| 119 | +Params(xs::Tuple) = Params(collect(xs)) |
| 120 | + |
| 121 | +@forward Params.order Base.iterate, Base.length, Base.getindex |
| 122 | +@forward Params.params Base.in |
| 123 | + |
105 | 124 | """ |
106 | 125 | gradient(() -> loss(), ps::Params) -> Grads |
107 | 126 |
|
@@ -135,25 +154,6 @@ function withgradient(f, ps::Params) |
135 | 154 | (val = y, grad = back(sensitivity(y))) |
136 | 155 | end |
137 | 156 |
|
138 | | -""" |
139 | | - Params([A, B]) |
140 | | -
|
141 | | -Container for implicit parameters, used when differentiating |
142 | | -a zero-argument funtion `() -> loss(A, B)` with respect to `A, B`. |
143 | | -""" |
144 | | -struct Params |
145 | | - order::Buffer # {Any, Vector{Any}} |
146 | | - params::IdSet{Any} # TODO store ids only |
147 | | -end |
148 | | - |
149 | | -Params() = Params(Buffer([], false), IdSet()) |
150 | | -Params(xs) = Params(Buffer(xs, false), IdSet(xs)) |
151 | | -Params(ps::Params) = ps |
152 | | -Params(xs::Tuple) = Params(collect(xs)) |
153 | | - |
154 | | -@forward Params.order Base.iterate, Base.length, Base.getindex |
155 | | -@forward Params.params Base.in |
156 | | - |
157 | 157 | function Base.union!(ps::Params, itrs...) |
158 | 158 | foreach(itr -> foreach(x -> push!(ps, x), itr), itrs) |
159 | 159 | return ps |
|
0 commit comments