Skip to content

Commit d708a56

Browse files
committed
Use "Datetime" ydb type for sa.DATETIME
1 parent 2c0cb4a commit d708a56

File tree

4 files changed

+37
-6
lines changed

4 files changed

+37
-6
lines changed

test/test_core.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,16 @@ def test_integer_types(self, connection):
253253
assert result == (b"Uint8", b"Uint16", b"Uint32", b"Uint64", b"Int8", b"Int16", b"Int32", b"Int64")
254254

255255
def test_datetime_types(self, connection: sa.Connection):
256+
stmt = sa.Select(
257+
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_datetime", datetime.datetime.now(), sa.DateTime))),
258+
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_DATETIME", datetime.datetime.now(), sa.DATETIME))),
259+
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_TIMESTAMP", datetime.datetime.now(), sa.TIMESTAMP))),
260+
)
261+
262+
result = connection.execute(stmt).fetchone()
263+
assert result == (b"Timestamp", b"Datetime", b"Timestamp")
264+
265+
def test_datetime_types_timezone(self, connection: sa.Connection):
256266
table = self.tables.test_datetime_types
257267

258268
now_dt = datetime.datetime.now()

ydb_sqlalchemy/sqlalchemy/__init__.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,13 @@ def visit_BINARY(self, type_: sa.BINARY, **kw):
132132
def visit_BLOB(self, type_: sa.BLOB, **kw):
133133
return "String"
134134

135-
def visit_DATETIME(self, type_: sa.TIMESTAMP, **kw):
135+
def visit_datetime(self, type_: sa.TIMESTAMP, **kw):
136+
return self.visit_TIMESTAMP(type_, **kw)
137+
138+
def visit_DATETIME(self, type_: sa.DATETIME, **kw):
139+
return "DateTime"
140+
141+
def visit_TIMESTAMP(self, type_: sa.TIMESTAMP, **kw):
136142
return "Timestamp"
137143

138144
def visit_list_type(self, type_: types.ListType, **kw):
@@ -193,7 +199,10 @@ def get_ydb_type(
193199
elif isinstance(type_, types.YqlJSON.YqlJSONPathType):
194200
ydb_type = ydb.PrimitiveType.Utf8
195201
# Json
196-
202+
elif isinstance(type_, sa.DATETIME):
203+
ydb_type = ydb.PrimitiveType.Datetime
204+
elif isinstance(type_, sa.TIMESTAMP):
205+
ydb_type = ydb.PrimitiveType.Timestamp
197206
elif isinstance(type_, sa.DateTime):
198207
ydb_type = ydb.PrimitiveType.Timestamp
199208
elif isinstance(type_, sa.Date):
@@ -610,7 +619,9 @@ class YqlDialect(StrCompileDialect):
610619
colspecs = {
611620
sa.types.JSON: types.YqlJSON,
612621
sa.types.JSON.JSONPathType: types.YqlJSON.YqlJSONPathType,
613-
sa.types.DateTime: types.YqlDateTime,
622+
sa.types.DateTime: types.YqlTimestamp,
623+
sa.types.DATETIME: types.YqlDateTime,
624+
sa.types.TIMESTAMP: types.YqlTimestamp,
614625
}
615626

616627
connection_characteristics = util.immutabledict(

ydb_sqlalchemy/sqlalchemy/datetime_types.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
from sqlalchemy import Dialect
55
from sqlalchemy import types as sqltypes
6-
from sqlalchemy.sql.type_api import _ResultProcessorType
6+
from sqlalchemy.sql.type_api import _BindProcessorType, _ResultProcessorType
77

88

9-
class YqlDateTime(sqltypes.DateTime):
9+
class YqlTimestamp(sqltypes.DateTime):
1010
def result_processor(self, dialect: Dialect, coltype: str) -> Optional[_ResultProcessorType[datetime.datetime]]:
1111
def process(value: Optional[datetime.datetime]) -> Optional[datetime.datetime]:
1212
if value is None:
@@ -16,3 +16,13 @@ def process(value: Optional[datetime.datetime]) -> Optional[datetime.datetime]:
1616
if self.timezone:
1717
return process
1818
return None
19+
20+
21+
class YqlDateTime(YqlTimestamp):
22+
def bind_processor(self, dialect: Dialect) -> Optional[_BindProcessorType[datetime.datetime]]:
23+
def process(value: Optional[datetime.datetime]) -> Optional[int]:
24+
if value is None:
25+
return None
26+
return int(value.timestamp())
27+
28+
return process

ydb_sqlalchemy/sqlalchemy/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from sqlalchemy import ARRAY, ColumnElement, exc, types
44
from sqlalchemy.sql import type_api
55

6-
from .datetime_types import YqlDateTime # noqa: F401
6+
from .datetime_types import YqlDateTime, YqlTimestamp # noqa: F401
77
from .json import YqlJSON # noqa: F401
88

99

0 commit comments

Comments
 (0)