diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh index d5d62dbf228..bf0a22b6e2d 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh @@ -602,7 +602,8 @@ template + bool IsFP8 = false, + bool IsDynamic = true> __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( const T* __restrict__ quant_qkv, // [num_head, num_heads + 2 * // gqa_group_size, head_size] @@ -662,8 +663,6 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( (head_idx - num_heads) % gqa_group_size * block_size + block_offset; } - T* cache_k_scale_now = cache_k_scale + cache_offset; - T* cache_v_scale_now = cache_v_scale + cache_offset; float thread_m2 = 0.0f; float warp_m2 = 0.0f; @@ -811,25 +810,34 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( } } // reduce max, 1 head per warp - T local_max = -INFINITY; + if constexpr (IsDynamic) { + T local_max = -INFINITY; #pragma unroll - for (int i = 0; i < HALF_K_VEC_SIZE; i++) { - local_max = __hmax(local_max, __habs(bias_vec1[i])); - local_max = __hmax(local_max, __habs(bias_vec2[i])); - } + for (int i = 0; i < HALF_K_VEC_SIZE; i++) { + local_max = __hmax(local_max, __habs(bias_vec1[i])); + local_max = __hmax(local_max, __habs(bias_vec2[i])); + } #pragma unroll - for (int m_offset = 16; m_offset > 0; m_offset /= 2) { - local_max = - __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset)); - } - - scale = __hdiv(448, local_max); + for (int m_offset = 16; m_offset > 0; m_offset /= 2) { + local_max = + __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset)); + } - if (lane_id == 0) { + scale = __hdiv(448, local_max); + T* cache_k_scale_now = cache_k_scale + cache_offset; + T* cache_v_scale_now = cache_v_scale + cache_offset; + if (lane_id == 0) { + if (head_idx < num_heads + gqa_group_size) { + cache_k_scale_now[0] = __hdiv(1, scale); + } else { + cache_v_scale_now[0] = __hdiv(1, scale); + } + } + } else { if (head_idx < num_heads + gqa_group_size) { - cache_k_scale_now[0] = __hdiv(1, scale); + scale = __ldg(&cache_k_scale[kv_head_idx]); } else { - cache_v_scale_now[0] = __hdiv(1, scale); + scale = __ldg(&cache_v_scale[kv_head_idx]); } } diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu index 7b6a3c15927..3a9305df2b5 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu @@ -17,32 +17,32 @@ template void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv, - T* key_cache, - T* value_cache, - T* qkv_out, - const int* block_tables, - const int* batch_id_per_token, - const int* cu_seqlens_q, - const int* seq_lens, - const int* seq_lens_encoder, - const float* cos_emb, - const float* sin_emb, - const float* qkv_out_scales, - const T* qkv_biases, - const int max_seq_len, - const int max_blocks_per_seq, - const int num_heads, - const int kv_num_heads, - const int dim_head, - const int block_size, - const int bsz, - const int token_num, - const cudaStream_t& stream, - const bool use_neox_style, - const float* q_norm_weight, - const float* k_norm_weight, - const float rms_norm_eps, - const bool rope_3d) { + T* key_cache, + T* value_cache, + T* qkv_out, + const int* block_tables, + const int* batch_id_per_token, + const int* cu_seqlens_q, + const int* seq_lens, + const int* seq_lens_encoder, + const float* cos_emb, + const float* sin_emb, + const float* qkv_out_scales, + const T* qkv_biases, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int kv_num_heads, + const int dim_head, + const int block_size, + const int bsz, + const int token_num, + const cudaStream_t& stream, + const bool use_neox_style, + const float* q_norm_weight, + const float* k_norm_weight, + const float rms_norm_eps, + const bool rope_3d) { int output_inner_dim = num_heads + 2 * kv_num_heads; const uint32_t elem_nums = use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2 @@ -55,35 +55,34 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv, int grid_size = 1; GetNumBlocks<128>(pack_num, &grid_size); if (use_neox_style) { - PD_THROW( - "append_speculate_cache_rope_qk_norm not support neox rope yet"); + PD_THROW("append_speculate_cache_rope_qk_norm not support neox rope yet"); } else { dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1); append_speculate_cache_T_rope_qk_norm_kernel - <<>>(qkv, - key_cache, - value_cache, - qkv_out, - block_tables, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - cos_emb, - sin_emb, - qkv_out_scales, - qkv_biases, - max_seq_len, - max_blocks_per_seq, - num_heads, - output_inner_dim, - dim_head, - block_size, - elem_nums, - kv_num_heads, - q_norm_weight, - k_norm_weight, - rms_norm_eps, - rope_3d); + <<>>(qkv, + key_cache, + value_cache, + qkv_out, + block_tables, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + max_seq_len, + max_blocks_per_seq, + num_heads, + output_inner_dim, + dim_head, + block_size, + elem_nums, + kv_num_heads, + q_norm_weight, + k_norm_weight, + rms_norm_eps, + rope_3d); } } @@ -175,33 +174,33 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, } } -template -void append_speculate_cache_fp8_dynamic_rope(const T* qkv, - uint8_t* key_cache, - uint8_t* value_cache, - T* qkv_out, - const int* block_tables, - const int* batch_id_per_token, - const int* cu_seqlens_q, - const int* seq_lens, - const int* seq_lens_encoder, - const float* cos_emb, - const float* sin_emb, - T* cache_k_scale, - T* cache_v_scale, - const float* q_norm_weight, - const float* k_norm_weight, - const int max_seq_len, - const int max_blocks_per_seq, - const int num_heads, - const int kv_num_heads, - const int dim_head, - const int block_size, - const int bsz, - const int token_num, - const cudaStream_t& stream, - const bool rope_3d, - const float rms_norm_eps) { +template +void append_speculate_cache_fp8_rope(const T* qkv, + uint8_t* key_cache, + uint8_t* value_cache, + T* qkv_out, + const int* block_tables, + const int* batch_id_per_token, + const int* cu_seqlens_q, + const int* seq_lens, + const int* seq_lens_encoder, + const float* cos_emb, + const float* sin_emb, + T* cache_k_scale, + T* cache_v_scale, + const float* q_norm_weight, + const float* k_norm_weight, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int kv_num_heads, + const int dim_head, + const int block_size, + const int bsz, + const int token_num, + const cudaStream_t& stream, + const bool rope_3d, + const float rms_norm_eps) { constexpr int num_warps = 4; const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; @@ -220,7 +219,12 @@ void append_speculate_cache_fp8_dynamic_rope(const T* qkv, num_heads, block_size, kv_num_heads); - append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel + append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel <<>>(qkv, key_cache, value_cache, @@ -247,7 +251,7 @@ void append_speculate_cache_fp8_dynamic_rope(const T* qkv, rms_norm_eps); } -template +template void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, uint8_t* key_cache, uint8_t* value_cache, @@ -489,7 +493,6 @@ void SpeculateWriteCacheWithRoPEKernel( auto num_heads = meta_data.q_num_heads; auto kv_num_heads = meta_data.kv_num_heads; - const float* cos_emb = rotary_embs ? rotary_embs.get().data() : nullptr; const float* sin_emb; @@ -515,8 +518,8 @@ void SpeculateWriteCacheWithRoPEKernel( sin_emb, qkv_out_scales ? qkv_out_scales.get().data() : nullptr, qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, + const_cast(qkv_biases.get().data())) + : nullptr, max_seq_len, max_blocks_per_seq, num_heads, @@ -532,209 +535,243 @@ void SpeculateWriteCacheWithRoPEKernel( rms_norm_eps, rope_3d); } else if (cache_quant_type_str == "block_wise_fp8") { - append_speculate_cache_fp8_dynamic_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - const_cast(reinterpret_cast(cache_k_scale.get().data())), - const_cast(reinterpret_cast(cache_v_scale.get().data())), - q_norm_weight.get().data(), - k_norm_weight.get().data(), - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - token_nums, - stream, - rope_3d, - rms_norm_eps - ); + append_speculate_cache_fp8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast( + cache_k_scale.get().data())), + const_cast(reinterpret_cast( + cache_v_scale.get().data())), + q_norm_weight.get().data(), + k_norm_weight.get().data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + rope_3d, + rms_norm_eps); + } else if (cache_quant_type_str == "cache_fp8") { + append_speculate_cache_fp8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast( + cache_k_scale.get().data())), + const_cast(reinterpret_cast( + cache_v_scale.get().data())), + q_norm_weight.get().data(), + k_norm_weight.get().data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + rope_3d, + rms_norm_eps); } else { PD_THROW( - "append_decode_cache_rope_qk_norm not support cachekv quant yet"); + "speculate_append_decode_cache_rope_qk_norm just supports " + "cache_quant_type " + "none/block_wise_fp8/cache_fp8"); } } else { if (cache_quant_type_str == "none") { - append_speculate_cache_rope( - reinterpret_cast(qkv_ptr), - reinterpret_cast(key_cache_out->data()), - reinterpret_cast(value_cache_out->data()), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - token_nums, - stream, - use_neox_rotary_style, - rope_3d); + append_speculate_cache_rope( + reinterpret_cast(qkv_ptr), + reinterpret_cast(key_cache_out->data()), + reinterpret_cast(value_cache_out->data()), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style, + rope_3d); } else if (cache_quant_type_str == "cache_int8") { - append_speculate_cache_int8_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - cache_k_scale ? reinterpret_cast( - const_cast(cache_k_scale.get().data())) + append_speculate_cache_int8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) : nullptr, - cache_v_scale ? reinterpret_cast( - const_cast(cache_v_scale.get().data())) + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - token_nums, - stream, - use_neox_rotary_style, - rope_3d); + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style, + rope_3d); } else if (cache_quant_type_str == "cache_fp8") { - append_speculate_cache_int8_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - cache_k_scale ? reinterpret_cast( - const_cast(cache_k_scale.get().data())) + append_speculate_cache_int8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) : nullptr, - cache_v_scale ? reinterpret_cast( - const_cast(cache_v_scale.get().data())) + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - token_nums, - stream, - use_neox_rotary_style, - rope_3d); + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style, + rope_3d); } else if (cache_quant_type_str == "block_wise_fp8") { - append_speculate_cache_fp8_dynamic_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - const_cast(reinterpret_cast(cache_k_scale.get().data())), - const_cast(reinterpret_cast(cache_v_scale.get().data())), - nullptr, // q_norm_weight - nullptr, // k_norm_weight - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - token_nums, - stream, - rope_3d, - rms_norm_eps - ); + append_speculate_cache_fp8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast( + cache_k_scale.get().data())), + const_cast(reinterpret_cast( + cache_v_scale.get().data())), + nullptr, // q_norm_weight + nullptr, // k_norm_weight + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + rope_3d, + rms_norm_eps); } else if (cache_quant_type_str == "cache_int4_zp") { - append_speculate_cache_int4_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(const_cast(qkv_out->data())), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - cache_k_scale ? reinterpret_cast( - const_cast(cache_k_scale.get().data())) + append_speculate_cache_int4_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(const_cast(qkv_out->data())), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) : nullptr, - cache_v_scale ? reinterpret_cast( - const_cast(cache_v_scale.get().data())) + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) : nullptr, - cache_k_zp ? reinterpret_cast( - const_cast(cache_k_zp.get().data())) - : nullptr, - cache_v_zp ? reinterpret_cast( - const_cast(cache_v_zp.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - token_nums, - stream, - use_neox_rotary_style, - rope_3d); + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) + : nullptr, + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style, + rope_3d); } else { - PD_THROW( - "cache_quant_type_str should be one of [none, cache_int8, " - "cache_int4_zp]"); + PD_THROW( + "cache_quant_type_str should be one of [none, cache_int8, " + "cache_int4_zp]"); } } } @@ -827,7 +864,6 @@ template void SpeculateWriteCacheWithRoPEKernel( const paddle::optional& k_norm_weight, const float rms_norm_eps); - template void SpeculateWriteCacheWithRoPEKernel( const AppendAttnMetaData& meta_data, diff --git a/fastdeploy/config.py b/fastdeploy/config.py index e7d26417d3d..4bb0a445c36 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -198,6 +198,8 @@ def __init__( self.pooler_config: Optional["PoolerConfig"] = field(init=False) self.override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None self.revision = None + self.prefix_layer_name = "layers" + self.kv_cache_quant_scale_path = "" self.partial_rotary_factor: float = 1.0 self.num_nextn_predict_layers = 0 @@ -244,6 +246,7 @@ def _post_init(self): self.enable_mm = is_multimodal_model + self.kv_cache_quant_scale_path = os.path.join(self.model, "kv_cache_scale.json") if self.runner_type == "pooling": os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = "1" @@ -1591,6 +1594,10 @@ def postprocess(self): else: self.scheduler_config.max_num_batched_tokens = self.model_config.max_model_len + self.scheduler_config.max_chunk_len = ( + self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_extra_num_batched_tokens + ) + if self.long_prefill_token_threshold == 0: self.long_prefill_token_threshold = int(self.model_config.max_model_len * 0.04) diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index d2b03f8b17e..6bf85c0681d 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -475,15 +475,16 @@ def deal_state_dict(state_dict): src_tensor._share_data_with(dst_tensor) -def load_cache_scale(model_path, fd_config, state_dict): - file_path = os.path.join(model_path, "kv_cache_scale.json") +def load_cache_scale(fd_config, state_dict): + file_path = fd_config.model_config.kv_cache_quant_scale_path + prefix_layer_name = fd_config.model_config.prefix_layer_name if os.path.exists(file_path): with open(file_path, "r") as f: data = json.load(f) for i in range(fd_config.model_config.num_hidden_layers): - k_scale_name = f"ernie.layers.{i}.self_attn.cachek_matmul.activation_scale" - v_scale_name = f"ernie.layers.{i}.self_attn.cachev_matmul.activation_scale" + k_scale_name = f"ernie.{prefix_layer_name}.{i}.self_attn.cachek_matmul.activation_scale" + v_scale_name = f"ernie.{prefix_layer_name}.{i}.self_attn.cachev_matmul.activation_scale" k_scale = data[k_scale_name] k_scale_tensor = paddle.to_tensor(k_scale, dtype=paddle.get_default_dtype()) @@ -547,6 +548,6 @@ def load_composite_checkpoint( if hasattr(fd_config.quant_config, "kv_cache_quant_type"): kv_cache_quant_type = fd_config.quant_config.kv_cache_quant_type if kv_cache_quant_type == "float8_e4m3fn": - load_cache_scale(model_path, fd_config, state_dict) + load_cache_scale(fd_config, state_dict) return state_dict diff --git a/fastdeploy/scheduler/config.py b/fastdeploy/scheduler/config.py index 83ee476e467..3be17c48cf1 100644 --- a/fastdeploy/scheduler/config.py +++ b/fastdeploy/scheduler/config.py @@ -268,7 +268,9 @@ def __init__(self, args): Exception: If invalid scheduler type is specified """ self.name = "local" # "local" for LocalScheduler or "global" for GlobalScheduler - self.max_num_batched_tokens = 2048 + self.max_num_batched_tokens = 2048 # base token_num for text inputs + self.max_extra_num_batched_tokens = 16384 # extra token_num for multimodal inputs + self.max_chunk_len = 18432 # max supported token_num = max_num_batched_tokens + max_extra_num_batched_tokens self.max_num_seqs = 34 self.splitwise_role = "mixed" self.config = None diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 3b40c8c164f..1403bd3d28b 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -104,6 +104,7 @@ def _update_mtp_config(self, main_model): self.model_config.num_hidden_layers = 1 self.model_config.model = self.speculative_config.model self.model_config.pretrained_config.prefix_name = "ernie.mtp_block" + self.model_config.prefix_layer_name = "mtp_block" if self.speculative_config.quantization != "": self.model_config.quantization = self.speculative_config.quantization self.model_config.start_layer_index = self.num_main_model_layers @@ -354,7 +355,7 @@ def _init_model_inputs(self): self.target_model_inputs["decoder_tile_ids_per_batch"] ) self.model_inputs["target_hidden_states"] = paddle.full( - [self.max_model_len * self.fd_config.max_prefill_batch, self.model_config.hidden_size], 0, dtype="bfloat16" + [self.fd_config.scheduler_config.max_chunk_len, self.model_config.hidden_size], 0, dtype="bfloat16" ) tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))