|
1 | 1 | #include "pad_reflect_1d.hpp" |
2 | 2 |
|
3 | | -void pad_reflect_1d_f32(const float* src,float* dst, |
4 | | - const int64_t ne0, const int64_t ne02, const int p0, const int p1, |
5 | | - const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, |
6 | | - const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, |
7 | | - const sycl::nd_item<3> &item_ct1){ |
8 | | - |
9 | | - const int i0 = item_ct1.get_group(0) * SYCL_CONCAT_BLOCK_SIZE + item_ct1.get_local_id(0); |
10 | | - const int i1 = item_ct1.get_group(1); |
11 | | - const int g2 = item_ct1.get_group(2); |
12 | | - const int i2 = g2 % ne02; |
13 | | - const int i3 = g2 / ne02; |
14 | | - |
15 | | - if (i0 >= p0 + ne0 + p1) return; |
16 | | - |
17 | | - int t = i0 - p0; |
18 | | - int period = 2 * ne0 -2; |
19 | | - int m = t % period; |
20 | | - m += (m < 0) * period; |
21 | | - int center = ne0 -1; |
22 | | - int srci0 = center - abs(center - m); |
23 | | - |
24 | | - int offest_src = i3*nb3 + i2*nb2 + i1*nb1 + srci0*nb0; |
25 | | - int offest_dst = i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00; |
26 | | - dst[offest_dst] = src[offest_src]; |
| 3 | +static void pad_reflect_1d_kernel_f32( |
| 4 | + const void *__restrict__ src0, void *__restrict__ dst, const int64_t ne0, |
| 5 | + const int64_t ne00, const sycl::uint3 ne01, const int64_t ne02, |
| 6 | + const int64_t ne03, const int64_t nb00, const int64_t nb01, |
| 7 | + const int64_t nb02, const int64_t nb03, const int64_t nb0, |
| 8 | + const int64_t nb1, const int64_t nb2, const int64_t nb3, const int p0, |
| 9 | + const int p1, sycl::nd_item<3> item_ct1) { |
27 | 10 |
|
| 11 | + const int64_t i3 = item_ct1.get_group(0); |
| 12 | + const int64_t i2 = item_ct1.get_group(1); |
| 13 | + |
| 14 | + const sycl::uint2 div_mod_packed = |
| 15 | + fast_div_modulo(item_ct1.get_group(2), ne01); |
| 16 | + const int64_t tile1 = div_mod_packed.y(); |
| 17 | + const int64_t tile0 = div_mod_packed.x(); |
| 18 | + const int64_t i1 = tile1; |
| 19 | + const int64_t i0 = |
| 20 | + item_ct1.get_local_id(2) + tile0 * item_ct1.get_local_range(2); |
| 21 | + |
| 22 | + if (i0 >= ne0 || i1 >= ne01.z() || i2 >= ne02 || i3 >= ne03) { |
| 23 | + return; |
| 24 | + } |
| 25 | + |
| 26 | + const char *src0_ptr = |
| 27 | + (const char *)src0 + i3 * nb03 + i2 * nb02 + i1 * nb01; |
| 28 | + char *dst_ptr = (char *)dst + i3 * nb3 + i2 * nb2 + i1 * nb1; |
| 29 | + |
| 30 | + const int64_t rel_i0 = i0 - p0; // relative i0 in src0 |
| 31 | + int64_t src_idx; |
| 32 | + |
| 33 | + if (rel_i0 < 0) { |
| 34 | + // Left padding - reflect |
| 35 | + src_idx = -rel_i0; |
| 36 | + } else if (rel_i0 < ne00) { |
| 37 | + // Middle - copy |
| 38 | + src_idx = rel_i0; |
| 39 | + } else { |
| 40 | + // Right padding - reflect |
| 41 | + src_idx = 2 * ne00 - 2 - rel_i0; |
| 42 | + } |
| 43 | + const float value = *(const float *)(src0_ptr + src_idx * nb00); |
| 44 | + *(float *)(dst_ptr + i0 * nb0) = value; |
| 45 | + |
| 46 | + GGML_UNUSED(p1); |
28 | 47 | } |
29 | 48 |
|
30 | | -void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst){ |
| 49 | +void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context &ctx, |
| 50 | + ggml_tensor *dst) { |
31 | 51 |
|
32 | | - const ggml_tensor * src0 = dst->src[0]; |
33 | | - queue_ptr stream = ctx.stream(); |
| 52 | + const ggml_tensor *src0 = dst->src[0]; |
| 53 | + dpct::queue_ptr stream = ctx.stream(); |
34 | 54 |
|
35 | 55 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
36 | | - GGML_ASSERT( dst->type == GGML_TYPE_F32); |
| 56 | + GGML_ASSERT(dst->type == GGML_TYPE_F32); |
37 | 57 |
|
38 | | - const int32_t * opts = (const int32_t *) dst->op_params; |
| 58 | + const int32_t *opts = (const int32_t *)dst->op_params; |
39 | 59 | const int p0 = opts[0]; |
40 | 60 | const int p1 = opts[1]; |
41 | 61 |
|
42 | | - const int64_t ne0 = src0->ne[0]; |
43 | | - |
44 | | - const int64_t ne00 = dst->ne[0]; |
45 | | - const int64_t ne01 = dst->ne[1]; |
46 | | - const int64_t ne02 = dst->ne[2]; |
47 | | - const int64_t ne03 = dst->ne[3]; |
48 | | - |
49 | | - const int64_t nb00 = dst->nb[0]; |
50 | | - const int64_t nb01 = dst->nb[1]; |
51 | | - const int64_t nb02 = dst->nb[2]; |
52 | | - const int64_t nb03 = dst->nb[3]; |
53 | | - const int64_t nb0 = src0->nb[0]; |
54 | | - const int64_t nb1 = src0->nb[1]; |
55 | | - const int64_t nb2 = src0->nb[2]; |
56 | | - const int64_t nb3 = src0->nb[3]; |
57 | | - |
58 | | - int num_blocks = (ne00 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE; |
59 | | - sycl::range<3> global(num_blocks * SYCL_CONCAT_BLOCK_SIZE, ne01, ne02*ne03); |
60 | | - sycl::range<3> local(SYCL_CONCAT_BLOCK_SIZE, 1, 1); |
61 | | - |
62 | | - stream->parallel_for( |
63 | | - sycl::nd_range<3>(global, |
64 | | - local), |
65 | | - [=](sycl::nd_item<3> item_ct1) { pad_reflect_1d_f32( |
66 | | - (const float *) src0->data, (float *) dst->data, |
67 | | - ne0, ne02, p0, p1, |
68 | | - nb0, nb1, nb2, nb3, |
69 | | - nb00, nb01, nb02, nb03 |
70 | | - , item_ct1); |
71 | | - }); |
| 62 | + const int64_t ne00 = src0->ne[0]; |
| 63 | + const int64_t ne01 = src0->ne[1]; |
| 64 | + const sycl::uint3 ne01_packed = init_fastdiv_values(ne01); |
| 65 | + const int64_t ne02 = src0->ne[2]; |
| 66 | + const int64_t ne03 = src0->ne[3]; |
| 67 | + |
| 68 | + const int64_t ne0 = dst->ne[0]; |
| 69 | + |
| 70 | + GGML_ASSERT(ne0 == ne00 + p0 + p1); |
| 71 | + |
| 72 | + constexpr int64_t bx = SYCL_PAD_REFLECT_1D_BLOCK_SIZE; |
| 73 | + const int64_t tiles0 = (ne0 + bx - 1) / bx; |
| 74 | + const dpct::dim3 grid_dims((unsigned)(ne01 * tiles0), (unsigned)ne02, |
| 75 | + (unsigned)ne03); |
| 76 | + const dpct::dim3 block_dims((unsigned)bx, 1, 1); |
| 77 | + |
| 78 | + stream->submit([&](sycl::handler &cgh) { |
| 79 | + auto src0_data_ct0 = src0->data; |
| 80 | + auto dst_data_ct1 = dst->data; |
| 81 | + auto src0_nb_ct7 = src0->nb[0]; |
| 82 | + auto src0_nb_ct8 = src0->nb[1]; |
| 83 | + auto src0_nb_ct9 = src0->nb[2]; |
| 84 | + auto src0_nb_ct10 = src0->nb[3]; |
| 85 | + auto dst_nb_ct11 = dst->nb[0]; |
| 86 | + auto dst_nb_ct12 = dst->nb[1]; |
| 87 | + auto dst_nb_ct13 = dst->nb[2]; |
| 88 | + auto dst_nb_ct14 = dst->nb[3]; |
| 89 | + |
| 90 | + cgh.parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), |
| 91 | + [=](sycl::nd_item<3> item_ct1) { |
| 92 | + pad_reflect_1d_kernel_f32( |
| 93 | + src0_data_ct0, dst_data_ct1, ne0, ne00, |
| 94 | + ne01_packed, ne02, ne03, src0_nb_ct7, |
| 95 | + src0_nb_ct8, src0_nb_ct9, src0_nb_ct10, |
| 96 | + dst_nb_ct11, dst_nb_ct12, dst_nb_ct13, |
| 97 | + dst_nb_ct14, p0, p1, item_ct1); |
| 98 | + }); |
| 99 | + }); |
72 | 100 | } |
0 commit comments