Skip to content

Commit 6feaf34

Browse files
authored
Dataset CTE: support all check types including custom sql (#2482)
* Dataset CTE: cleanup and run only in cloud test command * Dataset CTE: support all check types except custom sql * Dataset CTE: support all check types except custom sql * Dataset CTE: support custom sql checks * fix tests * fix tests * small remarks * small remarks * Review remarks
1 parent 58a1530 commit 6feaf34

File tree

27 files changed

+663
-104
lines changed

27 files changed

+663
-104
lines changed

soda-athena/src/soda_athena/common/data_sources/athena_data_source.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, data_source_model: AthenaDataSourceModel, connection: Optiona
3333
super().__init__(data_source_model=data_source_model, connection=connection)
3434

3535
def _create_sql_dialect(self) -> SqlDialect:
36-
return AthenaSqlDialect(self)
36+
return AthenaSqlDialect(data_source_impl=self)
3737

3838
def _create_data_source_connection(self) -> DataSourceConnection:
3939
return AthenaDataSourceConnection(
@@ -138,11 +138,6 @@ class AthenaSqlDialect(SqlDialect):
138138
(SodaDataTypeName.TIME, SodaDataTypeName.VARCHAR),
139139
)
140140

141-
# We need to pass the data source impl to the dialect to be able to access connection properties (such as the staging dir)
142-
def __init__(self, data_source_impl: AthenaDataSourceImpl):
143-
super().__init__()
144-
self.data_source_impl = data_source_impl
145-
146141
def default_casify(self, identifier: str) -> str:
147142
return identifier.lower()
148143

soda-bigquery/src/soda_bigquery/common/data_sources/bigquery_data_source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self, data_source_model: BigQueryDataSourceModel, connection: Optio
5454
self.cached_location = None
5555

5656
def _create_sql_dialect(self) -> SqlDialect:
57-
return BigQuerySqlDialect()
57+
return BigQuerySqlDialect(data_source_impl=self)
5858

5959
def _create_data_source_connection(self) -> DataSourceConnection:
6060
return BigQueryDataSourceConnection(

soda-core/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"opentelemetry-exporter-otlp-proto-http>=1.16.0,<2.0.0",
2222
"tabulate[widechars]",
2323
"python-dotenv~=1.0",
24+
"sqlglot",
2425
]
2526

2627
setup(

soda-core/src/soda_core/common/env_config_helper.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,12 @@ def soda_core_telemetry_local_debug_mode(self) -> bool:
4747
@property
4848
def soda_core_telemetry_test_mode(self) -> bool:
4949
return strtobool(os.getenv("SODA_CORE_TELEMETRY_TEST_MODE", "false"))
50+
51+
@property
52+
def soda_instruction_id(self) -> str | None:
53+
return os.getenv("SODA_INSTRUCTION_ID")
54+
55+
@property
56+
def is_running_on_agent(self) -> bool:
57+
# SODA_INSTRUCTION_ID is only set when running in Soda Agent
58+
return self.soda_instruction_id is not None

soda-core/src/soda_core/common/sql_ast.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class FROM(BaseSqlExpression):
7171
table_name: str
7272
table_prefix: Optional[list[str]] = None
7373
alias: Optional[str] = None
74-
sample_type: Optional[str] = None
74+
sampler_type: Optional[str] = None
7575
sample_size: Optional[Number] = None
7676

7777
def __post_init__(self):
@@ -85,8 +85,8 @@ def IN(self, table_prefix: str | list[str]) -> FROM:
8585
self.table_prefix = table_prefix if isinstance(table_prefix, list) else [table_prefix]
8686
return self
8787

88-
def SAMPLE(self, sample_type: str, sample_size: Number) -> FROM:
89-
self.sample_type = sample_type
88+
def SAMPLE(self, sampler_type: str, sample_size: Number) -> FROM:
89+
self.sampler_type = sampler_type
9090
self.sample_size = sample_size
9191
return self
9292

soda-core/src/soda_core/common/sql_dialect.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from datetime import date, datetime, time
66
from numbers import Number
77
from textwrap import indent
8-
from typing import Any, Optional, Tuple
8+
from typing import TYPE_CHECKING, Any, Optional, Tuple
99

1010
from soda_core.common.data_source_results import QueryResult
1111
from soda_core.common.dataset_identifier import DatasetIdentifier
@@ -16,6 +16,7 @@
1616
SodaDataTypeName,
1717
SqlDataType,
1818
)
19+
from soda_core.common.soda_cloud_dto import SamplerType
1920
from soda_core.common.sql_ast import (
2021
ALTER_TABLE,
2122
ALTER_TABLE_ADD_COLUMN,
@@ -81,6 +82,10 @@
8182
SqlExpression,
8283
SqlExpressionStr,
8384
)
85+
from soda_core.common.sql_utils import apply_sampling_to_sql
86+
87+
if TYPE_CHECKING:
88+
from soda_core.common.data_source_impl import DataSourceImpl
8489

8590
logger: logging.Logger = soda_logger
8691

@@ -96,7 +101,12 @@ class SqlDialect:
96101

97102
SODA_DATA_TYPE_SYNONYMS: tuple[tuple[SodaDataTypeName, ...]] = ()
98103

99-
def __init__(self):
104+
def __init__(
105+
self,
106+
data_source_impl: DataSourceImpl,
107+
):
108+
self.data_source_impl: DataSourceImpl = data_source_impl
109+
100110
self._data_type_name_synonym_mappings: dict[str, str] = self._build_data_type_name_synonym_mappings(
101111
self._get_data_type_name_synonyms()
102112
)
@@ -727,8 +737,8 @@ def _build_from_part(self, from_part: FROM) -> str:
727737
)
728738
]
729739

730-
if isinstance(from_part.sample_type, str) and isinstance(from_part.sample_size, Number):
731-
from_parts.append(self._build_sample_sql(from_part.sample_type, from_part.sample_size))
740+
if isinstance(from_part.sampler_type, str) and isinstance(from_part.sample_size, Number):
741+
from_parts.append(self._build_sample_sql(from_part.sampler_type, from_part.sample_size))
732742

733743
if isinstance(from_part.alias, str):
734744
from_parts.append(self._alias_format(from_part.alias))
@@ -976,7 +986,7 @@ def format_expr(e: SqlExpression) -> SqlExpression:
976986
string_to_hash = CONCAT_WS(separator="'||'", expressions=formatted_expressions)
977987
return self.build_expression_sql(STRING_HASH(string_to_hash))
978988

979-
def _build_sample_sql(self, sample_type: str, sample_size: Number) -> str:
989+
def _build_sample_sql(self, sampler_type: str, sample_size: Number) -> str:
980990
raise NotImplementedError("Sampling not implemented for this dialect")
981991

982992
def information_schema_namespace_elements(self, data_source_namespace: DataSourceNamespace) -> list[str]:
@@ -1191,6 +1201,20 @@ def get_sql_data_type_class(self) -> type:
11911201
def supports_case_sensitive_column_names(self) -> bool:
11921202
return True
11931203

1204+
def apply_sampling(
1205+
self,
1206+
sql: str,
1207+
sampler_limit: Number,
1208+
sampler_type: SamplerType,
1209+
) -> str:
1210+
return apply_sampling_to_sql(
1211+
sql=sql,
1212+
sampler_limit=sampler_limit,
1213+
sampler_type=sampler_type,
1214+
read_dialect=self.data_source_impl.type_name,
1215+
write_dialect=self.data_source_impl.type_name,
1216+
)
1217+
11941218
########################################################
11951219
# Metadata columns query
11961220
########################################################
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from __future__ import annotations
2+
3+
from numbers import Number
4+
5+
import sqlglot
6+
from soda_core.common.soda_cloud_dto import SamplerType
7+
from sqlglot import exp
8+
9+
10+
def build_sample_clause(sampler_limit: Number, sampler_type: SamplerType) -> exp.TableSample:
11+
if sampler_limit <= 0:
12+
raise ValueError("sampler_limit must be positive")
13+
14+
size = exp.Literal.number(sampler_limit)
15+
sample = exp.TableSample()
16+
17+
if sampler_type == SamplerType.ABSOLUTE_LIMIT:
18+
sample.set("size", size)
19+
else:
20+
raise ValueError(f"Unsupported sample type: {sampler_type}")
21+
22+
return sample
23+
24+
25+
def attach_sample_to_relation(rel: exp.Expression, sampler_limit: Number, sampler_type: SamplerType) -> None:
26+
"""
27+
Attach a TableSample clause to a relation (Table or Subquery),
28+
unless it already has one.
29+
"""
30+
if rel is None:
31+
return
32+
33+
if rel.args.get("sample"):
34+
return
35+
36+
if isinstance(rel, (exp.Table, exp.Subquery)):
37+
rel.set("sample", build_sample_clause(sampler_limit, sampler_type))
38+
39+
40+
def apply_sampling_to_sql(
41+
sql: str,
42+
sampler_limit: Number,
43+
sampler_type: SamplerType,
44+
read_dialect: str | None = None,
45+
write_dialect: str | None = None,
46+
) -> str:
47+
"""
48+
Add TABLESAMPLE / SAMPLE to every table-like source in all FROM and JOIN clauses,
49+
including inside CTEs and subqueries.
50+
51+
Exact rendering is dialect-specific.
52+
"""
53+
tree = sqlglot.parse_one(sql, read=read_dialect) if read_dialect else sqlglot.parse_one(sql)
54+
55+
# FROM sources (top-level, CTE bodies, nested subqueries)
56+
# Keep track of CTEs and skip them as they are already sampled at their definition
57+
ctes = {cte.alias_or_name for cte in tree.find_all(exp.CTE)}
58+
for from_ in tree.find_all(exp.From):
59+
if isinstance(from_.this, exp.Table) and from_.this.alias_or_name in ctes:
60+
continue
61+
62+
attach_sample_to_relation(from_.this, sampler_limit, sampler_type)
63+
64+
# JOIN targets
65+
for join in tree.find_all(exp.Join):
66+
attach_sample_to_relation(join.this, sampler_limit, sampler_type)
67+
68+
return tree.sql(dialect=write_dialect) if write_dialect else tree.sql()

soda-core/src/soda_core/contracts/impl/check_types/failed_rows_check.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,20 @@ def setup_metrics(
8585
self.failed_rows_count_metric_impl = self._resolve_metric(
8686
FailedRowsQueryMetricImpl(contract_impl=contract_impl, column_impl=column_impl, check_impl=self)
8787
)
88+
89+
sql = self.failed_rows_check_yaml.query
90+
91+
if contract_impl.should_apply_sampling:
92+
sql = contract_impl.data_source_impl.sql_dialect.apply_sampling(
93+
sql=sql,
94+
sampler_limit=contract_impl.sampler_limit,
95+
sampler_type=contract_impl.sampler_type,
96+
)
8897
if contract_impl.data_source_impl:
8998
failed_rows_count_query: Query = FailedRowsCountQuery(
9099
data_source_impl=contract_impl.data_source_impl,
91100
metrics=[self.failed_rows_count_metric_impl],
92-
failed_rows_query=self.failed_rows_check_yaml.query,
101+
failed_rows_query=sql,
93102
)
94103
self.queries.append(failed_rows_count_query)
95104

soda-core/src/soda_core/contracts/impl/check_types/invalidity_check.py

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ def setup_metrics(self, contract_impl: ContractImpl, column_impl: ColumnImpl, ch
8989
)
9090
# this is used in the check extension to extract failed keys and rows
9191
self.ref_query = InvalidReferenceCountQuery(
92+
cte=contract_impl.cte,
93+
sampler_type=contract_impl.sampler_type,
94+
sampler_limit=contract_impl.sampler_limit,
9295
metric_impl=self.invalid_count_metric_impl,
9396
dataset_filter=self.contract_impl.filter,
9497
check_filter=self.check_yaml.filter,
@@ -191,62 +194,63 @@ class DatasetAlias(Enum):
191194
class InvalidReferenceCountQuery(Query):
192195
def __init__(
193196
self,
197+
cte: CTE,
198+
sampler_type: Optional[str],
199+
sampler_limit: Optional[Number],
194200
metric_impl: InvalidReferenceCountMetricImpl,
195201
dataset_filter: Optional[str],
196202
check_filter: Optional[str],
197203
data_source_impl: Optional[DataSourceImpl],
198204
):
199205
super().__init__(data_source_impl=data_source_impl, metrics=[metric_impl])
200206
self.metric_impl = metric_impl
201-
self.dataset_filter = dataset_filter
202207
self.check_filter = check_filter
203208

204209
self.referencing_alias: str = DatasetAlias.CONTRACT.value
205210
self.referenced_alias: str = DatasetAlias.REFERENCE.value
211+
self._referenced_cte_name: str = "_soda_filtered_referenced_dataset"
206212

207-
sql_ast = self.build_query(SELECT(COUNT(STAR())))
213+
self.sampler_type: Optional[str] = sampler_type
214+
self.sampler_limit: Optional[Number] = sampler_limit
215+
216+
sql_ast = self.build_query(cte=cte)
208217
self.sql = self.data_source_impl.sql_dialect.build_select_sql(sql_ast)
209218

210-
def build_query(self, select_expression: SqlExpression) -> SqlExpression:
211-
sql_ast: list = [select_expression]
212-
sql_ast.extend(self.query_from())
219+
def build_query(self, cte: CTE) -> list[SqlExpression]:
220+
query = [
221+
WITH([cte, self.referenced_cte()]),
222+
SELECT(COUNT(STAR())),
223+
FROM(cte.alias).AS(self.referencing_alias),
224+
WHERE.optional(SqlExpressionStr.optional(self.check_filter)),
225+
]
226+
227+
query.extend(self.query_join())
213228

214-
if self.dataset_filter or self.check_filter:
215-
dataset_filter_expr: Optional[SqlExpressionStr] = None
216-
check_filter_expr: Optional[SqlExpressionStr] = None
217-
combined_filter_expr: Optional[SqlExpression] = None
229+
return query
218230

219-
if self.dataset_filter:
220-
dataset_filter_expr = SqlExpressionStr(self.dataset_filter)
221-
combined_filter_expr = dataset_filter_expr
231+
def referenced_cte(self) -> CTE:
232+
valid_reference_data: ValidReferenceData = self.metric_impl.missing_and_validity.valid_reference_data
233+
referenced_dataset_name: str = valid_reference_data.dataset_name
234+
referenced_dataset_prefix: Optional[list[str]] = valid_reference_data.dataset_prefix
222235

223-
if self.check_filter:
224-
check_filter_expr = SqlExpressionStr(self.check_filter)
225-
combined_filter_expr = check_filter_expr
236+
cte = CTE(self._referenced_cte_name).AS(
237+
[
238+
SELECT(STAR()),
239+
FROM(referenced_dataset_name).IN(referenced_dataset_prefix),
240+
]
241+
)
226242

227-
if dataset_filter_expr and check_filter_expr:
228-
combined_filter_expr = AND([dataset_filter_expr, check_filter_expr])
243+
if self.sampler_type and self.sampler_limit:
244+
cte.cte_query[1] = cte.cte_query[1].SAMPLE(self.sampler_type, self.sampler_limit)
229245

230-
original_from = sql_ast[1].AS(None)
231-
sql_ast[1] = FROM("filtered_dataset").AS(self.referencing_alias)
232-
sql_ast = [
233-
WITH([CTE("filtered_dataset").AS([SELECT(STAR()), original_from, WHERE(combined_filter_expr)])]),
234-
] + sql_ast
235-
return sql_ast
246+
return cte
236247

237-
def query_from(self) -> SqlExpression:
248+
def query_join(self) -> SqlExpression:
238249
valid_reference_data: ValidReferenceData = self.metric_impl.missing_and_validity.valid_reference_data
239250

240-
referencing_dataset_name: str = self.metric_impl.contract_impl.dataset_name
241-
referencing_dataset_prefix: Optional[str] = self.metric_impl.contract_impl.dataset_prefix
242251
referencing_column_name: str = self.metric_impl.column_impl.column_yaml.name
243252

244-
referenced_dataset_name: str = valid_reference_data.dataset_name
245-
referenced_dataset_prefix: Optional[list[str]] = (
246-
valid_reference_data.dataset_prefix
247-
if valid_reference_data.dataset_prefix is not None
248-
else self.metric_impl.contract_impl.dataset_prefix
249-
)
253+
referenced_dataset_name: str = self._referenced_cte_name
250254
referenced_column: str = valid_reference_data.column
251255

252256
# The variant to get the failed rows is:
@@ -265,9 +269,7 @@ def query_from(self) -> SqlExpression:
265269
)
266270

267271
return [
268-
FROM(referencing_dataset_name).IN(referencing_dataset_prefix).AS(self.referencing_alias),
269272
LEFT_INNER_JOIN(referenced_dataset_name)
270-
.IN(referenced_dataset_prefix)
271273
.ON(
272274
EQ(
273275
COLUMN(referencing_column_name).IN(self.referencing_alias),

soda-core/src/soda_core/contracts/impl/check_types/metric_check.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,21 @@ def setup_metrics(
7878
)
7979

8080
elif self.metric_check_yaml.query:
81+
sql = self.metric_check_yaml.query
82+
83+
if contract_impl.should_apply_sampling:
84+
sql = contract_impl.data_source_impl.sql_dialect.apply_sampling(
85+
sql, contract_impl.sampler_limit, sampler_type=contract_impl.sampler_type
86+
)
87+
8188
self.numeric_metric_impl = self._resolve_metric(
8289
MetricQueryMetricImpl(contract_impl=contract_impl, column_impl=column_impl, check_impl=self)
8390
)
8491
if contract_impl.data_source_impl:
8592
metric_query: Query = MetricQuery(
8693
data_source_impl=contract_impl.data_source_impl,
8794
metrics=[self.numeric_metric_impl],
88-
sql=self.metric_check_yaml.query,
95+
sql=sql,
8996
)
9097
self.queries.append(metric_query)
9198

0 commit comments

Comments
 (0)