Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions docs/en/resources/tools/bigquery/bigquery-get-dataset-info.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@ It's compatible with the following sources:

- [bigquery](../../sources/bigquery.md)

`bigquery-get-dataset-info` takes a `dataset` parameter to specify the dataset
on the given source. It also optionally accepts a `project` parameter to
define the Google Cloud project ID. If the `project` parameter is not provided,
the tool defaults to using the project defined in the source configuration.
`bigquery-get-dataset-info` accepts the following parameters:
- **`dataset`** (required): Specifies the dataset for which to retrieve metadata.
- **`project`** (optional): Defines the Google Cloud project ID. If not provided,
the tool defaults to the project from the source configuration.

The tool's behavior regarding these parameters is influenced by the
`allowedDatasets` restriction on the `bigquery` source:
- **Without `allowedDatasets` restriction:** The tool can retrieve metadata for
any dataset specified by the `dataset` and `project` parameters.
- **With `allowedDatasets` restriction:** Before retrieving metadata, the tool
verifies that the requested dataset is in the allowed list. If it is not, the
request is denied. If only one dataset is specified in the `allowedDatasets`
list, it will be used as the default value for the `dataset` parameter.

## Example

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
)

const kind string = "bigquery-get-dataset-info"
Expand All @@ -48,6 +49,8 @@ type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string
}

// validate compatible sources are still compatible
Expand Down Expand Up @@ -83,23 +86,33 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}

projectParameter := tools.NewStringParameterWithDefault(projectKey, s.BigQueryProject(), "The Google Cloud project ID containing the dataset.")
datasetParameter := tools.NewStringParameter(datasetKey, "The dataset to get metadata information.")
defaultProjectID := s.BigQueryProject()
projectDescription := "The Google Cloud project ID containing the dataset."
datasetDescription := "The dataset to get metadata information. Can be in `project.dataset` format."
var datasetParameter tools.Parameter
var projectParameter tools.Parameter

projectParameter, datasetParameter = bqutil.InitializeDatasetParameters(
s.BigQueryAllowedDatasets(),
defaultProjectID,
projectKey, datasetKey,
projectDescription, datasetDescription)
parameters := tools.Parameters{projectParameter, datasetParameter}

mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters)

// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
IsDatasetAllowed: s.IsDatasetAllowed,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
Expand All @@ -114,11 +127,12 @@ type Tool struct {
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`

Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string
IsDatasetAllowed func(projectID, datasetID string) bool
manifest tools.Manifest
mcpManifest tools.McpManifest
}

func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
Expand Down Expand Up @@ -147,11 +161,16 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}

if !t.IsDatasetAllowed(projectId, datasetId) {
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
}

dsHandle := bqClient.DatasetInProject(projectId, datasetId)

metadata, err := dsHandle.Metadata(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get metadata for dataset %s (in project %s): %w", datasetId, bqClient.Project(), err)
return nil, fmt.Errorf("failed to get metadata for dataset %s (in project %s): %w", datasetId, projectId, err)
}

return metadata, nil
Expand Down
60 changes: 58 additions & 2 deletions tests/bigquery/bigquery_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
"source": "my-instance",
"description": "Tool to list table within a dataset",
},
"get-dataset-info-restricted": map[string]any{
"kind": "bigquery-get-dataset-info",
"source": "my-instance",
"description": "Tool to get dataset info",
},
"get-table-info-restricted": map[string]any{
"kind": "bigquery-get-table-info",
"source": "my-instance",
Expand Down Expand Up @@ -324,6 +329,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
runListDatasetIdsWithRestriction(t, allowedDatasetName1, allowedDatasetName2)
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1)
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2)
runGetDatasetInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName)
runGetDatasetInfoWithRestriction(t, allowedDatasetName2, disallowedDatasetName)
runGetTableInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName)
runGetTableInfoWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName)
runExecuteSqlWithRestriction(t, allowedTableNameParam1, disallowedTableNameParam)
Expand Down Expand Up @@ -2474,7 +2481,7 @@ func runListDatasetIdsWithRestriction(t *testing.T, allowedDatasetName1, allowed
testCases := []struct {
name string
wantStatusCode int
wantElements []string
wantElements []string
}{
{
name: "invoke list-dataset-ids with restriction",
Expand All @@ -2499,7 +2506,7 @@ func runListDatasetIdsWithRestriction(t *testing.T, allowedDatasetName1, allowed
if err := json.Unmarshal(bodyBytes, &respBody); err != nil {
t.Fatalf("error parsing response body: %v", err)
}

gotJSON, ok := respBody["result"].(string)
if !ok {
t.Fatalf("unable to find 'result' as a string in response body: %s", string(bodyBytes))
Expand Down Expand Up @@ -2603,6 +2610,55 @@ func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowed
}
}

func runGetDatasetInfoWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName string) {
testCases := []struct {
name string
dataset string
wantStatusCode int
wantInError string
}{
{
name: "invoke on allowed dataset",
dataset: allowedDatasetName,
wantStatusCode: http.StatusOK,
},
{
name: "invoke on disallowed dataset",
dataset: disallowedDatasetName,
wantStatusCode: http.StatusBadRequest,
wantInError: fmt.Sprintf("access denied to dataset '%s'", disallowedDatasetName),
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"dataset":"%s"}`, tc.dataset)))
req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/get-dataset-info-restricted/invoke", body)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()

if resp.StatusCode != tc.wantStatusCode {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
}

if tc.wantInError != "" {
bodyBytes, _ := io.ReadAll(resp.Body)
if !strings.Contains(string(bodyBytes), tc.wantInError) {
t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError)
}
}
})
}
}

func runGetTableInfoWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName, allowedTableName, disallowedTableName string) {
testCases := []struct {
name string
Expand Down
Loading