Skip to content

Commit bea6b4d

Browse files
zepatrikory-bot
authored andcommitted
feat(hydra): split up persister
GitOrigin-RevId: 203cf926c1613fcbb20393c5b7d0af25c7aecb15
1 parent 4929a8d commit bea6b4d

38 files changed

+436
-543
lines changed

.reports/dep-licenses.csv

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,4 @@
33
"github.com/ory/x","Apache-2.0"
44
"github.com/stretchr/testify","MIT"
55
"go.opentelemetry.io/otel/sdk","Apache-2.0"
6-
"go.opentelemetry.io/otel/sdk","BSD-3-Clause"
76

aead/aesgcm.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func (c *AESGCM) Decrypt(ctx context.Context, ciphertext string, aad []byte) (pl
6262
return nil, err
6363
}
6464

65-
func (c *AESGCM) decrypt(ciphertext []byte, key, additionalData []byte) ([]byte, error) {
65+
func (*AESGCM) decrypt(ciphertext, key, additionalData []byte) ([]byte, error) {
6666
if len(key) != 32 {
6767
return nil, errors.Errorf("key must be exactly 32 long bytes, got %d bytes", len(key))
6868
}

cmd/cli/handler_migrate.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010

1111
"github.com/ory/hydra/v2/driver"
1212
"github.com/ory/hydra/v2/driver/config"
13-
"github.com/ory/hydra/v2/persistence"
13+
"github.com/ory/hydra/v2/persistence/sql"
1414
"github.com/ory/x/cmdx"
1515
"github.com/ory/x/configx"
1616
"github.com/ory/x/popx"
@@ -26,7 +26,7 @@ func newMigrateHandler(dOpts []driver.OptionsModifier) *MigrateHandler {
2626
}
2727
}
2828

29-
func (h *MigrateHandler) makePersister(cmd *cobra.Command, args []string) (p persistence.Persister, err error) {
29+
func (h *MigrateHandler) makeMigrationManager(cmd *cobra.Command, args []string) (*sql.MigrationManager, error) {
3030
opts := append([]driver.OptionsModifier{
3131
driver.WithConfigOptions(
3232
configx.SkipValidation(),
@@ -52,27 +52,27 @@ func (h *MigrateHandler) makePersister(cmd *cobra.Command, args []string) (p per
5252
return nil, cmdx.FailSilently(cmd)
5353
}
5454

55-
return d.Persister(), nil
55+
return d.Migrator(), nil
5656
}
5757

5858
func (h *MigrateHandler) MigrateSQLUp(cmd *cobra.Command, args []string) (err error) {
59-
p, err := h.makePersister(cmd, args)
59+
p, err := h.makeMigrationManager(cmd, args)
6060
if err != nil {
6161
return err
6262
}
6363
return popx.MigrateSQLUp(cmd, p)
6464
}
6565

6666
func (h *MigrateHandler) MigrateSQLDown(cmd *cobra.Command, args []string) (err error) {
67-
p, err := h.makePersister(cmd, args)
67+
p, err := h.makeMigrationManager(cmd, args)
6868
if err != nil {
6969
return err
7070
}
7171
return popx.MigrateSQLDown(cmd, p)
7272
}
7373

7474
func (h *MigrateHandler) MigrateStatus(cmd *cobra.Command, args []string) error {
75-
p, err := h.makePersister(cmd, args)
75+
p, err := h.makeMigrationManager(cmd, args)
7676
if err != nil {
7777
return err
7878
}

cmd/server/helper_cert.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"sync"
1414

1515
"github.com/go-jose/go-jose/v3"
16-
"github.com/gofrs/uuid"
1716
"github.com/pkg/errors"
1817

1918
"github.com/ory/hydra/v2/driver"
@@ -54,7 +53,7 @@ func GetOrCreateTLSCertificate(ctx context.Context, d *driver.RegistrySQL, tlsCo
5453
}
5554

5655
// no certificates configured: self-sign a new cert
57-
priv, err := jwk.GetOrGenerateKeys(ctx, d, d.SoftwareKeyManager(), TlsKeyName, uuid.Must(uuid.NewV4()).String(), "RS256")
56+
priv, err := jwk.GetOrGenerateKeys(ctx, d, d.KeyManager(), TlsKeyName, "RS256")
5857
if err != nil {
5958
d.Logger().WithError(err).Fatal("Unable to fetch or generate HTTPS TLS key pair")
6059
return nil // in case Fatal is hooked
@@ -68,12 +67,12 @@ func GetOrCreateTLSCertificate(ctx context.Context, d *driver.RegistrySQL, tlsCo
6867
}
6968

7069
AttachCertificate(priv, cert)
71-
if err := d.SoftwareKeyManager().DeleteKey(ctx, TlsKeyName, priv.KeyID); err != nil {
70+
if err := d.KeyManager().DeleteKey(ctx, TlsKeyName, priv.KeyID); err != nil {
7271
d.Logger().WithError(err).Fatal(`Could not update (delete) the self signed TLS certificate`)
7372
return nil // in case Fatal is hooked
7473
}
7574

76-
if err := d.SoftwareKeyManager().AddKey(ctx, TlsKeyName, priv); err != nil {
75+
if err := d.KeyManager().AddKey(ctx, TlsKeyName, priv); err != nil {
7776
d.Logger().WithError(err).Fatalf(`Could not update (add) the self signed TLS certificate: %s %x %d`, cert.SignatureAlgorithm, cert.Signature, len(cert.Signature))
7877
return nil // in case Fatalf is hooked
7978
}

driver/registry.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ import (
2525
)
2626

2727
type registry interface {
28-
dbal.Driver
29-
3028
x.HTTPClientProvider
3129

3230
contextx.Provider

driver/registry_sql.go

Lines changed: 42 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ type RegistrySQL struct {
6464
cv *client.Validator
6565
ctxer contextx.Contextualizer
6666
hh *healthx.Handler
67-
migrationStatus *popx.MigrationStatuses
6867
kc *aead.AESGCM
6968
flowc *aead.XChaCha20Poly1305
7069
cos consent.Strategy
@@ -76,7 +75,7 @@ type RegistrySQL struct {
7675
trc *otelx.Tracer
7776
tracerWrapper func(*otelx.Tracer) *otelx.Tracer
7877
arhs []oauth2.AccessRequestHook
79-
persister persistence.Persister
78+
basePersister *sql.BasePersister
8079
oc fosite.Configurator
8180
oidcs jwk.JWTSigner
8281
ats jwk.JWTSigner
@@ -87,10 +86,11 @@ type RegistrySQL struct {
8786
publicCORS *cors.Cors
8887
kratos kratos.Client
8988
fositeFactories []fositex.Factory
89+
migrator *sql.MigrationManager
9090

91-
defaultKeyManager jwk.Manager
92-
initialPing func(r *RegistrySQL) error
93-
middlewares []negroni.Handler
91+
keyManager jwk.Manager
92+
initialPing func(ctx context.Context, l *logrusx.Logger, p *sql.BasePersister) error
93+
middlewares []negroni.Handler
9494
}
9595

9696
var (
@@ -101,12 +101,10 @@ var (
101101
// defaultInitialPing is the default function that will be called within RegistrySQL.Init to make sure
102102
// the database is reachable. It can be injected for test purposes by changing the value
103103
// of RegistrySQL.initialPing.
104-
func defaultInitialPing(m *RegistrySQL) error {
105-
if err := resilience.Retry(m.l, 5*time.Second, 5*time.Minute, m.Ping); err != nil {
106-
m.Logger().Print("Could not ping database: ", err)
107-
return errors.WithStack(err)
108-
}
109-
return nil
104+
func defaultInitialPing(ctx context.Context, l *logrusx.Logger, p *sql.BasePersister) error {
105+
return errors.WithStack(resilience.Retry(l, 5*time.Second, 5*time.Minute, func() error {
106+
return p.Ping(ctx)
107+
}))
110108
}
111109

112110
func (m *RegistrySQL) Init(
@@ -116,7 +114,7 @@ func (m *RegistrySQL) Init(
116114
extraMigrations []fs.FS,
117115
goMigrations []popx.Migration,
118116
) error {
119-
if m.persister == nil {
117+
if m.basePersister == nil {
120118
if m.Config().CGroupsV1AutoMaxProcsEnabled() {
121119
_, err := maxprocs.Set(maxprocs.Logger(m.Logger().Infof))
122120
if err != nil {
@@ -146,59 +144,40 @@ func (m *RegistrySQL) Init(
146144
return errors.WithStack(err)
147145
}
148146

149-
p, err := sql.NewPersister(c, m, m.Config(), extraMigrations, goMigrations)
150-
if err != nil {
151-
return err
152-
}
153-
m.persister = p
154-
if err := m.initialPing(m); err != nil {
147+
m.basePersister = sql.NewBasePersister(c, m)
148+
if err := m.initialPing(ctx, m.Logger(), m.basePersister); err != nil {
149+
m.Logger().Print("Could not ping database: ", err)
155150
return err
156151
}
157152

158-
if m.Config().HSMEnabled() {
159-
hardwareKeyManager := hsm.NewKeyManager(m.HSMContext(), m.Config())
160-
m.defaultKeyManager = jwk.NewManagerStrategy(hardwareKeyManager, m.persister)
161-
} else {
162-
m.defaultKeyManager = m.persister
163-
}
153+
m.migrator = sql.NewMigrationManager(c, m, extraMigrations, goMigrations)
164154

165155
// if dsn is memory we have to run the migrations on every start
166156
// use case - such as
167157
// - just in memory
168158
// - shared connection
169159
// - shared but unique in the same process
170160
// see: https://sqlite.org/inmemorydb.html
171-
if dbal.IsMemorySQLite(m.Config().DSN()) {
172-
m.Logger().Print("Hydra is running migrations on every startup as DSN is memory.\n")
173-
m.Logger().Print("This means your data is lost when Hydra terminates.\n")
174-
if err := p.MigrateUp(context.Background()); err != nil {
175-
return err
176-
}
177-
} else if migrate {
178-
if err := p.MigrateUp(context.Background()); err != nil {
161+
switch {
162+
case dbal.IsMemorySQLite(m.Config().DSN()):
163+
m.Logger().Println("Hydra is running migrations on every startup as DSN is memory.")
164+
m.Logger().Println("This means your data is lost when Hydra terminates.")
165+
fallthrough
166+
case migrate:
167+
if err := m.migrator.MigrateUp(ctx); err != nil {
179168
return err
180169
}
181170
}
182171

183-
if skipNetworkInit {
184-
m.persister = p
185-
} else {
186-
net, err := p.DetermineNetwork(ctx)
172+
if !skipNetworkInit {
173+
net, err := m.basePersister.DetermineNetwork(ctx)
187174
if err != nil {
188175
m.Logger().WithError(err).Warnf("Unable to determine network, retrying.")
189176
return err
190177
}
191178

192-
m.persister = p.WithFallbackNetworkID(net.ID)
193-
}
194-
195-
if m.Config().HSMEnabled() {
196-
hardwareKeyManager := hsm.NewKeyManager(m.HSMContext(), m.Config())
197-
m.defaultKeyManager = jwk.NewManagerStrategy(hardwareKeyManager, m.persister)
198-
} else {
199-
m.defaultKeyManager = m.persister
179+
m.basePersister = m.basePersister.WithFallbackNetworkID(net.ID)
200180
}
201-
202181
}
203182

204183
return nil
@@ -211,11 +190,7 @@ func (m *RegistrySQL) alwaysCanHandle(dsn string) bool {
211190
}
212191

213192
func (m *RegistrySQL) PingContext(ctx context.Context) error {
214-
return m.Persister().Ping(ctx)
215-
}
216-
217-
func (m *RegistrySQL) Ping() error {
218-
return m.PingContext(context.Background())
193+
return m.basePersister.Ping(ctx)
219194
}
220195

221196
func (m *RegistrySQL) ClientManager() client.Manager {
@@ -231,11 +206,16 @@ func (m *RegistrySQL) OAuth2Storage() x.FositeStorer {
231206
}
232207

233208
func (m *RegistrySQL) KeyManager() jwk.Manager {
234-
return m.defaultKeyManager
235-
}
236-
237-
func (m *RegistrySQL) SoftwareKeyManager() jwk.Manager {
238-
return m.Persister()
209+
if m.keyManager == nil {
210+
softwareKeyManager := &sql.JWKPersister{BasePersister: m.basePersister}
211+
if m.Config().HSMEnabled() {
212+
hardwareKeyManager := hsm.NewKeyManager(m.HSMContext(), m.Config())
213+
m.keyManager = jwk.NewManagerStrategy(hardwareKeyManager, softwareKeyManager)
214+
} else {
215+
m.keyManager = softwareKeyManager
216+
}
217+
}
218+
return m.keyManager
239219
}
240220

241221
func (m *RegistrySQL) GrantManager() trust.GrantManager {
@@ -330,11 +310,7 @@ func (m *RegistrySQL) HealthHandler() *healthx.Handler {
330310
return m.PingContext(r.Context())
331311
},
332312
"migrations": func(r *http.Request) error {
333-
if m.migrationStatus != nil && !m.migrationStatus.HasPending() {
334-
return nil
335-
}
336-
337-
status, err := m.Persister().MigrationStatus(r.Context())
313+
status, err := m.migrator.MigrationStatus(r.Context())
338314
if err != nil {
339315
return err
340316
}
@@ -344,8 +320,6 @@ func (m *RegistrySQL) HealthHandler() *healthx.Handler {
344320
m.Logger().WithField("status", fmt.Sprintf("%+v", status)).WithError(err).Warn("Instance is not yet ready because migrations have not yet been fully applied.")
345321
return err
346322
}
347-
348-
m.migrationStatus = &status
349323
return nil
350324
},
351325
})
@@ -529,7 +503,7 @@ func (m *RegistrySQL) OpenIDConnectRequestValidator() *openid.OpenIDConnectReque
529503
}
530504

531505
func (m *RegistrySQL) Networker() x.Networker {
532-
return m.persister
506+
return m.basePersister
533507
}
534508

535509
func (m *RegistrySQL) SubjectIdentifierAlgorithm(ctx context.Context) map[string]consent.SubjectIdentifierAlgorithm {
@@ -569,7 +543,7 @@ func (m *RegistrySQL) Tracer(_ context.Context) *otelx.Tracer {
569543
}
570544

571545
func (m *RegistrySQL) Persister() persistence.Persister {
572-
return m.persister
546+
return sql.NewPersister(m.basePersister, m)
573547
}
574548

575549
// Config returns the configuration for the given context. It may or may not be the same as the global configuration.
@@ -609,3 +583,7 @@ func (m *RegistrySQL) Kratos() kratos.Client {
609583
func (m *RegistrySQL) HTTPMiddlewares() []negroni.Handler {
610584
return m.middlewares
611585
}
586+
587+
func (m *RegistrySQL) Migrator() *sql.MigrationManager {
588+
return m.migrator
589+
}

driver/registry_sql_test.go

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ func TestRegistrySQL_newKeyStrategy_handlesNetworkError(t *testing.T) {
6464

6565
l := logrusx.New("", "", logrusx.WithHook(&hook))
6666
l.Logrus().SetOutput(io.Discard)
67-
l.Logrus().ExitFunc = func(int) {} // Override the exit func to avoid call to os.Exit
6867

6968
// Create a config and set a valid but unresolvable DSN
7069
c := config.MustNew(t, l,
@@ -88,7 +87,7 @@ func TestRegistrySQL_newKeyStrategy_handlesNetworkError(t *testing.T) {
8887
"snizzles",
8988
)
9089

91-
assert.Equal(t, logrus.FatalLevel, hook.LastEntry().Level)
90+
assert.Equal(t, logrus.InfoLevel, hook.LastEntry().Level)
9291
assert.Contains(t, hook.LastEntry().Message, "snizzles")
9392
}
9493

@@ -158,8 +157,7 @@ func TestDefaultKeyManager_HsmDisabled(t *testing.T) {
158157
}),
159158
)
160159
require.NoError(t, err)
161-
assert.IsType(t, &sql.Persister{}, r.KeyManager())
162-
assert.IsType(t, &sql.Persister{}, r.SoftwareKeyManager())
160+
assert.IsType(t, &sql.JWKPersister{}, r.KeyManager())
163161
}
164162

165163
func TestDbUnknownTableColumns(t *testing.T) {
@@ -199,14 +197,13 @@ func TestDbUnknownTableColumns(t *testing.T) {
199197
})
200198
}
201199

202-
func sussessfulPing(r *RegistrySQL) error {
200+
func sussessfulPing(context.Context, *logrusx.Logger, *sql.BasePersister) error {
203201
// fake that ping is successful
204202
return nil
205203
}
206204

207-
func failedPing(err error) func(r *RegistrySQL) error {
208-
return func(r *RegistrySQL) error {
209-
r.Logger().Fatal(err.Error())
205+
func failedPing(err error) func(context.Context, *logrusx.Logger, *sql.BasePersister) error {
206+
return func(context.Context, *logrusx.Logger, *sql.BasePersister) error {
210207
return pkgerr.WithStack(err)
211208
}
212209
}

0 commit comments

Comments
 (0)