@@ -340,7 +340,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
340340 int softmax_elements_stride,
341341 int attn_batches)
342342{
343- TORCH_INTERNAL_ASSERT (softmax_elements >= 0 && softmax_elements <= 2048 );
343+ TORCH_INTERNAL_ASSERT (softmax_elements >= 0 && softmax_elements <= 8192 );
344344 if (softmax_elements == 0 ) {
345345 return ;
346346 } else {
@@ -361,6 +361,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
361361 int warps_per_block = (threads_per_block / warp_size);
362362 int batches_per_block = warps_per_block * batches_per_warp;
363363 TORCH_INTERNAL_ASSERT (attn_batches % batches_per_block == 0 );
364+
364365 int blocks_per_seq = attn_batches / batches_per_block;
365366 dim3 blocks (seq_len, blocks_per_seq, 1 );
366367 dim3 threads (warp_size, warps_per_block, 1 );
@@ -414,6 +415,14 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
414415 scaled_upper_triang_masked_softmax_warp_forward<input_t , output_t , acc_t , 11 >
415416 <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
416417 break ;
418+ case 12 : // 4096
419+ scaled_upper_triang_masked_softmax_warp_forward<input_t , output_t , acc_t , 12 >
420+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
421+ break ;
422+ case 13 : // 8192
423+ scaled_upper_triang_masked_softmax_warp_forward<input_t , output_t , acc_t , 13 >
424+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
425+ break ;
417426 default :
418427 break ;
419428 }
@@ -430,7 +439,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
430439 int softmax_elements_stride,
431440 int attn_batches)
432441{
433- TORCH_INTERNAL_ASSERT ( softmax_elements >= 0 && softmax_elements <= 2048 );
442+ TORCH_INTERNAL_ASSERT ( softmax_elements >= 0 && softmax_elements <= 8192 );
434443 if (softmax_elements == 0 ) {
435444 return ;
436445 } else {
@@ -505,6 +514,14 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
505514 scaled_upper_triang_masked_softmax_warp_backward<input_t , output_t , acc_t , 11 >
506515 <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
507516 break ;
517+ case 12 : // 4096
518+ scaled_upper_triang_masked_softmax_warp_backward<input_t , output_t , acc_t , 12 >
519+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
520+ break ;
521+ case 13 : // 8192
522+ scaled_upper_triang_masked_softmax_warp_backward<input_t , output_t , acc_t , 13 >
523+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
524+ break ;
508525 default :
509526 break ;
510527 }
0 commit comments