@@ -83,32 +83,50 @@ async def pipeline(
83
83
next_to_yield = 0
84
84
early_results : dict [int , R ] = {}
85
85
86
+ active_worker_tasks = set ()
87
+ semaphore = asyncio .Semaphore (max_concurrency )
88
+
89
+ async def _work_step (item , index , * args , ** kwargs ):
90
+ result = exception = None
91
+ try :
92
+ result : R | t .Awaitable [R ]= func (item , * args , ** kwargs )
93
+ except Exception as exc :
94
+ exception = exc
95
+
96
+ if not exception and isinstance (result , t .Awaitable ):
97
+ # cast in original code doesn't make sense:
98
+ # mapping callback must await to a "R" otherwise a typing error is justified.
99
+ # result = t.cast(R, await result)
100
+ try :
101
+ result = await result
102
+ except Exception as exc :
103
+ exception = exc
104
+ await output_queue .put ((index , result , exception ))
105
+ input_queue .task_done ()
106
+ #return index, result, exception
107
+
108
+
86
109
async def _worker () -> None :
87
110
while True :
88
111
89
- index , item , _ = await input_queue .get ()
90
-
91
- result = exception = None
112
+ index , item , input_exception = await input_queue .get ()
92
113
93
- if item is input_terminator :
114
+ if input_exception or item is input_terminator :
94
115
input_queue .task_done ()
95
116
break
96
- try :
97
- result : R | t .Awaitable [R ]= func (item , * args , ** kwargs )
98
- except Exception as exc :
99
- exception = exc
100
117
101
- if not exception and isinstance (result , t .Awaitable ):
102
- # cast in original code doesn't make sense:
103
- # mapping callback must await to a "R" otherwise a typing error is justified.
104
- # result = t.cast(R, await result)
105
- try :
106
- result = await result
107
- except Exception as exc :
108
- exception = exc
109
- await output_queue .put ((index , result , exception ))
110
- input_queue .task_done ()
118
+ if await semaphore .acquire ():
119
+ task = asyncio .create_task (_work_step (item , index , * args , ** kwargs ))
120
+ active_worker_tasks .add (task )
121
+ def cleanup (t ):
122
+ print (f"cleanng up { t } " )
123
+ worker_tasks .remove (t )
124
+ semaphore .release ()
125
+ print (f"cleaned up { t } " )
111
126
127
+ task .add_done_callback (lambda t : (active_worker_tasks .remove (t ), semaphore .release ()))
128
+
129
+ await asyncio .gather (* active_worker_tasks )
112
130
await output_queue .put ((- 1 , output_terminator , None ))
113
131
114
132
@@ -143,13 +161,14 @@ async def _feeder() -> None:
143
161
144
162
async def _re_orderer () -> t .AsyncIterable [R ]:
145
163
nonlocal next_to_yield
146
- remaining_workers = max_concurrency
147
- while remaining_workers :
164
+ finish = False
165
+ while True :
148
166
index , result , exception = await output_queue .get ()
149
167
if result is output_terminator :
150
- remaining_workers -= 1
151
168
output_queue .task_done ()
152
- continue
169
+ finish = True
170
+
171
+ return
153
172
154
173
early_results [index ] = result if exception is None else exception
155
174
while next_to_yield in early_results :
@@ -161,18 +180,19 @@ async def _re_orderer() -> t.AsyncIterable[R]:
161
180
next_to_yield += 1
162
181
output_queue .task_done ()
163
182
164
- tasks = [
165
- asyncio .create_task (_worker ()) for _ in range (max_concurrency )
166
- ] + [asyncio .create_task (_feeder ())]
183
+ admin_tasks = {
184
+ asyncio .create_task (_worker ()),
185
+ asyncio .create_task (_feeder ())
186
+ }
167
187
168
188
try :
169
189
async for result in _re_orderer ():
170
190
yield result
171
191
finally :
172
- for task in tasks :
192
+ for task in admin_tasks :
173
193
task .cancel ()
174
194
175
- await asyncio .gather (* tasks , return_exceptions = True )
195
+ await asyncio .gather (* admin_tasks , return_exceptions = True )
176
196
177
197
178
198
Pipeline = pipeline
0 commit comments