Skip to content
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2f7cfcf
mmf for rdna4
Nov 7, 2025
d564a35
align the padding for rdna4
Nov 7, 2025
0ec241d
Merge branch 'ggml-org:master' into mmf_wmma_rdna4
zhang-hui-yulo Nov 9, 2025
bbee5fe
forbit mul_mat_f for rdna4
Nov 9, 2025
6b8ceeb
Merge branch 'ggml-org:master' into mmf_wmma_rdna4
zhang-hui-yulo Nov 11, 2025
fd18344
fix as comment
Nov 11, 2025
7a09e22
remove device kernels
Nov 11, 2025
c65dd59
add constexpr for early return
Nov 11, 2025
48a53b5
Merge branch 'ggml-org:master' into mmf_wmma_rdna4
zhang-hui-yulo Nov 13, 2025
b7c13ee
update based on review comment
Nov 13, 2025
a0aa491
change based on the review comment
Nov 13, 2025
8c2f9a3
Merge branch 'ggml-org:master' into mmf_wmma_rdna4
zhang-hui-yulo Nov 14, 2025
7a88d7c
pass compile error
Nov 14, 2025
cfc149a
Merge branch 'ggml-org:master' into mmf_wmma_rdna4
zhang-hui-yulo Nov 14, 2025
59a012f
keep code consistency
Nov 17, 2025
6802fbf
Merge branch 'ggml-org:master' into mmf_wmma_rdna4
zhang-hui-yulo Nov 19, 2025
facded5
Merge branch 'ggml-org:master' into mmf_wmma_rdna4
zhang-hui-yulo Nov 21, 2025
28b5e45
Attempt to port RDNA4 WMMA optimizations to RDNA3
Nov 21, 2025
ba25661
Merge branch 'master' into mmf_wmma_rdna3
Nov 25, 2025
edb86ef
WMMA RDNA3 fixes
Nov 25, 2025
5a80fb4
fix RDNA3 not using the fast DP4A-based MMQ path. RDNA4 should still …
Nov 25, 2025
8aed111
more fixes for RDNA3 crashes with quantized models
Nov 25, 2025
9191856
Fix RDNA3 WMMA int8 codepath
Nov 27, 2025
c692629
fix endif commets
Nov 27, 2025
9c9b0ea
More fixes on get_i/get_j, add support for more tile sizes. get_j res…
Nov 29, 2025
816cb1d
Disable WMMA for RDNA3, enable only for RDNA4
Nov 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,9 @@ static const char * cu_get_error_str(CUresult err) {
#define AMD_MFMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)

#if defined(GGML_USE_HIP) && defined(RDNA4)
#if defined(GGML_USE_HIP) && (defined(RDNA3) || defined(RDNA4))
#define AMD_WMMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(RDNA4)
#endif // defined(GGML_USE_HIP) && (defined(RDNA3) || defined(RDNA4))

// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
Expand Down Expand Up @@ -288,7 +288,7 @@ static bool amd_mfma_available(const int cc) {
}

static bool amd_wmma_available(const int cc) {
return GGML_CUDA_CC_IS_RDNA4(cc);
return GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
}

static bool volta_mma_available(const int cc) {
Expand Down
64 changes: 51 additions & 13 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,20 @@ namespace ggml_cuda_mma {
}
}
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
T x[ne] = {0};

static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
return false;
// Integer WMMA is only supported on RDNA4
if constexpr (std::is_same_v<T, int>) {
#if defined(RDNA4)
if (I == 16 && J == 16) return true;
#endif
return false;
} else {
if (I == 16 && J == 16) return true;
return false;
}
}

static __device__ __forceinline__ int get_i(const int l) {
Expand All @@ -176,7 +183,6 @@ namespace ggml_cuda_mma {
return -1;
}
}
#endif
#else
static constexpr int ne = I * J / 32;
T x[ne] = {0};
Expand Down Expand Up @@ -223,7 +229,7 @@ namespace ggml_cuda_mma {
return -1;
}
}
#endif // defined(GGML_USE_HIP)
#endif
};

template <int I_, int J_>
Expand Down Expand Up @@ -265,7 +271,11 @@ namespace ggml_cuda_mma {
}
}
#elif defined(AMD_WMMA_AVAILABLE)
static constexpr int ne = I * J / 32;
#if defined(RDNA4)
static constexpr int ne = I * J / 32; // 4 half2 = 8 FP16 for RDNA4
#else
static constexpr int ne = I * J / 16; // 8 half2 = 16 FP16 for RDNA3 (duplicate layout)
#endif
half2 x[ne] = {{0.0f, 0.0f}};

static constexpr __device__ bool supported() {
Expand Down Expand Up @@ -341,7 +351,11 @@ namespace ggml_cuda_mma {
static constexpr int J = J_;

#if defined(AMD_WMMA_AVAILABLE)
static constexpr int ne = I * J / 32;
#if defined(RDNA4)
static constexpr int ne = I * J / 32; // 4 bfloat162 = 8 BF16 for RDNA4
#else
static constexpr int ne = I * J / 16; // 8 bfloat162 = 16 BF16 for RDNA3 (duplicate layout)
#endif
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};

static constexpr __device__ bool supported() {
Expand Down Expand Up @@ -441,13 +455,22 @@ namespace ggml_cuda_mma {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
#if !defined(RDNA4)
// RDNA3 has double the tile size, load 2 more int64_t
xi[1] = xs[1];
#endif
}else if constexpr (I == 16 && J == 8) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
xi[0] = xs[0];

const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
xi[1] = xs1[0];
#if !defined(RDNA4)
// RDNA3 has double the tile size, load 2 more int64_t
xi[2] = xs[1];
xi[3] = xs1[1];
#endif
}else{
NO_DEVICE_CODE;
}
Expand Down Expand Up @@ -738,12 +761,21 @@ namespace ggml_cuda_mma {
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
using floatx8_t = __attribute__((ext_vector_type(8))) float;
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
#else // RDNA3
using halfx16_t = __attribute__((ext_vector_type(16))) _Float16;
using floatx8_t = __attribute__((ext_vector_type(8))) float;
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
const halfx16_t& a_frag = reinterpret_cast<const halfx16_t&>(A.x[0]);
const halfx16_t& b_frag = reinterpret_cast<const halfx16_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, acc_frag);
#endif
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
Expand All @@ -753,12 +785,21 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
#if defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
using floatx8_t = __attribute__((ext_vector_type(8))) float;
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
#else // RDNA3
using bf16x16_t = __attribute__((ext_vector_type(16))) __bf16;
using floatx8_t = __attribute__((ext_vector_type(8))) float;
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
const bf16x16_t& a_frag = reinterpret_cast<const bf16x16_t&>(A.x[0]);
const bf16x16_t& b_frag = reinterpret_cast<const bf16x16_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, acc_frag);
#endif
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
Expand Down Expand Up @@ -786,16 +827,14 @@ namespace ggml_cuda_mma {
0, 0, 0);
#endif // defined(CDNA3)

#elif defined(AMD_WMMA_AVAILABLE)
#elif defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x;

using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;

#if defined(RDNA4)

acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[0],
Expand All @@ -812,8 +851,7 @@ namespace ggml_cuda_mma {
b_vec[1],
acc[0],
true
);
#endif // defined(RDNA4)
)

#else
GGML_UNUSED_VARS(D, A, B);
Expand Down Expand Up @@ -889,7 +927,7 @@ namespace ggml_cuda_mma {

static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
#if defined(AMD_WMMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x;
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/mmf.cu
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
return false;
}
} else {
if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
if (src1_ncols > 16 || amd_wmma_available(cc)) {
return false;
}
}
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,10 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return true;
}
// RDNA3 doesn't support integer WMMA operations required for MMQ
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
return false;
}
}

return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
Expand Down