Skip to content

Commit 02bb00b

Browse files
authored
feat(wren-launcher) support MSSQL data source for dbt-tools (#1887)
1 parent a940a52 commit 02bb00b

File tree

5 files changed

+214
-8
lines changed

5 files changed

+214
-8
lines changed

wren-launcher/commands/dbt/converter.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,26 @@ func ConvertDbtProjectCore(opts ConvertOptions) (*ConvertResult, error) {
139139
"password": typedDS.Password,
140140
},
141141
}
142+
case *WrenMSSQLDataSource:
143+
var host string
144+
if opts.UsedByContainer {
145+
host = handleLocalhostForContainer(typedDS.Host)
146+
} else {
147+
host = typedDS.Host
148+
}
149+
wrenDataSource = map[string]interface{}{
150+
"type": "mssql",
151+
"properties": map[string]interface{}{
152+
"host": host,
153+
"port": typedDS.Port,
154+
"database": typedDS.Database,
155+
"user": typedDS.User,
156+
"password": typedDS.Password,
157+
"tds_version": typedDS.TdsVersion,
158+
"driver": typedDS.Driver,
159+
"kwargs": typedDS.Kwargs,
160+
},
161+
}
142162
case *WrenLocalFileDataSource:
143163
var url string
144164
if opts.UsedByContainer {

wren-launcher/commands/dbt/data_source.go

Lines changed: 117 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,22 @@ import (
1111

1212
// Constants for data types
1313
const (
14-
integerType = "integer"
15-
varcharType = "varchar"
16-
dateType = "date"
17-
timestampType = "timestamp"
18-
doubleType = "double"
19-
booleanType = "boolean"
20-
postgresType = "postgres"
14+
integerType = "integer"
15+
smallintType = "smallint"
16+
bigintType = "bigint"
17+
floatType = "float"
18+
decimalType = "decimal"
19+
varcharType = "varchar"
20+
charType = "char"
21+
textType = "text"
22+
dateType = "date"
23+
timestampType = "timestamp"
24+
timestamptzType = "timestamptz"
25+
doubleType = "double"
26+
booleanType = "boolean"
27+
jsonType = "json"
28+
intervalType = "interval"
29+
postgresType = "postgres"
2130
)
2231

2332
// Constants for SQL data types
@@ -79,6 +88,8 @@ func convertConnectionToDataSource(conn DbtConnection, dbtHomePath, profileName,
7988
return convertToPostgresDataSource(conn)
8089
case "duckdb":
8190
return convertToLocalFileDataSource(conn, dbtHomePath)
91+
case "sqlserver":
92+
return convertToMSSQLDataSource(conn)
8293
case "mysql":
8394
return convertToMysqlDataSource(conn)
8495
default:
@@ -114,6 +125,26 @@ func convertToPostgresDataSource(conn DbtConnection) (*WrenPostgresDataSource, e
114125
return ds, nil
115126
}
116127

128+
func convertToMSSQLDataSource(conn DbtConnection) (*WrenMSSQLDataSource, error) {
129+
port := strconv.Itoa(conn.Port)
130+
if conn.Port == 0 {
131+
port = "1433"
132+
}
133+
134+
ds := &WrenMSSQLDataSource{
135+
Database: conn.Database,
136+
Host: conn.Server,
137+
Port: port,
138+
User: conn.User,
139+
Password: conn.Password,
140+
TdsVersion: "8.0", // the default tds version for Wren engine image
141+
Driver: "ODBC Driver 18 for SQL Server", // the driver used by Wren engine image
142+
Kwargs: map[string]interface{}{"TrustServerCertificate": "YES"},
143+
}
144+
145+
return ds, nil
146+
}
147+
117148
// convertToLocalFileDataSource converts to local file data source
118149
func convertToLocalFileDataSource(conn DbtConnection, dbtHome string) (*WrenLocalFileDataSource, error) {
119150
// For file types, we need to get URL and format info from Additional fields
@@ -264,6 +295,85 @@ func (ds *WrenPostgresDataSource) MapType(sourceType string) string {
264295
return sourceType
265296
}
266297

298+
type WrenMSSQLDataSource struct {
299+
Database string `json:"database"`
300+
Host string `json:"host"`
301+
Port string `json:"port"`
302+
User string `json:"user"`
303+
Password string `json:"password"`
304+
TdsVersion string `json:"tds_version"`
305+
Driver string `json:"driver"`
306+
Kwargs map[string]interface{} `json:"kwargs"`
307+
}
308+
309+
func (ds *WrenMSSQLDataSource) GetType() string {
310+
return "mssql"
311+
}
312+
313+
func (ds *WrenMSSQLDataSource) Validate() error {
314+
if ds.Host == "" {
315+
return fmt.Errorf("host cannot be empty")
316+
}
317+
if ds.Database == "" {
318+
return fmt.Errorf("database cannot be empty")
319+
}
320+
if ds.User == "" {
321+
return fmt.Errorf("user cannot be empty")
322+
}
323+
if ds.Port == "" {
324+
return fmt.Errorf("port must be specified")
325+
}
326+
port, err := strconv.Atoi(ds.Port)
327+
if err != nil {
328+
return fmt.Errorf("port must be a valid number")
329+
}
330+
if port <= 0 || port > 65535 {
331+
return fmt.Errorf("port must be between 1 and 65535")
332+
}
333+
if ds.Password == "" {
334+
return fmt.Errorf("password cannot be empty")
335+
}
336+
return nil
337+
}
338+
339+
func (ds *WrenMSSQLDataSource) MapType(sourceType string) string {
340+
// This method is not used in WrenMSSQLDataSource, but required by DataSource interface
341+
switch strings.ToLower(sourceType) {
342+
case charType, "nchar":
343+
return charType
344+
case varcharType, "nvarchar":
345+
return varcharType
346+
case textType, "ntext":
347+
return textType
348+
case "bit", "tinyint":
349+
return booleanType
350+
case "smallint":
351+
return smallintType
352+
case "int":
353+
return integerType
354+
case "bigint":
355+
return bigintType
356+
case booleanType:
357+
return booleanType
358+
case "float", "real":
359+
return floatType
360+
case "decimal", "numeric", "money", "smallmoney":
361+
return decimalType
362+
case "date":
363+
return dateType
364+
case "datetime", "datetime2", "smalldatetime":
365+
return timestampType
366+
case "time":
367+
return intervalType
368+
case "datetimeoffset":
369+
return timestamptzType
370+
case "json":
371+
return jsonType
372+
default:
373+
return strings.ToLower(sourceType)
374+
}
375+
}
376+
267377
type WrenMysqlDataSource struct {
268378
Database string `json:"database"`
269379
Host string `json:"host"`

wren-launcher/commands/dbt/data_source_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,73 @@ func TestFromDbtProfiles_UnsupportedType(t *testing.T) {
217217
}
218218
}
219219

220+
func TestFromMssqlProfiles(t *testing.T) {
221+
// Test MSSQL connection conversion
222+
profiles := &DbtProfiles{
223+
Profiles: map[string]DbtProfile{
224+
"test_profile": {
225+
Target: "dev",
226+
Outputs: map[string]DbtConnection{
227+
"dev": {
228+
Type: "sqlserver",
229+
Server: testHost,
230+
Port: 1433,
231+
Database: "test_db",
232+
User: testUser,
233+
Password: testPassword,
234+
},
235+
},
236+
},
237+
},
238+
}
239+
240+
dataSources, err := FromDbtProfiles(profiles)
241+
if err != nil {
242+
t.Fatalf("FromDbtProfiles failed: %v", err)
243+
}
244+
245+
if len(dataSources) != 1 {
246+
t.Fatalf("Expected 1 data source, got %d", len(dataSources))
247+
}
248+
249+
ds, ok := dataSources[0].(*WrenMSSQLDataSource)
250+
if !ok {
251+
t.Fatalf("Expected WrenMSSQLDataSource, got %T", dataSources[0])
252+
}
253+
254+
if ds.Host != testHost {
255+
t.Errorf("Expected host '%s', got '%s'", testHost, ds.Host)
256+
}
257+
258+
if ds.Port != "1433" {
259+
t.Errorf("Expected port 1433, got %s", ds.Port)
260+
}
261+
262+
if ds.Database != "test_db" {
263+
t.Errorf("Expected database 'test_db', got '%s'", ds.Database)
264+
}
265+
266+
if ds.User != testUser {
267+
t.Errorf("Expected user '%s', got '%s'", testUser, ds.User)
268+
}
269+
270+
if ds.Password != testPassword {
271+
t.Errorf("Expected password '%s', got '%s'", testPassword, ds.Password)
272+
}
273+
274+
if ds.TdsVersion != "8.0" {
275+
t.Errorf("Expected TDS version '8.0', got '%s'", ds.TdsVersion)
276+
}
277+
278+
if ds.Driver != "ODBC Driver 18 for SQL Server" {
279+
t.Errorf("Expected driver 'ODBC Driver 18 for SQL Server', got '%s'", ds.Driver)
280+
}
281+
282+
if ds.Kwargs["TrustServerCertificate"] != "YES" {
283+
t.Errorf("Expected TrustServerCertificate 'YES', got '%s'", ds.Kwargs["TrustServerCertificate"])
284+
}
285+
}
286+
220287
func TestFromDbtProfiles_NilProfiles(t *testing.T) {
221288
// Test nil profiles
222289
_, err := FromDbtProfiles(nil)

wren-launcher/commands/dbt/profiles.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ type DbtProfile struct {
1616
type DbtConnection struct {
1717
Type string `yaml:"type" json:"type"`
1818
Host string `yaml:"host,omitempty" json:"host,omitempty"`
19+
Server string `yaml:"server,omitempty" json:"server,omitempty"` // MSSQL
1920
Port int `yaml:"port,omitempty" json:"port,omitempty"`
2021
User string `yaml:"user,omitempty" json:"user,omitempty"`
2122
Password string `yaml:"password,omitempty" json:"password,omitempty"`

wren-launcher/commands/dbt/profiles_analyzer.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"os"
66
"path/filepath"
77
"runtime"
8+
"strconv"
89

910
"gopkg.in/yaml.v3"
1011
)
@@ -105,6 +106,12 @@ func parseConnection(connectionMap map[string]interface{}) (*DbtConnection, erro
105106
return v
106107
case float64:
107108
return int(v)
109+
case int64:
110+
return int(v)
111+
case string:
112+
if i, err := strconv.Atoi(v); err == nil {
113+
return i
114+
}
108115
}
109116
}
110117
return 0
@@ -142,13 +149,14 @@ func parseConnection(connectionMap map[string]interface{}) (*DbtConnection, erro
142149
connection.SSLMode = getString("sslmode")
143150
connection.Path = getString("path")
144151
connection.SslDisable = getBool("ssl_disable") // MySQL specific
152+
connection.Server = getString("server")
145153

146154
// Store any additional fields that weren't mapped
147155
knownFields := map[string]bool{
148156
"type": true, "host": true, "port": true, "user": true, "password": true,
149157
"database": true, "dbname": true, "schema": true, "project": true, "dataset": true,
150158
"keyfile": true, "account": true, "warehouse": true, "role": true,
151-
"keepalive": true, "search_path": true, "sslmode": true, "path": true, "ssl_disable": true,
159+
"keepalive": true, "search_path": true, "sslmode": true, "path": true, "server": true, "ssl_disable": true,
152160
}
153161

154162
for key, value := range connectionMap {

0 commit comments

Comments
 (0)