diff --git a/core/bifrost.go b/core/bifrost.go index 7db6663758..0cadac7718 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -23,6 +23,7 @@ import ( "github.com/maximhq/bifrost/core/providers/azure" "github.com/maximhq/bifrost/core/providers/bedrock" "github.com/maximhq/bifrost/core/providers/cerebras" + "github.com/maximhq/bifrost/core/providers/codex" "github.com/maximhq/bifrost/core/providers/cohere" "github.com/maximhq/bifrost/core/providers/elevenlabs" "github.com/maximhq/bifrost/core/providers/fireworks" @@ -3636,6 +3637,8 @@ func (bifrost *Bifrost) createBaseProvider(providerKey schemas.ModelProvider, co switch targetProviderKey { case schemas.OpenAI: return openai.NewOpenAIProvider(config, bifrost.logger), nil + case schemas.Codex: + return codex.NewCodexProvider(config, bifrost.logger) case schemas.Anthropic: return anthropic.NewAnthropicProvider(config, bifrost.logger), nil case schemas.Bedrock: diff --git a/core/providers/codex/auth.go b/core/providers/codex/auth.go new file mode 100644 index 0000000000..75d2881e94 --- /dev/null +++ b/core/providers/codex/auth.go @@ -0,0 +1,212 @@ +package codex + +import ( + "context" + "crypto/rand" + "encoding/base64" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/bytedance/sonic" +) + +const ( + OAuthClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + OAuthIssuer = "https://auth.openai.com" + DeviceVerificationURL = OAuthIssuer + "/codex/device" + deviceCallbackRedirect = OAuthIssuer + "/deviceauth/callback" + defaultPollingMarginSecs = 3 +) + +type TokenResponse struct { + IDToken string `json:"id_token"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` +} + +type DeviceAuthorizationResponse struct { + DeviceAuthID string `json:"device_auth_id"` + UserCode string `json:"user_code"` + Interval string `json:"interval"` +} + +type DeviceTokenResponse struct { + AuthorizationCode string `json:"authorization_code"` + CodeVerifier string `json:"code_verifier"` +} + +type IDTokenClaims struct { + ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"` + Organizations []struct { + ID string `json:"id"` + } `json:"organizations,omitempty"` + OpenAIAuth *struct { + ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"` + } `json:"https://api.openai.com/auth,omitempty"` +} + +func RefreshAccessToken(ctx context.Context, client *http.Client, refreshToken string) (*TokenResponse, error) { + form := url.Values{} + form.Set("grant_type", "refresh_token") + form.Set("refresh_token", refreshToken) + form.Set("client_id", OAuthClientID) + return executeTokenRequest(ctx, client, OAuthIssuer+"/oauth/token", strings.NewReader(form.Encode())) +} + +func StartDeviceAuthorization(ctx context.Context, client *http.Client, userAgent string) (*DeviceAuthorizationResponse, error) { + requestBody, err := sonic.Marshal(map[string]string{"client_id": OAuthClientID}) + if err != nil { + return nil, err + } + request, err := http.NewRequestWithContext(ctx, http.MethodPost, OAuthIssuer+"/api/accounts/deviceauth/usercode", strings.NewReader(string(requestBody))) + if err != nil { + return nil, err + } + request.Header.Set("Content-Type", "application/json") + if userAgent != "" { + request.Header.Set("User-Agent", userAgent) + } + response, err := client.Do(request) + if err != nil { + return nil, err + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("device authorization failed with status %d", response.StatusCode) + } + var result DeviceAuthorizationResponse + if err := sonic.ConfigDefault.NewDecoder(response.Body).Decode(&result); err != nil { + return nil, err + } + return &result, nil +} + +func PollDeviceAuthorization(ctx context.Context, client *http.Client, deviceAuthID, userCode, userAgent string) (*DeviceTokenResponse, int, error) { + requestBody, err := sonic.Marshal(map[string]string{"device_auth_id": deviceAuthID, "user_code": userCode}) + if err != nil { + return nil, 0, err + } + request, err := http.NewRequestWithContext(ctx, http.MethodPost, OAuthIssuer+"/api/accounts/deviceauth/token", strings.NewReader(string(requestBody))) + if err != nil { + return nil, 0, err + } + request.Header.Set("Content-Type", "application/json") + if userAgent != "" { + request.Header.Set("User-Agent", userAgent) + } + response, err := client.Do(request) + if err != nil { + return nil, 0, err + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return nil, response.StatusCode, nil + } + var result DeviceTokenResponse + if err := sonic.ConfigDefault.NewDecoder(response.Body).Decode(&result); err != nil { + return nil, response.StatusCode, err + } + return &result, response.StatusCode, nil +} + +func ExchangeDeviceAuthorizationCode(ctx context.Context, client *http.Client, code, codeVerifier string) (*TokenResponse, error) { + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("code", code) + form.Set("redirect_uri", deviceCallbackRedirect) + form.Set("client_id", OAuthClientID) + form.Set("code_verifier", codeVerifier) + return executeTokenRequest(ctx, client, OAuthIssuer+"/oauth/token", strings.NewReader(form.Encode())) +} + +func ExtractAccountID(tokens *TokenResponse) string { + if tokens == nil { + return "" + } + for _, candidate := range []string{tokens.IDToken, tokens.AccessToken} { + claims := parseJWTClaims(candidate) + if claims == nil { + continue + } + if claims.ChatGPTAccountID != "" { + return claims.ChatGPTAccountID + } + if claims.OpenAIAuth != nil && claims.OpenAIAuth.ChatGPTAccountID != "" { + return claims.OpenAIAuth.ChatGPTAccountID + } + if len(claims.Organizations) > 0 && claims.Organizations[0].ID != "" { + return claims.Organizations[0].ID + } + } + return "" +} + +func ExpiresAtFromNow(expiresIn int) string { + if expiresIn <= 0 { + expiresIn = 3600 + } + return time.Now().Add(time.Duration(expiresIn) * time.Second).UTC().Format(time.RFC3339) +} + +func NextPollTime(intervalSeconds int) time.Time { + if intervalSeconds <= 0 { + intervalSeconds = 5 + } + return time.Now().Add(time.Duration(intervalSeconds+defaultPollingMarginSecs) * time.Second) +} + +func executeTokenRequest(ctx context.Context, client *http.Client, endpoint string, body *strings.Reader) (*TokenResponse, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, body) + if err != nil { + return nil, err + } + request.Header.Set("Content-Type", "application/x-www-form-urlencoded") + response, err := client.Do(request) + if err != nil { + return nil, err + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token request failed with status %d", response.StatusCode) + } + var tokenResponse TokenResponse + if err := sonic.ConfigDefault.NewDecoder(response.Body).Decode(&tokenResponse); err != nil { + return nil, err + } + return &tokenResponse, nil +} + +func generateRandomString(length int) (string, error) { + const chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" + bytes := make([]byte, length) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + result := make([]byte, length) + for i, value := range bytes { + result[i] = chars[int(value)%len(chars)] + } + return string(result), nil +} + +func parseJWTClaims(token string) *IDTokenClaims { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return nil + } + decoded, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil + } + var claims IDTokenClaims + if err := sonic.Unmarshal(decoded, &claims); err != nil { + return nil + } + return &claims +} diff --git a/core/providers/codex/auth_test.go b/core/providers/codex/auth_test.go new file mode 100644 index 0000000000..b51a4d5f2f --- /dev/null +++ b/core/providers/codex/auth_test.go @@ -0,0 +1,19 @@ +package codex + +import "testing" + +func TestExtractAccountID(t *testing.T) { + token := &TokenResponse{ + AccessToken: "eyJhbGciOiJub25lIn0.eyJodHRwczovL2FwaS5vcGVuYWkuY29tL2F1dGgiOnsiY2hhdGdwdF9hY2NvdW50X2lkIjoib3JnXzEyMyJ9fQ.", + } + if accountID := ExtractAccountID(token); accountID != "org_123" { + t.Fatalf("expected account id org_123, got %q", accountID) + } +} + +func TestExpiresAtFromNow(t *testing.T) { + value := ExpiresAtFromNow(60) + if value == "" { + t.Fatal("expected non-empty RFC3339 expiry") + } +} diff --git a/core/providers/codex/codex.go b/core/providers/codex/codex.go new file mode 100644 index 0000000000..4cd2f2ae4d --- /dev/null +++ b/core/providers/codex/codex.go @@ -0,0 +1,919 @@ +package codex + +import ( + "context" + "fmt" + "maps" + "net/http" + "slices" + "strings" + "time" + + "github.com/google/uuid" + "github.com/maximhq/bifrost/core/providers/openai" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" + schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" + "golang.org/x/sync/singleflight" +) + +const defaultBaseURL = "https://chatgpt.com/backend-api/codex" +const defaultInstructions = "You are Codex, a helpful coding assistant. Follow the user's request exactly." + +var defaultModels = []string{ + "gpt-5.1-codex", + "gpt-5.1-codex-max", + "gpt-5.1-codex-mini", + "gpt-5.2", + "gpt-5.2-codex", + "gpt-5.3-codex", + "gpt-5.4", + "gpt-5.4-mini", +} + +type CodexProvider struct { + logger schemas.Logger + client *fasthttp.Client + authHTTPClient *http.Client + refreshGroup singleflight.Group + networkConfig schemas.NetworkConfig + sendBackRawRequest bool + sendBackRawResponse bool +} + +func NewCodexProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*CodexProvider, error) { + config.CheckAndSetDefaults() + + requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds) + client := &fasthttp.Client{ + ReadTimeout: requestTimeout, + WriteTimeout: requestTimeout, + MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost, + MaxIdleConnDuration: 30 * time.Second, + MaxConnWaitTimeout: requestTimeout, + MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds), + ConnPoolStrategy: fasthttp.FIFO, + } + + client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger) + client = providerUtils.ConfigureDialer(client) + client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) + + if config.NetworkConfig.BaseURL == "" { + config.NetworkConfig.BaseURL = defaultBaseURL + } + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") + + return &CodexProvider{ + logger: logger, + client: client, + authHTTPClient: &http.Client{Timeout: 20 * time.Second}, + networkConfig: config.NetworkConfig, + sendBackRawRequest: config.SendBackRawRequest, + sendBackRawResponse: config.SendBackRawResponse, + }, nil +} + +func (provider *CodexProvider) GetProviderKey() schemas.ModelProvider { + return schemas.Codex +} + +func (provider *CodexProvider) buildRequestURL(path string) string { + return provider.networkConfig.BaseURL + path +} + +func (provider *CodexProvider) authHeaders(ctx *schemas.BifrostContext, key schemas.Key) (map[string]string, *schemas.BifrostError) { + headers := map[string]string{} + if key.CodexKeyConfig != nil { + if key.CodexKeyConfig.AccessToken != nil { + if token := strings.TrimSpace(key.CodexKeyConfig.AccessToken.GetValue()); token != "" && !accessTokenExpired(key.CodexKeyConfig.AccessTokenExpiresAt) { + headers["Authorization"] = "Bearer " + token + } + } + if _, ok := headers["Authorization"]; !ok { + if refreshToken := strings.TrimSpace(key.CodexKeyConfig.RefreshToken.GetValue()); refreshToken != "" { + requestCtx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + refreshKey := key.ID + if refreshKey == "" { + refreshKey = refreshToken + } + result, err, _ := provider.refreshGroup.Do(refreshKey, func() (interface{}, error) { + return RefreshAccessToken(requestCtx, provider.authHTTPClient, refreshToken) + }) + if err != nil { + statusCode := http.StatusBadGateway + message := fmt.Sprintf("failed to refresh Codex access token: %v", err) + return nil, &schemas.BifrostError{IsBifrostError: true, StatusCode: &statusCode, Error: &schemas.ErrorField{Message: message}, ExtraFields: schemas.BifrostErrorExtraFields{Provider: provider.GetProviderKey()}} + } + tokens := result.(*TokenResponse) + headers["Authorization"] = "Bearer " + tokens.AccessToken + accountID := ExtractAccountID(tokens) + if accountID != "" { + headers["ChatGPT-Account-Id"] = accountID + } + provider.persistRefreshedCredentials(ctx, key, tokens, accountID) + } + } + if key.CodexKeyConfig.AccountID != nil { + if accountID := strings.TrimSpace(key.CodexKeyConfig.AccountID.GetValue()); accountID != "" { + if _, ok := headers["ChatGPT-Account-Id"]; !ok { + headers["ChatGPT-Account-Id"] = accountID + } + } + } + } + if _, ok := headers["Authorization"]; !ok { + if token := strings.TrimSpace(key.Value.GetValue()); token != "" { + headers["Authorization"] = "Bearer " + token + } + } + if _, ok := headers["Authorization"]; !ok { + statusCode := http.StatusUnauthorized + message := "Codex provider requires an authenticated key with a refresh token or access token" + return nil, &schemas.BifrostError{IsBifrostError: true, StatusCode: &statusCode, Error: &schemas.ErrorField{Message: message}, ExtraFields: schemas.BifrostErrorExtraFields{Provider: provider.GetProviderKey()}} + } + if _, ok := headers["User-Agent"]; !ok { + headers["User-Agent"] = "Bifrost Codex Provider" + } + if _, ok := headers["originator"]; !ok { + headers["originator"] = "opencode" + } + if _, ok := headers["session_id"]; !ok { + requestID := uuid.NewString() + if ctx != nil { + if value, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); ok && strings.TrimSpace(value) != "" { + requestID = value + } + } + headers["session_id"] = requestID + } + return headers, nil +} + +func (provider *CodexProvider) persistRefreshedCredentials(ctx *schemas.BifrostContext, key schemas.Key, tokens *TokenResponse, accountID string) { + if ctx == nil || tokens == nil || key.ID == "" { + return + } + persister, ok := ctx.Value(schemas.BifrostContextKeyCodexCredentialPersister).(schemas.CodexCredentialPersister) + if !ok || persister == nil { + return + } + refreshValue := tokens.RefreshToken + if strings.TrimSpace(refreshValue) == "" && key.CodexKeyConfig != nil { + refreshValue = key.CodexKeyConfig.RefreshToken.GetValue() + } + refreshedConfig := &schemas.CodexKeyConfig{ + RefreshToken: *schemas.NewEnvVar(refreshValue), + AccessToken: schemas.NewEnvVar(tokens.AccessToken), + AccessTokenExpiresAt: schemas.Ptr(ExpiresAtFromNow(tokens.ExpiresIn)), + AuthMethod: schemas.CodexAuthMethodDevice, + } + if key.CodexKeyConfig != nil && key.CodexKeyConfig.AuthMethod != "" { + refreshedConfig.AuthMethod = key.CodexKeyConfig.AuthMethod + } + if accountID != "" { + refreshedConfig.AccountID = schemas.NewEnvVar(accountID) + } else if key.CodexKeyConfig != nil && key.CodexKeyConfig.AccountID != nil { + refreshedConfig.AccountID = key.CodexKeyConfig.AccountID + } + if err := persister(key.ID, refreshedConfig); err != nil && provider.logger != nil { + provider.logger.Warn("failed to persist refreshed Codex credentials for key %s: %v", key.ID, err) + } +} + +func (provider *CodexProvider) responsesRequestFromChat(request *schemas.BifrostChatRequest) *schemas.BifrostResponsesRequest { + if len(request.Input) == 0 { + return request.ToResponsesRequest() + } + + instructions := make([]string, 0, 2) + remainingIndex := 0 + for idx, message := range request.Input { + if message.Role == "system" || message.Role == "developer" { + if text := extractChatMessageText(message); strings.TrimSpace(text) != "" { + instructions = append(instructions, text) + remainingIndex = idx + 1 + continue + } + } + break + } + + clone := *request + clone.Input = request.Input[remainingIndex:] + responsesRequest := clone.ToResponsesRequest() + ensureCodexResponseDefaults(nil, responsesRequest) + if len(instructions) > 0 { + if responsesRequest.Params == nil { + responsesRequest.Params = &schemas.ResponsesParameters{} + } + joined := strings.Join(instructions, "\n\n") + responsesRequest.Params.Instructions = &joined + } + return responsesRequest +} + +func (provider *CodexProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + ownedBy := string(provider.GetProviderKey()) + data := make([]schemas.Model, 0, len(defaultModels)) + for _, model := range defaultModels { + modelID := model + data = append(data, schemas.Model{ID: modelID, OwnedBy: &ownedBy}) + } + return &schemas.BifrostListModelsResponse{Data: data}, nil +} + +func (provider *CodexProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + response, err := provider.accumulateResponsesRequest(ctx, key, provider.responsesRequestFromChat(request)) + if err != nil { + return nil, err + } + return response.ToBifrostChatResponse(), nil +} + +func (provider *CodexProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + innerStream, bifrostErr := provider.ResponsesStream(ctx, codexNoOpPostHookRunner, key, provider.responsesRequestFromChat(request)) + if bifrostErr != nil { + return nil, bifrostErr + } + responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) + state := newCodexChatStreamState(ctx, request.Model) + go func() { + defer close(responseChan) + for chunk := range innerStream { + if chunk == nil { + continue + } + if chunk.BifrostError != nil { + bifrostErr := *chunk.BifrostError + bifrostErr.ExtraFields.RequestType = schemas.ChatCompletionStreamRequest + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, provider.logger) + return + } + if chunk.BifrostResponsesStreamResponse == nil { + continue + } + for _, chatChunk := range state.convert(chunk.BifrostResponsesStreamResponse) { + if chatChunk == nil { + continue + } + if chatChunk.Choices != nil && len(chatChunk.Choices) > 0 && chatChunk.Choices[0].FinishReason != nil { + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + } + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, chatChunk, nil, nil, nil, nil), responseChan) + } + } + }() + return responseChan, nil +} + +func (provider *CodexProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + return provider.accumulateResponsesRequest(ctx, key, request) +} + +func (provider *CodexProvider) sendResponsesRequest(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + ensureCodexResponseDefaults(ctx, request) + normalizeCodexInput(request) + authHeaders, bifrostErr := provider.authHeaders(ctx, key) + if bifrostErr != nil { + return nil, bifrostErr + } + return openai.HandleOpenAIResponsesRequest( + ctx, + provider.client, + provider.buildRequestURL("/responses"), + request, + key, + mergeExtraHeaders(provider.networkConfig.ExtraHeaders, authHeaders), + providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), + providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), + provider.GetProviderKey(), + nil, + provider.parseError, + provider.logger, + ) +} + +func (provider *CodexProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + ensureCodexResponseDefaults(ctx, request) + normalizeCodexInput(request) + authHeaders, bifrostErr := provider.authHeaders(ctx, key) + if bifrostErr != nil { + return nil, bifrostErr + } + return openai.HandleOpenAIResponsesStreaming( + ctx, + provider.client, + provider.buildRequestURL("/responses"), + request, + authHeaders, + provider.networkConfig.ExtraHeaders, + providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), + providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), + provider.GetProviderKey(), + postHookRunner, + nil, + provider.parseError, + nil, + nil, + provider.logger, + ) +} + +func (provider *CodexProvider) parseError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { + bifrostErr := openai.ParseOpenAIError(resp, requestType, providerName, model) + if bifrostErr != nil && bifrostErr.Error != nil && bifrostErr.Error.Message == "provider API error (status 400)" { + if body := strings.TrimSpace(string(resp.Body())); body != "" { + bifrostErr.Error.Message = bifrostErr.Error.Message + ": " + body + } + } + if bifrostErr != nil && bifrostErr.Error != nil && bifrostErr.Error.Code != nil && *bifrostErr.Error.Code == "usage_not_included" { + bifrostErr.Error.Message = bifrostErr.Error.Message + " Visit https://chatgpt.com/#pricing to upgrade your ChatGPT plan for Codex usage." + } + return bifrostErr +} + +func (provider *CodexProvider) CountTokens(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) VideoGeneration(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoGenerationRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) VideoRetrieve(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRetrieveRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) VideoDownload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDownloadRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) VideoDelete(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDeleteRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) VideoList(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoListRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) BatchDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) ContainerCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerCreateRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) ContainerList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerListRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) ContainerRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerRetrieveRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) ContainerDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerDeleteRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) ContainerFileCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileCreateRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) ContainerFileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileListRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) ContainerFileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileRetrieveRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) ContainerFileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileContentRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) ContainerFileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileDeleteRequest, provider.GetProviderKey()) +} + +func mergeExtraHeaders(extraHeaders map[string]string, authHeaders map[string]string) map[string]string { + if len(authHeaders) == 0 { + return extraHeaders + } + headers := make(map[string]string, len(extraHeaders)+len(authHeaders)) + if len(extraHeaders) > 0 { + maps.Copy(headers, extraHeaders) + } + maps.Copy(headers, authHeaders) + return headers +} + +func extractChatMessageText(message schemas.ChatMessage) string { + if message.Content == nil { + return "" + } + if message.Content.ContentStr != nil { + return *message.Content.ContentStr + } + parts := make([]string, 0, len(message.Content.ContentBlocks)) + for _, block := range message.Content.ContentBlocks { + if block.Text != nil && strings.TrimSpace(*block.Text) != "" { + parts = append(parts, *block.Text) + } + } + return strings.Join(parts, "\n") +} + +func accessTokenExpired(expiresAt *string) bool { + if expiresAt == nil || strings.TrimSpace(*expiresAt) == "" { + return true + } + parsed, err := time.Parse(time.RFC3339, *expiresAt) + if err != nil { + return true + } + return time.Now().After(parsed.Add(-30 * time.Second)) +} + +func ensureCodexResponseDefaults(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest) { + if request == nil { + return + } + if request.Params == nil { + request.Params = &schemas.ResponsesParameters{} + } + stripCodexUnsupportedParams(request.Params) + store := false + request.Params.Store = &store + if request.Params.Instructions == nil || strings.TrimSpace(*request.Params.Instructions) == "" { + instructions := defaultInstructions + request.Params.Instructions = &instructions + } + if request.Params.PromptCacheKey == nil { + promptCacheKey := uuid.NewString() + if ctx != nil { + if value, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); ok && strings.TrimSpace(value) != "" { + promptCacheKey = value + } + } + request.Params.PromptCacheKey = &promptCacheKey + } +} + +func stripCodexUnsupportedParams(params *schemas.ResponsesParameters) { + if params == nil { + return + } + + // Codex rejects explicit tuning/limit fields that are valid for the public OpenAI API. + // Keep the request working by stripping only the parameters we verified against the upstream endpoint. + params.Temperature = nil + params.TopP = nil + params.MaxOutputTokens = nil + if len(params.ExtraParams) == 0 { + return + } + delete(params.ExtraParams, "temperature") + delete(params.ExtraParams, "top_p") + delete(params.ExtraParams, "max_output_tokens") + delete(params.ExtraParams, "presence_penalty") + delete(params.ExtraParams, "frequency_penalty") +} + +func normalizeCodexInput(request *schemas.BifrostResponsesRequest) { + if request == nil || len(request.Input) == 0 { + return + } + for idx := range request.Input { + message := request.Input[idx] + role := "" + if message.Role != nil { + role = string(*message.Role) + } + if message.Content == nil { + continue + } + if message.Content.ContentStr != nil { + text := strings.TrimSpace(*message.Content.ContentStr) + blockType := codexTextBlockTypeForRole(role) + block := schemas.ResponsesMessageContentBlock{ + Type: blockType, + Text: &text, + } + if blockType == schemas.ResponsesOutputMessageContentTypeText { + block.ResponsesOutputMessageContentText = &schemas.ResponsesOutputMessageContentText{} + } + message.Content = &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{block}, + } + request.Input[idx] = message + continue + } + if len(message.Content.ContentBlocks) == 0 { + continue + } + blocks := make([]schemas.ResponsesMessageContentBlock, 0, len(message.Content.ContentBlocks)) + for _, block := range message.Content.ContentBlocks { + if block.Text != nil { + block.Type = codexTextBlockTypeForRole(role) + if block.Type == schemas.ResponsesOutputMessageContentTypeText && block.ResponsesOutputMessageContentText == nil { + block.ResponsesOutputMessageContentText = &schemas.ResponsesOutputMessageContentText{} + } + } + blocks = append(blocks, block) + } + message.Content = &schemas.ResponsesMessageContent{ContentBlocks: blocks} + request.Input[idx] = message + } +} + +func codexTextBlockTypeForRole(role string) schemas.ResponsesMessageContentBlockType { + switch role { + case string(schemas.ResponsesInputMessageRoleAssistant): + return schemas.ResponsesOutputMessageContentTypeText + default: + return schemas.ResponsesInputMessageContentBlockTypeText + } +} + +func (provider *CodexProvider) accumulateResponsesRequest(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + streamCtx := cloneBifrostContext(ctx) + stream, bifrostErr := provider.ResponsesStream(streamCtx, codexNoOpPostHookRunner, key, request) + if bifrostErr != nil { + return nil, bifrostErr + } + + accumulator := newCodexResponsesAccumulator(request.Model) + for chunk := range stream { + if chunk == nil { + continue + } + if chunk.BifrostError != nil { + return nil, chunk.BifrostError + } + if chunk.BifrostResponsesStreamResponse == nil { + continue + } + accumulator.add(chunk.BifrostResponsesStreamResponse) + if isFinalCodexResponsesChunk(chunk.BifrostResponsesStreamResponse) { + return accumulator.response(), nil + } + } + + statusCode := http.StatusBadGateway + message := "codex stream ended before a final response was accumulated" + return nil, &schemas.BifrostError{IsBifrostError: true, StatusCode: &statusCode, Error: &schemas.ErrorField{Message: message}, ExtraFields: schemas.BifrostErrorExtraFields{Provider: provider.GetProviderKey(), ModelRequested: request.Model, RequestType: schemas.ResponsesRequest}} +} + +func cloneBifrostContext(parent *schemas.BifrostContext) *schemas.BifrostContext { + if parent == nil { + return schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + } + deadline, ok := parent.Deadline() + if !ok { + deadline = schemas.NoDeadline + } + return schemas.NewBifrostContext(parent, deadline) +} + +func isFinalCodexResponsesChunk(resp *schemas.BifrostResponsesStreamResponse) bool { + if resp == nil { + return false + } + switch resp.Type { + case schemas.ResponsesStreamResponseTypeCompleted, schemas.ResponsesStreamResponseTypeFailed, schemas.ResponsesStreamResponseTypeIncomplete: + return true + default: + return false + } +} + +type codexResponsesAccumulator struct { + latest *schemas.BifrostResponsesResponse + itemsBySlot map[int]schemas.ResponsesMessage + model string +} + +type codexChatStreamState struct { + id string + model string + created int + chunkIndex int +} + +func newCodexResponsesAccumulator(model string) *codexResponsesAccumulator { + return &codexResponsesAccumulator{ + itemsBySlot: make(map[int]schemas.ResponsesMessage), + model: model, + } +} + +func newCodexChatStreamState(ctx *schemas.BifrostContext, requestedModel string) *codexChatStreamState { + id := uuid.NewString() + if ctx != nil { + if requestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); ok && strings.TrimSpace(requestID) != "" { + id = requestID + } + } + return &codexChatStreamState{ + id: id, + model: requestedModel, + created: int(time.Now().Unix()), + chunkIndex: -1, + } +} + +func codexNoOpPostHookRunner(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + return result, err +} + +func (s *codexChatStreamState) convert(resp *schemas.BifrostResponsesStreamResponse) []*schemas.BifrostChatResponse { + if resp == nil { + return nil + } + if resp.Response != nil { + if resp.Response.ID != nil && *resp.Response.ID != "" { + s.id = *resp.Response.ID + } + if resp.Response.Model != "" { + s.model = resp.Response.Model + } + if resp.Response.CreatedAt > 0 { + s.created = resp.Response.CreatedAt + } + } + + switch resp.Type { + case schemas.ResponsesStreamResponseTypeOutputItemAdded: + return s.outputItemAdded(resp) + case schemas.ResponsesStreamResponseTypeOutputTextDelta: + if resp.Delta == nil { + return nil + } + return []*schemas.BifrostChatResponse{s.chunk(&schemas.ChatStreamResponseChoiceDelta{Content: resp.Delta}, nil, nil)} + case schemas.ResponsesStreamResponseTypeReasoningSummaryTextDelta: + if resp.Delta == nil { + return nil + } + return []*schemas.BifrostChatResponse{s.chunk(&schemas.ChatStreamResponseChoiceDelta{Reasoning: resp.Delta}, nil, nil)} + case schemas.ResponsesStreamResponseTypeRefusalDelta: + if resp.Refusal == nil { + return nil + } + return []*schemas.BifrostChatResponse{s.chunk(&schemas.ChatStreamResponseChoiceDelta{Refusal: resp.Refusal}, nil, nil)} + case schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta: + if resp.Delta == nil { + return nil + } + toolType := "function" + toolCall := schemas.ChatAssistantMessageToolCall{ + Index: uint16(valueOrZero(resp.OutputIndex)), + Type: &toolType, + ID: resp.ItemID, + Function: schemas.ChatAssistantMessageToolCallFunction{Arguments: *resp.Delta}, + } + return []*schemas.BifrostChatResponse{s.chunk(&schemas.ChatStreamResponseChoiceDelta{ToolCalls: []schemas.ChatAssistantMessageToolCall{toolCall}}, nil, nil)} + case schemas.ResponsesStreamResponseTypeCompleted: + finishReason := s.finishReason(resp) + usage := responsesUsageToChatUsage(resp.Response) + return []*schemas.BifrostChatResponse{s.chunk(&schemas.ChatStreamResponseChoiceDelta{}, &finishReason, usage)} + case schemas.ResponsesStreamResponseTypeIncomplete: + finishReason := "length" + usage := responsesUsageToChatUsage(resp.Response) + return []*schemas.BifrostChatResponse{s.chunk(&schemas.ChatStreamResponseChoiceDelta{}, &finishReason, usage)} + case schemas.ResponsesStreamResponseTypeFailed: + finishReason := "stop" + usage := responsesUsageToChatUsage(resp.Response) + return []*schemas.BifrostChatResponse{s.chunk(&schemas.ChatStreamResponseChoiceDelta{}, &finishReason, usage)} + default: + return nil + } +} + +func (s *codexChatStreamState) outputItemAdded(resp *schemas.BifrostResponsesStreamResponse) []*schemas.BifrostChatResponse { + if resp.Item == nil { + return nil + } + responses := make([]*schemas.BifrostChatResponse, 0, 2) + if resp.Item.Role != nil && *resp.Item.Role == schemas.ResponsesInputMessageRoleAssistant { + role := "assistant" + responses = append(responses, s.chunk(&schemas.ChatStreamResponseChoiceDelta{Role: &role}, nil, nil)) + } + if resp.Item.Type != nil && *resp.Item.Type == schemas.ResponsesMessageTypeFunctionCall && resp.Item.ResponsesToolMessage != nil { + toolType := "function" + toolCall := schemas.ChatAssistantMessageToolCall{ + Index: uint16(valueOrZero(resp.OutputIndex)), + Type: &toolType, + ID: resp.Item.CallID, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: resp.Item.Name, + Arguments: "", + }, + } + responses = append(responses, s.chunk(&schemas.ChatStreamResponseChoiceDelta{ToolCalls: []schemas.ChatAssistantMessageToolCall{toolCall}}, nil, nil)) + } + return responses +} + +func (s *codexChatStreamState) chunk(delta *schemas.ChatStreamResponseChoiceDelta, finishReason *string, usage *schemas.BifrostLLMUsage) *schemas.BifrostChatResponse { + s.chunkIndex++ + return &schemas.BifrostChatResponse{ + ID: s.id, + Object: "chat.completion.chunk", + Created: s.created, + Model: s.model, + Usage: usage, + Choices: []schemas.BifrostResponseChoice{{ + Index: 0, + ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{Delta: delta}, + FinishReason: finishReason, + }}, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.ChatCompletionStreamRequest, + Provider: schemas.Codex, + ModelRequested: s.model, + ChunkIndex: s.chunkIndex, + }, + } +} + +func (s *codexChatStreamState) finishReason(resp *schemas.BifrostResponsesStreamResponse) string { + if resp != nil && resp.Response != nil { + for _, item := range resp.Response.Output { + if item.Type != nil && *item.Type == schemas.ResponsesMessageTypeFunctionCall { + return "tool_calls" + } + } + } + return "stop" +} + +func responsesUsageToChatUsage(resp *schemas.BifrostResponsesResponse) *schemas.BifrostLLMUsage { + if resp == nil || resp.Usage == nil { + return nil + } + return resp.Usage.ToBifrostLLMUsage() +} + +func valueOrZero(value *int) int { + if value == nil { + return 0 + } + return *value +} + +func (a *codexResponsesAccumulator) add(event *schemas.BifrostResponsesStreamResponse) { + if event == nil { + return + } + if event.Response != nil { + copy := *event.Response + a.latest = © + if copy.Model != "" { + a.model = copy.Model + } + } + if event.Item != nil && event.OutputIndex != nil && event.Type == schemas.ResponsesStreamResponseTypeOutputItemDone { + a.itemsBySlot[*event.OutputIndex] = *event.Item + } +} + +func (a *codexResponsesAccumulator) response() *schemas.BifrostResponsesResponse { + resp := &schemas.BifrostResponsesResponse{ + Object: "response", + Model: a.model, + } + if a.latest != nil { + copy := *a.latest + resp = © + } + if len(a.itemsBySlot) > 0 { + indices := make([]int, 0, len(a.itemsBySlot)) + for index := range a.itemsBySlot { + indices = append(indices, index) + } + slices.Sort(indices) + output := make([]schemas.ResponsesMessage, 0, len(indices)) + for _, index := range indices { + output = append(output, a.itemsBySlot[index]) + } + resp.Output = output + } + if resp.Model == "" { + resp.Model = a.model + } + if resp.Object == "" { + resp.Object = "response" + } + return resp +} + +func (provider *CodexProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey()) +} + +func (provider *CodexProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) +} diff --git a/core/providers/codex/codex_test.go b/core/providers/codex/codex_test.go new file mode 100644 index 0000000000..44fffc52a1 --- /dev/null +++ b/core/providers/codex/codex_test.go @@ -0,0 +1,171 @@ +package codex + +import ( + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/require" +) + +func TestEnsureCodexResponseDefaults_StripsUnsupportedParams(t *testing.T) { + temperature := 0.2 + topP := 0.5 + maxOutputTokens := 32 + promptCacheKey := "keep-me" + store := true + request := &schemas.BifrostResponsesRequest{ + Params: &schemas.ResponsesParameters{ + Temperature: &temperature, + TopP: &topP, + MaxOutputTokens: &maxOutputTokens, + PromptCacheKey: &promptCacheKey, + Store: &store, + ExtraParams: map[string]interface{}{ + "temperature": 0.2, + "top_p": 0.5, + "max_output_tokens": 32, + "presence_penalty": 0, + "frequency_penalty": 0, + "parallel_tool_calls": true, + }, + }, + } + + ensureCodexResponseDefaults(nil, request) + + require.NotNil(t, request.Params) + require.Nil(t, request.Params.Temperature) + require.Nil(t, request.Params.TopP) + require.Nil(t, request.Params.MaxOutputTokens) + require.NotNil(t, request.Params.Store) + require.False(t, *request.Params.Store) + require.NotNil(t, request.Params.PromptCacheKey) + require.Equal(t, promptCacheKey, *request.Params.PromptCacheKey) + require.NotContains(t, request.Params.ExtraParams, "temperature") + require.NotContains(t, request.Params.ExtraParams, "top_p") + require.NotContains(t, request.Params.ExtraParams, "max_output_tokens") + require.NotContains(t, request.Params.ExtraParams, "presence_penalty") + require.NotContains(t, request.Params.ExtraParams, "frequency_penalty") + require.Contains(t, request.Params.ExtraParams, "parallel_tool_calls") +} + +func TestEnsureCodexResponseDefaults_AddsInstructionsWhenMissing(t *testing.T) { + request := &schemas.BifrostResponsesRequest{} + + ensureCodexResponseDefaults(nil, request) + + require.NotNil(t, request.Params) + require.NotNil(t, request.Params.Instructions) + require.Equal(t, defaultInstructions, *request.Params.Instructions) +} + +func TestNormalizeCodexInput_PreservesAssistantOutputText(t *testing.T) { + userRole := schemas.ResponsesInputMessageRoleUser + assistantRole := schemas.ResponsesInputMessageRoleAssistant + userText := "user text" + assistantText := "assistant text" + request := &schemas.BifrostResponsesRequest{ + Input: []schemas.ResponsesMessage{ + { + Role: &userRole, + Content: &schemas.ResponsesMessageContent{ContentStr: &userText}, + }, + { + Role: &assistantRole, + Content: &schemas.ResponsesMessageContent{ContentStr: &assistantText}, + }, + }, + } + + normalizeCodexInput(request) + + require.Len(t, request.Input, 2) + require.Len(t, request.Input[0].Content.ContentBlocks, 1) + require.Equal(t, schemas.ResponsesInputMessageContentBlockTypeText, request.Input[0].Content.ContentBlocks[0].Type) + require.Len(t, request.Input[1].Content.ContentBlocks, 1) + require.Equal(t, schemas.ResponsesOutputMessageContentTypeText, request.Input[1].Content.ContentBlocks[0].Type) + require.NotNil(t, request.Input[1].Content.ContentBlocks[0].ResponsesOutputMessageContentText) +} + +func TestCodexResponsesAccumulator_UsesCompletedItems(t *testing.T) { + assistantRole := schemas.ResponsesInputMessageRoleAssistant + itemType := schemas.ResponsesMessageTypeMessage + text := "hello" + status := "completed" + accumulator := newCodexResponsesAccumulator("gpt-5.4-mini") + accumulator.add(&schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemDone, + OutputIndex: schemas.Ptr(0), + Item: &schemas.ResponsesMessage{ + Type: &itemType, + Role: &assistantRole, + Status: &status, + Content: &schemas.ResponsesMessageContent{ContentBlocks: []schemas.ResponsesMessageContentBlock{{ + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: &text, + }}}, + }, + }) + accumulator.add(&schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeCompleted, + Response: &schemas.BifrostResponsesResponse{Model: "gpt-5.4-mini", Object: "response", Status: schemas.Ptr("completed")}, + }) + + response := accumulator.response() + require.NotNil(t, response) + require.Equal(t, "gpt-5.4-mini", response.Model) + require.Len(t, response.Output, 1) + require.NotNil(t, response.Output[0].Content) + require.Len(t, response.Output[0].Content.ContentBlocks, 1) + require.Equal(t, "hello", *response.Output[0].Content.ContentBlocks[0].Text) +} + +func TestPersistRefreshedCredentials_UsesContextPersister(t *testing.T) { + ctx := schemas.NewBifrostContext(nil, schemas.NoDeadline) + called := false + ctx.SetValue(schemas.BifrostContextKeyCodexCredentialPersister, schemas.CodexCredentialPersister(func(keyID string, keyConfig *schemas.CodexKeyConfig) error { + called = true + require.Equal(t, "key-1", keyID) + require.NotNil(t, keyConfig) + require.Equal(t, "refresh-1", keyConfig.RefreshToken.GetValue()) + require.NotNil(t, keyConfig.AccessToken) + require.Equal(t, "access-1", keyConfig.AccessToken.GetValue()) + require.NotNil(t, keyConfig.AccountID) + require.Equal(t, "acct-1", keyConfig.AccountID.GetValue()) + return nil + })) + provider := &CodexProvider{} + key := schemas.Key{ID: "key-1", CodexKeyConfig: &schemas.CodexKeyConfig{RefreshToken: *schemas.NewEnvVar("refresh-1"), AuthMethod: schemas.CodexAuthMethodDevice}} + + provider.persistRefreshedCredentials(ctx, key, &TokenResponse{AccessToken: "access-1", RefreshToken: "", ExpiresIn: 60}, "acct-1") + + require.True(t, called) +} + +func TestCodexChatStreamState_ConvertsResponsesDeltasToChatChunks(t *testing.T) { + state := newCodexChatStreamState(nil, "codex/gpt-5.4-mini") + assistantRole := schemas.ResponsesInputMessageRoleAssistant + messageType := schemas.ResponsesMessageTypeMessage + text := "hi" + chunks := state.convert(&schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeOutputItemAdded, + Item: &schemas.ResponsesMessage{Role: &assistantRole, Type: &messageType}, + }) + require.Len(t, chunks, 1) + require.NotNil(t, chunks[0].Choices[0].ChatStreamResponseChoice) + require.Equal(t, "assistant", *chunks[0].Choices[0].ChatStreamResponseChoice.Delta.Role) + + chunks = state.convert(&schemas.BifrostResponsesStreamResponse{Type: schemas.ResponsesStreamResponseTypeOutputTextDelta, Delta: &text}) + require.Len(t, chunks, 1) + require.Equal(t, "hi", *chunks[0].Choices[0].ChatStreamResponseChoice.Delta.Content) + + completed := state.convert(&schemas.BifrostResponsesStreamResponse{ + Type: schemas.ResponsesStreamResponseTypeCompleted, + Response: &schemas.BifrostResponsesResponse{Usage: &schemas.ResponsesResponseUsage{InputTokens: 1, OutputTokens: 2, TotalTokens: 3}}, + }) + require.Len(t, completed, 1) + require.NotNil(t, completed[0].Choices[0].FinishReason) + require.Equal(t, "stop", *completed[0].Choices[0].FinishReason) + require.NotNil(t, completed[0].Usage) + require.Equal(t, 3, completed[0].Usage.TotalTokens) +} diff --git a/core/providers/mistral/mistral.go b/core/providers/mistral/mistral.go index 597bf6c239..8257103907 100644 --- a/core/providers/mistral/mistral.go +++ b/core/providers/mistral/mistral.go @@ -70,7 +70,7 @@ func NewMistralProvider(config *schemas.ProviderConfig, logger schemas.Logger) * // GetProviderKey returns the provider identifier for Mistral. func (provider *MistralProvider) GetProviderKey() schemas.ModelProvider { - return schemas.Mistral + return providerUtils.GetProviderName(schemas.Mistral, provider.customProviderConfig) } // listModelsByKey performs a list models request for a single key. @@ -138,6 +138,9 @@ func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, ke // ListModels performs a list models request to Mistral's API. // Requests are made concurrently for improved performance. func (provider *MistralProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + if err := providerUtils.CheckOperationAllowed(schemas.Mistral, provider.customProviderConfig, schemas.ListModelsRequest); err != nil { + return nil, err + } return providerUtils.HandleMultipleListModelsRequests( ctx, keys, @@ -160,11 +163,16 @@ func (provider *MistralProvider) TextCompletionStream(ctx *schemas.BifrostContex // ChatCompletion performs a chat completion request to the Mistral API. func (provider *MistralProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + if err := providerUtils.CheckOperationAllowed(schemas.Mistral, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil { + return nil, mistralUnsupportedWithModel(err, request.Model) + } + requestCopy := *request + requestCopy.Provider = schemas.Mistral return openai.HandleOpenAIChatCompletionRequest( ctx, provider.client, provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), - request, + &requestCopy, key, provider.networkConfig.ExtraHeaders, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), @@ -181,6 +189,11 @@ func (provider *MistralProvider) ChatCompletion(ctx *schemas.BifrostContext, key // Uses Mistral's OpenAI-compatible streaming format. // Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails. func (provider *MistralProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + if err := providerUtils.CheckOperationAllowed(schemas.Mistral, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { + return nil, mistralUnsupportedWithModel(err, request.Model) + } + requestCopy := *request + requestCopy.Provider = schemas.Mistral var authHeader map[string]string if key.Value.GetValue() != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()} @@ -190,7 +203,7 @@ func (provider *MistralProvider) ChatCompletionStream(ctx *schemas.BifrostContex ctx, provider.client, provider.networkConfig.BaseURL+"/v1/chat/completions", - request, + &requestCopy, authHeader, provider.networkConfig.ExtraHeaders, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), @@ -208,6 +221,9 @@ func (provider *MistralProvider) ChatCompletionStream(ctx *schemas.BifrostContex // Responses performs a responses request to the Mistral API. func (provider *MistralProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + if err := providerUtils.CheckOperationAllowed(schemas.Mistral, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { + return nil, mistralUnsupportedWithModel(err, request.Model) + } chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) if err != nil { return nil, err @@ -223,6 +239,9 @@ func (provider *MistralProvider) Responses(ctx *schemas.BifrostContext, key sche // ResponsesStream performs a streaming responses request to the Mistral API. func (provider *MistralProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + if err := providerUtils.CheckOperationAllowed(schemas.Mistral, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { + return nil, mistralUnsupportedWithModel(err, request.Model) + } ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, @@ -235,6 +254,9 @@ func (provider *MistralProvider) ResponsesStream(ctx *schemas.BifrostContext, po // Embedding generates embeddings for the given input text(s) using the Mistral API. // Supports Mistral's embedding models and returns a BifrostResponse containing the embedding(s). func (provider *MistralProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { + if err := providerUtils.CheckOperationAllowed(schemas.Mistral, provider.customProviderConfig, schemas.EmbeddingRequest); err != nil { + return nil, mistralUnsupportedWithModel(err, request.Model) + } // Use the shared embedding request handler return openai.HandleOpenAIEmbeddingRequest( ctx, @@ -264,6 +286,9 @@ func (provider *MistralProvider) Rerank(ctx *schemas.BifrostContext, key schemas // OCR performs an OCR request to the Mistral API. // It sends a JSON request to Mistral's OCR endpoint and returns the extracted content. func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) { + if err := providerUtils.CheckOperationAllowed(schemas.Mistral, provider.customProviderConfig, schemas.OCRRequest); err != nil { + return nil, mistralUnsupportedWithModel(err, request.Model) + } providerName := provider.GetProviderKey() // Convert Bifrost request to Mistral format @@ -387,6 +412,9 @@ func (provider *MistralProvider) SpeechStream(ctx *schemas.BifrostContext, postH // It creates a multipart form with the audio file and sends it to Mistral's transcription endpoint. // Returns the transcribed text and metadata, or an error if the request fails. func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { + if err := providerUtils.CheckOperationAllowed(schemas.Mistral, provider.customProviderConfig, schemas.TranscriptionRequest); err != nil { + return nil, mistralUnsupportedWithModel(err, request.Model) + } providerName := provider.GetProviderKey() // Convert Bifrost request to Mistral format @@ -492,6 +520,9 @@ func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key // It creates a multipart form with the audio file and streams transcription events. // Returns a channel of BifrostStreamChunk objects containing transcription deltas. func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + if err := providerUtils.CheckOperationAllowed(schemas.Mistral, provider.customProviderConfig, schemas.TranscriptionStreamRequest); err != nil { + return nil, mistralUnsupportedWithModel(err, request.Model) + } providerName := provider.GetProviderKey() // Convert Bifrost request to Mistral format @@ -896,3 +927,11 @@ func (provider *MistralProvider) Passthrough(_ *schemas.BifrostContext, _ schema func (provider *MistralProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey()) } + +func mistralUnsupportedWithModel(err *schemas.BifrostError, model string) *schemas.BifrostError { + if err == nil || model == "" || err.ExtraFields.ModelRequested != "" { + return err + } + err.ExtraFields.ModelRequested = model + return err +} diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index 209cd5b1b2..b0649efef4 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -1116,7 +1116,7 @@ func HandleOpenAIChatCompletionStreaming( // Skip scanner for non-SSE responses — avoids bufio.Scanner buffer bloat // on non-line-delimited data (e.g. provider returned JSON instead of SSE). - if providerUtils.DrainNonSSEStreamResponse(resp) { + if providerName != schemas.Codex && providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, streamRequestType, providerName, request.Model, logger) return @@ -1718,7 +1718,7 @@ func HandleOpenAIResponsesStreaming( // Skip scanner for non-SSE responses — avoids bufio.Scanner buffer bloat // on non-line-delimited data (e.g. provider returned JSON instead of SSE). - if providerUtils.DrainNonSSEStreamResponse(resp) { + if providerName != schemas.Codex && providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, logger) return diff --git a/core/schemas/account.go b/core/schemas/account.go index ceaeb2de8a..fdc0789fe2 100644 --- a/core/schemas/account.go +++ b/core/schemas/account.go @@ -25,6 +25,7 @@ type Key struct { HuggingFaceKeyConfig *HuggingFaceKeyConfig `json:"huggingface_key_config,omitempty"` // Hugging Face-specific key configuration ReplicateKeyConfig *ReplicateKeyConfig `json:"replicate_key_config,omitempty"` // Replicate-specific key configuration VLLMKeyConfig *VLLMKeyConfig `json:"vllm_key_config,omitempty"` // vLLM-specific key configuration + CodexKeyConfig *CodexKeyConfig `json:"codex_key_config,omitempty"` // Codex subscription-specific key configuration Enabled *bool `json:"enabled,omitempty"` // Whether the key is active (default:true) UseForBatchAPI *bool `json:"use_for_batch_api,omitempty"` // Whether this key can be used for batch API operations (default:false for new keys, migrated keys default to true) ConfigHash string `json:"config_hash,omitempty"` // Hash of config.json version, used for change detection @@ -113,6 +114,27 @@ type VLLMKeyConfig struct { ModelName string `json:"model_name"` // Exact model name served on this VLLM instance (used for key selection) } +type CodexAuthMethod string + +const ( + CodexAuthMethodDevice CodexAuthMethod = "device" + CodexAuthMethodManual CodexAuthMethod = "manual" +) + +// CodexKeyConfig holds ChatGPT Plus/Pro subscription credentials used by the Codex provider. +// Refresh tokens are the durable credential; access tokens are optional cached tokens for request reuse. +type CodexKeyConfig struct { + RefreshToken EnvVar `json:"refresh_token"` + AccessToken *EnvVar `json:"access_token,omitempty"` + AccessTokenExpiresAt *string `json:"access_token_expires_at,omitempty"` + AccountID *EnvVar `json:"account_id,omitempty"` + AuthMethod CodexAuthMethod `json:"auth_method,omitempty"` +} + +// CodexCredentialPersister persists refreshed Codex credentials for a stored key. +// It is injected by transports that can update persistent provider configuration. +type CodexCredentialPersister func(keyID string, keyConfig *CodexKeyConfig) error + // Account defines the interface for managing provider accounts and their configurations. // It provides methods to access provider-specific settings, API keys, and configurations. type Account interface { diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index c00c5c7078..901a7291c1 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -58,6 +58,7 @@ const ( VLLM ModelProvider = "vllm" Runway ModelProvider = "runway" Fireworks ModelProvider = "fireworks" + Codex ModelProvider = "codex" ) // SupportedBaseProviders is the list of base providers allowed for custom providers. @@ -96,6 +97,7 @@ var StandardProviders = []ModelProvider{ VLLM, Runway, Fireworks, + Codex, } // RequestType represents the type of request being made to a provider. @@ -205,7 +207,7 @@ const ( BifrostContextKeyTraceCompleter BifrostContextKey = "bifrost-trace-completer" // func() (callback to complete trace after streaming - set by tracing middleware) BifrostContextKeyPostHookSpanFinalizer BifrostContextKey = "bifrost-posthook-span-finalizer" // func(context.Context) (callback to finalize post-hook spans after streaming - set by bifrost) BifrostContextKeyAccumulatorID BifrostContextKey = "bifrost-accumulator-id" // string (ID for streaming accumulator lookup - set by tracer for accumulator operations) - BifrostContextKeyHasEmittedMessageDelta BifrostContextKey = "bifrost-has-emitted-message-delta" // bool (tracks whether message_delta was already emitted during streaming - avoids duplicates) + BifrostContextKeyHasEmittedMessageDelta BifrostContextKey = "bifrost-has-emitted-message-delta" // bool (tracks whether message_delta was already emitted during streaming - avoids duplicates) BifrostContextKeySkipDBUpdate BifrostContextKey = "bifrost-skip-db-update" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyGovernancePluginName BifrostContextKey = "governance-plugin-name" // string (name of the governance plugin that processed the request - set by bifrost) BifrostContextKeyIsEnterprise BifrostContextKey = "is-enterprise" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) @@ -215,6 +217,7 @@ const ( BifrostContextKeyIsCustomProvider BifrostContextKey = "bifrost-is-custom-provider" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyHTTPRequestType BifrostContextKey = "bifrost-http-request-type" // RequestType (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyPassthroughExtraParams BifrostContextKey = "bifrost-passthrough-extra-params" // bool + BifrostContextKeyCodexCredentialPersister BifrostContextKey = "bifrost-codex-credential-persister" // CodexCredentialPersister (set by bifrost transport for persisting refreshed Codex credentials) BifrostContextKeyRoutingEnginesUsed BifrostContextKey = "bifrost-routing-engines-used" // []string (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engines used ("routing-rule", "governance", "loadbalancing", etc.) BifrostContextKeyRoutingEngineLogs BifrostContextKey = "bifrost-routing-engine-logs" // []RoutingEngineLogEntry (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engine log entries BifrostContextKeySkipPluginPipeline BifrostContextKey = "bifrost-skip-plugin-pipeline" // bool - skip plugin pipeline for the request diff --git a/core/schemas/context.go b/core/schemas/context.go index 1ff4663eae..b61ef0bb3d 100644 --- a/core/schemas/context.go +++ b/core/schemas/context.go @@ -24,6 +24,7 @@ var reservedKeys = []any{ BifrostContextKeySkipKeySelection, BifrostContextKeyURLPath, BifrostContextKeyDeferTraceCompletion, + BifrostContextKeyCodexCredentialPersister, } // BifrostContext is a custom context.Context implementation that tracks user-set values. diff --git a/core/schemas/provider.go b/core/schemas/provider.go index 6fe0615b06..61a76d536f 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -8,18 +8,18 @@ import ( ) const ( - DefaultMaxRetries = 0 - DefaultRetryBackoffInitial = 500 * time.Millisecond - DefaultRetryBackoffMax = 5 * time.Second + DefaultMaxRetries = 0 + DefaultRetryBackoffInitial = 500 * time.Millisecond + DefaultRetryBackoffMax = 5 * time.Second DefaultRequestTimeoutInSeconds = 30 - DefaultMaxConnDurationInSeconds = 300 // 5 minutes — forces connection recycling to prevent stale connections from NAT/LB silent drops - DefaultBufferSize = 5000 - DefaultConcurrency = 1000 - DefaultStreamBufferSize = 256 - DefaultStreamIdleTimeoutInSeconds = 60 // Idle timeout per stream chunk — if no data for this many seconds, bifrost closes the connection - DefaultMaxConnsPerHost = 5000 - MaxConnsPerHostUpperBound = 10000 - DefaultMaxIdleConnsPerHost = 40 + DefaultMaxConnDurationInSeconds = 300 // 5 minutes — forces connection recycling to prevent stale connections from NAT/LB silent drops + DefaultBufferSize = 5000 + DefaultConcurrency = 1000 + DefaultStreamBufferSize = 256 + DefaultStreamIdleTimeoutInSeconds = 60 // Idle timeout per stream chunk — if no data for this many seconds, bifrost closes the connection + DefaultMaxConnsPerHost = 5000 + MaxConnsPerHostUpperBound = 10000 + DefaultMaxIdleConnsPerHost = 40 ) // Pre-defined errors for provider operations @@ -52,18 +52,18 @@ const ( // - When marshaling to JSON: a time.Duration is converted to milliseconds type NetworkConfig struct { // BaseURL is supported for OpenAI, Anthropic, Cohere, Mistral, and Ollama providers (required for Ollama) - BaseURL string `json:"base_url,omitempty"` // Base URL for the provider (optional) - ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // Additional headers to include in requests (optional) - DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"` // Default timeout for requests - MaxRetries int `json:"max_retries"` // Maximum number of retries - RetryBackoffInitial time.Duration `json:"retry_backoff_initial"` // Initial backoff duration (stored as nanoseconds, JSON as milliseconds) - RetryBackoffMax time.Duration `json:"retry_backoff_max"` // Maximum backoff duration (stored as nanoseconds, JSON as milliseconds) - InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"` // Disables TLS certificate verification for provider connections - CACertPEM string `json:"ca_cert_pem,omitempty"` // PEM-encoded CA certificate to trust for provider endpoint connections + BaseURL string `json:"base_url,omitempty"` // Base URL for the provider (optional) + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // Additional headers to include in requests (optional) + DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"` // Default timeout for requests + MaxRetries int `json:"max_retries"` // Maximum number of retries + RetryBackoffInitial time.Duration `json:"retry_backoff_initial"` // Initial backoff duration (stored as nanoseconds, JSON as milliseconds) + RetryBackoffMax time.Duration `json:"retry_backoff_max"` // Maximum backoff duration (stored as nanoseconds, JSON as milliseconds) + InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"` // Disables TLS certificate verification for provider connections + CACertPEM string `json:"ca_cert_pem,omitempty"` // PEM-encoded CA certificate to trust for provider endpoint connections StreamIdleTimeoutInSeconds int `json:"stream_idle_timeout_in_seconds,omitempty"` // Idle timeout per stream chunk (0 = use default 60s) - MaxConnsPerHost int `json:"max_conns_per_host,omitempty"` // Max TCP connections per provider host (default: 5000) - EnforceHTTP2 bool `json:"enforce_http2,omitempty"` // Force HTTP/2 on provider connections (relevant for net/http-based providers like Bedrock) - BetaHeaderOverrides map[string]bool `json:"beta_header_overrides,omitempty"` // Override default beta header support per provider (keys are prefixes like "redact-thinking-") + MaxConnsPerHost int `json:"max_conns_per_host,omitempty"` // Max TCP connections per provider host (default: 5000) + EnforceHTTP2 bool `json:"enforce_http2,omitempty"` // Force HTTP/2 on provider connections (relevant for net/http-based providers like Bedrock) + BetaHeaderOverrides map[string]bool `json:"beta_header_overrides,omitempty"` // Override default beta header support per provider (keys are prefixes like "redact-thinking-") } // UnmarshalJSON customizes JSON unmarshaling for NetworkConfig. @@ -485,14 +485,14 @@ type ProviderConfig struct { NetworkConfig NetworkConfig `json:"network_config"` // Network configuration ConcurrencyAndBufferSize ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` // Concurrency settings // Logger instance, can be provided by the user or bifrost default logger is used if not provided - Logger Logger `json:"-"` - ProxyConfig *ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration - SendBackRawRequest bool `json:"send_back_raw_request"` // Send raw request back in the bifrost response (default: false) - SendBackRawResponse bool `json:"send_back_raw_response"` // Send raw response back in the bifrost response (default: false) - StoreRawRequestResponse bool `json:"store_raw_request_response"` // Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false) - CustomProviderConfig *CustomProviderConfig `json:"custom_provider_config,omitempty"` - OpenAIConfig *OpenAIConfig `json:"openai_config,omitempty"` - PricingOverrides []ProviderPricingOverride `json:"pricing_overrides,omitempty"` + Logger Logger `json:"-"` + ProxyConfig *ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration + SendBackRawRequest bool `json:"send_back_raw_request"` // Send raw request back in the bifrost response (default: false) + SendBackRawResponse bool `json:"send_back_raw_response"` // Send raw response back in the bifrost response (default: false) + StoreRawRequestResponse bool `json:"store_raw_request_response"` // Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false) + CustomProviderConfig *CustomProviderConfig `json:"custom_provider_config,omitempty"` + OpenAIConfig *OpenAIConfig `json:"openai_config,omitempty"` + PricingOverrides []ProviderPricingOverride `json:"pricing_overrides,omitempty"` } // OpenAIConfig holds OpenAI-specific provider configuration. diff --git a/core/utils.go b/core/utils.go index ed8f40ebf4..85590e3ad2 100644 --- a/core/utils.go +++ b/core/utils.go @@ -98,7 +98,7 @@ func providerRequiresKey(providerKey schemas.ModelProvider, customConfig *schema // canProviderKeyValueBeEmpty returns true if the given provider allows the API key to be empty. // Some providers like Vertex and Bedrock have their credentials in additional key configs.. func CanProviderKeyValueBeEmpty(providerKey schemas.ModelProvider) bool { - return providerKey == schemas.Vertex || providerKey == schemas.Bedrock || providerKey == schemas.VLLM || providerKey == schemas.Azure + return providerKey == schemas.Vertex || providerKey == schemas.Bedrock || providerKey == schemas.VLLM || providerKey == schemas.Azure || providerKey == schemas.Codex } func isKeySkippingAllowed(providerKey schemas.ModelProvider) bool { diff --git a/docs/docs.json b/docs/docs.json index 0a537b322c..3d5a7a2f51 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -127,6 +127,7 @@ "providers/supported-providers/bedrock", "providers/supported-providers/cerebras", "providers/supported-providers/cohere", + "providers/supported-providers/codex", "providers/supported-providers/databricks", "providers/supported-providers/elevenlabs", "providers/supported-providers/fireworks", diff --git a/docs/media/ui-codex-screenshot-1.png b/docs/media/ui-codex-screenshot-1.png new file mode 100644 index 0000000000..2431ac805f Binary files /dev/null and b/docs/media/ui-codex-screenshot-1.png differ diff --git a/docs/media/ui-codex-screenshot-2.png b/docs/media/ui-codex-screenshot-2.png new file mode 100644 index 0000000000..5d2b02c305 Binary files /dev/null and b/docs/media/ui-codex-screenshot-2.png differ diff --git a/docs/media/ui-codex-screenshot-3.png b/docs/media/ui-codex-screenshot-3.png new file mode 100644 index 0000000000..03cbe2b86c Binary files /dev/null and b/docs/media/ui-codex-screenshot-3.png differ diff --git a/docs/providers/supported-providers/codex.mdx b/docs/providers/supported-providers/codex.mdx new file mode 100644 index 0000000000..b881153c5f --- /dev/null +++ b/docs/providers/supported-providers/codex.mdx @@ -0,0 +1,91 @@ +--- +title: "Codex" +description: "ChatGPT Plus/Pro-backed Codex provider guide and caveats" +icon: "code" +--- + +## Overview + +The `codex` provider lets Bifrost route requests through ChatGPT Plus/Pro-backed Codex access instead of the standard OpenAI API key flow. + + + This provider depends on the ChatGPT/Codex subscription flow and a non-public upstream surface that can change without notice. It is best + treated as an experimental compatibility path, not the default production path for most deployments. + + + + For production workloads, prefer standard API-backed providers such as [OpenAI](/providers/supported-providers/openai), + [Anthropic](/providers/supported-providers/anthropic), [Bedrock](/providers/supported-providers/bedrock), or [Vertex + AI](/providers/supported-providers/vertex). Those providers have stable public APIs, clearer support boundaries, and more predictable + operational behavior. + + +## Supported Operations + +| Operation | Non-Streaming | Streaming | Notes | +| ---------------- | ------------- | --------- | ---------------------------------------------------------------------------- | +| List Models | ✅ | - | Returns a curated Codex model list | +| Chat Completions | ✅ | ✅ | Non-stream requests are internally adapted to the stream-first upstream path | +| Responses API | ✅ | ✅ | Native upstream path | + +All other OpenAI-style operations are unsupported. + +## Authentication + +Codex uses a device-code flow in the Bifrost UI: + +1. Create or open a Codex key in the Providers UI. +2. Click **Connect with OpenAI**. +3. Open the verification link. +4. Enter the displayed code on the OpenAI page. +5. Bifrost stores the resulting refresh/access credentials in the config store. + +For gateway deployments, use a persistent config store such as Postgres. Ephemeral local storage inside a pod will lose Codex credentials on restart. + +## Configuration Mode + +`config.json` support remains available for manual/static credentials, but interactive device authentication is best suited to DB-backed provider management. + +When using config mode, provide a Codex key via `codex_key_config`: + +```json +{ + "providers": { + "codex": { + "keys": [ + { + "key_id": "codex-primary", + "name": "Codex Primary", + "codex_key_config": { + "refresh_token": "env.CODEX_REFRESH_TOKEN", + "account_id": "env.CODEX_ACCOUNT_ID", + "auth_method": "manual" + } + } + ] + } + } +} +``` + +## Request Compatibility Notes + +The upstream Codex endpoint is stricter than the standard OpenAI API. Bifrost adapts requests for compatibility, including: + +- forcing `store=false` +- adding default `instructions` when missing +- converting non-stream requests to the stream-first upstream flow internally +- stripping unsupported tuning parameters such as: + - `temperature` + - `top_p` + - `max_output_tokens` + - `presence_penalty` + - `frequency_penalty` + +Tool-calling and stream handling are normalized back into Bifrost's OpenAI-compatible response shapes. + +## Operational Guidance + +- Prefer API-key providers for production automation and long-term reliability. +- Keep Codex credentials in a persistent DB-backed config store. +- Expect upstream behavior to change more often than stable public provider APIs. diff --git a/docs/providers/supported-providers/overview.mdx b/docs/providers/supported-providers/overview.mdx index 98d13ffa73..af3043dbcc 100644 --- a/docs/providers/supported-providers/overview.mdx +++ b/docs/providers/supported-providers/overview.mdx @@ -22,6 +22,7 @@ The following table summarizes which operations are supported by each provider v | Bedrock (`bedrock/`) | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | | Cerebras (`cerebras/`) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | Cohere (`cohere/`) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | +| Codex (`codex/`) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | Elevenlabs (`elevenlabs/`) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | | Fireworks (`fireworks/`) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | Gemini (`gemini/`) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | @@ -45,6 +46,8 @@ The following table summarizes which operations are supported by each provider v - ❌ Not supported by the downstream provider, hence not supported by Bifrost. - ✅ Fully supported by the downstream provider, or internally implemented by Bifrost. +Codex is an experimental subscription-backed provider that depends on a non-public upstream surface. Prefer API-key providers such as OpenAI, Anthropic, Bedrock, or Vertex AI for production use. + Some operations are not supported by the downstream provider, and their internal implementation in Bifrost is optional. 🟡 diff --git a/framework/configstore/clientconfig.go b/framework/configstore/clientconfig.go index 00970ce25e..49b6dd9f41 100644 --- a/framework/configstore/clientconfig.go +++ b/framework/configstore/clientconfig.go @@ -418,6 +418,24 @@ func (p *ProviderConfig) Redacted() *ProviderConfig { vllmConfig.URL = *key.VLLMKeyConfig.URL.Redacted() redactedConfig.Keys[i].VLLMKeyConfig = vllmConfig } + + if key.CodexKeyConfig != nil { + codexConfig := &schemas.CodexKeyConfig{ + RefreshToken: key.CodexKeyConfig.RefreshToken, + AuthMethod: key.CodexKeyConfig.AuthMethod, + } + codexConfig.RefreshToken = *key.CodexKeyConfig.RefreshToken.Redacted() + if key.CodexKeyConfig.AccessToken != nil { + codexConfig.AccessToken = key.CodexKeyConfig.AccessToken.Redacted() + } + if key.CodexKeyConfig.AccountID != nil { + codexConfig.AccountID = key.CodexKeyConfig.AccountID.Redacted() + } + if key.CodexKeyConfig.AccessTokenExpiresAt != nil { + codexConfig.AccessTokenExpiresAt = key.CodexKeyConfig.AccessTokenExpiresAt + } + redactedConfig.Keys[i].CodexKeyConfig = codexConfig + } } return &redactedConfig } @@ -585,6 +603,13 @@ func GenerateKeyHash(key schemas.Key) (string, error) { } hash.Write(data) } + if key.CodexKeyConfig != nil { + data, err := sonic.Marshal(key.CodexKeyConfig) + if err != nil { + return "", err + } + hash.Write(data) + } // Hash Enabled (nil = false, only true produces different hash) if key.Enabled != nil && *key.Enabled { hash.Write([]byte("enabled:true")) diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index 0a7855fdc9..2bccbfd7d5 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -277,6 +277,12 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddProviderPricingOverridesColumn(ctx, db); err != nil { return err } + if err := migrationAddCodexKeyColumns(ctx, db); err != nil { + return err + } + if err := migrationAddCodexAuthSessionsTable(ctx, db); err != nil { + return err + } if err := migrationAddEncryptionColumns(ctx, db); err != nil { return err } @@ -356,20 +362,21 @@ func migrationAddStoreRawRequestResponseColumn(ctx context.Context, db *gorm.DB) // dirty after upgrade. StoreRawRequestResponse is now part of the // hash input; rows written before this migration have stale hashes. var providers []tables.TableProvider + selectColumns := []string{ + "id", + "name", + "network_config_json", + "concurrency_buffer_json", + "proxy_config_json", + "custom_provider_config_json", + "pricing_overrides_json", + "send_back_raw_request", + "send_back_raw_response", + "store_raw_request_response", + "encryption_status", + } if err := tx. - Select( - "id", - "name", - "network_config_json", - "concurrency_buffer_json", - "proxy_config_json", - "custom_provider_config_json", - "pricing_overrides_json", - "send_back_raw_request", - "send_back_raw_response", - "store_raw_request_response", - "encryption_status", - ). + Select(selectColumns). Find(&providers).Error; err != nil { return fmt.Errorf("failed to fetch providers for hash backfill: %w", err) } @@ -3962,6 +3969,73 @@ func migrationAddProviderPricingOverridesColumn(ctx context.Context, db *gorm.DB return nil } +// migrationAddCodexKeyColumns adds Codex credential columns to config_keys. +func migrationAddCodexKeyColumns(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_codex_key_columns", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + columns := []string{"CodexRefreshToken", "CodexAccessToken", "CodexAccessTokenExpiresAt", "CodexAccountID", "CodexAuthMethod"} + for _, column := range columns { + if !mg.HasColumn(&tables.TableKey{}, column) { + if err := mg.AddColumn(&tables.TableKey{}, column); err != nil { + return fmt.Errorf("failed to add %s column: %w", column, err) + } + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + columns := []string{"codex_auth_method", "codex_account_id", "codex_access_token_expires_at", "codex_access_token", "codex_refresh_token"} + for _, column := range columns { + if mg.HasColumn(&tables.TableKey{}, column) { + if err := mg.DropColumn(&tables.TableKey{}, column); err != nil { + return fmt.Errorf("failed to drop %s column: %w", column, err) + } + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running codex key columns migration: %s", err.Error()) + } + return nil +} + +func migrationAddCodexAuthSessionsTable(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_codex_auth_sessions_table", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + if !mg.HasTable(&tables.TableCodexAuthSession{}) { + if err := mg.CreateTable(&tables.TableCodexAuthSession{}); err != nil { + return fmt.Errorf("failed to create codex auth sessions table: %w", err) + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + if mg.HasTable(&tables.TableCodexAuthSession{}) { + if err := mg.DropTable(&tables.TableCodexAuthSession{}); err != nil { + return fmt.Errorf("failed to drop codex auth sessions table: %w", err) + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running codex auth sessions table migration: %s", err.Error()) + } + return nil +} + // migrationAddEncryptionColumns adds the encryption_status column to the config_keys, governance_virtual_keys, sessions, oauth_configs, oauth_tokens, config_mcp_clients, config_providers, config_vector_store, and config_plugins tables func migrationAddEncryptionColumns(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index 343e502cd2..8eab1d35b2 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -299,6 +299,7 @@ func (s *RDBConfigStore) UpdateProvidersConfig(ctx context.Context, providers ma BedrockKeyConfig: key.BedrockKeyConfig, ReplicateKeyConfig: key.ReplicateKeyConfig, VLLMKeyConfig: key.VLLMKeyConfig, + CodexKeyConfig: key.CodexKeyConfig, ConfigHash: keyHash, Status: string(key.Status), Description: key.Description, @@ -469,6 +470,7 @@ func (s *RDBConfigStore) UpdateProvider(ctx context.Context, provider schemas.Mo BedrockKeyConfig: key.BedrockKeyConfig, ReplicateKeyConfig: key.ReplicateKeyConfig, VLLMKeyConfig: key.VLLMKeyConfig, + CodexKeyConfig: key.CodexKeyConfig, ConfigHash: keyHash, Status: string(key.Status), Description: key.Description, @@ -592,6 +594,7 @@ func (s *RDBConfigStore) AddProvider(ctx context.Context, provider schemas.Model BedrockKeyConfig: key.BedrockKeyConfig, ReplicateKeyConfig: key.ReplicateKeyConfig, VLLMKeyConfig: key.VLLMKeyConfig, + CodexKeyConfig: key.CodexKeyConfig, ConfigHash: key.ConfigHash, Status: string(key.Status), Description: key.Description, @@ -714,6 +717,7 @@ func (s *RDBConfigStore) GetProvidersConfig(ctx context.Context) (map[schemas.Mo BedrockKeyConfig: dbKey.BedrockKeyConfig, ReplicateKeyConfig: dbKey.ReplicateKeyConfig, VLLMKeyConfig: dbKey.VLLMKeyConfig, + CodexKeyConfig: dbKey.CodexKeyConfig, ConfigHash: dbKey.ConfigHash, Status: schemas.KeyStatusType(dbKey.Status), Description: dbKey.Description, @@ -765,6 +769,7 @@ func (s *RDBConfigStore) GetProviderConfig(ctx context.Context, provider schemas BedrockKeyConfig: dbKey.BedrockKeyConfig, ReplicateKeyConfig: dbKey.ReplicateKeyConfig, VLLMKeyConfig: dbKey.VLLMKeyConfig, + CodexKeyConfig: dbKey.CodexKeyConfig, ConfigHash: dbKey.ConfigHash, Status: schemas.KeyStatusType(dbKey.Status), Description: dbKey.Description, @@ -863,6 +868,26 @@ func (s *RDBConfigStore) UpdateStatus(ctx context.Context, provider schemas.Mode return fmt.Errorf("either keyID or provider must be non-empty") } +func (s *RDBConfigStore) PersistCodexKeyConfig(ctx context.Context, keyID string, keyConfig *schemas.CodexKeyConfig) error { + if keyConfig == nil { + return nil + } + + var dbKey tables.TableKey + if err := s.db.WithContext(ctx).Where("provider = ? AND key_id = ?", string(schemas.Codex), keyID).First(&dbKey).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return fmt.Errorf("failed to load codex key for persistence: %w", err) + } + + dbKey.CodexKeyConfig = keyConfig + if err := s.db.WithContext(ctx).Save(&dbKey).Error; err != nil { + return fmt.Errorf("failed to persist codex key config: %w", err) + } + return nil +} + // GetMCPConfig retrieves the MCP configuration from the database. func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error) { var dbMCPClients []tables.TableMCPClient @@ -3601,3 +3626,36 @@ func (s *RDBConfigStore) GetOauthConfigByTokenID(ctx context.Context, tokenID st } return &config, nil } + +func (s *RDBConfigStore) GetCodexAuthSessionByID(ctx context.Context, id string) (*tables.TableCodexAuthSession, error) { + var session tables.TableCodexAuthSession + result := s.db.WithContext(ctx).Where("id = ?", id).First(&session) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to get codex auth session: %w", result.Error) + } + return &session, nil +} + +func (s *RDBConfigStore) CreateCodexAuthSession(ctx context.Context, session *tables.TableCodexAuthSession) error { + if err := s.db.WithContext(ctx).Create(session).Error; err != nil { + return fmt.Errorf("failed to create codex auth session: %w", err) + } + return nil +} + +func (s *RDBConfigStore) UpdateCodexAuthSession(ctx context.Context, session *tables.TableCodexAuthSession) error { + if err := s.db.WithContext(ctx).Save(session).Error; err != nil { + return fmt.Errorf("failed to update codex auth session: %w", err) + } + return nil +} + +func (s *RDBConfigStore) DeleteCodexAuthSession(ctx context.Context, id string) error { + if err := s.db.WithContext(ctx).Where("id = ?", id).Delete(&tables.TableCodexAuthSession{}).Error; err != nil { + return fmt.Errorf("failed to delete codex auth session: %w", err) + } + return nil +} diff --git a/framework/configstore/store.go b/framework/configstore/store.go index 4d6d960bbd..608dd94de4 100644 --- a/framework/configstore/store.go +++ b/framework/configstore/store.go @@ -88,6 +88,7 @@ type ConfigStore interface { GetProviders(ctx context.Context) ([]tables.TableProvider, error) GetProvider(ctx context.Context, provider schemas.ModelProvider) (*tables.TableProvider, error) UpdateStatus(ctx context.Context, provider schemas.ModelProvider, keyID string, status, errorMsg string) error + PersistCodexKeyConfig(ctx context.Context, keyID string, keyConfig *schemas.CodexKeyConfig) error // MCP config CRUD GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error) @@ -270,6 +271,12 @@ type ConfigStore interface { UpdateOauthToken(ctx context.Context, token *tables.TableOauthToken) error DeleteOauthToken(ctx context.Context, id string) error + // Codex auth sessions + GetCodexAuthSessionByID(ctx context.Context, id string) (*tables.TableCodexAuthSession, error) + CreateCodexAuthSession(ctx context.Context, session *tables.TableCodexAuthSession) error + UpdateCodexAuthSession(ctx context.Context, session *tables.TableCodexAuthSession) error + DeleteCodexAuthSession(ctx context.Context, id string) error + // Not found retry wrapper RetryOnNotFound(ctx context.Context, fn func(ctx context.Context) (any, error), maxRetries int, retryDelay time.Duration) (any, error) diff --git a/framework/configstore/tables/codexauth.go b/framework/configstore/tables/codexauth.go new file mode 100644 index 0000000000..dd7609b5b2 --- /dev/null +++ b/framework/configstore/tables/codexauth.go @@ -0,0 +1,60 @@ +package tables + +import ( + "fmt" + "time" + + "github.com/maximhq/bifrost/framework/encrypt" + "gorm.io/gorm" +) + +type TableCodexAuthSession struct { + ID string `gorm:"type:varchar(255);primaryKey" json:"id"` + Provider string `gorm:"type:varchar(50);index;not null" json:"provider"` + KeyID string `gorm:"type:varchar(255);index;not null" json:"key_id"` + FlowType string `gorm:"type:varchar(32);index;not null" json:"flow_type"` + Status string `gorm:"type:varchar(50);index;not null" json:"status"` + DeviceAuthID *string `gorm:"type:text" json:"-"` + UserCode *string `gorm:"type:varchar(64)" json:"user_code,omitempty"` + VerificationURI *string `gorm:"type:text" json:"verification_uri,omitempty"` + IntervalSeconds *int `gorm:"type:int" json:"interval_seconds,omitempty"` + NextPollAt *time.Time `gorm:"index" json:"next_poll_at,omitempty"` + LastError *string `gorm:"type:text" json:"last_error,omitempty"` + EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"` + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` + ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` + CompletedAt *time.Time `gorm:"index" json:"completed_at,omitempty"` +} + +func (TableCodexAuthSession) TableName() string { return "codex_auth_sessions" } + +func (s *TableCodexAuthSession) BeforeSave(tx *gorm.DB) error { + if s.Status == "" { + s.Status = "pending" + } + if encrypt.IsEnabled() { + encrypted := false + if s.DeviceAuthID != nil && *s.DeviceAuthID != "" { + if err := encryptString(s.DeviceAuthID); err != nil { + return fmt.Errorf("failed to encrypt codex device auth id: %w", err) + } + encrypted = true + } + if encrypted { + s.EncryptionStatus = EncryptionStatusEncrypted + } + } + return nil +} + +func (s *TableCodexAuthSession) AfterFind(tx *gorm.DB) error { + if s.EncryptionStatus == EncryptionStatusEncrypted { + if s.DeviceAuthID != nil && *s.DeviceAuthID != "" { + if err := decryptString(s.DeviceAuthID); err != nil { + return fmt.Errorf("failed to decrypt codex device auth id: %w", err) + } + } + } + return nil +} diff --git a/framework/configstore/tables/key.go b/framework/configstore/tables/key.go index c0763e9045..8edf5609e3 100644 --- a/framework/configstore/tables/key.go +++ b/framework/configstore/tables/key.go @@ -64,6 +64,13 @@ type TableKey struct { VLLMUrl *schemas.EnvVar `gorm:"type:text" json:"vllm_url,omitempty"` VLLMModelName *string `gorm:"type:varchar(255)" json:"vllm_model_name,omitempty"` + // Codex config fields (embedded) + CodexRefreshToken *schemas.EnvVar `gorm:"type:text" json:"codex_refresh_token,omitempty"` + CodexAccessToken *schemas.EnvVar `gorm:"type:text" json:"codex_access_token,omitempty"` + CodexAccessTokenExpiresAt *string `gorm:"type:varchar(255)" json:"codex_access_token_expires_at,omitempty"` + CodexAccountID *schemas.EnvVar `gorm:"type:text" json:"codex_account_id,omitempty"` + CodexAuthMethod *string `gorm:"type:varchar(32)" json:"codex_auth_method,omitempty"` + // Batch API configuration UseForBatchAPI *bool `gorm:"default:false" json:"use_for_batch_api,omitempty"` // Whether this key can be used for batch API operations @@ -80,6 +87,7 @@ type TableKey struct { BedrockKeyConfig *schemas.BedrockKeyConfig `gorm:"-" json:"bedrock_key_config,omitempty"` ReplicateKeyConfig *schemas.ReplicateKeyConfig `gorm:"-" json:"replicate_key_config,omitempty"` VLLMKeyConfig *schemas.VLLMKeyConfig `gorm:"-" json:"vllm_key_config,omitempty"` + CodexKeyConfig *schemas.CodexKeyConfig `gorm:"-" json:"codex_key_config,omitempty"` } // TableName sets the table name for each model @@ -342,6 +350,41 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error { k.VLLMModelName = nil } + if k.CodexKeyConfig != nil { + refreshToken := k.CodexKeyConfig.RefreshToken + k.CodexRefreshToken = &refreshToken + if k.CodexKeyConfig.AccessToken != nil { + accessToken := *k.CodexKeyConfig.AccessToken + k.CodexAccessToken = &accessToken + } else { + k.CodexAccessToken = nil + } + if k.CodexKeyConfig.AccountID != nil { + accountID := *k.CodexKeyConfig.AccountID + k.CodexAccountID = &accountID + } else { + k.CodexAccountID = nil + } + if k.CodexKeyConfig.AccessTokenExpiresAt != nil { + expiresAt := *k.CodexKeyConfig.AccessTokenExpiresAt + k.CodexAccessTokenExpiresAt = &expiresAt + } else { + k.CodexAccessTokenExpiresAt = nil + } + if k.CodexKeyConfig.AuthMethod != "" { + authMethod := string(k.CodexKeyConfig.AuthMethod) + k.CodexAuthMethod = &authMethod + } else { + k.CodexAuthMethod = nil + } + } else { + k.CodexRefreshToken = nil + k.CodexAccessToken = nil + k.CodexAccountID = nil + k.CodexAccessTokenExpiresAt = nil + k.CodexAuthMethod = nil + } + // Encrypt sensitive fields after serialization if encrypt.IsEnabled() { if err := encryptEnvVar(&k.Value); err != nil { @@ -411,6 +454,15 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error { if err := encryptEnvVarPtr(&k.VLLMUrl); err != nil { return fmt.Errorf("failed to encrypt vllm url: %w", err) } + if err := encryptEnvVarPtr(&k.CodexRefreshToken); err != nil { + return fmt.Errorf("failed to encrypt codex refresh token: %w", err) + } + if err := encryptEnvVarPtr(&k.CodexAccessToken); err != nil { + return fmt.Errorf("failed to encrypt codex access token: %w", err) + } + if err := encryptEnvVarPtr(&k.CodexAccountID); err != nil { + return fmt.Errorf("failed to encrypt codex account id: %w", err) + } k.EncryptionStatus = EncryptionStatusEncrypted } return nil @@ -489,6 +541,15 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { if err := decryptEnvVarPtr(&k.VLLMUrl); err != nil { return fmt.Errorf("failed to decrypt vllm url: %w", err) } + if err := decryptEnvVarPtr(&k.CodexRefreshToken); err != nil { + return fmt.Errorf("failed to decrypt codex refresh token: %w", err) + } + if err := decryptEnvVarPtr(&k.CodexAccessToken); err != nil { + return fmt.Errorf("failed to decrypt codex access token: %w", err) + } + if err := decryptEnvVarPtr(&k.CodexAccountID); err != nil { + return fmt.Errorf("failed to decrypt codex account id: %w", err) + } } if k.ModelsJSON != "" { @@ -638,5 +699,28 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { } else { k.VLLMKeyConfig = nil } + if k.CodexRefreshToken != nil || k.CodexAccessToken != nil || k.CodexAccountID != nil || k.CodexAccessTokenExpiresAt != nil || k.CodexAuthMethod != nil { + codexConfig := &schemas.CodexKeyConfig{ + RefreshToken: *schemas.NewEnvVar(""), + } + if k.CodexRefreshToken != nil { + codexConfig.RefreshToken = *k.CodexRefreshToken + } + if k.CodexAccessToken != nil { + codexConfig.AccessToken = k.CodexAccessToken + } + if k.CodexAccountID != nil { + codexConfig.AccountID = k.CodexAccountID + } + if k.CodexAccessTokenExpiresAt != nil { + codexConfig.AccessTokenExpiresAt = k.CodexAccessTokenExpiresAt + } + if k.CodexAuthMethod != nil { + codexConfig.AuthMethod = schemas.CodexAuthMethod(*k.CodexAuthMethod) + } + k.CodexKeyConfig = codexConfig + } else { + k.CodexKeyConfig = nil + } return nil } diff --git a/framework/modelcatalog/main.go b/framework/modelcatalog/main.go index 6603d04db7..96538cd439 100644 --- a/framework/modelcatalog/main.go +++ b/framework/modelcatalog/main.go @@ -88,8 +88,8 @@ type PricingEntry struct { InputCostPerAudioPerSecondAbove128kTokens *float64 `json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"` OutputCostPerTokenAbove128kTokens *float64 `json:"output_cost_per_token_above_128k_tokens,omitempty"` // Costs - 200k Tier - InputCostPerTokenAbove200kTokens *float64 `json:"input_cost_per_token_above_200k_tokens,omitempty"` - InputCostPerTokenAbove200kTokensPriority *float64 `json:"input_cost_per_token_above_200k_tokens_priority,omitempty"` + InputCostPerTokenAbove200kTokens *float64 `json:"input_cost_per_token_above_200k_tokens,omitempty"` + InputCostPerTokenAbove200kTokensPriority *float64 `json:"input_cost_per_token_above_200k_tokens_priority,omitempty"` OutputCostPerTokenAbove200kTokens *float64 `json:"output_cost_per_token_above_200k_tokens,omitempty"` OutputCostPerTokenAbove200kTokensPriority *float64 `json:"output_cost_per_token_above_200k_tokens_priority,omitempty"` // Costs - 272k Tier diff --git a/framework/modelcatalog/main_test.go b/framework/modelcatalog/main_test.go index 324b28c791..60888873ea 100644 --- a/framework/modelcatalog/main_test.go +++ b/framework/modelcatalog/main_test.go @@ -3,6 +3,7 @@ package modelcatalog import ( "testing" + bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" @@ -22,6 +23,7 @@ func newTestCatalog(modelPool map[schemas.ModelProvider][]string, baseModelIndex baseModelIndex: baseModelIndex, pricingData: make(map[string]configstoreTables.TableModelPricing), compiledOverrides: make(map[schemas.ModelProvider][]compiledProviderPricingOverride), + logger: bifrost.NewNoOpLogger(), } } diff --git a/framework/modelcatalog/pricing.go b/framework/modelcatalog/pricing.go index 53eb46bc2d..0264caa7ed 100644 --- a/framework/modelcatalog/pricing.go +++ b/framework/modelcatalog/pricing.go @@ -833,7 +833,6 @@ func populateOutputImageCount(imageUsage *schemas.ImageUsage, dataLen int) { // resolvePricing resolves the pricing entry for a model, trying deployment as fallback. func (mc *ModelCatalog) resolvePricing(provider, model, deployment string, requestType schemas.RequestType) *configstoreTables.TableModelPricing { mc.logger.Debug("looking up pricing for model %s and provider %s of request type %s", model, provider, normalizeRequestType(requestType)) - pricing, exists := mc.getPricing(model, provider, requestType) if exists { return pricing diff --git a/plugins/maxim/plugin_test.go b/plugins/maxim/plugin_test.go index d5d70d81cd..e8d2916acc 100644 --- a/plugins/maxim/plugin_test.go +++ b/plugins/maxim/plugin_test.go @@ -84,6 +84,9 @@ func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelPr // - MAXIM_LOGGER_ID: Your Maxim logger repository ID // - OPENAI_API_KEY: Your OpenAI API key for the test request func TestMaximLoggerPlugin(t *testing.T) { + if os.Getenv("MAXIM_API_KEY") == "" || os.Getenv("OPENAI_API_KEY") == "" { + t.Skip("Skipping Maxim integration test because MAXIM_API_KEY or OPENAI_API_KEY is not set") + } ctx := context.Background() // Initialize the Maxim plugin plugin, err := getPlugin() diff --git a/transports/bifrost-http/handlers/codexauth.go b/transports/bifrost-http/handlers/codexauth.go new file mode 100644 index 0000000000..246b35d07e --- /dev/null +++ b/transports/bifrost-http/handlers/codexauth.go @@ -0,0 +1,303 @@ +package handlers + +import ( + "context" + "fmt" + "net/http" + "strings" + "time" + + "github.com/fasthttp/router" + "github.com/google/uuid" + providerCodex "github.com/maximhq/bifrost/core/providers/codex" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +const ( + codexAuthSessionPending = "pending" + codexAuthSessionSucceeded = "authorized" + codexAuthSessionFailed = "failed" + codexAuthSessionExpired = "expired" + codexAuthSessionCancelled = "cancelled" + defaultCodexSessionTTL = 15 * time.Minute +) + +type CodexAuthHandler struct { + store *lib.Config + httpClient *http.Client +} + +func NewCodexAuthHandler(store *lib.Config) *CodexAuthHandler { + return &CodexAuthHandler{ + store: store, + httpClient: &http.Client{ + Timeout: 20 * time.Second, + }, + } +} + +func (h *CodexAuthHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { + r.POST("/api/providers/codex/keys/{keyId}/auth/device/start", lib.ChainMiddlewares(h.startDeviceAuth, middlewares...)) + r.GET("/api/providers/codex/auth/sessions/{id}", lib.ChainMiddlewares(h.getAuthSessionStatus, middlewares...)) + r.DELETE("/api/providers/codex/auth/sessions/{id}", lib.ChainMiddlewares(h.cancelAuthSession, middlewares...)) +} + +func (h *CodexAuthHandler) startDeviceAuth(ctx *fasthttp.RequestCtx) { + keyID := ctx.UserValue("keyId").(string) + if _, _, err := h.getEditableCodexKey(keyID); err != nil { + h.sendAuthError(ctx, err) + return + } + requestCtx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + deviceAuth, err := providerCodex.StartDeviceAuthorization(requestCtx, h.httpClient, h.userAgent()) + if err != nil { + SendError(ctx, fasthttp.StatusBadGateway, fmt.Sprintf("Failed to start device authorization: %v", err)) + return + } + intervalSeconds := 5 + if parsed, parseErr := time.ParseDuration(deviceAuth.Interval + "s"); parseErr == nil { + intervalSeconds = max(1, int(parsed.Seconds())) + } + verificationURI := providerCodex.DeviceVerificationURL + deviceAuthID := deviceAuth.DeviceAuthID + userCode := deviceAuth.UserCode + nextPollAt := providerCodex.NextPollTime(intervalSeconds) + session := &configstoreTables.TableCodexAuthSession{ + ID: uuid.NewString(), + Provider: string(schemas.Codex), + KeyID: keyID, + FlowType: string(schemas.CodexAuthMethodDevice), + Status: codexAuthSessionPending, + DeviceAuthID: &deviceAuthID, + UserCode: &userCode, + VerificationURI: &verificationURI, + IntervalSeconds: &intervalSeconds, + NextPollAt: &nextPollAt, + ExpiresAt: time.Now().Add(defaultCodexSessionTTL), + } + if err := h.store.ConfigStore.CreateCodexAuthSession(ctx, session); err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create auth session: %v", err)) + return + } + SendJSON(ctx, h.sessionResponse(session)) +} + +func (h *CodexAuthHandler) getAuthSessionStatus(ctx *fasthttp.RequestCtx) { + sessionID := ctx.UserValue("id").(string) + session, err := h.store.ConfigStore.GetCodexAuthSessionByID(ctx, sessionID) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get auth session: %v", err)) + return + } + if session == nil { + SendError(ctx, fasthttp.StatusNotFound, "Auth session not found") + return + } + if err := h.refreshSessionState(ctx, session); err != nil { + SendError(ctx, fasthttp.StatusBadGateway, fmt.Sprintf("Failed to refresh auth session: %v", err)) + return + } + SendJSON(ctx, h.sessionResponse(session)) +} + +func (h *CodexAuthHandler) cancelAuthSession(ctx *fasthttp.RequestCtx) { + sessionID := ctx.UserValue("id").(string) + session, err := h.store.ConfigStore.GetCodexAuthSessionByID(ctx, sessionID) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get auth session: %v", err)) + return + } + if session == nil { + SendError(ctx, fasthttp.StatusNotFound, "Auth session not found") + return + } + session.Status = codexAuthSessionCancelled + now := time.Now() + session.CompletedAt = &now + if err := h.store.ConfigStore.UpdateCodexAuthSession(ctx, session); err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to cancel auth session: %v", err)) + return + } + SendJSON(ctx, h.sessionResponse(session)) +} + +func (h *CodexAuthHandler) refreshSessionState(ctx context.Context, session *configstoreTables.TableCodexAuthSession) error { + if session.Status != codexAuthSessionPending { + return nil + } + if time.Now().After(session.ExpiresAt) { + session.Status = codexAuthSessionExpired + return h.store.ConfigStore.UpdateCodexAuthSession(ctx, session) + } + if session.FlowType != string(schemas.CodexAuthMethodDevice) { + return nil + } + if session.NextPollAt != nil && time.Now().Before(*session.NextPollAt) { + return nil + } + if session.DeviceAuthID == nil || session.UserCode == nil { + session.Status = codexAuthSessionFailed + message := "Missing device authorization state" + session.LastError = &message + return h.store.ConfigStore.UpdateCodexAuthSession(ctx, session) + } + pollCtx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + deviceToken, statusCode, err := providerCodex.PollDeviceAuthorization(pollCtx, h.httpClient, *session.DeviceAuthID, *session.UserCode, h.userAgent()) + if err != nil { + return err + } + if statusCode == http.StatusOK && deviceToken != nil { + exchangeCtx, exchangeCancel := context.WithTimeout(context.Background(), 20*time.Second) + defer exchangeCancel() + tokens, err := providerCodex.ExchangeDeviceAuthorizationCode(exchangeCtx, h.httpClient, deviceToken.AuthorizationCode, deviceToken.CodeVerifier) + if err != nil { + session.Status = codexAuthSessionFailed + message := err.Error() + session.LastError = &message + return h.store.ConfigStore.UpdateCodexAuthSession(ctx, session) + } + if err := h.persistTokensToKey(ctx, session.KeyID, tokens, schemas.CodexAuthMethodDevice); err != nil { + session.Status = codexAuthSessionFailed + message := err.Error() + session.LastError = &message + return h.store.ConfigStore.UpdateCodexAuthSession(ctx, session) + } + h.completeSession(session) + return nil + } + if statusCode == fasthttp.StatusForbidden || statusCode == fasthttp.StatusNotFound { + nextPollAt := providerCodex.NextPollTime(valueOrDefault(session.IntervalSeconds, 5)) + session.NextPollAt = &nextPollAt + return h.store.ConfigStore.UpdateCodexAuthSession(ctx, session) + } + session.Status = codexAuthSessionFailed + message := fmt.Sprintf("Device authorization failed with status %d", statusCode) + session.LastError = &message + return h.store.ConfigStore.UpdateCodexAuthSession(ctx, session) +} + +func (h *CodexAuthHandler) persistTokensToKey(ctx context.Context, keyID string, tokens *providerCodex.TokenResponse, authMethod schemas.CodexAuthMethod) error { + providerConfig, _, err := h.getEditableCodexKey(keyID) + if err != nil { + return err + } + updatedKeys := append([]schemas.Key(nil), providerConfig.Keys...) + for idx := range updatedKeys { + if updatedKeys[idx].ID != keyID { + continue + } + var existingAccountID *schemas.EnvVar + if updatedKeys[idx].CodexKeyConfig != nil { + existingAccountID = updatedKeys[idx].CodexKeyConfig.AccountID + } + refreshToken := schemas.NewEnvVar(tokens.RefreshToken) + accessToken := schemas.NewEnvVar(tokens.AccessToken) + accountIDValue := providerCodex.ExtractAccountID(tokens) + accessTokenExpiresAt := providerCodex.ExpiresAtFromNow(tokens.ExpiresIn) + updatedKeys[idx].CodexKeyConfig = &schemas.CodexKeyConfig{ + RefreshToken: *refreshToken, + AccessToken: accessToken, + AccessTokenExpiresAt: &accessTokenExpiresAt, + AuthMethod: authMethod, + } + if accountIDValue != "" { + updatedKeys[idx].CodexKeyConfig.AccountID = schemas.NewEnvVar(accountIDValue) + } else if existingAccountID != nil { + updatedKeys[idx].CodexKeyConfig.AccountID = existingAccountID + } + providerConfig.Keys = updatedKeys + return h.store.UpdateProviderConfig(ctx, schemas.Codex, *providerConfig) + } + return fmt.Errorf("Codex key %s not found", keyID) +} + +func (h *CodexAuthHandler) getEditableCodexKey(keyID string) (*configstore.ProviderConfig, *schemas.Key, error) { + if h.store == nil || h.store.ConfigStore == nil { + return nil, nil, fmt.Errorf("database-backed config store is required for Codex authentication") + } + providerConfig, err := h.store.GetProviderConfigRaw(schemas.Codex) + if err != nil { + return nil, nil, fmt.Errorf("failed to load Codex provider config: %w", err) + } + if providerConfig == nil { + return nil, nil, fmt.Errorf("Codex provider is not configured") + } + for idx := range providerConfig.Keys { + if providerConfig.Keys[idx].ID != keyID { + continue + } + // ConfigHash is used for general reconciliation and is not a reliable indicator + // that a key is currently managed by config.json. Allow Codex reauth for saved keys. + return providerConfig, &providerConfig.Keys[idx], nil + } + return nil, nil, fmt.Errorf("Codex key %s not found", keyID) +} + +func (h *CodexAuthHandler) completeSession(session *configstoreTables.TableCodexAuthSession) { + now := time.Now() + session.Status = codexAuthSessionSucceeded + session.CompletedAt = &now + session.LastError = nil + _ = h.store.ConfigStore.UpdateCodexAuthSession(context.Background(), session) +} + +func (h *CodexAuthHandler) failSession(session *configstoreTables.TableCodexAuthSession, message string) { + now := time.Now() + session.Status = codexAuthSessionFailed + session.CompletedAt = &now + session.LastError = &message + _ = h.store.ConfigStore.UpdateCodexAuthSession(context.Background(), session) +} + +func (h *CodexAuthHandler) sessionResponse(session *configstoreTables.TableCodexAuthSession) map[string]any { + response := map[string]any{ + "id": session.ID, + "flow_type": session.FlowType, + "status": session.Status, + "expires_at": session.ExpiresAt, + } + if session.VerificationURI != nil { + response["verification_uri"] = *session.VerificationURI + } + if session.UserCode != nil { + response["user_code"] = *session.UserCode + } + if session.IntervalSeconds != nil { + response["interval_seconds"] = *session.IntervalSeconds + } + if session.NextPollAt != nil { + response["next_poll_at"] = *session.NextPollAt + } + if session.LastError != nil { + response["last_error"] = *session.LastError + } + if session.CompletedAt != nil { + response["completed_at"] = *session.CompletedAt + } + return response +} + +func (h *CodexAuthHandler) userAgent() string { + return "bifrost-codex-auth" +} + +func (h *CodexAuthHandler) sendAuthError(ctx *fasthttp.RequestCtx, err error) { + status := fasthttp.StatusBadRequest + if strings.Contains(err.Error(), "database-backed config store") { + status = fasthttp.StatusServiceUnavailable + } + SendError(ctx, status, err.Error()) +} + +func valueOrDefault(value *int, fallback int) int { + if value == nil || *value <= 0 { + return fallback + } + return *value +} diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index f39fa19106..d9b38c7ab5 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -3112,6 +3112,52 @@ func (c *Config) UpdateProviderConfig(ctx context.Context, provider schemas.Mode return nil } +// PersistCodexKeyCredentials updates the persisted and in-memory credentials for a Codex key +// without triggering a provider reload. The DB remains the source of truth; the in-memory +// update is only to avoid repeated refreshes on the current instance. +func (c *Config) PersistCodexKeyCredentials(ctx context.Context, keyID string, refreshed *schemas.CodexKeyConfig) error { + if refreshed == nil { + return nil + } + + if c.ConfigStore != nil { + if err := c.ConfigStore.PersistCodexKeyConfig(ctx, keyID, refreshed); err != nil { + if errors.Is(err, configstore.ErrNotFound) { + return ErrNotFound + } + return fmt.Errorf("failed to persist codex key credentials: %w", err) + } + } + + c.Mu.Lock() + defer c.Mu.Unlock() + existingConfig, exists := c.Providers[schemas.Codex] + if !exists { + return ErrNotFound + } + for idx := range existingConfig.Keys { + if existingConfig.Keys[idx].ID != keyID { + continue + } + current := existingConfig.Keys[idx].CodexKeyConfig + merged := &schemas.CodexKeyConfig{} + if current != nil { + *merged = *current + } + merged.RefreshToken = refreshed.RefreshToken + merged.AccessToken = refreshed.AccessToken + merged.AccessTokenExpiresAt = refreshed.AccessTokenExpiresAt + merged.AuthMethod = refreshed.AuthMethod + if refreshed.AccountID != nil { + merged.AccountID = refreshed.AccountID + } + existingConfig.Keys[idx].CodexKeyConfig = merged + c.Providers[schemas.Codex] = existingConfig + return nil + } + return ErrNotFound +} + // RemoveProvider removes a provider configuration from memory. func (c *Config) RemoveProvider(ctx context.Context, provider schemas.ModelProvider) error { c.Mu.Lock() diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index f64d54d31f..f010a41bf9 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -444,6 +444,22 @@ func (m *MockConfigStore) UpdateProvider(ctx context.Context, provider schemas.M return nil } +func (m *MockConfigStore) PersistCodexKeyConfig(ctx context.Context, keyID string, keyConfig *schemas.CodexKeyConfig) error { + provider, ok := m.providers[schemas.Codex] + if !ok { + return configstore.ErrNotFound + } + for idx := range provider.Keys { + if provider.Keys[idx].ID != keyID { + continue + } + provider.Keys[idx].CodexKeyConfig = keyConfig + m.providers[schemas.Codex] = provider + return nil + } + return configstore.ErrNotFound +} + func (m *MockConfigStore) DeleteProvider(ctx context.Context, provider schemas.ModelProvider, tx ...*gorm.DB) error { delete(m.providers, provider) return nil @@ -1022,6 +1038,22 @@ func (m *MockConfigStore) DeleteOauthToken(ctx context.Context, id string) error return nil } +func (m *MockConfigStore) GetCodexAuthSessionByID(ctx context.Context, id string) (*tables.TableCodexAuthSession, error) { + return nil, nil +} + +func (m *MockConfigStore) CreateCodexAuthSession(ctx context.Context, session *tables.TableCodexAuthSession) error { + return nil +} + +func (m *MockConfigStore) UpdateCodexAuthSession(ctx context.Context, session *tables.TableCodexAuthSession) error { + return nil +} + +func (m *MockConfigStore) DeleteCodexAuthSession(ctx context.Context, id string) error { + return nil +} + // Routing rules func (m *MockConfigStore) GetRoutingRules(ctx context.Context) ([]tables.TableRoutingRule, error) { return nil, nil @@ -15752,13 +15784,13 @@ func TestConfigSchemaSyncTopLevel(t *testing.T) { // Enterprise-only features: These fields exist in the JSON schema for documentation // and validation purposes, but are only available in the enterprise version. enterpriseSchemaFields := map[string]bool{ - "$schema": true, - "audit_logs": true, - "cluster_config": true, - "saml_config": true, - "load_balancer_config": true, - "guardrails_config": true, - "large_payload_optimization": true, + "$schema": true, + "audit_logs": true, + "cluster_config": true, + "saml_config": true, + "load_balancer_config": true, + "guardrails_config": true, + "large_payload_optimization": true, } schema := loadJSONSchema(t) diff --git a/transports/bifrost-http/lib/ctx.go b/transports/bifrost-http/lib/ctx.go index ea0c8b0a72..7b2678a203 100644 --- a/transports/bifrost-http/lib/ctx.go +++ b/transports/bifrost-http/lib/ctx.go @@ -152,8 +152,8 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat "transfer-encoding": true, // prevent auth/key overrides via x-bf-eh-* - "x-api-key": true, - "x-goog-api-key": true, + "x-api-key": true, + "x-goog-api-key": true, "x-bf-api-key": true, "x-bf-api-key-id": true, "x-bf-vk": true, diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 757d712510..48b366d6bc 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -1019,6 +1019,7 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser // lib.ChainMiddlewares chains multiple middlewares together healthHandler := handlers.NewHealthHandler(s.Config) providerHandler := handlers.NewProviderHandler(callbacks, s.Config, s.Client) + codexAuthHandler := handlers.NewCodexAuthHandler(s.Config) oauthHandler := handlers.NewOAuthHandler(s.Config.OAuthProvider, s.Client, s.Config) mcpHandler := handlers.NewMCPHandler(callbacks, s.Client, s.Config, oauthHandler) configHandler := handlers.NewConfigHandler(callbacks, s.Config) @@ -1028,6 +1029,7 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser // Going ahead with API handlers healthHandler.RegisterRoutes(s.Router, middlewares...) providerHandler.RegisterRoutes(s.Router, middlewares...) + codexAuthHandler.RegisterRoutes(s.Router, middlewares...) mcpHandler.RegisterRoutes(s.Router, middlewares...) configHandler.RegisterRoutes(s.Router, middlewares...) oauthHandler.RegisterRoutes(s.Router, middlewares...) @@ -1122,6 +1124,14 @@ func (s *BifrostHTTPServer) GetAllRedactedRoutingRules(ctx context.Context, ids // PrepareCommonMiddlewares gets the common middlewares for the Bifrost HTTP server func (s *BifrostHTTPServer) PrepareCommonMiddlewares() []schemas.BifrostHTTPMiddleware { commonMiddlewares := []schemas.BifrostHTTPMiddleware{} + if s.Config != nil && s.Config.ConfigStore != nil { + commonMiddlewares = append(commonMiddlewares, func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + ctx.SetUserValue(schemas.BifrostContextKeyCodexCredentialPersister, schemas.CodexCredentialPersister(s.persistCodexCredentialRefresh)) + next(ctx) + } + }) + } // Preparing middlewares // Initializing prometheus plugin prometheusPlugin, err := lib.FindPluginAs[*telemetry.PrometheusPlugin](s.Config, telemetry.PluginName) @@ -1133,6 +1143,15 @@ func (s *BifrostHTTPServer) PrepareCommonMiddlewares() []schemas.BifrostHTTPMidd return commonMiddlewares } +func (s *BifrostHTTPServer) persistCodexCredentialRefresh(keyID string, refreshed *schemas.CodexKeyConfig) error { + if s.Config == nil || refreshed == nil { + return nil + } + ctx := schemas.NewBifrostContext(context.Background(), time.Now().Add(5*time.Second)) + defer ctx.Cancel() + return s.Config.PersistCodexKeyCredentials(ctx, keyID, refreshed) +} + // Bootstrap initializes the Bifrost HTTP server with all necessary components. // It: // 1. Initializes Prometheus collectors for monitoring diff --git a/transports/config.schema.json b/transports/config.schema.json index 7e64fb2440..3ca690d295 100644 --- a/transports/config.schema.json +++ b/transports/config.schema.json @@ -209,6 +209,9 @@ "cohere": { "$ref": "#/$defs/provider" }, + "codex": { + "$ref": "#/$defs/provider_with_codex_config" + }, "azure": { "$ref": "#/$defs/provider_with_azure_config" }, @@ -2307,6 +2310,51 @@ } ] }, + "codex_key": { + "allOf": [ + { + "$ref": "#/$defs/base_key" + }, + { + "type": "object", + "properties": { + "codex_key_config": { + "type": "object", + "properties": { + "refresh_token": { + "type": "string", + "description": "Codex refresh token (can use env. prefix)" + }, + "access_token": { + "type": "string", + "description": "Optional cached Codex access token (can use env. prefix)" + }, + "access_token_expires_at": { + "type": "string", + "description": "Optional RFC3339 access token expiry" + }, + "account_id": { + "type": "string", + "description": "Optional ChatGPT account id (can use env. prefix)" + }, + "auth_method": { + "type": "string", + "enum": ["device", "manual"], + "description": "How this Codex credential set was provisioned" + } + }, + "required": [ + "refresh_token" + ], + "additionalProperties": false + } + }, + "required": [ + "codex_key_config" + ] + } + ] + }, "azure_key": { "allOf": [ { @@ -2635,6 +2683,51 @@ ], "additionalProperties": false }, + "provider_with_codex_config": { + "type": "object", + "properties": { + "keys": { + "type": "array", + "items": { + "$ref": "#/$defs/codex_key" + }, + "minItems": 1, + "description": "Codex subscription credentials for this provider" + }, + "network_config": { + "$ref": "#/$defs/network_config" + }, + "concurrency_and_buffer_size": { + "$ref": "#/$defs/concurrency_config" + }, + "proxy_config": { + "$ref": "#/$defs/proxy_config" + }, + "send_back_raw_request": { + "type": "boolean", + "description": "Include raw request in BifrostResponse (default: false)" + }, + "send_back_raw_response": { + "type": "boolean", + "description": "Include raw response in BifrostResponse (default: false)" + }, + "store_raw_request_response": { + "type": "boolean", + "description": "Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false)" + }, + "pricing_overrides": { + "type": "array", + "items": { + "$ref": "#/$defs/provider_pricing_override" + }, + "description": "Provider-level pricing overrides matched by model pattern" + } + }, + "required": [ + "keys" + ], + "additionalProperties": false + }, "mcp_client_config": { "type": "object", "properties": { diff --git a/ui/app/workspace/providers/dialogs/providerConfigSheet.tsx b/ui/app/workspace/providers/dialogs/providerConfigSheet.tsx index d257598224..6cdfb6ac89 100644 --- a/ui/app/workspace/providers/dialogs/providerConfigSheet.tsx +++ b/ui/app/workspace/providers/dialogs/providerConfigSheet.tsx @@ -4,7 +4,13 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; import { ModelProvider } from "@/lib/types/config"; import { RbacOperation, RbacResource, useRbac } from "@enterprise/lib"; import { useEffect, useMemo, useState } from "react"; -import { ApiStructureFormFragment, BetaHeadersFormFragment, GovernanceFormFragment, OpenAIConfigFormFragment, ProxyFormFragment } from "../fragments"; +import { + ApiStructureFormFragment, + BetaHeadersFormFragment, + GovernanceFormFragment, + OpenAIConfigFormFragment, + ProxyFormFragment, +} from "../fragments"; import { DebuggingFormFragment } from "../fragments/debuggingFormFragment"; import { NetworkFormFragment } from "../fragments/networkFormFragment"; import { PerformanceFormFragment } from "../fragments/performanceFormFragment"; diff --git a/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx b/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx index 62653fd545..d3d21afe7a 100644 --- a/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx +++ b/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx @@ -2,6 +2,7 @@ import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; +import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible"; import { EnvVarInput } from "@/components/ui/envVarInput"; import { FormControl, FormDescription, FormField, FormItem, FormLabel, FormMessage } from "@/components/ui/form"; import { Input } from "@/components/ui/input"; @@ -13,9 +14,10 @@ import { TagInput } from "@/components/ui/tagInput"; import { Textarea } from "@/components/ui/textarea"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; import { isRedacted } from "@/lib/utils/validation"; -import { Info, Plus, Trash2 } from "lucide-react"; +import { ChevronDown, Info, Plus, Trash2 } from "lucide-react"; import { useEffect, useState } from "react"; import { Control, UseFormReturn } from "react-hook-form"; +import { CodexAuthControls } from "./codexAuthControls"; // Providers that support batch APIs const BATCH_SUPPORTED_PROVIDERS = ["openai", "bedrock", "anthropic", "gemini", "azure"]; @@ -24,6 +26,9 @@ interface Props { control: Control; providerName: string; form: UseFormReturn; + isEditing?: boolean; + isConfigManaged?: boolean; + onEnsurePersisted?: (authMethod?: "device" | "manual") => Promise; } // Batch API form field for all providers @@ -49,13 +54,17 @@ function BatchAPIFormField({ control, form }: { control: Control; form: Use ); } -export function ApiKeyFormFragment({ control, providerName, form }: Props) { +export function ApiKeyFormFragment({ control, providerName, form, isEditing = false, isConfigManaged = false, onEnsurePersisted }: Props) { const isBedrock = providerName === "bedrock"; const isVertex = providerName === "vertex"; const isAzure = providerName === "azure"; const isReplicate = providerName === "replicate"; const isVLLM = providerName === "vllm"; + const isCodex = providerName === "codex"; const supportsBatchAPI = BATCH_SUPPORTED_PROVIDERS.includes(providerName); + const [showManualCodexCredentials, setShowManualCodexCredentials] = useState( + isConfigManaged || form.getValues("key.codex_key_config.auth_method") === "manual", + ); // Auth type state for Azure: 'api_key', 'entra_id', or 'default_credential' const [azureAuthType, setAzureAuthType] = useState<"api_key" | "entra_id" | "default_credential">("api_key"); @@ -171,7 +180,7 @@ export function ApiKeyFormFragment({ control, providerName, form }: Props) { /> {/* Hide API Key field for Azure when using Entra ID/Default Credential, and for Bedrock when not using API Key auth */} - {!isAzure && !isBedrock && ( + {!isAzure && !isBedrock && !isCodex && ( )} + {isCodex && ( + { + setShowManualCodexCredentials(open); + if (open) { + form.setValue("key.codex_key_config.auth_method", "manual", { shouldDirty: true }); + } + }} + > +
+ +
+
Manual credentials (advanced)
+

+ Most users should use the connect flow above. Only open this if you already have Codex tokens and want to manage them + manually. +

+
+ +
+ +
+ ( + + Refresh Token + + + + + + )} + /> + ( + + Cached Access Token + + + + + + )} + /> +
+ ( + + ChatGPT Account ID + + + + + + )} + /> + ( + + Access Token Expiry + + + + + + )} + /> +
+
+
+
+
+ )} + {isCodex && ( + + )} {!isVLLM && ( <> ; + isEditing: boolean; + isConfigManaged: boolean; + onEnsurePersisted?: (authMethod?: "device" | "manual") => Promise; +} + +export function CodexAuthControls({ providerName, keyId, form, isEditing, isConfigManaged, onEnsurePersisted }: CodexAuthControlsProps) { + const [startDeviceAuth, { isLoading: isStartingDevice }] = useStartCodexDeviceAuthMutation(); + const [cancelCodexAuthSession] = useCancelCodexAuthSessionMutation(); + const [getCodexAuthSession] = useLazyGetCodexAuthSessionQuery(); + const [getProvider] = useLazyGetProviderQuery(); + const [session, setSession] = useState(null); + const [isOpen, setIsOpen] = useState(false); + const [statusMessage, setStatusMessage] = useState(null); + const popupRef = useRef(null); + const watchedRefreshToken = useWatch({ control: form.control, name: "key.codex_key_config.refresh_token" }) as EnvVar | undefined; + const watchedAccessToken = useWatch({ control: form.control, name: "key.codex_key_config.access_token" }) as EnvVar | undefined; + const watchedAccountID = useWatch({ control: form.control, name: "key.codex_key_config.account_id" }) as EnvVar | undefined; + + const isConnected = useMemo(() => { + return Boolean( + watchedRefreshToken?.value || + watchedRefreshToken?.env_var || + watchedAccessToken?.value || + watchedAccessToken?.env_var || + watchedAccountID?.value || + watchedAccountID?.env_var, + ); + }, [watchedAccessToken, watchedAccountID, watchedRefreshToken]); + + const keyStatus = useMemo(() => { + if (isConfigManaged) { + return { + label: "Config managed", + description: "This key is managed from config.json, so the UI cannot change its authentication state.", + variant: "secondary" as const, + }; + } + if (session?.status === "pending") { + return { + label: "Waiting for authorization", + description: "The sign-in flow has started. Complete the OpenAI verification step to finish connecting this key.", + variant: "secondary" as const, + }; + } + if (session?.status === "authorized" || isConnected) { + return { + label: "Connected", + description: "Bifrost has a stored Codex credential for this key and can use it for requests.", + variant: "default" as const, + }; + } + if (session?.status === "failed" || session?.status === "expired" || session?.status === "cancelled") { + return { + label: "Not connected", + description: + session.last_error || `The last authorization attempt ended as ${session.status}. Start a new connection flow to continue.`, + variant: "destructive" as const, + }; + } + return { + label: "Not connected", + description: "This key does not have a stored Codex credential yet.", + variant: "outline" as const, + }; + }, [isConfigManaged, isConnected, session]); + + const syncFormFromProvider = useCallback(async () => { + const updatedProvider = await getProvider(providerName).unwrap(); + const updatedKey = updatedProvider.keys.find((key) => key.id === keyId); + if (updatedKey?.codex_key_config) { + form.setValue("key.codex_key_config", updatedKey.codex_key_config, { shouldDirty: false }); + } + }, [form, getProvider, keyId, providerName]); + + useEffect(() => { + if (!session || session.status !== "pending") { + return; + } + const timer = window.setInterval(async () => { + try { + const nextSession = await getCodexAuthSession(session.id).unwrap(); + setSession(nextSession); + if (nextSession.status === "authorized") { + setStatusMessage("Authorization successful"); + await syncFormFromProvider(); + } + if (nextSession.status === "failed" || nextSession.status === "expired" || nextSession.status === "cancelled") { + setStatusMessage(nextSession.last_error || `Authorization ${nextSession.status}`); + } + } catch (error) { + setStatusMessage(getErrorMessage(error)); + } + }, 2000); + return () => window.clearInterval(timer); + }, [getCodexAuthSession, session, syncFormFromProvider]); + + const ensureKeyID = useCallback( + async (authMethod: "device") => { + form.setValue("key.codex_key_config.auth_method", authMethod, { shouldDirty: true }); + if (!isEditing && onEnsurePersisted) { + const persistedKeyID = await onEnsurePersisted(authMethod); + if (!persistedKeyID) { + throw new Error("Key name is required before starting Codex authentication"); + } + return persistedKeyID; + } + return keyId; + }, + [form, isEditing, keyId, onEnsurePersisted], + ); + + const beginDeviceFlow = async () => { + setStatusMessage(null); + const resolvedKeyID = await ensureKeyID("device"); + const nextSession = await startDeviceAuth(resolvedKeyID).unwrap(); + setSession(nextSession); + setIsOpen(true); + if (nextSession.verification_uri) { + popupRef.current = window.open( + nextSession.verification_uri, + "codex_device_auth", + "width=640,height=760,resizable=yes,scrollbars=yes", + ); + } + }; + + const handleCancel = async () => { + if (session) { + await cancelCodexAuthSession(session.id) + .unwrap() + .catch(() => undefined); + } + setIsOpen(false); + setSession(null); + setStatusMessage(null); + }; + + const handleCloseDialog = () => { + if (popupRef.current && !popupRef.current.closed) { + popupRef.current.close(); + } + setIsOpen(false); + }; + + if (isConfigManaged) { + return ( + + Interactive auth disabled + + Codex browser/device authentication is disabled for keys managed from `config.json`. Update the Codex credentials in config mode + instead. + + + ); + } + + return ( +
+
+
+
Key Status
+

{keyStatus.description}

+
+ {keyStatus.label} +
+
+
Interactive Authentication
+

+ Click connect, sign in with OpenAI, and enter the verification code shown here. Bifrost stores the resulting Codex credentials for + you. +

+
+ {!isEditing ? ( + + Ready to connect + + Click connect below. If this is a new key, Bifrost will save the draft automatically before opening the sign-in link and code + flow. + + + ) : null} +
+ +
+ {isConnected ?

This key already has Codex credentials configured.

: null} + + setIsOpen(open)}> + + + Connect ChatGPT Plus/Pro + Open the verification link, sign in, then enter the code below on the OpenAI page. + + +
+ {session?.verification_uri ? ( +
+
Step 1: Open sign-in link
+ + Open verification page + + +
+ ) : null} + + {session?.user_code ? ( +
+
Step 2: Enter this code
+ +
+ ) : null} + +
+ Status: {session?.status ?? "pending"} + {statusMessage ?
{statusMessage}
: null} +
+
+ + + {session?.status === "pending" ? ( + + ) : null} + + +
+
+
+ ); +} diff --git a/ui/app/workspace/providers/fragments/index.ts b/ui/app/workspace/providers/fragments/index.ts index 9f6d506fe2..cb8bcc9779 100644 --- a/ui/app/workspace/providers/fragments/index.ts +++ b/ui/app/workspace/providers/fragments/index.ts @@ -1,5 +1,6 @@ export { AllowedRequestsFields } from "./allowedRequestsFields"; export { BetaHeadersFormFragment } from "./betaHeadersFormFragment"; +export { CodexAuthControls } from "./codexAuthControls"; export { ApiKeyFormFragment } from "./apiKeysFormFragment"; export { ApiStructureFormFragment } from "./apiStructureFormFragment"; export { DebuggingFormFragment } from "./debuggingFormFragment"; diff --git a/ui/app/workspace/providers/views/providerKeyForm.tsx b/ui/app/workspace/providers/views/providerKeyForm.tsx index f9a550fc83..8885c12ed1 100644 --- a/ui/app/workspace/providers/views/providerKeyForm.tsx +++ b/ui/app/workspace/providers/views/providerKeyForm.tsx @@ -3,7 +3,7 @@ import { ConfigSyncAlert } from "@/components/ui/configSyncAlert"; import { Form } from "@/components/ui/form"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; import { getErrorMessage, useUpdateProviderMutation } from "@/lib/store"; -import { ModelProvider } from "@/lib/types/config"; +import { DefaultCodexKeyConfig, ModelProvider } from "@/lib/types/config"; import { modelProviderKeySchema } from "@/lib/types/schemas"; import { zodResolver } from "@hookform/resolvers/zod"; import { Save } from "lucide-react"; @@ -30,6 +30,7 @@ type ProviderKeyFormValues = z.infer; export default function ProviderKeyForm({ provider, keyIndex, onCancel, onSave }: Props) { const [updateProvider, { isLoading: isUpdatingProvider }] = useUpdateProviderMutation(); const isEditing = provider?.keys?.[keyIndex] !== undefined; + const isCodex = provider.name === "codex"; const currentKey = provider?.keys?.[keyIndex]; const form = useForm({ @@ -37,14 +38,17 @@ export default function ProviderKeyForm({ provider, keyIndex, onCancel, onSave } mode: "onChange", reValidateMode: "onChange", defaultValues: { - key: (provider?.keys?.[keyIndex] as ProviderKeyFormValues) ?? { - id: uuid(), - name: "", - models: [], - blacklisted_models: [], - weight: 1.0, - enabled: true, - }, + key: + (provider?.keys?.[keyIndex] as ProviderKeyFormValues) ?? + ({ + id: uuid(), + name: "", + models: [], + blacklisted_models: [], + weight: 1.0, + enabled: true, + ...(provider.name === "codex" ? { codex_key_config: { ...DefaultCodexKeyConfig, auth_method: "device" } } : {}), + } as ProviderKeyFormValues), }, }); @@ -65,14 +69,41 @@ export default function ProviderKeyForm({ provider, keyIndex, onCancel, onSave } return null; }, [form?.formState.errors, form?.formState.isValid, form?.formState.isDirty]); + const persistDraftKey = useCallback( + async (authMethod?: "device" | "manual") => { + if (authMethod) { + form.setValue("key.codex_key_config.auth_method", authMethod, { shouldDirty: true }); + } + + const isValid = await form.trigger(["key.name"]); + if (!isValid) { + return null; + } + + const value = modelProviderKeySchema.parse(form.getValues("key")); + if (provider.name === "codex") { + value.codex_key_config = { + ...DefaultCodexKeyConfig, + ...value.codex_key_config, + }; + } + const keys = provider.keys ?? []; + const normalizedValue = value as ProviderKeyFormValues; + const updatedKeys = [...keys.slice(0, keyIndex), normalizedValue, ...keys.slice(keyIndex + 1)] as typeof provider.keys; + const updatedProvider = await updateProvider({ + ...provider, + keys: updatedKeys, + }).unwrap(); + + const persistedKey = updatedProvider.keys[keyIndex] ?? value; + form.reset({ key: persistedKey }); + return persistedKey.id; + }, + [form, keyIndex, provider, updateProvider], + ); + const onSubmit = (value: any) => { - const keys = provider.keys ?? []; - const updatedKeys = [...keys.slice(0, keyIndex), value.key, ...keys.slice(keyIndex + 1)]; - updateProvider({ - ...provider, - keys: updatedKeys, - }) - .unwrap() + persistDraftKey(value.key?.codex_key_config?.auth_method) .then(() => { onSave(); }) @@ -86,8 +117,15 @@ export default function ProviderKeyForm({ provider, keyIndex, onCancel, onSave } return (
- - {isEditing && currentKey?.config_hash && } + + {isEditing && !isCodex && currentKey?.config_hash && }