Skip to content

Commit 5872843

Browse files
committed
Address comments to use context for ending goroutines
Signed-off-by: Guilherme Carvalho <[email protected]>
1 parent 85abb9a commit 5872843

File tree

6 files changed

+52
-71
lines changed

6 files changed

+52
-71
lines changed

support/oidc-discovery-provider/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ workload_api {
177177
```hcl
178178
log_level = "debug"
179179
domains = ["mypublicdomain.test"]
180-
serving_certificate {
180+
serving_cert_file {
181181
cert_file_path = "/some/path/on/disk/to/cert.pem"
182182
key_file_path = "/some/path/on/disk/to/key.pem"
183183
}
@@ -191,7 +191,7 @@ server_api {
191191
```hcl
192192
log_level = "debug"
193193
domains = ["mypublicdomain.test"]
194-
serving_certificate {
194+
serving_cert_file {
195195
cert_file_path = "/some/path/on/disk/to/cert.pem"
196196
key_file_path = "/some/path/on/disk/to/key.pem"
197197
}

support/oidc-discovery-provider/cert_manager.go

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
package main
22

33
import (
4+
"context"
45
"crypto/tls"
56
"crypto/x509"
67
"errors"
78
"fmt"
89
"io/fs"
910
"os"
10-
"strings"
1111
"sync"
1212
"time"
1313

1414
"github.com/sirupsen/logrus"
15-
"golang.org/x/net/idna"
1615
)
1716

1817
// DiskCertManager is a certificate manager that loads certificates from disk, and watches for changes.
@@ -43,8 +42,6 @@ func NewDiskCertManager(config *ServingCertFileConfig, log logrus.FieldLogger) (
4342
return nil, fmt.Errorf("failed to load certificate: %w", err)
4443
}
4544

46-
go dm.watchFileChanges()
47-
4845
return dm, nil
4946
}
5047

@@ -61,46 +58,27 @@ func (m *DiskCertManager) TLSConfig() *tls.Config {
6158

6259
// getCertificate is called by the TLS stack when a new TLS connection is established.
6360
func (m *DiskCertManager) getCertificate(chInfo *tls.ClientHelloInfo) (*tls.Certificate, error) {
64-
name := chInfo.ServerName
65-
if name == "" {
66-
return nil, errors.New("missing server name")
67-
}
68-
if !strings.Contains(strings.Trim(name, "."), ".") {
69-
return nil, errors.New("server name component count invalid")
70-
}
71-
72-
// Note that this conversion is necessary because some server names in the handshakes
73-
// started by some clients (such as cURL) are not converted to Punycode, which will
74-
// prevent us from obtaining certificates for them. In addition, we should also treat
75-
// example.com and EXAMPLE.COM as equivalent and return the same certificate for them.
76-
// Fortunately, this conversion also helped us deal with this kind of mixedcase problems.
77-
//
78-
// Due to the "σςΣ" problem (see https://unicode.org/faq/idn.html#22), we can't use
79-
// idna.Punycode.ToASCII (or just idna.ToASCII) here.
80-
name, err := idna.Lookup.ToASCII(name)
81-
if err != nil {
82-
return nil, errors.New("server name contains invalid character")
83-
}
84-
8561
m.certMtx.RLock()
8662
defer m.certMtx.RUnlock()
8763
cert := m.cert
8864

89-
// Verify that the certificate is valid for the requested server name.
90-
if name != cert.Leaf.Subject.CommonName {
91-
if err := cert.Leaf.VerifyHostname(name); err != nil {
92-
return nil, fmt.Errorf("server name mismatch: %w", err)
93-
}
94-
}
95-
9665
return cert, nil
9766
}
9867

99-
// watchFileChanges starts a file watcher to watch for changes to the cert and key files.
100-
func (m *DiskCertManager) watchFileChanges() {
68+
// WatchFileChanges starts a file watcher to watch for changes to the cert and key files.
69+
func (m *DiskCertManager) WatchFileChanges(ctx context.Context) {
70+
m.log.WithField("interval", m.fileSyncInterval).Info("Started watching certificate files")
71+
10172
ticker := time.NewTicker(m.fileSyncInterval)
102-
for range ticker.C {
103-
m.syncCertificateFiles()
73+
defer ticker.Stop()
74+
for {
75+
select {
76+
case <-ctx.Done():
77+
m.log.Info("Stopping file watcher")
78+
return
79+
case <-ticker.C:
80+
m.syncCertificateFiles()
81+
}
10482
}
10583
}
10684

support/oidc-discovery-provider/cert_manager_test_posix.go renamed to support/oidc-discovery-provider/cert_manager_posix_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ var (
2020
func writeFile(t *testing.T, name string, data []byte) {
2121
err := os.WriteFile(name, data, 0600)
2222
require.NoError(t, err)
23-
_, err = os.Stat(name)
24-
require.NoError(t, err)
2523
}
2624

2725
func removeFile(t *testing.T, name string) {

support/oidc-discovery-provider/cert_manager_test.go

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package main
22

33
import (
4+
"context"
45
"crypto/tls"
56
"crypto/x509"
67
"crypto/x509/pkix"
@@ -82,13 +83,16 @@ func TestTLSConfig(t *testing.T) {
8283
ServerName: "oidc-provider-discovery.example.com",
8384
}
8485

86+
ctx, cancelFn := context.WithCancel(context.Background())
8587
certManager, err := NewDiskCertManager(&ServingCertFileConfig{
8688
CertFilePath: tmpDir + certFilePath,
8789
KeyFilePath: tmpDir + keyFilePath,
8890
FileSyncInterval: 10 * time.Millisecond,
8991
}, logger)
9092
require.NoError(t, err)
9193

94+
certManager.WatchFileChanges(ctx)
95+
9296
tlsConfig := certManager.TLSConfig()
9397

9498
t.Run("error when configuration does not contain serving cert file settings", func(t *testing.T) {
@@ -150,32 +154,6 @@ func TestTLSConfig(t *testing.T) {
150154
require.EqualError(t, err, "failed to load certificate: tls: failed to find any PEM data in key input")
151155
})
152156

153-
t.Run("error when client misses server name", func(t *testing.T) {
154-
_, err := tlsConfig.GetCertificate(&tls.ClientHelloInfo{})
155-
require.EqualError(t, err, "missing server name")
156-
})
157-
158-
t.Run("error when client send server name with invalid character", func(t *testing.T) {
159-
_, err := tlsConfig.GetCertificate(&tls.ClientHelloInfo{
160-
ServerName: "example.com:8080",
161-
})
162-
require.EqualError(t, err, "server name contains invalid character")
163-
})
164-
165-
t.Run("error when client send server name with invalid component count", func(t *testing.T) {
166-
_, err := tlsConfig.GetCertificate(&tls.ClientHelloInfo{
167-
ServerName: "example",
168-
})
169-
require.EqualError(t, err, "server name component count invalid")
170-
})
171-
172-
t.Run("error when client send wrong server name", func(t *testing.T) {
173-
_, err := tlsConfig.GetCertificate(&tls.ClientHelloInfo{
174-
ServerName: "example.com",
175-
})
176-
require.EqualError(t, err, `server name mismatch: x509: certificate is not valid for any names, but wanted to match example.com`)
177-
})
178-
179157
t.Run("success loading initial certificate from disk", func(t *testing.T) {
180158
cert, err := tlsConfig.GetCertificate(chInfo)
181159
require.NoError(t, err)
@@ -398,4 +376,13 @@ func TestTLSConfig(t *testing.T) {
398376
return reflect.DeepEqual(oidcServerCertUpdated3, x509Cert)
399377
}, 10*time.Second, 100*time.Millisecond, "Failed to assert updated certificate")
400378
})
379+
380+
t.Run("stop file watcher when context is canceled", func(t *testing.T) {
381+
cancelFn()
382+
383+
require.Eventuallyf(t, func() bool {
384+
lastEntry := logHook.LastEntry()
385+
return lastEntry.Level == logrus.InfoLevel && lastEntry.Message == "Stopping file watcher"
386+
}, 10*time.Second, 10*time.Millisecond, "Failed to assert file watcher stop log")
387+
})
401388
}

support/oidc-discovery-provider/main.go

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
package main
22

33
import (
4+
"context"
45
"crypto/tls"
56
"flag"
67
"fmt"
78
"net"
89
"net/http"
910
"os"
11+
"os/signal"
12+
"syscall"
1013
"time"
1114

1215
"github.com/sirupsen/logrus"
@@ -49,6 +52,9 @@ func run(configPath string) error {
4952
}
5053
defer log.Close()
5154

55+
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
56+
defer stop()
57+
5258
source, err := newSource(log, config)
5359
if err != nil {
5460
return err
@@ -66,7 +72,7 @@ func run(configPath string) error {
6672
handler = logHandler(log, handler)
6773
}
6874

69-
listener, err := buildNetListener(config, log)
75+
listener, err := buildNetListener(ctx, config, log)
7076
if err != nil {
7177
return err
7278
}
@@ -91,10 +97,19 @@ func run(configPath string) error {
9197
Handler: handler,
9298
ReadHeaderTimeout: 10 * time.Second,
9399
}
100+
101+
go func() {
102+
<-ctx.Done()
103+
err = server.Shutdown(context.Background())
104+
if err != nil {
105+
log.Error(err)
106+
}
107+
}()
108+
94109
return server.Serve(listener)
95110
}
96111

97-
func buildNetListener(config *Config, log *log.Logger) (listener net.Listener, err error) {
112+
func buildNetListener(ctx context.Context, config *Config, log *log.Logger) (listener net.Listener, err error) {
98113
switch {
99114
case config.InsecureAddr != "":
100115
listener, err = net.Listen("tcp", config.InsecureAddr)
@@ -112,7 +127,7 @@ func buildNetListener(config *Config, log *log.Logger) (listener net.Listener, e
112127
telemetry.Address: listener.Addr().String(),
113128
}).Info("Serving HTTP")
114129
case config.ServingCertFile != nil:
115-
listener, err = newListenerWithServingCert(log, config)
130+
listener, err = newListenerWithServingCert(ctx, log, config)
116131
if err != nil {
117132
return nil, err
118133
}
@@ -152,15 +167,18 @@ func newSource(log logrus.FieldLogger, config *Config) (JWKSSource, error) {
152167
}
153168
}
154169

155-
func newListenerWithServingCert(log logrus.FieldLogger, config *Config) (net.Listener, error) {
170+
func newListenerWithServingCert(ctx context.Context, log logrus.FieldLogger, config *Config) (net.Listener, error) {
156171
certManager, err := NewDiskCertManager(config.ServingCertFile, log)
157172
if err != nil {
158173
return nil, err
159174
}
175+
go func() {
176+
certManager.WatchFileChanges(ctx)
177+
}()
160178

161179
tlsConfig := certManager.TLSConfig()
162180

163-
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: 443})
181+
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: 8080})
164182
if err != nil {
165183
return nil, fmt.Errorf("failed to create listener using certificate from disk: %w", err)
166184
}

0 commit comments

Comments
 (0)