diff --git a/cmd/root.go b/cmd/root.go index 83162d042814..5fb121b28b2b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -155,6 +155,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttables" _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql" _ "github.com/googleapis/genai-toolbox/internal/tools/redis" + _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch" _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches" _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerexecutesql" _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlisttables" diff --git a/cmd/root_test.go b/cmd/root_test.go index 451e31f20130..d057c5b9a4d4 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -1467,7 +1467,7 @@ func TestPrebuiltTools(t *testing.T) { wantToolset: server.ToolsetConfigs{ "serverless_spark_tools": tools.ToolsetConfig{ Name: "serverless_spark_tools", - ToolNames: []string{"list_batches"}, + ToolNames: []string{"list_batches", "get_batch"}, }, }, }, diff --git a/docs/en/resources/sources/serverless-spark.md b/docs/en/resources/sources/serverless-spark.md index 9484f16d05e9..c6ebbfc5d834 100644 --- a/docs/en/resources/sources/serverless-spark.md +++ b/docs/en/resources/sources/serverless-spark.md @@ -17,6 +17,8 @@ Apache Spark. - [`serverless-spark-list-batches`](../tools/serverless-spark/serverless-spark-list-batches.md) List and filter Serverless Spark batches. +- [`serverless-spark-get-batch`](../tools/serverless-spark/serverless-spark-get-batch.md) + Get a Serverless Spark batch. ## Requirements diff --git a/docs/en/resources/tools/serverless-spark/_index.md b/docs/en/resources/tools/serverless-spark/_index.md index c05149468bae..7e9867aeb639 100644 --- a/docs/en/resources/tools/serverless-spark/_index.md +++ b/docs/en/resources/tools/serverless-spark/_index.md @@ -4,4 +4,7 @@ type: docs weight: 1 description: > Tools that work with Google Cloud Serverless for Apache Spark Sources. ---- \ No newline at end of file +--- + +- [serverless-spark-get-batch](./serverless-spark-get-batch.md) +- [serverless-spark-list-batches](./serverless-spark-list-batches.md) diff --git a/docs/en/resources/tools/serverless-spark/serverless-spark-get-batch.md b/docs/en/resources/tools/serverless-spark/serverless-spark-get-batch.md new file mode 100644 index 000000000000..532af65344a1 --- /dev/null +++ b/docs/en/resources/tools/serverless-spark/serverless-spark-get-batch.md @@ -0,0 +1,84 @@ +--- +title: "serverless-spark-get-batch" +type: docs +weight: 1 +description: > + A "serverless-spark-get-batch" tool gets a single Spark batch from the source. +aliases: + - /resources/tools/serverless-spark-get-batch +--- + +# serverless-spark-get-batch + +The `serverless-spark-get-batch` tool allows you to retrieve a specific +Serverless Spark batch job. It's compatible with the following sources: + +- [serverless-spark](../../sources/serverless-spark.md) + +`serverless-spark-list-batches` accepts the following parameters: + +- **`name`**: The short name of the batch, e.g. for + `projects/my-project/locations/us-central1/my-batch`, pass `my-batch`. + +The tool gets the `project` and `location` from the source configuration. + +## Example + +```yaml +tools: + get_my_batch: + kind: serverless-spark-get-batch + source: my-serverless-spark-source + description: Use this tool to get a serverless spark batch. +``` + +## Response Format + +The response is a full Batch JSON object as defined in the [API +spec](https://cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.batches#Batch). +Example with a reduced set of fields: + +```json +{ + "createTime": "2025-10-10T15:15:21.303146Z", + "creator": "alice@example.com", + "labels": { + "goog-dataproc-batch-uuid": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + "goog-dataproc-location": "us-central1" + }, + "name": "projects/google.com:hadoop-cloud-dev/locations/us-central1/batches/alice-20251010-abcd", + "operation": "projects/google.com:hadoop-cloud-dev/regions/us-central1/operations/11111111-2222-3333-4444-555555555555", + "runtimeConfig": { + "properties": { + "spark:spark.driver.cores": "4", + "spark:spark.driver.memory": "12200m" + } + }, + "sparkBatch": { + "jarFileUris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "mainClass": "org.apache.spark.examples.SparkPi" + }, + "state": "SUCCEEDED", + "stateHistory": [ + { + "state": "PENDING", + "stateStartTime": "2025-10-10T15:15:21.303146Z" + }, + { + "state": "RUNNING", + "stateStartTime": "2025-10-10T15:16:41.291747Z" + } + ], + "stateTime": "2025-10-10T15:17:21.265493Z", + "uuid": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" +} +``` + +## Reference + +| **field** | **type** | **required** | **description** | +| ------------ | :------: | :----------: | -------------------------------------------------- | +| kind | string | true | Must be "serverless-spark-get-batch". | +| source | string | true | Name of the source the tool should use. | +| description | string | true | Description of the tool that is passed to the LLM. | +| authRequired | string[] | false | List of auth services required to invoke this tool | diff --git a/go.mod b/go.mod index 4086b5c1ed78..5e5496fd2dfb 100644 --- a/go.mod +++ b/go.mod @@ -56,6 +56,7 @@ require ( golang.org/x/oauth2 v0.32.0 google.golang.org/api v0.251.0 google.golang.org/genproto v0.0.0-20251007200510-49b9836ed3ff + google.golang.org/protobuf v1.36.10 modernc.org/sqlite v1.39.1 ) @@ -180,7 +181,6 @@ require ( google.golang.org/genproto/googleapis/api v0.0.0-20251002232023-7c0ddcbb5797 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251002232023-7c0ddcbb5797 // indirect google.golang.org/grpc v1.75.1 // indirect - google.golang.org/protobuf v1.36.10 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect modernc.org/libc v1.66.10 // indirect diff --git a/internal/prebuiltconfigs/tools/serverless-spark.yaml b/internal/prebuiltconfigs/tools/serverless-spark.yaml index e6fec1d4beec..3ef0a2834aa9 100644 --- a/internal/prebuiltconfigs/tools/serverless-spark.yaml +++ b/internal/prebuiltconfigs/tools/serverless-spark.yaml @@ -22,7 +22,11 @@ tools: list_batches: kind: serverless-spark-list-batches source: serverless-spark-source + get_batch: + kind: serverless-spark-get-batch + source: serverless-spark-source toolsets: serverless_spark_tools: - list_batches + - get_batch diff --git a/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go b/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go new file mode 100644 index 000000000000..0087a43127e0 --- /dev/null +++ b/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go @@ -0,0 +1,171 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package serverlesssparkgetbatch + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" + "github.com/googleapis/genai-toolbox/internal/tools" + "google.golang.org/protobuf/encoding/protojson" +) + +const kind = "serverless-spark-get-batch" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +// ToolConfigKind returns the unique name for this tool. +func (cfg Config) ToolConfigKind() string { + return kind +} + +// Initialize creates a new Tool instance. +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("source %q not found", cfg.Source) + } + + ds, ok := rawS.(*serverlessspark.Source) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, serverlessspark.SourceKind) + } + + desc := cfg.Description + if desc == "" { + desc = "Gets a Serverless Spark (aka Dataproc Serverless) batch" + } + + allParameters := tools.Parameters{ + tools.NewStringParameter("name", "The short name of the batch, e.g. for \"projects/my-project/locations/us-central1/batches/my-batch\", pass \"my-batch\" (the project and location are inherited from the source)"), + } + inputSchema, _ := allParameters.McpManifest() + + mcpManifest := tools.McpManifest{ + Name: cfg.Name, + Description: desc, + InputSchema: inputSchema, + } + + return Tool{ + Name: cfg.Name, + Kind: kind, + Source: ds, + AuthRequired: cfg.AuthRequired, + manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()}, + mcpManifest: mcpManifest, + Parameters: allParameters, + }, nil +} + +// Tool is the implementation of the tool. +type Tool struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Description string `yaml:"description"` + AuthRequired []string `yaml:"authRequired"` + + Source *serverlessspark.Source + + manifest tools.Manifest + mcpManifest tools.McpManifest + Parameters tools.Parameters +} + +// Invoke executes the tool's operation. +func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { + client := t.Source.GetBatchControllerClient() + + paramMap := params.AsMap() + name, ok := paramMap["name"].(string) + if !ok { + return nil, fmt.Errorf("missing required parameter: name") + } + + if strings.Contains(name, "/") { + return nil, fmt.Errorf("name must be a short batch name without '/': %s", name) + } + + req := &dataprocpb.GetBatchRequest{ + Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", t.Source.Project, t.Source.Location, name), + } + + batchPb, err := client.GetBatch(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to get batch: %w", err) + } + + jsonBytes, err := protojson.Marshal(batchPb) + if err != nil { + return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err) + } + + var result map[string]any + if err := json.Unmarshal(jsonBytes, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err) + } + + return result, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { + return tools.ParseParams(t.Parameters, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(services []string) bool { + return tools.IsAuthorized(t.AuthRequired, services) +} + +func (t Tool) RequiresClientAuthorization() bool { + // Client OAuth not supported, rely on ADCs. + return false +} diff --git a/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch_test.go b/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch_test.go new file mode 100644 index 000000000000..f7589f7b077d --- /dev/null +++ b/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch_test.go @@ -0,0 +1,72 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package serverlesssparkgetbatch_test + +import ( + "testing" + + "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch" +) + +func TestParseFromYaml(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: serverless-spark-get-batch + source: my-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": serverlesssparkgetbatch.Config{ + Name: "example_tool", + Kind: "serverless-spark-get-batch", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got, yaml.Strict()) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/tests/serverlessspark/serverless_spark_integration_test.go b/tests/serverlessspark/serverless_spark_integration_test.go index 0057b038d92d..ffe87c9a4fb7 100644 --- a/tests/serverlessspark/serverless_spark_integration_test.go +++ b/tests/serverlessspark/serverless_spark_integration_test.go @@ -24,16 +24,20 @@ import ( "os" "reflect" "regexp" + "strings" "testing" "time" dataproc "cloud.google.com/go/dataproc/v2/apiv1" "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" + "github.com/google/go-cmp/cmp" "github.com/googleapis/genai-toolbox/internal/testutils" "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches" "github.com/googleapis/genai-toolbox/tests" "google.golang.org/api/iterator" "google.golang.org/api/option" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/testing/protocmp" ) var ( @@ -81,6 +85,15 @@ func TestServerlessSparkToolEndpoints(t *testing.T) { "source": "my-spark", "authRequired": []string{"my-google-auth"}, }, + "get-batch": map[string]any{ + "kind": "serverless-spark-get-batch", + "source": "my-spark", + }, + "get-batch-with-auth": map[string]any{ + "kind": "serverless-spark-get-batch", + "source": "my-spark", + "authRequired": []string{"my-google-auth"}, + }, }, } @@ -106,13 +119,20 @@ func TestServerlessSparkToolEndpoints(t *testing.T) { defer client.Close() runListBatchesTest(t, client, ctx) - runListBatchesErrorTest(t) - runListBatchesAuthTest(t) + + fullName := listBatchesRpc(t, client, ctx, "", 1, true)[0].Name + runGetBatchTest(t, client, ctx, fullName) + + runErrorTest(t) + + // Get the most recent batch, which is all we need for this test. + runAuthTest(t, "list-batches-with-auth", map[string]any{"pageSize": 1}) + runAuthTest(t, "get-batch-with-auth", map[string]any{"name": shortName(fullName)}) } // runListBatchesTest invokes the running list-batches tool and ensures it returns the correct -// number of results. It can run successfully against any GCP project that has at least 2 succeeded -// or failed Serverless Spark batches, of any age. +// number of results. It can run successfully against any GCP project that contains at least 2 total +// Serverless Spark batches. func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context) { batch2 := listBatchesRpc(t, client, ctx, "", 2, true) batch20 := listBatchesRpc(t, client, ctx, "", 20, false) @@ -145,7 +165,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct } for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { + t.Run("list-batches "+tc.name, func(t *testing.T) { var actual []serverlesssparklistbatches.Batch var pageToken string for i := 0; i < tc.numPages; i++ { @@ -157,9 +177,9 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct request["pageSize"] = tc.pageSize } - resp, err := invokeListBatches("list-batches", request, nil) + resp, err := invokeTool("list-batches", request, nil) if err != nil { - t.Fatalf("invokeListBatches failed: %v", err) + t.Fatalf("invokeTool failed: %v", err) } defer resp.Body.Close() @@ -221,106 +241,181 @@ func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx co return serverlesssparklistbatches.ToBatches(batchPbs) } -func runListBatchesErrorTest(t *testing.T) { +func runAuthTest(t *testing.T, toolName string, request map[string]any) { + idToken, err := tests.GetGoogleIdToken(tests.ClientId) + if err != nil { + t.Fatalf("error getting Google ID token: %s", err) + } tcs := []struct { - name string - pageSize int - wantCode int - wantMsg string + name string + headers map[string]string + wantStatus int }{ { - name: "zero page size", - pageSize: 0, - wantCode: http.StatusBadRequest, - wantMsg: "pageSize must be positive: 0", + name: "valid auth token", + headers: map[string]string{"my-google-auth_token": idToken}, + wantStatus: http.StatusOK, }, { - name: "negative page size", - pageSize: -1, - wantCode: http.StatusBadRequest, - wantMsg: "pageSize must be positive: -1", + name: "invalid auth token", + headers: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, + wantStatus: http.StatusUnauthorized, + }, + { + name: "no auth token", + headers: nil, + wantStatus: http.StatusUnauthorized, }, } for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - request := map[string]any{ - "pageSize": tc.pageSize, - } - resp, err := invokeListBatches("list-batches", request, nil) + t.Run(toolName+" "+tc.name, func(t *testing.T) { + resp, err := invokeTool(toolName, request, tc.headers) if err != nil { - t.Fatalf("invokeListBatches failed: %v", err) + t.Fatalf("invokeTool failed: %v", err) } defer resp.Body.Close() - if resp.StatusCode != tc.wantCode { + if resp.StatusCode != tc.wantStatus { bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("response status code is not %d, got %d: %s", tc.wantCode, resp.StatusCode, string(bodyBytes)) + t.Fatalf("response status code is not %d, got %d: %s", tc.wantStatus, resp.StatusCode, string(bodyBytes)) } + }) + } +} - bodyBytes, err := io.ReadAll(resp.Body) +func runGetBatchTest(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, fullName string) { + // First get the batch details directly from the Go proto API. + req := &dataprocpb.GetBatchRequest{ + Name: fullName, + } + rawWantBatchPb, err := client.GetBatch(ctx, req) + if err != nil { + t.Fatalf("failed to get batch: %s", err) + } + + // Trim unknown fields from the proto by marshalling and unmarshalling. + jsonBytes, err := protojson.Marshal(rawWantBatchPb) + if err != nil { + t.Fatalf("failed to marshal batch to JSON: %s", err) + } + var wantBatchPb dataprocpb.Batch + if err := protojson.Unmarshal(jsonBytes, &wantBatchPb); err != nil { + t.Fatalf("error unmarshalling result: %s", err) + } + + tcs := []struct { + name string + batchName string + want *dataprocpb.Batch + }{ + { + name: "found batch", + batchName: shortName(fullName), + want: &wantBatchPb, + }, + } + + for _, tc := range tcs { + t.Run("get-batch "+tc.name, func(t *testing.T) { + request := map[string]any{"name": tc.batchName} + resp, err := invokeTool("get-batch", request, nil) if err != nil { - t.Fatalf("failed to read response body: %v", err) + t.Fatalf("invokeTool failed: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + var body map[string]any + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + t.Fatalf("error parsing response body: %v", err) + } + result, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") } - if !bytes.Contains(bodyBytes, []byte(tc.wantMsg)) { - t.Fatalf("response body does not contain %q: %s", tc.wantMsg, string(bodyBytes)) + // Unmarshal JSON to proto for proto-aware deep comparison. + var batch dataprocpb.Batch + if err := protojson.Unmarshal([]byte(result), &batch); err != nil { + t.Fatalf("error unmarshalling result: %s", err) + } + + if !cmp.Equal(&batch, tc.want, protocmp.Transform()) { + diff := cmp.Diff(&batch, tc.want, protocmp.Transform()) + t.Errorf("GetBatch() returned diff (-got +want):\n%s", diff) } }) } } -func runListBatchesAuthTest(t *testing.T) { - idToken, err := tests.GetGoogleIdToken(tests.ClientId) - if err != nil { - t.Fatalf("error getting Google ID token: %s", err) - } +func runErrorTest(t *testing.T) { + missingBatchFullName := fmt.Sprintf("projects/%s/locations/%s/batches/INVALID_BATCH", serverlessSparkProject, serverlessSparkLocation) tcs := []struct { - name string - toolName string - headers map[string]string - wantStatus int + name string + toolName string + request map[string]any + wantCode int + wantMsg string }{ { - name: "valid auth token", - toolName: "list-batches-with-auth", - headers: map[string]string{"my-google-auth_token": idToken}, - wantStatus: http.StatusOK, + name: "list-batches zero page size", + toolName: "list-batches", + request: map[string]any{"pageSize": 0}, + wantCode: http.StatusBadRequest, + wantMsg: "pageSize must be positive: 0", }, { - name: "invalid auth token", - toolName: "list-batches-with-auth", - headers: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, - wantStatus: http.StatusUnauthorized, + name: "list-batches negative page size", + toolName: "list-batches", + request: map[string]any{"pageSize": -1}, + wantCode: http.StatusBadRequest, + wantMsg: "pageSize must be positive: -1", }, { - name: "no auth token", - toolName: "list-batches-with-auth", - headers: nil, - wantStatus: http.StatusUnauthorized, + name: "get-batch missing batch", + toolName: "get-batch", + request: map[string]any{"name": "INVALID_BATCH"}, + wantCode: http.StatusBadRequest, + wantMsg: fmt.Sprintf("Not found: Batch projects/%s/locations/%s/batches/INVALID_BATCH", serverlessSparkProject, serverlessSparkLocation), + }, + { + name: "get-batch full batch name", + toolName: "get-batch", + request: map[string]any{"name": missingBatchFullName}, + wantCode: http.StatusBadRequest, + wantMsg: fmt.Sprintf("name must be a short batch name without '/': %s", missingBatchFullName), }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - request := map[string]any{ - "pageSize": 1, - } - resp, err := invokeListBatches(tc.toolName, request, tc.headers) + resp, err := invokeTool(tc.toolName, tc.request, nil) if err != nil { - t.Fatalf("invokeListBatches failed: %v", err) + t.Fatalf("invokeTool failed: %v", err) } defer resp.Body.Close() - if resp.StatusCode != tc.wantStatus { + if resp.StatusCode != tc.wantCode { bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("response status code is not %d, got %d: %s", tc.wantStatus, resp.StatusCode, string(bodyBytes)) + t.Fatalf("response status code is not %d, got %d: %s", tc.wantCode, resp.StatusCode, string(bodyBytes)) + } + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + + if !bytes.Contains(bodyBytes, []byte(tc.wantMsg)) { + t.Fatalf("response body does not contain %q: %s", tc.wantMsg, string(bodyBytes)) } }) } } -func invokeListBatches(toolName string, request map[string]any, headers map[string]string) (*http.Response, error) { +func invokeTool(toolName string, request map[string]any, headers map[string]string) (*http.Response, error) { requestBytes, err := json.Marshal(request) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) @@ -338,3 +433,8 @@ func invokeListBatches(toolName string, request map[string]any, headers map[stri return http.DefaultClient.Do(req) } + +func shortName(fullName string) string { + parts := strings.Split(fullName, "/") + return parts[len(parts)-1] +}