Skip to content

Commit b62aea9

Browse files
committed
YDB Decimal support
1 parent 8729c2e commit b62aea9

File tree

4 files changed

+381
-3
lines changed

4 files changed

+381
-3
lines changed

test/test_suite.py

Lines changed: 311 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ctypes
2+
import decimal
23

34
import pytest
45
import sqlalchemy as sa
@@ -266,7 +267,7 @@ def test_huge_int_auto_accommodation(self, connection, intvalue):
266267
pass
267268

268269

269-
@pytest.mark.skip("TODO: fix & skip those tests - add Double/Decimal support. see #12")
270+
@pytest.mark.skip("Use YdbDecimalTest for Decimal type testing")
270271
class NumericTest(_NumericTest):
271272
# SqlAlchemy maybe eat Decimal and throw Double
272273
pass
@@ -596,3 +597,312 @@ class RowFetchTest(_RowFetchTest):
596597
@pytest.mark.skip("scalar subquery unsupported")
597598
def test_row_w_scalar_select(self, connection):
598599
pass
600+
601+
602+
class DecimalTest(fixtures.TablesTest):
603+
"""Tests for YDB Decimal type using standard sa.DECIMAL"""
604+
605+
@classmethod
606+
def define_tables(cls, metadata):
607+
Table(
608+
"decimal_test",
609+
metadata,
610+
Column("id", Integer, primary_key=True),
611+
Column("decimal_default", sa.DECIMAL), # Default: precision=22, scale=9
612+
Column("decimal_custom", sa.DECIMAL(precision=10, scale=2)),
613+
Column("decimal_as_float", sa.DECIMAL(asdecimal=False)), # Should behave like Float
614+
)
615+
616+
def test_decimal_basic_operations(self, connection):
617+
"""Test basic insert and select operations with Decimal"""
618+
619+
table = self.tables.decimal_test
620+
621+
test_values = [
622+
decimal.Decimal("1"),
623+
decimal.Decimal("2"),
624+
decimal.Decimal("3"),
625+
]
626+
627+
# Insert test values
628+
for i, val in enumerate(test_values):
629+
connection.execute(
630+
table.insert().values(id=i + 1, decimal_default=val)
631+
)
632+
633+
# Select and verify
634+
results = connection.execute(
635+
select(table.c.decimal_default).order_by(table.c.id)
636+
).fetchall()
637+
638+
for i, (result,) in enumerate(results):
639+
expected = test_values[i]
640+
assert isinstance(result, decimal.Decimal)
641+
assert result == expected
642+
643+
def test_decimal_with_precision_scale(self, connection):
644+
"""Test Decimal with specific precision and scale"""
645+
646+
table = self.tables.decimal_test
647+
648+
# Test value that fits precision(10, 2)
649+
test_value = decimal.Decimal("12345678.99")
650+
651+
connection.execute(
652+
table.insert().values(id=100, decimal_custom=test_value)
653+
)
654+
655+
result = connection.scalar(
656+
select(table.c.decimal_custom).where(table.c.id == 100)
657+
)
658+
659+
assert isinstance(result, decimal.Decimal)
660+
assert result == test_value
661+
662+
def test_decimal_literal_rendering(self, connection):
663+
"""Test literal rendering of Decimal values"""
664+
from sqlalchemy import literal
665+
666+
table = self.tables.decimal_test
667+
668+
# Test literal in INSERT
669+
test_value = decimal.Decimal("999.99")
670+
671+
connection.execute(
672+
table.insert().values(
673+
id=300,
674+
decimal_default=literal(test_value, sa.DECIMAL())
675+
)
676+
)
677+
678+
result = connection.scalar(
679+
select(table.c.decimal_default).where(table.c.id == 300)
680+
)
681+
682+
assert isinstance(result, decimal.Decimal)
683+
assert result == test_value
684+
685+
def test_decimal_overflow(self, connection):
686+
"""Test behavior when precision is exceeded"""
687+
688+
table = self.tables.decimal_test
689+
690+
# Try to insert value that exceeds precision=10, scale=2
691+
overflow_value = decimal.Decimal("99999.99999")
692+
693+
with pytest.raises(Exception): # Should raise some kind of database error
694+
connection.execute(
695+
table.insert().values(id=500, decimal_custom=overflow_value)
696+
)
697+
connection.commit()
698+
699+
def test_decimal_asdecimal_false(self, connection):
700+
"""Test DECIMAL with asdecimal=False (should return float)"""
701+
702+
table = self.tables.decimal_test
703+
704+
test_value = decimal.Decimal("123.45")
705+
706+
connection.execute(
707+
table.insert().values(id=600, decimal_as_float=test_value)
708+
)
709+
710+
result = connection.scalar(
711+
select(table.c.decimal_as_float).where(table.c.id == 600)
712+
)
713+
714+
assert isinstance(result, float), f"Expected float, got {type(result)}"
715+
assert abs(result - 123.45) < 0.01
716+
717+
def test_decimal_arithmetic(self, connection):
718+
"""Test arithmetic operations with Decimal columns"""
719+
720+
table = self.tables.decimal_test
721+
722+
val1 = decimal.Decimal("100.50")
723+
val2 = decimal.Decimal("25.25")
724+
725+
connection.execute(table.insert().values(id=900, decimal_default=val1))
726+
connection.execute(table.insert().values(id=901, decimal_default=val2))
727+
728+
# Test various arithmetic operations
729+
addition_result = connection.scalar(
730+
select(table.c.decimal_default + decimal.Decimal("10.00"))
731+
.where(table.c.id == 900)
732+
)
733+
734+
subtraction_result = connection.scalar(
735+
select(table.c.decimal_default - decimal.Decimal("5.25"))
736+
.where(table.c.id == 900)
737+
)
738+
739+
multiplication_result = connection.scalar(
740+
select(table.c.decimal_default * decimal.Decimal("2.0"))
741+
.where(table.c.id == 901)
742+
)
743+
744+
division_result = connection.scalar(
745+
select(table.c.decimal_default / decimal.Decimal("2.0"))
746+
.where(table.c.id == 901)
747+
)
748+
749+
# Verify results
750+
assert abs(addition_result - decimal.Decimal("110.50")) < decimal.Decimal("0.01")
751+
assert abs(subtraction_result - decimal.Decimal("95.25")) < decimal.Decimal("0.01")
752+
assert abs(multiplication_result - decimal.Decimal("50.50")) < decimal.Decimal("0.01")
753+
assert abs(division_result - decimal.Decimal("12.625")) < decimal.Decimal("0.01")
754+
755+
def test_decimal_comparison_operations(self, connection):
756+
"""Test comparison operations with Decimal columns"""
757+
758+
table = self.tables.decimal_test
759+
760+
values = [
761+
decimal.Decimal("10.50"),
762+
decimal.Decimal("20.75"),
763+
decimal.Decimal("15.25"),
764+
]
765+
766+
for i, val in enumerate(values):
767+
connection.execute(
768+
table.insert().values(id=1000 + i, decimal_default=val)
769+
)
770+
771+
# Test various comparisons
772+
greater_than = connection.execute(
773+
select(table.c.id).where(
774+
table.c.decimal_default > decimal.Decimal("15.00")
775+
).order_by(table.c.id)
776+
).fetchall()
777+
778+
less_than = connection.execute(
779+
select(table.c.id).where(
780+
table.c.decimal_default < decimal.Decimal("15.00")
781+
).order_by(table.c.id)
782+
).fetchall()
783+
784+
equal_to = connection.execute(
785+
select(table.c.id).where(
786+
table.c.decimal_default == decimal.Decimal("15.25")
787+
)
788+
).fetchall()
789+
790+
between_values = connection.execute(
791+
select(table.c.id).where(
792+
table.c.decimal_default.between(
793+
decimal.Decimal("15.00"),
794+
decimal.Decimal("21.00")
795+
)
796+
).order_by(table.c.id)
797+
).fetchall()
798+
799+
# Verify results
800+
assert len(greater_than) == 2 # 20.75 and 15.25
801+
assert len(less_than) == 1 # 10.50
802+
assert len(equal_to) == 1 # 15.25
803+
assert len(between_values) == 2 # 20.75 and 15.25
804+
805+
def test_decimal_null_handling(self, connection):
806+
"""Test NULL handling with Decimal columns"""
807+
808+
table = self.tables.decimal_test
809+
810+
# Insert NULL value
811+
connection.execute(
812+
table.insert().values(id=1100, decimal_default=None)
813+
)
814+
815+
# Insert non-NULL value for comparison
816+
connection.execute(
817+
table.insert().values(id=1101, decimal_default=decimal.Decimal("42.42"))
818+
)
819+
820+
# Test NULL retrieval
821+
null_result = connection.scalar(
822+
select(table.c.decimal_default).where(table.c.id == 1100)
823+
)
824+
825+
non_null_result = connection.scalar(
826+
select(table.c.decimal_default).where(table.c.id == 1101)
827+
)
828+
829+
assert null_result is None
830+
assert non_null_result == decimal.Decimal("42.42")
831+
832+
# Test IS NULL / IS NOT NULL
833+
null_count = connection.scalar(
834+
select(func.count()).where(table.c.decimal_default.is_(None))
835+
)
836+
837+
not_null_count = connection.scalar(
838+
select(func.count()).where(table.c.decimal_default.isnot(None))
839+
)
840+
841+
# Should have at least 1 NULL and several non-NULL values from other tests
842+
assert null_count >= 1
843+
assert not_null_count >= 1
844+
845+
def test_decimal_input_type_conversion(self, connection):
846+
"""Test that bind_processor handles different input types correctly (float, string, int, Decimal)"""
847+
848+
table = self.tables.decimal_test
849+
850+
# Test different input types that should all be converted to Decimal
851+
test_cases = [
852+
(1400, 123.45, "float input"), # float
853+
(1401, "456.78", "string input"), # string
854+
(1402, decimal.Decimal("789.12"), "decimal input"), # already Decimal
855+
(1403, 100, "int input"), # int
856+
]
857+
858+
for test_id, input_value, description in test_cases:
859+
connection.execute(
860+
table.insert().values(id=test_id, decimal_default=input_value)
861+
)
862+
863+
result = connection.scalar(
864+
select(table.c.decimal_default).where(table.c.id == test_id)
865+
)
866+
867+
# All should be returned as Decimal
868+
assert isinstance(result, decimal.Decimal), f"Failed for {description}: got {type(result)}"
869+
870+
# Verify the value is approximately correct
871+
expected = decimal.Decimal(str(input_value))
872+
error_str = f"Failed for {description}: expected {expected}, got {result}"
873+
assert abs(result - expected) < decimal.Decimal("0.01"), error_str
874+
875+
def test_decimal_asdecimal_comparison(self, connection):
876+
"""Test comparison between asdecimal=True and asdecimal=False behavior"""
877+
878+
table = self.tables.decimal_test
879+
880+
test_value = decimal.Decimal("999.123")
881+
882+
# Insert same value into both columns
883+
connection.execute(
884+
table.insert().values(
885+
id=1500,
886+
decimal_default=test_value, # asdecimal=True (default)
887+
decimal_as_float=test_value # asdecimal=False
888+
)
889+
)
890+
891+
# Get results from both columns
892+
result_as_decimal = connection.scalar(
893+
select(table.c.decimal_default).where(table.c.id == 1500)
894+
)
895+
result_as_float = connection.scalar(
896+
select(table.c.decimal_as_float).where(table.c.id == 1500)
897+
)
898+
899+
# Check types are different
900+
assert isinstance(result_as_decimal, decimal.Decimal), f"Expected Decimal, got {type(result_as_decimal)}"
901+
assert isinstance(result_as_float, float), f"Expected float, got {type(result_as_float)}"
902+
903+
# Check values are approximately equal
904+
assert abs(result_as_decimal - test_value) < decimal.Decimal("0.001")
905+
assert abs(result_as_float - float(test_value)) < 0.001
906+
907+
# Check that converting between them gives same value
908+
assert abs(float(result_as_decimal) - result_as_float) < 0.001

ydb_sqlalchemy/sqlalchemy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ class YqlDialect(StrCompileDialect):
140140
sa.types.DateTime: types.YqlTimestamp, # Because YDB's DateTime doesn't store microseconds
141141
sa.types.DATETIME: types.YqlDateTime,
142142
sa.types.TIMESTAMP: types.YqlTimestamp,
143+
sa.types.DECIMAL: types.Decimal,
143144
}
144145

145146
connection_characteristics = util.immutabledict(

ydb_sqlalchemy/sqlalchemy/compiler/base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,13 @@ def visit_INTEGER(self, type_: sa.INTEGER, **kw):
113113
return "Int64"
114114

115115
def visit_NUMERIC(self, type_: sa.Numeric, **kw):
116-
"""Only Decimal(22,9) is supported for table columns"""
117116
return f"Decimal({type_.precision}, {type_.scale})"
118117

118+
def visit_DECIMAL(self, type_: sa.DECIMAL, **kw):
119+
precision = getattr(type_, 'precision', None) or 22
120+
scale = getattr(type_, 'scale', None) or 9
121+
return f"Decimal({precision}, {scale})"
122+
119123
def visit_BINARY(self, type_: sa.BINARY, **kw):
120124
return "String"
121125

@@ -204,7 +208,9 @@ def get_ydb_type(
204208
elif isinstance(type_, sa.Boolean):
205209
ydb_type = ydb.PrimitiveType.Bool
206210
elif isinstance(type_, sa.Numeric):
207-
ydb_type = ydb.DecimalType(type_.precision, type_.scale)
211+
precision = getattr(type_, 'precision', None) or 22
212+
scale = getattr(type_, 'scale', None) or 9
213+
ydb_type = ydb.DecimalType(precision, scale)
208214
elif isinstance(type_, (types.ListType, sa.ARRAY)):
209215
ydb_type = ydb.ListType(self.get_ydb_type(type_.item_type, is_optional=False))
210216
elif isinstance(type_, types.StructType):

0 commit comments

Comments
 (0)