Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .ci/integration.cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -751,8 +751,9 @@ steps:
entrypoint: /bin/bash
env:
- "GOPATH=/gopath"
- "SERVERLESS_SPARK_PROJECT=$PROJECT_ID"
- "SERVERLESS_SPARK_LOCATION=$_REGION"
- "SERVERLESS_SPARK_PROJECT=$PROJECT_ID"
- "SERVERLESS_SPARK_SERVICE_ACCOUNT=$SERVICE_ACCOUNT_EMAIL"
secretEnv: ["CLIENT_ID"]
volumes:
- name: "go"
Expand Down
1 change: 1 addition & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistviews"
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql"
_ "github.com/googleapis/genai-toolbox/internal/tools/redis"
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcancelbatch"
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch"
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches"
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerexecutesql"
Expand Down
2 changes: 1 addition & 1 deletion cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1474,7 +1474,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"serverless_spark_tools": tools.ToolsetConfig{
Name: "serverless_spark_tools",
ToolNames: []string{"list_batches", "get_batch"},
ToolNames: []string{"list_batches", "get_batch", "cancel_batch"},
},
},
},
Expand Down
2 changes: 2 additions & 0 deletions docs/en/resources/sources/serverless-spark.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Apache Spark.
List and filter Serverless Spark batches.
- [`serverless-spark-get-batch`](../tools/serverless-spark/serverless-spark-get-batch.md)
Get a Serverless Spark batch.
- [`serverless-spark-cancel-batch`](../tools/serverless-spark/serverless-spark-cancel-batch.md)
Cancel a running Serverless Spark batch operation.

## Requirements

Expand Down
1 change: 1 addition & 0 deletions docs/en/resources/tools/serverless-spark/_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ description: >

- [serverless-spark-get-batch](./serverless-spark-get-batch.md)
- [serverless-spark-list-batches](./serverless-spark-list-batches.md)
- [serverless-spark-cancel-batch](./serverless-spark-cancel-batch.md)
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
---
title: "serverless-spark-cancel-batch"
type: docs
weight: 2
description: >
A "serverless-spark-cancel-batch" tool cancels a running Spark batch operation.
aliases:
- /resources/tools/serverless-spark-cancel-batch
---

## About

`serverless-spark-cancel-batch` tool cancels a running Spark batch operation in
a Google Cloud Serverless for Apache Spark source. The cancellation request is
asynchronous, so the batch state will not change immediately after the tool
returns; it can take a minute or so for the cancellation to be reflected.

It's compatible with the following sources:

- [serverless-spark](../../sources/serverless-spark.md)

`serverless-spark-cancel-batch` accepts the following parameters:

- **`operation`** (required): The name of the operation to cancel. For example, for `projects/my-project/locations/us-central1/operations/my-operation`, you would pass `my-operation`.

The tool inherits the `project` and `location` from the source configuration.

## Example

```yaml
tools:
cancel_spark_batch:
kind: serverless-spark-cancel-batch
source: my-serverless-spark-source
description: Use this tool to cancel a running serverless spark batch operation.
```

## Response Format

```json
"Cancelled [projects/my-project/regions/us-central1/operations/my-operation]."
```

## Reference

| **field** | **type** | **required** | **description** |
| ------------ | :------: | :----------: | -------------------------------------------------- |
| kind | string | true | Must be "serverless-spark-cancel-batch". |
| source | string | true | Name of the source the tool should use. |
| description | string | true | Description of the tool that is passed to the LLM. |
| authRequired | string[] | false | List of auth services required to invoke this tool |
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ require (
cloud.google.com/go/dataproc/v2 v2.15.0
cloud.google.com/go/firestore v1.20.0
cloud.google.com/go/geminidataanalytics v0.2.1
cloud.google.com/go/longrunning v0.7.0
cloud.google.com/go/spanner v1.86.1
github.com/ClickHouse/clickhouse-go/v2 v2.40.3
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.53.0
Expand Down Expand Up @@ -80,7 +81,6 @@ require (
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
cloud.google.com/go/compute/metadata v0.9.0 // indirect
cloud.google.com/go/iam v1.5.3 // indirect
cloud.google.com/go/longrunning v0.7.0 // indirect
cloud.google.com/go/monitoring v1.24.3 // indirect
cloud.google.com/go/trace v1.11.7 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
Expand Down
4 changes: 4 additions & 0 deletions internal/prebuiltconfigs/tools/serverless-spark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ tools:
get_batch:
kind: serverless-spark-get-batch
source: serverless-spark-source
cancel_batch:
kind: serverless-spark-cancel-batch
source: serverless-spark-source

toolsets:
serverless_spark_tools:
- list_batches
- get_batch
- cancel_batch
41 changes: 31 additions & 10 deletions internal/sources/serverlessspark/serverlessspark.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"cloud.google.com/go/longrunning/autogen"
"go.opentelemetry.io/otel/trace"
"google.golang.org/api/option"
)
Expand Down Expand Up @@ -66,25 +67,31 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
if err != nil {
return nil, fmt.Errorf("failed to create dataproc client: %w", err)
}
opsClient, err := longrunning.NewOperationsClient(ctx, option.WithEndpoint(endpoint), option.WithUserAgent(ua))
if err != nil {
return nil, fmt.Errorf("failed to create longrunning client: %w", err)
}

s := &Source{
Name: r.Name,
Kind: SourceKind,
Project: r.Project,
Location: r.Location,
Client: client,
Name: r.Name,
Kind: SourceKind,
Project: r.Project,
Location: r.Location,
Client: client,
OpsClient: opsClient,
}
return s, nil
}

var _ sources.Source = &Source{}

type Source struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Project string
Location string
Client *dataproc.BatchControllerClient
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Project string
Location string
Client *dataproc.BatchControllerClient
OpsClient *longrunning.OperationsClient
}

func (s *Source) SourceKind() string {
Expand All @@ -94,3 +101,17 @@ func (s *Source) SourceKind() string {
func (s *Source) GetBatchControllerClient() *dataproc.BatchControllerClient {
return s.Client
}

func (s *Source) GetOperationsClient(ctx context.Context) (*longrunning.OperationsClient, error) {
return s.OpsClient, nil
}

func (s *Source) Close() error {
if err := s.Client.Close(); err != nil {
return err
}
if err := s.OpsClient.Close(); err != nil {
return err
}
return nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package serverlesssparkcancelbatch

import (
"context"
"fmt"
"strings"

"cloud.google.com/go/longrunning/autogen/longrunningpb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
"github.com/googleapis/genai-toolbox/internal/tools"
)

const kind = "serverless-spark-cancel-batch"

func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
}
}

func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}

type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description"`
AuthRequired []string `yaml:"authRequired"`
}

// validate interface
var _ tools.ToolConfig = Config{}

// ToolConfigKind returns the unique name for this tool.
func (cfg Config) ToolConfigKind() string {
return kind
}

// Initialize creates a new Tool instance.
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("source %q not found", cfg.Source)
}

ds, ok := rawS.(*serverlessspark.Source)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, serverlessspark.SourceKind)
}

desc := cfg.Description
if desc == "" {
desc = "Cancels a running Serverless Spark (aka Dataproc Serverless) batch operation. Note that the batch state will not change immediately after the tool returns; it can take a minute or so for the cancellation to be reflected."
}

allParameters := tools.Parameters{
tools.NewStringParameter("operation", "The name of the operation to cancel, e.g. for \"projects/my-project/locations/us-central1/operations/my-operation\", pass \"my-operation\""),
}
inputSchema, _ := allParameters.McpManifest()

mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: desc,
InputSchema: inputSchema,
}

return &Tool{
Name: cfg.Name,
Kind: kind,
Source: ds,
AuthRequired: cfg.AuthRequired,
manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()},
mcpManifest: mcpManifest,
Parameters: allParameters,
}, nil
}

// Tool is the implementation of the tool.
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Description string `yaml:"description"`
AuthRequired []string `yaml:"authRequired"`

Source *serverlessspark.Source

manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters tools.Parameters
}

// Invoke executes the tool's operation.
func (t *Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
client, err := t.Source.GetOperationsClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get operations client: %w", err)
}

paramMap := params.AsMap()
operation, ok := paramMap["operation"].(string)
if !ok {
return nil, fmt.Errorf("missing required parameter: operation")
}

if strings.Contains(operation, "/") {
return nil, fmt.Errorf("operation must be a short operation name without '/': %s", operation)
}

req := &longrunningpb.CancelOperationRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", t.Source.Project, t.Source.Location, operation),
}

err = client.CancelOperation(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to cancel operation: %w", err)
}

return fmt.Sprintf("Cancelled [%s].", operation), nil
}

func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.Parameters, data, claims)
}

func (t *Tool) Manifest() tools.Manifest {
return t.manifest
}

func (t *Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}

func (t *Tool) Authorized(services []string) bool {
return tools.IsAuthorized(t.AuthRequired, services)
}

func (t *Tool) RequiresClientAuthorization() bool {
// Client OAuth not supported, rely on ADCs.
return false
}
Loading
Loading