Skip to content

Commit 069c74e

Browse files
authored
add high performance moe kernel; fix a16w8 compile bug for sm<80 (#67)
1 parent 22807e4 commit 069c74e

File tree

9 files changed

+3465
-98
lines changed

9 files changed

+3465
-98
lines changed

csrc/core/kernel/cuda/gemm_lowp/gemm_a16w8_perc_kernel.cu

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1320,33 +1320,32 @@ struct ComputeTile_A16W8_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
13201320
// dequant B
13211321
#pragma unroll
13221322
for (int i = 0; i < WARP_NITER / 2; ++i) {
1323-
cvt_8bx4_to_16bx4_bias128(BQ_frag[reg_buf_idx][2 * i],
1324-
BF_frag[reg_buf_idx][2 * i]);
1323+
typename HalfType<FType>::T2 B_zero_x2 =
1324+
num2num2(static_cast<typename HalfType<FType>::T1>(0.f));
1325+
typename HalfType<FType>::T2 B_zero_y2 =
1326+
num2num2(static_cast<typename HalfType<FType>::T1>(0.f));
13251327
if (has_zp) {
1326-
BF_frag[reg_buf_idx][2 * i][0] =
1327-
__hsub2(BF_frag[reg_buf_idx][2 * i][0], num2num2(B_zero[i].x));
1328-
BF_frag[reg_buf_idx][2 * i][1] =
1329-
__hsub2(BF_frag[reg_buf_idx][2 * i][1], num2num2(B_zero[i].x));
1328+
B_zero_x2 = num2num2(B_zero[i].x);
1329+
B_zero_y2 = num2num2(B_zero[i].y);
13301330
}
13311331

1332-
BF_frag[reg_buf_idx][2 * i][0] =
1333-
__hmul2(BF_frag[reg_buf_idx][2 * i][0], num2num2(B_scale[i].x));
1334-
BF_frag[reg_buf_idx][2 * i][1] =
1335-
__hmul2(BF_frag[reg_buf_idx][2 * i][1], num2num2(B_scale[i].x));
1332+
cvt_8bx4_to_16bx4_bias128(BQ_frag[reg_buf_idx][2 * i],
1333+
BF_frag[reg_buf_idx][2 * i]);
1334+
1335+
BF_frag[reg_buf_idx][2 * i][0] = dequantize_func(
1336+
BF_frag[reg_buf_idx][2 * i][0], num2num2(B_scale[i].x), B_zero_x2);
1337+
BF_frag[reg_buf_idx][2 * i][1] = dequantize_func(
1338+
BF_frag[reg_buf_idx][2 * i][1], num2num2(B_scale[i].x), B_zero_x2);
13361339

13371340
cvt_8bx4_to_16bx4_bias128(BQ_frag[reg_buf_idx][2 * i + 1],
13381341
BF_frag[reg_buf_idx][2 * i + 1]);
1339-
if (has_zp) {
1340-
BF_frag[reg_buf_idx][2 * i + 1][0] =
1341-
__hsub2(BF_frag[reg_buf_idx][2 * i + 1][0], num2num2(B_zero[i].y));
1342-
BF_frag[reg_buf_idx][2 * i + 1][1] =
1343-
__hsub2(BF_frag[reg_buf_idx][2 * i + 1][1], num2num2(B_zero[i].y));
1344-
}
13451342

13461343
BF_frag[reg_buf_idx][2 * i + 1][0] =
1347-
__hmul2(BF_frag[reg_buf_idx][2 * i + 1][0], num2num2(B_scale[i].y));
1344+
dequantize_func(BF_frag[reg_buf_idx][2 * i + 1][0],
1345+
num2num2(B_scale[i].y), B_zero_y2);
13481346
BF_frag[reg_buf_idx][2 * i + 1][1] =
1349-
__hmul2(BF_frag[reg_buf_idx][2 * i + 1][1], num2num2(B_scale[i].y));
1347+
dequantize_func(BF_frag[reg_buf_idx][2 * i + 1][1],
1348+
num2num2(B_scale[i].y), B_zero_y2);
13501349
}
13511350
}
13521351

@@ -1677,6 +1676,10 @@ void ampere_hgemm_A16W8_perc_f16_f16_MtilexNtilex32_mma16816_multistage_AN_BTN32
16771676
const uint32_t K, void* workspace, const int sm_version,
16781677
const SplitKParams fused_gemm_params, const float alpha,
16791678
cudaStream_t stream) {
1679+
if (sm_version < 0x0800) {
1680+
throw std::runtime_error(
1681+
"this kernel is not supported on devices below sm80");
1682+
}
16801683
int Mtile = fused_gemm_params.Mtile;
16811684
int grid_x = (M + Mtile - 1) / Mtile;
16821685
int Ntile = fused_gemm_params.Ntile;

csrc/core/kernel/cuda/gemm_lowp/gemm_lowp_utils.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,7 +1030,15 @@ __device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128<__nv_bfloat162>(
10301030
}
10311031

10321032
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
1033+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
10331034
return __bfloat162bfloat162(x);
1035+
#else
1036+
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 3
1037+
__builtin_unreachable();
1038+
#else
1039+
return nv_bfloat162{};
1040+
#endif // __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 3
1041+
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
10341042
}
10351043

10361044
static __device__ half2 inline num2num2(const half x) {

0 commit comments

Comments
 (0)