diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index 3d722bc71b..ab33f6e77c 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -1055,12 +1055,31 @@ func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGette } } } + // Drop unsupported parameters identified by the litellmcompat plugin + jsonBody = dropUnsupportedParams(ctx, jsonBody) return jsonBody, nil } else { return rawBody, nil } } +// dropUnsupportedParams removes top-level JSON fields listed in the +// BifrostContextKeyLiteLLMCompatDroppedParams context key. The drop list is +// computed by the litellmcompat plugin by comparing the request's Params +// against the model's supported parameters allowlist from the catalog. +func dropUnsupportedParams(ctx context.Context, jsonBody []byte) []byte { + droppedParams, ok := ctx.Value(schemas.BifrostContextKeyLiteLLMCompatDroppedParams).([]string) + if !ok || len(droppedParams) == 0 { + return jsonBody + } + for _, param := range droppedParams { + if modified, err := sjson.DeleteBytes(jsonBody, param); err == nil { + jsonBody = modified + } + } + return jsonBody +} + // SetExtraHeadersHTTP sets additional headers from NetworkConfig to the standard HTTP request. // This allows users to configure custom headers for their provider requests. // Header keys are canonicalized using textproto.CanonicalMIMEHeaderKey to avoid duplicates. diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index c2a64e28fb..8ac9137a9e 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -255,6 +255,7 @@ const ( BifrostContextKeySessionID BifrostContextKey = "bifrost-session-id" // string session ID for the request (session stickiness) BifrostContextKeySessionTTL BifrostContextKey = "bifrost-session-ttl" // time.Duration session TTL for the request (session stickiness) BifrostContextKeyMCPExtraHeaders BifrostContextKey = "bifrost-mcp-extra-headers" // map[string][]string (these headers are forwarded only to the MCP while tool execution if they are in the allowlist of the MCP client) + BifrostContextKeyLiteLLMCompatDroppedParams BifrostContextKey = "litellmcompat-dropped-params" // []string (set by litellmcompat plugin - parameter names to drop from the provider request body) ) const ( @@ -964,4 +965,4 @@ type BifrostErrorExtraFields struct { RawResponse interface{} `json:"raw_response,omitempty"` LiteLLMCompat bool `json:"litellm_compat,omitempty"` KeyStatuses []KeyStatus `json:"key_statuses,omitempty"` -} \ No newline at end of file +} diff --git a/framework/modelcatalog/main.go b/framework/modelcatalog/main.go index 35730395ad..c40e4776b4 100644 --- a/framework/modelcatalog/main.go +++ b/framework/modelcatalog/main.go @@ -55,6 +55,10 @@ type ModelCatalog struct { // Values are normalized output types: "chat_completion", "responses", "text_completion" supportedOutputs map[string][]string + // Pre-parsed supported parameters index (keyed by model name, populated from model parameters supported_parameters) + // Values are parameter names the model accepts (e.g., "temperature", "top_p", "tools") + supportedParams map[string][]string + // Background sync worker syncTicker *time.Ticker done chan struct{} @@ -209,6 +213,7 @@ func Init(ctx context.Context, config *Config, configStore configstore.ConfigSto unfilteredModelPool: make(map[schemas.ModelProvider][]string), baseModelIndex: make(map[string]string), supportedOutputs: make(map[string][]string), + supportedParams: make(map[string][]string), done: make(chan struct{}), shouldSyncPricingFunc: shouldSyncPricingFunc, distributedLockManager: configstore.NewDistributedLockManager(configStore, logger, configstore.WithDefaultTTL(30*time.Second)), @@ -907,36 +912,123 @@ func (mc *ModelCatalog) IsResponsesSupported(model string, provider schemas.Mode return ok && slices.Contains(outputs, "responses") } -// buildSupportedOutputsIndex parses supported_endpoints from model parameters data -// and rebuilds the supportedOutputs index with normalized output type names. +// GetSupportedParameters returns the list of supported parameter names for a model. +// Returns nil if the model is not found in the catalog. +func (mc *ModelCatalog) GetSupportedParameters(model string) []string { + mc.mu.RLock() + params, ok := mc.supportedParams[model] + mc.mu.RUnlock() + if !ok { + return nil + } + // Return a copy to prevent external modification + result := make([]string, len(params)) + copy(result, params) + return result +} + +// buildSupportedOutputsIndex parses supported_endpoints and model parameters/capabilities +// from model parameters data and rebuilds the in-memory indexes. func (mc *ModelCatalog) buildSupportedOutputsIndex(paramsData map[string]json.RawMessage) { - newIndex := make(map[string][]string, len(paramsData)) + newOutputsIndex := make(map[string][]string, len(paramsData)) + newParamsIndex := make(map[string][]string, len(paramsData)) for model, data := range paramsData { - var params struct { - SupportedEndpoints []string `json:"supported_endpoints"` - } - if err := json.Unmarshal(data, ¶ms); err != nil || len(params.SupportedEndpoints) == 0 { + var parsed modelParametersParseResult + if err := json.Unmarshal(data, &parsed); err != nil { continue } - outputs := make([]string, 0, len(params.SupportedEndpoints)) - for _, endpoint := range params.SupportedEndpoints { - if normalized := normalizeEndpointToOutputType(endpoint); normalized != "" { - if !slices.Contains(outputs, normalized) { - outputs = append(outputs, normalized) + + // Build supported outputs from endpoints + if len(parsed.SupportedEndpoints) > 0 { + outputs := make([]string, 0, len(parsed.SupportedEndpoints)) + for _, endpoint := range parsed.SupportedEndpoints { + if normalized := normalizeEndpointToOutputType(endpoint); normalized != "" { + if !slices.Contains(outputs, normalized) { + outputs = append(outputs, normalized) + } } } + if len(outputs) > 0 { + newOutputsIndex[model] = outputs + } } - if len(outputs) > 0 { - newIndex[model] = outputs + + // Build supported params from model_parameters IDs and supports_* flags + supported := extractSupportedParams(&parsed) + if len(supported) > 0 { + newParamsIndex[model] = supported } } mc.mu.Lock() - mc.supportedOutputs = newIndex + mc.supportedOutputs = newOutputsIndex + mc.supportedParams = newParamsIndex mc.mu.Unlock() } +// modelParametersParseResult is the parsed result type used by buildSupportedOutputsIndex. +type modelParametersParseResult struct { + SupportedEndpoints []string `json:"supported_endpoints"` + ModelParameters []struct { + ID string `json:"id"` + } `json:"model_parameters"` + SupportsFunctionCalling *bool `json:"supports_function_calling"` + SupportsParallelFunctionCalling *bool `json:"supports_parallel_function_calling"` + SupportsToolChoice *bool `json:"supports_tool_choice"` + SupportsReasoning *bool `json:"supports_reasoning"` + SupportsServiceTier *bool `json:"supports_service_tier"` + SupportsPromptCaching *bool `json:"supports_prompt_caching"` +} + +// extractSupportedParams builds a list of supported OpenAI-compatible parameter +// names from model_parameters[].id values and supports_* boolean flags. +func extractSupportedParams(parsed *modelParametersParseResult) []string { + var supported []string + addParam := func(name string) { + if !slices.Contains(supported, name) { + supported = append(supported, name) + } + } + + // From model_parameters[].id — map IDs to request param names + for _, mp := range parsed.ModelParameters { + switch mp.ID { + case "reasoning_effort", "reasoning_summary": + addParam("reasoning") + case "web_search": + addParam("web_search_options") + case "promptTools", "image_detail", "stream": + // skip — not top-level request parameters + default: + addParam(mp.ID) + } + } + + // From supports_* boolean flags + if parsed.SupportsFunctionCalling != nil && *parsed.SupportsFunctionCalling { + addParam("tools") + } + if parsed.SupportsParallelFunctionCalling != nil && *parsed.SupportsParallelFunctionCalling { + addParam("parallel_tool_calls") + } + if parsed.SupportsToolChoice != nil && *parsed.SupportsToolChoice { + addParam("tool_choice") + } + if parsed.SupportsReasoning != nil && *parsed.SupportsReasoning { + addParam("reasoning") + } + if parsed.SupportsServiceTier != nil && *parsed.SupportsServiceTier { + addParam("service_tier") + } + if parsed.SupportsPromptCaching != nil && *parsed.SupportsPromptCaching { + addParam("prompt_cache_key") + addParam("prompt_cache_retention") + } + + return supported +} + // populateModelPool populates the model pool with all available models per provider (thread-safe) func (mc *ModelCatalog) populateModelPoolFromPricingData() { // Acquire write lock for the entire rebuild operation @@ -1019,6 +1111,7 @@ func NewTestCatalog(baseModelIndex map[string]string) *ModelCatalog { baseModelIndex: baseModelIndex, pricingData: make(map[string]configstoreTables.TableModelPricing), supportedOutputs: make(map[string][]string), + supportedParams: make(map[string][]string), done: make(chan struct{}), } } diff --git a/plugins/litellmcompat/dropparams.go b/plugins/litellmcompat/dropparams.go new file mode 100644 index 0000000000..9a7169d57c --- /dev/null +++ b/plugins/litellmcompat/dropparams.go @@ -0,0 +1,175 @@ +package litellmcompat + +import ( + "slices" + + "github.com/maximhq/bifrost/core/schemas" +) + +// computeUnsupportedParams checks each parameter field on the request's Params +// and returns the JSON field names of parameters that are set but not in the +// model's supported parameters allowlist. It does NOT mutate the request. +func computeUnsupportedParams(req *schemas.BifrostRequest, supportedParams []string) []string { + if req == nil { + return nil + } + switch { + case req.ChatRequest != nil && req.ChatRequest.Params != nil: + return unsupportedChatParams(req.ChatRequest.Params, supportedParams) + case req.ResponsesRequest != nil && req.ResponsesRequest.Params != nil: + return unsupportedResponsesParams(req.ResponsesRequest.Params, supportedParams) + case req.TextCompletionRequest != nil && req.TextCompletionRequest.Params != nil: + return unsupportedTextCompletionParams(req.TextCompletionRequest.Params, supportedParams) + } + return nil +} + +func unsupportedChatParams(p *schemas.ChatParameters, supported []string) []string { + var dropped []string + if p.Audio != nil && !slices.Contains(supported, "audio") { + dropped = append(dropped, "audio") + } + if p.FrequencyPenalty != nil && !slices.Contains(supported, "frequency_penalty") { + dropped = append(dropped, "frequency_penalty") + } + if p.LogitBias != nil && !slices.Contains(supported, "logit_bias") { + dropped = append(dropped, "logit_bias") + } + if p.LogProbs != nil && !slices.Contains(supported, "logprobs") { + dropped = append(dropped, "logprobs") + } + if p.MaxCompletionTokens != nil && !slices.Contains(supported, "max_completion_tokens") { + dropped = append(dropped, "max_completion_tokens") + } + if p.Metadata != nil && !slices.Contains(supported, "metadata") { + dropped = append(dropped, "metadata") + } + if p.ParallelToolCalls != nil && !slices.Contains(supported, "parallel_tool_calls") { + dropped = append(dropped, "parallel_tool_calls") + } + if p.Prediction != nil && !slices.Contains(supported, "prediction") { + dropped = append(dropped, "prediction") + } + if p.PresencePenalty != nil && !slices.Contains(supported, "presence_penalty") { + dropped = append(dropped, "presence_penalty") + } + if p.PromptCacheKey != nil && !slices.Contains(supported, "prompt_cache_key") { + dropped = append(dropped, "prompt_cache_key") + } + if p.PromptCacheRetention != nil && !slices.Contains(supported, "prompt_cache_retention") { + dropped = append(dropped, "prompt_cache_retention") + } + if p.Reasoning != nil && !slices.Contains(supported, "reasoning") { + dropped = append(dropped, "reasoning") + } + if p.ResponseFormat != nil && !slices.Contains(supported, "response_format") { + dropped = append(dropped, "response_format") + } + if p.Seed != nil && !slices.Contains(supported, "seed") { + dropped = append(dropped, "seed") + } + if p.ServiceTier != nil && !slices.Contains(supported, "service_tier") { + dropped = append(dropped, "service_tier") + } + if len(p.Stop) > 0 && !slices.Contains(supported, "stop") { + dropped = append(dropped, "stop") + } + if p.Temperature != nil && !slices.Contains(supported, "temperature") { + dropped = append(dropped, "temperature") + } + if p.TopLogProbs != nil && !slices.Contains(supported, "top_logprobs") { + dropped = append(dropped, "top_logprobs") + } + if p.TopP != nil && !slices.Contains(supported, "top_p") { + dropped = append(dropped, "top_p") + } + if p.ToolChoice != nil && !slices.Contains(supported, "tool_choice") { + dropped = append(dropped, "tool_choice") + } + if len(p.Tools) > 0 && !slices.Contains(supported, "tools") { + dropped = append(dropped, "tools") + } + if p.Verbosity != nil && !slices.Contains(supported, "verbosity") { + dropped = append(dropped, "verbosity") + } + return dropped +} + +func unsupportedResponsesParams(p *schemas.ResponsesParameters, supported []string) []string { + var dropped []string + if p.MaxOutputTokens != nil && !slices.Contains(supported, "max_output_tokens") { + dropped = append(dropped, "max_output_tokens") + } + if p.MaxToolCalls != nil && !slices.Contains(supported, "max_tool_calls") { + dropped = append(dropped, "max_tool_calls") + } + if p.Metadata != nil && !slices.Contains(supported, "metadata") { + dropped = append(dropped, "metadata") + } + if p.ParallelToolCalls != nil && !slices.Contains(supported, "parallel_tool_calls") { + dropped = append(dropped, "parallel_tool_calls") + } + if p.PromptCacheKey != nil && !slices.Contains(supported, "prompt_cache_key") { + dropped = append(dropped, "prompt_cache_key") + } + if p.Reasoning != nil && !slices.Contains(supported, "reasoning") { + dropped = append(dropped, "reasoning") + } + if p.ServiceTier != nil && !slices.Contains(supported, "service_tier") { + dropped = append(dropped, "service_tier") + } + if p.Temperature != nil && !slices.Contains(supported, "temperature") { + dropped = append(dropped, "temperature") + } + if p.Text != nil && !slices.Contains(supported, "text") { + dropped = append(dropped, "text") + } + if p.TopLogProbs != nil && !slices.Contains(supported, "top_logprobs") { + dropped = append(dropped, "top_logprobs") + } + if p.TopP != nil && !slices.Contains(supported, "top_p") { + dropped = append(dropped, "top_p") + } + if p.ToolChoice != nil && !slices.Contains(supported, "tool_choice") { + dropped = append(dropped, "tool_choice") + } + if len(p.Tools) > 0 && !slices.Contains(supported, "tools") { + dropped = append(dropped, "tools") + } + return dropped +} + +func unsupportedTextCompletionParams(p *schemas.TextCompletionParameters, supported []string) []string { + var dropped []string + if p.FrequencyPenalty != nil && !slices.Contains(supported, "frequency_penalty") { + dropped = append(dropped, "frequency_penalty") + } + if p.LogitBias != nil && !slices.Contains(supported, "logit_bias") { + dropped = append(dropped, "logit_bias") + } + if p.LogProbs != nil && !slices.Contains(supported, "logprobs") { + dropped = append(dropped, "logprobs") + } + if p.MaxTokens != nil && !slices.Contains(supported, "max_tokens") { + dropped = append(dropped, "max_tokens") + } + if p.N != nil && !slices.Contains(supported, "n") { + dropped = append(dropped, "n") + } + if p.PresencePenalty != nil && !slices.Contains(supported, "presence_penalty") { + dropped = append(dropped, "presence_penalty") + } + if p.Seed != nil && !slices.Contains(supported, "seed") { + dropped = append(dropped, "seed") + } + if len(p.Stop) > 0 && !slices.Contains(supported, "stop") { + dropped = append(dropped, "stop") + } + if p.Temperature != nil && !slices.Contains(supported, "temperature") { + dropped = append(dropped, "temperature") + } + if p.TopP != nil && !slices.Contains(supported, "top_p") { + dropped = append(dropped, "top_p") + } + return dropped +} diff --git a/plugins/litellmcompat/main.go b/plugins/litellmcompat/main.go index 59de6fedb8..5e9414ed5b 100644 --- a/plugins/litellmcompat/main.go +++ b/plugins/litellmcompat/main.go @@ -86,6 +86,20 @@ func (p *LiteLLMCompatPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schem // Apply request transforms in sequence req = transformTextToChatRequest(ctx, req, p.modelCatalog, p.logger) req = transformChatToResponsesRequest(ctx, req, p.modelCatalog, p.logger) + + // Compute unsupported parameters to drop based on model catalog allowlist + if ctx != nil && p.modelCatalog != nil { + model := getModelFromRequest(req) + if model != "" { + if supportedParams := p.modelCatalog.GetSupportedParameters(model); supportedParams != nil { + droppedParams := computeUnsupportedParams(req, supportedParams) + if len(droppedParams) > 0 { + ctx.SetValue(schemas.BifrostContextKeyLiteLLMCompatDroppedParams, droppedParams) + } + } + } + } + return req, nil, nil } @@ -103,7 +117,25 @@ func (p *LiteLLMCompatPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *s return result, bifrostErr, nil } +// getModelFromRequest extracts the model name from a BifrostRequest, +// checking each request type in order. +func getModelFromRequest(req *schemas.BifrostRequest) string { + if req == nil { + return "" + } + if req.ChatRequest != nil { + return req.ChatRequest.Model + } + if req.ResponsesRequest != nil { + return req.ResponsesRequest.Model + } + if req.TextCompletionRequest != nil { + return req.TextCompletionRequest.Model + } + return "" +} + // Cleanup performs plugin cleanup func (p *LiteLLMCompatPlugin) Cleanup() error { return nil -} +} \ No newline at end of file