Skip to content

Commit 2a8f9ea

Browse files
committed
Enforces 4D mask/bias; fixes dbias broadcasting
Standardizes attention aux inputs to strict 4D shapes with contiguous last dim and explicit broadcasting over batch, heads, and seqlen_q. Removes 3D mask/bias handling and validates dimensions against rounded key length. Allocates/validates dbias with broadcast-aware shapes and updates reductions to correctly sum over group, batch, and seqlen_q when broadcast, improving correctness for MQA/GQA and padded key lengths. Improves shape checks and internal consistency to prevent silent misalignment and shape-induced bugs.
1 parent 59c7f39 commit 2a8f9ea

File tree

1 file changed

+40
-57
lines changed

1 file changed

+40
-57
lines changed

csrc/flash_dmattn/flash_api.cpp

Lines changed: 40 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -795,14 +795,14 @@ mha_bwd(
795795
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
796796
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
797797
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
798-
const std::optional<at::Tensor> &mask_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k
799-
const std::optional<at::Tensor> &bias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k
798+
const std::optional<at::Tensor> &mask_, // {batch_size|1} x {num_heads|num_heads_k|1} x {seqlen_q|1} x {seqlen_k|1}
799+
const std::optional<at::Tensor> &bias_, // {batch_size|1} x {num_heads|num_heads_k|1} x {seqlen_q|1} x {seqlen_k|1}
800800
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
801801
const at::Tensor &softmax_lse, // b x h x seqlen_q
802802
std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
803803
std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
804804
std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
805-
std::optional<at::Tensor> &dbias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k
805+
std::optional<at::Tensor> &dbias_, // {batch_size|1} x {num_heads|num_heads_k|1} x {seqlen_q|1} x {seqlen_k|1}
806806
const float softmax_scale,
807807
const bool is_causal,
808808
const float softcap,
@@ -845,11 +845,8 @@ mha_bwd(
845845
mask = mask_.value();
846846
TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool");
847847
CHECK_DEVICE(mask);
848+
TORCH_CHECK(mask.dim() == 4, "mask must have 4 dimensions with shape (batch_size, nheads, seqlen_q, seqlen_k)");
848849
TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension");
849-
if (mask.dim() == 3) {
850-
// Add a dummy dimension for seqlen_q
851-
mask = mask.unsqueeze(2).expand({-1, -1, q.size(1), -1});
852-
}
853850
} else {
854851
mask = torch::empty({0}, opts);
855852
}
@@ -859,11 +856,8 @@ mha_bwd(
859856
bias = bias_.value();
860857
TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs");
861858
CHECK_DEVICE(bias);
859+
TORCH_CHECK(bias.dim() == 4, "bias must have 4 dimensions with shape (batch_size, nheads, seqlen_q, seqlen_k)");
862860
TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension");
863-
if (bias.dim() == 3) {
864-
// Add a dummy dimension for seqlen_q
865-
bias = bias.unsqueeze(2).expand({-1, -1, q.size(1), -1});
866-
}
867861
} else {
868862
bias = torch::empty({0}, opts);
869863
}
@@ -878,29 +872,39 @@ mha_bwd(
878872
const int num_heads_k = k.size(2);
879873
int num_heads_mask = has_mask ? mask.size(1) : 1;
880874
int num_heads_bias = has_bias ? bias.size(1) : 1;
875+
int batch_size_mask = has_mask ? mask.size(0) : batch_size;
876+
int batch_size_bias = has_bias ? bias.size(0) : batch_size;
877+
int seqlen_q_mask = has_mask ? mask.size(2) : seqlen_q;
878+
int seqlen_q_bias = has_bias ? bias.size(2) : seqlen_q;
879+
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
880+
const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
881+
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
882+
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
881883

882884
TORCH_CHECK(batch_size > 0, "batch size must be positive");
883885
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
884886
TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256");
885887
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
888+
886889
if (has_mask) {
890+
TORCH_CHECK(mask.size(0) == batch_size || mask.size(0) == 1, "Batch dimension in mask must be 1 or equal to batch size");
887891
TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, "Number of heads in mask must be 1, h_k or h");
892+
TORCH_CHECK(mask.size(2) == 1 || mask.size(2) == seqlen_q, "Query length dimension in mask must be 1 or equal to seqlen_q");
893+
TORCH_CHECK(mask.size(3) == seqlen_k_rounded, "Key length dimension in mask must be seqlen_k_rounded");
888894
}
889895
if (has_bias) {
896+
TORCH_CHECK(bias.size(0) == batch_size || bias.size(0) == 1, "Batch dimension in bias must be 1 or equal to batch size");
890897
TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, "Number of heads in bias must be 1, h_k or h");
898+
TORCH_CHECK(bias.size(2) == 1 || bias.size(2) == seqlen_q, "Query length dimension in bias must be 1 or equal to seqlen_q");
899+
TORCH_CHECK(bias.size(3) == seqlen_k_rounded, "Key length dimension in bias must be seqlen_k_rounded");
891900
}
892901

893-
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
894-
const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
895-
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
896-
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
897-
898902
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
899903
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
900904
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
901905
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
902906
CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);
903-
907+
904908
at::Tensor dq, dk, dv, dbias;
905909
if (dq_.has_value()) {
906910
dq = dq_.value();
@@ -934,30 +938,14 @@ mha_bwd(
934938
dbias = dbias_.value();
935939
TORCH_CHECK(dbias.dtype() == q_dtype, "dbias must have the same dtype as q");
936940
CHECK_DEVICE(dbias);
941+
TORCH_CHECK(dbias.dim() == 4, "dbias must have 4 dimensions with shape (batch_size, nheads, seqlen_q, seqlen_k)");
937942
TORCH_CHECK(dbias.stride(-1) == 1, "dbias must have contiguous last dimension");
938-
if (dbias.dim() == 4) {
939-
CHECK_SHAPE(dbias, batch_size, num_heads_bias, seqlen_q, seqlen_k);
940-
} else {
941-
CHECK_SHAPE(dbias, batch_size, num_heads_bias, seqlen_k);
942-
}
943+
TORCH_CHECK(dbias.size(0) == batch_size || dbias.size(0) == 1, "Batch dimension in dbias must be 1 or equal to batch size");
944+
TORCH_CHECK(dbias.size(1) == num_heads || dbias.size(1) == num_heads_k || dbias.size(1) == 1, "Number of heads in dbias must be 1, h_k or h");
945+
TORCH_CHECK(dbias.size(2) == seqlen_q || dbias.size(2) == 1, "Query length dimension in dbias must be 1 or equal to seqlen_q");
946+
TORCH_CHECK(dbias.size(3) == seqlen_k_rounded, "Key length dimension in dbias must be seqlen_k_rounded");
943947
} else {
944-
if (bias.dim() == 4) {
945-
if (num_heads_bias == 1) {
946-
dbias = torch::empty({batch_size, 1, seqlen_q, seqlen_k}, opts);
947-
} else if (num_heads_bias == num_heads_k) {
948-
dbias = torch::empty({batch_size, num_heads_k, seqlen_q, seqlen_k}, opts);
949-
} else {
950-
dbias = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts);
951-
}
952-
} else {
953-
if (num_heads_bias == 1) {
954-
dbias = torch::empty({batch_size, 1, seqlen_k}, opts);
955-
} else if (num_heads_bias == num_heads_k) {
956-
dbias = torch::empty({batch_size, num_heads_k, seqlen_k}, opts);
957-
} else {
958-
dbias = torch::empty({batch_size, num_heads, seqlen_k}, opts);
959-
}
960-
}
948+
dbias = torch::empty({batch_size_bias, num_heads_bias, seqlen_q_bias, seqlen_k_rounded}, opts);
961949
}
962950
} else {
963951
dbias = torch::empty({0}, opts);
@@ -990,8 +978,8 @@ mha_bwd(
990978
: dv;
991979
dbias_expanded = has_bias
992980
? (
993-
(num_heads_bias != num_heads) || (bias_.has_value() && bias_.value().dim() == 3) // MQA / GQA or bias has no seqlen_q dimension
994-
? torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts)
981+
(num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q) // MQA / GQA or dbias has different batch size or seqlen_q
982+
? torch::empty({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
995983
: dbias
996984
)
997985
: torch::empty({0}, opts);
@@ -1046,24 +1034,19 @@ mha_bwd(
10461034
at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
10471035
at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
10481036
}
1049-
// For MQA/GQA or num_heads_bias != num_heads, we also need to sum dbias across the heads
1037+
// For MQA/GQA or dbias has different batch size or seqlen_q, we need to sum dbias across the groups, batch and seqlen_q
10501038
if (has_bias) {
1051-
bool sum_seqlen_q = bias_.has_value() && bias_.value().dim() == 3;
1052-
if (num_heads_bias != num_heads) {
1053-
if (sum_seqlen_q) {
1054-
dbias_expanded = at::sum(
1055-
at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2}
1056-
);
1057-
} else {
1058-
at::sum_out(
1059-
dbias,
1060-
at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2}
1061-
);
1039+
if (num_heads_bias != num_heads && batch_size_bias == batch_size && seqlen_q_bias == seqlen_q) {
1040+
at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
1041+
} else {
1042+
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
1043+
if (seqlen_q_bias == 1) {
1044+
dbias_expanded = at::sum(dbias_expanded, {2}, true);
10621045
}
1063-
}
1064-
if (sum_seqlen_q) {
1065-
// We need to sum across the seqlen_q dimension
1066-
at::sum_out(dbias, dbias_expanded, {2});
1046+
if (batch_size_bias == 1) {
1047+
dbias_expanded = at::sum(dbias_expanded, {0}, true);
1048+
}
1049+
dbias.copy_(dbias_expanded);
10671050
}
10681051
}
10691052

0 commit comments

Comments
 (0)