Skip to content

Commit 5f0d2a1

Browse files
committed
feat(ggml-metal): Metal impl of tri
Branch: Mamba2SSD Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent c71e35e commit 5f0d2a1

File tree

4 files changed

+124
-2
lines changed

4 files changed

+124
-2
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,27 @@ typedef struct {
586586
uint64_t nb3;
587587
} ggml_metal_kargs_cumsum;
588588

589+
typedef struct {
590+
int64_t ne00;
591+
int64_t ne01;
592+
int64_t ne02;
593+
int64_t ne03;
594+
uint64_t nb00;
595+
uint64_t nb01;
596+
uint64_t nb02;
597+
uint64_t nb03;
598+
int64_t ne0;
599+
int64_t ne1;
600+
int64_t ne2;
601+
int64_t ne3;
602+
uint64_t nb0;
603+
uint64_t nb1;
604+
uint64_t nb2;
605+
uint64_t nb3;
606+
float c;
607+
uint32_t ttype;
608+
} ggml_metal_kargs_tri;
609+
589610
typedef struct {
590611
int32_t ne00;
591612
int32_t ne01;

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,8 +1004,57 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
10041004

10051005
int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
10061006
ggml_tensor * op = ctx->node(idx);
1007-
//DEBUG
1008-
GGML_ASSERT(false);
1007+
1008+
ggml_metal_library_t lib = ctx->lib;
1009+
ggml_metal_encoder_t enc = ctx->enc;
1010+
1011+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1012+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1013+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1014+
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
1015+
1016+
const ggml_tri_type ttype = (ggml_tri_type) op->op_params[0];
1017+
const float c = *((float *) &(op->op_params[1]));
1018+
1019+
ggml_metal_kargs_tri args = {
1020+
/*.ne00 =*/ ne00,
1021+
/*.ne01 =*/ ne01,
1022+
/*.ne02 =*/ ne02,
1023+
/*.ne03 =*/ ne03,
1024+
/*.nb00 =*/ nb00,
1025+
/*.nb01 =*/ nb01,
1026+
/*.nb02 =*/ nb02,
1027+
/*.nb03 =*/ nb03,
1028+
/*.ne0 =*/ ne0,
1029+
/*.ne1 =*/ ne1,
1030+
/*.ne2 =*/ ne2,
1031+
/*.ne3 =*/ ne3,
1032+
/*.nb0 =*/ nb0,
1033+
/*.nb1 =*/ nb1,
1034+
/*.nb2 =*/ nb2,
1035+
/*.nb3 =*/ nb3,
1036+
/*.c =*/ c,
1037+
/*.ttype =*/ static_cast<uint32_t>(ttype)
1038+
};
1039+
1040+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_tri(lib, op);
1041+
1042+
int nth = 32; // SIMD width
1043+
1044+
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1045+
nth *= 2;
1046+
}
1047+
1048+
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1049+
nth = std::min(nth, ne00);
1050+
1051+
ggml_metal_encoder_set_pipeline(enc, pipeline);
1052+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1053+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1054+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
1055+
1056+
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
1057+
10091058
return 1;
10101059
}
10111060

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1857,6 +1857,56 @@ template [[host_name("kernel_cumsum_f16")]] kernel kernel_cumsum_t kernel_cumsum
18571857
template [[host_name("kernel_cumsum_bf16")]] kernel kernel_cumsum_t kernel_cumsum<bfloat>;
18581858
#endif
18591859

1860+
inline static bool _ggml_vec_tri_cmp(const int i, const int r, const uint32_t type) {
1861+
switch (type) {
1862+
// ggml.h:620
1863+
case /* GGML_TRI_TYPE_LOWER */ 3: return i < r; break;
1864+
case /* GGML_TRI_TYPE_LOWER_DIAG */ 2: return i <= r; break;
1865+
case /* GGML_TRI_TYPE_UPPER */ 1: return i > r; break;
1866+
case /* GGML_TRI_TYPE_UPPER_DIAG */ 0: return i >= r; break;
1867+
}
1868+
}
1869+
1870+
template<typename T>
1871+
kernel void kernel_tri(
1872+
constant ggml_metal_kargs_tri & args,
1873+
device const char * src0,
1874+
device const char * dst,
1875+
uint3 tgpig[[threadgroup_position_in_grid]],
1876+
ushort3 tpitg[[thread_position_in_threadgroup]],
1877+
ushort sgitg[[simdgroup_index_in_threadgroup]],
1878+
ushort tiisg[[thread_index_in_simdgroup]],
1879+
ushort3 ntg[[threads_per_threadgroup]]) {
1880+
const int64_t i3 = tgpig.z;
1881+
const int64_t i2 = tgpig.y;
1882+
const int64_t i1 = tgpig.x;
1883+
1884+
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
1885+
return;
1886+
}
1887+
1888+
device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
1889+
device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
1890+
1891+
// Each thread is a single element of the row if ne00 < max threads per
1892+
// threadgroup, so this will loop once for each index that this thread is
1893+
// responsible for
1894+
const bool keep_org_val = isnan(args.c);
1895+
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1896+
dst_row[i0] = _ggml_vec_tri_cmp(i0, i1, args.ttype)
1897+
? (keep_org_val ? src_row[i0] : static_cast<T>(args.c))
1898+
: static_cast<T>(0.f);
1899+
}
1900+
}
1901+
1902+
typedef decltype(kernel_tri<float>) kernel_tri_t;
1903+
1904+
template [[host_name("kernel_tri_f32")]] kernel kernel_tri_t kernel_tri<float>;
1905+
template [[host_name("kernel_tri_f16")]] kernel kernel_tri_t kernel_tri<half>;
1906+
#if defined(GGML_METAL_HAS_BF16)
1907+
template [[host_name("kernel_tri_bf16")]] kernel kernel_tri_t kernel_tri<bfloat>;
1908+
#endif
1909+
18601910
template<typename T>
18611911
kernel void kernel_soft_max(
18621912
constant ggml_metal_kargs_soft_max & args,

tests/test-backend-ops.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6951,6 +6951,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval(int verbose
69516951
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {8, 8, 4, 16}, 42.f));
69526952
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F16, {8, 8, 4, 16}, 42.f));
69536953
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_BF16, {8, 8, 4, 16}, 42.f));
6954+
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {2025, 2025, 1, 1}));
69546955

69556956
for (bool v : {false, true}) {
69566957
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v));
@@ -7123,6 +7124,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
71237124
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {8, 8, 4, 16}, 42.f));
71247125
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F16, {8, 8, 4, 16}, 42.f));
71257126
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_BF16, {8, 8, 4, 16}, 42.f));
7127+
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER, GGML_TYPE_F32, {2025, 2025, 1, 1}));
71267128

71277129
for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
71287130
for (ggml_type type_a : all_types) {

0 commit comments

Comments
 (0)