Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/dqx/docs/reference/quality_checks.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ You can also define your own custom checks in Python (see [Creating custom check
| `has_dimension` | Checks whether the values in the input column are geometries of the specified dimension (2D projected dimension). This function requires Databricks serverless compute or runtime >= 17.1. | `column`: column to check (can be a string column name or a column expression); `dimension`: dimension to check |
| `has_x_coordinate_between` | Checks whether the values in the input column are geometries with x coordinate between the provided boundaries. This function requires Databricks serverless compute or runtime >= 17.1. | `column`: column to check (can be a string column name or a column expression); `min_value`: minimum value; `max_value`: maximum value |
| `has_y_coordinate_between` | Checks whether the values in the input column are geometries with y coordinate between the provided boundaries. This function requires Databricks serverless compute or runtime >= 17.1. | `column`: column to check (can be a string column name or a column expression); `min_value`: minimum value; `max_value`: maximum value | | `column`: column to check (can be a string column name or a column expression); `min_value`: minimum value; `max_value`: maximum value |

</details>

<Admonition type="warning" title="Applicability">
Expand Down Expand Up @@ -1396,6 +1397,7 @@ You can also define your own custom dataset-level checks (see [Creating custom c
| `compare_datasets` | Compares two DataFrames at both row and column levels, providing detailed information about differences, including new or missing rows and column-level changes. Only columns present in both the source and reference DataFrames are compared. Use with caution if `check_missing_records` is enabled, as this may increase the number of rows in the output beyond the original input DataFrame. The comparison does not support Map types (any column comparison on map type is skipped automatically). Comparing datasets is valuable for validating data during migrations, detecting drift, performing regression testing, or verifying synchronization between source and target systems. | `columns`: columns to use for row matching with the reference DataFrame (can be a list of string column names or column expressions, but only simple column expressions are allowed such as 'F.col("col1")'), if not having primary keys or wanting to match against all columns you can pass 'df.columns'; `ref_columns`: list of columns in the reference DataFrame or Table to row match against the source DataFrame (can be a list of string column names or column expressions, but only simple column expressions are allowed such as 'F.col("col1")'), if not having primary keys or wanting to match against all columns you can pass 'ref_df.columns'; note that `columns` are matched with `ref_columns` by position, so the order of the provided columns in both lists must be exactly aligned; `exclude_columns`: (optional) list of columns to exclude from the value comparison but not from row matching (can be a list of string column names or column expressions, but only simple column expressions are allowed such as 'F.col("col1")'); the `exclude_columns` field does not alter the list of columns used to determine row matches (columns), it only controls which columns are skipped during the value comparison; `ref_df_name`: (optional) name of the reference DataFrame (dictionary of DataFrames can be passed when applying checks); `ref_table`: (optional) fully qualified reference table name; either `ref_df_name` or `ref_table` must be provided but never both; the number of passed `columns` and `ref_columns` must match and keys are checks in the given order; `check_missing_records`: perform a FULL OUTER JOIN to identify records that are missing from source or reference DataFrames, default is False; use with caution as this may increase the number of rows in the output, as unmatched rows from both sides are included; `null_safe_row_matching`: (optional) treat NULLs as equal when matching rows using `columns` and `ref_columns` (default: True); `null_safe_column_value_matching`: (optional) treat NULLs as equal when comparing column values (default: True) |
| `is_data_fresh_per_time_window` | Freshness check that validates whether at least X records arrive within every Y-minute time window. | `column`: timestamp column (can be a string column name or a column expression); `window_minutes`: time window in minutes to check for data arrival; `min_records_per_window`: minimum number of records expected per time window; `lookback_windows`: (optional) number of time windows to look back from `curr_timestamp`, it filters records to include only those within the specified number of time windows from `curr_timestamp` (if no lookback is provided, the check is applied to the entire dataset); `curr_timestamp`: (optional) current timestamp column (if not provided, current_timestamp() function is used) |
| `has_valid_schema` | Schema check that validates whether the DataFrame schema matches an expected schema. In non-strict mode, validates that all expected columns exist with compatible types (allows extra columns). In strict mode, validates exact schema match (same columns, same order, same types) for all columns by default or for all columns specified in `columns`. This check is applied at the dataset level and reports schema violations for all rows in the DataFrame when incompatibilities are detected. | `expected_schema`: expected schema as a DDL string (e.g., "id INT, name STRING") or StructType object; `columns`: (optional) list of columns to validate (if not provided, all columns are considered); `strict`: (optional) whether to perform strict schema validation (default: False) - False: validates that all expected columns exist with compatible types, True: validates exact schema match |
| `has_no_outliers` | Checks whether the values in the input column contain any outliers. This function implements a median absolute deviation algorithm to find outliers. | `column`: column of type numeric to check (can be a string column name or a column expression); |

**Compare datasets check**

Expand Down Expand Up @@ -1736,6 +1738,13 @@ Complex data types are supported as well.
for_each_column: # apply the check for each column in the list
- [col3, col5]
- [col1]

# has_no_outliers check
- criticality: error
check:
function: has_no_outliers
arguments:
column: col1
```

**Providing Reference DataFrames programmatically**
Expand Down Expand Up @@ -2104,6 +2113,13 @@ checks = [
"strict": True,
},
),

# has_no_outliers check
DQDatasetRule(
criticality="error",
check_func=check_funcs.has_no_outliers,
column="col1" # or as expr: F.col("col1")
),

# has_valid_schema check with specific columns, expected schema defined using StructType
DQDatasetRule(
Expand Down
119 changes: 119 additions & 0 deletions src/databricks/labs/dqx/check_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from enum import Enum
from itertools import zip_longest
import operator as py_operator
from typing import Any
import pandas as pd # type: ignore[import-untyped]
import pyspark.sql.functions as F
from pyspark.sql import types
Expand Down Expand Up @@ -940,6 +941,85 @@ def is_data_fresh(
)


@register_rule("dataset")
def has_no_outliers(column: str | Column, row_filter: str | None = None) -> tuple[Column, Callable]:
"""
Build an outlier check condition and closure for dataset-level validation.

This function uses a statistical method called MAD (Median Absolute Deviation) to checks whether
the specified column's values are between the calculated limits. The upper limit is calculated from
the median + 3.5 * MAD. Values outside these limits are considered outliers.


Args:
column: column to check; can be a string column name or a column expression
row_filter: Optional SQL expression for filtering rows before checking for outliers.


Returns:
A tuple of:
- A Spark Column representing the condition for outliers violations.
- A closure that applies the outliers check and adds the necessary condition/count columns.
"""
column = F.col(column) if isinstance(column, str) else column

col_str_norm, col_expr_str, col_expr = _get_normalized_column_and_expr(column)

unique_str = uuid.uuid4().hex # make sure any column added to the dataframe is unique
condition_col = f"__condition_{col_str_norm}_{unique_str}"

def apply(df: DataFrame) -> DataFrame:
"""
Apply the outlier detection logic to the DataFrame.

Adds columns indicating the median and MAD for the column.

Args:
df: The input DataFrame to validate for outliers.

Returns:
The DataFrame with additional median and MAD columns for outlier detection.
"""
column_type = df.schema[col_expr_str].dataType
if not isinstance(column_type, (types.NumericType)):
raise InvalidParameterError(
f"Column '{col_expr_str}' must be of numeric type to perform outlier detection using MAD method, "
f"but got type '{column_type.simpleString()}' instead."
)
filter_condition = F.expr(row_filter) if row_filter else F.lit(True)
median, mad = _calculate_median_absolute_deviation(df, col_expr_str, row_filter)
if median is not None and mad is not None:
median = float(median)
mad = float(mad)
# Create outlier condition
lower_bound = median - (3.5 * mad)
upper_bound = median + (3.5 * mad)
lower_bound_expr = _get_limit_expr(lower_bound)
upper_bound_expr = _get_limit_expr(upper_bound)

condition = (col_expr < (lower_bound_expr)) | (col_expr > (upper_bound_expr))

# Add outlier detection columns
result_df = df.withColumn(condition_col, F.when(filter_condition & condition, True).otherwise(False))
else:
# If median or mad could not be calculated, no outliers can be detected
result_df = df.withColumn(condition_col, F.lit(False))

return result_df

condition = make_condition(
condition=F.col(condition_col),
message=F.concat_ws(
"",
F.lit("Value '"),
col_expr.cast("string"),
F.lit(f"' in Column '{col_expr_str}' is an outlier as per MAD."),
),
alias=f"{col_str_norm}_has_outliers",
)
return condition, apply


@register_rule("dataset")
def is_unique(
columns: list[str | Column],
Expand Down Expand Up @@ -2707,3 +2787,42 @@ def _validate_sql_query_params(query: str, merge_columns: list[str]) -> None:
raise UnsafeSqlQueryError(
"Provided SQL query is not safe for execution. Please ensure it does not contain any unsafe operations."
)


def _calculate_median_absolute_deviation(df: DataFrame, column: str, filter_condition: str | None) -> tuple[Any, Any]:
"""
Calculate the Median Absolute Deviation (MAD) for a numeric column.

The MAD is a robust measure of variability based on the median, calculated as:
MAD = median(|X_i - median(X)|)

This is useful for outlier detection as it is more robust to outliers than
standard deviation.

Args:
df: PySpark DataFrame
column: Name of the numeric column to calculate MAD for
filter_condition: Filter to apply before calculation (optional)

Returns:
The Median and Absolute Deviation values


Example:
>>> df = spark.createDataFrame([(1.0,), (2.0,), (3.0,), (4.0,), (5.0,)], ["values"])
>>> mad = _calculate_median_absolute_deviation(df, "values")
>>> print(f"MAD: {mad}")
"""
if filter_condition is not None:
df = df.filter(filter_condition)

# Step 1: Calculate the median of the column
median_value = df.agg(F.percentile_approx(column, 0.5)).collect()[0][0]

# Step 2: Calculate absolute deviations from the median
df_with_deviations = df.select(F.abs(F.col(column) - F.lit(median_value)).alias("absolute_deviation"))

# Step 3: Calculate the median of absolute deviations
mad = df_with_deviations.agg(F.percentile_approx("absolute_deviation", 0.5)).collect()[0][0]

return median_value, mad
9 changes: 9 additions & 0 deletions tests/integration/test_build_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ def test_build_quality_rules_from_dataframe(spark):
"arguments": {"column": "test_col", "group_by": ["a"], "limit": 0, "aggr_type": "count"},
},
},
{
"name": "column_has_outliers",
"criticality": "error",
"filter": "test_col > 0",
"check": {
"function": "has_no_outliers",
"arguments": {"column": "c"},
},
},
]

df = deserialize_checks_to_dataframe(spark, test_checks)
Expand Down
Loading
Loading