Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 119 additions & 57 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}

request := &schemas.BifrostListModelsRequest{
Provider: req.Provider,
Expand Down Expand Up @@ -258,23 +261,13 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr
baseProvider = config.CustomProviderConfig.BaseProviderType
}

// Get API key for the provider if required
key := schemas.Key{}
if providerRequiresKey(baseProvider) {
key, err = bifrost.selectKeyFromProviderForModel(&ctx, schemas.ListModelsRequest, req.Provider, "", baseProvider)
if err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: err.Error(),
Error: err,
},
}
}
keys, err := bifrost.getAllSupportedKeys(&ctx, req.Provider, baseProvider)
if err != nil {
Comment thread
TejasGhatte marked this conversation as resolved.
return nil, newBifrostError(err)
}

response, bifrostErr := executeRequestWithRetries(config, func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
return provider.ListModels(ctx, key, request)
return provider.ListModels(ctx, keys, request)
}, schemas.ListModelsRequest, req.Provider, "")
if bifrostErr != nil {
return nil, bifrostErr
Expand All @@ -285,8 +278,6 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr
// ListAllModels lists all models from all configured providers.
// It accumulates responses from all providers with a limit of 1000 per provider to get all results.
func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
startTime := time.Now()

if request == nil {
request = &schemas.BifrostListModelsRequest{}
}
Expand All @@ -296,64 +287,102 @@ func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.Bifr
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "failed to get configured providers",
Message: err.Error(),
Error: err,
},
}
}

// Accumulate all models from all providers
allModels := make([]schemas.Model, 0)
var firstError *schemas.BifrostError
startTime := time.Now()

// Result structure for collecting provider responses
type providerResult struct {
models []schemas.Model
err *schemas.BifrostError
}

results := make(chan providerResult, len(providerKeys))
var wg sync.WaitGroup

// Launch concurrent requests for all providers
for _, providerKey := range providerKeys {
if strings.TrimSpace(string(providerKey)) == "" {
continue
}

// Create request for this provider with limit of 1000
providerRequest := &schemas.BifrostListModelsRequest{
Provider: providerKey,
PageSize: schemas.DefaultPageSize,
}
wg.Add(1)
go func(providerKey schemas.ModelProvider) {
defer wg.Done()

iterations := 0
for {
iterations++
if iterations > schemas.MaxPaginationRequests {
bifrost.logger.Warn(fmt.Sprintf("reached maximum pagination requests (%d) for provider %s", schemas.MaxPaginationRequests, providerKey))
break
providerModels := make([]schemas.Model, 0)
var providerErr *schemas.BifrostError

// Create request for this provider with limit of 1000
providerRequest := &schemas.BifrostListModelsRequest{
Provider: providerKey,
PageSize: schemas.DefaultPageSize,
}

response, bifrostErr := bifrost.ListModelsRequest(ctx, providerRequest)
if bifrostErr != nil {
// Log the error but continue with other providers
// Skip logging "no keys found" and "not supported" errors as they are expected when a provider is not configured
if !strings.Contains(bifrostErr.Error.Message, "no keys found") &&
!strings.Contains(bifrostErr.Error.Message, "not supported") {
bifrost.logger.Warn(fmt.Sprintf("failed to list models for provider %s: %v", providerKey, bifrostErr.Error.Message))
iterations := 0
for {
// check for context cancellation
select {
case <-ctx.Done():
bifrost.logger.Warn(fmt.Sprintf("context cancelled for provider %s", providerKey))
return
default:
}
if firstError == nil {
firstError = bifrostErr

iterations++
if iterations > schemas.MaxPaginationRequests {
bifrost.logger.Warn(fmt.Sprintf("reached maximum pagination requests (%d) for provider %s, please increase the page size", schemas.MaxPaginationRequests, providerKey))
break
}
break
}

if response == nil {
break
}
response, bifrostErr := bifrost.ListModelsRequest(ctx, providerRequest)
if bifrostErr != nil {
// Skip logging "no keys found" and "not supported" errors as they are expected when a provider is not configured
if !strings.Contains(bifrostErr.Error.Message, "no keys found") &&
!strings.Contains(bifrostErr.Error.Message, "not supported") {
providerErr = bifrostErr
bifrost.logger.Warn(fmt.Sprintf("failed to list models for provider %s: %v", providerKey, bifrostErr.Error.Message))
}
break
}

if len(response.Data) > 0 {
allModels = append(allModels, response.Data...)
}
if response == nil || len(response.Data) == 0 {
break
}

// Check if there are more pages
if response.NextPageToken == "" {
break
providerModels = append(providerModels, response.Data...)

// Check if there are more pages
if response.NextPageToken == "" {
break
}

// Set the page token for the next request
providerRequest.PageToken = response.NextPageToken
}

// Set the page token for the next request
providerRequest.PageToken = response.NextPageToken
results <- providerResult{models: providerModels, err: providerErr}
}(providerKey)
}

// Wait for all goroutines to complete
wg.Wait()
close(results)

// Accumulate all models from all providers
allModels := make([]schemas.Model, 0)
var firstError *schemas.BifrostError

for result := range results {
if len(result.models) > 0 {
allModels = append(allModels, result.models...)
}
if result.err != nil && firstError == nil {
firstError = result.err
}
}

Expand All @@ -367,15 +396,12 @@ func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.Bifr
return allModels[i].ID < allModels[j].ID
})

// Calculate total elapsed time
elapsedTime := time.Since(startTime).Milliseconds()

// Return aggregated response with accumulated latency
response := &schemas.BifrostListModelsResponse{
Data: allModels,
ExtraFields: schemas.BifrostResponseExtraFields{
RequestType: schemas.ListModelsRequest,
Latency: elapsedTime,
Latency: time.Since(startTime).Milliseconds(),
},
}

Expand Down Expand Up @@ -2324,6 +2350,42 @@ func (bifrost *Bifrost) releaseBifrostRequest(req *schemas.BifrostRequest) {
bifrost.bifrostRequestPool.Put(req)
}

// getAllSupportedKeys retrieves all valid keys for a ListModels request.
// allowing the provider to aggregate results from multiple keys.
func (bifrost *Bifrost) getAllSupportedKeys(ctx *context.Context, providerKey schemas.ModelProvider, baseProviderType schemas.ModelProvider) ([]schemas.Key, error) {
// Check if key has been set in the context explicitly
if ctx != nil {
key, ok := (*ctx).Value(schemas.BifrostContextKeyDirectKey).(schemas.Key)
if ok {
// If a direct key is specified, return it as a single-element slice
return []schemas.Key{key}, nil
}
}
Comment thread
Pratham-Mishra04 marked this conversation as resolved.

keys, err := bifrost.account.GetKeysForProvider(ctx, providerKey)
if err != nil {
return nil, err
}

if len(keys) == 0 {
return nil, fmt.Errorf("no keys found for provider: %v", providerKey)
}

// Filter keys for ListModels - only check if key has a value
var supportedKeys []schemas.Key
for _, k := range keys {
if strings.TrimSpace(k.Value) != "" || canProviderKeyValueBeEmpty(baseProviderType) {
supportedKeys = append(supportedKeys, k)
}
}

if len(supportedKeys) == 0 {
return nil, fmt.Errorf("no valid keys found for provider: %v", providerKey)
}

return supportedKeys, nil
}
Comment thread
Pratham-Mishra04 marked this conversation as resolved.

// selectKeyFromProviderForModel selects an appropriate API key for a given provider and model.
// It uses weighted random selection if multiple keys are available.
func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) (schemas.Key, error) {
Expand Down
69 changes: 36 additions & 33 deletions core/providers/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,6 @@ func (provider *AnthropicProvider) GetProviderKey() schemas.ModelProvider {
return getProviderName(schemas.Anthropic, provider.customProviderConfig)
}

// parseStreamAnthropicError parses Anthropic streaming error responses.
func parseStreamAnthropicError(resp *http.Response, providerType schemas.ModelProvider) *schemas.BifrostError {
statusCode := resp.StatusCode
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()

var errorResp anthropic.AnthropicError
if err := sonic.Unmarshal(body, &errorResp); err != nil {
return newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerType)
}

return newProviderAPIError(errorResp.Error.Message, nil, statusCode, providerType, &errorResp.Error.Type, nil)
}

// completeRequest sends a request to Anthropic's API and handles the response.
// It constructs the API URL, sets up authentication, and processes the response.
// Returns the response body or an error if the request fails.
Expand Down Expand Up @@ -188,14 +174,9 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB
return bodyCopy, latency, nil
}

// ListModels performs a list models request to Anthropic's API.
func (provider *AnthropicProvider) ListModels(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
if err := checkOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ListModelsRequest); err != nil {
return nil, err
}

providerName := provider.GetProviderKey()

// listModelsByKey performs a list models request for a single key.
// Returns the response and latency, or an error if the request fails.
func (provider *AnthropicProvider) listModelsByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
// Create request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
Expand All @@ -206,8 +187,7 @@ func (provider *AnthropicProvider) ListModels(ctx context.Context, key schemas.K
setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil)

// Build URL using centralized URL construction
requestURL := anthropic.ToAnthropicListModelsURL(request, provider.networkConfig.BaseURL+"/v1/models")
req.SetRequestURI(requestURL)
req.SetRequestURI(fmt.Sprintf("%s/v1/models?limit=%d", provider.networkConfig.BaseURL, schemas.DefaultPageSize))
req.Header.SetMethod(http.MethodGet)
req.Header.SetContentType("application/json")
req.Header.Set("x-api-key", key.Value)
Expand All @@ -221,14 +201,10 @@ func (provider *AnthropicProvider) ListModels(ctx context.Context, key schemas.K

// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body())))

var errorResp anthropic.AnthropicError

bifrostErr := handleProviderAPIError(resp, &errorResp)
bifrostErr.Error.Type = &errorResp.Error.Type
bifrostErr.Error.Message = errorResp.Error.Message

return nil, bifrostErr
}

Expand All @@ -240,11 +216,7 @@ func (provider *AnthropicProvider) ListModels(ctx context.Context, key schemas.K
}

// Create final response
response := anthropicResponse.ToBifrostListModelsResponse(providerName)

// Set ExtraFields
response.ExtraFields.Provider = providerName
response.ExtraFields.RequestType = schemas.ListModelsRequest
response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey())
response.ExtraFields.Latency = latency.Milliseconds()

// Set raw response if enabled
Expand All @@ -255,6 +227,23 @@ func (provider *AnthropicProvider) ListModels(ctx context.Context, key schemas.K
return response, nil
}

// ListModels performs a list models request to Anthropic's API.
// It fetches models using all provided keys and aggregates the results.
// Uses a best-effort approach: continues with remaining keys even if some fail.
// Requests are made concurrently for improved performance.
func (provider *AnthropicProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
if err := checkOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ListModelsRequest); err != nil {
return nil, err
}
return handleMultipleListModelsRequests(
ctx,
keys,
request,
provider.listModelsByKey,
provider.logger,
)
}

// TextCompletion performs a text completion request to Anthropic's API.
// It formats the request, sends it to Anthropic, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
Expand Down Expand Up @@ -852,3 +841,17 @@ func (provider *AnthropicProvider) Transcription(ctx context.Context, key schema
func (provider *AnthropicProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) {
return nil, newUnsupportedOperationError("transcription stream", "anthropic")
}

// parseStreamAnthropicError parses Anthropic streaming error responses.
func parseStreamAnthropicError(resp *http.Response, providerType schemas.ModelProvider) *schemas.BifrostError {
statusCode := resp.StatusCode
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()

var errorResp anthropic.AnthropicError
if err := sonic.Unmarshal(body, &errorResp); err != nil {
return newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerType)
}

return newProviderAPIError(errorResp.Error.Message, nil, statusCode, providerType, &errorResp.Error.Type, nil)
}
Loading