diff --git a/dev-requirements.txt b/dev-requirements.txt index 33de8e444..51988662f 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.10 +# This file is autogenerated by pip-compile with Python 3.11 # by the following command: # # pip-compile dev-requirements.in @@ -18,8 +18,6 @@ distlib==0.3.8 # via virtualenv docopt==0.6.2 # via tbump -exceptiongroup==1.3.0 - # via pytest filelock==3.16.1 # via virtualenv freezegun==1.5.1 @@ -62,15 +60,8 @@ tabulate==0.8.10 # via cli-ui tbump==6.11.0 # via -r dev-requirements.in -tomli==2.2.1 - # via - # build - # pip-tools - # pytest tomlkit==0.11.8 # via tbump -typing-extensions==4.14.0 - # via exceptiongroup unidecode==1.3.8 # via cli-ui virtualenv==20.26.6 diff --git a/soda-core/src/soda_core/common/data_source_impl.py b/soda-core/src/soda_core/common/data_source_impl.py index a7b203a1f..c25e27d27 100644 --- a/soda-core/src/soda_core/common/data_source_impl.py +++ b/soda-core/src/soda_core/common/data_source_impl.py @@ -6,6 +6,7 @@ from soda_core.common.data_source_connection import DataSourceConnection from soda_core.common.data_source_results import QueryResult, UpdateResult +from soda_core.common.dataset_identifier import DatasetIdentifier from soda_core.common.exceptions import DataSourceConnectionException from soda_core.common.logging_constants import soda_logger from soda_core.common.sql_dialect import SqlDialect @@ -145,16 +146,16 @@ def get_max_aggregation_query_length(self) -> int: # BigQuery: No documented limit on query size, but practical limits on complexity and performance. return 63 * 1024 * 1024 - def is_different_data_type(self, expected_column: ColumnMetadata, actual_column_metadata: ColumnMetadata) -> bool: + def is_different_data_type(self, expected_column: ColumnMetadata, actual_column: ColumnMetadata) -> bool: canonical_expected_data_type: str = self.get_canonical_data_type(expected_column.data_type) - canonical_actual_data_type: str = self.get_canonical_data_type(actual_column_metadata.data_type) + canonical_actual_data_type: str = self.get_canonical_data_type(actual_column.data_type) if canonical_expected_data_type != canonical_actual_data_type: return True if ( isinstance(expected_column.character_maximum_length, int) - and expected_column.character_maximum_length != actual_column_metadata.character_maximum_length + and expected_column.character_maximum_length != actual_column.character_maximum_length ): return True @@ -203,6 +204,12 @@ def test_connection_error_message(self) -> Optional[str]: def build_data_source(self) -> DataSource: return DataSource(name=self.name, type=self.type_name) + def qualify_dataset_name(self, dataset_identifier: DatasetIdentifier) -> str: + assert dataset_identifier.data_source_name == self.name + return self.sql_dialect.qualify_dataset_name( + dataset_prefix=dataset_identifier.prefixes, dataset_name=dataset_identifier.dataset_name + ) + def quote_identifier(self, identifier: str) -> str: c = self.sql_dialect._get_default_quote_char() return f"{c}{identifier}{c}" diff --git a/soda-core/src/soda_core/common/extensions.py b/soda-core/src/soda_core/common/extensions.py index a8e495b83..2947a60e8 100644 --- a/soda-core/src/soda_core/common/extensions.py +++ b/soda-core/src/soda_core/common/extensions.py @@ -1,8 +1,6 @@ from importlib import import_module from typing import Callable, Optional -from soda_core.common.exceptions import ExtensionException - class Extensions: @classmethod @@ -11,7 +9,6 @@ def find_class_method(cls, module_name: str, class_name: str, method_name: str) module = import_module(module_name) class_ = getattr(module, class_name) return getattr(class_, method_name) - except AttributeError as e: - raise ExtensionException( - message=f"Feature '{class_name}.{method_name}' requires the Soda Extensions to be installed." - ) + except (AttributeError, ModuleNotFoundError) as e: + # Extension not installed + return None diff --git a/soda-core/src/soda_core/common/soda_cloud.py b/soda-core/src/soda_core/common/soda_cloud.py index d8c6abfe9..f7a61d473 100644 --- a/soda-core/src/soda_core/common/soda_cloud.py +++ b/soda-core/src/soda_core/common/soda_cloud.py @@ -228,9 +228,9 @@ def upload_contract_file(self, contract: Contract) -> Optional[str]: soda_cloud_file_path: str = f"{contract.soda_qualified_dataset_name.lower()}.yml" return self._upload_scan_yaml_file(yaml_str=contract_yaml_source_str, soda_cloud_file_path=soda_cloud_file_path) - def send_contract_result(self, contract_verification_result: ContractVerificationResult) -> bool: + def send_contract_result(self, contract_verification_result: ContractVerificationResult) -> Optional[dict]: """ - Returns True if a 200 OK was received, False otherwise + Returns A scanId string if a 200 OK was received, None otherwise """ contract_verification_result = _build_contract_result_json( contract_verification_result=contract_verification_result @@ -241,14 +241,13 @@ def send_contract_result(self, contract_verification_result: ContractVerificatio ) if response.status_code == 200: logger.info(f"{Emoticons.OK_HAND} Results sent to Soda Cloud") - response_json = response.json() + response_json: dict = response.json() if isinstance(response_json, dict): cloud_url: Optional[str] = response_json.get("cloudUrl") if isinstance(cloud_url, str): logger.info(f"To view the dataset on Soda Cloud, see {cloud_url}") - return True - else: - return False + return response_json + return None def send_contract_skeleton(self, contract_yaml_str: str, soda_cloud_file_path: str) -> None: file_id: Optional[str] = self._upload_scan_yaml_file( @@ -879,6 +878,9 @@ def _get_token(self) -> str: assert self.token, "No token in login response?!" return self.token + def send_failed_rows_diagnostics(self, scan_id: str, failed_rows_diagnostics: list[FailedRowsDiagnostic]): + print(f"TODO sending failed rows diagnostics for scan {scan_id} to Soda Cloud: {failed_rows_diagnostics}") + def to_jsonnable(o) -> object: if o is None or isinstance(o, str) or isinstance(o, int) or isinstance(o, float) or isinstance(o, bool): @@ -1196,3 +1198,25 @@ def _append_exception_to_cloud_log_dicts(cloud_log_dicts: list[dict], exception: exc_cloud_log_dict["index"] = len(cloud_log_dicts) cloud_log_dicts.append(exc_cloud_log_dict) return cloud_log_dicts + + +class FailedRowsDiagnostic: + def __init__(self, check_identity: str, name: str, query: str): + self.check_identity: str = check_identity + self.name: str = name + self.query: str = query + + +class QuerySourceFailedRowsDiagnostic(FailedRowsDiagnostic): + def __init__(self, check_identity: str, name: str, query: str): + super().__init__(check_identity, name, query) + + +class StoreKeysFailedRowsDiagnostic(FailedRowsDiagnostic): + def __init__(self, check_identity: str, name: str, query: str): + super().__init__(check_identity, name, query) + + +class StoreDataFailedRowsDiagnostic(FailedRowsDiagnostic): + def __init__(self, check_identity: str, name: str, query: str): + super().__init__(check_identity, name, query) diff --git a/soda-core/src/soda_core/common/sql_dialect.py b/soda-core/src/soda_core/common/sql_dialect.py index d29266bb2..670bb5b90 100644 --- a/soda-core/src/soda_core/common/sql_dialect.py +++ b/soda-core/src/soda_core/common/sql_dialect.py @@ -5,6 +5,7 @@ from numbers import Number from textwrap import dedent, indent +from soda_core.common.dataset_identifier import DatasetIdentifier from soda_core.common.sql_ast import * @@ -28,6 +29,11 @@ def quote_default(self, identifier: Optional[str]) -> Optional[str]: else None ) + def build_fully_qualified_sql_name(self, dataset_identifier: DatasetIdentifier) -> str: + return self.qualify_dataset_name( + dataset_prefix=dataset_identifier.prefixes, dataset_name=dataset_identifier.dataset_name + ) + def qualify_dataset_name(self, dataset_prefix: list[str], dataset_name: str) -> str: """ Creates a fully qualified table name, optionally quoting the table name @@ -85,14 +91,18 @@ def escape_string(self, value: str): def escape_regex(self, value: str): return value - def build_select_sql(self, select_elements: list) -> str: + def create_schema_if_not_exists_sql(self, schema_name: str) -> str: + quoted_schema_name: str = self.quote_default(schema_name) + return f"CREATE SCHEMA IF NOT EXISTS {quoted_schema_name};" + + def build_select_sql(self, select_elements: list, add_semicolon: bool = True) -> str: statement_lines: list[str] = [] statement_lines.extend(self._build_cte_sql_lines(select_elements)) statement_lines.extend(self._build_select_sql_lines(select_elements)) statement_lines.extend(self._build_from_sql_lines(select_elements)) statement_lines.extend(self._build_where_sql_lines(select_elements)) statement_lines.extend(self._build_order_by_lines(select_elements)) - return "\n".join(statement_lines) + ";" + return "\n".join(statement_lines) + (";" if add_semicolon else "") def _build_select_sql_lines(self, select_elements: list) -> list[str]: select_field_sqls: list[str] = [] diff --git a/soda-core/src/soda_core/common/statements/metadata_columns_query.py b/soda-core/src/soda_core/common/statements/metadata_columns_query.py index e849586ff..31409bcb2 100644 --- a/soda-core/src/soda_core/common/statements/metadata_columns_query.py +++ b/soda-core/src/soda_core/common/statements/metadata_columns_query.py @@ -13,6 +13,12 @@ class ColumnMetadata: data_type: str character_maximum_length: Optional[int] + def get_data_type_ddl(self) -> str: + if self.character_maximum_length is None: + return self.data_type + else: + return f"{self.data_type}({self.character_maximum_length})" + class MetadataColumnsQuery: def __init__(self, sql_dialect: SqlDialect, data_source_connection: DataSourceConnection): diff --git a/soda-core/src/soda_core/contracts/impl/check_types/invalidity_check.py b/soda-core/src/soda_core/contracts/impl/check_types/invalidity_check.py index 028c99fce..5533495b8 100644 --- a/soda-core/src/soda_core/contracts/impl/check_types/invalidity_check.py +++ b/soda-core/src/soda_core/contracts/impl/check_types/invalidity_check.py @@ -98,7 +98,7 @@ def __init__( ) else: self.invalid_count_metric_impl = self._resolve_metric( - InvalidCountMetric(contract_impl=contract_impl, column_impl=column_impl, check_impl=self) + InvalidCountMetricImpl(contract_impl=contract_impl, column_impl=column_impl, check_impl=self) ) self.row_count_metric = self._resolve_metric(RowCountMetricImpl(contract_impl=contract_impl, check_impl=self)) @@ -145,8 +145,11 @@ def evaluate(self, measurement_values: MeasurementValues, contract: Contract) -> diagnostic_metric_values=diagnostic_metric_values, ) + def get_threshold_metric_impl(self) -> Optional[MetricImpl]: + return self.invalid_count_metric_impl -class InvalidCountMetric(AggregationMetricImpl): + +class InvalidCountMetricImpl(AggregationMetricImpl): def __init__( self, contract_impl: ContractImpl, @@ -162,15 +165,17 @@ def __init__( ) def sql_expression(self) -> SqlExpression: + return SUM(CASE_WHEN(self.sql_condition_expression(), LITERAL(1))) + + def sql_condition_expression(self) -> SqlExpression: column_name: str = self.column_impl.column_yaml.name - invalid_count_condition: SqlExpression = AND.optional( + return AND.optional( [ SqlExpressionStr.optional(self.check_filter), NOT.optional(self.missing_and_validity.is_missing_expr(column_name)), self.missing_and_validity.is_invalid_expr(column_name), ] ) - return SUM(CASE_WHEN(invalid_count_condition, LITERAL(1))) def convert_db_value(self, value) -> int: # Note: expression SUM(CASE WHEN "id" IS NULL THEN 1 ELSE 0 END) gives NULL / None as a result if diff --git a/soda-core/src/soda_core/contracts/impl/check_types/missing_check.py b/soda-core/src/soda_core/contracts/impl/check_types/missing_check.py index c7e094c17..acc901c11 100644 --- a/soda-core/src/soda_core/contracts/impl/check_types/missing_check.py +++ b/soda-core/src/soda_core/contracts/impl/check_types/missing_check.py @@ -58,7 +58,7 @@ def __init__( ) self.metric_name = "missing_percent" if check_yaml.metric == "percent" else "missing_count" - self.missing_count_metric = self._resolve_metric( + self.missing_count_metric_impl = self._resolve_metric( MissingCountMetricImpl(contract_impl=contract_impl, column_impl=column_impl, check_impl=self) ) @@ -69,7 +69,7 @@ def __init__( self.missing_percent_metric_impl: MetricImpl = self.contract_impl.metrics_resolver.resolve_metric( DerivedPercentageMetricImpl( metric_type="missing_percent", - fraction_metric_impl=self.missing_count_metric, + fraction_metric_impl=self.missing_count_metric_impl, total_metric_impl=self.row_count_metric_impl, ) ) @@ -79,7 +79,7 @@ def evaluate(self, measurement_values: MeasurementValues, contract: Contract) -> diagnostic_metric_values: dict[str, float] = {} - missing_count: int = measurement_values.get_value(self.missing_count_metric) + missing_count: int = measurement_values.get_value(self.missing_count_metric_impl) diagnostic_metric_values["missing_count"] = missing_count row_count: int = measurement_values.get_value(self.row_count_metric_impl) @@ -104,6 +104,9 @@ def evaluate(self, measurement_values: MeasurementValues, contract: Contract) -> diagnostic_metric_values=diagnostic_metric_values, ) + def get_threshold_metric_impl(self) -> Optional[MetricImpl]: + return self.missing_count_metric_impl + class MissingCountMetricImpl(AggregationMetricImpl): def __init__( @@ -121,14 +124,16 @@ def __init__( ) def sql_expression(self) -> SqlExpression: + return SUM(CASE_WHEN(self.sql_condition_expression(), LITERAL(1))) + + def sql_condition_expression(self) -> SqlExpression: column_name: str = self.column_impl.column_yaml.name not_missing_and_invalid_expr = self.missing_and_validity.is_missing_expr(column_name) - missing_count_condition: SqlExpression = ( + return ( not_missing_and_invalid_expr if not self.check_filter else AND([SqlExpressionStr(self.check_filter), not_missing_and_invalid_expr]) ) - return SUM(CASE_WHEN(missing_count_condition, LITERAL(1), LITERAL(0))) def convert_db_value(self, value) -> int: # Note: expression SUM(CASE WHEN "id" IS NULL THEN 1 ELSE 0 END) gives NULL / None as a result if diff --git a/soda-core/src/soda_core/contracts/impl/check_types/schema_check.py b/soda-core/src/soda_core/contracts/impl/check_types/schema_check.py index 2f209f1a5..36693f49e 100644 --- a/soda-core/src/soda_core/contracts/impl/check_types/schema_check.py +++ b/soda-core/src/soda_core/contracts/impl/check_types/schema_check.py @@ -141,7 +141,7 @@ def evaluate(self, measurement_values: MeasurementValues, contract: Contract) -> actual_column_metadata and expected_column.data_type and self.contract_impl.data_source_impl.is_different_data_type( - expected_column=expected_column, actual_column_metadata=actual_column_metadata + expected_column=expected_column, actual_column=actual_column_metadata ) ): column_data_type_mismatches.append( diff --git a/soda-core/src/soda_core/contracts/impl/contract_verification_impl.py b/soda-core/src/soda_core/contracts/impl/contract_verification_impl.py index 29c549a06..e59619261 100644 --- a/soda-core/src/soda_core/contracts/impl/contract_verification_impl.py +++ b/soda-core/src/soda_core/contracts/impl/contract_verification_impl.py @@ -5,6 +5,7 @@ from datetime import timezone from enum import Enum from io import StringIO +from typing import Callable from ruamel.yaml import YAML from soda_core.common.consistent_hash_builder import ConsistentHashBuilder @@ -12,6 +13,7 @@ from soda_core.common.data_source_results import QueryResult from soda_core.common.dataset_identifier import DatasetIdentifier from soda_core.common.exceptions import InvalidRegexException, SodaCoreException +from soda_core.common.extensions import Extensions from soda_core.common.logging_constants import Emoticons, ExtraKeys, soda_logger from soda_core.common.logs import Location, Logs from soda_core.common.soda_cloud import SodaCloud @@ -31,6 +33,7 @@ ContractVerificationStatus, DataSource, Measurement, + SodaException, Threshold, YamlFileContentInfo, ) @@ -49,6 +52,28 @@ logger: logging.Logger = soda_logger +class ContractVerificationHandler: + @classmethod + def instance(cls, identifier: Optional[str] = None) -> Optional[ContractVerificationHandler]: + # TODO: replace with plugin extension mechanism + create_method: Callable[..., Optional[ContractVerificationHandler]] = Extensions.find_class_method( + module_name="soda.failed_rows_extractor.failed_rows_extractor", + class_name="FailedRowsExtractor", + method_name="create", + ) + return create_method() if create_method else None + + def handle( + self, + contract_impl: ContractImpl, + data_source_impl: DataSourceImpl, + contract_verification_result: ContractVerificationResult, + soda_cloud: SodaCloud, + soda_cloud_send_results_response_json: dict, + ): + pass + + class ContractVerificationSessionImpl: @classmethod def execute( @@ -177,6 +202,7 @@ def _execute_locally( contract_yaml=contract_yaml, only_validate_without_execute=only_validate_without_execute, data_timestamp=contract_yaml.data_timestamp, + execution_timestamp=contract_yaml.execution_timestamp, data_source_impl=data_source_impl, soda_cloud=soda_cloud_impl, publish_results=soda_cloud_publish_results, @@ -279,8 +305,9 @@ def __init__( logs: Logs, contract_yaml: ContractYaml, only_validate_without_execute: bool, - data_source_impl: DataSourceImpl, + data_source_impl: Optional[DataSourceImpl], data_timestamp: datetime, + execution_timestamp: datetime, soda_cloud: Optional[SodaCloud], publish_results: bool, ): @@ -295,6 +322,7 @@ def __init__( self.started_timestamp: datetime = datetime.now(tz=timezone.utc) + self.execution_timestamp: datetime = execution_timestamp self.data_timestamp: datetime = data_timestamp self.dataset_name: Optional[str] = None @@ -308,15 +336,22 @@ def __init__( self.column_impls: list[ColumnImpl] = [] self.check_impls: list[CheckImpl] = [] + # TODO replace usage of self.soda_qualified_dataset_name with self.dataset_identifier self.soda_qualified_dataset_name = contract_yaml.dataset - - self.sql_qualified_dataset_name: Optional[str] = ( - data_source_impl.sql_dialect.qualify_dataset_name( + # TODO replace usage of self.sql_qualified_dataset_name with self.dataset_identifier + self.sql_qualified_dataset_name: Optional[str] = None + + self.dataset_identifier: Optional[DatasetIdentifier] = None + if data_source_impl: + self.dataset_identifier = DatasetIdentifier( + data_source_name=self.data_source_impl.name, + prefixes=self.dataset_prefix, + dataset_name=self.dataset_name, + ) + # TODO replace usage of self.sql_qualified_dataset_name with self.dataset_identifier + self.sql_qualified_dataset_name = data_source_impl.sql_dialect.qualify_dataset_name( dataset_prefix=self.dataset_prefix, dataset_name=self.dataset_name ) - if data_source_impl - else None - ) self.column_impls: list[ColumnImpl] = self._parse_columns(contract_yaml=contract_yaml) self.check_impls: list[CheckImpl] = self._parse_checks(contract_yaml) @@ -333,9 +368,11 @@ def __init__( ) self._verify_duplicate_identities(self.all_check_impls) - self.metrics: list[MetricImpl] = self.metrics_resolver.get_resolved_metrics() - self.queries: list[Query] = self._build_queries() if data_source_impl else [] + + self.queries: list[Query] = [] + if data_source_impl: + self.queries = self._build_queries() def _dataset_checks_came_before_columns_in_yaml(self) -> Optional[bool]: contract_keys: list[str] = self.contract_yaml.contract_yaml_object.keys() @@ -466,18 +503,30 @@ def verify(self) -> ContractVerificationResult: contract_verification_result.log_records = self.logs.pop_log_records() + soda_cloud_response_json: Optional[dict] = None if self.soda_cloud and self.publish_results: file_id: Optional[str] = self.soda_cloud.upload_contract_file(contract_verification_result.contract) if file_id: # Side effect to pass file id to console logging later on. TODO reconsider this contract.source.soda_cloud_file_id = file_id # send_contract_result will use contract.source.soda_cloud_file_id - response_ok: bool = self.soda_cloud.send_contract_result(contract_verification_result) - if not response_ok: + soda_cloud_response_json = self.soda_cloud.send_contract_result(contract_verification_result) + scan_id: Optional[str] = soda_cloud_response_json.get("scanId") + if not scan_id: contract_verification_result.sending_results_to_soda_cloud_failed = True else: logger.debug(f"Not sending results to Soda Cloud {Emoticons.CROSS_MARK}") + contract_verification_handler: Optional[ContractVerificationHandler] = ContractVerificationHandler.instance() + if contract_verification_handler: + contract_verification_handler.handle( + contract_impl=self, + data_source_impl=self.data_source_impl, + contract_verification_result=contract_verification_result, + soda_cloud=self.soda_cloud, + soda_cloud_send_results_response_json=soda_cloud_response_json, + ) + return contract_verification_result def build_log_summary(self, contract_verification_result: ContractVerificationResult) -> str: @@ -1090,6 +1139,12 @@ def _build_definition(self) -> str: def _build_threshold(self) -> Optional[Threshold]: return self.threshold.to_threshold_info() if self.threshold else None + def get_threshold_metric_impl(self) -> Optional[MetricImpl]: + """ + Used in extensions + """ + raise SodaException(f"Check type '{self.type}' does not support get_threshold_metric_impl'") + class MissingAndValidityCheckImpl(CheckImpl): def __init__( @@ -1158,6 +1213,10 @@ def __eq__(self, other): return False return self.id == other.id + @abstractmethod + def sql_condition_expression(self) -> Optional[SqlExpression]: + pass + class AggregationMetricImpl(MetricImpl): def __init__( @@ -1180,6 +1239,12 @@ def __init__( def sql_expression(self) -> SqlExpression: pass + @abstractmethod + def sql_condition_expression(self) -> SqlExpression: + """ + Used in extensions + """ + def convert_db_value(self, value: any) -> any: return value diff --git a/soda-core/src/soda_core/contracts/impl/contract_yaml.py b/soda-core/src/soda_core/contracts/impl/contract_yaml.py index 5c7f44fed..93e8f1c62 100644 --- a/soda-core/src/soda_core/contracts/impl/contract_yaml.py +++ b/soda-core/src/soda_core/contracts/impl/contract_yaml.py @@ -155,11 +155,11 @@ def __init__( self.variables: list[VariableYaml] = self._parse_variable_yamls(contract_yaml_source, provided_variable_values) - soda_now: datetime = datetime.now(timezone.utc) - self.data_timestamp: datetime = self._get_data_timestamp(data_timestamp, soda_now) + self.execution_timestamp: datetime = datetime.now(timezone.utc) + self.data_timestamp: datetime = self._get_data_timestamp(data_timestamp, self.execution_timestamp) soda_variable_values: dict[str, str] = { - "NOW": convert_datetime_to_str(soda_now), + "NOW": convert_datetime_to_str(self.execution_timestamp), "DATA_TIMESTAMP": convert_datetime_to_str(self.data_timestamp), } @@ -570,6 +570,9 @@ def __init__(self, type_name: str, check_yaml_object: YamlObject): qualifier = check_yaml_object.read_value("qualifier") if check_yaml_object else None self.qualifier: Optional[str] = str(qualifier) if qualifier is not None else None self.filter: Optional[str] = check_yaml_object.read_string_opt("filter") if check_yaml_object else None + self.store_failed_rows: Optional[bool] = ( + check_yaml_object.read_bool_opt("store_failed_rows", default_value=False) if check_yaml_object else None + ) if self.filter: self.filter = self.filter.strip() diff --git a/soda-core/src/soda_core/model/data_source/data_source.py b/soda-core/src/soda_core/model/data_source/data_source.py index b447910e2..3f0562bb3 100644 --- a/soda-core/src/soda_core/model/data_source/data_source.py +++ b/soda-core/src/soda_core/model/data_source/data_source.py @@ -1,10 +1,11 @@ import abc -from typing import Literal +from typing import Literal, Optional from pydantic import BaseModel, Field from soda_core.model.data_source.data_source_connection_properties import ( DataSourceConnectionProperties, ) +from soda_core.model.failed_rows import FailedRowsConfigDatasource class DataSourceBase( @@ -20,6 +21,7 @@ class DataSourceBase( connection_properties: DataSourceConnectionProperties = Field( ..., alias="connection", description="Data source connection details" ) + failed_rows: Optional[FailedRowsConfigDatasource] = Field(None, description="Configuration for failed rows storage") @classmethod def get_class_type(cls) -> str: diff --git a/soda-core/src/soda_core/model/failed_rows.py b/soda-core/src/soda_core/model/failed_rows.py new file mode 100644 index 000000000..cf3cf0085 --- /dev/null +++ b/soda-core/src/soda_core/model/failed_rows.py @@ -0,0 +1,79 @@ +import abc +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field + +DEFAULT_SCHEMA = "failed_rows" + + +class FailedRowsStrategy(str, Enum): + """Enum for failed rows strategy""" + + NONE = "none" + STORE_DIAGNOSTICS = "store_diagnostics" + STORE_KEYS = "store_keys" + STORE_DATA = "store_data" + + +class FailedRowsConfigBase( + BaseModel, + abc.ABC, + frozen=True, + extra="forbid", + validate_by_name=True, # Allow to use both field names and aliases when populating from dict +): + enabled: Optional[bool] = Field( + None, + description="Enable or disable the storage of failed rows. " "If set to false, failed rows will not be stored.", + ) + + +class FailedRowsConfigOrganisation(FailedRowsConfigBase): + """Top-level configuration for failed rows, used on Soda Cloud.""" + + # Overrides the `enabled` field on other levels as Cloud config will return a bool always. + enabled: bool = Field( + ..., + description="Enable or disable the storage of failed rows. " "If set to false, failed rows will not be stored.", + ) + + path_default: Optional[str] = Field( + "{{ data_source.database }}/{DEFAULT_SCHEMA}", # TODO: revisit + description="Path to the warehouse location where failed rows will be stored. ", + ) + enabled_by_default: bool = Field( + True, + description="Enable or disable the storage of failed rows by default. Does not override the `enabled` setting if `enabled` is set to false." + "If set to false, failed rows will not be stored unless explicitly enabled in the contract or check.", + ) + strategy_default: FailedRowsStrategy = Field( + FailedRowsStrategy.STORE_DIAGNOSTICS, description="Default strategy for storing failed rows." + ) + + +class FailedRowsConfigDatasource(FailedRowsConfigBase): + """Top-level configuration for failed rows, on data source level.""" + + path: Optional[str] = Field(None, description="Path to the warehouse location where failed rows will be stored.") + enabled_by_default: Optional[bool] = Field( + True, + description="Enable or disable the storage of failed rows by default. Does not override the `enabled` setting if `enabled` is set to false." + "If set to false, failed rows will not be stored unless explicitly enabled in the contract or check.", + ) + strategy_default: Optional[FailedRowsStrategy] = Field( + FailedRowsStrategy.STORE_DIAGNOSTICS, description="Default strategy for storing failed rows." + ) + + +class FailedRowsConfigContract(FailedRowsConfigBase): + """Configuration for failed rows at the contract level.""" + + path: Optional[str] = Field(None, description="Path to the warehouse location where failed rows will be stored.") + strategy: Optional[FailedRowsStrategy] = Field(None, description="Strategy for storing failed rows.") + + +class FailedRowsConfigCheck(FailedRowsConfigBase): + """Configuration for failed rows at the check level.""" + + strategy: Optional[FailedRowsStrategy] = Field(None, description="Strategy for storing failed rows.") diff --git a/soda-postgres/src/soda_postgres/common/data_sources/postgres_data_source.py b/soda-postgres/src/soda_postgres/common/data_sources/postgres_data_source.py index ff6f7c107..4d3897360 100644 --- a/soda-postgres/src/soda_postgres/common/data_sources/postgres_data_source.py +++ b/soda-postgres/src/soda_postgres/common/data_sources/postgres_data_source.py @@ -30,3 +30,7 @@ def __init__(self): def _build_regex_like_sql(self, matches: REGEX_LIKE) -> str: expression: str = self.build_expression_sql(matches.expression) return f"{expression} ~ '{matches.regex_pattern}'" + + def create_schema_if_not_exists_sql(self, schema_name: str) -> str: + quoted_schema_name: str = self.quote_default(schema_name) + return f"CREATE SCHEMA IF NOT EXISTS {quoted_schema_name} AUTHORIZATION CURRENT_USER;" diff --git a/soda-tests/src/helpers/data_source_test_helper.py b/soda-tests/src/helpers/data_source_test_helper.py index bef5b2720..350f9769c 100644 --- a/soda-tests/src/helpers/data_source_test_helper.py +++ b/soda-tests/src/helpers/data_source_test_helper.py @@ -281,7 +281,8 @@ def create_test_schema_if_not_exists(self) -> None: self.data_source_impl.execute_update(sql) def create_test_schema_if_not_exists_sql(self) -> str: - return f"CREATE SCHEMA IF NOT EXISTS {self.dataset_prefix[1]} AUTHORIZATION CURRENT_USER;" + schema_name: str = self.dataset_prefix[1] + return self.data_source_impl.sql_dialect.create_schema_if_not_exists_sql(schema_name) def drop_test_schema_if_exists(self) -> None: sql: str = self.drop_test_schema_if_exists_sql() diff --git a/soda-tests/src/helpers/databricks_data_source_test_helper.py b/soda-tests/src/helpers/databricks_data_source_test_helper.py index d39736cda..991d97f53 100644 --- a/soda-tests/src/helpers/databricks_data_source_test_helper.py +++ b/soda-tests/src/helpers/databricks_data_source_test_helper.py @@ -34,9 +34,6 @@ def _create_data_source_yaml_str(self) -> str: catalog: {os.getenv("DATABRICKS_CATALOG", "unity_catalog")} """ - def create_test_schema_if_not_exists_sql(self) -> str: - return f"CREATE SCHEMA IF NOT EXISTS {self.dataset_prefix[1]};" - def _get_contract_data_type_dict(self) -> dict[str, str]: return { TestDataType.TEXT: "varchar", diff --git a/soda-tests/src/helpers/snowflake_data_source_test_helper.py b/soda-tests/src/helpers/snowflake_data_source_test_helper.py index 636a2991e..e1366e980 100644 --- a/soda-tests/src/helpers/snowflake_data_source_test_helper.py +++ b/soda-tests/src/helpers/snowflake_data_source_test_helper.py @@ -34,11 +34,6 @@ def _create_data_source_yaml_str(self) -> str: database: {self.dataset_prefix[0]} """ - def create_test_schema_if_not_exists_sql(self) -> str: - sql_dialect: "SqlDialect" = self.data_source_impl.sql_dialect - schema_name: str = self.dataset_prefix[1] - return f"CREATE SCHEMA IF NOT EXISTS {sql_dialect.quote_default(schema_name)};" - def _adjust_schema_name(self, schema_name: str) -> str: return schema_name.upper() diff --git a/soda-tests/tests/components/manual_test_agent_flow.py b/soda-tests/tests/components/manual_test_agent_flow.py index 7fb93131d..e9276d3d5 100644 --- a/soda-tests/tests/components/manual_test_agent_flow.py +++ b/soda-tests/tests/components/manual_test_agent_flow.py @@ -2,6 +2,7 @@ from dotenv import load_dotenv from soda_core.common.logging_configuration import configure_logging +from soda_core.common.soda_cloud import SodaCloud from soda_core.common.yaml import ContractYamlSource, SodaCloudYamlSource from soda_core.contracts.contract_verification import ContractVerificationSession @@ -31,13 +32,16 @@ def main(): soda_cloud_yaml_str = dedent( """ soda_cloud: - bla: bla + api_key_id: ${SODA_CLOUD_API_KEY_ID} + api_key_secret: ${SODA_CLOUD_API_KEY_SECRET} """ ).strip() + soda_cloud: SodaCloud = SodaCloud.from_yaml_source(SodaCloudYamlSource.from_str(soda_cloud_yaml_str), provided_variable_values={}) + ContractVerificationSession.execute( contract_yaml_sources=[ContractYamlSource.from_str(contract_yaml_str)], - soda_cloud_yaml_source=SodaCloudYamlSource.from_str(soda_cloud_yaml_str), + soda_cloud_impl=soda_cloud, soda_cloud_use_agent=True, soda_cloud_use_agent_blocking_timeout_in_minutes=55, ) diff --git a/soda-tests/tests/features/test_duplicate_column_check.py b/soda-tests/tests/features/test_duplicate_column_check.py index f38a25c51..fd80137b8 100644 --- a/soda-tests/tests/features/test_duplicate_column_check.py +++ b/soda-tests/tests/features/test_duplicate_column_check.py @@ -116,7 +116,7 @@ def test_duplicate_metric_typo_error(data_source_test_helper: DataSourceTestHelp metric: percentttt """, ) - assert "'metric' must be in ['count', 'percent']" == contract_verification_result.get_errors_str() + assert "'metric' must be in ['count', 'percent']" in contract_verification_result.get_errors_str() def test_duplicate_with_check_filter(data_source_test_helper: DataSourceTestHelper):