Skip to content
This repository was archived by the owner on Apr 2, 2024. It is now read-only.

Commit 98b9f0d

Browse files
Blagoj Atanasovskicevian
authored andcommitted
Add proper CORS handling by the HTTP API
A flag can be set by the user to specify the allowed origins to access the HTTP API. A wrapper handler sets the proper CORS headers if the flag is enabled.
1 parent 6693256 commit 98b9f0d

File tree

14 files changed

+226
-68
lines changed

14 files changed

+226
-68
lines changed

cmd/timescale-prometheus/main.go

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,23 @@ import (
1313
"net/http"
1414
pprof "net/http/pprof"
1515
"os"
16+
"regexp"
1617
"sync/atomic"
1718
"time"
1819

1920
"github.com/prometheus/common/route"
2021

2122
"github.com/jackc/pgx/v4/pgxpool"
2223
_ "github.com/jackc/pgx/v4/stdlib"
23-
24+
"github.com/jamiealquiza/envy"
25+
"github.com/prometheus/client_golang/prometheus"
26+
"github.com/prometheus/client_golang/prometheus/promhttp"
2427
"github.com/timescale/timescale-prometheus/pkg/api"
2528
"github.com/timescale/timescale-prometheus/pkg/log"
2629
"github.com/timescale/timescale-prometheus/pkg/pgclient"
2730
"github.com/timescale/timescale-prometheus/pkg/pgmodel"
28-
"github.com/timescale/timescale-prometheus/pkg/util"
29-
30-
"github.com/jamiealquiza/envy"
31-
32-
"github.com/prometheus/client_golang/prometheus"
33-
"github.com/prometheus/client_golang/prometheus/promhttp"
3431
"github.com/timescale/timescale-prometheus/pkg/query"
32+
"github.com/timescale/timescale-prometheus/pkg/util"
3533
)
3634

3735
type config struct {
@@ -44,6 +42,7 @@ type config struct {
4442
prometheusTimeout time.Duration
4543
electionInterval time.Duration
4644
migrate bool
45+
corsOrigin *regexp.Regexp
4746
}
4847

4948
const (
@@ -148,8 +147,12 @@ func init() {
148147
}
149148

150149
func main() {
151-
cfg := parseFlags()
152-
err := log.Init(cfg.logLevel)
150+
cfg, err := parseFlags()
151+
if err != nil {
152+
fmt.Println("Version: ", Version, "Commit Hash: ", CommitHash)
153+
fmt.Println("Fatal error: cannot parse flags ", err)
154+
}
155+
err = log.Init(cfg.logLevel)
153156
if err != nil {
154157
fmt.Println("Version: ", Version, "Commit Hash: ", CommitHash)
155158
fmt.Println("Fatal error: cannot start logger", err)
@@ -231,7 +234,6 @@ func main() {
231234
prometheus.MustRegister(labelsCacheCap)
232235

233236
router := route.New()
234-
235237
promMetrics := api.Metrics{
236238
LeaderGauge: leaderGauge,
237239
ReceivedSamples: receivedSamples,
@@ -253,21 +255,22 @@ func main() {
253255
router.Get("/read", readHandler)
254256
router.Post("/read", readHandler)
255257

258+
apiConf := &api.Config{AllowedOrigin: cfg.corsOrigin}
256259
queryable := client.GetQueryable()
257260
queryEngine := query.NewEngine(log.GetLogger(), time.Minute)
258-
queryHandler := timeHandler(httpRequestDuration, "query", api.Query(queryEngine, queryable))
261+
queryHandler := timeHandler(httpRequestDuration, "query", api.Query(apiConf, queryEngine, queryable))
259262
router.Get("/api/v1/query", queryHandler)
260263
router.Post("/api/v1/query", queryHandler)
261264

262-
queryRangeHandler := timeHandler(httpRequestDuration, "query_range", api.QueryRange(queryEngine, queryable))
265+
queryRangeHandler := timeHandler(httpRequestDuration, "query_range", api.QueryRange(apiConf, queryEngine, queryable))
263266
router.Get("/api/v1/query_range", queryRangeHandler)
264267
router.Post("/api/v1/query_range", queryRangeHandler)
265268

266-
labelsHandler := timeHandler(httpRequestDuration, "labels", api.Labels(queryable))
269+
labelsHandler := timeHandler(httpRequestDuration, "labels", api.Labels(apiConf, queryable))
267270
router.Get("/api/v1/labels", labelsHandler)
268271
router.Post("/api/v1/labels", labelsHandler)
269272

270-
labelValuesHandler := timeHandler(httpRequestDuration, "label/:name/values", api.LabelValues(queryable))
273+
labelValuesHandler := timeHandler(httpRequestDuration, "label/:name/values", api.LabelValues(apiConf, queryable))
271274
router.Get("/api/v1/label/:name/values", labelValuesHandler)
272275

273276
router.Get("/healthz", api.Health(client))
@@ -291,14 +294,23 @@ func main() {
291294
}
292295
}
293296

294-
func parseFlags() *config {
297+
func parseFlags() (*config, error) {
295298

296299
cfg := &config{}
297300

298301
pgclient.ParseFlags(&cfg.pgmodelCfg)
299302

300303
flag.StringVar(&cfg.listenAddr, "web-listen-address", ":9201", "Address to listen on for web endpoints.")
301304
flag.StringVar(&cfg.telemetryPath, "web-telemetry-path", "/metrics", "Address to listen on for web endpoints.")
305+
306+
var corsOriginFlag string
307+
flag.StringVar(&corsOriginFlag, "web-cors-origin", ".*", `Regex for CORS origin. It is fully anchored. Example: 'https?://(domain1|domain2)\.com'`)
308+
corsOriginRegex, err := compileAnchoredRegexString(corsOriginFlag)
309+
if err != nil {
310+
err = fmt.Errorf("could not compile CORS regex string %v: %w", corsOriginFlag, err)
311+
return nil, err
312+
}
313+
cfg.corsOrigin = corsOriginRegex
302314
flag.StringVar(&cfg.logLevel, "log-level", "debug", "The log level to use [ \"error\", \"warn\", \"info\", \"debug\" ].")
303315
flag.IntVar(&cfg.haGroupLockID, "leader-election-pg-advisory-lock-id", 0, "Unique advisory lock id per adapter high-availability group. Set it if you want to use leader election implementation based on PostgreSQL advisory lock.")
304316
flag.DurationVar(&cfg.prometheusTimeout, "leader-election-pg-advisory-lock-prometheus-timeout", -1, "Adapter will resign if there are no requests from Prometheus within a given timeout (0 means no timeout). "+
@@ -309,7 +321,7 @@ func parseFlags() *config {
309321
envy.Parse("TS_PROM")
310322
flag.Parse()
311323

312-
return cfg
324+
return cfg, nil
313325
}
314326

315327
func initElector(cfg *config) (*util.Elector, error) {
@@ -388,3 +400,11 @@ func timeHandler(histogramVec prometheus.ObserverVec, path string, handler http.
388400
histogramVec.WithLabelValues(path).Observe(float64(elapsedMs))
389401
}
390402
}
403+
404+
func compileAnchoredRegexString(s string) (*regexp.Regexp, error) {
405+
r, err := regexp.Compile("^(?:" + s + ")$")
406+
if err != nil {
407+
return nil, err
408+
}
409+
return r, nil
410+
}

pkg/api/common.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ import (
66
"io"
77
"math"
88
"net/http"
9+
"regexp"
910
"strconv"
1011
"time"
1112

1213
"github.com/pkg/errors"
1314
"github.com/prometheus/common/model"
1415
"github.com/prometheus/prometheus/promql/parser"
1516
"github.com/prometheus/prometheus/storage"
17+
"github.com/prometheus/prometheus/util/httputil"
1618
"github.com/timescale/timescale-prometheus/pkg/promql"
1719
)
1820

@@ -24,8 +26,21 @@ var (
2426
maxTimeFormatted = maxTime.Format(time.RFC3339Nano)
2527
)
2628

27-
func setHeaders(w http.ResponseWriter, res *promql.Result, warnings storage.Warnings) {
28-
w.Header().Set("Access-Control-Allow-Origin", "*")
29+
type Config struct {
30+
AllowedOrigin *regexp.Regexp
31+
}
32+
33+
func corsWrapper(conf *Config, f http.HandlerFunc) http.HandlerFunc {
34+
if conf.AllowedOrigin == nil {
35+
return f
36+
}
37+
return func(w http.ResponseWriter, r *http.Request) {
38+
httputil.SetCORS(w, conf.AllowedOrigin, r)
39+
f(w, r)
40+
}
41+
}
42+
43+
func setResponseHeaders(w http.ResponseWriter, res *promql.Result, warnings storage.Warnings) {
2944
w.Header().Set("Content-Type", "application/json")
3045
if warnings != nil && len(warnings) > 0 {
3146
w.Header().Set("Cache-Control", "no-store")
@@ -38,7 +53,7 @@ func setHeaders(w http.ResponseWriter, res *promql.Result, warnings storage.Warn
3853
}
3954

4055
func respondQuery(w http.ResponseWriter, res *promql.Result, warnings storage.Warnings) {
41-
setHeaders(w, res, warnings)
56+
setResponseHeaders(w, res, warnings)
4257
switch resVal := res.Value.(type) {
4358
case promql.Vector:
4459
warnings := make([]string, 0, len(res.Warnings))

pkg/api/common_test.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package api
2+
3+
import (
4+
"context"
5+
"github.com/timescale/timescale-prometheus/pkg/log"
6+
"net/http"
7+
"net/http/httptest"
8+
"reflect"
9+
"regexp"
10+
"testing"
11+
)
12+
13+
func TestCORSWrapper(t *testing.T) {
14+
_ = log.Init("debug")
15+
acceptSpecific, _ := regexp.Compile("^(?:" + "http://some-site.com" + ")$")
16+
acceptAny, _ := regexp.Compile("^(?:" + ".*" + ")$")
17+
18+
testCases := []struct {
19+
name string
20+
requestOrigin string
21+
acceptedOrigin *regexp.Regexp
22+
expectHeaders map[string][]string
23+
}{
24+
{
25+
name: "No origin",
26+
requestOrigin: "",
27+
acceptedOrigin: acceptSpecific,
28+
expectHeaders: map[string][]string{},
29+
}, {
30+
name: "Origin doesn't match accepted",
31+
requestOrigin: "http://some-unknown-site.com",
32+
acceptedOrigin: acceptSpecific,
33+
expectHeaders: map[string][]string{
34+
"Access-Control-Allow-Headers": {"Accept, Authorization, Content-Type, Origin"},
35+
"Access-Control-Allow-Methods": {"GET, POST, OPTIONS"},
36+
"Access-Control-Expose-Headers": {"Date"},
37+
"Vary": {"Origin"},
38+
},
39+
},
40+
{
41+
name: "Origin matches accepted",
42+
requestOrigin: "http://some-site.com",
43+
acceptedOrigin: acceptSpecific,
44+
expectHeaders: map[string][]string{
45+
"Access-Control-Allow-Headers": {"Accept, Authorization, Content-Type, Origin"},
46+
"Access-Control-Allow-Methods": {"GET, POST, OPTIONS"},
47+
"Access-Control-Expose-Headers": {"Date"},
48+
"Access-Control-Allow-Origin": {"http://some-site.com"},
49+
"Vary": {"Origin"},
50+
},
51+
}, {
52+
name: "Wildcard allowed origin",
53+
requestOrigin: "http://any-site.com",
54+
acceptedOrigin: acceptAny,
55+
expectHeaders: map[string][]string{
56+
"Access-Control-Allow-Headers": {"Accept, Authorization, Content-Type, Origin"},
57+
"Access-Control-Allow-Methods": {"GET, POST, OPTIONS"},
58+
"Access-Control-Expose-Headers": {"Date"},
59+
"Access-Control-Allow-Origin": {"*"},
60+
"Vary": {"Origin"},
61+
},
62+
},
63+
}
64+
for _, tc := range testCases {
65+
t.Run(tc.name, func(t *testing.T) {
66+
conf := &Config{}
67+
if tc.acceptedOrigin != nil {
68+
conf.AllowedOrigin = tc.acceptedOrigin
69+
} else {
70+
tc.acceptedOrigin = &regexp.Regexp{}
71+
}
72+
internalHandlerCalled := false
73+
handler := corsWrapper(conf, func(http.ResponseWriter, *http.Request) {
74+
internalHandlerCalled = true
75+
})
76+
w := doCORSWrapperRequest(t, handler, "http://localhost/", tc.requestOrigin)
77+
if !internalHandlerCalled {
78+
t.Fatalf("internal handler not called by CORS wrapper")
79+
return
80+
}
81+
returnedHeaders := w.Header()
82+
if len(returnedHeaders) != len(tc.expectHeaders) {
83+
t.Fatalf("expected %d headers, got %d", len(tc.expectHeaders), len(returnedHeaders))
84+
return
85+
}
86+
for hName, hValues := range tc.expectHeaders {
87+
returnedValues := returnedHeaders[hName]
88+
if !reflect.DeepEqual(hValues, returnedValues) {
89+
t.Errorf("expected header %s with value %v; got %v", hName, hValues, returnedValues)
90+
}
91+
}
92+
})
93+
94+
}
95+
96+
}
97+
98+
func doCORSWrapperRequest(t *testing.T, queryHandler http.Handler, url, origin string) *httptest.ResponseRecorder {
99+
req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil)
100+
if err != nil {
101+
t.Errorf("%v", err)
102+
}
103+
104+
req.Header.Set("Origin", origin)
105+
w := httptest.NewRecorder()
106+
queryHandler.ServeHTTP(w, req)
107+
return w
108+
}

pkg/api/label_values.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,30 @@ package api
33
import (
44
"context"
55
"fmt"
6+
"math"
7+
"net/http"
8+
69
"github.com/NYTimes/gziphandler"
710
"github.com/prometheus/common/model"
811
"github.com/prometheus/common/route"
912
"github.com/timescale/timescale-prometheus/pkg/promql"
1013
"github.com/timescale/timescale-prometheus/pkg/query"
11-
"math"
12-
"net/http"
1314
)
1415

15-
func LabelValues(queriable *query.Queryable) http.Handler {
16-
hf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
16+
func LabelValues(conf *Config, queryable *query.Queryable) http.Handler {
17+
hf := corsWrapper(conf, labelValues(queryable))
18+
return gziphandler.GzipHandler(hf)
19+
}
20+
21+
func labelValues(queryable *query.Queryable) http.HandlerFunc {
22+
return func(w http.ResponseWriter, r *http.Request) {
1723
ctx := r.Context()
1824
name := route.Param(ctx, "name")
1925
if !model.LabelNameRE.MatchString(name) {
2026
respondError(w, http.StatusBadRequest, fmt.Errorf("invalid label name: %s", name), "bad_data")
2127
return
2228
}
23-
querier, err := queriable.Querier(context.Background(), math.MinInt64, math.MaxInt64)
29+
querier, err := queryable.Querier(context.Background(), math.MinInt64, math.MaxInt64)
2430
if err != nil {
2531
respondError(w, http.StatusInternalServerError, err, "internal")
2632
return
@@ -35,7 +41,5 @@ func LabelValues(queriable *query.Queryable) http.Handler {
3541
respondLabels(w, &promql.Result{
3642
Value: values,
3743
}, warnings)
38-
})
39-
40-
return gziphandler.GzipHandler(hf)
44+
}
4145
}

pkg/api/labels.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@ package api
33
import (
44
"context"
55
"encoding/json"
6+
"math"
7+
"net/http"
8+
"strings"
9+
610
"github.com/NYTimes/gziphandler"
711
"github.com/prometheus/prometheus/promql/parser"
812
"github.com/prometheus/prometheus/storage"
913
"github.com/timescale/timescale-prometheus/pkg/promql"
1014
"github.com/timescale/timescale-prometheus/pkg/query"
11-
"math"
12-
"net/http"
13-
"strings"
1415
)
1516

1617
type labelsValue []string
@@ -23,9 +24,14 @@ func (l labelsValue) String() string {
2324
return strings.Join(l, "\n")
2425
}
2526

26-
func Labels(queriable *query.Queryable) http.Handler {
27-
hf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
28-
querier, err := queriable.Querier(context.Background(), math.MinInt64, math.MaxInt64)
27+
func Labels(conf *Config, queryable *query.Queryable) http.Handler {
28+
hf := corsWrapper(conf, labelsHandler(queryable))
29+
return gziphandler.GzipHandler(hf)
30+
}
31+
32+
func labelsHandler(queryable *query.Queryable) http.HandlerFunc {
33+
return func(w http.ResponseWriter, r *http.Request) {
34+
querier, err := queryable.Querier(context.Background(), math.MinInt64, math.MaxInt64)
2935
if err != nil {
3036
respondError(w, http.StatusInternalServerError, err, "internal")
3137
return
@@ -39,13 +45,11 @@ func Labels(queriable *query.Queryable) http.Handler {
3945
respondLabels(w, &promql.Result{
4046
Value: names,
4147
}, warnings)
42-
})
43-
44-
return gziphandler.GzipHandler(hf)
48+
}
4549
}
4650

4751
func respondLabels(w http.ResponseWriter, res *promql.Result, warnings storage.Warnings) {
48-
setHeaders(w, res, warnings)
52+
setResponseHeaders(w, res, warnings)
4953
resp := &response{
5054
Status: "success",
5155
Data: res.Value,

pkg/api/labels_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func TestLabels(t *testing.T) {
3535
}
3636
for _, tc := range testCases {
3737
t.Run(tc.name, func(t *testing.T) {
38-
handler := Labels(query.NewQueryable(tc.querier))
38+
handler := labelsHandler(query.NewQueryable(tc.querier))
3939
w := doLabels(t, handler)
4040

4141
if w.Code != tc.expectCode {

0 commit comments

Comments
 (0)