diff --git a/.ci/integration.cloudbuild.yaml b/.ci/integration.cloudbuild.yaml index 17cfa1d3bf44..3dc6896ce021 100644 --- a/.ci/integration.cloudbuild.yaml +++ b/.ci/integration.cloudbuild.yaml @@ -751,8 +751,9 @@ steps: entrypoint: /bin/bash env: - "GOPATH=/gopath" - - "SERVERLESS_SPARK_PROJECT=$PROJECT_ID" - "SERVERLESS_SPARK_LOCATION=$_REGION" + - "SERVERLESS_SPARK_PROJECT=$PROJECT_ID" + - "SERVERLESS_SPARK_SERVICE_ACCOUNT=$SERVICE_ACCOUNT_EMAIL" secretEnv: ["CLIENT_ID"] volumes: - name: "go" diff --git a/cmd/root.go b/cmd/root.go index 885ba0b68710..9fa25c6775dc 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -163,6 +163,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistviews" _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql" _ "github.com/googleapis/genai-toolbox/internal/tools/redis" + _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcancelbatch" _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch" _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches" _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerexecutesql" diff --git a/cmd/root_test.go b/cmd/root_test.go index 0449adca008a..cb971d1a0497 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -1474,7 +1474,7 @@ func TestPrebuiltTools(t *testing.T) { wantToolset: server.ToolsetConfigs{ "serverless_spark_tools": tools.ToolsetConfig{ Name: "serverless_spark_tools", - ToolNames: []string{"list_batches", "get_batch"}, + ToolNames: []string{"list_batches", "get_batch", "cancel_batch"}, }, }, }, diff --git a/docs/en/resources/sources/serverless-spark.md b/docs/en/resources/sources/serverless-spark.md index c6ebbfc5d834..0d137d36b7e2 100644 --- a/docs/en/resources/sources/serverless-spark.md +++ b/docs/en/resources/sources/serverless-spark.md @@ -19,6 +19,8 @@ Apache Spark. List and filter Serverless Spark batches. - [`serverless-spark-get-batch`](../tools/serverless-spark/serverless-spark-get-batch.md) Get a Serverless Spark batch. +- [`serverless-spark-cancel-batch`](../tools/serverless-spark/serverless-spark-cancel-batch.md) + Cancel a running Serverless Spark batch operation. ## Requirements diff --git a/docs/en/resources/tools/serverless-spark/_index.md b/docs/en/resources/tools/serverless-spark/_index.md index 7e9867aeb639..4974a07b19dc 100644 --- a/docs/en/resources/tools/serverless-spark/_index.md +++ b/docs/en/resources/tools/serverless-spark/_index.md @@ -8,3 +8,4 @@ description: > - [serverless-spark-get-batch](./serverless-spark-get-batch.md) - [serverless-spark-list-batches](./serverless-spark-list-batches.md) +- [serverless-spark-cancel-batch](./serverless-spark-cancel-batch.md) diff --git a/docs/en/resources/tools/serverless-spark/serverless-spark-cancel-batch.md b/docs/en/resources/tools/serverless-spark/serverless-spark-cancel-batch.md new file mode 100644 index 000000000000..4321d64fee6f --- /dev/null +++ b/docs/en/resources/tools/serverless-spark/serverless-spark-cancel-batch.md @@ -0,0 +1,51 @@ +--- +title: "serverless-spark-cancel-batch" +type: docs +weight: 2 +description: > + A "serverless-spark-cancel-batch" tool cancels a running Spark batch operation. +aliases: + - /resources/tools/serverless-spark-cancel-batch +--- + +## About + + `serverless-spark-cancel-batch` tool cancels a running Spark batch operation in + a Google Cloud Serverless for Apache Spark source. The cancellation request is + asynchronous, so the batch state will not change immediately after the tool + returns; it can take a minute or so for the cancellation to be reflected. + +It's compatible with the following sources: + +- [serverless-spark](../../sources/serverless-spark.md) + +`serverless-spark-cancel-batch` accepts the following parameters: + +- **`operation`** (required): The name of the operation to cancel. For example, for `projects/my-project/locations/us-central1/operations/my-operation`, you would pass `my-operation`. + +The tool inherits the `project` and `location` from the source configuration. + +## Example + +```yaml +tools: + cancel_spark_batch: + kind: serverless-spark-cancel-batch + source: my-serverless-spark-source + description: Use this tool to cancel a running serverless spark batch operation. +``` + +## Response Format + +```json +"Cancelled [projects/my-project/regions/us-central1/operations/my-operation]." +``` + +## Reference + +| **field** | **type** | **required** | **description** | +| ------------ | :------: | :----------: | -------------------------------------------------- | +| kind | string | true | Must be "serverless-spark-cancel-batch". | +| source | string | true | Name of the source the tool should use. | +| description | string | true | Description of the tool that is passed to the LLM. | +| authRequired | string[] | false | List of auth services required to invoke this tool | diff --git a/go.mod b/go.mod index b194e13f4e99..2e91d9d3fe41 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( cloud.google.com/go/dataproc/v2 v2.15.0 cloud.google.com/go/firestore v1.20.0 cloud.google.com/go/geminidataanalytics v0.2.1 + cloud.google.com/go/longrunning v0.7.0 cloud.google.com/go/spanner v1.86.1 github.com/ClickHouse/clickhouse-go/v2 v2.40.3 github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.53.0 @@ -80,7 +81,6 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.9.0 // indirect cloud.google.com/go/iam v1.5.3 // indirect - cloud.google.com/go/longrunning v0.7.0 // indirect cloud.google.com/go/monitoring v1.24.3 // indirect cloud.google.com/go/trace v1.11.7 // indirect filippo.io/edwards25519 v1.1.0 // indirect diff --git a/internal/prebuiltconfigs/tools/serverless-spark.yaml b/internal/prebuiltconfigs/tools/serverless-spark.yaml index 3ef0a2834aa9..7d78b18a95cf 100644 --- a/internal/prebuiltconfigs/tools/serverless-spark.yaml +++ b/internal/prebuiltconfigs/tools/serverless-spark.yaml @@ -25,8 +25,12 @@ tools: get_batch: kind: serverless-spark-get-batch source: serverless-spark-source + cancel_batch: + kind: serverless-spark-cancel-batch + source: serverless-spark-source toolsets: serverless_spark_tools: - list_batches - get_batch + - cancel_batch diff --git a/internal/sources/serverlessspark/serverlessspark.go b/internal/sources/serverlessspark/serverlessspark.go index 10cdaf1f7809..7a8a769635fa 100644 --- a/internal/sources/serverlessspark/serverlessspark.go +++ b/internal/sources/serverlessspark/serverlessspark.go @@ -22,6 +22,7 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" + "cloud.google.com/go/longrunning/autogen" "go.opentelemetry.io/otel/trace" "google.golang.org/api/option" ) @@ -66,13 +67,18 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So if err != nil { return nil, fmt.Errorf("failed to create dataproc client: %w", err) } + opsClient, err := longrunning.NewOperationsClient(ctx, option.WithEndpoint(endpoint), option.WithUserAgent(ua)) + if err != nil { + return nil, fmt.Errorf("failed to create longrunning client: %w", err) + } s := &Source{ - Name: r.Name, - Kind: SourceKind, - Project: r.Project, - Location: r.Location, - Client: client, + Name: r.Name, + Kind: SourceKind, + Project: r.Project, + Location: r.Location, + Client: client, + OpsClient: opsClient, } return s, nil } @@ -80,11 +86,12 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So var _ sources.Source = &Source{} type Source struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Project string - Location string - Client *dataproc.BatchControllerClient + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Project string + Location string + Client *dataproc.BatchControllerClient + OpsClient *longrunning.OperationsClient } func (s *Source) SourceKind() string { @@ -94,3 +101,17 @@ func (s *Source) SourceKind() string { func (s *Source) GetBatchControllerClient() *dataproc.BatchControllerClient { return s.Client } + +func (s *Source) GetOperationsClient(ctx context.Context) (*longrunning.OperationsClient, error) { + return s.OpsClient, nil +} + +func (s *Source) Close() error { + if err := s.Client.Close(); err != nil { + return err + } + if err := s.OpsClient.Close(); err != nil { + return err + } + return nil +} diff --git a/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go b/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go new file mode 100644 index 000000000000..36a0be409531 --- /dev/null +++ b/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go @@ -0,0 +1,162 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package serverlesssparkcancelbatch + +import ( + "context" + "fmt" + "strings" + + "cloud.google.com/go/longrunning/autogen/longrunningpb" + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" + "github.com/googleapis/genai-toolbox/internal/tools" +) + +const kind = "serverless-spark-cancel-batch" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +// ToolConfigKind returns the unique name for this tool. +func (cfg Config) ToolConfigKind() string { + return kind +} + +// Initialize creates a new Tool instance. +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("source %q not found", cfg.Source) + } + + ds, ok := rawS.(*serverlessspark.Source) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, serverlessspark.SourceKind) + } + + desc := cfg.Description + if desc == "" { + desc = "Cancels a running Serverless Spark (aka Dataproc Serverless) batch operation. Note that the batch state will not change immediately after the tool returns; it can take a minute or so for the cancellation to be reflected." + } + + allParameters := tools.Parameters{ + tools.NewStringParameter("operation", "The name of the operation to cancel, e.g. for \"projects/my-project/locations/us-central1/operations/my-operation\", pass \"my-operation\""), + } + inputSchema, _ := allParameters.McpManifest() + + mcpManifest := tools.McpManifest{ + Name: cfg.Name, + Description: desc, + InputSchema: inputSchema, + } + + return &Tool{ + Name: cfg.Name, + Kind: kind, + Source: ds, + AuthRequired: cfg.AuthRequired, + manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()}, + mcpManifest: mcpManifest, + Parameters: allParameters, + }, nil +} + +// Tool is the implementation of the tool. +type Tool struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Description string `yaml:"description"` + AuthRequired []string `yaml:"authRequired"` + + Source *serverlessspark.Source + + manifest tools.Manifest + mcpManifest tools.McpManifest + Parameters tools.Parameters +} + +// Invoke executes the tool's operation. +func (t *Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { + client, err := t.Source.GetOperationsClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get operations client: %w", err) + } + + paramMap := params.AsMap() + operation, ok := paramMap["operation"].(string) + if !ok { + return nil, fmt.Errorf("missing required parameter: operation") + } + + if strings.Contains(operation, "/") { + return nil, fmt.Errorf("operation must be a short operation name without '/': %s", operation) + } + + req := &longrunningpb.CancelOperationRequest{ + Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", t.Source.Project, t.Source.Location, operation), + } + + err = client.CancelOperation(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to cancel operation: %w", err) + } + + return fmt.Sprintf("Cancelled [%s].", operation), nil +} + +func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { + return tools.ParseParams(t.Parameters, data, claims) +} + +func (t *Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t *Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t *Tool) Authorized(services []string) bool { + return tools.IsAuthorized(t.AuthRequired, services) +} + +func (t *Tool) RequiresClientAuthorization() bool { + // Client OAuth not supported, rely on ADCs. + return false +} diff --git a/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch_test.go b/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch_test.go new file mode 100644 index 000000000000..5348399a321c --- /dev/null +++ b/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch_test.go @@ -0,0 +1,72 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package serverlesssparkcancelbatch_test + +import ( + "testing" + + "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcancelbatch" +) + +func TestParseFromYaml(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: serverless-spark-cancel-batch + source: my-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": serverlesssparkcancelbatch.Config{ + Name: "example_tool", + Kind: "serverless-spark-cancel-batch", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got, yaml.Strict()) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go b/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go index 7a5cacb4da80..a45e504a9bf6 100644 --- a/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go +++ b/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go @@ -127,6 +127,7 @@ type Batch struct { State string `json:"state"` Creator string `json:"creator"` CreateTime string `json:"createTime"` + Operation string `json:"operation"` } // Invoke executes the tool's operation. @@ -177,6 +178,7 @@ func ToBatches(batchPbs []*dataprocpb.Batch) []Batch { State: batchPb.State.Enum().String(), Creator: batchPb.Creator, CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339), + Operation: batchPb.Operation, } batches = append(batches, batch) } diff --git a/tests/serverlessspark/serverless_spark_integration_test.go b/tests/serverlessspark/serverless_spark_integration_test.go index ffe87c9a4fb7..f2fa106f8a31 100644 --- a/tests/serverlessspark/serverless_spark_integration_test.go +++ b/tests/serverlessspark/serverless_spark_integration_test.go @@ -24,6 +24,7 @@ import ( "os" "reflect" "regexp" + "slices" "strings" "testing" "time" @@ -41,16 +42,19 @@ import ( ) var ( - serverlessSparkProject = os.Getenv("SERVERLESS_SPARK_PROJECT") - serverlessSparkLocation = os.Getenv("SERVERLESS_SPARK_LOCATION") + serverlessSparkLocation = os.Getenv("SERVERLESS_SPARK_LOCATION") + serverlessSparkProject = os.Getenv("SERVERLESS_SPARK_PROJECT") + serverlessSparkServiceAccount = os.Getenv("SERVERLESS_SPARK_SERVICE_ACCOUNT") ) func getServerlessSparkVars(t *testing.T) map[string]any { switch "" { - case serverlessSparkProject: - t.Fatal("'SERVERLESS_SPARK_PROJECT' not set") case serverlessSparkLocation: t.Fatal("'SERVERLESS_SPARK_LOCATION' not set") + case serverlessSparkProject: + t.Fatal("'SERVERLESS_SPARK_PROJECT' not set") + case serverlessSparkServiceAccount: + t.Fatal("'SERVERLESS_SPARK_SERVICE_ACCOUNT' not set") } return map[string]any{ @@ -94,6 +98,15 @@ func TestServerlessSparkToolEndpoints(t *testing.T) { "source": "my-spark", "authRequired": []string{"my-google-auth"}, }, + "cancel-batch": map[string]any{ + "kind": "serverless-spark-cancel-batch", + "source": "my-spark", + }, + "cancel-batch-with-auth": map[string]any{ + "kind": "serverless-spark-cancel-batch", + "source": "my-spark", + "authRequired": []string{"my-google-auth"}, + }, }, } @@ -118,16 +131,264 @@ func TestServerlessSparkToolEndpoints(t *testing.T) { } defer client.Close() - runListBatchesTest(t, client, ctx) + t.Run("list-batches", func(t *testing.T) { + // list-batches is sensitive to state changes, so this test must run sequentially. + t.Run("success", func(t *testing.T) { + runListBatchesTest(t, client, ctx) + }) + t.Run("errors", func(t *testing.T) { + t.Parallel() + tcs := []struct { + name string + toolName string + request map[string]any + wantCode int + wantMsg string + }{ + { + name: "zero page size", + toolName: "list-batches", + request: map[string]any{"pageSize": 0}, + wantCode: http.StatusBadRequest, + wantMsg: "pageSize must be positive: 0", + }, + { + name: "negative page size", + toolName: "list-batches", + request: map[string]any{"pageSize": -1}, + wantCode: http.StatusBadRequest, + wantMsg: "pageSize must be positive: -1", + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + testError(t, tc.toolName, tc.request, tc.wantCode, tc.wantMsg) + }) + } + }) + t.Run("auth", func(t *testing.T) { + t.Parallel() + runAuthTest(t, "list-batches-with-auth", map[string]any{"pageSize": 1}, http.StatusOK) + }) + }) + + // The following tool tests are independent and can run in parallel with each other. + t.Run("parallel-tool-tests", func(t *testing.T) { + t.Run("get-batch", func(t *testing.T) { + t.Parallel() + fullName := listBatchesRpc(t, client, ctx, "", 1, true)[0].Name + t.Run("success", func(t *testing.T) { + t.Parallel() + runGetBatchTest(t, client, ctx, fullName) + }) + t.Run("errors", func(t *testing.T) { + t.Parallel() + missingBatchFullName := fmt.Sprintf("projects/%s/locations/%s/batches/INVALID_BATCH", serverlessSparkProject, serverlessSparkLocation) + tcs := []struct { + name string + toolName string + request map[string]any + wantCode int + wantMsg string + }{ + { + name: "missing batch", + toolName: "get-batch", + request: map[string]any{"name": "INVALID_BATCH"}, + wantCode: http.StatusBadRequest, + wantMsg: fmt.Sprintf("Not found: Batch projects/%s/locations/%s/batches/INVALID_BATCH", serverlessSparkProject, serverlessSparkLocation), + }, + { + name: "full batch name", + toolName: "get-batch", + request: map[string]any{"name": missingBatchFullName}, + wantCode: http.StatusBadRequest, + wantMsg: fmt.Sprintf("name must be a short batch name without '/': %s", missingBatchFullName), + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + testError(t, tc.toolName, tc.request, tc.wantCode, tc.wantMsg) + }) + } + }) + t.Run("auth", func(t *testing.T) { + t.Parallel() + runAuthTest(t, "get-batch-with-auth", map[string]any{"name": shortName(fullName)}, http.StatusOK) + }) + }) + + t.Run("cancel-batch", func(t *testing.T) { + t.Parallel() + t.Run("success", func(t *testing.T) { + t.Parallel() + tcs := []struct { + name string + getBatchName func(t *testing.T) string + }{ + { + name: "running batch", + getBatchName: func(t *testing.T) string { + return createBatch(t, client, ctx) + }, + }, + { + name: "succeeded batch", + getBatchName: func(t *testing.T) string { + return listBatchesRpc(t, client, ctx, "state = SUCCEEDED", 1, true)[0].Name + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + runCancelBatchTest(t, client, ctx, tc.getBatchName(t)) + }) + } + }) + t.Run("errors", func(t *testing.T) { + t.Parallel() + // Find a batch that's already completed. + completedBatchOp := listBatchesRpc(t, client, ctx, "state = SUCCEEDED", 1, true)[0].Operation + fullOpName := fmt.Sprintf("projects/%s/locations/%s/operations/%s", serverlessSparkProject, serverlessSparkLocation, shortName(completedBatchOp)) + tcs := []struct { + name string + toolName string + request map[string]any + wantCode int + wantMsg string + }{ + { + name: "missing op parameter", + toolName: "cancel-batch", + request: map[string]any{}, + wantCode: http.StatusBadRequest, + wantMsg: "parameter \\\"operation\\\" is required", + }, + { + name: "nonexistent op", + toolName: "cancel-batch", + request: map[string]any{"operation": "INVALID_OPERATION"}, + wantCode: http.StatusBadRequest, + wantMsg: "Operation not found", + }, + { + name: "full op name", + toolName: "cancel-batch", + request: map[string]any{"operation": fullOpName}, + wantCode: http.StatusBadRequest, + wantMsg: fmt.Sprintf("operation must be a short operation name without '/': %s", fullOpName), + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + testError(t, tc.toolName, tc.request, tc.wantCode, tc.wantMsg) + }) + } + }) + t.Run("auth", func(t *testing.T) { + t.Parallel() + runAuthTest(t, "cancel-batch-with-auth", map[string]any{"operation": "INVALID_OPERATION"}, http.StatusBadRequest) + }) + }) + }) +} + +func waitForBatch(t *testing.T, client *dataproc.BatchControllerClient, parentCtx context.Context, batch string, desiredStates []dataprocpb.Batch_State, timeout time.Duration) { + ctx, cancel := context.WithTimeout(parentCtx, timeout) + defer cancel() - fullName := listBatchesRpc(t, client, ctx, "", 1, true)[0].Name - runGetBatchTest(t, client, ctx, fullName) + for { + select { + case <-ctx.Done(): + t.Fatalf("timed out waiting for batch %s to reach one of states %v", batch, desiredStates) + default: + } + + getReq := &dataprocpb.GetBatchRequest{Name: batch} + batch, err := client.GetBatch(ctx, getReq) + if err != nil { + t.Fatalf("failed to get batch %s: %v", batch, err) + } + + if slices.Contains(desiredStates, batch.State) { + return + } + + if batch.State == dataprocpb.Batch_FAILED || batch.State == dataprocpb.Batch_CANCELLED || batch.State == dataprocpb.Batch_SUCCEEDED { + t.Fatalf("batch op %s is in a terminal state %s, but wanted one of %v. State message: %s", batch, batch.State, desiredStates, batch.StateMessage) + } + time.Sleep(2 * time.Second) + } +} - runErrorTest(t) +// createBatch creates a test batch and immediately returns the batch name, without waiting for the +// batch to start or complete. +func createBatch(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context) string { + parent := fmt.Sprintf("projects/%s/locations/%s", serverlessSparkProject, serverlessSparkLocation) + req := &dataprocpb.CreateBatchRequest{ + Parent: parent, + Batch: &dataprocpb.Batch{ + BatchConfig: &dataprocpb.Batch_SparkBatch{ + SparkBatch: &dataprocpb.SparkBatch{ + Driver: &dataprocpb.SparkBatch_MainClass{ + MainClass: "org.apache.spark.examples.SparkPi", + }, + JarFileUris: []string{ + "file:///usr/lib/spark/examples/jars/spark-examples.jar", + }, + Args: []string{"1000"}, + }, + }, + EnvironmentConfig: &dataprocpb.EnvironmentConfig{ + ExecutionConfig: &dataprocpb.ExecutionConfig{ + ServiceAccount: serverlessSparkServiceAccount, + }, + }, + }, + } + + createOp, err := client.CreateBatch(ctx, req) + if err != nil { + t.Fatalf("failed to create batch: %v", err) + } + meta, err := createOp.Metadata() + if err != nil { + t.Fatalf("failed to get batch metadata: %v", err) + } - // Get the most recent batch, which is all we need for this test. - runAuthTest(t, "list-batches-with-auth", map[string]any{"pageSize": 1}) - runAuthTest(t, "get-batch-with-auth", map[string]any{"name": shortName(fullName)}) + // Wait for the batch to become at least PENDING; it typically takes >10s to go from PENDING to + // RUNNING, giving us plenty of time to cancel it before it completes. + waitForBatch(t, client, ctx, meta.Batch, []dataprocpb.Batch_State{dataprocpb.Batch_PENDING, dataprocpb.Batch_RUNNING}, 1*time.Minute) + return meta.Batch +} + +func runCancelBatchTest(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, batchName string) { + // First get the batch details directly from the Go proto API. + batch, err := client.GetBatch(ctx, &dataprocpb.GetBatchRequest{Name: batchName}) + if err != nil { + t.Fatalf("failed to get batch: %s", err) + } + + request := map[string]any{"operation": shortName(batch.Operation)} + resp, err := invokeTool("cancel-batch", request, nil) + if err != nil { + t.Fatalf("invokeTool failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + if batch.State != dataprocpb.Batch_SUCCEEDED { + waitForBatch(t, client, ctx, batchName, []dataprocpb.Batch_State{dataprocpb.Batch_CANCELLING, dataprocpb.Batch_CANCELLED}, 2*time.Minute) + } } // runListBatchesTest invokes the running list-batches tool and ensures it returns the correct @@ -165,7 +426,8 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct } for _, tc := range tcs { - t.Run("list-batches "+tc.name, func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() var actual []serverlesssparklistbatches.Batch var pageToken string for i := 0; i < tc.numPages; i++ { @@ -241,7 +503,7 @@ func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx co return serverlesssparklistbatches.ToBatches(batchPbs) } -func runAuthTest(t *testing.T, toolName string, request map[string]any) { +func runAuthTest(t *testing.T, toolName string, request map[string]any, wantStatus int) { idToken, err := tests.GetGoogleIdToken(tests.ClientId) if err != nil { t.Fatalf("error getting Google ID token: %s", err) @@ -254,7 +516,7 @@ func runAuthTest(t *testing.T, toolName string, request map[string]any) { { name: "valid auth token", headers: map[string]string{"my-google-auth_token": idToken}, - wantStatus: http.StatusOK, + wantStatus: wantStatus, }, { name: "invalid auth token", @@ -269,7 +531,8 @@ func runAuthTest(t *testing.T, toolName string, request map[string]any) { } for _, tc := range tcs { - t.Run(toolName+" "+tc.name, func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() resp, err := invokeTool(toolName, request, tc.headers) if err != nil { t.Fatalf("invokeTool failed: %v", err) @@ -317,7 +580,8 @@ func runGetBatchTest(t *testing.T, client *dataproc.BatchControllerClient, ctx c } for _, tc := range tcs { - t.Run("get-batch "+tc.name, func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() request := map[string]any{"name": tc.batchName} resp, err := invokeTool("get-batch", request, nil) if err != nil { @@ -351,67 +615,25 @@ func runGetBatchTest(t *testing.T, client *dataproc.BatchControllerClient, ctx c } } -func runErrorTest(t *testing.T) { - missingBatchFullName := fmt.Sprintf("projects/%s/locations/%s/batches/INVALID_BATCH", serverlessSparkProject, serverlessSparkLocation) - tcs := []struct { - name string - toolName string - request map[string]any - wantCode int - wantMsg string - }{ - { - name: "list-batches zero page size", - toolName: "list-batches", - request: map[string]any{"pageSize": 0}, - wantCode: http.StatusBadRequest, - wantMsg: "pageSize must be positive: 0", - }, - { - name: "list-batches negative page size", - toolName: "list-batches", - request: map[string]any{"pageSize": -1}, - wantCode: http.StatusBadRequest, - wantMsg: "pageSize must be positive: -1", - }, - { - name: "get-batch missing batch", - toolName: "get-batch", - request: map[string]any{"name": "INVALID_BATCH"}, - wantCode: http.StatusBadRequest, - wantMsg: fmt.Sprintf("Not found: Batch projects/%s/locations/%s/batches/INVALID_BATCH", serverlessSparkProject, serverlessSparkLocation), - }, - { - name: "get-batch full batch name", - toolName: "get-batch", - request: map[string]any{"name": missingBatchFullName}, - wantCode: http.StatusBadRequest, - wantMsg: fmt.Sprintf("name must be a short batch name without '/': %s", missingBatchFullName), - }, +func testError(t *testing.T, toolName string, request map[string]any, wantCode int, wantMsg string) { + resp, err := invokeTool(toolName, request, nil) + if err != nil { + t.Fatalf("invokeTool failed: %v", err) } + defer resp.Body.Close() - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - resp, err := invokeTool(tc.toolName, tc.request, nil) - if err != nil { - t.Fatalf("invokeTool failed: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != tc.wantCode { - bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("response status code is not %d, got %d: %s", tc.wantCode, resp.StatusCode, string(bodyBytes)) - } + if resp.StatusCode != wantCode { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not %d, got %d: %s", wantCode, resp.StatusCode, string(bodyBytes)) + } - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("failed to read response body: %v", err) - } + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } - if !bytes.Contains(bodyBytes, []byte(tc.wantMsg)) { - t.Fatalf("response body does not contain %q: %s", tc.wantMsg, string(bodyBytes)) - } - }) + if !bytes.Contains(bodyBytes, []byte(wantMsg)) { + t.Fatalf("response body does not contain %q: %s", wantMsg, string(bodyBytes)) } }