Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,15 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorevalidaterules"
_ "github.com/googleapis/genai-toolbox/internal/tools/http"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardelement"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateprojectfile"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerconversationalanalytics"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateprojectfile"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdeleteprojectfile"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdevmode"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiondatabases"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnections"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectionschemas"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontables"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontablecolumns"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontables"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdashboards"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdimensions"
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetexplores"
Expand Down
143 changes: 116 additions & 27 deletions internal/sources/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,27 +72,30 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources

type Config struct {
// BigQuery configs
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Location string `yaml:"location"`
WriteMode string `yaml:"writeMode"`
AllowedDatasets []string `yaml:"allowedDatasets"`
UseClientOAuth bool `yaml:"useClientOAuth"`
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Location string `yaml:"location"`
WriteMode string `yaml:"writeMode"`
AllowedDatasets []string `yaml:"allowedDatasets"`
UseClientOAuth bool `yaml:"useClientOAuth"`
ImpersonateServiceAccount string `yaml:"impersonateServiceAccount"`
}

func (r Config) SourceConfigKind() string {
// Returns BigQuery source kind
return SourceKind
}

func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
if r.WriteMode == "" {
r.WriteMode = WriteModeAllowed
}

if r.WriteMode == WriteModeProtected && r.UseClientOAuth {
// The protected mode only allows write operations to the session's temporary datasets.
// when using client OAuth, a new session is created every
// time a BigQuery tool is invoked. Therefore, no session data can
// be preserved as needed by the protected mode.
return nil, fmt.Errorf("writeMode 'protected' cannot be used with useClientOAuth 'true'")
}

Expand All @@ -106,17 +109,38 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
var clientCreator BigqueryClientCreator
var err error

s := &Source{
Name: r.Name,
Kind: SourceKind,
Project: r.Project,
Location: r.Location,
Client: client,
RestService: restService,
TokenSource: tokenSource,
MaxQueryResultRows: 50,
WriteMode: r.WriteMode,
UseClientOAuth: r.UseClientOAuth,
ClientCreator: clientCreator,
ImpersonateServiceAccount: r.ImpersonateServiceAccount,
}

if r.UseClientOAuth {
clientCreator, err = newBigQueryClientCreator(ctx, tracer, r.Project, r.Location, r.Name)
// use client OAuth
baseClientCreator, err := newBigQueryClientCreator(ctx, tracer, r.Project, r.Location, r.Name)
if err != nil {
return nil, fmt.Errorf("error constructing client creator: %w", err)
}
setupClientCaching(s, baseClientCreator)

} else {
// Initializes a BigQuery Google SQL source
client, restService, tokenSource, err = initBigQueryConnection(ctx, tracer, r.Name, r.Project, r.Location, r.ImpersonateServiceAccount)
if err != nil {
return nil, fmt.Errorf("error creating client from ADC: %w", err)
}
s.Client = client
s.RestService = restService
s.TokenSource = tokenSource
}

allowedDatasets := make(map[string]struct{})
Expand All @@ -138,8 +162,8 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
allowedFullID = fmt.Sprintf("%s.%s", projectID, datasetID)
}

if client != nil {
dataset := client.DatasetInProject(projectID, datasetID)
if s.Client != nil {
dataset := s.Client.DatasetInProject(projectID, datasetID)
_, err := dataset.Metadata(ctx)
if err != nil {
if gerr, ok := err.(*googleapi.Error); ok && gerr.Code == http.StatusNotFound {
Expand All @@ -152,21 +176,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
}
}

s := &Source{
Name: r.Name,
Kind: SourceKind,
Project: r.Project,
Location: r.Location,
Client: client,
RestService: restService,
TokenSource: tokenSource,
MaxQueryResultRows: 50,
WriteMode: r.WriteMode,
AllowedDatasets: allowedDatasets,
UseClientOAuth: r.UseClientOAuth,
ClientCreator: clientCreator,
ImpersonateServiceAccount: r.ImpersonateServiceAccount,
}
s.AllowedDatasets = allowedDatasets
s.SessionProvider = s.newBigQuerySessionProvider()

if r.WriteMode != WriteModeAllowed && r.WriteMode != WriteModeBlocked && r.WriteMode != WriteModeProtected {
Expand All @@ -176,6 +186,58 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return s, nil
}

// setupClientCaching initializes caches and wraps the base client creator with caching logic.
func setupClientCaching(s *Source, baseCreator BigqueryClientCreator) {
// Define eviction handlers
onBqEvict := func(key string, value interface{}) {
if client, ok := value.(*bigqueryapi.Client); ok && client != nil {
client.Close()
}
}
onDataplexEvict := func(key string, value interface{}) {
if client, ok := value.(*dataplexapi.CatalogClient); ok && client != nil {
client.Close()
}
}

// Initialize caches
s.bqClientCache = NewCache(onBqEvict)
s.bqRestCache = NewCache(nil)
s.dataplexCache = NewCache(onDataplexEvict)

// Create the caching wrapper for the client creator
s.ClientCreator = func(tokenString string, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
// Check cache
bqClientVal, bqFound := s.bqClientCache.Get(tokenString)

if wantRestService {
restServiceVal, restFound := s.bqRestCache.Get(tokenString)
if bqFound && restFound {
// Cache hit for both
return bqClientVal.(*bigqueryapi.Client), restServiceVal.(*bigqueryrestapi.Service), nil
}
} else {
if bqFound {
return bqClientVal.(*bigqueryapi.Client), nil, nil
}
}

// Cache miss - call the client creator
client, restService, err := baseCreator(tokenString, wantRestService)
if err != nil {
return nil, nil, err
}

// Set in cache
s.bqClientCache.Set(tokenString, client)
if wantRestService && restService != nil {
s.bqRestCache.Set(tokenString, restService)
}

return client, restService, nil
}
}

var _ sources.Source = &Source{}

type Source struct {
Expand All @@ -197,6 +259,11 @@ type Source struct {
makeDataplexCatalogClient func() (*dataplexapi.CatalogClient, DataplexClientCreator, error)
SessionProvider BigQuerySessionProvider
Session *Session

// Caches for OAuth clients
bqClientCache *Cache
bqRestCache *Cache
dataplexCache *Cache
}

type Session struct {
Expand Down Expand Up @@ -397,7 +464,29 @@ func (s *Source) lazyInitDataplexClient(ctx context.Context, tracer trace.Tracer
return
}
client = c
clientCreator = cc

// If using OAuth, wrap the provided client creator (cc) with caching logic
if s.UseClientOAuth && cc != nil {
clientCreator = func(tokenString string) (*dataplexapi.CatalogClient, error) {
// Check cache
if val, found := s.dataplexCache.Get(tokenString); found {
return val.(*dataplexapi.CatalogClient), nil
}

// Cache miss - call client creator
dpClient, err := cc(tokenString)
if err != nil {
return nil, err
}

// Set in cache
s.dataplexCache.Set(tokenString, dpClient)
return dpClient, nil
}
} else {
// Not using OAuth or no creator was returned
clientCreator = cc
}
})
return client, clientCreator, err
}
Expand Down
125 changes: 125 additions & 0 deletions internal/sources/bigquery/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bigquery

import (
"sync"
"time"
)

// Item holds the cached value and its expiration timestamp
type Item struct {
Value any
ExpiresAt int64 // Unix nano timestamp
}

// IsExpired checks if the item is expired
func (item Item) IsExpired() bool {
return time.Now().UnixNano() > item.ExpiresAt
}

// OnEvictFunc is the signature for the callback
type OnEvictFunc func(key string, value any)

// Cache is a thread-safe, expiring key-value store
type Cache struct {
mu sync.RWMutex
items map[string]Item
onEvict OnEvictFunc
}

// NewCache creates a new cache and cleans up every 55 min
func NewCache(onEvict OnEvictFunc) *Cache {
const cleanupInterval = 55 * time.Minute

c := &Cache{
items: make(map[string]Item),
onEvict: onEvict,
}

go c.startCleanup(cleanupInterval)
return c
}

// startCleanup runs a ticker to periodically delete expired items
func (c *Cache) startCleanup(interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()

for range ticker.C {
c.DeleteExpired()
}
}

// delete is an internal helper that assumes the write lock is held
func (c *Cache) delete(key string, item Item) {
if c.onEvict != nil {
c.onEvict(key, item.Value)
}
delete(c.items, key)
}

// Set adds an item to the cache
func (c *Cache) Set(key string, value any) {
const ttl = 55 * time.Minute
expires := time.Now().Add(ttl).UnixNano()

c.mu.Lock()
defer c.mu.Unlock()

// If item already exists, evict the old one before replacing
if oldItem, found := c.items[key]; found {
c.delete(key, oldItem)
}

c.items[key] = Item{
Value: value,
ExpiresAt: expires,
}
}

// Get retrieves an item from the cache
func (c *Cache) Get(key string) (any, bool) {
c.mu.RLock()
item, found := c.items[key]
if !found || item.IsExpired() {
c.mu.RUnlock()
return nil, false
}
c.mu.RUnlock()

return item.Value, true
}

// Delete manually evicts an item
func (c *Cache) Delete(key string) {
c.mu.Lock()
defer c.mu.Unlock()

if item, found := c.items[key]; found {
c.delete(key, item)
}
}

// DeleteExpired removes all expired items
func (c *Cache) DeleteExpired() {
c.mu.Lock()
defer c.mu.Unlock()

for key, item := range c.items {
if item.IsExpired() {
c.delete(key, item)
}
}
}
Loading