Skip to content

Commit 4640852

Browse files
committed
feat: add multi-axis support for xt::roll with optimized pointer arithmetic
1 parent cbb9a36 commit 4640852

File tree

2 files changed

+226
-0
lines changed

2 files changed

+226
-0
lines changed

include/xtensor/misc/xmanipulation.hpp

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,6 +1033,77 @@ namespace xt
10331033
}
10341034
return to;
10351035
}
1036+
1037+
/**
1038+
* Multi-axis roll using pointer arithmetic.
1039+
* Algorithm extended from single-axis roll above.
1040+
*
1041+
* @param to output iterator
1042+
* @param from input iterator
1043+
* @param shifts normalized shift values for each axis (size == shape.size())
1044+
* @param shape the shape of the tensor
1045+
* @param M current dimension being processed
1046+
*/
1047+
template <class To, class From, class Shifts, class S>
1048+
To roll_multi(To to, From from, const Shifts& shifts, const S& shape, std::size_t M)
1049+
{
1050+
std::ptrdiff_t dim = static_cast<std::ptrdiff_t>(shape[M]);
1051+
std::ptrdiff_t offset = std::accumulate(
1052+
shape.begin() + M + 1,
1053+
shape.end(),
1054+
std::ptrdiff_t(1),
1055+
std::multiplies<std::ptrdiff_t>()
1056+
);
1057+
std::ptrdiff_t shift = shifts[M];
1058+
1059+
if (shape.size() == M + 1)
1060+
{
1061+
// Innermost dimension: direct copy
1062+
if (shift != 0)
1063+
{
1064+
const auto split = from + (dim - shift) * offset;
1065+
for (auto iter = split, end = from + dim * offset; iter != end; iter += offset, ++to)
1066+
{
1067+
*to = *iter;
1068+
}
1069+
for (auto iter = from, end = split; iter != end; iter += offset, ++to)
1070+
{
1071+
*to = *iter;
1072+
}
1073+
}
1074+
else
1075+
{
1076+
for (auto iter = from, end = from + dim * offset; iter != end; iter += offset, ++to)
1077+
{
1078+
*to = *iter;
1079+
}
1080+
}
1081+
}
1082+
else
1083+
{
1084+
// Recursive case: process current dimension, then recurse
1085+
if (shift != 0)
1086+
{
1087+
const auto split = from + (dim - shift) * offset;
1088+
for (auto iter = split, end = from + dim * offset; iter != end; iter += offset)
1089+
{
1090+
to = roll_multi(to, iter, shifts, shape, M + 1);
1091+
}
1092+
for (auto iter = from, end = split; iter != end; iter += offset)
1093+
{
1094+
to = roll_multi(to, iter, shifts, shape, M + 1);
1095+
}
1096+
}
1097+
else
1098+
{
1099+
for (auto iter = from, end = from + dim * offset; iter != end; iter += offset)
1100+
{
1101+
to = roll_multi(to, iter, shifts, shape, M + 1);
1102+
}
1103+
}
1104+
}
1105+
return to;
1106+
}
10361107
}
10371108

10381109
/**
@@ -1073,6 +1144,87 @@ namespace xt
10731144
return cpy;
10741145
}
10751146

1147+
/**
1148+
* Roll an expression along multiple axes.
1149+
*
1150+
* Elements that roll beyond the last position are re-introduced at the first.
1151+
* This function does not change the input expression.
1152+
*
1153+
* @ingroup xt_xmanipulation
1154+
* @param e the input xexpression
1155+
* @param shifts container of shift values for each axis
1156+
* @param axes container of axes along which elements are shifted
1157+
* @return a rolled copy of the input expression
1158+
*
1159+
* @note shifts and axes must have the same size. Each element in shifts
1160+
* corresponds to the shift amount for the axis at the same position in axes.
1161+
* @note If the same axis appears multiple times, the shifts are accumulated.
1162+
* @note Negative axis indices are supported (e.g., -1 refers to the last axis).
1163+
*
1164+
* Example:
1165+
* @code
1166+
* xt::xarray<int> a = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}};
1167+
* auto result = xt::roll(a, {1, 2}, {0, 1}); // roll 1 on axis 0, roll 2 on axis 1
1168+
* @endcode
1169+
*/
1170+
template <class E, class S, class X, XTL_REQUIRES(std::negation<xtl::is_integral<std::decay_t<S>>>)>
1171+
inline auto roll(E&& e, S&& shifts, X&& axes)
1172+
{
1173+
XTENSOR_ASSERT(std::size(shifts) == std::size(axes));
1174+
1175+
if (std::size(shifts) == 0)
1176+
{
1177+
return empty_like(e) = e;
1178+
}
1179+
1180+
const auto dim = e.dimension();
1181+
auto cpy = empty_like(e);
1182+
const auto& shape = cpy.shape();
1183+
1184+
// Accumulate shifts per axis (like NumPy)
1185+
// and normalize to positive values in range [0, axis_size)
1186+
std::vector<std::ptrdiff_t> total_shifts(dim, 0);
1187+
auto shift_it = std::begin(shifts);
1188+
auto axis_it = std::begin(axes);
1189+
for (; shift_it != std::end(shifts); ++shift_it, ++axis_it)
1190+
{
1191+
auto ax = normalize_axis(dim, static_cast<std::ptrdiff_t>(*axis_it));
1192+
total_shifts[ax] += static_cast<std::ptrdiff_t>(*shift_it);
1193+
}
1194+
1195+
// Normalize shifts to positive values
1196+
for (std::size_t ax = 0; ax < dim; ++ax)
1197+
{
1198+
auto axis_size = static_cast<std::ptrdiff_t>(shape[ax]);
1199+
if (axis_size > 0)
1200+
{
1201+
total_shifts[ax] = ((total_shifts[ax] % axis_size) + axis_size) % axis_size;
1202+
}
1203+
}
1204+
1205+
detail::roll_multi(cpy.begin(), e.begin(), total_shifts, shape, 0);
1206+
return cpy;
1207+
}
1208+
1209+
/**
1210+
* Roll an expression along multiple axes (C-style array overload).
1211+
*
1212+
* @ingroup xt_xmanipulation
1213+
* @param e the input xexpression
1214+
* @param shifts C-style array of shift values
1215+
* @param axes C-style array of axes
1216+
* @return a roll of the input expression
1217+
*/
1218+
template <class E, class S, std::size_t N, class I, std::size_t M>
1219+
inline auto roll(E&& e, const S (&shifts)[N], const I (&axes)[M])
1220+
{
1221+
return roll(
1222+
std::forward<E>(e),
1223+
xtl::forward_sequence<std::array<S, N>, decltype(shifts)>(shifts),
1224+
xtl::forward_sequence<std::array<I, M>, decltype(axes)>(axes)
1225+
);
1226+
}
1227+
10761228
/****************************
10771229
* repeat implementation *
10781230
****************************/

test/test_xmanipulation.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,80 @@ namespace xt
504504
ASSERT_EQ(expected8, xt::roll(e2, -2, /*axis*/ 2));
505505
}
506506

507+
TEST(xmanipulation, roll_multi_axis)
508+
{
509+
// Test 1: Basic 2D multi-axis roll
510+
xarray<double> e1 = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}};
511+
512+
// Roll 1 on axis 0, 2 on axis 1
513+
xarray<double> expected1 = {{8, 9, 7}, {2, 3, 1}, {5, 6, 4}};
514+
ASSERT_EQ(expected1, xt::roll(e1, {1, 2}, {0, 1}));
515+
516+
// Verify equivalence with sequential single-axis rolls
517+
auto sequential_result = xt::roll(xt::roll(e1, 1, 0), 2, 1);
518+
ASSERT_EQ(expected1, sequential_result);
519+
520+
// Test 2: std::array as input
521+
std::array<int, 2> shifts = {1, 2};
522+
std::array<int, 2> axes = {0, 1};
523+
ASSERT_EQ(expected1, xt::roll(e1, shifts, axes));
524+
525+
// Test 3: std::vector as input
526+
std::vector<int> shifts_v = {1, 2};
527+
std::vector<int> axes_v = {0, 1};
528+
ASSERT_EQ(expected1, xt::roll(e1, shifts_v, axes_v));
529+
530+
// Test 4: Negative shifts
531+
xarray<double> expected4 = {{6, 4, 5}, {9, 7, 8}, {3, 1, 2}};
532+
ASSERT_EQ(expected4, xt::roll(e1, {-1, -2}, {0, 1}));
533+
534+
// Test 5: Negative axis indices
535+
ASSERT_EQ(expected1, xt::roll(e1, {1, 2}, {-2, -1}));
536+
537+
// Test 6: 3D array
538+
xarray<double> e3 = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}};
539+
xarray<double> expected6 = {{{8, 7}, {6, 5}}, {{4, 3}, {2, 1}}};
540+
ASSERT_EQ(expected6, xt::roll(e3, {1, 1, 1}, {0, 1, 2}));
541+
542+
// Test 7: Single axis via multi-axis interface (should match single-axis version)
543+
xarray<double> expected7 = {{7, 8, 9}, {1, 2, 3}, {4, 5, 6}};
544+
ASSERT_EQ(expected7, xt::roll(e1, {1}, {0}));
545+
ASSERT_EQ(expected7, xt::roll(e1, 1, 0));
546+
547+
// Test 8: Empty shifts (should return copy of original)
548+
std::vector<int> empty_shifts;
549+
std::vector<int> empty_axes;
550+
ASSERT_EQ(e1, xt::roll(e1, empty_shifts, empty_axes));
551+
552+
// Test 9: Large shift values (should wrap around)
553+
// shift of 10 on axis 0 (size 3) is equivalent to shift of 1
554+
ASSERT_EQ(xt::roll(e1, {1}, {0}), xt::roll(e1, {10}, {0}));
555+
556+
// Test 10: Same axis appears multiple times (shifts accumulate)
557+
// NumPy: np.roll(a, (1, 2), axis=(0, 0)) equals np.roll(a, 3, axis=0)
558+
xarray<double> expected10 = xt::roll(e1, 3, 0);
559+
ASSERT_EQ(expected10, xt::roll(e1, {1, 2}, {0, 0}));
560+
561+
// Test 11: xtensor (fixed dimension) instead of xarray
562+
xt::xtensor<double, 2> t1 = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}};
563+
xt::xtensor<double, 2> t_expected = {{8, 9, 7}, {2, 3, 1}, {5, 6, 4}};
564+
ASSERT_EQ(t_expected, xt::roll(t1, {1, 2}, {0, 1}));
565+
566+
// Test 12: 1D array multi-axis operation (only one axis)
567+
xarray<double> e1d = {1, 2, 3, 4, 5};
568+
xarray<double> expected1d = {4, 5, 1, 2, 3};
569+
ASSERT_EQ(expected1d, xt::roll(e1d, {2}, {0}));
570+
571+
// Test 13: Column-major layout (result should be layout-independent)
572+
xarray<double, layout_type::column_major> cm = {{1, 2, 3}, {4, 5, 6}};
573+
xarray<double> expected_cm = {{3, 1, 2}, {6, 4, 5}};
574+
ASSERT_EQ(expected_cm, xt::roll(cm, {1}, {1}));
575+
576+
// Test 14: Column-major with multi-axis roll
577+
xarray<double> expected_cm2 = {{6, 4, 5}, {3, 1, 2}};
578+
ASSERT_EQ(expected_cm2, xt::roll(cm, {1, 1}, {0, 1}));
579+
}
580+
507581
TEST(xmanipulation, repeat_all_elements_of_axis_0_of_int_array_2_times)
508582
{
509583
xarray<int> array = {

0 commit comments

Comments
 (0)