diff --git a/docs/integrations.md b/docs/integrations.md new file mode 100644 index 0000000000..b420c8314a --- /dev/null +++ b/docs/integrations.md @@ -0,0 +1,478 @@ +# Bifrost Integrations Guide + +This guide shows how to use popular AI framework SDKs (LangChain, LangGraph, LiteLLM) with Bifrost by simply changing the base URL. This allows you to leverage all the benefits of Bifrost (fallbacks, load balancing, unified error handling) while using the familiar SDKs you already know. + +## Overview + +Bifrost provides integration endpoints that are compatible with popular AI framework SDKs: + +- **LangChain**: Compatible with `langchain-openai` and other LangChain providers +- **LangGraph**: Full compatibility with LangGraph workflows and state management +- **LiteLLM**: Direct OpenAI-compatible proxy for 100+ models +- **Direct Provider SDKs**: OpenAI, Anthropic, Google, Mistral SDKs + +## Quick Start + +### 1. LangChain Integration + +Use your existing LangChain code with Bifrost by changing the `base_url`: + +```python +# Before: Direct OpenAI +from langchain_openai import ChatOpenAI + +llm = ChatOpenAI( + model="gpt-4o", + api_key="your-openai-key" +) + +# After: Through Bifrost +llm = ChatOpenAI( + model="gpt-4o", + api_key="your-openai-key", + base_url="http://localhost:8080/integrations/langchain" # Bifrost LangChain endpoint +) +``` + +**Complete LangChain Example:** + +```python +from langchain_openai import ChatOpenAI +from langchain_core.messages import HumanMessage +import os + +# Configure your API keys +os.environ["OPENAI_API_KEY"] = "your-openai-key" + +# Create LangChain client pointing to Bifrost +llm = ChatOpenAI( + model="gpt-4o", + base_url="http://localhost:8080/integrations/langchain", + temperature=0.7 +) + +# Use normally - all LangChain features work +messages = [HumanMessage(content="Hello, how are you?")] +response = llm.invoke(messages) +print(response.content) + +# Streaming works too +for chunk in llm.stream(messages): + print(chunk.content, end="", flush=True) + +# Tool calling, batch processing, etc. all work as expected +``` + +**LangChain with Multiple Providers:** + +```python +from langchain_openai import ChatOpenAI +from langchain_anthropic import ChatAnthropic + +# OpenAI through Bifrost +openai_llm = ChatOpenAI( + model="gpt-4o", + base_url="http://localhost:8080/integrations/langchain" +) + +# Anthropic through Bifrost +anthropic_llm = ChatAnthropic( + model="claude-3-opus-20240229", + base_url="http://localhost:8080/integrations/langchain" +) + +# Both get Bifrost's benefits: fallbacks, load balancing, etc. +``` + +### 2. LangGraph Integration + +LangGraph workflows work seamlessly with Bifrost endpoints: + +```python +from langgraph.prebuilt import create_react_agent +from langchain_openai import ChatOpenAI +from langchain_core.tools import tool + +# Define a tool +@tool +def get_weather(location: str) -> str: + """Get the weather for a location.""" + return f"The weather in {location} is sunny." + +# Create LLM with Bifrost endpoint +llm = ChatOpenAI( + model="gpt-4o", + base_url="http://localhost:8080/integrations/langgraph" # LangGraph-specific endpoint +) + +# Create agent normally +agent = create_react_agent(llm, [get_weather]) + +# Run the agent +response = agent.invoke({ + "messages": [{"role": "user", "content": "What's the weather in Paris?"}] +}) +print(response["messages"][-1].content) +``` + +**Advanced LangGraph with State Management:** + +```python +from langgraph.graph import StateGraph, MessagesState +from langchain_openai import ChatOpenAI +from langchain_core.messages import HumanMessage + +# Define LLM with Bifrost +llm = ChatOpenAI( + model="gpt-4o", + base_url="http://localhost:8080/integrations/langgraph" +) + +# Define your graph nodes +def chatbot(state: MessagesState): + return {"messages": [llm.invoke(state["messages"])]} + +def human_feedback(state: MessagesState): + # Add human-in-the-loop logic + pass + +# Build the graph +graph = StateGraph(MessagesState) +graph.add_node("chatbot", chatbot) +graph.add_node("human", human_feedback) +graph.set_entry_point("chatbot") + +# Compile and run +app = graph.compile() +result = app.invoke({"messages": [HumanMessage(content="Hello")]}) +``` + +### 3. LiteLLM Integration + +LiteLLM provides the most direct integration - just change the base URL: + +```python +import litellm + +# Before: Direct to OpenAI +response = litellm.completion( + model="gpt-4o", + messages=[{"role": "user", "content": "Hello"}] +) + +# After: Through Bifrost +response = litellm.completion( + model="gpt-4o", + messages=[{"role": "user", "content": "Hello"}], + api_base="http://localhost:8080/integrations/litellm" # Bifrost LiteLLM endpoint +) +``` + +**LiteLLM with Multiple Providers:** + +```python +import litellm + +# Configure base URL globally +litellm.api_base = "http://localhost:8080/integrations/litellm" + +# Now all calls go through Bifrost +openai_response = litellm.completion( + model="gpt-4o", + messages=[{"role": "user", "content": "Hello from OpenAI"}] +) + +anthropic_response = litellm.completion( + model="claude-3-opus-20240229", + messages=[{"role": "user", "content": "Hello from Claude"}] +) + +# Embeddings work too +embeddings = litellm.embedding( + model="text-embedding-3-small", + input=["Hello world"] +) +``` + +**LiteLLM Router with Bifrost:** + +```python +import litellm +from litellm import Router + +# Create router with Bifrost endpoints +model_list = [ + { + "model_name": "gpt-4o", + "litellm_params": { + "model": "gpt-4o", + "api_base": "http://localhost:8080/integrations/litellm" + } + }, + { + "model_name": "claude-3", + "litellm_params": { + "model": "claude-3-opus-20240229", + "api_base": "http://localhost:8080/integrations/litellm" + } + } +] + +router = Router(model_list=model_list) + +# Router handles load balancing, Bifrost handles provider fallbacks +response = router.completion( + model="gpt-4o", + messages=[{"role": "user", "content": "Hello"}] +) +``` + +### 4. Direct Provider SDKs + +You can also use provider SDKs directly with Bifrost endpoints: + +**OpenAI Python SDK:** + +```python +from openai import OpenAI + +# Create client pointing to Bifrost +client = OpenAI( + api_key="your-openai-key", + base_url="http://localhost:8080/integrations/openai" +) + +# Use normally +response = client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "Hello"}] +) + +print(response.choices[0].message.content) +``` + +**Anthropic SDK:** + +```python +import anthropic + +# Create client pointing to Bifrost +client = anthropic.Anthropic( + api_key="your-anthropic-key", + base_url="http://localhost:8080/integrations/anthropic" +) + +# Use normally +response = client.messages.create( + model="claude-3-opus-20240229", + max_tokens=1000, + messages=[{"role": "user", "content": "Hello"}] +) + +print(response.content[0].text) +``` + +## Configuration + +### Environment Variables + +Set your provider API keys as usual: + +```bash +export OPENAI_API_KEY="your-openai-key" +export ANTHROPIC_API_KEY="your-anthropic-key" +export GOOGLE_API_KEY="your-google-key" +export MISTRAL_API_KEY="your-mistral-key" +``` + +### Bifrost Configuration + +Configure Bifrost with your providers in a JSON configuration file. Create a `config.json` file with your provider settings: + +```json +{ + "OpenAI": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": ["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"], + "weight": 1.0 + } + ], + "network_config": { + "default_request_timeout_in_seconds": 30, + "max_retries": 3, + "retry_backoff_initial_ms": 100, + "retry_backoff_max_ms": 2000 + }, + "concurrency_and_buffer_size": { + "concurrency": 5, + "buffer_size": 10 + } + }, + "Anthropic": { + "keys": [ + { + "value": "env.ANTHROPIC_API_KEY", + "models": ["claude-3-5-sonnet-20240620", "claude-3-haiku-20240307"], + "weight": 1.0 + } + ], + "network_config": { + "default_request_timeout_in_seconds": 30, + "max_retries": 3 + } + }, + "Bedrock": { + "keys": [ + { + "value": "env.BEDROCK_API_KEY", + "models": ["anthropic.claude-3-sonnet-20240229-v1:0"], + "weight": 1.0 + } + ], + "meta_config": { + "secret_access_key": "env.BEDROCK_ACCESS_KEY", + "region": "us-east-1" + } + }, + "Azure": { + "keys": [ + { + "value": "env.AZURE_API_KEY", + "models": ["gpt-4o"], + "weight": 1.0 + } + ], + "meta_config": { + "endpoint": "env.AZURE_ENDPOINT", + "deployments": { + "gpt-4o": "gpt-4o-deployment-name" + }, + "api_version": "2024-08-01-preview" + } + } +} +``` + +**Environment Variables Setup:** + +Set the corresponding environment variables referenced in your config: + +```bash +export OPENAI_API_KEY="your-openai-key" +export ANTHROPIC_API_KEY="your-anthropic-key" +export BEDROCK_API_KEY="your-bedrock-key" +export BEDROCK_ACCESS_KEY="your-aws-access-key" +export AZURE_API_KEY="your-azure-key" +export AZURE_ENDPOINT="https://your-resource.openai.azure.com/" +``` + +**Running Bifrost HTTP Server:** + +```bash +# Start the server with your configuration +go run transports/bifrost-http/main.go \ + -config config.json \ + -port 8080 \ + -pool-size 300 +``` + +**Configuration Options:** + +- **Keys**: Multiple API keys per provider with model restrictions and load balancing weights +- **Network Config**: Timeout settings, retry policies, and backoff strategies +- **Meta Config**: Provider-specific settings (AWS regions, Azure endpoints, etc.) +- **Concurrency**: Control request concurrency and buffer sizes per provider + +**Load Balancing:** + +Bifrost automatically load balances requests across multiple keys for the same provider based on the `weight` parameter and model availability. + +## Integration Endpoints + +Bifrost provides these integration endpoints: + +| Framework | Endpoint | Purpose | +| --------- | ------------------------- | ---------------------------------- | +| LangChain | `/integrations/langchain` | LangChain SDK compatibility | +| LangGraph | `/integrations/langgraph` | LangGraph workflows and agents | +| LiteLLM | `/integrations/litellm` | LiteLLM proxy compatibility | +| OpenAI | `/integrations/openai` | Direct OpenAI SDK compatibility | +| Anthropic | `/integrations/anthropic` | Direct Anthropic SDK compatibility | +| Google | `/integrations/genai` | Google GenAI SDK compatibility | +| Mistral | `/integrations/mistral` | Mistral SDK compatibility | + +## Benefits + +By using SDKs with Bifrost endpoints, you get: + +### 🔄 **Automatic Fallbacks** + +- If one provider fails, Bifrost automatically tries configured fallbacks +- No code changes needed - handled transparently + +### ⚖️ **Load Balancing** + +- Distribute requests across multiple provider instances +- Reduce rate limiting and improve reliability + +### 📊 **Unified Monitoring** + +- All requests flow through Bifrost for consistent logging and metrics +- Track usage, costs, and performance across all providers + +### 🛡️ **Error Handling** + +- Standardized error handling and retry logic +- Graceful degradation when providers are unavailable + +### 💰 **Cost Optimization** + +- Route to cost-effective alternatives based on your rules +- Track spending across all providers in one place + +## Migration + +### From Direct Provider Calls + +```python +# Before +from openai import OpenAI +client = OpenAI(api_key="...") + +# After +from openai import OpenAI +client = OpenAI( + api_key="...", + base_url="http://localhost:8080/integrations/openai" +) +``` + +### From LangChain + +```python +# Before +from langchain_openai import ChatOpenAI +llm = ChatOpenAI(model="gpt-4o") + +# After +from langchain_openai import ChatOpenAI +llm = ChatOpenAI( + model="gpt-4o", + base_url="http://localhost:8080/integrations/langchain" +) +``` + +### From LiteLLM + +```python +# Before +import litellm +response = litellm.completion(model="gpt-4o", ...) + +# After +import litellm +litellm.api_base = "http://localhost:8080/integrations/litellm" +response = litellm.completion(model="gpt-4o", ...) +``` diff --git a/transports/bifrost-http/integrations/anthropic/router.go b/transports/bifrost-http/integrations/anthropic/router.go new file mode 100644 index 0000000000..72657963f2 --- /dev/null +++ b/transports/bifrost-http/integrations/anthropic/router.go @@ -0,0 +1,69 @@ +package anthropic + +import ( + "encoding/json" + + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// AnthropicRouter holds route registrations for Anthropic endpoints. +type AnthropicRouter struct { + client *bifrost.Bifrost +} + +// NewAnthropicRouter creates a new AnthropicRouter with the given bifrost client. +func NewAnthropicRouter(client *bifrost.Bifrost) *AnthropicRouter { + return &AnthropicRouter{client: client} +} + +// RegisterRoutes registers all Anthropic routes on the given router. +func (a *AnthropicRouter) RegisterRoutes(r *router.Router) { + r.POST("/anthropic/v1/messages", a.handleMessages) +} + +// handleMessages handles POST /v1/messages +func (a *AnthropicRouter) handleMessages(ctx *fasthttp.RequestCtx) { + var req AnthropicMessageRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetContentType("application/json") + errResponse := map[string]string{"error": err.Error()} + jsonBytes, _ := json.Marshal(errResponse) + ctx.SetBody(jsonBytes) + return + } + + if req.Model == "" { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString("Model parameter is required") + return + } + + if req.MaxTokens == 0 { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString("max_tokens parameter is required") + return + } + + bifrostReq := req.ConvertToBifrostRequest() + + bifrostCtx := lib.ConvertToBifrostContext(ctx) + + result, err := a.client.ChatCompletionRequest(*bifrostCtx, bifrostReq) + if err != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.SetContentType("application/json") + jsonBytes, _ := json.Marshal(err) + ctx.SetBody(jsonBytes) + return + } + + anthropicResponse := DeriveAnthropicFromBifrostResponse(result) + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/json") + jsonBytes, _ := json.Marshal(anthropicResponse) + ctx.SetBody(jsonBytes) +} diff --git a/transports/bifrost-http/integrations/anthropic/types.go b/transports/bifrost-http/integrations/anthropic/types.go new file mode 100644 index 0000000000..c44ada870f --- /dev/null +++ b/transports/bifrost-http/integrations/anthropic/types.go @@ -0,0 +1,336 @@ +package anthropic + +import ( + "encoding/json" + + "github.com/maximhq/bifrost/core/schemas" +) + +// Helper function to create string pointer +func stringPtr(s string) *string { + return &s +} + +var fnTypePtr = stringPtr(string(schemas.ToolChoiceTypeFunction)) + +// AnthropicContent represents content in Anthropic message format +type AnthropicContent struct { + Type string `json:"type"` // "text", "image", "tool_use", "tool_result" + Text *string `json:"text,omitempty"` // For text content + ToolUseID *string `json:"tool_use_id,omitempty"` // For tool_result content + ID *string `json:"id,omitempty"` // For tool_use content + Name *string `json:"name,omitempty"` // For tool_use content + Input interface{} `json:"input,omitempty"` // For tool_use content + Content interface{} `json:"content,omitempty"` // For tool_result content + Source *AnthropicImageSource `json:"source,omitempty"` // For image content +} + +// AnthropicImageSource represents image source in Anthropic format +type AnthropicImageSource struct { + Type string `json:"type"` // "base64" + MediaType string `json:"media_type"` // "image/jpeg", "image/png", etc. + Data string `json:"data"` // Base64-encoded image data +} + +// AnthropicMessage represents a message in Anthropic format +type AnthropicMessage struct { + Role string `json:"role"` // "user", "assistant" + Content []AnthropicContent `json:"content"` // Array of content blocks +} + +// AnthropicTool represents a tool in Anthropic format +type AnthropicTool struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema interface{} `json:"input_schema"` +} + +// AnthropicToolChoice represents tool choice in Anthropic format +type AnthropicToolChoice struct { + Type string `json:"type"` // "auto", "any", "tool" + Name string `json:"name,omitempty"` // For type "tool" +} + +// AnthropicMessageRequest represents an Anthropic messages API request +type AnthropicMessageRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + Messages []AnthropicMessage `json:"messages"` + System *string `json:"system,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK *int `json:"top_k,omitempty"` + StopSequences *[]string `json:"stop_sequences,omitempty"` + Stream *bool `json:"stream,omitempty"` + Tools *[]AnthropicTool `json:"tools,omitempty"` + ToolChoice *AnthropicToolChoice `json:"tool_choice,omitempty"` +} + +// ConvertToBifrostRequest converts an Anthropic messages request to Bifrost format +func (r *AnthropicMessageRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { + bifrostReq := &schemas.BifrostRequest{ + Provider: schemas.Anthropic, + Model: r.Model, + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{}, + }, + } + + // Add system message if present + if r.System != nil && *r.System != "" { + systemMsg := schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleSystem, + Content: r.System, + } + *bifrostReq.Input.ChatCompletionInput = append(*bifrostReq.Input.ChatCompletionInput, systemMsg) + } + + // Convert messages + for _, msg := range r.Messages { + var bifrostMsg schemas.BifrostMessage + bifrostMsg.Role = schemas.ModelChatMessageRole(msg.Role) + + // Handle different content types + var textContent string + var toolCalls []schemas.ToolCall + var toolCallID *string + + for _, content := range msg.Content { + switch content.Type { + case "text": + if content.Text != nil { + textContent += *content.Text + } + case "tool_use": + if content.ID != nil && content.Name != nil { + tc := schemas.ToolCall{ + Type: fnTypePtr, + ID: content.ID, + Function: schemas.FunctionCall{ + Name: content.Name, + Arguments: jsonifyInput(content.Input), + }, + } + toolCalls = append(toolCalls, tc) + } + case "tool_result": + if content.ToolUseID != nil { + toolCallID = content.ToolUseID + if content.Content != nil { + if contentStr, ok := content.Content.(string); ok { + textContent += contentStr + } + } + } + } + } + + if textContent != "" { + bifrostMsg.Content = &textContent + } + + if len(toolCalls) > 0 { + bifrostMsg.AssistantMessage = &schemas.AssistantMessage{ + ToolCalls: &toolCalls, + } + } + + if toolCallID != nil { + bifrostMsg.ToolMessage = &schemas.ToolMessage{ + ToolCallID: toolCallID, + } + } + + *bifrostReq.Input.ChatCompletionInput = append(*bifrostReq.Input.ChatCompletionInput, bifrostMsg) + } + + // Convert parameters + if r.MaxTokens > 0 || r.Temperature != nil || r.TopP != nil || r.TopK != nil || r.StopSequences != nil { + params := &schemas.ModelParameters{} + + if r.MaxTokens > 0 { + params.MaxTokens = &r.MaxTokens + } + if r.Temperature != nil { + params.Temperature = r.Temperature + } + if r.TopP != nil { + params.TopP = r.TopP + } + if r.TopK != nil { + params.TopK = r.TopK + } + if r.StopSequences != nil { + params.StopSequences = r.StopSequences + } + + bifrostReq.Params = params + } + + // Convert tools + if r.Tools != nil { + tools := []schemas.Tool{} + for _, tool := range *r.Tools { + // Convert input_schema to FunctionParameters + params := schemas.FunctionParameters{ + Type: "object", + } + if tool.InputSchema != nil { + if schemaMap, ok := tool.InputSchema.(map[string]interface{}); ok { + if typeVal, ok := schemaMap["type"].(string); ok { + params.Type = typeVal + } + if desc, ok := schemaMap["description"].(string); ok { + params.Description = &desc + } + if required, ok := schemaMap["required"].([]interface{}); ok { + reqStrings := make([]string, len(required)) + for i, req := range required { + if reqStr, ok := req.(string); ok { + reqStrings[i] = reqStr + } + } + params.Required = reqStrings + } + if properties, ok := schemaMap["properties"].(map[string]interface{}); ok { + params.Properties = properties + } + if enum, ok := schemaMap["enum"].([]interface{}); ok { + enumStrings := make([]string, len(enum)) + for i, e := range enum { + if eStr, ok := e.(string); ok { + enumStrings[i] = eStr + } + } + params.Enum = &enumStrings + } + } + } + + t := schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: tool.Name, + Description: tool.Description, + Parameters: params, + }, + } + tools = append(tools, t) + } + if bifrostReq.Params == nil { + bifrostReq.Params = &schemas.ModelParameters{} + } + bifrostReq.Params.Tools = &tools + + // Convert tool choice + if r.ToolChoice != nil { + toolChoice := &schemas.ToolChoice{ + Type: schemas.ToolChoiceType(r.ToolChoice.Type), + } + if r.ToolChoice.Type == "tool" && r.ToolChoice.Name != "" { + toolChoice.Function = schemas.ToolChoiceFunction{ + Name: r.ToolChoice.Name, + } + } + bifrostReq.Params.ToolChoice = toolChoice + } + } + + return bifrostReq +} + +// Helper function to convert interface{} to JSON string +func jsonifyInput(input interface{}) string { + if input == nil { + return "{}" + } + jsonBytes, err := json.Marshal(input) + if err != nil { + return "{}" + } + return string(jsonBytes) +} + +// AnthropicMessageResponse represents an Anthropic messages API response +type AnthropicMessageResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []AnthropicContent `json:"content"` + Model string `json:"model"` + StopReason *string `json:"stop_reason,omitempty"` + StopSequence *string `json:"stop_sequence,omitempty"` + Usage *AnthropicUsage `json:"usage,omitempty"` +} + +// AnthropicUsage represents usage information in Anthropic format +type AnthropicUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +// DeriveAnthropicFromBifrostResponse converts a Bifrost response to Anthropic format +func DeriveAnthropicFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *AnthropicMessageResponse { + if bifrostResp == nil { + return nil + } + + anthropicResp := &AnthropicMessageResponse{ + ID: bifrostResp.ID, + Type: "message", + Role: "assistant", + Model: bifrostResp.Model, + } + + // Convert usage information + if bifrostResp.Usage != (schemas.LLMUsage{}) { + anthropicResp.Usage = &AnthropicUsage{ + InputTokens: bifrostResp.Usage.PromptTokens, + OutputTokens: bifrostResp.Usage.CompletionTokens, + } + } + + // Convert choices to content + var content []AnthropicContent + if len(bifrostResp.Choices) > 0 { + choice := bifrostResp.Choices[0] // Anthropic typically returns one choice + + if choice.FinishReason != nil { + anthropicResp.StopReason = choice.FinishReason + } + + // Add text content + if choice.Message.Content != nil && *choice.Message.Content != "" { + content = append(content, AnthropicContent{ + Type: "text", + Text: choice.Message.Content, + }) + } + + // Add tool calls as tool_use content + if choice.Message.AssistantMessage != nil && choice.Message.AssistantMessage.ToolCalls != nil { + for _, toolCall := range *choice.Message.AssistantMessage.ToolCalls { + // Parse arguments JSON string back to map + var input map[string]interface{} + if toolCall.Function.Arguments != "" { + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil { + input = map[string]interface{}{} + } + } else { + input = map[string]interface{}{} + } + + tc := AnthropicContent{ + Type: "tool_use", + ID: toolCall.ID, + Name: toolCall.Function.Name, + Input: input, + } + content = append(content, tc) + } + } + } + + anthropicResp.Content = content + return anthropicResp +} diff --git a/transports/bifrost-http/integrations/genai/router.go b/transports/bifrost-http/integrations/genai/router.go index 798dd184be..e2ddb5e367 100644 --- a/transports/bifrost-http/integrations/genai/router.go +++ b/transports/bifrost-http/integrations/genai/router.go @@ -41,7 +41,9 @@ func (g *GenAIRouter) handleChatCompletion(ctx *fasthttp.RequestCtx) { var req GeminiChatRequest if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { ctx.SetStatusCode(fasthttp.StatusBadRequest) - json.NewEncoder(ctx).Encode(err) + ctx.SetContentType("application/json") + jsonBytes, _ := json.Marshal(err) + ctx.SetBody(jsonBytes) return } @@ -52,12 +54,15 @@ func (g *GenAIRouter) handleChatCompletion(ctx *fasthttp.RequestCtx) { result, err := g.client.ChatCompletionRequest(*bifrostCtx, bifrostReq) if err != nil { ctx.SetStatusCode(fasthttp.StatusInternalServerError) - json.NewEncoder(ctx).Encode(err) + ctx.SetContentType("application/json") + jsonBytes, _ := json.Marshal(err) + ctx.SetBody(jsonBytes) return } genAIResponse := DeriveGenAIFromBifrostResponse(result) ctx.SetStatusCode(fasthttp.StatusOK) ctx.SetContentType("application/json") - json.NewEncoder(ctx).Encode(genAIResponse) + jsonBytes, _ := json.Marshal(genAIResponse) + ctx.SetBody(jsonBytes) } diff --git a/transports/bifrost-http/integrations/langchain/router.go b/transports/bifrost-http/integrations/langchain/router.go new file mode 100644 index 0000000000..ed6742d3e7 --- /dev/null +++ b/transports/bifrost-http/integrations/langchain/router.go @@ -0,0 +1,145 @@ +package langchain + +import ( + "encoding/json" + + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// LangChainRouter holds route registrations for LangChain endpoints. +type LangChainRouter struct { + client *bifrost.Bifrost +} + +// NewLangChainRouter creates a new LangChainRouter with the given bifrost client. +func NewLangChainRouter(client *bifrost.Bifrost) *LangChainRouter { + return &LangChainRouter{client: client} +} + +// RegisterRoutes registers all LangChain routes on the given router. +func (l *LangChainRouter) RegisterRoutes(r *router.Router) { + r.POST("/langchain/chat", l.handleChatInvoke) + r.POST("/langchain/invoke", l.handleInvoke) + r.POST("/langchain/batch", l.handleBatch) + r.POST("/langchain/stream", l.handleStream) +} + +// handleChatInvoke handles POST /langchain/chat - simplified chat interface +func (l *LangChainRouter) handleChatInvoke(ctx *fasthttp.RequestCtx) { + var req LangChainChatRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + json.NewEncoder(ctx).Encode(map[string]string{"error": err.Error()}) + return + } + + if req.Model == "" { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString("Model parameter is required") + return + } + + bifrostReq := req.ConvertToBifrostRequest() + + bifrostCtx := lib.ConvertToBifrostContext(ctx) + + result, err := l.client.ChatCompletionRequest(*bifrostCtx, bifrostReq) + if err != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + json.NewEncoder(ctx).Encode(err) + return + } + + langchainResponse := DeriveLangChainFromBifrostResponse(result) + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/json") + json.NewEncoder(ctx).Encode(langchainResponse) +} + +// handleInvoke handles POST /langchain/invoke - general invoke interface +func (l *LangChainRouter) handleInvoke(ctx *fasthttp.RequestCtx) { + var req LangChainInvokeRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + json.NewEncoder(ctx).Encode(map[string]string{"error": err.Error()}) + return + } + + bifrostReq := req.ConvertToBifrostRequest() + + bifrostCtx := lib.ConvertToBifrostContext(ctx) + + var result *schemas.BifrostResponse + var err *schemas.BifrostError + + if req.Type == "chat" { + result, err = l.client.ChatCompletionRequest(*bifrostCtx, bifrostReq) + } else { + result, err = l.client.TextCompletionRequest(*bifrostCtx, bifrostReq) + } + + if err != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + json.NewEncoder(ctx).Encode(err) + return + } + + langchainResponse := DeriveLangChainInvokeFromBifrostResponse(result, req.Type) + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/json") + json.NewEncoder(ctx).Encode(langchainResponse) +} + +// handleBatch handles POST /langchain/batch - batch processing +func (l *LangChainRouter) handleBatch(ctx *fasthttp.RequestCtx) { + var req LangChainBatchRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + json.NewEncoder(ctx).Encode(map[string]string{"error": err.Error()}) + return + } + + bifrostCtx := lib.ConvertToBifrostContext(ctx) + var responses []LangChainInvokeResponse + + // Process each request in the batch + for _, input := range req.Inputs { + bifrostReq := input.ConvertToBifrostRequest() + + var result *schemas.BifrostResponse + var err *schemas.BifrostError + + if input.Type == "chat" { + result, err = l.client.ChatCompletionRequest(*bifrostCtx, bifrostReq) + } else { + result, err = l.client.TextCompletionRequest(*bifrostCtx, bifrostReq) + } + + if err != nil { + // Add error response to batch + responses = append(responses, LangChainInvokeResponse{ + Error: err.Error.Message, + }) + } else { + langchainResponse := DeriveLangChainInvokeFromBifrostResponse(result, input.Type) + responses = append(responses, *langchainResponse) + } + } + + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/json") + json.NewEncoder(ctx).Encode(LangChainBatchResponse{ + Results: responses, + }) +} + +// handleStream handles POST /langchain/stream - streaming interface (simplified) +func (l *LangChainRouter) handleStream(ctx *fasthttp.RequestCtx) { + // For now, we'll return a non-streaming response + // In production, you'd implement Server-Sent Events or WebSocket streaming + l.handleInvoke(ctx) +} diff --git a/transports/bifrost-http/integrations/langchain/types.go b/transports/bifrost-http/integrations/langchain/types.go new file mode 100644 index 0000000000..4ac3a11039 --- /dev/null +++ b/transports/bifrost-http/integrations/langchain/types.go @@ -0,0 +1,447 @@ +package langchain + +import ( + "encoding/json" + + "github.com/maximhq/bifrost/core/schemas" +) + +// LangChain-style message types + +// LangChainMessage represents a message in LangChain format +type LangChainMessage struct { + Type string `json:"type"` // "human", "ai", "system", "tool" + Content string `json:"content"` // Text content + Name *string `json:"name,omitempty"` // Optional name + ToolCalls *[]LangChainToolCall `json:"tool_calls,omitempty"` // For AI messages with tool calls + ToolCallID *string `json:"tool_call_id,omitempty"` // For tool messages +} + +// LangChainToolCall represents a tool call in LangChain format +type LangChainToolCall struct { + Name string `json:"name"` + Args map[string]interface{} `json:"args"` + ID *string `json:"id,omitempty"` +} + +// LangChainTool represents a tool definition in LangChain format +type LangChainTool struct { + Name string `json:"name"` + Description string `json:"description"` + ArgsSchema map[string]interface{} `json:"args_schema"` +} + +// Request structures + +// LangChainChatRequest represents a simplified LangChain chat request +type LangChainChatRequest struct { + Model string `json:"model"` + Provider *string `json:"provider,omitempty"` // Optional explicit provider + Messages []LangChainMessage `json:"messages"` + Temperature *float64 `json:"temperature,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Tools *[]LangChainTool `json:"tools,omitempty"` + StopWords *[]string `json:"stop_words,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// ConvertToBifrostRequest converts a LangChain chat request to Bifrost format +func (r *LangChainChatRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { + provider := schemas.OpenAI // Default + if r.Provider != nil { + provider = schemas.ModelProvider(*r.Provider) + } + + bifrostReq := &schemas.BifrostRequest{ + Provider: provider, + Model: r.Model, + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{}, + }, + } + + // Convert messages + for _, msg := range r.Messages { + var bifrostMsg schemas.BifrostMessage + + // Map LangChain message types to Bifrost roles + switch msg.Type { + case "human": + bifrostMsg.Role = schemas.ModelChatMessageRoleUser + case "ai": + bifrostMsg.Role = schemas.ModelChatMessageRoleAssistant + case "system": + bifrostMsg.Role = schemas.ModelChatMessageRoleSystem + case "tool": + bifrostMsg.Role = schemas.ModelChatMessageRoleTool + default: + bifrostMsg.Role = schemas.ModelChatMessageRoleUser + } + + bifrostMsg.Content = &msg.Content + + // Handle tool calls for AI messages + if msg.ToolCalls != nil { + toolCalls := []schemas.ToolCall{} + for _, toolCall := range *msg.ToolCalls { + // Convert args map to JSON string + argsJSON := "{}" + if len(toolCall.Args) > 0 { + // In production, use proper JSON marshaling + argsJSON = mapToJSONString(toolCall.Args) + } + + tc := schemas.ToolCall{ + Type: stringPtr("function"), + ID: toolCall.ID, + Function: schemas.FunctionCall{ + Name: &toolCall.Name, + Arguments: argsJSON, + }, + } + toolCalls = append(toolCalls, tc) + } + bifrostMsg.AssistantMessage = &schemas.AssistantMessage{ + ToolCalls: &toolCalls, + } + } + + // Handle tool messages + if msg.ToolCallID != nil { + bifrostMsg.ToolMessage = &schemas.ToolMessage{ + ToolCallID: msg.ToolCallID, + } + } + + *bifrostReq.Input.ChatCompletionInput = append(*bifrostReq.Input.ChatCompletionInput, bifrostMsg) + } + + // Convert parameters + if r.Temperature != nil || r.MaxTokens != nil || r.TopP != nil || r.StopWords != nil { + params := &schemas.ModelParameters{} + + if r.Temperature != nil { + params.Temperature = r.Temperature + } + if r.MaxTokens != nil { + params.MaxTokens = r.MaxTokens + } + if r.TopP != nil { + params.TopP = r.TopP + } + if r.StopWords != nil { + params.StopSequences = r.StopWords + } + + bifrostReq.Params = params + } + + // Convert tools + if r.Tools != nil { + tools := []schemas.Tool{} + for _, tool := range *r.Tools { + // Convert args_schema to FunctionParameters + params := schemas.FunctionParameters{ + Type: "object", + } + if tool.ArgsSchema != nil { + if typeVal, ok := tool.ArgsSchema["type"].(string); ok { + params.Type = typeVal + } + if desc, ok := tool.ArgsSchema["description"].(string); ok { + params.Description = &desc + } + if required, ok := tool.ArgsSchema["required"].([]interface{}); ok { + reqStrings := make([]string, len(required)) + for i, req := range required { + if reqStr, ok := req.(string); ok { + reqStrings[i] = reqStr + } + } + params.Required = reqStrings + } + if properties, ok := tool.ArgsSchema["properties"].(map[string]interface{}); ok { + params.Properties = properties + } + if enum, ok := tool.ArgsSchema["enum"].([]interface{}); ok { + enumStrings := make([]string, len(enum)) + for i, e := range enum { + if eStr, ok := e.(string); ok { + enumStrings[i] = eStr + } + } + params.Enum = &enumStrings + } + } + + t := schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: tool.Name, + Description: tool.Description, + Parameters: params, + }, + } + tools = append(tools, t) + } + if bifrostReq.Params == nil { + bifrostReq.Params = &schemas.ModelParameters{} + } + bifrostReq.Params.Tools = &tools + } + + return bifrostReq +} + +// LangChainInvokeRequest represents a general LangChain invoke request +type LangChainInvokeRequest struct { + Type string `json:"type"` // "chat" or "completion" + Model string `json:"model"` + Provider *string `json:"provider,omitempty"` + Input interface{} `json:"input"` // Can be string or messages array + Config map[string]interface{} `json:"config,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// ConvertToBifrostRequest converts a LangChain invoke request to Bifrost format +func (r *LangChainInvokeRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { + provider := schemas.OpenAI // Default + if r.Provider != nil { + provider = schemas.ModelProvider(*r.Provider) + } + + bifrostReq := &schemas.BifrostRequest{ + Provider: provider, + Model: r.Model, + } + + if r.Type == "chat" { + // Handle chat input - could be messages array or string + if messages, ok := r.Input.([]interface{}); ok { + chatInput := []schemas.BifrostMessage{} + for _, msgInterface := range messages { + if msgMap, ok := msgInterface.(map[string]interface{}); ok { + msg := schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleUser, + } + if content, ok := msgMap["content"].(string); ok { + msg.Content = &content + } + if role, ok := msgMap["type"].(string); ok { + switch role { + case "human": + msg.Role = schemas.ModelChatMessageRoleUser + case "ai": + msg.Role = schemas.ModelChatMessageRoleAssistant + case "system": + msg.Role = schemas.ModelChatMessageRoleSystem + } + } + chatInput = append(chatInput, msg) + } + } + bifrostReq.Input = schemas.RequestInput{ + ChatCompletionInput: &chatInput, + } + } else if inputStr, ok := r.Input.(string); ok { + // Single string input - convert to user message + chatInput := []schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: &inputStr, + }, + } + bifrostReq.Input = schemas.RequestInput{ + ChatCompletionInput: &chatInput, + } + } + } else { + // Text completion + if inputStr, ok := r.Input.(string); ok { + bifrostReq.Input = schemas.RequestInput{ + TextCompletionInput: &inputStr, + } + } + } + + // Convert config to parameters + if r.Config != nil { + params := &schemas.ModelParameters{} + + if temp, ok := r.Config["temperature"].(float64); ok { + params.Temperature = &temp + } + if maxTokens, ok := r.Config["max_tokens"].(float64); ok { + maxTokensInt := int(maxTokens) + params.MaxTokens = &maxTokensInt + } + if topP, ok := r.Config["top_p"].(float64); ok { + params.TopP = &topP + } + + bifrostReq.Params = params + } + + return bifrostReq +} + +// LangChainBatchRequest represents a batch request +type LangChainBatchRequest struct { + Inputs []LangChainInvokeRequest `json:"inputs"` + Config map[string]interface{} `json:"config,omitempty"` +} + +// Response structures + +// LangChainChatResponse represents a LangChain chat response +type LangChainChatResponse struct { + Type string `json:"type"` // "ai" + Content string `json:"content"` + ToolCalls *[]LangChainToolCall `json:"tool_calls,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + Usage *LangChainUsage `json:"usage,omitempty"` +} + +// LangChainInvokeResponse represents a general invoke response +type LangChainInvokeResponse struct { + Output interface{} `json:"output"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + Usage *LangChainUsage `json:"usage,omitempty"` + Error string `json:"error,omitempty"` +} + +// LangChainBatchResponse represents a batch response +type LangChainBatchResponse struct { + Results []LangChainInvokeResponse `json:"results"` +} + +// LangChainUsage represents usage information +type LangChainUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// Conversion functions + +// DeriveLangChainFromBifrostResponse converts a Bifrost response to LangChain chat format +func DeriveLangChainFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *LangChainChatResponse { + if bifrostResp == nil { + return nil + } + + response := &LangChainChatResponse{ + Type: "ai", + Metadata: make(map[string]interface{}), + } + + // Convert usage information + if bifrostResp.Usage != (schemas.LLMUsage{}) { + response.Usage = &LangChainUsage{ + InputTokens: bifrostResp.Usage.PromptTokens, + OutputTokens: bifrostResp.Usage.CompletionTokens, + TotalTokens: bifrostResp.Usage.TotalTokens, + } + } + + // Get first choice + if len(bifrostResp.Choices) > 0 { + choice := bifrostResp.Choices[0] + + if choice.Message.Content != nil { + response.Content = *choice.Message.Content + } + + // Convert tool calls + if choice.Message.AssistantMessage != nil && choice.Message.AssistantMessage.ToolCalls != nil { + toolCalls := []LangChainToolCall{} + for _, toolCall := range *choice.Message.AssistantMessage.ToolCalls { + tc := LangChainToolCall{ + Name: *toolCall.Function.Name, + Args: stringToArgsMap(toolCall.Function.Arguments), + ID: toolCall.ID, + } + toolCalls = append(toolCalls, tc) + } + response.ToolCalls = &toolCalls + } + + // Add metadata + response.Metadata["model"] = bifrostResp.Model + if choice.FinishReason != nil { + response.Metadata["finish_reason"] = *choice.FinishReason + } + } + + return response +} + +// DeriveLangChainInvokeFromBifrostResponse converts a Bifrost response to LangChain invoke format +func DeriveLangChainInvokeFromBifrostResponse(bifrostResp *schemas.BifrostResponse, requestType string) *LangChainInvokeResponse { + if bifrostResp == nil { + return nil + } + + response := &LangChainInvokeResponse{ + Metadata: make(map[string]interface{}), + } + + // Convert usage information + if bifrostResp.Usage != (schemas.LLMUsage{}) { + response.Usage = &LangChainUsage{ + InputTokens: bifrostResp.Usage.PromptTokens, + OutputTokens: bifrostResp.Usage.CompletionTokens, + TotalTokens: bifrostResp.Usage.TotalTokens, + } + } + + // Set output based on request type + if requestType == "chat" { + if len(bifrostResp.Choices) > 0 { + choice := bifrostResp.Choices[0] + chatMsg := map[string]interface{}{ + "type": "ai", + } + if choice.Message.Content != nil { + chatMsg["content"] = *choice.Message.Content + } + response.Output = chatMsg + } + } else { + // Text completion + if len(bifrostResp.Choices) > 0 && bifrostResp.Choices[0].Message.Content != nil { + response.Output = *bifrostResp.Choices[0].Message.Content + } + } + + // Add metadata + response.Metadata["model"] = bifrostResp.Model + if len(bifrostResp.Choices) > 0 && bifrostResp.Choices[0].FinishReason != nil { + response.Metadata["finish_reason"] = *bifrostResp.Choices[0].FinishReason + } + + return response +} + +// Helper functions + +func stringPtr(s string) *string { + return &s +} + +func mapToJSONString(m map[string]interface{}) string { + jsonBytes, err := json.Marshal(m) + if err != nil { + return "{}" + } + return string(jsonBytes) +} + +func stringToArgsMap(s string) map[string]interface{} { + var result map[string]interface{} + if err := json.Unmarshal([]byte(s), &result); err != nil { + return make(map[string]interface{}) + } + return result +} diff --git a/transports/bifrost-http/integrations/langgraph/router.go b/transports/bifrost-http/integrations/langgraph/router.go new file mode 100644 index 0000000000..9dfeedbb76 --- /dev/null +++ b/transports/bifrost-http/integrations/langgraph/router.go @@ -0,0 +1,329 @@ +package langgraph + +import ( + "encoding/json" + + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// LangGraphRouter holds route registrations for LangGraph endpoints. +type LangGraphRouter struct { + client *bifrost.Bifrost +} + +// NewLangGraphRouter creates a new LangGraphRouter with the given bifrost client. +func NewLangGraphRouter(client *bifrost.Bifrost) *LangGraphRouter { + return &LangGraphRouter{client: client} +} + +// RegisterRoutes registers all LangGraph routes on the given router. +func (lg *LangGraphRouter) RegisterRoutes(r *router.Router) { + r.POST("/langgraph/invoke", lg.handleInvoke) + r.POST("/langgraph/stream", lg.handleStream) + r.POST("/langgraph/batch", lg.handleBatch) + r.POST("/langgraph/astream", lg.handleAsyncStream) + r.POST("/langgraph/graph/create", lg.handleCreateGraph) + r.POST("/langgraph/graph/execute", lg.handleExecuteGraph) +} + +// handleInvoke handles POST /langgraph/invoke - execute a graph +func (lg *LangGraphRouter) handleInvoke(ctx *fasthttp.RequestCtx) { + var req LangGraphInvokeRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetContentType("application/json") + errResponse := map[string]string{"error": err.Error()} + jsonBytes, _ := json.Marshal(errResponse) + ctx.SetBody(jsonBytes) + return + } + + response := lg.executeGraph(ctx, &req) + + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/json") + jsonBytes, _ := json.Marshal(response) + ctx.SetBody(jsonBytes) +} + +// handleStream handles POST /langgraph/stream - streaming graph execution +func (lg *LangGraphRouter) handleStream(ctx *fasthttp.RequestCtx) { + var req LangGraphStreamRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetContentType("application/json") + errResponse := map[string]string{"error": err.Error()} + jsonBytes, _ := json.Marshal(errResponse) + ctx.SetBody(jsonBytes) + return + } + + // For now, we'll execute synchronously and return the final result + // In production, you'd implement streaming using Server-Sent Events + invokeReq := LangGraphInvokeRequest{ + Graph: req.Graph, + Input: req.Input, + Config: req.Config, + } + + response := lg.executeGraph(ctx, &invokeReq) + + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/json") + json.NewEncoder(ctx).Encode(response) +} + +// handleBatch handles POST /langgraph/batch - batch graph execution +func (lg *LangGraphRouter) handleBatch(ctx *fasthttp.RequestCtx) { + var req LangGraphBatchRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + json.NewEncoder(ctx).Encode(map[string]string{"error": err.Error()}) + return + } + + var responses []LangGraphInvokeResponse + for _, input := range req.Inputs { + response := lg.executeGraph(ctx, &input) + responses = append(responses, *response) + } + + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/json") + json.NewEncoder(ctx).Encode(LangGraphBatchResponse{ + Results: responses, + }) +} + +// handleAsyncStream handles POST /langgraph/astream - async streaming +func (lg *LangGraphRouter) handleAsyncStream(ctx *fasthttp.RequestCtx) { + // For simplicity, redirect to regular stream + lg.handleStream(ctx) +} + +// handleCreateGraph handles POST /langgraph/graph/create - create a graph definition +func (lg *LangGraphRouter) handleCreateGraph(ctx *fasthttp.RequestCtx) { + var req LangGraphCreateRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + json.NewEncoder(ctx).Encode(map[string]string{"error": err.Error()}) + return + } + + // In production, you'd store the graph definition in a database + response := LangGraphCreateResponse{ + GraphID: generateGraphID(), + Status: "created", + Message: "Graph created successfully", + Graph: req.Graph, + } + + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/json") + json.NewEncoder(ctx).Encode(response) +} + +// handleExecuteGraph handles POST /langgraph/graph/execute - execute a stored graph +func (lg *LangGraphRouter) handleExecuteGraph(ctx *fasthttp.RequestCtx) { + var req LangGraphExecuteRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + json.NewEncoder(ctx).Encode(map[string]string{"error": err.Error()}) + return + } + + // In production, you'd load the graph from storage using req.GraphID + // For now, we'll create a simple mock execution + response := &LangGraphInvokeResponse{ + Output: map[string]interface{}{ + "result": "Graph execution completed", + "graph_id": req.GraphID, + }, + Metadata: map[string]interface{}{ + "execution_id": generateExecutionID(), + "status": "completed", + }, + } + + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/json") + json.NewEncoder(ctx).Encode(response) +} + +// executeGraph executes a graph by processing nodes sequentially +func (lg *LangGraphRouter) executeGraph(ctx *fasthttp.RequestCtx, req *LangGraphInvokeRequest) *LangGraphInvokeResponse { + if req.Graph == nil { + return &LangGraphInvokeResponse{ + Error: "Graph definition is required", + } + } + + bifrostCtx := lib.ConvertToBifrostContext(ctx) + + // Start with the initial input + currentState := req.Input + var finalOutput interface{} + var totalUsage *LangGraphUsage + + // Execute nodes in sequence (simplified graph execution) + for _, node := range req.Graph.Nodes { + // Convert node to bifrost request + bifrostReq := lg.nodeToRequest(node, currentState, req.Config) + + var result *schemas.BifrostResponse + var err *schemas.BifrostError + + // Execute based on node type + if node.Type == "chat" { + result, err = lg.client.ChatCompletionRequest(*bifrostCtx, bifrostReq) + } else { + result, err = lg.client.TextCompletionRequest(*bifrostCtx, bifrostReq) + } + + if err != nil { + return &LangGraphInvokeResponse{ + Error: err.Error.Message, + } + } + + // Update state with result + if len(result.Choices) > 0 && result.Choices[0].Message.Content != nil { + currentState = map[string]interface{}{ + "content": *result.Choices[0].Message.Content, + "previous_state": currentState, + } + finalOutput = currentState + } + + // Accumulate usage + if result.Usage != (schemas.LLMUsage{}) { + if totalUsage == nil { + totalUsage = &LangGraphUsage{} + } + totalUsage.InputTokens += result.Usage.PromptTokens + totalUsage.OutputTokens += result.Usage.CompletionTokens + totalUsage.TotalTokens += result.Usage.TotalTokens + } + } + + return &LangGraphInvokeResponse{ + Output: finalOutput, + Metadata: map[string]interface{}{ + "execution_id": generateExecutionID(), + "nodes_executed": len(req.Graph.Nodes), + }, + Usage: totalUsage, + } +} + +// nodeToRequest converts a graph node to a Bifrost request +func (lg *LangGraphRouter) nodeToRequest(node LangGraphNode, state interface{}, config map[string]interface{}) *schemas.BifrostRequest { + provider := schemas.OpenAI // Default + if node.Provider != nil { + provider = schemas.ModelProvider(*node.Provider) + } + + bifrostReq := &schemas.BifrostRequest{ + Provider: provider, + Model: node.Model, + } + + // Convert input based on node type + if node.Type == "chat" { + // Create a user message from the current state + var content string + if stateMap, ok := state.(map[string]interface{}); ok { + if contentStr, ok := stateMap["content"].(string); ok { + content = contentStr + } else { + content = "Continue the conversation" + } + } else if stateStr, ok := state.(string); ok { + content = stateStr + } else { + content = "Start conversation" + } + + messages := []schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: &content, + }, + } + + bifrostReq.Input = schemas.RequestInput{ + ChatCompletionInput: &messages, + } + } else { + // Text completion + var prompt string + if stateMap, ok := state.(map[string]interface{}); ok { + if contentStr, ok := stateMap["content"].(string); ok { + prompt = contentStr + } else { + prompt = "Complete the following:" + } + } else if stateStr, ok := state.(string); ok { + prompt = stateStr + } else { + prompt = "Generate text" + } + + bifrostReq.Input = schemas.RequestInput{ + TextCompletionInput: &prompt, + } + } + + // Apply node-specific parameters + if len(node.Parameters) > 0 || config != nil { + params := &schemas.ModelParameters{} + + // Apply node parameters + if temp, ok := node.Parameters["temperature"].(float64); ok { + params.Temperature = &temp + } + if maxTokens, ok := node.Parameters["max_tokens"].(float64); ok { + maxTokensInt := int(maxTokens) + params.MaxTokens = &maxTokensInt + } + + // Apply config parameters + if config != nil { + if temp, ok := config["temperature"].(float64); ok { + params.Temperature = &temp + } + if maxTokens, ok := config["max_tokens"].(float64); ok { + maxTokensInt := int(maxTokens) + params.MaxTokens = &maxTokensInt + } + } + + bifrostReq.Params = params + } + + return bifrostReq +} + +// Helper functions +func generateGraphID() string { + // In production, use UUID or similar + return "graph_" + randomString(8) +} + +func generateExecutionID() string { + // In production, use UUID or similar + return "exec_" + randomString(8) +} + +func randomString(length int) string { + const charset = "abcdefghijklmnopqrstuvwxyz0123456789" + b := make([]byte, length) + for i := range b { + b[i] = charset[i%len(charset)] + } + return string(b) +} diff --git a/transports/bifrost-http/integrations/langgraph/types.go b/transports/bifrost-http/integrations/langgraph/types.go new file mode 100644 index 0000000000..407fea5781 --- /dev/null +++ b/transports/bifrost-http/integrations/langgraph/types.go @@ -0,0 +1,405 @@ +package langgraph + +import ( + "encoding/json" + + "github.com/maximhq/bifrost/core/schemas" +) + +// LangGraph core types + +// LangGraphNode represents a node in a LangGraph workflow +type LangGraphNode struct { + ID string `json:"id"` + Type string `json:"type"` // "chat", "completion", "tool", "conditional" + Model string `json:"model"` + Provider *string `json:"provider,omitempty"` + Parameters map[string]interface{} `json:"parameters,omitempty"` + Prompt *string `json:"prompt,omitempty"` + Tools *[]LangGraphTool `json:"tools,omitempty"` + Condition *string `json:"condition,omitempty"` // For conditional nodes +} + +// LangGraphEdge represents an edge connecting nodes in the graph +type LangGraphEdge struct { + From string `json:"from"` + To string `json:"to"` + Condition *string `json:"condition,omitempty"` // Optional condition for conditional edges + Transform *string `json:"transform,omitempty"` // Optional data transformation +} + +// LangGraphTool represents a tool available to graph nodes +type LangGraphTool struct { + Name string `json:"name"` + Description string `json:"description"` + ArgsSchema map[string]interface{} `json:"args_schema"` + Function *string `json:"function,omitempty"` // Function reference +} + +// LangGraphDefinition represents a complete graph workflow +type LangGraphDefinition struct { + Name string `json:"name"` + Description *string `json:"description,omitempty"` + Nodes []LangGraphNode `json:"nodes"` + Edges []LangGraphEdge `json:"edges"` + StartNode string `json:"start_node"` + EndNodes []string `json:"end_nodes"` + Variables map[string]interface{} `json:"variables,omitempty"` +} + +// Request types + +// LangGraphInvokeRequest represents a request to execute a graph +type LangGraphInvokeRequest struct { + Graph *LangGraphDefinition `json:"graph"` + Input interface{} `json:"input"` + Config map[string]interface{} `json:"config,omitempty"` + ThreadID *string `json:"thread_id,omitempty"` // For conversation threading +} + +// LangGraphStreamRequest represents a streaming graph execution request +type LangGraphStreamRequest struct { + Graph *LangGraphDefinition `json:"graph"` + Input interface{} `json:"input"` + Config map[string]interface{} `json:"config,omitempty"` + StreamMode *string `json:"stream_mode,omitempty"` // "values", "updates", "debug" +} + +// LangGraphBatchRequest represents a batch graph execution request +type LangGraphBatchRequest struct { + Inputs []LangGraphInvokeRequest `json:"inputs"` + Config map[string]interface{} `json:"config,omitempty"` +} + +// LangGraphCreateRequest represents a request to create/store a graph +type LangGraphCreateRequest struct { + Graph LangGraphDefinition `json:"graph"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// LangGraphExecuteRequest represents a request to execute a stored graph +type LangGraphExecuteRequest struct { + GraphID string `json:"graph_id"` + Input interface{} `json:"input"` + Config map[string]interface{} `json:"config,omitempty"` +} + +// Response types + +// LangGraphInvokeResponse represents the response from graph execution +type LangGraphInvokeResponse struct { + Output interface{} `json:"output"` + State map[string]interface{} `json:"state,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + Usage *LangGraphUsage `json:"usage,omitempty"` + Error string `json:"error,omitempty"` + ThreadID *string `json:"thread_id,omitempty"` +} + +// LangGraphBatchResponse represents a batch execution response +type LangGraphBatchResponse struct { + Results []LangGraphInvokeResponse `json:"results"` +} + +// LangGraphCreateResponse represents the response from creating a graph +type LangGraphCreateResponse struct { + GraphID string `json:"graph_id"` + Status string `json:"status"` + Message string `json:"message"` + Graph LangGraphDefinition `json:"graph"` +} + +// LangGraphStreamEvent represents a single event in a stream +type LangGraphStreamEvent struct { + Event string `json:"event"` // "on_node_start", "on_node_end", "on_edge", etc. + NodeID *string `json:"node_id,omitempty"` + Data interface{} `json:"data"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// LangGraphUsage represents usage information for graph execution +type LangGraphUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` + NodesExecuted int `json:"nodes_executed"` + ToolCallsMade int `json:"tool_calls_made,omitempty"` +} + +// State management types + +// LangGraphState represents the current state of a graph execution +type LangGraphState struct { + CurrentNode string `json:"current_node"` + Variables map[string]interface{} `json:"variables"` + History []LangGraphStep `json:"history"` + ThreadID *string `json:"thread_id,omitempty"` +} + +// LangGraphStep represents a single execution step +type LangGraphStep struct { + NodeID string `json:"node_id"` + Input interface{} `json:"input"` + Output interface{} `json:"output"` + Timestamp string `json:"timestamp"` + Duration *float64 `json:"duration,omitempty"` // In milliseconds + Usage *LangGraphUsage `json:"usage,omitempty"` +} + +// ConvertToBifrostRequest converts a graph invoke request to a simplified Bifrost request +// This is used when the graph has only a single node or for simplified execution +func (r *LangGraphInvokeRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { + if r.Graph == nil || len(r.Graph.Nodes) == 0 { + return nil + } + + // Use the first node for simplified conversion + node := r.Graph.Nodes[0] + + provider := schemas.OpenAI // Default + if node.Provider != nil { + provider = schemas.ModelProvider(*node.Provider) + } + + bifrostReq := &schemas.BifrostRequest{ + Provider: provider, + Model: node.Model, + } + + // Convert input based on node type and input + if node.Type == "chat" { + // Handle chat input + var messages []schemas.BifrostMessage + + if inputStr, ok := r.Input.(string); ok { + // Simple string input + messages = []schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: &inputStr, + }, + } + } else if inputMap, ok := r.Input.(map[string]interface{}); ok { + // Structured input + if messagesArray, ok := inputMap["messages"].([]interface{}); ok { + for _, msgInterface := range messagesArray { + if msgMap, ok := msgInterface.(map[string]interface{}); ok { + msg := schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleUser, + } + if content, ok := msgMap["content"].(string); ok { + msg.Content = &content + } + if role, ok := msgMap["role"].(string); ok { + msg.Role = schemas.ModelChatMessageRole(role) + } + messages = append(messages, msg) + } + } + } else if content, ok := inputMap["content"].(string); ok { + messages = []schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: &content, + }, + } + } + } + + if len(messages) == 0 { + // Fallback + defaultContent := "Start conversation" + messages = []schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: &defaultContent, + }, + } + } + + bifrostReq.Input = schemas.RequestInput{ + ChatCompletionInput: &messages, + } + } else { + // Text completion + var prompt string + if inputStr, ok := r.Input.(string); ok { + prompt = inputStr + } else if inputMap, ok := r.Input.(map[string]interface{}); ok { + if promptStr, ok := inputMap["prompt"].(string); ok { + prompt = promptStr + } else if contentStr, ok := inputMap["content"].(string); ok { + prompt = contentStr + } else { + prompt = "Generate text" + } + } else { + prompt = "Generate text" + } + + bifrostReq.Input = schemas.RequestInput{ + TextCompletionInput: &prompt, + } + } + + // Convert parameters + if len(node.Parameters) > 0 || r.Config != nil { + params := &schemas.ModelParameters{} + + // Apply node parameters + if temp, ok := node.Parameters["temperature"].(float64); ok { + params.Temperature = &temp + } + if maxTokens, ok := node.Parameters["max_tokens"].(float64); ok { + maxTokensInt := int(maxTokens) + params.MaxTokens = &maxTokensInt + } + if topP, ok := node.Parameters["top_p"].(float64); ok { + params.TopP = &topP + } + + // Apply config parameters (override node parameters) + if r.Config != nil { + if temp, ok := r.Config["temperature"].(float64); ok { + params.Temperature = &temp + } + if maxTokens, ok := r.Config["max_tokens"].(float64); ok { + maxTokensInt := int(maxTokens) + params.MaxTokens = &maxTokensInt + } + if topP, ok := r.Config["top_p"].(float64); ok { + params.TopP = &topP + } + } + + bifrostReq.Params = params + } + + // Convert tools if available + if node.Tools != nil { + tools := []schemas.Tool{} + for _, tool := range *node.Tools { + // Convert args_schema to FunctionParameters + params := schemas.FunctionParameters{ + Type: "object", + } + if tool.ArgsSchema != nil { + if typeVal, ok := tool.ArgsSchema["type"].(string); ok { + params.Type = typeVal + } + if desc, ok := tool.ArgsSchema["description"].(string); ok { + params.Description = &desc + } + if required, ok := tool.ArgsSchema["required"].([]interface{}); ok { + reqStrings := make([]string, len(required)) + for i, req := range required { + if reqStr, ok := req.(string); ok { + reqStrings[i] = reqStr + } + } + params.Required = reqStrings + } + if properties, ok := tool.ArgsSchema["properties"].(map[string]interface{}); ok { + params.Properties = properties + } + if enum, ok := tool.ArgsSchema["enum"].([]interface{}); ok { + enumStrings := make([]string, len(enum)) + for i, e := range enum { + if eStr, ok := e.(string); ok { + enumStrings[i] = eStr + } + } + params.Enum = &enumStrings + } + } + + t := schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: tool.Name, + Description: tool.Description, + Parameters: params, + }, + } + tools = append(tools, t) + } + if bifrostReq.Params == nil { + bifrostReq.Params = &schemas.ModelParameters{} + } + bifrostReq.Params.Tools = &tools + } + + return bifrostReq +} + +// Helper functions for graph execution + +// GetStartNode returns the starting node of the graph +func (g *LangGraphDefinition) GetStartNode() *LangGraphNode { + for _, node := range g.Nodes { + if node.ID == g.StartNode { + return &node + } + } + return nil +} + +// GetNode returns a node by ID +func (g *LangGraphDefinition) GetNode(id string) *LangGraphNode { + for _, node := range g.Nodes { + if node.ID == id { + return &node + } + } + return nil +} + +// GetNextNodes returns the next nodes connected to the given node +func (g *LangGraphDefinition) GetNextNodes(nodeID string) []LangGraphNode { + var nextNodes []LangGraphNode + for _, edge := range g.Edges { + if edge.From == nodeID { + if nextNode := g.GetNode(edge.To); nextNode != nil { + nextNodes = append(nextNodes, *nextNode) + } + } + } + return nextNodes +} + +// IsEndNode checks if a node is an end node +func (g *LangGraphDefinition) IsEndNode(nodeID string) bool { + for _, endNodeID := range g.EndNodes { + if endNodeID == nodeID { + return true + } + } + return false +} + +// Validate checks if the graph definition is valid +func (g *LangGraphDefinition) Validate() error { + // Check if start node exists + if g.GetStartNode() == nil { + return json.NewEncoder(nil).Encode(map[string]string{"error": "Start node not found"}) + } + + // Check if all referenced nodes in edges exist + for _, edge := range g.Edges { + if g.GetNode(edge.From) == nil { + return json.NewEncoder(nil).Encode(map[string]string{"error": "Edge references non-existent 'from' node: " + edge.From}) + } + if g.GetNode(edge.To) == nil { + return json.NewEncoder(nil).Encode(map[string]string{"error": "Edge references non-existent 'to' node: " + edge.To}) + } + } + + // Check if end nodes exist + for _, endNodeID := range g.EndNodes { + if g.GetNode(endNodeID) == nil { + return json.NewEncoder(nil).Encode(map[string]string{"error": "End node not found: " + endNodeID}) + } + } + + return nil +} diff --git a/transports/bifrost-http/integrations/litellm/router.go b/transports/bifrost-http/integrations/litellm/router.go new file mode 100644 index 0000000000..26f5b01990 --- /dev/null +++ b/transports/bifrost-http/integrations/litellm/router.go @@ -0,0 +1,104 @@ +package litellm + +import ( + "encoding/json" + + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// LiteLLMRouter holds route registrations for LiteLLM endpoints. +type LiteLLMRouter struct { + client *bifrost.Bifrost +} + +// NewLiteLLMRouter creates a new LiteLLMRouter with the given bifrost client. +func NewLiteLLMRouter(client *bifrost.Bifrost) *LiteLLMRouter { + return &LiteLLMRouter{client: client} +} + +// RegisterRoutes registers all LiteLLM routes on the given router. +func (l *LiteLLMRouter) RegisterRoutes(r *router.Router) { + r.POST("/litellm/chat/completions", l.handleChatCompletion) + r.POST("/litellm/v1/chat/completions", l.handleChatCompletion) + r.POST("/litellm/completions", l.handleCompletion) + r.POST("/litellm/v1/completions", l.handleCompletion) +} + +// handleChatCompletion handles POST /chat/completions and /v1/chat/completions +func (l *LiteLLMRouter) handleChatCompletion(ctx *fasthttp.RequestCtx) { + var req LiteLLMChatRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetContentType("application/json") + errResponse := map[string]string{"error": err.Error()} + jsonBytes, _ := json.Marshal(errResponse) + ctx.SetBody(jsonBytes) + return + } + + if req.Model == "" { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString("Model parameter is required") + return + } + + bifrostReq := req.ConvertToBifrostRequest() + + bifrostCtx := lib.ConvertToBifrostContext(ctx) + + result, err := l.client.ChatCompletionRequest(*bifrostCtx, bifrostReq) + if err != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.SetContentType("application/json") + jsonBytes, _ := json.Marshal(err) + ctx.SetBody(jsonBytes) + return + } + + litellmResponse := DeriveLiteLLMFromBifrostResponse(result) + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/json") + jsonBytes, _ := json.Marshal(litellmResponse) + ctx.SetBody(jsonBytes) +} + +// handleCompletion handles POST /completions and /v1/completions +func (l *LiteLLMRouter) handleCompletion(ctx *fasthttp.RequestCtx) { + var req LiteLLMCompletionRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetContentType("application/json") + errResponse := map[string]string{"error": err.Error()} + jsonBytes, _ := json.Marshal(errResponse) + ctx.SetBody(jsonBytes) + return + } + + if req.Model == "" { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString("Model parameter is required") + return + } + + bifrostReq := req.ConvertToBifrostRequest() + + bifrostCtx := lib.ConvertToBifrostContext(ctx) + + result, err := l.client.TextCompletionRequest(*bifrostCtx, bifrostReq) + if err != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.SetContentType("application/json") + jsonBytes, _ := json.Marshal(err) + ctx.SetBody(jsonBytes) + return + } + + litellmResponse := DeriveLiteLLMCompletionFromBifrostResponse(result) + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/json") + jsonBytes, _ := json.Marshal(litellmResponse) + ctx.SetBody(jsonBytes) +} diff --git a/transports/bifrost-http/integrations/litellm/types.go b/transports/bifrost-http/integrations/litellm/types.go new file mode 100644 index 0000000000..7c76569072 --- /dev/null +++ b/transports/bifrost-http/integrations/litellm/types.go @@ -0,0 +1,494 @@ +package litellm + +import ( + "strings" + + "github.com/maximhq/bifrost/core/schemas" +) + +// Helper function to create string pointer +func stringPtr(s string) *string { + return &s +} + +// LiteLLM provides OpenAI-compatible API, so we'll use similar structures +// with support for multiple provider model routing + +// LiteLLMMessage represents a message in LiteLLM chat format (OpenAI-compatible) +type LiteLLMMessage struct { + Role string `json:"role"` + Content *string `json:"content,omitempty"` + Name *string `json:"name,omitempty"` + ToolCalls *[]LiteLLMToolCall `json:"tool_calls,omitempty"` + ToolCallID *string `json:"tool_call_id,omitempty"` +} + +// LiteLLMToolCall represents a tool call in LiteLLM format +type LiteLLMToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function LiteLLMFunctionCall `json:"function"` +} + +// LiteLLMFunctionCall represents a function call in LiteLLM format +type LiteLLMFunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// LiteLLMTool represents a tool in LiteLLM format +type LiteLLMTool struct { + Type string `json:"type"` + Function LiteLLMFunction `json:"function"` +} + +// LiteLLMFunction represents a function definition in LiteLLM format +type LiteLLMFunction struct { + Name string `json:"name"` + Description *string `json:"description,omitempty"` + Parameters interface{} `json:"parameters,omitempty"` +} + +// LiteLLMChatRequest represents a LiteLLM chat completion request +type LiteLLMChatRequest struct { + Model string `json:"model"` + Messages []LiteLLMMessage `json:"messages"` + MaxTokens *int `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + N *int `json:"n,omitempty"` + Stop interface{} `json:"stop,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + LogitBias map[string]float64 `json:"logit_bias,omitempty"` + User *string `json:"user,omitempty"` + Tools *[]LiteLLMTool `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + Stream *bool `json:"stream,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + // LiteLLM-specific parameters + APIBase *string `json:"api_base,omitempty"` + APIVersion *string `json:"api_version,omitempty"` + APIKey *string `json:"api_key,omitempty"` + Drop_params *bool `json:"drop_params,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// ConvertToBifrostRequest converts a LiteLLM chat request to Bifrost format +func (r *LiteLLMChatRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { + // LiteLLM can route to any provider, but we'll determine the appropriate provider from the model + provider := determineProviderFromModel(r.Model) + + bifrostReq := &schemas.BifrostRequest{ + Provider: provider, + Model: r.Model, + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{}, + }, + } + + // Convert messages + for _, msg := range r.Messages { + var bifrostMsg schemas.BifrostMessage + bifrostMsg.Role = schemas.ModelChatMessageRole(msg.Role) + bifrostMsg.Content = msg.Content + + // Handle tool calls for assistant messages + if msg.ToolCalls != nil { + toolCalls := []schemas.ToolCall{} + for _, toolCall := range *msg.ToolCalls { + tc := schemas.ToolCall{ + Type: stringPtr(toolCall.Type), + ID: &toolCall.ID, + Function: schemas.FunctionCall{ + Name: &toolCall.Function.Name, + Arguments: toolCall.Function.Arguments, + }, + } + toolCalls = append(toolCalls, tc) + } + bifrostMsg.AssistantMessage = &schemas.AssistantMessage{ + ToolCalls: &toolCalls, + } + } + + // Handle tool messages + if msg.ToolCallID != nil { + bifrostMsg.ToolMessage = &schemas.ToolMessage{ + ToolCallID: msg.ToolCallID, + } + } + + *bifrostReq.Input.ChatCompletionInput = append(*bifrostReq.Input.ChatCompletionInput, bifrostMsg) + } + + // Convert parameters + if r.MaxTokens != nil || r.Temperature != nil || r.TopP != nil || r.PresencePenalty != nil || + r.FrequencyPenalty != nil || r.N != nil { + params := &schemas.ModelParameters{ + ExtraParams: make(map[string]interface{}), + } + + if r.MaxTokens != nil { + params.MaxTokens = r.MaxTokens + } + if r.Temperature != nil { + params.Temperature = r.Temperature + } + if r.TopP != nil { + params.TopP = r.TopP + } + if r.PresencePenalty != nil { + params.PresencePenalty = r.PresencePenalty + } + if r.FrequencyPenalty != nil { + params.FrequencyPenalty = r.FrequencyPenalty + } + if r.N != nil { + params.ExtraParams["n"] = r.N + } + + // Add LiteLLM-specific params + if r.APIBase != nil { + params.ExtraParams["api_base"] = r.APIBase + } + if r.APIVersion != nil { + params.ExtraParams["api_version"] = r.APIVersion + } + if r.Drop_params != nil { + params.ExtraParams["drop_params"] = r.Drop_params + } + if r.Metadata != nil { + params.ExtraParams["metadata"] = r.Metadata + } + + bifrostReq.Params = params + } + + // Convert tools + if r.Tools != nil { + tools := []schemas.Tool{} + for _, tool := range *r.Tools { + description := "" + if tool.Function.Description != nil { + description = *tool.Function.Description + } + + // Convert parameters interface{} to FunctionParameters + params := schemas.FunctionParameters{ + Type: "object", + } + if tool.Function.Parameters != nil { + if paramMap, ok := tool.Function.Parameters.(map[string]interface{}); ok { + if typeVal, ok := paramMap["type"].(string); ok { + params.Type = typeVal + } + if desc, ok := paramMap["description"].(string); ok { + params.Description = &desc + } + if required, ok := paramMap["required"].([]interface{}); ok { + reqStrings := make([]string, len(required)) + for i, req := range required { + if reqStr, ok := req.(string); ok { + reqStrings[i] = reqStr + } + } + params.Required = reqStrings + } + if properties, ok := paramMap["properties"].(map[string]interface{}); ok { + params.Properties = properties + } + if enum, ok := paramMap["enum"].([]interface{}); ok { + enumStrings := make([]string, len(enum)) + for i, e := range enum { + if eStr, ok := e.(string); ok { + enumStrings[i] = eStr + } + } + params.Enum = &enumStrings + } + } + } + + t := schemas.Tool{ + Type: tool.Type, + Function: schemas.Function{ + Name: tool.Function.Name, + Description: description, + Parameters: params, + }, + } + tools = append(tools, t) + } + if bifrostReq.Params == nil { + bifrostReq.Params = &schemas.ModelParameters{} + } + bifrostReq.Params.Tools = &tools + } + + return bifrostReq +} + +// LiteLLMCompletionRequest represents a LiteLLM text completion request +type LiteLLMCompletionRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + MaxTokens *int `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + N *int `json:"n,omitempty"` + Stream *bool `json:"stream,omitempty"` + LogProbs *int `json:"logprobs,omitempty"` + Echo *bool `json:"echo,omitempty"` + Stop interface{} `json:"stop,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + BestOf *int `json:"best_of,omitempty"` + LogitBias map[string]float64 `json:"logit_bias,omitempty"` + User *string `json:"user,omitempty"` + // LiteLLM-specific parameters + APIBase *string `json:"api_base,omitempty"` + APIVersion *string `json:"api_version,omitempty"` + APIKey *string `json:"api_key,omitempty"` + Drop_params *bool `json:"drop_params,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// ConvertToBifrostRequest converts a LiteLLM completion request to Bifrost format +func (r *LiteLLMCompletionRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { + provider := determineProviderFromModel(r.Model) + + bifrostReq := &schemas.BifrostRequest{ + Provider: provider, + Model: r.Model, + Input: schemas.RequestInput{ + TextCompletionInput: &r.Prompt, + }, + } + + // Convert parameters + if r.MaxTokens != nil || r.Temperature != nil || r.TopP != nil || r.PresencePenalty != nil || + r.FrequencyPenalty != nil || r.N != nil { + params := &schemas.ModelParameters{ + ExtraParams: make(map[string]interface{}), + } + + if r.MaxTokens != nil { + params.MaxTokens = r.MaxTokens + } + if r.Temperature != nil { + params.Temperature = r.Temperature + } + if r.TopP != nil { + params.TopP = r.TopP + } + if r.PresencePenalty != nil { + params.PresencePenalty = r.PresencePenalty + } + if r.FrequencyPenalty != nil { + params.FrequencyPenalty = r.FrequencyPenalty + } + if r.N != nil { + params.ExtraParams["n"] = r.N + } + if r.LogProbs != nil { + params.ExtraParams["logprobs"] = r.LogProbs + } + if r.Echo != nil { + params.ExtraParams["echo"] = r.Echo + } + if r.BestOf != nil { + params.ExtraParams["best_of"] = r.BestOf + } + + // Add LiteLLM-specific params + if r.APIBase != nil { + params.ExtraParams["api_base"] = r.APIBase + } + if r.APIVersion != nil { + params.ExtraParams["api_version"] = r.APIVersion + } + if r.Drop_params != nil { + params.ExtraParams["drop_params"] = r.Drop_params + } + if r.Metadata != nil { + params.ExtraParams["metadata"] = r.Metadata + } + + bifrostReq.Params = params + } + + return bifrostReq +} + +// Helper function to determine provider from model name +func determineProviderFromModel(model string) schemas.ModelProvider { + // LiteLLM uses prefixes or model names to determine provider + // This is a simplified version - in production you'd have more sophisticated routing + if contains(model, "gpt") || contains(model, "o1") { + return schemas.OpenAI + } else if contains(model, "claude") { + return schemas.Anthropic + } else if contains(model, "gemini") || contains(model, "vertex") { + return schemas.Vertex + } else if contains(model, "bedrock") { + return schemas.Bedrock + } else if contains(model, "cohere") { + return schemas.Cohere + } + // Default to OpenAI for unknown models + return schemas.OpenAI +} + +// Helper function to check if string contains substring +func contains(s, substr string) bool { + return strings.Contains(s, substr) +} + +// Response structures + +// LiteLLMChatResponse represents a LiteLLM chat completion response +type LiteLLMChatResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + Choices []LiteLLMChoice `json:"choices"` + Usage *LiteLLMUsage `json:"usage,omitempty"` +} + +// LiteLLMChoice represents a choice in the LiteLLM response +type LiteLLMChoice struct { + Index int `json:"index"` + Message LiteLLMMessage `json:"message"` + FinishReason *string `json:"finish_reason,omitempty"` +} + +// LiteLLMCompletionResponse represents a LiteLLM text completion response +type LiteLLMCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + Choices []LiteLLMCompletionChoice `json:"choices"` + Usage *LiteLLMUsage `json:"usage,omitempty"` +} + +// LiteLLMCompletionChoice represents a choice in the LiteLLM completion response +type LiteLLMCompletionChoice struct { + Text string `json:"text"` + Index int `json:"index"` + FinishReason *string `json:"finish_reason,omitempty"` +} + +// LiteLLMUsage represents usage information in LiteLLM format +type LiteLLMUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// DeriveLiteLLMFromBifrostResponse converts a Bifrost response to LiteLLM chat format +func DeriveLiteLLMFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *LiteLLMChatResponse { + if bifrostResp == nil { + return nil + } + + litellmResp := &LiteLLMChatResponse{ + ID: bifrostResp.ID, + Object: "chat.completion", + Created: bifrostResp.Created, + Model: bifrostResp.Model, + Choices: make([]LiteLLMChoice, len(bifrostResp.Choices)), + } + + // Convert usage information + if bifrostResp.Usage != (schemas.LLMUsage{}) { + litellmResp.Usage = &LiteLLMUsage{ + PromptTokens: bifrostResp.Usage.PromptTokens, + CompletionTokens: bifrostResp.Usage.CompletionTokens, + TotalTokens: bifrostResp.Usage.TotalTokens, + } + } + + // Convert choices + for i, choice := range bifrostResp.Choices { + litellmChoice := LiteLLMChoice{ + Index: choice.Index, + FinishReason: choice.FinishReason, + } + + // Convert message + msg := LiteLLMMessage{ + Role: string(choice.Message.Role), + Content: choice.Message.Content, + } + + // Convert tool calls + if choice.Message.AssistantMessage != nil && choice.Message.AssistantMessage.ToolCalls != nil { + toolCalls := []LiteLLMToolCall{} + for _, toolCall := range *choice.Message.AssistantMessage.ToolCalls { + tc := LiteLLMToolCall{ + Type: *toolCall.Type, + Function: LiteLLMFunctionCall{ + Name: *toolCall.Function.Name, + Arguments: toolCall.Function.Arguments, + }, + } + if toolCall.ID != nil { + tc.ID = *toolCall.ID + } + toolCalls = append(toolCalls, tc) + } + msg.ToolCalls = &toolCalls + } + + litellmChoice.Message = msg + litellmResp.Choices[i] = litellmChoice + } + + return litellmResp +} + +// DeriveLiteLLMCompletionFromBifrostResponse converts a Bifrost response to LiteLLM completion format +func DeriveLiteLLMCompletionFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *LiteLLMCompletionResponse { + if bifrostResp == nil { + return nil + } + + litellmResp := &LiteLLMCompletionResponse{ + ID: bifrostResp.ID, + Object: "text_completion", + Created: bifrostResp.Created, + Model: bifrostResp.Model, + Choices: make([]LiteLLMCompletionChoice, len(bifrostResp.Choices)), + } + + // Convert usage information + if bifrostResp.Usage != (schemas.LLMUsage{}) { + litellmResp.Usage = &LiteLLMUsage{ + PromptTokens: bifrostResp.Usage.PromptTokens, + CompletionTokens: bifrostResp.Usage.CompletionTokens, + TotalTokens: bifrostResp.Usage.TotalTokens, + } + } + + // Convert choices + for i, choice := range bifrostResp.Choices { + text := "" + if choice.Message.Content != nil { + text = *choice.Message.Content + } + + litellmChoice := LiteLLMCompletionChoice{ + Text: text, + Index: choice.Index, + FinishReason: choice.FinishReason, + } + + litellmResp.Choices[i] = litellmChoice + } + + return litellmResp +} diff --git a/transports/bifrost-http/integrations/mistral/router.go b/transports/bifrost-http/integrations/mistral/router.go new file mode 100644 index 0000000000..68604e601d --- /dev/null +++ b/transports/bifrost-http/integrations/mistral/router.go @@ -0,0 +1,74 @@ +package mistral + +import ( + "encoding/json" + + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// MistralRouter holds route registrations for Mistral endpoints. +type MistralRouter struct { + client *bifrost.Bifrost +} + +// NewMistralRouter creates a new MistralRouter with the given bifrost client. +func NewMistralRouter(client *bifrost.Bifrost) *MistralRouter { + return &MistralRouter{client: client} +} + +// RegisterRoutes registers all Mistral routes on the given router. +func (m *MistralRouter) RegisterRoutes(r *router.Router) { + r.POST("/mistral/v1/chat/completions", m.handleChatCompletion) +} + +// handleChatCompletion handles POST /v1/chat/completions +func (m *MistralRouter) handleChatCompletion(ctx *fasthttp.RequestCtx) { + var req MistralChatRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetContentType("application/json") + errResponse := map[string]string{"error": err.Error()} + jsonBytes, _ := json.Marshal(errResponse) + ctx.SetBody(jsonBytes) + return + } + + if req.Model == "" { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString("Model parameter is required") + return + } + + bifrostReq := req.ConvertToBifrostRequest() + + bifrostCtx := lib.ConvertToBifrostContext(ctx) + + result, err := m.client.ChatCompletionRequest(*bifrostCtx, bifrostReq) + if err != nil { + // Determine appropriate HTTP status code based on error details + statusCode := fasthttp.StatusInternalServerError // Default to 500 + + // If the error has a specific status code from the provider, use it + if err.StatusCode != nil { + statusCode = *err.StatusCode + } else if !err.IsBifrostError { + // If it's not a Bifrost internal error, treat as client error + statusCode = fasthttp.StatusBadRequest + } + + ctx.SetStatusCode(statusCode) + ctx.SetContentType("application/json") + jsonBytes, _ := json.Marshal(err) + ctx.SetBody(jsonBytes) + return + } + + mistralResponse := DeriveMistralFromBifrostResponse(result) + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/json") + jsonBytes, _ := json.Marshal(mistralResponse) + ctx.SetBody(jsonBytes) +} diff --git a/transports/bifrost-http/integrations/mistral/types.go b/transports/bifrost-http/integrations/mistral/types.go new file mode 100644 index 0000000000..cc72f73a70 --- /dev/null +++ b/transports/bifrost-http/integrations/mistral/types.go @@ -0,0 +1,285 @@ +package mistral + +import ( + "github.com/maximhq/bifrost/core/schemas" +) + +// Helper function to create string pointer +func stringPtr(s string) *string { + return &s +} + +// Since Mistral uses OpenAI-compatible format, we'll reuse similar structures +// but map to Mistral provider + +// MistralMessage represents a message in the Mistral chat format (OpenAI-compatible) +type MistralMessage struct { + Role string `json:"role"` + Content *string `json:"content,omitempty"` + Name *string `json:"name,omitempty"` + ToolCalls *[]MistralToolCall `json:"tool_calls,omitempty"` + ToolCallID *string `json:"tool_call_id,omitempty"` +} + +// MistralToolCall represents a tool call in Mistral format +type MistralToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function MistralFunctionCall `json:"function"` +} + +// MistralFunctionCall represents a function call in Mistral format +type MistralFunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// MistralTool represents a tool in Mistral format +type MistralTool struct { + Type string `json:"type"` + Function MistralFunction `json:"function"` +} + +// MistralFunction represents a function definition in Mistral format +type MistralFunction struct { + Name string `json:"name"` + Description *string `json:"description,omitempty"` + Parameters interface{} `json:"parameters,omitempty"` +} + +// MistralChatRequest represents a Mistral chat completion request +type MistralChatRequest struct { + Model string `json:"model"` + Messages []MistralMessage `json:"messages"` + MaxTokens *int `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + RandomSeed *int `json:"random_seed,omitempty"` + SafePrompt *bool `json:"safe_prompt,omitempty"` + Tools *[]MistralTool `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + Stream *bool `json:"stream,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` +} + +// ConvertToBifrostRequest converts a Mistral chat request to Bifrost format +func (r *MistralChatRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { + // Note: Mistral uses OpenAI-compatible API format, so we use OpenAI provider + // This is the correct approach since Mistral follows OpenAI's API specification + bifrostReq := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, // Mistral is OpenAI-compatible + Model: r.Model, + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{}, + }, + } + + // Convert messages + for _, msg := range r.Messages { + var bifrostMsg schemas.BifrostMessage + bifrostMsg.Role = schemas.ModelChatMessageRole(msg.Role) + bifrostMsg.Content = msg.Content + + // Handle tool calls for assistant messages + if msg.ToolCalls != nil { + toolCalls := []schemas.ToolCall{} + for _, toolCall := range *msg.ToolCalls { + tc := schemas.ToolCall{ + Type: stringPtr(toolCall.Type), + ID: &toolCall.ID, + Function: schemas.FunctionCall{ + Name: &toolCall.Function.Name, + Arguments: toolCall.Function.Arguments, + }, + } + toolCalls = append(toolCalls, tc) + } + bifrostMsg.AssistantMessage = &schemas.AssistantMessage{ + ToolCalls: &toolCalls, + } + } + + // Handle tool messages + if msg.ToolCallID != nil { + bifrostMsg.ToolMessage = &schemas.ToolMessage{ + ToolCallID: msg.ToolCallID, + } + } + + *bifrostReq.Input.ChatCompletionInput = append(*bifrostReq.Input.ChatCompletionInput, bifrostMsg) + } + + // Convert parameters + if r.MaxTokens != nil || r.Temperature != nil || r.TopP != nil || r.RandomSeed != nil || r.SafePrompt != nil { + params := &schemas.ModelParameters{ + ExtraParams: make(map[string]interface{}), + } + + if r.MaxTokens != nil { + params.MaxTokens = r.MaxTokens + } + if r.Temperature != nil { + params.Temperature = r.Temperature + } + if r.TopP != nil { + params.TopP = r.TopP + } + if r.RandomSeed != nil { + params.ExtraParams["random_seed"] = r.RandomSeed + } + if r.SafePrompt != nil { + params.ExtraParams["safe_prompt"] = r.SafePrompt + } + + bifrostReq.Params = params + } + + // Convert tools + if r.Tools != nil { + tools := []schemas.Tool{} + for _, tool := range *r.Tools { + description := "" + if tool.Function.Description != nil { + description = *tool.Function.Description + } + + // Convert parameters interface{} to FunctionParameters + params := schemas.FunctionParameters{ + Type: "object", + } + if tool.Function.Parameters != nil { + if paramMap, ok := tool.Function.Parameters.(map[string]interface{}); ok { + if typeVal, ok := paramMap["type"].(string); ok { + params.Type = typeVal + } + if desc, ok := paramMap["description"].(string); ok { + params.Description = &desc + } + if required, ok := paramMap["required"].([]interface{}); ok { + reqStrings := make([]string, len(required)) + for i, req := range required { + if reqStr, ok := req.(string); ok { + reqStrings[i] = reqStr + } + } + params.Required = reqStrings + } + if properties, ok := paramMap["properties"].(map[string]interface{}); ok { + params.Properties = properties + } + if enum, ok := paramMap["enum"].([]interface{}); ok { + enumStrings := make([]string, len(enum)) + for i, e := range enum { + if eStr, ok := e.(string); ok { + enumStrings[i] = eStr + } + } + params.Enum = &enumStrings + } + } + } + + t := schemas.Tool{ + Type: tool.Type, + Function: schemas.Function{ + Name: tool.Function.Name, + Description: description, + Parameters: params, + }, + } + tools = append(tools, t) + } + if bifrostReq.Params == nil { + bifrostReq.Params = &schemas.ModelParameters{} + } + bifrostReq.Params.Tools = &tools + } + + return bifrostReq +} + +// MistralChatResponse represents a Mistral chat completion response +type MistralChatResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + Choices []MistralChoice `json:"choices"` + Usage *MistralUsage `json:"usage,omitempty"` +} + +// MistralChoice represents a choice in the Mistral response +type MistralChoice struct { + Index int `json:"index"` + Message MistralMessage `json:"message"` + FinishReason *string `json:"finish_reason,omitempty"` +} + +// MistralUsage represents usage information in Mistral format +type MistralUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// DeriveMistralFromBifrostResponse converts a Bifrost response to Mistral format +func DeriveMistralFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *MistralChatResponse { + if bifrostResp == nil { + return nil + } + + mistralResp := &MistralChatResponse{ + ID: bifrostResp.ID, + Object: "chat.completion", + Created: bifrostResp.Created, + Model: bifrostResp.Model, + Choices: make([]MistralChoice, len(bifrostResp.Choices)), + } + + // Convert usage information + if bifrostResp.Usage != (schemas.LLMUsage{}) { + mistralResp.Usage = &MistralUsage{ + PromptTokens: bifrostResp.Usage.PromptTokens, + CompletionTokens: bifrostResp.Usage.CompletionTokens, + TotalTokens: bifrostResp.Usage.TotalTokens, + } + } + + // Convert choices + for i, choice := range bifrostResp.Choices { + mistralChoice := MistralChoice{ + Index: choice.Index, + FinishReason: choice.FinishReason, + } + + // Convert message + msg := MistralMessage{ + Role: string(choice.Message.Role), + Content: choice.Message.Content, + } + + // Convert tool calls + if choice.Message.AssistantMessage != nil && choice.Message.AssistantMessage.ToolCalls != nil { + toolCalls := []MistralToolCall{} + for _, toolCall := range *choice.Message.AssistantMessage.ToolCalls { + tc := MistralToolCall{ + Type: *toolCall.Type, + Function: MistralFunctionCall{ + Name: *toolCall.Function.Name, + Arguments: toolCall.Function.Arguments, + }, + } + if toolCall.ID != nil { + tc.ID = *toolCall.ID + } + toolCalls = append(toolCalls, tc) + } + msg.ToolCalls = &toolCalls + } + + mistralChoice.Message = msg + mistralResp.Choices[i] = mistralChoice + } + + return mistralResp +} diff --git a/transports/bifrost-http/integrations/openai/router.go b/transports/bifrost-http/integrations/openai/router.go new file mode 100644 index 0000000000..16a902831e --- /dev/null +++ b/transports/bifrost-http/integrations/openai/router.go @@ -0,0 +1,60 @@ +package openai + +import ( + "encoding/json" + + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// OpenAIRouter holds route registrations for OpenAI endpoints. +type OpenAIRouter struct { + client *bifrost.Bifrost +} + +// NewOpenAIRouter creates a new OpenAIRouter with the given bifrost client. +func NewOpenAIRouter(client *bifrost.Bifrost) *OpenAIRouter { + return &OpenAIRouter{client: client} +} + +// RegisterRoutes registers all OpenAI routes on the given router. +func (o *OpenAIRouter) RegisterRoutes(r *router.Router) { + r.POST("/openai/v1/chat/completions", o.handleChatCompletion) +} + +// handleChatCompletion handles POST /v1/chat/completions +func (o *OpenAIRouter) handleChatCompletion(ctx *fasthttp.RequestCtx) { + var req OpenAIChatRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + errorResponse, _ := json.Marshal(map[string]string{"error": err.Error()}) + ctx.SetBody(errorResponse) + return + } + + if req.Model == "" { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString("Model parameter is required") + return + } + + bifrostReq := req.ConvertToBifrostRequest() + + bifrostCtx := lib.ConvertToBifrostContext(ctx) + + result, err := o.client.ChatCompletionRequest(*bifrostCtx, bifrostReq) + if err != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + errorResponse, _ := json.Marshal(err) + ctx.SetBody(errorResponse) + return + } + + openaiResponse := DeriveOpenAIFromBifrostResponse(result) + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/json") + responseBody, _ := json.Marshal(openaiResponse) + ctx.SetBody(responseBody) +} diff --git a/transports/bifrost-http/integrations/openai/types.go b/transports/bifrost-http/integrations/openai/types.go new file mode 100644 index 0000000000..72a5282c9c --- /dev/null +++ b/transports/bifrost-http/integrations/openai/types.go @@ -0,0 +1,482 @@ +package openai + +import ( + "github.com/maximhq/bifrost/core/schemas" +) + +// Helper function to create string pointer +func stringPtr(s string) *string { + return &s +} + +var fnTypePtr = stringPtr(string(schemas.ToolChoiceTypeFunction)) + +// OpenAIMessage represents a message in the OpenAI chat format +type OpenAIMessage struct { + Role string `json:"role"` + Content *string `json:"content,omitempty"` + Name *string `json:"name,omitempty"` + ToolCalls *[]OpenAIToolCall `json:"tool_calls,omitempty"` + ToolCallID *string `json:"tool_call_id,omitempty"` + FunctionCall *OpenAIFunctionCall `json:"function_call,omitempty"` +} + +// OpenAIToolCall represents a tool call in OpenAI format +type OpenAIToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function OpenAIFunctionCall `json:"function"` +} + +// OpenAIFunctionCall represents a function call in OpenAI format +type OpenAIFunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// OpenAITool represents a tool in OpenAI format +type OpenAITool struct { + Type string `json:"type"` + Function OpenAIFunction `json:"function"` +} + +// OpenAIFunction represents a function definition in OpenAI format +type OpenAIFunction struct { + Name string `json:"name"` + Description *string `json:"description,omitempty"` + Parameters interface{} `json:"parameters,omitempty"` +} + +// OpenAIChatRequest represents an OpenAI chat completion request +type OpenAIChatRequest struct { + Model string `json:"model"` + Messages []OpenAIMessage `json:"messages"` + MaxTokens *int `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + N *int `json:"n,omitempty"` + Stop interface{} `json:"stop,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + LogitBias map[string]float64 `json:"logit_bias,omitempty"` + User *string `json:"user,omitempty"` + Functions *[]OpenAIFunction `json:"functions,omitempty"` + FunctionCall interface{} `json:"function_call,omitempty"` + Tools *[]OpenAITool `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + Stream *bool `json:"stream,omitempty"` + LogProbs *bool `json:"logprobs,omitempty"` + TopLogProbs *int `json:"top_logprobs,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` +} + +// ConvertToBifrostRequest converts an OpenAI chat request to Bifrost format +func (r *OpenAIChatRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { + bifrostReq := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: r.Model, + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{}, + }, + } + + // Convert messages + for _, msg := range r.Messages { + var bifrostMsg schemas.BifrostMessage + bifrostMsg.Role = schemas.ModelChatMessageRole(msg.Role) + bifrostMsg.Content = msg.Content + + // Handle tool calls and function calls for assistant messages + var toolCalls []schemas.ToolCall + + // Add modern tool calls + if msg.ToolCalls != nil { + for _, toolCall := range *msg.ToolCalls { + tc := schemas.ToolCall{ + Type: &toolCall.Type, + ID: &toolCall.ID, + Function: schemas.FunctionCall{ + Name: &toolCall.Function.Name, + Arguments: toolCall.Function.Arguments, + }, + } + toolCalls = append(toolCalls, tc) + } + } + + // Add legacy function calls + if msg.FunctionCall != nil { + tc := schemas.ToolCall{ + Type: fnTypePtr, + Function: schemas.FunctionCall{ + Name: &msg.FunctionCall.Name, + Arguments: msg.FunctionCall.Arguments, + }, + } + toolCalls = append(toolCalls, tc) + } + + // Assign AssistantMessage only if we have tool calls + if len(toolCalls) > 0 { + bifrostMsg.AssistantMessage = &schemas.AssistantMessage{ + ToolCalls: &toolCalls, + } + } + + // Handle tool messages + if msg.ToolCallID != nil { + bifrostMsg.ToolMessage = &schemas.ToolMessage{ + ToolCallID: msg.ToolCallID, + } + } + + *bifrostReq.Input.ChatCompletionInput = append(*bifrostReq.Input.ChatCompletionInput, bifrostMsg) + } + + // Convert parameters + if r.MaxTokens != nil || r.Temperature != nil || r.TopP != nil || r.PresencePenalty != nil || + r.FrequencyPenalty != nil || r.N != nil || r.LogProbs != nil || r.TopLogProbs != nil || + r.Stop != nil || r.LogitBias != nil { + params := &schemas.ModelParameters{ + ExtraParams: make(map[string]interface{}), + } + + if r.MaxTokens != nil { + params.MaxTokens = r.MaxTokens + } + if r.Temperature != nil { + params.Temperature = r.Temperature + } + if r.TopP != nil { + params.TopP = r.TopP + } + if r.PresencePenalty != nil { + params.PresencePenalty = r.PresencePenalty + } + if r.FrequencyPenalty != nil { + params.FrequencyPenalty = r.FrequencyPenalty + } + if r.N != nil { + params.ExtraParams["n"] = r.N + } + if r.LogProbs != nil { + params.ExtraParams["logprobs"] = r.LogProbs + } + if r.TopLogProbs != nil { + params.ExtraParams["top_logprobs"] = r.TopLogProbs + } + if r.Stop != nil { + params.ExtraParams["stop"] = r.Stop + } + if r.LogitBias != nil { + params.ExtraParams["logit_bias"] = r.LogitBias + } + + bifrostReq.Params = params + } + + // Convert tools and functions (legacy) + var allTools []schemas.Tool + + // Handle modern Tools field + if r.Tools != nil { + for _, tool := range *r.Tools { + description := "" + if tool.Function.Description != nil { + description = *tool.Function.Description + } + + // Convert parameters interface{} to FunctionParameters + params := schemas.FunctionParameters{ + Type: "object", + } + if tool.Function.Parameters != nil { + if paramMap, ok := tool.Function.Parameters.(map[string]interface{}); ok { + if typeVal, ok := paramMap["type"].(string); ok { + params.Type = typeVal + } + if desc, ok := paramMap["description"].(string); ok { + params.Description = &desc + } + if required, ok := paramMap["required"].([]interface{}); ok { + reqStrings := make([]string, len(required)) + for i, req := range required { + if reqStr, ok := req.(string); ok { + reqStrings[i] = reqStr + } + } + params.Required = reqStrings + } + if properties, ok := paramMap["properties"].(map[string]interface{}); ok { + params.Properties = properties + } + if enum, ok := paramMap["enum"].([]interface{}); ok { + enumStrings := make([]string, len(enum)) + for i, e := range enum { + if eStr, ok := e.(string); ok { + enumStrings[i] = eStr + } + } + params.Enum = &enumStrings + } + } + } + + t := schemas.Tool{ + Type: tool.Type, + Function: schemas.Function{ + Name: tool.Function.Name, + Description: description, + Parameters: params, + }, + } + allTools = append(allTools, t) + } + } + + // Handle legacy Functions field + if r.Functions != nil { + for _, function := range *r.Functions { + description := "" + if function.Description != nil { + description = *function.Description + } + + // Convert parameters interface{} to FunctionParameters + params := schemas.FunctionParameters{ + Type: "object", + } + if function.Parameters != nil { + if paramMap, ok := function.Parameters.(map[string]interface{}); ok { + if typeVal, ok := paramMap["type"].(string); ok { + params.Type = typeVal + } + if desc, ok := paramMap["description"].(string); ok { + params.Description = &desc + } + if required, ok := paramMap["required"].([]interface{}); ok { + reqStrings := make([]string, len(required)) + for i, req := range required { + if reqStr, ok := req.(string); ok { + reqStrings[i] = reqStr + } + } + params.Required = reqStrings + } + if properties, ok := paramMap["properties"].(map[string]interface{}); ok { + params.Properties = properties + } + if enum, ok := paramMap["enum"].([]interface{}); ok { + enumStrings := make([]string, len(enum)) + for i, e := range enum { + if eStr, ok := e.(string); ok { + enumStrings[i] = eStr + } + } + params.Enum = &enumStrings + } + } + } + + t := schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: function.Name, + Description: description, + Parameters: params, + }, + } + allTools = append(allTools, t) + } + } + + // Set tools if any were found + if len(allTools) > 0 { + if bifrostReq.Params == nil { + bifrostReq.Params = &schemas.ModelParameters{} + } + bifrostReq.Params.Tools = &allTools + } + + // Convert tool choice (from either tool_choice or function_call) + if r.ToolChoice != nil || r.FunctionCall != nil { + if bifrostReq.Params == nil { + bifrostReq.Params = &schemas.ModelParameters{} + } + + // Handle ToolChoice (modern format) + if r.ToolChoice != nil { + toolChoice := &schemas.ToolChoice{} + + switch tc := r.ToolChoice.(type) { + case string: + // Handle "none", "auto", etc. + switch tc { + case "none": + toolChoice.Type = schemas.ToolChoiceTypeNone + case "auto": + toolChoice.Type = schemas.ToolChoiceTypeAuto + case "required": + toolChoice.Type = schemas.ToolChoiceTypeRequired + default: + toolChoice.Type = schemas.ToolChoiceTypeAuto // fallback + } + case map[string]interface{}: + // Handle object format like {"type": "function", "function": {"name": "get_weather"}} + if typeVal, ok := tc["type"].(string); ok { + switch typeVal { + case "function": + toolChoice.Type = schemas.ToolChoiceTypeFunction + if functionVal, ok := tc["function"].(map[string]interface{}); ok { + if name, ok := functionVal["name"].(string); ok { + toolChoice.Function = schemas.ToolChoiceFunction{Name: name} + } + } + case "none": + toolChoice.Type = schemas.ToolChoiceTypeNone + case "auto": + toolChoice.Type = schemas.ToolChoiceTypeAuto + case "required": + toolChoice.Type = schemas.ToolChoiceTypeRequired + default: + toolChoice.Type = schemas.ToolChoiceTypeAuto // fallback + } + } + } + + bifrostReq.Params.ToolChoice = toolChoice + } else if r.FunctionCall != nil { + // Handle legacy FunctionCall + toolChoice := &schemas.ToolChoice{} + + switch fc := r.FunctionCall.(type) { + case string: + // Handle "none", "auto" + switch fc { + case "none": + toolChoice.Type = schemas.ToolChoiceTypeNone + case "auto": + toolChoice.Type = schemas.ToolChoiceTypeAuto + default: + toolChoice.Type = schemas.ToolChoiceTypeAuto // fallback + } + case map[string]interface{}: + // Handle object format like {"name": "get_weather"} + if name, ok := fc["name"].(string); ok { + toolChoice.Type = schemas.ToolChoiceTypeFunction + toolChoice.Function = schemas.ToolChoiceFunction{Name: name} + } + } + + bifrostReq.Params.ToolChoice = toolChoice + } + } + + return bifrostReq +} + +// OpenAIChatResponse represents an OpenAI chat completion response +type OpenAIChatResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + Choices []OpenAIChoice `json:"choices"` + Usage *OpenAIUsage `json:"usage,omitempty"` + SystemFingerprint *string `json:"system_fingerprint,omitempty"` +} + +// OpenAIChoice represents a choice in the OpenAI response +type OpenAIChoice struct { + Index int `json:"index"` + Message OpenAIMessage `json:"message"` + FinishReason *string `json:"finish_reason,omitempty"` + LogProbs interface{} `json:"logprobs,omitempty"` +} + +// OpenAIUsage represents usage information in OpenAI format +type OpenAIUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// DeriveOpenAIFromBifrostResponse converts a Bifrost response to OpenAI format +func DeriveOpenAIFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *OpenAIChatResponse { + if bifrostResp == nil { + return nil + } + + openaiResp := &OpenAIChatResponse{ + ID: bifrostResp.ID, + Object: "chat.completion", + Created: bifrostResp.Created, + Model: bifrostResp.Model, + Choices: make([]OpenAIChoice, len(bifrostResp.Choices)), + } + + if bifrostResp.SystemFingerprint != nil { + openaiResp.SystemFingerprint = bifrostResp.SystemFingerprint + } + + // Convert usage information + if bifrostResp.Usage != (schemas.LLMUsage{}) { + openaiResp.Usage = &OpenAIUsage{ + PromptTokens: bifrostResp.Usage.PromptTokens, + CompletionTokens: bifrostResp.Usage.CompletionTokens, + TotalTokens: bifrostResp.Usage.TotalTokens, + } + } + + // Convert choices + for i, choice := range bifrostResp.Choices { + openaiChoice := OpenAIChoice{ + Index: choice.Index, + FinishReason: choice.FinishReason, + } + + // Convert message + msg := OpenAIMessage{ + Role: string(choice.Message.Role), + Content: choice.Message.Content, + } + + // Convert tool calls for assistant messages + if choice.Message.AssistantMessage != nil && choice.Message.AssistantMessage.ToolCalls != nil { + toolCalls := []OpenAIToolCall{} + for _, toolCall := range *choice.Message.AssistantMessage.ToolCalls { + tc := OpenAIToolCall{ + Type: *toolCall.Type, + Function: OpenAIFunctionCall{ + Name: *toolCall.Function.Name, + Arguments: toolCall.Function.Arguments, + }, + } + if toolCall.ID != nil { + tc.ID = *toolCall.ID + } + toolCalls = append(toolCalls, tc) + } + msg.ToolCalls = &toolCalls + + // Re-emit legacy function_call field when exactly one function tool-call is present + if len(toolCalls) == 1 && toolCalls[0].Type == "function" { + msg.FunctionCall = &OpenAIFunctionCall{ + Name: toolCalls[0].Function.Name, + Arguments: toolCalls[0].Function.Arguments, + } + } + } + + // Handle tool messages - propagate tool_call_id + if choice.Message.ToolMessage != nil && choice.Message.ToolMessage.ToolCallID != nil { + msg.ToolCallID = choice.Message.ToolMessage.ToolCallID + } + + openaiChoice.Message = msg + openaiResp.Choices[i] = openaiChoice + } + + return openaiResp +} diff --git a/transports/bifrost-http/main.go b/transports/bifrost-http/main.go index 8e953ace17..413c8718dc 100644 --- a/transports/bifrost-http/main.go +++ b/transports/bifrost-http/main.go @@ -29,7 +29,13 @@ import ( schemas "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/plugins/maxim" "github.com/maximhq/bifrost/transports/bifrost-http/integrations" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/anthropic" "github.com/maximhq/bifrost/transports/bifrost-http/integrations/genai" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/langchain" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/langgraph" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/litellm" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/mistral" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/openai" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/maximhq/bifrost/transports/bifrost-http/tracking" "github.com/prometheus/client_golang/prometheus" @@ -176,7 +182,15 @@ func main() { r := router.New() - extensions := []integrations.ExtensionRouter{genai.NewGenAIRouter(client)} + extensions := []integrations.ExtensionRouter{ + genai.NewGenAIRouter(client), + openai.NewOpenAIRouter(client), + anthropic.NewAnthropicRouter(client), + mistral.NewMistralRouter(client), + litellm.NewLiteLLMRouter(client), + langchain.NewLangChainRouter(client), + langgraph.NewLangGraphRouter(client), + } r.POST("/v1/text/completions", func(ctx *fasthttp.RequestCtx) { handleCompletion(ctx, client, false) @@ -297,11 +311,13 @@ func handleCompletion(ctx *fasthttp.RequestCtx, client *bifrost.Bifrost, isChat ctx.SetStatusCode(fasthttp.StatusBadRequest) } ctx.SetContentType("application/json") - json.NewEncoder(ctx).Encode(err) + jsonBytes, _ := json.Marshal(err) + ctx.SetBody(jsonBytes) return } ctx.SetStatusCode(fasthttp.StatusOK) ctx.SetContentType("application/json") - json.NewEncoder(ctx).Encode(resp) + jsonBytes, _ := json.Marshal(resp) + ctx.SetBody(jsonBytes) }