diff --git a/core/go.mod b/core/go.mod index 1b689325bd..fda2986edf 100644 --- a/core/go.mod +++ b/core/go.mod @@ -5,7 +5,7 @@ go 1.24.1 require ( github.com/aws/aws-sdk-go-v2 v1.36.3 github.com/aws/aws-sdk-go-v2/config v1.29.14 - github.com/goccy/go-json v0.10.5 + github.com/bytedance/sonic v1.14.0 github.com/mark3labs/mcp-go v0.32.0 github.com/valyala/fasthttp v1.60.0 golang.org/x/oauth2 v0.30.0 @@ -25,11 +25,16 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect github.com/aws/smithy-go v1.22.3 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cloudwego/base64x v0.1.5 // indirect github.com/google/uuid v1.6.0 // indirect github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.0.9 // indirect github.com/spf13/cast v1.7.1 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect golang.org/x/net v0.39.0 // indirect golang.org/x/text v0.24.0 // indirect ) diff --git a/core/go.sum b/core/go.sum index 9ffb1b5af4..e232e2a55b 100644 --- a/core/go.sum +++ b/core/go.sum @@ -28,18 +28,28 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/Xv github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k= github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= +github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= +github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= +github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= -github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= -github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -52,8 +62,17 @@ github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZV github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.60.0 h1:kBRYS0lOhVJ6V+bYN8PqAHELKHtXqwq9zNMLKx1MBsw= @@ -62,11 +81,16 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= diff --git a/core/providers/anthropic.go b/core/providers/anthropic.go index b7c34a2b1e..fbd7b57b55 100644 --- a/core/providers/anthropic.go +++ b/core/providers/anthropic.go @@ -12,8 +12,7 @@ import ( "sync" "time" - "github.com/goccy/go-json" - + "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) @@ -263,7 +262,7 @@ func (provider *AnthropicProvider) prepareTextCompletionParams(params map[string // Returns the response body or an error if the request fails. func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, url string, key string) ([]byte, *schemas.BifrostError) { // Marshal the request body - jsonData, err := json.Marshal(requestBody) + jsonData, err := sonic.Marshal(requestBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Anthropic) } @@ -548,7 +547,7 @@ func prepareAnthropicChatRequest(messages []schemas.BifrostMessage, params *sche if toolCall.Function.Name != nil { var input map[string]interface{} if toolCall.Function.Arguments != "" { - if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil { + if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil { // If unmarshaling fails, use a simple string representation input = map[string]interface{}{"arguments": toolCall.Function.Arguments} } @@ -712,7 +711,7 @@ func parseAnthropicResponse(response *AnthropicChatResponse, bifrostResponse *sc Name: &c.Name, } - args, err := json.Marshal(c.Input) + args, err := sonic.Marshal(c.Input) if err != nil { function.Arguments = fmt.Sprintf("%v", c.Input) } else { @@ -826,7 +825,7 @@ func handleAnthropicStreaming( logger schemas.Logger, ) (chan *schemas.BifrostStream, *schemas.BifrostError) { - jsonBody, err := json.Marshal(requestBody) + jsonBody, err := sonic.Marshal(requestBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerType) } @@ -903,7 +902,7 @@ func handleAnthropicStreaming( switch eventType { case "message_start": var event AnthropicStreamEvent - if err := json.Unmarshal([]byte(eventData), &event); err != nil { + if err := sonic.Unmarshal([]byte(eventData), &event); err != nil { logger.Warn(fmt.Sprintf("Failed to parse message_start event: %v", err)) continue } @@ -914,7 +913,7 @@ func handleAnthropicStreaming( case "content_block_start": var event AnthropicStreamEvent - if err := json.Unmarshal([]byte(eventData), &event); err != nil { + if err := sonic.Unmarshal([]byte(eventData), &event); err != nil { logger.Warn(fmt.Sprintf("Failed to parse content_block_start event: %v", err)) continue } @@ -1002,7 +1001,7 @@ func handleAnthropicStreaming( case "content_block_delta": var event AnthropicStreamEvent - if err := json.Unmarshal([]byte(eventData), &event); err != nil { + if err := sonic.Unmarshal([]byte(eventData), &event); err != nil { logger.Warn(fmt.Sprintf("Failed to parse content_block_delta event: %v", err)) continue } @@ -1122,7 +1121,7 @@ func handleAnthropicStreaming( case "message_delta": var event AnthropicStreamEvent - if err := json.Unmarshal([]byte(eventData), &event); err != nil { + if err := sonic.Unmarshal([]byte(eventData), &event); err != nil { logger.Warn(fmt.Sprintf("Failed to parse message_delta event: %v", err)) continue } @@ -1160,7 +1159,7 @@ func handleAnthropicStreaming( case "message_stop": var event AnthropicStreamEvent - if err := json.Unmarshal([]byte(eventData), &event); err != nil { + if err := sonic.Unmarshal([]byte(eventData), &event); err != nil { logger.Warn(fmt.Sprintf("Failed to parse message_stop event: %v", err)) continue } @@ -1203,7 +1202,7 @@ func handleAnthropicStreaming( case "error": var event AnthropicStreamEvent - if err := json.Unmarshal([]byte(eventData), &event); err != nil { + if err := sonic.Unmarshal([]byte(eventData), &event); err != nil { logger.Warn(fmt.Sprintf("Failed to parse error event: %v", err)) continue } diff --git a/core/providers/azure.go b/core/providers/azure.go index 2be8db4da3..2cdea6eab4 100644 --- a/core/providers/azure.go +++ b/core/providers/azure.go @@ -9,8 +9,7 @@ import ( "sync" "time" - "github.com/goccy/go-json" - + "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) @@ -169,7 +168,7 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody } // Marshal the request body - jsonData, err := json.Marshal(requestBody) + jsonData, err := sonic.Marshal(requestBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Azure) } @@ -392,7 +391,7 @@ func (provider *AzureProvider) Embedding(ctx context.Context, model string, key // Parse response var response AzureEmbeddingResponse - if err := json.Unmarshal(responseBody, &response); err != nil { + if err := sonic.Unmarshal(responseBody, &response); err != nil { return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Azure) } diff --git a/core/providers/bedrock.go b/core/providers/bedrock.go index ba6de3b957..8edf299c8c 100644 --- a/core/providers/bedrock.go +++ b/core/providers/bedrock.go @@ -7,6 +7,7 @@ import ( "context" "crypto/sha256" "encoding/hex" + "encoding/json" "errors" "fmt" "io" @@ -17,13 +18,12 @@ import ( "sync" "time" - "github.com/goccy/go-json" - "bufio" "github.com/aws/aws-sdk-go-v2/aws" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/aws/aws-sdk-go-v2/config" + "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" ) @@ -269,7 +269,7 @@ func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBod region = *provider.meta.GetRegion() } - jsonBody, err := json.Marshal(requestBody) + jsonBody, err := sonic.Marshal(requestBody) if err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return nil, &schemas.BifrostError{ @@ -346,7 +346,7 @@ func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBod if resp.StatusCode != http.StatusOK { var errorResp BedrockError - if err := json.Unmarshal(body, &errorResp); err != nil { + if err := sonic.Unmarshal(body, &errorResp); err != nil { return nil, &schemas.BifrostError{ IsBifrostError: true, StatusCode: &resp.StatusCode, @@ -379,7 +379,7 @@ func (provider *BedrockProvider) getTextCompletionResult(result []byte, model st fallthrough case "anthropic.claude-v2:1": var response BedrockAnthropicTextResponse - if err := json.Unmarshal(result, &response); err != nil { + if err := sonic.Unmarshal(result, &response); err != nil { return nil, &schemas.BifrostError{ IsBifrostError: true, Error: schemas.ErrorField{ @@ -421,7 +421,7 @@ func (provider *BedrockProvider) getTextCompletionResult(result []byte, model st fallthrough case "mistral.mistral-small-2402-v1:0": var response BedrockMistralTextResponse - if err := json.Unmarshal(result, &response); err != nil { + if err := sonic.Unmarshal(result, &response); err != nil { return nil, &schemas.BifrostError{ IsBifrostError: true, Error: schemas.ErrorField{ @@ -465,7 +465,7 @@ func (provider *BedrockProvider) getTextCompletionResult(result []byte, model st func parseBedrockAnthropicMessageToolCallContent(content string) map[string]interface{} { toolResultContentBlock := map[string]interface{}{} var parsedJSON interface{} - err := json.Unmarshal([]byte(content), &parsedJSON) + err := sonic.Unmarshal([]byte(content), &parsedJSON) if err == nil { if arr, ok := parsedJSON.([]interface{}); ok { toolResultContentBlock["json"] = map[string]interface{}{"content": arr} @@ -549,7 +549,7 @@ func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schema for _, toolCall := range *msg.AssistantMessage.ToolCalls { var input map[string]interface{} if toolCall.Function.Arguments != "" { - if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil { + if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil { input = map[string]interface{}{"arguments": toolCall.Function.Arguments} } } @@ -838,7 +838,7 @@ func (provider *BedrockProvider) TextCompletion(ctx context.Context, model strin // Parse raw response var rawResponse interface{} - if err := json.Unmarshal(body, &rawResponse); err != nil { + if err := sonic.Unmarshal(body, &rawResponse); err != nil { return nil, newBifrostOperationError("error parsing raw response", err, schemas.Bedrock) } @@ -978,7 +978,7 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model strin if input == nil { input = map[string]any{} } - arguments, err := json.Marshal(input) + arguments, err := sonic.Marshal(input) if err != nil { arguments = []byte("{}") } @@ -1159,7 +1159,7 @@ func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model Embedding []float32 `json:"embedding"` InputTextTokenCount int `json:"inputTextTokenCount"` } - if err := json.Unmarshal(rawResponse, &titanResp); err != nil { + if err := sonic.Unmarshal(rawResponse, &titanResp); err != nil { return nil, newBifrostOperationError("error parsing Titan embedding response", err, schemas.Bedrock) } @@ -1210,7 +1210,7 @@ func (provider *BedrockProvider) handleCohereEmbedding(ctx context.Context, mode ID string `json:"id"` Texts []string `json:"texts"` } - if err := json.Unmarshal(rawResponse, &cohereResp); err != nil { + if err := sonic.Unmarshal(rawResponse, &cohereResp); err != nil { return nil, newBifrostOperationError("error parsing Cohere embedding response", err, schemas.Bedrock) } @@ -1292,7 +1292,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH } // Create the streaming request - jsonBody, jsonErr := json.Marshal(requestBody) + jsonBody, jsonErr := sonic.Marshal(requestBody) if jsonErr != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, jsonErr, schemas.Bedrock) } @@ -1386,7 +1386,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH // Parse the JSON event var event map[string]interface{} - if err := json.Unmarshal([]byte(jsonStr), &event); err != nil { + if err := sonic.Unmarshal([]byte(jsonStr), &event); err != nil { provider.logger.Debug(fmt.Sprintf("Failed to parse JSON from stream: %v, data: %s", err, jsonStr)) continue } @@ -1456,7 +1456,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH // Extract and marshal input as arguments if input, hasInput := toolUse["input"].(map[string]interface{}); hasInput { - inputBytes, err := json.Marshal(input) + inputBytes, err := sonic.Marshal(input) if err != nil { toolCall.Function.Arguments = "{}" } else { diff --git a/core/providers/cohere.go b/core/providers/cohere.go index 8710efd7e3..765e727931 100644 --- a/core/providers/cohere.go +++ b/core/providers/cohere.go @@ -12,10 +12,9 @@ import ( "sync" "time" - "github.com/goccy/go-json" - "net/http" + "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) @@ -205,7 +204,7 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, model string } // Marshal request body - jsonBody, err := json.Marshal(requestBody) + jsonBody, err := sonic.Marshal(requestBody) if err != nil { return nil, &schemas.BifrostError{ IsBifrostError: true, @@ -270,7 +269,7 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, model string Name: &tool.Name, } - args, err := json.Marshal(tool.Parameters) + args, err := sonic.Marshal(tool.Parameters) if err != nil { function.Arguments = fmt.Sprintf("%v", tool.Parameters) } else { @@ -359,7 +358,7 @@ func prepareCohereChatRequest(messages []schemas.BifrostMessage, params *schemas for _, toolCall := range *msg.AssistantMessage.ToolCalls { var arguments map[string]interface{} var parsedJSON interface{} - err := json.Unmarshal([]byte(toolCall.Function.Arguments), &parsedJSON) + err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &parsedJSON) if err == nil { if arr, ok := parsedJSON.(map[string]interface{}); ok { arguments = arr @@ -395,7 +394,7 @@ func prepareCohereChatRequest(messages []schemas.BifrostMessage, params *schemas // Found the matching tool call, extract its parameters var parsedJSON interface{} - err := json.Unmarshal([]byte(toolCall.Function.Arguments), &parsedJSON) + err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &parsedJSON) if err == nil { if arr, ok := parsedJSON.(map[string]interface{}); ok { toolCallParameters = arr @@ -563,7 +562,7 @@ func convertChatHistory(history []struct { Name: &tool.Name, } - args, err := json.Marshal(tool.Parameters) + args, err := sonic.Marshal(tool.Parameters) if err != nil { function.Arguments = fmt.Sprintf("%v", tool.Parameters) } else { @@ -624,7 +623,7 @@ func (provider *CohereProvider) Embedding(ctx context.Context, model string, key } // Marshal request body - jsonBody, err := json.Marshal(requestBody) + jsonBody, err := sonic.Marshal(requestBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Cohere) } @@ -664,13 +663,13 @@ func (provider *CohereProvider) Embedding(ctx context.Context, model string, key // Parse response var cohereResp CohereEmbeddingResponse - if err := json.Unmarshal(resp.Body(), &cohereResp); err != nil { + if err := sonic.Unmarshal(resp.Body(), &cohereResp); err != nil { return nil, newBifrostOperationError("error parsing Cohere embedding response", err, schemas.Cohere) } // Parse raw response for consistent format var rawResponse interface{} - if err := json.Unmarshal(resp.Body(), &rawResponse); err != nil { + if err := sonic.Unmarshal(resp.Body(), &rawResponse); err != nil { return nil, newBifrostOperationError("error parsing raw response for Cohere embedding", err, schemas.Cohere) } @@ -709,7 +708,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo return nil, newBifrostOperationError("failed to prepare Cohere chat request", err, schemas.Cohere) } - jsonBody, err := json.Marshal(requestBody) + jsonBody, err := sonic.Marshal(requestBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Cohere) } @@ -773,7 +772,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo // Parse the streaming event var streamEvent map[string]interface{} - if err := json.Unmarshal([]byte(jsonData), &streamEvent); err != nil { + if err := sonic.Unmarshal([]byte(jsonData), &streamEvent); err != nil { provider.logger.Warn(fmt.Sprintf("Failed to parse Cohere stream event: %v", err)) continue } @@ -786,7 +785,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo switch eventType { case "stream-start": var startEvent CohereStreamStartEvent - if err := json.Unmarshal([]byte(jsonData), &startEvent); err != nil { + if err := sonic.Unmarshal([]byte(jsonData), &startEvent); err != nil { provider.logger.Warn(fmt.Sprintf("Failed to parse Cohere stream-start event: %v", err)) continue } @@ -823,7 +822,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo case "text-generation": var textEvent CohereStreamTextEvent - if err := json.Unmarshal([]byte(jsonData), &textEvent); err != nil { + if err := sonic.Unmarshal([]byte(jsonData), &textEvent); err != nil { provider.logger.Warn(fmt.Sprintf("Failed to parse Cohere text-generation event: %v", err)) continue } @@ -858,7 +857,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo case "tool-calls-chunk": var toolEvent CohereStreamToolCallEvent - if err := json.Unmarshal([]byte(jsonData), &toolEvent); err != nil { + if err := sonic.Unmarshal([]byte(jsonData), &toolEvent); err != nil { provider.logger.Warn(fmt.Sprintf("Failed to parse Cohere tool-use event: %v", err)) continue } @@ -902,7 +901,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo case "stream-end": var stopEvent CohereStreamStopEvent - if err := json.Unmarshal([]byte(jsonData), &stopEvent); err != nil { + if err := sonic.Unmarshal([]byte(jsonData), &stopEvent); err != nil { provider.logger.Warn(fmt.Sprintf("Failed to parse Cohere stream-end event: %v", err)) continue } @@ -914,7 +913,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo Name: &toolCall.Name, } - args, err := json.Marshal(toolCall.Parameters) + args, err := sonic.Marshal(toolCall.Parameters) if err != nil { function.Arguments = fmt.Sprintf("%v", toolCall.Parameters) } else { diff --git a/core/providers/groq.go b/core/providers/groq.go index 82b0dc0316..3df9c3402b 100644 --- a/core/providers/groq.go +++ b/core/providers/groq.go @@ -10,8 +10,7 @@ import ( "sync" "time" - "github.com/goccy/go-json" - + "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) @@ -113,7 +112,7 @@ func (provider *GroqProvider) ChatCompletion(ctx context.Context, model string, "messages": formattedMessages, }, preparedParams) - jsonBody, err := json.Marshal(requestBody) + jsonBody, err := sonic.Marshal(requestBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Groq) } diff --git a/core/providers/mistral.go b/core/providers/mistral.go index 59d89b69fe..c4ebaf85c4 100644 --- a/core/providers/mistral.go +++ b/core/providers/mistral.go @@ -10,8 +10,7 @@ import ( "sync" "time" - "github.com/goccy/go-json" - + "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) @@ -127,7 +126,7 @@ func (provider *MistralProvider) ChatCompletion(ctx context.Context, model strin "messages": formattedMessages, }, preparedParams) - jsonBody, err := json.Marshal(requestBody) + jsonBody, err := sonic.Marshal(requestBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Mistral) } @@ -234,7 +233,7 @@ func (provider *MistralProvider) Embedding(ctx context.Context, model string, ke } } - jsonBody, err := json.Marshal(requestBody) + jsonBody, err := sonic.Marshal(requestBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Mistral) } @@ -271,18 +270,18 @@ func (provider *MistralProvider) Embedding(ctx context.Context, model string, ke return nil, bifrostErr } - // Parse response using json.RawMessage to avoid double parsing - var rawMessage json.RawMessage = resp.Body() + // Parse response using sonic.RawMessage to avoid double parsing + rawMessage := resp.Body() // Parse into structured response var mistralResp MistralEmbeddingResponse - if err := json.Unmarshal(rawMessage, &mistralResp); err != nil { + if err := sonic.Unmarshal(rawMessage, &mistralResp); err != nil { return nil, newBifrostOperationError("error parsing Mistral embedding response", err, schemas.Mistral) } // Parse raw response for consistent format var rawResponse interface{} - if err := json.Unmarshal(rawMessage, &rawResponse); err != nil { + if err := sonic.Unmarshal(rawMessage, &rawResponse); err != nil { return nil, newBifrostOperationError("error parsing raw response for Mistral embedding", err, schemas.Mistral) } diff --git a/core/providers/ollama.go b/core/providers/ollama.go index 1f52787af4..dd653760b7 100644 --- a/core/providers/ollama.go +++ b/core/providers/ollama.go @@ -10,8 +10,7 @@ import ( "sync" "time" - "github.com/goccy/go-json" - + "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) @@ -114,7 +113,7 @@ func (provider *OllamaProvider) ChatCompletion(ctx context.Context, model string "messages": formattedMessages, }, preparedParams) - jsonBody, err := json.Marshal(requestBody) + jsonBody, err := sonic.Marshal(requestBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Ollama) } diff --git a/core/providers/openai.go b/core/providers/openai.go index 66cda30dea..3fd8d3e512 100644 --- a/core/providers/openai.go +++ b/core/providers/openai.go @@ -17,8 +17,7 @@ import ( "sync" "time" - "github.com/goccy/go-json" - + "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) @@ -41,36 +40,22 @@ type OpenAIResponse struct { Usage schemas.LLMUsage `json:"usage"` // Token usage statistics } -// OpenAIError represents the error response structure from the OpenAI API. -// It includes detailed error information and event tracking. -type OpenAIError struct { - EventID string `json:"event_id"` // Unique identifier for the error event - Type string `json:"type"` // Type of error - Error struct { - Type string `json:"type"` // Error type - Code string `json:"code"` // Error code - Message string `json:"message"` // Error message - Param interface{} `json:"param"` // Parameter that caused the error - EventID string `json:"event_id"` // Event ID for tracking - } `json:"error"` -} - // openAIResponsePool provides a pool for OpenAI response objects. var openAIResponsePool = sync.Pool{ New: func() interface{} { - return &OpenAIResponse{} + return &schemas.BifrostResponse{} }, } // acquireOpenAIResponse gets an OpenAI response from the pool and resets it. -func acquireOpenAIResponse() *OpenAIResponse { - resp := openAIResponsePool.Get().(*OpenAIResponse) - *resp = OpenAIResponse{} // Reset the struct +func acquireOpenAIResponse() *schemas.BifrostResponse { + resp := openAIResponsePool.Get().(*schemas.BifrostResponse) + *resp = schemas.BifrostResponse{} // Reset the struct return resp } // releaseOpenAIResponse returns an OpenAI response to the pool. -func releaseOpenAIResponse(resp *OpenAIResponse) { +func releaseOpenAIResponse(resp *schemas.BifrostResponse) { if resp != nil { openAIResponsePool.Put(resp) } @@ -103,7 +88,7 @@ func NewOpenAIProvider(config *schemas.ProviderConfig, logger schemas.Logger) *O // Pre-warm response pools for range config.ConcurrencyAndBufferSize.Concurrency { - openAIResponsePool.Put(&OpenAIResponse{}) + openAIResponsePool.Put(&schemas.BifrostResponse{}) } // Configure proxy if provided @@ -145,7 +130,7 @@ func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, model string "messages": formattedMessages, }, preparedParams) - jsonBody, err := json.Marshal(requestBody) + jsonBody, err := sonic.Marshal(requestBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.OpenAI) } @@ -185,32 +170,16 @@ func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, model string defer releaseOpenAIResponse(response) // Use enhanced response handler with pre-allocated response - rawResponse, bifrostErr := handleProviderResponse(responseBody, response) + _, bifrostErr = handleProviderResponse(responseBody, response) if bifrostErr != nil { return nil, bifrostErr } - // Create final response - bifrostResponse := &schemas.BifrostResponse{ - ID: response.ID, - Object: response.Object, - Choices: response.Choices, - Model: response.Model, - Created: response.Created, - ServiceTier: response.ServiceTier, - SystemFingerprint: response.SystemFingerprint, - Usage: &response.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - RawResponse: rawResponse, - }, - } - if params != nil { - bifrostResponse.ExtraFields.Params = *params + response.ExtraFields.Params = *params } - return bifrostResponse, nil + return response, nil } // prepareOpenAIChatRequest formats messages for the OpenAI API. @@ -293,7 +262,7 @@ func (provider *OpenAIProvider) Embedding(ctx context.Context, model string, key requestBody = mergeConfig(requestBody, params.ExtraParams) } - jsonBody, err := json.Marshal(requestBody) + jsonBody, err := sonic.Marshal(requestBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.OpenAI) } @@ -328,7 +297,7 @@ func (provider *OpenAIProvider) Embedding(ctx context.Context, model string, key // Parse response var response OpenAIResponse - if err := json.Unmarshal(resp.Body(), &response); err != nil { + if err := sonic.Unmarshal(resp.Body(), &response); err != nil { return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.OpenAI) } @@ -446,7 +415,7 @@ func handleOpenAIStreaming( logger schemas.Logger, ) (chan *schemas.BifrostStream, *schemas.BifrostError) { - jsonBody, err := json.Marshal(requestBody) + jsonBody, err := sonic.Marshal(requestBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.OpenAI) } @@ -516,41 +485,21 @@ func handleOpenAIStreaming( // First, check if this is an error response var errorCheck map[string]interface{} - if err := json.Unmarshal([]byte(jsonData), &errorCheck); err != nil { + if err := sonic.Unmarshal([]byte(jsonData), &errorCheck); err != nil { logger.Warn(fmt.Sprintf("Failed to parse stream data as JSON: %v", err)) continue } // Handle error responses if _, hasError := errorCheck["error"]; hasError { - var openAIError OpenAIError - if err := json.Unmarshal([]byte(jsonData), &openAIError); err != nil { + errorStream, err := parseOpenAIErrorForStreamDataLine(jsonData) + if err != nil { logger.Warn(fmt.Sprintf("Failed to parse error response: %v", err)) continue } - // Send error through channel - errorResponse := &schemas.BifrostStream{ - BifrostError: &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Type: &openAIError.Error.Type, - Code: &openAIError.Error.Code, - Message: openAIError.Error.Message, - Param: openAIError.Error.Param, - }, - }, - } - - if openAIError.EventID != "" { - errorResponse.BifrostError.EventID = &openAIError.EventID - } - if openAIError.Error.EventID != "" { - errorResponse.BifrostError.Error.EventID = &openAIError.Error.EventID - } - select { - case responseChan <- errorResponse: + case responseChan <- errorStream: case <-ctx.Done(): } return // Stop processing on error @@ -558,7 +507,7 @@ func handleOpenAIStreaming( // Parse into bifrost response var response schemas.BifrostResponse - if err := json.Unmarshal([]byte(jsonData), &response); err != nil { + if err := sonic.Unmarshal([]byte(jsonData), &response); err != nil { logger.Warn(fmt.Sprintf("Failed to parse stream response: %v", err)) continue } @@ -637,7 +586,7 @@ func (provider *OpenAIProvider) Speech(ctx context.Context, model string, key sc requestBody = mergeConfig(requestBody, params.ExtraParams) } - jsonBody, err := json.Marshal(requestBody) + jsonBody, err := sonic.Marshal(requestBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.OpenAI) } @@ -716,7 +665,7 @@ func (provider *OpenAIProvider) SpeechStream(ctx context.Context, postHookRunner requestBody = mergeConfig(requestBody, params.ExtraParams) } - jsonBody, err := json.Marshal(requestBody) + jsonBody, err := sonic.Marshal(requestBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.OpenAI) } @@ -794,41 +743,21 @@ func (provider *OpenAIProvider) SpeechStream(ctx context.Context, postHookRunner // First, check if this is an error response var errorCheck map[string]interface{} - if err := json.Unmarshal([]byte(jsonData), &errorCheck); err != nil { + if err := sonic.Unmarshal([]byte(jsonData), &errorCheck); err != nil { provider.logger.Warn(fmt.Sprintf("Failed to parse stream data as JSON: %v", err)) continue } // Handle error responses if _, hasError := errorCheck["error"]; hasError { - var openAIError OpenAIError - if err := json.Unmarshal([]byte(jsonData), &openAIError); err != nil { + errorStream, err := parseOpenAIErrorForStreamDataLine(jsonData) + if err != nil { provider.logger.Warn(fmt.Sprintf("Failed to parse error response: %v", err)) continue } - // Send error through channel - errorResponse := &schemas.BifrostStream{ - BifrostError: &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Type: &openAIError.Error.Type, - Code: &openAIError.Error.Code, - Message: openAIError.Error.Message, - Param: openAIError.Error.Param, - }, - }, - } - - if openAIError.EventID != "" { - errorResponse.BifrostError.EventID = &openAIError.EventID - } - if openAIError.Error.EventID != "" { - errorResponse.BifrostError.Error.EventID = &openAIError.Error.EventID - } - select { - case responseChan <- errorResponse: + case responseChan <- errorStream: case <-ctx.Done(): } return // Stop processing on error @@ -838,7 +767,7 @@ func (provider *OpenAIProvider) SpeechStream(ctx context.Context, postHookRunner var response schemas.BifrostResponse var speechResponse schemas.BifrostSpeech - if err := json.Unmarshal([]byte(jsonData), &speechResponse); err != nil { + if err := sonic.Unmarshal([]byte(jsonData), &speechResponse); err != nil { provider.logger.Warn(fmt.Sprintf("Failed to parse stream response: %v", err)) continue } @@ -914,13 +843,13 @@ func (provider *OpenAIProvider) Transcription(ctx context.Context, model string, BifrostTranscribeNonStreamResponse: &schemas.BifrostTranscribeNonStreamResponse{}, } - if err := json.Unmarshal(responseBody, transcribeResponse); err != nil { + if err := sonic.Unmarshal(responseBody, transcribeResponse); err != nil { return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.OpenAI) } // Parse raw response for RawResponse field var rawResponse interface{} - if err := json.Unmarshal(responseBody, &rawResponse); err != nil { + if err := sonic.Unmarshal(responseBody, &rawResponse); err != nil { return nil, newBifrostOperationError(schemas.ErrProviderDecodeRaw, err, schemas.OpenAI) } @@ -1028,41 +957,21 @@ func (provider *OpenAIProvider) TranscriptionStream(ctx context.Context, postHoo // First, check if this is an error response var errorCheck map[string]interface{} - if err := json.Unmarshal([]byte(jsonData), &errorCheck); err != nil { + if err := sonic.Unmarshal([]byte(jsonData), &errorCheck); err != nil { provider.logger.Warn(fmt.Sprintf("Failed to parse stream data as JSON: %v", err)) continue } // Handle error responses if _, hasError := errorCheck["error"]; hasError { - var openAIError OpenAIError - if err := json.Unmarshal([]byte(jsonData), &openAIError); err != nil { + errorStream, err := parseOpenAIErrorForStreamDataLine(jsonData) + if err != nil { provider.logger.Warn(fmt.Sprintf("Failed to parse error response: %v", err)) continue } - // Send error through channel - errorResponse := &schemas.BifrostStream{ - BifrostError: &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Type: &openAIError.Error.Type, - Code: &openAIError.Error.Code, - Message: openAIError.Error.Message, - Param: openAIError.Error.Param, - }, - }, - } - - if openAIError.EventID != "" { - errorResponse.BifrostError.EventID = &openAIError.EventID - } - if openAIError.Error.EventID != "" { - errorResponse.BifrostError.Error.EventID = &openAIError.Error.EventID - } - select { - case responseChan <- errorResponse: + case responseChan <- errorStream: case <-ctx.Done(): } return // Stop processing on error @@ -1071,7 +980,7 @@ func (provider *OpenAIProvider) TranscriptionStream(ctx context.Context, postHoo var response schemas.BifrostResponse var transcriptionResponse schemas.BifrostTranscribe - if err := json.Unmarshal([]byte(jsonData), &transcriptionResponse); err != nil { + if err := sonic.Unmarshal([]byte(jsonData), &transcriptionResponse); err != nil { provider.logger.Warn(fmt.Sprintf("Failed to parse stream response: %v", err)) continue } @@ -1173,32 +1082,32 @@ func parseTranscriptionFormDataBody(writer *multipart.Writer, input *schemas.Tra } func parseOpenAIError(resp *fasthttp.Response) *schemas.BifrostError { - var errorResp OpenAIError + var errorResp schemas.BifrostError bifrostErr := handleProviderAPIError(resp, &errorResp) - if errorResp.EventID != "" { - bifrostErr.EventID = &errorResp.EventID + if errorResp.EventID != nil { + bifrostErr.EventID = errorResp.EventID } - bifrostErr.Error.Type = &errorResp.Error.Type - bifrostErr.Error.Code = &errorResp.Error.Code + bifrostErr.Error.Type = errorResp.Error.Type + bifrostErr.Error.Code = errorResp.Error.Code bifrostErr.Error.Message = errorResp.Error.Message bifrostErr.Error.Param = errorResp.Error.Param - if errorResp.Error.EventID != "" { - bifrostErr.Error.EventID = &errorResp.Error.EventID + if errorResp.Error.EventID != nil { + bifrostErr.Error.EventID = errorResp.Error.EventID } return bifrostErr } func parseStreamOpenAIError(resp *http.Response) *schemas.BifrostError { - var errorResp OpenAIError + var errorResp schemas.BifrostError statusCode := resp.StatusCode body, _ := io.ReadAll(resp.Body) resp.Body.Close() - if err := json.Unmarshal(body, &errorResp); err != nil { + if err := sonic.Unmarshal(body, &errorResp); err != nil { return &schemas.BifrostError{ IsBifrostError: true, StatusCode: &statusCode, @@ -1215,16 +1124,45 @@ func parseStreamOpenAIError(resp *http.Response) *schemas.BifrostError { Error: schemas.ErrorField{}, } - if errorResp.EventID != "" { - bifrostErr.EventID = &errorResp.EventID + if errorResp.EventID != nil { + bifrostErr.EventID = errorResp.EventID } - bifrostErr.Error.Type = &errorResp.Error.Type - bifrostErr.Error.Code = &errorResp.Error.Code + bifrostErr.Error.Type = errorResp.Error.Type + bifrostErr.Error.Code = errorResp.Error.Code bifrostErr.Error.Message = errorResp.Error.Message bifrostErr.Error.Param = errorResp.Error.Param - if errorResp.Error.EventID != "" { - bifrostErr.Error.EventID = &errorResp.Error.EventID + if errorResp.Error.EventID != nil { + bifrostErr.Error.EventID = errorResp.Error.EventID } return bifrostErr } + +func parseOpenAIErrorForStreamDataLine(jsonData string) (*schemas.BifrostStream, error) { + var openAIError schemas.BifrostError + if err := sonic.Unmarshal([]byte(jsonData), &openAIError); err != nil { + return nil, err + } + + // Send error through channel + errorStream := &schemas.BifrostStream{ + BifrostError: &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Type: openAIError.Error.Type, + Code: openAIError.Error.Code, + Message: openAIError.Error.Message, + Param: openAIError.Error.Param, + }, + }, + } + + if openAIError.EventID != nil { + errorStream.BifrostError.EventID = openAIError.EventID + } + if openAIError.Error.EventID != nil { + errorStream.BifrostError.Error.EventID = openAIError.Error.EventID + } + + return errorStream, nil +} diff --git a/core/providers/sgl.go b/core/providers/sgl.go index 9186bc2c08..07b4add4e1 100644 --- a/core/providers/sgl.go +++ b/core/providers/sgl.go @@ -10,8 +10,7 @@ import ( "sync" "time" - "github.com/goccy/go-json" - + "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) @@ -114,7 +113,7 @@ func (provider *SGLProvider) ChatCompletion(ctx context.Context, model string, k "messages": formattedMessages, }, preparedParams) - jsonBody, err := json.Marshal(requestBody) + jsonBody, err := sonic.Marshal(requestBody) if err != nil { return nil, &schemas.BifrostError{ IsBifrostError: true, diff --git a/core/providers/utils.go b/core/providers/utils.go index 8080195b40..0fa9874637 100644 --- a/core/providers/utils.go +++ b/core/providers/utils.go @@ -14,7 +14,7 @@ import ( "strings" "sync" - "github.com/goccy/go-json" + "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttpproxy" @@ -287,7 +287,7 @@ func setExtraHeadersHTTP(req *http.Request, extraHeaders map[string]string, skip func handleProviderAPIError(resp *fasthttp.Response, errorResp any) *schemas.BifrostError { statusCode := resp.StatusCode() - if err := json.Unmarshal(resp.Body(), &errorResp); err != nil { + if err := sonic.Unmarshal(resp.Body(), &errorResp); err != nil { return &schemas.BifrostError{ IsBifrostError: true, StatusCode: &statusCode, @@ -309,7 +309,7 @@ func handleProviderAPIError(resp *fasthttp.Response, errorResp any) *schemas.Bif // It attempts to parse the response body into the provided response type // and returns either the parsed response or a BifrostError if parsing fails. func handleProviderResponse[T any](responseBody []byte, response *T) (interface{}, *schemas.BifrostError) { - var rawResponse interface{} + // var rawResponse interface{} var wg sync.WaitGroup var structuredErr, rawErr error @@ -317,11 +317,11 @@ func handleProviderResponse[T any](responseBody []byte, response *T) (interface{ wg.Add(2) go func() { defer wg.Done() - structuredErr = json.Unmarshal(responseBody, response) + structuredErr = sonic.Unmarshal(responseBody, response) }() go func() { defer wg.Done() - rawErr = json.Unmarshal(responseBody, &rawResponse) + // rawErr = sonic.Unmarshal(responseBody, &rawResponse) }() wg.Wait() @@ -345,7 +345,7 @@ func handleProviderResponse[T any](responseBody []byte, response *T) (interface{ } } - return rawResponse, nil + return nil, nil } // getRoleFromMessage extracts and validates the role from a message map. diff --git a/core/providers/vertex.go b/core/providers/vertex.go index b1c58170c3..6fe5fd5e00 100644 --- a/core/providers/vertex.go +++ b/core/providers/vertex.go @@ -14,9 +14,9 @@ import ( "strings" "sync" - "github.com/goccy/go-json" "golang.org/x/oauth2/google" + "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" ) @@ -159,7 +159,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model string delete(requestBody, "region") - jsonBody, err := json.Marshal(requestBody) + jsonBody, err := sonic.Marshal(requestBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Vertex) } @@ -236,14 +236,14 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model string removeVertexClient(key.VertexKeyConfig.AuthCredentials) } - var openAIErr OpenAIError + var openAIErr schemas.BifrostError var vertexErr []VertexError provider.logger.Debug(fmt.Sprintf("error from vertex provider: %s", string(body))) - if err := json.Unmarshal(body, &openAIErr); err != nil { + if err := sonic.Unmarshal(body, &openAIErr); err != nil { // Try Vertex error format if OpenAI format fails - if err := json.Unmarshal(body, &vertexErr); err != nil { + if err := sonic.Unmarshal(body, &vertexErr); err != nil { return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex) } @@ -303,7 +303,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model string Created: response.Created, ServiceTier: response.ServiceTier, SystemFingerprint: response.SystemFingerprint, - Usage: &response.Usage, + Usage: response.Usage, ExtraFields: schemas.BifrostResponseExtraFields{ Provider: schemas.Vertex, RawResponse: rawResponse, diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 35bd09f470..734796d195 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -2,8 +2,9 @@ package schemas import ( - "encoding/json" "fmt" + + "github.com/bytedance/sonic" ) const ( @@ -93,13 +94,13 @@ func (tc SpeechVoiceInput) MarshalJSON() ([]byte, error) { } if tc.Voice != nil { - return json.Marshal(*tc.Voice) + return sonic.Marshal(*tc.Voice) } if len(tc.MultiVoiceConfig) > 0 { - return json.Marshal(tc.MultiVoiceConfig) + return sonic.Marshal(tc.MultiVoiceConfig) } // If both are nil, return null - return json.Marshal(nil) + return sonic.Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for SpeechVoiceInput. @@ -108,14 +109,14 @@ func (tc SpeechVoiceInput) MarshalJSON() ([]byte, error) { func (tc *SpeechVoiceInput) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var stringContent string - if err := json.Unmarshal(data, &stringContent); err == nil { + if err := sonic.Unmarshal(data, &stringContent); err == nil { tc.Voice = &stringContent return nil } // Try to unmarshal as an array of VoiceConfig objects var voiceConfigs []VoiceConfig - if err := json.Unmarshal(data, &voiceConfigs); err == nil { + if err := sonic.Unmarshal(data, &voiceConfigs); err == nil { // Validate each VoiceConfig and append to MultiVoiceConfig for _, config := range voiceConfigs { if config.Voice == "" { @@ -245,13 +246,13 @@ func (tc ToolChoice) MarshalJSON() ([]byte, error) { } if tc.ToolChoiceStr != nil { - return json.Marshal(*tc.ToolChoiceStr) + return sonic.Marshal(*tc.ToolChoiceStr) } if tc.ToolChoiceStruct != nil { - return json.Marshal(*tc.ToolChoiceStruct) + return sonic.Marshal(*tc.ToolChoiceStruct) } // If both are nil, return null - return json.Marshal(nil) + return sonic.Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ToolChoice. @@ -260,14 +261,14 @@ func (tc ToolChoice) MarshalJSON() ([]byte, error) { func (tc *ToolChoice) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var stringContent string - if err := json.Unmarshal(data, &stringContent); err == nil { + if err := sonic.Unmarshal(data, &stringContent); err == nil { tc.ToolChoiceStr = &stringContent return nil } // Try to unmarshal as a direct struct of ToolChoiceStruct var toolChoiceStruct ToolChoiceStruct - if err := json.Unmarshal(data, &toolChoiceStruct); err == nil { + if err := sonic.Unmarshal(data, &toolChoiceStruct); err == nil { // Validate the Type field is not empty and is a valid value if toolChoiceStruct.Type == "" { return fmt.Errorf("tool_choice struct has empty type field") @@ -305,13 +306,13 @@ func (mc MessageContent) MarshalJSON() ([]byte, error) { } if mc.ContentStr != nil { - return json.Marshal(*mc.ContentStr) + return sonic.Marshal(*mc.ContentStr) } if mc.ContentBlocks != nil { - return json.Marshal(*mc.ContentBlocks) + return sonic.Marshal(*mc.ContentBlocks) } // If both are nil, return null - return json.Marshal(nil) + return sonic.Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for MessageContent. @@ -320,14 +321,14 @@ func (mc MessageContent) MarshalJSON() ([]byte, error) { func (mc *MessageContent) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var stringContent string - if err := json.Unmarshal(data, &stringContent); err == nil { + if err := sonic.Unmarshal(data, &stringContent); err == nil { mc.ContentStr = &stringContent return nil } // Try to unmarshal as a direct array of ContentBlock var arrayContent []ContentBlock - if err := json.Unmarshal(data, &arrayContent); err == nil { + if err := sonic.Unmarshal(data, &arrayContent); err == nil { mc.ContentBlocks = &arrayContent return nil }