From 31d9c82693b56077f9d8193630465b911603b986 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 9 Dec 2024 21:26:07 -0800 Subject: [PATCH 01/18] Add kernel for nan_to_num to ufunc extension --- dpnp/backend/extensions/ufunc/CMakeLists.txt | 1 + .../ufunc/elementwise_functions/common.cpp | 2 + .../elementwise_functions/nan_to_num.cpp | 300 ++++++++++++++++++ .../elementwise_functions/nan_to_num.hpp | 35 ++ .../elementwise_functions/nan_to_num.hpp | 130 ++++++++ dpnp/dpnp_iface_mathematical.py | 38 +-- 6 files changed, 487 insertions(+), 19 deletions(-) create mode 100644 dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp create mode 100644 dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.hpp create mode 100644 dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp diff --git a/dpnp/backend/extensions/ufunc/CMakeLists.txt b/dpnp/backend/extensions/ufunc/CMakeLists.txt index d45bfa822e5d..5f892506b81c 100644 --- a/dpnp/backend/extensions/ufunc/CMakeLists.txt +++ b/dpnp/backend/extensions/ufunc/CMakeLists.txt @@ -38,6 +38,7 @@ set(_elementwise_sources ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/lcm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/ldexp.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/logaddexp2.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/nan_to_num.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/radians.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/sinc.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/spacing.cpp diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp index 43a68e487cf4..f9d179d5ca4e 100644 --- a/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp @@ -38,6 +38,7 @@ #include "lcm.hpp" #include "ldexp.hpp" #include "logaddexp2.hpp" +#include "nan_to_num.hpp" #include "radians.hpp" #include "sinc.hpp" #include "spacing.hpp" @@ -64,6 +65,7 @@ void init_elementwise_functions(py::module_ m) init_lcm(m); init_ldexp(m); init_logaddexp2(m); + init_nan_to_num(m); init_radians(m); init_sinc(m); init_spacing(m); diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp new file mode 100644 index 000000000000..c4705e5b7816 --- /dev/null +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp @@ -0,0 +1,300 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +#include + +#include "dpctl4pybind11.hpp" +#include +#include +#include + +#include "kernels/elementwise_functions/nan_to_num.hpp" + +#include "../../elementwise_functions/simplify_iteration_space.hpp" + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/output_validation.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +// declare pybind11 wrappers in py_internal namespace +namespace dpnp::extensions::ufunc +{ + +namespace impl +{ +typedef sycl::event (*nan_to_num_fn_ptr_t)(sycl::queue &, + int, + size_t, + py::ssize_t *, + const py::object &, + const py::object &, + const py::object &, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &); + +template +sycl::event nan_to_num_call(sycl::queue &exec_q, + int nd, + size_t nelems, + py::ssize_t *shape_strides, + const py::object &py_nan, + const py::object &py_posinf, + const py::object &py_neginf, + const char *arg_p, + py::ssize_t arg_offset, + char *dst_p, + py::ssize_t dst_offset, + const std::vector &depends) +{ + sycl::event to_num_ev; + + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using realT = typename T::value_type; + realT nan_v = py::cast(py_nan); + realT posinf_v = py::cast(py_posinf); + realT neginf_v = py::cast(py_neginf); + + using dpnp::kernels::nan_to_num::nan_to_num_impl; + to_num_ev = nan_to_num_impl( + exec_q, nd, nelems, shape_strides, nan_v, posinf_v, neginf_v, arg_p, + arg_offset, dst_p, dst_offset, depends); + } + else { + T nan_v = py::cast(py_nan); + T posinf_v = py::cast(py_posinf); + T neginf_v = py::cast(py_neginf); + + using dpnp::kernels::nan_to_num::nan_to_num_impl; + to_num_ev = nan_to_num_impl( + exec_q, nd, nelems, shape_strides, nan_v, posinf_v, neginf_v, arg_p, + arg_offset, dst_p, dst_offset, depends); + } + return to_num_ev; +} + +namespace td_ns = dpctl::tensor::type_dispatch; +nan_to_num_fn_ptr_t nan_to_num_dispatch_vector[td_ns::num_types]; + +std::pair + py_nan_to_num(const dpctl::tensor::usm_ndarray &src, + const py::object &py_nan, + const py::object &py_posinf, + const py::object &py_neginf, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &q, + const std::vector &depends) +{ + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + if (src_typeid != dst_typeid) { + throw py::value_error("Array data types are not the same."); + } + + if (!dpctl::utils::queues_are_compatible(q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + int src_nd = src.get_ndim(); + if (src_nd != dst.get_ndim()) { + throw py::value_error("Array dimensions are not the same."); + } + + const py::ssize_t *src_shape = src.get_shape_raw(); + const py::ssize_t *dst_shape = dst.get_shape_raw(); + + bool shapes_equal(true); + size_t nelems(1); + for (int i = 0; i < src_nd; ++i) { + nelems *= static_cast(src_shape[i]); + shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]); + } + if (!shapes_equal) { + throw py::value_error("Array shapes are not the same."); + } + + // if nelems is zero, return + if (nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, nelems); + + // check memory overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + auto const &same_logical_tensors = + dpctl::tensor::overlap::SameLogicalTensors(); + if (overlap(src, dst) && !same_logical_tensors(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + const char *src_data = src.get_data(); + char *dst_data = dst.get_data(); + + auto const &src_strides = src.get_strides_vector(); + auto const &dst_strides = dst.get_strides_vector(); + + using shT = std::vector; + shT simplified_shape; + shT simplified_src_strides; + shT simplified_dst_strides; + py::ssize_t src_offset(0); + py::ssize_t dst_offset(0); + + int nd = src_nd; + const py::ssize_t *shape = src_shape; + + py_internal::simplify_iteration_space( + nd, shape, src_strides, dst_strides, + // output + simplified_shape, simplified_src_strides, simplified_dst_strides, + src_offset, dst_offset); + + auto fn = nan_to_num_dispatch_vector[src_typeid]; + + if (fn == nullptr) { + throw std::runtime_error( + "nan_to_num implementation is missing for src_typeid=" + + std::to_string(src_typeid)); + } + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + + std::vector host_tasks{}; + host_tasks.reserve(2); + + const auto &ptr_size_event_triple_ = device_allocate_and_pack( + q, host_tasks, simplified_shape, simplified_src_strides, + simplified_dst_strides); + py::ssize_t *shape_strides = std::get<0>(ptr_size_event_triple_); + const sycl::event ©_shape_ev = std::get<2>(ptr_size_event_triple_); + + if (shape_strides == nullptr) { + throw std::runtime_error("Device memory allocation failed"); + } + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shape_ev); + + sycl::event comp_ev = + fn(q, nelems, nd, shape_strides, py_nan, py_posinf, py_neginf, src_data, + src_offset, dst_data, dst_offset, all_deps); + + // async free of shape_strides temporary + auto ctx = q.get_context(); + sycl::event tmp_cleanup_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(comp_ev); + using dpctl::tensor::alloc_utils::sycl_free_noexcept; + cgh.host_task( + [ctx, shape_strides]() { sycl_free_noexcept(shape_strides, ctx); }); + }); + host_tasks.push_back(tmp_cleanup_ev); + + return std::make_pair( + dpctl::utils::keep_args_alive(q, {src, dst}, host_tasks), comp_ev); +} + +namespace py_int = dpnp::extensions::py_internal; + +/** + * @brief A factory to define pairs of supported types for which + * nan_to_num_call function is available. + * + * @tparam T Type of input vector `a` and of result vector `y`. + */ +template +struct NanToNumOutputType +{ + using value_type = typename std::disjunction< + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::DefaultResultEntry>::result_type; +}; + +template +struct NanToNumFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) { + return nullptr; + } + else { + using ::dpnp::extensions::ufunc::impl::nan_to_num_call; + return nan_to_num_call; + } + } +}; + +void populate_nan_to_num_dispatch_vector(void) +{ + using namespace td_ns; + + DispatchVectorBuilder dvb; + dvb.populate_dispatch_vector(nan_to_num_dispatch_vector); +} + +} // namespace impl + +void init_nan_to_num(py::module_ m) +{ + { + impl::populate_nan_to_num_dispatch_vector(); + + using impl::py_nan_to_num; + m.def("_nan_to_num", &py_nan_to_num, "", py::arg("src"), + py::arg("py_nan"), py::arg("py_posinf"), py::arg("py_neginf"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + } +} + +} // namespace dpnp::extensions::ufunc diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.hpp b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.hpp new file mode 100644 index 000000000000..26ac37bf1c4e --- /dev/null +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.hpp @@ -0,0 +1,35 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpnp::extensions::ufunc +{ +void init_nan_to_num(py::module_ m); +} // namespace dpnp::extensions::ufunc diff --git a/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp b/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp new file mode 100644 index 000000000000..b7b7a0d7f989 --- /dev/null +++ b/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp @@ -0,0 +1,130 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include +#include + +#include +// dpctl tensor headers +#include "kernels/dpctl_tensor_types.hpp" +#include "utils/offset_utils.hpp" +#include "utils/type_utils.hpp" + +namespace dpnp::kernels::nan_to_num +{ + +template +T to_num(const T v, const T nan, const T posinf, const T neginf) +{ + return (sycl::isnan(v)) ? nan + : (sycl::isinf(v)) ? (v > 0) ? posinf : neginf + : v; +} + +template +struct NanToNumFunctor +{ +public: + NanToNumFunctor(const T *inp, + T *out, + const InOutIndexerT &inp_out_indexer, + const scT nan, + const scT posinf, + const scT neginf) + : inp_(inp), out_(out), inp_out_indexer_(inp_out_indexer), nan_(nan), + posinf_(posinf), neginf_(neginf) + { + } + + void operator()(sycl::id<1> wid) const + { + const auto &offsets_ = inp_out_indexer_(wid.get(0)); + const dpctl::tensor::ssize_t &inp_offset = offsets_.get_first_offset(); + const dpctl::tensor::ssize_t &out_offset = offsets_.get_second_offset(); + + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using realT = typename T::value_type; + static_assert(std::is_same_v); + T z = inp_[inp_offset]; + realT x = to_num(z.real(), nan_, posinf_, neginf_); + realT y = to_num(z.imag(), nan_, posinf_, neginf_); + out_[out_offset] = T{x, y}; + } + else { + out_[out_offset] = to_num(inp_[inp_offset], nan_, posinf_, neginf_); + } + } + +private: + const T *inp_ = nullptr; + T *out_ = nullptr; + const InOutIndexerT inp_out_indexer_; + const scT nan_; + const scT posinf_; + const scT neginf_; +}; + +template +class NanToNumKernel; + +template +sycl::event nan_to_num_impl(sycl::queue &q, + size_t nelems, + int nd, + const ssize_t *shape_strides, + const scT nan, + const scT posinf, + const scT neginf, + const char *in_cp, + dpctl::tensor::ssize_t in_offset, + char *out_cp, + dpctl::tensor::ssize_t out_offset, + const std::vector &depends) +{ + dpctl::tensor::type_utils::validate_type_for_device(q); + + const T *in_tp = reinterpret_cast(in_cp); + T *out_tp = reinterpret_cast(out_cp); + + using InOutIndexerT = + typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + const InOutIndexerT indexer{nd, in_offset, out_offset, shape_strides}; + + sycl::event comp_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using KernelName = NanToNumKernel; + cgh.parallel_for( + {nelems}, NanToNumFunctor( + in_tp, out_tp, indexer, nan, posinf, neginf)); + }); + return comp_ev; +} + +} // namespace dpnp::kernels::nan_to_num diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index cf3d14b98de9..880c974976fc 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -3125,21 +3125,11 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None): "nan must be a scalar of an integer, float, bool, " f"but got {type(nan)}" ) - - out = dpnp.empty_like(x) if copy else x x_type = x.dtype.type if not issubclass(x_type, dpnp.inexact): - return x + return x.copy() if copy else x - parts = ( - (x.real, x.imag) if issubclass(x_type, dpnp.complexfloating) else (x,) - ) - parts_out = ( - (out.real, out.imag) - if issubclass(x_type, dpnp.complexfloating) - else (out,) - ) max_f, min_f = _get_max_min(x.real.dtype) if posinf is not None: if not isinstance(posinf, (int, float)): @@ -3156,16 +3146,26 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None): ) min_f = neginf - for part, part_out in zip(parts, parts_out): - nan_mask = dpnp.isnan(part) - posinf_mask = dpnp.isposinf(part) - neginf_mask = dpnp.isneginf(part) + if copy: + out = dpnp.empty_like(x) + else: + if not x.flags.writable: + raise ValueError("copy is required for read-only array `x`") + out = x + + x_ary = dpnp.get_usm_ndarray(x) + out_ary = dpnp.get_usm_ndarray(out) + + q = x.sycl_queue + _manager = dpu.SequentialOrderManager[q] + + h_ev, comp_ev = ufi._nan_to_num( + x_ary, nan, max_f, min_f, out_ary, q, depends=_manager.submitted_events + ) - part = dpnp.where(nan_mask, nan, part, out=part_out) - part = dpnp.where(posinf_mask, max_f, part, out=part_out) - part = dpnp.where(neginf_mask, min_f, part, out=part_out) + _manager.add_event_pair(h_ev, comp_ev) - return out + return dpnp.get_result_array(out_ary) if copy else x _NEGATIVE_DOCSTRING = """ From 5ad83b05d8755ddf2f99e6c7483b648322b5babd Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 9 Dec 2024 21:26:35 -0800 Subject: [PATCH 02/18] Add missing headers in nan_to_num.cpp --- .../ufunc/elementwise_functions/nan_to_num.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp index c4705e5b7816..3b88a4841a88 100644 --- a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp @@ -23,7 +23,14 @@ // THE POSSIBILITY OF SUCH DAMAGE. //***************************************************************************** +#include +#include #include +#include +#include +#include +#include +#include #include @@ -145,12 +152,8 @@ std::pair const py::ssize_t *src_shape = src.get_shape_raw(); const py::ssize_t *dst_shape = dst.get_shape_raw(); - bool shapes_equal(true); - size_t nelems(1); - for (int i = 0; i < src_nd; ++i) { - nelems *= static_cast(src_shape[i]); - shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]); - } + size_t nelems = src.get_size(); + bool shapes_equal = std::equal(src_shape, src_shape + src_nd, dst_shape); if (!shapes_equal) { throw py::value_error("Array shapes are not the same."); } From d99c0d52a69cdf4591cdd98801a111dcf1e32802 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 9 Dec 2024 21:27:26 -0800 Subject: [PATCH 03/18] Add contiguous kernel for nan_to_num --- .../elementwise_functions/nan_to_num.cpp | 130 +++++++++++++++++- .../elementwise_functions/nan_to_num.hpp | 37 +++++ 2 files changed, 163 insertions(+), 4 deletions(-) diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp index 3b88a4841a88..2026fe5543dd 100644 --- a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp @@ -114,8 +114,54 @@ sycl::event nan_to_num_call(sycl::queue &exec_q, return to_num_ev; } +typedef sycl::event (*nan_to_num_contig_fn_ptr_t)( + sycl::queue &, + size_t, + const py::object &, + const py::object &, + const py::object &, + const char *, + char *, + const std::vector &); + +template +sycl::event nan_to_num_contig_call(sycl::queue &exec_q, + size_t nelems, + const py::object &py_nan, + const py::object &py_posinf, + const py::object &py_neginf, + const char *arg_p, + char *dst_p, + const std::vector &depends) +{ + sycl::event to_num_contig_ev; + + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using realT = typename T::value_type; + realT nan_v = py::cast(py_nan); + realT posinf_v = py::cast(py_posinf); + realT neginf_v = py::cast(py_neginf); + + using dpnp::kernels::nan_to_num::nan_to_num_contig_impl; + to_num_contig_ev = nan_to_num_contig_impl( + exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends); + } + else { + T nan_v = py::cast(py_nan); + T posinf_v = py::cast(py_posinf); + T neginf_v = py::cast(py_neginf); + + using dpnp::kernels::nan_to_num::nan_to_num_contig_impl; + to_num_contig_ev = nan_to_num_contig_impl( + exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends); + } + return to_num_contig_ev; +} + namespace td_ns = dpctl::tensor::type_dispatch; nan_to_num_fn_ptr_t nan_to_num_dispatch_vector[td_ns::num_types]; +nan_to_num_contig_fn_ptr_t nan_to_num_contig_dispatch_vector[td_ns::num_types]; std::pair py_nan_to_num(const dpctl::tensor::usm_ndarray &src, @@ -176,6 +222,37 @@ std::pair const char *src_data = src.get_data(); char *dst_data = dst.get_data(); + // handle contiguous inputs + bool is_src_c_contig = src.is_c_contiguous(); + bool is_src_f_contig = src.is_f_contiguous(); + + bool is_dst_c_contig = dst.is_c_contiguous(); + bool is_dst_f_contig = dst.is_f_contiguous(); + + bool both_c_contig = (is_src_c_contig && is_dst_c_contig); + bool both_f_contig = (is_src_f_contig && is_dst_f_contig); + + if (both_c_contig || both_f_contig) { + auto contig_fn = nan_to_num_contig_dispatch_vector[src_typeid]; + + if (contig_fn == nullptr) { + throw std::runtime_error( + "Contiguous implementation is missing for src_typeid=" + + std::to_string(src_typeid)); + } + + auto comp_ev = contig_fn(q, nelems, py_nan, py_posinf, py_neginf, + src_data, dst_data, depends); + sycl::event ht_ev = + dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev}); + + return std::make_pair(ht_ev, comp_ev); + } + + // simplify iteration space + // if 1d with strides 1 - input is contig + // dispatch to strided + auto const &src_strides = src.get_strides_vector(); auto const &dst_strides = dst.get_strides_vector(); @@ -195,6 +272,30 @@ std::pair simplified_shape, simplified_src_strides, simplified_dst_strides, src_offset, dst_offset); + if (nd == 1 && simplified_src_strides[0] == 1 && + simplified_dst_strides[0] == 1) { + // Special case of contiguous data + auto contig_fn = nan_to_num_contig_dispatch_vector[src_typeid]; + + if (contig_fn == nullptr) { + throw std::runtime_error( + "Contiguous implementation is missing for src_typeid=" + + std::to_string(src_typeid)); + } + + int src_elem_size = src.get_elemsize(); + int dst_elem_size = dst.get_elemsize(); + auto comp_ev = + contig_fn(q, nelems, py_nan, py_posinf, py_neginf, + src_data + src_elem_size * src_offset, + dst_data + dst_elem_size * dst_offset, depends); + + sycl::event ht_ev = + dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev}); + + return std::make_pair(ht_ev, comp_ev); + } + auto fn = nan_to_num_dispatch_vector[src_typeid]; if (fn == nullptr) { @@ -277,12 +378,33 @@ struct NanToNumFactory } }; -void populate_nan_to_num_dispatch_vector(void) +template +struct NanToNumContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) { + return nullptr; + } + else { + using ::dpnp::extensions::ufunc::impl::nan_to_num_contig_call; + return nan_to_num_contig_call; + } + } +}; + +void populate_nan_to_num_dispatch_vectors(void) { using namespace td_ns; - DispatchVectorBuilder dvb; - dvb.populate_dispatch_vector(nan_to_num_dispatch_vector); + DispatchVectorBuilder dvb1; + dvb1.populate_dispatch_vector(nan_to_num_dispatch_vector); + + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(nan_to_num_contig_dispatch_vector); } } // namespace impl @@ -290,7 +412,7 @@ void populate_nan_to_num_dispatch_vector(void) void init_nan_to_num(py::module_ m) { { - impl::populate_nan_to_num_dispatch_vector(); + impl::populate_nan_to_num_dispatch_vectors(); using impl::py_nan_to_num; m.def("_nan_to_num", &py_nan_to_num, "", py::arg("src"), diff --git a/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp b/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp index b7b7a0d7f989..9a4cb8ba537c 100644 --- a/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp +++ b/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp @@ -127,4 +127,41 @@ sycl::event nan_to_num_impl(sycl::queue &q, return comp_ev; } +template +class NanToNumContigKernel; + +template +sycl::event nan_to_num_contig_impl(sycl::queue &q, + size_t nelems, + const scT nan, + const scT posinf, + const scT neginf, + const char *in_cp, + char *out_cp, + const std::vector &depends) +{ + dpctl::tensor::type_utils::validate_type_for_device(q); + + const T *in_tp = reinterpret_cast(in_cp); + T *out_tp = reinterpret_cast(out_cp); + + using dpctl::tensor::offset_utils::NoOpIndexer; + using InOutIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer; + constexpr NoOpIndexer in_indexer{}; + constexpr NoOpIndexer out_indexer{}; + constexpr InOutIndexerT indexer{in_indexer, out_indexer}; + + sycl::event comp_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using KernelName = NanToNumContigKernel; + cgh.parallel_for( + {nelems}, NanToNumFunctor( + in_tp, out_tp, indexer, nan, posinf, neginf)); + }); + return comp_ev; +} + } // namespace dpnp::kernels::nan_to_num From 13cd3b55e069e0aaa6fdae3c07bc497769052b3c Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 10 Dec 2024 00:44:36 -0800 Subject: [PATCH 04/18] Fix missed ssize_t to dpctl::tensor::ssize_t --- dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp b/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp index 9a4cb8ba537c..1f605ad6af44 100644 --- a/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp +++ b/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp @@ -97,7 +97,7 @@ template sycl::event nan_to_num_impl(sycl::queue &q, size_t nelems, int nd, - const ssize_t *shape_strides, + const dpctl::tensor::ssize_t *shape_strides, const scT nan, const scT posinf, const scT neginf, From f4c678214738fc82c1a45e62d5605e711d058e9b Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 10 Dec 2024 10:03:01 -0800 Subject: [PATCH 05/18] Clean-up nan_to_num.cpp dead code --- .../extensions/ufunc/elementwise_functions/nan_to_num.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp index 2026fe5543dd..18d294337296 100644 --- a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp @@ -342,8 +342,6 @@ std::pair dpctl::utils::keep_args_alive(q, {src, dst}, host_tasks), comp_ev); } -namespace py_int = dpnp::extensions::py_internal; - /** * @brief A factory to define pairs of supported types for which * nan_to_num_call function is available. @@ -372,7 +370,6 @@ struct NanToNumFactory return nullptr; } else { - using ::dpnp::extensions::ufunc::impl::nan_to_num_call; return nan_to_num_call; } } @@ -388,7 +385,6 @@ struct NanToNumContigFactory return nullptr; } else { - using ::dpnp::extensions::ufunc::impl::nan_to_num_contig_call; return nan_to_num_contig_call; } } From 377281518a954ec57c8ccdc4c9d87a1de866c71a Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 10 Dec 2024 10:05:47 -0800 Subject: [PATCH 06/18] Use dpnp.copy instead of copy method in nan_to_num --- dpnp/dpnp_iface_mathematical.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index 880c974976fc..f6dd53f442e8 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -3128,7 +3128,7 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None): x_type = x.dtype.type if not issubclass(x_type, dpnp.inexact): - return x.copy() if copy else x + return dpnp.copy(x) if copy else x max_f, min_f = _get_max_min(x.real.dtype) if posinf is not None: From ffb1379c8aa124252f10eb83ec6c80f33c8307bf Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 10 Dec 2024 10:06:58 -0800 Subject: [PATCH 07/18] Fix typo in nan_to_num --- dpnp/dpnp_iface_mathematical.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index f6dd53f442e8..550c9d23bd22 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -3165,7 +3165,7 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None): _manager.add_event_pair(h_ev, comp_ev) - return dpnp.get_result_array(out_ary) if copy else x + return dpnp.get_result_array(out) if copy else x _NEGATIVE_DOCSTRING = """ From 16d4ea204e80ff4f1e9bd4acf287d057e3d59870 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 10 Dec 2024 10:08:34 -0800 Subject: [PATCH 08/18] inline to_num in nan_to_num kernel --- dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp b/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp index 1f605ad6af44..1e327b390e0c 100644 --- a/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp +++ b/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp @@ -39,7 +39,7 @@ namespace dpnp::kernels::nan_to_num { template -T to_num(const T v, const T nan, const T posinf, const T neginf) +inline T to_num(const T v, const T nan, const T posinf, const T neginf) { return (sycl::isnan(v)) ? nan : (sycl::isinf(v)) ? (v > 0) ? posinf : neginf From cd13f407cc7b3c0926a02cd1e80006103a409eb8 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 10 Dec 2024 10:09:55 -0800 Subject: [PATCH 09/18] Add additional const qualifiers in nan_to_num impl functions --- .../kernels/elementwise_functions/nan_to_num.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp b/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp index 1e327b390e0c..2147b3dc8159 100644 --- a/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp +++ b/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp @@ -95,16 +95,16 @@ class NanToNumKernel; template sycl::event nan_to_num_impl(sycl::queue &q, - size_t nelems, - int nd, + const size_t nelems, + const int nd, const dpctl::tensor::ssize_t *shape_strides, const scT nan, const scT posinf, const scT neginf, const char *in_cp, - dpctl::tensor::ssize_t in_offset, + const dpctl::tensor::ssize_t in_offset, char *out_cp, - dpctl::tensor::ssize_t out_offset, + const dpctl::tensor::ssize_t out_offset, const std::vector &depends) { dpctl::tensor::type_utils::validate_type_for_device(q); @@ -132,7 +132,7 @@ class NanToNumContigKernel; template sycl::event nan_to_num_contig_impl(sycl::queue &q, - size_t nelems, + const size_t nelems, const scT nan, const scT posinf, const scT neginf, From 533963c7438040178def8be45c735ab0491e3776 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 10 Dec 2024 10:16:09 -0800 Subject: [PATCH 10/18] Use is_complex_v in nan_to_num kernel --- dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp b/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp index 2147b3dc8159..a61a871d0f30 100644 --- a/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp +++ b/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp @@ -67,8 +67,8 @@ struct NanToNumFunctor const dpctl::tensor::ssize_t &inp_offset = offsets_.get_first_offset(); const dpctl::tensor::ssize_t &out_offset = offsets_.get_second_offset(); - using dpctl::tensor::type_utils::is_complex; - if constexpr (is_complex::value) { + using dpctl::tensor::type_utils::is_complex_v; + if constexpr (is_complex_v) { using realT = typename T::value_type; static_assert(std::is_same_v); T z = inp_[inp_offset]; From 99cc211f356f08bdb9d05e84bea93c6d8c4f492f Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 10 Dec 2024 10:33:51 -0800 Subject: [PATCH 11/18] Simplify nan_to_num call logic Use std::conditional and value_type_of_t struct to avoid constexpr branches with redundant code --- .../elementwise_functions/nan_to_num.cpp | 85 +++++++++---------- 1 file changed, 39 insertions(+), 46 deletions(-) diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp index 18d294337296..081bf8c36db8 100644 --- a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp @@ -60,6 +60,22 @@ namespace dpnp::extensions::ufunc namespace impl { + +template +struct value_type_of +{ + using type = T; +}; + +template +struct value_type_of> +{ + using type = T; +}; + +template +using value_type_of_t = typename value_type_of::type; + typedef sycl::event (*nan_to_num_fn_ptr_t)(sycl::queue &, int, size_t, @@ -87,30 +103,18 @@ sycl::event nan_to_num_call(sycl::queue &exec_q, py::ssize_t dst_offset, const std::vector &depends) { - sycl::event to_num_ev; - - using dpctl::tensor::type_utils::is_complex; - if constexpr (is_complex::value) { - using realT = typename T::value_type; - realT nan_v = py::cast(py_nan); - realT posinf_v = py::cast(py_posinf); - realT neginf_v = py::cast(py_neginf); - - using dpnp::kernels::nan_to_num::nan_to_num_impl; - to_num_ev = nan_to_num_impl( - exec_q, nd, nelems, shape_strides, nan_v, posinf_v, neginf_v, arg_p, - arg_offset, dst_p, dst_offset, depends); - } - else { - T nan_v = py::cast(py_nan); - T posinf_v = py::cast(py_posinf); - T neginf_v = py::cast(py_neginf); - - using dpnp::kernels::nan_to_num::nan_to_num_impl; - to_num_ev = nan_to_num_impl( - exec_q, nd, nelems, shape_strides, nan_v, posinf_v, neginf_v, arg_p, - arg_offset, dst_p, dst_offset, depends); - } + using dpctl::tensor::type_utils::is_complex_v; + using scT = std::conditional_t, value_type_of_t, T>; + + scT nan_v = py::cast(py_nan); + scT posinf_v = py::cast(py_posinf); + scT neginf_v = py::cast(py_neginf); + + using dpnp::kernels::nan_to_num::nan_to_num_impl; + sycl::event to_num_ev = nan_to_num_impl( + exec_q, nd, nelems, shape_strides, nan_v, posinf_v, neginf_v, arg_p, + arg_offset, dst_p, dst_offset, depends); + return to_num_ev; } @@ -134,28 +138,17 @@ sycl::event nan_to_num_contig_call(sycl::queue &exec_q, char *dst_p, const std::vector &depends) { - sycl::event to_num_contig_ev; - - using dpctl::tensor::type_utils::is_complex; - if constexpr (is_complex::value) { - using realT = typename T::value_type; - realT nan_v = py::cast(py_nan); - realT posinf_v = py::cast(py_posinf); - realT neginf_v = py::cast(py_neginf); - - using dpnp::kernels::nan_to_num::nan_to_num_contig_impl; - to_num_contig_ev = nan_to_num_contig_impl( - exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends); - } - else { - T nan_v = py::cast(py_nan); - T posinf_v = py::cast(py_posinf); - T neginf_v = py::cast(py_neginf); - - using dpnp::kernels::nan_to_num::nan_to_num_contig_impl; - to_num_contig_ev = nan_to_num_contig_impl( - exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends); - } + using dpctl::tensor::type_utils::is_complex_v; + using scT = std::conditional_t, value_type_of_t, T>; + + scT nan_v = py::cast(py_nan); + scT posinf_v = py::cast(py_posinf); + scT neginf_v = py::cast(py_neginf); + + using dpnp::kernels::nan_to_num::nan_to_num_contig_impl; + sycl::event to_num_contig_ev = nan_to_num_contig_impl( + exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends); + return to_num_contig_ev; } From dc6e7f940727662707671d1d3fb2c947e53dd139 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Sat, 11 Jan 2025 18:12:11 -0800 Subject: [PATCH 12/18] Improve test coverage for nan_to_num --- dpnp/tests/test_mathematical.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/dpnp/tests/test_mathematical.py b/dpnp/tests/test_mathematical.py index 059cce931d21..75c18075d3c7 100644 --- a/dpnp/tests/test_mathematical.py +++ b/dpnp/tests/test_mathematical.py @@ -1558,6 +1558,27 @@ def test_errors_diff_types(self, kwarg, value): with pytest.raises(TypeError): dpnp.nan_to_num(ia, **{kwarg: value}) + def test_error_readonly(self): + a = dpnp.array([0, 1, dpnp.nan, dpnp.inf, -dpnp.inf]) + a.flags.writable = False + with pytest.raises(ValueError): + dpnp.nan_to_num(a, copy=False) + + @pytest.mark.parametrize("copy", [True, False]) + @pytest.mark.parametrize("dt", get_all_dtypes(no_bool=True, no_none=True)) + def test_nan_to_num_strided(self, copy, dt): + n = 10 + dt = numpy.dtype(dt) + np_a = numpy.arange(2 * n, dtype=dt) + dp_a = dpnp.arange(2 * n, dtype=dt) + if dt.kind in "fc": + np_a[::4] = numpy.nan + dp_a[::4] = dpnp.nan + dp_r = dpnp.nan_to_num(dp_a[::-2], copy=copy, nan=57.0) + np_r = numpy.nan_to_num(np_a[::-2], copy=copy, nan=57.0) + + assert_dtype_allclose(dp_r, np_r) + class TestProd: @pytest.mark.parametrize("axis", [None, 0, 1, -1, 2, -2, (1, 2), (0, -2)]) From cc45176301fa4af47ece1bee57eebbc04942d3da Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Sat, 11 Jan 2025 21:28:29 -0800 Subject: [PATCH 13/18] Align with changes to device_allocate_and_pack in dpctl --- .../elementwise_functions/nan_to_num.cpp | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp index 081bf8c36db8..9fd78bdee1f3 100644 --- a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp @@ -79,7 +79,7 @@ using value_type_of_t = typename value_type_of::type; typedef sycl::event (*nan_to_num_fn_ptr_t)(sycl::queue &, int, size_t, - py::ssize_t *, + const py::ssize_t *, const py::object &, const py::object &, const py::object &, @@ -93,7 +93,7 @@ template sycl::event nan_to_num_call(sycl::queue &exec_q, int nd, size_t nelems, - py::ssize_t *shape_strides, + const py::ssize_t *shape_strides, const py::object &py_nan, const py::object &py_posinf, const py::object &py_neginf, @@ -302,15 +302,12 @@ std::pair std::vector host_tasks{}; host_tasks.reserve(2); - const auto &ptr_size_event_triple_ = device_allocate_and_pack( + auto ptr_size_event_triple_ = device_allocate_and_pack( q, host_tasks, simplified_shape, simplified_src_strides, simplified_dst_strides); - py::ssize_t *shape_strides = std::get<0>(ptr_size_event_triple_); + auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_triple_)); const sycl::event ©_shape_ev = std::get<2>(ptr_size_event_triple_); - - if (shape_strides == nullptr) { - throw std::runtime_error("Device memory allocation failed"); - } + const py::ssize_t *shape_strides = shape_strides_owner.get(); std::vector all_deps; all_deps.reserve(depends.size() + 1); @@ -322,13 +319,9 @@ std::pair src_offset, dst_data, dst_offset, all_deps); // async free of shape_strides temporary - auto ctx = q.get_context(); - sycl::event tmp_cleanup_ev = q.submit([&](sycl::handler &cgh) { - cgh.depends_on(comp_ev); - using dpctl::tensor::alloc_utils::sycl_free_noexcept; - cgh.host_task( - [ctx, shape_strides]() { sycl_free_noexcept(shape_strides, ctx); }); - }); + sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + q, {comp_ev}, shape_strides_owner); + host_tasks.push_back(tmp_cleanup_ev); return std::make_pair( From 00e51ad7111c2d000516630e82979459aaa3e469 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 28 Jan 2025 14:39:46 -0800 Subject: [PATCH 14/18] Add subgroup load and store based implementation for nan_to_num kernel --- .../elementwise_functions/nan_to_num.hpp | 177 +++++++++++++++--- 1 file changed, 147 insertions(+), 30 deletions(-) diff --git a/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp b/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp index a61a871d0f30..549b61820427 100644 --- a/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp +++ b/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp @@ -31,8 +31,10 @@ #include // dpctl tensor headers +#include "kernels/alignment.hpp" #include "kernels/dpctl_tensor_types.hpp" #include "utils/offset_utils.hpp" +#include "utils/sycl_utils.hpp" #include "utils/type_utils.hpp" namespace dpnp::kernels::nan_to_num @@ -49,6 +51,14 @@ inline T to_num(const T v, const T nan, const T posinf, const T neginf) template struct NanToNumFunctor { +private: + const T *inp_ = nullptr; + T *out_ = nullptr; + const InOutIndexerT inp_out_indexer_; + const scT nan_; + const scT posinf_; + const scT neginf_; + public: NanToNumFunctor(const T *inp, T *out, @@ -80,18 +90,104 @@ struct NanToNumFunctor out_[out_offset] = to_num(inp_[inp_offset], nan_, posinf_, neginf_); } } +}; +template +struct NanToNumContigFunctor +{ private: - const T *inp_ = nullptr; + const T *in_ = nullptr; T *out_ = nullptr; - const InOutIndexerT inp_out_indexer_; + std::size_t nelems_; const scT nan_; const scT posinf_; const scT neginf_; -}; -template -class NanToNumKernel; +public: + NanToNumContigFunctor(const T *in, + T *out, + const std::size_t n_elems, + const scT nan, + const scT posinf, + const scT neginf) + : in_(in), out_(out), nelems_(n_elems), nan_(nan), posinf_(posinf), + neginf_(neginf) + { + } + + void operator()(sycl::nd_item<1> ndit) const + { + constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz; + /* Each work-item processes vec_sz elements, contiguous in memory */ + /* NOTE: work-group size must be divisible by sub-group size */ + + using dpctl::tensor::type_utils::is_complex_v; + if constexpr (enable_sg_loadstore && !is_complex_v) { + auto sg = ndit.get_sub_group(); + const std::uint16_t sgSize = sg.get_max_local_range()[0]; + const std::size_t base = + elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + + if (base + elems_per_wi * sgSize < nelems_) { + using dpctl::tensor::sycl_utils::sub_group_load; + using dpctl::tensor::sycl_utils::sub_group_store; +#pragma unroll + for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) { + const std::size_t offset = base + it * sgSize; + auto in_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&in_[offset]); + auto out_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&out_[offset]); + + sycl::vec arg_vec = + sub_group_load(sg, in_multi_ptr); +#pragma unroll + for (std::uint32_t k = 0; k < vec_sz; ++k) { + arg_vec[k] = to_num(arg_vec[k], nan_, posinf_, neginf_); + } + sub_group_store(sg, arg_vec, out_multi_ptr); + } + } + else { + const std::size_t lane_id = sg.get_local_id()[0]; + for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) { + out_[k] = to_num(in_[k], nan_, posinf_, neginf_); + } + } + } + else { + const std::uint16_t sgSize = + ndit.get_sub_group().get_local_range()[0]; + const std::size_t gid = ndit.get_global_linear_id(); + const std::uint16_t elems_per_sg = sgSize * elems_per_wi; + + const std::size_t start = + (gid / sgSize) * (elems_per_sg - sgSize) + gid; + const std::size_t end = std::min(nelems_, start + elems_per_sg); + for (std::size_t offset = start; offset < end; offset += sgSize) { + if constexpr (is_complex_v) { + using realT = typename T::value_type; + static_assert(std::is_same_v); + + T z = in_[offset]; + realT x = to_num(z.real(), nan_, posinf_, neginf_); + realT y = to_num(z.imag(), nan_, posinf_, neginf_); + out_[offset] = T{x, y}; + } + else { + out_[offset] = to_num(in_[offset], nan_, posinf_, neginf_); + } + } + } + } +}; template sycl::event nan_to_num_impl(sycl::queue &q, @@ -119,48 +215,69 @@ sycl::event nan_to_num_impl(sycl::queue &q, sycl::event comp_ev = q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); - using KernelName = NanToNumKernel; - cgh.parallel_for( - {nelems}, NanToNumFunctor( - in_tp, out_tp, indexer, nan, posinf, neginf)); + using NanToNumFunc = NanToNumFunctor; + cgh.parallel_for( + {nelems}, + NanToNumFunc(in_tp, out_tp, indexer, nan, posinf, neginf)); }); return comp_ev; } -template -class NanToNumContigKernel; - -template -sycl::event nan_to_num_contig_impl(sycl::queue &q, - const size_t nelems, +template +sycl::event nan_to_num_contig_impl(sycl::queue &exec_q, + std::size_t nelems, const scT nan, const scT posinf, const scT neginf, const char *in_cp, char *out_cp, - const std::vector &depends) + const std::vector &depends = {}) { - dpctl::tensor::type_utils::validate_type_for_device(q); + constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz; + const std::size_t n_work_items_needed = nelems / elems_per_wi; + const std::size_t empirical_threshold = std::size_t(1) << 21; + const std::size_t lws = (n_work_items_needed <= empirical_threshold) + ? std::size_t(128) + : std::size_t(256); + + const std::size_t n_groups = + ((nelems + lws * elems_per_wi - 1) / (lws * elems_per_wi)); + const auto gws_range = sycl::range<1>(n_groups * lws); + const auto lws_range = sycl::range<1>(lws); const T *in_tp = reinterpret_cast(in_cp); T *out_tp = reinterpret_cast(out_cp); - using dpctl::tensor::offset_utils::NoOpIndexer; - using InOutIndexerT = - dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer; - constexpr NoOpIndexer in_indexer{}; - constexpr NoOpIndexer out_indexer{}; - constexpr InOutIndexerT indexer{in_indexer, out_indexer}; - - sycl::event comp_ev = q.submit([&](sycl::handler &cgh) { + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(depends); - using KernelName = NanToNumContigKernel; - cgh.parallel_for( - {nelems}, NanToNumFunctor( - in_tp, out_tp, indexer, nan, posinf, neginf)); + using dpctl::tensor::kernels::alignment_utils::is_aligned; + using dpctl::tensor::kernels::alignment_utils::required_alignment; + if (is_aligned(in_tp) && + is_aligned(out_tp)) + { + constexpr bool enable_sg_loadstore = true; + using NanToNumFunc = NanToNumContigFunctor; + + cgh.parallel_for( + sycl::nd_range<1>(gws_range, lws_range), + NanToNumFunc(in_tp, out_tp, nelems, nan, posinf, neginf)); + } + else { + constexpr bool disable_sg_loadstore = false; + using NanToNumFunc = NanToNumContigFunctor; + + cgh.parallel_for( + sycl::nd_range<1>(gws_range, lws_range), + NanToNumFunc(in_tp, out_tp, nelems, nan, posinf, neginf)); + } }); + return comp_ev; } From 9dd3e60999ba818a671a7bf8ca6a8671f72f991c Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 29 Jan 2025 10:13:21 -0800 Subject: [PATCH 15/18] size_t -> std::size_t in nan_to_num Python binding --- .../ufunc/elementwise_functions/nan_to_num.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp index 9fd78bdee1f3..1535b2378d94 100644 --- a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp @@ -25,6 +25,7 @@ #include #include +#include #include #include #include @@ -78,7 +79,7 @@ using value_type_of_t = typename value_type_of::type; typedef sycl::event (*nan_to_num_fn_ptr_t)(sycl::queue &, int, - size_t, + std::size_t, const py::ssize_t *, const py::object &, const py::object &, @@ -92,7 +93,7 @@ typedef sycl::event (*nan_to_num_fn_ptr_t)(sycl::queue &, template sycl::event nan_to_num_call(sycl::queue &exec_q, int nd, - size_t nelems, + std::size_t nelems, const py::ssize_t *shape_strides, const py::object &py_nan, const py::object &py_posinf, @@ -120,7 +121,7 @@ sycl::event nan_to_num_call(sycl::queue &exec_q, typedef sycl::event (*nan_to_num_contig_fn_ptr_t)( sycl::queue &, - size_t, + std::size_t, const py::object &, const py::object &, const py::object &, @@ -130,7 +131,7 @@ typedef sycl::event (*nan_to_num_contig_fn_ptr_t)( template sycl::event nan_to_num_contig_call(sycl::queue &exec_q, - size_t nelems, + std::size_t nelems, const py::object &py_nan, const py::object &py_posinf, const py::object &py_neginf, @@ -191,7 +192,7 @@ std::pair const py::ssize_t *src_shape = src.get_shape_raw(); const py::ssize_t *dst_shape = dst.get_shape_raw(); - size_t nelems = src.get_size(); + std::size_t nelems = src.get_size(); bool shapes_equal = std::equal(src_shape, src_shape + src_nd, dst_shape); if (!shapes_equal) { throw py::value_error("Array shapes are not the same."); From e9953f51b31350194fd2c56d969eff6205af0fb6 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 5 Feb 2025 10:28:55 -0800 Subject: [PATCH 16/18] nan_to_num always returns dpnp_array --- dpnp/dpnp_iface_mathematical.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index 550c9d23bd22..8be5d6e28cea 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -3128,7 +3128,7 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None): x_type = x.dtype.type if not issubclass(x_type, dpnp.inexact): - return dpnp.copy(x) if copy else x + return dpnp.copy(x) if copy else dpnp.get_result_array(x) max_f, min_f = _get_max_min(x.real.dtype) if posinf is not None: @@ -3165,7 +3165,7 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None): _manager.add_event_pair(h_ev, comp_ev) - return dpnp.get_result_array(out) if copy else x + return dpnp.get_result_array(out) _NEGATIVE_DOCSTRING = """ From 99fd28b49168d94c2380a0021f72433d8b411f2c Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 5 Feb 2025 10:35:11 -0800 Subject: [PATCH 17/18] Apply review comments * nan_to_num_call -> nan_to_num_strided_call * add missing const markers on converted Python scalar objects --- .../elementwise_functions/nan_to_num.cpp | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp index 1535b2378d94..a328836a2b84 100644 --- a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp @@ -91,25 +91,25 @@ typedef sycl::event (*nan_to_num_fn_ptr_t)(sycl::queue &, const std::vector &); template -sycl::event nan_to_num_call(sycl::queue &exec_q, - int nd, - std::size_t nelems, - const py::ssize_t *shape_strides, - const py::object &py_nan, - const py::object &py_posinf, - const py::object &py_neginf, - const char *arg_p, - py::ssize_t arg_offset, - char *dst_p, - py::ssize_t dst_offset, - const std::vector &depends) +sycl::event nan_to_num_strided_call(sycl::queue &exec_q, + int nd, + std::size_t nelems, + const py::ssize_t *shape_strides, + const py::object &py_nan, + const py::object &py_posinf, + const py::object &py_neginf, + const char *arg_p, + py::ssize_t arg_offset, + char *dst_p, + py::ssize_t dst_offset, + const std::vector &depends) { using dpctl::tensor::type_utils::is_complex_v; using scT = std::conditional_t, value_type_of_t, T>; - scT nan_v = py::cast(py_nan); - scT posinf_v = py::cast(py_posinf); - scT neginf_v = py::cast(py_neginf); + const scT nan_v = py::cast(py_nan); + const scT posinf_v = py::cast(py_posinf); + const scT neginf_v = py::cast(py_neginf); using dpnp::kernels::nan_to_num::nan_to_num_impl; sycl::event to_num_ev = nan_to_num_impl( @@ -142,9 +142,9 @@ sycl::event nan_to_num_contig_call(sycl::queue &exec_q, using dpctl::tensor::type_utils::is_complex_v; using scT = std::conditional_t, value_type_of_t, T>; - scT nan_v = py::cast(py_nan); - scT posinf_v = py::cast(py_posinf); - scT neginf_v = py::cast(py_neginf); + const scT nan_v = py::cast(py_nan); + const scT posinf_v = py::cast(py_posinf); + const scT neginf_v = py::cast(py_neginf); using dpnp::kernels::nan_to_num::nan_to_num_contig_impl; sycl::event to_num_contig_ev = nan_to_num_contig_impl( @@ -331,7 +331,7 @@ std::pair /** * @brief A factory to define pairs of supported types for which - * nan_to_num_call function is available. + * nan-to-num function is available. * * @tparam T Type of input vector `a` and of result vector `y`. */ @@ -357,7 +357,7 @@ struct NanToNumFactory return nullptr; } else { - return nan_to_num_call; + return nan_to_num_strided_call; } } }; From 3acc9c4b16606abe7cfaf1e249f9a7f4860a45a8 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 5 Feb 2025 11:17:49 -0800 Subject: [PATCH 18/18] nan_to_num_impl -> nan_to_num_strided_impl --- .../elementwise_functions/nan_to_num.cpp | 4 ++-- .../elementwise_functions/nan_to_num.hpp | 24 +++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp index a328836a2b84..ec5dfd0a78b3 100644 --- a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp @@ -111,8 +111,8 @@ sycl::event nan_to_num_strided_call(sycl::queue &exec_q, const scT posinf_v = py::cast(py_posinf); const scT neginf_v = py::cast(py_neginf); - using dpnp::kernels::nan_to_num::nan_to_num_impl; - sycl::event to_num_ev = nan_to_num_impl( + using dpnp::kernels::nan_to_num::nan_to_num_strided_impl; + sycl::event to_num_ev = nan_to_num_strided_impl( exec_q, nd, nelems, shape_strides, nan_v, posinf_v, neginf_v, arg_p, arg_offset, dst_p, dst_offset, depends); diff --git a/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp b/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp index 549b61820427..c4219de63f40 100644 --- a/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp +++ b/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp @@ -190,18 +190,18 @@ struct NanToNumContigFunctor }; template -sycl::event nan_to_num_impl(sycl::queue &q, - const size_t nelems, - const int nd, - const dpctl::tensor::ssize_t *shape_strides, - const scT nan, - const scT posinf, - const scT neginf, - const char *in_cp, - const dpctl::tensor::ssize_t in_offset, - char *out_cp, - const dpctl::tensor::ssize_t out_offset, - const std::vector &depends) +sycl::event nan_to_num_strided_impl(sycl::queue &q, + const size_t nelems, + const int nd, + const dpctl::tensor::ssize_t *shape_strides, + const scT nan, + const scT posinf, + const scT neginf, + const char *in_cp, + const dpctl::tensor::ssize_t in_offset, + char *out_cp, + const dpctl::tensor::ssize_t out_offset, + const std::vector &depends) { dpctl::tensor::type_utils::validate_type_for_device(q);