Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ jobs:

- run: go build
- run: go test ./... -timeout 20s -race -covermode=atomic -coverprofile=coverage.out -coverpkg=./...
id: test
shell: bash

- name: go test ./... -timeout 20s -race -covermode=atomic -coverprofile=coverage.out -coverpkg=./... (retry)
if: steps.test.outcome == 'failure'
shell: bash
run: go test ./... -timeout 20s -race -covermode=atomic -coverprofile=coverage.out -coverpkg=./...

- run: go test ./... -timeout 20s -run='^$' -bench=. -benchmem -count 3

- uses: codecov/codecov-action@5a1091511ad55cbe89839c7260b706298ca349f7 # v5.5.1
Expand Down
6 changes: 4 additions & 2 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ openvpn:
http:
listen: ":9001"
secret: "1jd93h5b6s82lf03jh5b2hf9"
enable-proxy-headers: true
enable-proxy-headers: false
short-url: false
assets-path: "."
template: "../../README.md"
check:
Expand Down Expand Up @@ -162,7 +163,8 @@ http:
Check: config.HTTPCheck{
IPAddr: true,
},
EnableProxyHeaders: true,
EnableProxyHeaders: false,
ShortURL: false,
Listen: ":9001",
Secret: "1jd93h5b6s82lf03jh5b2hf9",
Template: func() types.Template {
Expand Down
6 changes: 6 additions & 0 deletions internal/config/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ func (c *Config) flagSetHTTP(flagSet *flag.FlagSet) {
lookupEnvOrDefault("http.enable-proxy-headers", c.HTTP.EnableProxyHeaders),
"Use X-Forward-For http header for client ips",
)
flagSet.BoolVar(
&c.HTTP.ShortURL,
"http.short-url",
lookupEnvOrDefault("http.short-url", c.HTTP.ShortURL),
"Enable short URL. The URL which is used for initial authentication will be reduced to /?s=... instead of /oauth2/start?state=...",
)
flagSet.TextVar(
&c.HTTP.AssetPath,
"http.assets-path",
Expand Down
1 change: 1 addition & 0 deletions internal/config/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type HTTP struct {
TLS bool `json:"tls" yaml:"tls"`
Check HTTPCheck `json:"check" yaml:"check"`
EnableProxyHeaders bool `json:"enable-proxy-headers" yaml:"enable-proxy-headers"`
ShortURL bool `json:"short-url" yaml:"short-url"`
}

type HTTPCheck struct {
Expand Down
14 changes: 13 additions & 1 deletion internal/httphandler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,19 @@ func New(conf config.Config, oAuth2Client *oauth2.Client) *http.ServeMux {
basePath := strings.TrimSuffix(conf.HTTP.BaseURL.Path, "/")

mux := http.NewServeMux()
mux.Handle("/", http.NotFoundHandler())
if basePath != "" {
mux.Handle("/", http.NotFoundHandler())
}

mux.Handle(fmt.Sprintf("GET %s/", basePath), noCacheHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if conf.HTTP.ShortURL && r.URL.Query().Has("s") {
http.Redirect(w, r, fmt.Sprintf("%s/oauth2/start?state=%s", basePath, r.URL.Query().Get("s")), http.StatusFound)

return
}

http.NotFound(w, r)
})))
mux.Handle(fmt.Sprintf("GET %s/ready", basePath), http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
Expand Down
65 changes: 57 additions & 8 deletions internal/oauth2/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,30 @@ func TestHandler(t *testing.T) {
true,
true,
},
{
"with short-url",
func() config.Config {
conf := config.Defaults
conf.HTTP.Secret = testutils.Secret
conf.HTTP.ShortURL = true
conf.OAuth2.Provider = generic.Name
conf.OAuth2.Endpoints = config.OAuth2Endpoints{}
conf.OAuth2.Scopes = []string{oauth2types.ScopeOpenID, oauth2types.ScopeProfile}
conf.OAuth2.Validate.Groups = make([]string, 0)
conf.OAuth2.Validate.Roles = make([]string, 0)
conf.OAuth2.Validate.Issuer = true
conf.OAuth2.Validate.IPAddr = false
conf.OpenVPN.Bypass.CommonNames = make(types.RegexpSlice, 0)
conf.OpenVPN.AuthTokenUser = true

return conf
}(),
state.New(state.ClientIdentifier{CID: 0, KID: 1, CommonName: "name"}, "127.0.0.1", "12345", ""),
false,
"",
true,
true,
},
{
"with ipaddr + forwarded-for",
func() config.Config {
Expand Down Expand Up @@ -429,6 +453,10 @@ func TestHandler(t *testing.T) {

conf, openVPNClient, managementInterface, _, httpClientListener, httpClient, logger := testutils.SetupMockEnvironment(ctx, t, tc.conf, nil, nil)

httpClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error {
return http.ErrUseLastResponse
}

managementInterfaceConn, errOpenVPNClientCh, err := testutils.ConnectToManagementInterface(t, managementInterface, openVPNClient)
require.NoError(t, err)

Expand Down Expand Up @@ -477,23 +505,44 @@ func TestHandler(t *testing.T) {
case tc.state == (state.State{}):
session = ""
default:
session, err = tc.state.Encode(tc.conf.HTTP.Secret.String())
session, err = tc.state.Encode(conf.HTTP.Secret.String())
require.NoError(t, err)
}

urlPath := "/oauth2/start?state="
if conf.HTTP.ShortURL {
urlPath = "/?s="
}

request, err = http.NewRequestWithContext(t.Context(), http.MethodGet,
fmt.Sprintf("%s/oauth2/start?state=%s", httpClientListener.URL, session),
fmt.Sprintf("%s%s%s", httpClientListener.URL, urlPath, session),
nil,
)

require.NoError(t, err)

if tc.xForwardedFor != "" {
request.Header.Set("X-Forwarded-For", tc.xForwardedFor)
if conf.HTTP.ShortURL {
resp, err = httpClient.Do(request) //nolint:bodyclose
require.NoError(t, err)

_, err = io.Copy(io.Discard, resp.Body)
require.NoError(t, err)

err = resp.Body.Close()
require.NoError(t, err)

require.Equal(t, http.StatusFound, resp.StatusCode)
require.NotEmpty(t, resp.Header.Get("Location"))

request, err = http.NewRequestWithContext(t.Context(), http.MethodGet,
httpClientListener.URL+resp.Header.Get("Location"),
nil,
)
require.NoError(t, err)
}

httpClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error {
return http.ErrUseLastResponse
if tc.xForwardedFor != "" {
request.Header.Set("X-Forwarded-For", tc.xForwardedFor)
}

reqErrCh := make(chan error, 1)
Expand Down Expand Up @@ -562,7 +611,7 @@ func TestHandler(t *testing.T) {
case tc.state.Client.UsernameIsDefined == 1:
testutils.ExpectMessage(t, managementInterfaceConn, reader, "client-auth-nt 0 1")
testutils.SendMessagef(t, managementInterfaceConn, "SUCCESS: client-auth command succeeded")
case tc.conf.OpenVPN.ClientConfig.Enabled:
case conf.OpenVPN.ClientConfig.Enabled:
if tc.state.Client.CommonName == "name" {
testutils.ExpectMessage(t, managementInterfaceConn, reader, "client-auth 0 1\r\n"+
"push \"ping 60\"\r\n"+
Expand All @@ -576,7 +625,7 @@ func TestHandler(t *testing.T) {
testutils.SendMessagef(t, managementInterfaceConn, "SUCCESS: client-auth command succeeded")
}
default:
if tc.conf.OAuth2.UserInfo {
if conf.OAuth2.UserInfo {
testutils.ExpectMessage(t, managementInterfaceConn, reader, "client-auth 0 1\r\npush \"auth-token-user dGVzdC11c2VyQGxvY2FsaG9zdA==\"\r\nEND")
} else {
testutils.ExpectMessage(t, managementInterfaceConn, reader, "client-auth 0 1\r\npush \"auth-token-user bmFtZQ==\"\r\nEND")
Expand Down
7 changes: 6 additions & 1 deletion internal/openvpn/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,12 @@ func (c *Client) startClientAuth(ctx context.Context, logger *slog.Logger, clien
return fmt.Errorf("error encoding state: %w", err)
}

startURL := utils.StringConcat(strings.TrimSuffix(c.conf.HTTP.BaseURL.String(), "/"), "/oauth2/start?state=", encodedSession)
urlPath := "/?s="
if !c.conf.HTTP.ShortURL {
urlPath = "/oauth2/start?state="
}

startURL := fmt.Sprintf("%s%s%s", strings.TrimSuffix(c.conf.HTTP.BaseURL.String(), "/"), urlPath, encodedSession)

if len(startURL) >= 245 {
return fmt.Errorf("url %s (%d chars) too long! OpenVPN support up to 245 chars. "+
Expand Down