Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttables"
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql"
_ "github.com/googleapis/genai-toolbox/internal/tools/redis"
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch"
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches"
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerexecutesql"
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlisttables"
Expand Down
2 changes: 1 addition & 1 deletion cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1467,7 +1467,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"serverless_spark_tools": tools.ToolsetConfig{
Name: "serverless_spark_tools",
ToolNames: []string{"list_batches"},
ToolNames: []string{"list_batches", "get_batch"},
},
},
},
Expand Down
2 changes: 2 additions & 0 deletions docs/en/resources/sources/serverless-spark.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Apache Spark.

- [`serverless-spark-list-batches`](../tools/serverless-spark/serverless-spark-list-batches.md)
List and filter Serverless Spark batches.
- [`serverless-spark-get-batch`](../tools/serverless-spark/serverless-spark-get-batch.md)
Get a Serverless Spark batch.

## Requirements

Expand Down
5 changes: 4 additions & 1 deletion docs/en/resources/tools/serverless-spark/_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@ type: docs
weight: 1
description: >
Tools that work with Google Cloud Serverless for Apache Spark Sources.
---
---

- [serverless-spark-get-batch](./serverless-spark-get-batch.md)
- [serverless-spark-list-batches](./serverless-spark-list-batches.md)
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
---
title: "serverless-spark-get-batch"
type: docs
weight: 1
description: >
A "serverless-spark-get-batch" tool gets a single Spark batch from the source.
aliases:
- /resources/tools/serverless-spark-get-batch
---

# serverless-spark-get-batch

The `serverless-spark-get-batch` tool allows you to retrieve a specific
Serverless Spark batch job. It's compatible with the following sources:

- [serverless-spark](../../sources/serverless-spark.md)

`serverless-spark-list-batches` accepts the following parameters:

- **`name`**: The short name of the batch, e.g. for
`projects/my-project/locations/us-central1/my-batch`, pass `my-batch`.

The tool gets the `project` and `location` from the source configuration.

## Example

```yaml
tools:
get_my_batch:
kind: serverless-spark-get-batch
source: my-serverless-spark-source
description: Use this tool to get a serverless spark batch.
```

## Response Format

The response is a full Batch JSON object as defined in the [API
spec](https://cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.batches#Batch).
Example with a reduced set of fields:

```json
{
"createTime": "2025-10-10T15:15:21.303146Z",
"creator": "alice@example.com",
"labels": {
"goog-dataproc-batch-uuid": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
"goog-dataproc-location": "us-central1"
},
"name": "projects/google.com:hadoop-cloud-dev/locations/us-central1/batches/alice-20251010-abcd",
"operation": "projects/google.com:hadoop-cloud-dev/regions/us-central1/operations/11111111-2222-3333-4444-555555555555",
"runtimeConfig": {
"properties": {
"spark:spark.driver.cores": "4",
"spark:spark.driver.memory": "12200m"
}
},
"sparkBatch": {
"jarFileUris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
"mainClass": "org.apache.spark.examples.SparkPi"
},
"state": "SUCCEEDED",
"stateHistory": [
{
"state": "PENDING",
"stateStartTime": "2025-10-10T15:15:21.303146Z"
},
{
"state": "RUNNING",
"stateStartTime": "2025-10-10T15:16:41.291747Z"
}
],
"stateTime": "2025-10-10T15:17:21.265493Z",
"uuid": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
}
```

## Reference

| **field** | **type** | **required** | **description** |
| ------------ | :------: | :----------: | -------------------------------------------------- |
| kind | string | true | Must be "serverless-spark-get-batch". |
| source | string | true | Name of the source the tool should use. |
| description | string | true | Description of the tool that is passed to the LLM. |
| authRequired | string[] | false | List of auth services required to invoke this tool |
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ require (
golang.org/x/oauth2 v0.32.0
google.golang.org/api v0.251.0
google.golang.org/genproto v0.0.0-20251007200510-49b9836ed3ff
google.golang.org/protobuf v1.36.10
modernc.org/sqlite v1.39.1
)

Expand Down Expand Up @@ -180,7 +181,6 @@ require (
google.golang.org/genproto/googleapis/api v0.0.0-20251002232023-7c0ddcbb5797 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251002232023-7c0ddcbb5797 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
modernc.org/libc v1.66.10 // indirect
Expand Down
4 changes: 4 additions & 0 deletions internal/prebuiltconfigs/tools/serverless-spark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ tools:
list_batches:
kind: serverless-spark-list-batches
source: serverless-spark-source
get_batch:
kind: serverless-spark-get-batch
source: serverless-spark-source

toolsets:
serverless_spark_tools:
- list_batches
- get_batch
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package serverlesssparkgetbatch

import (
"context"
"encoding/json"
"fmt"
"strings"

"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
"github.com/googleapis/genai-toolbox/internal/tools"
"google.golang.org/protobuf/encoding/protojson"
)

const kind = "serverless-spark-get-batch"

func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
}
}

func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}

type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description"`
AuthRequired []string `yaml:"authRequired"`
}

// validate interface
var _ tools.ToolConfig = Config{}

// ToolConfigKind returns the unique name for this tool.
func (cfg Config) ToolConfigKind() string {
return kind
}

// Initialize creates a new Tool instance.
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("source %q not found", cfg.Source)
}

ds, ok := rawS.(*serverlessspark.Source)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, serverlessspark.SourceKind)
}

desc := cfg.Description
if desc == "" {
desc = "Gets a Serverless Spark (aka Dataproc Serverless) batch"
}

allParameters := tools.Parameters{
tools.NewStringParameter("name", "The short name of the batch, e.g. for \"projects/my-project/locations/us-central1/batches/my-batch\", pass \"my-batch\" (the project and location are inherited from the source)"),
}
inputSchema, _ := allParameters.McpManifest()

mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: desc,
InputSchema: inputSchema,
}

return Tool{
Name: cfg.Name,
Kind: kind,
Source: ds,
AuthRequired: cfg.AuthRequired,
manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()},
mcpManifest: mcpManifest,
Parameters: allParameters,
}, nil
}

// Tool is the implementation of the tool.
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Description string `yaml:"description"`
AuthRequired []string `yaml:"authRequired"`

Source *serverlessspark.Source

manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters tools.Parameters
}

// Invoke executes the tool's operation.
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
client := t.Source.GetBatchControllerClient()

paramMap := params.AsMap()
name, ok := paramMap["name"].(string)
if !ok {
return nil, fmt.Errorf("missing required parameter: name")
}

if strings.Contains(name, "/") {
return nil, fmt.Errorf("name must be a short batch name without '/': %s", name)
}

req := &dataprocpb.GetBatchRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", t.Source.Project, t.Source.Location, name),
}

batchPb, err := client.GetBatch(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to get batch: %w", err)
}

jsonBytes, err := protojson.Marshal(batchPb)
if err != nil {
return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err)
}

var result map[string]any
if err := json.Unmarshal(jsonBytes, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err)
}

return result, nil
}

func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.Parameters, data, claims)
}

func (t Tool) Manifest() tools.Manifest {
return t.manifest
}

func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}

func (t Tool) Authorized(services []string) bool {
return tools.IsAuthorized(t.AuthRequired, services)
}

func (t Tool) RequiresClientAuthorization() bool {
// Client OAuth not supported, rely on ADCs.
return false
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package serverlesssparkgetbatch_test

import (
"testing"

"github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch"
)

func TestParseFromYaml(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic example",
in: `
tools:
example_tool:
kind: serverless-spark-get-batch
source: my-instance
description: some description
`,
want: server.ToolConfigs{
"example_tool": serverlesssparkgetbatch.Config{
Name: "example_tool",
Kind: "serverless-spark-get-batch",
Source: "my-instance",
Description: "some description",
AuthRequired: []string{},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got, yaml.Strict())
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}

if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}
Loading
Loading