@@ -795,14 +795,14 @@ mha_bwd(
795795 const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
796796 const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
797797 const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
798- const std::optional<at::Tensor> &mask_, // batch_size x {1 |num_heads_k|num_heads } x {seqlen_q|0 } x seqlen_k
799- const std::optional<at::Tensor> &bias_, // batch_size x {1 |num_heads_k|num_heads } x {seqlen_q|0 } x seqlen_k
798+ const std::optional<at::Tensor> &mask_, // { batch_size|1} x {num_heads |num_heads_k|1 } x {seqlen_q|1 } x { seqlen_k|1}
799+ const std::optional<at::Tensor> &bias_, // { batch_size|1} x {num_heads |num_heads_k|1 } x {seqlen_q|1 } x { seqlen_k|1}
800800 const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
801801 const at::Tensor &softmax_lse, // b x h x seqlen_q
802802 std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
803803 std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
804804 std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
805- std::optional<at::Tensor> &dbias_, // batch_size x {1 |num_heads_k|num_heads } x {seqlen_q|0 } x seqlen_k
805+ std::optional<at::Tensor> &dbias_, // { batch_size|1} x {num_heads |num_heads_k|1 } x {seqlen_q|1 } x { seqlen_k|1}
806806 const float softmax_scale,
807807 const bool is_causal,
808808 const float softcap,
@@ -845,11 +845,8 @@ mha_bwd(
845845 mask = mask_.value ();
846846 TORCH_CHECK (mask.dtype () == torch::kBool , " mask must have dtype bool" );
847847 CHECK_DEVICE (mask);
848+ TORCH_CHECK (mask.dim () == 4 , " mask must have 4 dimensions with shape (batch_size, nheads, seqlen_q, seqlen_k)" );
848849 TORCH_CHECK (mask.stride (-1 ) == 1 , " Input tensor must have contiguous last dimension" );
849- if (mask.dim () == 3 ) {
850- // Add a dummy dimension for seqlen_q
851- mask = mask.unsqueeze (2 ).expand ({-1 , -1 , q.size (1 ), -1 });
852- }
853850 } else {
854851 mask = torch::empty ({0 }, opts);
855852 }
@@ -859,11 +856,8 @@ mha_bwd(
859856 bias = bias_.value ();
860857 TORCH_CHECK (bias.dtype () == q_dtype, " bias must have the same dtype as inputs" );
861858 CHECK_DEVICE (bias);
859+ TORCH_CHECK (bias.dim () == 4 , " bias must have 4 dimensions with shape (batch_size, nheads, seqlen_q, seqlen_k)" );
862860 TORCH_CHECK (bias.stride (-1 ) == 1 , " Input tensor must have contiguous last dimension" );
863- if (bias.dim () == 3 ) {
864- // Add a dummy dimension for seqlen_q
865- bias = bias.unsqueeze (2 ).expand ({-1 , -1 , q.size (1 ), -1 });
866- }
867861 } else {
868862 bias = torch::empty ({0 }, opts);
869863 }
@@ -878,29 +872,39 @@ mha_bwd(
878872 const int num_heads_k = k.size (2 );
879873 int num_heads_mask = has_mask ? mask.size (1 ) : 1 ;
880874 int num_heads_bias = has_bias ? bias.size (1 ) : 1 ;
875+ int batch_size_mask = has_mask ? mask.size (0 ) : batch_size;
876+ int batch_size_bias = has_bias ? bias.size (0 ) : batch_size;
877+ int seqlen_q_mask = has_mask ? mask.size (2 ) : seqlen_q;
878+ int seqlen_q_bias = has_bias ? bias.size (2 ) : seqlen_q;
879+ auto round_multiple = [](int x, int m) { return (x + m - 1 ) / m * m; };
880+ const int head_size_rounded = round_multiple (head_size, head_size <= 128 ? 32 : 64 );
881+ const int seqlen_q_rounded = round_multiple (seqlen_q, 128 );
882+ const int seqlen_k_rounded = round_multiple (seqlen_k, 128 );
881883
882884 TORCH_CHECK (batch_size > 0 , " batch size must be positive" );
883885 TORCH_CHECK (head_size % 8 == 0 , " head_size should be a multiple of 8" );
884886 TORCH_CHECK (head_size <= 256 , " FlashDynamicMaskAttention backward only supports head dimension at most 256" );
885887 TORCH_CHECK (num_heads % num_heads_k == 0 , " Number of heads in key/value must divide number of heads in query" );
888+
886889 if (has_mask) {
890+ TORCH_CHECK (mask.size (0 ) == batch_size || mask.size (0 ) == 1 , " Batch dimension in mask must be 1 or equal to batch size" );
887891 TORCH_CHECK (num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, " Number of heads in mask must be 1, h_k or h" );
892+ TORCH_CHECK (mask.size (2 ) == 1 || mask.size (2 ) == seqlen_q, " Query length dimension in mask must be 1 or equal to seqlen_q" );
893+ TORCH_CHECK (mask.size (3 ) == seqlen_k_rounded, " Key length dimension in mask must be seqlen_k_rounded" );
888894 }
889895 if (has_bias) {
896+ TORCH_CHECK (bias.size (0 ) == batch_size || bias.size (0 ) == 1 , " Batch dimension in bias must be 1 or equal to batch size" );
890897 TORCH_CHECK (num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, " Number of heads in bias must be 1, h_k or h" );
898+ TORCH_CHECK (bias.size (2 ) == 1 || bias.size (2 ) == seqlen_q, " Query length dimension in bias must be 1 or equal to seqlen_q" );
899+ TORCH_CHECK (bias.size (3 ) == seqlen_k_rounded, " Key length dimension in bias must be seqlen_k_rounded" );
891900 }
892901
893- auto round_multiple = [](int x, int m) { return (x + m - 1 ) / m * m; };
894- const int head_size_rounded = round_multiple (head_size, head_size <= 128 ? 32 : 64 );
895- const int seqlen_q_rounded = round_multiple (seqlen_q, 128 );
896- const int seqlen_k_rounded = round_multiple (seqlen_k, 128 );
897-
898902 CHECK_SHAPE (q, batch_size, seqlen_q, num_heads, head_size);
899903 CHECK_SHAPE (k, batch_size, seqlen_k, num_heads_k, head_size);
900904 CHECK_SHAPE (v, batch_size, seqlen_k, num_heads_k, head_size);
901905 CHECK_SHAPE (out, batch_size, seqlen_q, num_heads, head_size);
902906 CHECK_SHAPE (dout, batch_size, seqlen_q, num_heads, head_size);
903-
907+
904908 at::Tensor dq, dk, dv, dbias;
905909 if (dq_.has_value ()) {
906910 dq = dq_.value ();
@@ -934,30 +938,14 @@ mha_bwd(
934938 dbias = dbias_.value ();
935939 TORCH_CHECK (dbias.dtype () == q_dtype, " dbias must have the same dtype as q" );
936940 CHECK_DEVICE (dbias);
941+ TORCH_CHECK (dbias.dim () == 4 , " dbias must have 4 dimensions with shape (batch_size, nheads, seqlen_q, seqlen_k)" );
937942 TORCH_CHECK (dbias.stride (-1 ) == 1 , " dbias must have contiguous last dimension" );
938- if (dbias.dim () == 4 ) {
939- CHECK_SHAPE (dbias, batch_size, num_heads_bias, seqlen_q, seqlen_k);
940- } else {
941- CHECK_SHAPE (dbias, batch_size, num_heads_bias, seqlen_k);
942- }
943+ TORCH_CHECK (dbias.size (0 ) == batch_size || dbias.size (0 ) == 1 , " Batch dimension in dbias must be 1 or equal to batch size" );
944+ TORCH_CHECK (dbias.size (1 ) == num_heads || dbias.size (1 ) == num_heads_k || dbias.size (1 ) == 1 , " Number of heads in dbias must be 1, h_k or h" );
945+ TORCH_CHECK (dbias.size (2 ) == seqlen_q || dbias.size (2 ) == 1 , " Query length dimension in dbias must be 1 or equal to seqlen_q" );
946+ TORCH_CHECK (dbias.size (3 ) == seqlen_k_rounded, " Key length dimension in dbias must be seqlen_k_rounded" );
943947 } else {
944- if (bias.dim () == 4 ) {
945- if (num_heads_bias == 1 ) {
946- dbias = torch::empty ({batch_size, 1 , seqlen_q, seqlen_k}, opts);
947- } else if (num_heads_bias == num_heads_k) {
948- dbias = torch::empty ({batch_size, num_heads_k, seqlen_q, seqlen_k}, opts);
949- } else {
950- dbias = torch::empty ({batch_size, num_heads, seqlen_q, seqlen_k}, opts);
951- }
952- } else {
953- if (num_heads_bias == 1 ) {
954- dbias = torch::empty ({batch_size, 1 , seqlen_k}, opts);
955- } else if (num_heads_bias == num_heads_k) {
956- dbias = torch::empty ({batch_size, num_heads_k, seqlen_k}, opts);
957- } else {
958- dbias = torch::empty ({batch_size, num_heads, seqlen_k}, opts);
959- }
960- }
948+ dbias = torch::empty ({batch_size_bias, num_heads_bias, seqlen_q_bias, seqlen_k_rounded}, opts);
961949 }
962950 } else {
963951 dbias = torch::empty ({0 }, opts);
@@ -990,8 +978,8 @@ mha_bwd(
990978 : dv;
991979 dbias_expanded = has_bias
992980 ? (
993- (num_heads_bias != num_heads) || (bias_. has_value () && bias_. value (). dim () == 3 ) // MQA / GQA or bias has no seqlen_q dimension
994- ? torch::empty ({batch_size, num_heads, seqlen_q, seqlen_k }, opts)
981+ (num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q ) // MQA / GQA or dbias has different batch size or seqlen_q
982+ ? torch::empty ({batch_size, num_heads, seqlen_q, seqlen_k_rounded }, opts)
995983 : dbias
996984 )
997985 : torch::empty ({0 }, opts);
@@ -1046,24 +1034,19 @@ mha_bwd(
10461034 at::sum_out (dk, at::reshape (dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3 });
10471035 at::sum_out (dv, at::reshape (dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3 });
10481036 }
1049- // For MQA/GQA or num_heads_bias != num_heads , we also need to sum dbias across the heads
1037+ // For MQA/GQA or dbias has different batch size or seqlen_q , we need to sum dbias across the groups, batch and seqlen_q
10501038 if (has_bias) {
1051- bool sum_seqlen_q = bias_.has_value () && bias_.value ().dim () == 3 ;
1052- if (num_heads_bias != num_heads) {
1053- if (sum_seqlen_q) {
1054- dbias_expanded = at::sum (
1055- at::reshape (dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2 }
1056- );
1057- } else {
1058- at::sum_out (
1059- dbias,
1060- at::reshape (dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2 }
1061- );
1039+ if (num_heads_bias != num_heads && batch_size_bias == batch_size && seqlen_q_bias == seqlen_q) {
1040+ at::sum_out (dbias, at::reshape (dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2 });
1041+ } else {
1042+ dbias_expanded = at::sum (at::reshape (dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2 });
1043+ if (seqlen_q_bias == 1 ) {
1044+ dbias_expanded = at::sum (dbias_expanded, {2 }, true );
10621045 }
1063- }
1064- if (sum_seqlen_q) {
1065- // We need to sum across the seqlen_q dimension
1066- at::sum_out ( dbias, dbias_expanded, { 2 } );
1046+ if (batch_size_bias == 1 ) {
1047+ dbias_expanded = at::sum (dbias_expanded, { 0 }, true );
1048+ }
1049+ dbias. copy_ ( dbias_expanded);
10671050 }
10681051 }
10691052
0 commit comments