|
1 | 1 | import ctypes |
| 2 | +import decimal |
2 | 3 |
|
3 | 4 | import pytest |
4 | 5 | import sqlalchemy as sa |
@@ -266,7 +267,7 @@ def test_huge_int_auto_accommodation(self, connection, intvalue): |
266 | 267 | pass |
267 | 268 |
|
268 | 269 |
|
269 | | -@pytest.mark.skip("TODO: fix & skip those tests - add Double/Decimal support. see #12") |
| 270 | +@pytest.mark.skip("Use YdbDecimalTest for Decimal type testing") |
270 | 271 | class NumericTest(_NumericTest): |
271 | 272 | # SqlAlchemy maybe eat Decimal and throw Double |
272 | 273 | pass |
@@ -596,3 +597,256 @@ class RowFetchTest(_RowFetchTest): |
596 | 597 | @pytest.mark.skip("scalar subquery unsupported") |
597 | 598 | def test_row_w_scalar_select(self, connection): |
598 | 599 | pass |
| 600 | + |
| 601 | + |
| 602 | +class DecimalTest(fixtures.TablesTest): |
| 603 | + """Tests for YDB Decimal type using standard sa.DECIMAL""" |
| 604 | + |
| 605 | + @classmethod |
| 606 | + def define_tables(cls, metadata): |
| 607 | + Table( |
| 608 | + "decimal_test", |
| 609 | + metadata, |
| 610 | + Column("id", Integer, primary_key=True), |
| 611 | + Column("decimal_default", sa.DECIMAL), # Default: precision=22, scale=9 |
| 612 | + Column("decimal_custom", sa.DECIMAL(precision=10, scale=2)), |
| 613 | + Column("decimal_as_float", sa.DECIMAL(asdecimal=False)), # Should behave like Float |
| 614 | + ) |
| 615 | + |
| 616 | + def test_decimal_basic_operations(self, connection): |
| 617 | + """Test basic insert and select operations with Decimal""" |
| 618 | + |
| 619 | + table = self.tables.decimal_test |
| 620 | + |
| 621 | + test_values = [ |
| 622 | + decimal.Decimal("1"), |
| 623 | + decimal.Decimal("2"), |
| 624 | + decimal.Decimal("3"), |
| 625 | + ] |
| 626 | + |
| 627 | + # Insert test values |
| 628 | + for i, val in enumerate(test_values): |
| 629 | + connection.execute(table.insert().values(id=i + 1, decimal_default=val)) |
| 630 | + |
| 631 | + # Select and verify |
| 632 | + results = connection.execute(select(table.c.decimal_default).order_by(table.c.id)).fetchall() |
| 633 | + |
| 634 | + for i, (result,) in enumerate(results): |
| 635 | + expected = test_values[i] |
| 636 | + assert isinstance(result, decimal.Decimal) |
| 637 | + assert result == expected |
| 638 | + |
| 639 | + def test_decimal_with_precision_scale(self, connection): |
| 640 | + """Test Decimal with specific precision and scale""" |
| 641 | + |
| 642 | + table = self.tables.decimal_test |
| 643 | + |
| 644 | + # Test value that fits precision(10, 2) |
| 645 | + test_value = decimal.Decimal("12345678.99") |
| 646 | + |
| 647 | + connection.execute(table.insert().values(id=100, decimal_custom=test_value)) |
| 648 | + |
| 649 | + result = connection.scalar(select(table.c.decimal_custom).where(table.c.id == 100)) |
| 650 | + |
| 651 | + assert isinstance(result, decimal.Decimal) |
| 652 | + assert result == test_value |
| 653 | + |
| 654 | + def test_decimal_literal_rendering(self, connection): |
| 655 | + """Test literal rendering of Decimal values""" |
| 656 | + from sqlalchemy import literal |
| 657 | + |
| 658 | + table = self.tables.decimal_test |
| 659 | + |
| 660 | + # Test literal in INSERT |
| 661 | + test_value = decimal.Decimal("999.99") |
| 662 | + |
| 663 | + connection.execute(table.insert().values(id=300, decimal_default=literal(test_value, sa.DECIMAL()))) |
| 664 | + |
| 665 | + result = connection.scalar(select(table.c.decimal_default).where(table.c.id == 300)) |
| 666 | + |
| 667 | + assert isinstance(result, decimal.Decimal) |
| 668 | + assert result == test_value |
| 669 | + |
| 670 | + def test_decimal_overflow(self, connection): |
| 671 | + """Test behavior when precision is exceeded""" |
| 672 | + |
| 673 | + table = self.tables.decimal_test |
| 674 | + |
| 675 | + # Try to insert value that exceeds precision=10, scale=2 |
| 676 | + overflow_value = decimal.Decimal("99999.99999") |
| 677 | + |
| 678 | + with pytest.raises(Exception): # Should raise some kind of database error |
| 679 | + connection.execute(table.insert().values(id=500, decimal_custom=overflow_value)) |
| 680 | + connection.commit() |
| 681 | + |
| 682 | + def test_decimal_asdecimal_false(self, connection): |
| 683 | + """Test DECIMAL with asdecimal=False (should return float)""" |
| 684 | + |
| 685 | + table = self.tables.decimal_test |
| 686 | + |
| 687 | + test_value = decimal.Decimal("123.45") |
| 688 | + |
| 689 | + connection.execute(table.insert().values(id=600, decimal_as_float=test_value)) |
| 690 | + |
| 691 | + result = connection.scalar(select(table.c.decimal_as_float).where(table.c.id == 600)) |
| 692 | + |
| 693 | + assert isinstance(result, float), f"Expected float, got {type(result)}" |
| 694 | + assert abs(result - 123.45) < 0.01 |
| 695 | + |
| 696 | + def test_decimal_arithmetic(self, connection): |
| 697 | + """Test arithmetic operations with Decimal columns""" |
| 698 | + |
| 699 | + table = self.tables.decimal_test |
| 700 | + |
| 701 | + val1 = decimal.Decimal("100.50") |
| 702 | + val2 = decimal.Decimal("25.25") |
| 703 | + |
| 704 | + connection.execute(table.insert().values(id=900, decimal_default=val1)) |
| 705 | + connection.execute(table.insert().values(id=901, decimal_default=val2)) |
| 706 | + |
| 707 | + # Test various arithmetic operations |
| 708 | + addition_result = connection.scalar( |
| 709 | + select(table.c.decimal_default + decimal.Decimal("10.00")).where(table.c.id == 900) |
| 710 | + ) |
| 711 | + |
| 712 | + subtraction_result = connection.scalar( |
| 713 | + select(table.c.decimal_default - decimal.Decimal("5.25")).where(table.c.id == 900) |
| 714 | + ) |
| 715 | + |
| 716 | + multiplication_result = connection.scalar( |
| 717 | + select(table.c.decimal_default * decimal.Decimal("2.0")).where(table.c.id == 901) |
| 718 | + ) |
| 719 | + |
| 720 | + division_result = connection.scalar( |
| 721 | + select(table.c.decimal_default / decimal.Decimal("2.0")).where(table.c.id == 901) |
| 722 | + ) |
| 723 | + |
| 724 | + # Verify results |
| 725 | + assert abs(addition_result - decimal.Decimal("110.50")) < decimal.Decimal("0.01") |
| 726 | + assert abs(subtraction_result - decimal.Decimal("95.25")) < decimal.Decimal("0.01") |
| 727 | + assert abs(multiplication_result - decimal.Decimal("50.50")) < decimal.Decimal("0.01") |
| 728 | + assert abs(division_result - decimal.Decimal("12.625")) < decimal.Decimal("0.01") |
| 729 | + |
| 730 | + def test_decimal_comparison_operations(self, connection): |
| 731 | + """Test comparison operations with Decimal columns""" |
| 732 | + |
| 733 | + table = self.tables.decimal_test |
| 734 | + |
| 735 | + values = [ |
| 736 | + decimal.Decimal("10.50"), |
| 737 | + decimal.Decimal("20.75"), |
| 738 | + decimal.Decimal("15.25"), |
| 739 | + ] |
| 740 | + |
| 741 | + for i, val in enumerate(values): |
| 742 | + connection.execute(table.insert().values(id=1000 + i, decimal_default=val)) |
| 743 | + |
| 744 | + # Test various comparisons |
| 745 | + greater_than = connection.execute( |
| 746 | + select(table.c.id).where(table.c.decimal_default > decimal.Decimal("15.00")).order_by(table.c.id) |
| 747 | + ).fetchall() |
| 748 | + |
| 749 | + less_than = connection.execute( |
| 750 | + select(table.c.id).where(table.c.decimal_default < decimal.Decimal("15.00")).order_by(table.c.id) |
| 751 | + ).fetchall() |
| 752 | + |
| 753 | + equal_to = connection.execute( |
| 754 | + select(table.c.id).where(table.c.decimal_default == decimal.Decimal("15.25")) |
| 755 | + ).fetchall() |
| 756 | + |
| 757 | + between_values = connection.execute( |
| 758 | + select(table.c.id) |
| 759 | + .where(table.c.decimal_default.between(decimal.Decimal("15.00"), decimal.Decimal("21.00"))) |
| 760 | + .order_by(table.c.id) |
| 761 | + ).fetchall() |
| 762 | + |
| 763 | + # Verify results |
| 764 | + assert len(greater_than) == 2 # 20.75 and 15.25 |
| 765 | + assert len(less_than) == 1 # 10.50 |
| 766 | + assert len(equal_to) == 1 # 15.25 |
| 767 | + assert len(between_values) == 2 # 20.75 and 15.25 |
| 768 | + |
| 769 | + def test_decimal_null_handling(self, connection): |
| 770 | + """Test NULL handling with Decimal columns""" |
| 771 | + |
| 772 | + table = self.tables.decimal_test |
| 773 | + |
| 774 | + # Insert NULL value |
| 775 | + connection.execute(table.insert().values(id=1100, decimal_default=None)) |
| 776 | + |
| 777 | + # Insert non-NULL value for comparison |
| 778 | + connection.execute(table.insert().values(id=1101, decimal_default=decimal.Decimal("42.42"))) |
| 779 | + |
| 780 | + # Test NULL retrieval |
| 781 | + null_result = connection.scalar(select(table.c.decimal_default).where(table.c.id == 1100)) |
| 782 | + |
| 783 | + non_null_result = connection.scalar(select(table.c.decimal_default).where(table.c.id == 1101)) |
| 784 | + |
| 785 | + assert null_result is None |
| 786 | + assert non_null_result == decimal.Decimal("42.42") |
| 787 | + |
| 788 | + # Test IS NULL / IS NOT NULL |
| 789 | + null_count = connection.scalar(select(func.count()).where(table.c.decimal_default.is_(None))) |
| 790 | + |
| 791 | + not_null_count = connection.scalar(select(func.count()).where(table.c.decimal_default.isnot(None))) |
| 792 | + |
| 793 | + # Should have at least 1 NULL and several non-NULL values from other tests |
| 794 | + assert null_count >= 1 |
| 795 | + assert not_null_count >= 1 |
| 796 | + |
| 797 | + def test_decimal_input_type_conversion(self, connection): |
| 798 | + """Test that bind_processor handles different input types correctly (float, string, int, Decimal)""" |
| 799 | + |
| 800 | + table = self.tables.decimal_test |
| 801 | + |
| 802 | + # Test different input types that should all be converted to Decimal |
| 803 | + test_cases = [ |
| 804 | + (1400, 123.45, "float input"), # float |
| 805 | + (1401, "456.78", "string input"), # string |
| 806 | + (1402, decimal.Decimal("789.12"), "decimal input"), # already Decimal |
| 807 | + (1403, 100, "int input"), # int |
| 808 | + ] |
| 809 | + |
| 810 | + for test_id, input_value, description in test_cases: |
| 811 | + connection.execute(table.insert().values(id=test_id, decimal_default=input_value)) |
| 812 | + |
| 813 | + result = connection.scalar(select(table.c.decimal_default).where(table.c.id == test_id)) |
| 814 | + |
| 815 | + # All should be returned as Decimal |
| 816 | + assert isinstance(result, decimal.Decimal), f"Failed for {description}: got {type(result)}" |
| 817 | + |
| 818 | + # Verify the value is approximately correct |
| 819 | + expected = decimal.Decimal(str(input_value)) |
| 820 | + error_str = f"Failed for {description}: expected {expected}, got {result}" |
| 821 | + assert abs(result - expected) < decimal.Decimal("0.01"), error_str |
| 822 | + |
| 823 | + def test_decimal_asdecimal_comparison(self, connection): |
| 824 | + """Test comparison between asdecimal=True and asdecimal=False behavior""" |
| 825 | + |
| 826 | + table = self.tables.decimal_test |
| 827 | + |
| 828 | + test_value = decimal.Decimal("999.123") |
| 829 | + |
| 830 | + # Insert same value into both columns |
| 831 | + connection.execute( |
| 832 | + table.insert().values( |
| 833 | + id=1500, |
| 834 | + decimal_default=test_value, # asdecimal=True (default) |
| 835 | + decimal_as_float=test_value, # asdecimal=False |
| 836 | + ) |
| 837 | + ) |
| 838 | + |
| 839 | + # Get results from both columns |
| 840 | + result_as_decimal = connection.scalar(select(table.c.decimal_default).where(table.c.id == 1500)) |
| 841 | + result_as_float = connection.scalar(select(table.c.decimal_as_float).where(table.c.id == 1500)) |
| 842 | + |
| 843 | + # Check types are different |
| 844 | + assert isinstance(result_as_decimal, decimal.Decimal), f"Expected Decimal, got {type(result_as_decimal)}" |
| 845 | + assert isinstance(result_as_float, float), f"Expected float, got {type(result_as_float)}" |
| 846 | + |
| 847 | + # Check values are approximately equal |
| 848 | + assert abs(result_as_decimal - test_value) < decimal.Decimal("0.001") |
| 849 | + assert abs(result_as_float - float(test_value)) < 0.001 |
| 850 | + |
| 851 | + # Check that converting between them gives same value |
| 852 | + assert abs(float(result_as_decimal) - result_as_float) < 0.001 |
0 commit comments