Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class TakeFunctor
ssize_t src_offset = orthog_offsets.get_first_offset();
ssize_t dst_offset = orthog_offsets.get_second_offset();

const ProjectorT proj{};
constexpr ProjectorT proj{};
for (int axis_idx = 0; axis_idx < k_; ++axis_idx) {
indT *ind_data = reinterpret_cast<indT *>(ind_[axis_idx]);

Expand Down Expand Up @@ -239,7 +239,7 @@ class PutFunctor
ssize_t dst_offset = orthog_offsets.get_first_offset();
ssize_t val_offset = orthog_offsets.get_second_offset();

const ProjectorT proj{};
constexpr ProjectorT proj{};
for (int axis_idx = 0; axis_idx < k_; ++axis_idx) {
indT *ind_data = reinterpret_cast<indT *>(ind_[axis_idx]);

Expand Down
81 changes: 39 additions & 42 deletions dpctl/tensor/libtensor/include/utils/indexing_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,41 +49,40 @@ template <typename IndT> struct WrapIndex
ssize_t operator()(ssize_t max_item, IndT ind) const
{
ssize_t projected;
max_item = sycl::max<ssize_t>(max_item, 1);
constexpr ssize_t unit(1);
max_item = sycl::max(max_item, unit);

constexpr std::uintmax_t ind_max = std::numeric_limits<IndT>::max();
constexpr std::uintmax_t ssize_max =
std::numeric_limits<ssize_t>::max();

if constexpr (std::is_signed_v<IndT>) {
static constexpr std::uintmax_t ind_max =
std::numeric_limits<IndT>::max();
static constexpr std::uintmax_t ssize_max =
std::numeric_limits<ssize_t>::max();
static constexpr std::intmax_t ind_min =
std::numeric_limits<IndT>::min();
static constexpr std::intmax_t ssize_min =
constexpr std::intmax_t ind_min = std::numeric_limits<IndT>::min();
constexpr std::intmax_t ssize_min =
std::numeric_limits<ssize_t>::min();

if constexpr (ind_max <= ssize_max && ind_min >= ssize_min) {
projected = sycl::clamp<ssize_t>(static_cast<ssize_t>(ind),
-max_item, max_item - 1);
const ssize_t ind_ = static_cast<ssize_t>(ind);
const ssize_t lb = -max_item;
const ssize_t ub = max_item - 1;
projected = sycl::clamp(ind_, lb, ub);
}
else {
projected = sycl::clamp<IndT>(ind, static_cast<IndT>(-max_item),
static_cast<IndT>(max_item - 1));
const IndT lb = static_cast<IndT>(-max_item);
const IndT ub = static_cast<IndT>(max_item - 1);
projected = static_cast<ssize_t>(sycl::clamp(ind, lb, ub));
}
return (projected < 0) ? projected + max_item : projected;
}
else {
static constexpr std::uintmax_t ind_max =
std::numeric_limits<IndT>::max();
static constexpr std::uintmax_t ssize_max =
std::numeric_limits<ssize_t>::max();

if constexpr (ind_max <= ssize_max) {
projected =
sycl::min<ssize_t>(static_cast<ssize_t>(ind), max_item - 1);
const ssize_t ind_ = static_cast<ssize_t>(ind);
const ssize_t ub = max_item - 1;
projected = sycl::min(ind_, ub);
}
else {
projected =
sycl::min<IndT>(ind, static_cast<IndT>(max_item - 1));
const IndT ub = static_cast<IndT>(max_item - 1);
projected = static_cast<ssize_t>(sycl::min(ind, ub));
}
return projected;
}
Expand All @@ -95,40 +94,38 @@ template <typename IndT> struct ClipIndex
ssize_t operator()(ssize_t max_item, IndT ind) const
{
ssize_t projected;
max_item = sycl::max<ssize_t>(max_item, 1);
constexpr ssize_t unit(1);
max_item = sycl::max<ssize_t>(max_item, unit);

constexpr std::uintmax_t ind_max = std::numeric_limits<IndT>::max();
constexpr std::uintmax_t ssize_max =
std::numeric_limits<ssize_t>::max();
if constexpr (std::is_signed_v<IndT>) {
static constexpr std::uintmax_t ind_max =
std::numeric_limits<IndT>::max();
static constexpr std::uintmax_t ssize_max =
std::numeric_limits<ssize_t>::max();
static constexpr std::intmax_t ind_min =
std::numeric_limits<IndT>::min();
static constexpr std::intmax_t ssize_min =
constexpr std::intmax_t ind_min = std::numeric_limits<IndT>::min();
constexpr std::intmax_t ssize_min =
std::numeric_limits<ssize_t>::min();

if constexpr (ind_max <= ssize_max && ind_min >= ssize_min) {
projected = sycl::clamp<ssize_t>(static_cast<ssize_t>(ind),
ssize_t(0), max_item - 1);
const ssize_t ind_ = static_cast<ssize_t>(ind);
constexpr ssize_t lb(0);
const ssize_t ub = max_item - 1;
projected = sycl::clamp(ind_, lb, ub);
}
else {
projected = sycl::clamp<IndT>(ind, IndT(0),
static_cast<IndT>(max_item - 1));
constexpr IndT lb(0);
const IndT ub = static_cast<IndT>(max_item - 1);
projected = static_cast<size_t>(sycl::clamp(ind, lb, ub));
}
}
else {
static constexpr std::uintmax_t ind_max =
std::numeric_limits<IndT>::max();
static constexpr std::uintmax_t ssize_max =
std::numeric_limits<ssize_t>::max();

if constexpr (ind_max <= ssize_max) {
projected =
sycl::min<ssize_t>(static_cast<ssize_t>(ind), max_item - 1);
const ssize_t ind_ = static_cast<ssize_t>(ind);
const ssize_t ub = max_item - 1;
projected = sycl::min(ind_, ub);
}
else {
projected =
sycl::min<IndT>(ind, static_cast<IndT>(max_item - 1));
const IndT ub = static_cast<IndT>(max_item - 1);
projected = static_cast<ssize_t>(sycl::min(ind, ub));
}
}
return projected;
Expand Down
Loading