1
- # SPDX-License-Identifier: CC-PDM-1.0
1
+ # SPDX-License-Identifier: LGPL-3.0+
2
2
# author: Martin Jurča, Joao S. O. Bueno
3
3
import asyncio
4
4
from functools import partial
5
5
from logging import getLogger
6
6
from inspect import isawaitable
7
7
from itertools import chain
8
8
from collections .abc import MutableSet
9
+ from numbers import Real as NReal # for typing purposes
10
+ from decimal import Decimal
9
11
10
12
import inspect
11
13
import heapq
15
17
16
18
import typing as t
17
19
20
+ # for some reason, decimal is not a subtype of real.
21
+ Real = NReal | Decimal
18
22
19
23
logger = getLogger (__name__ )
20
24
21
25
T = t .TypeVar ("T" )
22
26
R = t .TypeVar ("R" )
23
27
24
- TIME_UNIT = t .Literal ["second" ] | t .Literal ["minute" ] | t .Literal ["hour" ] | t .Literal ["day" ]
25
- NUMBER = int | float
26
-
27
-
28
- def normalize_freq (value : NUMBER , unit : TIME_UNIT ) -> float : # normalizes frequency to 'day'
29
- match unit :
30
- case "second" :
31
- value *= (60 * 60 * 24 )
32
- case "minute" :
33
- value *= (60 * 24 )
34
- case "hour" :
35
- value *= 24
36
- case "day" :
37
- pass
38
- case _:
39
- raise ValueError (f"Invalid time unit for frequency throttle - should be one of { TIME_UNIT } " )
40
- return value
28
+ TIME_UNIT = (
29
+ t .Literal ["second" ] | t .Literal ["minute" ] | t .Literal ["hour" ] | t .Literal ["day" ]
30
+ )
41
31
42
32
43
33
# sentinels:
@@ -112,6 +102,8 @@ def discard(self, value):
112
102
def _as_async_iterable (
113
103
iterable : t .AsyncIterable [T ] | t .Iterable [T ],
114
104
) -> t .AsyncIterable [T ]:
105
+ # author: Martin Jurča
106
+ # License: CC-PDM-1.0
115
107
if isinstance (iterable , t .AsyncIterable ):
116
108
return iterable
117
109
@@ -125,15 +117,80 @@ async def _sync_to_async_iterable() -> t.AsyncIterable[T]:
125
117
PipelineErrors = t .Literal ["strict" , "ignore" , "lazy_raise" ]
126
118
127
119
120
+ class RateLimiter :
121
+ """Intended to limit rates for running a given Stage -
122
+
123
+ Use, for example, to respect the rate limit of
124
+ external APIs.
125
+
126
+ Just await the instance before executing each action that should be throttled.
127
+ """
128
+
129
+ # This is offset to a separate class so that it can be plugable
130
+ # (e.g. for an off-process coordinated limiter)
131
+ def __init__ (self , rate_limit : Real , unit : TIME_UNIT = "second" ):
132
+ self .rate_limit = rate_limit
133
+ self .unit = unit
134
+ self .last_reset : None | float = None
135
+
136
+ def reset (self ):
137
+ # self.event = asyncio.Event()
138
+ loop = asyncio .get_running_loop ()
139
+ self .last_reset = loop .time ()
140
+
141
+ def __await__ (self ):
142
+ loop = asyncio .get_running_loop ()
143
+ if (
144
+ self .last_reset is None
145
+ or (remaining := self .normalized - (loop .time () - self .last_reset )) < 0
146
+ ):
147
+ yield None
148
+ self .reset ()
149
+ return
150
+ fut = loop .create_future ()
151
+ loop .call_later (remaining , lambda : fut .set_result (None ))
152
+ yield from fut
153
+ self .reset ()
154
+ return None
155
+
156
+ def __copy__ (self ):
157
+ instance = type (self )()
158
+ instance .__dict__ .update (self .__dict__ )
159
+ instance .last_reset = None
160
+ return instance
161
+
162
+ @property
163
+ def normalized (self ):
164
+ """normalizes frequency to 'second' and returns interval between calls"""
165
+ value = self .rate_limit
166
+ match self .unit :
167
+ case "second" :
168
+ pass
169
+ case "minute" :
170
+ value /= 60
171
+ case "hour" :
172
+ value /= 3600
173
+ case "day" :
174
+ value /= 24 * 3600
175
+ case _:
176
+ raise ValueError (
177
+ f"Invalid time unit for frequency throtle - should be one of { TIME_UNIT } "
178
+ )
179
+ return 1 / value
180
+
181
+ def __repr__ (self ):
182
+ return f"{ self .__class__ .__name__ } ({ self .rate_limit } , { self .unit } )"
183
+
184
+
128
185
class Stage :
129
186
tasks = None
130
187
131
188
def __init__ (
132
189
self ,
133
190
code ,
134
191
max_concurrency : t .Optional [int ] = None ,
135
- rate_limit : t . Optional [ NUMBER ] = None ,
136
- rate_limit_unit : TIME_UNIT = second ,
192
+ rate_limit : None | RateLimiter = None ,
193
+ rate_limit_unit : TIME_UNIT = " second" ,
137
194
preserve_order : bool = True ,
138
195
force_concurrency : bool = True ,
139
196
parent : "Pipeline" = None ,
@@ -147,19 +204,15 @@ def __init__(
147
204
"""
148
205
self .code = code
149
206
self .max_concurrency = max_concurrency
150
- self .rate_limit = normalize_freq (rate_limit , rate_limit_unit ) if rate_limit is not None else None
207
+ self .rate_limiter = (
208
+ rate_limit
209
+ if isinstance (rate_limit , RateLimiter )
210
+ else RateLimiter (rate_limit , rate_limit_unit ) if rate_limit else None
211
+ )
151
212
self .preserve_order = preserve_order
152
213
self .parent = parent
153
214
self .reset ()
154
215
155
- @property
156
- def rate_limit (self ):
157
- return self ._rate_limit if self ._rate_limit else self .parent .rate_limit
158
-
159
- @rate_limit .setter
160
- def rate_limit (self , value ):
161
- self ._rate_limit = value
162
-
163
216
def add_next_stage (self , next_ ):
164
217
self .next .add (next_ )
165
218
@@ -213,7 +266,6 @@ def __repr__(self):
213
266
return f"{ self .__class__ .__name__ } { self .code } "
214
267
215
268
216
-
217
269
class Pipeline :
218
270
"""
219
271
Pipeline class
@@ -232,8 +284,8 @@ def __init__(
232
284
source : t .Optional [t .AsyncIterable [T ] | t .Iterable [T ]],
233
285
* stages : t .Sequence [t .Callable | Stage ],
234
286
max_concurrency : t .Optional [int ] = None ,
235
- rate_limit : t . Optional [ int ] = None ,
236
- rate_limit_unit : T_TIME_UNIT = "second" ,
287
+ rate_limit : None | RateLimiter | Real = None ,
288
+ rate_limit_unit : TIME_UNIT = "second" ,
237
289
on_error : PipelineErrors = "strict" ,
238
290
preserve_order : bool = False ,
239
291
max_simultaneous_records : t .Optional [int ] = None ,
@@ -249,19 +301,31 @@ def __init__(
249
301
limited to 4)
250
302
- on_error: WHat to do if any stage raises an exeception - defaults to re-raise the
251
303
exception and stop the whole pipeline
304
+ - rate_limit: An overall rate-limitting parameter which can be used to throtle all stages.
305
+ If anyone stage should have a limit different from the limit to the whole pipeline,
306
+ create it as an explicit Stage instance and configure the limiter there.
307
+ - rate_limit_unit: if rate_limit is given as a number, this states the time unit to be used in the rate limiting ratio.
308
+ Not used otherwise.
252
309
- preserve_order: whether to yield the final results in the same order they were acquired from data.
253
310
- max_simultaneous_records: limit on amount of records to hold across all stages and input in internal
254
311
data structures: the idea is throtle data consumption in order to limit the
255
312
amount of memory used by the Pipeline
256
313
257
314
"""
258
315
self .max_concurrency = max_concurrency
259
- self .data = _as_async_iterable (source ) if source not in (None , Placeholder ) else None
316
+ self .data = (
317
+ _as_async_iterable (source ) if source not in (None , Placeholder ) else None
318
+ )
260
319
self .preserve_order = preserve_order
261
320
# TBD: maybe allow limitting total memory usage instead of elements in the pipeline?
262
321
self .max_simultaneous_records = max_simultaneous_records
263
322
self .on_error = on_error
264
323
self .raw_stages = stages
324
+ self .rate_limiter = (
325
+ rate_limit
326
+ if isinstance (rate_limit , RateLimiter )
327
+ else RateLimiter (rate_limit , rate_limit_unit ) if rate_limit else None
328
+ )
265
329
self .reset ()
266
330
267
331
def _create_stages (self , stages ):
0 commit comments