Skip to content

Commit 550274a

Browse files
aeneasrory-bot
authored andcommitted
chore: simplify consent challenge decoding
GitOrigin-RevId: c22df92e87df48f7884ae4a85f591a3e6e71e22e
1 parent 4d61925 commit 550274a

10 files changed

+255
-376
lines changed

consent/handler.go

Lines changed: 39 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ func (h *Handler) getOAuth2LoginRequest(w http.ResponseWriter, r *http.Request)
353353
var err error
354354
ctx, span := h.r.Tracer(r.Context()).Tracer().Start(r.Context(), "consent.getOAuth2LoginRequest")
355355
defer otelx.End(span, &err)
356+
356357
challenge := cmp.Or(
357358
r.URL.Query().Get("login_challenge"),
358359
r.URL.Query().Get("challenge"),
@@ -652,6 +653,10 @@ type getOAuth2ConsentRequest struct {
652653
// 410: oAuth2RedirectTo
653654
// default: errorOAuth2
654655
func (h *Handler) getOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request) {
656+
var err error
657+
ctx, span := h.r.Tracer(r.Context()).Tracer().Start(r.Context(), "consent.getOAuth2ConsentRequest")
658+
defer otelx.End(span, &err)
659+
655660
challenge := cmp.Or(
656661
r.URL.Query().Get("consent_challenge"),
657662
r.URL.Query().Get("challenge"),
@@ -661,28 +666,23 @@ func (h *Handler) getOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request
661666
return
662667
}
663668

664-
request, err := h.r.ConsentManager().GetConsentRequest(r.Context(), challenge)
669+
f, err := flow.DecodeFromConsentChallenge(ctx, h.r, challenge)
665670
if err != nil {
666671
h.r.Writer().WriteError(w, r, err)
667672
return
668673
}
669-
if request.WasHandled {
674+
675+
if f.ConsentWasHandled {
670676
h.r.Writer().WriteCode(w, r, http.StatusGone, &flow.OAuth2RedirectTo{
671-
RedirectTo: request.RequestURL,
677+
RedirectTo: f.RequestURL,
672678
})
673679
return
674680
}
675681

676-
if request.RequestedScope == nil {
677-
request.RequestedScope = []string{}
678-
}
679-
680-
if request.RequestedAudience == nil {
681-
request.RequestedAudience = []string{}
682-
}
683-
684-
request.Client = sanitizeClient(request.Client)
685-
h.r.Writer().Write(w, r, request)
682+
// Transform flow to the existing API format.
683+
req := f.GetConsentRequest(challenge)
684+
req.Client.Secret = ""
685+
h.r.Writer().Write(w, r, req)
686686
}
687687

688688
// Accept OAuth 2.0 Consent Request
@@ -734,7 +734,9 @@ type acceptOAuth2ConsentRequest struct {
734734
// 200: oAuth2RedirectTo
735735
// default: errorOAuth2
736736
func (h *Handler) acceptOAuth2ConsentRequest(w http.ResponseWriter, r *http.Request) {
737-
ctx := r.Context()
737+
var err error
738+
ctx, span := h.r.Tracer(r.Context()).Tracer().Start(r.Context(), "consent.acceptOAuth2ConsentRequest")
739+
defer otelx.End(span, &err)
738740

739741
challenge := cmp.Or(
740742
r.URL.Query().Get("consent_challenge"),
@@ -745,39 +747,32 @@ func (h *Handler) acceptOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ
745747
return
746748
}
747749

748-
var p flow.AcceptOAuth2ConsentRequest
750+
var payload flow.AcceptOAuth2ConsentRequest
749751
d := json.NewDecoder(r.Body)
750752
d.DisallowUnknownFields()
751-
if err := d.Decode(&p); err != nil {
753+
if err := d.Decode(&payload); err != nil {
752754
h.r.Writer().WriteErrorCode(w, r, http.StatusBadRequest, errors.WithStack(err))
753755
return
754756
}
755757

756-
cr, err := h.r.ConsentManager().GetConsentRequest(ctx, challenge)
758+
f, err := flow.DecodeFromConsentChallenge(ctx, h.r, challenge)
757759
if err != nil {
758760
h.r.Writer().WriteError(w, r, errors.WithStack(err))
759761
return
760762
}
761763

762-
p.ConsentRequestID = cr.ConsentRequestID
763-
p.RequestedAt = cr.RequestedAt
764-
p.HandledAt = sqlxx.NullTime(time.Now().UTC())
765-
766-
f, err := flow.Decode[flow.Flow](ctx, h.r.FlowCipher(), challenge, flow.AsConsentChallenge)
767-
if err != nil {
768-
h.r.Writer().WriteError(w, r, err)
769-
return
770-
}
764+
payload.ConsentRequestID = f.ConsentRequestID.String()
765+
payload.RequestedAt = f.RequestedAt
766+
payload.HandledAt = sqlxx.NullTime(time.Now().UTC())
771767

772-
hr, err := h.r.ConsentManager().HandleConsentRequest(ctx, f, &p)
773-
if err != nil {
768+
if err := f.HandleConsentRequest(&payload); err != nil {
774769
h.r.Writer().WriteError(w, r, errors.WithStack(err))
775770
return
776-
} else if hr.Skip {
777-
p.Remember = false
771+
} else if f.ConsentSkip {
772+
payload.Remember = false
778773
}
779774

780-
ru, err := url.Parse(hr.RequestURL)
775+
ru, err := url.Parse(f.RequestURL)
781776
if err != nil {
782777
h.r.Writer().WriteError(w, r, err)
783778
return
@@ -789,8 +784,7 @@ func (h *Handler) acceptOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ
789784
return
790785
}
791786

792-
events.Trace(ctx, events.ConsentAccepted, events.WithClientID(cr.Client.GetID()), events.WithSubject(cr.Subject))
793-
787+
events.Trace(ctx, events.ConsentAccepted, events.WithClientID(f.Client.GetID()), events.WithSubject(f.Subject))
794788
h.r.Writer().Write(w, r, &flow.OAuth2RedirectTo{
795789
RedirectTo: urlx.SetQuery(ru, url.Values{"consent_verifier": {verifier}}).String(),
796790
})
@@ -855,41 +849,33 @@ func (h *Handler) rejectOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ
855849
return
856850
}
857851

858-
var p flow.RequestDeniedError
852+
var payload flow.RequestDeniedError
859853
d := json.NewDecoder(r.Body)
860854
d.DisallowUnknownFields()
861-
if err := d.Decode(&p); err != nil {
855+
if err := d.Decode(&payload); err != nil {
862856
h.r.Writer().WriteErrorCode(w, r, http.StatusBadRequest, errors.WithStack(err))
863857
return
864858
}
865859

866-
p.Valid = true
867-
p.SetDefaults(flow.ConsentRequestDeniedErrorName)
868-
hr, err := h.r.ConsentManager().GetConsentRequest(ctx, challenge)
860+
payload.Valid = true
861+
payload.SetDefaults(flow.ConsentRequestDeniedErrorName)
862+
f, err := flow.DecodeFromConsentChallenge(ctx, h.r, challenge)
869863
if err != nil {
870864
h.r.Writer().WriteError(w, r, errors.WithStack(err))
871865
return
872866
}
873867

874-
f, err := flow.Decode[flow.Flow](ctx, h.r.FlowCipher(), challenge, flow.AsConsentChallenge)
875-
if err != nil {
876-
h.r.Writer().WriteError(w, r, err)
877-
return
878-
}
879-
cr := f.GetConsentRequest(challenge)
880-
881-
request, err := h.r.ConsentManager().HandleConsentRequest(ctx, f, &flow.AcceptOAuth2ConsentRequest{
882-
Error: &p,
883-
ConsentRequestID: cr.ConsentRequestID,
884-
RequestedAt: hr.RequestedAt,
868+
if err := f.HandleConsentRequest(&flow.AcceptOAuth2ConsentRequest{
869+
Error: &payload,
870+
ConsentRequestID: f.ConsentRequestID.String(),
871+
RequestedAt: f.RequestedAt,
885872
HandledAt: sqlxx.NullTime(time.Now().UTC()),
886-
})
887-
if err != nil {
873+
}); err != nil {
888874
h.r.Writer().WriteError(w, r, errors.WithStack(err))
889875
return
890876
}
891877

892-
ru, err := url.Parse(request.RequestURL)
878+
ru, err := url.Parse(f.RequestURL)
893879
if err != nil {
894880
h.r.Writer().WriteError(w, r, err)
895881
return
@@ -901,7 +887,7 @@ func (h *Handler) rejectOAuth2ConsentRequest(w http.ResponseWriter, r *http.Requ
901887
return
902888
}
903889

904-
events.Trace(ctx, events.ConsentRejected, events.WithClientID(request.Client.GetID()), events.WithSubject(request.Subject))
890+
events.Trace(ctx, events.ConsentRejected, events.WithClientID(f.Client.GetID()), events.WithSubject(f.Subject))
905891

906892
h.r.Writer().Write(w, r, &flow.OAuth2RedirectTo{
907893
RedirectTo: urlx.SetQuery(ru, url.Values{"consent_verifier": {verifier}}).String(),

consent/manager.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@ func (ForcedObfuscatedLoginSession) TableName() string {
2626

2727
type (
2828
Manager interface {
29-
GetConsentRequest(ctx context.Context, challenge string) (*flow.OAuth2ConsentRequest, error)
30-
HandleConsentRequest(ctx context.Context, f *flow.Flow, r *flow.AcceptOAuth2ConsentRequest) (*flow.OAuth2ConsentRequest, error)
31-
3229
RevokeSubjectConsentSession(ctx context.Context, user string) error
3330
RevokeSubjectClientConsentSession(ctx context.Context, user, client string) error
3431
RevokeConsentSessionByID(ctx context.Context, consentRequestID string) error

consent/test/manager_test_helpers.go

Lines changed: 12 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ import (
1111
"time"
1212

1313
"github.com/gofrs/uuid"
14+
"github.com/stretchr/testify/assert"
15+
"github.com/stretchr/testify/require"
16+
1417
"github.com/ory/fosite"
1518
"github.com/ory/fosite/handler/openid"
1619
"github.com/ory/hydra/v2/aead"
@@ -20,11 +23,8 @@ import (
2023
"github.com/ory/hydra/v2/flow"
2124
"github.com/ory/hydra/v2/oauth2"
2225
"github.com/ory/hydra/v2/x"
23-
"github.com/ory/x/assertx"
2426
"github.com/ory/x/contextx"
2527
"github.com/ory/x/sqlxx"
26-
"github.com/stretchr/testify/assert"
27-
"github.com/stretchr/testify/require"
2828
)
2929

3030
func MockConsentRequest(key string, remember bool, rememberFor int, hasError bool, skip bool, authAt bool, loginChallengeBase string, network string) (c *flow.OAuth2ConsentRequest, h *flow.AcceptOAuth2ConsentRequest, f *flow.Flow) {
@@ -239,8 +239,7 @@ func SaneMockHandleConsentRequest(t *testing.T, m consent.Manager, f *flow.Flow,
239239
HandledAt: sqlxx.NullTime(time.Now().UTC().Add(-time.Minute)),
240240
}
241241

242-
_, err := m.HandleConsentRequest(context.Background(), f, h)
243-
require.NoError(t, err)
242+
require.NoError(t, f.HandleConsentRequest(h))
244243

245244
return h
246245
}
@@ -584,26 +583,11 @@ func ManagerTests(deps Deps, m consent.Manager, clientManager client.Manager, fo
584583
_ = clientManager.CreateClient(ctx, consentRequest.Client) // Ignore errors that are caused by duplication
585584
f.NID = deps.Contextualizer().Network(context.Background(), uuid.Nil)
586585

587-
consentChallenge := makeID("challenge", network, tc.key)
588-
589-
_, err := m.GetConsentRequest(ctx, consentChallenge)
590-
require.Error(t, err)
591-
592-
consentChallenge = x.Must(f.ToConsentChallenge(ctx, deps))
593-
594-
got1, err := m.GetConsentRequest(ctx, consentChallenge)
595-
require.NoError(t, err)
596-
compareConsentRequest(t, consentRequest, got1)
597-
assert.False(t, got1.WasHandled)
598-
599-
got1, err = m.HandleConsentRequest(ctx, f, h)
600-
require.NoError(t, err)
601-
assertx.TimeDifferenceLess(t, time.Now(), time.Time(h.HandledAt), 5)
602-
compareConsentRequest(t, consentRequest, got1)
586+
require.NoError(t, f.HandleConsentRequest(h))
587+
assert.WithinDuration(t, time.Now(), time.Time(h.HandledAt), 5*time.Second)
603588

604589
h.GrantedAudience = sqlxx.StringSliceJSONFormat{"new-audience"}
605-
_, err = m.HandleConsentRequest(ctx, f, h)
606-
require.NoError(t, err)
590+
require.NoError(t, f.HandleConsentRequest(h))
607591

608592
consentVerifier := x.Must(f.ToConsentVerifier(ctx, deps))
609593

@@ -711,10 +695,8 @@ func ManagerTests(deps Deps, m consent.Manager, clientManager client.Manager, fo
711695
_ = clientManager.CreateClient(ctx, cr1.Client)
712696
_ = clientManager.CreateClient(ctx, cr2.Client)
713697

714-
_, err := m.HandleConsentRequest(ctx, f1, hcr1)
715-
require.NoError(t, err)
716-
_, err = m.HandleConsentRequest(ctx, f2, hcr2)
717-
require.NoError(t, err)
698+
require.NoError(t, f1.HandleConsentRequest(hcr1))
699+
require.NoError(t, f2.HandleConsentRequest(hcr2))
718700

719701
crr1, err := m.VerifyAndInvalidateConsentRequest(ctx, x.Must(f1.ToConsentVerifier(ctx, deps)))
720702
require.NoError(t, err)
@@ -778,13 +760,6 @@ func ManagerTests(deps Deps, m consent.Manager, clientManager client.Manager, fo
778760
require.NoError(t, m.RevokeSubjectClientConsentSession(ctx, tc.subject, tc.client))
779761
}
780762

781-
for _, id := range tc.ids {
782-
t.Run(fmt.Sprintf("id=%s", id), func(t *testing.T) {
783-
_, err := m.GetConsentRequest(ctx, id)
784-
assert.True(t, errors.Is(err, x.ErrNotFound))
785-
})
786-
}
787-
788763
r, err := fositeManager.GetAccessTokenSession(ctx, tc.at, nil)
789764
assert.Error(t, err, "%+v", r)
790765
r, err = fositeManager.GetRefreshTokenSession(ctx, tc.rt, nil)
@@ -823,10 +798,8 @@ func ManagerTests(deps Deps, m consent.Manager, clientManager client.Manager, fo
823798
flow.WithConsentCSRF(cr2.CSRF),
824799
)
825800

826-
_, err := m.HandleConsentRequest(ctx, f1, hcr1)
827-
require.NoError(t, err)
828-
_, err = m.HandleConsentRequest(ctx, f2, hcr2)
829-
require.NoError(t, err)
801+
require.NoError(t, f1.HandleConsentRequest(hcr1))
802+
require.NoError(t, f2.HandleConsentRequest(hcr2))
830803
handledConsentRequest1, err := m.VerifyAndInvalidateConsentRequest(ctx, x.Must(f1.ToConsentVerifier(ctx, deps)))
831804
require.NoError(t, err)
832805
handledConsentRequest2, err := m.VerifyAndInvalidateConsentRequest(ctx, x.Must(f2.ToConsentVerifier(ctx, deps)))
@@ -1119,46 +1092,8 @@ func ManagerTests(deps Deps, m consent.Manager, clientManager client.Manager, fo
11191092
require.NoError(t, m.CreateLoginSession(ctx, &s))
11201093
require.NoError(t, m.ConfirmLoginSession(ctx, &s))
11211094

1122-
f := &flow.Flow{
1123-
ID: uuid.Must(uuid.NewV4()).String(),
1124-
Subject: uuid.Must(uuid.NewV4()).String(),
1125-
LoginVerifier: uuid.Must(uuid.NewV4()).String(),
1126-
Client: cl,
1127-
LoginAuthenticatedAt: sqlxx.NullTime(time.Now()),
1128-
RequestedAt: time.Now(),
1129-
SessionID: sqlxx.NullString(s.ID),
1130-
NID: deps.Contextualizer().Network(ctx, uuid.Nil),
1131-
}
1132-
1133-
expected := &flow.OAuth2ConsentRequest{
1134-
ConsentRequestID: uuid.Must(uuid.NewV4()).String(),
1135-
Skip: true,
1136-
Subject: subject,
1137-
OpenIDConnectContext: nil,
1138-
Client: cl,
1139-
ClientID: cl.ID,
1140-
RequestURL: "",
1141-
LoginChallenge: sqlxx.NullString(f.ID),
1142-
LoginSessionID: sqlxx.NullString(s.ID),
1143-
Verifier: uuid.Must(uuid.NewV4()).String(),
1144-
CSRF: uuid.Must(uuid.NewV4()).String(),
1145-
}
1146-
1147-
f.ConsentRequestID = sqlxx.NullString(expected.ConsentRequestID)
1148-
1149-
consentChallenge, err := f.ToConsentChallenge(ctx, deps)
1150-
require.NoError(t, err)
1151-
1152-
result, err := m.GetConsentRequest(ctx, consentChallenge)
1153-
require.NoError(t, err)
1154-
assert.EqualValues(t, expected.ConsentRequestID, result.ConsentRequestID)
1155-
1156-
_, err = m.DeleteLoginSession(ctx, s.ID)
1157-
require.NoError(t, err)
1158-
1159-
result, err = m.GetConsentRequest(ctx, consentChallenge)
1095+
_, err := m.DeleteLoginSession(ctx, s.ID)
11601096
require.NoError(t, err)
1161-
assert.EqualValues(t, expected.ConsentRequestID, result.ConsentRequestID)
11621097
})
11631098
}
11641099
}

0 commit comments

Comments
 (0)