Skip to content

Commit 19a20bb

Browse files
authored
Merge pull request #159 from brunobrn/feature/snowflake-rsa-support
Add optional RSA key authentication and token regeneration support for Snowflake
2 parents be2aea6 + f64bf84 commit 19a20bb

File tree

2 files changed

+106
-22
lines changed

2 files changed

+106
-22
lines changed

config.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/prometheus/client_golang/prometheus"
1616
"github.com/prometheus/client_golang/prometheus/promauto"
1717
"github.com/robfig/cron/v3"
18+
"github.com/snowflakedb/gosnowflake"
1819
"gopkg.in/yaml.v2"
1920
)
2021

@@ -164,6 +165,8 @@ type connection struct {
164165
user string
165166
tokenExpirationTime time.Time
166167
iteratorValues []string
168+
snowflakeConfig *gosnowflake.Config
169+
snowflakeDSN string
167170
}
168171

169172
// Query is an SQL query that is executed on a connection

job.go

Lines changed: 103 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ import (
99
"strconv"
1010
"strings"
1111
"time"
12+
"crypto/rsa"
13+
"crypto/x509"
14+
"encoding/pem"
1215

1316
_ "github.com/ClickHouse/clickhouse-go/v2" // register the ClickHouse driver
1417
"github.com/cenkalti/backoff"
@@ -374,36 +377,86 @@ func (j *Job) updateConnections() {
374377
}
375378
}
376379
if newConn.driver == "snowflake" {
380+
u, err := url.Parse(conn)
381+
if err != nil {
382+
level.Error(j.log).Log("msg", "Failed to parse Snowflake URL", "url", conn, "err", err)
383+
continue
384+
}
385+
386+
queryParams := u.Query()
387+
privateKeyPath := os.ExpandEnv(queryParams.Get("private_key_file"))
388+
377389
cfg := &gosnowflake.Config{
378390
Account: u.Host,
379391
User: u.User.Username(),
392+
Role: queryParams.Get("role"),
393+
Database: queryParams.Get("database"),
394+
Schema: queryParams.Get("schema"),
380395
}
381-
382-
pw, set := u.User.Password()
383-
if set {
384-
cfg.Password = pw
385-
}
386-
387-
if u.Port() != "" {
388-
portStr, err := strconv.Atoi(u.Port())
396+
397+
if privateKeyPath != "" {
398+
// RSA key auth
399+
keyBytes, err := os.ReadFile(privateKeyPath)
389400
if err != nil {
390-
level.Error(j.log).Log("msg", "Failed to parse Snowflake port", "connection", conn, "err", err)
401+
level.Error(j.log).Log("msg", "Failed to read private key file", "path", privateKeyPath, "err", err)
402+
continue
403+
}
404+
405+
keyBlock, _ := pem.Decode(keyBytes)
406+
if keyBlock == nil {
407+
level.Error(j.log).Log("msg", "Failed to decode PEM block", "path", privateKeyPath)
408+
continue
409+
}
410+
411+
var privateKey *rsa.PrivateKey
412+
if parsedKey, err := x509.ParsePKCS8PrivateKey(keyBlock.Bytes); err == nil {
413+
privateKey, _ = parsedKey.(*rsa.PrivateKey)
414+
} else if parsedKey, err := x509.ParsePKCS1PrivateKey(keyBlock.Bytes); err == nil {
415+
privateKey = parsedKey
416+
} else {
417+
level.Error(j.log).Log("msg", "Failed to parse private key", "err", err)
418+
continue
419+
}
420+
421+
cfg.Authenticator = gosnowflake.AuthTypeJwt
422+
cfg.PrivateKey = privateKey
423+
424+
dsn, err := gosnowflake.DSN(cfg)
425+
if err != nil {
426+
level.Error(j.log).Log("msg", "Failed to create Snowflake DSN with RSA", "err", err)
427+
continue
428+
}
429+
430+
newConn.snowflakeConfig = cfg
431+
newConn.snowflakeDSN = dsn
432+
newConn.host = u.Host
433+
newConn.tokenExpirationTime = time.Now().Add(time.Hour)
434+
} else {
435+
// Password auth
436+
if pw, set := u.User.Password(); set {
437+
cfg.Password = pw
438+
}
439+
if u.Port() != "" {
440+
if port, err := strconv.Atoi(u.Port()); err == nil {
441+
cfg.Port = port
442+
}
443+
}
444+
445+
dsn, err := gosnowflake.DSN(cfg)
446+
if err != nil {
447+
level.Error(j.log).Log("msg", "Failed to create Snowflake DSN with password", "err", err)
448+
continue
449+
}
450+
451+
newConn.conn, err = sqlx.Open("snowflake", dsn)
452+
if err != nil {
453+
level.Error(j.log).Log("msg", "Failed to open Snowflake connection", "err", err)
391454
continue
392455
}
393-
cfg.Port = portStr
394-
}
395-
396-
dsn, err := gosnowflake.DSN(cfg)
397-
if err != nil {
398-
level.Error(j.log).Log("msg", "Failed to create Snowflake DSN", "connection", conn, "err", err)
399-
continue
400-
}
401-
402-
newConn.conn, err = sqlx.Open("snowflake", dsn)
403-
if err != nil {
404-
level.Error(j.log).Log("msg", "Failed to open Snowflake connection", "connection", conn, "err", err)
405-
continue
406456
}
457+
458+
j.conns = append(j.conns, newConn)
459+
continue
407460
}
408461

409462
j.conns = append(j.conns, newConn)
@@ -570,6 +623,34 @@ func (c *connection) connect(job *Job) error {
570623
}
571624
return nil
572625
}
626+
if c.driver == "snowflake" {
627+
if c.snowflakeDSN != "" {
628+
if time.Now().After(c.tokenExpirationTime) {
629+
if c.conn != nil {
630+
c.conn.Close()
631+
c.conn = nil
632+
}
633+
c.tokenExpirationTime = time.Now().Add(time.Hour)
634+
}
635+
636+
db, err := sqlx.Open("snowflake", c.snowflakeDSN)
637+
if err != nil {
638+
return fmt.Errorf("failed to open Snowflake connection: %w (host: %s)", err, c.host)
639+
}
640+
641+
db.SetMaxOpenConns(1)
642+
db.SetMaxIdleConns(0)
643+
db.SetConnMaxLifetime(30 * time.Minute)
644+
645+
if err := db.Ping(); err != nil {
646+
db.Close()
647+
return fmt.Errorf("failed to ping Snowflake: %w (host: %s)", err, c.host)
648+
}
649+
650+
c.conn = db
651+
return nil
652+
}
653+
}
573654
dsn := c.url
574655
switch c.driver {
575656
case "mysql":

0 commit comments

Comments
 (0)