diff --git a/docs/en/resources/tools/bigquery/bigquery-get-dataset-info.md b/docs/en/resources/tools/bigquery/bigquery-get-dataset-info.md index 4f1f2cb716b1..59627c57cee0 100644 --- a/docs/en/resources/tools/bigquery/bigquery-get-dataset-info.md +++ b/docs/en/resources/tools/bigquery/bigquery-get-dataset-info.md @@ -15,10 +15,19 @@ It's compatible with the following sources: - [bigquery](../../sources/bigquery.md) -`bigquery-get-dataset-info` takes a `dataset` parameter to specify the dataset -on the given source. It also optionally accepts a `project` parameter to -define the Google Cloud project ID. If the `project` parameter is not provided, -the tool defaults to using the project defined in the source configuration. +`bigquery-get-dataset-info` accepts the following parameters: +- **`dataset`** (required): Specifies the dataset for which to retrieve metadata. +- **`project`** (optional): Defines the Google Cloud project ID. If not provided, + the tool defaults to the project from the source configuration. + +The tool's behavior regarding these parameters is influenced by the +`allowedDatasets` restriction on the `bigquery` source: +- **Without `allowedDatasets` restriction:** The tool can retrieve metadata for + any dataset specified by the `dataset` and `project` parameters. +- **With `allowedDatasets` restriction:** Before retrieving metadata, the tool + verifies that the requested dataset is in the allowed list. If it is not, the + request is denied. If only one dataset is specified in the `allowedDatasets` + list, it will be used as the default value for the `dataset` parameter. ## Example diff --git a/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go b/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go index e6f7b54ddcf7..6d3920c4ac2d 100644 --- a/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go +++ b/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go @@ -23,6 +23,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" "github.com/googleapis/genai-toolbox/internal/tools" + bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" ) const kind string = "bigquery-get-dataset-info" @@ -48,6 +49,8 @@ type compatibleSource interface { BigQueryClient() *bigqueryapi.Client BigQueryClientCreator() bigqueryds.BigqueryClientCreator UseClientAuthorization() bool + IsDatasetAllowed(projectID, datasetID string) bool + BigQueryAllowedDatasets() []string } // validate compatible sources are still compatible @@ -83,23 +86,33 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) } - projectParameter := tools.NewStringParameterWithDefault(projectKey, s.BigQueryProject(), "The Google Cloud project ID containing the dataset.") - datasetParameter := tools.NewStringParameter(datasetKey, "The dataset to get metadata information.") + defaultProjectID := s.BigQueryProject() + projectDescription := "The Google Cloud project ID containing the dataset." + datasetDescription := "The dataset to get metadata information. Can be in `project.dataset` format." + var datasetParameter tools.Parameter + var projectParameter tools.Parameter + + projectParameter, datasetParameter = bqutil.InitializeDatasetParameters( + s.BigQueryAllowedDatasets(), + defaultProjectID, + projectKey, datasetKey, + projectDescription, datasetDescription) parameters := tools.Parameters{projectParameter, datasetParameter} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters) // finish tool setup t := Tool{ - Name: cfg.Name, - Kind: kind, - Parameters: parameters, - AuthRequired: cfg.AuthRequired, - UseClientOAuth: s.UseClientAuthorization(), - ClientCreator: s.BigQueryClientCreator(), - Client: s.BigQueryClient(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Name: cfg.Name, + Kind: kind, + Parameters: parameters, + AuthRequired: cfg.AuthRequired, + UseClientOAuth: s.UseClientAuthorization(), + ClientCreator: s.BigQueryClientCreator(), + Client: s.BigQueryClient(), + IsDatasetAllowed: s.IsDatasetAllowed, + manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -114,11 +127,12 @@ type Tool struct { UseClientOAuth bool `yaml:"useClientOAuth"` Parameters tools.Parameters `yaml:"parameters"` - Client *bigqueryapi.Client - ClientCreator bigqueryds.BigqueryClientCreator - Statement string - manifest tools.Manifest - mcpManifest tools.McpManifest + Client *bigqueryapi.Client + ClientCreator bigqueryds.BigqueryClientCreator + Statement string + IsDatasetAllowed func(projectID, datasetID string) bool + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { @@ -147,11 +161,16 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) } } + + if !t.IsDatasetAllowed(projectId, datasetId) { + return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId) + } + dsHandle := bqClient.DatasetInProject(projectId, datasetId) metadata, err := dsHandle.Metadata(ctx) if err != nil { - return nil, fmt.Errorf("failed to get metadata for dataset %s (in project %s): %w", datasetId, bqClient.Project(), err) + return nil, fmt.Errorf("failed to get metadata for dataset %s (in project %s): %w", datasetId, projectId, err) } return metadata, nil diff --git a/tests/bigquery/bigquery_integration_test.go b/tests/bigquery/bigquery_integration_test.go index 9fffda9d1e88..40a0fee82530 100644 --- a/tests/bigquery/bigquery_integration_test.go +++ b/tests/bigquery/bigquery_integration_test.go @@ -275,6 +275,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) { "source": "my-instance", "description": "Tool to list table within a dataset", }, + "get-dataset-info-restricted": map[string]any{ + "kind": "bigquery-get-dataset-info", + "source": "my-instance", + "description": "Tool to get dataset info", + }, "get-table-info-restricted": map[string]any{ "kind": "bigquery-get-table-info", "source": "my-instance", @@ -324,6 +329,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) { runListDatasetIdsWithRestriction(t, allowedDatasetName1, allowedDatasetName2) runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1) runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2) + runGetDatasetInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName) + runGetDatasetInfoWithRestriction(t, allowedDatasetName2, disallowedDatasetName) runGetTableInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName) runGetTableInfoWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName) runExecuteSqlWithRestriction(t, allowedTableNameParam1, disallowedTableNameParam) @@ -2474,7 +2481,7 @@ func runListDatasetIdsWithRestriction(t *testing.T, allowedDatasetName1, allowed testCases := []struct { name string wantStatusCode int - wantElements []string + wantElements []string }{ { name: "invoke list-dataset-ids with restriction", @@ -2499,7 +2506,7 @@ func runListDatasetIdsWithRestriction(t *testing.T, allowedDatasetName1, allowed if err := json.Unmarshal(bodyBytes, &respBody); err != nil { t.Fatalf("error parsing response body: %v", err) } - + gotJSON, ok := respBody["result"].(string) if !ok { t.Fatalf("unable to find 'result' as a string in response body: %s", string(bodyBytes)) @@ -2603,6 +2610,55 @@ func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowed } } +func runGetDatasetInfoWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName string) { + testCases := []struct { + name string + dataset string + wantStatusCode int + wantInError string + }{ + { + name: "invoke on allowed dataset", + dataset: allowedDatasetName, + wantStatusCode: http.StatusOK, + }, + { + name: "invoke on disallowed dataset", + dataset: disallowedDatasetName, + wantStatusCode: http.StatusBadRequest, + wantInError: fmt.Sprintf("access denied to dataset '%s'", disallowedDatasetName), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"dataset":"%s"}`, tc.dataset))) + req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/get-dataset-info-restricted/invoke", body) + if err != nil { + t.Fatalf("unable to create request: %s", err) + } + req.Header.Add("Content-type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to send request: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != tc.wantStatusCode { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes)) + } + + if tc.wantInError != "" { + bodyBytes, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(bodyBytes), tc.wantInError) { + t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError) + } + } + }) + } +} + func runGetTableInfoWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName, allowedTableName, disallowedTableName string) { testCases := []struct { name string