diff --git a/.ci/integration.cloudbuild.yaml b/.ci/integration.cloudbuild.yaml index a16815994329..6f9e6f48ae5a 100644 --- a/.ci/integration.cloudbuild.yaml +++ b/.ci/integration.cloudbuild.yaml @@ -723,6 +723,25 @@ steps: "Oracle" \ oracle \ oracle + + - id: "serverless-spark" + name: golang:1 + waitFor: ["compile-test-binary"] + entrypoint: /bin/bash + env: + - "GOPATH=/gopath" + - "SERVERLESS_SPARK_PROJECT=$PROJECT_ID" + - "SERVERLESS_SPARK_LOCATION=$_REGION" + secretEnv: ["CLIENT_ID"] + volumes: + - name: "go" + path: "/gopath" + args: + - -c + - | + .ci/test_with_coverage.sh \ + "Serverless Spark" \ + serverlessspark availableSecrets: secretManager: diff --git a/cmd/root.go b/cmd/root.go index 6f84eeff7d8f..83162d042814 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -155,6 +155,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttables" _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql" _ "github.com/googleapis/genai-toolbox/internal/tools/redis" + _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches" _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerexecutesql" _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlisttables" _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannersql" @@ -196,6 +197,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/sources/oracle" _ "github.com/googleapis/genai-toolbox/internal/sources/postgres" _ "github.com/googleapis/genai-toolbox/internal/sources/redis" + _ "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" _ "github.com/googleapis/genai-toolbox/internal/sources/spanner" _ "github.com/googleapis/genai-toolbox/internal/sources/sqlite" _ "github.com/googleapis/genai-toolbox/internal/sources/tidb" diff --git a/cmd/root_test.go b/cmd/root_test.go index b8862592919d..3f4b1e2a628c 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -1254,6 +1254,7 @@ func TestPrebuiltTools(t *testing.T) { cloudsqlpgobsvconfig, _ := prebuiltconfigs.Get("cloud-sql-postgres-observability") cloudsqlmysqlobsvconfig, _ := prebuiltconfigs.Get("cloud-sql-mysql-observability") cloudsqlmssqlobsvconfig, _ := prebuiltconfigs.Get("cloud-sql-mssql-observability") + serverless_spark_config, _ := prebuiltconfigs.Get("serverless-spark") // Set environment variables t.Setenv("API_KEY", "your_api_key") @@ -1305,6 +1306,9 @@ func TestPrebuiltTools(t *testing.T) { t.Setenv("CLOUD_SQL_MSSQL_PASSWORD", "your_cloudsql_mssql_password") t.Setenv("CLOUD_SQL_POSTGRES_PASSWORD", "your_cloudsql_pg_password") + t.Setenv("SERVERLESS_SPARK_PROJECT", "your_gcp_project_id") + t.Setenv("SERVERLESS_SPARK_LOCATION", "your_gcp_location") + t.Setenv("POSTGRES_HOST", "localhost") t.Setenv("POSTGRES_PORT", "5432") t.Setenv("POSTGRES_DATABASE", "your_postgres_db") @@ -1457,6 +1461,16 @@ func TestPrebuiltTools(t *testing.T) { }, }, }, + { + name: "serverless spark prebuilt tools", + in: serverless_spark_config, + wantToolset: server.ToolsetConfigs{ + "serverless_spark_tools": tools.ToolsetConfig{ + Name: "serverless_spark_tools", + ToolNames: []string{"list_batches"}, + }, + }, + }, { name: "firestore prebuilt tools", in: firestoreconfig, diff --git a/docs/en/reference/prebuilt-tools.md b/docs/en/reference/prebuilt-tools.md index d6c399ff82cf..2a07d740be59 100644 --- a/docs/en/reference/prebuilt-tools.md +++ b/docs/en/reference/prebuilt-tools.md @@ -493,6 +493,20 @@ details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP. * `list_invalid_indexes`: Lists invalid indexes in the database. * `get_query_plan`: Generate the execution plan of a statement. +## Google Cloud Serverless for Apache Spark + +* `--prebuilt` value: `serverless-spark` +* **Environment Variables:** + * `SERVERLESS_SPARK_PROJECT`: The GCP project ID + * `SERVERLESS_SPARK_LOCATION`: The GCP Location. +* **Permissions:** + * **Dataproc Serverless Viewer** (`roles/dataproc.serverlessViewer`) to + view serverless batches. + * **Dataproc Serverless Editor** (`roles/dataproc.serverlessEditor`) to + view serverless batches. +* **Tools:** + * `list_batches`: Lists Spark batches. + ## Spanner (GoogleSQL dialect) * `--prebuilt` value: `spanner` diff --git a/docs/en/resources/sources/serverless-spark.md b/docs/en/resources/sources/serverless-spark.md new file mode 100644 index 000000000000..9484f16d05e9 --- /dev/null +++ b/docs/en/resources/sources/serverless-spark.md @@ -0,0 +1,57 @@ +--- +title: "Serverless for Apache Spark" +type: docs +weight: 1 +description: > + Google Cloud Serverless for Apache Spark lets you run Spark workloads without requiring you to provision and manage your own Spark cluster. +--- + +## About + +The [Serverless for Apache +Spark](https://cloud.google.com/dataproc-serverless/docs/overview) source allows +Toolbox to interact with Spark batches hosted on Google Cloud Serverless for +Apache Spark. + +## Available Tools + +- [`serverless-spark-list-batches`](../tools/serverless-spark/serverless-spark-list-batches.md) + List and filter Serverless Spark batches. + +## Requirements + +### IAM Permissions + +Serverless for Apache Spark uses [Identity and Access Management +(IAM)](https://cloud.google.com/bigquery/docs/access-control) to control user +and group access to serverless Spark resources like batches and sessions. + +Toolbox will use your [Application Default Credentials +(ADC)](https://cloud.google.com/docs/authentication#adc) to authorize and +authenticate when interacting with Google Cloud Serverless for Apache Spark. +When using this method, you need to ensure the IAM identity associated with your +ADC has the correct +[permissions](https://cloud.google.com/dataproc-serverless/docs/concepts/iam) +for the actions you intend to perform. Common roles include +`roles/dataproc.serverlessEditor` (which includes permissions to run batches) or +`roles/dataproc.serverlessViewer`. Follow this +[guide](https://cloud.google.com/docs/authentication/provide-credentials-adc) to +set up your ADC. + +## Example + +```yaml +sources: + my-serverless-spark-source: + kind: serverless-spark + project: my-project-id + location: us-central1 +``` + +## Reference + +| **field** | **type** | **required** | **description** | +| --------- | :------: | :----------: | ----------------------------------------------------------------- | +| kind | string | true | Must be "serverless-spark". | +| project | string | true | ID of the GCP project with Serverless for Apache Spark resources. | +| location | string | true | Location containing Serverless for Apache Spark resources. | diff --git a/docs/en/resources/tools/serverless-spark/_index.md b/docs/en/resources/tools/serverless-spark/_index.md new file mode 100644 index 000000000000..c05149468bae --- /dev/null +++ b/docs/en/resources/tools/serverless-spark/_index.md @@ -0,0 +1,7 @@ +--- +title: "Serverless for Apache Spark" +type: docs +weight: 1 +description: > + Tools that work with Google Cloud Serverless for Apache Spark Sources. +--- \ No newline at end of file diff --git a/docs/en/resources/tools/serverless-spark/serverless-spark-list-batches.md b/docs/en/resources/tools/serverless-spark/serverless-spark-list-batches.md new file mode 100644 index 000000000000..54d68eaa2e65 --- /dev/null +++ b/docs/en/resources/tools/serverless-spark/serverless-spark-list-batches.md @@ -0,0 +1,74 @@ +--- +title: "serverless-spark-list-batches" +type: docs +weight: 1 +description: > + A "serverless-spark-list-batches" tool returns a list of Spark batches from the source. +aliases: + - /resources/tools/serverless-spark-list-batches +--- + +## About + +A `serverless-spark-list-batches` tool returns a list of Spark batches from a +Google Cloud Serverless for Apache Spark source. It's compatible with the +following sources: + +- [serverless-spark](../../sources/serverless-spark.md) + +`serverless-spark-list-batches` accepts the following parameters: + +- **`filter`** (optional): A filter expression to limit the batches returned. + Filters are case sensitive and may contain multiple clauses combined with + logical operators (AND/OR). Supported fields are `batch_id`, `batch_uuid`, + `state`, `create_time`, and `labels`. For example: `state = RUNNING AND +create_time < "2023-01-01T00:00:00Z"`. +- **`pageSize`** (optional): The maximum number of batches to return in a single + page. +- **`pageToken`** (optional): A page token, received from a previous call, to + retrieve the next page of results. + +The tool gets the `project` and `location` from the source configuration. + +## Example + +```yaml +tools: + list_spark_batches: + kind: serverless-spark-list-batches + source: my-serverless-spark-source + description: Use this tool to list and filter serverless spark batches. +``` + +## Response Format + +```json +{ + "batches": [ + { + "name": "projects/my-project/locations/us-central1/batches/batch-abc-123", + "uuid": "a1b2c3d4-e5f6-7890-1234-567890abcdef", + "state": "SUCCEEDED", + "creator": "alice@example.com", + "createTime": "2023-10-27T10:00:00Z" + }, + { + "name": "projects/my-project/locations/us-central1/batches/batch-def-456", + "uuid": "b2c3d4e5-f6a7-8901-2345-678901bcdefa", + "state": "FAILED", + "creator": "alice@example.com", + "createTime": "2023-10-27T11:30:00Z" + } + ], + "nextPageToken": "abcd1234" +} +``` + +## Reference + +| **field** | **type** | **required** | **description** | +| ------------ | :------: | :----------: | -------------------------------------------------- | +| kind | string | true | Must be "serverless-spark-list-batches". | +| source | string | true | Name of the source the tool should use. | +| description | string | true | Description of the tool that is passed to the LLM. | +| authRequired | string[] | false | List of auth services required to invoke this tool | diff --git a/go.mod b/go.mod index 0c9c7af5c9b8..183f14258cff 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( cloud.google.com/go/bigtable v1.40.1 cloud.google.com/go/cloudsqlconn v1.18.1 cloud.google.com/go/dataplex v1.27.1 + cloud.google.com/go/dataproc/v2 v2.14.1 cloud.google.com/go/firestore v1.20.0 cloud.google.com/go/geminidataanalytics v0.2.1 cloud.google.com/go/spanner v1.86.0 diff --git a/go.sum b/go.sum index cde1c00fda27..5bf0343583ab 100644 --- a/go.sum +++ b/go.sum @@ -241,6 +241,8 @@ cloud.google.com/go/dataplex v1.27.1/go.mod h1:VB+xlYJiJ5kreonXsa2cHPj0A3CfPh/mg cloud.google.com/go/dataproc v1.7.0/go.mod h1:CKAlMjII9H90RXaMpSxQ8EU6dQx6iAYNPcYPOkSbi8s= cloud.google.com/go/dataproc v1.8.0/go.mod h1:5OW+zNAH0pMpw14JVrPONsxMQYMBqJuzORhIBfBn9uI= cloud.google.com/go/dataproc v1.12.0/go.mod h1:zrF3aX0uV3ikkMz6z4uBbIKyhRITnxvr4i3IjKsKrw4= +cloud.google.com/go/dataproc/v2 v2.14.1 h1:Kxq0iomU0H4MlVP4HYeYPNJnV+YxNctf/hFrprmGy5Y= +cloud.google.com/go/dataproc/v2 v2.14.1/go.mod h1:tSdkodShfzrrUNPDVEL6MdH9/mIEvp/Z9s9PBdbsZg8= cloud.google.com/go/dataqna v0.5.0/go.mod h1:90Hyk596ft3zUQ8NkFfvICSIfHFh1Bc7C4cK3vbhkeo= cloud.google.com/go/dataqna v0.6.0/go.mod h1:1lqNpM7rqNLVgWBJyk5NF6Uen2PHym0jtVJonplVsDA= cloud.google.com/go/dataqna v0.7.0/go.mod h1:Lx9OcIIeqCrw1a6KdO3/5KMP1wAmTc0slZWwP12Qq3c= diff --git a/internal/prebuiltconfigs/prebuiltconfigs_test.go b/internal/prebuiltconfigs/prebuiltconfigs_test.go index 9ed78ccfb4b5..f2f34eb08b6d 100644 --- a/internal/prebuiltconfigs/prebuiltconfigs_test.go +++ b/internal/prebuiltconfigs/prebuiltconfigs_test.go @@ -44,6 +44,7 @@ var expectedToolSources = []string{ "neo4j", "oceanbase", "postgres", + "serverless-spark", "spanner-postgres", "spanner", "sqlite", diff --git a/internal/prebuiltconfigs/tools/serverless-spark.yaml b/internal/prebuiltconfigs/tools/serverless-spark.yaml new file mode 100644 index 000000000000..e6fec1d4beec --- /dev/null +++ b/internal/prebuiltconfigs/tools/serverless-spark.yaml @@ -0,0 +1,28 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +sources: + serverless-spark-source: + kind: serverless-spark + project: ${SERVERLESS_SPARK_PROJECT} + location: ${SERVERLESS_SPARK_LOCATION} + +tools: + list_batches: + kind: serverless-spark-list-batches + source: serverless-spark-source + +toolsets: + serverless_spark_tools: + - list_batches diff --git a/internal/sources/serverlessspark/serverlessspark.go b/internal/sources/serverlessspark/serverlessspark.go new file mode 100644 index 000000000000..10cdaf1f7809 --- /dev/null +++ b/internal/sources/serverlessspark/serverlessspark.go @@ -0,0 +1,96 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package serverlessspark + +import ( + "context" + "fmt" + + dataproc "cloud.google.com/go/dataproc/v2/apiv1" + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util" + "go.opentelemetry.io/otel/trace" + "google.golang.org/api/option" +) + +const SourceKind string = "serverless-spark" + +// validate interface +var _ sources.SourceConfig = Config{} + +func init() { + if !sources.Register(SourceKind, newConfig) { + panic(fmt.Sprintf("source kind %q already registered", SourceKind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Project string `yaml:"project" validate:"required"` + Location string `yaml:"location" validate:"required"` +} + +func (r Config) SourceConfigKind() string { + return SourceKind +} + +func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) { + ua, err := util.UserAgentFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("error in User Agent retrieval: %s", err) + } + endpoint := fmt.Sprintf("%s-dataproc.googleapis.com:443", r.Location) + client, err := dataproc.NewBatchControllerClient(ctx, option.WithEndpoint(endpoint), option.WithUserAgent(ua)) + if err != nil { + return nil, fmt.Errorf("failed to create dataproc client: %w", err) + } + + s := &Source{ + Name: r.Name, + Kind: SourceKind, + Project: r.Project, + Location: r.Location, + Client: client, + } + return s, nil +} + +var _ sources.Source = &Source{} + +type Source struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Project string + Location string + Client *dataproc.BatchControllerClient +} + +func (s *Source) SourceKind() string { + return SourceKind +} + +func (s *Source) GetBatchControllerClient() *dataproc.BatchControllerClient { + return s.Client +} diff --git a/internal/sources/serverlessspark/serverlessspark_test.go b/internal/sources/serverlessspark/serverlessspark_test.go new file mode 100644 index 000000000000..3162f259ef32 --- /dev/null +++ b/internal/sources/serverlessspark/serverlessspark_test.go @@ -0,0 +1,125 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package serverlessspark_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" + "github.com/googleapis/genai-toolbox/internal/testutils" +) + +func TestParseFromYamlServerlessSpark(t *testing.T) { + tcs := []struct { + desc string + in string + want server.SourceConfigs + }{ + { + desc: "basic example", + in: ` + sources: + my-instance: + kind: serverless-spark + project: my-project + location: my-location + `, + want: server.SourceConfigs{ + "my-instance": serverlessspark.Config{ + Name: "my-instance", + Kind: serverlessspark.SourceKind, + Project: "my-project", + Location: "my-location", + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if !cmp.Equal(tc.want, got.Sources) { + t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources) + } + }) + } + +} + +func TestFailParseFromYaml(t *testing.T) { + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "extra field", + in: ` + sources: + my-instance: + kind: serverless-spark + project: my-project + location: my-location + foo: bar + `, + err: "unable to parse source \"my-instance\" as \"serverless-spark\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: serverless-spark\n 3 | location: my-location\n 4 | project: my-project", + }, + { + desc: "missing required field project", + in: ` + sources: + my-instance: + kind: serverless-spark + location: my-location + `, + err: "unable to parse source \"my-instance\" as \"serverless-spark\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag", + }, + { + desc: "missing required field location", + in: ` + sources: + my-instance: + kind: serverless-spark + project: my-project + `, + err: "unable to parse source \"my-instance\" as \"serverless-spark\": Key: 'Config.Location' Error:Field validation for 'Location' failed on the 'required' tag", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err == nil { + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if errStr != tc.err { + t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err) + } + }) + } +} diff --git a/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go b/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go new file mode 100644 index 000000000000..7a5cacb4da80 --- /dev/null +++ b/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go @@ -0,0 +1,205 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package serverlesssparklistbatches + +import ( + "context" + "fmt" + "time" + + "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" + "github.com/googleapis/genai-toolbox/internal/tools" + "google.golang.org/api/iterator" +) + +const kind = "serverless-spark-list-batches" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +// ToolConfigKind returns the unique name for this tool. +func (cfg Config) ToolConfigKind() string { + return kind +} + +// Initialize creates a new Tool instance. +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("source %q not found", cfg.Source) + } + + ds, ok := rawS.(*serverlessspark.Source) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, serverlessspark.SourceKind) + } + + desc := cfg.Description + if desc == "" { + desc = "Lists available Serverless Spark (aka Dataproc Serverless) batches" + } + + allParameters := tools.Parameters{ + tools.NewStringParameterWithRequired("filter", `Filter expression to limit the batches. Filters are case sensitive, and may contain multiple clauses combined with logical operators (AND/OR, case sensitive). Supported fields are batch_id, batch_uuid, state, create_time, and labels. e.g. state = RUNNING AND create_time < "2023-01-01T00:00:00Z" filters for batches in state RUNNING that were created before 2023-01-01. state = RUNNING AND labels.environment=production filters for batches in state in a RUNNING state that have a production environment label. Valid states are STATE_UNSPECIFIED, PENDING, RUNNING, CANCELLING, CANCELLED, SUCCEEDED, FAILED. Valid operators are < > <= >= = !=, and : as "has" for labels, meaning any non-empty value)`, false), + tools.NewIntParameterWithDefault("pageSize", 20, "The maximum number of batches to return in a single page (default 20)"), + tools.NewStringParameterWithRequired("pageToken", "A page token, received from a previous `ListBatches` call", false), + } + inputSchema, _ := allParameters.McpManifest() + + mcpManifest := tools.McpManifest{ + Name: cfg.Name, + Description: desc, + InputSchema: inputSchema, + } + + return Tool{ + Name: cfg.Name, + Kind: kind, + Source: ds, + AuthRequired: cfg.AuthRequired, + manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()}, + mcpManifest: mcpManifest, + Parameters: allParameters, + }, nil +} + +// Tool is the implementation of the tool. +type Tool struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Description string `yaml:"description"` + AuthRequired []string `yaml:"authRequired"` + + Source *serverlessspark.Source + + manifest tools.Manifest + mcpManifest tools.McpManifest + Parameters tools.Parameters +} + +// ListBatchesResponse is the response from the list batches API. +type ListBatchesResponse struct { + Batches []Batch `json:"batches"` + NextPageToken string `json:"nextPageToken"` +} + +// Batch represents a single batch job. +type Batch struct { + Name string `json:"name"` + UUID string `json:"uuid"` + State string `json:"state"` + Creator string `json:"creator"` + CreateTime string `json:"createTime"` +} + +// Invoke executes the tool's operation. +func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { + client := t.Source.GetBatchControllerClient() + + parent := fmt.Sprintf("projects/%s/locations/%s", t.Source.Project, t.Source.Location) + req := &dataprocpb.ListBatchesRequest{ + Parent: parent, + OrderBy: "create_time desc", + } + + paramMap := params.AsMap() + if ps, ok := paramMap["pageSize"]; ok && ps != nil { + req.PageSize = int32(ps.(int)) + if (req.PageSize) <= 0 { + return nil, fmt.Errorf("pageSize must be positive: %d", req.PageSize) + } + } + if pt, ok := paramMap["pageToken"]; ok && pt != nil { + req.PageToken = pt.(string) + } + if filter, ok := paramMap["filter"]; ok && filter != nil { + req.Filter = filter.(string) + } + + it := client.ListBatches(ctx, req) + pager := iterator.NewPager(it, int(req.PageSize), req.PageToken) + + var batchPbs []*dataprocpb.Batch + nextPageToken, err := pager.NextPage(&batchPbs) + if err != nil { + return nil, fmt.Errorf("failed to list batches: %w", err) + } + + batches := ToBatches(batchPbs) + + return ListBatchesResponse{Batches: batches, NextPageToken: nextPageToken}, nil +} + +// ToBatches converts a slice of protobuf Batch messages to a slice of Batch structs. +func ToBatches(batchPbs []*dataprocpb.Batch) []Batch { + batches := make([]Batch, 0, len(batchPbs)) + for _, batchPb := range batchPbs { + batch := Batch{ + Name: batchPb.Name, + UUID: batchPb.Uuid, + State: batchPb.State.Enum().String(), + Creator: batchPb.Creator, + CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339), + } + batches = append(batches, batch) + } + return batches +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { + return tools.ParseParams(t.Parameters, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(services []string) bool { + return tools.IsAuthorized(t.AuthRequired, services) +} + +func (t Tool) RequiresClientAuthorization() bool { + // Client OAuth not supported, rely on ADCs. + return false +} diff --git a/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches_test.go b/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches_test.go new file mode 100644 index 000000000000..95a11237408a --- /dev/null +++ b/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches_test.go @@ -0,0 +1,72 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package serverlesssparklistbatches_test + +import ( + "testing" + + "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches" +) + +func TestParseFromYaml(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: serverless-spark-list-batches + source: my-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": serverlesssparklistbatches.Config{ + Name: "example_tool", + Kind: "serverless-spark-list-batches", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got, yaml.Strict()) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/tests/serverlessspark/serverless_spark_integration_test.go b/tests/serverlessspark/serverless_spark_integration_test.go new file mode 100644 index 000000000000..0057b038d92d --- /dev/null +++ b/tests/serverlessspark/serverless_spark_integration_test.go @@ -0,0 +1,340 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package serverlessspark + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "reflect" + "regexp" + "testing" + "time" + + dataproc "cloud.google.com/go/dataproc/v2/apiv1" + "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches" + "github.com/googleapis/genai-toolbox/tests" + "google.golang.org/api/iterator" + "google.golang.org/api/option" +) + +var ( + serverlessSparkProject = os.Getenv("SERVERLESS_SPARK_PROJECT") + serverlessSparkLocation = os.Getenv("SERVERLESS_SPARK_LOCATION") +) + +func getServerlessSparkVars(t *testing.T) map[string]any { + switch "" { + case serverlessSparkProject: + t.Fatal("'SERVERLESS_SPARK_PROJECT' not set") + case serverlessSparkLocation: + t.Fatal("'SERVERLESS_SPARK_LOCATION' not set") + } + + return map[string]any{ + "kind": "serverless-spark", + "project": serverlessSparkProject, + "location": serverlessSparkLocation, + } +} + +func TestServerlessSparkToolEndpoints(t *testing.T) { + sourceConfig := getServerlessSparkVars(t) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-spark": sourceConfig, + }, + "authServices": map[string]any{ + "my-google-auth": map[string]any{ + "kind": "google", + "clientId": tests.ClientId, + }, + }, + "tools": map[string]any{ + "list-batches": map[string]any{ + "kind": "serverless-spark-list-batches", + "source": "my-spark", + }, + "list-batches-with-auth": map[string]any{ + "kind": "serverless-spark-list-batches", + "source": "my-spark", + "authRequired": []string{"my-google-auth"}, + }, + }, + } + + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + endpoint := fmt.Sprintf("%s-dataproc.googleapis.com:443", serverlessSparkLocation) + client, err := dataproc.NewBatchControllerClient(ctx, option.WithEndpoint(endpoint)) + if err != nil { + t.Fatalf("failed to create dataproc client: %v", err) + } + defer client.Close() + + runListBatchesTest(t, client, ctx) + runListBatchesErrorTest(t) + runListBatchesAuthTest(t) +} + +// runListBatchesTest invokes the running list-batches tool and ensures it returns the correct +// number of results. It can run successfully against any GCP project that has at least 2 succeeded +// or failed Serverless Spark batches, of any age. +func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context) { + batch2 := listBatchesRpc(t, client, ctx, "", 2, true) + batch20 := listBatchesRpc(t, client, ctx, "", 20, false) + + tcs := []struct { + name string + filter string + pageSize int + numPages int + want []serverlesssparklistbatches.Batch + }{ + {name: "one page", pageSize: 2, numPages: 1, want: batch2}, + {name: "two pages", pageSize: 1, numPages: 2, want: batch2}, + {name: "20 batches", pageSize: 20, numPages: 1, want: batch20}, + {name: "omit page size", numPages: 1, want: batch20}, + { + name: "filtered", + filter: "state = SUCCEEDED", + pageSize: 2, + numPages: 1, + want: listBatchesRpc(t, client, ctx, "state = SUCCEEDED", 2, true), + }, + { + name: "empty", + filter: "state = SUCCEEDED AND state = FAILED", + pageSize: 1, + numPages: 1, + want: nil, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + var actual []serverlesssparklistbatches.Batch + var pageToken string + for i := 0; i < tc.numPages; i++ { + request := map[string]any{ + "filter": tc.filter, + "pageToken": pageToken, + } + if tc.pageSize > 0 { + request["pageSize"] = tc.pageSize + } + + resp, err := invokeListBatches("list-batches", request, nil) + if err != nil { + t.Fatalf("invokeListBatches failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var body map[string]any + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + t.Fatalf("error parsing response body: %v", err) + } + + result, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + + var listResponse serverlesssparklistbatches.ListBatchesResponse + if err := json.Unmarshal([]byte(result), &listResponse); err != nil { + t.Fatalf("error unmarshalling result: %s", err) + } + actual = append(actual, listResponse.Batches...) + pageToken = listResponse.NextPageToken + } + + if !reflect.DeepEqual(actual, tc.want) { + t.Fatalf("unexpected batches: got %+v, want %+v", actual, tc.want) + } + }) + } +} + +func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, filter string, n int, exact bool) []serverlesssparklistbatches.Batch { + parent := fmt.Sprintf("projects/%s/locations/%s", serverlessSparkProject, serverlessSparkLocation) + req := &dataprocpb.ListBatchesRequest{ + Parent: parent, + PageSize: 2, + OrderBy: "create_time desc", + } + if filter != "" { + req.Filter = filter + } + + it := client.ListBatches(ctx, req) + pager := iterator.NewPager(it, n, "") + var batchPbs []*dataprocpb.Batch + _, err := pager.NextPage(&batchPbs) + if err != nil { + t.Fatalf("failed to list batches: %s", err) + } + if exact && len(batchPbs) != n { + t.Fatalf("expected exactly %d batches, got %d", n, len(batchPbs)) + } + if !exact && (len(batchPbs) == 0 || len(batchPbs) > n) { + t.Fatalf("expected between 1 and %d batches, got %d", n, len(batchPbs)) + } + + return serverlesssparklistbatches.ToBatches(batchPbs) +} + +func runListBatchesErrorTest(t *testing.T) { + tcs := []struct { + name string + pageSize int + wantCode int + wantMsg string + }{ + { + name: "zero page size", + pageSize: 0, + wantCode: http.StatusBadRequest, + wantMsg: "pageSize must be positive: 0", + }, + { + name: "negative page size", + pageSize: -1, + wantCode: http.StatusBadRequest, + wantMsg: "pageSize must be positive: -1", + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + request := map[string]any{ + "pageSize": tc.pageSize, + } + resp, err := invokeListBatches("list-batches", request, nil) + if err != nil { + t.Fatalf("invokeListBatches failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != tc.wantCode { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not %d, got %d: %s", tc.wantCode, resp.StatusCode, string(bodyBytes)) + } + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + + if !bytes.Contains(bodyBytes, []byte(tc.wantMsg)) { + t.Fatalf("response body does not contain %q: %s", tc.wantMsg, string(bodyBytes)) + } + }) + } +} + +func runListBatchesAuthTest(t *testing.T) { + idToken, err := tests.GetGoogleIdToken(tests.ClientId) + if err != nil { + t.Fatalf("error getting Google ID token: %s", err) + } + tcs := []struct { + name string + toolName string + headers map[string]string + wantStatus int + }{ + { + name: "valid auth token", + toolName: "list-batches-with-auth", + headers: map[string]string{"my-google-auth_token": idToken}, + wantStatus: http.StatusOK, + }, + { + name: "invalid auth token", + toolName: "list-batches-with-auth", + headers: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, + wantStatus: http.StatusUnauthorized, + }, + { + name: "no auth token", + toolName: "list-batches-with-auth", + headers: nil, + wantStatus: http.StatusUnauthorized, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + request := map[string]any{ + "pageSize": 1, + } + resp, err := invokeListBatches(tc.toolName, request, tc.headers) + if err != nil { + t.Fatalf("invokeListBatches failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != tc.wantStatus { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not %d, got %d: %s", tc.wantStatus, resp.StatusCode, string(bodyBytes)) + } + }) + } +} + +func invokeListBatches(toolName string, request map[string]any, headers map[string]string) (*http.Response, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + url := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", toolName) + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(requestBytes)) + if err != nil { + return nil, fmt.Errorf("unable to create request: %w", err) + } + req.Header.Add("Content-type", "application/json") + for k, v := range headers { + req.Header.Add(k, v) + } + + return http.DefaultClient.Do(req) +}