Skip to content

Commit e9a7e7e

Browse files
committed
8192 upper
1 parent cc965d9 commit e9a7e7e

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

megatron/fused_kernels/rocm/scaled_upper_triang_masked_softmax.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
340340
int softmax_elements_stride,
341341
int attn_batches)
342342
{
343-
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
343+
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 8192 );
344344
if (softmax_elements == 0) {
345345
return;
346346
} else {
@@ -361,6 +361,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
361361
int warps_per_block = (threads_per_block / warp_size);
362362
int batches_per_block = warps_per_block * batches_per_warp;
363363
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
364+
364365
int blocks_per_seq = attn_batches / batches_per_block;
365366
dim3 blocks(seq_len, blocks_per_seq, 1);
366367
dim3 threads(warp_size, warps_per_block, 1);
@@ -414,6 +415,14 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
414415
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
415416
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
416417
break;
418+
case 12: // 4096
419+
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
420+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
421+
break;
422+
case 13: // 8192
423+
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 13>
424+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
425+
break;
417426
default:
418427
break;
419428
}
@@ -430,7 +439,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
430439
int softmax_elements_stride,
431440
int attn_batches)
432441
{
433-
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
442+
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 8192 );
434443
if (softmax_elements == 0) {
435444
return;
436445
} else {
@@ -505,6 +514,14 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
505514
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
506515
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
507516
break;
517+
case 12: // 4096
518+
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
519+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
520+
break;
521+
case 13: // 8192
522+
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 13>
523+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
524+
break;
508525
default:
509526
break;
510527
}

megatron/fused_kernels/rocm/scaled_upper_triang_masked_softmax_cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ torch::Tensor fwd_cuda(
3737
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
3838
const int attn_batches = input.size(0);
3939
const int seq_len = input.size(1);
40-
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
40+
TORCH_INTERNAL_ASSERT(seq_len <= 8192);
4141

4242
// Output
4343
auto act_options = input.options().requires_grad(false);

0 commit comments

Comments
 (0)