Skip to content

Commit 7b8b5d5

Browse files
committed
CBG-2807: Allow databases to override CORS config (#6179)
* Don't set db headers unless we auth to DB * Disallow setting max_age on database
1 parent 0e9999b commit 7b8b5d5

18 files changed

+505
-92
lines changed

auth/cors.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright 2023-Present Couchbase, Inc.
2+
//
3+
// Use of this software is governed by the Business Source License included
4+
// in the file licenses/BSL-Couchbase.txt. As of the Change Date specified
5+
// in that file, in accordance with the Business Source License, use of this
6+
// software will be governed by the Apache License, Version 2.0, included in
7+
// the file licenses/APL2.txt.
8+
9+
package auth
10+
11+
import (
12+
"net/http"
13+
"strings"
14+
)
15+
16+
// Configuration for Cross-Origin Resource Sharing
17+
// <https://en.wikipedia.org/wiki/Cross-origin_resource_sharing>
18+
type CORSConfig struct {
19+
Origin []string `json:"origin,omitempty" help:"List of allowed origins, use ['*'] to allow access from everywhere"`
20+
LoginOrigin []string `json:"login_origin,omitempty" help:"List of allowed login origins"`
21+
Headers []string `json:"headers,omitempty" help:"List of allowed headers"`
22+
MaxAge int `json:"max_age,omitempty" help:"Maximum age of the CORS Options request"`
23+
}
24+
25+
// Adds Access-Control-Allow-Origin, Access-Control-Allow-Credentials, Access-Control-Allow-Headers headers to an HTTP response.
26+
func (cors *CORSConfig) AddResponseHeaders(request *http.Request, response http.ResponseWriter) {
27+
if originHeader := request.Header["Origin"]; len(originHeader) > 0 {
28+
origin := MatchedOrigin(cors.Origin, originHeader)
29+
response.Header().Add("Access-Control-Allow-Origin", origin)
30+
response.Header().Add("Access-Control-Allow-Credentials", "true")
31+
response.Header().Add("Access-Control-Allow-Headers", strings.Join(cors.Headers, ", "))
32+
}
33+
}
34+
35+
func MatchedOrigin(allowOrigins []string, rqOrigins []string) string {
36+
for _, rv := range rqOrigins {
37+
for _, av := range allowOrigins {
38+
if rv == av {
39+
return av
40+
}
41+
}
42+
}
43+
for _, av := range allowOrigins {
44+
if av == "*" {
45+
return "*"
46+
}
47+
}
48+
return ""
49+
}

db/database.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ type DatabaseContext struct {
130130
CollectionNames map[string]map[string]struct{} // Map of scope, collection names
131131
MetadataKeys *base.MetadataKeys // Factory to generate metadata document keys
132132
RequireResync base.ScopeAndCollectionNames // Collections requiring resync before database can go online
133+
CORS *auth.CORSConfig // CORS configuration
133134
}
134135

135136
type Scope struct {

docs/api/components/schemas.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1762,6 +1762,25 @@ Database:
17621762
Defaults to true when running in serverless mode otherwise defaults to false.
17631763
type: boolean
17641764
default: false
1765+
cors:
1766+
description: CORS configuration for this database; if present, overrides server's config.
1767+
type: object
1768+
properties:
1769+
origin:
1770+
description: 'List of allowed origins, use [''*''] to allow access from everywhere'
1771+
type: array
1772+
items:
1773+
type: string
1774+
login_origin:
1775+
description: List of allowed login origins
1776+
type: array
1777+
items:
1778+
type: string
1779+
headers:
1780+
description: List of allowed headers
1781+
type: array
1782+
items:
1783+
type: string
17651784
title: Database-config
17661785
Event-config:
17671786
type: object

rest/admin_api.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ func (h *handler) handleCreateDB() error {
5454
}
5555

5656
config.Name = dbName
57-
5857
if h.server.persistentConfig {
5958
if err := config.validatePersistentDbConfig(); err != nil {
6059
return base.HTTPErrorf(http.StatusBadRequest, err.Error())

rest/api_test.go

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -229,45 +229,65 @@ func TestFunkyDocIDs(t *testing.T) {
229229
func TestCORSOrigin(t *testing.T) {
230230
rt := NewRestTester(t, nil)
231231
defer rt.Close()
232+
tests := []struct {
233+
origin string
234+
headerOutput string
235+
}{
236+
{
237+
origin: "http://example.com",
238+
headerOutput: "http://example.com",
239+
},
240+
{
241+
origin: "http://staging.example.com",
242+
headerOutput: "http://staging.example.com",
243+
},
232244

233-
reqHeaders := map[string]string{
234-
"Origin": "http://example.com",
245+
{
246+
origin: "http://hack0r.com",
247+
headerOutput: "*",
248+
},
235249
}
236-
response := rt.SendRequestWithHeaders("GET", "/{{.keyspace}}/", "", reqHeaders)
237-
assert.Equal(t, "http://example.com", response.Header().Get("Access-Control-Allow-Origin"))
238250

239-
// now test a non-listed origin
240-
// b/c * is in config we get *
241-
reqHeaders = map[string]string{
242-
"Origin": "http://hack0r.com",
243-
}
244-
response = rt.SendRequestWithHeaders("GET", "/{{.keyspace}}/", "", reqHeaders)
245-
assert.Equal(t, "*", response.Header().Get("Access-Control-Allow-Origin"))
251+
for _, tc := range tests {
252+
t.Run(tc.origin, func(t *testing.T) {
246253

247-
// now test another origin in config
248-
reqHeaders = map[string]string{
249-
"Origin": "http://staging.example.com",
250-
}
251-
response = rt.SendRequestWithHeaders("GET", "/{{.keyspace}}/", "", reqHeaders)
252-
assert.Equal(t, "http://staging.example.com", response.Header().Get("Access-Control-Allow-Origin"))
254+
invalidDatabaseName := "invalid database name"
255+
reqHeaders := map[string]string{
256+
"Origin": tc.origin,
257+
}
258+
response := rt.SendRequestWithHeaders("GET", "/{{.keyspace}}/", "", reqHeaders)
259+
assert.Equal(t, tc.headerOutput, response.Header().Get("Access-Control-Allow-Origin"))
260+
RequireStatus(t, response, http.StatusBadRequest)
261+
require.Contains(t, response.Body.String(), invalidDatabaseName)
262+
263+
response = rt.SendRequestWithHeaders("GET", "/{{.db}}/", "", reqHeaders)
264+
assert.Equal(t, tc.headerOutput, response.Header().Get("Access-Control-Allow-Origin"))
265+
RequireStatus(t, response, http.StatusUnauthorized)
266+
require.Contains(t, response.Body.String(), ErrLoginRequired.Message)
267+
268+
response = rt.SendRequestWithHeaders("GET", "/notadb/", "", reqHeaders)
269+
assert.Equal(t, tc.headerOutput, response.Header().Get("Access-Control-Allow-Origin"))
270+
RequireStatus(t, response, http.StatusUnauthorized)
271+
require.Contains(t, response.Body.String(), ErrLoginRequired.Message)
272+
273+
// admin port doesn't have CORS
274+
response = rt.SendAdminRequestWithHeaders("GET", "/{{.keyspace}}/_all_docs", "", reqHeaders)
275+
assert.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin"))
276+
RequireStatus(t, response, http.StatusOK)
253277

254-
// test no header on _admin apis
255-
reqHeaders = map[string]string{
256-
"Origin": "http://example.com",
257-
}
258-
response = rt.SendAdminRequestWithHeaders("GET", "/{{.keyspace}}/_all_docs", "", reqHeaders)
259-
assert.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin"))
260-
261-
// test with a config without * should reject non-matches
262-
sc := rt.ServerContext()
263-
sc.Config.API.CORS.Origin = []string{"http://example.com", "http://staging.example.com"}
264-
// now test a non-listed origin
265-
// b/c * is in config we get *
266-
reqHeaders = map[string]string{
267-
"Origin": "http://hack0r.com",
278+
// test with a config without * should reject non-matches
279+
sc := rt.ServerContext()
280+
defer func() {
281+
sc.Config.API.CORS.Origin = defaultTestingCORSOrigin
282+
}()
283+
284+
sc.Config.API.CORS.Origin = []string{"http://example.com", "http://staging.example.com"}
285+
if !base.StringSliceContains(sc.Config.API.CORS.Origin, tc.origin) {
286+
response = rt.SendRequestWithHeaders("GET", "/{{.keyspace}}/", "", reqHeaders)
287+
assert.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin"))
288+
}
289+
})
268290
}
269-
response = rt.SendRequestWithHeaders("GET", "/{{.keyspace}}/", "", reqHeaders)
270-
assert.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin"))
271291
}
272292

273293
// assertGatewayStatus is like requireStatus but with StatusGatewayTimeout error checking for temporary network failures.

rest/config.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ type DbConfig struct {
165165
GraphQL *functions.GraphQLConfig `json:"graphql,omitempty"` // GraphQL configuration & resolver fns
166166
UserFunctions *functions.FunctionsConfig `json:"functions,omitempty"` // Named JS fns for clients to call
167167
Suspendable *bool `json:"suspendable,omitempty"` // Allow the database to be suspended
168+
CORS *auth.CORSConfig `json:"cors,omitempty"`
168169
}
169170

170171
type ScopesConfig map[string]ScopeConfig
@@ -680,6 +681,10 @@ func (dbConfig *DbConfig) validateVersion(ctx context.Context, isEnterpriseEditi
680681
}
681682
}
682683

684+
if dbConfig.CORS != nil && dbConfig.CORS.MaxAge != 0 {
685+
multiError = multiError.Append(fmt.Errorf("cors.max_age can not be set on a database level"))
686+
687+
}
683688
if dbConfig.DeprecatedPool != nil {
684689
base.WarnfCtx(ctx, `"pool" config option is not supported. The pool will be set to "default". The option should be removed from config file.`)
685690
}

rest/config_legacy.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ func (lc *LegacyServerConfig) ToStartupConfig() (*StartupConfig, DbConfigMap, er
230230
}
231231

232232
if lc.CORS != nil {
233-
sc.API.CORS = &CORSConfig{
233+
sc.API.CORS = &auth.CORSConfig{
234234
Origin: lc.CORS.Origin,
235235
LoginOrigin: lc.CORS.LoginOrigin,
236236
Headers: lc.CORS.Headers,

rest/config_startup.go

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ type APIConfig struct {
121121
CompressResponses *bool `json:"compress_responses,omitempty" help:"If false, disables compression of HTTP responses"`
122122
HideProductVersion *bool `json:"hide_product_version,omitempty" help:"Whether product versions removed from Server headers and REST API responses"`
123123

124-
HTTPS HTTPSConfig `json:"https,omitempty"`
125-
CORS *CORSConfig `json:"cors,omitempty"`
124+
HTTPS HTTPSConfig `json:"https,omitempty"`
125+
CORS *auth.CORSConfig `json:"cors,omitempty"`
126126
}
127127

128128
type HTTPSConfig struct {
@@ -131,13 +131,6 @@ type HTTPSConfig struct {
131131
TLSKeyPath string `json:"tls_key_path,omitempty" help:"The TLS key file to use for the REST APIs"`
132132
}
133133

134-
type CORSConfig struct {
135-
Origin []string `json:"origin,omitempty" help:"List of allowed origins, use ['*'] to allow access from everywhere"`
136-
LoginOrigin []string `json:"login_origin,omitempty" help:"List of allowed login origins"`
137-
Headers []string `json:"headers,omitempty" help:"List of allowed headers"`
138-
MaxAge int `json:"max_age,omitempty" help:"Maximum age of the CORS Options request"`
139-
}
140-
141134
type AuthConfig struct {
142135
BcryptCost int `json:"bcrypt_cost,omitempty" help:"Cost to use for bcrypt password hashes"`
143136
}
@@ -220,7 +213,7 @@ func LoadStartupConfigFromPath(path string) (*StartupConfig, error) {
220213
func NewEmptyStartupConfig() StartupConfig {
221214
return StartupConfig{
222215
API: APIConfig{
223-
CORS: &CORSConfig{},
216+
CORS: &auth.CORSConfig{},
224217
},
225218
Logging: base.LoggingConfig{
226219
Console: &base.ConsoleLoggerConfig{},

rest/config_startup_test.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"testing"
1313
"time"
1414

15+
"github.com/couchbase/sync_gateway/auth"
1516
"github.com/couchbase/sync_gateway/base"
1617
"github.com/stretchr/testify/assert"
1718
"github.com/stretchr/testify/require"
@@ -92,15 +93,15 @@ func TestStartupConfigMerge(t *testing.T) {
9293
},
9394
{
9495
name: "Keep original *CORSconfig",
95-
config: StartupConfig{API: APIConfig{CORS: &CORSConfig{MaxAge: 5, Origin: []string{"Test"}}}},
96-
override: StartupConfig{API: APIConfig{CORS: &CORSConfig{}}},
97-
expected: StartupConfig{API: APIConfig{CORS: &CORSConfig{MaxAge: 5, Origin: []string{"Test"}}}},
96+
config: StartupConfig{API: APIConfig{CORS: &auth.CORSConfig{MaxAge: 5, Origin: []string{"Test"}}}},
97+
override: StartupConfig{API: APIConfig{CORS: &auth.CORSConfig{}}},
98+
expected: StartupConfig{API: APIConfig{CORS: &auth.CORSConfig{MaxAge: 5, Origin: []string{"Test"}}}},
9899
},
99100
{
100-
name: "Keep original *CORSConfig from override nil value",
101-
config: StartupConfig{API: APIConfig{CORS: &CORSConfig{MaxAge: 5, Origin: []string{"Test"}}}},
101+
name: "Keep original *auth.CORSConfig from override nil value",
102+
config: StartupConfig{API: APIConfig{CORS: &auth.CORSConfig{MaxAge: 5, Origin: []string{"Test"}}}},
102103
override: StartupConfig{},
103-
expected: StartupConfig{API: APIConfig{CORS: &CORSConfig{MaxAge: 5, Origin: []string{"Test"}}}},
104+
expected: StartupConfig{API: APIConfig{CORS: &auth.CORSConfig{MaxAge: 5, Origin: []string{"Test"}}}},
104105
},
105106
{
106107
name: "Override unset ConfigDuration",

0 commit comments

Comments
 (0)