Skip to content
This repository was archived by the owner on Dec 26, 2022. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 44 additions & 12 deletions pincer/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import logging
import signal
from asyncio import (
iscoroutinefunction,
ensure_future,
Expand Down Expand Up @@ -201,6 +202,17 @@ def __init__(
throttler: ThrottleInterface = DefaultThrottleHandler,
reconnect: bool = True,
):
def sigint_handler(_signal, _frame):
_log.info("SIGINT received, shutting down...")

# A print statement to make sure the user sees the message
print("Closing the client loop, this can take a few seconds...")

create_task(self.http.close())
if self.loop.is_running():
self.loop.stop()

signal.signal(signal.SIGINT, sigint_handler)

if isinstance(intents, Iterable):
intents = sum(intents)
Expand All @@ -218,13 +230,14 @@ def __init__(
APIObject.link(self)

self.throttler = throttler
self.event_mgr = EventMgr()

async def get_gateway():
return GatewayInfo.from_dict(await self.http.get("gateway/bot"))

loop = get_event_loop()
self.gateway: GatewayInfo = loop.run_until_complete(get_gateway())
self.loop = get_event_loop()
self.event_mgr = EventMgr(self.loop)

self.gateway: GatewayInfo = self.loop.run_until_complete(get_gateway())

# The guild and channel value is only registered if the Client has the GUILDS
# intent.
Expand Down Expand Up @@ -495,9 +508,8 @@ def execute_event(calls: List[Coro], gateway: Gateway, *args, **kwargs):

def run(self):
"""Start the bot."""
loop = get_event_loop()
ensure_future(self.start_shard(0, 1), loop=loop)
loop.run_forever()
ensure_future(self.start_shard(0, 1), loop=self.loop)
self.loop.run_forever()

def run_autosharded(self):
"""
Expand All @@ -515,12 +527,10 @@ def run_shards(self, shards: Iterable, num_shards: int):
num_shards: int
The total amount of shards.
"""
loop = get_event_loop()

for shard in shards:
ensure_future(self.start_shard(shard, num_shards), loop=loop)
ensure_future(self.start_shard(shard, num_shards), loop=self.loop)

loop.run_forever()
self.loop.run_forever()

async def start_shard(self, shard: int, num_shards: int):
"""|coro|
Expand Down Expand Up @@ -554,11 +564,33 @@ async def start_shard(self, shard: int, num_shards: int):

create_task(gateway.start_loop())

def __del__(self):
"""Ensure close of the http client."""
@property
def is_closed(self) -> bool:
"""
Returns
-------
bool
Whether the bot is closed.
"""
return self.loop.is_running()

def close(self):
"""
Ensure close of the http client.
Allow for script execution to continue.
"""
if hasattr(self, "http"):
create_task(self.http.close())

self.loop.stop()

def __del__(self):
if self.loop.is_running():
self.loop.stop()

if not self.loop.is_closed():
self.close()

async def handle_middleware(
self,
payload: GatewayDispatch,
Expand Down
28 changes: 9 additions & 19 deletions pincer/utils/event_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from asyncio import Event, wait_for as _wait_for, get_running_loop, TimeoutError
from asyncio import Event, wait_for as _wait_for, TimeoutError
from collections import deque
from typing import TYPE_CHECKING

from ..exceptions import TimeoutError as PincerTimeoutError

if TYPE_CHECKING:
from asyncio import AbstractEventLoop
from typing import Any, List, Union, Optional
from .types import CheckFunction


class _Processable(ABC):

@abstractmethod
def process(self, event_name: str, event_value: Any):
"""
Expand Down Expand Up @@ -89,11 +89,7 @@ class _Event(_Processable):
returned later.
"""

def __init__(
self,
event_name: str,
check: CheckFunction
):
def __init__(self, event_name: str, check: CheckFunction):
self.event_name = event_name
self.check = check
self.event = Event()
Expand Down Expand Up @@ -197,8 +193,9 @@ class EventMgr:
The List of events that need to be processed.
"""

def __init__(self):
def __init__(self, loop: AbstractEventLoop):
self.event_list: List[_Processable] = []
self.loop = loop

def process_events(self, event_name, event_value):
"""
Expand All @@ -213,10 +210,7 @@ def process_events(self, event_name, event_value):
event.process(event_name, event_value)

async def wait_for(
self,
event_name: str,
check: CheckFunction,
timeout: Optional[float]
self, event_name: str, check: CheckFunction, timeout: Optional[float]
) -> Any:
"""
Parameters
Expand Down Expand Up @@ -277,17 +271,13 @@ async def loop_for(
loop_mgr = _LoopMgr(event_name, check)
self.event_list.append(loop_mgr)

loop = get_running_loop()

while True:
start_time = loop.time()
start_time = self.loop.time()

try:
yield await _wait_for(
loop_mgr.get_next(),
timeout=_lowest_value(
loop_timeout, iteration_timeout
)
timeout=_lowest_value(loop_timeout, iteration_timeout),
)

except TimeoutError:
Expand All @@ -305,7 +295,7 @@ async def loop_for(
# `not` can't be used here because there is a check for
# `loop_timeout == 0`
if loop_timeout is not None:
loop_timeout -= loop.time() - start_time
loop_timeout -= self.loop.time() - start_time

# loop_timeout can be below 0 if the user's code in the for loop
# takes longer than the time left in loop_timeout
Expand Down