Skip to content
This repository was archived by the owner on Dec 26, 2022. It is now read-only.

Commit 72993cb

Browse files
authored
🔀 Merge pull request #329 from Lunarmagpie/sharding
🐛 ✨ Added VERY basic sharding support and fixed the gateway
2 parents 6a4b2b5 + 05f7e3d commit 72993cb

File tree

78 files changed

+1605
-679
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+1605
-679
lines changed

docs/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
1818

1919
# <img src="../assets/svg/pincer.svg" height="24px" alt="Pincer Logo"> Pincer
20-
The snappy asynchronous discord api wrapper API wrapper written with aiohttp & websockets.
20+
The snappy asynchronous discord api wrapper API wrapper written with aiohttp.
2121

2222
| :exclamation: | The package is currently within Pre-Alpha phase |
2323
| ------------- | :---------------------------------------------- |

docs/api/pincer.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ Exceptions
8989

9090
.. autoexception:: UnavailableGuildError()
9191

92+
.. autoexception:: TimeoutError()
93+
94+
.. autoexception:: GatewayConnectionError()
95+
9296
.. autoexception:: HTTPError()
9397

9498
.. autoexception:: NotModifiedError()
@@ -149,6 +153,8 @@ Exception Hierarchy
149153
- :exc:`InvalidTokenError`
150154
- :exc:`HeartbeatError`
151155
- :exc:`UnavailableGuildError`
156+
- :exc:`TimeoutError`
157+
- :exc:`GatewayConnectionError`
152158
- :exc:`HTTPError`
153159
- :exc:`NotModifiedError`
154160
- :exc:`BadRequestError`

pincer/_config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,23 @@
88
@dataclass(repr=False)
99
class GatewayConfig:
1010
"""This file is to make maintaining the library and its gateway
11-
configuration easier.
11+
configuration easier. Leave compression blank for no compression.
1212
"""
13-
socket_base_url: str = "wss://gateway.discord.gg/"
13+
MAX_RETRIES: int = 5
1414
version: int = 9
1515
encoding: str = "json"
16-
compression: Optional[str] = "zlib-stream"
16+
compression: str = "zlib-stream"
1717

1818
@classmethod
19-
def uri(cls) -> str:
19+
def make_uri(cls, uri) -> str:
2020
"""
2121
Returns
2222
-------
2323
:class:`str`:
2424
The GatewayConfig's uri.
2525
"""
2626
return (
27-
f"{cls.socket_base_url}"
27+
f"{uri}"
2828
f"?v={cls.version}"
2929
f"&encoding={cls.encoding}"
3030
) + f"&compress={cls.compression}" * cls.compressed()

pincer/client.py

Lines changed: 125 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,26 @@
44
from __future__ import annotations
55

66
import logging
7-
from asyncio import iscoroutinefunction, run, ensure_future
7+
from asyncio import iscoroutinefunction, ensure_future, create_task, get_event_loop
88
from collections import defaultdict
9+
from functools import partial
910
from importlib import import_module
1011
from inspect import isasyncgenfunction
1112
from typing import (
1213
Any,
1314
Dict,
1415
List,
1516
Optional,
17+
Iterable,
1618
Tuple,
1719
Union,
1820
overload,
19-
AsyncIterator,
2021
TYPE_CHECKING
2122
)
2223
from . import __package__
2324
from .commands import ChatCommandHandler
2425
from .core import HTTPClient
25-
from .core.gateway import Dispatcher
26+
from .core.gateway import GatewayInfo, Gateway
2627
from .exceptions import (
2728
InvalidEventName,
2829
TooManySetupArguments,
@@ -49,7 +50,7 @@
4950
from .utils.conversion import construct_client_dict, remove_none
5051
from .utils.event_mgr import EventMgr
5152
from .utils.extraction import get_index
52-
from .utils.insertion import should_pass_cls
53+
from .utils.insertion import should_pass_cls, should_pass_gateway
5354
from .utils.signature import get_params
5455
from .utils.types import CheckFunction
5556
from .utils.types import Coro
@@ -137,12 +138,10 @@ def decorator(func: Coro):
137138
"already been registered"
138139
)
139140

140-
async def wrapper(cls, payload: GatewayDispatch):
141+
async def wrapper(cls, gateway: Gateway, payload: GatewayDispatch):
141142
_log.debug("`%s` middleware has been invoked", call)
142143

143-
return await (
144-
func(cls, payload) if should_pass_cls(func) else func(payload)
145-
)
144+
return await func(cls, gateway, payload)
146145

147146
_events[call] = wrapper
148147
return wrapper
@@ -154,7 +153,7 @@ async def wrapper(cls, payload: GatewayDispatch):
154153
event_middleware(event)(middleware_)
155154

156155

157-
class Client(Dispatcher):
156+
class Client:
158157
"""The client is the main instance which is between the programmer
159158
and the discord API.
160159
@@ -200,24 +199,27 @@ def __init__(
200199
if isinstance(intents, Iterable):
201200
intents = sum(intents)
202201

203-
super().__init__(
204-
token,
205-
handlers={
206-
# Gets triggered on all events
207-
-1: self.payload_event_handler,
208-
# Use this event handler for opcode 0.
209-
0: self.event_handler,
210-
},
211-
intents=intents or Intents.all(),
212-
reconnect=reconnect,
213-
)
202+
if intents is None:
203+
intents = Intents.all()
204+
205+
self.intents = intents
206+
self.reconnect = reconnect
207+
self.token = token
214208

215209
self.bot: Optional[User] = None
216210
self.received_message = received or "Command arrived successfully!"
217211
self.http = HTTPClient(token)
218212
self.throttler = throttler
219213
self.event_mgr = EventMgr()
220214

215+
async def get_gateway():
216+
return GatewayInfo.from_dict(
217+
await self.http.get("gateway/bot")
218+
)
219+
220+
loop = get_event_loop()
221+
self.gateway: GatewayInfo = loop.run_until_complete(get_gateway())
222+
221223
# The guild and channel value is only registered if the Client has the GUILDS
222224
# intent.
223225
self.guilds: Dict[Snowflake, Optional[Guild]] = {}
@@ -236,7 +238,6 @@ def chat_commands(self) -> List[str]:
236238
cmd.app.name for cmd in ChatCommandHandler.register.values()
237239
]
238240

239-
240241
@property
241242
def guild_ids(self) -> List[Snowflake]:
242243
"""
@@ -341,7 +342,6 @@ def get_event_coro(name: str) -> List[Optional[Coro]]:
341342
]
342343
)
343344

344-
345345
def load_cog(self, path: str, package: Optional[str] = None):
346346
"""Load a cog from a string path, setup method in COG may
347347
optionally have a first argument which will contain the client!
@@ -461,7 +461,7 @@ async def unload_cog(self, path: str):
461461
await ChatCommandHandler(self).remove_commands(to_remove)
462462

463463
@staticmethod
464-
def execute_event(calls: List[Coro], *args, **kwargs):
464+
def execute_event(calls: List[Coro], gateway: Gateway, *args, **kwargs):
465465
"""Invokes an event.
466466
467467
Parameters
@@ -484,19 +484,86 @@ def execute_event(calls: List[Coro], *args, **kwargs):
484484
*remove_none(args),
485485
)
486486

487+
if should_pass_gateway(call):
488+
call_args = (call_args[0], gateway, *call_args[1:])
489+
487490
ensure_future(call(*call_args, **kwargs))
488491

489492
def run(self):
490-
"""Start the event listener."""
491-
self.start_loop()
493+
"""Start the bot."""
494+
loop = get_event_loop()
495+
ensure_future(self.start_shard(0, 1), loop=loop)
496+
loop.run_forever()
497+
498+
def run_autosharded(self):
499+
"""
500+
Runs the bot with the amount of shards specified by the Discord gateway.
501+
"""
502+
num_shards = self.gateway.shards
503+
return self.run_shards(range(num_shards), num_shards)
504+
505+
def run_shards(self, shards: Iterable, num_shards: int):
506+
"""
507+
Runs shards that you specify.
508+
509+
shards: Iterable
510+
The shards to run.
511+
num_shards: int
512+
The total amount of shards.
513+
"""
514+
loop = get_event_loop()
515+
516+
for shard in shards:
517+
ensure_future(self.start_shard(shard, num_shards), loop=loop)
518+
519+
loop.run_forever()
520+
521+
async def start_shard(
522+
self,
523+
shard: int,
524+
num_shards: int
525+
):
526+
"""|coro|
527+
Starts a shard
528+
This should not be run most of the time. ``run_shards`` and ``run_autosharded``
529+
will likely do what you want.
530+
531+
shard : int
532+
The number of the shard to start.
533+
num_shards : int
534+
The total number of shards.
535+
"""
536+
537+
gateway = Gateway(
538+
self.token,
539+
intents=self.intents,
540+
url=self.gateway.url,
541+
shard=shard,
542+
num_shards=num_shards
543+
)
544+
await gateway.init_session()
545+
546+
gateway.append_handlers({
547+
# Gets triggered on all events
548+
-1: partial(self.payload_event_handler, gateway),
549+
# Use this event handler for opcode 0.
550+
0: partial(self.event_handler, gateway)
551+
})
552+
553+
create_task(gateway.start_loop())
492554

493555
def __del__(self):
494556
"""Ensure close of the http client."""
495557
if hasattr(self, "http"):
496-
run(self.http.close())
558+
create_task(self.http.close())
497559

498560
async def handle_middleware(
499-
self, payload: GatewayDispatch, key: str, *args, **kwargs
561+
self,
562+
payload: GatewayDispatch,
563+
key: str,
564+
gateway: Gateway,
565+
*args,
566+
**kwargs
500567
) -> Tuple[Optional[Coro], List[Any], Dict[str, Any]]:
501568
"""|coro|
502569
@@ -527,7 +594,7 @@ async def handle_middleware(
527594
next_call, arguments, params = ware, [], {}
528595

529596
if iscoroutinefunction(ware):
530-
extractable = await ware(self, payload, *args, **kwargs)
597+
extractable = await ware(self, gateway, payload, *args, **kwargs)
531598

532599
if not isinstance(extractable, tuple):
533600
raise RuntimeError(
@@ -544,11 +611,16 @@ async def handle_middleware(
544611
return (next_call, ret_object)
545612

546613
return await self.handle_middleware(
547-
payload, next_call, *arguments, **params
614+
payload, next_call, gateway, *arguments, **params
548615
)
549616

550617
async def execute_error(
551-
self, error: Exception, name: str = "on_error", *args, **kwargs
618+
self,
619+
error: Exception,
620+
gateway: Gateway,
621+
name: str = "on_error",
622+
*args,
623+
**kwargs
552624
):
553625
"""|coro|
554626
@@ -567,11 +639,16 @@ async def execute_error(
567639
if ``call := self.get_event_coro(name)`` is :data:`False`
568640
"""
569641
if calls := self.get_event_coro(name):
570-
self.execute_event(calls, error, *args, **kwargs)
642+
self.execute_event(calls, gateway, error, *args, **kwargs)
571643
else:
572644
raise error
573645

574-
async def process_event(self, name: str, payload: GatewayDispatch):
646+
async def process_event(
647+
self,
648+
name: str,
649+
payload: GatewayDispatch,
650+
gateway: Gateway
651+
):
575652
"""|coro|
576653
577654
Processes and invokes an event and its middleware
@@ -587,16 +664,20 @@ async def process_event(self, name: str, payload: GatewayDispatch):
587664
what specifically happened.
588665
"""
589666
try:
590-
key, args = await self.handle_middleware(payload, name)
667+
key, args = await self.handle_middleware(payload, name, gateway)
591668
self.event_mgr.process_events(key, args)
592669

593670
if calls := self.get_event_coro(key):
594-
self.execute_event(calls, args)
671+
self.execute_event(calls, gateway, args)
595672

596673
except Exception as e:
597-
await self.execute_error(e)
674+
await self.execute_error(e, gateway)
598675

599-
async def event_handler(self, _, payload: GatewayDispatch):
676+
async def event_handler(
677+
self,
678+
gateway: Gateway,
679+
payload: GatewayDispatch
680+
):
600681
"""|coro|
601682
602683
Handles all payload events with opcode 0.
@@ -611,9 +692,13 @@ async def event_handler(self, _, payload: GatewayDispatch):
611692
required data for the client to know what event it is and
612693
what specifically happened.
613694
"""
614-
await self.process_event(payload.event_name.lower(), payload)
695+
await self.process_event(payload.event_name.lower(), payload, gateway)
615696

616-
async def payload_event_handler(self, _, payload: GatewayDispatch):
697+
async def payload_event_handler(
698+
self,
699+
gateway: Gateway,
700+
payload: GatewayDispatch
701+
):
617702
"""|coro|
618703
619704
Special event which activates the on_payload event.
@@ -628,7 +713,7 @@ async def payload_event_handler(self, _, payload: GatewayDispatch):
628713
required data for the client to know what event it is and
629714
what specifically happened.
630715
"""
631-
await self.process_event("payload", payload)
716+
await self.process_event("payload", payload, gateway)
632717

633718
@overload
634719
async def create_guild(
@@ -907,7 +992,6 @@ async def get_webhook(
907992
"""
908993
return await Webhook.from_id(self, id, token)
909994

910-
911995
async def get_current_user(self) -> User:
912996
"""|coro|
913997
The user object of the requester's account.
@@ -923,7 +1007,7 @@ async def get_current_user(self) -> User:
9231007
"""
9241008
return User.from_dict(
9251009
construct_client_dict(
926-
self,
1010+
self,
9271011
await self.http.get("users/@me")
9281012
)
9291013
)

0 commit comments

Comments
 (0)