diff --git a/cmd/root.go b/cmd/root.go index 5ddc1b88e4af..215c19bc403c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -98,15 +98,15 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorevalidaterules" _ "github.com/googleapis/genai-toolbox/internal/tools/http" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardelement" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateprojectfile" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerconversationalanalytics" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateprojectfile" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdeleteprojectfile" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdevmode" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiondatabases" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnections" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectionschemas" - _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontables" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontablecolumns" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontables" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdashboards" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdimensions" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetexplores" diff --git a/internal/sources/bigquery/bigquery.go b/internal/sources/bigquery/bigquery.go index d9729e12d0bb..5fe43e58bb19 100644 --- a/internal/sources/bigquery/bigquery.go +++ b/internal/sources/bigquery/bigquery.go @@ -72,13 +72,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources type Config struct { // BigQuery configs - Name string `yaml:"name" validate:"required"` - Kind string `yaml:"kind" validate:"required"` - Project string `yaml:"project" validate:"required"` - Location string `yaml:"location"` - WriteMode string `yaml:"writeMode"` - AllowedDatasets []string `yaml:"allowedDatasets"` - UseClientOAuth bool `yaml:"useClientOAuth"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Project string `yaml:"project" validate:"required"` + Location string `yaml:"location"` + WriteMode string `yaml:"writeMode"` + AllowedDatasets []string `yaml:"allowedDatasets"` + UseClientOAuth bool `yaml:"useClientOAuth"` ImpersonateServiceAccount string `yaml:"impersonateServiceAccount"` } @@ -86,13 +86,16 @@ func (r Config) SourceConfigKind() string { // Returns BigQuery source kind return SourceKind } - func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) { if r.WriteMode == "" { r.WriteMode = WriteModeAllowed } if r.WriteMode == WriteModeProtected && r.UseClientOAuth { + // The protected mode only allows write operations to the session's temporary datasets. + // when using client OAuth, a new session is created every + // time a BigQuery tool is invoked. Therefore, no session data can + // be preserved as needed by the protected mode. return nil, fmt.Errorf("writeMode 'protected' cannot be used with useClientOAuth 'true'") } @@ -106,17 +109,38 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So var clientCreator BigqueryClientCreator var err error + s := &Source{ + Name: r.Name, + Kind: SourceKind, + Project: r.Project, + Location: r.Location, + Client: client, + RestService: restService, + TokenSource: tokenSource, + MaxQueryResultRows: 50, + WriteMode: r.WriteMode, + UseClientOAuth: r.UseClientOAuth, + ClientCreator: clientCreator, + ImpersonateServiceAccount: r.ImpersonateServiceAccount, + } + if r.UseClientOAuth { - clientCreator, err = newBigQueryClientCreator(ctx, tracer, r.Project, r.Location, r.Name) + // use client OAuth + baseClientCreator, err := newBigQueryClientCreator(ctx, tracer, r.Project, r.Location, r.Name) if err != nil { return nil, fmt.Errorf("error constructing client creator: %w", err) } + setupClientCaching(s, baseClientCreator) + } else { // Initializes a BigQuery Google SQL source client, restService, tokenSource, err = initBigQueryConnection(ctx, tracer, r.Name, r.Project, r.Location, r.ImpersonateServiceAccount) if err != nil { return nil, fmt.Errorf("error creating client from ADC: %w", err) } + s.Client = client + s.RestService = restService + s.TokenSource = tokenSource } allowedDatasets := make(map[string]struct{}) @@ -138,8 +162,8 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So allowedFullID = fmt.Sprintf("%s.%s", projectID, datasetID) } - if client != nil { - dataset := client.DatasetInProject(projectID, datasetID) + if s.Client != nil { + dataset := s.Client.DatasetInProject(projectID, datasetID) _, err := dataset.Metadata(ctx) if err != nil { if gerr, ok := err.(*googleapi.Error); ok && gerr.Code == http.StatusNotFound { @@ -152,21 +176,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So } } - s := &Source{ - Name: r.Name, - Kind: SourceKind, - Project: r.Project, - Location: r.Location, - Client: client, - RestService: restService, - TokenSource: tokenSource, - MaxQueryResultRows: 50, - WriteMode: r.WriteMode, - AllowedDatasets: allowedDatasets, - UseClientOAuth: r.UseClientOAuth, - ClientCreator: clientCreator, - ImpersonateServiceAccount: r.ImpersonateServiceAccount, - } + s.AllowedDatasets = allowedDatasets s.SessionProvider = s.newBigQuerySessionProvider() if r.WriteMode != WriteModeAllowed && r.WriteMode != WriteModeBlocked && r.WriteMode != WriteModeProtected { @@ -176,6 +186,58 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return s, nil } +// setupClientCaching initializes caches and wraps the base client creator with caching logic. +func setupClientCaching(s *Source, baseCreator BigqueryClientCreator) { + // Define eviction handlers + onBqEvict := func(key string, value interface{}) { + if client, ok := value.(*bigqueryapi.Client); ok && client != nil { + client.Close() + } + } + onDataplexEvict := func(key string, value interface{}) { + if client, ok := value.(*dataplexapi.CatalogClient); ok && client != nil { + client.Close() + } + } + + // Initialize caches + s.bqClientCache = NewCache(onBqEvict) + s.bqRestCache = NewCache(nil) + s.dataplexCache = NewCache(onDataplexEvict) + + // Create the caching wrapper for the client creator + s.ClientCreator = func(tokenString string, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) { + // Check cache + bqClientVal, bqFound := s.bqClientCache.Get(tokenString) + + if wantRestService { + restServiceVal, restFound := s.bqRestCache.Get(tokenString) + if bqFound && restFound { + // Cache hit for both + return bqClientVal.(*bigqueryapi.Client), restServiceVal.(*bigqueryrestapi.Service), nil + } + } else { + if bqFound { + return bqClientVal.(*bigqueryapi.Client), nil, nil + } + } + + // Cache miss - call the client creator + client, restService, err := baseCreator(tokenString, wantRestService) + if err != nil { + return nil, nil, err + } + + // Set in cache + s.bqClientCache.Set(tokenString, client) + if wantRestService && restService != nil { + s.bqRestCache.Set(tokenString, restService) + } + + return client, restService, nil + } +} + var _ sources.Source = &Source{} type Source struct { @@ -197,6 +259,11 @@ type Source struct { makeDataplexCatalogClient func() (*dataplexapi.CatalogClient, DataplexClientCreator, error) SessionProvider BigQuerySessionProvider Session *Session + + // Caches for OAuth clients + bqClientCache *Cache + bqRestCache *Cache + dataplexCache *Cache } type Session struct { @@ -397,7 +464,29 @@ func (s *Source) lazyInitDataplexClient(ctx context.Context, tracer trace.Tracer return } client = c - clientCreator = cc + + // If using OAuth, wrap the provided client creator (cc) with caching logic + if s.UseClientOAuth && cc != nil { + clientCreator = func(tokenString string) (*dataplexapi.CatalogClient, error) { + // Check cache + if val, found := s.dataplexCache.Get(tokenString); found { + return val.(*dataplexapi.CatalogClient), nil + } + + // Cache miss - call client creator + dpClient, err := cc(tokenString) + if err != nil { + return nil, err + } + + // Set in cache + s.dataplexCache.Set(tokenString, dpClient) + return dpClient, nil + } + } else { + // Not using OAuth or no creator was returned + clientCreator = cc + } }) return client, clientCreator, err } diff --git a/internal/sources/bigquery/cache.go b/internal/sources/bigquery/cache.go new file mode 100644 index 000000000000..947c82238fd3 --- /dev/null +++ b/internal/sources/bigquery/cache.go @@ -0,0 +1,125 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package bigquery + +import ( + "sync" + "time" +) + +// Item holds the cached value and its expiration timestamp +type Item struct { + Value any + ExpiresAt int64 // Unix nano timestamp +} + +// IsExpired checks if the item is expired +func (item Item) IsExpired() bool { + return time.Now().UnixNano() > item.ExpiresAt +} + +// OnEvictFunc is the signature for the callback +type OnEvictFunc func(key string, value any) + +// Cache is a thread-safe, expiring key-value store +type Cache struct { + mu sync.RWMutex + items map[string]Item + onEvict OnEvictFunc +} + +// NewCache creates a new cache and cleans up every 55 min +func NewCache(onEvict OnEvictFunc) *Cache { + const cleanupInterval = 55 * time.Minute + + c := &Cache{ + items: make(map[string]Item), + onEvict: onEvict, + } + + go c.startCleanup(cleanupInterval) + return c +} + +// startCleanup runs a ticker to periodically delete expired items +func (c *Cache) startCleanup(interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for range ticker.C { + c.DeleteExpired() + } +} + +// delete is an internal helper that assumes the write lock is held +func (c *Cache) delete(key string, item Item) { + if c.onEvict != nil { + c.onEvict(key, item.Value) + } + delete(c.items, key) +} + +// Set adds an item to the cache +func (c *Cache) Set(key string, value any) { + const ttl = 55 * time.Minute + expires := time.Now().Add(ttl).UnixNano() + + c.mu.Lock() + defer c.mu.Unlock() + + // If item already exists, evict the old one before replacing + if oldItem, found := c.items[key]; found { + c.delete(key, oldItem) + } + + c.items[key] = Item{ + Value: value, + ExpiresAt: expires, + } +} + +// Get retrieves an item from the cache +func (c *Cache) Get(key string) (any, bool) { + c.mu.RLock() + item, found := c.items[key] + if !found || item.IsExpired() { + c.mu.RUnlock() + return nil, false + } + c.mu.RUnlock() + + return item.Value, true +} + +// Delete manually evicts an item +func (c *Cache) Delete(key string) { + c.mu.Lock() + defer c.mu.Unlock() + + if item, found := c.items[key]; found { + c.delete(key, item) + } +} + +// DeleteExpired removes all expired items +func (c *Cache) DeleteExpired() { + c.mu.Lock() + defer c.mu.Unlock() + + for key, item := range c.items { + if item.IsExpired() { + c.delete(key, item) + } + } +}