Skip to content

Commit efaaccd

Browse files
refactor pad_reflect_1d to make the UT case pass (#17204)
Co-authored-by: Zhang Jianyu <zhang.jianyu@outlook.com>
1 parent 4abef75 commit efaaccd

File tree

3 files changed

+115
-59
lines changed

3 files changed

+115
-59
lines changed

ggml/src/ggml-sycl/common.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,4 +617,30 @@ static __dpct_inline__ float get_alibi_slope(const float max_bias,
617617
return dpct::pow(base, exph);
618618
}
619619

620+
static const sycl::uint3 init_fastdiv_values(uint32_t d) {
621+
GGML_ASSERT(d != 0);
622+
623+
uint32_t L = 0;
624+
while (L < 32 && (uint32_t{ 1 } << L) < d) {
625+
L++;
626+
}
627+
628+
uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
629+
return sycl::uint3(mp, L, d);
630+
}
631+
632+
633+
static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) {
634+
const uint32_t hi = sycl::mul_hi<unsigned>(n, fastdiv_values.x());
635+
return (hi + n) >> fastdiv_values.y();
636+
}
637+
638+
639+
static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) {
640+
const uint32_t div_val = fastdiv(n, fastdiv_values);
641+
const uint32_t mod_val = n - div_val * fastdiv_values.z();
642+
return sycl::uint2(div_val, mod_val);
643+
}
644+
645+
620646
#endif // GGML_SYCL_COMMON_HPP
Lines changed: 87 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,100 @@
11
#include "pad_reflect_1d.hpp"
22

3-
void pad_reflect_1d_f32(const float* src,float* dst,
4-
const int64_t ne0, const int64_t ne02, const int p0, const int p1,
5-
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
6-
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
7-
const sycl::nd_item<3> &item_ct1){
8-
9-
const int i0 = item_ct1.get_group(0) * SYCL_CONCAT_BLOCK_SIZE + item_ct1.get_local_id(0);
10-
const int i1 = item_ct1.get_group(1);
11-
const int g2 = item_ct1.get_group(2);
12-
const int i2 = g2 % ne02;
13-
const int i3 = g2 / ne02;
14-
15-
if (i0 >= p0 + ne0 + p1) return;
16-
17-
int t = i0 - p0;
18-
int period = 2 * ne0 -2;
19-
int m = t % period;
20-
m += (m < 0) * period;
21-
int center = ne0 -1;
22-
int srci0 = center - abs(center - m);
23-
24-
int offest_src = i3*nb3 + i2*nb2 + i1*nb1 + srci0*nb0;
25-
int offest_dst = i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00;
26-
dst[offest_dst] = src[offest_src];
3+
static void pad_reflect_1d_kernel_f32(
4+
const void *__restrict__ src0, void *__restrict__ dst, const int64_t ne0,
5+
const int64_t ne00, const sycl::uint3 ne01, const int64_t ne02,
6+
const int64_t ne03, const int64_t nb00, const int64_t nb01,
7+
const int64_t nb02, const int64_t nb03, const int64_t nb0,
8+
const int64_t nb1, const int64_t nb2, const int64_t nb3, const int p0,
9+
const int p1, sycl::nd_item<3> item_ct1) {
2710

11+
const int64_t i3 = item_ct1.get_group(0);
12+
const int64_t i2 = item_ct1.get_group(1);
13+
14+
const sycl::uint2 div_mod_packed =
15+
fast_div_modulo(item_ct1.get_group(2), ne01);
16+
const int64_t tile1 = div_mod_packed.y();
17+
const int64_t tile0 = div_mod_packed.x();
18+
const int64_t i1 = tile1;
19+
const int64_t i0 =
20+
item_ct1.get_local_id(2) + tile0 * item_ct1.get_local_range(2);
21+
22+
if (i0 >= ne0 || i1 >= ne01.z() || i2 >= ne02 || i3 >= ne03) {
23+
return;
24+
}
25+
26+
const char *src0_ptr =
27+
(const char *)src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
28+
char *dst_ptr = (char *)dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
29+
30+
const int64_t rel_i0 = i0 - p0; // relative i0 in src0
31+
int64_t src_idx;
32+
33+
if (rel_i0 < 0) {
34+
// Left padding - reflect
35+
src_idx = -rel_i0;
36+
} else if (rel_i0 < ne00) {
37+
// Middle - copy
38+
src_idx = rel_i0;
39+
} else {
40+
// Right padding - reflect
41+
src_idx = 2 * ne00 - 2 - rel_i0;
42+
}
43+
const float value = *(const float *)(src0_ptr + src_idx * nb00);
44+
*(float *)(dst_ptr + i0 * nb0) = value;
45+
46+
GGML_UNUSED(p1);
2847
}
2948

30-
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst){
49+
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context &ctx,
50+
ggml_tensor *dst) {
3151

32-
const ggml_tensor * src0 = dst->src[0];
33-
queue_ptr stream = ctx.stream();
52+
const ggml_tensor *src0 = dst->src[0];
53+
dpct::queue_ptr stream = ctx.stream();
3454

3555
GGML_ASSERT(src0->type == GGML_TYPE_F32);
36-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
56+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
3757

38-
const int32_t * opts = (const int32_t *) dst->op_params;
58+
const int32_t *opts = (const int32_t *)dst->op_params;
3959
const int p0 = opts[0];
4060
const int p1 = opts[1];
4161

42-
const int64_t ne0 = src0->ne[0];
43-
44-
const int64_t ne00 = dst->ne[0];
45-
const int64_t ne01 = dst->ne[1];
46-
const int64_t ne02 = dst->ne[2];
47-
const int64_t ne03 = dst->ne[3];
48-
49-
const int64_t nb00 = dst->nb[0];
50-
const int64_t nb01 = dst->nb[1];
51-
const int64_t nb02 = dst->nb[2];
52-
const int64_t nb03 = dst->nb[3];
53-
const int64_t nb0 = src0->nb[0];
54-
const int64_t nb1 = src0->nb[1];
55-
const int64_t nb2 = src0->nb[2];
56-
const int64_t nb3 = src0->nb[3];
57-
58-
int num_blocks = (ne00 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
59-
sycl::range<3> global(num_blocks * SYCL_CONCAT_BLOCK_SIZE, ne01, ne02*ne03);
60-
sycl::range<3> local(SYCL_CONCAT_BLOCK_SIZE, 1, 1);
61-
62-
stream->parallel_for(
63-
sycl::nd_range<3>(global,
64-
local),
65-
[=](sycl::nd_item<3> item_ct1) { pad_reflect_1d_f32(
66-
(const float *) src0->data, (float *) dst->data,
67-
ne0, ne02, p0, p1,
68-
nb0, nb1, nb2, nb3,
69-
nb00, nb01, nb02, nb03
70-
, item_ct1);
71-
});
62+
const int64_t ne00 = src0->ne[0];
63+
const int64_t ne01 = src0->ne[1];
64+
const sycl::uint3 ne01_packed = init_fastdiv_values(ne01);
65+
const int64_t ne02 = src0->ne[2];
66+
const int64_t ne03 = src0->ne[3];
67+
68+
const int64_t ne0 = dst->ne[0];
69+
70+
GGML_ASSERT(ne0 == ne00 + p0 + p1);
71+
72+
constexpr int64_t bx = SYCL_PAD_REFLECT_1D_BLOCK_SIZE;
73+
const int64_t tiles0 = (ne0 + bx - 1) / bx;
74+
const dpct::dim3 grid_dims((unsigned)(ne01 * tiles0), (unsigned)ne02,
75+
(unsigned)ne03);
76+
const dpct::dim3 block_dims((unsigned)bx, 1, 1);
77+
78+
stream->submit([&](sycl::handler &cgh) {
79+
auto src0_data_ct0 = src0->data;
80+
auto dst_data_ct1 = dst->data;
81+
auto src0_nb_ct7 = src0->nb[0];
82+
auto src0_nb_ct8 = src0->nb[1];
83+
auto src0_nb_ct9 = src0->nb[2];
84+
auto src0_nb_ct10 = src0->nb[3];
85+
auto dst_nb_ct11 = dst->nb[0];
86+
auto dst_nb_ct12 = dst->nb[1];
87+
auto dst_nb_ct13 = dst->nb[2];
88+
auto dst_nb_ct14 = dst->nb[3];
89+
90+
cgh.parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
91+
[=](sycl::nd_item<3> item_ct1) {
92+
pad_reflect_1d_kernel_f32(
93+
src0_data_ct0, dst_data_ct1, ne0, ne00,
94+
ne01_packed, ne02, ne03, src0_nb_ct7,
95+
src0_nb_ct8, src0_nb_ct9, src0_nb_ct10,
96+
dst_nb_ct11, dst_nb_ct12, dst_nb_ct13,
97+
dst_nb_ct14, p0, p1, item_ct1);
98+
});
99+
});
72100
}

ggml/src/ggml-sycl/pad_reflect_1d.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
#include "common.hpp"
55

6+
#define SYCL_PAD_REFLECT_1D_BLOCK_SIZE 256
7+
68
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
79

810
#endif // GGML_SYCL_PAD_REFLECT_1D_HPP

0 commit comments

Comments
 (0)