Skip to content

Commit d9e03db

Browse files
authored
sycl: add missing BF16 conversion support for Intel oneAPI (ggml-org#17780)
* sycl: add missing BF16 conversion support for Intel oneAPI * Fix Line 645: Trailing whitespace
1 parent db97837 commit d9e03db

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

ggml/src/ggml-sycl/convert.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
#include "dequantize.hpp"
33
#include "presets.hpp"
44

5+
#if defined(__INTEL_LLVM_COMPILER)
6+
#if __has_include(<sycl/ext/oneapi/bfloat16.hpp>)
7+
#include <sycl/ext/oneapi/bfloat16.hpp>
8+
#define GGML_SYCL_HAS_BF16
9+
#endif
10+
#endif
11+
512
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
613
static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
714
const sycl::nd_item<3> &item_ct1) {
@@ -566,6 +573,10 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
566573
return dequantize_row_iq4_nl_sycl;
567574
case GGML_TYPE_F32:
568575
return convert_unary_sycl<float>;
576+
#ifdef GGML_SYCL_HAS_BF16
577+
case GGML_TYPE_BF16:
578+
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
579+
#endif
569580
default:
570581
return nullptr;
571582
}
@@ -627,6 +638,10 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
627638
return dequantize_row_iq4_nl_sycl;
628639
case GGML_TYPE_F16:
629640
return convert_unary_sycl<sycl::half>;
641+
#ifdef GGML_SYCL_HAS_BF16
642+
case GGML_TYPE_BF16:
643+
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
644+
#endif
630645
default:
631646
return nullptr;
632647
}
@@ -636,6 +651,10 @@ to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
636651
switch (type) {
637652
case GGML_TYPE_F32:
638653
return convert_unary_nc_sycl<float>;
654+
#ifdef GGML_SYCL_HAS_BF16
655+
case GGML_TYPE_BF16:
656+
return convert_unary_nc_sycl<sycl::ext::oneapi::bfloat16>;
657+
#endif
639658
default:
640659
return nullptr;
641660
}

0 commit comments

Comments
 (0)