Skip to content

Commit d83e545

Browse files
Fix bug in validation of multiple audiences (#441)
* Fix bug in validation of multiple audiences In a situation where multiple audiences are validated by the validator, the order of evaluation of the for-range loop affects the result. If we produce matches such as: ``` { "example.org": true, "example.com": false, } ``` and we configured the validator to expect a single match on audience, the code would either: 1. produce "token has invalid audience" if "example.org" was evaluated first 2. produce a passing result if "example.com" was evaluated first This commit fixes this bug, and adds a suite of tests as well as regression tests to prevent this issue in future. * Adding three more test cases to be sure * Removing required alltogether form verifyAudience * Removing required --------- Co-authored-by: Christian Banse <[email protected]>
1 parent 75740f1 commit d83e545

File tree

2 files changed

+133
-32
lines changed

2 files changed

+133
-32
lines changed

validator.go

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package jwt
22

33
import (
44
"fmt"
5+
"slices"
56
"time"
67
)
78

@@ -124,7 +125,7 @@ func (v *Validator) Validate(claims Claims) error {
124125

125126
// If we have an expected audience, we also require the audience claim
126127
if len(v.expectedAud) > 0 {
127-
if err = v.verifyAudience(claims, v.expectedAud, v.expectAllAud, true); err != nil {
128+
if err = v.verifyAudience(claims, v.expectedAud, v.expectAllAud); err != nil {
128129
errs = append(errs, err)
129130
}
130131
}
@@ -229,52 +230,39 @@ func (v *Validator) verifyNotBefore(claims Claims, cmp time.Time, required bool)
229230
//
230231
// Additionally, if any error occurs while retrieving the claim, e.g., when its
231232
// the wrong type, an ErrTokenUnverifiable error will be returned.
232-
func (v *Validator) verifyAudience(claims Claims, cmp []string, expectAllAud bool, required bool) error {
233+
func (v *Validator) verifyAudience(claims Claims, cmp []string, expectAllAud bool) error {
233234
aud, err := claims.GetAudience()
234235
if err != nil {
235236
return err
236237
}
237238

238-
if len(aud) == 0 {
239+
// Check that aud exists and is not empty. We only require the aud claim
240+
// if we expect at least one audience to be present.
241+
if len(aud) == 0 || len(aud) == 1 && aud[0] == "" {
242+
required := len(v.expectedAud) > 0
239243
return errorIfRequired(required, "aud")
240244
}
241245

242-
// use a var here to keep constant time compare when looping over a number of claims
243-
matching := make(map[string]bool, 0)
244-
245-
// build a matching hashmap out of the expected aud
246-
for _, expected := range cmp {
247-
matching[expected] = false
248-
}
249-
250-
// compare the expected aud with the actual aud in a constant time manner by looping over all actual values
251-
var stringClaims string
252-
for _, a := range aud {
253-
a := a
254-
_, ok := matching[a]
255-
if ok {
256-
matching[a] = true
246+
if !expectAllAud {
247+
for _, a := range aud {
248+
// If we only expect one match, we can stop early if we find a match
249+
if slices.Contains(cmp, a) {
250+
return nil
251+
}
257252
}
258253

259-
stringClaims = stringClaims + a
254+
return ErrTokenInvalidAudience
260255
}
261256

262-
// check if all expected auds are present
263-
result := true
264-
for _, match := range matching {
265-
if !expectAllAud && match {
266-
break
267-
} else if !match {
268-
result = false
257+
// Note that we are looping cmp here to ensure that all expected audiences
258+
// are present in the aud claim.
259+
for _, a := range cmp {
260+
if !slices.Contains(aud, a) {
261+
return ErrTokenInvalidAudience
269262
}
270263
}
271264

272-
// case where "" is sent in one or many aud claims
273-
if stringClaims == "" {
274-
return errorIfRequired(required, "aud")
275-
}
276-
277-
return errorIfFalse(result, ErrTokenInvalidAudience)
265+
return nil
278266
}
279267

280268
// verifyIssuer compares the iss claim in claims against cmp.

validator_test.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,116 @@ func Test_Validator_verifyIssuedAt(t *testing.T) {
261261
})
262262
}
263263
}
264+
265+
func Test_Validator_verifyAudience(t *testing.T) {
266+
type fields struct {
267+
expectedAud []string
268+
}
269+
type args struct {
270+
claims Claims
271+
cmp []string
272+
expectAllAud bool
273+
}
274+
tests := []struct {
275+
name string
276+
fields fields
277+
args args
278+
wantErr error
279+
}{
280+
{
281+
name: "fail without audience when expecting one aud match",
282+
fields: fields{expectedAud: []string{"example.com"}},
283+
args: args{
284+
claims: MapClaims{},
285+
cmp: []string{"example.com"},
286+
expectAllAud: false,
287+
},
288+
wantErr: ErrTokenRequiredClaimMissing,
289+
},
290+
{
291+
name: "fail without audience when expecting all aud matches",
292+
fields: fields{expectedAud: []string{"example.com"}},
293+
args: args{
294+
claims: MapClaims{},
295+
cmp: []string{"example.com"},
296+
expectAllAud: true,
297+
},
298+
wantErr: ErrTokenRequiredClaimMissing,
299+
},
300+
{
301+
name: "good when audience matches",
302+
fields: fields{expectedAud: []string{"example.com"}},
303+
args: args{
304+
claims: RegisteredClaims{Audience: ClaimStrings{"example.com"}},
305+
cmp: []string{"example.com"},
306+
expectAllAud: false,
307+
},
308+
wantErr: nil,
309+
},
310+
{
311+
name: "fail when audience matches with one value",
312+
fields: fields{expectedAud: []string{"example.org", "example.com"}},
313+
args: args{
314+
claims: RegisteredClaims{Audience: ClaimStrings{"example.com"}},
315+
cmp: []string{"example.org", "example.com"},
316+
expectAllAud: false,
317+
},
318+
wantErr: nil,
319+
},
320+
{
321+
name: "fail when audience matches with all values",
322+
fields: fields{expectedAud: []string{"example.org", "example.com"}},
323+
args: args{
324+
claims: RegisteredClaims{Audience: ClaimStrings{"example.org", "example.com"}},
325+
cmp: []string{"example.org", "example.com"},
326+
expectAllAud: true,
327+
},
328+
wantErr: nil,
329+
},
330+
{
331+
name: "fail when audience not matching",
332+
fields: fields{expectedAud: []string{"example.org", "example.com"}},
333+
args: args{
334+
claims: RegisteredClaims{Audience: ClaimStrings{"example.net"}},
335+
cmp: []string{"example.org", "example.com"},
336+
expectAllAud: false,
337+
},
338+
wantErr: ErrTokenInvalidAudience,
339+
},
340+
{
341+
name: "fail when audience not matching all values",
342+
fields: fields{expectedAud: []string{"example.org", "example.com"}},
343+
args: args{
344+
claims: RegisteredClaims{Audience: ClaimStrings{"example.org", "example.net"}},
345+
cmp: []string{"example.org", "example.com"},
346+
expectAllAud: true,
347+
},
348+
wantErr: ErrTokenInvalidAudience,
349+
},
350+
{
351+
name: "fail when audience missing",
352+
fields: fields{expectedAud: []string{"example.org", "example.com"}},
353+
args: args{
354+
claims: MapClaims{},
355+
cmp: []string{"example.org", "example.com"},
356+
expectAllAud: true,
357+
},
358+
wantErr: ErrTokenRequiredClaimMissing,
359+
},
360+
}
361+
for _, tt := range tests {
362+
t.Run(tt.name, func(t *testing.T) {
363+
v := &Validator{
364+
expectedAud: tt.fields.expectedAud,
365+
expectAllAud: tt.args.expectAllAud,
366+
}
367+
368+
err := v.verifyAudience(tt.args.claims, tt.args.cmp, tt.args.expectAllAud)
369+
if tt.wantErr == nil && err != nil {
370+
t.Errorf("validator.verifyAudience() error = %v, wantErr %v", err, tt.wantErr)
371+
} else if tt.wantErr != nil && !errors.Is(err, tt.wantErr) {
372+
t.Errorf("validator.verifyAudience() error = %v, wantErr %v", err, tt.wantErr)
373+
}
374+
})
375+
}
376+
}

0 commit comments

Comments
 (0)