Skip to content

Commit 7c2253b

Browse files
STEFANOVIVASmwojtyczkaCopilot
authored
Outlier detection numerical values (#944)
## Changes Added new check function for outlier detection of numeric values. The checkuses a statistical method called MAD (Median Absolute Deviation) to check whether the specified column's values are within the calculated limits. The lower limit is calculated as median - 3.5 * MAD and the upper limit as median + 3.5 * MAD. Values outside these limits are considered outliers. ### Linked issues <!-- DOC: Link issue with a keyword: close, closes, closed, fix, fixes, fixed, resolve, resolves, resolved. See https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword --> Resolves #359 ### Tests <!-- How is this tested? Please see the checklist below and also describe any other relevant tests --> - [x] manually tested - [x] added unit tests - [x] added integration tests - [ ] added end-to-end tests - [x] added performance tests --------- Co-authored-by: Marcin Wojtyczka <marcin.wojtyczka@databricks.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 0055e66 commit 7c2253b

File tree

7 files changed

+480
-1
lines changed

7 files changed

+480
-1
lines changed

docs/dqx/docs/reference/quality_checks.mdx

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ You can also define your own custom checks in Python (see [Creating custom check
6969
| `has_dimension` | Checks whether the values in the input column are geometries of the specified dimension (2D projected dimension). This function requires Databricks serverless compute or runtime >= 17.1. | `column`: column to check (can be a string column name or a column expression); `dimension`: dimension to check |
7070
| `has_x_coordinate_between` | Checks whether the values in the input column are geometries with x coordinate between the provided boundaries. This function requires Databricks serverless compute or runtime >= 17.1. | `column`: column to check (can be a string column name or a column expression); `min_value`: minimum value; `max_value`: maximum value |
7171
| `has_y_coordinate_between` | Checks whether the values in the input column are geometries with y coordinate between the provided boundaries. This function requires Databricks serverless compute or runtime >= 17.1. | `column`: column to check (can be a string column name or a column expression); `min_value`: minimum value; `max_value`: maximum value | | `column`: column to check (can be a string column name or a column expression); `min_value`: minimum value; `max_value`: maximum value |
72+
7273
</details>
7374

7475
<Admonition type="warning" title="Applicability">
@@ -1396,6 +1397,7 @@ You can also define your own custom dataset-level checks (see [Creating custom c
13961397
| `compare_datasets` | Compares two DataFrames at both row and column levels, providing detailed information about differences, including new or missing rows and column-level changes. Only columns present in both the source and reference DataFrames are compared. Use with caution if `check_missing_records` is enabled, as this may increase the number of rows in the output beyond the original input DataFrame. The comparison does not support Map types (any column comparison on map type is skipped automatically). Comparing datasets is valuable for validating data during migrations, detecting drift, performing regression testing, or verifying synchronization between source and target systems. | `columns`: columns to use for row matching with the reference DataFrame (can be a list of string column names or column expressions, but only simple column expressions are allowed such as 'F.col("col1")'), if not having primary keys or wanting to match against all columns you can pass 'df.columns'; `ref_columns`: list of columns in the reference DataFrame or Table to row match against the source DataFrame (can be a list of string column names or column expressions, but only simple column expressions are allowed such as 'F.col("col1")'), if not having primary keys or wanting to match against all columns you can pass 'ref_df.columns'; note that `columns` are matched with `ref_columns` by position, so the order of the provided columns in both lists must be exactly aligned; `exclude_columns`: (optional) list of columns to exclude from the value comparison but not from row matching (can be a list of string column names or column expressions, but only simple column expressions are allowed such as 'F.col("col1")'); the `exclude_columns` field does not alter the list of columns used to determine row matches (columns), it only controls which columns are skipped during the value comparison; `ref_df_name`: (optional) name of the reference DataFrame (dictionary of DataFrames can be passed when applying checks); `ref_table`: (optional) fully qualified reference table name; either `ref_df_name` or `ref_table` must be provided but never both; the number of passed `columns` and `ref_columns` must match and keys are checks in the given order; `check_missing_records`: perform a FULL OUTER JOIN to identify records that are missing from source or reference DataFrames, default is False; use with caution as this may increase the number of rows in the output, as unmatched rows from both sides are included; `null_safe_row_matching`: (optional) treat NULLs as equal when matching rows using `columns` and `ref_columns` (default: True); `null_safe_column_value_matching`: (optional) treat NULLs as equal when comparing column values (default: True) |
13971398
| `is_data_fresh_per_time_window` | Freshness check that validates whether at least X records arrive within every Y-minute time window. | `column`: timestamp column (can be a string column name or a column expression); `window_minutes`: time window in minutes to check for data arrival; `min_records_per_window`: minimum number of records expected per time window; `lookback_windows`: (optional) number of time windows to look back from `curr_timestamp`, it filters records to include only those within the specified number of time windows from `curr_timestamp` (if no lookback is provided, the check is applied to the entire dataset); `curr_timestamp`: (optional) current timestamp column (if not provided, current_timestamp() function is used) |
13981399
| `has_valid_schema` | Schema check that validates whether the DataFrame schema matches an expected schema. In non-strict mode, validates that all expected columns exist with compatible types (allows extra columns). In strict mode, validates exact schema match (same columns, same order, same types) for all columns by default or for all columns specified in `columns`. This check is applied at the dataset level and reports schema violations for all rows in the DataFrame when incompatibilities are detected. | `expected_schema`: expected schema as a DDL string (e.g., "id INT, name STRING") or StructType object; `columns`: (optional) list of columns to validate (if not provided, all columns are considered); `strict`: (optional) whether to perform strict schema validation (default: False) - False: validates that all expected columns exist with compatible types, True: validates exact schema match |
1400+
| `has_no_outliers` | Checks whether the values in the input column contain any outliers. This function implements a median absolute deviation (MAD) algorithm to find outliers. | `column`: column of type numeric to check (can be a string column name or a column expression); |
13991401

14001402
**Compare datasets check**
14011403

@@ -1735,6 +1737,13 @@ Complex data types are supported as well.
17351737
for_each_column: # apply the check for each column in the list
17361738
- [col3, col5]
17371739
- [col1]
1740+
1741+
# has_no_outliers check
1742+
- criticality: error
1743+
check:
1744+
function: has_no_outliers
1745+
arguments:
1746+
column: col1
17381747
```
17391748
17401749
**Providing Reference DataFrames programmatically**
@@ -2103,6 +2112,13 @@ checks = [
21032112
"strict": True,
21042113
},
21052114
),
2115+
2116+
# has_no_outliers check
2117+
DQDatasetRule(
2118+
criticality="error",
2119+
check_func=check_funcs.has_no_outliers,
2120+
column="col1" # or as expr: F.col("col1")
2121+
),
21062122

21072123
# has_valid_schema check with specific columns, expected schema defined using StructType
21082124
DQDatasetRule(

src/databricks/labs/dqx/check_funcs.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,85 @@ def is_data_fresh(
973973
)
974974

975975

976+
@register_rule("dataset")
977+
def has_no_outliers(column: str | Column, row_filter: str | None = None) -> tuple[Column, Callable]:
978+
"""
979+
Build an outlier check condition and closure for dataset-level validation.
980+
981+
This function uses a statistical method called MAD (Median Absolute Deviation) to check whether
982+
the specified column's values are within the calculated limits. The lower limit is calculated as
983+
median - 3.5 * MAD and the upper limit as median + 3.5 * MAD. Values outside these limits are considered outliers.
984+
985+
986+
Args:
987+
column: column to check; can be a string column name or a column expression
988+
row_filter: Optional SQL expression for filtering rows before checking for outliers.
989+
990+
991+
Returns:
992+
A tuple of:
993+
- A Spark Column representing the condition for outliers violations.
994+
- A closure that applies the outliers check and adds the necessary condition/count columns.
995+
"""
996+
column = F.col(column) if isinstance(column, str) else column
997+
998+
col_str_norm, col_expr_str, col_expr = _get_normalized_column_and_expr(column)
999+
1000+
unique_str = uuid.uuid4().hex # make sure any column added to the dataframe is unique
1001+
condition_col = f"__condition_{col_str_norm}_{unique_str}"
1002+
1003+
def apply(df: DataFrame) -> DataFrame:
1004+
"""
1005+
Apply the outlier detection logic to the DataFrame.
1006+
1007+
Adds columns indicating the median and MAD for the column.
1008+
1009+
Args:
1010+
df: The input DataFrame to validate for outliers.
1011+
1012+
Returns:
1013+
The DataFrame with additional median and MAD columns for outlier detection.
1014+
"""
1015+
column_type = df.schema[col_expr_str].dataType
1016+
if not isinstance(column_type, (types.NumericType)):
1017+
raise InvalidParameterError(
1018+
f"Column '{col_expr_str}' must be of numeric type to perform outlier detection using MAD method, "
1019+
f"but got type '{column_type.simpleString()}' instead."
1020+
)
1021+
filter_condition = F.expr(row_filter) if row_filter else F.lit(True)
1022+
median, mad = _calculate_median_absolute_deviation(df, col_expr_str, row_filter)
1023+
if median is not None and mad is not None:
1024+
median = float(median)
1025+
mad = float(mad)
1026+
# Create outlier condition
1027+
lower_bound = median - (3.5 * mad)
1028+
upper_bound = median + (3.5 * mad)
1029+
lower_bound_expr = _get_limit_expr(lower_bound)
1030+
upper_bound_expr = _get_limit_expr(upper_bound)
1031+
1032+
condition = (col_expr < (lower_bound_expr)) | (col_expr > (upper_bound_expr))
1033+
1034+
# Add outlier detection columns
1035+
result_df = df.withColumn(condition_col, F.when(filter_condition & condition, True).otherwise(False))
1036+
else:
1037+
# If median or mad could not be calculated, no outliers can be detected
1038+
result_df = df.withColumn(condition_col, F.lit(False))
1039+
1040+
return result_df
1041+
1042+
condition = make_condition(
1043+
condition=F.col(condition_col),
1044+
message=F.concat_ws(
1045+
"",
1046+
F.lit("Value '"),
1047+
col_expr.cast("string"),
1048+
F.lit(f"' in Column '{col_expr_str}' is an outlier as per MAD."),
1049+
),
1050+
alias=f"{col_str_norm}_has_outliers",
1051+
)
1052+
return condition, apply
1053+
1054+
9761055
@register_rule("dataset")
9771056
def is_unique(
9781057
columns: list[str | Column],
@@ -2920,3 +2999,36 @@ def _validate_sql_query_params(query: str, merge_columns: list[str]) -> None:
29202999
raise UnsafeSqlQueryError(
29213000
"Provided SQL query is not safe for execution. Please ensure it does not contain any unsafe operations."
29223001
)
3002+
3003+
3004+
def _calculate_median_absolute_deviation(df: DataFrame, column: str, filter_condition: str | None) -> tuple[Any, Any]:
3005+
"""
3006+
Calculate the Median Absolute Deviation (MAD) for a numeric column.
3007+
3008+
The MAD is a robust measure of variability based on the median, calculated as:
3009+
MAD = median(|X_i - median(X)|)
3010+
3011+
This is useful for outlier detection as it is more robust to outliers than
3012+
standard deviation.
3013+
3014+
Args:
3015+
df: PySpark DataFrame
3016+
column: Name of the numeric column to calculate MAD for
3017+
filter_condition: Filter to apply before calculation (optional)
3018+
3019+
Returns:
3020+
The Median and Absolute Deviation values
3021+
"""
3022+
if filter_condition is not None:
3023+
df = df.filter(filter_condition)
3024+
3025+
# Step 1: Calculate the median of the column
3026+
median_value = df.agg(F.percentile_approx(column, 0.5)).collect()[0][0]
3027+
3028+
# Step 2: Calculate absolute deviations from the median
3029+
df_with_deviations = df.select(F.abs(F.col(column) - F.lit(median_value)).alias("absolute_deviation"))
3030+
3031+
# Step 3: Calculate the median of absolute deviations
3032+
mad = df_with_deviations.agg(F.percentile_approx("absolute_deviation", 0.5)).collect()[0][0]
3033+
3034+
return median_value, mad

tests/integration/test_build_rules.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,15 @@ def test_build_quality_rules_from_dataframe(spark):
7979
"arguments": {"column": "test_col", "group_by": ["a"], "limit": 0, "aggr_type": "count"},
8080
},
8181
},
82+
{
83+
"name": "column_has_outliers",
84+
"criticality": "error",
85+
"filter": "test_col > 0",
86+
"check": {
87+
"function": "has_no_outliers",
88+
"arguments": {"column": "c"},
89+
},
90+
},
8291
]
8392

8493
df = deserialize_checks_to_dataframe(spark, test_checks)

0 commit comments

Comments
 (0)