Skip to content

Commit f2ad887

Browse files
committed
feat: Move the ttype conditional to templating to avoid conditional in kernel
Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 2a7bbc7 commit f2ad887

File tree

4 files changed

+39
-19
lines changed

4 files changed

+39
-19
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,8 +363,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_tri(ggml_metal_library_t l
363363
char name[256];
364364

365365
const char * op_str = "tri";
366+
const int ttype = op->op_params[0];
366367

367-
snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
368+
snprintf(base, 256, "kernel_%s_%s_%d", op_str, ggml_type_name(op->src[0]->type), ttype);
368369

369370
snprintf(name, 256, "%s", base);
370371

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,6 @@ typedef struct {
848848
uint64_t nb1;
849849
uint64_t nb2;
850850
uint64_t nb3;
851-
uint32_t ttype;
852851
} ggml_metal_kargs_tri;
853852

854853
typedef struct {

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3914,8 +3914,6 @@ int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
39143914
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
39153915
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
39163916

3917-
const ggml_tri_type ttype = (ggml_tri_type) op->op_params[0];
3918-
39193917
ggml_metal_kargs_tri args = {
39203918
/*.ne00 =*/ ne00,
39213919
/*.ne01 =*/ ne01,
@@ -3933,7 +3931,6 @@ int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
39333931
/*.nb1 =*/ nb1,
39343932
/*.nb2 =*/ nb2,
39353933
/*.nb3 =*/ nb3,
3936-
/*.ttype =*/ static_cast<uint32_t>(ttype)
39373934
};
39383935

39393936
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_tri(lib, op);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1943,17 +1943,31 @@ typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
19431943

19441944
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
19451945

1946-
inline static bool _ggml_vec_tri_cmp(const int i, const int r, const uint32_t type) {
1947-
switch (type) {
1948-
// ggml.h:620
1949-
case /* GGML_TRI_TYPE_LOWER */ 3: return i < r; break;
1950-
case /* GGML_TRI_TYPE_LOWER_DIAG */ 2: return i <= r; break;
1951-
case /* GGML_TRI_TYPE_UPPER */ 1: return i > r; break;
1952-
case /* GGML_TRI_TYPE_UPPER_DIAG */ 0: return i >= r; break;
1953-
}
1946+
1947+
template<uint32_t ttype>
1948+
bool _ggml_vec_tri_cmp(const int i, const int r);
1949+
1950+
template<>
1951+
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {
1952+
return i < r;
19541953
}
19551954

1956-
template<typename T>
1955+
template<>
1956+
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {
1957+
return i <= r;
1958+
}
1959+
1960+
template<>
1961+
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {
1962+
return i > r;
1963+
}
1964+
1965+
template<>
1966+
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {
1967+
return i >= r;
1968+
}
1969+
1970+
template<typename T, int ttype>
19571971
kernel void kernel_tri(
19581972
constant ggml_metal_kargs_tri & args,
19591973
device const char * src0,
@@ -1978,16 +1992,25 @@ kernel void kernel_tri(
19781992
// threadgroup, so this will loop once for each index that this thread is
19791993
// responsible for
19801994
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1981-
dst_row[i0] = _ggml_vec_tri_cmp(i0, i1, args.ttype) ? src_row[i0] : static_cast<T>(0.f);
1995+
dst_row[i0] = _ggml_vec_tri_cmp<ttype>(i0, i1) ? src_row[i0] : static_cast<T>(0.f);
19821996
}
19831997
}
19841998

1985-
typedef decltype(kernel_tri<float>) kernel_tri_t;
1999+
typedef decltype(kernel_tri<float, 0>) kernel_tri_t;
19862000

1987-
template [[host_name("kernel_tri_f32")]] kernel kernel_tri_t kernel_tri<float>;
1988-
template [[host_name("kernel_tri_f16")]] kernel kernel_tri_t kernel_tri<half>;
2001+
template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>;
2002+
template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>;
2003+
template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>;
2004+
template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>;
2005+
template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>;
2006+
template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>;
2007+
template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>;
2008+
template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>;
19892009
#if defined(GGML_METAL_HAS_BF16)
1990-
template [[host_name("kernel_tri_bf16")]] kernel kernel_tri_t kernel_tri<bfloat>;
2010+
template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;
2011+
template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;
2012+
template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;
2013+
template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;
19912014
#endif
19922015

19932016
template<typename T>

0 commit comments

Comments
 (0)