20
20
Any ,
21
21
Awaitable ,
22
22
Callable ,
23
+ Collection ,
23
24
Dict ,
24
25
Generic ,
25
26
Hashable ,
@@ -69,13 +70,21 @@ def __init__(
69
70
self ,
70
71
orig : Callable [..., Any ],
71
72
num_args : Optional [int ],
73
+ uncached_args : Optional [Collection [str ]] = None ,
72
74
cache_context : bool = False ,
73
75
):
74
76
self .orig = orig
75
77
76
78
arg_spec = inspect .getfullargspec (orig )
77
79
all_args = arg_spec .args
78
80
81
+ # There's no reason that keyword-only arguments couldn't be supported,
82
+ # but right now they're buggy so do not allow them.
83
+ if arg_spec .kwonlyargs :
84
+ raise ValueError (
85
+ "_CacheDescriptorBase does not support keyword-only arguments."
86
+ )
87
+
79
88
if "cache_context" in all_args :
80
89
if not cache_context :
81
90
raise ValueError (
@@ -88,6 +97,9 @@ def __init__(
88
97
" named `cache_context`"
89
98
)
90
99
100
+ if num_args is not None and uncached_args is not None :
101
+ raise ValueError ("Cannot provide both num_args and uncached_args" )
102
+
91
103
if num_args is None :
92
104
num_args = len (all_args ) - 1
93
105
if cache_context :
@@ -105,6 +117,12 @@ def __init__(
105
117
# list of the names of the args used as the cache key
106
118
self .arg_names = all_args [1 : num_args + 1 ]
107
119
120
+ # If there are args to not cache on, filter them out (and fix the size of num_args).
121
+ if uncached_args is not None :
122
+ include_arg_in_cache_key = [n not in uncached_args for n in self .arg_names ]
123
+ else :
124
+ include_arg_in_cache_key = [True ] * len (self .arg_names )
125
+
108
126
# self.arg_defaults is a map of arg name to its default value for each
109
127
# argument that has a default value
110
128
if arg_spec .defaults :
@@ -119,8 +137,8 @@ def __init__(
119
137
120
138
self .add_cache_context = cache_context
121
139
122
- self .cache_key_builder = get_cache_key_builder (
123
- self .arg_names , self .arg_defaults
140
+ self .cache_key_builder = _get_cache_key_builder (
141
+ self .arg_names , include_arg_in_cache_key , self .arg_defaults
124
142
)
125
143
126
144
@@ -130,8 +148,7 @@ class _LruCachedFunction(Generic[F]):
130
148
131
149
132
150
def lru_cache (
133
- max_entries : int = 1000 ,
134
- cache_context : bool = False ,
151
+ * , max_entries : int = 1000 , cache_context : bool = False
135
152
) -> Callable [[F ], _LruCachedFunction [F ]]:
136
153
"""A method decorator that applies a memoizing cache around the function.
137
154
@@ -186,7 +203,9 @@ def __init__(
186
203
max_entries : int = 1000 ,
187
204
cache_context : bool = False ,
188
205
):
189
- super ().__init__ (orig , num_args = None , cache_context = cache_context )
206
+ super ().__init__ (
207
+ orig , num_args = None , uncached_args = None , cache_context = cache_context
208
+ )
190
209
self .max_entries = max_entries
191
210
192
211
def __get__ (self , obj : Optional [Any ], owner : Optional [Type ]) -> Callable [..., Any ]:
@@ -260,6 +279,9 @@ def foo(self, key, cache_context):
260
279
num_args: number of positional arguments (excluding ``self`` and
261
280
``cache_context``) to use as cache keys. Defaults to all named
262
281
args of the function.
282
+ uncached_args: a list of argument names to not use as the cache key.
283
+ (``self`` and ``cache_context`` are always ignored.) Cannot be used
284
+ with num_args.
263
285
tree:
264
286
cache_context:
265
287
iterable:
@@ -273,12 +295,18 @@ def __init__(
273
295
orig : Callable [..., Any ],
274
296
max_entries : int = 1000 ,
275
297
num_args : Optional [int ] = None ,
298
+ uncached_args : Optional [Collection [str ]] = None ,
276
299
tree : bool = False ,
277
300
cache_context : bool = False ,
278
301
iterable : bool = False ,
279
302
prune_unread_entries : bool = True ,
280
303
):
281
- super ().__init__ (orig , num_args = num_args , cache_context = cache_context )
304
+ super ().__init__ (
305
+ orig ,
306
+ num_args = num_args ,
307
+ uncached_args = uncached_args ,
308
+ cache_context = cache_context ,
309
+ )
282
310
283
311
if tree and self .num_args < 2 :
284
312
raise RuntimeError (
@@ -369,7 +397,7 @@ def __init__(
369
397
but including list_name) to use as cache keys. Defaults to all
370
398
named args of the function.
371
399
"""
372
- super ().__init__ (orig , num_args = num_args )
400
+ super ().__init__ (orig , num_args = num_args , uncached_args = None )
373
401
374
402
self .list_name = list_name
375
403
@@ -530,8 +558,10 @@ def get_instance(
530
558
531
559
532
560
def cached (
561
+ * ,
533
562
max_entries : int = 1000 ,
534
563
num_args : Optional [int ] = None ,
564
+ uncached_args : Optional [Collection [str ]] = None ,
535
565
tree : bool = False ,
536
566
cache_context : bool = False ,
537
567
iterable : bool = False ,
@@ -541,6 +571,7 @@ def cached(
541
571
orig ,
542
572
max_entries = max_entries ,
543
573
num_args = num_args ,
574
+ uncached_args = uncached_args ,
544
575
tree = tree ,
545
576
cache_context = cache_context ,
546
577
iterable = iterable ,
@@ -551,7 +582,7 @@ def cached(
551
582
552
583
553
584
def cachedList (
554
- cached_method_name : str , list_name : str , num_args : Optional [int ] = None
585
+ * , cached_method_name : str , list_name : str , num_args : Optional [int ] = None
555
586
) -> Callable [[F ], _CachedFunction [F ]]:
556
587
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
557
588
@@ -590,13 +621,16 @@ def batch_do_something(self, first_arg, second_args):
590
621
return cast (Callable [[F ], _CachedFunction [F ]], func )
591
622
592
623
593
- def get_cache_key_builder (
594
- param_names : Sequence [str ], param_defaults : Mapping [str , Any ]
624
+ def _get_cache_key_builder (
625
+ param_names : Sequence [str ],
626
+ include_params : Sequence [bool ],
627
+ param_defaults : Mapping [str , Any ],
595
628
) -> Callable [[Sequence [Any ], Mapping [str , Any ]], CacheKey ]:
596
629
"""Construct a function which will build cache keys suitable for a cached function
597
630
598
631
Args:
599
632
param_names: list of formal parameter names for the cached function
633
+ include_params: list of bools of whether to include the parameter name in the cache key
600
634
param_defaults: a mapping from parameter name to default value for that param
601
635
602
636
Returns:
@@ -608,6 +642,7 @@ def get_cache_key_builder(
608
642
609
643
if len (param_names ) == 1 :
610
644
nm = param_names [0 ]
645
+ assert include_params [0 ] is True
611
646
612
647
def get_cache_key (args : Sequence [Any ], kwargs : Mapping [str , Any ]) -> CacheKey :
613
648
if nm in kwargs :
@@ -620,13 +655,18 @@ def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
620
655
else :
621
656
622
657
def get_cache_key (args : Sequence [Any ], kwargs : Mapping [str , Any ]) -> CacheKey :
623
- return tuple (_get_cache_key_gen (param_names , param_defaults , args , kwargs ))
658
+ return tuple (
659
+ _get_cache_key_gen (
660
+ param_names , include_params , param_defaults , args , kwargs
661
+ )
662
+ )
624
663
625
664
return get_cache_key
626
665
627
666
628
667
def _get_cache_key_gen (
629
668
param_names : Iterable [str ],
669
+ include_params : Iterable [bool ],
630
670
param_defaults : Mapping [str , Any ],
631
671
args : Sequence [Any ],
632
672
kwargs : Mapping [str , Any ],
@@ -637,16 +677,18 @@ def _get_cache_key_gen(
637
677
This is essentially the same operation as `inspect.getcallargs`, but optimised so
638
678
that we don't need to inspect the target function for each call.
639
679
"""
640
-
641
680
# We loop through each arg name, looking up if its in the `kwargs`,
642
681
# otherwise using the next argument in `args`. If there are no more
643
682
# args then we try looking the arg name up in the defaults.
644
683
pos = 0
645
- for nm in param_names :
684
+ for nm , inc in zip ( param_names , include_params ) :
646
685
if nm in kwargs :
647
- yield kwargs [nm ]
686
+ if inc :
687
+ yield kwargs [nm ]
648
688
elif pos < len (args ):
649
- yield args [pos ]
689
+ if inc :
690
+ yield args [pos ]
650
691
pos += 1
651
692
else :
652
- yield param_defaults [nm ]
693
+ if inc :
694
+ yield param_defaults [nm ]
0 commit comments