Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions transports/bifrost-http/handlers/wsresponses.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ func (h *WSResponsesHandler) RegisterRoutes(r *router.Router, middlewares ...sch

// handleUpgrade upgrades the HTTP connection to WebSocket and starts the event loop.
func (h *WSResponsesHandler) handleUpgrade(ctx *fasthttp.RequestCtx) {
// Infer the default provider from the request path.
// Integration-scoped paths (e.g. /openai/v1/responses) are already bound to
// OpenAI by the URL, so bare model strings like "gpt-4o" are unambiguous and
// should be accepted without a provider prefix.
// The unified /v1/responses path is multi-provider by design and requires an
// explicit "provider/model" format, so its default remains empty.
defaultProvider := inferDefaultProviderFromPath(string(ctx.Path()))

err := h.upgrader.Upgrade(ctx, func(conn *ws.Conn) {
defer conn.Close()

Expand All @@ -93,13 +101,25 @@ func (h *WSResponsesHandler) handleUpgrade(ctx *fasthttp.RequestCtx) {
// Capture auth headers from the upgrade request for per-event context creation
authHeaders := captureAuthHeaders(ctx)

h.eventLoop(conn, session, authHeaders)
h.eventLoop(conn, session, authHeaders, defaultProvider)
})
if err != nil {
logger.Warn("websocket upgrade failed for /v1/responses: %v", err)
}
}

// inferDefaultProviderFromPath returns the default ModelProvider to use when
// parsing a bare model string for a given request path.
// Paths under the /openai/ integration prefix default to OpenAI.
// All other paths (including the unified /v1/responses) return an empty
// provider, preserving the requirement for an explicit "provider/model" format.
func inferDefaultProviderFromPath(path string) schemas.ModelProvider {
if strings.HasPrefix(path, "/openai/") {
return schemas.OpenAI
}
return ""
}

// authHeaders holds auth-related headers captured during the WS upgrade.
type authHeaders struct {
authorization string
Expand Down Expand Up @@ -132,7 +152,9 @@ func captureAuthHeaders(ctx *fasthttp.RequestCtx) *authHeaders {
}

// eventLoop reads events from the client WebSocket and processes them.
func (h *WSResponsesHandler) eventLoop(conn *ws.Conn, session *bfws.Session, auth *authHeaders) {
// defaultProvider is the provider inferred from the upgrade path and is
// forwarded to handleResponseCreate for bare-model-string resolution.
func (h *WSResponsesHandler) eventLoop(conn *ws.Conn, session *bfws.Session, auth *authHeaders, defaultProvider schemas.ModelProvider) {
for {
_, message, err := conn.ReadMessage()
if err != nil {
Expand All @@ -153,7 +175,7 @@ func (h *WSResponsesHandler) eventLoop(conn *ws.Conn, session *bfws.Session, aut

switch schemas.WebSocketEventType(envelope.Type) {
case schemas.WSEventResponseCreate:
h.handleResponseCreate(session, auth, message)
h.handleResponseCreate(session, auth, message, defaultProvider)
default:
writeWSError(session, 400, "invalid_request_error", "unsupported event type: "+envelope.Type)
}
Expand All @@ -163,7 +185,10 @@ func (h *WSResponsesHandler) eventLoop(conn *ws.Conn, session *bfws.Session, aut
// handleResponseCreate processes a response.create event.
// Strategy: try native WS upstream for providers that support it, otherwise use HTTP bridge.
// If native WS upstream fails mid-stream, falls back to HTTP bridge.
func (h *WSResponsesHandler) handleResponseCreate(session *bfws.Session, auth *authHeaders, message []byte) {
// defaultProvider is the provider inferred from the upgrade path (e.g. schemas.OpenAI for
// /openai/v1/responses). It is forwarded to ParseModelString so that bare model strings
// like "gpt-4o" are accepted on integration-scoped paths without an explicit prefix.
func (h *WSResponsesHandler) handleResponseCreate(session *bfws.Session, auth *authHeaders, message []byte, defaultProvider schemas.ModelProvider) {
var event schemas.WebSocketResponsesEvent

if err := sonic.Unmarshal(message, &event); err != nil {
Expand All @@ -174,7 +199,7 @@ func (h *WSResponsesHandler) handleResponseCreate(session *bfws.Session, auth *a
// Store override: default to store=true (Codex sends false by default but expects true).
// If DisableStore is set in provider config, force store=false.
// If client explicitly sets store, respect that value unless DisableStore overrides it.
provider, modelName := schemas.ParseModelString(event.Model, "")
provider, modelName := schemas.ParseModelString(event.Model, defaultProvider)
if provider == "" || modelName == "" {
writeWSError(session, 400, "invalid_request_error", "failed to parse model string")
return
Expand All @@ -187,7 +212,7 @@ func (h *WSResponsesHandler) handleResponseCreate(session *bfws.Session, auth *a
event.Store = schemas.Ptr(true)
}

bifrostReq, err := h.convertEventToRequest(&event)
bifrostReq, err := h.convertEventToRequest(&event, defaultProvider)
if err != nil {
writeWSError(session, 400, "invalid_request_error", err.Error())
return
Expand Down Expand Up @@ -418,8 +443,10 @@ func (h *WSResponsesHandler) trackResponseID(session *bfws.Session, data []byte)
}

// convertEventToRequest converts a WebSocket response.create event to a BifrostResponsesRequest.
func (h *WSResponsesHandler) convertEventToRequest(event *schemas.WebSocketResponsesEvent) (*schemas.BifrostResponsesRequest, error) {
provider, modelName := schemas.ParseModelString(event.Model, "")
// defaultProvider is forwarded to ParseModelString so that bare model strings are resolved
// correctly on integration-scoped paths (e.g. /openai/v1/responses supplies schemas.OpenAI).
func (h *WSResponsesHandler) convertEventToRequest(event *schemas.WebSocketResponsesEvent, defaultProvider schemas.ModelProvider) (*schemas.BifrostResponsesRequest, error) {
provider, modelName := schemas.ParseModelString(event.Model, defaultProvider)
if provider == "" || modelName == "" {
return nil, errModelFormat
}
Expand Down
135 changes: 135 additions & 0 deletions transports/bifrost-http/handlers/wsresponses_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,138 @@ func TestCreateBifrostContextFromAuth_EmptyBaggageSessionIDIgnored(t *testing.T)
t.Fatalf("parent request id should be unset, got %#v", got)
}
}

// ---------------------------------------------------------------------------
// inferDefaultProviderFromPath
// ---------------------------------------------------------------------------

func TestInferDefaultProviderFromPath_OpenAIIntegrationPrefix(t *testing.T) {
paths := []string{
"/openai/v1/responses",
"/openai/responses",
"/openai/openai/responses",
"/openai/v1/chat/completions",
}
for _, p := range paths {
got := inferDefaultProviderFromPath(p)
if got != schemas.OpenAI {
t.Errorf("inferDefaultProviderFromPath(%q) = %q, want %q", p, got, schemas.OpenAI)
}
}
}

func TestInferDefaultProviderFromPath_UnifiedPathNoDefault(t *testing.T) {
paths := []string{
"/v1/responses",
"/v1/chat/completions",
"/",
"",
"/anthropic/v1/messages",
}
for _, p := range paths {
got := inferDefaultProviderFromPath(p)
if got != "" {
t.Errorf("inferDefaultProviderFromPath(%q) = %q, want empty (no default)", p, got)
}
}
}

// ---------------------------------------------------------------------------
// convertEventToRequest: bare vs prefixed model on integration and unified paths
// ---------------------------------------------------------------------------

// buildMinimalEvent returns a minimal WebSocketResponsesEvent suitable for
// convertEventToRequest. input is valid JSON (e.g. a JSON array string).
func buildMinimalEvent(model string, inputJSON []byte) *schemas.WebSocketResponsesEvent {
return &schemas.WebSocketResponsesEvent{
Model: model,
Input: inputJSON,
}
}

var minimalInput = []byte(`[{"type":"message","role":"user","content":[{"type":"input_text","text":"hi"}]}]`)

// TestConvertEventToRequest_BareModelIntegrationPath verifies that a bare model
// string (e.g. "gpt-4o") is accepted when defaultProvider is schemas.OpenAI,
// matching the /openai/v1/responses integration path behavior.
func TestConvertEventToRequest_BareModelIntegrationPath(t *testing.T) {
h := &WSResponsesHandler{config: nil}
event := buildMinimalEvent("gpt-4o", minimalInput)

req, err := h.convertEventToRequest(event, schemas.OpenAI)
if err != nil {
t.Fatalf("expected no error for bare model on integration path, got: %v", err)
}
if req.Provider != schemas.OpenAI {
t.Errorf("Provider = %q, want %q", req.Provider, schemas.OpenAI)
}
if req.Model != "gpt-4o" {
t.Errorf("Model = %q, want %q", req.Model, "gpt-4o")
}
}

// TestConvertEventToRequest_BareModelUnifiedPathRejected verifies that a bare
// model string is rejected when defaultProvider is empty, matching the unified
// /v1/responses path behavior (multi-provider, requires explicit prefix).
func TestConvertEventToRequest_BareModelUnifiedPathRejected(t *testing.T) {
h := &WSResponsesHandler{config: nil}
event := buildMinimalEvent("gpt-4o", minimalInput)

_, err := h.convertEventToRequest(event, "")
if err == nil {
t.Fatal("expected error for bare model on unified path, got nil")
}
}

// TestConvertEventToRequest_PrefixedModelIntegrationPath verifies that an
// explicitly prefixed model string (e.g. "openai/gpt-4o") works on the
// integration path (defaultProvider = schemas.OpenAI).
func TestConvertEventToRequest_PrefixedModelIntegrationPath(t *testing.T) {
h := &WSResponsesHandler{config: nil}
event := buildMinimalEvent("openai/gpt-4o", minimalInput)

req, err := h.convertEventToRequest(event, schemas.OpenAI)
if err != nil {
t.Fatalf("expected no error for prefixed model on integration path, got: %v", err)
}
if req.Provider != schemas.OpenAI {
t.Errorf("Provider = %q, want %q", req.Provider, schemas.OpenAI)
}
if req.Model != "gpt-4o" {
t.Errorf("Model = %q, want %q", req.Model, "gpt-4o")
}
}

// TestConvertEventToRequest_PrefixedModelUnifiedPath verifies that an explicitly
// prefixed model string (e.g. "openai/gpt-4o") is accepted on the unified path
// (defaultProvider = ""), i.e. the current working mode is not broken.
func TestConvertEventToRequest_PrefixedModelUnifiedPath(t *testing.T) {
h := &WSResponsesHandler{config: nil}
event := buildMinimalEvent("openai/gpt-4o", minimalInput)

req, err := h.convertEventToRequest(event, "")
if err != nil {
t.Fatalf("expected no error for prefixed model on unified path, got: %v", err)
}
if req.Provider != schemas.OpenAI {
t.Errorf("Provider = %q, want %q", req.Provider, schemas.OpenAI)
}
}

// TestConvertEventToRequest_AnthropicPrefixedModelUnifiedPath verifies that an
// Anthropic-prefixed model string works on the unified path without ambiguity.
func TestConvertEventToRequest_AnthropicPrefixedModelUnifiedPath(t *testing.T) {
h := &WSResponsesHandler{config: nil}
event := buildMinimalEvent("anthropic/claude-3-5-sonnet-20241022", minimalInput)

req, err := h.convertEventToRequest(event, "")
if err != nil {
t.Fatalf("expected no error for anthropic-prefixed model on unified path, got: %v", err)
}
if req.Provider != schemas.Anthropic {
t.Errorf("Provider = %q, want %q", req.Provider, schemas.Anthropic)
}
if req.Model != "claude-3-5-sonnet-20241022" {
t.Errorf("Model = %q, want %q", req.Model, "claude-3-5-sonnet-20241022")
}
}