From dcc0bb673f9e57d638ea8a866f1ec223f7e9006b Mon Sep 17 00:00:00 2001 From: Juexin Wang Date: Wed, 14 Jan 2026 13:43:10 -0800 Subject: [PATCH] refactor(cloudgda): update to use google-cloud-go sdk types This introduces the geminidataanalytics Go SDK and replaces manually defined types with the official protocol buffers and wrappers across both tool and source execution. --- go.mod | 2 +- go.sum | 4 +- internal/sources/cloudgda/cloud_gda.go | 114 +++----- internal/sources/cloudgda/cloud_gda_test.go | 13 +- internal/tools/cloudgda/cloudgda.go | 71 ++++- internal/tools/cloudgda/cloudgda_test.go | 264 +++++++------------ internal/tools/cloudgda/types.go | 116 -------- tests/cloudgda/cloud_gda_integration_test.go | 181 ++++++------- 8 files changed, 284 insertions(+), 481 deletions(-) delete mode 100644 internal/tools/cloudgda/types.go diff --git a/go.mod b/go.mod index 9df36dbc2071..1c84e0c11092 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( cloud.google.com/go/dataplex v1.28.0 cloud.google.com/go/dataproc/v2 v2.15.0 cloud.google.com/go/firestore v1.20.0 - cloud.google.com/go/geminidataanalytics v0.3.0 + cloud.google.com/go/geminidataanalytics v0.5.0 cloud.google.com/go/logging v1.13.1 cloud.google.com/go/longrunning v0.7.0 cloud.google.com/go/spanner v1.86.1 diff --git a/go.sum b/go.sum index b1676bac848a..622acb84f943 100644 --- a/go.sum +++ b/go.sum @@ -311,8 +311,8 @@ cloud.google.com/go/gaming v1.6.0/go.mod h1:YMU1GEvA39Qt3zWGyAVA9bpYz/yAhTvaQ1t2 cloud.google.com/go/gaming v1.7.0/go.mod h1:LrB8U7MHdGgFG851iHAfqUdLcKBdQ55hzXy9xBJz0+w= cloud.google.com/go/gaming v1.8.0/go.mod h1:xAqjS8b7jAVW0KFYeRUxngo9My3f33kFmua++Pi+ggM= cloud.google.com/go/gaming v1.9.0/go.mod h1:Fc7kEmCObylSWLO334NcO+O9QMDyz+TKC4v1D7X+Bc0= -cloud.google.com/go/geminidataanalytics v0.3.0 h1:2Wi/kqFb5OLuEGH7q+/miE19VTqK1MYHjBEHENap9HI= -cloud.google.com/go/geminidataanalytics v0.3.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg= +cloud.google.com/go/geminidataanalytics v0.5.0 h1:+1usY81Cb+hE8BokpqCM7EgJtRCKzUKx7FvrHbT5hCA= +cloud.google.com/go/geminidataanalytics v0.5.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg= cloud.google.com/go/gkebackup v0.2.0/go.mod h1:XKvv/4LfG829/B8B7xRkk8zRrOEbKtEam6yNfuQNH60= cloud.google.com/go/gkebackup v0.3.0/go.mod h1:n/E671i1aOQvUxT541aTkCwExO/bTer2HDlj4TsBRAo= cloud.google.com/go/gkebackup v0.4.0/go.mod h1:byAyBGUwYGEEww7xsbnUTBHIYcOPy/PgUWUtOeRm9Vg= diff --git a/internal/sources/cloudgda/cloud_gda.go b/internal/sources/cloudgda/cloud_gda.go index 80e8df431c7c..4c977418c6a9 100644 --- a/internal/sources/cloudgda/cloud_gda.go +++ b/internal/sources/cloudgda/cloud_gda.go @@ -14,23 +14,23 @@ package cloudgda import ( - "bytes" "context" - "encoding/json" "fmt" - "io" - "net/http" + geminidataanalytics "cloud.google.com/go/geminidataanalytics/apiv1beta" + "cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" "go.opentelemetry.io/otel/trace" "golang.org/x/oauth2" - "golang.org/x/oauth2/google" + "google.golang.org/api/option" ) const SourceType string = "cloud-gemini-data-analytics" -const Endpoint string = "https://geminidataanalytics.googleapis.com" + +// NewDataChatClient can be overridden for testing. +var NewDataChatClient = geminidataanalytics.NewDataChatClient // validate interface var _ sources.SourceConfig = Config{} @@ -67,29 +67,19 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("error in User Agent retrieval: %s", err) } - var client *http.Client - if r.UseClientOAuth { - client = &http.Client{ - Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport), - } - } else { - // Use Application Default Credentials - // Scope: "https://www.googleapis.com/auth/cloud-platform" is generally sufficient for GDA - creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform") - if err != nil { - return nil, fmt.Errorf("failed to find default credentials: %w", err) - } - baseClient := oauth2.NewClient(ctx, creds.TokenSource) - baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport) - client = baseClient - } - s := &Source{ Config: r, - Client: client, - BaseURL: Endpoint, userAgent: ua, } + + if !r.UseClientOAuth { + client, err := NewDataChatClient(ctx, option.WithUserAgent(ua)) + if err != nil { + return nil, fmt.Errorf("failed to create DataChatClient: %w", err) + } + s.Client = client + } + return s, nil } @@ -97,8 +87,7 @@ var _ sources.Source = &Source{} type Source struct { Config - Client *http.Client - BaseURL string + Client *geminidataanalytics.DataChatClient userAgent string } @@ -114,63 +103,36 @@ func (s *Source) GetProjectID() string { return s.ProjectID } -func (s *Source) GetBaseURL() string { - return s.BaseURL -} - -func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) { - if s.UseClientOAuth { - if accessToken == "" { - return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided") - } - token := &oauth2.Token{AccessToken: accessToken} - baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)) - baseClient.Transport = util.NewUserAgentRoundTripper(s.userAgent, baseClient.Transport) - return baseClient, nil - } - return s.Client, nil -} - func (s *Source) UseClientAuthorization() bool { return s.UseClientOAuth } -func (s *Source) RunQuery(ctx context.Context, tokenStr string, bodyBytes []byte) (any, error) { - // The API endpoint itself always uses the "global" location. - apiLocation := "global" - apiParent := fmt.Sprintf("projects/%s/locations/%s", s.GetProjectID(), apiLocation) - apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", s.GetBaseURL(), apiParent) - - client, err := s.GetClient(ctx, tokenStr) - if err != nil { - return nil, fmt.Errorf("failed to get HTTP client: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") +func (s *Source) GetClient(ctx context.Context, tokenStr string) (*geminidataanalytics.DataChatClient, func(), error) { + if s.UseClientOAuth { + if tokenStr == "" { + return nil, nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided") + } + token := &oauth2.Token{AccessToken: tokenStr} + opts := []option.ClientOption{ + option.WithUserAgent(s.userAgent), + option.WithTokenSource(oauth2.StaticTokenSource(token)), + } - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to execute request: %w", err) + client, err := NewDataChatClient(ctx, opts...) + if err != nil { + return nil, nil, fmt.Errorf("failed to create per-request DataChatClient: %w", err) + } + return client, func() { client.Close() }, nil } - defer resp.Body.Close() + return s.Client, func() {}, nil +} - respBody, err := io.ReadAll(resp.Body) +func (s *Source) RunQuery(ctx context.Context, tokenStr string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) { + client, cleanup, err := s.GetClient(ctx, tokenStr) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody)) - } - - var result map[string]any - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) + return nil, err } + defer cleanup() - return result, nil + return client.QueryData(ctx, req) } diff --git a/internal/sources/cloudgda/cloud_gda_test.go b/internal/sources/cloudgda/cloud_gda_test.go index 6ec771f60120..b081d84753c3 100644 --- a/internal/sources/cloudgda/cloud_gda_test.go +++ b/internal/sources/cloudgda/cloud_gda_test.go @@ -172,11 +172,9 @@ func TestInitialize(t *testing.T) { if gdaSrc.Client == nil && !tc.wantClientOAuth { t.Fatal("expected non-nil HTTP client for ADC, got nil") } - // When client OAuth is true, the source's client should be initialized with a base HTTP client - // that includes the user agent round tripper, but not the OAuth token. The token-aware - // client is created by GetClient. - if gdaSrc.Client == nil && tc.wantClientOAuth { - t.Fatal("expected non-nil HTTP client for client OAuth config, got nil") + // When client OAuth is true, the source's client should be nil. + if gdaSrc.Client != nil && tc.wantClientOAuth { + t.Fatal("expected nil HTTP client for client OAuth config, got non-nil") } // Test UseClientAuthorization method @@ -186,15 +184,16 @@ func TestInitialize(t *testing.T) { // Test GetClient with accessToken for client OAuth scenarios if tc.wantClientOAuth { - client, err := gdaSrc.GetClient(ctx, "dummy-token") + client, cleanup, err := gdaSrc.GetClient(ctx, "dummy-token") if err != nil { t.Fatalf("GetClient with token failed: %v", err) } + defer cleanup() if client == nil { t.Fatal("expected non-nil HTTP client from GetClient with token, got nil") } // Ensure passing empty token with UseClientOAuth enabled returns error - _, err = gdaSrc.GetClient(ctx, "") + _, _, err = gdaSrc.GetClient(ctx, "") if err == nil || err.Error() != "client-side OAuth is enabled but no access token was provided" { t.Errorf("expected 'client-side OAuth is enabled but no access token was provided' error, got: %v", err) } diff --git a/internal/tools/cloudgda/cloudgda.go b/internal/tools/cloudgda/cloudgda.go index 14862909b43f..be351a1e1ca3 100644 --- a/internal/tools/cloudgda/cloudgda.go +++ b/internal/tools/cloudgda/cloudgda.go @@ -20,12 +20,14 @@ import ( "fmt" "net/http" + "cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/protobuf/encoding/protojson" ) const resourceType string = "cloud-gemini-data-analytics-query" @@ -62,7 +64,49 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { GetProjectID() string UseClientAuthorization() bool - RunQuery(context.Context, string, []byte) (any, error) + RunQuery(context.Context, string, *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) +} + +// QueryDataContext wraps geminidataanalyticspb.QueryDataContext to support YAML decoding via protojson. +type QueryDataContext struct { + *geminidataanalyticspb.QueryDataContext +} + +func (q *QueryDataContext) UnmarshalYAML(b []byte) error { + var raw map[string]any + if err := yaml.Unmarshal(b, &raw); err != nil { + return fmt.Errorf("failed to unmarshal context from yaml: %w", err) + } + jsonBytes, err := json.Marshal(raw) + if err != nil { + return fmt.Errorf("failed to marshal context map: %w", err) + } + q.QueryDataContext = &geminidataanalyticspb.QueryDataContext{} + if err := protojson.Unmarshal(jsonBytes, q.QueryDataContext); err != nil { + return fmt.Errorf("failed to unmarshal context to proto: %w", err) + } + return nil +} + +// GenerationOptions wraps geminidataanalyticspb.GenerationOptions to support YAML decoding via protojson. +type GenerationOptions struct { + *geminidataanalyticspb.GenerationOptions +} + +func (g *GenerationOptions) UnmarshalYAML(b []byte) error { + var raw map[string]any + if err := yaml.Unmarshal(b, &raw); err != nil { + return fmt.Errorf("failed to unmarshal generation options from yaml: %w", err) + } + jsonBytes, err := json.Marshal(raw) + if err != nil { + return fmt.Errorf("failed to marshal generation options map: %w", err) + } + g.GenerationOptions = &geminidataanalyticspb.GenerationOptions{} + if err := protojson.Unmarshal(jsonBytes, g.GenerationOptions); err != nil { + return fmt.Errorf("failed to unmarshal generation options to proto: %w", err) + } + return nil } type Config struct { @@ -99,12 +143,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - return Tool{ + t := Tool{ Config: cfg, AllParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, - }, nil + } + + return t, nil } // validate interface @@ -146,19 +192,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // The parent in the request payload uses the tool's configured location. payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location) - payload := &QueryDataRequest{ - Parent: payloadParent, - Prompt: query, - Context: t.Context, - GenerationOptions: t.GenerationOptions, + req := &geminidataanalyticspb.QueryDataRequest{ + Parent: payloadParent, + Prompt: query, } - bodyBytes, err := json.Marshal(payload) - if err != nil { - return nil, util.NewClientServerError("failed to marshal request payload", http.StatusInternalServerError, err) + if t.Context != nil { + req.Context = t.Context.QueryDataContext + } + + if t.GenerationOptions != nil { + req.GenerationOptions = t.GenerationOptions.GenerationOptions } - resp, err := source.RunQuery(ctx, tokenStr, bodyBytes) + resp, err := source.RunQuery(ctx, tokenStr, req) if err != nil { return nil, util.ProcessGcpError(err) } diff --git a/internal/tools/cloudgda/cloudgda_test.go b/internal/tools/cloudgda/cloudgda_test.go index d5e73658ea3c..73b29ccf174b 100644 --- a/internal/tools/cloudgda/cloudgda_test.go +++ b/internal/tools/cloudgda/cloudgda_test.go @@ -16,18 +16,15 @@ package cloudgda_test import ( "context" - "encoding/json" "fmt" - "io" - "net/http" - "net/http/httptest" "testing" + "cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/googleapis/genai-toolbox/internal/server" "github.com/googleapis/genai-toolbox/internal/server/resources" "github.com/googleapis/genai-toolbox/internal/sources" - cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda" "github.com/googleapis/genai-toolbox/internal/testutils" "github.com/googleapis/genai-toolbox/internal/tools" cloudgdatool "github.com/googleapis/genai-toolbox/internal/tools/cloudgda" @@ -77,23 +74,29 @@ func TestParseFromYaml(t *testing.T) { Location: "us-central1", AuthRequired: []string{}, Context: &cloudgdatool.QueryDataContext{ - DatasourceReferences: &cloudgdatool.DatasourceReferences{ - SpannerReference: &cloudgdatool.SpannerReference{ - DatabaseReference: &cloudgdatool.SpannerDatabaseReference{ - ProjectID: "cloud-db-nl2sql", - Region: "us-central1", - InstanceID: "evalbench", - DatabaseID: "financial", - Engine: cloudgdatool.SpannerEngineGoogleSQL, - }, - AgentContextReference: &cloudgdatool.AgentContextReference{ - ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates", + QueryDataContext: &geminidataanalyticspb.QueryDataContext{ + DatasourceReferences: &geminidataanalyticspb.DatasourceReferences{ + References: &geminidataanalyticspb.DatasourceReferences_SpannerReference{ + SpannerReference: &geminidataanalyticspb.SpannerReference{ + DatabaseReference: &geminidataanalyticspb.SpannerDatabaseReference{ + ProjectId: "cloud-db-nl2sql", + Region: "us-central1", + InstanceId: "evalbench", + DatabaseId: "financial", + Engine: geminidataanalyticspb.SpannerDatabaseReference_GOOGLE_SQL, + }, + AgentContextReference: &geminidataanalyticspb.AgentContextReference{ + ContextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates", + }, + }, }, }, }, }, GenerationOptions: &cloudgdatool.GenerationOptions{ - GenerateQueryResult: true, + GenerationOptions: &geminidataanalyticspb.GenerationOptions{ + GenerateQueryResult: true, + }, }, }, }, @@ -107,68 +110,63 @@ func TestParseFromYaml(t *testing.T) { if err != nil { t.Fatalf("unable to unmarshal: %s", err) } - if !cmp.Equal(tc.want, got) { + if !cmp.Equal(tc.want, got, cmpopts.IgnoreUnexported(geminidataanalyticspb.QueryDataContext{}, geminidataanalyticspb.DatasourceReferences{}, geminidataanalyticspb.SpannerReference{}, geminidataanalyticspb.SpannerDatabaseReference{}, geminidataanalyticspb.AgentContextReference{}, geminidataanalyticspb.GenerationOptions{}, geminidataanalyticspb.DatasourceReferences_SpannerReference{})) { t.Fatalf("incorrect parse: want %v, got %v", tc.want, got) } }) } } -// authRoundTripper is a mock http.RoundTripper that adds a dummy Authorization header. -type authRoundTripper struct { - Token string - Next http.RoundTripper +// fakeSource implements the compatibleSource interface for testing. +type fakeSource struct { + projectID string + useClientOAuth bool + expectedQuery string + expectedParent string + response *geminidataanalyticspb.QueryDataResponse } -func (rt *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - newReq := *req - newReq.Header = make(http.Header) - for k, v := range req.Header { - newReq.Header[k] = v - } - newReq.Header.Set("Authorization", rt.Token) - if rt.Next == nil { - return http.DefaultTransport.RoundTrip(&newReq) - } - return rt.Next.RoundTrip(&newReq) +func (f *fakeSource) GetProjectID() string { + return f.projectID +} + +func (f *fakeSource) UseClientAuthorization() bool { + return f.useClientOAuth +} + +func (f *fakeSource) SourceType() string { + return "cloud-gemini-data-analytics" } -type mockSource struct { - Type string - client *http.Client // Can be used to inject a specific client - baseURL string // BaseURL is needed to implement sources.Source.BaseURL - config cloudgdasrc.Config // to return from ToConfig +func (f *fakeSource) ToConfig() sources.SourceConfig { + return nil } -func (m *mockSource) SourceType() string { return m.Type } -func (m *mockSource) ToConfig() sources.SourceConfig { return m.config } -func (m *mockSource) GetClient(ctx context.Context, token string) (*http.Client, error) { - if m.client != nil { - return m.client, nil +func (f *fakeSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) { + return f, nil +} + +func (f *fakeSource) RunQuery(ctx context.Context, token string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) { + if req.Prompt != f.expectedQuery { + return nil, fmt.Errorf("unexpected query: got %q, want %q", req.Prompt, f.expectedQuery) } - // Default client for testing if not explicitly set - transport := &http.Transport{} - authTransport := &authRoundTripper{ - Token: "Bearer test-access-token", // Dummy token - Next: transport, + if req.Parent != f.expectedParent { + return nil, fmt.Errorf("unexpected parent: got %q, want %q", req.Parent, f.expectedParent) } - return &http.Client{Transport: authTransport}, nil -} -func (m *mockSource) UseClientAuthorization() bool { return false } -func (m *mockSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) { - return m, nil + // Basic validation of context/options could be added here if needed, + // but the test case mainly checks if they are passed correctly via successful invocation. + + return f.response, nil } -func (m *mockSource) BaseURL() string { return m.baseURL } func TestInitialize(t *testing.T) { t.Parallel() + // Minimal fake source + fake := &fakeSource{projectID: "test-project"} + srcs := map[string]sources.Source{ - "gda-api-source": &cloudgdasrc.Source{ - Config: cloudgdasrc.Config{Name: "gda-api-source", Type: cloudgdasrc.SourceType, ProjectID: "test-project"}, - Client: &http.Client{}, - BaseURL: cloudgdasrc.Endpoint, - }, + "gda-api-source": fake, } tcs := []struct { @@ -187,9 +185,7 @@ func TestInitialize(t *testing.T) { }, } - // Add an incompatible source for testing - srcs["incompatible-source"] = &mockSource{Type: "another-type"} - + // No incompatible source for testing needed with fakeSource for _, tc := range tcs { tc := tc t.Run(tc.desc, func(t *testing.T) { @@ -206,92 +202,27 @@ func TestInitialize(t *testing.T) { func TestInvoke(t *testing.T) { t.Parallel() - // Mock the HTTP client and server for Invoke testing - serverMux := http.NewServeMux() - // Update expected URL path to include the location "us-central1" - serverMux.HandleFunc("/v1beta/projects/test-project/locations/global:queryData", func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - t.Errorf("expected POST method, got %s", r.Method) - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - if r.Header.Get("Content-Type") != "application/json" { - t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) - http.Error(w, "Bad request", http.StatusBadRequest) - return - } - - // Read and unmarshal the request body - bodyBytes, err := io.ReadAll(r.Body) - if err != nil { - t.Errorf("failed to read request body: %v", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - var reqPayload cloudgdatool.QueryDataRequest - if err := json.Unmarshal(bodyBytes, &reqPayload); err != nil { - t.Errorf("failed to unmarshal request payload: %v", err) - http.Error(w, "Bad request", http.StatusBadRequest) - return - } - // Verify expected fields - if r.Header.Get("Authorization") == "" { - t.Errorf("expected Authorization header, got empty") - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return - } - if reqPayload.Prompt != "How many accounts who have region in Prague are eligible for loans?" { - t.Errorf("unexpected prompt: %s", reqPayload.Prompt) - } + projectID := "test-project" + location := "us-central1" + query := "How many accounts who have region in Prague are eligible for loans?" + expectedParent := fmt.Sprintf("projects/%s/locations/%s", projectID, location) - // Verify payload's parent uses the tool's configured location - if reqPayload.Parent != fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1") { - t.Errorf("unexpected payload parent: got %q, want %q", reqPayload.Parent, fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1")) - } - - // Verify context from config - if reqPayload.Context == nil || - reqPayload.Context.DatasourceReferences == nil || - reqPayload.Context.DatasourceReferences.SpannerReference == nil || - reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference == nil || - reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference.ProjectID != "cloud-db-nl2sql" { - t.Errorf("unexpected context: %v", reqPayload.Context) - } - - // Verify generation options from config - if reqPayload.GenerationOptions == nil || !reqPayload.GenerationOptions.GenerateQueryResult { - t.Errorf("unexpected generation options: %v", reqPayload.GenerationOptions) - } - - // Simulate a successful response - resp := map[string]any{ - "queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;", - "naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.", - } - _ = json.NewEncoder(w).Encode(resp) - }) - - mockServer := httptest.NewServer(serverMux) - defer mockServer.Close() - - ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent") - - // Create an authenticated client that uses the mock server - authTransport := &authRoundTripper{ - Token: "Bearer test-access-token", - Next: mockServer.Client().Transport, + // Prepare expected response + expectedResp := &geminidataanalyticspb.QueryDataResponse{ + GeneratedQuery: "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;", + NaturalLanguageAnswer: "There are 5 accounts in Prague eligible for loans.", } - authClient := &http.Client{Transport: authTransport} - // Create a real cloudgdasrc.Source but inject the authenticated client - mockGdaSource := &cloudgdasrc.Source{ - Config: cloudgdasrc.Config{Name: "mock-gda-source", Type: cloudgdasrc.SourceType, ProjectID: "test-project"}, - Client: authClient, - BaseURL: mockServer.URL, + fake := &fakeSource{ + projectID: projectID, + expectedQuery: query, + expectedParent: expectedParent, + response: expectedResp, } + srcs := map[string]sources.Source{ - "mock-gda-source": mockGdaSource, + "mock-gda-source": fake, } // Initialize the tool config with context @@ -300,25 +231,31 @@ func TestInvoke(t *testing.T) { Type: "cloud-gemini-data-analytics-query", Source: "mock-gda-source", Description: "Query Gemini Data Analytics", - Location: "us-central1", // Set location for the test + Location: location, Context: &cloudgdatool.QueryDataContext{ - DatasourceReferences: &cloudgdatool.DatasourceReferences{ - SpannerReference: &cloudgdatool.SpannerReference{ - DatabaseReference: &cloudgdatool.SpannerDatabaseReference{ - ProjectID: "cloud-db-nl2sql", - Region: "us-central1", - InstanceID: "evalbench", - DatabaseID: "financial", - Engine: cloudgdatool.SpannerEngineGoogleSQL, - }, - AgentContextReference: &cloudgdatool.AgentContextReference{ - ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates", + QueryDataContext: &geminidataanalyticspb.QueryDataContext{ + DatasourceReferences: &geminidataanalyticspb.DatasourceReferences{ + References: &geminidataanalyticspb.DatasourceReferences_SpannerReference{ + SpannerReference: &geminidataanalyticspb.SpannerReference{ + DatabaseReference: &geminidataanalyticspb.SpannerDatabaseReference{ + ProjectId: "cloud-db-nl2sql", + Region: "us-central1", + InstanceId: "evalbench", + DatabaseId: "financial", + Engine: geminidataanalyticspb.SpannerDatabaseReference_GOOGLE_SQL, + }, + AgentContextReference: &geminidataanalyticspb.AgentContextReference{ + ContextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates", + }, + }, }, }, }, }, GenerationOptions: &cloudgdatool.GenerationOptions{ - GenerateQueryResult: true, + GenerationOptions: &geminidataanalyticspb.GenerationOptions{ + GenerateQueryResult: true, + }, }, } @@ -329,24 +266,25 @@ func TestInvoke(t *testing.T) { // Prepare parameters for invocation - ONLY query params := parameters.ParamValues{ - {Name: "query", Value: "How many accounts who have region in Prague are eligible for loans?"}, + {Name: "query", Value: query}, } resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil, nil) + ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent") + // Invoke the tool - result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client + result, err := tool.Invoke(ctx, resourceMgr, params, "") if err != nil { t.Fatalf("tool invocation failed: %v", err) } - // Validate the result - expectedResult := map[string]any{ - "queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;", - "naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.", + gotResp, ok := result.(*geminidataanalyticspb.QueryDataResponse) + if !ok { + t.Fatalf("expected result type *geminidataanalyticspb.QueryDataResponse, got %T", result) } - if !cmp.Equal(expectedResult, result) { - t.Errorf("unexpected result: got %v, want %v", result, expectedResult) + if diff := cmp.Diff(expectedResp, gotResp, cmpopts.IgnoreUnexported(geminidataanalyticspb.QueryDataResponse{})); diff != "" { + t.Errorf("unexpected result mismatch (-want +got):\n%s", diff) } } diff --git a/internal/tools/cloudgda/types.go b/internal/tools/cloudgda/types.go deleted file mode 100644 index 8e82cb50c226..000000000000 --- a/internal/tools/cloudgda/types.go +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cloudgda - -// See full service definition at: https://github.com/googleapis/googleapis/blob/master/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto - -// QueryDataRequest represents the JSON body for the queryData API -type QueryDataRequest struct { - Parent string `json:"parent"` - Prompt string `json:"prompt"` - Context *QueryDataContext `json:"context,omitempty"` - GenerationOptions *GenerationOptions `json:"generationOptions,omitempty"` -} - -// QueryDataContext reflects the proto definition for the query context. -type QueryDataContext struct { - DatasourceReferences *DatasourceReferences `json:"datasourceReferences,omitempty" yaml:"datasourceReferences,omitempty"` -} - -// DatasourceReferences reflects the proto definition for datasource references, using a oneof. -type DatasourceReferences struct { - SpannerReference *SpannerReference `json:"spannerReference,omitempty" yaml:"spannerReference,omitempty"` - AlloyDBReference *AlloyDBReference `json:"alloydb,omitempty" yaml:"alloydb,omitempty"` - CloudSQLReference *CloudSQLReference `json:"cloudSqlReference,omitempty" yaml:"cloudSqlReference,omitempty"` -} - -// SpannerReference reflects the proto definition for Spanner database reference. -type SpannerReference struct { - DatabaseReference *SpannerDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"` - AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"` -} - -// SpannerDatabaseReference reflects the proto definition for a Spanner database reference. -type SpannerDatabaseReference struct { - Engine SpannerEngine `json:"engine,omitempty" yaml:"engine,omitempty"` - ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"` - Region string `json:"region,omitempty" yaml:"region,omitempty"` - InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"` - DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"` - TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"` -} - -// SpannerEngine represents the engine of the Spanner instance. -type SpannerEngine string - -const ( - SpannerEngineUnspecified SpannerEngine = "ENGINE_UNSPECIFIED" - SpannerEngineGoogleSQL SpannerEngine = "GOOGLE_SQL" - SpannerEnginePostgreSQL SpannerEngine = "POSTGRESQL" -) - -// AlloyDBReference reflects the proto definition for an AlloyDB database reference. -type AlloyDBReference struct { - DatabaseReference *AlloyDBDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"` - AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"` -} - -// AlloyDBDatabaseReference reflects the proto definition for an AlloyDB database reference. -type AlloyDBDatabaseReference struct { - ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"` - Region string `json:"region,omitempty" yaml:"region,omitempty"` - ClusterID string `json:"clusterId,omitempty" yaml:"clusterId,omitempty"` - InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"` - DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"` - TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"` -} - -// CloudSQLReference reflects the proto definition for a Cloud SQL database reference. -type CloudSQLReference struct { - DatabaseReference *CloudSQLDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"` - AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"` -} - -// CloudSQLDatabaseReference reflects the proto definition for a Cloud SQL database reference. -type CloudSQLDatabaseReference struct { - Engine CloudSQLEngine `json:"engine,omitempty" yaml:"engine,omitempty"` - ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"` - Region string `json:"region,omitempty" yaml:"region,omitempty"` - InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"` - DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"` - TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"` -} - -// CloudSQLEngine represents the engine of the Cloud SQL instance. -type CloudSQLEngine string - -const ( - CloudSQLEngineUnspecified CloudSQLEngine = "ENGINE_UNSPECIFIED" - CloudSQLEnginePostgreSQL CloudSQLEngine = "POSTGRESQL" - CloudSQLEngineMySQL CloudSQLEngine = "MYSQL" -) - -// AgentContextReference reflects the proto definition for agent context. -type AgentContextReference struct { - ContextSetID string `json:"contextSetId,omitempty" yaml:"contextSetId,omitempty"` -} - -// GenerationOptions reflects the proto definition for generation options. -type GenerationOptions struct { - GenerateQueryResult bool `json:"generateQueryResult" yaml:"generateQueryResult"` - GenerateNaturalLanguageAnswer bool `json:"generateNaturalLanguageAnswer" yaml:"generateNaturalLanguageAnswer"` - GenerateExplanation bool `json:"generateExplanation" yaml:"generateExplanation"` - GenerateDisambiguationQuestion bool `json:"generateDisambiguationQuestion" yaml:"generateDisambiguationQuestion"` -} diff --git a/tests/cloudgda/cloud_gda_integration_test.go b/tests/cloudgda/cloud_gda_integration_test.go index 24c0cab1cbe1..557f80bdd93e 100644 --- a/tests/cloudgda/cloud_gda_integration_test.go +++ b/tests/cloudgda/cloud_gda_integration_test.go @@ -18,78 +18,75 @@ import ( "bytes" "context" "encoding/json" + "fmt" + "net" "net/http" - "net/http/httptest" - "net/url" "regexp" "strings" "testing" "time" + geminidataanalytics "cloud.google.com/go/geminidataanalytics/apiv1beta" + "cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" + source "github.com/googleapis/genai-toolbox/internal/sources/cloudgda" "github.com/googleapis/genai-toolbox/internal/testutils" "github.com/googleapis/genai-toolbox/internal/tools/cloudgda" "github.com/googleapis/genai-toolbox/tests" + "google.golang.org/api/option" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" ) var ( cloudGdaToolType = "cloud-gemini-data-analytics-query" ) -type cloudGdaTransport struct { - transport http.RoundTripper - url *url.URL -} - -func (t *cloudGdaTransport) RoundTrip(req *http.Request) (*http.Response, error) { - if strings.HasPrefix(req.URL.String(), "https://geminidataanalytics.googleapis.com") { - req.URL.Scheme = t.url.Scheme - req.URL.Host = t.url.Host - } - return t.transport.RoundTrip(req) -} - -type masterHandler struct { +type mockDataChatServer struct { + geminidataanalyticspb.UnimplementedDataChatServiceServer t *testing.T } -func (h *masterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if !strings.Contains(r.UserAgent(), "genai-toolbox/") { - h.t.Errorf("User-Agent header not found") - } - - if r.Method != http.MethodPost { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - // Verify URL structure - // Expected: /v1beta/projects/{project}/locations/global:queryData - if !strings.Contains(r.URL.Path, ":queryData") || !strings.Contains(r.URL.Path, "locations/global") { - h.t.Errorf("unexpected URL path: %s", r.URL.Path) - http.Error(w, "Not found", http.StatusNotFound) - return - } - - var reqBody cloudgda.QueryDataRequest - if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { - h.t.Fatalf("failed to decode request body: %v", err) - } - - if reqBody.Prompt == "" { - http.Error(w, "missing prompt", http.StatusBadRequest) - return +func (s *mockDataChatServer) QueryData(ctx context.Context, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) { + if req.Prompt == "" { + s.t.Errorf("missing prompt") + return nil, fmt.Errorf("missing prompt") } - response := map[string]any{ - "queryResult": "SELECT * FROM table;", - "naturalLanguageAnswer": "Here is the answer.", - } + return &geminidataanalyticspb.QueryDataResponse{ + GeneratedQuery: "SELECT * FROM table;", + NaturalLanguageAnswer: "Here is the answer.", + }, nil +} - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(response); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) +func getCloudGdaToolsConfig() map[string]any { + return map[string]any{ + "sources": map[string]any{ + "my-gda-source": map[string]any{ + "type": "cloud-gemini-data-analytics", + "projectId": "test-project", + }, + }, + "tools": map[string]any{ + "cloud-gda-query": map[string]any{ + "type": cloudGdaToolType, + "source": "my-gda-source", + "description": "Test GDA Tool", + "location": "us-central1", + "context": map[string]any{ + "datasourceReferences": map[string]any{ + "spannerReference": map[string]any{ + "databaseReference": map[string]any{ + "projectId": "test-project", + "instanceId": "test-instance", + "databaseId": "test-db", + "engine": "GOOGLE_SQL", + }, + }, + }, + }, + }, + }, } } @@ -97,27 +94,38 @@ func TestCloudGdaToolEndpoints(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - handler := &masterHandler{t: t} - server := httptest.NewServer(handler) - defer server.Close() - - serverURL, err := url.Parse(server.URL) + // Start a gRPC server + lis, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { - t.Fatalf("failed to parse server URL: %v", err) + t.Fatalf("failed to listen: %v", err) + } + s := grpc.NewServer() + geminidataanalyticspb.RegisterDataChatServiceServer(s, &mockDataChatServer{t: t}) + go func() { + if err := s.Serve(lis); err != nil { + // This might happen on strict shutdown, log if unexpected + t.Logf("server executed: %v", err) + } + }() + defer s.Stop() + + // Configure toolbox to use the gRPC server + endpoint := lis.Addr().String() + + // Override client creation + origFunc := source.NewDataChatClient + defer func() { + source.NewDataChatClient = origFunc + }() + + source.NewDataChatClient = func(ctx context.Context, opts ...option.ClientOption) (*geminidataanalytics.DataChatClient, error) { + opts = append(opts, + option.WithEndpoint(endpoint), + option.WithoutAuthentication(), + option.WithGRPCDialOption(grpc.WithTransportCredentials(insecure.NewCredentials()))) + return origFunc(ctx, opts...) } - originalTransport := http.DefaultClient.Transport - if originalTransport == nil { - originalTransport = http.DefaultTransport - } - http.DefaultClient.Transport = &cloudGdaTransport{ - transport: originalTransport, - url: serverURL, - } - t.Cleanup(func() { - http.DefaultClient.Transport = originalTransport - }) - var args []string toolsFile := getCloudGdaToolsConfig() cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) @@ -156,7 +164,7 @@ func TestCloudGdaToolEndpoints(t *testing.T) { // 2. RunToolInvokeParametersTest params := []byte(`{"query": "test question"}`) - tests.RunToolInvokeParametersTest(t, toolName, params, "\"queryResult\":\"SELECT * FROM table;\"") + tests.RunToolInvokeParametersTest(t, toolName, params, "\"generated_query\":\"SELECT * FROM table;\"") // 3. Manual MCP Tool Call Test // Initialize MCP session @@ -196,38 +204,3 @@ func TestCloudGdaToolEndpoints(t *testing.T) { t.Errorf("MCP response does not contain expected query result: %s", respStr) } } - -func getCloudGdaToolsConfig() map[string]any { - // Mocked responses and a dummy `projectId` are used in this integration - // test due to limited project-specific allowlisting. API functionality is - // verified via internal monitoring; this test specifically validates the - // integration flow between the source and the tool. - return map[string]any{ - "sources": map[string]any{ - "my-gda-source": map[string]any{ - "type": "cloud-gemini-data-analytics", - "projectId": "test-project", - }, - }, - "tools": map[string]any{ - "cloud-gda-query": map[string]any{ - "type": cloudGdaToolType, - "source": "my-gda-source", - "description": "Test GDA Tool", - "location": "us-central1", - "context": map[string]any{ - "datasourceReferences": map[string]any{ - "spannerReference": map[string]any{ - "databaseReference": map[string]any{ - "projectId": "test-project", - "instanceId": "test-instance", - "databaseId": "test-db", - "engine": "GOOGLE_SQL", - }, - }, - }, - }, - }, - }, - } -}