diff --git a/demos/dqx_demo_library.py b/demos/dqx_demo_library.py
index d414317cb..19d3bbbe3 100644
--- a/demos/dqx_demo_library.py
+++ b/demos/dqx_demo_library.py
@@ -690,6 +690,227 @@ def not_ends_with(column: str, suffix: str) -> Column:
# COMMAND ----------
+# MAGIC %md
+# MAGIC ### Extended Aggregate Functions for Data Quality Checks
+# MAGIC
+# MAGIC DQX now supports 20 curated aggregate functions for advanced data quality monitoring:
+# MAGIC - **Statistical functions**: `stddev`, `variance`, `median`, `mode`, `skewness`, `kurtosis` for anomaly detection
+# MAGIC - **Percentile functions**: `percentile`, `approx_percentile` for SLA monitoring
+# MAGIC - **Cardinality functions**: `count_distinct`, `approx_count_distinct` (uses HyperLogLog++)
+# MAGIC - **Any Databricks built-in aggregate**: Supported with runtime validation
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC #### Example 1: Statistical Functions - Anomaly Detection with Standard Deviation
+# MAGIC
+# MAGIC Detect unusual variance in sensor readings per machine. High standard deviation indicates unstable sensors that may need calibration.
+
+# COMMAND ----------
+
+from databricks.labs.dqx.engine import DQEngine
+from databricks.labs.dqx.rule import DQDatasetRule
+from databricks.labs.dqx import check_funcs
+from databricks.sdk import WorkspaceClient
+
+# Manufacturing sensor data with readings from multiple machines
+manufacturing_df = spark.createDataFrame([
+ ["M1", "2024-01-01", 20.1],
+ ["M1", "2024-01-02", 20.3],
+ ["M1", "2024-01-03", 20.2], # Machine 1: stable readings (low stddev)
+ ["M2", "2024-01-01", 18.5],
+ ["M2", "2024-01-02", 25.7],
+ ["M2", "2024-01-03", 15.2], # Machine 2: unstable readings (high stddev) - should FAIL
+ ["M3", "2024-01-01", 19.8],
+ ["M3", "2024-01-02", 20.1],
+ ["M3", "2024-01-03", 19.9], # Machine 3: stable readings
+], "machine_id: string, date: string, temperature: double")
+
+# Quality check: Standard deviation should not exceed 3.0 per machine
+checks = [
+ DQDatasetRule(
+ criticality="error",
+ check_func=check_funcs.is_aggr_not_greater_than,
+ column="temperature",
+ check_func_kwargs={
+ "aggr_type": "stddev",
+ "group_by": ["machine_id"],
+ "limit": 3.0
+ },
+ ),
+]
+
+dq_engine = DQEngine(WorkspaceClient())
+result_df = dq_engine.apply_checks(manufacturing_df, checks)
+display(result_df)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC #### Example 2: Approximate Aggregate Functions - Efficient Cardinality Estimation
+# MAGIC
+# MAGIC **`approx_count_distinct`** provides fast, memory-efficient cardinality estimation for large datasets.
+# MAGIC
+# MAGIC **From [Databricks Documentation](https://docs.databricks.com/aws/en/sql/language-manual/functions/approx_count_distinct.html):**
+# MAGIC - Uses **HyperLogLog++** (HLL++) algorithm, a state-of-the-art cardinality estimator
+# MAGIC - **Accurate within 5%** by default (configurable via `relativeSD` parameter)
+# MAGIC - **Memory efficient**: Uses fixed memory regardless of cardinality
+# MAGIC - **Ideal for**: High-cardinality columns, large datasets, real-time analytics
+# MAGIC
+# MAGIC **Use Case**: Monitor daily active users without expensive exact counting.
+
+# COMMAND ----------
+
+# User activity data with high cardinality
+user_activity_df = spark.createDataFrame([
+ ["2024-01-01", f"user_{i}"] for i in range(1, 95001) # 95,000 distinct users on day 1
+] + [
+ ["2024-01-02", f"user_{i}"] for i in range(1, 50001) # 50,000 distinct users on day 2
+], "activity_date: string, user_id: string")
+
+# Quality check: Ensure daily active users don't drop below 60,000
+# Using approx_count_distinct is much faster than count_distinct for large datasets
+checks = [
+ DQDatasetRule(
+ criticality="warn",
+ check_func=check_funcs.is_aggr_not_less_than,
+ column="user_id",
+ check_func_kwargs={
+ "aggr_type": "approx_count_distinct", # Fast approximate counting
+ "group_by": ["activity_date"],
+ "limit": 60000
+ },
+ ),
+]
+
+result_df = dq_engine.apply_checks(user_activity_df, checks)
+display(result_df)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC #### Example 3: Non-Curated Aggregate Functions with Runtime Validation
+# MAGIC
+# MAGIC DQX supports any Databricks built-in aggregate function beyond the curated list:
+# MAGIC - **Warning**: Non-curated functions trigger a warning
+# MAGIC - **Runtime validation**: Ensures the function returns numeric values compatible with comparisons
+# MAGIC - **Graceful errors**: Invalid aggregates (e.g., `collect_list` returning arrays) fail with clear messages
+# MAGIC
+# MAGIC **Note**: User-Defined Aggregate Functions (UDAFs) are not currently supported.
+
+# COMMAND ----------
+
+import warnings
+
+# Sensor data with multiple readings per sensor
+sensor_sample_df = spark.createDataFrame([
+ ["S1", 45.2],
+ ["S1", 45.8],
+ ["S2", 78.1],
+ ["S2", 78.5],
+], "sensor_id: string, reading: double")
+
+# Using a valid but non-curated aggregate function: any_value
+# This will work but produce a warning
+checks = [
+ DQDatasetRule(
+ criticality="warn",
+ check_func=check_funcs.is_aggr_not_greater_than,
+ column="reading",
+ check_func_kwargs={
+ "aggr_type": "any_value", # Not in curated list - triggers warning
+ "group_by": ["sensor_id"],
+ "limit": 100.0
+ },
+ ),
+]
+
+# Capture warnings during execution
+with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ result_df = dq_engine.apply_checks(sensor_sample_df, checks)
+
+ # Display warning message if present
+ if w:
+ print(f"⚠️ Warning: {w[0].message}")
+
+display(result_df)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC #### Example 4: Percentile Functions for SLA Monitoring
+# MAGIC
+# MAGIC Monitor P95 latency to ensure 95% of API requests meet SLA requirements.
+# MAGIC
+# MAGIC **Using `aggr_params`:** Pass aggregate function parameters as a dictionary.
+# MAGIC - Single parameter: `aggr_params={"percentile": 0.95}`
+# MAGIC - Multiple parameters: `aggr_params={"percentile": 0.99, "accuracy": 10000}`
+
+# COMMAND ----------
+
+# API latency data in milliseconds
+api_latency_df = spark.createDataFrame([
+ ["2024-01-01", i * 10.0] for i in range(1, 101) # 10ms to 1000ms latencies
+], "date: string, latency_ms: double")
+
+# Quality check: P95 latency must be under 950ms
+checks = [
+ DQDatasetRule(
+ criticality="error",
+ check_func=check_funcs.is_aggr_not_greater_than,
+ column="latency_ms",
+ check_func_kwargs={
+ "aggr_type": "percentile",
+ "aggr_params": {"percentile": 0.95}, # aggr_params as dict
+ "group_by": ["date"],
+ "limit": 950.0
+ },
+ ),
+]
+
+result_df = dq_engine.apply_checks(api_latency_df, checks)
+display(result_df)
+
+# COMMAND ----------
+
+# MAGIC %md
+# MAGIC #### Example 5: Uniqueness Validation with Count Distinct
+# MAGIC
+# MAGIC Ensure referential integrity: each country should have exactly one country code.
+# MAGIC
+# MAGIC Use `count_distinct` for exact cardinality validation across groups.
+
+# COMMAND ----------
+
+# Country data with potential duplicates
+country_df = spark.createDataFrame([
+ ["US", "USA"],
+ ["US", "USA"], # OK: same code
+ ["FR", "FRA"],
+ ["FR", "FRN"], # ERROR: different codes for same country
+ ["DE", "DEU"],
+], "country: string, country_code: string")
+
+# Quality check: Each country must have exactly one distinct country code
+checks = [
+ DQDatasetRule(
+ criticality="error",
+ check_func=check_funcs.is_aggr_not_greater_than,
+ column="country_code",
+ check_func_kwargs={
+ "aggr_type": "count_distinct", # Exact distinct count per group
+ "group_by": ["country"],
+ "limit": 1
+ },
+ ),
+]
+
+result_df = dq_engine.apply_checks(country_df, checks)
+display(result_df)
+
+# COMMAND ----------
+
# MAGIC %md
# MAGIC ### Creating custom dataset-level checks
# MAGIC Requirement: Fail all readings from a sensor if any reading for that sensor exceeds a specified threshold from the sensor specification table.
diff --git a/docs/dqx/docs/guide/quality_checks_definition.mdx b/docs/dqx/docs/guide/quality_checks_definition.mdx
index 621824e4f..f26ce73c6 100644
--- a/docs/dqx/docs/guide/quality_checks_definition.mdx
+++ b/docs/dqx/docs/guide/quality_checks_definition.mdx
@@ -135,6 +135,41 @@ This approach provides static type checking and autocompletion in IDEs, making i
column="col1",
check_func_kwargs={"aggr_type": "avg", "group_by": ["col2"], "limit": 10.5},
),
+
+ # Extended aggregate functions for advanced data quality checks
+
+ DQDatasetRule( # Uniqueness check: each country should have one country code
+ criticality="error",
+ check_func=check_funcs.is_aggr_not_greater_than,
+ column="country_code",
+ check_func_kwargs={
+ "aggr_type": "count_distinct", # Exact distinct count (automatically uses two-stage aggregation)
+ "group_by": ["country"],
+ "limit": 1
+ },
+ ),
+
+ DQDatasetRule( # Anomaly detection: detect unusual variance in sensor readings
+ criticality="warn",
+ check_func=check_funcs.is_aggr_not_greater_than,
+ column="temperature",
+ check_func_kwargs={
+ "aggr_type": "stddev",
+ "group_by": ["machine_id"],
+ "limit": 5.0
+ },
+ ),
+
+ DQDatasetRule( # SLA monitoring: P95 latency must be under 1 second
+ criticality="error",
+ check_func=check_funcs.is_aggr_not_greater_than,
+ column="latency_ms",
+ check_func_kwargs={
+ "aggr_type": "percentile",
+ "aggr_params": {"percentile": 0.95},
+ "limit": 1000
+ },
+ ),
]
```
@@ -144,6 +179,110 @@ This approach provides static type checking and autocompletion in IDEs, making i
The validation of arguments and keyword arguments for the check function is automatically performed upon creating a `DQRowRule`.
+## Practical Use Cases for Extended Aggregates
+
+### Uniqueness Validation with count_distinct
+
+Ensure referential integrity by verifying that each entity has exactly one identifier:
+
+```yaml
+# Each country should have exactly one country code
+- criticality: error
+ check:
+ function: is_aggr_not_greater_than
+ arguments:
+ column: country_code
+ aggr_type: count_distinct
+ group_by:
+ - country
+ limit: 1
+```
+
+### Anomaly Detection with Statistical Functions
+
+Detect unusual patterns in sensor data or business metrics:
+
+```yaml
+# Alert on unusually high temperature variance per machine
+- criticality: warn
+ check:
+ function: is_aggr_not_greater_than
+ arguments:
+ column: temperature
+ aggr_type: stddev
+ group_by:
+ - machine_id
+ limit: 5.0
+
+# Monitor revenue stability across product lines
+- criticality: error
+ check:
+ function: is_aggr_not_greater_than
+ arguments:
+ column: daily_revenue
+ aggr_type: variance
+ group_by:
+ - product_line
+ limit: 1000000.0
+```
+
+### SLA and Performance Monitoring with Percentiles
+
+Monitor service performance and ensure SLA compliance:
+
+```yaml
+# P95 API latency must be under 1 second
+- criticality: error
+ check:
+ function: is_aggr_not_greater_than
+ arguments:
+ column: latency_ms
+ aggr_type: percentile
+ aggr_params:
+ percentile: 0.95
+ limit: 1000
+
+# P99 response time check (fast approximate)
+- criticality: warn
+ check:
+ function: is_aggr_not_greater_than
+ arguments:
+ column: response_time_ms
+ aggr_type: approx_percentile
+ aggr_params:
+ percentile: 0.99
+ accuracy: 10000
+ limit: 5000
+
+# Median baseline for order processing time
+- criticality: warn
+ check:
+ function: is_aggr_not_less_than
+ arguments:
+ column: processing_time_sec
+ aggr_type: median
+ group_by:
+ - order_type
+ limit: 30.0
+```
+
+
+**Uniqueness & Cardinality:**
+- `count_distinct` - Exact distinct counts (uniqueness validation)
+- `approx_count_distinct` - Fast approximate counts (very large datasets)
+
+**Statistical Monitoring:**
+- `stddev` / `variance` - Detect anomalies and inconsistencies
+- `median` - Baseline checks, central tendency
+- `mode` - Most frequent value (categorical data)
+
+**Performance & SLAs:**
+- `percentile` - Exact P95/P99 for SLA compliance
+- `approx_percentile` - Fast percentile estimates for large datasets
+
+**Learn more:** See all [Databricks aggregate functions](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha)
+
+
### Checks defined using metadata (list of dictionaries)
Checks can be defined using declarative syntax as a list of dictionaries.
diff --git a/docs/dqx/docs/reference/benchmarks.mdx b/docs/dqx/docs/reference/benchmarks.mdx
index 5bddeb18e..18c1bc423 100644
--- a/docs/dqx/docs/reference/benchmarks.mdx
+++ b/docs/dqx/docs/reference/benchmarks.mdx
@@ -57,6 +57,9 @@ sidebar_position: 13
| test_benchmark_has_valid_schema | 0.172078 | 0.172141 | 0.163793 | 0.181081 | 0.006715 | 0.009295 | 0.167010 | 0.176305 | 6 | 0 | 2 | 5.81 |
| test_benchmark_has_x_coordinate_between | 0.217192 | 0.213656 | 0.209310 | 0.236233 | 0.011150 | 0.012638 | 0.209410 | 0.222048 | 5 | 0 | 1 | 4.60 |
| test_benchmark_has_y_coordinate_between | 0.218497 | 0.219630 | 0.209352 | 0.234111 | 0.010103 | 0.013743 | 0.209584 | 0.223327 | 5 | 0 | 1 | 4.58 |
+| test_benchmark_is_aggr_approx_count_distinct_with_group_by | 0.242859 | 0.239567 | 0.223938 | 0.272970 | 0.018260 | 0.017771 | 0.232410 | 0.250181 | 5 | 0 | 2 | 4.12 |
+| test_benchmark_is_aggr_count_distinct_no_group_by | 0.277439 | 0.272561 | 0.250425 | 0.320653 | 0.027686 | 0.038256 | 0.256090 | 0.294346 | 5 | 0 | 1 | 3.60 |
+| test_benchmark_is_aggr_count_distinct_with_group_by | 0.248015 | 0.249024 | 0.231550 | 0.270066 | 0.016044 | 0.026386 | 0.233048 | 0.259434 | 5 | 0 | 2 | 4.03 |
| test_benchmark_is_aggr_equal | 0.304401 | 0.305693 | 0.266624 | 0.330403 | 0.026888 | 0.044641 | 0.284540 | 0.329181 | 5 | 0 | 1 | 3.29 |
| test_benchmark_is_aggr_not_equal | 0.296462 | 0.296800 | 0.275119 | 0.312035 | 0.013498 | 0.013448 | 0.291054 | 0.304502 | 5 | 0 | 2 | 3.37 |
| test_benchmark_is_aggr_not_greater_than | 0.307771 | 0.315185 | 0.277924 | 0.316280 | 0.016705 | 0.010701 | 0.304974 | 0.315675 | 5 | 1 | 1 | 3.25 |
diff --git a/docs/dqx/docs/reference/quality_checks.mdx b/docs/dqx/docs/reference/quality_checks.mdx
index cf364d5c6..8de8cf22a 100644
--- a/docs/dqx/docs/reference/quality_checks.mdx
+++ b/docs/dqx/docs/reference/quality_checks.mdx
@@ -1387,10 +1387,10 @@ You can also define your own custom dataset-level checks (see [Creating custom c
| Check | Description | Arguments |
| ---------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `is_unique` | Checks whether the values in the input column are unique and reports an issue for each row that contains a duplicate value. It supports uniqueness check for multiple columns (composite key). Null values are not considered duplicates by default, following the ANSI SQL standard. | `columns`: columns to check (can be a list of column names or column expressions); `nulls_distinct`: controls how null values are treated, default is True, thus nulls are not duplicates, eg. (NULL, NULL) not equals (NULL, NULL) and (1, NULL) not equals (1, NULL) |
-| `is_aggr_not_greater_than` | Checks whether the aggregated values over group of rows or all rows are not greater than the provided limit. | `column`: column to check (can be a string column name or a column expression), optional for 'count' aggregation; `limit`: limit as number, column name or sql expression; `aggr_type`: aggregation function to use, such as "count" (default), "sum", "avg", "min", and "max"; `group_by`: (optional) list of columns or column expressions to group the rows for aggregation (no grouping by default) |
-| `is_aggr_not_less_than` | Checks whether the aggregated values over group of rows or all rows are not less than the provided limit. | `column`: column to check (can be a string column name or a column expression), optional for 'count' aggregation; `limit`: limit as number, column name or sql expression; `aggr_type`: aggregation function to use, such as "count" (default), "sum", "avg", "min", and "max"; `group_by`: (optional) list of columns or column expressions to group the rows for aggregation (no grouping by default) |
-| `is_aggr_equal` | Checks whether the aggregated values over group of rows or all rows are equal to the provided limit. | `column`: column to check (can be a string column name or a column expression), optional for 'count' aggregation; `limit`: limit as number, column name or sql expression, where string literals must be single quoted, e.g. 'string_value'; `aggr_type`: aggregation function to use, such as "count" (default), "sum", "avg", "min", and "max"; `group_by`: (optional) list of columns or column expressions to group the rows for aggregation (no grouping by default) |
-| `is_aggr_not_equal` | Checks whether the aggregated values over group of rows or all rows are not equal to the provided limit. | `column`: column to check (can be a string column name or a column expression), optional for 'count' aggregation; `limit`: limit as number, column name or sql expression, where string literals must be single quoted, e.g. 'string_value'; `aggr_type`: aggregation function to use, such as "count" (default), "sum", "avg", "min", and "max"; `group_by`: (optional) list of columns or column expressions to group the rows for aggregation (no grouping by default) |
+| `is_aggr_not_greater_than` | Checks whether the aggregated values over group of rows or all rows are not greater than the provided limit. | `column`: column to check (can be a string column name or a column expression), optional for 'count' aggregation; `limit`: limit as number, column name or sql expression (string literals must be single quoted, e.g. 'string_value'); `aggr_type`: aggregation function (default: "count"), supports 20 curated functions (count, sum, avg, stddev, percentile, etc.) plus any Databricks built-in aggregate; `group_by`: (optional) list of columns or column expressions to group the rows for aggregation (no grouping by default); `row_filter`: (optional) SQL expression to filter rows before aggregation; `aggr_params`: (optional) dict of parameters for aggregates requiring them |
+| `is_aggr_not_less_than` | Checks whether the aggregated values over group of rows or all rows are not less than the provided limit. | `column`: column to check (can be a string column name or a column expression), optional for 'count' aggregation; `limit`: limit as number, column name or sql expression (string literals must be single quoted, e.g. 'string_value'); `aggr_type`: aggregation function (default: "count"), supports 20 curated functions (count, sum, avg, stddev, percentile, etc.) plus any Databricks built-in aggregate; `group_by`: (optional) list of columns or column expressions to group the rows for aggregation (no grouping by default); `row_filter`: (optional) SQL expression to filter rows before aggregation; `aggr_params`: (optional) dict of parameters for aggregates requiring them |
+| `is_aggr_equal` | Checks whether the aggregated values over group of rows or all rows are equal to the provided limit. | `column`: column to check (can be a string column name or a column expression), optional for 'count' aggregation; `limit`: limit as number, column name or sql expression (string literals must be single quoted, e.g. 'string_value'); `aggr_type`: aggregation function (default: "count"), supports 20 curated functions (count, sum, avg, stddev, percentile, etc.) plus any Databricks built-in aggregate; `group_by`: (optional) list of columns or column expressions to group the rows for aggregation (no grouping by default); `row_filter`: (optional) SQL expression to filter rows before aggregation; `aggr_params`: (optional) dict of parameters for aggregates requiring them |
+| `is_aggr_not_equal` | Checks whether the aggregated values over group of rows or all rows are not equal to the provided limit. | `column`: column to check (can be a string column name or a column expression), optional for 'count' aggregation; `limit`: limit as number, column name or sql expression (string literals must be single quoted, e.g. 'string_value'); `aggr_type`: aggregation function (default: "count"), supports 20 curated functions (count, sum, avg, stddev, percentile, etc.) plus any Databricks built-in aggregate; `group_by`: (optional) list of columns or column expressions to group the rows for aggregation (no grouping by default); `row_filter`: (optional) SQL expression to filter rows before aggregation; `aggr_params`: (optional) dict of parameters for aggregates requiring them |
| `foreign_key` (aka is_in_list) | Checks whether input column or columns can be found in the reference DataFrame or Table (foreign key check). It supports foreign key check on single and composite keys. This check can be used to validate whether values in the input column(s) exist in a predefined list of allowed values (stored in the reference DataFrame or Table). It serves as a scalable alternative to `is_in_list` row-level checks, when working with large lists. | `columns`: columns to check (can be a list of string column names or column expressions); `ref_columns`: columns to check for existence in the reference DataFrame or Table (can be a list string column name or a column expression); `ref_df_name`: (optional) name of the reference DataFrame (dictionary of DataFrames can be passed when applying checks); `ref_table`: (optional) fully qualified reference table name; either `ref_df_name` or `ref_table` must be provided but never both; the number of passed `columns` and `ref_columns` must match and keys are checks in the given order; negate: if True the condition is negated (i.e. the check fails when the foreign key values exist in the reference DataFrame/Table), if False the check fails when the foreign key values do not exist in the reference |
| `sql_query` | Checks whether the condition column produced by a SQL query is satisfied. The check expects the query to return a boolean condition column indicating whether a record meets the requirement (True = fail, False = pass), and one or more merge columns so that results can be joined back to the input DataFrame to preserve all original records. Important considerations: if merge columns aren't unique, multiple query rows can attach to a single input row, potentially causing false positives. Performance tip: since the check must join back to the input DataFrame to retain original records, writing a custom dataset-level rule is usually more performant than `sql_query` check. | `query`: query string, must return all merge columns and condition column; `input_placeholder`: name to be used in the sql query as `{{ input_placeholder }}` to refer to the input DataFrame, optional reference DataFrames are referred by the name provided in the dictionary of reference DataFrames (e.g. `{{ ref_df_key }}`, dictionary of DataFrames can be passed when applying checks); `merge_columns`: list of columns used for merging with the input DataFrame which must exist in the input DataFrame and be present in output of the sql query; `condition_column`: name of the column indicating a violation (False = pass, True = fail); `msg`: (optional) message to output; `name`: (optional) name of the resulting check (it can be overwritten by `name` specified at the check level); `negate`: if the condition should be negated |
| `compare_datasets` | Compares two DataFrames at both row and column levels, providing detailed information about differences, including new or missing rows and column-level changes. Only columns present in both the source and reference DataFrames are compared. Use with caution if `check_missing_records` is enabled, as this may increase the number of rows in the output beyond the original input DataFrame. The comparison does not support Map types (any column comparison on map type is skipped automatically). Comparing datasets is valuable for validating data during migrations, detecting drift, performing regression testing, or verifying synchronization between source and target systems. | `columns`: columns to use for row matching with the reference DataFrame (can be a list of string column names or column expressions, but only simple column expressions are allowed such as 'F.col("col1")'), if not having primary keys or wanting to match against all columns you can pass 'df.columns'; `ref_columns`: list of columns in the reference DataFrame or Table to row match against the source DataFrame (can be a list of string column names or column expressions, but only simple column expressions are allowed such as 'F.col("col1")'), if not having primary keys or wanting to match against all columns you can pass 'ref_df.columns'; note that `columns` are matched with `ref_columns` by position, so the order of the provided columns in both lists must be exactly aligned; `exclude_columns`: (optional) list of columns to exclude from the value comparison but not from row matching (can be a list of string column names or column expressions, but only simple column expressions are allowed such as 'F.col("col1")'); the `exclude_columns` field does not alter the list of columns used to determine row matches (columns), it only controls which columns are skipped during the value comparison; `ref_df_name`: (optional) name of the reference DataFrame (dictionary of DataFrames can be passed when applying checks); `ref_table`: (optional) fully qualified reference table name; either `ref_df_name` or `ref_table` must be provided but never both; the number of passed `columns` and `ref_columns` must match and keys are checks in the given order; `check_missing_records`: perform a FULL OUTER JOIN to identify records that are missing from source or reference DataFrames, default is False; use with caution as this may increase the number of rows in the output, as unmatched rows from both sides are included; `null_safe_row_matching`: (optional) treat NULLs as equal when matching rows using `columns` and `ref_columns` (default: True); `null_safe_column_value_matching`: (optional) treat NULLs as equal when comparing column values (default: True) |
@@ -1600,7 +1600,223 @@ Complex data types are supported as well.
group_by:
- col3
limit: 200
+```
+
+
+## Aggregate Function Types {#aggregate-function-types}
+
+DQX provides four aggregate check functions for validating metrics across your entire dataset or within groups:
+
+- **`is_aggr_not_greater_than`** - Validates that aggregated values do not exceed a limit
+- **`is_aggr_not_less_than`** - Validates that aggregated values meet a minimum threshold
+- **`is_aggr_equal`** - Validates that aggregated values equal an expected value
+- **`is_aggr_not_equal`** - Validates that aggregated values differ from a specific value
+
+Use these functions for scenarios like "each country has ≤ 1 country code" or "P95 latency < 1s per region".
+
+**When to use aggregate checks:**
+- Validate cardinality and uniqueness constraints
+- Monitor statistical properties (variance, outliers)
+- Enforce SLAs and performance thresholds
+- Detect data quality issues within groups
+
+DQX supports 20 **curated** aggregate functions (recommended for data quality) plus any other [Databricks built-in aggregate function](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha).
+
+### Curated Aggregate Functions {#curated-aggregate-functions}
+
+Validated and optimized for data quality use cases:
+
+#### Basic Aggregations
+- `count` - Count records (nulls excluded)
+- `sum` - Sum of values
+- `avg` - Average of values
+- `min` - Minimum value
+- `max` - Maximum value
+
+#### Cardinality & Uniqueness
+- `count_distinct` - Exact distinct count (use for uniqueness validation)
+- `approx_count_distinct` - Fast approximate count (±5% accuracy, ideal for very large datasets)
+- `count_if` - Conditional counting
+
+#### Statistical Analysis
+- `stddev` / `stddev_samp` - Sample standard deviation (anomaly detection)
+- `stddev_pop` - Population standard deviation
+- `variance` / `var_samp` - Sample variance (stability checks)
+- `var_pop` - Population variance
+- `median` - 50th percentile baseline
+- `mode` - Most frequent value (categorical data quality)
+- `skewness` - Distribution skewness (detect asymmetry)
+- `kurtosis` - Distribution kurtosis (detect heavy tails)
+
+#### SLA & Performance Monitoring
+- `percentile` - Exact percentile (requires `aggr_params`)
+- `approx_percentile` - Approximate percentile (faster, requires `aggr_params`)
+
+
+Using `count_distinct` with `group_by` performs multiple expensive operations. **For large-scale datasets, use `approx_count_distinct` instead**—it processes efficiently and scales well with high-cardinality groups.
+
+**Performance benchmarks**: See [benchmarks](/docs/reference/benchmarks) for detailed comparison.
+
+
+### Using `aggr_params` for Function Parameters {#using-aggr_params-for-function-parameters}
+
+Some aggregate functions require additional parameters beyond the column. Pass these as a dictionary:
+
+**Python:**
+```python
+check_func_kwargs={
+ "aggr_type": "percentile",
+ "aggr_params": {"percentile": 0.95}, # Single parameter
+ "limit": 1000
+}
+
+# Multiple parameters:
+check_func_kwargs={
+ "aggr_type": "approx_percentile",
+ "aggr_params": {
+ "percentile": 0.99,
+ "accuracy": 10000 # Multiple parameters as dict
+ }
+}
+```
+
+**YAML:**
+```yaml
+aggr_params:
+ percentile: 0.95
+ accuracy: 10000 # Multiple parameters as nested YAML
+```
+
+The `aggr_params` dictionary is unpacked as keyword arguments to the Spark aggregate function. See [Databricks function documentation](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha) for available parameters for each function.
+
+### Non-Curated Aggregate Functions {#non-curated-aggregates}
+
+Any [Databricks built-in aggregate function](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha) can be used beyond the curated list. DQX will:
+- Issue a warning (recommending curated functions when possible)
+- Validate the function returns numeric values
+**Limitations:**
+- Functions returning arrays, structs, or maps (e.g., `collect_list`, `collect_set`) will fail with a clear error message explaining they cannot be compared to numeric limits
+- **User-Defined Aggregate Functions (UDAFs) are not currently supported** - only Databricks built-in functions in `pyspark.sql.functions`
+
+
+Some aggregate functions require newer Spark versions. If a function is unavailable, the error message will indicate the minimum required Spark version. **Use the latest Databricks Runtime for best compatibility.**
+
+
+### Extended Examples {#aggregate-extended-examples}
+
+```yaml
+# Uniqueness: Each country has exactly one country code
+- criticality: error
+ check:
+ function: is_aggr_not_greater_than
+ arguments:
+ column: country_code
+ aggr_type: count_distinct
+ group_by:
+ - country
+ limit: 1
+
+# Cardinality: Total unique users across entire dataset
+- criticality: error
+ check:
+ function: is_aggr_not_greater_than
+ arguments:
+ column: user_id
+ aggr_type: count_distinct
+ limit: 1000000
+
+# Fast approximate count for very large datasets
+- criticality: error
+ check:
+ function: is_aggr_not_greater_than
+ arguments:
+ column: user_id
+ aggr_type: approx_count_distinct
+ group_by:
+ - session_date
+ limit: 100000
+
+# Standard deviation: Detect unusual variance in sensor readings per machine
+- criticality: warn
+ check:
+ function: is_aggr_not_greater_than
+ arguments:
+ column: temperature
+ aggr_type: stddev
+ group_by:
+ - machine_id
+ limit: 5.0
+
+# Percentile: Monitor P95 latency for SLA compliance
+- criticality: error
+ check:
+ function: is_aggr_not_greater_than
+ arguments:
+ column: latency_ms
+ aggr_type: percentile
+ aggr_params:
+ percentile: 0.95
+ limit: 1000
+
+# Approximate percentile: Fast P99 check with accuracy control
+- criticality: warn
+ check:
+ function: is_aggr_not_greater_than
+ arguments:
+ column: response_time
+ aggr_type: approx_percentile
+ aggr_params:
+ percentile: 0.99
+ accuracy: 10000
+ limit: 5000
+
+# Median: Baseline check for central tendency
+- criticality: warn
+ check:
+ function: is_aggr_not_less_than
+ arguments:
+ column: order_value
+ aggr_type: median
+ group_by:
+ - product_category
+ limit: 50.0
+
+# Variance: Stability check for financial data
+- criticality: error
+ check:
+ function: is_aggr_not_greater_than
+ arguments:
+ column: daily_revenue
+ aggr_type: variance
+ limit: 1000000.0
+
+# Mode: Alert if any single error code dominates
+- criticality: warn
+ check:
+ function: is_aggr_not_greater_than
+ arguments:
+ column: error_code
+ aggr_type: mode
+ group_by:
+ - service_name
+ limit: 100
+
+# count_if: Monitor error rate per service
+- criticality: error
+ check:
+ function: is_aggr_not_greater_than
+ arguments:
+ column: status_code >= 500
+ aggr_type: count_if
+ group_by:
+ - service_name
+ limit: 10
+```
+
+
+**Checks defined in YAML (continued)**
+```yaml
# foreign_key check using reference DataFrame
- criticality: error
check:
diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py
index 40def0770..3cf7f4eab 100644
--- a/src/databricks/labs/dqx/check_funcs.py
+++ b/src/databricks/labs/dqx/check_funcs.py
@@ -7,6 +7,7 @@
from enum import Enum
from itertools import zip_longest
import operator as py_operator
+from typing import Any
import pandas as pd # type: ignore[import-untyped]
import pyspark.sql.functions as F
from pyspark.sql import types
@@ -27,6 +28,40 @@
IPV4_MAX_OCTET_COUNT = 4
IPV4_BIT_LENGTH = 32
+# Curated aggregate functions for data quality checks
+# These are univariate (single-column) aggregate functions suitable for DQ monitoring
+# Maps function names to human-readable display names for error messages
+CURATED_AGGR_FUNCTIONS = {
+ "count": "Count",
+ "sum": "Sum",
+ "avg": "Average",
+ "min": "Min",
+ "max": "Max",
+ "count_distinct": "Distinct count",
+ "approx_count_distinct": "Approximate distinct count",
+ "count_if": "Conditional count",
+ "stddev": "Standard deviation",
+ "stddev_pop": "Population standard deviation",
+ "stddev_samp": "Sample standard deviation",
+ "variance": "Variance",
+ "var_pop": "Population variance",
+ "var_samp": "Sample variance",
+ "median": "Median",
+ "mode": "Mode",
+ "skewness": "Skewness",
+ "kurtosis": "Kurtosis",
+ "percentile": "Percentile",
+ "approx_percentile": "Approximate percentile",
+}
+
+# Aggregate functions incompatible with Spark window functions
+# These require two-stage aggregation (groupBy + join) instead of window functions when used with group_by
+# Spark limitation: DISTINCT operations are not supported in window functions
+WINDOW_INCOMPATIBLE_AGGREGATES = {
+ "count_distinct", # DISTINCT_WINDOW_FUNCTION_UNSUPPORTED error
+ # Future: Add other aggregates that don't work with windows (e.g., collect_set with DISTINCT)
+}
+
class DQPattern(Enum):
"""Enum class to represent DQ patterns used to match data in columns."""
@@ -461,7 +496,6 @@ def is_equal_to(
column (str | Column): Column to check. Can be a string column name or a column expression.
value (int | float | str | datetime.date | datetime.datetime | Column | None, optional):
The value to compare with. Can be a literal or a Spark Column. Defaults to None.
- String literals must be single quoted, e.g. 'string_value'.
Returns:
Column: A Spark Column condition that fails if the column value is not equal to the given value.
@@ -493,7 +527,6 @@ def is_not_equal_to(
column (str | Column): Column to check. Can be a string column name or a column expression.
value (int | float | str | datetime.date | datetime.datetime | Column | None, optional):
The value to compare with. Can be a literal or a Spark Column. Defaults to None.
- String literals must be single quoted, e.g. 'string_value'.
Returns:
Column: A Spark Column condition that fails if the column value is equal to the given value.
@@ -1267,20 +1300,25 @@ def is_aggr_not_greater_than(
aggr_type: str = "count",
group_by: list[str | Column] | None = None,
row_filter: str | None = None,
+ aggr_params: dict[str, Any] | None = None,
) -> tuple[Column, Callable]:
"""
Build an aggregation check condition and closure for dataset-level validation.
- This function verifies that an aggregation (count, sum, avg, min, max) on a column
- or group of columns does not exceed a specified limit. Rows where the aggregation
- result exceeds the limit are flagged.
+ This function verifies that an aggregation on a column or group of columns does not exceed
+ a specified limit. Supports curated aggregate functions (count, sum, avg, stddev, percentile, etc.)
+ and any Databricks built-in aggregate. Rows where the aggregation result exceeds the limit are flagged.
Args:
column: Column name (str) or Column expression to aggregate.
- limit: Numeric value, column name, or SQL expression for the limit.
- aggr_type: Aggregation type: 'count', 'sum', 'avg', 'min', or 'max' (default: 'count').
+ limit: Numeric value, column name, or SQL expression for the limit. String literals must be single quoted, e.g. 'string_value'.
+ aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max,
+ count_distinct, stddev, percentile, and more. Any Databricks built-in aggregate is supported.
group_by: Optional list of column names or Column expressions to group by.
row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter.
+ aggr_params: Optional dict of parameters for aggregates requiring them (e.g., percentile value for
+ percentile functions, accuracy for approximate aggregates). Parameters are passed as keyword
+ arguments to the Spark function.
Returns:
A tuple of:
@@ -1291,6 +1329,7 @@ def is_aggr_not_greater_than(
column,
limit,
aggr_type,
+ aggr_params,
group_by,
row_filter,
compare_op=py_operator.gt,
@@ -1306,20 +1345,25 @@ def is_aggr_not_less_than(
aggr_type: str = "count",
group_by: list[str | Column] | None = None,
row_filter: str | None = None,
+ aggr_params: dict[str, Any] | None = None,
) -> tuple[Column, Callable]:
"""
Build an aggregation check condition and closure for dataset-level validation.
- This function verifies that an aggregation (count, sum, avg, min, max) on a column
- or group of columns is not below a specified limit. Rows where the aggregation
- result is below the limit are flagged.
+ This function verifies that an aggregation on a column or group of columns is not below
+ a specified limit. Supports curated aggregate functions (count, sum, avg, stddev, percentile, etc.)
+ and any Databricks built-in aggregate. Rows where the aggregation result is below the limit are flagged.
Args:
column: Column name (str) or Column expression to aggregate.
- limit: Numeric value, column name, or SQL expression for the limit.
- aggr_type: Aggregation type: 'count', 'sum', 'avg', 'min', or 'max' (default: 'count').
+ limit: Numeric value, column name, or SQL expression for the limit. String literals must be single quoted, e.g. 'string_value'.
+ aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max,
+ count_distinct, stddev, percentile, and more. Any Databricks built-in aggregate is supported.
group_by: Optional list of column names or Column expressions to group by.
row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter.
+ aggr_params: Optional dict of parameters for aggregates requiring them (e.g., percentile value for
+ percentile functions, accuracy for approximate aggregates). Parameters are passed as keyword
+ arguments to the Spark function.
Returns:
A tuple of:
@@ -1330,6 +1374,7 @@ def is_aggr_not_less_than(
column,
limit,
aggr_type,
+ aggr_params,
group_by,
row_filter,
compare_op=py_operator.lt,
@@ -1345,20 +1390,25 @@ def is_aggr_equal(
aggr_type: str = "count",
group_by: list[str | Column] | None = None,
row_filter: str | None = None,
+ aggr_params: dict[str, Any] | None = None,
) -> tuple[Column, Callable]:
"""
Build an aggregation check condition and closure for dataset-level validation.
- This function verifies that an aggregation (count, sum, avg, min, max) on a column
- or group of columns is equal to a specified limit. Rows where the aggregation
- result is not equal to the limit are flagged.
+ This function verifies that an aggregation on a column or group of columns is equal to
+ a specified limit. Supports curated aggregate functions (count, sum, avg, stddev, percentile, etc.)
+ and any Databricks built-in aggregate. Rows where the aggregation result is not equal to the limit are flagged.
Args:
column: Column name (str) or Column expression to aggregate.
limit: Numeric value, column name, or SQL expression for the limit. String literals must be single quoted, e.g. 'string_value'.
- aggr_type: Aggregation type: 'count', 'sum', 'avg', 'min', or 'max' (default: 'count').
+ aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max,
+ count_distinct, stddev, percentile, and more. Any Databricks built-in aggregate is supported.
group_by: Optional list of column names or Column expressions to group by.
row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter.
+ aggr_params: Optional dict of parameters for aggregates requiring them (e.g., percentile value for
+ percentile functions, accuracy for approximate aggregates). Parameters are passed as keyword
+ arguments to the Spark function.
Returns:
A tuple of:
@@ -1369,6 +1419,7 @@ def is_aggr_equal(
column,
limit,
aggr_type,
+ aggr_params,
group_by,
row_filter,
compare_op=py_operator.ne,
@@ -1384,20 +1435,25 @@ def is_aggr_not_equal(
aggr_type: str = "count",
group_by: list[str | Column] | None = None,
row_filter: str | None = None,
+ aggr_params: dict[str, Any] | None = None,
) -> tuple[Column, Callable]:
"""
Build an aggregation check condition and closure for dataset-level validation.
- This function verifies that an aggregation (count, sum, avg, min, max) on a column
- or group of columns is not equal to a specified limit. Rows where the aggregation
- result is equal to the limit are flagged.
+ This function verifies that an aggregation on a column or group of columns is not equal to
+ a specified limit. Supports curated aggregate functions (count, sum, avg, stddev, percentile, etc.)
+ and any Databricks built-in aggregate. Rows where the aggregation result is equal to the limit are flagged.
Args:
column: Column name (str) or Column expression to aggregate.
limit: Numeric value, column name, or SQL expression for the limit. String literals must be single quoted, e.g. 'string_value'.
- aggr_type: Aggregation type: 'count', 'sum', 'avg', 'min', or 'max' (default: 'count').
+ aggr_type: Aggregation type (default: 'count'). Curated types include count, sum, avg, min, max,
+ count_distinct, stddev, percentile, and more. Any Databricks built-in aggregate is supported.
group_by: Optional list of column names or Column expressions to group by.
row_filter: Optional SQL expression to filter rows before aggregation. Auto-injected from the check filter.
+ aggr_params: Optional dict of parameters for aggregates requiring them (e.g., percentile value for
+ percentile functions, accuracy for approximate aggregates). Parameters are passed as keyword
+ arguments to the Spark function.
Returns:
A tuple of:
@@ -1408,6 +1464,7 @@ def is_aggr_not_equal(
column,
limit,
aggr_type,
+ aggr_params,
group_by,
row_filter,
compare_op=py_operator.eq,
@@ -2223,10 +2280,95 @@ def _add_compare_condition(
)
+def _build_aggregate_expression(
+ aggr_type: str,
+ filtered_expr: Column,
+ aggr_params: dict[str, Any] | None,
+) -> Column:
+ """
+ Build the appropriate Spark aggregate expression based on function type and parameters.
+
+ Args:
+ aggr_type: Name of the aggregate function.
+ filtered_expr: Column expression with filters applied.
+ aggr_params: Optional parameters for the aggregate function.
+
+ Returns:
+ Spark Column expression for the aggregate.
+
+ Raises:
+ MissingParameterError: If required parameters are missing for specific aggregates.
+ InvalidParameterError: If the aggregate function is not found or parameters are invalid.
+ """
+ if aggr_type == "count_distinct":
+ return F.countDistinct(filtered_expr)
+
+ if aggr_type in {"percentile", "approx_percentile"}:
+ if not aggr_params or "percentile" not in aggr_params:
+ raise MissingParameterError(
+ f"'{aggr_type}' requires aggr_params with 'percentile' key (e.g., {{'percentile': 0.95}})"
+ )
+ pct = aggr_params["percentile"]
+ # Pass through any additional parameters to Spark (e.g., accuracy, frequency)
+ # Spark will validate parameter names and types at runtime
+ other_params = {k: v for k, v in aggr_params.items() if k != "percentile"}
+
+ try:
+ aggr_func = getattr(F, aggr_type)
+ return aggr_func(filtered_expr, pct, **other_params)
+ except Exception as exc:
+ raise InvalidParameterError(f"Failed to build '{aggr_type}' expression: {exc}") from exc
+
+ try:
+ aggr_func = getattr(F, aggr_type)
+ if aggr_params:
+ return aggr_func(filtered_expr, **aggr_params)
+ return aggr_func(filtered_expr)
+ except AttributeError as exc:
+ raise InvalidParameterError(
+ f"Aggregate function '{aggr_type}' not found in pyspark.sql.functions. "
+ f"Verify the function name is correct, or check if your Databricks Runtime version supports this function. "
+ f"Some newer aggregate functions (e.g., mode, median) require DBR 15.4+ (Spark 3.5+). "
+ f"See: https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-functions-builtin-alpha"
+ ) from exc
+ except Exception as exc:
+ raise InvalidParameterError(f"Failed to build '{aggr_type}' expression: {exc}") from exc
+
+
+def _validate_aggregate_return_type(
+ df: DataFrame,
+ aggr_type: str,
+ metric_col: str,
+) -> None:
+ """
+ Validate aggregate returns a numeric type that can be compared to limits.
+
+ This is a schema-only validation (no data scanning) that checks whether the aggregate
+ function returns a type compatible with numeric comparisons.
+
+ Args:
+ df: DataFrame containing the aggregate result column.
+ aggr_type: Name of the aggregate function being validated.
+ metric_col: Column name containing the aggregate result.
+
+ Raises:
+ InvalidParameterError: If the aggregate returns a non-numeric type (Array, Map, Struct)
+ that cannot be compared to limits.
+ """
+ result_type = df.schema[metric_col].dataType
+ if isinstance(result_type, (types.ArrayType, types.MapType, types.StructType)):
+ raise InvalidParameterError(
+ f"Aggregate function '{aggr_type}' returned {result_type.typeName()} "
+ f"which cannot be compared to numeric limits. "
+ f"Use aggregate functions that return numeric values (e.g., count, sum, avg)."
+ )
+
+
def _is_aggr_compare(
column: str | Column,
limit: int | float | str | Column,
aggr_type: str,
+ aggr_params: dict[str, Any] | None,
group_by: list[str | Column] | None,
row_filter: str | None,
compare_op: Callable[[Column, Column], Column],
@@ -2241,8 +2383,12 @@ def _is_aggr_compare(
Args:
column: Column name (str) or Column expression to aggregate.
- limit: Numeric value, column name, or SQL expression for the limit.
- aggr_type: Aggregation type: 'count', 'sum', 'avg', 'min', or 'max'.
+ limit: Numeric value, column name, or SQL expression for the limit. String literals must be single quoted, e.g. 'string_value'.
+ aggr_type: Aggregation type. Curated functions include 'count', 'sum', 'avg', 'min', 'max',
+ 'count_distinct', 'stddev', 'percentile', and more. Any Databricks built-in aggregate
+ function is supported (will trigger a warning for non-curated functions).
+ aggr_params: Optional dictionary of parameters for aggregate functions that require them
+ (e.g., percentile functions need {"percentile": 0.95}).
group_by: Optional list of columns or Column expressions to group by.
row_filter: Optional SQL expression to filter rows before aggregation.
compare_op: Comparison operator (e.g., operator.gt, operator.lt).
@@ -2255,11 +2401,19 @@ def _is_aggr_compare(
- A closure that applies the aggregation check logic.
Raises:
- InvalidParameterError: If an unsupported aggregation type is provided.
- """
- supported_aggr_types = {"count", "sum", "avg", "min", "max"}
- if aggr_type not in supported_aggr_types:
- raise InvalidParameterError(f"Unsupported aggregation type: {aggr_type}. Supported: {supported_aggr_types}")
+ InvalidParameterError: If an aggregate returns non-numeric types or is not found.
+ MissingParameterError: If required parameters for specific aggregates are not provided.
+ """
+ # Warn if using non-curated aggregate function
+ is_curated = aggr_type in CURATED_AGGR_FUNCTIONS
+ if not is_curated:
+ warnings.warn(
+ f"Using non-curated aggregate function '{aggr_type}'. "
+ f"Curated functions: {', '.join(sorted(CURATED_AGGR_FUNCTIONS))}. "
+ f"Non-curated aggregates must return a single numeric value per group.",
+ UserWarning,
+ stacklevel=3,
+ )
aggr_col_str_norm, aggr_col_str, aggr_col_expr = _get_normalized_column_and_expr(column)
@@ -2302,11 +2456,33 @@ def apply(df: DataFrame) -> DataFrame:
"""
filter_col = F.expr(row_filter) if row_filter else F.lit(True)
filtered_expr = F.when(filter_col, aggr_col_expr) if row_filter else aggr_col_expr
- aggr_expr = getattr(F, aggr_type)(filtered_expr)
+
+ # Build aggregation expression
+ aggr_expr = _build_aggregate_expression(aggr_type, filtered_expr, aggr_params)
if group_by:
- window_spec = Window.partitionBy(*[F.col(col) if isinstance(col, str) else col for col in group_by])
- df = df.withColumn(metric_col, aggr_expr.over(window_spec))
+ # Convert group_by to Column expressions (reused for both window and groupBy approaches)
+ group_cols = [F.col(col) if isinstance(col, str) else col for col in group_by]
+
+ # Check if aggregate is incompatible with window functions (e.g., count_distinct with DISTINCT)
+ if aggr_type in WINDOW_INCOMPATIBLE_AGGREGATES:
+ # Use two-stage aggregation: groupBy + join (instead of window functions)
+ # This is required for aggregates like count_distinct that don't support window DISTINCT operations
+ agg_df = df.groupBy(*group_cols).agg(aggr_expr.alias(metric_col))
+
+ # Join aggregated metrics back to original DataFrame to maintain row-level granularity
+ # Note: Aliased Column expressions in group_by are not supported for window-incompatible
+ # aggregates (e.g., count_distinct). Use string column names or simple F.col() expressions.
+ join_cols = [col if isinstance(col, str) else get_column_name_or_alias(col) for col in group_by]
+ df = df.join(agg_df, on=join_cols, how="left")
+ else:
+ # Use standard window function approach for window-compatible aggregates
+ window_spec = Window.partitionBy(*group_cols)
+ df = df.withColumn(metric_col, aggr_expr.over(window_spec))
+
+ # Validate non-curated aggregates (type check only - window functions return same row count)
+ if not is_curated:
+ _validate_aggregate_return_type(df, aggr_type, metric_col)
else:
# When no group-by columns are provided, using partitionBy would move all rows into a single partition,
# forcing the window function to process the entire dataset in one task.
@@ -2314,17 +2490,25 @@ def apply(df: DataFrame) -> DataFrame:
# Note: The aggregation naturally returns a single row without a groupBy clause,
# so no explicit limit is required (informational only).
agg_df = df.select(aggr_expr.alias(metric_col)).limit(1)
+
+ # Validate non-curated aggregates (type check only - we already limited to 1 row)
+ if not is_curated:
+ _validate_aggregate_return_type(agg_df, aggr_type, metric_col)
+
df = df.crossJoin(agg_df) # bring the metric across all rows
df = df.withColumn(condition_col, compare_op(F.col(metric_col), limit_expr))
return df
+ # Get human-readable display name for aggregate function (including params if present)
+ aggr_display_name = _get_aggregate_display_name(aggr_type, aggr_params)
+
condition = make_condition(
condition=F.col(condition_col),
message=F.concat_ws(
"",
- F.lit(f"{aggr_type.capitalize()} "),
+ F.lit(f"{aggr_display_name} value "),
F.col(metric_col).cast("string"),
F.lit(f" in column '{aggr_col_str}'"),
F.lit(f"{' per group of columns ' if group_by_list_str else ''}"),
@@ -2453,6 +2637,35 @@ def _get_normalized_column_and_expr(column: str | Column) -> tuple[str, str, Col
return col_str_norm, column_str, col_expr
+def _get_aggregate_display_name(aggr_type: str, aggr_params: dict[str, Any] | None = None) -> str:
+ """
+ Get a human-readable display name for an aggregate function.
+
+ This helper provides user-friendly names for aggregate functions in error messages,
+ transforming technical function names (e.g., 'count_distinct') into readable text
+ (e.g., 'Distinct value count').
+
+ Args:
+ aggr_type: The aggregate function name (e.g., 'count_distinct', 'max', 'avg').
+ aggr_params: Optional parameters passed to the aggregate function.
+
+ Returns:
+ A human-readable display name for the aggregate function, including parameters
+ if provided. For non-curated functions, returns the function name in quotes
+ with 'value' suffix.
+ """
+ # Get base display name (curated functions have friendly names, others show function name in quotes)
+ base_name = CURATED_AGGR_FUNCTIONS.get(aggr_type, f"'{aggr_type}'")
+
+ # Add parameters if present
+ if aggr_params:
+ # Format parameters as key=value pairs
+ param_str = ", ".join(f"{k}={v}" for k, v in aggr_params.items())
+ return f"{base_name} ({param_str})"
+
+ return base_name
+
+
def _get_column_expr(column: Column | str) -> Column:
"""
Convert a column input (string or Column) into a Spark Column expression.
diff --git a/tests/integration/test_apply_checks.py b/tests/integration/test_apply_checks.py
index db2a25ad1..647336f5d 100644
--- a/tests/integration/test_apply_checks.py
+++ b/tests/integration/test_apply_checks.py
@@ -6872,7 +6872,7 @@ def test_apply_aggr_checks(ws, spark):
[
{
"name": "c_avg_greater_than_limit",
- "message": "Avg 4.0 in column 'c' is greater than limit: 0",
+ "message": "Average value 4.0 in column 'c' is greater than limit: 0",
"columns": ["c"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -6894,7 +6894,7 @@ def test_apply_aggr_checks(ws, spark):
},
{
"name": "a_count_greater_than_limit",
- "message": "Count 2 in column 'a' is greater than limit: 0",
+ "message": "Count value 2 in column 'a' is greater than limit: 0",
"columns": ["a"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -6921,7 +6921,7 @@ def test_apply_aggr_checks(ws, spark):
[
{
"name": "b_sum_group_by_a_not_equal_to_limit",
- "message": "Sum 2 in column 'b' per group of columns 'a' is not equal to limit: 8",
+ "message": "Sum value 2 in column 'b' per group of columns 'a' is not equal to limit: 8",
"columns": ["b"],
"filter": None,
"function": "is_aggr_equal",
@@ -6931,7 +6931,7 @@ def test_apply_aggr_checks(ws, spark):
},
{
"name": "a_count_group_by_a_greater_than_limit",
- "message": "Count 2 in column 'a' per group of columns 'a' is greater than limit: 0",
+ "message": "Count value 2 in column 'a' per group of columns 'a' is greater than limit: 0",
"columns": ["a"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -6941,7 +6941,7 @@ def test_apply_aggr_checks(ws, spark):
},
{
"name": "row_count_group_by_a_b_greater_than_limit",
- "message": "Count 1 in column 'a' per group of columns 'a, b' is greater than limit: 0",
+ "message": "Count value 1 in column 'a' per group of columns 'a, b' is greater than limit: 0",
"columns": ["a"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -6951,7 +6951,7 @@ def test_apply_aggr_checks(ws, spark):
},
{
"name": "c_avg_greater_than_limit",
- "message": "Avg 4.0 in column 'c' is greater than limit: 0",
+ "message": "Average value 4.0 in column 'c' is greater than limit: 0",
"columns": ["c"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -6961,7 +6961,7 @@ def test_apply_aggr_checks(ws, spark):
},
{
"name": "c_avg_group_by_a_greater_than_limit",
- "message": "Avg 4.0 in column 'c' per group of columns 'a' is greater than limit: 0",
+ "message": "Average value 4.0 in column 'c' per group of columns 'a' is greater than limit: 0",
"columns": ["c"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -6983,7 +6983,7 @@ def test_apply_aggr_checks(ws, spark):
},
{
"name": "a_count_greater_than_limit",
- "message": "Count 2 in column 'a' is greater than limit: 0",
+ "message": "Count value 2 in column 'a' is greater than limit: 0",
"columns": ["a"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -7010,7 +7010,7 @@ def test_apply_aggr_checks(ws, spark):
[
{
"name": "b_sum_group_by_a_not_equal_to_limit",
- "message": "Sum 2 in column 'b' per group of columns 'a' is not equal to limit: 8",
+ "message": "Sum value 2 in column 'b' per group of columns 'a' is not equal to limit: 8",
"columns": ["b"],
"filter": None,
"function": "is_aggr_equal",
@@ -7020,7 +7020,7 @@ def test_apply_aggr_checks(ws, spark):
},
{
"name": "a_count_group_by_a_greater_than_limit",
- "message": "Count 2 in column 'a' per group of columns 'a' is greater than limit: 0",
+ "message": "Count value 2 in column 'a' per group of columns 'a' is greater than limit: 0",
"columns": ["a"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -7030,7 +7030,7 @@ def test_apply_aggr_checks(ws, spark):
},
{
"name": "a_count_group_by_a_greater_than_limit_with_b_not_null",
- "message": "Count 1 in column 'a' per group of columns 'a' is greater than limit: 0",
+ "message": "Count value 1 in column 'a' per group of columns 'a' is greater than limit: 0",
"columns": ["a"],
"filter": "b is not null",
"function": "is_aggr_not_greater_than",
@@ -7040,7 +7040,7 @@ def test_apply_aggr_checks(ws, spark):
},
{
"name": "row_count_group_by_a_b_greater_than_limit",
- "message": "Count 1 in column 'a' per group of columns 'a, b' is greater than limit: 0",
+ "message": "Count value 1 in column 'a' per group of columns 'a, b' is greater than limit: 0",
"columns": ["a"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -7050,7 +7050,7 @@ def test_apply_aggr_checks(ws, spark):
},
{
"name": "c_avg_greater_than_limit",
- "message": "Avg 4.0 in column 'c' is greater than limit: 0",
+ "message": "Average value 4.0 in column 'c' is greater than limit: 0",
"columns": ["c"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -7060,7 +7060,7 @@ def test_apply_aggr_checks(ws, spark):
},
{
"name": "c_avg_group_by_a_greater_than_limit",
- "message": "Avg 4.0 in column 'c' per group of columns 'a' is greater than limit: 0",
+ "message": "Average value 4.0 in column 'c' per group of columns 'a' is greater than limit: 0",
"columns": ["c"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -7070,7 +7070,7 @@ def test_apply_aggr_checks(ws, spark):
},
{
"name": "a_count_group_by_a_less_than_limit_with_b_not_null",
- "message": "Count 1 in column 'a' per group of columns 'a' is less than limit: 10",
+ "message": "Count value 1 in column 'a' per group of columns 'a' is less than limit: 10",
"columns": ["a"],
"filter": "b is not null",
"function": "is_aggr_not_less_than",
@@ -7092,7 +7092,7 @@ def test_apply_aggr_checks(ws, spark):
},
{
"name": "a_count_greater_than_limit",
- "message": "Count 2 in column 'a' is greater than limit: 0",
+ "message": "Count value 2 in column 'a' is greater than limit: 0",
"columns": ["a"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -7102,7 +7102,7 @@ def test_apply_aggr_checks(ws, spark):
},
{
"name": "a_count_greater_than_limit_with_b_not_null",
- "message": "Count 1 in column 'a' is greater than limit: 0",
+ "message": "Count value 1 in column 'a' is greater than limit: 0",
"columns": ["a"],
"filter": "b is not null",
"function": "is_aggr_not_greater_than",
@@ -7272,7 +7272,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
[
{
"name": "c_avg_greater_than_limit",
- "message": "Avg 4.0 in column 'c' is greater than limit: 0",
+ "message": "Average value 4.0 in column 'c' is greater than limit: 0",
"columns": ["c"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -7282,7 +7282,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "a_count_not_equal_to_limit",
- "message": "Count 2 in column 'a' is equal to limit: 2",
+ "message": "Count value 2 in column 'a' is equal to limit: 2",
"columns": ["a"],
"filter": None,
"function": "is_aggr_not_equal",
@@ -7304,7 +7304,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "a_count_greater_than_limit",
- "message": "Count 2 in column 'a' is greater than limit: 0",
+ "message": "Count value 2 in column 'a' is greater than limit: 0",
"columns": ["a"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -7334,7 +7334,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "c_avg_not_equal_to_limit",
- "message": "Avg 4.0 in column 'c' is equal to limit: 4.0",
+ "message": "Average value 4.0 in column 'c' is equal to limit: 4.0",
"columns": ["c"],
"filter": None,
"function": "is_aggr_not_equal",
@@ -7351,7 +7351,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
[
{
"name": "a_count_group_by_a_greater_than_limit",
- "message": "Count 2 in column 'a' per group of columns 'a' is greater than limit: 0",
+ "message": "Count value 2 in column 'a' per group of columns 'a' is greater than limit: 0",
"columns": ["a"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -7361,7 +7361,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "row_count_group_by_a_b_greater_than_limit",
- "message": "Count 1 in column 'a' per group of columns 'a, b' is greater than limit: 0",
+ "message": "Count value 1 in column 'a' per group of columns 'a, b' is greater than limit: 0",
"columns": ["a"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -7371,7 +7371,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "c_avg_greater_than_limit",
- "message": "Avg 4.0 in column 'c' is greater than limit: 0",
+ "message": "Average value 4.0 in column 'c' is greater than limit: 0",
"columns": ["c"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -7381,7 +7381,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "c_avg_group_by_a_greater_than_limit",
- "message": "Avg 4.0 in column 'c' per group of columns 'a' is greater than limit: 0",
+ "message": "Average value 4.0 in column 'c' per group of columns 'a' is greater than limit: 0",
"columns": ["c"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -7391,7 +7391,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "a_count_not_equal_to_limit",
- "message": "Count 2 in column 'a' is equal to limit: 2",
+ "message": "Count value 2 in column 'a' is equal to limit: 2",
"columns": ["a"],
"filter": None,
"function": "is_aggr_not_equal",
@@ -7413,7 +7413,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "a_count_greater_than_limit",
- "message": "Count 2 in column 'a' is greater than limit: 0",
+ "message": "Count value 2 in column 'a' is greater than limit: 0",
"columns": ["a"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -7443,7 +7443,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "c_avg_not_equal_to_limit",
- "message": "Avg 4.0 in column 'c' is equal to limit: 4.0",
+ "message": "Average value 4.0 in column 'c' is equal to limit: 4.0",
"columns": ["c"],
"filter": None,
"function": "is_aggr_not_equal",
@@ -7460,7 +7460,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
[
{
"name": "a_count_group_by_a_greater_than_limit",
- "message": "Count 2 in column 'a' per group of columns 'a' is greater than limit: 0",
+ "message": "Count value 2 in column 'a' per group of columns 'a' is greater than limit: 0",
"columns": ["a"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -7470,7 +7470,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "a_count_group_by_a_greater_than_limit_with_b_not_null",
- "message": "Count 1 in column 'a' per group of columns 'a' is greater than limit: 0",
+ "message": "Count value 1 in column 'a' per group of columns 'a' is greater than limit: 0",
"columns": ["a"],
"filter": "b is not null",
"function": "is_aggr_not_greater_than",
@@ -7480,7 +7480,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "row_count_group_by_a_b_greater_than_limit",
- "message": "Count 1 in column 'a' per group of columns 'a, b' is greater than limit: 0",
+ "message": "Count value 1 in column 'a' per group of columns 'a, b' is greater than limit: 0",
"columns": ["a"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -7490,7 +7490,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "c_avg_greater_than_limit",
- "message": "Avg 4.0 in column 'c' is greater than limit: 0",
+ "message": "Average value 4.0 in column 'c' is greater than limit: 0",
"columns": ["c"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -7500,7 +7500,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "c_avg_group_by_a_greater_than_limit",
- "message": "Avg 4.0 in column 'c' per group of columns 'a' is greater than limit: 0",
+ "message": "Average value 4.0 in column 'c' per group of columns 'a' is greater than limit: 0",
"columns": ["c"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -7510,7 +7510,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "a_count_group_by_a_less_than_limit_with_b_not_null",
- "message": "Count 1 in column 'a' per group of columns 'a' is less than limit: 10",
+ "message": "Count value 1 in column 'a' per group of columns 'a' is less than limit: 10",
"columns": ["a"],
"filter": "b is not null",
"function": "is_aggr_not_less_than",
@@ -7520,7 +7520,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "a_count_not_equal_to_limit",
- "message": "Count 2 in column 'a' is equal to limit: 2",
+ "message": "Count value 2 in column 'a' is equal to limit: 2",
"columns": ["a"],
"filter": None,
"function": "is_aggr_not_equal",
@@ -7530,7 +7530,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "a_count_not_equal_to_limit_with_filter",
- "message": "Count 1 in column 'a' is equal to limit: 1",
+ "message": "Count value 1 in column 'a' is equal to limit: 1",
"columns": ["a"],
"filter": "b is not null",
"function": "is_aggr_not_equal",
@@ -7552,7 +7552,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "a_count_greater_than_limit",
- "message": "Count 2 in column 'a' is greater than limit: 0",
+ "message": "Count value 2 in column 'a' is greater than limit: 0",
"columns": ["a"],
"filter": None,
"function": "is_aggr_not_greater_than",
@@ -7562,7 +7562,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "a_count_greater_than_limit_with_b_not_null",
- "message": "Count 1 in column 'a' is greater than limit: 0",
+ "message": "Count value 1 in column 'a' is greater than limit: 0",
"columns": ["a"],
"filter": "b is not null",
"function": "is_aggr_not_greater_than",
@@ -7592,7 +7592,7 @@ def test_apply_aggr_checks_by_metadata(ws, spark):
},
{
"name": "c_avg_not_equal_to_limit",
- "message": "Avg 4.0 in column 'c' is equal to limit: 4.0",
+ "message": "Average value 4.0 in column 'c' is equal to limit: 4.0",
"columns": ["c"],
"filter": None,
"function": "is_aggr_not_equal",
diff --git a/tests/integration/test_apply_checks_and_save_in_table.py b/tests/integration/test_apply_checks_and_save_in_table.py
index 9f71ee293..2290d7561 100644
--- a/tests/integration/test_apply_checks_and_save_in_table.py
+++ b/tests/integration/test_apply_checks_and_save_in_table.py
@@ -1803,14 +1803,27 @@ def test_apply_checks_and_save_in_tables_for_patterns_no_tables_matching(ws, spa
)
-def test_apply_checks_and_save_in_tables_for_patterns_exclude_no_tables_matching(ws, spark):
- # Test with empty list of table configs
+def test_apply_checks_and_save_in_tables_for_patterns_exclude_no_tables_matching(ws, spark, make_schema, make_table):
+ # Create an isolated schema with a known table, then exclude it using exclude_patterns
+ catalog_name = TEST_CATALOG
+ schema = make_schema(catalog_name=catalog_name)
+
+ # Create a table in the isolated schema
+ make_table(
+ catalog_name=catalog_name,
+ schema_name=schema.name,
+ ctas="SELECT 1 as id",
+ )
+
engine = DQEngine(ws, spark=spark, extra_params=EXTRA_PARAMS)
- # This should not raise an error
+ # Include the schema pattern, but exclude all tables with wildcard - should result in no tables
+ # Using exclude_patterns avoids the full catalog scan that exclude_matched=True triggers
with pytest.raises(NotFound, match="No tables found matching include or exclude criteria"):
engine.apply_checks_and_save_in_tables_for_patterns(
- patterns=["*"], checks_location="some/location", exclude_matched=True
+ patterns=[f"{catalog_name}.{schema.name}.*"],
+ exclude_patterns=[f"{catalog_name}.{schema.name}.*"],
+ checks_location="some/location",
)
diff --git a/tests/integration/test_dataset_checks.py b/tests/integration/test_dataset_checks.py
index 99eafa264..4944239f6 100644
--- a/tests/integration/test_dataset_checks.py
+++ b/tests/integration/test_dataset_checks.py
@@ -22,7 +22,7 @@
has_valid_schema,
)
from databricks.labs.dqx.utils import get_column_name_or_alias
-from databricks.labs.dqx.errors import InvalidParameterError
+from databricks.labs.dqx.errors import InvalidParameterError, MissingParameterError
from tests.conftest import TEST_CATALOG
@@ -221,39 +221,39 @@ def test_is_aggr_not_greater_than(spark: SparkSession):
[
"c",
None,
- "Count 3 in column 'a' is greater than limit: 1",
+ "Count value 3 in column 'a' is greater than limit: 1",
# displayed since filtering is done after, filter only applied for calculation inside the check
- "Count 2 in column 'a' is greater than limit: 0",
+ "Count value 2 in column 'a' is greater than limit: 0",
None,
None,
- "Avg 2.0 in column 'b' is greater than limit: 0.0",
- "Sum 4 in column 'b' is greater than limit: 0.0",
- "Min 1 in column 'b' is greater than limit: 0.0",
- "Max 3 in column 'b' is greater than limit: 0.0",
+ "Average value 2.0 in column 'b' is greater than limit: 0.0",
+ "Sum value 4 in column 'b' is greater than limit: 0.0",
+ "Min value 1 in column 'b' is greater than limit: 0.0",
+ "Max value 3 in column 'b' is greater than limit: 0.0",
],
[
"a",
1,
- "Count 3 in column 'a' is greater than limit: 1",
- "Count 2 in column 'a' is greater than limit: 0",
- "Count 1 in column 'a' per group of columns 'a' is greater than limit: 0",
- "Count 1 in column 'b' per group of columns 'b' is greater than limit: 0",
- "Avg 2.0 in column 'b' is greater than limit: 0.0",
- "Sum 4 in column 'b' is greater than limit: 0.0",
- "Min 1 in column 'b' is greater than limit: 0.0",
- "Max 3 in column 'b' is greater than limit: 0.0",
+ "Count value 3 in column 'a' is greater than limit: 1",
+ "Count value 2 in column 'a' is greater than limit: 0",
+ "Count value 1 in column 'a' per group of columns 'a' is greater than limit: 0",
+ "Count value 1 in column 'b' per group of columns 'b' is greater than limit: 0",
+ "Average value 2.0 in column 'b' is greater than limit: 0.0",
+ "Sum value 4 in column 'b' is greater than limit: 0.0",
+ "Min value 1 in column 'b' is greater than limit: 0.0",
+ "Max value 3 in column 'b' is greater than limit: 0.0",
],
[
"b",
3,
- "Count 3 in column 'a' is greater than limit: 1",
- "Count 2 in column 'a' is greater than limit: 0",
- "Count 1 in column 'a' per group of columns 'a' is greater than limit: 0",
- "Count 1 in column 'b' per group of columns 'b' is greater than limit: 0",
- "Avg 2.0 in column 'b' is greater than limit: 0.0",
- "Sum 4 in column 'b' is greater than limit: 0.0",
- "Min 1 in column 'b' is greater than limit: 0.0",
- "Max 3 in column 'b' is greater than limit: 0.0",
+ "Count value 3 in column 'a' is greater than limit: 1",
+ "Count value 2 in column 'a' is greater than limit: 0",
+ "Count value 1 in column 'a' per group of columns 'a' is greater than limit: 0",
+ "Count value 1 in column 'b' per group of columns 'b' is greater than limit: 0",
+ "Average value 2.0 in column 'b' is greater than limit: 0.0",
+ "Sum value 4 in column 'b' is greater than limit: 0.0",
+ "Min value 1 in column 'b' is greater than limit: 0.0",
+ "Max value 3 in column 'b' is greater than limit: 0.0",
],
],
expected_schema,
@@ -302,38 +302,38 @@ def test_is_aggr_not_less_than(spark: SparkSession):
[
"c",
None,
- "Count 3 in column 'a' is less than limit: 4",
- "Count 2 in column 'a' is less than limit: 3",
- "Count 0 in column 'a' per group of columns 'a' is less than limit: 2",
- "Count 0 in column 'b' per group of columns 'b' is less than limit: 2",
- "Avg 2.0 in column 'b' is less than limit: 3.0",
- "Sum 4 in column 'b' is less than limit: 5.0",
- "Min 1 in column 'b' is less than limit: 2.0",
- "Max 3 in column 'b' is less than limit: 4.0",
+ "Count value 3 in column 'a' is less than limit: 4",
+ "Count value 2 in column 'a' is less than limit: 3",
+ "Count value 0 in column 'a' per group of columns 'a' is less than limit: 2",
+ "Count value 0 in column 'b' per group of columns 'b' is less than limit: 2",
+ "Average value 2.0 in column 'b' is less than limit: 3.0",
+ "Sum value 4 in column 'b' is less than limit: 5.0",
+ "Min value 1 in column 'b' is less than limit: 2.0",
+ "Max value 3 in column 'b' is less than limit: 4.0",
],
[
"a",
1,
- "Count 3 in column 'a' is less than limit: 4",
- "Count 2 in column 'a' is less than limit: 3",
- "Count 1 in column 'a' per group of columns 'a' is less than limit: 2",
- "Count 1 in column 'b' per group of columns 'b' is less than limit: 2",
- "Avg 2.0 in column 'b' is less than limit: 3.0",
- "Sum 4 in column 'b' is less than limit: 5.0",
- "Min 1 in column 'b' is less than limit: 2.0",
- "Max 3 in column 'b' is less than limit: 4.0",
+ "Count value 3 in column 'a' is less than limit: 4",
+ "Count value 2 in column 'a' is less than limit: 3",
+ "Count value 1 in column 'a' per group of columns 'a' is less than limit: 2",
+ "Count value 1 in column 'b' per group of columns 'b' is less than limit: 2",
+ "Average value 2.0 in column 'b' is less than limit: 3.0",
+ "Sum value 4 in column 'b' is less than limit: 5.0",
+ "Min value 1 in column 'b' is less than limit: 2.0",
+ "Max value 3 in column 'b' is less than limit: 4.0",
],
[
"b",
3,
- "Count 3 in column 'a' is less than limit: 4",
- "Count 2 in column 'a' is less than limit: 3",
- "Count 1 in column 'a' per group of columns 'a' is less than limit: 2",
- "Count 1 in column 'b' per group of columns 'b' is less than limit: 2",
- "Avg 2.0 in column 'b' is less than limit: 3.0",
- "Sum 4 in column 'b' is less than limit: 5.0",
- "Min 1 in column 'b' is less than limit: 2.0",
- "Max 3 in column 'b' is less than limit: 4.0",
+ "Count value 3 in column 'a' is less than limit: 4",
+ "Count value 2 in column 'a' is less than limit: 3",
+ "Count value 1 in column 'a' per group of columns 'a' is less than limit: 2",
+ "Count value 1 in column 'b' per group of columns 'b' is less than limit: 2",
+ "Average value 2.0 in column 'b' is less than limit: 3.0",
+ "Sum value 4 in column 'b' is less than limit: 5.0",
+ "Min value 1 in column 'b' is less than limit: 2.0",
+ "Max value 3 in column 'b' is less than limit: 4.0",
],
],
expected_schema,
@@ -405,37 +405,37 @@ def test_is_aggr_equal(spark: SparkSession):
"c",
None,
None,
- "Count 2 in column 'a' is not equal to limit: 1",
- "Count 0 in column 'a' per group of columns 'a' is not equal to limit: 1",
- "Count 0 in column 'b' per group of columns 'b' is not equal to limit: 2",
+ "Count value 2 in column 'a' is not equal to limit: 1",
+ "Count value 0 in column 'a' per group of columns 'a' is not equal to limit: 1",
+ "Count value 0 in column 'b' per group of columns 'b' is not equal to limit: 2",
None,
- "Sum 4 in column 'b' is not equal to limit: 10.0",
+ "Sum value 4 in column 'b' is not equal to limit: 10.0",
None,
- "Max 3 in column 'b' is not equal to limit: 5.0",
+ "Max value 3 in column 'b' is not equal to limit: 5.0",
],
[
"a",
1,
None,
- "Count 2 in column 'a' is not equal to limit: 1",
+ "Count value 2 in column 'a' is not equal to limit: 1",
None,
- "Count 1 in column 'b' per group of columns 'b' is not equal to limit: 2",
+ "Count value 1 in column 'b' per group of columns 'b' is not equal to limit: 2",
None,
- "Sum 4 in column 'b' is not equal to limit: 10.0",
+ "Sum value 4 in column 'b' is not equal to limit: 10.0",
None,
- "Max 3 in column 'b' is not equal to limit: 5.0",
+ "Max value 3 in column 'b' is not equal to limit: 5.0",
],
[
"b",
3,
None,
- "Count 2 in column 'a' is not equal to limit: 1",
+ "Count value 2 in column 'a' is not equal to limit: 1",
None,
- "Count 1 in column 'b' per group of columns 'b' is not equal to limit: 2",
+ "Count value 1 in column 'b' per group of columns 'b' is not equal to limit: 2",
None,
- "Sum 4 in column 'b' is not equal to limit: 10.0",
+ "Sum value 4 in column 'b' is not equal to limit: 10.0",
None,
- "Max 3 in column 'b' is not equal to limit: 5.0",
+ "Max value 3 in column 'b' is not equal to limit: 5.0",
],
],
expected_schema,
@@ -483,37 +483,37 @@ def test_is_aggr_not_equal(spark: SparkSession):
[
"c",
None,
- "Count 3 in column 'a' is equal to limit: 3",
+ "Count value 3 in column 'a' is equal to limit: 3",
None,
None,
None,
- "Avg 2.0 in column 'b' is equal to limit: 2.0",
+ "Average value 2.0 in column 'b' is equal to limit: 2.0",
None,
- "Min 1 in column 'b' is equal to limit: 1.0",
+ "Min value 1 in column 'b' is equal to limit: 1.0",
None,
],
[
"a",
1,
- "Count 3 in column 'a' is equal to limit: 3",
+ "Count value 3 in column 'a' is equal to limit: 3",
None,
- "Count 1 in column 'a' per group of columns 'a' is equal to limit: 1",
+ "Count value 1 in column 'a' per group of columns 'a' is equal to limit: 1",
None,
- "Avg 2.0 in column 'b' is equal to limit: 2.0",
+ "Average value 2.0 in column 'b' is equal to limit: 2.0",
None,
- "Min 1 in column 'b' is equal to limit: 1.0",
+ "Min value 1 in column 'b' is equal to limit: 1.0",
None,
],
[
"b",
3,
- "Count 3 in column 'a' is equal to limit: 3",
+ "Count value 3 in column 'a' is equal to limit: 3",
None,
- "Count 1 in column 'a' per group of columns 'a' is equal to limit: 1",
+ "Count value 1 in column 'a' per group of columns 'a' is equal to limit: 1",
None,
- "Avg 2.0 in column 'b' is equal to limit: 2.0",
+ "Average value 2.0 in column 'b' is equal to limit: 2.0",
None,
- "Min 1 in column 'b' is equal to limit: 1.0",
+ "Min value 1 in column 'b' is equal to limit: 1.0",
None,
],
],
@@ -523,6 +523,463 @@ def test_is_aggr_not_equal(spark: SparkSession):
assert_df_equality(actual, expected, ignore_nullable=True)
+def test_is_aggr_with_count_distinct(spark: SparkSession):
+ """Test count_distinct for exact cardinality (works without group_by)."""
+ test_df = spark.createDataFrame(
+ [
+ ["val1", "data1"],
+ ["val1", "data2"], # Same first column
+ ["val2", "data3"], # Different first column
+ ["val3", "data4"],
+ ],
+ "a: string, b: string",
+ )
+
+ checks = [
+ # Global count_distinct (no group_by) - 3 distinct values in 'a', limit is 5, should pass
+ is_aggr_not_greater_than("a", limit=5, aggr_type="count_distinct"),
+ ]
+
+ actual = _apply_checks(test_df, checks)
+
+ expected = spark.createDataFrame(
+ [
+ ["val1", "data1", None],
+ ["val1", "data2", None],
+ ["val2", "data3", None],
+ ["val3", "data4", None],
+ ],
+ "a: string, b: string, a_count_distinct_greater_than_limit: string",
+ )
+
+ assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True)
+
+
+def test_is_aggr_with_count_distinct_and_group_by(spark: SparkSession):
+ """Test that count_distinct with group_by works using two-stage aggregation."""
+ test_df = spark.createDataFrame(
+ [
+ ["group1", "val1"],
+ ["group1", "val1"], # Same value
+ ["group1", "val2"], # Different value - 2 distinct
+ ["group2", "val3"],
+ ["group2", "val3"], # Same value - only 1 distinct
+ ],
+ "a: string, b: string",
+ )
+
+ checks = [
+ # group1 has 2 distinct values, should exceed limit of 1
+ is_aggr_not_greater_than("b", limit=1, aggr_type="count_distinct", group_by=["a"]),
+ ]
+
+ actual = _apply_checks(test_df, checks)
+
+ expected = spark.createDataFrame(
+ [
+ [
+ "group1",
+ "val1",
+ "Distinct count value 2 in column 'b' per group of columns 'a' is greater than limit: 1",
+ ],
+ [
+ "group1",
+ "val1",
+ "Distinct count value 2 in column 'b' per group of columns 'a' is greater than limit: 1",
+ ],
+ [
+ "group1",
+ "val2",
+ "Distinct count value 2 in column 'b' per group of columns 'a' is greater than limit: 1",
+ ],
+ ["group2", "val3", None],
+ ["group2", "val3", None],
+ ],
+ "a: string, b: string, b_count_distinct_group_by_a_greater_than_limit: string",
+ )
+
+ assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True)
+
+
+def test_is_aggr_with_count_distinct_and_column_expression_in_group_by(spark: SparkSession):
+ """Test count_distinct with Column expression (F.col) in group_by.
+
+ This tests that the two-stage aggregation (groupBy + join) correctly handles
+ Column expressions (not just string column names) in group_by.
+ """
+ test_df = spark.createDataFrame(
+ [
+ ["group1", "val1"],
+ ["group1", "val2"], # 2 distinct values in group1
+ ["group2", "val3"],
+ ["group2", "val3"], # 1 distinct value in group2
+ ],
+ "a: string, b: string",
+ )
+
+ # Use Column expression (F.col) in group_by instead of string
+ checks = [
+ is_aggr_not_greater_than(
+ "b",
+ limit=1,
+ aggr_type="count_distinct",
+ group_by=[F.col("a")], # Column expression without alias
+ ),
+ ]
+
+ actual = _apply_checks(test_df, checks)
+
+ # group1 has 2 distinct values > 1, should fail
+ # group2 has 1 distinct value <= 1, should pass
+ expected = spark.createDataFrame(
+ [
+ [
+ "group1",
+ "val1",
+ "Distinct count value 2 in column 'b' per group of columns 'a' is greater than limit: 1",
+ ],
+ [
+ "group1",
+ "val2",
+ "Distinct count value 2 in column 'b' per group of columns 'a' is greater than limit: 1",
+ ],
+ ["group2", "val3", None],
+ ["group2", "val3", None],
+ ],
+ "a: string, b: string, b_count_distinct_group_by_a_greater_than_limit: string",
+ )
+
+ assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True)
+
+
+def test_is_aggr_with_approx_count_distinct(spark: SparkSession):
+ """Test approx_count_distinct for fast cardinality estimation with group_by."""
+ test_df = spark.createDataFrame(
+ [
+ ["group1", "val1"],
+ ["group1", "val1"], # Same value
+ ["group1", "val2"], # Different value
+ ["group2", "val3"],
+ ["group2", "val3"], # Same value - only 1 distinct
+ ],
+ "a: string, b: string",
+ )
+
+ checks = [
+ # group1 has 2 distinct values, should exceed limit of 1
+ is_aggr_not_greater_than("b", limit=1, aggr_type="approx_count_distinct", group_by=["a"]),
+ ]
+
+ actual = _apply_checks(test_df, checks)
+
+ expected = spark.createDataFrame(
+ [
+ [
+ "group1",
+ "val1",
+ "Approximate distinct count value 2 in column 'b' per group of columns 'a' is greater than limit: 1",
+ ],
+ [
+ "group1",
+ "val1",
+ "Approximate distinct count value 2 in column 'b' per group of columns 'a' is greater than limit: 1",
+ ],
+ [
+ "group1",
+ "val2",
+ "Approximate distinct count value 2 in column 'b' per group of columns 'a' is greater than limit: 1",
+ ],
+ ["group2", "val3", None],
+ ["group2", "val3", None],
+ ],
+ "a: string, b: string, b_approx_count_distinct_group_by_a_greater_than_limit: string",
+ )
+
+ assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True)
+
+
+def test_is_aggr_with_aggr_params_generic(spark: SparkSession):
+ """Test aggr_params passed through to generic aggregate function (not percentile)."""
+ test_df = spark.createDataFrame(
+ [
+ ["group1", "val1"],
+ ["group1", "val1"],
+ ["group1", "val2"],
+ ["group1", "val3"],
+ ["group2", "valA"],
+ ["group2", "valA"],
+ ],
+ "a: string, b: string",
+ )
+
+ # Test approx_count_distinct with rsd (relative standard deviation) parameter
+ # This tests the generic aggr_params pass-through (line 2313 in check_funcs.py)
+ checks = [
+ is_aggr_not_greater_than(
+ "b",
+ limit=5,
+ aggr_type="approx_count_distinct",
+ aggr_params={"rsd": 0.01}, # More accurate approximation (1% relative error)
+ group_by=["a"],
+ ),
+ ]
+
+ actual = _apply_checks(test_df, checks)
+
+ # All rows should pass (group1 has ~3 distinct, group2 has ~1 distinct, both <= 5)
+ expected = spark.createDataFrame(
+ [
+ ["group1", "val1", None],
+ ["group1", "val1", None],
+ ["group1", "val2", None],
+ ["group1", "val3", None],
+ ["group2", "valA", None],
+ ["group2", "valA", None],
+ ],
+ "a: string, b: string, b_approx_count_distinct_group_by_a_greater_than_limit: string",
+ )
+
+ assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True)
+
+
+def test_is_aggr_with_statistical_functions(spark: SparkSession):
+ """Test statistical aggregate functions: stddev, variance, median."""
+ test_df = spark.createDataFrame(
+ [
+ ["A", 10.0],
+ ["A", 20.0],
+ ["A", 30.0],
+ ["B", 5.0],
+ ["B", 5.0],
+ ["B", 5.0],
+ ],
+ "a: string, b: double",
+ )
+
+ checks = [
+ # Standard deviation check (group A stddev ~8.16, group B stddev=0, both <= 10.0)
+ is_aggr_not_greater_than("b", limit=10.0, aggr_type="stddev", group_by=["a"]),
+ # Variance check (group A variance ~66.67, group B variance=0, both <= 100.0)
+ is_aggr_not_greater_than("b", limit=100.0, aggr_type="variance", group_by=["a"]),
+ # Median check (dataset-level median 7.5, passes < 25.0)
+ is_aggr_not_greater_than("b", limit=25.0, aggr_type="median"),
+ ]
+
+ actual = _apply_checks(test_df, checks)
+
+ # All checks should pass
+ expected = spark.createDataFrame(
+ [
+ ["A", 10.0, None, None, None],
+ ["A", 20.0, None, None, None],
+ ["A", 30.0, None, None, None],
+ ["B", 5.0, None, None, None],
+ ["B", 5.0, None, None, None],
+ ["B", 5.0, None, None, None],
+ ],
+ "a: string, b: double, b_stddev_group_by_a_greater_than_limit: string, "
+ "b_variance_group_by_a_greater_than_limit: string, b_median_greater_than_limit: string",
+ )
+
+ assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True)
+
+
+def test_is_aggr_with_mode_function(spark: SparkSession):
+ """Test mode aggregate function for detecting most common numeric value."""
+ test_df = spark.createDataFrame(
+ [
+ # groupA: most common error code is 401 (appears 3 times)
+ ["groupA", 401],
+ ["groupA", 401],
+ ["groupA", 401],
+ ["groupA", 500],
+ # groupB: most common error code is 200 (appears 2 times)
+ ["groupB", 200],
+ ["groupB", 200],
+ ["groupB", 404],
+ ],
+ "a: string, b: int",
+ )
+
+ # Check that the most common error code value doesn't exceed threshold
+ checks = [
+ is_aggr_not_greater_than("b", limit=400, aggr_type="mode", group_by=["a"]),
+ ]
+
+ actual = _apply_checks(test_df, checks)
+
+ # groupA should fail (mode=401 > limit=400), groupB should pass (mode=200 <= limit=400)
+ expected = spark.createDataFrame(
+ [
+ ["groupA", 401, "Mode value 401 in column 'b' per group of columns 'a' is greater than limit: 400"],
+ ["groupA", 401, "Mode value 401 in column 'b' per group of columns 'a' is greater than limit: 400"],
+ ["groupA", 401, "Mode value 401 in column 'b' per group of columns 'a' is greater than limit: 400"],
+ ["groupA", 500, "Mode value 401 in column 'b' per group of columns 'a' is greater than limit: 400"],
+ ["groupB", 200, None],
+ ["groupB", 200, None],
+ ["groupB", 404, None],
+ ],
+ "a: string, b: int, b_mode_group_by_a_greater_than_limit: string",
+ )
+
+ assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True)
+
+
+def test_is_aggr_with_percentile_functions(spark: SparkSession):
+ """Test percentile and approx_percentile with aggr_params."""
+ test_df = spark.createDataFrame(
+ [(f"row{i}", float(i)) for i in range(1, 101)],
+ "a: string, b: double",
+ )
+
+ checks = [
+ # P95 should be around 95.95 (dataset-level), all pass (< 100.0)
+ is_aggr_not_greater_than("b", limit=100.0, aggr_type="percentile", aggr_params={"percentile": 0.95}),
+ # P99 with approx_percentile should be around 99 (dataset-level), all pass (< 100.0)
+ is_aggr_not_greater_than("b", limit=100.0, aggr_type="approx_percentile", aggr_params={"percentile": 0.99}),
+ # P50 (median) should be around 50.5 (dataset-level), all pass (>= 40.0)
+ is_aggr_not_less_than(
+ "b", limit=40.0, aggr_type="approx_percentile", aggr_params={"percentile": 0.50, "accuracy": 100}
+ ),
+ ]
+
+ actual = _apply_checks(test_df, checks)
+
+ # All checks should pass (P95 ~95 < 100, P99 ~99 < 100, P50 ~50 >= 40)
+ expected = spark.createDataFrame(
+ [(f"row{i}", float(i), None, None, None) for i in range(1, 101)],
+ "a: string, b: double, b_percentile_greater_than_limit: string, "
+ "b_approx_percentile_greater_than_limit: string, b_approx_percentile_less_than_limit: string",
+ )
+
+ assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True)
+
+
+def test_is_aggr_message_includes_params(spark: SparkSession):
+ """Test that violation messages include aggr_params for differentiation."""
+ test_df = spark.createDataFrame(
+ [("group1", 100.0), ("group1", 200.0), ("group1", 300.0)],
+ "a: string, b: double",
+ )
+
+ # Create a check that will fail - P50 is 200, limit is 100
+ checks = [
+ is_aggr_not_greater_than(
+ "b", limit=100.0, aggr_type="percentile", aggr_params={"percentile": 0.5}, group_by=["a"]
+ ),
+ ]
+
+ actual = _apply_checks(test_df, checks)
+
+ # Verify message includes the percentile parameter
+ # Column name includes group_by info: b_percentile_group_by_a_greater_than_limit
+ messages = [row.b_percentile_group_by_a_greater_than_limit for row in actual.collect()]
+ assert all(msg is not None for msg in messages), "All rows should have violation messages"
+ assert all("percentile=0.5" in msg for msg in messages), "Message should include percentile parameter"
+ assert all("Percentile (percentile=0.5) value" in msg for msg in messages), "Message should have correct format"
+
+
+def test_is_aggr_percentile_missing_params(spark: SparkSession):
+ """Test that percentile functions require percentile parameter."""
+ test_df = spark.createDataFrame([(1, 10.0)], "id: int, value: double")
+
+ # Should raise error when percentile param is missing
+ with pytest.raises(MissingParameterError, match="percentile.*requires aggr_params"):
+ _, apply_fn = is_aggr_not_greater_than("value", limit=100.0, aggr_type="percentile")
+ apply_fn(test_df)
+
+
+def test_is_aggr_percentile_invalid_params_caught_by_spark(spark: SparkSession):
+ """Test that invalid aggr_params are caught by Spark at runtime.
+
+ This verifies our permissive strategy: we don't validate extra parameters,
+ but Spark will raise an error for truly invalid ones.
+ """
+ test_df = spark.createDataFrame([(1, 10.0), (2, 20.0)], "id: int, value: double")
+
+ # Pass an invalid parameter type (string instead of float for percentile)
+ # Spark should raise an error when the DataFrame is evaluated
+ with pytest.raises(Exception): # Spark will raise AnalysisException or similar
+ _, apply_fn = is_aggr_not_greater_than(
+ "value",
+ limit=100.0,
+ aggr_type="approx_percentile",
+ aggr_params={"percentile": "invalid_string"}, # Invalid: should be float
+ )
+ result_df = apply_fn(test_df)
+ result_df.collect() # Force evaluation to trigger Spark error
+
+
+def test_is_aggr_with_invalid_parameter_name(spark: SparkSession):
+ """Test that invalid parameter names in aggr_params raise clear errors."""
+ test_df = spark.createDataFrame([(1, 10.0), (2, 20.0)], "id: int, value: double")
+
+ # Pass an invalid parameter name - should raise InvalidParameterError with context
+ with pytest.raises(InvalidParameterError, match="Failed to build 'approx_percentile' expression"):
+ _, apply_fn = is_aggr_not_greater_than(
+ "value",
+ limit=100.0,
+ aggr_type="approx_percentile",
+ aggr_params={"percentile": 0.95, "invalid_param": 1}, # Invalid param name
+ )
+ apply_fn(test_df)
+
+
+def test_is_aggr_with_invalid_aggregate_function(spark: SparkSession):
+ """Test that invalid aggregate function names raise clear errors."""
+ test_df = spark.createDataFrame([(1, 10)], "id: int, value: int")
+
+ # Non-existent function should raise error
+ with pytest.raises(InvalidParameterError, match="not found in pyspark.sql.functions"):
+ _, apply_fn = is_aggr_not_greater_than("value", limit=100, aggr_type="nonexistent_function")
+ apply_fn(test_df)
+
+
+def test_is_aggr_with_collect_list_fails(spark: SparkSession):
+ """Test that collect_list (returns array) fails with clear error message - no group_by."""
+ test_df = spark.createDataFrame([(1, 10), (2, 20)], "id: int, value: int")
+
+ # collect_list returns array which cannot be compared to numeric limit
+ with pytest.raises(InvalidParameterError, match="array.*cannot be compared"):
+ _, apply_fn = is_aggr_not_greater_than("value", limit=100, aggr_type="collect_list")
+ apply_fn(test_df)
+
+
+def test_is_aggr_with_collect_list_fails_with_group_by(spark: SparkSession):
+ """Test that collect_list with group_by also fails with clear error message - bug fix verification."""
+ test_df = spark.createDataFrame(
+ [("A", 10), ("A", 20), ("B", 30)],
+ "category: string, value: int",
+ )
+
+ # This is the bug fix: collect_list with group_by should fail gracefully, not with cryptic Spark error
+ with pytest.raises(InvalidParameterError, match="array.*cannot be compared"):
+ _, apply_fn = is_aggr_not_greater_than("value", limit=100, aggr_type="collect_list", group_by=["category"])
+ apply_fn(test_df)
+
+
+def test_is_aggr_non_curated_aggregate_with_warning(spark: SparkSession):
+ """Test that non-curated (built-in but not in curated list) aggregates work but produce warning."""
+ test_df = spark.createDataFrame(
+ [("A", 10), ("B", 20), ("C", 30)],
+ "a: string, b: int",
+ )
+
+ # Use a valid aggregate that's not in curated list (e.g., any_value)
+ # any_value returns one arbitrary non-null value (likely 10, 20, or 30), all < 100
+ with pytest.warns(UserWarning, match="non-curated.*any_value"):
+ checks = [is_aggr_not_greater_than("b", limit=100, aggr_type="any_value")]
+ actual = _apply_checks(test_df, checks)
+
+ # Should still work - any_value returns an arbitrary value, all should pass (< 100)
+ expected = spark.createDataFrame(
+ [("A", 10, None), ("B", 20, None), ("C", 30, None)],
+ "a: string, b: int, b_any_value_greater_than_limit: string",
+ )
+
+ assert_df_equality(actual, expected, ignore_nullable=True, ignore_row_order=True)
+
+
def test_dataset_compare(spark: SparkSession, set_utc_timezone):
schema = "id1 long, id2 long, name string, dt date, ts timestamp, score float, likes bigint, active boolean"
diff --git a/tests/perf/.benchmarks/baseline.json b/tests/perf/.benchmarks/baseline.json
index b3f449bb9..66e28fffd 100644
--- a/tests/perf/.benchmarks/baseline.json
+++ b/tests/perf/.benchmarks/baseline.json
@@ -1681,6 +1681,111 @@
"iterations": 1
}
},
+ {
+ "group": null,
+ "name": "test_benchmark_is_aggr_approx_count_distinct_with_group_by",
+ "fullname": "tests/perf/test_apply_checks.py::test_benchmark_is_aggr_approx_count_distinct_with_group_by",
+ "params": null,
+ "param": null,
+ "extra_info": {},
+ "options": {
+ "disable_gc": false,
+ "timer": "perf_counter",
+ "min_rounds": 5,
+ "max_time": 1.0,
+ "min_time": 0.000005,
+ "warmup": false
+ },
+ "stats": {
+ "min": 0.22393849300033253,
+ "max": 0.27297041700012414,
+ "mean": 0.2428591086000779,
+ "stddev": 0.01826011249364107,
+ "rounds": 5,
+ "median": 0.23956732599981478,
+ "iqr": 0.01777076375026354,
+ "q1": 0.23241047199996956,
+ "q3": 0.2501812357502331,
+ "iqr_outliers": 0,
+ "stddev_outliers": 2,
+ "outliers": "2;0",
+ "ld15iqr": 0.22393849300033253,
+ "hd15iqr": 0.27297041700012414,
+ "ops": 4.117613729887829,
+ "total": 1.2142955430003894,
+ "iterations": 1
+ }
+ },
+ {
+ "group": null,
+ "name": "test_benchmark_is_aggr_count_distinct_no_group_by",
+ "fullname": "tests/perf/test_apply_checks.py::test_benchmark_is_aggr_count_distinct_no_group_by",
+ "params": null,
+ "param": null,
+ "extra_info": {},
+ "options": {
+ "disable_gc": false,
+ "timer": "perf_counter",
+ "min_rounds": 5,
+ "max_time": 1.0,
+ "min_time": 0.000005,
+ "warmup": false
+ },
+ "stats": {
+ "min": 0.2504249900002833,
+ "max": 0.3206533539996599,
+ "mean": 0.2774389961999987,
+ "stddev": 0.02768560893059817,
+ "rounds": 5,
+ "median": 0.2725611239998216,
+ "iqr": 0.03825604774954172,
+ "q1": 0.25609008650030773,
+ "q3": 0.29434613424984946,
+ "iqr_outliers": 0,
+ "stddev_outliers": 1,
+ "outliers": "1;0",
+ "ld15iqr": 0.2504249900002833,
+ "hd15iqr": 0.3206533539996599,
+ "ops": 3.6043959706339397,
+ "total": 1.3871949809999933,
+ "iterations": 1
+ }
+ },
+ {
+ "group": null,
+ "name": "test_benchmark_is_aggr_count_distinct_with_group_by",
+ "fullname": "tests/perf/test_apply_checks.py::test_benchmark_is_aggr_count_distinct_with_group_by",
+ "params": null,
+ "param": null,
+ "extra_info": {},
+ "options": {
+ "disable_gc": false,
+ "timer": "perf_counter",
+ "min_rounds": 5,
+ "max_time": 1.0,
+ "min_time": 0.000005,
+ "warmup": false
+ },
+ "stats": {
+ "min": 0.2315504329999385,
+ "max": 0.2700657520003915,
+ "mean": 0.24801521500012313,
+ "stddev": 0.01604423432706884,
+ "rounds": 5,
+ "median": 0.2490236530002221,
+ "iqr": 0.026385781999579194,
+ "q1": 0.2330477210002755,
+ "q3": 0.2594335029998547,
+ "iqr_outliers": 0,
+ "stddev_outliers": 2,
+ "outliers": "2;0",
+ "ld15iqr": 0.2315504329999385,
+ "hd15iqr": 0.2700657520003915,
+ "ops": 4.0320106974062195,
+ "total": 1.2400760750006157,
+ "iterations": 1
+ }
+ },
{
"group": null,
"name": "test_benchmark_is_aggr_equal",
diff --git a/tests/perf/test_apply_checks.py b/tests/perf/test_apply_checks.py
index 398461ae1..1a4dba33f 100644
--- a/tests/perf/test_apply_checks.py
+++ b/tests/perf/test_apply_checks.py
@@ -1592,3 +1592,62 @@ def test_benchmark_foreach_has_valid_schema(benchmark, ws, generated_string_df):
benchmark.group += f"_{n_rows}_rows_{len(columns)}_columns"
actual_count = benchmark(lambda: dq_engine.apply_checks(df, checks).count())
assert actual_count == EXPECTED_ROWS
+
+
+def test_benchmark_is_aggr_count_distinct_with_group_by(benchmark, ws, generated_df):
+ """Benchmark count_distinct with group_by (uses two-stage aggregation: groupBy + join)."""
+ dq_engine = DQEngine(workspace_client=ws, extra_params=EXTRA_PARAMS)
+ checks = [
+ DQDatasetRule(
+ criticality="warn",
+ check_func=check_funcs.is_aggr_not_greater_than,
+ column="col2",
+ check_func_kwargs={
+ "aggr_type": "count_distinct",
+ "group_by": ["col3"],
+ "limit": 1000000,
+ },
+ )
+ ]
+ checked = dq_engine.apply_checks(generated_df, checks)
+ actual_count = benchmark(lambda: checked.count())
+ assert actual_count == EXPECTED_ROWS
+
+
+def test_benchmark_is_aggr_approx_count_distinct_with_group_by(benchmark, ws, generated_df):
+ """Benchmark approx_count_distinct with group_by (uses window functions - should be faster)."""
+ dq_engine = DQEngine(workspace_client=ws, extra_params=EXTRA_PARAMS)
+ checks = [
+ DQDatasetRule(
+ criticality="warn",
+ check_func=check_funcs.is_aggr_not_greater_than,
+ column="col2",
+ check_func_kwargs={
+ "aggr_type": "approx_count_distinct",
+ "group_by": ["col3"],
+ "limit": 1000000,
+ },
+ )
+ ]
+ checked = dq_engine.apply_checks(generated_df, checks)
+ actual_count = benchmark(lambda: checked.count())
+ assert actual_count == EXPECTED_ROWS
+
+
+def test_benchmark_is_aggr_count_distinct_no_group_by(benchmark, ws, generated_df):
+ """Benchmark count_distinct without group_by (baseline - uses standard aggregation)."""
+ dq_engine = DQEngine(workspace_client=ws, extra_params=EXTRA_PARAMS)
+ checks = [
+ DQDatasetRule(
+ criticality="warn",
+ check_func=check_funcs.is_aggr_not_greater_than,
+ column="col2",
+ check_func_kwargs={
+ "aggr_type": "count_distinct",
+ "limit": 1000000,
+ },
+ )
+ ]
+ checked = dq_engine.apply_checks(generated_df, checks)
+ actual_count = benchmark(lambda: checked.count())
+ assert actual_count == EXPECTED_ROWS
diff --git a/tests/unit/test_row_checks.py b/tests/unit/test_row_checks.py
index 2fcaada2a..ef8b80d3f 100644
--- a/tests/unit/test_row_checks.py
+++ b/tests/unit/test_row_checks.py
@@ -51,8 +51,14 @@ def test_col_is_in_list_missing_allowed_list():
def test_incorrect_aggr_type():
- with pytest.raises(InvalidParameterError, match="Unsupported aggregation type"):
- is_aggr_not_greater_than("a", 1, aggr_type="invalid")
+ # With new implementation, invalid aggr_type triggers a warning (not immediate error)
+ # The error occurs at runtime when the apply function is called
+ with pytest.warns(UserWarning, match="non-curated.*invalid"):
+ condition, apply_fn = is_aggr_not_greater_than("a", 1, aggr_type="invalid")
+
+ # Function should return successfully (error will happen at runtime when applied to DataFrame)
+ assert condition is not None
+ assert apply_fn is not None
def test_col_is_ipv4_address_in_cidr_missing_cidr_block():