diff --git a/docs/features/semantic-caching.mdx b/docs/features/semantic-caching.mdx index d63211cfab..f25747c720 100644 --- a/docs/features/semantic-caching.mdx +++ b/docs/features/semantic-caching.mdx @@ -118,7 +118,6 @@ import ( cacheConfig := &semanticcache.Config{ // Embedding model configuration (Required) Provider: schemas.OpenAI, - Keys: []schemas.Key{{Value: "sk-..."}}, EmbeddingModel: "text-embedding-3-small", Dimension: 1536, @@ -155,22 +154,32 @@ bifrostConfig := schemas.BifrostConfig{ ![Semantic Cache Plugin Configuration](../media/ui-semantic-cache-config.png) -**Note**: Make sure you have a vector store setup (using `config.json`) before configuring the semantic cache plugin. +**Prerequisites**: A vector store must be configured and enabled in `config.json`, and at least one provider must be configured, before the toggle becomes available. -1. **Navigate to Settings** - - Open Bifrost UI at `http://localhost:8080` - - Go to Settings. +1. **Navigate to the Config page** in the Bifrost UI and find the **Plugins** section. -2. **Configure Semantic Cache Plugin** +2. **Toggle** the **Enable Semantic Caching** switch to enable it. The configuration form expands below. -- Toggle the plugin switch to enable it, and fill in the required fields. +3. **Fill in the fields** across the four sections: -**Required Fields:** -- **Provider**: The provider to use for caching. -- **Embedding Model**: The embedding model to use for caching. -- **Dimension**: The embedding dimension for the configured embedding model. +**Provider and Model Settings** (required for semantic mode): +- **Configured Providers**: Dropdown of providers already set up in Bifrost. The selected provider's API keys are inherited automatically. +- **Embedding Model**: The embedding model to use (e.g. `text-embedding-3-small`). -**Note**: Changes will need a restart of the Bifrost server to take effect, because the plugin is loaded on startup only. +**Cache Settings**: +- **TTL (seconds)**: How long cached responses are kept (default: 300 s). +- **Similarity Threshold**: Cosine similarity cutoff for a cache hit (0–1, default: 0.8). +- **Dimension**: Vector dimension matching your embedding model (e.g. 1536 for `text-embedding-3-small`). + +**Conversation Settings**: +- **Conversation History Threshold**: Skip caching when the conversation has more than this many messages (default: 3). +- **Exclude System Prompt** (toggle): Exclude system messages from cache-key generation. + +**Cache Behavior**: +- **Cache by Model** (toggle): Include the model name in the cache key (default: on). +- **Cache by Provider** (toggle): Include the provider name in the cache key (default: on). + +4. Click **Save**. Changes are persisted and applied immediately for enabled plugins via the API reload path; other plugin changes (e.g. via `config.json`) may still require a restart. @@ -202,7 +211,7 @@ bifrostConfig := schemas.BifrostConfig{ } ``` -> **Note**: In `config.json` setups, provider keys are taken from the provider config on initialization, so you do not need to duplicate `keys` inside the plugin config. Any updates to the provider keys will not be reflected until next restart. +> **Note**: Provider API keys are inherited automatically from the global provider configuration. You do not need to (and cannot) specify keys inside the plugin config. **TTL Format Options:** - Duration strings: `"30s"`, `"5m"`, `"1h"`, `"24h"` @@ -228,7 +237,7 @@ Exact-match direct entries are stored and retrieved using a deterministic cache ### Setup -To enable direct-only mode globally, set `dimension: 1` and omit the `provider` and `keys` fields from the plugin config. The plugin will automatically fall back to direct search only. +To enable direct-only mode globally, set `dimension: 1` and omit the `provider` and `embedding_model` fields from the plugin config. The plugin will automatically fall back to direct search only. > **Important**: If you specify `dimension: 1` and also provide a `provider`, Bifrost treats the config as provider-backed semantic mode, not direct-only mode. To use direct-only mode, omit the `provider` field entirely. @@ -246,7 +255,7 @@ import ( ) cacheConfig := &semanticcache.Config{ - // No Provider, Keys, or EmbeddingModel -- direct hash mode only + // No Provider or EmbeddingModel -- direct hash mode only Dimension: 1, // Placeholder; entries are stored as metadata-only (no embedding vectors). Change dimension before switching to dual-layer mode to avoid mixed-dimension issues. TTL: 5 * time.Minute, diff --git a/plugins/semanticcache/main.go b/plugins/semanticcache/main.go index 148bbfef1a..c065ceff35 100644 --- a/plugins/semanticcache/main.go +++ b/plugins/semanticcache/main.go @@ -15,7 +15,6 @@ import ( bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/framework" "github.com/maximhq/bifrost/framework/vectorstore" ) @@ -25,7 +24,6 @@ import ( type Config struct { // Embedding Model settings - REQUIRED for semantic caching Provider schemas.ModelProvider `json:"provider"` - Keys []schemas.Key `json:"keys"` EmbeddingModel string `json:"embedding_model,omitempty"` // Model to use for generating embeddings (optional) // Plugin behavior settings @@ -48,19 +46,18 @@ type Config struct { func (c *Config) UnmarshalJSON(data []byte) error { // Define a temporary struct to avoid infinite recursion type TempConfig struct { - Provider string `json:"provider"` - Keys []schemas.Key `json:"keys"` - EmbeddingModel string `json:"embedding_model,omitempty"` - CleanUpOnShutdown bool `json:"cleanup_on_shutdown,omitempty"` - Dimension int `json:"dimension"` - TTL interface{} `json:"ttl,omitempty"` - Threshold float64 `json:"threshold,omitempty"` - VectorStoreNamespace string `json:"vector_store_namespace,omitempty"` - DefaultCacheKey string `json:"default_cache_key,omitempty"` - ConversationHistoryThreshold int `json:"conversation_history_threshold,omitempty"` - CacheByModel *bool `json:"cache_by_model,omitempty"` - CacheByProvider *bool `json:"cache_by_provider,omitempty"` - ExcludeSystemPrompt *bool `json:"exclude_system_prompt,omitempty"` + Provider string `json:"provider"` + EmbeddingModel string `json:"embedding_model,omitempty"` + CleanUpOnShutdown bool `json:"cleanup_on_shutdown,omitempty"` + Dimension int `json:"dimension"` + TTL interface{} `json:"ttl,omitempty"` + Threshold float64 `json:"threshold,omitempty"` + VectorStoreNamespace string `json:"vector_store_namespace,omitempty"` + DefaultCacheKey string `json:"default_cache_key,omitempty"` + ConversationHistoryThreshold int `json:"conversation_history_threshold,omitempty"` + CacheByModel *bool `json:"cache_by_model,omitempty"` + CacheByProvider *bool `json:"cache_by_provider,omitempty"` + ExcludeSystemPrompt *bool `json:"exclude_system_prompt,omitempty"` } var temp TempConfig @@ -70,7 +67,6 @@ func (c *Config) UnmarshalJSON(data []byte) error { // Set simple fields c.Provider = schemas.ModelProvider(temp.Provider) - c.Keys = temp.Keys c.EmbeddingModel = temp.EmbeddingModel c.CleanUpOnShutdown = temp.CleanUpOnShutdown c.Dimension = temp.Dimension @@ -129,6 +125,10 @@ type StreamAccumulator struct { mu sync.Mutex // Protects chunk operations } +// EmbeddingRequestExecutor is a function that executes a request and returns a response and an error. +// It maps to .EmbeddingRequest() of the bifrost client. +type EmbeddingRequestExecutor func(ctx *schemas.BifrostContext, req *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) + // Plugin implements the schemas.LLMPlugin interface for semantic caching. // It caches responses using a two-tier approach: direct hash matching for exact requests // and semantic similarity search for related content. The plugin supports configurable caching behavior @@ -139,12 +139,12 @@ type StreamAccumulator struct { // - config: Plugin configuration including semantic cache and caching settings // - logger: Logger instance for plugin operations type Plugin struct { - store vectorstore.VectorStore - config *Config - logger schemas.Logger - client *bifrost.Bifrost - streamAccumulators sync.Map // Track stream accumulators by request ID - waitGroup sync.WaitGroup + store vectorstore.VectorStore + config *Config + logger schemas.Logger + embeddingRequestExecutor EmbeddingRequestExecutor + streamAccumulators sync.Map // Track stream accumulators by request ID + waitGroup sync.WaitGroup } // Plugin constants @@ -201,45 +201,6 @@ var VectorStoreProperties = map[string]vectorstore.VectorStoreProperties{ }, } -type PluginAccount struct { - provider schemas.ModelProvider - keys []schemas.Key -} - -func (pa *PluginAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { - return []schemas.ModelProvider{pa.provider}, nil -} - -func (pa *PluginAccount) GetKeysForProvider(ctx context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { - return pa.keys, nil -} - -func (pa *PluginAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { - return &schemas.ProviderConfig{ - NetworkConfig: schemas.DefaultNetworkConfig, - ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, - }, nil -} - -// Dependencies is a list of dependencies that the plugin requires. -var Dependencies []framework.FrameworkDependency = []framework.FrameworkDependency{framework.FrameworkDependencyVectorStore} - -// ProvidersWithEmbeddingSupport lists all providers that support embedding operations. -// Providers not in this list will return UnsupportedOperationError for embedding requests. -var ProvidersWithEmbeddingSupport = map[schemas.ModelProvider]bool{ - schemas.OpenAI: true, - schemas.Azure: true, - schemas.Bedrock: true, - schemas.Cohere: true, - schemas.Gemini: true, - schemas.Vertex: true, - schemas.Mistral: true, - schemas.Ollama: true, - schemas.Nebius: true, - schemas.HuggingFace: true, - schemas.SGL: true, -} - const ( CacheKey schemas.BifrostContextKey = "semantic_cache_key" // To set the cache key for a request - REQUIRED for all requests CacheTTLKey schemas.BifrostContextKey = "semantic_cache_ttl" // To explicitly set the TTL for a request @@ -323,26 +284,8 @@ func Init(ctx context.Context, config *Config, logger schemas.Logger, store vect if config.Provider == "" && config.Dimension == 1 { logger.Info(PluginLoggerPrefix + " Starting in direct-only mode (dimension=1, no embedding provider)") - } else if config.Provider == "" || len(config.Keys) == 0 { - logger.Warn(PluginLoggerPrefix + " Incomplete semantic mode config: missing provider or keys, falling back to direct search only") - } else { - // Validate that the provider supports embeddings - if bifrost.IsStandardProvider(config.Provider) && !ProvidersWithEmbeddingSupport[config.Provider] { - return nil, fmt.Errorf("provider '%s' does not support embedding operations required for semantic cache. Supported providers: openai, azure, bedrock, cohere, gemini, vertex, mistral, ollama, nebius, huggingface, sgl. Note: custom providers based on embedding-capable providers are also supported", config.Provider) - } - - bifrost, err := bifrost.Init(ctx, schemas.BifrostConfig{ - Logger: logger, - Account: &PluginAccount{ - provider: config.Provider, - keys: config.Keys, - }, - }) - if err != nil { - return nil, fmt.Errorf("failed to initialize bifrost for semantic cache: %w", err) - } - - plugin.client = bifrost + } else if config.Provider == "" { + logger.Warn(PluginLoggerPrefix + " Incomplete semantic mode config: missing provider, falling back to direct search only") } createCtx, cancel := context.WithTimeout(ctx, CreateNamespaceTimeout) @@ -378,19 +321,6 @@ func (plugin *Plugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, return chunk, nil } -func (plugin *Plugin) clearRequestScopedContext(ctx *schemas.BifrostContext) { - ctx.ClearValue(requestIDKey) - ctx.ClearValue(requestStorageIDKey) - ctx.ClearValue(requestHashKey) - ctx.ClearValue(requestParamsHashKey) - ctx.ClearValue(requestModelKey) - ctx.ClearValue(requestProviderKey) - ctx.ClearValue(requestEmbeddingKey) - ctx.ClearValue(requestEmbeddingTokensKey) - ctx.ClearValue(isCacheHitKey) - ctx.ClearValue(cacheHitTypeKey) -} - // PreLLMHook is called before a request is processed by Bifrost. // It performs a two-stage cache lookup: first direct hash matching, then semantic similarity search. // Uses UUID-based keys for entries stored in the VectorStore. @@ -465,7 +395,7 @@ func (plugin *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifro } } - if performSemanticSearch && plugin.client != nil { + if performSemanticSearch && plugin.embeddingRequestExecutor != nil { if req.EmbeddingRequest != nil || req.TranscriptionRequest != nil { plugin.logger.Debug(PluginLoggerPrefix + " Skipping semantic search for embedding/transcription input") // For vector stores that require vectors, set a zero vector placeholder @@ -488,7 +418,7 @@ func (plugin *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifro if shortCircuit != nil { return req, shortCircuit, nil } - } else if !performSemanticSearch && plugin.store.RequiresVectors() && plugin.client != nil { + } else if !performSemanticSearch && plugin.store.RequiresVectors() && plugin.embeddingRequestExecutor != nil { // Vector store requires vectors but we're in direct-only mode // Generate embeddings for storage purposes (not for searching) if req.EmbeddingRequest != nil || req.TranscriptionRequest != nil { @@ -759,11 +689,6 @@ func (plugin *Plugin) Cleanup() error { // Clean up old stream accumulators first plugin.cleanupOldStreamAccumulators() - // Shutdown the internal Bifrost client used for embeddings - if plugin.client != nil { - plugin.client.Shutdown() - } - // Only clean up cache entries if configured to do so if !plugin.config.CleanUpOnShutdown { plugin.logger.Debug(PluginLoggerPrefix + " Cleanup on shutdown is disabled, skipping cache cleanup") @@ -804,6 +729,15 @@ func (plugin *Plugin) Cleanup() error { return nil } +// SetEmbeddingRequestExecutor sets the embedding request executor for the plugin. +// Needs to be set before the plugin is used. +// +// Parameters: +// - executor: The embedding request executor to set +func (plugin *Plugin) SetEmbeddingRequestExecutor(executor EmbeddingRequestExecutor) { + plugin.embeddingRequestExecutor = executor +} + // Public Methods for External Use // ClearCacheForKey deletes cache entries for a specific cache key. @@ -869,3 +803,16 @@ func (plugin *Plugin) ClearCacheForRequestID(requestID string) error { return nil } + +func (plugin *Plugin) clearRequestScopedContext(ctx *schemas.BifrostContext) { + ctx.ClearValue(requestIDKey) + ctx.ClearValue(requestStorageIDKey) + ctx.ClearValue(requestHashKey) + ctx.ClearValue(requestParamsHashKey) + ctx.ClearValue(requestModelKey) + ctx.ClearValue(requestProviderKey) + ctx.ClearValue(requestEmbeddingKey) + ctx.ClearValue(requestEmbeddingTokensKey) + ctx.ClearValue(isCacheHitKey) + ctx.ClearValue(cacheHitTypeKey) +} diff --git a/plugins/semanticcache/plugin_core_test.go b/plugins/semanticcache/plugin_core_test.go index 822fc1f645..5bed26528d 100644 --- a/plugins/semanticcache/plugin_core_test.go +++ b/plugins/semanticcache/plugin_core_test.go @@ -2,7 +2,6 @@ package semanticcache import ( "context" - "strings" "testing" "time" @@ -389,9 +388,6 @@ func TestCacheConfiguration(t *testing.T) { EmbeddingModel: "text-embedding-3-small", Dimension: 1536, Threshold: 0.95, // Very high threshold - Keys: []schemas.Key{ - {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: schemas.WhiteList{"*"}, Weight: 1.0}, - }, }, expectedBehavior: "strict_matching", }, @@ -402,9 +398,6 @@ func TestCacheConfiguration(t *testing.T) { EmbeddingModel: "text-embedding-3-small", Dimension: 1536, Threshold: 0.1, // Very low threshold - Keys: []schemas.Key{ - {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: schemas.WhiteList{"*"}, Weight: 1.0}, - }, }, expectedBehavior: "loose_matching", }, @@ -416,9 +409,6 @@ func TestCacheConfiguration(t *testing.T) { Dimension: 1536, Threshold: 0.8, TTL: 1 * time.Hour, // Custom TTL - Keys: []schemas.Key{ - {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: schemas.WhiteList{"*"}, Weight: 1.0}, - }, }, expectedBehavior: "custom_ttl", }, @@ -463,7 +453,7 @@ func (m *MockUnsupportedStore) Ping(ctx context.Context) error { } func (m *MockUnsupportedStore) CreateNamespace(ctx context.Context, namespace string, dimension int, properties map[string]vectorstore.VectorStoreProperties) error { - return vectorstore.ErrNotSupported + return nil } func (m *MockUnsupportedStore) DeleteNamespace(ctx context.Context, namespace string) error { @@ -547,23 +537,13 @@ func TestInvalidProviderRejection(t *testing.T) { Dimension: 1536, Threshold: 0.8, CleanUpOnShutdown: false, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.TEST_API_KEY"), - Models: schemas.WhiteList{"*"}, - Weight: 1.0, - }, - }, } + // Provider validation was moved to request time (global client handles it). + // Init itself should succeed regardless of the provider set in config. _, err := Init(ctx, config, logger, mockStore) - if err == nil { - t.Errorf("Expected error for provider '%s' but got none", provider) - } - - expectedErrSubstring := "does not support embedding operations" - if err != nil && !strings.Contains(err.Error(), expectedErrSubstring) { - t.Errorf("Expected error message to contain '%s', but got: %v", expectedErrSubstring, err) + if err != nil { + t.Errorf("Init should succeed for provider '%s' (validation happens at request time), but got: %v", provider, err) } }) } @@ -584,18 +564,11 @@ func TestValidProviderAccepted(t *testing.T) { Dimension: 1536, Threshold: 0.8, CleanUpOnShutdown: false, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: schemas.WhiteList{"*"}, - Weight: 1.0, - }, - }, } - // Should fail due to namespace creation, not provider validation + // Init should succeed; provider validation happens at request time via the global client. _, err := Init(ctx, config, logger, mockStore) - if err != nil && strings.Contains(err.Error(), "does not support embedding operations") { - t.Errorf("Valid provider OpenAI should not be rejected for embedding support, but got: %v", err) + if err != nil { + t.Errorf("Valid provider OpenAI should be accepted at Init, but got: %v", err) } } diff --git a/plugins/semanticcache/plugin_image_generation_test.go b/plugins/semanticcache/plugin_image_generation_test.go index f50f3c5c9b..a65c06e81b 100644 --- a/plugins/semanticcache/plugin_image_generation_test.go +++ b/plugins/semanticcache/plugin_image_generation_test.go @@ -128,9 +128,6 @@ func TestImageGenerationSemanticSearch(t *testing.T) { EmbeddingModel: "text-embedding-3-small", Dimension: 1536, Threshold: 0.5, - Keys: []schemas.Key{ - {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: []string{"*"}, Weight: 1.0}, - }, } setup := NewTestSetupWithConfig(t, config) defer setup.Cleanup() diff --git a/plugins/semanticcache/plugin_vectorstore_test.go b/plugins/semanticcache/plugin_vectorstore_test.go index 5e390bbe80..f4ac8130f2 100644 --- a/plugins/semanticcache/plugin_vectorstore_test.go +++ b/plugins/semanticcache/plugin_vectorstore_test.go @@ -55,13 +55,6 @@ func getDefaultTestConfig() *Config { Dimension: 1536, Threshold: 0.8, CleanUpOnShutdown: true, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: schemas.WhiteList{"*"}, - Weight: 1.0, - }, - }, } } diff --git a/plugins/semanticcache/test_utils.go b/plugins/semanticcache/test_utils.go index d267254473..e9b847c6dc 100644 --- a/plugins/semanticcache/test_utils.go +++ b/plugins/semanticcache/test_utils.go @@ -371,13 +371,6 @@ func NewTestSetup(t *testing.T) *TestSetup { Dimension: 1536, Threshold: 0.8, CleanUpOnShutdown: true, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: schemas.WhiteList{"*"}, - Weight: 1.0, - }, - }, }) } @@ -427,6 +420,9 @@ func NewTestSetupWithVectorStore(t *testing.T, config *Config, storeType vectors // Get a mocked Bifrost client client := getMockedBifrostClient(t, ctx, logger, plugin) + // Wire the global client as the embedding executor so semantic search works. + pluginImpl.SetEmbeddingRequestExecutor(client.EmbeddingRequest) + return &TestSetup{ Logger: logger, Store: store, @@ -648,13 +644,6 @@ func CreateTestSetupWithConversationThreshold(t *testing.T, threshold int) *Test CleanUpOnShutdown: true, Threshold: 0.8, ConversationHistoryThreshold: threshold, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{"*"}, - Weight: 1.0, - }, - }, } return NewTestSetupWithConfig(t, config) @@ -669,13 +658,6 @@ func CreateTestSetupWithExcludeSystemPrompt(t *testing.T, excludeSystem bool) *T CleanUpOnShutdown: true, Threshold: 0.8, ExcludeSystemPrompt: &excludeSystem, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{"*"}, - Weight: 1.0, - }, - }, } return NewTestSetupWithConfig(t, config) @@ -691,13 +673,6 @@ func CreateTestSetupWithThresholdAndExcludeSystem(t *testing.T, threshold int, e Threshold: 0.8, ConversationHistoryThreshold: threshold, ExcludeSystemPrompt: &excludeSystem, - Keys: []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{"*"}, - Weight: 1.0, - }, - }, } return NewTestSetupWithConfig(t, config) diff --git a/plugins/semanticcache/utils.go b/plugins/semanticcache/utils.go index 712030051a..957115ee24 100644 --- a/plugins/semanticcache/utils.go +++ b/plugins/semanticcache/utils.go @@ -67,8 +67,16 @@ func (plugin *Plugin) generateEmbedding(ctx *schemas.BifrostContext, text string }, } - // Generate embedding using bifrost client - response, err := plugin.client.EmbeddingRequest(ctx, embeddingReq) + // Create a new context from incoming context. Parent ctx will be used for cancellation. + embeddingCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + defer embeddingCtx.ReleasePluginScope() + + embeddingCtx.SetValue(schemas.BifrostContextKeySkipPluginPipeline, true) + + if plugin.embeddingRequestExecutor == nil { + return nil, 0, fmt.Errorf("embedding request executor is not configured") + } + response, err := plugin.embeddingRequestExecutor(embeddingCtx, embeddingReq) if err != nil { return nil, 0, fmt.Errorf("failed to generate embedding: %v", err) } diff --git a/transports/bifrost-http/handlers/plugins.go b/transports/bifrost-http/handlers/plugins.go index 11b362dddf..3dc4f8353c 100644 --- a/transports/bifrost-http/handlers/plugins.go +++ b/transports/bifrost-http/handlers/plugins.go @@ -35,25 +35,23 @@ func NewPluginsHandler(pluginsLoader PluginsLoader, configStore configstore.Conf } } - - // CreatePluginRequest is the request body for creating a plugin type CreatePluginRequest struct { - Name string `json:"name"` - Enabled bool `json:"enabled"` - Config map[string]any `json:"config"` - Path *string `json:"path"` + Name string `json:"name"` + Enabled bool `json:"enabled"` + Config map[string]any `json:"config"` + Path *string `json:"path"` Placement *schemas.PluginPlacement `json:"placement,omitempty"` - Order *int `json:"order,omitempty"` + Order *int `json:"order,omitempty"` } // UpdatePluginRequest is the request body for updating a plugin type UpdatePluginRequest struct { - Enabled bool `json:"enabled"` - Path *string `json:"path"` - Config map[string]any `json:"config"` + Enabled bool `json:"enabled"` + Path *string `json:"path"` + Config map[string]any `json:"config"` Placement *schemas.PluginPlacement `json:"placement,omitempty"` - Order *int `json:"order,omitempty"` + Order *int `json:"order,omitempty"` } // RegisterRoutes registers the routes for the PluginsHandler @@ -66,15 +64,15 @@ func (h *PluginsHandler) RegisterRoutes(r *router.Router, middlewares ...schemas } type PluginResponse struct { - Name string `json:"name"` - ActualName string `json:"actualName"` - Enabled bool `json:"enabled"` - Config any `json:"config"` - IsCustom bool `json:"isCustom"` - Path *string `json:"path"` - Placement *schemas.PluginPlacement `json:"placement,omitempty"` - Order *int `json:"order,omitempty"` - Status schemas.PluginStatus `json:"status"` + Name string `json:"name"` + ActualName string `json:"actualName"` + Enabled bool `json:"enabled"` + Config any `json:"config"` + IsCustom bool `json:"isCustom"` + Path *string `json:"path"` + Placement *schemas.PluginPlacement `json:"placement,omitempty"` + Order *int `json:"order,omitempty"` + Status schemas.PluginStatus `json:"status"` } // buildPluginResponse constructs a PluginResponse with status for a given TablePlugin. diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index aa845ab0e1..77ed94437f 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -2623,8 +2623,8 @@ func loadPlugins(ctx context.Context, config *Config, configData *ConfigData) { Order: plugin.Order, } if plugin.Name == semanticcache.PluginName { - if err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig); err != nil { - logger.Warn("failed to add provider keys to semantic cache config: %v", err) + if err := config.ValidateSemanticCacheConfig(pluginConfig); err != nil { + logger.Warn("failed to validate semantic cache config: %v", err) } } config.PluginConfigs[i] = pluginConfig @@ -2693,16 +2693,6 @@ func mergePlugins(ctx context.Context, config *Config, configData *ConfigData) { } } - // Process semantic cache plugin - for i, plugin := range config.PluginConfigs { - if plugin.Name == semanticcache.PluginName { - if err := config.AddProviderKeysToSemanticCacheConfig(plugin); err != nil { - logger.Warn("failed to add provider keys to semantic cache config: %v", err) - } - config.PluginConfigs[i] = plugin - } - } - // Update store if config.ConfigStore != nil { logger.Debug("updating plugins in store") @@ -2724,11 +2714,6 @@ func mergePlugins(ctx context.Context, config *Config, configData *ConfigData) { Placement: plugin.Placement, Order: plugin.Order, } - if plugin.Name == semanticcache.PluginName { - if err := config.RemoveProviderKeysFromSemanticCacheConfig(pluginConfig); err != nil { - logger.Warn("failed to remove provider keys from semantic cache config: %v", err) - } - } if err := config.ConfigStore.UpsertPlugin(ctx, pluginConfig); err != nil { logger.Warn("failed to update plugin: %v", err) } @@ -4787,7 +4772,7 @@ func ValidateCustomProviderUpdate(newConfig, existingConfig configstore.Provider return nil } -func (c *Config) AddProviderKeysToSemanticCacheConfig(config *schemas.PluginConfig) error { +func (c *Config) ValidateSemanticCacheConfig(config *schemas.PluginConfig) error { if config.Name != semanticcache.PluginName { return nil } @@ -4856,13 +4841,11 @@ func (c *Config) AddProviderKeysToSemanticCacheConfig(config *schemas.PluginConf } configMap["embedding_model"] = embeddingModel - keys, err := c.GetProviderConfigRaw(schemas.ModelProvider(provider)) - if err != nil { + // Validate that the provider is configured in the global client (keys are inherited automatically). + if _, err := c.GetProviderConfigRaw(schemas.ModelProvider(provider)); err != nil { return fmt.Errorf("failed to get provider config for %s: %w", provider, err) } - configMap["keys"] = keys.Keys - return nil } @@ -4909,30 +4892,6 @@ func semanticCacheConfigDimension(configMap map[string]interface{}) (int, bool, return 0, false, fmt.Errorf("semantic_cache plugin 'dimension' field must be numeric, got %T", dimensionVal) } } - -func (c *Config) RemoveProviderKeysFromSemanticCacheConfig(config *configstoreTables.TablePlugin) error { - if config.Name != semanticcache.PluginName { - return nil - } - - // Check if config.Config exists - if config.Config == nil { - return fmt.Errorf("semantic_cache plugin config is nil") - } - - // Type assert config.Config to map[string]interface{} - configMap, ok := config.Config.(map[string]interface{}) - if !ok { - return fmt.Errorf("semantic_cache plugin config must be a map, got %T", config.Config) - } - - configMap["keys"] = []schemas.Key{} - - config.Config = configMap - - return nil -} - func (c *Config) GetAvailableProviders(model string) []schemas.ModelProvider { c.Mu.RLock() defer c.Mu.RUnlock() diff --git a/transports/bifrost-http/lib/semantic_cache_config_test.go b/transports/bifrost-http/lib/semantic_cache_config_test.go index 61c5da22c7..2d79bd9526 100644 --- a/transports/bifrost-http/lib/semantic_cache_config_test.go +++ b/transports/bifrost-http/lib/semantic_cache_config_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyMode(t *testing.T) { +func TestValidateSemanticCacheConfig_DirectOnlyMode(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -19,7 +19,7 @@ func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyMode(t *testing.T) { }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.NoError(t, err) configMap, ok := pluginConfig.Config.(map[string]interface{}) @@ -28,7 +28,7 @@ func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyMode(t *testing.T) { require.False(t, hasKeys, "direct-only mode should not inject provider keys") } -func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyModeRemovesStaleProviderBackedFields(t *testing.T) { +func TestValidateSemanticCacheConfig_DirectOnlyModeRemovesStaleProviderBackedFields(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -39,18 +39,16 @@ func TestAddProviderKeysToSemanticCacheConfig_DirectOnlyModeRemovesStaleProvider }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.NoError(t, err) configMap, ok := pluginConfig.Config.(map[string]interface{}) require.True(t, ok) - _, hasKeys := configMap["keys"] - require.False(t, hasKeys, "direct-only mode should remove stale provider keys") _, hasEmbeddingModel := configMap["embedding_model"] require.False(t, hasEmbeddingModel, "direct-only mode should remove stale embedding_model") } -func TestAddProviderKeysToSemanticCacheConfig_InjectsProviderKeys(t *testing.T) { +func TestValidateSemanticCacheConfig_ProviderBackedModeValidationPasses(t *testing.T) { config := &Config{ Providers: map[schemas.ModelProvider]configstore.ProviderConfig{ schemas.OpenAI: { @@ -73,19 +71,17 @@ func TestAddProviderKeysToSemanticCacheConfig_InjectsProviderKeys(t *testing.T) }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.NoError(t, err) configMap, ok := pluginConfig.Config.(map[string]interface{}) require.True(t, ok) - keys, ok := configMap["keys"].([]schemas.Key) - require.True(t, ok, "provider-backed mode should inject provider keys") - require.Len(t, keys, 1) - require.Equal(t, "openai-key", keys[0].Name) + _, hasKeys := configMap["keys"] + require.False(t, hasKeys, "keys are inherited from global client; they must not be injected into the plugin config") require.Equal(t, "openai", configMap["provider"]) } -func TestAddProviderKeysToSemanticCacheConfig_SemanticModeMissingProvider(t *testing.T) { +func TestValidateSemanticCacheConfig_SemanticModeMissingProvider(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -94,12 +90,12 @@ func TestAddProviderKeysToSemanticCacheConfig_SemanticModeMissingProvider(t *tes }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.Error(t, err) require.Contains(t, err.Error(), "requires 'provider' for semantic mode") } -func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeMissingDimension(t *testing.T) { +func TestValidateSemanticCacheConfig_ProviderBackedModeMissingDimension(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -109,12 +105,12 @@ func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeMissingDimension }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.Error(t, err) require.Contains(t, err.Error(), "requires 'dimension' for provider-backed semantic mode") } -func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeDimensionOne(t *testing.T) { +func TestValidateSemanticCacheConfig_ProviderBackedModeDimensionOne(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -125,12 +121,12 @@ func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeDimensionOne(t * }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.Error(t, err) require.Contains(t, err.Error(), "requires 'dimension' > 1") } -func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeMissingEmbeddingModel(t *testing.T) { +func TestValidateSemanticCacheConfig_ProviderBackedModeMissingEmbeddingModel(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -140,12 +136,12 @@ func TestAddProviderKeysToSemanticCacheConfig_ProviderBackedModeMissingEmbedding }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.Error(t, err) require.Contains(t, err.Error(), "requires 'embedding_model'") } -func TestAddProviderKeysToSemanticCacheConfig_InvalidDimensionZero(t *testing.T) { +func TestValidateSemanticCacheConfig_InvalidDimensionZero(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -154,12 +150,12 @@ func TestAddProviderKeysToSemanticCacheConfig_InvalidDimensionZero(t *testing.T) }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.Error(t, err) require.Contains(t, err.Error(), "'dimension' must be >= 1") } -func TestAddProviderKeysToSemanticCacheConfig_InvalidDimensionNegative(t *testing.T) { +func TestValidateSemanticCacheConfig_InvalidDimensionNegative(t *testing.T) { config := &Config{} pluginConfig := &schemas.PluginConfig{ Name: semanticcache.PluginName, @@ -168,7 +164,7 @@ func TestAddProviderKeysToSemanticCacheConfig_InvalidDimensionNegative(t *testin }, } - err := config.AddProviderKeysToSemanticCacheConfig(pluginConfig) + err := config.ValidateSemanticCacheConfig(pluginConfig) require.Error(t, err) require.Contains(t, err.Error(), "'dimension' must be >= 1") } diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index ca5a3da674..2bbdeadad5 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -969,6 +969,10 @@ func (s *BifrostHTTPServer) ReloadPlugin(ctx context.Context, name string, path if err != nil { return s.updatePluginErrorStatus(name, "loading", err) } + // Wire the embedding executor on the new instance before syncing. + if semanticCachePlugin, ok := plugin.(*semanticcache.Plugin); ok { + semanticCachePlugin.SetEmbeddingRequestExecutor(s.Client.EmbeddingRequest) + } return s.SyncLoadedPlugin(ctx, name, plugin, placement, order) } @@ -1372,7 +1376,6 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { } wg.Wait() } - logger.Info("models added to catalog") s.Config.SetBifrostClient(s.Client) // Initialize routes @@ -1394,6 +1397,11 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { apiMiddlewares = append(apiMiddlewares, s.AuthMiddleware.APIMiddleware()) } } + // Add semantic cache plugin embedding request executor if it exists + semanticCachePlugin, err := lib.FindPluginAs[*semanticcache.Plugin](s.Config, semanticcache.PluginName) + if err == nil && semanticCachePlugin != nil { + semanticCachePlugin.SetEmbeddingRequestExecutor(s.Client.EmbeddingRequest) + } // Register routes err = s.RegisterAPIRoutes(s.Ctx, s, apiMiddlewares...) if err != nil { diff --git a/transports/config.schema.json b/transports/config.schema.json index 80134f9287..fd69a78878 100644 --- a/transports/config.schema.json +++ b/transports/config.schema.json @@ -1393,13 +1393,6 @@ "huggingface" ] }, - "keys": { - "type": "array", - "description": "API keys for the embedding provider. These are injected at runtime for config-driven setups and are not needed for direct caching with dimension: 1.", - "items": { - "type": "string" - } - }, "embedding_model": { "type": "string", "description": "Model to use for generating embeddings in provider-backed semantic caching. Required when provider is set and not allowed in direct-only mode." diff --git a/ui/app/workspace/config/views/pluginsForm.tsx b/ui/app/workspace/config/views/pluginsForm.tsx index 33477ec8a2..dcd459de4c 100644 --- a/ui/app/workspace/config/views/pluginsForm.tsx +++ b/ui/app/workspace/config/views/pluginsForm.tsx @@ -48,9 +48,6 @@ const normalizeCacheConfigForSave = (config: EditorCacheConfig) => { if (config.updated_at !== undefined) { normalized.updated_at = config.updated_at; } - if (config.keys !== undefined) { - normalized.keys = config.keys; - } const provider = config.provider?.trim(); const embeddingModel = config.embedding_model?.trim(); @@ -375,7 +372,7 @@ export default function PluginsForm({ isVectorStoreEnabled }: PluginsFormProps)

API keys for the embedding provider will be inherited from the main provider configuration. The semantic cache will use - the configured provider's keys automatically. Updates in keys will be reflected on Bifrost restart. + the configured provider's keys automatically.

diff --git a/ui/lib/types/config.ts b/ui/lib/types/config.ts index d4768fe981..25d645fda4 100644 --- a/ui/lib/types/config.ts +++ b/ui/lib/types/config.ts @@ -540,13 +540,11 @@ interface BaseCacheConfig { export interface DirectCacheConfig extends BaseCacheConfig { dimension: 1; provider?: undefined; - keys?: ModelProviderKey[]; embedding_model?: undefined; } export interface ProviderBackedCacheConfig extends BaseCacheConfig { provider: ModelProviderName; - keys?: ModelProviderKey[]; embedding_model: string; dimension: number; } @@ -555,7 +553,6 @@ export type CacheConfig = DirectCacheConfig | ProviderBackedCacheConfig; export interface EditorCacheConfig extends BaseCacheConfig { provider?: ModelProviderName; - keys?: ModelProviderKey[]; embedding_model?: string; dimension?: number; } diff --git a/ui/lib/types/schemas.ts b/ui/lib/types/schemas.ts index 60ef6c2572..7a1fbc0834 100644 --- a/ui/lib/types/schemas.ts +++ b/ui/lib/types/schemas.ts @@ -4,1020 +4,1115 @@ import { z } from "zod"; // Global error map - turns Zod's default messages into readable, human-friendly ones. // Individual schemas can still override by passing their own message. z.config({ - customError: (issue) => { - if (issue.code === "invalid_type") { - // Field is missing / undefined - if (issue.input === undefined || issue.input === null) { - return "This field is required"; - } - const expected = issue.expected; - const received = typeof issue.input; - if (expected === "number") return "Must be a valid number"; - if (expected === "string") return "Must be a valid text value"; - if (expected === "boolean") return "Must be true or false"; - return `Expected ${expected}, received ${received}`; - } - if (issue.code === "too_small") { - if (issue.origin === "string" && issue.minimum === 1) { - return "This field is required"; - } - if (issue.origin === "number") { - return `Must be at least ${issue.minimum}`; - } - if (issue.origin === "array" && issue.minimum === 1) { - return "At least one item is required"; - } - } - if (issue.code === "too_big") { - if (issue.origin === "number") { - return `Must be at most ${issue.maximum}`; - } - if (issue.origin === "string") { - return `Must be at most ${issue.maximum} characters`; - } - } - if (issue.code === "invalid_format") { - if (issue.format === "url") return "Must be a valid URL"; - if (issue.format === "email") return "Must be a valid email"; - } - return undefined; // fall back to Zod default - }, + customError: (issue) => { + if (issue.code === "invalid_type") { + // Field is missing / undefined + if (issue.input === undefined || issue.input === null) { + return "This field is required"; + } + const expected = issue.expected; + const received = typeof issue.input; + if (expected === "number") return "Must be a valid number"; + if (expected === "string") return "Must be a valid text value"; + if (expected === "boolean") return "Must be true or false"; + return `Expected ${expected}, received ${received}`; + } + if (issue.code === "too_small") { + if (issue.origin === "string" && issue.minimum === 1) { + return "This field is required"; + } + if (issue.origin === "number") { + return `Must be at least ${issue.minimum}`; + } + if (issue.origin === "array" && issue.minimum === 1) { + return "At least one item is required"; + } + } + if (issue.code === "too_big") { + if (issue.origin === "number") { + return `Must be at most ${issue.maximum}`; + } + if (issue.origin === "string") { + return `Must be at most ${issue.maximum} characters`; + } + } + if (issue.code === "invalid_format") { + if (issue.format === "url") return "Must be a valid URL"; + if (issue.format === "email") return "Must be a valid email"; + } + return undefined; // fall back to Zod default + }, }); // Base Zod schemas matching the TypeScript types // Known provider schema -export const knownProviderSchema = z.enum(KnownProvidersNames as unknown as [string, ...string[]]); +export const knownProviderSchema = z.enum( + KnownProvidersNames as unknown as [string, ...string[]], +); // Custom provider name schema (branded type simulation) -export const customProviderNameSchema = z.string().min(1, "Custom provider name is required"); +export const customProviderNameSchema = z + .string() + .min(1, "Custom provider name is required"); // Model provider name schema (union of known and custom providers) -export const modelProviderNameSchema = z.union([knownProviderSchema, customProviderNameSchema]); +export const modelProviderNameSchema = z.union([ + knownProviderSchema, + customProviderNameSchema, +]); // EnvVar schema - matches the Go EnvVar type from schemas/env.go export const _envVarBase = z.object({ - value: z.string().optional(), - env_var: z.string().optional(), - from_env: z.boolean().optional(), + value: z.string().optional(), + env_var: z.string().optional(), + from_env: z.boolean().optional(), }); // Extending the base schema export const envVarSchema = Object.assign(_envVarBase, { - required: (message: string) => _envVarBase.refine((v) => !!v?.value?.trim() || !!v?.env_var?.trim(), message), + required: (message: string) => + _envVarBase.refine( + (v) => !!v?.value?.trim() || !!v?.env_var?.trim(), + message, + ), }); // Helper to check if an envVar field has a value or env reference -function isEnvVarSet(v: { value?: string; env_var?: string } | undefined): boolean { - if (!v) return false; - return !!v.value?.trim() || !!v.env_var?.trim(); +function isEnvVarSet( + v: { value?: string; env_var?: string } | undefined, +): boolean { + if (!v) return false; + return !!v.value?.trim() || !!v.env_var?.trim(); } // Azure key config schema export const azureKeyConfigSchema = z - .object({ - _auth_type: z.enum(["api_key", "entra_id", "default_credential"]).optional(), - endpoint: envVarSchema.optional(), - api_version: envVarSchema.optional(), - client_id: envVarSchema.optional(), - client_secret: envVarSchema.optional(), - tenant_id: envVarSchema.optional(), - scopes: z.array(z.string()).optional(), - }) - .refine((data) => isEnvVarSet(data.endpoint), { - message: "Endpoint is required", - path: ["endpoint"], - }) - .refine( - (data) => { - // When using Entra ID, all three fields are required - if (data._auth_type === "entra_id") { - return isEnvVarSet(data.client_id) && isEnvVarSet(data.client_secret) && isEnvVarSet(data.tenant_id); - } - // Otherwise, if any Entra ID field is set, all three must be set - const hasClientId = isEnvVarSet(data.client_id); - const hasClientSecret = isEnvVarSet(data.client_secret); - const hasTenantId = isEnvVarSet(data.tenant_id); - const anyEntraField = hasClientId || hasClientSecret || hasTenantId; - if (!anyEntraField) return true; - return hasClientId && hasClientSecret && hasTenantId; - }, - { - message: "Client ID, Client Secret, and Tenant ID are all required for Entra ID authentication", - path: ["client_id"], - }, - ); + .object({ + _auth_type: z + .enum(["api_key", "entra_id", "default_credential"]) + .optional(), + endpoint: envVarSchema.optional(), + api_version: envVarSchema.optional(), + client_id: envVarSchema.optional(), + client_secret: envVarSchema.optional(), + tenant_id: envVarSchema.optional(), + scopes: z.array(z.string()).optional(), + }) + .refine((data) => isEnvVarSet(data.endpoint), { + message: "Endpoint is required", + path: ["endpoint"], + }) + .refine( + (data) => { + // When using Entra ID, all three fields are required + if (data._auth_type === "entra_id") { + return ( + isEnvVarSet(data.client_id) && + isEnvVarSet(data.client_secret) && + isEnvVarSet(data.tenant_id) + ); + } + // Otherwise, if any Entra ID field is set, all three must be set + const hasClientId = isEnvVarSet(data.client_id); + const hasClientSecret = isEnvVarSet(data.client_secret); + const hasTenantId = isEnvVarSet(data.tenant_id); + const anyEntraField = hasClientId || hasClientSecret || hasTenantId; + if (!anyEntraField) return true; + return hasClientId && hasClientSecret && hasTenantId; + }, + { + message: + "Client ID, Client Secret, and Tenant ID are all required for Entra ID authentication", + path: ["client_id"], + }, + ); // Vertex key config schema export const vertexKeyConfigSchema = z - .object({ - _auth_type: z.enum(["service_account", "service_account_json", "api_key"]).optional(), - project_id: envVarSchema.optional(), - project_number: envVarSchema.optional(), - region: envVarSchema.optional(), - auth_credentials: envVarSchema.optional(), - }) - .refine((data) => isEnvVarSet(data.project_id), { - message: "Project ID is required", - path: ["project_id"], - }) - .refine((data) => isEnvVarSet(data.region), { - message: "Region is required", - path: ["region"], - }) - .refine( - (data) => { - // When using service_account_json auth, auth_credentials is required - if (data._auth_type === "service_account_json") { - return isEnvVarSet(data.auth_credentials); - } - return true; - }, - { - message: "Auth Credentials is required for service account JSON authentication", - path: ["auth_credentials"], - }, - ); + .object({ + _auth_type: z + .enum(["service_account", "service_account_json", "api_key"]) + .optional(), + project_id: envVarSchema.optional(), + project_number: envVarSchema.optional(), + region: envVarSchema.optional(), + auth_credentials: envVarSchema.optional(), + }) + .refine((data) => isEnvVarSet(data.project_id), { + message: "Project ID is required", + path: ["project_id"], + }) + .refine((data) => isEnvVarSet(data.region), { + message: "Region is required", + path: ["region"], + }) + .refine( + (data) => { + // When using service_account_json auth, auth_credentials is required + if (data._auth_type === "service_account_json") { + return isEnvVarSet(data.auth_credentials); + } + return true; + }, + { + message: + "Auth Credentials is required for service account JSON authentication", + path: ["auth_credentials"], + }, + ); // S3 bucket configuration for Bedrock batch operations export const s3BucketConfigSchema = z.object({ - bucket_name: z.string().min(1, "Bucket name is required"), - prefix: z.string().optional(), - is_default: z.boolean().optional(), + bucket_name: z.string().min(1, "Bucket name is required"), + prefix: z.string().optional(), + is_default: z.boolean().optional(), }); export const batchS3ConfigSchema = z.object({ - buckets: z.array(s3BucketConfigSchema).optional(), + buckets: z.array(s3BucketConfigSchema).optional(), }); // Bedrock key config schema export const bedrockKeyConfigSchema = z - .object({ - _auth_type: z.enum(["iam_role", "explicit", "api_key"]).optional(), - access_key: envVarSchema.optional(), - secret_key: envVarSchema.optional(), - session_token: envVarSchema.optional(), - region: envVarSchema.optional(), - role_arn: envVarSchema.optional(), - external_id: envVarSchema.optional(), - session_name: envVarSchema.optional(), - arn: envVarSchema.optional(), - batch_s3_config: batchS3ConfigSchema.optional(), - }) - .refine( - (data) => { - // Region is required for Bedrock - return isEnvVarSet(data.region); - }, - { - message: "Region is required", - path: ["region"], - }, - ) - .refine( - (data) => { - // When using explicit credentials, both access_key and secret_key are required - if (data._auth_type === "explicit") { - return isEnvVarSet(data.access_key) && isEnvVarSet(data.secret_key); - } - // Otherwise, if either is set both must be set - const hasAccessKey = isEnvVarSet(data.access_key); - const hasSecretKey = isEnvVarSet(data.secret_key); - if (!hasAccessKey && !hasSecretKey) return true; - return hasAccessKey && hasSecretKey; - }, - { - message: "Both Access Key and Secret Key are required for explicit credentials", - path: ["access_key"], - }, - ); + .object({ + _auth_type: z.enum(["iam_role", "explicit", "api_key"]).optional(), + access_key: envVarSchema.optional(), + secret_key: envVarSchema.optional(), + session_token: envVarSchema.optional(), + region: envVarSchema.optional(), + role_arn: envVarSchema.optional(), + external_id: envVarSchema.optional(), + session_name: envVarSchema.optional(), + arn: envVarSchema.optional(), + batch_s3_config: batchS3ConfigSchema.optional(), + }) + .refine( + (data) => { + // Region is required for Bedrock + return isEnvVarSet(data.region); + }, + { + message: "Region is required", + path: ["region"], + }, + ) + .refine( + (data) => { + // When using explicit credentials, both access_key and secret_key are required + if (data._auth_type === "explicit") { + return isEnvVarSet(data.access_key) && isEnvVarSet(data.secret_key); + } + // Otherwise, if either is set both must be set + const hasAccessKey = isEnvVarSet(data.access_key); + const hasSecretKey = isEnvVarSet(data.secret_key); + if (!hasAccessKey && !hasSecretKey) return true; + return hasAccessKey && hasSecretKey; + }, + { + message: + "Both Access Key and Secret Key are required for explicit credentials", + path: ["access_key"], + }, + ); // VLLM key config schema export const vllmKeyConfigSchema = z - .object({ - url: envVarSchema.optional(), - model_name: z.string().trim().min(1, "Model name is required"), - }) - .refine((data) => isEnvVarSet(data.url), { - message: "Server URL is required", - path: ["url"], - }); + .object({ + url: envVarSchema.optional(), + model_name: z.string().trim().min(1, "Model name is required"), + }) + .refine((data) => isEnvVarSet(data.url), { + message: "Server URL is required", + path: ["url"], + }); export const replicateKeyConfigSchema = z.object({ - use_deployments_endpoint: z.boolean(), + use_deployments_endpoint: z.boolean(), }); // Ollama key config schema export const ollamaKeyConfigSchema = z - .object({ - url: envVarSchema.optional(), - }) - .refine((data) => isEnvVarSet(data.url), { - message: "Server URL is required", - path: ["url"], - }); + .object({ + url: envVarSchema.optional(), + }) + .refine((data) => isEnvVarSet(data.url), { + message: "Server URL is required", + path: ["url"], + }); // SGL key config schema export const sglKeyConfigSchema = z - .object({ - url: envVarSchema.optional(), - }) - .refine((data) => isEnvVarSet(data.url), { - message: "Server URL is required", - path: ["url"], - }); + .object({ + url: envVarSchema.optional(), + }) + .refine((data) => isEnvVarSet(data.url), { + message: "Server URL is required", + path: ["url"], + }); // Model provider key schema export const modelProviderKeySchema = z - .object({ - id: z.string().min(1, "Id is required"), - name: z.string().min(1, "Name is required"), - value: envVarSchema.optional(), - models: z.array(z.string()).optional().default(["*"]), - blacklisted_models: z.array(z.string()).default([]).optional(), - weight: z - .union([z.number(), z.string()]) - .transform((val, ctx) => { - if (typeof val === "number") return val; - if (val.trim() === "") return 1.0; - // Use Number() rather than parseFloat() so that strings like "0.5abc" - // are rejected outright instead of silently parsing to 0.5. - const num = Number(val); - if (!Number.isFinite(num)) { - ctx.addIssue({ - code: "custom", - message: "Weight must be a valid number between 0 and 1", - }); - return z.NEVER; - } - return num; - }) - .pipe(z.number().min(0, "Weight must be equal to or greater than 0").max(1, "Weight must be equal to or less than 1")), - aliases: z.record(z.string(), z.string()).optional(), - azure_key_config: azureKeyConfigSchema.optional(), - vertex_key_config: vertexKeyConfigSchema.optional(), - bedrock_key_config: bedrockKeyConfigSchema.optional(), - vllm_key_config: vllmKeyConfigSchema.optional(), - replicate_key_config: replicateKeyConfigSchema.optional(), - ollama_key_config: ollamaKeyConfigSchema.optional(), - sgl_key_config: sglKeyConfigSchema.optional(), - use_for_batch_api: z.boolean().optional(), - enabled: z.boolean().optional(), - }) - .refine( - (data) => { - // Providers with dedicated config that never need a top-level API key - if (data.vllm_key_config || data.replicate_key_config || data.ollama_key_config || data.sgl_key_config) { - return true; - } - // Azure requires API key only when using api_key auth - if (data.azure_key_config) { - if (data.azure_key_config._auth_type === "api_key") { - return isEnvVarSet(data.value); - } - return true; - } - // Bedrock only requires API key when using api_key auth - if (data.bedrock_key_config) { - if (data.bedrock_key_config._auth_type === "api_key") { - return isEnvVarSet(data.value); - } - return true; - } - // Vertex requires API key only when using api_key auth - if (data.vertex_key_config) { - if (data.vertex_key_config._auth_type === "api_key") { - return isEnvVarSet(data.value); - } - return true; - } - // Otherwise, value is required - return isEnvVarSet(data.value); - }, - { - message: "API Key is required", - path: ["value"], - }, - ); + .object({ + id: z.string().min(1, "Id is required"), + name: z.string().min(1, "Name is required"), + value: envVarSchema.optional(), + models: z.array(z.string()).optional().default(["*"]), + blacklisted_models: z.array(z.string()).default([]).optional(), + weight: z + .union([z.number(), z.string()]) + .transform((val, ctx) => { + if (typeof val === "number") return val; + if (val.trim() === "") return 1.0; + // Use Number() rather than parseFloat() so that strings like "0.5abc" + // are rejected outright instead of silently parsing to 0.5. + const num = Number(val); + if (!Number.isFinite(num)) { + ctx.addIssue({ + code: "custom", + message: "Weight must be a valid number between 0 and 1", + }); + return z.NEVER; + } + return num; + }) + .pipe( + z + .number() + .min(0, "Weight must be equal to or greater than 0") + .max(1, "Weight must be equal to or less than 1"), + ), + aliases: z.record(z.string(), z.string()).optional(), + azure_key_config: azureKeyConfigSchema.optional(), + vertex_key_config: vertexKeyConfigSchema.optional(), + bedrock_key_config: bedrockKeyConfigSchema.optional(), + vllm_key_config: vllmKeyConfigSchema.optional(), + replicate_key_config: replicateKeyConfigSchema.optional(), + ollama_key_config: ollamaKeyConfigSchema.optional(), + sgl_key_config: sglKeyConfigSchema.optional(), + use_for_batch_api: z.boolean().optional(), + enabled: z.boolean().optional(), + }) + .refine( + (data) => { + // Providers with dedicated config that never need a top-level API key + if ( + data.vllm_key_config || + data.replicate_key_config || + data.ollama_key_config || + data.sgl_key_config + ) { + return true; + } + // Azure requires API key only when using api_key auth + if (data.azure_key_config) { + if (data.azure_key_config._auth_type === "api_key") { + return isEnvVarSet(data.value); + } + return true; + } + // Bedrock only requires API key when using api_key auth + if (data.bedrock_key_config) { + if (data.bedrock_key_config._auth_type === "api_key") { + return isEnvVarSet(data.value); + } + return true; + } + // Vertex requires API key only when using api_key auth + if (data.vertex_key_config) { + if (data.vertex_key_config._auth_type === "api_key") { + return isEnvVarSet(data.value); + } + return true; + } + // Otherwise, value is required + return isEnvVarSet(data.value); + }, + { + message: "API Key is required", + path: ["value"], + }, + ); // Network config schema export const networkConfigSchema = z - .object({ - base_url: z.union([z.string().url("Must be a valid URL"), z.string().length(0)]).optional(), - extra_headers: z.record(z.string(), z.string()).optional(), - default_request_timeout_in_seconds: z - .number() - .min(1, "Timeout must be greater than 0 seconds") - .max(3600, "Timeout must be less than 3600 seconds"), - max_retries: z.number().min(0, "Max retries must be greater than 0").max(10, "Max retries must be less than 10"), - retry_backoff_initial: z.number().min(100), - retry_backoff_max: z.number().min(100), - insecure_skip_verify: z.boolean().optional(), - ca_cert_pem: envVarSchema.optional(), - stream_idle_timeout_in_seconds: z - .number() - .int("Stream idle timeout must be a whole number of seconds") - .min(5, "Stream idle timeout must be at least 5 seconds") - .max(3600, "Stream idle timeout must be at most 3600 seconds i.e. 60 minutes") - .optional(), - max_conns_per_host: z - .number() - .int("Max connections must be a whole number") - .min(1, "Max connections must be at least 1") - .max(10000, "Max connections must be at most 10000") - .optional(), - enforce_http2: z.boolean().optional(), - }) - .refine((d) => d.retry_backoff_initial <= d.retry_backoff_max, { - message: "retry_backoff_initial must be <= retry_backoff_max", - path: ["retry_backoff_initial"], - }); + .object({ + base_url: z + .union([z.string().url("Must be a valid URL"), z.string().length(0)]) + .optional(), + extra_headers: z.record(z.string(), z.string()).optional(), + default_request_timeout_in_seconds: z + .number() + .min(1, "Timeout must be greater than 0 seconds") + .max(3600, "Timeout must be less than 3600 seconds"), + max_retries: z + .number() + .min(0, "Max retries must be greater than 0") + .max(10, "Max retries must be less than 10"), + retry_backoff_initial: z.number().min(100), + retry_backoff_max: z.number().min(100), + insecure_skip_verify: z.boolean().optional(), + ca_cert_pem: envVarSchema.optional(), + stream_idle_timeout_in_seconds: z + .number() + .int("Stream idle timeout must be a whole number of seconds") + .min(5, "Stream idle timeout must be at least 5 seconds") + .max( + 3600, + "Stream idle timeout must be at most 3600 seconds i.e. 60 minutes", + ) + .optional(), + max_conns_per_host: z + .number() + .int("Max connections must be a whole number") + .min(1, "Max connections must be at least 1") + .max(10000, "Max connections must be at most 10000") + .optional(), + enforce_http2: z.boolean().optional(), + }) + .refine((d) => d.retry_backoff_initial <= d.retry_backoff_max, { + message: "retry_backoff_initial must be <= retry_backoff_max", + path: ["retry_backoff_initial"], + }); // Network form schema - more lenient for form inputs export const networkFormConfigSchema = z - .object({ - base_url: z - .union([ - z - .string() - .url("Must be a valid URL") - .refine((url) => url.startsWith("https://") || url.startsWith("http://"), { - message: "Must be a valid HTTP or HTTPS URL", - }), - z.string().length(0), - ]) - .optional(), - extra_headers: z.record(z.string(), z.string()).optional(), - default_request_timeout_in_seconds: z.coerce - .number("Timeout must be a number") - .min(1, "Timeout must be greater than 0 seconds") - .max(172800, "Timeout must be less than 172800 seconds i.e. 48 hours"), - max_retries: z.coerce - .number("Max retries must be a number") - .min(0, "Max retries must be greater than 0") - .max(10, "Max retries must be less than 10"), - retry_backoff_initial: z.coerce - .number("Retry backoff initial must be a number") - .min(100, "Retry backoff initial must be at least 100ms") - .max(1000000, "Retry backoff initial must be at most 1000000ms"), - retry_backoff_max: z.coerce - .number("Retry backoff max must be a number") - .min(100, "Retry backoff max must be at least 100ms") - .max(1000000, "Retry backoff max must be at most 1000000ms"), - insecure_skip_verify: z.boolean().optional(), - ca_cert_pem: envVarSchema.optional(), - stream_idle_timeout_in_seconds: z.coerce - .number("Stream idle timeout must be a number") - .int("Stream idle timeout must be a whole number of seconds") - .min(5, "Stream idle timeout must be at least 5 seconds") - .max(3600, "Stream idle timeout must be at most 3600 seconds i.e. 60 minutes") - .optional(), - max_conns_per_host: z.coerce - .number("Max connections must be a number") - .int("Max connections must be a whole number") - .min(1, "Max connections must be at least 1") - .max(10000, "Max connections must be at most 10000") - .optional(), - enforce_http2: z.boolean().optional(), - }) - .refine((d) => d.retry_backoff_initial <= d.retry_backoff_max, { - message: "Initial backoff must be less than or equal to max backoff", - path: ["retry_backoff_initial"], - }); + .object({ + base_url: z + .union([ + z + .string() + .url("Must be a valid URL") + .refine( + (url) => url.startsWith("https://") || url.startsWith("http://"), + { + message: "Must be a valid HTTP or HTTPS URL", + }, + ), + z.string().length(0), + ]) + .optional(), + extra_headers: z.record(z.string(), z.string()).optional(), + default_request_timeout_in_seconds: z.coerce + .number("Timeout must be a number") + .min(1, "Timeout must be greater than 0 seconds") + .max(172800, "Timeout must be less than 172800 seconds i.e. 48 hours"), + max_retries: z.coerce + .number("Max retries must be a number") + .min(0, "Max retries must be greater than 0") + .max(10, "Max retries must be less than 10"), + retry_backoff_initial: z.coerce + .number("Retry backoff initial must be a number") + .min(100, "Retry backoff initial must be at least 100ms") + .max(1000000, "Retry backoff initial must be at most 1000000ms"), + retry_backoff_max: z.coerce + .number("Retry backoff max must be a number") + .min(100, "Retry backoff max must be at least 100ms") + .max(1000000, "Retry backoff max must be at most 1000000ms"), + insecure_skip_verify: z.boolean().optional(), + ca_cert_pem: envVarSchema.optional(), + stream_idle_timeout_in_seconds: z.coerce + .number("Stream idle timeout must be a number") + .int("Stream idle timeout must be a whole number of seconds") + .min(5, "Stream idle timeout must be at least 5 seconds") + .max( + 3600, + "Stream idle timeout must be at most 3600 seconds i.e. 60 minutes", + ) + .optional(), + max_conns_per_host: z.coerce + .number("Max connections must be a number") + .int("Max connections must be a whole number") + .min(1, "Max connections must be at least 1") + .max(10000, "Max connections must be at most 10000") + .optional(), + enforce_http2: z.boolean().optional(), + }) + .refine((d) => d.retry_backoff_initial <= d.retry_backoff_max, { + message: "Initial backoff must be less than or equal to max backoff", + path: ["retry_backoff_initial"], + }); // Concurrency and buffer size schema export const concurrencyAndBufferSizeSchema = z.object({ - concurrency: z.number().min(1, "Concurrency must be greater than 0").max(100, "Concurrency must be less than or equal to 100"), - buffer_size: z.number().min(1, "Buffer size must be greater than 0").max(1000, "Buffer size must be less than or equal to 1000"), + concurrency: z + .number() + .min(1, "Concurrency must be greater than 0") + .max(100, "Concurrency must be less than or equal to 100"), + buffer_size: z + .number() + .min(1, "Buffer size must be greater than 0") + .max(1000, "Buffer size must be less than or equal to 1000"), }); // Proxy type schema -export const proxyTypeSchema = z.enum(["none", "http", "socks5", "environment"]); +export const proxyTypeSchema = z.enum([ + "none", + "http", + "socks5", + "environment", +]); // Proxy config schema export const proxyConfigSchema = z - .object({ - type: proxyTypeSchema, - url: envVarSchema.optional(), - username: envVarSchema.optional(), - password: envVarSchema.optional(), - ca_cert_pem: envVarSchema.optional(), - }) - .refine( - (data) => - !(data.type === "http" || data.type === "socks5") || - data.url?.from_env === true || - (data.url?.value && data.url.value.trim().length > 0), - { - message: "Proxy URL is required when using HTTP or SOCKS5 proxy", - path: ["url"], - }, - ) - .refine( - (data) => { - if ((data.type === "http" || data.type === "socks5") && data.url?.value?.trim()) { - if (data.url.from_env || data.url.env_var?.startsWith("env.")) { - return true; - } - try { - new URL(data.url.value); - return true; - } catch { - return false; - } - } - return true; - }, - { message: "Must be a valid URL (e.g., http://proxy.example.com:8080)", path: ["url"] }, - ); + .object({ + type: proxyTypeSchema, + url: envVarSchema.optional(), + username: envVarSchema.optional(), + password: envVarSchema.optional(), + ca_cert_pem: envVarSchema.optional(), + }) + .refine( + (data) => + !(data.type === "http" || data.type === "socks5") || + data.url?.from_env === true || + (data.url?.value && data.url.value.trim().length > 0), + { + message: "Proxy URL is required when using HTTP or SOCKS5 proxy", + path: ["url"], + }, + ) + .refine( + (data) => { + if ( + (data.type === "http" || data.type === "socks5") && + data.url?.value?.trim() + ) { + if (data.url.from_env || data.url.env_var?.startsWith("env.")) { + return true; + } + try { + new URL(data.url.value); + return true; + } catch { + return false; + } + } + return true; + }, + { + message: "Must be a valid URL (e.g., http://proxy.example.com:8080)", + path: ["url"], + }, + ); // Proxy form schema - more lenient for form inputs with conditional validation export const proxyFormConfigSchema = z - .object({ - type: proxyTypeSchema, - url: envVarSchema.optional(), - username: envVarSchema.optional(), - password: envVarSchema.optional(), - ca_cert_pem: envVarSchema.optional(), - }) - .refine( - (data) => { - if (data.type === "none") { - return true; - } - // URL is required when proxy type is http or socks5 - if (data.type === "http" || data.type === "socks5") { - // Env-backed URLs may have empty resolved value before env resolution. - if (data.url?.from_env || data.url?.env_var?.startsWith("env.")) return true; - // Literal URLs must be non-empty. - if (!data.url?.value || data.url.value.trim().length === 0) return false; - } - return true; - }, - { - message: "Proxy URL is required when using HTTP or SOCKS5 proxy", - path: ["url"], - }, - ) - .refine( - (data) => { - // URL must be valid format when provided and proxy type requires it - if ((data.type === "http" || data.type === "socks5") && data.url?.value && data.url.value.trim().length > 0) { - if (data.url.from_env || data.url.env_var?.startsWith("env.")) { - return true; - } - try { - new URL(data.url.value); - return true; - } catch { - return false; - } - } - return true; - }, - { - message: "Must be a valid URL (e.g., http://proxy.example.com:8080)", - path: ["url"], - }, - ); + .object({ + type: proxyTypeSchema, + url: envVarSchema.optional(), + username: envVarSchema.optional(), + password: envVarSchema.optional(), + ca_cert_pem: envVarSchema.optional(), + }) + .refine( + (data) => { + if (data.type === "none") { + return true; + } + // URL is required when proxy type is http or socks5 + if (data.type === "http" || data.type === "socks5") { + // Env-backed URLs may have empty resolved value before env resolution. + if (data.url?.from_env || data.url?.env_var?.startsWith("env.")) + return true; + // Literal URLs must be non-empty. + if (!data.url?.value || data.url.value.trim().length === 0) + return false; + } + return true; + }, + { + message: "Proxy URL is required when using HTTP or SOCKS5 proxy", + path: ["url"], + }, + ) + .refine( + (data) => { + // URL must be valid format when provided and proxy type requires it + if ( + (data.type === "http" || data.type === "socks5") && + data.url?.value && + data.url.value.trim().length > 0 + ) { + if (data.url.from_env || data.url.env_var?.startsWith("env.")) { + return true; + } + try { + new URL(data.url.value); + return true; + } catch { + return false; + } + } + return true; + }, + { + message: "Must be a valid URL (e.g., http://proxy.example.com:8080)", + path: ["url"], + }, + ); // OpenAI Config tab export const openaiConfigFormSchema = z.object({ - disable_store: z.boolean(), + disable_store: z.boolean(), }); export type OpenAIConfigFormSchema = z.infer; // Allowed requests schema export const allowedRequestsSchema = z.object({ - text_completion: z.boolean(), - text_completion_stream: z.boolean(), - chat_completion: z.boolean(), - chat_completion_stream: z.boolean(), - responses: z.boolean(), - responses_stream: z.boolean(), - embedding: z.boolean(), - speech: z.boolean(), - speech_stream: z.boolean(), - transcription: z.boolean(), - transcription_stream: z.boolean(), - image_generation: z.boolean(), - image_generation_stream: z.boolean(), - image_edit: z.boolean(), - image_edit_stream: z.boolean(), - image_variation: z.boolean(), - ocr: z.boolean().optional(), - ocr_stream: z.boolean().optional(), - rerank: z.boolean(), - video_generation: z.boolean(), - video_retrieve: z.boolean(), - video_download: z.boolean(), - video_delete: z.boolean(), - video_list: z.boolean(), - video_remix: z.boolean(), - count_tokens: z.boolean(), - list_models: z.boolean(), - websocket_responses: z.boolean(), - realtime: z.boolean(), + text_completion: z.boolean(), + text_completion_stream: z.boolean(), + chat_completion: z.boolean(), + chat_completion_stream: z.boolean(), + responses: z.boolean(), + responses_stream: z.boolean(), + embedding: z.boolean(), + speech: z.boolean(), + speech_stream: z.boolean(), + transcription: z.boolean(), + transcription_stream: z.boolean(), + image_generation: z.boolean(), + image_generation_stream: z.boolean(), + image_edit: z.boolean(), + image_edit_stream: z.boolean(), + image_variation: z.boolean(), + ocr: z.boolean().optional(), + ocr_stream: z.boolean().optional(), + rerank: z.boolean(), + video_generation: z.boolean(), + video_retrieve: z.boolean(), + video_download: z.boolean(), + video_delete: z.boolean(), + video_list: z.boolean(), + video_remix: z.boolean(), + count_tokens: z.boolean(), + list_models: z.boolean(), + websocket_responses: z.boolean(), + realtime: z.boolean(), }); // Custom provider config schema export const customProviderConfigSchema = z - .object({ - base_provider_type: knownProviderSchema, - is_key_less: z.boolean().optional(), - allowed_requests: allowedRequestsSchema.optional(), - request_path_overrides: z.record(z.string(), z.string().optional()).optional(), - }) - .refine( - (data) => { - if (data.base_provider_type === "bedrock") { - return !data.is_key_less; - } - return true; - }, - { - message: "Is keyless is not allowed for Bedrock", - path: ["is_key_less"], - }, - ); + .object({ + base_provider_type: knownProviderSchema, + is_key_less: z.boolean().optional(), + allowed_requests: allowedRequestsSchema.optional(), + request_path_overrides: z + .record(z.string(), z.string().optional()) + .optional(), + }) + .refine( + (data) => { + if (data.base_provider_type === "bedrock") { + return !data.is_key_less; + } + return true; + }, + { + message: "Is keyless is not allowed for Bedrock", + path: ["is_key_less"], + }, + ); // Form-specific custom provider config schema export const formCustomProviderConfigSchema = z - .object({ - base_provider_type: z.string().min(1, "Base provider type is required"), - is_key_less: z.boolean().optional(), - allowed_requests: allowedRequestsSchema.optional(), - request_path_overrides: z.record(z.string(), z.string().optional()).optional(), - }) - .refine( - (data) => { - if (data.base_provider_type === "bedrock") { - return !data.is_key_less; - } - return true; - }, - { - message: "Is keyless is not allowed for Bedrock", - path: ["is_key_less"], - }, - ); + .object({ + base_provider_type: z.string().min(1, "Base provider type is required"), + is_key_less: z.boolean().optional(), + allowed_requests: allowedRequestsSchema.optional(), + request_path_overrides: z + .record(z.string(), z.string().optional()) + .optional(), + }) + .refine( + (data) => { + if (data.base_provider_type === "bedrock") { + return !data.is_key_less; + } + return true; + }, + { + message: "Is keyless is not allowed for Bedrock", + path: ["is_key_less"], + }, + ); // Full model provider config schema export const modelProviderConfigSchema = z.object({ - keys: z.array(modelProviderKeySchema).min(1, "At least one key is required"), - network_config: networkConfigSchema.optional(), - concurrency_and_buffer_size: concurrencyAndBufferSizeSchema.optional(), - proxy_config: proxyConfigSchema.optional(), - send_back_raw_request: z.boolean().optional(), - send_back_raw_response: z.boolean().optional(), - store_raw_request_response: z.boolean().optional(), - custom_provider_config: customProviderConfigSchema.optional(), + keys: z.array(modelProviderKeySchema).min(1, "At least one key is required"), + network_config: networkConfigSchema.optional(), + concurrency_and_buffer_size: concurrencyAndBufferSizeSchema.optional(), + proxy_config: proxyConfigSchema.optional(), + send_back_raw_request: z.boolean().optional(), + send_back_raw_response: z.boolean().optional(), + store_raw_request_response: z.boolean().optional(), + custom_provider_config: customProviderConfigSchema.optional(), }); // Model provider schema export const modelProviderSchema = modelProviderConfigSchema.extend({ - name: modelProviderNameSchema, + name: modelProviderNameSchema, }); // Form-specific model provider config schema export const formModelProviderConfigSchema = z.object({ - keys: z.array(modelProviderKeySchema).min(1, "At least one key is required"), - network_config: networkConfigSchema.optional(), - concurrency_and_buffer_size: concurrencyAndBufferSizeSchema.optional(), - proxy_config: proxyConfigSchema.optional(), - send_back_raw_request: z.boolean().optional(), - send_back_raw_response: z.boolean().optional(), - store_raw_request_response: z.boolean().optional(), - custom_provider_config: formCustomProviderConfigSchema.optional(), + keys: z.array(modelProviderKeySchema).min(1, "At least one key is required"), + network_config: networkConfigSchema.optional(), + concurrency_and_buffer_size: concurrencyAndBufferSizeSchema.optional(), + proxy_config: proxyConfigSchema.optional(), + send_back_raw_request: z.boolean().optional(), + send_back_raw_response: z.boolean().optional(), + store_raw_request_response: z.boolean().optional(), + custom_provider_config: formCustomProviderConfigSchema.optional(), }); // Flexible model provider schema for form data - allows any string for name export const formModelProviderSchema = formModelProviderConfigSchema.extend({ - name: z.string().min(1, "Provider name is required"), + name: z.string().min(1, "Provider name is required"), }); // Add provider request schema export const addProviderRequestSchema = z.object({ - provider: modelProviderNameSchema, - keys: z.array(modelProviderKeySchema).min(1, "At least one key is required"), - network_config: networkConfigSchema.optional(), - concurrency_and_buffer_size: concurrencyAndBufferSizeSchema.optional(), - proxy_config: proxyConfigSchema.optional(), - send_back_raw_request: z.boolean().optional(), - send_back_raw_response: z.boolean().optional(), - store_raw_request_response: z.boolean().optional(), - custom_provider_config: customProviderConfigSchema.optional(), - openai_config: openaiConfigFormSchema.optional(), + provider: modelProviderNameSchema, + keys: z.array(modelProviderKeySchema).min(1, "At least one key is required"), + network_config: networkConfigSchema.optional(), + concurrency_and_buffer_size: concurrencyAndBufferSizeSchema.optional(), + proxy_config: proxyConfigSchema.optional(), + send_back_raw_request: z.boolean().optional(), + send_back_raw_response: z.boolean().optional(), + store_raw_request_response: z.boolean().optional(), + custom_provider_config: customProviderConfigSchema.optional(), + openai_config: openaiConfigFormSchema.optional(), }); // Update provider request schema export const updateProviderRequestSchema = z.object({ - keys: z.array(modelProviderKeySchema).min(1, "At least one key is required"), - network_config: networkConfigSchema, - concurrency_and_buffer_size: concurrencyAndBufferSizeSchema, - proxy_config: proxyConfigSchema, - send_back_raw_request: z.boolean().optional(), - send_back_raw_response: z.boolean().optional(), - store_raw_request_response: z.boolean().optional(), - custom_provider_config: customProviderConfigSchema.optional(), - openai_config: openaiConfigFormSchema.optional(), + keys: z.array(modelProviderKeySchema).min(1, "At least one key is required"), + network_config: networkConfigSchema, + concurrency_and_buffer_size: concurrencyAndBufferSizeSchema, + proxy_config: proxyConfigSchema, + send_back_raw_request: z.boolean().optional(), + send_back_raw_response: z.boolean().optional(), + store_raw_request_response: z.boolean().optional(), + custom_provider_config: customProviderConfigSchema.optional(), + openai_config: openaiConfigFormSchema.optional(), }); // Cache config schema const baseCacheConfigSchema = z.object({ - ttl_seconds: z.number().int().min(1).default(3600), - threshold: z.number().min(0).max(1).default(0.8), - conversation_history_threshold: z.number().int().min(0).optional(), - exclude_system_prompt: z.boolean().optional(), - cache_by_model: z.boolean().default(false), - cache_by_provider: z.boolean().default(false), - created_at: z.string().optional(), - updated_at: z.string().optional(), + ttl_seconds: z.number().int().min(1).default(3600), + threshold: z.number().min(0).max(1).default(0.8), + conversation_history_threshold: z.number().int().min(0).optional(), + exclude_system_prompt: z.boolean().optional(), + cache_by_model: z.boolean().default(false), + cache_by_provider: z.boolean().default(false), + created_at: z.string().optional(), + updated_at: z.string().optional(), }); const directCacheConfigSchema = baseCacheConfigSchema - .extend({ - dimension: z.literal(1), - keys: z.array(modelProviderKeySchema).optional(), - }) - .strict(); + .extend({ + dimension: z.literal(1), + keys: z.array(modelProviderKeySchema).optional(), + }) + .strict(); const providerBackedCacheConfigSchema = baseCacheConfigSchema - .extend({ - provider: modelProviderNameSchema, - keys: z.array(modelProviderKeySchema).optional(), - embedding_model: z.string().min(1, "Embedding model is required"), - dimension: z.number().int().min(2, "Dimension must be greater than 1 for provider-backed semantic cache"), - }) - .strict(); - -export const cacheConfigSchema = z.union([directCacheConfigSchema, providerBackedCacheConfigSchema]); + .extend({ + provider: modelProviderNameSchema, + keys: z.array(modelProviderKeySchema).optional(), + embedding_model: z.string().min(1, "Embedding model is required"), + dimension: z + .number() + .int() + .min( + 2, + "Dimension must be greater than 1 for provider-backed semantic cache", + ), + }) + .strict(); + +export const cacheConfigSchema = z.union([ + directCacheConfigSchema, + providerBackedCacheConfigSchema, +]); // Core config schema export const coreConfigSchema = z.object({ - drop_excess_requests: z.boolean().default(false), - initial_pool_size: z.number().min(1).default(10), - prometheus_labels: z.array(z.string()).default([]), - enable_logging: z.boolean().default(true), - disable_content_logging: z.boolean().default(false), - enforce_auth_on_inference: z.boolean().default(false), - allow_direct_keys: z.boolean().default(false), - hide_deleted_virtual_keys_in_filters: z.boolean().default(false), - allowed_origins: z.array(z.string()).default(["*"]), - max_request_body_size_mb: z.number().min(1).default(100), - mcp_agent_depth: z.number().min(1).default(10), - mcp_tool_execution_timeout: z.number().min(1).default(30), - mcp_code_mode_binding_level: z.enum(["server", "tool"]).default("server"), - mcp_disable_auto_tool_inject: z.boolean().default(false), + drop_excess_requests: z.boolean().default(false), + initial_pool_size: z.number().min(1).default(10), + prometheus_labels: z.array(z.string()).default([]), + enable_logging: z.boolean().default(true), + disable_content_logging: z.boolean().default(false), + enforce_auth_on_inference: z.boolean().default(false), + allow_direct_keys: z.boolean().default(false), + hide_deleted_virtual_keys_in_filters: z.boolean().default(false), + allowed_origins: z.array(z.string()).default(["*"]), + max_request_body_size_mb: z.number().min(1).default(100), + mcp_agent_depth: z.number().min(1).default(10), + mcp_tool_execution_timeout: z.number().min(1).default(30), + mcp_code_mode_binding_level: z.enum(["server", "tool"]).default("server"), + mcp_disable_auto_tool_inject: z.boolean().default(false), }); // Bifrost config schema export const bifrostConfigSchema = z.object({ - client_config: coreConfigSchema, - is_db_connected: z.boolean(), - is_cache_connected: z.boolean(), - is_logs_connected: z.boolean(), + client_config: coreConfigSchema, + is_db_connected: z.boolean(), + is_cache_connected: z.boolean(), + is_logs_connected: z.boolean(), }); // Network and proxy form schema - combined for the NetworkFormFragment export const networkAndProxyFormSchema = z.object({ - network_config: networkFormConfigSchema.optional(), - proxy_config: proxyFormConfigSchema.optional(), + network_config: networkFormConfigSchema.optional(), + proxy_config: proxyFormConfigSchema.optional(), }); // Proxy-only form schema for the ProxyFormFragment export const proxyOnlyFormSchema = z.object({ - proxy_config: proxyFormConfigSchema.optional(), + proxy_config: proxyFormConfigSchema.optional(), }); // Network-only form schema for the NetworkFormFragment export const networkOnlyFormSchema = z.object({ - network_config: networkFormConfigSchema.optional(), + network_config: networkFormConfigSchema.optional(), }); // Performance form schema for the PerformanceFormFragment (concurrency/buffer only; raw request/response are in Debugging tab) export const performanceFormSchema = z.object({ - concurrency_and_buffer_size: z - .object({ - concurrency: z - .number({ error: "Concurrency must be a number" }) - .min(1, "Concurrency must be greater than 0") - .max(100000, "Concurrency must be less than 100000"), - buffer_size: z - .number({ error: "Buffer size must be a number" }) - .min(1, "Buffer size must be greater than 0") - .max(100000, "Buffer size must be less than 100000"), - }) - .refine((data) => data.concurrency <= data.buffer_size, { - message: "Concurrency must be less than or equal to buffer size", - path: ["concurrency"], - }), + concurrency_and_buffer_size: z + .object({ + concurrency: z + .number({ error: "Concurrency must be a number" }) + .min(1, "Concurrency must be greater than 0") + .max(100000, "Concurrency must be less than 100000"), + buffer_size: z + .number({ error: "Buffer size must be a number" }) + .min(1, "Buffer size must be greater than 0") + .max(100000, "Buffer size must be less than 100000"), + }) + .refine((data) => data.concurrency <= data.buffer_size, { + message: "Concurrency must be less than or equal to buffer size", + path: ["concurrency"], + }), }); // Debugging tab (raw request/response toggles) export const debuggingFormSchema = z.object({ - send_back_raw_request: z.boolean(), - send_back_raw_response: z.boolean(), - store_raw_request_response: z.boolean(), + send_back_raw_request: z.boolean(), + send_back_raw_response: z.boolean(), + store_raw_request_response: z.boolean(), }); export type DebuggingFormSchema = z.infer; // Beta Headers tab export const betaHeadersFormSchema = z.object({ - beta_header_overrides: z.record(z.string(), z.boolean()).optional(), + beta_header_overrides: z.record(z.string(), z.boolean()).optional(), }); export type BetaHeadersFormSchema = z.infer; // OTEL Configuration Schema export const otelConfigSchema = z - .object({ - service_name: z.string().optional(), - collector_url: z.string().default(""), - trace_type: z - .enum(["genai_extension", "vercel", "open_inference"], { - message: "Please select a trace type", - }) - .default("genai_extension"), - headers: z.record(z.string(), z.string()).optional(), - protocol: z - .enum(["http", "grpc"], { - message: "Please select a protocol", - }) - .default("http"), - // TLS configuration - tls_ca_cert: z.string().optional(), - insecure: z.boolean().default(true), - // Metrics push configuration - metrics_enabled: z.boolean().default(false), - metrics_endpoint: z.string().optional(), - metrics_push_interval: z.number().int().min(1).max(300).default(15), - }) - .superRefine((data, ctx) => { - const protocol = data.protocol; - const hostPortRegex = /^(?!https?:\/\/)([a-zA-Z0-9.-]+|\[[0-9a-fA-F:]+\]|\d{1,3}(?:\.\d{1,3}){3}):(\d{1,5})$/; - - // Helper to validate URL format - const validateHttpUrl = (url: string, path: string[]) => { - try { - const u = new URL(url); - if (!(u.protocol === "http:" || u.protocol === "https:")) { - ctx.addIssue({ - code: "custom", - path, - message: "Must be a valid HTTP or HTTPS URL", - }); - return false; - } - return true; - } catch { - ctx.addIssue({ - code: "custom", - path, - message: "Must be a valid HTTP or HTTPS URL", - }); - return false; - } - }; - - // Helper to validate host:port format - const validateHostPort = (value: string, path: string[], example: string) => { - const match = value.match(hostPortRegex); - if (!match) { - ctx.addIssue({ - code: "custom", - path, - message: `Must be in the format : for gRPC (e.g. ${example})`, - }); - return false; - } - const port = Number(match[2]); - if (!(port >= 1 && port <= 65535)) { - ctx.addIssue({ - code: "custom", - path, - message: "Port must be between 1 and 65535", - }); - return false; - } - return true; - }; - - // Validate collector_url format (emptiness check is at form level, gated by enabled) - const collectorUrl = (data.collector_url || "").trim(); - if (collectorUrl && protocol === "http") { - validateHttpUrl(collectorUrl, ["collector_url"]); - } else if (collectorUrl && protocol === "grpc") { - validateHostPort(collectorUrl, ["collector_url"], "otel-collector:4317"); - } - - // Validate metrics_endpoint when metrics_enabled is true - if (data.metrics_enabled) { - const metricsEndpoint = (data.metrics_endpoint || "").trim(); - if (!metricsEndpoint) { - ctx.addIssue({ - code: "custom", - path: ["metrics_endpoint"], - message: "Metrics endpoint is required when metrics push is enabled", - }); - } else if (protocol === "http") { - validateHttpUrl(metricsEndpoint, ["metrics_endpoint"]); - } else if (protocol === "grpc") { - validateHostPort(metricsEndpoint, ["metrics_endpoint"], "otel-collector:4317"); - } - } - }); + .object({ + service_name: z.string().optional(), + collector_url: z.string().default(""), + trace_type: z + .enum(["genai_extension", "vercel", "open_inference"], { + message: "Please select a trace type", + }) + .default("genai_extension"), + headers: z.record(z.string(), z.string()).optional(), + protocol: z + .enum(["http", "grpc"], { + message: "Please select a protocol", + }) + .default("http"), + // TLS configuration + tls_ca_cert: z.string().optional(), + insecure: z.boolean().default(true), + // Metrics push configuration + metrics_enabled: z.boolean().default(false), + metrics_endpoint: z.string().optional(), + metrics_push_interval: z.number().int().min(1).max(300).default(15), + }) + .superRefine((data, ctx) => { + const protocol = data.protocol; + const hostPortRegex = + /^(?!https?:\/\/)([a-zA-Z0-9.-]+|\[[0-9a-fA-F:]+\]|\d{1,3}(?:\.\d{1,3}){3}):(\d{1,5})$/; + + // Helper to validate URL format + const validateHttpUrl = (url: string, path: string[]) => { + try { + const u = new URL(url); + if (!(u.protocol === "http:" || u.protocol === "https:")) { + ctx.addIssue({ + code: "custom", + path, + message: "Must be a valid HTTP or HTTPS URL", + }); + return false; + } + return true; + } catch { + ctx.addIssue({ + code: "custom", + path, + message: "Must be a valid HTTP or HTTPS URL", + }); + return false; + } + }; + + // Helper to validate host:port format + const validateHostPort = ( + value: string, + path: string[], + example: string, + ) => { + const match = value.match(hostPortRegex); + if (!match) { + ctx.addIssue({ + code: "custom", + path, + message: `Must be in the format : for gRPC (e.g. ${example})`, + }); + return false; + } + const port = Number(match[2]); + if (!(port >= 1 && port <= 65535)) { + ctx.addIssue({ + code: "custom", + path, + message: "Port must be between 1 and 65535", + }); + return false; + } + return true; + }; + + // Validate collector_url format (emptiness check is at form level, gated by enabled) + const collectorUrl = (data.collector_url || "").trim(); + if (collectorUrl && protocol === "http") { + validateHttpUrl(collectorUrl, ["collector_url"]); + } else if (collectorUrl && protocol === "grpc") { + validateHostPort(collectorUrl, ["collector_url"], "otel-collector:4317"); + } + + // Validate metrics_endpoint when metrics_enabled is true + if (data.metrics_enabled) { + const metricsEndpoint = (data.metrics_endpoint || "").trim(); + if (!metricsEndpoint) { + ctx.addIssue({ + code: "custom", + path: ["metrics_endpoint"], + message: "Metrics endpoint is required when metrics push is enabled", + }); + } else if (protocol === "http") { + validateHttpUrl(metricsEndpoint, ["metrics_endpoint"]); + } else if (protocol === "grpc") { + validateHostPort( + metricsEndpoint, + ["metrics_endpoint"], + "otel-collector:4317", + ); + } + } + }); // OTEL form schema for the OtelFormFragment export const otelFormSchema = z - .object({ - enabled: z.boolean().default(true), - otel_config: otelConfigSchema, - }) - .superRefine((data, ctx) => { - if (data.enabled) { - const collectorUrl = (data.otel_config.collector_url || "").trim(); - if (!collectorUrl) { - ctx.addIssue({ - code: "custom", - path: ["otel_config", "collector_url"], - message: "Collector address is required", - }); - } - } - }); + .object({ + enabled: z.boolean().default(true), + otel_config: otelConfigSchema, + }) + .superRefine((data, ctx) => { + if (data.enabled) { + const collectorUrl = (data.otel_config.collector_url || "").trim(); + if (!collectorUrl) { + ctx.addIssue({ + code: "custom", + path: ["otel_config", "collector_url"], + message: "Collector address is required", + }); + } + } + }); // Maxim Configuration Schema export const maximConfigSchema = z.object({ - api_key: z.string().default(""), - log_repo_id: z.string().optional(), + api_key: z.string().default(""), + log_repo_id: z.string().optional(), }); // Maxim form schema for the MaximFormFragment export const maximFormSchema = z - .object({ - enabled: z.boolean().default(true), - maxim_config: maximConfigSchema, - }) - .superRefine((data, ctx) => { - if (data.enabled) { - const apiKey = (data.maxim_config.api_key || "").trim(); - if (!apiKey) { - ctx.addIssue({ - code: "custom", - path: ["maxim_config", "api_key"], - message: "API key is required", - }); - } else if (!apiKey.startsWith("sk_mx_")) { - ctx.addIssue({ - code: "custom", - path: ["maxim_config", "api_key"], - message: "API key must start with 'sk_mx_'", - }); - } - } - }); + .object({ + enabled: z.boolean().default(true), + maxim_config: maximConfigSchema, + }) + .superRefine((data, ctx) => { + if (data.enabled) { + const apiKey = (data.maxim_config.api_key || "").trim(); + if (!apiKey) { + ctx.addIssue({ + code: "custom", + path: ["maxim_config", "api_key"], + message: "API key is required", + }); + } else if (!apiKey.startsWith("sk_mx_")) { + ctx.addIssue({ + code: "custom", + path: ["maxim_config", "api_key"], + message: "API key must start with 'sk_mx_'", + }); + } + } + }); // Prometheus Push Gateway Configuration Schema export const prometheusConfigSchema = z - .object({ - push_gateway_url: z.string().optional(), - job_name: z.string().default("bifrost"), - instance_id: z.string().optional(), - push_interval: z.number().min(1).max(300).default(15), - basic_auth_username: z.string().optional(), - basic_auth_password: z.string().optional(), - }) - .superRefine((data, ctx) => { - // Validate push_gateway_url format - const url = (data.push_gateway_url || "").trim(); - if (url) { - try { - const u = new URL(url); - if (!(u.protocol === "http:" || u.protocol === "https:")) { - ctx.addIssue({ - code: "custom", - path: ["push_gateway_url"], - message: "Must be a valid HTTP or HTTPS URL", - }); - } - } catch { - ctx.addIssue({ - code: "custom", - path: ["push_gateway_url"], - message: "Must be a valid URL (e.g., http://pushgateway:9091)", - }); - } - } - - // Validate basic auth: if one credential is provided, both must be provided - const hasUsername = !!data.basic_auth_username?.trim(); - const hasPassword = !!data.basic_auth_password?.trim(); - if (hasUsername && !hasPassword) { - ctx.addIssue({ - code: "custom", - path: ["basic_auth_password"], - message: "Password is required when username is provided", - }); - } - if (hasPassword && !hasUsername) { - ctx.addIssue({ - code: "custom", - path: ["basic_auth_username"], - message: "Username is required when password is provided", - }); - } - }); + .object({ + push_gateway_url: z.string().optional(), + job_name: z.string().default("bifrost"), + instance_id: z.string().optional(), + push_interval: z.number().min(1).max(300).default(15), + basic_auth_username: z.string().optional(), + basic_auth_password: z.string().optional(), + }) + .superRefine((data, ctx) => { + // Validate push_gateway_url format + const url = (data.push_gateway_url || "").trim(); + if (url) { + try { + const u = new URL(url); + if (!(u.protocol === "http:" || u.protocol === "https:")) { + ctx.addIssue({ + code: "custom", + path: ["push_gateway_url"], + message: "Must be a valid HTTP or HTTPS URL", + }); + } + } catch { + ctx.addIssue({ + code: "custom", + path: ["push_gateway_url"], + message: "Must be a valid URL (e.g., http://pushgateway:9091)", + }); + } + } + + // Validate basic auth: if one credential is provided, both must be provided + const hasUsername = !!data.basic_auth_username?.trim(); + const hasPassword = !!data.basic_auth_password?.trim(); + if (hasUsername && !hasPassword) { + ctx.addIssue({ + code: "custom", + path: ["basic_auth_password"], + message: "Password is required when username is provided", + }); + } + if (hasPassword && !hasUsername) { + ctx.addIssue({ + code: "custom", + path: ["basic_auth_username"], + message: "Username is required when password is provided", + }); + } + }); // Prometheus form schema for the PrometheusFormFragment export const prometheusFormSchema = z - .object({ - enabled: z.boolean().default(true), - prometheus_config: prometheusConfigSchema, - }) - .superRefine((data, ctx) => { - // When enabled, push_gateway_url is required - if (data.enabled) { - const url = (data.prometheus_config.push_gateway_url || "").trim(); - if (!url) { - ctx.addIssue({ - code: "custom", - path: ["prometheus_config", "push_gateway_url"], - message: "Push Gateway URL is required when enabled", - }); - } - } - }); + .object({ + enabled: z.boolean().default(true), + prometheus_config: prometheusConfigSchema, + }) + .superRefine((data, ctx) => { + // When enabled, push_gateway_url is required + if (data.enabled) { + const url = (data.prometheus_config.push_gateway_url || "").trim(); + if (!url) { + ctx.addIssue({ + code: "custom", + path: ["prometheus_config", "push_gateway_url"], + message: "Push Gateway URL is required when enabled", + }); + } + } + }); // MCP Client update schema export const mcpClientUpdateSchema = z.object({ - is_code_mode_client: z.boolean().optional(), - is_ping_available: z.boolean().optional(), - allow_on_all_virtual_keys: z.boolean().optional(), - name: z - .string() - .min(1, "Name is required") - .refine((val) => !val.includes("-"), { - message: "Client name cannot contain hyphens", - }) - .refine((val) => !val.includes(" "), { - message: "Client name cannot contain spaces", - }) - .refine((val) => !/^[0-9]/.test(val), { - message: "Client name cannot start with a number", - }), - headers: z.record(z.string(), envVarSchema).optional().nullable(), - tools_to_execute: z - .array(z.string()) - .optional() - .refine( - (tools) => { - if (!tools || tools.length === 0) return true; - const hasWildcard = tools.includes("*"); - return !hasWildcard || tools.length === 1; - }, - { message: "Wildcard '*' cannot be combined with other tool names" }, - ) - .refine( - (tools) => { - if (!tools) return true; - return tools.length === new Set(tools).size; - }, - { message: "Duplicate tool names are not allowed" }, - ), - tools_to_auto_execute: z - .array(z.string()) - .optional() - .refine( - (tools) => { - if (!tools || tools.length === 0) return true; - const hasWildcard = tools.includes("*"); - return !hasWildcard || tools.length === 1; - }, - { message: "Wildcard '*' cannot be combined with other tool names" }, - ) - .refine( - (tools) => { - if (!tools) return true; - return tools.length === new Set(tools).size; - }, - { message: "Duplicate tool names are not allowed" }, - ), - tool_pricing: z.record(z.string(), z.number().min(0, "Cost must be non-negative")).optional(), - tool_sync_interval: z.number().optional(), // -1 = disabled, 0 = use global, >0 = custom interval in minutes - allowed_extra_headers: z - .array(z.string()) - .optional() - .refine( - (headers) => { - if (!headers || headers.length === 0) return true; - const hasWildcard = headers.includes("*"); - return !hasWildcard || headers.length === 1; - }, - { message: "Wildcard '*' cannot be combined with specific header names" }, - ), + is_code_mode_client: z.boolean().optional(), + is_ping_available: z.boolean().optional(), + allow_on_all_virtual_keys: z.boolean().optional(), + name: z + .string() + .min(1, "Name is required") + .refine((val) => !val.includes("-"), { + message: "Client name cannot contain hyphens", + }) + .refine((val) => !val.includes(" "), { + message: "Client name cannot contain spaces", + }) + .refine((val) => !/^[0-9]/.test(val), { + message: "Client name cannot start with a number", + }), + headers: z.record(z.string(), envVarSchema).optional().nullable(), + tools_to_execute: z + .array(z.string()) + .optional() + .refine( + (tools) => { + if (!tools || tools.length === 0) return true; + const hasWildcard = tools.includes("*"); + return !hasWildcard || tools.length === 1; + }, + { message: "Wildcard '*' cannot be combined with other tool names" }, + ) + .refine( + (tools) => { + if (!tools) return true; + return tools.length === new Set(tools).size; + }, + { message: "Duplicate tool names are not allowed" }, + ), + tools_to_auto_execute: z + .array(z.string()) + .optional() + .refine( + (tools) => { + if (!tools || tools.length === 0) return true; + const hasWildcard = tools.includes("*"); + return !hasWildcard || tools.length === 1; + }, + { message: "Wildcard '*' cannot be combined with other tool names" }, + ) + .refine( + (tools) => { + if (!tools) return true; + return tools.length === new Set(tools).size; + }, + { message: "Duplicate tool names are not allowed" }, + ), + tool_pricing: z + .record(z.string(), z.number().min(0, "Cost must be non-negative")) + .optional(), + tool_sync_interval: z.number().optional(), // -1 = disabled, 0 = use global, >0 = custom interval in minutes + allowed_extra_headers: z + .array(z.string()) + .optional() + .refine( + (headers) => { + if (!headers || headers.length === 0) return true; + const hasWildcard = headers.includes("*"); + return !hasWildcard || headers.length === 1; + }, + { message: "Wildcard '*' cannot be combined with specific header names" }, + ), }); // Global proxy type schema @@ -1025,88 +1120,102 @@ export const globalProxyTypeSchema = z.enum(["http", "socks5", "tcp"]); // Global proxy configuration schema export const globalProxyConfigSchema = z - .object({ - enabled: z.boolean(), - type: globalProxyTypeSchema, - url: z.string(), - username: z.string().optional(), - password: z.string().optional(), - ca_cert_pem: z.string().optional(), - no_proxy: z.string().optional(), - timeout: z.number().min(0).optional(), - skip_tls_verify: z.boolean().optional(), - enable_for_scim: z.boolean(), - enable_for_inference: z.boolean(), - enable_for_api: z.boolean(), - }) - .refine( - (data) => { - // URL is required when proxy is enabled - if (data.enabled && (!data.url || data.url.trim().length === 0)) { - return false; - } - return true; - }, - { - message: "Proxy URL is required when proxy is enabled", - path: ["url"], - }, - ) - .refine( - (data) => { - // Validate URL format when provided and enabled - if (data.enabled && data.url && data.url.trim().length > 0) { - try { - new URL(data.url); - return true; - } catch { - return false; - } - } - return true; - }, - { - message: "Must be a valid URL (e.g., http://proxy.example.com:8080)", - path: ["url"], - }, - ); + .object({ + enabled: z.boolean(), + type: globalProxyTypeSchema, + url: z.string(), + username: z.string().optional(), + password: z.string().optional(), + ca_cert_pem: z.string().optional(), + no_proxy: z.string().optional(), + timeout: z.number().min(0).optional(), + skip_tls_verify: z.boolean().optional(), + enable_for_scim: z.boolean(), + enable_for_inference: z.boolean(), + enable_for_api: z.boolean(), + }) + .refine( + (data) => { + // URL is required when proxy is enabled + if (data.enabled && (!data.url || data.url.trim().length === 0)) { + return false; + } + return true; + }, + { + message: "Proxy URL is required when proxy is enabled", + path: ["url"], + }, + ) + .refine( + (data) => { + // Validate URL format when provided and enabled + if (data.enabled && data.url && data.url.trim().length > 0) { + try { + new URL(data.url); + return true; + } catch { + return false; + } + } + return true; + }, + { + message: "Must be a valid URL (e.g., http://proxy.example.com:8080)", + path: ["url"], + }, + ); // Global proxy form schema for the ProxyView export const globalProxyFormSchema = z.object({ - proxy_config: globalProxyConfigSchema, + proxy_config: globalProxyConfigSchema, }); // Global header filter configuration schema // Controls which headers with the x-bf-eh-* prefix are forwarded to LLM providers export const globalHeaderFilterConfigSchema = z.object({ - allowlist: z.array(z.string()).optional(), // If non-empty, only these headers are allowed - denylist: z.array(z.string()).optional(), // Headers to always block + allowlist: z.array(z.string()).optional(), // If non-empty, only these headers are allowed + denylist: z.array(z.string()).optional(), // Headers to always block }); // Global header filter form schema for the HeaderFilterView export const globalHeaderFilterFormSchema = z.object({ - header_filter_config: globalHeaderFilterConfigSchema, + header_filter_config: globalHeaderFilterConfigSchema, }); // Routing rule creation schema export const routingRuleSchema = z - .object({ - name: z.string().min(1, "Rule name is required").max(255, "Rule name must be less than 255 characters"), - description: z.string().max(1000, "Description must be less than 1000 characters").optional(), - cel_expression: z.string().optional(), - provider: z.string().min(1, "Provider is required"), - model: z.string().optional(), - fallbacks: z.array(z.string()).optional().default([]), - scope: z.enum(["global", "team", "customer", "virtual_key"]), - scope_id: z.string().optional(), - priority: z.number().min(0, "Priority must be 0 or greater").max(1000, "Priority must be 1000 or less"), - enabled: z.boolean().default(true), - chain_rule: z.boolean().default(false), - }) - .refine((data) => data.scope === "global" || (data.scope_id != null && data.scope_id.trim() !== ""), { - message: "Scope ID is required when scope is not global", - path: ["scope_id"], - }); + .object({ + name: z + .string() + .min(1, "Rule name is required") + .max(255, "Rule name must be less than 255 characters"), + description: z + .string() + .max(1000, "Description must be less than 1000 characters") + .optional(), + cel_expression: z.string().optional(), + provider: z.string().min(1, "Provider is required"), + model: z.string().optional(), + fallbacks: z.array(z.string()).optional().default([]), + scope: z.enum(["global", "team", "customer", "virtual_key"]), + scope_id: z.string().optional(), + priority: z + .number() + .min(0, "Priority must be 0 or greater") + .max(1000, "Priority must be 1000 or less"), + enabled: z.boolean().default(true), + chain_rule: z.boolean().default(false), + }) + .refine( + (data) => + data.scope === "global" || + (data.scope_id != null && data.scope_id.trim() !== ""), + { + message: "Scope ID is required when scope is not global", + path: ["scope_id"], + }, + ); // Export type inference helpers export type EnvVar = z.infer; @@ -1115,7 +1224,9 @@ export type ModelProviderKeySchema = z.infer; export type NetworkConfigSchema = z.infer; export type NetworkFormConfigSchema = z.infer; export type ProxyFormConfigSchema = z.infer; -export type NetworkAndProxyFormSchema = z.infer; +export type NetworkAndProxyFormSchema = z.infer< + typeof networkAndProxyFormSchema +>; export type ProxyOnlyFormSchema = z.infer; export type OtelConfigSchema = z.infer; export type OtelFormSchema = z.infer; @@ -1125,9 +1236,15 @@ export type PrometheusConfigSchema = z.infer; export type PrometheusFormSchema = z.infer; export type NetworkOnlyFormSchema = z.infer; export type PerformanceFormSchema = z.infer; -export type CustomProviderConfigSchema = z.infer; +export type CustomProviderConfigSchema = z.infer< + typeof customProviderConfigSchema +>; export type GlobalProxyConfigSchema = z.infer; export type GlobalProxyFormSchema = z.infer; -export type GlobalHeaderFilterConfigSchema = z.infer; -export type GlobalHeaderFilterFormSchema = z.infer; -export type RoutingRuleSchema = z.infer; \ No newline at end of file +export type GlobalHeaderFilterConfigSchema = z.infer< + typeof globalHeaderFilterConfigSchema +>; +export type GlobalHeaderFilterFormSchema = z.infer< + typeof globalHeaderFilterFormSchema +>; +export type RoutingRuleSchema = z.infer;