Skip to content

Commit a14a3a0

Browse files
8371711: AArch64: SVE intrinsics for Arrays.sort methods (int, float)
This patch adds an SVE implementation of primitive array sorting (Arrays.sort()) on AArch64 systems that support SVE. On non-SVE machines, we fall back to the existing Java implementation. For smaller arrays (length <= 64), we use insertion sort; for larger arrays we use an SVE-vectorized quicksort partitioner followed by an odd-even transposition cleanup pass. The SVE path is enabled by default for int type. For float type, it is available through the experimental flag : -XX:+UnlockExperimentalVMOptions -XX:+UseSVELibSimdSortForFP Without this flag being enabled, the default Java implementation would be executed for floats (the flag is disabled by default). Float is gated due to observed regressions on some small/medium sizes. On larger arrays, the SVE float path shows upto 1.47x speedup on Neoverse V2 and 2.12x on Neoverse V1. Following are the performance numbers for ArraysSort JMH benchmark - Case A: Ratio between the scores of master branch and UseSVELibSimdSortForFP flag disabled (which is the default). Case B: Ratio between the scores of master branch and UseSVELibSimdSortForFP flag enabled (the int numbers will be the same but this now enables SVE vectorized sorting for floats). We would want the ratios to be >= 1 to be at par or better than the default Java implementation (master branch). On Neoverse V1: Benchmark (size) Mode Cnt A B ArraysSort.floatParallelSort 10 avgt 3 0.98 0.98 ArraysSort.floatParallelSort 25 avgt 3 1.01 0.83 ArraysSort.floatParallelSort 50 avgt 3 0.99 0.55 ArraysSort.floatParallelSort 75 avgt 3 0.99 0.66 ArraysSort.floatParallelSort 100 avgt 3 0.98 0.66 ArraysSort.floatParallelSort 1000 avgt 3 1.00 0.84 ArraysSort.floatParallelSort 10000 avgt 3 1.03 1.52 ArraysSort.floatParallelSort 100000 avgt 3 1.03 1.46 ArraysSort.floatParallelSort 1000000 avgt 3 0.98 1.81 ArraysSort.floatSort 10 avgt 3 1.00 0.98 ArraysSort.floatSort 25 avgt 3 1.00 0.81 ArraysSort.floatSort 50 avgt 3 0.99 0.56 ArraysSort.floatSort 75 avgt 3 0.99 0.65 ArraysSort.floatSort 100 avgt 3 0.98 0.70 ArraysSort.floatSort 1000 avgt 3 0.99 0.84 ArraysSort.floatSort 10000 avgt 3 0.99 1.72 ArraysSort.floatSort 100000 avgt 3 1.00 1.94 ArraysSort.floatSort 1000000 avgt 3 1.00 2.13 ArraysSort.intParallelSort 10 avgt 3 1.08 1.08 ArraysSort.intParallelSort 25 avgt 3 1.04 1.05 ArraysSort.intParallelSort 50 avgt 3 1.29 1.30 ArraysSort.intParallelSort 75 avgt 3 1.16 1.16 ArraysSort.intParallelSort 100 avgt 3 1.07 1.07 ArraysSort.intParallelSort 1000 avgt 3 1.13 1.13 ArraysSort.intParallelSort 10000 avgt 3 1.49 1.38 ArraysSort.intParallelSort 100000 avgt 3 1.64 1.62 ArraysSort.intParallelSort 1000000 avgt 3 2.26 2.27 ArraysSort.intSort 10 avgt 3 1.08 1.08 ArraysSort.intSort 25 avgt 3 1.02 1.02 ArraysSort.intSort 50 avgt 3 1.25 1.25 ArraysSort.intSort 75 avgt 3 1.16 1.20 ArraysSort.intSort 100 avgt 3 1.07 1.07 ArraysSort.intSort 1000 avgt 3 1.12 1.13 ArraysSort.intSort 10000 avgt 3 1.94 1.95 ArraysSort.intSort 100000 avgt 3 1.86 1.86 ArraysSort.intSort 1000000 avgt 3 2.09 2.09 On Neoverse V2: Benchmark (size) Mode Cnt A B ArraysSort.floatParallelSort 10 avgt 3 1.02 1.02 ArraysSort.floatParallelSort 25 avgt 3 0.97 0.71 ArraysSort.floatParallelSort 50 avgt 3 0.94 0.65 ArraysSort.floatParallelSort 75 avgt 3 0.96 0.82 ArraysSort.floatParallelSort 100 avgt 3 0.95 0.84 ArraysSort.floatParallelSort 1000 avgt 3 1.01 0.94 ArraysSort.floatParallelSort 10000 avgt 3 1.01 1.25 ArraysSort.floatParallelSort 100000 avgt 3 1.01 1.09 ArraysSort.floatParallelSort 1000000 avgt 3 1.00 1.10 ArraysSort.floatSort 10 avgt 3 1.02 1.00 ArraysSort.floatSort 25 avgt 3 0.99 0.76 ArraysSort.floatSort 50 avgt 3 0.97 0.66 ArraysSort.floatSort 75 avgt 3 1.01 0.83 ArraysSort.floatSort 100 avgt 3 1.00 0.85 ArraysSort.floatSort 1000 avgt 3 0.99 0.93 ArraysSort.floatSort 10000 avgt 3 1.00 1.28 ArraysSort.floatSort 100000 avgt 3 1.00 1.37 ArraysSort.floatSort 1000000 avgt 3 1.00 1.48 ArraysSort.intParallelSort 10 avgt 3 1.05 1.05 ArraysSort.intParallelSort 25 avgt 3 0.99 0.84 ArraysSort.intParallelSort 50 avgt 3 1.03 1.14 ArraysSort.intParallelSort 75 avgt 3 0.91 0.99 ArraysSort.intParallelSort 100 avgt 3 0.98 0.96 ArraysSort.intParallelSort 1000 avgt 3 1.32 1.30 ArraysSort.intParallelSort 10000 avgt 3 1.40 1.40 ArraysSort.intParallelSort 100000 avgt 3 1.00 1.04 ArraysSort.intParallelSort 1000000 avgt 3 1.15 1.14 ArraysSort.intSort 10 avgt 3 1.05 1.05 ArraysSort.intSort 25 avgt 3 1.03 1.03 ArraysSort.intSort 50 avgt 3 1.08 1.14 ArraysSort.intSort 75 avgt 3 0.88 0.98 ArraysSort.intSort 100 avgt 3 1.01 0.99 ArraysSort.intSort 1000 avgt 3 1.3 1.32 ArraysSort.intSort 10000 avgt 3 1.43 1.43 ArraysSort.intSort 100000 avgt 3 1.30 1.30 ArraysSort.intSort 1000000 avgt 3 1.37 1.37 This patch is part of a series of patches to add support for vectorized array sorting for AArch64 (including fixing the regressions for small/medium float arrays, support for double/long etc).
1 parent 2735140 commit a14a3a0

17 files changed

+1088
-0
lines changed
Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
/*
2+
* Copyright (c) 2021, 2023, Intel Corporation. All rights reserved.
3+
* Copyright (c) 2021 Serge Sans Paille. All rights reserved.
4+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
5+
*
6+
* This code is free software; you can redistribute it and/or modify it
7+
* under the terms of the GNU General Public License version 2 only, as
8+
* published by the Free Software Foundation.
9+
*
10+
* This code is distributed in the hope that it will be useful, but WITHOUT
11+
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
12+
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13+
* version 2 for more details (a copy is included in the LICENSE file that
14+
* accompanied this code).
15+
*
16+
* You should have received a copy of the GNU General Public License version
17+
* 2 along with this work; if not, write to the Free Software Foundation,
18+
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
19+
*
20+
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
21+
* or visit www.oracle.com if you need additional information or have any
22+
* questions.
23+
*
24+
*/
25+
26+
// This implementation is based on x86-simd-sort(https://github.com/intel/x86-simd-sort)
27+
28+
#ifndef AVX2_QSORT_32BIT
29+
#define AVX2_QSORT_32BIT
30+
31+
#include "avx2-emu-funcs.hpp"
32+
#include "xss-common-qsort.h"
33+
34+
/*
35+
* Constants used in sorting 8 elements in a ymm registers. Based on Bitonic
36+
* sorting network (see
37+
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg)
38+
*/
39+
40+
// ymm 7, 6, 5, 4, 3, 2, 1, 0
41+
#define NETWORK_32BIT_AVX2_1 4, 5, 6, 7, 0, 1, 2, 3
42+
#define NETWORK_32BIT_AVX2_2 0, 1, 2, 3, 4, 5, 6, 7
43+
#define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2
44+
#define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4
45+
46+
/*
47+
* Assumes ymm is random and performs a full sorting network defined in
48+
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
49+
*/
50+
template <typename vtype, typename reg_t = typename vtype::reg_t>
51+
X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit(reg_t ymm) {
52+
const typename vtype::opmask_t oxAA = _mm256_set_epi32(
53+
0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0);
54+
const typename vtype::opmask_t oxCC = _mm256_set_epi32(
55+
0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0);
56+
const typename vtype::opmask_t oxF0 = _mm256_set_epi32(
57+
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0, 0);
58+
59+
const typename vtype::ymmi_t rev_index = vtype::seti(NETWORK_32BIT_AVX2_2);
60+
ymm = cmp_merge<vtype>(
61+
ymm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(ymm), oxAA);
62+
ymm = cmp_merge<vtype>(
63+
ymm, vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_1), ymm), oxCC);
64+
ymm = cmp_merge<vtype>(
65+
ymm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(ymm), oxAA);
66+
ymm = cmp_merge<vtype>(ymm, vtype::permutexvar(rev_index, ymm), oxF0);
67+
ymm = cmp_merge<vtype>(
68+
ymm, vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_3), ymm), oxCC);
69+
ymm = cmp_merge<vtype>(
70+
ymm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(ymm), oxAA);
71+
return ymm;
72+
}
73+
74+
struct avx2_32bit_swizzle_ops;
75+
76+
template <>
77+
struct avx2_vector<int32_t> {
78+
using type_t = int32_t;
79+
using reg_t = __m256i;
80+
using ymmi_t = __m256i;
81+
using opmask_t = __m256i;
82+
static const uint8_t numlanes = 8;
83+
#ifdef XSS_MINIMAL_NETWORK_SORT
84+
static constexpr int network_sort_threshold = numlanes;
85+
#else
86+
static constexpr int network_sort_threshold = 256;
87+
#endif
88+
static constexpr int partition_unroll_factor = 4;
89+
90+
using swizzle_ops = avx2_32bit_swizzle_ops;
91+
92+
static type_t type_max() { return X86_SIMD_SORT_MAX_INT32; }
93+
static type_t type_min() { return X86_SIMD_SORT_MIN_INT32; }
94+
static reg_t zmm_max() {
95+
return _mm256_set1_epi32(type_max());
96+
} // TODO: this should broadcast bits as is?
97+
static opmask_t get_partial_loadmask(uint64_t num_to_read) {
98+
auto mask = ((0x1ull << num_to_read) - 0x1ull);
99+
return convert_int_to_avx2_mask(mask);
100+
}
101+
static ymmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7,
102+
int v8) {
103+
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
104+
}
105+
static opmask_t kxor_opmask(opmask_t x, opmask_t y) {
106+
return _mm256_xor_si256(x, y);
107+
}
108+
static opmask_t ge(reg_t x, reg_t y) {
109+
opmask_t equal = eq(x, y);
110+
opmask_t greater = _mm256_cmpgt_epi32(x, y);
111+
return _mm256_castps_si256(_mm256_or_ps(_mm256_castsi256_ps(equal),
112+
_mm256_castsi256_ps(greater)));
113+
}
114+
static opmask_t gt(reg_t x, reg_t y) { return _mm256_cmpgt_epi32(x, y); }
115+
static opmask_t eq(reg_t x, reg_t y) { return _mm256_cmpeq_epi32(x, y); }
116+
template <int scale>
117+
static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index,
118+
void const *base) {
119+
return _mm256_mask_i32gather_epi32(src, base, index, mask, scale);
120+
}
121+
template <int scale>
122+
static reg_t i64gather(__m256i index, void const *base) {
123+
return _mm256_i32gather_epi32((int const *)base, index, scale);
124+
}
125+
static reg_t loadu(void const *mem) {
126+
return _mm256_loadu_si256((reg_t const *)mem);
127+
}
128+
static reg_t max(reg_t x, reg_t y) { return _mm256_max_epi32(x, y); }
129+
static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) {
130+
return avx2_emu_mask_compressstoreu32<type_t>(mem, mask, x);
131+
}
132+
static reg_t maskz_loadu(opmask_t mask, void const *mem) {
133+
return _mm256_maskload_epi32((const int *)mem, mask);
134+
}
135+
static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) {
136+
reg_t dst = _mm256_maskload_epi32((type_t *)mem, mask);
137+
return mask_mov(x, mask, dst);
138+
}
139+
static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) {
140+
return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(x),
141+
_mm256_castsi256_ps(y),
142+
_mm256_castsi256_ps(mask)));
143+
}
144+
static void mask_storeu(void *mem, opmask_t mask, reg_t x) {
145+
return _mm256_maskstore_epi32((type_t *)mem, mask, x);
146+
}
147+
static reg_t min(reg_t x, reg_t y) { return _mm256_min_epi32(x, y); }
148+
static reg_t permutexvar(__m256i idx, reg_t ymm) {
149+
return _mm256_permutevar8x32_epi32(ymm, idx);
150+
// return avx2_emu_permutexvar_epi32(idx, ymm);
151+
}
152+
static reg_t permutevar(reg_t ymm, __m256i idx) {
153+
return _mm256_permutevar8x32_epi32(ymm, idx);
154+
}
155+
static reg_t reverse(reg_t ymm) {
156+
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
157+
return permutexvar(rev_index, ymm);
158+
}
159+
static type_t reducemax(reg_t v) {
160+
return avx2_emu_reduce_max32<type_t>(v);
161+
}
162+
static type_t reducemin(reg_t v) {
163+
return avx2_emu_reduce_min32<type_t>(v);
164+
}
165+
static reg_t set1(type_t v) { return _mm256_set1_epi32(v); }
166+
template <uint8_t mask>
167+
static reg_t shuffle(reg_t ymm) {
168+
return _mm256_shuffle_epi32(ymm, mask);
169+
}
170+
static void storeu(void *mem, reg_t x) {
171+
_mm256_storeu_si256((__m256i *)mem, x);
172+
}
173+
static reg_t sort_vec(reg_t x) {
174+
return sort_ymm_32bit<avx2_vector<type_t>>(x);
175+
}
176+
static reg_t cast_from(__m256i v) { return v; }
177+
static __m256i cast_to(reg_t v) { return v; }
178+
static int double_compressstore(type_t *left_addr, type_t *right_addr,
179+
opmask_t k, reg_t reg) {
180+
return avx2_double_compressstore32<type_t>(left_addr, right_addr, k,
181+
reg);
182+
}
183+
};
184+
185+
template <>
186+
struct avx2_vector<float> {
187+
using type_t = float;
188+
using reg_t = __m256;
189+
using ymmi_t = __m256i;
190+
using opmask_t = __m256i;
191+
static const uint8_t numlanes = 8;
192+
#ifdef XSS_MINIMAL_NETWORK_SORT
193+
static constexpr int network_sort_threshold = numlanes;
194+
#else
195+
static constexpr int network_sort_threshold = 256;
196+
#endif
197+
static constexpr int partition_unroll_factor = 4;
198+
199+
using swizzle_ops = avx2_32bit_swizzle_ops;
200+
201+
static type_t type_max() { return X86_SIMD_SORT_INFINITYF; }
202+
static type_t type_min() { return -X86_SIMD_SORT_INFINITYF; }
203+
static reg_t zmm_max() { return _mm256_set1_ps(type_max()); }
204+
205+
static ymmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7,
206+
int v8) {
207+
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
208+
}
209+
210+
static reg_t maskz_loadu(opmask_t mask, void const *mem) {
211+
return _mm256_maskload_ps((const float *)mem, mask);
212+
}
213+
static opmask_t ge(reg_t x, reg_t y) {
214+
return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_GE_OQ));
215+
}
216+
static opmask_t gt(reg_t x, reg_t y) {
217+
return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_GT_OQ));
218+
}
219+
static opmask_t eq(reg_t x, reg_t y) {
220+
return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_EQ_OQ));
221+
}
222+
static opmask_t get_partial_loadmask(uint64_t num_to_read) {
223+
auto mask = ((0x1ull << num_to_read) - 0x1ull);
224+
return convert_int_to_avx2_mask(mask);
225+
}
226+
static int32_t convert_mask_to_int(opmask_t mask) {
227+
return convert_avx2_mask_to_int(mask);
228+
}
229+
template <int type>
230+
static opmask_t fpclass(reg_t x) {
231+
if constexpr (type == (0x01 | 0x80)) {
232+
return _mm256_castps_si256(_mm256_cmp_ps(x, x, _CMP_UNORD_Q));
233+
} else {
234+
static_assert(type == (0x01 | 0x80), "should not reach here");
235+
}
236+
}
237+
template <int scale>
238+
static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index,
239+
void const *base) {
240+
return _mm256_mask_i32gather_ps(src, base, index,
241+
_mm256_castsi256_ps(mask), scale);
242+
;
243+
}
244+
template <int scale>
245+
static reg_t i64gather(__m256i index, void const *base) {
246+
return _mm256_i32gather_ps((float *)base, index, scale);
247+
}
248+
static reg_t loadu(void const *mem) {
249+
return _mm256_loadu_ps((float const *)mem);
250+
}
251+
static reg_t max(reg_t x, reg_t y) { return _mm256_max_ps(x, y); }
252+
static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) {
253+
return avx2_emu_mask_compressstoreu32<type_t>(mem, mask, x);
254+
}
255+
static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) {
256+
reg_t dst = _mm256_maskload_ps((type_t *)mem, mask);
257+
return mask_mov(x, mask, dst);
258+
}
259+
static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) {
260+
return _mm256_blendv_ps(x, y, _mm256_castsi256_ps(mask));
261+
}
262+
static void mask_storeu(void *mem, opmask_t mask, reg_t x) {
263+
return _mm256_maskstore_ps((type_t *)mem, mask, x);
264+
}
265+
static reg_t min(reg_t x, reg_t y) { return _mm256_min_ps(x, y); }
266+
static reg_t permutexvar(__m256i idx, reg_t ymm) {
267+
return _mm256_permutevar8x32_ps(ymm, idx);
268+
}
269+
static reg_t permutevar(reg_t ymm, __m256i idx) {
270+
return _mm256_permutevar8x32_ps(ymm, idx);
271+
}
272+
static reg_t reverse(reg_t ymm) {
273+
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
274+
return permutexvar(rev_index, ymm);
275+
}
276+
static type_t reducemax(reg_t v) {
277+
return avx2_emu_reduce_max32<type_t>(v);
278+
}
279+
static type_t reducemin(reg_t v) {
280+
return avx2_emu_reduce_min32<type_t>(v);
281+
}
282+
static reg_t set1(type_t v) { return _mm256_set1_ps(v); }
283+
template <uint8_t mask>
284+
static reg_t shuffle(reg_t ymm) {
285+
return _mm256_castsi256_ps(
286+
_mm256_shuffle_epi32(_mm256_castps_si256(ymm), mask));
287+
}
288+
static void storeu(void *mem, reg_t x) {
289+
_mm256_storeu_ps((float *)mem, x);
290+
}
291+
static reg_t sort_vec(reg_t x) {
292+
return sort_ymm_32bit<avx2_vector<type_t>>(x);
293+
}
294+
static reg_t cast_from(__m256i v) { return _mm256_castsi256_ps(v); }
295+
static __m256i cast_to(reg_t v) { return _mm256_castps_si256(v); }
296+
static int double_compressstore(type_t *left_addr, type_t *right_addr,
297+
opmask_t k, reg_t reg) {
298+
return avx2_double_compressstore32<type_t>(left_addr, right_addr, k,
299+
reg);
300+
}
301+
};
302+
303+
struct avx2_32bit_swizzle_ops {
304+
template <typename vtype, int scale>
305+
X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(
306+
typename vtype::reg_t reg) {
307+
__m256i v = vtype::cast_to(reg);
308+
309+
if constexpr (scale == 2) {
310+
__m256 vf = _mm256_castsi256_ps(v);
311+
vf = _mm256_permute_ps(vf, 0b10110001);
312+
v = _mm256_castps_si256(vf);
313+
} else if constexpr (scale == 4) {
314+
__m256 vf = _mm256_castsi256_ps(v);
315+
vf = _mm256_permute_ps(vf, 0b01001110);
316+
v = _mm256_castps_si256(vf);
317+
} else if constexpr (scale == 8) {
318+
v = _mm256_permute2x128_si256(v, v, 0b00000001);
319+
} else {
320+
static_assert(scale == -1, "should not be reached");
321+
}
322+
323+
return vtype::cast_from(v);
324+
}
325+
326+
template <typename vtype, int scale>
327+
X86_SIMD_SORT_INLINE typename vtype::reg_t reverse_n(
328+
typename vtype::reg_t reg) {
329+
__m256i v = vtype::cast_to(reg);
330+
331+
if constexpr (scale == 2) {
332+
return swap_n<vtype, 2>(reg);
333+
} else if constexpr (scale == 4) {
334+
constexpr uint64_t mask = 0b00011011;
335+
__m256 vf = _mm256_castsi256_ps(v);
336+
vf = _mm256_permute_ps(vf, mask);
337+
v = _mm256_castps_si256(vf);
338+
} else if constexpr (scale == 8) {
339+
return vtype::reverse(reg);
340+
} else {
341+
static_assert(scale == -1, "should not be reached");
342+
}
343+
344+
return vtype::cast_from(v);
345+
}
346+
347+
template <typename vtype, int scale>
348+
X86_SIMD_SORT_INLINE typename vtype::reg_t merge_n(
349+
typename vtype::reg_t reg, typename vtype::reg_t other) {
350+
__m256i v1 = vtype::cast_to(reg);
351+
__m256i v2 = vtype::cast_to(other);
352+
353+
if constexpr (scale == 2) {
354+
v1 = _mm256_blend_epi32(v1, v2, 0b01010101);
355+
} else if constexpr (scale == 4) {
356+
v1 = _mm256_blend_epi32(v1, v2, 0b00110011);
357+
} else if constexpr (scale == 8) {
358+
v1 = _mm256_blend_epi32(v1, v2, 0b00001111);
359+
} else {
360+
static_assert(scale == -1, "should not be reached");
361+
}
362+
363+
return vtype::cast_from(v1);
364+
}
365+
};
366+
367+
#endif // AVX2_QSORT_32BIT

0 commit comments

Comments
 (0)