diff --git a/docs/dqx/docs/reference/quality_checks.mdx b/docs/dqx/docs/reference/quality_checks.mdx index 1ef154fc1..19851eecb 100644 --- a/docs/dqx/docs/reference/quality_checks.mdx +++ b/docs/dqx/docs/reference/quality_checks.mdx @@ -69,6 +69,7 @@ You can also define your own custom checks in Python (see [Creating custom check | `has_dimension` | Checks whether the values in the input column are geometries of the specified dimension (2D projected dimension). This function requires Databricks serverless compute or runtime >= 17.1. | `column`: column to check (can be a string column name or a column expression); `dimension`: dimension to check | | `has_x_coordinate_between` | Checks whether the values in the input column are geometries with x coordinate between the provided boundaries. This function requires Databricks serverless compute or runtime >= 17.1. | `column`: column to check (can be a string column name or a column expression); `min_value`: minimum value; `max_value`: maximum value | | `has_y_coordinate_between` | Checks whether the values in the input column are geometries with y coordinate between the provided boundaries. This function requires Databricks serverless compute or runtime >= 17.1. | `column`: column to check (can be a string column name or a column expression); `min_value`: minimum value; `max_value`: maximum value | | `column`: column to check (can be a string column name or a column expression); `min_value`: minimum value; `max_value`: maximum value | + @@ -1396,6 +1397,7 @@ You can also define your own custom dataset-level checks (see [Creating custom c | `compare_datasets` | Compares two DataFrames at both row and column levels, providing detailed information about differences, including new or missing rows and column-level changes. Only columns present in both the source and reference DataFrames are compared. Use with caution if `check_missing_records` is enabled, as this may increase the number of rows in the output beyond the original input DataFrame. The comparison does not support Map types (any column comparison on map type is skipped automatically). Comparing datasets is valuable for validating data during migrations, detecting drift, performing regression testing, or verifying synchronization between source and target systems. | `columns`: columns to use for row matching with the reference DataFrame (can be a list of string column names or column expressions, but only simple column expressions are allowed such as 'F.col("col1")'), if not having primary keys or wanting to match against all columns you can pass 'df.columns'; `ref_columns`: list of columns in the reference DataFrame or Table to row match against the source DataFrame (can be a list of string column names or column expressions, but only simple column expressions are allowed such as 'F.col("col1")'), if not having primary keys or wanting to match against all columns you can pass 'ref_df.columns'; note that `columns` are matched with `ref_columns` by position, so the order of the provided columns in both lists must be exactly aligned; `exclude_columns`: (optional) list of columns to exclude from the value comparison but not from row matching (can be a list of string column names or column expressions, but only simple column expressions are allowed such as 'F.col("col1")'); the `exclude_columns` field does not alter the list of columns used to determine row matches (columns), it only controls which columns are skipped during the value comparison; `ref_df_name`: (optional) name of the reference DataFrame (dictionary of DataFrames can be passed when applying checks); `ref_table`: (optional) fully qualified reference table name; either `ref_df_name` or `ref_table` must be provided but never both; the number of passed `columns` and `ref_columns` must match and keys are checks in the given order; `check_missing_records`: perform a FULL OUTER JOIN to identify records that are missing from source or reference DataFrames, default is False; use with caution as this may increase the number of rows in the output, as unmatched rows from both sides are included; `null_safe_row_matching`: (optional) treat NULLs as equal when matching rows using `columns` and `ref_columns` (default: True); `null_safe_column_value_matching`: (optional) treat NULLs as equal when comparing column values (default: True) | | `is_data_fresh_per_time_window` | Freshness check that validates whether at least X records arrive within every Y-minute time window. | `column`: timestamp column (can be a string column name or a column expression); `window_minutes`: time window in minutes to check for data arrival; `min_records_per_window`: minimum number of records expected per time window; `lookback_windows`: (optional) number of time windows to look back from `curr_timestamp`, it filters records to include only those within the specified number of time windows from `curr_timestamp` (if no lookback is provided, the check is applied to the entire dataset); `curr_timestamp`: (optional) current timestamp column (if not provided, current_timestamp() function is used) | | `has_valid_schema` | Schema check that validates whether the DataFrame schema matches an expected schema. In non-strict mode, validates that all expected columns exist with compatible types (allows extra columns). In strict mode, validates exact schema match (same columns, same order, same types) for all columns by default or for all columns specified in `columns`. This check is applied at the dataset level and reports schema violations for all rows in the DataFrame when incompatibilities are detected. | `expected_schema`: expected schema as a DDL string (e.g., "id INT, name STRING") or StructType object; `columns`: (optional) list of columns to validate (if not provided, all columns are considered); `strict`: (optional) whether to perform strict schema validation (default: False) - False: validates that all expected columns exist with compatible types, True: validates exact schema match | +| `has_no_outliers` | Checks whether the values in the input column contain any outliers. This function implements a median absolute deviation (MAD) algorithm to find outliers. | `column`: column of type numeric to check (can be a string column name or a column expression); | **Compare datasets check** @@ -1735,6 +1737,13 @@ Complex data types are supported as well. for_each_column: # apply the check for each column in the list - [col3, col5] - [col1] + +# has_no_outliers check +- criticality: error + check: + function: has_no_outliers + arguments: + column: col1 ``` **Providing Reference DataFrames programmatically** @@ -2103,6 +2112,13 @@ checks = [ "strict": True, }, ), + + # has_no_outliers check + DQDatasetRule( + criticality="error", + check_func=check_funcs.has_no_outliers, + column="col1" # or as expr: F.col("col1") + ), # has_valid_schema check with specific columns, expected schema defined using StructType DQDatasetRule( diff --git a/src/databricks/labs/dqx/check_funcs.py b/src/databricks/labs/dqx/check_funcs.py index 3cf7f4eab..9a968370d 100644 --- a/src/databricks/labs/dqx/check_funcs.py +++ b/src/databricks/labs/dqx/check_funcs.py @@ -973,6 +973,85 @@ def is_data_fresh( ) +@register_rule("dataset") +def has_no_outliers(column: str | Column, row_filter: str | None = None) -> tuple[Column, Callable]: + """ + Build an outlier check condition and closure for dataset-level validation. + + This function uses a statistical method called MAD (Median Absolute Deviation) to check whether + the specified column's values are within the calculated limits. The lower limit is calculated as + median - 3.5 * MAD and the upper limit as median + 3.5 * MAD. Values outside these limits are considered outliers. + + + Args: + column: column to check; can be a string column name or a column expression + row_filter: Optional SQL expression for filtering rows before checking for outliers. + + + Returns: + A tuple of: + - A Spark Column representing the condition for outliers violations. + - A closure that applies the outliers check and adds the necessary condition/count columns. + """ + column = F.col(column) if isinstance(column, str) else column + + col_str_norm, col_expr_str, col_expr = _get_normalized_column_and_expr(column) + + unique_str = uuid.uuid4().hex # make sure any column added to the dataframe is unique + condition_col = f"__condition_{col_str_norm}_{unique_str}" + + def apply(df: DataFrame) -> DataFrame: + """ + Apply the outlier detection logic to the DataFrame. + + Adds columns indicating the median and MAD for the column. + + Args: + df: The input DataFrame to validate for outliers. + + Returns: + The DataFrame with additional median and MAD columns for outlier detection. + """ + column_type = df.schema[col_expr_str].dataType + if not isinstance(column_type, (types.NumericType)): + raise InvalidParameterError( + f"Column '{col_expr_str}' must be of numeric type to perform outlier detection using MAD method, " + f"but got type '{column_type.simpleString()}' instead." + ) + filter_condition = F.expr(row_filter) if row_filter else F.lit(True) + median, mad = _calculate_median_absolute_deviation(df, col_expr_str, row_filter) + if median is not None and mad is not None: + median = float(median) + mad = float(mad) + # Create outlier condition + lower_bound = median - (3.5 * mad) + upper_bound = median + (3.5 * mad) + lower_bound_expr = _get_limit_expr(lower_bound) + upper_bound_expr = _get_limit_expr(upper_bound) + + condition = (col_expr < (lower_bound_expr)) | (col_expr > (upper_bound_expr)) + + # Add outlier detection columns + result_df = df.withColumn(condition_col, F.when(filter_condition & condition, True).otherwise(False)) + else: + # If median or mad could not be calculated, no outliers can be detected + result_df = df.withColumn(condition_col, F.lit(False)) + + return result_df + + condition = make_condition( + condition=F.col(condition_col), + message=F.concat_ws( + "", + F.lit("Value '"), + col_expr.cast("string"), + F.lit(f"' in Column '{col_expr_str}' is an outlier as per MAD."), + ), + alias=f"{col_str_norm}_has_outliers", + ) + return condition, apply + + @register_rule("dataset") def is_unique( columns: list[str | Column], @@ -2920,3 +2999,36 @@ def _validate_sql_query_params(query: str, merge_columns: list[str]) -> None: raise UnsafeSqlQueryError( "Provided SQL query is not safe for execution. Please ensure it does not contain any unsafe operations." ) + + +def _calculate_median_absolute_deviation(df: DataFrame, column: str, filter_condition: str | None) -> tuple[Any, Any]: + """ + Calculate the Median Absolute Deviation (MAD) for a numeric column. + + The MAD is a robust measure of variability based on the median, calculated as: + MAD = median(|X_i - median(X)|) + + This is useful for outlier detection as it is more robust to outliers than + standard deviation. + + Args: + df: PySpark DataFrame + column: Name of the numeric column to calculate MAD for + filter_condition: Filter to apply before calculation (optional) + + Returns: + The Median and Absolute Deviation values + """ + if filter_condition is not None: + df = df.filter(filter_condition) + + # Step 1: Calculate the median of the column + median_value = df.agg(F.percentile_approx(column, 0.5)).collect()[0][0] + + # Step 2: Calculate absolute deviations from the median + df_with_deviations = df.select(F.abs(F.col(column) - F.lit(median_value)).alias("absolute_deviation")) + + # Step 3: Calculate the median of absolute deviations + mad = df_with_deviations.agg(F.percentile_approx("absolute_deviation", 0.5)).collect()[0][0] + + return median_value, mad diff --git a/tests/integration/test_build_rules.py b/tests/integration/test_build_rules.py index 80537e816..daa6ceedd 100644 --- a/tests/integration/test_build_rules.py +++ b/tests/integration/test_build_rules.py @@ -79,6 +79,15 @@ def test_build_quality_rules_from_dataframe(spark): "arguments": {"column": "test_col", "group_by": ["a"], "limit": 0, "aggr_type": "count"}, }, }, + { + "name": "column_has_outliers", + "criticality": "error", + "filter": "test_col > 0", + "check": { + "function": "has_no_outliers", + "arguments": {"column": "c"}, + }, + }, ] df = deserialize_checks_to_dataframe(spark, test_checks) diff --git a/tests/integration/test_dataset_checks.py b/tests/integration/test_dataset_checks.py index 4944239f6..a325c3fd7 100644 --- a/tests/integration/test_dataset_checks.py +++ b/tests/integration/test_dataset_checks.py @@ -3,8 +3,8 @@ from typing import Any import json import itertools +from decimal import Decimal import pytest - import pyspark.sql.functions as F from chispa.dataframe_comparer import assert_df_equality # type: ignore from pyspark.sql import Column, DataFrame, SparkSession @@ -16,6 +16,7 @@ is_aggr_not_less_than, is_aggr_equal, is_aggr_not_equal, + has_no_outliers, foreign_key, compare_datasets, is_data_fresh_per_time_window, @@ -30,6 +31,279 @@ SCHEMA = "a: string, b: int" +def test_has_no_outliers_int_numeric_types(spark: SparkSession): + test_df = spark.createDataFrame( + [ + [1, 10], + [2, 12], + [3, 14], + [4, 13], + [5, 11], + [6, 20], # outlier + [7, 9], # outlier + [8, 15], + [9, 14], + [10, 13], + ], + "a: int, b: int", + ) + + condition, apply_method = has_no_outliers("b") + actual_apply_df = apply_method(test_df) + actual_condition_df = actual_apply_df.select("a", "b", condition) + + expected_condition_df = spark.createDataFrame( + [ + [1, 10, None], + [2, 12, None], + [3, 14, None], + [4, 13, None], + [5, 11, None], + [6, 20, "Value '20' in Column 'b' is an outlier as per MAD."], + [7, 9, "Value '9' in Column 'b' is an outlier as per MAD."], + [8, 15, None], + [9, 14, None], + [10, 13, None], + ], + "a: int, b: int, b_has_outliers: string", + ) + assert_df_equality(actual_condition_df, expected_condition_df, ignore_nullable=True, ignore_row_order=True) + + +def test_has_no_outliers_float_numeric_types(spark: SparkSession): + test_df = spark.createDataFrame( + [ + [1, 10.0], + [2, 12.0], + [3, 14.0], + [4, 13.0], + [5, 11.0], + [6, 20.0], # outlier + [7, 9.0], # outlier + [8, 15.0], + [9, 14.0], + [10, 13.0], + ], + "a: int, b: float", + ) + + condition, apply_method = has_no_outliers("b") + actual_apply_df = apply_method(test_df) + actual_condition_df = actual_apply_df.select("a", "b", condition) + + expected_condition_df = spark.createDataFrame( + [ + [1, 10.0, None], + [2, 12.0, None], + [3, 14.0, None], + [4, 13.0, None], + [5, 11.0, None], + [6, 20.0, "Value '20.0' in Column 'b' is an outlier as per MAD."], + [7, 9.0, "Value '9.0' in Column 'b' is an outlier as per MAD."], + [8, 15.0, None], + [9, 14.0, None], + [10, 13.0, None], + ], + "a: int, b: float, b_has_outliers: string", + ) + assert_df_equality(actual_condition_df, expected_condition_df, ignore_nullable=True, ignore_row_order=True) + + +def test_has_no_outliers_long_numeric_types(spark: SparkSession): + test_df = spark.createDataFrame( + [ + [1, 10], + [2, 12], + [3, 14], + [4, 13], + [5, 11], + [6, 20], # outlier + [7, 9], # outlier + [8, 15], + [9, 14], + [10, 13], + ], + "a: int, b: long", + ) + + condition, apply_method = has_no_outliers("b") + actual_apply_df = apply_method(test_df) + actual_condition_df = actual_apply_df.select("a", "b", condition) + + expected_condition_df = spark.createDataFrame( + [ + [1, 10, None], + [2, 12, None], + [3, 14, None], + [4, 13, None], + [5, 11, None], + [6, 20, "Value '20' in Column 'b' is an outlier as per MAD."], + [7, 9, "Value '9' in Column 'b' is an outlier as per MAD."], + [8, 15, None], + [9, 14, None], + [10, 13, None], + ], + "a: int, b: long, b_has_outliers: string", + ) + assert_df_equality(actual_condition_df, expected_condition_df, ignore_nullable=True, ignore_row_order=True) + + +def test_has_no_outliers_decimal_numeric_types(spark: SparkSession): + test_df = spark.createDataFrame( + [ + [1, Decimal("10.00")], + [2, Decimal("12.00")], + [3, Decimal("14.00")], + [4, Decimal("13.00")], + [5, Decimal("11.00")], + [6, Decimal("20.00")], # outlier + [7, Decimal("9.00")], # outlier + [8, Decimal("15.00")], + [9, Decimal("14.00")], + [10, Decimal("13.00")], + ], + "a: int, b: decimal(10,2)", + ) + + condition, apply_method = has_no_outliers("b") + actual_apply_df = apply_method(test_df) + actual_condition_df = actual_apply_df.select("a", "b", condition) + + expected_condition_df = spark.createDataFrame( + [ + [1, Decimal("10.00"), None], + [2, Decimal("12.00"), None], + [3, Decimal("14.00"), None], + [4, Decimal("13.00"), None], + [5, Decimal("11.00"), None], + [6, Decimal("20.00"), "Value '20.00' in Column 'b' is an outlier as per MAD."], + [7, Decimal("9.00"), "Value '9.00' in Column 'b' is an outlier as per MAD."], + [8, Decimal("15.00"), None], + [9, Decimal("14.00"), None], + [10, Decimal("13.00"), None], + ], + "a: int, b: decimal(10,2), b_has_outliers: string", + ) + assert_df_equality(actual_condition_df, expected_condition_df, ignore_nullable=True, ignore_row_order=True) + + +def test_has_no_outliers_empty_dataframe(spark: SparkSession): + test_df = spark.createDataFrame( + [], + "a: int, b: decimal(10,2)", + ) + + condition, apply_method = has_no_outliers("b") + actual_apply_df = apply_method(test_df) + actual_condition_df = actual_apply_df.select("a", "b", condition) + + expected_condition_df = spark.createDataFrame( + [], + "a: int, b: decimal(10,2), b_has_outliers: string", + ) + assert_df_equality(actual_condition_df, expected_condition_df, ignore_nullable=True, ignore_row_order=True) + + +def test_has_no_outliers_with_row_filter(spark: SparkSession): + test_df = spark.createDataFrame( + [ + [1, 10], + [2, 12], + [3, 5], # outlier + [3, 13], + [3, 11], + [3, 22], # outlier + [7, 9], + [3, 15], + [4, 16], + [10, 25], + ], + "a: int, b: int", + ) + + condition, apply_method = has_no_outliers("b", row_filter="a = 3") + actual_apply_df = apply_method(test_df) + actual_condition_df = actual_apply_df.select("a", "b", condition) + + expected_condition_df = spark.createDataFrame( + [ + [1, 10, None], + [2, 12, None], + [3, 5, "Value '5' in Column 'b' is an outlier as per MAD."], + [3, 13, None], + [3, 11, None], + [3, 22, "Value '22' in Column 'b' is an outlier as per MAD."], # outlier + [7, 9, None], + [3, 15, None], + [4, 16, None], + [10, 25, None], + ], + "a: int, b: int, b_has_outliers: string", + ) + assert_df_equality(actual_condition_df, expected_condition_df, ignore_nullable=True, ignore_row_order=True) + + +def test_has_no_outliers_with_none_median(spark: SparkSession): + test_df = spark.createDataFrame( + [ + [1, None], + [2, None], + [3, None], + [3, None], + [3, None], + [3, None], + [7, None], + [3, None], + [4, None], + [10, None], + ], + "a: int, b: int", + ) + + condition, apply_method = has_no_outliers("b") + actual_apply_df = apply_method(test_df) + actual_condition_df = actual_apply_df.select("a", "b", condition) + + expected_condition_df = spark.createDataFrame( + [ + [1, None, None], + [2, None, None], + [3, None, None], + [3, None, None], + [3, None, None], + [3, None, None], + [7, None, None], + [3, None, None], + [4, None, None], + [10, None, None], + ], + "a: int, b: int, b_has_outliers: string", + ) + assert_df_equality(actual_condition_df, expected_condition_df, ignore_nullable=True, ignore_row_order=True) + + +def test_has_no_outliers_in_string_columns(spark: SparkSession): + test_df = spark.createDataFrame( + [ + ["str1", 10], + ["str2", 12], + ["str3", 14], + ["str4", 13], + ["str5", 11], + ["str6", 20], + ], + "a: string, b: int", + ) + apply_method = has_no_outliers("a")[1] + + with pytest.raises( + InvalidParameterError, + match="Column 'a' must be of numeric type to perform outlier detection using MAD method, " + "but got type 'string' instead.", + ): + apply_method(test_df) + + def test_is_unique(spark: SparkSession): test_df = spark.createDataFrame( [ diff --git a/tests/perf/test_apply_checks.py b/tests/perf/test_apply_checks.py index 1a4dba33f..4785d8f3f 100644 --- a/tests/perf/test_apply_checks.py +++ b/tests/perf/test_apply_checks.py @@ -1217,6 +1217,42 @@ def test_benchmark_foreach_is_data_fresh_per_time_window(benchmark, ws, generate assert actual_count == EXPECTED_ROWS +def test_benchmark_has_no_outliers(benchmark, ws, generated_df): + dq_engine = DQEngine(workspace_client=ws, extra_params=EXTRA_PARAMS) + checks = [ + DQDatasetRule( + criticality="warn", + check_func=check_funcs.has_no_outliers, + column="col2", + ) + ] + checked = dq_engine.apply_checks(generated_df, checks) + actual_count = benchmark(lambda: checked.count()) + assert actual_count == EXPECTED_ROWS + + +@pytest.mark.parametrize( + "generated_integer_df", + [{"n_rows": DEFAULT_ROWS, "n_columns": 5}], + indirect=True, + ids=lambda param: f"n_rows_{param['n_rows']}_n_columns_{param['n_columns']}", +) +@pytest.mark.benchmark(group="test_benchmark_foreach_is_not_greater_than") +def test_benchmark_foreach_has_no_outliers(benchmark, ws, generated_integer_df): + columns, df, n_rows = generated_integer_df + dq_engine = DQEngine(workspace_client=ws, extra_params=EXTRA_PARAMS) + checks = [ + *DQForEachColRule( + criticality="error", + check_func=check_funcs.has_no_outliers, + columns=columns, + ).get_rules() + ] + benchmark.group += f"_{n_rows}_rows_{len(columns)}_columns" + result = benchmark(lambda: dq_engine.apply_checks(df, checks).count()) + assert result == EXPECTED_ROWS + + @pytest.mark.parametrize( "column", [ diff --git a/tests/resources/all_dataset_checks.yaml b/tests/resources/all_dataset_checks.yaml index 4d5e523ef..39402900b 100644 --- a/tests/resources/all_dataset_checks.yaml +++ b/tests/resources/all_dataset_checks.yaml @@ -216,3 +216,10 @@ function: has_valid_schema arguments: expected_schema: "col1 STRING, col2 INT, col3 INT" + +# has_no_outliers check +- criticality: error + check: + function: has_no_outliers + arguments: + column: col2 \ No newline at end of file diff --git a/tests/unit/test_build_rules.py b/tests/unit/test_build_rules.py index 00c4f3fe0..a34e26dfd 100644 --- a/tests/unit/test_build_rules.py +++ b/tests/unit/test_build_rules.py @@ -16,6 +16,7 @@ is_in_list, is_not_null_and_not_empty_array, is_not_null_and_is_in_list, + has_no_outliers, is_unique, is_aggr_not_greater_than, is_aggr_not_less_than, @@ -231,6 +232,7 @@ def test_build_rules(): column="i", check_func_kwargs={"column": "i_as_kwargs"}, ), + DQDatasetRule(criticality="warn", check_func=has_no_outliers, column="c"), DQDatasetRule(criticality="warn", check_func=is_unique, columns=["g"]), DQDatasetRule(criticality="warn", check_func=is_unique, check_func_kwargs={"columns": ["g_as_kwargs"]}), # columns field should be used instead of columns kwargs @@ -426,6 +428,7 @@ def test_build_rules(): check_func_kwargs={"column": "i_as_kwargs"}, check_func_args=[[1, 2]], ), + DQDatasetRule(name="c_has_outliers", criticality="warn", check_func=has_no_outliers, column="c"), DQDatasetRule( name="g_is_not_unique", criticality="warn", @@ -548,6 +551,15 @@ def test_build_rules_by_metadata(): "criticality": "error", "check": {"function": "is_not_null", "for_each_column": ["a", "b"], "arguments": {}}, }, + { + "name": "c_has_no_outliers", + "criticality": "error", + "filter": "a=0", + "check": { + "function": "has_no_outliers", + "arguments": {"column": "c"}, + }, + }, { "criticality": "error", "check": { @@ -723,6 +735,13 @@ def test_build_rules_by_metadata(): check_func=is_not_null, column="b", ), + DQDatasetRule( + name="c_has_no_outliers", + criticality="error", + check_func=has_no_outliers, + filter="a=0", + column="c", + ), DQDatasetRule( name="struct_a_b_is_not_unique", criticality="error", @@ -1148,6 +1167,7 @@ def test_convert_dq_rules_to_metadata(): DQDatasetRule( criticality="error", check_func=is_unique, columns=["col1"], check_func_kwargs={"row_filter": "col2 > 0"} ), + DQDatasetRule(criticality="error", check_func=has_no_outliers, column="col2"), ] actual_metadata = serialize_checks(checks) @@ -1351,6 +1371,11 @@ def test_convert_dq_rules_to_metadata(): 'criticality': 'error', 'check': {'function': 'is_unique', 'arguments': {'columns': ['col1'], 'row_filter': 'col2 > 0'}}, }, + { + "name": "col2_has_outliers", + "criticality": "error", + "check": {"function": "has_no_outliers", "arguments": {"column": "col2"}}, + }, ] assert actual_metadata == expected_metadata