Skip to content

Commit b36dbdb

Browse files
committed
Refactor worker tasks
1 parent 4fe6e05 commit b36dbdb

File tree

1 file changed

+47
-27
lines changed

1 file changed

+47
-27
lines changed

extraasync/pipeline.py

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -83,32 +83,50 @@ async def pipeline(
8383
next_to_yield = 0
8484
early_results: dict[int, R] = {}
8585

86+
active_worker_tasks = set()
87+
semaphore = asyncio.Semaphore(max_concurrency)
88+
89+
async def _work_step(item, index, *args, **kwargs):
90+
result = exception = None
91+
try:
92+
result: R | t.Awaitable[R]= func(item, *args, **kwargs)
93+
except Exception as exc:
94+
exception = exc
95+
96+
if not exception and isinstance(result, t.Awaitable):
97+
# cast in original code doesn't make sense:
98+
# mapping callback must await to a "R" otherwise a typing error is justified.
99+
# result = t.cast(R, await result)
100+
try:
101+
result = await result
102+
except Exception as exc:
103+
exception = exc
104+
await output_queue.put((index, result, exception))
105+
input_queue.task_done()
106+
#return index, result, exception
107+
108+
86109
async def _worker() -> None:
87110
while True:
88111

89-
index, item, _ = await input_queue.get()
90-
91-
result = exception = None
112+
index, item, input_exception = await input_queue.get()
92113

93-
if item is input_terminator:
114+
if input_exception or item is input_terminator:
94115
input_queue.task_done()
95116
break
96-
try:
97-
result: R | t.Awaitable[R]= func(item, *args, **kwargs)
98-
except Exception as exc:
99-
exception = exc
100117

101-
if not exception and isinstance(result, t.Awaitable):
102-
# cast in original code doesn't make sense:
103-
# mapping callback must await to a "R" otherwise a typing error is justified.
104-
# result = t.cast(R, await result)
105-
try:
106-
result = await result
107-
except Exception as exc:
108-
exception = exc
109-
await output_queue.put((index, result, exception))
110-
input_queue.task_done()
118+
if await semaphore.acquire():
119+
task = asyncio.create_task(_work_step(item, index, *args, **kwargs))
120+
active_worker_tasks.add(task)
121+
def cleanup(t):
122+
print(f"cleanng up {t}")
123+
worker_tasks.remove(t)
124+
semaphore.release()
125+
print(f"cleaned up {t}")
111126

127+
task.add_done_callback(lambda t: (active_worker_tasks.remove(t), semaphore.release()))
128+
129+
await asyncio.gather(*active_worker_tasks)
112130
await output_queue.put((-1, output_terminator, None))
113131

114132

@@ -143,13 +161,14 @@ async def _feeder() -> None:
143161

144162
async def _re_orderer() -> t.AsyncIterable[R]:
145163
nonlocal next_to_yield
146-
remaining_workers = max_concurrency
147-
while remaining_workers:
164+
finish = False
165+
while True:
148166
index, result, exception = await output_queue.get()
149167
if result is output_terminator:
150-
remaining_workers -= 1
151168
output_queue.task_done()
152-
continue
169+
finish = True
170+
171+
return
153172

154173
early_results[index] = result if exception is None else exception
155174
while next_to_yield in early_results:
@@ -161,18 +180,19 @@ async def _re_orderer() -> t.AsyncIterable[R]:
161180
next_to_yield += 1
162181
output_queue.task_done()
163182

164-
tasks = [
165-
asyncio.create_task(_worker()) for _ in range(max_concurrency)
166-
] + [asyncio.create_task(_feeder())]
183+
admin_tasks = {
184+
asyncio.create_task(_worker()),
185+
asyncio.create_task(_feeder())
186+
}
167187

168188
try:
169189
async for result in _re_orderer():
170190
yield result
171191
finally:
172-
for task in tasks:
192+
for task in admin_tasks:
173193
task.cancel()
174194

175-
await asyncio.gather(*tasks, return_exceptions=True)
195+
await asyncio.gather(*admin_tasks, return_exceptions=True)
176196

177197

178198
Pipeline = pipeline

0 commit comments

Comments
 (0)