@@ -189,18 +189,18 @@ end
189189
190190# Vanilla RNN
191191
192- struct RNNCell{F,A ,V,S}
192+ struct RNNCell{F,I,H ,V,S}
193193 σ:: F
194- Wi:: A
195- Wh:: A
194+ Wi:: I
195+ Wh:: H
196196 b:: V
197197 state0:: S
198198end
199199
200200RNNCell ((in, out):: Pair , σ= tanh; init= Flux. glorot_uniform, initb= zeros32, init_state= zeros32) =
201201 RNNCell (σ, init (out, in), init (out, out), initb (out), init_state (out,1 ))
202202
203- function (m:: RNNCell{F,A, V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {F,A ,V,T}
203+ function (m:: RNNCell{F,I,H, V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {F,I,H ,V,T}
204204 Wi, Wh, b = m. Wi, m. Wh, m. b
205205 σ = NNlib. fast_act (m. σ, x)
206206 h = σ .(Wi* x .+ Wh* h .+ b)
@@ -271,15 +271,27 @@ julia> r(rand(Float32, 3, 10)) |> size # batch size of 10
271271 julia> r(rand(Float32, 3)) |> size # erroneously outputs a length 5*10 = 50 vector.
272272 (50,)
273273 ```
274+
275+ # Note:
276+ `RNNCell`s can be constructed directly by specifying the non-linear function, the `W_i` and `W_h` internal matrices, a bias vector `b`, and a learnable initial state `state0`. The `W_i` and `W_h` matrices do not need to be the same type, but if `W_h` is `dxd`, then `W_i` should be of shape `dxN`.
277+
278+ ```julia
279+ julia> using LinearAlgebra
280+
281+ julia> r = Flux.Recur(Flux.RNNCell(tanh, rand(5, 4), Tridiagonal(rand(5, 5)), rand(5), rand(5, 1)))
282+
283+ julia> r(rand(4, 10)) |> size # batch size of 10
284+ (5, 10)
285+ ````
274286"""
275287RNN (a... ; ka... ) = Recur (RNNCell (a... ; ka... ))
276288Recur (m:: RNNCell ) = Recur (m, m. state0)
277289
278290# LSTM
279291
280- struct LSTMCell{A ,V,S}
281- Wi:: A
282- Wh:: A
292+ struct LSTMCell{I,H ,V,S}
293+ Wi:: I
294+ Wh:: H
283295 b:: V
284296 state0:: S
285297end
@@ -293,7 +305,7 @@ function LSTMCell((in, out)::Pair;
293305 return cell
294306end
295307
296- function (m:: LSTMCell{A, V,<:NTuple{2,AbstractMatrix{T}}} )((h, c), x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {A ,V,T}
308+ function (m:: LSTMCell{I,H, V,<:NTuple{2,AbstractMatrix{T}}} )((h, c), x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {I,H ,V,T}
297309 b, o = m. b, size (h, 1 )
298310 g = muladd (m. Wi, x, muladd (m. Wh, h, b))
299311 input, forget, cell, output = multigate (g, o, Val (4 ))
@@ -351,17 +363,17 @@ function _gru_output(gxs, ghs, bs)
351363 return r, z
352364end
353365
354- struct GRUCell{A ,V,S}
355- Wi:: A
356- Wh:: A
366+ struct GRUCell{I,H ,V,S}
367+ Wi:: I
368+ Wh:: H
357369 b:: V
358370 state0:: S
359371end
360372
361373GRUCell ((in, out):: Pair ; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
362374 GRUCell (init (out * 3 , in), init (out * 3 , out), initb (out * 3 ), init_state (out,1 ))
363375
364- function (m:: GRUCell{A, V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {A ,V,T}
376+ function (m:: GRUCell{I,H, V,<:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {I,H ,V,T}
365377 Wi, Wh, b, o = m. Wi, m. Wh, m. b, size (h, 1 )
366378 gxs, ghs, bs = multigate (Wi* x, o, Val (3 )), multigate (Wh* h, o, Val (3 )), multigate (b, o, Val (3 ))
367379 r, z = _gru_output (gxs, ghs, bs)
@@ -414,19 +426,19 @@ Recur(m::GRUCell) = Recur(m, m.state0)
414426
415427# GRU v3
416428
417- struct GRUv3Cell{A,V ,S}
418- Wi:: A
419- Wh:: A
429+ struct GRUv3Cell{I,H,V,HH ,S}
430+ Wi:: I
431+ Wh:: H
420432 b:: V
421- Wh_h̃:: A
433+ Wh_h̃:: HH
422434 state0:: S
423435end
424436
425437GRUv3Cell ((in, out):: Pair ; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
426438 GRUv3Cell (init (out * 3 , in), init (out * 2 , out), initb (out * 3 ),
427439 init (out, out), init_state (out,1 ))
428440
429- function (m:: GRUv3Cell{A,V, <:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {A,V ,T}
441+ function (m:: GRUv3Cell{I,H,V,HH, <:AbstractMatrix{T}} )(h, x:: Union{AbstractVecOrMat{T},OneHotArray} ) where {I,H,V,HH ,T}
430442 Wi, Wh, b, Wh_h̃, o = m. Wi, m. Wh, m. b, m. Wh_h̃, size (h, 1 )
431443 gxs, ghs, bs = multigate (Wi* x, o, Val (3 )), multigate (Wh* h, o, Val (2 )), multigate (b, o, Val (3 ))
432444 r, z = _gru_output (gxs, ghs, bs)
0 commit comments