diff --git a/core/bifrost.go b/core/bifrost.go index 66a8351307..816e8e775c 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -33,21 +33,22 @@ type ChannelMessage struct { // It handles request routing, provider management, and response processing. type Bifrost struct { ctx context.Context - account schemas.Account // account interface - plugins atomic.Pointer[[]schemas.Plugin] // list of plugins - requestQueues sync.Map // provider request queues (thread-safe) - waitGroups sync.Map // wait groups for each provider (thread-safe) - providerMutexes sync.Map // mutexes for each provider to prevent concurrent updates (thread-safe) - channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init - responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init - errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init - responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init - pluginPipelinePool sync.Pool // Pool for PluginPipeline objects - bifrostRequestPool sync.Pool // Pool for BifrostRequest objects - logger schemas.Logger // logger instance, default logger is used if not provided - mcpManager *MCPManager // MCP integration manager (nil if MCP not configured) - dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. - keySelector schemas.KeySelector // Custom key selector function + account schemas.Account // account interface + plugins atomic.Pointer[[]schemas.Plugin] // list of plugins + providers atomic.Pointer[[]schemas.Provider] // list of providers + requestQueues sync.Map // provider request queues (thread-safe) + waitGroups sync.Map // wait groups for each provider (thread-safe) + providerMutexes sync.Map // mutexes for each provider to prevent concurrent updates (thread-safe) + channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init + responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init + errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init + responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init + pluginPipelinePool sync.Pool // Pool for PluginPipeline objects + bifrostRequestPool sync.Pool // Pool for BifrostRequest objects + logger schemas.Logger // logger instance, default logger is used if not provided + mcpManager *MCPManager // MCP integration manager (nil if MCP not configured) + dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. + keySelector schemas.KeySelector // Custom key selector function } // PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks how many plugins ran, and manages short-circuiting and error aggregation. @@ -91,6 +92,10 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { keySelector: config.KeySelector, } bifrost.plugins.Store(&config.Plugins) + + // Initialize providers slice + bifrost.providers.Store(&[]schemas.Provider{}) + bifrost.dropExcessRequests.Store(config.DropExcessRequests) if bifrost.keySelector == nil { @@ -203,6 +208,169 @@ func (bifrost *Bifrost) ReloadConfig(config schemas.BifrostConfig) error { // PUBLIC API METHODS +// ListModelsRequest sends a list models request to the specified provider. +func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + if req == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "list models request is nil", + }, + } + } + if req.Provider == "" { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "provider is required for list models request", + }, + } + } + + request := &schemas.BifrostListModelsRequest{ + Provider: req.Provider, + PageSize: req.PageSize, + PageToken: req.PageToken, + ExtraParams: req.ExtraParams, + } + + provider := bifrost.getProviderByKey(req.Provider) + if provider == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "provider not found for list models request", + }, + } + } + + // Determine the base provider type for key requirement checks + baseProvider := req.Provider + providerConfig, err := bifrost.account.GetConfigForProvider(req.Provider) + if err == nil && providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.BaseProviderType != "" { + baseProvider = providerConfig.CustomProviderConfig.BaseProviderType + } + + // Get API key for the provider if required + key := schemas.Key{} + if providerRequiresKey(baseProvider) { + key, err = bifrost.selectKeyFromProviderForModel(&ctx, schemas.ListModelsRequest, req.Provider, "", baseProvider) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: err.Error(), + Error: err, + }, + } + } + } + + response, bifrostErr := provider.ListModels(ctx, key, request) + if bifrostErr != nil { + return nil, bifrostErr + } + return response, nil +} + +// ListAllModels lists all models from all configured providers. +// It accumulates responses from all providers with a limit of 1000 per provider to get all results. +func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + startTime := time.Now() + + if request == nil { + request = &schemas.BifrostListModelsRequest{} + } + + providerKeys, err := bifrost.account.GetConfiguredProviders() + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "failed to get configured providers", + Error: err, + }, + } + } + + // Accumulate all models from all providers + allModels := make([]schemas.Model, 0) + var firstError *schemas.BifrostError + + for _, providerKey := range providerKeys { + if strings.TrimSpace(string(providerKey)) == "" { + continue + } + + // Create request for this provider with limit of 1000 + providerRequest := &schemas.BifrostListModelsRequest{ + Provider: providerKey, + PageSize: schemas.DefaultPageSize, + } + + iterations := 0 + for { + iterations++ + if iterations > schemas.MaxPaginationRequests { + bifrost.logger.Warn(fmt.Sprintf("reached maximum pagination requests (%d) for provider %s", schemas.MaxPaginationRequests, providerKey)) + break + } + + response, bifrostErr := bifrost.ListModelsRequest(ctx, providerRequest) + if bifrostErr != nil { + // Log the error but continue with other providers + bifrost.logger.Warn(fmt.Sprintf("failed to list models for provider %s: %v", providerKey, bifrostErr.Error.Message)) + if firstError == nil { + firstError = bifrostErr + } + break + } + + if response == nil { + break + } + + if len(response.Data) > 0 { + allModels = append(allModels, response.Data...) + } + + // Check if there are more pages + if response.NextPageToken == "" { + break + } + + // Set the page token for the next request + providerRequest.PageToken = response.NextPageToken + } + } + + // If we couldn't get any models from any provider, return the first error + if len(allModels) == 0 && firstError != nil { + return nil, firstError + } + + // Sort models alphabetically by ID + sort.Slice(allModels, func(i, j int) bool { + return allModels[i].ID < allModels[j].ID + }) + + // Calculate total elapsed time + elapsedTime := time.Since(startTime).Milliseconds() + + // Return aggregated response with accumulated latency + response := &schemas.BifrostListModelsResponse{ + Data: allModels, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.ListModelsRequest, + Latency: elapsedTime, + }, + } + + response = response.ApplyPagination(request.PageSize, request.PageToken) + + return response, nil +} + // TextCompletionRequest sends a text completion request to the specified provider. func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { if req == nil { @@ -1040,6 +1208,21 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi waitGroupValue, _ := bifrost.waitGroups.Load(providerKey) currentWaitGroup := waitGroupValue.(*sync.WaitGroup) + // Atomically append provider to the providers slice + for { + oldPtr := bifrost.providers.Load() + var oldSlice []schemas.Provider + if oldPtr != nil { + oldSlice = *oldPtr + } + newSlice := make([]schemas.Provider, len(oldSlice)+1) + copy(newSlice, oldSlice) + newSlice[len(oldSlice)] = provider + if bifrost.providers.CompareAndSwap(oldPtr, &newSlice) { + break + } + } + for range providerConfig.ConcurrencyAndBufferSize.Concurrency { currentWaitGroup.Add(1) go bifrost.requestWorker(provider, providerConfig, queue) @@ -1092,6 +1275,23 @@ func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (cha return queue, nil } +// getProviderByKey retrieves a provider instance from the providers array by its provider key. +// Returns the provider if found, or nil if no provider with the given key exists. +func (bifrost *Bifrost) getProviderByKey(providerKey schemas.ModelProvider) schemas.Provider { + providers := bifrost.providers.Load() + if providers == nil { + return nil + } + + for _, provider := range *providers { + if provider.GetProviderKey() == providerKey { + return provider + } + } + + return nil +} + // CORE INTERNAL LOGIC // shouldTryFallbacks handles the primary error and returns true if we should proceed with fallbacks, false if we should return immediately @@ -1589,7 +1789,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas key := schemas.Key{} if providerRequiresKey(baseProvider) { // Use the custom provider name for actual key selection, but pass base provider type for key validation - key, err = bifrost.selectKeyFromProviderForModel(&req.Context, provider.GetProviderKey(), model, baseProvider) + key, err = bifrost.selectKeyFromProviderForModel(&req.Context, req.RequestType, provider.GetProviderKey(), model, baseProvider) if err != nil { bifrost.logger.Warn("error selecting key for model %s: %v", model, err) req.Err <- schemas.BifrostError{ @@ -1954,7 +2154,7 @@ func (bifrost *Bifrost) releaseBifrostRequest(req *schemas.BifrostRequest) { // selectKeyFromProviderForModel selects an appropriate API key for a given provider and model. // It uses weighted random selection if multiple keys are available. -func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) (schemas.Key, error) { +func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) (schemas.Key, error) { // Check if key has been set in the context explicitly if ctx != nil { key, ok := (*ctx).Value(schemas.BifrostContextKeyDirectKey).(schemas.Key) @@ -1974,28 +2174,36 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov // filter out keys which dont support the model, if the key has no models, it is supported for all models var supportedKeys []schemas.Key - for _, key := range keys { - modelSupported := (slices.Contains(key.Models, model) && (strings.TrimSpace(key.Value) != "" || canProviderKeyValueBeEmpty(baseProviderType))) || len(key.Models) == 0 - - // Additional deployment checks for Azure and Bedrock - deploymentSupported := true - if baseProviderType == schemas.Azure && key.AzureKeyConfig != nil { - // For Azure, check if deployment exists for this model - if len(key.AzureKeyConfig.Deployments) > 0 { - _, deploymentSupported = key.AzureKeyConfig.Deployments[model] - } - } else if baseProviderType == schemas.Bedrock && key.BedrockKeyConfig != nil { - // For Bedrock, check if deployment exists for this model - if len(key.BedrockKeyConfig.Deployments) > 0 { - _, deploymentSupported = key.BedrockKeyConfig.Deployments[model] + if requestType == schemas.ListModelsRequest { + // Skip deployment check but still check if the key has a value + for _, k := range keys { + if strings.TrimSpace(k.Value) != "" || canProviderKeyValueBeEmpty(baseProviderType) { + supportedKeys = append(supportedKeys, k) } } + } else { + for _, key := range keys { + modelSupported := (slices.Contains(key.Models, model) && (strings.TrimSpace(key.Value) != "" || canProviderKeyValueBeEmpty(baseProviderType))) || len(key.Models) == 0 - if modelSupported && deploymentSupported { - supportedKeys = append(supportedKeys, key) + // Additional deployment checks for Azure and Bedrock + deploymentSupported := true + if baseProviderType == schemas.Azure && key.AzureKeyConfig != nil { + // For Azure, check if deployment exists for this model + if len(key.AzureKeyConfig.Deployments) > 0 { + _, deploymentSupported = key.AzureKeyConfig.Deployments[model] + } + } else if baseProviderType == schemas.Bedrock && key.BedrockKeyConfig != nil { + // For Bedrock, check if deployment exists for this model + if len(key.BedrockKeyConfig.Deployments) > 0 { + _, deploymentSupported = key.BedrockKeyConfig.Deployments[model] + } + } + + if modelSupported && deploymentSupported { + supportedKeys = append(supportedKeys, key) + } } } - if len(supportedKeys) == 0 { if baseProviderType == schemas.Azure || baseProviderType == schemas.Bedrock { return schemas.Key{}, fmt.Errorf("no keys found that support model/deployment: %s", model) diff --git a/core/changelog.md b/core/changelog.md index 184a81a5c4..37541b90e0 100644 --- a/core/changelog.md +++ b/core/changelog.md @@ -1,9 +1,5 @@ -- bug: fixed embedding request not being handled in `GetExtraFields()` method of `BifrostResponse` -- fix: added latency calculation for vertex native requests -- feat: added cached tokens and reasoning tokens to the usage metadata for chat completions -- feat: added global region support for vertex API -- fix: added filter for extra fields in chat completions request for Mistral provider -- fix: fixed ResponsesComputerToolCallPendingSafetyCheck code field \ No newline at end of file +- feat: added ListModels method to Provider interface +- feat: enabled provider tracking in Bifrost core for API exposure \ No newline at end of file diff --git a/core/providers/anthropic.go b/core/providers/anthropic.go index 4a76a22582..2b23bb351d 100644 --- a/core/providers/anthropic.go +++ b/core/providers/anthropic.go @@ -155,7 +155,7 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) req.SetRequestURI(url) - req.Header.SetMethod("POST") + req.Header.SetMethod(http.MethodPost) req.Header.SetContentType("application/json") req.Header.Set("x-api-key", key) req.Header.Set("anthropic-version", provider.apiVersion) @@ -188,6 +188,73 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB return bodyCopy, latency, nil } +// ListModels performs a list models request to Anthropic's API. +func (provider *AnthropicProvider) ListModels(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ListModelsRequest); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + // Build URL using centralized URL construction + requestURL := anthropic.ToAnthropicListModelsURL(request, provider.networkConfig.BaseURL+"/v1/models") + req.SetRequestURI(requestURL) + req.Header.SetMethod(http.MethodGet) + req.Header.SetContentType("application/json") + req.Header.Set("x-api-key", key.Value) + req.Header.Set("anthropic-version", provider.apiVersion) + + // Make request + latency, bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body()))) + + var errorResp anthropic.AnthropicError + + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Type = &errorResp.Error.Type + bifrostErr.Error.Message = errorResp.Error.Message + + return nil, bifrostErr + } + + // Parse Anthropic's response + var anthropicResponse anthropic.AnthropicListModelsResponse + rawResponse, bifrostErr := handleProviderResponse(resp.Body(), &anthropicResponse, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Create final response + response := anthropicResponse.ToBifrostListModelsResponse(providerName) + + // Set ExtraFields + response.ExtraFields.Provider = providerName + response.ExtraFields.RequestType = schemas.ListModelsRequest + response.ExtraFields.Latency = latency.Milliseconds() + + // Set raw response if enabled + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + return response, nil +} + // TextCompletion performs a text completion request to Anthropic's API. // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. @@ -345,7 +412,7 @@ func handleAnthropicChatCompletionStreaming( } // Create HTTP request for streaming - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody)) if err != nil { if errors.Is(err, context.Canceled) { return nil, &schemas.BifrostError{ @@ -593,7 +660,7 @@ func (provider *AnthropicProvider) ResponsesStream(ctx context.Context, postHook } // Create HTTP request for streaming - req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/v1/messages", bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.networkConfig.BaseURL+"/v1/messages", bytes.NewReader(jsonBody)) if err != nil { if errors.Is(err, context.Canceled) { return nil, &schemas.BifrostError{ diff --git a/core/providers/azure.go b/core/providers/azure.go index bb126c4d66..e0fccee831 100644 --- a/core/providers/azure.go +++ b/core/providers/azure.go @@ -10,6 +10,7 @@ import ( "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/core/schemas/providers/azure" "github.com/maximhq/bifrost/core/schemas/providers/openai" "github.com/valyala/fasthttp" ) @@ -88,7 +89,7 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { - apiVersion = schemas.Ptr("2024-02-01") + apiVersion = schemas.Ptr(azure.DefaultAzureAPIVersion) } url = fmt.Sprintf("%s/openai/deployments/%s/%s?api-version=%s", url, deployment, path, *apiVersion) @@ -106,7 +107,7 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) req.SetRequestURI(url) - req.Header.SetMethod("POST") + req.Header.SetMethod(http.MethodPost) req.Header.SetContentType("application/json") if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken)) @@ -143,6 +144,97 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody return bodyCopy, latency, nil } +// ListModels performs a list models request to Azure's API. +// It retrieves all models accessible by the Azure OpenAI resource +func (provider *AzureProvider) ListModels(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + // Validate Azure key configuration + if key.AzureKeyConfig == nil { + return nil, newConfigurationError("azure key config not set", schemas.Azure) + } + + if key.AzureKeyConfig.Endpoint == "" { + return nil, newConfigurationError("endpoint not set", schemas.Azure) + } + + // Get API version + apiVersion := key.AzureKeyConfig.APIVersion + if apiVersion == nil { + apiVersion = schemas.Ptr(azure.DefaultAzureAPIVersion) + } + + // Construct URL - list models is a resource-level operation, doesn't require deployment + url := fmt.Sprintf("%s/openai/models?api-version=%s", key.AzureKeyConfig.Endpoint, *apiVersion) + + // Create the request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(url) + req.Header.SetMethod(http.MethodGet) + req.Header.SetContentType("application/json") + + // Set Azure authentication - either Bearer token or api-key + if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken)) + // Ensure api-key is not accidentally present (from extra headers, etc.) + req.Header.Del("api-key") + } else { + req.Header.Set("api-key", key.Value) + } + + // Send the request and measure latency + latency, bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from azure provider: %s", string(resp.Body()))) + + var errorResp map[string]interface{} + + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = fmt.Sprintf("%s error: %v", schemas.Azure, errorResp) + + return nil, bifrostErr + } + + // Read the response body and copy it before releasing the response + // to avoid use-after-free since resp.Body() references fasthttp's internal buffer + responseBody := append([]byte(nil), resp.Body()...) + + // Parse Azure-specific response + azureResponse := &azure.AzureListModelsResponse{} + rawResponse, bifrostErr := handleProviderResponse(responseBody, azureResponse, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Convert to Bifrost response + response := azureResponse.ToBifrostListModelsResponse() + if response == nil { + return nil, newBifrostOperationError("failed to convert Azure model list response", nil, schemas.Azure) + } + + response = response.ApplyPagination(request.PageSize, request.PageToken) + + response.ExtraFields.Provider = schemas.Azure + response.ExtraFields.Latency = latency.Milliseconds() + response.ExtraFields.RequestType = schemas.ListModelsRequest + + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + return response, nil +} + // TextCompletion performs a text completion request to Azure's API. // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. @@ -202,7 +294,7 @@ func (provider *AzureProvider) TextCompletionStream(ctx context.Context, postHoo apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { - apiVersion = schemas.Ptr("2024-02-01") + apiVersion = schemas.Ptr(azure.DefaultAzureAPIVersion) } fullURL = fmt.Sprintf("%s/openai/deployments/%s/completions?api-version=%s", baseURL, deployment, *apiVersion) @@ -298,7 +390,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { - apiVersion = schemas.Ptr("2024-02-01") + apiVersion = schemas.Ptr(azure.DefaultAzureAPIVersion) } fullURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", baseURL, deployment, *apiVersion) diff --git a/core/providers/bedrock.go b/core/providers/bedrock.go index 25575236aa..d5c8d49d27 100644 --- a/core/providers/bedrock.go +++ b/core/providers/bedrock.go @@ -89,7 +89,7 @@ func (provider *BedrockProvider) GetProviderKey() schemas.ModelProvider { func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBody interface{}, path string, key schemas.Key) ([]byte, time.Duration, *schemas.BifrostError) { config := key.BedrockKeyConfig - region := "us-east-1" + region := bedrock.DefaultBedrockRegion if config.Region != nil { region = *config.Region } @@ -218,7 +218,7 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx context.Context, reque // Format the path with proper model identifier for streaming path := provider.getModelPath("converse-stream", model, key) - region := "us-east-1" + region := bedrock.DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil { region = *key.BedrockKeyConfig.Region } @@ -230,7 +230,7 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx context.Context, reque } // Create HTTP request for streaming - req, reqErr := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewReader(jsonBody)) + req, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewReader(jsonBody)) if reqErr != nil { return nil, newBifrostOperationError("error creating request", reqErr, providerName) } @@ -345,6 +345,142 @@ func signAWSRequest(ctx context.Context, req *http.Request, accessKey, secretKey return nil } +// ListModels performs a list models request to Bedrock's API. +// It retrieves all foundation models available in Amazon Bedrock. +func (provider *BedrockProvider) ListModels(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ListModelsRequest); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + if key.BedrockKeyConfig == nil { + return nil, newConfigurationError("bedrock key config is not provided", providerName) + } + + config := key.BedrockKeyConfig + + region := bedrock.DefaultBedrockRegion + if config.Region != nil { + region = *config.Region + } + + // List models endpoint uses the bedrock service (not bedrock-runtime) + url := fmt.Sprintf("https://bedrock.%s.amazonaws.com/foundation-models", region) + + // Create the GET request without a body + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: &schemas.ErrorField{ + Message: "error creating request", + Error: err, + }, + } + } + + // Set any extra headers from network config + setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil) + + // If Value is set, use API Key authentication - else use IAM role authentication + if key.Value != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value)) + } else { + // Sign the request using either explicit credentials or IAM role authentication + if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, region, "bedrock", providerName); err != nil { + return nil, err + } + } + + // Execute the request and measure latency + startTime := time.Now() + resp, err := provider.client.Do(req) + latency := time.Since(startTime) + if err != nil { + if errors.Is(err, context.Canceled) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Type: schemas.Ptr(schemas.RequestCancelled), + Message: schemas.ErrRequestCancelled, + Error: err, + }, + } + } + if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) { + return nil, newBifrostOperationError(schemas.ErrProviderRequestTimedOut, err, providerName) + } + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: schemas.ErrProviderRequest, + Error: err, + }, + } + } + defer resp.Body.Close() + + // Read response body + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: &schemas.ErrorField{ + Message: "error reading request", + Error: err, + }, + } + } + + if resp.StatusCode != http.StatusOK { + var errorResp bedrock.BedrockError + + if err := sonic.Unmarshal(responseBody, &errorResp); err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: &resp.StatusCode, + Error: &schemas.ErrorField{ + Message: schemas.ErrProviderResponseUnmarshal, + Error: err, + }, + } + } + + return nil, &schemas.BifrostError{ + StatusCode: &resp.StatusCode, + Error: &schemas.ErrorField{ + Message: errorResp.Message, + }, + } + } + + // Parse Bedrock-specific response + bedrockResponse := &bedrock.BedrockListModelsResponse{} + rawResponse, bifrostErr := handleProviderResponse(responseBody, bedrockResponse, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Convert to Bifrost response + response := bedrockResponse.ToBifrostListModelsResponse(providerName) + if response == nil { + return nil, newBifrostOperationError("failed to convert Bedrock model list response", nil, providerName) + } + + response = response.ApplyPagination(request.PageSize, request.PageToken) + + response.ExtraFields.Provider = providerName + response.ExtraFields.Latency = latency.Milliseconds() + response.ExtraFields.RequestType = schemas.ListModelsRequest + + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + return response, nil +} + // TextCompletion performs a text completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. diff --git a/core/providers/cerebras.go b/core/providers/cerebras.go index 4f37281baf..d1b15f5c0a 100644 --- a/core/providers/cerebras.go +++ b/core/providers/cerebras.go @@ -61,6 +61,11 @@ func (provider *CerebrasProvider) GetProviderKey() schemas.ModelProvider { return schemas.Cerebras } +// ListModels performs a list models request to Cerebras's API. +func (provider *CerebrasProvider) ListModels(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + return handleOpenAIListModelsRequest(ctx, provider.client, request, provider.networkConfig.BaseURL+"/v1/models", key, provider.networkConfig.ExtraHeaders, provider.GetProviderKey(), provider.sendBackRawResponse, provider.logger) +} + // TextCompletion performs a text completion request to Cerebras's API. // It formats the request, sends it to Cerebras, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. diff --git a/core/providers/cohere.go b/core/providers/cohere.go index fced620744..796587789c 100644 --- a/core/providers/cohere.go +++ b/core/providers/cohere.go @@ -95,6 +95,68 @@ func (provider *CohereProvider) GetProviderKey() schemas.ModelProvider { return getProviderName(schemas.Cohere, provider.customProviderConfig) } +// ListModels performs a list models request to Cohere's API. +func (provider *CohereProvider) ListModels(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ListModelsRequest); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + // Build URL using centralized URL construction + requestURL := cohere.ToCohereListModelsURL(request, provider.networkConfig.BaseURL+"/v1/models") + req.SetRequestURI(requestURL) + req.Header.SetMethod(http.MethodGet) + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key.Value) + + // Make request + latency, bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) + + var errorResp cohere.CohereError + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = errorResp.Message + + return nil, bifrostErr + } + + // Parse Cohere list models response + var cohereResponse cohere.CohereListModelsResponse + rawResponse, bifrostErr := handleProviderResponse(resp.Body(), &cohereResponse, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Convert Cohere v2 response to Bifrost response + response := cohereResponse.ToBifrostListModelsResponse(providerName) + + response.ExtraFields.Provider = providerName + response.ExtraFields.RequestType = schemas.ListModelsRequest + response.ExtraFields.Latency = latency.Milliseconds() + + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + return response, nil +} + // TextCompletion is not supported by the Cohere provider. // Returns an error indicating that text completion is not supported. func (provider *CohereProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { @@ -131,19 +193,19 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, key schemas. } // Convert Cohere v2 response to Bifrost response - reponse := cohereResponse.ToBifrostChatResponse() + response := cohereResponse.ToBifrostChatResponse() - reponse.Model = request.Model - reponse.ExtraFields.Provider = providerName - reponse.ExtraFields.ModelRequested = request.Model - reponse.ExtraFields.RequestType = schemas.ChatCompletionRequest - reponse.ExtraFields.Latency = latency.Milliseconds() + response.Model = request.Model + response.ExtraFields.Provider = providerName + response.ExtraFields.ModelRequested = request.Model + response.ExtraFields.RequestType = schemas.ChatCompletionRequest + response.ExtraFields.Latency = latency.Milliseconds() if provider.sendBackRawResponse { - reponse.ExtraFields.RawResponse = rawResponse + response.ExtraFields.RawResponse = rawResponse } - return reponse, nil + return response, nil } func (provider *CohereProvider) handleCohereChatCompletionRequest(ctx context.Context, reqBody *cohere.CohereChatRequest, key schemas.Key) (*cohere.CohereChatResponse, interface{}, time.Duration, *schemas.BifrostError) { @@ -171,7 +233,7 @@ func (provider *CohereProvider) handleCohereChatCompletionRequest(ctx context.Co setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) req.SetRequestURI(provider.networkConfig.BaseURL + "/v2/chat") - req.Header.SetMethod("POST") + req.Header.SetMethod(http.MethodPost) req.Header.SetContentType("application/json") req.Header.Set("Authorization", "Bearer "+key.Value) @@ -246,7 +308,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo } // Create HTTP request for streaming - req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/v2/chat", bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.networkConfig.BaseURL+"/v2/chat", bytes.NewReader(jsonBody)) if err != nil { if errors.Is(err, context.Canceled) { return nil, &schemas.BifrostError{ @@ -528,7 +590,7 @@ func (provider *CohereProvider) ResponsesStream(ctx context.Context, postHookRun } // Create HTTP request for streaming - req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/v2/chat", bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.networkConfig.BaseURL+"/v2/chat", bytes.NewReader(jsonBody)) if err != nil { if errors.Is(err, context.Canceled) { return nil, &schemas.BifrostError{ @@ -723,7 +785,7 @@ func (provider *CohereProvider) Embedding(ctx context.Context, key schemas.Key, setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) req.SetRequestURI(provider.networkConfig.BaseURL + "/v2/embed") - req.Header.SetMethod("POST") + req.Header.SetMethod(http.MethodPost) req.Header.SetContentType("application/json") req.Header.Set("Authorization", "Bearer "+key.Value) diff --git a/core/providers/gemini.go b/core/providers/gemini.go index fb0463782b..ae74b6b9d7 100644 --- a/core/providers/gemini.go +++ b/core/providers/gemini.go @@ -71,6 +71,61 @@ func (provider *GeminiProvider) GetProviderKey() schemas.ModelProvider { return getProviderName(schemas.Gemini, provider.customProviderConfig) } +// ListModels performs a list models request to Gemini's API. +func (provider *GeminiProvider) ListModels(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ListModelsRequest); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + // Build URL using centralized URL construction + requestURL := gemini.ToGeminiListModelsURL(request, provider.networkConfig.BaseURL+"/models") + req.SetRequestURI(requestURL) + req.Header.SetMethod(http.MethodGet) + req.Header.SetContentType("application/json") + req.Header.Set("x-goog-api-key", key.Value) + + // Make request + latency, bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + return nil, parseGeminiError(providerName, resp) + } + + // Parse Gemini's response + var geminiResponse gemini.GeminiListModelsResponse + rawResponse, bifrostErr := handleProviderResponse(resp.Body(), &geminiResponse, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + response := geminiResponse.ToBifrostListModelsResponse(providerName) + + response.ExtraFields.Provider = providerName + response.ExtraFields.RequestType = schemas.ListModelsRequest + response.ExtraFields.Latency = latency.Milliseconds() + + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + return response, nil +} + // TextCompletion is not supported by the Gemini provider. func (provider *GeminiProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("text completion", string(provider.GetProviderKey())) @@ -113,7 +168,7 @@ func (provider *GeminiProvider) ChatCompletion(ctx context.Context, key schemas. setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) req.SetRequestURI(provider.networkConfig.BaseURL + "/openai/chat/completions") - req.Header.SetMethod("POST") + req.Header.SetMethod(http.MethodPost) req.Header.SetContentType("application/json") req.Header.Set("Authorization", "Bearer "+key.Value) @@ -319,7 +374,7 @@ func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner } // Create HTTP request for streaming - req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/models/"+request.Model+":streamGenerateContent?alt=sse", bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.networkConfig.BaseURL+"/models/"+request.Model+":streamGenerateContent?alt=sse", bytes.NewReader(jsonBody)) if err != nil { if errors.Is(err, context.Canceled) { return nil, &schemas.BifrostError{ @@ -597,7 +652,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo } // Create HTTP request for streaming - req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/models/"+request.Model+":streamGenerateContent?alt=sse", bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.networkConfig.BaseURL+"/models/"+request.Model+":streamGenerateContent?alt=sse", bytes.NewReader(jsonBody)) if err != nil { if errors.Is(err, context.Canceled) { return nil, &schemas.BifrostError{ @@ -856,7 +911,7 @@ func (provider *GeminiProvider) completeRequest(ctx context.Context, model strin // Use Gemini's generateContent endpoint req.SetRequestURI(provider.networkConfig.BaseURL + "/models/" + model + endpoint) - req.Header.SetMethod("POST") + req.Header.SetMethod(http.MethodPost) req.Header.SetContentType("application/json") req.Header.Set("x-goog-api-key", key.Value) diff --git a/core/providers/groq.go b/core/providers/groq.go index 368fa1fbe4..a11e8fcc21 100644 --- a/core/providers/groq.go +++ b/core/providers/groq.go @@ -66,6 +66,21 @@ func (provider *GroqProvider) GetProviderKey() schemas.ModelProvider { return schemas.Groq } +// ListModels performs a list models request to Groq's API. +func (provider *GroqProvider) ListModels(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + return handleOpenAIListModelsRequest( + ctx, + provider.client, + request, + provider.networkConfig.BaseURL+"/v1/models", + key, + provider.networkConfig.ExtraHeaders, + schemas.Groq, + provider.sendBackRawResponse, + provider.logger, + ) +} + // TextCompletion is not supported by the Groq provider. func (provider *GroqProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { // Checking if litellm fallback is set diff --git a/core/providers/mistral.go b/core/providers/mistral.go index b78946abca..3cb9d9c397 100644 --- a/core/providers/mistral.go +++ b/core/providers/mistral.go @@ -9,6 +9,7 @@ import ( "time" schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/core/schemas/providers/mistral" "github.com/valyala/fasthttp" ) @@ -66,6 +67,58 @@ func (provider *MistralProvider) GetProviderKey() schemas.ModelProvider { return schemas.Mistral } +// ListModels performs a list models request to Mistral's API. +func (provider *MistralProvider) ListModels(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/models") + req.Header.SetMethod(http.MethodGet) + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key.Value) + + // Make request + latency, bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + return nil, parseOpenAIError(resp) + } + + // Parse Mistral's response + var mistralResponse mistral.MistralListModelsResponse + rawResponse, bifrostErr := handleProviderResponse(resp.Body(), &mistralResponse, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Create final response + response := mistralResponse.ToBifrostListModelsResponse() + + response = response.ApplyPagination(request.PageSize, request.PageToken) + + // Set ExtraFields + response.ExtraFields.Provider = provider.GetProviderKey() + response.ExtraFields.RequestType = schemas.ListModelsRequest + response.ExtraFields.Latency = latency.Milliseconds() + + // Set raw response if enabled + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + return response, nil +} + // TextCompletion is not supported by the Mistral provider. func (provider *MistralProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("text completion", "mistral") diff --git a/core/providers/ollama.go b/core/providers/ollama.go index 20c4457d3e..a6f653fe25 100644 --- a/core/providers/ollama.go +++ b/core/providers/ollama.go @@ -68,6 +68,11 @@ func (provider *OllamaProvider) GetProviderKey() schemas.ModelProvider { return schemas.Ollama } +// ListModels performs a list models request to Ollama's API. +func (provider *OllamaProvider) ListModels(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + return handleOpenAIListModelsRequest(ctx, provider.client, request, provider.networkConfig.BaseURL+"/v1/models", key, provider.networkConfig.ExtraHeaders, provider.GetProviderKey(), provider.sendBackRawResponse, provider.logger) +} + // TextCompletion performs a text completion request to the Ollama API. func (provider *OllamaProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return handleOpenAITextCompletionRequest( @@ -195,3 +200,4 @@ func (provider *OllamaProvider) Transcription(ctx context.Context, key schemas.K func (provider *OllamaProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription stream", "ollama") } + diff --git a/core/providers/openai.go b/core/providers/openai.go index 4d621ed2d3..d582b86073 100644 --- a/core/providers/openai.go +++ b/core/providers/openai.go @@ -77,6 +77,82 @@ func (provider *OpenAIProvider) GetProviderKey() schemas.ModelProvider { return getProviderName(schemas.OpenAI, provider.customProviderConfig) } +func (provider *OpenAIProvider) ListModels(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ListModelsRequest); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + return handleOpenAIListModelsRequest(ctx, provider.client, request, provider.networkConfig.BaseURL+"/v1/models", key, provider.networkConfig.ExtraHeaders, providerName, provider.sendBackRawResponse, provider.logger) + +} + +func handleOpenAIListModelsRequest( + ctx context.Context, + client *fasthttp.Client, + request *schemas.BifrostListModelsRequest, + url string, + key schemas.Key, + extraHeaders map[string]string, + providerName schemas.ModelProvider, + sendBackRawResponse bool, + logger schemas.Logger, +) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, extraHeaders, nil) + + req.SetRequestURI(url) + req.Header.SetMethod(http.MethodGet) + req.Header.SetContentType("application/json") + + if key.Value != "" { + req.Header.Set("Authorization", "Bearer "+key.Value) + } + // Make request + latency, bifrostErr := makeRequestWithContext(ctx, client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) + return nil, parseOpenAIError(resp) + } + + responseBody := resp.Body() + + openaiResponse := &openai.OpenAIListModelsResponse{} + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(responseBody, openaiResponse, sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + response := openaiResponse.ToBifrostListModelsResponse(providerName) + + response = response.ApplyPagination(request.PageSize, request.PageToken) + + // Set raw response if enabled + if sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + response.ExtraFields.Provider = providerName + response.ExtraFields.RequestType = schemas.ListModelsRequest + response.ExtraFields.Latency = latency.Milliseconds() + + return response, nil +} + // TextCompletion is not supported by the OpenAI provider. // Returns an error indicating that text completion is not available. func (provider *OpenAIProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { @@ -126,7 +202,7 @@ func handleOpenAITextCompletionRequest( setExtraHeaders(req, extraHeaders, nil) req.SetRequestURI(url) - req.Header.SetMethod("POST") + req.Header.SetMethod(http.MethodPost) req.Header.SetContentType("application/json") if key.Value != "" { @@ -235,7 +311,7 @@ func handleOpenAITextCompletionStreaming( } // Create HTTP request for streaming - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody)) if err != nil { if errors.Is(err, context.Canceled) { return nil, &schemas.BifrostError{ @@ -482,7 +558,7 @@ func handleOpenAIChatCompletionRequest( setExtraHeaders(req, extraHeaders, nil) req.SetRequestURI(url) - req.Header.SetMethod("POST") + req.Header.SetMethod(http.MethodPost) req.Header.SetContentType("application/json") if key.Value != "" { @@ -594,7 +670,7 @@ func handleOpenAIChatCompletionStreaming( } // Create HTTP request for streaming - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody)) if err != nil { if errors.Is(err, context.Canceled) { return nil, &schemas.BifrostError{ @@ -844,7 +920,7 @@ func handleOpenAIResponsesRequest( setExtraHeaders(req, extraHeaders, nil) req.SetRequestURI(url) - req.Header.SetMethod("POST") + req.Header.SetMethod(http.MethodPost) req.Header.SetContentType("application/json") if key.Value != "" { @@ -951,7 +1027,7 @@ func handleOpenAIResponsesStreaming( } // Create HTTP request for streaming - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody)) if err != nil { if errors.Is(err, context.Canceled) { return nil, &schemas.BifrostError{ @@ -1167,7 +1243,7 @@ func handleOpenAIEmbeddingRequest( setExtraHeaders(req, extraHeaders, nil) req.SetRequestURI(url) - req.Header.SetMethod("POST") + req.Header.SetMethod(http.MethodPost) req.Header.SetContentType("application/json") if key.Value != "" { @@ -1242,7 +1318,7 @@ func (provider *OpenAIProvider) Speech(ctx context.Context, key schemas.Key, req setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/audio/speech") - req.Header.SetMethod("POST") + req.Header.SetMethod(http.MethodPost) req.Header.SetContentType("application/json") req.Header.Set("Authorization", "Bearer "+key.Value) @@ -1310,7 +1386,7 @@ func (provider *OpenAIProvider) SpeechStream(ctx context.Context, postHookRunner } // Create HTTP request for streaming - req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/v1/audio/speech", bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.networkConfig.BaseURL+"/v1/audio/speech", bytes.NewReader(jsonBody)) if err != nil { if errors.Is(err, context.Canceled) { return nil, &schemas.BifrostError{ @@ -1497,7 +1573,7 @@ func (provider *OpenAIProvider) Transcription(ctx context.Context, key schemas.K setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/audio/transcriptions") - req.Header.SetMethod("POST") + req.Header.SetMethod(http.MethodPost) req.Header.SetContentType(writer.FormDataContentType()) // This sets multipart/form-data with boundary req.Header.Set("Authorization", "Bearer "+key.Value) @@ -1577,7 +1653,7 @@ func (provider *OpenAIProvider) TranscriptionStream(ctx context.Context, postHoo } // Create HTTP request for streaming - req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/v1/audio/transcriptions", &body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.networkConfig.BaseURL+"/v1/audio/transcriptions", &body) if err != nil { if errors.Is(err, context.Canceled) { return nil, &schemas.BifrostError{ diff --git a/core/providers/openrouter.go b/core/providers/openrouter.go index 57e09422f6..8dafe7cb9c 100644 --- a/core/providers/openrouter.go +++ b/core/providers/openrouter.go @@ -4,6 +4,7 @@ package providers import ( "context" + "fmt" "net/http" "strings" "time" @@ -61,6 +62,59 @@ func (provider *OpenRouterProvider) GetProviderKey() schemas.ModelProvider { return schemas.OpenRouter } +// ListModels performs a list models request to OpenRouter's API. +func (provider *OpenRouterProvider) ListModels(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/models") + req.Header.SetMethod(http.MethodGet) + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key.Value) + + // Make request + latency, bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", schemas.OpenRouter, string(resp.Body()))) + return nil, parseOpenAIError(resp) + } + + var openrouterResponse schemas.BifrostListModelsResponse + rawResponse, bifrostErr := handleProviderResponse(resp.Body(), &openrouterResponse, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + for i := range openrouterResponse.Data { + openrouterResponse.Data[i].ID = string(schemas.OpenRouter) + "/" + openrouterResponse.Data[i].ID + } + + response := openrouterResponse.ApplyPagination(request.PageSize, request.PageToken) + + // Set ExtraFields + response.ExtraFields.Provider = provider.GetProviderKey() + response.ExtraFields.RequestType = schemas.ListModelsRequest + response.ExtraFields.Latency = latency.Milliseconds() + + // Set raw response if enabled + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + return response, nil +} + // TextCompletion performs a text completion request to the OpenRouter API. func (provider *OpenRouterProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return handleOpenAITextCompletionRequest( diff --git a/core/providers/parasail.go b/core/providers/parasail.go index 632f618784..ea201e2721 100644 --- a/core/providers/parasail.go +++ b/core/providers/parasail.go @@ -66,6 +66,21 @@ func (provider *ParasailProvider) GetProviderKey() schemas.ModelProvider { return schemas.Parasail } +// ListModels performs a list models request to Parasail's API. +func (provider *ParasailProvider) ListModels(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + return handleOpenAIListModelsRequest( + ctx, + provider.client, + request, + provider.networkConfig.BaseURL+"/v1/models", + key, + provider.networkConfig.ExtraHeaders, + schemas.Parasail, + provider.sendBackRawResponse, + provider.logger, + ) +} + // TextCompletion is not supported by the Parasail provider. func (provider *ParasailProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("text completion", "parasail") diff --git a/core/providers/sgl.go b/core/providers/sgl.go index c2aec4268b..c7d7f99350 100644 --- a/core/providers/sgl.go +++ b/core/providers/sgl.go @@ -68,6 +68,21 @@ func (provider *SGLProvider) GetProviderKey() schemas.ModelProvider { return schemas.SGL } +// ListModels performs a list models request to SGL's API. +func (provider *SGLProvider) ListModels(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + return handleOpenAIListModelsRequest( + ctx, + provider.client, + request, + provider.networkConfig.BaseURL+"/v1/models", + key, + provider.networkConfig.ExtraHeaders, + schemas.SGL, + provider.sendBackRawResponse, + provider.logger, + ) +} + // TextCompletion is not supported by the SGL provider. func (provider *SGLProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return handleOpenAITextCompletionRequest( diff --git a/core/providers/vertex.go b/core/providers/vertex.go index 0edb91b175..5c30cbddeb 100644 --- a/core/providers/vertex.go +++ b/core/providers/vertex.go @@ -132,6 +132,156 @@ func (provider *VertexProvider) GetProviderKey() schemas.ModelProvider { return schemas.Vertex } +// ListModels performs a list models request to Vertex's API. +func (provider *VertexProvider) ListModels(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + providerName := provider.GetProviderKey() + + if key.VertexKeyConfig == nil { + return nil, newConfigurationError("vertex key config is not set", providerName) + } + + projectID := key.VertexKeyConfig.ProjectID + if projectID == "" { + return nil, newConfigurationError("project ID is not set", providerName) + } + + region := key.VertexKeyConfig.Region + if region == "" { + return nil, newConfigurationError("region is not set in key config", providerName) + } + + // Build URL using centralized URL construction + requestURL := vertex.ToVertexListModelsURL(request, fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/models", region, projectID, region)) + + // Create request + req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) + if err != nil { + if errors.Is(err, context.Canceled) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Type: schemas.Ptr(schemas.RequestCancelled), + Message: schemas.ErrRequestCancelled, + Error: err, + }, + } + } + if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) { + return nil, newBifrostOperationError(schemas.ErrProviderRequestTimedOut, err, providerName) + } + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: schemas.ErrProviderRequest, + Error: err, + }, + } + } + + // Set any extra headers from network config + setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil) + + req.Header.Set("Content-Type", "application/json") + + client, err := getAuthClient(key) + if err != nil { + // Remove client from pool if auth client creation fails + removeVertexClient(key.VertexKeyConfig.AuthCredentials) + return nil, newBifrostOperationError("error creating auth client", err, providerName) + } + + // Make request and measure latency + startTime := time.Now() + resp, err := client.Do(req) + latency := time.Since(startTime) + if err != nil { + if errors.Is(err, context.Canceled) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Type: schemas.Ptr(schemas.RequestCancelled), + Message: schemas.ErrRequestCancelled, + Error: err, + }, + } + } + if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) { + return nil, newBifrostOperationError(schemas.ErrProviderRequestTimedOut, err, providerName) + } + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: schemas.ErrProviderRequest, + Error: err, + }, + } + } + defer resp.Body.Close() + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: &schemas.ErrorField{ + Message: "error reading response", + Error: err, + }, + } + } + + // Handle error response + if resp.StatusCode != http.StatusOK { + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + removeVertexClient(key.VertexKeyConfig.AuthCredentials) + } + provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(body))) + + var errorResp VertexError + + if err := sonic.Unmarshal(body, &errorResp); err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: &resp.StatusCode, + Error: &schemas.ErrorField{ + Message: schemas.ErrProviderResponseUnmarshal, + Error: err, + }, + } + } + + return nil, &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: &resp.StatusCode, + Error: &schemas.ErrorField{ + Message: errorResp.Error.Message, + }, + } + } + + // Parse Vertex's response + var vertexResponse vertex.VertexListModelsResponse + rawResponse, bifrostErr := handleProviderResponse(body, &vertexResponse, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Create final response + response := vertexResponse.ToBifrostListModelsResponse() + + // Set ExtraFields + response.ExtraFields.Provider = provider.GetProviderKey() + response.ExtraFields.RequestType = schemas.ListModelsRequest + response.ExtraFields.Latency = latency.Milliseconds() + + // Set raw response if enabled + if provider.sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + return response, nil +} + // TextCompletion is not supported by the Vertex provider. // Returns an error indicating that text completion is not available. func (provider *VertexProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { @@ -219,7 +369,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas. } // Create request - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody)) if err != nil { if errors.Is(err, context.Canceled) { return nil, &schemas.BifrostError{ @@ -514,18 +664,25 @@ func (provider *VertexProvider) Embedding(ctx context.Context, key schemas.Key, return nil, newConfigurationError("embedding input texts are empty", schemas.Vertex) } + // All Vertex AI embedding models use the same native Vertex embedding API + return provider.handleVertexEmbedding(ctx, request.Model, key, reqBody) +} + +// handleVertexEmbedding handles embedding requests using Vertex's native embedding API +// This is used for all Vertex AI embedding models as they all use the same response format +func (provider *VertexProvider) handleVertexEmbedding(ctx context.Context, model string, key schemas.Key, vertexReq *vertex.VertexEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { // Use the typed request directly - jsonBody, err := sonic.Marshal(reqBody) + jsonBody, err := sonic.Marshal(vertexReq) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Vertex) } // Build the native Vertex embedding API endpoint url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", - key.VertexKeyConfig.Region, key.VertexKeyConfig.ProjectID, key.VertexKeyConfig.Region, request.Model) + key.VertexKeyConfig.Region, key.VertexKeyConfig.ProjectID, key.VertexKeyConfig.Region, model) // Create request - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody)) if err != nil { if errors.Is(err, context.Canceled) { return nil, &schemas.BifrostError{ @@ -625,7 +782,7 @@ func (provider *VertexProvider) Embedding(ctx context.Context, key schemas.Key, // Set ExtraFields bifrostResponse.ExtraFields.Provider = schemas.Vertex - bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.ModelRequested = model bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() diff --git a/core/schemas/account.go b/core/schemas/account.go index 9762beb940..dee76c12c2 100644 --- a/core/schemas/account.go +++ b/core/schemas/account.go @@ -21,7 +21,7 @@ type Key struct { type AzureKeyConfig struct { Endpoint string `json:"endpoint"` // Azure service endpoint URL Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model names to deployment names - APIVersion *string `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-08-01-preview" + APIVersion *string `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-10-21" } // VertexKeyConfig represents the Vertex-specific configuration. diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index e639a01827..eded6a4e40 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -79,6 +79,7 @@ var StandardProviders = []ModelProvider{ type RequestType string const ( + ListModelsRequest RequestType = "list_models" TextCompletionRequest RequestType = "text_completion" TextCompletionStreamRequest RequestType = "text_completion_stream" ChatCompletionRequest RequestType = "chat_completion" @@ -118,6 +119,7 @@ type Fallback struct { // BifrostRequest is the request struct for all bifrost requests. // only ONE of the following fields should be set: +// - ListModelsRequest // - TextCompletionRequest // - ChatRequest // - ResponsesRequest @@ -128,6 +130,7 @@ type Fallback struct { type BifrostRequest struct { RequestType RequestType + ListModelsRequest *BifrostListModelsRequest TextCompletionRequest *BifrostTextCompletionRequest ChatRequest *BifrostChatRequest ResponsesRequest *BifrostResponsesRequest @@ -250,8 +253,8 @@ func (r *BifrostResponse) GetExtraFields() *BifrostResponseExtraFields { // BifrostResponseExtraFields contains additional fields in a response. type BifrostResponseExtraFields struct { RequestType RequestType `json:"request_type"` - Provider ModelProvider `json:"provider"` - ModelRequested string `json:"model_requested"` + Provider ModelProvider `json:"provider,omitempty"` + ModelRequested string `json:"model_requested,omitempty"` Latency int64 `json:"latency"` // in milliseconds (for streaming responses this will be each chunk latency, and the last chunk latency will be the total latency) ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses RawResponse interface{} `json:"raw_response,omitempty"` diff --git a/core/schemas/models.go b/core/schemas/models.go new file mode 100644 index 0000000000..9e4a718a3c --- /dev/null +++ b/core/schemas/models.go @@ -0,0 +1,226 @@ +package schemas + +import ( + "encoding/base64" + "fmt" + + "github.com/bytedance/sonic" +) + +// DefaultPageSize is the default page size for listing models +const DefaultPageSize = 1000 +// MaxPaginationRequests is the maximum number of pagination requests to make +const MaxPaginationRequests = 20 + +type BifrostListModelsRequest struct { + Provider ModelProvider `json:"provider"` + + PageSize int `json:"page_size"` + + // PageToken: Token received from previous request to retrieve next page + PageToken string `json:"page_token"` + + // ExtraParams: Additional provider-specific query parameters + // This allows for flexibility to pass any custom parameters that specific providers might support + ExtraParams map[string]interface{} `json:"-"` +} + +type BifrostListModelsResponse struct { + Data []Model `json:"data"` + ExtraFields BifrostResponseExtraFields `json:"extra_fields"` + NextPageToken string `json:"next_page_token,omitempty"` // Token to retrieve next page + + // Anthropic specific fields + FirstID *string `json:"-"` + LastID *string `json:"-"` + HasMore *bool `json:"-"` +} + +type Model struct { + ID string `json:"id"` + CanonicalSlug *string `json:"canonical_slug,omitempty"` + Name *string `json:"name,omitempty"` + Created *int64 `json:"created,omitempty"` + ContextLength *int `json:"context_length,omitempty"` + Architecture *Architecture `json:"architecture,omitempty"` + Pricing *Pricing `json:"pricing,omitempty"` + TopProvider *TopProvider `json:"top_provider,omitempty"` + PerRequestLimits *PerRequestLimits `json:"per_request_limits,omitempty"` + SupportedParameters []string `json:"supported_parameters,omitempty"` + DefaultParameters *DefaultParameters `json:"default_parameters,omitempty"` + HuggingFaceID *string `json:"hugging_face_id,omitempty"` + Description *string `json:"description,omitempty"` + + OwnedBy *string `json:"owned_by,omitempty"` + SupportedMethods []string `json:"supported_methods,omitempty"` +} + +type Architecture struct { + Modality *string `json:"modality,omitempty"` + Tokenizer *string `json:"tokenizer,omitempty"` + InstructType *string `json:"instruct_type,omitempty"` + InputModalities []string `json:"input_modalities,omitempty"` + OutputModalities []string `json:"output_modalities,omitempty"` +} + +type Pricing struct { + Prompt *string `json:"prompt,omitempty"` + Completion *string `json:"completion,omitempty"` + Request *string `json:"request,omitempty"` + Image *string `json:"image,omitempty"` + WebSearch *string `json:"web_search,omitempty"` + InternalReasoning *string `json:"internal_reasoning,omitempty"` + InputCacheRead *string `json:"input_cache_read,omitempty"` + InputCacheWrite *string `json:"input_cache_write,omitempty"` +} + +type TopProvider struct { + IsModerated *bool `json:"is_moderated,omitempty"` + ContextLength *int `json:"context_length,omitempty"` + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` +} + +type PerRequestLimits struct { + PromptTokens *int `json:"prompt_tokens,omitempty"` + CompletionTokens *int `json:"completion_tokens,omitempty"` +} + +type DefaultParameters struct { + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` +} + +// paginationCursor represents the internal cursor structure for pagination. +type paginationCursor struct { + Offset int `json:"o"` + LastID string `json:"l,omitempty"` +} + +// encodePaginationCursor creates an opaque base64-encoded page token from cursor data. +// Returns empty string if offset is 0 or negative. +func encodePaginationCursor(offset int, lastID string) (string, error) { + if offset <= 0 { + return "", nil + } + + cursor := paginationCursor{ + Offset: offset, + LastID: lastID, + } + + jsonData, err := sonic.Marshal(cursor) + if err != nil { + return "", fmt.Errorf("failed to marshal pagination cursor: %w", err) + } + + // Use URL-safe base64 encoding without padding for opaque token + encoded := base64.RawURLEncoding.EncodeToString(jsonData) + return encoded, nil +} + +// decodePaginationCursor extracts cursor data from an opaque base64-encoded page token. +// Returns cursor with 0 offset for empty or invalid tokens. +func decodePaginationCursor(token string) paginationCursor { + if token == "" { + return paginationCursor{} + } + + // Decode base64 + decoded, err := base64.RawURLEncoding.DecodeString(token) + if err != nil { + return paginationCursor{} + } + + var cursor paginationCursor + if err := sonic.Unmarshal(decoded, &cursor); err != nil { + return paginationCursor{} + } + + if cursor.Offset < 0 { + return paginationCursor{} + } + + return cursor +} + +// validatePaginationCursor validates that the cursor matches the expected position in the data. +// Returns true if the cursor is valid, false otherwise. +func validatePaginationCursor(cursor paginationCursor, data []Model) bool { + if cursor.LastID == "" { + return true + } + + if cursor.Offset <= 0 || cursor.Offset > len(data) { + return false + } + + prevIndex := cursor.Offset - 1 + if prevIndex >= 0 && prevIndex < len(data) { + return data[prevIndex].ID == cursor.LastID + } + + return true +} + +// ApplyPagination applies offset-based pagination to a BifrostListModelsResponse. +// Uses opaque tokens with LastID validation to ensure cursor integrity. +// Returns the paginated response with properly set NextPageToken. +func (response *BifrostListModelsResponse) ApplyPagination(pageSize int, pageToken string) *BifrostListModelsResponse { + if response == nil { + return nil + } + + totalItems := len(response.Data) + + if pageSize <= 0 { + return response + } + + cursor := decodePaginationCursor(pageToken) + offset := cursor.Offset + + // Validate cursor integrity if LastID is present + if cursor.LastID != "" && !validatePaginationCursor(cursor, response.Data) { + // Invalid cursor: reset to beginning + offset = 0 + } + + if offset >= totalItems { + // Return empty page, no next token + return &BifrostListModelsResponse{ + Data: []Model{}, + ExtraFields: response.ExtraFields, + NextPageToken: "", + } + } + + endIndex := offset + pageSize + if endIndex > totalItems { + endIndex = totalItems + } + + paginatedData := response.Data[offset:endIndex] + + paginatedResponse := &BifrostListModelsResponse{ + Data: paginatedData, + ExtraFields: response.ExtraFields, + } + + if endIndex < totalItems { + // Get the last item ID for cursor validation + var lastID string + if len(paginatedData) > 0 { + lastID = paginatedData[len(paginatedData)-1].ID + } + + nextToken, err := encodePaginationCursor(endIndex, lastID) + if err == nil { + paginatedResponse.NextPageToken = nextToken + } + } else { + paginatedResponse.NextPageToken = "" + } + + return paginatedResponse +} diff --git a/core/schemas/provider.go b/core/schemas/provider.go index d458b4eadb..c289caa4ec 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -87,6 +87,7 @@ type ProxyConfig struct { // A nil *AllowedRequests means "all operations allowed." // A non-nil value only allows fields set to true; omitted or false fields are disallowed. type AllowedRequests struct { + ListModels bool `json:"list_models"` TextCompletion bool `json:"text_completion"` TextCompletionStream bool `json:"text_completion_stream"` ChatCompletion bool `json:"chat_completion"` @@ -105,6 +106,8 @@ func (ar *AllowedRequests) IsOperationAllowed(operation RequestType) bool { } switch operation { + case ListModelsRequest: + return ar.ListModels case TextCompletionRequest: return ar.TextCompletion case TextCompletionStreamRequest: @@ -194,6 +197,8 @@ type PostHookRunner func(ctx *context.Context, result *BifrostResponse, err *Bif type Provider interface { // GetProviderKey returns the provider's identifier GetProviderKey() ModelProvider + // ListModels performs a list models request + ListModels(ctx context.Context, key Key, request *BifrostListModelsRequest) (*BifrostListModelsResponse, *BifrostError) // TextCompletion performs a text completion request TextCompletion(ctx context.Context, key Key, request *BifrostTextCompletionRequest) (*BifrostTextCompletionResponse, *BifrostError) // TextCompletionStream performs a text completion stream request diff --git a/core/schemas/providers/anthropic/models.go b/core/schemas/providers/anthropic/models.go new file mode 100644 index 0000000000..b9f6e839b8 --- /dev/null +++ b/core/schemas/providers/anthropic/models.go @@ -0,0 +1,103 @@ +package anthropic + +import ( + "net/url" + "strconv" + "time" + + "github.com/maximhq/bifrost/core/schemas" +) + +func ToAnthropicListModelsURL(request *schemas.BifrostListModelsRequest, baseURL string) string { + // Add limit parameter (default to 1000) + pageSize := request.PageSize + if pageSize <= 0 { + pageSize = schemas.DefaultPageSize + } + + // Build query parameters + params := url.Values{} + params.Set("limit", strconv.Itoa(pageSize)) + + // Add cursor-based pagination parameters + if request.ExtraParams != nil { + // before_id for backward pagination + if beforeID, ok := request.ExtraParams["before_id"].(string); ok && beforeID != "" { + params.Set("before_id", beforeID) + } + // after_id for forward pagination + if afterID, ok := request.ExtraParams["after_id"].(string); ok && afterID != "" { + params.Set("after_id", afterID) + } + } + // Use page_token as after_id if not explicitly provided in ExtraParams + if request.PageToken != "" { + if request.ExtraParams == nil { + params.Set("after_id", request.PageToken) + } else if _, hasAfterID := request.ExtraParams["after_id"]; !hasAfterID { + params.Set("after_id", request.PageToken) + } + } + + return baseURL + "?" + params.Encode() +} + +func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider) *schemas.BifrostListModelsResponse { + if response == nil { + return nil + } + + bifrostResponse := &schemas.BifrostListModelsResponse{ + Data: make([]schemas.Model, 0, len(response.Data)), + FirstID: response.FirstID, + LastID: response.LastID, + HasMore: schemas.Ptr(response.HasMore), + } + + // Map Anthropic's cursor-based pagination to Bifrost's token-based pagination + // If there are more results, set next_page_token to last_id so it can be used in the next request + if response.HasMore && response.LastID != nil { + bifrostResponse.NextPageToken = *response.LastID + } + + for _, model := range response.Data { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ + ID: string(providerKey) + "/" + model.ID, + Name: schemas.Ptr(model.DisplayName), + Created: schemas.Ptr(model.CreatedAt.Unix()), + }) + } + + return bifrostResponse +} + +func ToAnthropicListModelsResponse(response *schemas.BifrostListModelsResponse) *AnthropicListModelsResponse { + if response == nil { + return nil + } + + anthropicResponse := &AnthropicListModelsResponse{ + Data: make([]AnthropicModel, 0, len(response.Data)), + } + if response.FirstID != nil { + anthropicResponse.FirstID = response.FirstID + } + if response.LastID != nil { + anthropicResponse.LastID = response.LastID + } + + for _, model := range response.Data { + anthropicModel := AnthropicModel{ + ID: model.ID, + } + if model.Name != nil { + anthropicModel.DisplayName = *model.Name + } + if model.Created != nil { + anthropicModel.CreatedAt = time.Unix(*model.Created, 0) + } + anthropicResponse.Data = append(anthropicResponse.Data, anthropicModel) + } + + return anthropicResponse +} diff --git a/core/schemas/providers/anthropic/types.go b/core/schemas/providers/anthropic/types.go index 8d2102d03c..378d5c56d6 100644 --- a/core/schemas/providers/anthropic/types.go +++ b/core/schemas/providers/anthropic/types.go @@ -3,6 +3,7 @@ package anthropic import ( "encoding/json" "fmt" + "time" "github.com/maximhq/bifrost/core/schemas" ) @@ -268,6 +269,22 @@ type AnthropicStreamDelta struct { StopSequence *string `json:"stop_sequence,omitempty"` } +// ==================== MODEL TYPES ==================== + +type AnthropicModel struct { + ID string `json:"id"` + DisplayName string `json:"display_name"` + CreatedAt time.Time `json:"created_at"` + Type string `json:"type"` +} + +type AnthropicListModelsResponse struct { + Data []AnthropicModel `json:"data"` + FirstID *string `json:"first_id,omitempty"` + HasMore bool `json:"has_more"` + LastID *string `json:"last_id,omitempty"` +} + // ==================== ERROR TYPES ==================== // AnthropicMessageError represents an Anthropic messages API error response diff --git a/core/schemas/providers/azure/models.go b/core/schemas/providers/azure/models.go new file mode 100644 index 0000000000..aa4eeca2fa --- /dev/null +++ b/core/schemas/providers/azure/models.go @@ -0,0 +1,22 @@ +package azure + +import "github.com/maximhq/bifrost/core/schemas" + +func (response *AzureListModelsResponse) ToBifrostListModelsResponse() *schemas.BifrostListModelsResponse { + if response == nil { + return nil + } + + bifrostResponse := &schemas.BifrostListModelsResponse{ + Data: make([]schemas.Model, 0, len(response.Data)), + } + + for _, model := range response.Data { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ + ID: string(schemas.Azure) + "/" + model.ID, + Created: schemas.Ptr(model.CreatedAt), + Name: schemas.Ptr(model.Model), + }) + } + return bifrostResponse +} diff --git a/core/schemas/providers/azure/types.go b/core/schemas/providers/azure/types.go new file mode 100644 index 0000000000..045b22627e --- /dev/null +++ b/core/schemas/providers/azure/types.go @@ -0,0 +1,34 @@ +package azure + +// DefaultAzureAPIVersion is the default Azure OpenAI API version to use when not specified. +const DefaultAzureAPIVersion = "2024-10-21" + +type AzureModelCapabilities struct { + FineTune bool `json:"fine_tune"` + Inference bool `json:"inference"` + Completion bool `json:"completion"` + ChatCompletion bool `json:"chat_completion"` + Embeddings bool `json:"embeddings"` +} + +type AzureModelDeprecation struct { + FineTune int64 `json:"fine_tune,omitempty"` + Inference int64 `json:"inference,omitempty"` +} + +type AzureModel struct { + Status string `json:"status"` + Model string `json:"model,omitempty"` + FineTune string `json:"fine_tune,omitempty"` + Capabilities AzureModelCapabilities `json:"capabilities,omitempty"` + LifecycleStatus string `json:"lifecycle_status"` + Deprecation *AzureModelDeprecation `json:"deprecation,omitempty"` + ID string `json:"id"` + CreatedAt int64 `json:"created_at"` + Object string `json:"object"` +} + +type AzureListModelsResponse struct { + Object string `json:"object"` + Data []AzureModel `json:"data"` +} diff --git a/core/schemas/providers/bedrock/models.go b/core/schemas/providers/bedrock/models.go new file mode 100644 index 0000000000..8e82a81dee --- /dev/null +++ b/core/schemas/providers/bedrock/models.go @@ -0,0 +1,27 @@ +package bedrock + +import "github.com/maximhq/bifrost/core/schemas" + +func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider) *schemas.BifrostListModelsResponse { + if response == nil { + return nil + } + + bifrostResponse := &schemas.BifrostListModelsResponse{ + Data: make([]schemas.Model, 0, len(response.ModelSummaries)), + } + + for _, model := range response.ModelSummaries { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ + ID: string(providerKey) + "/" + model.ModelID, + Name: schemas.Ptr(model.ModelName), + OwnedBy: schemas.Ptr(model.ProviderName), + Architecture: &schemas.Architecture{ + InputModalities: model.InputModalities, + OutputModalities: model.OutputModalities, + }, + }) + } + + return bifrostResponse +} diff --git a/core/schemas/providers/bedrock/types.go b/core/schemas/providers/bedrock/types.go index b6958d6ea9..0b4243394e 100644 --- a/core/schemas/providers/bedrock/types.go +++ b/core/schemas/providers/bedrock/types.go @@ -1,5 +1,8 @@ package bedrock +// DefaultBedrockRegion is the default region for Bedrock +const DefaultBedrockRegion = "us-east-1" + // ==================== REQUEST TYPES ==================== // BedrockTextCompletionRequest represents a Bedrock text completion request @@ -430,3 +433,26 @@ type BedrockTitanEmbeddingResponse struct { Embedding []float32 `json:"embedding"` // The embedding vector InputTextTokenCount int `json:"inputTextTokenCount"` // Number of tokens in input } + +// ==================== MODELS TYPES ==================== +type BedrockModelLifecycle struct { + Status string `json:"status"` +} + +type BedrockModel struct { + CustomizationsSupported []string `json:"customizationsSupported,omitempty"` + InferenceTypesSupported []string `json:"inferenceTypesSupported,omitempty"` + InputModalities []string `json:"inputModalities,omitempty"` + ModelArn string `json:"modelArn"` + ModelID string `json:"modelId"` + ModelLifecycle BedrockModelLifecycle `json:"modelLifecycle,omitempty"` + ModelName string `json:"modelName"` + OutputModalities []string `json:"outputModalities,omitempty"` + ProviderName string `json:"providerName"` + ResponseStreamingSupported bool `json:"responseStreamingSupported"` +} + +// BedrockListModelsResponse represents the response from AWS Bedrock's ListFoundationModels API +type BedrockListModelsResponse struct { + ModelSummaries []BedrockModel `json:"modelSummaries"` +} diff --git a/core/schemas/providers/cohere/models.go b/core/schemas/providers/cohere/models.go new file mode 100644 index 0000000000..4c13ec197c --- /dev/null +++ b/core/schemas/providers/cohere/models.go @@ -0,0 +1,56 @@ +package cohere + +import ( + "net/url" + "strconv" + + schemas "github.com/maximhq/bifrost/core/schemas" +) + +func ToCohereListModelsURL(request *schemas.BifrostListModelsRequest, baseURL string) string { + pageSize := request.PageSize + if pageSize <= 0 { + pageSize = schemas.DefaultPageSize + } + + // Build query parameters + params := url.Values{} + params.Set("page_size", strconv.Itoa(pageSize)) + + if request.PageToken != "" { + params.Set("page_token", request.PageToken) + } + + if request.ExtraParams != nil { + if endpoint, ok := request.ExtraParams["endpoint"].(string); ok && endpoint != "" { + params.Set("endpoint", endpoint) + } + if defaultOnly, ok := request.ExtraParams["default_only"].(bool); ok && defaultOnly { + params.Set("default_only", "true") + } + } + + return baseURL + "?" + params.Encode() +} + +func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider) *schemas.BifrostListModelsResponse { + if response == nil { + return nil + } + + bifrostResponse := &schemas.BifrostListModelsResponse{ + Data: make([]schemas.Model, 0, len(response.Models)), + NextPageToken: response.NextPageToken, + } + + for _, model := range response.Models { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ + ID: string(providerKey) + "/" + model.Name, + Name: schemas.Ptr(model.Name), + ContextLength: schemas.Ptr(int(model.ContextLength)), + SupportedMethods: model.Endpoints, + }) + } + + return bifrostResponse +} diff --git a/core/schemas/providers/cohere/types.go b/core/schemas/providers/cohere/types.go index ac735cc56c..7f1e3e9024 100644 --- a/core/schemas/providers/cohere/types.go +++ b/core/schemas/providers/cohere/types.go @@ -523,3 +523,21 @@ type CohereError struct { Message string `json:"message"` // Error message Code *string `json:"code,omitempty"` // Optional error code } + +// ==================== MODEL TYPES ==================== + +type CohereModel struct { + Name string `json:"name"` + IsDeprecated bool `json:"is_deprecated"` + Endpoints []string `json:"endpoints"` + Finetuned bool `json:"finetuned"` + ContextLength int `json:"context_length"` + TokenizerURL string `json:"tokenizer_url"` + DefaultEndpoints []string `json:"default_endpoints"` + Features []string `json:"features"` +} + +type CohereListModelsResponse struct { + Models []CohereModel `json:"models"` + NextPageToken string `json:"next_page_token"` +} diff --git a/core/schemas/providers/gemini/chat.go b/core/schemas/providers/gemini/chat.go index 42ee936e43..01149cb78f 100644 --- a/core/schemas/providers/gemini/chat.go +++ b/core/schemas/providers/gemini/chat.go @@ -439,7 +439,9 @@ func ToGeminiChatResponse(bifrostResp *schemas.BifrostChatResponse) *GenerateCon for _, toolCall := range choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls { argsMap := make(map[string]interface{}) if toolCall.Function.Arguments != "" { - json.Unmarshal([]byte(toolCall.Function.Arguments), &argsMap) + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &argsMap); err != nil { + argsMap = map[string]interface{}{} + } } if toolCall.Function.Name != nil { fc := &FunctionCall{ diff --git a/core/schemas/providers/gemini/models.go b/core/schemas/providers/gemini/models.go new file mode 100644 index 0000000000..0e4b824493 --- /dev/null +++ b/core/schemas/providers/gemini/models.go @@ -0,0 +1,81 @@ +package gemini + +import ( + "net/url" + "strconv" + "strings" + + "github.com/maximhq/bifrost/core/schemas" +) + +func ToGeminiListModelsURL(request *schemas.BifrostListModelsRequest, baseURL string) string { + // Add limit parameter (default to 1000) + pageSize := request.PageSize + if pageSize <= 0 { + pageSize = schemas.DefaultPageSize + } + + // Build query parameters + params := url.Values{} + params.Set("pageSize", strconv.Itoa(pageSize)) + + if request.PageToken != "" { + params.Set("pageToken", request.PageToken) + } + + return baseURL + "?" + params.Encode() +} + +func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider) *schemas.BifrostListModelsResponse { + if response == nil { + return nil + } + + bifrostResponse := &schemas.BifrostListModelsResponse{ + Data: make([]schemas.Model, 0, len(response.Models)), + NextPageToken: response.NextPageToken, + } + + for _, model := range response.Models { + contextLength := model.InputTokenLimit + model.OutputTokenLimit + // Remove prefix models/ from model.Name + modelName := strings.TrimPrefix(model.Name, "models/") + bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ + ID: string(providerKey) + "/" + modelName, + Name: schemas.Ptr(model.DisplayName), + Description: schemas.Ptr(model.Description), + ContextLength: schemas.Ptr(int(contextLength)), + SupportedMethods: model.SupportedGenerationMethods, + }) + } + + return bifrostResponse +} + +func ToGeminiListModelsResponse(resp *schemas.BifrostListModelsResponse) *GeminiListModelsResponse { + if resp == nil { + return nil + } + + geminiResponse := &GeminiListModelsResponse{ + Models: make([]GeminiModel, 0, len(resp.Data)), + NextPageToken: resp.NextPageToken, + } + + for _, model := range resp.Data { + geminiModel := GeminiModel{ + Name: model.ID, + SupportedGenerationMethods: model.SupportedMethods, + } + if model.Name != nil { + geminiModel.DisplayName = *model.Name + } + if model.Description != nil { + geminiModel.Description = *model.Description + } + + geminiResponse.Models = append(geminiResponse.Models, geminiModel) + } + + return geminiResponse +} diff --git a/core/schemas/providers/gemini/types.go b/core/schemas/providers/gemini/types.go index 3ac9e0580f..a49fa0f7ab 100644 --- a/core/schemas/providers/gemini/types.go +++ b/core/schemas/providers/gemini/types.go @@ -1249,3 +1249,27 @@ type GeminiGenerationError struct { } `json:"details"` } `json:"error"` } + +// ==================== MODEL TYPES ==================== + +type GeminiModel struct { + Name string `json:"name"` + BaseModelID string `json:"baseModelId"` + Version string `json:"version"` + DisplayName string `json:"displayName"` + Description string `json:"description"` + InputTokenLimit int `json:"inputTokenLimit"` + OutputTokenLimit int `json:"outputTokenLimit"` + SupportedGenerationMethods []string `json:"supportedGenerationMethods"` + Thinking bool `json:"thinking"` + Temperature float64 `json:"temperature"` + MaxTemperature float64 `json:"maxTemperature"` + TopP float64 `json:"topP"` + TopK int `json:"topK"` +} + +// GeminiListModelsResponse represents the response from Google Gemini's list models API. +type GeminiListModelsResponse struct { + Models []GeminiModel `json:"models"` + NextPageToken string `json:"nextPageToken"` +} diff --git a/core/schemas/providers/mistral/models.go b/core/schemas/providers/mistral/models.go new file mode 100644 index 0000000000..181b9f2b79 --- /dev/null +++ b/core/schemas/providers/mistral/models.go @@ -0,0 +1,27 @@ +package mistral + +import "github.com/maximhq/bifrost/core/schemas" + +func (response *MistralListModelsResponse) ToBifrostListModelsResponse() *schemas.BifrostListModelsResponse { + if response == nil { + return nil + } + + bifrostResponse := &schemas.BifrostListModelsResponse{ + Data: make([]schemas.Model, 0, len(response.Data)), + } + + for _, model := range response.Data { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ + ID: string(schemas.Mistral) + "/" + model.ID, + Name: schemas.Ptr(model.Name), + Description: schemas.Ptr(model.Description), + Created: schemas.Ptr(model.Created), + ContextLength: schemas.Ptr(int(model.MaxContextLength)), + OwnedBy: schemas.Ptr(model.OwnedBy), + }) + + } + + return bifrostResponse +} diff --git a/core/schemas/providers/mistral/types.go b/core/schemas/providers/mistral/types.go new file mode 100644 index 0000000000..3a8431ab0c --- /dev/null +++ b/core/schemas/providers/mistral/types.go @@ -0,0 +1,34 @@ +package mistral + +// MistralModel represents a single model in the Mistral Models API response +type MistralModel struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` + Capabilities Capabilities `json:"capabilities"` + Name string `json:"name"` + Description string `json:"description"` + MaxContextLength int `json:"max_context_length"` + Aliases []string `json:"aliases"` + Deprecation *string `json:"deprecation,omitempty"` + DeprecationReplacementModel *string `json:"deprecation_replacement_model,omitempty"` + DefaultModelTemperature float64 `json:"default_model_temperature"` + Type string `json:"type"` +} + +// Capabilities describes the model's supported features +type Capabilities struct { + CompletionChat bool `json:"completion_chat"` + CompletionFim bool `json:"completion_fim"` + FunctionCalling bool `json:"function_calling"` + FineTuning bool `json:"fine_tuning"` + Vision bool `json:"vision"` + Classification bool `json:"classification"` +} + +// MistralListModelsResponse is the root response object from the Mistral Models API +type MistralListModelsResponse struct { + Object string `json:"object"` + Data []MistralModel `json:"data"` +} diff --git a/core/schemas/providers/openai/models.go b/core/schemas/providers/openai/models.go new file mode 100644 index 0000000000..7d815f35c0 --- /dev/null +++ b/core/schemas/providers/openai/models.go @@ -0,0 +1,54 @@ +package openai + +import "github.com/maximhq/bifrost/core/schemas" + +func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider) *schemas.BifrostListModelsResponse { + if response == nil { + return nil + } + + bifrostResponse := &schemas.BifrostListModelsResponse{ + Data: make([]schemas.Model, 0, len(response.Data)), + } + + for _, model := range response.Data { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ + ID: string(providerKey) + "/" + model.ID, + Created: model.Created, + OwnedBy: schemas.Ptr(model.OwnedBy), + ContextLength: model.ContextWindow, + }) + + } + + return bifrostResponse +} + +func ToOpenAIListModelsResponse(response *schemas.BifrostListModelsResponse) *OpenAIListModelsResponse { + + if response == nil { + return nil + } + + openaiResponse := &OpenAIListModelsResponse{ + Data: make([]OpenAIModel, 0, len(response.Data)), + } + + for _, model := range response.Data { + openaiModel := OpenAIModel{ + ID: model.ID, + Object: "model", + } + if model.Created != nil { + openaiModel.Created = model.Created + } + if model.OwnedBy != nil { + openaiModel.OwnedBy = *model.OwnedBy + } + + openaiResponse.Data = append(openaiResponse.Data, openaiModel) + + } + + return openaiResponse +} diff --git a/core/schemas/providers/openai/types.go b/core/schemas/providers/openai/types.go index 5fe1ba6474..e3c0440177 100644 --- a/core/schemas/providers/openai/types.go +++ b/core/schemas/providers/openai/types.go @@ -124,3 +124,20 @@ func (r *OpenAISpeechRequest) IsStreamingRequested() bool { func (r *OpenAITranscriptionRequest) IsStreamingRequested() bool { return r.Stream != nil && *r.Stream } + +// MODEL TYPES +type OpenAIModel struct { + ID string `json:"id"` + Object string `json:"object"` + OwnedBy string `json:"owned_by"` + Created *int64 `json:"created,omitempty"` + + // GROQ specific fields + Active *bool `json:"active,omitempty"` + ContextWindow *int `json:"context_window,omitempty"` +} + +type OpenAIListModelsResponse struct { + Object string `json:"object"` + Data []OpenAIModel `json:"data"` +} diff --git a/core/schemas/providers/vertex/models.go b/core/schemas/providers/vertex/models.go new file mode 100644 index 0000000000..2ad4cb9702 --- /dev/null +++ b/core/schemas/providers/vertex/models.go @@ -0,0 +1,48 @@ +package vertex + +import ( + "net/url" + "strconv" + + "github.com/maximhq/bifrost/core/schemas" +) + +func ToVertexListModelsURL(request *schemas.BifrostListModelsRequest, baseURL string) string { + // Add limit parameter (default to 100 for Vertex) + pageSize := request.PageSize + if pageSize <= 0 { + pageSize = DefaultPageSize + } + + // Build query parameters + params := url.Values{} + params.Set("pageSize", strconv.Itoa(pageSize)) + + if request.PageToken != "" { + params.Set("pageToken", request.PageToken) + } + + return baseURL + "?" + params.Encode() +} + +func (response *VertexListModelsResponse) ToBifrostListModelsResponse() *schemas.BifrostListModelsResponse { + if response == nil { + return nil + } + + bifrostResponse := &schemas.BifrostListModelsResponse{ + Data: make([]schemas.Model, 0, len(response.Models)), + NextPageToken: response.NextPageToken, + } + + for _, model := range response.Models { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ + ID: string(schemas.Vertex) + "/" + model.Name, + Name: schemas.Ptr(model.DisplayName), + Description: schemas.Ptr(model.Description), + Created: schemas.Ptr(model.VersionCreateTime.Unix()), + }) + } + + return bifrostResponse +} diff --git a/core/schemas/providers/vertex/types.go b/core/schemas/providers/vertex/types.go index 797ccfa1d5..a974d8f704 100644 --- a/core/schemas/providers/vertex/types.go +++ b/core/schemas/providers/vertex/types.go @@ -1,12 +1,14 @@ package vertex +import "time" + // Vertex AI Embedding API types // VertexEmbeddingInstance represents a single embedding instance in the request type VertexEmbeddingInstance struct { - Content string `json:"content"` // The text to generate embeddings for + Content string `json:"content"` // The text to generate embeddings for TaskType *string `json:"task_type,omitempty"` // Intended downstream application (optional) - Title *string `json:"title,omitempty"` // Used to help the model produce better embeddings (optional) + Title *string `json:"title,omitempty"` // Used to help the model produce better embeddings (optional) } // VertexEmbeddingParameters represents the parameters for the embedding request @@ -42,3 +44,21 @@ type VertexEmbeddingPrediction struct { type VertexEmbeddingResponse struct { Predictions []VertexEmbeddingPrediction `json:"predictions"` // List of embedding predictions } + +// ================================ Model Types ================================ + +const DefaultPageSize = 100 + +type VertexModel struct { + Name string `json:"name"` + VersionId string `json:"versionId"` + VersionAliases []string `json:"versionAliases"` + VersionCreateTime time.Time `json:"versionCreateTime"` + DisplayName string `json:"displayName"` + Description string `json:"description"` +} + +type VertexListModelsResponse struct { + Models []VertexModel `json:"models"` + NextPageToken string `json:"nextPageToken"` +} diff --git a/core/version b/core/version index 0b1f1edf11..fd9d1a5aca 100644 --- a/core/version +++ b/core/version @@ -1 +1 @@ -1.2.13 +1.2.14 diff --git a/docs/apis/openapi.json b/docs/apis/openapi.json index 3dfc9be5f3..082067c63e 100644 --- a/docs/apis/openapi.json +++ b/docs/apis/openapi.json @@ -51,6 +51,10 @@ { "name": "MCP", "description": "Endpoints for Model Context Protocol (MCP) integrations." + }, + { + "name": "Models", + "description": "Endpoint for listing available models." } ], "paths": { @@ -284,6 +288,66 @@ } } }, + "/v1/models": { + "get": { + "summary": "List Models", + "description": "Lists available models. If a provider is specified, returns that provider's models; otherwise, returns models from all configured providers.\n\n**Note:** Only fields returned by the provider API are included in the response. Fields not provided by the provider are omitted from the JSON response.", + "operationId": "listModels", + "tags": [ + "Bifrost Core", + "Models" + ], + "parameters": [ + { + "name": "provider", + "in": "query", + "required": false, + "description": "The provider to list models from. If not set, returns models from all configured providers in Bifrost.", + "schema": { + "$ref": "#/components/schemas/ModelProvider" + }, + "example": "openai" + }, + { + "name": "page_size", + "in": "query", + "required": false, + "description": "Number of models to return per page. If not specified, the provider's default page size will be used.", + "schema": { + "type": "integer" + }, + "example": 100 + }, + { + "name": "page_token", + "in": "query", + "required": false, + "description": "Token received from previous request to retrieve the next page of results. Omit this parameter to fetch the first page.", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Successful response with list of models", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ListModelsResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, "/health": { "get": { "summary": "Get Health", @@ -3195,6 +3259,7 @@ "RequestType": { "type": "string", "enum": [ + "list_models", "text_completion", "chat_completion", "chat_completion_stream", @@ -3916,8 +3981,8 @@ }, "latency": { "type": "number", - "description": "Request latency in seconds", - "example": 1.234 + "description": "Request latency in milliseconds", + "example": 1234 }, "billed_usage": { "$ref": "#/components/schemas/BilledLLMUsage" @@ -4559,8 +4624,8 @@ }, "latency": { "type": "number", - "description": "Request latency in seconds", - "example": 1.234 + "description": "Request latency in milliseconds", + "example": 1234 }, "tokens": { "type": "integer", @@ -6491,6 +6556,258 @@ "$ref": "#/components/schemas/SearchStats" } } + }, + "ListModelsResponse": { + "type": "object", + "required": [ + "data" + ], + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ModelInfo" + }, + "description": "Array of model information objects" + }, + "next_page_token": { + "type": "string", + "description": "Token to retrieve the next page of results. Omitted if there are no more pages.", + "example": "eyJwYWdlIjoxfQ==" + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "$ref": "#/components/schemas/ModelProvider" + }, + "request_type": { + "$ref": "#/components/schemas/RequestType" + }, + "latency": { + "type": "number", + "description": "Request latency in milliseconds", + "example": 1234 + }, + "raw_response": { + "type": "object", + "description": "Raw provider response" + } + } + } + } + }, + "ModelInfo": { + "type": "object", + "description": "Model information object. Only fields returned by the provider API are included; unset fields are omitted from the JSON response.", + "required": [ + "id" + ], + "properties": { + "id": { + "type": "string", + "description": "Unique model identifier", + "example": "openrouter/openai/gpt-4o" + }, + "canonical_slug": { + "type": "string", + "description": "Canonical slug for the model", + "example": "openai/gpt-4o", + "nullable": true + }, + "name": { + "type": "string", + "description": "Human-readable model name", + "example": "OpenAI: GPT-4o" + }, + "created": { + "type": "integer", + "description": "Unix timestamp of model creation", + "example": 1715558400, + "nullable": true + }, + "context_length": { + "type": "integer", + "description": "Maximum context length in tokens", + "example": 128000, + "nullable": true + }, + "architecture": { + "$ref": "#/components/schemas/ModelArchitecture", + "description": "Object describing the model's technical capabilities", + "nullable": true + }, + "pricing": { + "$ref": "#/components/schemas/ModelPricing", + "description": "Lowest price structure for using this model (all values in USD per token/request/unit)", + "nullable": true + }, + "top_provider": { + "$ref": "#/components/schemas/ModelTopProvider", + "description": "Configuration details for the primary provider", + "nullable": true + }, + "per_request_limits": { + "description": "Rate limiting information (null if no limits)", + "$ref": "#/components/schemas/ModelPerRequestLimits", + "nullable": true + }, + "supported_parameters": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of supported parameter names", + "nullable": true + }, + "default_parameters": { + "$ref": "#/components/schemas/ModelDefaultParameters", + "nullable": true + }, + "hugging_face_id": { + "type": "string", + "description": "Hugging Face model identifier", + "nullable": true + }, + "description": { + "type": "string", + "description": "Description of the model and its capabilities", + "example": "GPT-4o (\"o\" for \"omni\") is OpenAI's latest AI model, supporting both text and image inputs with text outputs", + "nullable": true + }, + "owned_by": { + "type": "string", + "description": "Organization that owns the model", + "example": "openai", + "nullable": true + }, + "supported_methods": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of supported API methods", + "nullable": true + } + } + }, + "ModelArchitecture": { + "type": "object", + "properties": { + "tokenizer": { + "type": "string", + "example": "GPT", + "description": "Tokenization method used by the model" + }, + "instruct_type": { + "description": "Instruction format type (null if not applicable)", + "type": "string", + "nullable": true + }, + "input_modalities": { + "type": "array", + "items": { + "type": "string", + "enum": ["file", "image", "text", "audio", "video"] + }, + "example": ["text", "image", "file"], + "description": "Supported input types for the model" + }, + "output_modalities": { + "type": "array", + "items": { + "type": "string", + "enum": ["text", "image", "audio", "video"] + }, + "example": ["text"], + "description": "Supported output types for the model" + }, + "modality": { + "type": "string", + "description": "Primary model modality", + "example": "text+image->text" + } + } + }, + "ModelPricing": { + "type": "object", + "description": "All pricing values are in USD per token/request/unit. A value of '0' indicates the feature is free.", + "properties": { + "prompt": { + "type": "string", + "description": "Cost per input token in USD", + "example": "0.0000025" + }, + "completion": { + "type": "string", + "description": "Cost per output token in USD", + "example": "0.00001" + }, + "request": { + "type": "string", + "description": "Fixed cost per API request in USD", + "example": "0" + }, + "image": { + "type": "string", + "description": "Cost per image input in USD", + "example": "0.003613" + }, + "web_search": { + "type": "string", + "description": "Cost per web search operation in USD", + "example": "0" + }, + "internal_reasoning": { + "type": "string", + "description": "Cost for internal reasoning tokens in USD", + "example": "0" + }, + "input_cache_read": { + "type": "string", + "description": "Cost per cached input token read in USD", + "example": "0.00000125" + }, + "input_cache_write": { + "type": "string", + "description": "Cost per cached input token write in USD", + "example": "0" + } + } + }, + "ModelTopProvider": { + "type": "object", + "description": "Configuration details for the primary provider", + "properties": { + "context_length": { + "type": "integer", + "description": "Provider-specific context limit in tokens" + }, + "max_completion_tokens": { + "type": "integer", + "description": "Maximum completion tokens" + }, + "is_moderated": { + "type": "boolean", + "description": "Whether content moderation is applied to the model output" + } + } + }, + "ModelPerRequestLimits": { + "type": "object", + "properties": { + "prompt_tokens": { + "type": "integer", + "description": "Maximum prompt tokens per request" + }, + "completion_tokens": { + "type": "integer", + "description": "Maximum completion tokens per request" + } + } + }, + "ModelDefaultParameters": { + "type": "object" } }, "responses": { diff --git a/docs/features/custom-providers.mdx b/docs/features/custom-providers.mdx index 484cd2e205..a248b4ca7d 100644 --- a/docs/features/custom-providers.mdx +++ b/docs/features/custom-providers.mdx @@ -61,6 +61,7 @@ curl --location 'http://localhost:8080/api/providers' \ "custom_provider_config": { "base_provider_type": "openai", "allowed_requests": { + "list_models": false, "text_completion": false, "chat_completion": true, "chat_completion_stream": true, @@ -92,6 +93,7 @@ curl --location 'http://localhost:8080/api/providers' \ "custom_provider_config": { "base_provider_type": "openai", "allowed_requests": { + "list_models": false, "text_completion": false, "chat_completion": true, "chat_completion_stream": true, diff --git a/docs/features/unified-interface.mdx b/docs/features/unified-interface.mdx index 7487471ce0..0b6a581eab 100644 --- a/docs/features/unified-interface.mdx +++ b/docs/features/unified-interface.mdx @@ -85,24 +85,25 @@ response, err := client.ChatCompletionRequest(ctx, &schemas.BifrostChatRequest{ The following table summarizes which operations are supported by each provider via Bifrost’s unified interface. -| Provider | Text | Text (stream) | Chat | Chat (stream) | Responses | Responses (stream) | Embeddings | TTS | TTS (stream) | STT | STT (stream) | -|----------|------|----------------|------|---------------|-----------|--------------------|------------|-----|-------------|-----|--------------| -| Anthropic (`anthropic/`) | Yes | No | Yes | Yes | Yes | Yes | No | No | No | No | No | -| Azure OpenAI (`azure/`) | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | -| Bedrock (`bedrock/`) | Yes | No | Yes | Yes | Yes | Yes | Yes | No | No | No | No | -| Cerebras (`cerebras/`) | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | No | -| Cohere (`cohere/`) | No | No | Yes | Yes | Yes | Yes | Yes | No | No | No | No | -| Gemini (`gemini/`) | No | No | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | -| Groq (`groq/`) | No | No | Yes | Yes | Yes | Yes | No | No | No | No | No | -| Mistral (`mistral/`) | No | No | Yes | Yes | Yes | Yes | Yes | No | No | No | No | -| Ollama (`ollama/`) | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | -| OpenAI (`openai/`) | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | -| OpenRouter (`openrouter/`) | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | No | -| Parasail (`parasail/`) | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | No | -| SGL (`sgl/`) | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | -| Vertex AI (`vertex/`) | No | No | Yes | Yes | Yes | Yes | Yes | No | No | No | No | +| Provider | Models | Text | Text (stream) | Chat | Chat (stream) | Responses | Responses (stream) | Embeddings | TTS | TTS (stream) | STT | STT (stream) | +|----------|--------|------|----------------|------|---------------|-----------|--------------------|------------|-----|-------------|-----|--------------| +| Anthropic (`anthropic/`) | Yes | Yes | No | Yes | Yes | Yes | Yes | No | No | No | No | No | +| Azure OpenAI (`azure/`) | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | +| Bedrock (`bedrock/`) | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | No | No | No | No | +| Cerebras (`cerebras/`) | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | No | +| Cohere (`cohere/`) | Yes | No | No | Yes | Yes | Yes | Yes | Yes | No | No | No | No | +| Gemini (`gemini/`) | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | +| Groq (`groq/`) | Yes | No | No | Yes | Yes | Yes | Yes | No | No | No | No | No | +| Mistral (`mistral/`) | Yes | No | No | Yes | Yes | Yes | Yes | Yes | No | No | No | No | +| Ollama (`ollama/`) | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | +| OpenAI (`openai/`) | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | +| OpenRouter (`openrouter/`) | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | No | +| Parasail (`parasail/`) | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | No | +| SGL (`sgl/`) | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | +| Vertex AI (`vertex/`) | Yes | No | No | Yes | Yes | Yes | Yes | Yes | No | No | No | No | Notes: +- “Models” refers to the list models operation (`/v1/models`). - “Text” refers to the classic text completion interface (`/v1/completions`). - “Responses” refers to the OpenAI-style Responses API (`/v1/responses`). Non-OpenAI providers map this to their native chat API under the hood. - TTS corresponds to `/v1/audio/speech` and STT to `/v1/audio/transcriptions`. diff --git a/framework/changelog.md b/framework/changelog.md index d539ca8293..40ba9b6d85 100644 --- a/framework/changelog.md +++ b/framework/changelog.md @@ -1,5 +1,4 @@ -- chore: version update core to 1.2.13 -- feat: added support for vertex provider/model format in pricing lookup \ No newline at end of file +- chore: version update core to 1.2.14 \ No newline at end of file diff --git a/framework/version b/framework/version index 645377eea8..63b283b23a 100644 --- a/framework/version +++ b/framework/version @@ -1 +1 @@ -1.1.15 +1.1.16 diff --git a/plugins/governance/changelog.md b/plugins/governance/changelog.md index 51f9eb7087..747b283dbc 100644 --- a/plugins/governance/changelog.md +++ b/plugins/governance/changelog.md @@ -1,4 +1,4 @@ -- chore: version update core to 1.2.13 and framework to 1.1.15 +- chore: version update core to 1.2.14 and framework to 1.1.16 diff --git a/plugins/governance/version b/plugins/governance/version index 25b22e060e..ef40e4d0f7 100644 --- a/plugins/governance/version +++ b/plugins/governance/version @@ -1 +1 @@ -1.3.16 +1.3.17 diff --git a/plugins/jsonparser/changelog.md b/plugins/jsonparser/changelog.md index 51f9eb7087..747b283dbc 100644 --- a/plugins/jsonparser/changelog.md +++ b/plugins/jsonparser/changelog.md @@ -1,4 +1,4 @@ -- chore: version update core to 1.2.13 and framework to 1.1.15 +- chore: version update core to 1.2.14 and framework to 1.1.16 diff --git a/plugins/jsonparser/version b/plugins/jsonparser/version index 92ee6ac2fa..287e5b8247 100644 --- a/plugins/jsonparser/version +++ b/plugins/jsonparser/version @@ -1 +1 @@ -1.3.15 \ No newline at end of file +1.3.16 \ No newline at end of file diff --git a/plugins/logging/changelog.md b/plugins/logging/changelog.md index 51f9eb7087..747b283dbc 100644 --- a/plugins/logging/changelog.md +++ b/plugins/logging/changelog.md @@ -1,4 +1,4 @@ -- chore: version update core to 1.2.13 and framework to 1.1.15 +- chore: version update core to 1.2.14 and framework to 1.1.16 diff --git a/plugins/logging/version b/plugins/logging/version index 5bdcf5c395..25b22e060e 100644 --- a/plugins/logging/version +++ b/plugins/logging/version @@ -1 +1 @@ -1.3.15 +1.3.16 diff --git a/plugins/maxim/changelog.md b/plugins/maxim/changelog.md index 51f9eb7087..747b283dbc 100644 --- a/plugins/maxim/changelog.md +++ b/plugins/maxim/changelog.md @@ -1,4 +1,4 @@ -- chore: version update core to 1.2.13 and framework to 1.1.15 +- chore: version update core to 1.2.14 and framework to 1.1.16 diff --git a/plugins/maxim/version b/plugins/maxim/version index 8a3b8ac20b..6ee81aba98 100644 --- a/plugins/maxim/version +++ b/plugins/maxim/version @@ -1 +1 @@ -1.4.15 +1.4.16 diff --git a/plugins/mocker/changelog.md b/plugins/mocker/changelog.md index c7113783b7..747b283dbc 100644 --- a/plugins/mocker/changelog.md +++ b/plugins/mocker/changelog.md @@ -1,6 +1,4 @@ -- chore: version update core to 1.2.13 and framework to 1.1.15 -- feat: added support for responses request -- feat: added "skip-mocker" context key to skip mocker plugin per request +- chore: version update core to 1.2.14 and framework to 1.1.16 diff --git a/plugins/mocker/version b/plugins/mocker/version index 92ee6ac2fa..287e5b8247 100644 --- a/plugins/mocker/version +++ b/plugins/mocker/version @@ -1 +1 @@ -1.3.15 \ No newline at end of file +1.3.16 \ No newline at end of file diff --git a/plugins/otel/changelog.md b/plugins/otel/changelog.md index bc74baa33f..747b283dbc 100644 --- a/plugins/otel/changelog.md +++ b/plugins/otel/changelog.md @@ -1,6 +1,4 @@ -- chore: version update core to 1.2.13 and framework to 1.1.15 -- feat: added headers support for OTel configuration. Value prefixed with env will be fetched from environment variables (env.) -- feat: emission of OTel resource spans is completely async - this brings down inference overhead to < 1µsecond \ No newline at end of file +- chore: version update core to 1.2.14 and framework to 1.1.16 diff --git a/plugins/otel/version b/plugins/otel/version index 758a46e9bd..d941c12bd0 100644 --- a/plugins/otel/version +++ b/plugins/otel/version @@ -1 +1 @@ -1.0.15 \ No newline at end of file +1.0.16 \ No newline at end of file diff --git a/plugins/semanticcache/changelog.md b/plugins/semanticcache/changelog.md index 5781539e61..747b283dbc 100644 --- a/plugins/semanticcache/changelog.md +++ b/plugins/semanticcache/changelog.md @@ -1,5 +1,4 @@ -- chore: version update core to 1.2.13 and framework to 1.1.15 -- tests: added mocker plugin to all chat/responses tests +- chore: version update core to 1.2.14 and framework to 1.1.16 diff --git a/plugins/semanticcache/version b/plugins/semanticcache/version index 5bdcf5c395..25b22e060e 100644 --- a/plugins/semanticcache/version +++ b/plugins/semanticcache/version @@ -1 +1 @@ -1.3.15 +1.3.16 diff --git a/plugins/telemetry/changelog.md b/plugins/telemetry/changelog.md index 51f9eb7087..747b283dbc 100644 --- a/plugins/telemetry/changelog.md +++ b/plugins/telemetry/changelog.md @@ -1,4 +1,4 @@ -- chore: version update core to 1.2.13 and framework to 1.1.15 +- chore: version update core to 1.2.14 and framework to 1.1.16 diff --git a/plugins/telemetry/version b/plugins/telemetry/version index 92ee6ac2fa..287e5b8247 100644 --- a/plugins/telemetry/version +++ b/plugins/telemetry/version @@ -1 +1 @@ -1.3.15 \ No newline at end of file +1.3.16 \ No newline at end of file diff --git a/tests/core-providers/anthropic_test.go b/tests/core-providers/anthropic_test.go index 3229593bbd..34a722ad9e 100644 --- a/tests/core-providers/anthropic_test.go +++ b/tests/core-providers/anthropic_test.go @@ -43,6 +43,7 @@ func TestAnthropic(t *testing.T) { CompleteEnd2End: true, Embedding: false, Reasoning: true, + ListModels: true, }, } diff --git a/tests/core-providers/azure_test.go b/tests/core-providers/azure_test.go index 2d702dc036..d6e76574ff 100644 --- a/tests/core-providers/azure_test.go +++ b/tests/core-providers/azure_test.go @@ -44,6 +44,7 @@ func TestAzure(t *testing.T) { MultipleImages: true, CompleteEnd2End: true, Embedding: true, + ListModels: true, }, } diff --git a/tests/core-providers/bedrock_test.go b/tests/core-providers/bedrock_test.go index 82b68853e4..cb796811f7 100644 --- a/tests/core-providers/bedrock_test.go +++ b/tests/core-providers/bedrock_test.go @@ -45,6 +45,7 @@ func TestBedrock(t *testing.T) { CompleteEnd2End: true, Embedding: true, Reasoning: true, + ListModels: true, }, } diff --git a/tests/core-providers/cerebras_test.go b/tests/core-providers/cerebras_test.go index d5c9fbffe6..37cb53fc16 100644 --- a/tests/core-providers/cerebras_test.go +++ b/tests/core-providers/cerebras_test.go @@ -44,6 +44,7 @@ func TestCerebras(t *testing.T) { MultipleImages: false, CompleteEnd2End: true, Embedding: false, + ListModels: true, }, } diff --git a/tests/core-providers/cohere_test.go b/tests/core-providers/cohere_test.go index 7fcaec0025..907bc5891a 100644 --- a/tests/core-providers/cohere_test.go +++ b/tests/core-providers/cohere_test.go @@ -41,6 +41,7 @@ func TestCohere(t *testing.T) { CompleteEnd2End: false, Embedding: true, Reasoning: true, + ListModels: true, }, } diff --git a/tests/core-providers/config/account.go b/tests/core-providers/config/account.go index 0ab40f146c..332a4ad77f 100644 --- a/tests/core-providers/config/account.go +++ b/tests/core-providers/config/account.go @@ -37,6 +37,7 @@ type TestScenarios struct { TranscriptionStream bool // Streaming speech-to-text functionality Embedding bool // Embedding functionality Reasoning bool // Reasoning/thinking functionality via Responses API + ListModels bool // List available models functionality } // ComprehensiveTestConfig extends TestConfig with additional scenarios @@ -179,10 +180,9 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx *context.Context Deployments: map[string]string{ "text-embedding-ada-002": "text-embedding-ada-002", }, - // Use environment variable for API version with fallback to current preview version - // Note: This is a preview API version that may change over time. Update as needed. + // Use environment variable for API version with fallback to current stable version // Set AZURE_API_VERSION environment variable to override the default. - APIVersion: bifrost.Ptr(getEnvWithDefault("AZURE_API_VERSION", "2024-08-01-preview")), + APIVersion: bifrost.Ptr(getEnvWithDefault("AZURE_API_VERSION", "2024-10-21")), }, }, }, nil @@ -500,6 +500,7 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ TranscriptionStream: true, // OpenAI supports streaming STT Embedding: true, Reasoning: true, // OpenAI supports reasoning via o1 models + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.Anthropic, Model: "claude-3-7-sonnet-20250219"}, @@ -527,6 +528,7 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ Transcription: false, // Not supported TranscriptionStream: false, // Not supported Embedding: false, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -554,6 +556,7 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ Transcription: false, // Not supported TranscriptionStream: false, // Not supported Embedding: true, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -581,6 +584,7 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ Transcription: false, // Not supported TranscriptionStream: false, // Not supported Embedding: true, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -608,6 +612,7 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ Transcription: false, // Not supported yet TranscriptionStream: false, // Not supported yet Embedding: true, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -635,6 +640,7 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ Transcription: false, // Not supported TranscriptionStream: false, // Not supported Embedding: true, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -661,6 +667,7 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ Transcription: false, // Not supported TranscriptionStream: false, // Not supported Embedding: true, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -688,6 +695,7 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ Transcription: false, // Not supported TranscriptionStream: false, // Not supported Embedding: false, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -715,6 +723,7 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ Transcription: false, // Not supported TranscriptionStream: false, // Not supported Embedding: false, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -742,6 +751,7 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ Transcription: false, // Not supported TranscriptionStream: false, // Not supported Embedding: false, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -772,6 +782,7 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ Transcription: true, TranscriptionStream: true, Embedding: true, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -799,6 +810,7 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ Transcription: false, TranscriptionStream: false, Embedding: false, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, diff --git a/tests/core-providers/gemini_test.go b/tests/core-providers/gemini_test.go index a9c731567c..718c994335 100644 --- a/tests/core-providers/gemini_test.go +++ b/tests/core-providers/gemini_test.go @@ -50,6 +50,7 @@ func TestGemini(t *testing.T) { SpeechSynthesis: true, SpeechSynthesisStream: true, Reasoning: false, //TODO: Supported but lost since we map Gemini's responses via chat completions, fix is a native Gemini handler or reasoning support in chat completions + ListModels: true, }, } diff --git a/tests/core-providers/groq_test.go b/tests/core-providers/groq_test.go index 5cfd8662c3..b3b8be693d 100644 --- a/tests/core-providers/groq_test.go +++ b/tests/core-providers/groq_test.go @@ -47,6 +47,7 @@ func TestGroq(t *testing.T) { MultipleImages: false, CompleteEnd2End: true, Embedding: false, + ListModels: true, }, } diff --git a/tests/core-providers/mistral_test.go b/tests/core-providers/mistral_test.go index 17c241fdbc..82731b3fe0 100644 --- a/tests/core-providers/mistral_test.go +++ b/tests/core-providers/mistral_test.go @@ -42,6 +42,7 @@ func TestMistral(t *testing.T) { MultipleImages: true, CompleteEnd2End: true, Embedding: true, + ListModels: true, }, } diff --git a/tests/core-providers/ollama_test.go b/tests/core-providers/ollama_test.go index c43383988b..ee97daacc2 100644 --- a/tests/core-providers/ollama_test.go +++ b/tests/core-providers/ollama_test.go @@ -39,6 +39,7 @@ func TestOllama(t *testing.T) { MultipleImages: false, CompleteEnd2End: true, Embedding: false, + ListModels: true, }, } diff --git a/tests/core-providers/openai_test.go b/tests/core-providers/openai_test.go index 8bb6491f68..774d7cb944 100644 --- a/tests/core-providers/openai_test.go +++ b/tests/core-providers/openai_test.go @@ -55,6 +55,7 @@ func TestOpenAI(t *testing.T) { TranscriptionStream: true, Embedding: true, Reasoning: true, + ListModels: true, }, } diff --git a/tests/core-providers/openrouter_test.go b/tests/core-providers/openrouter_test.go index bb4b805aa9..488f58a341 100644 --- a/tests/core-providers/openrouter_test.go +++ b/tests/core-providers/openrouter_test.go @@ -39,6 +39,7 @@ func TestOpenRouter(t *testing.T) { ImageBase64: false, // OpenRouter's responses API is in Beta MultipleImages: false, // OpenRouter's responses API is in Beta CompleteEnd2End: false, // OpenRouter's responses API is in Beta + ListModels: true, }, } diff --git a/tests/core-providers/parasail_test.go b/tests/core-providers/parasail_test.go index e8f03ee4e2..070fa572a2 100644 --- a/tests/core-providers/parasail_test.go +++ b/tests/core-providers/parasail_test.go @@ -39,6 +39,7 @@ func TestParasail(t *testing.T) { MultipleImages: false, // Not supported yet CompleteEnd2End: true, Embedding: false, // Not supported yet + ListModels: true, }, } diff --git a/tests/core-providers/scenarios/list_models.go b/tests/core-providers/scenarios/list_models.go new file mode 100644 index 0000000000..419249fcea --- /dev/null +++ b/tests/core-providers/scenarios/list_models.go @@ -0,0 +1,161 @@ +package scenarios + +import ( + "context" + "os" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// RunListModelsTest executes the list models test scenario +func RunListModelsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.ListModels { + t.Logf("List models not supported for provider %s", testConfig.Provider) + return + } + + t.Run("ListModels", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + // Create basic list models request + request := &schemas.BifrostListModelsRequest{ + Provider: testConfig.Provider, + } + + // Execute list models request + response, bifrostErr := client.ListModelsRequest(ctx, request) + if bifrostErr != nil { + t.Fatalf("❌ List models request failed: %v", GetErrorMessage(bifrostErr)) + } + + // Validate response structure + if response == nil { + t.Fatal("❌ List models response is nil") + } + + // Validate that we have models in the response + if len(response.Data) == 0 { + t.Fatal("❌ List models response contains no models") + } + + t.Logf("✅ List models returned %d models", len(response.Data)) + + // Validate individual model entries + validModels := 0 + for i, model := range response.Data { + if model.ID == "" { + t.Errorf("❌ Model at index %d has empty ID", i) + continue + } + + // Log a few sample models for verification + if i < 5 { + t.Logf(" Model %d: ID=%s", i+1, model.ID) + } + + validModels++ + } + + if validModels == 0 { + t.Fatal("❌ No valid models found in response") + } + + t.Logf("✅ Validated %d models with proper structure", validModels) + + // Validate extra fields + if response.ExtraFields.Provider != testConfig.Provider { + t.Errorf("❌ Provider mismatch: expected %s, got %s", testConfig.Provider, response.ExtraFields.Provider) + } + + if response.ExtraFields.RequestType != schemas.ListModelsRequest { + t.Errorf("❌ Request type mismatch: expected %s, got %s", schemas.ListModelsRequest, response.ExtraFields.RequestType) + } + + // Validate latency is reasonable (non-negative and not absurdly high) + if response.ExtraFields.Latency < 0 { + t.Errorf("❌ Invalid latency: %d ms (should be non-negative)", response.ExtraFields.Latency) + } else if response.ExtraFields.Latency > 30000 { + t.Logf("⚠️ Warning: High latency detected: %d ms", response.ExtraFields.Latency) + } else { + t.Logf("✅ Request latency: %d ms", response.ExtraFields.Latency) + } + + t.Logf("🎉 List models test passed successfully!") + }) +} + +// RunListModelsPaginationTest executes pagination test for list models +func RunListModelsPaginationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.ListModels { + t.Logf("List models not supported for provider %s", testConfig.Provider) + return + } + + t.Run("ListModelsPagination", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + // Test pagination with page size + pageSize := 5 + request := &schemas.BifrostListModelsRequest{ + Provider: testConfig.Provider, + PageSize: pageSize, + } + + response, bifrostErr := client.ListModelsRequest(ctx, request) + if bifrostErr != nil { + t.Fatalf("❌ List models pagination request failed: %v", GetErrorMessage(bifrostErr)) + } + + if response == nil { + t.Fatal("❌ List models pagination response is nil") + } + + // Check that pagination was applied + if len(response.Data) > pageSize { + t.Errorf("❌ Expected at most %d models, got %d", pageSize, len(response.Data)) + } else { + t.Logf("✅ Pagination working: returned %d models (page size: %d)", len(response.Data), pageSize) + } + + // Test with page token if provided + if response.NextPageToken != "" { + t.Logf("✅ Next page token available: %s", response.NextPageToken) + + // Fetch next page + nextPageRequest := &schemas.BifrostListModelsRequest{ + Provider: testConfig.Provider, + PageSize: pageSize, + PageToken: response.NextPageToken, + } + + nextPageResponse, nextPageErr := client.ListModelsRequest(ctx, nextPageRequest) + if nextPageErr != nil { + t.Errorf("❌ Failed to fetch next page: %v", GetErrorMessage(nextPageErr)) + } else if nextPageResponse != nil { + t.Logf("✅ Successfully fetched next page with %d models", len(nextPageResponse.Data)) + + // Verify that the next page contains different models + if len(response.Data) > 0 && len(nextPageResponse.Data) > 0 { + firstPageFirstModel := response.Data[0].ID + secondPageFirstModel := nextPageResponse.Data[0].ID + if firstPageFirstModel != secondPageFirstModel { + t.Logf("✅ Pages contain different models (first page: %s, second page: %s)", + firstPageFirstModel, secondPageFirstModel) + } + } + } + } else { + t.Logf("ℹ️ No next page token - all models returned in single page") + } + + t.Logf("🎉 List models pagination test completed!") + }) +} diff --git a/tests/core-providers/sgl_test.go b/tests/core-providers/sgl_test.go index 7cd9024cc5..2b48b41446 100644 --- a/tests/core-providers/sgl_test.go +++ b/tests/core-providers/sgl_test.go @@ -40,6 +40,7 @@ func TestSGL(t *testing.T) { MultipleImages: true, CompleteEnd2End: true, Embedding: true, + ListModels: true, }, } diff --git a/tests/core-providers/tests.go b/tests/core-providers/tests.go index 338d2958a2..c5d300ff4e 100644 --- a/tests/core-providers/tests.go +++ b/tests/core-providers/tests.go @@ -49,6 +49,8 @@ func runAllComprehensiveTests(t *testing.T, client *bifrost.Bifrost, ctx context scenarios.RunTranscriptionStreamAdvancedTest, scenarios.RunEmbeddingTest, scenarios.RunReasoningTest, + scenarios.RunListModelsTest, + scenarios.RunListModelsPaginationTest, } // Execute all test scenarios @@ -84,6 +86,7 @@ func printTestSummary(t *testing.T, testConfig config.ComprehensiveTestConfig) { {"TranscriptionStream", testConfig.Scenarios.TranscriptionStream}, {"Embedding", testConfig.Scenarios.Embedding && testConfig.EmbeddingModel != ""}, {"Reasoning", testConfig.Scenarios.Reasoning && testConfig.ReasoningModel != ""}, + {"ListModels", testConfig.Scenarios.ListModels}, } supported := 0 diff --git a/tests/core-providers/vertex_test.go b/tests/core-providers/vertex_test.go index 6e611ff959..ff826d551c 100644 --- a/tests/core-providers/vertex_test.go +++ b/tests/core-providers/vertex_test.go @@ -40,6 +40,7 @@ func TestVertex(t *testing.T) { MultipleImages: true, CompleteEnd2End: true, Embedding: true, + ListModels: true, }, } diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index 7ab776e1a2..ee04d73225 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -274,6 +274,9 @@ const ( // RegisterRoutes registers all completion-related routes func (h *CompletionHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { + // Model endpoints + r.GET("/v1/models", lib.ChainMiddlewares(h.listModels, middlewares...)) + // Completion endpoints r.POST("/v1/completions", lib.ChainMiddlewares(h.textCompletion, middlewares...)) r.POST("/v1/chat/completions", lib.ChainMiddlewares(h.chatCompletion, middlewares...)) @@ -283,6 +286,64 @@ func (h *CompletionHandler) RegisterRoutes(r *router.Router, middlewares ...lib. r.POST("/v1/audio/transcriptions", lib.ChainMiddlewares(h.transcription, middlewares...)) } +// listModels handles GET /v1/models - Process list models requests +// If provider is not specified, lists all models from all configured providers +func (h *CompletionHandler) listModels(ctx *fasthttp.RequestCtx) { + // Get provider from query parameters + provider := string(ctx.QueryArgs().Peek("provider")) + + // Convert context + bifrostCtx := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys()) + if bifrostCtx == nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context", h.logger) + return + } + + var resp *schemas.BifrostListModelsResponse + var bifrostErr *schemas.BifrostError + + pageSize := 0 + if pageSizeStr := ctx.QueryArgs().Peek("page_size"); len(pageSizeStr) > 0 { + if n, err := strconv.Atoi(string(pageSizeStr)); err == nil && n >= 0 { + pageSize = n + } + } + pageToken := string(ctx.QueryArgs().Peek("page_token")) + + bifrostListModelsReq := &schemas.BifrostListModelsRequest{ + Provider: schemas.ModelProvider(provider), + PageSize: pageSize, + PageToken: pageToken, + } + + // Pass-through unknown query params for provider-specific features + extraParams := map[string]interface{}{} + for k, v := range ctx.QueryArgs().All() { + s := string(k) + if s != "provider" && s != "page_size" && s != "page_token" { + extraParams[s] = string(v) + } + } + if len(extraParams) > 0 { + bifrostListModelsReq.ExtraParams = extraParams + } + + // If provider is empty, list all models from all providers + if provider == "" { + resp, bifrostErr = h.client.ListAllModels(*bifrostCtx, bifrostListModelsReq) + } else { + resp, bifrostErr = h.client.ListModelsRequest(*bifrostCtx, bifrostListModelsReq) + } + + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr, h.logger) + return + } + + // Send successful response + SendJSON(ctx, resp, h.logger) +} + // textCompletion handles POST /v1/completions - Process text completion requests func (h *CompletionHandler) textCompletion(ctx *fasthttp.RequestCtx) { var req TextRequest diff --git a/transports/bifrost-http/integrations/anthropic.go b/transports/bifrost-http/integrations/anthropic.go index cd11fd38cb..bf4ba9e8c1 100644 --- a/transports/bifrost-http/integrations/anthropic.go +++ b/transports/bifrost-http/integrations/anthropic.go @@ -2,11 +2,14 @@ package integrations import ( "errors" + "fmt" + "strconv" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/core/schemas/providers/anthropic" "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" ) // AnthropicRouter handles Anthropic-compatible API endpoints @@ -18,7 +21,7 @@ type AnthropicRouter struct { func CreateAnthropicRouteConfigs(pathPrefix string) []RouteConfig { return []RouteConfig{ { - Type: RouteConfigTypeAnthropic, + Type: RouteConfigTypeAnthropic, Path: pathPrefix + "/v1/complete", Method: "POST", GetRequestTypeInstance: func() interface{} { @@ -71,9 +74,71 @@ func CreateAnthropicRouteConfigs(pathPrefix string) []RouteConfig { } } +func CreateAnthropicListModelsRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) []RouteConfig { + return []RouteConfig{ + { + Type: RouteConfigTypeAnthropic, + Path: pathPrefix + "/v1/models", + Method: "GET", + GetRequestTypeInstance: func() interface{} { + return &schemas.BifrostListModelsRequest{} + }, + RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { + if listModelsReq, ok := req.(*schemas.BifrostListModelsRequest); ok { + return &schemas.BifrostRequest{ + ListModelsRequest: listModelsReq, + }, nil + } + return nil, errors.New("invalid request type") + }, + ListModelsResponseConverter: func(resp *schemas.BifrostListModelsResponse) (interface{}, error) { + return anthropic.ToAnthropicListModelsResponse(resp), nil + }, + ErrorConverter: func(err *schemas.BifrostError) interface{} { + return anthropic.ToAnthropicChatCompletionError(err) + }, + PreCallback: extractAnthropicListModelsParams, + }, + } +} + +// extractAnthropicListModelsParams extracts query parameters for list models request +func extractAnthropicListModelsParams(ctx *fasthttp.RequestCtx, req interface{}) error { + if listModelsReq, ok := req.(*schemas.BifrostListModelsRequest); ok { + // Set provider to Anthropic + listModelsReq.Provider = schemas.Anthropic + + // Extract limit from query parameters + if limitStr := string(ctx.QueryArgs().Peek("limit")); limitStr != "" { + if limit, err := strconv.Atoi(limitStr); err == nil { + listModelsReq.PageSize = limit + } else { + return fmt.Errorf("invalid limit parameter: %w", err) + } + } + + if beforeID := string(ctx.QueryArgs().Peek("before_id")); beforeID != "" { + if listModelsReq.ExtraParams == nil { + listModelsReq.ExtraParams = make(map[string]interface{}) + } + listModelsReq.ExtraParams["before_id"] = beforeID + } + + if afterID := string(ctx.QueryArgs().Peek("after_id")); afterID != "" { + if listModelsReq.ExtraParams == nil { + listModelsReq.ExtraParams = make(map[string]interface{}) + } + listModelsReq.ExtraParams["after_id"] = afterID + } + + return nil + } + return errors.New("invalid request type for Anthropic list models") +} + // NewAnthropicRouter creates a new AnthropicRouter with the given bifrost client. func NewAnthropicRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *AnthropicRouter { return &AnthropicRouter{ - GenericRouter: NewGenericRouter(client, handlerStore, CreateAnthropicRouteConfigs("/anthropic"), logger), + GenericRouter: NewGenericRouter(client, handlerStore, append(CreateAnthropicRouteConfigs("/anthropic"), CreateAnthropicListModelsRouteConfigs("/anthropic", handlerStore)...), logger), } } diff --git a/transports/bifrost-http/integrations/genai.go b/transports/bifrost-http/integrations/genai.go index 092af54361..d1b2eff638 100644 --- a/transports/bifrost-http/integrations/genai.go +++ b/transports/bifrost-http/integrations/genai.go @@ -3,6 +3,7 @@ package integrations import ( "errors" "fmt" + "strconv" "strings" bifrost "github.com/maximhq/bifrost/core" @@ -23,7 +24,7 @@ func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig { // Chat completions endpoint routes = append(routes, RouteConfig{ - Type: RouteConfigTypeGenAI, + Type: RouteConfigTypeGenAI, Path: pathPrefix + "/v1beta/models/{model:*}", Method: "POST", GetRequestTypeInstance: func() interface{} { @@ -63,6 +64,30 @@ func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig { PreCallback: extractAndSetModelFromURL, }) + routes = append(routes, RouteConfig{ + Type: RouteConfigTypeGenAI, + Path: pathPrefix + "/v1beta/models", + Method: "GET", + GetRequestTypeInstance: func() interface{} { + return &schemas.BifrostListModelsRequest{} + }, + RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { + if listModelsReq, ok := req.(*schemas.BifrostListModelsRequest); ok { + return &schemas.BifrostRequest{ + ListModelsRequest: listModelsReq, + }, nil + } + return nil, errors.New("invalid request type") + }, + ListModelsResponseConverter: func(resp *schemas.BifrostListModelsResponse) (interface{}, error) { + return gemini.ToGeminiListModelsResponse(resp), nil + }, + ErrorConverter: func(err *schemas.BifrostError) interface{} { + return gemini.ToGeminiError(err) + }, + PreCallback: extractGeminiListModelsParams, + }) + return routes } @@ -127,3 +152,26 @@ func extractAndSetModelFromURL(ctx *fasthttp.RequestCtx, req interface{}) error return fmt.Errorf("invalid request type for GenAI") } + +// extractGeminiListModelsParams extracts query parameters for list models request +func extractGeminiListModelsParams(ctx *fasthttp.RequestCtx, req interface{}) error { + if listModelsReq, ok := req.(*schemas.BifrostListModelsRequest); ok { + // Set provider to Gemini + listModelsReq.Provider = schemas.Gemini + + // Extract pageSize from query parameters (Gemini uses pageSize instead of limit) + if pageSizeStr := string(ctx.QueryArgs().Peek("pageSize")); pageSizeStr != "" { + if pageSize, err := strconv.Atoi(pageSizeStr); err == nil { + listModelsReq.PageSize = pageSize + } + } + + // Extract pageToken from query parameters + if pageToken := string(ctx.QueryArgs().Peek("pageToken")); pageToken != "" { + listModelsReq.PageToken = pageToken + } + + return nil + } + return errors.New("invalid request type for Gemini list models") +} diff --git a/transports/bifrost-http/integrations/openai.go b/transports/bifrost-http/integrations/openai.go index 5db0c8b563..50b423a6a7 100644 --- a/transports/bifrost-http/integrations/openai.go +++ b/transports/bifrost-http/integrations/openai.go @@ -55,6 +55,8 @@ func AzureEndpointPreHook(handlerStore lib.HandlerStore) func(ctx *fasthttp.Requ r.Model = setAzureModelName(r.Model, deploymentIDStr) case *openai.OpenAIEmbeddingRequest: r.Model = setAzureModelName(r.Model, deploymentIDStr) + case *schemas.BifrostListModelsRequest: + r.Provider = schemas.Azure } if deploymentEndpoint == nil || azureKey == nil || !handlerStore.ShouldAllowDirectKeys() { @@ -101,7 +103,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) "/openai/deployments/{deployment-id}/completions", } { routes = append(routes, RouteConfig{ - Type: RouteConfigTypeOpenAI, + Type: RouteConfigTypeOpenAI, Path: pathPrefix + path, Method: "POST", GetRequestTypeInstance: func() interface{} { @@ -140,7 +142,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) "/openai/deployments/{deployment-id}/chat/completions", } { routes = append(routes, RouteConfig{ - Type: RouteConfigTypeOpenAI, + Type: RouteConfigTypeOpenAI, Path: pathPrefix + path, Method: "POST", GetRequestTypeInstance: func() interface{} { @@ -179,7 +181,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) "/openai/deployments/{deployment-id}/responses", } { routes = append(routes, RouteConfig{ - Type: RouteConfigTypeOpenAI, + Type: RouteConfigTypeOpenAI, Path: pathPrefix + path, Method: "POST", GetRequestTypeInstance: func() interface{} { @@ -219,7 +221,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) "/openai/deployments/{deployment-id}/embeddings", } { routes = append(routes, RouteConfig{ - Type: RouteConfigTypeOpenAI, + Type: RouteConfigTypeOpenAI, Path: pathPrefix + path, Method: "POST", GetRequestTypeInstance: func() interface{} { @@ -250,7 +252,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) "/openai/deployments/{deployment-id}/audio/speech", } { routes = append(routes, RouteConfig{ - Type: RouteConfigTypeOpenAI, + Type: RouteConfigTypeOpenAI, Path: pathPrefix + path, Method: "POST", GetRequestTypeInstance: func() interface{} { @@ -286,7 +288,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) "/openai/deployments/{deployment-id}/audio/transcriptions", } { routes = append(routes, RouteConfig{ - Type: RouteConfigTypeOpenAI, + Type: RouteConfigTypeOpenAI, Path: pathPrefix + path, Method: "POST", GetRequestTypeInstance: func() interface{} { @@ -322,10 +324,74 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) return routes } +func CreateOpenAIListModelsRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) []RouteConfig { + var routes []RouteConfig + + // Models endpoint + for _, path := range []string{ + "/v1/models", + "/models", + "/openai/deployments/{deployment-id}/models", + } { + routes = append(routes, RouteConfig{ + Type: RouteConfigTypeOpenAI, + Path: pathPrefix + path, + Method: "GET", + GetRequestTypeInstance: func() interface{} { + return &schemas.BifrostListModelsRequest{} + }, + RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { + if listModelsReq, ok := req.(*schemas.BifrostListModelsRequest); ok { + return &schemas.BifrostRequest{ + ListModelsRequest: listModelsReq, + }, nil + } + return nil, errors.New("invalid request type") + }, + ListModelsResponseConverter: func(resp *schemas.BifrostListModelsResponse) (interface{}, error) { + return openai.ToOpenAIListModelsResponse(resp), nil + }, + ErrorConverter: func(err *schemas.BifrostError) interface{} { + return err + }, + PreCallback: setQueryParamsAndAzureEndpointPreHook(handlerStore), + }) + } + + return routes +} + +// setQueryParamsAndAzureEndpointPreHook creates a combined pre-callback for OpenAI list models +// that handles both Azure endpoint preprocessing and query parameter extraction +func setQueryParamsAndAzureEndpointPreHook(handlerStore lib.HandlerStore) PreRequestCallback { + azureHook := AzureEndpointPreHook(handlerStore) + + return func(ctx *fasthttp.RequestCtx, req interface{}) error { + // First run the Azure endpoint pre-hook if needed + if azureHook != nil { + if err := azureHook(ctx, req); err != nil { + return err + } + } + + // Then extract query parameters for list models + if listModelsReq, ok := req.(*schemas.BifrostListModelsRequest); ok { + // Set provider to OpenAI (may be overridden by Azure hook) + if listModelsReq.Provider == "" { + listModelsReq.Provider = schemas.OpenAI + } + + return nil + } + + return nil + } +} + // NewOpenAIRouter creates a new OpenAIRouter with the given bifrost client. func NewOpenAIRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *OpenAIRouter { return &OpenAIRouter{ - GenericRouter: NewGenericRouter(client, handlerStore, CreateOpenAIRouteConfigs("/openai", handlerStore), logger), + GenericRouter: NewGenericRouter(client, handlerStore, append(CreateOpenAIRouteConfigs("/openai", handlerStore), CreateOpenAIListModelsRouteConfigs("/openai", handlerStore)...), logger), } } diff --git a/transports/bifrost-http/integrations/utils.go b/transports/bifrost-http/integrations/utils.go index 7351b65122..4c1c562abd 100644 --- a/transports/bifrost-http/integrations/utils.go +++ b/transports/bifrost-http/integrations/utils.go @@ -115,6 +115,8 @@ type RequestConverter func(req interface{}) (*schemas.BifrostRequest, error) // ResponseConverter is a function that converts Bifrost responses to integration-specific format. // It takes a BifrostResponse and returns the format expected by the specific integration. +type ListModelsResponseConverter func(*schemas.BifrostListModelsResponse) (interface{}, error) + type TextResponseConverter func(*schemas.BifrostTextCompletionResponse) (interface{}, error) type ChatResponseConverter func(*schemas.BifrostChatResponse) (interface{}, error) @@ -193,20 +195,21 @@ type StreamConfig struct { type RouteConfigType string const ( - RouteConfigTypeOpenAI RouteConfigType = "openai" + RouteConfigTypeOpenAI RouteConfigType = "openai" RouteConfigTypeAnthropic RouteConfigType = "anthropic" - RouteConfigTypeGenAI RouteConfigType = "genai" + RouteConfigTypeGenAI RouteConfigType = "genai" ) // RouteConfig defines the configuration for a single route in an integration. // It specifies the path, method, and handlers for request/response conversion. type RouteConfig struct { - Type RouteConfigType // Type of the route (e.g., "chat", "text", "embedding", "responses", "speech", "transcription") + Type RouteConfigType // Type of the route Path string // HTTP path pattern (e.g., "/openai/v1/chat/completions") Method string // HTTP method (POST, GET, PUT, DELETE) GetRequestTypeInstance func() interface{} // Factory function to create request instance (SHOULD NOT BE NIL) RequestParser RequestParser // Optional: custom request parsing (e.g., multipart/form-data) RequestConverter RequestConverter // Function to convert request to BifrostRequest (SHOULD NOT BE NIL) + ListModelsResponseConverter ListModelsResponseConverter // Function to convert BifrostListModelsResponse to integration format (SHOULD NOT BE NIL) TextResponseConverter TextResponseConverter // Function to convert BifrostTextCompletionResponse to integration format (SHOULD NOT BE NIL) ChatResponseConverter ChatResponseConverter // Function to convert BifrostChatResponse to integration format (SHOULD NOT BE NIL) ResponsesResponseConverter ResponsesResponseConverter // Function to convert BifrostResponsesResponse to integration format (SHOULD NOT BE NIL) @@ -244,27 +247,37 @@ func NewGenericRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, ro func (g *GenericRouter) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { for _, route := range g.routes { // Validate route configuration at startup to fail fast + method := strings.ToUpper(route.Method) + if route.GetRequestTypeInstance == nil { g.logger.Warn("route configuration is invalid: GetRequestTypeInstance cannot be nil for route " + route.Path) continue } + + // Test that GetRequestTypeInstance returns a valid instance + if testInstance := route.GetRequestTypeInstance(); testInstance == nil { + g.logger.Warn("route configuration is invalid: GetRequestTypeInstance returned nil for route " + route.Path) + continue + } + + // For list models endpoints, verify ListModelsResponseConverter is set + if method == fasthttp.MethodGet && route.ListModelsResponseConverter == nil { + g.logger.Warn("route configuration is invalid: ListModelsResponseConverter cannot be nil for GET route " + route.Path) + continue + } + if route.RequestConverter == nil { g.logger.Warn("route configuration is invalid: RequestConverter cannot be nil for route " + route.Path) continue } + if route.ErrorConverter == nil { g.logger.Warn("route configuration is invalid: ErrorConverter cannot be nil for route " + route.Path) continue } - // Test that GetRequestTypeInstance returns a valid instance - if testInstance := route.GetRequestTypeInstance(); testInstance == nil { - g.logger.Warn("route configuration is invalid: GetRequestTypeInstance returned nil for route " + route.Path) - continue - } - handler := g.createHandler(route) - switch strings.ToUpper(route.Method) { + switch method { case fasthttp.MethodPost: r.POST(route.Path, lib.ChainMiddlewares(handler, middlewares...)) case fasthttp.MethodGet: @@ -289,14 +302,14 @@ func (g *GenericRouter) RegisterRoutes(r *router.Router, middlewares ...lib.Bifr // 6. Convert and send the response using the configured response converter func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { + method := string(ctx.Method()) + // Parse request body into the integration-specific request type // Note: config validation is performed at startup in RegisterRoutes req := config.GetRequestTypeInstance() - method := string(ctx.Method()) - // Parse request body based on configuration - if method != fasthttp.MethodGet && method != fasthttp.MethodDelete { + if method != fasthttp.MethodGet { if config.RequestParser != nil { // Use custom parser (e.g., for multipart/form-data) if err := config.RequestParser(ctx, req); err != nil { @@ -372,6 +385,26 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf var err error switch { + case bifrostReq.ListModelsRequest != nil: + listModelsResponse, bifrostErr := g.client.ListModelsRequest(*bifrostCtx, bifrostReq.ListModelsRequest) + if bifrostErr != nil { + g.sendError(ctx, config.ErrorConverter, bifrostErr) + return + } + + if config.PostCallback != nil { + if err := config.PostCallback(ctx, req, listModelsResponse); err != nil { + g.sendError(ctx, config.ErrorConverter, newBifrostError(err, "failed to execute post-request callback")) + return + } + } + + if listModelsResponse == nil { + g.sendError(ctx, config.ErrorConverter, newBifrostError(nil, "Bifrost response is nil after post-request callback")) + return + } + + response, err = config.ListModelsResponseConverter(listModelsResponse) case bifrostReq.TextCompletionRequest != nil: textCompletionResponse, bifrostErr := g.client.TextCompletionRequest(*bifrostCtx, bifrostReq.TextCompletionRequest) if bifrostErr != nil { diff --git a/transports/changelog.md b/transports/changelog.md index 59bf4c7e69..01b3b85c2a 100644 --- a/transports/changelog.md +++ b/transports/changelog.md @@ -1,13 +1,5 @@ -- chore: version update core to 1.2.13 and framework to 1.1.15 -- feat: added headers support for OTel configuration. Value prefixed with env will be fetched from environment variables (env.) -- feat: emission of OTel resource spans is completely async - this brings down inference overhead to < 1µsecond -- fix: added latency calculation for vertex native requests -- feat: added cached tokens and reasoning tokens to the usage in ui -- fix: cost calculation for vertex requests -- feat: added global region support for vertex API -- fix: added filter for extra fields in chat completions request for Mistral provider -- fix: added wildcard validation for allowed origins in UI security settings -- fix: fixed code field in pending_safety_checks for Responses API \ No newline at end of file +- chore: version update core to 1.2.4 and framework to 1.1.16 +- feat: added `/v1/models` endpoint to list models of configured providers \ No newline at end of file diff --git a/transports/version b/transports/version index 0c00f61081..17e63e7aff 100644 --- a/transports/version +++ b/transports/version @@ -1 +1 @@ -1.3.10 +1.3.11 diff --git a/ui/app/providers/dialogs/addNewCustomProviderDialog.tsx b/ui/app/providers/dialogs/addNewCustomProviderDialog.tsx index 6f5e3b0a79..2bd3e90942 100644 --- a/ui/app/providers/dialogs/addNewCustomProviderDialog.tsx +++ b/ui/app/providers/dialogs/addNewCustomProviderDialog.tsx @@ -21,6 +21,7 @@ const allowedRequestsSchema = z.object({ speech_stream: z.boolean(), transcription: z.boolean(), transcription_stream: z.boolean(), + list_models: z.boolean(), }); const formSchema = z.object({ @@ -55,6 +56,7 @@ export default function AddCustomProviderDialog({ show, onClose, onSave }: Props speech_stream: true, transcription: true, transcription_stream: true, + list_models: true, }, }, }); diff --git a/ui/app/providers/fragments/allowedRequestsFields.tsx b/ui/app/providers/fragments/allowedRequestsFields.tsx index aa3c71e8b4..ed4ded37af 100644 --- a/ui/app/providers/fragments/allowedRequestsFields.tsx +++ b/ui/app/providers/fragments/allowedRequestsFields.tsx @@ -10,6 +10,7 @@ interface AllowedRequestsFieldsProps { } const REQUEST_TYPES = [ + { key: "list_models", label: "List Models" }, { key: "text_completion", label: "Text Completion" }, { key: "chat_completion", label: "Chat Completion" }, { key: "chat_completion_stream", label: "Chat Completion Stream" }, diff --git a/ui/app/providers/fragments/apiStructureFormFragment.tsx b/ui/app/providers/fragments/apiStructureFormFragment.tsx index 8af2aeebd2..dcb2195bad 100644 --- a/ui/app/providers/fragments/apiStructureFormFragment.tsx +++ b/ui/app/providers/fragments/apiStructureFormFragment.tsx @@ -45,6 +45,7 @@ export function ApiStructureFormFragment({ provider, showRestartAlert }: Props) speech_stream: provider.custom_provider_config?.allowed_requests?.speech_stream ?? true, transcription: provider.custom_provider_config?.allowed_requests?.transcription ?? true, transcription_stream: provider.custom_provider_config?.allowed_requests?.transcription_stream ?? true, + list_models: provider.custom_provider_config?.allowed_requests?.list_models ?? true, }, }, }); diff --git a/ui/lib/constants/config.ts b/ui/lib/constants/config.ts index 1df36f1f33..019bee53df 100644 --- a/ui/lib/constants/config.ts +++ b/ui/lib/constants/config.ts @@ -65,6 +65,7 @@ export const DEFAULT_ALLOWED_REQUESTS = { speech_stream: true, transcription: true, transcription_stream: true, + list_models: true, } as const satisfies Required; export const IS_ENTERPRISE = process.env.NEXT_PUBLIC_IS_ENTERPRISE === "true"; diff --git a/ui/lib/types/config.ts b/ui/lib/types/config.ts index 90ace5fad3..6057461ac3 100644 --- a/ui/lib/types/config.ts +++ b/ui/lib/types/config.ts @@ -120,30 +120,15 @@ export interface AllowedRequests { speech_stream: boolean; transcription: boolean; transcription_stream: boolean; + list_models: boolean; } -export const DefaultAllowedRequests: AllowedRequests = { - text_completion: true, - chat_completion: true, - chat_completion_stream: true, - embedding: true, - speech: true, - speech_stream: true, - transcription: true, - transcription_stream: true, -} as const satisfies Required; - // CustomProviderConfig matching Go's schemas.CustomProviderConfig export interface CustomProviderConfig { base_provider_type: KnownProvider; allowed_requests?: AllowedRequests; } -export const DefaultCustomProviderConfig: CustomProviderConfig = { - base_provider_type: "openai", - allowed_requests: DefaultAllowedRequests, -} as const satisfies Required; - // ProviderConfig matching Go's lib.ProviderConfig export interface ModelProviderConfig { keys: ModelProviderKey[]; diff --git a/ui/lib/types/schemas.ts b/ui/lib/types/schemas.ts index a104c95496..0095665ac8 100644 --- a/ui/lib/types/schemas.ts +++ b/ui/lib/types/schemas.ts @@ -281,6 +281,7 @@ export const allowedRequestsSchema = z.object({ speech_stream: z.boolean(), transcription: z.boolean(), transcription_stream: z.boolean(), + list_models: z.boolean(), }); // Custom provider config schema