Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 255 additions & 1 deletion test/test_suite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ctypes
import decimal

import pytest
import sqlalchemy as sa
Expand Down Expand Up @@ -266,7 +267,7 @@ def test_huge_int_auto_accommodation(self, connection, intvalue):
pass


@pytest.mark.skip("TODO: fix & skip those tests - add Double/Decimal support. see #12")
@pytest.mark.skip("Use YdbDecimalTest for Decimal type testing")
class NumericTest(_NumericTest):
# SqlAlchemy maybe eat Decimal and throw Double
pass
Expand Down Expand Up @@ -596,3 +597,256 @@ class RowFetchTest(_RowFetchTest):
@pytest.mark.skip("scalar subquery unsupported")
def test_row_w_scalar_select(self, connection):
pass


class DecimalTest(fixtures.TablesTest):
"""Tests for YDB Decimal type using standard sa.DECIMAL"""

@classmethod
def define_tables(cls, metadata):
Table(
"decimal_test",
metadata,
Column("id", Integer, primary_key=True),
Column("decimal_default", sa.DECIMAL), # Default: precision=22, scale=9
Column("decimal_custom", sa.DECIMAL(precision=10, scale=2)),
Column("decimal_as_float", sa.DECIMAL(asdecimal=False)), # Should behave like Float
)

def test_decimal_basic_operations(self, connection):
"""Test basic insert and select operations with Decimal"""

table = self.tables.decimal_test

test_values = [
decimal.Decimal("1"),
decimal.Decimal("2"),
decimal.Decimal("3"),
]

# Insert test values
for i, val in enumerate(test_values):
connection.execute(table.insert().values(id=i + 1, decimal_default=val))

# Select and verify
results = connection.execute(select(table.c.decimal_default).order_by(table.c.id)).fetchall()

for i, (result,) in enumerate(results):
expected = test_values[i]
assert isinstance(result, decimal.Decimal)
assert result == expected

def test_decimal_with_precision_scale(self, connection):
"""Test Decimal with specific precision and scale"""

table = self.tables.decimal_test

# Test value that fits precision(10, 2)
test_value = decimal.Decimal("12345678.99")

connection.execute(table.insert().values(id=100, decimal_custom=test_value))

result = connection.scalar(select(table.c.decimal_custom).where(table.c.id == 100))

assert isinstance(result, decimal.Decimal)
assert result == test_value

def test_decimal_literal_rendering(self, connection):
"""Test literal rendering of Decimal values"""
from sqlalchemy import literal

table = self.tables.decimal_test

# Test literal in INSERT
test_value = decimal.Decimal("999.99")

connection.execute(table.insert().values(id=300, decimal_default=literal(test_value, sa.DECIMAL())))

result = connection.scalar(select(table.c.decimal_default).where(table.c.id == 300))

assert isinstance(result, decimal.Decimal)
assert result == test_value

def test_decimal_overflow(self, connection):
"""Test behavior when precision is exceeded"""

table = self.tables.decimal_test

# Try to insert value that exceeds precision=10, scale=2
overflow_value = decimal.Decimal("99999.99999")

with pytest.raises(Exception): # Should raise some kind of database error
connection.execute(table.insert().values(id=500, decimal_custom=overflow_value))
connection.commit()

def test_decimal_asdecimal_false(self, connection):
"""Test DECIMAL with asdecimal=False (should return float)"""

table = self.tables.decimal_test

test_value = decimal.Decimal("123.45")

connection.execute(table.insert().values(id=600, decimal_as_float=test_value))

result = connection.scalar(select(table.c.decimal_as_float).where(table.c.id == 600))

assert isinstance(result, float), f"Expected float, got {type(result)}"
assert abs(result - 123.45) < 0.01

def test_decimal_arithmetic(self, connection):
"""Test arithmetic operations with Decimal columns"""

table = self.tables.decimal_test

val1 = decimal.Decimal("100.50")
val2 = decimal.Decimal("25.25")

connection.execute(table.insert().values(id=900, decimal_default=val1))
connection.execute(table.insert().values(id=901, decimal_default=val2))

# Test various arithmetic operations
addition_result = connection.scalar(
select(table.c.decimal_default + decimal.Decimal("10.00")).where(table.c.id == 900)
)

subtraction_result = connection.scalar(
select(table.c.decimal_default - decimal.Decimal("5.25")).where(table.c.id == 900)
)

multiplication_result = connection.scalar(
select(table.c.decimal_default * decimal.Decimal("2.0")).where(table.c.id == 901)
)

division_result = connection.scalar(
select(table.c.decimal_default / decimal.Decimal("2.0")).where(table.c.id == 901)
)

# Verify results
assert abs(addition_result - decimal.Decimal("110.50")) < decimal.Decimal("0.01")
assert abs(subtraction_result - decimal.Decimal("95.25")) < decimal.Decimal("0.01")
assert abs(multiplication_result - decimal.Decimal("50.50")) < decimal.Decimal("0.01")
assert abs(division_result - decimal.Decimal("12.625")) < decimal.Decimal("0.01")

def test_decimal_comparison_operations(self, connection):
"""Test comparison operations with Decimal columns"""

table = self.tables.decimal_test

values = [
decimal.Decimal("10.50"),
decimal.Decimal("20.75"),
decimal.Decimal("15.25"),
]

for i, val in enumerate(values):
connection.execute(table.insert().values(id=1000 + i, decimal_default=val))

# Test various comparisons
greater_than = connection.execute(
select(table.c.id).where(table.c.decimal_default > decimal.Decimal("15.00")).order_by(table.c.id)
).fetchall()

less_than = connection.execute(
select(table.c.id).where(table.c.decimal_default < decimal.Decimal("15.00")).order_by(table.c.id)
).fetchall()

equal_to = connection.execute(
select(table.c.id).where(table.c.decimal_default == decimal.Decimal("15.25"))
).fetchall()

between_values = connection.execute(
select(table.c.id)
.where(table.c.decimal_default.between(decimal.Decimal("15.00"), decimal.Decimal("21.00")))
.order_by(table.c.id)
).fetchall()

# Verify results
assert len(greater_than) == 2 # 20.75 and 15.25
assert len(less_than) == 1 # 10.50
assert len(equal_to) == 1 # 15.25
assert len(between_values) == 2 # 20.75 and 15.25

def test_decimal_null_handling(self, connection):
"""Test NULL handling with Decimal columns"""

table = self.tables.decimal_test

# Insert NULL value
connection.execute(table.insert().values(id=1100, decimal_default=None))

# Insert non-NULL value for comparison
connection.execute(table.insert().values(id=1101, decimal_default=decimal.Decimal("42.42")))

# Test NULL retrieval
null_result = connection.scalar(select(table.c.decimal_default).where(table.c.id == 1100))

non_null_result = connection.scalar(select(table.c.decimal_default).where(table.c.id == 1101))

assert null_result is None
assert non_null_result == decimal.Decimal("42.42")

# Test IS NULL / IS NOT NULL
null_count = connection.scalar(select(func.count()).where(table.c.decimal_default.is_(None)))

not_null_count = connection.scalar(select(func.count()).where(table.c.decimal_default.isnot(None)))

# Should have at least 1 NULL and several non-NULL values from other tests
assert null_count >= 1
assert not_null_count >= 1

def test_decimal_input_type_conversion(self, connection):
"""Test that bind_processor handles different input types correctly (float, string, int, Decimal)"""

table = self.tables.decimal_test

# Test different input types that should all be converted to Decimal
test_cases = [
(1400, 123.45, "float input"), # float
(1401, "456.78", "string input"), # string
(1402, decimal.Decimal("789.12"), "decimal input"), # already Decimal
(1403, 100, "int input"), # int
]

for test_id, input_value, description in test_cases:
connection.execute(table.insert().values(id=test_id, decimal_default=input_value))

result = connection.scalar(select(table.c.decimal_default).where(table.c.id == test_id))

# All should be returned as Decimal
assert isinstance(result, decimal.Decimal), f"Failed for {description}: got {type(result)}"

# Verify the value is approximately correct
expected = decimal.Decimal(str(input_value))
error_str = f"Failed for {description}: expected {expected}, got {result}"
assert abs(result - expected) < decimal.Decimal("0.01"), error_str

def test_decimal_asdecimal_comparison(self, connection):
"""Test comparison between asdecimal=True and asdecimal=False behavior"""

table = self.tables.decimal_test

test_value = decimal.Decimal("999.123")

# Insert same value into both columns
connection.execute(
table.insert().values(
id=1500,
decimal_default=test_value, # asdecimal=True (default)
decimal_as_float=test_value, # asdecimal=False
)
)

# Get results from both columns
result_as_decimal = connection.scalar(select(table.c.decimal_default).where(table.c.id == 1500))
result_as_float = connection.scalar(select(table.c.decimal_as_float).where(table.c.id == 1500))

# Check types are different
assert isinstance(result_as_decimal, decimal.Decimal), f"Expected Decimal, got {type(result_as_decimal)}"
assert isinstance(result_as_float, float), f"Expected float, got {type(result_as_float)}"

# Check values are approximately equal
assert abs(result_as_decimal - test_value) < decimal.Decimal("0.001")
assert abs(result_as_float - float(test_value)) < 0.001

# Check that converting between them gives same value
assert abs(float(result_as_decimal) - result_as_float) < 0.001
1 change: 1 addition & 0 deletions ydb_sqlalchemy/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class YqlDialect(StrCompileDialect):
sa.types.DateTime: types.YqlTimestamp, # Because YDB's DateTime doesn't store microseconds
sa.types.DATETIME: types.YqlDateTime,
sa.types.TIMESTAMP: types.YqlTimestamp,
sa.types.DECIMAL: types.Decimal,
}

connection_characteristics = util.immutabledict(
Expand Down
10 changes: 8 additions & 2 deletions ydb_sqlalchemy/sqlalchemy/compiler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,13 @@ def visit_INTEGER(self, type_: sa.INTEGER, **kw):
return "Int64"

def visit_NUMERIC(self, type_: sa.Numeric, **kw):
"""Only Decimal(22,9) is supported for table columns"""
return f"Decimal({type_.precision}, {type_.scale})"

def visit_DECIMAL(self, type_: sa.DECIMAL, **kw):
precision = getattr(type_, "precision", None) or 22
scale = getattr(type_, "scale", None) or 9
return f"Decimal({precision}, {scale})"

def visit_BINARY(self, type_: sa.BINARY, **kw):
return "String"

Expand Down Expand Up @@ -204,7 +208,9 @@ def get_ydb_type(
elif isinstance(type_, sa.Boolean):
ydb_type = ydb.PrimitiveType.Bool
elif isinstance(type_, sa.Numeric):
ydb_type = ydb.DecimalType(type_.precision, type_.scale)
precision = getattr(type_, "precision", None) or 22
scale = getattr(type_, "scale", None) or 9
ydb_type = ydb.DecimalType(precision, scale)
elif isinstance(type_, (types.ListType, sa.ARRAY)):
ydb_type = ydb.ListType(self.get_ydb_type(type_.item_type, is_optional=False))
elif isinstance(type_, types.StructType):
Expand Down
61 changes: 61 additions & 0 deletions ydb_sqlalchemy/sqlalchemy/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import decimal
from typing import Any, Mapping, Type, Union

from sqlalchemy import __version__ as sa_version
Expand Down Expand Up @@ -46,6 +47,66 @@ class Int8(types.Integer):
__visit_name__ = "int8"


class Decimal(types.DECIMAL):
__visit_name__ = "DECIMAL"

def __init__(self, precision=None, scale=None, asdecimal=True):
# YDB supports Decimal(22,9) by default
if precision is None:
precision = 22
if scale is None:
scale = 9
super().__init__(precision=precision, scale=scale, asdecimal=asdecimal)

def bind_processor(self, dialect):
def process(value):
if value is None:
return None
# Convert float to Decimal if needed
if isinstance(value, float):
return decimal.Decimal(str(value))
elif isinstance(value, str):
return decimal.Decimal(value)
elif not isinstance(value, decimal.Decimal):
return decimal.Decimal(str(value))
return value

return process

def result_processor(self, dialect, coltype):
def process(value):
if value is None:
return None

# YDB always returns Decimal values as decimal.Decimal objects
# But if asdecimal=False, we should convert to float
if not self.asdecimal:
return float(value)

# For asdecimal=True (default), return as Decimal
if not isinstance(value, decimal.Decimal):
return decimal.Decimal(str(value))
return value

return process

def literal_processor(self, dialect):
def process(value):
# Convert float to Decimal if needed
if isinstance(value, float):
value = decimal.Decimal(str(value))
elif not isinstance(value, decimal.Decimal):
value = decimal.Decimal(str(value))

# Use default precision and scale if not specified
precision = self.precision if self.precision is not None else 22
scale = self.scale if self.scale is not None else 9

return f'Decimal("{str(value)}", {precision}, {scale})'

return process


class ListType(ARRAY):
__visit_name__ = "list_type"

Expand Down