Skip to content

Commit a8d022b

Browse files
committed
fix: correct negative axis handling in roll function
1 parent cbb9a36 commit a8d022b

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

include/xtensor/misc/xmanipulation.hpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,17 +1052,15 @@ namespace xt
10521052
{
10531053
auto cpy = empty_like(e);
10541054
const auto& shape = cpy.shape();
1055-
std::size_t saxis = static_cast<std::size_t>(axis);
1056-
if (axis < 0)
1057-
{
1058-
axis += std::ptrdiff_t(cpy.dimension());
1059-
}
1055+
const auto dim = cpy.dimension();
10601056

1061-
if (saxis >= cpy.dimension() || axis < 0)
1057+
if (axis < -static_cast<std::ptrdiff_t>(dim) || axis >= static_cast<std::ptrdiff_t>(dim))
10621058
{
1063-
XTENSOR_THROW(std::runtime_error, "axis is no within shape dimension.");
1059+
XTENSOR_THROW(std::runtime_error, "axis is not within shape dimension.");
10641060
}
10651061

1062+
std::size_t saxis = normalize_axis(dim, axis);
1063+
10661064
const auto axis_dim = static_cast<std::ptrdiff_t>(shape[saxis]);
10671065
while (shift < 0)
10681066
{

test/test_xmanipulation.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,18 @@ namespace xt
502502

503503
xarray<double> expected8 = {{{3, 1, 2}}, {{6, 4, 5}}, {{9, 7, 8}}};
504504
ASSERT_EQ(expected8, xt::roll(e2, -2, /*axis*/ 2));
505+
506+
EXPECT_THROW(xt::roll(e2, 1, /*axis*/ 3), std::runtime_error);
507+
EXPECT_THROW(xt::roll(e2, 1, /*axis*/ -4), std::runtime_error);
508+
509+
xarray<double> expected9 = {{{3, 1, 2}}, {{6, 4, 5}}, {{9, 7, 8}}};
510+
ASSERT_EQ(expected9, xt::roll(e2, -2, /*axis*/ -1));
511+
512+
xarray<double> expected10 = {{{1, 2, 3}}, {{4, 5, 6}}, {{7, 8, 9}}};
513+
ASSERT_EQ(expected10, xt::roll(e2, -2, /*axis*/ -2));
514+
515+
xarray<double> expected11 = {{{4, 5, 6}}, {{7, 8, 9}}, {{1, 2, 3}}};
516+
ASSERT_EQ(expected11, xt::roll(e2, 2, /*axis*/ -3));
505517
}
506518

507519
TEST(xmanipulation, repeat_all_elements_of_axis_0_of_int_array_2_times)

0 commit comments

Comments
 (0)