Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions wren-launcher/commands/dbt/converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,26 @@ func ConvertDbtProjectCore(opts ConvertOptions) (*ConvertResult, error) {
"password": typedDS.Password,
},
}
case *WrenMSSQLDataSource:
var host string
if opts.UsedByContainer {
host = handleLocalhostForContainer(typedDS.Host)
} else {
host = typedDS.Host
}
wrenDataSource = map[string]interface{}{
"type": "mssql",
"properties": map[string]interface{}{
"host": host,
"port": typedDS.Port,
"database": typedDS.Database,
"user": typedDS.User,
"password": typedDS.Password,
"tds_version": typedDS.TdsVersion,
"driver": typedDS.Driver,
"kwargs": typedDS.Kwargs,
},
}
case *WrenLocalFileDataSource:
var url string
if opts.UsedByContainer {
Expand Down
124 changes: 117 additions & 7 deletions wren-launcher/commands/dbt/data_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,22 @@ import (

// Constants for data types
const (
integerType = "integer"
varcharType = "varchar"
dateType = "date"
timestampType = "timestamp"
doubleType = "double"
booleanType = "boolean"
postgresType = "postgres"
integerType = "integer"
smallintType = "smallint"
bigintType = "bigint"
floatType = "float"
decimalType = "decimal"
varcharType = "varchar"
charType = "char"
textType = "text"
dateType = "date"
timestampType = "timestamp"
timestamptzType = "timestamptz"
doubleType = "double"
booleanType = "boolean"
jsonType = "json"
intervalType = "interval"
postgresType = "postgres"
)

// Constants for SQL data types
Expand Down Expand Up @@ -79,6 +88,8 @@ func convertConnectionToDataSource(conn DbtConnection, dbtHomePath, profileName,
return convertToPostgresDataSource(conn)
case "duckdb":
return convertToLocalFileDataSource(conn, dbtHomePath)
case "sqlserver":
return convertToMSSQLDataSource(conn)
case "mysql":
return convertToMysqlDataSource(conn)
default:
Expand Down Expand Up @@ -114,6 +125,26 @@ func convertToPostgresDataSource(conn DbtConnection) (*WrenPostgresDataSource, e
return ds, nil
}

func convertToMSSQLDataSource(conn DbtConnection) (*WrenMSSQLDataSource, error) {
port := strconv.Itoa(conn.Port)
if conn.Port == 0 {
port = "1433"
}

ds := &WrenMSSQLDataSource{
Database: conn.Database,
Host: conn.Server,
Port: port,
User: conn.User,
Password: conn.Password,
TdsVersion: "8.0", // the default tds version for Wren engine image
Driver: "ODBC Driver 18 for SQL Server", // the driver used by Wren engine image
Kwargs: map[string]interface{}{"TrustServerCertificate": "YES"},
}

return ds, nil
}

// convertToLocalFileDataSource converts to local file data source
func convertToLocalFileDataSource(conn DbtConnection, dbtHome string) (*WrenLocalFileDataSource, error) {
// For file types, we need to get URL and format info from Additional fields
Expand Down Expand Up @@ -264,6 +295,85 @@ func (ds *WrenPostgresDataSource) MapType(sourceType string) string {
return sourceType
}

type WrenMSSQLDataSource struct {
Database string `json:"database"`
Host string `json:"host"`
Port string `json:"port"`
User string `json:"user"`
Password string `json:"password"`
TdsVersion string `json:"tds_version"`
Driver string `json:"driver"`
Kwargs map[string]interface{} `json:"kwargs"`
}

func (ds *WrenMSSQLDataSource) GetType() string {
return "mssql"
}

func (ds *WrenMSSQLDataSource) Validate() error {
if ds.Host == "" {
return fmt.Errorf("host cannot be empty")
}
if ds.Database == "" {
return fmt.Errorf("database cannot be empty")
}
if ds.User == "" {
return fmt.Errorf("user cannot be empty")
}
if ds.Port == "" {
return fmt.Errorf("port must be specified")
}
port, err := strconv.Atoi(ds.Port)
if err != nil {
return fmt.Errorf("port must be a valid number")
}
if port <= 0 || port > 65535 {
return fmt.Errorf("port must be between 1 and 65535")
}
if ds.Password == "" {
return fmt.Errorf("password cannot be empty")
}
return nil
}

func (ds *WrenMSSQLDataSource) MapType(sourceType string) string {
// This method is not used in WrenMSSQLDataSource, but required by DataSource interface
switch strings.ToLower(sourceType) {
case charType, "nchar":
return charType
case varcharType, "nvarchar":
return varcharType
case textType, "ntext":
return textType
case "bit", "tinyint":
return booleanType
case "smallint":
return smallintType
case "int":
return integerType
case "bigint":
return bigintType
case booleanType:
return booleanType
case "float", "real":
return floatType
case "decimal", "numeric", "money", "smallmoney":
return decimalType
case "date":
return dateType
case "datetime", "datetime2", "smalldatetime":
return timestampType
case "time":
return intervalType
case "datetimeoffset":
return timestamptzType
case "json":
return jsonType
default:
return strings.ToLower(sourceType)
}
}

type WrenMysqlDataSource struct {
Database string `json:"database"`
Host string `json:"host"`
Expand Down
67 changes: 67 additions & 0 deletions wren-launcher/commands/dbt/data_source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,73 @@ func TestFromDbtProfiles_UnsupportedType(t *testing.T) {
}
}

func TestFromMssqlProfiles(t *testing.T) {
// Test MSSQL connection conversion
profiles := &DbtProfiles{
Profiles: map[string]DbtProfile{
"test_profile": {
Target: "dev",
Outputs: map[string]DbtConnection{
"dev": {
Type: "sqlserver",
Server: testHost,
Port: 1433,
Database: "test_db",
User: testUser,
Password: testPassword,
},
},
},
},
}

dataSources, err := FromDbtProfiles(profiles)
if err != nil {
t.Fatalf("FromDbtProfiles failed: %v", err)
}

if len(dataSources) != 1 {
t.Fatalf("Expected 1 data source, got %d", len(dataSources))
}

ds, ok := dataSources[0].(*WrenMSSQLDataSource)
if !ok {
t.Fatalf("Expected WrenMSSQLDataSource, got %T", dataSources[0])
}

if ds.Host != testHost {
t.Errorf("Expected host '%s', got '%s'", testHost, ds.Host)
}

if ds.Port != "1433" {
t.Errorf("Expected port 1433, got %s", ds.Port)
}

if ds.Database != "test_db" {
t.Errorf("Expected database 'test_db', got '%s'", ds.Database)
}

if ds.User != testUser {
t.Errorf("Expected user '%s', got '%s'", testUser, ds.User)
}

if ds.Password != testPassword {
t.Errorf("Expected password '%s', got '%s'", testPassword, ds.Password)
}

if ds.TdsVersion != "8.0" {
t.Errorf("Expected TDS version '8.0', got '%s'", ds.TdsVersion)
}

if ds.Driver != "ODBC Driver 18 for SQL Server" {
t.Errorf("Expected driver 'ODBC Driver 18 for SQL Server', got '%s'", ds.Driver)
}

if ds.Kwargs["TrustServerCertificate"] != "YES" {
t.Errorf("Expected TrustServerCertificate 'YES', got '%s'", ds.Kwargs["TrustServerCertificate"])
}
}

func TestFromDbtProfiles_NilProfiles(t *testing.T) {
// Test nil profiles
_, err := FromDbtProfiles(nil)
Expand Down
1 change: 1 addition & 0 deletions wren-launcher/commands/dbt/profiles.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type DbtProfile struct {
type DbtConnection struct {
Type string `yaml:"type" json:"type"`
Host string `yaml:"host,omitempty" json:"host,omitempty"`
Server string `yaml:"server,omitempty" json:"server,omitempty"` // MSSQL
Port int `yaml:"port,omitempty" json:"port,omitempty"`
User string `yaml:"user,omitempty" json:"user,omitempty"`
Password string `yaml:"password,omitempty" json:"password,omitempty"`
Expand Down
10 changes: 9 additions & 1 deletion wren-launcher/commands/dbt/profiles_analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"
"path/filepath"
"runtime"
"strconv"

"gopkg.in/yaml.v3"
)
Expand Down Expand Up @@ -105,6 +106,12 @@ func parseConnection(connectionMap map[string]interface{}) (*DbtConnection, erro
return v
case float64:
return int(v)
case int64:
return int(v)
case string:
if i, err := strconv.Atoi(v); err == nil {
return i
}
}
}
return 0
Expand Down Expand Up @@ -142,13 +149,14 @@ func parseConnection(connectionMap map[string]interface{}) (*DbtConnection, erro
connection.SSLMode = getString("sslmode")
connection.Path = getString("path")
connection.SslDisable = getBool("ssl_disable") // MySQL specific
connection.Server = getString("server")

// Store any additional fields that weren't mapped
knownFields := map[string]bool{
"type": true, "host": true, "port": true, "user": true, "password": true,
"database": true, "dbname": true, "schema": true, "project": true, "dataset": true,
"keyfile": true, "account": true, "warehouse": true, "role": true,
"keepalive": true, "search_path": true, "sslmode": true, "path": true, "ssl_disable": true,
"keepalive": true, "search_path": true, "sslmode": true, "path": true, "server": true, "ssl_disable": true,
}

for key, value := range connectionMap {
Expand Down