Skip to content

Commit 2348c85

Browse files
committed
fix: Fixed cache refresh
Signed-off-by: ntkathole <[email protected]>
1 parent 02c3006 commit 2348c85

File tree

6 files changed

+167
-83
lines changed

6 files changed

+167
-83
lines changed

sdk/python/feast/api/registry/rest/metrics.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import json
2-
import logging
32
from typing import Dict, List, Optional
43

54
from fastapi import APIRouter, Depends, HTTPException, Query, Request
@@ -57,7 +56,6 @@ class PopularTagsResponse(BaseModel):
5756

5857

5958
def get_metrics_router(grpc_handler, server=None) -> APIRouter:
60-
logger = logging.getLogger(__name__)
6159
router = APIRouter()
6260

6361
@router.get("/metrics/resource_counts", tags=["Metrics"])
@@ -321,20 +319,43 @@ async def recently_visited(
321319
user = getattr(request.state, "user", None)
322320
if not user:
323321
user = "anonymous"
324-
project_val = project or (server.store.project if server else None)
325322
key = f"recently_visited_{user}"
326-
logger.info(
327-
f"[/metrics/recently_visited] Project: {project_val}, Key: {key}, Object: {object_type}"
328-
)
329-
try:
330-
visits_json = (
331-
server.registry.get_project_metadata(project_val, key)
332-
if server
333-
else None
334-
)
335-
visits = json.loads(visits_json) if visits_json else []
336-
except Exception:
337-
visits = []
323+
visits = []
324+
if project:
325+
try:
326+
visits_json = (
327+
server.registry.get_project_metadata(project, key)
328+
if server
329+
else None
330+
)
331+
visits = json.loads(visits_json) if visits_json else []
332+
except Exception:
333+
visits = []
334+
else:
335+
try:
336+
if server:
337+
projects_resp = grpc_call(
338+
grpc_handler.ListProjects,
339+
RegistryServer_pb2.ListProjectsRequest(allow_cache=True),
340+
)
341+
all_projects = [
342+
p["spec"]["name"] for p in projects_resp.get("projects", [])
343+
]
344+
for project_name in all_projects:
345+
try:
346+
visits_json = server.registry.get_project_metadata(
347+
project_name, key
348+
)
349+
if visits_json:
350+
project_visits = json.loads(visits_json)
351+
visits.extend(project_visits)
352+
except Exception:
353+
continue
354+
visits = sorted(
355+
visits, key=lambda x: x.get("timestamp", ""), reverse=True
356+
)
357+
except Exception:
358+
visits = []
338359
if object_type:
339360
visits = [v for v in visits if v.get("object") == object_type]
340361

sdk/python/feast/infra/registry/caching_registry.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -425,9 +425,6 @@ def list_projects(
425425
return self._list_projects(tags)
426426

427427
def refresh(self, project: Optional[str] = None):
428-
if self._refresh_lock.locked():
429-
logger.debug("Skipping refresh if already in progress")
430-
return
431428
try:
432429
self.cached_registry_proto = self.proto()
433430
self.cached_registry_proto_created = _utc_now()
@@ -436,43 +433,48 @@ def refresh(self, project: Optional[str] = None):
436433

437434
def _refresh_cached_registry_if_necessary(self):
438435
if self.cache_mode == "sync":
439-
# Try acquiring the lock without blocking
440-
if not self._refresh_lock.acquire(blocking=False):
441-
logger.debug(
442-
"Skipping refresh if lock is already held by another thread"
443-
)
444-
return
445-
try:
436+
437+
def is_cache_expired():
446438
if self.cached_registry_proto == RegistryProto():
447-
# Avoids the need to refresh the registry when cache is not populated yet
448-
# Specially during the __init__ phase
449-
# proto() will populate the cache with project metadata if no objects are registered
450-
expired = False
451-
else:
452-
expired = (
453-
self.cached_registry_proto is None
454-
or self.cached_registry_proto_created is None
455-
) or (
456-
self.cached_registry_proto_ttl.total_seconds()
457-
> 0 # 0 ttl means infinity
458-
and (
459-
_utc_now()
460-
> (
461-
self.cached_registry_proto_created
462-
+ self.cached_registry_proto_ttl
463-
)
464-
)
439+
if self.cached_registry_proto_ttl.total_seconds() == 0:
440+
return False
441+
else:
442+
return True
443+
444+
# Cache is expired if it's None or creation time is None
445+
if (
446+
self.cached_registry_proto is None
447+
or not hasattr(self, "cached_registry_proto_created")
448+
or self.cached_registry_proto_created is None
449+
):
450+
return True
451+
452+
# Cache is expired if TTL > 0 and current time exceeds creation + TTL
453+
if self.cached_registry_proto_ttl.total_seconds() > 0 and _utc_now() > (
454+
self.cached_registry_proto_created + self.cached_registry_proto_ttl
455+
):
456+
return True
457+
458+
return False
459+
460+
if is_cache_expired():
461+
if not self._refresh_lock.acquire(blocking=False):
462+
logger.debug(
463+
"Skipping refresh if lock is already held by another thread"
464+
)
465+
return
466+
try:
467+
logger.info(
468+
f"Registry cache expired(ttl: {self.cached_registry_proto_ttl.total_seconds()} seconds), so refreshing"
465469
)
466-
if expired:
467-
logger.debug("Registry cache expired, so refreshing")
468470
self.refresh()
469-
except Exception as e:
470-
logger.debug(
471-
f"Error in _refresh_cached_registry_if_necessary: {e}",
472-
exc_info=True,
473-
)
474-
finally:
475-
self._refresh_lock.release() # Always release the lock safely
471+
except Exception as e:
472+
logger.debug(
473+
f"Error in _refresh_cached_registry_if_necessary: {e}",
474+
exc_info=True,
475+
)
476+
finally:
477+
self._refresh_lock.release()
476478

477479
def _start_thread_async_refresh(self, cache_ttl_seconds):
478480
self.refresh()

sdk/python/feast/infra/registry/snowflake.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,8 @@ def _apply_object(
404404
if not self.purge_feast_metadata:
405405
self._set_last_updated_metadata(update_datetime, project)
406406

407+
self.refresh()
408+
407409
def apply_permission(
408410
self, permission: Permission, project: str, commit: bool = True
409411
):
@@ -492,6 +494,7 @@ def _delete_object(
492494
raise not_found_exception(name, project)
493495
self._set_last_updated_metadata(_utc_now(), project)
494496

497+
self.refresh()
495498
return cursor.rowcount
496499

497500
def delete_permission(self, name: str, project: str, commit: bool = True):
@@ -1125,18 +1128,6 @@ def process_project(project: Project):
11251128
project_name = project.name
11261129
last_updated_timestamp = project.last_updated_timestamp
11271130

1128-
try:
1129-
cached_project = self.get_project(project_name, True)
1130-
except ProjectObjectNotFoundException:
1131-
cached_project = None
1132-
1133-
allow_cache = False
1134-
1135-
if cached_project is not None:
1136-
allow_cache = (
1137-
last_updated_timestamp <= cached_project.last_updated_timestamp
1138-
)
1139-
11401131
r.projects.extend([project.to_proto()])
11411132
last_updated_timestamps.append(last_updated_timestamp)
11421133

@@ -1151,7 +1142,7 @@ def process_project(project: Project):
11511142
(self.list_validation_references, r.validation_references),
11521143
(self.list_permissions, r.permissions),
11531144
]:
1154-
objs: List[Any] = lister(project_name, allow_cache) # type: ignore
1145+
objs: List[Any] = lister(project_name, allow_cache=False) # type: ignore
11551146
if objs:
11561147
obj_protos = [obj.to_proto() for obj in objs]
11571148
for obj_proto in obj_protos:

sdk/python/feast/infra/registry/sql.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
MetaData,
1616
String,
1717
Table,
18+
Text,
1819
create_engine,
1920
delete,
2021
insert,
@@ -209,7 +210,7 @@ class FeastMetadataKeys(Enum):
209210
metadata,
210211
Column("project_id", String(255), primary_key=True),
211212
Column("metadata_key", String(50), primary_key=True),
212-
Column("metadata_value", String(50), nullable=False),
213+
Column("metadata_value", Text, nullable=False),
213214
Column("last_updated_timestamp", BigInteger, nullable=False),
214215
)
215216

@@ -326,6 +327,7 @@ def teardown(self):
326327
entities,
327328
data_sources,
328329
feature_views,
330+
stream_feature_views,
329331
feature_services,
330332
on_demand_feature_views,
331333
saved_datasets,
@@ -845,18 +847,6 @@ def process_project(project: Project):
845847
project_name = project.name
846848
last_updated_timestamp = project.last_updated_timestamp
847849

848-
try:
849-
cached_project = self.get_project(project_name, True)
850-
except ProjectObjectNotFoundException:
851-
cached_project = None
852-
853-
allow_cache = False
854-
855-
if cached_project is not None:
856-
allow_cache = (
857-
last_updated_timestamp <= cached_project.last_updated_timestamp
858-
)
859-
860850
r.projects.extend([project.to_proto()])
861851
last_updated_timestamps.append(last_updated_timestamp)
862852

@@ -871,7 +861,7 @@ def process_project(project: Project):
871861
(self.list_validation_references, r.validation_references),
872862
(self.list_permissions, r.permissions),
873863
]:
874-
objs: List[Any] = lister(project_name, allow_cache) # type: ignore
864+
objs: List[Any] = lister(project_name, allow_cache=False) # type: ignore
875865
if objs:
876866
obj_protos = [obj.to_proto() for obj in objs]
877867
for obj_proto in obj_protos:
@@ -1020,6 +1010,9 @@ def _apply_object(
10201010
if not self.purge_feast_metadata:
10211011
self._set_last_updated_metadata(update_datetime, project, conn)
10221012

1013+
if self.cache_mode == "sync":
1014+
self.refresh()
1015+
10231016
def _maybe_init_project_metadata(self, project):
10241017
# Initialize project metadata if needed
10251018
with self.write_engine.begin() as conn:
@@ -1062,6 +1055,8 @@ def _delete_object(
10621055
if not self.purge_feast_metadata:
10631056
self._set_last_updated_metadata(_utc_now(), project, conn)
10641057

1058+
if self.cache_mode == "sync":
1059+
self.refresh()
10651060
return rows.rowcount
10661061

10671062
def _get_object(

sdk/python/tests/integration/registration/test_universal_registry.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,8 @@ def sqlite_registry():
273273
registry_config = SqlRegistryConfig(
274274
registry_type="sql",
275275
path="sqlite://",
276+
cache_ttl_seconds=2,
277+
cache_mode="sync",
276278
)
277279

278280
yield SqlRegistry(registry_config, "project", None)
@@ -1156,11 +1158,10 @@ def test_registry_cache(test_registry):
11561158
registry_data_sources_cached = test_registry.list_data_sources(
11571159
project, allow_cache=True
11581160
)
1159-
# Not refreshed cache, so cache miss
1160-
assert len(registry_feature_views_cached) == 0
1161-
assert len(registry_data_sources_cached) == 0
1161+
assert len(registry_feature_views_cached) == 1
1162+
assert len(registry_data_sources_cached) == 1
1163+
11621164
test_registry.refresh(project)
1163-
# Now objects exist
11641165
registry_feature_views_cached = test_registry.list_feature_views(
11651166
project, allow_cache=True, tags=fv1.tags
11661167
)

sdk/python/tests/unit/infra/registry/test_registry.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
from feast.infra.registry.caching_registry import CachingRegistry
7+
from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto
78

89

910
class TestCachingRegistry(CachingRegistry):
@@ -188,6 +189,79 @@ def test_cache_expiry_triggers_refresh(registry):
188189
mock_refresh.assert_called_once()
189190

190191

192+
def test_empty_cache_refresh_with_ttl(registry):
193+
"""Test that empty cache is refreshed when TTL > 0"""
194+
# Set up empty cache with TTL > 0
195+
registry.cached_registry_proto = RegistryProto()
196+
registry.cached_registry_proto_created = datetime.now(timezone.utc)
197+
registry.cached_registry_proto_ttl = timedelta(seconds=10) # TTL > 0
198+
199+
# Mock refresh to check if it's called
200+
with patch.object(
201+
CachingRegistry, "refresh", wraps=registry.refresh
202+
) as mock_refresh:
203+
registry._refresh_cached_registry_if_necessary()
204+
# Should refresh because cache is empty and TTL > 0
205+
mock_refresh.assert_called_once()
206+
207+
208+
def test_empty_cache_no_refresh_with_infinite_ttl(registry):
209+
"""Test that empty cache is not refreshed when TTL = 0 (infinite)"""
210+
# Set up empty cache with TTL = 0 (infinite)
211+
registry.cached_registry_proto = RegistryProto()
212+
registry.cached_registry_proto_created = datetime.now(timezone.utc)
213+
registry.cached_registry_proto_ttl = timedelta(seconds=0) # TTL = 0 (infinite)
214+
215+
# Mock refresh to check if it's called
216+
with patch.object(
217+
CachingRegistry, "refresh", wraps=registry.refresh
218+
) as mock_refresh:
219+
registry._refresh_cached_registry_if_necessary()
220+
# Should not refresh because TTL = 0 (infinite)
221+
mock_refresh.assert_not_called()
222+
223+
224+
def test_concurrent_cache_refresh_race_condition(registry):
225+
"""Test that concurrent requests don't skip cache refresh when cache is expired"""
226+
import threading
227+
import time
228+
229+
# Set up expired cache
230+
registry.cached_registry_proto = RegistryProto()
231+
registry.cached_registry_proto_created = datetime.now(timezone.utc) - timedelta(
232+
seconds=5
233+
)
234+
registry.cached_registry_proto_ttl = timedelta(
235+
seconds=2
236+
) # TTL = 2 seconds, cache is expired
237+
238+
refresh_calls = []
239+
240+
def mock_refresh():
241+
refresh_calls.append(threading.current_thread().ident)
242+
time.sleep(0.1) # Simulate refresh work
243+
244+
# Mock the refresh method to track calls
245+
with patch.object(registry, "refresh", side_effect=mock_refresh):
246+
# Simulate concurrent requests
247+
threads = []
248+
for i in range(3):
249+
thread = threading.Thread(
250+
target=registry._refresh_cached_registry_if_necessary
251+
)
252+
threads.append(thread)
253+
thread.start()
254+
255+
# Wait for all threads to complete
256+
for thread in threads:
257+
thread.join()
258+
259+
# At least one thread should have called refresh (the first one to acquire the lock)
260+
assert len(refresh_calls) >= 1, (
261+
"At least one thread should have refreshed the cache"
262+
)
263+
264+
191265
def test_skip_refresh_if_lock_held(registry):
192266
"""Test that refresh is skipped if the lock is already held by another thread"""
193267
registry.cached_registry_proto = "some_cached_data"

0 commit comments

Comments
 (0)