Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ require (
cloud.google.com/go/dataplex v1.28.0
cloud.google.com/go/dataproc/v2 v2.15.0
cloud.google.com/go/firestore v1.20.0
cloud.google.com/go/geminidataanalytics v0.3.0
cloud.google.com/go/geminidataanalytics v0.5.0
cloud.google.com/go/logging v1.13.1
cloud.google.com/go/longrunning v0.7.0
cloud.google.com/go/spanner v1.86.1
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,8 @@ cloud.google.com/go/gaming v1.6.0/go.mod h1:YMU1GEvA39Qt3zWGyAVA9bpYz/yAhTvaQ1t2
cloud.google.com/go/gaming v1.7.0/go.mod h1:LrB8U7MHdGgFG851iHAfqUdLcKBdQ55hzXy9xBJz0+w=
cloud.google.com/go/gaming v1.8.0/go.mod h1:xAqjS8b7jAVW0KFYeRUxngo9My3f33kFmua++Pi+ggM=
cloud.google.com/go/gaming v1.9.0/go.mod h1:Fc7kEmCObylSWLO334NcO+O9QMDyz+TKC4v1D7X+Bc0=
cloud.google.com/go/geminidataanalytics v0.3.0 h1:2Wi/kqFb5OLuEGH7q+/miE19VTqK1MYHjBEHENap9HI=
cloud.google.com/go/geminidataanalytics v0.3.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
cloud.google.com/go/geminidataanalytics v0.5.0 h1:+1usY81Cb+hE8BokpqCM7EgJtRCKzUKx7FvrHbT5hCA=
cloud.google.com/go/geminidataanalytics v0.5.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
cloud.google.com/go/gkebackup v0.2.0/go.mod h1:XKvv/4LfG829/B8B7xRkk8zRrOEbKtEam6yNfuQNH60=
cloud.google.com/go/gkebackup v0.3.0/go.mod h1:n/E671i1aOQvUxT541aTkCwExO/bTer2HDlj4TsBRAo=
cloud.google.com/go/gkebackup v0.4.0/go.mod h1:byAyBGUwYGEEww7xsbnUTBHIYcOPy/PgUWUtOeRm9Vg=
Expand Down
114 changes: 38 additions & 76 deletions internal/sources/cloudgda/cloud_gda.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,23 @@
package cloudgda

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"

geminidataanalytics "cloud.google.com/go/geminidataanalytics/apiv1beta"
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
)

const SourceType string = "cloud-gemini-data-analytics"
const Endpoint string = "https://geminidataanalytics.googleapis.com"

// NewDataChatClient can be overridden for testing.
var NewDataChatClient = geminidataanalytics.NewDataChatClient

// validate interface
var _ sources.SourceConfig = Config{}
Expand Down Expand Up @@ -67,38 +67,27 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
}

var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
}
} else {
// Use Application Default Credentials
// Scope: "https://www.googleapis.com/auth/cloud-platform" is generally sufficient for GDA
creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
client = baseClient
}

s := &Source{
Config: r,
Client: client,
BaseURL: Endpoint,
userAgent: ua,
}

if !r.UseClientOAuth {
client, err := NewDataChatClient(ctx, option.WithUserAgent(ua))
if err != nil {
return nil, fmt.Errorf("failed to create DataChatClient: %w", err)
}
s.Client = client
}

return s, nil
}

var _ sources.Source = &Source{}

type Source struct {
Config
Client *http.Client
BaseURL string
Client *geminidataanalytics.DataChatClient
userAgent string
}

Expand All @@ -114,63 +103,36 @@ func (s *Source) GetProjectID() string {
return s.ProjectID
}

func (s *Source) GetBaseURL() string {
return s.BaseURL
}

func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
if s.UseClientOAuth {
if accessToken == "" {
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
}
token := &oauth2.Token{AccessToken: accessToken}
baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
baseClient.Transport = util.NewUserAgentRoundTripper(s.userAgent, baseClient.Transport)
return baseClient, nil
}
return s.Client, nil
}

func (s *Source) UseClientAuthorization() bool {
return s.UseClientOAuth
}

func (s *Source) RunQuery(ctx context.Context, tokenStr string, bodyBytes []byte) (any, error) {
// The API endpoint itself always uses the "global" location.
apiLocation := "global"
apiParent := fmt.Sprintf("projects/%s/locations/%s", s.GetProjectID(), apiLocation)
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", s.GetBaseURL(), apiParent)

client, err := s.GetClient(ctx, tokenStr)
if err != nil {
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
func (s *Source) GetClient(ctx context.Context, tokenStr string) (*geminidataanalytics.DataChatClient, func(), error) {
if s.UseClientOAuth {
if tokenStr == "" {
return nil, nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
}
token := &oauth2.Token{AccessToken: tokenStr}
opts := []option.ClientOption{
option.WithUserAgent(s.userAgent),
option.WithTokenSource(oauth2.StaticTokenSource(token)),
}

resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute request: %w", err)
client, err := NewDataChatClient(ctx, opts...)
if err != nil {
return nil, nil, fmt.Errorf("failed to create per-request DataChatClient: %w", err)
}
return client, func() { client.Close() }, nil
}
defer resp.Body.Close()
return s.Client, func() {}, nil
}

respBody, err := io.ReadAll(resp.Body)
func (s *Source) RunQuery(ctx context.Context, tokenStr string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) {
client, cleanup, err := s.GetClient(ctx, tokenStr)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody))
}

var result map[string]any
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
return nil, err
}
defer cleanup()

return result, nil
return client.QueryData(ctx, req)
}
13 changes: 6 additions & 7 deletions internal/sources/cloudgda/cloud_gda_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,9 @@ func TestInitialize(t *testing.T) {
if gdaSrc.Client == nil && !tc.wantClientOAuth {
t.Fatal("expected non-nil HTTP client for ADC, got nil")
}
// When client OAuth is true, the source's client should be initialized with a base HTTP client
// that includes the user agent round tripper, but not the OAuth token. The token-aware
// client is created by GetClient.
if gdaSrc.Client == nil && tc.wantClientOAuth {
t.Fatal("expected non-nil HTTP client for client OAuth config, got nil")
// When client OAuth is true, the source's client should be nil.
if gdaSrc.Client != nil && tc.wantClientOAuth {
t.Fatal("expected nil HTTP client for client OAuth config, got non-nil")
}

// Test UseClientAuthorization method
Expand All @@ -186,15 +184,16 @@ func TestInitialize(t *testing.T) {

// Test GetClient with accessToken for client OAuth scenarios
if tc.wantClientOAuth {
client, err := gdaSrc.GetClient(ctx, "dummy-token")
client, cleanup, err := gdaSrc.GetClient(ctx, "dummy-token")
if err != nil {
t.Fatalf("GetClient with token failed: %v", err)
}
defer cleanup()
if client == nil {
t.Fatal("expected non-nil HTTP client from GetClient with token, got nil")
}
// Ensure passing empty token with UseClientOAuth enabled returns error
_, err = gdaSrc.GetClient(ctx, "")
_, _, err = gdaSrc.GetClient(ctx, "")
if err == nil || err.Error() != "client-side OAuth is enabled but no access token was provided" {
t.Errorf("expected 'client-side OAuth is enabled but no access token was provided' error, got: %v", err)
}
Expand Down
71 changes: 59 additions & 12 deletions internal/tools/cloudgda/cloudgda.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ import (
"fmt"
"net/http"

"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/protobuf/encoding/protojson"
)

const resourceType string = "cloud-gemini-data-analytics-query"
Expand Down Expand Up @@ -62,7 +64,49 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
GetProjectID() string
UseClientAuthorization() bool
RunQuery(context.Context, string, []byte) (any, error)
RunQuery(context.Context, string, *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error)
}

// QueryDataContext wraps geminidataanalyticspb.QueryDataContext to support YAML decoding via protojson.
type QueryDataContext struct {
*geminidataanalyticspb.QueryDataContext
}

func (q *QueryDataContext) UnmarshalYAML(b []byte) error {
var raw map[string]any
if err := yaml.Unmarshal(b, &raw); err != nil {
return fmt.Errorf("failed to unmarshal context from yaml: %w", err)
}
jsonBytes, err := json.Marshal(raw)
if err != nil {
return fmt.Errorf("failed to marshal context map: %w", err)
}
q.QueryDataContext = &geminidataanalyticspb.QueryDataContext{}
if err := protojson.Unmarshal(jsonBytes, q.QueryDataContext); err != nil {
return fmt.Errorf("failed to unmarshal context to proto: %w", err)
}
return nil
}

// GenerationOptions wraps geminidataanalyticspb.GenerationOptions to support YAML decoding via protojson.
type GenerationOptions struct {
*geminidataanalyticspb.GenerationOptions
}

func (g *GenerationOptions) UnmarshalYAML(b []byte) error {
var raw map[string]any
if err := yaml.Unmarshal(b, &raw); err != nil {
return fmt.Errorf("failed to unmarshal generation options from yaml: %w", err)
}
jsonBytes, err := json.Marshal(raw)
if err != nil {
return fmt.Errorf("failed to marshal generation options map: %w", err)
}
g.GenerationOptions = &geminidataanalyticspb.GenerationOptions{}
if err := protojson.Unmarshal(jsonBytes, g.GenerationOptions); err != nil {
return fmt.Errorf("failed to unmarshal generation options to proto: %w", err)
}
return nil
}

type Config struct {
Expand Down Expand Up @@ -99,12 +143,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)

return Tool{
t := Tool{
Config: cfg,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}, nil
}

return t, nil
}

// validate interface
Expand Down Expand Up @@ -146,19 +192,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// The parent in the request payload uses the tool's configured location.
payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location)

payload := &QueryDataRequest{
Parent: payloadParent,
Prompt: query,
Context: t.Context,
GenerationOptions: t.GenerationOptions,
req := &geminidataanalyticspb.QueryDataRequest{
Parent: payloadParent,
Prompt: query,
}

bodyBytes, err := json.Marshal(payload)
if err != nil {
return nil, util.NewClientServerError("failed to marshal request payload", http.StatusInternalServerError, err)
if t.Context != nil {
req.Context = t.Context.QueryDataContext
}

if t.GenerationOptions != nil {
req.GenerationOptions = t.GenerationOptions.GenerationOptions
}

resp, err := source.RunQuery(ctx, tokenStr, bodyBytes)
resp, err := source.RunQuery(ctx, tokenStr, req)
if err != nil {
return nil, util.ProcessGcpError(err)
}
Expand Down
Loading
Loading