diff --git a/core/bifrost.go b/core/bifrost.go index 34eadfeac2..00270990a7 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -5547,6 +5547,25 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas // Tells the logging plugin whether to persist raw bytes in log records. req.Context.SetValue(schemas.BifrostContextKeyShouldStoreRawInLogs, effectiveStore) + // ChatGPT OAuth: extract Bearer token from request headers and inject as direct key. + // This bypasses the need for allow_direct_keys — the provider's chatgpt_oauth flag + // is sufficient authorization to forward the caller's token. Logic lives in the + // openai package to keep all ChatGPT-specific code in one module. + if config.OpenAIConfig != nil && config.OpenAIConfig.ChatGPTOAuth { + if _, alreadySet := req.Context.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key); !alreadySet { + if headers, ok := req.Context.Value(schemas.BifrostContextKeyRequestHeaders).(map[string]string); ok { + if token := openai.ExtractChatGPTOAuthBearerToken(headers); token != "" { + req.Context.SetValue(schemas.BifrostContextKeyDirectKey, schemas.Key{ + ID: openai.ChatGPTOAuthDirectKeyID, + Value: *schemas.NewEnvVar(token), + Models: []string{}, + Weight: 1.0, + }) + } + } + } + } + var keys []schemas.Key // keyProvider is passed to executeRequestWithRetries to manage key selection and rotation. // It is nil when no key is required (e.g. providerRequiresKey=false) or for multi-key diff --git a/core/providers/azure/azure.go b/core/providers/azure/azure.go index 2686084531..03753eafa4 100644 --- a/core/providers/azure/azure.go +++ b/core/providers/azure/azure.go @@ -824,6 +824,7 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post nil, nil, nil, + nil, provider.logger, postHookSpanFinalizer, ) diff --git a/core/providers/fireworks/fireworks.go b/core/providers/fireworks/fireworks.go index 827d1777df..f86f7e845f 100644 --- a/core/providers/fireworks/fireworks.go +++ b/core/providers/fireworks/fireworks.go @@ -80,7 +80,6 @@ func (provider *FireworksProvider) ListModels(ctx *schemas.BifrostContext, keys ) } - // TextCompletion performs a text completion request to the Fireworks AI API. func (provider *FireworksProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionRequest( @@ -188,6 +187,7 @@ func (provider *FireworksProvider) Responses(ctx *schemas.BifrostContext, key sc nil, nil, provider.logger, + nil, ) } @@ -212,6 +212,7 @@ func (provider *FireworksProvider) ResponsesStream(ctx *schemas.BifrostContext, nil, nil, nil, + nil, provider.logger, postHookSpanFinalizer, ) diff --git a/core/providers/openai/chatgpt_oauth.go b/core/providers/openai/chatgpt_oauth.go new file mode 100644 index 0000000000..a121d8c1d8 --- /dev/null +++ b/core/providers/openai/chatgpt_oauth.go @@ -0,0 +1,353 @@ +package openai + +import ( + "encoding/base64" + "errors" + "fmt" + "strings" + + "github.com/bytedance/sonic" + schemas "github.com/maximhq/bifrost/core/schemas" +) + +// errChatGPTOAuthRequiresStreaming is returned when a non-streaming Responses +// request is issued against a chatgpt_oauth-enabled provider. The ChatGPT +// backend only accepts stream=true so the non-streaming path is unsupported. +var errChatGPTOAuthRequiresStreaming = errors.New("chatgpt_oauth requires streaming /responses") + +// ChatGPTOAuthDefaultBaseURL is the default base URL for ChatGPT's backend API. +// When chatgpt_oauth is enabled and no custom base URL is set, this is used. +const ChatGPTOAuthDefaultBaseURL = "https://chatgpt.com/backend-api/codex" + +// ChatGPT OAuth Route Map +// +// The ChatGPT backend API (chatgpt.com/backend-api/codex) uses different paths +// from the standard OpenAI API (api.openai.com/v1). When chatgpt_oauth is enabled, +// the /v1 prefix is stripped. Routes supported by the ChatGPT backend: +// +// Standard OpenAI Path → ChatGPT Backend Path Method Notes +// ───────────────────────────────────────────────────────────────────────────────── +// /v1/responses → /responses POST(SSE) Primary inference +// /v1/responses (WS upgrade) → /responses (WS upgrade) WebSocket Preferred transport, falls back to SSE +// /v1/responses/compact → /responses/compact POST Context compaction (OpenAI+Azure only) +// /v1/responses/input_tokens → /responses/input_tokens POST Token counting +// /v1/models → /models?client_version= GET Returns {models:[{slug}]} format +// /v1/realtime/calls → /realtime/calls POST Voice/realtime (creates WebRTC call) +// /v1/realtime → /realtime WebSocket Voice/realtime session +// N/A → /memories/trace_summarize POST Memory summarization +// N/A → /files POST File upload (note: NOT under /codex/) +// N/A → /files/{id}/uploaded POST File upload completion +// +// Required headers on every request: +// - Authorization: Bearer (handled by direct key passthrough) +// - chatgpt-account-id: (extracted from JWT, added here) +// - OpenAI-Beta: responses=experimental (added here) +// +// Required body mutations for /responses: +// - instructions: must exist (default "") +// - store: must be false +// - max_output_tokens: must be deleted +// - stream: must be true (backend only accepts streaming) + +// ChatGPTOAuthClientVersionFallback is injected on outbound /models requests +// when the inbound caller didn't supply a ?client_version= query param. The +// ChatGPT backend requires the parameter to exist on /models but is tolerant of +// the actual value. If the inbound /v1/models?client_version=... query string +// IS forwarded by the transport (i.e. reaches chatGPTOAuthPath with a query), +// the caller's value is preserved and the fallback is not used. +// Matches the openai-oauth proxy fallback. +const ChatGPTOAuthClientVersionFallback = "0.111.0" + +// ChatGPTOAuthDirectKeyID is the key ID used when Bifrost auto-injects a Bearer +// token from the inbound Authorization header as a direct key. +const ChatGPTOAuthDirectKeyID = "chatgpt-oauth" + +// ExtractChatGPTOAuthBearerToken extracts a Bearer token from a request headers +// map (case-insensitive "authorization" lookup). Returns "" if no Bearer token +// is present. Public helper used by core/bifrost.go for the auto-inject path. +func ExtractChatGPTOAuthBearerToken(headers map[string]string) string { + if headers == nil { + return "" + } + authHeader, ok := headers["authorization"] + if !ok { + // Try case-insensitive fallback since the caller may not lowercase. + for k, v := range headers { + if strings.EqualFold(k, "authorization") { + authHeader = v + ok = true + break + } + } + } + if !ok || authHeader == "" { + return "" + } + if !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { + return "" + } + return strings.TrimSpace(authHeader[7:]) +} + +// chatGPTOAuthPath maps a standard OpenAI /v1/... path to the ChatGPT backend path. +// Strips the /v1 prefix and appends required query parameters for routes that need them +// (e.g. /models requires ?client_version). Returns the path unchanged if it doesn't start with /v1. +// +// For /models: if the incoming path already carries a client_version (e.g. Codex sends +// /v1/models?client_version=0.121.0), that value is preserved. Only when the param is +// absent do we inject the fallback so the ChatGPT backend doesn't reject the request. +func chatGPTOAuthPath(standardPath string) string { + // Split path and query so we can inspect and preserve caller-supplied query params. + pathOnly, query, hasQuery := strings.Cut(standardPath, "?") + + mapped := pathOnly + if pathOnly == "/v1" { + mapped = "/" + } else if strings.HasPrefix(pathOnly, "/v1/") { + mapped = pathOnly[3:] // strip "/v1" prefix, keep the "/" + } + + // /models requires a client_version query parameter on the ChatGPT backend. + if mapped == "/models" { + if !hasQuery { + return mapped + "?client_version=" + ChatGPTOAuthClientVersionFallback + } + // Preserve caller query; inject fallback only if client_version is absent. + if !queryContainsKey(query, "client_version") { + return mapped + "?" + query + "&client_version=" + ChatGPTOAuthClientVersionFallback + } + return mapped + "?" + query + } + + if hasQuery { + return mapped + "?" + query + } + return mapped +} + +// queryContainsKey reports whether the given raw query string contains the named key. +// Does not URL-decode — callers pass already-valid query strings. +func queryContainsKey(rawQuery, key string) bool { + for _, pair := range strings.Split(rawQuery, "&") { + k, _, _ := strings.Cut(pair, "=") + if k == key { + return true + } + } + return false +} + +// chatGPTOAuthWebSocketURL builds the upstream WebSocket URL for the ChatGPT backend, +// stripping the /v1 prefix and converting http(s):// to ws(s)://. +func chatGPTOAuthWebSocketURL(baseURL, standardPath string) string { + url := strings.Replace(baseURL, "https://", "wss://", 1) + url = strings.Replace(url, "http://", "ws://", 1) + return url + chatGPTOAuthPath(standardPath) +} + +// chatGPTOAuthWebSocketHeaders builds the OAuth-specific headers required for the +// ChatGPT upstream WebSocket connection: Authorization (Bearer from key), +// chatgpt-account-id (extracted from JWT), and OpenAI-Beta. +// +// It does NOT inject identity defaults (originator, version). Those defaults are +// the responsibility of the caller — mergeClientWSHeaders in wsresponses.go fills +// them in only when the client-sent headers do not already carry those values. +// This keeps the two concerns separate: this function owns OAuth credentials; +// mergeClientWSHeaders owns the identity-fallback policy. +// +// Merge order (lowest → highest priority): +// 1. existingExtraHeaders — provider-level static extra headers from config +// 2. OAuth headers — Authorization (Bearer from key), chatgpt-account-id +// (extracted from JWT), OpenAI-Beta. These always win. +// +// forwardedHeaders is accepted for API compatibility but is no longer used here; +// pass nil or an empty map. +func chatGPTOAuthWebSocketHeaders(key schemas.Key, existingExtraHeaders map[string]string, forwardedHeaders map[string]string, logger schemas.Logger) map[string]string { + authHeader := map[string]string{"Authorization": "Bearer " + key.Value.GetValue()} + accountID, err := extractChatGPTAccountID(key.Value.GetValue()) + if err != nil { + if logger != nil { + logger.Warn("chatgpt_oauth: failed to extract account ID for WebSocket: %v", err) + } + // OAuth Authorization still wins; skip chatgpt-account-id and OpenAI-Beta. + return mergeHeadersCaseInsensitive(existingExtraHeaders, authHeader) + } + oauth := chatGPTOAuthExtraHeaders(accountID) + oauth["Authorization"] = authHeader["Authorization"] + + // Merge order: existingExtraHeaders → oauth (oauth always wins) + return mergeHeadersCaseInsensitive(existingExtraHeaders, oauth) +} + +// mapContainsKeyCI reports whether m contains a key that matches target +// case-insensitively. +func mapContainsKeyCI(m map[string]string, target string) bool { + for k := range m { + if strings.EqualFold(k, target) { + return true + } + } + return false +} + +// extractChatGPTAccountID decodes the JWT access token payload and extracts +// the chatgpt_account_id from the "https://api.openai.com/auth" claim. +// No signature verification is performed — we only need the claim value. +func extractChatGPTAccountID(accessToken string) (string, error) { + if accessToken == "" { + return "", fmt.Errorf("empty access token") + } + + parts := strings.Split(accessToken, ".") + if len(parts) != 3 { + return "", fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) + } + + // base64url decode the payload (second segment) + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", fmt.Errorf("failed to decode JWT payload: %w", err) + } + + var claims map[string]interface{} + if err := sonic.Unmarshal(payload, &claims); err != nil { + return "", fmt.Errorf("failed to parse JWT claims: %w", err) + } + + authClaim, ok := claims["https://api.openai.com/auth"] + if !ok { + return "", fmt.Errorf("missing https://api.openai.com/auth claim in JWT") + } + + authMap, ok := authClaim.(map[string]interface{}) + if !ok { + return "", fmt.Errorf("https://api.openai.com/auth claim is not an object") + } + + accountID, ok := authMap["chatgpt_account_id"].(string) + if !ok || accountID == "" { + return "", fmt.Errorf("chatgpt_account_id not found or empty in JWT auth claim") + } + + // Sanitize: reject account IDs containing newlines or carriage returns + // to prevent HTTP header injection attacks via crafted JWTs. + if strings.ContainsAny(accountID, "\r\n") { + return "", fmt.Errorf("chatgpt_account_id contains invalid characters") + } + + return accountID, nil +} + +// transformChatGPTResponsesBody modifies the JSON request body for the ChatGPT backend API: +// - ensures "instructions" field exists (defaults to "") +// - forces "store" to false (the backend rejects store=true for OAuth callers) +// - deletes "max_output_tokens" +// - forces "stream" to true (the ChatGPT backend API only accepts streaming requests) +func transformChatGPTResponsesBody(body []byte) ([]byte, error) { + var data map[string]interface{} + if err := sonic.Unmarshal(body, &data); err != nil { + return nil, fmt.Errorf("failed to parse request body: %w", err) + } + + // Ensure instructions field exists + if _, ok := data["instructions"]; !ok { + data["instructions"] = "" + } + + // Force store to false — the ChatGPT backend API rejects store=true for OAuth + // callers regardless of caller intent. + data["store"] = false + + // Remove max_output_tokens + delete(data, "max_output_tokens") + + // Force stream to true — the ChatGPT backend API only accepts streaming + data["stream"] = true + + return sonic.Marshal(data) +} + +// chatGPTOAuthExtraHeaders returns the extra headers required for ChatGPT OAuth requests. +func chatGPTOAuthExtraHeaders(accountID string) map[string]string { + return map[string]string{ + "chatgpt-account-id": accountID, + "OpenAI-Beta": "responses=experimental", + } +} + +// chatGPTOAuthPrepare extracts the account ID from the bearer token, builds the +// merged extra headers (OAuth-specific headers merged with any existing headers), +// and maps the standard OpenAI path to the ChatGPT backend path. +// This is the single entry point for all ChatGPT OAuth header/path logic — +// openai.go calls this instead of duplicating the logic. +func chatGPTOAuthPrepare(key schemas.Key, existingExtraHeaders map[string]string, standardPath string, logger schemas.Logger) (extraHeaders map[string]string, path string, err error) { + accountID, err := extractChatGPTAccountID(key.Value.GetValue()) + if err != nil { + return nil, "", err + } + return mergeHeadersCaseInsensitive(existingExtraHeaders, chatGPTOAuthExtraHeaders(accountID)), chatGPTOAuthPath(standardPath), nil +} + +// chatGPTOAuthMergeHeaders merges ChatGPT OAuth headers (chatgpt-account-id, OpenAI-Beta) +// into the existing extraHeaders. Safe to call unconditionally — returns existingExtraHeaders +// unchanged when enabled=false or when JWT extraction fails (logged). +// Use for request types that don't need body transformation (ListModels, ChatCompletion, etc). +func chatGPTOAuthMergeHeaders(enabled bool, key schemas.Key, existingExtraHeaders map[string]string, logger schemas.Logger) map[string]string { + if !enabled { + return existingExtraHeaders + } + accountID, err := extractChatGPTAccountID(key.Value.GetValue()) + if err != nil { + if logger != nil { + logger.Warn("chatgpt_oauth: failed to extract account ID: %v", err) + } + return existingExtraHeaders + } + return mergeHeadersCaseInsensitive(existingExtraHeaders, chatGPTOAuthExtraHeaders(accountID)) +} + +// chatGPTOAuthApplyRequest is a convenience wrapper that applies ChatGPT OAuth +// transformations for the Responses request: merged headers + body transformer. +// Path mapping is handled separately by buildRequestURL, which auto-strips /v1 +// when chatgpt_oauth is enabled. +// If enabled is false, returns the inputs unchanged and nil bodyTransformer. +// If enabled is true and JWT extraction fails, returns an error so the caller +// can surface a structured "invalid ChatGPT OAuth token" error rather than +// letting the upstream reject a mutated body with no account-id header. +func chatGPTOAuthApplyRequest(enabled bool, key schemas.Key, existingExtraHeaders map[string]string, logger schemas.Logger) (headers map[string]string, bodyTransformer func([]byte) ([]byte, error), err error) { + if !enabled { + return existingExtraHeaders, nil, nil + } + accountID, extractErr := extractChatGPTAccountID(key.Value.GetValue()) + if extractErr != nil { + return nil, nil, fmt.Errorf("invalid ChatGPT OAuth token: %w", extractErr) + } + oauthHeaders := chatGPTOAuthExtraHeaders(accountID) + merged := mergeHeadersCaseInsensitive(existingExtraHeaders, oauthHeaders) + return merged, transformChatGPTResponsesBody, nil +} + +// mergeHeadersCaseInsensitive merges two header maps, treating header names +// case-insensitively. OAuth overrides always win. Keys from the OAuth map are +// preserved as-is; duplicates from existingHeaders (by case-insensitive match) +// are dropped. This prevents both "openai-beta" and "OpenAI-Beta" from ending +// up in the result map where Go's unordered iteration would cause intermittent +// behavior in SetExtraHeaders. +func mergeHeadersCaseInsensitive(existingHeaders, oauthHeaders map[string]string) map[string]string { + // Build case-insensitive lookup of OAuth keys so we can skip duplicates from existingHeaders. + oauthKeysLower := make(map[string]bool, len(oauthHeaders)) + for k := range oauthHeaders { + oauthKeysLower[strings.ToLower(k)] = true + } + merged := make(map[string]string, len(existingHeaders)+len(oauthHeaders)) + for k, v := range existingHeaders { + if oauthKeysLower[strings.ToLower(k)] { + continue // OAuth override wins + } + merged[k] = v + } + for k, v := range oauthHeaders { + merged[k] = v + } + return merged +} diff --git a/core/providers/openai/chatgpt_oauth_test.go b/core/providers/openai/chatgpt_oauth_test.go new file mode 100644 index 0000000000..0644e9506f --- /dev/null +++ b/core/providers/openai/chatgpt_oauth_test.go @@ -0,0 +1,854 @@ +package openai + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "testing" + + schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// buildTestJWT creates a minimal JWT with the given payload for testing. +// No signature verification is needed — we only decode the payload. +func buildTestJWT(payload map[string]interface{}) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`)) + payloadBytes, err := json.Marshal(payload) + if err != nil { + panic(fmt.Sprintf("buildTestJWT: failed to marshal test payload: %v", err)) + } + payloadB64 := base64.RawURLEncoding.EncodeToString(payloadBytes) + sig := base64.RawURLEncoding.EncodeToString([]byte("fake-signature")) + return header + "." + payloadB64 + "." + sig +} + +// captureLogger is a minimal schemas.Logger stub that records Debug and Warn +// calls for assertion in tests. All other methods are no-ops. +type captureLogger struct { + debugs []string + warns []string +} + +func (l *captureLogger) Debug(msg string, args ...any) { + l.debugs = append(l.debugs, fmt.Sprintf(msg, args...)) +} +func (l *captureLogger) Info(msg string, args ...any) {} +func (l *captureLogger) Warn(msg string, args ...any) { + l.warns = append(l.warns, fmt.Sprintf(msg, args...)) +} +func (l *captureLogger) Error(msg string, args ...any) {} +func (l *captureLogger) Fatal(msg string, args ...any) {} +func (l *captureLogger) SetLevel(level schemas.LogLevel) {} +func (l *captureLogger) SetOutputType(outputType schemas.LoggerOutputType) {} +func (l *captureLogger) LogHTTPRequest(level schemas.LogLevel, msg string) schemas.LogEventBuilder { + return nil +} + +// buildTestJWTRaw creates a JWT with a raw base64url-encoded payload string. +// Useful for testing invalid JSON payloads. +func buildTestJWTRaw(payloadB64 string) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`)) + sig := base64.RawURLEncoding.EncodeToString([]byte("fake-signature")) + return header + "." + payloadB64 + "." + sig +} + +// --------------------------------------------------------------------------- +// extractChatGPTAccountID +// --------------------------------------------------------------------------- + +func TestExtractChatGPTAccountID(t *testing.T) { + t.Run("valid token with account ID", func(t *testing.T) { + token := buildTestJWT(map[string]interface{}{ + "sub": "google-oauth2|12345", + "https://api.openai.com/auth": map[string]interface{}{ + "chatgpt_account_id": "9774aee9-daa9-4327-afe5-3efbeed7e328", + "chatgpt_user_id": "user-FcJBIsPIye2kIwcIet4nIvx4", + }, + }) + accountID, err := extractChatGPTAccountID(token) + require.NoError(t, err) + assert.Equal(t, "9774aee9-daa9-4327-afe5-3efbeed7e328", accountID) + }) + + t.Run("empty token", func(t *testing.T) { + _, err := extractChatGPTAccountID("") + assert.Error(t, err) + assert.Contains(t, err.Error(), "empty access token") + }) + + t.Run("malformed JWT - no dots", func(t *testing.T) { + _, err := extractChatGPTAccountID("not-a-jwt") + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid JWT format") + }) + + t.Run("malformed JWT - two parts only", func(t *testing.T) { + _, err := extractChatGPTAccountID("header.payload") + assert.Error(t, err) + assert.Contains(t, err.Error(), "expected 3 parts, got 2") + }) + + t.Run("malformed JWT - four parts", func(t *testing.T) { + _, err := extractChatGPTAccountID("a.b.c.d") + assert.Error(t, err) + assert.Contains(t, err.Error(), "expected 3 parts, got 4") + }) + + t.Run("malformed JWT - invalid base64 payload", func(t *testing.T) { + _, err := extractChatGPTAccountID("header.!!!invalid!!!.sig") + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode JWT payload") + }) + + t.Run("malformed JWT - payload is not valid JSON", func(t *testing.T) { + notJSON := base64.RawURLEncoding.EncodeToString([]byte("this is not json")) + token := buildTestJWTRaw(notJSON) + _, err := extractChatGPTAccountID(token) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse JWT claims") + }) + + t.Run("missing auth claim", func(t *testing.T) { + token := buildTestJWT(map[string]interface{}{ + "sub": "google-oauth2|12345", + }) + _, err := extractChatGPTAccountID(token) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing https://api.openai.com/auth claim") + }) + + t.Run("auth claim is not an object", func(t *testing.T) { + token := buildTestJWT(map[string]interface{}{ + "https://api.openai.com/auth": "not-an-object", + }) + _, err := extractChatGPTAccountID(token) + assert.Error(t, err) + assert.Contains(t, err.Error(), "claim is not an object") + }) + + t.Run("auth claim is an array", func(t *testing.T) { + token := buildTestJWT(map[string]interface{}{ + "https://api.openai.com/auth": []string{"a", "b"}, + }) + _, err := extractChatGPTAccountID(token) + assert.Error(t, err) + assert.Contains(t, err.Error(), "claim is not an object") + }) + + t.Run("missing account_id in auth claim", func(t *testing.T) { + token := buildTestJWT(map[string]interface{}{ + "https://api.openai.com/auth": map[string]interface{}{ + "chatgpt_user_id": "user-FcJBIsPIye2kIwcIet4nIvx4", + }, + }) + _, err := extractChatGPTAccountID(token) + assert.Error(t, err) + assert.Contains(t, err.Error(), "chatgpt_account_id not found or empty") + }) + + t.Run("account_id is not a string", func(t *testing.T) { + token := buildTestJWT(map[string]interface{}{ + "https://api.openai.com/auth": map[string]interface{}{ + "chatgpt_account_id": 12345, + }, + }) + _, err := extractChatGPTAccountID(token) + assert.Error(t, err) + assert.Contains(t, err.Error(), "chatgpt_account_id not found or empty") + }) + + t.Run("empty account_id", func(t *testing.T) { + token := buildTestJWT(map[string]interface{}{ + "https://api.openai.com/auth": map[string]interface{}{ + "chatgpt_account_id": "", + }, + }) + _, err := extractChatGPTAccountID(token) + assert.Error(t, err) + assert.Contains(t, err.Error(), "chatgpt_account_id not found or empty") + }) + + // Security: HTTP header injection prevention + t.Run("rejects account_id with newline (header injection)", func(t *testing.T) { + token := buildTestJWT(map[string]interface{}{ + "https://api.openai.com/auth": map[string]interface{}{ + "chatgpt_account_id": "legit-id\r\nX-Injected: malicious", + }, + }) + _, err := extractChatGPTAccountID(token) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid characters") + }) + + t.Run("rejects account_id with bare newline", func(t *testing.T) { + token := buildTestJWT(map[string]interface{}{ + "https://api.openai.com/auth": map[string]interface{}{ + "chatgpt_account_id": "legit-id\ninjection", + }, + }) + _, err := extractChatGPTAccountID(token) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid characters") + }) + + t.Run("rejects account_id with bare carriage return", func(t *testing.T) { + token := buildTestJWT(map[string]interface{}{ + "https://api.openai.com/auth": map[string]interface{}{ + "chatgpt_account_id": "legit-id\rinjection", + }, + }) + _, err := extractChatGPTAccountID(token) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid characters") + }) +} + +// --------------------------------------------------------------------------- +// transformChatGPTResponsesBody +// --------------------------------------------------------------------------- + +func TestTransformChatGPTResponsesBody(t *testing.T) { + t.Run("adds instructions, store, stream and removes max_output_tokens", func(t *testing.T) { + input := []byte(`{"model":"gpt-5.4","input":[{"role":"user","content":"hello"}],"max_output_tokens":4096}`) + output, err := transformChatGPTResponsesBody(input) + require.NoError(t, err) + + var result map[string]interface{} + require.NoError(t, json.Unmarshal(output, &result)) + + assert.Equal(t, "", result["instructions"], "instructions should default to empty string") + assert.Equal(t, false, result["store"], "store should default to false") + assert.Equal(t, true, result["stream"], "stream must be forced to true") + assert.Equal(t, "gpt-5.4", result["model"], "model should be preserved") + assert.NotNil(t, result["input"], "input should be preserved") + _, hasMaxOutputTokens := result["max_output_tokens"] + assert.False(t, hasMaxOutputTokens, "max_output_tokens must be removed") + }) + + t.Run("preserves existing instructions", func(t *testing.T) { + input := []byte(`{"model":"gpt-5.4","instructions":"You are a coding assistant"}`) + output, err := transformChatGPTResponsesBody(input) + require.NoError(t, err) + + var result map[string]interface{} + require.NoError(t, json.Unmarshal(output, &result)) + + assert.Equal(t, "You are a coding assistant", result["instructions"]) + }) + + t.Run("forces store false even when caller sets true", func(t *testing.T) { + input := []byte(`{"model":"gpt-5.4","store":true}`) + output, err := transformChatGPTResponsesBody(input) + require.NoError(t, err) + + var result map[string]interface{} + require.NoError(t, json.Unmarshal(output, &result)) + + assert.Equal(t, false, result["store"], "ChatGPT backend rejects store=true for OAuth callers; transformer must force false") + }) + + t.Run("keeps store false when caller already sets false", func(t *testing.T) { + input := []byte(`{"model":"gpt-5.4","store":false}`) + output, err := transformChatGPTResponsesBody(input) + require.NoError(t, err) + + var result map[string]interface{} + require.NoError(t, json.Unmarshal(output, &result)) + + assert.Equal(t, false, result["store"]) + }) + + t.Run("forces stream true when not present", func(t *testing.T) { + input := []byte(`{"model":"gpt-5.4","input":"hello"}`) + output, err := transformChatGPTResponsesBody(input) + require.NoError(t, err) + + var result map[string]interface{} + require.NoError(t, json.Unmarshal(output, &result)) + + assert.Equal(t, true, result["stream"]) + }) + + t.Run("overrides stream false with true", func(t *testing.T) { + input := []byte(`{"model":"gpt-5.4","input":"hello","stream":false}`) + output, err := transformChatGPTResponsesBody(input) + require.NoError(t, err) + + var result map[string]interface{} + require.NoError(t, json.Unmarshal(output, &result)) + + assert.Equal(t, true, result["stream"], "stream must be forced to true even if caller set false") + }) + + t.Run("preserves stream true", func(t *testing.T) { + input := []byte(`{"model":"gpt-5.4","stream":true}`) + output, err := transformChatGPTResponsesBody(input) + require.NoError(t, err) + + var result map[string]interface{} + require.NoError(t, json.Unmarshal(output, &result)) + + assert.Equal(t, true, result["stream"]) + }) + + t.Run("preserves all other fields", func(t *testing.T) { + input := []byte(`{"model":"gpt-5.4","input":[{"role":"user","content":"hi"}],"temperature":0.7,"top_p":0.9,"tools":[{"type":"function"}]}`) + output, err := transformChatGPTResponsesBody(input) + require.NoError(t, err) + + var result map[string]interface{} + require.NoError(t, json.Unmarshal(output, &result)) + + assert.Equal(t, "gpt-5.4", result["model"]) + assert.Equal(t, float64(0.7), result["temperature"]) + assert.Equal(t, float64(0.9), result["top_p"]) + assert.NotNil(t, result["input"]) + assert.NotNil(t, result["tools"]) + }) + + t.Run("handles empty JSON object", func(t *testing.T) { + input := []byte(`{}`) + output, err := transformChatGPTResponsesBody(input) + require.NoError(t, err) + + var result map[string]interface{} + require.NoError(t, json.Unmarshal(output, &result)) + + assert.Equal(t, "", result["instructions"]) + assert.Equal(t, false, result["store"]) + assert.Equal(t, true, result["stream"]) + }) + + t.Run("invalid JSON returns error", func(t *testing.T) { + _, err := transformChatGPTResponsesBody([]byte(`not json`)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse request body") + }) + + t.Run("empty input returns error", func(t *testing.T) { + _, err := transformChatGPTResponsesBody([]byte(``)) + assert.Error(t, err) + }) +} + +// --------------------------------------------------------------------------- +// chatGPTOAuthExtraHeaders +// --------------------------------------------------------------------------- + +func TestChatGPTOAuthExtraHeaders(t *testing.T) { + t.Run("returns correct headers", func(t *testing.T) { + headers := chatGPTOAuthExtraHeaders("9774aee9-daa9-4327-afe5-3efbeed7e328") + + assert.Equal(t, "9774aee9-daa9-4327-afe5-3efbeed7e328", headers["chatgpt-account-id"]) + assert.Equal(t, "responses=experimental", headers["OpenAI-Beta"]) + assert.Len(t, headers, 2) + }) + + t.Run("works with any account ID format", func(t *testing.T) { + headers := chatGPTOAuthExtraHeaders("simple-id") + assert.Equal(t, "simple-id", headers["chatgpt-account-id"]) + }) +} + +// --------------------------------------------------------------------------- +// ChatGPTOAuthDefaultBaseURL constant +// --------------------------------------------------------------------------- + +func TestChatGPTOAuthDefaultBaseURL(t *testing.T) { + assert.Equal(t, "https://chatgpt.com/backend-api/codex", ChatGPTOAuthDefaultBaseURL) +} + +// --------------------------------------------------------------------------- +// chatGPTOAuthPath — route mapping for all documented ChatGPT backend routes +// --------------------------------------------------------------------------- + +func TestChatGPTOAuthPath(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + // Core inference routes + {"responses POST/SSE", "/v1/responses", "/responses"}, + {"responses WebSocket", "/v1/responses", "/responses"}, // same path, different upgrade + {"responses compact", "/v1/responses/compact", "/responses/compact"}, + {"responses input_tokens", "/v1/responses/input_tokens", "/responses/input_tokens"}, + + // Models — appends required client_version query param when not present + {"models no query injects fallback", "/v1/models", "/models?client_version=" + ChatGPTOAuthClientVersionFallback}, + // Models — preserves caller-supplied client_version + {"models preserves caller client_version", "/v1/models?client_version=1.2.3", "/models?client_version=1.2.3"}, + // Models — injects fallback alongside other caller query params + {"models preserves other query, adds fallback", "/v1/models?foo=bar", "/models?foo=bar&client_version=" + ChatGPTOAuthClientVersionFallback}, + // Models — preserves multiple params including caller's client_version + {"models preserves all when client_version present", "/v1/models?foo=bar&client_version=9.9.9", "/models?foo=bar&client_version=9.9.9"}, + + // Realtime/voice + {"realtime calls", "/v1/realtime/calls", "/realtime/calls"}, + {"realtime session", "/v1/realtime", "/realtime"}, + // Query-string preservation for non-/models routes + {"preserves query for non-models", "/v1/responses?foo=bar", "/responses?foo=bar"}, + + // Edge cases + {"bare /v1", "/v1", "/"}, + {"already stripped path", "/responses", "/responses"}, + {"non-v1 path passthrough", "/custom/path", "/custom/path"}, + {"empty path", "", ""}, + {"root path", "/", "/"}, + {"v1 without slash", "/v1files", "/v1files"}, // must not strip partial match + + // Files (note: in ChatGPT backend these are NOT under /codex/) + {"files upload", "/v1/files", "/files"}, + {"files uploaded", "/v1/files/file-abc123/uploaded", "/files/file-abc123/uploaded"}, + + // Memory + {"memories trace", "/v1/memories/trace_summarize", "/memories/trace_summarize"}, + + // Batches (standard OpenAI, may not exist on ChatGPT backend but path should still map) + {"batches", "/v1/batches", "/batches"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, chatGPTOAuthPath(tt.input)) + }) + } +} + +// --------------------------------------------------------------------------- +// chatGPTOAuthPrepare — integration of all helpers +// --------------------------------------------------------------------------- + +func TestChatGPTOAuthPrepare(t *testing.T) { + validToken := buildTestJWT(map[string]interface{}{ + "https://api.openai.com/auth": map[string]interface{}{ + "chatgpt_account_id": "acct-123", + }, + }) + + t.Run("returns mapped path and merged headers for /v1/responses", func(t *testing.T) { + key := schemas.Key{Value: schemas.EnvVar{Val: validToken}} + existing := map[string]string{"X-Custom": "value"} + + headers, path, err := chatGPTOAuthPrepare(key, existing, "/v1/responses", nil) + + require.NoError(t, err) + assert.Equal(t, "/responses", path) + assert.Equal(t, "acct-123", headers["chatgpt-account-id"]) + assert.Equal(t, "responses=experimental", headers["OpenAI-Beta"]) + assert.Equal(t, "value", headers["X-Custom"], "existing headers must be preserved") + }) + + t.Run("maps /v1/models with client_version", func(t *testing.T) { + key := schemas.Key{Value: schemas.EnvVar{Val: validToken}} + + _, path, err := chatGPTOAuthPrepare(key, nil, "/v1/models", nil) + require.NoError(t, err) + assert.Equal(t, "/models?client_version="+ChatGPTOAuthClientVersionFallback, path) + }) + + t.Run("maps /v1/responses/compact", func(t *testing.T) { + key := schemas.Key{Value: schemas.EnvVar{Val: validToken}} + + _, path, err := chatGPTOAuthPrepare(key, nil, "/v1/responses/compact", nil) + require.NoError(t, err) + assert.Equal(t, "/responses/compact", path) + }) + + t.Run("maps /v1/responses/input_tokens", func(t *testing.T) { + key := schemas.Key{Value: schemas.EnvVar{Val: validToken}} + + _, path, err := chatGPTOAuthPrepare(key, nil, "/v1/responses/input_tokens", nil) + require.NoError(t, err) + assert.Equal(t, "/responses/input_tokens", path) + }) + + t.Run("maps /v1/realtime/calls", func(t *testing.T) { + key := schemas.Key{Value: schemas.EnvVar{Val: validToken}} + + _, path, err := chatGPTOAuthPrepare(key, nil, "/v1/realtime/calls", nil) + require.NoError(t, err) + assert.Equal(t, "/realtime/calls", path) + }) + + t.Run("maps /v1/files", func(t *testing.T) { + key := schemas.Key{Value: schemas.EnvVar{Val: validToken}} + + _, path, err := chatGPTOAuthPrepare(key, nil, "/v1/files", nil) + require.NoError(t, err) + assert.Equal(t, "/files", path) + }) + + t.Run("oauth headers override existing conflicting headers", func(t *testing.T) { + key := schemas.Key{Value: schemas.EnvVar{Val: validToken}} + existing := map[string]string{ + "chatgpt-account-id": "old-id", + "OpenAI-Beta": "old-beta", + "X-Keep": "keep", + } + + headers, _, err := chatGPTOAuthPrepare(key, existing, "/v1/responses", nil) + + require.NoError(t, err) + assert.Equal(t, "acct-123", headers["chatgpt-account-id"], "OAuth header must override existing") + assert.Equal(t, "responses=experimental", headers["OpenAI-Beta"], "OAuth header must override existing") + assert.Equal(t, "keep", headers["X-Keep"], "non-conflicting headers preserved") + }) + + t.Run("nil existing headers does not panic", func(t *testing.T) { + key := schemas.Key{Value: schemas.EnvVar{Val: validToken}} + + headers, path, err := chatGPTOAuthPrepare(key, nil, "/v1/responses", nil) + + require.NoError(t, err) + assert.Equal(t, "/responses", path) + assert.Equal(t, "acct-123", headers["chatgpt-account-id"]) + assert.Equal(t, "responses=experimental", headers["OpenAI-Beta"]) + assert.Len(t, headers, 2) + }) + + t.Run("empty existing headers map", func(t *testing.T) { + key := schemas.Key{Value: schemas.EnvVar{Val: validToken}} + + headers, _, err := chatGPTOAuthPrepare(key, map[string]string{}, "/v1/responses", nil) + + require.NoError(t, err) + assert.Len(t, headers, 2) + }) + + t.Run("returns error on invalid token", func(t *testing.T) { + key := schemas.Key{Value: schemas.EnvVar{Val: "not-a-jwt"}} + + _, _, err := chatGPTOAuthPrepare(key, nil, "/v1/responses", nil) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid JWT format") + }) + + t.Run("returns error on empty token", func(t *testing.T) { + key := schemas.Key{Value: schemas.EnvVar{Val: ""}} + + _, _, err := chatGPTOAuthPrepare(key, nil, "/v1/responses", nil) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "empty access token") + }) +} + +// --------------------------------------------------------------------------- +// ExtractChatGPTOAuthBearerToken (public helper used by core/bifrost.go) +// --------------------------------------------------------------------------- + +func TestExtractChatGPTOAuthBearerToken(t *testing.T) { + t.Run("extracts bearer token from lowercased key", func(t *testing.T) { + headers := map[string]string{"authorization": "Bearer abc123"} + assert.Equal(t, "abc123", ExtractChatGPTOAuthBearerToken(headers)) + }) + + t.Run("extracts bearer token case-insensitively", func(t *testing.T) { + headers := map[string]string{"Authorization": "Bearer abc123"} + assert.Equal(t, "abc123", ExtractChatGPTOAuthBearerToken(headers)) + }) + + t.Run("accepts mixed-case Bearer prefix", func(t *testing.T) { + headers := map[string]string{"authorization": "bearer xyz"} + assert.Equal(t, "xyz", ExtractChatGPTOAuthBearerToken(headers)) + }) + + t.Run("trims whitespace", func(t *testing.T) { + headers := map[string]string{"authorization": "Bearer padded "} + assert.Equal(t, "padded", ExtractChatGPTOAuthBearerToken(headers)) + }) + + t.Run("returns empty when no auth header", func(t *testing.T) { + assert.Equal(t, "", ExtractChatGPTOAuthBearerToken(map[string]string{"x-other": "v"})) + }) + + t.Run("returns empty when auth is not Bearer", func(t *testing.T) { + headers := map[string]string{"authorization": "Basic dXNlcjpwYXNz"} + assert.Equal(t, "", ExtractChatGPTOAuthBearerToken(headers)) + }) + + t.Run("returns empty when auth header is empty", func(t *testing.T) { + headers := map[string]string{"authorization": ""} + assert.Equal(t, "", ExtractChatGPTOAuthBearerToken(headers)) + }) + + t.Run("returns empty when headers is nil", func(t *testing.T) { + assert.Equal(t, "", ExtractChatGPTOAuthBearerToken(nil)) + }) +} + +// --------------------------------------------------------------------------- +// chatGPTOAuthMergeHeaders (non-request variant used for headers-only routes) +// --------------------------------------------------------------------------- + +func TestChatGPTOAuthMergeHeaders(t *testing.T) { + validToken := buildTestJWT(map[string]interface{}{ + "https://api.openai.com/auth": map[string]interface{}{"chatgpt_account_id": "acct-xyz"}, + }) + key := schemas.Key{Value: schemas.EnvVar{Val: validToken}} + + t.Run("disabled returns input unchanged", func(t *testing.T) { + existing := map[string]string{"X-Custom": "v"} + got := chatGPTOAuthMergeHeaders(false, key, existing, nil) + assert.Equal(t, existing, got) + }) + + t.Run("enabled merges OAuth headers", func(t *testing.T) { + existing := map[string]string{"X-Custom": "v"} + got := chatGPTOAuthMergeHeaders(true, key, existing, nil) + assert.Equal(t, "acct-xyz", got["chatgpt-account-id"]) + assert.Equal(t, "responses=experimental", got["OpenAI-Beta"]) + assert.Equal(t, "v", got["X-Custom"]) + }) + + t.Run("enabled with invalid token returns unchanged headers", func(t *testing.T) { + existing := map[string]string{"X-Custom": "v"} + badKey := schemas.Key{Value: schemas.EnvVar{Val: "not-a-jwt"}} + got := chatGPTOAuthMergeHeaders(true, badKey, existing, nil) + assert.Equal(t, existing, got) + }) + + t.Run("case-insensitive override drops conflicting existing header", func(t *testing.T) { + existing := map[string]string{ + "chatgpt-account-id": "stale-id", + "OPENAI-BETA": "stale-beta", + "X-Keep": "keep", + } + got := chatGPTOAuthMergeHeaders(true, key, existing, nil) + // OAuth values win even when existing had different casing + assert.Equal(t, "acct-xyz", got["chatgpt-account-id"]) + assert.Equal(t, "responses=experimental", got["OpenAI-Beta"]) + // Stale casing should NOT appear as a duplicate key + _, hasStaleBeta := got["OPENAI-BETA"] + assert.False(t, hasStaleBeta, "existing header with conflicting case must be dropped") + assert.Equal(t, "keep", got["X-Keep"]) + }) +} + +// --------------------------------------------------------------------------- +// chatGPTOAuthApplyRequest — fail-fast on invalid token +// --------------------------------------------------------------------------- + +func TestChatGPTOAuthApplyRequest(t *testing.T) { + validToken := buildTestJWT(map[string]interface{}{ + "https://api.openai.com/auth": map[string]interface{}{"chatgpt_account_id": "acct-apply"}, + }) + + t.Run("disabled returns unchanged headers and nil transformer", func(t *testing.T) { + existing := map[string]string{"X-Custom": "v"} + headers, transformer, err := chatGPTOAuthApplyRequest(false, schemas.Key{}, existing, nil) + require.NoError(t, err) + assert.Equal(t, existing, headers) + assert.Nil(t, transformer) + }) + + t.Run("enabled with valid token returns merged headers + transformer", func(t *testing.T) { + key := schemas.Key{Value: schemas.EnvVar{Val: validToken}} + headers, transformer, err := chatGPTOAuthApplyRequest(true, key, nil, nil) + require.NoError(t, err) + assert.Equal(t, "acct-apply", headers["chatgpt-account-id"]) + assert.NotNil(t, transformer) + }) + + t.Run("enabled with invalid token returns error (fail fast)", func(t *testing.T) { + key := schemas.Key{Value: schemas.EnvVar{Val: "not-a-jwt"}} + headers, transformer, err := chatGPTOAuthApplyRequest(true, key, nil, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid ChatGPT OAuth token") + assert.Nil(t, headers) + assert.Nil(t, transformer) + }) +} + +// --------------------------------------------------------------------------- +// chatGPTOAuthWebSocketURL / chatGPTOAuthWebSocketHeaders +// --------------------------------------------------------------------------- + +func TestChatGPTOAuthWebSocketURL(t *testing.T) { + t.Run("https with /v1 prefix", func(t *testing.T) { + got := chatGPTOAuthWebSocketURL("https://chatgpt.com/backend-api/codex", "/v1/responses") + assert.Equal(t, "wss://chatgpt.com/backend-api/codex/responses", got) + }) + + t.Run("http maps to ws", func(t *testing.T) { + got := chatGPTOAuthWebSocketURL("http://localhost:8080", "/v1/responses") + assert.Equal(t, "ws://localhost:8080/responses", got) + }) +} + +func TestChatGPTOAuthWebSocketHeaders(t *testing.T) { + validToken := buildTestJWT(map[string]interface{}{ + "https://api.openai.com/auth": map[string]interface{}{"chatgpt_account_id": "acct-ws"}, + }) + key := schemas.Key{Value: schemas.EnvVar{Val: validToken}} + + t.Run("sets Authorization + chatgpt headers", func(t *testing.T) { + got := chatGPTOAuthWebSocketHeaders(key, nil, nil, nil) + assert.Equal(t, "Bearer "+validToken, got["Authorization"]) + assert.Equal(t, "acct-ws", got["chatgpt-account-id"]) + assert.Equal(t, "responses=experimental", got["OpenAI-Beta"]) + }) + + t.Run("merges extra headers but skips Authorization override", func(t *testing.T) { + extra := map[string]string{ + "authorization": "Bearer should-be-ignored", + "X-Custom": "kept", + } + got := chatGPTOAuthWebSocketHeaders(key, extra, nil, nil) + assert.Equal(t, "Bearer "+validToken, got["Authorization"]) + assert.Equal(t, "kept", got["X-Custom"]) + }) + + t.Run("falls back to auth-only when JWT invalid", func(t *testing.T) { + badKey := schemas.Key{Value: schemas.EnvVar{Val: "not-a-jwt"}} + got := chatGPTOAuthWebSocketHeaders(badKey, nil, nil, nil) + assert.Equal(t, "Bearer not-a-jwt", got["Authorization"]) + _, hasAccountID := got["chatgpt-account-id"] + assert.False(t, hasAccountID, "account-id should not be set when JWT extraction fails") + }) + + t.Run("logs warning when JWT invalid and logger provided", func(t *testing.T) { + badKey := schemas.Key{Value: schemas.EnvVar{Val: "not-a-jwt"}} + logger := &captureLogger{} + got := chatGPTOAuthWebSocketHeaders(badKey, nil, nil, logger) + assert.Equal(t, "Bearer not-a-jwt", got["Authorization"]) + require.Len(t, logger.warns, 1) + assert.Contains(t, logger.warns[0], "failed to extract account ID for WebSocket") + }) + + // ----------------------------------------------------------------- + // forwardedHeaders: first-party Codex identity header passthrough + // ----------------------------------------------------------------- + + t.Run("forwardedHeaders parameter is accepted but no longer processed here", func(t *testing.T) { + // Identity headers (originator, version, user-agent) are now merged by + // mergeClientWSHeaders in wsresponses.go, not by this function. + // chatGPTOAuthWebSocketHeaders only returns OAuth-specific headers. + forwarded := map[string]string{ + "originator": "codex_cli_rs", + "version": "0.121.0", + "user-agent": "codex/0.121.0 (Linux; amd64)", + } + got := chatGPTOAuthWebSocketHeaders(key, nil, forwarded, nil) + // OAuth headers must be present + assert.Equal(t, "Bearer "+validToken, got["Authorization"]) + assert.Equal(t, "acct-ws", got["chatgpt-account-id"]) + assert.Equal(t, "responses=experimental", got["OpenAI-Beta"]) + // Identity headers are NOT in the output — they come via mergeClientWSHeaders + _, hasOriginator := got["originator"] + _, hasVersion := got["version"] + _, hasUserAgent := got["user-agent"] + assert.False(t, hasOriginator, "originator is not processed by chatGPTOAuthWebSocketHeaders") + assert.False(t, hasVersion, "version is not processed by chatGPTOAuthWebSocketHeaders") + assert.False(t, hasUserAgent, "user-agent is not processed by chatGPTOAuthWebSocketHeaders") + }) + + t.Run("forwarded authorization is ignored (forwardedHeaders no longer processed)", func(t *testing.T) { + forwarded := map[string]string{ + "authorization": "Bearer client-should-not-appear", + } + got := chatGPTOAuthWebSocketHeaders(key, nil, forwarded, nil) + assert.Equal(t, "Bearer "+validToken, got["Authorization"], + "provider OAuth token must be present; forwarded authorization is not processed") + }) + + t.Run("empty forwarded map does NOT inject identity defaults (caller's responsibility)", func(t *testing.T) { + // chatGPTOAuthWebSocketHeaders only injects OAuth headers; identity defaults + // (originator, version) are the responsibility of mergeClientWSHeaders in + // wsresponses.go so that real client values always win over any fallback. + log := &captureLogger{} + got := chatGPTOAuthWebSocketHeaders(key, nil, map[string]string{}, log) + _, hasOriginator := got["originator"] + _, hasVersion := got["version"] + assert.False(t, hasOriginator, "chatGPTOAuthWebSocketHeaders must NOT inject originator default — that is mergeClientWSHeaders' job") + assert.False(t, hasVersion, "chatGPTOAuthWebSocketHeaders must NOT inject version default — that is mergeClientWSHeaders' job") + // No debug logs should be emitted from this function + assert.Empty(t, log.debugs, "no debug logs expected from chatGPTOAuthWebSocketHeaders") + }) + + t.Run("nil forwarded map does NOT inject identity defaults (caller's responsibility)", func(t *testing.T) { + got := chatGPTOAuthWebSocketHeaders(key, nil, nil, nil) + _, hasOriginator := got["originator"] + _, hasVersion := got["version"] + assert.False(t, hasOriginator, "chatGPTOAuthWebSocketHeaders must NOT inject originator default") + assert.False(t, hasVersion, "chatGPTOAuthWebSocketHeaders must NOT inject version default") + }) +} + +// --------------------------------------------------------------------------- +// Logger-nil branches for chatGPTOAuthMergeHeaders +// --------------------------------------------------------------------------- + +func TestChatGPTOAuthMergeHeaders_LoggerBranch(t *testing.T) { + t.Run("invalid token with logger emits warning", func(t *testing.T) { + badKey := schemas.Key{Value: schemas.EnvVar{Val: "not-a-jwt"}} + logger := &captureLogger{} + existing := map[string]string{"X-Keep": "v"} + got := chatGPTOAuthMergeHeaders(true, badKey, existing, logger) + assert.Equal(t, existing, got) + require.Len(t, logger.warns, 1) + assert.Contains(t, logger.warns[0], "failed to extract account ID") + }) +} + +// --------------------------------------------------------------------------- +// OpenAIListModelsResponse.UnmarshalJSON — dual-shape handling +// --------------------------------------------------------------------------- + +func TestOpenAIListModelsResponse_UnmarshalStandard(t *testing.T) { + body := []byte(`{"object":"list","data":[{"id":"gpt-4","object":"model","owned_by":"openai"}]}`) + var resp OpenAIListModelsResponse + require.NoError(t, json.Unmarshal(body, &resp)) + assert.Equal(t, "list", resp.Object) + require.Len(t, resp.Data, 1) + assert.Equal(t, "gpt-4", resp.Data[0].ID) + assert.Equal(t, "openai", resp.Data[0].OwnedBy) +} + +func TestOpenAIListModelsResponse_UnmarshalChatGPT(t *testing.T) { + body := []byte(`{"models":[{"slug":"gpt-5.3-codex"},{"slug":"gpt-5.4"}]}`) + var resp OpenAIListModelsResponse + require.NoError(t, json.Unmarshal(body, &resp)) + assert.Equal(t, "list", resp.Object, "projected object must be list") + require.Len(t, resp.Data, 2) + assert.Equal(t, "gpt-5.3-codex", resp.Data[0].ID) + assert.Equal(t, "model", resp.Data[0].Object) + assert.Equal(t, "chatgpt-oauth", resp.Data[0].OwnedBy) + assert.Equal(t, "gpt-5.4", resp.Data[1].ID) +} + +func TestOpenAIListModelsResponse_UnmarshalChatGPT_SkipsEmptySlug(t *testing.T) { + body := []byte(`{"models":[{"slug":"gpt-5.4"},{"slug":""},{"slug":"gpt-5.2"}]}`) + var resp OpenAIListModelsResponse + require.NoError(t, json.Unmarshal(body, &resp)) + require.Len(t, resp.Data, 2) + assert.Equal(t, "gpt-5.4", resp.Data[0].ID) + assert.Equal(t, "gpt-5.2", resp.Data[1].ID) +} + +func TestOpenAIListModelsResponse_UnmarshalEmpty(t *testing.T) { + body := []byte(`{}`) + var resp OpenAIListModelsResponse + require.NoError(t, json.Unmarshal(body, &resp)) + assert.Empty(t, resp.Data) +} + +func TestOpenAIListModelsResponse_UnmarshalInvalidJSON(t *testing.T) { + body := []byte(`{invalid}`) + var resp OpenAIListModelsResponse + err := resp.UnmarshalJSON(body) + require.Error(t, err) +} + +// --------------------------------------------------------------------------- +// Non-streaming Responses path rejects chatgpt_oauth cleanly via error sentinel +// --------------------------------------------------------------------------- + +func TestChatGPTOAuthRequiresStreaming_Error(t *testing.T) { + // The sentinel must be exported within the package so Responses() can reference it. + assert.NotNil(t, errChatGPTOAuthRequiresStreaming) + assert.Contains(t, errChatGPTOAuthRequiresStreaming.Error(), "streaming") +} diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index 8833252738..b68c4c84b9 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -32,6 +32,7 @@ type OpenAIProvider struct { sendBackRawResponse bool // Whether to include raw response in BifrostResponse customProviderConfig *schemas.CustomProviderConfig // Custom provider config disableStore bool // Whether to force store=false on outgoing requests + chatgptOAuth bool // Whether to route requests through ChatGPT's backend API for subscription-based access } // NewOpenAIProvider creates a new OpenAI provider instance. @@ -62,8 +63,13 @@ func NewOpenAIProvider(config *schemas.ProviderConfig, logger schemas.Logger) *O client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) streamingClient := providerUtils.BuildStreamingClient(client) // Set default BaseURL if not provided + chatgptOAuth := config.OpenAIConfig != nil && config.OpenAIConfig.ChatGPTOAuth if config.NetworkConfig.BaseURL == "" { - config.NetworkConfig.BaseURL = "https://api.openai.com" + if chatgptOAuth { + config.NetworkConfig.BaseURL = ChatGPTOAuthDefaultBaseURL + } else { + config.NetworkConfig.BaseURL = "https://api.openai.com" + } } config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") @@ -76,6 +82,7 @@ func NewOpenAIProvider(config *schemas.ProviderConfig, logger schemas.Logger) *O sendBackRawResponse: config.SendBackRawResponse, customProviderConfig: config.CustomProviderConfig, disableStore: config.OpenAIConfig != nil && config.OpenAIConfig.DisableStore, + chatgptOAuth: chatgptOAuth, } } @@ -84,8 +91,32 @@ func (provider *OpenAIProvider) GetProviderKey() schemas.ModelProvider { return providerUtils.GetProviderName(schemas.OpenAI, provider.customProviderConfig) } +// effectiveExtraHeaders returns the network config's ExtraHeaders, merged with +// ChatGPT OAuth headers (chatgpt-account-id, OpenAI-Beta) when chatgpt_oauth is enabled. +// This is the canonical way for every upstream request to get its headers so that +// ChatGPT OAuth-specific headers are automatically injected across all routes. +func (provider *OpenAIProvider) effectiveExtraHeaders(key schemas.Key) map[string]string { + return chatGPTOAuthMergeHeaders(provider.chatgptOAuth, key, provider.networkConfig.ExtraHeaders, provider.logger) +} + +// buildFullURL concatenates the base URL with a standard OpenAI /v1 path, +// applying the chatgpt_oauth path rewrite when enabled. Use this instead of +// manually appending to provider.networkConfig.BaseURL for routes that don't +// already go through buildRequestURL (e.g. dynamic paths like /v1/files/{id}). +func (provider *OpenAIProvider) buildFullURL(standardPath string) string { + if provider.chatgptOAuth { + standardPath = chatGPTOAuthPath(standardPath) + } + return provider.networkConfig.BaseURL + standardPath +} + // buildRequestURL constructs the full request URL using the provider's configuration. +// When chatgpt_oauth is enabled, the /v1 prefix is stripped from paths so requests +// route correctly to chatgpt.com/backend-api/codex/ instead of /codex/v1/. func (provider *OpenAIProvider) buildRequestURL(ctx *schemas.BifrostContext, defaultPath string, requestType schemas.RequestType) string { + if provider.chatgptOAuth { + defaultPath = chatGPTOAuthPath(defaultPath) + } path, isCompleteURL := providerUtils.GetRequestPath(ctx, defaultPath, provider.customProviderConfig, requestType) if isCompleteURL { return path @@ -99,6 +130,14 @@ func (provider *OpenAIProvider) ListModels(ctx *schemas.BifrostContext, keys []s } providerName := provider.GetProviderKey() + // Pick a representative key for ChatGPT OAuth header extraction. + // All keys for a ChatGPT OAuth-enabled provider share the same account, + // so the first key's JWT is sufficient for the chatgpt-account-id header. + var headerKey schemas.Key + if len(keys) > 0 { + headerKey = keys[0] + } + if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { return providerUtils.HandleKeylessListModelsRequest(providerName, func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { return ListModelsByKey( @@ -107,7 +146,7 @@ func (provider *OpenAIProvider) ListModels(ctx *schemas.BifrostContext, keys []s provider.buildRequestURL(ctx, "/v1/models", schemas.ListModelsRequest), schemas.Key{}, request.Unfiltered, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(headerKey), providerName, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), @@ -120,7 +159,7 @@ func (provider *OpenAIProvider) ListModels(ctx *schemas.BifrostContext, keys []s request, provider.buildRequestURL(ctx, "/v1/models", schemas.ListModelsRequest), keys, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(headerKey), providerName, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), @@ -240,7 +279,7 @@ func (provider *OpenAIProvider) TextCompletion(ctx *schemas.BifrostContext, key provider.buildRequestURL(ctx, "/v1/completions", schemas.TextCompletionRequest), request, key, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), @@ -400,7 +439,7 @@ func (provider *OpenAIProvider) TextCompletionStream(ctx *schemas.BifrostContext provider.buildRequestURL(ctx, "/v1/completions", schemas.TextCompletionStreamRequest), request, authHeader, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), @@ -747,7 +786,7 @@ func (provider *OpenAIProvider) ChatCompletion(ctx *schemas.BifrostContext, key provider.buildRequestURL(ctx, "/v1/chat/completions", schemas.ChatCompletionRequest), request, key, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), @@ -916,7 +955,7 @@ func (provider *OpenAIProvider) ChatCompletionStream(ctx *schemas.BifrostContext provider.buildRequestURL(ctx, "/v1/chat/completions", schemas.ChatCompletionStreamRequest), request, authHeader, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), @@ -1345,6 +1384,16 @@ func (provider *OpenAIProvider) Responses(ctx *schemas.BifrostContext, key schem return nil, err } + // ChatGPT OAuth rejects non-streaming /responses entirely — the upstream only + // accepts stream=true. Non-streaming callers would receive SSE that the standard + // handler can't parse. Surface a clear error instead of a confusing decode failure. + if provider.chatgptOAuth { + return nil, providerUtils.NewBifrostOperationError( + "non-streaming /responses is not supported when chatgpt_oauth is enabled; use ResponsesStream (stream=true) instead", + errChatGPTOAuthRequiresStreaming, + ) + } + if provider.disableStore { if request.Params == nil { request.Params = &schemas.ResponsesParameters{} @@ -1352,23 +1401,35 @@ func (provider *OpenAIProvider) Responses(ctx *schemas.BifrostContext, key schem request.Params.Store = schemas.Ptr(false) } + // Pass raw ExtraHeaders (not effectiveExtraHeaders) to avoid double JWT decoding. + // chatGPTOAuthApplyRequest merges OAuth headers in itself; calling effectiveExtraHeaders + // here would extract the account ID, then chatGPTOAuthApplyRequest would extract it + // again — each extraction base64url-decodes and JSON-parses the JWT. + extraHeaders, bodyTransformer, oauthErr := chatGPTOAuthApplyRequest(provider.chatgptOAuth, key, provider.networkConfig.ExtraHeaders, provider.logger) + if oauthErr != nil { + return nil, providerUtils.NewBifrostOperationError(oauthErr.Error(), oauthErr) + } + return HandleOpenAIResponsesRequest( ctx, provider.client, provider.buildRequestURL(ctx, "/v1/responses", schemas.ResponsesRequest), request, key, - provider.networkConfig.ExtraHeaders, + extraHeaders, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), nil, nil, provider.logger, + bodyTransformer, ) } // HandleOpenAIResponsesRequest handles a responses request to OpenAI's API. +// bodyTransformer, if non-nil, is applied to the serialized JSON body before it is sent. +// Pass nil for standard behavior. func HandleOpenAIResponsesRequest( ctx *schemas.BifrostContext, client *fasthttp.Client, @@ -1382,6 +1443,7 @@ func HandleOpenAIResponsesRequest( customResponseHandler responseHandler[schemas.BifrostResponsesResponse], customErrorConverter ErrorConverter, logger schemas.Logger, + bodyTransformer func([]byte) ([]byte, error), ) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { // Create request req := fasthttp.AcquireRequest() @@ -1407,23 +1469,27 @@ func HandleOpenAIResponsesRequest( req.Header.Set("Authorization", "Bearer "+key.Value.GetValue()) } - // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { - if lpErr != nil { - return nil, lpErr - } - if len(lpResult.ResponseBody) > 0 { - response := &schemas.BifrostResponsesResponse{} - if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) + // Large payload passthrough: stream body directly without JSON marshaling. + // Skip it when a body transformer is set (e.g. chatgpt_oauth) — the transformer + // needs to mutate the JSON, which can't happen on a streamed passthrough. + if bodyTransformer == nil { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { + if lpErr != nil { + return nil, lpErr } - response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} - return response, nil + if len(lpResult.ResponseBody) > 0 { + response := &schemas.BifrostResponsesResponse{} + if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) + } + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} + return response, nil + } + return &schemas.BifrostResponsesResponse{ + Model: request.Model, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, + }, nil } - return &schemas.BifrostResponsesResponse{ - Model: request.Model, - ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, - }, nil } // Use centralized converter @@ -1437,6 +1503,14 @@ func HandleOpenAIResponsesRequest( return nil, bifrostErr } + if bodyTransformer != nil { + transformed, transformErr := bodyTransformer(jsonData) + if transformErr != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, transformErr) + } + jsonData = transformed + } + req.SetBody(jsonData) // Make request @@ -1511,6 +1585,7 @@ func (provider *OpenAIProvider) ResponsesStream(ctx *schemas.BifrostContext, pos if key.Value.GetValue() != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()} } + if provider.disableStore { if request.Params == nil { request.Params = &schemas.ResponsesParameters{} @@ -1518,6 +1593,12 @@ func (provider *OpenAIProvider) ResponsesStream(ctx *schemas.BifrostContext, pos request.Params.Store = schemas.Ptr(false) } + // Pass raw ExtraHeaders (see comment in Responses above) to avoid double JWT decoding. + extraHeaders, streamBodyTransformer, oauthErr := chatGPTOAuthApplyRequest(provider.chatgptOAuth, key, provider.networkConfig.ExtraHeaders, provider.logger) + if oauthErr != nil { + return nil, providerUtils.NewBifrostOperationError(oauthErr.Error(), oauthErr) + } + // Use shared streaming logic return HandleOpenAIResponsesStreaming( ctx, @@ -1525,7 +1606,7 @@ func (provider *OpenAIProvider) ResponsesStream(ctx *schemas.BifrostContext, pos provider.buildRequestURL(ctx, "/v1/responses", schemas.ResponsesStreamRequest), request, authHeader, - provider.networkConfig.ExtraHeaders, + extraHeaders, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), @@ -1534,6 +1615,7 @@ func (provider *OpenAIProvider) ResponsesStream(ctx *schemas.BifrostContext, pos nil, nil, nil, + streamBodyTransformer, provider.logger, postHookSpanFinalizer, ) @@ -1556,6 +1638,7 @@ func HandleOpenAIResponsesStreaming( customErrorConverter ErrorConverter, postRequestConverter func(*OpenAIResponsesRequest) *OpenAIResponsesRequest, postResponseConverter func(*schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse, + bodyTransformer func([]byte) ([]byte, error), logger schemas.Logger, postHookSpanFinalizer func(context.Context), ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { @@ -1588,6 +1671,18 @@ func HandleOpenAIResponsesStreaming( return nil, bifrostErr } + // When a body transformer is set (e.g. chatgpt_oauth), the transformed JSON must + // be the one sent upstream — large-payload streaming passthrough must be bypassed + // since it streams the raw original body, defeating the transformer. + hasBodyTransformer := bodyTransformer != nil + if hasBodyTransformer { + transformed, transformErr := bodyTransformer(jsonBody) + if transformErr != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, transformErr) + } + jsonBody = transformed + } + // Create HTTP request for streaming req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -1606,7 +1701,12 @@ func HandleOpenAIResponsesStreaming( req.Header.Set(key, value) } - setStreamingRequestBody(ctx, req, jsonBody, providerName) + if hasBodyTransformer { + // Always use the transformed JSON body; skip large-payload streaming passthrough. + req.SetBody(jsonBody) + } else { + setStreamingRequestBody(ctx, req, jsonBody, providerName) + } // Use streaming-aware client when large payload optimization is active — ensures // MaxResponseBodySize > 0 so ErrBodyTooLarge triggers StreamBody for Content-Length responses. @@ -1832,7 +1932,7 @@ func (provider *OpenAIProvider) Embedding(ctx *schemas.BifrostContext, key schem provider.buildRequestURL(ctx, "/v1/embeddings", schemas.EmbeddingRequest), request, key, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), @@ -1987,7 +2087,7 @@ func (provider *OpenAIProvider) Speech(ctx *schemas.BifrostContext, key schemas. provider.buildRequestURL(ctx, "/v1/audio/speech", schemas.SpeechRequest), request, key, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), @@ -2128,7 +2228,7 @@ func (provider *OpenAIProvider) SpeechStream(ctx *schemas.BifrostContext, postHo provider.buildRequestURL(ctx, "/v1/audio/speech", schemas.SpeechStreamRequest), request, authHeader, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), @@ -2382,7 +2482,7 @@ func (provider *OpenAIProvider) Transcription(ctx *schemas.BifrostContext, key s provider.buildRequestURL(ctx, "/v1/audio/transcriptions", schemas.TranscriptionRequest), request, key, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), provider.GetProviderKey(), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), nil, @@ -2568,7 +2668,7 @@ func (provider *OpenAIProvider) TranscriptionStream(ctx *schemas.BifrostContext, provider.buildRequestURL(ctx, "/v1/audio/transcriptions", schemas.TranscriptionStreamRequest), request, authHeader, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), false, provider.GetProviderKey(), @@ -2838,7 +2938,7 @@ func (provider *OpenAIProvider) ImageGeneration(ctx *schemas.BifrostContext, key provider.buildRequestURL(ctx, "/v1/images/generations", schemas.ImageGenerationRequest), req, key, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), @@ -2997,7 +3097,7 @@ func (provider *OpenAIProvider) ImageGenerationStream( provider.buildRequestURL(ctx, "/v1/images/generations", schemas.ImageGenerationStreamRequest), request, authHeader, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), @@ -3402,7 +3502,7 @@ func (provider *OpenAIProvider) VideoGeneration(ctx *schemas.BifrostContext, key provider.buildRequestURL(ctx, "/v1/videos", schemas.VideoGenerationRequest), request, key, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), @@ -3428,7 +3528,7 @@ func (provider *OpenAIProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s provider.buildRequestURL(ctx, "/v1/videos/"+videoID, schemas.VideoRetrieveRequest), request, key, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), nil, // OpenAI uses Bearer from key providerName, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), @@ -3458,7 +3558,7 @@ func (provider *OpenAIProvider) VideoDownload(ctx *schemas.BifrostContext, key s defer fasthttp.ReleaseResponse(resp) // Set headers - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) // Build URL: /v1/videos/{video_id}/content requestURL := provider.buildRequestURL(ctx, "/v1/videos/"+videoID+"/content", schemas.VideoDownloadRequest) @@ -3536,7 +3636,7 @@ func (provider *OpenAIProvider) VideoDelete(ctx *schemas.BifrostContext, key sch provider.buildRequestURL(ctx, "/v1/videos/"+videoID, schemas.VideoDeleteRequest), videoID, key, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), providerName, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), @@ -3556,7 +3656,7 @@ func (provider *OpenAIProvider) VideoList(ctx *schemas.BifrostContext, key schem provider.buildRequestURL(ctx, "/v1/videos", schemas.VideoListRequest), request, key, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), @@ -3964,7 +4064,7 @@ func (provider *OpenAIProvider) CountTokens(ctx *schemas.BifrostContext, key sch provider.buildRequestURL(ctx, "/v1/responses/input_tokens", schemas.CountTokensRequest), request, key, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), @@ -4102,7 +4202,7 @@ func (provider *OpenAIProvider) ImageEdit(ctx *schemas.BifrostContext, key schem provider.buildRequestURL(ctx, "/v1/images/edits", schemas.ImageEditRequest), request, key, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), false, providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), @@ -4241,7 +4341,7 @@ func (provider *OpenAIProvider) ImageEditStream(ctx *schemas.BifrostContext, pos provider.buildRequestURL(ctx, "/v1/images/edits", schemas.ImageEditStreamRequest), request, authHeader, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), false, providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), @@ -4621,7 +4721,7 @@ func (provider *OpenAIProvider) ImageVariation(ctx *schemas.BifrostContext, key provider.buildRequestURL(ctx, "/v1/images/variations", schemas.ImageVariationRequest), request, key, - provider.networkConfig.ExtraHeaders, + provider.effectiveExtraHeaders(key), false, providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), @@ -4796,7 +4896,7 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche defer fasthttp.ReleaseResponse(resp) // Set headers - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) req.SetRequestURI(provider.buildRequestURL(ctx, "/v1/files", schemas.FileUploadRequest)) req.Header.SetMethod(http.MethodPost) req.Header.SetContentType(writer.FormDataContentType()) @@ -4893,7 +4993,7 @@ func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []sch } // Set headers - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) req.SetRequestURI(requestURL) req.Header.SetMethod(http.MethodGet) req.Header.SetContentType("application/json") @@ -4986,8 +5086,8 @@ func (provider *OpenAIProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ resp := fasthttp.AcquireResponse() // Set headers - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) - req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/files/" + request.FileID) + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) + req.SetRequestURI(provider.buildFullURL("/v1/files/" + request.FileID)) req.Header.SetMethod(http.MethodGet) req.Header.SetContentType("application/json") @@ -5062,8 +5162,8 @@ func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []s resp := fasthttp.AcquireResponse() // Set headers - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) - req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/files/" + request.FileID) + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) + req.SetRequestURI(provider.buildFullURL("/v1/files/" + request.FileID)) req.Header.SetMethod(http.MethodDelete) req.Header.SetContentType("application/json") @@ -5152,8 +5252,8 @@ func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys [] resp := fasthttp.AcquireResponse() // Set headers - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) - req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/files/" + request.FileID + "/content") + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) + req.SetRequestURI(provider.buildFullURL("/v1/files/" + request.FileID + "/content")) req.Header.SetMethod(http.MethodGet) if key.Value.GetValue() != "" { @@ -5247,7 +5347,7 @@ func (provider *OpenAIProvider) VideoRemix(ctx *schemas.BifrostContext, key sche defer fasthttp.ReleaseResponse(resp) // Set headers - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) req.SetRequestURI(provider.buildRequestURL(ctx, "/v1/videos/"+videoID+"/remix", schemas.VideoRemixRequest)) req.Header.SetMethod(http.MethodPost) req.Header.SetContentType("application/json") @@ -5351,7 +5451,7 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch defer fasthttp.ReleaseResponse(resp) // Set headers - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) req.SetRequestURI(provider.buildRequestURL(ctx, "/v1/batches", schemas.BatchCreateRequest)) req.Header.SetMethod(http.MethodPost) req.Header.SetContentType("application/json") @@ -5458,7 +5558,7 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc } // Set headers - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) req.SetRequestURI(requestURL) req.Header.SetMethod(http.MethodGet) req.Header.SetContentType("application/json") @@ -5538,8 +5638,8 @@ func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys resp := fasthttp.AcquireResponse() // Set headers - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) - req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/batches/" + request.BatchID) + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) + req.SetRequestURI(provider.buildFullURL("/v1/batches/" + request.BatchID)) req.Header.SetMethod(http.MethodGet) req.Header.SetContentType("application/json") @@ -5612,8 +5712,8 @@ func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] resp := fasthttp.AcquireResponse() // Set headers - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) - req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/batches/" + request.BatchID + "/cancel") + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) + req.SetRequestURI(provider.buildFullURL("/v1/batches/" + request.BatchID + "/cancel")) req.Header.SetMethod(http.MethodPost) req.Header.SetContentType("application/json") @@ -5729,8 +5829,8 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ resp := fasthttp.AcquireResponse() // Set headers - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) - req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/files/" + *batchResp.OutputFileID + "/content") + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) + req.SetRequestURI(provider.buildFullURL("/v1/files/" + *batchResp.OutputFileID + "/content")) req.Header.SetMethod(http.MethodGet) if key.Value.GetValue() != "" { @@ -5853,7 +5953,7 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) req.SetRequestURI(provider.buildRequestURL(ctx, "/v1/containers", schemas.ContainerCreateRequest)) req.Header.SetMethod(http.MethodPost) @@ -5983,7 +6083,7 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) req.SetRequestURI(requestURL) req.Header.SetMethod(http.MethodGet) @@ -6082,7 +6182,7 @@ func (provider *OpenAIProvider) ContainerRetrieve(ctx *schemas.BifrostContext, k req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) req.SetRequestURI(provider.buildRequestURL(ctx, "/v1/containers/"+request.ContainerID, schemas.ContainerRetrieveRequest)) req.Header.SetMethod(http.MethodGet) @@ -6189,7 +6289,7 @@ func (provider *OpenAIProvider) ContainerDelete(ctx *schemas.BifrostContext, key req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) req.SetRequestURI(provider.buildRequestURL(ctx, "/v1/containers/"+request.ContainerID, schemas.ContainerDeleteRequest)) req.Header.SetMethod(http.MethodDelete) @@ -6282,7 +6382,7 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) endpoint := fmt.Sprintf("/v1/containers/%s/files", request.ContainerID) req.SetRequestURI(provider.buildRequestURL(ctx, endpoint, schemas.ContainerFileCreateRequest)) @@ -6443,7 +6543,7 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) req.SetRequestURI(requestURL) req.Header.SetMethod(http.MethodGet) @@ -6547,7 +6647,7 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) endpoint := fmt.Sprintf("/v1/containers/%s/files/%s", request.ContainerID, request.FileID) req.SetRequestURI(provider.buildRequestURL(ctx, endpoint, schemas.ContainerFileRetrieveRequest)) @@ -6661,7 +6761,7 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) endpoint := fmt.Sprintf("/v1/containers/%s/files/%s/content", request.ContainerID, request.FileID) req.SetRequestURI(provider.buildRequestURL(ctx, endpoint, schemas.ContainerFileContentRequest)) @@ -6760,7 +6860,7 @@ func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() - providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) + providerUtils.SetExtraHeaders(ctx, req, provider.effectiveExtraHeaders(key), nil) endpoint := fmt.Sprintf("/v1/containers/%s/files/%s", request.ContainerID, request.FileID) req.SetRequestURI(provider.buildRequestURL(ctx, endpoint, schemas.ContainerFileDeleteRequest)) @@ -6850,7 +6950,7 @@ func (provider *OpenAIProvider) Passthrough( path = after } - url := provider.networkConfig.BaseURL + "/v1" + path + url := provider.buildFullURL("/v1" + path) if req.RawQuery != "" { url += "?" + req.RawQuery } @@ -6863,7 +6963,7 @@ func (provider *OpenAIProvider) Passthrough( fasthttpReq.Header.SetMethod(req.Method) fasthttpReq.SetRequestURI(url) - providerUtils.SetExtraHeaders(ctx, fasthttpReq, provider.networkConfig.ExtraHeaders, nil) + providerUtils.SetExtraHeaders(ctx, fasthttpReq, provider.effectiveExtraHeaders(key), nil) for k, v := range req.SafeHeaders { fasthttpReq.Header.Set(k, v) @@ -6925,7 +7025,7 @@ func (provider *OpenAIProvider) PassthroughStream( if after, ok := strings.CutPrefix(path, "/v1"); ok { path = after } - url := provider.networkConfig.BaseURL + "/v1" + path + url := provider.buildFullURL("/v1" + path) if req.RawQuery != "" { url += "?" + req.RawQuery } @@ -6938,7 +7038,7 @@ func (provider *OpenAIProvider) PassthroughStream( fasthttpReq.Header.SetMethod(req.Method) fasthttpReq.SetRequestURI(url) - providerUtils.SetExtraHeaders(ctx, fasthttpReq, provider.networkConfig.ExtraHeaders, nil) + providerUtils.SetExtraHeaders(ctx, fasthttpReq, provider.effectiveExtraHeaders(key), nil) for k, v := range req.SafeHeaders { fasthttpReq.Header.Set(k, v) diff --git a/core/providers/openai/types.go b/core/providers/openai/types.go index e2eab5245a..1b41224659 100644 --- a/core/providers/openai/types.go +++ b/core/providers/openai/types.go @@ -864,12 +864,59 @@ type OpenAIModel struct { ContextWindow *int `json:"context_window,omitempty"` } -// OpenAIListModelsResponse represents an OpenAI list models response +// OpenAIListModelsResponse represents an OpenAI list models response. +// Supports two wire formats via a custom UnmarshalJSON: +// - Standard OpenAI: {"object":"list","data":[{"id":"...","object":"model",...}]} +// - ChatGPT backend: {"models":[{"slug":"..."}]} — only "slug" is significant; +// other OpenAIModel fields (object, owned_by, created) are populated with +// sensible defaults so downstream code treats it uniformly. type OpenAIListModelsResponse struct { Object string `json:"object"` Data []OpenAIModel `json:"data"` } +// UnmarshalJSON accepts both the standard OpenAI shape ("data":[...]) and the +// ChatGPT backend shape ("models":[{"slug"}]). When the payload looks like the +// ChatGPT shape (no "data", has "models"), it is projected into Data with +// OpenAIModel.ID=slug, Object="model", OwnedBy="chatgpt-oauth". +func (r *OpenAIListModelsResponse) UnmarshalJSON(data []byte) error { + // Raw inspection to decide which shape we're dealing with. + var raw struct { + Object string `json:"object"` + Data []OpenAIModel `json:"data"` + Models []chatGPTRawModel `json:"models"` + } + if err := sonic.Unmarshal(data, &raw); err != nil { + return err + } + r.Object = raw.Object + if len(raw.Data) > 0 { + r.Data = raw.Data + return nil + } + // ChatGPT backend shape: project slug → ID. + if len(raw.Models) > 0 { + r.Object = "list" + r.Data = make([]OpenAIModel, 0, len(raw.Models)) + for _, m := range raw.Models { + if m.Slug == "" { + continue + } + r.Data = append(r.Data, OpenAIModel{ + ID: m.Slug, + Object: "model", + OwnedBy: "chatgpt-oauth", + }) + } + } + return nil +} + +// chatGPTRawModel is the ChatGPT backend's model entry shape. Only Slug is used. +type chatGPTRawModel struct { + Slug string `json:"slug"` +} + // OpenAIImageGenerationRequest is the struct for Image Generation requests by OpenAI. type OpenAIImageGenerationRequest struct { Model string `json:"model"` diff --git a/core/providers/openai/websocket.go b/core/providers/openai/websocket.go index 878443cc7a..316ff35fe3 100644 --- a/core/providers/openai/websocket.go +++ b/core/providers/openai/websocket.go @@ -7,21 +7,44 @@ import ( ) // SupportsWebSocketMode returns true since OpenAI natively supports the Responses API WebSocket Mode. +// This applies to both the standard OpenAI path (api.openai.com) and the ChatGPT OAuth path +// (chatgpt.com/backend-api/codex). func (provider *OpenAIProvider) SupportsWebSocketMode() bool { return true } // WebSocketResponsesURL returns the WebSocket URL for the OpenAI Responses API. -// Converts the HTTP base URL to a WSS URL: https://api.openai.com -> wss://api.openai.com/v1/responses +// Converts the HTTP base URL to a WSS URL: https://api.openai.com -> wss://api.openai.com/v1/responses. +// When chatgpt_oauth is enabled, routes to chatgpt.com/backend-api/codex/responses (no /v1 prefix). func (provider *OpenAIProvider) WebSocketResponsesURL(key schemas.Key) string { + if provider.chatgptOAuth { + return chatGPTOAuthWebSocketURL(provider.networkConfig.BaseURL, "/v1/responses") + } base := provider.networkConfig.BaseURL base = strings.Replace(base, "https://", "wss://", 1) base = strings.Replace(base, "http://", "ws://", 1) return base + "/v1/responses" } -// WebSocketHeaders returns the headers required for the upstream WebSocket connection to OpenAI. +// ChatGPTOAuthEnabled reports whether this provider instance was configured with +// chatgpt_oauth enabled. Used by the transport layer to decide whether to inject +// Codex identity defaults (originator, version) into the upstream WS connection. +func (provider *OpenAIProvider) ChatGPTOAuthEnabled() bool { + return provider.chatgptOAuth +} + +// WebSocketHeaders returns the OAuth-specific headers for the upstream WebSocket connection. +// For chatgpt_oauth, it returns Authorization, chatgpt-account-id, and OpenAI-Beta only; +// it does NOT inject Codex identity headers (originator, version). +// +// The caller (tryNativeWSUpstream in wsresponses.go) then passes these provider headers to +// mergeClientWSHeaders, which layers the real client headers on top (so a Codex client's own +// originator and version always win) and injects identity defaults only as a last resort if +// neither the client nor the provider supplied them. func (provider *OpenAIProvider) WebSocketHeaders(key schemas.Key) map[string]string { + if provider.chatgptOAuth { + return chatGPTOAuthWebSocketHeaders(key, provider.networkConfig.ExtraHeaders, nil, provider.logger) + } headers := map[string]string{ "Authorization": "Bearer " + key.Value.GetValue(), } diff --git a/core/providers/openrouter/openrouter.go b/core/providers/openrouter/openrouter.go index 36e4ff0566..cd12758fb0 100644 --- a/core/providers/openrouter/openrouter.go +++ b/core/providers/openrouter/openrouter.go @@ -371,6 +371,7 @@ func (provider *OpenRouterProvider) Responses(ctx *schemas.BifrostContext, key s nil, nil, provider.logger, + nil, ) } @@ -396,6 +397,7 @@ func (provider *OpenRouterProvider) ResponsesStream(ctx *schemas.BifrostContext, nil, nil, nil, + nil, provider.logger, postHookSpanFinalizer, ) diff --git a/core/providers/xai/xai.go b/core/providers/xai/xai.go index e787f307fd..ff73fbf500 100644 --- a/core/providers/xai/xai.go +++ b/core/providers/xai/xai.go @@ -189,6 +189,7 @@ func (provider *XAIProvider) Responses(ctx *schemas.BifrostContext, key schemas. nil, ParseXAIError, provider.logger, + nil, ) } @@ -213,6 +214,7 @@ func (provider *XAIProvider) ResponsesStream(ctx *schemas.BifrostContext, postHo ParseXAIError, nil, nil, + nil, provider.logger, postHookSpanFinalizer, ) diff --git a/core/schemas/provider.go b/core/schemas/provider.go index e087747c07..413d1453e0 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -448,6 +448,7 @@ type ProviderConfig struct { // OpenAIConfig holds OpenAI-specific provider configuration. type OpenAIConfig struct { DisableStore bool `json:"disable_store"` // When true, forces store=false on all outgoing OpenAI requests (default: false) + ChatGPTOAuth bool `json:"chatgpt_oauth"` // When true, routes requests through ChatGPT's backend API (chatgpt.com/backend-api/codex) for subscription-based access } func (config *ProviderConfig) CheckAndSetDefaults() { diff --git a/docs/cli-agents/codex-cli.mdx b/docs/cli-agents/codex-cli.mdx index eb0ddaa5a7..c173454bc0 100644 --- a/docs/cli-agents/codex-cli.mdx +++ b/docs/cli-agents/codex-cli.mdx @@ -18,12 +18,14 @@ npm install -g @openai/codex ## Configuring Codex CLI with Bifrost +Bifrost supports two authentication methods for Codex CLI: **API Key** and **ChatGPT OAuth**. Choose the one that matches your setup. + +### Option 1: API Key (Usage-Based) + -Codex CLI always prefers OAuth over custom API keys. Make sure you run `/logout` before configuring the Bifrost gateway with Codex. +Codex CLI always prefers OAuth over API keys. If you are logged in via ChatGPT OAuth, run `/logout` in Codex before using this method. -### Update codex.toml - Add the Bifrost base URL and credentials to your global `~/.codex/config.toml` or project-specific `.codex/config.toml`: ```bash @@ -38,7 +40,37 @@ model = "openai/gpt-5.4" Always run `codex` from the same terminal session where you exported variables, or restart the terminal after changing your profile. GUI-launched terminals or IDEs may not pick up shell-profile exports unless the environment is configured there as well. -Codex CLI defaults to [websocket mode](https://developers.openai.com/api/docs/guides/websocket-mode) for the Responses API and automatically falls back to HTTPS if the websocket connection fails. To enable https for Codex CLI by default, add these settings in your `config.toml`: +### Option 2: ChatGPT OAuth (Subscription-Based) + +If you have a ChatGPT Plus, Pro, Business, or Enterprise subscription, you can use your ChatGPT OAuth login through Bifrost. This routes requests through ChatGPT's backend API using your subscription-based access. + + + + Go to **Providers > OpenAI > OpenAI Config** and enable the **ChatGPT OAuth Passthrough** toggle. When enabled, Bifrost automatically: + - Routes requests to `chatgpt.com/backend-api/codex` instead of `api.openai.com` + - Injects the required `chatgpt-account-id` and `OpenAI-Beta` headers + - Transforms request bodies to match the ChatGPT backend format + + + Set the Bifrost URL in your `~/.codex/config.toml`: + ```toml + openai_base_url = "http://localhost:8080/openai/v1" + model = "gpt-5.4" + ``` + No API key, `env_key`, or `openai/` model prefix is needed — Bifrost automatically extracts the OAuth token from the request and the OpenAI provider is the default when no provider prefix is specified. + + + If you are not already logged in, run `codex` and complete the ChatGPT OAuth login flow when prompted. + + + + +The ChatGPT OAuth passthrough requires no API keys configured on the OpenAI provider and no additional security settings. Bifrost automatically extracts the OAuth token from the `Authorization` header and forwards it to the ChatGPT backend. + + +### Transport mode (WebSocket vs HTTPS) + +Codex CLI defaults to [websocket mode](https://developers.openai.com/api/docs/guides/websocket-mode) for the Responses API and automatically falls back to HTTPS if the websocket connection fails. To force HTTPS by default, add these settings in your `config.toml`: ```toml diff --git a/transports/bifrost-http/handlers/wsresponses.go b/transports/bifrost-http/handlers/wsresponses.go index ca293a116e..c41ec045ac 100644 --- a/transports/bifrost-http/handlers/wsresponses.go +++ b/transports/bifrost-http/handlers/wsresponses.go @@ -2,12 +2,15 @@ package handlers import ( "context" + "encoding/hex" "strings" + "unicode/utf8" "github.com/bytedance/sonic" "github.com/fasthttp/router" ws "github.com/fasthttp/websocket" bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/providers/openai" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/transports/bifrost-http/integrations" "github.com/maximhq/bifrost/transports/bifrost-http/lib" @@ -108,9 +111,90 @@ type authHeaders struct { googAPIKey string baggage string extraHeaders map[string]string + + // clientHeaders contains all request headers forwarded from the incoming WS + // upgrade that are not Bifrost-internal or hop-by-hop. These are forwarded + // to the upstream WS connection so that first-party identity headers (e.g. + // originator, version, user-agent sent by Codex) reach the upstream backend. + // Keys are stored in their original capitalisation. + // See wsUpgradeHeaderBlocklist for the blocklist rationale. + clientHeaders map[string]string +} + +// wsUpgradeHeaderBlocklist contains lowercase header names that must NOT be +// forwarded from the client WS upgrade to the upstream WS connection. +// +// Design decision: forward-all-except-blocklist. +// An allowlist would have to be updated every time Codex adds a new identity +// header, causing silent regressions. The blocklist is stable: it covers +// hop-by-hop headers (which change meaning at each hop), security-sensitive +// Bifrost-internal routing headers, credential headers that Bifrost manages +// itself (authorization is re-injected by the provider with the right token), +// and network-topology headers that must not be leaked to external upstreams. +// +// Hop-by-hop headers (RFC 7230 §6.1 + WebSocket specific): +// +// connection, upgrade, sec-websocket-key, sec-websocket-version, +// sec-websocket-extensions, sec-websocket-protocol, keep-alive, +// proxy-authorization, proxy-authenticate, te, trailer, trailers, +// transfer-encoding +// +// Security / Bifrost-internal: +// +// host — must reflect the upstream hostname, not the client's +// origin — forwarding Origin: http://localhost... would cause +// chatgpt.com to CORS-reject the connection +// cookie — session cookies must not be proxied to a different origin +// authorization — re-injected by the provider with the correct OAuth token +// x-api-key — Bifrost routing credential; not forwarded to upstream +// x-goog-api-key — Bifrost routing credential; not forwarded to upstream +// x-bf-* — Bifrost-internal routing/config headers +// baggage — W3C baggage contains Bifrost session metadata (session-id) +// +// Network topology (must not leak internal addresses to external upstreams): +// +// x-forwarded-for, x-forwarded-host, x-forwarded-proto, x-real-ip +// +// Content metadata (not meaningful for WS): +// +// content-length, content-type +var wsUpgradeHeaderBlocklist = map[string]bool{ + // Hop-by-hop + "connection": true, + "upgrade": true, + "sec-websocket-key": true, + "sec-websocket-version": true, + "sec-websocket-extensions": true, + "sec-websocket-protocol": true, + "keep-alive": true, + "proxy-authorization": true, + "proxy-authenticate": true, + "te": true, + "trailer": true, + "trailers": true, + "transfer-encoding": true, + // Security / Bifrost-internal + "host": true, + "origin": true, + "cookie": true, + "authorization": true, + "x-api-key": true, + "x-goog-api-key": true, + "baggage": true, + // Network topology — must not leak internal addresses to external upstreams + "x-forwarded-for": true, + "x-forwarded-host": true, + "x-forwarded-proto": true, + "x-real-ip": true, + // Content metadata + "content-length": true, + "content-type": true, } // captureAuthHeaders captures the auth headers from the request. +// In addition to the named auth fields, it builds clientHeaders: all request +// headers not in wsUpgradeHeaderBlocklist and not matching the x-bf-* prefix +// (those are handled separately in extraHeaders). func captureAuthHeaders(ctx *fasthttp.RequestCtx) *authHeaders { ah := &authHeaders{ authorization: string(ctx.Request.Header.Peek("Authorization")), @@ -119,6 +203,7 @@ func captureAuthHeaders(ctx *fasthttp.RequestCtx) *authHeaders { googAPIKey: string(ctx.Request.Header.Peek("x-goog-api-key")), baggage: string(ctx.Request.Header.Peek("baggage")), extraHeaders: make(map[string]string), + clientHeaders: make(map[string]string), } for key, value := range ctx.Request.Header.All() { @@ -126,11 +211,132 @@ func captureAuthHeaders(ctx *fasthttp.RequestCtx) *authHeaders { lk := strings.ToLower(k) if strings.HasPrefix(lk, "x-bf-") { ah.extraHeaders[k] = string(value) + // x-bf-* headers are NOT forwarded to upstream as clientHeaders. + continue } + if wsUpgradeHeaderBlocklist[lk] { + continue + } + ah.clientHeaders[k] = string(value) } return ah } +// mergeClientWSHeaders merges client-supplied first-party headers into the +// provider-built upstream headers map, and optionally injects Codex identity +// defaults for any identity headers the client did not supply. +// +// Merge semantics: +// - Client headers (from the incoming WS upgrade, pre-filtered by captureAuthHeaders) +// are applied first as a base layer so that identity headers like originator, +// version, and user-agent reach the upstream backend. +// - Provider headers (built by WebSocketHeaders, which already contains OAuth +// credentials and any config-level extra headers) are applied on top so that +// Authorization, chatgpt-account-id, and OpenAI-Beta are never overwritten. +// +// This is effectively: clientHeaders → providerHeaders (provider wins on conflict). +// The wsUpgradeHeaderBlocklist already ensured clientHeaders contains no auth or +// hop-by-hop entries, so there is no security risk from the client layer. +// +// Identity fallbacks (injected ONLY when injectCodexDefaults is true AND the final +// merged map lacks the header): +// - originator: defaults to "codex_cli_rs" +// - version: defaults to openai.ChatGPTOAuthClientVersionFallback ("0.111.0") +// +// injectCodexDefaults must only be true on the ChatGPT OAuth path +// (provider is OpenAI + chatgpt_oauth enabled). Standard api.openai.com connections +// must not have these chatgpt.com-specific identity markers injected. +// +// These fallbacks satisfy chatgpt.com's anti-abuse gate, which requires both +// headers to be present. A real Codex client always sends its own values, which +// take precedence. The defaults are logged at DEBUG level when applied. +func mergeClientWSHeaders(providerHeaders, clientHeaders map[string]string, injectCodexDefaults bool) map[string]string { + // Build case-insensitive lookup of provider keys so client cannot silently + // shadow a provider header with different capitalisation. + providerLower := make(map[string]bool, len(providerHeaders)) + for k := range providerHeaders { + providerLower[strings.ToLower(k)] = true + } + + merged := make(map[string]string, len(providerHeaders)+len(clientHeaders)+2) + for k, v := range clientHeaders { + if providerLower[strings.ToLower(k)] { + continue // provider header wins + } + merged[k] = v + } + for k, v := range providerHeaders { + merged[k] = v + } + + if !injectCodexDefaults { + return merged + } + + // Inject identity defaults only when neither the client nor the provider + // supplied the header. Check case-insensitively to avoid duplicates. + if !mapContainsKeyCI(merged, "originator") { + merged["originator"] = chatGPTOAuthCodexDefaultOriginator + if logger != nil { + logger.Debug("chatgpt_oauth: injecting default originator=%s (not present in client headers)", chatGPTOAuthCodexDefaultOriginator) + } + } + if !mapContainsKeyCI(merged, "version") { + merged["version"] = chatGPTOAuthCodexDefaultVersionFallback + if logger != nil { + logger.Debug("chatgpt_oauth: injecting default version=%s (not present in client headers)", chatGPTOAuthCodexDefaultVersionFallback) + } + } + + return merged +} + +// mapContainsKeyCI reports whether m contains a key that matches target +// case-insensitively. +func mapContainsKeyCI(m map[string]string, target string) bool { + for k := range m { + if strings.EqualFold(k, target) { + return true + } + } + return false +} + +// chatGPTOAuthCodexDefaultOriginator is the originator value injected by +// mergeClientWSHeaders when the client did not supply an originator header. +const chatGPTOAuthCodexDefaultOriginator = "codex_cli_rs" + +// chatGPTOAuthCodexDefaultVersionFallback is the version value injected by +// mergeClientWSHeaders when the client did not supply a version header. +// Mirrors openai.ChatGPTOAuthClientVersionFallback. +const chatGPTOAuthCodexDefaultVersionFallback = "0.111.0" + +// wsFramePreview returns a preview string of the first 600 bytes of data. +// For valid UTF-8 payloads it returns the raw string; for binary it returns a +// hex dump of the first 300 bytes (which expands to at most 600 chars). +// The second return value is true when the payload was truncated. +func wsFramePreview(data []byte) (string, bool) { + const maxPreviewBytes = 600 + const maxHexInputBytes = 300 // hex.EncodeToString doubles length → 600 chars max + + truncated := len(data) > maxPreviewBytes + preview := data + if len(preview) > maxPreviewBytes { + preview = preview[:maxPreviewBytes] + } + + if utf8.Valid(preview) { + return string(preview), truncated + } + + // Binary: hex-dump up to maxHexInputBytes of the raw payload + raw := data + if len(raw) > maxHexInputBytes { + raw = raw[:maxHexInputBytes] + } + return hex.EncodeToString(raw), truncated +} + // eventLoop reads events from the client WebSocket and processes them. func (h *WSResponsesHandler) eventLoop(conn *ws.Conn, session *bfws.Session, auth *authHeaders) { for { @@ -209,7 +415,7 @@ func (h *WSResponsesHandler) handleResponseCreate(session *bfws.Session, auth *a } // Try native WS upstream first - if h.tryNativeWSUpstream(session, bifrostCtx, bifrostReq, message) { + if h.tryNativeWSUpstream(session, bifrostCtx, auth, bifrostReq, message) { cancel() return } @@ -221,9 +427,12 @@ func (h *WSResponsesHandler) handleResponseCreate(session *bfws.Session, auth *a // tryNativeWSUpstream attempts to forward the event to a native WS upstream connection. // Returns true if the event was handled (successfully or with error sent to client). // Returns false if the provider doesn't support WS and we should fall back to HTTP bridge. +// auth is forwarded so that clientHeaders captured at upgrade time can be merged into +// the upstream WS headers (see mergeClientWSHeaders). func (h *WSResponsesHandler) tryNativeWSUpstream( session *bfws.Session, ctx *schemas.BifrostContext, + auth *authHeaders, req *schemas.BifrostResponsesRequest, rawEvent []byte, ) bool { @@ -256,7 +465,21 @@ func (h *WSResponsesHandler) tryNativeWSUpstream( // If no upstream connection pinned, get one from the pool or dial if upstream == nil || upstream.IsClosed() { - headers := wsProvider.WebSocketHeaders(key) + // Build upstream headers: start with provider base headers (which include + // OAuth credentials for chatgpt_oauth), then merge client headers on top + // so that first-party identity headers (originator, version, user-agent …) + // forwarded from Codex are present. Provider OAuth headers (Authorization, + // chatgpt-account-id, OpenAI-Beta) retain highest priority because they + // come from the provider and are merged last inside WebSocketHeaders. + baseHeaders := wsProvider.WebSocketHeaders(key) + // Inject Codex identity defaults (originator, version) only on the ChatGPT OAuth + // path — those are chatgpt.com-specific anti-abuse markers that must not be sent + // to the standard api.openai.com endpoint. + // x-bf-eh-* header forwarding is handled by mergeWSExtraHeaders (extracted to + // PR #3014 against main); client-upgrade header forwarding has been removed here. + openaiProvider, _ := provider.(*openai.OpenAIProvider) + injectCodexDefaults := openaiProvider != nil && openaiProvider.ChatGPTOAuthEnabled() + headers := mergeClientWSHeaders(baseHeaders, nil, injectCodexDefaults) poolKey := bfws.PoolKey{ Provider: req.Provider, KeyID: key.ID, @@ -290,30 +513,73 @@ func (h *WSResponsesHandler) tryNativeWSUpstream( return true } + // Retrieve tracer and traceID for chunk accumulation + tracer, _ := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer) + traceID, _ := ctx.Value(schemas.BifrostContextKeyTraceID).(string) + + // DIAGNOSTIC: track frames and exit reason for end-of-stream summary log. + // Logged at debug level; silent at LOG_LEVEL=info (default). Enable with LOG_LEVEL=debug. + diagFrameIndex := 0 + diagExitReason := "unknown" + diagTerminalType := "" + defer func() { + args := []any{ + "provider", string(req.Provider), + "model", req.Model, + "frames_forwarded", diagFrameIndex, + "exit_reason", diagExitReason, + } + if diagTerminalType != "" { + args = append(args, "terminal_type", diagTerminalType) + } + logger.Debug("ws upstream stream ended", args...) + }() + // Forward the raw event to upstream if err := upstream.WriteMessage(ws.TextMessage, rawEvent); err != nil { logger.Warn("upstream WS write failed for %s: %v, falling back to HTTP bridge", req.Provider, err) h.pool.Discard(upstream) session.SetUpstream(nil) + diagExitReason = "upstream_dial_or_write_failed" return false } - // Retrieve tracer and traceID for chunk accumulation - tracer, _ := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer) - traceID, _ := ctx.Value(schemas.BifrostContextKeyTraceID).(string) - // Read response events from upstream and relay to client, running post-hooks per chunk forwardedAny := false for { msgType, data, readErr := upstream.ReadMessage() + + // DIAGNOSTIC: log each raw upstream frame at debug level (silent at LOG_LEVEL=info). + if readErr == nil { + diagFrameIndex++ + preview, truncated := wsFramePreview(data) + frameArgs := []any{ + "provider", string(req.Provider), + "model", req.Model, + "msg_type", msgType, + "len", len(data), + "frame_index", diagFrameIndex, + "preview", preview, + } + if truncated { + frameArgs = append(frameArgs, "truncated", true) + } + logger.Debug("ws upstream frame", frameArgs...) + } + if readErr != nil { - logger.Warn("upstream WS read failed for %s: %v, falling back to HTTP bridge", req.Provider, readErr) h.pool.Discard(upstream) session.SetUpstream(nil) + if !forwardedAny { + logger.Warn("upstream WS read failed for %s: %v, falling back to HTTP bridge", req.Provider, readErr) + diagExitReason = "upstream_dial_or_write_failed" return false } + + logger.Warn("upstream WS read failed for %s after forwarding data: %v", req.Provider, readErr) writeWSError(session, 502, "upstream_connection_error", "upstream websocket stream interrupted") + diagExitReason = "abnormal_close" return true } @@ -336,6 +602,7 @@ func (h *WSResponsesHandler) tryNativeWSUpstream( h.pool.Discard(upstream) session.SetUpstream(nil) writeWSBifrostError(session, postErr) + diagExitReason = "post_hook_error" return true } } @@ -343,12 +610,17 @@ func (h *WSResponsesHandler) tryNativeWSUpstream( if writeErr := session.WriteMessage(msgType, data); writeErr != nil { h.pool.Discard(upstream) session.SetUpstream(nil) + diagExitReason = "client_write_failed" return true } forwardedAny = true if isTerminal { h.trackResponseID(session, data) + diagExitReason = "terminal_event_detected" + if streamResp != nil { + diagTerminalType = string(streamResp.Type) + } return true } } @@ -562,12 +834,40 @@ func createBifrostContextFromAuth(handlerStore lib.HandlerStore, auth *authHeade } } - // Forward x-bf-* headers + // Populate BifrostContextKeyRequestHeaders so downstream consumers (governance, + // logging plugins, provider-specific header forwarding) can access the full + // set of client headers. Keys are lowercased for case-insensitive lookup, + // matching the behaviour of the HTTP path in lib/ctx.go. + // We merge clientHeaders (non-Bifrost, non-hop-by-hop) with the named auth + // fields so the map is complete (auth fields were excluded from clientHeaders + // by the blocklist to prevent accidental forwarding upstream, but they are + // still relevant for downstream plugin inspection). + allHeaders := make(map[string]string, len(auth.clientHeaders)+6) + for k, v := range auth.clientHeaders { + allHeaders[strings.ToLower(k)] = v + } + if auth.authorization != "" { + allHeaders["authorization"] = auth.authorization + } + if auth.virtualKey != "" { + allHeaders["x-bf-vk"] = auth.virtualKey + } + if auth.apiKey != "" { + allHeaders["x-api-key"] = auth.apiKey + } + if auth.googAPIKey != "" { + allHeaders["x-goog-api-key"] = auth.googAPIKey + } + if auth.baggage != "" { + allHeaders["baggage"] = auth.baggage + } + // Process extraHeaders in a single pass: add to allHeaders AND handle x-bf-* context keys. for k, v := range auth.extraHeaders { lk := strings.ToLower(k) + allHeaders[lk] = v switch { case lk == "x-bf-vk": - // Already handled above + // Already handled above via auth.virtualKey case lk == "x-bf-api-key": ctx.SetValue(schemas.BifrostContextKeyAPIKeyName, v) case strings.HasPrefix(lk, "x-bf-eh-"): @@ -580,6 +880,7 @@ func createBifrostContextFromAuth(handlerStore lib.HandlerStore, auth *authHeade ctx.SetValue(schemas.BifrostContextKeyExtraHeaders, existing) } } + ctx.SetValue(schemas.BifrostContextKeyRequestHeaders, allHeaders) return ctx, cancel } @@ -685,7 +986,7 @@ var wsResponsesKnownFields = map[string]bool{ } var ( - errModelFormat = errorf("model should be in provider/model format") + errModelFormat = errorf("model is required") errInputRequired = errorf("input is required for responses") ) diff --git a/transports/bifrost-http/handlers/wsresponses_test.go b/transports/bifrost-http/handlers/wsresponses_test.go index aad3b15e9c..2bdb740767 100644 --- a/transports/bifrost-http/handlers/wsresponses_test.go +++ b/transports/bifrost-http/handlers/wsresponses_test.go @@ -1,6 +1,7 @@ package handlers import ( + "strings" "testing" "github.com/maximhq/bifrost/core/schemas" @@ -66,3 +67,301 @@ func TestCreateBifrostContextFromAuth_EmptyBaggageSessionIDIgnored(t *testing.T) t.Fatalf("parent request id should be unset, got %#v", got) } } + +// --------------------------------------------------------------------------- +// wsUpgradeHeaderBlocklist: verify hop-by-hop and Bifrost-internal headers are blocked +// --------------------------------------------------------------------------- + +func TestWSUpgradeHeaderBlocklist_HopByHopDropped(t *testing.T) { + hopByHop := []string{ + "connection", + "upgrade", + "sec-websocket-key", + "sec-websocket-version", + "sec-websocket-extensions", + "sec-websocket-protocol", + "keep-alive", + "proxy-authorization", + "proxy-authenticate", + "te", + "trailer", + "trailers", + "transfer-encoding", + } + for _, h := range hopByHop { + if !wsUpgradeHeaderBlocklist[strings.ToLower(h)] { + t.Errorf("expected hop-by-hop header %q to be in blocklist, but it is not", h) + } + } +} + +func TestWSUpgradeHeaderBlocklist_OriginDropped(t *testing.T) { + // origin must be blocked: forwarding Origin: http://localhost... would cause + // chatgpt.com to CORS-reject the upstream WebSocket connection. + if !wsUpgradeHeaderBlocklist["origin"] { + t.Error("expected \"origin\" to be in blocklist") + } +} + +func TestWSUpgradeHeaderBlocklist_NetworkTopologyDropped(t *testing.T) { + // These headers must not leak internal network topology to external upstreams. + topology := []string{ + "x-forwarded-for", + "x-forwarded-host", + "x-forwarded-proto", + "x-real-ip", + } + for _, h := range topology { + if !wsUpgradeHeaderBlocklist[strings.ToLower(h)] { + t.Errorf("expected network topology header %q to be in blocklist, but it is not", h) + } + } +} + +func TestWSUpgradeHeaderBlocklist_SecurityHeadersDropped(t *testing.T) { + security := []string{ + "host", + "cookie", + "authorization", + "x-api-key", + "x-goog-api-key", + "baggage", + } + for _, h := range security { + if !wsUpgradeHeaderBlocklist[strings.ToLower(h)] { + t.Errorf("expected security header %q to be in blocklist, but it is not", h) + } + } +} + +func TestWSUpgradeHeaderBlocklist_CodexIdentityHeadersAllowed(t *testing.T) { + // These first-party Codex headers must NOT be in the blocklist so they + // flow through to the upstream WS connection. + allowed := []string{ + "originator", + "version", + "user-agent", + "session_id", + } + for _, h := range allowed { + if wsUpgradeHeaderBlocklist[strings.ToLower(h)] { + t.Errorf("expected Codex identity header %q to be allowed (not in blocklist)", h) + } + } +} + +// --------------------------------------------------------------------------- +// mergeClientWSHeaders: verify merge semantics +// --------------------------------------------------------------------------- + +// TestMergeClientWSHeaders_ClientHeadersFlowThrough models the production case: +// providerHeaders comes from chatGPTOAuthWebSocketHeaders (OAuth only, no identity +// defaults), and the client sends real Codex identity headers. The merge must +// preserve the client's values unchanged. +func TestMergeClientWSHeaders_ClientHeadersFlowThrough(t *testing.T) { + // Production-accurate: chatGPTOAuthWebSocketHeaders returns only OAuth headers. + provider := map[string]string{ + "Authorization": "Bearer token", + "chatgpt-account-id": "acct-123", + "OpenAI-Beta": "responses=experimental", + } + // Client (Codex) sends its own identity headers. + client := map[string]string{ + "originator": "codex_cli_rs", + "version": "0.121.0", + "user-agent": "codex/0.121.0 (Linux; amd64)", + } + got := mergeClientWSHeaders(provider, client, true) + + // Client identity headers must flow through unchanged. + if got["originator"] != "codex_cli_rs" { + t.Errorf("originator = %q, want %q", got["originator"], "codex_cli_rs") + } + // Critical: client's version (0.121.0) must NOT be replaced by the default (0.111.0). + if got["version"] != "0.121.0" { + t.Errorf("version = %q, want %q (client value must win over default)", got["version"], "0.121.0") + } + if got["user-agent"] != "codex/0.121.0 (Linux; amd64)" { + t.Errorf("user-agent = %q, want codex/0.121.0 (Linux; amd64)", got["user-agent"]) + } + // OAuth headers must also be present. + if got["Authorization"] != "Bearer token" { + t.Errorf("Authorization = %q, want %q", got["Authorization"], "Bearer token") + } +} + +// TestMergeClientWSHeaders_DefaultsInjectedWhenClientSendsNeither verifies that +// mergeClientWSHeaders injects the identity fallbacks when neither the client nor +// the provider supplied originator or version. +func TestMergeClientWSHeaders_DefaultsInjectedWhenClientSendsNeither(t *testing.T) { + provider := map[string]string{ + "Authorization": "Bearer token", + "chatgpt-account-id": "acct-123", + "OpenAI-Beta": "responses=experimental", + } + // Client sends no identity headers at all. + client := map[string]string{ + "user-agent": "some-agent/1.0", + } + got := mergeClientWSHeaders(provider, client, true) + + if got["originator"] != chatGPTOAuthCodexDefaultOriginator { + t.Errorf("originator = %q, want default %q", got["originator"], chatGPTOAuthCodexDefaultOriginator) + } + if got["version"] != chatGPTOAuthCodexDefaultVersionFallback { + t.Errorf("version = %q, want default %q", got["version"], chatGPTOAuthCodexDefaultVersionFallback) + } +} + +// TestMergeClientWSHeaders_ClientOriginatorWinsOverDefault verifies that when the +// client supplies a custom originator, it wins over the injected default. +func TestMergeClientWSHeaders_ClientOriginatorWinsOverDefault(t *testing.T) { + provider := map[string]string{ + "Authorization": "Bearer token", + } + client := map[string]string{ + "originator": "something-else", + "version": "9.9.9", + } + got := mergeClientWSHeaders(provider, client, true) + + if got["originator"] != "something-else" { + t.Errorf("originator = %q, want %q (client value must win)", got["originator"], "something-else") + } + if got["version"] != "9.9.9" { + t.Errorf("version = %q, want %q (client value must win)", got["version"], "9.9.9") + } +} + +func TestMergeClientWSHeaders_ProviderHeadersWinOnConflict(t *testing.T) { + provider := map[string]string{ + "Authorization": "Bearer oauth-token", + } + client := map[string]string{ + // Client should never win on Authorization — blocklist prevents this in + // captureAuthHeaders, but mergeClientWSHeaders defends in depth. + "Authorization": "Bearer client-should-lose", + "originator": "codex_cli_rs", + } + got := mergeClientWSHeaders(provider, client, true) + + if got["Authorization"] != "Bearer oauth-token" { + t.Errorf("Authorization = %q, want provider value %q", got["Authorization"], "Bearer oauth-token") + } + if got["originator"] != "codex_cli_rs" { + t.Errorf("originator = %q, want %q", got["originator"], "codex_cli_rs") + } +} + +func TestMergeClientWSHeaders_EmptyClientHeadersGetDefaultsInjected(t *testing.T) { + // Even with no client headers, the identity defaults are always injected so + // chatgpt.com's anti-abuse gate receives the required identity markers. + provider := map[string]string{ + "Authorization": "Bearer tok", + } + got := mergeClientWSHeaders(provider, nil, true) + if got["Authorization"] != "Bearer tok" { + t.Errorf("Authorization = %q, want %q", got["Authorization"], "Bearer tok") + } + if got["originator"] != chatGPTOAuthCodexDefaultOriginator { + t.Errorf("originator = %q, want default %q", got["originator"], chatGPTOAuthCodexDefaultOriginator) + } + if got["version"] != chatGPTOAuthCodexDefaultVersionFallback { + t.Errorf("version = %q, want default %q", got["version"], chatGPTOAuthCodexDefaultVersionFallback) + } +} + +// TestMergeClientWSHeaders_NonOAuthPathNoDefaultsInjected verifies that when +// injectCodexDefaults is false (standard api.openai.com, non-OAuth provider), +// mergeClientWSHeaders does NOT inject the chatgpt.com-specific originator or +// version headers even when the client sends none. +func TestMergeClientWSHeaders_NonOAuthPathNoDefaultsInjected(t *testing.T) { + // Standard OpenAI provider headers — no OAuth headers. + provider := map[string]string{ + "Authorization": "Bearer sk-abc123", + } + // Client sends no identity headers. + client := map[string]string{ + "user-agent": "myapp/1.0", + } + got := mergeClientWSHeaders(provider, client, false) + + if _, hasOriginator := got["originator"]; hasOriginator { + t.Errorf("originator must NOT be injected on non-OAuth path, but found %q", got["originator"]) + } + if _, hasVersion := got["version"]; hasVersion { + t.Errorf("version must NOT be injected on non-OAuth path, but found %q", got["version"]) + } + // Provider and client headers must still be merged correctly. + if got["Authorization"] != "Bearer sk-abc123" { + t.Errorf("Authorization = %q, want %q", got["Authorization"], "Bearer sk-abc123") + } + if got["user-agent"] != "myapp/1.0" { + t.Errorf("user-agent = %q, want %q", got["user-agent"], "myapp/1.0") + } +} + +// --------------------------------------------------------------------------- +// createBifrostContextFromAuth: verify BifrostContextKeyRequestHeaders is set +// --------------------------------------------------------------------------- + +// --------------------------------------------------------------------------- +// isTerminalStreamType: terminal detection +// --------------------------------------------------------------------------- + +func TestIsTerminalStreamType_TerminalTypes(t *testing.T) { + terminals := []schemas.ResponsesStreamResponseType{ + schemas.ResponsesStreamResponseTypeCompleted, + schemas.ResponsesStreamResponseTypeFailed, + schemas.ResponsesStreamResponseTypeIncomplete, + schemas.ResponsesStreamResponseTypeError, + } + for _, tt := range terminals { + if !isTerminalStreamType(tt) { + t.Errorf("expected %q to be terminal", tt) + } + } +} + +func TestIsTerminalStreamType_NonTerminalTypes(t *testing.T) { + nonTerminals := []schemas.ResponsesStreamResponseType{ + schemas.ResponsesStreamResponseTypeOutputTextDelta, + schemas.ResponsesStreamResponseTypeCreated, + schemas.ResponsesStreamResponseTypeInProgress, + schemas.ResponsesStreamResponseType("codex.rate_limits"), + schemas.ResponsesStreamResponseType(""), + } + for _, tt := range nonTerminals { + if isTerminalStreamType(tt) { + t.Errorf("expected %q to be non-terminal", tt) + } + } +} + +func TestCreateBifrostContextFromAuth_RequestHeadersPopulated(t *testing.T) { + auth := &authHeaders{ + authorization: "Bearer some-token", + apiKey: "", + clientHeaders: map[string]string{ + "originator": "codex_cli_rs", + "version": "0.121.0", + "user-agent": "codex/0.121.0", + }, + extraHeaders: make(map[string]string), + } + + ctx, cancel := createBifrostContextFromAuth(testWSHandlerStore{}, auth) + defer cancel() + + headers, ok := ctx.Value(schemas.BifrostContextKeyRequestHeaders).(map[string]string) + if !ok || headers == nil { + t.Fatal("BifrostContextKeyRequestHeaders not set or wrong type") + } + if headers["originator"] != "codex_cli_rs" { + t.Errorf("originator = %q, want %q", headers["originator"], "codex_cli_rs") + } + if headers["authorization"] != "Bearer some-token" { + t.Errorf("authorization = %q, want %q", headers["authorization"], "Bearer some-token") + } +} + diff --git a/transports/bifrost-http/websocket/pool_test.go b/transports/bifrost-http/websocket/pool_test.go index 9f734656ab..c507cf1110 100644 --- a/transports/bifrost-http/websocket/pool_test.go +++ b/transports/bifrost-http/websocket/pool_test.go @@ -1,6 +1,8 @@ package websocket import ( + "errors" + "net" "net/http" "net/http/httptest" "strings" @@ -128,6 +130,159 @@ func TestPoolClose(t *testing.T) { assert.Error(t, err) } +// --------------------------------------------------------------------------- +// Idle-timeout / SetReadDeadline behaviour tests. +// These tests exercise UpstreamConn.SetReadDeadline directly so that the +// tryNativeWSUpstream idle-timeout logic (which calls SetReadDeadline before +// each ReadMessage) is covered at the lowest possible level. +// --------------------------------------------------------------------------- + +// TestUpstreamConnReadDeadline_Timeout verifies that a read that is given a +// very short deadline fails with a timeout error when the server never sends. +func TestUpstreamConnReadDeadline_Timeout(t *testing.T) { + // Server that upgrades but never writes any frames (simulates a stalling + // upstream: e.g. silent rate-limit hold after accepting the WS upgrade). + upgrader := ws.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + // Intentionally block forever — simulates upstream stall. + select {} + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + wsConn, _, err := Dial(wsURL, nil) + require.NoError(t, err) + uc := newUpstreamConn(wsConn, schemas.OpenAI, "k1", wsURL) + defer uc.Close() + + const shortDeadline = 100 * time.Millisecond + start := time.Now() + require.NoError(t, uc.SetReadDeadline(time.Now().Add(shortDeadline))) + _, _, readErr := uc.ReadMessage() + elapsed := time.Since(start) + + require.Error(t, readErr, "expected read to fail with timeout") + + var netErr net.Error + require.True(t, errors.As(readErr, &netErr) && netErr.Timeout(), + "expected a net.Error with Timeout()=true, got: %v", readErr) + + // Elapsed time should be close to the deadline, not many seconds. + assert.Less(t, elapsed, shortDeadline+500*time.Millisecond, + "read should have timed out quickly") +} + +// TestUpstreamConnReadDeadline_PeriodicFramesNoTimeout verifies that an +// upstream that sends a frame every shortInterval does not trigger a timeout +// when the deadline is longer than the interval. Each successful read clears +// the deadline (as tryNativeWSUpstream does) so the stream stays alive. +func TestUpstreamConnReadDeadline_PeriodicFramesNoTimeout(t *testing.T) { + const frameInterval = 80 * time.Millisecond + const idleTimeout = 300 * time.Millisecond + const numFrames = 4 + + upgrader := ws.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + for range numFrames { + time.Sleep(frameInterval) + if werr := conn.WriteMessage(ws.TextMessage, []byte(`{"type":"ping"}`)); werr != nil { + return + } + } + // After sending all frames, close normally. + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + wsConn, _, err := Dial(wsURL, nil) + require.NoError(t, err) + uc := newUpstreamConn(wsConn, schemas.OpenAI, "k1", wsURL) + defer uc.Close() + + received := 0 + for { + // Replicate the per-read deadline pattern from tryNativeWSUpstream. + require.NoError(t, uc.SetReadDeadline(time.Now().Add(idleTimeout))) + _, data, readErr := uc.ReadMessage() + if readErr != nil { + // Server closed cleanly after numFrames — not a timeout. + var netErr net.Error + if errors.As(readErr, &netErr) && netErr.Timeout() { + t.Fatalf("unexpected timeout after %d frames (interval %v < deadline %v)", received, frameInterval, idleTimeout) + } + break + } + // Clear deadline after successful read (mirrors tryNativeWSUpstream). + _ = uc.SetReadDeadline(time.Time{}) + _ = data + received++ + } + assert.Equal(t, numFrames, received, "expected to receive all frames without timeout") +} + +// TestUpstreamConnReadDeadline_OneThenSilent verifies that a timeout fires +// after idleness FOLLOWING the first frame, not at request-start + timeout. +// The server sends one frame immediately and then goes silent. +func TestUpstreamConnReadDeadline_OneThenSilent(t *testing.T) { + const idleTimeout = 150 * time.Millisecond + + upgrader := ws.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + // Send exactly one frame, then stall. + conn.WriteMessage(ws.TextMessage, []byte(`{"type":"response.created"}`)) //nolint:errcheck + // Block indefinitely — simulates upstream stall after initial frame. + select {} + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + wsConn, _, err := Dial(wsURL, nil) + require.NoError(t, err) + uc := newUpstreamConn(wsConn, schemas.OpenAI, "k1", wsURL) + defer uc.Close() + + // First read: should succeed within idleTimeout. + require.NoError(t, uc.SetReadDeadline(time.Now().Add(idleTimeout))) + _, _, firstErr := uc.ReadMessage() + require.NoError(t, firstErr, "first read (one frame sent) should succeed") + // Clear deadline — mirrors tryNativeWSUpstream on a successful read. + _ = uc.SetReadDeadline(time.Time{}) + + // Second read: server is now silent. Set a new idle deadline. + start := time.Now() + require.NoError(t, uc.SetReadDeadline(time.Now().Add(idleTimeout))) + _, _, secondErr := uc.ReadMessage() + elapsed := time.Since(start) + + require.Error(t, secondErr, "second read should fail (upstream stalled)") + var netErr net.Error + require.True(t, errors.As(secondErr, &netErr) && netErr.Timeout(), + "expected timeout error on second read, got: %v", secondErr) + + // The timeout should have fired approximately idleTimeout after the SECOND + // read attempt, not at request-start + idleTimeout. We verify it did NOT + // fire instantly (i.e. the first read succeeded and reset the clock). + assert.GreaterOrEqual(t, elapsed, idleTimeout/2, + "timeout should not fire before the idle deadline expires") + assert.Less(t, elapsed, idleTimeout+500*time.Millisecond, + "timeout should fire close to idleTimeout after stall begins") +} + func TestPoolExpiredConnection(t *testing.T) { server := startTestWSServer(t) defer server.Close() diff --git a/ui/app/workspace/providers/fragments/openaiConfigFormFragment.tsx b/ui/app/workspace/providers/fragments/openaiConfigFormFragment.tsx index 79646109b9..b929bf6cd7 100644 --- a/ui/app/workspace/providers/fragments/openaiConfigFormFragment.tsx +++ b/ui/app/workspace/providers/fragments/openaiConfigFormFragment.tsx @@ -26,6 +26,7 @@ export function OpenAIConfigFormFragment({ provider }: OpenAIConfigFormFragmentP reValidateMode: "onChange", defaultValues: { disable_store: provider.openai_config?.disable_store ?? false, + chatgpt_oauth: provider.openai_config?.chatgpt_oauth ?? false, }, }); @@ -36,14 +37,16 @@ export function OpenAIConfigFormFragment({ provider }: OpenAIConfigFormFragmentP useEffect(() => { form.reset({ disable_store: provider.openai_config?.disable_store ?? false, + chatgpt_oauth: provider.openai_config?.chatgpt_oauth ?? false, }); - }, [form, provider.name, provider.openai_config?.disable_store]); + }, [form, provider.name, provider.openai_config?.disable_store, provider.openai_config?.chatgpt_oauth]); const onSubmit = (data: OpenAIConfigFormSchema) => { updateProvider( buildProviderUpdatePayload(provider, { openai_config: { disable_store: data.disable_store, + chatgpt_oauth: data.chatgpt_oauth, }, }), ) @@ -94,6 +97,37 @@ export function OpenAIConfigFormFragment({ provider }: OpenAIConfigFormFragmentP )} /> + ( + +
+
+ ChatGPT OAuth Passthrough +

+ Route requests through ChatGPT's backend API for subscription-based (seat/plan) access. When enabled, Bifrost + automatically extracts the OAuth token from the request, routes to chatgpt.com/backend-api/codex, and injects the + required headers and body transformations. No additional security settings or API keys are needed. +

+
+ + { + field.onChange(checked); + form.trigger("chatgpt_oauth"); + }} + /> + +
+ +
+ )} + />
diff --git a/ui/lib/types/config.ts b/ui/lib/types/config.ts index 368e835c7c..894b1e8ec3 100644 --- a/ui/lib/types/config.ts +++ b/ui/lib/types/config.ts @@ -295,6 +295,7 @@ export interface CustomProviderConfig { // OpenAIConfig holds OpenAI-specific provider configuration. export interface OpenAIConfig { disable_store?: boolean; + chatgpt_oauth?: boolean; } // ProviderConfig matching Go's lib.ProviderConfig diff --git a/ui/lib/types/schemas.ts b/ui/lib/types/schemas.ts index 7599162721..ec88de0e79 100644 --- a/ui/lib/types/schemas.ts +++ b/ui/lib/types/schemas.ts @@ -473,6 +473,7 @@ export const proxyFormConfigSchema = z // OpenAI Config tab export const openaiConfigFormSchema = z.object({ disable_store: z.boolean(), + chatgpt_oauth: z.boolean(), }); export type OpenAIConfigFormSchema = z.infer;