Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions core/bot/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package bot

import (
"encoding/base64"
"net/http"
"strings"

"github.com/DisgoOrg/disgo/core"
Expand Down Expand Up @@ -66,11 +65,8 @@ func buildBot(token string, config Config) (*core.Bot, error) {
if config.RestClientConfig.Logger == nil {
config.RestClientConfig.Logger = config.Logger
}
if config.RestClientConfig.Headers == nil {
config.RestClientConfig.Headers = http.Header{}
}
if _, ok := config.RestClientConfig.Headers["Authorization"]; !ok {
config.RestClientConfig.Headers["Authorization"] = []string{discord.TokenTypeBot.Apply(token)}
if config.RestClientConfig.BotTokenFunc == nil {
config.RestClientConfig.BotTokenFunc = func() string {return bot.Token}
}
config.RestClient = rest.NewClient(config.RestClientConfig)
}
Expand Down
2 changes: 1 addition & 1 deletion oauth2/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func (c *Client) GetGuilds(session Session, opts ...rest.RequestOpt) ([]discord.
return nil, ErrMissingOAuth2Scope(discord.ApplicationScopeGuilds)
}

return c.OAuth2Service.GetCurrentUserGuilds(session.AccessToken(), opts...)
return c.OAuth2Service.GetCurrentUserGuilds(session.AccessToken(), "", "", 0, opts...)
}

// GetConnections returns the discord.Connection(s) the user has connected. This requires the discord.ApplicationScopeConnections scope in the session
Expand Down
48 changes: 32 additions & 16 deletions rest/oauth2_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ func NewOAuth2Service(restClient Client) OAuth2Service {
type OAuth2Service interface {
Service
GetBotApplicationInfo(opts ...RequestOpt) (*discord.Application, error)
GetAuthorizationInfo(opts ...RequestOpt) (*discord.AuthorizationInformation, error)

GetCurrentUserGuilds(token string, opts ...RequestOpt) ([]discord.OAuth2Guild, error)
GetCurrentUser(token string, opts ...RequestOpt) (*discord.OAuth2User, error)
GetCurrentUserConnections(token string, opts ...RequestOpt) ([]discord.Connection, error)
GetCurrentAuthorizationInfo(bearerToken string, opts ...RequestOpt) (*discord.AuthorizationInformation, error)
GetCurrentUser(bearerToken string, opts ...RequestOpt) (*discord.OAuth2User, error)
GetCurrentUserGuilds(bearerToken string, before discord.Snowflake, after discord.Snowflake, limit int, opts ...RequestOpt) ([]discord.OAuth2Guild, error)
GetCurrentUserConnections(bearerToken string, opts ...RequestOpt) ([]discord.Connection, error)

GetAccessToken(clientID discord.Snowflake, clientSecret string, code string, redirectURI string, opts ...RequestOpt) (*discord.AccessTokenExchange, error)
RefreshAccessToken(clientID discord.Snowflake, clientSecret string, refreshToken string, opts ...RequestOpt) (*discord.AccessTokenExchange, error)
Expand All @@ -33,6 +33,13 @@ type oAuth2ServiceImpl struct {
restClient Client
}

func (s *oAuth2ServiceImpl) botOrBearerToken(bearerToken string, opts []RequestOpt) []RequestOpt {
if bearerToken == "" {
return applyBotToken(s.RestClient().Config().BotTokenFunc(), opts)
}
return applyBearerToken(bearerToken, opts)
}

func (s *oAuth2ServiceImpl) RestClient() Client {
return s.restClient
}
Expand All @@ -47,48 +54,57 @@ func (s *oAuth2ServiceImpl) GetBotApplicationInfo(opts ...RequestOpt) (applicati
return
}

func (s *oAuth2ServiceImpl) GetAuthorizationInfo(opts ...RequestOpt) (info *discord.AuthorizationInformation, err error) {
func (s *oAuth2ServiceImpl) GetCurrentAuthorizationInfo(bearerToken string, opts ...RequestOpt) (info *discord.AuthorizationInformation, err error) {
var compiledRoute *route.CompiledAPIRoute
compiledRoute, err = route.GetAuthorizationInfo.Compile(nil)
if err != nil {
return
}
err = s.restClient.Do(compiledRoute, nil, &info, opts...)
err = s.restClient.Do(compiledRoute, nil, &info, applyBearerToken(bearerToken, opts)...)
return
}

func (s *oAuth2ServiceImpl) GetCurrentUserGuilds(token string, opts ...RequestOpt) (guilds []discord.OAuth2Guild, err error) {
queryParams := route.QueryValues{}

func (s *oAuth2ServiceImpl) GetCurrentUser(bearerToken string, opts ...RequestOpt) (user *discord.OAuth2User, err error) {
var compiledRoute *route.CompiledAPIRoute
compiledRoute, err = route.GetCurrentUserGuilds.Compile(queryParams)
compiledRoute, err = route.GetCurrentUser.Compile(nil)
if err != nil {
return
}

err = s.restClient.Do(compiledRoute, nil, &guilds, append(opts, WithHeader("authorization", discord.TokenTypeBearer.Apply(token)))...)
err = s.restClient.Do(compiledRoute, nil, &user, s.botOrBearerToken(bearerToken, opts)...)
return
}

func (s *oAuth2ServiceImpl) GetCurrentUser(token string, opts ...RequestOpt) (user *discord.OAuth2User, err error) {
func (s *oAuth2ServiceImpl) GetCurrentUserGuilds(bearerToken string, before discord.Snowflake, after discord.Snowflake, limit int, opts ...RequestOpt) (guilds []discord.OAuth2Guild, err error) {
queryParams := route.QueryValues{}
if before != "" {
queryParams["before"] = before
}
if after != "" {
queryParams["after"] = after
}
if limit != 0 {
queryParams["limit"] = limit
}

var compiledRoute *route.CompiledAPIRoute
compiledRoute, err = route.GetCurrentUser.Compile(nil)
compiledRoute, err = route.GetCurrentUserGuilds.Compile(queryParams)
if err != nil {
return
}

err = s.restClient.Do(compiledRoute, nil, &user, append(opts, WithHeader("authorization", discord.TokenTypeBearer.Apply(token)))...)
err = s.restClient.Do(compiledRoute, nil, &guilds, s.botOrBearerToken(bearerToken, opts)...)
return
}

func (s *oAuth2ServiceImpl) GetCurrentUserConnections(token string, opts ...RequestOpt) (connections []discord.Connection, err error) {
func (s *oAuth2ServiceImpl) GetCurrentUserConnections(bearerToken string, opts ...RequestOpt) (connections []discord.Connection, err error) {
var compiledRoute *route.CompiledAPIRoute
compiledRoute, err = route.GetCurrentUserConnections.Compile(nil)
if err != nil {
return
}

err = s.restClient.Do(compiledRoute, nil, &connections, append(opts, WithHeader("authorization", discord.TokenTypeBearer.Apply(token)))...)
err = s.restClient.Do(compiledRoute, nil, &connections, s.botOrBearerToken(bearerToken, opts)...)
return
}

Expand Down
26 changes: 17 additions & 9 deletions rest/rest_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ type Client interface {
// Close closes the rest client and awaits all pending requests to finish. You can use a cancelling context to abort the waiting
Close(ctx context.Context) error

// Do makes a request to the given route and marshals the given interface{} as json and unmarshalls the response into the given interface
// Do makes a request to the given route.CompiledAPIRoute and marshals the given interface{} as json and unmarshalls the response into the given interface
Do(route *route.CompiledAPIRoute, rqBody interface{}, rsBody interface{}, opts ...RequestOpt) error
}

Expand Down Expand Up @@ -121,19 +121,15 @@ func (c *clientImpl) retry(cRoute *route.CompiledAPIRoute, rqBody interface{}, r
return err
}

// write all headers to the request
if headers := c.Config().Headers; headers != nil {
for key, values := range headers {
for _, value := range values {
rq.Header.Add(key, value)
}
}
}
rq.Header.Set("User-Agent", c.Config().UserAgent)
if contentType != "" {
rq.Header.Set("Content-Type", contentType)
}

if cRoute.APIRoute.NeedsAuth() {
opts = applyBotToken(c.Config().BotTokenFunc(), opts)
}

config := &RequestConfig{Request: rq}
config.Apply(opts)

Expand Down Expand Up @@ -209,3 +205,15 @@ func (c *clientImpl) retry(cRoute *route.CompiledAPIRoute, rqBody interface{}, r
func (c *clientImpl) Do(cRoute *route.CompiledAPIRoute, rqBody interface{}, rsBody interface{}, opts ...RequestOpt) error {
return c.retry(cRoute, rqBody, rsBody, 1, opts)
}

func applyToken(tokenType discord.TokenType, token string, opts []RequestOpt) []RequestOpt {
return append(opts, WithHeader("authorization", tokenType.Apply(token)))
}

func applyBotToken(botToken string, opts []RequestOpt) []RequestOpt {
return applyToken(discord.TokenTypeBot, botToken, opts)
}

func applyBearerToken(bearerToken string, opts []RequestOpt) []RequestOpt {
return applyToken(discord.TokenTypeBearer, bearerToken, opts)
}
12 changes: 6 additions & 6 deletions rest/rest_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package rest
import (
"fmt"
"net/http"
"time"

"github.com/DisgoOrg/disgo/info"
"github.com/DisgoOrg/disgo/rest/rrate"
Expand All @@ -11,9 +12,8 @@ import (

// DefaultConfig is the configuration which is used by default
var DefaultConfig = Config{
HTTPClient: http.DefaultClient,
HTTPClient: &http.Client{Timeout: 20*time.Second},
RateLimiterConfig: &rrate.DefaultConfig,
Headers: http.Header{},
UserAgent: fmt.Sprintf("DiscordBot (%s, %s)", info.GitHub, info.Version),
}

Expand All @@ -23,7 +23,7 @@ type Config struct {
HTTPClient *http.Client
RateLimiter rrate.Limiter
RateLimiterConfig *rrate.Config
Headers http.Header
BotTokenFunc func() string
UserAgent string
}

Expand Down Expand Up @@ -80,11 +80,11 @@ func WithRateLimiterConfigOpts(opts ...rrate.ConfigOpt) ConfigOpt {
}
}

// WithHeaders adds a custom header to all requests
// WithBotTokenFunc sets the function to get the bot token
//goland:noinspection GoUnusedExportedFunction
func WithHeaders(headers http.Header) ConfigOpt {
func WithBotTokenFunc(botTokenFunc func() string) ConfigOpt {
return func(config *Config) {
config.Headers = headers
config.BotTokenFunc = botTokenFunc
}
}

Expand Down
15 changes: 15 additions & 0 deletions rest/route/api_route.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ import (

// NewAPIRoute generates a new discord api path struct
func NewAPIRoute(method Method, path string, queryParams ...string) *APIRoute {
return newAPIRoute(method, path, queryParams, true)
}

func NewAPIRouteNoAuth(method Method, path string, queryParams ...string) *APIRoute {
return newAPIRoute(method, path, queryParams, false)
}

func newAPIRoute(method Method, path string, queryParams []string, needsAuth bool) *APIRoute {
params := map[string]struct{}{}
for _, param := range queryParams {
params[param] = struct{}{}
Expand All @@ -19,6 +27,7 @@ func NewAPIRoute(method Method, path string, queryParams ...string) *APIRoute {
queryParams: params,
urlParamCount: countURLParams(path),
method: method,
needsAuth: needsAuth,
}
}

Expand All @@ -37,6 +46,7 @@ type APIRoute struct {
queryParams map[string]struct{}
urlParamCount int
method Method
needsAuth bool
}

// Compile returns a CompiledAPIRoute
Expand Down Expand Up @@ -89,6 +99,11 @@ func (r *APIRoute) Path() string {
return r.path
}

// NeedsAuth returns whether the route requires authentication
func (r *APIRoute) NeedsAuth() bool {
return r.needsAuth
}

// CompiledAPIRoute is APIRoute compiled with all URL args
type CompiledAPIRoute struct {
APIRoute *APIRoute
Expand Down
40 changes: 20 additions & 20 deletions rest/route/route_endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@ var (
// OAuth2
var (
GetBotApplicationInfo = NewAPIRoute(GET, "/oauth2/applications/@me")
GetAuthorizationInfo = NewAPIRoute(GET, "/oauth2/@me")
GetAuthorizationInfo = NewAPIRouteNoAuth(GET, "/oauth2/@me")
Authorize = NewRoute("/oauth2/authorize", "client_id", "permissions", "redirect_uri", "response_type", "scope", "state", "guild_id", "disable_guild_select")
Token = NewAPIRoute(POST, "/oauth2/token")
Token = NewAPIRouteNoAuth(POST, "/oauth2/token")
)

// Users
var (
GetUser = NewAPIRoute(GET, "/users/{user.id}")
GetCurrentUser = NewAPIRoute(GET, "/users/@me")
UpdateSelfUser = NewAPIRoute(PATCH, "/users/@me")
GetCurrentUserConnections = NewAPIRoute(GET, "/users/@me/connections")
GetCurrentUserGuilds = NewAPIRoute(GET, "/users/@me/guilds", "before", "after", "limit")
GetCurrentUserConnections = NewAPIRouteNoAuth(GET, "/users/@me/connections")
GetCurrentUserGuilds = NewAPIRouteNoAuth(GET, "/users/@me/guilds", "before", "after", "limit")
LeaveGuild = NewAPIRoute(DELETE, "/users/@me/guilds/{guild.id}")
GetDMChannels = NewAPIRoute(GET, "/users/@me/channels")
CreateDMChannel = NewAPIRoute(POST, "/users/@me/channels")
Expand Down Expand Up @@ -209,15 +209,15 @@ var (
UpdateWebhook = NewAPIRoute(PATCH, "/webhooks/{webhook.id}")
DeleteWebhook = NewAPIRoute(DELETE, "/webhooks/{webhook.id}")

GetWebhookWithToken = NewAPIRoute(GET, "/webhooks/{webhook.id}/{webhook.token}")
UpdateWebhookWithToken = NewAPIRoute(PATCH, "/webhooks/{webhook.id}/{webhook.token}")
DeleteWebhookWithToken = NewAPIRoute(DELETE, "/webhooks/{webhook.id}/{webhook.token}")
GetWebhookWithToken = NewAPIRouteNoAuth(GET, "/webhooks/{webhook.id}/{webhook.token}")
UpdateWebhookWithToken = NewAPIRouteNoAuth(PATCH, "/webhooks/{webhook.id}/{webhook.token}")
DeleteWebhookWithToken = NewAPIRouteNoAuth(DELETE, "/webhooks/{webhook.id}/{webhook.token}")

CreateWebhookMessage = NewAPIRoute(POST, "/webhooks/{webhook.id}/{webhook.token}", "wait", "thread_id")
CreateWebhookMessageSlack = NewAPIRoute(POST, "/webhooks/{webhook.id}/{webhook.token}/slack", "wait")
CreateWebhookMessageGitHub = NewAPIRoute(POST, "/webhooks/{webhook.id}/{webhook.token}/github", "wait")
UpdateWebhookMessage = NewAPIRoute(PATCH, "/webhooks/{webhook.id}/{webhook.token}/messages/{message.id}")
DeleteWebhookMessage = NewAPIRoute(DELETE, "/webhooks/{webhook.id}/{webhook.token}/messages/{message.id}")
CreateWebhookMessage = NewAPIRouteNoAuth(POST, "/webhooks/{webhook.id}/{webhook.token}", "wait", "thread_id")
CreateWebhookMessageSlack = NewAPIRouteNoAuth(POST, "/webhooks/{webhook.id}/{webhook.token}/slack", "wait")
CreateWebhookMessageGitHub = NewAPIRouteNoAuth(POST, "/webhooks/{webhook.id}/{webhook.token}/github", "wait")
UpdateWebhookMessage = NewAPIRouteNoAuth(PATCH, "/webhooks/{webhook.id}/{webhook.token}/messages/{message.id}")
DeleteWebhookMessage = NewAPIRouteNoAuth(DELETE, "/webhooks/{webhook.id}/{webhook.token}/messages/{message.id}")
)

// Invites
Expand Down Expand Up @@ -251,15 +251,15 @@ var (
SetGuildCommandsPermissions = NewAPIRoute(PUT, "/applications/{application.id}/guilds/{guild.id}/commands/permissions")
SetGuildCommandPermissions = NewAPIRoute(PUT, "/applications/{application.id}/guilds/{guild.id}/commands/{command.id}/permissions")

GetInteractionResponse = NewAPIRoute(GET, "/webhooks/{application.id}/{interaction.token}/messages/@original")
CreateInteractionResponse = NewAPIRoute(POST, "/interactions/{interaction.id}/{interaction.token}/callback")
UpdateInteractionResponse = NewAPIRoute(PATCH, "/webhooks/{application.id}/{interaction.token}/messages/@original")
DeleteInteractionResponse = NewAPIRoute(DELETE, "/webhooks/{application.id}/{interaction.token}/messages/@original")
GetInteractionResponse = NewAPIRouteNoAuth(GET, "/webhooks/{application.id}/{interaction.token}/messages/@original")
CreateInteractionResponse = NewAPIRouteNoAuth(POST, "/interactions/{interaction.id}/{interaction.token}/callback")
UpdateInteractionResponse = NewAPIRouteNoAuth(PATCH, "/webhooks/{application.id}/{interaction.token}/messages/@original")
DeleteInteractionResponse = NewAPIRouteNoAuth(DELETE, "/webhooks/{application.id}/{interaction.token}/messages/@original")

GetFollowupMessage = NewAPIRoute(GET, "/webhooks/{application.id}/{interaction.token}")
CreateFollowupMessage = NewAPIRoute(POST, "/webhooks/{application.id}/{interaction.token}")
UpdateFollowupMessage = NewAPIRoute(PATCH, "/webhooks/{application.id}/{interaction.token}/messages/{message.id}")
DeleteFollowupMessage = NewAPIRoute(DELETE, "/webhooks/{application.id}/{interaction.token}/messages/{message.id}")
GetFollowupMessage = NewAPIRouteNoAuth(GET, "/webhooks/{application.id}/{interaction.token}")
CreateFollowupMessage = NewAPIRouteNoAuth(POST, "/webhooks/{application.id}/{interaction.token}")
UpdateFollowupMessage = NewAPIRouteNoAuth(PATCH, "/webhooks/{application.id}/{interaction.token}/messages/{message.id}")
DeleteFollowupMessage = NewAPIRouteNoAuth(DELETE, "/webhooks/{application.id}/{interaction.token}/messages/{message.id}")
)

// CDN
Expand Down
11 changes: 0 additions & 11 deletions rest/user_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ func NewUserService(restClient Client) UserService {
type UserService interface {
Service
GetUser(userID discord.Snowflake, opts ...RequestOpt) (*discord.User, error)
GetSelfUser(opts ...RequestOpt) (*discord.OAuth2User, error)
UpdateSelfUser(selfUserUpdate discord.SelfUserUpdate, opts ...RequestOpt) (*discord.OAuth2User, error)
GetGuilds(before int, after int, limit int, opts ...RequestOpt) ([]discord.OAuth2Guild, error)
LeaveGuild(guildID discord.Snowflake, opts ...RequestOpt) error
Expand All @@ -43,16 +42,6 @@ func (s *userServiceImpl) GetUser(userID discord.Snowflake, opts ...RequestOpt)
return
}

func (s *userServiceImpl) GetSelfUser(opts ...RequestOpt) (selfUser *discord.OAuth2User, err error) {
var compiledRoute *route.CompiledAPIRoute
compiledRoute, err = route.GetCurrentUser.Compile(nil)
if err != nil {
return
}
err = s.restClient.Do(compiledRoute, nil, &selfUser, opts...)
return
}

func (s *userServiceImpl) UpdateSelfUser(updateSelfUser discord.SelfUserUpdate, opts ...RequestOpt) (selfUser *discord.OAuth2User, err error) {
var compiledRoute *route.CompiledAPIRoute
compiledRoute, err = route.UpdateSelfUser.Compile(nil)
Expand Down