diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index e23da5cccd4d..e02e56aa791f 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -840,7 +840,9 @@ def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch. def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0: + # torch._scaled_mm requires both dimensions of matrices to be divisible by 16 + # Check input dimensions and weight output dimension + if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0 or weight.shape[0] % 16 != 0: return F.linear(input, weight, bias) out = _linear_fp8(input, weight, bias) return out