Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions sdk/python/feast/api/registry/rest/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import logging
from typing import Dict, List, Optional

from fastapi import APIRouter, Depends, HTTPException, Query, Request
Expand Down Expand Up @@ -57,7 +56,6 @@ class PopularTagsResponse(BaseModel):


def get_metrics_router(grpc_handler, server=None) -> APIRouter:
logger = logging.getLogger(__name__)
router = APIRouter()

@router.get("/metrics/resource_counts", tags=["Metrics"])
Expand Down Expand Up @@ -321,20 +319,43 @@ async def recently_visited(
user = getattr(request.state, "user", None)
if not user:
user = "anonymous"
project_val = project or (server.store.project if server else None)
key = f"recently_visited_{user}"
logger.info(
f"[/metrics/recently_visited] Project: {project_val}, Key: {key}, Object: {object_type}"
)
try:
visits_json = (
server.registry.get_project_metadata(project_val, key)
if server
else None
)
visits = json.loads(visits_json) if visits_json else []
except Exception:
visits = []
visits = []
if project:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Earlier logic was not considering all projects while showing metrics. Fixed it by listing projects first and get the entries for all projects.

try:
visits_json = (
server.registry.get_project_metadata(project, key)
if server
else None
)
visits = json.loads(visits_json) if visits_json else []
except Exception:
visits = []
else:
try:
if server:
projects_resp = grpc_call(
grpc_handler.ListProjects,
RegistryServer_pb2.ListProjectsRequest(allow_cache=True),
)
all_projects = [
p["spec"]["name"] for p in projects_resp.get("projects", [])
]
for project_name in all_projects:
try:
visits_json = server.registry.get_project_metadata(
project_name, key
)
if visits_json:
project_visits = json.loads(visits_json)
visits.extend(project_visits)
except Exception:
continue
visits = sorted(
visits, key=lambda x: x.get("timestamp", ""), reverse=True
)
except Exception:
visits = []
if object_type:
visits = [v for v in visits if v.get("object") == object_type]

Expand Down
3 changes: 0 additions & 3 deletions sdk/python/feast/infra/compute_engines/spark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ def get_or_create_new_spark_session(
conf=SparkConf().setAll([(k, v) for k, v in spark_config.items()])
)

spark_builder = spark_builder.config("spark.driver.host", "127.0.0.1")
spark_builder = spark_builder.config("spark.driver.bindAddress", "127.0.0.1")

spark_session = spark_builder.getOrCreate()
spark_session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
return spark_session
Expand Down
77 changes: 39 additions & 38 deletions sdk/python/feast/infra/registry/caching_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,9 +425,6 @@ def list_projects(
return self._list_projects(tags)

def refresh(self, project: Optional[str] = None):
if self._refresh_lock.locked():
logger.debug("Skipping refresh if already in progress")
return
try:
self.cached_registry_proto = self.proto()
self.cached_registry_proto_created = _utc_now()
Expand All @@ -436,43 +433,47 @@ def refresh(self, project: Optional[str] = None):

def _refresh_cached_registry_if_necessary(self):
if self.cache_mode == "sync":
# Try acquiring the lock without blocking
if not self._refresh_lock.acquire(blocking=False):
logger.debug(
"Skipping refresh if lock is already held by another thread"
)
return
try:
if self.cached_registry_proto == RegistryProto():
# Avoids the need to refresh the registry when cache is not populated yet
# Specially during the __init__ phase
# proto() will populate the cache with project metadata if no objects are registered
expired = False
else:
expired = (
self.cached_registry_proto is None
or self.cached_registry_proto_created is None
) or (
self.cached_registry_proto_ttl.total_seconds()
> 0 # 0 ttl means infinity
and (
_utc_now()
> (
self.cached_registry_proto_created
+ self.cached_registry_proto_ttl
)
)

def is_cache_expired():
if (
self.cached_registry_proto is None
or self.cached_registry_proto == RegistryProto()
):
return True

# Cache is expired if creation time is None
if (
not hasattr(self, "cached_registry_proto_created")
or self.cached_registry_proto_created is None
):
return True

# Cache is expired if TTL > 0 and current time exceeds creation + TTL
if self.cached_registry_proto_ttl.total_seconds() > 0 and _utc_now() > (
self.cached_registry_proto_created + self.cached_registry_proto_ttl
):
return True

return False

if is_cache_expired():
if not self._refresh_lock.acquire(blocking=False):
logger.debug(
"Skipping refresh if lock is already held by another thread"
)
return
try:
logger.info(
f"Registry cache expired(ttl: {self.cached_registry_proto_ttl.total_seconds()} seconds), so refreshing"
)
if expired:
logger.debug("Registry cache expired, so refreshing")
self.refresh()
except Exception as e:
logger.debug(
f"Error in _refresh_cached_registry_if_necessary: {e}",
exc_info=True,
)
finally:
self._refresh_lock.release() # Always release the lock safely
except Exception as e:
logger.debug(
f"Error in _refresh_cached_registry_if_necessary: {e}",
exc_info=True,
)
finally:
self._refresh_lock.release()

def _start_thread_async_refresh(self, cache_ttl_seconds):
self.refresh()
Expand Down
23 changes: 9 additions & 14 deletions sdk/python/feast/infra/registry/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MetaData,
String,
Table,
Text,
create_engine,
delete,
insert,
Expand Down Expand Up @@ -209,7 +210,7 @@ class FeastMetadataKeys(Enum):
metadata,
Column("project_id", String(255), primary_key=True),
Column("metadata_key", String(50), primary_key=True),
Column("metadata_value", String(50), nullable=False),
Column("metadata_value", Text, nullable=False),
Column("last_updated_timestamp", BigInteger, nullable=False),
)

Expand Down Expand Up @@ -326,6 +327,7 @@ def teardown(self):
entities,
data_sources,
feature_views,
stream_feature_views,
feature_services,
on_demand_feature_views,
saved_datasets,
Expand Down Expand Up @@ -845,18 +847,6 @@ def process_project(project: Project):
project_name = project.name
last_updated_timestamp = project.last_updated_timestamp

try:
cached_project = self.get_project(project_name, True)
except ProjectObjectNotFoundException:
cached_project = None

allow_cache = False

if cached_project is not None:
allow_cache = (
last_updated_timestamp <= cached_project.last_updated_timestamp
)

r.projects.extend([project.to_proto()])
last_updated_timestamps.append(last_updated_timestamp)

Expand All @@ -871,7 +861,7 @@ def process_project(project: Project):
(self.list_validation_references, r.validation_references),
(self.list_permissions, r.permissions),
]:
objs: List[Any] = lister(project_name, allow_cache) # type: ignore
objs: List[Any] = lister(project_name, allow_cache=False) # type: ignore
if objs:
obj_protos = [obj.to_proto() for obj in objs]
for obj_proto in obj_protos:
Expand Down Expand Up @@ -1020,6 +1010,9 @@ def _apply_object(
if not self.purge_feast_metadata:
self._set_last_updated_metadata(update_datetime, project, conn)

if self.cache_mode == "sync":
self.refresh()

def _maybe_init_project_metadata(self, project):
# Initialize project metadata if needed
with self.write_engine.begin() as conn:
Expand Down Expand Up @@ -1062,6 +1055,8 @@ def _delete_object(
if not self.purge_feast_metadata:
self._set_last_updated_metadata(_utc_now(), project, conn)

if self.cache_mode == "sync":
self.refresh()
return rows.rowcount

def _get_object(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ def sqlite_registry():
registry_config = SqlRegistryConfig(
registry_type="sql",
path="sqlite://",
cache_ttl_seconds=2,
cache_mode="sync",
)

yield SqlRegistry(registry_config, "project", None)
Expand Down Expand Up @@ -1156,11 +1158,10 @@ def test_registry_cache(test_registry):
registry_data_sources_cached = test_registry.list_data_sources(
project, allow_cache=True
)
# Not refreshed cache, so cache miss
assert len(registry_feature_views_cached) == 0
assert len(registry_data_sources_cached) == 0
assert len(registry_feature_views_cached) == 1
assert len(registry_data_sources_cached) == 1

test_registry.refresh(project)
# Now objects exist
registry_feature_views_cached = test_registry.list_feature_views(
project, allow_cache=True, tags=fv1.tags
)
Expand Down
58 changes: 58 additions & 0 deletions sdk/python/tests/unit/infra/registry/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from feast.infra.registry.caching_registry import CachingRegistry
from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto


class TestCachingRegistry(CachingRegistry):
Expand Down Expand Up @@ -188,6 +189,63 @@ def test_cache_expiry_triggers_refresh(registry):
mock_refresh.assert_called_once()


def test_empty_cache_refresh_with_ttl(registry):
"""Test that empty cache is refreshed when TTL > 0"""
# Set up empty cache with TTL > 0
registry.cached_registry_proto = RegistryProto()
registry.cached_registry_proto_created = datetime.now(timezone.utc)
registry.cached_registry_proto_ttl = timedelta(seconds=10) # TTL > 0

# Mock refresh to check if it's called
with patch.object(
CachingRegistry, "refresh", wraps=registry.refresh
) as mock_refresh:
registry._refresh_cached_registry_if_necessary()
# Should refresh because cache is empty and TTL > 0
mock_refresh.assert_called_once()


def test_concurrent_cache_refresh_race_condition(registry):
"""Test that concurrent requests don't skip cache refresh when cache is expired"""
import threading
import time

# Set up expired cache
registry.cached_registry_proto = RegistryProto()
registry.cached_registry_proto_created = datetime.now(timezone.utc) - timedelta(
seconds=5
)
registry.cached_registry_proto_ttl = timedelta(
seconds=2
) # TTL = 2 seconds, cache is expired

refresh_calls = []

def mock_refresh():
refresh_calls.append(threading.current_thread().ident)
time.sleep(0.1) # Simulate refresh work

# Mock the refresh method to track calls
with patch.object(registry, "refresh", side_effect=mock_refresh):
# Simulate concurrent requests
threads = []
for i in range(3):
thread = threading.Thread(
target=registry._refresh_cached_registry_if_necessary
)
threads.append(thread)
thread.start()

# Wait for all threads to complete
for thread in threads:
thread.join()

# At least one thread should have called refresh (the first one to acquire the lock)
assert len(refresh_calls) >= 1, (
"At least one thread should have refreshed the cache"
)


def test_skip_refresh_if_lock_held(registry):
"""Test that refresh is skipped if the lock is already held by another thread"""
registry.cached_registry_proto = "some_cached_data"
Expand Down
Loading