diff --git a/wren-launcher/commands/dbt/converter.go b/wren-launcher/commands/dbt/converter.go index afb10d03d7..f03abccf86 100644 --- a/wren-launcher/commands/dbt/converter.go +++ b/wren-launcher/commands/dbt/converter.go @@ -139,6 +139,26 @@ func ConvertDbtProjectCore(opts ConvertOptions) (*ConvertResult, error) { "password": typedDS.Password, }, } + case *WrenMSSQLDataSource: + var host string + if opts.UsedByContainer { + host = handleLocalhostForContainer(typedDS.Host) + } else { + host = typedDS.Host + } + wrenDataSource = map[string]interface{}{ + "type": "mssql", + "properties": map[string]interface{}{ + "host": host, + "port": typedDS.Port, + "database": typedDS.Database, + "user": typedDS.User, + "password": typedDS.Password, + "tds_version": typedDS.TdsVersion, + "driver": typedDS.Driver, + "kwargs": typedDS.Kwargs, + }, + } case *WrenLocalFileDataSource: var url string if opts.UsedByContainer { diff --git a/wren-launcher/commands/dbt/data_source.go b/wren-launcher/commands/dbt/data_source.go index a12aeac538..9ddfdf8e0c 100644 --- a/wren-launcher/commands/dbt/data_source.go +++ b/wren-launcher/commands/dbt/data_source.go @@ -11,13 +11,22 @@ import ( // Constants for data types const ( - integerType = "integer" - varcharType = "varchar" - dateType = "date" - timestampType = "timestamp" - doubleType = "double" - booleanType = "boolean" - postgresType = "postgres" + integerType = "integer" + smallintType = "smallint" + bigintType = "bigint" + floatType = "float" + decimalType = "decimal" + varcharType = "varchar" + charType = "char" + textType = "text" + dateType = "date" + timestampType = "timestamp" + timestamptzType = "timestamptz" + doubleType = "double" + booleanType = "boolean" + jsonType = "json" + intervalType = "interval" + postgresType = "postgres" ) // Constants for SQL data types @@ -79,6 +88,8 @@ func convertConnectionToDataSource(conn DbtConnection, dbtHomePath, profileName, return convertToPostgresDataSource(conn) case "duckdb": return convertToLocalFileDataSource(conn, dbtHomePath) + case "sqlserver": + return convertToMSSQLDataSource(conn) case "mysql": return convertToMysqlDataSource(conn) default: @@ -114,6 +125,26 @@ func convertToPostgresDataSource(conn DbtConnection) (*WrenPostgresDataSource, e return ds, nil } +func convertToMSSQLDataSource(conn DbtConnection) (*WrenMSSQLDataSource, error) { + port := strconv.Itoa(conn.Port) + if conn.Port == 0 { + port = "1433" + } + + ds := &WrenMSSQLDataSource{ + Database: conn.Database, + Host: conn.Server, + Port: port, + User: conn.User, + Password: conn.Password, + TdsVersion: "8.0", // the default tds version for Wren engine image + Driver: "ODBC Driver 18 for SQL Server", // the driver used by Wren engine image + Kwargs: map[string]interface{}{"TrustServerCertificate": "YES"}, + } + + return ds, nil +} + // convertToLocalFileDataSource converts to local file data source func convertToLocalFileDataSource(conn DbtConnection, dbtHome string) (*WrenLocalFileDataSource, error) { // For file types, we need to get URL and format info from Additional fields @@ -264,6 +295,85 @@ func (ds *WrenPostgresDataSource) MapType(sourceType string) string { return sourceType } +type WrenMSSQLDataSource struct { + Database string `json:"database"` + Host string `json:"host"` + Port string `json:"port"` + User string `json:"user"` + Password string `json:"password"` + TdsVersion string `json:"tds_version"` + Driver string `json:"driver"` + Kwargs map[string]interface{} `json:"kwargs"` +} + +func (ds *WrenMSSQLDataSource) GetType() string { + return "mssql" +} + +func (ds *WrenMSSQLDataSource) Validate() error { + if ds.Host == "" { + return fmt.Errorf("host cannot be empty") + } + if ds.Database == "" { + return fmt.Errorf("database cannot be empty") + } + if ds.User == "" { + return fmt.Errorf("user cannot be empty") + } + if ds.Port == "" { + return fmt.Errorf("port must be specified") + } + port, err := strconv.Atoi(ds.Port) + if err != nil { + return fmt.Errorf("port must be a valid number") + } + if port <= 0 || port > 65535 { + return fmt.Errorf("port must be between 1 and 65535") + } + if ds.Password == "" { + return fmt.Errorf("password cannot be empty") + } + return nil +} + +func (ds *WrenMSSQLDataSource) MapType(sourceType string) string { + // This method is not used in WrenMSSQLDataSource, but required by DataSource interface + switch strings.ToLower(sourceType) { + case charType, "nchar": + return charType + case varcharType, "nvarchar": + return varcharType + case textType, "ntext": + return textType + case "bit", "tinyint": + return booleanType + case "smallint": + return smallintType + case "int": + return integerType + case "bigint": + return bigintType + case booleanType: + return booleanType + case "float", "real": + return floatType + case "decimal", "numeric", "money", "smallmoney": + return decimalType + case "date": + return dateType + case "datetime", "datetime2", "smalldatetime": + return timestampType + case "time": + return intervalType + case "datetimeoffset": + return timestamptzType + case "json": + return jsonType + default: + return strings.ToLower(sourceType) + } +} + type WrenMysqlDataSource struct { Database string `json:"database"` Host string `json:"host"` diff --git a/wren-launcher/commands/dbt/data_source_test.go b/wren-launcher/commands/dbt/data_source_test.go index 4bdd3e7c1d..925d3c72c2 100644 --- a/wren-launcher/commands/dbt/data_source_test.go +++ b/wren-launcher/commands/dbt/data_source_test.go @@ -217,6 +217,73 @@ func TestFromDbtProfiles_UnsupportedType(t *testing.T) { } } +func TestFromMssqlProfiles(t *testing.T) { + // Test MSSQL connection conversion + profiles := &DbtProfiles{ + Profiles: map[string]DbtProfile{ + "test_profile": { + Target: "dev", + Outputs: map[string]DbtConnection{ + "dev": { + Type: "sqlserver", + Server: testHost, + Port: 1433, + Database: "test_db", + User: testUser, + Password: testPassword, + }, + }, + }, + }, + } + + dataSources, err := FromDbtProfiles(profiles) + if err != nil { + t.Fatalf("FromDbtProfiles failed: %v", err) + } + + if len(dataSources) != 1 { + t.Fatalf("Expected 1 data source, got %d", len(dataSources)) + } + + ds, ok := dataSources[0].(*WrenMSSQLDataSource) + if !ok { + t.Fatalf("Expected WrenMSSQLDataSource, got %T", dataSources[0]) + } + + if ds.Host != testHost { + t.Errorf("Expected host '%s', got '%s'", testHost, ds.Host) + } + + if ds.Port != "1433" { + t.Errorf("Expected port 1433, got %s", ds.Port) + } + + if ds.Database != "test_db" { + t.Errorf("Expected database 'test_db', got '%s'", ds.Database) + } + + if ds.User != testUser { + t.Errorf("Expected user '%s', got '%s'", testUser, ds.User) + } + + if ds.Password != testPassword { + t.Errorf("Expected password '%s', got '%s'", testPassword, ds.Password) + } + + if ds.TdsVersion != "8.0" { + t.Errorf("Expected TDS version '8.0', got '%s'", ds.TdsVersion) + } + + if ds.Driver != "ODBC Driver 18 for SQL Server" { + t.Errorf("Expected driver 'ODBC Driver 18 for SQL Server', got '%s'", ds.Driver) + } + + if ds.Kwargs["TrustServerCertificate"] != "YES" { + t.Errorf("Expected TrustServerCertificate 'YES', got '%s'", ds.Kwargs["TrustServerCertificate"]) + } +} + func TestFromDbtProfiles_NilProfiles(t *testing.T) { // Test nil profiles _, err := FromDbtProfiles(nil) diff --git a/wren-launcher/commands/dbt/profiles.go b/wren-launcher/commands/dbt/profiles.go index e87822e70b..231d256ecd 100644 --- a/wren-launcher/commands/dbt/profiles.go +++ b/wren-launcher/commands/dbt/profiles.go @@ -16,6 +16,7 @@ type DbtProfile struct { type DbtConnection struct { Type string `yaml:"type" json:"type"` Host string `yaml:"host,omitempty" json:"host,omitempty"` + Server string `yaml:"server,omitempty" json:"server,omitempty"` // MSSQL Port int `yaml:"port,omitempty" json:"port,omitempty"` User string `yaml:"user,omitempty" json:"user,omitempty"` Password string `yaml:"password,omitempty" json:"password,omitempty"` diff --git a/wren-launcher/commands/dbt/profiles_analyzer.go b/wren-launcher/commands/dbt/profiles_analyzer.go index 1a0698ba83..8100029de3 100644 --- a/wren-launcher/commands/dbt/profiles_analyzer.go +++ b/wren-launcher/commands/dbt/profiles_analyzer.go @@ -5,6 +5,7 @@ import ( "os" "path/filepath" "runtime" + "strconv" "gopkg.in/yaml.v3" ) @@ -105,6 +106,12 @@ func parseConnection(connectionMap map[string]interface{}) (*DbtConnection, erro return v case float64: return int(v) + case int64: + return int(v) + case string: + if i, err := strconv.Atoi(v); err == nil { + return i + } } } return 0 @@ -142,13 +149,14 @@ func parseConnection(connectionMap map[string]interface{}) (*DbtConnection, erro connection.SSLMode = getString("sslmode") connection.Path = getString("path") connection.SslDisable = getBool("ssl_disable") // MySQL specific + connection.Server = getString("server") // Store any additional fields that weren't mapped knownFields := map[string]bool{ "type": true, "host": true, "port": true, "user": true, "password": true, "database": true, "dbname": true, "schema": true, "project": true, "dataset": true, "keyfile": true, "account": true, "warehouse": true, "role": true, - "keepalive": true, "search_path": true, "sslmode": true, "path": true, "ssl_disable": true, + "keepalive": true, "search_path": true, "sslmode": true, "path": true, "server": true, "ssl_disable": true, } for key, value := range connectionMap {