Skip to content

Commit 61d7638

Browse files
authored
feat: add short URL support for OAuth2 authentication (#634)
1 parent 473b911 commit 61d7638

File tree

7 files changed

+95
-12
lines changed

7 files changed

+95
-12
lines changed

.github/workflows/ci.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ jobs:
2727

2828
- run: go build
2929
- run: go test ./... -timeout 20s -race -covermode=atomic -coverprofile=coverage.out -coverpkg=./...
30+
id: test
31+
shell: bash
32+
33+
- name: go test ./... -timeout 20s -race -covermode=atomic -coverprofile=coverage.out -coverpkg=./... (retry)
34+
if: steps.test.outcome == 'failure'
35+
shell: bash
36+
run: go test ./... -timeout 20s -race -covermode=atomic -coverprofile=coverage.out -coverpkg=./...
37+
3038
- run: go test ./... -timeout 20s -run='^$' -bench=. -benchmem -count 3
3139

3240
- uses: codecov/codecov-action@5a1091511ad55cbe89839c7260b706298ca349f7 # v5.5.1

internal/config/config_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ openvpn:
132132
http:
133133
listen: ":9001"
134134
secret: "1jd93h5b6s82lf03jh5b2hf9"
135-
enable-proxy-headers: true
135+
enable-proxy-headers: false
136+
short-url: false
136137
assets-path: "."
137138
template: "../../README.md"
138139
check:
@@ -162,7 +163,8 @@ http:
162163
Check: config.HTTPCheck{
163164
IPAddr: true,
164165
},
165-
EnableProxyHeaders: true,
166+
EnableProxyHeaders: false,
167+
ShortURL: false,
166168
Listen: ":9001",
167169
Secret: "1jd93h5b6s82lf03jh5b2hf9",
168170
Template: func() types.Template {

internal/config/flags.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ func (c *Config) flagSetHTTP(flagSet *flag.FlagSet) {
100100
lookupEnvOrDefault("http.enable-proxy-headers", c.HTTP.EnableProxyHeaders),
101101
"Use X-Forward-For http header for client ips",
102102
)
103+
flagSet.BoolVar(
104+
&c.HTTP.ShortURL,
105+
"http.short-url",
106+
lookupEnvOrDefault("http.short-url", c.HTTP.ShortURL),
107+
"Enable short URL. The URL which is used for initial authentication will be reduced to /?s=... instead of /oauth2/start?state=...",
108+
)
103109
flagSet.TextVar(
104110
&c.HTTP.AssetPath,
105111
"http.assets-path",

internal/config/types.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type HTTP struct {
3737
TLS bool `json:"tls" yaml:"tls"`
3838
Check HTTPCheck `json:"check" yaml:"check"`
3939
EnableProxyHeaders bool `json:"enable-proxy-headers" yaml:"enable-proxy-headers"`
40+
ShortURL bool `json:"short-url" yaml:"short-url"`
4041
}
4142

4243
type HTTPCheck struct {

internal/httphandler/handler.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,19 @@ func New(conf config.Config, oAuth2Client *oauth2.Client) *http.ServeMux {
2424
basePath := strings.TrimSuffix(conf.HTTP.BaseURL.Path, "/")
2525

2626
mux := http.NewServeMux()
27-
mux.Handle("/", http.NotFoundHandler())
27+
if basePath != "" {
28+
mux.Handle("/", http.NotFoundHandler())
29+
}
30+
31+
mux.Handle(fmt.Sprintf("GET %s/", basePath), noCacheHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
32+
if conf.HTTP.ShortURL && r.URL.Query().Has("s") {
33+
http.Redirect(w, r, fmt.Sprintf("%s/oauth2/start?state=%s", basePath, r.URL.Query().Get("s")), http.StatusFound)
34+
35+
return
36+
}
37+
38+
http.NotFound(w, r)
39+
})))
2840
mux.Handle(fmt.Sprintf("GET %s/ready", basePath), http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
2941
w.WriteHeader(http.StatusOK)
3042
_, _ = w.Write([]byte("OK"))

internal/oauth2/handler_test.go

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,30 @@ func TestHandler(t *testing.T) {
233233
true,
234234
true,
235235
},
236+
{
237+
"with short-url",
238+
func() config.Config {
239+
conf := config.Defaults
240+
conf.HTTP.Secret = testutils.Secret
241+
conf.HTTP.ShortURL = true
242+
conf.OAuth2.Provider = generic.Name
243+
conf.OAuth2.Endpoints = config.OAuth2Endpoints{}
244+
conf.OAuth2.Scopes = []string{oauth2types.ScopeOpenID, oauth2types.ScopeProfile}
245+
conf.OAuth2.Validate.Groups = make([]string, 0)
246+
conf.OAuth2.Validate.Roles = make([]string, 0)
247+
conf.OAuth2.Validate.Issuer = true
248+
conf.OAuth2.Validate.IPAddr = false
249+
conf.OpenVPN.Bypass.CommonNames = make(types.RegexpSlice, 0)
250+
conf.OpenVPN.AuthTokenUser = true
251+
252+
return conf
253+
}(),
254+
state.New(state.ClientIdentifier{CID: 0, KID: 1, CommonName: "name"}, "127.0.0.1", "12345", ""),
255+
false,
256+
"",
257+
true,
258+
true,
259+
},
236260
{
237261
"with ipaddr + forwarded-for",
238262
func() config.Config {
@@ -429,6 +453,10 @@ func TestHandler(t *testing.T) {
429453

430454
conf, openVPNClient, managementInterface, _, httpClientListener, httpClient, logger := testutils.SetupMockEnvironment(ctx, t, tc.conf, nil, nil)
431455

456+
httpClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error {
457+
return http.ErrUseLastResponse
458+
}
459+
432460
managementInterfaceConn, errOpenVPNClientCh, err := testutils.ConnectToManagementInterface(t, managementInterface, openVPNClient)
433461
require.NoError(t, err)
434462

@@ -477,23 +505,44 @@ func TestHandler(t *testing.T) {
477505
case tc.state == (state.State{}):
478506
session = ""
479507
default:
480-
session, err = tc.state.Encode(tc.conf.HTTP.Secret.String())
508+
session, err = tc.state.Encode(conf.HTTP.Secret.String())
481509
require.NoError(t, err)
482510
}
483511

512+
urlPath := "/oauth2/start?state="
513+
if conf.HTTP.ShortURL {
514+
urlPath = "/?s="
515+
}
516+
484517
request, err = http.NewRequestWithContext(t.Context(), http.MethodGet,
485-
fmt.Sprintf("%s/oauth2/start?state=%s", httpClientListener.URL, session),
518+
fmt.Sprintf("%s%s%s", httpClientListener.URL, urlPath, session),
486519
nil,
487520
)
488521

489522
require.NoError(t, err)
490523

491-
if tc.xForwardedFor != "" {
492-
request.Header.Set("X-Forwarded-For", tc.xForwardedFor)
524+
if conf.HTTP.ShortURL {
525+
resp, err = httpClient.Do(request) //nolint:bodyclose
526+
require.NoError(t, err)
527+
528+
_, err = io.Copy(io.Discard, resp.Body)
529+
require.NoError(t, err)
530+
531+
err = resp.Body.Close()
532+
require.NoError(t, err)
533+
534+
require.Equal(t, http.StatusFound, resp.StatusCode)
535+
require.NotEmpty(t, resp.Header.Get("Location"))
536+
537+
request, err = http.NewRequestWithContext(t.Context(), http.MethodGet,
538+
httpClientListener.URL+resp.Header.Get("Location"),
539+
nil,
540+
)
541+
require.NoError(t, err)
493542
}
494543

495-
httpClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error {
496-
return http.ErrUseLastResponse
544+
if tc.xForwardedFor != "" {
545+
request.Header.Set("X-Forwarded-For", tc.xForwardedFor)
497546
}
498547

499548
reqErrCh := make(chan error, 1)
@@ -562,7 +611,7 @@ func TestHandler(t *testing.T) {
562611
case tc.state.Client.UsernameIsDefined == 1:
563612
testutils.ExpectMessage(t, managementInterfaceConn, reader, "client-auth-nt 0 1")
564613
testutils.SendMessagef(t, managementInterfaceConn, "SUCCESS: client-auth command succeeded")
565-
case tc.conf.OpenVPN.ClientConfig.Enabled:
614+
case conf.OpenVPN.ClientConfig.Enabled:
566615
if tc.state.Client.CommonName == "name" {
567616
testutils.ExpectMessage(t, managementInterfaceConn, reader, "client-auth 0 1\r\n"+
568617
"push \"ping 60\"\r\n"+
@@ -576,7 +625,7 @@ func TestHandler(t *testing.T) {
576625
testutils.SendMessagef(t, managementInterfaceConn, "SUCCESS: client-auth command succeeded")
577626
}
578627
default:
579-
if tc.conf.OAuth2.UserInfo {
628+
if conf.OAuth2.UserInfo {
580629
testutils.ExpectMessage(t, managementInterfaceConn, reader, "client-auth 0 1\r\npush \"auth-token-user dGVzdC11c2VyQGxvY2FsaG9zdA==\"\r\nEND")
581630
} else {
582631
testutils.ExpectMessage(t, managementInterfaceConn, reader, "client-auth 0 1\r\npush \"auth-token-user bmFtZQ==\"\r\nEND")

internal/openvpn/client.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,12 @@ func (c *Client) startClientAuth(ctx context.Context, logger *slog.Logger, clien
128128
return fmt.Errorf("error encoding state: %w", err)
129129
}
130130

131-
startURL := utils.StringConcat(strings.TrimSuffix(c.conf.HTTP.BaseURL.String(), "/"), "/oauth2/start?state=", encodedSession)
131+
urlPath := "/?s="
132+
if !c.conf.HTTP.ShortURL {
133+
urlPath = "/oauth2/start?state="
134+
}
135+
136+
startURL := fmt.Sprintf("%s%s%s", strings.TrimSuffix(c.conf.HTTP.BaseURL.String(), "/"), urlPath, encodedSession)
132137

133138
if len(startURL) >= 245 {
134139
return fmt.Errorf("url %s (%d chars) too long! OpenVPN support up to 245 chars. "+

0 commit comments

Comments
 (0)