@@ -1943,17 +1943,31 @@ typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
19431943
19441944template [[host_name(" kernel_cumsum_add_f32" )]] kernel kernel_cumsum_add_t kernel_cumsum_add<float >;
19451945
1946- inline static bool _ggml_vec_tri_cmp (const int i, const int r, const uint32_t type) {
1947- switch (type) {
1948- // ggml.h:620
1949- case /* GGML_TRI_TYPE_LOWER */ 3 : return i < r; break ;
1950- case /* GGML_TRI_TYPE_LOWER_DIAG */ 2 : return i <= r; break ;
1951- case /* GGML_TRI_TYPE_UPPER */ 1 : return i > r; break ;
1952- case /* GGML_TRI_TYPE_UPPER_DIAG */ 0 : return i >= r; break ;
1953- }
1946+
1947+ template <uint32_t ttype>
1948+ bool _ggml_vec_tri_cmp (const int i, const int r);
1949+
1950+ template <>
1951+ bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3 >(const int i, const int r) {
1952+ return i < r;
19541953}
19551954
1956- template <typename T>
1955+ template <>
1956+ bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2 >(const int i, const int r) {
1957+ return i <= r;
1958+ }
1959+
1960+ template <>
1961+ bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1 >(const int i, const int r) {
1962+ return i > r;
1963+ }
1964+
1965+ template <>
1966+ bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0 >(const int i, const int r) {
1967+ return i >= r;
1968+ }
1969+
1970+ template <typename T, int ttype>
19571971kernel void kernel_tri (
19581972 constant ggml_metal_kargs_tri & args,
19591973 device const char * src0,
@@ -1978,16 +1992,25 @@ kernel void kernel_tri(
19781992 // threadgroup, so this will loop once for each index that this thread is
19791993 // responsible for
19801994 for (int64_t i0 = tpitg.x ; i0 < args.ne00 ; i0 += ntg.x ) {
1981- dst_row[i0] = _ggml_vec_tri_cmp (i0, i1, args. ttype ) ? src_row[i0] : static_cast <T>(0 .f );
1995+ dst_row[i0] = _ggml_vec_tri_cmp<ttype> (i0, i1) ? src_row[i0] : static_cast <T>(0 .f );
19821996 }
19831997}
19841998
1985- typedef decltype (kernel_tri<float >) kernel_tri_t;
1999+ typedef decltype (kernel_tri<float , 0 >) kernel_tri_t;
19862000
1987- template [[host_name(" kernel_tri_f32" )]] kernel kernel_tri_t kernel_tri<float >;
1988- template [[host_name(" kernel_tri_f16" )]] kernel kernel_tri_t kernel_tri<half>;
2001+ template [[host_name(" kernel_tri_f32_0" )]] kernel kernel_tri_t kernel_tri<float , 0 >;
2002+ template [[host_name(" kernel_tri_f32_1" )]] kernel kernel_tri_t kernel_tri<float , 1 >;
2003+ template [[host_name(" kernel_tri_f32_2" )]] kernel kernel_tri_t kernel_tri<float , 2 >;
2004+ template [[host_name(" kernel_tri_f32_3" )]] kernel kernel_tri_t kernel_tri<float , 3 >;
2005+ template [[host_name(" kernel_tri_f16_0" )]] kernel kernel_tri_t kernel_tri<half, 0 >;
2006+ template [[host_name(" kernel_tri_f16_1" )]] kernel kernel_tri_t kernel_tri<half, 1 >;
2007+ template [[host_name(" kernel_tri_f16_2" )]] kernel kernel_tri_t kernel_tri<half, 2 >;
2008+ template [[host_name(" kernel_tri_f16_3" )]] kernel kernel_tri_t kernel_tri<half, 3 >;
19892009#if defined(GGML_METAL_HAS_BF16)
1990- template [[host_name(" kernel_tri_bf16" )]] kernel kernel_tri_t kernel_tri<bfloat>;
2010+ template [[host_name(" kernel_tri_bf16_0" )]] kernel kernel_tri_t kernel_tri<bfloat, 0 >;
2011+ template [[host_name(" kernel_tri_bf16_1" )]] kernel kernel_tri_t kernel_tri<bfloat, 1 >;
2012+ template [[host_name(" kernel_tri_bf16_2" )]] kernel kernel_tri_t kernel_tri<bfloat, 2 >;
2013+ template [[host_name(" kernel_tri_bf16_3" )]] kernel kernel_tri_t kernel_tri<bfloat, 3 >;
19912014#endif
19922015
19932016template <typename T>
0 commit comments