diff --git a/core/bifrost.go b/core/bifrost.go index 685437a70a..3a6ccf5244 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -820,6 +820,8 @@ func (bifrost *Bifrost) createProviderFromProviderKey(providerKey schemas.ModelP return providers.NewOllamaProvider(config, bifrost.logger) case schemas.Groq: return providers.NewGroqProvider(config, bifrost.logger) + case schemas.SGL: + return providers.NewSGLProvider(config, bifrost.logger) default: return nil, fmt.Errorf("unsupported provider: %s", providerKey) } diff --git a/core/providers/sgl.go b/core/providers/sgl.go new file mode 100644 index 0000000000..74479cce8a --- /dev/null +++ b/core/providers/sgl.go @@ -0,0 +1,238 @@ +// Package providers implements various LLM providers and their utility functions. +// This file contains the SGL provider implementation. +package providers + +import ( + "context" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/goccy/go-json" + + schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" +) + +// SGLResponse represents the response structure from the SGL API. +type SGLResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Choices []schemas.BifrostResponseChoice `json:"choices"` + Model string `json:"model"` + Created int `json:"created"` + Usage schemas.LLMUsage `json:"usage"` +} + +// sglResponsePool provides a pool for SGL response objects. +var sglResponsePool = sync.Pool{ + New: func() interface{} { + return &SGLResponse{} + }, +} + +// acquireSGLResponse gets a SGL response from the pool and resets it. +func acquireSGLResponse() *SGLResponse { + resp := sglResponsePool.Get().(*SGLResponse) + *resp = SGLResponse{} // Reset the struct + return resp +} + +// releaseSGLResponse returns a SGL response to the pool. +func releaseSGLResponse(resp *SGLResponse) { + if resp != nil { + sglResponsePool.Put(resp) + } +} + +// SGLProvider implements the Provider interface for SGL's API. +type SGLProvider struct { + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + streamClient *http.Client // HTTP client for streaming requests + networkConfig schemas.NetworkConfig // Network configuration including extra headers +} + +// NewSGLProvider creates a new SGL provider instance. +// It initializes the HTTP client with the provided configuration and sets up response pools. +// The client is configured with timeouts, concurrency limits, and optional proxy settings. +func NewSGLProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*SGLProvider, error) { + config.CheckAndSetDefaults() + + client := &fasthttp.Client{ + ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + } + + // Initialize streaming HTTP client + streamClient := &http.Client{ + Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + } + + // Pre-warm response pools + for range config.ConcurrencyAndBufferSize.Concurrency { + sglResponsePool.Put(&SGLResponse{}) + } + + // Configure proxy if provided + client = configureProxy(client, config.ProxyConfig, logger) + + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") + + // BaseURL is required for SGLang + if config.NetworkConfig.BaseURL == "" { + return nil, fmt.Errorf("base_url is required for sgl provider") + } + + return &SGLProvider{ + logger: logger, + client: client, + streamClient: streamClient, + networkConfig: config.NetworkConfig, + }, nil +} + +// GetProviderKey returns the provider identifier for SGL. +func (provider *SGLProvider) GetProviderKey() schemas.ModelProvider { + return schemas.SGL +} + +// TextCompletion is not supported by the SGL provider. +func (provider *SGLProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("text completion", "sgl") +} + +// ChatCompletion performs a chat completion request to the SGL API. +func (provider *SGLProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + }, preparedParams) + + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderJSONMarshaling, + Error: err, + }, + } + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + if key != "" { + req.Header.Set("Authorization", "Bearer "+key) + } + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from sgl provider: %s", string(resp.Body()))) + + var errorResp map[string]interface{} + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = fmt.Sprintf("SGL error: %v", errorResp) + return nil, bifrostErr + } + + responseBody := resp.Body() + + // Pre-allocate response structs from pools + response := acquireSGLResponse() + defer releaseSGLResponse(response) + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(responseBody, response) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Create final response + bifrostResponse := &schemas.BifrostResponse{ + ID: response.ID, + Object: response.Object, + Choices: response.Choices, + Model: response.Model, + Created: response.Created, + Usage: &response.Usage, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.SGL, + RawResponse: rawResponse, + }, + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil +} + +// Embedding is not supported by the SGL provider. +func (provider *SGLProvider) Embedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("embedding", "sgl") +} + +// ChatCompletionStream performs a streaming chat completion request to the SGL API. +// It supports real-time streaming of responses using Server-Sent Events (SSE). +// Uses SGL's OpenAI-compatible streaming format. +// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. +func (provider *SGLProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + "stream": true, + }, preparedParams) + + // Prepare SGL headers (SGL typically doesn't require authorization, but we include it if provided) + headers := map[string]string{ + "Content-Type": "application/json", + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + // Only add Authorization header if key is provided (SGL can run without auth) + if key != "" { + headers["Authorization"] = "Bearer " + key + } + + // Use shared OpenAI-compatible streaming logic + return handleOpenAIStreaming( + ctx, + provider.streamClient, + provider.networkConfig.BaseURL+"/v1/chat/completions", + requestBody, + headers, + provider.networkConfig.ExtraHeaders, + schemas.SGL, + params, + postHookRunner, + provider.logger, + ) +} diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 594e300077..c75e039dbc 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -46,6 +46,7 @@ const ( Mistral ModelProvider = "mistral" Ollama ModelProvider = "ollama" Groq ModelProvider = "groq" + SGL ModelProvider = "sgl" ) //* Request Structs diff --git a/core/utils.go b/core/utils.go index 2fbb79cb6c..8aa329d5c5 100644 --- a/core/utils.go +++ b/core/utils.go @@ -14,7 +14,7 @@ func Ptr[T any](v T) *T { // providerRequiresKey returns true if the given provider requires an API key for authentication. // Some providers like Vertex and Ollama are keyless and don't require API keys. func providerRequiresKey(providerKey schemas.ModelProvider) bool { - return providerKey != schemas.Vertex && providerKey != schemas.Ollama + return providerKey != schemas.Vertex && providerKey != schemas.Ollama && providerKey != schemas.SGL } // calculateBackoff implements exponential backoff with jitter for retry attempts. diff --git a/docs/usage/providers.md b/docs/usage/providers.md index 438f5c0bae..cdae51ed5f 100644 --- a/docs/usage/providers.md +++ b/docs/usage/providers.md @@ -14,7 +14,8 @@ Multi-provider support with unified API across all AI providers. Switch between | **Cohere** | Command, Embed, Rerank | Enterprise NLP, multilingual | ✅ | | **Mistral** | Mistral Large, Medium, Small | European AI, cost-effective | ✅ | | **Ollama** | Llama, Mistral, CodeLlama | Local deployment, privacy | ✅ | -| **Groq** | Mixtral, Llama, Gemma | Enterprise AI platform | ✅ | +| **Groq** | Mixtral, Llama, Gemma | Enterprise AI platform | ✅ | +| **SGLang** | Qwen | Enterprise AI platform | ✅ | --- diff --git a/tests/core-providers/config/account.go b/tests/core-providers/config/account.go index 4c68a319f0..f8bbf414cb 100644 --- a/tests/core-providers/config/account.go +++ b/tests/core-providers/config/account.go @@ -64,6 +64,7 @@ func (account *ComprehensiveTestAccount) GetConfiguredProviders() ([]schemas.Mod schemas.Ollama, schemas.Mistral, schemas.Groq, + schemas.SGL, }, nil } @@ -244,6 +245,17 @@ func (account *ComprehensiveTestAccount) GetConfigForProvider(providerKey schema NetworkConfig: schemas.DefaultNetworkConfig, ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, }, nil + case schemas.SGL: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: os.Getenv("SGL_BASE_URL"), + DefaultRequestTimeoutInSeconds: 30, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil default: return nil, fmt.Errorf("unsupported provider: %s", providerKey) } diff --git a/tests/core-providers/sgl_test.go b/tests/core-providers/sgl_test.go new file mode 100644 index 0000000000..38e3646edb --- /dev/null +++ b/tests/core-providers/sgl_test.go @@ -0,0 +1,41 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestSGL(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Cleanup() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.SGL, + ChatModel: "Qwen2.5-VL-7B-Instruct", + TextModel: "", // SGL doesn't support text completion + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + ChatCompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/transports/README.md b/transports/README.md index 6c3f3b9e72..bcad8b3803 100644 --- a/transports/README.md +++ b/transports/README.md @@ -53,7 +53,7 @@ docker run -p 8080:8080 -v $(pwd)/data:/app/data maximhq/bifrost | Feature | Description | Learn More | | ----------------------------- | ------------------------------------------------------------------- | ---------------------------------------------------------- | | **🖥️ Built-in Web UI** | Visual configuration, live monitoring, request logs, and analytics | Open `http://localhost:8080` after startup | -| **🔄 Multi-Provider Support** | OpenAI, Anthropic, Azure, Bedrock, Vertex, Cohere, Mistral, Ollama, Groq | [Provider Setup](../docs/usage/providers.md) | +| **🔄 Multi-Provider Support** | OpenAI, Anthropic, Azure, Bedrock, Vertex, Cohere, Mistral, Ollama, Groq, SGLang | [Provider Setup](../docs/usage/providers.md) | | **🔌 Drop-in Compatibility** | Replace OpenAI/Anthropic/GenAI APIs with zero code changes | [Integrations](../docs/usage/http-transport/integrations/) | | **🛠️ MCP Tool Calling** | Enable AI models to use external tools (filesystem, web, databases) | [MCP Guide](../docs/mcp.md) | | **⚡ Plugin System** | Add analytics, caching, rate limiting, custom logic | [Plugin System](../docs/plugins.md) | diff --git a/transports/bifrost-http/handlers/providers.go b/transports/bifrost-http/handlers/providers.go index 0dd5afd964..8e83f2879e 100644 --- a/transports/bifrost-http/handlers/providers.go +++ b/transports/bifrost-http/handlers/providers.go @@ -153,7 +153,7 @@ func (h *ProviderHandler) AddProvider(ctx *fasthttp.RequestCtx) { } // Validate required keys - if len(req.Keys) == 0 && req.Provider != schemas.Vertex && req.Provider != schemas.Ollama { + if len(req.Keys) == 0 && req.Provider != schemas.Vertex && req.Provider != schemas.Ollama && req.Provider != schemas.SGL { SendError(ctx, fasthttp.StatusBadRequest, "At least one API key is required", h.logger) return } @@ -250,7 +250,7 @@ func (h *ProviderHandler) UpdateProvider(ctx *fasthttp.RequestCtx) { // Validate and process keys if req.Keys != nil { - if len(req.Keys) == 0 && provider != schemas.Vertex && provider != schemas.Ollama { + if len(req.Keys) == 0 && provider != schemas.Vertex && provider != schemas.Ollama && provider != schemas.SGL { SendError(ctx, fasthttp.StatusBadRequest, "At least one API key is required", h.logger) return } diff --git a/transports/bifrost-http/integrations/utils.go b/transports/bifrost-http/integrations/utils.go index 6b551fea18..30bb1e751f 100644 --- a/transports/bifrost-http/integrations/utils.go +++ b/transports/bifrost-http/integrations/utils.go @@ -140,7 +140,7 @@ type RouteConfig struct { ErrorConverter ErrorConverter // Function to convert BifrostError to integration format (SHOULD NOT BE NIL) StreamConfig *StreamConfig // Optional: Streaming configuration (if nil, streaming not supported) PreCallback PreRequestCallback // Optional: called before request processing - PostCallback PostRequestCallback // Optional: called after request processing (not supported for streaming) + PostCallback PostRequestCallback // Optional: called after request processing } // GenericRouter provides a reusable router implementation for all integrations. diff --git a/transports/bifrost-http/ui/404.html b/transports/bifrost-http/ui/404.html index 0bc5d2a1c8..a0fd813772 100644 --- a/transports/bifrost-http/ui/404.html +++ b/transports/bifrost-http/ui/404.html @@ -119,4 +119,4 @@ from { left: -80%; width: 80%; } to { left: 110%; width: 10%; } } -