diff --git a/.gitignore b/.gitignore index 540b0b388d..a5e2010b87 100644 --- a/.gitignore +++ b/.gitignore @@ -53,4 +53,5 @@ test-reports # Cursor specific -.cursor/ \ No newline at end of file +.cursor/ +build/ \ No newline at end of file diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index daacc8b316..0c84330a7f 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -5,8 +5,6 @@ import ( "context" "encoding/json" "errors" - - "github.com/bytedance/sonic" ) const ( @@ -468,17 +466,17 @@ type BifrostStream struct { // This ensures that only the non-nil embedded struct is marshaled, func (bs BifrostStream) MarshalJSON() ([]byte, error) { if bs.BifrostTextCompletionResponse != nil { - return sonic.Marshal(bs.BifrostTextCompletionResponse) + return Marshal(bs.BifrostTextCompletionResponse) } else if bs.BifrostChatResponse != nil { - return sonic.Marshal(bs.BifrostChatResponse) + return Marshal(bs.BifrostChatResponse) } else if bs.BifrostResponsesStreamResponse != nil { - return sonic.Marshal(bs.BifrostResponsesStreamResponse) + return Marshal(bs.BifrostResponsesStreamResponse) } else if bs.BifrostSpeechStreamResponse != nil { - return sonic.Marshal(bs.BifrostSpeechStreamResponse) + return Marshal(bs.BifrostSpeechStreamResponse) } else if bs.BifrostTranscriptionStreamResponse != nil { - return sonic.Marshal(bs.BifrostTranscriptionStreamResponse) + return Marshal(bs.BifrostTranscriptionStreamResponse) } else if bs.BifrostError != nil { - return sonic.Marshal(bs.BifrostError) + return Marshal(bs.BifrostError) } // Return empty object if both are nil (shouldn't happen in practice) return []byte("{}"), nil diff --git a/core/schemas/chatcompletions.go b/core/schemas/chatcompletions.go index 4f9091b965..5aae9d6d2f 100644 --- a/core/schemas/chatcompletions.go +++ b/core/schemas/chatcompletions.go @@ -4,8 +4,6 @@ import ( "bytes" "fmt" "sort" - - "github.com/bytedance/sonic" ) // BifrostChatRequest is the request struct for chat completion requests @@ -200,7 +198,7 @@ func (cp *ChatParameters) UnmarshalJSON(data []byte) error { aux.Alias = (*Alias)(cp) // Single unmarshal - if err := sonic.Unmarshal(data, &aux); err != nil { + if err := Unmarshal(data, &aux); err != nil { return err } @@ -288,11 +286,11 @@ type ToolFunctionParameters struct { func (t *ToolFunctionParameters) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a JSON string (xAI format) var jsonStr string - if err := sonic.Unmarshal(data, &jsonStr); err == nil { + if err := Unmarshal(data, &jsonStr); err == nil { // It's a string, so parse the string as JSON type Alias ToolFunctionParameters var temp Alias - if err := sonic.Unmarshal([]byte(jsonStr), &temp); err != nil { + if err := Unmarshal([]byte(jsonStr), &temp); err != nil { return fmt.Errorf("failed to unmarshal parameters string: %w", err) } *t = ToolFunctionParameters(temp) @@ -302,7 +300,7 @@ func (t *ToolFunctionParameters) UnmarshalJSON(data []byte) error { // Otherwise, unmarshal as a normal JSON object type Alias ToolFunctionParameters var temp Alias - if err := sonic.Unmarshal(data, &temp); err != nil { + if err := Unmarshal(data, &temp); err != nil { return err } *t = ToolFunctionParameters(temp) @@ -370,7 +368,7 @@ func (om OrderedMap) MarshalJSON() ([]byte, error) { } // key - keyBytes, err := sonic.Marshal(k) + keyBytes, err := Marshal(k) if err != nil { return nil, err } @@ -378,7 +376,7 @@ func (om OrderedMap) MarshalJSON() ([]byte, error) { buf.WriteByte(':') // value - valBytes, err := sonic.Marshal(norm[k]) + valBytes, err := Marshal(norm[k]) if err != nil { return nil, err } @@ -443,13 +441,13 @@ func (ctc ChatToolChoice) MarshalJSON() ([]byte, error) { } if ctc.ChatToolChoiceStr != nil { - return sonic.Marshal(ctc.ChatToolChoiceStr) + return Marshal(ctc.ChatToolChoiceStr) } if ctc.ChatToolChoiceStruct != nil { - return sonic.Marshal(ctc.ChatToolChoiceStruct) + return Marshal(ctc.ChatToolChoiceStruct) } // If both are nil, return null - return sonic.Marshal(nil) + return Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ChatMessageContent. @@ -458,7 +456,7 @@ func (ctc ChatToolChoice) MarshalJSON() ([]byte, error) { func (ctc *ChatToolChoice) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var toolChoiceStr string - if err := sonic.Unmarshal(data, &toolChoiceStr); err == nil { + if err := Unmarshal(data, &toolChoiceStr); err == nil { ctc.ChatToolChoiceStr = &toolChoiceStr ctc.ChatToolChoiceStruct = nil return nil @@ -466,7 +464,7 @@ func (ctc *ChatToolChoice) UnmarshalJSON(data []byte) error { // Try to unmarshal as a direct array of ContentBlock var chatToolChoice ChatToolChoiceStruct - if err := sonic.Unmarshal(data, &chatToolChoice); err == nil { + if err := Unmarshal(data, &chatToolChoice); err == nil { ctc.ChatToolChoiceStr = nil ctc.ChatToolChoiceStruct = &chatToolChoice return nil @@ -523,7 +521,7 @@ type ChatMessage struct { // UnmarshalJSON implements custom JSON unmarshalling for ChatMessage. // This is needed because ChatAssistantMessage has a custom UnmarshalJSON method, -// which interferes with sonic's handling of other fields in ChatMessage. +// which interferes with the JSON library's handling of other fields in ChatMessage. func (cm *ChatMessage) UnmarshalJSON(data []byte) error { // Unmarshal the base fields directly type baseFields struct { @@ -532,7 +530,7 @@ func (cm *ChatMessage) UnmarshalJSON(data []byte) error { Content *ChatMessageContent `json:"content,omitempty"` } var base baseFields - if err := sonic.Unmarshal(data, &base); err != nil { + if err := Unmarshal(data, &base); err != nil { return err } cm.Name = base.Name @@ -542,7 +540,7 @@ func (cm *ChatMessage) UnmarshalJSON(data []byte) error { // Unmarshal ChatToolMessage fields type toolMsgAlias ChatToolMessage var toolMsg toolMsgAlias - if err := sonic.Unmarshal(data, &toolMsg); err != nil { + if err := Unmarshal(data, &toolMsg); err != nil { return err } if toolMsg.ToolCallID != nil { @@ -551,7 +549,7 @@ func (cm *ChatMessage) UnmarshalJSON(data []byte) error { // Unmarshal ChatAssistantMessage (which has its own custom unmarshaller) var assistantMsg ChatAssistantMessage - if err := sonic.Unmarshal(data, &assistantMsg); err != nil { + if err := Unmarshal(data, &assistantMsg); err != nil { return err } // Only set if any field is populated @@ -579,13 +577,13 @@ func (mc ChatMessageContent) MarshalJSON() ([]byte, error) { } if mc.ContentStr != nil { - return sonic.Marshal(*mc.ContentStr) + return Marshal(*mc.ContentStr) } if mc.ContentBlocks != nil { - return sonic.Marshal(mc.ContentBlocks) + return Marshal(mc.ContentBlocks) } // If both are nil, return null - return sonic.Marshal(nil) + return Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ChatMessageContent. @@ -601,7 +599,7 @@ func (mc *ChatMessageContent) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var stringContent string - if err := sonic.Unmarshal(data, &stringContent); err == nil { + if err := Unmarshal(data, &stringContent); err == nil { mc.ContentStr = &stringContent mc.ContentBlocks = nil return nil @@ -609,7 +607,7 @@ func (mc *ChatMessageContent) UnmarshalJSON(data []byte) error { // Try to unmarshal as a direct array of ContentBlock var arrayContent []ChatContentBlock - if err := sonic.Unmarshal(data, &arrayContent); err == nil { + if err := Unmarshal(data, &arrayContent); err == nil { mc.ContentBlocks = arrayContent mc.ContentStr = nil return nil @@ -709,7 +707,7 @@ func (cm *ChatAssistantMessage) UnmarshalJSON(data []byte) error { ReasoningContent *string `json:"reasoning_content,omitempty"` // xAI uses this field name } - if err := sonic.Unmarshal(data, &aux); err != nil { + if err := Unmarshal(data, &aux); err != nil { return err } @@ -856,7 +854,7 @@ func (d *ChatStreamResponseChoiceDelta) UnmarshalJSON(data []byte) error { ReasoningContent *string `json:"reasoning_content,omitempty"` // xAI uses this field name } - if err := sonic.Unmarshal(data, &aux); err != nil { + if err := Unmarshal(data, &aux); err != nil { return err } @@ -945,7 +943,7 @@ type BifrostCost struct { func (bc *BifrostCost) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct float var costFloat float64 - if err := sonic.Unmarshal(data, &costFloat); err == nil { + if err := Unmarshal(data, &costFloat); err == nil { bc.TotalCost = costFloat return nil } @@ -954,7 +952,7 @@ func (bc *BifrostCost) UnmarshalJSON(data []byte) error { // Use a type alias to avoid infinite recursion type Alias BifrostCost var costStruct Alias - if err := sonic.Unmarshal(data, &costStruct); err == nil { + if err := Unmarshal(data, &costStruct); err == nil { *bc = BifrostCost(costStruct) return nil } diff --git a/core/schemas/embedding.go b/core/schemas/embedding.go index 73f0d2664c..e1fbbe9f3d 100644 --- a/core/schemas/embedding.go +++ b/core/schemas/embedding.go @@ -2,8 +2,6 @@ package schemas import ( "fmt" - - "github.com/bytedance/sonic" ) type BifrostEmbeddingRequest struct { @@ -58,16 +56,16 @@ func (e *EmbeddingInput) MarshalJSON() ([]byte, error) { } if e.Text != nil { - return sonic.Marshal(*e.Text) + return Marshal(*e.Text) } if e.Texts != nil { - return sonic.Marshal(e.Texts) + return Marshal(e.Texts) } if e.Embedding != nil { - return sonic.Marshal(e.Embedding) + return Marshal(e.Embedding) } if e.Embeddings != nil { - return sonic.Marshal(e.Embeddings) + return Marshal(e.Embeddings) } return nil, fmt.Errorf("invalid embedding input") @@ -80,25 +78,25 @@ func (e *EmbeddingInput) UnmarshalJSON(data []byte) error { e.Embeddings = nil // Try string var s string - if err := sonic.Unmarshal(data, &s); err == nil { + if err := Unmarshal(data, &s); err == nil { e.Text = &s return nil } // Try []string var ss []string - if err := sonic.Unmarshal(data, &ss); err == nil { + if err := Unmarshal(data, &ss); err == nil { e.Texts = ss return nil } // Try []int var i []int - if err := sonic.Unmarshal(data, &i); err == nil { + if err := Unmarshal(data, &i); err == nil { e.Embedding = i return nil } // Try [][]int var i2 [][]int - if err := sonic.Unmarshal(data, &i2); err == nil { + if err := Unmarshal(data, &i2); err == nil { e.Embeddings = i2 return nil } @@ -129,13 +127,13 @@ type EmbeddingStruct struct { func (be EmbeddingStruct) MarshalJSON() ([]byte, error) { if be.EmbeddingStr != nil { - return sonic.Marshal(be.EmbeddingStr) + return Marshal(be.EmbeddingStr) } if be.EmbeddingArray != nil { - return sonic.Marshal(be.EmbeddingArray) + return Marshal(be.EmbeddingArray) } if be.Embedding2DArray != nil { - return sonic.Marshal(be.Embedding2DArray) + return Marshal(be.Embedding2DArray) } return nil, fmt.Errorf("no embedding found") } @@ -143,21 +141,21 @@ func (be EmbeddingStruct) MarshalJSON() ([]byte, error) { func (be *EmbeddingStruct) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var stringContent string - if err := sonic.Unmarshal(data, &stringContent); err == nil { + if err := Unmarshal(data, &stringContent); err == nil { be.EmbeddingStr = &stringContent return nil } // Try to unmarshal as a direct array of float32 var arrayContent []float32 - if err := sonic.Unmarshal(data, &arrayContent); err == nil { + if err := Unmarshal(data, &arrayContent); err == nil { be.EmbeddingArray = arrayContent return nil } // Try to unmarshal as a direct 2D array of float32 var arrayContent2D [][]float32 - if err := sonic.Unmarshal(data, &arrayContent2D); err == nil { + if err := Unmarshal(data, &arrayContent2D); err == nil { be.Embedding2DArray = arrayContent2D return nil } diff --git a/core/schemas/json_native.go b/core/schemas/json_native.go new file mode 100644 index 0000000000..3d21c91444 --- /dev/null +++ b/core/schemas/json_native.go @@ -0,0 +1,20 @@ +//go:build !tinygo && !wasm + +package schemas + +import "github.com/bytedance/sonic" + +// Marshal encodes v to JSON bytes using the high-performance sonic library. +func Marshal(v interface{}) ([]byte, error) { + return sonic.Marshal(v) +} + +// MarshalString encodes v to a JSON string using sonic. +func MarshalString(v interface{}) (string, error) { + return sonic.MarshalString(v) +} + +// Unmarshal decodes JSON data into v using sonic. +func Unmarshal(data []byte, v interface{}) error { + return sonic.Unmarshal(data, v) +} diff --git a/core/schemas/json_wasm.go b/core/schemas/json_wasm.go new file mode 100644 index 0000000000..f04c328d2f --- /dev/null +++ b/core/schemas/json_wasm.go @@ -0,0 +1,24 @@ +//go:build tinygo || wasm + +package schemas + +import "encoding/json" + +// Marshal encodes v to JSON bytes using the standard library. +func Marshal(v interface{}) ([]byte, error) { + return json.Marshal(v) +} + +// MarshalString encodes v to a JSON string using the standard library. +func MarshalString(v interface{}) (string, error) { + data, err := json.Marshal(v) + if err != nil { + return "", err + } + return string(data), nil +} + +// Unmarshal decodes JSON data into v using the standard library. +func Unmarshal(data []byte, v interface{}) error { + return json.Unmarshal(data, v) +} diff --git a/core/schemas/mcp.go b/core/schemas/mcp.go index f54998cc5d..a9437a29cd 100644 --- a/core/schemas/mcp.go +++ b/core/schemas/mcp.go @@ -1,3 +1,5 @@ +//go:build !tinygo && !wasm + // Package schemas defines the core schemas and types used by the Bifrost system. package schemas diff --git a/core/schemas/mcp_wasm.go b/core/schemas/mcp_wasm.go new file mode 100644 index 0000000000..1a34e39b26 --- /dev/null +++ b/core/schemas/mcp_wasm.go @@ -0,0 +1,7 @@ +//go:build tinygo || wasm + +package schemas + +// MCPConfig is a stub for WASM builds. +// MCP functionality is not available in WASM plugins. +type MCPConfig struct{} diff --git a/core/schemas/models.go b/core/schemas/models.go index ab1e627559..4226e6e31e 100644 --- a/core/schemas/models.go +++ b/core/schemas/models.go @@ -3,8 +3,6 @@ package schemas import ( "encoding/base64" "fmt" - - "github.com/bytedance/sonic" ) // DefaultPageSize is the default page size for listing models @@ -182,7 +180,7 @@ func encodePaginationCursor(offset int, lastID string) (string, error) { LastID: lastID, } - jsonData, err := sonic.Marshal(cursor) + jsonData, err := Marshal(cursor) if err != nil { return "", fmt.Errorf("failed to marshal pagination cursor: %w", err) } @@ -206,7 +204,7 @@ func decodePaginationCursor(token string) paginationCursor { } var cursor paginationCursor - if err := sonic.Unmarshal(decoded, &cursor); err != nil { + if err := Unmarshal(decoded, &cursor); err != nil { return paginationCursor{} } diff --git a/core/schemas/plugin.go b/core/schemas/plugin.go index 6aa54a176c..3ab4cd8f02 100644 --- a/core/schemas/plugin.go +++ b/core/schemas/plugin.go @@ -3,8 +3,7 @@ package schemas import ( "context" - - "github.com/valyala/fasthttp" + "sync" ) // PluginShortCircuit represents a plugin's decision to short-circuit the normal flow. @@ -33,9 +32,58 @@ type PluginStatus struct { Logs []string `json:"logs"` } -// BifrostHTTPMiddleware is a middleware function for the Bifrost HTTP transport -// It follows the standard pattern: receives the next handler and returns a new handler -type BifrostHTTPMiddleware func(next fasthttp.RequestHandler) fasthttp.RequestHandler +// HTTPRequest is a serializable representation of an HTTP request. +// Used for plugin HTTP transport interception (supports both native .so and WASM plugins). +// This type is pooled for allocation control - use AcquireHTTPRequest and ReleaseHTTPRequest. +type HTTPRequest struct { + Method string `json:"method"` + Path string `json:"path"` + Headers map[string]string `json:"headers"` + Query map[string]string `json:"query"` + Body []byte `json:"body"` +} + +// HTTPResponse is a serializable representation of an HTTP response. +// Used for short-circuit responses in plugin HTTP transport interception. +type HTTPResponse struct { + StatusCode int `json:"status_code"` + Headers map[string]string `json:"headers"` + Body []byte `json:"body"` +} + +// httpRequestPool is the pool for HTTPRequest objects to reduce allocations. +var httpRequestPool = sync.Pool{ + New: func() any { + return &HTTPRequest{ + Headers: make(map[string]string, 16), + Query: make(map[string]string, 8), + } + }, +} + +// AcquireHTTPRequest gets an HTTPRequest from the pool. +// The returned HTTPRequest is ready to use with pre-allocated maps. +// Call ReleaseHTTPRequest when done to return it to the pool. +func AcquireHTTPRequest() *HTTPRequest { + return httpRequestPool.Get().(*HTTPRequest) +} + +// ReleaseHTTPRequest returns an HTTPRequest to the pool. +// The HTTPRequest is reset before being returned to the pool. +// Do not use the HTTPRequest after calling this function. +func ReleaseHTTPRequest(req *HTTPRequest) { + if req == nil { + return + } + // Clear the maps + clear(req.Headers) + clear(req.Query) + // Reset fields + req.Method = "" + req.Path = "" + req.Body = nil + httpRequestPool.Put(req) +} // Plugin defines the interface for Bifrost plugins. // Plugins can intercept and modify requests and responses at different stages @@ -45,7 +93,7 @@ type BifrostHTTPMiddleware func(next fasthttp.RequestHandler) fasthttp.RequestHa // PostHooks are executed in the reverse order of PreHooks. // // Execution order: -// 1. HTTPTransportMiddleware (HTTP transport only, modifies raw headers/body before entering Bifrost core) +// 1. HTTPTransportIntercept (HTTP transport only, modifies raw headers/body before entering Bifrost core) // 2. PreHook (executed in registration order) // 3. Provider call // 4. PostHook (executed in reverse order of PreHooks) @@ -72,11 +120,18 @@ type Plugin interface { // GetName returns the name of the plugin. GetName() string - // HTTPTransportMiddleware is called at the HTTP transport layer before requests enter Bifrost core. - // It allows plugins to modify the request and response before they are processed by the next middleware. + // HTTPTransportIntercept is called at the HTTP transport layer before requests enter Bifrost core. + // It receives a serializable HTTPRequest and allows plugins to modify it in-place. // Only invoked when using HTTP transport (bifrost-http), not when using Bifrost as a Go SDK directly. - // Returns a new handler that will be called next in the middleware chain. - HTTPTransportMiddleware() BifrostHTTPMiddleware + // Works with both native .so plugins and WASM plugins due to serializable types. + // + // Return values: + // - (nil, nil): Continue to next plugin/handler, request modifications are applied + // - (*HTTPResponse, nil): Short-circuit with this response, skip remaining plugins and provider call + // - (nil, error): Short-circuit with error response + // + // Return nil for both values if the plugin doesn't need HTTP transport interception. + HTTPTransportIntercept(ctx *BifrostContext, req *HTTPRequest) (*HTTPResponse, error) // PreHook is called before a request is processed by a provider. // It allows plugins to modify the request before it is sent to the provider. diff --git a/core/schemas/plugin_native.go b/core/schemas/plugin_native.go new file mode 100644 index 0000000000..672433142e --- /dev/null +++ b/core/schemas/plugin_native.go @@ -0,0 +1,12 @@ +//go:build !tinygo && !wasm + +package schemas + +import ( + "github.com/valyala/fasthttp" +) + +// BifrostHTTPMiddleware is a middleware function for the Bifrost HTTP transport. +// It follows the standard pattern: receives the next handler and returns a new handler. +// Used internally for CORS, Auth, Tracing middleware. Plugins use HTTPTransportIntercept instead. +type BifrostHTTPMiddleware func(next fasthttp.RequestHandler) fasthttp.RequestHandler diff --git a/core/schemas/responses.go b/core/schemas/responses.go index 526717720c..384fbbb5cf 100644 --- a/core/schemas/responses.go +++ b/core/schemas/responses.go @@ -2,8 +2,6 @@ package schemas import ( "fmt" - - "github.com/bytedance/sonic" ) // ============================================================================= @@ -155,13 +153,13 @@ func (rc ResponsesResponseConversation) MarshalJSON() ([]byte, error) { } if rc.ResponsesResponseConversationStr != nil { - return sonic.Marshal(*rc.ResponsesResponseConversationStr) + return Marshal(*rc.ResponsesResponseConversationStr) } if rc.ResponsesResponseConversationStruct != nil { - return sonic.Marshal(rc.ResponsesResponseConversationStruct) + return Marshal(rc.ResponsesResponseConversationStruct) } // If both are nil, return null - return sonic.Marshal(nil) + return Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ResponsesMessageContent. @@ -170,14 +168,14 @@ func (rc ResponsesResponseConversation) MarshalJSON() ([]byte, error) { func (rc *ResponsesResponseConversation) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var stringContent string - if err := sonic.Unmarshal(data, &stringContent); err == nil { + if err := Unmarshal(data, &stringContent); err == nil { rc.ResponsesResponseConversationStr = &stringContent return nil } // Try to unmarshal as a direct array of ContentBlock var structContent ResponsesResponseConversationStruct - if err := sonic.Unmarshal(data, &structContent); err == nil { + if err := Unmarshal(data, &structContent); err == nil { rc.ResponsesResponseConversationStruct = &structContent return nil } @@ -199,13 +197,13 @@ func (rc ResponsesResponseInstructions) MarshalJSON() ([]byte, error) { } if rc.ResponsesResponseInstructionsStr != nil { - return sonic.Marshal(*rc.ResponsesResponseInstructionsStr) + return Marshal(*rc.ResponsesResponseInstructionsStr) } if rc.ResponsesResponseInstructionsArray != nil { - return sonic.Marshal(rc.ResponsesResponseInstructionsArray) + return Marshal(rc.ResponsesResponseInstructionsArray) } // If both are nil, return null - return sonic.Marshal(nil) + return Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ResponsesMessageContent. @@ -214,14 +212,14 @@ func (rc ResponsesResponseInstructions) MarshalJSON() ([]byte, error) { func (rc *ResponsesResponseInstructions) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var stringContent string - if err := sonic.Unmarshal(data, &stringContent); err == nil { + if err := Unmarshal(data, &stringContent); err == nil { rc.ResponsesResponseInstructionsStr = &stringContent return nil } // Try to unmarshal as a direct array of ContentBlock var arrayContent []ResponsesMessage - if err := sonic.Unmarshal(data, &arrayContent); err == nil { + if err := Unmarshal(data, &arrayContent); err == nil { rc.ResponsesResponseInstructionsArray = arrayContent return nil } @@ -359,13 +357,13 @@ func (rc ResponsesMessageContent) MarshalJSON() ([]byte, error) { } if rc.ContentStr != nil { - return sonic.Marshal(*rc.ContentStr) + return Marshal(*rc.ContentStr) } if rc.ContentBlocks != nil { - return sonic.Marshal(rc.ContentBlocks) + return Marshal(rc.ContentBlocks) } // If both are nil, return null - return sonic.Marshal(nil) + return Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ResponsesMessageContent. @@ -374,14 +372,14 @@ func (rc ResponsesMessageContent) MarshalJSON() ([]byte, error) { func (rc *ResponsesMessageContent) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var stringContent string - if err := sonic.Unmarshal(data, &stringContent); err == nil { + if err := Unmarshal(data, &stringContent); err == nil { rc.ContentStr = &stringContent return nil } // Try to unmarshal as a direct array of ContentBlock var arrayContent []ResponsesMessageContentBlock - if err := sonic.Unmarshal(data, &arrayContent); err == nil { + if err := Unmarshal(data, &arrayContent); err == nil { rc.ContentBlocks = arrayContent return nil } @@ -501,38 +499,38 @@ type ResponsesToolMessageActionStruct struct { func (action ResponsesToolMessageActionStruct) MarshalJSON() ([]byte, error) { if action.ResponsesComputerToolCallAction != nil { - return sonic.Marshal(action.ResponsesComputerToolCallAction) + return Marshal(action.ResponsesComputerToolCallAction) } if action.ResponsesWebSearchToolCallAction != nil { - return sonic.Marshal(action.ResponsesWebSearchToolCallAction) + return Marshal(action.ResponsesWebSearchToolCallAction) } if action.ResponsesLocalShellToolCallAction != nil { - return sonic.Marshal(action.ResponsesLocalShellToolCallAction) + return Marshal(action.ResponsesLocalShellToolCallAction) } if action.ResponsesMCPApprovalRequestAction != nil { - return sonic.Marshal(action.ResponsesMCPApprovalRequestAction) + return Marshal(action.ResponsesMCPApprovalRequestAction) } return nil, fmt.Errorf("responses tool message action struct is neither a computer tool call action nor a web search tool call action nor a local shell tool call action nor a mcp approval request action") } func (action *ResponsesToolMessageActionStruct) UnmarshalJSON(data []byte) error { var computerToolCallAction ResponsesComputerToolCallAction - if err := sonic.Unmarshal(data, &computerToolCallAction); err == nil { + if err := Unmarshal(data, &computerToolCallAction); err == nil { action.ResponsesComputerToolCallAction = &computerToolCallAction return nil } var webSearchToolCallAction ResponsesWebSearchToolCallAction - if err := sonic.Unmarshal(data, &webSearchToolCallAction); err == nil { + if err := Unmarshal(data, &webSearchToolCallAction); err == nil { action.ResponsesWebSearchToolCallAction = &webSearchToolCallAction return nil } var localShellToolCallAction ResponsesLocalShellToolCallAction - if err := sonic.Unmarshal(data, &localShellToolCallAction); err == nil { + if err := Unmarshal(data, &localShellToolCallAction); err == nil { action.ResponsesLocalShellToolCallAction = &localShellToolCallAction return nil } var mcpApprovalRequestAction ResponsesMCPApprovalRequestAction - if err := sonic.Unmarshal(data, &mcpApprovalRequestAction); err == nil { + if err := Unmarshal(data, &mcpApprovalRequestAction); err == nil { action.ResponsesMCPApprovalRequestAction = &mcpApprovalRequestAction return nil } @@ -547,29 +545,29 @@ type ResponsesToolMessageOutputStruct struct { func (output ResponsesToolMessageOutputStruct) MarshalJSON() ([]byte, error) { if output.ResponsesToolCallOutputStr != nil { - return sonic.Marshal(*output.ResponsesToolCallOutputStr) + return Marshal(*output.ResponsesToolCallOutputStr) } if output.ResponsesFunctionToolCallOutputBlocks != nil { - return sonic.Marshal(output.ResponsesFunctionToolCallOutputBlocks) + return Marshal(output.ResponsesFunctionToolCallOutputBlocks) } if output.ResponsesComputerToolCallOutput != nil { - return sonic.Marshal(output.ResponsesComputerToolCallOutput) + return Marshal(output.ResponsesComputerToolCallOutput) } return nil, fmt.Errorf("responses tool message output struct is neither a string nor an array of responses message content blocks nor a computer tool call output data") } func (output *ResponsesToolMessageOutputStruct) UnmarshalJSON(data []byte) error { var str string - if err := sonic.Unmarshal(data, &str); err == nil { + if err := Unmarshal(data, &str); err == nil { output.ResponsesToolCallOutputStr = &str return nil } var array []ResponsesMessageContentBlock - if err := sonic.Unmarshal(data, &array); err == nil { + if err := Unmarshal(data, &array); err == nil { output.ResponsesFunctionToolCallOutputBlocks = array return nil } var computerToolCallOutput ResponsesComputerToolCallOutputData - if err := sonic.Unmarshal(data, &computerToolCallOutput); err == nil { + if err := Unmarshal(data, &computerToolCallOutput); err == nil { output.ResponsesComputerToolCallOutput = &computerToolCallOutput return nil } @@ -685,13 +683,13 @@ func (rf ResponsesFunctionToolCallOutput) MarshalJSON() ([]byte, error) { } if rf.ResponsesFunctionToolCallOutputStr != nil { - return sonic.Marshal(*rf.ResponsesFunctionToolCallOutputStr) + return Marshal(*rf.ResponsesFunctionToolCallOutputStr) } if rf.ResponsesFunctionToolCallOutputBlocks != nil { - return sonic.Marshal(rf.ResponsesFunctionToolCallOutputBlocks) + return Marshal(rf.ResponsesFunctionToolCallOutputBlocks) } // If both are nil, return null - return sonic.Marshal(nil) + return Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ResponsesFunctionToolCallOutput. @@ -700,7 +698,7 @@ func (rf ResponsesFunctionToolCallOutput) MarshalJSON() ([]byte, error) { func (rf *ResponsesFunctionToolCallOutput) UnmarshalJSON(data []byte) error { // Parse as generic object to check if it contains content-like fields var genericObj map[string]interface{} - if err := sonic.Unmarshal(data, &genericObj); err != nil { + if err := Unmarshal(data, &genericObj); err != nil { return err } @@ -720,14 +718,14 @@ func (rf *ResponsesFunctionToolCallOutput) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var stringContent string - if err := sonic.Unmarshal(data, &stringContent); err == nil { + if err := Unmarshal(data, &stringContent); err == nil { rf.ResponsesFunctionToolCallOutputStr = &stringContent return nil } // Try to unmarshal as a direct array of ContentBlock var arrayContent []ResponsesMessageContentBlock - if err := sonic.Unmarshal(data, &arrayContent); err == nil { + if err := Unmarshal(data, &arrayContent); err == nil { rf.ResponsesFunctionToolCallOutputBlocks = arrayContent return nil } @@ -794,10 +792,10 @@ func (o ResponsesCodeInterpreterOutput) MarshalJSON() ([]byte, error) { // Marshal whichever one is present if o.ResponsesCodeInterpreterOutputLogs != nil { - return sonic.Marshal(o.ResponsesCodeInterpreterOutputLogs) + return Marshal(o.ResponsesCodeInterpreterOutputLogs) } if o.ResponsesCodeInterpreterOutputImage != nil { - return sonic.Marshal(o.ResponsesCodeInterpreterOutputImage) + return Marshal(o.ResponsesCodeInterpreterOutputImage) } // Return null if neither is set @@ -815,7 +813,7 @@ func (o *ResponsesCodeInterpreterOutput) UnmarshalJSON(data []byte) error { var typeStruct struct { Type string `json:"type"` } - if err := sonic.Unmarshal(data, &typeStruct); err != nil { + if err := Unmarshal(data, &typeStruct); err != nil { return fmt.Errorf("failed to read type field: %w", err) } @@ -823,7 +821,7 @@ func (o *ResponsesCodeInterpreterOutput) UnmarshalJSON(data []byte) error { switch typeStruct.Type { case "logs": var logs ResponsesCodeInterpreterOutputLogs - if err := sonic.Unmarshal(data, &logs); err != nil { + if err := Unmarshal(data, &logs); err != nil { return fmt.Errorf("failed to unmarshal logs output: %w", err) } o.ResponsesCodeInterpreterOutputLogs = &logs @@ -832,7 +830,7 @@ func (o *ResponsesCodeInterpreterOutput) UnmarshalJSON(data []byte) error { case "image": var image ResponsesCodeInterpreterOutputImage - if err := sonic.Unmarshal(data, &image); err != nil { + if err := Unmarshal(data, &image); err != nil { return fmt.Errorf("failed to unmarshal image output: %w", err) } o.ResponsesCodeInterpreterOutputImage = &image @@ -982,13 +980,13 @@ func (tc ResponsesToolChoice) MarshalJSON() ([]byte, error) { } if tc.ResponsesToolChoiceStr != nil { - return sonic.Marshal(tc.ResponsesToolChoiceStr) + return Marshal(tc.ResponsesToolChoiceStr) } if tc.ResponsesToolChoiceStruct != nil { - return sonic.Marshal(tc.ResponsesToolChoiceStruct) + return Marshal(tc.ResponsesToolChoiceStruct) } // If both are nil, return null - return sonic.Marshal(nil) + return Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ChatMessageContent. @@ -997,14 +995,14 @@ func (tc ResponsesToolChoice) MarshalJSON() ([]byte, error) { func (tc *ResponsesToolChoice) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var toolChoiceStr string - if err := sonic.Unmarshal(data, &toolChoiceStr); err == nil { + if err := Unmarshal(data, &toolChoiceStr); err == nil { tc.ResponsesToolChoiceStr = &toolChoiceStr return nil } // Try to unmarshal as a direct array of ContentBlock var responsesToolChoiceStruct ResponsesToolChoiceStruct - if err := sonic.Unmarshal(data, &responsesToolChoiceStruct); err == nil { + if err := Unmarshal(data, &responsesToolChoiceStruct); err == nil { tc.ResponsesToolChoiceStruct = &responsesToolChoiceStruct return nil } @@ -1115,14 +1113,14 @@ func (f *ResponsesToolFileSearchFilter) MarshalJSON() ([]byte, error) { return nil, fmt.Errorf("unknown filter type: %s", f.Type) } - return sonic.Marshal(result) + return Marshal(result) } // UnmarshalJSON implements custom JSON unmarshaling for ResponsesToolFileSearchFilter func (f *ResponsesToolFileSearchFilter) UnmarshalJSON(data []byte) error { // First, unmarshal into a map to inspect the type field var raw map[string]interface{} - if err := sonic.Unmarshal(data, &raw); err != nil { + if err := Unmarshal(data, &raw); err != nil { return fmt.Errorf("failed to unmarshal filter JSON: %w", err) } @@ -1147,7 +1145,7 @@ func (f *ResponsesToolFileSearchFilter) UnmarshalJSON(data []byte) error { f.ResponsesToolFileSearchCompoundFilter = nil // Unmarshal into the comparison filter - if err := sonic.Unmarshal(data, f.ResponsesToolFileSearchComparisonFilter); err != nil { + if err := Unmarshal(data, f.ResponsesToolFileSearchComparisonFilter); err != nil { return fmt.Errorf("failed to unmarshal comparison filter: %w", err) } @@ -1165,7 +1163,7 @@ func (f *ResponsesToolFileSearchFilter) UnmarshalJSON(data []byte) error { f.ResponsesToolFileSearchComparisonFilter = nil // Unmarshal into the compound filter - if err := sonic.Unmarshal(data, f.ResponsesToolFileSearchCompoundFilter); err != nil { + if err := Unmarshal(data, f.ResponsesToolFileSearchCompoundFilter); err != nil { return fmt.Errorf("failed to unmarshal compound filter: %w", err) } @@ -1273,7 +1271,7 @@ func (as ResponsesToolMCPAllowedToolsApprovalSetting) MarshalJSON() ([]byte, err } if as.Setting != nil { - return sonic.Marshal(*as.Setting) + return Marshal(*as.Setting) } if as.Always != nil || as.Never != nil { // Marshal as an object with always/never fields @@ -1284,17 +1282,17 @@ func (as ResponsesToolMCPAllowedToolsApprovalSetting) MarshalJSON() ([]byte, err if as.Never != nil { obj["never"] = as.Never } - return sonic.Marshal(obj) + return Marshal(obj) } // If all are nil, return null - return sonic.Marshal(nil) + return Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ResponsesToolMCPAllowedToolsApprovalSetting func (as *ResponsesToolMCPAllowedToolsApprovalSetting) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var settingStr string - if err := sonic.Unmarshal(data, &settingStr); err == nil { + if err := Unmarshal(data, &settingStr); err == nil { as.Setting = &settingStr return nil } @@ -1304,7 +1302,7 @@ func (as *ResponsesToolMCPAllowedToolsApprovalSetting) UnmarshalJSON(data []byte Always *ResponsesToolMCPAllowedToolsApprovalFilter `json:"always,omitempty"` Never *ResponsesToolMCPAllowedToolsApprovalFilter `json:"never,omitempty"` } - if err := sonic.Unmarshal(data, &obj); err == nil { + if err := Unmarshal(data, &obj); err == nil { as.Always = obj.Always as.Never = obj.Never return nil diff --git a/core/schemas/speech.go b/core/schemas/speech.go index 6dcf4fec86..afea821177 100644 --- a/core/schemas/speech.go +++ b/core/schemas/speech.go @@ -2,8 +2,6 @@ package schemas import ( "fmt" - - "github.com/bytedance/sonic" ) type BifrostSpeechRequest struct { @@ -81,13 +79,13 @@ func (vi *SpeechVoiceInput) MarshalJSON() ([]byte, error) { } if vi.Voice != nil { - return sonic.Marshal(*vi.Voice) + return Marshal(*vi.Voice) } if len(vi.MultiVoiceConfig) > 0 { - return sonic.Marshal(vi.MultiVoiceConfig) + return Marshal(vi.MultiVoiceConfig) } // If both are nil, return null - return sonic.Marshal(nil) + return Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for SpeechVoiceInput. @@ -100,14 +98,14 @@ func (vi *SpeechVoiceInput) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var stringContent string - if err := sonic.Unmarshal(data, &stringContent); err == nil { + if err := Unmarshal(data, &stringContent); err == nil { vi.Voice = &stringContent return nil } // Try to unmarshal as an array of VoiceConfig objects var voiceConfigs []VoiceConfig - if err := sonic.Unmarshal(data, &voiceConfigs); err == nil { + if err := Unmarshal(data, &voiceConfigs); err == nil { // Validate each VoiceConfig and build a new slice deterministically validConfigs := make([]VoiceConfig, 0, len(voiceConfigs)) for _, config := range voiceConfigs { diff --git a/core/schemas/textcompletions.go b/core/schemas/textcompletions.go index c65f0db2f9..071673b51a 100644 --- a/core/schemas/textcompletions.go +++ b/core/schemas/textcompletions.go @@ -2,8 +2,6 @@ package schemas import ( "fmt" - - "github.com/bytedance/sonic" ) // BifrostTextCompletionRequest is the request struct for text completion requests @@ -96,20 +94,20 @@ func (t *TextCompletionInput) MarshalJSON() ([]byte, error) { return nil, fmt.Errorf("text completion input must set exactly one of: prompt_str or prompt_array") } if t.PromptStr != nil { - return sonic.Marshal(*t.PromptStr) + return Marshal(*t.PromptStr) } - return sonic.Marshal(t.PromptArray) + return Marshal(t.PromptArray) } func (t *TextCompletionInput) UnmarshalJSON(data []byte) error { var prompt string - if err := sonic.Unmarshal(data, &prompt); err == nil { + if err := Unmarshal(data, &prompt); err == nil { t.PromptStr = &prompt t.PromptArray = nil return nil } var promptArray []string - if err := sonic.Unmarshal(data, &promptArray); err == nil { + if err := Unmarshal(data, &promptArray); err == nil { t.PromptStr = nil t.PromptArray = promptArray return nil diff --git a/core/schemas/utils.go b/core/schemas/utils.go index e8fc6976e5..e1a8c46470 100644 --- a/core/schemas/utils.go +++ b/core/schemas/utils.go @@ -7,8 +7,6 @@ import ( "regexp" "strconv" "strings" - - "github.com/bytedance/sonic" ) // Ptr creates a pointer to any value. @@ -267,7 +265,7 @@ func JsonifyInput(input interface{}) string { if input == nil { return "{}" } - jsonString, err := sonic.MarshalString(input) + jsonString, err := MarshalString(input) if err != nil { return "{}" } diff --git a/docs/enterprise/setting-up-okta.mdx b/docs/enterprise/setting-up-okta.mdx index bf122ebaf9..cf0bb754c6 100644 --- a/docs/enterprise/setting-up-okta.mdx +++ b/docs/enterprise/setting-up-okta.mdx @@ -71,9 +71,21 @@ Configure the following settings for your application: --- -## Step 3: Configure Authorization Server +## Step 3: Configure Authorization Server (optional) -Bifrost uses Okta's Authorization Server to issue tokens. You can use the default authorization server or create a custom one. + +The default authorization server (`/oauth2/default`) is available to all Okta plans and **supports custom claims**, including role claims. The API Access Management paid add-on is only required to create additional custom authorization servers beyond the default. + + +Bifrost uses Okta's Authorization Server to issue tokens. You have three options: + +1. **Use `/oauth2/default` with role claims (recommended)** — Complete Steps 4-7 to configure custom role claims on the default authorization server. This enables automatic RBAC synchronization. + +2. **Use `/oauth2/default` without role claims** — Skip Steps 4-7. The first user to sign in automatically receives the Admin role and can manage RBAC for all subsequent users through the Bifrost dashboard. + +3. **Skip Step 3 entirely** — Authorization is not configured through Okta. You'll need an alternative authentication mechanism. + +### Configuring the Authorization Server 1. Navigate to **Security** → **API** 2. Click on **Authorization Servers** @@ -224,6 +236,10 @@ Adjust the group filter expression based on your naming convention. The example 4. Click **Save and Go Back** + +Role claims are available only when you configure custom claims on your authorization server. Ensure you add role claims to your chosen authorization server (for example, `/oauth2/default`) to enable RBAC. If you skipped Steps 4-7, the first user to sign in automatically receives the **Admin** role and can manage RBAC for all subsequent users through the Bifrost dashboard. + + --- ## Step 8: Configure Bifrost @@ -239,7 +255,7 @@ Now configure Bifrost to use Okta as the identity provider. | Field | Value | |-------|-------| | **Client ID** | Your Okta application Client ID | -| **Issuer URL** | `https://your-domain.okta.com/oauth2/default` | +| **Issuer URL** | Issuer URL | | **Audience** | Your API audience (e.g., `api://default` or custom) | | **Client Secret** | Your Okta application Client Secret (optional, for token revocation) | diff --git a/docs/plugins/migration-guide.mdx b/docs/plugins/migration-guide.mdx index 20313d7081..b2ceab6d67 100644 --- a/docs/plugins/migration-guide.mdx +++ b/docs/plugins/migration-guide.mdx @@ -6,7 +6,7 @@ icon: "arrow-up-right-dots" ## Overview -Bifrost v1.4.x introduces a new plugin interface for HTTP transport layer interception. This guide helps you migrate existing plugins from the v1.3.x `TransportInterceptor` pattern to the v1.4.x `HTTPTransportMiddleware` pattern. +Bifrost v1.4.x introduces a new plugin interface for HTTP transport layer interception. This guide helps you migrate existing plugins from the v1.3.x `TransportInterceptor` pattern to the v1.4.x `HTTPTransportIntercept` pattern. If your plugin doesn't use `TransportInterceptor`, no migration is needed. The `PreHook`, `PostHook`, `Init`, `GetName`, and `Cleanup` functions remain unchanged. @@ -14,39 +14,42 @@ If your plugin doesn't use `TransportInterceptor`, no migration is needed. The ` ## What Changed? -The HTTP transport interception mechanism changed from a simple function that receives and returns headers/body to a middleware pattern that wraps the entire request handler chain. +The HTTP transport interception mechanism changed from a simple function that receives and returns headers/body to a serializable intercept pattern that works with both native `.so` plugins and WASM plugins. ### Key Differences -| Aspect | v1.3.x (TransportInterceptor) | v1.4.x+ (HTTPTransportMiddleware) | -|--------|-------------------------------|-----------------------------------| -| Function signature | `TransportInterceptor(ctx, url, headers, body)` | `HTTPTransportMiddleware()` | -| Return type | `(headers, body, error)` | `BifrostHTTPMiddleware` | -| Access scope | Headers and body as maps | Full `*fasthttp.RequestCtx` | -| Flow control | Implicit (return modified values) | Explicit (`next(ctx)` call) | -| Capability | Request modification only | Request AND response modification | +| Aspect | v1.3.x (TransportInterceptor) | v1.4.x+ (HTTPTransportIntercept) | +|--------|-------------------------------|----------------------------------| +| Signature | `TransportInterceptor(ctx, url, headers, body)` | `HTTPTransportIntercept(ctx, req)` | +| Return type | `(headers, body, error)` | `(*HTTPResponse, error)` | +| Request type | Separate `headers map`, `body map` | Unified `*HTTPRequest` struct | +| Modification | Return modified maps | Modify `req` in-place | +| Short-circuit | Return error | Return `*HTTPResponse` | +| WASM support | No | Yes | +| Context | Limited `BifrostContext` | Full `*BifrostContext` | ### Why the Change? -The new middleware pattern provides: +The new intercept pattern provides: -1. **Full HTTP control** - Access to the complete `*fasthttp.RequestCtx` including method, path, query params, and all headers -2. **Response interception** - Ability to modify responses after they return from downstream handlers -3. **Better composability** - Standard middleware pattern that chains naturally with other handlers -4. **More flexibility** - Can short-circuit requests, add timing, implement retries, etc. +1. **WASM plugin support** - Serializable types work across WASM boundary +2. **Simpler API** - No middleware wrapper, direct function call +3. **Better testability** - No fasthttp dependency in plugin tests +4. **Full context access** - BifrostContext available for request metadata +5. **Custom response short-circuits** - Return a full response to short-circuit ## Migration Steps ### Step 1: Update Imports -Add the `fasthttp` import to your plugin: +Remove the `fasthttp` import if present: ```go import ( "fmt" "github.com/maximhq/bifrost/core/schemas" - "github.com/valyala/fasthttp" // Add this import + // Remove: "github.com/valyala/fasthttp" ) ``` @@ -70,29 +73,26 @@ func TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[s **After (v1.4.x+):** ```go -// HTTPTransportMiddleware returns a middleware for HTTP transport -func HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { - return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { - return func(ctx *fasthttp.RequestCtx) { - // Add custom header - ctx.Request.Header.Set("X-Custom-Header", "value") - - // Modify body (if needed) - // Note: Body modification requires parsing and re-serializing - // ctx.Request.SetBody(modifiedBody) - - // Call next handler in chain - next(ctx) - - // Can also modify response here after next() returns - } - } +// HTTPTransportIntercept intercepts requests at the transport layer +// Modify req in-place. Return (*HTTPResponse, nil) to short-circuit. +func HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + // Add custom header (in-place modification) + req.Headers["X-Custom-Header"] = "value" + + // Modify body (in-place modification) + var body map[string]any + sonic.Unmarshal(req.Body, &body) + body["custom_field"] = "custom_value" + req.Body, _ = sonic.Marshal(body) + + // Return nil to continue, or return &HTTPResponse{} to short-circuit + return nil, nil } ``` ### Step 3: Update Body Modification Logic -In v1.3.x, you received the body as a `map[string]any`. In v1.4.x, you work with raw bytes: +In v1.3.x, you received the body as a `map[string]any`. In v1.4.x, you work with `req.Body` bytes: **Before (v1.3.x):** @@ -109,46 +109,16 @@ func TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[s ```go import "github.com/bytedance/sonic" -func HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { - return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { - return func(ctx *fasthttp.RequestCtx) { - // Parse existing body - var body map[string]any - if err := sonic.Unmarshal(ctx.Request.Body(), &body); err == nil { - // Modify body - body["model"] = "gpt-4" - - // Re-serialize and set - if newBody, err := sonic.Marshal(body); err == nil { - ctx.Request.SetBody(newBody) - } - } - - next(ctx) - } - } -} -``` - -### Step 4: Handle Response Modification (New Capability) - -The new pattern allows you to modify responses, which wasn't possible in v1.3.x: - -```go -func HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { - return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { - return func(ctx *fasthttp.RequestCtx) { - // Before request - startTime := time.Now() - - // Process request - next(ctx) - - // After response - NEW CAPABILITY - duration := time.Since(startTime) - ctx.Response.Header.Set("X-Processing-Time", duration.String()) - } +func HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + // Parse body + var body map[string]any + if err := sonic.Unmarshal(req.Body, &body); err == nil { + // Modify body + body["model"] = "gpt-4" + // Update req.Body in-place + req.Body, _ = sonic.Marshal(body) } + return nil, nil } ``` @@ -164,8 +134,8 @@ return headers, body, nil **v1.4.x+:** ```go -ctx.Request.Header.Set("Authorization", "Bearer " + token) -next(ctx) +req.Headers["Authorization"] = "Bearer " + token +return nil, nil ``` ### Reading Headers @@ -177,7 +147,7 @@ apiKey := headers["X-API-Key"] **v1.4.x+:** ```go -apiKey := string(ctx.Request.Header.Peek("X-API-Key")) +apiKey := req.Headers["X-API-Key"] ``` ### Conditional Processing @@ -195,21 +165,16 @@ func TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[s **v1.4.x+:** ```go -func HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { - return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { - return func(ctx *fasthttp.RequestCtx) { - if string(ctx.Request.Header.Peek("X-Skip-Processing")) == "true" { - next(ctx) - return - } - // Process... - next(ctx) - } +func HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + if req.Headers["X-Skip-Processing"] == "true" { + return nil, nil // Continue without modification } + // Process... + return nil, nil } ``` -### Error Handling +### Error Handling / Short-Circuit **v1.3.x:** ```go @@ -223,17 +188,38 @@ func TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[s **v1.4.x+:** ```go -func HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { - return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { - return func(ctx *fasthttp.RequestCtx) { - if len(ctx.Request.Header.Peek("X-API-Key")) == 0 { - ctx.SetStatusCode(401) - ctx.SetBodyString(`{"error": "missing API key"}`) - return // Don't call next - short-circuit the request - } - next(ctx) - } +func HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + if req.Headers["X-API-Key"] == "" { + // Return a custom response to short-circuit + return &schemas.HTTPResponse{ + StatusCode: 401, + Headers: map[string]string{"Content-Type": "application/json"}, + Body: []byte(`{"error": "missing API key"}`), + }, nil } + return nil, nil +} +``` + +### Accessing Request Method and Path + +**v1.3.x:** +```go +// url parameter contained the full URL +func TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { + // Limited access to URL + return headers, body, nil +} +``` + +**v1.4.x+:** +```go +func HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + // Full access to request properties + method := req.Method // "GET", "POST", etc. + path := req.Path // "/v1/chat/completions" + query := req.Query // map[string]string of query params + return nil, nil } ``` @@ -256,9 +242,9 @@ func HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { -d '{"model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello"}]}' ``` -4. **Verify logs show the new middleware being called:** +4. **Verify logs show the new intercept being called:** ``` - HTTPTransportMiddleware called + HTTPTransportIntercept called PreHook called PostHook called ``` @@ -270,33 +256,38 @@ func HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { **Error:** `plugin: symbol TransportInterceptor not found` This error occurs if Bifrost v1.4.x is looking for the old function. Make sure: -1. You've updated to `HTTPTransportMiddleware` -2. The function signature matches exactly +1. You've updated to `HTTPTransportIntercept` +2. The function signature matches exactly: `func HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error)` 3. You've rebuilt the plugin with the correct core version ### Body modification not working -Make sure you're calling `ctx.Request.SetBody()` with the serialized bytes, not the map directly: +Make sure you're assigning back to `req.Body`: ```go -// Wrong -ctx.Request.SetBody(body) // body is map[string]any - -// Correct -bodyBytes, _ := sonic.Marshal(body) -ctx.Request.SetBody(bodyBytes) +// Wrong - body changes lost +var body map[string]any +sonic.Unmarshal(req.Body, &body) +body["model"] = "gpt-4" +// Missing: req.Body = ... + +// Correct - body changes applied +var body map[string]any +sonic.Unmarshal(req.Body, &body) +body["model"] = "gpt-4" +req.Body, _ = sonic.Marshal(body) // Assign back! ``` ### Headers not being set -Remember that `fasthttp` header methods are case-sensitive for custom headers: +Make sure you're modifying `req.Headers` directly: ```go -// Set header -ctx.Request.Header.Set("X-Custom-Header", "value") +// Set header (case-sensitive keys) +req.Headers["X-Custom-Header"] = "value" -// Read header - use Peek for []byte or string conversion -value := string(ctx.Request.Header.Peek("X-Custom-Header")) +// Read header +value := req.Headers["X-Custom-Header"] ``` ## Need Help? diff --git a/docs/plugins/writing-plugin.mdx b/docs/plugins/writing-plugin.mdx index aa4c9e58be..528219aa35 100644 --- a/docs/plugins/writing-plugin.mdx +++ b/docs/plugins/writing-plugin.mdx @@ -71,7 +71,6 @@ import ( "fmt" "github.com/maximhq/bifrost/core/schemas" - "github.com/valyala/fasthttp" ) // Init is called when the plugin is loaded @@ -87,18 +86,15 @@ func GetName() string { return "Hello World Plugin" } -// HTTPTransportMiddleware returns a middleware for HTTP transport +// HTTPTransportIntercept intercepts requests at the HTTP transport layer +// Modify req in-place. Return (*HTTPResponse, nil) to short-circuit. // Only called when using HTTP transport (bifrost-http) -func HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { - return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { - return func(ctx *fasthttp.RequestCtx) { - fmt.Println("HTTPTransportMiddleware called") - // Modify request headers/body via ctx.Request before calling next - // Call next handler in the chain - next(ctx) - // Can also modify response via ctx.Response after next returns - } - } +func HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + fmt.Println("HTTPTransportIntercept called") + // Modify request in-place (headers, body, query params) + req.Headers["X-Custom-Header"] = "custom-value" + // Return nil to continue, or return &schemas.HTTPResponse{} to short-circuit + return nil, nil } // PreHook is called before the request is sent to the provider @@ -212,15 +208,20 @@ Returns a unique identifier for your plugin. This name appears in logs and statu -#### `HTTPTransportMiddleware()` +#### `HTTPTransportIntercept(ctx, req)` -**HTTP transport only.** Returns a middleware that wraps the HTTP request handler chain. Use this to: -- Intercept and modify requests before they enter Bifrost core -- Intercept and modify responses before they're returned to clients -- Implement authentication or logging at the transport layer -- Access raw `*fasthttp.RequestCtx` for full HTTP control +**HTTP transport only.** Intercepts requests at the transport layer. Use this to: +- Modify request headers, body, or query params in-place +- Short-circuit with a custom response +- Access `BifrostContext` for request metadata +- Works with both native .so and WASM plugins -The middleware pattern requires calling `next(ctx)` to pass control to subsequent handlers. +Key points: +- Receives serializable `*HTTPRequest` (not raw fasthttp) +- Modify `req.Headers`, `req.Body`, `req.Query` directly +- Return `(nil, nil)` to continue to next plugin/handler +- Return `(*HTTPResponse, nil)` to short-circuit with response +- Return `(nil, error)` to short-circuit with error This function is **only called** when using `bifrost-http`. It's **not invoked** when using Bifrost as a Go SDK. @@ -429,7 +430,7 @@ Check the logs for plugin hook calls: ``` -HTTPTransportMiddleware called +HTTPTransportIntercept called PreHook called PostHook called ``` diff --git a/examples/plugins/hello-world-wasm-go/Makefile b/examples/plugins/hello-world-wasm-go/Makefile new file mode 100644 index 0000000000..713f911bd3 --- /dev/null +++ b/examples/plugins/hello-world-wasm-go/Makefile @@ -0,0 +1,74 @@ +.PHONY: all build clean help check-tinygo + +# Colors +COLOR_RESET = \033[0m +COLOR_INFO = \033[36m +COLOR_SUCCESS = \033[32m +COLOR_WARNING = \033[33m +COLOR_ERROR = \033[31m +COLOR_BOLD = \033[1m + +# Plugin configuration +PLUGIN_NAME = hello-world +OUTPUT_DIR = build +OUTPUT = $(OUTPUT_DIR)/$(PLUGIN_NAME).wasm + +# TinyGo build flags +TINYGO_TARGET = wasi +TINYGO_SCHEDULER = none + +help: ## Show this help message + @echo '$(COLOR_BOLD)Hello World WASM Plugin$(COLOR_RESET)' + @echo '' + @echo '$(COLOR_BOLD)Usage:$(COLOR_RESET) make [target]' + @echo '' + @echo '$(COLOR_BOLD)Prerequisites:$(COLOR_RESET)' + @echo ' - TinyGo (https://tinygo.org/getting-started/install/)' + @echo ' macOS: brew install tinygo' + @echo ' Linux: See TinyGo installation docs' + @echo '' + @echo '$(COLOR_BOLD)Available targets:$(COLOR_RESET)' + @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {printf " $(COLOR_INFO)%-15s$(COLOR_RESET) %s\n", $$1, $$2}' $(MAKEFILE_LIST) + +check-tinygo: ## Check if TinyGo is installed + @which tinygo > /dev/null 2>&1 || (echo "$(COLOR_ERROR)Error: TinyGo is not installed$(COLOR_RESET)"; \ + echo "$(COLOR_INFO)Install TinyGo:$(COLOR_RESET)"; \ + echo " macOS: brew install tinygo"; \ + echo " Linux: See https://tinygo.org/getting-started/install/"; \ + exit 1) + @echo "$(COLOR_SUCCESS)✓ TinyGo found: $$(tinygo version)$(COLOR_RESET)" + +build: check-tinygo ## Build the WASM plugin + @mkdir -p $(OUTPUT_DIR) + @echo "$(COLOR_INFO)Building WASM plugin...$(COLOR_RESET)" + GOWORK=off tinygo build -o $(OUTPUT) -target=$(TINYGO_TARGET) -scheduler=$(TINYGO_SCHEDULER) . + @echo "$(COLOR_SUCCESS)✓ Plugin built successfully: $(OUTPUT)$(COLOR_RESET)" + @ls -lh $(OUTPUT) | awk '{print " Size: " $$5}' + +build-optimized: check-tinygo ## Build the WASM plugin with size optimizations + @mkdir -p $(OUTPUT_DIR) + @echo "$(COLOR_INFO)Building optimized WASM plugin...$(COLOR_RESET)" + GOWORK=off tinygo build -o $(OUTPUT) -target=$(TINYGO_TARGET) -scheduler=$(TINYGO_SCHEDULER) -no-debug -gc=leaking . + @echo "$(COLOR_SUCCESS)✓ Optimized plugin built: $(OUTPUT)$(COLOR_RESET)" + @ls -lh $(OUTPUT) | awk '{print " Size: " $$5}' + +clean: ## Remove build artifacts + @echo "$(COLOR_INFO)Cleaning build artifacts...$(COLOR_RESET)" + @rm -rf $(OUTPUT_DIR) + @echo "$(COLOR_SUCCESS)✓ Clean complete$(COLOR_RESET)" + +info: ## Show build information + @echo "$(COLOR_BOLD)Build Configuration$(COLOR_RESET)" + @echo " Plugin Name: $(PLUGIN_NAME)" + @echo " Output: $(OUTPUT)" + @echo " Target: $(TINYGO_TARGET)" + @echo " Scheduler: $(TINYGO_SCHEDULER)" + @echo "" + @if [ -f "$(OUTPUT)" ]; then \ + echo "$(COLOR_SUCCESS)Plugin exists:$(COLOR_RESET)"; \ + ls -lh $(OUTPUT) | awk '{print " " $$9 " (" $$5 ")"}'; \ + else \ + echo "$(COLOR_WARNING)Plugin not built yet$(COLOR_RESET)"; \ + fi + +.DEFAULT_GOAL := help diff --git a/examples/plugins/hello-world-wasm-go/README.md b/examples/plugins/hello-world-wasm-go/README.md new file mode 100644 index 0000000000..3b12233419 --- /dev/null +++ b/examples/plugins/hello-world-wasm-go/README.md @@ -0,0 +1,170 @@ +# Hello World WASM Plugin + +A minimal example of a Bifrost plugin written in Go and compiled to WebAssembly using TinyGo. + +## Prerequisites + +### TinyGo Installation + +TinyGo is required to compile Go code to WebAssembly with a small binary size. + +**macOS:** +```bash +brew install tinygo +``` + +**Linux (Ubuntu/Debian):** +```bash +wget https://github.com/tinygo-org/tinygo/releases/download/v0.32.0/tinygo_0.32.0_amd64.deb +sudo dpkg -i tinygo_0.32.0_amd64.deb +``` + +**Other platforms:** +See [TinyGo Installation Guide](https://tinygo.org/getting-started/install/) + +## Building + +```bash +# Build the WASM plugin +make build + +# Build with size optimizations +make build-optimized + +# Clean build artifacts +make clean +``` + +The compiled plugin will be at `build/hello-world.wasm`. + +## Plugin Structure + +WASM plugins must export the following functions: + +| Export | Signature | Description | +|--------|-----------|-------------| +| `plugin_malloc` | `(size: u32) -> u32` | Allocate memory for host to write data (or `malloc` for non-TinyGo) | +| `plugin_free` | `(ptr: u32)` | Free allocated memory (optional, or `free` for non-TinyGo) | +| `get_name` | `() -> u64` | Returns packed ptr+len of plugin name | +| `http_transport_intercept` | `(ctx_ptr, ctx_len, req_ptr, req_len: u32) -> u64` | HTTP transport intercept | +| `pre_hook` | `(ctx_ptr, ctx_len, req_ptr, req_len: u32) -> u64` | Pre-request hook | +| `post_hook` | `(ctx_ptr, ctx_len, resp_ptr, resp_len, err_ptr, err_len: u32) -> u64` | Post-response hook | +| `cleanup` | `() -> i32` | Cleanup resources (0 = success) | +| `init` | `(config_ptr, config_len: u32) -> i32` | Initialize with config (optional) | + +### Return Value Format + +Functions returning data use a packed `u64` format: +- Upper 32 bits: pointer to data in WASM memory +- Lower 32 bits: length of data + +### Data Exchange + +All complex data is exchanged as JSON: + +**HTTPTransportIntercept Input:** +- `ctx`: `{"request_id": "..."}` (context info) +- `req`: HTTP request JSON +```json +{ + "method": "POST", + "path": "/v1/chat/completions", + "headers": {"Content-Type": "application/json"}, + "query": {}, + "body": "base64-encoded-body" +} +``` + +**HTTPTransportIntercept Output:** +```json +{ + "response": null, + "error": "" +} +``` +To short-circuit, return a response: +```json +{ + "response": { + "status_code": 401, + "headers": {"Content-Type": "application/json"}, + "body": "base64-encoded-body" + }, + "error": "" +} +``` + +**PreHook Input:** +- `ctx`: `{"request_id": "..."}` (context info) +- `req`: Bifrost request JSON + +**PreHook Output:** +```json +{ + "request": { ... }, + "short_circuit": null, + "error": "" +} +``` + +**PostHook Input:** +- `ctx`: Context JSON +- `resp`: Bifrost response JSON +- `err`: Bifrost error JSON (or null) + +**PostHook Output:** +```json +{ + "response": { ... }, + "bifrost_error": null, + "error": "" +} +``` + +## Usage with Bifrost + +Configure the plugin in your Bifrost config: + +```json +{ + "plugins": [ + { + "path": "/path/to/hello-world.wasm", + "name": "hello-world-wasm", + "enabled": true + } + ] +} +``` + +Or load from URL: + +```json +{ + "plugins": [ + { + "path": "https://example.com/plugins/hello-world.wasm", + "name": "hello-world-wasm", + "enabled": true + } + ] +} +``` + +## Limitations + +WASM plugins have some limitations compared to native `.so` plugins: + +1. **Performance**: JSON serialization/deserialization adds overhead compared to native plugins. + +2. **Memory**: WASM modules have a linear memory model with limited addressing. + +3. **TinyGo Constraints**: Some Go standard library features are not available in TinyGo. + +## Benefits + +1. **Cross-platform**: Single `.wasm` binary runs on any OS/architecture +2. **Security**: WASM provides sandboxed execution +3. **No CGO**: Pure Go compilation, no C dependencies needed on the host +4. **Portability**: Easy to distribute and deploy +5. **Full feature parity**: HTTP transport intercept, PreHook, and PostHook all supported \ No newline at end of file diff --git a/examples/plugins/hello-world-wasm-go/go.mod b/examples/plugins/hello-world-wasm-go/go.mod new file mode 100644 index 0000000000..64a44e2780 --- /dev/null +++ b/examples/plugins/hello-world-wasm-go/go.mod @@ -0,0 +1,32 @@ +module github.com/maximhq/bifrost/examples/plugins/hello-world-wasm + +go 1.25.5 + +require github.com/maximhq/bifrost/core v0.0.0-00010101000000-000000000000 + +replace github.com/maximhq/bifrost/core => ../../../core + +require ( + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.14.2 // indirect + github.com/bytedance/sonic/loader v0.4.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/klauspost/compress v1.18.2 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mailru/easyjson v0.9.1 // indirect + github.com/mark3labs/mcp-go v0.43.2 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.68.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/arch v0.23.0 // indirect + golang.org/x/sys v0.39.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/examples/plugins/hello-world-wasm-go/go.sum b/examples/plugins/hello-world-wasm-go/go.sum new file mode 100644 index 0000000000..fee36f9db2 --- /dev/null +++ b/examples/plugins/hello-world-wasm-go/go.sum @@ -0,0 +1,78 @@ +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/sonic v1.14.2 h1:k1twIoe97C1DtYUo+fZQy865IuHia4PR5RPiuGPPIIE= +github.com/bytedance/sonic v1.14.2/go.mod h1:T80iDELeHiHKSc0C9tubFygiuXoGzrkjKzX2quAx980= +github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2NYzevs+o= +github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= +github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8= +github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I= +github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= +github.com/maximhq/bifrost/core v1.3.3 h1:r2llMAfzIHeSxwY2L55UaSOsY17JSg5zYcqF2JtaRVY= +github.com/maximhq/bifrost/core v1.3.3/go.mod h1:abKQRnJQPZz8/UMxCcbuNHEyq19Db+IX4KlGJdlLY8E= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.68.0 h1:v12Nx16iepr8r9ySOwqI+5RBJ/DqTxhOy1HrHoDFnok= +github.com/valyala/fasthttp v1.68.0/go.mod h1:5EXiRfYQAoiO/khu4oU9VISC/eVY6JqmSpPJoHCKsz4= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/arch v0.23.0 h1:lKF64A2jF6Zd8L0knGltUnegD62JMFBiCPBmQpToHhg= +golang.org/x/arch v0.23.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/examples/plugins/hello-world-wasm-go/main.go b/examples/plugins/hello-world-wasm-go/main.go new file mode 100644 index 0000000000..66eaeefd3d --- /dev/null +++ b/examples/plugins/hello-world-wasm-go/main.go @@ -0,0 +1,180 @@ +// Package main provides a hello-world WASM plugin example for Bifrost. +// This plugin demonstrates the basic structure and exports required for WASM plugins. +// +// Build with TinyGo: +// +// tinygo build -o build/hello-world.wasm -target=wasi -scheduler=none main.go +package main + +import ( + "encoding/json" +) + +// ============================================================================ +// Plugin Exports +// ============================================================================ + +//export get_name +func get_name() uint64 { + return writeBytes([]byte("Hello World WASM Plugin")) +} + +//export init +func init_plugin(configPtr, configLen uint32) int32 { + println("WASM Plugin: Init called") + if configLen > 0 { + configData := readInput(configPtr, configLen) + println("WASM Plugin: Config received:", string(configData)) + } + return 0 +} + +//export http_intercept +func http_intercept(inputPtr, inputLen uint32) uint64 { + println("WASM Plugin: http_intercept called") + + inputData := readInput(inputPtr, inputLen) + if inputData == nil { + return writeError("no input data") + } + + // Parse input + var input HTTPInterceptInput + if err := json.Unmarshal(inputData, &input); err != nil { + println("WASM Plugin: parse error:", err.Error()) + return writeError("parse error: " + err.Error()) + } + + // Log parsed data + println("WASM Plugin: HTTP", input.Request.Method, input.Request.Path) + if ct, ok := input.Request.Headers["content-type"]; ok { + println("WASM Plugin: Content-Type:", ct) + } + input.Context["from-http"] = "123" + // Return pass-through + output := HTTPInterceptOutput{ + Context: input.Context, + Request: input.Request, + HasResponse: false, + Error: "", + } + + data, _ := json.Marshal(output) + return writeBytes(data) +} + +//export pre_hook +func pre_hook(inputPtr, inputLen uint32) uint64 { + println("WASM Plugin: pre_hook called") + + inputData := readInput(inputPtr, inputLen) + if inputData == nil { + return writePreHookError("no input data") + } + + println("WASM Plugin: Pre-hook input:", string(inputData)) + + // Parse input + var input PreHookInput + if err := json.Unmarshal(inputData, &input); err != nil { + println("WASM Plugin: parse error:", err.Error()) + return writePreHookError("parse error: " + err.Error()) + } + + // Print existing context + for k, v := range input.Context { + println("WASM Plugin: Context", k, "=", v) + } + + input.Context["from-pre-hook"] = "789" + + // Return with custom context value + output := PreHookOutput{ + Context: input.Context, + Request: input.Request, + HasShortCircuit: false, + Error: "", + } + + data, _ := json.Marshal(output) + return writeBytes(data) +} + +//export post_hook +func post_hook(inputPtr, inputLen uint32) uint64 { + println("WASM Plugin: post_hook called") + + inputData := readInput(inputPtr, inputLen) + if inputData == nil { + return writePostHookError("no input data") + } + + // Parse input + var input PostHookInput + if err := json.Unmarshal(inputData, &input); err != nil { + println("WASM Plugin: parse error:", err.Error()) + return writePostHookError("parse error: " + err.Error()) + } + + println("WASM Plugin: Post-hook input:", string(inputData)) + // Print existing context + for k, v := range input.Context { + println("WASM Plugin: Context", k, "=", v) + } + + // Parse response for logging + + if processed, ok := input.Context["wasm_plugin_processed"].(bool); ok && processed { + println("WASM Plugin: Pre-hook context value present") + } + + input.Context["from-post-hook"] = "456" + // Return pass-through + output := PostHookOutput{ + Context: input.Context, + Response: input.Response, + Error: input.Error, + HasError: false, + HookError: "", + } + + data, _ := json.Marshal(output) + return writeBytes(data) +} + +//export cleanup +func cleanup() int32 { + println("WASM Plugin: Cleanup called") + return 0 +} + +// Helper functions for error responses +func writeError(msg string) uint64 { + output := HTTPInterceptOutput{HasResponse: false, Error: msg} + data, _ := json.Marshal(output) + return writeBytes(data) +} + +func writePreHookError(msg string) uint64 { + output := PreHookOutput{ + Context: map[string]interface{}{}, + Request: nil, + HasShortCircuit: false, + Error: msg, + } + data, _ := json.Marshal(output) + return writeBytes(data) +} + +func writePostHookError(msg string) uint64 { + output := PostHookOutput{ + Context: map[string]interface{}{}, + Response: nil, + HasError: false, + HookError: msg, + } + data, _ := json.Marshal(output) + return writeBytes(data) +} + +func main() {} diff --git a/examples/plugins/hello-world-wasm-go/memory.go b/examples/plugins/hello-world-wasm-go/memory.go new file mode 100644 index 0000000000..63b8fa3af1 --- /dev/null +++ b/examples/plugins/hello-world-wasm-go/memory.go @@ -0,0 +1,89 @@ +package main + +import "unsafe" + +// ============================================================================ +// Memory Management +// ============================================================================ + +// heapSize is the fixed size of the pre-allocated heap. +// This must be large enough to handle all allocations during the plugin lifetime. +// The heap is never reallocated to ensure all pointers remain valid. +const heapSize = 4 * 1024 * 1024 // 4MB fixed heap + +// heapBase is a fixed-size buffer that is never reallocated. +// All allocations come from this buffer to ensure pointer stability. +var heapBase []byte + +// heapOffset tracks the next available position in heapBase. +var heapOffset uint32 = 0 + +// heapBasePtr caches the base pointer of heapBase for efficient offset-to-pointer conversion. +var heapBasePtr uintptr + +func init() { + // Pre-allocate the fixed heap once at startup. + // This ensures heapBase is never reallocated after pointers are handed out. + heapBase = make([]byte, heapSize) + heapBasePtr = uintptr(unsafe.Pointer(&heapBase[0])) +} + +//export plugin_malloc +func plugin_malloc(size uint32) uint32 { + if size == 0 { + return 0 + } + // Align to 8-byte boundary + alignedSize := (size + 7) &^ 7 + // Check if we have enough space (no reallocation allowed) + if heapOffset+alignedSize > uint32(len(heapBase)) { + // Allocation failure - heap exhausted + // Return 0 to indicate failure rather than reallocating + return 0 + } + // Return pointer to the allocated region + ptr := uint32(heapBasePtr + uintptr(heapOffset)) + heapOffset += alignedSize + return ptr +} + +//export plugin_free +func plugin_free(ptr uint32) { + // No-op: we use a simple bump allocator without individual frees. + // Memory is reclaimed when the plugin is unloaded. +} + +// plugin_reset resets the heap allocator, allowing memory to be reused. +// This should only be called when no allocated memory is in use. +// +//export plugin_reset +func plugin_reset() { + heapOffset = 0 +} + +func packResult(ptr uint32, length uint32) uint64 { + return (uint64(ptr) << 32) | uint64(length) +} + +func writeBytes(data []byte) uint64 { + if len(data) == 0 { + return 0 + } + // Allocate from the stable heap + ptr := plugin_malloc(uint32(len(data))) + if ptr == 0 { + // Allocation failed + return 0 + } + // Copy data into the allocated region + offset := ptr - uint32(heapBasePtr) + copy(heapBase[offset:offset+uint32(len(data))], data) + return packResult(ptr, uint32(len(data))) +} + +func readInput(ptr, length uint32) []byte { + if length == 0 { + return nil + } + return unsafe.Slice((*byte)(unsafe.Pointer(uintptr(ptr))), length) +} diff --git a/examples/plugins/hello-world-wasm-go/types.go b/examples/plugins/hello-world-wasm-go/types.go new file mode 100644 index 0000000000..6333b802e6 --- /dev/null +++ b/examples/plugins/hello-world-wasm-go/types.go @@ -0,0 +1,54 @@ +package main + +import "github.com/maximhq/bifrost/core/schemas" + +// ============================================================================ +// Input/Output Structs +// ============================================================================ + +// HTTPInterceptInput is the input for http_intercept +type HTTPInterceptInput struct { + Context map[string]interface{} `json:"context"` + Request *schemas.HTTPRequest `json:"request,omitempty"` +} + +// HTTPInterceptOutput is the output for http_intercept +type HTTPInterceptOutput struct { + Context map[string]interface{} `json:"context"` + Request *schemas.HTTPRequest `json:"request,omitempty"` + Response *schemas.HTTPResponse `json:"response,omitempty"` + HasResponse bool `json:"has_response"` + Error string `json:"error"` +} + +// PreHookInput is the input for pre_hook +type PreHookInput struct { + Context map[string]interface{} `json:"context"` + Request *schemas.BifrostRequest `json:"request,omitempty"` // Keep raw for pass-through +} + +// PreHookOutput is the output for pre_hook +type PreHookOutput struct { + Context map[string]interface{} `json:"context"` + Request *schemas.BifrostRequest `json:"request,omitempty"` + ShortCircuit *schemas.PluginShortCircuit `json:"short_circuit,omitempty"` + HasShortCircuit bool `json:"has_short_circuit"` + Error string `json:"error"` +} + +// PostHookInput is the input for post_hook +type PostHookInput struct { + Context map[string]interface{} `json:"context"` + Response *schemas.BifrostResponse `json:"response,omitempty"` + Error *schemas.BifrostError `json:"error,omitempty"` + HasError bool `json:"has_error"` +} + +// PostHookOutput is the output for post_hook +type PostHookOutput struct { + Context map[string]interface{} `json:"context"` + Response *schemas.BifrostResponse `json:"response,omitempty"` + Error *schemas.BifrostError `json:"error,omitempty"` + HasError bool `json:"has_error"` + HookError string `json:"hook_error"` +} diff --git a/examples/plugins/hello-world-wasm-rust/Cargo.toml b/examples/plugins/hello-world-wasm-rust/Cargo.toml new file mode 100644 index 0000000000..1b97bd30ab --- /dev/null +++ b/examples/plugins/hello-world-wasm-rust/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "hello-world-wasm-rust" +version = "0.1.0" +edition = "2021" +description = "A minimal Bifrost WASM plugin example in Rust" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" + +[profile.release] +opt-level = "s" +lto = true +strip = true +panic = "abort" diff --git a/examples/plugins/hello-world-wasm-rust/Makefile b/examples/plugins/hello-world-wasm-rust/Makefile new file mode 100644 index 0000000000..152dd8a39d --- /dev/null +++ b/examples/plugins/hello-world-wasm-rust/Makefile @@ -0,0 +1,80 @@ +.PHONY: all build build-optimized clean help check-rust + +# Colors +COLOR_RESET = \033[0m +COLOR_INFO = \033[36m +COLOR_SUCCESS = \033[32m +COLOR_WARNING = \033[33m +COLOR_ERROR = \033[31m +COLOR_BOLD = \033[1m + +# Plugin configuration +PLUGIN_NAME = hello-world +OUTPUT_DIR = build +OUTPUT = $(OUTPUT_DIR)/$(PLUGIN_NAME).wasm +TARGET = wasm32-unknown-unknown + +help: ## Show this help message + @echo '$(COLOR_BOLD)Hello World WASM Plugin (Rust)$(COLOR_RESET)' + @echo '' + @echo '$(COLOR_BOLD)Usage:$(COLOR_RESET) make [target]' + @echo '' + @echo '$(COLOR_BOLD)Prerequisites:$(COLOR_RESET)' + @echo ' - Rust with wasm32-unknown-unknown target' + @echo ' rustup target add wasm32-unknown-unknown' + @echo '' + @echo '$(COLOR_BOLD)Available targets:$(COLOR_RESET)' + @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {printf " $(COLOR_INFO)%-15s$(COLOR_RESET) %s\n", $$1, $$2}' $(MAKEFILE_LIST) + +check-rust: ## Check if Rust and WASM target are installed + @which cargo > /dev/null 2>&1 || (echo "$(COLOR_ERROR)Error: Rust/Cargo is not installed$(COLOR_RESET)"; \ + echo "$(COLOR_INFO)Install Rust: https://rustup.rs/$(COLOR_RESET)"; \ + exit 1) + @rustup target list --installed | grep -q $(TARGET) || (echo "$(COLOR_ERROR)Error: WASM target not installed$(COLOR_RESET)"; \ + echo "$(COLOR_INFO)Install with: rustup target add $(TARGET)$(COLOR_RESET)"; \ + exit 1) + @echo "$(COLOR_SUCCESS)✓ Rust found: $$(rustc --version)$(COLOR_RESET)" + @echo "$(COLOR_SUCCESS)✓ WASM target: $(TARGET)$(COLOR_RESET)" + +build: check-rust ## Build the WASM plugin + @mkdir -p $(OUTPUT_DIR) + @echo "$(COLOR_INFO)Building WASM plugin...$(COLOR_RESET)" + cargo build --release --target $(TARGET) + @cp target/$(TARGET)/release/hello_world_wasm_rust.wasm $(OUTPUT) + @echo "$(COLOR_SUCCESS)✓ Plugin built successfully: $(OUTPUT)$(COLOR_RESET)" + @ls -lh $(OUTPUT) | awk '{print " Size: " $$5}' + +build-optimized: check-rust ## Build with wasm-opt optimization (requires wasm-opt) + @mkdir -p $(OUTPUT_DIR) + @echo "$(COLOR_INFO)Building optimized WASM plugin...$(COLOR_RESET)" + cargo build --release --target $(TARGET) + @cp target/$(TARGET)/release/hello_world_wasm_rust.wasm $(OUTPUT) + @if which wasm-opt > /dev/null 2>&1; then \ + echo "$(COLOR_INFO)Running wasm-opt...$(COLOR_RESET)"; \ + wasm-opt -Os -o $(OUTPUT) $(OUTPUT); \ + else \ + echo "$(COLOR_WARNING)wasm-opt not found, skipping optimization$(COLOR_RESET)"; \ + fi + @echo "$(COLOR_SUCCESS)✓ Plugin built: $(OUTPUT)$(COLOR_RESET)" + @ls -lh $(OUTPUT) | awk '{print " Size: " $$5}' + +clean: ## Remove build artifacts + @echo "$(COLOR_INFO)Cleaning build artifacts...$(COLOR_RESET)" + @cargo clean + @rm -rf $(OUTPUT_DIR) + @echo "$(COLOR_SUCCESS)✓ Clean complete$(COLOR_RESET)" + +info: ## Show build information + @echo "$(COLOR_BOLD)Build Configuration$(COLOR_RESET)" + @echo " Plugin Name: $(PLUGIN_NAME)" + @echo " Output: $(OUTPUT)" + @echo " Target: $(TARGET)" + @echo "" + @if [ -f "$(OUTPUT)" ]; then \ + echo "$(COLOR_SUCCESS)Plugin exists:$(COLOR_RESET)"; \ + ls -lh $(OUTPUT) | awk '{print " " $$9 " (" $$5 ")"}'; \ + else \ + echo "$(COLOR_WARNING)Plugin not built yet$(COLOR_RESET)"; \ + fi + +.DEFAULT_GOAL := help diff --git a/examples/plugins/hello-world-wasm-rust/README.md b/examples/plugins/hello-world-wasm-rust/README.md new file mode 100644 index 0000000000..625794c422 --- /dev/null +++ b/examples/plugins/hello-world-wasm-rust/README.md @@ -0,0 +1,528 @@ +# Bifrost WASM Plugin (Rust) + +A comprehensive example of a Bifrost plugin written in Rust and compiled to WebAssembly. This plugin demonstrates proper structure definitions with serde, JSON parsing, context handling, and request/response modification patterns. + +## Prerequisites + +### Rust Installation + +Install Rust from [rustup.rs](https://rustup.rs/) and add the WASM target: + +```bash +# Install Rust (if not already installed) +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh + +# Add WASM target +rustup target add wasm32-unknown-unknown +``` + +### Optional: wasm-opt + +For smaller binaries, install `wasm-opt` from [binaryen](https://github.com/WebAssembly/binaryen): + +```bash +# macOS +brew install binaryen + +# Linux +apt install binaryen +``` + +## Building + +```bash +# Build the WASM plugin +make build + +# Build with wasm-opt optimization +make build-optimized + +# Clean build artifacts +make clean +``` + +The compiled plugin will be at `build/hello-world.wasm`. + +## File Structure + +``` +src/ +├── lib.rs # Plugin implementation (hooks) +├── memory.rs # Memory management utilities +└── types.rs # Type definitions (mirrors Go SDK) +``` + +## Plugin Structure + +WASM plugins must export the following functions: + +| Export | Signature | Description | +|--------|-----------|-------------| +| `malloc` | `(size: u32) -> u32` | Allocate memory for host to write data | +| `free` | `(ptr: u32, size: u32)` | Free allocated memory | +| `get_name` | `() -> u64` | Returns packed ptr+len of plugin name | +| `init` | `(config_ptr, config_len: u32) -> i32` | Initialize with config (optional) | +| `http_intercept` | `(input_ptr, input_len: u32) -> u64` | HTTP transport intercept | +| `pre_hook` | `(input_ptr, input_len: u32) -> u64` | Pre-request hook | +| `post_hook` | `(input_ptr, input_len: u32) -> u64` | Post-response hook | +| `cleanup` | `() -> i32` | Cleanup resources (0 = success) | + +### Return Value Format + +Functions returning data use a packed `u64` format: +- Upper 32 bits: pointer to data in WASM memory +- Lower 32 bits: length of data + +## Data Structures + +This plugin uses `serde` with derive macros for JSON serialization. All structures mirror the Go SDK types: + +### Context + +```rust +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostContext { + pub request_id: Option, + + // Custom values via HashMap + #[serde(flatten)] + pub values: HashMap, +} + +impl BifrostContext { + pub fn set_value(&mut self, key: &str, value: impl Into); + pub fn get_string(&self, key: &str) -> Option<&str>; + pub fn get_bool(&self, key: &str) -> Option; +} +``` + +### HTTP Transport Types + +```rust +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct HTTPRequest { + pub method: String, + pub path: String, + pub headers: HashMap, + pub query: HashMap, + pub body: String, // base64 encoded +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct HTTPResponse { + pub status_code: i32, + pub headers: HashMap, + pub body: String, // base64 encoded +} +``` + +### Chat Completion Types + +```rust +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ChatMessageRole { + User, + Assistant, + System, + Tool, + Developer, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ChatMessageContent { + Text(String), + Blocks(Vec), +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ChatMessage { + pub role: ChatMessageRole, + pub content: Option, + pub name: Option, + pub tool_call_id: Option, + pub tool_calls: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ChatParameters { + pub temperature: Option, + pub max_completion_tokens: Option, + pub top_p: Option, + pub frequency_penalty: Option, + pub presence_penalty: Option, + pub stop: Option>, + pub tools: Option>, + + #[serde(flatten)] + pub extra: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostChatRequest { + pub provider: String, + pub model: String, + pub input: Vec, + pub params: Option, + pub fallbacks: Option>, +} +``` + +### Response Types + +```rust +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct LLMUsage { + pub prompt_tokens: i32, + pub completion_tokens: i32, + pub total_tokens: i32, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ResponseChoice { + pub index: i32, + pub message: Option, + pub delta: Option, + pub finish_reason: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostChatResponse { + pub id: String, + pub model: String, + pub choices: Vec, + pub usage: Option, + pub created: Option, + pub object: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostResponse { + pub chat_response: Option, +} +``` + +### Error Types + +```rust +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ErrorField { + pub message: String, + #[serde(rename = "type")] + pub error_type: Option, + pub code: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostError { + pub error: ErrorField, + pub status_code: Option, + pub allow_fallbacks: Option, +} + +impl BifrostError { + pub fn new(message: &str) -> Self; + pub fn with_type(self, error_type: &str) -> Self; + pub fn with_code(self, code: &str) -> Self; + pub fn with_status(self, status: i32) -> Self; +} +``` + +### Short Circuit + +```rust +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PluginShortCircuit { + pub response: Option, + pub error: Option, +} +``` + +## Hook Input/Output Structures + +### http_intercept + +**Input:** +```json +{ + "context": { "request_id": "abc-123" }, + "request": { + "method": "POST", + "path": "/v1/chat/completions", + "headers": { "Content-Type": "application/json" }, + "query": {}, + "body": "" + } +} +``` + +**Output:** +```json +{ + "context": { "request_id": "abc-123" }, + "request": {}, + "response": { "status_code": 200, "headers": {}, "body": "" }, + "has_response": false, + "error": "" +} +``` + +### pre_hook + +**Input:** +```json +{ + "context": { "request_id": "abc-123" }, + "request": { + "provider": "openai", + "model": "gpt-4", + "input": [{ "role": "user", "content": "Hello" }], + "params": { "temperature": 0.7 } + } +} +``` + +**Output:** +```json +{ + "context": { "request_id": "abc-123", "plugin_processed": true }, + "request": {}, + "short_circuit": { + "response": { "chat_response": { ... } } + }, + "has_short_circuit": false, + "error": "" +} +``` + +### post_hook + +**Input:** +```json +{ + "context": { "request_id": "abc-123", "plugin_processed": true }, + "response": { + "chat_response": { + "id": "chatcmpl-123", + "model": "gpt-4", + "choices": [{ "index": 0, "message": { "role": "assistant", "content": "Hi!" } }], + "usage": { "prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15 } + } + }, + "error": {}, + "has_error": false +} +``` + +**Output:** +```json +{ + "context": { "request_id": "abc-123", "post_hook_completed": true }, + "response": {}, + "error": {}, + "has_error": false, + "hook_error": "" +} +``` + +## Usage Examples + +### Modifying Context + +```rust +#[no_mangle] +pub extern "C" fn pre_hook(input_ptr: u32, input_len: u32) -> u64 { + let input_str = read_string(input_ptr, input_len); + let input: PreHookInput = serde_json::from_str(&input_str).unwrap(); + + let mut output = PreHookOutput { + context: input.context.clone(), + ..Default::default() + }; + + // Add custom values to context + output.context.set_value("plugin_processed", serde_json::json!(true)); + output.context.set_value("plugin_name", serde_json::json!("my-rust-plugin")); + + write_string(&serde_json::to_string(&output).unwrap()) +} +``` + +### Short-Circuit with Mock Response + +```rust +#[no_mangle] +pub extern "C" fn pre_hook(input_ptr: u32, input_len: u32) -> u64 { + let input_str = read_string(input_ptr, input_len); + let input: PreHookInput = serde_json::from_str(&input_str).unwrap(); + + let (provider, model) = input.get_provider_model(); + + // Check if this should be mocked + if model == "mock-model" { + let mut output = PreHookOutput { + context: input.context.clone(), + has_short_circuit: true, + ..Default::default() + }; + + // Build mock response + let mock_response = BifrostResponse { + chat_response: Some(BifrostChatResponse { + id: format!("mock-{}", input.context.request_id.unwrap_or_default()), + model: "mock-model".to_string(), + choices: vec![ResponseChoice { + index: 0, + message: Some(ChatMessage { + role: ChatMessageRole::Assistant, + content: Some(ChatMessageContent::Text( + "This is a mock response!".to_string() + )), + ..Default::default() + }), + finish_reason: Some("stop".to_string()), + ..Default::default() + }], + usage: Some(LLMUsage { + prompt_tokens: 10, + completion_tokens: 15, + total_tokens: 25, + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + + output.short_circuit = Some(PluginShortCircuit { + response: Some(mock_response), + error: None, + }); + + return write_string(&serde_json::to_string(&output).unwrap()); + } + + // Pass through + let output = PreHookOutput { + context: input.context, + ..Default::default() + }; + write_string(&serde_json::to_string(&output).unwrap()) +} +``` + +### Short-Circuit with Error + +```rust +#[no_mangle] +pub extern "C" fn pre_hook(input_ptr: u32, input_len: u32) -> u64 { + let input_str = read_string(input_ptr, input_len); + let input: PreHookInput = serde_json::from_str(&input_str).unwrap(); + + // Check rate limit (example) + if should_rate_limit(&input.context) { + let mut output = PreHookOutput { + context: input.context.clone(), + has_short_circuit: true, + ..Default::default() + }; + + output.short_circuit = Some(PluginShortCircuit { + response: None, + error: Some( + BifrostError::new("Rate limit exceeded") + .with_type("rate_limit") + .with_code("429") + .with_status(429) + ), + }); + + return write_string(&serde_json::to_string(&output).unwrap()); + } + + // Pass through + let output = PreHookOutput { + context: input.context, + ..Default::default() + }; + write_string(&serde_json::to_string(&output).unwrap()) +} +``` + +### Modifying Responses in post_hook + +```rust +#[no_mangle] +pub extern "C" fn post_hook(input_ptr: u32, input_len: u32) -> u64 { + let input_str = read_string(input_ptr, input_len); + let input: PostHookInput = serde_json::from_str(&input_str).unwrap(); + + let mut output = PostHookOutput { + context: input.context.clone(), + ..Default::default() + }; + + // Handle errors + if input.has_error { + output.has_error = true; + output.error = input.error.clone(); + + // Optionally modify the error + if let Some(mut error) = input.parse_error() { + error.error.message = format!("{} (via rust plugin)", error.error.message); + output.error = serde_json::to_value(&error).unwrap_or_default(); + } + + return write_string(&serde_json::to_string(&output).unwrap()); + } + + // Pass through or modify response + if let Some(mut response) = input.parse_response() { + if let Some(ref mut chat) = response.chat_response { + // Add a marker to the model name + chat.model = format!("{} (via rust-wasm)", chat.model); + } + output.response = serde_json::to_value(&response).unwrap_or_default(); + } + + write_string(&serde_json::to_string(&output).unwrap()) +} +``` + +## Usage with Bifrost + +Configure the plugin in your Bifrost config: + +```json +{ + "plugins": [ + { + "path": "/path/to/hello-world.wasm", + "name": "hello-world-wasm-rust", + "enabled": true, + "config": { + "custom_option": "value" + } + } + ] +} +``` + +## Testing + +The plugin includes unit tests that can be run with: + +```bash +cargo test +``` + +## Benefits + +1. **Performance**: Rust compiles to highly optimized WASM +2. **Safety**: Memory safety without garbage collection +3. **Small binaries**: Rust WASM binaries are typically very small +4. **Cross-platform**: Single `.wasm` binary runs on any OS/architecture +5. **Security**: WASM provides sandboxed execution +6. **Type Safety**: Strongly typed structures with serde derive macros +7. **Excellent JSON**: serde_json provides robust JSON handling diff --git a/examples/plugins/hello-world-wasm-rust/src/lib.rs b/examples/plugins/hello-world-wasm-rust/src/lib.rs new file mode 100644 index 0000000000..8816f9dd9e --- /dev/null +++ b/examples/plugins/hello-world-wasm-rust/src/lib.rs @@ -0,0 +1,291 @@ +//! Bifrost WASM Plugin for Rust +//! +//! This plugin demonstrates the proper structure for parsing inputs, +//! building responses, and handling context - similar to Go plugin patterns. +//! +//! Build with: cargo build --release --target wasm32-unknown-unknown + +mod memory; +mod types; + +use memory::{read_string, write_string}; +use types::*; + +// Global configuration storage +static mut PLUGIN_CONFIG: Option = None; + +// ============================================================================= +// Exported Plugin Functions +// ============================================================================= + +/// Return the plugin name +#[no_mangle] +pub extern "C" fn get_name() -> u64 { + write_string("hello-world-wasm-rust") +} + +/// Initialize the plugin with config +/// Returns 0 on success, non-zero on error +#[no_mangle] +pub extern "C" fn init(config_ptr: u32, config_len: u32) -> i32 { + let config_str = read_string(config_ptr, config_len); + + // Parse configuration + let config: PluginConfig = if config_str.is_empty() { + PluginConfig::default() + } else { + match serde_json::from_str(&config_str) { + Ok(c) => c, + Err(_) => return 1, // Config parse error + } + }; + + // Store configuration + unsafe { + PLUGIN_CONFIG = Some(config); + } + + 0 // Success +} + +/// HTTP transport intercept +/// Called at the HTTP layer before request enters Bifrost core. +/// Can modify headers, query params, or short-circuit with a response. +#[no_mangle] +pub extern "C" fn http_intercept(input_ptr: u32, input_len: u32) -> u64 { + let input_str = read_string(input_ptr, input_len); + + // Parse input + let input: HTTPInterceptInput = match serde_json::from_str(&input_str) { + Ok(i) => i, + Err(e) => { + let output = HTTPInterceptOutput { + error: format!("Failed to parse input: {}", e), + ..Default::default() + }; + return write_string(&serde_json::to_string(&output).unwrap_or_default()); + } + }; + + // Create output with context preserved + let output = HTTPInterceptOutput { + context: input.context, + request: input.request, + has_response: false, + ..Default::default() + }; + + // Example: Short-circuit health check endpoint + // Uncomment to test: + /* + if input.request.path == "/health" { + output.has_response = true; + output.response = Some(HTTPResponse { + status_code: 200, + headers: HashMap::new(), + body: base64::encode(r#"{"status":"ok"}"#), + }); + return write_string(&serde_json::to_string(&output).unwrap_or_default()); + } + */ + + // Pass through + write_string(&serde_json::to_string(&output).unwrap_or_default()) +} + +/// Pre-request hook +/// Called before request is sent to the provider. +/// Can modify the request or short-circuit with a response/error. +#[no_mangle] +pub extern "C" fn pre_hook(input_ptr: u32, input_len: u32) -> u64 { + let input_str = read_string(input_ptr, input_len); + + // Parse input + let input: PreHookInput = match serde_json::from_str(&input_str) { + Ok(i) => i, + Err(e) => { + let output = PreHookOutput { + error: format!("Failed to parse input: {}", e), + ..Default::default() + }; + return write_string(&serde_json::to_string(&output).unwrap_or_default()); + } + }; + + // Create output with context preserved + let mut output = PreHookOutput { + context: input.context.clone(), + request: input.request.clone(), + has_short_circuit: false, + ..Default::default() + }; + + // Add custom values to context for tracking + output.context.set_value("plugin_processed", serde_json::json!(true)); + output.context.set_value("plugin_name", serde_json::json!("hello-world-wasm-rust")); + + // Get provider and model for potential modifications + let (_provider, model) = input.get_provider_model(); + + // Example: Short-circuit with mock response for specific model + // Uncomment to test: + /* + if model == "mock-model" { + output.has_short_circuit = true; + + let mock_response = BifrostResponse { + chat_response: Some(BifrostChatResponse { + id: format!("mock-{}", input.context.request_id.unwrap_or_default()), + model: "mock-model".to_string(), + choices: vec![ResponseChoice { + index: 0, + message: Some(ChatMessage { + role: ChatMessageRole::Assistant, + content: Some(ChatMessageContent::Text( + "This is a mock response from the Rust WASM plugin!".to_string() + )), + ..Default::default() + }), + finish_reason: Some("stop".to_string()), + ..Default::default() + }], + usage: Some(LLMUsage { + prompt_tokens: 10, + completion_tokens: 15, + total_tokens: 25, + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + + output.short_circuit = Some(PluginShortCircuit { + response: Some(mock_response), + error: None, + }); + + return write_string(&serde_json::to_string(&output).unwrap_or_default()); + } + */ + + // Example: Short-circuit with rate limit error + // Uncomment to test: + /* + if should_rate_limit(&input.context) { + output.has_short_circuit = true; + output.short_circuit = Some(PluginShortCircuit { + response: None, + error: Some( + BifrostError::new("Rate limit exceeded") + .with_type("rate_limit") + .with_code("429") + .with_status(429) + ), + }); + return write_string(&serde_json::to_string(&output).unwrap_or_default()); + } + */ + + // Silence unused variable warning in example code + let _ = model; + + // Pass through - empty request means use original + write_string(&serde_json::to_string(&output).unwrap_or_default()) +} + +/// Post-response hook +/// Called after response is received from provider. +/// Can modify the response or error. +#[no_mangle] +pub extern "C" fn post_hook(input_ptr: u32, input_len: u32) -> u64 { + let input_str = read_string(input_ptr, input_len); + + // Parse input + let input: PostHookInput = match serde_json::from_str(&input_str) { + Ok(i) => i, + Err(e) => { + let output = PostHookOutput { + hook_error: format!("Failed to parse input: {}", e), + ..Default::default() + }; + return write_string(&serde_json::to_string(&output).unwrap_or_default()); + } + }; + + // Create output with context preserved + let mut output = PostHookOutput { + context: input.context.clone(), + response: serde_json::json!({}), + error: serde_json::json!({}), + has_error: false, + hook_error: String::new(), + }; + + // Check if our plugin processed this request + let plugin_processed = input.context.get_bool("plugin_processed").unwrap_or(false); + + if plugin_processed { + // Plugin was involved in pre_hook, add completion marker + output.context.set_value("post_hook_completed", serde_json::json!(true)); + } + + // Handle error case + if input.has_error { + output.has_error = true; + output.error = input.error.clone(); + + // Example: Modify error message + // Uncomment to test: + /* + if let Some(mut error) = input.parse_error() { + error.error.message = format!("{} (processed by Rust WASM plugin)", error.error.message); + output.error = serde_json::to_value(&error).unwrap_or_default(); + } + */ + + return write_string(&serde_json::to_string(&output).unwrap_or_default()); + } + + // Handle success case - pass through response + output.response = input.response; + + // Example: Modify response + // Uncomment to test: + /* + if let Some(mut response) = input.parse_response() { + // Add custom metadata, modify model name, etc. + if let Some(ref mut chat) = response.chat_response { + // Add a marker to the model name + chat.model = format!("{} (via rust-wasm)", chat.model); + } + output.response = serde_json::to_value(&response).unwrap_or_default(); + } + */ + + write_string(&serde_json::to_string(&output).unwrap_or_default()) +} + +/// Cleanup resources +/// Called when plugin is being unloaded. +/// Returns 0 on success, non-zero on error +#[no_mangle] +pub extern "C" fn cleanup() -> i32 { + // Clear stored configuration + unsafe { + PLUGIN_CONFIG = None; + } + + 0 // Success +} + +// ============================================================================= +// Helper Functions +// ============================================================================= + +/// Example rate limit check function +#[allow(dead_code)] +fn should_rate_limit(_context: &BifrostContext) -> bool { + // Implement your rate limiting logic here + false +} diff --git a/examples/plugins/hello-world-wasm-rust/src/memory.rs b/examples/plugins/hello-world-wasm-rust/src/memory.rs new file mode 100644 index 0000000000..bab6fecc1d --- /dev/null +++ b/examples/plugins/hello-world-wasm-rust/src/memory.rs @@ -0,0 +1,70 @@ +//! Memory management utilities for WASM plugins. +//! Handles allocation, deallocation, and string read/write operations. + +use std::alloc::{alloc, dealloc, Layout}; +use std::slice; + +/// Pack a pointer and length into a single u64 +/// Upper 32 bits: pointer, Lower 32 bits: length +pub fn pack_result(ptr: u32, len: u32) -> u64 { + ((ptr as u64) << 32) | (len as u64) +} + +/// Write a string to WASM memory and return packed pointer+length +pub fn write_string(s: &str) -> u64 { + if s.is_empty() { + return 0; + } + let bytes = s.as_bytes(); + let ptr = unsafe { malloc(bytes.len() as u32) }; + if ptr == 0 { + return 0; + } + unsafe { + std::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr as *mut u8, bytes.len()); + } + pack_result(ptr, bytes.len() as u32) +} + +/// Read a string from WASM memory given pointer and length +pub fn read_string(ptr: u32, len: u32) -> String { + if len == 0 { + return String::new(); + } + let bytes = unsafe { slice::from_raw_parts(ptr as *const u8, len as usize) }; + String::from_utf8_lossy(bytes).into_owned() +} + +/// Allocate memory for the host to write data +/// +/// # Safety +/// This function is marked as safe but performs unsafe operations internally. +/// It is intended to be called from WASM host. +#[no_mangle] +pub extern "C" fn malloc(size: u32) -> u32 { + if size == 0 { + return 0; + } + let layout = match Layout::from_size_align(size as usize, 1) { + Ok(l) => l, + Err(_) => return 0, + }; + unsafe { alloc(layout) as u32 } +} + +/// Free allocated memory +/// +/// # Safety +/// This function is marked as safe but performs unsafe operations internally. +/// It is intended to be called from WASM host. +#[no_mangle] +pub extern "C" fn free(ptr: u32, size: u32) { + if ptr == 0 || size == 0 { + return; + } + let layout = match Layout::from_size_align(size as usize, 1) { + Ok(l) => l, + Err(_) => return, + }; + unsafe { dealloc(ptr as *mut u8, layout) } +} diff --git a/examples/plugins/hello-world-wasm-rust/src/types.rs b/examples/plugins/hello-world-wasm-rust/src/types.rs new file mode 100644 index 0000000000..0018e5be4d --- /dev/null +++ b/examples/plugins/hello-world-wasm-rust/src/types.rs @@ -0,0 +1,641 @@ +//! Type definitions for Bifrost WASM plugins. +//! These structures mirror the Go SDK types for interoperability. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +// ============================================================================= +// Context Structure +// ============================================================================= + +/// BifrostContext holds request-scoped values passed between hooks. +/// Common keys include: +/// - request_id: Unique identifier for the request +/// - Custom plugin values can be added and will be persisted across hooks +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostContext { + #[serde(default)] + pub request_id: Option, + + /// Custom values set by plugins + #[serde(flatten)] + pub values: HashMap, +} + +impl BifrostContext { + pub fn new() -> Self { + Self::default() + } + + /// Set a custom value in the context + pub fn set_value(&mut self, key: &str, value: impl Into) { + self.values.insert(key.to_string(), value.into()); + } + + /// Get a string value from the context + pub fn get_string(&self, key: &str) -> Option<&str> { + self.values.get(key).and_then(|v| v.as_str()) + } + + /// Get a boolean value from the context + pub fn get_bool(&self, key: &str) -> Option { + self.values.get(key).and_then(|v| v.as_bool()) + } +} + +// ============================================================================= +// HTTP Transport Structures +// ============================================================================= + +/// HTTPRequest represents an incoming HTTP request at the transport layer. +/// Body is base64-encoded. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct HTTPRequest { + #[serde(default)] + pub method: String, + + #[serde(default)] + pub path: String, + + #[serde(default)] + pub headers: HashMap, + + #[serde(default)] + pub query: HashMap, + + /// Base64-encoded request body + #[serde(default)] + pub body: String, +} + +/// HTTPResponse represents an HTTP response to return. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct HTTPResponse { + #[serde(default)] + pub status_code: i32, + + #[serde(default)] + pub headers: HashMap, + + /// Base64-encoded response body + #[serde(default)] + pub body: String, +} + +/// HTTPInterceptInput is the input for http_intercept hook. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct HTTPInterceptInput { + #[serde(default)] + pub context: BifrostContext, + + #[serde(default)] + pub request: HTTPRequest, +} + +/// HTTPInterceptOutput is the output for http_intercept hook. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct HTTPInterceptOutput { + pub context: BifrostContext, + + #[serde(default)] + pub request: serde_json::Value, + + #[serde(skip_serializing_if = "Option::is_none")] + pub response: Option, + + #[serde(default)] + pub has_response: bool, + + #[serde(default)] + pub error: String, +} + +// ============================================================================= +// Chat Completion Structures (BifrostRequest) +// ============================================================================= + +/// ChatMessageRole represents the role of a message sender. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ChatMessageRole { + User, + Assistant, + System, + Tool, + Developer, +} + +impl Default for ChatMessageRole { + fn default() -> Self { + ChatMessageRole::User + } +} + +/// ChatMessageContent can be either a string or an array of content blocks. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ChatMessageContent { + Text(String), + Blocks(Vec), +} + +impl Default for ChatMessageContent { + fn default() -> Self { + ChatMessageContent::Text(String::new()) + } +} + +/// ChatContentBlock represents a content block in a message. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ChatContentBlock { + #[serde(rename = "type")] + pub block_type: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub image_url: Option, +} + +/// ImageUrl represents an image URL in a content block. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ImageUrl { + pub url: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub detail: Option, +} + +/// ChatMessage represents a message in the conversation. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ChatMessage { + #[serde(default)] + pub role: ChatMessageRole, + + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, +} + +/// ToolCall represents a tool call made by the assistant. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ToolCall { + #[serde(default)] + pub id: Option, + + #[serde(rename = "type", default)] + pub call_type: Option, + + #[serde(default)] + pub function: ToolCallFunction, +} + +/// ToolCallFunction represents the function being called. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ToolCallFunction { + #[serde(default)] + pub name: Option, + + #[serde(default)] + pub arguments: String, +} + +/// ChatParameters contains optional parameters for chat completion. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ChatParameters { + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub max_completion_tokens: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + /// Catch-all for additional parameters + #[serde(flatten)] + pub extra: HashMap, +} + +/// ChatTool represents a tool definition. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ChatTool { + #[serde(rename = "type")] + pub tool_type: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub function: Option, +} + +/// ChatToolFunction represents a function definition. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ChatToolFunction { + pub name: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, +} + +/// BifrostChatRequest represents a chat completion request. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostChatRequest { + #[serde(default)] + pub provider: String, + + #[serde(default)] + pub model: String, + + #[serde(default)] + pub input: Vec, + + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub fallbacks: Option>, +} + +/// Fallback represents a fallback provider/model. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct Fallback { + pub provider: String, + pub model: String, +} + +/// BifrostRequest is the unified request structure. +/// Only one of the request types should be present. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostRequest { + #[serde(skip_serializing_if = "Option::is_none")] + pub chat_request: Option, + + // Add other request types as needed + #[serde(flatten)] + pub extra: HashMap, +} + +impl BifrostRequest { + /// Get provider and model from the request + pub fn get_provider_model(&self) -> (String, String) { + if let Some(ref chat) = self.chat_request { + return (chat.provider.clone(), chat.model.clone()); + } + (String::new(), String::new()) + } +} + +// ============================================================================= +// Response Structures (BifrostResponse) +// ============================================================================= + +/// LLMUsage contains token usage information. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct LLMUsage { + #[serde(default)] + pub prompt_tokens: i32, + + #[serde(default)] + pub completion_tokens: i32, + + #[serde(default)] + pub total_tokens: i32, + + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_tokens_details: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_tokens_details: Option, +} + +/// ResponseChoice represents a single completion choice. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ResponseChoice { + #[serde(default)] + pub index: i32, + + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub delta: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, +} + +/// BifrostChatResponse represents a chat completion response. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostChatResponse { + #[serde(default)] + pub id: String, + + #[serde(default)] + pub model: String, + + #[serde(default)] + pub choices: Vec, + + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub created: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub object: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, +} + +/// BifrostResponse is the unified response structure. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostResponse { + #[serde(skip_serializing_if = "Option::is_none")] + pub chat_response: Option, + + #[serde(flatten)] + pub extra: HashMap, +} + +// ============================================================================= +// Error Structure +// ============================================================================= + +/// ErrorField contains the error details. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ErrorField { + #[serde(default)] + pub message: String, + + #[serde(skip_serializing_if = "Option::is_none", rename = "type")] + pub error_type: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub param: Option, +} + +/// BifrostError represents an error response. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostError { + #[serde(default)] + pub error: ErrorField, + + #[serde(skip_serializing_if = "Option::is_none")] + pub status_code: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub allow_fallbacks: Option, +} + +impl BifrostError { + /// Create a new error with a message + pub fn new(message: &str) -> Self { + Self { + error: ErrorField { + message: message.to_string(), + ..Default::default() + }, + ..Default::default() + } + } + + /// Set the error type + pub fn with_type(mut self, error_type: &str) -> Self { + self.error.error_type = Some(error_type.to_string()); + self + } + + /// Set the error code + pub fn with_code(mut self, code: &str) -> Self { + self.error.code = Some(code.to_string()); + self + } + + /// Set the status code + pub fn with_status(mut self, status: i32) -> Self { + self.status_code = Some(status); + self + } +} + +// ============================================================================= +// Short Circuit Structure +// ============================================================================= + +/// PluginShortCircuit allows plugins to short-circuit the request flow. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PluginShortCircuit { + #[serde(skip_serializing_if = "Option::is_none")] + pub response: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +// ============================================================================= +// Hook Input/Output Structures +// ============================================================================= + +/// PreHookInput is the input for pre_hook. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PreHookInput { + #[serde(default)] + pub context: BifrostContext, + + #[serde(default)] + pub request: serde_json::Value, +} + +impl PreHookInput { + /// Parse the request as a BifrostRequest + pub fn parse_request(&self) -> Option { + serde_json::from_value(self.request.clone()).ok() + } + + /// Get provider and model from the request + pub fn get_provider_model(&self) -> (String, String) { + if let Some(req) = self.parse_request() { + return req.get_provider_model(); + } + // Try direct access for simpler structures + let provider = self.request.get("provider") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let model = self.request.get("model") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + (provider, model) + } +} + +/// PreHookOutput is the output for pre_hook. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PreHookOutput { + pub context: BifrostContext, + + #[serde(default)] + pub request: serde_json::Value, + + #[serde(skip_serializing_if = "Option::is_none")] + pub short_circuit: Option, + + #[serde(default)] + pub has_short_circuit: bool, + + #[serde(default)] + pub error: String, +} + +/// PostHookInput is the input for post_hook. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PostHookInput { + #[serde(default)] + pub context: BifrostContext, + + #[serde(default)] + pub response: serde_json::Value, + + #[serde(default)] + pub error: serde_json::Value, + + #[serde(default)] + pub has_error: bool, +} + +impl PostHookInput { + /// Parse the response as a BifrostResponse + pub fn parse_response(&self) -> Option { + serde_json::from_value(self.response.clone()).ok() + } + + /// Parse the error as a BifrostError + pub fn parse_error(&self) -> Option { + if self.has_error { + serde_json::from_value(self.error.clone()).ok() + } else { + None + } + } +} + +/// PostHookOutput is the output for post_hook. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PostHookOutput { + pub context: BifrostContext, + + #[serde(default)] + pub response: serde_json::Value, + + #[serde(default)] + pub error: serde_json::Value, + + #[serde(default)] + pub has_error: bool, + + #[serde(default)] + pub hook_error: String, +} + +// ============================================================================= +// Plugin Configuration +// ============================================================================= + +/// Plugin configuration (customize as needed) +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PluginConfig { + #[serde(flatten)] + pub values: HashMap, +} + +// ============================================================================= +// Tests +// ============================================================================= + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_context_serialization() { + let mut ctx = BifrostContext::new(); + ctx.request_id = Some("test-123".to_string()); + ctx.set_value("custom_key", "custom_value"); + + let json = serde_json::to_string(&ctx).unwrap(); + assert!(json.contains("request_id")); + assert!(json.contains("custom_key")); + } + + #[test] + fn test_chat_message() { + let msg = ChatMessage { + role: ChatMessageRole::User, + content: Some(ChatMessageContent::Text("Hello!".to_string())), + ..Default::default() + }; + + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains("user")); + assert!(json.contains("Hello!")); + } + + #[test] + fn test_bifrost_error() { + let error = BifrostError::new("Test error") + .with_type("test_type") + .with_code("500") + .with_status(500); + + let json = serde_json::to_string(&error).unwrap(); + assert!(json.contains("Test error")); + assert!(json.contains("test_type")); + } + + #[test] + fn test_pre_hook_input_parsing() { + let json = r#"{ + "context": {"request_id": "test-123"}, + "request": {"provider": "openai", "model": "gpt-4"} + }"#; + + let input: PreHookInput = serde_json::from_str(json).unwrap(); + assert_eq!(input.context.request_id, Some("test-123".to_string())); + + let (provider, model) = input.get_provider_model(); + assert_eq!(provider, "openai"); + assert_eq!(model, "gpt-4"); + } +} diff --git a/examples/plugins/hello-world-wasm-typescript/Makefile b/examples/plugins/hello-world-wasm-typescript/Makefile new file mode 100644 index 0000000000..bb4c2e1a7a --- /dev/null +++ b/examples/plugins/hello-world-wasm-typescript/Makefile @@ -0,0 +1,70 @@ +.PHONY: all build build-debug clean help install check-node + +# Colors +COLOR_RESET = \033[0m +COLOR_INFO = \033[36m +COLOR_SUCCESS = \033[32m +COLOR_WARNING = \033[33m +COLOR_ERROR = \033[31m +COLOR_BOLD = \033[1m + +# Plugin configuration +PLUGIN_NAME = hello-world +OUTPUT_DIR = build +OUTPUT = $(OUTPUT_DIR)/$(PLUGIN_NAME).wasm + +help: ## Show this help message + @echo '$(COLOR_BOLD)Hello World WASM Plugin (TypeScript/AssemblyScript)$(COLOR_RESET)' + @echo '' + @echo '$(COLOR_BOLD)Usage:$(COLOR_RESET) make [target]' + @echo '' + @echo '$(COLOR_BOLD)Prerequisites:$(COLOR_RESET)' + @echo ' - Node.js (https://nodejs.org/)' + @echo ' - npm install (to install AssemblyScript)' + @echo '' + @echo '$(COLOR_BOLD)Available targets:$(COLOR_RESET)' + @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {printf " $(COLOR_INFO)%-15s$(COLOR_RESET) %s\n", $$1, $$2}' $(MAKEFILE_LIST) + +check-node: ## Check if Node.js is installed + @which node > /dev/null 2>&1 || (echo "$(COLOR_ERROR)Error: Node.js is not installed$(COLOR_RESET)"; \ + echo "$(COLOR_INFO)Install Node.js: https://nodejs.org/$(COLOR_RESET)"; \ + exit 1) + @echo "$(COLOR_SUCCESS)✓ Node.js found: $$(node --version)$(COLOR_RESET)" + +install: check-node ## Install dependencies + @echo "$(COLOR_INFO)Installing dependencies...$(COLOR_RESET)" + npm install + @echo "$(COLOR_SUCCESS)✓ Dependencies installed$(COLOR_RESET)" + +build: install ## Build the WASM plugin + @mkdir -p $(OUTPUT_DIR) + @echo "$(COLOR_INFO)Building WASM plugin...$(COLOR_RESET)" + npm run build + @echo "$(COLOR_SUCCESS)✓ Plugin built successfully: $(OUTPUT)$(COLOR_RESET)" + @ls -lh $(OUTPUT) | awk '{print " Size: " $$5}' + +build-debug: install ## Build with debug info + @mkdir -p $(OUTPUT_DIR) + @echo "$(COLOR_INFO)Building WASM plugin (debug)...$(COLOR_RESET)" + npm run build:debug + @echo "$(COLOR_SUCCESS)✓ Debug plugin built: $(OUTPUT)$(COLOR_RESET)" + @ls -lh $(OUTPUT) | awk '{print " Size: " $$5}' + +clean: ## Remove build artifacts + @echo "$(COLOR_INFO)Cleaning build artifacts...$(COLOR_RESET)" + @rm -rf $(OUTPUT_DIR) node_modules + @echo "$(COLOR_SUCCESS)✓ Clean complete$(COLOR_RESET)" + +info: ## Show build information + @echo "$(COLOR_BOLD)Build Configuration$(COLOR_RESET)" + @echo " Plugin Name: $(PLUGIN_NAME)" + @echo " Output: $(OUTPUT)" + @echo "" + @if [ -f "$(OUTPUT)" ]; then \ + echo "$(COLOR_SUCCESS)Plugin exists:$(COLOR_RESET)"; \ + ls -lh $(OUTPUT) | awk '{print " " $$9 " (" $$5 ")"}'; \ + else \ + echo "$(COLOR_WARNING)Plugin not built yet$(COLOR_RESET)"; \ + fi + +.DEFAULT_GOAL := help diff --git a/examples/plugins/hello-world-wasm-typescript/README.md b/examples/plugins/hello-world-wasm-typescript/README.md new file mode 100644 index 0000000000..d573350193 --- /dev/null +++ b/examples/plugins/hello-world-wasm-typescript/README.md @@ -0,0 +1,453 @@ +# Bifrost WASM Plugin (TypeScript/AssemblyScript) + +A comprehensive example of a Bifrost plugin written in TypeScript and compiled to WebAssembly using AssemblyScript. This plugin demonstrates proper structure definitions, JSON parsing, context handling, and request/response modification patterns. + +## Prerequisites + +### Node.js Installation + +Node.js is required to run AssemblyScript: + +**macOS:** +```bash +brew install node +``` + +**Linux:** +```bash +curl -fsSL https://deb.nodesource.com/setup_20.x | sudo -E bash - +sudo apt install -y nodejs +``` + +**Other platforms:** +See [Node.js Downloads](https://nodejs.org/en/download/) + +## Building + +```bash +# Install dependencies and build +make build + +# Build with debug info +make build-debug + +# Clean build artifacts +make clean +``` + +The compiled plugin will be at `build/hello-world.wasm`. + +## File Structure + +``` +assembly/ +├── index.ts # Plugin implementation (hooks) +├── memory.ts # Memory management utilities +├── types.ts # Type definitions (mirrors Go SDK) +└── tsconfig.json # AssemblyScript config +``` + +## Plugin Structure + +WASM plugins must export the following functions: + +| Export | Signature | Description | +|--------|-----------|-------------| +| `malloc` | `(size: u32) -> u32` | Allocate memory for host to write data | +| `free` | `(ptr: u32)` | Free allocated memory | +| `get_name` | `() -> u64` | Returns packed ptr+len of plugin name | +| `init` | `(config_ptr, config_len: u32) -> i32` | Initialize with config (optional) | +| `http_intercept` | `(input_ptr, input_len: u32) -> u64` | HTTP transport intercept | +| `pre_hook` | `(input_ptr, input_len: u32) -> u64` | Pre-request hook | +| `post_hook` | `(input_ptr, input_len: u32) -> u64` | Post-response hook | +| `cleanup` | `() -> i32` | Cleanup resources (0 = success) | + +### Return Value Format + +Functions returning data use a packed `u64` format: +- Upper 32 bits: pointer to data in WASM memory +- Lower 32 bits: length of data + +## Data Structures + +This plugin uses `json-as` with `@json` decorators for automatic JSON serialization. All structures mirror the Go SDK types: + +### Context + +```typescript +@json +class BifrostContext { + request_id: string = '' // Unique request identifier + plugin_processed: string = '' // Custom plugin values + plugin_name: string = '' +} +``` + +### HTTP Transport Types + +```typescript +@json +class HTTPRequest { + method: string = '' // GET, POST, etc. + path: string = '' // /v1/chat/completions + body: string = '' // base64 encoded +} + +@json +class HTTPResponse { + status_code: i32 = 200 // HTTP status code + body: string = '' // base64 encoded +} +``` + +### Chat Completion Types + +```typescript +@json +class ChatMessage { + role: string = '' // "user", "assistant", "system", "tool" + content: string = '' + name: string = '' + tool_call_id: string = '' +} + +@json +class ChatParameters { + temperature: f64 = 0 + max_completion_tokens: i32 = 0 + top_p: f64 = 0 +} + +@json +class BifrostChatRequest { + provider: string = '' // "openai", "anthropic", etc. + model: string = '' // "gpt-4", "claude-3", etc. + input: ChatMessage[] = [] + params: ChatParameters = new ChatParameters() +} +``` + +### Response Types + +```typescript +@json +class LLMUsage { + prompt_tokens: i32 = 0 + completion_tokens: i32 = 0 + total_tokens: i32 = 0 +} + +@json +class ResponseChoice { + index: i32 = 0 + message: ChatMessage = new ChatMessage() + finish_reason: string = 'stop' // "stop", "length", "tool_calls" +} + +@json +class BifrostChatResponse { + id: string = '' + model: string = '' + choices: ResponseChoice[] = [] + usage: LLMUsage = new LLMUsage() +} +``` + +### Error Types + +```typescript +@json +class ErrorField { + message: string = '' + type: string = '' // "rate_limit", "auth_error", etc. + code: string = '' // "429", "401", etc. +} + +@json +class BifrostError { + error: ErrorField = new ErrorField() + status_code: i32 = 0 +} +``` + +### Short Circuit + +```typescript +@json +class PluginShortCircuit { + response: BifrostResponse | null = null // Success short-circuit + error: BifrostError | null = null // Error short-circuit +} +``` + +## Hook Input/Output Structures + +### http_intercept + +**Input:** +```json +{ + "context": { "request_id": "abc-123" }, + "request": { + "method": "POST", + "path": "/v1/chat/completions", + "headers": { "Content-Type": "application/json" }, + "query": {}, + "body": "" + } +} +``` + +**Output:** +```json +{ + "context": { "request_id": "abc-123", "custom_key": "value" }, + "request": {}, + "response": { "status_code": 200, "headers": {}, "body": "" }, + "has_response": false, + "error": "" +} +``` + +### pre_hook + +**Input:** +```json +{ + "context": { "request_id": "abc-123" }, + "request": { + "provider": "openai", + "model": "gpt-4", + "input": [{ "role": "user", "content": "Hello" }], + "params": { "temperature": 0.7 } + } +} +``` + +**Output:** +```json +{ + "context": { "request_id": "abc-123", "plugin_processed": "true" }, + "request": {}, + "short_circuit": { + "response": { "chat_response": { ... } } + }, + "has_short_circuit": false, + "error": "" +} +``` + +### post_hook + +**Input:** +```json +{ + "context": { "request_id": "abc-123", "plugin_processed": "true" }, + "response": { + "chat_response": { + "id": "chatcmpl-123", + "model": "gpt-4", + "choices": [{ "index": 0, "message": { "role": "assistant", "content": "Hi!" } }], + "usage": { "prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15 } + } + }, + "error": {}, + "has_error": false +} +``` + +**Output:** +```json +{ + "context": { "request_id": "abc-123", "post_hook_completed": "true" }, + "response": {}, + "error": {}, + "has_error": false, + "hook_error": "" +} +``` + +## Usage Examples + +### Modifying Context + +```typescript +import { JSON } from 'json-as' + +export function pre_hook(inputPtr: u32, inputLen: u32): u64 { + const inputJson = readString(inputPtr, inputLen) + const input = JSON.parse(inputJson) + + const output = new PreHookOutput() + output.context = input.context + + // Add custom values to context + output.context.plugin_processed = 'true' + output.context.plugin_name = 'my-plugin' + + return writeString(JSON.stringify(output)) +} +``` + +### Short-Circuit with Mock Response + +```typescript +import { JSON } from 'json-as' + +export function pre_hook(inputPtr: u32, inputLen: u32): u64 { + const inputJson = readString(inputPtr, inputLen) + const input = JSON.parse(inputJson) + + // Check if this should be mocked + const model = input.request.model + if (model === 'mock-model') { + const output = new PreHookOutput() + output.context = input.context + output.has_short_circuit = true + output.short_circuit = new PluginShortCircuit() + + // Build mock response + const mockResponse = new BifrostResponse() + mockResponse.chat_response = new BifrostChatResponse() + mockResponse.chat_response!.id = 'mock-' + input.context.request_id + mockResponse.chat_response!.model = 'mock-model' + + const choice = new ResponseChoice() + choice.message.role = 'assistant' + choice.message.content = 'This is a mock response!' + mockResponse.chat_response!.choices.push(choice) + + mockResponse.chat_response!.usage.prompt_tokens = 10 + mockResponse.chat_response!.usage.completion_tokens = 15 + mockResponse.chat_response!.usage.total_tokens = 25 + + output.short_circuit!.response = mockResponse + return writeString(JSON.stringify(output)) + } + + // Pass through + const output = new PreHookOutput() + output.context = input.context + return writeString(JSON.stringify(output)) +} +``` + +### Short-Circuit with Error + +```typescript +import { JSON } from 'json-as' + +export function pre_hook(inputPtr: u32, inputLen: u32): u64 { + const inputJson = readString(inputPtr, inputLen) + const input = JSON.parse(inputJson) + + // Check rate limit (example) + if (shouldRateLimit(input.context.request_id)) { + const output = new PreHookOutput() + output.context = input.context + output.has_short_circuit = true + output.short_circuit = new PluginShortCircuit() + + const error = new BifrostError() + error.error.message = 'Rate limit exceeded' + error.error.type = 'rate_limit' + error.error.code = '429' + error.status_code = 429 + + output.short_circuit!.error = error + return writeString(JSON.stringify(output)) + } + + // Pass through + const output = new PreHookOutput() + output.context = input.context + return writeString(JSON.stringify(output)) +} +``` + +### Modifying Responses in post_hook + +```typescript +import { JSON } from 'json-as' + +export function post_hook(inputPtr: u32, inputLen: u32): u64 { + const inputJson = readString(inputPtr, inputLen) + const input = JSON.parse(inputJson) + + const output = new PostHookOutput() + output.context = input.context + + // Handle errors + if (input.has_error && input.error !== null) { + output.has_error = true + output.error = input.error + // Could modify error here if needed + return writeString(JSON.stringify(output)) + } + + // Modify response + if (input.response !== null && input.response!.chat_response !== null) { + output.response = input.response + // Could add logging, metrics, or modify response here + } + + return writeString(JSON.stringify(output)) +} +``` + +## Usage with Bifrost + +Configure the plugin in your Bifrost config: + +```json +{ + "plugins": [ + { + "path": "/path/to/hello-world.wasm", + "name": "hello-world-wasm-typescript", + "enabled": true, + "config": { + "custom_option": "value" + } + } + ] +} +``` + +## AssemblyScript Notes + +AssemblyScript is similar to TypeScript but with some differences: + +1. **Types are required**: All variables must have explicit types +2. **No closures**: Functions cannot capture variables from outer scope +3. **Limited stdlib**: Not all JavaScript/TypeScript features are available +4. **Strict null handling**: Null checks are required +5. **JSON via json-as**: Uses the `json-as` package with `@json` decorators for serialization + +This plugin uses `json-as` for JSON parsing/serialization: + +```typescript +import { JSON } from 'json-as' + +@json +class MyClass { + name: string = '' + value: i32 = 0 +} + +// Parse JSON +const obj = JSON.parse('{"name":"test","value":42}') + +// Stringify to JSON +const json = JSON.stringify(obj) +``` + +See [AssemblyScript Documentation](https://www.assemblyscript.org/introduction.html) and [json-as Documentation](https://github.com/JairusSW/as-json) for more details. + +## Benefits + +1. **Familiar syntax**: TypeScript-like syntax for JS/TS developers +2. **Cross-platform**: Single `.wasm` binary runs on any OS/architecture +3. **Security**: WASM provides sandboxed execution +4. **Type Safety**: Strongly typed structures catch errors at compile time +5. **npm ecosystem**: Can use npm for dependency management diff --git a/examples/plugins/hello-world-wasm-typescript/assembly/index.ts b/examples/plugins/hello-world-wasm-typescript/assembly/index.ts new file mode 100644 index 0000000000..1df93bad6b --- /dev/null +++ b/examples/plugins/hello-world-wasm-typescript/assembly/index.ts @@ -0,0 +1,251 @@ +/** + * Bifrost WASM Plugin for TypeScript/AssemblyScript + * + * This plugin demonstrates the proper structure for parsing inputs, + * building responses, and handling context - similar to Go plugin patterns. + * + * Build with: npm run build + */ + +import { JSON } from 'json-as' + +// Memory management exports +import { free as _free, malloc as _malloc, readString, writeString } from './memory' + +// Type definitions +import { + HTTPInterceptInput, + HTTPInterceptOutput, + PostHookInput, + PostHookOutput, + PreHookInput, + PreHookOutput +} from './types' + +// ============================================================================= +// Re-export memory functions for WASM host +// ============================================================================= + +export function malloc(size: u32): u32 { + return _malloc(size) +} + +export function free(ptr: u32): void { + _free(ptr) +} + +// ============================================================================= +// Plugin Configuration +// ============================================================================= + +// Plugin configuration storage +let pluginConfig: string = '' + +// ============================================================================= +// Exported Plugin Functions +// ============================================================================= + +// Return the plugin name +export function get_name(): u64 { + return writeString('hello-world-wasm-typescript') +} + +// Initialize the plugin with config +// Returns 0 on success, non-zero on error +export function init(configPtr: u32, configLen: u32): i32 { + // Parse and store configuration + pluginConfig = readString(configPtr, configLen) + + // Validate configuration if needed + // For this example, we just accept any config + + return 0 // Success +} + +/** + * HTTP transport intercept + * Called at the HTTP layer before request enters Bifrost core. + * Can modify headers, query params, or short-circuit with a response. + */ +export function http_intercept(inputPtr: u32, inputLen: u32): u64 { + // Parse input using json-as + const inputJson = readString(inputPtr, inputLen) + const input = JSON.parse(inputJson) + + // Create output with context preserved + const output = new HTTPInterceptOutput() + output.context = input.context + output.has_response = false + + // Example: Short-circuit health check endpoint + // Uncomment to test: + /* + if (input.request.path === '/health') { + output.has_response = true + output.response = new HTTPResponse() + output.response!.status_code = 200 + output.response!.body = 'eyJzdGF0dXMiOiJvayJ9' // base64 of {"status":"ok"} + return writeString(JSON.stringify(output)) + } + */ + + return writeString(JSON.stringify(output)) +} + +/** + * Pre-request hook + * Called before request is sent to the provider. + * Can modify the request or short-circuit with a response/error. + */ +export function pre_hook(inputPtr: u32, inputLen: u32): u64 { + // Parse input using json-as + const inputJson = readString(inputPtr, inputLen) + const input = JSON.parse(inputJson) + + // Create output with context preserved + const output = new PreHookOutput() + output.context = input.context + + // Add custom values to context for tracking in post_hook + output.context.plugin_processed = 'true' + output.context.plugin_name = 'hello-world-wasm-typescript' + + // Get provider and model from request + let provider = input.request.provider + let model = input.request.model + if (input.request.chat_request !== null) { + provider = input.request.chat_request!.provider + model = input.request.chat_request!.model + } + + // Example: Short-circuit with mock response for specific model + // Uncomment to test: + /* + if (model === 'mock-model') { + output.has_short_circuit = true + output.short_circuit = new PluginShortCircuit() + + const mockResponse = new BifrostResponse() + mockResponse.chat_response = new BifrostChatResponse() + mockResponse.chat_response!.id = 'mock-' + input.context.request_id + mockResponse.chat_response!.model = 'mock-model' + + const choice = new ResponseChoice() + choice.message.role = 'assistant' + choice.message.content = 'This is a mock response from the WASM plugin!' + choice.finish_reason = 'stop' + mockResponse.chat_response!.choices.push(choice) + + mockResponse.chat_response!.usage.prompt_tokens = 10 + mockResponse.chat_response!.usage.completion_tokens = 15 + mockResponse.chat_response!.usage.total_tokens = 25 + + output.short_circuit!.response = mockResponse + return writeString(JSON.stringify(output)) + } + */ + + // Example: Short-circuit with rate limit error + // Uncomment to test: + /* + if (shouldRateLimit(input.context.request_id)) { + output.has_short_circuit = true + output.short_circuit = new PluginShortCircuit() + + const error = new BifrostError() + error.error.message = 'Rate limit exceeded' + error.error.type = 'rate_limit' + error.error.code = '429' + error.status_code = 429 + + output.short_circuit!.error = error + return writeString(JSON.stringify(output)) + } + */ + + // Pass through - null request means use original + return writeString(JSON.stringify(output)) +} + +/** + * Post-response hook + * Called after response is received from provider. + * Can modify the response or error. + */ +export function post_hook(inputPtr: u32, inputLen: u32): u64 { + // Parse input using json-as + const inputJson = readString(inputPtr, inputLen) + const input = JSON.parse(inputJson) + + // Create output with context preserved + const output = new PostHookOutput() + output.context = input.context + + // Check if our plugin processed this request + const pluginProcessed = input.context.plugin_processed + + // Add completion marker if plugin was involved + if (pluginProcessed === 'true') { + output.context.post_hook_completed = 'true' + } + + // Handle error case + if (input.has_error && input.error !== null) { + output.has_error = true + output.error = input.error + + // Example: Modify error message + // Uncomment to test: + /* + const modifiedError = new BifrostError() + modifiedError.error.message = input.error!.error.message + ' (processed by WASM plugin)' + modifiedError.error.type = input.error!.error.type + modifiedError.error.code = input.error!.error.code + modifiedError.status_code = input.error!.status_code + output.error = modifiedError + */ + + return writeString(JSON.stringify(output)) + } + + // Handle success case - pass through response + if (input.response !== null) { + output.response = input.response + + // Example: Modify response model name + // Uncomment to test: + /* + if (input.response!.chat_response !== null) { + const modifiedResponse = new BifrostResponse() + modifiedResponse.chat_response = input.response!.chat_response + modifiedResponse.chat_response!.model += ' (via wasm-ts)' + output.response = modifiedResponse + } + */ + } + + return writeString(JSON.stringify(output)) +} + +/** + * Cleanup resources + * Called when plugin is being unloaded. + * Returns 0 on success, non-zero on error + */ +export function cleanup(): i32 { + // Clear any stored configuration + pluginConfig = '' + + // Perform any necessary cleanup + return 0 // Success +} + +// ============================================================================= +// Helper Functions +// ============================================================================= + +// Example rate limit check function +function shouldRateLimit(_requestId: string): bool { + // Implement your rate limiting logic here + return false +} diff --git a/examples/plugins/hello-world-wasm-typescript/assembly/memory.ts b/examples/plugins/hello-world-wasm-typescript/assembly/memory.ts new file mode 100644 index 0000000000..fcfb425e68 --- /dev/null +++ b/examples/plugins/hello-world-wasm-typescript/assembly/memory.ts @@ -0,0 +1,45 @@ +/** + * Memory management utilities for WASM plugins. + * Handles allocation, deallocation, and string read/write operations. + */ + +// Pack a pointer and length into a single u64 +// Upper 32 bits: pointer, Lower 32 bits: length +export function packResult(ptr: u32, len: u32): u64 { + return (u64(ptr) << 32) | u64(len) +} + +// Write a string to memory and return packed pointer+length +export function writeString(s: string): u64 { + if (s.length === 0) { + return 0 + } + const encoded = String.UTF8.encode(s) + const ptr = changetype(encoded) + return packResult(ptr, encoded.byteLength) +} + +// Read a string from memory given pointer and length +export function readString(ptr: u32, len: u32): string { + if (len === 0) { + return '' + } + const buffer = new ArrayBuffer(len) + memory.copy(changetype(buffer), ptr, len) + return String.UTF8.decode(buffer) +} + +// Allocate memory for the host to write data +export function malloc(size: u32): u32 { + if (size === 0) { + return 0 + } + const buffer = new ArrayBuffer(size) + return changetype(buffer) +} + +// Free allocated memory (handled by AssemblyScript runtime) +export function free(_ptr: u32): void { + // AssemblyScript handles garbage collection + // This is provided for API compatibility +} diff --git a/examples/plugins/hello-world-wasm-typescript/assembly/tsconfig.json b/examples/plugins/hello-world-wasm-typescript/assembly/tsconfig.json new file mode 100644 index 0000000000..798b474eab --- /dev/null +++ b/examples/plugins/hello-world-wasm-typescript/assembly/tsconfig.json @@ -0,0 +1,4 @@ +{ + "extends": "assemblyscript/std/assembly.json", + "include": ["./**/*.ts"] +} diff --git a/examples/plugins/hello-world-wasm-typescript/assembly/types.ts b/examples/plugins/hello-world-wasm-typescript/assembly/types.ts new file mode 100644 index 0000000000..073a3caee9 --- /dev/null +++ b/examples/plugins/hello-world-wasm-typescript/assembly/types.ts @@ -0,0 +1,251 @@ +/** + * Type definitions for Bifrost WASM plugins. + * These structures mirror the Go SDK types for interoperability. + */ + +import { JSON } from 'json-as' + +// ============================================================================= +// Context Structure +// ============================================================================= + +/** + * BifrostContext holds request-scoped values passed between hooks. + * Common keys include: + * - request_id: Unique identifier for the request + * - Custom plugin values can be added and will be persisted across hooks + */ +@json +export class BifrostContext { + request_id: string = '' + + // Custom values for plugin use (add more as needed) + plugin_processed: string = '' + plugin_name: string = '' + post_hook_completed: string = '' +} + +// ============================================================================= +// HTTP Transport Structures +// ============================================================================= + +/** + * HTTPRequest represents an incoming HTTP request at the transport layer. + * Body is base64-encoded. + */ +@json +export class HTTPRequest { + method: string = '' + path: string = '' + body: string = '' // base64 encoded + headers: Map = new Map() + query: Map = new Map() +} + +/** + * HTTPResponse represents an HTTP response to return. + */ +@json +export class HTTPResponse { + status_code: i32 = 200 + body: string = '' // base64 encoded + headers: Map = new Map() +} + +/** + * HTTPInterceptInput is the input for http_intercept hook. + */ +@json +export class HTTPInterceptInput { + context: BifrostContext = new BifrostContext() + request: HTTPRequest = new HTTPRequest() +} + +/** + * HTTPInterceptOutput is the output for http_intercept hook. + */ +@json +export class HTTPInterceptOutput { + context: BifrostContext = new BifrostContext() + request: HTTPRequest | null = null + response: HTTPResponse | null = null + has_response: bool = false + error: string = '' +} + +// ============================================================================= +// Chat Completion Structures (BifrostRequest) +// ============================================================================= + +/** + * ChatMessage represents a message in the conversation. + */ +@json +export class ChatMessage { + role: string = '' // "user", "assistant", "system", "tool" + content: string = '' + name: string = '' + tool_call_id: string = '' +} + +/** + * ChatParameters contains optional parameters for chat completion. + */ +@json +export class ChatParameters { + temperature: f64 = 0 + max_completion_tokens: i32 = 0 + top_p: f64 = 0 +} + +/** + * BifrostChatRequest represents a chat completion request. + */ +@json +export class BifrostChatRequest { + provider: string = '' + model: string = '' + input: ChatMessage[] = [] + params: ChatParameters = new ChatParameters() +} + +/** + * BifrostRequest is the unified request structure. + */ +@json +export class BifrostRequest { + chat_request: BifrostChatRequest | null = null + + // Direct fields for simpler request structures + provider: string = '' + model: string = '' + input: ChatMessage[] = [] + params: ChatParameters | null = null +} + +// ============================================================================= +// Response Structures (BifrostResponse) +// ============================================================================= + +/** + * LLMUsage contains token usage information. + */ +@json +export class LLMUsage { + prompt_tokens: i32 = 0 + completion_tokens: i32 = 0 + total_tokens: i32 = 0 +} + +/** + * ResponseChoice represents a single completion choice. + */ +@json +export class ResponseChoice { + index: i32 = 0 + message: ChatMessage = new ChatMessage() + finish_reason: string = 'stop' +} + +/** + * BifrostChatResponse represents a chat completion response. + */ +@json +export class BifrostChatResponse { + id: string = '' + model: string = '' + choices: ResponseChoice[] = [] + usage: LLMUsage = new LLMUsage() +} + +/** + * BifrostResponse is the unified response structure. + */ +@json +export class BifrostResponse { + chat_response: BifrostChatResponse | null = null +} + +// ============================================================================= +// Error Structure +// ============================================================================= + +/** + * ErrorField contains the error details. + */ +@json +export class ErrorField { + message: string = '' + type: string = '' + code: string = '' +} + +/** + * BifrostError represents an error response. + */ +@json +export class BifrostError { + error: ErrorField = new ErrorField() + status_code: i32 = 0 +} + +// ============================================================================= +// Short Circuit Structure +// ============================================================================= + +/** + * PluginShortCircuit allows plugins to short-circuit the request flow. + */ +@json +export class PluginShortCircuit { + response: BifrostResponse | null = null + error: BifrostError | null = null +} + +// ============================================================================= +// Hook Input/Output Structures +// ============================================================================= + +/** + * PreHookInput is the input for pre_hook. + */ +@json +export class PreHookInput { + context: BifrostContext = new BifrostContext() + request: BifrostRequest = new BifrostRequest() +} + +/** + * PreHookOutput is the output for pre_hook. + */ +@json +export class PreHookOutput { + context: BifrostContext = new BifrostContext() + request: BifrostRequest | null = null + short_circuit: PluginShortCircuit | null = null + has_short_circuit: bool = false + error: string = '' +} + +/** + * PostHookInput is the input for post_hook. + */ +@json +export class PostHookInput { + context: BifrostContext = new BifrostContext() + response: BifrostResponse | null = null + error: BifrostError | null = null + has_error: bool = false +} + +/** + * PostHookOutput is the output for post_hook. + */ +@json +export class PostHookOutput { + context: BifrostContext = new BifrostContext() + response: BifrostResponse | null = null + error: BifrostError | null = null + has_error: bool = false + hook_error: string = '' +} diff --git a/examples/plugins/hello-world-wasm-typescript/package-lock.json b/examples/plugins/hello-world-wasm-typescript/package-lock.json new file mode 100644 index 0000000000..b66ee621e1 --- /dev/null +++ b/examples/plugins/hello-world-wasm-typescript/package-lock.json @@ -0,0 +1,65 @@ +{ + "name": "hello-world-wasm-typescript", + "version": "0.1.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "hello-world-wasm-typescript", + "version": "0.1.0", + "dependencies": { + "json-as": "^1.0.0" + }, + "devDependencies": { + "assemblyscript": "^0.27.29" + } + }, + "node_modules/assemblyscript": { + "version": "0.27.37", + "resolved": "https://registry.npmjs.org/assemblyscript/-/assemblyscript-0.27.37.tgz", + "integrity": "sha512-YtY5k3PiV3SyUQ6gRlR2OCn8dcVRwkpiG/k2T5buoL2ymH/Z/YbaYWbk/f9mO2HTgEtGWjPiAQrIuvA7G/63Gg==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "binaryen": "116.0.0-nightly.20240114", + "long": "^5.2.4" + }, + "bin": { + "asc": "bin/asc.js", + "asinit": "bin/asinit.js" + }, + "engines": { + "node": ">=18", + "npm": ">=10" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/assemblyscript" + } + }, + "node_modules/binaryen": { + "version": "116.0.0-nightly.20240114", + "resolved": "https://registry.npmjs.org/binaryen/-/binaryen-116.0.0-nightly.20240114.tgz", + "integrity": "sha512-0GZrojJnuhoe+hiwji7QFaL3tBlJoA+KFUN7ouYSDGZLSo9CKM8swQX8n/UcbR0d1VuZKU+nhogNzv423JEu5A==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "wasm-opt": "bin/wasm-opt", + "wasm2js": "bin/wasm2js" + } + }, + "node_modules/json-as": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/json-as/-/json-as-1.2.3.tgz", + "integrity": "sha512-yvRkR0Lv8597jHbsf+e93fo+pQctbsiDl7HGuBl71GqKhNT9KtyqtNzal7L7nEIfUq1NNkdACaT1O5D8KtX2zw==", + "license": "MIT" + }, + "node_modules/long": { + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/long/-/long-5.3.2.tgz", + "integrity": "sha512-mNAgZ1GmyNhD7AuqnTG3/VQ26o760+ZYBPKjPvugO8+nLbYfX6TVpJPseBvopbdY+qpZ/lKUnmEc1LeZYS3QAA==", + "dev": true, + "license": "Apache-2.0" + } + } +} diff --git a/examples/plugins/hello-world-wasm-typescript/package.json b/examples/plugins/hello-world-wasm-typescript/package.json new file mode 100644 index 0000000000..a9e9d98122 --- /dev/null +++ b/examples/plugins/hello-world-wasm-typescript/package.json @@ -0,0 +1,15 @@ +{ + "name": "hello-world-wasm-typescript", + "version": "0.1.0", + "description": "A Bifrost WASM plugin example in TypeScript (AssemblyScript)", + "scripts": { + "build": "asc assembly/index.ts --outFile build/hello-world.wasm --optimize --exportRuntime", + "build:debug": "asc assembly/index.ts --outFile build/hello-world.wasm --debug --exportRuntime" + }, + "dependencies": { + "json-as": "^1.0.0" + }, + "devDependencies": { + "assemblyscript": "^0.27.29" + } +} diff --git a/examples/plugins/hello-world/main.go b/examples/plugins/hello-world/main.go index 0872b591e4..51bd541aa8 100644 --- a/examples/plugins/hello-world/main.go +++ b/examples/plugins/hello-world/main.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/maximhq/bifrost/core/schemas" - "github.com/valyala/fasthttp" ) func Init(config any) error { @@ -16,14 +15,14 @@ func GetName() string { return "Hello World Plugin" } -func HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { - return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { - return func(ctx *fasthttp.RequestCtx) { - fmt.Println("HTTPTransportMiddleware called") - ctx.SetUserValue(schemas.BifrostContextKey("hello-world-plugin-transport-interceptor"), "transport-interceptor-value") - next(ctx) - } - } +func HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + fmt.Println("HTTPTransportIntercept called") + // Modify request in-place + req.Headers["X-Hello-World-Plugin"] = "transport-interceptor-value" + // Store value in context for PreHook/PostHook + ctx.SetValue(schemas.BifrostContextKey("hello-world-plugin-transport-interceptor"), "transport-interceptor-value") + // Return nil to continue processing, or return &schemas.HTTPResponse{} to short-circuit + return nil, nil } func PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index 3cd01229d4..60846eb5ae 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -55,6 +55,7 @@ func (s *RDBConfigStore) UpdateClientConfig(ctx context.Context, config *ClientC MCPToolExecutionTimeout: config.MCPToolExecutionTimeout, MCPCodeModeBindingLevel: config.MCPCodeModeBindingLevel, HeaderFilterConfig: config.HeaderFilterConfig, + ConfigHash: config.ConfigHash, } // Delete existing client config and create new one in a transaction return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { @@ -212,6 +213,7 @@ func (s *RDBConfigStore) GetClientConfig(ctx context.Context) (*ClientConfig, er MCPToolExecutionTimeout: dbConfig.MCPToolExecutionTimeout, MCPCodeModeBindingLevel: dbConfig.MCPCodeModeBindingLevel, HeaderFilterConfig: dbConfig.HeaderFilterConfig, + ConfigHash: dbConfig.ConfigHash, }, nil } diff --git a/framework/plugins/dynamicplugin.go b/framework/plugins/dynamicplugin.go deleted file mode 100644 index 4b4fe47df8..0000000000 --- a/framework/plugins/dynamicplugin.go +++ /dev/null @@ -1,177 +0,0 @@ -package plugins - -import ( - "fmt" - "os" - "plugin" - "strings" - "time" - - "github.com/maximhq/bifrost/core/schemas" - "github.com/valyala/fasthttp" -) - -// DynamicPlugin is the interface for a dynamic plugin -type DynamicPlugin struct { - Enabled bool - Path string - - Config any - - filename string - plugin *plugin.Plugin - - getName func() string - httpTransportMiddleware func() schemas.BifrostHTTPMiddleware - preHook func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) - postHook func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) - cleanup func() error -} - -// GetName returns the name of the plugin -func (dp *DynamicPlugin) GetName() string { - return dp.getName() -} - -// HTTPTransportMiddleware returns the HTTP transport middleware function for this plugin -func (dp *DynamicPlugin) HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { - if dp.httpTransportMiddleware == nil { - return nil - } - return dp.httpTransportMiddleware() -} - -// PreHook is not used for dynamic plugins -func (dp *DynamicPlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { - return dp.preHook(ctx, req) -} - -// PostHook is not used for dynamic plugins -func (dp *DynamicPlugin) PostHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { - return dp.postHook(ctx, resp, bifrostErr) -} - -// Cleanup is not used for dynamic plugins -func (dp *DynamicPlugin) Cleanup() error { - return dp.cleanup() -} - -// loadDynamicPlugin loads a dynamic plugin from a path -func loadDynamicPlugin(path string, config any) (schemas.Plugin, error) { - dp := &DynamicPlugin{ - Path: path, - } - // Checking if path is URL or file path - if strings.HasPrefix(dp.Path, "http") { - // Download the file - req := fasthttp.AcquireRequest() - defer fasthttp.ReleaseRequest(req) - response := fasthttp.AcquireResponse() - defer fasthttp.ReleaseResponse(response) - - req.SetRequestURI(dp.Path) - req.Header.SetMethod(fasthttp.MethodGet) - req.Header.Set("Accept", "application/octet-stream") - req.Header.Set("Accept-Encoding", "gzip") - req.Header.Set("Accept-Language", "en-US,en;q=0.9") - err := fasthttp.DoTimeout(req, response, 120*time.Second) - if err != nil { - return nil, err - } - if response.StatusCode() != fasthttp.StatusOK { - return nil, fmt.Errorf("failed to download plugin: %d", response.StatusCode()) - } - // Create a unique temporary file for the plugin - tempFile, err := os.CreateTemp(os.TempDir(), "bifrost-plugin-*.so") - if err != nil { - return nil, fmt.Errorf("failed to create temporary file: %w", err) - } - tempPath := tempFile.Name() - // Write the downloaded body to the temporary file - _, err = tempFile.Write(response.Body()) - if err != nil { - tempFile.Close() - os.Remove(tempPath) - return nil, fmt.Errorf("failed to write plugin to temporary file: %w", err) - } - // Close the file - err = tempFile.Close() - if err != nil { - os.Remove(tempPath) - return nil, fmt.Errorf("failed to close temporary file: %w", err) - } - // Set file permissions to be executable - err = os.Chmod(tempPath, 0755) - if err != nil { - os.Remove(tempPath) - return nil, fmt.Errorf("failed to set executable permissions on plugin: %w", err) - } - dp.Path = tempPath - } - plugin, err := plugin.Open(dp.Path) - if err != nil { - return nil, err - } - ok := false - // Looking up for optional Init method - initSym, err := plugin.Lookup("Init") - if err != nil { - if strings.Contains(err.Error(), "symbol Init not found") { - initSym = nil - } else { - return nil, err - } - } - if initSym != nil { - initFunc, ok := initSym.(func(config any) error) - if !ok { - return nil, fmt.Errorf("failed to cast Init to func(config any) error") - } - err := initFunc(config) - if err != nil { - return nil, err - } - } - // Looking up for GetName method - getNameSym, err := plugin.Lookup("GetName") - if err != nil { - return nil, err - } - if dp.getName, ok = getNameSym.(func() string); !ok { - return nil, fmt.Errorf("failed to cast GetName to func() string") - } - // Looking up for HTTPTransportMiddleware method - httpTransportMiddlewareSym, err := plugin.Lookup("HTTPTransportMiddleware") - if err != nil { - return nil, err - } - if dp.httpTransportMiddleware, ok = httpTransportMiddlewareSym.(func() schemas.BifrostHTTPMiddleware); !ok { - return nil, fmt.Errorf("failed to cast HTTPTransportMiddleware to func() schemas.BifrostHTTPMiddleware") - } - // Looking up for PreHook method - preHookSym, err := plugin.Lookup("PreHook") - if err != nil { - return nil, err - } - if dp.preHook, ok = preHookSym.(func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error)); !ok { - return nil, fmt.Errorf("failed to cast PreHook to func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error)") - } - // Looking up for PostHook method - postHookSym, err := plugin.Lookup("PostHook") - if err != nil { - return nil, err - } - if dp.postHook, ok = postHookSym.(func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error)); !ok { - return nil, fmt.Errorf("failed to cast PostHook to func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error)") - } - // Looking up for Cleanup method - cleanupSym, err := plugin.Lookup("Cleanup") - if err != nil { - return nil, err - } - if dp.cleanup, ok = cleanupSym.(func() error); !ok { - return nil, fmt.Errorf("failed to cast Cleanup to func() error") - } - dp.plugin = plugin - return dp, nil -} diff --git a/framework/plugins/loader.go b/framework/plugins/loader.go new file mode 100644 index 0000000000..1ad11a2eba --- /dev/null +++ b/framework/plugins/loader.go @@ -0,0 +1,8 @@ +package plugins + +import "github.com/maximhq/bifrost/core/schemas" + +// PluginLoader is the contract for a plugin loader +type PluginLoader interface { + LoadDynamicPlugin(path string, config any) (schemas.Plugin, error) +} diff --git a/framework/plugins/main.go b/framework/plugins/main.go index dde83b4ccf..ee1a6dc07e 100644 --- a/framework/plugins/main.go +++ b/framework/plugins/main.go @@ -14,11 +14,12 @@ type DynamicPluginConfig struct { // Config is the configuration for the plugins framework type Config struct { + Plugins []DynamicPluginConfig `json:"plugins"` } // LoadPlugins loads the plugins from the config -func LoadPlugins(config *Config) ([]schemas.Plugin, error) { +func LoadPlugins(loader PluginLoader, config *Config) ([]schemas.Plugin, error) { plugins := []schemas.Plugin{} if config == nil { return plugins, nil @@ -27,7 +28,7 @@ func LoadPlugins(config *Config) ([]schemas.Plugin, error) { if !dp.Enabled { continue } - plugin, err := loadDynamicPlugin(dp.Path, dp.Config) + plugin, err := loader.LoadDynamicPlugin(dp.Path, dp.Config) if err != nil { return nil, err } diff --git a/framework/plugins/soloader.go b/framework/plugins/soloader.go new file mode 100644 index 0000000000..9738fe057d --- /dev/null +++ b/framework/plugins/soloader.go @@ -0,0 +1,94 @@ +package plugins + +import ( + "fmt" + "plugin" + "strings" + + "github.com/maximhq/bifrost/core/schemas" +) + +// SharedObjectPluginLoader is the loader for shared object plugins +type SharedObjectPluginLoader struct{} + +// LoadDynamicPlugin loads a dynamic plugin from a shared object file +func (l *SharedObjectPluginLoader) LoadDynamicPlugin(path string, config any) (schemas.Plugin, error) { + dp := &DynamicPlugin{ + Path: path, + } + // Checking if path is URL or file path + if strings.HasPrefix(dp.Path, "http") { + // Download the file + tempPath, err := DownloadPlugin(dp.Path, ".so") + if err != nil { + return nil, err + } + dp.Path = tempPath + } + plugin, err := plugin.Open(dp.Path) + if err != nil { + return nil, err + } + ok := false + // Looking up for optional Init method + initSym, err := plugin.Lookup("Init") + if err != nil { + if strings.Contains(err.Error(), "symbol Init not found") { + initSym = nil + } else { + return nil, err + } + } + if initSym != nil { + initFunc, ok := initSym.(func(config any) error) + if !ok { + return nil, fmt.Errorf("failed to cast Init to func(config any) error") + } + err := initFunc(config) + if err != nil { + return nil, err + } + } + // Looking up for GetName method + getNameSym, err := plugin.Lookup("GetName") + if err != nil { + return nil, err + } + if dp.getName, ok = getNameSym.(func() string); !ok { + return nil, fmt.Errorf("failed to cast GetName to func() string") + } + // Looking up for HTTPTransportIntercept method + httpTransportInterceptSym, err := plugin.Lookup("HTTPTransportIntercept") + if err != nil { + return nil, err + } + if dp.httpTransportIntercept, ok = httpTransportInterceptSym.(func(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error)); !ok { + return nil, fmt.Errorf("failed to cast HTTPTransportIntercept to func(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error)") + } + // Looking up for PreHook method + preHookSym, err := plugin.Lookup("PreHook") + if err != nil { + return nil, err + } + if dp.preHook, ok = preHookSym.(func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error)); !ok { + return nil, fmt.Errorf("failed to cast PreHook to func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error)") + } + // Looking up for PostHook method + postHookSym, err := plugin.Lookup("PostHook") + if err != nil { + return nil, err + } + if dp.postHook, ok = postHookSym.(func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error)); !ok { + return nil, fmt.Errorf("failed to cast PostHook to func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error)") + } + // Looking up for Cleanup method + cleanupSym, err := plugin.Lookup("Cleanup") + if err != nil { + return nil, err + } + if dp.cleanup, ok = cleanupSym.(func() error); !ok { + return nil, fmt.Errorf("failed to cast Cleanup to func() error") + } + dp.plugin = plugin + return dp, nil +} diff --git a/framework/plugins/soplugin.go b/framework/plugins/soplugin.go new file mode 100644 index 0000000000..c909181cb2 --- /dev/null +++ b/framework/plugins/soplugin.go @@ -0,0 +1,52 @@ +package plugins + +import ( + "plugin" + + "github.com/maximhq/bifrost/core/schemas" +) + +// DynamicPlugin is the interface for a dynamic plugin +type DynamicPlugin struct { + Enabled bool + Path string + + Config any + + filename string + plugin *plugin.Plugin + + getName func() string + httpTransportIntercept func(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) + preHook func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) + postHook func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) + cleanup func() error +} + +// GetName returns the name of the plugin +func (dp *DynamicPlugin) GetName() string { + return dp.getName() +} + +// HTTPTransportIntercept intercepts HTTP requests at the transport layer for this plugin +func (dp *DynamicPlugin) HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + if dp.httpTransportIntercept == nil { + return nil, nil + } + return dp.httpTransportIntercept(ctx, req) +} + +// PreHook is not used for dynamic plugins +func (dp *DynamicPlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + return dp.preHook(ctx, req) +} + +// PostHook is not used for dynamic plugins +func (dp *DynamicPlugin) PostHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + return dp.postHook(ctx, resp, bifrostErr) +} + +// Cleanup is not used for dynamic plugins +func (dp *DynamicPlugin) Cleanup() error { + return dp.cleanup() +} diff --git a/framework/plugins/dynamicplugin_test.go b/framework/plugins/soplugin_test.go similarity index 87% rename from framework/plugins/dynamicplugin_test.go rename to framework/plugins/soplugin_test.go index 2b37577439..04ad60dd8c 100644 --- a/framework/plugins/dynamicplugin_test.go +++ b/framework/plugins/soplugin_test.go @@ -13,7 +13,6 @@ import ( "github.com/maximhq/bifrost/core/schemas" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/valyala/fasthttp" ) const ( @@ -39,7 +38,8 @@ func TestDynamicPluginLifecycle(t *testing.T) { }, } - plugins, err := LoadPlugins(config) + loader := &SharedObjectPluginLoader{} + plugins, err := LoadPlugins(loader, config) require.NoError(t, err, "Failed to load plugins") require.Len(t, plugins, 1, "Expected exactly one plugin to be loaded") @@ -51,36 +51,31 @@ func TestDynamicPluginLifecycle(t *testing.T) { assert.Equal(t, "Hello World Plugin", name, "Plugin name should match") }) - // Test HTTPTransportMiddleware - t.Run("HTTPTransportMiddleware", func(t *testing.T) { - // Track if the next handler was called - nextHandlerCalled := false + // Test HTTPTransportIntercept + t.Run("HTTPTransportIntercept", func(t *testing.T) { + ctx := context.Background() + pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second) + defer cancel() - // Create a mock next handler - nextHandler := func(ctx *fasthttp.RequestCtx) { - nextHandlerCalled = true + // Create a test HTTP request + req := &schemas.HTTPRequest{ + Method: "POST", + Path: "/api", + Headers: map[string]string{ + "Content-Type": "application/json", + "Authorization": "Bearer token123", + }, + Query: map[string]string{}, + Body: []byte(`{"test": "data"}`), } - // Get the middleware function - middleware := plugin.HTTPTransportMiddleware() - require.NotNil(t, middleware, "HTTPTransportMiddleware should return a middleware function") + // Call HTTPTransportIntercept + resp, err := plugin.HTTPTransportIntercept(pluginCtx, req) + require.NoError(t, err, "HTTPTransportIntercept should not return error") + assert.Nil(t, resp, "HTTPTransportIntercept should return nil response to continue") - // Wrap the next handler with the middleware - wrappedHandler := middleware(nextHandler) - require.NotNil(t, wrappedHandler, "Middleware should return a wrapped handler") - - // Create a test request context - ctx := &fasthttp.RequestCtx{} - ctx.Request.SetRequestURI("http://example.com/api") - ctx.Request.Header.SetMethod("POST") - ctx.Request.Header.Set("Content-Type", "application/json") - ctx.Request.Header.Set("Authorization", "Bearer token123") - - // Call the wrapped handler - wrappedHandler(ctx) - - // Verify the next handler was called - assert.True(t, nextHandlerCalled, "Next handler should have been called") + // Verify headers were modified (hello-world plugin adds a header) + assert.Equal(t, "transport-interceptor-value", req.Headers["X-Hello-World-Plugin"], "Plugin should have added custom header") }) // Test PreHook @@ -183,7 +178,8 @@ func TestLoadPlugins_DisabledPlugin(t *testing.T) { }, } - plugins, err := LoadPlugins(config) + loader := &SharedObjectPluginLoader{} + plugins, err := LoadPlugins(loader, config) require.NoError(t, err, "LoadPlugins should not error for disabled plugins") assert.Len(t, plugins, 0, "No plugins should be loaded when all are disabled") } @@ -210,7 +206,8 @@ func TestLoadPlugins_MultiplePlugins(t *testing.T) { }, } - plugins, err := LoadPlugins(config) + loader := &SharedObjectPluginLoader{} + plugins, err := LoadPlugins(loader, config) require.NoError(t, err, "LoadPlugins should succeed for multiple plugins") assert.Len(t, plugins, 2, "Two plugins should be loaded") @@ -232,7 +229,8 @@ func TestLoadPlugins_InvalidPath(t *testing.T) { }, } - plugins, err := LoadPlugins(config) + loader := &SharedObjectPluginLoader{} + plugins, err := LoadPlugins(loader, config) assert.Error(t, err, "LoadPlugins should return error for invalid path") assert.Nil(t, plugins, "No plugins should be loaded on error") } @@ -242,8 +240,8 @@ func TestLoadPlugins_EmptyConfig(t *testing.T) { config := &Config{ Plugins: []DynamicPluginConfig{}, } - - plugins, err := LoadPlugins(config) + loader := &SharedObjectPluginLoader{} + plugins, err := LoadPlugins(loader, config) require.NoError(t, err, "LoadPlugins should succeed with empty config") assert.Len(t, plugins, 0, "No plugins should be loaded with empty config") } @@ -253,7 +251,8 @@ func TestDynamicPlugin_ContextPropagation(t *testing.T) { pluginPath := buildHelloWorldPlugin(t) defer cleanupHelloWorldPlugin(t) - plugin, err := loadDynamicPlugin(pluginPath, nil) + loader := &SharedObjectPluginLoader{} + plugin, err := loader.LoadDynamicPlugin(pluginPath, nil) require.NoError(t, err, "Failed to load plugin") // Create a context with a value @@ -288,7 +287,8 @@ func TestDynamicPlugin_ConcurrentCalls(t *testing.T) { pluginPath := buildHelloWorldPlugin(t) defer cleanupHelloWorldPlugin(t) - plugin, err := loadDynamicPlugin(pluginPath, nil) + loader := &SharedObjectPluginLoader{} + plugin, err := loader.LoadDynamicPlugin(pluginPath, nil) require.NoError(t, err, "Failed to load plugin") // Run multiple goroutines calling plugin methods @@ -395,7 +395,8 @@ func TestLoadDynamicPlugin_DirectCall(t *testing.T) { pluginPath := buildHelloWorldPlugin(t) defer cleanupHelloWorldPlugin(t) - plugin, err := loadDynamicPlugin(pluginPath, map[string]interface{}{ + loader := &SharedObjectPluginLoader{} + plugin, err := loader.LoadDynamicPlugin(pluginPath, map[string]interface{}{ "test": "config", }) require.NoError(t, err, "loadDynamicPlugin should succeed") @@ -412,7 +413,8 @@ func TestDynamicPlugin_NilConfig(t *testing.T) { pluginPath := buildHelloWorldPlugin(t) defer cleanupHelloWorldPlugin(t) - plugin, err := loadDynamicPlugin(pluginPath, nil) + loader := &SharedObjectPluginLoader{} + plugin, err := loader.LoadDynamicPlugin(pluginPath, nil) require.NoError(t, err, "loadDynamicPlugin should succeed with nil config") assert.NotNil(t, plugin, "Plugin should not be nil") @@ -426,7 +428,8 @@ func TestDynamicPlugin_ShortCircuitNil(t *testing.T) { pluginPath := buildHelloWorldPlugin(t) defer cleanupHelloWorldPlugin(t) - plugin, err := loadDynamicPlugin(pluginPath, nil) + loader := &SharedObjectPluginLoader{} + plugin, err := loader.LoadDynamicPlugin(pluginPath, nil) require.NoError(t, err, "Failed to load plugin") ctx := context.Background() @@ -451,7 +454,8 @@ func BenchmarkDynamicPlugin_PreHook(b *testing.B) { pluginPath := buildHelloWorldPluginForBenchmark(b) defer cleanupHelloWorldPluginForBenchmark(b) - plugin, err := loadDynamicPlugin(pluginPath, nil) + loader := &SharedObjectPluginLoader{} + plugin, err := loader.LoadDynamicPlugin(pluginPath, nil) require.NoError(b, err, "Failed to load plugin") ctx := context.Background() @@ -476,7 +480,8 @@ func BenchmarkDynamicPlugin_PostHook(b *testing.B) { pluginPath := buildHelloWorldPluginForBenchmark(b) defer cleanupHelloWorldPluginForBenchmark(b) - plugin, err := loadDynamicPlugin(pluginPath, nil) + loader := &SharedObjectPluginLoader{} + plugin, err := loader.LoadDynamicPlugin(pluginPath, nil) require.NoError(b, err, "Failed to load plugin") ctx := context.Background() @@ -499,7 +504,8 @@ func BenchmarkDynamicPlugin_GetName(b *testing.B) { pluginPath := buildHelloWorldPluginForBenchmark(b) defer cleanupHelloWorldPluginForBenchmark(b) - plugin, err := loadDynamicPlugin(pluginPath, nil) + loader := &SharedObjectPluginLoader{} + plugin, err := loader.LoadDynamicPlugin(pluginPath, nil) require.NoError(b, err, "Failed to load plugin") b.ResetTimer() @@ -566,7 +572,8 @@ func TestDynamicPlugin_GetNameNotEmpty(t *testing.T) { pluginPath := buildHelloWorldPlugin(t) defer cleanupHelloWorldPlugin(t) - plugin, err := loadDynamicPlugin(pluginPath, nil) + loader := &SharedObjectPluginLoader{} + plugin, err := loader.LoadDynamicPlugin(pluginPath, nil) require.NoError(t, err, "Failed to load plugin") name := plugin.GetName() diff --git a/framework/plugins/utils.go b/framework/plugins/utils.go new file mode 100644 index 0000000000..f9f4465c5c --- /dev/null +++ b/framework/plugins/utils.go @@ -0,0 +1,72 @@ +package plugins + +import ( + "fmt" + "os" + "time" + + "github.com/valyala/fasthttp" +) + +// DownloadPlugin downloads a plugin from a URL and returns the local file path +func DownloadPlugin(url string, extension string) (string, error) { + req := fasthttp.AcquireRequest() + defer fasthttp.ReleaseRequest(req) + response := fasthttp.AcquireResponse() + defer fasthttp.ReleaseResponse(response) + + req.SetRequestURI(url) + req.Header.SetMethod(fasthttp.MethodGet) + req.Header.Set("Accept", "application/octet-stream") + req.Header.Set("Accept-Encoding", "gzip") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") + + err := fasthttp.DoTimeout(req, response, 120*time.Second) + if err != nil { + return "", err + } + + if response.StatusCode() != fasthttp.StatusOK { + return "", fmt.Errorf("failed to download plugin: %d", response.StatusCode()) + } + + // Decompress the response body if it was gzip/deflate compressed + // BodyUncompressed handles both gzip and deflate encodings based on Content-Encoding header + body, err := response.BodyUncompressed() + if err != nil { + return "", fmt.Errorf("failed to decompress response body: %w", err) + } + + // Create a unique temporary file for the plugin + tempFile, err := os.CreateTemp(os.TempDir(), "bifrost-plugin-*"+extension) + if err != nil { + return "", fmt.Errorf("failed to create temporary file: %w", err) + } + tempPath := tempFile.Name() + + // Write the downloaded body to the temporary file + _, err = tempFile.Write(body) + if err != nil { + tempFile.Close() + os.Remove(tempPath) + return "", fmt.Errorf("failed to write plugin to temporary file: %w", err) + } + + // Close the file + err = tempFile.Close() + if err != nil { + os.Remove(tempPath) + return "", fmt.Errorf("failed to close temporary file: %w", err) + } + + // Set file permissions to be executable (for .so files) + if extension == ".so" { + err = os.Chmod(tempPath, 0755) + if err != nil { + os.Remove(tempPath) + return "", fmt.Errorf("failed to set executable permissions on plugin: %w", err) + } + } + + return tempPath, nil +} diff --git a/plugins/governance/go.mod b/plugins/governance/go.mod index 83abaf9407..871d87bc48 100644 --- a/plugins/governance/go.mod +++ b/plugins/governance/go.mod @@ -9,7 +9,6 @@ require ( github.com/maximhq/bifrost/core v1.3.3 github.com/maximhq/bifrost/framework v1.2.3 github.com/stretchr/testify v1.11.1 - github.com/valyala/fasthttp v1.68.0 ) require ( @@ -102,6 +101,7 @@ require ( github.com/spf13/cast v1.10.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.68.0 // indirect github.com/weaviate/weaviate v1.34.5 // indirect github.com/weaviate/weaviate-go-client/v5 v5.6.0 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect diff --git a/plugins/governance/main.go b/plugins/governance/main.go index 51ce382e63..d8d6000d9f 100644 --- a/plugins/governance/main.go +++ b/plugins/governance/main.go @@ -16,7 +16,6 @@ import ( "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/modelcatalog" - "github.com/valyala/fasthttp" ) // PluginName is the name of the governance plugin @@ -41,7 +40,7 @@ type InMemoryStore interface { type BaseGovernancePlugin interface { GetName() string - HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware + HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) PostHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) Cleanup() error @@ -230,13 +229,13 @@ func (p *GovernancePlugin) GetName() string { return PluginName } -func parseVirtualKey(ctx *fasthttp.RequestCtx) *string { +func parseVirtualKeyFromHTTPRequest(req *schemas.HTTPRequest) *string { var virtualKeyValue string - vkHeader := ctx.Request.Header.Peek("x-bf-vk") - if string(vkHeader) != "" { - return bifrost.Ptr(string(vkHeader)) + vkHeader := req.Headers["x-bf-vk"] + if vkHeader != "" { + return bifrost.Ptr(vkHeader) } - authHeader := string(ctx.Request.Header.Peek("Authorization")) + authHeader := req.Headers["Authorization"] if authHeader != "" { if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { authHeaderValue := strings.TrimSpace(authHeader[7:]) // Remove "Bearer " prefix @@ -248,69 +247,59 @@ func parseVirtualKey(ctx *fasthttp.RequestCtx) *string { if virtualKeyValue != "" { return bifrost.Ptr(virtualKeyValue) } - xAPIKey := string(ctx.Request.Header.Peek("x-api-key")) + xAPIKey := req.Headers["x-api-key"] if xAPIKey != "" && strings.HasPrefix(strings.ToLower(xAPIKey), VirtualKeyPrefix) { return bifrost.Ptr(xAPIKey) } // Checking x-goog-api-key header - xGoogleAPIKey := string(ctx.Request.Header.Peek("x-goog-api-key")) + xGoogleAPIKey := req.Headers["x-goog-api-key"] if xGoogleAPIKey != "" && strings.HasPrefix(strings.ToLower(xGoogleAPIKey), VirtualKeyPrefix) { return bifrost.Ptr(xGoogleAPIKey) } return nil } -// HTTPTransportMiddleware intercepts requests before they are processed (governance decision point) -func (p *GovernancePlugin) HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { - return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { - return func(ctx *fasthttp.RequestCtx) { - virtualKeyValue := parseVirtualKey(ctx) - if virtualKeyValue == nil { - next(ctx) - return - } - // Get the virtual key from the store - virtualKey, ok := p.store.GetVirtualKey(*virtualKeyValue) - if !ok || virtualKey == nil || !virtualKey.IsActive { - next(ctx) - return - } - headers, err := p.addMCPIncludeTools(nil, virtualKey) - if err != nil { - p.logger.Error("failed to add MCP include tools: %v", err) - next(ctx) - return - } - for header, value := range headers { - ctx.Request.Header.Set(header, value) - } - if ctx.Request.Body() == nil { - next(ctx) - return - } - var payload map[string]any - err = sonic.Unmarshal(ctx.Request.Body(), &payload) - if err != nil { - p.logger.Error("failed to unmarshal request body to check for virtual key: %v", err) - next(ctx) - return - } - payload, err = p.loadBalanceProvider(payload, virtualKey) - if err != nil { - p.logger.Error("failed to load balance provider: %v", err) - next(ctx) - return - } - body, err := sonic.Marshal(payload) - if err != nil { - p.logger.Error("failed to marshal request body to check for virtual key: %v", err) - next(ctx) - return - } - ctx.Request.SetBody(body) - next(ctx) - } +// HTTPTransportIntercept intercepts requests before they are processed (governance decision point) +// It modifies the request in-place and returns nil to continue, or an HTTPResponse to short-circuit. +func (p *GovernancePlugin) HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + virtualKeyValue := parseVirtualKeyFromHTTPRequest(req) + if virtualKeyValue == nil { + return nil, nil + } + // Get the virtual key from the store + virtualKey, ok := p.store.GetVirtualKey(*virtualKeyValue) + if !ok || virtualKey == nil || !virtualKey.IsActive { + return nil, nil + } + headers, err := p.addMCPIncludeTools(nil, virtualKey) + if err != nil { + p.logger.Error("failed to add MCP include tools: %v", err) + return nil, nil + } + for header, value := range headers { + req.Headers[header] = value + } + if len(req.Body) == 0 { + return nil, nil + } + var payload map[string]any + err = sonic.Unmarshal(req.Body, &payload) + if err != nil { + p.logger.Error("failed to unmarshal request body to check for virtual key: %v", err) + return nil, nil + } + payload, err = p.loadBalanceProvider(payload, virtualKey) + if err != nil { + p.logger.Error("failed to load balance provider: %v", err) + return nil, nil + } + body, err := sonic.Marshal(payload) + if err != nil { + p.logger.Error("failed to marshal request body to check for virtual key: %v", err) + return nil, nil } + req.Body = body + return nil, nil } // loadBalanceProvider loads balances the provider for the request diff --git a/plugins/jsonparser/main.go b/plugins/jsonparser/main.go index d4dbcb5c02..5ce59be483 100644 --- a/plugins/jsonparser/main.go +++ b/plugins/jsonparser/main.go @@ -83,9 +83,9 @@ func (p *JsonParserPlugin) GetName() string { return PluginName } -// HTTPTransportMiddleware is not used for this plugin -func (p *JsonParserPlugin) HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { - return nil +// HTTPTransportIntercept is not used for this plugin +func (p *JsonParserPlugin) HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + return nil, nil } // PreHook is not used for this plugin as we only process responses diff --git a/plugins/logging/main.go b/plugins/logging/main.go index 7da12e7348..8678a07172 100644 --- a/plugins/logging/main.go +++ b/plugins/logging/main.go @@ -191,9 +191,9 @@ func (p *LoggerPlugin) GetName() string { return PluginName } -// HTTPTransportMiddleware is not used for this plugin -func (p *LoggerPlugin) HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { - return nil +// HTTPTransportIntercept is not used for this plugin +func (p *LoggerPlugin) HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + return nil, nil } // PreHook is called before a request is processed - FULLY ASYNC, NO DATABASE I/O diff --git a/plugins/maxim/main.go b/plugins/maxim/main.go index 1d15f32eca..bf1abf823a 100644 --- a/plugins/maxim/main.go +++ b/plugins/maxim/main.go @@ -156,9 +156,9 @@ func (plugin *Plugin) GetName() string { return PluginName } -// HTTPTransportMiddleware is not used for this plugin -func (plugin *Plugin) HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { - return nil +// HTTPTransportIntercept is not used for this plugin +func (plugin *Plugin) HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + return nil, nil } // getEffectiveLogRepoID determines which single log repo ID to use based on priority: diff --git a/plugins/mocker/main.go b/plugins/mocker/main.go index 32cfdb0a65..dbfc183803 100644 --- a/plugins/mocker/main.go +++ b/plugins/mocker/main.go @@ -478,9 +478,9 @@ func (p *MockerPlugin) GetName() string { return PluginName } -// HTTPTransportMiddleware is not used for this plugin -func (p *MockerPlugin) HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { - return nil +// HTTPTransportIntercept is not used for this plugin +func (p *MockerPlugin) HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + return nil, nil } // PreHook intercepts requests and applies mocking rules based on configuration diff --git a/plugins/otel/main.go b/plugins/otel/main.go index 31aab3dc00..cca58c13ac 100644 --- a/plugins/otel/main.go +++ b/plugins/otel/main.go @@ -145,9 +145,9 @@ func (p *OtelPlugin) GetName() string { return PluginName } -// HTTPTransportMiddleware is not used for this plugin -func (p *OtelPlugin) HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { - return nil +// HTTPTransportIntercept is not used for this plugin +func (p *OtelPlugin) HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + return nil, nil } // ValidateConfig function for the OTEL plugin diff --git a/plugins/semanticcache/main.go b/plugins/semanticcache/main.go index 3c47b00b00..0a55299432 100644 --- a/plugins/semanticcache/main.go +++ b/plugins/semanticcache/main.go @@ -335,9 +335,9 @@ func (plugin *Plugin) GetName() string { return PluginName } -// HTTPTransportMiddleware is not used for this plugin -func (plugin *Plugin) HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { - return nil +// HTTPTransportIntercept is not used for this plugin +func (plugin *Plugin) HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + return nil, nil } // PreHook is called before a request is processed by Bifrost. diff --git a/plugins/telemetry/main.go b/plugins/telemetry/main.go index d80eadad59..dfbdafc7a0 100644 --- a/plugins/telemetry/main.go +++ b/plugins/telemetry/main.go @@ -276,9 +276,9 @@ func (p *PrometheusPlugin) GetName() string { return PluginName } -// HTTPTransportMiddleware is not used for this plugin -func (p *PrometheusPlugin) HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { - return nil +// HTTPTransportIntercept is not used for this plugin +func (p *PrometheusPlugin) HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + return nil, nil } // PreHook records the start time of the request in the context. diff --git a/tests/integrations/config.json b/tests/integrations/config.json index b5a75f4f65..2e2730b924 100644 --- a/tests/integrations/config.json +++ b/tests/integrations/config.json @@ -180,8 +180,8 @@ "*" ], "enable_logging": true, - "enable_governance": true, - "enforce_governance_header": true, + "enable_governance": false, + "enforce_governance_header": false, "allow_direct_keys": false, "max_request_body_size_mb": 100, "enable_litellm_fallbacks": false diff --git a/transports/bifrost-http/handlers/middlewares.go b/transports/bifrost-http/handlers/middlewares.go index f084f12e04..9b6a17843a 100644 --- a/transports/bifrost-http/handlers/middlewares.go +++ b/transports/bifrost-http/handlers/middlewares.go @@ -46,7 +46,9 @@ func CorsMiddleware(config *lib.Config) schemas.BifrostHTTPMiddleware { } } -// TransportInterceptorMiddleware collects all plugin HTTP transport middleware and chains them. +// TransportInterceptorMiddleware runs all plugin HTTP transport interceptors. +// It converts the fasthttp request to a serializable HTTPRequest, runs all plugin interceptors, +// and applies any modifications back to the fasthttp context. func TransportInterceptorMiddleware(config *lib.Config) schemas.BifrostHTTPMiddleware { return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { @@ -55,18 +57,101 @@ func TransportInterceptorMiddleware(config *lib.Config) schemas.BifrostHTTPMiddl next(ctx) return } - pluginsMiddlewareChain := []schemas.BifrostHTTPMiddleware{} + + // Get or create BifrostContext from fasthttp context + bifrostCtx := getBifrostContextFromFastHTTP(ctx) + // Acquire pooled request + req := schemas.AcquireHTTPRequest() + defer schemas.ReleaseHTTPRequest(req) + fasthttpToHTTPRequest(ctx, req) + // Run plugin interceptors for _, plugin := range plugins { - middleware := plugin.HTTPTransportMiddleware() - // Collect plugin HTTP transport middleware - if middleware == nil { - continue + resp, err := plugin.HTTPTransportIntercept(bifrostCtx, req) + if err != nil { + // Short-circuit with error + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.SetBodyString(err.Error()) + return } - pluginsMiddlewareChain = append(pluginsMiddlewareChain, middleware) + if resp != nil { + // Short-circuit with response + applyHTTPResponseToCtx(ctx, resp) + return + } + // If we got here, the plugin may have modified req in-place + } + // Apply modifications back to fasthttp context + applyHTTPRequestToCtx(ctx, req) + // Adding user values + for key, value := range bifrostCtx.GetUserValues() { + ctx.SetUserValue(key, value) } - lib.ChainMiddlewares(next, pluginsMiddlewareChain...)(ctx) + next(ctx) + } + } +} + +// getBifrostContextFromFastHTTP gets or creates a BifrostContext from fasthttp context. +func getBifrostContextFromFastHTTP(ctx *fasthttp.RequestCtx) *schemas.BifrostContext { + return schemas.NewBifrostContext(ctx, schemas.NoDeadline) +} + +// fasthttpToHTTPRequest populates a pooled HTTPRequest from fasthttp context. +func fasthttpToHTTPRequest(ctx *fasthttp.RequestCtx, req *schemas.HTTPRequest) { + req.Method = string(ctx.Method()) + req.Path = string(ctx.Path()) + + // Copy headers + for key, values := range ctx.Request.Header.All() { + req.Headers[string(key)] = string(values) + } + + // Copy query params + for key, values := range ctx.Request.URI().QueryArgs().All() { + for _, value := range values { + req.Query[string(key)] = string(value) } } + + // Copy body + body := ctx.Request.Body() + if len(body) > 0 { + req.Body = make([]byte, len(body)) + copy(req.Body, body) + } +} + +// applyHTTPRequestToCtx applies modifications from HTTPRequest back to fasthttp context. +func applyHTTPRequestToCtx(ctx *fasthttp.RequestCtx, req *schemas.HTTPRequest) { + // If path/method is different, throw error + if req.Method != string(ctx.Method()) || req.Path != string(ctx.Path()) { + logger.Error("request method/path mismatch: %s %s != %s %s", req.Method, req.Path, string(ctx.Method()), string(ctx.Path())) + SendError(ctx, fasthttp.StatusConflict, "request method/path was modified by a plugin, this is not allowed") + return + } + // Apply headers + for key, value := range req.Headers { + ctx.Request.Header.Set(key, value) + } + // Apply query params + for key, value := range req.Query { + ctx.Request.URI().QueryArgs().Set(key, value) + } + // Apply body if set + if req.Body != nil { + ctx.Request.SetBody(req.Body) + } +} + +// applyHTTPResponseToCtx writes a short-circuit response to fasthttp context. +func applyHTTPResponseToCtx(ctx *fasthttp.RequestCtx, resp *schemas.HTTPResponse) { + ctx.SetStatusCode(resp.StatusCode) + for key, value := range resp.Headers { + ctx.Response.Header.Set(key, value) + } + if resp.Body != nil { + ctx.SetBody(resp.Body) + } } // validateSession checks if a session token is valid diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 825d8756e5..0ecea411eb 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -25,6 +25,7 @@ import ( "github.com/maximhq/bifrost/framework/encrypt" "github.com/maximhq/bifrost/framework/logstore" "github.com/maximhq/bifrost/framework/modelcatalog" + plugins "github.com/maximhq/bifrost/framework/plugins" "github.com/maximhq/bifrost/framework/vectorstore" "github.com/maximhq/bifrost/plugins/semanticcache" "gorm.io/gorm" @@ -232,7 +233,8 @@ type Config struct { EnvKeys map[string][]configstore.EnvKeyInfo // Plugin configs - atomic for lock-free reads with CAS updates - Plugins atomic.Pointer[[]schemas.Plugin] + Plugins atomic.Pointer[[]schemas.Plugin] + PluginLoader plugins.PluginLoader // Plugin configs from config file/database PluginConfigs []*schemas.PluginConfig @@ -473,7 +475,7 @@ func initStoresFromFile(ctx context.Context, config *Config, configData *ConfigD return nil } -// loadClientConfigFromFile loads and merges client config from file with store +// loadClientConfigFromFile loads and merges client config from file with store using hash-based reconciliation func loadClientConfigFromFile(ctx context.Context, config *Config, configData *ConfigData) { var clientConfig *configstore.ClientConfig var err error @@ -485,77 +487,30 @@ func loadClientConfigFromFile(ctx context.Context, config *Config, configData *C } } - if clientConfig != nil { - config.ClientConfig = *clientConfig - // For backward compatibility, handle cases where max request body size is not set - if config.ClientConfig.MaxRequestBodySizeMB == 0 { - config.ClientConfig.MaxRequestBodySizeMB = DefaultClientConfig.MaxRequestBodySizeMB - } - - // Merge with config file if present - if configData.Client != nil { - logger.Debug("merging client config from config file with store") - // DB takes priority, but fill in empty/zero values from config file - if config.ClientConfig.InitialPoolSize == 0 && configData.Client.InitialPoolSize != 0 { - config.ClientConfig.InitialPoolSize = configData.Client.InitialPoolSize - } - if len(config.ClientConfig.PrometheusLabels) == 0 && len(configData.Client.PrometheusLabels) > 0 { - config.ClientConfig.PrometheusLabels = configData.Client.PrometheusLabels - } - if len(config.ClientConfig.AllowedOrigins) == 0 && len(configData.Client.AllowedOrigins) > 0 { - config.ClientConfig.AllowedOrigins = configData.Client.AllowedOrigins - } - if config.ClientConfig.MaxRequestBodySizeMB == 0 && configData.Client.MaxRequestBodySizeMB != 0 { - config.ClientConfig.MaxRequestBodySizeMB = configData.Client.MaxRequestBodySizeMB - } - // Boolean fields: only override if DB has false and config file has true - if !config.ClientConfig.DropExcessRequests && configData.Client.DropExcessRequests { - config.ClientConfig.DropExcessRequests = configData.Client.DropExcessRequests - } - if !config.ClientConfig.EnableLogging && configData.Client.EnableLogging { - config.ClientConfig.EnableLogging = configData.Client.EnableLogging - } - if !config.ClientConfig.DisableContentLogging && configData.Client.DisableContentLogging { - config.ClientConfig.DisableContentLogging = configData.Client.DisableContentLogging - } - if !config.ClientConfig.EnableGovernance && configData.Client.EnableGovernance { - config.ClientConfig.EnableGovernance = configData.Client.EnableGovernance - } - if !config.ClientConfig.EnforceGovernanceHeader && configData.Client.EnforceGovernanceHeader { - config.ClientConfig.EnforceGovernanceHeader = configData.Client.EnforceGovernanceHeader - } - if !config.ClientConfig.AllowDirectKeys && configData.Client.AllowDirectKeys { - config.ClientConfig.AllowDirectKeys = configData.Client.AllowDirectKeys - } - if !config.ClientConfig.EnableLiteLLMFallbacks && configData.Client.EnableLiteLLMFallbacks { - config.ClientConfig.EnableLiteLLMFallbacks = configData.Client.EnableLiteLLMFallbacks - } - if config.ClientConfig.MCPAgentDepth == 0 && configData.Client.MCPAgentDepth != 0 { - config.ClientConfig.MCPAgentDepth = configData.Client.MCPAgentDepth - } - if config.ClientConfig.MCPToolExecutionTimeout == 0 && configData.Client.MCPToolExecutionTimeout != 0 { - config.ClientConfig.MCPToolExecutionTimeout = configData.Client.MCPToolExecutionTimeout - } - if config.ClientConfig.MCPCodeModeBindingLevel == "" && configData.Client.MCPCodeModeBindingLevel != "" { - config.ClientConfig.MCPCodeModeBindingLevel = configData.Client.MCPCodeModeBindingLevel - } - // Update store with merged config - if config.ConfigStore != nil { - logger.Debug("updating merged client config in store") - if err = config.ConfigStore.UpdateClientConfig(ctx, &config.ClientConfig); err != nil { - logger.Warn("failed to update merged client config: %v", err) - } - } - } - } else { + // Case 1: No config in DB - use file config (or defaults) + if clientConfig == nil { logger.Debug("client config not found in store, using config file") if configData.Client != nil { config.ClientConfig = *configData.Client if config.ClientConfig.MaxRequestBodySizeMB == 0 { config.ClientConfig.MaxRequestBodySizeMB = DefaultClientConfig.MaxRequestBodySizeMB } + // Generate hash for the file config + fileHash, hashErr := configData.Client.GenerateClientConfigHash() + if hashErr != nil { + logger.Warn("failed to generate client config hash: %v", hashErr) + } else { + config.ClientConfig.ConfigHash = fileHash + } } else { config.ClientConfig = DefaultClientConfig + // Generate hash for default config + defaultHash, hashErr := config.ClientConfig.GenerateClientConfigHash() + if hashErr != nil { + logger.Warn("failed to generate default client config hash: %v", hashErr) + } else { + config.ClientConfig.ConfigHash = defaultHash + } } if config.ConfigStore != nil { logger.Debug("updating client config in store") @@ -563,6 +518,48 @@ func loadClientConfigFromFile(ctx context.Context, config *Config, configData *C logger.Warn("failed to update client config: %v", err) } } + return + } + + // Case 2: Config exists in DB + config.ClientConfig = *clientConfig + // For backward compatibility, handle cases where max request body size is not set + if config.ClientConfig.MaxRequestBodySizeMB == 0 { + config.ClientConfig.MaxRequestBodySizeMB = DefaultClientConfig.MaxRequestBodySizeMB + } + + // Case 2a: No file config - use DB config as-is + if configData.Client == nil { + logger.Debug("no client config in file, using DB config") + return + } + + // Case 2b: Both DB and file config exist - use hash-based reconciliation + fileHash, hashErr := configData.Client.GenerateClientConfigHash() + if hashErr != nil { + logger.Warn("failed to generate client config hash from file: %v", hashErr) + return + } + + if clientConfig.ConfigHash != fileHash { + // Hash mismatch - config.json was changed, sync from file + logger.Debug("client config hash mismatch, syncing from config file") + config.ClientConfig = *configData.Client + config.ClientConfig.ConfigHash = fileHash + // Apply defaults for zero values + if config.ClientConfig.MaxRequestBodySizeMB == 0 { + config.ClientConfig.MaxRequestBodySizeMB = DefaultClientConfig.MaxRequestBodySizeMB + } + // Update store with file config + if config.ConfigStore != nil { + logger.Debug("updating client config in store from file") + if err = config.ConfigStore.UpdateClientConfig(ctx, &config.ClientConfig); err != nil { + logger.Warn("failed to update client config: %v", err) + } + } + } else { + // Hash matches - keep DB config (preserves UI changes) + logger.Debug("client config hash matches, keeping DB config") } } diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 61b87cfa26..3c4c3a9b5a 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -220,7 +220,7 @@ func LoadPlugin[T schemas.Plugin](ctx context.Context, name string, path *string if path != nil { logger.Info("loading dynamic plugin %s from path %s", name, *path) // Load dynamic plugin - plugins, err := dynamicPlugins.LoadPlugins(&dynamicPlugins.Config{ + plugins, err := dynamicPlugins.LoadPlugins(bifrostConfig.PluginLoader, &dynamicPlugins.Config{ Plugins: []dynamicPlugins.DynamicPluginConfig{ { Path: *path, @@ -1194,6 +1194,8 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to load config %v", err) } + // Initializing plugin loader + s.Config.PluginLoader = &dynamicPlugins.SharedObjectPluginLoader{} // Initialize log retention cleaner if log store is configured if s.Config.LogsStore != nil { // If log retention days remains 0, then we wont be initializing the log retention cleaner