Skip to content

Commit c195af2

Browse files
committed
Merge remote-tracking branch 'origin/master' into ggml-cumsum-tri
* origin/master: CUDA: generalized (mma) FA, add Volta support (#17505) chat : reserve memory in compute_diffs and improve naming (#17729)
2 parents 60fe39b + 2e1c9cd commit c195af2

File tree

12 files changed

+975
-760
lines changed

12 files changed

+975
-760
lines changed

common/chat.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,29 +85,36 @@ json common_chat_msg::to_json_oaicompat() const
8585
return message;
8686
}
8787

88-
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) {
88+
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new) {
8989
std::vector<common_chat_msg_diff> diffs;
90-
if (previous_msg.reasoning_content != new_msg.reasoning_content) {
90+
if (msg_new.tool_calls.size() > msg_prv.tool_calls.size()) {
91+
diffs.reserve(msg_new.tool_calls.size() - msg_prv.tool_calls.size() + 3);
92+
} else {
93+
diffs.reserve(3);
94+
}
95+
96+
// TODO: these can become expensive for long messages - how to optimize?
97+
if (msg_prv.reasoning_content != msg_new.reasoning_content) {
9198
auto & diff = diffs.emplace_back();
92-
diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content);
99+
diff.reasoning_content_delta = string_diff(msg_prv.reasoning_content, msg_new.reasoning_content);
93100
}
94-
if (previous_msg.content != new_msg.content) {
101+
if (msg_prv.content != msg_new.content) {
95102
auto & diff = diffs.emplace_back();
96-
diff.content_delta = string_diff(previous_msg.content, new_msg.content);
103+
diff.content_delta = string_diff(msg_prv.content, msg_new.content);
97104
}
98105

99-
if (new_msg.tool_calls.size() < previous_msg.tool_calls.size()) {
106+
if (msg_new.tool_calls.size() < msg_prv.tool_calls.size()) {
100107
throw std::runtime_error("Invalid diff: now finding less tool calls!");
101108
}
102109

103-
if (!previous_msg.tool_calls.empty()) {
104-
auto idx = previous_msg.tool_calls.size() - 1;
105-
const auto & pref = previous_msg.tool_calls[idx];
106-
const auto & newf = new_msg.tool_calls[idx];
110+
if (!msg_prv.tool_calls.empty()) {
111+
const auto idx = msg_prv.tool_calls.size() - 1;
112+
const auto & pref = msg_prv.tool_calls[idx];
113+
const auto & newf = msg_new.tool_calls[idx];
107114
if (pref.name != newf.name) {
108115
throw std::runtime_error("Invalid diff: tool call mismatch!");
109116
}
110-
auto args_diff = string_diff(pref.arguments, newf.arguments);
117+
const auto args_diff = string_diff(pref.arguments, newf.arguments);
111118
if (!args_diff.empty() || pref.id != newf.id) {
112119
auto & diff = diffs.emplace_back();
113120
diff.tool_call_index = idx;
@@ -118,11 +125,12 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
118125
diff.tool_call_delta.arguments = args_diff;
119126
}
120127
}
121-
for (size_t idx = previous_msg.tool_calls.size(); idx < new_msg.tool_calls.size(); ++idx) {
128+
for (size_t idx = msg_prv.tool_calls.size(); idx < msg_new.tool_calls.size(); ++idx) {
122129
auto & diff = diffs.emplace_back();
123130
diff.tool_call_index = idx;
124-
diff.tool_call_delta = new_msg.tool_calls[idx];
131+
diff.tool_call_delta = msg_new.tool_calls[idx];
125132
}
133+
126134
return diffs;
127135
}
128136

common/chat.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ struct common_chat_msg_diff {
7777
size_t tool_call_index = std::string::npos;
7878
common_chat_tool_call tool_call_delta;
7979

80-
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
80+
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new);
8181

8282
bool operator==(const common_chat_msg_diff & other) const {
8383
return content_delta == other.content_delta

ggml/include/ggml.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2279,7 +2279,7 @@ extern "C" {
22792279
float stop,
22802280
float step);
22812281

2282-
#define GGML_KQ_MASK_PAD 64
2282+
#define GGML_KQ_MASK_PAD 1
22832283

22842284
// q: [n_embd_k, n_batch, n_head, ne3 ]
22852285
// k: [n_embd_k, n_kv, n_head_kv, ne3 ]

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ typedef void (* fattn_kernel_t)(
2525
const float m1,
2626
const uint32_t n_head_log2,
2727
const float logit_softcap,
28-
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
28+
const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
2929
const int32_t nb01, const int32_t nb02, const int32_t nb03,
3030
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
3131
const int32_t nb11, const int32_t nb12, const int64_t nb13,
@@ -621,7 +621,8 @@ static __global__ void flash_attn_mask_to_KV_max(
621621
template<int D, int ncols1, int ncols2> // D == head size
622622
__launch_bounds__(D, 1)
623623
static __global__ void flash_attn_stream_k_fixup(
624-
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
624+
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
625+
const int nbatch_fa) {
625626
constexpr int ncols = ncols1*ncols2;
626627

627628
const int bidx0 = blockIdx.x;
@@ -632,8 +633,8 @@ static __global__ void flash_attn_stream_k_fixup(
632633

633634
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
634635

635-
const int iter_k = ne11 / FATTN_KQ_STRIDE;
636-
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
636+
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
637+
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
637638

638639
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
639640
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
@@ -765,7 +766,7 @@ static __global__ void flash_attn_combine_results(
765766
template <int DV, int ncols1, int ncols2>
766767
void launch_fattn(
767768
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
768-
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
769+
const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
769770
) {
770771
constexpr int ncols = ncols1 * ncols2;
771772

@@ -790,8 +791,6 @@ void launch_fattn(
790791
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
791792

792793
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
793-
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
794-
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
795794

796795
ggml_cuda_pool & pool = ctx.pool();
797796
cudaStream_t main_stream = ctx.stream();
@@ -915,7 +914,7 @@ void launch_fattn(
915914

916915
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
917916
} else {
918-
const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
917+
const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
919918

920919
// parallel_blocks must not be larger than what the tensor size allows:
921920
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
@@ -970,6 +969,9 @@ void launch_fattn(
970969
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
971970
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
972971

972+
// TODO other tensor dimensions after removal of WMMA kernel:
973+
const uint3 ne01 = init_fastdiv_values(Q->ne[1]);
974+
973975
GGML_ASSERT(block_dim.x % warp_size == 0);
974976
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
975977
(const char *) Q->data,
@@ -980,7 +982,7 @@ void launch_fattn(
980982
KV_max.ptr,
981983
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
982984
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
983-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
985+
Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
984986
K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
985987
nb21, nb22, nb23,
986988
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
@@ -995,7 +997,7 @@ void launch_fattn(
995997

996998
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
997999
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
998-
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
1000+
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
9991001
}
10001002
} else if (parallel_blocks > 1) {
10011003
const dim3 block_dim_combine(DV, 1, 1);

0 commit comments

Comments
 (0)