Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/robfig/cron/v3"
"github.com/snowflakedb/gosnowflake"
"gopkg.in/yaml.v2"
)

Expand Down Expand Up @@ -164,6 +165,8 @@ type connection struct {
user string
tokenExpirationTime time.Time
iteratorValues []string
snowflakeConfig *gosnowflake.Config
snowflakeDSN string
}

// Query is an SQL query that is executed on a connection
Expand Down
125 changes: 103 additions & 22 deletions job.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
"strconv"
"strings"
"time"
"crypto/rsa"
"crypto/x509"
"encoding/pem"

_ "github.com/ClickHouse/clickhouse-go/v2" // register the ClickHouse driver
"github.com/cenkalti/backoff"
Expand Down Expand Up @@ -374,36 +377,86 @@ func (j *Job) updateConnections() {
}
}
if newConn.driver == "snowflake" {
u, err := url.Parse(conn)
if err != nil {
level.Error(j.log).Log("msg", "Failed to parse Snowflake URL", "url", conn, "err", err)
continue
}

queryParams := u.Query()
privateKeyPath := os.ExpandEnv(queryParams.Get("private_key_file"))

cfg := &gosnowflake.Config{
Account: u.Host,
User: u.User.Username(),
Role: queryParams.Get("role"),
Database: queryParams.Get("database"),
Schema: queryParams.Get("schema"),
}

pw, set := u.User.Password()
if set {
cfg.Password = pw
}

if u.Port() != "" {
portStr, err := strconv.Atoi(u.Port())

if privateKeyPath != "" {
// RSA key auth
keyBytes, err := os.ReadFile(privateKeyPath)
if err != nil {
level.Error(j.log).Log("msg", "Failed to parse Snowflake port", "connection", conn, "err", err)
level.Error(j.log).Log("msg", "Failed to read private key file", "path", privateKeyPath, "err", err)
continue
}

keyBlock, _ := pem.Decode(keyBytes)
if keyBlock == nil {
level.Error(j.log).Log("msg", "Failed to decode PEM block", "path", privateKeyPath)
continue
}

var privateKey *rsa.PrivateKey
if parsedKey, err := x509.ParsePKCS8PrivateKey(keyBlock.Bytes); err == nil {
privateKey, _ = parsedKey.(*rsa.PrivateKey)
} else if parsedKey, err := x509.ParsePKCS1PrivateKey(keyBlock.Bytes); err == nil {
privateKey = parsedKey
} else {
level.Error(j.log).Log("msg", "Failed to parse private key", "err", err)
continue
}

cfg.Authenticator = gosnowflake.AuthTypeJwt
cfg.PrivateKey = privateKey

dsn, err := gosnowflake.DSN(cfg)
if err != nil {
level.Error(j.log).Log("msg", "Failed to create Snowflake DSN with RSA", "err", err)
continue
}

newConn.snowflakeConfig = cfg
newConn.snowflakeDSN = dsn
newConn.host = u.Host
newConn.tokenExpirationTime = time.Now().Add(time.Hour)
} else {
// Password auth
if pw, set := u.User.Password(); set {
cfg.Password = pw
}
if u.Port() != "" {
if port, err := strconv.Atoi(u.Port()); err == nil {
cfg.Port = port
}
}

dsn, err := gosnowflake.DSN(cfg)
if err != nil {
level.Error(j.log).Log("msg", "Failed to create Snowflake DSN with password", "err", err)
continue
}

newConn.conn, err = sqlx.Open("snowflake", dsn)
if err != nil {
level.Error(j.log).Log("msg", "Failed to open Snowflake connection", "err", err)
continue
}
cfg.Port = portStr
}

dsn, err := gosnowflake.DSN(cfg)
if err != nil {
level.Error(j.log).Log("msg", "Failed to create Snowflake DSN", "connection", conn, "err", err)
continue
}

newConn.conn, err = sqlx.Open("snowflake", dsn)
if err != nil {
level.Error(j.log).Log("msg", "Failed to open Snowflake connection", "connection", conn, "err", err)
continue
}

j.conns = append(j.conns, newConn)
continue
}

j.conns = append(j.conns, newConn)
Expand Down Expand Up @@ -570,6 +623,34 @@ func (c *connection) connect(job *Job) error {
}
return nil
}
if c.driver == "snowflake" {
if c.snowflakeDSN != "" {
if time.Now().After(c.tokenExpirationTime) {
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
c.tokenExpirationTime = time.Now().Add(time.Hour)
}

db, err := sqlx.Open("snowflake", c.snowflakeDSN)
if err != nil {
return fmt.Errorf("failed to open Snowflake connection: %w (host: %s)", err, c.host)
}

db.SetMaxOpenConns(1)
db.SetMaxIdleConns(0)
db.SetConnMaxLifetime(30 * time.Minute)

if err := db.Ping(); err != nil {
db.Close()
return fmt.Errorf("failed to ping Snowflake: %w (host: %s)", err, c.host)
}

c.conn = db
return nil
}
}
dsn := c.url
switch c.driver {
case "mysql":
Expand Down