@@ -16,6 +16,7 @@ class SnowflakeSource(DataSource):
16
16
def __init__ (
17
17
self ,
18
18
database : Optional [str ] = None ,
19
+ warehouse : Optional [str ] = None ,
19
20
schema : Optional [str ] = None ,
20
21
table : Optional [str ] = None ,
21
22
query : Optional [str ] = None ,
@@ -33,6 +34,7 @@ def __init__(
33
34
34
35
Args:
35
36
database (optional): Snowflake database where the features are stored.
37
+ warehouse (optional): Snowflake warehouse where the database is stored.
36
38
schema (optional): Snowflake schema in which the table is located.
37
39
table (optional): Snowflake table where the features are stored.
38
40
event_timestamp_column (optional): Event timestamp column used for point in
@@ -55,7 +57,11 @@ def __init__(
55
57
_schema = "PUBLIC" if (database and table and not schema ) else schema
56
58
57
59
self .snowflake_options = SnowflakeOptions (
58
- database = database , schema = _schema , table = table , query = query
60
+ database = database ,
61
+ schema = _schema ,
62
+ table = table ,
63
+ query = query ,
64
+ warehouse = warehouse ,
59
65
)
60
66
61
67
# If no name, use the table as the default name
@@ -107,6 +113,7 @@ def from_proto(data_source: DataSourceProto):
107
113
database = data_source .snowflake_options .database ,
108
114
schema = data_source .snowflake_options .schema ,
109
115
table = data_source .snowflake_options .table ,
116
+ warehouse = data_source .snowflake_options .warehouse ,
110
117
event_timestamp_column = data_source .event_timestamp_column ,
111
118
created_timestamp_column = data_source .created_timestamp_column ,
112
119
query = data_source .snowflake_options .query ,
@@ -131,6 +138,7 @@ def __eq__(self, other):
131
138
and self .snowflake_options .schema == other .snowflake_options .schema
132
139
and self .snowflake_options .table == other .snowflake_options .table
133
140
and self .snowflake_options .query == other .snowflake_options .query
141
+ and self .snowflake_options .warehouse == other .snowflake_options .warehouse
134
142
and self .event_timestamp_column == other .event_timestamp_column
135
143
and self .created_timestamp_column == other .created_timestamp_column
136
144
and self .field_mapping == other .field_mapping
@@ -159,6 +167,11 @@ def query(self):
159
167
"""Returns the snowflake options of this snowflake source."""
160
168
return self .snowflake_options .query
161
169
170
+ @property
171
+ def warehouse (self ):
172
+ """Returns the warehouse of this snowflake source."""
173
+ return self .snowflake_options .warehouse
174
+
162
175
def to_proto (self ) -> DataSourceProto :
163
176
"""
164
177
Converts a SnowflakeSource object to its protobuf representation.
@@ -245,11 +258,13 @@ def __init__(
245
258
schema : Optional [str ],
246
259
table : Optional [str ],
247
260
query : Optional [str ],
261
+ warehouse : Optional [str ],
248
262
):
249
263
self ._database = database
250
264
self ._schema = schema
251
265
self ._table = table
252
266
self ._query = query
267
+ self ._warehouse = warehouse
253
268
254
269
@property
255
270
def query (self ):
@@ -291,6 +306,16 @@ def table(self, table):
291
306
"""Sets the table ref of this snowflake table."""
292
307
self ._table = table
293
308
309
+ @property
310
+ def warehouse (self ):
311
+ """Returns the warehouse name of this snowflake table."""
312
+ return self ._warehouse
313
+
314
+ @warehouse .setter
315
+ def warehouse (self , warehouse ):
316
+ """Sets the warehouse name of this snowflake table."""
317
+ self ._warehouse = warehouse
318
+
294
319
@classmethod
295
320
def from_proto (cls , snowflake_options_proto : DataSourceProto .SnowflakeOptions ):
296
321
"""
@@ -307,6 +332,7 @@ def from_proto(cls, snowflake_options_proto: DataSourceProto.SnowflakeOptions):
307
332
schema = snowflake_options_proto .schema ,
308
333
table = snowflake_options_proto .table ,
309
334
query = snowflake_options_proto .query ,
335
+ warehouse = snowflake_options_proto .warehouse ,
310
336
)
311
337
312
338
return snowflake_options
@@ -323,6 +349,7 @@ def to_proto(self) -> DataSourceProto.SnowflakeOptions:
323
349
schema = self .schema ,
324
350
table = self .table ,
325
351
query = self .query ,
352
+ warehouse = self .warehouse ,
326
353
)
327
354
328
355
return snowflake_options_proto
@@ -335,7 +362,7 @@ class SavedDatasetSnowflakeStorage(SavedDatasetStorage):
335
362
336
363
def __init__ (self , table_ref : str ):
337
364
self .snowflake_options = SnowflakeOptions (
338
- database = None , schema = None , table = table_ref , query = None
365
+ database = None , schema = None , table = table_ref , query = None , warehouse = None
339
366
)
340
367
341
368
@staticmethod
0 commit comments