Skip to content

Commit 8dda569

Browse files
Merge pull request #209 from menloresearch/update-dev-from-master-2025-08-20-00-11
Sync master with upstream release b6209
2 parents 576f4cd + fb22dd0 commit 8dda569

28 files changed

+258
-104
lines changed

common/arg.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1530,6 +1530,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
15301530
params.ctx_shift = false;
15311531
}
15321532
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
1533+
add_opt(common_arg(
1534+
{"--context-shift"},
1535+
string_format("enables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
1536+
[](common_params & params) {
1537+
params.ctx_shift = true;
1538+
}
1539+
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_CONTEXT_SHIFT"));
15331540
add_opt(common_arg(
15341541
{"--chunks"}, "N",
15351542
string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
@@ -1823,7 +1830,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
18231830
[](common_params & params, const std::string & value) {
18241831
params.sampling.top_n_sigma = std::stof(value);
18251832
}
1826-
).set_examples({LLAMA_EXAMPLE_MAIN}).set_sparam());
1833+
).set_sparam());
18271834
add_opt(common_arg(
18281835
{"--xtc-probability"}, "N",
18291836
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),

common/chat.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,6 @@ const char * common_reasoning_format_name(common_reasoning_format format) {
632632
case COMMON_REASONING_FORMAT_AUTO: return "auto";
633633
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
634634
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
635-
case COMMON_REASONING_FORMAT_GRANITE: return "granite";
636635
default:
637636
throw std::runtime_error("Unknown reasoning format");
638637
}

common/common.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,15 @@ struct common_params_diffusion {
239239
bool add_gumbel_noise = false; // add gumbel noise to the logits if temp > 0.0
240240
};
241241

242+
// reasoning API response format (not to be confused as chat template's reasoning format)
242243
enum common_reasoning_format {
243244
COMMON_REASONING_FORMAT_NONE,
244-
COMMON_REASONING_FORMAT_AUTO,
245+
COMMON_REASONING_FORMAT_AUTO, // Same as deepseek, using `message.reasoning_content`
245246
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
246247
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
247-
COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
248+
// do not extend this enum unless you absolutely have to
249+
// in most cases, use COMMON_REASONING_FORMAT_AUTO
250+
// see: https://github.com/ggml-org/llama.cpp/pull/15408
248251
};
249252

250253

@@ -372,7 +375,7 @@ struct common_params {
372375
bool cont_batching = true; // insert new sequences for decoding on-the-fly
373376
bool flash_attn = false; // flash attention
374377
bool no_perf = false; // disable performance metrics
375-
bool ctx_shift = true; // context shift on inifinite text generation
378+
bool ctx_shift = false; // context shift on inifinite text generation
376379
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
377380
bool kv_unified = false; // enable unified KV cache
378381

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 106 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2154,86 +2154,129 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
21542154

21552155
GGML_TENSOR_BINARY_OP_LOCALS
21562156

2157-
// theta_scale arange, [0,1,...,ne00/2 - 1]
21582157
int64_t theta_scale_length = ne00 / 2;
2159-
ggml_cann_pool_alloc theta_scale_allocator(ctx.pool(),
2160-
theta_scale_length * sizeof(float_t));
2161-
void* theta_scale_buffer = theta_scale_allocator.get();
21622158
int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
21632159
size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t),
21642160
theta_scale_length * sizeof(float_t)};
21652161

2166-
aclTensor* acl_theta_scale_tensor =
2167-
ggml_cann_create_tensor(theta_scale_buffer, ACL_FLOAT, sizeof(float_t),
2168-
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2169-
float start = 0;
2170-
float step = 1;
2171-
float stop = ne00 / 2;
2172-
float n_elements = ne00 / 2;
2173-
aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
2174-
2175-
// power
2176-
aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
2177-
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
2178-
acl_theta_scale_tensor);
2179-
2180-
// freq_scale
2181-
if (freq_scale != 1) {
2182-
aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
2183-
}
2184-
2185-
// freq_factors
2186-
if (src2) {
2187-
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
2188-
src2->data, ggml_cann_type_mapping(src2->type),
2189-
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2190-
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
2191-
ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
2192-
}
2193-
2194-
// position
21952162
GGML_ASSERT(src1->type == GGML_TYPE_I32);
21962163
int64_t position_length = src1->ne[0];
21972164
int64_t position_ne[] = {1, 1, position_length, 1};
21982165
size_t position_nb[] = {sizeof(int32_t), sizeof(int32_t), sizeof(int32_t),
21992166
sizeof(int32_t) * position_length};
2200-
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
2201-
src1->data, ggml_cann_type_mapping(src1->type),
2202-
ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
2203-
2204-
// power * position
2205-
int64_t theta_length = theta_scale_length * position_length;
2206-
ggml_cann_pool_alloc theta_allocator(ctx.pool(),
2207-
theta_length * sizeof(float_t));
2208-
void* theta_buffer = theta_allocator.get();
2167+
22092168
int64_t theta_ne[] = {theta_scale_length, 1, position_length, 1};
22102169
size_t theta_nb[GGML_MAX_DIMS];
22112170
theta_nb[0] = sizeof(float_t);
22122171
for (int i = 1; i < GGML_MAX_DIMS; i++) {
22132172
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
22142173
}
2215-
aclTensor* acl_theta_tensor =
2216-
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
2217-
theta_ne, theta_nb, GGML_MAX_DIMS);
2218-
aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
2219-
acl_theta_tensor);
2220-
2221-
// sin/cos
2222-
ggml_cann_pool_alloc sin_allocator(ctx.pool(),
2223-
theta_length * sizeof(float_t));
2224-
void* sin_buffer = sin_allocator.get();
2225-
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
2226-
sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2227-
GGML_MAX_DIMS, ACL_FORMAT_ND);
2228-
aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
22292174

2230-
ggml_cann_pool_alloc cos_allocator(ctx.pool(),
2231-
theta_length * sizeof(float_t));
2232-
void* cos_buffer = cos_allocator.get();
2175+
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
2176+
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
2177+
2178+
// used for accuracy testing
2179+
bool is_attention = is_q || is_k;
2180+
2181+
if(ctx.init_ptr == nullptr || !is_attention) {
2182+
// theta_scale arange, [0,1,...,ne00/2 - 1]
2183+
if(ctx.init_ptr != nullptr){
2184+
ACL_CHECK(aclrtFree(ctx.init_ptr));
2185+
}
2186+
ACL_CHECK(aclrtMalloc(&ctx.init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2187+
2188+
aclTensor* acl_theta_scale_tensor =
2189+
ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
2190+
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2191+
float start = 0;
2192+
float step = 1;
2193+
float stop = ne00 / 2;
2194+
float n_elements = ne00 / 2;
2195+
aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
2196+
2197+
// power
2198+
aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
2199+
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
2200+
acl_theta_scale_tensor);
2201+
2202+
// freq_scale
2203+
if (freq_scale != 1) {
2204+
aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
2205+
}
2206+
2207+
// freq_factors
2208+
if (src2) {
2209+
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
2210+
src2->data, ggml_cann_type_mapping(src2->type),
2211+
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2212+
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
2213+
ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
2214+
}
2215+
// release
2216+
ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale);
2217+
}
2218+
2219+
if(ctx.sin_ptr == nullptr) {
2220+
int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
2221+
ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2222+
ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2223+
}
2224+
if(position_length > ctx.max_prompt_length) {
2225+
ctx.max_prompt_length = position_length;
2226+
int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
2227+
ACL_CHECK(aclrtFree(ctx.sin_ptr));
2228+
ACL_CHECK(aclrtFree(ctx.cos_ptr));
2229+
ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2230+
ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2231+
}
2232+
2233+
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
2234+
2235+
if(is_fisrt_layer || !is_attention) {
2236+
2237+
aclTensor* acl_theta_scale_tensor =
2238+
ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
2239+
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2240+
2241+
// position
2242+
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
2243+
src1->data, ggml_cann_type_mapping(src1->type),
2244+
ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
2245+
2246+
// power * position
2247+
int64_t theta_length = theta_scale_length * position_length;
2248+
ggml_cann_pool_alloc theta_allocator(ctx.pool(),
2249+
theta_length * sizeof(float_t));
2250+
void* theta_buffer = theta_allocator.get();
2251+
2252+
aclTensor* acl_theta_tensor =
2253+
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
2254+
theta_ne, theta_nb, GGML_MAX_DIMS);
2255+
aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
2256+
acl_theta_tensor);
2257+
2258+
// sin/cos
2259+
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
2260+
ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2261+
GGML_MAX_DIMS, ACL_FORMAT_ND);
2262+
aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
2263+
2264+
aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
2265+
ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2266+
GGML_MAX_DIMS, ACL_FORMAT_ND);
2267+
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
2268+
2269+
// release
2270+
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
2271+
acl_theta_tensor, acl_sin_tensor, acl_cos_tensor);
2272+
}
2273+
2274+
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
2275+
ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2276+
GGML_MAX_DIMS, ACL_FORMAT_ND);
22332277
aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
2234-
cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2235-
GGML_MAX_DIMS, ACL_FORMAT_ND);
2236-
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
2278+
ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2279+
GGML_MAX_DIMS, ACL_FORMAT_ND);
22372280

22382281
// attn_factor
22392282
if (attn_factor != 1) {
@@ -2257,8 +2300,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
22572300
}
22582301

22592302
// release
2260-
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
2261-
acl_theta_tensor, acl_sin_tensor, acl_cos_tensor, acl_theta_scale);
2303+
ggml_cann_release_resources(ctx, acl_sin_tensor, acl_cos_tensor);
22622304
}
22632305

22642306
#ifdef __cplusplus

ggml/src/ggml-cann/common.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,10 @@ struct ggml_backend_cann_context {
368368
std::string name; /**< Name of the device. */
369369
std::string description; /**< Description of the device. */
370370
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
371+
void* init_ptr = nullptr;
372+
void* sin_ptr = nullptr;
373+
void* cos_ptr = nullptr;
374+
int64_t max_prompt_length = 65536;
371375
#ifdef USE_ACL_GRAPH
372376
/// Cached CANN ACL graph used for executing the current ggml computation graph.
373377
std::unique_ptr<ggml_cann_graph> cann_graph;
@@ -414,6 +418,15 @@ struct ggml_backend_cann_context {
414418
ACL_CHECK(aclrtDestroyStream(streams[i]));
415419
}
416420
}
421+
if(init_ptr != nullptr) {
422+
ACL_CHECK(aclrtFree(init_ptr));
423+
}
424+
if(sin_ptr != nullptr) {
425+
ACL_CHECK(aclrtFree(sin_ptr));
426+
}
427+
if(cos_ptr != nullptr) {
428+
ACL_CHECK(aclrtFree(cos_ptr));
429+
}
417430
}
418431

419432
/**

ggml/src/ggml-cpu/arch-fallback.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
7474
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
7575
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
76-
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
7776
// repack.cpp
7877
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
7978
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8

ggml/src/ggml-cpu/arch/powerpc/quants.c

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,72 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
278278
#endif
279279
}
280280

281+
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
282+
assert(nrc == 1);
283+
UNUSED(nrc);
284+
UNUSED(bx);
285+
UNUSED(by);
286+
UNUSED(bs);
287+
assert(n % QK_MXFP4 == 0);
288+
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
289+
290+
const block_mxfp4 * GGML_RESTRICT x = vx;
291+
const block_q8_0 * GGML_RESTRICT y = vy;
292+
293+
const int nb = n / QK_MXFP4;
294+
295+
int ib = 0;
296+
float sumf = 0;
297+
298+
#if defined(__POWER9_VECTOR__)
299+
const vector signed char lowMask = vec_splats((signed char)0xF);
300+
const vector unsigned char vshift4 = vec_splats((unsigned char)4);
301+
vector float vsumf0 = vec_splats(0.0f);
302+
303+
vector signed char kv = vec_xl(0, (const signed char *)kvalues_mxfp4);
304+
305+
#pragma GCC unroll 8
306+
for (; ib < nb; ++ib) {
307+
__builtin_prefetch(x[ib].qs, 0, 1);
308+
__builtin_prefetch(y[ib].qs, 0, 1);
309+
310+
vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d) *
311+
GGML_E8M0_TO_FP32_HALF(x[ib].e));
312+
313+
vector signed char q8y0 = vec_xl( 0, y[ib].qs);
314+
vector signed char q8y1 = vec_xl(16, y[ib].qs);
315+
316+
vector signed char qxs = (vector signed char)vec_xl(0, x[ib].qs);
317+
318+
vector unsigned char lo_nibbles = (vector unsigned char)vec_and(qxs, lowMask);
319+
vector unsigned char hi_nibbles = (vector unsigned char)vec_sr(qxs, vshift4);
320+
321+
vector signed char q4x0 = vec_perm(kv, kv, lo_nibbles);
322+
vector signed char q4x1 = vec_perm(kv, kv, hi_nibbles);
323+
324+
vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0));
325+
vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1));
326+
327+
vector signed int vsumi0 = vec_splats((int32_t)0);
328+
vsumi0 = vec_sum4s(qv0, vsumi0);
329+
vsumi0 = vec_sum4s(qv1, vsumi0);
330+
331+
vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vyd, vsumf0);
332+
}
333+
334+
vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
335+
vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
336+
sumf = vec_extract(vsumf0, 0);
337+
*s = sumf;
338+
#else
339+
UNUSED(x);
340+
UNUSED(y);
341+
UNUSED(ib);
342+
UNUSED(sumf);
343+
ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
344+
#endif
345+
}
346+
281347
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
282348
const int qk = QK8_0;
283349
const int nb = n / qk;

ggml/src/ggml-cuda/common.cuh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@
7878
#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
7979

8080
// Moore Threads
81+
#define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons
82+
8183
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
8284
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
8385
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
@@ -490,13 +492,14 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
490492
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP)
491493
}
492494

493-
#if CUDART_VERSION < CUDART_HMASK
495+
#if (defined(CUDART_VERSION) && CUDART_VERSION < CUDART_HMASK) || defined(GGML_USE_HIP) || \
496+
(defined(MUSART_VERSION) && MUSART_VERSION < MUSART_HMASK)
494497
static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
495498
const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
496499
const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
497500
return mask_low | mask_high;
498501
}
499-
#endif // CUDART_VERSION < CUDART_HMASK
502+
#endif // (defined(CUDART_VERSION) && CUDART_VERSION < CUDART_HMASK) || defined(GGML_USE_HIP) || (defined(MUSART_VERSION) && MUSART_VERSION < MUSART_HMASK)
500503

501504
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
502505
#if defined(GGML_USE_HIP)

0 commit comments

Comments
 (0)