Skip to content

Commit abc48fe

Browse files
committed
Add max_concurrency feature to ExtraTaskGroup
1 parent a9fb701 commit abc48fe

File tree

4 files changed

+68
-7
lines changed

4 files changed

+68
-7
lines changed

README.md

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,17 @@ any exception is raised in any of the taskgroup tasks,
4444
all sibling incomplete tasks get cancelled immediatelly.
4545

4646
With ExtraTaskGroup, all created tasks are run to completion,
47-
and any exceptions are bubbled up as ExceptionGroups on
48-
the host task.
47+
by default and any exceptions are bubbled up as ExceptionGroups on
48+
the host task. If the classic "cancel all other tasks" behavior
49+
is desired, the named argument `default_abort=True` can be
50+
passed when the group is created.
51+
52+
ExtraTaskGroup will also optionally limit the number
53+
of concurrent tasks that are executed. If `max_concurrency` is given,
54+
tasks created with an inner `.create_task` call will be
55+
bounded by a Semaphore, meaning at most "max_concurrency" tasks
56+
will be running at the same time, and others will start running as the first
57+
ones are completed.
4958

5059
```python
5160
import asyncio
@@ -70,6 +79,9 @@ asyncio.run(main())
7079

7180
```
7281

82+
new in 0.4: the max_concurrency and default_abort paramters
83+
84+
7385
sync_to_async
7486
----------------------
7587
Allows calling an async function from a synchronous context.

extraasync/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .sync_async_bridge import sync_to_async, async_to_sync
44
from .async_hooks import at_loop_stop_callback, remove_loop_stop_callback
55

6-
__version__ = "0.3.0"
6+
__version__ = "0.4.0"
77

88

99
__all__ = ["aenumerate", "ExtraTaskGroup", "sync_to_async", "async_to_sync", "at_loop_stop_callback", "remove_loop_stop_callback", "__version__"]

extraasync/taskgroup.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
import asyncio
12
from asyncio import TaskGroup
23

4+
from typing import Optional
5+
6+
37
# Idea originally developed for an answer on StackOverflow
48
# at: https://stackoverflow.com/questions/75250788/how-to-prevent-python3-11-taskgroup-from-canceling-all-the-tasks/75261668#75261668
59

@@ -17,7 +21,7 @@
1721

1822

1923
class ExtraTaskGroup(TaskGroup):
20-
def __init__(self, *, default_abort: bool = False):
24+
def __init__(self, *, max_concurrency: Optional[int] = None, default_abort: bool = False):
2125
"""A subclass of asyncio.TaskGroup
2226
2327
By default, the different behavior is that if a
@@ -33,16 +37,34 @@ def __init__(self, *, default_abort: bool = False):
3337
single exception.
3438
3539
Args:
40+
max_concurrency: If given, the maximum number of active spawned tasks permited at each time.
41+
The asynchronous "acreate_task" method should be used then, and calling "create_task" will raise
42+
an error.
43+
Defaults to None. implemented internally as using a semaphore.
44+
3645
default_abort: if True, allows the default asyncio.TaskGroup behavior
37-
or aborting all other running tasks when the first one raises an exception.
46+
or aborting all other running tasks when the first one raises an exception.
3847
3948
4049
"""
41-
self.default_abort = default_abort
50+
self.__max_concurrency = max_concurrency
51+
if max_concurrency != None:
52+
self.__semaphore = asyncio.Semaphore(max_concurrency)
53+
self.__default_abort = default_abort
4254
super().__init__()
4355

56+
57+
def create_task(self, coro, *args, **kwargs):
58+
if self.__max_concurrency is None:
59+
return super().create_task(coro, *args, **kwargs)
60+
return super().create_task(self.__managed_task(coro), *args, **kwargs)
61+
62+
async def __managed_task(self, coro):
63+
async with self.__semaphore:
64+
return await coro
65+
4466
def _abort(self):
45-
if self.default_abort:
67+
if self.__default_abort:
4668
return super()._abort()
4769
return None
4870

tests/test_taskgroup.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,30 @@ class Dummy:
6666

6767

6868

69+
@pytest.mark.parametrize(
70+
["max_tasks", "max_concurrency"],[
71+
(2, 1),
72+
(5, 1),
73+
(5, 2),
74+
(30, 10),
75+
])
76+
@pytest.mark.asyncio
77+
async def test_extrataskgroup_limits_concurrency(max_tasks, max_concurrency):
78+
running = 0
79+
max_running = 0
80+
async def blah():
81+
nonlocal max_running, running
82+
running += 1
83+
max_running = max(max_running, running)
84+
await asyncio.sleep(0.05)
85+
running -= 1
86+
return 23
87+
88+
tasks = set()
89+
90+
async with ExtraTaskGroup(max_concurrency=max_concurrency) as tg:
91+
for _ in range(max_tasks):
92+
tasks.add(tg.create_task(blah()))
93+
94+
assert all(task.result() == 23 for task in tasks)
95+
assert max_running == max_concurrency

0 commit comments

Comments
 (0)