Skip to content

Commit 743bc4a

Browse files
author
kabulov kozim
committed
fix orm for tables in directories
1 parent 6dc5578 commit 743bc4a

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

ydb_sqlalchemy/sqlalchemy/__init__.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -584,14 +584,15 @@ class YqlDialect(StrCompileDialect):
584584
def import_dbapi(cls: Any):
585585
return dbapi.YdbDBApi()
586586

587-
def __init__(self, json_serializer=None, json_deserializer=None, _add_declare_for_yql_stmt_vars=False, **kwargs):
587+
def __init__(self, json_serializer=None, json_deserializer=None, _add_declare_for_yql_stmt_vars=False, directories=[], **kwargs):
588588
super().__init__(**kwargs)
589589

590590
self._json_deserializer = json_deserializer
591591
self._json_serializer = json_serializer
592592
# NOTE: _add_declare_for_yql_stmt_vars is temporary and is soon to be removed.
593593
# no need in declare in yql statement here since ydb 24-1
594594
self._add_declare_for_yql_stmt_vars = _add_declare_for_yql_stmt_vars
595+
self._directories = directories
595596

596597
def _describe_table(self, connection, table_name, schema=None):
597598
if schema is not None:
@@ -673,6 +674,12 @@ def do_rollback(self, dbapi_connection: dbapi.Connection) -> None:
673674
def do_commit(self, dbapi_connection: dbapi.Connection) -> None:
674675
dbapi_connection.commit()
675676

677+
def _fix_variable_name(self, variable):
678+
for directory in self._directories:
679+
if variable.startswith(f"{directory}/"):
680+
return f"{directory}_" + variable[len(directory) + 1:]
681+
return variable
682+
676683
def _format_variables(
677684
self,
678685
statement: str,
@@ -689,12 +696,12 @@ def _format_variables(
689696
formatted_parameters = []
690697
for i in range(len(parameters_sequence)):
691698
variable_names.update(set(parameters_sequence[i].keys()))
692-
formatted_parameters.append({f"${k}": v for k, v in parameters_sequence[i].items()})
699+
formatted_parameters.append({f"${self._fix_variable_name(k)}": v for k, v in parameters_sequence[i].items()})
693700
else:
694701
variable_names = set(parameters.keys())
695-
formatted_parameters = {f"${k}": v for k, v in parameters.items()}
702+
formatted_parameters = {f"${self._fix_variable_name(k)}": v for k, v in parameters.items()}
696703

697-
formatted_variable_names = {variable_name: f"${variable_name}" for variable_name in variable_names}
704+
formatted_variable_names = {variable_name: f"${self._fix_variable_name(variable_name)}" for variable_name in variable_names}
698705
formatted_statement = formatted_statement % formatted_variable_names
699706

700707
formatted_statement = formatted_statement.replace("%%", "%")
@@ -717,7 +724,7 @@ def _make_ydb_operation(
717724

718725
if not is_ddl and parameters:
719726
parameters_types = context.compiled.get_bind_types(parameters)
720-
parameters_types = {f"${k}": v for k, v in parameters_types.items()}
727+
parameters_types = {f"${self._fix_variable_name(k)}": v for k, v in parameters_types.items()}
721728
statement, parameters = self._format_variables(statement, parameters, execute_many)
722729
if self._add_declare_for_yql_stmt_vars:
723730
statement = self._add_declare_for_yql_stmt_vars_impl(statement, parameters_types)

0 commit comments

Comments
 (0)