Skip to content

Commit 525040c

Browse files
committed
ggml : add ggml_top_k
1 parent b61de2b commit 525040c

File tree

5 files changed

+85
-10
lines changed

5 files changed

+85
-10
lines changed

ggml/include/ggml.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ extern "C" {
530530
GGML_OP_ARANGE,
531531
GGML_OP_TIMESTEP_EMBEDDING,
532532
GGML_OP_ARGSORT,
533+
GGML_OP_TOP_K,
533534
GGML_OP_LEAKY_RELU,
534535
GGML_OP_TRI,
535536
GGML_OP_FILL,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,6 +1927,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
19271927
{
19281928
ggml_compute_forward_argsort(params, tensor);
19291929
} break;
1930+
case GGML_OP_TOP_K:
1931+
{
1932+
ggml_compute_forward_top_k(params, tensor);
1933+
} break;
19301934
case GGML_OP_LEAKY_RELU:
19311935
{
19321936
ggml_compute_forward_leaky_relu(params, tensor);
@@ -2311,6 +2315,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
23112315
case GGML_OP_ARANGE:
23122316
case GGML_OP_TIMESTEP_EMBEDDING:
23132317
case GGML_OP_ARGSORT:
2318+
case GGML_OP_TOP_K:
23142319
case GGML_OP_FLASH_ATTN_EXT:
23152320
case GGML_OP_FLASH_ATTN_BACK:
23162321
case GGML_OP_SSM_CONV:
@@ -2834,6 +2839,10 @@ struct ggml_cplan ggml_graph_plan(
28342839
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
28352840
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
28362841
} break;
2842+
case GGML_OP_TOP_K:
2843+
{
2844+
cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks;
2845+
} break;
28372846
case GGML_OP_FLASH_ATTN_EXT:
28382847
{
28392848
const int64_t ne10 = node->src[1]->ne[0]; // DK

ggml/src/ggml-cpu/ops.cpp

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7794,7 +7794,7 @@ void ggml_compute_forward_timestep_embedding(
77947794
// ggml_compute_forward_argsort
77957795

77967796
template<enum ggml_sort_order order>
7797-
struct argsort_cmp {
7797+
struct cmp_argsort {
77987798
const float * data;
77997799
bool operator()(int32_t a, int32_t b) const {
78007800
if constexpr (order == GGML_SORT_ORDER_ASC) {
@@ -7833,11 +7833,11 @@ static void ggml_compute_forward_argsort_f32(
78337833

78347834
switch (order) {
78357835
case GGML_SORT_ORDER_ASC:
7836-
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
7836+
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
78377837
break;
78387838

78397839
case GGML_SORT_ORDER_DESC:
7840-
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
7840+
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
78417841
break;
78427842

78437843
default:
@@ -7864,6 +7864,67 @@ void ggml_compute_forward_argsort(
78647864
}
78657865
}
78667866

7867+
// ggml_compute_forward_top_k
7868+
7869+
struct cmp_top_k {
7870+
const float * data;
7871+
bool operator()(int32_t a, int32_t b) const {
7872+
return data[a] > data[b];
7873+
}
7874+
};
7875+
7876+
static void ggml_compute_forward_top_k_f32(
7877+
const ggml_compute_params * params,
7878+
ggml_tensor * dst) {
7879+
7880+
const ggml_tensor * src0 = dst->src[0];
7881+
7882+
GGML_TENSOR_UNARY_OP_LOCALS
7883+
7884+
GGML_ASSERT(nb0 == sizeof(float));
7885+
7886+
const int ith = params->ith;
7887+
const int nth = params->nth;
7888+
7889+
const int64_t nr = ggml_nrows(src0);
7890+
7891+
const int k = ggml_get_op_params_i32(dst, 0);
7892+
7893+
int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
7894+
7895+
for (int64_t i = ith; i < nr; i += nth) {
7896+
const float * src_data = (float *)((char *) src0->data + i*nb01);
7897+
7898+
for (int64_t j = 0; j < ne00; j++) {
7899+
tmp[j] = j;
7900+
}
7901+
7902+
std::partial_sort(tmp, tmp + k, tmp + ne00, cmp_top_k{src_data});
7903+
7904+
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7905+
7906+
std::copy(tmp, tmp + k, dst_data);
7907+
}
7908+
}
7909+
7910+
void ggml_compute_forward_top_k(
7911+
const ggml_compute_params * params,
7912+
ggml_tensor * dst) {
7913+
7914+
const ggml_tensor * src0 = dst->src[0];
7915+
7916+
switch (src0->type) {
7917+
case GGML_TYPE_F32:
7918+
{
7919+
ggml_compute_forward_top_k_f32(params, dst);
7920+
} break;
7921+
default:
7922+
{
7923+
GGML_ABORT("fatal error");
7924+
}
7925+
}
7926+
}
7927+
78677928
// ggml_compute_forward_flash_attn_ext
78687929

78697930
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct
8181
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
8282
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
8383
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
84+
void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst);
8485
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
8586
void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
8687
void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml.c

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
990990
"ARANGE",
991991
"TIMESTEP_EMBEDDING",
992992
"ARGSORT",
993+
"TOP_K",
993994
"LEAKY_RELU",
994995
"TRI",
995996
"FILL",
@@ -1023,7 +1024,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
10231024
"GLU",
10241025
};
10251026

1026-
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
1027+
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
10271028

10281029
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10291030
"none",
@@ -1098,6 +1099,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10981099
"arange(start, stop, step)",
10991100
"timestep_embedding(timesteps, dim, max_period)",
11001101
"argsort(x)",
1102+
"top_k(x)",
11011103
"leaky_relu(x)",
11021104
"tri(x)",
11031105
"fill(x, c)",
@@ -1131,7 +1133,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
11311133
"glu(x)",
11321134
};
11331135

1134-
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
1136+
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
11351137

11361138
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
11371139

@@ -5139,6 +5141,7 @@ struct ggml_tensor * ggml_argsort(
51395141
struct ggml_tensor * a,
51405142
enum ggml_sort_order order) {
51415143
GGML_ASSERT(a->ne[0] <= INT32_MAX);
5144+
51425145
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
51435146

51445147
ggml_set_op_params_i32(result, 0, (int32_t) order);
@@ -5157,12 +5160,12 @@ struct ggml_tensor * ggml_top_k(
51575160
int k) {
51585161
GGML_ASSERT(a->ne[0] >= k);
51595162

5160-
struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
5163+
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]);
5164+
5165+
ggml_set_op_params_i32(result, 0, (int32_t) k);
51615166

5162-
result = ggml_view_4d(ctx, result,
5163-
k, result->ne[1], result->ne[2], result->ne[3],
5164-
result->nb[1], result->nb[2], result->nb[3],
5165-
0);
5167+
result->op = GGML_OP_TOP_K;
5168+
result->src[0] = a;
51665169

51675170
return result;
51685171
}

0 commit comments

Comments
 (0)