Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions src/databricks/labs/dqx/check_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,85 @@ def is_data_fresh(
)


@register_rule("dataset")
def has_no_outliers(column: str | Column, row_filter: str | None = None) -> tuple[Column, Callable]:
"""
Build a outlier check condition and closure for dataset-level validation.

This function uses a statistical method called MAD (Median Absolute Deviation) to checks whether
the specified columns values are between the calculated limits. The upper limits is calculated from the median + 3 * MAD
and the lower limit is calculated from the median - 3 * MAD. Values outside these limits are considered outliers.
The formula for calculating MAD is defined in the method _calculate_median_absolute_deviation.

Args:
column: column to check; can be a string column name or a column expression
row_filter: Optional SQL expression for filtering rows before checking for outliers.


Returns:
A tuple of:
- A Spark Column representing the condition for ouliers violations.
- A closure that applies the outliers check and adds the necessary condition/count columns.
"""
column = F.col(column) if isinstance(column, str) else column

col_str_norm, col_expr_str, col_expr = _get_normalized_column_and_expr(column)

unique_str = uuid.uuid4().hex # make sure any column added to the dataframe is unique
condition_col = f"__condition_{col_str_norm}_{unique_str}"

def apply(df: DataFrame) -> DataFrame:
"""
Apply the outlier detection logic to the DataFrame.

Adds columns indicating the median and MAD for the column.

Args:
df: The input DataFrame to validate for outliers.

Returns:
The DataFrame with additional median and MAD columns for outlier detection.
"""
column_type = df.schema[col_expr_str].dataType
if not isinstance(column_type, (types.NumericType)):
raise InvalidParameterError(
f"Column '{col_expr_str}' must be of numeric type to perform outlier detection using MAD method, "
f"but got type '{column_type.simpleString()}' instead."
)
filter_expr = F.expr(row_filter) if row_filter else F.lit(True)
filtered_df = df.filter(filter_expr)
mad, median = _calculate_median_absolute_deviation(filtered_df, col_expr_str)

# criar condição de outlier
lower_bound = median - (3.5 * mad)
upper_bound = median + (3.5 * mad)
lower_bound_expr = _get_limit_expr(lower_bound)
upper_bound_expr = _get_limit_expr(upper_bound)

condition = (col_expr < (lower_bound_expr)) | (col_expr > (upper_bound_expr))

# Add outlier detection columns
result_df = (
filtered_df.withColumn('median_value', F.lit(median))
.withColumn('mad_value', F.lit(mad))
.withColumn(condition_col, F.when(condition, True).otherwise(False))
)

return result_df

condition = make_condition(
condition=F.col(condition_col),
message=F.concat_ws(
"",
F.lit("Value '"),
col_expr.cast("string"),
F.lit(f"' in Column '{col_expr_str}' is considered an outlier, according to the MAD statistical method."),
),
alias=f"{col_str_norm}_has_outliers",
)
return condition, apply


@register_rule("dataset")
def is_unique(
columns: list[str | Column],
Expand Down Expand Up @@ -2705,3 +2784,53 @@ def _validate_sql_query_params(query: str, merge_columns: list[str]) -> None:
raise UnsafeSqlQueryError(
"Provided SQL query is not safe for execution. Please ensure it does not contain any unsafe operations."
)


def _calculate_median_absolute_deviation(df: DataFrame, column: str) -> tuple[float, float]:
"""
Calculate the Median Absolute Deviation (MAD) for a numeric column.

The MAD is a robust measure of variability based on the median, calculated as:
MAD = median(|X_i - median(X)|)

This is useful for outlier detection as it is more robust to outliers than
standard deviation.

Args:
df: PySpark DataFrame
column: Name of the numeric column to calculate MAD for

Returns:
float: The Median Absolute Deviation value

Raises:
ValueError: If column doesn't exist or contains no numeric values

Example:
>>> df = spark.createDataFrame([(1.0,), (2.0,), (3.0,), (4.0,), (5.0,)], ["values"])
>>> mad = calculate_median_absolute_deviation(df, "values")
>>> print(f"MAD: {mad}")
"""
# Validate column exists
if column not in df.columns:
raise ValueError(f"Column '{column}' does not exist in DataFrame")

# Remove null values
df_clean = df.filter(F.col(column).isNotNull())

if df_clean.count() == 0:
raise ValueError(f"Column '{column}' contains no numeric values")

# Step 1: Calculate the median of the column
median_value = df_clean.agg(F.percentile_approx(column, 0.5)).collect()[0][0]

if median_value is None:
raise ValueError(f"Could not calculate median for column '{column}'")

# Step 2: Calculate absolute deviations from the median
df_with_deviations = df_clean.select(F.abs(F.col(column) - F.lit(median_value)).alias("absolute_deviation"))

# Step 3: Calculate the median of absolute deviations
mad = df_with_deviations.agg(F.percentile_approx("absolute_deviation", 0.5)).collect()[0][0]

return float(mad) if mad is not None else 0.0, float(median_value) if median_value is not None else 0.0
9 changes: 9 additions & 0 deletions tests/integration/test_build_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ def test_build_quality_rules_from_dataframe(spark):
"arguments": {"column": "test_col", "group_by": ["a"], "limit": 0, "aggr_type": "count"},
},
},
{
"name": "column_has_no_outliers",
"criticality": "error",
"filter": "test_col > 0",
"check": {
"function": "has_no_outliers",
"arguments": {"column": "c"},
},
},
]

df = deserialize_checks_to_dataframe(spark, test_checks)
Expand Down
62 changes: 62 additions & 0 deletions tests/integration/test_dataset_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
is_aggr_not_less_than,
is_aggr_equal,
is_aggr_not_equal,
has_no_outliers,
foreign_key,
compare_datasets,
is_data_fresh_per_time_window,
Expand All @@ -30,6 +31,67 @@
SCHEMA = "a: string, b: int"


def test_has_no_outliers(spark: SparkSession):
test_df = spark.createDataFrame(
[
[1, 10],
[2, 12],
[3, 14],
[4, 13],
[5, 11],
[6, 20], # outlier
[7, 9],
[8, 15],
[9, 14],
[10, 13],
],
"a: int, b: int",
)

condition, apply_method = has_no_outliers("b")
actual_apply_df = apply_method(test_df)
actual_condition_df = actual_apply_df.select("a", "b", condition)

expected_condition_df = spark.createDataFrame(
[
[1, 10, None],
[2, 12, None],
[3, 14, None],
[4, 13, None],
[5, 11, None],
[6, 20, "Value '20' in Column 'b' is considered an outlier, according to the MAD statistical method."],
[7, 9, "Value '9' in Column 'b' is considered an outlier, according to the MAD statistical method."],
[8, 15, None],
[9, 14, None],
[10, 13, None],
],
"a: int, b: int, b_has_outliers: string",
)
assert_df_equality(actual_condition_df, expected_condition_df, ignore_nullable=True, ignore_row_order=True)


def test_has_no_outliers_in_string_columns(spark: SparkSession):
test_df = spark.createDataFrame(
[
["str1", 10],
["str2", 12],
["str3", 14],
["str4", 13],
["str5", 11],
["str6", 20], # outlier
],
"a: string, b: int",
)
apply_method = has_no_outliers("a")[1]

with pytest.raises(
InvalidParameterError,
match="Column 'a' must be of numeric type to perform outlier detection using MAD method, "
"but got type 'string' instead.",
):
apply_method(test_df)


def test_is_unique(spark: SparkSession):
test_df = spark.createDataFrame(
[
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/test_build_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
is_in_list,
is_not_null_and_not_empty_array,
is_not_null_and_is_in_list,
has_no_outliers,
is_unique,
is_aggr_not_greater_than,
is_aggr_not_less_than,
Expand Down Expand Up @@ -231,6 +232,7 @@ def test_build_rules():
column="i",
check_func_kwargs={"column": "i_as_kwargs"},
),
DQDatasetRule(criticality="warn", check_func=has_no_outliers, column="c"),
DQDatasetRule(criticality="warn", check_func=is_unique, columns=["g"]),
DQDatasetRule(criticality="warn", check_func=is_unique, check_func_kwargs={"columns": ["g_as_kwargs"]}),
# columns field should be used instead of columns kwargs
Expand Down Expand Up @@ -426,6 +428,7 @@ def test_build_rules():
check_func_kwargs={"column": "i_as_kwargs"},
check_func_args=[[1, 2]],
),
DQDatasetRule(name="c_has_outliers", criticality="warn", check_func=has_no_outliers, column="c"),
DQDatasetRule(
name="g_is_not_unique",
criticality="warn",
Expand Down Expand Up @@ -548,6 +551,15 @@ def test_build_rules_by_metadata():
"criticality": "error",
"check": {"function": "is_not_null", "for_each_column": ["a", "b"], "arguments": {}},
},
{
"name": "c_has_no_outliers",
"criticality": "error",
"filter": "a=0",
"check": {
"function": "has_no_outliers",
"arguments": {"column": "c"},
},
},
{
"criticality": "error",
"check": {
Expand Down Expand Up @@ -723,6 +735,13 @@ def test_build_rules_by_metadata():
check_func=is_not_null,
column="b",
),
DQDatasetRule(
name="c_has_no_outliers",
criticality="error",
check_func=has_no_outliers,
filter="a=0",
column="c",
),
DQDatasetRule(
name="struct_a_b_is_not_unique",
criticality="error",
Expand Down Expand Up @@ -1148,6 +1167,7 @@ def test_convert_dq_rules_to_metadata():
DQDatasetRule(
criticality="error", check_func=is_unique, columns=["col1"], check_func_kwargs={"row_filter": "col2 > 0"}
),
DQDatasetRule(criticality="error", check_func=has_no_outliers, column="col2"),
]
actual_metadata = serialize_checks(checks)

Expand Down Expand Up @@ -1351,6 +1371,11 @@ def test_convert_dq_rules_to_metadata():
'criticality': 'error',
'check': {'function': 'is_unique', 'arguments': {'columns': ['col1'], 'row_filter': 'col2 > 0'}},
},
{
"name": "col2_has_outliers",
"criticality": "error",
"check": {"function": "has_no_outliers", "arguments": {"column": "col2"}},
},
]

assert actual_metadata == expected_metadata
Expand Down
Loading