We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9406930 commit 1a40446Copy full SHA for 1a40446
fla/layers/kda.py
@@ -172,7 +172,7 @@ def forward(
172
173
batch_size, q_len, _ = hidden_states.shape
174
# change to inference mode.
175
- mode = 'fused_recurrent' if q_len <= 64 else self.mode
+ mode = 'fused_recurrent' if q_len <= 64 and not self.training else self.mode
176
if self.training:
177
assert mode == 'chunk', "Only chunk mode is supported in training."
178
0 commit comments