Skip to content

Commit 966823b

Browse files
author
Chanwit Kaewkasi
committed
implement auth middleware for s3
Signed-off-by: Chanwit Kaewkasi <[email protected]>
1 parent d7a45d0 commit 966823b

File tree

3 files changed

+448
-6
lines changed

3 files changed

+448
-6
lines changed

cmd/gitops-bucket-server/main.go

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,49 @@ import (
1111
"github.com/johannesboyne/gofakes3"
1212
"github.com/johannesboyne/gofakes3/backend/s3mem"
1313
"github.com/weaveworks/weave-gitops/pkg/http"
14+
"github.com/weaveworks/weave-gitops/pkg/s3"
1415
)
1516

1617
func main() {
18+
logger := log.New(os.Stdout, "", 0)
19+
20+
awsAccessKeyID := os.Getenv("AWS_ACCESS_KEY_ID")
21+
if awsAccessKeyID == "" {
22+
minioRootUser := os.Getenv("MINIO_ROOT_USER")
23+
if minioRootUser == "" {
24+
logger.Fatal("AWS_ACCESS_KEY_ID or MINIO_ROOT_USER must be set")
25+
return
26+
}
27+
28+
awsAccessKeyID = minioRootUser
29+
}
30+
31+
awsSecretAccessKey := os.Getenv("AWS_SECRET_ACCESS_KEY")
32+
if awsSecretAccessKey == "" {
33+
minioRootPassword := os.Getenv("MINIO_ROOT_PASSWORD")
34+
if minioRootPassword == "" {
35+
logger.Fatal("AWS_SECRET_ACCESS_KEY or MINIO_ROOT_PASSWORD must be set")
36+
return
37+
}
38+
39+
awsSecretAccessKey = minioRootPassword
40+
}
41+
1742
ctx, cancel := signal.NotifyContext(
1843
context.Background(),
1944
syscall.SIGINT,
2045
syscall.SIGTERM)
2146
defer cancel()
2247

23-
logger := log.New(os.Stdout, "", 0)
24-
backend := s3mem.New()
25-
s3 := gofakes3.New(backend,
48+
s3Server := gofakes3.New(s3mem.New(),
2649
gofakes3.WithAutoBucket(true),
27-
gofakes3.WithLogger(gofakes3.StdLog(logger, gofakes3.LogErr, gofakes3.LogWarn, gofakes3.LogInfo)))
28-
s3Server := s3.Server()
50+
gofakes3.WithLogger(
51+
gofakes3.StdLog(
52+
logger,
53+
gofakes3.LogErr,
54+
gofakes3.LogWarn,
55+
gofakes3.LogInfo,
56+
))).Server()
2957

3058
var (
3159
httpPort, httpsPort int
@@ -54,7 +82,7 @@ func main() {
5482
Logger: logger,
5583
}
5684

57-
if err := srv.Start(ctx, s3Server); err != nil {
85+
if err := srv.Start(ctx, s3.AuthMiddleware(awsAccessKeyID, awsSecretAccessKey, s3Server)); err != nil {
5886
logger.Fatalf("server exited unexpectedly: %s", err)
5987
}
6088
}

pkg/s3/auth_middleware.go

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
package s3
2+
3+
import (
4+
"bytes"
5+
"crypto/hmac"
6+
"crypto/sha256"
7+
"encoding/hex"
8+
"fmt"
9+
"net/http"
10+
"sort"
11+
"strings"
12+
"time"
13+
14+
"github.com/minio/minio-go/v7/pkg/s3utils"
15+
)
16+
17+
// Signature and API related constants.
18+
const (
19+
unsignedPayload = "UNSIGNED-PAYLOAD"
20+
signV4Algorithm = "AWS4-HMAC-SHA256"
21+
iso8601DateFormat = "20060102T150405Z"
22+
yyyymmdd = "20060102"
23+
)
24+
25+
type credential struct {
26+
AccessKeyID string
27+
Time string
28+
Location string
29+
Service string
30+
Request string
31+
}
32+
33+
func parseCredential(str string) (credential, error) {
34+
parts := strings.Split(str, "/")
35+
if len(parts) != 5 {
36+
return credential{}, fmt.Errorf("invalid credential format")
37+
}
38+
39+
return credential{
40+
AccessKeyID: parts[0],
41+
Time: parts[1],
42+
Location: parts[2],
43+
Service: parts[3],
44+
Request: parts[4],
45+
}, nil
46+
}
47+
48+
func AuthMiddleware(accessKeyID, secretAccessKey string, handler http.Handler) http.Handler {
49+
return http.HandlerFunc(func(w http.ResponseWriter, rq *http.Request) {
50+
if err := verifySignature(*rq, accessKeyID, secretAccessKey); err != nil {
51+
authorizedError(w, err)
52+
return
53+
}
54+
55+
handler.ServeHTTP(w, rq)
56+
})
57+
}
58+
59+
func authorizedError(w http.ResponseWriter, err error) {
60+
w.WriteHeader(http.StatusUnauthorized)
61+
w.Header().Set("Content-Type", "application/xml")
62+
63+
xml := fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
64+
<Error>
65+
<Code>Unauthorized</Code>
66+
<Message>%s</Message>
67+
</Error>`, err.Error())
68+
69+
if _, err := w.Write([]byte(xml)); err != nil {
70+
return
71+
}
72+
}
73+
74+
// sum256 calculate sha256 sum for an input byte array.
75+
func sum256(data []byte) []byte {
76+
hash := sha256.New()
77+
hash.Write(data)
78+
79+
return hash.Sum(nil)
80+
}
81+
82+
// sumHMAC calculate hmac between two input byte array.
83+
func sumHMAC(key []byte, data []byte) []byte {
84+
hash := hmac.New(sha256.New, key)
85+
hash.Write(data)
86+
87+
return hash.Sum(nil)
88+
}
89+
90+
func getHashedPayload(req http.Request) string {
91+
hashedPayload := req.Header.Get("X-Amz-Content-Sha256")
92+
if hashedPayload == "" {
93+
// Presign does not have a payload, use S3 recommended value.
94+
hashedPayload = unsignedPayload
95+
}
96+
97+
return hashedPayload
98+
}
99+
100+
func headerExists(key string, headers []string) bool {
101+
for _, k := range headers {
102+
if k == key {
103+
return true
104+
}
105+
}
106+
107+
return false
108+
}
109+
110+
// getHostAddr returns host header if available, otherwise returns host from URL
111+
func getHostAddr(req *http.Request) string {
112+
host := req.Header.Get("host")
113+
if host != "" && req.Host != host {
114+
return host
115+
}
116+
117+
if req.Host != "" {
118+
return req.Host
119+
}
120+
121+
return req.URL.Host
122+
}
123+
124+
// Trim leading and trailing spaces and replace sequential spaces with one space, following Trimall()
125+
// in http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
126+
func signV4TrimAll(input string) string {
127+
// Compress adjacent spaces (a space is determined by
128+
// unicode.IsSpace() internally here) to one space and return
129+
return strings.Join(strings.Fields(input), " ")
130+
}
131+
132+
func getCanonicalHeaders(req http.Request, signedHeaders string) string {
133+
headers := []string{}
134+
vals := make(map[string][]string)
135+
136+
for _, k := range strings.Split(signedHeaders, ";") {
137+
lk := strings.ToLower(k)
138+
headers = append(headers, lk)
139+
vals[lk] = req.Header.Values(k)
140+
}
141+
142+
if !headerExists("host", headers) {
143+
headers = append(headers, "host")
144+
}
145+
146+
sort.Strings(headers)
147+
148+
var buf bytes.Buffer
149+
// Save all the headers in canonical form <header>:<value> newline
150+
// separated for each header.
151+
for _, k := range headers {
152+
buf.WriteString(k)
153+
buf.WriteByte(':')
154+
155+
switch {
156+
case k == "host":
157+
buf.WriteString(getHostAddr(&req))
158+
buf.WriteByte('\n')
159+
default:
160+
for idx, v := range vals[k] {
161+
if idx > 0 {
162+
buf.WriteByte(',')
163+
}
164+
165+
buf.WriteString(signV4TrimAll(v))
166+
}
167+
168+
buf.WriteByte('\n')
169+
}
170+
}
171+
172+
return buf.String()
173+
}
174+
175+
// getSigningKey hmac seed to calculate final signature.
176+
func getSigningKey(secret, loc string, t time.Time, serviceType string) []byte {
177+
date := sumHMAC([]byte("AWS4"+secret), []byte(t.Format(yyyymmdd)))
178+
location := sumHMAC(date, []byte(loc))
179+
service := sumHMAC(location, []byte(serviceType))
180+
signingKey := sumHMAC(service, []byte("aws4_request"))
181+
182+
return signingKey
183+
}
184+
185+
// getSignature final signature in hexadecimal form.
186+
func getSignature(signingKey []byte, stringToSign string) string {
187+
return hex.EncodeToString(sumHMAC(signingKey, []byte(stringToSign)))
188+
}
189+
190+
// getScope generate a string of a specific date, an AWS region, and a
191+
// service.
192+
func getScope(location string, t time.Time, serviceType string) string {
193+
scope := strings.Join([]string{
194+
t.Format(yyyymmdd),
195+
location,
196+
serviceType,
197+
"aws4_request",
198+
}, "/")
199+
200+
return scope
201+
}
202+
203+
// getCredential generate a credential string.
204+
func getCredential(accessKeyID, location string, t time.Time, serviceType string) string {
205+
scope := getScope(location, t, serviceType)
206+
return accessKeyID + "/" + scope
207+
}
208+
209+
// getCanonicalRequest generate a canonical request of style.
210+
//
211+
// canonicalRequest =
212+
//
213+
// <HTTPMethod>\n
214+
// <CanonicalURI>\n
215+
// <CanonicalQueryString>\n
216+
// <CanonicalHeaders>\n
217+
// <SignedHeaders>\n
218+
// <HashedPayload>
219+
func getCanonicalRequest(req http.Request, signedHeaders string, hashedPayload string) string {
220+
req.URL.RawQuery = strings.ReplaceAll(req.URL.Query().Encode(), "+", "%20")
221+
canonicalRequest := strings.Join([]string{
222+
req.Method,
223+
s3utils.EncodePath(req.URL.Path),
224+
req.URL.RawQuery,
225+
getCanonicalHeaders(req, signedHeaders),
226+
signedHeaders,
227+
hashedPayload,
228+
}, "\n")
229+
230+
return canonicalRequest
231+
}
232+
233+
// getStringToSignV4 a string based on selected query values.
234+
func getStringToSignV4(t time.Time, location, canonicalRequest, serviceType string) string {
235+
stringToSign := signV4Algorithm + "\n" + t.Format(iso8601DateFormat) + "\n"
236+
stringToSign = stringToSign + getScope(location, t, serviceType) + "\n"
237+
stringToSign += hex.EncodeToString(sum256([]byte(canonicalRequest)))
238+
239+
return stringToSign
240+
}
241+
242+
// verifySignature - verify signature for S3 version '4'
243+
func verifySignature(req http.Request, accessKeyID string, secretAccessKey string) error {
244+
auth := req.Header.Get("Authorization")
245+
if auth == "" {
246+
return fmt.Errorf("header Authorization is missing")
247+
}
248+
249+
auth = strings.TrimPrefix(auth, signV4Algorithm+" ")
250+
251+
parts := strings.Split(auth, ", ")
252+
if len(parts) != 3 {
253+
return fmt.Errorf("invalid Authorization header")
254+
}
255+
256+
credentialStr := strings.SplitN(parts[0], "=", 2)[1]
257+
signedHeaders := strings.SplitN(parts[1], "=", 2)[1]
258+
parsedSignature := strings.SplitN(parts[2], "=", 2)[1]
259+
260+
hashedPayload := getHashedPayload(req)
261+
262+
amzDate := req.Header.Get("X-Amz-Date")
263+
if amzDate == "" {
264+
return fmt.Errorf("header X-Amz-Date is missing")
265+
}
266+
267+
t, err := time.Parse(iso8601DateFormat, amzDate)
268+
if err != nil {
269+
return err
270+
}
271+
272+
credential, err := parseCredential(credentialStr)
273+
if err != nil {
274+
return err
275+
}
276+
277+
// Get canonical request
278+
canonicalRequest := getCanonicalRequest(req, signedHeaders, hashedPayload)
279+
280+
// Get string to sign from canonical request.
281+
stringToSign := getStringToSignV4(t, credential.Location, canonicalRequest, credential.Service)
282+
283+
// Get hmac signing key.
284+
signingKey := getSigningKey(secretAccessKey, credential.Location, t, credential.Service)
285+
286+
// Get credential string.
287+
computedCredential := getCredential(accessKeyID, credential.Location, t, credential.Service)
288+
289+
// Calculate parsedSignature.
290+
computedSignature := getSignature(signingKey, stringToSign)
291+
292+
if computedCredential != credentialStr {
293+
return fmt.Errorf("access denied: credential does not match")
294+
}
295+
296+
if computedSignature != parsedSignature {
297+
return fmt.Errorf("access denied: signature does not match")
298+
}
299+
300+
return nil
301+
}

0 commit comments

Comments
 (0)