diff --git a/wren-launcher/commands/dbt/README.md b/wren-launcher/commands/dbt/README.md index 4d872dc554..4639f16bd0 100644 --- a/wren-launcher/commands/dbt/README.md +++ b/wren-launcher/commands/dbt/README.md @@ -1,3 +1,17 @@ +# Requirement for DBT project +This part outlines some requirements for the target dbt project: +- Ensure the DBT project is qualified and generates the required files: + - `catalog.json` + - `manifest.json` + Execute the following commands: + ``` + dbt build + dbt docs generate + ``` +- Prepare the profile of the dbt project for the connection info of your database. + - `profiles.yml` + + # How to Support a New Data Source This document outlines the steps required to add support for a new data source to the dbt project converter. diff --git a/wren-launcher/commands/dbt/converter.go b/wren-launcher/commands/dbt/converter.go index 350ab39928..afb10d03d7 100644 --- a/wren-launcher/commands/dbt/converter.go +++ b/wren-launcher/commands/dbt/converter.go @@ -154,6 +154,18 @@ func ConvertDbtProjectCore(opts ConvertOptions) (*ConvertResult, error) { "format": typedDS.Format, }, } + case *WrenMysqlDataSource: + wrenDataSource = map[string]interface{}{ + "type": "mysql", + "properties": map[string]interface{}{ + "host": typedDS.Host, + "port": typedDS.Port, + "database": typedDS.Database, + "user": typedDS.User, + "password": typedDS.Password, + "sslMode": typedDS.SslMode, + }, + } default: pterm.Warning.Printf("Warning: Unsupported data source type: %s\n", ds.GetType()) wrenDataSource = map[string]interface{}{ diff --git a/wren-launcher/commands/dbt/data_source.go b/wren-launcher/commands/dbt/data_source.go index b96d4e2b39..a12aeac538 100644 --- a/wren-launcher/commands/dbt/data_source.go +++ b/wren-launcher/commands/dbt/data_source.go @@ -3,6 +3,7 @@ package dbt import ( "fmt" "path/filepath" + "strconv" "strings" "github.com/pterm/pterm" @@ -16,6 +17,7 @@ const ( timestampType = "timestamp" doubleType = "double" booleanType = "boolean" + postgresType = "postgres" ) // Constants for SQL data types @@ -73,10 +75,12 @@ func FromDbtProfiles(profiles *DbtProfiles) ([]DataSource, error) { // convertConnectionToDataSource converts connection to corresponding DataSource based on connection type func convertConnectionToDataSource(conn DbtConnection, dbtHomePath, profileName, outputName string) (DataSource, error) { switch strings.ToLower(conn.Type) { - case "postgres", "postgresql": + case postgresType, "postgresql": return convertToPostgresDataSource(conn) case "duckdb": return convertToLocalFileDataSource(conn, dbtHomePath) + case "mysql": + return convertToMysqlDataSource(conn) default: // For unsupported database types, we can choose to ignore or return error // Here we choose to return nil and log a warning @@ -87,19 +91,26 @@ func convertConnectionToDataSource(conn DbtConnection, dbtHomePath, profileName, // convertToPostgresDataSource converts to PostgreSQL data source func convertToPostgresDataSource(conn DbtConnection) (*WrenPostgresDataSource, error) { + // For PostgreSQL, prefer dbname over database field + dbName := conn.DbName + if dbName == "" { + dbName = conn.Database + } + + pterm.Info.Printf("Converting Postgres data source: %s:%d/%s\n", conn.Host, conn.Port, dbName) + port := strconv.Itoa(conn.Port) + if conn.Port == 0 { + port = "5432" + } + ds := &WrenPostgresDataSource{ Host: conn.Host, - Port: conn.Port, - Database: conn.Database, + Port: port, + Database: dbName, User: conn.User, Password: conn.Password, } - // If no port is specified, use PostgreSQL default port - if ds.Port == 0 { - ds.Port = 5432 - } - return ds, nil } @@ -143,6 +154,30 @@ func convertToLocalFileDataSource(conn DbtConnection, dbtHome string) (*WrenLoca }, nil } +func convertToMysqlDataSource(conn DbtConnection) (*WrenMysqlDataSource, error) { + pterm.Info.Printf("Converting MySQL data source: %s:%d/%s\n", conn.Host, conn.Port, conn.Database) + + sslMode := "ENABLED" // Default SSL mode + if conn.SslDisable { + sslMode = "DISABLED" + } + port := strconv.Itoa(conn.Port) + if conn.Port == 0 { + port = "3306" + } + + ds := &WrenMysqlDataSource{ + Host: conn.Host, + Port: port, + Database: conn.Database, + User: conn.User, + Password: conn.Password, + SslMode: sslMode, + } + + return ds, nil +} + type WrenLocalFileDataSource struct { Url string `json:"url"` Format string `json:"format"` @@ -189,7 +224,7 @@ func (ds *WrenLocalFileDataSource) MapType(sourceType string) string { type WrenPostgresDataSource struct { Host string `json:"host"` - Port int `json:"port"` + Port string `json:"port"` Database string `json:"database"` User string `json:"user"` Password string `json:"password"` @@ -197,7 +232,7 @@ type WrenPostgresDataSource struct { // GetType implements DataSource interface func (ds *WrenPostgresDataSource) GetType() string { - return "postgres" + return postgresType } // Validate implements DataSource interface @@ -211,7 +246,14 @@ func (ds *WrenPostgresDataSource) Validate() error { if ds.User == "" { return fmt.Errorf("user cannot be empty") } - if ds.Port <= 0 || ds.Port > 65535 { + if ds.Port == "" { + return fmt.Errorf("port must be specified") + } + port, err := strconv.Atoi(ds.Port) + if err != nil { + return fmt.Errorf("port must be a valid number") + } + if port <= 0 || port > 65535 { return fmt.Errorf("port must be between 1 and 65535") } return nil @@ -222,6 +264,83 @@ func (ds *WrenPostgresDataSource) MapType(sourceType string) string { return sourceType } +type WrenMysqlDataSource struct { + Database string `json:"database"` + Host string `json:"host"` + Password string `json:"password"` + Port string `json:"port"` + User string `json:"user"` + SslCA string `json:"ssl_ca,omitempty"` // Optional SSL CA file for MySQL + SslMode string `json:"ssl_mode,omitempty"` // Optional SSL mode for MySQL +} + +// GetType implements DataSource interface +func (ds *WrenMysqlDataSource) GetType() string { + return "mysql" +} + +// Validate implements DataSource interface +func (ds *WrenMysqlDataSource) Validate() error { + if ds.Host == "" { + return fmt.Errorf("host cannot be empty") + } + if ds.Database == "" { + return fmt.Errorf("database cannot be empty") + } + if ds.User == "" { + return fmt.Errorf("user cannot be empty") + } + if ds.Port == "" { + return fmt.Errorf("port must be specified") + } + port, err := strconv.Atoi(ds.Port) + if err != nil { + return fmt.Errorf("port must be a valid number") + } + if port <= 0 || port > 65535 { + return fmt.Errorf("port must be between 1 and 65535") + } + return nil +} + +func (ds *WrenMysqlDataSource) MapType(sourceType string) string { + // This method is not used in WrenMysqlDataSource, but required by DataSource interface + sourceType = strings.ToUpper(sourceType) + switch sourceType { + case "CHAR": + return "char" + case "VARCHAR": + return varcharType + case "TEXT", "TINYTEXT", "MEDIUMTEXT", "LONGTEXT", "ENUM", "SET": + return "text" + case "BIT", "TINYINT": + return "TINYINT" + case "SMALLINT": + return "SMALLINT" + case "MEDIUMINT", "INT", "INTEGER": + return "INTEGER" + case "BIGINT": + return "BIGINT" + case "FLOAT", "DOUBLE": + return "DOUBLE" + case "DECIMAL", "NUMERIC": + return "DECIMAL" + case "DATE": + return "DATE" + case "DATETIME": + return "DATETIME" + case "TIMESTAMP": + return "TIMESTAMPTZ" + case "BOOLEAN", "BOOL": + return "BOOLEAN" + case "JSON": + return "JSON" + default: + // Return the original type if no mapping is found + return strings.ToLower(sourceType) + } +} + // GetActiveDataSources gets active data sources based on specified profile and target // If profileName is empty, it will use the first found profile // If targetName is empty, it will use the profile's default target @@ -326,7 +445,7 @@ func (d *DefaultDataSource) MapType(sourceType string) string { case "integer", "int", "bigint", "int64": return "integer" case "varchar", "text", "string", "char": - return "varchar" + return varcharType case "timestamp", "datetime", "date": return "timestamp" case "double", "float", "decimal", "numeric": diff --git a/wren-launcher/commands/dbt/data_source_test.go b/wren-launcher/commands/dbt/data_source_test.go index 201d8c4430..4bdd3e7c1d 100644 --- a/wren-launcher/commands/dbt/data_source_test.go +++ b/wren-launcher/commands/dbt/data_source_test.go @@ -20,8 +20,8 @@ func validatePostgresDataSource(t *testing.T, ds *WrenPostgresDataSource, expect if ds.Host != testHost { t.Errorf("Expected host '%s', got '%s'", testHost, ds.Host) } - if ds.Port != 5432 { - t.Errorf("Expected port 5432, got %d", ds.Port) + if ds.Port != "5432" { + t.Errorf("Expected port 5432, got %s", ds.Port) } if ds.Database != expectedDB { t.Errorf("Expected database '%s', got '%s'", expectedDB, ds.Database) @@ -89,12 +89,12 @@ func TestFromDbtProfiles_PostgresWithDbName(t *testing.T) { Target: "dev", Outputs: map[string]DbtConnection{ "dev": { - Type: pgType, - Host: testHost, + Type: "postgres", + Host: "localhost", Port: 5432, - Database: "jaffle_shop", - User: testUser, - Password: testPassword, + DbName: "jaffle_shop", // Using dbname instead of database + User: "test_user", + Password: "test_pass", }, }, }, @@ -115,7 +115,31 @@ func TestFromDbtProfiles_PostgresWithDbName(t *testing.T) { t.Fatalf("Expected WrenPostgresDataSource, got %T", dataSources[0]) } - validatePostgresDataSource(t, ds, "jaffle_shop") + if ds.Host != "localhost" { + t.Errorf("Expected host 'localhost', got '%s'", ds.Host) + } + if ds.Port != "5432" { + t.Errorf("Expected port 5432, got %s", ds.Port) + } + if ds.Database != "jaffle_shop" { + t.Errorf("Expected database 'jaffle_shop', got '%s'", ds.Database) + } + if ds.User != "test_user" { + t.Errorf("Expected user 'test_user', got '%s'", ds.User) + } + if ds.Password != "test_pass" { + t.Errorf("Expected password 'test_pass', got '%s'", ds.Password) + } + + // Test validation + if err := ds.Validate(); err != nil { + t.Errorf("Validation failed: %v", err) + } + + // Test type + if ds.GetType() != "postgres" { + t.Errorf("Expected type 'postgres', got '%s'", ds.GetType()) + } } func TestFromDbtProfiles_LocalFile(t *testing.T) { @@ -231,7 +255,7 @@ func testDataSourceValidation(t *testing.T, testName string, validDS Validator, func TestPostgresDataSourceValidation(t *testing.T) { validDS := &WrenPostgresDataSource{ Host: testHost, - Port: 5432, + Port: "5432", Database: "test", User: "user", } @@ -243,7 +267,7 @@ func TestPostgresDataSourceValidation(t *testing.T) { { "empty host", &WrenPostgresDataSource{ - Port: 5432, + Port: "5432", Database: "test", User: "user", }, @@ -252,7 +276,7 @@ func TestPostgresDataSourceValidation(t *testing.T) { "empty database", &WrenPostgresDataSource{ Host: testHost, - Port: 5432, + Port: "5432", User: "user", }, }, @@ -260,7 +284,7 @@ func TestPostgresDataSourceValidation(t *testing.T) { "invalid port", &WrenPostgresDataSource{ Host: testHost, - Port: 0, + Port: "0", Database: "test", User: "user", }, @@ -270,6 +294,63 @@ func TestPostgresDataSourceValidation(t *testing.T) { testDataSourceValidation(t, "postgres", validDS, invalidCases) } +func TestMysqlDataSourceValidation(t *testing.T) { + // Test MySQL data source validation + tests := []struct { + name string + ds *WrenMysqlDataSource + wantErr bool + }{ + { + name: "valid", + ds: &WrenMysqlDataSource{ + Host: "localhost", + Port: "3306", + Database: "test", + User: "user", + }, + wantErr: false, + }, + { + name: "empty host", + ds: &WrenMysqlDataSource{ + Port: "3306", + Database: "test", + User: "user", + }, + wantErr: true, + }, + { + name: "empty database", + ds: &WrenMysqlDataSource{ + Host: "localhost", + Port: "3306", + User: "user", + }, + wantErr: true, + }, + { + name: "invalid port", + ds: &WrenMysqlDataSource{ + Host: "localhost", + Port: "", + Database: "test", + User: "user", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.ds.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + func TestGetActiveDataSources(t *testing.T) { profiles := &DbtProfiles{ Profiles: map[string]DbtProfile{ diff --git a/wren-launcher/commands/dbt/profiles.go b/wren-launcher/commands/dbt/profiles.go index 3bc9c38852..e87822e70b 100644 --- a/wren-launcher/commands/dbt/profiles.go +++ b/wren-launcher/commands/dbt/profiles.go @@ -19,6 +19,7 @@ type DbtConnection struct { Port int `yaml:"port,omitempty" json:"port,omitempty"` User string `yaml:"user,omitempty" json:"user,omitempty"` Password string `yaml:"password,omitempty" json:"password,omitempty"` + DbName string `yaml:"dbname,omitempty" json:"dbname,omitempty"` // Postgres Database string `yaml:"database,omitempty" json:"database,omitempty"` Schema string `yaml:"schema,omitempty" json:"schema,omitempty"` // Additional fields for different database types @@ -33,6 +34,8 @@ type DbtConnection struct { SearchPath string `yaml:"search_path,omitempty" json:"search_path,omitempty"` // Postgres SSLMode string `yaml:"sslmode,omitempty" json:"sslmode,omitempty"` // Postgres + SslDisable bool `yaml:"ssl_disable,omitempty" json:"ssl_disable,omitempty"` // MySQL + Path string `yaml:"path,omitempty" json:"path,omitempty"` // DuckDB // Flexible additional properties Additional map[string]interface{} `yaml:",inline" json:"additional,omitempty"` diff --git a/wren-launcher/commands/dbt/profiles_analyzer.go b/wren-launcher/commands/dbt/profiles_analyzer.go index b01f79c1e0..1a0698ba83 100644 --- a/wren-launcher/commands/dbt/profiles_analyzer.go +++ b/wren-launcher/commands/dbt/profiles_analyzer.go @@ -127,6 +127,7 @@ func parseConnection(connectionMap map[string]interface{}) (*DbtConnection, erro connection.User = getString("user") connection.Password = getString("password") connection.Database = getString("database") + connection.DbName = getString("dbname") // PostgreSQL specific connection.Schema = getString("schema") // Extract database-specific fields @@ -140,13 +141,14 @@ func parseConnection(connectionMap map[string]interface{}) (*DbtConnection, erro connection.SearchPath = getString("search_path") connection.SSLMode = getString("sslmode") connection.Path = getString("path") + connection.SslDisable = getBool("ssl_disable") // MySQL specific // Store any additional fields that weren't mapped knownFields := map[string]bool{ "type": true, "host": true, "port": true, "user": true, "password": true, - "database": true, "schema": true, "project": true, "dataset": true, + "database": true, "dbname": true, "schema": true, "project": true, "dataset": true, "keyfile": true, "account": true, "warehouse": true, "role": true, - "keepalive": true, "search_path": true, "sslmode": true, "path": true, + "keepalive": true, "search_path": true, "sslmode": true, "path": true, "ssl_disable": true, } for key, value := range connectionMap {