Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions src/databricks/labs/dqx/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
from concurrent import futures
from dataclasses import dataclass
from datetime import timezone
from decimal import Decimal, Context
from difflib import SequenceMatcher
from fnmatch import fnmatch
Expand Down Expand Up @@ -576,13 +577,13 @@ def _get_min_max(
f"stddev={stddev}, min={metrics.get('min')}"
)
# we need to preserve type at the end
min_limit, max_limit = self._adjust_min_max_limits(min_limit, max_limit, avg, typ, metrics)
min_limit, max_limit = self._adjust_min_max_limits(min_limit, max_limit, avg, typ, metrics, opts)
else:
logger.info(f"Can't get min/max for field {col_name}")
return descr, max_limit, min_limit

def _adjust_min_max_limits(
self, min_limit: Any, max_limit: Any, avg: Any, typ: T.DataType, metrics: dict[str, Any]
self, min_limit: Any, max_limit: Any, avg: Any, typ: T.DataType, metrics: dict[str, Any], opts: dict[str, Any]
) -> tuple[Any, Any]:
"""
Adjusts the minimum and maximum limits based on the data type of the column.
Expand All @@ -592,23 +593,26 @@ def _adjust_min_max_limits(
:param avg: The average value of the column.
:param typ: The PySpark data type of the column.
:param metrics: A dictionary containing the calculated metrics.
:param opts: A dictionary of options for min/max limit adjustment.
:return: A tuple containing the adjusted minimum and maximum limits.
"""
if isinstance(typ, T.IntegralType):
min_limit = int(self._round_value(min_limit, "down", {"round": True}))
max_limit = int(self._round_value(max_limit, "up", {"round": True}))
min_limit = int(self._round_value(min_limit, "down", opts))
max_limit = int(self._round_value(max_limit, "up", opts))
elif typ == T.DateType():
min_limit = datetime.date.fromtimestamp(int(min_limit))
max_limit = datetime.date.fromtimestamp(int(max_limit))
metrics["min"] = datetime.date.fromtimestamp(int(metrics["min"]))
metrics["max"] = datetime.date.fromtimestamp(int(metrics["max"]))
metrics["mean"] = datetime.date.fromtimestamp(int(avg))
elif typ == T.TimestampType():
min_limit = self._round_value(datetime.datetime.fromtimestamp(int(min_limit)), "down", {"round": True})
max_limit = self._round_value(datetime.datetime.fromtimestamp(int(max_limit)), "up", {"round": True})
metrics["min"] = datetime.datetime.fromtimestamp(int(metrics["min"]))
metrics["max"] = datetime.datetime.fromtimestamp(int(metrics["max"]))
metrics["mean"] = datetime.datetime.fromtimestamp(int(avg))
min_limit = self._round_value(
datetime.datetime.fromtimestamp(int(min_limit), tz=timezone.utc), "down", opts
)
max_limit = self._round_value(datetime.datetime.fromtimestamp(int(max_limit), tz=timezone.utc), "up", opts)
metrics["min"] = datetime.datetime.fromtimestamp(int(metrics["min"]), tz=timezone.utc)
metrics["max"] = datetime.datetime.fromtimestamp(int(metrics["max"]), tz=timezone.utc)
metrics["mean"] = datetime.datetime.fromtimestamp(int(avg), tz=timezone.utc)
return min_limit, max_limit

@staticmethod
Expand Down Expand Up @@ -674,10 +678,21 @@ def _type_supports_min_max(typ: T.DataType) -> bool:

@staticmethod
def _round_datetime(value: datetime.datetime, direction: str) -> datetime.datetime:
"""
Rounds a datetime value to midnight based on the specified direction.

:param value: The datetime value to round.
:param direction: The rounding direction ("up" or "down").
:return: The rounded datetime value.
"""
if direction == "down":
return value.replace(hour=0, minute=0, second=0, microsecond=0)
if direction == "up":
return value.replace(hour=0, minute=0, second=0, microsecond=0) + datetime.timedelta(days=1)
try:
return value.replace(hour=0, minute=0, second=0, microsecond=0) + datetime.timedelta(days=1)
except OverflowError:
logger.warning("Rounding datetime up caused overflow; returning datetime.max instead.")
return datetime.datetime.max
return value

@staticmethod
Expand Down
195 changes: 193 additions & 2 deletions tests/integration/test_profiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import date, datetime
from datetime import date, datetime, timezone
from decimal import Decimal

import pytest
Expand Down Expand Up @@ -84,7 +84,10 @@ def test_profiler(spark, ws):
name="min_max",
column="s1.ns1",
description="Real min/max values were used",
parameters={"min": datetime(2023, 1, 6, 0, 0), "max": datetime(2023, 1, 9, 0, 0)},
parameters={
"min": datetime(2023, 1, 6, 0, 0, tzinfo=timezone.utc),
"max": datetime(2023, 1, 9, 0, 0, tzinfo=timezone.utc),
},
),
DQProfile(name="is_not_null", column="s1.s2.ns2", description=None, parameters=None),
DQProfile(name="is_not_null", column="s1.s2.ns3", description=None, parameters=None),
Expand Down Expand Up @@ -195,6 +198,194 @@ def test_profiler_non_default_profile_options(spark, ws):
assert rules == expected_rules


def test_profiler_non_default_profile_options_remove_outliers_no_outlier_columns(spark, ws):
inp_schema = T.StructType(
[
T.StructField("t1", T.IntegerType()),
T.StructField("t2", T.StringType()),
T.StructField(
"s1",
T.StructType(
[
T.StructField("ns1", T.TimestampType()),
T.StructField(
"s2",
T.StructType([T.StructField("ns2", T.StringType()), T.StructField("ns3", T.DateType())]),
),
]
),
),
]
)
inp_df = spark.createDataFrame(
[
[
1,
" test ",
{
"ns1": datetime.fromisoformat("9999-12-31T10:00:11+00:00"),
"s2": {"ns2": "test", "ns3": date.fromisoformat("9999-12-31")},
},
],
[
2,
" ",
{
"ns1": datetime.fromisoformat("2023-01-07T10:00:11+00:00"),
"s2": {"ns2": "test2", "ns3": date.fromisoformat("2023-01-07")},
},
],
[
3,
None,
{
"ns1": datetime.fromisoformat("2023-01-06T10:00:11+00:00"),
"s2": {"ns2": "test", "ns3": date.fromisoformat("2023-01-06")},
},
],
],
schema=inp_schema,
)

profiler = DQProfiler(ws)

profile_options = {
"round": False, # do not round the min/max values
"max_in_count": 1, # generate is_in if we have less than 1 percent of distinct values
"distinct_ratio": 0.01, # generate is_in if we have less than 1 percent of distinct values
"remove_outliers": True, # remove outliers
"num_sigmas": 1, # number of sigmas to use when remove_outliers is True
"trim_strings": False, # trim whitespace from strings
"max_empty_ratio": 0.01, # generate is_not_null_or_empty rule if we have less than 1 percent of empty strings
"sample_fraction": 1.0, # fraction of data to sample
"sample_seed": None, # seed for sampling
"limit": 1000, # limit the number of samples
}

stats, rules = profiler.profile(inp_df, columns=inp_df.columns, options=profile_options)

expected_rules = [
DQProfile(name="is_not_null", column="t1", description=None, parameters=None),
DQProfile(
name="min_max", column="t1", description="Real min/max values were used", parameters={"min": 1, "max": 3}
),
DQProfile(name='is_not_null_or_empty', column='t2', description=None, parameters={'trim_strings': False}),
DQProfile(name="is_not_null", column="s1.ns1", description=None, parameters=None),
DQProfile(
name="min_max",
column="s1.ns1",
description="Real min/max values were used",
parameters={
'max': datetime(9999, 12, 31, 10, 0, 11, tzinfo=timezone.utc),
'min': datetime(2023, 1, 6, 10, 0, 11, tzinfo=timezone.utc),
},
),
DQProfile(name="is_not_null", column="s1.s2.ns2", description=None, parameters=None),
DQProfile(name="is_not_null", column="s1.s2.ns3", description=None, parameters=None),
DQProfile(
name="min_max",
column="s1.s2.ns3",
description="Real min/max values were used",
parameters={"min": date(2023, 1, 6), "max": date(9999, 12, 31)},
),
]
assert len(stats.keys()) > 0
assert rules == expected_rules


def test_profiler_non_default_profile_options_with_rounding_enabled(spark, ws):
inp_schema = T.StructType(
[
T.StructField("t1", T.IntegerType()),
T.StructField("t2", T.StringType()),
T.StructField(
"s1",
T.StructType(
[
T.StructField("ns1", T.TimestampType()),
T.StructField(
"s2",
T.StructType([T.StructField("ns2", T.StringType()), T.StructField("ns3", T.DateType())]),
),
]
),
),
]
)
inp_df = spark.createDataFrame(
[
[
1,
" test ",
{
"ns1": datetime.fromisoformat("9999-12-31T10:00:11+00:00"),
"s2": {"ns2": "test", "ns3": date.fromisoformat("9999-12-31")},
},
],
[
2,
" ",
{
"ns1": datetime.fromisoformat("2023-01-07T10:00:11+00:00"),
"s2": {"ns2": "test2", "ns3": date.fromisoformat("2023-01-07")},
},
],
[
3,
None,
{
"ns1": datetime.fromisoformat("2023-01-06T10:00:11+00:00"),
"s2": {"ns2": "test", "ns3": date.fromisoformat("2023-01-06")},
},
],
],
schema=inp_schema,
)

profiler = DQProfiler(ws)

profile_options = {
"round": True, # round the min/max values
"max_in_count": 1, # generate is_in if we have less than 1 percent of distinct values
"distinct_ratio": 0.01, # generate is_in if we have less than 1 percent of distinct values
"remove_outliers": False, # do not remove outliers
"outlier_columns": ["t1", "s1"], # remove outliers in all columns of appropriate type
"num_sigmas": 1, # number of sigmas to use when remove_outliers is True
"trim_strings": False, # trim whitespace from strings
"max_empty_ratio": 0.01, # generate is_not_null_or_empty rule if we have less than 1 percent of empty strings
"sample_fraction": 1.0, # fraction of data to sample
"sample_seed": None, # seed for sampling
"limit": 1000, # limit the number of samples
}

stats, rules = profiler.profile(inp_df, columns=inp_df.columns, options=profile_options)

expected_rules = [
DQProfile(name="is_not_null", column="t1", description=None, parameters=None),
DQProfile(
name="min_max", column="t1", description="Real min/max values were used", parameters={"min": 1, "max": 3}
),
DQProfile(name='is_not_null_or_empty', column='t2', description=None, parameters={'trim_strings': False}),
DQProfile(name="is_not_null", column="s1.ns1", description=None, parameters=None),
DQProfile(
name="min_max",
column="s1.ns1",
description="Real min/max values were used",
parameters={'max': datetime.max, 'min': datetime(2023, 1, 6)},
),
DQProfile(name="is_not_null", column="s1.s2.ns2", description=None, parameters=None),
DQProfile(name="is_not_null", column="s1.s2.ns3", description=None, parameters=None),
DQProfile(
name="min_max",
column="s1.s2.ns3",
description="Real min/max values were used",
parameters={"min": date(2023, 1, 6), "max": date(9999, 12, 31)},
),
]
assert len(stats.keys()) > 0
assert rules == expected_rules


def test_profiler_empty_df(spark, ws):
test_df = spark.createDataFrame([], "data: string")

Expand Down