Skip to content

Commit e2f2a49

Browse files
committed
YDB Decimal support
1 parent 8729c2e commit e2f2a49

File tree

4 files changed

+325
-3
lines changed

4 files changed

+325
-3
lines changed

test/test_suite.py

Lines changed: 255 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ctypes
2+
import decimal
23

34
import pytest
45
import sqlalchemy as sa
@@ -266,7 +267,7 @@ def test_huge_int_auto_accommodation(self, connection, intvalue):
266267
pass
267268

268269

269-
@pytest.mark.skip("TODO: fix & skip those tests - add Double/Decimal support. see #12")
270+
@pytest.mark.skip("Use YdbDecimalTest for Decimal type testing")
270271
class NumericTest(_NumericTest):
271272
# SqlAlchemy maybe eat Decimal and throw Double
272273
pass
@@ -596,3 +597,256 @@ class RowFetchTest(_RowFetchTest):
596597
@pytest.mark.skip("scalar subquery unsupported")
597598
def test_row_w_scalar_select(self, connection):
598599
pass
600+
601+
602+
class DecimalTest(fixtures.TablesTest):
603+
"""Tests for YDB Decimal type using standard sa.DECIMAL"""
604+
605+
@classmethod
606+
def define_tables(cls, metadata):
607+
Table(
608+
"decimal_test",
609+
metadata,
610+
Column("id", Integer, primary_key=True),
611+
Column("decimal_default", sa.DECIMAL), # Default: precision=22, scale=9
612+
Column("decimal_custom", sa.DECIMAL(precision=10, scale=2)),
613+
Column("decimal_as_float", sa.DECIMAL(asdecimal=False)), # Should behave like Float
614+
)
615+
616+
def test_decimal_basic_operations(self, connection):
617+
"""Test basic insert and select operations with Decimal"""
618+
619+
table = self.tables.decimal_test
620+
621+
test_values = [
622+
decimal.Decimal("1"),
623+
decimal.Decimal("2"),
624+
decimal.Decimal("3"),
625+
]
626+
627+
# Insert test values
628+
for i, val in enumerate(test_values):
629+
connection.execute(table.insert().values(id=i + 1, decimal_default=val))
630+
631+
# Select and verify
632+
results = connection.execute(select(table.c.decimal_default).order_by(table.c.id)).fetchall()
633+
634+
for i, (result,) in enumerate(results):
635+
expected = test_values[i]
636+
assert isinstance(result, decimal.Decimal)
637+
assert result == expected
638+
639+
def test_decimal_with_precision_scale(self, connection):
640+
"""Test Decimal with specific precision and scale"""
641+
642+
table = self.tables.decimal_test
643+
644+
# Test value that fits precision(10, 2)
645+
test_value = decimal.Decimal("12345678.99")
646+
647+
connection.execute(table.insert().values(id=100, decimal_custom=test_value))
648+
649+
result = connection.scalar(select(table.c.decimal_custom).where(table.c.id == 100))
650+
651+
assert isinstance(result, decimal.Decimal)
652+
assert result == test_value
653+
654+
def test_decimal_literal_rendering(self, connection):
655+
"""Test literal rendering of Decimal values"""
656+
from sqlalchemy import literal
657+
658+
table = self.tables.decimal_test
659+
660+
# Test literal in INSERT
661+
test_value = decimal.Decimal("999.99")
662+
663+
connection.execute(table.insert().values(id=300, decimal_default=literal(test_value, sa.DECIMAL())))
664+
665+
result = connection.scalar(select(table.c.decimal_default).where(table.c.id == 300))
666+
667+
assert isinstance(result, decimal.Decimal)
668+
assert result == test_value
669+
670+
def test_decimal_overflow(self, connection):
671+
"""Test behavior when precision is exceeded"""
672+
673+
table = self.tables.decimal_test
674+
675+
# Try to insert value that exceeds precision=10, scale=2
676+
overflow_value = decimal.Decimal("99999.99999")
677+
678+
with pytest.raises(Exception): # Should raise some kind of database error
679+
connection.execute(table.insert().values(id=500, decimal_custom=overflow_value))
680+
connection.commit()
681+
682+
def test_decimal_asdecimal_false(self, connection):
683+
"""Test DECIMAL with asdecimal=False (should return float)"""
684+
685+
table = self.tables.decimal_test
686+
687+
test_value = decimal.Decimal("123.45")
688+
689+
connection.execute(table.insert().values(id=600, decimal_as_float=test_value))
690+
691+
result = connection.scalar(select(table.c.decimal_as_float).where(table.c.id == 600))
692+
693+
assert isinstance(result, float), f"Expected float, got {type(result)}"
694+
assert abs(result - 123.45) < 0.01
695+
696+
def test_decimal_arithmetic(self, connection):
697+
"""Test arithmetic operations with Decimal columns"""
698+
699+
table = self.tables.decimal_test
700+
701+
val1 = decimal.Decimal("100.50")
702+
val2 = decimal.Decimal("25.25")
703+
704+
connection.execute(table.insert().values(id=900, decimal_default=val1))
705+
connection.execute(table.insert().values(id=901, decimal_default=val2))
706+
707+
# Test various arithmetic operations
708+
addition_result = connection.scalar(
709+
select(table.c.decimal_default + decimal.Decimal("10.00")).where(table.c.id == 900)
710+
)
711+
712+
subtraction_result = connection.scalar(
713+
select(table.c.decimal_default - decimal.Decimal("5.25")).where(table.c.id == 900)
714+
)
715+
716+
multiplication_result = connection.scalar(
717+
select(table.c.decimal_default * decimal.Decimal("2.0")).where(table.c.id == 901)
718+
)
719+
720+
division_result = connection.scalar(
721+
select(table.c.decimal_default / decimal.Decimal("2.0")).where(table.c.id == 901)
722+
)
723+
724+
# Verify results
725+
assert abs(addition_result - decimal.Decimal("110.50")) < decimal.Decimal("0.01")
726+
assert abs(subtraction_result - decimal.Decimal("95.25")) < decimal.Decimal("0.01")
727+
assert abs(multiplication_result - decimal.Decimal("50.50")) < decimal.Decimal("0.01")
728+
assert abs(division_result - decimal.Decimal("12.625")) < decimal.Decimal("0.01")
729+
730+
def test_decimal_comparison_operations(self, connection):
731+
"""Test comparison operations with Decimal columns"""
732+
733+
table = self.tables.decimal_test
734+
735+
values = [
736+
decimal.Decimal("10.50"),
737+
decimal.Decimal("20.75"),
738+
decimal.Decimal("15.25"),
739+
]
740+
741+
for i, val in enumerate(values):
742+
connection.execute(table.insert().values(id=1000 + i, decimal_default=val))
743+
744+
# Test various comparisons
745+
greater_than = connection.execute(
746+
select(table.c.id).where(table.c.decimal_default > decimal.Decimal("15.00")).order_by(table.c.id)
747+
).fetchall()
748+
749+
less_than = connection.execute(
750+
select(table.c.id).where(table.c.decimal_default < decimal.Decimal("15.00")).order_by(table.c.id)
751+
).fetchall()
752+
753+
equal_to = connection.execute(
754+
select(table.c.id).where(table.c.decimal_default == decimal.Decimal("15.25"))
755+
).fetchall()
756+
757+
between_values = connection.execute(
758+
select(table.c.id)
759+
.where(table.c.decimal_default.between(decimal.Decimal("15.00"), decimal.Decimal("21.00")))
760+
.order_by(table.c.id)
761+
).fetchall()
762+
763+
# Verify results
764+
assert len(greater_than) == 2 # 20.75 and 15.25
765+
assert len(less_than) == 1 # 10.50
766+
assert len(equal_to) == 1 # 15.25
767+
assert len(between_values) == 2 # 20.75 and 15.25
768+
769+
def test_decimal_null_handling(self, connection):
770+
"""Test NULL handling with Decimal columns"""
771+
772+
table = self.tables.decimal_test
773+
774+
# Insert NULL value
775+
connection.execute(table.insert().values(id=1100, decimal_default=None))
776+
777+
# Insert non-NULL value for comparison
778+
connection.execute(table.insert().values(id=1101, decimal_default=decimal.Decimal("42.42")))
779+
780+
# Test NULL retrieval
781+
null_result = connection.scalar(select(table.c.decimal_default).where(table.c.id == 1100))
782+
783+
non_null_result = connection.scalar(select(table.c.decimal_default).where(table.c.id == 1101))
784+
785+
assert null_result is None
786+
assert non_null_result == decimal.Decimal("42.42")
787+
788+
# Test IS NULL / IS NOT NULL
789+
null_count = connection.scalar(select(func.count()).where(table.c.decimal_default.is_(None)))
790+
791+
not_null_count = connection.scalar(select(func.count()).where(table.c.decimal_default.isnot(None)))
792+
793+
# Should have at least 1 NULL and several non-NULL values from other tests
794+
assert null_count >= 1
795+
assert not_null_count >= 1
796+
797+
def test_decimal_input_type_conversion(self, connection):
798+
"""Test that bind_processor handles different input types correctly (float, string, int, Decimal)"""
799+
800+
table = self.tables.decimal_test
801+
802+
# Test different input types that should all be converted to Decimal
803+
test_cases = [
804+
(1400, 123.45, "float input"), # float
805+
(1401, "456.78", "string input"), # string
806+
(1402, decimal.Decimal("789.12"), "decimal input"), # already Decimal
807+
(1403, 100, "int input"), # int
808+
]
809+
810+
for test_id, input_value, description in test_cases:
811+
connection.execute(table.insert().values(id=test_id, decimal_default=input_value))
812+
813+
result = connection.scalar(select(table.c.decimal_default).where(table.c.id == test_id))
814+
815+
# All should be returned as Decimal
816+
assert isinstance(result, decimal.Decimal), f"Failed for {description}: got {type(result)}"
817+
818+
# Verify the value is approximately correct
819+
expected = decimal.Decimal(str(input_value))
820+
error_str = f"Failed for {description}: expected {expected}, got {result}"
821+
assert abs(result - expected) < decimal.Decimal("0.01"), error_str
822+
823+
def test_decimal_asdecimal_comparison(self, connection):
824+
"""Test comparison between asdecimal=True and asdecimal=False behavior"""
825+
826+
table = self.tables.decimal_test
827+
828+
test_value = decimal.Decimal("999.123")
829+
830+
# Insert same value into both columns
831+
connection.execute(
832+
table.insert().values(
833+
id=1500,
834+
decimal_default=test_value, # asdecimal=True (default)
835+
decimal_as_float=test_value, # asdecimal=False
836+
)
837+
)
838+
839+
# Get results from both columns
840+
result_as_decimal = connection.scalar(select(table.c.decimal_default).where(table.c.id == 1500))
841+
result_as_float = connection.scalar(select(table.c.decimal_as_float).where(table.c.id == 1500))
842+
843+
# Check types are different
844+
assert isinstance(result_as_decimal, decimal.Decimal), f"Expected Decimal, got {type(result_as_decimal)}"
845+
assert isinstance(result_as_float, float), f"Expected float, got {type(result_as_float)}"
846+
847+
# Check values are approximately equal
848+
assert abs(result_as_decimal - test_value) < decimal.Decimal("0.001")
849+
assert abs(result_as_float - float(test_value)) < 0.001
850+
851+
# Check that converting between them gives same value
852+
assert abs(float(result_as_decimal) - result_as_float) < 0.001

ydb_sqlalchemy/sqlalchemy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ class YqlDialect(StrCompileDialect):
140140
sa.types.DateTime: types.YqlTimestamp, # Because YDB's DateTime doesn't store microseconds
141141
sa.types.DATETIME: types.YqlDateTime,
142142
sa.types.TIMESTAMP: types.YqlTimestamp,
143+
sa.types.DECIMAL: types.Decimal,
143144
}
144145

145146
connection_characteristics = util.immutabledict(

ydb_sqlalchemy/sqlalchemy/compiler/base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,13 @@ def visit_INTEGER(self, type_: sa.INTEGER, **kw):
113113
return "Int64"
114114

115115
def visit_NUMERIC(self, type_: sa.Numeric, **kw):
116-
"""Only Decimal(22,9) is supported for table columns"""
117116
return f"Decimal({type_.precision}, {type_.scale})"
118117

118+
def visit_DECIMAL(self, type_: sa.DECIMAL, **kw):
119+
precision = getattr(type_, "precision", None) or 22
120+
scale = getattr(type_, "scale", None) or 9
121+
return f"Decimal({precision}, {scale})"
122+
119123
def visit_BINARY(self, type_: sa.BINARY, **kw):
120124
return "String"
121125

@@ -204,7 +208,9 @@ def get_ydb_type(
204208
elif isinstance(type_, sa.Boolean):
205209
ydb_type = ydb.PrimitiveType.Bool
206210
elif isinstance(type_, sa.Numeric):
207-
ydb_type = ydb.DecimalType(type_.precision, type_.scale)
211+
precision = getattr(type_, "precision", None) or 22
212+
scale = getattr(type_, "scale", None) or 9
213+
ydb_type = ydb.DecimalType(precision, scale)
208214
elif isinstance(type_, (types.ListType, sa.ARRAY)):
209215
ydb_type = ydb.ListType(self.get_ydb_type(type_.item_type, is_optional=False))
210216
elif isinstance(type_, types.StructType):

ydb_sqlalchemy/sqlalchemy/types.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import decimal
12
from typing import Any, Mapping, Type, Union
23

34
from sqlalchemy import __version__ as sa_version
@@ -46,6 +47,66 @@ class Int8(types.Integer):
4647
__visit_name__ = "int8"
4748

4849

50+
class Decimal(types.DECIMAL):
51+
__visit_name__ = "DECIMAL"
52+
53+
def __init__(self, precision=None, scale=None, asdecimal=True):
54+
# YDB supports Decimal(22,9) by default
55+
if precision is None:
56+
precision = 22
57+
if scale is None:
58+
scale = 9
59+
super().__init__(precision=precision, scale=scale, asdecimal=asdecimal)
60+
61+
def bind_processor(self, dialect):
62+
def process(value):
63+
if value is None:
64+
return None
65+
# Convert float to Decimal if needed
66+
if isinstance(value, float):
67+
return decimal.Decimal(str(value))
68+
elif isinstance(value, str):
69+
return decimal.Decimal(value)
70+
elif not isinstance(value, decimal.Decimal):
71+
return decimal.Decimal(str(value))
72+
return value
73+
74+
return process
75+
76+
def result_processor(self, dialect, coltype):
77+
def process(value):
78+
if value is None:
79+
return None
80+
81+
# YDB always returns Decimal values as decimal.Decimal objects
82+
# But if asdecimal=False, we should convert to float
83+
if not self.asdecimal:
84+
return float(value)
85+
86+
# For asdecimal=True (default), return as Decimal
87+
if not isinstance(value, decimal.Decimal):
88+
return decimal.Decimal(str(value))
89+
return value
90+
91+
return process
92+
93+
def literal_processor(self, dialect):
94+
def process(value):
95+
# Convert float to Decimal if needed
96+
if isinstance(value, float):
97+
value = decimal.Decimal(str(value))
98+
elif not isinstance(value, decimal.Decimal):
99+
value = decimal.Decimal(str(value))
100+
101+
# Use default precision and scale if not specified
102+
precision = self.precision if self.precision is not None else 22
103+
scale = self.scale if self.scale is not None else 9
104+
105+
return f'Decimal("{str(value)}", {precision}, {scale})'
106+
107+
return process
108+
109+
49110
class ListType(ARRAY):
50111
__visit_name__ = "list_type"
51112

0 commit comments

Comments
 (0)