diff --git a/docs/features/semantic-caching.mdx b/docs/features/semantic-caching.mdx index d63211cfab..f25747c720 100644 --- a/docs/features/semantic-caching.mdx +++ b/docs/features/semantic-caching.mdx @@ -118,7 +118,6 @@ import ( cacheConfig := &semanticcache.Config{ // Embedding model configuration (Required) Provider: schemas.OpenAI, - Keys: []schemas.Key{{Value: "sk-..."}}, EmbeddingModel: "text-embedding-3-small", Dimension: 1536, @@ -155,22 +154,32 @@ bifrostConfig := schemas.BifrostConfig{  -**Note**: Make sure you have a vector store setup (using `config.json`) before configuring the semantic cache plugin. +**Prerequisites**: A vector store must be configured and enabled in `config.json`, and at least one provider must be configured, before the toggle becomes available. -1. **Navigate to Settings** - - Open Bifrost UI at `http://localhost:8080` - - Go to Settings. +1. **Navigate to the Config page** in the Bifrost UI and find the **Plugins** section. -2. **Configure Semantic Cache Plugin** +2. **Toggle** the **Enable Semantic Caching** switch to enable it. The configuration form expands below. -- Toggle the plugin switch to enable it, and fill in the required fields. +3. **Fill in the fields** across the four sections: -**Required Fields:** -- **Provider**: The provider to use for caching. -- **Embedding Model**: The embedding model to use for caching. -- **Dimension**: The embedding dimension for the configured embedding model. +**Provider and Model Settings** (required for semantic mode): +- **Configured Providers**: Dropdown of providers already set up in Bifrost. The selected provider's API keys are inherited automatically. +- **Embedding Model**: The embedding model to use (e.g. `text-embedding-3-small`). -**Note**: Changes will need a restart of the Bifrost server to take effect, because the plugin is loaded on startup only. +**Cache Settings**: +- **TTL (seconds)**: How long cached responses are kept (default: 300 s). +- **Similarity Threshold**: Cosine similarity cutoff for a cache hit (0–1, default: 0.8). +- **Dimension**: Vector dimension matching your embedding model (e.g. 1536 for `text-embedding-3-small`). + +**Conversation Settings**: +- **Conversation History Threshold**: Skip caching when the conversation has more than this many messages (default: 3). +- **Exclude System Prompt** (toggle): Exclude system messages from cache-key generation. + +**Cache Behavior**: +- **Cache by Model** (toggle): Include the model name in the cache key (default: on). +- **Cache by Provider** (toggle): Include the provider name in the cache key (default: on). + +4. Click **Save**. Changes are persisted and applied immediately for enabled plugins via the API reload path; other plugin changes (e.g. via `config.json`) may still require a restart. @@ -202,7 +211,7 @@ bifrostConfig := schemas.BifrostConfig{ } ``` -> **Note**: In `config.json` setups, provider keys are taken from the provider config on initialization, so you do not need to duplicate `keys` inside the plugin config. Any updates to the provider keys will not be reflected until next restart. +> **Note**: Provider API keys are inherited automatically from the global provider configuration. You do not need to (and cannot) specify keys inside the plugin config. **TTL Format Options:** - Duration strings: `"30s"`, `"5m"`, `"1h"`, `"24h"` @@ -228,7 +237,7 @@ Exact-match direct entries are stored and retrieved using a deterministic cache ### Setup -To enable direct-only mode globally, set `dimension: 1` and omit the `provider` and `keys` fields from the plugin config. The plugin will automatically fall back to direct search only. +To enable direct-only mode globally, set `dimension: 1` and omit the `provider` and `embedding_model` fields from the plugin config. The plugin will automatically fall back to direct search only. > **Important**: If you specify `dimension: 1` and also provide a `provider`, Bifrost treats the config as provider-backed semantic mode, not direct-only mode. To use direct-only mode, omit the `provider` field entirely. @@ -246,7 +255,7 @@ import ( ) cacheConfig := &semanticcache.Config{ - // No Provider, Keys, or EmbeddingModel -- direct hash mode only + // No Provider or EmbeddingModel -- direct hash mode only Dimension: 1, // Placeholder; entries are stored as metadata-only (no embedding vectors). Change dimension before switching to dual-layer mode to avoid mixed-dimension issues. TTL: 5 * time.Minute, diff --git a/plugins/semanticcache/main.go b/plugins/semanticcache/main.go index 148bbfef1a..c065ceff35 100644 --- a/plugins/semanticcache/main.go +++ b/plugins/semanticcache/main.go @@ -15,7 +15,6 @@ import ( bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/framework" "github.com/maximhq/bifrost/framework/vectorstore" ) @@ -25,7 +24,6 @@ import ( type Config struct { // Embedding Model settings - REQUIRED for semantic caching Provider schemas.ModelProvider `json:"provider"` - Keys []schemas.Key `json:"keys"` EmbeddingModel string `json:"embedding_model,omitempty"` // Model to use for generating embeddings (optional) // Plugin behavior settings @@ -48,19 +46,18 @@ type Config struct { func (c *Config) UnmarshalJSON(data []byte) error { // Define a temporary struct to avoid infinite recursion type TempConfig struct { - Provider string `json:"provider"` - Keys []schemas.Key `json:"keys"` - EmbeddingModel string `json:"embedding_model,omitempty"` - CleanUpOnShutdown bool `json:"cleanup_on_shutdown,omitempty"` - Dimension int `json:"dimension"` - TTL interface{} `json:"ttl,omitempty"` - Threshold float64 `json:"threshold,omitempty"` - VectorStoreNamespace string `json:"vector_store_namespace,omitempty"` - DefaultCacheKey string `json:"default_cache_key,omitempty"` - ConversationHistoryThreshold int `json:"conversation_history_threshold,omitempty"` - CacheByModel *bool `json:"cache_by_model,omitempty"` - CacheByProvider *bool `json:"cache_by_provider,omitempty"` - ExcludeSystemPrompt *bool `json:"exclude_system_prompt,omitempty"` + Provider string `json:"provider"` + EmbeddingModel string `json:"embedding_model,omitempty"` + CleanUpOnShutdown bool `json:"cleanup_on_shutdown,omitempty"` + Dimension int `json:"dimension"` + TTL interface{} `json:"ttl,omitempty"` + Threshold float64 `json:"threshold,omitempty"` + VectorStoreNamespace string `json:"vector_store_namespace,omitempty"` + DefaultCacheKey string `json:"default_cache_key,omitempty"` + ConversationHistoryThreshold int `json:"conversation_history_threshold,omitempty"` + CacheByModel *bool `json:"cache_by_model,omitempty"` + CacheByProvider *bool `json:"cache_by_provider,omitempty"` + ExcludeSystemPrompt *bool `json:"exclude_system_prompt,omitempty"` } var temp TempConfig @@ -70,7 +67,6 @@ func (c *Config) UnmarshalJSON(data []byte) error { // Set simple fields c.Provider = schemas.ModelProvider(temp.Provider) - c.Keys = temp.Keys c.EmbeddingModel = temp.EmbeddingModel c.CleanUpOnShutdown = temp.CleanUpOnShutdown c.Dimension = temp.Dimension @@ -129,6 +125,10 @@ type StreamAccumulator struct { mu sync.Mutex // Protects chunk operations } +// EmbeddingRequestExecutor is a function that executes a request and returns a response and an error. +// It maps to .EmbeddingRequest() of the bifrost client. +type EmbeddingRequestExecutor func(ctx *schemas.BifrostContext, req *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) + // Plugin implements the schemas.LLMPlugin interface for semantic caching. // It caches responses using a two-tier approach: direct hash matching for exact requests // and semantic similarity search for related content. The plugin supports configurable caching behavior @@ -139,12 +139,12 @@ type StreamAccumulator struct { // - config: Plugin configuration including semantic cache and caching settings // - logger: Logger instance for plugin operations type Plugin struct { - store vectorstore.VectorStore - config *Config - logger schemas.Logger - client *bifrost.Bifrost - streamAccumulators sync.Map // Track stream accumulators by request ID - waitGroup sync.WaitGroup + store vectorstore.VectorStore + config *Config + logger schemas.Logger + embeddingRequestExecutor EmbeddingRequestExecutor + streamAccumulators sync.Map // Track stream accumulators by request ID + waitGroup sync.WaitGroup } // Plugin constants @@ -201,45 +201,6 @@ var VectorStoreProperties = map[string]vectorstore.VectorStoreProperties{ }, } -type PluginAccount struct { - provider schemas.ModelProvider - keys []schemas.Key -} - -func (pa *PluginAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { - return []schemas.ModelProvider{pa.provider}, nil -} - -func (pa *PluginAccount) GetKeysForProvider(ctx context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { - return pa.keys, nil -} - -func (pa *PluginAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { - return &schemas.ProviderConfig{ - NetworkConfig: schemas.DefaultNetworkConfig, - ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, - }, nil -} - -// Dependencies is a list of dependencies that the plugin requires. -var Dependencies []framework.FrameworkDependency = []framework.FrameworkDependency{framework.FrameworkDependencyVectorStore} - -// ProvidersWithEmbeddingSupport lists all providers that support embedding operations. -// Providers not in this list will return UnsupportedOperationError for embedding requests. -var ProvidersWithEmbeddingSupport = map[schemas.ModelProvider]bool{ - schemas.OpenAI: true, - schemas.Azure: true, - schemas.Bedrock: true, - schemas.Cohere: true, - schemas.Gemini: true, - schemas.Vertex: true, - schemas.Mistral: true, - schemas.Ollama: true, - schemas.Nebius: true, - schemas.HuggingFace: true, - schemas.SGL: true, -} - const ( CacheKey schemas.BifrostContextKey = "semantic_cache_key" // To set the cache key for a request - REQUIRED for all requests CacheTTLKey schemas.BifrostContextKey = "semantic_cache_ttl" // To explicitly set the TTL for a request @@ -323,26 +284,8 @@ func Init(ctx context.Context, config *Config, logger schemas.Logger, store vect if config.Provider == "" && config.Dimension == 1 { logger.Info(PluginLoggerPrefix + " Starting in direct-only mode (dimension=1, no embedding provider)") - } else if config.Provider == "" || len(config.Keys) == 0 { - logger.Warn(PluginLoggerPrefix + " Incomplete semantic mode config: missing provider or keys, falling back to direct search only") - } else { - // Validate that the provider supports embeddings - if bifrost.IsStandardProvider(config.Provider) && !ProvidersWithEmbeddingSupport[config.Provider] { - return nil, fmt.Errorf("provider '%s' does not support embedding operations required for semantic cache. Supported providers: openai, azure, bedrock, cohere, gemini, vertex, mistral, ollama, nebius, huggingface, sgl. Note: custom providers based on embedding-capable providers are also supported", config.Provider) - } - - bifrost, err := bifrost.Init(ctx, schemas.BifrostConfig{ - Logger: logger, - Account: &PluginAccount{ - provider: config.Provider, - keys: config.Keys, - }, - }) - if err != nil { - return nil, fmt.Errorf("failed to initialize bifrost for semantic cache: %w", err) - } - - plugin.client = bifrost + } else if config.Provider == "" { + logger.Warn(PluginLoggerPrefix + " Incomplete semantic mode config: missing provider, falling back to direct search only") } createCtx, cancel := context.WithTimeout(ctx, CreateNamespaceTimeout) @@ -378,19 +321,6 @@ func (plugin *Plugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, return chunk, nil } -func (plugin *Plugin) clearRequestScopedContext(ctx *schemas.BifrostContext) { - ctx.ClearValue(requestIDKey) - ctx.ClearValue(requestStorageIDKey) - ctx.ClearValue(requestHashKey) - ctx.ClearValue(requestParamsHashKey) - ctx.ClearValue(requestModelKey) - ctx.ClearValue(requestProviderKey) - ctx.ClearValue(requestEmbeddingKey) - ctx.ClearValue(requestEmbeddingTokensKey) - ctx.ClearValue(isCacheHitKey) - ctx.ClearValue(cacheHitTypeKey) -} - // PreLLMHook is called before a request is processed by Bifrost. // It performs a two-stage cache lookup: first direct hash matching, then semantic similarity search. // Uses UUID-based keys for entries stored in the VectorStore. @@ -465,7 +395,7 @@ func (plugin *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifro } } - if performSemanticSearch && plugin.client != nil { + if performSemanticSearch && plugin.embeddingRequestExecutor != nil { if req.EmbeddingRequest != nil || req.TranscriptionRequest != nil { plugin.logger.Debug(PluginLoggerPrefix + " Skipping semantic search for embedding/transcription input") // For vector stores that require vectors, set a zero vector placeholder @@ -488,7 +418,7 @@ func (plugin *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifro if shortCircuit != nil { return req, shortCircuit, nil } - } else if !performSemanticSearch && plugin.store.RequiresVectors() && plugin.client != nil { + } else if !performSemanticSearch && plugin.store.RequiresVectors() && plugin.embeddingRequestExecutor != nil { // Vector store requires vectors but we're in direct-only mode // Generate embeddings for storage purposes (not for searching) if req.EmbeddingRequest != nil || req.TranscriptionRequest != nil { @@ -759,11 +689,6 @@ func (plugin *Plugin) Cleanup() error { // Clean up old stream accumulators first plugin.cleanupOldStreamAccumulators() - // Shutdown the internal Bifrost client used for embeddings - if plugin.client != nil { - plugin.client.Shutdown() - } - // Only clean up cache entries if configured to do so if !plugin.config.CleanUpOnShutdown { plugin.logger.Debug(PluginLoggerPrefix + " Cleanup on shutdown is disabled, skipping cache cleanup") @@ -804,6 +729,15 @@ func (plugin *Plugin) Cleanup() error { return nil } +// SetEmbeddingRequestExecutor sets the embedding request executor for the plugin. +// Needs to be set before the plugin is used. +// +// Parameters: +// - executor: The embedding request executor to set +func (plugin *Plugin) SetEmbeddingRequestExecutor(executor EmbeddingRequestExecutor) { + plugin.embeddingRequestExecutor = executor +} + // Public Methods for External Use // ClearCacheForKey deletes cache entries for a specific cache key. @@ -869,3 +803,16 @@ func (plugin *Plugin) ClearCacheForRequestID(requestID string) error { return nil } + +func (plugin *Plugin) clearRequestScopedContext(ctx *schemas.BifrostContext) { + ctx.ClearValue(requestIDKey) + ctx.ClearValue(requestStorageIDKey) + ctx.ClearValue(requestHashKey) + ctx.ClearValue(requestParamsHashKey) + ctx.ClearValue(requestModelKey) + ctx.ClearValue(requestProviderKey) + ctx.ClearValue(requestEmbeddingKey) + ctx.ClearValue(requestEmbeddingTokensKey) + ctx.ClearValue(isCacheHitKey) + ctx.ClearValue(cacheHitTypeKey) +} diff --git a/plugins/semanticcache/plugin_core_test.go b/plugins/semanticcache/plugin_core_test.go index 822fc1f645..5bed26528d 100644 --- a/plugins/semanticcache/plugin_core_test.go +++ b/plugins/semanticcache/plugin_core_test.go @@ -2,7 +2,6 @@ package semanticcache import ( "context" - "strings" "testing" "time" @@ -389,9 +388,6 @@ func TestCacheConfiguration(t *testing.T) { EmbeddingModel: "text-embedding-3-small", Dimension: 1536, Threshold: 0.95, // Very high threshold - Keys: []schemas.Key{ - {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: schemas.WhiteList{"*"}, Weight: 1.0}, - }, }, expectedBehavior: "strict_matching", }, @@ -402,9 +398,6 @@ func TestCacheConfiguration(t *testing.T) { EmbeddingModel: "text-embedding-3-small", Dimension: 1536, Threshold: 0.1, // Very low threshold - Keys: []schemas.Key{ - {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: schemas.WhiteList{"*"}, Weight: 1.0}, - }, }, expectedBehavior: "loose_matching", }, @@ -416,9 +409,6 @@ func TestCacheConfiguration(t *testing.T) { Dimension: 1536, Threshold: 0.8, TTL: 1 * time.Hour, // Custom TTL - Keys: []schemas.Key{ - {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: schemas.WhiteList{"*"}, Weight: 1.0}, - }, }, expectedBehavior: "custom_ttl", }, @@ -463,7 +453,7 @@ func (m *MockUnsupportedStore) Ping(ctx context.Context) error { } func (m *MockUnsupportedStore) CreateNamespace(ctx context.Context, namespace string, dimension int, properties map[string]vectorstore.VectorStoreProperties) error { - return vectorstore.ErrNotSupported + return nil } func (m *MockUnsupportedStore) DeleteNamespace(ctx context.Context, namespace string) error { @@ -547,23 +537,13 @@ func TestInvalidProviderRejection(t *testing.T) { Dimension: 1536, Threshold: 0.8, CleanUpOnShutdown: false, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.TEST_API_KEY"), - Models: schemas.WhiteList{"*"}, - Weight: 1.0, - }, - }, } + // Provider validation was moved to request time (global client handles it). + // Init itself should succeed regardless of the provider set in config. _, err := Init(ctx, config, logger, mockStore) - if err == nil { - t.Errorf("Expected error for provider '%s' but got none", provider) - } - - expectedErrSubstring := "does not support embedding operations" - if err != nil && !strings.Contains(err.Error(), expectedErrSubstring) { - t.Errorf("Expected error message to contain '%s', but got: %v", expectedErrSubstring, err) + if err != nil { + t.Errorf("Init should succeed for provider '%s' (validation happens at request time), but got: %v", provider, err) } }) } @@ -584,18 +564,11 @@ func TestValidProviderAccepted(t *testing.T) { Dimension: 1536, Threshold: 0.8, CleanUpOnShutdown: false, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: schemas.WhiteList{"*"}, - Weight: 1.0, - }, - }, } - // Should fail due to namespace creation, not provider validation + // Init should succeed; provider validation happens at request time via the global client. _, err := Init(ctx, config, logger, mockStore) - if err != nil && strings.Contains(err.Error(), "does not support embedding operations") { - t.Errorf("Valid provider OpenAI should not be rejected for embedding support, but got: %v", err) + if err != nil { + t.Errorf("Valid provider OpenAI should be accepted at Init, but got: %v", err) } } diff --git a/plugins/semanticcache/plugin_image_generation_test.go b/plugins/semanticcache/plugin_image_generation_test.go index f50f3c5c9b..a65c06e81b 100644 --- a/plugins/semanticcache/plugin_image_generation_test.go +++ b/plugins/semanticcache/plugin_image_generation_test.go @@ -128,9 +128,6 @@ func TestImageGenerationSemanticSearch(t *testing.T) { EmbeddingModel: "text-embedding-3-small", Dimension: 1536, Threshold: 0.5, - Keys: []schemas.Key{ - {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: []string{"*"}, Weight: 1.0}, - }, } setup := NewTestSetupWithConfig(t, config) defer setup.Cleanup() diff --git a/plugins/semanticcache/plugin_vectorstore_test.go b/plugins/semanticcache/plugin_vectorstore_test.go index 5e390bbe80..f4ac8130f2 100644 --- a/plugins/semanticcache/plugin_vectorstore_test.go +++ b/plugins/semanticcache/plugin_vectorstore_test.go @@ -55,13 +55,6 @@ func getDefaultTestConfig() *Config { Dimension: 1536, Threshold: 0.8, CleanUpOnShutdown: true, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: schemas.WhiteList{"*"}, - Weight: 1.0, - }, - }, } } diff --git a/plugins/semanticcache/test_utils.go b/plugins/semanticcache/test_utils.go index d267254473..e9b847c6dc 100644 --- a/plugins/semanticcache/test_utils.go +++ b/plugins/semanticcache/test_utils.go @@ -371,13 +371,6 @@ func NewTestSetup(t *testing.T) *TestSetup { Dimension: 1536, Threshold: 0.8, CleanUpOnShutdown: true, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: schemas.WhiteList{"*"}, - Weight: 1.0, - }, - }, }) } @@ -427,6 +420,9 @@ func NewTestSetupWithVectorStore(t *testing.T, config *Config, storeType vectors // Get a mocked Bifrost client client := getMockedBifrostClient(t, ctx, logger, plugin) + // Wire the global client as the embedding executor so semantic search works. + pluginImpl.SetEmbeddingRequestExecutor(client.EmbeddingRequest) + return &TestSetup{ Logger: logger, Store: store, @@ -648,13 +644,6 @@ func CreateTestSetupWithConversationThreshold(t *testing.T, threshold int) *Test CleanUpOnShutdown: true, Threshold: 0.8, ConversationHistoryThreshold: threshold, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{"*"}, - Weight: 1.0, - }, - }, } return NewTestSetupWithConfig(t, config) @@ -669,13 +658,6 @@ func CreateTestSetupWithExcludeSystemPrompt(t *testing.T, excludeSystem bool) *T CleanUpOnShutdown: true, Threshold: 0.8, ExcludeSystemPrompt: &excludeSystem, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{"*"}, - Weight: 1.0, - }, - }, } return NewTestSetupWithConfig(t, config) @@ -691,13 +673,6 @@ func CreateTestSetupWithThresholdAndExcludeSystem(t *testing.T, threshold int, e Threshold: 0.8, ConversationHistoryThreshold: threshold, ExcludeSystemPrompt: &excludeSystem, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{"*"}, - Weight: 1.0, - }, - }, } return NewTestSetupWithConfig(t, config) diff --git a/plugins/semanticcache/utils.go b/plugins/semanticcache/utils.go index 712030051a..957115ee24 100644 --- a/plugins/semanticcache/utils.go +++ b/plugins/semanticcache/utils.go @@ -67,8 +67,16 @@ func (plugin *Plugin) generateEmbedding(ctx *schemas.BifrostContext, text string }, } - // Generate embedding using bifrost client - response, err := plugin.client.EmbeddingRequest(ctx, embeddingReq) + // Create a new context from incoming context. Parent ctx will be used for cancellation. + embeddingCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + defer embeddingCtx.ReleasePluginScope() + + embeddingCtx.SetValue(schemas.BifrostContextKeySkipPluginPipeline, true) + + if plugin.embeddingRequestExecutor == nil { + return nil, 0, fmt.Errorf("embedding request executor is not configured") + } + response, err := plugin.embeddingRequestExecutor(embeddingCtx, embeddingReq) if err != nil { return nil, 0, fmt.Errorf("failed to generate embedding: %v", err) } diff --git a/transports/bifrost-http/handlers/plugins.go b/transports/bifrost-http/handlers/plugins.go index 11b362dddf..3dc4f8353c 100644 --- a/transports/bifrost-http/handlers/plugins.go +++ b/transports/bifrost-http/handlers/plugins.go @@ -35,25 +35,23 @@ func NewPluginsHandler(pluginsLoader PluginsLoader, configStore configstore.Conf } } - - // CreatePluginRequest is the request body for creating a plugin type CreatePluginRequest struct { - Name string `json:"name"` - Enabled bool `json:"enabled"` - Config map[string]any `json:"config"` - Path *string `json:"path"` + Name string `json:"name"` + Enabled bool `json:"enabled"` + Config map[string]any `json:"config"` + Path *string `json:"path"` Placement *schemas.PluginPlacement `json:"placement,omitempty"` - Order *int `json:"order,omitempty"` + Order *int `json:"order,omitempty"` } // UpdatePluginRequest is the request body for updating a plugin type UpdatePluginRequest struct { - Enabled bool `json:"enabled"` - Path *string `json:"path"` - Config map[string]any `json:"config"` + Enabled bool `json:"enabled"` + Path *string `json:"path"` + Config map[string]any `json:"config"` Placement *schemas.PluginPlacement `json:"placement,omitempty"` - Order *int `json:"order,omitempty"` + Order *int `json:"order,omitempty"` } // RegisterRoutes registers the routes for the PluginsHandler @@ -66,15 +64,15 @@ func (h *PluginsHandler) RegisterRoutes(r *router.Router, middlewares ...schemas } type PluginResponse struct { - Name string `json:"name"` - ActualName string `json:"actualName"` - Enabled bool `json:"enabled"` - Config any `json:"config"` - IsCustom bool `json:"isCustom"` - Path *string `json:"path"` - Placement *schemas.PluginPlacement `json:"placement,omitempty"` - Order *int `json:"order,omitempty"` - Status schemas.PluginStatus `json:"status"` + Name string `json:"name"` + ActualName string `json:"actualName"` + Enabled bool `json:"enabled"` + Config any `json:"config"` + IsCustom bool `json:"isCustom"` + Path *string `json:"path"` + Placement *schemas.PluginPlacement `json:"placement,omitempty"` + Order *int `json:"order,omitempty"` + Status schemas.PluginStatus `json:"status"` } // buildPluginResponse constructs a PluginResponse with status for a given TablePlugin. diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index aa845ab0e1..77ed94437f 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -2623,8 +2623,8 @@ func loadPlugins(ctx context.Context, config *Config, configData *ConfigData) { Order: plugin.Order, } if plugin.Name == semanticcache.PluginName { - if err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig); err != nil { - logger.Warn("failed to add provider keys to semantic cache config: %v", err) + if err := config.ValidateSemanticCacheConfig(pluginConfig); err != nil { + logger.Warn("failed to validate semantic cache config: %v", err) } } config.PluginConfigs[i] = pluginConfig @@ -2693,16 +2693,6 @@ func mergePlugins(ctx context.Context, config *Config, configData *ConfigData) { } } - // Process semantic cache plugin - for i, plugin := range config.PluginConfigs { - if plugin.Name == semanticcache.PluginName { - if err := config.AddProviderKeysToSemanticCacheConfig(plugin); err != nil { - logger.Warn("failed to add provider keys to semantic cache config: %v", err) - } - config.PluginConfigs[i] = plugin - } - } - // Update store if config.ConfigStore != nil { logger.Debug("updating plugins in store") @@ -2724,11 +2714,6 @@ func mergePlugins(ctx context.Context, config *Config, configData *ConfigData) { Placement: plugin.Placement, Order: plugin.Order, } - if plugin.Name == semanticcache.PluginName { - if err := config.RemoveProviderKeysFromSemanticCacheConfig(pluginConfig); err != nil { - logger.Warn("failed to remove provider keys from semantic cache config: %v", err) - } - } if err := config.ConfigStore.UpsertPlugin(ctx, pluginConfig); err != nil { logger.Warn("failed to update plugin: %v", err) } @@ -4787,7 +4772,7 @@ func ValidateCustomProviderUpdate(newConfig, existingConfig configstore.Provider return nil } -func (c *Config) AddProviderKeysToSemanticCacheConfig(config *schemas.PluginConfig) error { +func (c *Config) ValidateSemanticCacheConfig(config *schemas.PluginConfig) error { if config.Name != semanticcache.PluginName { return nil } @@ -4856,13 +4841,11 @@ func (c *Config) AddProviderKeysToSemanticCacheConfig(config *schemas.PluginConf } configMap["embedding_model"] = embeddingModel - keys, err := c.GetProviderConfigRaw(schemas.ModelProvider(provider)) - if err != nil { + // Validate that the provider is configured in the global client (keys are inherited automatically). + if _, err := c.GetProviderConfigRaw(schemas.ModelProvider(provider)); err != nil { return fmt.Errorf("failed to get provider config for %s: %w", provider, err) } - configMap["keys"] = keys.Keys - return nil } @@ -4909,30 +4892,6 @@ func semanticCacheConfigDimension(configMap map[string]interface{}) (int, bool, return 0, false, fmt.Errorf("semantic_cache plugin 'dimension' field must be numeric, got %T", dimensionVal) } } - -func (c *Config) RemoveProviderKeysFromSemanticCacheConfig(config *configstoreTables.TablePlugin) error { - if config.Name != semanticcache.PluginName { - return nil - } - - // Check if config.Config exists - if config.Config == nil { - return fmt.Errorf("semantic_cache plugin config is nil") - } - - // Type assert config.Config to map[string]interface{} - configMap, ok := config.Config.(map[string]interface{}) - if !ok { - return fmt.Errorf("semantic_cache plugin config must be a map, got %T", config.Config) - } - - configMap["keys"] = []schemas.Key{} - - config.Config = configMap - - return nil -} - func (c *Config) GetAvailableProviders(model string) []schemas.ModelProvider { c.Mu.RLock() defer c.Mu.RUnlock() diff --git a/transports/bifrost-http/lib/semantic_cache_config_test.go b/transports/bifrost-http/lib/semantic_cache_config_test.go index 61c5da22c7..2d79bd9526 100644 --- a/transports/bifrost-http/lib/semantic_cache_config_test.go +++ b/transports/bifrost-http/lib/semantic_cache_config_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyMode(t *testing.T) { +func TestValidateSemanticCacheConfig_DirectOnlyMode(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -19,7 +19,7 @@ func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyMode(t *testing.T) { }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.NoError(t, err) configMap, ok := pluginConfig.Config.(map[string]interface{}) @@ -28,7 +28,7 @@ func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyMode(t *testing.T) { require.False(t, hasKeys, "direct-only mode should not inject provider keys") } -func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyModeRemovesStaleProviderBackedFields(t *testing.T) { +func TestValidateSemanticCacheConfig_DirectOnlyModeRemovesStaleProviderBackedFields(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -39,18 +39,16 @@ func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyModeRemovesStaleProvider }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.NoError(t, err) configMap, ok := pluginConfig.Config.(map[string]interface{}) require.True(t, ok) - _, hasKeys := configMap["keys"] - require.False(t, hasKeys, "direct-only mode should remove stale provider keys") _, hasEmbeddingModel := configMap["embedding_model"] require.False(t, hasEmbeddingModel, "direct-only mode should remove stale embedding_model") } -func TestAddProviderKeysToSemanticCacheConfig_InjectsProviderKeys(t *testing.T) { +func TestValidateSemanticCacheConfig_ProviderBackedModeValidationPasses(t *testing.T) { config := &Config{ Providers: map[schemas.ModelProvider]configstore.ProviderConfig{ schemas.OpenAI: { @@ -73,19 +71,17 @@ func TestAddProviderKeysToSemanticCacheConfig_InjectsProviderKeys(t *testing.T) }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.NoError(t, err) configMap, ok := pluginConfig.Config.(map[string]interface{}) require.True(t, ok) - keys, ok := configMap["keys"].([]schemas.Key) - require.True(t, ok, "provider-backed mode should inject provider keys") - require.Len(t, keys, 1) - require.Equal(t, "openai-key", keys[0].Name) + _, hasKeys := configMap["keys"] + require.False(t, hasKeys, "keys are inherited from global client; they must not be injected into the plugin config") require.Equal(t, "openai", configMap["provider"]) } -func TestAddProviderKeysToSemanticCacheConfig_SemanticModeMissingProvider(t *testing.T) { +func TestValidateSemanticCacheConfig_SemanticModeMissingProvider(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -94,12 +90,12 @@ func TestAddProviderKeysToSemanticCacheConfig_SemanticModeMissingProvider(t *tes }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.Error(t, err) require.Contains(t, err.Error(), "requires 'provider' for semantic mode") } -func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeMissingDimension(t *testing.T) { +func TestValidateSemanticCacheConfig_ProviderBackedModeMissingDimension(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -109,12 +105,12 @@ func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeMissingDimension }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.Error(t, err) require.Contains(t, err.Error(), "requires 'dimension' for provider-backed semantic mode") } -func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeDimensionOne(t *testing.T) { +func TestValidateSemanticCacheConfig_ProviderBackedModeDimensionOne(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -125,12 +121,12 @@ func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeDimensionOne(t * }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.Error(t, err) require.Contains(t, err.Error(), "requires 'dimension' > 1") } -func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeMissingEmbeddingModel(t *testing.T) { +func TestValidateSemanticCacheConfig_ProviderBackedModeMissingEmbeddingModel(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -140,12 +136,12 @@ func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeMissingEmbedding }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.Error(t, err) require.Contains(t, err.Error(), "requires 'embedding_model'") } -func TestAddProviderKeysToSemanticCacheConfig_InvalidDimensionZero(t *testing.T) { +func TestValidateSemanticCacheConfig_InvalidDimensionZero(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -154,12 +150,12 @@ func TestAddProviderKeysToSemanticCacheConfig_InvalidDimensionZero(t *testing.T) }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.Error(t, err) require.Contains(t, err.Error(), "'dimension' must be >= 1") } -func TestAddProviderKeysToSemanticCacheConfig_InvalidDimensionNegative(t *testing.T) { +func TestValidateSemanticCacheConfig_InvalidDimensionNegative(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -168,7 +164,7 @@ func TestAddProviderKeysToSemanticCacheConfig_InvalidDimensionNegative(t *testin }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.Error(t, err) require.Contains(t, err.Error(), "'dimension' must be >= 1") } diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index ca5a3da674..2bbdeadad5 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -969,6 +969,10 @@ func (s *BifrostHTTPServer) ReloadPlugin(ctx context.Context, name string, path if err != nil { return s.updatePluginErrorStatus(name, "loading", err) } + // Wire the embedding executor on the new instance before syncing. + if semanticCachePlugin, ok := plugin.(*semanticcache.Plugin); ok { + semanticCachePlugin.SetEmbeddingRequestExecutor(s.Client.EmbeddingRequest) + } return s.SyncLoadedPlugin(ctx, name, plugin, placement, order) } @@ -1372,7 +1376,6 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { } wg.Wait() } - logger.Info("models added to catalog") s.Config.SetBifrostClient(s.Client) // Initialize routes @@ -1394,6 +1397,11 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { apiMiddlewares = append(apiMiddlewares, s.AuthMiddleware.APIMiddleware()) } } + // Add semantic cache plugin embedding request executor if it exists + semanticCachePlugin, err := lib.FindPluginAs[*semanticcache.Plugin](s.Config, semanticcache.PluginName) + if err == nil && semanticCachePlugin != nil { + semanticCachePlugin.SetEmbeddingRequestExecutor(s.Client.EmbeddingRequest) + } // Register routes err = s.RegisterAPIRoutes(s.Ctx, s, apiMiddlewares...) if err != nil { diff --git a/transports/config.schema.json b/transports/config.schema.json index 80134f9287..fd69a78878 100644 --- a/transports/config.schema.json +++ b/transports/config.schema.json @@ -1393,13 +1393,6 @@ "huggingface" ] }, - "keys": { - "type": "array", - "description": "API keys for the embedding provider. These are injected at runtime for config-driven setups and are not needed for direct caching with dimension: 1.", - "items": { - "type": "string" - } - }, "embedding_model": { "type": "string", "description": "Model to use for generating embeddings in provider-backed semantic caching. Required when provider is set and not allowed in direct-only mode." diff --git a/ui/app/workspace/config/views/pluginsForm.tsx b/ui/app/workspace/config/views/pluginsForm.tsx index 33477ec8a2..dcd459de4c 100644 --- a/ui/app/workspace/config/views/pluginsForm.tsx +++ b/ui/app/workspace/config/views/pluginsForm.tsx @@ -48,9 +48,6 @@ const normalizeCacheConfigForSave = (config: EditorCacheConfig) => { if (config.updated_at !== undefined) { normalized.updated_at = config.updated_at; } - if (config.keys !== undefined) { - normalized.keys = config.keys; - } const provider = config.provider?.trim(); const embeddingModel = config.embedding_model?.trim(); @@ -375,7 +372,7 @@ export default function PluginsForm({ isVectorStoreEnabled }: PluginsFormProps)
API keys for the embedding provider will be inherited from the main provider configuration. The semantic cache will use - the configured provider's keys automatically. Updates in keys will be reflected on Bifrost restart. + the configured provider's keys automatically.
diff --git a/ui/lib/types/config.ts b/ui/lib/types/config.ts index d4768fe981..25d645fda4 100644 --- a/ui/lib/types/config.ts +++ b/ui/lib/types/config.ts @@ -540,13 +540,11 @@ interface BaseCacheConfig { export interface DirectCacheConfig extends BaseCacheConfig { dimension: 1; provider?: undefined; - keys?: ModelProviderKey[]; embedding_model?: undefined; } export interface ProviderBackedCacheConfig extends BaseCacheConfig { provider: ModelProviderName; - keys?: ModelProviderKey[]; embedding_model: string; dimension: number; } @@ -555,7 +553,6 @@ export type CacheConfig = DirectCacheConfig | ProviderBackedCacheConfig; export interface EditorCacheConfig extends BaseCacheConfig { provider?: ModelProviderName; - keys?: ModelProviderKey[]; embedding_model?: string; dimension?: number; } diff --git a/ui/lib/types/schemas.ts b/ui/lib/types/schemas.ts index 60ef6c2572..7a1fbc0834 100644 --- a/ui/lib/types/schemas.ts +++ b/ui/lib/types/schemas.ts @@ -4,1020 +4,1115 @@ import { z } from "zod"; // Global error map - turns Zod's default messages into readable, human-friendly ones. // Individual schemas can still override by passing their own message. z.config({ - customError: (issue) => { - if (issue.code === "invalid_type") { - // Field is missing / undefined - if (issue.input === undefined || issue.input === null) { - return "This field is required"; - } - const expected = issue.expected; - const received = typeof issue.input; - if (expected === "number") return "Must be a valid number"; - if (expected === "string") return "Must be a valid text value"; - if (expected === "boolean") return "Must be true or false"; - return `Expected ${expected}, received ${received}`; - } - if (issue.code === "too_small") { - if (issue.origin === "string" && issue.minimum === 1) { - return "This field is required"; - } - if (issue.origin === "number") { - return `Must be at least ${issue.minimum}`; - } - if (issue.origin === "array" && issue.minimum === 1) { - return "At least one item is required"; - } - } - if (issue.code === "too_big") { - if (issue.origin === "number") { - return `Must be at most ${issue.maximum}`; - } - if (issue.origin === "string") { - return `Must be at most ${issue.maximum} characters`; - } - } - if (issue.code === "invalid_format") { - if (issue.format === "url") return "Must be a valid URL"; - if (issue.format === "email") return "Must be a valid email"; - } - return undefined; // fall back to Zod default - }, + customError: (issue) => { + if (issue.code === "invalid_type") { + // Field is missing / undefined + if (issue.input === undefined || issue.input === null) { + return "This field is required"; + } + const expected = issue.expected; + const received = typeof issue.input; + if (expected === "number") return "Must be a valid number"; + if (expected === "string") return "Must be a valid text value"; + if (expected === "boolean") return "Must be true or false"; + return `Expected ${expected}, received ${received}`; + } + if (issue.code === "too_small") { + if (issue.origin === "string" && issue.minimum === 1) { + return "This field is required"; + } + if (issue.origin === "number") { + return `Must be at least ${issue.minimum}`; + } + if (issue.origin === "array" && issue.minimum === 1) { + return "At least one item is required"; + } + } + if (issue.code === "too_big") { + if (issue.origin === "number") { + return `Must be at most ${issue.maximum}`; + } + if (issue.origin === "string") { + return `Must be at most ${issue.maximum} characters`; + } + } + if (issue.code === "invalid_format") { + if (issue.format === "url") return "Must be a valid URL"; + if (issue.format === "email") return "Must be a valid email"; + } + return undefined; // fall back to Zod default + }, }); // Base Zod schemas matching the TypeScript types // Known provider schema -export const knownProviderSchema = z.enum(KnownProvidersNames as unknown as [string, ...string[]]); +export const knownProviderSchema = z.enum( + KnownProvidersNames as unknown as [string, ...string[]], +); // Custom provider name schema (branded type simulation) -export const customProviderNameSchema = z.string().min(1, "Custom provider name is required"); +export const customProviderNameSchema = z + .string() + .min(1, "Custom provider name is required"); // Model provider name schema (union of known and custom providers) -export const modelProviderNameSchema = z.union([knownProviderSchema, customProviderNameSchema]); +export const modelProviderNameSchema = z.union([ + knownProviderSchema, + customProviderNameSchema, +]); // EnvVar schema - matches the Go EnvVar type from schemas/env.go export const _envVarBase = z.object({ - value: z.string().optional(), - env_var: z.string().optional(), - from_env: z.boolean().optional(), + value: z.string().optional(), + env_var: z.string().optional(), + from_env: z.boolean().optional(), }); // Extending the base schema export const envVarSchema = Object.assign(_envVarBase, { - required: (message: string) => _envVarBase.refine((v) => !!v?.value?.trim() || !!v?.env_var?.trim(), message), + required: (message: string) => + _envVarBase.refine( + (v) => !!v?.value?.trim() || !!v?.env_var?.trim(), + message, + ), }); // Helper to check if an envVar field has a value or env reference -function isEnvVarSet(v: { value?: string; env_var?: string } | undefined): boolean { - if (!v) return false; - return !!v.value?.trim() || !!v.env_var?.trim(); +function isEnvVarSet( + v: { value?: string; env_var?: string } | undefined, +): boolean { + if (!v) return false; + return !!v.value?.trim() || !!v.env_var?.trim(); } // Azure key config schema export const azureKeyConfigSchema = z - .object({ - _auth_type: z.enum(["api_key", "entra_id", "default_credential"]).optional(), - endpoint: envVarSchema.optional(), - api_version: envVarSchema.optional(), - client_id: envVarSchema.optional(), - client_secret: envVarSchema.optional(), - tenant_id: envVarSchema.optional(), - scopes: z.array(z.string()).optional(), - }) - .refine((data) => isEnvVarSet(data.endpoint), { - message: "Endpoint is required", - path: ["endpoint"], - }) - .refine( - (data) => { - // When using Entra ID, all three fields are required - if (data._auth_type === "entra_id") { - return isEnvVarSet(data.client_id) && isEnvVarSet(data.client_secret) && isEnvVarSet(data.tenant_id); - } - // Otherwise, if any Entra ID field is set, all three must be set - const hasClientId = isEnvVarSet(data.client_id); - const hasClientSecret = isEnvVarSet(data.client_secret); - const hasTenantId = isEnvVarSet(data.tenant_id); - const anyEntraField = hasClientId || hasClientSecret || hasTenantId; - if (!anyEntraField) return true; - return hasClientId && hasClientSecret && hasTenantId; - }, - { - message: "Client ID, Client Secret, and Tenant ID are all required for Entra ID authentication", - path: ["client_id"], - }, - ); + .object({ + _auth_type: z + .enum(["api_key", "entra_id", "default_credential"]) + .optional(), + endpoint: envVarSchema.optional(), + api_version: envVarSchema.optional(), + client_id: envVarSchema.optional(), + client_secret: envVarSchema.optional(), + tenant_id: envVarSchema.optional(), + scopes: z.array(z.string()).optional(), + }) + .refine((data) => isEnvVarSet(data.endpoint), { + message: "Endpoint is required", + path: ["endpoint"], + }) + .refine( + (data) => { + // When using Entra ID, all three fields are required + if (data._auth_type === "entra_id") { + return ( + isEnvVarSet(data.client_id) && + isEnvVarSet(data.client_secret) && + isEnvVarSet(data.tenant_id) + ); + } + // Otherwise, if any Entra ID field is set, all three must be set + const hasClientId = isEnvVarSet(data.client_id); + const hasClientSecret = isEnvVarSet(data.client_secret); + const hasTenantId = isEnvVarSet(data.tenant_id); + const anyEntraField = hasClientId || hasClientSecret || hasTenantId; + if (!anyEntraField) return true; + return hasClientId && hasClientSecret && hasTenantId; + }, + { + message: + "Client ID, Client Secret, and Tenant ID are all required for Entra ID authentication", + path: ["client_id"], + }, + ); // Vertex key config schema export const vertexKeyConfigSchema = z - .object({ - _auth_type: z.enum(["service_account", "service_account_json", "api_key"]).optional(), - project_id: envVarSchema.optional(), - project_number: envVarSchema.optional(), - region: envVarSchema.optional(), - auth_credentials: envVarSchema.optional(), - }) - .refine((data) => isEnvVarSet(data.project_id), { - message: "Project ID is required", - path: ["project_id"], - }) - .refine((data) => isEnvVarSet(data.region), { - message: "Region is required", - path: ["region"], - }) - .refine( - (data) => { - // When using service_account_json auth, auth_credentials is required - if (data._auth_type === "service_account_json") { - return isEnvVarSet(data.auth_credentials); - } - return true; - }, - { - message: "Auth Credentials is required for service account JSON authentication", - path: ["auth_credentials"], - }, - ); + .object({ + _auth_type: z + .enum(["service_account", "service_account_json", "api_key"]) + .optional(), + project_id: envVarSchema.optional(), + project_number: envVarSchema.optional(), + region: envVarSchema.optional(), + auth_credentials: envVarSchema.optional(), + }) + .refine((data) => isEnvVarSet(data.project_id), { + message: "Project ID is required", + path: ["project_id"], + }) + .refine((data) => isEnvVarSet(data.region), { + message: "Region is required", + path: ["region"], + }) + .refine( + (data) => { + // When using service_account_json auth, auth_credentials is required + if (data._auth_type === "service_account_json") { + return isEnvVarSet(data.auth_credentials); + } + return true; + }, + { + message: + "Auth Credentials is required for service account JSON authentication", + path: ["auth_credentials"], + }, + ); // S3 bucket configuration for Bedrock batch operations export const s3BucketConfigSchema = z.object({ - bucket_name: z.string().min(1, "Bucket name is required"), - prefix: z.string().optional(), - is_default: z.boolean().optional(), + bucket_name: z.string().min(1, "Bucket name is required"), + prefix: z.string().optional(), + is_default: z.boolean().optional(), }); export const batchS3ConfigSchema = z.object({ - buckets: z.array(s3BucketConfigSchema).optional(), + buckets: z.array(s3BucketConfigSchema).optional(), }); // Bedrock key config schema export const bedrockKeyConfigSchema = z - .object({ - _auth_type: z.enum(["iam_role", "explicit", "api_key"]).optional(), - access_key: envVarSchema.optional(), - secret_key: envVarSchema.optional(), - session_token: envVarSchema.optional(), - region: envVarSchema.optional(), - role_arn: envVarSchema.optional(), - external_id: envVarSchema.optional(), - session_name: envVarSchema.optional(), - arn: envVarSchema.optional(), - batch_s3_config: batchS3ConfigSchema.optional(), - }) - .refine( - (data) => { - // Region is required for Bedrock - return isEnvVarSet(data.region); - }, - { - message: "Region is required", - path: ["region"], - }, - ) - .refine( - (data) => { - // When using explicit credentials, both access_key and secret_key are required - if (data._auth_type === "explicit") { - return isEnvVarSet(data.access_key) && isEnvVarSet(data.secret_key); - } - // Otherwise, if either is set both must be set - const hasAccessKey = isEnvVarSet(data.access_key); - const hasSecretKey = isEnvVarSet(data.secret_key); - if (!hasAccessKey && !hasSecretKey) return true; - return hasAccessKey && hasSecretKey; - }, - { - message: "Both Access Key and Secret Key are required for explicit credentials", - path: ["access_key"], - }, - ); + .object({ + _auth_type: z.enum(["iam_role", "explicit", "api_key"]).optional(), + access_key: envVarSchema.optional(), + secret_key: envVarSchema.optional(), + session_token: envVarSchema.optional(), + region: envVarSchema.optional(), + role_arn: envVarSchema.optional(), + external_id: envVarSchema.optional(), + session_name: envVarSchema.optional(), + arn: envVarSchema.optional(), + batch_s3_config: batchS3ConfigSchema.optional(), + }) + .refine( + (data) => { + // Region is required for Bedrock + return isEnvVarSet(data.region); + }, + { + message: "Region is required", + path: ["region"], + }, + ) + .refine( + (data) => { + // When using explicit credentials, both access_key and secret_key are required + if (data._auth_type === "explicit") { + return isEnvVarSet(data.access_key) && isEnvVarSet(data.secret_key); + } + // Otherwise, if either is set both must be set + const hasAccessKey = isEnvVarSet(data.access_key); + const hasSecretKey = isEnvVarSet(data.secret_key); + if (!hasAccessKey && !hasSecretKey) return true; + return hasAccessKey && hasSecretKey; + }, + { + message: + "Both Access Key and Secret Key are required for explicit credentials", + path: ["access_key"], + }, + ); // VLLM key config schema export const vllmKeyConfigSchema = z - .object({ - url: envVarSchema.optional(), - model_name: z.string().trim().min(1, "Model name is required"), - }) - .refine((data) => isEnvVarSet(data.url), { - message: "Server URL is required", - path: ["url"], - }); + .object({ + url: envVarSchema.optional(), + model_name: z.string().trim().min(1, "Model name is required"), + }) + .refine((data) => isEnvVarSet(data.url), { + message: "Server URL is required", + path: ["url"], + }); export const replicateKeyConfigSchema = z.object({ - use_deployments_endpoint: z.boolean(), + use_deployments_endpoint: z.boolean(), }); // Ollama key config schema export const ollamaKeyConfigSchema = z - .object({ - url: envVarSchema.optional(), - }) - .refine((data) => isEnvVarSet(data.url), { - message: "Server URL is required", - path: ["url"], - }); + .object({ + url: envVarSchema.optional(), + }) + .refine((data) => isEnvVarSet(data.url), { + message: "Server URL is required", + path: ["url"], + }); // SGL key config schema export const sglKeyConfigSchema = z - .object({ - url: envVarSchema.optional(), - }) - .refine((data) => isEnvVarSet(data.url), { - message: "Server URL is required", - path: ["url"], - }); + .object({ + url: envVarSchema.optional(), + }) + .refine((data) => isEnvVarSet(data.url), { + message: "Server URL is required", + path: ["url"], + }); // Model provider key schema export const modelProviderKeySchema = z - .object({ - id: z.string().min(1, "Id is required"), - name: z.string().min(1, "Name is required"), - value: envVarSchema.optional(), - models: z.array(z.string()).optional().default(["*"]), - blacklisted_models: z.array(z.string()).default([]).optional(), - weight: z - .union([z.number(), z.string()]) - .transform((val, ctx) => { - if (typeof val === "number") return val; - if (val.trim() === "") return 1.0; - // Use Number() rather than parseFloat() so that strings like "0.5abc" - // are rejected outright instead of silently parsing to 0.5. - const num = Number(val); - if (!Number.isFinite(num)) { - ctx.addIssue({ - code: "custom", - message: "Weight must be a valid number between 0 and 1", - }); - return z.NEVER; - } - return num; - }) - .pipe(z.number().min(0, "Weight must be equal to or greater than 0").max(1, "Weight must be equal to or less than 1")), - aliases: z.record(z.string(), z.string()).optional(), - azure_key_config: azureKeyConfigSchema.optional(), - vertex_key_config: vertexKeyConfigSchema.optional(), - bedrock_key_config: bedrockKeyConfigSchema.optional(), - vllm_key_config: vllmKeyConfigSchema.optional(), - replicate_key_config: replicateKeyConfigSchema.optional(), - ollama_key_config: ollamaKeyConfigSchema.optional(), - sgl_key_config: sglKeyConfigSchema.optional(), - use_for_batch_api: z.boolean().optional(), - enabled: z.boolean().optional(), - }) - .refine( - (data) => { - // Providers with dedicated config that never need a top-level API key - if (data.vllm_key_config || data.replicate_key_config || data.ollama_key_config || data.sgl_key_config) { - return true; - } - // Azure requires API key only when using api_key auth - if (data.azure_key_config) { - if (data.azure_key_config._auth_type === "api_key") { - return isEnvVarSet(data.value); - } - return true; - } - // Bedrock only requires API key when using api_key auth - if (data.bedrock_key_config) { - if (data.bedrock_key_config._auth_type === "api_key") { - return isEnvVarSet(data.value); - } - return true; - } - // Vertex requires API key only when using api_key auth - if (data.vertex_key_config) { - if (data.vertex_key_config._auth_type === "api_key") { - return isEnvVarSet(data.value); - } - return true; - } - // Otherwise, value is required - return isEnvVarSet(data.value); - }, - { - message: "API Key is required", - path: ["value"], - }, - ); + .object({ + id: z.string().min(1, "Id is required"), + name: z.string().min(1, "Name is required"), + value: envVarSchema.optional(), + models: z.array(z.string()).optional().default(["*"]), + blacklisted_models: z.array(z.string()).default([]).optional(), + weight: z + .union([z.number(), z.string()]) + .transform((val, ctx) => { + if (typeof val === "number") return val; + if (val.trim() === "") return 1.0; + // Use Number() rather than parseFloat() so that strings like "0.5abc" + // are rejected outright instead of silently parsing to 0.5. + const num = Number(val); + if (!Number.isFinite(num)) { + ctx.addIssue({ + code: "custom", + message: "Weight must be a valid number between 0 and 1", + }); + return z.NEVER; + } + return num; + }) + .pipe( + z + .number() + .min(0, "Weight must be equal to or greater than 0") + .max(1, "Weight must be equal to or less than 1"), + ), + aliases: z.record(z.string(), z.string()).optional(), + azure_key_config: azureKeyConfigSchema.optional(), + vertex_key_config: vertexKeyConfigSchema.optional(), + bedrock_key_config: bedrockKeyConfigSchema.optional(), + vllm_key_config: vllmKeyConfigSchema.optional(), + replicate_key_config: replicateKeyConfigSchema.optional(), + ollama_key_config: ollamaKeyConfigSchema.optional(), + sgl_key_config: sglKeyConfigSchema.optional(), + use_for_batch_api: z.boolean().optional(), + enabled: z.boolean().optional(), + }) + .refine( + (data) => { + // Providers with dedicated config that never need a top-level API key + if ( + data.vllm_key_config || + data.replicate_key_config || + data.ollama_key_config || + data.sgl_key_config + ) { + return true; + } + // Azure requires API key only when using api_key auth + if (data.azure_key_config) { + if (data.azure_key_config._auth_type === "api_key") { + return isEnvVarSet(data.value); + } + return true; + } + // Bedrock only requires API key when using api_key auth + if (data.bedrock_key_config) { + if (data.bedrock_key_config._auth_type === "api_key") { + return isEnvVarSet(data.value); + } + return true; + } + // Vertex requires API key only when using api_key auth + if (data.vertex_key_config) { + if (data.vertex_key_config._auth_type === "api_key") { + return isEnvVarSet(data.value); + } + return true; + } + // Otherwise, value is required + return isEnvVarSet(data.value); + }, + { + message: "API Key is required", + path: ["value"], + }, + ); // Network config schema export const networkConfigSchema = z - .object({ - base_url: z.union([z.string().url("Must be a valid URL"), z.string().length(0)]).optional(), - extra_headers: z.record(z.string(), z.string()).optional(), - default_request_timeout_in_seconds: z - .number() - .min(1, "Timeout must be greater than 0 seconds") - .max(3600, "Timeout must be less than 3600 seconds"), - max_retries: z.number().min(0, "Max retries must be greater than 0").max(10, "Max retries must be less than 10"), - retry_backoff_initial: z.number().min(100), - retry_backoff_max: z.number().min(100), - insecure_skip_verify: z.boolean().optional(), - ca_cert_pem: envVarSchema.optional(), - stream_idle_timeout_in_seconds: z - .number() - .int("Stream idle timeout must be a whole number of seconds") - .min(5, "Stream idle timeout must be at least 5 seconds") - .max(3600, "Stream idle timeout must be at most 3600 seconds i.e. 60 minutes") - .optional(), - max_conns_per_host: z - .number() - .int("Max connections must be a whole number") - .min(1, "Max connections must be at least 1") - .max(10000, "Max connections must be at most 10000") - .optional(), - enforce_http2: z.boolean().optional(), - }) - .refine((d) => d.retry_backoff_initial <= d.retry_backoff_max, { - message: "retry_backoff_initial must be <= retry_backoff_max", - path: ["retry_backoff_initial"], - }); + .object({ + base_url: z + .union([z.string().url("Must be a valid URL"), z.string().length(0)]) + .optional(), + extra_headers: z.record(z.string(), z.string()).optional(), + default_request_timeout_in_seconds: z + .number() + .min(1, "Timeout must be greater than 0 seconds") + .max(3600, "Timeout must be less than 3600 seconds"), + max_retries: z + .number() + .min(0, "Max retries must be greater than 0") + .max(10, "Max retries must be less than 10"), + retry_backoff_initial: z.number().min(100), + retry_backoff_max: z.number().min(100), + insecure_skip_verify: z.boolean().optional(), + ca_cert_pem: envVarSchema.optional(), + stream_idle_timeout_in_seconds: z + .number() + .int("Stream idle timeout must be a whole number of seconds") + .min(5, "Stream idle timeout must be at least 5 seconds") + .max( + 3600, + "Stream idle timeout must be at most 3600 seconds i.e. 60 minutes", + ) + .optional(), + max_conns_per_host: z + .number() + .int("Max connections must be a whole number") + .min(1, "Max connections must be at least 1") + .max(10000, "Max connections must be at most 10000") + .optional(), + enforce_http2: z.boolean().optional(), + }) + .refine((d) => d.retry_backoff_initial <= d.retry_backoff_max, { + message: "retry_backoff_initial must be <= retry_backoff_max", + path: ["retry_backoff_initial"], + }); // Network form schema - more lenient for form inputs export const networkFormConfigSchema = z - .object({ - base_url: z - .union([ - z - .string() - .url("Must be a valid URL") - .refine((url) => url.startsWith("https://") || url.startsWith("http://"), { - message: "Must be a valid HTTP or HTTPS URL", - }), - z.string().length(0), - ]) - .optional(), - extra_headers: z.record(z.string(), z.string()).optional(), - default_request_timeout_in_seconds: z.coerce - .number("Timeout must be a number") - .min(1, "Timeout must be greater than 0 seconds") - .max(172800, "Timeout must be less than 172800 seconds i.e. 48 hours"), - max_retries: z.coerce - .number("Max retries must be a number") - .min(0, "Max retries must be greater than 0") - .max(10, "Max retries must be less than 10"), - retry_backoff_initial: z.coerce - .number("Retry backoff initial must be a number") - .min(100, "Retry backoff initial must be at least 100ms") - .max(1000000, "Retry backoff initial must be at most 1000000ms"), - retry_backoff_max: z.coerce - .number("Retry backoff max must be a number") - .min(100, "Retry backoff max must be at least 100ms") - .max(1000000, "Retry backoff max must be at most 1000000ms"), - insecure_skip_verify: z.boolean().optional(), - ca_cert_pem: envVarSchema.optional(), - stream_idle_timeout_in_seconds: z.coerce - .number("Stream idle timeout must be a number") - .int("Stream idle timeout must be a whole number of seconds") - .min(5, "Stream idle timeout must be at least 5 seconds") - .max(3600, "Stream idle timeout must be at most 3600 seconds i.e. 60 minutes") - .optional(), - max_conns_per_host: z.coerce - .number("Max connections must be a number") - .int("Max connections must be a whole number") - .min(1, "Max connections must be at least 1") - .max(10000, "Max connections must be at most 10000") - .optional(), - enforce_http2: z.boolean().optional(), - }) - .refine((d) => d.retry_backoff_initial <= d.retry_backoff_max, { - message: "Initial backoff must be less than or equal to max backoff", - path: ["retry_backoff_initial"], - }); + .object({ + base_url: z + .union([ + z + .string() + .url("Must be a valid URL") + .refine( + (url) => url.startsWith("https://") || url.startsWith("http://"), + { + message: "Must be a valid HTTP or HTTPS URL", + }, + ), + z.string().length(0), + ]) + .optional(), + extra_headers: z.record(z.string(), z.string()).optional(), + default_request_timeout_in_seconds: z.coerce + .number("Timeout must be a number") + .min(1, "Timeout must be greater than 0 seconds") + .max(172800, "Timeout must be less than 172800 seconds i.e. 48 hours"), + max_retries: z.coerce + .number("Max retries must be a number") + .min(0, "Max retries must be greater than 0") + .max(10, "Max retries must be less than 10"), + retry_backoff_initial: z.coerce + .number("Retry backoff initial must be a number") + .min(100, "Retry backoff initial must be at least 100ms") + .max(1000000, "Retry backoff initial must be at most 1000000ms"), + retry_backoff_max: z.coerce + .number("Retry backoff max must be a number") + .min(100, "Retry backoff max must be at least 100ms") + .max(1000000, "Retry backoff max must be at most 1000000ms"), + insecure_skip_verify: z.boolean().optional(), + ca_cert_pem: envVarSchema.optional(), + stream_idle_timeout_in_seconds: z.coerce + .number("Stream idle timeout must be a number") + .int("Stream idle timeout must be a whole number of seconds") + .min(5, "Stream idle timeout must be at least 5 seconds") + .max( + 3600, + "Stream idle timeout must be at most 3600 seconds i.e. 60 minutes", + ) + .optional(), + max_conns_per_host: z.coerce + .number("Max connections must be a number") + .int("Max connections must be a whole number") + .min(1, "Max connections must be at least 1") + .max(10000, "Max connections must be at most 10000") + .optional(), + enforce_http2: z.boolean().optional(), + }) + .refine((d) => d.retry_backoff_initial <= d.retry_backoff_max, { + message: "Initial backoff must be less than or equal to max backoff", + path: ["retry_backoff_initial"], + }); // Concurrency and buffer size schema export const concurrencyAndBufferSizeSchema = z.object({ - concurrency: z.number().min(1, "Concurrency must be greater than 0").max(100, "Concurrency must be less than or equal to 100"), - buffer_size: z.number().min(1, "Buffer size must be greater than 0").max(1000, "Buffer size must be less than or equal to 1000"), + concurrency: z + .number() + .min(1, "Concurrency must be greater than 0") + .max(100, "Concurrency must be less than or equal to 100"), + buffer_size: z + .number() + .min(1, "Buffer size must be greater than 0") + .max(1000, "Buffer size must be less than or equal to 1000"), }); // Proxy type schema -export const proxyTypeSchema = z.enum(["none", "http", "socks5", "environment"]); +export const proxyTypeSchema = z.enum([ + "none", + "http", + "socks5", + "environment", +]); // Proxy config schema export const proxyConfigSchema = z - .object({ - type: proxyTypeSchema, - url: envVarSchema.optional(), - username: envVarSchema.optional(), - password: envVarSchema.optional(), - ca_cert_pem: envVarSchema.optional(), - }) - .refine( - (data) => - !(data.type === "http" || data.type === "socks5") || - data.url?.from_env === true || - (data.url?.value && data.url.value.trim().length > 0), - { - message: "Proxy URL is required when using HTTP or SOCKS5 proxy", - path: ["url"], - }, - ) - .refine( - (data) => { - if ((data.type === "http" || data.type === "socks5") && data.url?.value?.trim()) { - if (data.url.from_env || data.url.env_var?.startsWith("env.")) { - return true; - } - try { - new URL(data.url.value); - return true; - } catch { - return false; - } - } - return true; - }, - { message: "Must be a valid URL (e.g., http://proxy.example.com:8080)", path: ["url"] }, - ); + .object({ + type: proxyTypeSchema, + url: envVarSchema.optional(), + username: envVarSchema.optional(), + password: envVarSchema.optional(), + ca_cert_pem: envVarSchema.optional(), + }) + .refine( + (data) => + !(data.type === "http" || data.type === "socks5") || + data.url?.from_env === true || + (data.url?.value && data.url.value.trim().length > 0), + { + message: "Proxy URL is required when using HTTP or SOCKS5 proxy", + path: ["url"], + }, + ) + .refine( + (data) => { + if ( + (data.type === "http" || data.type === "socks5") && + data.url?.value?.trim() + ) { + if (data.url.from_env || data.url.env_var?.startsWith("env.")) { + return true; + } + try { + new URL(data.url.value); + return true; + } catch { + return false; + } + } + return true; + }, + { + message: "Must be a valid URL (e.g., http://proxy.example.com:8080)", + path: ["url"], + }, + ); // Proxy form schema - more lenient for form inputs with conditional validation export const proxyFormConfigSchema = z - .object({ - type: proxyTypeSchema, - url: envVarSchema.optional(), - username: envVarSchema.optional(), - password: envVarSchema.optional(), - ca_cert_pem: envVarSchema.optional(), - }) - .refine( - (data) => { - if (data.type === "none") { - return true; - } - // URL is required when proxy type is http or socks5 - if (data.type === "http" || data.type === "socks5") { - // Env-backed URLs may have empty resolved value before env resolution. - if (data.url?.from_env || data.url?.env_var?.startsWith("env.")) return true; - // Literal URLs must be non-empty. - if (!data.url?.value || data.url.value.trim().length === 0) return false; - } - return true; - }, - { - message: "Proxy URL is required when using HTTP or SOCKS5 proxy", - path: ["url"], - }, - ) - .refine( - (data) => { - // URL must be valid format when provided and proxy type requires it - if ((data.type === "http" || data.type === "socks5") && data.url?.value && data.url.value.trim().length > 0) { - if (data.url.from_env || data.url.env_var?.startsWith("env.")) { - return true; - } - try { - new URL(data.url.value); - return true; - } catch { - return false; - } - } - return true; - }, - { - message: "Must be a valid URL (e.g., http://proxy.example.com:8080)", - path: ["url"], - }, - ); + .object({ + type: proxyTypeSchema, + url: envVarSchema.optional(), + username: envVarSchema.optional(), + password: envVarSchema.optional(), + ca_cert_pem: envVarSchema.optional(), + }) + .refine( + (data) => { + if (data.type === "none") { + return true; + } + // URL is required when proxy type is http or socks5 + if (data.type === "http" || data.type === "socks5") { + // Env-backed URLs may have empty resolved value before env resolution. + if (data.url?.from_env || data.url?.env_var?.startsWith("env.")) + return true; + // Literal URLs must be non-empty. + if (!data.url?.value || data.url.value.trim().length === 0) + return false; + } + return true; + }, + { + message: "Proxy URL is required when using HTTP or SOCKS5 proxy", + path: ["url"], + }, + ) + .refine( + (data) => { + // URL must be valid format when provided and proxy type requires it + if ( + (data.type === "http" || data.type === "socks5") && + data.url?.value && + data.url.value.trim().length > 0 + ) { + if (data.url.from_env || data.url.env_var?.startsWith("env.")) { + return true; + } + try { + new URL(data.url.value); + return true; + } catch { + return false; + } + } + return true; + }, + { + message: "Must be a valid URL (e.g., http://proxy.example.com:8080)", + path: ["url"], + }, + ); // OpenAI Config tab export const openaiConfigFormSchema = z.object({ - disable_store: z.boolean(), + disable_store: z.boolean(), }); export type OpenAIConfigFormSchema = z.infer