Skip to content

Commit 1181a9e

Browse files
authored
fix: Snowflake api update (#2487)
* Update snowflake source Signed-off-by: Kevin Zhang <[email protected]> * Fix snowflake Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]>
1 parent ce35835 commit 1181a9e

File tree

6 files changed

+43
-3
lines changed

6 files changed

+43
-3
lines changed

protos/feast/core/DataSource.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ message DataSource {
164164

165165
// Snowflake schema name
166166
string database = 4;
167+
168+
// Snowflake warehouse name
169+
string warehouse = 5;
167170
}
168171

169172
// Defines configuration for custom third-party data sources.

sdk/python/feast/infra/offline_stores/snowflake.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ def pull_latest_from_table_or_query(
128128
+ '"'
129129
)
130130

131+
if data_source.snowflake_options.warehouse:
132+
config.offline_store.warehouse = data_source.snowflake_options.warehouse
133+
131134
snowflake_conn = get_snowflake_conn(config.offline_store)
132135

133136
query = f"""
@@ -173,6 +176,9 @@ def pull_all_from_table_or_query(
173176
+ '"'
174177
)
175178

179+
if data_source.snowflake_options.warehouse:
180+
config.offline_store.warehouse = data_source.snowflake_options.warehouse
181+
176182
snowflake_conn = get_snowflake_conn(config.offline_store)
177183

178184
start_date = start_date.astimezone(tz=utc)

sdk/python/feast/infra/offline_stores/snowflake_source.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class SnowflakeSource(DataSource):
1616
def __init__(
1717
self,
1818
database: Optional[str] = None,
19+
warehouse: Optional[str] = None,
1920
schema: Optional[str] = None,
2021
table: Optional[str] = None,
2122
query: Optional[str] = None,
@@ -33,6 +34,7 @@ def __init__(
3334
3435
Args:
3536
database (optional): Snowflake database where the features are stored.
37+
warehouse (optional): Snowflake warehouse where the database is stored.
3638
schema (optional): Snowflake schema in which the table is located.
3739
table (optional): Snowflake table where the features are stored.
3840
event_timestamp_column (optional): Event timestamp column used for point in
@@ -55,7 +57,11 @@ def __init__(
5557
_schema = "PUBLIC" if (database and table and not schema) else schema
5658

5759
self.snowflake_options = SnowflakeOptions(
58-
database=database, schema=_schema, table=table, query=query
60+
database=database,
61+
schema=_schema,
62+
table=table,
63+
query=query,
64+
warehouse=warehouse,
5965
)
6066

6167
# If no name, use the table as the default name
@@ -107,6 +113,7 @@ def from_proto(data_source: DataSourceProto):
107113
database=data_source.snowflake_options.database,
108114
schema=data_source.snowflake_options.schema,
109115
table=data_source.snowflake_options.table,
116+
warehouse=data_source.snowflake_options.warehouse,
110117
event_timestamp_column=data_source.event_timestamp_column,
111118
created_timestamp_column=data_source.created_timestamp_column,
112119
query=data_source.snowflake_options.query,
@@ -131,6 +138,7 @@ def __eq__(self, other):
131138
and self.snowflake_options.schema == other.snowflake_options.schema
132139
and self.snowflake_options.table == other.snowflake_options.table
133140
and self.snowflake_options.query == other.snowflake_options.query
141+
and self.snowflake_options.warehouse == other.snowflake_options.warehouse
134142
and self.event_timestamp_column == other.event_timestamp_column
135143
and self.created_timestamp_column == other.created_timestamp_column
136144
and self.field_mapping == other.field_mapping
@@ -159,6 +167,11 @@ def query(self):
159167
"""Returns the snowflake options of this snowflake source."""
160168
return self.snowflake_options.query
161169

170+
@property
171+
def warehouse(self):
172+
"""Returns the warehouse of this snowflake source."""
173+
return self.snowflake_options.warehouse
174+
162175
def to_proto(self) -> DataSourceProto:
163176
"""
164177
Converts a SnowflakeSource object to its protobuf representation.
@@ -245,11 +258,13 @@ def __init__(
245258
schema: Optional[str],
246259
table: Optional[str],
247260
query: Optional[str],
261+
warehouse: Optional[str],
248262
):
249263
self._database = database
250264
self._schema = schema
251265
self._table = table
252266
self._query = query
267+
self._warehouse = warehouse
253268

254269
@property
255270
def query(self):
@@ -291,6 +306,16 @@ def table(self, table):
291306
"""Sets the table ref of this snowflake table."""
292307
self._table = table
293308

309+
@property
310+
def warehouse(self):
311+
"""Returns the warehouse name of this snowflake table."""
312+
return self._warehouse
313+
314+
@warehouse.setter
315+
def warehouse(self, warehouse):
316+
"""Sets the warehouse name of this snowflake table."""
317+
self._warehouse = warehouse
318+
294319
@classmethod
295320
def from_proto(cls, snowflake_options_proto: DataSourceProto.SnowflakeOptions):
296321
"""
@@ -307,6 +332,7 @@ def from_proto(cls, snowflake_options_proto: DataSourceProto.SnowflakeOptions):
307332
schema=snowflake_options_proto.schema,
308333
table=snowflake_options_proto.table,
309334
query=snowflake_options_proto.query,
335+
warehouse=snowflake_options_proto.warehouse,
310336
)
311337

312338
return snowflake_options
@@ -323,6 +349,7 @@ def to_proto(self) -> DataSourceProto.SnowflakeOptions:
323349
schema=self.schema,
324350
table=self.table,
325351
query=self.query,
352+
warehouse=self.warehouse,
326353
)
327354

328355
return snowflake_options_proto
@@ -335,7 +362,7 @@ class SavedDatasetSnowflakeStorage(SavedDatasetStorage):
335362

336363
def __init__(self, table_ref: str):
337364
self.snowflake_options = SnowflakeOptions(
338-
database=None, schema=None, table=table_ref, query=None
365+
database=None, schema=None, table=table_ref, query=None, warehouse=None
339366
)
340367

341368
@staticmethod

sdk/python/feast/templates/snowflake/bootstrap.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def bootstrap():
6868

6969
repo_path = pathlib.Path(__file__).parent.absolute()
7070
config_file = repo_path / "feature_store.yaml"
71-
71+
driver_file = repo_path / "driver_repo.py"
7272
replace_str_in_file(
7373
config_file, "SNOWFLAKE_DEPLOYMENT_URL", snowflake_deployment_url
7474
)
@@ -78,6 +78,8 @@ def bootstrap():
7878
replace_str_in_file(config_file, "SNOWFLAKE_WAREHOUSE", snowflake_warehouse)
7979
replace_str_in_file(config_file, "SNOWFLAKE_DATABASE", snowflake_database)
8080

81+
replace_str_in_file(driver_file, "SNOWFLAKE_WAREHOUSE", snowflake_warehouse)
82+
8183

8284
def replace_str_in_file(file_path, match_str, sub_str):
8385
with open(file_path, "r") as f:

sdk/python/feast/templates/snowflake/driver_repo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# The Snowflake table where features can be found
2525
database=yaml.safe_load(open("feature_store.yaml"))["offline_store"]["database"],
2626
table=f"{project_name}_feast_driver_hourly_stats",
27+
warehouse="SNOWFLAKE_WAREHOUSE",
2728
# The event timestamp is used for point-in-time joins and for ensuring only
2829
# features within the TTL are returned
2930
event_timestamp_column="event_timestamp",

sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def create_data_source(
5656
event_timestamp_column=event_timestamp_column,
5757
created_timestamp_column=created_timestamp_column,
5858
field_mapping=field_mapping or {"ts_1": "ts"},
59+
warehouse=self.offline_store_config.warehouse,
5960
)
6061

6162
def create_saved_dataset_destination(self) -> SavedDatasetSnowflakeStorage:

0 commit comments

Comments
 (0)