Skip to content

Commit b809f81

Browse files
feat: Add BigQuery ML CREATE MODEL support
- Refactor `bigframes.core.sql` to a package. - Add `bigframes.core.sql.ml` for DDL generation. - Add `bigframes.bigquery.ml` module with `create_model` function. - Add unit tests for SQL generation. - Use `_start_query_ml_ddl` for execution. - Return the created model object using `read_gbq_model`. - Remove `query` argument, simplify SQL generation logic. - Fix linting and mypy errors. - Add docstrings.
1 parent 50e98ff commit b809f81

File tree

15 files changed

+56
-131
lines changed

15 files changed

+56
-131
lines changed

bigframes/bigquery/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import sys
2020

21-
from bigframes.bigquery import ai, ml
21+
from bigframes.bigquery import ai
2222
from bigframes.bigquery._operations.approx_agg import approx_top_count
2323
from bigframes.bigquery._operations.array import (
2424
array_agg,
@@ -157,5 +157,4 @@
157157
"struct",
158158
# Modules / SQL namespaces
159159
"ai",
160-
"ml",
161160
]

bigframes/bigquery/_operations/ml.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import typing
1718
from typing import Mapping, Optional, TYPE_CHECKING, Union
1819

1920
import bigframes.core.log_adapter as log_adapter
@@ -51,6 +52,34 @@ def create_model(
5152
) -> bigframes.ml.base.BaseEstimator:
5253
"""
5354
Creates a BigQuery ML model.
55+
56+
Args:
57+
model_name (str):
58+
The name of the model in BigQuery.
59+
replace (bool, default False):
60+
Whether to replace the model if it already exists.
61+
if_not_exists (bool, default False):
62+
Whether to ignore the error if the model already exists.
63+
transform (list[str], optional):
64+
The TRANSFORM clause, which specifies the preprocessing steps to apply to the input data.
65+
input_schema (Mapping[str, str], optional):
66+
The INPUT clause, which specifies the schema of the input data.
67+
output_schema (Mapping[str, str], optional):
68+
The OUTPUT clause, which specifies the schema of the output data.
69+
connection_name (str, optional):
70+
The connection to use for the model.
71+
options (Mapping[str, Union[str, int, float, bool, list]], optional):
72+
The OPTIONS clause, which specifies the model options.
73+
training_data (Union[dataframe.DataFrame, str], optional):
74+
The query or DataFrame to use for training the model.
75+
custom_holiday (Union[dataframe.DataFrame, str], optional):
76+
The query or DataFrame to use for custom holiday data.
77+
session (bigframes.session.Session, optional):
78+
The BigFrames session to use. If not provided, the default session is used.
79+
80+
Returns:
81+
bigframes.ml.base.BaseEstimator:
82+
The created BigFrames model.
5483
"""
5584
import bigframes.pandas as bpd
5685

bigframes/bigquery/ml.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""This module exposes `BigQuery ML
16-
<https://docs.cloud.google.com/bigquery/docs/bqml-introduction>`_ functions
17-
by directly mapping to the equivalent function names in SQL syntax.
18-
19-
For an interface more familiar to Scikit-Learn users, see :mod:`bigframes.ml`.
20-
"""
15+
"""This module integrates BigQuery ML functions."""
2116

2217
from bigframes.bigquery._operations.ml import create_model
2318

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def _(expr: TypedExpr, op: ops.MapOp) -> sge.Expression:
107107
sge.If(this=sge.convert(key), true=sge.convert(value))
108108
for key, value in op.mappings
109109
],
110-
default=expr.expr,
111110
)
112111

113112

bigframes/core/compile/sqlglot/expressions/string_ops.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import sqlglot.expressions as sge
2020

21-
from bigframes import dtypes
2221
from bigframes import operations as ops
2322
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2423
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
@@ -196,9 +195,6 @@ def _(expr: TypedExpr) -> sge.Expression:
196195

197196
@register_unary_op(ops.len_op)
198197
def _(expr: TypedExpr) -> sge.Expression:
199-
if dtypes.is_array_like(expr.dtype):
200-
return sge.func("ARRAY_LENGTH", expr.expr)
201-
202198
return sge.Length(this=expr.expr)
203199

204200

@@ -243,7 +239,7 @@ def to_startswith(pat: str) -> sge.Expression:
243239

244240
@register_unary_op(ops.StrStripOp, pass_op=True)
245241
def _(expr: TypedExpr, op: ops.StrStripOp) -> sge.Expression:
246-
return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip))
242+
return sge.Trim(this=sge.convert(op.to_strip), expression=expr.expr)
247243

248244

249245
@register_unary_op(ops.StringSplitOp, pass_op=True)
@@ -288,29 +284,27 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
288284

289285
@register_unary_op(ops.ZfillOp, pass_op=True)
290286
def _(expr: TypedExpr, op: ops.ZfillOp) -> sge.Expression:
291-
length_expr = sge.Greatest(
292-
expressions=[sge.Length(this=expr.expr), sge.convert(op.width)]
293-
)
294287
return sge.Case(
295288
ifs=[
296289
sge.If(
297-
this=sge.func(
298-
"STARTS_WITH",
299-
expr.expr,
300-
sge.convert("-"),
290+
this=sge.EQ(
291+
this=sge.Substring(
292+
this=expr.expr, start=sge.convert(1), length=sge.convert(1)
293+
),
294+
expression=sge.convert("-"),
301295
),
302296
true=sge.Concat(
303297
expressions=[
304298
sge.convert("-"),
305299
sge.func(
306300
"LPAD",
307-
sge.Substring(this=expr.expr, start=sge.convert(2)),
308-
length_expr - 1,
301+
sge.Substring(this=expr.expr, start=sge.convert(1)),
302+
sge.convert(op.width - 1),
309303
sge.convert("0"),
310304
),
311305
]
312306
),
313307
)
314308
],
315-
default=sge.func("LPAD", expr.expr, length_expr, sge.convert("0")),
309+
default=sge.func("LPAD", expr.expr, sge.convert(op.width), sge.convert("0")),
316310
)

bigframes/core/sql/ml.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import typing
1718
from typing import Mapping, Optional, Union
1819

1920
import bigframes.core.compile.googlesql as googlesql

bigframes/ml/__init__.py

Lines changed: 4 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -12,82 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""BigQuery DataFrames ML provides a SKLearn-like API on the BigQuery engine.
16-
17-
.. code:: python
18-
19-
from bigframes.ml.linear_model import LinearRegression
20-
model = LinearRegression()
21-
model.fit(feature_columns, label_columns)
22-
model.predict(feature_columns_from_test_data)
23-
24-
You can also save your fit parameters to BigQuery for later use.
25-
26-
.. code:: python
27-
28-
import bigframes.pandas as bpd
29-
model.to_gbq(
30-
your_model_id, # For example: "bqml_tutorial.penguins_model"
31-
replace=True,
32-
)
33-
saved_model = bpd.read_gbq_model(your_model_id)
34-
saved_model.predict(feature_columns_from_test_data)
35-
36-
See the `BigQuery ML linear regression tutorial
37-
<https://docs.cloud.google.com/bigquery/docs/linear-regression-tutorial>`_ for a
38-
detailed example.
39-
40-
See all, the references for ``bigframes.ml`` sub-modules:
41-
42-
* :mod:`bigframes.ml.cluster`
43-
* :mod:`bigframes.ml.compose`
44-
* :mod:`bigframes.ml.decomposition`
45-
* :mod:`bigframes.ml.ensemble`
46-
* :mod:`bigframes.ml.forecasting`
47-
* :mod:`bigframes.ml.imported`
48-
* :mod:`bigframes.ml.impute`
49-
* :mod:`bigframes.ml.linear_model`
50-
* :mod:`bigframes.ml.llm`
51-
* :mod:`bigframes.ml.metrics`
52-
* :mod:`bigframes.ml.model_selection`
53-
* :mod:`bigframes.ml.pipeline`
54-
* :mod:`bigframes.ml.preprocessing`
55-
* :mod:`bigframes.ml.remote`
56-
57-
Alternatively, check out mod:`bigframes.bigquery.ml` for an interface that is
58-
more similar to the BigQuery ML SQL syntax.
59-
"""
60-
61-
from bigframes.ml import (
62-
cluster,
63-
compose,
64-
decomposition,
65-
ensemble,
66-
forecasting,
67-
imported,
68-
impute,
69-
linear_model,
70-
llm,
71-
metrics,
72-
model_selection,
73-
pipeline,
74-
preprocessing,
75-
remote,
76-
)
15+
"""BigQuery DataFrames ML provides a SKLearn-like API on the BigQuery engine."""
7716

7817
__all__ = [
7918
"cluster",
8019
"compose",
8120
"decomposition",
82-
"ensemble",
83-
"forecasting",
84-
"imported",
85-
"impute",
8621
"linear_model",
87-
"llm",
8822
"metrics",
8923
"model_selection",
9024
"pipeline",
9125
"preprocessing",
26+
"llm",
27+
"forecasting",
28+
"imported",
9229
"remote",
9330
]

docs/reference/index.rst

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ packages.
1010
bigframes._config
1111
bigframes.bigquery
1212
bigframes.bigquery.ai
13-
bigframes.bigquery.ml
1413
bigframes.enums
1514
bigframes.exceptions
1615
bigframes.geopandas
@@ -27,8 +26,6 @@ scikit-learn.
2726
.. autosummary::
2827
:toctree: api
2928

30-
bigframes.ml
31-
bigframes.ml.base
3229
bigframes.ml.cluster
3330
bigframes.ml.compose
3431
bigframes.ml.decomposition
@@ -38,7 +35,6 @@ scikit-learn.
3835
bigframes.ml.impute
3936
bigframes.ml.linear_model
4037
bigframes.ml.llm
41-
bigframes.ml.metrics
4238
bigframes.ml.model_selection
4339
bigframes.ml.pipeline
4440
bigframes.ml.preprocessing

tests/system/large/functions/test_remote_function.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1651,12 +1651,13 @@ def square(x):
16511651
return x * x
16521652

16531653

1654-
# The default value of 100 is used if the maximum instances value is not set.
1654+
# Note: Zero represents default, which is 100 instances actually, which is why the remote function still works
1655+
# in the df.apply() call here
16551656
@pytest.mark.parametrize(
16561657
("max_instances_args", "expected_max_instances"),
16571658
[
1658-
pytest.param({}, 100, id="no-set"),
1659-
pytest.param({"cloud_function_max_instances": None}, 100, id="set-None"),
1659+
pytest.param({}, 0, id="no-set"),
1660+
pytest.param({"cloud_function_max_instances": None}, 0, id="set-None"),
16601661
pytest.param({"cloud_function_max_instances": 1000}, 1000, id="set-explicit"),
16611662
],
16621663
)

tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_map/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
CASE `string_col` WHEN 'value1' THEN 'mapped1' ELSE `string_col` END AS `bfcol_1`
8+
CASE `string_col` WHEN 'value1' THEN 'mapped1' END AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

0 commit comments

Comments
 (0)