Skip to content

Commit b82a9d6

Browse files
committed
Vendor DeepGEMM FP8 lightning indexer kernels. They're not wired up yet.
1 parent 7289478 commit b82a9d6

File tree

11 files changed

+2015
-0
lines changed

11 files changed

+2015
-0
lines changed

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,8 @@ endif()
218218
target_include_directories(ggml-cuda PRIVATE
219219
${CMAKE_CURRENT_SOURCE_DIR}/vendors/cutlass/include
220220
)
221+
222+
# DeepGEMM vendor includes for sm90/sm100 FP8 paged MQA logits kernel (DSA)
223+
target_include_directories(ggml-cuda PRIVATE
224+
${CMAKE_CURRENT_SOURCE_DIR}/vendors/DeepGEMM
225+
)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#pragma once
2+
3+
namespace cute {
4+
5+
struct ignore_t {
6+
template <typename T>
7+
constexpr const ignore_t& operator=(T&&) const noexcept {
8+
return *this;
9+
}
10+
};
11+
12+
inline constexpr ignore_t ignore{};
13+
14+
} // namespace cute
15+
16+
#define CUTE_TIE_CONCAT_IMPL(A, B) A##B
17+
#define CUTE_TIE_CONCAT(A, B) CUTE_TIE_CONCAT_IMPL(A, B)
18+
19+
#define CUTE_TIE_GET_NTH_ARG(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) N
20+
#define CUTE_TIE_COUNT_ARGS(...) \
21+
CUTE_TIE_GET_NTH_ARG(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
22+
23+
#define CUTE_TIE_OP_DECL(I, TUPLE, VAR) auto VAR = ::cute::get<I>(TUPLE)
24+
#define CUTE_TIE_OP_ASSIGN(I, TUPLE, VAR) VAR = ::cute::get<I>(TUPLE)
25+
26+
#define CUTE_TIE_APPLY_OP_1(OP, T, V1) OP(0, T, V1);
27+
#define CUTE_TIE_APPLY_OP_2(OP, T, V1, V2) OP(0, T, V1); OP(1, T, V2);
28+
#define CUTE_TIE_APPLY_OP_3(OP, T, V1, V2, V3) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3);
29+
#define CUTE_TIE_APPLY_OP_4(OP, T, V1, V2, V3, V4) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4);
30+
#define CUTE_TIE_APPLY_OP_5(OP, T, V1, V2, V3, V4, V5) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); OP(4, T, V5);
31+
32+
#define CUTE_TIE_DECL(TUPLE_EXPR, ...) \
33+
auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \
34+
CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \
35+
CUTE_TIE_OP_DECL, \
36+
CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \
37+
__VA_ARGS__ \
38+
)
39+
40+
#define CUTE_TIE(TUPLE_EXPR, ...) \
41+
do { \
42+
auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \
43+
CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \
44+
CUTE_TIE_OP_ASSIGN, \
45+
CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \
46+
__VA_ARGS__ \
47+
); \
48+
} while (0)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#pragma once
2+
3+
#include <deep_gemm/common/types.hpp>
4+
#include <deep_gemm/common/utils.cuh>
5+
6+
namespace deep_gemm {
7+
8+
struct EpilogueIdentity {
9+
template <uint32_t STORE_BLOCK_N>
10+
__device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) {
11+
return n_idx;
12+
}
13+
};
14+
15+
template <uint32_t kLeft, uint32_t kMid, uint32_t kRight>
16+
struct EpilogueHeadSplits: EpilogueIdentity {
17+
template <uint32_t STORE_BLOCK_N>
18+
__device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) {
19+
DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0
20+
and kRight % STORE_BLOCK_N == 0, "Invalid head splits config");
21+
return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid;
22+
}
23+
};
24+
25+
#pragma clang diagnostic pop
26+
27+
} // namespace deep_gemm
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#pragma once
2+
3+
#include <cuda_bf16.h>
4+
#include <cuda_fp8.h>
5+
#include <cuda/std/cstdint>
6+
#include <cuda/std/utility>
7+
8+
#include <deep_gemm/common/utils.cuh>
9+
10+
// Operation functors
11+
template <typename T> struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } };
12+
template <typename T> struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } };
13+
template <typename T> struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } };
14+
template <typename T> struct ReduceAnd { __device__ T operator()(T a, T b) const { return a & b; } };
15+
template <typename T> struct ReduceOr { __device__ T operator()(T a, T b) const { return a | b; } };
16+
17+
// Unified reduction function
18+
template <int kNumLanesPerGroup, bool kIntergroupReduce, typename T, typename Op>
19+
__forceinline__ __device__ T warp_reduce(T value, Op op) {
20+
DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or
21+
kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1,
22+
"Invalid number of lanes");
23+
constexpr uint32_t mask = 0xffffffff;
24+
if constexpr (kIntergroupReduce) {
25+
if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1));
26+
if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2));
27+
if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4));
28+
if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8));
29+
if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16));
30+
} else {
31+
if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16));
32+
if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8));
33+
if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4));
34+
if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2));
35+
if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1));
36+
}
37+
return value;
38+
}
39+
40+
// Convenience aliases
41+
template <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
42+
__forceinline__ __device__ T warp_reduce_sum(T value) {
43+
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceSum<T>{});
44+
}
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
#pragma once
2+
3+
#include <deep_gemm/common/types.hpp>
4+
#include <deep_gemm/common/utils.cuh>
5+
6+
namespace deep_gemm {
7+
8+
enum class IndexType {
9+
MN,
10+
K,
11+
SF_K,
12+
};
13+
14+
template <GemmType kGemmType, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumSMs, bool kIsMulticastOnA>
15+
static constexpr uint32_t get_num_1d_blocks_per_group() {
16+
// Select the best from candidates
17+
uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits<uint32_t>::max();
18+
for (const auto& candidate: {8u, 16u}) {
19+
const auto& usage = kIsMulticastOnA ?
20+
candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N
21+
candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M
22+
if (usage < min_usage)
23+
min_usage = usage, num_best_blocks = candidate;
24+
}
25+
return num_best_blocks;
26+
}
27+
28+
#pragma clang diagnostic push
29+
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
30+
template <GemmType kGemmType,
31+
uint32_t BLOCK_M, uint32_t BLOCK_N,
32+
uint32_t kNumGroups,
33+
uint32_t kNumMulticast, bool kIsMulticastOnA,
34+
uint32_t kNumSMs,
35+
uint32_t SF_K_ALIGNMENT = 512u, // for k-grouped GEMM only: 128 (SM90 float SF) or 512 (SM100 UE8M0 SF)
36+
uint32_t kNum1DBlocksPerGroup = get_num_1d_blocks_per_group<kGemmType, BLOCK_M, BLOCK_N, kNumSMs, kIsMulticastOnA>()>
37+
struct Scheduler {
38+
int current_iter = -1;
39+
40+
// Block configs
41+
uint32_t num_blocks;
42+
uint32_t num_m_blocks;
43+
uint32_t num_n_blocks;
44+
45+
// For SM90 multicast checks
46+
uint32_t num_blocks_in_group;
47+
bool is_peer_cta_alive = true;
48+
49+
// For grouped GEMM
50+
int* grouped_layout;
51+
uint32_t current_group_idx = 0;
52+
// Only used for masked layout
53+
uint32_t current_m_cumsum = 0;
54+
// Only used for k-grouped layout
55+
uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0;
56+
uint32_t next_group_idx, next_shape_k;
57+
58+
// Only used for k-grouped gemm
59+
__device__ __forceinline__ void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const {
60+
for (; group_idx < kNumGroups; ++ group_idx) {
61+
shape_k = __ldg(grouped_layout + group_idx);
62+
if (shape_k > 0)
63+
break;
64+
}
65+
}
66+
67+
// ReSharper disable once CppPossiblyUninitializedMember
68+
__device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, const uint32_t& shape_k,
69+
int* grouped_layout = nullptr) {
70+
num_m_blocks = ceil_div(shape_m, BLOCK_M);
71+
num_n_blocks = ceil_div(shape_n, BLOCK_N);
72+
current_shape_k = shape_k;
73+
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) {
74+
num_blocks = num_m_blocks * num_n_blocks;
75+
} else if (kGemmType == GemmType::MGroupedContiguous) {
76+
num_blocks = num_m_blocks * num_n_blocks;
77+
this->grouped_layout = grouped_layout;
78+
} else if (kGemmType == GemmType::MGroupedMasked) {
79+
this->grouped_layout = grouped_layout;
80+
} else if (kGemmType == GemmType::KGroupedContiguous) {
81+
this->grouped_layout = grouped_layout;
82+
get_next_k_group(current_group_idx, current_shape_k);
83+
next_group_idx = current_group_idx + 1;
84+
get_next_k_group(next_group_idx, next_shape_k);
85+
}
86+
}
87+
88+
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
89+
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size");
90+
91+
// Swizzle for better L2 usages
92+
const auto& primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks;
93+
const auto& secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks;
94+
const auto& num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup;
95+
const auto& group_idx = block_idx / num_blocks_per_group;
96+
auto first_block_idx = group_idx * kNum1DBlocksPerGroup;
97+
auto in_group_idx = block_idx % num_blocks_per_group;
98+
num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx);
99+
100+
// Fix unaligned TMA multicast
101+
// NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast
102+
// while SM100 uses 2-CTA, which can not be dynamically disabled
103+
#if __CUDA_ARCH__ < 1000
104+
if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) {
105+
if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) {
106+
num_blocks_in_group = num_blocks_in_group ^ 1;
107+
} else {
108+
in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks;
109+
first_block_idx += num_blocks_in_group ^ 1;
110+
num_blocks_in_group = 1;
111+
}
112+
}
113+
#endif
114+
115+
// Convert to final M/N block indices
116+
// `kIsMulticastOnA == true` leads to groups on N
117+
if constexpr (kIsMulticastOnA) {
118+
m_block_idx = in_group_idx / num_blocks_in_group;
119+
n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
120+
} else {
121+
m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
122+
n_block_idx = in_group_idx / num_blocks_in_group;
123+
}
124+
}
125+
126+
template <bool kWithGroupOffset, IndexType kIndexType = IndexType::MN>
127+
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
128+
const uint32_t& block_idx, const uint32_t& m_block_idx = 0) {
129+
if constexpr (kGemmType == GemmType::Normal) {
130+
return block_idx * block_size;
131+
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
132+
const auto offset = kWithGroupOffset ? cute::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0;
133+
return offset * shape_dim + block_idx * block_size;
134+
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
135+
const auto offset = kWithGroupOffset ? current_group_idx : 0;
136+
return offset * shape_dim + block_idx * block_size;
137+
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
138+
auto offset = 0;
139+
if constexpr (kWithGroupOffset) {
140+
if constexpr (kIndexType == IndexType::MN)
141+
offset = current_group_idx * shape_dim;
142+
else if constexpr (kIndexType == IndexType::K)
143+
offset = current_k_cumsum;
144+
else if constexpr (kIndexType == IndexType::SF_K)
145+
offset = current_sf_k_cumsum;
146+
}
147+
return offset + block_idx * block_size;
148+
} else if constexpr (kGemmType == GemmType::Batched) {
149+
// Ignore kWithGroupOffset, and apply offset for IndexType::SF_K
150+
const auto offset = kIndexType == IndexType::SF_K ? current_group_idx : 0;
151+
return offset * shape_dim + block_idx * block_size;
152+
}
153+
}
154+
155+
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
156+
const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x;
157+
158+
if constexpr (kGemmType == GemmType::MGroupedMasked) {
159+
while (true) {
160+
// End of the task
161+
if (current_group_idx == kNumGroups)
162+
return false;
163+
164+
// Within current group
165+
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + current_group_idx)), BLOCK_M);
166+
const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks;
167+
if (next_block_idx < current_m_block_cumsum * num_n_blocks)
168+
break;
169+
170+
// Move to check the next group
171+
current_group_idx ++, current_m_cumsum = current_m_block_cumsum;
172+
}
173+
174+
get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx);
175+
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
176+
while (true) {
177+
// End of the task
178+
if (current_group_idx == kNumGroups)
179+
return false;
180+
181+
// Within current group
182+
if (next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks)
183+
break;
184+
185+
// Move to check the next group
186+
current_k_cumsum += current_shape_k;
187+
current_sf_k_cumsum += ceil_div(current_shape_k, SF_K_ALIGNMENT);
188+
current_num_valid_groups ++;
189+
190+
current_group_idx = next_group_idx ++;
191+
current_shape_k = next_shape_k;
192+
get_next_k_group(next_group_idx, next_shape_k);
193+
}
194+
195+
get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_m_blocks * num_n_blocks, m_block_idx, n_block_idx);
196+
} else if constexpr (kGemmType == GemmType::Batched) {
197+
if (next_block_idx >= num_blocks * kNumGroups)
198+
return false;
199+
200+
current_group_idx = next_block_idx / num_blocks;
201+
const auto& block_idx = next_block_idx - current_group_idx * num_blocks;
202+
if constexpr (kIsMulticastOnA) {
203+
m_block_idx = block_idx / num_n_blocks;
204+
n_block_idx = block_idx % num_n_blocks;
205+
} else {
206+
m_block_idx = block_idx % num_m_blocks;
207+
n_block_idx = block_idx / num_m_blocks;
208+
}
209+
} else {
210+
if (next_block_idx >= num_blocks)
211+
return false;
212+
213+
// For SM90 only
214+
// NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned
215+
is_peer_cta_alive = num_n_blocks % kNumMulticast == 0 or // Always aligned on N (constant bypass)
216+
num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass)
217+
(next_block_idx ^ 1) < num_blocks; // Peer CTA in bound
218+
get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx);
219+
}
220+
return true;
221+
}
222+
223+
// For SM90 only
224+
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
225+
if (num_blocks_in_group == 1)
226+
return false;
227+
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or
228+
kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched) {
229+
return true;
230+
} else {
231+
DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type");
232+
if constexpr (kIsMulticastOnA) {
233+
return true;
234+
} else {
235+
const auto& group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
236+
const auto& peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
237+
return group_idx == peer_group_idx;
238+
}
239+
}
240+
}
241+
242+
// For SM90 only
243+
// ReSharper disable once CppNotAllPathsReturnValue
244+
__device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const {
245+
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) {
246+
return true;
247+
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
248+
return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0;
249+
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
250+
return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + current_group_idx);
251+
}
252+
}
253+
};
254+
255+
#pragma clang diagnostic pop
256+
257+
} // namespace deep_gemm

0 commit comments

Comments
 (0)