Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ The behavior of this tool is influenced by the `writeMode` setting on its `bigqu
tools using the same source. This allows the `input_data` parameter to be a query that references temporary resources (e.g.,
`TEMP` tables) created within that session.

The tool's behavior is also influenced by the `allowedDatasets` restriction on the `bigquery` source:

- **Without `allowedDatasets` restriction:** The tool can use any table or query for the `input_data` parameter.
- **With `allowedDatasets` restriction:** The tool verifies that the `input_data` parameter only accesses tables within the allowed datasets.
- If `input_data` is a table ID, the tool checks if the table's dataset is in the allowed list.
- If `input_data` is a query, the tool performs a dry run to analyze the query and rejects it if it accesses any table outside the allowed list.


## Example

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
)
Expand All @@ -50,6 +51,8 @@ type compatibleSource interface {
BigQueryRestService() *bigqueryrestapi.Service
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string
BigQuerySession() bigqueryds.BigQuerySessionProvider
}

Expand Down Expand Up @@ -86,8 +89,17 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}

inputDataParameter := tools.NewStringParameter("input_data",
"The data that contain the test and control data to analyze. Can be a fully qualified BigQuery table ID or a SQL query.")
allowedDatasets := s.BigQueryAllowedDatasets()
inputDataDescription := "The data that contain the test and control data to analyze. Can be a fully qualified BigQuery table ID or a SQL query."
if len(allowedDatasets) > 0 {
datasetIDs := []string{}
for _, ds := range allowedDatasets {
datasetIDs = append(datasetIDs, fmt.Sprintf("`%s`", ds))
}
inputDataDescription += fmt.Sprintf(" The query or table must only access datasets from the following list: %s.", strings.Join(datasetIDs, ", "))
}

inputDataParameter := tools.NewStringParameter("input_data", inputDataDescription)
contributionMetricParameter := tools.NewStringParameter("contribution_metric",
`The name of the column that contains the metric to analyze.
Provides the expression to use to calculate the metric you are analyzing.
Expand Down Expand Up @@ -123,17 +135,19 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)

// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
SessionProvider: s.BigQuerySession(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
SessionProvider: s.BigQuerySession(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
Expand All @@ -148,12 +162,14 @@ type Tool struct {
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`

Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
SessionProvider bigqueryds.BigQuerySessionProvider
manifest tools.Manifest
mcpManifest tools.McpManifest
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
SessionProvider bigqueryds.BigQuerySessionProvider
manifest tools.Manifest
mcpManifest tools.McpManifest
}

// Invoke runs the contribution analysis.
Expand All @@ -164,6 +180,22 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"])
}

bqClient := t.Client
restService := t.RestService
var err error

// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}

modelID := fmt.Sprintf("contribution_analysis_model_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))

var options []string
Expand Down Expand Up @@ -196,8 +228,54 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
var inputDataSource string
trimmedUpperInputData := strings.TrimSpace(strings.ToUpper(inputData))
if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") {
if len(t.AllowedDatasets) > 0 {
var connProps []*bigqueryapi.ConnectionProperty
session, err := t.SessionProvider(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
if session != nil {
connProps = []*bigqueryapi.ConnectionProperty{
{Key: "session_id", Value: session.ID},
}
}
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps)
if err != nil {
return nil, fmt.Errorf("query validation failed: %w", err)
}
statementType := dryRunJob.Statistics.Query.StatementType
if statementType != "SELECT" {
return nil, fmt.Errorf("the 'input_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType)
}

queryStats := dryRunJob.Statistics.Query
if queryStats != nil {
for _, tableRef := range queryStats.ReferencedTables {
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
return nil, fmt.Errorf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
}
}
} else {
return nil, fmt.Errorf("could not analyze query in input_data to validate against allowed datasets")
}
}
inputDataSource = fmt.Sprintf("(%s)", inputData)
} else {
if len(t.AllowedDatasets) > 0 {
parts := strings.Split(inputData, ".")
var projectID, datasetID string
switch len(parts) {
case 3: // project.dataset.table
projectID, datasetID = parts[0], parts[1]
case 2: // dataset.table
projectID, datasetID = t.Client.Project(), parts[0]
default:
return nil, fmt.Errorf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData)
}
if !t.IsDatasetAllowed(projectID, datasetID) {
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData)
}
}
inputDataSource = fmt.Sprintf("SELECT * FROM `%s`", inputData)
}

Expand All @@ -209,21 +287,6 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
inputDataSource,
)

bqClient := t.Client
var err error

// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}

createModelQuery := bqClient.Query(createModelSQL)

// Get session from provider if in protected mode.
Expand Down
116 changes: 113 additions & 3 deletions tests/bigquery/bigquery_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
}

func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Minute)
defer cancel()

client, err := initBigQueryConnection(BigqueryProject)
Expand All @@ -225,6 +225,9 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
allowedForecastTableName2 := "allowed_forecast_table_2"
disallowedForecastTableName := "disallowed_forecast_table"

allowedAnalyzeContributionTableName1 := "allowed_analyze_contribution_table_1"
allowedAnalyzeContributionTableName2 := "allowed_analyze_contribution_table_2"
disallowedAnalyzeContributionTableName := "disallowed_analyze_contribution_table"
// Setup allowed table
allowedTableNameParam1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedTableName1)
createAllowedTableStmt1 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam1)
Expand Down Expand Up @@ -259,6 +262,23 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
teardownDisallowedForecast := setupBigQueryTable(t, ctx, client, createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedDatasetName, disallowedForecastTableFullName, disallowedForecastParams)
defer teardownDisallowedForecast(t)

// Setup allowed analyze contribution table
allowedAnalyzeContributionTableFullName1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedAnalyzeContributionTableName1)
createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, analyzeContributionParams1 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName1)
teardownAllowedAnalyzeContribution1 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, allowedDatasetName1, allowedAnalyzeContributionTableFullName1, analyzeContributionParams1)
defer teardownAllowedAnalyzeContribution1(t)

allowedAnalyzeContributionTableFullName2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedAnalyzeContributionTableName2)
createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, analyzeContributionParams2 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName2)
teardownAllowedAnalyzeContribution2 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, allowedDatasetName2, allowedAnalyzeContributionTableFullName2, analyzeContributionParams2)
defer teardownAllowedAnalyzeContribution2(t)

// Setup disallowed analyze contribution table
disallowedAnalyzeContributionTableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedAnalyzeContributionTableName)
createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedAnalyzeContributionParams := getBigQueryAnalyzeContributionToolInfo(disallowedAnalyzeContributionTableFullName)
teardownDisallowedAnalyzeContribution := setupBigQueryTable(t, ctx, client, createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedDatasetName, disallowedAnalyzeContributionTableFullName, disallowedAnalyzeContributionParams)
defer teardownDisallowedAnalyzeContribution(t)

// Configure source with dataset restriction.
sourceConfig := getBigQueryVars(t)
sourceConfig["allowedDatasets"] = []string{allowedDatasetName1, allowedDatasetName2}
Expand Down Expand Up @@ -300,6 +320,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
"source": "my-instance",
"description": "Tool to forecast",
},
"analyze-contribution-restricted": map[string]any{
"kind": "bigquery-analyze-contribution",
"source": "my-instance",
"description": "Tool to analyze contribution",
},
}

// Create config file
Expand Down Expand Up @@ -327,8 +352,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {

// Run tests
runListDatasetIdsWithRestriction(t, allowedDatasetName1, allowedDatasetName2)
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1)
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2)
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1, allowedAnalyzeContributionTableName1)
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2, allowedAnalyzeContributionTableName2)
runGetDatasetInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName)
runGetDatasetInfoWithRestriction(t, allowedDatasetName2, disallowedDatasetName)
runGetTableInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName)
Expand All @@ -339,6 +364,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
runConversationalAnalyticsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName)
runForecastWithRestriction(t, allowedForecastTableFullName1, disallowedForecastTableFullName)
runForecastWithRestriction(t, allowedForecastTableFullName2, disallowedForecastTableFullName)
runAnalyzeContributionWithRestriction(t, allowedAnalyzeContributionTableFullName1, disallowedAnalyzeContributionTableFullName)
runAnalyzeContributionWithRestriction(t, allowedAnalyzeContributionTableFullName2, disallowedAnalyzeContributionTableFullName)
}

func TestBigQueryWriteModeAllowed(t *testing.T) {
Expand Down Expand Up @@ -3125,3 +3152,86 @@ func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTa
})
}
}

func runAnalyzeContributionWithRestriction(t *testing.T, allowedTableFullName, disallowedTableFullName string) {
allowedTableUnquoted := strings.ReplaceAll(allowedTableFullName, "`", "")
disallowedTableUnquoted := strings.ReplaceAll(disallowedTableFullName, "`", "")
disallowedDatasetFQN := strings.Join(strings.Split(disallowedTableUnquoted, ".")[0:2], ".")

testCases := []struct {
name string
inputData string
wantStatusCode int
wantInResult string
wantInError string
}{
{
name: "invoke with allowed table name",
inputData: allowedTableUnquoted,
wantStatusCode: http.StatusOK,
wantInResult: `"relative_difference"`,
},
{
name: "invoke with disallowed table name",
inputData: disallowedTableUnquoted,
wantStatusCode: http.StatusBadRequest,
wantInError: fmt.Sprintf("access to dataset '%s' (from table '%s') is not allowed", disallowedDatasetFQN, disallowedTableUnquoted),
},
{
name: "invoke with query on allowed table",
inputData: fmt.Sprintf("SELECT * FROM %s", allowedTableFullName),
wantStatusCode: http.StatusOK,
wantInResult: `"relative_difference"`,
},
{
name: "invoke with query on disallowed table",
inputData: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName),
wantStatusCode: http.StatusBadRequest,
wantInError: fmt.Sprintf("query in input_data accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN),
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
requestBodyMap := map[string]any{
"input_data": tc.inputData,
"contribution_metric": "SUM(metric)",
"is_test_col": "is_test",
"dimension_id_cols": []string{"dim1", "dim2"},
}
bodyBytes, err := json.Marshal(requestBodyMap)
if err != nil {
t.Fatalf("failed to marshal request body: %v", err)
}
body := bytes.NewBuffer(bodyBytes)

resp, bodyBytes := tests.RunRequest(t, http.MethodPost, "http://127.0.0.1:5000/api/tool/analyze-contribution-restricted/invoke", body, nil)

if resp.StatusCode != tc.wantStatusCode {
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
}

var respBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &respBody); err != nil {
t.Fatalf("error parsing response body: %v", err)
}

if tc.wantInResult != "" {
got, ok := respBody["result"].(string)
if !ok {
t.Fatalf("unable to find result in response body")
}

if !strings.Contains(got, tc.wantInResult) {
t.Errorf("unexpected result: got %q, want to contain %q", string(bodyBytes), tc.wantInResult)
}
}

if tc.wantInError != "" {
if !strings.Contains(string(bodyBytes), tc.wantInError) {
t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError)
}
}
})
}
}
Loading