44from __future__ import annotations
55
66import logging
7- from asyncio import iscoroutinefunction , run , ensure_future
7+ from asyncio import iscoroutinefunction , ensure_future , create_task , get_event_loop
88from collections import defaultdict
9+ from functools import partial
910from importlib import import_module
1011from inspect import isasyncgenfunction
1112from typing import (
1213 Any ,
1314 Dict ,
1415 List ,
1516 Optional ,
17+ Iterable ,
1618 Tuple ,
1719 Union ,
1820 overload ,
19- AsyncIterator ,
2021 TYPE_CHECKING
2122)
2223from . import __package__
2324from .commands import ChatCommandHandler
2425from .core import HTTPClient
25- from .core .gateway import Dispatcher
26+ from .core .gateway import GatewayInfo , Gateway
2627from .exceptions import (
2728 InvalidEventName ,
2829 TooManySetupArguments ,
4950from .utils .conversion import construct_client_dict , remove_none
5051from .utils .event_mgr import EventMgr
5152from .utils .extraction import get_index
52- from .utils .insertion import should_pass_cls
53+ from .utils .insertion import should_pass_cls , should_pass_gateway
5354from .utils .signature import get_params
5455from .utils .types import CheckFunction
5556from .utils .types import Coro
@@ -137,12 +138,10 @@ def decorator(func: Coro):
137138 "already been registered"
138139 )
139140
140- async def wrapper (cls , payload : GatewayDispatch ):
141+ async def wrapper (cls , gateway : Gateway , payload : GatewayDispatch ):
141142 _log .debug ("`%s` middleware has been invoked" , call )
142143
143- return await (
144- func (cls , payload ) if should_pass_cls (func ) else func (payload )
145- )
144+ return await func (cls , gateway , payload )
146145
147146 _events [call ] = wrapper
148147 return wrapper
@@ -154,7 +153,7 @@ async def wrapper(cls, payload: GatewayDispatch):
154153 event_middleware (event )(middleware_ )
155154
156155
157- class Client ( Dispatcher ) :
156+ class Client :
158157 """The client is the main instance which is between the programmer
159158 and the discord API.
160159
@@ -200,24 +199,27 @@ def __init__(
200199 if isinstance (intents , Iterable ):
201200 intents = sum (intents )
202201
203- super ().__init__ (
204- token ,
205- handlers = {
206- # Gets triggered on all events
207- - 1 : self .payload_event_handler ,
208- # Use this event handler for opcode 0.
209- 0 : self .event_handler ,
210- },
211- intents = intents or Intents .all (),
212- reconnect = reconnect ,
213- )
202+ if intents is None :
203+ intents = Intents .all ()
204+
205+ self .intents = intents
206+ self .reconnect = reconnect
207+ self .token = token
214208
215209 self .bot : Optional [User ] = None
216210 self .received_message = received or "Command arrived successfully!"
217211 self .http = HTTPClient (token )
218212 self .throttler = throttler
219213 self .event_mgr = EventMgr ()
220214
215+ async def get_gateway ():
216+ return GatewayInfo .from_dict (
217+ await self .http .get ("gateway/bot" )
218+ )
219+
220+ loop = get_event_loop ()
221+ self .gateway : GatewayInfo = loop .run_until_complete (get_gateway ())
222+
221223 # The guild and channel value is only registered if the Client has the GUILDS
222224 # intent.
223225 self .guilds : Dict [Snowflake , Optional [Guild ]] = {}
@@ -236,7 +238,6 @@ def chat_commands(self) -> List[str]:
236238 cmd .app .name for cmd in ChatCommandHandler .register .values ()
237239 ]
238240
239-
240241 @property
241242 def guild_ids (self ) -> List [Snowflake ]:
242243 """
@@ -341,7 +342,6 @@ def get_event_coro(name: str) -> List[Optional[Coro]]:
341342 ]
342343 )
343344
344-
345345 def load_cog (self , path : str , package : Optional [str ] = None ):
346346 """Load a cog from a string path, setup method in COG may
347347 optionally have a first argument which will contain the client!
@@ -461,7 +461,7 @@ async def unload_cog(self, path: str):
461461 await ChatCommandHandler (self ).remove_commands (to_remove )
462462
463463 @staticmethod
464- def execute_event (calls : List [Coro ], * args , ** kwargs ):
464+ def execute_event (calls : List [Coro ], gateway : Gateway , * args , ** kwargs ):
465465 """Invokes an event.
466466
467467 Parameters
@@ -484,19 +484,86 @@ def execute_event(calls: List[Coro], *args, **kwargs):
484484 * remove_none (args ),
485485 )
486486
487+ if should_pass_gateway (call ):
488+ call_args = (call_args [0 ], gateway , * call_args [1 :])
489+
487490 ensure_future (call (* call_args , ** kwargs ))
488491
489492 def run (self ):
490- """Start the event listener."""
491- self .start_loop ()
493+ """Start the bot."""
494+ loop = get_event_loop ()
495+ ensure_future (self .start_shard (0 , 1 ), loop = loop )
496+ loop .run_forever ()
497+
498+ def run_autosharded (self ):
499+ """
500+ Runs the bot with the amount of shards specified by the Discord gateway.
501+ """
502+ num_shards = self .gateway .shards
503+ return self .run_shards (range (num_shards ), num_shards )
504+
505+ def run_shards (self , shards : Iterable , num_shards : int ):
506+ """
507+ Runs shards that you specify.
508+
509+ shards: Iterable
510+ The shards to run.
511+ num_shards: int
512+ The total amount of shards.
513+ """
514+ loop = get_event_loop ()
515+
516+ for shard in shards :
517+ ensure_future (self .start_shard (shard , num_shards ), loop = loop )
518+
519+ loop .run_forever ()
520+
521+ async def start_shard (
522+ self ,
523+ shard : int ,
524+ num_shards : int
525+ ):
526+ """|coro|
527+ Starts a shard
528+ This should not be run most of the time. ``run_shards`` and ``run_autosharded``
529+ will likely do what you want.
530+
531+ shard : int
532+ The number of the shard to start.
533+ num_shards : int
534+ The total number of shards.
535+ """
536+
537+ gateway = Gateway (
538+ self .token ,
539+ intents = self .intents ,
540+ url = self .gateway .url ,
541+ shard = shard ,
542+ num_shards = num_shards
543+ )
544+ await gateway .init_session ()
545+
546+ gateway .append_handlers ({
547+ # Gets triggered on all events
548+ - 1 : partial (self .payload_event_handler , gateway ),
549+ # Use this event handler for opcode 0.
550+ 0 : partial (self .event_handler , gateway )
551+ })
552+
553+ create_task (gateway .start_loop ())
492554
493555 def __del__ (self ):
494556 """Ensure close of the http client."""
495557 if hasattr (self , "http" ):
496- run (self .http .close ())
558+ create_task (self .http .close ())
497559
498560 async def handle_middleware (
499- self , payload : GatewayDispatch , key : str , * args , ** kwargs
561+ self ,
562+ payload : GatewayDispatch ,
563+ key : str ,
564+ gateway : Gateway ,
565+ * args ,
566+ ** kwargs
500567 ) -> Tuple [Optional [Coro ], List [Any ], Dict [str , Any ]]:
501568 """|coro|
502569
@@ -527,7 +594,7 @@ async def handle_middleware(
527594 next_call , arguments , params = ware , [], {}
528595
529596 if iscoroutinefunction (ware ):
530- extractable = await ware (self , payload , * args , ** kwargs )
597+ extractable = await ware (self , gateway , payload , * args , ** kwargs )
531598
532599 if not isinstance (extractable , tuple ):
533600 raise RuntimeError (
@@ -544,11 +611,16 @@ async def handle_middleware(
544611 return (next_call , ret_object )
545612
546613 return await self .handle_middleware (
547- payload , next_call , * arguments , ** params
614+ payload , next_call , gateway , * arguments , ** params
548615 )
549616
550617 async def execute_error (
551- self , error : Exception , name : str = "on_error" , * args , ** kwargs
618+ self ,
619+ error : Exception ,
620+ gateway : Gateway ,
621+ name : str = "on_error" ,
622+ * args ,
623+ ** kwargs
552624 ):
553625 """|coro|
554626
@@ -567,11 +639,16 @@ async def execute_error(
567639 if ``call := self.get_event_coro(name)`` is :data:`False`
568640 """
569641 if calls := self .get_event_coro (name ):
570- self .execute_event (calls , error , * args , ** kwargs )
642+ self .execute_event (calls , gateway , error , * args , ** kwargs )
571643 else :
572644 raise error
573645
574- async def process_event (self , name : str , payload : GatewayDispatch ):
646+ async def process_event (
647+ self ,
648+ name : str ,
649+ payload : GatewayDispatch ,
650+ gateway : Gateway
651+ ):
575652 """|coro|
576653
577654 Processes and invokes an event and its middleware
@@ -587,16 +664,20 @@ async def process_event(self, name: str, payload: GatewayDispatch):
587664 what specifically happened.
588665 """
589666 try :
590- key , args = await self .handle_middleware (payload , name )
667+ key , args = await self .handle_middleware (payload , name , gateway )
591668 self .event_mgr .process_events (key , args )
592669
593670 if calls := self .get_event_coro (key ):
594- self .execute_event (calls , args )
671+ self .execute_event (calls , gateway , args )
595672
596673 except Exception as e :
597- await self .execute_error (e )
674+ await self .execute_error (e , gateway )
598675
599- async def event_handler (self , _ , payload : GatewayDispatch ):
676+ async def event_handler (
677+ self ,
678+ gateway : Gateway ,
679+ payload : GatewayDispatch
680+ ):
600681 """|coro|
601682
602683 Handles all payload events with opcode 0.
@@ -611,9 +692,13 @@ async def event_handler(self, _, payload: GatewayDispatch):
611692 required data for the client to know what event it is and
612693 what specifically happened.
613694 """
614- await self .process_event (payload .event_name .lower (), payload )
695+ await self .process_event (payload .event_name .lower (), payload , gateway )
615696
616- async def payload_event_handler (self , _ , payload : GatewayDispatch ):
697+ async def payload_event_handler (
698+ self ,
699+ gateway : Gateway ,
700+ payload : GatewayDispatch
701+ ):
617702 """|coro|
618703
619704 Special event which activates the on_payload event.
@@ -628,7 +713,7 @@ async def payload_event_handler(self, _, payload: GatewayDispatch):
628713 required data for the client to know what event it is and
629714 what specifically happened.
630715 """
631- await self .process_event ("payload" , payload )
716+ await self .process_event ("payload" , payload , gateway )
632717
633718 @overload
634719 async def create_guild (
@@ -907,7 +992,6 @@ async def get_webhook(
907992 """
908993 return await Webhook .from_id (self , id , token )
909994
910-
911995 async def get_current_user (self ) -> User :
912996 """|coro|
913997 The user object of the requester's account.
@@ -923,7 +1007,7 @@ async def get_current_user(self) -> User:
9231007 """
9241008 return User .from_dict (
9251009 construct_client_dict (
926- self ,
1010+ self ,
9271011 await self .http .get ("users/@me" )
9281012 )
9291013 )
0 commit comments