@@ -1320,33 +1320,32 @@ struct ComputeTile_A16W8_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
13201320// dequant B
13211321#pragma unroll
13221322 for (int i = 0 ; i < WARP_NITER / 2 ; ++i) {
1323- cvt_8bx4_to_16bx4_bias128 (BQ_frag[reg_buf_idx][2 * i],
1324- BF_frag[reg_buf_idx][2 * i]);
1323+ typename HalfType<FType>::T2 B_zero_x2 =
1324+ num2num2 (static_cast <typename HalfType<FType>::T1>(0 .f ));
1325+ typename HalfType<FType>::T2 B_zero_y2 =
1326+ num2num2 (static_cast <typename HalfType<FType>::T1>(0 .f ));
13251327 if (has_zp) {
1326- BF_frag[reg_buf_idx][2 * i][0 ] =
1327- __hsub2 (BF_frag[reg_buf_idx][2 * i][0 ], num2num2 (B_zero[i].x ));
1328- BF_frag[reg_buf_idx][2 * i][1 ] =
1329- __hsub2 (BF_frag[reg_buf_idx][2 * i][1 ], num2num2 (B_zero[i].x ));
1328+ B_zero_x2 = num2num2 (B_zero[i].x );
1329+ B_zero_y2 = num2num2 (B_zero[i].y );
13301330 }
13311331
1332- BF_frag[reg_buf_idx][2 * i][0 ] =
1333- __hmul2 (BF_frag[reg_buf_idx][2 * i][0 ], num2num2 (B_scale[i].x ));
1334- BF_frag[reg_buf_idx][2 * i][1 ] =
1335- __hmul2 (BF_frag[reg_buf_idx][2 * i][1 ], num2num2 (B_scale[i].x ));
1332+ cvt_8bx4_to_16bx4_bias128 (BQ_frag[reg_buf_idx][2 * i],
1333+ BF_frag[reg_buf_idx][2 * i]);
1334+
1335+ BF_frag[reg_buf_idx][2 * i][0 ] = dequantize_func (
1336+ BF_frag[reg_buf_idx][2 * i][0 ], num2num2 (B_scale[i].x ), B_zero_x2);
1337+ BF_frag[reg_buf_idx][2 * i][1 ] = dequantize_func (
1338+ BF_frag[reg_buf_idx][2 * i][1 ], num2num2 (B_scale[i].x ), B_zero_x2);
13361339
13371340 cvt_8bx4_to_16bx4_bias128 (BQ_frag[reg_buf_idx][2 * i + 1 ],
13381341 BF_frag[reg_buf_idx][2 * i + 1 ]);
1339- if (has_zp) {
1340- BF_frag[reg_buf_idx][2 * i + 1 ][0 ] =
1341- __hsub2 (BF_frag[reg_buf_idx][2 * i + 1 ][0 ], num2num2 (B_zero[i].y ));
1342- BF_frag[reg_buf_idx][2 * i + 1 ][1 ] =
1343- __hsub2 (BF_frag[reg_buf_idx][2 * i + 1 ][1 ], num2num2 (B_zero[i].y ));
1344- }
13451342
13461343 BF_frag[reg_buf_idx][2 * i + 1 ][0 ] =
1347- __hmul2 (BF_frag[reg_buf_idx][2 * i + 1 ][0 ], num2num2 (B_scale[i].y ));
1344+ dequantize_func (BF_frag[reg_buf_idx][2 * i + 1 ][0 ],
1345+ num2num2 (B_scale[i].y ), B_zero_y2);
13481346 BF_frag[reg_buf_idx][2 * i + 1 ][1 ] =
1349- __hmul2 (BF_frag[reg_buf_idx][2 * i + 1 ][1 ], num2num2 (B_scale[i].y ));
1347+ dequantize_func (BF_frag[reg_buf_idx][2 * i + 1 ][1 ],
1348+ num2num2 (B_scale[i].y ), B_zero_y2);
13501349 }
13511350 }
13521351
@@ -1677,6 +1676,10 @@ void ampere_hgemm_A16W8_perc_f16_f16_MtilexNtilex32_mma16816_multistage_AN_BTN32
16771676 const uint32_t K, void * workspace, const int sm_version,
16781677 const SplitKParams fused_gemm_params, const float alpha,
16791678 cudaStream_t stream) {
1679+ if (sm_version < 0x0800 ) {
1680+ throw std::runtime_error (
1681+ " this kernel is not supported on devices below sm80" );
1682+ }
16801683 int Mtile = fused_gemm_params.Mtile ;
16811684 int grid_x = (M + Mtile - 1 ) / Mtile;
16821685 int Ntile = fused_gemm_params.Ntile ;
0 commit comments