Skip to content

Commit ccafac5

Browse files
committed
support dbt-mssql converting
1 parent a34ad63 commit ccafac5

File tree

5 files changed

+207
-7
lines changed

5 files changed

+207
-7
lines changed

wren-launcher/commands/dbt/converter.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,26 @@ func ConvertDbtProjectCore(opts ConvertOptions) (*ConvertResult, error) {
139139
"password": typedDS.Password,
140140
},
141141
}
142+
case *WrenMSSQLDataSource:
143+
var host string
144+
if opts.UsedByContainer {
145+
host = handleLocalhostForContainer(typedDS.Host)
146+
} else {
147+
host = typedDS.Host
148+
}
149+
wrenDataSource = map[string]interface{}{
150+
"type": "mssql",
151+
"properties": map[string]interface{}{
152+
"host": host,
153+
"port": typedDS.Port,
154+
"database": typedDS.Database,
155+
"user": typedDS.User,
156+
"password": typedDS.Password,
157+
"tds_version": typedDS.TdsVersion,
158+
"driver": typedDS.Driver,
159+
"kwargs": typedDS.Kwargs,
160+
},
161+
}
142162
case *WrenLocalFileDataSource:
143163
var url string
144164
if opts.UsedByContainer {

wren-launcher/commands/dbt/data_source.go

Lines changed: 117 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,29 @@ package dbt
33
import (
44
"fmt"
55
"path/filepath"
6+
"strconv"
67
"strings"
78

89
"github.com/pterm/pterm"
910
)
1011

1112
// Constants for data types
1213
const (
13-
integerType = "integer"
14-
varcharType = "varchar"
15-
dateType = "date"
16-
timestampType = "timestamp"
17-
doubleType = "double"
18-
booleanType = "boolean"
14+
integerType = "integer"
15+
smallintType = "smallint"
16+
bigintType = "bigint"
17+
floatType = "float"
18+
decimalType = "decimal"
19+
varcharType = "varchar"
20+
charType = "char"
21+
textType = "text"
22+
dateType = "date"
23+
timestampType = "timestamp"
24+
timestamptzType = "timestamptz"
25+
doubleType = "double"
26+
booleanType = "boolean"
27+
jsonType = "json"
28+
intervalType = "interval"
1929
)
2030

2131
// Constants for SQL data types
@@ -77,6 +87,8 @@ func convertConnectionToDataSource(conn DbtConnection, dbtHomePath, profileName,
7787
return convertToPostgresDataSource(conn)
7888
case "duckdb":
7989
return convertToLocalFileDataSource(conn, dbtHomePath)
90+
case "sqlserver":
91+
return convertToMSSQLDataSource(conn)
8092
default:
8193
// For unsupported database types, we can choose to ignore or return error
8294
// Here we choose to return nil and log a warning
@@ -103,6 +115,26 @@ func convertToPostgresDataSource(conn DbtConnection) (*WrenPostgresDataSource, e
103115
return ds, nil
104116
}
105117

118+
func convertToMSSQLDataSource(conn DbtConnection) (*WrenMSSQLDataSource, error) {
119+
port := strconv.Itoa(conn.Port)
120+
if conn.Port == 0 {
121+
port = "1433"
122+
}
123+
124+
ds := &WrenMSSQLDataSource{
125+
Database: conn.Database,
126+
Host: conn.Server,
127+
Port: port,
128+
User: conn.User,
129+
Password: conn.Password,
130+
TdsVersion: "8.0", // the default tds version for Wren engine image
131+
Driver: "ODBC Driver 18 for SQL Server", // the driver used by Wren engine image
132+
Kwargs: map[string]interface{}{"TrustServerCertificate": "YES"},
133+
}
134+
135+
return ds, nil
136+
}
137+
106138
// convertToLocalFileDataSource converts to local file data source
107139
func convertToLocalFileDataSource(conn DbtConnection, dbtHome string) (*WrenLocalFileDataSource, error) {
108140
// For file types, we need to get URL and format info from Additional fields
@@ -222,6 +254,85 @@ func (ds *WrenPostgresDataSource) MapType(sourceType string) string {
222254
return sourceType
223255
}
224256

257+
type WrenMSSQLDataSource struct {
258+
Database string `json:"database"`
259+
Host string `json:"host"`
260+
Port string `json:"port"`
261+
User string `json:"user"`
262+
Password string `json:"password"`
263+
TdsVersion string `json:"tds_version"`
264+
Driver string `json:"driver"`
265+
Kwargs map[string]interface{} `json:"kwargs"`
266+
}
267+
268+
func (ds *WrenMSSQLDataSource) GetType() string {
269+
return "mssql"
270+
}
271+
272+
func (ds *WrenMSSQLDataSource) Validate() error {
273+
if ds.Host == "" {
274+
return fmt.Errorf("host cannot be empty")
275+
}
276+
if ds.Database == "" {
277+
return fmt.Errorf("database cannot be empty")
278+
}
279+
if ds.User == "" {
280+
return fmt.Errorf("user cannot be empty")
281+
}
282+
if ds.Port == "" {
283+
return fmt.Errorf("port must be specified")
284+
}
285+
port, err := strconv.Atoi(ds.Port)
286+
if err != nil {
287+
return fmt.Errorf("port must be a valid number")
288+
}
289+
if port <= 0 || port > 65535 {
290+
return fmt.Errorf("port must be between 1 and 65535")
291+
}
292+
if ds.Password == "" {
293+
return fmt.Errorf("password cannot be empty")
294+
}
295+
return nil
296+
}
297+
298+
func (ds *WrenMSSQLDataSource) MapType(sourceType string) string {
299+
// This method is not used in WrenMSSQLDataSource, but required by DataSource interface
300+
switch strings.ToLower(sourceType) {
301+
case "char", "nchar":
302+
return charType
303+
case varcharType, "nvarchar":
304+
return varcharType
305+
case "text", "ntext":
306+
return textType
307+
case "bit", "tinyint":
308+
return booleanType
309+
case "smallint":
310+
return smallintType
311+
case "int":
312+
return integerType
313+
case "bigint":
314+
return bigintType
315+
case booleanType:
316+
return booleanType
317+
case "float", "real":
318+
return floatType
319+
case "decimal", "numeric", "money", "smallmoney":
320+
return decimalType
321+
case "date":
322+
return dateType
323+
case "datetime", "datetime2", "smalldatetime":
324+
return timestampType
325+
case "time":
326+
return intervalType
327+
case "datetimeoffset":
328+
return timestamptzType
329+
case "json":
330+
return jsonType
331+
default:
332+
return strings.ToLower(sourceType)
333+
}
334+
}
335+
225336
// GetActiveDataSources gets active data sources based on specified profile and target
226337
// If profileName is empty, it will use the first found profile
227338
// If targetName is empty, it will use the profile's default target

wren-launcher/commands/dbt/data_source_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,73 @@ func TestFromDbtProfiles_UnsupportedType(t *testing.T) {
193193
}
194194
}
195195

196+
func TestFromMssqlProfiles(t *testing.T) {
197+
// Test MSSQL connection conversion
198+
profiles := &DbtProfiles{
199+
Profiles: map[string]DbtProfile{
200+
"test_profile": {
201+
Target: "dev",
202+
Outputs: map[string]DbtConnection{
203+
"dev": {
204+
Type: "sqlserver",
205+
Server: testHost,
206+
Port: 1433,
207+
Database: "test_db",
208+
User: testUser,
209+
Password: testPassword,
210+
},
211+
},
212+
},
213+
},
214+
}
215+
216+
dataSources, err := FromDbtProfiles(profiles)
217+
if err != nil {
218+
t.Fatalf("FromDbtProfiles failed: %v", err)
219+
}
220+
221+
if len(dataSources) != 1 {
222+
t.Fatalf("Expected 1 data source, got %d", len(dataSources))
223+
}
224+
225+
ds, ok := dataSources[0].(*WrenMSSQLDataSource)
226+
if !ok {
227+
t.Fatalf("Expected WrenMSSQLDataSource, got %T", dataSources[0])
228+
}
229+
230+
if ds.Host != testHost {
231+
t.Errorf("Expected host '%s', got '%s'", testHost, ds.Host)
232+
}
233+
234+
if ds.Port != "1433" {
235+
t.Errorf("Expected port 1433, got %s", ds.Port)
236+
}
237+
238+
if ds.Database != "test_db" {
239+
t.Errorf("Expected database 'test_db', got '%s'", ds.Database)
240+
}
241+
242+
if ds.User != testUser {
243+
t.Errorf("Expected user '%s', got '%s'", testUser, ds.User)
244+
}
245+
246+
if ds.Password != testPassword {
247+
t.Errorf("Expected password '%s', got '%s'", testPassword, ds.Password)
248+
}
249+
250+
if ds.TdsVersion != "8.0" {
251+
t.Errorf("Expected TDS version '8.0', got '%s'", ds.TdsVersion)
252+
}
253+
254+
if ds.Driver != "ODBC Driver 18 for SQL Server" {
255+
t.Errorf("Expected driver 'ODBC Driver 18 for SQL Server', got '%s'", ds.Driver)
256+
}
257+
258+
if ds.Kwargs["TrustServerCertificate"] != "YES" {
259+
t.Errorf("Expected TrustServerCertificate 'YES', got '%s'", ds.Kwargs["TrustServerCertificate"])
260+
}
261+
}
262+
196263
func TestFromDbtProfiles_NilProfiles(t *testing.T) {
197264
// Test nil profiles
198265
_, err := FromDbtProfiles(nil)

wren-launcher/commands/dbt/profiles.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ type DbtProfile struct {
1616
type DbtConnection struct {
1717
Type string `yaml:"type" json:"type"`
1818
Host string `yaml:"host,omitempty" json:"host,omitempty"`
19+
Server string `yaml:"server,omitempty" json:"server,omitempty"` // MSSQL
1920
Port int `yaml:"port,omitempty" json:"port,omitempty"`
2021
User string `yaml:"user,omitempty" json:"user,omitempty"`
2122
Password string `yaml:"password,omitempty" json:"password,omitempty"`

wren-launcher/commands/dbt/profiles_analyzer.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,14 @@ func parseConnection(connectionMap map[string]interface{}) (*DbtConnection, erro
140140
connection.SearchPath = getString("search_path")
141141
connection.SSLMode = getString("sslmode")
142142
connection.Path = getString("path")
143+
connection.Server = getString("server")
143144

144145
// Store any additional fields that weren't mapped
145146
knownFields := map[string]bool{
146147
"type": true, "host": true, "port": true, "user": true, "password": true,
147148
"database": true, "schema": true, "project": true, "dataset": true,
148149
"keyfile": true, "account": true, "warehouse": true, "role": true,
149-
"keepalive": true, "search_path": true, "sslmode": true, "path": true,
150+
"keepalive": true, "search_path": true, "sslmode": true, "path": true, "server": true,
150151
}
151152

152153
for key, value := range connectionMap {

0 commit comments

Comments
 (0)