Skip to content

Commit 81dc7f4

Browse files
committed
[WIP] Pipeline: Rate limiter class
1 parent d7c3962 commit 81dc7f4

File tree

3 files changed

+123
-36
lines changed

3 files changed

+123
-36
lines changed

extraasync/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
from .taskgroup import ExtraTaskGroup
33
from .sync_async_bridge import sync_to_async, async_to_sync
44
from .async_hooks import at_loop_stop_callback, remove_loop_stop_callback
5-
from .pipeline import Pipeline
5+
from .pipeline import Pipeline, RateLimiter
66

77
__version__ = "0.3.0"
88

99

10-
__all__ = ["aenumerate", "ExtraTaskGroup", "sync_to_async", "async_to_sync", "at_loop_stop_callback", "pipeline", "remove_loop_stop_callback", "__version__"]
10+
__all__ = ["aenumerate", "ExtraTaskGroup", "sync_to_async", "async_to_sync", "at_loop_stop_callback", "Pipeline", "remove_loop_stop_callback", "RateLimiter", "__version__"]

extraasync/pipeline.py

Lines changed: 97 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
# SPDX-License-Identifier: CC-PDM-1.0
1+
# SPDX-License-Identifier: LGPL-3.0+
22
# author: Martin Jurča, Joao S. O. Bueno
33
import asyncio
44
from functools import partial
55
from logging import getLogger
66
from inspect import isawaitable
77
from itertools import chain
88
from collections.abc import MutableSet
9+
from numbers import Real as NReal # for typing purposes
10+
from decimal import Decimal
911

1012
import inspect
1113
import heapq
@@ -15,29 +17,17 @@
1517

1618
import typing as t
1719

20+
# for some reason, decimal is not a subtype of real.
21+
Real = NReal | Decimal
1822

1923
logger = getLogger(__name__)
2024

2125
T = t.TypeVar("T")
2226
R = t.TypeVar("R")
2327

24-
TIME_UNIT = t.Literal["second"] | t.Literal["minute"] | t.Literal["hour"] | t.Literal["day"]
25-
NUMBER = int | float
26-
27-
28-
def normalize_freq(value: NUMBER, unit: TIME_UNIT) -> float: # normalizes frequency to 'day'
29-
match unit:
30-
case "second":
31-
value *= (60 * 60 * 24)
32-
case "minute":
33-
value *= (60 * 24)
34-
case "hour":
35-
value *= 24
36-
case "day":
37-
pass
38-
case _:
39-
raise ValueError(f"Invalid time unit for frequency throttle - should be one of {TIME_UNIT}")
40-
return value
28+
TIME_UNIT = (
29+
t.Literal["second"] | t.Literal["minute"] | t.Literal["hour"] | t.Literal["day"]
30+
)
4131

4232

4333
# sentinels:
@@ -112,6 +102,8 @@ def discard(self, value):
112102
def _as_async_iterable(
113103
iterable: t.AsyncIterable[T] | t.Iterable[T],
114104
) -> t.AsyncIterable[T]:
105+
# author: Martin Jurča
106+
# License: CC-PDM-1.0
115107
if isinstance(iterable, t.AsyncIterable):
116108
return iterable
117109

@@ -125,15 +117,80 @@ async def _sync_to_async_iterable() -> t.AsyncIterable[T]:
125117
PipelineErrors = t.Literal["strict", "ignore", "lazy_raise"]
126118

127119

120+
class RateLimiter:
121+
"""Intended to limit rates for running a given Stage -
122+
123+
Use, for example, to respect the rate limit of
124+
external APIs.
125+
126+
Just await the instance before executing each action that should be throttled.
127+
"""
128+
129+
# This is offset to a separate class so that it can be plugable
130+
# (e.g. for an off-process coordinated limiter)
131+
def __init__(self, rate_limit: Real, unit: TIME_UNIT = "second"):
132+
self.rate_limit = rate_limit
133+
self.unit = unit
134+
self.last_reset: None | float = None
135+
136+
def reset(self):
137+
# self.event = asyncio.Event()
138+
loop = asyncio.get_running_loop()
139+
self.last_reset = loop.time()
140+
141+
def __await__(self):
142+
loop = asyncio.get_running_loop()
143+
if (
144+
self.last_reset is None
145+
or (remaining := self.normalized - (loop.time() - self.last_reset)) < 0
146+
):
147+
yield None
148+
self.reset()
149+
return
150+
fut = loop.create_future()
151+
loop.call_later(remaining, lambda: fut.set_result(None))
152+
yield from fut
153+
self.reset()
154+
return None
155+
156+
def __copy__(self):
157+
instance = type(self)()
158+
instance.__dict__.update(self.__dict__)
159+
instance.last_reset = None
160+
return instance
161+
162+
@property
163+
def normalized(self):
164+
"""normalizes frequency to 'second' and returns interval between calls"""
165+
value = self.rate_limit
166+
match self.unit:
167+
case "second":
168+
pass
169+
case "minute":
170+
value /= 60
171+
case "hour":
172+
value /= 3600
173+
case "day":
174+
value /= 24 * 3600
175+
case _:
176+
raise ValueError(
177+
f"Invalid time unit for frequency throtle - should be one of {TIME_UNIT}"
178+
)
179+
return 1 / value
180+
181+
def __repr__(self):
182+
return f"{self.__class__.__name__}({self.rate_limit}, {self.unit})"
183+
184+
128185
class Stage:
129186
tasks = None
130187

131188
def __init__(
132189
self,
133190
code,
134191
max_concurrency: t.Optional[int] = None,
135-
rate_limit: t.Optional[NUMBER] = None,
136-
rate_limit_unit: TIME_UNIT = second,
192+
rate_limit: None | RateLimiter = None,
193+
rate_limit_unit: TIME_UNIT = "second",
137194
preserve_order: bool = True,
138195
force_concurrency: bool = True,
139196
parent: "Pipeline" = None,
@@ -147,19 +204,15 @@ def __init__(
147204
"""
148205
self.code = code
149206
self.max_concurrency = max_concurrency
150-
self.rate_limit = normalize_freq(rate_limit, rate_limit_unit) if rate_limit is not None else None
207+
self.rate_limiter = (
208+
rate_limit
209+
if isinstance(rate_limit, RateLimiter)
210+
else RateLimiter(rate_limit, rate_limit_unit) if rate_limit else None
211+
)
151212
self.preserve_order = preserve_order
152213
self.parent = parent
153214
self.reset()
154215

155-
@property
156-
def rate_limit(self):
157-
return self._rate_limit if self._rate_limit else self.parent.rate_limit
158-
159-
@rate_limit.setter
160-
def rate_limit(self, value):
161-
self._rate_limit = value
162-
163216
def add_next_stage(self, next_):
164217
self.next.add(next_)
165218

@@ -213,7 +266,6 @@ def __repr__(self):
213266
return f"{self.__class__.__name__}{self.code}"
214267

215268

216-
217269
class Pipeline:
218270
"""
219271
Pipeline class
@@ -232,8 +284,8 @@ def __init__(
232284
source: t.Optional[t.AsyncIterable[T] | t.Iterable[T]],
233285
*stages: t.Sequence[t.Callable | Stage],
234286
max_concurrency: t.Optional[int] = None,
235-
rate_limit: t.Optional[int] = None,
236-
rate_limit_unit: T_TIME_UNIT = "second",
287+
rate_limit: None | RateLimiter | Real = None,
288+
rate_limit_unit: TIME_UNIT = "second",
237289
on_error: PipelineErrors = "strict",
238290
preserve_order: bool = False,
239291
max_simultaneous_records: t.Optional[int] = None,
@@ -249,19 +301,31 @@ def __init__(
249301
limited to 4)
250302
- on_error: WHat to do if any stage raises an exeception - defaults to re-raise the
251303
exception and stop the whole pipeline
304+
- rate_limit: An overall rate-limitting parameter which can be used to throtle all stages.
305+
If anyone stage should have a limit different from the limit to the whole pipeline,
306+
create it as an explicit Stage instance and configure the limiter there.
307+
- rate_limit_unit: if rate_limit is given as a number, this states the time unit to be used in the rate limiting ratio.
308+
Not used otherwise.
252309
- preserve_order: whether to yield the final results in the same order they were acquired from data.
253310
- max_simultaneous_records: limit on amount of records to hold across all stages and input in internal
254311
data structures: the idea is throtle data consumption in order to limit the
255312
amount of memory used by the Pipeline
256313
257314
"""
258315
self.max_concurrency = max_concurrency
259-
self.data = _as_async_iterable(source) if source not in (None, Placeholder) else None
316+
self.data = (
317+
_as_async_iterable(source) if source not in (None, Placeholder) else None
318+
)
260319
self.preserve_order = preserve_order
261320
# TBD: maybe allow limitting total memory usage instead of elements in the pipeline?
262321
self.max_simultaneous_records = max_simultaneous_records
263322
self.on_error = on_error
264323
self.raw_stages = stages
324+
self.rate_limiter = (
325+
rate_limit
326+
if isinstance(rate_limit, RateLimiter)
327+
else RateLimiter(rate_limit, rate_limit_unit) if rate_limit else None
328+
)
265329
self.reset()
266330

267331
def _create_stages(self, stages):

tests/test_pipeline.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pytest
88

9-
from extraasync import Pipeline
9+
from extraasync import Pipeline, RateLimiter
1010

1111

1212
@pytest.mark.asyncio
@@ -332,3 +332,26 @@ async def test_pipeline_store_result_r_rshift_operator(): ...
332332
@pytest.mark.skip
333333
@pytest.mark.asyncio
334334
async def test_pipeline_fine_tune_stages(): ...
335+
336+
337+
@pytest.mark.asyncio
338+
async def test_rate_limiter_starts_immediately():
339+
loop = asyncio.get_running_loop()
340+
threshold = 0.005 # ~sys.getswitchinterval()
341+
limiter = RateLimiter(1)
342+
start_time = loop.time()
343+
await limiter
344+
assert loop.time() - start_time < threshold
345+
346+
347+
@pytest.mark.asyncio
348+
async def test_rate_limiter_throtles_rate():
349+
loop = asyncio.get_running_loop()
350+
threshold = 0.005 # ~sys.getswitchinterval()
351+
limiter = RateLimiter(20, "second")
352+
start_time = loop.time()
353+
for i in range(11):
354+
await limiter
355+
assert (
356+
loop.time() - start_time >= 0.5
357+
) # should be equal or greater than half second

0 commit comments

Comments
 (0)