Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,8 @@ template <typename T,
int VecSize = 4,
int RoundType = 0,
int HeadDim = 128,
bool IsFP8 = false>
bool IsFP8 = false,
bool IsDynamic = true>
__global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
const T* __restrict__ quant_qkv, // [num_head, num_heads + 2 *
// gqa_group_size, head_size]
Expand Down Expand Up @@ -662,8 +663,6 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
(head_idx - num_heads) % gqa_group_size * block_size +
block_offset;
}
T* cache_k_scale_now = cache_k_scale + cache_offset;
T* cache_v_scale_now = cache_v_scale + cache_offset;

float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
Expand Down Expand Up @@ -811,25 +810,34 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
}
}
// reduce max, 1 head per warp
T local_max = -INFINITY;
if constexpr (IsDynamic) {
T local_max = -INFINITY;
#pragma unroll
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
local_max = __hmax(local_max, __habs(bias_vec1[i]));
local_max = __hmax(local_max, __habs(bias_vec2[i]));
}
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
local_max = __hmax(local_max, __habs(bias_vec1[i]));
local_max = __hmax(local_max, __habs(bias_vec2[i]));
}
#pragma unroll
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
local_max =
__hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
}

scale = __hdiv(448, local_max);
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
local_max =
__hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
}

if (lane_id == 0) {
scale = __hdiv(448, local_max);
T* cache_k_scale_now = cache_k_scale + cache_offset;
T* cache_v_scale_now = cache_v_scale + cache_offset;
if (lane_id == 0) {
if (head_idx < num_heads + gqa_group_size) {
cache_k_scale_now[0] = __hdiv(1, scale);
} else {
cache_v_scale_now[0] = __hdiv(1, scale);
}
}
} else {
if (head_idx < num_heads + gqa_group_size) {
cache_k_scale_now[0] = __hdiv(1, scale);
scale = __ldg(&cache_k_scale[kv_head_idx]);
} else {
cache_v_scale_now[0] = __hdiv(1, scale);
scale = __ldg(&cache_v_scale[kv_head_idx]);
}
}

Expand Down
Loading
Loading