diff --git a/csrc/kernels/bgmv_expand.cpp b/csrc/kernels/bgmv_expand.cpp index c910005c7e3..6749c120250 100644 --- a/csrc/kernels/bgmv_expand.cpp +++ b/csrc/kernels/bgmv_expand.cpp @@ -341,10 +341,8 @@ class BGMVExpand { } // declare all dtype kernel -BGMV_EXPAND_TYPE_DECLARE(half) -#if (__CCE_AICORE__ >= 220) - BGMV_EXPAND_TYPE_DECLARE(bfloat16_t) -#endif +BGMV_EXPAND_TYPE_DECLARE(half); +BGMV_EXPAND_TYPE_DECLARE(bfloat16_t); namespace vllm_ascend { extern void bgmv_expand_impl(AscendType type, void* stream, void* x, void* weight, void* indices, uint32_t indicesSize, @@ -356,11 +354,9 @@ extern void bgmv_expand_impl(AscendType type, void* stream, void* x, void* weigh bgmv_expand_half<<>>(x, weight, indices, indicesSize, yIn, yOut, batchSize, numTokensPerCore, maxLoRARank, outputHiddenDim, sliceOffset, outputFullDim); } else if (type == AscendType::BF16) { - #if (__CCE_AICORE__ >= 220) - bgmv_expand_bfloat16_t<<>>(x, weight, indices, indicesSize, yIn, yOut, batchSize, - numTokensPerCore, maxLoRARank, outputHiddenDim, - sliceOffset, outputFullDim); - #endif + bgmv_expand_bfloat16_t<<>>(x, weight, indices, indicesSize, yIn, yOut, batchSize, + numTokensPerCore, maxLoRARank, outputHiddenDim, + sliceOffset, outputFullDim); } else { return; } diff --git a/csrc/kernels/bgmv_shrink.cpp b/csrc/kernels/bgmv_shrink.cpp index b5a2d15dd5a..e168e4b64d8 100644 --- a/csrc/kernels/bgmv_shrink.cpp +++ b/csrc/kernels/bgmv_shrink.cpp @@ -225,10 +225,8 @@ class BGMVShrink { } // declare all dtype kernel -BGMV_SHRINK_TYPE_DECLARE(half) -#if (__CCE_AICORE__ >= 220) - BGMV_SHRINK_TYPE_DECLARE(bfloat16_t) -#endif +BGMV_SHRINK_TYPE_DECLARE(half); +BGMV_SHRINK_TYPE_DECLARE(bfloat16_t); namespace vllm_ascend { extern void bgmv_shrink_impl(AscendType type, void* stream, void* x, void* weight, void* indices, uint32_t indicesSize, @@ -240,10 +238,8 @@ extern void bgmv_shrink_impl(AscendType type, void* stream, void* x, void* weigh bgmv_shrink_half<<>>(x, weight, indices, indicesSize, y, batchSize, numTokensPerCore, inputHiddenDim, maxLoRARank, scale); } else if (type == AscendType::BF16) { - #if (__CCE_AICORE__ >= 220) - bgmv_shrink_bfloat16_t<<>>(x, weight, indices, indicesSize, y, batchSize, numTokensPerCore, - inputHiddenDim, maxLoRARank, scale); - #endif + bgmv_shrink_bfloat16_t<<>>(x, weight, indices, indicesSize, y, batchSize, numTokensPerCore, + inputHiddenDim, maxLoRARank, scale); } else { return; } diff --git a/csrc/kernels/sgmv_expand.cpp b/csrc/kernels/sgmv_expand.cpp index 5466bd69950..65ec2271a34 100644 --- a/csrc/kernels/sgmv_expand.cpp +++ b/csrc/kernels/sgmv_expand.cpp @@ -356,10 +356,8 @@ class SGMVExpand { } // declare all dtype kernel -SGMV_EXPAND_TYPE_DECLARE(half) -#if (__CCE_AICORE__ >= 220) - SGMV_EXPAND_TYPE_DECLARE(bfloat16_t) -#endif +SGMV_EXPAND_TYPE_DECLARE(half); +SGMV_EXPAND_TYPE_DECLARE(bfloat16_t); namespace vllm_ascend { extern void sgmv_expand_impl(AscendType type, void* stream, void* x, void* weight, @@ -375,12 +373,10 @@ extern void sgmv_expand_impl(AscendType type, void* stream, void* x, void* weigh numTokensPerCore, maxLoRARank, outputHiddenDim, sliceOffset, outputFullDim); } else if (type == AscendType::BF16) { - #if (__CCE_AICORE__ >= 220) - sgmv_expand_bfloat16_t<<>>(x, weight, loraIndices, loraIndicesSize, - seqLen, seqLenSize, yIn, yOut, batchSize, - numTokensPerCore, maxLoRARank, outputHiddenDim, - sliceOffset, outputFullDim); - #endif + sgmv_expand_bfloat16_t<<>>(x, weight, loraIndices, loraIndicesSize, + seqLen, seqLenSize, yIn, yOut, batchSize, + numTokensPerCore, maxLoRARank, outputHiddenDim, + sliceOffset, outputFullDim); } else { return; } diff --git a/csrc/kernels/sgmv_shrink.cpp b/csrc/kernels/sgmv_shrink.cpp index a72e592ea1e..c8d8deca928 100644 --- a/csrc/kernels/sgmv_shrink.cpp +++ b/csrc/kernels/sgmv_shrink.cpp @@ -241,10 +241,8 @@ class SGMVShrink { } // declare all dtype kernel -SGMV_SHRINK_TYPE_DECLARE(half) -#if (__CCE_AICORE__ >= 220) - SGMV_SHRINK_TYPE_DECLARE(bfloat16_t) -#endif +SGMV_SHRINK_TYPE_DECLARE(half); +SGMV_SHRINK_TYPE_DECLARE(bfloat16_t); namespace vllm_ascend { extern void sgmv_shrink_impl(AscendType type, void* stream, void* x, void* weight, @@ -260,13 +258,11 @@ extern void sgmv_shrink_impl(AscendType type, void* stream, void* x, void* weigh numTokensPerCore, inputHiddenDim, maxLoRARank, scale); } else if (type == AscendType::BF16) { - #if (__CCE_AICORE__ >= 220) - sgmv_shrink_bfloat16_t<<>>(x, weight, loraIndices, loraIndicesSize, - seqLen, seqLenSize, - y, batchSize, - numTokensPerCore, inputHiddenDim, maxLoRARank, - scale); - #endif + sgmv_shrink_bfloat16_t<<>>(x, weight, loraIndices, loraIndicesSize, + seqLen, seqLenSize, + y, batchSize, + numTokensPerCore, inputHiddenDim, maxLoRARank, + scale); } else { return; } diff --git a/vllm_ascend/lora/punica_npu.py b/vllm_ascend/lora/punica_npu.py index bf86501d72e..ac90660838f 100644 --- a/vllm_ascend/lora/punica_npu.py +++ b/vllm_ascend/lora/punica_npu.py @@ -255,6 +255,7 @@ def add_lora_embedding(self, # Embedding layer only need expand op expand_fun: Callable = (self._expand_prefill if self.is_prefill else self._expand_decode) + x = x.to(torch.float32) expand_fun(y, x, lora_b_stacked, add_inputs) def add_lora_linear(self,