diff --git a/wren-launcher/commands/dbt.go b/wren-launcher/commands/dbt.go index f23643cefe..60c85a6534 100644 --- a/wren-launcher/commands/dbt.go +++ b/wren-launcher/commands/dbt.go @@ -12,10 +12,11 @@ import ( // then converts them to WrenDataSource and Wren MDL format func DbtAutoConvert() { var opts struct { - ProjectPath string - OutputDir string - ProfileName string - Target string + ProjectPath string + OutputDir string + ProfileName string + Target string + IncludeStagingModels bool } // Define command line flags @@ -23,6 +24,7 @@ func DbtAutoConvert() { flag.StringVar(&opts.OutputDir, "output", "", "Output directory for generated JSON files") flag.StringVar(&opts.ProfileName, "profile", "", "Specific profile name to use (optional, uses first found if not provided)") flag.StringVar(&opts.Target, "target", "", "Specific target to use (optional, uses profile default if not provided)") + flag.BoolVar(&opts.IncludeStagingModels, "include-staging-models", false, "If set, staging models will be included during conversion") flag.Parse() // Validate required parameters @@ -40,11 +42,12 @@ func DbtAutoConvert() { // ConvertOptions struct for core conversion logic convertOpts := dbt.ConvertOptions{ - ProjectPath: opts.ProjectPath, - OutputDir: opts.OutputDir, - ProfileName: opts.ProfileName, - Target: opts.Target, - RequireCatalog: true, // DbtAutoConvert requires catalog.json to exist + ProjectPath: opts.ProjectPath, + OutputDir: opts.OutputDir, + ProfileName: opts.ProfileName, + Target: opts.Target, + RequireCatalog: true, // DbtAutoConvert requires catalog.json to exist + IncludeStagingModels: opts.IncludeStagingModels, } // Call the core conversion logic @@ -57,14 +60,15 @@ func DbtAutoConvert() { // DbtConvertProject is a public wrapper function for processDbtProject to use // It converts a dbt project without requiring catalog.json to exist -func DbtConvertProject(projectPath, outputDir, profileName, target string, usedByContainer bool) (*dbt.ConvertResult, error) { +func DbtConvertProject(projectPath, outputDir, profileName, target string, usedByContainer bool, IncludeStagingModels bool) (*dbt.ConvertResult, error) { convertOpts := dbt.ConvertOptions{ - ProjectPath: projectPath, - OutputDir: outputDir, - ProfileName: profileName, - Target: target, - RequireCatalog: false, // Allow processDbtProject to continue without catalog.json - UsedByContainer: usedByContainer, + ProjectPath: projectPath, + OutputDir: outputDir, + ProfileName: profileName, + Target: target, + RequireCatalog: false, // Allow processDbtProject to continue without catalog.json + UsedByContainer: usedByContainer, + IncludeStagingModels: IncludeStagingModels, } return dbt.ConvertDbtProjectCore(convertOpts) diff --git a/wren-launcher/commands/dbt/converter.go b/wren-launcher/commands/dbt/converter.go index afb10d03d7..acc1b3aa16 100644 --- a/wren-launcher/commands/dbt/converter.go +++ b/wren-launcher/commands/dbt/converter.go @@ -5,20 +5,25 @@ import ( "fmt" "os" "path/filepath" + "regexp" "sort" "strings" "github.com/pterm/pterm" ) +// Note: All struct definitions (WrenMDLManifest, WrenModel, etc.) are defined +// in wren_mdl.go to prevent "redeclared in this block" compilation errors. + // ConvertOptions holds the options for dbt project conversion type ConvertOptions struct { - ProjectPath string - OutputDir string - ProfileName string - Target string - RequireCatalog bool // if true, missing catalog.json is an error; if false, it's a warning - UsedByContainer bool // if true, used by container, no need to print usage info + ProjectPath string + OutputDir string + ProfileName string + Target string + RequireCatalog bool // if true, missing catalog.json is an error; if false, it's a warning + UsedByContainer bool // if true, used by container, no need to print usage info + IncludeStagingModels bool // if true, staging models will be included in the conversion } // ConvertResult holds the result of dbt project conversion @@ -51,10 +56,11 @@ func ConvertDbtProjectCore(opts ConvertOptions) (*ConvertResult, error) { pterm.Info.Println("Skipping data source conversion...") } - // Search for catalog.json and manifest.json in target directory + // Search for catalog.json, manifest.json, and semantic_manifest.json in target directory targetDir := filepath.Join(opts.ProjectPath, "target") catalogPath := filepath.Join(targetDir, "catalog.json") manifestPath := filepath.Join(targetDir, "manifest.json") + semanticManifestPath := filepath.Join(targetDir, "semantic_manifest.json") if !FileExists(catalogPath) { if opts.RequireCatalog { @@ -66,14 +72,23 @@ func ConvertDbtProjectCore(opts ConvertOptions) (*ConvertResult, error) { } } - // Check for manifest.json (optional but recommended for descriptions) + // Check for manifest.json (optional but recommended for descriptions and relationships) var manifestPathForConversion string if FileExists(manifestPath) { pterm.Info.Printf("Found manifest.json at: %s\n", manifestPath) manifestPathForConversion = manifestPath } else { pterm.Warning.Printf("Warning: manifest.json not found at: %s\n", manifestPath) - pterm.Info.Println("Model and column descriptions will not be included") + pterm.Info.Println("Model descriptions, column descriptions, and relationships will not be included") + } + + // Check for semantic_manifest.json (optional) + var semanticManifestPathForConversion string + if FileExists(semanticManifestPath) { + pterm.Info.Printf("Found semantic_manifest.json at: %s\n", semanticManifestPath) + semanticManifestPathForConversion = semanticManifestPath + } else { + pterm.Info.Println("semantic_manifest.json not found, skipping metric and primary key conversion.") } // Convert profiles.yml to WrenDataSource (if profiles found) @@ -154,6 +169,15 @@ func ConvertDbtProjectCore(opts ConvertOptions) (*ConvertResult, error) { "format": typedDS.Format, }, } + case *WrenBigQueryDataSource: + wrenDataSource = map[string]interface{}{ + "type": "bigquery", + "properties": map[string]interface{}{ + "project_id": typedDS.Project, + "dataset_id": typedDS.Dataset, + "credentials": typedDS.Credentials, + }, + } case *WrenMysqlDataSource: wrenDataSource = map[string]interface{}{ "type": "mysql", @@ -198,7 +222,7 @@ func ConvertDbtProjectCore(opts ConvertOptions) (*ConvertResult, error) { ds = &DefaultDataSource{} } - manifest, err := ConvertDbtCatalogToWrenMDL(catalogPath, ds, manifestPathForConversion) + manifest, err := ConvertDbtCatalogToWrenMDL(catalogPath, ds, manifestPathForConversion, semanticManifestPathForConversion, opts.IncludeStagingModels) if err != nil { return nil, fmt.Errorf("failed to convert catalog: %w", err) } @@ -219,6 +243,9 @@ func ConvertDbtProjectCore(opts ConvertOptions) (*ConvertResult, error) { // Summary pterm.Success.Println("\nšŸŽ‰ Conversion completed successfully!") pterm.Info.Printf("Models converted: %d\n", len(manifest.Models)) + pterm.Info.Printf("Relationships generated: %d\n", len(manifest.Relationships)) + pterm.Info.Printf("Metrics generated: %d\n", len(manifest.Metrics)) + pterm.Info.Printf("Enums generated: %d\n", len(manifest.EnumDefinitions)) if dataSourceGenerated { pterm.Info.Println("Generated files:") @@ -253,133 +280,256 @@ func handleLocalhostForContainer(host string) string { return host } -// ConvertDbtCatalogToWrenMDL converts dbt catalog.json to Wren MDL format -func ConvertDbtCatalogToWrenMDL(catalogPath string, data_source DataSource, manifestPath string) (*WrenMDLManifest, error) { - // Read and parse the catalog.json file - data, err := os.ReadFile(catalogPath) // #nosec G304 -- catalogPath is controlled by application +// ConvertDbtCatalogToWrenMDL is the main function to convert a dbt catalog into a Wren MDL manifest. +// It orchestrates the reading of dbt artifacts and processes each dbt node to convert it into a Wren model. +func ConvertDbtCatalogToWrenMDL(catalogPath string, dataSource DataSource, manifestPath string, semanticManifestPath string, includeStagingModels bool) (*WrenMDLManifest, error) { + // --- 1. Read and Parse All Necessary DBT Artifact Files --- + + // Read and unmarshal the primary catalog.json file. + catalogBytes, err := os.ReadFile(filepath.Clean(catalogPath)) if err != nil { return nil, fmt.Errorf("failed to read catalog file %s: %w", catalogPath, err) } - var catalogData map[string]interface{} - if err := json.Unmarshal(data, &catalogData); err != nil { + if err := json.Unmarshal(catalogBytes, &catalogData); err != nil { return nil, fmt.Errorf("failed to parse catalog JSON: %w", err) } - // Parse manifest.json for descriptions (optional) + // Read and unmarshal the manifest.json file, which contains rich metadata. var manifestData map[string]interface{} if manifestPath != "" { - pterm.Info.Printf("Reading manifest.json for descriptions from: %s\n", manifestPath) - manifestBytes, err := os.ReadFile(manifestPath) // #nosec G304 -- manifestPath is controlled by application + pterm.Info.Printf("Reading manifest.json for descriptions and relationships from: %s\n", manifestPath) + manifestBytes, err := os.ReadFile(filepath.Clean(manifestPath)) // #nosec G304 -- manifestPath is controlled by application if err != nil { - pterm.Warning.Printf("Warning: Failed to read manifest file %s: %v\n", manifestPath, err) - } else { - if err := json.Unmarshal(manifestBytes, &manifestData); err != nil { - pterm.Warning.Printf("Warning: Failed to parse manifest JSON: %v\n", err) - } + pterm.Warning.Printf("Could not read manifest file %s: %v. Descriptions and relationships will be missing.\n", manifestPath, err) + } else if err := json.Unmarshal(manifestBytes, &manifestData); err != nil { + pterm.Warning.Printf("Could not parse manifest file %s: %v. Descriptions and relationships will be missing.\n", manifestPath, err) } } - // Extract nodes + // Read and unmarshal the semantic_manifest.json file for metrics and primary keys. + var semanticManifestData map[string]interface{} + if semanticManifestPath != "" { + semanticBytes, err := os.ReadFile(filepath.Clean(semanticManifestPath)) + if err != nil { + pterm.Warning.Printf("Could not read semantic_manifest.json: %v. Metrics and primary keys will be missing.\n", err) + } else if err := json.Unmarshal(semanticBytes, &semanticManifestData); err != nil { + pterm.Warning.Printf("Could not parse semantic_manifest.json: %v. Metrics and primary keys will be missing.\n", err) + } + } + + // --- 2. Initialize Wren Manifest and Pre-process Metadata --- + + manifest := &WrenMDLManifest{ + Catalog: "wren", + Schema: "public", + EnumDefinitions: []EnumDefinition{}, + Models: []WrenModel{}, + Relationships: []Relationship{}, + Metrics: []Metric{}, + Views: []View{}, + DataSources: dataSource.GetType(), + } + + // Create lookup maps to store pre-processed information for quick access. + enumValueToNameMap := make(map[string]string) + columnToEnumNameMap := make(map[string]string) + columnToNotNullMap := make(map[string]bool) + modelToPrimaryKeyMap := make(map[string]string) + + // Pre-process the manifest to extract test data (enums, not-null constraints). + if manifestData != nil { + preprocessManifestForTests(manifestData, &manifest.EnumDefinitions, enumValueToNameMap, columnToEnumNameMap, columnToNotNullMap) + } + + // Pre-process the semantic manifest to extract primary key information. + if semanticManifestData != nil { + preprocessSemanticManifestForPrimaryKeys(semanticManifestData, modelToPrimaryKeyMap) + } + + // --- 3. Convert dbt Nodes to Wren Models --- + nodesValue, exists := catalogData["nodes"] if !exists { return nil, fmt.Errorf("no 'nodes' section found in catalog") } - nodesMap, ok := nodesValue.(map[string]interface{}) if !ok { return nil, fmt.Errorf("invalid 'nodes' format in catalog") } - // Initialize Wren MDL manifest - manifest := &WrenMDLManifest{ - Catalog: "wren", // Default catalog name - Schema: "public", // Default schema name - Models: []WrenModel{}, - Relationships: []Relationship{}, - Views: []View{}, - DataSources: data_source.GetType(), // Default data source name - } - - // Convert each dbt model to Wren model + // Iterate through each node in the catalog and convert it to a Wren model. for nodeKey, nodeValue := range nodesMap { nodeMap, ok := nodeValue.(map[string]interface{}) if !ok { continue } - - // Only process models (skip seeds, tests, etc.) + // We are only interested in nodes that represent dbt models. if !strings.HasPrefix(nodeKey, "model.") { continue } - // Skip staging models - if strings.Contains(nodeKey, ".stg_") || strings.Contains(nodeKey, ".staging_") { + // Skip staging models if the user has opted to exclude them. + modelName := getModelNameFromNodeKey(nodeKey) + if !includeStagingModels && (strings.HasPrefix(modelName, "stg_") || strings.HasPrefix(modelName, "staging_")) { continue } - model, err := convertDbtNodeToWrenModel(nodeKey, nodeMap, data_source, manifestData) + // Perform the conversion for the single node. + model, err := convertDbtNodeToWrenModel(nodeKey, nodeMap, dataSource, manifestData, columnToEnumNameMap, columnToNotNullMap, modelToPrimaryKeyMap) if err != nil { - pterm.Warning.Printf("Warning: Failed to convert model %s: %v\n", nodeKey, err) + pterm.Warning.Printf("Failed to convert model %s: %v\n", nodeKey, err) continue } - manifest.Models = append(manifest.Models, *model) } - return manifest, nil -} + // --- 4. Generate Relationships and Metrics --- -// convertDbtNodeToWrenModel converts a single dbt node to Wren model -func convertDbtNodeToWrenModel(nodeKey string, nodeData map[string]interface{}, data_source DataSource, manifestData map[string]interface{}) (*WrenModel, error) { - // Extract model name from node key (e.g., "model.jaffle_shop.customers" -> "customers") - parts := strings.Split(nodeKey, ".") - if len(parts) < 3 { - return nil, fmt.Errorf("invalid node key format: %s", nodeKey) + // Generate relationships between models based on the dbt manifest. + if manifestData != nil { + manifest.Relationships = generateRelationships(manifestData) } - modelName := parts[len(parts)-1] - // Extract metadata - metadataValue, exists := nodeData["metadata"] - if !exists { - return nil, fmt.Errorf("no metadata found for model %s", nodeKey) + // Generate metrics from the semantic manifest. + if semanticManifestData != nil { + manifest.Metrics = convertDbtMetricsToWrenMetrics(semanticManifestData) } - metadata, ok := metadataValue.(map[string]interface{}) + return manifest, nil +} + +// preprocessManifestForTests extracts information from dbt tests (like 'not_null' and 'accepted_values') +// and populates maps that will be used later during model conversion. +func preprocessManifestForTests(manifestData map[string]interface{}, enums *[]EnumDefinition, enumValueToNameMap, columnToEnumNameMap map[string]string, columnToNotNullMap map[string]bool) { + nodes, ok := manifestData["nodes"].(map[string]interface{}) if !ok { - return nil, fmt.Errorf("invalid metadata format for model %s", nodeKey) + return } - // Create table reference - tableRef := TableReference{ - Table: getStringFromMap(metadata, "name", modelName), + for nodeKey, nodeValue := range nodes { + nodeMap, ok := nodeValue.(map[string]interface{}) + if !ok { + continue + } + + // Process tests defined directly on model columns. + if strings.HasPrefix(nodeKey, "model.") { + modelName := getModelNameFromNodeKey(nodeKey) + if columns, ok := nodeMap["columns"].(map[string]interface{}); ok { + for columnName, colData := range columns { + if colMap, ok := colData.(map[string]interface{}); ok { + processColumnForTests(nodeKey, modelName, columnName, colMap, enums, enumValueToNameMap, columnToEnumNameMap, columnToNotNullMap) + } + } + } + } + + // Process compiled test nodes which are separate entries in the manifest. + if strings.HasPrefix(nodeKey, "test.") { + testMeta, _ := nodeMap["test_metadata"].(map[string]interface{}) + testName := getStringFromMap(testMeta, "name", "") + attachedNodeID := getStringFromMap(nodeMap, "attached_node", "") + columnName := getStringFromMap(nodeMap, "column_name", "") + + if attachedNodeID != "" && columnName != "" { + columnKey := fmt.Sprintf("%s.%s", attachedNodeID, columnName) + modelName := getModelNameFromNodeKey(attachedNodeID) + + if testName == "not_null" { + columnToNotNullMap[columnKey] = true + } + + if testName == "accepted_values" { + if kwargs, ok := testMeta["kwargs"].(map[string]interface{}); ok { + if values, ok := kwargs["values"].([]interface{}); ok && len(values) > 0 { + createOrLinkEnum(modelName, columnName, columnKey, values, enums, enumValueToNameMap, columnToEnumNameMap) + } + } + } + } + } } +} - if catalog := getStringFromMap(metadata, "database", ""); catalog != "" { - tableRef.Catalog = catalog +// preprocessSemanticManifestForPrimaryKeys extracts primary key information from the semantic manifest. +func preprocessSemanticManifestForPrimaryKeys(semanticData map[string]interface{}, modelToPrimaryKeyMap map[string]string) { + semanticModels, ok := semanticData["semantic_models"].([]interface{}) + if !ok { + return } - if schema := getStringFromMap(metadata, "schema", ""); schema != "" { - tableRef.Schema = schema + + for _, sm := range semanticModels { + smMap, ok := sm.(map[string]interface{}) + if !ok { + continue + } + + var modelName string + if nr, ok := smMap["node_relation"].(map[string]interface{}); ok { + modelName = getStringFromMap(nr, "alias", "") + } + + if entities, ok := smMap["entities"].([]interface{}); ok { + for _, entity := range entities { + if entityMap, ok := entity.(map[string]interface{}); ok { + if getStringFromMap(entityMap, "type", "") == "primary" { + pk := getStringFromMap(entityMap, "expr", "") + if modelName != "" && pk != "" { + modelToPrimaryKeyMap[modelName] = pk + } + } + } + } + } } +} - // Extract descriptions from manifest.json if available - var modelDescription string - var columnDescriptions map[string]string +// generateRelationships iterates through the manifest and creates relationship definitions. +func generateRelationships(manifestData map[string]interface{}) []Relationship { + var relationships []Relationship + if nodes, ok := manifestData["nodes"].(map[string]interface{}); ok { + for nodeKey, nodeValue := range nodes { + nodeMap, ok := nodeValue.(map[string]interface{}) + if !ok { + continue + } - if manifestData != nil { - if nodes, ok := manifestData["nodes"].(map[string]interface{}); ok { - if manifestNode, ok := nodes[nodeKey].(map[string]interface{}); ok { - // Extract model description - modelDescription = getStringFromMap(manifestNode, "description", "") - - // Extract column descriptions - if manifestColumns, ok := manifestNode["columns"].(map[string]interface{}); ok { - columnDescriptions = make(map[string]string) - for colName, colData := range manifestColumns { + // Case 1: Handle tests on model columns (including structs) + if strings.HasPrefix(nodeKey, "model.") { + fromModelName := getModelNameFromNodeKey(nodeKey) + if fromModelName == "" { + continue + } + if columns, ok := nodeMap["columns"].(map[string]interface{}); ok { + for columnName, colData := range columns { if colMap, ok := colData.(map[string]interface{}); ok { - description := getStringFromMap(colMap, "description", "") - if description != "" { - columnDescriptions[colName] = description + relationships = append(relationships, parseTestsForRelationships(fromModelName, columnName, colMap)...) + } + } + } + } + + // Case 2: Handle compiled test nodes for simple columns + if strings.HasPrefix(nodeKey, "test.") { + if testMeta, ok := nodeMap["test_metadata"].(map[string]interface{}); ok { + if getStringFromMap(testMeta, "name", "") == "relationships" { + if kwargs, ok := testMeta["kwargs"].(map[string]interface{}); ok { + toRef := getStringFromMap(kwargs, "to", "") + toField := getStringFromMap(kwargs, "field", "") + toModelName := parseRef(toRef) + fromColumnName := getStringFromMap(nodeMap, "column_name", "") + attachedNodeID := getStringFromMap(nodeMap, "attached_node", "") + fromModelName := getModelNameFromNodeKey(attachedNodeID) + + if toModelName != "" && toField != "" && fromModelName != "" && fromColumnName != "" { + rel := Relationship{ + Name: fmt.Sprintf("%s_to_%s_by_%s", fromModelName, toModelName, fromColumnName), + Models: []string{fromModelName, toModelName}, + JoinType: "MANY_TO_ONE", + Condition: fmt.Sprintf("%s.%s = %s.%s", fromModelName, fromColumnName, toModelName, toField), + } + relationships = append(relationships, rel) } } } @@ -387,69 +537,515 @@ func convertDbtNodeToWrenModel(nodeKey string, nodeData map[string]interface{}, } } } + seen := make(map[string]struct{}, len(relationships)) + var unique []Relationship + for _, r := range relationships { + key := r.Name + "|" + r.JoinType + "|" + r.Condition + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + unique = append(unique, r) + } + return unique +} - // Convert columns - columnsValue, exists := nodeData["columns"] +// parseTestsForRelationships is a helper function to extract relationship tests from a column or its fields. +func parseTestsForRelationships(fromModelName, columnName string, colMap map[string]interface{}) []Relationship { + var relationships []Relationship + // Case 1: Tests are directly on the column. + if tests, ok := colMap["tests"].([]interface{}); ok { + relationships = append(relationships, extractRelationshipsFromTests(fromModelName, columnName, tests)...) + } + // Case 2: Tests are on fields within a struct column. + if fields, ok := colMap["fields"].([]interface{}); ok { + for _, fieldData := range fields { + if fieldMap, ok := fieldData.(map[string]interface{}); ok { + fieldName := getStringFromMap(fieldMap, "name", "") + if fieldName == "" { + continue + } + if tests, ok := fieldMap["tests"].([]interface{}); ok { + relationships = append(relationships, extractRelationshipsFromTests(fromModelName, fieldName, tests)...) + } + } + } + } + return relationships +} + +// extractRelationshipsFromTests extracts relationship info from a 'tests' array. +func extractRelationshipsFromTests(fromModelName, fromColumnName string, tests []interface{}) []Relationship { + var relationships []Relationship + for _, test := range tests { + if relTest, ok := test.(map[string]interface{}); ok { + if relData, ok := relTest["relationships"].(map[string]interface{}); ok { + toRef := getStringFromMap(relData, "to", "") + toField := getStringFromMap(relData, "field", "") + toModelName := parseRef(toRef) + + if toModelName != "" && toField != "" { + rel := Relationship{ + Name: fmt.Sprintf("%s_to_%s_by_%s", fromModelName, toModelName, fromColumnName), + Models: []string{fromModelName, toModelName}, + JoinType: "MANY_TO_ONE", + Condition: fmt.Sprintf("%s.%s = %s.%s", fromModelName, fromColumnName, toModelName, toField), + } + relationships = append(relationships, rel) + } + } + } + } + return relationships +} + +// createOrLinkEnum is a helper to de-duplicate and manage enum creation based on 'accepted_values' tests. +func createOrLinkEnum(modelName, columnName, columnKey string, values []interface{}, allEnums *[]EnumDefinition, enumValueToNameMap, columnToEnumNameMap map[string]string) { + var strValues []string + for _, v := range values { + if s, ok := v.(string); ok { + strValues = append(strValues, s) + } + } + if len(strValues) == 0 { + return + } + sort.Strings(strValues) + valueKey := strings.Join(strValues, ",") + + enumName, exists := enumValueToNameMap[valueKey] if !exists { - return nil, fmt.Errorf("no columns found for model %s", nodeKey) + enumName = fmt.Sprintf("%s_%s_Enum", modelName, columnName) + // Sanitize enum name to be a valid identifier + re := regexp.MustCompile(`[^a-zA-Z0-9_]`) + enumName = re.ReplaceAllString(enumName, "_") + if len(enumName) > 0 && enumName[0] >= '0' && enumName[0] <= '9' { + enumName = "_" + enumName + } + *allEnums = append(*allEnums, EnumDefinition{ + Name: enumName, + Values: strValues, + }) + enumValueToNameMap[valueKey] = enumName + } + columnToEnumNameMap[columnKey] = enumName +} + +// processColumnForTests finds tests in a column definition (including nested fields) and processes them. +func processColumnForTests(nodeKey, modelName, columnName string, colMap map[string]interface{}, allEnums *[]EnumDefinition, enumValueToNameMap, columnToEnumNameMap map[string]string, columnToNotNullMap map[string]bool) { + // Helper to handle the actual test processing for a given column/field + processTests := func(currentColumnKey, currentColumnName string, tests []interface{}) { + for _, test := range tests { + // Handle not_null test (string format) + if testStr, ok := test.(string); ok && testStr == "not_null" { + columnToNotNullMap[currentColumnKey] = true + } + + // Handle tests in map format (e.g., accepted_values) + if testMap, ok := test.(map[string]interface{}); ok { + if accepted, ok := testMap["accepted_values"].(map[string]interface{}); ok { + if values, ok := accepted["values"].([]interface{}); ok && len(values) > 0 { + createOrLinkEnum(modelName, currentColumnName, currentColumnKey, values, allEnums, enumValueToNameMap, columnToEnumNameMap) + } + } + } + } } - columnsMap, ok := columnsValue.(map[string]interface{}) + // Case 1: Tests are directly on the column itself. + if tests, ok := colMap["tests"].([]interface{}); ok { + columnKey := fmt.Sprintf("%s.%s", nodeKey, columnName) + processTests(columnKey, columnName, tests) + } + + // Case 2: The column is a struct, and tests are on its fields. + if fields, ok := colMap["fields"].([]interface{}); ok { + for _, fieldData := range fields { + if fieldMap, ok := fieldData.(map[string]interface{}); ok { + fieldName := getStringFromMap(fieldMap, "name", "") + if fieldName == "" { + continue + } + if tests, ok := fieldMap["tests"].([]interface{}); ok { + // The unique key for a field is based on the field name. + columnKey := fmt.Sprintf("%s.%s", nodeKey, fieldName) + processTests(columnKey, fieldName, tests) + } + } + } + } +} + +// convertDbtMetricsToWrenMetrics converts dbt metrics from the semantic manifest into the Wren MDL format. +// It serves as the main entry point for metric conversion, orchestrating the creation of lookup tables +// and processing each metric definition. +func convertDbtMetricsToWrenMetrics(semanticData map[string]interface{}) []Metric { + var wrenMetrics []Metric + + // --- 1. Pre-process semantic models to build fast lookup maps --- + // These maps are essential for quickly finding the model a measure belongs to and its details. + measureToModelMap, measureDataLookup := buildMeasureLookups(semanticData) + + // --- 2. Iterate through each metric and convert it --- + metrics, ok := semanticData["metrics"].([]interface{}) if !ok { - return nil, fmt.Errorf("invalid columns format for model %s", nodeKey) + // If there's no 'metrics' array, there's nothing to do. + return wrenMetrics } - var wrenColumns []WrenColumn - for _, colValue := range columnsMap { - colMap, ok := colValue.(map[string]interface{}) + for _, m := range metrics { + metricMap, ok := m.(map[string]interface{}) + if !ok { + continue // Skip if the item is not a valid map. + } + + // --- 3. Extract basic metric information --- + metricName := getStringFromMap(metricMap, "name", "") + if metricName == "" { + continue // A metric must have a name. + } + + wrenMetric := Metric{ + Name: metricName, + DisplayName: getStringFromMap(metricMap, "label", metricName), + Description: getStringFromMap(metricMap, "description", ""), + } + + typeParams, _ := metricMap["type_params"].(map[string]interface{}) + + // --- 4. Determine the base model and time dimensions for the metric --- + baseModel := findBaseModelForMetric(typeParams, measureToModelMap) + if baseModel == "" { + pterm.Warning.Printf("Could not find a parent model for metric '%s'\n", metricName) + continue // Skip metric if we can't associate it with a model. + } + + wrenMetric.Models = []string{baseModel} + wrenMetric.Dimensions = findTimeDimensionsForModel(semanticData, baseModel) + + // --- 5. Build the specific aggregation expression based on the metric type --- + metricType := getStringFromMap(metricMap, "type", "") + wrenMetric.Aggregation = buildAggregationExpression(metricType, typeParams, measureDataLookup) + + // --- 6. Final validation before adding to the list --- + // A metric is only valid if it has a base model and a valid aggregation expression. + if wrenMetric.Aggregation != "" && len(wrenMetric.Models) > 0 { + wrenMetrics = append(wrenMetrics, wrenMetric) + } + } + + return wrenMetrics +} + +// buildMeasureLookups preprocesses the semantic models to create two essential maps: +// 1. measureToModelMap: Maps a measure's name to the name of the model it belongs to. +// 2. measureDataLookup: Maps a measure's name to its full data map for easy access to properties like `agg` and `expr`. +func buildMeasureLookups(semanticData map[string]interface{}) (map[string]string, map[string]map[string]interface{}) { + measureToModelMap := make(map[string]string) + measureDataLookup := make(map[string]map[string]interface{}) + + semanticModels, ok := semanticData["semantic_models"].([]interface{}) + if !ok { + return measureToModelMap, measureDataLookup + } + + for _, sm := range semanticModels { + smMap, ok := sm.(map[string]interface{}) if !ok { continue } - column := WrenColumn{ - Name: getStringFromMap(colMap, "name", ""), - Type: data_source.MapType(getStringFromMap(colMap, "type", "")), + modelName := getStringFromMap(smMap, "name", "") + if modelName == "" { + continue } - // Initialize properties map if needed - if column.Properties == nil { - column.Properties = make(map[string]string) + if measures, ok := smMap["measures"].([]interface{}); ok { + for _, m := range measures { + if measureMap, ok := m.(map[string]interface{}); ok { + measureName := getStringFromMap(measureMap, "name", "") + if measureName != "" { + measureToModelMap[measureName] = modelName + measureDataLookup[measureName] = measureMap + } + } + } } + } + return measureToModelMap, measureDataLookup +} - // Set description from manifest if available - if columnDescriptions != nil { - if description, exists := columnDescriptions[column.Name]; exists && description != "" { - column.Properties["description"] = description +// findBaseModelForMetric identifies the underlying base model for a given metric +// by looking at its "input_measures". +func findBaseModelForMetric(typeParams map[string]interface{}, measureToModelMap map[string]string) string { + inputMeasuresValue, ok := typeParams["input_measures"] + if !ok { + // Fallback for simple metrics that use "measure" instead of "input_measures" + if measureValue, ok := typeParams["measure"]; ok { + if measureMap, ok := measureValue.(map[string]interface{}); ok { + measureName := getStringFromMap(measureMap, "name", "") + if model, exists := measureToModelMap[measureName]; exists { + return model + } } } + return "" + } - // Set notNull based on comment or other indicators - // This is a basic implementation - you might need more sophisticated logic - if comment := getStringFromMap(colMap, "comment", ""); comment != "" { - column.Properties["comment"] = comment + inputMeasuresList, ok := inputMeasuresValue.([]interface{}) + if !ok || len(inputMeasuresList) == 0 { + return "" + } + + // Assume all measures for a given metric come from the same base model. + // We only need to find the first valid one. + for _, inputMeasure := range inputMeasuresList { + if imMap, ok := inputMeasure.(map[string]interface{}); ok { + imName := getStringFromMap(imMap, "name", "") + if model, exists := measureToModelMap[imName]; exists { + return model // Return the first model we find. + } } + } + return "" +} - wrenColumns = append(wrenColumns, column) +// findTimeDimensionsForModel scans the semantic models to find all columns +// marked with type "time" for a specific model name. +func findTimeDimensionsForModel(semanticData map[string]interface{}, baseModelName string) []string { + var timeDimensions []string + semanticModels, ok := semanticData["semantic_models"].([]interface{}) + if !ok { + return timeDimensions + } + + for _, sm := range semanticModels { + smMap, ok := sm.(map[string]interface{}) + if !ok { + continue + } + + if getStringFromMap(smMap, "name", "") == baseModelName { + if dims, ok := smMap["dimensions"].([]interface{}); ok { + for _, d := range dims { + if dimMap, ok := d.(map[string]interface{}); ok { + if getStringFromMap(dimMap, "type", "") == "time" { + timeDimensions = append(timeDimensions, getStringFromMap(dimMap, "name", "")) + } + } + } + } + break // Found the model, no need to continue looping. + } + } + return timeDimensions +} + +// buildAggregationExpression constructs the SQL aggregation string for a Wren metric +// based on its dbt type ('simple', 'ratio', or 'derived'). +func buildAggregationExpression(metricType string, typeParams map[string]interface{}, measureDataLookup map[string]map[string]interface{}) string { + switch metricType { + case "simple": + // A simple metric is a direct aggregation of one measure (e.g., SUM(revenue)). + if measure, ok := typeParams["measure"].(map[string]interface{}); ok { + measureName := getStringFromMap(measure, "name", "") + if measureData, ok := measureDataLookup[measureName]; ok { + agg := getStringFromMap(measureData, "agg", "sum") // Default to SUM + expr := getStringFromMap(measureData, "expr", measureName) // Fallback to measure name + return fmt.Sprintf("%s(%s)", strings.ToUpper(agg), expr) + } + } + case "ratio": + // A ratio metric is a division of two measures (e.g., SUM(profit) / SUM(revenue)). + num, numOK := typeParams["numerator"].(map[string]interface{}) + den, denOK := typeParams["denominator"].(map[string]interface{}) + if !numOK || !denOK { + return "" + } + + numName := getStringFromMap(num, "name", "") + denName := getStringFromMap(den, "name", "") + numData, numDataOK := measureDataLookup[numName] + denData, denDataOK := measureDataLookup[denName] + + if numDataOK && denDataOK { + numAgg := strings.ToUpper(getStringFromMap(numData, "agg", "sum")) + denAgg := strings.ToUpper(getStringFromMap(denData, "agg", "sum")) + numExpr := getStringFromMap(numData, "expr", numName) + denExpr := getStringFromMap(denData, "expr", denName) + return fmt.Sprintf("(%s(%s)) / (%s(%s))", numAgg, numExpr, denAgg, denExpr) + } + case "derived": + // A derived metric uses a freeform SQL expression. + return getStringFromMap(typeParams, "expr", "") + } + return "" // Return empty string if no valid aggregation could be built. +} + +// extractDescriptionsFromManifest parses the manifest.json data to find the +// model-level description and a map of all column-level descriptions. +func extractDescriptionsFromManifest(manifestData map[string]interface{}, nodeKey string) (string, map[string]string) { + if manifestData == nil { + return "", nil + } + + nodes, ok := manifestData["nodes"].(map[string]interface{}) + if !ok { + return "", nil + } + + manifestNode, ok := nodes[nodeKey].(map[string]interface{}) + if !ok { + return "", nil + } + + // Extract the top-level model description + modelDescription := getStringFromMap(manifestNode, "description", "") + columnDescriptions := make(map[string]string) + + manifestColumns, ok := manifestNode["columns"].(map[string]interface{}) + if !ok { + // Return the model description even if columns aren't found + return modelDescription, nil + } + + // Iterate through columns to extract their descriptions + for colName, colData := range manifestColumns { + if colMap, ok := colData.(map[string]interface{}); ok { + if description := getStringFromMap(colMap, "description", ""); description != "" { + columnDescriptions[colName] = description + } + } + } + + return modelDescription, columnDescriptions +} + +// buildWrenColumn creates a single WrenColumn from its corresponding dbt column data map. +// It populates the name, type, and properties like enums, descriptions, and comments. +func buildWrenColumn(colMap map[string]interface{}, nodeKey string, dataSource DataSource, columnDescriptions map[string]string, columnToEnumNameMap map[string]string, columnToNotNullMap map[string]bool) WrenColumn { + columnName := getStringFromMap(colMap, "name", "") + columnKey := fmt.Sprintf("%s.%s", nodeKey, columnName) + + column := WrenColumn{ + Name: columnName, + DisplayName: getStringFromMap(getMapFromMap(colMap, "meta", nil), "label", ""), + Type: dataSource.MapType(getStringFromMap(colMap, "type", "")), + NotNull: columnToNotNullMap[columnKey], // Defaults to false if not found + } + + // Assign an enum if one was derived from dbt tests + if enumName, ok := columnToEnumNameMap[columnKey]; ok { + column.Enum = enumName + } + + // Use a temporary map to build the properties + properties := make(map[string]string) + if description, exists := columnDescriptions[column.Name]; exists && description != "" { + properties["description"] = description + } + if comment := getStringFromMap(colMap, "comment", ""); comment != "" { + properties["comment"] = comment + } + + // Assign the properties map only if it's not empty + if len(properties) > 0 { + column.Properties = properties + } + + return column +} + +// convertAndSortColumns extracts, sorts, and converts dbt columns to the WrenColumn format. +func convertAndSortColumns(nodeData map[string]interface{}, nodeKey string, dataSource DataSource, columnDescriptions map[string]string, columnToEnumNameMap map[string]string, columnToNotNullMap map[string]bool) ([]WrenColumn, error) { + columnsValue, exists := nodeData["columns"] + if !exists { + return nil, fmt.Errorf("no columns found for model %s", nodeKey) + } + + columnsMap, ok := columnsValue.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid columns format for model %s", nodeKey) } - // Sort columns by index if available - sort.Slice(wrenColumns, func(i, j int) bool { - // This is a simplified sort - you might want to use the index from dbt - return wrenColumns[i].Name < wrenColumns[j].Name + // Convert map to a slice for sorting + var columnsData []map[string]interface{} + for _, colValue := range columnsMap { + if colMap, ok := colValue.(map[string]interface{}); ok { + columnsData = append(columnsData, colMap) + } + } + + // Sort columns by the 'index' field, falling back to name + sort.Slice(columnsData, func(i, j int) bool { + indexI, okI := columnsData[i]["index"].(float64) + indexJ, okJ := columnsData[j]["index"].(float64) + if okI && okJ { + return indexI < indexJ + } + return getStringFromMap(columnsData[i], "name", "") < getStringFromMap(columnsData[j], "name", "") }) + // Build the final slice of WrenColumns + var wrenColumns []WrenColumn + for _, colMap := range columnsData { + if getStringFromMap(colMap, "name", "") == "" { + continue + } + column := buildWrenColumn(colMap, nodeKey, dataSource, columnDescriptions, columnToEnumNameMap, columnToNotNullMap) + wrenColumns = append(wrenColumns, column) + } + + return wrenColumns, nil +} + +// convertDbtNodeToWrenModel converts a single dbt node to a Wren model. +// This function now orchestrates calls to helpers to perform the conversion. +func convertDbtNodeToWrenModel(nodeKey string, nodeData map[string]interface{}, dataSource DataSource, manifestData map[string]interface{}, columnToEnumNameMap map[string]string, columnToNotNullMap map[string]bool, modelToPrimaryKeyMap map[string]string) (*WrenModel, error) { + modelName := getModelNameFromNodeKey(nodeKey) + if modelName == "" { + return nil, fmt.Errorf("invalid node key format: %s", nodeKey) + } + + // --- 1. Extract Metadata and Table Reference --- + metadataValue, exists := nodeData["metadata"] + if !exists { + return nil, fmt.Errorf("no metadata found for model %s", nodeKey) + } + metadata, ok := metadataValue.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid metadata format for model %s", nodeKey) + } + tableRef := TableReference{ + Table: getStringFromMap(metadata, "name", modelName), + Catalog: getStringFromMap(metadata, "database", ""), + Schema: getStringFromMap(metadata, "schema", ""), + } + + // --- 2. Extract Descriptions from Manifest --- + modelDescription, columnDescriptions := extractDescriptionsFromManifest(manifestData, nodeKey) + + // --- 3. Convert and Sort Columns --- + wrenColumns, err := convertAndSortColumns(nodeData, nodeKey, dataSource, columnDescriptions, columnToEnumNameMap, columnToNotNullMap) + if err != nil { + return nil, err + } + + // --- 4. Assemble the Final WrenModel --- model := &WrenModel{ Name: modelName, TableReference: tableRef, Columns: wrenColumns, } - // Set model description from manifest if available + // Set primary key if available + if pk, ok := modelToPrimaryKeyMap[modelName]; ok { + model.PrimaryKey = pk + } + + // Set model description if available if modelDescription != "" { - if model.Properties == nil { - model.Properties = make(map[string]string) - } - model.Properties["description"] = modelDescription + model.Properties = map[string]string{"description": modelDescription} } return model, nil @@ -457,6 +1053,9 @@ func convertDbtNodeToWrenModel(nodeKey string, nodeData map[string]interface{}, // getStringFromMap safely extracts a string value from a map func getStringFromMap(m map[string]interface{}, key, defaultValue string) string { + if m == nil { + return defaultValue + } if value, exists := m[key]; exists { if str, ok := value.(string); ok { return str @@ -464,3 +1063,38 @@ func getStringFromMap(m map[string]interface{}, key, defaultValue string) string } return defaultValue } + +// getMapFromMap safely extracts a map value from a map +func getMapFromMap(m map[string]interface{}, key string, defaultValue map[string]interface{}) map[string]interface{} { + if value, exists := m[key]; exists { + if str, ok := value.(map[string]interface{}); ok { + return str + } + } + return defaultValue +} + +// getModelNameFromNodeKey extracts the model name from a dbt node key. +// e.g., "model.jaffle_shop.customers" -> "customers" +func getModelNameFromNodeKey(nodeKey string) string { + parts := strings.Split(nodeKey, ".") + if len(parts) > 0 { + return parts[len(parts)-1] + } + return "" +} + +var refRegex = regexp.MustCompile(`ref\s*\(\s*['"]([^'"]+)['"]\s*\)`) + +// parseRef extracts the model name from a dbt ref string. +// e.g., "ref('stg_orders')" +func parseRef(refStr string) string { + // Use the precompiled regex to find matches. + matches := refRegex.FindStringSubmatch(refStr) + if len(matches) > 1 { + // The first submatch (index 1) is the captured group, + // which is the model name we want. + return matches[1] + } + return "" +} diff --git a/wren-launcher/commands/dbt/data_source.go b/wren-launcher/commands/dbt/data_source.go index a12aeac538..f274b4224a 100644 --- a/wren-launcher/commands/dbt/data_source.go +++ b/wren-launcher/commands/dbt/data_source.go @@ -1,7 +1,10 @@ package dbt import ( + "encoding/base64" + "encoding/json" "fmt" + "os" "path/filepath" "strconv" "strings" @@ -81,6 +84,9 @@ func convertConnectionToDataSource(conn DbtConnection, dbtHomePath, profileName, return convertToLocalFileDataSource(conn, dbtHomePath) case "mysql": return convertToMysqlDataSource(conn) + case "bigquery": + // Pass the dbtHomePath to the BigQuery converter + return convertToBigQueryDataSource(conn, dbtHomePath) default: // For unsupported database types, we can choose to ignore or return error // Here we choose to return nil and log a warning @@ -178,6 +184,97 @@ func convertToMysqlDataSource(conn DbtConnection) (*WrenMysqlDataSource, error) return ds, nil } +// convertToBigQueryDataSource converts to BigQuery data source +func convertToBigQueryDataSource(conn DbtConnection, dbtHomePath string) (*WrenBigQueryDataSource, error) { + method := strings.ToLower(strings.TrimSpace(conn.Method)) + var credentials string + + // Helper: validate JSON and base64 encode + encodeJSON := func(b []byte) (string, error) { + var js map[string]interface{} + if err := json.Unmarshal(b, &js); err != nil { + return "", fmt.Errorf("service account JSON is invalid: %w", err) + } + return base64.StdEncoding.EncodeToString(b), nil + } + + switch method { + case "service-account-json": + // Extract inline JSON from Additional["keyfile_json"] + var keyfileJSON string + if kfj, exists := conn.Additional["keyfile_json"]; exists { + if kfjStr, ok := kfj.(string); ok { + keyfileJSON = kfjStr + } + } + if keyfileJSON == "" { + return nil, fmt.Errorf("bigquery: method 'service-account-json' requires 'keyfile_json'") + } + enc, err := encodeJSON([]byte(keyfileJSON)) + if err != nil { + return nil, err + } + credentials = enc + case "service-account", "": + // Prefer structured field; fall back to Additional["keyfile"] + keyfilePath := strings.TrimSpace(conn.Keyfile) + if keyfilePath == "" { + if kf, ok := conn.Additional["keyfile"]; ok { + if kfStr, ok := kf.(string); ok { + keyfilePath = strings.TrimSpace(kfStr) + } + } + } + if keyfilePath == "" { + // If method was omitted (""), try as a fallback to inline json + if kfj, ok := conn.Additional["keyfile_json"]; ok { + if kfjStr, ok := kfj.(string); ok && kfjStr != "" { + enc, err := encodeJSON([]byte(kfjStr)) + if err != nil { + return nil, err + } + credentials = enc + } + } + if credentials == "" { + return nil, fmt.Errorf("bigquery: method 'service-account' requires 'keyfile' path") + } + } else { + // If keyfile path is not absolute, join it + // with the dbt project's home directory path. + resolvedKeyfilePath := keyfilePath + if !filepath.IsAbs(keyfilePath) && dbtHomePath != "" { + resolvedKeyfilePath = filepath.Join(dbtHomePath, keyfilePath) + } + + cleanPath := filepath.Clean(resolvedKeyfilePath) + b, err := os.ReadFile(cleanPath) + if err != nil { + // Update the error message to show the path that was attempted + return nil, fmt.Errorf("failed to read keyfile '%s': %w", cleanPath, err) + } + enc, err := encodeJSON(b) + if err != nil { + return nil, err + } + credentials = enc + } + case "oauth": + pterm.Warning.Println("bigquery: oauth auth method is not supported; skipping data source") + return nil, nil + default: + pterm.Warning.Printf("bigquery: unsupported auth method '%s'; supported: service-account, service-account-json\n", method) + return nil, nil + } + + ds := &WrenBigQueryDataSource{ + Project: conn.Project, + Dataset: conn.Dataset, + Credentials: credentials, + } + return ds, nil +} + type WrenLocalFileDataSource struct { Url string `json:"url"` Format string `json:"format"` @@ -341,6 +438,57 @@ func (ds *WrenMysqlDataSource) MapType(sourceType string) string { } } +type WrenBigQueryDataSource struct { + Project string `json:"project_id"` + Dataset string `json:"dataset_id"` + Credentials string `json:"credentials"` +} + +// GetType implements DataSource interface +func (ds *WrenBigQueryDataSource) GetType() string { + return "bigquery" +} + +// Validate implements DataSource interface +func (ds *WrenBigQueryDataSource) Validate() error { + if strings.TrimSpace(ds.Project) == "" { + return fmt.Errorf("project_id cannot be empty") + } + if strings.TrimSpace(ds.Dataset) == "" { + return fmt.Errorf("dataset_id cannot be empty") + } + if strings.TrimSpace(ds.Credentials) == "" { + return fmt.Errorf("credentials cannot be empty") + } + return nil +} + +// MapType implements DataSource interface +func (ds *WrenBigQueryDataSource) MapType(sourceType string) string { + switch strings.ToUpper(sourceType) { + case "INT64", "INTEGER": + return integerType + case "FLOAT64", "FLOAT": + return doubleType + case "STRING": + return varcharType + case "BOOL", "BOOLEAN": + return booleanType + case "DATE": + return dateType + case "TIMESTAMP", "DATETIME": + return timestampType + case "NUMERIC", "DECIMAL", "BIGNUMERIC": + return doubleType + case "BYTES": + return varcharType + case "JSON": + return varcharType + default: + return strings.ToLower(sourceType) + } +} + // GetActiveDataSources gets active data sources based on specified profile and target // If profileName is empty, it will use the first found profile // If targetName is empty, it will use the profile's default target diff --git a/wren-launcher/commands/dbt/data_source_test.go b/wren-launcher/commands/dbt/data_source_test.go index 4bdd3e7c1d..09fdd5d7ce 100644 --- a/wren-launcher/commands/dbt/data_source_test.go +++ b/wren-launcher/commands/dbt/data_source_test.go @@ -1,6 +1,9 @@ package dbt import ( + "encoding/base64" + "os" + "path/filepath" "testing" ) @@ -81,8 +84,8 @@ func TestFromDbtProfiles_Postgres(t *testing.T) { validatePostgresDataSource(t, ds, "test_db") } -func TestFromDbtProfiles_PostgresWithDbName(t *testing.T) { - // Test PostgreSQL connection conversion with dbname field (PostgreSQL specific) +func TestFromDbtProfiles_PostgresWithDefaultPort(t *testing.T) { + // Test PostgreSQL connection conversion when port is not specified profiles := &DbtProfiles{ Profiles: map[string]DbtProfile{ "test_profile": { @@ -172,7 +175,7 @@ func TestFromDbtProfiles_LocalFile(t *testing.T) { t.Fatalf("Expected WrenLocalFileDataSource, got %T", dataSources[0]) } - if ds.Url != "/abs_path" { + if filepath.ToSlash(ds.Url) != "/abs_path" { t.Errorf("Expected url '/abs_path', got '%s'", ds.Url) } if ds.Format != duckdbType { @@ -225,28 +228,256 @@ func TestFromDbtProfiles_NilProfiles(t *testing.T) { } } -// Validator interface for data sources -type Validator interface { - Validate() error +func TestValidateAllDataSources(t *testing.T) { + // Test valid profiles + validProfiles := &DbtProfiles{ + Profiles: map[string]DbtProfile{ + "valid_project": { + Target: "dev", + Outputs: map[string]DbtConnection{ + "dev": { + Type: "postgres", + Host: "localhost", + Port: 5432, + Database: "test_db", + User: "user", + }, + }, + }, + }, + } + + err := ValidateAllDataSources(validProfiles) + if err != nil { + t.Errorf("ValidateAllDataSources failed for valid profiles: %v", err) + } + + // Test invalid profiles + invalidProfiles := &DbtProfiles{ + Profiles: map[string]DbtProfile{ + "invalid_project": { + Target: "dev", + Outputs: map[string]DbtConnection{ + "dev": { + Type: "postgres", + Host: "localhost", + // Missing required fields + }, + }, + }, + }, + } + + err = ValidateAllDataSources(invalidProfiles) + if err == nil { + t.Error("ValidateAllDataSources should fail for invalid profiles") + } } -// Helper function to test data source validation -func testDataSourceValidation(t *testing.T, testName string, validDS Validator, invalidDSCases []struct { - name string - ds Validator -}) { - t.Helper() +func TestFromDbtProfiles_BigQuery(t *testing.T) { + tempDir, err := os.MkdirTemp("", "test-dbt-home") + if err != nil { + t.Fatal(err) + } + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Logf("Failed to remove temporary directory %s: %v", tempDir, err) + } + }() + + t.Run("service-account-json", func(t *testing.T) { + keyfileContent := `{"type": "service_account", "project_id": "test-project", "private_key_id": "test-key-id", "private_key": "test-private-key", "client_email": "test-client-email", "client_id": "test-client-id", "auth_uri": "test-auth-uri", "token_uri": "test-token-uri", "auth_provider_x509_cert_url": "test-cert-url", "client_x509_cert_url": "test-client-cert-url"}` // #nosec G101 + profiles := &DbtProfiles{ + Profiles: map[string]DbtProfile{ + "test_profile": { + Target: "dev", + Outputs: map[string]DbtConnection{ + "dev": { + Type: "bigquery", + Method: "service-account-json", + Project: "test-project", + Dataset: "test-dataset", + Additional: map[string]interface{}{ + "keyfile_json": keyfileContent, + }, + }, + }, + }, + }, + } + + dataSources, err := GetActiveDataSources(profiles, "", "test_profile", "dev") + if err != nil { + t.Fatalf("GetActiveDataSources failed: %v", err) + } + + if len(dataSources) != 1 { + t.Fatalf("Expected 1 data source, got %d", len(dataSources)) + } + + ds, ok := dataSources[0].(*WrenBigQueryDataSource) + if !ok { + t.Fatalf("Expected WrenBigQueryDataSource, got %T", dataSources[0]) + } + + if ds.Project != "test-project" { + t.Errorf("Expected project 'test-project', got '%s'", ds.Project) + } + + if ds.Dataset != "test-dataset" { + t.Errorf("Expected dataset 'test-dataset', got '%s'", ds.Dataset) + } + + encodedContent, _ := base64.StdEncoding.DecodeString(ds.Credentials) + if string(encodedContent) != keyfileContent { + t.Errorf("Expected base64-encoded keyfile JSON content, got different content") + } + }) + + t.Run("service-account-with-absolute-keyfile-path", func(t *testing.T) { + keyfileContent := `{"type": "service_account"}` // #nosec G101 + keyfilePath := filepath.Join(tempDir, "keyfile.json") + if err := os.WriteFile(keyfilePath, []byte(keyfileContent), 0600); err != nil { + t.Fatal(err) + } + + profiles := &DbtProfiles{ + Profiles: map[string]DbtProfile{ + "test_profile": { + Target: "dev", + Outputs: map[string]DbtConnection{ + "dev": { + Type: "bigquery", + Method: "service-account", + Project: "test-project", + Dataset: "test-dataset", + Keyfile: keyfilePath, + }, + }, + }, + }, + } + + dataSources, err := GetActiveDataSources(profiles, "", "test_profile", "dev") + if err != nil { + t.Fatalf("GetActiveDataSources failed: %v", err) + } + + if len(dataSources) != 1 { + t.Fatalf("Expected 1 data source, got %d", len(dataSources)) + } + + ds, ok := dataSources[0].(*WrenBigQueryDataSource) + if !ok { + t.Fatalf("Expected WrenBigQueryDataSource, got %T", dataSources[0]) + } + + encodedContent, _ := base64.StdEncoding.DecodeString(ds.Credentials) + if string(encodedContent) != keyfileContent { + t.Errorf("Expected base64-encoded keyfile content, got different content") + } + }) + + t.Run("service-account-with-relative-keyfile-path", func(t *testing.T) { + dbtHomePath := tempDir + keyfileContent := `{"type": "service_account"}` // #nosec G101 + keyfilePath := "keys/keyfile.json" + fullKeyfilePath := filepath.Join(dbtHomePath, keyfilePath) + + if err := os.MkdirAll(filepath.Dir(fullKeyfilePath), 0750); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(fullKeyfilePath, []byte(keyfileContent), 0600); err != nil { + t.Fatal(err) + } + + profiles := &DbtProfiles{ + Profiles: map[string]DbtProfile{ + "test_profile": { + Target: "dev", + Outputs: map[string]DbtConnection{ + "dev": { + Type: "bigquery", + Method: "service-account", + Project: "test-project", + Dataset: "test-dataset", + Keyfile: keyfilePath, + }, + }, + }, + }, + } - t.Run(testName+" valid", func(t *testing.T) { - if err := validDS.Validate(); err != nil { - t.Errorf("Valid data source validation failed: %v", err) + dataSources, err := GetActiveDataSources(profiles, dbtHomePath, "test_profile", "dev") + if err != nil { + t.Fatalf("GetActiveDataSources failed: %v", err) + } + + if len(dataSources) != 1 { + t.Fatalf("Expected 1 data source, got %d", len(dataSources)) + } + + ds, ok := dataSources[0].(*WrenBigQueryDataSource) + if !ok { + t.Fatalf("Expected WrenBigQueryDataSource, got %T", dataSources[0]) + } + + encodedContent, _ := base64.StdEncoding.DecodeString(ds.Credentials) + if string(encodedContent) != keyfileContent { + t.Errorf("Expected base64-encoded keyfile content, got different content") } }) +} - for _, tt := range invalidDSCases { - t.Run(testName+" "+tt.name, func(t *testing.T) { - if err := tt.ds.Validate(); err == nil { - t.Errorf("Expected validation error for %s, but got none", tt.name) +func TestBigQueryDataSourceValidation(t *testing.T) { + tests := []struct { + name string + ds *WrenBigQueryDataSource + wantErr bool + }{ + { + name: "valid", + ds: &WrenBigQueryDataSource{ + Project: "test-project", + Dataset: "test-dataset", + Credentials: "dGVzdC1jcmVkZW50aWFscw==", // "test-credentials" + }, + wantErr: false, + }, + { + name: "invalid - missing project", + ds: &WrenBigQueryDataSource{ + Project: "", + Dataset: "test-dataset", + Credentials: "dGVzdC1jcmVkZW50aWFscw==", + }, + wantErr: true, + }, + { + name: "invalid - missing dataset", + ds: &WrenBigQueryDataSource{ + Project: "test-project", + Dataset: "", + Credentials: "dGVzdC1jcmVkZW50aWFscw==", + }, + wantErr: true, + }, + { + name: "invalid - missing credentials", + ds: &WrenBigQueryDataSource{ + Project: "test-project", + Dataset: "test-dataset", + Credentials: "", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.ds.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) } }) } @@ -474,48 +705,57 @@ func TestGetDataSourceByType(t *testing.T) { } } -func TestValidateAllDataSources(t *testing.T) { - // Test valid profiles - validProfiles := &DbtProfiles{ - Profiles: map[string]DbtProfile{ - "valid_project": { - Target: "dev", - Outputs: map[string]DbtConnection{ - "dev": { - Type: "postgres", - Host: "localhost", - Port: 5432, - Database: "test_db", - User: "user", - }, - }, - }, +func TestMapType(t *testing.T) { + tests := []struct { + name string + dataSource DataSource + sourceType string + want string + }{ + { + name: "BigQuery INT64 to integer", + dataSource: &WrenBigQueryDataSource{}, + sourceType: "INT64", + want: "integer", }, - } - - err := ValidateAllDataSources(validProfiles) - if err != nil { - t.Errorf("ValidateAllDataSources failed for valid profiles: %v", err) - } - - // Test invalid profiles - invalidProfiles := &DbtProfiles{ - Profiles: map[string]DbtProfile{ - "invalid_project": { - Target: "dev", - Outputs: map[string]DbtConnection{ - "dev": { - Type: "postgres", - Host: "localhost", - // Missing required fields - }, - }, - }, + { + name: "BigQuery STRING to varchar", + dataSource: &WrenBigQueryDataSource{}, + sourceType: "STRING", + want: "varchar", + }, + { + name: "LocalFile INTEGER to integer", + dataSource: &WrenLocalFileDataSource{}, + sourceType: "INTEGER", + want: "integer", + }, + { + name: "LocalFile VARCHAR to varchar", + dataSource: &WrenLocalFileDataSource{}, + sourceType: "VARCHAR", + want: "varchar", + }, + { + name: "DefaultDataSource int to integer", + dataSource: &DefaultDataSource{}, + sourceType: "int", + want: "integer", + }, + { + name: "PostgresDataSource (no mapping)", + dataSource: &WrenPostgresDataSource{}, + sourceType: "unknown_type", + want: "unknown_type", }, } - err = ValidateAllDataSources(invalidProfiles) - if err == nil { - t.Error("ValidateAllDataSources should fail for invalid profiles") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.dataSource.MapType(tt.sourceType) + if got != tt.want { + t.Errorf("MapType(%s) = %s; want %s", tt.sourceType, got, tt.want) + } + }) } } diff --git a/wren-launcher/commands/dbt/profiles.go b/wren-launcher/commands/dbt/profiles.go index e87822e70b..62ca16746a 100644 --- a/wren-launcher/commands/dbt/profiles.go +++ b/wren-launcher/commands/dbt/profiles.go @@ -26,6 +26,7 @@ type DbtConnection struct { Project string `yaml:"project,omitempty" json:"project,omitempty"` // BigQuery Dataset string `yaml:"dataset,omitempty" json:"dataset,omitempty"` // BigQuery Keyfile string `yaml:"keyfile,omitempty" json:"keyfile,omitempty"` // BigQuery + Method string `yaml:"method,omitempty" json:"method,omitempty"` // BigQuery Account string `yaml:"account,omitempty" json:"account,omitempty"` // Snowflake Warehouse string `yaml:"warehouse,omitempty" json:"warehouse,omitempty"` // Snowflake Role string `yaml:"role,omitempty" json:"role,omitempty"` // Snowflake diff --git a/wren-launcher/commands/dbt/profiles_analyzer.go b/wren-launcher/commands/dbt/profiles_analyzer.go index 1a0698ba83..f120f4dadf 100644 --- a/wren-launcher/commands/dbt/profiles_analyzer.go +++ b/wren-launcher/commands/dbt/profiles_analyzer.go @@ -134,6 +134,7 @@ func parseConnection(connectionMap map[string]interface{}) (*DbtConnection, erro connection.Project = getString("project") connection.Dataset = getString("dataset") connection.Keyfile = getString("keyfile") + connection.Method = getString("method") connection.Account = getString("account") connection.Warehouse = getString("warehouse") connection.Role = getString("role") @@ -147,7 +148,7 @@ func parseConnection(connectionMap map[string]interface{}) (*DbtConnection, erro knownFields := map[string]bool{ "type": true, "host": true, "port": true, "user": true, "password": true, "database": true, "dbname": true, "schema": true, "project": true, "dataset": true, - "keyfile": true, "account": true, "warehouse": true, "role": true, + "keyfile": true, "method": true, "account": true, "warehouse": true, "role": true, "keepalive": true, "search_path": true, "sslmode": true, "path": true, "ssl_disable": true, } diff --git a/wren-launcher/commands/dbt/wren_mdl.go b/wren-launcher/commands/dbt/wren_mdl.go index cc0b44021c..60c084f9be 100644 --- a/wren-launcher/commands/dbt/wren_mdl.go +++ b/wren-launcher/commands/dbt/wren_mdl.go @@ -2,12 +2,20 @@ package dbt // WrenMDLManifest represents the complete Wren MDL structure type WrenMDLManifest struct { - Catalog string `json:"catalog"` - Schema string `json:"schema"` - Models []WrenModel `json:"models"` - Relationships []Relationship `json:"relationships"` - Views []View `json:"views"` - DataSources string `json:"dataSources,omitempty"` + Catalog string `json:"catalog"` + Schema string `json:"schema"` + EnumDefinitions []EnumDefinition `json:"enumDefinitions,omitempty"` + Models []WrenModel `json:"models"` + Relationships []Relationship `json:"relationships"` + Metrics []Metric `json:"metrics,omitempty"` + Views []View `json:"views"` + DataSources string `json:"dataSources,omitempty"` +} + +// EnumDefinition represents a named list of values that can be used by columns. +type EnumDefinition struct { + Name string `json:"name"` + Values []string `json:"values"` } // WrenModel represents a model in the Wren MDL format @@ -31,7 +39,9 @@ type TableReference struct { // WrenColumn represents a column in the Wren MDL format type WrenColumn struct { Name string `json:"name"` + DisplayName string `json:"displayName,omitempty"` Type string `json:"type"` + Enum string `json:"enum,omitempty"` Relationship string `json:"relationship,omitempty"` IsCalculated bool `json:"isCalculated,omitempty"` NotNull bool `json:"notNull,omitempty"` @@ -48,6 +58,16 @@ type Relationship struct { Properties map[string]string `json:"properties,omitempty"` } +// Metric defines a business-level calculation in Wren MDL. +type Metric struct { + Name string `json:"name"` + Models []string `json:"models"` + Dimensions []string `json:"dimensions"` + Aggregation string `json:"aggregation"` + DisplayName string `json:"displayName"` + Description string `json:"description,omitempty"` +} + // View represents a view in the Wren MDL format type View struct { Name string `json:"name"` diff --git a/wren-launcher/commands/launch.go b/wren-launcher/commands/launch.go index 4edb0f6d4c..01fe9f13ed 100644 --- a/wren-launcher/commands/launch.go +++ b/wren-launcher/commands/launch.go @@ -187,6 +187,18 @@ func askForDbtTarget() (string, error) { return result, nil } +func askForIncludeStagingModels() (bool, error) { + prompt := promptui.Select{ + Label: "Include staging models (stg_*, staging_*)?", + Items: []string{"No", "Yes"}, + } + _, result, err := prompt.Run() + if err != nil { + return false, err + } + return result == "Yes", nil +} + func Launch() { // recover from panic defer func() { @@ -521,8 +533,15 @@ func processDbtProject(projectDir string) (string, error) { return "", err } - // Use the core conversion function from dbt package - result, err := DbtConvertProject(dbtProjectPath, targetDir, profileName, target, true) + // Ask the user whether to include staging models + includeStagingModels, err := askForIncludeStagingModels() + if err != nil { + pterm.Warning.Println("Could not get staging model preference, defaulting to 'No'.") + includeStagingModels = false + } + + // Use the core conversion function from dbt package, passing the user's choice + result, err := DbtConvertProject(dbtProjectPath, targetDir, profileName, target, true, includeStagingModels) if err != nil { return "", fmt.Errorf("failed to convert dbt project: %w", err) } diff --git a/wren-launcher/utils/docker.go b/wren-launcher/utils/docker.go index 280ee28791..f49d67c480 100644 --- a/wren-launcher/utils/docker.go +++ b/wren-launcher/utils/docker.go @@ -23,7 +23,7 @@ import ( const ( // please change the version when the version is updated - WREN_PRODUCT_VERSION string = "0.28.0" + WREN_PRODUCT_VERSION string = "0.28.0" DOCKER_COMPOSE_YAML_URL string = "https://raw.githubusercontent.com/Canner/WrenAI/" + WREN_PRODUCT_VERSION + "/docker/docker-compose.yaml" DOCKER_COMPOSE_ENV_URL string = "https://raw.githubusercontent.com/Canner/WrenAI/" + WREN_PRODUCT_VERSION + "/docker/.env.example" AI_SERVICE_CONFIG_URL string = "https://raw.githubusercontent.com/Canner/WrenAI/" + WREN_PRODUCT_VERSION + "/docker/config.example.yaml"