Skip to content

Commit 1a40446

Browse files
authored
fix: don't force fused_recurrent when in training mode (#636)
1 parent 9406930 commit 1a40446

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

fla/layers/kda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def forward(
172172

173173
batch_size, q_len, _ = hidden_states.shape
174174
# change to inference mode.
175-
mode = 'fused_recurrent' if q_len <= 64 else self.mode
175+
mode = 'fused_recurrent' if q_len <= 64 and not self.training else self.mode
176176
if self.training:
177177
assert mode == 'chunk', "Only chunk mode is supported in training."
178178

0 commit comments

Comments
 (0)