From 6edda4ee5c13dba30f4e0698b5a32407723e686a Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Thu, 27 Nov 2025 13:20:36 +0000 Subject: [PATCH 01/24] feat: extended is_aggr_* --- .../docs/guide/quality_checks_definition.mdx | 131 ++++++++++++ docs/dqx/docs/reference/quality_checks.mdx | 124 ++++++++++++ src/databricks/labs/dqx/check_funcs.py | 187 +++++++++++++++--- tests/integration/test_dataset_checks.py | 153 ++++++++++++++ 4 files changed, 572 insertions(+), 23 deletions(-) diff --git a/docs/dqx/docs/guide/quality_checks_definition.mdx b/docs/dqx/docs/guide/quality_checks_definition.mdx index 621824e4f..148699dc8 100644 --- a/docs/dqx/docs/guide/quality_checks_definition.mdx +++ b/docs/dqx/docs/guide/quality_checks_definition.mdx @@ -135,6 +135,41 @@ This approach provides static type checking and autocompletion in IDEs, making i column="col1", check_func_kwargs={"aggr_type": "avg", "group_by": ["col2"], "limit": 10.5}, ), + + # Extended aggregate functions for advanced data quality checks + + DQDatasetRule( # Uniqueness check: each country should have one country code + criticality="error", + check_func=check_funcs.is_aggr_not_greater_than, + column="country_code", + check_func_kwargs={ + "aggr_type": "count_distinct", + "group_by": ["country"], + "limit": 1 + }, + ), + + DQDatasetRule( # Anomaly detection: detect unusual variance in sensor readings + criticality="warn", + check_func=check_funcs.is_aggr_not_greater_than, + column="temperature", + check_func_kwargs={ + "aggr_type": "stddev", + "group_by": ["machine_id"], + "limit": 5.0 + }, + ), + + DQDatasetRule( # SLA monitoring: P95 latency must be under 1 second + criticality="error", + check_func=check_funcs.is_aggr_not_greater_than, + column="latency_ms", + check_func_kwargs={ + "aggr_type": "percentile", + "aggr_params": {"percentile": 0.95}, + "limit": 1000 + }, + ), ] ``` @@ -144,6 +179,102 @@ This approach provides static type checking and autocompletion in IDEs, making i The validation of arguments and keyword arguments for the check function is automatically performed upon creating a `DQRowRule`. +## Practical Use Cases for Extended Aggregates + +### Uniqueness Validation with count_distinct + +Ensure referential integrity by verifying that each entity has exactly one identifier: + +```yaml +# Each country should have exactly one country code +- criticality: error + check: + function: is_aggr_not_greater_than + arguments: + column: country_code + aggr_type: count_distinct + group_by: + - country + limit: 1 +``` + +### Anomaly Detection with Statistical Functions + +Detect unusual patterns in sensor data or business metrics: + +```yaml +# Alert on unusually high temperature variance per machine +- criticality: warn + check: + function: is_aggr_not_greater_than + arguments: + column: temperature + aggr_type: stddev + group_by: + - machine_id + limit: 5.0 + +# Monitor revenue stability across product lines +- criticality: error + check: + function: is_aggr_not_greater_than + arguments: + column: daily_revenue + aggr_type: variance + group_by: + - product_line + limit: 1000000.0 +``` + +### SLA and Performance Monitoring with Percentiles + +Monitor service performance and ensure SLA compliance: + +```yaml +# P95 API latency must be under 1 second +- criticality: error + check: + function: is_aggr_not_greater_than + arguments: + column: latency_ms + aggr_type: percentile + aggr_params: + percentile: 0.95 + limit: 1000 + +# P99 response time check (fast approximate) +- criticality: warn + check: + function: is_aggr_not_greater_than + arguments: + column: response_time_ms + aggr_type: approx_percentile + aggr_params: + percentile: 0.99 + accuracy: 10000 + limit: 5000 + +# Median baseline for order processing time +- criticality: warn + check: + function: is_aggr_not_less_than + arguments: + column: processing_time_sec + aggr_type: median + group_by: + - order_type + limit: 30.0 +``` + + +- **count_distinct**: Uniqueness validation, cardinality checks +- **stddev/variance**: Anomaly detection, consistency monitoring +- **percentile**: SLA compliance, outlier detection (P95, P99) +- **median**: Baseline checks, central tendency +- **approx_percentile**: Fast percentile estimation for large datasets +- **approx_count_distinct**: Efficient cardinality for high-cardinality columns + + ### Checks defined using metadata (list of dictionaries) Checks can be defined using declarative syntax as a list of dictionaries. diff --git a/docs/dqx/docs/reference/quality_checks.mdx b/docs/dqx/docs/reference/quality_checks.mdx index 4587eab77..dc9101d54 100644 --- a/docs/dqx/docs/reference/quality_checks.mdx +++ b/docs/dqx/docs/reference/quality_checks.mdx @@ -1601,6 +1601,130 @@ Complex data types are supported as well. - col3 limit: 200 +## Aggregate Function Types + +The `is_aggr_*` functions support a wide range of aggregate functions for data quality checks. Functions are categorized as **curated** (recommended for DQ) or **custom** (user-defined, validated at runtime). + +### Curated Aggregate Functions + +Curated functions are validated and optimized for data quality use cases: + +#### Basic Aggregations +- `count` - Count records (nulls excluded) +- `sum` - Sum of values +- `avg` - Average of values +- `min` - Minimum value +- `max` - Maximum value + +#### Cardinality & Uniqueness +- `count_distinct` - Count unique values (e.g., "each country has one code") +- `approx_count_distinct` - Approximate distinct count (faster for large datasets) +- `count_if` - Conditional counting + +#### Statistical Analysis +- `stddev` / `stddev_samp` - Sample standard deviation (anomaly detection) +- `stddev_pop` - Population standard deviation +- `variance` / `var_samp` - Sample variance (stability checks) +- `var_pop` - Population variance +- `median` - 50th percentile baseline +- `skewness` - Distribution skewness (detect asymmetry) +- `kurtosis` - Distribution kurtosis (detect heavy tails) + +#### SLA & Performance Monitoring +- `percentile` - Exact percentile (requires `aggr_params`) +- `approx_percentile` - Approximate percentile (faster, requires `aggr_params`) + +#### Correlation Analysis +- `corr` - Correlation coefficient +- `covar_pop` - Population covariance +- `covar_samp` - Sample covariance + +### Custom Aggregates + +Custom aggregate functions (including UDAFs) are supported with runtime validation. A warning will be issued, and the function must return a single numeric value per group. + +**Not suitable:** Functions returning arrays, structs, or maps (e.g., `collect_list`, `collect_set`, `string_agg`) will fail validation. + +### Extended Examples + +```yaml +# count_distinct: Ensure each country has exactly one country code +- criticality: error + check: + function: is_aggr_not_greater_than + arguments: + column: country_code + aggr_type: count_distinct + group_by: + - country + limit: 1 + +# Standard deviation: Detect unusual variance in sensor readings per machine +- criticality: warn + check: + function: is_aggr_not_greater_than + arguments: + column: temperature + aggr_type: stddev + group_by: + - machine_id + limit: 5.0 + +# Percentile: Monitor P95 latency for SLA compliance +- criticality: error + check: + function: is_aggr_not_greater_than + arguments: + column: latency_ms + aggr_type: percentile + aggr_params: + percentile: 0.95 + limit: 1000 + +# Approximate percentile: Fast P99 check with accuracy control +- criticality: warn + check: + function: is_aggr_not_greater_than + arguments: + column: response_time + aggr_type: approx_percentile + aggr_params: + percentile: 0.99 + accuracy: 10000 + limit: 5000 + +# Median: Baseline check for central tendency +- criticality: warn + check: + function: is_aggr_not_less_than + arguments: + column: order_value + aggr_type: median + group_by: + - product_category + limit: 50.0 + +# Variance: Stability check for financial data +- criticality: error + check: + function: is_aggr_not_greater_than + arguments: + column: daily_revenue + aggr_type: variance + limit: 1000000.0 + +# Approximate count distinct: Efficient cardinality estimation +- criticality: warn + check: + function: is_aggr_not_greater_than + arguments: + column: user_id + aggr_type: approx_count_distinct + group_by: + - session_date + limit: 100000 +``` + # foreign_key check using reference DataFrame - criticality: error check: diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index 5ab24df1f..ab20c39d6 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -27,6 +27,37 @@ IPV4_MAX_OCTET_COUNT = 4 IPV4_BIT_LENGTH = 32 +# Curated aggregate functions for data quality checks +CURATED_AGGR_FUNCTIONS = { + # Basic aggregations (5) + "count", + "sum", + "avg", + "min", + "max", + # Cardinality & Uniqueness (3) + "count_distinct", + "approx_count_distinct", + "count_if", + # Statistical Analysis (9) + "stddev", + "stddev_pop", + "stddev_samp", + "variance", + "var_pop", + "var_samp", + "median", + "skewness", + "kurtosis", + # SLA & Performance Monitoring (2) + "percentile", + "approx_percentile", + # Correlation Analysis (3) + "corr", + "covar_pop", + "covar_samp", +} + class DQPattern(Enum): """Enum class to represent DQ patterns used to match data in columns.""" @@ -1263,20 +1294,23 @@ def is_aggr_not_greater_than( column: str | Column, limit: int | float | str | Column, aggr_type: str = "count", + aggr_params: dict[str, any] | None = None, group_by: list[str | Column] | None = None, row_filter: str | None = None, ) -> tuple[Column, Callable]: """ Build an aggregation check condition and closure for dataset-level validation. - This function verifies that an aggregation (count, sum, avg, min, max) on a column - or group of columns does not exceed a specified limit. Rows where the aggregation - result exceeds the limit are flagged. + This function verifies that an aggregation on a column or group of columns does not exceed + a specified limit. Supports curated aggregate functions (count, sum, avg, stddev, percentile, etc.) + and custom aggregates. Rows where the aggregation result exceeds the limit are flagged. Args: column: Column name (str) or Column expression to aggregate. limit: Numeric value, column name, or SQL expression for the limit. - aggr_type: Aggregation type: 'count', 'sum', 'avg', 'min', or 'max' (default: 'count'). + aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max, + count_distinct, stddev, percentile, and more. Custom aggregates are supported with validation. + aggr_params: Optional dict of parameters for aggregates requiring them (e.g., {"percentile": 0.95}). group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. @@ -1289,6 +1323,7 @@ def is_aggr_not_greater_than( column, limit, aggr_type, + aggr_params, group_by, row_filter, compare_op=py_operator.gt, @@ -1302,20 +1337,23 @@ def is_aggr_not_less_than( column: str | Column, limit: int | float | str | Column, aggr_type: str = "count", + aggr_params: dict[str, any] | None = None, group_by: list[str | Column] | None = None, row_filter: str | None = None, ) -> tuple[Column, Callable]: """ Build an aggregation check condition and closure for dataset-level validation. - This function verifies that an aggregation (count, sum, avg, min, max) on a column - or group of columns is not below a specified limit. Rows where the aggregation - result is below the limit are flagged. + This function verifies that an aggregation on a column or group of columns is not below + a specified limit. Supports curated aggregate functions (count, sum, avg, stddev, percentile, etc.) + and custom aggregates. Rows where the aggregation result is below the limit are flagged. Args: column: Column name (str) or Column expression to aggregate. limit: Numeric value, column name, or SQL expression for the limit. - aggr_type: Aggregation type: 'count', 'sum', 'avg', 'min', or 'max' (default: 'count'). + aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max, + count_distinct, stddev, percentile, and more. Custom aggregates are supported with validation. + aggr_params: Optional dict of parameters for aggregates requiring them (e.g., {"percentile": 0.95}). group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. @@ -1328,6 +1366,7 @@ def is_aggr_not_less_than( column, limit, aggr_type, + aggr_params, group_by, row_filter, compare_op=py_operator.lt, @@ -1341,20 +1380,23 @@ def is_aggr_equal( column: str | Column, limit: int | float | str | Column, aggr_type: str = "count", + aggr_params: dict[str, any] | None = None, group_by: list[str | Column] | None = None, row_filter: str | None = None, ) -> tuple[Column, Callable]: """ Build an aggregation check condition and closure for dataset-level validation. - This function verifies that an aggregation (count, sum, avg, min, max) on a column - or group of columns is equal to a specified limit. Rows where the aggregation - result is not equal to the limit are flagged. + This function verifies that an aggregation on a column or group of columns is equal to + a specified limit. Supports curated aggregate functions (count, sum, avg, stddev, percentile, etc.) + and custom aggregates. Rows where the aggregation result is not equal to the limit are flagged. Args: column: Column name (str) or Column expression to aggregate. limit: Numeric value, column name, or SQL expression for the limit. - aggr_type: Aggregation type: 'count', 'sum', 'avg', 'min', or 'max' (default: 'count'). + aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max, + count_distinct, stddev, percentile, and more. Custom aggregates are supported with validation. + aggr_params: Optional dict of parameters for aggregates requiring them (e.g., {"percentile": 0.95}). group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. @@ -1367,6 +1409,7 @@ def is_aggr_equal( column, limit, aggr_type, + aggr_params, group_by, row_filter, compare_op=py_operator.ne, @@ -1380,20 +1423,23 @@ def is_aggr_not_equal( column: str | Column, limit: int | float | str | Column, aggr_type: str = "count", + aggr_params: dict[str, any] | None = None, group_by: list[str | Column] | None = None, row_filter: str | None = None, ) -> tuple[Column, Callable]: """ Build an aggregation check condition and closure for dataset-level validation. - This function verifies that an aggregation (count, sum, avg, min, max) on a column - or group of columns is not equal to a specified limit. Rows where the aggregation - result is equal to the limit are flagged. + This function verifies that an aggregation on a column or group of columns is not equal to + a specified limit. Supports curated aggregate functions (count, sum, avg, stddev, percentile, etc.) + and custom aggregates. Rows where the aggregation result is equal to the limit are flagged. Args: column: Column name (str) or Column expression to aggregate. limit: Numeric value, column name, or SQL expression for the limit. - aggr_type: Aggregation type: 'count', 'sum', 'avg', 'min', or 'max' (default: 'count'). + aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max, + count_distinct, stddev, percentile, and more. Custom aggregates are supported with validation. + aggr_params: Optional dict of parameters for aggregates requiring them (e.g., {"percentile": 0.95}). group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. @@ -1406,6 +1452,7 @@ def is_aggr_not_equal( column, limit, aggr_type, + aggr_params, group_by, row_filter, compare_op=py_operator.eq, @@ -2221,10 +2268,52 @@ def _add_compare_condition( ) +def _validate_custom_aggregate( + agg_df: DataFrame, + aggr_type: str, + metric_col: str, + expected_row_count: int, +) -> None: + """ + Validate custom aggregate returns proper single numeric value per group. + + This function validates that a custom (non-curated) aggregate function behaves correctly + by ensuring it returns the expected number of rows and a numeric data type that can be + compared to limits. + + Args: + agg_df: DataFrame with aggregated results. + aggr_type: Name of the aggregate function being validated. + metric_col: Column name containing the aggregate result. + expected_row_count: Expected number of rows (1 for no group_by, count of groups for group_by). + + Raises: + InvalidParameterError: If the aggregate returns more rows than expected (not a proper aggregate) + or if it returns a non-numeric type (Array, Map, Struct) that cannot be compared to limits. + """ + # Check 1: Row count (must return exactly expected rows) + actual_count = agg_df.count() + if actual_count > expected_row_count: + raise InvalidParameterError( + f"Aggregate function '{aggr_type}' returned {actual_count} rows, " + f"expected {expected_row_count}. This is not a proper aggregate function." + ) + + # Check 2: Return type (must be numeric, not array/struct/map) + result_type = agg_df.schema[metric_col].dataType + if isinstance(result_type, (types.ArrayType, types.MapType, types.StructType)): + raise InvalidParameterError( + f"Aggregate function '{aggr_type}' returned {result_type.typeName()} " + f"which cannot be compared to numeric limits. " + f"Use aggregate functions that return numeric values (e.g., count, sum, avg)." + ) + + def _is_aggr_compare( column: str | Column, limit: int | float | str | Column, aggr_type: str, + aggr_params: dict[str, any] | None, group_by: list[str | Column] | None, row_filter: str | None, compare_op: Callable[[Column, Column], Column], @@ -2240,7 +2329,11 @@ def _is_aggr_compare( Args: column: Column name (str) or Column expression to aggregate. limit: Numeric value, column name, or SQL expression for the limit. - aggr_type: Aggregation type: 'count', 'sum', 'avg', 'min', or 'max'. + aggr_type: Aggregation type. Curated functions include 'count', 'sum', 'avg', 'min', 'max', + 'count_distinct', 'stddev', 'percentile', and more. Custom aggregate functions are + supported but will trigger a warning and runtime validation. + aggr_params: Optional dictionary of parameters for aggregate functions that require them + (e.g., percentile functions need {"percentile": 0.95}). group_by: Optional list of columns or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. compare_op: Comparison operator (e.g., operator.gt, operator.lt). @@ -2253,11 +2346,19 @@ def _is_aggr_compare( - A closure that applies the aggregation check logic. Raises: - InvalidParameterError: If an unsupported aggregation type is provided. - """ - supported_aggr_types = {"count", "sum", "avg", "min", "max"} - if aggr_type not in supported_aggr_types: - raise InvalidParameterError(f"Unsupported aggregation type: {aggr_type}. Supported: {supported_aggr_types}") + InvalidParameterError: If a custom aggregate returns non-numeric or multiple rows per group. + MissingParameterError: If required parameters for specific aggregates are not provided. + """ + # Warn if using non-curated aggregate function + is_curated = aggr_type in CURATED_AGGR_FUNCTIONS + if not is_curated: + warnings.warn( + f"Using non-curated aggregate function '{aggr_type}'. " + f"Curated functions: {', '.join(sorted(CURATED_AGGR_FUNCTIONS))}. " + f"Custom aggregates must return a single numeric value per group.", + UserWarning, + stacklevel=2, + ) aggr_col_str_norm, aggr_col_str, aggr_col_expr = _get_normalized_column_and_expr(column) @@ -2300,7 +2401,41 @@ def apply(df: DataFrame) -> DataFrame: """ filter_col = F.expr(row_filter) if row_filter else F.lit(True) filtered_expr = F.when(filter_col, aggr_col_expr) if row_filter else aggr_col_expr - aggr_expr = getattr(F, aggr_type)(filtered_expr) + + # Build aggregation expression based on function type + if aggr_type == "count_distinct": + # Spark uses countDistinct, not count_distinct + aggr_expr = F.countDistinct(filtered_expr) + elif aggr_type in ("percentile", "approx_percentile"): + # Percentile functions require percentile parameter + if not aggr_params or "percentile" not in aggr_params: + raise MissingParameterError( + f"'{aggr_type}' requires aggr_params with 'percentile' key (e.g., {{'percentile': 0.95}})" + ) + pct = aggr_params["percentile"] + + if aggr_type == "percentile": + aggr_expr = F.percentile(filtered_expr, pct) + else: # approx_percentile + # Check if accuracy parameter is provided + if "accuracy" in aggr_params: + aggr_expr = F.approx_percentile(filtered_expr, pct, aggr_params["accuracy"]) + else: + aggr_expr = F.approx_percentile(filtered_expr, pct) + else: + # All other aggregate functions (curated and custom) + try: + aggr_func = getattr(F, aggr_type) + # Apply aggr_params if provided and function supports them + if aggr_params: + aggr_expr = aggr_func(filtered_expr, **aggr_params) + else: + aggr_expr = aggr_func(filtered_expr) + except AttributeError: + raise InvalidParameterError( + f"Aggregate function '{aggr_type}' not found in pyspark.sql.functions. " + f"Verify the function name is correct." + ) if group_by: window_spec = Window.partitionBy(*[F.col(col) if isinstance(col, str) else col for col in group_by]) @@ -2312,6 +2447,12 @@ def apply(df: DataFrame) -> DataFrame: # Note: The aggregation naturally returns a single row without a groupBy clause, # so no explicit limit is required (informational only). agg_df = df.select(aggr_expr.alias(metric_col)).limit(1) + + # Validate custom aggregates + if not is_curated: + expected_rows = 1 # No group_by means single row expected + _validate_custom_aggregate(agg_df, aggr_type, metric_col, expected_rows) + df = df.crossJoin(agg_df) # bring the metric across all rows df = df.withColumn(condition_col, compare_op(F.col(metric_col), limit_expr)) diff --git a/tests/integration/test_dataset_checks.py b/tests/integration/test_dataset_checks.py index 99eafa264..b9e600b47 100644 --- a/tests/integration/test_dataset_checks.py +++ b/tests/integration/test_dataset_checks.py @@ -523,6 +523,159 @@ def test_is_aggr_not_equal(spark: SparkSession): assert_df_equality(actual, expected, ignore_nullable=True) +def test_is_aggr_with_count_distinct(spark: SparkSession): + """Test count_distinct aggregation - GitHub issue #929 requirement: 'each country has one code'.""" + test_df = spark.createDataFrame( + [ + ["US", "USA"], + ["US", "USA"], # Same code, OK + ["FR", "FRA"], + ["FR", "FRN"], # Different code, should FAIL + ["DE", "DEU"], + ], + "country: string, code: string", + ) + + checks = [ + # Each country should have exactly one distinct code + is_aggr_not_greater_than("code", limit=1, aggr_type="count_distinct", group_by=["country"]), + # Global distinct count + is_aggr_not_greater_than("country", limit=5, aggr_type="count_distinct"), + ] + + actual = _apply_checks(test_df, checks) + + expected_schema = "country: string, code: string, code_count_distinct_group_by_country_greater_than_limit: string, country_count_distinct_greater_than_limit: string" + + expected = spark.createDataFrame( + [ + ["US", "USA", None, None], + ["US", "USA", None, None], + ["FR", "FRA", "Count_distinct 2 in column 'code' per group of columns 'country' is greater than limit: 1", None], + ["FR", "FRN", "Count_distinct 2 in column 'code' per group of columns 'country' is greater than limit: 1", None], + ["DE", "DEU", None, None], + ], + expected_schema, + ) + + assert_df_equality(actual, expected, ignore_nullable=True) + + +def test_is_aggr_with_statistical_functions(spark: SparkSession): + """Test statistical aggregate functions: stddev, variance, median, skewness, kurtosis.""" + test_df = spark.createDataFrame( + [ + ["A", 10.0], + ["A", 20.0], + ["A", 30.0], + ["B", 5.0], + ["B", 5.0], + ["B", 5.0], + ], + "group: string, value: double", + ) + + checks = [ + # Standard deviation check (group A has higher variance than B) + is_aggr_not_greater_than("value", limit=10.0, aggr_type="stddev", group_by=["group"]), + # Variance check + is_aggr_not_greater_than("value", limit=100.0, aggr_type="variance", group_by=["group"]), + # Median check + is_aggr_not_greater_than("value", limit=25.0, aggr_type="median"), + ] + + actual = _apply_checks(test_df, checks) + + # Check that checks were applied (exact values depend on Spark's statistical functions) + assert "value_stddev_group_by_group_greater_than_limit" in actual.columns + assert "value_variance_group_by_group_greater_than_limit" in actual.columns + assert "value_median_greater_than_limit" in actual.columns + + +def test_is_aggr_with_percentile_functions(spark: SparkSession): + """Test percentile and approx_percentile with aggr_params.""" + test_df = spark.createDataFrame( + [(i, float(i)) for i in range(1, 101)], + "id: int, value: double", + ) + + checks = [ + # P95 should be around 95 + is_aggr_not_greater_than("value", limit=100.0, aggr_type="percentile", aggr_params={"percentile": 0.95}), + # P99 with approx_percentile + is_aggr_not_greater_than("value", limit=100.0, aggr_type="approx_percentile", aggr_params={"percentile": 0.99}), + # P50 (median) with accuracy parameter + is_aggr_not_less_than("value", limit=40.0, aggr_type="approx_percentile", aggr_params={"percentile": 0.50, "accuracy": 100}), + ] + + actual = _apply_checks(test_df, checks) + + # Verify columns exist + assert "value_percentile_greater_than_limit" in actual.columns + assert "value_approx_percentile_greater_than_limit" in actual.columns + assert "value_approx_percentile_less_than_limit" in actual.columns + + +def test_is_aggr_percentile_missing_params(spark: SparkSession): + """Test that percentile functions require percentile parameter.""" + test_df = spark.createDataFrame([(1, 10.0)], "id: int, value: double") + + from databricks.labs.dqx.errors import MissingParameterError + + # Should raise error when percentile param is missing + with pytest.raises(MissingParameterError, match="percentile.*requires aggr_params"): + condition, apply_fn = is_aggr_not_greater_than("value", limit=100.0, aggr_type="percentile") + apply_fn(test_df) + + +def test_is_aggr_with_invalid_aggregate_function(spark: SparkSession): + """Test that invalid aggregate function names raise clear errors.""" + test_df = spark.createDataFrame([(1, 10)], "id: int, value: int") + + from databricks.labs.dqx.errors import InvalidParameterError + + # Non-existent function should raise error + with pytest.raises(InvalidParameterError, match="not found in pyspark.sql.functions"): + condition, apply_fn = is_aggr_not_greater_than("value", limit=100, aggr_type="nonexistent_function") + apply_fn(test_df) + + +def test_is_aggr_with_collect_list_fails(spark: SparkSession): + """Test that collect_list (returns array) fails with clear error message.""" + test_df = spark.createDataFrame([(1, 10), (2, 20)], "id: int, value: int") + + from databricks.labs.dqx.errors import InvalidParameterError + + # collect_list returns array which cannot be compared to numeric limit + with pytest.raises(InvalidParameterError, match="ArrayType.*cannot be compared"): + condition, apply_fn = is_aggr_not_greater_than("value", limit=100, aggr_type="collect_list") + apply_fn(test_df) + + +def test_is_aggr_custom_aggregate_with_warning(spark: SparkSession): + """Test that custom (non-curated) aggregates work but produce warning.""" + import warnings + + test_df = spark.createDataFrame( + [("A", 10), ("B", 20), ("C", 30)], + "category: string, value: int", + ) + + # Use a valid aggregate that's not in curated list (e.g., if we had a UDAF) + # For now, test with any_value which is a valid Spark aggregate + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + condition, apply_fn = is_aggr_not_greater_than("value", limit=100, aggr_type="any_value") + result = apply_fn(test_df) + + # Should have warning about non-curated aggregate + assert len(w) > 0 + assert "non-curated" in str(w[0].message).lower() + + # Should still work + assert "value_any_value_greater_than_limit" in result.columns + + def test_dataset_compare(spark: SparkSession, set_utc_timezone): schema = "id1 long, id2 long, name string, dt date, ts timestamp, score float, likes bigint, active boolean" From e8602ce0f72191f90a08ec6b9bd72da02b890471 Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Thu, 27 Nov 2025 13:32:06 +0000 Subject: [PATCH 02/24] fmt: code formatting and cleanup --- src/databricks/labs/dqx/check_funcs.py | 106 +++++++++++++---------- tests/integration/test_dataset_checks.py | 37 ++++---- 2 files changed, 80 insertions(+), 63 deletions(-) diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index ab20c39d6..32cf0f503 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -7,6 +7,7 @@ from enum import Enum from itertools import zip_longest import operator as py_operator +from typing import Any import pandas as pd # type: ignore[import-untyped] import pyspark.sql.functions as F from pyspark.sql import types @@ -29,17 +30,14 @@ # Curated aggregate functions for data quality checks CURATED_AGGR_FUNCTIONS = { - # Basic aggregations (5) "count", "sum", "avg", "min", "max", - # Cardinality & Uniqueness (3) "count_distinct", "approx_count_distinct", "count_if", - # Statistical Analysis (9) "stddev", "stddev_pop", "stddev_samp", @@ -49,10 +47,8 @@ "median", "skewness", "kurtosis", - # SLA & Performance Monitoring (2) "percentile", "approx_percentile", - # Correlation Analysis (3) "corr", "covar_pop", "covar_samp", @@ -1294,7 +1290,7 @@ def is_aggr_not_greater_than( column: str | Column, limit: int | float | str | Column, aggr_type: str = "count", - aggr_params: dict[str, any] | None = None, + aggr_params: dict[str, Any] | None = None, group_by: list[str | Column] | None = None, row_filter: str | None = None, ) -> tuple[Column, Callable]: @@ -1337,7 +1333,7 @@ def is_aggr_not_less_than( column: str | Column, limit: int | float | str | Column, aggr_type: str = "count", - aggr_params: dict[str, any] | None = None, + aggr_params: dict[str, Any] | None = None, group_by: list[str | Column] | None = None, row_filter: str | None = None, ) -> tuple[Column, Callable]: @@ -1380,7 +1376,7 @@ def is_aggr_equal( column: str | Column, limit: int | float | str | Column, aggr_type: str = "count", - aggr_params: dict[str, any] | None = None, + aggr_params: dict[str, Any] | None = None, group_by: list[str | Column] | None = None, row_filter: str | None = None, ) -> tuple[Column, Callable]: @@ -1423,7 +1419,7 @@ def is_aggr_not_equal( column: str | Column, limit: int | float | str | Column, aggr_type: str = "count", - aggr_params: dict[str, any] | None = None, + aggr_params: dict[str, Any] | None = None, group_by: list[str | Column] | None = None, row_filter: str | None = None, ) -> tuple[Column, Callable]: @@ -2268,6 +2264,54 @@ def _add_compare_condition( ) +def _build_aggregate_expression( + aggr_type: str, + filtered_expr: Column, + aggr_params: dict[str, Any] | None, +) -> Column: + """ + Build the appropriate Spark aggregate expression based on function type and parameters. + + Args: + aggr_type: Name of the aggregate function. + filtered_expr: Column expression with filters applied. + aggr_params: Optional parameters for the aggregate function. + + Returns: + Spark Column expression for the aggregate. + + Raises: + MissingParameterError: If required parameters are missing for specific aggregates. + InvalidParameterError: If the aggregate function is not found. + """ + if aggr_type == "count_distinct": + return F.countDistinct(filtered_expr) + + if aggr_type in {"percentile", "approx_percentile"}: + if not aggr_params or "percentile" not in aggr_params: + raise MissingParameterError( + f"'{aggr_type}' requires aggr_params with 'percentile' key (e.g., {{'percentile': 0.95}})" + ) + pct = aggr_params["percentile"] + + if aggr_type == "percentile": + return F.percentile(filtered_expr, pct) + if "accuracy" in aggr_params: + return F.approx_percentile(filtered_expr, pct, aggr_params["accuracy"]) + return F.approx_percentile(filtered_expr, pct) + + try: + aggr_func = getattr(F, aggr_type) + if aggr_params: + return aggr_func(filtered_expr, **aggr_params) + return aggr_func(filtered_expr) + except AttributeError as exc: + raise InvalidParameterError( + f"Aggregate function '{aggr_type}' not found in pyspark.sql.functions. " + f"Verify the function name is correct." + ) from exc + + def _validate_custom_aggregate( agg_df: DataFrame, aggr_type: str, @@ -2313,7 +2357,7 @@ def _is_aggr_compare( column: str | Column, limit: int | float | str | Column, aggr_type: str, - aggr_params: dict[str, any] | None, + aggr_params: dict[str, Any] | None, group_by: list[str | Column] | None, row_filter: str | None, compare_op: Callable[[Column, Column], Column], @@ -2401,41 +2445,9 @@ def apply(df: DataFrame) -> DataFrame: """ filter_col = F.expr(row_filter) if row_filter else F.lit(True) filtered_expr = F.when(filter_col, aggr_col_expr) if row_filter else aggr_col_expr - - # Build aggregation expression based on function type - if aggr_type == "count_distinct": - # Spark uses countDistinct, not count_distinct - aggr_expr = F.countDistinct(filtered_expr) - elif aggr_type in ("percentile", "approx_percentile"): - # Percentile functions require percentile parameter - if not aggr_params or "percentile" not in aggr_params: - raise MissingParameterError( - f"'{aggr_type}' requires aggr_params with 'percentile' key (e.g., {{'percentile': 0.95}})" - ) - pct = aggr_params["percentile"] - - if aggr_type == "percentile": - aggr_expr = F.percentile(filtered_expr, pct) - else: # approx_percentile - # Check if accuracy parameter is provided - if "accuracy" in aggr_params: - aggr_expr = F.approx_percentile(filtered_expr, pct, aggr_params["accuracy"]) - else: - aggr_expr = F.approx_percentile(filtered_expr, pct) - else: - # All other aggregate functions (curated and custom) - try: - aggr_func = getattr(F, aggr_type) - # Apply aggr_params if provided and function supports them - if aggr_params: - aggr_expr = aggr_func(filtered_expr, **aggr_params) - else: - aggr_expr = aggr_func(filtered_expr) - except AttributeError: - raise InvalidParameterError( - f"Aggregate function '{aggr_type}' not found in pyspark.sql.functions. " - f"Verify the function name is correct." - ) + + # Build aggregation expression + aggr_expr = _build_aggregate_expression(aggr_type, filtered_expr, aggr_params) if group_by: window_spec = Window.partitionBy(*[F.col(col) if isinstance(col, str) else col for col in group_by]) @@ -2447,12 +2459,12 @@ def apply(df: DataFrame) -> DataFrame: # Note: The aggregation naturally returns a single row without a groupBy clause, # so no explicit limit is required (informational only). agg_df = df.select(aggr_expr.alias(metric_col)).limit(1) - + # Validate custom aggregates if not is_curated: expected_rows = 1 # No group_by means single row expected _validate_custom_aggregate(agg_df, aggr_type, metric_col, expected_rows) - + df = df.crossJoin(agg_df) # bring the metric across all rows df = df.withColumn(condition_col, compare_op(F.col(metric_col), limit_expr)) diff --git a/tests/integration/test_dataset_checks.py b/tests/integration/test_dataset_checks.py index b9e600b47..8c22b4304 100644 --- a/tests/integration/test_dataset_checks.py +++ b/tests/integration/test_dataset_checks.py @@ -3,6 +3,7 @@ from typing import Any import json import itertools +import warnings import pytest import pyspark.sql.functions as F @@ -22,7 +23,7 @@ has_valid_schema, ) from databricks.labs.dqx.utils import get_column_name_or_alias -from databricks.labs.dqx.errors import InvalidParameterError +from databricks.labs.dqx.errors import InvalidParameterError, MissingParameterError from tests.conftest import TEST_CATALOG @@ -551,8 +552,18 @@ def test_is_aggr_with_count_distinct(spark: SparkSession): [ ["US", "USA", None, None], ["US", "USA", None, None], - ["FR", "FRA", "Count_distinct 2 in column 'code' per group of columns 'country' is greater than limit: 1", None], - ["FR", "FRN", "Count_distinct 2 in column 'code' per group of columns 'country' is greater than limit: 1", None], + [ + "FR", + "FRA", + "Count_distinct 2 in column 'code' per group of columns 'country' is greater than limit: 1", + None, + ], + [ + "FR", + "FRN", + "Count_distinct 2 in column 'code' per group of columns 'country' is greater than limit: 1", + None, + ], ["DE", "DEU", None, None], ], expected_schema, @@ -605,7 +616,9 @@ def test_is_aggr_with_percentile_functions(spark: SparkSession): # P99 with approx_percentile is_aggr_not_greater_than("value", limit=100.0, aggr_type="approx_percentile", aggr_params={"percentile": 0.99}), # P50 (median) with accuracy parameter - is_aggr_not_less_than("value", limit=40.0, aggr_type="approx_percentile", aggr_params={"percentile": 0.50, "accuracy": 100}), + is_aggr_not_less_than( + "value", limit=40.0, aggr_type="approx_percentile", aggr_params={"percentile": 0.50, "accuracy": 100} + ), ] actual = _apply_checks(test_df, checks) @@ -620,11 +633,9 @@ def test_is_aggr_percentile_missing_params(spark: SparkSession): """Test that percentile functions require percentile parameter.""" test_df = spark.createDataFrame([(1, 10.0)], "id: int, value: double") - from databricks.labs.dqx.errors import MissingParameterError - # Should raise error when percentile param is missing with pytest.raises(MissingParameterError, match="percentile.*requires aggr_params"): - condition, apply_fn = is_aggr_not_greater_than("value", limit=100.0, aggr_type="percentile") + _, apply_fn = is_aggr_not_greater_than("value", limit=100.0, aggr_type="percentile") apply_fn(test_df) @@ -632,11 +643,9 @@ def test_is_aggr_with_invalid_aggregate_function(spark: SparkSession): """Test that invalid aggregate function names raise clear errors.""" test_df = spark.createDataFrame([(1, 10)], "id: int, value: int") - from databricks.labs.dqx.errors import InvalidParameterError - # Non-existent function should raise error with pytest.raises(InvalidParameterError, match="not found in pyspark.sql.functions"): - condition, apply_fn = is_aggr_not_greater_than("value", limit=100, aggr_type="nonexistent_function") + _, apply_fn = is_aggr_not_greater_than("value", limit=100, aggr_type="nonexistent_function") apply_fn(test_df) @@ -644,18 +653,14 @@ def test_is_aggr_with_collect_list_fails(spark: SparkSession): """Test that collect_list (returns array) fails with clear error message.""" test_df = spark.createDataFrame([(1, 10), (2, 20)], "id: int, value: int") - from databricks.labs.dqx.errors import InvalidParameterError - # collect_list returns array which cannot be compared to numeric limit with pytest.raises(InvalidParameterError, match="ArrayType.*cannot be compared"): - condition, apply_fn = is_aggr_not_greater_than("value", limit=100, aggr_type="collect_list") + _, apply_fn = is_aggr_not_greater_than("value", limit=100, aggr_type="collect_list") apply_fn(test_df) def test_is_aggr_custom_aggregate_with_warning(spark: SparkSession): """Test that custom (non-curated) aggregates work but produce warning.""" - import warnings - test_df = spark.createDataFrame( [("A", 10), ("B", 20), ("C", 30)], "category: string, value: int", @@ -665,7 +670,7 @@ def test_is_aggr_custom_aggregate_with_warning(spark: SparkSession): # For now, test with any_value which is a valid Spark aggregate with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - condition, apply_fn = is_aggr_not_greater_than("value", limit=100, aggr_type="any_value") + _, apply_fn = is_aggr_not_greater_than("value", limit=100, aggr_type="any_value") result = apply_fn(test_df) # Should have warning about non-curated aggregate From ae6de97ad77b56b74c2f97a3e75152fec7e4e760 Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Thu, 27 Nov 2025 14:14:23 +0000 Subject: [PATCH 03/24] add: 1. is_aggr with group_by and 2. updated demo library and removed 3. bivariate analysis aggr functions --- demos/dqx_demo_library.py | 213 ++++++++++++++++++ .../docs/guide/quality_checks_definition.mdx | 1 + docs/dqx/docs/reference/quality_checks.mdx | 17 +- src/databricks/labs/dqx/check_funcs.py | 44 ++-- tests/integration/test_dataset_checks.py | 55 ++++- 5 files changed, 296 insertions(+), 34 deletions(-) diff --git a/demos/dqx_demo_library.py b/demos/dqx_demo_library.py index d414317cb..4765e765a 100644 --- a/demos/dqx_demo_library.py +++ b/demos/dqx_demo_library.py @@ -690,6 +690,219 @@ def not_ends_with(column: str, suffix: str) -> Column: # COMMAND ---------- +# MAGIC %md +# MAGIC ### Extended Aggregate Functions for Data Quality Checks +# MAGIC +# MAGIC DQX now supports 20 curated aggregate functions for advanced data quality monitoring: +# MAGIC - **Statistical functions**: `stddev`, `variance`, `median`, `mode`, `skewness`, `kurtosis` for anomaly detection +# MAGIC - **Percentile functions**: `percentile`, `approx_percentile` for SLA monitoring +# MAGIC - **Cardinality functions**: `count_distinct`, `approx_count_distinct` for uniqueness validation +# MAGIC - **Custom aggregates**: Support for user-defined functions with runtime validation + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### Example 1: Statistical Functions - Anomaly Detection with Standard Deviation +# MAGIC +# MAGIC Detect unusual variance in sensor readings per machine. High standard deviation indicates unstable sensors that may need calibration. + +# COMMAND ---------- + +from databricks.labs.dqx.engine import DQEngine +from databricks.labs.dqx.rule import DQDatasetRule +from databricks.labs.dqx import check_funcs +from databricks.sdk import WorkspaceClient + +# Manufacturing sensor data with readings from multiple machines +manufacturing_df = spark.createDataFrame([ + ["M1", "2024-01-01", 20.1], + ["M1", "2024-01-02", 20.3], + ["M1", "2024-01-03", 20.2], # Machine 1: stable readings (low stddev) + ["M2", "2024-01-01", 18.5], + ["M2", "2024-01-02", 25.7], + ["M2", "2024-01-03", 15.2], # Machine 2: unstable readings (high stddev) - should FAIL + ["M3", "2024-01-01", 19.8], + ["M3", "2024-01-02", 20.1], + ["M3", "2024-01-03", 19.9], # Machine 3: stable readings +], "machine_id: string, date: string, temperature: double") + +# Quality check: Standard deviation should not exceed 3.0 per machine +checks = [ + DQDatasetRule( + criticality="error", + check_func=check_funcs.is_aggr_not_greater_than, + column="temperature", + check_func_kwargs={ + "aggr_type": "stddev", + "group_by": ["machine_id"], + "limit": 3.0 + }, + ), +] + +dq_engine = DQEngine(WorkspaceClient()) +result_df = dq_engine.apply_checks(manufacturing_df, checks) +display(result_df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### Example 2: Approximate Aggregate Functions - Efficient Cardinality Estimation +# MAGIC +# MAGIC **`approx_count_distinct`** provides fast, memory-efficient cardinality estimation for large datasets. +# MAGIC +# MAGIC **From [Databricks Documentation](https://docs.databricks.com/aws/en/sql/language-manual/functions/approx_count_distinct.html):** +# MAGIC - Uses **HyperLogLog++** (HLL++) algorithm, a state-of-the-art cardinality estimator +# MAGIC - **Accurate within 5%** by default (configurable via `relativeSD` parameter) +# MAGIC - **Memory efficient**: Uses fixed memory regardless of cardinality +# MAGIC - **Ideal for**: High-cardinality columns, large datasets, real-time analytics +# MAGIC +# MAGIC **Use Case**: Monitor daily active users without expensive exact counting. + +# COMMAND ---------- + +# User activity data with high cardinality +user_activity_df = spark.createDataFrame([ + ["2024-01-01", f"user_{i}"] for i in range(1, 95001) # 95,000 distinct users on day 1 +] + [ + ["2024-01-02", f"user_{i}"] for i in range(1, 50001) # 50,000 distinct users on day 2 +], "activity_date: string, user_id: string") + +# Quality check: Ensure daily active users don't drop below 60,000 +# Using approx_count_distinct is much faster than count_distinct for large datasets +checks = [ + DQDatasetRule( + criticality="warn", + check_func=check_funcs.is_aggr_not_less_than, + column="user_id", + check_func_kwargs={ + "aggr_type": "approx_count_distinct", # Fast approximate counting + "group_by": ["activity_date"], + "limit": 60000 + }, + ), +] + +result_df = dq_engine.apply_checks(user_activity_df, checks) +display(result_df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### Example 3: Custom Aggregate Functions with Runtime Validation +# MAGIC +# MAGIC DQX supports custom aggregate functions (including UDAFs) with: +# MAGIC - **Warning**: Non-curated functions trigger a warning +# MAGIC - **Runtime validation**: Ensures the function returns numeric values compatible with comparisons +# MAGIC - **Graceful errors**: Invalid aggregates (e.g., `collect_list` returning arrays) fail with clear messages + +# COMMAND ---------- + +import warnings + +# Sensor data with multiple readings per sensor +sensor_sample_df = spark.createDataFrame([ + ["S1", 45.2], + ["S1", 45.8], + ["S2", 78.1], + ["S2", 78.5], +], "sensor_id: string, reading: double") + +# Using a valid but non-curated aggregate function: any_value +# This will work but produce a warning +checks = [ + DQDatasetRule( + criticality="warn", + check_func=check_funcs.is_aggr_not_greater_than, + column="reading", + check_func_kwargs={ + "aggr_type": "any_value", # Not in curated list - triggers warning + "group_by": ["sensor_id"], + "limit": 100.0 + }, + ), +] + +# Capture warnings during execution +with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result_df = dq_engine.apply_checks(sensor_sample_df, checks) + + # Display warning message if present + if w: + print(f"⚠️ Warning: {w[0].message}") + +display(result_df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### Example 4: Percentile Functions for SLA Monitoring +# MAGIC +# MAGIC Monitor P95 latency to ensure 95% of API requests meet SLA requirements. + +# COMMAND ---------- + +# API latency data in milliseconds +api_latency_df = spark.createDataFrame([ + ["2024-01-01", i * 10.0] for i in range(1, 101) # 10ms to 1000ms latencies +], "date: string, latency_ms: double") + +# Quality check: P95 latency must be under 950ms +checks = [ + DQDatasetRule( + criticality="error", + check_func=check_funcs.is_aggr_not_greater_than, + column="latency_ms", + check_func_kwargs={ + "aggr_type": "percentile", + "aggr_params": {"percentile": 0.95}, # P95 + "group_by": ["date"], + "limit": 950.0 + }, + ), +] + +result_df = dq_engine.apply_checks(api_latency_df, checks) +display(result_df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC #### Example 5: Count Distinct for Uniqueness Validation +# MAGIC +# MAGIC Ensure referential integrity: each country should have exactly one country code. + +# COMMAND ---------- + +# Country data with potential duplicates +country_df = spark.createDataFrame([ + ["US", "USA"], + ["US", "USA"], # OK: same code + ["FR", "FRA"], + ["FR", "FRN"], # ERROR: different codes for same country + ["DE", "DEU"], +], "country: string, country_code: string") + +# Quality check: Each country must have exactly one distinct country code +checks = [ + DQDatasetRule( + criticality="error", + check_func=check_funcs.is_aggr_not_greater_than, + column="country_code", + check_func_kwargs={ + "aggr_type": "count_distinct", + "group_by": ["country"], + "limit": 1 + }, + ), +] + +result_df = dq_engine.apply_checks(country_df, checks) +display(result_df) + +# COMMAND ---------- + # MAGIC %md # MAGIC ### Creating custom dataset-level checks # MAGIC Requirement: Fail all readings from a sensor if any reading for that sensor exceeds a specified threshold from the sensor specification table. diff --git a/docs/dqx/docs/guide/quality_checks_definition.mdx b/docs/dqx/docs/guide/quality_checks_definition.mdx index 148699dc8..086aafd99 100644 --- a/docs/dqx/docs/guide/quality_checks_definition.mdx +++ b/docs/dqx/docs/guide/quality_checks_definition.mdx @@ -271,6 +271,7 @@ Monitor service performance and ensure SLA compliance: - **stddev/variance**: Anomaly detection, consistency monitoring - **percentile**: SLA compliance, outlier detection (P95, P99) - **median**: Baseline checks, central tendency +- **mode**: Most frequent value checks, categorical data quality - **approx_percentile**: Fast percentile estimation for large datasets - **approx_count_distinct**: Efficient cardinality for high-cardinality columns diff --git a/docs/dqx/docs/reference/quality_checks.mdx b/docs/dqx/docs/reference/quality_checks.mdx index dc9101d54..977b91624 100644 --- a/docs/dqx/docs/reference/quality_checks.mdx +++ b/docs/dqx/docs/reference/quality_checks.mdx @@ -1627,6 +1627,7 @@ Curated functions are validated and optimized for data quality use cases: - `variance` / `var_samp` - Sample variance (stability checks) - `var_pop` - Population variance - `median` - 50th percentile baseline +- `mode` - Most frequent value (categorical data quality) - `skewness` - Distribution skewness (detect asymmetry) - `kurtosis` - Distribution kurtosis (detect heavy tails) @@ -1634,11 +1635,6 @@ Curated functions are validated and optimized for data quality use cases: - `percentile` - Exact percentile (requires `aggr_params`) - `approx_percentile` - Approximate percentile (faster, requires `aggr_params`) -#### Correlation Analysis -- `corr` - Correlation coefficient -- `covar_pop` - Population covariance -- `covar_samp` - Sample covariance - ### Custom Aggregates Custom aggregate functions (including UDAFs) are supported with runtime validation. A warning will be issued, and the function must return a single numeric value per group. @@ -1723,6 +1719,17 @@ Custom aggregate functions (including UDAFs) are supported with runtime validati group_by: - session_date limit: 100000 + +# Mode: Alert if any single error code dominates +- criticality: warn + check: + function: is_aggr_not_greater_than + arguments: + column: error_code + aggr_type: mode + group_by: + - service_name + limit: 100 ``` # foreign_key check using reference DataFrame diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index 32cf0f503..7b058aa8e 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -29,6 +29,7 @@ IPV4_BIT_LENGTH = 32 # Curated aggregate functions for data quality checks +# These are univariate (single-column) aggregate functions suitable for DQ monitoring CURATED_AGGR_FUNCTIONS = { "count", "sum", @@ -45,13 +46,11 @@ "var_pop", "var_samp", "median", + "mode", "skewness", "kurtosis", "percentile", "approx_percentile", - "corr", - "covar_pop", - "covar_samp", } @@ -2312,39 +2311,27 @@ def _build_aggregate_expression( ) from exc -def _validate_custom_aggregate( - agg_df: DataFrame, +def _validate_aggregate_return_type( + df: DataFrame, aggr_type: str, metric_col: str, - expected_row_count: int, ) -> None: """ - Validate custom aggregate returns proper single numeric value per group. + Validate aggregate returns a numeric type that can be compared to limits. - This function validates that a custom (non-curated) aggregate function behaves correctly - by ensuring it returns the expected number of rows and a numeric data type that can be - compared to limits. + This is a schema-only validation (no data scanning) that checks whether the aggregate + function returns a type compatible with numeric comparisons. Args: - agg_df: DataFrame with aggregated results. + df: DataFrame containing the aggregate result column. aggr_type: Name of the aggregate function being validated. metric_col: Column name containing the aggregate result. - expected_row_count: Expected number of rows (1 for no group_by, count of groups for group_by). Raises: - InvalidParameterError: If the aggregate returns more rows than expected (not a proper aggregate) - or if it returns a non-numeric type (Array, Map, Struct) that cannot be compared to limits. + InvalidParameterError: If the aggregate returns a non-numeric type (Array, Map, Struct) + that cannot be compared to limits. """ - # Check 1: Row count (must return exactly expected rows) - actual_count = agg_df.count() - if actual_count > expected_row_count: - raise InvalidParameterError( - f"Aggregate function '{aggr_type}' returned {actual_count} rows, " - f"expected {expected_row_count}. This is not a proper aggregate function." - ) - - # Check 2: Return type (must be numeric, not array/struct/map) - result_type = agg_df.schema[metric_col].dataType + result_type = df.schema[metric_col].dataType if isinstance(result_type, (types.ArrayType, types.MapType, types.StructType)): raise InvalidParameterError( f"Aggregate function '{aggr_type}' returned {result_type.typeName()} " @@ -2452,6 +2439,10 @@ def apply(df: DataFrame) -> DataFrame: if group_by: window_spec = Window.partitionBy(*[F.col(col) if isinstance(col, str) else col for col in group_by]) df = df.withColumn(metric_col, aggr_expr.over(window_spec)) + + # Validate custom aggregates (type check only - window functions return same row count) + if not is_curated: + _validate_aggregate_return_type(df, aggr_type, metric_col) else: # When no group-by columns are provided, using partitionBy would move all rows into a single partition, # forcing the window function to process the entire dataset in one task. @@ -2460,10 +2451,9 @@ def apply(df: DataFrame) -> DataFrame: # so no explicit limit is required (informational only). agg_df = df.select(aggr_expr.alias(metric_col)).limit(1) - # Validate custom aggregates + # Validate custom aggregates (type check only - we already limited to 1 row) if not is_curated: - expected_rows = 1 # No group_by means single row expected - _validate_custom_aggregate(agg_df, aggr_type, metric_col, expected_rows) + _validate_aggregate_return_type(agg_df, aggr_type, metric_col) df = df.crossJoin(agg_df) # bring the metric across all rows diff --git a/tests/integration/test_dataset_checks.py b/tests/integration/test_dataset_checks.py index 8c22b4304..98a4980db 100644 --- a/tests/integration/test_dataset_checks.py +++ b/tests/integration/test_dataset_checks.py @@ -573,7 +573,7 @@ def test_is_aggr_with_count_distinct(spark: SparkSession): def test_is_aggr_with_statistical_functions(spark: SparkSession): - """Test statistical aggregate functions: stddev, variance, median, skewness, kurtosis.""" + """Test statistical aggregate functions: stddev, variance, median, mode, skewness, kurtosis.""" test_df = spark.createDataFrame( [ ["A", 10.0], @@ -593,6 +593,8 @@ def test_is_aggr_with_statistical_functions(spark: SparkSession): is_aggr_not_greater_than("value", limit=100.0, aggr_type="variance", group_by=["group"]), # Median check is_aggr_not_greater_than("value", limit=25.0, aggr_type="median"), + # Mode check (most frequent value per group) + is_aggr_not_greater_than("value", limit=10.0, aggr_type="mode", group_by=["group"]), ] actual = _apply_checks(test_df, checks) @@ -601,6 +603,42 @@ def test_is_aggr_with_statistical_functions(spark: SparkSession): assert "value_stddev_group_by_group_greater_than_limit" in actual.columns assert "value_variance_group_by_group_greater_than_limit" in actual.columns assert "value_median_greater_than_limit" in actual.columns + assert "value_mode_group_by_group_greater_than_limit" in actual.columns + + +def test_is_aggr_with_mode_function(spark: SparkSession): + """Test mode aggregate function for categorical data quality checks.""" + test_df = spark.createDataFrame( + [ + ["service_A", "ERROR_401"], + ["service_A", "ERROR_401"], + ["service_A", "ERROR_401"], # ERROR_401 appears 3 times (mode) + ["service_A", "ERROR_500"], + ["service_B", "ERROR_200"], + ["service_B", "ERROR_200"], # ERROR_200 appears 2 times (mode) + ["service_B", "ERROR_404"], + ], + "service: string, error_code: string", + ) + + # Check that modal error code doesn't appear too frequently (>2 times per service) + checks = [ + is_aggr_not_greater_than("error_code", limit=2, aggr_type="mode", group_by=["service"]), + ] + + actual = _apply_checks(test_df, checks) + + # service_A should fail (mode=3 > limit=2), service_B should pass (mode=2 <= limit=2) + assert "error_code_mode_group_by_service_greater_than_limit" in actual.columns + + # Verify failures - service_A rows should have error messages + service_a_errors = ( + actual.filter(F.col("service") == "service_A") + .select("error_code_mode_group_by_service_greater_than_limit") + .distinct() + .collect() + ) + assert len([r for r in service_a_errors if r[0] is not None]) > 0 def test_is_aggr_with_percentile_functions(spark: SparkSession): @@ -650,7 +688,7 @@ def test_is_aggr_with_invalid_aggregate_function(spark: SparkSession): def test_is_aggr_with_collect_list_fails(spark: SparkSession): - """Test that collect_list (returns array) fails with clear error message.""" + """Test that collect_list (returns array) fails with clear error message - no group_by.""" test_df = spark.createDataFrame([(1, 10), (2, 20)], "id: int, value: int") # collect_list returns array which cannot be compared to numeric limit @@ -659,6 +697,19 @@ def test_is_aggr_with_collect_list_fails(spark: SparkSession): apply_fn(test_df) +def test_is_aggr_with_collect_list_fails_with_group_by(spark: SparkSession): + """Test that collect_list with group_by also fails with clear error message - bug fix verification.""" + test_df = spark.createDataFrame( + [("A", 10), ("A", 20), ("B", 30)], + "category: string, value: int", + ) + + # This is the bug fix: collect_list with group_by should fail gracefully, not with cryptic Spark error + with pytest.raises(InvalidParameterError, match="ArrayType.*cannot be compared"): + _, apply_fn = is_aggr_not_greater_than("value", limit=100, aggr_type="collect_list", group_by=["category"]) + apply_fn(test_df) + + def test_is_aggr_custom_aggregate_with_warning(spark: SparkSession): """Test that custom (non-curated) aggregates work but produce warning.""" test_df = spark.createDataFrame( From a661b934b25f68b26e0f5dccb93df4b09dec286f Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Thu, 27 Nov 2025 14:18:07 +0000 Subject: [PATCH 04/24] fix: test expectation --- tests/unit/test_row_checks.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_row_checks.py b/tests/unit/test_row_checks.py index 2fcaada2a..3bd2e9cc9 100644 --- a/tests/unit/test_row_checks.py +++ b/tests/unit/test_row_checks.py @@ -51,8 +51,22 @@ def test_col_is_in_list_missing_allowed_list(): def test_incorrect_aggr_type(): - with pytest.raises(InvalidParameterError, match="Unsupported aggregation type"): - is_aggr_not_greater_than("a", 1, aggr_type="invalid") + # With new implementation, invalid aggr_type triggers a warning (not immediate error) + # The error occurs at runtime when the apply function is called + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + condition, apply_fn = is_aggr_not_greater_than("a", 1, aggr_type="invalid") + + # Should have warning about non-curated aggregate + assert len(w) > 0 + assert "non-curated" in str(w[0].message).lower() + assert "invalid" in str(w[0].message) + + # Function should return successfully (error will happen at runtime when applied to DataFrame) + assert condition is not None + assert apply_fn is not None def test_col_is_ipv4_address_in_cidr_missing_cidr_block(): From 91b5029eddab91492b3ce9fa8c26547e1be210b6 Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Thu, 27 Nov 2025 14:30:52 +0000 Subject: [PATCH 05/24] fix + fmt: test assetion for arraytype --- tests/integration/test_dataset_checks.py | 4 ++-- tests/unit/test_row_checks.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/integration/test_dataset_checks.py b/tests/integration/test_dataset_checks.py index 98a4980db..d210b43c2 100644 --- a/tests/integration/test_dataset_checks.py +++ b/tests/integration/test_dataset_checks.py @@ -692,7 +692,7 @@ def test_is_aggr_with_collect_list_fails(spark: SparkSession): test_df = spark.createDataFrame([(1, 10), (2, 20)], "id: int, value: int") # collect_list returns array which cannot be compared to numeric limit - with pytest.raises(InvalidParameterError, match="ArrayType.*cannot be compared"): + with pytest.raises(InvalidParameterError, match="array.*cannot be compared"): _, apply_fn = is_aggr_not_greater_than("value", limit=100, aggr_type="collect_list") apply_fn(test_df) @@ -705,7 +705,7 @@ def test_is_aggr_with_collect_list_fails_with_group_by(spark: SparkSession): ) # This is the bug fix: collect_list with group_by should fail gracefully, not with cryptic Spark error - with pytest.raises(InvalidParameterError, match="ArrayType.*cannot be compared"): + with pytest.raises(InvalidParameterError, match="array.*cannot be compared"): _, apply_fn = is_aggr_not_greater_than("value", limit=100, aggr_type="collect_list", group_by=["category"]) apply_fn(test_df) diff --git a/tests/unit/test_row_checks.py b/tests/unit/test_row_checks.py index 3bd2e9cc9..2c1f5d085 100644 --- a/tests/unit/test_row_checks.py +++ b/tests/unit/test_row_checks.py @@ -1,3 +1,4 @@ +import warnings import pytest from databricks.labs.dqx.check_funcs import ( is_equal_to, @@ -53,17 +54,15 @@ def test_col_is_in_list_missing_allowed_list(): def test_incorrect_aggr_type(): # With new implementation, invalid aggr_type triggers a warning (not immediate error) # The error occurs at runtime when the apply function is called - import warnings - with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") condition, apply_fn = is_aggr_not_greater_than("a", 1, aggr_type="invalid") - + # Should have warning about non-curated aggregate assert len(w) > 0 assert "non-curated" in str(w[0].message).lower() assert "invalid" in str(w[0].message) - + # Function should return successfully (error will happen at runtime when applied to DataFrame) assert condition is not None assert apply_fn is not None From 817cf4a861ac6e8f8a3f6fb77608851f4516e812 Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Thu, 27 Nov 2025 15:08:15 +0000 Subject: [PATCH 06/24] docs: fixes --- docs/dqx/docs/reference/quality_checks.mdx | 2 ++ src/databricks/labs/dqx/check_funcs.py | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/dqx/docs/reference/quality_checks.mdx b/docs/dqx/docs/reference/quality_checks.mdx index 977b91624..ca78a99bb 100644 --- a/docs/dqx/docs/reference/quality_checks.mdx +++ b/docs/dqx/docs/reference/quality_checks.mdx @@ -1600,6 +1600,7 @@ Complex data types are supported as well. group_by: - col3 limit: 200 +``` ## Aggregate Function Types @@ -1732,6 +1733,7 @@ Custom aggregate functions (including UDAFs) are supported with runtime validati limit: 100 ``` +```yaml # foreign_key check using reference DataFrame - criticality: error check: diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index 7b058aa8e..6eaad0d7d 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -1305,7 +1305,7 @@ def is_aggr_not_greater_than( limit: Numeric value, column name, or SQL expression for the limit. aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max, count_distinct, stddev, percentile, and more. Custom aggregates are supported with validation. - aggr_params: Optional dict of parameters for aggregates requiring them (e.g., {"percentile": 0.95}). + aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. @@ -1348,7 +1348,7 @@ def is_aggr_not_less_than( limit: Numeric value, column name, or SQL expression for the limit. aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max, count_distinct, stddev, percentile, and more. Custom aggregates are supported with validation. - aggr_params: Optional dict of parameters for aggregates requiring them (e.g., {"percentile": 0.95}). + aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. @@ -1391,7 +1391,7 @@ def is_aggr_equal( limit: Numeric value, column name, or SQL expression for the limit. aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max, count_distinct, stddev, percentile, and more. Custom aggregates are supported with validation. - aggr_params: Optional dict of parameters for aggregates requiring them (e.g., {"percentile": 0.95}). + aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. @@ -1434,7 +1434,7 @@ def is_aggr_not_equal( limit: Numeric value, column name, or SQL expression for the limit. aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max, count_distinct, stddev, percentile, and more. Custom aggregates are supported with validation. - aggr_params: Optional dict of parameters for aggregates requiring them (e.g., {"percentile": 0.95}). + aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. From 45eab640cb6b6e3cd38a563e3109e55133940f8f Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Thu, 27 Nov 2025 16:17:02 +0000 Subject: [PATCH 07/24] fix: count_distinct does not support group_by because of the use of windowing functions in DQX. Parameter ordering was changed accidentaly. --- demos/dqx_demo_library.py | 2 +- .../docs/guide/quality_checks_definition.mdx | 3 +- docs/dqx/docs/reference/quality_checks.mdx | 32 ++-- src/databricks/labs/dqx/check_funcs.py | 26 ++- tests/integration/test_dataset_checks.py | 172 ++++++++++-------- 5 files changed, 132 insertions(+), 103 deletions(-) diff --git a/demos/dqx_demo_library.py b/demos/dqx_demo_library.py index 4765e765a..f32e8bb2a 100644 --- a/demos/dqx_demo_library.py +++ b/demos/dqx_demo_library.py @@ -696,7 +696,7 @@ def not_ends_with(column: str, suffix: str) -> Column: # MAGIC DQX now supports 20 curated aggregate functions for advanced data quality monitoring: # MAGIC - **Statistical functions**: `stddev`, `variance`, `median`, `mode`, `skewness`, `kurtosis` for anomaly detection # MAGIC - **Percentile functions**: `percentile`, `approx_percentile` for SLA monitoring -# MAGIC - **Cardinality functions**: `count_distinct`, `approx_count_distinct` for uniqueness validation +# MAGIC - **Cardinality functions**: `count_distinct` (global only), `approx_count_distinct` (works with group_by, uses HyperLogLog++) # MAGIC - **Custom aggregates**: Support for user-defined functions with runtime validation # COMMAND ---------- diff --git a/docs/dqx/docs/guide/quality_checks_definition.mdx b/docs/dqx/docs/guide/quality_checks_definition.mdx index 086aafd99..3543d19bb 100644 --- a/docs/dqx/docs/guide/quality_checks_definition.mdx +++ b/docs/dqx/docs/guide/quality_checks_definition.mdx @@ -267,7 +267,8 @@ Monitor service performance and ensure SLA compliance: ``` -- **count_distinct**: Uniqueness validation, cardinality checks +- **count_distinct**: Exact cardinality (global only, cannot use with `group_by` due to Spark limitation) +- **approx_count_distinct**: Fast cardinality estimation (works with `group_by`, uses HyperLogLog++) - **stddev/variance**: Anomaly detection, consistency monitoring - **percentile**: SLA compliance, outlier detection (P95, P99) - **median**: Baseline checks, central tendency diff --git a/docs/dqx/docs/reference/quality_checks.mdx b/docs/dqx/docs/reference/quality_checks.mdx index ca78a99bb..4b27a636f 100644 --- a/docs/dqx/docs/reference/quality_checks.mdx +++ b/docs/dqx/docs/reference/quality_checks.mdx @@ -1618,8 +1618,8 @@ Curated functions are validated and optimized for data quality use cases: - `max` - Maximum value #### Cardinality & Uniqueness -- `count_distinct` - Count unique values (e.g., "each country has one code") -- `approx_count_distinct` - Approximate distinct count (faster for large datasets) +- `count_distinct` - Exact distinct count (global aggregations only, not supported with `group_by`) +- `approx_count_distinct` - Approximate distinct count using HyperLogLog++ (works with `group_by`, faster for large datasets) - `count_if` - Conditional counting #### Statistical Analysis @@ -1645,16 +1645,25 @@ Custom aggregate functions (including UDAFs) are supported with runtime validati ### Extended Examples ```yaml -# count_distinct: Ensure each country has exactly one country code +# count_distinct: Exact distinct count (global only, no group_by) - criticality: error check: function: is_aggr_not_greater_than arguments: - column: country_code + column: user_id aggr_type: count_distinct + limit: 1000000 + +# approx_count_distinct: Fast cardinality estimation (works with group_by) +- criticality: error + check: + function: is_aggr_not_greater_than + arguments: + column: user_id + aggr_type: approx_count_distinct group_by: - - country - limit: 1 + - session_date + limit: 100000 # Standard deviation: Detect unusual variance in sensor readings per machine - criticality: warn @@ -1710,17 +1719,6 @@ Custom aggregate functions (including UDAFs) are supported with runtime validati aggr_type: variance limit: 1000000.0 -# Approximate count distinct: Efficient cardinality estimation -- criticality: warn - check: - function: is_aggr_not_greater_than - arguments: - column: user_id - aggr_type: approx_count_distinct - group_by: - - session_date - limit: 100000 - # Mode: Alert if any single error code dominates - criticality: warn check: diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index 6eaad0d7d..a9e3c3f78 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -1289,9 +1289,9 @@ def is_aggr_not_greater_than( column: str | Column, limit: int | float | str | Column, aggr_type: str = "count", - aggr_params: dict[str, Any] | None = None, group_by: list[str | Column] | None = None, row_filter: str | None = None, + aggr_params: dict[str, Any] | None = None, ) -> tuple[Column, Callable]: """ Build an aggregation check condition and closure for dataset-level validation. @@ -1305,9 +1305,9 @@ def is_aggr_not_greater_than( limit: Numeric value, column name, or SQL expression for the limit. aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max, count_distinct, stddev, percentile, and more. Custom aggregates are supported with validation. - aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. + aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). Returns: A tuple of: @@ -1332,9 +1332,9 @@ def is_aggr_not_less_than( column: str | Column, limit: int | float | str | Column, aggr_type: str = "count", - aggr_params: dict[str, Any] | None = None, group_by: list[str | Column] | None = None, row_filter: str | None = None, + aggr_params: dict[str, Any] | None = None, ) -> tuple[Column, Callable]: """ Build an aggregation check condition and closure for dataset-level validation. @@ -1348,9 +1348,10 @@ def is_aggr_not_less_than( limit: Numeric value, column name, or SQL expression for the limit. aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max, count_distinct, stddev, percentile, and more. Custom aggregates are supported with validation. - aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. + aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). + Returns: A tuple of: @@ -1375,9 +1376,9 @@ def is_aggr_equal( column: str | Column, limit: int | float | str | Column, aggr_type: str = "count", - aggr_params: dict[str, Any] | None = None, group_by: list[str | Column] | None = None, row_filter: str | None = None, + aggr_params: dict[str, Any] | None = None, ) -> tuple[Column, Callable]: """ Build an aggregation check condition and closure for dataset-level validation. @@ -1391,9 +1392,9 @@ def is_aggr_equal( limit: Numeric value, column name, or SQL expression for the limit. aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max, count_distinct, stddev, percentile, and more. Custom aggregates are supported with validation. - aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. + aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). Returns: A tuple of: @@ -1418,9 +1419,9 @@ def is_aggr_not_equal( column: str | Column, limit: int | float | str | Column, aggr_type: str = "count", - aggr_params: dict[str, Any] | None = None, group_by: list[str | Column] | None = None, row_filter: str | None = None, + aggr_params: dict[str, Any] | None = None, ) -> tuple[Column, Callable]: """ Build an aggregation check condition and closure for dataset-level validation. @@ -1434,9 +1435,9 @@ def is_aggr_not_equal( limit: Numeric value, column name, or SQL expression for the limit. aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max, count_distinct, stddev, percentile, and more. Custom aggregates are supported with validation. - aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. + aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). Returns: A tuple of: @@ -2380,6 +2381,15 @@ def _is_aggr_compare( InvalidParameterError: If a custom aggregate returns non-numeric or multiple rows per group. MissingParameterError: If required parameters for specific aggregates are not provided. """ + # Validate count_distinct with group_by (Spark limitation) + if aggr_type == "count_distinct" and group_by: + raise InvalidParameterError( + "count_distinct cannot be used with group_by due to Spark limitation: " + "DISTINCT is not supported in window functions. " + "Use 'approx_count_distinct' instead, which provides fast approximate counting using HyperLogLog++ " + "and works with both grouped and ungrouped aggregations." + ) + # Warn if using non-curated aggregate function is_curated = aggr_type in CURATED_AGGR_FUNCTIONS if not is_curated: diff --git a/tests/integration/test_dataset_checks.py b/tests/integration/test_dataset_checks.py index d210b43c2..81974926a 100644 --- a/tests/integration/test_dataset_checks.py +++ b/tests/integration/test_dataset_checks.py @@ -525,55 +525,76 @@ def test_is_aggr_not_equal(spark: SparkSession): def test_is_aggr_with_count_distinct(spark: SparkSession): - """Test count_distinct aggregation - GitHub issue #929 requirement: 'each country has one code'.""" + """Test count_distinct for exact cardinality (works without group_by).""" test_df = spark.createDataFrame( [ - ["US", "USA"], - ["US", "USA"], # Same code, OK - ["FR", "FRA"], - ["FR", "FRN"], # Different code, should FAIL - ["DE", "DEU"], + ["val1", "data1"], + ["val1", "data2"], # Same first column + ["val2", "data3"], # Different first column + ["val3", "data4"], ], - "country: string, code: string", + "a: string, b: string", ) checks = [ - # Each country should have exactly one distinct code - is_aggr_not_greater_than("code", limit=1, aggr_type="count_distinct", group_by=["country"]), - # Global distinct count - is_aggr_not_greater_than("country", limit=5, aggr_type="count_distinct"), + # Global count_distinct (no group_by) - should work fine + is_aggr_not_greater_than("a", limit=5, aggr_type="count_distinct"), ] actual = _apply_checks(test_df, checks) - expected_schema = "country: string, code: string, code_count_distinct_group_by_country_greater_than_limit: string, country_count_distinct_greater_than_limit: string" + # Check that the check was applied + assert "a_count_distinct_greater_than_limit" in actual.columns - expected = spark.createDataFrame( + +def test_is_aggr_with_count_distinct_and_group_by_fails(spark: SparkSession): + """Test that count_distinct with group_by raises clear error about Spark limitation.""" + test_df = spark.createDataFrame( + [["group1", "val1"], ["group1", "val2"]], + "a: string, b: string", + ) + + # count_distinct with group_by should raise InvalidParameterError + with pytest.raises(InvalidParameterError, match="count_distinct cannot be used with group_by.*Spark limitation"): + _, apply_fn = is_aggr_not_greater_than("b", limit=1, aggr_type="count_distinct", group_by=["a"]) + apply_fn(test_df) + + +def test_is_aggr_with_approx_count_distinct(spark: SparkSession): + """Test approx_count_distinct for fast cardinality estimation with group_by.""" + test_df = spark.createDataFrame( [ - ["US", "USA", None, None], - ["US", "USA", None, None], - [ - "FR", - "FRA", - "Count_distinct 2 in column 'code' per group of columns 'country' is greater than limit: 1", - None, - ], - [ - "FR", - "FRN", - "Count_distinct 2 in column 'code' per group of columns 'country' is greater than limit: 1", - None, - ], - ["DE", "DEU", None, None], + ["group1", "val1"], + ["group1", "val1"], # Same value + ["group1", "val2"], # Different value + ["group2", "val3"], + ["group2", "val3"], # Same value - only 1 distinct ], - expected_schema, + "a: string, b: string", ) - assert_df_equality(actual, expected, ignore_nullable=True) + checks = [ + # group1 has 2 distinct values, should exceed limit of 1 + is_aggr_not_greater_than("b", limit=1, aggr_type="approx_count_distinct", group_by=["a"]), + ] + + actual = _apply_checks(test_df, checks) + + # Check that the check was applied + assert "b_approx_count_distinct_group_by_a_greater_than_limit" in actual.columns + + # Verify group1 has violations (approx 2 distinct values > limit 1) + group1_violations = ( + actual.filter(F.col("a") == "group1") + .select("b_approx_count_distinct_group_by_a_greater_than_limit") + .distinct() + .collect() + ) + assert len([r for r in group1_violations if r[0] is not None]) > 0 def test_is_aggr_with_statistical_functions(spark: SparkSession): - """Test statistical aggregate functions: stddev, variance, median, mode, skewness, kurtosis.""" + """Test statistical aggregate functions: stddev, variance, median, mode.""" test_df = spark.createDataFrame( [ ["A", 10.0], @@ -583,88 +604,87 @@ def test_is_aggr_with_statistical_functions(spark: SparkSession): ["B", 5.0], ["B", 5.0], ], - "group: string, value: double", + "a: string, b: double", ) checks = [ # Standard deviation check (group A has higher variance than B) - is_aggr_not_greater_than("value", limit=10.0, aggr_type="stddev", group_by=["group"]), + is_aggr_not_greater_than("b", limit=10.0, aggr_type="stddev", group_by=["a"]), # Variance check - is_aggr_not_greater_than("value", limit=100.0, aggr_type="variance", group_by=["group"]), - # Median check - is_aggr_not_greater_than("value", limit=25.0, aggr_type="median"), + is_aggr_not_greater_than("b", limit=100.0, aggr_type="variance", group_by=["a"]), + # Median check (dataset-level, no grouping) + is_aggr_not_greater_than("b", limit=25.0, aggr_type="median"), # Mode check (most frequent value per group) - is_aggr_not_greater_than("value", limit=10.0, aggr_type="mode", group_by=["group"]), + is_aggr_not_greater_than("b", limit=10.0, aggr_type="mode", group_by=["a"]), ] actual = _apply_checks(test_df, checks) - # Check that checks were applied (exact values depend on Spark's statistical functions) - assert "value_stddev_group_by_group_greater_than_limit" in actual.columns - assert "value_variance_group_by_group_greater_than_limit" in actual.columns - assert "value_median_greater_than_limit" in actual.columns - assert "value_mode_group_by_group_greater_than_limit" in actual.columns + # Check that checks were applied + assert "b_stddev_group_by_a_greater_than_limit" in actual.columns + assert "b_variance_group_by_a_greater_than_limit" in actual.columns + assert "b_median_greater_than_limit" in actual.columns + assert "b_mode_group_by_a_greater_than_limit" in actual.columns def test_is_aggr_with_mode_function(spark: SparkSession): - """Test mode aggregate function for categorical data quality checks.""" + """Test mode aggregate function for detecting most common numeric value.""" test_df = spark.createDataFrame( [ - ["service_A", "ERROR_401"], - ["service_A", "ERROR_401"], - ["service_A", "ERROR_401"], # ERROR_401 appears 3 times (mode) - ["service_A", "ERROR_500"], - ["service_B", "ERROR_200"], - ["service_B", "ERROR_200"], # ERROR_200 appears 2 times (mode) - ["service_B", "ERROR_404"], + # groupA: most common error code is 401 (appears 3 times) + ["groupA", 401], + ["groupA", 401], + ["groupA", 401], + ["groupA", 500], + # groupB: most common error code is 200 (appears 2 times) + ["groupB", 200], + ["groupB", 200], + ["groupB", 404], ], - "service: string, error_code: string", + "a: string, b: int", ) - # Check that modal error code doesn't appear too frequently (>2 times per service) + # Check that the most common error code value doesn't exceed threshold checks = [ - is_aggr_not_greater_than("error_code", limit=2, aggr_type="mode", group_by=["service"]), + is_aggr_not_greater_than("b", limit=400, aggr_type="mode", group_by=["a"]), ] actual = _apply_checks(test_df, checks) - # service_A should fail (mode=3 > limit=2), service_B should pass (mode=2 <= limit=2) - assert "error_code_mode_group_by_service_greater_than_limit" in actual.columns + # groupA should fail (mode=401 > limit=400), groupB should pass (mode=200 <= limit=400) + assert "b_mode_group_by_a_greater_than_limit" in actual.columns - # Verify failures - service_A rows should have error messages - service_a_errors = ( - actual.filter(F.col("service") == "service_A") - .select("error_code_mode_group_by_service_greater_than_limit") - .distinct() - .collect() + # Verify failures - groupA rows should have error messages (mode 401 > 400) + group_a_errors = ( + actual.filter(F.col("a") == "groupA").select("b_mode_group_by_a_greater_than_limit").distinct().collect() ) - assert len([r for r in service_a_errors if r[0] is not None]) > 0 + assert len([r for r in group_a_errors if r[0] is not None]) > 0 def test_is_aggr_with_percentile_functions(spark: SparkSession): """Test percentile and approx_percentile with aggr_params.""" test_df = spark.createDataFrame( - [(i, float(i)) for i in range(1, 101)], - "id: int, value: double", + [(f"row{i}", float(i)) for i in range(1, 101)], + "a: string, b: double", ) checks = [ # P95 should be around 95 - is_aggr_not_greater_than("value", limit=100.0, aggr_type="percentile", aggr_params={"percentile": 0.95}), + is_aggr_not_greater_than("b", limit=100.0, aggr_type="percentile", aggr_params={"percentile": 0.95}), # P99 with approx_percentile - is_aggr_not_greater_than("value", limit=100.0, aggr_type="approx_percentile", aggr_params={"percentile": 0.99}), + is_aggr_not_greater_than("b", limit=100.0, aggr_type="approx_percentile", aggr_params={"percentile": 0.99}), # P50 (median) with accuracy parameter is_aggr_not_less_than( - "value", limit=40.0, aggr_type="approx_percentile", aggr_params={"percentile": 0.50, "accuracy": 100} + "b", limit=40.0, aggr_type="approx_percentile", aggr_params={"percentile": 0.50, "accuracy": 100} ), ] actual = _apply_checks(test_df, checks) # Verify columns exist - assert "value_percentile_greater_than_limit" in actual.columns - assert "value_approx_percentile_greater_than_limit" in actual.columns - assert "value_approx_percentile_less_than_limit" in actual.columns + assert "b_percentile_greater_than_limit" in actual.columns + assert "b_approx_percentile_greater_than_limit" in actual.columns + assert "b_approx_percentile_less_than_limit" in actual.columns def test_is_aggr_percentile_missing_params(spark: SparkSession): @@ -714,22 +734,22 @@ def test_is_aggr_custom_aggregate_with_warning(spark: SparkSession): """Test that custom (non-curated) aggregates work but produce warning.""" test_df = spark.createDataFrame( [("A", 10), ("B", 20), ("C", 30)], - "category: string, value: int", + "a: string, b: int", ) - # Use a valid aggregate that's not in curated list (e.g., if we had a UDAF) - # For now, test with any_value which is a valid Spark aggregate + # Use a valid aggregate that's not in curated list (e.g., any_value) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - _, apply_fn = is_aggr_not_greater_than("value", limit=100, aggr_type="any_value") + _, apply_fn = is_aggr_not_greater_than("b", limit=100, aggr_type="any_value") result = apply_fn(test_df) # Should have warning about non-curated aggregate assert len(w) > 0 assert "non-curated" in str(w[0].message).lower() - # Should still work - assert "value_any_value_greater_than_limit" in result.columns + # Should still work - check for condition column with any_value in name + condition_cols = [c for c in result.columns if "any_value" in c and "condition" in c] + assert len(condition_cols) > 0 def test_dataset_compare(spark: SparkSession, set_utc_timezone): From 53ae329fb25ab354f3425ff898968d46d8dd6a5a Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Thu, 27 Nov 2025 17:14:40 +0000 Subject: [PATCH 08/24] docs + fix: count_distinct is now supported with a generalised implementation. docs updated with more user friendly language. --- demos/dqx_demo_library.py | 6 ++- .../docs/guide/quality_checks_definition.mdx | 26 ++++++----- docs/dqx/docs/reference/quality_checks.mdx | 38 ++++++++++++---- src/databricks/labs/dqx/check_funcs.py | 44 +++++++++++++------ tests/integration/test_dataset_checks.py | 43 +++++++++++++++--- 5 files changed, 116 insertions(+), 41 deletions(-) diff --git a/demos/dqx_demo_library.py b/demos/dqx_demo_library.py index f32e8bb2a..f5988eebf 100644 --- a/demos/dqx_demo_library.py +++ b/demos/dqx_demo_library.py @@ -869,9 +869,11 @@ def not_ends_with(column: str, suffix: str) -> Column: # COMMAND ---------- # MAGIC %md -# MAGIC #### Example 5: Count Distinct for Uniqueness Validation +# MAGIC #### Example 5: Uniqueness Validation with Count Distinct # MAGIC # MAGIC Ensure referential integrity: each country should have exactly one country code. +# MAGIC +# MAGIC Use `count_distinct` for exact cardinality validation across groups. # COMMAND ---------- @@ -891,7 +893,7 @@ def not_ends_with(column: str, suffix: str) -> Column: check_func=check_funcs.is_aggr_not_greater_than, column="country_code", check_func_kwargs={ - "aggr_type": "count_distinct", + "aggr_type": "count_distinct", # Exact distinct count per group "group_by": ["country"], "limit": 1 }, diff --git a/docs/dqx/docs/guide/quality_checks_definition.mdx b/docs/dqx/docs/guide/quality_checks_definition.mdx index 3543d19bb..f26ce73c6 100644 --- a/docs/dqx/docs/guide/quality_checks_definition.mdx +++ b/docs/dqx/docs/guide/quality_checks_definition.mdx @@ -143,7 +143,7 @@ This approach provides static type checking and autocompletion in IDEs, making i check_func=check_funcs.is_aggr_not_greater_than, column="country_code", check_func_kwargs={ - "aggr_type": "count_distinct", + "aggr_type": "count_distinct", # Exact distinct count (automatically uses two-stage aggregation) "group_by": ["country"], "limit": 1 }, @@ -266,15 +266,21 @@ Monitor service performance and ensure SLA compliance: limit: 30.0 ``` - -- **count_distinct**: Exact cardinality (global only, cannot use with `group_by` due to Spark limitation) -- **approx_count_distinct**: Fast cardinality estimation (works with `group_by`, uses HyperLogLog++) -- **stddev/variance**: Anomaly detection, consistency monitoring -- **percentile**: SLA compliance, outlier detection (P95, P99) -- **median**: Baseline checks, central tendency -- **mode**: Most frequent value checks, categorical data quality -- **approx_percentile**: Fast percentile estimation for large datasets -- **approx_count_distinct**: Efficient cardinality for high-cardinality columns + +**Uniqueness & Cardinality:** +- `count_distinct` - Exact distinct counts (uniqueness validation) +- `approx_count_distinct` - Fast approximate counts (very large datasets) + +**Statistical Monitoring:** +- `stddev` / `variance` - Detect anomalies and inconsistencies +- `median` - Baseline checks, central tendency +- `mode` - Most frequent value (categorical data) + +**Performance & SLAs:** +- `percentile` - Exact P95/P99 for SLA compliance +- `approx_percentile` - Fast percentile estimates for large datasets + +**Learn more:** See all [Databricks aggregate functions](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha) ### Checks defined using metadata (list of dictionaries) diff --git a/docs/dqx/docs/reference/quality_checks.mdx b/docs/dqx/docs/reference/quality_checks.mdx index 4b27a636f..190f0451a 100644 --- a/docs/dqx/docs/reference/quality_checks.mdx +++ b/docs/dqx/docs/reference/quality_checks.mdx @@ -1604,11 +1604,19 @@ Complex data types are supported as well. ## Aggregate Function Types -The `is_aggr_*` functions support a wide range of aggregate functions for data quality checks. Functions are categorized as **curated** (recommended for DQ) or **custom** (user-defined, validated at runtime). +Use `is_aggr_*` functions to validate metrics across your entire dataset or within groups (e.g., "each country has ≤ 1 country code", "P95 latency < 1s per region"). + +**When to use:** +- Validate cardinality and uniqueness constraints +- Monitor statistical properties (variance, outliers) +- Enforce SLAs and performance thresholds +- Detect data quality issues across groups + +DQX supports 20 **curated** aggregate functions (recommended for data quality) plus any [Databricks aggregate function](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha) as custom aggregates. ### Curated Aggregate Functions -Curated functions are validated and optimized for data quality use cases: +Validated and optimized for data quality use cases: #### Basic Aggregations - `count` - Count records (nulls excluded) @@ -1618,8 +1626,8 @@ Curated functions are validated and optimized for data quality use cases: - `max` - Maximum value #### Cardinality & Uniqueness -- `count_distinct` - Exact distinct count (global aggregations only, not supported with `group_by`) -- `approx_count_distinct` - Approximate distinct count using HyperLogLog++ (works with `group_by`, faster for large datasets) +- `count_distinct` - Exact distinct count (use for uniqueness validation) +- `approx_count_distinct` - Fast approximate count (±5% accuracy, ideal for very large datasets) - `count_if` - Conditional counting #### Statistical Analysis @@ -1638,14 +1646,27 @@ Curated functions are validated and optimized for data quality use cases: ### Custom Aggregates -Custom aggregate functions (including UDAFs) are supported with runtime validation. A warning will be issued, and the function must return a single numeric value per group. +Any [Databricks aggregate function](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha) can be used as a custom aggregate. DQX will: +- Issue a warning (use curated functions when possible) +- Validate the function returns numeric values -**Not suitable:** Functions returning arrays, structs, or maps (e.g., `collect_list`, `collect_set`, `string_agg`) will fail validation. +**Note:** Functions returning arrays, structs, or maps (e.g., `collect_list`, `collect_set`) will fail with a clear error message explaining they cannot be compared to numeric limits. ### Extended Examples ```yaml -# count_distinct: Exact distinct count (global only, no group_by) +# Uniqueness: Each country has exactly one country code +- criticality: error + check: + function: is_aggr_not_greater_than + arguments: + column: country_code + aggr_type: count_distinct + group_by: + - country + limit: 1 + +# Cardinality: Total unique users across entire dataset - criticality: error check: function: is_aggr_not_greater_than @@ -1654,7 +1675,7 @@ Custom aggregate functions (including UDAFs) are supported with runtime validati aggr_type: count_distinct limit: 1000000 -# approx_count_distinct: Fast cardinality estimation (works with group_by) +# Fast approximate count for very large datasets - criticality: error check: function: is_aggr_not_greater_than @@ -1731,6 +1752,7 @@ Custom aggregate functions (including UDAFs) are supported with runtime validati limit: 100 ``` + ```yaml # foreign_key check using reference DataFrame - criticality: error diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index a9e3c3f78..182756f53 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -53,6 +53,14 @@ "approx_percentile", } +# Aggregate functions incompatible with Spark window functions +# These require two-stage aggregation (groupBy + join) instead of window functions when used with group_by +# Spark limitation: DISTINCT operations are not supported in window functions +WINDOW_INCOMPATIBLE_AGGREGATES = { + "count_distinct", # DISTINCT_WINDOW_FUNCTION_UNSUPPORTED error + # Future: Add other aggregates that don't work with windows (e.g., collect_set with DISTINCT) +} + class DQPattern(Enum): """Enum class to represent DQ patterns used to match data in columns.""" @@ -2381,15 +2389,6 @@ def _is_aggr_compare( InvalidParameterError: If a custom aggregate returns non-numeric or multiple rows per group. MissingParameterError: If required parameters for specific aggregates are not provided. """ - # Validate count_distinct with group_by (Spark limitation) - if aggr_type == "count_distinct" and group_by: - raise InvalidParameterError( - "count_distinct cannot be used with group_by due to Spark limitation: " - "DISTINCT is not supported in window functions. " - "Use 'approx_count_distinct' instead, which provides fast approximate counting using HyperLogLog++ " - "and works with both grouped and ungrouped aggregations." - ) - # Warn if using non-curated aggregate function is_curated = aggr_type in CURATED_AGGR_FUNCTIONS if not is_curated: @@ -2447,12 +2446,29 @@ def apply(df: DataFrame) -> DataFrame: aggr_expr = _build_aggregate_expression(aggr_type, filtered_expr, aggr_params) if group_by: - window_spec = Window.partitionBy(*[F.col(col) if isinstance(col, str) else col for col in group_by]) - df = df.withColumn(metric_col, aggr_expr.over(window_spec)) + # Check if aggregate is incompatible with window functions (e.g., count_distinct with DISTINCT) + if aggr_type in WINDOW_INCOMPATIBLE_AGGREGATES: + # Use two-stage aggregation: groupBy + join (instead of window functions) + # This is required for aggregates like count_distinct that don't support window DISTINCT operations + group_cols = [F.col(col) if isinstance(col, str) else col for col in group_by] + agg_df = df.groupBy(*group_cols).agg(aggr_expr.alias(metric_col)) + + # Validate custom aggregates before join + if not is_curated: + _validate_aggregate_return_type(agg_df, aggr_type, metric_col) + + # Join aggregated metrics back to original DataFrame to maintain row-level granularity + # Use column names for join (extract names from Column objects if present) + join_cols = [col if isinstance(col, str) else get_column_name_or_alias(col) for col in group_by] + df = df.join(agg_df, on=join_cols, how="left") + else: + # Use standard window function approach for window-compatible aggregates + window_spec = Window.partitionBy(*[F.col(col) if isinstance(col, str) else col for col in group_by]) + df = df.withColumn(metric_col, aggr_expr.over(window_spec)) - # Validate custom aggregates (type check only - window functions return same row count) - if not is_curated: - _validate_aggregate_return_type(df, aggr_type, metric_col) + # Validate custom aggregates (type check only - window functions return same row count) + if not is_curated: + _validate_aggregate_return_type(df, aggr_type, metric_col) else: # When no group-by columns are provided, using partitionBy would move all rows into a single partition, # forcing the window function to process the entire dataset in one task. diff --git a/tests/integration/test_dataset_checks.py b/tests/integration/test_dataset_checks.py index 81974926a..c6e1bebc7 100644 --- a/tests/integration/test_dataset_checks.py +++ b/tests/integration/test_dataset_checks.py @@ -547,17 +547,46 @@ def test_is_aggr_with_count_distinct(spark: SparkSession): assert "a_count_distinct_greater_than_limit" in actual.columns -def test_is_aggr_with_count_distinct_and_group_by_fails(spark: SparkSession): - """Test that count_distinct with group_by raises clear error about Spark limitation.""" +def test_is_aggr_with_count_distinct_and_group_by(spark: SparkSession): + """Test that count_distinct with group_by works using two-stage aggregation.""" test_df = spark.createDataFrame( - [["group1", "val1"], ["group1", "val2"]], + [ + ["group1", "val1"], + ["group1", "val1"], # Same value + ["group1", "val2"], # Different value - 2 distinct + ["group2", "val3"], + ["group2", "val3"], # Same value - only 1 distinct + ], "a: string, b: string", ) - # count_distinct with group_by should raise InvalidParameterError - with pytest.raises(InvalidParameterError, match="count_distinct cannot be used with group_by.*Spark limitation"): - _, apply_fn = is_aggr_not_greater_than("b", limit=1, aggr_type="count_distinct", group_by=["a"]) - apply_fn(test_df) + checks = [ + # group1 has 2 distinct values, should exceed limit of 1 + is_aggr_not_greater_than("b", limit=1, aggr_type="count_distinct", group_by=["a"]), + ] + + actual = _apply_checks(test_df, checks) + + # Check that the check was applied + assert "b_count_distinct_group_by_a_greater_than_limit" in actual.columns + + # Verify group1 has violations (2 distinct values > limit 1) + group1_violations = ( + actual.filter(F.col("a") == "group1") + .select("b_count_distinct_group_by_a_greater_than_limit") + .distinct() + .collect() + ) + assert len([r for r in group1_violations if r[0] is not None]) > 0 + + # Verify group2 passes (1 distinct value <= limit 1) + group2_violations = ( + actual.filter(F.col("a") == "group2") + .select("b_count_distinct_group_by_a_greater_than_limit") + .distinct() + .collect() + ) + assert len([r for r in group2_violations if r[0] is None]) > 0 def test_is_aggr_with_approx_count_distinct(spark: SparkSession): From 1611dfb41632485be4662dd83931c3fbe3fe7f9e Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Thu, 27 Nov 2025 17:25:25 +0000 Subject: [PATCH 09/24] docs + fmt: improved the doc language and code comments. --- demos/dqx_demo_library.py | 10 ++++--- docs/dqx/docs/reference/quality_checks.mdx | 12 +++++---- src/databricks/labs/dqx/check_funcs.py | 31 +++++++++++----------- tests/integration/test_dataset_checks.py | 4 +-- 4 files changed, 30 insertions(+), 27 deletions(-) diff --git a/demos/dqx_demo_library.py b/demos/dqx_demo_library.py index f5988eebf..5b439c5a5 100644 --- a/demos/dqx_demo_library.py +++ b/demos/dqx_demo_library.py @@ -696,8 +696,8 @@ def not_ends_with(column: str, suffix: str) -> Column: # MAGIC DQX now supports 20 curated aggregate functions for advanced data quality monitoring: # MAGIC - **Statistical functions**: `stddev`, `variance`, `median`, `mode`, `skewness`, `kurtosis` for anomaly detection # MAGIC - **Percentile functions**: `percentile`, `approx_percentile` for SLA monitoring -# MAGIC - **Cardinality functions**: `count_distinct` (global only), `approx_count_distinct` (works with group_by, uses HyperLogLog++) -# MAGIC - **Custom aggregates**: Support for user-defined functions with runtime validation +# MAGIC - **Cardinality functions**: `count_distinct`, `approx_count_distinct` (uses HyperLogLog++) +# MAGIC - **Any Databricks built-in aggregate**: Supported with runtime validation # COMMAND ---------- @@ -789,12 +789,14 @@ def not_ends_with(column: str, suffix: str) -> Column: # COMMAND ---------- # MAGIC %md -# MAGIC #### Example 3: Custom Aggregate Functions with Runtime Validation +# MAGIC #### Example 3: Non-Curated Aggregate Functions with Runtime Validation # MAGIC -# MAGIC DQX supports custom aggregate functions (including UDAFs) with: +# MAGIC DQX supports any Databricks built-in aggregate function beyond the curated list: # MAGIC - **Warning**: Non-curated functions trigger a warning # MAGIC - **Runtime validation**: Ensures the function returns numeric values compatible with comparisons # MAGIC - **Graceful errors**: Invalid aggregates (e.g., `collect_list` returning arrays) fail with clear messages +# MAGIC +# MAGIC **Note**: User-Defined Aggregate Functions (UDAFs) are not currently supported. # COMMAND ---------- diff --git a/docs/dqx/docs/reference/quality_checks.mdx b/docs/dqx/docs/reference/quality_checks.mdx index 190f0451a..1df44e05a 100644 --- a/docs/dqx/docs/reference/quality_checks.mdx +++ b/docs/dqx/docs/reference/quality_checks.mdx @@ -1612,7 +1612,7 @@ Use `is_aggr_*` functions to validate metrics across your entire dataset or with - Enforce SLAs and performance thresholds - Detect data quality issues across groups -DQX supports 20 **curated** aggregate functions (recommended for data quality) plus any [Databricks aggregate function](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha) as custom aggregates. +DQX supports 20 **curated** aggregate functions (recommended for data quality) plus any other [Databricks built-in aggregate function](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha). ### Curated Aggregate Functions @@ -1644,13 +1644,15 @@ Validated and optimized for data quality use cases: - `percentile` - Exact percentile (requires `aggr_params`) - `approx_percentile` - Approximate percentile (faster, requires `aggr_params`) -### Custom Aggregates +### Non-Curated Aggregates -Any [Databricks aggregate function](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha) can be used as a custom aggregate. DQX will: -- Issue a warning (use curated functions when possible) +Any [Databricks built-in aggregate function](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha) can be used beyond the curated list. DQX will: +- Issue a warning (recommending curated functions when possible) - Validate the function returns numeric values -**Note:** Functions returning arrays, structs, or maps (e.g., `collect_list`, `collect_set`) will fail with a clear error message explaining they cannot be compared to numeric limits. +**Limitations:** +- Functions returning arrays, structs, or maps (e.g., `collect_list`, `collect_set`) will fail with a clear error message explaining they cannot be compared to numeric limits +- **User-Defined Aggregate Functions (UDAFs) are not currently supported** - only Databricks built-in functions in `pyspark.sql.functions` ### Extended Examples diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index 182756f53..9cd649405 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -1306,13 +1306,13 @@ def is_aggr_not_greater_than( This function verifies that an aggregation on a column or group of columns does not exceed a specified limit. Supports curated aggregate functions (count, sum, avg, stddev, percentile, etc.) - and custom aggregates. Rows where the aggregation result exceeds the limit are flagged. + and any Databricks built-in aggregate. Rows where the aggregation result exceeds the limit are flagged. Args: column: Column name (str) or Column expression to aggregate. limit: Numeric value, column name, or SQL expression for the limit. aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max, - count_distinct, stddev, percentile, and more. Custom aggregates are supported with validation. + count_distinct, stddev, percentile, and more. Any Databricks built-in aggregate is supported. group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). @@ -1349,18 +1349,17 @@ def is_aggr_not_less_than( This function verifies that an aggregation on a column or group of columns is not below a specified limit. Supports curated aggregate functions (count, sum, avg, stddev, percentile, etc.) - and custom aggregates. Rows where the aggregation result is below the limit are flagged. + and any Databricks built-in aggregate. Rows where the aggregation result is below the limit are flagged. Args: column: Column name (str) or Column expression to aggregate. limit: Numeric value, column name, or SQL expression for the limit. aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max, - count_distinct, stddev, percentile, and more. Custom aggregates are supported with validation. + count_distinct, stddev, percentile, and more. Any Databricks built-in aggregate is supported. group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). - Returns: A tuple of: - A Spark Column representing the condition for aggregation limit violations. @@ -1393,13 +1392,13 @@ def is_aggr_equal( This function verifies that an aggregation on a column or group of columns is equal to a specified limit. Supports curated aggregate functions (count, sum, avg, stddev, percentile, etc.) - and custom aggregates. Rows where the aggregation result is not equal to the limit are flagged. + and any Databricks built-in aggregate. Rows where the aggregation result is not equal to the limit are flagged. Args: column: Column name (str) or Column expression to aggregate. limit: Numeric value, column name, or SQL expression for the limit. aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max, - count_distinct, stddev, percentile, and more. Custom aggregates are supported with validation. + count_distinct, stddev, percentile, and more. Any Databricks built-in aggregate is supported. group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). @@ -1436,13 +1435,13 @@ def is_aggr_not_equal( This function verifies that an aggregation on a column or group of columns is not equal to a specified limit. Supports curated aggregate functions (count, sum, avg, stddev, percentile, etc.) - and custom aggregates. Rows where the aggregation result is equal to the limit are flagged. + and any Databricks built-in aggregate. Rows where the aggregation result is equal to the limit are flagged. Args: column: Column name (str) or Column expression to aggregate. limit: Numeric value, column name, or SQL expression for the limit. aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max, - count_distinct, stddev, percentile, and more. Custom aggregates are supported with validation. + count_distinct, stddev, percentile, and more. Any Databricks built-in aggregate is supported. group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). @@ -2370,8 +2369,8 @@ def _is_aggr_compare( column: Column name (str) or Column expression to aggregate. limit: Numeric value, column name, or SQL expression for the limit. aggr_type: Aggregation type. Curated functions include 'count', 'sum', 'avg', 'min', 'max', - 'count_distinct', 'stddev', 'percentile', and more. Custom aggregate functions are - supported but will trigger a warning and runtime validation. + 'count_distinct', 'stddev', 'percentile', and more. Any Databricks built-in aggregate + function is supported (will trigger a warning for non-curated functions). aggr_params: Optional dictionary of parameters for aggregate functions that require them (e.g., percentile functions need {"percentile": 0.95}). group_by: Optional list of columns or Column expressions to group by. @@ -2386,7 +2385,7 @@ def _is_aggr_compare( - A closure that applies the aggregation check logic. Raises: - InvalidParameterError: If a custom aggregate returns non-numeric or multiple rows per group. + InvalidParameterError: If an aggregate returns non-numeric types or is not found. MissingParameterError: If required parameters for specific aggregates are not provided. """ # Warn if using non-curated aggregate function @@ -2395,7 +2394,7 @@ def _is_aggr_compare( warnings.warn( f"Using non-curated aggregate function '{aggr_type}'. " f"Curated functions: {', '.join(sorted(CURATED_AGGR_FUNCTIONS))}. " - f"Custom aggregates must return a single numeric value per group.", + f"Non-curated aggregates must return a single numeric value per group.", UserWarning, stacklevel=2, ) @@ -2453,7 +2452,7 @@ def apply(df: DataFrame) -> DataFrame: group_cols = [F.col(col) if isinstance(col, str) else col for col in group_by] agg_df = df.groupBy(*group_cols).agg(aggr_expr.alias(metric_col)) - # Validate custom aggregates before join + # Validate non-curated aggregates before join if not is_curated: _validate_aggregate_return_type(agg_df, aggr_type, metric_col) @@ -2466,7 +2465,7 @@ def apply(df: DataFrame) -> DataFrame: window_spec = Window.partitionBy(*[F.col(col) if isinstance(col, str) else col for col in group_by]) df = df.withColumn(metric_col, aggr_expr.over(window_spec)) - # Validate custom aggregates (type check only - window functions return same row count) + # Validate non-curated aggregates (type check only - window functions return same row count) if not is_curated: _validate_aggregate_return_type(df, aggr_type, metric_col) else: @@ -2477,7 +2476,7 @@ def apply(df: DataFrame) -> DataFrame: # so no explicit limit is required (informational only). agg_df = df.select(aggr_expr.alias(metric_col)).limit(1) - # Validate custom aggregates (type check only - we already limited to 1 row) + # Validate non-curated aggregates (type check only - we already limited to 1 row) if not is_curated: _validate_aggregate_return_type(agg_df, aggr_type, metric_col) diff --git a/tests/integration/test_dataset_checks.py b/tests/integration/test_dataset_checks.py index c6e1bebc7..f01d76713 100644 --- a/tests/integration/test_dataset_checks.py +++ b/tests/integration/test_dataset_checks.py @@ -759,8 +759,8 @@ def test_is_aggr_with_collect_list_fails_with_group_by(spark: SparkSession): apply_fn(test_df) -def test_is_aggr_custom_aggregate_with_warning(spark: SparkSession): - """Test that custom (non-curated) aggregates work but produce warning.""" +def test_is_aggr_non_curated_aggregate_with_warning(spark: SparkSession): + """Test that non-curated (built-in but not in curated list) aggregates work but produce warning.""" test_df = spark.createDataFrame( [("A", 10), ("B", 20), ("C", 30)], "a: string, b: int", From d38f67179b178ce74f84c1878b5fa50e12fe5f9c Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Thu, 27 Nov 2025 19:36:59 +0000 Subject: [PATCH 10/24] test cov + docs: removed dead code and added a new test. Improved docs for aggr_params --- demos/dqx_demo_library.py | 6 +++- docs/dqx/docs/reference/quality_checks.mdx | 31 +++++++++++++++++++ src/databricks/labs/dqx/check_funcs.py | 20 +++++++----- tests/integration/test_dataset_checks.py | 36 ++++++++++++++++++++++ 4 files changed, 84 insertions(+), 9 deletions(-) diff --git a/demos/dqx_demo_library.py b/demos/dqx_demo_library.py index 5b439c5a5..19d3bbbe3 100644 --- a/demos/dqx_demo_library.py +++ b/demos/dqx_demo_library.py @@ -842,6 +842,10 @@ def not_ends_with(column: str, suffix: str) -> Column: # MAGIC #### Example 4: Percentile Functions for SLA Monitoring # MAGIC # MAGIC Monitor P95 latency to ensure 95% of API requests meet SLA requirements. +# MAGIC +# MAGIC **Using `aggr_params`:** Pass aggregate function parameters as a dictionary. +# MAGIC - Single parameter: `aggr_params={"percentile": 0.95}` +# MAGIC - Multiple parameters: `aggr_params={"percentile": 0.99, "accuracy": 10000}` # COMMAND ---------- @@ -858,7 +862,7 @@ def not_ends_with(column: str, suffix: str) -> Column: column="latency_ms", check_func_kwargs={ "aggr_type": "percentile", - "aggr_params": {"percentile": 0.95}, # P95 + "aggr_params": {"percentile": 0.95}, # aggr_params as dict "group_by": ["date"], "limit": 950.0 }, diff --git a/docs/dqx/docs/reference/quality_checks.mdx b/docs/dqx/docs/reference/quality_checks.mdx index 1df44e05a..ac0f4786f 100644 --- a/docs/dqx/docs/reference/quality_checks.mdx +++ b/docs/dqx/docs/reference/quality_checks.mdx @@ -1644,6 +1644,37 @@ Validated and optimized for data quality use cases: - `percentile` - Exact percentile (requires `aggr_params`) - `approx_percentile` - Approximate percentile (faster, requires `aggr_params`) +### Using `aggr_params` for Function Parameters + +Some aggregate functions require additional parameters beyond the column. Pass these as a dictionary: + +**Python:** +```python +check_func_kwargs={ + "aggr_type": "percentile", + "aggr_params": {"percentile": 0.95}, # Single parameter + "limit": 1000 +} + +# Multiple parameters: +check_func_kwargs={ + "aggr_type": "approx_percentile", + "aggr_params": { + "percentile": 0.99, + "accuracy": 10000 # Multiple parameters as dict + } +} +``` + +**YAML:** +```yaml +aggr_params: + percentile: 0.95 + accuracy: 10000 # Multiple parameters as nested YAML +``` + +The `aggr_params` dictionary is unpacked as keyword arguments to the Spark aggregate function. See [Databricks function documentation](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha) for available parameters for each function. + ### Non-Curated Aggregates Any [Databricks built-in aggregate function](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha) can be used beyond the curated list. DQX will: diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index 9cd649405..2d3661341 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -1315,7 +1315,9 @@ def is_aggr_not_greater_than( count_distinct, stddev, percentile, and more. Any Databricks built-in aggregate is supported. group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. - aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). + aggr_params: Optional dict of parameters for aggregates requiring them. The dict is unpacked as keyword + arguments to the Spark function. Single parameter: `{"percentile": 0.95}`. Multiple parameters: + `{"percentile": 0.99, "accuracy": 10000}`. Returns: A tuple of: @@ -1358,7 +1360,9 @@ def is_aggr_not_less_than( count_distinct, stddev, percentile, and more. Any Databricks built-in aggregate is supported. group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. - aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). + aggr_params: Optional dict of parameters for aggregates requiring them. The dict is unpacked as keyword + arguments to the Spark function. Single parameter: `{"percentile": 0.95}`. Multiple parameters: + `{"percentile": 0.99, "accuracy": 10000}`. Returns: A tuple of: @@ -1401,7 +1405,9 @@ def is_aggr_equal( count_distinct, stddev, percentile, and more. Any Databricks built-in aggregate is supported. group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. - aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). + aggr_params: Optional dict of parameters for aggregates requiring them. The dict is unpacked as keyword + arguments to the Spark function. Single parameter: `{"percentile": 0.95}`. Multiple parameters: + `{"percentile": 0.99, "accuracy": 10000}`. Returns: A tuple of: @@ -1444,7 +1450,9 @@ def is_aggr_not_equal( count_distinct, stddev, percentile, and more. Any Databricks built-in aggregate is supported. group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. - aggr_params: Optional dict of parameters for aggregates requiring them (e.g., `{"percentile": 0.95}`). + aggr_params: Optional dict of parameters for aggregates requiring them. The dict is unpacked as keyword + arguments to the Spark function. Single parameter: `{"percentile": 0.95}`. Multiple parameters: + `{"percentile": 0.99, "accuracy": 10000}`. Returns: A tuple of: @@ -2452,10 +2460,6 @@ def apply(df: DataFrame) -> DataFrame: group_cols = [F.col(col) if isinstance(col, str) else col for col in group_by] agg_df = df.groupBy(*group_cols).agg(aggr_expr.alias(metric_col)) - # Validate non-curated aggregates before join - if not is_curated: - _validate_aggregate_return_type(agg_df, aggr_type, metric_col) - # Join aggregated metrics back to original DataFrame to maintain row-level granularity # Use column names for join (extract names from Column objects if present) join_cols = [col if isinstance(col, str) else get_column_name_or_alias(col) for col in group_by] diff --git a/tests/integration/test_dataset_checks.py b/tests/integration/test_dataset_checks.py index f01d76713..fc4f3865f 100644 --- a/tests/integration/test_dataset_checks.py +++ b/tests/integration/test_dataset_checks.py @@ -622,6 +622,42 @@ def test_is_aggr_with_approx_count_distinct(spark: SparkSession): assert len([r for r in group1_violations if r[0] is not None]) > 0 +def test_is_aggr_with_aggr_params_generic(spark: SparkSession): + """Test aggr_params passed through to generic aggregate function (not percentile).""" + test_df = spark.createDataFrame( + [ + ["group1", "val1"], + ["group1", "val1"], + ["group1", "val2"], + ["group1", "val3"], + ["group2", "valA"], + ["group2", "valA"], + ], + "a: string, b: string", + ) + + # Test approx_count_distinct with rsd (relative standard deviation) parameter + # This tests the generic aggr_params pass-through (line 2313 in check_funcs.py) + checks = [ + is_aggr_not_greater_than( + "b", + limit=5, + aggr_type="approx_count_distinct", + aggr_params={"rsd": 0.01}, # More accurate approximation (1% relative error) + group_by=["a"], + ), + ] + + actual = _apply_checks(test_df, checks) + + # Check that the check was applied with custom rsd parameter + assert "b_approx_count_distinct_group_by_a_greater_than_limit" in actual.columns + + # Verify results are computed (group1 has ~3 distinct, group2 has ~1 distinct) + result_count = actual.select("b_approx_count_distinct_group_by_a_greater_than_limit").count() + assert result_count == 6 # All rows should have the metric + + def test_is_aggr_with_statistical_functions(spark: SparkSession): """Test statistical aggregate functions: stddev, variance, median, mode.""" test_df = spark.createDataFrame( From b48ff6e0ab562c6d53d6b7123eb3481aebca74e0 Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Thu, 27 Nov 2025 19:41:35 +0000 Subject: [PATCH 11/24] fix: docs build docstrings and dict syntax issue --- src/databricks/labs/dqx/check_funcs.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index 2d3661341..b3ad8896a 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -1315,9 +1315,9 @@ def is_aggr_not_greater_than( count_distinct, stddev, percentile, and more. Any Databricks built-in aggregate is supported. group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. - aggr_params: Optional dict of parameters for aggregates requiring them. The dict is unpacked as keyword - arguments to the Spark function. Single parameter: `{"percentile": 0.95}`. Multiple parameters: - `{"percentile": 0.99, "accuracy": 10000}`. + aggr_params: Optional dict of parameters for aggregates requiring them (e.g., percentile value for + percentile functions, accuracy for approximate aggregates). Parameters are passed as keyword + arguments to the Spark function. Returns: A tuple of: @@ -1360,9 +1360,9 @@ def is_aggr_not_less_than( count_distinct, stddev, percentile, and more. Any Databricks built-in aggregate is supported. group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. - aggr_params: Optional dict of parameters for aggregates requiring them. The dict is unpacked as keyword - arguments to the Spark function. Single parameter: `{"percentile": 0.95}`. Multiple parameters: - `{"percentile": 0.99, "accuracy": 10000}`. + aggr_params: Optional dict of parameters for aggregates requiring them (e.g., percentile value for + percentile functions, accuracy for approximate aggregates). Parameters are passed as keyword + arguments to the Spark function. Returns: A tuple of: @@ -1405,9 +1405,9 @@ def is_aggr_equal( count_distinct, stddev, percentile, and more. Any Databricks built-in aggregate is supported. group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. - aggr_params: Optional dict of parameters for aggregates requiring them. The dict is unpacked as keyword - arguments to the Spark function. Single parameter: `{"percentile": 0.95}`. Multiple parameters: - `{"percentile": 0.99, "accuracy": 10000}`. + aggr_params: Optional dict of parameters for aggregates requiring them (e.g., percentile value for + percentile functions, accuracy for approximate aggregates). Parameters are passed as keyword + arguments to the Spark function. Returns: A tuple of: @@ -1450,9 +1450,9 @@ def is_aggr_not_equal( count_distinct, stddev, percentile, and more. Any Databricks built-in aggregate is supported. group_by: Optional list of column names or Column expressions to group by. row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter. - aggr_params: Optional dict of parameters for aggregates requiring them. The dict is unpacked as keyword - arguments to the Spark function. Single parameter: `{"percentile": 0.95}`. Multiple parameters: - `{"percentile": 0.99, "accuracy": 10000}`. + aggr_params: Optional dict of parameters for aggregates requiring them (e.g., percentile value for + percentile functions, accuracy for approximate aggregates). Parameters are passed as keyword + arguments to the Spark function. Returns: A tuple of: From 1f8a8dfa6cf7b3992b455c80fbf0f627472e7850 Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Thu, 27 Nov 2025 20:03:05 +0000 Subject: [PATCH 12/24] docs: Spark/DBR version compatibility. --- docs/dqx/docs/reference/quality_checks.mdx | 7 +++++++ src/databricks/labs/dqx/check_funcs.py | 4 +++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/dqx/docs/reference/quality_checks.mdx b/docs/dqx/docs/reference/quality_checks.mdx index ac0f4786f..509e21fd0 100644 --- a/docs/dqx/docs/reference/quality_checks.mdx +++ b/docs/dqx/docs/reference/quality_checks.mdx @@ -1685,6 +1685,13 @@ Any [Databricks built-in aggregate function](https://docs.databricks.com/aws/en/ - Functions returning arrays, structs, or maps (e.g., `collect_list`, `collect_set`) will fail with a clear error message explaining they cannot be compared to numeric limits - **User-Defined Aggregate Functions (UDAFs) are not currently supported** - only Databricks built-in functions in `pyspark.sql.functions` +**Runtime Compatibility:** +DQX requires **Databricks Runtime 15.4+ (Spark 3.5+)**. Some newer functions may not be available on older runtimes: +- `mode`, `median` - Require Spark 3.4+ (DBR 13.3+) +- Most statistical functions - Available in all supported versions + +Unsupported functions will fail with an error message and version guidance. + ### Extended Examples ```yaml diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index b3ad8896a..7fa767a43 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -2323,7 +2323,9 @@ def _build_aggregate_expression( except AttributeError as exc: raise InvalidParameterError( f"Aggregate function '{aggr_type}' not found in pyspark.sql.functions. " - f"Verify the function name is correct." + f"Verify the function name is correct, or check if your Databricks Runtime version supports this function. " + f"Some newer aggregate functions (e.g., mode, median) require DBR 15.4+ (Spark 3.5+). " + f"See: https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha" ) from exc From 3640b8d695e15a3736b81daa77b3ccd04c4eed13 Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Fri, 28 Nov 2025 11:25:48 +0000 Subject: [PATCH 13/24] review: addressing PR code review comments. --- docs/dqx/docs/reference/quality_checks.mdx | 55 ++++-- src/databricks/labs/dqx/check_funcs.py | 10 +- tests/integration/test_dataset_checks.py | 201 +++++++++++++-------- tests/unit/test_row_checks.py | 9 +- 4 files changed, 181 insertions(+), 94 deletions(-) diff --git a/docs/dqx/docs/reference/quality_checks.mdx b/docs/dqx/docs/reference/quality_checks.mdx index 509e21fd0..9f8d3e3a1 100644 --- a/docs/dqx/docs/reference/quality_checks.mdx +++ b/docs/dqx/docs/reference/quality_checks.mdx @@ -1387,10 +1387,10 @@ You can also define your own custom dataset-level checks (see [Creating custom c | Check | Description | Arguments | | ---------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | `is_unique` | Checks whether the values in the input column are unique and reports an issue for each row that contains a duplicate value. It supports uniqueness check for multiple columns (composite key). Null values are not considered duplicates by default, following the ANSI SQL standard. | `columns`: columns to check (can be a list of column names or column expressions); `nulls_distinct`: controls how null values are treated, default is True, thus nulls are not duplicates, eg. (NULL, NULL) not equals (NULL, NULL) and (1, NULL) not equals (1, NULL) | -| `is_aggr_not_greater_than` | Checks whether the aggregated values over group of rows or all rows are not greater than the provided limit. | `column`: column to check (can be a string column name or a column expression), optional for 'count' aggregation; `limit`: limit as number, column name or sql expression; `aggr_type`: aggregation function to use, such as "count" (default), "sum", "avg", "min", and "max"; `group_by`: (optional) list of columns or column expressions to group the rows for aggregation (no grouping by default) | -| `is_aggr_not_less_than` | Checks whether the aggregated values over group of rows or all rows are not less than the provided limit. | `column`: column to check (can be a string column name or a column expression), optional for 'count' aggregation; `limit`: limit as number, column name or sql expression; `aggr_type`: aggregation function to use, such as "count" (default), "sum", "avg", "min", and "max"; `group_by`: (optional) list of columns or column expressions to group the rows for aggregation (no grouping by default) | -| `is_aggr_equal` | Checks whether the aggregated values over group of rows or all rows are equal to the provided limit. | `column`: column to check (can be a string column name or a column expression), optional for 'count' aggregation; `limit`: limit as number, column name or sql expression; `aggr_type`: aggregation function to use, such as "count" (default), "sum", "avg", "min", and "max"; `group_by`: (optional) list of columns or column expressions to group the rows for aggregation (no grouping by default) | -| `is_aggr_not_equal` | Checks whether the aggregated values over group of rows or all rows are not equal to the provided limit. | `column`: column to check (can be a string column name or a column expression), optional for 'count' aggregation; `limit`: limit as number, column name or sql expression; `aggr_type`: aggregation function to use, such as "count" (default), "sum", "avg", "min", and "max"; `group_by`: (optional) list of columns or column expressions to group the rows for aggregation (no grouping by default) | +| `is_aggr_not_greater_than` | Checks whether the aggregated values over group of rows or all rows are not greater than the provided limit. | `column`: column to check (can be a string column name or a column expression), optional for 'count' aggregation; `limit`: limit as number, column name or sql expression; `aggr_type`: aggregation function (default: "count"), supports 20 curated functions (count, sum, avg, stddev, percentile, etc.) plus any Databricks built-in aggregate; `group_by`: (optional) list of columns or column expressions to group the rows for aggregation (no grouping by default); `row_filter`: (optional) SQL expression to filter rows before aggregation; `aggr_params`: (optional) dict of parameters for aggregates requiring them | +| `is_aggr_not_less_than` | Checks whether the aggregated values over group of rows or all rows are not less than the provided limit. | `column`: column to check (can be a string column name or a column expression), optional for 'count' aggregation; `limit`: limit as number, column name or sql expression; `aggr_type`: aggregation function (default: "count"), supports 20 curated functions (count, sum, avg, stddev, percentile, etc.) plus any Databricks built-in aggregate; `group_by`: (optional) list of columns or column expressions to group the rows for aggregation (no grouping by default); `row_filter`: (optional) SQL expression to filter rows before aggregation; `aggr_params`: (optional) dict of parameters for aggregates requiring them | +| `is_aggr_equal` | Checks whether the aggregated values over group of rows or all rows are equal to the provided limit. | `column`: column to check (can be a string column name or a column expression), optional for 'count' aggregation; `limit`: limit as number, column name or sql expression; `aggr_type`: aggregation function (default: "count"), supports 20 curated functions (count, sum, avg, stddev, percentile, etc.) plus any Databricks built-in aggregate; `group_by`: (optional) list of columns or column expressions to group the rows for aggregation (no grouping by default); `row_filter`: (optional) SQL expression to filter rows before aggregation; `aggr_params`: (optional) dict of parameters for aggregates requiring them | +| `is_aggr_not_equal` | Checks whether the aggregated values over group of rows or all rows are not equal to the provided limit. | `column`: column to check (can be a string column name or a column expression), optional for 'count' aggregation; `limit`: limit as number, column name or sql expression; `aggr_type`: aggregation function (default: "count"), supports 20 curated functions (count, sum, avg, stddev, percentile, etc.) plus any Databricks built-in aggregate; `group_by`: (optional) list of columns or column expressions to group the rows for aggregation (no grouping by default); `row_filter`: (optional) SQL expression to filter rows before aggregation; `aggr_params`: (optional) dict of parameters for aggregates requiring them | | `foreign_key` (aka is_in_list) | Checks whether input column or columns can be found in the reference DataFrame or Table (foreign key check). It supports foreign key check on single and composite keys. This check can be used to validate whether values in the input column(s) exist in a predefined list of allowed values (stored in the reference DataFrame or Table). It serves as a scalable alternative to `is_in_list` row-level checks, when working with large lists. | `columns`: columns to check (can be a list of string column names or column expressions); `ref_columns`: columns to check for existence in the reference DataFrame or Table (can be a list string column name or a column expression); `ref_df_name`: (optional) name of the reference DataFrame (dictionary of DataFrames can be passed when applying checks); `ref_table`: (optional) fully qualified reference table name; either `ref_df_name` or `ref_table` must be provided but never both; the number of passed `columns` and `ref_columns` must match and keys are checks in the given order; negate: if True the condition is negated (i.e. the check fails when the foreign key values exist in the reference DataFrame/Table), if False the check fails when the foreign key values do not exist in the reference | | `sql_query` | Checks whether the condition column produced by a SQL query is satisfied. The check expects the query to return a boolean condition column indicating whether a record meets the requirement (True = fail, False = pass), and one or more merge columns so that results can be joined back to the input DataFrame to preserve all original records. Important considerations: if merge columns aren't unique, multiple query rows can attach to a single input row, potentially causing false positives. Performance tip: since the check must join back to the input DataFrame to retain original records, writing a custom dataset-level rule is usually more performant than `sql_query` check. | `query`: query string, must return all merge columns and condition column; `input_placeholder`: name to be used in the sql query as `{{ input_placeholder }}` to refer to the input DataFrame, optional reference DataFrames are referred by the name provided in the dictionary of reference DataFrames (e.g. `{{ ref_df_key }}`, dictionary of DataFrames can be passed when applying checks); `merge_columns`: list of columns used for merging with the input DataFrame which must exist in the input DataFrame and be present in output of the sql query; `condition_column`: name of the column indicating a violation (False = pass, True = fail); `msg`: (optional) message to output; `name`: (optional) name of the resulting check (it can be overwritten by `name` specified at the check level); `negate`: if the condition should be negated | | `compare_datasets` | Compares two DataFrames at both row and column levels, providing detailed information about differences, including new or missing rows and column-level changes. Only columns present in both the source and reference DataFrames are compared. Use with caution if `check_missing_records` is enabled, as this may increase the number of rows in the output beyond the original input DataFrame. The comparison does not support Map types (any column comparison on map type is skipped automatically). Comparing datasets is valuable for validating data during migrations, detecting drift, performing regression testing, or verifying synchronization between source and target systems. | `columns`: columns to use for row matching with the reference DataFrame (can be a list of string column names or column expressions, but only simple column expressions are allowed such as 'F.col("col1")'), if not having primary keys or wanting to match against all columns you can pass 'df.columns'; `ref_columns`: list of columns in the reference DataFrame or Table to row match against the source DataFrame (can be a list of string column names or column expressions, but only simple column expressions are allowed such as 'F.col("col1")'), if not having primary keys or wanting to match against all columns you can pass 'ref_df.columns'; note that `columns` are matched with `ref_columns` by position, so the order of the provided columns in both lists must be exactly aligned; `exclude_columns`: (optional) list of columns to exclude from the value comparison but not from row matching (can be a list of string column names or column expressions, but only simple column expressions are allowed such as 'F.col("col1")'); the `exclude_columns` field does not alter the list of columns used to determine row matches (columns), it only controls which columns are skipped during the value comparison; `ref_df_name`: (optional) name of the reference DataFrame (dictionary of DataFrames can be passed when applying checks); `ref_table`: (optional) fully qualified reference table name; either `ref_df_name` or `ref_table` must be provided but never both; the number of passed `columns` and `ref_columns` must match and keys are checks in the given order; `check_missing_records`: perform a FULL OUTER JOIN to identify records that are missing from source or reference DataFrames, default is False; use with caution as this may increase the number of rows in the output, as unmatched rows from both sides are included; `null_safe_row_matching`: (optional) treat NULLs as equal when matching rows using `columns` and `ref_columns` (default: True); `null_safe_column_value_matching`: (optional) treat NULLs as equal when comparing column values (default: True) | @@ -1601,20 +1601,28 @@ Complex data types are supported as well. - col3 limit: 200 ``` + + +## Aggregate Function Types {#aggregate-function-types} -## Aggregate Function Types +DQX provides four aggregate check functions for validating metrics across your entire dataset or within groups: -Use `is_aggr_*` functions to validate metrics across your entire dataset or within groups (e.g., "each country has ≤ 1 country code", "P95 latency < 1s per region"). +- **`is_aggr_not_greater_than`** - Validates that aggregated values do not exceed a limit +- **`is_aggr_not_less_than`** - Validates that aggregated values meet a minimum threshold +- **`is_aggr_equal`** - Validates that aggregated values equal an expected value +- **`is_aggr_not_equal`** - Validates that aggregated values differ from a specific value -**When to use:** +Use these functions for scenarios like "each country has ≤ 1 country code" or "P95 latency < 1s per region". + +**When to use aggregate checks:** - Validate cardinality and uniqueness constraints - Monitor statistical properties (variance, outliers) - Enforce SLAs and performance thresholds -- Detect data quality issues across groups +- Detect data quality issues within groups DQX supports 20 **curated** aggregate functions (recommended for data quality) plus any other [Databricks built-in aggregate function](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha). -### Curated Aggregate Functions +### Curated Aggregate Functions {#curated-aggregate-functions} Validated and optimized for data quality use cases: @@ -1644,7 +1652,16 @@ Validated and optimized for data quality use cases: - `percentile` - Exact percentile (requires `aggr_params`) - `approx_percentile` - Approximate percentile (faster, requires `aggr_params`) -### Using `aggr_params` for Function Parameters + +`count_distinct` with `group_by` uses two-stage aggregation (groupBy + join) due to Spark's DISTINCT window limitation. **For large-scale datasets with many groups**, prefer `approx_count_distinct` which: +- Uses efficient window functions (no two-stage aggregation) +- Provides ±5% accuracy via HyperLogLog++ +- Scales better for high-cardinality groups + +Use `count_distinct` when exact counts are critical; use `approx_count_distinct` for large datasets or when approximation is acceptable. + + +### Using `aggr_params` for Function Parameters {#using-aggr_params-for-function-parameters} Some aggregate functions require additional parameters beyond the column. Pass these as a dictionary: @@ -1675,7 +1692,7 @@ aggr_params: The `aggr_params` dictionary is unpacked as keyword arguments to the Spark aggregate function. See [Databricks function documentation](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha) for available parameters for each function. -### Non-Curated Aggregates +### Non-Curated Aggregates {#non-curated-aggregates} Any [Databricks built-in aggregate function](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha) can be used beyond the curated list. DQX will: - Issue a warning (recommending curated functions when possible) @@ -1692,7 +1709,7 @@ DQX requires **Databricks Runtime 15.4+ (Spark 3.5+)**. Some newer functions may Unsupported functions will fail with an error message and version guidance. -### Extended Examples +### Extended Examples {#aggregate-extended-examples} ```yaml # Uniqueness: Each country has exactly one country code @@ -1790,9 +1807,21 @@ Unsupported functions will fail with an error message and version guidance. group_by: - service_name limit: 100 -``` +# count_if: Monitor error rate per service +- criticality: error + check: + function: is_aggr_not_greater_than + arguments: + column: status_code >= 500 + aggr_type: count_if + group_by: + - service_name + limit: 10 +``` +
+**Checks defined in YAML (continued)** ```yaml # foreign_key check using reference DataFrame - criticality: error diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index 7fa767a43..12240a459 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -2303,7 +2303,7 @@ def _build_aggregate_expression( return F.countDistinct(filtered_expr) if aggr_type in {"percentile", "approx_percentile"}: - if not aggr_params or "percentile" not in aggr_params: + if not aggr_params or not aggr_params.get("percentile"): raise MissingParameterError( f"'{aggr_type}' requires aggr_params with 'percentile' key (e.g., {{'percentile': 0.95}})" ) @@ -2311,7 +2311,7 @@ def _build_aggregate_expression( if aggr_type == "percentile": return F.percentile(filtered_expr, pct) - if "accuracy" in aggr_params: + if aggr_params.get("accuracy"): return F.approx_percentile(filtered_expr, pct, aggr_params["accuracy"]) return F.approx_percentile(filtered_expr, pct) @@ -2462,6 +2462,12 @@ def apply(df: DataFrame) -> DataFrame: group_cols = [F.col(col) if isinstance(col, str) else col for col in group_by] agg_df = df.groupBy(*group_cols).agg(aggr_expr.alias(metric_col)) + # Validate non-curated aggregates (type check on aggregated result before join) + # This ensures consistent validation regardless of whether window functions or two-stage + # aggregation is used (e.g., if a non-curated function is added to WINDOW_INCOMPATIBLE_AGGREGATES) + if not is_curated: + _validate_aggregate_return_type(agg_df, aggr_type, metric_col) + # Join aggregated metrics back to original DataFrame to maintain row-level granularity # Use column names for join (extract names from Column objects if present) join_cols = [col if isinstance(col, str) else get_column_name_or_alias(col) for col in group_by] diff --git a/tests/integration/test_dataset_checks.py b/tests/integration/test_dataset_checks.py index fc4f3865f..4c99c3836 100644 --- a/tests/integration/test_dataset_checks.py +++ b/tests/integration/test_dataset_checks.py @@ -3,7 +3,6 @@ from typing import Any import json import itertools -import warnings import pytest import pyspark.sql.functions as F @@ -537,14 +536,23 @@ def test_is_aggr_with_count_distinct(spark: SparkSession): ) checks = [ - # Global count_distinct (no group_by) - should work fine + # Global count_distinct (no group_by) - 3 distinct values in 'a', limit is 5, should pass is_aggr_not_greater_than("a", limit=5, aggr_type="count_distinct"), ] actual = _apply_checks(test_df, checks) - # Check that the check was applied - assert "a_count_distinct_greater_than_limit" in actual.columns + expected = spark.createDataFrame( + [ + ["val1", "data1", None], + ["val1", "data2", None], + ["val2", "data3", None], + ["val3", "data4", None], + ], + "a: string, b: string, a_count_distinct_greater_than_limit: string", + ) + + assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True) def test_is_aggr_with_count_distinct_and_group_by(spark: SparkSession): @@ -567,26 +575,18 @@ def test_is_aggr_with_count_distinct_and_group_by(spark: SparkSession): actual = _apply_checks(test_df, checks) - # Check that the check was applied - assert "b_count_distinct_group_by_a_greater_than_limit" in actual.columns - - # Verify group1 has violations (2 distinct values > limit 1) - group1_violations = ( - actual.filter(F.col("a") == "group1") - .select("b_count_distinct_group_by_a_greater_than_limit") - .distinct() - .collect() + expected = spark.createDataFrame( + [ + ["group1", "val1", "Count_distinct 2 in column 'b' per group of columns 'a' is greater than limit: 1"], + ["group1", "val1", "Count_distinct 2 in column 'b' per group of columns 'a' is greater than limit: 1"], + ["group1", "val2", "Count_distinct 2 in column 'b' per group of columns 'a' is greater than limit: 1"], + ["group2", "val3", None], + ["group2", "val3", None], + ], + "a: string, b: string, b_count_distinct_group_by_a_greater_than_limit: string", ) - assert len([r for r in group1_violations if r[0] is not None]) > 0 - # Verify group2 passes (1 distinct value <= limit 1) - group2_violations = ( - actual.filter(F.col("a") == "group2") - .select("b_count_distinct_group_by_a_greater_than_limit") - .distinct() - .collect() - ) - assert len([r for r in group2_violations if r[0] is None]) > 0 + assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True) def test_is_aggr_with_approx_count_distinct(spark: SparkSession): @@ -609,17 +609,30 @@ def test_is_aggr_with_approx_count_distinct(spark: SparkSession): actual = _apply_checks(test_df, checks) - # Check that the check was applied - assert "b_approx_count_distinct_group_by_a_greater_than_limit" in actual.columns - - # Verify group1 has violations (approx 2 distinct values > limit 1) - group1_violations = ( - actual.filter(F.col("a") == "group1") - .select("b_approx_count_distinct_group_by_a_greater_than_limit") - .distinct() - .collect() + expected = spark.createDataFrame( + [ + [ + "group1", + "val1", + "Approx_count_distinct 2 in column 'b' per group of columns 'a' is greater than limit: 1", + ], + [ + "group1", + "val1", + "Approx_count_distinct 2 in column 'b' per group of columns 'a' is greater than limit: 1", + ], + [ + "group1", + "val2", + "Approx_count_distinct 2 in column 'b' per group of columns 'a' is greater than limit: 1", + ], + ["group2", "val3", None], + ["group2", "val3", None], + ], + "a: string, b: string, b_approx_count_distinct_group_by_a_greater_than_limit: string", ) - assert len([r for r in group1_violations if r[0] is not None]) > 0 + + assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True) def test_is_aggr_with_aggr_params_generic(spark: SparkSession): @@ -650,12 +663,20 @@ def test_is_aggr_with_aggr_params_generic(spark: SparkSession): actual = _apply_checks(test_df, checks) - # Check that the check was applied with custom rsd parameter - assert "b_approx_count_distinct_group_by_a_greater_than_limit" in actual.columns + # All rows should pass (group1 has ~3 distinct, group2 has ~1 distinct, both <= 5) + expected = spark.createDataFrame( + [ + ["group1", "val1", None], + ["group1", "val1", None], + ["group1", "val2", None], + ["group1", "val3", None], + ["group2", "valA", None], + ["group2", "valA", None], + ], + "a: string, b: string, b_approx_count_distinct_group_by_a_greater_than_limit: string", + ) - # Verify results are computed (group1 has ~3 distinct, group2 has ~1 distinct) - result_count = actual.select("b_approx_count_distinct_group_by_a_greater_than_limit").count() - assert result_count == 6 # All rows should have the metric + assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True) def test_is_aggr_with_statistical_functions(spark: SparkSession): @@ -673,23 +694,50 @@ def test_is_aggr_with_statistical_functions(spark: SparkSession): ) checks = [ - # Standard deviation check (group A has higher variance than B) + # Standard deviation check (group A stddev ~10, group B stddev=0, both <= 10.0) is_aggr_not_greater_than("b", limit=10.0, aggr_type="stddev", group_by=["a"]), - # Variance check + # Variance check (group A variance ~100, group B variance=0, both <= 100.0) is_aggr_not_greater_than("b", limit=100.0, aggr_type="variance", group_by=["a"]), - # Median check (dataset-level, no grouping) + # Median check (dataset-level median ~10, should fail > 25.0) is_aggr_not_greater_than("b", limit=25.0, aggr_type="median"), - # Mode check (most frequent value per group) + # Mode check (group A mode=10/20/30, group B mode=5, both <= 10.0 or fail) is_aggr_not_greater_than("b", limit=10.0, aggr_type="mode", group_by=["a"]), ] actual = _apply_checks(test_df, checks) - # Check that checks were applied - assert "b_stddev_group_by_a_greater_than_limit" in actual.columns - assert "b_variance_group_by_a_greater_than_limit" in actual.columns - assert "b_median_greater_than_limit" in actual.columns - assert "b_mode_group_by_a_greater_than_limit" in actual.columns + # Group A has stddev ~8.16, variance ~66.67, mode=10 (first value), all pass + # Group B has stddev=0, variance=0, mode=5, all pass + # Median ~7.5 (pass < 25.0) + expected = spark.createDataFrame( + [ + ["A", 10.0, None, None, None, None], + [ + "A", + 20.0, + None, + None, + None, + "Mode 20.0 in column 'b' per group of columns 'a' is greater than limit: 10.0", + ], + [ + "A", + 30.0, + None, + None, + None, + "Mode 30.0 in column 'b' per group of columns 'a' is greater than limit: 10.0", + ], + ["B", 5.0, None, None, None, None], + ["B", 5.0, None, None, None, None], + ["B", 5.0, None, None, None, None], + ], + "a: string, b: double, b_stddev_group_by_a_greater_than_limit: string, " + "b_variance_group_by_a_greater_than_limit: string, b_median_greater_than_limit: string, " + "b_mode_group_by_a_greater_than_limit: string", + ) + + assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True) def test_is_aggr_with_mode_function(spark: SparkSession): @@ -717,13 +765,20 @@ def test_is_aggr_with_mode_function(spark: SparkSession): actual = _apply_checks(test_df, checks) # groupA should fail (mode=401 > limit=400), groupB should pass (mode=200 <= limit=400) - assert "b_mode_group_by_a_greater_than_limit" in actual.columns - - # Verify failures - groupA rows should have error messages (mode 401 > 400) - group_a_errors = ( - actual.filter(F.col("a") == "groupA").select("b_mode_group_by_a_greater_than_limit").distinct().collect() + expected = spark.createDataFrame( + [ + ["groupA", 401, "Mode 401 in column 'b' per group of columns 'a' is greater than limit: 400"], + ["groupA", 401, "Mode 401 in column 'b' per group of columns 'a' is greater than limit: 400"], + ["groupA", 401, "Mode 401 in column 'b' per group of columns 'a' is greater than limit: 400"], + ["groupA", 500, "Mode 401 in column 'b' per group of columns 'a' is greater than limit: 400"], + ["groupB", 200, None], + ["groupB", 200, None], + ["groupB", 404, None], + ], + "a: string, b: int, b_mode_group_by_a_greater_than_limit: string", ) - assert len([r for r in group_a_errors if r[0] is not None]) > 0 + + assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True) def test_is_aggr_with_percentile_functions(spark: SparkSession): @@ -734,11 +789,11 @@ def test_is_aggr_with_percentile_functions(spark: SparkSession): ) checks = [ - # P95 should be around 95 + # P95 should be around 95.95 (dataset-level), all pass (< 100.0) is_aggr_not_greater_than("b", limit=100.0, aggr_type="percentile", aggr_params={"percentile": 0.95}), - # P99 with approx_percentile + # P99 with approx_percentile should be around 99 (dataset-level), all pass (< 100.0) is_aggr_not_greater_than("b", limit=100.0, aggr_type="approx_percentile", aggr_params={"percentile": 0.99}), - # P50 (median) with accuracy parameter + # P50 (median) should be around 50.5 (dataset-level), all pass (>= 40.0) is_aggr_not_less_than( "b", limit=40.0, aggr_type="approx_percentile", aggr_params={"percentile": 0.50, "accuracy": 100} ), @@ -746,10 +801,14 @@ def test_is_aggr_with_percentile_functions(spark: SparkSession): actual = _apply_checks(test_df, checks) - # Verify columns exist - assert "b_percentile_greater_than_limit" in actual.columns - assert "b_approx_percentile_greater_than_limit" in actual.columns - assert "b_approx_percentile_less_than_limit" in actual.columns + # All checks should pass (P95 ~95 < 100, P99 ~99 < 100, P50 ~50 >= 40) + expected = spark.createDataFrame( + [(f"row{i}", float(i), None, None, None) for i in range(1, 101)], + "a: string, b: double, b_percentile_greater_than_limit: string, " + "b_approx_percentile_greater_than_limit: string, b_approx_percentile_less_than_limit: string", + ) + + assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True) def test_is_aggr_percentile_missing_params(spark: SparkSession): @@ -803,18 +862,18 @@ def test_is_aggr_non_curated_aggregate_with_warning(spark: SparkSession): ) # Use a valid aggregate that's not in curated list (e.g., any_value) - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - _, apply_fn = is_aggr_not_greater_than("b", limit=100, aggr_type="any_value") - result = apply_fn(test_df) - - # Should have warning about non-curated aggregate - assert len(w) > 0 - assert "non-curated" in str(w[0].message).lower() - - # Should still work - check for condition column with any_value in name - condition_cols = [c for c in result.columns if "any_value" in c and "condition" in c] - assert len(condition_cols) > 0 + # any_value returns one arbitrary non-null value (likely 10, 20, or 30), all < 100 + with pytest.warns(UserWarning, match="non-curated.*any_value"): + checks = [is_aggr_not_greater_than("b", limit=100, aggr_type="any_value")] + actual = _apply_checks(test_df, checks) + + # Should still work - any_value returns an arbitrary value, all should pass (< 100) + expected = spark.createDataFrame( + [("A", 10, None), ("B", 20, None), ("C", 30, None)], + "a: string, b: int, b_any_value_greater_than_limit: string", + ) + + assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True) def test_dataset_compare(spark: SparkSession, set_utc_timezone): diff --git a/tests/unit/test_row_checks.py b/tests/unit/test_row_checks.py index 2c1f5d085..ef8b80d3f 100644 --- a/tests/unit/test_row_checks.py +++ b/tests/unit/test_row_checks.py @@ -1,4 +1,3 @@ -import warnings import pytest from databricks.labs.dqx.check_funcs import ( is_equal_to, @@ -54,15 +53,9 @@ def test_col_is_in_list_missing_allowed_list(): def test_incorrect_aggr_type(): # With new implementation, invalid aggr_type triggers a warning (not immediate error) # The error occurs at runtime when the apply function is called - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + with pytest.warns(UserWarning, match="non-curated.*invalid"): condition, apply_fn = is_aggr_not_greater_than("a", 1, aggr_type="invalid") - # Should have warning about non-curated aggregate - assert len(w) > 0 - assert "non-curated" in str(w[0].message).lower() - assert "invalid" in str(w[0].message) - # Function should return successfully (error will happen at runtime when applied to DataFrame) assert condition is not None assert apply_fn is not None From e9bc7b36f39664f79e0774ff97aaec5f1af6b11c Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Fri, 28 Nov 2025 11:55:12 +0000 Subject: [PATCH 14/24] perf test: added a new set of performance tests for is_aggr count distinct and approx count distinct --- docs/dqx/docs/reference/quality_checks.mdx | 2 + tests/perf/test_apply_checks.py | 59 ++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/docs/dqx/docs/reference/quality_checks.mdx b/docs/dqx/docs/reference/quality_checks.mdx index 9f8d3e3a1..6109fc211 100644 --- a/docs/dqx/docs/reference/quality_checks.mdx +++ b/docs/dqx/docs/reference/quality_checks.mdx @@ -1659,6 +1659,8 @@ Validated and optimized for data quality use cases: - Scales better for high-cardinality groups Use `count_distinct` when exact counts are critical; use `approx_count_distinct` for large datasets or when approximation is acceptable. + +**Performance benchmarks**: See [benchmarks](./benchmarks) for performance comparison across different cardinality levels. ### Using `aggr_params` for Function Parameters {#using-aggr_params-for-function-parameters} diff --git a/tests/perf/test_apply_checks.py b/tests/perf/test_apply_checks.py index 398461ae1..1a4dba33f 100644 --- a/tests/perf/test_apply_checks.py +++ b/tests/perf/test_apply_checks.py @@ -1592,3 +1592,62 @@ def test_benchmark_foreach_has_valid_schema(benchmark, ws, generated_string_df): benchmark.group += f"_{n_rows}_rows_{len(columns)}_columns" actual_count = benchmark(lambda: dq_engine.apply_checks(df, checks).count()) assert actual_count == EXPECTED_ROWS + + +def test_benchmark_is_aggr_count_distinct_with_group_by(benchmark, ws, generated_df): + """Benchmark count_distinct with group_by (uses two-stage aggregation: groupBy + join).""" + dq_engine = DQEngine(workspace_client=ws, extra_params=EXTRA_PARAMS) + checks = [ + DQDatasetRule( + criticality="warn", + check_func=check_funcs.is_aggr_not_greater_than, + column="col2", + check_func_kwargs={ + "aggr_type": "count_distinct", + "group_by": ["col3"], + "limit": 1000000, + }, + ) + ] + checked = dq_engine.apply_checks(generated_df, checks) + actual_count = benchmark(lambda: checked.count()) + assert actual_count == EXPECTED_ROWS + + +def test_benchmark_is_aggr_approx_count_distinct_with_group_by(benchmark, ws, generated_df): + """Benchmark approx_count_distinct with group_by (uses window functions - should be faster).""" + dq_engine = DQEngine(workspace_client=ws, extra_params=EXTRA_PARAMS) + checks = [ + DQDatasetRule( + criticality="warn", + check_func=check_funcs.is_aggr_not_greater_than, + column="col2", + check_func_kwargs={ + "aggr_type": "approx_count_distinct", + "group_by": ["col3"], + "limit": 1000000, + }, + ) + ] + checked = dq_engine.apply_checks(generated_df, checks) + actual_count = benchmark(lambda: checked.count()) + assert actual_count == EXPECTED_ROWS + + +def test_benchmark_is_aggr_count_distinct_no_group_by(benchmark, ws, generated_df): + """Benchmark count_distinct without group_by (baseline - uses standard aggregation).""" + dq_engine = DQEngine(workspace_client=ws, extra_params=EXTRA_PARAMS) + checks = [ + DQDatasetRule( + criticality="warn", + check_func=check_funcs.is_aggr_not_greater_than, + column="col2", + check_func_kwargs={ + "aggr_type": "count_distinct", + "limit": 1000000, + }, + ) + ] + checked = dq_engine.apply_checks(generated_df, checks) + actual_count = benchmark(lambda: checked.count()) + assert actual_count == EXPECTED_ROWS From 9671ca240578f075495e75d631aba8ee242c2be1 Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Fri, 28 Nov 2025 12:39:19 +0000 Subject: [PATCH 15/24] flaky test fix --- tests/integration/test_dataset_checks.py | 43 +++++++----------------- 1 file changed, 12 insertions(+), 31 deletions(-) diff --git a/tests/integration/test_dataset_checks.py b/tests/integration/test_dataset_checks.py index 4c99c3836..12e530619 100644 --- a/tests/integration/test_dataset_checks.py +++ b/tests/integration/test_dataset_checks.py @@ -680,7 +680,7 @@ def test_is_aggr_with_aggr_params_generic(spark: SparkSession): def test_is_aggr_with_statistical_functions(spark: SparkSession): - """Test statistical aggregate functions: stddev, variance, median, mode.""" + """Test statistical aggregate functions: stddev, variance, median.""" test_df = spark.createDataFrame( [ ["A", 10.0], @@ -694,47 +694,28 @@ def test_is_aggr_with_statistical_functions(spark: SparkSession): ) checks = [ - # Standard deviation check (group A stddev ~10, group B stddev=0, both <= 10.0) + # Standard deviation check (group A stddev ~8.16, group B stddev=0, both <= 10.0) is_aggr_not_greater_than("b", limit=10.0, aggr_type="stddev", group_by=["a"]), - # Variance check (group A variance ~100, group B variance=0, both <= 100.0) + # Variance check (group A variance ~66.67, group B variance=0, both <= 100.0) is_aggr_not_greater_than("b", limit=100.0, aggr_type="variance", group_by=["a"]), - # Median check (dataset-level median ~10, should fail > 25.0) + # Median check (dataset-level median 7.5, passes < 25.0) is_aggr_not_greater_than("b", limit=25.0, aggr_type="median"), - # Mode check (group A mode=10/20/30, group B mode=5, both <= 10.0 or fail) - is_aggr_not_greater_than("b", limit=10.0, aggr_type="mode", group_by=["a"]), ] actual = _apply_checks(test_df, checks) - # Group A has stddev ~8.16, variance ~66.67, mode=10 (first value), all pass - # Group B has stddev=0, variance=0, mode=5, all pass - # Median ~7.5 (pass < 25.0) + # All checks should pass expected = spark.createDataFrame( [ - ["A", 10.0, None, None, None, None], - [ - "A", - 20.0, - None, - None, - None, - "Mode 20.0 in column 'b' per group of columns 'a' is greater than limit: 10.0", - ], - [ - "A", - 30.0, - None, - None, - None, - "Mode 30.0 in column 'b' per group of columns 'a' is greater than limit: 10.0", - ], - ["B", 5.0, None, None, None, None], - ["B", 5.0, None, None, None, None], - ["B", 5.0, None, None, None, None], + ["A", 10.0, None, None, None], + ["A", 20.0, None, None, None], + ["A", 30.0, None, None, None], + ["B", 5.0, None, None, None], + ["B", 5.0, None, None, None], + ["B", 5.0, None, None, None], ], "a: string, b: double, b_stddev_group_by_a_greater_than_limit: string, " - "b_variance_group_by_a_greater_than_limit: string, b_median_greater_than_limit: string, " - "b_mode_group_by_a_greater_than_limit: string", + "b_variance_group_by_a_greater_than_limit: string, b_median_greater_than_limit: string", ) assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True) From 759edcfe23d165a510bc2398750f60d3f3fb27af Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Fri, 28 Nov 2025 12:46:01 +0000 Subject: [PATCH 16/24] docs: fix broken link --- docs/dqx/docs/reference/quality_checks.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/dqx/docs/reference/quality_checks.mdx b/docs/dqx/docs/reference/quality_checks.mdx index 6109fc211..54eae684a 100644 --- a/docs/dqx/docs/reference/quality_checks.mdx +++ b/docs/dqx/docs/reference/quality_checks.mdx @@ -1660,7 +1660,7 @@ Validated and optimized for data quality use cases: Use `count_distinct` when exact counts are critical; use `approx_count_distinct` for large datasets or when approximation is acceptable. -**Performance benchmarks**: See [benchmarks](./benchmarks) for performance comparison across different cardinality levels. +**Performance benchmarks**: See [benchmarks](/docs/reference/benchmarks) for performance comparison across different cardinality levels. ### Using `aggr_params` for Function Parameters {#using-aggr_params-for-function-parameters} From 69f31c9be2da252224be3f9c13d6cedf7d2aa3ff Mon Sep 17 00:00:00 2001 From: vb-dbrks Date: Fri, 28 Nov 2025 13:52:44 +0000 Subject: [PATCH 17/24] Add pytest-benchmark performance baseline --- docs/dqx/docs/reference/benchmarks.mdx | 3 + tests/perf/.benchmarks/baseline.json | 105 +++++++++++++++++++++++++ 2 files changed, 108 insertions(+) diff --git a/docs/dqx/docs/reference/benchmarks.mdx b/docs/dqx/docs/reference/benchmarks.mdx index 5bddeb18e..18c1bc423 100644 --- a/docs/dqx/docs/reference/benchmarks.mdx +++ b/docs/dqx/docs/reference/benchmarks.mdx @@ -57,6 +57,9 @@ sidebar_position: 13 | test_benchmark_has_valid_schema | 0.172078 | 0.172141 | 0.163793 | 0.181081 | 0.006715 | 0.009295 | 0.167010 | 0.176305 | 6 | 0 | 2 | 5.81 | | test_benchmark_has_x_coordinate_between | 0.217192 | 0.213656 | 0.209310 | 0.236233 | 0.011150 | 0.012638 | 0.209410 | 0.222048 | 5 | 0 | 1 | 4.60 | | test_benchmark_has_y_coordinate_between | 0.218497 | 0.219630 | 0.209352 | 0.234111 | 0.010103 | 0.013743 | 0.209584 | 0.223327 | 5 | 0 | 1 | 4.58 | +| test_benchmark_is_aggr_approx_count_distinct_with_group_by | 0.242859 | 0.239567 | 0.223938 | 0.272970 | 0.018260 | 0.017771 | 0.232410 | 0.250181 | 5 | 0 | 2 | 4.12 | +| test_benchmark_is_aggr_count_distinct_no_group_by | 0.277439 | 0.272561 | 0.250425 | 0.320653 | 0.027686 | 0.038256 | 0.256090 | 0.294346 | 5 | 0 | 1 | 3.60 | +| test_benchmark_is_aggr_count_distinct_with_group_by | 0.248015 | 0.249024 | 0.231550 | 0.270066 | 0.016044 | 0.026386 | 0.233048 | 0.259434 | 5 | 0 | 2 | 4.03 | | test_benchmark_is_aggr_equal | 0.304401 | 0.305693 | 0.266624 | 0.330403 | 0.026888 | 0.044641 | 0.284540 | 0.329181 | 5 | 0 | 1 | 3.29 | | test_benchmark_is_aggr_not_equal | 0.296462 | 0.296800 | 0.275119 | 0.312035 | 0.013498 | 0.013448 | 0.291054 | 0.304502 | 5 | 0 | 2 | 3.37 | | test_benchmark_is_aggr_not_greater_than | 0.307771 | 0.315185 | 0.277924 | 0.316280 | 0.016705 | 0.010701 | 0.304974 | 0.315675 | 5 | 1 | 1 | 3.25 | diff --git a/tests/perf/.benchmarks/baseline.json b/tests/perf/.benchmarks/baseline.json index b3f449bb9..66e28fffd 100644 --- a/tests/perf/.benchmarks/baseline.json +++ b/tests/perf/.benchmarks/baseline.json @@ -1681,6 +1681,111 @@ "iterations": 1 } }, + { + "group": null, + "name": "test_benchmark_is_aggr_approx_count_distinct_with_group_by", + "fullname": "tests/perf/test_apply_checks.py::test_benchmark_is_aggr_approx_count_distinct_with_group_by", + "params": null, + "param": null, + "extra_info": {}, + "options": { + "disable_gc": false, + "timer": "perf_counter", + "min_rounds": 5, + "max_time": 1.0, + "min_time": 0.000005, + "warmup": false + }, + "stats": { + "min": 0.22393849300033253, + "max": 0.27297041700012414, + "mean": 0.2428591086000779, + "stddev": 0.01826011249364107, + "rounds": 5, + "median": 0.23956732599981478, + "iqr": 0.01777076375026354, + "q1": 0.23241047199996956, + "q3": 0.2501812357502331, + "iqr_outliers": 0, + "stddev_outliers": 2, + "outliers": "2;0", + "ld15iqr": 0.22393849300033253, + "hd15iqr": 0.27297041700012414, + "ops": 4.117613729887829, + "total": 1.2142955430003894, + "iterations": 1 + } + }, + { + "group": null, + "name": "test_benchmark_is_aggr_count_distinct_no_group_by", + "fullname": "tests/perf/test_apply_checks.py::test_benchmark_is_aggr_count_distinct_no_group_by", + "params": null, + "param": null, + "extra_info": {}, + "options": { + "disable_gc": false, + "timer": "perf_counter", + "min_rounds": 5, + "max_time": 1.0, + "min_time": 0.000005, + "warmup": false + }, + "stats": { + "min": 0.2504249900002833, + "max": 0.3206533539996599, + "mean": 0.2774389961999987, + "stddev": 0.02768560893059817, + "rounds": 5, + "median": 0.2725611239998216, + "iqr": 0.03825604774954172, + "q1": 0.25609008650030773, + "q3": 0.29434613424984946, + "iqr_outliers": 0, + "stddev_outliers": 1, + "outliers": "1;0", + "ld15iqr": 0.2504249900002833, + "hd15iqr": 0.3206533539996599, + "ops": 3.6043959706339397, + "total": 1.3871949809999933, + "iterations": 1 + } + }, + { + "group": null, + "name": "test_benchmark_is_aggr_count_distinct_with_group_by", + "fullname": "tests/perf/test_apply_checks.py::test_benchmark_is_aggr_count_distinct_with_group_by", + "params": null, + "param": null, + "extra_info": {}, + "options": { + "disable_gc": false, + "timer": "perf_counter", + "min_rounds": 5, + "max_time": 1.0, + "min_time": 0.000005, + "warmup": false + }, + "stats": { + "min": 0.2315504329999385, + "max": 0.2700657520003915, + "mean": 0.24801521500012313, + "stddev": 0.01604423432706884, + "rounds": 5, + "median": 0.2490236530002221, + "iqr": 0.026385781999579194, + "q1": 0.2330477210002755, + "q3": 0.2594335029998547, + "iqr_outliers": 0, + "stddev_outliers": 2, + "outliers": "2;0", + "ld15iqr": 0.2315504329999385, + "hd15iqr": 0.2700657520003915, + "ops": 4.0320106974062195, + "total": 1.2400760750006157, + "iterations": 1 + } + }, { "group": null, "name": "test_benchmark_is_aggr_equal", From 20870c81099f2847670593e766f22009a9c18002 Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Tue, 2 Dec 2025 16:33:06 +0000 Subject: [PATCH 18/24] docs: pr review addressed --- docs/dqx/docs/reference/quality_checks.mdx | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/docs/dqx/docs/reference/quality_checks.mdx b/docs/dqx/docs/reference/quality_checks.mdx index 54eae684a..ae57797ed 100644 --- a/docs/dqx/docs/reference/quality_checks.mdx +++ b/docs/dqx/docs/reference/quality_checks.mdx @@ -1653,14 +1653,9 @@ Validated and optimized for data quality use cases: - `approx_percentile` - Approximate percentile (faster, requires `aggr_params`) -`count_distinct` with `group_by` uses two-stage aggregation (groupBy + join) due to Spark's DISTINCT window limitation. **For large-scale datasets with many groups**, prefer `approx_count_distinct` which: -- Uses efficient window functions (no two-stage aggregation) -- Provides ±5% accuracy via HyperLogLog++ -- Scales better for high-cardinality groups +Using `count_distinct` with `group_by` performs multiple expensive operations. **For large-scale datasets, use `approx_count_distinct` instead**—it processes efficiently and scales well with high-cardinality groups. -Use `count_distinct` when exact counts are critical; use `approx_count_distinct` for large datasets or when approximation is acceptable. - -**Performance benchmarks**: See [benchmarks](/docs/reference/benchmarks) for performance comparison across different cardinality levels. +**Performance benchmarks**: See [benchmarks](/docs/reference/benchmarks) for detailed comparison. ### Using `aggr_params` for Function Parameters {#using-aggr_params-for-function-parameters} @@ -1694,7 +1689,7 @@ aggr_params: The `aggr_params` dictionary is unpacked as keyword arguments to the Spark aggregate function. See [Databricks function documentation](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha) for available parameters for each function. -### Non-Curated Aggregates {#non-curated-aggregates} +### Non-Curated Aggregate Functions {#non-curated-aggregates} Any [Databricks built-in aggregate function](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha) can be used beyond the curated list. DQX will: - Issue a warning (recommending curated functions when possible) @@ -1704,12 +1699,9 @@ Any [Databricks built-in aggregate function](https://docs.databricks.com/aws/en/ - Functions returning arrays, structs, or maps (e.g., `collect_list`, `collect_set`) will fail with a clear error message explaining they cannot be compared to numeric limits - **User-Defined Aggregate Functions (UDAFs) are not currently supported** - only Databricks built-in functions in `pyspark.sql.functions` -**Runtime Compatibility:** -DQX requires **Databricks Runtime 15.4+ (Spark 3.5+)**. Some newer functions may not be available on older runtimes: -- `mode`, `median` - Require Spark 3.4+ (DBR 13.3+) -- Most statistical functions - Available in all supported versions - -Unsupported functions will fail with an error message and version guidance. + +Some aggregate functions require newer Spark versions. If a function is unavailable, the error message will indicate the minimum required Spark version. **Use the latest Databricks Runtime for best compatibility.** + ### Extended Examples {#aggregate-extended-examples} From 1d0c96af762d1c4325465ac3e5992fa4c9566d67 Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Tue, 2 Dec 2025 22:21:05 +0000 Subject: [PATCH 19/24] improvement: readability of error messages + update tests --- src/databricks/labs/dqx/check_funcs.py | 68 ++++++++++++++++-------- tests/integration/test_apply_checks.py | 26 ++++----- tests/integration/test_dataset_checks.py | 66 +++++++++++------------ 3 files changed, 91 insertions(+), 69 deletions(-) diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index ae12e5035..c7d48111c 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -30,27 +30,28 @@ # Curated aggregate functions for data quality checks # These are univariate (single-column) aggregate functions suitable for DQ monitoring +# Maps function names to human-readable display names for error messages CURATED_AGGR_FUNCTIONS = { - "count", - "sum", - "avg", - "min", - "max", - "count_distinct", - "approx_count_distinct", - "count_if", - "stddev", - "stddev_pop", - "stddev_samp", - "variance", - "var_pop", - "var_samp", - "median", - "mode", - "skewness", - "kurtosis", - "percentile", - "approx_percentile", + "count": "Count", + "sum": "Sum", + "avg": "Average", + "min": "Min value", + "max": "Max value", + "count_distinct": "Distinct value count", + "approx_count_distinct": "Approximate distinct value count", + "count_if": "Conditional count", + "stddev": "Standard deviation", + "stddev_pop": "Population standard deviation", + "stddev_samp": "Sample standard deviation", + "variance": "Variance", + "var_pop": "Population variance", + "var_samp": "Sample variance", + "median": "Median", + "mode": "Mode", + "skewness": "Skewness", + "kurtosis": "Kurtosis", + "percentile": "Percentile", + "approx_percentile": "Approximate percentile", } # Aggregate functions incompatible with Spark window functions @@ -2303,7 +2304,7 @@ def _build_aggregate_expression( return F.countDistinct(filtered_expr) if aggr_type in {"percentile", "approx_percentile"}: - if not aggr_params or not aggr_params.get("percentile"): + if not aggr_params or "percentile" not in aggr_params: raise MissingParameterError( f"'{aggr_type}' requires aggr_params with 'percentile' key (e.g., {{'percentile': 0.95}})" ) @@ -2311,7 +2312,7 @@ def _build_aggregate_expression( if aggr_type == "percentile": return F.percentile(filtered_expr, pct) - if aggr_params.get("accuracy"): + if "accuracy" in aggr_params: return F.approx_percentile(filtered_expr, pct, aggr_params["accuracy"]) return F.approx_percentile(filtered_expr, pct) @@ -2498,11 +2499,14 @@ def apply(df: DataFrame) -> DataFrame: return df + # Get human-readable display name for aggregate function + aggr_display_name = _get_aggregate_display_name(aggr_type) + condition = make_condition( condition=F.col(condition_col), message=F.concat_ws( "", - F.lit(f"{aggr_type.capitalize()} "), + F.lit(f"{aggr_display_name} "), F.col(metric_col).cast("string"), F.lit(f" in column '{aggr_col_str}'"), F.lit(f"{' per group of columns ' if group_by_list_str else ''}"), @@ -2631,6 +2635,24 @@ def _get_normalized_column_and_expr(column: str | Column) -> tuple[str, str, Col return col_str_norm, column_str, col_expr +def _get_aggregate_display_name(aggr_type: str) -> str: + """ + Get a human-readable display name for an aggregate function. + + This helper provides user-friendly names for aggregate functions in error messages, + transforming technical function names (e.g., 'count_distinct') into readable text + (e.g., 'Distinct value count'). + + Args: + aggr_type: The aggregate function name (e.g., 'count_distinct', 'max', 'avg'). + + Returns: + A human-readable display name for the aggregate function. If no mapping exists, + returns the capitalized function name. + """ + return CURATED_AGGR_FUNCTIONS.get(aggr_type, aggr_type.capitalize()) + + def _get_column_expr(column: Column | str) -> Column: """ Convert a column input (string or Column) into a Spark Column expression. diff --git a/tests/integration/test_apply_checks.py b/tests/integration/test_apply_checks.py index db2a25ad1..b5c9a287d 100644 --- a/tests/integration/test_apply_checks.py +++ b/tests/integration/test_apply_checks.py @@ -6872,7 +6872,7 @@ def test_apply_aggr_checks(ws, spark): [ { "name": "c_avg_greater_than_limit", - "message": "Avg 4.0 in column 'c' is greater than limit: 0", + "message": "Average 4.0 in column 'c' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -6951,7 +6951,7 @@ def test_apply_aggr_checks(ws, spark): }, { "name": "c_avg_greater_than_limit", - "message": "Avg 4.0 in column 'c' is greater than limit: 0", + "message": "Average 4.0 in column 'c' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -6961,7 +6961,7 @@ def test_apply_aggr_checks(ws, spark): }, { "name": "c_avg_group_by_a_greater_than_limit", - "message": "Avg 4.0 in column 'c' per group of columns 'a' is greater than limit: 0", + "message": "Average 4.0 in column 'c' per group of columns 'a' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7050,7 +7050,7 @@ def test_apply_aggr_checks(ws, spark): }, { "name": "c_avg_greater_than_limit", - "message": "Avg 4.0 in column 'c' is greater than limit: 0", + "message": "Average 4.0 in column 'c' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7060,7 +7060,7 @@ def test_apply_aggr_checks(ws, spark): }, { "name": "c_avg_group_by_a_greater_than_limit", - "message": "Avg 4.0 in column 'c' per group of columns 'a' is greater than limit: 0", + "message": "Average 4.0 in column 'c' per group of columns 'a' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7272,7 +7272,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): [ { "name": "c_avg_greater_than_limit", - "message": "Avg 4.0 in column 'c' is greater than limit: 0", + "message": "Average 4.0 in column 'c' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7334,7 +7334,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "c_avg_not_equal_to_limit", - "message": "Avg 4.0 in column 'c' is equal to limit: 4.0", + "message": "Average 4.0 in column 'c' is equal to limit: 4.0", "columns": ["c"], "filter": None, "function": "is_aggr_not_equal", @@ -7371,7 +7371,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "c_avg_greater_than_limit", - "message": "Avg 4.0 in column 'c' is greater than limit: 0", + "message": "Average 4.0 in column 'c' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7381,7 +7381,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "c_avg_group_by_a_greater_than_limit", - "message": "Avg 4.0 in column 'c' per group of columns 'a' is greater than limit: 0", + "message": "Average 4.0 in column 'c' per group of columns 'a' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7443,7 +7443,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "c_avg_not_equal_to_limit", - "message": "Avg 4.0 in column 'c' is equal to limit: 4.0", + "message": "Average 4.0 in column 'c' is equal to limit: 4.0", "columns": ["c"], "filter": None, "function": "is_aggr_not_equal", @@ -7490,7 +7490,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "c_avg_greater_than_limit", - "message": "Avg 4.0 in column 'c' is greater than limit: 0", + "message": "Average 4.0 in column 'c' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7500,7 +7500,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "c_avg_group_by_a_greater_than_limit", - "message": "Avg 4.0 in column 'c' per group of columns 'a' is greater than limit: 0", + "message": "Average 4.0 in column 'c' per group of columns 'a' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7592,7 +7592,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "c_avg_not_equal_to_limit", - "message": "Avg 4.0 in column 'c' is equal to limit: 4.0", + "message": "Average 4.0 in column 'c' is equal to limit: 4.0", "columns": ["c"], "filter": None, "function": "is_aggr_not_equal", diff --git a/tests/integration/test_dataset_checks.py b/tests/integration/test_dataset_checks.py index 12e530619..525db0e48 100644 --- a/tests/integration/test_dataset_checks.py +++ b/tests/integration/test_dataset_checks.py @@ -226,10 +226,10 @@ def test_is_aggr_not_greater_than(spark: SparkSession): "Count 2 in column 'a' is greater than limit: 0", None, None, - "Avg 2.0 in column 'b' is greater than limit: 0.0", + "Average 2.0 in column 'b' is greater than limit: 0.0", "Sum 4 in column 'b' is greater than limit: 0.0", - "Min 1 in column 'b' is greater than limit: 0.0", - "Max 3 in column 'b' is greater than limit: 0.0", + "Min value 1 in column 'b' is greater than limit: 0.0", + "Max value 3 in column 'b' is greater than limit: 0.0", ], [ "a", @@ -238,10 +238,10 @@ def test_is_aggr_not_greater_than(spark: SparkSession): "Count 2 in column 'a' is greater than limit: 0", "Count 1 in column 'a' per group of columns 'a' is greater than limit: 0", "Count 1 in column 'b' per group of columns 'b' is greater than limit: 0", - "Avg 2.0 in column 'b' is greater than limit: 0.0", + "Average 2.0 in column 'b' is greater than limit: 0.0", "Sum 4 in column 'b' is greater than limit: 0.0", - "Min 1 in column 'b' is greater than limit: 0.0", - "Max 3 in column 'b' is greater than limit: 0.0", + "Min value 1 in column 'b' is greater than limit: 0.0", + "Max value 3 in column 'b' is greater than limit: 0.0", ], [ "b", @@ -250,10 +250,10 @@ def test_is_aggr_not_greater_than(spark: SparkSession): "Count 2 in column 'a' is greater than limit: 0", "Count 1 in column 'a' per group of columns 'a' is greater than limit: 0", "Count 1 in column 'b' per group of columns 'b' is greater than limit: 0", - "Avg 2.0 in column 'b' is greater than limit: 0.0", + "Average 2.0 in column 'b' is greater than limit: 0.0", "Sum 4 in column 'b' is greater than limit: 0.0", - "Min 1 in column 'b' is greater than limit: 0.0", - "Max 3 in column 'b' is greater than limit: 0.0", + "Min value 1 in column 'b' is greater than limit: 0.0", + "Max value 3 in column 'b' is greater than limit: 0.0", ], ], expected_schema, @@ -306,10 +306,10 @@ def test_is_aggr_not_less_than(spark: SparkSession): "Count 2 in column 'a' is less than limit: 3", "Count 0 in column 'a' per group of columns 'a' is less than limit: 2", "Count 0 in column 'b' per group of columns 'b' is less than limit: 2", - "Avg 2.0 in column 'b' is less than limit: 3.0", + "Average 2.0 in column 'b' is less than limit: 3.0", "Sum 4 in column 'b' is less than limit: 5.0", - "Min 1 in column 'b' is less than limit: 2.0", - "Max 3 in column 'b' is less than limit: 4.0", + "Min value 1 in column 'b' is less than limit: 2.0", + "Max value 3 in column 'b' is less than limit: 4.0", ], [ "a", @@ -318,10 +318,10 @@ def test_is_aggr_not_less_than(spark: SparkSession): "Count 2 in column 'a' is less than limit: 3", "Count 1 in column 'a' per group of columns 'a' is less than limit: 2", "Count 1 in column 'b' per group of columns 'b' is less than limit: 2", - "Avg 2.0 in column 'b' is less than limit: 3.0", + "Average 2.0 in column 'b' is less than limit: 3.0", "Sum 4 in column 'b' is less than limit: 5.0", - "Min 1 in column 'b' is less than limit: 2.0", - "Max 3 in column 'b' is less than limit: 4.0", + "Min value 1 in column 'b' is less than limit: 2.0", + "Max value 3 in column 'b' is less than limit: 4.0", ], [ "b", @@ -330,10 +330,10 @@ def test_is_aggr_not_less_than(spark: SparkSession): "Count 2 in column 'a' is less than limit: 3", "Count 1 in column 'a' per group of columns 'a' is less than limit: 2", "Count 1 in column 'b' per group of columns 'b' is less than limit: 2", - "Avg 2.0 in column 'b' is less than limit: 3.0", + "Average 2.0 in column 'b' is less than limit: 3.0", "Sum 4 in column 'b' is less than limit: 5.0", - "Min 1 in column 'b' is less than limit: 2.0", - "Max 3 in column 'b' is less than limit: 4.0", + "Min value 1 in column 'b' is less than limit: 2.0", + "Max value 3 in column 'b' is less than limit: 4.0", ], ], expected_schema, @@ -411,7 +411,7 @@ def test_is_aggr_equal(spark: SparkSession): None, "Sum 4 in column 'b' is not equal to limit: 10.0", None, - "Max 3 in column 'b' is not equal to limit: 5.0", + "Max value 3 in column 'b' is not equal to limit: 5.0", ], [ "a", @@ -423,7 +423,7 @@ def test_is_aggr_equal(spark: SparkSession): None, "Sum 4 in column 'b' is not equal to limit: 10.0", None, - "Max 3 in column 'b' is not equal to limit: 5.0", + "Max value 3 in column 'b' is not equal to limit: 5.0", ], [ "b", @@ -435,7 +435,7 @@ def test_is_aggr_equal(spark: SparkSession): None, "Sum 4 in column 'b' is not equal to limit: 10.0", None, - "Max 3 in column 'b' is not equal to limit: 5.0", + "Max value 3 in column 'b' is not equal to limit: 5.0", ], ], expected_schema, @@ -487,9 +487,9 @@ def test_is_aggr_not_equal(spark: SparkSession): None, None, None, - "Avg 2.0 in column 'b' is equal to limit: 2.0", + "Average 2.0 in column 'b' is equal to limit: 2.0", None, - "Min 1 in column 'b' is equal to limit: 1.0", + "Min value 1 in column 'b' is equal to limit: 1.0", None, ], [ @@ -499,9 +499,9 @@ def test_is_aggr_not_equal(spark: SparkSession): None, "Count 1 in column 'a' per group of columns 'a' is equal to limit: 1", None, - "Avg 2.0 in column 'b' is equal to limit: 2.0", + "Average 2.0 in column 'b' is equal to limit: 2.0", None, - "Min 1 in column 'b' is equal to limit: 1.0", + "Min value 1 in column 'b' is equal to limit: 1.0", None, ], [ @@ -511,9 +511,9 @@ def test_is_aggr_not_equal(spark: SparkSession): None, "Count 1 in column 'a' per group of columns 'a' is equal to limit: 1", None, - "Avg 2.0 in column 'b' is equal to limit: 2.0", + "Average 2.0 in column 'b' is equal to limit: 2.0", None, - "Min 1 in column 'b' is equal to limit: 1.0", + "Min value 1 in column 'b' is equal to limit: 1.0", None, ], ], @@ -577,9 +577,9 @@ def test_is_aggr_with_count_distinct_and_group_by(spark: SparkSession): expected = spark.createDataFrame( [ - ["group1", "val1", "Count_distinct 2 in column 'b' per group of columns 'a' is greater than limit: 1"], - ["group1", "val1", "Count_distinct 2 in column 'b' per group of columns 'a' is greater than limit: 1"], - ["group1", "val2", "Count_distinct 2 in column 'b' per group of columns 'a' is greater than limit: 1"], + ["group1", "val1", "Distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1"], + ["group1", "val1", "Distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1"], + ["group1", "val2", "Distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1"], ["group2", "val3", None], ["group2", "val3", None], ], @@ -614,17 +614,17 @@ def test_is_aggr_with_approx_count_distinct(spark: SparkSession): [ "group1", "val1", - "Approx_count_distinct 2 in column 'b' per group of columns 'a' is greater than limit: 1", + "Approximate distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1", ], [ "group1", "val1", - "Approx_count_distinct 2 in column 'b' per group of columns 'a' is greater than limit: 1", + "Approximate distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1", ], [ "group1", "val2", - "Approx_count_distinct 2 in column 'b' per group of columns 'a' is greater than limit: 1", + "Approximate distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1", ], ["group2", "val3", None], ["group2", "val3", None], From 1f68600ab8df3868ffedfb413436ae0d5a8966b6 Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Tue, 2 Dec 2025 22:42:59 +0000 Subject: [PATCH 20/24] fix: stacklevel for warnings, trailing whitespace, add Column expression test --- src/databricks/labs/dqx/check_funcs.py | 7 +-- tests/integration/test_dataset_checks.py | 69 ++++++++++++++++++++++-- 2 files changed, 70 insertions(+), 6 deletions(-) diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index c7d48111c..bda3d67d2 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -2407,7 +2407,7 @@ def _is_aggr_compare( f"Curated functions: {', '.join(sorted(CURATED_AGGR_FUNCTIONS))}. " f"Non-curated aggregates must return a single numeric value per group.", UserWarning, - stacklevel=2, + stacklevel=3, ) aggr_col_str_norm, aggr_col_str, aggr_col_expr = _get_normalized_column_and_expr(column) @@ -2470,7 +2470,8 @@ def apply(df: DataFrame) -> DataFrame: _validate_aggregate_return_type(agg_df, aggr_type, metric_col) # Join aggregated metrics back to original DataFrame to maintain row-level granularity - # Use column names for join (extract names from Column objects if present) + # Note: Aliased Column expressions in group_by are not supported for window-incompatible + # aggregates (e.g., count_distinct). Use string column names or simple F.col() expressions. join_cols = [col if isinstance(col, str) else get_column_name_or_alias(col) for col in group_by] df = df.join(agg_df, on=join_cols, how="left") else: @@ -2501,7 +2502,7 @@ def apply(df: DataFrame) -> DataFrame: # Get human-readable display name for aggregate function aggr_display_name = _get_aggregate_display_name(aggr_type) - + condition = make_condition( condition=F.col(condition_col), message=F.concat_ws( diff --git a/tests/integration/test_dataset_checks.py b/tests/integration/test_dataset_checks.py index 525db0e48..62c2c5a4b 100644 --- a/tests/integration/test_dataset_checks.py +++ b/tests/integration/test_dataset_checks.py @@ -577,9 +577,72 @@ def test_is_aggr_with_count_distinct_and_group_by(spark: SparkSession): expected = spark.createDataFrame( [ - ["group1", "val1", "Distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1"], - ["group1", "val1", "Distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1"], - ["group1", "val2", "Distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1"], + [ + "group1", + "val1", + "Distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1", + ], + [ + "group1", + "val1", + "Distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1", + ], + [ + "group1", + "val2", + "Distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1", + ], + ["group2", "val3", None], + ["group2", "val3", None], + ], + "a: string, b: string, b_count_distinct_group_by_a_greater_than_limit: string", + ) + + assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True) + + +def test_is_aggr_with_count_distinct_and_column_expression_in_group_by(spark: SparkSession): + """Test count_distinct with Column expression (F.col) in group_by. + + This tests that the two-stage aggregation (groupBy + join) correctly handles + Column expressions (not just string column names) in group_by. + """ + test_df = spark.createDataFrame( + [ + ["group1", "val1"], + ["group1", "val2"], # 2 distinct values in group1 + ["group2", "val3"], + ["group2", "val3"], # 1 distinct value in group2 + ], + "a: string, b: string", + ) + + # Use Column expression (F.col) in group_by instead of string + checks = [ + is_aggr_not_greater_than( + "b", + limit=1, + aggr_type="count_distinct", + group_by=[F.col("a")], # Column expression without alias + ), + ] + + actual = _apply_checks(test_df, checks) + + # group1 has 2 distinct values > 1, should fail + # group2 has 1 distinct value <= 1, should pass + expected = spark.createDataFrame( + [ + [ + "group1", + "val1", + "Distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1", + ], + [ + "group1", + "val2", + "Distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1", + ], ["group2", "val3", None], ["group2", "val3", None], ], From 7cf0ccf2f81a352f446193461a97787ddd499d8a Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Tue, 2 Dec 2025 23:14:06 +0000 Subject: [PATCH 21/24] code hardening: 1. removed dead code which was "just in case" 2. added test for incorrect parameters 3. More permissive parameter passing to aggr functions --- src/databricks/labs/dqx/check_funcs.py | 16 +++++----------- tests/integration/test_dataset_checks.py | 21 +++++++++++++++++++++ 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index bda3d67d2..2a879c6fd 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -2309,12 +2309,12 @@ def _build_aggregate_expression( f"'{aggr_type}' requires aggr_params with 'percentile' key (e.g., {{'percentile': 0.95}})" ) pct = aggr_params["percentile"] + # Pass through any additional parameters to Spark (e.g., accuracy, frequency) + # Spark will validate parameter names and types at runtime + other_params = {k: v for k, v in aggr_params.items() if k != "percentile"} - if aggr_type == "percentile": - return F.percentile(filtered_expr, pct) - if "accuracy" in aggr_params: - return F.approx_percentile(filtered_expr, pct, aggr_params["accuracy"]) - return F.approx_percentile(filtered_expr, pct) + aggr_func = getattr(F, aggr_type) + return aggr_func(filtered_expr, pct, **other_params) try: aggr_func = getattr(F, aggr_type) @@ -2463,12 +2463,6 @@ def apply(df: DataFrame) -> DataFrame: group_cols = [F.col(col) if isinstance(col, str) else col for col in group_by] agg_df = df.groupBy(*group_cols).agg(aggr_expr.alias(metric_col)) - # Validate non-curated aggregates (type check on aggregated result before join) - # This ensures consistent validation regardless of whether window functions or two-stage - # aggregation is used (e.g., if a non-curated function is added to WINDOW_INCOMPATIBLE_AGGREGATES) - if not is_curated: - _validate_aggregate_return_type(agg_df, aggr_type, metric_col) - # Join aggregated metrics back to original DataFrame to maintain row-level granularity # Note: Aliased Column expressions in group_by are not supported for window-incompatible # aggregates (e.g., count_distinct). Use string column names or simple F.col() expressions. diff --git a/tests/integration/test_dataset_checks.py b/tests/integration/test_dataset_checks.py index 62c2c5a4b..f0d91ac95 100644 --- a/tests/integration/test_dataset_checks.py +++ b/tests/integration/test_dataset_checks.py @@ -865,6 +865,27 @@ def test_is_aggr_percentile_missing_params(spark: SparkSession): apply_fn(test_df) +def test_is_aggr_percentile_invalid_params_caught_by_spark(spark: SparkSession): + """Test that invalid aggr_params are caught by Spark at runtime. + + This verifies our permissive strategy: we don't validate extra parameters, + but Spark will raise an error for truly invalid ones. + """ + test_df = spark.createDataFrame([(1, 10.0), (2, 20.0)], "id: int, value: double") + + # Pass an invalid parameter type (string instead of float for percentile) + # Spark should raise an error when the DataFrame is evaluated + with pytest.raises(Exception): # Spark will raise AnalysisException or similar + _, apply_fn = is_aggr_not_greater_than( + "value", + limit=100.0, + aggr_type="approx_percentile", + aggr_params={"percentile": "invalid_string"}, # Invalid: should be float + ) + result_df = apply_fn(test_df) + result_df.collect() # Force evaluation to trigger Spark error + + def test_is_aggr_with_invalid_aggregate_function(spark: SparkSession): """Test that invalid aggregate function names raise clear errors.""" test_df = spark.createDataFrame([(1, 10)], "id: int, value: int") From 24748ae6cb291b982c39a1e2972d7d6bba646af5 Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Wed, 3 Dec 2025 15:28:37 +0000 Subject: [PATCH 22/24] fix flaky test: test_apply_checks_and_save_in_tables_for_patterns_exclude_no_tables_matching --- .../test_apply_checks_and_save_in_table.py | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/integration/test_apply_checks_and_save_in_table.py b/tests/integration/test_apply_checks_and_save_in_table.py index 9f71ee293..2290d7561 100644 --- a/tests/integration/test_apply_checks_and_save_in_table.py +++ b/tests/integration/test_apply_checks_and_save_in_table.py @@ -1803,14 +1803,27 @@ def test_apply_checks_and_save_in_tables_for_patterns_no_tables_matching(ws, spa ) -def test_apply_checks_and_save_in_tables_for_patterns_exclude_no_tables_matching(ws, spark): - # Test with empty list of table configs +def test_apply_checks_and_save_in_tables_for_patterns_exclude_no_tables_matching(ws, spark, make_schema, make_table): + # Create an isolated schema with a known table, then exclude it using exclude_patterns + catalog_name = TEST_CATALOG + schema = make_schema(catalog_name=catalog_name) + + # Create a table in the isolated schema + make_table( + catalog_name=catalog_name, + schema_name=schema.name, + ctas="SELECT 1 as id", + ) + engine = DQEngine(ws, spark=spark, extra_params=EXTRA_PARAMS) - # This should not raise an error + # Include the schema pattern, but exclude all tables with wildcard - should result in no tables + # Using exclude_patterns avoids the full catalog scan that exclude_matched=True triggers with pytest.raises(NotFound, match="No tables found matching include or exclude criteria"): engine.apply_checks_and_save_in_tables_for_patterns( - patterns=["*"], checks_location="some/location", exclude_matched=True + patterns=[f"{catalog_name}.{schema.name}.*"], + exclude_patterns=[f"{catalog_name}.{schema.name}.*"], + checks_location="some/location", ) From c500cff995d68f956fdfa29b20e48f5480f93482 Mon Sep 17 00:00:00 2001 From: Varun Bhandary Date: Fri, 5 Dec 2025 16:13:21 +0000 Subject: [PATCH 23/24] code improv: consistent aggregate violation messages with params support --- src/databricks/labs/dqx/check_funcs.py | 50 ++++--- tests/integration/test_apply_checks.py | 78 +++++------ tests/integration/test_dataset_checks.py | 167 ++++++++++++++--------- 3 files changed, 176 insertions(+), 119 deletions(-) diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index 2a879c6fd..3cf7f4eab 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -35,10 +35,10 @@ "count": "Count", "sum": "Sum", "avg": "Average", - "min": "Min value", - "max": "Max value", - "count_distinct": "Distinct value count", - "approx_count_distinct": "Approximate distinct value count", + "min": "Min", + "max": "Max", + "count_distinct": "Distinct count", + "approx_count_distinct": "Approximate distinct count", "count_if": "Conditional count", "stddev": "Standard deviation", "stddev_pop": "Population standard deviation", @@ -2298,7 +2298,7 @@ def _build_aggregate_expression( Raises: MissingParameterError: If required parameters are missing for specific aggregates. - InvalidParameterError: If the aggregate function is not found. + InvalidParameterError: If the aggregate function is not found or parameters are invalid. """ if aggr_type == "count_distinct": return F.countDistinct(filtered_expr) @@ -2313,8 +2313,11 @@ def _build_aggregate_expression( # Spark will validate parameter names and types at runtime other_params = {k: v for k, v in aggr_params.items() if k != "percentile"} - aggr_func = getattr(F, aggr_type) - return aggr_func(filtered_expr, pct, **other_params) + try: + aggr_func = getattr(F, aggr_type) + return aggr_func(filtered_expr, pct, **other_params) + except Exception as exc: + raise InvalidParameterError(f"Failed to build '{aggr_type}' expression: {exc}") from exc try: aggr_func = getattr(F, aggr_type) @@ -2328,6 +2331,8 @@ def _build_aggregate_expression( f"Some newer aggregate functions (e.g., mode, median) require DBR 15.4+ (Spark 3.5+). " f"See: https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha" ) from exc + except Exception as exc: + raise InvalidParameterError(f"Failed to build '{aggr_type}' expression: {exc}") from exc def _validate_aggregate_return_type( @@ -2456,11 +2461,13 @@ def apply(df: DataFrame) -> DataFrame: aggr_expr = _build_aggregate_expression(aggr_type, filtered_expr, aggr_params) if group_by: + # Convert group_by to Column expressions (reused for both window and groupBy approaches) + group_cols = [F.col(col) if isinstance(col, str) else col for col in group_by] + # Check if aggregate is incompatible with window functions (e.g., count_distinct with DISTINCT) if aggr_type in WINDOW_INCOMPATIBLE_AGGREGATES: # Use two-stage aggregation: groupBy + join (instead of window functions) # This is required for aggregates like count_distinct that don't support window DISTINCT operations - group_cols = [F.col(col) if isinstance(col, str) else col for col in group_by] agg_df = df.groupBy(*group_cols).agg(aggr_expr.alias(metric_col)) # Join aggregated metrics back to original DataFrame to maintain row-level granularity @@ -2470,7 +2477,7 @@ def apply(df: DataFrame) -> DataFrame: df = df.join(agg_df, on=join_cols, how="left") else: # Use standard window function approach for window-compatible aggregates - window_spec = Window.partitionBy(*[F.col(col) if isinstance(col, str) else col for col in group_by]) + window_spec = Window.partitionBy(*group_cols) df = df.withColumn(metric_col, aggr_expr.over(window_spec)) # Validate non-curated aggregates (type check only - window functions return same row count) @@ -2494,14 +2501,14 @@ def apply(df: DataFrame) -> DataFrame: return df - # Get human-readable display name for aggregate function - aggr_display_name = _get_aggregate_display_name(aggr_type) + # Get human-readable display name for aggregate function (including params if present) + aggr_display_name = _get_aggregate_display_name(aggr_type, aggr_params) condition = make_condition( condition=F.col(condition_col), message=F.concat_ws( "", - F.lit(f"{aggr_display_name} "), + F.lit(f"{aggr_display_name} value "), F.col(metric_col).cast("string"), F.lit(f" in column '{aggr_col_str}'"), F.lit(f"{' per group of columns ' if group_by_list_str else ''}"), @@ -2630,7 +2637,7 @@ def _get_normalized_column_and_expr(column: str | Column) -> tuple[str, str, Col return col_str_norm, column_str, col_expr -def _get_aggregate_display_name(aggr_type: str) -> str: +def _get_aggregate_display_name(aggr_type: str, aggr_params: dict[str, Any] | None = None) -> str: """ Get a human-readable display name for an aggregate function. @@ -2640,12 +2647,23 @@ def _get_aggregate_display_name(aggr_type: str) -> str: Args: aggr_type: The aggregate function name (e.g., 'count_distinct', 'max', 'avg'). + aggr_params: Optional parameters passed to the aggregate function. Returns: - A human-readable display name for the aggregate function. If no mapping exists, - returns the capitalized function name. + A human-readable display name for the aggregate function, including parameters + if provided. For non-curated functions, returns the function name in quotes + with 'value' suffix. """ - return CURATED_AGGR_FUNCTIONS.get(aggr_type, aggr_type.capitalize()) + # Get base display name (curated functions have friendly names, others show function name in quotes) + base_name = CURATED_AGGR_FUNCTIONS.get(aggr_type, f"'{aggr_type}'") + + # Add parameters if present + if aggr_params: + # Format parameters as key=value pairs + param_str = ", ".join(f"{k}={v}" for k, v in aggr_params.items()) + return f"{base_name} ({param_str})" + + return base_name def _get_column_expr(column: Column | str) -> Column: diff --git a/tests/integration/test_apply_checks.py b/tests/integration/test_apply_checks.py index b5c9a287d..647336f5d 100644 --- a/tests/integration/test_apply_checks.py +++ b/tests/integration/test_apply_checks.py @@ -6872,7 +6872,7 @@ def test_apply_aggr_checks(ws, spark): [ { "name": "c_avg_greater_than_limit", - "message": "Average 4.0 in column 'c' is greater than limit: 0", + "message": "Average value 4.0 in column 'c' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -6894,7 +6894,7 @@ def test_apply_aggr_checks(ws, spark): }, { "name": "a_count_greater_than_limit", - "message": "Count 2 in column 'a' is greater than limit: 0", + "message": "Count value 2 in column 'a' is greater than limit: 0", "columns": ["a"], "filter": None, "function": "is_aggr_not_greater_than", @@ -6921,7 +6921,7 @@ def test_apply_aggr_checks(ws, spark): [ { "name": "b_sum_group_by_a_not_equal_to_limit", - "message": "Sum 2 in column 'b' per group of columns 'a' is not equal to limit: 8", + "message": "Sum value 2 in column 'b' per group of columns 'a' is not equal to limit: 8", "columns": ["b"], "filter": None, "function": "is_aggr_equal", @@ -6931,7 +6931,7 @@ def test_apply_aggr_checks(ws, spark): }, { "name": "a_count_group_by_a_greater_than_limit", - "message": "Count 2 in column 'a' per group of columns 'a' is greater than limit: 0", + "message": "Count value 2 in column 'a' per group of columns 'a' is greater than limit: 0", "columns": ["a"], "filter": None, "function": "is_aggr_not_greater_than", @@ -6941,7 +6941,7 @@ def test_apply_aggr_checks(ws, spark): }, { "name": "row_count_group_by_a_b_greater_than_limit", - "message": "Count 1 in column 'a' per group of columns 'a, b' is greater than limit: 0", + "message": "Count value 1 in column 'a' per group of columns 'a, b' is greater than limit: 0", "columns": ["a"], "filter": None, "function": "is_aggr_not_greater_than", @@ -6951,7 +6951,7 @@ def test_apply_aggr_checks(ws, spark): }, { "name": "c_avg_greater_than_limit", - "message": "Average 4.0 in column 'c' is greater than limit: 0", + "message": "Average value 4.0 in column 'c' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -6961,7 +6961,7 @@ def test_apply_aggr_checks(ws, spark): }, { "name": "c_avg_group_by_a_greater_than_limit", - "message": "Average 4.0 in column 'c' per group of columns 'a' is greater than limit: 0", + "message": "Average value 4.0 in column 'c' per group of columns 'a' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -6983,7 +6983,7 @@ def test_apply_aggr_checks(ws, spark): }, { "name": "a_count_greater_than_limit", - "message": "Count 2 in column 'a' is greater than limit: 0", + "message": "Count value 2 in column 'a' is greater than limit: 0", "columns": ["a"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7010,7 +7010,7 @@ def test_apply_aggr_checks(ws, spark): [ { "name": "b_sum_group_by_a_not_equal_to_limit", - "message": "Sum 2 in column 'b' per group of columns 'a' is not equal to limit: 8", + "message": "Sum value 2 in column 'b' per group of columns 'a' is not equal to limit: 8", "columns": ["b"], "filter": None, "function": "is_aggr_equal", @@ -7020,7 +7020,7 @@ def test_apply_aggr_checks(ws, spark): }, { "name": "a_count_group_by_a_greater_than_limit", - "message": "Count 2 in column 'a' per group of columns 'a' is greater than limit: 0", + "message": "Count value 2 in column 'a' per group of columns 'a' is greater than limit: 0", "columns": ["a"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7030,7 +7030,7 @@ def test_apply_aggr_checks(ws, spark): }, { "name": "a_count_group_by_a_greater_than_limit_with_b_not_null", - "message": "Count 1 in column 'a' per group of columns 'a' is greater than limit: 0", + "message": "Count value 1 in column 'a' per group of columns 'a' is greater than limit: 0", "columns": ["a"], "filter": "b is not null", "function": "is_aggr_not_greater_than", @@ -7040,7 +7040,7 @@ def test_apply_aggr_checks(ws, spark): }, { "name": "row_count_group_by_a_b_greater_than_limit", - "message": "Count 1 in column 'a' per group of columns 'a, b' is greater than limit: 0", + "message": "Count value 1 in column 'a' per group of columns 'a, b' is greater than limit: 0", "columns": ["a"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7050,7 +7050,7 @@ def test_apply_aggr_checks(ws, spark): }, { "name": "c_avg_greater_than_limit", - "message": "Average 4.0 in column 'c' is greater than limit: 0", + "message": "Average value 4.0 in column 'c' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7060,7 +7060,7 @@ def test_apply_aggr_checks(ws, spark): }, { "name": "c_avg_group_by_a_greater_than_limit", - "message": "Average 4.0 in column 'c' per group of columns 'a' is greater than limit: 0", + "message": "Average value 4.0 in column 'c' per group of columns 'a' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7070,7 +7070,7 @@ def test_apply_aggr_checks(ws, spark): }, { "name": "a_count_group_by_a_less_than_limit_with_b_not_null", - "message": "Count 1 in column 'a' per group of columns 'a' is less than limit: 10", + "message": "Count value 1 in column 'a' per group of columns 'a' is less than limit: 10", "columns": ["a"], "filter": "b is not null", "function": "is_aggr_not_less_than", @@ -7092,7 +7092,7 @@ def test_apply_aggr_checks(ws, spark): }, { "name": "a_count_greater_than_limit", - "message": "Count 2 in column 'a' is greater than limit: 0", + "message": "Count value 2 in column 'a' is greater than limit: 0", "columns": ["a"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7102,7 +7102,7 @@ def test_apply_aggr_checks(ws, spark): }, { "name": "a_count_greater_than_limit_with_b_not_null", - "message": "Count 1 in column 'a' is greater than limit: 0", + "message": "Count value 1 in column 'a' is greater than limit: 0", "columns": ["a"], "filter": "b is not null", "function": "is_aggr_not_greater_than", @@ -7272,7 +7272,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): [ { "name": "c_avg_greater_than_limit", - "message": "Average 4.0 in column 'c' is greater than limit: 0", + "message": "Average value 4.0 in column 'c' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7282,7 +7282,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "a_count_not_equal_to_limit", - "message": "Count 2 in column 'a' is equal to limit: 2", + "message": "Count value 2 in column 'a' is equal to limit: 2", "columns": ["a"], "filter": None, "function": "is_aggr_not_equal", @@ -7304,7 +7304,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "a_count_greater_than_limit", - "message": "Count 2 in column 'a' is greater than limit: 0", + "message": "Count value 2 in column 'a' is greater than limit: 0", "columns": ["a"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7334,7 +7334,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "c_avg_not_equal_to_limit", - "message": "Average 4.0 in column 'c' is equal to limit: 4.0", + "message": "Average value 4.0 in column 'c' is equal to limit: 4.0", "columns": ["c"], "filter": None, "function": "is_aggr_not_equal", @@ -7351,7 +7351,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): [ { "name": "a_count_group_by_a_greater_than_limit", - "message": "Count 2 in column 'a' per group of columns 'a' is greater than limit: 0", + "message": "Count value 2 in column 'a' per group of columns 'a' is greater than limit: 0", "columns": ["a"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7361,7 +7361,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "row_count_group_by_a_b_greater_than_limit", - "message": "Count 1 in column 'a' per group of columns 'a, b' is greater than limit: 0", + "message": "Count value 1 in column 'a' per group of columns 'a, b' is greater than limit: 0", "columns": ["a"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7371,7 +7371,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "c_avg_greater_than_limit", - "message": "Average 4.0 in column 'c' is greater than limit: 0", + "message": "Average value 4.0 in column 'c' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7381,7 +7381,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "c_avg_group_by_a_greater_than_limit", - "message": "Average 4.0 in column 'c' per group of columns 'a' is greater than limit: 0", + "message": "Average value 4.0 in column 'c' per group of columns 'a' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7391,7 +7391,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "a_count_not_equal_to_limit", - "message": "Count 2 in column 'a' is equal to limit: 2", + "message": "Count value 2 in column 'a' is equal to limit: 2", "columns": ["a"], "filter": None, "function": "is_aggr_not_equal", @@ -7413,7 +7413,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "a_count_greater_than_limit", - "message": "Count 2 in column 'a' is greater than limit: 0", + "message": "Count value 2 in column 'a' is greater than limit: 0", "columns": ["a"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7443,7 +7443,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "c_avg_not_equal_to_limit", - "message": "Average 4.0 in column 'c' is equal to limit: 4.0", + "message": "Average value 4.0 in column 'c' is equal to limit: 4.0", "columns": ["c"], "filter": None, "function": "is_aggr_not_equal", @@ -7460,7 +7460,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): [ { "name": "a_count_group_by_a_greater_than_limit", - "message": "Count 2 in column 'a' per group of columns 'a' is greater than limit: 0", + "message": "Count value 2 in column 'a' per group of columns 'a' is greater than limit: 0", "columns": ["a"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7470,7 +7470,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "a_count_group_by_a_greater_than_limit_with_b_not_null", - "message": "Count 1 in column 'a' per group of columns 'a' is greater than limit: 0", + "message": "Count value 1 in column 'a' per group of columns 'a' is greater than limit: 0", "columns": ["a"], "filter": "b is not null", "function": "is_aggr_not_greater_than", @@ -7480,7 +7480,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "row_count_group_by_a_b_greater_than_limit", - "message": "Count 1 in column 'a' per group of columns 'a, b' is greater than limit: 0", + "message": "Count value 1 in column 'a' per group of columns 'a, b' is greater than limit: 0", "columns": ["a"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7490,7 +7490,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "c_avg_greater_than_limit", - "message": "Average 4.0 in column 'c' is greater than limit: 0", + "message": "Average value 4.0 in column 'c' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7500,7 +7500,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "c_avg_group_by_a_greater_than_limit", - "message": "Average 4.0 in column 'c' per group of columns 'a' is greater than limit: 0", + "message": "Average value 4.0 in column 'c' per group of columns 'a' is greater than limit: 0", "columns": ["c"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7510,7 +7510,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "a_count_group_by_a_less_than_limit_with_b_not_null", - "message": "Count 1 in column 'a' per group of columns 'a' is less than limit: 10", + "message": "Count value 1 in column 'a' per group of columns 'a' is less than limit: 10", "columns": ["a"], "filter": "b is not null", "function": "is_aggr_not_less_than", @@ -7520,7 +7520,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "a_count_not_equal_to_limit", - "message": "Count 2 in column 'a' is equal to limit: 2", + "message": "Count value 2 in column 'a' is equal to limit: 2", "columns": ["a"], "filter": None, "function": "is_aggr_not_equal", @@ -7530,7 +7530,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "a_count_not_equal_to_limit_with_filter", - "message": "Count 1 in column 'a' is equal to limit: 1", + "message": "Count value 1 in column 'a' is equal to limit: 1", "columns": ["a"], "filter": "b is not null", "function": "is_aggr_not_equal", @@ -7552,7 +7552,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "a_count_greater_than_limit", - "message": "Count 2 in column 'a' is greater than limit: 0", + "message": "Count value 2 in column 'a' is greater than limit: 0", "columns": ["a"], "filter": None, "function": "is_aggr_not_greater_than", @@ -7562,7 +7562,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "a_count_greater_than_limit_with_b_not_null", - "message": "Count 1 in column 'a' is greater than limit: 0", + "message": "Count value 1 in column 'a' is greater than limit: 0", "columns": ["a"], "filter": "b is not null", "function": "is_aggr_not_greater_than", @@ -7592,7 +7592,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark): }, { "name": "c_avg_not_equal_to_limit", - "message": "Average 4.0 in column 'c' is equal to limit: 4.0", + "message": "Average value 4.0 in column 'c' is equal to limit: 4.0", "columns": ["c"], "filter": None, "function": "is_aggr_not_equal", diff --git a/tests/integration/test_dataset_checks.py b/tests/integration/test_dataset_checks.py index f0d91ac95..4944239f6 100644 --- a/tests/integration/test_dataset_checks.py +++ b/tests/integration/test_dataset_checks.py @@ -221,37 +221,37 @@ def test_is_aggr_not_greater_than(spark: SparkSession): [ "c", None, - "Count 3 in column 'a' is greater than limit: 1", + "Count value 3 in column 'a' is greater than limit: 1", # displayed since filtering is done after, filter only applied for calculation inside the check - "Count 2 in column 'a' is greater than limit: 0", + "Count value 2 in column 'a' is greater than limit: 0", None, None, - "Average 2.0 in column 'b' is greater than limit: 0.0", - "Sum 4 in column 'b' is greater than limit: 0.0", + "Average value 2.0 in column 'b' is greater than limit: 0.0", + "Sum value 4 in column 'b' is greater than limit: 0.0", "Min value 1 in column 'b' is greater than limit: 0.0", "Max value 3 in column 'b' is greater than limit: 0.0", ], [ "a", 1, - "Count 3 in column 'a' is greater than limit: 1", - "Count 2 in column 'a' is greater than limit: 0", - "Count 1 in column 'a' per group of columns 'a' is greater than limit: 0", - "Count 1 in column 'b' per group of columns 'b' is greater than limit: 0", - "Average 2.0 in column 'b' is greater than limit: 0.0", - "Sum 4 in column 'b' is greater than limit: 0.0", + "Count value 3 in column 'a' is greater than limit: 1", + "Count value 2 in column 'a' is greater than limit: 0", + "Count value 1 in column 'a' per group of columns 'a' is greater than limit: 0", + "Count value 1 in column 'b' per group of columns 'b' is greater than limit: 0", + "Average value 2.0 in column 'b' is greater than limit: 0.0", + "Sum value 4 in column 'b' is greater than limit: 0.0", "Min value 1 in column 'b' is greater than limit: 0.0", "Max value 3 in column 'b' is greater than limit: 0.0", ], [ "b", 3, - "Count 3 in column 'a' is greater than limit: 1", - "Count 2 in column 'a' is greater than limit: 0", - "Count 1 in column 'a' per group of columns 'a' is greater than limit: 0", - "Count 1 in column 'b' per group of columns 'b' is greater than limit: 0", - "Average 2.0 in column 'b' is greater than limit: 0.0", - "Sum 4 in column 'b' is greater than limit: 0.0", + "Count value 3 in column 'a' is greater than limit: 1", + "Count value 2 in column 'a' is greater than limit: 0", + "Count value 1 in column 'a' per group of columns 'a' is greater than limit: 0", + "Count value 1 in column 'b' per group of columns 'b' is greater than limit: 0", + "Average value 2.0 in column 'b' is greater than limit: 0.0", + "Sum value 4 in column 'b' is greater than limit: 0.0", "Min value 1 in column 'b' is greater than limit: 0.0", "Max value 3 in column 'b' is greater than limit: 0.0", ], @@ -302,36 +302,36 @@ def test_is_aggr_not_less_than(spark: SparkSession): [ "c", None, - "Count 3 in column 'a' is less than limit: 4", - "Count 2 in column 'a' is less than limit: 3", - "Count 0 in column 'a' per group of columns 'a' is less than limit: 2", - "Count 0 in column 'b' per group of columns 'b' is less than limit: 2", - "Average 2.0 in column 'b' is less than limit: 3.0", - "Sum 4 in column 'b' is less than limit: 5.0", + "Count value 3 in column 'a' is less than limit: 4", + "Count value 2 in column 'a' is less than limit: 3", + "Count value 0 in column 'a' per group of columns 'a' is less than limit: 2", + "Count value 0 in column 'b' per group of columns 'b' is less than limit: 2", + "Average value 2.0 in column 'b' is less than limit: 3.0", + "Sum value 4 in column 'b' is less than limit: 5.0", "Min value 1 in column 'b' is less than limit: 2.0", "Max value 3 in column 'b' is less than limit: 4.0", ], [ "a", 1, - "Count 3 in column 'a' is less than limit: 4", - "Count 2 in column 'a' is less than limit: 3", - "Count 1 in column 'a' per group of columns 'a' is less than limit: 2", - "Count 1 in column 'b' per group of columns 'b' is less than limit: 2", - "Average 2.0 in column 'b' is less than limit: 3.0", - "Sum 4 in column 'b' is less than limit: 5.0", + "Count value 3 in column 'a' is less than limit: 4", + "Count value 2 in column 'a' is less than limit: 3", + "Count value 1 in column 'a' per group of columns 'a' is less than limit: 2", + "Count value 1 in column 'b' per group of columns 'b' is less than limit: 2", + "Average value 2.0 in column 'b' is less than limit: 3.0", + "Sum value 4 in column 'b' is less than limit: 5.0", "Min value 1 in column 'b' is less than limit: 2.0", "Max value 3 in column 'b' is less than limit: 4.0", ], [ "b", 3, - "Count 3 in column 'a' is less than limit: 4", - "Count 2 in column 'a' is less than limit: 3", - "Count 1 in column 'a' per group of columns 'a' is less than limit: 2", - "Count 1 in column 'b' per group of columns 'b' is less than limit: 2", - "Average 2.0 in column 'b' is less than limit: 3.0", - "Sum 4 in column 'b' is less than limit: 5.0", + "Count value 3 in column 'a' is less than limit: 4", + "Count value 2 in column 'a' is less than limit: 3", + "Count value 1 in column 'a' per group of columns 'a' is less than limit: 2", + "Count value 1 in column 'b' per group of columns 'b' is less than limit: 2", + "Average value 2.0 in column 'b' is less than limit: 3.0", + "Sum value 4 in column 'b' is less than limit: 5.0", "Min value 1 in column 'b' is less than limit: 2.0", "Max value 3 in column 'b' is less than limit: 4.0", ], @@ -405,11 +405,11 @@ def test_is_aggr_equal(spark: SparkSession): "c", None, None, - "Count 2 in column 'a' is not equal to limit: 1", - "Count 0 in column 'a' per group of columns 'a' is not equal to limit: 1", - "Count 0 in column 'b' per group of columns 'b' is not equal to limit: 2", + "Count value 2 in column 'a' is not equal to limit: 1", + "Count value 0 in column 'a' per group of columns 'a' is not equal to limit: 1", + "Count value 0 in column 'b' per group of columns 'b' is not equal to limit: 2", None, - "Sum 4 in column 'b' is not equal to limit: 10.0", + "Sum value 4 in column 'b' is not equal to limit: 10.0", None, "Max value 3 in column 'b' is not equal to limit: 5.0", ], @@ -417,11 +417,11 @@ def test_is_aggr_equal(spark: SparkSession): "a", 1, None, - "Count 2 in column 'a' is not equal to limit: 1", + "Count value 2 in column 'a' is not equal to limit: 1", None, - "Count 1 in column 'b' per group of columns 'b' is not equal to limit: 2", + "Count value 1 in column 'b' per group of columns 'b' is not equal to limit: 2", None, - "Sum 4 in column 'b' is not equal to limit: 10.0", + "Sum value 4 in column 'b' is not equal to limit: 10.0", None, "Max value 3 in column 'b' is not equal to limit: 5.0", ], @@ -429,11 +429,11 @@ def test_is_aggr_equal(spark: SparkSession): "b", 3, None, - "Count 2 in column 'a' is not equal to limit: 1", + "Count value 2 in column 'a' is not equal to limit: 1", None, - "Count 1 in column 'b' per group of columns 'b' is not equal to limit: 2", + "Count value 1 in column 'b' per group of columns 'b' is not equal to limit: 2", None, - "Sum 4 in column 'b' is not equal to limit: 10.0", + "Sum value 4 in column 'b' is not equal to limit: 10.0", None, "Max value 3 in column 'b' is not equal to limit: 5.0", ], @@ -483,11 +483,11 @@ def test_is_aggr_not_equal(spark: SparkSession): [ "c", None, - "Count 3 in column 'a' is equal to limit: 3", + "Count value 3 in column 'a' is equal to limit: 3", None, None, None, - "Average 2.0 in column 'b' is equal to limit: 2.0", + "Average value 2.0 in column 'b' is equal to limit: 2.0", None, "Min value 1 in column 'b' is equal to limit: 1.0", None, @@ -495,11 +495,11 @@ def test_is_aggr_not_equal(spark: SparkSession): [ "a", 1, - "Count 3 in column 'a' is equal to limit: 3", + "Count value 3 in column 'a' is equal to limit: 3", None, - "Count 1 in column 'a' per group of columns 'a' is equal to limit: 1", + "Count value 1 in column 'a' per group of columns 'a' is equal to limit: 1", None, - "Average 2.0 in column 'b' is equal to limit: 2.0", + "Average value 2.0 in column 'b' is equal to limit: 2.0", None, "Min value 1 in column 'b' is equal to limit: 1.0", None, @@ -507,11 +507,11 @@ def test_is_aggr_not_equal(spark: SparkSession): [ "b", 3, - "Count 3 in column 'a' is equal to limit: 3", + "Count value 3 in column 'a' is equal to limit: 3", None, - "Count 1 in column 'a' per group of columns 'a' is equal to limit: 1", + "Count value 1 in column 'a' per group of columns 'a' is equal to limit: 1", None, - "Average 2.0 in column 'b' is equal to limit: 2.0", + "Average value 2.0 in column 'b' is equal to limit: 2.0", None, "Min value 1 in column 'b' is equal to limit: 1.0", None, @@ -580,17 +580,17 @@ def test_is_aggr_with_count_distinct_and_group_by(spark: SparkSession): [ "group1", "val1", - "Distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1", + "Distinct count value 2 in column 'b' per group of columns 'a' is greater than limit: 1", ], [ "group1", "val1", - "Distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1", + "Distinct count value 2 in column 'b' per group of columns 'a' is greater than limit: 1", ], [ "group1", "val2", - "Distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1", + "Distinct count value 2 in column 'b' per group of columns 'a' is greater than limit: 1", ], ["group2", "val3", None], ["group2", "val3", None], @@ -636,12 +636,12 @@ def test_is_aggr_with_count_distinct_and_column_expression_in_group_by(spark: Sp [ "group1", "val1", - "Distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1", + "Distinct count value 2 in column 'b' per group of columns 'a' is greater than limit: 1", ], [ "group1", "val2", - "Distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1", + "Distinct count value 2 in column 'b' per group of columns 'a' is greater than limit: 1", ], ["group2", "val3", None], ["group2", "val3", None], @@ -677,17 +677,17 @@ def test_is_aggr_with_approx_count_distinct(spark: SparkSession): [ "group1", "val1", - "Approximate distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1", + "Approximate distinct count value 2 in column 'b' per group of columns 'a' is greater than limit: 1", ], [ "group1", "val1", - "Approximate distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1", + "Approximate distinct count value 2 in column 'b' per group of columns 'a' is greater than limit: 1", ], [ "group1", "val2", - "Approximate distinct value count 2 in column 'b' per group of columns 'a' is greater than limit: 1", + "Approximate distinct count value 2 in column 'b' per group of columns 'a' is greater than limit: 1", ], ["group2", "val3", None], ["group2", "val3", None], @@ -811,10 +811,10 @@ def test_is_aggr_with_mode_function(spark: SparkSession): # groupA should fail (mode=401 > limit=400), groupB should pass (mode=200 <= limit=400) expected = spark.createDataFrame( [ - ["groupA", 401, "Mode 401 in column 'b' per group of columns 'a' is greater than limit: 400"], - ["groupA", 401, "Mode 401 in column 'b' per group of columns 'a' is greater than limit: 400"], - ["groupA", 401, "Mode 401 in column 'b' per group of columns 'a' is greater than limit: 400"], - ["groupA", 500, "Mode 401 in column 'b' per group of columns 'a' is greater than limit: 400"], + ["groupA", 401, "Mode value 401 in column 'b' per group of columns 'a' is greater than limit: 400"], + ["groupA", 401, "Mode value 401 in column 'b' per group of columns 'a' is greater than limit: 400"], + ["groupA", 401, "Mode value 401 in column 'b' per group of columns 'a' is greater than limit: 400"], + ["groupA", 500, "Mode value 401 in column 'b' per group of columns 'a' is greater than limit: 400"], ["groupB", 200, None], ["groupB", 200, None], ["groupB", 404, None], @@ -855,6 +855,30 @@ def test_is_aggr_with_percentile_functions(spark: SparkSession): assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True) +def test_is_aggr_message_includes_params(spark: SparkSession): + """Test that violation messages include aggr_params for differentiation.""" + test_df = spark.createDataFrame( + [("group1", 100.0), ("group1", 200.0), ("group1", 300.0)], + "a: string, b: double", + ) + + # Create a check that will fail - P50 is 200, limit is 100 + checks = [ + is_aggr_not_greater_than( + "b", limit=100.0, aggr_type="percentile", aggr_params={"percentile": 0.5}, group_by=["a"] + ), + ] + + actual = _apply_checks(test_df, checks) + + # Verify message includes the percentile parameter + # Column name includes group_by info: b_percentile_group_by_a_greater_than_limit + messages = [row.b_percentile_group_by_a_greater_than_limit for row in actual.collect()] + assert all(msg is not None for msg in messages), "All rows should have violation messages" + assert all("percentile=0.5" in msg for msg in messages), "Message should include percentile parameter" + assert all("Percentile (percentile=0.5) value" in msg for msg in messages), "Message should have correct format" + + def test_is_aggr_percentile_missing_params(spark: SparkSession): """Test that percentile functions require percentile parameter.""" test_df = spark.createDataFrame([(1, 10.0)], "id: int, value: double") @@ -886,6 +910,21 @@ def test_is_aggr_percentile_invalid_params_caught_by_spark(spark: SparkSession): result_df.collect() # Force evaluation to trigger Spark error +def test_is_aggr_with_invalid_parameter_name(spark: SparkSession): + """Test that invalid parameter names in aggr_params raise clear errors.""" + test_df = spark.createDataFrame([(1, 10.0), (2, 20.0)], "id: int, value: double") + + # Pass an invalid parameter name - should raise InvalidParameterError with context + with pytest.raises(InvalidParameterError, match="Failed to build 'approx_percentile' expression"): + _, apply_fn = is_aggr_not_greater_than( + "value", + limit=100.0, + aggr_type="approx_percentile", + aggr_params={"percentile": 0.95, "invalid_param": 1}, # Invalid param name + ) + apply_fn(test_df) + + def test_is_aggr_with_invalid_aggregate_function(spark: SparkSession): """Test that invalid aggregate function names raise clear errors.""" test_df = spark.createDataFrame([(1, 10)], "id: int, value: int") From 7692dd68f27e4fd3eda51435ffc8ec4cc1c16d55 Mon Sep 17 00:00:00 2001 From: Greg Hansen Date: Sat, 6 Dec 2025 16:36:09 -0500 Subject: [PATCH 24/24] Update docs --- .../docs/guide/quality_checks_definition.mdx | 234 ++++++------ docs/dqx/docs/reference/quality_checks.mdx | 339 +++++++----------- 2 files changed, 237 insertions(+), 336 deletions(-) diff --git a/docs/dqx/docs/guide/quality_checks_definition.mdx b/docs/dqx/docs/guide/quality_checks_definition.mdx index f26ce73c6..3d49839b6 100644 --- a/docs/dqx/docs/guide/quality_checks_definition.mdx +++ b/docs/dqx/docs/guide/quality_checks_definition.mdx @@ -34,6 +34,10 @@ You can define quality checks programmatically using a list of DQX classes (list Checks can be defined using DQX classes such as `DQRowRule`, `DQDatasetRule`, and `DQForEachColRule`. This approach provides static type checking and autocompletion in IDEs, making it potentially easier to work with. + +The validation of arguments and keyword arguments for the check function is automatically performed upon creating a `DQRowRule`. + + ```python @@ -123,12 +127,14 @@ This approach provides static type checking and autocompletion in IDEs, making i column="col1", check_func_kwargs={"aggr_type": "avg", "group_by": ["col2"], "limit": 1.2}, ), + DQDatasetRule( # dataset check working across group of rows criticality="error", check_func=check_funcs.is_aggr_equal, column="col1", check_func_kwargs={"aggr_type": "count", "group_by": ["col2"], "limit": 5}, ), + DQDatasetRule( # dataset check working across group of rows criticality="error", check_func=check_funcs.is_aggr_not_equal, @@ -136,20 +142,18 @@ This approach provides static type checking and autocompletion in IDEs, making i check_func_kwargs={"aggr_type": "avg", "group_by": ["col2"], "limit": 10.5}, ), - # Extended aggregate functions for advanced data quality checks - - DQDatasetRule( # Uniqueness check: each country should have one country code + DQDatasetRule( # dataset check for distinct value count for groups (each group should have 1 value) criticality="error", check_func=check_funcs.is_aggr_not_greater_than, column="country_code", check_func_kwargs={ - "aggr_type": "count_distinct", # Exact distinct count (automatically uses two-stage aggregation) + "aggr_type": "count_distinct", # Exact distinct count "group_by": ["country"], "limit": 1 }, ), - DQDatasetRule( # Anomaly detection: detect unusual variance in sensor readings + DQDatasetRule( # dataset check for standard deviation for groups criticality="warn", check_func=check_funcs.is_aggr_not_greater_than, column="temperature", @@ -160,7 +164,7 @@ This approach provides static type checking and autocompletion in IDEs, making i }, ), - DQDatasetRule( # SLA monitoring: P95 latency must be under 1 second + DQDatasetRule( # dataset check for percentile with the percentile value passed using aggr_params criticality="error", check_func=check_funcs.is_aggr_not_greater_than, column="latency_ms", @@ -175,114 +179,6 @@ This approach provides static type checking and autocompletion in IDEs, making i - -The validation of arguments and keyword arguments for the check function is automatically performed upon creating a `DQRowRule`. - - -## Practical Use Cases for Extended Aggregates - -### Uniqueness Validation with count_distinct - -Ensure referential integrity by verifying that each entity has exactly one identifier: - -```yaml -# Each country should have exactly one country code -- criticality: error - check: - function: is_aggr_not_greater_than - arguments: - column: country_code - aggr_type: count_distinct - group_by: - - country - limit: 1 -``` - -### Anomaly Detection with Statistical Functions - -Detect unusual patterns in sensor data or business metrics: - -```yaml -# Alert on unusually high temperature variance per machine -- criticality: warn - check: - function: is_aggr_not_greater_than - arguments: - column: temperature - aggr_type: stddev - group_by: - - machine_id - limit: 5.0 - -# Monitor revenue stability across product lines -- criticality: error - check: - function: is_aggr_not_greater_than - arguments: - column: daily_revenue - aggr_type: variance - group_by: - - product_line - limit: 1000000.0 -``` - -### SLA and Performance Monitoring with Percentiles - -Monitor service performance and ensure SLA compliance: - -```yaml -# P95 API latency must be under 1 second -- criticality: error - check: - function: is_aggr_not_greater_than - arguments: - column: latency_ms - aggr_type: percentile - aggr_params: - percentile: 0.95 - limit: 1000 - -# P99 response time check (fast approximate) -- criticality: warn - check: - function: is_aggr_not_greater_than - arguments: - column: response_time_ms - aggr_type: approx_percentile - aggr_params: - percentile: 0.99 - accuracy: 10000 - limit: 5000 - -# Median baseline for order processing time -- criticality: warn - check: - function: is_aggr_not_less_than - arguments: - column: processing_time_sec - aggr_type: median - group_by: - - order_type - limit: 30.0 -``` - - -**Uniqueness & Cardinality:** -- `count_distinct` - Exact distinct counts (uniqueness validation) -- `approx_count_distinct` - Fast approximate counts (very large datasets) - -**Statistical Monitoring:** -- `stddev` / `variance` - Detect anomalies and inconsistencies -- `median` - Baseline checks, central tendency -- `mode` - Most frequent value (categorical data) - -**Performance & SLAs:** -- `percentile` - Exact P95/P99 for SLA compliance -- `approx_percentile` - Fast percentile estimates for large datasets - -**Learn more:** See all [Databricks aggregate functions](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha) - - ### Checks defined using metadata (list of dictionaries) Checks can be defined using declarative syntax as a list of dictionaries. @@ -385,6 +281,7 @@ but you can also define a list of dictionaries directly in python. group_by: - col2 limit: 1.2 + - criticality: error check: function: is_aggr_equal @@ -394,6 +291,7 @@ but you can also define a list of dictionaries directly in python. group_by: - col2 limit: 5 + - criticality: error check: function: is_aggr_not_equal @@ -403,6 +301,36 @@ but you can also define a list of dictionaries directly in python. group_by: - col2 limit: 10.5 + + - criticality: error + check: + function: is_aggr_not_greater_than + arguments: + column: country_code + aggr_type: count_distinct + group_by: + - country + limit: 1 + + - criticality: warn + check: + function: is_aggr_not_greater_than + arguments: + column: temperature + aggr_type: stddev + group_by: + - machine_id + limit: 5.0 + + - criticality: error + check: + function: is_aggr_not_greater_than + arguments: + column: latency_ms + aggr_type: percentile + aggr_params: + percentile: 0.95 + limit: 1000 """) ``` @@ -518,6 +446,7 @@ Example `yaml` file defining several checks: group_by: - col2 limit: 1.2 + - criticality: error check: function: is_aggr_equal @@ -527,6 +456,7 @@ Example `yaml` file defining several checks: group_by: - col2 limit: 5 + - criticality: error check: function: is_aggr_not_equal @@ -536,6 +466,36 @@ Example `yaml` file defining several checks: group_by: - col2 limit: 10.5 + +- criticality: error + check: + function: is_aggr_not_greater_than + arguments: + column: country_code + aggr_type: count_distinct + group_by: + - country + limit: 1 + +- criticality: warn + check: + function: is_aggr_not_greater_than + arguments: + column: temperature + aggr_type: stddev + group_by: + - machine_id + limit: 5.0 + +- criticality: error + check: + function: is_aggr_not_greater_than + arguments: + column: latency_ms + aggr_type: percentile + aggr_params: + percentile: 0.95 + limit: 1000 ``` ## JSON format (declarative approach) @@ -679,6 +639,44 @@ Checks defined in JSON files are supported by the DQX workflows. "limit": 10.5 } } + }, + { + "criticality": "error", + "check": { + "function": "is_aggr_not_greater_than", + "arguments": { + "column": "country_code", + "aggr_type": "count_distinct", + "group_by": ["country"], + "limit": 1 + } + } + }, + { + "criticality": "warn", + "check": { + "function": "is_aggr_not_greater_than", + "arguments": { + "column": "temperature", + "aggr_type": "stddev", + "group_by": ["machine_id"], + "limit": 5.0 + } + } + }, + { + "criticality": "error", + "check": { + "function": "is_aggr_not_greater_than", + "arguments": { + "column": "latency_ms", + "aggr_type": "percentile", + "aggr_params": { + "percentile": 0.95 + }, + "limit": 1000 + } + } } ] ``` @@ -725,6 +723,10 @@ This validation ensures that the checks are correctly defined and can be interpr The validation cannot be used for checks defined programmatically using DQX classes. When checks are defined programmatically with DQX classes, syntax validation is unnecessary because the application will fail to interpret them if the DQX objects are constructed incorrectly. + +Validating quality rules are typically done as part of the CI/CD process to ensure checks are ready to use in the application. + + ```python @@ -761,7 +763,3 @@ When checks are defined programmatically with DQX classes, syntax validation is - 'checks_location': file or table location of the quality checks - - -Validating quality rules are typically done as part of the CI/CD process to ensure checks are ready to use in the application. - \ No newline at end of file diff --git a/docs/dqx/docs/reference/quality_checks.mdx b/docs/dqx/docs/reference/quality_checks.mdx index 8de8cf22a..1ef154fc1 100644 --- a/docs/dqx/docs/reference/quality_checks.mdx +++ b/docs/dqx/docs/reference/quality_checks.mdx @@ -1600,223 +1600,6 @@ Complex data types are supported as well. group_by: - col3 limit: 200 -``` -
- -## Aggregate Function Types {#aggregate-function-types} - -DQX provides four aggregate check functions for validating metrics across your entire dataset or within groups: - -- **`is_aggr_not_greater_than`** - Validates that aggregated values do not exceed a limit -- **`is_aggr_not_less_than`** - Validates that aggregated values meet a minimum threshold -- **`is_aggr_equal`** - Validates that aggregated values equal an expected value -- **`is_aggr_not_equal`** - Validates that aggregated values differ from a specific value - -Use these functions for scenarios like "each country has ≤ 1 country code" or "P95 latency < 1s per region". - -**When to use aggregate checks:** -- Validate cardinality and uniqueness constraints -- Monitor statistical properties (variance, outliers) -- Enforce SLAs and performance thresholds -- Detect data quality issues within groups - -DQX supports 20 **curated** aggregate functions (recommended for data quality) plus any other [Databricks built-in aggregate function](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha). - -### Curated Aggregate Functions {#curated-aggregate-functions} - -Validated and optimized for data quality use cases: - -#### Basic Aggregations -- `count` - Count records (nulls excluded) -- `sum` - Sum of values -- `avg` - Average of values -- `min` - Minimum value -- `max` - Maximum value - -#### Cardinality & Uniqueness -- `count_distinct` - Exact distinct count (use for uniqueness validation) -- `approx_count_distinct` - Fast approximate count (±5% accuracy, ideal for very large datasets) -- `count_if` - Conditional counting - -#### Statistical Analysis -- `stddev` / `stddev_samp` - Sample standard deviation (anomaly detection) -- `stddev_pop` - Population standard deviation -- `variance` / `var_samp` - Sample variance (stability checks) -- `var_pop` - Population variance -- `median` - 50th percentile baseline -- `mode` - Most frequent value (categorical data quality) -- `skewness` - Distribution skewness (detect asymmetry) -- `kurtosis` - Distribution kurtosis (detect heavy tails) - -#### SLA & Performance Monitoring -- `percentile` - Exact percentile (requires `aggr_params`) -- `approx_percentile` - Approximate percentile (faster, requires `aggr_params`) - - -Using `count_distinct` with `group_by` performs multiple expensive operations. **For large-scale datasets, use `approx_count_distinct` instead**—it processes efficiently and scales well with high-cardinality groups. - -**Performance benchmarks**: See [benchmarks](/docs/reference/benchmarks) for detailed comparison. - - -### Using `aggr_params` for Function Parameters {#using-aggr_params-for-function-parameters} - -Some aggregate functions require additional parameters beyond the column. Pass these as a dictionary: - -**Python:** -```python -check_func_kwargs={ - "aggr_type": "percentile", - "aggr_params": {"percentile": 0.95}, # Single parameter - "limit": 1000 -} - -# Multiple parameters: -check_func_kwargs={ - "aggr_type": "approx_percentile", - "aggr_params": { - "percentile": 0.99, - "accuracy": 10000 # Multiple parameters as dict - } -} -``` - -**YAML:** -```yaml -aggr_params: - percentile: 0.95 - accuracy: 10000 # Multiple parameters as nested YAML -``` - -The `aggr_params` dictionary is unpacked as keyword arguments to the Spark aggregate function. See [Databricks function documentation](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha) for available parameters for each function. - -### Non-Curated Aggregate Functions {#non-curated-aggregates} - -Any [Databricks built-in aggregate function](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha) can be used beyond the curated list. DQX will: -- Issue a warning (recommending curated functions when possible) -- Validate the function returns numeric values - -**Limitations:** -- Functions returning arrays, structs, or maps (e.g., `collect_list`, `collect_set`) will fail with a clear error message explaining they cannot be compared to numeric limits -- **User-Defined Aggregate Functions (UDAFs) are not currently supported** - only Databricks built-in functions in `pyspark.sql.functions` - - -Some aggregate functions require newer Spark versions. If a function is unavailable, the error message will indicate the minimum required Spark version. **Use the latest Databricks Runtime for best compatibility.** - - -### Extended Examples {#aggregate-extended-examples} - -```yaml -# Uniqueness: Each country has exactly one country code -- criticality: error - check: - function: is_aggr_not_greater_than - arguments: - column: country_code - aggr_type: count_distinct - group_by: - - country - limit: 1 - -# Cardinality: Total unique users across entire dataset -- criticality: error - check: - function: is_aggr_not_greater_than - arguments: - column: user_id - aggr_type: count_distinct - limit: 1000000 - -# Fast approximate count for very large datasets -- criticality: error - check: - function: is_aggr_not_greater_than - arguments: - column: user_id - aggr_type: approx_count_distinct - group_by: - - session_date - limit: 100000 - -# Standard deviation: Detect unusual variance in sensor readings per machine -- criticality: warn - check: - function: is_aggr_not_greater_than - arguments: - column: temperature - aggr_type: stddev - group_by: - - machine_id - limit: 5.0 - -# Percentile: Monitor P95 latency for SLA compliance -- criticality: error - check: - function: is_aggr_not_greater_than - arguments: - column: latency_ms - aggr_type: percentile - aggr_params: - percentile: 0.95 - limit: 1000 - -# Approximate percentile: Fast P99 check with accuracy control -- criticality: warn - check: - function: is_aggr_not_greater_than - arguments: - column: response_time - aggr_type: approx_percentile - aggr_params: - percentile: 0.99 - accuracy: 10000 - limit: 5000 - -# Median: Baseline check for central tendency -- criticality: warn - check: - function: is_aggr_not_less_than - arguments: - column: order_value - aggr_type: median - group_by: - - product_category - limit: 50.0 - -# Variance: Stability check for financial data -- criticality: error - check: - function: is_aggr_not_greater_than - arguments: - column: daily_revenue - aggr_type: variance - limit: 1000000.0 - -# Mode: Alert if any single error code dominates -- criticality: warn - check: - function: is_aggr_not_greater_than - arguments: - column: error_code - aggr_type: mode - group_by: - - service_name - limit: 100 - -# count_if: Monitor error rate per service -- criticality: error - check: - function: is_aggr_not_greater_than - arguments: - column: status_code >= 500 - aggr_type: count_if - group_by: - - service_name - limit: 10 -``` - -
-**Checks defined in YAML (continued)** -```yaml # foreign_key check using reference DataFrame - criticality: error check: @@ -2387,9 +2170,129 @@ The reference DataFrames are used in selected Dataset-level checks: ```
+### Usage examples for aggregate checks + +DQX supports several curated aggregate functions for use with aggregate value checks. Use `aggr_params` when defining aggregate checks to pass additional keyword arguments to aggregate functions. To use other aggregate functions, see [using non-curated aggregate functions](#using-non-curated-aggregate-functions). + +
+**Curated aggregate functions** +| Function | Description | Parameters | +| ---------------------------------- | -------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `count` | Count of values | | +| `sum` | Sum of values | | +| `avg` | Arithmetic mean of values | | +| `min` | Minimum value | | +| `max` | Maximum value | | +| `count_distinct` | Distinct count of values | | +| `approx_count_distinct` | Approximate distinct count of values; Efficient for large datasets | | +| `stddev` | Sample standard deviation of values | | +| `stddev_pop` | Population standard deviation of values | | +| `variance` | Sample variance of values | | +| `var_pop` | Population variance of values | | +| `median` | Median or 50th percentile value | | +| `mode` | Most frequent value | | +| `skewness` | Degree of asymmetry in the distribution of values | | +| `kurtosis` | Degree of tailedness in the distribution of values | | +| `percentile` | Percentile value | `percentile`: Percentage of values which are less than or equal to the result value | +| `approx_percentile` | Approximate percentile value; Efficient for large datasets | `percentile`: Percentage of values which are less than or equal to the result value; `accuracy`: Optional positive numeric value which controls the approximation accuracy (default 1000). The relative estimation error is *1.0 / accuracy*. | +
+ +#### Counting Distinct Values + +Distinct value counts can be used to guarantee uniqueness and consistency of values. This could be to guarantee referential integrity between datasets or to enforce consistency of column values. + + +Using `count_distinct` with `group_by` performs multiple expensive operations. For large-scale datasets where uniqueness is not required, use `approx_count_distinct` for best performance. + + +```yaml +# Ensure only 1 country code exists per country +- criticality: error + check: + function: is_aggr_not_greater_than + arguments: + column: country_code + aggr_type: count_distinct + group_by: + - country + limit: 1 +``` + +#### Detecting Statistical Anomalies + +Descriptive statistics are often used for process control, forecasting, and regression analysis. Variation from expected values could signal that a process has changed or that a model's features have drifted. Monitor dataset statistics (e.g. `avg`, `stddev`) to detect significant variations. + +```yaml +# Detect unusually high temperature variance per machine +- criticality: warn + check: + function: is_aggr_not_greater_than + arguments: + column: temperature + aggr_type: stddev + group_by: + - machine_id + limit: 5.0 + +# Detect revenue variation across product lines +- criticality: error + check: + function: is_aggr_not_greater_than + arguments: + column: daily_revenue + aggr_type: variance + group_by: + - product_line + limit: 1000000.0 +``` + +#### Monitoring for Service Level Agreements + +Percentiles are useful for monitoring service-level agreements (SLAs). Monitor percentiles to detect when SLAs (e.g. API response time) differ from required values. + + +`percentile` calculates an exact percentile using all values in the input data. Use `approx_percentile` for scalable percentile computation when exact values are not required. + + +```yaml +# Detect when the P95 API latency is slower than 1 second +- criticality: error + check: + function: is_aggr_not_greater_than + arguments: + column: latency_ms + aggr_type: percentile + aggr_params: + percentile: 0.95 + limit: 1000 + +# Detect when the P95 API latency is slower than 1 second using approximate percentiles +- criticality: warn + check: + function: is_aggr_not_greater_than + arguments: + column: latency_ms + aggr_type: approx_percentile + aggr_params: + percentile: 0.99 + accuracy: 10000 + limit: 1000 +``` + +#### Using Non-Curated Aggregate Functions + +Other [Databricks built-in aggregate functions](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha) may be used to check aggregate values. + + +Using non-curated aggregate functions is supported with the following limitations: +- Functions returning arrays, structs, or maps (e.g., `collect_list`, `collect_set`) are not supported +- User-Defined Aggregate Functions (UDAFs) are not supported +- Some aggregate functions are only available in newer versions of the Databricks Runtime. If a function is unavailable in your environment, DQX will raise an error indicating the minimum required runtime version. Use the latest Databricks Runtime version for best compatibility. + + ## Creating Custom Row-level Checks -### Using SQL Expression +### Using SQL ExpressionS You can define custom checks using SQL Expression rule (`sql_expression`).