Skip to content

Commit 6220b87

Browse files
committed
fix: Fix registry cache initialization
Signed-off-by: ntkathole <[email protected]>
1 parent 5b62733 commit 6220b87

File tree

7 files changed

+73
-69
lines changed

7 files changed

+73
-69
lines changed

sdk/python/feast/infra/registry/base_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,7 @@ def list_projects(
784784
raise NotImplementedError
785785

786786
@abstractmethod
787-
def proto(self) -> RegistryProto:
787+
def proto(self, force_refresh: bool = False) -> RegistryProto:
788788
"""
789789
Retrieves a proto version of the registry.
790790

sdk/python/feast/infra/registry/caching_registry.py

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -426,55 +426,46 @@ def list_projects(
426426

427427
def refresh(self, project: Optional[str] = None):
428428
try:
429-
self.cached_registry_proto = self.proto()
429+
self.cached_registry_proto = self.proto(force_refresh=True)
430430
self.cached_registry_proto_created = _utc_now()
431431
except Exception as e:
432432
logger.debug(f"Error while refreshing registry: {e}", exc_info=True)
433433

434434
def _refresh_cached_registry_if_necessary(self):
435435
if self.cache_mode == "sync":
436-
437-
def is_cache_expired():
436+
# Try acquiring the lock without blocking
437+
if not self._refresh_lock.acquire(blocking=False):
438+
logger.debug(
439+
"Skipping refresh if lock is already held by another thread"
440+
)
441+
return
442+
try:
438443
if self.cached_registry_proto == RegistryProto():
439-
if self.cached_registry_proto_ttl.total_seconds() == 0:
440-
return False
441-
else:
442-
return True
443-
444-
# Cache is expired if it's None or creation time is None
445-
if (
446-
self.cached_registry_proto is None
447-
or not hasattr(self, "cached_registry_proto_created")
448-
or self.cached_registry_proto_created is None
449-
):
450-
return True
451-
452-
# Cache is expired if TTL > 0 and current time exceeds creation + TTL
453-
if self.cached_registry_proto_ttl.total_seconds() > 0 and _utc_now() > (
454-
self.cached_registry_proto_created + self.cached_registry_proto_ttl
455-
):
456-
return True
457-
458-
return False
459-
460-
if is_cache_expired():
461-
if not self._refresh_lock.acquire(blocking=False):
462-
logger.debug(
463-
"Skipping refresh if lock is already held by another thread"
464-
)
465-
return
466-
try:
467-
logger.info(
468-
f"Registry cache expired(ttl: {self.cached_registry_proto_ttl.total_seconds()} seconds), so refreshing"
444+
expired = False
445+
else:
446+
expired = (
447+
self.cached_registry_proto is None
448+
or self.cached_registry_proto_created is None
449+
) or (
450+
self.cached_registry_proto_ttl.total_seconds() > 0
451+
and (
452+
_utc_now()
453+
> (
454+
self.cached_registry_proto_created
455+
+ self.cached_registry_proto_ttl
456+
)
457+
)
469458
)
459+
if expired:
460+
logger.debug("Registry cache expired, so refreshing")
470461
self.refresh()
471-
except Exception as e:
472-
logger.debug(
473-
f"Error in _refresh_cached_registry_if_necessary: {e}",
474-
exc_info=True,
475-
)
476-
finally:
477-
self._refresh_lock.release()
462+
except Exception as e:
463+
logger.debug(
464+
f"Error in _refresh_cached_registry_if_necessary: {e}",
465+
exc_info=True,
466+
)
467+
finally:
468+
self._refresh_lock.release() # Always release the lock safely
478469

479470
def _start_thread_async_refresh(self, cache_ttl_seconds):
480471
self.refresh()

sdk/python/feast/infra/registry/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@ def teardown(self):
890890
"""Tears down (removes) the registry."""
891891
self._registry_store.teardown()
892892

893-
def proto(self) -> RegistryProto:
893+
def proto(self, force_refresh: bool = False) -> RegistryProto:
894894
return self.cached_registry_proto or RegistryProto()
895895

896896
def _prepare_registry_for_changes(self, project_name: str):

sdk/python/feast/infra/registry/remote.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def get_project_metadata(self, project: str, key: str) -> Optional[str]:
594594
return meta[key]
595595
return None
596596

597-
def proto(self) -> RegistryProto:
597+
def proto(self, force_refresh: bool = False) -> RegistryProto:
598598
return self.stub.Proto(Empty())
599599

600600
def commit(self):

sdk/python/feast/infra/registry/snowflake.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -493,9 +493,11 @@ def _delete_object(
493493
if cursor.rowcount < 1 and not_found_exception: # type: ignore
494494
raise not_found_exception(name, project)
495495
self._set_last_updated_metadata(_utc_now(), project)
496+
rowcount = cursor.rowcount
496497

497-
self.refresh()
498-
return cursor.rowcount
498+
self.refresh()
499+
500+
return rowcount
499501

500502
def delete_permission(self, name: str, project: str, commit: bool = True):
501503
return self._delete_object(
@@ -1119,7 +1121,7 @@ def get_user_metadata(
11191121
else:
11201122
raise FeatureViewNotFoundException(feature_view.name, project=project)
11211123

1122-
def proto(self) -> RegistryProto:
1124+
def proto(self, force_refresh: bool = False) -> RegistryProto:
11231125
r = RegistryProto()
11241126
last_updated_timestamps = []
11251127

@@ -1128,6 +1130,18 @@ def process_project(project: Project):
11281130
project_name = project.name
11291131
last_updated_timestamp = project.last_updated_timestamp
11301132

1133+
try:
1134+
cached_project = self.get_project(project_name, True)
1135+
except ProjectObjectNotFoundException:
1136+
cached_project = None
1137+
1138+
allow_cache = False
1139+
1140+
if cached_project is not None and not force_refresh:
1141+
allow_cache = (
1142+
last_updated_timestamp <= cached_project.last_updated_timestamp
1143+
)
1144+
11311145
r.projects.extend([project.to_proto()])
11321146
last_updated_timestamps.append(last_updated_timestamp)
11331147

@@ -1142,7 +1156,7 @@ def process_project(project: Project):
11421156
(self.list_validation_references, r.validation_references),
11431157
(self.list_permissions, r.permissions),
11441158
]:
1145-
objs: List[Any] = lister(project_name, allow_cache=False) # type: ignore
1159+
objs: List[Any] = lister(project_name, allow_cache) # type: ignore
11461160
if objs:
11471161
obj_protos = [obj.to_proto() for obj in objs]
11481162
for obj_proto in obj_protos:

sdk/python/feast/infra/registry/sql.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,7 @@ def get_user_metadata(
838838
else:
839839
raise FeatureViewNotFoundException(feature_view.name, project=project)
840840

841-
def proto(self) -> RegistryProto:
841+
def proto(self, force_refresh: bool = False) -> RegistryProto:
842842
r = RegistryProto()
843843
last_updated_timestamps = []
844844

@@ -847,6 +847,18 @@ def process_project(project: Project):
847847
project_name = project.name
848848
last_updated_timestamp = project.last_updated_timestamp
849849

850+
try:
851+
cached_project = self.get_project(project_name, True)
852+
except ProjectObjectNotFoundException:
853+
cached_project = None
854+
855+
allow_cache = False
856+
857+
if cached_project is not None and not force_refresh:
858+
allow_cache = (
859+
last_updated_timestamp <= cached_project.last_updated_timestamp
860+
)
861+
850862
r.projects.extend([project.to_proto()])
851863
last_updated_timestamps.append(last_updated_timestamp)
852864

@@ -861,7 +873,7 @@ def process_project(project: Project):
861873
(self.list_validation_references, r.validation_references),
862874
(self.list_permissions, r.permissions),
863875
]:
864-
objs: List[Any] = lister(project_name, allow_cache=False) # type: ignore
876+
objs: List[Any] = lister(project_name, allow_cache) # type: ignore
865877
if objs:
866878
obj_protos = [obj.to_proto() for obj in objs]
867879
for obj_proto in obj_protos:
@@ -1055,9 +1067,12 @@ def _delete_object(
10551067
if not self.purge_feast_metadata:
10561068
self._set_last_updated_metadata(_utc_now(), project, conn)
10571069

1058-
if self.cache_mode == "sync":
1059-
self.refresh()
1060-
return rows.rowcount
1070+
rowcount = rows.rowcount
1071+
1072+
if self.cache_mode == "sync":
1073+
self.refresh()
1074+
1075+
return rowcount
10611076

10621077
def _get_object(
10631078
self,

sdk/python/tests/unit/infra/registry/test_registry.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -189,22 +189,6 @@ def test_cache_expiry_triggers_refresh(registry):
189189
mock_refresh.assert_called_once()
190190

191191

192-
def test_empty_cache_refresh_with_ttl(registry):
193-
"""Test that empty cache is refreshed when TTL > 0"""
194-
# Set up empty cache with TTL > 0
195-
registry.cached_registry_proto = RegistryProto()
196-
registry.cached_registry_proto_created = datetime.now(timezone.utc)
197-
registry.cached_registry_proto_ttl = timedelta(seconds=10) # TTL > 0
198-
199-
# Mock refresh to check if it's called
200-
with patch.object(
201-
CachingRegistry, "refresh", wraps=registry.refresh
202-
) as mock_refresh:
203-
registry._refresh_cached_registry_if_necessary()
204-
# Should refresh because cache is empty and TTL > 0
205-
mock_refresh.assert_called_once()
206-
207-
208192
def test_empty_cache_no_refresh_with_infinite_ttl(registry):
209193
"""Test that empty cache is not refreshed when TTL = 0 (infinite)"""
210194
# Set up empty cache with TTL = 0 (infinite)
@@ -227,7 +211,7 @@ def test_concurrent_cache_refresh_race_condition(registry):
227211
import time
228212

229213
# Set up expired cache
230-
registry.cached_registry_proto = RegistryProto()
214+
registry.cached_registry_proto = "some_cached_data" # Not empty
231215
registry.cached_registry_proto_created = datetime.now(timezone.utc) - timedelta(
232216
seconds=5
233217
)

0 commit comments

Comments
 (0)