Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 141 additions & 23 deletions internal/tools/bigquery/bigquerysql/bigquerysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ package bigquerysql
import (
"context"
"fmt"
"reflect"
"strings"

bigqueryapi "cloud.google.com/go/bigquery"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
)

Expand All @@ -45,6 +47,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T

type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client
BigQueryRestService() *bigqueryrestapi.Service
}

// validate compatible sources are still compatible
Expand Down Expand Up @@ -101,6 +104,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
Statement: cfg.Statement,
AuthRequired: cfg.AuthRequired,
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
Expand All @@ -117,15 +121,17 @@ type Tool struct {
Parameters tools.Parameters `yaml:"parameters"`
TemplateParameters tools.Parameters `yaml:"templateParameters"`
AllParams tools.Parameters `yaml:"allParams"`

Client *bigqueryapi.Client
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
Statement string
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
manifest tools.Manifest
mcpManifest tools.McpManifest
}

func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
namedArgs := make([]bigqueryapi.QueryParameter, 0, len(params))
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters))

paramsMap := params.AsMap()
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
Expand All @@ -136,14 +142,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error)
name := p.GetName()
value := paramsMap[name]

// BigQuery's QueryParameter only accepts typed slices as input
// This checks if the param is an array.
// If yes, convert []any to typed slice (e.g []string, []int)
switch arrayParam := p.(type) {
case *tools.ArrayParameter:
// This block for converting []any to typed slices is still necessary and correct.
if arrayParam, ok := p.(*tools.ArrayParameter); ok {
arrayParamValue, ok := value.([]any)
if !ok {
return nil, fmt.Errorf("unable to convert parameter `%s` to []any %w", name, err)
return nil, fmt.Errorf("unable to convert parameter `%s` to []any", name)
}
itemType := arrayParam.GetItems().GetType()
var err error
Expand All @@ -153,22 +156,69 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error)
}
}

if strings.Contains(t.Statement, "@"+name) {
namedArgs = append(namedArgs, bigqueryapi.QueryParameter{
Name: name,
Value: value,
})
// Determine if the parameter is named or positional for the high-level client.
var paramNameForHighLevel string
if strings.Contains(newStatement, "@"+name) {
paramNameForHighLevel = name
}

// 1. Create the high-level parameter for the final query execution.
highLevelParams = append(highLevelParams, bigqueryapi.QueryParameter{
Name: paramNameForHighLevel,
Value: value,
})

// 2. Create the low-level parameter for the dry run, using the defined type from `p`.
lowLevelParam := &bigqueryrestapi.QueryParameter{
Name: paramNameForHighLevel,
ParameterType: &bigqueryrestapi.QueryParameterType{},
ParameterValue: &bigqueryrestapi.QueryParameterValue{},
}

if arrayParam, ok := p.(*tools.ArrayParameter); ok {
// Handle array types based on their defined item type.
lowLevelParam.ParameterType.Type = "ARRAY"
itemType, err := BQTypeStringFromToolType(arrayParam.GetItems().GetType())
if err != nil {
return nil, err
}
lowLevelParam.ParameterType.ArrayType = &bigqueryrestapi.QueryParameterType{Type: itemType}

// Build the array values.
sliceVal := reflect.ValueOf(value)
arrayValues := make([]*bigqueryrestapi.QueryParameterValue, sliceVal.Len())
for i := 0; i < sliceVal.Len(); i++ {
arrayValues[i] = &bigqueryrestapi.QueryParameterValue{
Value: fmt.Sprintf("%v", sliceVal.Index(i).Interface()),
}
}
lowLevelParam.ParameterValue.ArrayValues = arrayValues
} else {
namedArgs = append(namedArgs, bigqueryapi.QueryParameter{
Value: value,
})
// Handle scalar types based on their defined type.
bqType, err := BQTypeStringFromToolType(p.GetType())
if err != nil {
return nil, err
}
lowLevelParam.ParameterType.Type = bqType
lowLevelParam.ParameterValue.Value = fmt.Sprintf("%v", value)
}
lowLevelParams = append(lowLevelParams, lowLevelParam)
}

query := t.Client.Query(newStatement)
query.Parameters = namedArgs
query.Parameters = highLevelParams
query.Location = t.Client.Location

dryRunJob, err := dryRunQuery(ctx, t.RestService, t.Client.Project(), t.Client.Location, newStatement, lowLevelParams, query.ConnectionProperties)
if err != nil {
// This is a fallback check in case the switch logic was bypassed.
return nil, fmt.Errorf("final query validation failed: %w", err)
}
statementType := dryRunJob.Statistics.Query.StatementType

// This block handles SELECT statements, which return a row set.
// We iterate through the results, convert each row into a map of
// column names to values, and return the collection of rows.
it, err := query.Read(ctx)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
Expand All @@ -177,7 +227,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error)
var out []any
for {
var row map[string]bigqueryapi.Value
err := it.Next(&row)
err = it.Next(&row)
if err == iterator.Done {
break
}
Expand All @@ -190,8 +240,21 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error)
}
out = append(out, vMap)
}
// If the query returned any rows, return them directly.
if len(out) > 0 {
return out, nil
}

return out, nil
// This handles the standard case for a SELECT query that successfully
// executes but returns zero rows.
if statementType == "SELECT" {
return "The query returned 0 rows.", nil
}
// This is the fallback for a successful query that doesn't return content.
// In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc.
// However, it is also possible that this was a query that was expected to return rows
// but returned none, a case that we cannot distinguish here.
return "Query executed successfully and returned no content.", nil
}

func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
Expand All @@ -209,3 +272,58 @@ func (t Tool) McpManifest() tools.McpManifest {
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

func BQTypeStringFromToolType(toolType string) (string, error) {
switch toolType {
case "string":
return "STRING", nil
case "integer":
return "INT64", nil
case "float":
return "FLOAT64", nil
case "boolean":
return "BOOL", nil
// Note: 'array' is handled separately as it has a nested item type.
default:
return "", fmt.Errorf("unsupported tool parameter type for BigQuery: %s", toolType)
}
}

func dryRunQuery(
ctx context.Context,
restService *bigqueryrestapi.Service,
projectID string,
location string,
sql string,
params []*bigqueryrestapi.QueryParameter,
connProps []*bigqueryapi.ConnectionProperty,
) (*bigqueryrestapi.Job, error) {
useLegacySql := false

restConnProps := make([]*bigqueryrestapi.ConnectionProperty, len(connProps))
for i, prop := range connProps {
restConnProps[i] = &bigqueryrestapi.ConnectionProperty{Key: prop.Key, Value: prop.Value}
}

jobToInsert := &bigqueryrestapi.Job{
JobReference: &bigqueryrestapi.JobReference{
ProjectId: projectID,
Location: location,
},
Configuration: &bigqueryrestapi.JobConfiguration{
DryRun: true,
Query: &bigqueryrestapi.JobConfigurationQuery{
Query: sql,
UseLegacySql: &useLegacySql,
ConnectionProperties: restConnProps,
QueryParameters: params,
},
},
}

insertResponse, err := restService.Jobs.Insert(projectID, jobToInsert).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to insert dry run job: %w", err)
}
return insertResponse, nil
}
Loading
Loading