@@ -229,45 +229,65 @@ func TestFunkyDocIDs(t *testing.T) {
229
229
func TestCORSOrigin (t * testing.T ) {
230
230
rt := NewRestTester (t , nil )
231
231
defer rt .Close ()
232
+ tests := []struct {
233
+ origin string
234
+ headerOutput string
235
+ }{
236
+ {
237
+ origin : "http://example.com" ,
238
+ headerOutput : "http://example.com" ,
239
+ },
240
+ {
241
+ origin : "http://staging.example.com" ,
242
+ headerOutput : "http://staging.example.com" ,
243
+ },
232
244
233
- reqHeaders := map [string ]string {
234
- "Origin" : "http://example.com" ,
245
+ {
246
+ origin : "http://hack0r.com" ,
247
+ headerOutput : "*" ,
248
+ },
235
249
}
236
- response := rt .SendRequestWithHeaders ("GET" , "/{{.keyspace}}/" , "" , reqHeaders )
237
- assert .Equal (t , "http://example.com" , response .Header ().Get ("Access-Control-Allow-Origin" ))
238
250
239
- // now test a non-listed origin
240
- // b/c * is in config we get *
241
- reqHeaders = map [string ]string {
242
- "Origin" : "http://hack0r.com" ,
243
- }
244
- response = rt .SendRequestWithHeaders ("GET" , "/{{.keyspace}}/" , "" , reqHeaders )
245
- assert .Equal (t , "*" , response .Header ().Get ("Access-Control-Allow-Origin" ))
251
+ for _ , tc := range tests {
252
+ t .Run (tc .origin , func (t * testing.T ) {
246
253
247
- // now test another origin in config
248
- reqHeaders = map [string ]string {
249
- "Origin" : "http://staging.example.com" ,
250
- }
251
- response = rt .SendRequestWithHeaders ("GET" , "/{{.keyspace}}/" , "" , reqHeaders )
252
- assert .Equal (t , "http://staging.example.com" , response .Header ().Get ("Access-Control-Allow-Origin" ))
254
+ invalidDatabaseName := "invalid database name"
255
+ reqHeaders := map [string ]string {
256
+ "Origin" : tc .origin ,
257
+ }
258
+ response := rt .SendRequestWithHeaders ("GET" , "/{{.keyspace}}/" , "" , reqHeaders )
259
+ assert .Equal (t , tc .headerOutput , response .Header ().Get ("Access-Control-Allow-Origin" ))
260
+ RequireStatus (t , response , http .StatusBadRequest )
261
+ require .Contains (t , response .Body .String (), invalidDatabaseName )
262
+
263
+ response = rt .SendRequestWithHeaders ("GET" , "/{{.db}}/" , "" , reqHeaders )
264
+ assert .Equal (t , tc .headerOutput , response .Header ().Get ("Access-Control-Allow-Origin" ))
265
+ RequireStatus (t , response , http .StatusUnauthorized )
266
+ require .Contains (t , response .Body .String (), ErrLoginRequired .Message )
267
+
268
+ response = rt .SendRequestWithHeaders ("GET" , "/notadb/" , "" , reqHeaders )
269
+ assert .Equal (t , tc .headerOutput , response .Header ().Get ("Access-Control-Allow-Origin" ))
270
+ RequireStatus (t , response , http .StatusUnauthorized )
271
+ require .Contains (t , response .Body .String (), ErrLoginRequired .Message )
272
+
273
+ // admin port doesn't have CORS
274
+ response = rt .SendAdminRequestWithHeaders ("GET" , "/{{.keyspace}}/_all_docs" , "" , reqHeaders )
275
+ assert .Equal (t , "" , response .Header ().Get ("Access-Control-Allow-Origin" ))
276
+ RequireStatus (t , response , http .StatusOK )
253
277
254
- // test no header on _admin apis
255
- reqHeaders = map [string ]string {
256
- "Origin" : "http://example.com" ,
257
- }
258
- response = rt .SendAdminRequestWithHeaders ("GET" , "/{{.keyspace}}/_all_docs" , "" , reqHeaders )
259
- assert .Equal (t , "" , response .Header ().Get ("Access-Control-Allow-Origin" ))
260
-
261
- // test with a config without * should reject non-matches
262
- sc := rt .ServerContext ()
263
- sc .Config .API .CORS .Origin = []string {"http://example.com" , "http://staging.example.com" }
264
- // now test a non-listed origin
265
- // b/c * is in config we get *
266
- reqHeaders = map [string ]string {
267
- "Origin" : "http://hack0r.com" ,
278
+ // test with a config without * should reject non-matches
279
+ sc := rt .ServerContext ()
280
+ defer func () {
281
+ sc .Config .API .CORS .Origin = defaultTestingCORSOrigin
282
+ }()
283
+
284
+ sc .Config .API .CORS .Origin = []string {"http://example.com" , "http://staging.example.com" }
285
+ if ! base .StringSliceContains (sc .Config .API .CORS .Origin , tc .origin ) {
286
+ response = rt .SendRequestWithHeaders ("GET" , "/{{.keyspace}}/" , "" , reqHeaders )
287
+ assert .Equal (t , "" , response .Header ().Get ("Access-Control-Allow-Origin" ))
288
+ }
289
+ })
268
290
}
269
- response = rt .SendRequestWithHeaders ("GET" , "/{{.keyspace}}/" , "" , reqHeaders )
270
- assert .Equal (t , "" , response .Header ().Get ("Access-Control-Allow-Origin" ))
271
291
}
272
292
273
293
// assertGatewayStatus is like requireStatus but with StatusGatewayTimeout error checking for temporary network failures.
0 commit comments