|
4 | 4 | import pytest
|
5 | 5 |
|
6 | 6 | from feast.infra.registry.caching_registry import CachingRegistry
|
| 7 | +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto |
7 | 8 |
|
8 | 9 |
|
9 | 10 | class TestCachingRegistry(CachingRegistry):
|
@@ -188,6 +189,79 @@ def test_cache_expiry_triggers_refresh(registry):
|
188 | 189 | mock_refresh.assert_called_once()
|
189 | 190 |
|
190 | 191 |
|
| 192 | +def test_empty_cache_refresh_with_ttl(registry): |
| 193 | + """Test that empty cache is refreshed when TTL > 0""" |
| 194 | + # Set up empty cache with TTL > 0 |
| 195 | + registry.cached_registry_proto = RegistryProto() |
| 196 | + registry.cached_registry_proto_created = datetime.now(timezone.utc) |
| 197 | + registry.cached_registry_proto_ttl = timedelta(seconds=10) # TTL > 0 |
| 198 | + |
| 199 | + # Mock refresh to check if it's called |
| 200 | + with patch.object( |
| 201 | + CachingRegistry, "refresh", wraps=registry.refresh |
| 202 | + ) as mock_refresh: |
| 203 | + registry._refresh_cached_registry_if_necessary() |
| 204 | + # Should refresh because cache is empty and TTL > 0 |
| 205 | + mock_refresh.assert_called_once() |
| 206 | + |
| 207 | + |
| 208 | +def test_empty_cache_no_refresh_with_infinite_ttl(registry): |
| 209 | + """Test that empty cache is not refreshed when TTL = 0 (infinite)""" |
| 210 | + # Set up empty cache with TTL = 0 (infinite) |
| 211 | + registry.cached_registry_proto = RegistryProto() |
| 212 | + registry.cached_registry_proto_created = datetime.now(timezone.utc) |
| 213 | + registry.cached_registry_proto_ttl = timedelta(seconds=0) # TTL = 0 (infinite) |
| 214 | + |
| 215 | + # Mock refresh to check if it's called |
| 216 | + with patch.object( |
| 217 | + CachingRegistry, "refresh", wraps=registry.refresh |
| 218 | + ) as mock_refresh: |
| 219 | + registry._refresh_cached_registry_if_necessary() |
| 220 | + # Should not refresh because TTL = 0 (infinite) |
| 221 | + mock_refresh.assert_not_called() |
| 222 | + |
| 223 | + |
| 224 | +def test_concurrent_cache_refresh_race_condition(registry): |
| 225 | + """Test that concurrent requests don't skip cache refresh when cache is expired""" |
| 226 | + import threading |
| 227 | + import time |
| 228 | + |
| 229 | + # Set up expired cache |
| 230 | + registry.cached_registry_proto = RegistryProto() |
| 231 | + registry.cached_registry_proto_created = datetime.now(timezone.utc) - timedelta( |
| 232 | + seconds=5 |
| 233 | + ) |
| 234 | + registry.cached_registry_proto_ttl = timedelta( |
| 235 | + seconds=2 |
| 236 | + ) # TTL = 2 seconds, cache is expired |
| 237 | + |
| 238 | + refresh_calls = [] |
| 239 | + |
| 240 | + def mock_refresh(): |
| 241 | + refresh_calls.append(threading.current_thread().ident) |
| 242 | + time.sleep(0.1) # Simulate refresh work |
| 243 | + |
| 244 | + # Mock the refresh method to track calls |
| 245 | + with patch.object(registry, "refresh", side_effect=mock_refresh): |
| 246 | + # Simulate concurrent requests |
| 247 | + threads = [] |
| 248 | + for i in range(3): |
| 249 | + thread = threading.Thread( |
| 250 | + target=registry._refresh_cached_registry_if_necessary |
| 251 | + ) |
| 252 | + threads.append(thread) |
| 253 | + thread.start() |
| 254 | + |
| 255 | + # Wait for all threads to complete |
| 256 | + for thread in threads: |
| 257 | + thread.join() |
| 258 | + |
| 259 | + # At least one thread should have called refresh (the first one to acquire the lock) |
| 260 | + assert len(refresh_calls) >= 1, ( |
| 261 | + "At least one thread should have refreshed the cache" |
| 262 | + ) |
| 263 | + |
| 264 | + |
191 | 265 | def test_skip_refresh_if_lock_held(registry):
|
192 | 266 | """Test that refresh is skipped if the lock is already held by another thread"""
|
193 | 267 | registry.cached_registry_proto = "some_cached_data"
|
|
0 commit comments