Skip to content
This repository was archived by the owner on Dec 26, 2022. It is now read-only.

Commit 58dc68d

Browse files
SigmanificientEnderchieftrag1c
authored
✨ Adding close, is_closed and saving loop (#371)
* ✨ Storing loop * ✨ close & is_closed * 📝 improving documentation * ✨ Adding proper Sig INT handling * 🎨 reformat code * Update pincer/client.py Co-authored-by: Endercheif <[email protected]> * Update pincer/client.py Co-authored-by: trag1c <[email protected]> Co-authored-by: Endercheif <[email protected]> Co-authored-by: trag1c <[email protected]>
1 parent 7f23a3d commit 58dc68d

File tree

2 files changed

+53
-31
lines changed

2 files changed

+53
-31
lines changed

pincer/client.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import annotations
55

66
import logging
7+
import signal
78
from asyncio import (
89
iscoroutinefunction,
910
ensure_future,
@@ -201,6 +202,17 @@ def __init__(
201202
throttler: ThrottleInterface = DefaultThrottleHandler,
202203
reconnect: bool = True,
203204
):
205+
def sigint_handler(_signal, _frame):
206+
_log.info("SIGINT received, shutting down...")
207+
208+
# A print statement to make sure the user sees the message
209+
print("Closing the client loop, this can take a few seconds...")
210+
211+
create_task(self.http.close())
212+
if self.loop.is_running():
213+
self.loop.stop()
214+
215+
signal.signal(signal.SIGINT, sigint_handler)
204216

205217
if isinstance(intents, Iterable):
206218
intents = sum(intents)
@@ -218,13 +230,14 @@ def __init__(
218230
APIObject.link(self)
219231

220232
self.throttler = throttler
221-
self.event_mgr = EventMgr()
222233

223234
async def get_gateway():
224235
return GatewayInfo.from_dict(await self.http.get("gateway/bot"))
225236

226-
loop = get_event_loop()
227-
self.gateway: GatewayInfo = loop.run_until_complete(get_gateway())
237+
self.loop = get_event_loop()
238+
self.event_mgr = EventMgr(self.loop)
239+
240+
self.gateway: GatewayInfo = self.loop.run_until_complete(get_gateway())
228241

229242
# The guild and channel value is only registered if the Client has the GUILDS
230243
# intent.
@@ -495,9 +508,8 @@ def execute_event(calls: List[Coro], gateway: Gateway, *args, **kwargs):
495508

496509
def run(self):
497510
"""Start the bot."""
498-
loop = get_event_loop()
499-
ensure_future(self.start_shard(0, 1), loop=loop)
500-
loop.run_forever()
511+
ensure_future(self.start_shard(0, 1), loop=self.loop)
512+
self.loop.run_forever()
501513

502514
def run_autosharded(self):
503515
"""
@@ -515,12 +527,10 @@ def run_shards(self, shards: Iterable, num_shards: int):
515527
num_shards: int
516528
The total amount of shards.
517529
"""
518-
loop = get_event_loop()
519-
520530
for shard in shards:
521-
ensure_future(self.start_shard(shard, num_shards), loop=loop)
531+
ensure_future(self.start_shard(shard, num_shards), loop=self.loop)
522532

523-
loop.run_forever()
533+
self.loop.run_forever()
524534

525535
async def start_shard(self, shard: int, num_shards: int):
526536
"""|coro|
@@ -554,11 +564,33 @@ async def start_shard(self, shard: int, num_shards: int):
554564

555565
create_task(gateway.start_loop())
556566

557-
def __del__(self):
558-
"""Ensure close of the http client."""
567+
@property
568+
def is_closed(self) -> bool:
569+
"""
570+
Returns
571+
-------
572+
bool
573+
Whether the bot is closed.
574+
"""
575+
return self.loop.is_running()
576+
577+
def close(self):
578+
"""
579+
Ensure close of the http client.
580+
Allow for script execution to continue.
581+
"""
559582
if hasattr(self, "http"):
560583
create_task(self.http.close())
561584

585+
self.loop.stop()
586+
587+
def __del__(self):
588+
if self.loop.is_running():
589+
self.loop.stop()
590+
591+
if not self.loop.is_closed():
592+
self.close()
593+
562594
async def handle_middleware(
563595
self,
564596
payload: GatewayDispatch,

pincer/utils/event_mgr.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,19 @@
44
from __future__ import annotations
55

66
from abc import ABC, abstractmethod
7-
from asyncio import Event, wait_for as _wait_for, get_running_loop, TimeoutError
7+
from asyncio import Event, wait_for as _wait_for, TimeoutError
88
from collections import deque
99
from typing import TYPE_CHECKING
1010

1111
from ..exceptions import TimeoutError as PincerTimeoutError
1212

1313
if TYPE_CHECKING:
14+
from asyncio import AbstractEventLoop
1415
from typing import Any, List, Union, Optional
1516
from .types import CheckFunction
1617

1718

1819
class _Processable(ABC):
19-
2020
@abstractmethod
2121
def process(self, event_name: str, event_value: Any):
2222
"""
@@ -89,11 +89,7 @@ class _Event(_Processable):
8989
returned later.
9090
"""
9191

92-
def __init__(
93-
self,
94-
event_name: str,
95-
check: CheckFunction
96-
):
92+
def __init__(self, event_name: str, check: CheckFunction):
9793
self.event_name = event_name
9894
self.check = check
9995
self.event = Event()
@@ -194,8 +190,9 @@ class EventMgr:
194190
The List of events that need to be processed.
195191
"""
196192

197-
def __init__(self):
193+
def __init__(self, loop: AbstractEventLoop):
198194
self.event_list: List[_Processable] = []
195+
self.loop = loop
199196

200197
def process_events(self, event_name, event_value):
201198
"""
@@ -210,10 +207,7 @@ def process_events(self, event_name, event_value):
210207
event.process(event_name, event_value)
211208

212209
async def wait_for(
213-
self,
214-
event_name: str,
215-
check: CheckFunction,
216-
timeout: Optional[float]
210+
self, event_name: str, check: CheckFunction, timeout: Optional[float]
217211
) -> Any:
218212
"""
219213
Parameters
@@ -274,17 +268,13 @@ async def loop_for(
274268
loop_mgr = _LoopMgr(event_name, check)
275269
self.event_list.append(loop_mgr)
276270

277-
loop = get_running_loop()
278-
279271
while True:
280-
start_time = loop.time()
272+
start_time = self.loop.time()
281273

282274
try:
283275
yield await _wait_for(
284276
loop_mgr.get_next(),
285-
timeout=_lowest_value(
286-
loop_timeout, iteration_timeout
287-
)
277+
timeout=_lowest_value(loop_timeout, iteration_timeout),
288278
)
289279

290280
except TimeoutError:
@@ -302,7 +292,7 @@ async def loop_for(
302292
# `not` can't be used here because there is a check for
303293
# `loop_timeout == 0`
304294
if loop_timeout is not None:
305-
loop_timeout -= loop.time() - start_time
295+
loop_timeout -= self.loop.time() - start_time
306296

307297
# loop_timeout can be below 0 if the user's code in the for loop
308298
# takes longer than the time left in loop_timeout

0 commit comments

Comments
 (0)