Skip to content

Commit f1b3d95

Browse files
Jan Waśwendigo
authored andcommitted
Enable TLS in test Trino container
1 parent aa08ec4 commit f1b3d95

File tree

3 files changed

+130
-0
lines changed

3 files changed

+130
-0
lines changed

trino/etc/config.properties

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,11 @@ node-scheduler.include-coordinator=true
44
http-server.http.port=8080
55
discovery-server.enabled=true
66
discovery.uri=http://localhost:8080
7+
8+
http-server.authentication.type=JWT
9+
http-server.authentication.jwt.key-file=/etc/trino/secrets/public_key.pem
10+
http-server.https.enabled=true
11+
http-server.https.port=8443
12+
http-server.authentication.allow-insecure-over-http=true
13+
http-server.https.keystore.path=/etc/trino/secrets/certificate_with_key.pem
14+
internal-communication.shared-secret=gotrino

trino/etc/secrets/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.pem

trino/integration_test.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,20 @@ package trino
1616

1717
import (
1818
"context"
19+
"crypto/rand"
20+
"crypto/rsa"
21+
"crypto/tls"
22+
"crypto/x509"
23+
"crypto/x509/pkix"
1924
"database/sql"
2025
"database/sql/driver"
26+
"encoding/pem"
2127
"errors"
2228
"flag"
29+
"fmt"
2330
"io"
2431
"log"
32+
"math/big"
2533
"net/http"
2634
"os"
2735
"strings"
@@ -55,6 +63,7 @@ var (
5563
false,
5664
"do not delete containers on exit",
5765
)
66+
tlsServer = ""
5867
)
5968

6069
func TestMain(m *testing.M) {
@@ -79,6 +88,10 @@ func TestMain(m *testing.M) {
7988
resource, ok = pool.ContainerByName(name)
8089

8190
if !ok {
91+
err = generateCerts(wd + "/etc/secrets")
92+
if err != nil {
93+
log.Fatalf("Could not generate TLS certificates: %s", err)
94+
}
8295
if *trinoImageTagFlag == "" {
8396
*trinoImageTagFlag = "latest"
8497
}
@@ -87,6 +100,10 @@ func TestMain(m *testing.M) {
87100
Repository: "trinodb/trino",
88101
Tag: *trinoImageTagFlag,
89102
Mounts: []string{wd + "/etc:/etc/trino"},
103+
ExposedPorts: []string{
104+
"8080/tcp",
105+
"8443/tcp",
106+
},
90107
})
91108
if err != nil {
92109
log.Fatalf("Could not start resource: %s", err)
@@ -106,6 +123,12 @@ func TestMain(m *testing.M) {
106123
log.Fatalf("Timed out waiting for container to get ready: %s", err)
107124
}
108125
*integrationServerFlag = "http://test@localhost:" + resource.GetPort("8080/tcp")
126+
tlsServer = "https://test@localhost:" + resource.GetPort("8443/tcp")
127+
128+
http.DefaultTransport.(*http.Transport).TLSClientConfig, err = getTLSConfig(wd + "/etc/secrets")
129+
if err != nil {
130+
log.Fatalf("Failed to set the default TLS config: %s", err)
131+
}
109132
}
110133

111134
code := m.Run()
@@ -120,6 +143,104 @@ func TestMain(m *testing.M) {
120143
os.Exit(code)
121144
}
122145

146+
func generateCerts(dir string) error {
147+
priv, err := rsa.GenerateKey(rand.Reader, 2048)
148+
if err != nil {
149+
return fmt.Errorf("failed to generate private key: %w", err)
150+
}
151+
152+
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
153+
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
154+
if err != nil {
155+
return fmt.Errorf("failed to generate serial number: %w", err)
156+
}
157+
158+
template := x509.Certificate{
159+
SerialNumber: serialNumber,
160+
Subject: pkix.Name{
161+
Organization: []string{"Trino Software Foundation"},
162+
},
163+
DNSNames: []string{"localhost"},
164+
NotBefore: time.Now(),
165+
NotAfter: time.Now().Add(1 * time.Hour),
166+
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
167+
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
168+
BasicConstraintsValid: true,
169+
}
170+
171+
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
172+
if err != nil {
173+
return fmt.Errorf("unable to marshal private key: %w", err)
174+
}
175+
privBlock := &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}
176+
err = writePEM(dir+"/private_key.pem", privBlock)
177+
if err != nil {
178+
return err
179+
}
180+
181+
pubBytes, err := x509.MarshalPKIXPublicKey(&priv.PublicKey)
182+
if err != nil {
183+
return fmt.Errorf("unable to marshal public key: %w", err)
184+
}
185+
pubBlock := &pem.Block{Type: "PUBLIC KEY", Bytes: pubBytes}
186+
err = writePEM(dir+"/public_key.pem", pubBlock)
187+
if err != nil {
188+
return err
189+
}
190+
191+
certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
192+
if err != nil {
193+
return fmt.Errorf("failed to create certificate: %w", err)
194+
}
195+
certBlock := &pem.Block{Type: "CERTIFICATE", Bytes: certBytes}
196+
err = writePEM(dir+"/certificate.pem", certBlock)
197+
if err != nil {
198+
return err
199+
}
200+
201+
err = writePEM(dir+"/certificate_with_key.pem", certBlock, privBlock, pubBlock)
202+
if err != nil {
203+
return err
204+
}
205+
206+
return nil
207+
}
208+
209+
func writePEM(filename string, blocks ...*pem.Block) error {
210+
// all files are world-readable, so they can be read inside the Trino container
211+
out, err := os.Create(filename)
212+
if err != nil {
213+
return fmt.Errorf("failed to open %s for writing: %w", filename, err)
214+
}
215+
for _, block := range blocks {
216+
if err := pem.Encode(out, block); err != nil {
217+
return fmt.Errorf("failed to write %s data to %s: %w", block.Type, filename, err)
218+
}
219+
}
220+
if err := out.Close(); err != nil {
221+
return fmt.Errorf("error closing %s: %w", filename, err)
222+
}
223+
return nil
224+
}
225+
226+
func getTLSConfig(dir string) (*tls.Config, error) {
227+
certPool, err := x509.SystemCertPool()
228+
if err != nil {
229+
return nil, fmt.Errorf("failed to read the system cert pool: %s", err)
230+
}
231+
caCertPEM, err := os.ReadFile(dir + "/certificate.pem")
232+
if err != nil {
233+
return nil, fmt.Errorf("failed to read the certificate: %s", err)
234+
}
235+
ok := certPool.AppendCertsFromPEM(caCertPEM)
236+
if !ok {
237+
return nil, fmt.Errorf("failed to parse the certificate: %s", err)
238+
}
239+
return &tls.Config{
240+
RootCAs: certPool,
241+
}, nil
242+
}
243+
123244
// integrationOpen opens a connection to the integration test server.
124245
func integrationOpen(t *testing.T, dsn ...string) *sql.DB {
125246
if testing.Short() {

0 commit comments

Comments
 (0)