@@ -1748,20 +1748,19 @@ radix_sort_axis1_contig_impl(sycl::queue &exec_q,
17481748}
17491749
17501750template <typename ValueT, typename IndexT>
1751- class populate_indexed_data_for_radix_sort_krn ;
1751+ class radix_argsort_index_write_out_krn ;
17521752
1753- template <typename ValueT, typename IndexT>
1754- class index_write_out_for_radix_sort_krn ;
1753+ template <typename ValueT, typename IndexT> class radix_argsort_iota_krn ;
17551754
17561755template <typename argTy, typename IndexTy>
17571756sycl::event
17581757radix_argsort_axis1_contig_impl (sycl::queue &exec_q,
17591758 const bool sort_ascending,
1760- // number of sub-arrays to sort (num. of rows in
1761- // a matrix when sorting over rows)
1759+ // number of sub-arrays to sort (num. of
1760+ // rows in a matrix when sorting over rows)
17621761 size_t iter_nelems,
1763- // size of each array to sort (length of rows,
1764- // i.e. number of columns)
1762+ // size of each array to sort (length of
1763+ // rows, i.e. number of columns)
17651764 size_t sort_nelems,
17661765 const char *arg_cp,
17671766 char *res_cp,
@@ -1776,90 +1775,6 @@ radix_argsort_axis1_contig_impl(sycl::queue &exec_q,
17761775 IndexTy *res_tp =
17771776 reinterpret_cast <IndexTy *>(res_cp) + iter_res_offset + sort_res_offset;
17781777
1779- using ValueIndexT = std::pair<argTy, IndexTy>;
1780-
1781- const std::size_t total_nelems = iter_nelems * sort_nelems;
1782- const std::size_t padded_total_nelems = ((total_nelems + 63 ) / 64 ) * 64 ;
1783- ValueIndexT *workspace = sycl::malloc_device<ValueIndexT>(
1784- padded_total_nelems + total_nelems, exec_q);
1785-
1786- if (nullptr == workspace) {
1787- throw std::runtime_error (" Could not allocate workspace on device" );
1788- }
1789-
1790- ValueIndexT *indexed_data_tp = workspace;
1791- ValueIndexT *temp_tp = workspace + padded_total_nelems;
1792-
1793- using Proj = radix_sort_details::ValueProj<argTy, IndexTy>;
1794- constexpr Proj proj_op{};
1795-
1796- sycl::event populate_indexed_data_ev =
1797- exec_q.submit ([&](sycl::handler &cgh) {
1798- cgh.depends_on (depends);
1799-
1800- using KernelName =
1801- populate_indexed_data_for_radix_sort_krn<argTy, IndexTy>;
1802-
1803- cgh.parallel_for <KernelName>(
1804- sycl::range<1 >(total_nelems), [=](sycl::id<1 > id) {
1805- size_t i = id[0 ];
1806- IndexTy sort_id = static_cast <IndexTy>(i % sort_nelems);
1807- indexed_data_tp[i] = std::make_pair (arg_tp[i], sort_id);
1808- });
1809- });
1810-
1811- sycl::event radix_sort_ev =
1812- radix_sort_details::parallel_radix_sort_impl<ValueIndexT, Proj>(
1813- exec_q, iter_nelems, sort_nelems, indexed_data_tp, temp_tp, proj_op,
1814- sort_ascending, {populate_indexed_data_ev});
1815-
1816- sycl::event write_out_ev = exec_q.submit ([&](sycl::handler &cgh) {
1817- cgh.depends_on (radix_sort_ev);
1818-
1819- using KernelName = index_write_out_for_radix_sort_krn<argTy, IndexTy>;
1820-
1821- cgh.parallel_for <KernelName>(
1822- sycl::range<1 >(total_nelems),
1823- [=](sycl::id<1 > id) { res_tp[id] = std::get<1 >(temp_tp[id]); });
1824- });
1825-
1826- sycl::event cleanup_ev = exec_q.submit ([&](sycl::handler &cgh) {
1827- cgh.depends_on (write_out_ev);
1828-
1829- const sycl::context &ctx = exec_q.get_context ();
1830-
1831- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1832- cgh.host_task ([ctx, workspace] { sycl_free_noexcept (workspace, ctx); });
1833- });
1834-
1835- return cleanup_ev;
1836- }
1837-
1838- template <typename ValueT, typename IndexT> class iota_for_radix_sort_krn ;
1839-
1840- template <typename argTy, typename IndexTy>
1841- sycl::event
1842- radix_argsort_axis1_contig_alt_impl (sycl::queue &exec_q,
1843- const bool sort_ascending,
1844- // number of sub-arrays to sort (num. of
1845- // rows in a matrix when sorting over rows)
1846- size_t iter_nelems,
1847- // size of each array to sort (length of
1848- // rows, i.e. number of columns)
1849- size_t sort_nelems,
1850- const char *arg_cp,
1851- char *res_cp,
1852- ssize_t iter_arg_offset,
1853- ssize_t iter_res_offset,
1854- ssize_t sort_arg_offset,
1855- ssize_t sort_res_offset,
1856- const std::vector<sycl::event> &depends)
1857- {
1858- const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
1859- iter_arg_offset + sort_arg_offset;
1860- IndexTy *res_tp =
1861- reinterpret_cast <IndexTy *>(res_cp) + iter_res_offset + sort_res_offset;
1862-
18631778 const std::size_t total_nelems = iter_nelems * sort_nelems;
18641779 const std::size_t padded_total_nelems = ((total_nelems + 63 ) / 64 ) * 64 ;
18651780 IndexTy *workspace = sycl::malloc_device<IndexTy>(
@@ -1877,7 +1792,7 @@ radix_argsort_axis1_contig_alt_impl(sycl::queue &exec_q,
18771792 sycl::event iota_ev = exec_q.submit ([&](sycl::handler &cgh) {
18781793 cgh.depends_on (depends);
18791794
1880- using KernelName = iota_for_radix_sort_krn <argTy, IndexTy>;
1795+ using KernelName = radix_argsort_iota_krn <argTy, IndexTy>;
18811796
18821797 cgh.parallel_for <KernelName>(
18831798 sycl::range<1 >(total_nelems), [=](sycl::id<1 > id) {
@@ -1895,7 +1810,7 @@ radix_argsort_axis1_contig_alt_impl(sycl::queue &exec_q,
18951810 sycl::event map_back_ev = exec_q.submit ([&](sycl::handler &cgh) {
18961811 cgh.depends_on (radix_sort_ev);
18971812
1898- using KernelName = index_write_out_for_radix_sort_krn <argTy, IndexTy>;
1813+ using KernelName = radix_argsort_index_write_out_krn <argTy, IndexTy>;
18991814
19001815 cgh.parallel_for <KernelName>(
19011816 sycl::range<1 >(total_nelems), [=](sycl::id<1 > id) {
0 commit comments