diff --git a/test/test_suite.py b/test/test_suite.py index 573bdbc..a870a6d 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -1,4 +1,5 @@ import ctypes +import decimal import pytest import sqlalchemy as sa @@ -266,7 +267,7 @@ def test_huge_int_auto_accommodation(self, connection, intvalue): pass -@pytest.mark.skip("TODO: fix & skip those tests - add Double/Decimal support. see #12") +@pytest.mark.skip("Use YdbDecimalTest for Decimal type testing") class NumericTest(_NumericTest): # SqlAlchemy maybe eat Decimal and throw Double pass @@ -596,3 +597,256 @@ class RowFetchTest(_RowFetchTest): @pytest.mark.skip("scalar subquery unsupported") def test_row_w_scalar_select(self, connection): pass + + +class DecimalTest(fixtures.TablesTest): + """Tests for YDB Decimal type using standard sa.DECIMAL""" + + @classmethod + def define_tables(cls, metadata): + Table( + "decimal_test", + metadata, + Column("id", Integer, primary_key=True), + Column("decimal_default", sa.DECIMAL), # Default: precision=22, scale=9 + Column("decimal_custom", sa.DECIMAL(precision=10, scale=2)), + Column("decimal_as_float", sa.DECIMAL(asdecimal=False)), # Should behave like Float + ) + + def test_decimal_basic_operations(self, connection): + """Test basic insert and select operations with Decimal""" + + table = self.tables.decimal_test + + test_values = [ + decimal.Decimal("1"), + decimal.Decimal("2"), + decimal.Decimal("3"), + ] + + # Insert test values + for i, val in enumerate(test_values): + connection.execute(table.insert().values(id=i + 1, decimal_default=val)) + + # Select and verify + results = connection.execute(select(table.c.decimal_default).order_by(table.c.id)).fetchall() + + for i, (result,) in enumerate(results): + expected = test_values[i] + assert isinstance(result, decimal.Decimal) + assert result == expected + + def test_decimal_with_precision_scale(self, connection): + """Test Decimal with specific precision and scale""" + + table = self.tables.decimal_test + + # Test value that fits precision(10, 2) + test_value = decimal.Decimal("12345678.99") + + connection.execute(table.insert().values(id=100, decimal_custom=test_value)) + + result = connection.scalar(select(table.c.decimal_custom).where(table.c.id == 100)) + + assert isinstance(result, decimal.Decimal) + assert result == test_value + + def test_decimal_literal_rendering(self, connection): + """Test literal rendering of Decimal values""" + from sqlalchemy import literal + + table = self.tables.decimal_test + + # Test literal in INSERT + test_value = decimal.Decimal("999.99") + + connection.execute(table.insert().values(id=300, decimal_default=literal(test_value, sa.DECIMAL()))) + + result = connection.scalar(select(table.c.decimal_default).where(table.c.id == 300)) + + assert isinstance(result, decimal.Decimal) + assert result == test_value + + def test_decimal_overflow(self, connection): + """Test behavior when precision is exceeded""" + + table = self.tables.decimal_test + + # Try to insert value that exceeds precision=10, scale=2 + overflow_value = decimal.Decimal("99999.99999") + + with pytest.raises(Exception): # Should raise some kind of database error + connection.execute(table.insert().values(id=500, decimal_custom=overflow_value)) + connection.commit() + + def test_decimal_asdecimal_false(self, connection): + """Test DECIMAL with asdecimal=False (should return float)""" + + table = self.tables.decimal_test + + test_value = decimal.Decimal("123.45") + + connection.execute(table.insert().values(id=600, decimal_as_float=test_value)) + + result = connection.scalar(select(table.c.decimal_as_float).where(table.c.id == 600)) + + assert isinstance(result, float), f"Expected float, got {type(result)}" + assert abs(result - 123.45) < 0.01 + + def test_decimal_arithmetic(self, connection): + """Test arithmetic operations with Decimal columns""" + + table = self.tables.decimal_test + + val1 = decimal.Decimal("100.50") + val2 = decimal.Decimal("25.25") + + connection.execute(table.insert().values(id=900, decimal_default=val1)) + connection.execute(table.insert().values(id=901, decimal_default=val2)) + + # Test various arithmetic operations + addition_result = connection.scalar( + select(table.c.decimal_default + decimal.Decimal("10.00")).where(table.c.id == 900) + ) + + subtraction_result = connection.scalar( + select(table.c.decimal_default - decimal.Decimal("5.25")).where(table.c.id == 900) + ) + + multiplication_result = connection.scalar( + select(table.c.decimal_default * decimal.Decimal("2.0")).where(table.c.id == 901) + ) + + division_result = connection.scalar( + select(table.c.decimal_default / decimal.Decimal("2.0")).where(table.c.id == 901) + ) + + # Verify results + assert abs(addition_result - decimal.Decimal("110.50")) < decimal.Decimal("0.01") + assert abs(subtraction_result - decimal.Decimal("95.25")) < decimal.Decimal("0.01") + assert abs(multiplication_result - decimal.Decimal("50.50")) < decimal.Decimal("0.01") + assert abs(division_result - decimal.Decimal("12.625")) < decimal.Decimal("0.01") + + def test_decimal_comparison_operations(self, connection): + """Test comparison operations with Decimal columns""" + + table = self.tables.decimal_test + + values = [ + decimal.Decimal("10.50"), + decimal.Decimal("20.75"), + decimal.Decimal("15.25"), + ] + + for i, val in enumerate(values): + connection.execute(table.insert().values(id=1000 + i, decimal_default=val)) + + # Test various comparisons + greater_than = connection.execute( + select(table.c.id).where(table.c.decimal_default > decimal.Decimal("15.00")).order_by(table.c.id) + ).fetchall() + + less_than = connection.execute( + select(table.c.id).where(table.c.decimal_default < decimal.Decimal("15.00")).order_by(table.c.id) + ).fetchall() + + equal_to = connection.execute( + select(table.c.id).where(table.c.decimal_default == decimal.Decimal("15.25")) + ).fetchall() + + between_values = connection.execute( + select(table.c.id) + .where(table.c.decimal_default.between(decimal.Decimal("15.00"), decimal.Decimal("21.00"))) + .order_by(table.c.id) + ).fetchall() + + # Verify results + assert len(greater_than) == 2 # 20.75 and 15.25 + assert len(less_than) == 1 # 10.50 + assert len(equal_to) == 1 # 15.25 + assert len(between_values) == 2 # 20.75 and 15.25 + + def test_decimal_null_handling(self, connection): + """Test NULL handling with Decimal columns""" + + table = self.tables.decimal_test + + # Insert NULL value + connection.execute(table.insert().values(id=1100, decimal_default=None)) + + # Insert non-NULL value for comparison + connection.execute(table.insert().values(id=1101, decimal_default=decimal.Decimal("42.42"))) + + # Test NULL retrieval + null_result = connection.scalar(select(table.c.decimal_default).where(table.c.id == 1100)) + + non_null_result = connection.scalar(select(table.c.decimal_default).where(table.c.id == 1101)) + + assert null_result is None + assert non_null_result == decimal.Decimal("42.42") + + # Test IS NULL / IS NOT NULL + null_count = connection.scalar(select(func.count()).where(table.c.decimal_default.is_(None))) + + not_null_count = connection.scalar(select(func.count()).where(table.c.decimal_default.isnot(None))) + + # Should have at least 1 NULL and several non-NULL values from other tests + assert null_count >= 1 + assert not_null_count >= 1 + + def test_decimal_input_type_conversion(self, connection): + """Test that bind_processor handles different input types correctly (float, string, int, Decimal)""" + + table = self.tables.decimal_test + + # Test different input types that should all be converted to Decimal + test_cases = [ + (1400, 123.45, "float input"), # float + (1401, "456.78", "string input"), # string + (1402, decimal.Decimal("789.12"), "decimal input"), # already Decimal + (1403, 100, "int input"), # int + ] + + for test_id, input_value, description in test_cases: + connection.execute(table.insert().values(id=test_id, decimal_default=input_value)) + + result = connection.scalar(select(table.c.decimal_default).where(table.c.id == test_id)) + + # All should be returned as Decimal + assert isinstance(result, decimal.Decimal), f"Failed for {description}: got {type(result)}" + + # Verify the value is approximately correct + expected = decimal.Decimal(str(input_value)) + error_str = f"Failed for {description}: expected {expected}, got {result}" + assert abs(result - expected) < decimal.Decimal("0.01"), error_str + + def test_decimal_asdecimal_comparison(self, connection): + """Test comparison between asdecimal=True and asdecimal=False behavior""" + + table = self.tables.decimal_test + + test_value = decimal.Decimal("999.123") + + # Insert same value into both columns + connection.execute( + table.insert().values( + id=1500, + decimal_default=test_value, # asdecimal=True (default) + decimal_as_float=test_value, # asdecimal=False + ) + ) + + # Get results from both columns + result_as_decimal = connection.scalar(select(table.c.decimal_default).where(table.c.id == 1500)) + result_as_float = connection.scalar(select(table.c.decimal_as_float).where(table.c.id == 1500)) + + # Check types are different + assert isinstance(result_as_decimal, decimal.Decimal), f"Expected Decimal, got {type(result_as_decimal)}" + assert isinstance(result_as_float, float), f"Expected float, got {type(result_as_float)}" + + # Check values are approximately equal + assert abs(result_as_decimal - test_value) < decimal.Decimal("0.001") + assert abs(result_as_float - float(test_value)) < 0.001 + + # Check that converting between them gives same value + assert abs(float(result_as_decimal) - result_as_float) < 0.001 diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 0f271f3..db1d1a6 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -140,6 +140,7 @@ class YqlDialect(StrCompileDialect): sa.types.DateTime: types.YqlTimestamp, # Because YDB's DateTime doesn't store microseconds sa.types.DATETIME: types.YqlDateTime, sa.types.TIMESTAMP: types.YqlTimestamp, + sa.types.DECIMAL: types.Decimal, } connection_characteristics = util.immutabledict( diff --git a/ydb_sqlalchemy/sqlalchemy/compiler/base.py b/ydb_sqlalchemy/sqlalchemy/compiler/base.py index 7bb4469..de803c2 100644 --- a/ydb_sqlalchemy/sqlalchemy/compiler/base.py +++ b/ydb_sqlalchemy/sqlalchemy/compiler/base.py @@ -113,9 +113,13 @@ def visit_INTEGER(self, type_: sa.INTEGER, **kw): return "Int64" def visit_NUMERIC(self, type_: sa.Numeric, **kw): - """Only Decimal(22,9) is supported for table columns""" return f"Decimal({type_.precision}, {type_.scale})" + def visit_DECIMAL(self, type_: sa.DECIMAL, **kw): + precision = getattr(type_, "precision", None) or 22 + scale = getattr(type_, "scale", None) or 9 + return f"Decimal({precision}, {scale})" + def visit_BINARY(self, type_: sa.BINARY, **kw): return "String" @@ -204,7 +208,9 @@ def get_ydb_type( elif isinstance(type_, sa.Boolean): ydb_type = ydb.PrimitiveType.Bool elif isinstance(type_, sa.Numeric): - ydb_type = ydb.DecimalType(type_.precision, type_.scale) + precision = getattr(type_, "precision", None) or 22 + scale = getattr(type_, "scale", None) or 9 + ydb_type = ydb.DecimalType(precision, scale) elif isinstance(type_, (types.ListType, sa.ARRAY)): ydb_type = ydb.ListType(self.get_ydb_type(type_.item_type, is_optional=False)) elif isinstance(type_, types.StructType): diff --git a/ydb_sqlalchemy/sqlalchemy/types.py b/ydb_sqlalchemy/sqlalchemy/types.py index 261eb9f..d8d601c 100644 --- a/ydb_sqlalchemy/sqlalchemy/types.py +++ b/ydb_sqlalchemy/sqlalchemy/types.py @@ -1,3 +1,4 @@ +import decimal from typing import Any, Mapping, Type, Union from sqlalchemy import __version__ as sa_version @@ -46,6 +47,66 @@ class Int8(types.Integer): __visit_name__ = "int8" +class Decimal(types.DECIMAL): + __visit_name__ = "DECIMAL" + + def __init__(self, precision=None, scale=None, asdecimal=True): + # YDB supports Decimal(22,9) by default + if precision is None: + precision = 22 + if scale is None: + scale = 9 + super().__init__(precision=precision, scale=scale, asdecimal=asdecimal) + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + # Convert float to Decimal if needed + if isinstance(value, float): + return decimal.Decimal(str(value)) + elif isinstance(value, str): + return decimal.Decimal(value) + elif not isinstance(value, decimal.Decimal): + return decimal.Decimal(str(value)) + return value + + return process + + def result_processor(self, dialect, coltype): + def process(value): + if value is None: + return None + + # YDB always returns Decimal values as decimal.Decimal objects + # But if asdecimal=False, we should convert to float + if not self.asdecimal: + return float(value) + + # For asdecimal=True (default), return as Decimal + if not isinstance(value, decimal.Decimal): + return decimal.Decimal(str(value)) + return value + + return process + + def literal_processor(self, dialect): + def process(value): + # Convert float to Decimal if needed + if isinstance(value, float): + value = decimal.Decimal(str(value)) + elif not isinstance(value, decimal.Decimal): + value = decimal.Decimal(str(value)) + + # Use default precision and scale if not specified + precision = self.precision if self.precision is not None else 22 + scale = self.scale if self.scale is not None else 9 + + return f'Decimal("{str(value)}", {precision}, {scale})' + + return process + + class ListType(ARRAY): __visit_name__ = "list_type"