Skip to content

Commit d74921d

Browse files
make it possible to use a custom tls.Config for listening and dialing (#22)
1 parent 2823159 commit d74921d

File tree

3 files changed

+184
-56
lines changed

3 files changed

+184
-56
lines changed

p2p/transport/webtransport/listener.go

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@ const queueLen = 16
2727
const handshakeTimeout = 10 * time.Second
2828

2929
type listener struct {
30-
transport tpt.Transport
31-
noise *noise.Transport
32-
certManager *certManager
33-
rcmgr network.ResourceManager
34-
gater connmgr.ConnectionGater
30+
transport tpt.Transport
31+
noise *noise.Transport
32+
certManager *certManager
33+
staticTLSConf *tls.Config
34+
35+
rcmgr network.ResourceManager
36+
gater connmgr.ConnectionGater
3537

3638
server webtransport.Server
3739

@@ -48,7 +50,7 @@ type listener struct {
4850

4951
var _ tpt.Listener = &listener{}
5052

51-
func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Transport, certManager *certManager, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Listener, error) {
53+
func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Transport, certManager *certManager, tlsConf *tls.Config, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Listener, error) {
5254
network, addr, err := manet.DialArgs(laddr)
5355
if err != nil {
5456
return nil, err
@@ -65,23 +67,23 @@ func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Trans
6567
if err != nil {
6668
return nil, err
6769
}
70+
if tlsConf == nil {
71+
tlsConf = &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
72+
return certManager.GetConfig(), nil
73+
}}
74+
}
6875
ln := &listener{
69-
transport: transport,
70-
noise: noise,
71-
certManager: certManager,
72-
rcmgr: rcmgr,
73-
gater: gater,
74-
queue: make(chan tpt.CapableConn, queueLen),
75-
serverClosed: make(chan struct{}),
76-
addr: udpConn.LocalAddr(),
77-
multiaddr: localMultiaddr,
78-
server: webtransport.Server{
79-
H3: http3.Server{
80-
TLSConfig: &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
81-
return certManager.GetConfig(), nil
82-
}},
83-
},
84-
},
76+
transport: transport,
77+
noise: noise,
78+
certManager: certManager,
79+
staticTLSConf: tlsConf,
80+
rcmgr: rcmgr,
81+
gater: gater,
82+
queue: make(chan tpt.CapableConn, queueLen),
83+
serverClosed: make(chan struct{}),
84+
addr: udpConn.LocalAddr(),
85+
multiaddr: localMultiaddr,
86+
server: webtransport.Server{H3: http3.Server{TLSConfig: tlsConf}},
8587
}
8688
ln.ctx, ln.ctxCancel = context.WithCancel(context.Background())
8789
mux := http.NewServeMux()
@@ -198,6 +200,9 @@ func (l *listener) Addr() net.Addr {
198200
}
199201

200202
func (l *listener) Multiaddr() ma.Multiaddr {
203+
if l.certManager == nil {
204+
return l.multiaddr
205+
}
201206
return l.multiaddr.Encapsulate(l.certManager.AddrComponent())
202207
}
203208

p2p/transport/webtransport/transport.go

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"crypto/tls"
66
"crypto/x509"
7+
"errors"
78
"fmt"
89
"io"
910
"sync"
@@ -43,6 +44,27 @@ func WithClock(cl clock.Clock) Option {
4344
}
4445
}
4546

47+
// WithTLSConfig sets a tls.Config used for listening.
48+
// When used, the certificate from that config will be used, and no /certhash will be added to the listener's multiaddr.
49+
// This is most useful when running a listener that has a valid (CA-signed) certificate.
50+
func WithTLSConfig(c *tls.Config) Option {
51+
return func(t *transport) error {
52+
t.staticTLSConf = c
53+
return nil
54+
}
55+
}
56+
57+
// WithTLSClientConfig sets a custom tls.Config used for dialing.
58+
// This option is most useful for setting a custom tls.Config.RootCAs certificate pool.
59+
// When dialing a multiaddr that contains a /certhash component, this library will set InsecureSkipVerify and
60+
// overwrite the VerifyPeerCertificate callback.
61+
func WithTLSClientConfig(c *tls.Config) Option {
62+
return func(t *transport) error {
63+
t.tlsClientConf = c
64+
return nil
65+
}
66+
}
67+
4668
type transport struct {
4769
privKey ic.PrivKey
4870
pid peer.ID
@@ -54,6 +76,8 @@ type transport struct {
5476
listenOnce sync.Once
5577
listenOnceErr error
5678
certManager *certManager
79+
staticTLSConf *tls.Config
80+
tlsClientConf *tls.Config
5781

5882
noise *noise.Transport
5983
}
@@ -129,15 +153,21 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp
129153

130154
func (t *transport) dial(ctx context.Context, addr string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, error) {
131155
url := fmt.Sprintf("https://%s%s", addr, webtransportHTTPEndpoint)
156+
var tlsConf *tls.Config
157+
if t.tlsClientConf != nil {
158+
tlsConf = t.tlsClientConf.Clone()
159+
} else {
160+
tlsConf = &tls.Config{}
161+
}
162+
163+
if len(certHashes) > 0 {
164+
tlsConf.InsecureSkipVerify = true // this is not insecure. We verify the certificate ourselves.
165+
tlsConf.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
166+
return verifyRawCerts(rawCerts, certHashes)
167+
}
168+
}
132169
dialer := webtransport.Dialer{
133-
RoundTripper: &http3.RoundTripper{
134-
TLSClientConfig: &tls.Config{
135-
InsecureSkipVerify: true, // this is not insecure. We verify the certificate ourselves.
136-
VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
137-
return verifyRawCerts(rawCerts, certHashes)
138-
},
139-
},
140-
},
170+
RoundTripper: &http3.RoundTripper{TLSClientConfig: tlsConf},
141171
}
142172
rsp, sess, err := dialer.Dial(ctx, url, nil)
143173
if err != nil {
@@ -193,6 +223,14 @@ func (t *transport) checkEarlyData(b []byte) error {
193223
return fmt.Errorf("failed to unmarshal early data protobuf: %w", err)
194224
}
195225
hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes))
226+
227+
if t.staticTLSConf != nil {
228+
if len(hashes) > 0 {
229+
return errors.New("using static TLS config, didn't expect any certificate hashes")
230+
}
231+
return nil
232+
}
233+
196234
for _, h := range msg.CertHashes {
197235
dh, err := multihash.Decode(h)
198236
if err != nil {
@@ -224,13 +262,15 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) {
224262
if !webtransportMatcher.Matches(laddr) {
225263
return nil, fmt.Errorf("cannot listen on non-WebTransport addr: %s", laddr)
226264
}
227-
t.listenOnce.Do(func() {
228-
t.certManager, t.listenOnceErr = newCertManager(t.clock)
229-
})
230-
if t.listenOnceErr != nil {
231-
return nil, t.listenOnceErr
265+
if t.staticTLSConf == nil {
266+
t.listenOnce.Do(func() {
267+
t.certManager, t.listenOnceErr = newCertManager(t.clock)
268+
})
269+
if t.listenOnceErr != nil {
270+
return nil, t.listenOnceErr
271+
}
232272
}
233-
return newListener(laddr, t, t.noise, t.certManager, t.gater, t.rcmgr)
273+
return newListener(laddr, t, t.noise, t.certManager, t.staticTLSConf, t.gater, t.rcmgr)
234274
}
235275

236276
func (t *transport) Protocols() []int {

p2p/transport/webtransport/transport_test.go

Lines changed: 103 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,19 @@ package libp2pwebtransport_test
22

33
import (
44
"context"
5+
"crypto/ecdsa"
6+
"crypto/elliptic"
57
"crypto/rand"
68
"crypto/sha256"
9+
"crypto/tls"
10+
"crypto/x509"
11+
"crypto/x509/pkix"
712
"errors"
813
"fmt"
914
"io"
15+
"math/big"
1016
"net"
17+
"strings"
1118
"testing"
1219
"time"
1320

@@ -64,6 +71,19 @@ func stripCertHashes(addr ma.Multiaddr) ma.Multiaddr {
6471
}
6572
}
6673

74+
// create a /certhash multiaddr component using the SHA256 of foobar
75+
func getCerthashComponent(t *testing.T, b []byte) ma.Multiaddr {
76+
t.Helper()
77+
h := sha256.Sum256(b)
78+
mh, err := multihash.Encode(h[:], multihash.SHA2_256)
79+
require.NoError(t, err)
80+
certStr, err := multibase.Encode(multibase.Base58BTC, mh)
81+
require.NoError(t, err)
82+
ha, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr)
83+
require.NoError(t, err)
84+
return ha
85+
}
86+
6787
func TestTransport(t *testing.T) {
6888
serverID, serverKey := newIdentity(t)
6989
tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager)
@@ -129,29 +149,11 @@ func TestHashVerification(t *testing.T) {
129149
require.NoError(t, err)
130150
defer tr2.(io.Closer).Close()
131151

132-
// create a hash component using the SHA256 of foobar
133-
h := sha256.Sum256([]byte("foobar"))
134-
mh, err := multihash.Encode(h[:], multihash.SHA2_256)
135-
require.NoError(t, err)
136-
certStr, err := multibase.Encode(multibase.Base58BTC, mh)
137-
require.NoError(t, err)
138-
foobarHash, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr)
139-
require.NoError(t, err)
152+
foobarHash := getCerthashComponent(t, []byte("foobar"))
140153

141154
t.Run("fails using only a wrong hash", func(t *testing.T) {
142155
// replace the certificate hash in the multiaddr with a fake hash
143-
addr := ln.Multiaddr()
144-
// strip off all certhash components
145-
for {
146-
a, comp := ma.SplitLast(addr)
147-
if comp.Protocol().Code != ma.P_CERTHASH {
148-
break
149-
}
150-
addr = a
151-
}
152-
153-
addr = addr.Encapsulate(foobarHash)
154-
156+
addr := stripCertHashes(ln.Multiaddr()).Encapsulate(foobarHash)
155157
_, err := tr2.Dial(context.Background(), addr, serverID)
156158
require.Error(t, err)
157159
require.Contains(t, err.Error(), "CRYPTO_ERROR (0x12a): cert hash not found")
@@ -424,3 +426,84 @@ func TestConnectionGaterInterceptSecured(t *testing.T) {
424426
t.Fatal("timeout")
425427
}
426428
}
429+
430+
func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config {
431+
t.Helper()
432+
certTempl := &x509.Certificate{
433+
SerialNumber: big.NewInt(1234),
434+
Subject: pkix.Name{Organization: []string{"webtransport"}},
435+
NotBefore: start,
436+
NotAfter: end,
437+
IsCA: true,
438+
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
439+
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
440+
BasicConstraintsValid: true,
441+
IPAddresses: []net.IP{ip},
442+
}
443+
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
444+
require.NoError(t, err)
445+
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &priv.PublicKey, priv)
446+
require.NoError(t, err)
447+
cert, err := x509.ParseCertificate(caBytes)
448+
require.NoError(t, err)
449+
return &tls.Config{
450+
Certificates: []tls.Certificate{{
451+
Certificate: [][]byte{cert.Raw},
452+
PrivateKey: priv,
453+
Leaf: cert,
454+
}},
455+
}
456+
}
457+
458+
func TestStaticTLSConf(t *testing.T) {
459+
tlsConf := getTLSConf(t, net.ParseIP("127.0.0.1"), time.Now(), time.Now().Add(365*24*time.Hour))
460+
461+
serverID, serverKey := newIdentity(t)
462+
tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager, libp2pwebtransport.WithTLSConfig(tlsConf))
463+
require.NoError(t, err)
464+
defer tr.(io.Closer).Close()
465+
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
466+
require.NoError(t, err)
467+
defer ln.Close()
468+
require.Empty(t, extractCertHashes(ln.Multiaddr()), "listener address shouldn't contain any certhash")
469+
470+
t.Run("fails when the certificate is invalid", func(t *testing.T) {
471+
_, key := newIdentity(t)
472+
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager)
473+
require.NoError(t, err)
474+
defer cl.(io.Closer).Close()
475+
476+
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
477+
require.Error(t, err)
478+
if !strings.Contains(err.Error(), "certificate is not trusted") &&
479+
!strings.Contains(err.Error(), "certificate signed by unknown authority") {
480+
t.Fatalf("expected a certificate error, got %+v", err)
481+
}
482+
})
483+
484+
t.Run("fails when dialing with a wrong certhash", func(t *testing.T) {
485+
_, key := newIdentity(t)
486+
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager)
487+
require.NoError(t, err)
488+
defer cl.(io.Closer).Close()
489+
490+
addr := ln.Multiaddr().Encapsulate(getCerthashComponent(t, []byte("foo")))
491+
_, err = cl.Dial(context.Background(), addr, serverID)
492+
require.Error(t, err)
493+
require.Contains(t, err.Error(), "cert hash not found")
494+
})
495+
496+
t.Run("accepts a valid TLS certificate", func(t *testing.T) {
497+
_, key := newIdentity(t)
498+
store := x509.NewCertPool()
499+
store.AddCert(tlsConf.Certificates[0].Leaf)
500+
tlsConf := &tls.Config{RootCAs: store}
501+
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager, libp2pwebtransport.WithTLSClientConfig(tlsConf))
502+
require.NoError(t, err)
503+
defer cl.(io.Closer).Close()
504+
505+
conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID)
506+
require.NoError(t, err)
507+
defer conn.Close()
508+
})
509+
}

0 commit comments

Comments
 (0)