2020from dpctl .tests .helper import get_queue_or_skip , skip_if_dtype_not_supported
2121
2222
23+ def _expected_largest_inds (inp , n , shift , k ):
24+ "Computed expected top_k indices for mode='largest'"
25+ assert k < n
26+ ones_start_id = shift % (2 * n )
27+
28+ alloc_dev = inp .device
29+
30+ if ones_start_id < n :
31+ expected_inds = dpt .arange (
32+ ones_start_id , ones_start_id + k , dtype = "i8" , device = alloc_dev
33+ )
34+ else :
35+ # wrap-around
36+ ones_end_id = (ones_start_id + n ) % (2 * n )
37+ if ones_end_id >= k :
38+ expected_inds = dpt .arange (k , dtype = "i8" , device = alloc_dev )
39+ else :
40+ expected_inds = dpt .concat (
41+ (
42+ dpt .arange (ones_end_id , dtype = "i8" , device = alloc_dev ),
43+ dpt .arange (
44+ ones_start_id ,
45+ ones_start_id + k - ones_end_id ,
46+ dtype = "i8" ,
47+ device = alloc_dev ,
48+ ),
49+ )
50+ )
51+
52+ return expected_inds
53+
54+
2355@pytest .mark .parametrize (
2456 "dtype" ,
2557 [
3870 "c16" ,
3971 ],
4072)
41- @pytest .mark .parametrize ("n" , [33 , 255 , 511 , 1021 , 8193 ])
42- def test_topk_1d_largest (dtype , n ):
73+ @pytest .mark .parametrize ("n" , [33 , 43 , 255 , 511 , 1021 , 8193 ])
74+ def test_top_k_1d_largest (dtype , n ):
4375 q = get_queue_or_skip ()
4476 skip_if_dtype_not_supported (dtype , q )
4577
78+ shift , k = 734 , 5
4679 o = dpt .ones (n , dtype = dtype )
4780 z = dpt .zeros (n , dtype = dtype )
48- zo = dpt .concat ((o , z ))
49- inp = dpt .roll (zo , 734 )
50- k = 5
81+ oz = dpt .concat ((o , z ))
82+ inp = dpt .roll (oz , shift )
83+
84+ expected_inds = _expected_largest_inds (oz , n , shift , k )
5185
5286 s = dpt .top_k (inp , k , mode = "largest" )
5387 assert s .values .shape == (k ,)
5488 assert s .values .dtype == inp .dtype
5589 assert s .indices .shape == (k ,)
56- assert dpt .all (s .values == dpt .ones (k , dtype = dtype ))
57- assert dpt .all (s .values == inp [s .indices ])
90+ assert dpt .all (s .indices == expected_inds )
91+ assert dpt .all (s .values == dpt .ones (k , dtype = dtype )), s .values
92+ assert dpt .all (s .values == inp [s .indices ]), s .indices
93+
94+
95+ def _expected_smallest_inds (inp , n , shift , k ):
96+ "Computed expected top_k indices for mode='smallest'"
97+ assert k < n
98+ zeros_start_id = (n + shift ) % (2 * n )
99+ zeros_end_id = (shift ) % (2 * n )
100+
101+ alloc_dev = inp .device
102+
103+ if zeros_start_id < zeros_end_id :
104+ expected_inds = dpt .arange (
105+ zeros_start_id , zeros_start_id + k , dtype = "i8" , device = alloc_dev
106+ )
107+ else :
108+ if zeros_end_id >= k :
109+ expected_inds = dpt .arange (k , dtype = "i8" , device = alloc_dev )
110+ else :
111+ expected_inds = dpt .concat (
112+ (
113+ dpt .arange (zeros_end_id , dtype = "i8" , device = alloc_dev ),
114+ dpt .arange (
115+ zeros_start_id ,
116+ zeros_start_id + k - zeros_end_id ,
117+ dtype = "i8" ,
118+ device = alloc_dev ,
119+ ),
120+ )
121+ )
122+
123+ return expected_inds
58124
59125
60126@pytest .mark .parametrize (
@@ -75,41 +141,80 @@ def test_topk_1d_largest(dtype, n):
75141 "c16" ,
76142 ],
77143)
78- @pytest .mark .parametrize ("n" , [33 , 255 , 257 , 513 , 1021 , 8193 ])
79- def test_topk_1d_smallest (dtype , n ):
144+ @pytest .mark .parametrize ("n" , [37 , 39 , 61 , 255 , 257 , 513 , 1021 , 8193 ])
145+ def test_top_k_1d_smallest (dtype , n ):
80146 q = get_queue_or_skip ()
81147 skip_if_dtype_not_supported (dtype , q )
82148
149+ shift , k = 734 , 5
83150 o = dpt .ones (n , dtype = dtype )
84151 z = dpt .zeros (n , dtype = dtype )
85- zo = dpt .concat ((o , z ))
86- inp = dpt .roll (zo , 734 )
87- k = 5
152+ oz = dpt .concat ((o , z ))
153+ inp = dpt .roll (oz , shift )
154+
155+ expected_inds = _expected_smallest_inds (oz , n , shift , k )
88156
89157 s = dpt .top_k (inp , k , mode = "smallest" )
90158 assert s .values .shape == (k ,)
91159 assert s .values .dtype == inp .dtype
92160 assert s .indices .shape == (k ,)
93- assert dpt .all (s .values == dpt .zeros (k , dtype = dtype ))
94- assert dpt .all (s .values == inp [s .indices ])
161+ assert dpt .all (s .indices == expected_inds )
162+ assert dpt .all (s .values == dpt .zeros (k , dtype = dtype )), s .values
163+ assert dpt .all (s .values == inp [s .indices ]), s .indices
95164
96165
97166# triage failing top k radix implementation on CPU
98167# replicates from Python behavior of radix sort topk implementation
99- @pytest .mark .parametrize ("n" , [33 , 255 , 511 , 1021 , 8193 ])
100- def test_topk_largest_1d_radix_i1_255 (n ):
168+ @pytest .mark .parametrize (
169+ "n" ,
170+ [
171+ 33 ,
172+ 34 ,
173+ 35 ,
174+ 36 ,
175+ 37 ,
176+ 38 ,
177+ 39 ,
178+ 40 ,
179+ 41 ,
180+ 42 ,
181+ 43 ,
182+ 44 ,
183+ 45 ,
184+ 46 ,
185+ 47 ,
186+ 48 ,
187+ 49 ,
188+ 50 ,
189+ 61 ,
190+ 137 ,
191+ 255 ,
192+ 511 ,
193+ 1021 ,
194+ 8193 ,
195+ ],
196+ )
197+ def test_top_k_largest_1d_radix_i1 (n ):
101198 get_queue_or_skip ()
102199 dt = "i1"
103200
201+ shift , k = 734 , 5
104202 o = dpt .ones (n , dtype = dt )
105203 z = dpt .zeros (n , dtype = dt )
106- zo = dpt .concat ((o , z ))
107- inp = dpt .roll (zo , 734 )
108- k = 5
109-
110- sorted = dpt .copy (dpt .sort (inp , descending = True , kind = "radixsort" )[:k ])
111- argsorted = dpt .copy (
112- dpt .argsort (inp , descending = True , kind = "radixsort" )[:k ]
113- )
114- assert dpt .all (sorted == dpt .ones (k , dtype = dt ))
115- assert dpt .all (sorted == inp [argsorted ])
204+ oz = dpt .concat ((o , z ))
205+ inp = dpt .roll (oz , shift )
206+
207+ expected_inds = _expected_largest_inds (oz , n , shift , k )
208+
209+ sorted_v = dpt .sort (inp , descending = True , kind = "radixsort" )
210+ argsorted = dpt .argsort (inp , descending = True , kind = "radixsort" )
211+
212+ assert dpt .all (sorted_v == inp [argsorted ])
213+
214+ topk_vals = dpt .copy (sorted_v [:k ])
215+ topk_inds = dpt .copy (argsorted [:k ])
216+
217+ assert dpt .all (topk_vals == dpt .ones (k , dtype = dt ))
218+ assert dpt .all (topk_inds == expected_inds )
219+
220+ assert dpt .all (topk_vals == inp [topk_inds ]), topk_inds
0 commit comments