Skip to content

Commit 3c7a022

Browse files
authored
fix: Fixed Registry Cache Refresh Issues (#5604)
* fix: Fixed cache refresh Signed-off-by: ntkathole <[email protected]> * fix: Fixed registry cache init Signed-off-by: ntkathole <[email protected]> --------- Signed-off-by: ntkathole <[email protected]>
1 parent 1d08786 commit 3c7a022

File tree

6 files changed

+147
-74
lines changed

6 files changed

+147
-74
lines changed

sdk/python/feast/api/registry/rest/metrics.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import json
2-
import logging
32
from typing import Dict, List, Optional
43

54
from fastapi import APIRouter, Depends, HTTPException, Query, Request
@@ -57,7 +56,6 @@ class PopularTagsResponse(BaseModel):
5756

5857

5958
def get_metrics_router(grpc_handler, server=None) -> APIRouter:
60-
logger = logging.getLogger(__name__)
6159
router = APIRouter()
6260

6361
@router.get("/metrics/resource_counts", tags=["Metrics"])
@@ -321,20 +319,43 @@ async def recently_visited(
321319
user = getattr(request.state, "user", None)
322320
if not user:
323321
user = "anonymous"
324-
project_val = project or (server.store.project if server else None)
325322
key = f"recently_visited_{user}"
326-
logger.info(
327-
f"[/metrics/recently_visited] Project: {project_val}, Key: {key}, Object: {object_type}"
328-
)
329-
try:
330-
visits_json = (
331-
server.registry.get_project_metadata(project_val, key)
332-
if server
333-
else None
334-
)
335-
visits = json.loads(visits_json) if visits_json else []
336-
except Exception:
337-
visits = []
323+
visits = []
324+
if project:
325+
try:
326+
visits_json = (
327+
server.registry.get_project_metadata(project, key)
328+
if server
329+
else None
330+
)
331+
visits = json.loads(visits_json) if visits_json else []
332+
except Exception:
333+
visits = []
334+
else:
335+
try:
336+
if server:
337+
projects_resp = grpc_call(
338+
grpc_handler.ListProjects,
339+
RegistryServer_pb2.ListProjectsRequest(allow_cache=True),
340+
)
341+
all_projects = [
342+
p["spec"]["name"] for p in projects_resp.get("projects", [])
343+
]
344+
for project_name in all_projects:
345+
try:
346+
visits_json = server.registry.get_project_metadata(
347+
project_name, key
348+
)
349+
if visits_json:
350+
project_visits = json.loads(visits_json)
351+
visits.extend(project_visits)
352+
except Exception:
353+
continue
354+
visits = sorted(
355+
visits, key=lambda x: x.get("timestamp", ""), reverse=True
356+
)
357+
except Exception:
358+
visits = []
338359
if object_type:
339360
visits = [v for v in visits if v.get("object") == object_type]
340361

sdk/python/feast/infra/compute_engines/spark/utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@ def get_or_create_new_spark_session(
2121
conf=SparkConf().setAll([(k, v) for k, v in spark_config.items()])
2222
)
2323

24-
spark_builder = spark_builder.config("spark.driver.host", "127.0.0.1")
25-
spark_builder = spark_builder.config("spark.driver.bindAddress", "127.0.0.1")
26-
2724
spark_session = spark_builder.getOrCreate()
2825
spark_session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
2926
return spark_session

sdk/python/feast/infra/registry/caching_registry.py

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -425,9 +425,6 @@ def list_projects(
425425
return self._list_projects(tags)
426426

427427
def refresh(self, project: Optional[str] = None):
428-
if self._refresh_lock.locked():
429-
logger.debug("Skipping refresh if already in progress")
430-
return
431428
try:
432429
self.cached_registry_proto = self.proto()
433430
self.cached_registry_proto_created = _utc_now()
@@ -436,43 +433,47 @@ def refresh(self, project: Optional[str] = None):
436433

437434
def _refresh_cached_registry_if_necessary(self):
438435
if self.cache_mode == "sync":
439-
# Try acquiring the lock without blocking
440-
if not self._refresh_lock.acquire(blocking=False):
441-
logger.debug(
442-
"Skipping refresh if lock is already held by another thread"
443-
)
444-
return
445-
try:
446-
if self.cached_registry_proto == RegistryProto():
447-
# Avoids the need to refresh the registry when cache is not populated yet
448-
# Specially during the __init__ phase
449-
# proto() will populate the cache with project metadata if no objects are registered
450-
expired = False
451-
else:
452-
expired = (
453-
self.cached_registry_proto is None
454-
or self.cached_registry_proto_created is None
455-
) or (
456-
self.cached_registry_proto_ttl.total_seconds()
457-
> 0 # 0 ttl means infinity
458-
and (
459-
_utc_now()
460-
> (
461-
self.cached_registry_proto_created
462-
+ self.cached_registry_proto_ttl
463-
)
464-
)
436+
437+
def is_cache_expired():
438+
if (
439+
self.cached_registry_proto is None
440+
or self.cached_registry_proto == RegistryProto()
441+
):
442+
return True
443+
444+
# Cache is expired if creation time is None
445+
if (
446+
not hasattr(self, "cached_registry_proto_created")
447+
or self.cached_registry_proto_created is None
448+
):
449+
return True
450+
451+
# Cache is expired if TTL > 0 and current time exceeds creation + TTL
452+
if self.cached_registry_proto_ttl.total_seconds() > 0 and _utc_now() > (
453+
self.cached_registry_proto_created + self.cached_registry_proto_ttl
454+
):
455+
return True
456+
457+
return False
458+
459+
if is_cache_expired():
460+
if not self._refresh_lock.acquire(blocking=False):
461+
logger.debug(
462+
"Skipping refresh if lock is already held by another thread"
463+
)
464+
return
465+
try:
466+
logger.info(
467+
f"Registry cache expired(ttl: {self.cached_registry_proto_ttl.total_seconds()} seconds), so refreshing"
465468
)
466-
if expired:
467-
logger.debug("Registry cache expired, so refreshing")
468469
self.refresh()
469-
except Exception as e:
470-
logger.debug(
471-
f"Error in _refresh_cached_registry_if_necessary: {e}",
472-
exc_info=True,
473-
)
474-
finally:
475-
self._refresh_lock.release() # Always release the lock safely
470+
except Exception as e:
471+
logger.debug(
472+
f"Error in _refresh_cached_registry_if_necessary: {e}",
473+
exc_info=True,
474+
)
475+
finally:
476+
self._refresh_lock.release()
476477

477478
def _start_thread_async_refresh(self, cache_ttl_seconds):
478479
self.refresh()

sdk/python/feast/infra/registry/sql.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
MetaData,
1616
String,
1717
Table,
18+
Text,
1819
create_engine,
1920
delete,
2021
insert,
@@ -209,7 +210,7 @@ class FeastMetadataKeys(Enum):
209210
metadata,
210211
Column("project_id", String(255), primary_key=True),
211212
Column("metadata_key", String(50), primary_key=True),
212-
Column("metadata_value", String(50), nullable=False),
213+
Column("metadata_value", Text, nullable=False),
213214
Column("last_updated_timestamp", BigInteger, nullable=False),
214215
)
215216

@@ -326,6 +327,7 @@ def teardown(self):
326327
entities,
327328
data_sources,
328329
feature_views,
330+
stream_feature_views,
329331
feature_services,
330332
on_demand_feature_views,
331333
saved_datasets,
@@ -845,18 +847,6 @@ def process_project(project: Project):
845847
project_name = project.name
846848
last_updated_timestamp = project.last_updated_timestamp
847849

848-
try:
849-
cached_project = self.get_project(project_name, True)
850-
except ProjectObjectNotFoundException:
851-
cached_project = None
852-
853-
allow_cache = False
854-
855-
if cached_project is not None:
856-
allow_cache = (
857-
last_updated_timestamp <= cached_project.last_updated_timestamp
858-
)
859-
860850
r.projects.extend([project.to_proto()])
861851
last_updated_timestamps.append(last_updated_timestamp)
862852

@@ -871,7 +861,7 @@ def process_project(project: Project):
871861
(self.list_validation_references, r.validation_references),
872862
(self.list_permissions, r.permissions),
873863
]:
874-
objs: List[Any] = lister(project_name, allow_cache) # type: ignore
864+
objs: List[Any] = lister(project_name, allow_cache=False) # type: ignore
875865
if objs:
876866
obj_protos = [obj.to_proto() for obj in objs]
877867
for obj_proto in obj_protos:
@@ -1020,6 +1010,9 @@ def _apply_object(
10201010
if not self.purge_feast_metadata:
10211011
self._set_last_updated_metadata(update_datetime, project, conn)
10221012

1013+
if self.cache_mode == "sync":
1014+
self.refresh()
1015+
10231016
def _maybe_init_project_metadata(self, project):
10241017
# Initialize project metadata if needed
10251018
with self.write_engine.begin() as conn:
@@ -1062,6 +1055,8 @@ def _delete_object(
10621055
if not self.purge_feast_metadata:
10631056
self._set_last_updated_metadata(_utc_now(), project, conn)
10641057

1058+
if self.cache_mode == "sync":
1059+
self.refresh()
10651060
return rows.rowcount
10661061

10671062
def _get_object(

sdk/python/tests/integration/registration/test_universal_registry.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,8 @@ def sqlite_registry():
273273
registry_config = SqlRegistryConfig(
274274
registry_type="sql",
275275
path="sqlite://",
276+
cache_ttl_seconds=2,
277+
cache_mode="sync",
276278
)
277279

278280
yield SqlRegistry(registry_config, "project", None)
@@ -1156,11 +1158,10 @@ def test_registry_cache(test_registry):
11561158
registry_data_sources_cached = test_registry.list_data_sources(
11571159
project, allow_cache=True
11581160
)
1159-
# Not refreshed cache, so cache miss
1160-
assert len(registry_feature_views_cached) == 0
1161-
assert len(registry_data_sources_cached) == 0
1161+
assert len(registry_feature_views_cached) == 1
1162+
assert len(registry_data_sources_cached) == 1
1163+
11621164
test_registry.refresh(project)
1163-
# Now objects exist
11641165
registry_feature_views_cached = test_registry.list_feature_views(
11651166
project, allow_cache=True, tags=fv1.tags
11661167
)

sdk/python/tests/unit/infra/registry/test_registry.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
from feast.infra.registry.caching_registry import CachingRegistry
7+
from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto
78

89

910
class TestCachingRegistry(CachingRegistry):
@@ -188,6 +189,63 @@ def test_cache_expiry_triggers_refresh(registry):
188189
mock_refresh.assert_called_once()
189190

190191

192+
def test_empty_cache_refresh_with_ttl(registry):
193+
"""Test that empty cache is refreshed when TTL > 0"""
194+
# Set up empty cache with TTL > 0
195+
registry.cached_registry_proto = RegistryProto()
196+
registry.cached_registry_proto_created = datetime.now(timezone.utc)
197+
registry.cached_registry_proto_ttl = timedelta(seconds=10) # TTL > 0
198+
199+
# Mock refresh to check if it's called
200+
with patch.object(
201+
CachingRegistry, "refresh", wraps=registry.refresh
202+
) as mock_refresh:
203+
registry._refresh_cached_registry_if_necessary()
204+
# Should refresh because cache is empty and TTL > 0
205+
mock_refresh.assert_called_once()
206+
207+
208+
def test_concurrent_cache_refresh_race_condition(registry):
209+
"""Test that concurrent requests don't skip cache refresh when cache is expired"""
210+
import threading
211+
import time
212+
213+
# Set up expired cache
214+
registry.cached_registry_proto = RegistryProto()
215+
registry.cached_registry_proto_created = datetime.now(timezone.utc) - timedelta(
216+
seconds=5
217+
)
218+
registry.cached_registry_proto_ttl = timedelta(
219+
seconds=2
220+
) # TTL = 2 seconds, cache is expired
221+
222+
refresh_calls = []
223+
224+
def mock_refresh():
225+
refresh_calls.append(threading.current_thread().ident)
226+
time.sleep(0.1) # Simulate refresh work
227+
228+
# Mock the refresh method to track calls
229+
with patch.object(registry, "refresh", side_effect=mock_refresh):
230+
# Simulate concurrent requests
231+
threads = []
232+
for i in range(3):
233+
thread = threading.Thread(
234+
target=registry._refresh_cached_registry_if_necessary
235+
)
236+
threads.append(thread)
237+
thread.start()
238+
239+
# Wait for all threads to complete
240+
for thread in threads:
241+
thread.join()
242+
243+
# At least one thread should have called refresh (the first one to acquire the lock)
244+
assert len(refresh_calls) >= 1, (
245+
"At least one thread should have refreshed the cache"
246+
)
247+
248+
191249
def test_skip_refresh_if_lock_held(registry):
192250
"""Test that refresh is skipped if the lock is already held by another thread"""
193251
registry.cached_registry_proto = "some_cached_data"

0 commit comments

Comments
 (0)