Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 690cb4f

Browse files
authored
Allow for ignoring some arguments when caching. (#12189)
* `@cached` can now take an `uncached_args` which is an iterable of names to not use in the cache key. * Requires `@cached`, @cachedList` and `@lru_cache` to use keyword arguments for clarity. * Asserts that keyword-only arguments in cached functions are not accepted. (I tested this briefly and I don't believe this works properly.)
1 parent 0326888 commit 690cb4f

File tree

4 files changed

+142
-21
lines changed

4 files changed

+142
-21
lines changed

changelog.d/12189.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Support skipping some arguments when generating cache keys.

synapse/storage/databases/main/events_worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,7 +1286,7 @@ async def have_seen_events(
12861286
)
12871287
return {eid for ((_rid, eid), have_event) in res.items() if have_event}
12881288

1289-
@cachedList("have_seen_event", "keys")
1289+
@cachedList(cached_method_name="have_seen_event", list_name="keys")
12901290
async def _have_seen_events_dict(
12911291
self, keys: Iterable[Tuple[str, str]]
12921292
) -> Dict[Tuple[str, str], bool]:
@@ -1954,7 +1954,7 @@ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
19541954
get_event_id_for_timestamp_txn,
19551955
)
19561956

1957-
@cachedList("is_partial_state_event", list_name="event_ids")
1957+
@cachedList(cached_method_name="is_partial_state_event", list_name="event_ids")
19581958
async def get_partial_state_events(
19591959
self, event_ids: Collection[str]
19601960
) -> Dict[str, bool]:

synapse/util/caches/descriptors.py

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Any,
2121
Awaitable,
2222
Callable,
23+
Collection,
2324
Dict,
2425
Generic,
2526
Hashable,
@@ -69,13 +70,21 @@ def __init__(
6970
self,
7071
orig: Callable[..., Any],
7172
num_args: Optional[int],
73+
uncached_args: Optional[Collection[str]] = None,
7274
cache_context: bool = False,
7375
):
7476
self.orig = orig
7577

7678
arg_spec = inspect.getfullargspec(orig)
7779
all_args = arg_spec.args
7880

81+
# There's no reason that keyword-only arguments couldn't be supported,
82+
# but right now they're buggy so do not allow them.
83+
if arg_spec.kwonlyargs:
84+
raise ValueError(
85+
"_CacheDescriptorBase does not support keyword-only arguments."
86+
)
87+
7988
if "cache_context" in all_args:
8089
if not cache_context:
8190
raise ValueError(
@@ -88,6 +97,9 @@ def __init__(
8897
" named `cache_context`"
8998
)
9099

100+
if num_args is not None and uncached_args is not None:
101+
raise ValueError("Cannot provide both num_args and uncached_args")
102+
91103
if num_args is None:
92104
num_args = len(all_args) - 1
93105
if cache_context:
@@ -105,6 +117,12 @@ def __init__(
105117
# list of the names of the args used as the cache key
106118
self.arg_names = all_args[1 : num_args + 1]
107119

120+
# If there are args to not cache on, filter them out (and fix the size of num_args).
121+
if uncached_args is not None:
122+
include_arg_in_cache_key = [n not in uncached_args for n in self.arg_names]
123+
else:
124+
include_arg_in_cache_key = [True] * len(self.arg_names)
125+
108126
# self.arg_defaults is a map of arg name to its default value for each
109127
# argument that has a default value
110128
if arg_spec.defaults:
@@ -119,8 +137,8 @@ def __init__(
119137

120138
self.add_cache_context = cache_context
121139

122-
self.cache_key_builder = get_cache_key_builder(
123-
self.arg_names, self.arg_defaults
140+
self.cache_key_builder = _get_cache_key_builder(
141+
self.arg_names, include_arg_in_cache_key, self.arg_defaults
124142
)
125143

126144

@@ -130,8 +148,7 @@ class _LruCachedFunction(Generic[F]):
130148

131149

132150
def lru_cache(
133-
max_entries: int = 1000,
134-
cache_context: bool = False,
151+
*, max_entries: int = 1000, cache_context: bool = False
135152
) -> Callable[[F], _LruCachedFunction[F]]:
136153
"""A method decorator that applies a memoizing cache around the function.
137154
@@ -186,7 +203,9 @@ def __init__(
186203
max_entries: int = 1000,
187204
cache_context: bool = False,
188205
):
189-
super().__init__(orig, num_args=None, cache_context=cache_context)
206+
super().__init__(
207+
orig, num_args=None, uncached_args=None, cache_context=cache_context
208+
)
190209
self.max_entries = max_entries
191210

192211
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
@@ -260,6 +279,9 @@ def foo(self, key, cache_context):
260279
num_args: number of positional arguments (excluding ``self`` and
261280
``cache_context``) to use as cache keys. Defaults to all named
262281
args of the function.
282+
uncached_args: a list of argument names to not use as the cache key.
283+
(``self`` and ``cache_context`` are always ignored.) Cannot be used
284+
with num_args.
263285
tree:
264286
cache_context:
265287
iterable:
@@ -273,12 +295,18 @@ def __init__(
273295
orig: Callable[..., Any],
274296
max_entries: int = 1000,
275297
num_args: Optional[int] = None,
298+
uncached_args: Optional[Collection[str]] = None,
276299
tree: bool = False,
277300
cache_context: bool = False,
278301
iterable: bool = False,
279302
prune_unread_entries: bool = True,
280303
):
281-
super().__init__(orig, num_args=num_args, cache_context=cache_context)
304+
super().__init__(
305+
orig,
306+
num_args=num_args,
307+
uncached_args=uncached_args,
308+
cache_context=cache_context,
309+
)
282310

283311
if tree and self.num_args < 2:
284312
raise RuntimeError(
@@ -369,7 +397,7 @@ def __init__(
369397
but including list_name) to use as cache keys. Defaults to all
370398
named args of the function.
371399
"""
372-
super().__init__(orig, num_args=num_args)
400+
super().__init__(orig, num_args=num_args, uncached_args=None)
373401

374402
self.list_name = list_name
375403

@@ -530,8 +558,10 @@ def get_instance(
530558

531559

532560
def cached(
561+
*,
533562
max_entries: int = 1000,
534563
num_args: Optional[int] = None,
564+
uncached_args: Optional[Collection[str]] = None,
535565
tree: bool = False,
536566
cache_context: bool = False,
537567
iterable: bool = False,
@@ -541,6 +571,7 @@ def cached(
541571
orig,
542572
max_entries=max_entries,
543573
num_args=num_args,
574+
uncached_args=uncached_args,
544575
tree=tree,
545576
cache_context=cache_context,
546577
iterable=iterable,
@@ -551,7 +582,7 @@ def cached(
551582

552583

553584
def cachedList(
554-
cached_method_name: str, list_name: str, num_args: Optional[int] = None
585+
*, cached_method_name: str, list_name: str, num_args: Optional[int] = None
555586
) -> Callable[[F], _CachedFunction[F]]:
556587
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
557588
@@ -590,13 +621,16 @@ def batch_do_something(self, first_arg, second_args):
590621
return cast(Callable[[F], _CachedFunction[F]], func)
591622

592623

593-
def get_cache_key_builder(
594-
param_names: Sequence[str], param_defaults: Mapping[str, Any]
624+
def _get_cache_key_builder(
625+
param_names: Sequence[str],
626+
include_params: Sequence[bool],
627+
param_defaults: Mapping[str, Any],
595628
) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
596629
"""Construct a function which will build cache keys suitable for a cached function
597630
598631
Args:
599632
param_names: list of formal parameter names for the cached function
633+
include_params: list of bools of whether to include the parameter name in the cache key
600634
param_defaults: a mapping from parameter name to default value for that param
601635
602636
Returns:
@@ -608,6 +642,7 @@ def get_cache_key_builder(
608642

609643
if len(param_names) == 1:
610644
nm = param_names[0]
645+
assert include_params[0] is True
611646

612647
def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
613648
if nm in kwargs:
@@ -620,13 +655,18 @@ def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
620655
else:
621656

622657
def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
623-
return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
658+
return tuple(
659+
_get_cache_key_gen(
660+
param_names, include_params, param_defaults, args, kwargs
661+
)
662+
)
624663

625664
return get_cache_key
626665

627666

628667
def _get_cache_key_gen(
629668
param_names: Iterable[str],
669+
include_params: Iterable[bool],
630670
param_defaults: Mapping[str, Any],
631671
args: Sequence[Any],
632672
kwargs: Mapping[str, Any],
@@ -637,16 +677,18 @@ def _get_cache_key_gen(
637677
This is essentially the same operation as `inspect.getcallargs`, but optimised so
638678
that we don't need to inspect the target function for each call.
639679
"""
640-
641680
# We loop through each arg name, looking up if its in the `kwargs`,
642681
# otherwise using the next argument in `args`. If there are no more
643682
# args then we try looking the arg name up in the defaults.
644683
pos = 0
645-
for nm in param_names:
684+
for nm, inc in zip(param_names, include_params):
646685
if nm in kwargs:
647-
yield kwargs[nm]
686+
if inc:
687+
yield kwargs[nm]
648688
elif pos < len(args):
649-
yield args[pos]
689+
if inc:
690+
yield args[pos]
650691
pos += 1
651692
else:
652-
yield param_defaults[nm]
693+
if inc:
694+
yield param_defaults[nm]

tests/util/caches/test_descriptors.py

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,84 @@ def fn(self, arg1, arg2):
141141
self.assertEqual(r, "chips")
142142
obj.mock.assert_not_called()
143143

144+
@defer.inlineCallbacks
145+
def test_cache_uncached_args(self):
146+
"""
147+
Only the arguments not named in uncached_args should matter to the cache
148+
149+
Note that this is identical to test_cache_num_args, but provides the
150+
arguments differently.
151+
"""
152+
153+
class Cls:
154+
# Note that it is important that this is not the last argument to
155+
# test behaviour of skipping arguments properly.
156+
@descriptors.cached(uncached_args=("arg2",))
157+
def fn(self, arg1, arg2, arg3):
158+
return self.mock(arg1, arg2, arg3)
159+
160+
def __init__(self):
161+
self.mock = mock.Mock()
162+
163+
obj = Cls()
164+
obj.mock.return_value = "fish"
165+
r = yield obj.fn(1, 2, 3)
166+
self.assertEqual(r, "fish")
167+
obj.mock.assert_called_once_with(1, 2, 3)
168+
obj.mock.reset_mock()
169+
170+
# a call with different params should call the mock again
171+
obj.mock.return_value = "chips"
172+
r = yield obj.fn(2, 3, 4)
173+
self.assertEqual(r, "chips")
174+
obj.mock.assert_called_once_with(2, 3, 4)
175+
obj.mock.reset_mock()
176+
177+
# the two values should now be cached; we should be able to vary
178+
# the second argument and still get the cached result.
179+
r = yield obj.fn(1, 4, 3)
180+
self.assertEqual(r, "fish")
181+
r = yield obj.fn(2, 5, 4)
182+
self.assertEqual(r, "chips")
183+
obj.mock.assert_not_called()
184+
185+
@defer.inlineCallbacks
186+
def test_cache_kwargs(self):
187+
"""Test that keyword arguments are treated properly"""
188+
189+
class Cls:
190+
def __init__(self):
191+
self.mock = mock.Mock()
192+
193+
@descriptors.cached()
194+
def fn(self, arg1, kwarg1=2):
195+
return self.mock(arg1, kwarg1=kwarg1)
196+
197+
obj = Cls()
198+
obj.mock.return_value = "fish"
199+
r = yield obj.fn(1, kwarg1=2)
200+
self.assertEqual(r, "fish")
201+
obj.mock.assert_called_once_with(1, kwarg1=2)
202+
obj.mock.reset_mock()
203+
204+
# a call with different params should call the mock again
205+
obj.mock.return_value = "chips"
206+
r = yield obj.fn(1, kwarg1=3)
207+
self.assertEqual(r, "chips")
208+
obj.mock.assert_called_once_with(1, kwarg1=3)
209+
obj.mock.reset_mock()
210+
211+
# the values should now be cached.
212+
r = yield obj.fn(1, kwarg1=2)
213+
self.assertEqual(r, "fish")
214+
# We should be able to not provide kwarg1 and get the cached value back.
215+
r = yield obj.fn(1)
216+
self.assertEqual(r, "fish")
217+
# Keyword arguments can be in any order.
218+
r = yield obj.fn(kwarg1=2, arg1=1)
219+
self.assertEqual(r, "fish")
220+
obj.mock.assert_not_called()
221+
144222
def test_cache_with_sync_exception(self):
145223
"""If the wrapped function throws synchronously, things should continue to work"""
146224

@@ -656,7 +734,7 @@ def __init__(self):
656734
def fn(self, arg1, arg2):
657735
pass
658736

659-
@descriptors.cachedList("fn", "args1")
737+
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
660738
async def list_fn(self, args1, arg2):
661739
assert current_context().name == "c1"
662740
# we want this to behave like an asynchronous function
@@ -715,7 +793,7 @@ def __init__(self):
715793
def fn(self, arg1):
716794
pass
717795

718-
@descriptors.cachedList("fn", "args1")
796+
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
719797
def list_fn(self, args1) -> "Deferred[dict]":
720798
return self.mock(args1)
721799

@@ -758,7 +836,7 @@ def __init__(self):
758836
def fn(self, arg1, arg2):
759837
pass
760838

761-
@descriptors.cachedList("fn", "args1")
839+
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
762840
async def list_fn(self, args1, arg2):
763841
# we want this to behave like an asynchronous function
764842
await run_on_reactor()

0 commit comments

Comments
 (0)