|
9 | 9 | "strconv"
|
10 | 10 | "strings"
|
11 | 11 | "time"
|
| 12 | + "crypto/rsa" |
| 13 | + "crypto/x509" |
| 14 | + "encoding/pem" |
12 | 15 |
|
13 | 16 | _ "github.com/ClickHouse/clickhouse-go/v2" // register the ClickHouse driver
|
14 | 17 | "github.com/cenkalti/backoff"
|
@@ -374,36 +377,86 @@ func (j *Job) updateConnections() {
|
374 | 377 | }
|
375 | 378 | }
|
376 | 379 | if newConn.driver == "snowflake" {
|
| 380 | + u, err := url.Parse(conn) |
| 381 | + if err != nil { |
| 382 | + level.Error(j.log).Log("msg", "Failed to parse Snowflake URL", "url", conn, "err", err) |
| 383 | + continue |
| 384 | + } |
| 385 | + |
| 386 | + queryParams := u.Query() |
| 387 | + privateKeyPath := os.ExpandEnv(queryParams.Get("private_key_file")) |
| 388 | + |
377 | 389 | cfg := &gosnowflake.Config{
|
378 | 390 | Account: u.Host,
|
379 | 391 | User: u.User.Username(),
|
| 392 | + Role: queryParams.Get("role"), |
| 393 | + Database: queryParams.Get("database"), |
| 394 | + Schema: queryParams.Get("schema"), |
380 | 395 | }
|
381 |
| - |
382 |
| - pw, set := u.User.Password() |
383 |
| - if set { |
384 |
| - cfg.Password = pw |
385 |
| - } |
386 |
| - |
387 |
| - if u.Port() != "" { |
388 |
| - portStr, err := strconv.Atoi(u.Port()) |
| 396 | + |
| 397 | + if privateKeyPath != "" { |
| 398 | + // RSA key auth |
| 399 | + keyBytes, err := os.ReadFile(privateKeyPath) |
389 | 400 | if err != nil {
|
390 |
| - level.Error(j.log).Log("msg", "Failed to parse Snowflake port", "connection", conn, "err", err) |
| 401 | + level.Error(j.log).Log("msg", "Failed to read private key file", "path", privateKeyPath, "err", err) |
| 402 | + continue |
| 403 | + } |
| 404 | + |
| 405 | + keyBlock, _ := pem.Decode(keyBytes) |
| 406 | + if keyBlock == nil { |
| 407 | + level.Error(j.log).Log("msg", "Failed to decode PEM block", "path", privateKeyPath) |
| 408 | + continue |
| 409 | + } |
| 410 | + |
| 411 | + var privateKey *rsa.PrivateKey |
| 412 | + if parsedKey, err := x509.ParsePKCS8PrivateKey(keyBlock.Bytes); err == nil { |
| 413 | + privateKey, _ = parsedKey.(*rsa.PrivateKey) |
| 414 | + } else if parsedKey, err := x509.ParsePKCS1PrivateKey(keyBlock.Bytes); err == nil { |
| 415 | + privateKey = parsedKey |
| 416 | + } else { |
| 417 | + level.Error(j.log).Log("msg", "Failed to parse private key", "err", err) |
| 418 | + continue |
| 419 | + } |
| 420 | + |
| 421 | + cfg.Authenticator = gosnowflake.AuthTypeJwt |
| 422 | + cfg.PrivateKey = privateKey |
| 423 | + |
| 424 | + dsn, err := gosnowflake.DSN(cfg) |
| 425 | + if err != nil { |
| 426 | + level.Error(j.log).Log("msg", "Failed to create Snowflake DSN with RSA", "err", err) |
| 427 | + continue |
| 428 | + } |
| 429 | + |
| 430 | + newConn.snowflakeConfig = cfg |
| 431 | + newConn.snowflakeDSN = dsn |
| 432 | + newConn.host = u.Host |
| 433 | + newConn.tokenExpirationTime = time.Now().Add(time.Hour) |
| 434 | + } else { |
| 435 | + // Password auth |
| 436 | + if pw, set := u.User.Password(); set { |
| 437 | + cfg.Password = pw |
| 438 | + } |
| 439 | + if u.Port() != "" { |
| 440 | + if port, err := strconv.Atoi(u.Port()); err == nil { |
| 441 | + cfg.Port = port |
| 442 | + } |
| 443 | + } |
| 444 | + |
| 445 | + dsn, err := gosnowflake.DSN(cfg) |
| 446 | + if err != nil { |
| 447 | + level.Error(j.log).Log("msg", "Failed to create Snowflake DSN with password", "err", err) |
| 448 | + continue |
| 449 | + } |
| 450 | + |
| 451 | + newConn.conn, err = sqlx.Open("snowflake", dsn) |
| 452 | + if err != nil { |
| 453 | + level.Error(j.log).Log("msg", "Failed to open Snowflake connection", "err", err) |
391 | 454 | continue
|
392 | 455 | }
|
393 |
| - cfg.Port = portStr |
394 |
| - } |
395 |
| - |
396 |
| - dsn, err := gosnowflake.DSN(cfg) |
397 |
| - if err != nil { |
398 |
| - level.Error(j.log).Log("msg", "Failed to create Snowflake DSN", "connection", conn, "err", err) |
399 |
| - continue |
400 |
| - } |
401 |
| - |
402 |
| - newConn.conn, err = sqlx.Open("snowflake", dsn) |
403 |
| - if err != nil { |
404 |
| - level.Error(j.log).Log("msg", "Failed to open Snowflake connection", "connection", conn, "err", err) |
405 |
| - continue |
406 | 456 | }
|
| 457 | + |
| 458 | + j.conns = append(j.conns, newConn) |
| 459 | + continue |
407 | 460 | }
|
408 | 461 |
|
409 | 462 | j.conns = append(j.conns, newConn)
|
@@ -570,6 +623,34 @@ func (c *connection) connect(job *Job) error {
|
570 | 623 | }
|
571 | 624 | return nil
|
572 | 625 | }
|
| 626 | + if c.driver == "snowflake" { |
| 627 | + if c.snowflakeDSN != "" { |
| 628 | + if time.Now().After(c.tokenExpirationTime) { |
| 629 | + if c.conn != nil { |
| 630 | + c.conn.Close() |
| 631 | + c.conn = nil |
| 632 | + } |
| 633 | + c.tokenExpirationTime = time.Now().Add(time.Hour) |
| 634 | + } |
| 635 | + |
| 636 | + db, err := sqlx.Open("snowflake", c.snowflakeDSN) |
| 637 | + if err != nil { |
| 638 | + return fmt.Errorf("failed to open Snowflake connection: %w (host: %s)", err, c.host) |
| 639 | + } |
| 640 | + |
| 641 | + db.SetMaxOpenConns(1) |
| 642 | + db.SetMaxIdleConns(0) |
| 643 | + db.SetConnMaxLifetime(30 * time.Minute) |
| 644 | + |
| 645 | + if err := db.Ping(); err != nil { |
| 646 | + db.Close() |
| 647 | + return fmt.Errorf("failed to ping Snowflake: %w (host: %s)", err, c.host) |
| 648 | + } |
| 649 | + |
| 650 | + c.conn = db |
| 651 | + return nil |
| 652 | + } |
| 653 | + } |
573 | 654 | dsn := c.url
|
574 | 655 | switch c.driver {
|
575 | 656 | case "mysql":
|
|
0 commit comments