Skip to content

Conversation

@ssam18
Copy link

@ssam18 ssam18 commented Nov 12, 2025

Fixes #6390

Problem

When use_fp8=True is enabled in HybridParallelPlugin and the model has output layers with dimensions not divisible by 16 (e.g., binary classification with 2 outputs), the training fails with:

Expected both dimensions of mat2 to be divisible by 16 but got torch.Size([768, 2])

Root Cause

torch._scaled_mm requires both dimensions of the weight matrix to be divisible by 16. The existing check in linear_fp8() only validated:

  • Input dimension (input.shape[-1])
  • Batch dimensions (np.prod(input.shape[:-1]))

But it did not check the output dimension (weight.shape[0]).

When using GPT2ForSequenceClassification with num_labels=2, the score layer has weight shape [768, 2], where 2 is not divisible by 16.

Solution

Added a check for weight.shape[0] % 16 != 0 to fallback to regular F.linear when the output dimension is not compatible with FP8.

if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0 or weight.shape[0] % 16 != 0:
    return F.linear(input, weight, bias)

Testing

This fix allows the model to:

  • Use FP8 for layers with compatible dimensions (performance benefit)
  • Fallback to standard FP16/BF16 for incompatible layers (correctness)
  • Run successfully with small output dimensions (e.g., binary classification)

The change is backward compatible and doesn't affect existing working configurations.

Fixes hpcaitech#6390

The issue occurs when use_fp8=True is enabled and the model has output
layers with dimensions not divisible by 16 (e.g., binary classification
with 2 outputs).

torch._scaled_mm requires BOTH dimensions of mat2 (weight matrix) to be
divisible by 16. The previous check only validated input dimensions but
not the weight output dimension (weight.shape[0]).

When using GPT2ForSequenceClassification with num_labels=2, the score
layer has weight shape [768, 2], causing the error:
'Expected both dimensions of mat2 to be divisible by 16 but got
torch.Size([768, 2])'

This fix adds a check for weight.shape[0] % 16 != 0 to fallback to
regular F.linear when the output dimension is not compatible with FP8.
@ssam18 ssam18 requested a review from a team as a code owner November 12, 2025 17:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG]: use_fp8=True in HybridParallelPlugin causes runtime error Expected both dimensions of mat2 to be divisible by 16 but got torch.Size([768, 2])

1 participant