Skip to content

Commit 6bcf77c

Browse files
authored
fix: Fix bug where SQL registry was incorrectly writing infra config around online stores (#3394)
fix: Fix bug where SQL registry was incorrectly writing info around sqlite online store Signed-off-by: Danny Chiao <[email protected]> Signed-off-by: Danny Chiao <[email protected]>
1 parent fd97254 commit 6bcf77c

File tree

2 files changed

+122
-84
lines changed

2 files changed

+122
-84
lines changed

sdk/python/feast/infra/registry/sql.py

Lines changed: 94 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -207,14 +207,14 @@ def get_stream_feature_view(
207207
self, name: str, project: str, allow_cache: bool = False
208208
):
209209
return self._get_object(
210-
stream_feature_views,
211-
name,
212-
project,
213-
StreamFeatureViewProto,
214-
StreamFeatureView,
215-
"feature_view_name",
216-
"feature_view_proto",
217-
FeatureViewNotFoundException,
210+
table=stream_feature_views,
211+
name=name,
212+
project=project,
213+
proto_class=StreamFeatureViewProto,
214+
python_class=StreamFeatureView,
215+
id_field_name="feature_view_name",
216+
proto_field_name="feature_view_proto",
217+
not_found_exception=FeatureViewNotFoundException,
218218
)
219219

220220
def list_stream_feature_views(
@@ -230,101 +230,105 @@ def list_stream_feature_views(
230230

231231
def apply_entity(self, entity: Entity, project: str, commit: bool = True):
232232
return self._apply_object(
233-
entities, project, "entity_name", entity, "entity_proto"
233+
table=entities,
234+
project=project,
235+
id_field_name="entity_name",
236+
obj=entity,
237+
proto_field_name="entity_proto",
234238
)
235239

236240
def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity:
237241
return self._get_object(
238-
entities,
239-
name,
240-
project,
241-
EntityProto,
242-
Entity,
243-
"entity_name",
244-
"entity_proto",
245-
EntityNotFoundException,
242+
table=entities,
243+
name=name,
244+
project=project,
245+
proto_class=EntityProto,
246+
python_class=Entity,
247+
id_field_name="entity_name",
248+
proto_field_name="entity_proto",
249+
not_found_exception=EntityNotFoundException,
246250
)
247251

248252
def get_feature_view(
249253
self, name: str, project: str, allow_cache: bool = False
250254
) -> FeatureView:
251255
return self._get_object(
252-
feature_views,
253-
name,
254-
project,
255-
FeatureViewProto,
256-
FeatureView,
257-
"feature_view_name",
258-
"feature_view_proto",
259-
FeatureViewNotFoundException,
256+
table=feature_views,
257+
name=name,
258+
project=project,
259+
proto_class=FeatureViewProto,
260+
python_class=FeatureView,
261+
id_field_name="feature_view_name",
262+
proto_field_name="feature_view_proto",
263+
not_found_exception=FeatureViewNotFoundException,
260264
)
261265

262266
def get_on_demand_feature_view(
263267
self, name: str, project: str, allow_cache: bool = False
264268
) -> OnDemandFeatureView:
265269
return self._get_object(
266-
on_demand_feature_views,
267-
name,
268-
project,
269-
OnDemandFeatureViewProto,
270-
OnDemandFeatureView,
271-
"feature_view_name",
272-
"feature_view_proto",
273-
FeatureViewNotFoundException,
270+
table=on_demand_feature_views,
271+
name=name,
272+
project=project,
273+
proto_class=OnDemandFeatureViewProto,
274+
python_class=OnDemandFeatureView,
275+
id_field_name="feature_view_name",
276+
proto_field_name="feature_view_proto",
277+
not_found_exception=FeatureViewNotFoundException,
274278
)
275279

276280
def get_request_feature_view(self, name: str, project: str):
277281
return self._get_object(
278-
request_feature_views,
279-
name,
280-
project,
281-
RequestFeatureViewProto,
282-
RequestFeatureView,
283-
"feature_view_name",
284-
"feature_view_proto",
285-
FeatureViewNotFoundException,
282+
table=request_feature_views,
283+
name=name,
284+
project=project,
285+
proto_class=RequestFeatureViewProto,
286+
python_class=RequestFeatureView,
287+
id_field_name="feature_view_name",
288+
proto_field_name="feature_view_proto",
289+
not_found_exception=FeatureViewNotFoundException,
286290
)
287291

288292
def get_feature_service(
289293
self, name: str, project: str, allow_cache: bool = False
290294
) -> FeatureService:
291295
return self._get_object(
292-
feature_services,
293-
name,
294-
project,
295-
FeatureServiceProto,
296-
FeatureService,
297-
"feature_service_name",
298-
"feature_service_proto",
299-
FeatureServiceNotFoundException,
296+
table=feature_services,
297+
name=name,
298+
project=project,
299+
proto_class=FeatureServiceProto,
300+
python_class=FeatureService,
301+
id_field_name="feature_service_name",
302+
proto_field_name="feature_service_proto",
303+
not_found_exception=FeatureServiceNotFoundException,
300304
)
301305

302306
def get_saved_dataset(
303307
self, name: str, project: str, allow_cache: bool = False
304308
) -> SavedDataset:
305309
return self._get_object(
306-
saved_datasets,
307-
name,
308-
project,
309-
SavedDatasetProto,
310-
SavedDataset,
311-
"saved_dataset_name",
312-
"saved_dataset_proto",
313-
SavedDatasetNotFound,
310+
table=saved_datasets,
311+
name=name,
312+
project=project,
313+
proto_class=SavedDatasetProto,
314+
python_class=SavedDataset,
315+
id_field_name="saved_dataset_name",
316+
proto_field_name="saved_dataset_proto",
317+
not_found_exception=SavedDatasetNotFound,
314318
)
315319

316320
def get_validation_reference(
317321
self, name: str, project: str, allow_cache: bool = False
318322
) -> ValidationReference:
319323
return self._get_object(
320-
validation_references,
321-
name,
322-
project,
323-
ValidationReferenceProto,
324-
ValidationReference,
325-
"validation_reference_name",
326-
"validation_reference_proto",
327-
ValidationReferenceNotFound,
324+
table=validation_references,
325+
name=name,
326+
project=project,
327+
proto_class=ValidationReferenceProto,
328+
python_class=ValidationReference,
329+
id_field_name="validation_reference_name",
330+
proto_field_name="validation_reference_proto",
331+
not_found_exception=ValidationReferenceNotFound,
328332
)
329333

330334
def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity]:
@@ -364,14 +368,14 @@ def get_data_source(
364368
self, name: str, project: str, allow_cache: bool = False
365369
) -> DataSource:
366370
return self._get_object(
367-
data_sources,
368-
name,
369-
project,
370-
DataSourceProto,
371-
DataSource,
372-
"data_source_name",
373-
"data_source_proto",
374-
DataSourceObjectNotFoundException,
371+
table=data_sources,
372+
name=name,
373+
project=project,
374+
proto_class=DataSourceProto,
375+
python_class=DataSource,
376+
id_field_name="data_source_name",
377+
proto_field_name="data_source_proto",
378+
not_found_exception=DataSourceObjectNotFoundException,
375379
)
376380

377381
def list_data_sources(
@@ -556,22 +560,28 @@ def delete_validation_reference(self, name: str, project: str, commit: bool = Tr
556560

557561
def update_infra(self, infra: Infra, project: str, commit: bool = True):
558562
self._apply_object(
559-
managed_infra, project, "infra_name", infra, "infra_proto", name="infra_obj"
563+
table=managed_infra,
564+
project=project,
565+
id_field_name="infra_name",
566+
obj=infra,
567+
proto_field_name="infra_proto",
568+
name="infra_obj",
560569
)
561570

562571
def get_infra(self, project: str, allow_cache: bool = False) -> Infra:
563572
infra_object = self._get_object(
564-
managed_infra,
565-
"infra_obj",
566-
project,
567-
InfraProto,
568-
Infra,
569-
"infra_name",
570-
"infra_proto",
571-
None,
573+
table=managed_infra,
574+
name="infra_obj",
575+
project=project,
576+
proto_class=InfraProto,
577+
python_class=Infra,
578+
id_field_name="infra_name",
579+
proto_field_name="infra_proto",
580+
not_found_exception=None,
572581
)
573-
infra_object = infra_object or InfraProto()
574-
return Infra.from_proto(infra_object)
582+
if infra_object:
583+
return infra_object
584+
return Infra()
575585

576586
def apply_user_metadata(
577587
self,

sdk/python/tests/unit/test_sql_registry.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from feast.errors import FeatureViewNotFoundException
2929
from feast.feature_view import FeatureView
3030
from feast.field import Field
31+
from feast.infra.infra_object import Infra
32+
from feast.infra.online_stores.sqlite import SqliteTable
3133
from feast.infra.registry.sql import SqlRegistry
3234
from feast.on_demand_feature_view import on_demand_feature_view
3335
from feast.repo_config import RegistryConfig
@@ -258,10 +260,20 @@ def test_apply_feature_view_success(sql_registry):
258260
and feature_view.features[3].dtype == Array(Bytes)
259261
and feature_view.entities[0] == "fs1_my_entity_1"
260262
)
263+
assert feature_view.ttl == timedelta(minutes=5)
261264

262265
# After the first apply, the created_timestamp should be the same as the last_update_timestamp.
263266
assert feature_view.created_timestamp == feature_view.last_updated_timestamp
264267

268+
# Modify the feature view and apply again to test if diffing the online store table works
269+
fv1.ttl = timedelta(minutes=6)
270+
sql_registry.apply_feature_view(fv1, project)
271+
feature_views = sql_registry.list_feature_views(project)
272+
assert len(feature_views) == 1
273+
feature_view = sql_registry.get_feature_view("my_feature_view_1", project)
274+
assert feature_view.ttl == timedelta(minutes=6)
275+
276+
# Delete feature view
265277
sql_registry.delete_feature_view("my_feature_view_1", project)
266278
feature_views = sql_registry.list_feature_views(project)
267279
assert len(feature_views) == 0
@@ -570,6 +582,22 @@ def test_update_infra(sql_registry):
570582
project = "project"
571583
infra = sql_registry.get_infra(project=project)
572584

585+
assert len(infra.infra_objects) == 0
586+
573587
# Should run update infra successfully
574588
sql_registry.update_infra(infra, project)
589+
590+
# Should run update infra successfully when adding
591+
new_infra = Infra()
592+
new_infra.infra_objects.append(
593+
SqliteTable(
594+
path="/tmp/my_path.db",
595+
name="my_table",
596+
)
597+
)
598+
sql_registry.update_infra(new_infra, project)
599+
infra = sql_registry.get_infra(project=project)
600+
assert len(infra.infra_objects) == 1
601+
602+
# Try again since second time, infra should be not-empty
575603
sql_registry.teardown()

0 commit comments

Comments
 (0)