diff --git a/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go b/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go index 54d29d160582..226cca46039d 100644 --- a/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go +++ b/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go @@ -482,30 +482,33 @@ func handleTextResponse(resp *TextResponse) map[string]any { } func handleSchemaResponse(resp *SchemaResponse) map[string]any { + res := make(map[string]any) if resp.Query != nil { - return map[string]any{"Question": resp.Query.Question} + res["Question"] = resp.Query.Question } if resp.Result != nil { var formattedSources []map[string]any for _, ds := range resp.Result.Datasources { formattedSources = append(formattedSources, formatDatasourceAsDict(&ds)) } - return map[string]any{"Schema Resolved": formattedSources} + res["Schema Resolved"] = formattedSources } - return nil + if len(res) == 0 { + return nil + } + return res } func handleDataResponse(resp *DataResponse, maxRows int) map[string]any { + res := make(map[string]any) if resp.Query != nil { - return map[string]any{ - "Retrieval Query": map[string]any{ - "Query Name": resp.Query.Name, - "Question": resp.Query.Question, - }, + res["Retrieval Query"] = map[string]any{ + "Query Name": resp.Query.Name, + "Question": resp.Query.Question, } } if resp.GeneratedSQL != "" { - return map[string]any{"SQL Generated": resp.GeneratedSQL} + res["SQL Generated"] = resp.GeneratedSQL } if resp.Result != nil { var headers []string @@ -533,15 +536,16 @@ func handleDataResponse(resp *DataResponse, maxRows int) map[string]any { summary = fmt.Sprintf("Showing the first %d of %d total rows.", numRowsToDisplay, totalRows) } - return map[string]any{ - "Data Retrieved": map[string]any{ - "headers": headers, - "rows": compactRows, - "summary": summary, - }, + res["Data Retrieved"] = map[string]any{ + "headers": headers, + "rows": compactRows, + "summary": summary, } } - return nil + if len(res) == 0 { + return nil + } + return res } func handleError(resp *ErrorResponse) map[string]any { @@ -557,9 +561,17 @@ func appendMessage(messages []map[string]any, newMessage map[string]any) []map[s if newMessage == nil { return messages } - if len(messages) > 0 { - if _, ok := messages[len(messages)-1]["Data Retrieved"]; ok { - messages = messages[:len(messages)-1] + + if _, hasData := newMessage["Data Retrieved"]; hasData { + // Only keep the last data result while preserving SQL and other metadata. + for i := len(messages) - 1; i >= 0; i-- { + if _, ok := messages[i]["Data Retrieved"]; ok { + delete(messages[i], "Data Retrieved") + if len(messages[i]) == 0 { + messages = append(messages[:i], messages[i+1:]...) + } + break + } } } return append(messages, newMessage) diff --git a/tests/bigquery/bigquery_integration_test.go b/tests/bigquery/bigquery_integration_test.go index 30307296f1e4..575c101b7580 100644 --- a/tests/bigquery/bigquery_integration_test.go +++ b/tests/bigquery/bigquery_integration_test.go @@ -173,7 +173,7 @@ func TestBigQueryToolEndpoints(t *testing.T) { datasetInfoWant := "\"Location\":\"US\",\"DefaultTableExpiration\":0,\"Labels\":null,\"Access\":" tableInfoWant := "{\"Name\":\"\",\"Location\":\"US\",\"Description\":\"\",\"Schema\":[{\"Name\":\"id\"" ddlWant := `"Query executed successfully and returned no content."` - dataInsightsWant := `(?s)Schema Resolved.*Retrieval Query.*SQL Generated.*Answer` + dataInsightsWant := `(?s)(Schema Resolved.*)?(Retrieval Query.*)?SQL Generated.*Data Retrieved.*Answer` // Partial message; the full error message is too long. mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"error processing GCP request: failed to insert dry run job: googleapi: Error 400: Syntax error: Unexpected identifier \"SELEC\" at [1:1]` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"f0_\":1}"}]}}` @@ -2393,7 +2393,7 @@ func runBigQueryConversationalAnalyticsInvokeTest(t *testing.T, datasetName, tab `{"user_query_with_context": "What are the names in the table?", "table_references": %q}`, tableRefsJSON, ))), - want: "[{\"f0_\":1}]", + want: dataInsightsWant, isErr: false, }, {