Skip to content

Commit db55f88

Browse files
Fix docstring for 'g' parameter shape (#631)
Updated the shape of the 'g' parameter in the docstring to reflect the new dimensions.
1 parent 6eefe8e commit db55f88

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

fla/ops/kda/fused_recurrent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def fused_recurrent_kda(
2929
values of shape `[B, T, HV, V]`.
3030
GVA is applied if `HV > H`.
3131
g (torch.Tensor):
32-
g (decays) of shape `[B, T, HV]`.
32+
g (decays) of shape `[B, T, HV, K]`.
3333
beta (torch.Tensor):
3434
betas of shape `[B, T, HV]`.
3535
scale (Optional[float]):

0 commit comments

Comments
 (0)