Skip to content

Commit 7690808

Browse files
committed
refactor: Branchless version of tri using _ggml_vec_tri_cmp as a mask
Branch: ggml-cumsum-tri Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 1496afd commit 7690808

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2038,7 +2038,8 @@ kernel void kernel_tri(
20382038
// threadgroup, so this will loop once for each index that this thread is
20392039
// responsible for
20402040
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
2041-
dst_row[i0] = _ggml_vec_tri_cmp<ttype>(i0, i1) ? src_row[i0] : static_cast<T>(0.f);
2041+
// Use the comparison as a mask for branchless
2042+
dst_row[i0] = static_cast<T>(_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];
20422043
}
20432044
}
20442045

0 commit comments

Comments
 (0)