Skip to content

Commit 55f88eb

Browse files
ShawnNewPeiyuLau
andauthored
[MLU] fix rnn with set projection zero (#855) (#859)
Co-authored-by: PeiyuLau <135964669+PeiyuLau@users.noreply.github.com>
1 parent eedb71e commit 55f88eb

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

backends/mlu/kernels/rnn_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ void RnnKernel(const Context& dev_ctx,
8181
int in_dim_arr[in_out_dim_num] = {seq_len, batch_size, input_dim};
8282
int out_dim_arr[in_out_dim_num] = {
8383
seq_len, batch_size, direction_num * hidden_size};
84-
int proj_size = hidden_size;
84+
int proj_size = 0;
8585

8686
std::vector<int> seq_len_vec(batch_size, seq_len);
8787
if (sequence_length.is_initialized()) { // set seq_len if no padding,
@@ -380,7 +380,7 @@ void RnnGradKernel(const Context& dev_ctx,
380380
int in_dim_arr[in_out_dim_num] = {seq_len, batch_size, input_dim};
381381
int out_dim_arr[in_out_dim_num] = {
382382
seq_len, batch_size, direction_num * hidden_size};
383-
int proj_size = hidden_size;
383+
int proj_size = 0;
384384
PADDLE_ENFORCE_EQ(
385385
num_layers,
386386
1,

0 commit comments

Comments
 (0)