From 9cf22a59257e8dcaf61e351c07229da19037aaa9 Mon Sep 17 00:00:00 2001 From: Milan Lukac Date: Mon, 12 May 2025 15:47:06 +0200 Subject: [PATCH 1/2] Support multiple datasource connections --- .../model/data_source/data_source.py | 22 +++++++++++++++---- .../data_sources/databricks_data_source.py | 2 +- .../data_source/databricks_data_source.py | 4 ---- .../data_sources/postgres_data_source.py | 2 +- .../model/data_source/postgres_data_source.py | 11 ++++------ .../data_sources/snowflake_data_source.py | 2 +- .../data_source/snowflake_data_source.py | 4 ---- 7 files changed, 25 insertions(+), 22 deletions(-) diff --git a/soda-core/src/soda_core/model/data_source/data_source.py b/soda-core/src/soda_core/model/data_source/data_source.py index b447910e2..25bb50d86 100644 --- a/soda-core/src/soda_core/model/data_source/data_source.py +++ b/soda-core/src/soda_core/model/data_source/data_source.py @@ -1,7 +1,7 @@ import abc -from typing import Literal +from typing import Dict, Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from soda_core.model.data_source.data_source_connection_properties import ( DataSourceConnectionProperties, ) @@ -17,11 +17,25 @@ class DataSourceBase( name: str = Field(..., description="Data source name") type: Literal["data_source"] # Alias connection -> connection_properties to read that from yaml config - connection_properties: DataSourceConnectionProperties = Field( - ..., alias="connection", description="Data source connection details" + connection_properties: Dict[str, DataSourceConnectionProperties] = Field( + ..., alias="connections", description="Data source connection details" ) + @property + def default_connection(self) -> DataSourceConnectionProperties: + """Default first iteration implemetation. Use first connection in the dict for now. Empty dict is not allowed.""" + + # TODO: Decide and implement the connection selection strategy. Use "default" or "primary" connection if available? etc + return next(iter(self.connection_properties.values())) + @classmethod def get_class_type(cls) -> str: """The only way to get the simple type like 'postgres' for plugin registry before instantiating the class""" return cls.model_fields["type"].default + + @field_validator("connection_properties", mode="before") + @classmethod + def _validate_connections(cls, value): + if not isinstance(value, dict): + raise ValueError("connections must be a dict") + return {name: cls.infer_connection_type(conn) for name, conn in value.items()} diff --git a/soda-databricks/src/soda_databricks/common/data_sources/databricks_data_source.py b/soda-databricks/src/soda_databricks/common/data_sources/databricks_data_source.py index 72d524051..2c5b29af9 100644 --- a/soda-databricks/src/soda_databricks/common/data_sources/databricks_data_source.py +++ b/soda-databricks/src/soda_databricks/common/data_sources/databricks_data_source.py @@ -20,7 +20,7 @@ def _create_sql_dialect(self) -> SqlDialect: def _create_data_source_connection(self) -> DataSourceConnection: return DatabricksDataSourceConnection( - name=self.data_source_model.name, connection_properties=self.data_source_model.connection_properties + name=self.data_source_model.name, connection_properties=self.data_source_model.default_connection ) def create_metadata_columns_query(self) -> DatabricksMetadataColumnsQuery: diff --git a/soda-databricks/src/soda_databricks/model/data_source/databricks_data_source.py b/soda-databricks/src/soda_databricks/model/data_source/databricks_data_source.py index 98e8c667f..a776acc35 100644 --- a/soda-databricks/src/soda_databricks/model/data_source/databricks_data_source.py +++ b/soda-databricks/src/soda_databricks/model/data_source/databricks_data_source.py @@ -11,11 +11,7 @@ class DatabricksDataSource(DataSourceBase, abc.ABC): type: Literal["databricks"] = Field("databricks") - connection_properties: DatabricksConnectionProperties = Field( - ..., alias="connection", description="Databricks connection configuration" - ) - @field_validator("connection_properties", mode="before") @classmethod def infer_connection_type(cls, value): if "access_token" in value: diff --git a/soda-postgres/src/soda_postgres/common/data_sources/postgres_data_source.py b/soda-postgres/src/soda_postgres/common/data_sources/postgres_data_source.py index ff6f7c107..bb8de40d7 100644 --- a/soda-postgres/src/soda_postgres/common/data_sources/postgres_data_source.py +++ b/soda-postgres/src/soda_postgres/common/data_sources/postgres_data_source.py @@ -19,7 +19,7 @@ def _create_sql_dialect(self) -> SqlDialect: def _create_data_source_connection(self) -> DataSourceConnection: return PostgresDataSourceConnection( - name=self.data_source_model.name, connection_properties=self.data_source_model.connection_properties + name=self.data_source_model.name, connection_properties=self.data_source_model.default_connection ) diff --git a/soda-postgres/src/soda_postgres/model/data_source/postgres_data_source.py b/soda-postgres/src/soda_postgres/model/data_source/postgres_data_source.py index 33a0f9cfd..0997bd825 100644 --- a/soda-postgres/src/soda_postgres/model/data_source/postgres_data_source.py +++ b/soda-postgres/src/soda_postgres/model/data_source/postgres_data_source.py @@ -1,7 +1,7 @@ import abc from typing import Literal -from pydantic import Field, field_validator +from pydantic import Field from soda_core.model.data_source.data_source import DataSourceBase from soda_postgres.model.data_source.postgres_connection_properties import ( PostgresConnectionPassword, @@ -12,14 +12,11 @@ class PostgresDataSource(DataSourceBase, abc.ABC): type: Literal["postgres"] = Field("postgres") - connection_properties: PostgresConnectionProperties = Field( - ..., alias="connection", description="Data source connection details" - ) - @field_validator("connection_properties", mode="before") - def infer_connection_type(cls, value): + @classmethod + def infer_connection_type(cls, value: dict) -> PostgresConnectionProperties: if "password" in value: return PostgresConnectionPassword(**value) elif "password_file" in value: return PostgresConnectionPasswordFile(**value) - raise ValueError("Unknown connection structure") + raise ValueError("Unknown Postgres connection structure") diff --git a/soda-snowflake/src/soda_snowflake/common/data_sources/snowflake_data_source.py b/soda-snowflake/src/soda_snowflake/common/data_sources/snowflake_data_source.py index c547bedb0..f6c366c8b 100644 --- a/soda-snowflake/src/soda_snowflake/common/data_sources/snowflake_data_source.py +++ b/soda-snowflake/src/soda_snowflake/common/data_sources/snowflake_data_source.py @@ -19,7 +19,7 @@ def _create_sql_dialect(self) -> SqlDialect: def _create_data_source_connection(self) -> DataSourceConnection: return SnowflakeDataSourceConnection( - name=self.data_source_model.name, connection_properties=self.data_source_model.connection_properties + name=self.data_source_model.name, connection_properties=self.data_source_model.default_connection ) diff --git a/soda-snowflake/src/soda_snowflake/model/data_source/snowflake_data_source.py b/soda-snowflake/src/soda_snowflake/model/data_source/snowflake_data_source.py index 223c07c73..6fcf34939 100644 --- a/soda-snowflake/src/soda_snowflake/model/data_source/snowflake_data_source.py +++ b/soda-snowflake/src/soda_snowflake/model/data_source/snowflake_data_source.py @@ -15,11 +15,7 @@ class SnowflakeDataSource(DataSourceBase, abc.ABC): type: Literal["snowflake"] = Field("snowflake") - connection_properties: SnowflakeConnectionProperties = Field( - ..., alias="connection", description="Snowflake connection configuration" - ) - @field_validator("connection_properties", mode="before") @classmethod def infer_connection_type(cls, value): if "password" in value: From eed009597636146439d473904e2d172b847d4353 Mon Sep 17 00:00:00 2001 From: Milan Lukac Date: Tue, 13 May 2025 09:23:42 +0200 Subject: [PATCH 2/2] tests --- soda-core/tests/components/test_data_source_api.py | 13 +++++++------ .../databricks_data_source_test_helper.py | 11 ++++++----- .../postgres_data_source_test_helper.py | 13 +++++++------ .../snowflake_data_source_test_helper.py | 11 ++++++----- 4 files changed, 26 insertions(+), 22 deletions(-) diff --git a/soda-core/tests/components/test_data_source_api.py b/soda-core/tests/components/test_data_source_api.py index 658102d5a..e11896da5 100644 --- a/soda-core/tests/components/test_data_source_api.py +++ b/soda-core/tests/components/test_data_source_api.py @@ -26,16 +26,17 @@ def test_data_source_env_var_resolving(env_vars: dict): yaml_str=f""" type: postgres name: postgres_test_ds - connection: - host: ${{env.TEST_POSTGRES_HOST}} - user: ${{env.TEST_POSTGRES_USERNAME}} - password: '${{env.TEST_POSTGRES_PASSWORD}}' - database: ${{env.TEST_POSTGRES_DATABASE}} + connections: + default: + host: ${{env.TEST_POSTGRES_HOST}} + user: ${{env.TEST_POSTGRES_USERNAME}} + password: '${{env.TEST_POSTGRES_PASSWORD}}' + database: ${{env.TEST_POSTGRES_DATABASE}} """ ) data_source_impl: DataSourceImpl = DataSourceImpl.from_yaml_source(data_source_yaml_source) - connection_properties = data_source_impl.data_source_model.connection_properties + connection_properties = data_source_impl.data_source_model.connection_properties["default"] assert connection_properties.host == "localhost" assert connection_properties.user == "soda_test" assert isinstance(connection_properties.password, SecretStr) diff --git a/soda-databricks/tests/data_sources/databricks_data_source_test_helper.py b/soda-databricks/tests/data_sources/databricks_data_source_test_helper.py index d39736cda..f5b3280a5 100644 --- a/soda-databricks/tests/data_sources/databricks_data_source_test_helper.py +++ b/soda-databricks/tests/data_sources/databricks_data_source_test_helper.py @@ -27,11 +27,12 @@ def _create_data_source_yaml_str(self) -> str: return f""" type: databricks name: DATABRICKS_TEST_DS - connection: - host: {os.getenv("DATABRICKS_HOST")} - http_path: {os.getenv("DATABRICKS_HTTP_PATH")} - access_token: {os.getenv("DATABRICKS_TOKEN")} - catalog: {os.getenv("DATABRICKS_CATALOG", "unity_catalog")} + connections: + default: + host: {os.getenv("DATABRICKS_HOST")} + http_path: {os.getenv("DATABRICKS_HTTP_PATH")} + access_token: {os.getenv("DATABRICKS_TOKEN")} + catalog: {os.getenv("DATABRICKS_CATALOG", "unity_catalog")} """ def create_test_schema_if_not_exists_sql(self) -> str: diff --git a/soda-postgres/tests/data_sources/postgres_data_source_test_helper.py b/soda-postgres/tests/data_sources/postgres_data_source_test_helper.py index 35dc1fb3f..3def8d306 100644 --- a/soda-postgres/tests/data_sources/postgres_data_source_test_helper.py +++ b/soda-postgres/tests/data_sources/postgres_data_source_test_helper.py @@ -20,10 +20,11 @@ def _create_data_source_yaml_str(self) -> str: return f""" type: postgres name: postgres_test_ds - connection: - host: {os.getenv("POSTGRES_HOST", "localhost")} - user: {os.getenv("POSTGRES_USERNAME", "soda_test")} - password: {os.getenv("POSTGRES_PASSWORD")} - port: {int(os.getenv("POSTGRES_PORT", "5432"))} - database: {self.dataset_prefix[0]} + connections: + default: + host: {os.getenv("POSTGRES_HOST", "localhost")} + user: {os.getenv("POSTGRES_USERNAME", "soda_test")} + password: {os.getenv("POSTGRES_PASSWORD")} + port: {int(os.getenv("POSTGRES_PORT", "5432"))} + database: {self.dataset_prefix[0]} """ diff --git a/soda-snowflake/tests/data_sources/snowflake_data_source_test_helper.py b/soda-snowflake/tests/data_sources/snowflake_data_source_test_helper.py index 636a2991e..d4ec6754f 100644 --- a/soda-snowflake/tests/data_sources/snowflake_data_source_test_helper.py +++ b/soda-snowflake/tests/data_sources/snowflake_data_source_test_helper.py @@ -27,11 +27,12 @@ def _create_data_source_yaml_str(self) -> str: return f""" type: snowflake name: SNOWFLAKE_TEST_DS - connection: - account: {os.getenv("SNOWFLAKE_ACCOUNT")} - user: {os.getenv("SNOWFLAKE_USER")} - password: {os.getenv("SNOWFLAKE_PASSWORD")} - database: {self.dataset_prefix[0]} + connections: + default: + account: {os.getenv("SNOWFLAKE_ACCOUNT")} + user: {os.getenv("SNOWFLAKE_USER")} + password: {os.getenv("SNOWFLAKE_PASSWORD")} + database: {self.dataset_prefix[0]} """ def create_test_schema_if_not_exists_sql(self) -> str: