diff --git a/quaddtype/numpy_quaddtype/src/ops.hpp b/quaddtype/numpy_quaddtype/src/ops.hpp index dd385d92..d0041ff7 100644 --- a/quaddtype/numpy_quaddtype/src/ops.hpp +++ b/quaddtype/numpy_quaddtype/src/ops.hpp @@ -76,6 +76,40 @@ quad_sqrt(const Sleef_quad *op) return Sleef_sqrtq1_u05(*op); } +static inline Sleef_quad +quad_cbrt(const Sleef_quad *op) +{ + // SLEEF doesn't provide cbrt, so we implement it using pow + // cbrt(x) = x^(1/3) + // For negative values: cbrt(-x) = -cbrt(x) + + // Handle special cases + if (Sleef_iunordq1(*op, *op)) { + return *op; // NaN + } + if (Sleef_icmpeqq1(*op, QUAD_ZERO)) { + return *op; // ±0 + } + // Check if op is ±inf: isinf(x) = abs(x) == inf + if (Sleef_icmpeqq1(Sleef_fabsq1(*op), QUAD_POS_INF)) { + return *op; // ±inf + } + + // Compute 1/3 as a quad precision constant + Sleef_quad three = Sleef_cast_from_int64q1(3); + Sleef_quad one_third = Sleef_divq1_u05(QUAD_ONE, three); + + // Handle negative values: cbrt(-x) = -cbrt(x) + if (Sleef_icmpltq1(*op, QUAD_ZERO)) { + Sleef_quad abs_val = Sleef_fabsq1(*op); + Sleef_quad result = Sleef_powq1_u10(abs_val, one_third); + return Sleef_negq1(result); + } + + // Positive values + return Sleef_powq1_u10(*op, one_third); +} + static inline Sleef_quad quad_square(const Sleef_quad *op) { @@ -260,6 +294,12 @@ ld_sqrt(const long double *op) return sqrtl(*op); } +static inline long double +ld_cbrt(const long double *op) +{ + return cbrtl(*op); +} + static inline long double ld_square(const long double *op) { diff --git a/quaddtype/numpy_quaddtype/src/umath/unary_ops.cpp b/quaddtype/numpy_quaddtype/src/umath/unary_ops.cpp index b8a82aee..d8095afa 100644 --- a/quaddtype/numpy_quaddtype/src/umath/unary_ops.cpp +++ b/quaddtype/numpy_quaddtype/src/umath/unary_ops.cpp @@ -179,6 +179,9 @@ init_quad_unary_ops(PyObject *numpy) if (create_quad_unary_ufunc(numpy, "sqrt") < 0) { return -1; } + if (create_quad_unary_ufunc(numpy, "cbrt") < 0) { + return -1; + } if (create_quad_unary_ufunc(numpy, "square") < 0) { return -1; } diff --git a/quaddtype/release_tracker.md b/quaddtype/release_tracker.md index cb61ecbe..781c3843 100644 --- a/quaddtype/release_tracker.md +++ b/quaddtype/release_tracker.md @@ -38,7 +38,7 @@ | log1p | ✅ | ✅ | | sqrt | ✅ | ✅ | | square | ✅ | ✅ | -| cbrt | | | +| cbrt | ✅ | ✅ | | reciprocal | ✅ | ✅ | | gcd | | | | lcm | | | diff --git a/quaddtype/tests/test_quaddtype.py b/quaddtype/tests/test_quaddtype.py index b63daeef..5d565790 100644 --- a/quaddtype/tests/test_quaddtype.py +++ b/quaddtype/tests/test_quaddtype.py @@ -253,6 +253,81 @@ def test_rint_near_halfway(): assert np.rint(QuadPrecision("7.5")) == 8 +@pytest.mark.parametrize("val", [ + # Perfect cubes + "1.0", "8.0", "27.0", "64.0", "125.0", "1000.0", + # Negative perfect cubes + "-1.0", "-8.0", "-27.0", "-64.0", "-125.0", "-1000.0", + # Small positive values + "0.001", "0.008", "0.027", "1e-9", "1e-15", "1e-100", + # Small negative values + "-0.001", "-0.008", "-0.027", "-1e-9", "-1e-15", "-1e-100", + # Large positive values + "1e10", "1e15", "1e100", "1e300", + # Large negative values + "-1e10", "-1e15", "-1e100", "-1e300", + # Fractional values + "0.5", "2.5", "3.5", "10.5", "100.5", + "-0.5", "-2.5", "-3.5", "-10.5", "-100.5", + # Edge cases + "0.0", "-0.0", + # Special values + "inf", "-inf", "nan", "-nan" +]) +def test_cbrt(val): + """Comprehensive test for cube root function""" + quad_val = QuadPrecision(val) + float_val = float(val) + + quad_result = np.cbrt(quad_val) + float_result = np.cbrt(float_val) + + # Handle NaN cases + if np.isnan(float_result): + assert np.isnan( + float(quad_result)), f"Expected NaN for cbrt({val}), got {float(quad_result)}" + return + + # Handle infinity cases + if np.isinf(float_result): + assert np.isinf( + float(quad_result)), f"Expected inf for cbrt({val}), got {float(quad_result)}" + assert np.sign(float_result) == np.sign( + float(quad_result)), f"Infinity sign mismatch for cbrt({val})" + return + + # For finite results, check value and sign + # Use relative tolerance for cbrt + if float_result != 0.0: + rtol = 1e-14 if abs(float_result) < 1e100 else 1e-10 + np.testing.assert_allclose(float(quad_result), float_result, rtol=rtol, atol=1e-15, + err_msg=f"Value mismatch for cbrt({val})") + else: + # For zero results + assert float(quad_result) == 0.0, f"Expected 0 for cbrt({val}), got {float(quad_result)}" + assert np.signbit(float_result) == np.signbit( + quad_result), f"Zero sign mismatch for cbrt({val})" + + +def test_cbrt_accuracy(): + """Test that cbrt gives accurate results for perfect cubes""" + # Test perfect cubes + for i in [1, 2, 3, 4, 5, 10, 100]: + val = QuadPrecision(i ** 3) + result = np.cbrt(val) + expected = QuadPrecision(i) + np.testing.assert_allclose(float(result), float(expected), rtol=1e-14, atol=1e-15, + err_msg=f"cbrt({i}^3) should equal {i}") + + # Test negative perfect cubes + for i in [1, 2, 3, 4, 5, 10, 100]: + val = QuadPrecision(-(i ** 3)) + result = np.cbrt(val) + expected = QuadPrecision(-i) + np.testing.assert_allclose(float(result), float(expected), rtol=1e-14, atol=1e-15, + err_msg=f"cbrt(-{i}^3) should equal -{i}") + + @pytest.mark.parametrize("op", ["exp", "exp2"]) @pytest.mark.parametrize("val", [ # Basic cases