Skip to content

Commit a8cb8aa

Browse files
committed
refactor: support sql_predicate when compile readtable in sqlglot
1 parent 7efdda8 commit a8cb8aa

File tree

4 files changed

+33
-0
lines changed

4 files changed

+33
-0
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def compile_readtable(node: nodes.ReadTableNode, child: ir.SQLGlotIR):
180180
col_names=[col.source_id for col in node.scan_list.items],
181181
alias_names=[col.id.sql for col in node.scan_list.items],
182182
uid_gen=child.uid_gen,
183+
sql_predicate=node.source.sql_predicate,
183184
system_time=node.source.at_time,
184185
)
185186

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def from_table(
120120
col_names: typing.Sequence[str],
121121
alias_names: typing.Sequence[str],
122122
uid_gen: guid.SequentialUIDGenerator,
123+
sql_predicate: typing.Optional[str] = None,
123124
system_time: typing.Optional[datetime.datetime] = None,
124125
) -> SQLGlotIR:
125126
"""Builds a SQLGlotIR expression from a BigQuery table.
@@ -131,6 +132,7 @@ def from_table(
131132
col_names (typing.Sequence[str]): The names of the columns to select.
132133
alias_names (typing.Sequence[str]): The aliases for the selected columns.
133134
uid_gen (guid.SequentialUIDGenerator): A generator for unique identifiers.
135+
sql_predicate (typing.Optional[str]): An optional SQL predicate for filtering.
134136
system_time (typing.Optional[str]): An optional system time for time-travel queries.
135137
"""
136138
selections = [
@@ -158,6 +160,10 @@ def from_table(
158160
version=version,
159161
)
160162
select_expr = sge.Select().select(*selections).from_(table_expr)
163+
if sql_predicate:
164+
select_expr = select_expr.where(
165+
sg.parse_one(sql_predicate, dialect="bigquery"), append=False
166+
)
161167
return cls(expr=select_expr, uid_gen=uid_gen)
162168

163169
@classmethod
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col`,
4+
`rowindex`,
5+
`string_col`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
WHERE
8+
`rowindex` > 0 AND `string_col` IN ('Hello, World!')
9+
)
10+
SELECT
11+
`rowindex`,
12+
`int64_col`,
13+
`string_col`
14+
FROM `bfcte_0`

tests/unit/core/compile/sqlglot/test_compile_readtable.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,15 @@ def test_compile_readtable_w_system_time(
6767
)
6868
bf_df = compiler_session.read_gbq_table(str(table_ref))
6969
snapshot.assert_match(bf_df.sql, "out.sql")
70+
71+
72+
def test_compile_readtable_w_columns_filters(compiler_session, snapshot):
73+
columns = ["rowindex", "int64_col", "string_col"]
74+
filters = [("rowindex", ">", 0), ("string_col", "in", ["Hello, World!"])]
75+
bf_df = compiler_session._loader.read_gbq_table(
76+
"bigframes-dev.sqlglot_test.scalar_types",
77+
enable_snapshot=False,
78+
columns=columns,
79+
filters=filters,
80+
)
81+
snapshot.assert_match(bf_df.sql, "out.sql")

0 commit comments

Comments
 (0)