Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions soda-core/src/soda_core/model/data_source/data_source.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
from typing import Literal
from typing import Dict, Literal

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from soda_core.model.data_source.data_source_connection_properties import (
DataSourceConnectionProperties,
)
Expand All @@ -17,11 +17,25 @@ class DataSourceBase(
name: str = Field(..., description="Data source name")
type: Literal["data_source"]
# Alias connection -> connection_properties to read that from yaml config
connection_properties: DataSourceConnectionProperties = Field(
..., alias="connection", description="Data source connection details"
connection_properties: Dict[str, DataSourceConnectionProperties] = Field(
..., alias="connections", description="Data source connection details"
)

@property
def default_connection(self) -> DataSourceConnectionProperties:
"""Default first iteration implemetation. Use first connection in the dict for now. Empty dict is not allowed."""
Copy link

Copilot AI May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spelling mistake: 'implemetation' should be corrected to 'implementation'.

Suggested change
"""Default first iteration implemetation. Use first connection in the dict for now. Empty dict is not allowed."""
"""Default first iteration implementation. Use first connection in the dict for now. Empty dict is not allowed."""

Copilot uses AI. Check for mistakes.

# TODO: Decide and implement the connection selection strategy. Use "default" or "primary" connection if available? etc
Comment on lines +26 to +28
Copy link

Copilot AI May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider explicitly selecting the 'default' connection key (if present) instead of relying on iteration order, to improve clarity and maintainability when multiple connections are provided.

Suggested change
"""Default first iteration implemetation. Use first connection in the dict for now. Empty dict is not allowed."""
# TODO: Decide and implement the connection selection strategy. Use "default" or "primary" connection if available? etc
"""Return the default connection. Prioritize the 'default' key if present; otherwise, use the first connection in the dict."""
if "default" in self.connection_properties:
return self.connection_properties["default"]

Copilot uses AI. Check for mistakes.
return next(iter(self.connection_properties.values()))

@classmethod
def get_class_type(cls) -> str:
"""The only way to get the simple type like 'postgres' for plugin registry before instantiating the class"""
return cls.model_fields["type"].default

@field_validator("connection_properties", mode="before")
@classmethod
def _validate_connections(cls, value):
if not isinstance(value, dict):
raise ValueError("connections must be a dict")
return {name: cls.infer_connection_type(conn) for name, conn in value.items()}
13 changes: 7 additions & 6 deletions soda-core/tests/components/test_data_source_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,17 @@ def test_data_source_env_var_resolving(env_vars: dict):
yaml_str=f"""
type: postgres
name: postgres_test_ds
connection:
host: ${{env.TEST_POSTGRES_HOST}}
user: ${{env.TEST_POSTGRES_USERNAME}}
password: '${{env.TEST_POSTGRES_PASSWORD}}'
database: ${{env.TEST_POSTGRES_DATABASE}}
connections:
default:
host: ${{env.TEST_POSTGRES_HOST}}
user: ${{env.TEST_POSTGRES_USERNAME}}
password: '${{env.TEST_POSTGRES_PASSWORD}}'
database: ${{env.TEST_POSTGRES_DATABASE}}
"""
)
data_source_impl: DataSourceImpl = DataSourceImpl.from_yaml_source(data_source_yaml_source)

connection_properties = data_source_impl.data_source_model.connection_properties
connection_properties = data_source_impl.data_source_model.connection_properties["default"]
assert connection_properties.host == "localhost"
assert connection_properties.user == "soda_test"
assert isinstance(connection_properties.password, SecretStr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _create_sql_dialect(self) -> SqlDialect:

def _create_data_source_connection(self) -> DataSourceConnection:
return DatabricksDataSourceConnection(
name=self.data_source_model.name, connection_properties=self.data_source_model.connection_properties
name=self.data_source_model.name, connection_properties=self.data_source_model.default_connection
)

def create_metadata_columns_query(self) -> DatabricksMetadataColumnsQuery:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@

class DatabricksDataSource(DataSourceBase, abc.ABC):
type: Literal["databricks"] = Field("databricks")
connection_properties: DatabricksConnectionProperties = Field(
..., alias="connection", description="Databricks connection configuration"
)

@field_validator("connection_properties", mode="before")
@classmethod
def infer_connection_type(cls, value):
if "access_token" in value:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ def _create_data_source_yaml_str(self) -> str:
return f"""
type: databricks
name: DATABRICKS_TEST_DS
connection:
host: {os.getenv("DATABRICKS_HOST")}
http_path: {os.getenv("DATABRICKS_HTTP_PATH")}
access_token: {os.getenv("DATABRICKS_TOKEN")}
catalog: {os.getenv("DATABRICKS_CATALOG", "unity_catalog")}
connections:
default:
host: {os.getenv("DATABRICKS_HOST")}
http_path: {os.getenv("DATABRICKS_HTTP_PATH")}
access_token: {os.getenv("DATABRICKS_TOKEN")}
catalog: {os.getenv("DATABRICKS_CATALOG", "unity_catalog")}
"""

def create_test_schema_if_not_exists_sql(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _create_sql_dialect(self) -> SqlDialect:

def _create_data_source_connection(self) -> DataSourceConnection:
return PostgresDataSourceConnection(
name=self.data_source_model.name, connection_properties=self.data_source_model.connection_properties
name=self.data_source_model.name, connection_properties=self.data_source_model.default_connection
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
from typing import Literal

from pydantic import Field, field_validator
from pydantic import Field
from soda_core.model.data_source.data_source import DataSourceBase
from soda_postgres.model.data_source.postgres_connection_properties import (
PostgresConnectionPassword,
Expand All @@ -12,14 +12,11 @@

class PostgresDataSource(DataSourceBase, abc.ABC):
type: Literal["postgres"] = Field("postgres")
connection_properties: PostgresConnectionProperties = Field(
..., alias="connection", description="Data source connection details"
)

@field_validator("connection_properties", mode="before")
def infer_connection_type(cls, value):
@classmethod
def infer_connection_type(cls, value: dict) -> PostgresConnectionProperties:
if "password" in value:
return PostgresConnectionPassword(**value)
elif "password_file" in value:
return PostgresConnectionPasswordFile(**value)
raise ValueError("Unknown connection structure")
raise ValueError("Unknown Postgres connection structure")
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ def _create_data_source_yaml_str(self) -> str:
return f"""
type: postgres
name: postgres_test_ds
connection:
host: {os.getenv("POSTGRES_HOST", "localhost")}
user: {os.getenv("POSTGRES_USERNAME", "soda_test")}
password: {os.getenv("POSTGRES_PASSWORD")}
port: {int(os.getenv("POSTGRES_PORT", "5432"))}
database: {self.dataset_prefix[0]}
connections:
default:
host: {os.getenv("POSTGRES_HOST", "localhost")}
user: {os.getenv("POSTGRES_USERNAME", "soda_test")}
password: {os.getenv("POSTGRES_PASSWORD")}
port: {int(os.getenv("POSTGRES_PORT", "5432"))}
database: {self.dataset_prefix[0]}
"""
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _create_sql_dialect(self) -> SqlDialect:

def _create_data_source_connection(self) -> DataSourceConnection:
return SnowflakeDataSourceConnection(
name=self.data_source_model.name, connection_properties=self.data_source_model.connection_properties
name=self.data_source_model.name, connection_properties=self.data_source_model.default_connection
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@

class SnowflakeDataSource(DataSourceBase, abc.ABC):
type: Literal["snowflake"] = Field("snowflake")
connection_properties: SnowflakeConnectionProperties = Field(
..., alias="connection", description="Snowflake connection configuration"
)

@field_validator("connection_properties", mode="before")
@classmethod
def infer_connection_type(cls, value):
if "password" in value:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ def _create_data_source_yaml_str(self) -> str:
return f"""
type: snowflake
name: SNOWFLAKE_TEST_DS
connection:
account: {os.getenv("SNOWFLAKE_ACCOUNT")}
user: {os.getenv("SNOWFLAKE_USER")}
password: {os.getenv("SNOWFLAKE_PASSWORD")}
database: {self.dataset_prefix[0]}
connections:
default:
account: {os.getenv("SNOWFLAKE_ACCOUNT")}
user: {os.getenv("SNOWFLAKE_USER")}
password: {os.getenv("SNOWFLAKE_PASSWORD")}
database: {self.dataset_prefix[0]}
"""

def create_test_schema_if_not_exists_sql(self) -> str:
Expand Down