-
Notifications
You must be signed in to change notification settings - Fork 14k
ggml : add ggml_top_k #17365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml : add ggml_top_k #17365
Conversation
56ab2ca to
4d75c05
Compare
4d75c05 to
5d8ce1c
Compare
|
Does this operator expect the top-K elements to be sorted? |
I feel it should not be sorted, as algorithmically we are performing a selection, and depending on the algorithm the outcome of this selection is unordered: https://leimao.github.io/blog/CPU-TopK-Algorithm/ Should one wish to sort, one could easily do |
In principle it does not have to expect the elements to be sorted. However the current implementation sorts them in descending order in order to be able to verify correctness with |
By treating them as sets rather than lists? We could use std::unordered_set for this |
Currently, test-backend-ops relies on NMSE of outputs rather than cardinality checks, but I guess that can be changed. |
|
It would be ok to add an overrideable error function to |
|
What are common tensor shapes and values of Does this operation support non-contiguous rows? |
|
@jeffbolznv This will be used in #17004 to do top-k sampling efficiently on the GPU. The typical shapes are:
Support for non-contiguous rows is not necessary for now - will add asserts for that. |
|
OK, understood. When you get a chance, please rebase, I'll implement something based on #17313. |
5d8ce1c to
4dea5dd
Compare
| ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens] | ||
|
|
||
| ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens] | ||
| ggml_tensor * group_scores = ggml_argsort_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess these are temporary until all backends support are in place? Add a TODO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not 100% sure yet - keeping the expert order deterministic might be necessary. And using ggml_top_k here would likely not make a big difference performance wise since the arrays are very small.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Completely unnecessary for the expert group selection at least.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't change it anywhere else as it will also break fusion in all backends for topk-moe
The values are also shuffled: llama.cpp/tests/test-backend-ops.cpp Lines 5034 to 5042 in c63ecde
So top 1 could be any number. |
|
Vulkan support is ready in #17418. |
db4570a to
1e3d461
Compare
Added the overridable error function in 961dd4f Think this is OK to merge. Will do so later today if there are no additional concerns. |
| int64_t ne00; | ||
| int64_t ne01; | ||
| int64_t ne02; | ||
| int64_t ne03; | ||
| int32_t ne00; | ||
| int32_t ne01; | ||
| int32_t ne02; | ||
| int32_t ne03; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't know my way around the metal backend but this could lead to overflow; Do we need to protect against this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The convention is to use 32-bit ints for the number of elements:
llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h
Lines 87 to 94 in b1846f1
| // kernel argument structs | |
| // | |
| // - element counters (e.g. ne00) typically use int32_t to reduce register usage | |
| // however, be careful from int overflows when using those in the kernel implementation | |
| // | |
| // - strides (e.g. nb00) use uint64_t | |
Overflows are handled by explicitly casting to 64-bit when we multiply 32-bit ints:
llama.cpp/ggml/src/ggml-metal/ggml-metal.metal
Lines 3200 to 3202 in 1e3d461
| device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | |
It's possible that some casts are missing here and there, but I usually update these when I spot them.
Add a dedicated top-k op so that it can be more efficiently optimized by backend implementations. The old implementation is renamed to
ggml_argsort_top_k.TODO:
op_paramsNext PRs: