diff --git a/.github/workflows/transport-ci.yml b/.github/workflows/transport-ci.yml new file mode 100644 index 0000000000..b350d4ba19 --- /dev/null +++ b/.github/workflows/transport-ci.yml @@ -0,0 +1,200 @@ +name: Transport CI - Dependency Update and Docker Build + +on: + push: + tags: + - "core/v*" # Triggers dependency update + - "transports/v*" # Triggers docker build (for manual tags) + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + REGISTRY: docker.io + ACCOUNT: maximhq + IMAGE_NAME: bifrost + +jobs: + update-transport-dependency: + if: startsWith(github.ref, 'refs/tags/core/v') + runs-on: ubuntu-latest + permissions: + contents: write + outputs: + new_transport_tag: ${{ steps.next_version.outputs.new_tag }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + ref: main + fetch-depth: 0 + fetch-tags: true + token: ${{ secrets.GH_TOKEN }} + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: "1.24.1" + + - name: Get and validate core version from tag + id: get_version + run: | + TAG_NAME=${GITHUB_REF#refs/tags/core/} + + # Validate core tag format + if ! echo "$TAG_NAME" | grep -qE '^v[0-9]+\.[0-9]+\.[0-9]+$'; then + echo "Error: Invalid core tag format 'core/$TAG_NAME'. Expected format: core/vMAJOR.MINOR.PATCH" + exit 1 + fi + + echo "version=${TAG_NAME}" >> $GITHUB_OUTPUT + echo "Core version: ${TAG_NAME}" + + - name: Configure Git + run: | + git config user.name "GitHub Actions Bot" + git config user.email "github-actions[bot]@users.noreply.github.com" + + - name: Get latest transport version and increment + id: next_version + run: | + # Get the latest transport tag (using transports/ prefix to match docker build) + LATEST_TAG=$(git tag -l 'transports/v*' | sort -V | tail -n 1) + if [ -z "$LATEST_TAG" ]; then + # If no transport tag exists, start with v0.1.0 + NEW_TAG="transports/v0.1.0" + else + # Extract version numbers + VERSION=${LATEST_TAG#transports/v} + + # Validate version format + if ! echo "$VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+$'; then + echo "Error: Invalid tag format '$LATEST_TAG'. Expected format: transports/vMAJOR.MINOR.PATCH" + exit 1 + fi + + MAJOR=$(echo $VERSION | cut -d. -f1) + MINOR=$(echo $VERSION | cut -d. -f2) + PATCH=$(echo $VERSION | cut -d. -f3) + + # Increment patch version + NEW_PATCH=$((PATCH + 1)) + NEW_TAG="transports/v${MAJOR}.${MINOR}.${NEW_PATCH}" + fi + + # Check if the new tag already exists + if git tag --list | grep -q "^${NEW_TAG}$"; then + echo "Error: Tag '$NEW_TAG' already exists!" + exit 1 + fi + + echo "new_tag=${NEW_TAG}" >> $GITHUB_OUTPUT + echo "New transport version will be: ${NEW_TAG}" + + - name: Update transport dependency + working-directory: transports + run: | + echo "Updating core dependency to ${{ steps.get_version.outputs.version }}" + if ! go get github.com/maximhq/bifrost/core@${{ steps.get_version.outputs.version }}; then + echo "Error: Failed to fetch core version ${{ steps.get_version.outputs.version }}" + exit 1 + fi + go mod tidy + + - name: Build transport + working-directory: transports + run: go build ./... + + - name: Commit and push changes + run: | + git add transports/go.mod transports/go.sum + if git diff --staged --quiet; then + echo "No changes to commit. Dependency is already up to date." + else + git commit -m "chore: update transport's core dependency to ${{ steps.get_version.outputs.version }}" + git push + fi + + - name: Create and push transport tag + run: | + git tag ${{ steps.next_version.outputs.new_tag }} + git push origin ${{ steps.next_version.outputs.new_tag }} + + build-and-push-docker: + if: always() && needs.update-transport-dependency.result != 'failure' && (startsWith(github.ref, 'refs/tags/transports/v') || needs.update-transport-dependency.result == 'success') + needs: [update-transport-dependency] + runs-on: ubuntu-latest + permissions: + contents: read + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Extract and validate metadata + id: meta + run: | + # Determine the tag to use + if [ "${{ needs.update-transport-dependency.outputs.new_transport_tag }}" != "" ]; then + # Use the tag created by dependency update + TAG="${{ needs.update-transport-dependency.outputs.new_transport_tag }}" + else + # Use the tag that triggered this workflow (manual tag) + TAG=${GITHUB_REF#refs/tags/} + fi + + echo "tag=${TAG}" >> $GITHUB_OUTPUT + + # Extract version from tag + VERSION=${TAG#transports/v} + + # Validate version format + if ! echo "$VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+$'; then + echo "Error: Invalid tag format '$TAG'. Expected format: transports/vMAJOR.MINOR.PATCH" + exit 1 + fi + + # Create image tags (Docker tags cannot contain slashes, so use version only) + echo "tags<> $GITHUB_OUTPUT + echo "${{ env.REGISTRY }}/${{ env.ACCOUNT }}/${{ env.IMAGE_NAME }}:v${VERSION}" >> $GITHUB_OUTPUT + echo "${{ env.REGISTRY }}/${{ env.ACCOUNT }}/${{ env.IMAGE_NAME }}:latest" >> $GITHUB_OUTPUT + echo "EOF" >> $GITHUB_OUTPUT + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Generate timestamp + id: timestamp + run: echo "created_at=$(date -u +'%Y-%m-%dT%H:%M:%SZ')" >> $GITHUB_OUTPUT + + - name: Build and push Docker image + uses: docker/build-push-action@v5 + with: + context: ./transports + file: ./transports/Dockerfile + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: | + org.opencontainers.image.title=Bifrost LLM Gateway (HTTP) + org.opencontainers.image.description=The fastest LLM gateway written in Go. Learn more here: https://github.com/maximhq/bifrost + org.opencontainers.image.source=${{ github.server_url }}/${{ github.repository }} + org.opencontainers.image.version=${{ steps.meta.outputs.tag }} + org.opencontainers.image.created=${{ steps.timestamp.outputs.created_at }} + org.opencontainers.image.revision=${{ github.sha }} + build-args: | + TRANSPORT_TYPE=http + platforms: linux/amd64,linux/arm64 + cache-from: type=gha + cache-to: type=gha,mode=max + + - name: Image digest + run: echo "Image pushed successfully with tags from previous step" diff --git a/.gitignore b/.gitignore index 48303bc0a5..1ce9c03db4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ .env .vscode .DS_Store +*_creds* +**/venv/ +**/__pycache__/** +private.* diff --git a/README.md b/README.md index e227739c01..774011154d 100644 --- a/README.md +++ b/README.md @@ -1,25 +1,186 @@ # Bifrost +[![Go Report Card](https://goreportcard.com/badge/github.com/maximhq/bifrost/core)](https://goreportcard.com/report/github.com/maximhq/bifrost/core) + Bifrost is an open-source middleware that serves as a unified gateway to various AI model providers, enabling seamless integration and fallback mechanisms for your AI-powered applications. +## ⚑ Quickstart + +### Prerequisites + +- Go 1.23 or higher (not needed if using Docker) +- Access to at least one AI model provider (OpenAI, Anthropic, etc.) +- API keys for the providers you wish to use + +### A. Using Bifrost as an HTTP Server + +1. **Create `config.json`**: This file should contain your provider settings and API keys. + + ```json + { + "providers": { + "openai": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": ["gpt-4o-mini"], + "weight": 1.0 + } + ] + } + } + } + ``` + +2. **Set Up Your Environment**: Add your environment variable to the session. + + ```bash + export OPENAI_API_KEY=your_openai_api_key + export ANTHROPIC_API_KEY=your_anthropic_api_key + ``` + + Note: Ensure you add all variables stated in your `config.json` file. + +3. **Start the Bifrost HTTP Server**: + + You can run the server using either a Go Binary or Docker (if Go is not installed). + + #### i) Using Go Binary + + - Install the transport package: + + ```bash + go install github.com/maximhq/bifrost/transports/bifrost-http@latest + ``` + + - Run the server (ensure Go is in your PATH): + + ```bash + bifrost-http -config config.json -port 8080 -pool-size 300 + ``` + + #### ii) OR Using Docker + + - Pull the Docker image: + + ```bash + docker pull maximhq/bifrost + ``` + + - Run the Docker container: + + ```bash + docker run -p 8080:8080 \ + -v $(pwd)/config.json:/app/config/config.json \ + -e OPENAI_API_KEY \ + -e ANTHROPIC_API_KEY \ + maximhq/bifrost + ``` + + Note: Ensure you mount your config file and add all environment variables referenced in your `config.json` file. + +4. **Using the API**: Once the server is running, you can send requests to the HTTP endpoints. + + ```bash + curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "provider": "openai", + "model": "gpt-4o-mini", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me about Bifrost in Norse mythology."} + ] + }' + ``` + +For additional HTTP server configuration options, read [this](https://github.com/maximhq/bifrost/blob/main/transports/README.md). + +### B. Using Bifrost as a Go Package + +1. **Implement Your Account Interface**: First, create an account that follows [Bifrost's account interface](https://github.com/maximhq/bifrost/blob/main/core/schemas/account.go). + + ```golang + type BaseAccount struct{} + + func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + return []schemas.ModelProvider{schemas.OpenAI}, nil + } + + func (baseAccount *BaseAccount) GetKeysForProvider(providerKey schemas.ModelProvider) ([]schemas.Key, error) { + return []schemas.Key{ + { + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{"gpt-4o-mini"}, + Weight: 1.0, + }, + }, nil + } + + func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + } + ``` + + Bifrost uses these methods to get all the keys and configurations it needs to call the providers. See the [Additional Configurations](#additional-configurations) section for additional customization options. + +2. **Initialize Bifrost**: Set up the Bifrost instance by providing your account implementation. + + ```golang + account := BaseAccount{} + + client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &account, + }) + ``` + +3. **Use Bifrost**: Make your First LLM Call! + + ```golang + bifrostResult, bifrostErr := bifrost.ChatCompletionRequest( + context.Background(), + &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", // make sure you have configured gpt-4o-mini in your account interface + Input: schemas.RequestInput{ + ChatCompletionInput: bifrost.Ptr([]schemas.BifrostMessage{{ + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("What is a LLM gateway?"), + }, + }}), + }, + }, + ) + ``` + + You can add model parameters by including `Params: &schemas.ModelParameters{...yourParams}` in ChatCompletionRequest. + ## πŸ“‘ Table of Contents - [Bifrost](#bifrost) + - [⚑ Quickstart](#-quickstart) + - [Prerequisites](#prerequisites) + - [A. Using Bifrost as an HTTP Server](#a-using-bifrost-as-an-http-server) + - [i) Using Go Binary](#i-using-go-binary) + - [ii) OR Using Docker](#ii-or-using-docker) + - [B. Using Bifrost as a Go Package](#b-using-bifrost-as-a-go-package) - [πŸ“‘ Table of Contents](#-table-of-contents) - [πŸ” Overview](#-overview) - [✨ Features](#-features) - [πŸ—οΈ Repository Structure](#️-repository-structure) + - [πŸš€ Getting Started](#-getting-started) + - [Package Structure](#package-structure) + - [Additional Configurations](#additional-configurations) - [πŸ“Š Benchmarks](#-benchmarks) - [Test Environment](#test-environment) - - [t3.medium Instance](#t3medium-instance) - - [t3.xlarge Instance](#t3xlarge-instance) + - [1. t3.medium(2 vCPUs, 4GB RAM)](#1-t3medium2-vcpus-4gb-ram) + - [2. t3.xlarge(4 vCPUs, 16GB RAM)](#2-t3xlarge4-vcpus-16gb-ram) - [Performance Metrics](#performance-metrics) - [Key Performance Highlights](#key-performance-highlights) - - [πŸš€ Getting Started](#-getting-started) - - [Package Structure](#package-structure) - - [Prerequisites](#prerequisites) - - [Setting up Bifrost](#setting-up-bifrost) - - [Additional Configurations](#additional-configurations) - [🀝 Contributing](#-contributing) - [πŸ“„ License](#-license) @@ -27,7 +188,7 @@ Bifrost is an open-source middleware that serves as a unified gateway to various ## πŸ” Overview -Bifrost acts as a bridge between your applications and multiple AI providers (OpenAI, Anthropic, Amazon Bedrock, etc.). It provides a consistent API interface while handling: +Bifrost acts as a bridge between your applications and multiple AI providers (OpenAI, Anthropic, Amazon Bedrock, Mistral, Ollama, etc.). It provides a consistent API while handling: - Authentication and key management - Request routing and load balancing @@ -41,13 +202,16 @@ With Bifrost, you can focus on building your AI-powered applications without wor ## ✨ Features -- **Multi-Provider Support**: Integrate with OpenAI, Anthropic, Amazon Bedrock, and more through a single API +- **Multi-Provider Support**: Integrate with OpenAI, Anthropic, Amazon Bedrock, Mistral, Ollama, and more through a single API - **Fallback Mechanisms**: Automatically retry failed requests with alternative models or providers -- **Dynamic Key Management**: Rotate and manage API keys efficiently +- **Dynamic Key Management**: Rotate and manage API keys efficiently with weighted distribution - **Connection Pooling**: Optimize network resources for better performance - **Concurrency Control**: Manage rate limits and parallel requests effectively -- **HTTP Transport**: RESTful API interface for easy integration -- **Custom Configuration**: Flexible JSON-based configuration +- **Flexible Transports**: Multiple transports for easy integration into your infra +- **Plugin First Architecture**: No callback hell, simple addition/creation of custom plugins +- **MCP Integration**: Built-in Model Context Protocol (MCP) support for external tool integration and execution +- **Custom Configuration**: Offers granular control over pool sizes, network retry settings, fallback providers, and network proxy configurations +- **Built-in Observability**: Native Prometheus metrics out of the box, no wrappers, no sidecars, just drop it in and scrape --- @@ -55,332 +219,149 @@ With Bifrost, you can focus on building your AI-powered applications without wor Bifrost is built with a modular architecture: -``` +```text bifrost/ β”œβ”€β”€ core/ # Core functionality and shared components β”‚ β”œβ”€β”€ providers/ # Provider-specific implementations β”‚ β”œβ”€β”€ schemas/ # Interfaces and structs used in bifrost -β”‚ β”œβ”€β”€ tests/ # Tests to make sure everything is in place β”‚ β”œβ”€β”€ bifrost.go # Main Bifrost implementation -β”‚ +β”‚ +β”œβ”€β”€ docs/ # Documentations for Bifrost's configurations and contribution guides +β”‚ └── ... +β”‚ +β”œβ”€β”€ tests/ # All test setups related to /core and /transports +β”‚ └── ... +β”‚ β”œβ”€β”€ transports/ # Interface layers (HTTP, gRPC, etc.) -β”‚ β”œβ”€β”€ http/ # HTTP transport implementation +β”‚ β”œβ”€β”€ bifrost-http/ # HTTP transport implementation β”‚ └── ... β”‚ └── plugins/ # Plugin Implementations - β”œβ”€β”€ maxim-logger.go + β”œβ”€β”€ maxim/ └── ... ``` -The system uses a provider-agnostic approach with well-defined interfaces to easily extend to new AI providers. All interfaces are defined in `core/schemas/` and can be used as a reference for adding new plugins. +The system uses a provider-agnostic approach with well-defined interfaces to easily extend to new AI providers. All interfaces are defined in `core/schemas/` and can be used as a reference for contributions. --- -## πŸ“Š Benchmarks +## πŸš€ Getting Started -Bifrost has been tested under high load conditions to ensure optimal performance. The following results were obtained from benchmark tests running at 5000 requests per second (RPS) on different AWS EC2 instances, with Bifrost running inside Docker containers. +If you want to **set up the Bifrost API quickly**, [check the transports documentation](https://github.com/maximhq/bifrost/tree/main/transports/README.md). -### Test Environment +### Package Structure -#### t3.medium Instance -- **Instance**: AWS EC2 t3.medium -- **vCPUs**: 2 -- **Memory**: 4GB RAM -- **Container**: Docker container with resource limits matching instance specs -- **Bifrost Configurations**: - - Buffer Size: 15,000 - - Initial Pool Size: 10,000 - -#### t3.xlarge Instance -- **Instance**: AWS EC2 t3.xlarge -- **vCPUs**: 4 -- **Memory**: 16GB RAM -- **Container**: Docker container with resource limits matching instance specs -- **Bifrost Configurations**: - - Buffer Size: 20,000 - - Initial Pool Size: 15,000 +Bifrost is divided into three Go packages: core, plugins, and transports. -### Performance Metrics +1. **core**: This package contains the core implementation of Bifrost as a Go package. +2. **plugins**: This package serves as an extension to core. Plugins support both traditional instantiation and dynamic loading via configuration files. -| Metric | t3.medium | t3.xlarge | -|--------|-----------|-----------| -| Success Rate | 100.00% | 100.00% | -| Average Request Size | 0.13 KB | 0.13 KB | -| **Average Response Size** | **`1.37 KB`** | **`10.32 KB`** | -| Average Latency | 2.12s | 1.61s | -| Peak Memory Usage | 1312.79 MB | 3340.44 MB | -| Queue Wait Time | 47.13 Β΅s | 1.67 Β΅s | -| Key Selection Time | 16 ns | 10 ns | -| Message Formatting | 2.19 Β΅s | 2.11 Β΅s | -| Params Preparation | 436 ns | 417 ns | -| Request Body Preparation | 2.65 Β΅s | 2.36 Β΅s | -| JSON Marshaling | 63.47 Β΅s | 26.80 Β΅s | -| Request Setup | 6.59 Β΅s | 7.17 Β΅s | -| HTTP Request | 1.56s | 1.50s | -| Error Handling | 189 ns | 162 ns | -| Response Parsing | 11.30 ms | 2.11 ms | + **Traditional Plugin Usage:** -### Key Performance Highlights + ```golang + // go get github.com/maximhq/bifrost/plugins/maxim + maximPlugin, err := maxim.NewMaximLoggerPlugin(os.Getenv("MAXIM_API_KEY"), os.Getenv("MAXIM_LOGGER_ID")) + if err != nil { + return nil, err + } -- **Perfect Success Rate**: 100% request success rate under high load on both instances -- **Efficient Queue Management**: Minimal queue wait time (1.67 Β΅s on t3.xlarge) -- **Fast Key Selection**: Near-instantaneous key selection (10 ns on t3.xlarge) -- **Optimized Memory Usage**: - - t3.medium: ~1.3GB at 5000 RPS - - t3.xlarge: ~3.3GB at 5000 RPS -- **Efficient Request Processing**: Most operations complete in microseconds -- **Network Efficiency**: - - Consistent small request sizes (0.13 KB) across instances - - Larger response sizes on t3.xlarge (10.32 KB vs 1.37 KB) due to more detailed responses -- **Improved Performance on t3.xlarge**: - - 24% faster average latency - - 81% faster response parsing - - 58% faster JSON marshaling - - Significantly reduced queue wait times - - Higher buffer and pool sizes enabled by increased resources - -These benchmarks demonstrate Bifrost's ability to handle high-throughput scenarios while maintaining reliability and performance, even when containerized. The t3.xlarge instance shows improved performance across most metrics, particularly in processing times and latency, while maintaining the same high reliability and success rate. The larger response sizes on t3.xlarge indicate its ability to handle more detailed responses without compromising performance. + // Initialize Bifrost + client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{maximPlugin}, + }) + ``` -One of Bifrost's key strengths is its flexibility in configuration. You can freely decide the tradeoff between memory usage and processing speed by adjusting Bifrost's configurations: + **Dynamic Plugin Loading:** + All plugins must implement an `Init(config json.RawMessage) (schemas.Plugin, error)` function for dynamic loading via configuration files. See [Plugin Documentation](https://github.com/maximhq/bifrost/blob/main/docs/plugins.md) for details. -- **Memory vs Speed Tradeoff**: - - Higher buffer and pool sizes (like in t3.xlarge) improve speed but use more memory - - Lower configurations (like in t3.medium) use less memory but may have slightly higher latencies - - You can fine-tune these parameters based on your specific needs and available resources +3. **transports**: This package contains transport clients like HTTP to expose your Bifrost client. You can either `go get` this package or directly use the independent Dockerfile to quickly spin up your [Bifrost API](https://github.com/maximhq/bifrost/tree/main/transports/README.md) (read more on this). -- **Customizable Parameters**: - - Buffer Size: Controls the maximum number of concurrent requests - - Initial Pool Size: Determines the initial allocation of resources - - Concurrency Settings: Adjustable per provider - - Retry and Timeout Configurations: Customizable based on your requirements +### Additional Configurations -This flexibility allows you to optimize Bifrost for your specific use case, whether you prioritize speed, memory efficiency, or a balance between the two. +- [Memory Management](https://github.com/maximhq/bifrost/blob/main/docs/memory-management.md) +- [Logger](https://github.com/maximhq/bifrost/blob/main/docs/logger.md) +- [Plugins](https://github.com/maximhq/bifrost/blob/main/docs/plugins.md) +- [Provider Configurations](https://github.com/maximhq/bifrost/blob/main/docs/providers.md) +- [Fallbacks](https://github.com/maximhq/bifrost/blob/main/docs/fallbacks.md) +- [MCP Integration](https://github.com/maximhq/bifrost/blob/main/docs/mcp.md) --- -## πŸš€ Getting Started - -If you want to **set up the Bifrost API quickly**, [check the transports documentation](https://github.com/maximhq/bifrost/tree/main/transports/README.md). - -### Package Structure - -Bifrost is divided into three Go packages: core, plugins, and transports. - -1. **core**: This package contains the core implementation of Bifrost as a Go package. - -2. **plugins**: This package serves as an extension to core. You can download this package using `go get github.com/maximhq/bifrost/plugins` and pass the plugins while initializing Bifrost. - - ```golang - plugin, err := plugins.NewMaximLoggerPlugin(os.Getenv("MAXIM_API_KEY"), os.Getenv("MAXIM_LOGGER_ID")) - if err != nil { - return nil, err - } - - // Initialize Bifrost - client, err := bifrost.Init(schemas.BifrostConfig{ - Account: &account, - Plugins: []schemas.Plugin{plugin}, - }) - ``` - -3. **transports**: This package contains transport clients like HTTP to expose your Bifrost client. You can either `go get` this package or directly use the independent Dockerfile to quickly spin up your Bifrost API interface ([Click here](https://github.com/maximhq/bifrost/tree/main/transports/README.md) to read more on this). - -### Prerequisites - -- Go 1.23 or higher -- Access to at least one AI model provider (OpenAI, Anthropic, etc.) -- API keys for the providers you wish to use - -### Setting up Bifrost - -1. Setting up your account: You first need to create your account which follows [Bifrost's account interface](https://github.com/maximhq/bifrost/blob/main/core/schemas/account.go). - -Example: - ```golang - type BaseAccount struct{} - - func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { - return []schemas.ModelProvider{schemas.OpenAI}, nil - } - - func (baseAccount *BaseAccount) GetKeysForProvider(providerKey schemas.ModelProvider) ([]schemas.Key, error) { - switch providerKey { - case schemas.OpenAI: - return []schemas.Key{ - { - Value: os.Getenv("OPENAI_API_KEY"), - Models: []string{"gpt-4o-mini"}, - }, - }, nil - default: - return nil, fmt.Errorf("unsupported provider: %s", providerKey) - } - } - - func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { - switch providerKey { - case schemas.OpenAI: - return &schemas.ProviderConfig{ - ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, - BufferSize: 10, - }, - }, nil - default: - return nil, fmt.Errorf("unsupported provider: %s", providerKey) - } - } - ``` - -Bifrost uses these methods to get all the keys and configurations it needs to call the providers. You can check the [Additional Configurations](#additional-configurations) section for further customizations. - -2. Get bifrost core package: Simply run `go get github.com/maximhq/bifrost/core` to download bifrost/core package. - -3. Initialising Bifrost: Initialise bifrost by providing your account implementation - -```golang -client, err := bifrost.Init(schemas.BifrostConfig{ - Account: &yourAccount, -}) -``` +## πŸ“Š Benchmarks -4. Make your First LLM Call! - -```golang - msg = "What is a LLM gateway?" - messages := []schemas.Message{ - { Role: schemas.RoleUser, Content: &msg }, - } - - bifrostResult, bifrostErr := bifrost.ChatCompletionRequest( - schemas.OpenAI, &schemas.BifrostRequest{ - Model: "gpt-4o", // make sure you have configured gpt-4o in your account interface - Input: schemas.RequestInput{ - ChatCompletionInput: &messages, - }, - }, context.Background() - ) -``` +Bifrost has been tested under high load conditions to ensure optimal performance. The following results were obtained from benchmark tests running at 5000 requests per second (RPS) on different AWS EC2 instances. -you can add model parameters by passing them in `Params:&schemas.ModelParameters{...yourParams}` ChatCompletionRequest. +### Test Environment -### Additional Configurations +#### 1. t3.medium(2 vCPUs, 4GB RAM) -1. InitalPoolSize and DropExcessRequests: You can customise the initial pool size of the structs and channels bifrost creates on `bifrost.Init()`. A higher value would mean lesser run time allocations and lower latency but at the cost of more memory usage. Takes the defined default value if not provided. +- Buffer Size: 15,000 +- Initial Pool Size: 10,000 -```golang - client, err := bifrost.Init(schemas.BifrostConfig{ - Account: &yourAccount, - InitialPoolSize: 500, - DropExcessRequests: true, - }) -``` +#### 2. t3.xlarge(4 vCPUs, 16GB RAM) -When `DropExcessRequests` is set to true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. By default it is set to false. +- Buffer Size: 20,000 +- Initial Pool Size: 15,000 -2. Logger: Like account interface, bifrost also allows you to pass your custom logger if it follows [bifrost's logger interface](https://github.com/maximhq/bifrost/blob/main/core/schemas/logger.go). Takes in the [default logger](https://github.com/maximhq/bifrost/blob/main/core/logger.go) if not provided. +### Performance Metrics -```golang - client, err := bifrost.Init(schemas.BifrostConfig{ - Account: &yourAccount, - Logger: &yourLogger, - }) -``` +| Metric | t3.medium | t3.xlarge | +| ------------------------- | ------------- | -------------- | +| Success Rate | 100.00% | 100.00% | +| Average Request Size | 0.13 KB | 0.13 KB | +| **Average Response Size** | **`1.37 KB`** | **`10.32 KB`** | +| Average Latency | 2.12s | 1.61s | +| Peak Memory Usage | 1312.79 MB | 3340.44 MB | +| Queue Wait Time | 47.13 Β΅s | 1.67 Β΅s | +| Key Selection Time | 16 ns | 10 ns | +| Message Formatting | 2.19 Β΅s | 2.11 Β΅s | +| Params Preparation | 436 ns | 417 ns | +| Request Body Preparation | 2.65 Β΅s | 2.36 Β΅s | +| JSON Marshaling | 63.47 Β΅s | 26.80 Β΅s | +| Request Setup | 6.59 Β΅s | 7.17 Β΅s | +| HTTP Request | 1.56s | 1.50s | +| Error Handling | 189 ns | 162 ns | +| Response Parsing | 11.30 ms | 2.11 ms | +| **Bifrost's Overhead** | **`59 Β΅s\*`** | **`11 Β΅s\*`** | + +_\*Bifrost's overhead is measured at 59 Β΅s on t3.medium and 11 Β΅s on t3.xlarge, excluding the time taken for JSON marshalling and the HTTP call to the LLM, both of which are required in any custom implementation._ + +**Note**: On the t3.xlarge, we tested with significantly larger response payloads (~10 KB average vs ~1 KB on t3.medium). Even so, response parsing time dropped dramatically thanks to better CPU throughput and Bifrost's optimized memory reuse. -The default logger is set to level info by default. If you wish to use it but with a different log level, you can do it like this - +### Key Performance Highlights -```golang - client, err := bifrost.Init(schemas.BifrostConfig{ - Account: &yourAccount, - Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug), - }) -``` +- **Perfect Success Rate**: 100% request success rate under high load on both instances +- **Total Overhead**: Less than only _15Β΅s added per request_ on average +- **Efficient Queue Management**: Minimal queue wait time (1.67 Β΅s on t3.xlarge) +- **Fast Key Selection**: Near-instantaneous key selection (10 ns on t3.xlarge) +- **Improved Performance on t3.xlarge**: + - 24% faster average latency + - 81% faster response parsing + - 58% faster JSON marshaling + - Significantly reduced queue wait times -3. Plugins: You can create and pass your custom pre-hook and post-hook plugins to bifrost as long as they follow [bifrost's plugin interface](https://github.com/maximhq/bifrost/blob/main/core/schemas/plugin.go). +One of Bifrost's key strengths is its flexibility in configuration. You can freely decide the tradeoff between memory usage and processing speed by adjusting Bifrost's configurations. This flexibility allows you to optimize Bifrost for your specific use case, whether you prioritize speed, memory efficiency, or a balance between the two. -```golang - client, err := bifrost.Init(schemas.BifrostConfig{ - Account: &yourAccount, - Plugins: []schemas.Plugin{yourPlugin1, yourPlugin2, ...}, - }) -``` +- Higher buffer and pool sizes (like in t3.xlarge) improve speed but use more memory +- Lower configurations (like in t3.medium) use less memory but may have slightly higher latencies +- You can fine-tune these parameters based on your specific needs and available resources -4. Customise your provider settings: You can customise proxy config, timeouts, retry settings, concurrency buffer sizes for each of your provider in your account interface's GetConfigForProvider() method. - -exmaple: -```golang - schemas.ProviderConfig{ - NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, - MaxRetries: 2, - RetryBackoffInitial: 100 * time.Millisecond, - RetryBackoffMax: 2 * time.Second, - }, - MetaConfig: &meta.BedrockMetaConfig{ - SecretAccessKey: os.Getenv("BEDROCK_ACCESS_KEY"), - Region: StrPtr("us-east-1"), - }, - ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, - BufferSize: 10, - }, - ProxyConfig: &schemas.ProxyConfig{ - Type: schemas.HttpProxy, - URL: yourProxyURL, - }, - } -``` + - Initial Pool Size: Determines the initial allocation of resources + - Buffer and Concurrency Settings: Controls the queue size and maximum number of concurrent requests (adjustable per provider). + - Retry and Timeout Configurations: Customizable based on your requirements for each provider. -You can manage buffer size (maximum number of requests you want to hold in the system) concurrency (maximum number of requests you want to be made concurrently) for each provider. You can manage user usage and provider limits by providing these custom provider settings Default values are taken for network config, concurrecy and buffer sizes if not provided. - -Bifrost also supports multiple API keys per provider, enabling both load balancing and redundancy. You can assign weights to each key to control how frequently they are selected for requests. By default, all keys are treated with equal weight unless specified otherwise. - -```golang - []schemas.Key{ - { - Value: os.Getenv("OPEN_AI_API_KEY1"), - Models: []string{"gpt-4o-mini", "gpt-4-turbo"}, - Weight: 0.6, - }, - { - Value: os.Getenv("OPEN_AI_API_KEY2"), - Models: []string{"gpt-4-turbo"}, - Weight: 0.3, - }, - { - Value: os.Getenv("OPEN_AI_API_KEY3"), - Models: []string{"gpt-4o-mini"}, - Weight: 0.1, - }, - } -``` +Curious? Run your own benchmarks. The [Bifrost Benchmarking](https://github.com/maximhq/bifrost-benchmarking) repo has everything you need to test it in your own environment. -You can check [this](https://github.com/maximhq/bifrost/blob/main/core/tests/account.go) file to refer all the customisation settings. - -5. Fallbacks: You can define fallback providers for each request, which will be used if all retry attempts with your primary provider fail. These fallback providers are attempted in the order you specify, provided they are configured in your account at runtime. Once a fallback is triggered, its own retry settings will apply, rather than those of the original provider. - -```golang - result, err := bifrost.ChatCompletionRequest( - schemas.OpenAI, &schemas.BifrostRequest{ - Model: "gpt-4o", - Input: schemas.RequestInput{ - ChatCompletionInput: &messages, - }, - Fallbacks: []schemas.Fallback{ - { - Provider: schemas.Anthropic, - Model: "claude-3-5-sonnet-20240620", // make sure you have configured this - }, - }, - }, context.Background() - ) -``` +**πŸ›οΈ Curious how we handle scales of 10k+ RPS?** Check out our [System Architecture Documentation](./docs/system-architecture.md) for detailed insights into Bifrost's high-performance design, memory management, and scaling strategies. --- ## 🀝 Contributing -Contributions are welcome! We welcome all kinds of contributions β€” bug fixes, features, docs, and ideas. Please feel free to submit a Pull Request. +We welcome contributions of all kindsβ€”whether it's bug fixes, features, documentation improvements, or new ideas. Feel free to open an issue, and once it's assigned, submit a Pull Request. + +Here's how to get started (after picking up an issue): 1. Fork the repository 2. Create your feature branch (`git checkout -b feature/amazing-feature`) @@ -392,8 +373,6 @@ Contributions are welcome! We welcome all kinds of contributions β€” bug fixes, ## πŸ“„ License -This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. - ---- +This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details. Built with ❀️ by [Maxim](https://github.com/maximhq) diff --git a/core/bifrost.go b/core/bifrost.go index a4ead1c278..40cc6219b2 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -7,11 +7,9 @@ import ( "context" "fmt" "math/rand" - "os" - "os/signal" "slices" + "strings" "sync" - "syscall" "time" "github.com/maximhq/bifrost/core/providers" @@ -30,6 +28,7 @@ const ( // It contains the request, response and error channels, and the request type. type ChannelMessage struct { schemas.BifrostRequest + Context context.Context Response chan *schemas.BifrostResponse Err chan schemas.BifrostError Type RequestType @@ -49,6 +48,79 @@ type Bifrost struct { logger schemas.Logger // logger instance, default logger is used if not provided dropExcessRequests bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. backgroundCtx context.Context // Shared background context for nil context handling + mcpManager *MCPManager // MCP integration manager (nil if MCP not configured) +} + +// PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks how many plugins ran, and manages short-circuiting and error aggregation. +type PluginPipeline struct { + plugins []schemas.Plugin + logger schemas.Logger + + // Number of PreHooks that were executed (used to determine which PostHooks to run in reverse order) + executedPreHooks int + // Errors from PreHooks and PostHooks + preHookErrors []error + postHookErrors []error +} + +// NewPluginPipeline creates a new pipeline for a given plugin slice and logger. +func NewPluginPipeline(plugins []schemas.Plugin, logger schemas.Logger) *PluginPipeline { + return &PluginPipeline{ + plugins: plugins, + logger: logger, + } +} + +// RunPreHooks executes PreHooks in order, tracks how many ran, and returns the final request, any short-circuit decision, and the count. +func (p *PluginPipeline) RunPreHooks(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, int) { + var shortCircuit *schemas.PluginShortCircuit + var err error + for i, plugin := range p.plugins { + req, shortCircuit, err = plugin.PreHook(ctx, req) + if err != nil { + p.preHookErrors = append(p.preHookErrors, err) + p.logger.Warn(fmt.Sprintf("Error in PreHook for plugin %s: %v", plugin.GetName(), err)) + } + p.executedPreHooks = i + 1 + if shortCircuit != nil { + return req, shortCircuit, p.executedPreHooks // short-circuit: only plugins up to and including i ran + } + } + return req, nil, p.executedPreHooks +} + +// RunPostHooks executes PostHooks in reverse order for the plugins whose PreHook ran. +// Accepts the response and error, and allows plugins to transform either (e.g., recover from error, or invalidate a response). +// Returns the final response and error after all hooks. If both are set, error takes precedence unless error is nil. +func (p *PluginPipeline) RunPostHooks(ctx *context.Context, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError, count int) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Defensive: ensure count is within valid bounds + if count < 0 { + count = 0 + } + if count > len(p.plugins) { + count = len(p.plugins) + } + var err error + for i := count - 1; i >= 0; i-- { + plugin := p.plugins[i] + resp, bifrostErr, err = plugin.PostHook(ctx, resp, bifrostErr) + if err != nil { + p.postHookErrors = append(p.postHookErrors, err) + p.logger.Warn(fmt.Sprintf("Error in PostHook for plugin %s: %v", plugin.GetName(), err)) + } + // If a plugin recovers from an error (sets bifrostErr to nil and sets resp), allow that + // If a plugin invalidates a response (sets resp to nil and sets bifrostErr), allow that + } + // Final logic: if both are set, error takes precedence, unless error is nil + if bifrostErr != nil { + if resp != nil && bifrostErr.StatusCode == nil && bifrostErr.Error.Type == nil && + bifrostErr.Error.Message == "" && bifrostErr.Error.Error == nil { + // Defensive: treat as recovery if error is empty + return resp, nil + } + return resp, bifrostErr + } + return resp, nil } // createProviderFromProviderKey creates a new provider instance based on the provider key. @@ -60,11 +132,17 @@ func (bifrost *Bifrost) createProviderFromProviderKey(providerKey schemas.ModelP case schemas.Anthropic: return providers.NewAnthropicProvider(config, bifrost.logger), nil case schemas.Bedrock: - return providers.NewBedrockProvider(config, bifrost.logger), nil + return providers.NewBedrockProvider(config, bifrost.logger) case schemas.Cohere: return providers.NewCohereProvider(config, bifrost.logger), nil case schemas.Azure: - return providers.NewAzureProvider(config, bifrost.logger), nil + return providers.NewAzureProvider(config, bifrost.logger) + case schemas.Vertex: + return providers.NewVertexProvider(config, bifrost.logger) + case schemas.Mistral: + return providers.NewMistralProvider(config, bifrost.logger), nil + case schemas.Ollama: + return providers.NewOllamaProvider(config, bifrost.logger) default: return nil, fmt.Errorf("unsupported provider: %s", providerKey) } @@ -78,10 +156,12 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi return fmt.Errorf("failed to get config for provider: %v", err) } - // Check if the provider has any keys - keys, err := bifrost.account.GetKeysForProvider(providerKey) - if err != nil || len(keys) == 0 { - return fmt.Errorf("failed to get keys for provider: %v", err) + // Check if the provider has any keys (skip keyless providers) + if providerRequiresKey(providerKey) { + keys, err := bifrost.account.GetKeysForProvider(providerKey) + if err != nil || len(keys) == 0 { + return fmt.Errorf("failed to get keys for provider: %v", err) + } } queue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize) // Buffered channel per provider @@ -93,7 +173,7 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi provider, err := bifrost.createProviderFromProviderKey(providerKey, config) if err != nil { - return fmt.Errorf("failed to get provider for the given key: %v", err) + return fmt.Errorf("failed to create provider for the given key: %v", err) } for range providerConfig.ConcurrencyAndBufferSize.Concurrency { @@ -157,6 +237,17 @@ func Init(config schemas.BifrostConfig) (*Bifrost, error) { } bifrost.logger = config.Logger + // Initialize MCP manager if configured + if config.MCPConfig != nil { + mcpManager, err := newMCPManager(*config.MCPConfig, bifrost.logger) + if err != nil { + bifrost.logger.Warn(fmt.Sprintf("failed to initialize MCP manager: %v", err)) + } else { + bifrost.mcpManager = mcpManager + bifrost.logger.Info("MCP integration initialized successfully") + } + } + // Create buffered channels for each provider and start workers for _, providerKey := range providerKeys { config, err := bifrost.account.GetConfigForProvider(providerKey) @@ -166,7 +257,7 @@ func Init(config schemas.BifrostConfig) (*Bifrost, error) { } if err := bifrost.prepareProvider(providerKey, config); err != nil { - bifrost.logger.Warn(fmt.Sprintf("failed to prepare provider: %v", err)) + bifrost.logger.Warn(fmt.Sprintf("failed to prepare provider %s: %v", providerKey, err)) } } @@ -212,9 +303,9 @@ func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) { bifrost.channelMessagePool.Put(msg) } -// SelectKeyFromProviderForModel selects an appropriate API key for a given provider and model. +// selectKeyFromProviderForModel selects an appropriate API key for a given provider and model. // It uses weighted random selection if multiple keys are available. -func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey schemas.ModelProvider, model string) (string, error) { +func (bifrost *Bifrost) selectKeyFromProviderForModel(providerKey schemas.ModelProvider, model string) (string, error) { keys, err := bifrost.account.GetKeysForProvider(providerKey) if err != nil { return "", err @@ -227,7 +318,7 @@ func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey schemas.ModelP // filter out keys which dont support the model var supportedKeys []schemas.Key for _, key := range keys { - if slices.Contains(key.Models, model) { + if slices.Contains(key.Models, model) && strings.TrimSpace(key.Value) != "" { supportedKeys = append(supportedKeys, key) } } @@ -263,13 +354,25 @@ func (bifrost *Bifrost) SelectKeyFromProviderForModel(providerKey schemas.ModelP return supportedKeys[0].Value, nil } +// Define a set of retryable status codes +var retryableStatusCodes = map[int]bool{ + 500: true, // Internal Server Error + 502: true, // Bad Gateway + 503: true, // Service Unavailable + 504: true, // Gateway Timeout + 429: true, // Too Many Requests +} + +// providerRequiresKey returns true if the given provider requires an API key for authentication. +// Some providers like Vertex and Ollama are keyless and don't require API keys. +func providerRequiresKey(providerKey schemas.ModelProvider) bool { + return providerKey != schemas.Vertex && providerKey != schemas.Ollama +} + // calculateBackoff implements exponential backoff with jitter for retry attempts. func (bifrost *Bifrost) calculateBackoff(attempt int, config *schemas.ProviderConfig) time.Duration { // Calculate an exponential backoff: initial * 2^attempt - backoff := config.NetworkConfig.RetryBackoffInitial * time.Duration(1< config.NetworkConfig.RetryBackoffMax { - backoff = config.NetworkConfig.RetryBackoffMax - } + backoff := min(config.NetworkConfig.RetryBackoffInitial*time.Duration(1< 0 { for _, fallback := range req.Fallbacks { // Check if we have config for this fallback provider @@ -459,16 +588,21 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey schemas.ModelProvider, continue } - // Create a new request with the fallback model + // Create a new request with the fallback provider and model fallbackReq := *req + fallbackReq.Provider = fallback.Provider fallbackReq.Model = fallback.Model // Try the fallback provider - result, fallbackErr := bifrost.tryTextCompletion(fallback.Provider, &fallbackReq, ctx) + result, fallbackErr := bifrost.tryTextCompletion(&fallbackReq, ctx) if fallbackErr == nil { bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) return result, nil } + if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { + return nil, fallbackErr + } + bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message)) } } @@ -479,67 +613,56 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey schemas.ModelProvider, // tryTextCompletion attempts a text completion request with a single provider. // This is a helper function used by TextCompletionRequest to handle individual provider attempts. -func (bifrost *Bifrost) tryTextCompletion(providerKey schemas.ModelProvider, req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { - queue, err := bifrost.GetProviderQueue(providerKey) +func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { + queue, err := bifrost.getProviderQueue(req.Provider) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - }, - } + return nil, newBifrostError(err) } - for _, plugin := range bifrost.plugins { - req, err = plugin.PreHook(&ctx, req) - if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - }, - } - } + // Add MCP tools to request if MCP is configured + if bifrost.mcpManager != nil { + req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req) } - if req == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "bifrost request after plugin hooks cannot be nil", - }, + pipeline := NewPluginPipeline(bifrost.plugins, bifrost.logger) + preReq, shortCircuit, preCount := pipeline.RunPreHooks(&ctx, req) + if shortCircuit != nil { + // Handle short-circuit with response (success case) + if shortCircuit.Response != nil { + resp, bifrostErr := pipeline.RunPostHooks(&ctx, shortCircuit.Response, nil, preCount) + if bifrostErr != nil { + return nil, bifrostErr + } + return resp, nil + } + // Handle short-circuit with error + if shortCircuit.Error != nil { + resp, bifrostErr := pipeline.RunPostHooks(&ctx, nil, shortCircuit.Error, preCount) + if bifrostErr != nil { + return nil, bifrostErr + } + return resp, nil } } + if preReq == nil { + return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + } - // Get a ChannelMessage from the pool - msg := bifrost.getChannelMessage(*req, TextCompletionRequest) + msg := bifrost.getChannelMessage(*preReq, TextCompletionRequest) + msg.Context = ctx - // Handle queue send with context and proper cleanup select { case queue <- *msg: // Message was sent successfully case <-ctx.Done(): - // Request was cancelled by caller bifrost.releaseChannelMessage(msg) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "request cancelled while waiting for queue space", - }, - } + return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space") default: if bifrost.dropExcessRequests { - // Drop request immediately if configured to do so bifrost.releaseChannelMessage(msg) bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "request dropped: queue is full", - }, - } + return nil, newBifrostErrorFromMsg("request dropped: queue is full") } - // If not dropping excess requests, wait with context if ctx == nil { ctx = bifrost.backgroundCtx } @@ -548,91 +671,87 @@ func (bifrost *Bifrost) tryTextCompletion(providerKey schemas.ModelProvider, req // Message was sent successfully case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "request cancelled while waiting for queue space", - }, - } + return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space") } } - // Handle response var result *schemas.BifrostResponse + var resp *schemas.BifrostResponse select { case result = <-msg.Response: - // Run plugins in reverse order - for i := len(bifrost.plugins) - 1; i >= 0; i-- { - result, err = bifrost.plugins[i].PostHook(&ctx, result) - if err != nil { - bifrost.releaseChannelMessage(msg) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - }, - } - } + resp, bifrostErr := pipeline.RunPostHooks(&ctx, result, nil, len(bifrost.plugins)) + if bifrostErr != nil { + bifrost.releaseChannelMessage(msg) + return nil, bifrostErr } - case err := <-msg.Err: bifrost.releaseChannelMessage(msg) - return nil, &err + return resp, nil + case bifrostErrVal := <-msg.Err: + bifrostErrPtr := &bifrostErrVal + resp, bifrostErrPtr = pipeline.RunPostHooks(&ctx, nil, bifrostErrPtr, len(bifrost.plugins)) + bifrost.releaseChannelMessage(msg) + if bifrostErrPtr != nil { + return nil, bifrostErrPtr + } + return resp, nil } - - // Return message to pool - bifrost.releaseChannelMessage(msg) - return result, nil } // ChatCompletionRequest sends a chat completion request to the specified provider. // It handles plugin hooks, request validation, response processing, and fallback providers. // If the primary provider fails, it will try each fallback provider in order until one succeeds. -func (bifrost *Bifrost) ChatCompletionRequest(providerKey schemas.ModelProvider, req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { if req == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "bifrost request cannot be nil", - }, - } + return nil, newBifrostErrorFromMsg("bifrost request cannot be nil") + } + + if req.Provider == "" { + return nil, newBifrostErrorFromMsg("provider is required") } if req.Model == "" { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "model is required", - }, - } + return nil, newBifrostErrorFromMsg("model is required") } // Try the primary provider first - primaryResult, primaryErr := bifrost.tryChatCompletion(providerKey, req, ctx) + primaryResult, primaryErr := bifrost.tryChatCompletion(req, ctx) if primaryErr == nil { return primaryResult, nil } + // Check if this is a short-circuit error that doesn't allow fallbacks + // Note: AllowFallbacks = nil is treated as true (allow fallbacks by default) + if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks { + return nil, primaryErr + } + // If primary provider failed and we have fallbacks, try them in order + // This includes both regular provider errors and plugin short-circuit errors with AllowFallbacks=true/nil if len(req.Fallbacks) > 0 { for _, fallback := range req.Fallbacks { // Check if we have config for this fallback provider _, err := bifrost.account.GetConfigForProvider(fallback.Provider) if err != nil { - bifrost.logger.Warn(fmt.Sprintf("Skipping fallback provider %s: %v", fallback.Provider, err)) + bifrost.logger.Warn(fmt.Sprintf("Config not found for provider %s, skipping fallback: %v", fallback.Provider, err)) continue } - // Create a new request with the fallback model + // Create a new request with the fallback provider and model fallbackReq := *req + fallbackReq.Provider = fallback.Provider fallbackReq.Model = fallback.Model // Try the fallback provider - result, fallbackErr := bifrost.tryChatCompletion(fallback.Provider, &fallbackReq, ctx) + result, fallbackErr := bifrost.tryChatCompletion(&fallbackReq, ctx) if fallbackErr == nil { bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) return result, nil } - bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %v", fallback.Provider, fallbackErr.Error.Message)) + if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { + return nil, fallbackErr + } + + bifrost.logger.Warn(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message)) } } @@ -642,67 +761,56 @@ func (bifrost *Bifrost) ChatCompletionRequest(providerKey schemas.ModelProvider, // tryChatCompletion attempts a chat completion request with a single provider. // This is a helper function used by ChatCompletionRequest to handle individual provider attempts. -func (bifrost *Bifrost) tryChatCompletion(providerKey schemas.ModelProvider, req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { - queue, err := bifrost.GetProviderQueue(providerKey) +func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { + queue, err := bifrost.getProviderQueue(req.Provider) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - }, - } + return nil, newBifrostError(err) } - for _, plugin := range bifrost.plugins { - req, err = plugin.PreHook(&ctx, req) - if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - }, - } - } + // Add MCP tools to request if MCP is configured + if bifrost.mcpManager != nil { + req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req) } - if req == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "bifrost request after plugin hooks cannot be nil", - }, + pipeline := NewPluginPipeline(bifrost.plugins, bifrost.logger) + preReq, shortCircuit, preCount := pipeline.RunPreHooks(&ctx, req) + if shortCircuit != nil { + // Handle short-circuit with response (success case) + if shortCircuit.Response != nil { + resp, bifrostErr := pipeline.RunPostHooks(&ctx, shortCircuit.Response, nil, preCount) + if bifrostErr != nil { + return nil, bifrostErr + } + return resp, nil } + // Handle short-circuit with error + if shortCircuit.Error != nil { + resp, bifrostErr := pipeline.RunPostHooks(&ctx, nil, shortCircuit.Error, preCount) + if bifrostErr != nil { + return nil, bifrostErr + } + return resp, nil + } + } + if preReq == nil { + return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") } - // Get a ChannelMessage from the pool - msg := bifrost.getChannelMessage(*req, ChatCompletionRequest) + msg := bifrost.getChannelMessage(*preReq, ChatCompletionRequest) + msg.Context = ctx - // Handle queue send with context and proper cleanup select { case queue <- *msg: // Message was sent successfully case <-ctx.Done(): - // Request was cancelled by caller bifrost.releaseChannelMessage(msg) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "request cancelled while waiting for queue space", - }, - } + return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space") default: if bifrost.dropExcessRequests { - // Drop request immediately if configured to do so bifrost.releaseChannelMessage(msg) bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "request dropped: queue is full", - }, - } + return nil, newBifrostErrorFromMsg("request dropped: queue is full") } - // If not dropping excess requests, wait with context if ctx == nil { ctx = bifrost.backgroundCtx } @@ -711,46 +819,101 @@ func (bifrost *Bifrost) tryChatCompletion(providerKey schemas.ModelProvider, req // Message was sent successfully case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "request cancelled while waiting for queue space", - }, - } + return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space") } } - // Handle response var result *schemas.BifrostResponse + var resp *schemas.BifrostResponse select { case result = <-msg.Response: - // Run plugins in reverse order - for i := len(bifrost.plugins) - 1; i >= 0; i-- { - result, err = bifrost.plugins[i].PostHook(&ctx, result) - if err != nil { - bifrost.releaseChannelMessage(msg) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - }, - } - } + resp, bifrostErr := pipeline.RunPostHooks(&ctx, result, nil, len(bifrost.plugins)) + if bifrostErr != nil { + bifrost.releaseChannelMessage(msg) + return nil, bifrostErr } - case err := <-msg.Err: bifrost.releaseChannelMessage(msg) - return nil, &err + return resp, nil + case bifrostErrVal := <-msg.Err: + bifrostErrPtr := &bifrostErrVal + resp, bifrostErrPtr = pipeline.RunPostHooks(&ctx, nil, bifrostErrPtr, len(bifrost.plugins)) + bifrost.releaseChannelMessage(msg) + if bifrostErrPtr != nil { + return nil, bifrostErrPtr + } + return resp, nil + } +} + +// ExecuteMCPTool executes an MCP tool call and returns the result as a tool message. +// This is the main public API for manual MCP tool execution. +// +// Parameters: +// - ctx: Execution context +// - toolCall: The tool call to execute (from assistant message) +// +// Returns: +// - schemas.BifrostMessage: Tool message with execution result +// - schemas.BifrostError: Any execution error +func (bifrost *Bifrost) ExecuteMCPTool(ctx context.Context, toolCall schemas.ToolCall) (*schemas.BifrostMessage, *schemas.BifrostError) { + if bifrost.mcpManager == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "MCP is not configured in this Bifrost instance", + }, + } + } + + result, err := bifrost.mcpManager.executeTool(ctx, toolCall) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: err.Error(), + Error: err, + }, + } } - // Return message to pool - bifrost.releaseChannelMessage(msg) return result, nil } -// Shutdown gracefully stops all workers when triggered. +// RegisterMCPTool registers a typed tool handler with the MCP integration. +// This allows developers to easily add custom tools that will be available +// to all LLM requests processed by this Bifrost instance. +// +// Parameters: +// - name: Unique tool name +// - description: Human-readable tool description +// - handler: Function that handles tool execution +// - toolSchema: Bifrost tool schema for function calling +// +// Returns: +// - error: Any registration error +// +// Example: +// +// type EchoArgs struct { +// Message string `json:"message"` +// } +// +// err := bifrost.RegisterMCPTool("echo", "Echo a message", +// func(args EchoArgs) (string, error) { +// return args.Message, nil +// }, toolSchema) +func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(args any) (string, error), toolSchema schemas.Tool) error { + if bifrost.mcpManager == nil { + return fmt.Errorf("MCP is not configured in this Bifrost instance") + } + + return bifrost.mcpManager.registerTool(name, description, handler, toolSchema) +} + +// Cleanup gracefully stops all workers when triggered. // It closes all request channels and waits for workers to exit. -func (bifrost *Bifrost) Shutdown() { - bifrost.logger.Info("[BIFROST] Graceful Shutdown Initiated - Closing all request channels...") +func (bifrost *Bifrost) Cleanup() { + bifrost.logger.Info("Graceful Cleanup Initiated - Closing all request channels...") // Close all provider queues to signal workers to stop for _, queue := range bifrost.requestQueues { @@ -761,14 +924,22 @@ func (bifrost *Bifrost) Shutdown() { for _, waitGroup := range bifrost.waitGroups { waitGroup.Wait() } -} -// Cleanup handles SIGINT (Ctrl+C) to exit cleanly. -// It sets up signal handling and calls Shutdown when interrupted. -func (bifrost *Bifrost) Cleanup() { - signalChan := make(chan os.Signal, 1) - signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM) + // Cleanup MCP manager + if bifrost.mcpManager != nil { + err := bifrost.mcpManager.cleanup() + if err != nil { + bifrost.logger.Warn(fmt.Sprintf("Error cleaning up MCP manager: %s", err.Error())) + } + } + + // Cleanup plugins + for _, plugin := range bifrost.plugins { + err := plugin.Cleanup() + if err != nil { + bifrost.logger.Warn(fmt.Sprintf("Error cleaning up plugin: %s", err.Error())) + } + } - <-signalChan // Wait for interrupt signal - bifrost.Shutdown() // Gracefully shut down + bifrost.logger.Info("Graceful Cleanup Completed") } diff --git a/core/go.mod b/core/go.mod index af649c745f..1b689325bd 100644 --- a/core/go.mod +++ b/core/go.mod @@ -2,16 +2,17 @@ module github.com/maximhq/bifrost/core go 1.24.1 -require github.com/joho/godotenv v1.5.1 - require ( github.com/aws/aws-sdk-go-v2 v1.36.3 github.com/aws/aws-sdk-go-v2/config v1.29.14 - github.com/maximhq/bifrost/plugins v1.0.0 + github.com/goccy/go-json v0.10.5 + github.com/mark3labs/mcp-go v0.32.0 github.com/valyala/fasthttp v1.60.0 + golang.org/x/oauth2 v0.30.0 ) require ( + cloud.google.com/go/compute/metadata v0.3.0 // indirect github.com/andybalholm/brotli v1.1.1 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect @@ -24,10 +25,11 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect github.com/aws/smithy-go v1.22.3 // indirect - github.com/goccy/go-json v0.10.5 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/klauspost/compress v1.18.0 // indirect - github.com/maximhq/maxim-go v0.1.1 // indirect + github.com/spf13/cast v1.7.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect golang.org/x/net v0.39.0 // indirect golang.org/x/text v0.24.0 // indirect ) diff --git a/core/go.sum b/core/go.sum index d0f8edd171..9ffb1b5af4 100644 --- a/core/go.sum +++ b/core/go.sum @@ -1,3 +1,5 @@ +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= @@ -26,23 +28,45 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/Xv github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k= github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= -github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= -github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= -github.com/maximhq/bifrost/plugins v1.0.0 h1:ul4tMMQHOdhyFQueyZwmQB3uX+s2buYSKzq1FW0m090= -github.com/maximhq/bifrost/plugins v1.0.0/go.mod h1:IUDZ2NMgCjIn1SVCvYbWZd/Lsk96MNytOvEKpinjvHo= -github.com/maximhq/maxim-go v0.1.1 h1:69uUQjjDPmUGcKg/M4/3AO0fbD+70Agt66pH/UCsI5M= -github.com/maximhq/maxim-go v0.1.1/go.mod h1:0+UTWM7UZwNNE5VnljLtr/vpRGtYP8r/2q9WDwlLWFw= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8= +github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.60.0 h1:kBRYS0lOhVJ6V+bYN8PqAHELKHtXqwq9zNMLKx1MBsw= github.com/valyala/fasthttp v1.60.0/go.mod h1:iY4kDgV3Gc6EqhRZ8icqcmlG6bqhcDXfuHgTO4FXCvc= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/core/mcp.go b/core/mcp.go new file mode 100644 index 0000000000..bc6a48a904 --- /dev/null +++ b/core/mcp.go @@ -0,0 +1,940 @@ +package bifrost + +import ( + "context" + "encoding/json" + "fmt" + "os" + "slices" + "strings" + "sync" + "time" + + "github.com/maximhq/bifrost/core/schemas" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// ============================================================================ +// CONSTANTS +// ============================================================================ + +const ( + // MCP defaults and identifiers + BifrostMCPVersion = "1.0.0" // Version identifier for Bifrost + BifrostMCPClientName = "BifrostClient" // Name for internal Bifrost MCP client + BifrostMCPClientKey = "bifrost-internal" // Key for internal Bifrost client in clientMap + MCPLogPrefix = "[Bifrost MCP]" // Consistent logging prefix + MCPClientConnectionEstablishTimeout = 30 * time.Second // Timeout for MCP client connection establishment + + // Context keys for client filtering in requests + MCPContextKeyIncludeClients = "mcp_include_clients" // Context key for whitelist client filtering + MCPContextKeyExcludeClients = "mcp_exclude_clients" // Context key for blacklist client filtering +) + +// ============================================================================ +// TYPE DEFINITIONS +// ============================================================================ + +// MCPManager manages MCP integration for Bifrost core. +// It provides a bridge between Bifrost and various MCP servers, supporting +// both local tool hosting and external MCP server connections. +type MCPManager struct { + server *server.MCPServer // Local MCP server instance for hosting tools (STDIO-based) + clientMap map[string]*MCPClient // Map of MCP client names to their configurations + mu sync.RWMutex // Read-write mutex for thread-safe operations + serverRunning bool // Track whether local MCP server is running + logger schemas.Logger // Logger instance for structured logging +} + +// MCPClient represents a connected MCP client with its configuration and tools. +type MCPClient struct { + Name string // Unique name for this client + Conn *client.Client // Active MCP client connection + ExecutionConfig schemas.MCPClientConfig // Tool filtering settings + ToolMap map[string]schemas.Tool // Available tools mapped by name + ConnectionInfo MCPClientConnectionInfo `json:"connection_info"` // Connection metadata for management + cancelFunc context.CancelFunc `json:"-"` // Cancel function for SSE connections (not serialized) +} + +// MCPClientConnectionInfo stores metadata about how a client is connected. +type MCPClientConnectionInfo struct { + Type schemas.MCPConnectionType `json:"type"` // Connection type (HTTP, STDIO, or SSE) + ConnectionURL *string `json:"connection_url,omitempty"` // HTTP/SSE endpoint URL (for HTTP/SSE connections) + StdioCommandString *string `json:"stdio_command_string,omitempty"` // Command string for display (for STDIO connections) +} + +// MCPToolHandler is a generic function type for handling tool calls with typed arguments. +// T represents the expected argument structure for the tool. +type MCPToolHandler[T any] func(args T) (string, error) + +// ============================================================================ +// CONSTRUCTOR AND INITIALIZATION +// ============================================================================ + +// newMCPManager creates and initializes a new MCP manager instance. +// +// Parameters: +// - config: MCP configuration including server port and client configs +// - logger: Logger instance for structured logging (uses default if nil) +// +// Returns: +// - *MCPManager: Initialized manager instance +// - error: Any initialization error +func newMCPManager(config schemas.MCPConfig, logger schemas.Logger) (*MCPManager, error) { + // Use provided logger or create default logger with info level + if logger == nil { + logger = NewDefaultLogger(schemas.LogLevelInfo) + } + + manager := &MCPManager{ + clientMap: make(map[string]*MCPClient), + logger: logger, + } + + // Process client configs: create client map entries and establish connections + for _, clientConfig := range config.ClientConfigs { + // Validate client configuration + if err := validateMCPClientConfig(&clientConfig); err != nil { + return nil, fmt.Errorf("invalid MCP client configuration: %w", err) + } + + // Create client map entry + manager.clientMap[clientConfig.Name] = &MCPClient{ + Name: clientConfig.Name, + ExecutionConfig: clientConfig, + ToolMap: make(map[string]schemas.Tool), + } + + // Attempt to establish connection + err := manager.connectToMCPClient(clientConfig) + if err != nil { + logger.Warn(fmt.Sprintf("%s Failed to connect to MCP client %s: %v", MCPLogPrefix, clientConfig.Name, err)) + // Continue with other connections even if one fails + } + } + + manager.logger.Info(MCPLogPrefix + " MCP Manager initialized") + + return manager, nil +} + +// ============================================================================ +// TOOL REGISTRATION AND DISCOVERY +// ============================================================================ + +// getAvailableTools returns all tools from connected MCP clients. +// Applies client filtering if specified in the context. +func (m *MCPManager) getAvailableTools(ctx context.Context) []schemas.Tool { + m.mu.RLock() + defer m.mu.RUnlock() + + var includeClients []string + var excludeClients []string + + // Extract client filtering from request context + if existingIncludeClients, ok := ctx.Value(MCPContextKeyIncludeClients).([]string); ok && existingIncludeClients != nil { + includeClients = existingIncludeClients + } + if existingExcludeClients, ok := ctx.Value(MCPContextKeyExcludeClients).([]string); ok && existingExcludeClients != nil { + excludeClients = existingExcludeClients + } + + tools := make([]schemas.Tool, 0) + for clientName, client := range m.clientMap { + // Apply client filtering logic + if !m.shouldIncludeClient(clientName, includeClients, excludeClients) { + continue + } + + // Add all tools from this client + for _, tool := range client.ToolMap { + tools = append(tools, tool) + } + } + return tools +} + +// registerTool registers a typed tool handler with the local MCP server. +// This is a convenience function that handles the conversion between typed Go +// handlers and the MCP protocol. +// +// Type Parameters: +// - T: The expected argument type for the tool (must be JSON-deserializable) +// +// Parameters: +// - name: Unique tool name +// - description: Human-readable tool description +// - handler: Typed function that handles tool execution +// - toolSchema: Bifrost tool schema for function calling +// +// Returns: +// - error: Any registration error +// +// Example: +// +// type EchoArgs struct { +// Message string `json:"message"` +// } +// +// err := bifrost.RegisterMCPTool("echo", "Echo a message", +// func(args EchoArgs) (string, error) { +// return args.Message, nil +// }, toolSchema) +func (m *MCPManager) registerTool(name, description string, handler MCPToolHandler[any], toolSchema schemas.Tool) error { + // Ensure local server is set up + if err := m.setupLocalHost(); err != nil { + return fmt.Errorf("failed to setup local host: %w", err) + } + + // Verify internal client exists + if _, ok := m.clientMap[BifrostMCPClientKey]; !ok { + return fmt.Errorf("bifrost client not found") + } + + m.mu.Lock() + defer m.mu.Unlock() + + // Check if tool name already exists to prevent silent overwrites + if _, exists := m.clientMap[BifrostMCPClientKey].ToolMap[name]; exists { + return fmt.Errorf("tool '%s' is already registered", name) + } + + m.logger.Info(fmt.Sprintf("%s Registering typed tool: %s", MCPLogPrefix, name)) + + // Create MCP handler wrapper that converts between typed and MCP interfaces + mcpHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract arguments from the request using the request's methods + args := request.GetArguments() + result, err := handler(args) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Error: %s", err.Error())), nil + } + return mcp.NewToolResultText(result), nil + } + + // Register the tool with the local MCP server using AddTool + if m.server != nil { + tool := mcp.NewTool(name, mcp.WithDescription(description)) + m.server.AddTool(tool, mcpHandler) + } + + // Store tool definition for Bifrost integration + m.clientMap[BifrostMCPClientKey].ToolMap[name] = toolSchema + + return nil +} + +// setupLocalHost initializes the local MCP server and client if not already running. +// This creates a STDIO-based server for local tool hosting and a corresponding client. +// This is called automatically when tools are registered or when the server is needed. +// +// Returns: +// - error: Any setup error +func (m *MCPManager) setupLocalHost() error { + // Check if server is already running + if m.server != nil && m.serverRunning { + return nil + } + + // Create and configure local MCP server (STDIO-based) + server, err := m.createLocalMCPServer() + if err != nil { + return fmt.Errorf("failed to create local MCP server: %w", err) + } + m.server = server + + // Create and configure local MCP client (STDIO-based) + client, err := m.createLocalMCPClient() + if err != nil { + return fmt.Errorf("failed to create local MCP client: %w", err) + } + m.clientMap[BifrostMCPClientKey] = client + + // Start the server and initialize client connection + return m.startLocalMCPServer() +} + +// createLocalMCPServer creates a new local MCP server instance with STDIO transport. +// This server will host tools registered via RegisterTool function. +// +// Returns: +// - *server.MCPServer: Configured MCP server instance +// - error: Any creation error +func (m *MCPManager) createLocalMCPServer() (*server.MCPServer, error) { + // Create MCP server + mcpServer := server.NewMCPServer( + "Bifrost-MCP-Server", + "1.0.0", + server.WithToolCapabilities(true), + ) + + return mcpServer, nil +} + +// createLocalMCPClient creates a client that connects to the local MCP server via STDIO. +// This client is used internally by Bifrost to access locally hosted tools. +// +// Returns: +// - *MCPClient: Configured client for local server +// - error: Any creation error +func (m *MCPManager) createLocalMCPClient() (*MCPClient, error) { + // For local STDIO communication, we'll use the same process + // Create a STDIO transport that communicates with our local server + // This creates an in-process communication channel + stdioTransport := transport.NewStdio( + "", // Empty command means in-process + nil, // No environment variables needed + ) + + // Create the MCP client + mcpClient := client.NewClient(stdioTransport) + + return &MCPClient{ + Name: BifrostMCPClientName, + Conn: mcpClient, + ExecutionConfig: schemas.MCPClientConfig{ + Name: BifrostMCPClientName, + }, + ToolMap: make(map[string]schemas.Tool), + ConnectionInfo: MCPClientConnectionInfo{ + Type: schemas.MCPConnectionTypeSTDIO, + }, + }, nil +} + +// startLocalMCPServer starts the STDIO server in a background goroutine. +// +// Returns: +// - error: Any startup error +func (m *MCPManager) startLocalMCPServer() error { + m.mu.Lock() + defer m.mu.Unlock() + + // Check if server is already running + if m.server != nil && m.serverRunning { + return nil + } + + if m.server == nil { + return fmt.Errorf("server not initialized") + } + + // Start the STDIO server in background goroutine + go func() { + if err := server.ServeStdio(m.server); err != nil { + m.logger.Error(fmt.Errorf("MCP STDIO server error: %w", err)) + m.mu.Lock() + m.serverRunning = false + m.mu.Unlock() + } + }() + + // Mark server as running + m.serverRunning = true + + // Initialize the client connection to the server + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if _, ok := m.clientMap[BifrostMCPClientKey]; !ok { + return fmt.Errorf("bifrost client not found") + } + + // Start the local client transport first + if err := m.clientMap[BifrostMCPClientKey].Conn.Start(ctx); err != nil { + m.serverRunning = false + return fmt.Errorf("failed to start local MCP client transport: %v", err) + } + + // Create proper initialize request + initRequest := mcp.InitializeRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodInitialize), + }, + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: BifrostMCPClientName, + Version: BifrostMCPVersion, + }, + }, + } + + _, err := m.clientMap[BifrostMCPClientKey].Conn.Initialize(ctx, initRequest) + if err != nil { + m.serverRunning = false + return fmt.Errorf("failed to initialize MCP client: %v", err) + } + + return nil +} + +// executeTool executes a tool call and returns the result as a tool message. +// +// Parameters: +// - ctx: Execution context +// - toolCall: The tool call to execute (from assistant message) +// +// Returns: +// - schemas.BifrostMessage: Tool message with execution result +// - error: Any execution error +func (m *MCPManager) executeTool(ctx context.Context, toolCall schemas.ToolCall) (*schemas.BifrostMessage, error) { + if toolCall.Function.Name == nil { + return nil, fmt.Errorf("tool call missing function name") + } + toolName := *toolCall.Function.Name + + // Parse tool arguments + var arguments map[string]interface{} + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + return nil, fmt.Errorf("failed to parse tool arguments for '%s': %v", toolName, err) + } + + // Find which client has this tool + client := m.findMCPClientForTool(toolName) + if client == nil { + return nil, fmt.Errorf("tool '%s' not found in any connected MCP client", toolName) + } + + if client.Conn == nil { + return nil, fmt.Errorf("client '%s' has no active connection", client.Name) + } + + // Call the tool via MCP client -> MCP server + callRequest := mcp.CallToolRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsCall), + }, + Params: mcp.CallToolParams{ + Name: toolName, + Arguments: arguments, + }, + } + + m.logger.Debug(fmt.Sprintf("%s Starting tool execution: %s via client: %s", MCPLogPrefix, toolName, client.Name)) + + toolResponse, callErr := client.Conn.CallTool(ctx, callRequest) + if callErr != nil { + m.logger.Error(fmt.Errorf("%s Tool execution failed for %s via client %s: %v", MCPLogPrefix, toolName, client.Name, callErr)) + return nil, fmt.Errorf("MCP tool call failed: %v", callErr) + } + + m.logger.Debug(fmt.Sprintf("%s Tool execution completed: %s", MCPLogPrefix, toolName)) + + // Extract text from MCP response + responseText := m.extractTextFromMCPResponse(toolResponse, toolName) + + // Create tool response message + return m.createToolResponseMessage(toolCall, responseText), nil +} + +// ============================================================================ +// EXTERNAL MCP CONNECTION MANAGEMENT +// ============================================================================ + +// connectToMCPClient establishes a connection to an external MCP server and +// registers its available tools with the manager. +func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { + // First lock: Initialize or validate client entry + m.mu.Lock() + + // Initialize or validate client entry + if existingClient, exists := m.clientMap[config.Name]; exists { + // Client entry exists from config, check for existing connection + if existingClient.Conn != nil { + m.mu.Unlock() + return fmt.Errorf("client %s already has an active connection", config.Name) + } + // Update connection type for this connection attempt + existingClient.ConnectionInfo.Type = config.ConnectionType + } else { + // Create new client entry with configuration + m.clientMap[config.Name] = &MCPClient{ + Name: config.Name, + ExecutionConfig: config, + ToolMap: make(map[string]schemas.Tool), + ConnectionInfo: MCPClientConnectionInfo{ + Type: config.ConnectionType, + }, + } + } + m.mu.Unlock() + + // Heavy operations performed outside lock + var externalClient *client.Client + var connectionInfo MCPClientConnectionInfo + var err error + + // Create appropriate transport based on connection type + switch config.ConnectionType { + case schemas.MCPConnectionTypeHTTP: + externalClient, connectionInfo, err = m.createHTTPConnection(config) + case schemas.MCPConnectionTypeSTDIO: + externalClient, connectionInfo, err = m.createSTDIOConnection(config) + case schemas.MCPConnectionTypeSSE: + externalClient, connectionInfo, err = m.createSSEConnection(config) + default: + return fmt.Errorf("unknown connection type: %s", config.ConnectionType) + } + + if err != nil { + return fmt.Errorf("failed to create connection: %w", err) + } + + // Initialize the external client with timeout + // For SSE connections, we need a long-lived context, for others we can use timeout + var ctx context.Context + var cancel context.CancelFunc + + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + // SSE connections need a long-lived context for the persistent stream + ctx, cancel = context.WithCancel(context.Background()) + // Don't defer cancel here - SSE needs the context to remain active + } else { + // Other connection types can use timeout context + ctx, cancel = context.WithTimeout(context.Background(), MCPClientConnectionEstablishTimeout) + defer cancel() + } + + // Start the transport first (required for STDIO and SSE clients) + if err := externalClient.Start(ctx); err != nil { + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + cancel() // Cancel SSE context only on error + } + return fmt.Errorf("failed to start MCP client transport %s: %v", config.Name, err) + } + + // Create proper initialize request for external client + extInitRequest := mcp.InitializeRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodInitialize), + }, + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: fmt.Sprintf("Bifrost-%s", config.Name), + Version: "1.0.0", + }, + }, + } + + _, err = externalClient.Initialize(ctx, extInitRequest) + if err != nil { + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + cancel() // Cancel SSE context only on error + } + return fmt.Errorf("failed to initialize MCP client %s: %v", config.Name, err) + } + + // Retrieve tools from the external server (this also requires network I/O) + tools, err := m.retrieveExternalTools(ctx, externalClient, config) + if err != nil { + m.logger.Warn(fmt.Sprintf("%s Failed to retrieve tools from %s: %v", MCPLogPrefix, config.Name, err)) + // Continue with connection even if tool retrieval fails + tools = make(map[string]schemas.Tool) + } + + // Second lock: Update client with final connection details and tools + m.mu.Lock() + defer m.mu.Unlock() + + // Verify client still exists (could have been cleaned up during heavy operations) + if client, exists := m.clientMap[config.Name]; exists { + // Store the external client connection and details + client.Conn = externalClient + client.ConnectionInfo = connectionInfo + + // Store cancel function for SSE connections to enable proper cleanup + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + client.cancelFunc = cancel + } + + // Store discovered tools + for toolName, tool := range tools { + client.ToolMap[toolName] = tool + } + + m.logger.Info(fmt.Sprintf("%s Connected to MCP client: %s", MCPLogPrefix, config.Name)) + } else { + return fmt.Errorf("client %s was removed during connection setup", config.Name) + } + + return nil +} + +// retrieveExternalTools retrieves and filters tools from an external MCP server without holding locks. +func (m *MCPManager) retrieveExternalTools(ctx context.Context, client *client.Client, config schemas.MCPClientConfig) (map[string]schemas.Tool, error) { + // Get available tools from external server + listRequest := mcp.ListToolsRequest{ + PaginatedRequest: mcp.PaginatedRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsList), + }, + }, + } + + toolsResponse, err := client.ListTools(ctx, listRequest) + if err != nil { + return nil, fmt.Errorf("failed to list tools: %v", err) + } + + if toolsResponse == nil { + return make(map[string]schemas.Tool), nil // No tools available + } + + tools := make(map[string]schemas.Tool) + + // toolsResponse is already a ListToolsResult + for _, mcpTool := range toolsResponse.Tools { + // Check if tool should be skipped based on configuration + if m.shouldSkipToolForConfig(mcpTool.Name, config) { + continue + } + + // Convert MCP tool schema to Bifrost format + bifrostTool := m.convertMCPToolToBifrostSchema(&mcpTool) + tools[mcpTool.Name] = bifrostTool + } + + return tools, nil +} + +// shouldSkipToolForConfig checks if a tool should be skipped based on client configuration (without accessing clientMap). +func (m *MCPManager) shouldSkipToolForConfig(toolName string, config schemas.MCPClientConfig) bool { + // If ToolsToExecute is specified, only execute tools in that list + if len(config.ToolsToExecute) > 0 { + for _, allowedTool := range config.ToolsToExecute { + if allowedTool == toolName { + return false // Tool is allowed + } + } + return true // Tool not in allowed list + } + + // Check if tool is in skip list + for _, skipTool := range config.ToolsToSkip { + if skipTool == toolName { + return true // Tool should be skipped + } + } + + return false // Tool is allowed +} + +// convertMCPToolToBifrostSchema converts an MCP tool definition to Bifrost format. +func (m *MCPManager) convertMCPToolToBifrostSchema(mcpTool *mcp.Tool) schemas.Tool { + return schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: mcpTool.Name, + Description: mcpTool.Description, + Parameters: schemas.FunctionParameters{ + Type: mcpTool.InputSchema.Type, + Properties: mcpTool.InputSchema.Properties, + Required: mcpTool.InputSchema.Required, + }, + }, + } +} + +// extractTextFromMCPResponse extracts text content from an MCP tool response. +func (m *MCPManager) extractTextFromMCPResponse(toolResponse *mcp.CallToolResult, toolName string) string { + if toolResponse == nil { + return fmt.Sprintf("MCP tool '%s' executed successfully", toolName) + } + + var result strings.Builder + for _, contentBlock := range toolResponse.Content { + // Handle typed content + switch content := contentBlock.(type) { + case mcp.TextContent: + result.WriteString(fmt.Sprintf("[Text Response: %s]\n", content.Text)) + case mcp.ImageContent: + result.WriteString(fmt.Sprintf("[Image Response: %s, MIME: %s]\n", content.Data, content.MIMEType)) + case mcp.AudioContent: + result.WriteString(fmt.Sprintf("[Audio Response: %s, MIME: %s]\n", content.Data, content.MIMEType)) + case mcp.EmbeddedResource: + result.WriteString(fmt.Sprintf("[Embedded Resource Response: %s]\n", content.Type)) + default: + // Fallback: try to extract from map structure + if jsonBytes, err := json.Marshal(contentBlock); err == nil { + var contentMap map[string]interface{} + if json.Unmarshal(jsonBytes, &contentMap) == nil { + if text, ok := contentMap["text"].(string); ok { + result.WriteString(fmt.Sprintf("[Text Response: %s]\n", text)) + continue + } + } + // Final fallback: serialize as JSON + result.WriteString(string(jsonBytes)) + } + } + } + + if result.Len() > 0 { + return strings.TrimSpace(result.String()) + } + return fmt.Sprintf("MCP tool '%s' executed successfully", toolName) +} + +// createToolResponseMessage creates a tool response message with the execution result. +func (m *MCPManager) createToolResponseMessage(toolCall schemas.ToolCall, responseText string) *schemas.BifrostMessage { + return &schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleTool, + Content: schemas.MessageContent{ + ContentStr: &responseText, + }, + ToolMessage: &schemas.ToolMessage{ + ToolCallID: toolCall.ID, + }, + } +} + +func (m *MCPManager) addMCPToolsToBifrostRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest { + mcpTools := m.getAvailableTools(ctx) + if len(mcpTools) > 0 { + // Initialize tools array if needed + if req.Params == nil { + req.Params = &schemas.ModelParameters{} + } + if req.Params.Tools == nil { + req.Params.Tools = &[]schemas.Tool{} + } + tools := *req.Params.Tools + + // Create a map of existing tool names for O(1) lookup + existingToolsMap := make(map[string]bool) + for _, tool := range tools { + existingToolsMap[tool.Function.Name] = true + } + + // Add MCP tools that are not already present + for _, mcpTool := range mcpTools { + if !existingToolsMap[mcpTool.Function.Name] { + tools = append(tools, mcpTool) + // Update the map to prevent duplicates within MCP tools as well + existingToolsMap[mcpTool.Function.Name] = true + } + } + req.Params.Tools = &tools + + } + return req +} + +func validateMCPClientConfig(config *schemas.MCPClientConfig) error { + if strings.TrimSpace(config.Name) == "" { + return fmt.Errorf("name is required for MCP client config") + } + + if config.ConnectionType == "" { + return fmt.Errorf("connection type is required for MCP client config") + } + + switch config.ConnectionType { + case schemas.MCPConnectionTypeHTTP: + if config.ConnectionString == nil { + return fmt.Errorf("ConnectionString is required for HTTP connection type in client '%s'", config.Name) + } + case schemas.MCPConnectionTypeSSE: + if config.ConnectionString == nil { + return fmt.Errorf("ConnectionString is required for SSE connection type in client '%s'", config.Name) + } + case schemas.MCPConnectionTypeSTDIO: + if config.StdioConfig == nil { + return fmt.Errorf("StdioConfig is required for STDIO connection type in client '%s'", config.Name) + } + default: + return fmt.Errorf("unknown connection type '%s' in client '%s'", config.ConnectionType, config.Name) + } + + // Check for overlapping tools between ToolsToSkip and ToolsToExecute + if len(config.ToolsToSkip) > 0 && len(config.ToolsToExecute) > 0 { + skipMap := make(map[string]bool) + for _, tool := range config.ToolsToSkip { + skipMap[tool] = true + } + + var overlapping []string + for _, tool := range config.ToolsToExecute { + if skipMap[tool] { + overlapping = append(overlapping, tool) + } + } + + if len(overlapping) > 0 { + return fmt.Errorf("tools cannot be both included and excluded in client '%s': %v", config.Name, overlapping) + } + } + + return nil +} + +// ============================================================================ +// HELPER METHODS +// ============================================================================ + +// findMCPClientForTool safely finds a client that has the specified tool. +func (m *MCPManager) findMCPClientForTool(toolName string) *MCPClient { + m.mu.RLock() + defer m.mu.RUnlock() + + for _, client := range m.clientMap { + if _, exists := client.ToolMap[toolName]; exists { + return client + } + } + return nil +} + +// shouldIncludeClient determines if a client should be included based on filtering rules. +func (m *MCPManager) shouldIncludeClient(clientName string, includeClients, excludeClients []string) bool { + // If includeClients is specified, only include those clients (whitelist mode) + if len(includeClients) > 0 { + return slices.Contains(includeClients, clientName) + } + + // If excludeClients is specified, exclude those clients (blacklist mode) + if len(excludeClients) > 0 { + return !slices.Contains(excludeClients, clientName) + } + + // Default: include all clients + return true +} + +// createHTTPConnection creates an HTTP-based MCP client connection without holding locks. +func (m *MCPManager) createHTTPConnection(config schemas.MCPClientConfig) (*client.Client, MCPClientConnectionInfo, error) { + if config.ConnectionString == nil { + return nil, MCPClientConnectionInfo{}, fmt.Errorf("HTTP connection string is required") + } + + // Prepare connection info + connectionInfo := MCPClientConnectionInfo{ + Type: config.ConnectionType, + ConnectionURL: config.ConnectionString, + } + + // Create StreamableHTTP transport + httpTransport, err := transport.NewStreamableHTTP(*config.ConnectionString) + if err != nil { + return nil, MCPClientConnectionInfo{}, fmt.Errorf("failed to create HTTP transport: %w", err) + } + + client := client.NewClient(httpTransport) + + return client, connectionInfo, nil +} + +// createSTDIOConnection creates a STDIO-based MCP client connection without holding locks. +func (m *MCPManager) createSTDIOConnection(config schemas.MCPClientConfig) (*client.Client, MCPClientConnectionInfo, error) { + if config.StdioConfig == nil { + return nil, MCPClientConnectionInfo{}, fmt.Errorf("stdio config is required") + } + + // Prepare STDIO command info for display + cmdString := fmt.Sprintf("%s %s", config.StdioConfig.Command, strings.Join(config.StdioConfig.Args, " ")) + + // Check if environment variables are set + for _, env := range config.StdioConfig.Envs { + if os.Getenv(env) == "" { + return nil, MCPClientConnectionInfo{}, fmt.Errorf("environment variable %s is not set for MCP client %s", env, config.Name) + } + } + + // Create STDIO transport + stdioTransport := transport.NewStdio( + config.StdioConfig.Command, + config.StdioConfig.Envs, + config.StdioConfig.Args..., + ) + + // Prepare connection info + connectionInfo := MCPClientConnectionInfo{ + Type: config.ConnectionType, + StdioCommandString: &cmdString, + } + + client := client.NewClient(stdioTransport) + + // Return nil for cmd since mark3labs/mcp-go manages the process internally + return client, connectionInfo, nil +} + +// createSSEConnection creates a SSE-based MCP client connection without holding locks. +func (m *MCPManager) createSSEConnection(config schemas.MCPClientConfig) (*client.Client, MCPClientConnectionInfo, error) { + if config.ConnectionString == nil { + return nil, MCPClientConnectionInfo{}, fmt.Errorf("SSE connection string is required") + } + + // Prepare connection info + connectionInfo := MCPClientConnectionInfo{ + Type: config.ConnectionType, + ConnectionURL: config.ConnectionString, // Reuse HTTPConnectionURL field for SSE URL display + } + + // Create SSE transport + sseTransport, err := transport.NewSSE(*config.ConnectionString) + if err != nil { + return nil, MCPClientConnectionInfo{}, fmt.Errorf("failed to create SSE transport: %w", err) + } + + client := client.NewClient(sseTransport) + + return client, connectionInfo, nil +} + +// cleanup performs cleanup of all MCP resources including clients and local server. +// This function safely disconnects all MCP clients (HTTP, STDIO, and SSE) and +// cleans up the local MCP server. It handles proper cancellation of SSE contexts +// and closes all transport connections. +// +// Returns: +// - error: Always returns nil, but maintains error interface for consistency +func (m *MCPManager) cleanup() error { + m.mu.Lock() + defer m.mu.Unlock() + + // Disconnect all external MCP clients + for name, client := range m.clientMap { + m.logger.Info(fmt.Sprintf("%s Disconnecting MCP client: %s", MCPLogPrefix, name)) + + // Cancel SSE context if present (required for proper SSE cleanup) + if client.cancelFunc != nil { + client.cancelFunc() + client.cancelFunc = nil + } + + // Close the client transport connection + // This handles cleanup for all transport types (HTTP, STDIO, SSE) + if client.Conn != nil { + if err := client.Conn.Close(); err != nil { + m.logger.Error(fmt.Errorf("%s Failed to close MCP client %s: %w", MCPLogPrefix, name, err)) + } + client.Conn = nil + } + + // Clear client tool map + client.ToolMap = make(map[string]schemas.Tool) + } + + // Clear the client map + m.clientMap = make(map[string]*MCPClient) + + // Clear local server reference + // Note: mark3labs/mcp-go STDIO server cleanup is handled automatically + if m.server != nil { + m.logger.Info(MCPLogPrefix + " Clearing local MCP server reference") + m.server = nil + m.serverRunning = false + } + + m.logger.Info(MCPLogPrefix + " MCP cleanup completed") + return nil +} diff --git a/core/providers/anthropic.go b/core/providers/anthropic.go index 881c0aceda..466a73a572 100644 --- a/core/providers/anthropic.go +++ b/core/providers/anthropic.go @@ -3,7 +3,9 @@ package providers import ( + "context" "fmt" + "strings" "sync" "time" @@ -67,10 +69,18 @@ type AnthropicError struct { } `json:"error"` // Error details } +type AnthropicImageContent struct { + Type ImageContentType `json:"type"` + URL string `json:"url"` + MediaType string `json:"media_type,omitempty"` +} + // AnthropicProvider implements the Provider interface for Anthropic's Claude API. type AnthropicProvider struct { - logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + apiVersion string // API version for the provider + networkConfig schemas.NetworkConfig // Network configuration including extra headers } // anthropicChatResponsePool provides a pool for Anthropic chat response objects. @@ -119,27 +129,35 @@ func releaseAnthropicTextResponse(resp *AnthropicTextResponse) { // It initializes the HTTP client with the provided configuration and sets up response pools. // The client is configured with timeouts, concurrency limits, and optional proxy settings. func NewAnthropicProvider(config *schemas.ProviderConfig, logger schemas.Logger) *AnthropicProvider { - setConfigDefaults(config) + config.CheckAndSetDefaults() client := &fasthttp.Client{ ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), - MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + MaxConnsPerHost: config.ConcurrencyAndBufferSize.Concurrency, } // Pre-warm response pools for range config.ConcurrencyAndBufferSize.Concurrency { anthropicTextResponsePool.Put(&AnthropicTextResponse{}) anthropicChatResponsePool.Put(&AnthropicChatResponse{}) - bifrostResponsePool.Put(&schemas.BifrostResponse{}) + } // Configure proxy if provided client = configureProxy(client, config.ProxyConfig, logger) + // Set default BaseURL if not provided + if config.NetworkConfig.BaseURL == "" { + config.NetworkConfig.BaseURL = "https://api.anthropic.com" + } + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") + return &AnthropicProvider{ - logger: logger, - client: client, + logger: logger, + client: client, + apiVersion: "2023-06-01", + networkConfig: config.NetworkConfig, } } @@ -167,7 +185,7 @@ func (provider *AnthropicProvider) prepareTextCompletionParams(params map[string // completeRequest sends a request to Anthropic's API and handles the response. // It constructs the API URL, sets up authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *AnthropicProvider) completeRequest(requestBody map[string]interface{}, url string, key string) ([]byte, *schemas.BifrostError) { +func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, url string, key string) ([]byte, *schemas.BifrostError) { // Marshal the request body jsonData, err := json.Marshal(requestBody) if err != nil { @@ -186,26 +204,27 @@ func (provider *AnthropicProvider) completeRequest(requestBody map[string]interf defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + req.SetRequestURI(url) req.Header.SetMethod("POST") req.Header.SetContentType("application/json") req.Header.Set("x-api-key", key) - req.Header.Set("anthropic-version", "2023-06-01") + req.Header.Set("anthropic-version", provider.apiVersion) + req.SetBody(jsonData) // Send the request - if err := provider.client.Do(req, resp); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderRequest, - Error: err, - }, - } + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr } // Handle error response if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from anthropic provider: %s", string(resp.Body()))) + var errorResp AnthropicError bifrostErr := handleProviderAPIError(resp, &errorResp) @@ -224,7 +243,7 @@ func (provider *AnthropicProvider) completeRequest(requestBody map[string]interf // TextCompletion performs a text completion request to Anthropic's API. // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AnthropicProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { preparedParams := provider.prepareTextCompletionParams(prepareParams(params)) // Merge additional parameters @@ -233,7 +252,7 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param "prompt": fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", text), }, preparedParams) - responseBody, err := provider.completeRequest(requestBody, "https://api.anthropic.com/v1/complete", key) + responseBody, err := provider.completeRequest(ctx, requestBody, provider.networkConfig.BaseURL+"/v1/complete", key) if err != nil { return nil, err } @@ -242,83 +261,249 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param response := acquireAnthropicTextResponse() defer releaseAnthropicTextResponse(response) - // Create Bifrost response from pool - bifrostResponse := acquireBifrostResponse() - defer releaseBifrostResponse(bifrostResponse) - rawResponse, bifrostErr := handleProviderResponse(responseBody, response) if bifrostErr != nil { return nil, bifrostErr } - bifrostResponse.ID = response.ID - bifrostResponse.Choices = []schemas.BifrostResponseChoice{ - { - Index: 0, - Message: schemas.BifrostResponseChoiceMessage{ - Role: schemas.RoleAssistant, - Content: &response.Completion, + // Create final response + bifrostResponse := &schemas.BifrostResponse{ + ID: response.ID, + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + Message: schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ + ContentStr: &response.Completion, + }, + }, }, }, + Usage: schemas.LLMUsage{ + PromptTokens: response.Usage.InputTokens, + CompletionTokens: response.Usage.OutputTokens, + TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, + }, + Model: response.Model, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.Anthropic, + RawResponse: rawResponse, + }, } - bifrostResponse.Usage = schemas.LLMUsage{ - PromptTokens: response.Usage.InputTokens, - CompletionTokens: response.Usage.OutputTokens, - TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, + + if params != nil { + bifrostResponse.ExtraFields.Params = *params } - bifrostResponse.Model = response.Model + + return bifrostResponse, nil +} + +// ChatCompletion performs a chat completion request to Anthropic's API. +// It formats the request, sends it to Anthropic, and processes the response. +// Returns a BifrostResponse containing the completion results or an error if the request fails. +func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareAnthropicChatRequest(messages, params) + + // Merge additional parameters + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + }, preparedParams) + + responseBody, err := provider.completeRequest(ctx, requestBody, provider.networkConfig.BaseURL+"/v1/messages", key) + if err != nil { + return nil, err + } + + // Create response object from pool + response := acquireAnthropicChatResponse() + defer releaseAnthropicChatResponse(response) + + rawResponse, bifrostErr := handleProviderResponse(responseBody, response) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Create final response + bifrostResponse := &schemas.BifrostResponse{} + bifrostResponse, err = parseAnthropicResponse(response, bifrostResponse) + if err != nil { + return nil, err + } + bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ Provider: schemas.Anthropic, RawResponse: rawResponse, } + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + return bifrostResponse, nil } -// ChatCompletion performs a chat completion request to Anthropic's API. -// It formats the request, sends it to Anthropic, and processes the response. -// Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - // Format messages for Anthropic API - var formattedMessages []map[string]interface{} - for _, msg := range messages { - if msg.ImageContent != nil { - var content []map[string]interface{} +// buildAnthropicImageSourceMap creates the "source" map for an Anthropic image content part. +func buildAnthropicImageSourceMap(imgContent *schemas.ImageURLStruct) map[string]interface{} { + if imgContent == nil { + return nil + } - imageContent := map[string]interface{}{ - "type": "image", - "source": map[string]interface{}{ - "type": msg.ImageContent.Type, - }, + sanitizedURL, _ := SanitizeImageURL(imgContent.URL) + urlTypeInfo := ExtractURLTypeInfo(sanitizedURL) + + formattedImgContent := AnthropicImageContent{ + Type: urlTypeInfo.Type, + } + + if urlTypeInfo.MediaType != nil { + formattedImgContent.MediaType = *urlTypeInfo.MediaType + } + + if urlTypeInfo.DataURLWithoutPrefix != nil { + formattedImgContent.URL = *urlTypeInfo.DataURLWithoutPrefix + } else { + formattedImgContent.URL = sanitizedURL + } + + sourceMap := map[string]interface{}{ + "type": string(formattedImgContent.Type), // "base64" or "url" + } + + if formattedImgContent.Type == ImageContentTypeURL { + sourceMap["url"] = formattedImgContent.URL + } else { + if formattedImgContent.MediaType != "" { + sourceMap["media_type"] = formattedImgContent.MediaType + } + sourceMap["data"] = formattedImgContent.URL // URL field contains base64 data string + } + return sourceMap +} + +func prepareAnthropicChatRequest(messages []schemas.BifrostMessage, params *schemas.ModelParameters) ([]map[string]interface{}, map[string]interface{}) { + // Add system messages if present + var systemMessages []BedrockAnthropicSystemMessage + for _, msg := range messages { + if msg.Role == schemas.ModelChatMessageRoleSystem { + if msg.Content.ContentStr != nil { + systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ + Text: *msg.Content.ContentStr, + }) + } else if msg.Content.ContentBlocks != nil { + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ + Text: *block.Text, + }) + } + } } + } + } - // Handle different image source types - if *msg.ImageContent.Type == "url" { - imageContent["source"].(map[string]interface{})["url"] = msg.ImageContent.URL + // Format messages for Anthropic API + var formattedMessages []map[string]interface{} + for _, msg := range messages { + var content []interface{} + + if msg.Role != schemas.ModelChatMessageRoleSystem { + if msg.Role == schemas.ModelChatMessageRoleTool && msg.ToolMessage != nil && msg.ToolMessage.ToolCallID != nil { + toolCallResult := map[string]interface{}{ + "type": "tool_result", + "tool_use_id": *msg.ToolMessage.ToolCallID, + } + + var toolCallResultContent []map[string]interface{} + + if msg.Content.ContentStr != nil { + toolCallResultContent = append(toolCallResultContent, map[string]interface{}{ + "type": "text", + "text": *msg.Content.ContentStr, + }) + } else if msg.Content.ContentBlocks != nil { + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + toolCallResultContent = append(toolCallResultContent, map[string]interface{}{ + "type": "text", + "text": *block.Text, + }) + } + } + } + + toolCallResult["content"] = toolCallResultContent + content = append(content, toolCallResult) } else { - imageContent["source"].(map[string]interface{})["media_type"] = msg.ImageContent.MediaType - imageContent["source"].(map[string]interface{})["data"] = msg.ImageContent.URL + // Add text content if present + if msg.Content.ContentStr != nil && *msg.Content.ContentStr != "" { + content = append(content, map[string]interface{}{ + "type": "text", + "text": *msg.Content.ContentStr, + }) + } else if msg.Content.ContentBlocks != nil { + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil && *block.Text != "" { + content = append(content, map[string]interface{}{ + "type": "text", + "text": *block.Text, + }) + } + if block.ImageURL != nil { + imageSource := buildAnthropicImageSourceMap(block.ImageURL) + if imageSource != nil { + content = append(content, map[string]interface{}{ + "type": "image", + "source": imageSource, + }) + } + } + } + } + + // Add thinking content if present in AssistantMessage + if msg.AssistantMessage != nil && msg.AssistantMessage.Thought != nil { + content = append(content, map[string]interface{}{ + "type": "thinking", + "thinking": *msg.AssistantMessage.Thought, + }) + } + + // Add tool calls as content if present + if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { + for _, toolCall := range *msg.AssistantMessage.ToolCalls { + if toolCall.Function.Name != nil { + var input map[string]interface{} + if toolCall.Function.Arguments != "" { + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil { + // If unmarshaling fails, use a simple string representation + input = map[string]interface{}{"arguments": toolCall.Function.Arguments} + } + } + + toolUseContent := map[string]interface{}{ + "type": "tool_use", + "name": *toolCall.Function.Name, + "input": input, + } + + if toolCall.ID != nil { + toolUseContent["id"] = *toolCall.ID + } + + content = append(content, toolUseContent) + } + } + } } - content = append(content, imageContent) - - // Add text content if present - if msg.Content != nil { - content = append(content, map[string]interface{}{ - "type": "text", - "text": msg.Content, + if len(content) > 0 { + formattedMessages = append(formattedMessages, map[string]interface{}{ + "role": msg.Role, + "content": content, }) } - - formattedMessages = append(formattedMessages, map[string]interface{}{ - "role": msg.Role, - "content": content, - }) - } else { - formattedMessages = append(formattedMessages, map[string]interface{}{ - "role": msg.Role, - "content": msg.Content, - }) } } @@ -338,43 +523,118 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages [] preparedParams["tools"] = tools } - // Merge additional parameters - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - }, preparedParams) + // Transform tool choice if present + if params != nil && params.ToolChoice != nil { + if params.ToolChoice.ToolChoiceStr != nil { + preparedParams["tool_choice"] = map[string]interface{}{ + "type": *params.ToolChoice.ToolChoiceStr, + } + } else if params.ToolChoice.ToolChoiceStruct != nil { + switch toolChoice := params.ToolChoice.ToolChoiceStruct.Type; toolChoice { + case schemas.ToolChoiceTypeFunction: + fallthrough + case "tool": + preparedParams["tool_choice"] = map[string]interface{}{ + "type": "tool", + "name": params.ToolChoice.ToolChoiceStruct.Function.Name, + } + default: + preparedParams["tool_choice"] = map[string]interface{}{ + "type": toolChoice, + } + } + } + } - responseBody, err := provider.completeRequest(requestBody, "https://api.anthropic.com/v1/messages", key) - if err != nil { - return nil, err + if len(systemMessages) > 0 { + var messages []string + for _, message := range systemMessages { + messages = append(messages, message.Text) + } + + preparedParams["system"] = strings.Join(messages, " ") } - // Create response object from pool - response := acquireAnthropicChatResponse() - defer releaseAnthropicChatResponse(response) + // Post-process formattedMessages for tool call results + processedFormattedMessages := []map[string]interface{}{} // Use a new slice + i := 0 + for i < len(formattedMessages) { + currentMsg := formattedMessages[i] + currentRole, roleOk := getRoleFromMessage(currentMsg) + + if !roleOk || currentRole == "" { + // If role is of an unexpected type, missing, or empty, treat as non-tool message + processedFormattedMessages = append(processedFormattedMessages, currentMsg) + i++ + continue + } - // Create Bifrost response from pool - bifrostResponse := acquireBifrostResponse() - defer releaseBifrostResponse(bifrostResponse) + if currentRole == schemas.ModelChatMessageRoleTool { + // Content of a tool message is the toolCallResult map + // Initialize accumulatedToolResults with the content of the current tool message. + var accumulatedToolResults []interface{} - rawResponse, bifrostErr := handleProviderResponse(responseBody, response) - if bifrostErr != nil { - return nil, bifrostErr + // Safely extract content from current message + if content, ok := currentMsg["content"].([]interface{}); ok { + accumulatedToolResults = content + } else { + // If content is not the expected type, skip this message + processedFormattedMessages = append(processedFormattedMessages, currentMsg) + i++ + continue + } + + // Look ahead for more sequential tool messages + j := i + 1 + for j < len(formattedMessages) { + nextMsg := formattedMessages[j] + nextRole, nextRoleOk := getRoleFromMessage(nextMsg) + + if !nextRoleOk || nextRole == "" || nextRole != schemas.ModelChatMessageRoleTool { + break // Not a sequential tool message or role is invalid/missing/empty + } + + // Safely extract content from next message + if nextContent, ok := nextMsg["content"].([]interface{}); ok { + accumulatedToolResults = append(accumulatedToolResults, nextContent...) + } + j++ + } + + // Create a new message with role User and accumulated content + mergedMsg := map[string]interface{}{ + "role": schemas.ModelChatMessageRoleUser, // Final role is User + "content": accumulatedToolResults, + } + processedFormattedMessages = append(processedFormattedMessages, mergedMsg) + i = j // Advance main loop index past all merged messages + } else { + // Not a tool message, add it as is + processedFormattedMessages = append(processedFormattedMessages, currentMsg) + i++ + } } + formattedMessages = processedFormattedMessages // Update with processed messages - // Process the response into our BifrostResponse format - var choices []schemas.BifrostResponseChoice + return formattedMessages, preparedParams +} - // Process content and tool calls - for i, c := range response.Content { - var content string - var toolCalls []schemas.ToolCall +func parseAnthropicResponse(response *AnthropicChatResponse, bifrostResponse *schemas.BifrostResponse) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Collect all content and tool calls into a single message + var toolCalls []schemas.ToolCall + var thinking string + var contentBlocks []schemas.ContentBlock + // Process content and tool calls + for _, c := range response.Content { switch c.Type { case "thinking": - content = c.Thinking + thinking = c.Thinking case "text": - content = c.Text + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: "text", + Text: &c.Text, + }) case "tool_use": function := schemas.FunctionCall{ Name: &c.Name, @@ -393,31 +653,44 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages [] Function: function, }) } + } + + // Create the assistant message + var assistantMessage *schemas.AssistantMessage - choices = append(choices, schemas.BifrostResponseChoice{ - Index: i, - Message: schemas.BifrostResponseChoiceMessage{ - Role: schemas.RoleAssistant, - Content: &content, - ToolCalls: &toolCalls, + // Create AssistantMessage if we have tool calls or thinking + if len(toolCalls) > 0 || thinking != "" { + assistantMessage = &schemas.AssistantMessage{} + if len(toolCalls) > 0 { + assistantMessage.ToolCalls = &toolCalls + } + if thinking != "" { + assistantMessage.Thought = &thinking + } + } + + // Create a single choice with the collected content + bifrostResponse.ID = response.ID + bifrostResponse.Choices = []schemas.BifrostResponseChoice{ + { + Index: 0, + Message: schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ + ContentBlocks: &contentBlocks, + }, + AssistantMessage: assistantMessage, }, FinishReason: &response.StopReason, StopString: response.StopSequence, - }) + }, } - - bifrostResponse.ID = response.ID - bifrostResponse.Choices = choices bifrostResponse.Usage = schemas.LLMUsage{ PromptTokens: response.Usage.InputTokens, CompletionTokens: response.Usage.OutputTokens, TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, } bifrostResponse.Model = response.Model - bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - Provider: schemas.Anthropic, - RawResponse: rawResponse, - } return bifrostResponse, nil } diff --git a/core/providers/azure.go b/core/providers/azure.go index 13e8e2ee1b..5867778334 100644 --- a/core/providers/azure.go +++ b/core/providers/azure.go @@ -3,6 +3,7 @@ package providers import ( + "context" "fmt" "sync" "time" @@ -51,6 +52,9 @@ type AzureError struct { } `json:"error"` } +// AzureAuthorizationTokenKey is the context key for the Azure authentication token. +const AzureAuthorizationTokenKey ContextKey = "azure-authorization-token" + // azureTextCompletionResponsePool provides a pool for Azure text completion response objects. var azureTextCompletionResponsePool = sync.Pool{ New: func() interface{} { @@ -95,38 +99,44 @@ func releaseAzureTextResponse(resp *AzureTextResponse) { // AzureProvider implements the Provider interface for Azure's OpenAI API. type AzureProvider struct { - logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests - meta schemas.MetaConfig // Azure-specific configuration + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + meta schemas.MetaConfig // Azure-specific configuration + networkConfig schemas.NetworkConfig // Network configuration including extra headers } // NewAzureProvider creates a new Azure provider instance. // It initializes the HTTP client with the provided configuration and sets up response pools. // The client is configured with timeouts, concurrency limits, and optional proxy settings. -func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) *AzureProvider { - setConfigDefaults(config) +func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*AzureProvider, error) { + config.CheckAndSetDefaults() + + if config.MetaConfig == nil { + return nil, fmt.Errorf("meta config is not set") + } client := &fasthttp.Client{ ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), - MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + MaxConnsPerHost: config.ConcurrencyAndBufferSize.Concurrency, } // Pre-warm response pools for range config.ConcurrencyAndBufferSize.Concurrency { azureChatResponsePool.Put(&AzureChatResponse{}) azureTextCompletionResponsePool.Put(&AzureTextResponse{}) - bifrostResponsePool.Put(&schemas.BifrostResponse{}) + } // Configure proxy if provided client = configureProxy(client, config.ProxyConfig, logger) return &AzureProvider{ - logger: logger, - client: client, - meta: config.MetaConfig, - } + logger: logger, + client: client, + meta: config.MetaConfig, + networkConfig: config.NetworkConfig, + }, nil } // GetProviderKey returns the provider identifier for Azure. @@ -137,7 +147,7 @@ func (provider *AzureProvider) GetProviderKey() schemas.ModelProvider { // completeRequest sends a request to Azure's API and handles the response. // It constructs the API URL, sets up authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *AzureProvider) completeRequest(requestBody map[string]interface{}, path string, key string, model string) ([]byte, *schemas.BifrostError) { +func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, key string, model string) ([]byte, *schemas.BifrostError) { // Marshal the request body jsonData, err := json.Marshal(requestBody) if err != nil { @@ -193,25 +203,32 @@ func (provider *AzureProvider) completeRequest(requestBody map[string]interface{ defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + req.SetRequestURI(url) req.Header.SetMethod("POST") req.Header.SetContentType("application/json") - req.Header.Set("api-key", key) + if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken)) + // Ensure api-key is not accidentally present (from extra headers, etc.) + req.Header.Del("api-key") + } else { + req.Header.Set("api-key", key) + } + req.SetBody(jsonData) // Send the request - if err := provider.client.Do(req, resp); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderRequest, - Error: err, - }, - } + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr } // Handle error response if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from azure provider: %s", string(resp.Body()))) + var errorResp AzureError bifrostErr := handleProviderAPIError(resp, &errorResp) @@ -230,7 +247,7 @@ func (provider *AzureProvider) completeRequest(requestBody map[string]interface{ // TextCompletion performs a text completion request to Azure's API. // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AzureProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AzureProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { preparedParams := prepareParams(params) // Merge additional parameters @@ -239,7 +256,7 @@ func (provider *AzureProvider) TextCompletion(model, key, text string, params *s "prompt": text, }, preparedParams) - responseBody, err := provider.completeRequest(requestBody, "completions", key, model) + responseBody, err := provider.completeRequest(ctx, requestBody, "completions", key, model) if err != nil { return nil, err } @@ -248,10 +265,6 @@ func (provider *AzureProvider) TextCompletion(model, key, text string, params *s response := acquireAzureTextResponse() defer releaseAzureTextResponse(response) - // Create Bifrost response from pool - bifrostResponse := acquireBifrostResponse() - defer releaseBifrostResponse(bifrostResponse) - rawResponse, bifrostErr := handleProviderResponse(responseBody, response) if bifrostErr != nil { return nil, bifrostErr @@ -263,9 +276,11 @@ func (provider *AzureProvider) TextCompletion(model, key, text string, params *s if len(response.Choices) > 0 { choices = append(choices, schemas.BifrostResponseChoice{ Index: 0, - Message: schemas.BifrostResponseChoiceMessage{ - Role: schemas.RoleAssistant, - Content: &response.Choices[0].Text, + Message: schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ + ContentStr: &response.Choices[0].Text, + }, }, FinishReason: response.Choices[0].FinishReason, LogProbs: &schemas.LogProbs{ @@ -274,15 +289,22 @@ func (provider *AzureProvider) TextCompletion(model, key, text string, params *s }) } - bifrostResponse.ID = response.ID - bifrostResponse.Choices = choices - bifrostResponse.Model = response.Model - bifrostResponse.Created = response.Created - bifrostResponse.SystemFingerprint = response.SystemFingerprint - bifrostResponse.Usage = response.Usage - bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - Provider: schemas.Azure, - RawResponse: rawResponse, + // Create final response + bifrostResponse := &schemas.BifrostResponse{ + ID: response.ID, + Choices: choices, + Model: response.Model, + Created: response.Created, + SystemFingerprint: response.SystemFingerprint, + Usage: response.Usage, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.Azure, + RawResponse: rawResponse, + }, + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params } return bifrostResponse, nil @@ -291,17 +313,8 @@ func (provider *AzureProvider) TextCompletion(model, key, text string, params *s // ChatCompletion performs a chat completion request to Azure's API. // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AzureProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - preparedParams := prepareParams(params) - - // Format messages for Azure API - var formattedMessages []map[string]interface{} - for _, msg := range messages { - formattedMessages = append(formattedMessages, map[string]interface{}{ - "role": msg.Role, - "content": msg.Content, - }) - } +func (provider *AzureProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) // Merge additional parameters requestBody := mergeConfig(map[string]interface{}{ @@ -309,7 +322,7 @@ func (provider *AzureProvider) ChatCompletion(model, key string, messages []sche "messages": formattedMessages, }, preparedParams) - responseBody, err := provider.completeRequest(requestBody, "chat/completions", key, model) + responseBody, err := provider.completeRequest(ctx, requestBody, "chat/completions", key, model) if err != nil { return nil, err } @@ -318,24 +331,27 @@ func (provider *AzureProvider) ChatCompletion(model, key string, messages []sche response := acquireAzureChatResponse() defer releaseAzureChatResponse(response) - // Create Bifrost response from pool - bifrostResponse := acquireBifrostResponse() - defer releaseBifrostResponse(bifrostResponse) - rawResponse, bifrostErr := handleProviderResponse(responseBody, response) if bifrostErr != nil { return nil, bifrostErr } - bifrostResponse.ID = response.ID - bifrostResponse.Choices = response.Choices - bifrostResponse.Model = response.Model - bifrostResponse.Created = response.Created - bifrostResponse.SystemFingerprint = response.SystemFingerprint - bifrostResponse.Usage = response.Usage - bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - Provider: schemas.Azure, - RawResponse: rawResponse, + // Create final response + bifrostResponse := &schemas.BifrostResponse{ + ID: response.ID, + Choices: response.Choices, + Model: response.Model, + Created: response.Created, + SystemFingerprint: response.SystemFingerprint, + Usage: response.Usage, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.Azure, + RawResponse: rawResponse, + }, + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params } return bifrostResponse, nil diff --git a/core/providers/bedrock.go b/core/providers/bedrock.go index 3a9b35c39b..169e2a53e1 100644 --- a/core/providers/bedrock.go +++ b/core/providers/bedrock.go @@ -7,6 +7,7 @@ import ( "context" "crypto/sha256" "encoding/hex" + "errors" "fmt" "io" "net/http" @@ -49,7 +50,9 @@ type BedrockChatResponse struct { Output struct { Message struct { Content []struct { - Text string `json:"text"` // Message content + Text *string `json:"text"` // Message content + // Bedrock returns a union type where either Text or ToolUse is present (mutually exclusive) + BedrockAnthropicToolUseMessage } `json:"content"` // Array of message content Role string `json:"role"` // Role of the message sender } `json:"message"` // Message structure @@ -94,8 +97,8 @@ type BedrockAnthropicImageMessage struct { // BedrockAnthropicImage represents image data for Anthropic models. type BedrockAnthropicImage struct { - Format string `json:"string"` // Image format - Source BedrockAnthropicImageSource `json:"source"` // Image source + Format string `json:"format,omitempty"` // Image format + Source BedrockAnthropicImageSource `json:"source,omitempty"` // Image source } // BedrockAnthropicImageSource represents the source of an image for Anthropic models. @@ -105,8 +108,18 @@ type BedrockAnthropicImageSource struct { // BedrockMistralToolCall represents a tool call for Mistral models. type BedrockMistralToolCall struct { - ID string `json:"id"` // Tool call ID - Function schemas.Function `json:"function"` // Function to call + ID string `json:"id"` // Tool call ID + Function schemas.FunctionCall `json:"function"` // Function to call +} + +type BedrockAnthropicToolUseMessage struct { + ToolUse *BedrockAnthropicToolUse `json:"toolUse"` +} + +type BedrockAnthropicToolUse struct { + ToolUseID string `json:"toolUseId"` + Name string `json:"name"` + Input map[string]interface{} `json:"input"` } // BedrockAnthropicToolCall represents a tool call for Anthropic models. @@ -130,9 +143,10 @@ type BedrockError struct { // BedrockProvider implements the Provider interface for AWS Bedrock. type BedrockProvider struct { - logger schemas.Logger // Logger for provider operations - client *http.Client // HTTP client for API requests - meta schemas.MetaConfig // AWS-specific configuration + logger schemas.Logger // Logger for provider operations + client *http.Client // HTTP client for API requests + meta schemas.MetaConfig // Bedrock-specific configuration + networkConfig schemas.NetworkConfig // Network configuration including extra headers } // bedrockChatResponsePool provides a pool for Bedrock response objects. @@ -159,22 +173,27 @@ func releaseBedrockChatResponse(resp *BedrockChatResponse) { // NewBedrockProvider creates a new Bedrock provider instance. // It initializes the HTTP client with the provided configuration and sets up response pools. // The client is configured with timeouts and AWS-specific settings. -func NewBedrockProvider(config *schemas.ProviderConfig, logger schemas.Logger) *BedrockProvider { - setConfigDefaults(config) +func NewBedrockProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*BedrockProvider, error) { + config.CheckAndSetDefaults() + + if config.MetaConfig == nil { + return nil, fmt.Errorf("meta config is not set") + } client := &http.Client{Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)} // Pre-warm response pools for range config.ConcurrencyAndBufferSize.Concurrency { bedrockChatResponsePool.Put(&BedrockChatResponse{}) - bifrostResponsePool.Put(&schemas.BifrostResponse{}) + } return &BedrockProvider{ - logger: logger, - client: client, - meta: config.MetaConfig, - } + logger: logger, + client: client, + meta: config.MetaConfig, + networkConfig: config.NetworkConfig, + }, nil } // GetProviderKey returns the provider identifier for Bedrock. @@ -185,7 +204,7 @@ func (provider *BedrockProvider) GetProviderKey() schemas.ModelProvider { // CompleteRequest sends a request to Bedrock's API and handles the response. // It constructs the API URL, sets up AWS authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *BedrockProvider) completeRequest(requestBody map[string]interface{}, path string, accessKey string) ([]byte, *schemas.BifrostError) { +func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, accessKey string) ([]byte, *schemas.BifrostError) { if provider.meta == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -202,6 +221,16 @@ func (provider *BedrockProvider) completeRequest(requestBody map[string]interfac jsonBody, err := json.Marshal(requestBody) if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Type: StrPtr(schemas.RequestCancelled), + Message: fmt.Sprintf("Request cancelled or timed out by context: %v", ctx.Err()), + Error: err, + }, + } + } return nil, &schemas.BifrostError{ IsBifrostError: true, Error: schemas.ErrorField{ @@ -212,7 +241,7 @@ func (provider *BedrockProvider) completeRequest(requestBody map[string]interfac } // Create the request with the JSON body - req, err := http.NewRequest("POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewBuffer(jsonBody)) + req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewBuffer(jsonBody)) if err != nil { return nil, &schemas.BifrostError{ IsBifrostError: true, @@ -223,8 +252,20 @@ func (provider *BedrockProvider) completeRequest(requestBody map[string]interfac } } - if err := signAWSRequest(req, accessKey, *provider.meta.GetSecretAccessKey(), provider.meta.GetSessionToken(), region, "bedrock"); err != nil { - return nil, err + // Set any extra headers from network config + setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil) + + if provider.meta.GetSecretAccessKey() != nil { + if err := signAWSRequest(req, accessKey, *provider.meta.GetSecretAccessKey(), provider.meta.GetSessionToken(), region, "bedrock"); err != nil { + return nil, err + } + } else { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "secret access key not set", + }, + } } // Execute the request @@ -258,6 +299,7 @@ func (provider *BedrockProvider) completeRequest(requestBody map[string]interfac if err := json.Unmarshal(body, &errorResp); err != nil { return nil, &schemas.BifrostError{ IsBifrostError: true, + StatusCode: &resp.StatusCode, Error: schemas.ErrorField{ Message: schemas.ErrProviderResponseUnmarshal, Error: err, @@ -301,9 +343,11 @@ func (provider *BedrockProvider) getTextCompletionResult(result []byte, model st Choices: []schemas.BifrostResponseChoice{ { Index: 0, - Message: schemas.BifrostResponseChoiceMessage{ - Role: schemas.RoleAssistant, - Content: &response.Completion, + Message: schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ + ContentStr: &response.Completion, + }, }, FinishReason: &response.StopReason, StopString: &response.Stop, @@ -339,9 +383,11 @@ func (provider *BedrockProvider) getTextCompletionResult(result []byte, model st for i, output := range response.Outputs { choices = append(choices, schemas.BifrostResponseChoice{ Index: i, - Message: schemas.BifrostResponseChoiceMessage{ - Role: schemas.RoleAssistant, - Content: &output.Text, + Message: schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ + ContentStr: &output.Text, + }, }, FinishReason: &output.StopReason, }) @@ -364,10 +410,30 @@ func (provider *BedrockProvider) getTextCompletionResult(result []byte, model st } } +// parseBedrockAnthropicMessageToolCallContent parses the content of a tool call message. +// It handles both text and JSON content. +// Returns a map containing the parsed content. +func parseBedrockAnthropicMessageToolCallContent(content string) map[string]interface{} { + toolResultContentBlock := map[string]interface{}{} + var parsedJSON interface{} + err := json.Unmarshal([]byte(content), &parsedJSON) + if err == nil { + if arr, ok := parsedJSON.([]interface{}); ok { + toolResultContentBlock["json"] = map[string]interface{}{"content": arr} + } else { + toolResultContentBlock["json"] = parsedJSON + } + } else { + toolResultContentBlock["text"] = content + } + + return toolResultContentBlock +} + // PrepareChatCompletionMessages formats chat messages for Bedrock's API. // It handles different model types (Anthropic and Mistral) and formats messages accordingly. // Returns a map containing the formatted messages and any system messages, or an error if formatting fails. -func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schemas.Message, model string) (map[string]interface{}, *schemas.BifrostError) { +func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schemas.BifrostMessage, model string) (map[string]interface{}, *schemas.BifrostError) { switch model { case "anthropic.claude-instant-v1:2": fallthrough @@ -389,43 +455,189 @@ func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schema // Add system messages if present var systemMessages []BedrockAnthropicSystemMessage for _, msg := range messages { - if msg.Role == schemas.RoleSystem { - //TODO handling image inputs here - systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ - Text: *msg.Content, - }) + if msg.Role == schemas.ModelChatMessageRoleSystem { + if msg.Content.ContentStr != nil { + systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ + Text: *msg.Content.ContentStr, + }) + } else if msg.Content.ContentBlocks != nil { + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ + Text: *block.Text, + }) + } + } + } } } // Format messages for Bedrock API var bedrockMessages []map[string]interface{} for _, msg := range messages { - if msg.Role != schemas.RoleSystem { - var content any - if msg.Content != nil { - content = BedrockAnthropicTextMessage{ - Type: "text", - Text: *msg.Content, + var content []interface{} + if msg.Role != schemas.ModelChatMessageRoleSystem { + if msg.Role == schemas.ModelChatMessageRoleTool && msg.ToolCallID != nil { + toolCallResult := map[string]interface{}{ + "toolUseId": *msg.ToolCallID, } - } else if msg.ImageContent != nil { - content = BedrockAnthropicImageMessage{ - Type: "image", - Image: BedrockAnthropicImage{ - Format: *msg.ImageContent.Type, - Source: BedrockAnthropicImageSource{ - Bytes: msg.ImageContent.URL, - }, - }, + var toolResultContentBlocks []map[string]interface{} + if msg.Content.ContentStr != nil { + toolResultContentBlocks = append(toolResultContentBlocks, parseBedrockAnthropicMessageToolCallContent(*msg.Content.ContentStr)) + } else if msg.Content.ContentBlocks != nil { + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + toolResultContentBlocks = append(toolResultContentBlocks, parseBedrockAnthropicMessageToolCallContent(*block.Text)) + } + } + } + toolCallResult["content"] = toolResultContentBlocks + content = append(content, map[string]interface{}{ + "toolResult": toolCallResult, + }) + } else { + if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { + for _, toolCall := range *msg.AssistantMessage.ToolCalls { + var input map[string]interface{} + if toolCall.Function.Arguments != "" { + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil { + input = map[string]interface{}{"arguments": toolCall.Function.Arguments} + } + } + + content = append(content, BedrockAnthropicToolUseMessage{ + ToolUse: &BedrockAnthropicToolUse{ + ToolUseID: *toolCall.ID, + Name: *toolCall.Function.Name, + Input: input, + }, + }) + } + } + + if msg.Content.ContentStr != nil { + content = append(content, BedrockAnthropicTextMessage{ + Type: "text", + Text: *msg.Content.ContentStr, + }) + } else if msg.Content.ContentBlocks != nil { + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + content = append(content, BedrockAnthropicTextMessage{ + Type: "text", + Text: *block.Text, + }) + } + if block.ImageURL != nil { + sanitizedURL, _ := SanitizeImageURL(block.ImageURL.URL) + urlTypeInfo := ExtractURLTypeInfo(sanitizedURL) + + formattedImgContent := AnthropicImageContent{ + Type: urlTypeInfo.Type, + } + + if urlTypeInfo.MediaType != nil { + formattedImgContent.MediaType = *urlTypeInfo.MediaType + } + + if urlTypeInfo.DataURLWithoutPrefix != nil { + formattedImgContent.URL = *urlTypeInfo.DataURLWithoutPrefix + } else { + formattedImgContent.URL = sanitizedURL + } + + content = append(content, BedrockAnthropicImageMessage{ + Type: "image", + Image: BedrockAnthropicImage{ + Format: func() string { + if formattedImgContent.MediaType != "" { + mediaType := formattedImgContent.MediaType + // Remove "image/" prefix if present, since normalizeMediaType ensures full format + mediaType = strings.TrimPrefix(mediaType, "image/") + return mediaType + } + return "" + }(), + Source: BedrockAnthropicImageSource{ + Bytes: formattedImgContent.URL, + }, + }, + }) + } + } } } - bedrockMessages = append(bedrockMessages, map[string]interface{}{ - "role": msg.Role, - "content": []interface{}{content}, - }) + if len(content) > 0 { + bedrockMessages = append(bedrockMessages, map[string]interface{}{ + "role": msg.Role, + "content": content, + }) + } } } + // Post-process bedrockMessages for tool call results + processedBedrockMessages := []map[string]interface{}{} + i := 0 + for i < len(bedrockMessages) { + currentMsg := bedrockMessages[i] + currentRole, roleOk := getRoleFromMessage(currentMsg) + + if !roleOk { + // If role is of an unexpected type or missing, treat as non-tool message + processedBedrockMessages = append(processedBedrockMessages, currentMsg) + i++ + continue + } + + if currentRole == schemas.ModelChatMessageRoleTool { + // Content of a tool message is the toolCallResult map + // Initialize accumulatedToolResults with the content of the current tool message. + var accumulatedToolResults []interface{} + + // Safely extract content from current message + if content, ok := currentMsg["content"].([]interface{}); ok { + accumulatedToolResults = content + } else { + // If content is not the expected type, skip this message + processedBedrockMessages = append(processedBedrockMessages, currentMsg) + i++ + continue + } + + // Look ahead for more sequential tool messages + j := i + 1 + for j < len(bedrockMessages) { + nextMsg := bedrockMessages[j] + nextRole, nextRoleOk := getRoleFromMessage(nextMsg) + + if !nextRoleOk || nextRole != schemas.ModelChatMessageRoleTool { + break // Not a sequential tool message or role is invalid/missing + } + + // Safely extract content from next message + if nextContent, ok := nextMsg["content"].([]interface{}); ok { + accumulatedToolResults = append(accumulatedToolResults, nextContent...) + } + j++ + } + + // Create a new message with role User and accumulated content + mergedMsg := map[string]interface{}{ + "role": schemas.ModelChatMessageRoleUser, // Final role is User + "content": accumulatedToolResults, + } + processedBedrockMessages = append(processedBedrockMessages, mergedMsg) + i = j // Advance main loop index past all merged messages + } else { + // Not a tool message, add it as is + processedBedrockMessages = append(processedBedrockMessages, currentMsg) + i++ + } + } + bedrockMessages = processedBedrockMessages // Update with processed messages + body := map[string]interface{}{ "messages": bedrockMessages, } @@ -436,7 +648,7 @@ func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schema messages = append(messages, message.Text) } - body["system"] = strings.Join(messages, " ") + body["system"] = strings.Join(messages, " \n") } return body, nil @@ -447,20 +659,30 @@ func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schema var bedrockMessages []BedrockMistralChatMessage for _, msg := range messages { var filteredToolCalls []BedrockMistralToolCall - if msg.ToolCalls != nil { - for _, toolCall := range *msg.ToolCalls { - filteredToolCalls = append(filteredToolCalls, BedrockMistralToolCall{ - ID: *toolCall.ID, - Function: toolCall.Function, - }) + if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { + for _, toolCall := range *msg.AssistantMessage.ToolCalls { + if toolCall.ID != nil { + filteredToolCalls = append(filteredToolCalls, BedrockMistralToolCall{ + ID: *toolCall.ID, + Function: toolCall.Function, + }) + } } } message := BedrockMistralChatMessage{ Role: msg.Role, - Content: []BedrockMistralContent{ - {Text: *msg.Content}, - }, + } + + switch { + case msg.Content.ContentStr != nil: + message.Content = []BedrockMistralContent{{Text: *msg.Content.ContentStr}} + case msg.Content.ContentBlocks != nil: + for _, b := range *msg.Content.ContentBlocks { + if b.Text != nil { + message.Content = append(message.Content, BedrockMistralContent{Text: *b.Text}) + } + } } if len(filteredToolCalls) > 0 { @@ -553,19 +775,19 @@ func (provider *BedrockProvider) prepareTextCompletionParams(params map[string]i // TextCompletion performs a text completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *BedrockProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { preparedParams := provider.prepareTextCompletionParams(prepareParams(params), model) requestBody := mergeConfig(map[string]interface{}{ "prompt": text, }, preparedParams) - body, err := provider.completeRequest(requestBody, fmt.Sprintf("%s/invoke", model), key) + body, err := provider.completeRequest(ctx, requestBody, fmt.Sprintf("%s/invoke", model), key) if err != nil { return nil, err } - result, err := provider.getTextCompletionResult(body, model) + bifrostResponse, err := provider.getTextCompletionResult(body, model) if err != nil { return nil, err } @@ -582,15 +804,69 @@ func (provider *BedrockProvider) TextCompletion(model, key, text string, params } } - result.ExtraFields.RawResponse = rawResponse + bifrostResponse.ExtraFields.RawResponse = rawResponse + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil +} - return result, nil +// extractToolsFromHistory extracts minimal tool definitions from conversation history. +// It analyzes the messages to find tool-related content and returns whether tool content +// was found and a list of unique minimal tool definitions extracted from the conversation. +// This is needed when Bedrock requires toolConfig but no tools are provided in the current request. +func (provider *BedrockProvider) extractToolsFromHistory(messages []schemas.BifrostMessage) (bool, []BedrockAnthropicToolCall) { + hasToolContent := false + var toolsFromHistory []BedrockAnthropicToolCall + seenTools := make(map[string]BedrockAnthropicToolCall) + + for _, msg := range messages { + // Check for tool result messages + if msg.Role == schemas.ModelChatMessageRoleTool { + hasToolContent = true + } + // Check for assistant messages with tool calls + if msg.Role == schemas.ModelChatMessageRoleAssistant && msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { + hasToolContent = true + // Extract tool definitions from tool calls for toolConfig + for _, toolCall := range *msg.AssistantMessage.ToolCalls { + if toolCall.Function.Name != nil { + toolName := *toolCall.Function.Name + if _, exists := seenTools[toolName]; !exists { + // Create a basic tool definition from the tool call + // Note: We can't fully reconstruct the original tool definition, + // but we can provide a minimal one that satisfies Bedrock's requirement + tool := BedrockAnthropicToolCall{ + ToolSpec: BedrockAnthropicToolSpec{ + Name: toolName, + Description: fmt.Sprintf("Tool: %s", toolName), + InputSchema: struct { + Json interface{} `json:"json"` + }{ + Json: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + }, + }, + }, + } + seenTools[toolName] = tool + toolsFromHistory = append(toolsFromHistory, tool) + } + } + } + } + } + + return hasToolContent, toolsFromHistory } // ChatCompletion performs a chat completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *BedrockProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { messageBody, err := provider.prepareChatCompletionMessages(messages, model) if err != nil { return nil, err @@ -600,7 +876,21 @@ func (provider *BedrockProvider) ChatCompletion(model, key string, messages []sc // Transform tools if present if params != nil && params.Tools != nil && len(*params.Tools) > 0 { - preparedParams["tools"] = provider.getChatCompletionTools(params, model) + preparedParams["toolConfig"] = map[string]interface{}{ + "tools": provider.getChatCompletionTools(params, model), + } + } else { + // Check if conversation history contains tool use/result blocks + // Bedrock requires toolConfig when such blocks are present + hasToolContent, toolsFromHistory := provider.extractToolsFromHistory(messages) + + // If conversation contains tool content but no tools provided in current request, + // include the extracted tools to satisfy Bedrock's toolConfig requirement + if hasToolContent && len(toolsFromHistory) > 0 { + preparedParams["toolConfig"] = map[string]interface{}{ + "tools": toolsFromHistory, + } + } } requestBody := mergeConfig(messageBody, preparedParams) @@ -618,7 +908,7 @@ func (provider *BedrockProvider) ChatCompletion(model, key string, messages []sc } // Create the signed request - responseBody, err := provider.completeRequest(requestBody, path, key) + responseBody, err := provider.completeRequest(ctx, requestBody, path, key) if err != nil { return nil, err } @@ -627,40 +917,90 @@ func (provider *BedrockProvider) ChatCompletion(model, key string, messages []sc response := acquireBedrockChatResponse() defer releaseBedrockChatResponse(response) - // Create Bifrost response from pool - bifrostResponse := acquireBifrostResponse() - defer releaseBifrostResponse(bifrostResponse) - rawResponse, bifrostErr := handleProviderResponse(responseBody, response) if bifrostErr != nil { return nil, bifrostErr } - var choices []schemas.BifrostResponseChoice - for i, choice := range response.Output.Message.Content { - choices = append(choices, schemas.BifrostResponseChoice{ - Index: i, - Message: schemas.BifrostResponseChoiceMessage{ - Role: schemas.RoleAssistant, - Content: &choice.Text, + // Collect all content and tool calls into a single message (similar to Anthropic aggregation) + var toolCalls []schemas.ToolCall + + var contentBlocks []schemas.ContentBlock + // Process content and tool calls + for _, choice := range response.Output.Message.Content { + if choice.Text != nil && *choice.Text != "" { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: "text", + Text: choice.Text, + }) + } + + if choice.ToolUse != nil { + input := choice.ToolUse.Input + if input == nil { + input = map[string]any{} + } + arguments, err := json.Marshal(input) + if err != nil { + arguments = []byte("{}") + } + + toolCalls = append(toolCalls, schemas.ToolCall{ + Type: StrPtr("function"), + ID: &choice.ToolUse.ToolUseID, + Function: schemas.FunctionCall{ + Name: &choice.ToolUse.Name, + Arguments: string(arguments), + }, + }) + } + } + + // Create the assistant message + var assistantMessage *schemas.AssistantMessage + + // Create AssistantMessage if we have tool calls + if len(toolCalls) > 0 { + assistantMessage = &schemas.AssistantMessage{ + ToolCalls: &toolCalls, + } + } + + // Create a single choice with the aggregated content + choices := []schemas.BifrostResponseChoice{ + { + Index: 0, + Message: schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ + ContentBlocks: &contentBlocks, + }, + AssistantMessage: assistantMessage, }, FinishReason: &response.StopReason, - }) + }, } latency := float64(response.Metrics.Latency) - bifrostResponse.Choices = choices - bifrostResponse.Usage = schemas.LLMUsage{ - PromptTokens: response.Usage.InputTokens, - CompletionTokens: response.Usage.OutputTokens, - TotalTokens: response.Usage.TotalTokens, - } - bifrostResponse.Model = model - bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - Latency: &latency, - Provider: schemas.Bedrock, - RawResponse: rawResponse, + // Create final response + bifrostResponse := &schemas.BifrostResponse{ + Choices: choices, + Usage: schemas.LLMUsage{ + PromptTokens: response.Usage.InputTokens, + CompletionTokens: response.Usage.OutputTokens, + TotalTokens: response.Usage.TotalTokens, + }, + Model: model, + ExtraFields: schemas.BifrostResponseExtraFields{ + Latency: &latency, + Provider: schemas.Bedrock, + RawResponse: rawResponse, + }, + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params } return bifrostResponse, nil diff --git a/core/providers/cohere.go b/core/providers/cohere.go index 0240af086b..c8b17b3f74 100644 --- a/core/providers/cohere.go +++ b/core/providers/cohere.go @@ -3,8 +3,10 @@ package providers import ( + "context" "fmt" "slices" + "strings" "sync" "time" @@ -93,31 +95,39 @@ type CohereError struct { // CohereProvider implements the Provider interface for Cohere. type CohereProvider struct { - logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + networkConfig schemas.NetworkConfig // Network configuration including extra headers } // NewCohereProvider creates a new Cohere provider instance. // It initializes the HTTP client with the provided configuration and sets up response pools. // The client is configured with timeouts and connection limits. func NewCohereProvider(config *schemas.ProviderConfig, logger schemas.Logger) *CohereProvider { - setConfigDefaults(config) + config.CheckAndSetDefaults() client := &fasthttp.Client{ ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), - MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + MaxConnsPerHost: config.ConcurrencyAndBufferSize.Concurrency, } // Pre-warm response pools for range config.ConcurrencyAndBufferSize.Concurrency { cohereResponsePool.Put(&CohereChatResponse{}) - bifrostResponsePool.Put(&schemas.BifrostResponse{}) + + } + + // Set default BaseURL if not provided + if config.NetworkConfig.BaseURL == "" { + config.NetworkConfig.BaseURL = "https://api.cohere.ai" } + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") return &CohereProvider{ - logger: logger, - client: client, + logger: logger, + client: client, + networkConfig: config.NetworkConfig, } } @@ -128,7 +138,7 @@ func (provider *CohereProvider) GetProviderKey() schemas.ModelProvider { // TextCompletion is not supported by the Cohere provider. // Returns an error indicating that text completion is not supported. -func (provider *CohereProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *CohereProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -140,7 +150,7 @@ func (provider *CohereProvider) TextCompletion(model, key, text string, params * // ChatCompletion performs a chat completion request to the Cohere API. // It formats the request, sends it to Cohere, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *CohereProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *CohereProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { // Get the last message and chat history lastMessage := messages[len(messages)-1] chatHistory := messages[:len(messages)-1] @@ -148,21 +158,139 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []sch // Transform chat history var cohereHistory []map[string]interface{} for _, msg := range chatHistory { - cohereHistory = append(cohereHistory, map[string]interface{}{ - "role": msg.Role, - "message": msg.Content, - }) + historyMsg := map[string]interface{}{ + "role": msg.Role, + } + + if msg.Role == schemas.ModelChatMessageRoleAssistant { + if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { + var toolCalls []map[string]interface{} + for _, toolCall := range *msg.AssistantMessage.ToolCalls { + var arguments map[string]interface{} + var parsedJSON interface{} + err := json.Unmarshal([]byte(toolCall.Function.Arguments), &parsedJSON) + if err == nil { + if arr, ok := parsedJSON.(map[string]interface{}); ok { + arguments = arr + } else { + arguments = map[string]interface{}{"content": parsedJSON} + } + } else { + arguments = map[string]interface{}{"content": toolCall.Function.Arguments} + } + + toolCalls = append(toolCalls, map[string]interface{}{ + "name": toolCall.Function.Name, + "parameters": arguments, + }) + } + historyMsg["tool_calls"] = toolCalls + } + } else if msg.Role == schemas.ModelChatMessageRoleTool { + // Find the original tool call parameters from conversation history + var toolCallParameters map[string]interface{} + + // Look back through the chat history to find the assistant message with the matching tool call + for i := len(chatHistory) - 1; i >= 0; i-- { + prevMsg := chatHistory[i] + if prevMsg.Role == schemas.ModelChatMessageRoleAssistant && + prevMsg.AssistantMessage != nil && + prevMsg.AssistantMessage.ToolCalls != nil { + + // Search through tool calls in this assistant message + for _, toolCall := range *prevMsg.AssistantMessage.ToolCalls { + if toolCall.ID != nil && msg.ToolMessage != nil && msg.ToolMessage.ToolCallID != nil && + *toolCall.ID == *msg.ToolMessage.ToolCallID { + + // Found the matching tool call, extract its parameters + var parsedJSON interface{} + err := json.Unmarshal([]byte(toolCall.Function.Arguments), &parsedJSON) + if err == nil { + if arr, ok := parsedJSON.(map[string]interface{}); ok { + toolCallParameters = arr + } else { + toolCallParameters = map[string]interface{}{"content": parsedJSON} + } + } else { + toolCallParameters = map[string]interface{}{"content": toolCall.Function.Arguments} + } + break + } + } + + // If we found the parameters, stop searching + if toolCallParameters != nil { + break + } + } + } + + // If no parameters found, use empty map as fallback + if toolCallParameters == nil { + toolCallParameters = map[string]interface{}{} + } + + toolResults := []map[string]interface{}{ + { + "call": map[string]interface{}{ + "name": *msg.ToolMessage.ToolCallID, + "parameters": toolCallParameters, + }, + "outputs": *msg.Content.ContentStr, + }, + } + + historyMsg["tool_results"] = toolResults + } + + if msg.Content.ContentStr != nil { + historyMsg["message"] = *msg.Content.ContentStr + } else if msg.Content.ContentBlocks != nil { + // Create content array with text and image + contentArray := []map[string]interface{}{} + + // Iterate over ContentBlocks to build the content array + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + contentArray = append(contentArray, map[string]interface{}{ + "type": "text", + "text": *block.Text, + }) + } + // Add image content using our helper function + // NOTE: Cohere v1 does not support image content + // if processedImageContent := processImageContent(block.ImageContent); processedImageContent != nil { + // contentArray = append(contentArray, processedImageContent) + // } + } + + historyMsg["content"] = contentArray + } + + cohereHistory = append(cohereHistory, historyMsg) } preparedParams := prepareParams(params) // Prepare request body requestBody := mergeConfig(map[string]interface{}{ - "message": lastMessage.Content, "chat_history": cohereHistory, "model": model, }, preparedParams) + // Handle the last message content based on whether it supports vision + if lastMessage.Content.ContentStr != nil { + requestBody["message"] = *lastMessage.Content.ContentStr + } else if lastMessage.Content.ContentBlocks != nil { + message := "" + for _, block := range *lastMessage.Content.ContentBlocks { + if block.Text != nil { + message += *block.Text + "\n" + } + } + requestBody["message"] = strings.TrimSuffix(message, "\n") + } + // Add tools if present if params != nil && params.Tools != nil && len(*params.Tools) > 0 { var tools []CohereTool @@ -196,6 +324,16 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []sch } requestBody["tools"] = tools } + // Add tool choice if present + if params != nil && params.ToolChoice != nil { + if params.ToolChoice.ToolChoiceStr != nil { + requestBody["tool_choice"] = *params.ToolChoice.ToolChoiceStr + } else if params.ToolChoice.ToolChoiceStruct != nil { + requestBody["tool_choice"] = map[string]interface{}{ + "type": strings.ToUpper(string(params.ToolChoice.ToolChoiceStruct.Type)), + } + } + } // Marshal request body jsonBody, err := json.Marshal(requestBody) @@ -215,25 +353,26 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []sch defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) - req.SetRequestURI("https://api.cohere.ai/v1/chat") + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat") req.Header.SetMethod("POST") req.Header.SetContentType("application/json") req.Header.Set("Authorization", "Bearer "+key) + req.SetBody(jsonBody) // Make request - if err := provider.client.Do(req, resp); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderRequest, - Error: err, - }, - } + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr } // Handle error response if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from cohere provider: %s", string(resp.Body()))) + var errorResp CohereError bifrostErr := handleProviderAPIError(resp, &errorResp) @@ -249,10 +388,6 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []sch response := acquireCohereResponse() defer releaseCohereResponse(response) - // Create Bifrost response from pool - bifrostResponse := acquireBifrostResponse() - defer releaseBifrostResponse(bifrostResponse) - rawResponse, bifrostErr := handleProviderResponse(responseBody, response) if bifrostErr != nil { return nil, bifrostErr @@ -287,49 +422,88 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []sch role = lastMsg.Role content = lastMsg.Message } else { - role = schemas.RoleChatbot + role = schemas.ModelChatMessageRoleChatbot content = response.Text } - bifrostResponse.ID = response.ResponseID - bifrostResponse.Choices = []schemas.BifrostResponseChoice{ - { - Index: 0, - Message: schemas.BifrostResponseChoiceMessage{ - Role: role, - Content: &content, - ToolCalls: &toolCalls, + // Create final response + bifrostResponse := &schemas.BifrostResponse{ + ID: response.ResponseID, + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + Message: schemas.BifrostMessage{ + Role: role, + Content: schemas.MessageContent{ + ContentStr: &content, + }, + AssistantMessage: &schemas.AssistantMessage{ + ToolCalls: &toolCalls, + }, + }, + FinishReason: &response.FinishReason, }, - FinishReason: &response.FinishReason, }, - } - bifrostResponse.Usage = schemas.LLMUsage{ - PromptTokens: int(response.Meta.Tokens.InputTokens), - CompletionTokens: int(response.Meta.Tokens.OutputTokens), - TotalTokens: int(response.Meta.Tokens.InputTokens + response.Meta.Tokens.OutputTokens), - } - bifrostResponse.Model = model - bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - Provider: schemas.Cohere, - BilledUsage: &schemas.BilledLLMUsage{ - PromptTokens: float64Ptr(response.Meta.BilledUnits.InputTokens), - CompletionTokens: float64Ptr(response.Meta.BilledUnits.OutputTokens), + Usage: schemas.LLMUsage{ + PromptTokens: int(response.Meta.Tokens.InputTokens), + CompletionTokens: int(response.Meta.Tokens.OutputTokens), + TotalTokens: int(response.Meta.Tokens.InputTokens + response.Meta.Tokens.OutputTokens), + }, + Model: model, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.Cohere, + BilledUsage: &schemas.BilledLLMUsage{ + PromptTokens: float64Ptr(response.Meta.BilledUnits.InputTokens), + CompletionTokens: float64Ptr(response.Meta.BilledUnits.OutputTokens), + }, + ChatHistory: convertChatHistory(response.ChatHistory), + RawResponse: rawResponse, }, - ChatHistory: convertChatHistory(response.ChatHistory), - RawResponse: rawResponse, + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params } return bifrostResponse, nil } +// processImageContent processes image content for Cohere API format. +// It creates a copy of the image content, normalizes and formats it, then returns the properly formatted map. +// This prevents unintended mutations to the original image content. +func processImageContent(imageContent *schemas.ImageURLStruct) map[string]interface{} { + if imageContent == nil { + return nil + } + + sanitizedURL, _ := SanitizeImageURL(imageContent.URL) + urlTypeInfo := ExtractURLTypeInfo(sanitizedURL) + + formattedImgContent := AnthropicImageContent{ + Type: urlTypeInfo.Type, + URL: sanitizedURL, + } + + if urlTypeInfo.MediaType != nil { + formattedImgContent.MediaType = *urlTypeInfo.MediaType + } + + return map[string]interface{}{ + "type": "image_url", + "image_url": map[string]interface{}{ + "url": formattedImgContent.URL, + }, + } +} + // convertChatHistory converts Cohere's chat history format to Bifrost's format for standardization. // It transforms the chat history messages and their tool calls. func convertChatHistory(history []struct { Role schemas.ModelChatMessageRole `json:"role"` Message string `json:"message"` ToolCalls []CohereToolCall `json:"tool_calls"` -}) *[]schemas.BifrostResponseChoiceMessage { - converted := make([]schemas.BifrostResponseChoiceMessage, len(history)) +}) *[]schemas.BifrostMessage { + converted := make([]schemas.BifrostMessage, len(history)) for i, msg := range history { var toolCalls []schemas.ToolCall if msg.ToolCalls != nil { @@ -350,10 +524,15 @@ func convertChatHistory(history []struct { }) } } - converted[i] = schemas.BifrostResponseChoiceMessage{ - Role: msg.Role, - Content: &msg.Message, - ToolCalls: &toolCalls, + + converted[i] = schemas.BifrostMessage{ + Role: msg.Role, + Content: schemas.MessageContent{ + ContentStr: &msg.Message, + }, + AssistantMessage: &schemas.AssistantMessage{ + ToolCalls: &toolCalls, + }, } } return &converted diff --git a/core/providers/mistral.go b/core/providers/mistral.go new file mode 100644 index 0000000000..6bc34da683 --- /dev/null +++ b/core/providers/mistral.go @@ -0,0 +1,187 @@ +// Package providers implements various LLM providers and their utility functions. +// This file contains the Mistral provider implementation. +package providers + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/goccy/go-json" + + schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" +) + +// MistralResponse represents the response structure from the Mistral API. +type MistralResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Choices []schemas.BifrostResponseChoice `json:"choices"` + Model string `json:"model"` + Created int `json:"created"` + Usage schemas.LLMUsage `json:"usage"` +} + +// mistralResponsePool provides a pool for Mistral response objects. +var mistralResponsePool = sync.Pool{ + New: func() interface{} { + return &MistralResponse{} + }, +} + +// acquireMistralResponse gets a Mistral response from the pool and resets it. +func acquireMistralResponse() *MistralResponse { + resp := mistralResponsePool.Get().(*MistralResponse) + *resp = MistralResponse{} // Reset the struct + return resp +} + +// releaseMistralResponse returns a Mistral response to the pool. +func releaseMistralResponse(resp *MistralResponse) { + if resp != nil { + mistralResponsePool.Put(resp) + } +} + +// MistralProvider implements the Provider interface for Mistral's API. +type MistralProvider struct { + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + networkConfig schemas.NetworkConfig // Network configuration including extra headers +} + +// NewMistralProvider creates a new Mistral provider instance. +// It initializes the HTTP client with the provided configuration and sets up response pools. +// The client is configured with timeouts, concurrency limits, and optional proxy settings. +func NewMistralProvider(config *schemas.ProviderConfig, logger schemas.Logger) *MistralProvider { + config.CheckAndSetDefaults() + + client := &fasthttp.Client{ + ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + MaxConnsPerHost: config.ConcurrencyAndBufferSize.Concurrency, + } + + // Pre-warm response pools + for range config.ConcurrencyAndBufferSize.Concurrency { + mistralResponsePool.Put(&MistralResponse{}) + } + + // Configure proxy if provided + client = configureProxy(client, config.ProxyConfig, logger) + + // Set default BaseURL if not provided + if config.NetworkConfig.BaseURL == "" { + config.NetworkConfig.BaseURL = "https://api.mistral.ai" + } + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") + + return &MistralProvider{ + logger: logger, + client: client, + networkConfig: config.NetworkConfig, + } +} + +// GetProviderKey returns the provider identifier for Mistral. +func (provider *MistralProvider) GetProviderKey() schemas.ModelProvider { + return schemas.Mistral +} + +// TextCompletion is not supported by the Mistral provider. +func (provider *MistralProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "text completion is not supported by mistral provider", + }, + } +} + +// ChatCompletion performs a chat completion request to the Mistral API. +func (provider *MistralProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + }, preparedParams) + + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderJSONMarshaling, + Error: err, + }, + } + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from mistral provider: %s", string(resp.Body()))) + + var errorResp map[string]interface{} + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = fmt.Sprintf("Mistral error: %v", errorResp) + return nil, bifrostErr + } + + responseBody := resp.Body() + + // Pre-allocate response structs from pools + response := acquireMistralResponse() + defer releaseMistralResponse(response) + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(responseBody, response) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Create final response + bifrostResponse := &schemas.BifrostResponse{ + ID: response.ID, + Object: response.Object, + Choices: response.Choices, + Model: response.Model, + Created: response.Created, + Usage: response.Usage, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.Mistral, + RawResponse: rawResponse, + }, + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil +} diff --git a/core/providers/ollama.go b/core/providers/ollama.go new file mode 100644 index 0000000000..bed89846dd --- /dev/null +++ b/core/providers/ollama.go @@ -0,0 +1,190 @@ +// Package providers implements various LLM providers and their utility functions. +// This file contains the Ollama provider implementation. +package providers + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/goccy/go-json" + + schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" +) + +// OllamaResponse represents the response structure from the Ollama API. +type OllamaResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Choices []schemas.BifrostResponseChoice `json:"choices"` + Model string `json:"model"` + Created int `json:"created"` + Usage schemas.LLMUsage `json:"usage"` +} + +// ollamaResponsePool provides a pool for Ollama response objects. +var ollamaResponsePool = sync.Pool{ + New: func() interface{} { + return &OllamaResponse{} + }, +} + +// acquireOllamaResponse gets a Ollama response from the pool and resets it. +func acquireOllamaResponse() *OllamaResponse { + resp := ollamaResponsePool.Get().(*OllamaResponse) + *resp = OllamaResponse{} // Reset the struct + return resp +} + +// releaseOllamaResponse returns a Ollama response to the pool. +func releaseOllamaResponse(resp *OllamaResponse) { + if resp != nil { + ollamaResponsePool.Put(resp) + } +} + +// OllamaProvider implements the Provider interface for Ollama's API. +type OllamaProvider struct { + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + networkConfig schemas.NetworkConfig // Network configuration including extra headers +} + +// NewOllamaProvider creates a new Ollama provider instance. +// It initializes the HTTP client with the provided configuration and sets up response pools. +// The client is configured with timeouts, concurrency limits, and optional proxy settings. +func NewOllamaProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*OllamaProvider, error) { + config.CheckAndSetDefaults() + + client := &fasthttp.Client{ + ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + } + + // Pre-warm response pools + for range config.ConcurrencyAndBufferSize.Concurrency { + ollamaResponsePool.Put(&OllamaResponse{}) + } + + // Configure proxy if provided + client = configureProxy(client, config.ProxyConfig, logger) + + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") + + // BaseURL is required for Ollama + if config.NetworkConfig.BaseURL == "" { + return nil, fmt.Errorf("base_url is required for ollama provider") + } + + return &OllamaProvider{ + logger: logger, + client: client, + networkConfig: config.NetworkConfig, + }, nil +} + +// GetProviderKey returns the provider identifier for Ollama. +func (provider *OllamaProvider) GetProviderKey() schemas.ModelProvider { + return schemas.Ollama +} + +// TextCompletion is not supported by the Ollama provider. +func (provider *OllamaProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "text completion is not supported by ollama provider", + }, + } +} + +// ChatCompletion performs a chat completion request to the Ollama API. +func (provider *OllamaProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + }, preparedParams) + + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderJSONMarshaling, + Error: err, + }, + } + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + if key != "" { + req.Header.Set("Authorization", "Bearer "+key) + } + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from ollama provider: %s", string(resp.Body()))) + + var errorResp map[string]interface{} + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = fmt.Sprintf("Ollama error: %v", errorResp) + return nil, bifrostErr + } + + responseBody := resp.Body() + + // Pre-allocate response structs from pools + response := acquireOllamaResponse() + defer releaseOllamaResponse(response) + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(responseBody, response) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Create final response + bifrostResponse := &schemas.BifrostResponse{ + ID: response.ID, + Object: response.Object, + Choices: response.Choices, + Model: response.Model, + Created: response.Created, + Usage: response.Usage, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.Ollama, + RawResponse: rawResponse, + }, + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil +} diff --git a/core/providers/openai.go b/core/providers/openai.go index ff96a69b77..bd345e3c0b 100644 --- a/core/providers/openai.go +++ b/core/providers/openai.go @@ -3,6 +3,9 @@ package providers import ( + "context" + "fmt" + "strings" "sync" "time" @@ -60,36 +63,43 @@ func releaseOpenAIResponse(resp *OpenAIResponse) { } } -// OpenAIProvider implements the Provider interface for OpenAI's API. +// OpenAIProvider implements the Provider interface for OpenAI's GPT API. type OpenAIProvider struct { - logger schemas.Logger // Logger for provider operations - client *fasthttp.Client // HTTP client for API requests + logger schemas.Logger // Logger for provider operations + client *fasthttp.Client // HTTP client for API requests + networkConfig schemas.NetworkConfig // Network configuration including extra headers } // NewOpenAIProvider creates a new OpenAI provider instance. // It initializes the HTTP client with the provided configuration and sets up response pools. // The client is configured with timeouts, concurrency limits, and optional proxy settings. func NewOpenAIProvider(config *schemas.ProviderConfig, logger schemas.Logger) *OpenAIProvider { - setConfigDefaults(config) + config.CheckAndSetDefaults() client := &fasthttp.Client{ ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), - MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + MaxConnsPerHost: config.ConcurrencyAndBufferSize.Concurrency, } // Pre-warm response pools for range config.ConcurrencyAndBufferSize.Concurrency { openAIResponsePool.Put(&OpenAIResponse{}) - bifrostResponsePool.Put(&schemas.BifrostResponse{}) } // Configure proxy if provided client = configureProxy(client, config.ProxyConfig, logger) + // Set default BaseURL if not provided + if config.NetworkConfig.BaseURL == "" { + config.NetworkConfig.BaseURL = "https://api.openai.com" + } + config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") + return &OpenAIProvider{ - logger: logger, - client: client, + logger: logger, + client: client, + networkConfig: config.NetworkConfig, } } @@ -100,7 +110,7 @@ func (provider *OpenAIProvider) GetProviderKey() schemas.ModelProvider { // TextCompletion is not supported by the OpenAI provider. // Returns an error indicating that text completion is not available. -func (provider *OpenAIProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -112,47 +122,8 @@ func (provider *OpenAIProvider) TextCompletion(model, key, text string, params * // ChatCompletion performs a chat completion request to the OpenAI API. // It supports both text and image content in messages. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - // Format messages for OpenAI API - var formattedMessages []map[string]interface{} - for _, msg := range messages { - if msg.ImageContent != nil { - var content []map[string]interface{} - - // Add text content if present - if msg.Content != nil { - content = append(content, map[string]interface{}{ - "type": "text", - "text": msg.Content, - }) - } - - imageContent := map[string]interface{}{ - "type": "image_url", - "image_url": map[string]interface{}{ - "url": msg.ImageContent.URL, - }, - } - - if msg.ImageContent.Detail != nil { - imageContent["image_url"].(map[string]interface{})["detail"] = msg.ImageContent.Detail - } - - content = append(content, imageContent) - - formattedMessages = append(formattedMessages, map[string]interface{}{ - "role": msg.Role, - "content": content, - }) - } else { - formattedMessages = append(formattedMessages, map[string]interface{}{ - "role": msg.Role, - "content": msg.Content, - }) - } - } - - preparedParams := prepareParams(params) +func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) requestBody := mergeConfig(map[string]interface{}{ "model": model, @@ -176,35 +147,40 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []sch defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) - req.SetRequestURI("https://api.openai.com/v1/chat/completions") + // Set any extra headers from network config + setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + + req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") req.Header.SetMethod("POST") req.Header.SetContentType("application/json") req.Header.Set("Authorization", "Bearer "+key) + req.SetBody(jsonBody) // Make request - if err := provider.client.Do(req, resp); err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderRequest, - Error: err, - }, - } + bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr } // Handle error response if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from openai provider: %s", string(resp.Body()))) + var errorResp OpenAIError bifrostErr := handleProviderAPIError(resp, &errorResp) - bifrostErr.EventID = &errorResp.EventID + if errorResp.EventID != "" { + bifrostErr.EventID = &errorResp.EventID + } bifrostErr.Error.Type = &errorResp.Error.Type bifrostErr.Error.Code = &errorResp.Error.Code bifrostErr.Error.Message = errorResp.Error.Message bifrostErr.Error.Param = errorResp.Error.Param - bifrostErr.Error.EventID = &errorResp.Error.EventID + if errorResp.Error.EventID != "" { + bifrostErr.Error.EventID = &errorResp.Error.EventID + } return nil, bifrostErr } @@ -215,28 +191,76 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []sch response := acquireOpenAIResponse() defer releaseOpenAIResponse(response) - result := acquireBifrostResponse() - defer releaseBifrostResponse(result) - // Use enhanced response handler with pre-allocated response rawResponse, bifrostErr := handleProviderResponse(responseBody, response) if bifrostErr != nil { return nil, bifrostErr } - // Populate result from response - result.ID = response.ID - result.Choices = response.Choices - result.Object = response.Object - result.Usage = response.Usage - result.ServiceTier = response.ServiceTier - result.SystemFingerprint = response.SystemFingerprint - result.Model = response.Model - result.Created = response.Created - result.ExtraFields = schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - RawResponse: rawResponse, + // Create final response + bifrostResponse := &schemas.BifrostResponse{ + ID: response.ID, + Object: response.Object, + Choices: response.Choices, + Model: response.Model, + Created: response.Created, + ServiceTier: response.ServiceTier, + SystemFingerprint: response.SystemFingerprint, + Usage: response.Usage, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.OpenAI, + RawResponse: rawResponse, + }, + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil +} + +func prepareOpenAIChatRequest(messages []schemas.BifrostMessage, params *schemas.ModelParameters) ([]map[string]interface{}, map[string]interface{}) { + // Format messages for OpenAI API + var formattedMessages []map[string]interface{} + for _, msg := range messages { + if msg.Role == schemas.ModelChatMessageRoleAssistant { + assistantMessage := map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { + assistantMessage["tool_calls"] = *msg.AssistantMessage.ToolCalls + } + formattedMessages = append(formattedMessages, assistantMessage) + } else { + message := map[string]interface{}{ + "role": msg.Role, + } + + if msg.Content.ContentStr != nil { + message["content"] = *msg.Content.ContentStr + } else if msg.Content.ContentBlocks != nil { + contentBlocks := *msg.Content.ContentBlocks + for i := range contentBlocks { + if contentBlocks[i].Type == schemas.ContentBlockTypeImage && contentBlocks[i].ImageURL != nil { + sanitizedURL, _ := SanitizeImageURL(contentBlocks[i].ImageURL.URL) + contentBlocks[i].ImageURL.URL = sanitizedURL + } + } + + message["content"] = contentBlocks + } + + if msg.ToolMessage != nil && msg.ToolMessage.ToolCallID != nil { + message["tool_call_id"] = *msg.ToolMessage.ToolCallID + } + + formattedMessages = append(formattedMessages, message) + } } - return result, nil + preparedParams := prepareParams(params) + + return formattedMessages, preparedParams } diff --git a/core/providers/utils.go b/core/providers/utils.go index 2988fac358..94712f91dc 100644 --- a/core/providers/utils.go +++ b/core/providers/utils.go @@ -3,9 +3,14 @@ package providers import ( + "context" "fmt" + "net/http" + "net/textproto" "net/url" "reflect" + "regexp" + "slices" "strings" "sync" @@ -17,27 +22,46 @@ import ( "maps" ) -// bifrostResponsePool provides a pool for Bifrost response objects. -var bifrostResponsePool = sync.Pool{ - New: func() interface{} { - return &schemas.BifrostResponse{} - }, +// dataURIRegex is a precompiled regex for matching data URI format patterns. +// It matches patterns like: data:image/png;base64,iVBORw0KGgo... +var dataURIRegex = regexp.MustCompile(`^data:([^;]+)(;base64)?,(.+)$`) + +// base64Regex is a precompiled regex for matching base64 strings. +// It matches strings containing only valid base64 characters with optional padding. +var base64Regex = regexp.MustCompile(`^[A-Za-z0-9+/]*={0,2}$`) + +// fileExtensionToMediaType maps common image file extensions to their corresponding media types. +// This map is used to infer media types from file extensions in URLs. +var fileExtensionToMediaType = map[string]string{ + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".png": "image/png", + ".gif": "image/gif", + ".webp": "image/webp", + ".svg": "image/svg+xml", + ".bmp": "image/bmp", } -// acquireBifrostResponse gets a Bifrost response from the pool and resets it. -func acquireBifrostResponse() *schemas.BifrostResponse { - resp := bifrostResponsePool.Get().(*schemas.BifrostResponse) - *resp = schemas.BifrostResponse{} // Reset the struct - return resp -} +// ImageContentType represents the type of image content +type ImageContentType string -// releaseBifrostResponse returns a Bifrost response to the pool. -func releaseBifrostResponse(resp *schemas.BifrostResponse) { - if resp != nil { - bifrostResponsePool.Put(resp) - } +const ( + ImageContentTypeBase64 ImageContentType = "base64" + ImageContentTypeURL ImageContentType = "url" +) + +// URLTypeInfo contains extracted information about a URL +type URLTypeInfo struct { + Type ImageContentType + MediaType *string + DataURLWithoutPrefix *string // URL without the prefix (eg data:image/png;base64,iVBORw0KGgo...) } +// ContextKey is a custom type for context keys to prevent key collisions in the context. +// It provides type safety for context values and ensures that context keys are unique +// across different packages. +type ContextKey string + // mergeConfig merges a default configuration map with custom parameters. // It creates a new map containing all default values, then overrides them with any custom values. // Returns a new map containing the merged configuration. @@ -104,6 +128,49 @@ func prepareParams(params *schemas.ModelParameters) map[string]interface{} { return flatParams } +// IMPORTANT: This function does NOT truly cancel the underlying fasthttp network request if the +// context is done. The fasthttp client call will continue in its goroutine until it completes +// or times out based on its own settings. This function merely stops *waiting* for the +// fasthttp call and returns an error related to the context. +func makeRequestWithContext(ctx context.Context, client *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) *schemas.BifrostError { + errChan := make(chan error, 1) + + go func() { + // client.Do is a blocking call. + // It will send an error (or nil for success) to errChan when it completes. + errChan <- client.Do(req, resp) + }() + + select { + case <-ctx.Done(): + // Context was cancelled (e.g., deadline exceeded or manual cancellation). + // Return a BifrostError indicating this. + return &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Type: StrPtr(schemas.RequestCancelled), + Message: fmt.Sprintf("Request cancelled or timed out by context: %v", ctx.Err()), + Error: ctx.Err(), + }, + } + case err := <-errChan: + // The fasthttp.Do call completed. + if err != nil { + // The HTTP request itself failed (e.g., connection error, fasthttp timeout). + return &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderRequest, + Error: err, + }, + } + } + // HTTP request was successful from fasthttp's perspective (err is nil). + // The caller should check resp.StatusCode() for HTTP-level errors (4xx, 5xx). + return nil + } +} + // configureProxy sets up a proxy for the fasthttp client based on the provided configuration. // It supports HTTP, SOCKS5, and environment-based proxy configurations. // Returns the configured client or the original client if proxy configuration is invalid. @@ -157,13 +224,73 @@ func configureProxy(client *fasthttp.Client, proxyConfig *schemas.ProxyConfig, l return client } +// setExtraHeaders sets additional headers from NetworkConfig to the fasthttp request. +// This allows users to configure custom headers for their provider requests. +// Header keys are canonicalized using textproto.CanonicalMIMEHeaderKey to avoid duplicates. +// The Authorization header is excluded for security reasons. +// It accepts a list of headers (all canonicalized) to skip for security reasons. +// Headers are only set if they don't already exist on the request to avoid overwriting important headers. +func setExtraHeaders(req *fasthttp.Request, extraHeaders map[string]string, skipHeaders *[]string) { + if extraHeaders == nil { + return + } + + for key, value := range extraHeaders { + canonicalKey := textproto.CanonicalMIMEHeaderKey(key) + // Skip Authorization header for security reasons + if key == "Authorization" { + continue + } + if skipHeaders != nil { + if slices.Contains(*skipHeaders, key) { + continue + } + } + // Only set the header if it doesn't already exist to avoid overwriting important headers + if len(req.Header.Peek(canonicalKey)) == 0 { + req.Header.Set(canonicalKey, value) + } + } +} + +// setExtraHeadersHTTP sets additional headers from NetworkConfig to the standard HTTP request. +// This allows users to configure custom headers for their provider requests. +// Header keys are canonicalized using textproto.CanonicalMIMEHeaderKey to avoid duplicates. +// It accepts a list of headers (all canonicalized) to skip for security reasons. +// Headers are only set if they don't already exist on the request to avoid overwriting important headers. +func setExtraHeadersHTTP(req *http.Request, extraHeaders map[string]string, skipHeaders *[]string) { + if extraHeaders == nil { + return + } + + for key, value := range extraHeaders { + canonicalKey := textproto.CanonicalMIMEHeaderKey(key) + // Skip Authorization header for security reasons + if key == "Authorization" { + continue + } + if skipHeaders != nil { + if slices.Contains(*skipHeaders, key) { + continue + } + } + // Only set the header if it doesn't already exist to avoid overwriting important headers + if req.Header.Get(canonicalKey) == "" { + req.Header.Set(canonicalKey, value) + } + } +} + // handleProviderAPIError processes error responses from provider APIs. // It attempts to unmarshal the error response and returns a BifrostError // with the appropriate status code and error information. func handleProviderAPIError(resp *fasthttp.Response, errorResp any) *schemas.BifrostError { + statusCode := resp.StatusCode() + if err := json.Unmarshal(resp.Body(), &errorResp); err != nil { return &schemas.BifrostError{ IsBifrostError: true, + StatusCode: &statusCode, Error: schemas.ErrorField{ Message: schemas.ErrProviderResponseUnmarshal, Error: err, @@ -171,8 +298,6 @@ func handleProviderAPIError(resp *fasthttp.Response, errorResp any) *schemas.Bif } } - statusCode := resp.StatusCode() - return &schemas.BifrostError{ IsBifrostError: false, StatusCode: &statusCode, @@ -223,38 +348,196 @@ func handleProviderResponse[T any](responseBody []byte, response *T) (interface{ return rawResponse, nil } +// getRoleFromMessage extracts and validates the role from a message map. +func getRoleFromMessage(msg map[string]interface{}) (schemas.ModelChatMessageRole, bool) { + roleVal, exists := msg["role"] + if !exists { + return "", false // Role key doesn't exist + } + + // Try direct assertion to ModelChatMessageRole + roleAsModelType, ok := roleVal.(schemas.ModelChatMessageRole) + if ok { + return roleAsModelType, true + } + + // Try assertion to string and then convert + roleAsString, okStr := roleVal.(string) + if okStr { + return schemas.ModelChatMessageRole(roleAsString), true + } + + return "", false // Role is of an unexpected or invalid type +} + // float64Ptr creates a pointer to a float64 value. // This is a helper function for creating pointers to float64 values. func float64Ptr(f float64) *float64 { return &f } -func setConfigDefaults(config *schemas.ProviderConfig) { - if config.ConcurrencyAndBufferSize.Concurrency == 0 { - config.ConcurrencyAndBufferSize.Concurrency = schemas.DefaultConcurrency +// StrPtr creates a pointer to a string value. +// This is a helper function for creating pointers to string values. +func StrPtr(s string) *string { + return &s +} + +//* IMAGE UTILS *// + +// SanitizeImageURL sanitizes and validates an image URL. +// It handles both data URLs and regular HTTP/HTTPS URLs. +// It also detects raw base64 image data and adds proper data URL headers. +func SanitizeImageURL(rawURL string) (string, error) { + if rawURL == "" { + return rawURL, fmt.Errorf("URL cannot be empty") } - if config.ConcurrencyAndBufferSize.BufferSize == 0 { - config.ConcurrencyAndBufferSize.BufferSize = schemas.DefaultBufferSize + // Trim whitespace + rawURL = strings.TrimSpace(rawURL) + + // Check if it's already a proper data URL + if strings.HasPrefix(rawURL, "data:") { + // Validate data URL format + if !dataURIRegex.MatchString(rawURL) { + return rawURL, fmt.Errorf("invalid data URL format") + } + return rawURL, nil } - if config.NetworkConfig.DefaultRequestTimeoutInSeconds == 0 { - config.NetworkConfig.DefaultRequestTimeoutInSeconds = schemas.DefaultRequestTimeoutInSeconds + // Check if it looks like raw base64 image data + if isLikelyBase64(rawURL) { + // Detect the image type from the base64 data + mediaType := detectImageTypeFromBase64(rawURL) + + // Remove any whitespace/newlines from base64 data + cleanBase64 := strings.ReplaceAll(strings.ReplaceAll(rawURL, "\n", ""), " ", "") + + // Create proper data URL + return fmt.Sprintf("data:%s;base64,%s", mediaType, cleanBase64), nil } - if config.NetworkConfig.MaxRetries == 0 { - config.NetworkConfig.MaxRetries = schemas.DefaultMaxRetries + // Parse as regular URL + parsedURL, err := url.Parse(rawURL) + if err != nil { + return rawURL, fmt.Errorf("invalid URL format: %w", err) } - if config.NetworkConfig.RetryBackoffInitial == 0 { - config.NetworkConfig.RetryBackoffInitial = schemas.DefaultRetryBackoffInitial + // Validate scheme + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return rawURL, fmt.Errorf("URL must use http or https scheme") } - if config.NetworkConfig.RetryBackoffMax == 0 { - config.NetworkConfig.RetryBackoffMax = schemas.DefaultRetryBackoffMax + // Validate host + if parsedURL.Host == "" { + return rawURL, fmt.Errorf("URL must have a valid host") } + + return parsedURL.String(), nil } -func StrPtr(s string) *string { - return &s +// ExtractURLTypeInfo extracts type and media type information from a sanitized URL. +// For data URLs, it parses the media type and encoding. +// For regular URLs, it attempts to infer the media type from the file extension. +func ExtractURLTypeInfo(sanitizedURL string) URLTypeInfo { + if strings.HasPrefix(sanitizedURL, "data:") { + return extractDataURLInfo(sanitizedURL) + } + return extractRegularURLInfo(sanitizedURL) +} + +// extractDataURLInfo extracts information from a data URL +func extractDataURLInfo(dataURL string) URLTypeInfo { + // Parse data URL: data:[][;base64], + matches := dataURIRegex.FindStringSubmatch(dataURL) + + if len(matches) != 4 { + return URLTypeInfo{Type: ImageContentTypeBase64} + } + + mediaType := matches[1] + isBase64 := matches[2] == ";base64" + + dataURLWithoutPrefix := dataURL + if isBase64 { + dataURLWithoutPrefix = dataURL[len("data:")+len(mediaType)+len(";base64,"):] + } + + info := URLTypeInfo{ + MediaType: &mediaType, + DataURLWithoutPrefix: &dataURLWithoutPrefix, + } + + if isBase64 { + info.Type = ImageContentTypeBase64 + } else { + info.Type = ImageContentTypeURL // Non-base64 data URL + } + + return info +} + +// extractRegularURLInfo extracts information from a regular HTTP/HTTPS URL +func extractRegularURLInfo(regularURL string) URLTypeInfo { + info := URLTypeInfo{ + Type: ImageContentTypeURL, + } + + // Try to infer media type from file extension + parsedURL, err := url.Parse(regularURL) + if err != nil { + return info + } + + path := strings.ToLower(parsedURL.Path) + + // Check for known file extensions using the map + for ext, mediaType := range fileExtensionToMediaType { + if strings.HasSuffix(path, ext) { + info.MediaType = &mediaType + break + } + } + // For URLs without recognizable extensions, MediaType remains nil + + return info +} + +// detectImageTypeFromBase64 detects the image type from base64 data by examining the header bytes +func detectImageTypeFromBase64(base64Data string) string { + // Remove any whitespace or newlines + cleanData := strings.ReplaceAll(strings.ReplaceAll(base64Data, "\n", ""), " ", "") + + // Check common image format signatures in base64 + switch { + case strings.HasPrefix(cleanData, "/9j/") || strings.HasPrefix(cleanData, "/9k/"): + // JPEG images typically start with /9j/ or /9k/ in base64 (FFD8 in hex) + return "image/jpeg" + case strings.HasPrefix(cleanData, "iVBORw0KGgo"): + // PNG images start with iVBORw0KGgo in base64 (89504E470D0A1A0A in hex) + return "image/png" + case strings.HasPrefix(cleanData, "R0lGOD"): + // GIF images start with R0lGOD in base64 (474946 in hex) + return "image/gif" + case strings.HasPrefix(cleanData, "Qk"): + // BMP images start with Qk in base64 (424D in hex) + return "image/bmp" + case strings.HasPrefix(cleanData, "UklGR") && len(cleanData) >= 16 && cleanData[12:16] == "V0VC": + // WebP images start with RIFF header (UklGR in base64) and have WEBP signature at offset 8-11 (V0VC in base64) + return "image/webp" + case strings.HasPrefix(cleanData, "PHN2Zy") || strings.HasPrefix(cleanData, "PD94bW"): + // SVG images often start with 0 { + return nil, &schemas.BifrostError{ + StatusCode: &resp.StatusCode, + Type: &vertexErr[0].Error.Status, + Error: schemas.ErrorField{ + Message: vertexErr[0].Error.Message, + }, + } + } + } + + return nil, &schemas.BifrostError{ + StatusCode: &resp.StatusCode, + Error: schemas.ErrorField{ + Message: openAIErr.Error.Message, + }, + } + } + + if strings.Contains(model, "claude") { + // Create response object from pool + response := acquireAnthropicChatResponse() + defer releaseAnthropicChatResponse(response) + + rawResponse, bifrostErr := handleProviderResponse(body, response) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Create final response + bifrostResponse := &schemas.BifrostResponse{} + var err *schemas.BifrostError + bifrostResponse, err = parseAnthropicResponse(response, bifrostResponse) + if err != nil { + return nil, err + } + + bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ + Provider: schemas.Vertex, + RawResponse: rawResponse, + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil + } else { + // Pre-allocate response structs from pools + response := acquireOpenAIResponse() + defer releaseOpenAIResponse(response) + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(body, response) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Create final response + bifrostResponse := &schemas.BifrostResponse{ + ID: response.ID, + Object: response.Object, + Choices: response.Choices, + Model: response.Model, + Created: response.Created, + ServiceTier: response.ServiceTier, + SystemFingerprint: response.SystemFingerprint, + Usage: response.Usage, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.Vertex, + RawResponse: rawResponse, + }, + } + + if params != nil { + bifrostResponse.ExtraFields.Params = *params + } + + return bifrostResponse, nil + } +} diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 4e3f06041e..7fb0320ae3 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -1,6 +1,11 @@ // Package schemas defines the core schemas and types used by the Bifrost system. package schemas +import ( + "encoding/json" + "fmt" +) + const ( DefaultInitialPoolSize = 100 ) @@ -12,19 +17,20 @@ type BifrostConfig struct { Account Account Plugins []Plugin Logger Logger - InitialPoolSize int // Initial pool size for sync pools in Bifrost. Higher values will reduce memory allocations but will increase memory usage. - DropExcessRequests bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. + InitialPoolSize int // Initial pool size for sync pools in Bifrost. Higher values will reduce memory allocations but will increase memory usage. + DropExcessRequests bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. + MCPConfig *MCPConfig // MCP (Model Context Protocol) configuration for tool integration } // ModelChatMessageRole represents the role of a chat message type ModelChatMessageRole string const ( - RoleAssistant ModelChatMessageRole = "assistant" - RoleUser ModelChatMessageRole = "user" - RoleSystem ModelChatMessageRole = "system" - RoleChatbot ModelChatMessageRole = "chatbot" - RoleTool ModelChatMessageRole = "tool" + ModelChatMessageRoleAssistant ModelChatMessageRole = "assistant" + ModelChatMessageRoleUser ModelChatMessageRole = "user" + ModelChatMessageRoleSystem ModelChatMessageRole = "system" + ModelChatMessageRoleChatbot ModelChatMessageRole = "chatbot" + ModelChatMessageRoleTool ModelChatMessageRole = "tool" ) // ModelProvider represents the different AI model providers supported by Bifrost. @@ -36,6 +42,9 @@ const ( Anthropic ModelProvider = "anthropic" Bedrock ModelProvider = "bedrock" Cohere ModelProvider = "cohere" + Vertex ModelProvider = "vertex" + Mistral ModelProvider = "mistral" + Ollama ModelProvider = "ollama" ) //* Request Structs @@ -43,28 +52,29 @@ const ( // RequestInput represents the input for a model request, which can be either // a text completion or a chat completion, but either one must be provided. type RequestInput struct { - TextCompletionInput *string - ChatCompletionInput *[]Message + TextCompletionInput *string `json:"text_completion_input,omitempty"` + ChatCompletionInput *[]BifrostMessage `json:"chat_completion_input,omitempty"` } // BifrostRequest represents a request to be processed by Bifrost. // It must be provided when calling the Bifrost for text completion or chat completion. // It contains the model identifier, input data, and parameters for the request. type BifrostRequest struct { - Model string - Input RequestInput - Params *ModelParameters + Provider ModelProvider `json:"provider"` + Model string `json:"model"` + Input RequestInput `json:"input"` + Params *ModelParameters `json:"params,omitempty"` // Fallbacks are tried in order, the first one to succeed is returned // Provider config must be available for each fallback's provider in account's GetConfigForProvider, // else it will be skipped. - Fallbacks []Fallback + Fallbacks []Fallback `json:"fallbacks,omitempty"` } // Fallback represents a fallback model to be used if the primary model is not available. type Fallback struct { - Provider ModelProvider - Model string + Provider ModelProvider `json:"provider"` + Model string `json:"model"` } // ModelParameters represents the parameters that can be used to configure @@ -73,15 +83,14 @@ type Fallback struct { type ModelParameters struct { ToolChoice *ToolChoice `json:"tool_choice,omitempty"` Tools *[]Tool `json:"tools,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` // Controls randomness in the output - TopP *float64 `json:"top_p,omitempty"` // Controls diversity via nucleus sampling - TopK *int `json:"top_k,omitempty"` // Controls diversity via top-k sampling - MaxTokens *int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate - StopSequences *[]string `json:"stop_sequences,omitempty"` // Sequences that stop generation - PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens - FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens - ParallelToolCalls *bool `json:"parallel_tool_calls"` // Enables parallel tool calls - + Temperature *float64 `json:"temperature,omitempty"` // Controls randomness in the output + TopP *float64 `json:"top_p,omitempty"` // Controls diversity via nucleus sampling + TopK *int `json:"top_k,omitempty"` // Controls diversity via top-k sampling + MaxTokens *int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate + StopSequences *[]string `json:"stop_sequences,omitempty"` // Sequences that stop generation + PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Enables parallel tool calls // Dynamic parameters that can be provider-specific, they are directly // added to the request as is. ExtraParams map[string]interface{} `json:"-"` @@ -89,10 +98,11 @@ type ModelParameters struct { // FunctionParameters represents the parameters for a function definition. type FunctionParameters struct { - Type string `json:"type,"` // Type of the parameters + Type string `json:"type"` // Type of the parameters Description *string `json:"description,omitempty"` // Description of the parameters - Required []string `json:"required"` // Required parameter names - Properties map[string]interface{} `json:"properties"` // Parameter properties + Required []string `json:"required,omitempty"` // Required parameter names + Properties map[string]interface{} `json:"properties,omitempty"` // Parameter properties + Enum *[]string `json:"enum,omitempty"` // Enum values for the parameters } // Function represents a function that can be called by the model. @@ -114,16 +124,16 @@ type Tool struct { type ToolChoiceType string const ( - // ToolChoiceNone means no tool will be called - ToolChoiceNone ToolChoiceType = "none" - // ToolChoiceAuto means the model can choose whether to call a tool - ToolChoiceAuto ToolChoiceType = "auto" - // ToolChoiceAny means any tool can be called - ToolChoiceAny ToolChoiceType = "any" - // ToolChoiceTool means a specific tool must be called - ToolChoiceTool ToolChoiceType = "tool" - // ToolChoiceRequired means a tool must be called - ToolChoiceRequired ToolChoiceType = "required" + // ToolChoiceTypeNone means no tool will be called + ToolChoiceTypeNone ToolChoiceType = "none" + // ToolChoiceTypeAuto means the model can choose whether to call a tool + ToolChoiceTypeAuto ToolChoiceType = "auto" + // ToolChoiceTypeAny means any tool can be called + ToolChoiceTypeAny ToolChoiceType = "any" + // ToolChoiceTypeFunction means a specific tool must be called (converted to "tool" for Anthropic) + ToolChoiceTypeFunction ToolChoiceType = "function" + // ToolChoiceTypeRequired means a tool must be called + ToolChoiceTypeRequired ToolChoiceType = "required" ) // ToolChoiceFunction represents a specific function to be called. @@ -131,26 +141,147 @@ type ToolChoiceFunction struct { Name string `json:"name"` // Name of the function to call } -// ToolChoice represents how a tool should be chosen for a request. +// ToolChoiceStruct represents a specific tool choice. +type ToolChoiceStruct struct { + Type ToolChoiceType `json:"type"` // Type of tool choice + Function ToolChoiceFunction `json:"function,omitempty"` // Function to call if type is ToolChoiceTypeFunction +} + +// ToolChoice represents how a tool should be chosen for a request. (either a string or a struct) type ToolChoice struct { - Type ToolChoiceType `json:"type"` // Type of tool choice - Function ToolChoiceFunction `json:"function"` // Function to call if type is ToolChoiceTool + ToolChoiceStr *string + ToolChoiceStruct *ToolChoiceStruct +} + +// MarshalJSON implements custom JSON marshalling for ToolChoice. +// It marshals either ToolChoiceStr or ToolChoiceStruct directly without wrapping. +func (tc ToolChoice) MarshalJSON() ([]byte, error) { + // Validation: ensure only one field is set at a time + if tc.ToolChoiceStr != nil && tc.ToolChoiceStruct != nil { + return nil, fmt.Errorf("both ToolChoiceStr and ToolChoiceStruct are set; only one should be non-nil") + } + + if tc.ToolChoiceStr != nil { + return json.Marshal(*tc.ToolChoiceStr) + } + if tc.ToolChoiceStruct != nil { + return json.Marshal(*tc.ToolChoiceStruct) + } + // If both are nil, return null + return json.Marshal(nil) +} + +// UnmarshalJSON implements custom JSON unmarshalling for ToolChoice. +// It determines whether "tool_choice" is a string or struct and assigns to the appropriate field. +// It also handles direct string/array content without a wrapper object. +func (tc *ToolChoice) UnmarshalJSON(data []byte) error { + // First, try to unmarshal as a direct string + var stringContent string + if err := json.Unmarshal(data, &stringContent); err == nil { + tc.ToolChoiceStr = &stringContent + return nil + } + + // Try to unmarshal as a direct struct of ToolChoiceStruct + var toolChoiceStruct ToolChoiceStruct + if err := json.Unmarshal(data, &toolChoiceStruct); err == nil { + // Validate the Type field is not empty and is a valid value + if toolChoiceStruct.Type == "" { + return fmt.Errorf("tool_choice struct has empty type field") + } + + tc.ToolChoiceStruct = &toolChoiceStruct + return nil + } + + return fmt.Errorf("tool_choice field is neither a string nor a struct") +} + +// BifrostMessage represents a message in a chat conversation. +type BifrostMessage struct { + Role ModelChatMessageRole `json:"role"` + Content MessageContent `json:"content"` + + // Embedded pointer structs - when non-nil, their exported fields are flattened into the top-level JSON object + // IMPORTANT: Only one of the following can be non-nil at a time, otherwise the JSON marshalling will override the common fields + *ToolMessage + *AssistantMessage +} + +type MessageContent struct { + ContentStr *string + ContentBlocks *[]ContentBlock +} + +// MarshalJSON implements custom JSON marshalling for MessageContent. +// It marshals either ContentStr or ContentBlocks directly without wrapping. +func (mc MessageContent) MarshalJSON() ([]byte, error) { + // Validation: ensure only one field is set at a time + if mc.ContentStr != nil && mc.ContentBlocks != nil { + return nil, fmt.Errorf("both ContentStr and ContentBlocks are set; only one should be non-nil") + } + + if mc.ContentStr != nil { + return json.Marshal(*mc.ContentStr) + } + if mc.ContentBlocks != nil { + return json.Marshal(*mc.ContentBlocks) + } + // If both are nil, return null + return json.Marshal(nil) } -// Message represents a single message in a chat conversation. -type Message struct { - Role ModelChatMessageRole `json:"role"` - Content *string `json:"content,omitempty"` - ImageContent *ImageContent `json:"image_content,omitempty"` - ToolCalls *[]Tool `json:"tool_calls,omitempty"` +// UnmarshalJSON implements custom JSON unmarshalling for MessageContent. +// It determines whether "content" is a string or array and assigns to the appropriate field. +// It also handles direct string/array content without a wrapper object. +func (mc *MessageContent) UnmarshalJSON(data []byte) error { + // First, try to unmarshal as a direct string + var stringContent string + if err := json.Unmarshal(data, &stringContent); err == nil { + mc.ContentStr = &stringContent + return nil + } + + // Try to unmarshal as a direct array of ContentBlock + var arrayContent []ContentBlock + if err := json.Unmarshal(data, &arrayContent); err == nil { + mc.ContentBlocks = &arrayContent + return nil + } + + return fmt.Errorf("content field is neither a string nor an array of ContentBlock") +} + +type ContentBlockType string + +const ( + ContentBlockTypeText ContentBlockType = "text" + ContentBlockTypeImage ContentBlockType = "image_url" +) + +type ContentBlock struct { + Type ContentBlockType `json:"type"` + Text *string `json:"text,omitempty"` + ImageURL *ImageURLStruct `json:"image_url,omitempty"` +} + +// ToolMessage represents a message from a tool +type ToolMessage struct { + ToolCallID *string `json:"tool_call_id,omitempty"` +} + +// AssistantMessage represents a message from an assistant +type AssistantMessage struct { + Refusal *string `json:"refusal,omitempty"` + Annotations []Annotation `json:"annotations,omitempty"` + ToolCalls *[]ToolCall `json:"tool_calls,omitempty"` + Thought *string `json:"thought,omitempty"` } // ImageContent represents image data in a message. -type ImageContent struct { - Type *string `json:"type"` - URL string `json:"url"` - MediaType *string `json:"media_type"` - Detail *string `json:"detail"` +type ImageURLStruct struct { + URL string `json:"url"` + Detail *string `json:"detail,omitempty"` } //* Response Structs @@ -260,41 +391,42 @@ type Annotation struct { Citation Citation `json:"url_citation"` } -// BifrostResponseChoiceMessage represents a choice in the completion response -type BifrostResponseChoiceMessage struct { - Role ModelChatMessageRole `json:"role"` - Content *string `json:"content,omitempty"` - Refusal *string `json:"refusal,omitempty"` - Annotations []Annotation `json:"annotations,omitempty"` - ToolCalls *[]ToolCall `json:"tool_calls,omitempty"` -} - // BifrostResponseChoice represents a choice in the completion result type BifrostResponseChoice struct { - Index int `json:"index"` - Message BifrostResponseChoiceMessage `json:"message"` - FinishReason *string `json:"finish_reason,omitempty"` - StopString *string `json:"stop,omitempty"` - LogProbs *LogProbs `json:"log_probs,omitempty"` + Index int `json:"index"` + Message BifrostMessage `json:"message"` + FinishReason *string `json:"finish_reason,omitempty"` + StopString *string `json:"stop,omitempty"` + LogProbs *LogProbs `json:"log_probs,omitempty"` } // BifrostResponseExtraFields contains additional fields in a response. type BifrostResponseExtraFields struct { - Provider ModelProvider `json:"provider"` - Params ModelParameters `json:"model_params"` - Latency *float64 `json:"latency,omitempty"` - ChatHistory *[]BifrostResponseChoiceMessage `json:"chat_history,omitempty"` - BilledUsage *BilledLLMUsage `json:"billed_usage,omitempty"` - RawResponse interface{} `json:"raw_response"` + Provider ModelProvider `json:"provider"` + Params ModelParameters `json:"model_params"` + Latency *float64 `json:"latency,omitempty"` + ChatHistory *[]BifrostMessage `json:"chat_history,omitempty"` + BilledUsage *BilledLLMUsage `json:"billed_usage,omitempty"` + RawResponse interface{} `json:"raw_response"` } +const ( + RequestCancelled = "request_cancelled" +) + // BifrostError represents an error from the Bifrost system. +// +// PLUGIN DEVELOPERS: When creating BifrostError in PreHook or PostHook, you can set AllowFallbacks: +// - AllowFallbacks = &true: Bifrost will try fallback providers if available +// - AllowFallbacks = &false: Bifrost will return this error immediately, no fallbacks +// - AllowFallbacks = nil: Treated as true by default (fallbacks allowed for resilience) type BifrostError struct { EventID *string `json:"event_id,omitempty"` Type *string `json:"type,omitempty"` IsBifrostError bool `json:"is_bifrost_error"` StatusCode *int `json:"status_code,omitempty"` Error ErrorField `json:"error"` + AllowFallbacks *bool `json:"allow_fallbacks,omitempty"` // Optional: Controls fallback behavior (nil = true by default) } // ErrorField represents detailed error information. diff --git a/core/schemas/mcp.go b/core/schemas/mcp.go new file mode 100644 index 0000000000..22d189e5bf --- /dev/null +++ b/core/schemas/mcp.go @@ -0,0 +1,35 @@ +// Package schemas defines the core schemas and types used by the Bifrost system. +package schemas + +// MCPConfig represents the configuration for MCP integration in Bifrost. +// It enables tool auto-discovery and execution from local and external MCP servers. +type MCPConfig struct { + ServerPort *int `json:"server_port,omitempty"` // Port for local MCP server (only required for local tool setup, defaults to 8181) + ClientConfigs []MCPClientConfig `json:"client_configs,omitempty"` // Per-client execution configurations +} + +// MCPClientConfig defines tool filtering for an MCP client. +type MCPClientConfig struct { + Name string `json:"name"` // Client name + ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, or SSE) + ConnectionString *string `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections) + StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections) + ToolsToSkip []string `json:"tools_to_skip,omitempty"` // Tools to exclude from this client + ToolsToExecute []string `json:"tools_to_execute,omitempty"` // Tools to include from this client (if specified, only these are used) +} + +// MCPConnectionType defines the communication protocol for MCP connections +type MCPConnectionType string + +const ( + MCPConnectionTypeHTTP MCPConnectionType = "http" // HTTP-based MCP connection (streamable) + MCPConnectionTypeSTDIO MCPConnectionType = "stdio" // STDIO-based MCP connection + MCPConnectionTypeSSE MCPConnectionType = "sse" // Server-Sent Events MCP connection +) + +// MCPStdioConfig defines how to launch a STDIO-based MCP server. +type MCPStdioConfig struct { + Command string `json:"command"` // Executable command to run + Args []string `json:"args"` // Command line arguments + Envs []string `json:"envs"` // Environment variables required +} diff --git a/core/schemas/meta/azure.go b/core/schemas/meta/azure.go index df5fd163b9..58abbd071c 100644 --- a/core/schemas/meta/azure.go +++ b/core/schemas/meta/azure.go @@ -11,31 +11,6 @@ type AzureMetaConfig struct { APIVersion *string `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-02-01" } -// This is not used for Azure. -func (c *AzureMetaConfig) GetSecretAccessKey() *string { - return nil -} - -// This is not used for Azure. -func (c *AzureMetaConfig) GetRegion() *string { - return nil -} - -// This is not used for Azure. -func (c *AzureMetaConfig) GetSessionToken() *string { - return nil -} - -// This is not used for Azure. -func (c *AzureMetaConfig) GetARN() *string { - return nil -} - -// This is not used for Azure. -func (c *AzureMetaConfig) GetInferenceProfiles() map[string]string { - return nil -} - // GetEndpoint returns the Azure service endpoint. // This specifies the base URL for Azure API requests. func (c *AzureMetaConfig) GetEndpoint() *string { @@ -44,7 +19,7 @@ func (c *AzureMetaConfig) GetEndpoint() *string { // GetDeployments returns the deployment configurations. // This maps model names to their corresponding Azure deployment names. -// Eg. "gpt-4o": "your-deployment-name-for-gpt-4o" +// E.g. "gpt-4o": "your-deployment-name-for-gpt-4o" func (c *AzureMetaConfig) GetDeployments() map[string]string { return c.Deployments } @@ -54,3 +29,12 @@ func (c *AzureMetaConfig) GetDeployments() map[string]string { func (c *AzureMetaConfig) GetAPIVersion() *string { return c.APIVersion } + +// These are not used for Azure. +func (c *AzureMetaConfig) GetARN() *string { return nil } +func (c *AzureMetaConfig) GetAuthCredentials() *string { return nil } +func (c *AzureMetaConfig) GetInferenceProfiles() map[string]string { return nil } +func (c *AzureMetaConfig) GetProjectID() *string { return nil } +func (c *AzureMetaConfig) GetRegion() *string { return nil } +func (c *AzureMetaConfig) GetSecretAccessKey() *string { return nil } +func (c *AzureMetaConfig) GetSessionToken() *string { return nil } diff --git a/core/schemas/meta/bedrock.go b/core/schemas/meta/bedrock.go index 1a875d3f65..bdff19e76a 100644 --- a/core/schemas/meta/bedrock.go +++ b/core/schemas/meta/bedrock.go @@ -43,17 +43,9 @@ func (c *BedrockMetaConfig) GetInferenceProfiles() map[string]string { return c.InferenceProfiles } -// This is not used for Bedrock. -func (c *BedrockMetaConfig) GetEndpoint() *string { - return nil -} - -// This is not used for Bedrock. -func (c *BedrockMetaConfig) GetDeployments() map[string]string { - return nil -} - -// This is not used for Bedrock. -func (c *BedrockMetaConfig) GetAPIVersion() *string { - return nil -} +// These are not used for Bedrock. +func (c *BedrockMetaConfig) GetAPIVersion() *string { return nil } +func (c *BedrockMetaConfig) GetAuthCredentials() *string { return nil } +func (c *BedrockMetaConfig) GetDeployments() map[string]string { return nil } +func (c *BedrockMetaConfig) GetEndpoint() *string { return nil } +func (c *BedrockMetaConfig) GetProjectID() *string { return nil } diff --git a/core/schemas/meta/vertex.go b/core/schemas/meta/vertex.go new file mode 100644 index 0000000000..a82e46380d --- /dev/null +++ b/core/schemas/meta/vertex.go @@ -0,0 +1,39 @@ +// Package meta provides provider-specific configuration structures and schemas. +// This file contains the AWS Vertex-specific configuration implementation. + +package meta + +// VertexMetaConfig represents the Vertex-specific configuration. +// It contains Vertex-specific settings required for authentication and service access. +type VertexMetaConfig struct { + ProjectID string `json:"project_id,omitempty"` + Region string `json:"region,omitempty"` + AuthCredentials string `json:"auth_credentials,omitempty"` +} + +// GetRegion returns the Vertex region. +// This is the region for the Vertex project. +func (c *VertexMetaConfig) GetRegion() *string { + return &c.Region +} + +// GetProjectID returns the Vertex project ID. +// This is the project ID for the Vertex project. +func (c *VertexMetaConfig) GetProjectID() *string { + return &c.ProjectID +} + +// GetAuthCredentials returns the authentication credentials for the provider. +// This is the authentication credentials for the google cloud api. +func (c *VertexMetaConfig) GetAuthCredentials() *string { + return &c.AuthCredentials +} + +// These are not used for Vertex. +func (c *VertexMetaConfig) GetAPIVersion() *string { return nil } +func (c *VertexMetaConfig) GetARN() *string { return nil } +func (c *VertexMetaConfig) GetDeployments() map[string]string { return nil } +func (c *VertexMetaConfig) GetEndpoint() *string { return nil } +func (c *VertexMetaConfig) GetInferenceProfiles() map[string]string { return nil } +func (c *VertexMetaConfig) GetSecretAccessKey() *string { return nil } +func (c *VertexMetaConfig) GetSessionToken() *string { return nil } diff --git a/core/schemas/plugin.go b/core/schemas/plugin.go index c10adebf3e..7a8fdf554e 100644 --- a/core/schemas/plugin.go +++ b/core/schemas/plugin.go @@ -1,7 +1,17 @@ // Package schemas defines the core schemas and types used by the Bifrost system. package schemas -import "context" +import ( + "context" + "encoding/json" +) + +// PluginShortCircuit represents a plugin's decision to short-circuit the normal flow. +// It can contain either a response (success short-circuit) or an error (error short-circuit). +type PluginShortCircuit struct { + Response *BifrostResponse // If set, short-circuit with this response (skips provider call) + Error *BifrostError // If set, short-circuit with this error (can set AllowFallbacks field) +} // Plugin defines the interface for Bifrost plugins. // Plugins can intercept and modify requests and responses at different stages @@ -9,22 +19,52 @@ import "context" // User can provide multiple plugins in the BifrostConfig. // PreHooks are executed in the order they are registered. // PostHooks are executed in the reverse order of PreHooks. - +// // PreHooks and PostHooks can be used to implement custom logic, such as: // - Rate limiting // - Caching // - Logging // - Monitoring +// +// Plugin error handling: +// - No Plugin errors are returned to the caller; they are logged as warnings by the Bifrost instance. +// - PreHook and PostHook can both modify the request/response and the error. Plugins can recover from errors (set error to nil and provide a response), or invalidate a response (set response to nil and provide an error). +// - PostHook is always called with both the current response and error, and should handle either being nil. +// - Only truly empty errors (no message, no error, no status code, no type) are treated as recoveries by the pipeline. +// - If a PreHook returns a PluginShortCircuit, the provider call may be skipped and only the PostHook methods of plugins that had their PreHook executed are called in reverse order. +// - The plugin pipeline ensures symmetry: for every PreHook executed, the corresponding PostHook will be called in reverse order. +// +// IMPORTANT: When returning BifrostError from PreHook or PostHook: +// - You can set the AllowFallbacks field to control fallback behavior +// - AllowFallbacks = &true: Allow Bifrost to try fallback providers +// - AllowFallbacks = &false: Do not try fallbacks, return error immediately +// - AllowFallbacks = nil: Treated as true by default (allow fallbacks for resilience) +// +// Plugin authors should ensure their hooks are robust to both response and error being nil, and should not assume either is always present. type Plugin interface { + // GetName returns the name of the plugin. + GetName() string + // PreHook is called before a request is processed by a provider. // It allows plugins to modify the request before it is sent to the provider. // The context parameter can be used to maintain state across plugin calls. - // Returns the modified request and any error that occurred during processing. - PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, error) + // Returns the modified request, an optional short-circuit decision, and any error that occurred during processing. + PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *PluginShortCircuit, error) + + // PostHook is called after a response is received from a provider or a PreHook short-circuit. + // It allows plugins to modify the response and/or error before it is returned to the caller. + // Plugins can recover from errors (set error to nil and provide a response), or invalidate a response (set response to nil and provide an error). + // Returns the modified response, bifrost error, and any error that occurred during processing. + PostHook(ctx *context.Context, result *BifrostResponse, err *BifrostError) (*BifrostResponse, *BifrostError, error) - // PostHook is called after a response is received from a provider. - // It allows plugins to modify the response before it is returned to the caller. - // Returns the modified response and any error that occurred during processing. - PostHook(ctx *context.Context, result *BifrostResponse) (*BifrostResponse, error) + // Cleanup is called on bifrost shutdown. + // It allows plugins to clean up any resources they have allocated. + // Returns any error that occurred during cleanup, which will be logged as a warning by the Bifrost instance. + Cleanup() error } + +// Init defines the standardized constructor signature for all plugins. +// All plugins should implement: Init(config json.RawMessage) (Plugin, error) +// This is enforced at development time through documentation and examples. +type Init func(config json.RawMessage) (Plugin, error) diff --git a/core/schemas/provider.go b/core/schemas/provider.go index 56376b730f..800178063f 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -1,7 +1,11 @@ // Package schemas defines the core schemas and types used by the Bifrost system. package schemas -import "time" +import ( + "context" + "maps" + "time" +) const ( DefaultMaxRetries = 0 @@ -23,11 +27,23 @@ const ( ) // NetworkConfig represents the network configuration for provider connections. +// ExtraHeaders is automatically copied during provider initialization to prevent data races. type NetworkConfig struct { - DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"` // Default timeout for requests - MaxRetries int `json:"max_retries"` // Maximum number of retries - RetryBackoffInitial time.Duration `json:"retry_backoff_initial"` // Initial backoff duration - RetryBackoffMax time.Duration `json:"retry_backoff_max"` // Maximum backoff duration + // BaseURL is supported for OpenAI, Anthropic, Cohere, Mistral, and Ollama providers (required for Ollama) + BaseURL string `json:"base_url,omitempty"` // Base URL for the provider (optional) + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // Additional headers to include in requests (optional) + DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"` // Default timeout for requests + MaxRetries int `json:"max_retries"` // Maximum number of retries + RetryBackoffInitial time.Duration `json:"retry_backoff_initial"` // Initial backoff duration + RetryBackoffMax time.Duration `json:"retry_backoff_max"` // Maximum backoff duration +} + +// DefaultNetworkConfig is the default network configuration for provider connections. +var DefaultNetworkConfig = NetworkConfig{ + DefaultRequestTimeoutInSeconds: DefaultRequestTimeoutInSeconds, + MaxRetries: DefaultMaxRetries, + RetryBackoffInitial: DefaultRetryBackoffInitial, + RetryBackoffMax: DefaultRetryBackoffMax, } // MetaConfig defines the interface for provider-specific configuration. @@ -49,6 +65,10 @@ type MetaConfig interface { GetDeployments() map[string]string // GetAPIVersion returns the API version GetAPIVersion() *string + // GetProjectID returns the project ID + GetProjectID() *string + // GetAuthCredentials returns the authentication credentials for the provider + GetAuthCredentials() *string } // ConcurrencyAndBufferSize represents configuration for concurrent operations and buffer sizes. @@ -57,6 +77,12 @@ type ConcurrencyAndBufferSize struct { BufferSize int `json:"buffer_size"` // Size of the buffer } +// DefaultConcurrencyAndBufferSize is the default concurrency and buffer size for provider operations. +var DefaultConcurrencyAndBufferSize = ConcurrencyAndBufferSize{ + Concurrency: DefaultConcurrency, + BufferSize: DefaultBufferSize, +} + // ProxyType defines the type of proxy to use for connections. type ProxyType string @@ -91,12 +117,45 @@ type ProviderConfig struct { ProxyConfig *ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration } +func (config *ProviderConfig) CheckAndSetDefaults() { + if config.ConcurrencyAndBufferSize.Concurrency == 0 { + config.ConcurrencyAndBufferSize.Concurrency = DefaultConcurrency + } + + if config.ConcurrencyAndBufferSize.BufferSize == 0 { + config.ConcurrencyAndBufferSize.BufferSize = DefaultBufferSize + } + + if config.NetworkConfig.DefaultRequestTimeoutInSeconds == 0 { + config.NetworkConfig.DefaultRequestTimeoutInSeconds = DefaultRequestTimeoutInSeconds + } + + if config.NetworkConfig.MaxRetries == 0 { + config.NetworkConfig.MaxRetries = DefaultMaxRetries + } + + if config.NetworkConfig.RetryBackoffInitial == 0 { + config.NetworkConfig.RetryBackoffInitial = DefaultRetryBackoffInitial + } + + if config.NetworkConfig.RetryBackoffMax == 0 { + config.NetworkConfig.RetryBackoffMax = DefaultRetryBackoffMax + } + + // Create a defensive copy of ExtraHeaders to prevent data races + if config.NetworkConfig.ExtraHeaders != nil { + headersCopy := make(map[string]string, len(config.NetworkConfig.ExtraHeaders)) + maps.Copy(headersCopy, config.NetworkConfig.ExtraHeaders) + config.NetworkConfig.ExtraHeaders = headersCopy + } +} + // Provider defines the interface for AI model providers. type Provider interface { // GetProviderKey returns the provider's identifier GetProviderKey() ModelProvider // TextCompletion performs a text completion request - TextCompletion(model, key, text string, params *ModelParameters) (*BifrostResponse, *BifrostError) + TextCompletion(ctx context.Context, model, key, text string, params *ModelParameters) (*BifrostResponse, *BifrostError) // ChatCompletion performs a chat completion request - ChatCompletion(model, key string, messages []Message, params *ModelParameters) (*BifrostResponse, *BifrostError) + ChatCompletion(ctx context.Context, model, key string, messages []BifrostMessage, params *ModelParameters) (*BifrostResponse, *BifrostError) } diff --git a/core/tests/account.go b/core/tests/account.go deleted file mode 100644 index 53278b7ce3..0000000000 --- a/core/tests/account.go +++ /dev/null @@ -1,202 +0,0 @@ -// Package tests provides test utilities and configurations for the Bifrost system. -// It includes test implementations of schemas, mock objects, and helper functions -// for testing the Bifrost functionality with various AI providers. -package tests - -import ( - "fmt" - "os" - "time" - - schemas "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/core/schemas/meta" -) - -// BaseAccount provides a test implementation of the Account interface. -// It implements basic account functionality for testing purposes, supporting -// multiple AI providers including OpenAI, Anthropic, Bedrock, Cohere, and Azure. -// The implementation uses environment variables from the .env file for API keys and provides -// default configurations suitable for testing. -type BaseAccount struct{} - -// GetConfiguredProviders returns the list of initially supported providers. -// This implementation returns OpenAI, Anthropic, and Bedrock as the default providers. -// -// Returns: -// - []schemas.SupportedModelProvider: A slice containing the supported provider identifiers -// - error: Always returns nil as this implementation doesn't produce errors -func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { - return []schemas.ModelProvider{schemas.OpenAI, schemas.Anthropic, schemas.Bedrock, schemas.Cohere, schemas.Azure}, nil -} - -// GetKeysForProvider returns the API keys and associated models for a given provider. -// It retrieves API keys from environment variables and maps them to their supported models. -// Each key includes a weight value for load balancing purposes. -// -// Parameters: -// - providerKey: The identifier of the provider to get keys for -// -// Returns: -// - []schemas.Key: A slice of Key objects containing API keys and their configurations -// - error: An error if the provider is not supported -// -// Environment Variables Used: -// - OPENAI_API_KEY: API key for OpenAI -// - ANTHROPIC_API_KEY: API key for Anthropic -// - BEDROCK_API_KEY: API key for AWS Bedrock -// - COHERE_API_KEY: API key for Cohere -// - AZURE_API_KEY: API key for Azure OpenAI -func (baseAccount *BaseAccount) GetKeysForProvider(providerKey schemas.ModelProvider) ([]schemas.Key, error) { - switch providerKey { - case schemas.OpenAI: - return []schemas.Key{ - { - Value: os.Getenv("OPENAI_API_KEY"), - Models: []string{"gpt-4o-mini", "gpt-4-turbo"}, - Weight: 1.0, - }, - }, nil - case schemas.Anthropic: - return []schemas.Key{ - { - Value: os.Getenv("ANTHROPIC_API_KEY"), - Models: []string{"claude-3-7-sonnet-20250219", "claude-3-5-sonnet-20240620", "claude-2.1"}, - Weight: 1.0, - }, - }, nil - case schemas.Bedrock: - return []schemas.Key{ - { - Value: os.Getenv("BEDROCK_API_KEY"), - Models: []string{"anthropic.claude-v2:1", "mistral.mixtral-8x7b-instruct-v0:1", "mistral.mistral-large-2402-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"}, - Weight: 1.0, - }, - }, nil - case schemas.Cohere: - return []schemas.Key{ - { - Value: os.Getenv("COHERE_API_KEY"), - Models: []string{"command-a-03-2025"}, - Weight: 1.0, - }, - }, nil - case schemas.Azure: - return []schemas.Key{ - { - Value: os.Getenv("AZURE_API_KEY"), - Models: []string{"gpt-4o"}, - Weight: 1.0, - }, - }, nil - default: - return nil, fmt.Errorf("unsupported provider: %s", providerKey) - } -} - -// GetConfigForProvider returns the configuration settings for a given provider. -// It provides standardized configuration settings for network operations, -// concurrency, and provider-specific metadata. -// -// Parameters: -// - providerKey: The identifier of the provider to get configuration for -// -// Returns: -// - *schemas.ProviderConfig: Configuration settings for the provider, including: -// - Network settings (timeouts, retries, backoff) -// - Concurrency and buffer size settings -// - Provider-specific metadata (for Bedrock and Azure) -// - error: An error if the provider is not supported -// -// Environment Variables Used: -// - BEDROCK_ACCESS_KEY: AWS access key for Bedrock configuration -// - AZURE_ENDPOINT: Azure endpoint for Azure OpenAI configuration -// -// Default Settings: -// - Request Timeout: 30 seconds -// - Max Retries: 1 -// - Initial Backoff: 100ms -// - Max Backoff: 2s -// - Concurrency: 3 -// - Buffer Size: 10 -func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { - switch providerKey { - case schemas.OpenAI: - return &schemas.ProviderConfig{ - NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, - MaxRetries: 1, - RetryBackoffInitial: 100 * time.Millisecond, - RetryBackoffMax: 2 * time.Second, - }, - ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, - BufferSize: 10, - }, - }, nil - case schemas.Anthropic: - return &schemas.ProviderConfig{ - NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, - MaxRetries: 1, - RetryBackoffInitial: 100 * time.Millisecond, - RetryBackoffMax: 2 * time.Second, - }, - ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, - BufferSize: 10, - }, - }, nil - case schemas.Bedrock: - return &schemas.ProviderConfig{ - NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, - MaxRetries: 1, - RetryBackoffInitial: 100 * time.Millisecond, - RetryBackoffMax: 2 * time.Second, - }, - MetaConfig: &meta.BedrockMetaConfig{ - SecretAccessKey: os.Getenv("BEDROCK_ACCESS_KEY"), - Region: StrPtr("us-east-1"), - }, - ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, - BufferSize: 10, - }, - }, nil - case schemas.Cohere: - return &schemas.ProviderConfig{ - NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, - MaxRetries: 1, - RetryBackoffInitial: 100 * time.Millisecond, - RetryBackoffMax: 2 * time.Second, - }, - ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, - BufferSize: 10, - }, - }, nil - case schemas.Azure: - return &schemas.ProviderConfig{ - NetworkConfig: schemas.NetworkConfig{ - DefaultRequestTimeoutInSeconds: 30, - MaxRetries: 1, - RetryBackoffInitial: 100 * time.Millisecond, - RetryBackoffMax: 2 * time.Second, - }, - MetaConfig: &meta.AzureMetaConfig{ - Endpoint: os.Getenv("AZURE_ENDPOINT"), - Deployments: map[string]string{ - "gpt-4o": "gpt-4o-aug", - }, - APIVersion: StrPtr("2024-08-01-preview"), - }, - ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ - Concurrency: 3, - BufferSize: 10, - }, - }, nil - default: - return nil, fmt.Errorf("unsupported provider: %s", providerKey) - } -} diff --git a/core/tests/anthropic_test.go b/core/tests/anthropic_test.go deleted file mode 100644 index 5df5170b45..0000000000 --- a/core/tests/anthropic_test.go +++ /dev/null @@ -1,37 +0,0 @@ -// Package tests provides test utilities and configurations for the Bifrost system. -// It includes test implementations of schemas, mock objects, and helper functions -// for testing the Bifrost functionality with various AI providers. -package tests - -import ( - "testing" - - schemas "github.com/maximhq/bifrost/core/schemas" -) - -func TestAnthropic(t *testing.T) { - bifrost, err := getBifrost() - if err != nil { - t.Fatalf("Error initializing bifrost: %v", err) - return - } - - maxTokens := 4096 - - config := TestConfig{ - Provider: schemas.Anthropic, - TextModel: "claude-2.1", - ChatModel: "claude-3-5-sonnet-20240620", - SetupText: true, - SetupToolCalls: false, // available in 3.7 sonnet - SetupImage: true, - SetupBaseImage: true, - CustomParams: &schemas.ModelParameters{ - MaxTokens: &maxTokens, - }, - } - - SetupAllRequests(bifrost, config) - - bifrost.Cleanup() -} diff --git a/core/tests/azure_test.go b/core/tests/azure_test.go deleted file mode 100644 index 81f37b6308..0000000000 --- a/core/tests/azure_test.go +++ /dev/null @@ -1,30 +0,0 @@ -// Package tests provides test utilities and configurations for the Bifrost system. -// It includes test implementations of schemas, mock objects, and helper functions -// for testing the Bifrost functionality with various AI providers. -package tests - -import ( - "testing" - - schemas "github.com/maximhq/bifrost/core/schemas" -) - -func TestAzure(t *testing.T) { - bifrost, err := getBifrost() - if err != nil { - t.Fatalf("Error initializing bifrost: %v", err) - return - } - - config := TestConfig{ - Provider: schemas.Azure, - ChatModel: "gpt-4o", - SetupText: false, // gpt-4o does not support text completion - SetupToolCalls: true, - SetupImage: true, - SetupBaseImage: false, - } - - SetupAllRequests(bifrost, config) - bifrost.Cleanup() -} diff --git a/core/tests/bedrock_test.go b/core/tests/bedrock_test.go deleted file mode 100644 index ed227629cc..0000000000 --- a/core/tests/bedrock_test.go +++ /dev/null @@ -1,38 +0,0 @@ -// Package tests provides test utilities and configurations for the Bifrost system. -// It includes test implementations of schemas, mock objects, and helper functions -// for testing the Bifrost functionality with various AI providers. -package tests - -import ( - "testing" - - schemas "github.com/maximhq/bifrost/core/schemas" -) - -func TestBedrock(t *testing.T) { - bifrost, err := getBifrost() - if err != nil { - t.Fatalf("Error initializing bifrost: %v", err) - return - } - - maxTokens := 4096 - textCompletion := "\n\nHuman:\n\nAssistant:" - - config := TestConfig{ - Provider: schemas.Bedrock, - TextModel: "anthropic.claude-v2:1", - ChatModel: "anthropic.claude-3-sonnet-20240229-v1:0", - SetupText: true, - SetupToolCalls: true, - SetupImage: true, - SetupBaseImage: false, - CustomParams: &schemas.ModelParameters{ - MaxTokens: &maxTokens, - }, - CustomTextCompletion: &textCompletion, - } - - SetupAllRequests(bifrost, config) - bifrost.Cleanup() -} diff --git a/core/tests/cohere_test.go b/core/tests/cohere_test.go deleted file mode 100644 index 37a7bfb37b..0000000000 --- a/core/tests/cohere_test.go +++ /dev/null @@ -1,31 +0,0 @@ -// Package tests provides test utilities and configurations for the Bifrost system. -// It includes test implementations of schemas, mock objects, and helper functions -// for testing the Bifrost functionality with various AI providers. -package tests - -import ( - "testing" - - schemas "github.com/maximhq/bifrost/core/schemas" -) - -func TestCohere(t *testing.T) { - bifrost, err := getBifrost() - if err != nil { - t.Fatalf("Error initializing bifrost: %v", err) - return - } - - config := TestConfig{ - Provider: schemas.Cohere, - ChatModel: "command-a-03-2025", - SetupText: false, // Cohere does not support text completion - SetupToolCalls: true, - SetupImage: false, - SetupBaseImage: false, - } - - SetupAllRequests(bifrost, config) - - bifrost.Cleanup() -} diff --git a/core/tests/openai_test.go b/core/tests/openai_test.go deleted file mode 100644 index cae22a1b79..0000000000 --- a/core/tests/openai_test.go +++ /dev/null @@ -1,37 +0,0 @@ -// Package tests provides test utilities and configurations for the Bifrost system. -// It includes test implementations of schemas, mock objects, and helper functions -// for testing the Bifrost functionality with various AI providers. -package tests - -import ( - "testing" - - schemas "github.com/maximhq/bifrost/core/schemas" -) - -func TestOpenAI(t *testing.T) { - bifrost, err := getBifrost() - if err != nil { - t.Fatalf("Error initializing bifrost: %v", err) - return - } - - config := TestConfig{ - Provider: schemas.OpenAI, - TextModel: "gpt-4o-mini", - ChatModel: "gpt-4o-mini", - SetupText: true, // OpenAI does not support text completion - SetupToolCalls: false, - SetupImage: false, - SetupBaseImage: false, - Fallbacks: []schemas.Fallback{ - { - Provider: schemas.Anthropic, - Model: "claude-3-5-sonnet-20240620", - }, - }, - } - - SetupAllRequests(bifrost, config) - bifrost.Cleanup() -} diff --git a/core/tests/setup.go b/core/tests/setup.go deleted file mode 100644 index 9af47664f3..0000000000 --- a/core/tests/setup.go +++ /dev/null @@ -1,106 +0,0 @@ -// Package tests provides test utilities and configurations for the Bifrost system. -// It includes test implementations of schemas, mock objects, and helper functions -// for testing the Bifrost functionality with various AI providers. -package tests - -import ( - "fmt" - "log" - "os" - - bifrost "github.com/maximhq/bifrost/core" - schemas "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/plugins" - - "github.com/joho/godotenv" -) - -// loadEnv loads environment variables from a .env file into the process environment. -// It uses the godotenv package to load variables and fails if the .env file cannot be loaded. -// -// Environment Variables: -// - .env file: Contains configuration values for the test environment -// -// Returns: -// - None, but will log.Fatal if the .env file cannot be loaded -func loadEnv() { - err := godotenv.Load() - if err != nil { - log.Fatal("Error loading .env file:", err) - } -} - -// getPlugin initializes and returns a Plugin instance for testing purposes. -// It sets up the Maxim logger with configuration from environment variables. -// -// Environment Variables: -// - MAXIM_API_KEY: API key for Maxim SDK authentication -// - MAXIM_LOGGER_ID: ID for the Maxim logger instance -// -// Returns: -// - schemas.Plugin: A configured plugin instance for request/response tracing -// - error: Any error that occurred during plugin initialization -func getPlugin() (schemas.Plugin, error) { - loadEnv() - - // check if Maxim Logger variables are set - if os.Getenv("MAXIM_API_KEY") == "" { - return nil, fmt.Errorf("MAXIM_API_KEY is not set, please set it in your .env file or pass nil in the Plugins field when initializing Bifrost") - } - - if os.Getenv("MAXIM_LOGGER_ID") == "" { - return nil, fmt.Errorf("MAXIM_LOGGER_ID is not set, please set it in your .env file or pass nil in the Plugins field when initializing Bifrost") - } - - plugin, err := plugins.NewMaximLoggerPlugin(os.Getenv("MAXIM_API_KEY"), os.Getenv("MAXIM_LOGGER_ID")) - if err != nil { - return nil, err - } - - return plugin, nil -} - -// getBifrost initializes and returns a Bifrost instance for testing. -// It sets up the test account, plugin, and logger configuration. -// -// Environment Variables: -// - Uses environment variables loaded by loadEnv() -// -// Returns: -// - *bifrost.Bifrost: A configured Bifrost instance ready for testing -// - error: Any error that occurred during Bifrost initialization -// -// The function: -// 1. Loads environment variables -// 2. Creates a test account instance -// 3. Initializes a plugin for request tracing -// 4. Configures Bifrost with the account, plugin, and default logger -func getBifrost() (*bifrost.Bifrost, error) { - loadEnv() - - account := BaseAccount{} - - // You can pass nil in the Plugins field if you don't want to use the implemented example plugin. - plugin, err := getPlugin() - if err != nil { - fmt.Println("Error setting up the plugin:", err) - return nil, err - } - - // Initialize Bifrost - b, err := bifrost.Init(schemas.BifrostConfig{ - Account: &account, - // Plugins: nil, - Plugins: []schemas.Plugin{plugin}, - Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug), - }) - if err != nil { - return nil, err - } - - return b, nil -} - -func StrPtr(s string) *string { - return &s -} diff --git a/core/tests/tests.go b/core/tests/tests.go deleted file mode 100644 index f8a70b8d53..0000000000 --- a/core/tests/tests.go +++ /dev/null @@ -1,314 +0,0 @@ -// Package tests provides test utilities and configurations for the Bifrost system. -// It includes test implementations of schemas, mock objects, and helper functions -// for testing the Bifrost functionality with various AI providers. -package tests - -import ( - "context" - "fmt" - "time" - - bifrost "github.com/maximhq/bifrost/core" - schemas "github.com/maximhq/bifrost/core/schemas" -) - -// TestConfig holds configuration for test requests across different AI providers. -// It provides a flexible way to configure test scenarios for various provider capabilities. -// -// Fields: -// - Provider: The AI provider to test (e.g., OpenAI, Anthropic, etc.) -// - ChatModel: The model to use for chat completion tests -// - TextModel: The model to use for text completion tests -// - Messages: Custom messages to use in chat tests (optional) -// - SetupText: Whether to run text completion tests -// - SetupToolCalls: Whether to run function calling tests -// - SetupImage: Whether to run image input tests -// - SetupBaseImage: Whether to run base64 image tests -// - CustomTextCompletion: Custom text for completion tests (optional) -// - CustomParams: Custom model parameters for requests (optional) -// - Fallbacks: List of fallback providers and models to try if primary provider fails -type TestConfig struct { - Provider schemas.ModelProvider - ChatModel string - TextModel string - Messages []string - SetupText bool - SetupToolCalls bool - SetupImage bool - SetupBaseImage bool - CustomTextCompletion *string - CustomParams *schemas.ModelParameters - Fallbacks []schemas.Fallback -} - -// CommonTestMessages contains default messages used across providers for testing. -// These messages are used when no custom messages are provided in the test configuration. -var CommonTestMessages = []string{ - "Hello! How are you today?", - "Tell me a joke!", - "What's your favorite programming language?", -} - -// WeatherToolParams defines the parameters for a weather function tool. -// This is used to test function calling capabilities of AI providers. -var WeatherToolParams = schemas.ModelParameters{ - Tools: &[]schemas.Tool{{ - Type: "function", - Function: schemas.Function{ - Name: "get_weather", - Description: "Get the current weather in a given location", - Parameters: schemas.FunctionParameters{ - Type: "object", - Properties: map[string]interface{}{ - "location": map[string]interface{}{ - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": map[string]interface{}{ - "type": "string", - "enum": []string{"celsius", "fahrenheit"}, - }, - }, - Required: []string{"location"}, - }, - }, - }}, -} - -// setupTextCompletionRequest sets up and executes a text completion test request. -// It runs asynchronously and prints the result or error to stdout. -// -// Parameters: -// - bifrost: The Bifrost instance to use for the request -// - config: Test configuration containing model and parameters -// - ctx: Context for the request -func setupTextCompletionRequest(bifrost *bifrost.Bifrost, config TestConfig, ctx context.Context) { - text := "Hello world!" - if config.CustomTextCompletion != nil { - text = *config.CustomTextCompletion - } - - params := schemas.ModelParameters{} - if config.CustomParams != nil { - params = *config.CustomParams - } - - go func() { - result, err := bifrost.TextCompletionRequest(config.Provider, &schemas.BifrostRequest{ - Model: config.TextModel, - Input: schemas.RequestInput{ - TextCompletionInput: &text, - }, - Params: ¶ms, - Fallbacks: config.Fallbacks, - }, ctx) - if err != nil { - fmt.Printf("\nError in %s text completion: %v\n", config.Provider, err.Error.Message) - } else { - fmt.Printf("\nπŸ’ %s Text Completion Result: %s\n", config.Provider, *result.Choices[0].Message.Content) - } - }() -} - -// setupChatCompletionRequests sets up and executes multiple chat completion test requests. -// It runs requests asynchronously with staggered delays and prints results to stdout. -// -// Parameters: -// - bifrost: The Bifrost instance to use for the requests -// - config: Test configuration containing model and parameters -// - ctx: Context for the requests -func setupChatCompletionRequests(bifrost *bifrost.Bifrost, config TestConfig, ctx context.Context) { - messages := config.Messages - if len(messages) == 0 { - messages = CommonTestMessages - } - - params := schemas.ModelParameters{} - if config.CustomParams != nil { - params = *config.CustomParams - } - - for i, message := range messages { - delay := time.Duration(100*(i+1)) * time.Millisecond - go func(msg string, delay time.Duration, index int) { - time.Sleep(delay) - messages := []schemas.Message{ - { - Role: schemas.RoleUser, - Content: &msg, - }, - } - result, err := bifrost.ChatCompletionRequest(config.Provider, &schemas.BifrostRequest{ - Model: config.ChatModel, - Input: schemas.RequestInput{ - ChatCompletionInput: &messages, - }, - Params: ¶ms, - Fallbacks: config.Fallbacks, - }, ctx) - if err != nil { - fmt.Printf("\nError in %s request %d: %v\n", config.Provider, index+1, err.Error.Message) - } else { - fmt.Printf("\nπŸ’ %s Chat Completion Result %d: %s\n", config.Provider, index+1, *result.Choices[0].Message.Content) - } - }(message, delay, i) - } -} - -// setupImageTests sets up and executes image input test requests. -// It tests both URL and base64 image inputs (if enabled) and prints results to stdout. -// -// Parameters: -// - bifrost: The Bifrost instance to use for the requests -// - config: Test configuration containing model and parameters -// - ctx: Context for the requests -func setupImageTests(bifrost *bifrost.Bifrost, config TestConfig, ctx context.Context) { - params := schemas.ModelParameters{} - if config.CustomParams != nil { - params = *config.CustomParams - } - - // URL image test - urlImageMessages := []schemas.Message{ - { - Role: schemas.RoleUser, - Content: StrPtr("What is Happening in this picture?"), - ImageContent: &schemas.ImageContent{ - Type: StrPtr("url"), - URL: "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg", - }, - }, - } - - if config.Provider == schemas.Anthropic { - urlImageMessages[0].ImageContent.Type = StrPtr("url") - } - - go func() { - result, err := bifrost.ChatCompletionRequest(config.Provider, &schemas.BifrostRequest{ - Model: config.ChatModel, - Input: schemas.RequestInput{ - ChatCompletionInput: &urlImageMessages, - }, - Params: ¶ms, - Fallbacks: config.Fallbacks, - }, ctx) - if err != nil { - fmt.Printf("\nError in %s URL image request: %v\n", config.Provider, err.Error.Message) - } else { - fmt.Printf("\nπŸ’ %s URL Image Result: %s\n", config.Provider, *result.Choices[0].Message.Content) - } - }() - - // Base64 image test (only for providers that support it) - if config.SetupBaseImage { - base64ImageMessages := []schemas.Message{ - { - Role: schemas.RoleUser, - Content: StrPtr("What is this image about?"), - ImageContent: &schemas.ImageContent{ - Type: StrPtr("base64"), - URL: "/9j/4AAQSkZJRgABAQEAYABgAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAAIAAoDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAb/xAAUEAEAAAAAAAAAAAAAAAAAAAAA/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAX/xAAUEQEAAAAAAAAAAAAAAAAAAAAA/9oADAMBAAIRAxEAPwCdABmX/9k=", - MediaType: StrPtr("image/jpeg"), - }, - }, - } - - go func() { - result, err := bifrost.ChatCompletionRequest(config.Provider, &schemas.BifrostRequest{ - Model: config.ChatModel, - Input: schemas.RequestInput{ - ChatCompletionInput: &base64ImageMessages, - }, - Params: ¶ms, - Fallbacks: config.Fallbacks, - }, ctx) - if err != nil { - fmt.Printf("\nError in %s base64 image request: %v\n", config.Provider, err.Error.Message) - } else { - fmt.Printf("\nπŸ’ %s Base64 Image Result: %s\n", config.Provider, *result.Choices[0].Message.Content) - } - }() - } -} - -// setupToolCalls sets up and executes function calling test requests. -// It tests the provider's ability to handle tool/function calls and prints results to stdout. -// -// Parameters: -// - bifrost: The Bifrost instance to use for the requests -// - config: Test configuration containing model and parameters -// - ctx: Context for the requests -func setupToolCalls(bifrost *bifrost.Bifrost, config TestConfig, ctx context.Context) { - messages := []string{"What's the weather like in Mumbai?"} - - params := WeatherToolParams - if config.CustomParams != nil { - customParams := *config.CustomParams - if customParams.Tools != nil { - params.Tools = customParams.Tools - } - if customParams.MaxTokens != nil { - params.MaxTokens = customParams.MaxTokens - } - } - - for i, message := range messages { - delay := time.Duration(100*(i+1)) * time.Millisecond - go func(msg string, delay time.Duration, index int) { - time.Sleep(delay) - messages := []schemas.Message{ - { - Role: schemas.RoleUser, - Content: &msg, - }, - } - result, err := bifrost.ChatCompletionRequest(config.Provider, &schemas.BifrostRequest{ - Model: config.ChatModel, - Input: schemas.RequestInput{ - ChatCompletionInput: &messages, - }, - Params: ¶ms, - Fallbacks: config.Fallbacks, - }, ctx) - if err != nil { - fmt.Printf("\nError in %s tool call request %d: %v\n", config.Provider, index+1, err.Error.Message) - } else { - if result.Choices[0].Message.ToolCalls != nil && len(*result.Choices[0].Message.ToolCalls) > 0 { - toolCall := *result.Choices[0].Message.ToolCalls - fmt.Printf("\nπŸ’ %s Tool Call Result %d: %s\n", config.Provider, index+1, toolCall[0].Function.Arguments) - } else { - fmt.Printf("\nπŸ’ %s No tool calls in response %d\n", config.Provider, index+1) - if result.ExtraFields.RawResponse != nil { - fmt.Println("\nRaw JSON Response", result.ExtraFields.RawResponse) - } - } - } - }(message, delay, i) - } -} - -// SetupAllRequests sets up and executes all configured test requests for a provider. -// It coordinates the execution of text completion, chat completion, image, and tool call tests -// based on the provided configuration. -// -// Parameters: -// - bifrost: The Bifrost instance to use for the requests -// - config: Test configuration specifying which tests to run -func SetupAllRequests(bifrost *bifrost.Bifrost, config TestConfig) { - ctx := context.Background() - - if config.SetupText { - setupTextCompletionRequest(bifrost, config, ctx) - } - - setupChatCompletionRequests(bifrost, config, ctx) - - if config.SetupImage { - setupImageTests(bifrost, config, ctx) - } - - if config.SetupToolCalls { - setupToolCalls(bifrost, config, ctx) - } -} diff --git a/core/utils.go b/core/utils.go new file mode 100644 index 0000000000..f9b336cf89 --- /dev/null +++ b/core/utils.go @@ -0,0 +1,30 @@ +package bifrost + +import schemas "github.com/maximhq/bifrost/core/schemas" + +func Ptr[T any](v T) *T { + return &v +} + +// newBifrostError wraps a standard error into a BifrostError with IsBifrostError set to false. +// This helper function reduces code duplication when handling non-Bifrost errors. +func newBifrostError(err error) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: err.Error(), + Error: err, + }, + } +} + +// newBifrostErrorFromMsg creates a BifrostError with a custom message. +// This helper function is used for static error messages. +func newBifrostErrorFromMsg(message string) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: message, + }, + } +} diff --git a/docs/fallbacks.md b/docs/fallbacks.md new file mode 100644 index 0000000000..9651796bc3 --- /dev/null +++ b/docs/fallbacks.md @@ -0,0 +1,205 @@ +# Bifrost Fallback System + +Bifrost provides a robust fallback mechanism that allows you to define alternative providers and models to use when the primary provider fails. This ensures high availability and reliability for your AI-powered applications. + +## 1. How Fallbacks Work + +1. When a request is made to a primary provider, Bifrost first attempts to complete the request using that provider +2. If the primary provider fails after all retry attempts, Bifrost automatically tries the fallback providers in the order specified +3. Each fallback provider uses its own retry settings and configuration set in your account implementation +4. The first successful fallback response is returned to the client + +## 2. Configuring Fallbacks + +### Basic Fallback Configuration + +```golang +result, err := bifrost.ChatCompletionRequest( + context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Fallbacks: []schemas.Fallback{ + { + Provider: schemas.Anthropic, + Model: "claude-3-sonnet", + }, + }, + }, +) +``` + +### Multiple Fallbacks + +```golang +result, err := bifrost.ChatCompletionRequest( + context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Fallbacks: []schemas.Fallback{ + { + Provider: schemas.Anthropic, + Model: "claude-3-sonnet", + }, + { + Provider: schemas.Bedrock, + Model: "anthropic.claude-3-sonnet", + }, + { + Provider: schemas.Azure, + Model: "gpt-4", + }, + }, + }, +) +``` + +## 3. Important Considerations + +### Provider Configuration + +- Each fallback provider must be properly configured in your account +- If a fallback provider is not configured, it will be skipped +- Each provider's configuration (retries, timeouts, etc.) is independent + +### Model Compatibility + +- Ensure that the fallback models support the same capabilities as your primary model +- Consider model-specific parameters and limitations +- Verify that the fallback models are available in your account + +### Performance Impact + +- Fallbacks add latency when the primary provider fails +- Consider the order of fallbacks based on: + - Provider reliability + - Model performance + - Cost considerations + - Geographic location + +## 4. Best Practices + +1. **Provider Selection** + + - Choose fallback providers with different infrastructure + - Consider geographic distribution for high availability + - Balance cost and performance in fallback order + +2. **Model Selection** + + - Use models with similar capabilities + - Consider model-specific features (e.g., function calling) + - Account for different token limits and pricing + +3. **Error Handling** + + - Monitor fallback usage to identify provider issues + - Set up alerts for frequent fallback activations (can be done using bifrost's plugin interface) + - Regularly review and update fallback configurations + +4. **Testing** + - Test fallback scenarios in development + - Verify all fallback providers are properly configured + - Simulate provider failures to ensure smooth fallback + +## 5. HTTP Transport Examples + +### Basic HTTP Fallback Request + +```json +POST /v1/chat/completions +{ + "provider": "openai", + "model": "gpt-4", + "input": { + "chat_completion_input": [ + { + "role": "user", + "content": "Hello, how are you?" + } + ] + }, + "fallbacks": [ + { + "provider": "anthropic", + "model": "claude-3-sonnet" + } + ] +} +``` + +### HTTP Request with Multiple Fallbacks + +```json +POST /v1/chat/completions +{ + "provider": "openai", + "model": "gpt-4", + "input": { + "chat_completion_input": [ + { + "role": "user", + "content": "Explain quantum computing" + } + ] + }, + "fallbacks": [ + { + "provider": "anthropic", + "model": "claude-3-sonnet" + }, + { + "provider": "bedrock", + "model": "anthropic.claude-3-sonnet" + }, + { + "provider": "azure", + "model": "gpt-4" + } + ], + "params": { + "temperature": 0.7, + "max_tokens": 1000 + } +} +``` + +### HTTP Response Example + +```json +{ + "id": "chatcmpl-123", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Quantum computing is a type of computing..." + }, + "finish_reason": "stop" + } + ], + "model": "claude-3-sonnet", + "usage": { + "prompt_tokens": 10, + "completion_tokens": 100, + "total_tokens": 110 + }, + "extra_fields": { + "provider": "anthropic", + "latency": 1.234, + "billed_usage": { + "prompt_tokens": 10.0, + "completion_tokens": 100.0 + } + } +} +``` + +Note: The response includes metadata about which provider was used (in this case, the fallback provider "anthropic") and its performance metrics. diff --git a/docs/http-transport-api.md b/docs/http-transport-api.md new file mode 100644 index 0000000000..66234426d8 --- /dev/null +++ b/docs/http-transport-api.md @@ -0,0 +1,845 @@ +# Bifrost HTTP Transport API Reference + +This document provides comprehensive API documentation for the Bifrost HTTP transport, which exposes REST endpoints for text and chat completions using various AI model providers. + +## Base URL + +```text + http://localhost:8080 +``` + +> πŸ”§ **MCP (Model Context Protocol) Integration**: Bifrost HTTP transport includes built-in MCP support for external tool integration. When MCP is configured, tools are automatically discovered and added to model requests. For comprehensive MCP setup and usage, see the [**MCP Integration Guide**](mcp.md) and [**HTTP Transport MCP Configuration**](../transports/README.md#mcp-model-context-protocol-configuration). + +## OpenAPI Specification + +The complete OpenAPI 3.0 specification is available as a JSON file: + +πŸ“„ **[OpenAPI Specification (JSON)](openapi.json)** + +This machine-readable specification can be used with: + +- **Swagger UI** - Interactive API documentation +- **Postman** - Import for API testing +- **Code Generation** - Generate client SDKs in multiple languages +- **API Gateways** - Request/response validation +- **Testing Tools** - Automated API testing + +## Authentication + +API keys are configured through environment variables for each provider. See the [providers documentation](providers.md) for setup instructions. + +## Endpoints + +### 1. Chat Completions + +**POST** `/v1/chat/completions` + +Creates a chat completion using conversational messages. + +#### Request Body + + ```json + { + "provider": "openai", + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": "What's the weather like today?" + } + ], + "params": { + "max_tokens": 1000, + "temperature": 0.7, + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + } + }, + "required": ["location"] + } + } + } + ] + }, + "fallbacks": [ + { + "provider": "anthropic", + "model": "claude-3-sonnet-20240229" + } + ] + } + ``` + +#### Request Body with Structured Content (Text and Image) + + ```json + { + "provider": "openai", + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What's happening in this image? What's the weather like?" + }, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/weather-photo.jpg" + } + } + ] + } + ], + "params": { + "max_tokens": 1000, + "temperature": 0.7 + } + } + ``` + +#### Response + + ```json + { + "id": "chatcmpl-123", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "I'd be happy to help you check the weather! However, I'll need to know your location first.", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\": \"user_location\"}" + } + } + ] + }, + "finish_reason": "tool_calls" + } + ], + "model": "gpt-4o", + "created": 1677652288, + "usage": { + "prompt_tokens": 56, + "completion_tokens": 31, + "total_tokens": 87 + }, + "extra_fields": { + "provider": "openai", + "model_params": { + "max_tokens": 1000, + "temperature": 0.7 + }, + "latency": 1.234, + "raw_response": {} + } + } + ``` + +### 2. Text Completions + +**POST** `/v1/text/completions` + +Creates a text completion from a prompt. + +#### Request Body + + ```json + { + "provider": "openai", + "model": "gpt-3.5-turbo-instruct", + "text": "The future of AI is", + "params": { + "max_tokens": 100, + "temperature": 0.7, + "stop_sequences": ["\n\n"] + }, + "fallbacks": [ + { + "provider": "cohere", + "model": "command" + } + ] + } + ``` + +#### Response + + ```json + { + "id": "cmpl-123", + "object": "text.completion", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The future of AI is incredibly promising, with advances in machine learning..." + }, + "finish_reason": "stop" + } + ], + "model": "gpt-3.5-turbo-instruct", + "created": 1677652288, + "usage": { + "prompt_tokens": 5, + "completion_tokens": 95, + "total_tokens": 100 + }, + "extra_fields": { + "provider": "openai", + "model_params": { + "max_tokens": 100, + "temperature": 0.7 + }, + "latency": 0.856, + "raw_response": {} + } + } + ``` + +### 3. MCP Tool Execution + +**POST** `/v1/mcp/tool/execute` + +Executes MCP (Model Context Protocol) tools that have been configured in Bifrost. This endpoint is used to execute tool calls returned by AI models during conversations. + +> **Note**: This endpoint requires MCP to be configured in Bifrost. See [MCP Integration Guide](mcp.md) for setup instructions. + +#### Request Body + +```json +{ + "type": "function", + "id": "toolu_01Vmq4gaU6tSy7ZRKVC7U2fg", + "function": { + "name": "google_search", + "arguments": "{\"gl\":\"us\",\"hl\":\"en\",\"num\":5,\"q\":\"San Francisco news yesterday\",\"tbs\":\"qdr:d\"}" + } +} +``` + +#### Response + +```json +{ + "role": "tool", + "content": "{\n \"searchParameters\": {\n \"q\": \"San Francisco news yesterday\",\n \"gl\": \"us\",\n \"hl\": \"en\",\n \"type\": \"search\",\n \"num\": 5,\n \"tbs\": \"qdr:d\",\n \"engine\": \"google\"\n },\n \"organic\": [\n {\n \"title\": \"San Francisco Chronicle Β· Giants' today\"\n },\n {\n \"query\": \"s.f. chronicle e edition\"\n }\n ],\n \"credits\": 1\n}", + "tool_call_id": "toolu_01Vmq4gaU6tSy7ZRKVC7U2fg" +} +``` + +#### Multi-Turn Tool Workflow + +The typical workflow for using MCP tools involves: + +1. **Send chat completion request** β†’ AI responds with `tool_calls` +2. **Execute tools via `/v1/mcp/tool/execute`** β†’ Get tool result messages +3. **Add tool results to conversation** β†’ Send back for final response + +```bash +# Step 1: Chat completion (AI decides to use tools) +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "provider": "openai", + "model": "gpt-4o-mini", + "messages": [ + {"role": "user", "content": "Search for San Francisco news from yesterday"} + ] + }' + +# Step 2: Execute the tool call returned by AI +curl -X POST http://localhost:8080/v1/mcp/tool/execute \ + -H "Content-Type: application/json" \ + -d '{ + "type": "function", + "id": "toolu_01Vmq4gaU6tSy7ZRKVC7U2fg", + "function": { + "name": "google_search", + "arguments": "{\"q\":\"San Francisco news yesterday\"}" + } + }' + +# Step 3: Continue conversation with tool results +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "provider": "openai", + "model": "gpt-4o-mini", + "messages": [ + {"role": "user", "content": "Search for San Francisco news from yesterday"}, + {"role": "assistant", "tool_calls": [...]}, + {"role": "tool", "content": "...", "tool_call_id": "toolu_01Vmq4gaU6tSy7ZRKVC7U2fg"} + ] + }' +``` + +For detailed MCP setup and multi-turn conversation examples, see [Multi-Turn Conversations with MCP Tools](../transports/README.md#multi-turn-conversations-with-mcp-tools). + +### 4. Metrics + +**GET** `/metrics` + +Returns Prometheus metrics for monitoring and observability. + +## Schema Definitions + +### CompletionRequest + +The main request object for both chat and text completions. + +| Field | Type | Required | Description | +| ----------- | ------------------------------------- | -------- | ------------------------------------------------------------------------------------------------------ | +| `provider` | `string` | βœ… | AI model provider (`openai`, `anthropic`, `azure`, `bedrock`, `cohere`, `vertex`, `mistral`, `ollama`) | +| `model` | `string` | βœ… | Model identifier (provider-specific) | +| `messages` | [`BifrostMessage[]`](#bifrostmessage) | βœ…\* | Array of chat messages (required for chat completions) | +| `text` | `string` | βœ…\* | Text prompt (required for text completions) | +| `params` | [`ModelParameters`](#modelparameters) | ❌ | Model parameters and configuration | +| `fallbacks` | [`Fallback[]`](#fallback) | ❌ | Fallback providers and models | + +\*Either `messages` or `text` is required depending on the endpoint. + +### BifrostMessage + +Represents a message in a chat conversation. + +| Field | Type | Required | Description | +| -------------- | --------------------------------------------- | -------- | ------------------------------------------------------------------------------- | +| `role` | `string` | βœ… | Message role (`user`, `assistant`, `system`, `tool`) | +| `content` | `string` or [`ContentBlock[]`](#contentblock) | βœ… | Message content - can be simple text or structured content with text and images | +| `tool_call_id` | `string` | ❌ | ID of the tool call (for tool messages) | +| `tool_calls` | [`ToolCall[]`](#toolcall) | ❌ | Tool calls made by assistant | +| `refusal` | `string` | ❌ | Refusal message from assistant | +| `annotations` | `Annotation[]` | ❌ | Message annotations | +| `thought` | `string` | ❌ | Assistant's internal thought process | + +### ContentBlock + +Represents a structured content block in a message (for text and image content). + +| Field | Type | Required | Description | +| ----------- | ----------------------------------- | -------- | ---------------------------------------------- | +| `type` | `string` | βœ… | Content type (`text` or `image_url`) | +| `text` | `string` | ❌ | Text content (required when type is `text`) | +| `image_url` | [`ImageURLStruct`](#imageurlstruct) | ❌ | Image data (required when type is `image_url`) | + +### ImageURLStruct + +Represents image data in a message. + +| Field | Type | Required | Description | +| -------- | -------- | -------- | ------------------------------------------ | +| `url` | `string` | βœ… | Image URL or data URI | +| `detail` | `string` | ❌ | Image detail level (`low`, `high`, `auto`) | + +### ModelParameters + +Configuration parameters for model behavior. + +| Field | Type | Description | +| --------------------- | --------------------------- | --------------------------------------- | +| `temperature` | `number` | Controls randomness (0.0-2.0) | +| `top_p` | `number` | Nucleus sampling parameter (0.0-1.0) | +| `top_k` | `integer` | Top-k sampling parameter | +| `max_tokens` | `integer` | Maximum tokens to generate | +| `stop_sequences` | `string[]` | Sequences that stop generation | +| `presence_penalty` | `number` | Penalizes repeated tokens (-2.0 to 2.0) | +| `frequency_penalty` | `number` | Penalizes frequent tokens (-2.0 to 2.0) | +| `tools` | [`Tool[]`](#tool) | Available tools for the model | +| `tool_choice` | [`ToolChoice`](#toolchoice) | How tools should be chosen | +| `parallel_tool_calls` | `boolean` | Enable parallel tool execution | + +### Tool + +Defines a function that the model can call. + +| Field | Type | Required | Description | +| ---------- | ----------------------- | -------- | --------------------------------------- | +| `id` | `string` | ❌ | Unique tool identifier | +| `type` | `string` | βœ… | Tool type (currently only `"function"`) | +| `function` | [`Function`](#function) | βœ… | Function definition | + +### Function + +Defines the function details for a tool. + +| Field | Type | Required | Description | +| ------------- | ------------------------------------------- | -------- | -------------------------- | +| `name` | `string` | βœ… | Function name | +| `description` | `string` | βœ… | Function description | +| `parameters` | [`FunctionParameters`](#functionparameters) | βœ… | Function parameters schema | + +### FunctionParameters + +JSON Schema defining function parameters. + +| Field | Type | Required | Description | +| ------------- | ---------- | -------- | ----------------------------------- | +| `type` | `string` | βœ… | Parameter type (usually `"object"`) | +| `description` | `string` | ❌ | Parameter description | +| `properties` | `object` | ❌ | Parameter properties (JSON Schema) | +| `required` | `string[]` | ❌ | Required parameter names | +| `enum` | `string[]` | ❌ | Enum values for parameters | + +### ToolChoice + +Specifies how the model should choose tools. + +| Field | Type | Required | Description | +| ---------- | ------------------------------------------- | -------- | ----------------------------------------------------------- | +| `type` | `string` | βœ… | Choice type (`none`, `auto`, `any`, `function`, `required`) | +| `function` | [`ToolChoiceFunction`](#toolchoicefunction) | ❌ | Specific function to call (when type is `function`) | + +### ToolChoiceFunction + +Specifies a particular function to call. + +| Field | Type | Required | Description | +| ------ | -------- | -------- | ---------------------------- | +| `name` | `string` | βœ… | Name of the function to call | + +### Fallback + +Defines a fallback provider and model. + +| Field | Type | Required | Description | +| ---------- | -------- | -------- | ---------------------- | +| `provider` | `string` | βœ… | Fallback provider name | +| `model` | `string` | βœ… | Fallback model name | + +### BifrostResponse + +The response object returned by both endpoints. + +| Field | Type | Description | +| -------------------- | ----------------------------------------------------------- | ------------------------------------------------------ | +| `id` | `string` | Unique response identifier | +| `object` | `string` | Response type (`chat.completion` or `text.completion`) | +| `choices` | [`BifrostResponseChoice[]`](#bifrostresponsechoice) | Array of completion choices | +| `model` | `string` | Model used for generation | +| `created` | `integer` | Unix timestamp of creation | +| `service_tier` | `string` | Service tier used | +| `system_fingerprint` | `string` | System fingerprint | +| `usage` | [`LLMUsage`](#llmusage) | Token usage statistics | +| `extra_fields` | [`BifrostResponseExtraFields`](#bifrostresponseextrafields) | Additional Bifrost-specific data | + +### BifrostResponseChoice + +A single completion choice. + +| Field | Type | Description | +| --------------- | ----------------------------------- | ----------------------------------- | +| `index` | `integer` | Choice index | +| `message` | [`BifrostMessage`](#bifrostmessage) | The completion message | +| `finish_reason` | `string` | Reason completion stopped | +| `stop` | `string` | Stop sequence that ended generation | +| `log_probs` | `LogProbs` | Log probabilities (if requested) | + +### LLMUsage + +Token usage statistics. + +| Field | Type | Description | +| --------------------------- | ------------------------- | ----------------------------------- | +| `prompt_tokens` | `integer` | Tokens in the prompt | +| `completion_tokens` | `integer` | Tokens in the completion | +| `total_tokens` | `integer` | Total tokens used | +| `completion_tokens_details` | `CompletionTokensDetails` | Detailed completion token breakdown | + +### BifrostResponseExtraFields + +Additional Bifrost-specific response data. + +| Field | Type | Description | +| -------------- | ------------------------------------- | ------------------------------- | +| `provider` | `string` | Provider used for the request | +| `model_params` | [`ModelParameters`](#modelparameters) | Parameters used for the request | +| `latency` | `number` | Request latency in seconds | +| `chat_history` | [`BifrostMessage[]`](#bifrostmessage) | Full conversation history | +| `billed_usage` | `BilledLLMUsage` | Billing usage information | +| `raw_response` | `object` | Raw provider response | + +### ToolCall + +Represents a tool call made by the assistant. + +| Field | Type | Description | +| ---------- | ------------------------------- | --------------------------- | +| `id` | `string` | Unique tool call identifier | +| `type` | `string` | Tool call type (`function`) | +| `function` | [`FunctionCall`](#functioncall) | Function call details | + +### FunctionCall + +Details of a function call. + +| Field | Type | Description | +| ----------- | -------- | --------------------------------- | +| `name` | `string` | Function name | +| `arguments` | `string` | JSON string of function arguments | + +### BifrostError + +Error response format. + +| Field | Type | Description | +| ------------------ | --------------------------- | ------------------------------------- | +| `event_id` | `string` | Unique error event ID | +| `type` | `string` | Error type | +| `is_bifrost_error` | `boolean` | Whether error originated from Bifrost | +| `status_code` | `integer` | HTTP status code | +| `error` | [`ErrorField`](#errorfield) | Detailed error information | + +### ErrorField + +Detailed error information. + +| Field | Type | Description | +| ---------- | -------- | ------------------------------- | +| `type` | `string` | Error type | +| `code` | `string` | Error code | +| `message` | `string` | Human-readable error message | +| `param` | `any` | Parameter that caused the error | +| `event_id` | `string` | Error event ID | + +## Supported Providers + +| Provider | Key | +| ---------------- | ----------- | +| OpenAI | `openai` | +| Anthropic | `anthropic` | +| Azure OpenAI | `azure` | +| AWS Bedrock | `bedrock` | +| Cohere | `cohere` | +| Google Vertex AI | `vertex` | +| Mistral | `mistral` | +| Ollama | `ollama` | + +## Error Codes + +| Status Code | Description | +| ----------- | --------------------------------------------------------------- | +| `400` | Bad Request - Invalid request format or missing required fields | +| `401` | Unauthorized - Invalid or missing API key | +| `429` | Too Many Requests - Rate limit exceeded | +| `500` | Internal Server Error - Server or provider error | +| `502` | Bad Gateway - Provider service unavailable | +| `503` | Service Unavailable - Bifrost service temporarily unavailable | + +## Rate Limiting + +Rate limiting is handled by the individual providers. Bifrost respects provider rate limits and will return appropriate error responses when limits are exceeded. + +## Examples + +### Simple Chat + + ```bash + curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "provider": "openai", + "model": "gpt-4o", + "messages": [ + {"role": "user", "content": "Hello, world!"} + ] + }' + ``` + +### Chat with Images + + ```bash + curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "provider": "openai", + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What do you see in this image?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}} + ] + } + ] + }' + ``` + +### Chat with Tools + + ```bash + curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "provider": "openai", + "model": "gpt-4o", + "messages": [ + {"role": "user", "content": "What'\''s the weather in San Francisco?"} + ], + "params": { + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get current weather", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"} + }, + "required": ["location"] + } + } + } + ], + "tool_choice": {"type": "function", "function": {"name": "get_weather"}} + } + }' + ``` + +### Text Completion + + ```bash + curl -X POST http://localhost:8080/v1/text/completions \ + -H "Content-Type: application/json" \ + -d '{ + "provider": "openai", + "model": "gpt-3.5-turbo-instruct", + "text": "The benefits of artificial intelligence include", + "params": { + "max_tokens": 150, + "temperature": 0.7 + } + }' + ``` + +### Using Fallbacks + + ```bash + curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "provider": "openai", + "model": "gpt-4o", + "messages": [ + {"role": "user", "content": "Explain quantum computing"} + ], + "fallbacks": [ + {"provider": "anthropic", "model": "claude-3-sonnet-20240229"}, + {"provider": "cohere", "model": "command"} + ] + }' + ``` + +## Integration Examples + +### Python + + ```python + import requests + + def chat_completion(messages, provider="openai", model="gpt-4o"): + response = requests.post( + "http://localhost:8080/v1/chat/completions", + json={ + "provider": provider, + "model": model, + "messages": messages, + "params": {"max_tokens": 1000} + } + ) + return response.json() + + # Simple text message + result = chat_completion([ + {"role": "user", "content": "Hello, how are you?"} + ]) + print(result["choices"][0]["message"]["content"]) + + # Structured content with image + result = chat_completion([ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}} + ] + } + ]) + print(result["choices"][0]["message"]["content"]) + ``` + +### Node.js + + ```javascript + const axios = require("axios"); + + async function chatCompletion(messages, provider = "openai", model = "gpt-4o") { + try { + const response = await axios.post( + "http://localhost:8080/v1/chat/completions", + { + provider, + model, + messages, + params: { max_tokens: 1000 }, + } + ); + return response.data; + } catch (error) { + console.error("Error:", error.response?.data || error.message); + throw error; + } + } + + // Usage with structured content + chatCompletion([ + { + role: "user", + content: [ + { type: "text", text: "Describe this image" }, + { + type: "image_url", + image_url: { url: "https://example.com/image.jpg" }, + }, + ], + }, + ]).then((result) => { + console.log(result.choices[0].message.content); + }); + ``` + +### Go + + ```go + package main + + import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + ) + + type ChatRequest struct { + Provider string `json:"provider"` + Model string `json:"model"` + Messages []BifrostMessage `json:"messages"` + Params *ModelParameters `json:"params,omitempty"` + } + + type BifrostMessage struct { + Role string `json:"role"` + Content interface{} `json:"content"` // Can be string or []ContentBlock + } + + type ContentBlock struct { + Type string `json:"type"` + Text *string `json:"text,omitempty"` + ImageURL *ImageURLStruct `json:"image_url,omitempty"` + } + + type ImageURLStruct struct { + URL string `json:"url"` + Detail *string `json:"detail,omitempty"` + } + + type ModelParameters struct { + MaxTokens *int `json:"max_tokens,omitempty"` + } + + func chatCompletion(messages []BifrostMessage) error { + reqBody := ChatRequest{ + Provider: "openai", + Model: "gpt-4o", + Messages: messages, + Params: &ModelParameters{MaxTokens: intPtr(1000)}, + } + + jsonData, _ := json.Marshal(reqBody) + resp, err := http.Post( + "http://localhost:8080/v1/chat/completions", + "application/json", + bytes.NewBuffer(jsonData), + ) + if err != nil { + return err + } + defer resp.Body.Close() + + var result map[string]interface{} + json.NewDecoder(resp.Body).Decode(&result) + fmt.Println(result) + return nil + } + + func intPtr(i int) *int { return &i } + ``` + +## Configuration + +The HTTP transport can be configured via command-line flags and environment variables: + + ```bash + # Using environment variables for plugin configuration (optional) + export MAXIM_LOG_REPO_ID=your-repo-id + + ./bifrost-http \ + -config config.json \ + -port 8080 \ + -pool-size 300 \ + -drop-excess-requests \ + -plugins maxim \ + -prometheus-labels env,service + ``` + +### Configuration Flags + +| Flag | Description | Default | +| ----------------------- | -------------------------------- | -------- | +| `-config` | Path to configuration file | Required | +| `-port` | Server port | `8080` | +| `-pool-size` | Initial connection pool size | `300` | +| `-drop-excess-requests` | Drop requests when queue is full | `false` | +| `-plugins` | Comma-separated list of plugins | None | +| `-prometheus-labels` | Additional Prometheus labels | None | + +### Environment Variables for Plugins (Optional) + +Plugin-specific configuration should be provided via environment variables: + +| Environment Variable | Description | Default | +| -------------------- | --------------------------- | ------- | +| `MAXIM_LOG_REPO_ID` | Maxim logging repository ID | None | + +## Monitoring + +The `/metrics` endpoint provides Prometheus-compatible metrics for monitoring: + +- Request counts by provider, model, and status +- Request latency histograms +- Token usage metrics +- Error rates and types +- Connection pool statistics diff --git a/docs/logger.md b/docs/logger.md new file mode 100644 index 0000000000..1fd73e9c6e --- /dev/null +++ b/docs/logger.md @@ -0,0 +1,123 @@ +# Bifrost Logging System + +Bifrost provides a flexible logging system that allows you to either use the built-in logger or implement your own custom logger. + +## 1. Log Levels + +Bifrost supports four log levels: + +- `debug`: Detailed debugging information, typically only needed during development +- `info`: General informational messages about normal operation +- `warn`: Potentially harmful situations that don't prevent normal operation +- `error`: Serious problems that need attention and may prevent normal operation + +## 2. Using the Default Logger + +Bifrost comes with a built-in logger that writes to stdout/stderr with formatted timestamps and log levels. It's used by default if no custom logger is provided. + +### Default Configuration + +```golang +client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &yourAccount, + // Logger not specified, will use default logger with info level +}) +``` + +### Customizing Default Logger Level + +```golang +client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &yourAccount, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug), // Set to debug level +}) +``` + +### Default Logger Output Format + +The default logger formats messages as: + +```text + [BIFROST-TIMESTAMP] LEVEL: message + [BIFROST-TIMESTAMP] ERROR: (error: error_message) +``` + +Example outputs: + +```text + [BIFROST-2024-03-20T10:15:30Z] INFO: Initializing provider OpenAI + [BIFROST-2024-03-20T10:15:31Z] ERROR: (error: failed to connect to provider) +``` + +## 3. Implementing a Custom Logger + +You can implement your own logger by following the `Logger` interface: + +```golang +type Logger interface { + // Debug logs a debug-level message + Debug(msg string) + + // Info logs an info-level message + Info(msg string) + + // Warn logs a warning-level message + Warn(msg string) + + // Error logs an error-level message + Error(err error) +} +``` + +### Example Custom Logger Implementation + +```golang +type CustomLogger struct { + // Your logger fields +} + +func (l *CustomLogger) Debug(msg string) { + // Implement debug logging +} + +func (l *CustomLogger) Info(msg string) { + // Implement info logging +} + +func (l *CustomLogger) Warn(msg string) { + // Implement warning logging +} + +func (l *CustomLogger) Error(err error) { + // Implement error logging +} + +// Using the custom logger +client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &yourAccount, + Logger: &CustomLogger{}, +}) +``` + +## 4. Best Practices + +1. **Log Level Selection** + + - Use `debug` for development and troubleshooting + - Use `info` for production monitoring + - Use `warn` for potential issues that don't affect functionality + - Use `error` for critical issues that need immediate attention + +2. **Custom Logger Implementation** + + - Ensure thread safety if your logger is used concurrently + - Consider implementing log rotation for production environments + - Include relevant context in log messages + - Handle errors appropriately in your logging implementation + +3. **Performance Considerations** + - Avoid expensive operations in logging methods + - Consider using async logging for high-throughput scenarios + - Be mindful of log volume in production environments + +Remember that logging is crucial for monitoring and debugging your Bifrost implementation. Choose the appropriate logging strategy based on your environment and requirements. diff --git a/docs/mcp.md b/docs/mcp.md new file mode 100644 index 0000000000..9bb0667d19 --- /dev/null +++ b/docs/mcp.md @@ -0,0 +1,1470 @@ +# Bifrost MCP Integration + +The **Bifrost MCP (Model Context Protocol) Integration** provides seamless connectivity between Bifrost and MCP servers, enabling dynamic tool discovery, registration, and execution from both local and external MCP sources. + +## Table of Contents + +- [Overview](#overview) +- [Features](#features) +- [Quick Start](#quick-start) +- [HTTP Transport Usage](#http-transport-usage) +- [Configuration](#configuration) +- [Usage Examples](#usage-examples) +- [Implementing Chat Conversations with MCP Tools](#implementing-chat-conversations-with-mcp-tools) +- [Architecture](#architecture) +- [API Reference](#api-reference) +- [Advanced Features](#advanced-features) +- [Troubleshooting](#troubleshooting) + +## Overview + +The MCP Integration acts as a bridge between Bifrost and the Model Context Protocol ecosystem, allowing you to: + +- **Host Local Tools**: Register Go functions as MCP tools directly in Bifrost core +- **Connect to External MCP Servers**: Integrate with existing MCP servers via HTTP or STDIO +- **Automatic Tool Discovery**: Automatically discover and register tools from connected MCP servers +- **Dynamic Tool Execution**: Seamless tool execution integrated into Bifrost's request flow +- **Client Filtering**: Control which MCP clients are active per request + +## Features + +### πŸ”§ **Tool Management** + +- **Local Tool Hosting**: Register typed Go functions as MCP tools +- **External Tool Integration**: Connect to HTTP or STDIO-based MCP servers +- **Dynamic Discovery**: Automatically discover tools from external servers +- **Tool Filtering**: Include/exclude specific tools or clients per request + +### πŸ”’ **Security & Control** + +- **Client Filtering**: Control which MCP clients are active per request +- **Tool Filtering**: Configure which tools are available from each client +- **Safe Defaults**: Comprehensive tool management and filtering + +### πŸ”Œ **Connection Types** + +- **HTTP**: Connect to web-based MCP servers with streaming support +- **STDIO**: Launch and communicate with command-line MCP tools +- **SSE**: Connect to Server-Sent Events based MCP services +- **Process Management**: Automatic cleanup of STDIO processes and SSE connections + +## Quick Start + +### 1. Basic Setup + +```go +package main + +import ( + "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +func main() { + // Create MCP configuration + mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{ + { + Name: "weather-service", + ToolsToSkip: []string{}, // No tools to skip + }, + }, + } + + // Create Bifrost instance with MCP integration + bifrost, err := bifrost.Init(schemas.BifrostConfig{ + Account: accountImplementation, + MCPConfig: mcpConfig, // MCP is configured directly in Bifrost + Logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo), + }) + if err != nil { + panic(err) + } + defer bifrost.Cleanup() +} +``` + +### 2. Register a Simple Tool + +```go +// Define tool arguments structure +type EchoArgs struct { + Message string `json:"message"` +} + +// Create tool schema +toolSchema := schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: "echo", + Description: "Echo a message back to the user", + Parameters: schemas.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to echo back", + }, + }, + Required: []string{"message"}, + }, + }, +} + +// Register the tool with Bifrost +err := bifrost.RegisterMCPTool("echo", "Echo a message", + func(args any) (string, error) { + // Type assertion for arguments + if echoArgs, ok := args.(map[string]interface{}); ok { + if message, exists := echoArgs["message"].(string); exists { + return fmt.Sprintf("Echo: %s", message), nil + } + } + return "", fmt.Errorf("invalid arguments") + }, toolSchema) +``` + +### 3. Connect to External MCP Server + +```go +mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{ + // HTTP-based MCP server + { + Name: "weather-service", + ConnectionType: schemas.MCPConnectionTypeHTTP, + ConnectionString: &[]string{"http://localhost:3000"}[0], + ToolsToSkip: []string{}, // No tools to skip + }, + // STDIO-based MCP tool + { + Name: "filesystem-tools", + ConnectionType: schemas.MCPConnectionTypeSTDIO, + StdioConfig: &schemas.MCPStdioConfig{ + Command: "npx", + Args: []string{"@modelcontextprotocol/server-filesystem", "/tmp"}, + }, + ToolsToSkip: []string{"rm", "delete"}, // Skip dangerous operations + }, + }, +} +``` + +## HTTP Transport Usage + +This section covers HTTP-specific MCP setup and usage patterns for integrating tools via Bifrost HTTP Transport. + +> πŸ“– **For detailed HTTP transport setup and configuration examples, see** [**Bifrost Transports Documentation**](../transports/README.md#mcp-model-context-protocol-configuration). + +### HTTP Transport Configuration + +Configure MCP in your JSON configuration file when using Bifrost HTTP Transport: + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": ["gpt-4o-mini"], + "weight": 1.0 + } + ] + } + }, + "mcp": { + "client_configs": [ + { + "name": "filesystem", + "connection_type": "stdio", + "stdio_config": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], + "envs": [] + }, + "tools_to_skip": ["rm", "delete"], + "tools_to_execute": [] + }, + { + "name": "web-search", + "connection_type": "http", + "connection_string": "http://localhost:3001/mcp", + "tools_to_skip": [], + "tools_to_execute": [] + }, + { + "name": "real-time-data", + "connection_type": "sse", + "connection_string": "http://localhost:3002/sse", + "tools_to_skip": [], + "tools_to_execute": [] + } + ] + } +} +``` + +### Starting HTTP Transport with MCP + +```bash +# Start Bifrost HTTP server with MCP configuration +bifrost-http -config config.json -port 8080 -pool-size 300 + +# Or using Docker +docker run -p 8080:8080 \ + -v ./config.json:/app/config.json \ + -e OPENAI_API_KEY \ + bifrost-transports +``` + +### HTTP API Endpoints with MCP Tools + +When MCP is configured, tools are automatically added to chat completion requests. The HTTP transport provides two key endpoints: + +- `POST /v1/chat/completions` - Chat with automatic tool discovery +- `POST /v1/mcp/tool/execute` - Execute specific tool calls + +#### 1. Standard Chat Completion (Tools Auto-Added) + +```bash +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "provider": "openai", + "model": "gpt-4o-mini", + "messages": [ + {"role": "user", "content": "List the files in /tmp directory"} + ] + }' +``` + +**Response** (AI decides to use tools): + +```json +{ + "data": { + "choices": [ + { + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "list_files", + "arguments": "{\"path\": \"/tmp\"}" + } + } + ] + } + } + ] + } +} +``` + +#### 2. Multi-Turn Tool Execution Flow + +> πŸ“‹ **For complete multi-turn conversation examples with tool execution, see** [**HTTP Transport Multi-Turn Examples**](../transports/README.md#multi-turn-conversations-with-mcp-tools). + +The typical flow involves: + +1. **Initial Request** β†’ AI responds with tool calls +2. **Tool Execution** β†’ Use Bifrost's `/v1/mcp/tool/execute` endpoint +3. **Continue Conversation** β†’ Send conversation history with tool results +4. **Final Response** β†’ AI provides final answer + +```bash +# Step 2: Execute tool using Bifrost's MCP endpoint +curl -X POST http://localhost:8080/v1/mcp/tool/execute \ + -H "Content-Type: application/json" \ + -d ' { + "id": "call_abc123", + "type": "function", + "function": { + "name": "list_files", + "arguments": "{\"path\": \"/tmp\"}" + } + }' + +# Response: {"role": "tool", "content": "config.json\nreadme.txt\ndata.csv", "tool_call_id": "call_abc123"} + +# Step 3: Continue conversation with tool results +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "provider": "openai", + "model": "gpt-4o-mini", + "messages": [ + {"role": "user", "content": "List the files in /tmp directory"}, + { + "role": "assistant", + "tool_calls": [{ + "id": "call_abc123", + "type": "function", + "function": { + "name": "list_files", + "arguments": "{\"path\": \"/tmp\"}" + } + }] + }, + { + "role": "tool", + "content": "config.json\nreadme.txt\ndata.csv", + "tool_call_id": "call_abc123" + } + ] + }' +``` + +### HTTP Headers for MCP Client Filtering + +Control which MCP clients are active per request using HTTP headers: + +```bash +# Include only specific MCP clients +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "X-MCP-Include-Clients: filesystem,weather" \ + -d '{...}' + +# Exclude specific MCP clients +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "X-MCP-Exclude-Clients: dangerous-tools" \ + -d '{...}' +``` + +### Tool Execution with HTTP Transport + +The HTTP transport provides a dedicated endpoint for tool execution: + +**Endpoint:** `POST /v1/mcp/tool/execute` + +**Workflow:** + +1. **Send chat completion request** β†’ Receive tool calls in response +2. **Execute tools via `/v1/mcp/tool/execute`** β†’ Get tool result messages +3. **Add tool results to conversation** β†’ Continue chat completion +4. **Receive final response** β†’ Complete conversation + +**Request Format:** (Tool Call Result) + +```json +{ + "id": "call_abc123", + "type": "function", + "function": { + "name": "tool_name", + "arguments": "{\"param\": \"value\"}" + } +} +``` + +**Response Format:** + +```json +{ + "role": "tool", + "content": "tool execution result", + "tool_call_id": "call_abc123" +} +``` + +This approach gives you control over when to execute tools while leveraging Bifrost's MCP infrastructure for the actual execution. + +### Environment Variables + +Set environment variables for MCP tools that require them: + +```bash +export OPENAI_API_KEY="your-api-key" +export FILESYSTEM_ROOT="/allowed/path" +export SEARCH_API_KEY="your-search-key" + +# Start HTTP transport +bifrost-http -config config.json +``` + +## Configuration + +### Bifrost Configuration with MCP + +```go +type BifrostConfig struct { + Account Account + Plugins []Plugin + Logger Logger + InitialPoolSize int + DropExcessRequests bool + MCPConfig *MCPConfig `json:"mcp_config,omitempty"` // MCP configuration +} +``` + +### MCP Configuration + +```go +type MCPConfig struct { + ClientConfigs []MCPClientConfig `json:"client_configs,omitempty"` // MCP client configurations (connection + filtering) +} +``` + +### Client Configuration (Connection + Tool Filtering) + +```go +type MCPClientConfig struct { + Name string `json:"name"` // Client name + ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, or SSE) + ConnectionString *string `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections) + StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections) + ToolsToSkip []string `json:"tools_to_skip,omitempty"` // Tools to exclude from this client + ToolsToExecute []string `json:"tools_to_execute,omitempty"` // Tools to include from this client (if specified, only these are used) +} +``` + +### Connection Types + +```go +type MCPConnectionType string + +const ( + MCPConnectionTypeHTTP MCPConnectionType = "http" // HTTP-based MCP connection (streamable) + MCPConnectionTypeSTDIO MCPConnectionType = "stdio" // STDIO-based MCP connection + MCPConnectionTypeSSE MCPConnectionType = "sse" // Server-Sent Events MCP connection +) +``` + +### STDIO Configuration + +```go +type MCPStdioConfig struct { + Command string `json:"command"` // Executable command to run + Args []string `json:"args"` // Command line arguments + Envs []string `json:"envs"` // Environment variables required +} +``` + +### Example Configuration + +```go +mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{ + { + Name: "weather-service", + ConnectionType: schemas.MCPConnectionTypeHTTP, + ConnectionString: &[]string{"http://localhost:3000"}[0], + ToolsToExecute: []string{"get_weather", "get_forecast"}, // Only these tools + }, + { + Name: "filesystem-tools", + ConnectionType: schemas.MCPConnectionTypeSTDIO, + StdioConfig: &schemas.MCPStdioConfig{ + Command: "npx", + Args: []string{"@modelcontextprotocol/server-filesystem", "/home/user/documents"}, + }, + ToolsToSkip: []string{"rm", "delete"}, // Skip dangerous operations + }, + { + Name: "local-tools-only", + // No ConnectionType means this client is for tool filtering only + // (for tools registered via RegisterMCPTool) + ToolsToExecute: []string{"echo", "calculate"}, + }, + }, +} +``` + +## Usage Examples + +### Example 1: File System Tools + +```go +mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{ + { + Name: "filesystem", + ConnectionType: schemas.MCPConnectionTypeSTDIO, + StdioConfig: &schemas.MCPStdioConfig{ + Command: "npx", + Args: []string{"@modelcontextprotocol/server-filesystem", "/home/user/documents"}, + }, + ToolsToExecute: []string{"read_file", "list_files"}, // Read-only operations + }, + }, +} + +bifrost, err := bifrost.Init(schemas.BifrostConfig{ + Account: account, + MCPConfig: mcpConfig, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo), +}) +``` + +### Example 2: Weather Service Integration + +```go +// Register weather tool +weatherSchema := schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: "get_weather", + Description: "Get current weather for a location", + Parameters: schemas.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "City name or coordinates", + }, + "units": map[string]interface{}{ + "type": "string", + "description": "Temperature units (celsius/fahrenheit)", + "enum": []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }, +} + +err := bifrost.RegisterMCPTool("get_weather", "Get current weather", + func(args any) (string, error) { + // Extract arguments + argMap := args.(map[string]interface{}) + location := argMap["location"].(string) + units := "celsius" // default + if u, ok := argMap["units"].(string); ok { + units = u + } + + // Call external weather API + weather, err := getWeatherData(location, units) + if err != nil { + return "", err + } + return formatWeatherResponse(weather), nil + }, weatherSchema) +``` + +### Example 3: Client Filtering in Requests + +```go +// Create context with client filtering +ctx := context.Background() +ctx = context.WithValue(ctx, "mcp_include_clients", []string{"weather-service"}) +// Only tools from weather-service will be available + +ctx = context.WithValue(ctx, "mcp_exclude_clients", []string{"filesystem"}) +// All tools except filesystem tools will be available + +// Use in Bifrost request +request := &schemas.BifrostRequest{ + Provider: "openai", + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: &[]string{"What's the weather like today?"}[0], + }, + }, + }, + }, +} + +response, err := bifrost.ChatCompletionRequest(ctx, request) +``` + +> 🌐 **HTTP Transport Users**: When using Bifrost HTTP transport, use HTTP headers instead of context values: `X-MCP-Include-Clients` and `X-MCP-Exclude-Clients`. See [HTTP Headers for MCP Client Filtering](#http-headers-for-mcp-client-filtering). + +## Implementing Chat Conversations with MCP Tools + +This section explains how to build chat applications that leverage MCP tools using the Bifrost Go package. You'll learn the key patterns for tool call handling, conversation management, and implementing your own tool approval logic. + +> 🌐 **For HTTP Transport usage with MCP tools, see [HTTP Transport Usage](#http-transport-usage) and [Multi-Turn Conversations with MCP Tools](../transports/README.md#multi-turn-conversations-with-mcp-tools).** + +### Why You Control Tool Execution + +**Bifrost does NOT automatically execute tools for you.** Instead, it: + +1. **Discovers and registers** MCP tools from your configured clients +2. **Adds tools to LLM requests** automatically +3. **Provides the infrastructure** to execute tools when the LLM requests them +4. **Leaves the execution logic to you** - giving you full control over when and how tools run + +This design gives you complete control over: + +- **Security**: You decide which tools to run and when +- **User approval**: You can implement approval flows +- **Error handling**: You control how failures are handled +- **Logging and monitoring**: You can track all tool usage + +### MCP Tool Execution Flow + +The following diagram shows the complete flow from user input to tool execution, highlighting where **you** control the process: + +```mermaid +flowchart TD + A["πŸ‘€ User Message
\"List files in current directory\""] --> B["πŸ€– Bifrost Core"] + + B --> C["πŸ”§ MCP Manager
Auto-discovers and adds
available tools to request"] + + C --> D["🌐 LLM Provider
(OpenAI, Anthropic, etc.)"] + + D --> E{"πŸ” Response contains
tool_calls?"} + + E -->|No| F["βœ… Final Response
Display to user"] + + E -->|Yes| G["πŸ“ Add assistant message
with tool_calls to history"] + + G --> H["πŸ›‘οΈ YOUR EXECUTION LOGIC
(Security, Approval, Logging)"] + + H --> I{"πŸ€” User Decision Point
Execute this tool?"} + + I -->|Deny| J["❌ Create denial result
Add to conversation history"] + + I -->|Approve| K["βš™οΈ client.ExecuteMCPTool()
Bifrost executes via MCP"] + + K --> L["πŸ“Š Tool Result
Add to conversation history"] + + J --> M["πŸ”„ Continue conversation loop
Send updated history back to LLM"] + L --> M + + M --> D + + style A fill:#e1f5fe + style F fill:#e8f5e8 + style H fill:#fff3e0 + style I fill:#fce4ec + style K fill:#f3e5f5 +``` + +### Basic Chat Loop Pattern + +Here's the core pattern for handling tool-enabled conversations: + +```go +func processChatWithTools(client *bifrost.Bifrost, history []schemas.BifrostMessage) { + maxIterations := 10 // Prevent infinite loops + + for iteration := 0; iteration < maxIterations; iteration++ { + // 1. Send conversation to LLM (tools auto-added by Bifrost) + response, err := client.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ + Provider: "openai", + Model: "gpt-4", + Input: schemas.RequestInput{ChatCompletionInput: &history}, + }) + + assistantMessage := response.Data.Choices[0].Message + + // 2. Check if LLM wants to use tools + if assistantMessage.ToolCalls != nil && len(*assistantMessage.ToolCalls) > 0 { + // Add assistant message with tool calls to history + history = append(history, assistantMessage) + + // 3. Execute tools (YOUR LOGIC HERE) + for _, toolCall := range *assistantMessage.ToolCalls { + toolResult := executeToolWithApproval(client, toolCall) + history = append(history, *toolResult) + } + + continue // Get final response from LLM + } + + // 4. No more tools - conversation complete + return &assistantMessage + } +} +``` + +### Tool Approval Patterns + +Since you control tool execution, you can implement various approval mechanisms: + +#### 1. **Manual Approval** + +```go +func executeToolWithApproval(client *bifrost.Bifrost, toolCall schemas.ToolCall) *schemas.BifrostMessage { + // Ask user for approval + fmt.Printf("πŸ”§ Execute %s? (y/n): ", toolCall.Function.Name) + scanner := bufio.NewScanner(os.Stdin) + scanner.Scan() + + if strings.ToLower(scanner.Text()) != "y" { + return &schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleTool, + Content: schemas.MessageContent{ + ContentStr: &[]string{"Tool execution cancelled by user"}[0], + }, + ToolCallID: toolCall.ID, + } + } + + // User approved - execute via Bifrost's MCP infrastructure + return client.ExecuteMCPTool(ctx, toolCall) +} +``` + +#### 2. **Automatic with Allowlist** + +```go +func executeIfSafe(client *bifrost.Bifrost, toolCall schemas.ToolCall) *schemas.BifrostMessage { + safeFunctions := []string{"read_file", "list_files", "search_web"} + + for _, safe := range safeFunctions { + if toolCall.Function.Name == safe { + return client.ExecuteMCPTool(ctx, toolCall) // Auto-execute safe tools + } + } + + // Dangerous tool - require approval or reject + return askForApproval(client, toolCall) +} +``` + +#### 3. **Role-Based Approval** + +```go +func executeBasedOnRole(client *bifrost.Bifrost, toolCall schemas.ToolCall, userRole string) *schemas.BifrostMessage { + switch userRole { + case "admin": + return client.ExecuteMCPTool(ctx, toolCall) // Admins can run anything + case "user": + if isReadOnlyTool(toolCall.Function.Name) { + return client.ExecuteMCPTool(ctx, toolCall) // Users get read-only tools + } + return requireApproval(client, toolCall) + default: + return denyExecution(toolCall, "Insufficient permissions") + } +} +``` + +### Core Implementation Details + +#### 1. **MCP Configuration Setup** + +Configure your MCP clients with appropriate security controls: + +```go +mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{ + { + Name: "filesystem", + ConnectionType: schemas.MCPConnectionTypeSTDIO, + StdioConfig: &schemas.MCPStdioConfig{ + Command: "npx", + Args: []string{"@modelcontextprotocol/server-filesystem", "."}, + }, + ToolsToExecute: []string{"read_file", "list_files"}, // Whitelist safe tools + }, + }, +} +``` + +#### 2. **Critical Message Sequencing** + +The LLM expects this exact conversation flow: + +```text + 1. User Message -> "Can you read config.json?" + 2. Assistant Message -> [with tool_calls to read_file] + 3. Tool Result(s) -> [file contents] + 4. Assistant Message -> [final response with no tool_calls] +``` + +**Implementation:** + +```go +// MUST add assistant message with tool calls BEFORE executing tools +history = append(history, assistantMessage) + +// Execute tools and add results +for _, toolCall := range *assistantMessage.ToolCalls { + result := executeWithYourLogic(client, toolCall) + history = append(history, *result) // Each tool result +} + +// Send back to LLM for final response +``` + +#### 3. **Error Handling** + +Always return a valid tool result, even for errors: + +```go +func handleToolError(toolCall schemas.ToolCall, err error) *schemas.BifrostMessage { + return &schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleTool, + Content: schemas.MessageContent{ + ContentStr: &[]string{fmt.Sprintf("Error: %v", err)}[0], + }, + ToolCallID: toolCall.ID, // CRITICAL: Must match the tool call ID + } +} +``` + +### Essential Fields + +When implementing tool execution, these fields are critical: + +#### **Tool Call ID Matching** + +```go +// The tool result MUST have the same ID as the tool call +toolResult.ToolCallID = toolCall.ID +``` + +#### **Message Roles** + +```go +schemas.ModelChatMessageRoleUser // User input +schemas.ModelChatMessageRoleAssistant // LLM responses (with/without tool_calls) +schemas.ModelChatMessageRoleTool // Tool execution results +schemas.ModelChatMessageRoleSystem // System instructions +``` + +### Common Implementation Patterns + +#### **Automatic Execution with Safety Checks** + +```go +func executeWithSafetyChecks(client *bifrost.Bifrost, toolCall schemas.ToolCall) *schemas.BifrostMessage { + // Log all tool usage + log.Printf("Tool called: %s with args: %s", toolCall.Function.Name, toolCall.Function.Arguments) + + // Apply your business logic here + if requiresSpecialHandling(toolCall.Function.Name) { + return handleSpecialTool(client, toolCall) + } + + // Default: execute via Bifrost MCP infrastructure + result, err := client.ExecuteMCPTool(ctx, toolCall) + if err != nil { + return createErrorResult(toolCall, err) + } + + return result +} +``` + +#### **Context-Based Filtering** + +```go +// Runtime control over which MCP clients are active +ctx := context.WithValue(context.Background(), "mcp_include_clients", []string{"safe-tools"}) + +response, err := client.ChatCompletionRequest(ctx, request) +``` + +### Key Takeaways + +1. **Bifrost handles discovery and registration** - MCP tools are automatically added to LLM requests +2. **You control execution** - Implement approval, logging, and security in your tool execution logic +3. **Message sequencing matters** - Follow the exact conversation flow pattern +4. **Tool Call IDs must match** - Critical for proper conversation continuity +5. **Error handling is essential** - Always return valid tool results, even for failures + +For a complete working example, see `tests/core-chatbot/main.go` in the repository. + +## Architecture + +### Integration Architecture + +```text +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Bifrost Core β”‚ +β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ +β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ +β”‚ β”‚ MCP Manager β”‚ β”‚ +β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ +β”‚ β”‚ β”‚ Local MCP β”‚ β”‚ External MCP β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ Server β”‚ β”‚ Clients β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ - Host Tools β”‚ β”‚ - HTTP Clients β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ - HTTP Server β”‚ β”‚ - STDIO Procs β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ - SSE Clients β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ - Tool Reg. β”‚ β”‚ - Tool Discoveryβ”‚ β”‚ β”‚ +β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ +β”‚ β”‚ β”‚ β”‚ +β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ +β”‚ β”‚ β”‚ Client Manager β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ - Connection Lifecycle β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ - Tool Mapping β”‚ β”‚ β”‚ +β”‚ β”‚ β”‚ - Configuration Management β”‚ β”‚ β”‚ +β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ +β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ +β”‚ β”‚ +β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ +β”‚ β”‚ Request Processing β”‚ β”‚ +β”‚ β”‚ - Tool Auto-Discovery β”‚ β”‚ +β”‚ β”‚ - Tool Execution β”‚ β”‚ +β”‚ β”‚ - Client Filtering β”‚ β”‚ +β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β–Ό +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ AI Model Providers β”‚ +β”‚ - OpenAI, Anthropic, Azure, etc. β”‚ +β”‚ - Tool-enabled requests β”‚ +β”‚ - Automatic tool calling β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +### Tool Execution Flow + +```text +User Request + β”‚ + β–Ό +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Add MCP Tools │───▢│ LLM Process │───▢│ Execute Tools β”‚ +β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ +β”‚ - Discovery β”‚ β”‚ - Generate β”‚ β”‚ - Call MCP β”‚ +β”‚ - Filter β”‚ β”‚ Response β”‚ β”‚ Servers β”‚ +β”‚ - Add to β”‚ β”‚ - Tool Calls β”‚ β”‚ - Return β”‚ +β”‚ Request β”‚ β”‚ β”‚ β”‚ Results β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +### Connection Types + +#### HTTP Connections + +- Direct HTTP communication with MCP servers +- Suitable for web services and remote tools +- Automatic connection management + +#### STDIO Connections + +- Launch command-line MCP tools as child processes +- Communicate via stdin/stdout +- Automatic process lifecycle management +- Process cleanup on Bifrost shutdown + +#### SSE Connections + +- Connect to Server-Sent Events streams +- Persistent, long-lived connections for real-time data +- Automatic connection management and reconnection +- Proper context cleanup on shutdown + +## API Reference + +### Core Integration Methods + +`bifrost.Init(config schemas.BifrostConfig) (*Bifrost, error)` + +Initializes Bifrost with MCP integration. The MCP configuration is provided as part of the main Bifrost configuration. + +`bifrost.RegisterMCPTool(name, description string, handler func(args any) (string, error), toolSchema schemas.Tool) error` + +Registers a typed Go function as an MCP tool in the local MCP server. + +`bifrost.ExecuteMCPTool(ctx context.Context, toolCall schemas.ToolCall) (*schemas.BifrostMessage, *schemas.BifrostError)` + +Executes an MCP tool with the given tool call and returns the result. + +### Configuration Structures + +`schemas.MCPConfig` + +Main configuration for MCP integration. + +```go +type MCPConfig struct { + ClientConfigs []MCPClientConfig `json:"client_configs,omitempty"` // MCP client configurations (connection + filtering) +} +``` + +`schemas.MCPClientConfig` + +Configuration for individual MCP clients, including both connection details and tool filtering. + +```go +type MCPClientConfig struct { + Name string `json:"name"` // Client name + ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, or SSE) + ConnectionString *string `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections) + StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections) + ToolsToSkip []string `json:"tools_to_skip,omitempty"` // Tools to exclude from this client + ToolsToExecute []string `json:"tools_to_execute,omitempty"` // Tools to include from this client (if specified, only these are used) +} +``` + +`schemas.MCPStdioConfig` + +STDIO-specific configuration for launching external MCP tools. + +```go +type MCPStdioConfig struct { + Command string `json:"command"` // Executable command to run + Args []string `json:"args"` // Command line arguments + Envs []string `json:"envs"` // Environment variables required +} +``` + +`schemas.MCPConnectionType` + +Enumeration of supported connection types. + +```go +type MCPConnectionType string + +const ( + MCPConnectionTypeHTTP MCPConnectionType = "http" // HTTP-based MCP connection (streamable) + MCPConnectionTypeSTDIO MCPConnectionType = "stdio" // STDIO-based MCP connection + MCPConnectionTypeSSE MCPConnectionType = "sse" // Server-Sent Events MCP connection +) +``` + +### Context Keys for Client Filtering + +- `"mcp_include_clients"`: Whitelist specific clients for a request +- `"mcp_exclude_clients"`: Blacklist specific clients for a request + +## Advanced Features + +### Tool and Client Filtering + +The MCP integration provides multiple levels of filtering to control which tools are available and how they execute. + +#### Configuration-Level Tool Filtering + +Configure which tools are available from each client at startup: + +```go +mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{ + { + Name: "filesystem-tools", + ConnectionType: schemas.MCPConnectionTypeSTDIO, + StdioConfig: &schemas.MCPStdioConfig{ + Command: "npx", + Args: []string{"@modelcontextprotocol/server-filesystem", "/home/user"}, + }, + ToolsToSkip: []string{"rm", "delete", "format", "chmod"}, // Exclude dangerous tools + + // Alternative: Include only specific tools (whitelist approach) + // If ToolsToExecute is specified, ONLY these tools will be available + // ToolsToExecute: []string{"read_file", "list_files", "write_file"}, + }, + { + Name: "safe-tools", + ConnectionType: schemas.MCPConnectionTypeHTTP, + ConnectionString: &[]string{"http://localhost:3000"}[0], + ToolsToExecute: []string{"search", "weather"}, // Only safe operations + }, + }, +} +``` + +**Configuration-Level Priority Rules:** + +1. **`ToolsToExecute` takes precedence**: If specified, only these tools are available (whitelist) +2. **`ToolsToSkip` is secondary**: Only applies when `ToolsToExecute` is empty (blacklist) +3. **Empty configurations**: All discovered tools are available + +#### Request-Level Client Filtering + +Control which MCP clients are active per individual request: + +```go +// Whitelist mode - only include specific clients +ctx = context.WithValue(ctx, "mcp_include_clients", []string{"weather", "calendar"}) + +// Blacklist mode - exclude specific clients +ctx = context.WithValue(ctx, "mcp_exclude_clients", []string{"filesystem", "admin-tools"}) + +// Use in request +response, err := bifrost.ChatCompletionRequest(ctx, request) +``` + +**Request-Level Priority Rules:** + +1. **Include takes absolute precedence**: If `mcp_include_clients` is set, only those clients are used +2. **Exclude is secondary**: Only applies when include list is empty +3. **Empty filters**: All configured clients are available + +#### Combined Example: Multi-Level Filtering + +```go +// 1. Configuration Level: Set up clients with tool filtering +mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{ + { + Name: "filesystem", + ConnectionType: schemas.MCPConnectionTypeSTDIO, + StdioConfig: &schemas.MCPStdioConfig{ + Command: "npx", + Args: []string{"@modelcontextprotocol/server-filesystem", "/home/user"}, + }, + ToolsToExecute: []string{"read_file", "list_files"}, // Only safe read operations + }, + { + Name: "weather", + ConnectionType: schemas.MCPConnectionTypeHTTP, + ConnectionString: &[]string{"http://localhost:3000"}[0], + // All tools available (no filtering) + }, + { + Name: "admin-tools", + ConnectionType: schemas.MCPConnectionTypeSTDIO, + StdioConfig: &schemas.MCPStdioConfig{ + Command: "npx", + Args: []string{"admin-server"}, + }, + ToolsToSkip: []string{"delete_user", "reset_system"}, // Exclude dangerous operations + }, + }, +} + +// 2. Request Level: Further filter clients per request +ctx := context.Background() + +// For safe operations - include filesystem and weather only +ctx = context.WithValue(ctx, "mcp_include_clients", []string{"filesystem", "weather"}) + +// For admin operations - exclude only high-risk client +// ctx = context.WithValue(ctx, "mcp_exclude_clients", []string{"admin-tools"}) +``` + +#### Filtering Priority Summary + +**Overall Priority Order (highest to lowest):** + +1. **Request-level include** (`mcp_include_clients`) - Absolute whitelist +2. **Request-level exclude** (`mcp_exclude_clients`) - Applied if no include list +3. **Config-level tool whitelist** (`ToolsToExecute`) - Per-client tool whitelist +4. **Config-level tool blacklist** (`ToolsToSkip`) - Per-client tool blacklist +5. **Default**: All tools from all clients available + +### Automatic Tool Discovery and Integration + +The MCP integration automatically: + +1. **Discovers Tools**: Connects to external MCP servers and discovers available tools +2. **Adds to Requests**: Automatically adds discovered tools to LLM requests +3. **Executes Tools**: Handles tool execution when LLMs make tool calls +4. **Manages Connections**: Maintains connections to external MCP servers + +#### Dynamic Tool Integration Flow + +```go +// Tools are automatically discovered and added to requests +request := &schemas.BifrostRequest{ + Provider: "openai", + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &conversationHistory, + }, + // No need to manually specify tools - they're added automatically +} + +// MCP integration automatically: +// 1. Discovers available tools from all connected clients +// 2. Filters tools based on configuration and context +// 3. Adds tools to the request +// 4. Handles tool execution if the LLM makes tool calls +response, err := bifrost.ChatCompletionRequest(ctx, request) +``` + +### External MCP Server Integration + +#### Connecting to NPM-based MCP Servers + +Many MCP servers are available as NPM packages: + +```go +mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{ + // File system server + { + Name: "filesystem", + ConnectionType: schemas.MCPConnectionTypeSTDIO, + StdioConfig: &schemas.MCPStdioConfig{ + Command: "npx", + Args: []string{"@modelcontextprotocol/server-filesystem", "/home/user"}, + }, + ToolsToSkip: []string{"rm", "delete"}, // Skip dangerous operations + }, + // Web search server + { + Name: "web-search", + ConnectionType: schemas.MCPConnectionTypeSTDIO, + StdioConfig: &schemas.MCPStdioConfig{ + Command: "npx", + Args: []string{"-y", "serper-search-scrape-mcp-server"}, + }, + ToolsToSkip: []string{}, // No tools to skip + }, + // Database server + { + Name: "database", + ConnectionType: schemas.MCPConnectionTypeSTDIO, + StdioConfig: &schemas.MCPStdioConfig{ + Command: "npx", + Args: []string{"@modelcontextprotocol/server-sqlite", "database.db"}, + }, + ToolsToSkip: []string{"drop_table", "delete_database"}, // Skip destructive operations + }, + }, +} +``` + +#### Connecting to HTTP-based MCP Servers + +For web-based MCP services: + +```go +mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{ + { + Name: "cloud-service", + ConnectionType: schemas.MCPConnectionTypeHTTP, + ConnectionString: &[]string{"https://api.example.com/mcp"}[0], + ToolsToSkip: []string{}, // No tools to skip + }, + }, +} +``` + +#### Connecting to SSE-based MCP Servers + +For Server-Sent Events based MCP services: + +```go +mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{ + { + Name: "real-time-service", + ConnectionType: schemas.MCPConnectionTypeSSE, + ConnectionString: &[]string{"https://api.example.com/sse"}[0], + ToolsToSkip: []string{}, // No tools to skip + }, + }, +} +``` + +### Security Best Practices + +**1. Default to Restrictive Filtering** + +```go +// Secure by default - only allow safe tools +clientConfig := schemas.MCPClientConfig{ + Name: "external-tools", + ConnectionType: schemas.MCPConnectionTypeSTDIO, + StdioConfig: &schemas.MCPStdioConfig{ + Command: "npx", + Args: []string{"external-mcp-server"}, + }, + ToolsToExecute: []string{"search", "weather", "read_file"}, // Whitelist approach +} +``` + +**2. Environment-based Tool Control** + +```go +func getSecureMCPConfig(environment string) *schemas.MCPConfig { + config := &schemas.MCPConfig{} + + switch environment { + case "production": + // Minimal tools in production + config.ClientConfigs = []schemas.MCPClientConfig{ + { + Name: "search", + ConnectionType: schemas.MCPConnectionTypeHTTP, + ConnectionString: &[]string{"http://search-service:3000"}[0], + ToolsToExecute: []string{"web_search"}, + }, + } + case "development": + // More permissive in development + config.ClientConfigs = []schemas.MCPClientConfig{ + { + Name: "filesystem", + ConnectionType: schemas.MCPConnectionTypeSTDIO, + StdioConfig: &schemas.MCPStdioConfig{ + Command: "npx", + Args: []string{"@modelcontextprotocol/server-filesystem", "/home/user"}, + }, + ToolsToSkip: []string{"rm", "delete"}, + }, + { + Name: "search", + ConnectionType: schemas.MCPConnectionTypeHTTP, + ConnectionString: &[]string{"http://localhost:3000"}[0], + }, + } + } + return config +} +``` + +**3. User Role-Based Filtering** + +```go +func getContextForUserRole(role string) context.Context { + ctx := context.Background() + + switch role { + case "admin": + // Admins get all tools + return ctx + case "user": + // Users get safe tools only + return context.WithValue(ctx, "mcp_include_clients", + []string{"weather", "search"}) + case "guest": + // Guests get minimal access + return context.WithValue(ctx, "mcp_include_clients", + []string{"weather"}) + default: + // No tools for unknown roles + return context.WithValue(ctx, "mcp_include_clients", []string{}) + } +} +``` + +## Troubleshooting + +### Common Issues + +#### 1. Connection Failures + +**STDIO Connection Issues:** + +```text +Error: failed to start command 'npx @modelcontextprotocol/server-filesystem' +``` + +**Solutions:** + +- Verify the command exists and is executable +- Check command arguments are correct +- Ensure required dependencies are installed (Node.js for NPM packages) +- Check file permissions + +**HTTP Connection Issues:** + +```text +Error: failed to initialize external MCP client: connection refused +``` + +**Solutions:** + +- Verify the HTTP server is running +- Check the URL is correct and accessible +- Verify network connectivity +- Check firewall settings + +**SSE Connection Issues:** + +```text +Error: SSE stream error: context canceled +``` + +**Solutions:** + +- Verify the SSE server is running and accessible +- Check the SSE endpoint URL is correct +- Ensure the server supports Server-Sent Events protocol +- Check for network connectivity issues +- Verify the SSE stream is properly formatted + +#### 2. Tool Registration Failures + +**Tool Already Exists:** + +```text +Error: tool 'echo' already registered +``` + +**Solutions:** + +- Use unique tool names across all MCP clients +- Check for duplicate registrations +- Clear existing tools if needed + +#### 3. Tool Filtering Issues + +**No Tools Available:** + +```text +Warning: No MCP tools found in response +``` + +**Common Causes & Solutions:** + +- **Over-restrictive filtering**: Check if `mcp_include_clients` is too narrow +- **All tools skipped**: Review `ToolsToSkip` configuration for each client +- **Client connection issues**: Verify external MCP clients are connected +- **Empty whitelist**: If `ToolsToExecute` is set but empty, no tools will be available + +**Unexpected Tool Availability:** + +```text +Warning: Restricted tool 'delete_all_files' is available when it shouldn't be +``` + +**Solutions:** + +- **Check priority order**: Ensure `ToolsToExecute` whitelist is properly configured +- **Verify client filtering**: Make sure dangerous clients are excluded at request level +- **Review configuration**: Confirm `ToolsToSkip` is correctly specified + +#### 4. Tool Execution Failures + +**Tool Not Found:** + +```text +Error: MCP tool 'unknown_tool' not found +``` + +**Solutions:** + +- Verify tool name spelling +- Check if tool is available from connected clients +- Verify client is not filtered out in the request context + +### Debugging Tips + +#### Enable Debug Logging + +```go +logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug) +bifrost, err := bifrost.Init(schemas.BifrostConfig{ + MCPConfig: mcpConfig, + Logger: logger, +}) +``` + +#### Check Tool Registration + +The MCP integration automatically discovers and registers tools. Check the logs for tool discovery messages. + +#### Debug Filtering Configuration + +```go +// Check what clients are active by examining the context +ctx := context.Background() +ctx = context.WithValue(ctx, "mcp_include_clients", []string{"filesystem"}) + +// The integration will automatically filter tools based on context +response, err := bifrost.ChatCompletionRequest(ctx, request) +``` + +#### Monitor Process Status + +External STDIO processes are managed automatically. Check the logs for process start/stop messages. + +--- + +For more information, see the [main Bifrost documentation](../README.md). diff --git a/docs/memory-management.md b/docs/memory-management.md new file mode 100644 index 0000000000..468ebc9577 --- /dev/null +++ b/docs/memory-management.md @@ -0,0 +1,104 @@ +# Bifrost Memory and Concurrency Management + +This document outlines the key configurations for managing memory usage and concurrency in Bifrost. + +## 1. Initial Pool Size + +The `InitialPoolSize` configuration determines the initial size of object pools that Bifrost creates during initialization. These pools are used to reduce runtime allocations and improve performance. + +### Default Value + +- Default: `100` + +### Configuration + +```golang +client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &yourAccount, + InitialPoolSize: 500, // Custom pool size + DropExcessRequests: true, +}) +``` + +### Impact + +- Higher values reduce runtime allocations and latency +- Higher values increase memory usage +- Recommended to set based on your expected concurrent request volume + +## 2. Drop Excess Requests + +The `DropExcessRequests` flag controls how Bifrost handles requests when queues are full. + +### Default Value + +- Default: `false` + +### Configuration + +```golang +client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &yourAccount, + InitialPoolSize: 500, + DropExcessRequests: true, // Enable dropping excess requests +}) +``` + +### Behavior + +- When `true`: Requests are dropped immediately if the queue is full +- When `false`: Requests wait for queue space to become available + +## 3. Provider Concurrency and Buffer Size + +Each provider can be configured with specific concurrency and buffer size settings. + +### Default Values + +- Default Concurrency: `10` workers +- Default Buffer Size: `100` requests + +### Configuration + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": ["gpt-4o-mini"], + "weight": 1.0 + } + ], + "concurrency_and_buffer_size": { + "concurrency": 20, // Number of concurrent workers + "buffer_size": 200 // Size of the request queue + } + } + } +} +``` + +### Impact + +- **Concurrency**: Controls the number of parallel workers processing requests + + - Higher values increase throughput but also increase resource usage + - Should be set based on your provider's rate limits and server capacity + +- **Buffer Size**: Controls the size of the request queue + - Higher values allow more requests to be queued + - Should be set based on your expected request volume and latency requirements + +### Best Practices + +1. Set `InitialPoolSize` to match your expected concurrent request volume +2. Enable `DropExcessRequests` if you want to fail fast when the system is overloaded +3. Configure provider concurrency based on: + - Provider's rate limits + - Available system resources + - Expected request patterns +4. Set buffer size to handle expected request spikes while considering memory constraints + +Remember that these configurations have direct impact on your system's performance and resource usage. It's recommended to test your configuration under expected load conditions to find the optimal settings for your use case. diff --git a/docs/openapi.json b/docs/openapi.json new file mode 100644 index 0000000000..a8467b8514 --- /dev/null +++ b/docs/openapi.json @@ -0,0 +1,1308 @@ +{ + "openapi": "3.0.3", + "info": { + "title": "Bifrost HTTP Transport API", + "description": "A unified HTTP API for accessing multiple AI model providers:\n\nβ€’ openai\nβ€’ anthropic\nβ€’ azure\nβ€’ bedrock\nβ€’ cohere\nβ€’ vertex\nβ€’ mistral\nβ€’ ollama\n\nBifrost provides standardized endpoints for text and chat completions with built-in fallback support and comprehensive monitoring.\n\n**MCP Integration**: Includes Model Context Protocol (MCP) support for external tool integration. Configure MCP servers to automatically add tools to model requests and execute them via dedicated endpoints.", + "version": "1.1.2", + "contact": { + "name": "Bifrost API Support", + "url": "https://github.com/maximhq/bifrost" + }, + "license": { + "name": "MIT", + "url": "https://opensource.org/licenses/MIT" + } + }, + "servers": [ + { + "url": "http://localhost:8080", + "description": "Local development server" + } + ], + "paths": { + "/v1/chat/completions": { + "post": { + "summary": "Create Chat Completion", + "description": "Creates a chat completion using conversational messages. Supports tool calling, image inputs, and multiple AI providers with automatic fallbacks.", + "operationId": "createChatCompletion", + "tags": ["Chat Completions"], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ChatCompletionRequest" + }, + "examples": { + "simple_chat": { + "summary": "Simple chat message", + "value": { + "provider": "openai", + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": "Hello, how are you?" + } + ] + } + }, + "tool_calling": { + "summary": "Chat with tool calling", + "value": { + "provider": "openai", + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": "What's the weather in San Francisco?" + } + ], + "params": { + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + } + }, + "required": ["location"] + } + } + } + ], + "tool_choice": { + "type": "function", + "function": { + "name": "get_weather" + } + } + } + } + }, + "with_fallbacks": { + "summary": "Chat with fallback providers", + "value": { + "provider": "openai", + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": "Explain quantum computing" + } + ], + "fallbacks": [ + { + "provider": "anthropic", + "model": "claude-3-sonnet-20240229" + }, + { + "provider": "cohere", + "model": "command" + } + ] + } + }, + "structured_content": { + "summary": "Chat with structured content (text and image)", + "value": { + "provider": "openai", + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What's happening in this image? What's the weather like?" + }, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/weather-photo.jpg", + "detail": "high" + } + } + ] + } + ], + "params": { + "max_tokens": 1000, + "temperature": 0.7 + } + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Successful chat completion", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostResponse" + }, + "examples": { + "simple_response": { + "summary": "Simple chat response", + "value": { + "id": "chatcmpl-123", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! I'm doing well, thank you for asking. How can I help you today?" + }, + "finish_reason": "stop" + } + ], + "model": "gpt-4o", + "created": 1677652288, + "usage": { + "prompt_tokens": 12, + "completion_tokens": 19, + "total_tokens": 31 + }, + "extra_fields": { + "provider": "openai", + "model_params": {}, + "latency": 1.234, + "raw_response": {} + } + } + }, + "tool_response": { + "summary": "Tool calling response", + "value": { + "id": "chatcmpl-456", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\": \"San Francisco, CA\"}" + } + } + ] + }, + "finish_reason": "tool_calls" + } + ], + "model": "gpt-4o", + "created": 1677652288, + "usage": { + "prompt_tokens": 45, + "completion_tokens": 12, + "total_tokens": 57 + }, + "extra_fields": { + "provider": "openai", + "model_params": {}, + "latency": 0.856, + "raw_response": {} + } + } + } + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "401": { + "$ref": "#/components/responses/Unauthorized" + }, + "429": { + "$ref": "#/components/responses/RateLimited" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/v1/text/completions": { + "post": { + "summary": "Create Text Completion", + "description": "Creates a text completion from a prompt. Useful for text generation, summarization, and other non-conversational tasks.", + "operationId": "createTextCompletion", + "tags": ["Text Completions"], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TextCompletionRequest" + }, + "examples": { + "simple_text": { + "summary": "Simple text completion", + "value": { + "provider": "openai", + "model": "gpt-3.5-turbo-instruct", + "text": "The future of artificial intelligence is", + "params": { + "max_tokens": 100, + "temperature": 0.7 + } + } + }, + "with_stop_sequences": { + "summary": "Text completion with stop sequences", + "value": { + "provider": "cohere", + "model": "command", + "text": "Write a short story about a robot:", + "params": { + "max_tokens": 200, + "temperature": 0.8, + "stop_sequences": ["\n\n", "THE END"] + } + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Successful text completion", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostResponse" + }, + "examples": { + "text_response": { + "summary": "Text completion response", + "value": { + "id": "cmpl-789", + "object": "text.completion", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The future of artificial intelligence is incredibly promising, with advances in machine learning, natural language processing, and robotics reshaping industries and daily life." + }, + "finish_reason": "stop" + } + ], + "model": "gpt-3.5-turbo-instruct", + "created": 1677652288, + "usage": { + "prompt_tokens": 8, + "completion_tokens": 32, + "total_tokens": 40 + }, + "extra_fields": { + "provider": "openai", + "model_params": { + "max_tokens": 100, + "temperature": 0.7 + }, + "latency": 0.654, + "raw_response": {} + } + } + } + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "401": { + "$ref": "#/components/responses/Unauthorized" + }, + "429": { + "$ref": "#/components/responses/RateLimited" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/v1/mcp/tool/execute": { + "post": { + "summary": "Execute MCP Tool", + "description": "Executes an MCP (Model Context Protocol) tool that has been configured in Bifrost. This endpoint is used to execute tool calls returned by AI models during conversations. Requires MCP to be configured in Bifrost.", + "operationId": "executeMCPTool", + "tags": ["MCP Tools"], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ToolCall" + }, + "examples": { + "google_search": { + "summary": "Google Search Tool Execution", + "value": { + "type": "function", + "id": "toolu_01VfefsSy7ZRdawdw7U2fg", + "function": { + "name": "google_search", + "arguments": "{\"gl\":\"us\",\"hl\":\"en\",\"num\":5,\"q\":\"San Francisco news yesterday\",\"tbs\":\"qdr:d\"}" + } + } + }, + "file_read": { + "summary": "File Read Tool Execution", + "value": { + "type": "function", + "id": "call_abc123", + "function": { + "name": "read_file", + "arguments": "{\"path\": \"/tmp/config.json\"}" + } + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Tool execution successful", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostMessage" + }, + "examples": { + "search_result": { + "summary": "Google Search Result", + "value": { + "role": "tool", + "content": "{\n \"searchParameters\": {\n \"q\": \"San Francisco news yesterday\",\n \"gl\": \"us\",\n \"hl\": \"en\",\n \"type\": \"search\",\n \"num\": 5,\n \"tbs\": \"qdr:d\",\n \"engine\": \"google\"\n },\n \"organic\": [\n {\n \"title\": \"San Francisco Chronicle Β· Giants' today\"\n },\n {\n \"query\": \"s.f. chronicle e edition\"\n }\n ],\n \"credits\": 1\n}", + "tool_call_id": "toolu_01VfefsSy7ZRdawdw7U2fg" + } + }, + "file_content": { + "summary": "File Read Result", + "value": { + "role": "tool", + "content": "{\n \"provider\": \"openai\",\n \"model\": \"gpt-4o-mini\",\n \"api_key\": \"sk-***\"\n}", + "tool_call_id": "call_abc123" + } + } + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest" + }, + "429": { + "$ref": "#/components/responses/RateLimited" + }, + "500": { + "$ref": "#/components/responses/InternalServerError" + } + } + } + }, + "/metrics": { + "get": { + "summary": "Get Prometheus Metrics", + "description": "Returns Prometheus-compatible metrics for monitoring request counts, latency, token usage, and error rates.", + "operationId": "getMetrics", + "tags": ["Monitoring"], + "responses": { + "200": { + "description": "Prometheus metrics in text format", + "content": { + "text/plain": { + "schema": { + "type": "string" + }, + "example": "# HELP http_requests_total Total number of HTTP requests\n# TYPE http_requests_total counter\nhttp_requests_total{method=\"POST\",handler=\"/v1/chat/completions\",code=\"200\"} 42\n" + } + } + } + } + } + } + }, + "components": { + "schemas": { + "ChatCompletionRequest": { + "type": "object", + "required": ["provider", "model", "messages"], + "properties": { + "provider": { + "$ref": "#/components/schemas/ModelProvider" + }, + "model": { + "type": "string", + "description": "Model identifier (provider-specific)", + "example": "gpt-4o" + }, + "messages": { + "type": "array", + "items": { + "$ref": "#/components/schemas/BifrostMessage" + }, + "description": "Array of chat messages", + "minItems": 1 + }, + "params": { + "$ref": "#/components/schemas/ModelParameters" + }, + "fallbacks": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Fallback" + }, + "description": "Fallback providers and models" + } + } + }, + "TextCompletionRequest": { + "type": "object", + "required": ["provider", "model", "text"], + "properties": { + "provider": { + "$ref": "#/components/schemas/ModelProvider" + }, + "model": { + "type": "string", + "description": "Model identifier (provider-specific)", + "example": "gpt-3.5-turbo-instruct" + }, + "text": { + "type": "string", + "description": "Text prompt for completion", + "example": "The benefits of artificial intelligence include" + }, + "params": { + "$ref": "#/components/schemas/ModelParameters" + }, + "fallbacks": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Fallback" + }, + "description": "Fallback providers and models" + } + } + }, + "ModelProvider": { + "type": "string", + "enum": [ + "openai", + "anthropic", + "azure", + "bedrock", + "cohere", + "vertex", + "mistral", + "ollama" + ], + "description": "AI model provider", + "example": "openai" + }, + "BifrostMessage": { + "type": "object", + "required": ["role"], + "properties": { + "role": { + "$ref": "#/components/schemas/MessageRole" + }, + "content": { + "oneOf": [ + { + "type": "string", + "description": "Simple text content", + "example": "Hello, how are you?" + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/ContentBlock" + }, + "description": "Structured content with text and images" + } + ], + "description": "Message content - can be simple text or structured content with text and images" + }, + "tool_call_id": { + "type": "string", + "description": "ID of the tool call (for tool messages)" + }, + "tool_calls": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolCall" + }, + "description": "Tool calls made by assistant" + }, + "refusal": { + "type": "string", + "description": "Refusal message from assistant" + }, + "annotations": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Annotation" + }, + "description": "Message annotations" + }, + "thought": { + "type": "string", + "description": "Assistant's internal thought process" + } + } + }, + "MessageRole": { + "type": "string", + "enum": ["user", "assistant", "system", "tool"], + "description": "Role of the message sender", + "example": "user" + }, + "ContentBlock": { + "type": "object", + "required": ["type"], + "discriminator": { + "propertyName": "type" + }, + "oneOf": [ + { + "type": "object", + "required": ["type", "text"], + "properties": { + "type": { + "type": "string", + "enum": ["text"], + "description": "Content type for text blocks", + "example": "text" + }, + "text": { + "type": "string", + "description": "Text content", + "example": "What do you see in this image?" + } + }, + "additionalProperties": false + }, + { + "type": "object", + "required": ["type", "image_url"], + "properties": { + "type": { + "type": "string", + "enum": ["image_url"], + "description": "Content type for image blocks", + "example": "image_url" + }, + "image_url": { + "$ref": "#/components/schemas/ImageURLStruct", + "description": "Image data" + } + }, + "additionalProperties": false + } + ] + }, + "ImageURLStruct": { + "type": "object", + "required": ["url"], + "properties": { + "url": { + "type": "string", + "description": "Image URL or data URI", + "example": "https://example.com/image.jpg" + }, + "detail": { + "type": "string", + "enum": ["low", "high", "auto"], + "description": "Image detail level", + "example": "auto" + } + } + }, + "ModelParameters": { + "type": "object", + "properties": { + "temperature": { + "type": "number", + "minimum": 0.0, + "maximum": 2.0, + "description": "Controls randomness in the output", + "example": 0.7 + }, + "top_p": { + "type": "number", + "minimum": 0.0, + "maximum": 1.0, + "description": "Nucleus sampling parameter", + "example": 0.9 + }, + "top_k": { + "type": "integer", + "minimum": 1, + "description": "Top-k sampling parameter", + "example": 40 + }, + "max_tokens": { + "type": "integer", + "minimum": 1, + "description": "Maximum number of tokens to generate", + "example": 1000 + }, + "stop_sequences": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Sequences that stop generation", + "example": ["\n\n", "END"] + }, + "presence_penalty": { + "type": "number", + "minimum": -2.0, + "maximum": 2.0, + "description": "Penalizes repeated tokens", + "example": 0.0 + }, + "frequency_penalty": { + "type": "number", + "minimum": -2.0, + "maximum": 2.0, + "description": "Penalizes frequent tokens", + "example": 0.0 + }, + "tools": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Tool" + }, + "description": "Available tools for the model" + }, + "tool_choice": { + "$ref": "#/components/schemas/ToolChoice" + }, + "parallel_tool_calls": { + "type": "boolean", + "description": "Enable parallel tool execution", + "example": true + } + } + }, + "Tool": { + "type": "object", + "required": ["type", "function"], + "properties": { + "id": { + "type": "string", + "description": "Unique tool identifier" + }, + "type": { + "type": "string", + "enum": ["function"], + "description": "Tool type", + "example": "function" + }, + "function": { + "$ref": "#/components/schemas/Function" + } + } + }, + "Function": { + "type": "object", + "required": ["name", "description", "parameters"], + "properties": { + "name": { + "type": "string", + "description": "Function name", + "example": "get_weather" + }, + "description": { + "type": "string", + "description": "Function description", + "example": "Get current weather for a location" + }, + "parameters": { + "$ref": "#/components/schemas/FunctionParameters" + } + } + }, + "FunctionParameters": { + "type": "object", + "required": ["type"], + "properties": { + "type": { + "type": "string", + "description": "Parameter type", + "example": "object" + }, + "description": { + "type": "string", + "description": "Parameter description" + }, + "properties": { + "type": "object", + "additionalProperties": true, + "description": "Parameter properties (JSON Schema)" + }, + "required": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Required parameter names" + }, + "enum": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Enum values for parameters" + } + } + }, + "ToolChoice": { + "type": "object", + "required": ["type"], + "properties": { + "type": { + "type": "string", + "enum": ["none", "auto", "any", "function", "required"], + "description": "How tools should be chosen", + "example": "auto" + }, + "function": { + "$ref": "#/components/schemas/ToolChoiceFunction" + } + } + }, + "ToolChoiceFunction": { + "type": "object", + "required": ["name"], + "properties": { + "name": { + "type": "string", + "description": "Name of the function to call", + "example": "get_weather" + } + } + }, + "ToolCall": { + "type": "object", + "required": ["function"], + "properties": { + "id": { + "type": "string", + "description": "Unique tool call identifier", + "example": "tool_123" + }, + "type": { + "type": "string", + "enum": ["function"], + "description": "Tool call type", + "example": "function" + }, + "function": { + "$ref": "#/components/schemas/FunctionCall" + } + } + }, + "FunctionCall": { + "type": "object", + "required": ["name", "arguments"], + "properties": { + "name": { + "type": "string", + "description": "Function name", + "example": "get_weather" + }, + "arguments": { + "type": "string", + "description": "JSON string of function arguments", + "example": "{\"location\": \"San Francisco, CA\"}" + } + } + }, + "Annotation": { + "type": "object", + "required": ["type", "url_citation"], + "properties": { + "type": { + "type": "string", + "description": "Annotation type" + }, + "url_citation": { + "$ref": "#/components/schemas/Citation" + } + } + }, + "Citation": { + "type": "object", + "required": ["start_index", "end_index", "title"], + "properties": { + "start_index": { + "type": "integer", + "description": "Start index in the text" + }, + "end_index": { + "type": "integer", + "description": "End index in the text" + }, + "title": { + "type": "string", + "description": "Citation title" + }, + "url": { + "type": "string", + "description": "Citation URL" + }, + "sources": { + "description": "Citation sources" + }, + "type": { + "type": "string", + "description": "Citation type" + } + } + }, + "Fallback": { + "type": "object", + "required": ["provider", "model"], + "properties": { + "provider": { + "$ref": "#/components/schemas/ModelProvider" + }, + "model": { + "type": "string", + "description": "Fallback model name", + "example": "claude-3-sonnet-20240229" + } + } + }, + "BifrostResponse": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Unique response identifier", + "example": "chatcmpl-123" + }, + "object": { + "type": "string", + "enum": ["chat.completion", "text.completion"], + "description": "Response type", + "example": "chat.completion" + }, + "choices": { + "type": "array", + "items": { + "$ref": "#/components/schemas/BifrostResponseChoice" + }, + "description": "Array of completion choices" + }, + "model": { + "type": "string", + "description": "Model used for generation", + "example": "gpt-4o" + }, + "created": { + "type": "integer", + "description": "Unix timestamp of creation", + "example": 1677652288 + }, + "service_tier": { + "type": "string", + "description": "Service tier used" + }, + "system_fingerprint": { + "type": "string", + "description": "System fingerprint" + }, + "usage": { + "$ref": "#/components/schemas/LLMUsage" + }, + "extra_fields": { + "$ref": "#/components/schemas/BifrostResponseExtraFields" + } + } + }, + "BifrostResponseChoice": { + "type": "object", + "required": ["index", "message"], + "properties": { + "index": { + "type": "integer", + "description": "Choice index", + "example": 0 + }, + "message": { + "$ref": "#/components/schemas/BifrostMessage" + }, + "finish_reason": { + "type": "string", + "enum": [ + "stop", + "length", + "tool_calls", + "content_filter", + "function_call" + ], + "description": "Reason completion stopped", + "example": "stop" + }, + "stop": { + "type": "string", + "description": "Stop sequence that ended generation" + }, + "log_probs": { + "$ref": "#/components/schemas/LogProbs" + } + } + }, + "LLMUsage": { + "type": "object", + "properties": { + "prompt_tokens": { + "type": "integer", + "description": "Tokens in the prompt", + "example": 56 + }, + "completion_tokens": { + "type": "integer", + "description": "Tokens in the completion", + "example": 31 + }, + "total_tokens": { + "type": "integer", + "description": "Total tokens used", + "example": 87 + }, + "completion_tokens_details": { + "$ref": "#/components/schemas/CompletionTokensDetails" + } + } + }, + "CompletionTokensDetails": { + "type": "object", + "properties": { + "reasoning_tokens": { + "type": "integer", + "description": "Tokens used for reasoning" + }, + "audio_tokens": { + "type": "integer", + "description": "Tokens used for audio" + }, + "accepted_prediction_tokens": { + "type": "integer", + "description": "Accepted prediction tokens" + }, + "rejected_prediction_tokens": { + "type": "integer", + "description": "Rejected prediction tokens" + } + } + }, + "BifrostResponseExtraFields": { + "type": "object", + "properties": { + "provider": { + "$ref": "#/components/schemas/ModelProvider" + }, + "model_params": { + "$ref": "#/components/schemas/ModelParameters" + }, + "latency": { + "type": "number", + "description": "Request latency in seconds", + "example": 1.234 + }, + "chat_history": { + "type": "array", + "items": { + "$ref": "#/components/schemas/BifrostMessage" + }, + "description": "Full conversation history" + }, + "billed_usage": { + "$ref": "#/components/schemas/BilledLLMUsage" + }, + "raw_response": { + "type": "object", + "description": "Raw provider response" + } + } + }, + "BilledLLMUsage": { + "type": "object", + "properties": { + "prompt_tokens": { + "type": "number", + "description": "Billed prompt tokens" + }, + "completion_tokens": { + "type": "number", + "description": "Billed completion tokens" + }, + "search_units": { + "type": "number", + "description": "Billed search units" + }, + "classifications": { + "type": "number", + "description": "Billed classifications" + } + } + }, + "LogProbs": { + "type": "object", + "properties": { + "content": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ContentLogProb" + }, + "description": "Log probabilities for content" + }, + "refusal": { + "type": "array", + "items": { + "$ref": "#/components/schemas/LogProb" + }, + "description": "Log probabilities for refusal" + } + } + }, + "ContentLogProb": { + "type": "object", + "required": ["logprob", "token"], + "properties": { + "bytes": { + "type": "array", + "items": { + "type": "integer" + }, + "description": "Byte representation" + }, + "logprob": { + "type": "number", + "description": "Log probability", + "example": -0.123 + }, + "token": { + "type": "string", + "description": "Token", + "example": "hello" + }, + "top_logprobs": { + "type": "array", + "items": { + "$ref": "#/components/schemas/LogProb" + }, + "description": "Top log probabilities" + } + } + }, + "LogProb": { + "type": "object", + "required": ["logprob", "token"], + "properties": { + "bytes": { + "type": "array", + "items": { + "type": "integer" + }, + "description": "Byte representation" + }, + "logprob": { + "type": "number", + "description": "Log probability", + "example": -0.456 + }, + "token": { + "type": "string", + "description": "Token", + "example": "world" + } + } + }, + "BifrostError": { + "type": "object", + "required": ["is_bifrost_error", "error"], + "properties": { + "event_id": { + "type": "string", + "description": "Unique error event ID", + "example": "evt_123" + }, + "type": { + "type": "string", + "description": "Error type", + "example": "invalid_request_error" + }, + "is_bifrost_error": { + "type": "boolean", + "description": "Whether error originated from Bifrost", + "example": true + }, + "status_code": { + "type": "integer", + "description": "HTTP status code", + "example": 400 + }, + "error": { + "$ref": "#/components/schemas/ErrorField" + } + } + }, + "ErrorField": { + "type": "object", + "required": ["message"], + "properties": { + "type": { + "type": "string", + "description": "Error type", + "example": "invalid_request_error" + }, + "code": { + "type": "string", + "description": "Error code", + "example": "missing_required_parameter" + }, + "message": { + "type": "string", + "description": "Human-readable error message", + "example": "Provider is required" + }, + "param": { + "description": "Parameter that caused the error", + "example": "provider" + }, + "event_id": { + "type": "string", + "description": "Error event ID", + "example": "evt_123" + } + } + } + }, + "responses": { + "BadRequest": { + "description": "Bad Request - Invalid request format or missing required fields", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostError" + }, + "example": { + "is_bifrost_error": true, + "status_code": 400, + "error": { + "type": "invalid_request_error", + "code": "missing_required_parameter", + "message": "Provider is required", + "param": "provider" + } + } + } + } + }, + "Unauthorized": { + "description": "Unauthorized - Invalid or missing API key", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostError" + }, + "example": { + "is_bifrost_error": true, + "status_code": 401, + "error": { + "type": "authentication_error", + "message": "Invalid API key provided" + } + } + } + } + }, + "RateLimited": { + "description": "Too Many Requests - Rate limit exceeded", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostError" + }, + "example": { + "is_bifrost_error": false, + "status_code": 429, + "error": { + "type": "rate_limit_error", + "message": "Rate limit exceeded. Please try again later." + } + } + } + } + }, + "InternalServerError": { + "description": "Internal Server Error - Server or provider error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/BifrostError" + }, + "example": { + "is_bifrost_error": true, + "status_code": 500, + "error": { + "type": "api_error", + "message": "Internal server error occurred" + } + } + } + } + } + } + }, + "tags": [ + { + "name": "Chat Completions", + "description": "Create chat completions using conversational messages" + }, + { + "name": "Text Completions", + "description": "Create text completions from prompts" + }, + { + "name": "MCP Tools", + "description": "Execute MCP tools" + }, + { + "name": "Monitoring", + "description": "Monitoring and observability endpoint" + } + ] +} diff --git a/docs/plugins.md b/docs/plugins.md new file mode 100644 index 0000000000..2fe5ab80e4 --- /dev/null +++ b/docs/plugins.md @@ -0,0 +1,1079 @@ +# Bifrost Plugin System + +Bifrost provides a powerful plugin system that allows you to extend and customize the request/response pipeline. Plugins can implement rate limiting, caching, authentication, logging, monitoring, and more. + +## Table of Contents + +1. [Plugin Loading & Configuration](#1-plugin-loading--configuration) +2. [Transport Integration](#2-transport-integration) +3. [Plugin Architecture Overview](#3-plugin-architecture-overview) +4. [Plugin Interface](#4-plugin-interface) +5. [Plugin Lifecycle](#5-plugin-lifecycle) +6. [Plugin Execution Flow](#6-plugin-execution-flow) +7. [Short-Circuit Behavior](#7-short-circuit-behavior) +8. [Error Handling & Fallbacks](#8-error-handling--fallbacks) +9. [Building Custom Plugins](#9-building-custom-plugins) +10. [Plugin Examples](#10-plugin-examples) +11. [Best Practices](#11-best-practices) +12. [Plugin Development Guidelines](#12-plugin-development-guidelines) +13. [Troubleshooting Guide](#13-troubleshooting-guide) +14. [Performance Optimization](#14-performance-optimization) + +## 1. Plugin Loading & Configuration + +Bifrost supports **Dynamic Configuration-Based Loading** for plugins + +### Dynamic Plugin Loading + +Configure plugins in your `config.json` file for automatic Just-In-Time (JIT) compilation and loading: + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": ["gpt-4o-mini"], + "weight": 1.0 + } + ] + } + }, + "plugins": [ + { + "name": "maxim", + "source": "github.com/maximhq/bifrost/plugins/maxim", + "type": "remote", + "config": { + "api_key": "env.MAXIM_API_KEY", + "log_repo_id": "env.MAXIM_LOG_REPO_ID" + } + }, + { + "name": "custom-plugin", + "source": "./local-plugins/custom", + "type": "local", + "config": { + "setting": "value" + } + } + ] +} +``` + +#### Plugin Configuration Fields + +- **name**: Unique identifier for the plugin +- **source**: Go module path (for remote) or local directory path (for local) +- **type**: Either `"remote"` or `"local"` +- **config**: Plugin-specific configuration (supports `env.VARIABLE_NAME` substitution) + +#### Plugin Types + +**Remote Plugins**: Go modules hosted on public repositories + +```json +{ + "name": "maxim", + "source": "github.com/maximhq/bifrost/plugins/maxim", + "type": "remote", + "config": { ... } +} +``` + +**Local Plugins**: Directories containing Go plugin source code + +```json +{ + "name": "custom-plugin", + "source": "./plugins/custom", + "type": "local", + "config": { ... } +} +``` + +### Environment Variable Support + +Plugin configurations support environment variable substitution using the `env.` prefix: + +```json +{ + "config": { + "api_key": "env.MY_API_KEY", + "endpoint": "env.MY_ENDPOINT" + } +} +``` + +### System Requirements + +- **Go 1.21+**: Required for JIT compilation +- **CGO**: Must be enabled for plugin compilation +- **Platform**: Linux or macOS (Windows not supported due to Go plugin limitations) + +**Windows Users**: Direct Go binary plugin support is not available on Windows due to Go's plugin system limitations. However, Windows users can run Bifrost with full plugin functionality using Docker. + +### Docker Usage with Plugins + +The Bifrost Docker image fully supports dynamic plugin loading. Here are common usage patterns: + +#### Basic Configuration-Based Loading + +```bash +# Run with config file containing plugin definitions +docker run -d \ + -v /path/to/config.json:/app/config/config.json \ + -p 8080:8080 \ + maximhq/bifrost +``` + +#### With Environment Variables + +```bash +# Plugin configs can use environment variables +docker run -d \ + -v /path/to/config.json:/app/config/config.json \ + -e MAXIM_API_KEY=your-api-key \ + -e MAXIM_LOG_REPO_ID=your-repo-id \ + -e CUSTOM_PLUGIN_SETTING=value \ + -p 8080:8080 \ + maximhq/bifrost +``` + +#### Local Plugin Volumes + +For local plugins, mount your plugin directories: + +```bash +# Mount local plugins directory +docker run -d \ + --name bifrost \ + -v /path/to/config.json:/app/config/config.json \ + -v /path/to/your/plugins:/app/plugins:ro \ + -p 8080:8080 \ + maximhq/bifrost +``` + +**Config for local plugins in Docker:** + +```json +{ + "plugins": [ + { + "name": "my-local-plugin", + "source": "/app/plugins/my-plugin", + "type": "local", + "config": { + "setting": "env.MY_PLUGIN_SETTING" + } + } + ] +} +``` + +#### Version Compatibility + +Bifrost uses a hybrid dependency resolution approach: + +1. **Development/Local**: Reads local `go.mod` files to determine dependency versions +2. **Binary/Docker**: Uses `go list` to query runtime dependencies + +This ensures plugins are compiled with the same dependency versions as the main Bifrost binary, preventing version conflicts during plugin loading. + +## 2. Transport Integration + +### HTTP Transport Configuration + +The HTTP transport (`bifrost-http`) automatically loads plugins from the configuration file and provides additional transport-specific features: + +#### Configuration Options + +- **-config**: Path to configuration file (required) +- **-port**: Server port (default: 8080) +- **-pool-size**: Initial connection pool size (default: 300) +- **-drop-excess-requests**: Drop excess requests when pool is full +- **-prometheus-labels**: Labels to add to Prometheus metrics + +#### Built-in Transport Plugins + +The HTTP transport includes these built-in plugins: + +1. **Prometheus Plugin**: Automatically enabled for metrics collection + - Exposes `/metrics` endpoint + - Tracks request counts, latencies, and error rates + - Configurable via `-prometheus-labels` flag + +#### Plugin Loading Process + +1. Configuration-based plugins are loaded first +2. Built-in transport plugins are added automatically +3. All plugins are initialized with their respective configurations + +### Plugin Interface for Dynamic Loading + +All plugins must implement the standard interface and provide an `Init` function for dynamic loading: + +```go +package package_name + +import ( + "encoding/json" + "github.com/maximhq/bifrost/core/schemas" +) + +// Plugin configuration struct +type MyPluginConfig struct { + Setting1 string `json:"setting1"` + Setting2 int `json:"setting2"` +} + +// Required Init function for dynamic loading +func Init(configData json.RawMessage) (schemas.Plugin, error) { + var config MyPluginConfig + if err := json.Unmarshal(configData, &config); err != nil { + return nil, err + } + + // Initialize and return your plugin + return NewMyPlugin(config) +} + +// Your plugin implementation +type MyPlugin struct { + config MyPluginConfig +} + +func NewMyPlugin(config MyPluginConfig) *MyPlugin { + return &MyPlugin{config: config} +} + +func (p *MyPlugin) GetName() string { + return "my-plugin" +} + +func (p *MyPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + // Your pre-hook logic + return req, nil, nil +} + +func (p *MyPlugin) PostHook(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + // Your post-hook logic + return result, err, nil +} + +func (p *MyPlugin) Cleanup() error { + // Cleanup logic + return nil +} +``` + +## 3. Plugin Architecture Overview + +Bifrost plugins follow a **PreHook β†’ Provider β†’ PostHook** pattern with support for short-circuiting and fallback control. + +### Key Concepts + +- **PreHook**: Executed before provider call - can modify requests or short-circuit +- **PostHook**: Executed after provider response - can modify responses or recover from errors +- **Short-Circuit**: Plugin can skip provider call and return response/error directly +- **Fallback Control**: Plugins can control whether fallback providers should be tried +- **Pipeline Symmetry**: Every PreHook execution gets a corresponding PostHook call + +## 4. Plugin Interface + +```go +type Plugin interface { + // GetName returns the name of the plugin + GetName() string + + // PreHook is called before a request is processed by a provider + // Can modify request, short-circuit with response, or short-circuit with error + PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *PluginShortCircuit, error) + + // PostHook is called after a response or after PreHook short-circuit + // Can modify response/error or recover from errors + PostHook(ctx *context.Context, result *BifrostResponse, err *BifrostError) (*BifrostResponse, *BifrostError, error) + + // Cleanup is called on bifrost shutdown + Cleanup() error +} + +type PluginShortCircuit struct { + Response *BifrostResponse // If set, skip provider and return this response + Error *BifrostError // If set, skip provider and return this error +} +``` + +## 5. Plugin Lifecycle + +```mermaid +stateDiagram-v2 + [*] --> PluginInit: Plugin Creation + PluginInit --> Registered: Add to BifrostConfig + Registered --> PreHookCall: Request Received + + PreHookCall --> ModifyRequest: Normal Flow + PreHookCall --> ShortCircuitResponse: Return Response + PreHookCall --> ShortCircuitError: Return Error + + ModifyRequest --> ProviderCall: Send to Provider + ProviderCall --> PostHookCall: Receive Response + + ShortCircuitResponse --> PostHookCall: Skip Provider + ShortCircuitError --> PostHookCall: Pipeline Symmetry + + PostHookCall --> ModifyResponse: Process Result + PostHookCall --> RecoverError: Error Recovery + PostHookCall --> FallbackCheck: Check AllowFallbacks + PostHookCall --> ResponseReady: Pass Through + + FallbackCheck --> TryFallback: AllowFallbacks=true/nil + FallbackCheck --> ResponseReady: AllowFallbacks=false + TryFallback --> PreHookCall: Next Provider + + ModifyResponse --> ResponseReady: Modified + RecoverError --> ResponseReady: Recovered + ResponseReady --> [*]: Return to Client + + Registered --> CleanupCall: Bifrost Shutdown + CleanupCall --> [*]: Plugin Destroyed +``` + +## 6. Plugin Execution Flow + +### Normal Flow (No Short-Circuit) + +```mermaid +sequenceDiagram + participant Client + participant Bifrost + participant Plugin1 + participant Plugin2 + participant Provider + + Client->>Bifrost: Request + Bifrost->>Plugin1: PreHook(request) + Plugin1-->>Bifrost: modified request + Bifrost->>Plugin2: PreHook(request) + Plugin2-->>Bifrost: modified request + Bifrost->>Provider: API Call + Provider-->>Bifrost: response + Bifrost->>Plugin2: PostHook(response) + Plugin2-->>Bifrost: modified response + Bifrost->>Plugin1: PostHook(response) + Plugin1-->>Bifrost: modified response + Bifrost-->>Client: Final Response +``` + +### With Short-Circuit Response + +```mermaid +sequenceDiagram + participant Client + participant Bifrost + participant Plugin1 + participant Plugin2 + participant Provider + + Client->>Bifrost: Request + Bifrost->>Plugin1: PreHook(request) + Plugin1-->>Bifrost: PluginShortCircuit{Response} + Note over Provider: Provider call skipped + Bifrost->>Plugin1: PostHook(response) + Plugin1-->>Bifrost: modified response + Bifrost-->>Client: Final Response +``` + +### With Short-Circuit Error (Allow Fallbacks) + +```mermaid +sequenceDiagram + participant Client + participant Bifrost + participant Plugin1 + participant Provider1 + participant Provider2 + + Client->>Bifrost: Request (Provider1 + Fallback Provider2) + Bifrost->>Plugin1: PreHook(request) + Plugin1-->>Bifrost: PluginShortCircuit{Error, AllowFallbacks=true} + Note over Provider1: Provider1 call skipped + Bifrost->>Plugin1: PostHook(error) + Plugin1-->>Bifrost: error unchanged + + Note over Bifrost: Try fallback provider + Bifrost->>Plugin1: PreHook(request for Provider2) + Plugin1-->>Bifrost: modified request + Bifrost->>Provider2: API Call + Provider2-->>Bifrost: response + Bifrost->>Plugin1: PostHook(response) + Plugin1-->>Bifrost: modified response + Bifrost-->>Client: Final Response +``` + +### Complex Plugin Decision Flow + +```mermaid +graph TD + A["Client Request"] --> B["Bifrost"] + B --> C["Auth Plugin PreHook"] + C --> D{"Authenticated?"} + D -->|No| E["Return Auth Error
AllowFallbacks=false"] + D -->|Yes| F["RateLimit Plugin PreHook"] + F --> G{"Rate Limited?"} + G -->|Yes| H["Return Rate Error
AllowFallbacks=nil"] + G -->|No| I["Cache Plugin PreHook"] + I --> J{"Cache Hit?"} + J -->|Yes| K["Return Cached Response"] + J -->|No| L["Provider API Call"] + L --> M["Cache Plugin PostHook"] + M --> N["Store in Cache"] + N --> O["RateLimit Plugin PostHook"] + O --> P["Auth Plugin PostHook"] + P --> Q["Final Response"] + + E --> R["Skip Fallbacks"] + H --> S["Try Fallback Provider"] + K --> T["Skip Provider Call"] +``` + +## 7. Short-Circuit Behavior + +Plugins can short-circuit the normal flow in two ways: + +### 1. Short-Circuit with Response (Success) + +```go +func (p *CachePlugin) PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *PluginShortCircuit, error) { + if cachedResponse := p.getFromCache(req); cachedResponse != nil { + // Return cached response, skip provider call + return req, &PluginShortCircuit{ + Response: cachedResponse, + }, nil + } + return req, nil, nil +} +``` + +### 2. Short-Circuit with Error + +```go +func (p *AuthPlugin) PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *PluginShortCircuit, error) { + if !p.isAuthenticated(req) { + // Return error, skip provider call + return req, &PluginShortCircuit{ + Error: &BifrostError{ + Error: ErrorField{Message: "authentication failed"}, + AllowFallbacks: &false, // Don't try other providers + }, + }, nil + } + return req, nil, nil +} +``` + +## 8. Error Handling & Fallbacks + +When plugins return errors, they control whether Bifrost should try fallback providers: + +### AllowFallbacks Control + +```go +// Allow fallbacks (default behavior) +&BifrostError{ + Error: ErrorField{Message: "rate limit exceeded"}, + AllowFallbacks: nil, // nil = true by default +} + +// Explicitly allow fallbacks +&BifrostError{ + Error: ErrorField{Message: "temporary failure"}, + AllowFallbacks: &true, +} + +// Prevent fallbacks +&BifrostError{ + Error: ErrorField{Message: "authentication failed"}, + AllowFallbacks: &false, +} +``` + +### Fallback Decision Matrix + +| Error Type | AllowFallbacks | Behavior | +| ------------------ | --------------- | ---------------------------------------------------------- | +| Rate Limiting | `nil` or `true` | βœ… Try fallbacks (other providers may not be rate limited) | +| Temporary Failure | `nil` or `true` | βœ… Try fallbacks (may succeed with different provider) | +| Authentication | `false` | ❌ No fallbacks (fundamental failure) | +| Validation Error | `false` | ❌ No fallbacks (request is invalid) | +| Security Violation | `false` | ❌ No fallbacks (security concern) | + +### PostHook Error Recovery + +Plugins can recover from errors in PostHook: + +```go +func (p *RetryPlugin) PostHook(ctx *context.Context, result *BifrostResponse, err *BifrostError) (*BifrostResponse, *BifrostError, error) { + if err != nil && p.shouldRetry(err) { + // Recover by calling provider again + if retryResponse := p.retry(ctx); retryResponse != nil { + return retryResponse, nil, nil // Recovered successfully + } + } + return result, err, nil +} +``` + +## 9. Building Custom Plugins + +### Basic Plugin Structure + +**For Dynamic Loading (Required)** + +All plugins must provide an `Init` function for dynamic loading: + +```go +package plugin_name + +import ( + "encoding/json" + "github.com/maximhq/bifrost/core/schemas" +) + +type CustomPluginConfig struct { + Setting1 string `json:"setting1"` + Setting2 int `json:"setting2"` +} + +// Required Init function for dynamic loading +// Signature: Init(config json.RawMessage) (schemas.Plugin, error) +func Init(configData json.RawMessage) (schemas.Plugin, error) { + var config CustomPluginConfig + if err := json.Unmarshal(configData, &config); err != nil { + return nil, err + } + + return NewCustomPlugin(config), nil +} + +type CustomPlugin struct { + config CustomPluginConfig + // Add your fields here +} + +func NewCustomPlugin(config CustomPluginConfig) *CustomPlugin { + return &CustomPlugin{config: config} +} + +func (p *CustomPlugin) GetName() string { + return "custom-plugin" +} + +func (p *CustomPlugin) PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *PluginShortCircuit, error) { + // Modify request or short-circuit + return req, nil, nil +} + +func (p *CustomPlugin) PostHook(ctx *context.Context, result *BifrostResponse, err *BifrostError) (*BifrostResponse, *BifrostError, error) { + // Modify response/error or recover from errors + return result, err, nil +} + +func (p *CustomPlugin) Cleanup() error { + // Clean up resources + return nil +} +``` + +**For Direct Instantiation (Legacy)** + +```go +// Traditional constructor for direct usage +func NewCustomPlugin(config CustomPluginConfig) *CustomPlugin { + return &CustomPlugin{config: config} +} +``` + +### Plugin Development Checklist + +- [ ] **Implement required Init function** with signature: `Init(config json.RawMessage) (schemas.Plugin, error)` +- [ ] **Use proper package name** (not `package main`) +- [ ] Handle nil response and error in PostHook +- [ ] Set appropriate AllowFallbacks for errors +- [ ] Implement proper cleanup in Cleanup() +- [ ] Add configuration validation and JSON tags +- [ ] Write comprehensive tests including Init function +- [ ] Document behavior and configuration + +## 10. Plugin Examples + +### Rate Limiting Plugin + +```go +type RateLimitPlugin struct { + limiters map[ModelProvider]*rate.Limiter + mu sync.RWMutex +} + +func NewRateLimitPlugin(limits map[ModelProvider]float64) *RateLimitPlugin { + limiters := make(map[ModelProvider]*rate.Limiter) + for provider, limit := range limits { + limiters[provider] = rate.NewLimiter(rate.Limit(limit), 1) + } + return &RateLimitPlugin{limiters: limiters} +} + +func (p *RateLimitPlugin) GetName() string { + return "RateLimitPlugin" +} + +func (p *RateLimitPlugin) PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *PluginShortCircuit, error) { + p.mu.RLock() + limiter, exists := p.limiters[req.Provider] + p.mu.RUnlock() + + if exists && !limiter.Allow() { + // Rate limited - allow fallbacks to other providers + return req, &PluginShortCircuit{ + Error: &BifrostError{ + Error: ErrorField{ + Message: fmt.Sprintf("rate limit exceeded for %s", req.Provider), + }, + AllowFallbacks: nil, // Allow fallbacks by default + }, + }, nil + } + + return req, nil, nil +} + +func (p *RateLimitPlugin) PostHook(ctx *context.Context, result *BifrostResponse, err *BifrostError) (*BifrostResponse, *BifrostError, error) { + return result, err, nil +} + +func (p *RateLimitPlugin) Cleanup() error { + return nil +} +``` + +### Authentication Plugin + +```go +type AuthPlugin struct { + validator TokenValidator +} + +func NewAuthPlugin(validator TokenValidator) *AuthPlugin { + return &AuthPlugin{validator: validator} +} + +func (p *AuthPlugin) GetName() string { + return "AuthPlugin" +} + +func (p *AuthPlugin) PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *PluginShortCircuit, error) { + if !p.validator.IsValid(*ctx, req) { + // Authentication failed - don't try fallbacks + return req, &PluginShortCircuit{ + Error: &BifrostError{ + Error: ErrorField{ + Message: "authentication failed", + Type: &authErrorType, + }, + AllowFallbacks: &false, // Don't try other providers + }, + }, nil + } + + return req, nil, nil +} + +func (p *AuthPlugin) PostHook(ctx *context.Context, result *BifrostResponse, err *BifrostError) (*BifrostResponse, *BifrostError, error) { + return result, err, nil +} + +func (p *AuthPlugin) Cleanup() error { + return p.validator.Cleanup() +} +``` + +### Caching Plugin with Recovery + +```go +type CachePlugin struct { + cache Cache + ttl time.Duration +} + +func (p *CachePlugin) PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *PluginShortCircuit, error) { + key := p.generateKey(req) + if cachedResponse := p.cache.Get(key); cachedResponse != nil { + // Return cached response, skip provider + return req, &PluginShortCircuit{ + Response: cachedResponse, + }, nil + } + + return req, nil, nil +} + +func (p *CachePlugin) PostHook(ctx *context.Context, result *BifrostResponse, err *BifrostError) (*BifrostResponse, *BifrostError, error) { + if result != nil { + // Cache successful response + key := p.generateKeyFromResponse(result) + p.cache.Set(key, result, p.ttl) + } + + return result, err, nil +} +``` + +## 11. Best Practices + +### Plugin Design + +1. **Keep plugins focused** - Each plugin should have a single responsibility +2. **Make plugins configurable** - Use configuration structs for flexibility +3. **Handle edge cases** - Always check for nil values and error conditions +4. **Be mindful of performance** - Plugins add latency to every request + +### Error Handling + +1. **Default to allowing fallbacks** - Unless the error is fundamental +2. **Use appropriate error types** - Help categorize different failure modes +3. **Provide clear error messages** - Include context about what failed +4. **Consider error recovery** - PostHooks can recover from certain errors + +### Resource Management + +1. **Implement proper cleanup** - Release resources in Cleanup() +2. **Use context for cancellation** - Respect request timeouts +3. **Avoid memory leaks** - Clean up goroutines and connections +4. **Handle concurrent access** - Use proper synchronization + +### Testing + +1. **Test all code paths** - Including error conditions and edge cases +2. **Test short-circuit behavior** - Verify responses and error handling +3. **Test fallback control** - Ensure AllowFallbacks works correctly +4. **Test plugin interactions** - Verify behavior with multiple plugins + +## 12. Plugin Development Guidelines + +### Plugin Structure Requirements + +Each plugin should be organized as follows: + +```text +plugins/ +└── your-plugin-name/ + β”œβ”€β”€ main.go # Plugin implementation with Init function + β”œβ”€β”€ plugin_test.go # Comprehensive tests including Init + β”œβ”€β”€ README.md # Documentation with examples + └── go.mod # Module definition +``` + +### Required Components + +**1. Init Function (Mandatory)** + +Every plugin must implement the standardized Init function: + +```go +// Signature defined in schemas.Init +func Init(config json.RawMessage) (schemas.Plugin, error) +``` + +**2. Configuration Struct** + +Define a configuration struct with JSON tags: + +```go +type YourPluginConfig struct { + APIKey string `json:"api_key"` + Timeout int `json:"timeout"` + EnableX bool `json:"enable_x"` +} +``` + +**3. Plugin Implementation** + +Implement the schemas.Plugin interface: + +```go +func (p *YourPlugin) GetName() string +func (p *YourPlugin) PreHook(...) (...) +func (p *YourPlugin) PostHook(...) (...) +func (p *YourPlugin) Cleanup() error +``` + +### Using Plugins + +```go +import ( + "github.com/maximhq/bifrost/core" + "github.com/your-org/your-plugin" +) + +client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &yourAccount, + Plugins: []schemas.Plugin{ + your_plugin.NewYourPlugin(config), + // Add more plugins as needed + }, +}) +``` + +### Plugin Execution Order + +Plugins execute in the order they are registered: + +```go +Plugins: []schemas.Plugin{ + authPlugin, // PreHook: 1st, PostHook: 3rd + rateLimitPlugin, // PreHook: 2nd, PostHook: 2nd + loggingPlugin, // PreHook: 3rd, PostHook: 1st +} +``` + +**PreHook Order**: Auth β†’ RateLimit β†’ Logging β†’ Provider +**PostHook Order**: Provider β†’ Logging β†’ RateLimit β†’ Auth + +### Contribution Guidelines + +1. **Design Discussion** + + - Open an issue to discuss your plugin idea + - Explain the use case and design approach + - Get feedback before implementation + +2. **Implementation Standards** + + - Follow Go best practices and conventions + - Include comprehensive error handling + - Ensure thread safety where needed + - Add extensive test coverage (>80%) + +3. **Testing Requirements** + + - Unit tests for all functionality + - Integration tests with Bifrost + - Test error scenarios and edge cases + - Test short-circuit behavior + - Test fallback control + +4. **Documentation Standards** + - Clear, comprehensive README + - Code comments for complex logic + - Usage examples + - Performance characteristics + +### Plugin Testing Best Practices + +```go +func TestYourPlugin_PreHook(t *testing.T) { + tests := []struct { + name string + config YourPluginConfig + request *schemas.BifrostRequest + expectShortCircuit bool + expectError bool + expectFallbacks bool + }{ + { + name: "valid request passes through", + config: YourPluginConfig{EnableFeature: true}, + request: &schemas.BifrostRequest{/* valid request */}, + expectShortCircuit: false, + }, + { + name: "invalid request short-circuits with error", + config: YourPluginConfig{EnableFeature: true}, + request: &schemas.BifrostRequest{/* invalid request */}, + expectShortCircuit: true, + expectError: true, + expectFallbacks: false, + }, + // Add more test cases + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + plugin := NewYourPlugin(tt.config) + ctx := context.Background() + + req, shortCircuit, err := plugin.PreHook(&ctx, tt.request) + + // Assertions + if tt.expectError { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + } + + if tt.expectShortCircuit { + assert.NotNil(t, shortCircuit) + if shortCircuit.Error != nil && shortCircuit.Error.AllowFallbacks != nil { + assert.Equal(t, tt.expectFallbacks, *shortCircuit.Error.AllowFallbacks) + } + } else { + assert.Nil(t, shortCircuit) + } + }) + } +} +``` + +## 13. Troubleshooting Guide + +### Common Issues + +#### 1. Dynamic Plugin Loading Failures + +**Symptoms**: Plugin fails to compile or load dynamically +**Solutions**: + +```bash +# Check Go toolchain availability +go version +go env CGO_ENABLED + +# Verify plugin source exists +ls -la /path/to/plugin # for local plugins +go get github.com/user/plugin # test remote plugin access + +# Debug plugin compilation +export BIFROST_LOG_LEVEL=debug +./bifrost-http -config config.json +``` + +**Common dynamic loading errors**: + +- `plugin was built with a different version`: Version mismatch between main app and plugin dependencies +- `plugin initialization failed`: Plugin's `Init` function returned an error (check configuration) +- `go: module not found`: Remote plugin source is inaccessible or doesn't exist +- `plugin: symbol Init not found`: Plugin missing required `Init` function with correct signature +- `plugin: Init function signature mismatch`: Init function doesn't match required signature `func(json.RawMessage) (schemas.Plugin, error)` + +#### 2. Plugin Not Being Called + +**Symptoms**: Plugin hooks are never executed +**Solutions**: + +```go +// Ensure plugin is properly registered +client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{ + yourPlugin, // Make sure it's in the list + }, +}) + +// Check plugin implements interface correctly +var _ schemas.Plugin = (*YourPlugin)(nil) +``` + +#### 3. Short-Circuit Not Working + +**Symptoms**: Provider is still called despite returning PluginShortCircuit +**Solutions**: + +```go +// Correct: Either Response OR Error, not both +return req, &schemas.PluginShortCircuit{ + Response: cachedResponse, // OR Error, not both +}, nil + +// Incorrect: Don't return error with PluginShortCircuit +return req, &schemas.PluginShortCircuit{...}, fmt.Errorf("error") +``` + +#### 4. Fallback Behavior Not Working + +**Symptoms**: Fallbacks not tried when expected, or tried when they shouldn't be +**Solutions**: + +```go +// For PreHook short-circuits, use PluginShortCircuit +return req, &schemas.PluginShortCircuit{ + Error: &schemas.BifrostError{ + Error: schemas.ErrorField{Message: "error"}, + AllowFallbacks: &false, // Explicitly control fallbacks + }, +}, nil +``` + +#### 5. Memory Leaks + +**Solutions**: + +```go +func (p *YourPlugin) Cleanup() error { + // Close channels + close(p.stopChan) + + // Cancel contexts + p.cancel() + + // Close connections + if p.conn != nil { + p.conn.Close() + } + + // Wait for goroutines + p.wg.Wait() + + return nil +} +``` + +#### 6. Race Conditions + +**Solutions**: + +```go +type ThreadSafePlugin struct { + mu sync.RWMutex + state map[string]interface{} +} + +func (p *ThreadSafePlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + p.mu.Lock() + defer p.mu.Unlock() + + // Safe access to shared state + p.state[req.ID] = "processing" + return req, nil, nil +} +``` + +## 14. Performance Optimization + +1. **Minimize Hook Latency** + + - Avoid blocking operations in hooks + - Use goroutines for background work + - Cache expensive computations + +2. **Efficient Resource Usage** + + - Pool connections and resources + - Use sync.Pool for frequently allocated objects + - Implement proper cleanup + +3. **Monitor Memory Usage** + - Profile your plugin under load + - Watch for memory leaks + - Use appropriate data structures + +## Summary + +This documentation provides complete coverage for Bifrost plugin development: + +- **Architecture & Lifecycle** - Understanding the plugin system and execution flow +- **Interface & Behavior** - Exact method signatures and short-circuit capabilities +- **Error Handling** - Complete control over fallback behavior with AllowFallbacks +- **Practical Examples** - Real-world plugins for rate limiting, auth, and caching +- **Development Guidelines** - Best practices, testing, and contribution standards +- **Troubleshooting** - Solutions for common issues and performance optimization diff --git a/docs/providers.md b/docs/providers.md new file mode 100644 index 0000000000..475d2dd496 --- /dev/null +++ b/docs/providers.md @@ -0,0 +1,379 @@ +# Bifrost Provider System + +Bifrost supports multiple AI model providers, each with its own configuration options and capabilities. This document explains how to configure providers and develop new ones. + +## 1. Supported Providers + +Bifrost currently supports the following providers: + +- OpenAI +- Anthropic +- Azure +- Bedrock +- Cohere +- Vertex +- Mistral +- Ollama + +## 2. Provider Configuration + +### Basic Configuration Structure + +```golang +schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.custom-deployment.com", // Custom base URL (optional) + ExtraHeaders: map[string]string{ // Additional headers (optional) + "X-Organization-ID": "org-123", + "X-Environment": "production", + "User-Agent": "MyApp/1.0 Bifrost/1.0", + }, + DefaultRequestTimeoutInSeconds: 30, + MaxRetries: 2, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, // Number of concurrent requests + BufferSize: 10, // Maximum requests in queue + }, + ProxyConfig: &schemas.ProxyConfig{ + Type: schemas.HttpProxy, + URL: "http://your-proxy:port", + }, +} +``` + +### Default Values + +```golang +const ( + DefaultMaxRetries = 0 + DefaultRetryBackoffInitial = 500 * time.Millisecond + DefaultRetryBackoffMax = 5 * time.Second + DefaultRequestTimeoutInSeconds = 30 + DefaultBufferSize = 100 + DefaultConcurrency = 10 +) +``` + +## 3. Provider-Specific Meta Configurations + +Few providers new meta configs for their setup. + +### Azure + +```golang +meta.AzureMetaConfig{ + Endpoint: "https://your-resource.openai.azure.com", + APIVersion: "2024-02-15-preview", + Deployments: map[string]string{ + "gpt-4": "gpt-4-deployment", + "gpt-35-turbo": "gpt-35-turbo-deployment", + }, +} +``` + +### Bedrock + +```golang +meta.BedrockMetaConfig{ + SecretAccessKey: os.Getenv("AWS_SECRET_ACCESS_KEY"), + Region: "us-east-1", + SessionToken: os.Getenv("BEDROCK_SESSION_TOKEN"), // Optional + ARN: os.Getenv("BEDROCK_ARN"), // Optional + InferenceProfiles: map[string]string{ + "gpt-4": "gpt-4-deployment-profile", + } +} +``` + +### Vertex + +```golang +meta.VertexMetaConfig{ + ProjectID: os.Getenv("VERTEX_PROJECT_ID"), + Location: "us-central1", + AuthCredentials: os.Getenv("VERTEX_AUTH_CREDENTIALS"), // GCP Auth creds +} +``` + +## 4. API Key Management + +### Key Weights + +Bifrost supports weighted distribution of requests across multiple API keys. The weight determines the relative frequency of key usage: + +- Weights are normalized (sum to 1.0) +- Higher weight = more frequent usage +- Equal weights if not specified +- Model-specific key assignment + +Example with weights: + +```golang +[]schemas.Key{ + { + Value: os.Getenv("OPEN_AI_API_KEY1"), + Models: []string{"gpt-4", "gpt-4-turbo"}, + Weight: 0.6, // 60% of requests for these models + }, + { + Value: os.Getenv("OPEN_AI_API_KEY2"), + Models: []string{"gpt-4-turbo"}, + Weight: 0.3, // 30% of requests for gpt-4-turbo + }, + { + Value: os.Getenv("OPEN_AI_API_KEY3"), + Models: []string{"gpt-4"}, + Weight: 0.1, // 10% of requests for gpt-4 + }, +} +``` + +### Key Selection Logic + +1. Filters keys that support the requested model +2. Normalizes weights of available keys +3. Uses weighted random selection +4. Falls back to first available key if selection fails + +## 5. Proxy Configuration + +Bifrost supports various proxy types for provider connections: + +### HTTP Proxy + +```golang +schemas.ProxyConfig{ + Type: schemas.HttpProxy, + URL: "http://proxy.example.com:8080", + Username: "user", // Optional + Password: "pass", // Optional +} +``` + +### SOCKS5 Proxy + +```golang +schemas.ProxyConfig{ + Type: schemas.Socks5Proxy, + URL: "socks5://proxy.example.com:1080", + Username: "user", // Optional + Password: "pass", // Optional +} +``` + +### Environment Proxy + +```golang +schemas.ProxyConfig{ + Type: schemas.EnvProxy, + // Uses HTTP_PROXY, HTTPS_PROXY environment variables +} +``` + +### Proxy Best Practices + +1. **Security** + + - Use HTTPS proxies when possible + - Rotate proxy credentials regularly + - Monitor proxy performance + +2. **Performance** + + - Choose geographically close proxies + - Monitor proxy latency + - Implement proxy fallbacks + +3. **Configuration** + + - Set appropriate timeouts + - Configure retry policies + - Monitor proxy errors + +## 6. Extra Headers Configuration + +Bifrost supports custom headers that can be added to all requests sent to a provider. This is useful for enterprise deployments, custom authentication, or provider-specific requirements. + +### Configuration + +Extra headers are configured in the `NetworkConfig` section: + +```golang +schemas.NetworkConfig{ + ExtraHeaders: map[string]string{ + "X-Organization-ID": "org-123", + "X-Environment": "production", + "User-Agent": "MyApp/1.0 Bifrost/1.0", + "X-Custom-Auth": "bearer-token-xyz", + }, +} +``` + +### JSON Configuration + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": ["gpt-4o-mini", "gpt-4"], + "weight": 1.0 + } + ], + "network_config": { + "extra_headers": { + "X-Organization-ID": "org-123", + "X-Environment": "production", + "User-Agent": "MyApp/1.0 Bifrost/1.0" + } + } + } + } +} +``` + +### Use Cases + +1. **Enterprise Deployments** + + - Organization or tenant identification + - Custom authentication headers + - Environment tracking (dev/staging/prod) + +2. **Self-hosted Providers** + + - Custom routing headers for Ollama deployments + - Load balancer identification + - Custom API versions + +3. **Monitoring & Observability** + + - Request source identification + - Custom correlation IDs + - Application version tracking + +4. **Provider-specific Requirements** + - Beta feature flags + - Custom API versions + - Regional preferences + +### Header Precedence + +Headers configured in `extra_headers` are applied before mandatory provider headers. If there are conflicts (such as duplicate header names), the mandatory headers will take precedence and overwrite or ignore the `extra_headers` values. This ensures that critical provider functionality is not compromised by custom header configurations. + +**Important Notes:** + +- Authorization headers are automatically filtered out from `extra_headers` for security reasons +- Provider-specific mandatory headers (like API keys, content-type, etc.) always take precedence +- Custom headers should not conflict with standard HTTP headers required by the provider + +### Best Practices + +1. **Security** + + - Use environment variables for sensitive headers + - Avoid hardcoding authentication tokens + - Review headers regularly for security implications + +2. **Performance** + + - Keep header count minimal for performance + - Use short, descriptive header names + - Monitor header impact on request size + +3. **Compliance** + - Document custom headers for audit purposes + - Ensure headers comply with HTTP standards + - Validate header values before deployment + +## 7. Provider Development Guidelines + +### 1. Provider Structure + +All providers should be implemented in the `core/providers` directory. The structure should be: + +```text +core/ +β”œβ”€β”€ providers/ +β”‚ β”œβ”€β”€ your_provider.go # Provider implementation +β”‚ └── ... # Other provider implementations +└── schemas/ + └── meta/ + └── your_provider.go # Provider-specific meta configuration +``` + +### 2. Provider Interface + +```golang +type Provider interface { + // GetProviderKey returns the provider's identifier + GetProviderKey() ModelProvider + + // TextCompletion performs a text completion request + TextCompletion(model, key, text string, params *ModelParameters) (*BifrostResponse, *BifrostError) + + // ChatCompletion performs a chat completion request + ChatCompletion(model, key string, messages []Message, params *ModelParameters) (*BifrostResponse, *BifrostError) +} +``` + +### 3. Meta Configuration + +If your provider requires additional configuration beyond the standard `ProviderConfig`, implement a meta configuration in `core/schemas/meta/your_provider.go`: + +```golang +// YourProviderMetaConfig implements the MetaConfig interface +type YourProviderMetaConfig struct { + // Add your provider-specific fields here + Endpoint string `json:"endpoint"` + APIVersion string `json:"api_version"` + // ... other fields +} + +// Implement all required methods from the MetaConfig interface +func (c *YourProviderMetaConfig) GetSecretAccessKey() *string { /* ... */ } +func (c *YourProviderMetaConfig) GetRegion() *string { /* ... */ } +// ... implement other interface methods +``` + +The meta configuration must implement all methods from the `MetaConfig` interface defined in `core/schemas/provider.go`. Return `nil` for methods that don't apply to your provider. + +### 4. Development Process + +1. Open an issue to discuss the new provider +2. Create a pull request with: + - Provider implementation in `core/providers/` + - Addition of provider key in `ModelProvider` in `/core/schemas/bifrost.go` + - Meta configuration in `core/schemas/meta/` (if needed) + - Tests in `core/tests` with `Test{ProviderName}` function name. + - Documentation update in `docs/providers.go` + +### 5. Implementation Requirements + +1. **Error Handling** + + - Use standard Bifrost error types + - Gracefully handling and logging (using bifrost logger) all runtime errors + +2. **Configuration** + + - Support provider-specific settings through meta configuration (if needed) + - Implement default values + - Validate configuration + - Implement sync pools for optimized resource allocations + +3. **Testing** + + - Unit tests for all methods (using `core/tests/setup.go` file) + - Integration tests + - Error case coverage + +4. **Documentation** + - Provider capabilities + - Configuration options + - Meta configuration usage diff --git a/docs/system-architecture.md b/docs/system-architecture.md new file mode 100644 index 0000000000..d7bf2befcb --- /dev/null +++ b/docs/system-architecture.md @@ -0,0 +1,671 @@ +# Bifrost System Architecture + +## Overview + +Bifrost is designed as a high-performance, horizontally scalable middleware that acts as a unified gateway to multiple AI model providers. The architecture is specifically optimized to handle **10,000+ requests per second (RPS)** through sophisticated concurrency management, memory optimization, and connection pooling strategies. + +## Core Architecture Principles + +### 1. **Asynchronous Request Processing** + +Bifrost uses a channel-based worker pool architecture where each provider maintains its own queue of workers to process requests concurrently. + +### 2. **Memory Pool Management** + +Advanced object pooling minimizes garbage collection pressure and memory allocations during high-load scenarios. + +### 3. **Provider Isolation** + +Each AI provider operates in its own isolated context with dedicated configuration, workers, and resource management. + +### 4. **Plugin-First Design** + +Extensible plugin architecture allows for custom logic injection without modifying core functionality. + +--- + +## High-Level System Architecture + +```mermaid +graph TB + subgraph "Client Layer" + HTTP[HTTP Transport] + SDK[Go SDK] + gRPC[gRPC Transport] + end + + subgraph "Bifrost Core" + LB[Load Balancer/Router] + PM[MCP Manager] + subgraph "Request Processing" + PP[Plugin Pipeline] + RQ[Request Queue Manager] + WP[Worker Pool Manager] + end + subgraph "Memory Management" + CP[Channel Pool] + RP[Response Pool] + MP[Message Pool] + end + end + + subgraph "Provider Layer" + subgraph "OpenAI Workers" + OW1[Worker 1] + OW2[Worker 2] + OWN[Worker N] + end + subgraph "Anthropic Workers" + AW1[Worker 1] + AW2[Worker 2] + AWN[Worker N] + end + subgraph "Other Providers" + PW1[Bedrock Workers] + PW2[Azure Workers] + PWN[Other Workers] + end + end + + subgraph "External Systems" + OPENAI[OpenAI API] + ANTHROPIC[Anthropic API] + BEDROCK[Amazon Bedrock] + AZURE[Azure OpenAI] + MCP[MCP Servers] + end + + HTTP --> LB + SDK --> LB + gRPC --> LB + LB --> PM + PM --> PP + PP --> RQ + RQ --> WP + WP --> CP + WP --> RP + WP --> MP + + WP --> OW1 + WP --> AW1 + WP --> PW1 + + OW1 --> OPENAI + OW2 --> OPENAI + OWN --> OPENAI + + AW1 --> ANTHROPIC + AW2 --> ANTHROPIC + AWN --> ANTHROPIC + + PW1 --> BEDROCK + PW2 --> AZURE + + PM --> MCP +``` + +## Getting Started + +To quickly deploy Bifrost and start using it at scale, see the [HTTP Transport API Documentation](./http-transport-api.md) for: + +- **Quick Setup**: Docker and binary deployment options +- **Configuration Examples**: Sample configs for different use cases +- **API Usage**: Complete API reference and examples +- **Performance Tuning**: Optimization settings for high-scale deployments + +--- + +## Detailed Component Architecture + +### 1. Request Flow Architecture + +The request processing pipeline is designed for maximum throughput and minimal latency: + +```mermaid +sequenceDiagram + participant Client + participant Transport + participant Bifrost + participant Plugin + participant Provider + participant AIService + + Client->>Transport: HTTP/SDK Request + Transport->>Bifrost: BifrostRequest + Bifrost->>Plugin: PreHook() + Plugin-->>Bifrost: Modified Request + + Bifrost->>Bifrost: Get Channel from Pool + Bifrost->>Bifrost: Select API Key (Weighted) + Bifrost->>Provider: Queue Request + + Provider->>Provider: Worker Picks Up Request + Provider->>AIService: HTTP Request + AIService-->>Provider: HTTP Response + + Provider->>Bifrost: Response/Error + Bifrost->>Plugin: PostHook() + Plugin-->>Bifrost: Modified Response + + Bifrost->>Bifrost: Return Channel to Pool + Bifrost-->>Transport: BifrostResponse + Transport-->>Client: HTTP/SDK Response +``` + +#### Key Components: + +- **Transport Layer**: HTTP, gRPC, or Go SDK entry points +- **Plugin Pipeline**: Pre/Post hooks for custom logic injection +- **Memory Pools**: Object reuse to minimize GC pressure +- **Worker Pools**: Provider-specific concurrent request processors +- **Key Management**: Weighted distribution across multiple API keys + +#### Detailed Request Processing Flow + +```mermaid +flowchart TD + subgraph "Request Processing Flow" + A[Incoming Request] --> B{Request Type?} + B -->|Text Completion| C[TextCompletionRequest] + B -->|Chat Completion| D[ChatCompletionRequest] + C --> E[Validate Request] + D --> E + E --> F[Get Channel Message from Pool] + F --> G[Apply Plugin PreHooks] + G --> H{Short Circuit?} + H -->|Yes| I[Return Early Response] + H -->|No| J[Select Provider & Model] + J --> K[Get API Key for Provider] + K --> L[Add to Provider Queue] + L --> M[Worker Processes Request] + M --> P[Make API Request] + P --> T[Parse Response] + T --> U[Apply Plugin PostHooks] + U --> V[Return Channel Message to Pool] + V --> W[Return Response to Client] + I --> V + end + + subgraph "Error Handling" + X[Error Occurred] --> Y{Retryable?} + Y -->|Yes| Z[Apply Backoff] + Z --> P + Y -->|No| BB{Fallback Available?} + BB -->|Yes| CC[Try Fallback Provider] + CC --> J + BB -->|No| AA[Return Error Response] + AA --> U + end +``` + +This diagram illustrates the complete request lifecycle including error handling and the plugin pipeline. Note that when tool calls are present in the response, Bifrost returns them to the client for execution rather than executing them automatically. + +### 2. Memory Management Architecture + +Bifrost's memory management system is optimized for high-throughput scenarios with minimal garbage collection impact. See [Memory Management Documentation](./memory-management.md) for detailed configuration options. + +#### Object Pooling Strategy: + +1. **Channel Pools**: Pre-allocated channels for request/response communication +2. **Message Pools**: Reusable `ChannelMessage` objects to reduce allocations +3. **Response Pools**: Pre-allocated response structures + +#### Configuration Impact: + +- `InitialPoolSize`: Controls initial memory allocation (default: 100) +- Higher values reduce runtime allocations but increase memory usage +- Optimal setting: Match expected concurrent request volume + +### 3. Provider Worker Pool Architecture + +Each AI provider operates with its own isolated worker pool system: + +```mermaid +graph TB + subgraph "Provider Worker Pool" + Queue[Request Queue] + subgraph "Workers" + W1[Worker 1] + W2[Worker 2] + W3[Worker 3] + WN[Worker N] + end + subgraph "Key Management" + KS[Key Selector
Weighted Distribution] + K1[API Key 1
Weight: 0.6] + K2[API Key 2
Weight: 0.3] + K3[API Key 3
Weight: 0.1] + end + end + + subgraph "Provider API" + API[AI Provider API
OpenAI/Anthropic/etc.] + end + + Queue --> W1 + Queue --> W2 + Queue --> W3 + Queue --> WN + + W1 --> KS + W2 --> KS + W3 --> KS + WN --> KS + + KS --> K1 + KS --> K2 + KS --> K3 + + K1 --> API + K2 --> API + K3 --> API +``` + +#### Worker Pool Characteristics: + +- **Isolated Queues**: Each provider has its own buffered channel queue +- **Configurable Concurrency**: Number of workers per provider (default: 10) +- **Buffer Management**: Configurable queue size (default: 100) +- **Load Distribution**: Weighted API key selection for load balancing + +#### Performance Tuning: + +- **Concurrency**: Higher values increase throughput but consume more resources +- **Buffer Size**: Larger buffers handle request spikes but use more memory +- **Drop Excess Requests**: Optional fail-fast behavior when queues are full + +See [Provider Configuration Documentation](./providers.md) for detailed configuration options. + +--- + +## High-Performance Features + +### 1. Connection Pooling and Keep-Alive + +Bifrost maintains persistent HTTP connections to reduce connection overhead: + +- **HTTP/2 Support**: Multiplexed connections where supported +- **Connection Reuse**: Persistent connections with keep-alive +- **Custom Timeouts**: Configurable request timeouts per provider +- **Retry Logic**: Exponential backoff for failed requests + +### 2. Dynamic Key Management + +Advanced API key management system for optimal performance: + +```go +type Key struct { + Value string // The actual API key value + Models []string // List of models this key can access + Weight float64 // Weight for load balancing (0.0-1.0) +} +``` + +#### Key Selection Process: + +1. **Model Filtering**: Keys are filtered by model compatibility +2. **Weight Normalization**: Weights are normalized to sum to 1.0 +3. **Weighted Random Selection**: Keys are selected based on weight distribution +4. **Fallback Logic**: Falls back to first available key if selection fails + +### 3. Fallback System Architecture + +Robust fallback mechanism for high availability. See [Fallback Documentation](./fallbacks.md) for complete configuration guide. + +```mermaid +graph TD + subgraph "Primary Request" + PR[Primary Provider
OpenAI gpt-4] + PF{Request Fails?} + end + + subgraph "Fallback Chain" + F1[Fallback 1
Anthropic claude-3-sonnet] + F1F{Fails?} + F2[Fallback 2
Bedrock claude-3-sonnet] + F2F{Fails?} + F3[Fallback 3
Azure gpt-4] + end + + subgraph "Response" + SUCCESS[Return Response] + ERROR[Return Error] + end + + PR --> PF + PF -->|Yes| F1 + PF -->|No| SUCCESS + + F1 --> F1F + F1F -->|Yes| F2 + F1F -->|No| SUCCESS + + F2 --> F2F + F2F -->|Yes| F3 + F2F -->|No| SUCCESS + + F3 --> SUCCESS + F3 -->|All Failed| ERROR +``` + +#### Fallback Characteristics: + +- **Sequential Processing**: Fallbacks are tried in order until one succeeds +- **Independent Configuration**: Each fallback provider uses its own settings +- **Model Compatibility**: Ensures fallback models support required features +- **Error Propagation**: Detailed error information from each attempt + +### 4. Plugin Architecture + +Extensible plugin system for custom logic injection. See [Plugin Documentation](./plugins.md) for usage and development guide. + +```go +type Plugin interface { + GetName() string + PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *PluginShortCircuit, error) + PostHook(ctx *context.Context, result *BifrostResponse, err *BifrostError) (*BifrostResponse, *BifrostError, error) + Cleanup() error +} +``` + +#### Plugin Pipeline Features: + +- **Pre-Hook Processing**: Request modification before provider call +- **Post-Hook Processing**: Response modification after provider call +- **Short-Circuit Support**: Skip provider calls for cached responses +- **Error Recovery**: Plugins can recover from errors or invalidate responses +- **Symmetric Execution**: PostHooks run in reverse order of PreHooks + +### 5. Model Context Protocol (MCP) Integration + +Built-in MCP support for external tool integration. Bifrost integrates with MCP servers to provide tool capabilities to AI models, but the actual tool execution is handled by the client application: + +```mermaid +graph TB + subgraph "MCP Architecture" + Client[Client Application] + Bifrost[Bifrost Core] + MCP[MCP Manager] + subgraph "MCP Servers" + MCP1[MCP Server 1
File System Tools] + MCP2[MCP Server 2
Database Tools] + MCP3[MCP Server 3
API Tools] + end + AI[AI Provider
OpenAI/Anthropic/etc.] + end + + Client -->|Chat request| Bifrost + Bifrost -->|Get available tools| MCP + MCP -->|Tool schemas| MCP1 + MCP -->|Tool schemas| MCP2 + MCP -->|Tool schemas| MCP3 + MCP1 -->|Tool schemas| MCP + MCP2 -->|Tool schemas| MCP + MCP3 -->|Tool schemas| MCP + MCP -->|Combined tool schemas| Bifrost + Bifrost -->|Request + tool schemas| AI + AI -->|Response + tool calls| Bifrost + Bifrost -->|Response + tool calls| Client + Client -->|Tool execution request| Bifrost + Bifrost -->|Execute tools| MCP + MCP -->|Tool execution| MCP1 + MCP -->|Tool execution| MCP2 + MCP -->|Tool execution| MCP3 + MCP1 -->|Tool results| MCP + MCP2 -->|Tool results| MCP + MCP3 -->|Tool results| MCP + MCP -->|Tool results| Bifrost + Bifrost -->|Tool results| Client + Client -->|Continue conversation| Bifrost +``` + +**Key Points:** + +- **Tool Discovery**: Bifrost fetches available tools from MCP servers and includes them in AI requests +- **Tool Calls**: AI models return tool calls in their responses, which Bifrost passes through to the client +- **Client-Side Execution**: The client application is responsible for executing tool calls via MCP +- **Conversation Continuation**: After tool execution, clients can continue the conversation with tool results +- **Connection Types**: Support for HTTP, STDIO, and SSE connections +- **Client Filtering**: Include/exclude specific MCP clients/tools per request +- **Local Tool Hosting**: Host custom tools within Bifrost and use them in your requests. + +See [MCP Documentation](./mcp.md) for detailed configuration and usage examples. + +--- + +## Performance Benchmarks + +### Benchmark Results (5000 RPS Test) + +| Instance Type | Success Rate | Avg Latency | Peak Memory | Bifrost Overhead | +| ------------- | ------------ | ----------- | ----------- | ---------------- | +| t3.medium | 100.00% | 2.12s | 1312.79 MB | **59 Β΅s** | +| t3.xlarge | 100.00% | 1.61s | 3340.44 MB | **11 Β΅s** | + +#### Key Performance Metrics: + +- **Queue Wait Time**: 1.67 Β΅s (t3.xlarge) +- **Key Selection**: 10 ns (t3.xlarge) +- **Message Formatting**: 2.11 Β΅s (t3.xlarge) +- **JSON Marshaling**: 26.80 Β΅s (t3.xlarge) + +### Scaling Configuration Examples + +#### High-Throughput Configuration (10k+ RPS) + +```go +// Bifrost Configuration +bifrost.Init(schemas.BifrostConfig{ + Account: &account, + InitialPoolSize: 20000, // High pool size for memory optimization + DropExcessRequests: true, // Fail-fast when overloaded +}) + +// Provider Configuration +schemas.ProviderConfig{ + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 20000, // High concurrency for throughput + BufferSize: 30000, // Large buffer for request spikes + }, + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + MaxRetries: 2, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, +} +``` + +#### Memory-Optimized Configuration + +```go +// Lower memory usage, slightly higher latency +bifrost.Init(schemas.BifrostConfig{ + Account: &account, + InitialPoolSize: 250, // Standard pool size + DropExcessRequests: false, // Queue requests instead of dropping +}) + +// Provider Configuration +schemas.ProviderConfig{ + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 100, // Moderate concurrency + BufferSize: 250, // Standard buffer size + }, +} +``` + +--- + +## Multi-Provider Support + +Bifrost supports 8 AI model providers with unified interfaces: + +1. **OpenAI** - GPT models with function calling +2. **Anthropic** - Claude models with tool use +3. **Amazon Bedrock** - Multi-model platform with inference profiles +4. **Azure OpenAI** - Enterprise GPT deployment +5. **Google Vertex AI** - Gemini and other Google models +6. **Cohere** - Command and embedding models +7. **Mistral AI** - Mistral model family +8. **Ollama** - Local model deployment + +--- + +### Logging Architecture + +Comprehensive logging system with configurable levels. See [Logger Documentation](./logger.md) for setup guide. + +#### Log Levels: + +- **Debug**: Detailed execution traces +- **Info**: General operational information +- **Warn**: Non-critical issues and fallback usage +- **Error**: Critical errors requiring attention + +--- + +## Network and Security Features + +### Proxy Support + +Enterprise-grade proxy support for secure deployments: + +- **HTTP Proxy**: Standard HTTP proxy with authentication +- **SOCKS5 Proxy**: SOCKS5 proxy support +- **Environment Proxy**: Automatic proxy detection from environment +- **Per-Provider Configuration**: Different proxies per provider + +### Security Features + +- **API Key Rotation**: Hot-swappable API keys without downtime +- **Rate Limiting**: Built-in rate limiting and backoff strategies +- **Request Isolation**: Provider-level request isolation +- **Secure Defaults**: Secure configuration defaults + +--- + +## Transport Layer Architecture + +Bifrost supports multiple transport mechanisms for flexible integration: + +### HTTP Transport + +Full-featured HTTP API with OpenAPI specification: + +- **RESTful Endpoints**: Standard HTTP API patterns +- **Request/Response Validation**: JSON schema validation +- **Error Handling**: Structured error responses +- **Documentation**: Complete OpenAPI 3.0 specification + +See [HTTP Transport API Documentation](./http-transport-api.md) for complete API reference. + +### Go SDK + +Native Go integration for embedded usage: + +- **Type Safety**: Compile-time type checking +- **Context Support**: Full context.Context integration +- **Error Handling**: Structured error types +- **Memory Efficiency**: Direct object access without serialization + +### Future Transports + +Planned transport implementations: + +- **gRPC Transport**: High-performance binary protocol +- **WebSocket Transport**: Real-time streaming support + +--- + +## Configuration Management + +### Account Interface + +Central configuration management through the Account interface: + +```go +type Account interface { + GetConfiguredProviders() ([]ModelProvider, error) + GetKeysForProvider(providerKey ModelProvider) ([]Key, error) + GetConfigForProvider(providerKey ModelProvider) (*ProviderConfig, error) +} +``` + +### Dynamic Configuration + +- **Hot Reloading**: Update configurations without restart +- **Environment Variables**: Support for environment-based config +- **Validation**: Configuration validation at startup +- **Defaults**: Sensible defaults for all settings + +--- + +## Error Handling and Resilience + +### Error Classification + +Bifrost provides structured error handling with detailed error information: + +```go +type BifrostError struct { + EventID *string `json:"event_id,omitempty"` + Type *string `json:"type,omitempty"` + IsBifrostError bool `json:"is_bifrost_error"` + StatusCode *int `json:"status_code,omitempty"` + Error ErrorField `json:"error"` + AllowFallbacks *bool `json:"allow_fallbacks,omitempty"` +} +``` + +### Resilience Patterns + +- **Circuit Breaker**: Automatic failure detection and recovery +- **Bulkhead**: Resource isolation between providers +- **Timeout**: Configurable request timeouts +- **Retry**: Exponential backoff with jitter +- **Fallback**: Multi-level fallback chains + +--- + +## Development and Extension + +### Custom Provider Development + +Bifrost's modular architecture supports custom provider implementation: + +```go +type Provider interface { + GetProviderKey() ModelProvider + TextCompletion(ctx context.Context, model, key, text string, params *ModelParameters) (*BifrostResponse, *BifrostError) + ChatCompletion(ctx context.Context, model, key string, messages []BifrostMessage, params *ModelParameters) (*BifrostResponse, *BifrostError) +} +``` + +### Plugin Development + +Extensible plugin system for custom functionality: + +- **Dynamic Loading**: All plugins must implement `Init(config json.RawMessage) (schemas.Plugin, error)` function +- **Request Processing**: Modify requests before provider calls +- **Response Processing**: Transform responses after provider calls +- **Caching**: Implement custom caching strategies +- **Monitoring**: Add custom metrics and logging +- **Authentication**: Implement custom auth mechanisms + +For complete plugin development guide, see [Plugin Documentation](./plugins.md). + +--- + +## Conclusion + +Bifrost's architecture is specifically designed to handle enterprise-scale AI workloads with **10,000+ RPS** through: + +- **Advanced Concurrency**: Channel-based worker pools with configurable parallelism +- **Memory Optimization**: Object pooling and GC pressure reduction +- **Provider Isolation**: Independent scaling and configuration per provider +- **Extensibility**: Plugin architecture for custom logic +- **Resilience**: Multi-level fallback and error handling +- **Observability**: Built-in metrics and comprehensive logging + +The modular design allows for horizontal scaling, custom integrations, and enterprise-grade reliability while maintaining sub-millisecond overhead in the request processing pipeline. diff --git a/plugins/go.mod b/plugins/go.mod deleted file mode 100644 index 82e50b301d..0000000000 --- a/plugins/go.mod +++ /dev/null @@ -1,8 +0,0 @@ -module github.com/maximhq/bifrost/plugins - -go 1.24.1 - -require ( - github.com/maximhq/bifrost/core v1.0.1 - github.com/maximhq/maxim-go v0.1.1 -) diff --git a/plugins/go.sum b/plugins/go.sum deleted file mode 100644 index b8cb7b66eb..0000000000 --- a/plugins/go.sum +++ /dev/null @@ -1,4 +0,0 @@ -github.com/maximhq/bifrost/core v1.0.1 h1:B0u6o13faUexA+V0EUU0bsLW2dHg9+R2TZKQzPzCxlY= -github.com/maximhq/bifrost/core v1.0.1/go.mod h1:4+Ept2EnX1EEjH/mBuSwK7eE56znI/BCoCbIrx25/x8= -github.com/maximhq/maxim-go v0.1.1 h1:69uUQjjDPmUGcKg/M4/3AO0fbD+70Agt66pH/UCsI5M= -github.com/maximhq/maxim-go v0.1.1/go.mod h1:0+UTWM7UZwNNE5VnljLtr/vpRGtYP8r/2q9WDwlLWFw= diff --git a/plugins/maxim-sdk.go b/plugins/maxim-sdk.go deleted file mode 100644 index c70ad59e7d..0000000000 --- a/plugins/maxim-sdk.go +++ /dev/null @@ -1,128 +0,0 @@ -// Package plugins provides plugins for the Bifrost system. -// This file contains the Plugin implementation using maxim's logger plugin for bifrost. -package plugins - -import ( - "context" - "fmt" - "time" - - "github.com/maximhq/bifrost/core/schemas" - - "github.com/maximhq/maxim-go" - "github.com/maximhq/maxim-go/logging" -) - -// NewMaximLogger initializes and returns a Plugin instance for Maxim's logger. -// -// Parameters: -// - apiKey: API key for Maxim SDK authentication -// - loggerId: ID for the Maxim logger instance -// -// Returns: -// - schemas.Plugin: A configured plugin instance for request/response tracing -// - error: Any error that occurred during plugin initialization -func NewMaximLoggerPlugin(apiKey string, loggerId string) (schemas.Plugin, error) { - // check if Maxim Logger variables are set - if apiKey == "" { - return nil, fmt.Errorf("apiKey is not set") - } - - if loggerId == "" { - return nil, fmt.Errorf("loggerId is not set") - } - - mx := maxim.Init(&maxim.MaximSDKConfig{ApiKey: apiKey}) - - logger, err := mx.GetLogger(&logging.LoggerConfig{Id: loggerId}) - if err != nil { - return nil, err - } - - plugin := &Plugin{logger} - - return plugin, nil -} - -// contextKey is a custom type for context keys to prevent key collisions in the context. -// It provides type safety for context values and ensures that context keys are unique -// across different packages. -type contextKey string - -// traceIDKey is the context key used to store and retrieve trace IDs. -// This constant provides a consistent key for tracking request traces -// throughout the request/response lifecycle. -const ( - traceIDKey contextKey = "traceID" -) - -// Plugin implements the schemas.Plugin interface for Maxim's logger. -// It provides request and response tracing functionality using the Maxim logger, -// allowing detailed tracking of requests and responses. -// -// Fields: -// - logger: A Maxim logger instance used for tracing requests and responses -type Plugin struct { - logger *logging.Logger -} - -// PreHook is called before a request is processed by Bifrost. -// It creates a new trace for the incoming request and stores the trace ID in the context. -// The trace includes request details that can be used for debugging and monitoring. -// -// Parameters: -// - ctx: Pointer to the context.Context that will store the trace ID -// - req: The incoming Bifrost request to be traced -// -// Returns: -// - *schemas.BifrostRequest: The original request, unmodified -// - error: Always returns nil as this implementation doesn't produce errors -// -// The trace ID format is "YYYYMMDD_HHmmssSSS" based on the current time. -// If the context is nil, tracing information will still be logged but not stored in context. -func (plugin *Plugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, error) { - traceID := time.Now().Format("20060102_150405000") - - trace := plugin.logger.Trace(&logging.TraceConfig{ - Id: traceID, - Name: maxim.StrPtr("bifrost"), - }) - - trace.SetInput(fmt.Sprintf("New Request Incoming: %v", req)) - - if ctx != nil { - // Store traceID in context - *ctx = context.WithValue(*ctx, traceIDKey, traceID) - } - - return req, nil -} - -// PostHook is called after a request has been processed by Bifrost. -// It retrieves the trace ID from the context and logs the response details. -// This completes the request trace by adding response information. -// -// Parameters: -// - ctxRef: Pointer to the context.Context containing the trace ID -// - res: The Bifrost response to be traced -// -// Returns: -// - *schemas.BifrostResponse: The original response, unmodified -// - error: Returns an error if the trace ID cannot be retrieved from the context -// -// If the context is nil or the trace ID is not found, an error will be returned -// but the response will still be passed through unmodified. -func (plugin *Plugin) PostHook(ctxRef *context.Context, res *schemas.BifrostResponse) (*schemas.BifrostResponse, error) { - // Get traceID from context - if ctxRef != nil { - ctx := *ctxRef - traceID, ok := ctx.Value(traceIDKey).(string) - if !ok { - return res, fmt.Errorf("traceID not found in context") - } - - plugin.logger.SetTraceOutput(traceID, fmt.Sprintf("Response: %v", res)) - } - - return res, nil -} diff --git a/plugins/maxim/README.md b/plugins/maxim/README.md new file mode 100644 index 0000000000..3ed09ea3d7 --- /dev/null +++ b/plugins/maxim/README.md @@ -0,0 +1,118 @@ +# Maxim-SDK Plugin for Bifrost + +This plugin integrates the Maxim SDK into Bifrost, enabling seamless observability and evaluation of LLM interactions. It captures and forwards inputs/outputs from Bifrost to the Maxim's observability platform. This facilitates end-to-end tracing, evaluation, and monitoring of your LLM-based application. + +## Usage for Bifrost Go Package + +1. Download the Plugin + + ```bash + go get github.com/maximhq/bifrost/plugins/maxim + ``` + +2. Initialise the Plugin + + ```go + maximPlugin, err := maxim.NewMaximLoggerPlugin("your_maxim_api_key", "your_maxim_log_repo_id") + if err != nil { + return nil, err + } + ``` + +3. Pass the plugin to Bifrost + +```go + client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &yourAccount, + Plugins: []schemas.Plugin{maximPlugin}, + }) +``` + +## Usage for Bifrost HTTP Transport + +1. Set up the environment variables + + ```bash + export MAXIM_API_KEY=your_maxim_api_key + export MAXIM_LOG_REPO_ID=your_maxim_log_repo_id + ``` + +2. Set up flags to add the plugin + Add `maxim` to the `--plugins` flag + + e.g., `bifrost-http -config config.json -env .env -plugins maxim` + + For docker build + + ```bash + docker build -t bifrost-transports . + ``` + + Running the docker container + + ```bash + docker run -d \ + -p 8080:8080 \ + -v $(pwd)/config.json:/app/config/config.json \ + -e APP_PORT=8080 \ + -e MAXIM_API_KEY \ + -e MAXIM_LOG_REPO_ID \ + bifrost-transport + ``` + +## Viewing Your Traces + +1. Log in to your [Maxim Dashboard](https://getmaxim.ai/dashboard) +2. Navigate to your repository +3. View detailed llm traces, including: + - LLM inputs/outputs + - Tool usage patterns + - Performance metrics + - Cost analytics + +## Additional Features + +The plugin also supports custom `session-id`, `trace-id` and `generation-id` if the user wishes to log the generations to their custom logging implementation. To use it, pass your trace ID to the request context with the key `trace-id`, and similarly `generation-id` for generation ID. In these cases, no new trace/generation is created and the output is logged to your provided generation. Likewise, `session-id` can be used to add the traces to your generated session. + +e.g. + +```go + ctx = context.WithValue(ctx, "generation-id", "123") + + result, err := bifrostClient.ChatCompletionRequest(schemas.OpenAI, &schemas.BifrostRequest{ + Model: "gpt-4o", + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: ¶ms, + }, ctx) +``` + +HTTP transport offers out-of-the-box support for this feature (when the Maxim plugin is used). Pass `x-bf-maxim-session-id`, `x-bf-maxim-trace-id`, or `x-bf-maxim-generation-id` headers with your request to use this feature. + +## Testing Maxim Logger + +To test the Maxim Logger plugin, you'll need to set up the following environment variables: + +```bash +# Required environment variables +export MAXIM_API_KEY=your_maxim_api_key +export MAXIM_LOGGER_ID=your_maxim_log_repo_id +export OPENAI_API_KEY=your_openai_api_key +``` + +Then you can run the tests using: + +```bash +go test -run TestMaximLoggerPlugin +``` + +The test suite includes: + +- Plugin initialization tests +- Integration tests with Bifrost +- Error handling for missing environment variables + +Note: The tests make actual API calls to both Maxim and OpenAI, so ensure you have valid API keys and sufficient quota before running the tests. + +After the test is complete, you can check your traces on [Maxim's Dashboard](https://www.getmaxim.ai) diff --git a/plugins/maxim/go.mod b/plugins/maxim/go.mod new file mode 100644 index 0000000000..e6ba6c0b5b --- /dev/null +++ b/plugins/maxim/go.mod @@ -0,0 +1,39 @@ +module github.com/maximhq/bifrost/plugins/maxim + +go 1.24.1 + +require ( + github.com/maximhq/bifrost/core v1.1.6 + github.com/maximhq/maxim-go v0.1.3 +) + +require github.com/google/uuid v1.6.0 + +require ( + cloud.google.com/go/compute/metadata v0.7.0 // indirect + github.com/andybalholm/brotli v1.1.1 // indirect + github.com/aws/aws-sdk-go-v2 v1.36.3 // indirect + github.com/aws/aws-sdk-go-v2/config v1.29.14 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect + github.com/aws/smithy-go v1.22.3 // indirect + github.com/goccy/go-json v0.10.5 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/mark3labs/mcp-go v0.32.0 // indirect + github.com/spf13/cast v1.7.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.62.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/net v0.40.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.25.0 // indirect +) diff --git a/plugins/maxim/go.sum b/plugins/maxim/go.sum new file mode 100644 index 0000000000..079a2d8be1 --- /dev/null +++ b/plugins/maxim/go.sum @@ -0,0 +1,78 @@ +cloud.google.com/go/compute/metadata v0.7.0 h1:PBWF+iiAerVNe8UCHxdOt6eHLVc3ydFeOCw78U8ytSU= +cloud.google.com/go/compute/metadata v0.7.0/go.mod h1:j5MvL9PprKL39t166CoB1uVHfQMs4tFQZZcKwksXUjo= +github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= +github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= +github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= +github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= +github.com/aws/aws-sdk-go-v2/config v1.29.14 h1:f+eEi/2cKCg9pqKBoAIwRGzVb70MRKqWX4dg1BDcSJM= +github.com/aws/aws-sdk-go-v2/config v1.29.14/go.mod h1:wVPHWcIFv3WO89w0rE10gzf17ZYy+UVS1Geq8Iei34g= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67 h1:9KxtdcIA/5xPNQyZRgUSpYOE6j9Bc4+D7nZua0KGYOM= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67/go.mod h1:p3C44m+cfnbv763s52gCqrjaqyPikj9Sg47kUVaNZQQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 h1:1Gw+9ajCV1jogloEv1RRnvfRFia2cL6c9cuKV2Ps+G8= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0cFmC3JvwLm5kM83luako= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/XvaX32evhproijJEZY= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= +github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k= +github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= +github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8= +github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/maximhq/bifrost/core v1.1.6 h1:rZrfPVcAfNggfBaOTdu/w+xNwDhW79bfexXsw8LRoMQ= +github.com/maximhq/bifrost/core v1.1.6/go.mod h1:yMRCncTgKYBIrECSRVxMbY3BL8CjLbipJlc644jryxc= +github.com/maximhq/maxim-go v0.1.3 h1:nVzdz3hEjZVxmWHARWIM+Yrn1Jp50qrsK4BA/sz2jj8= +github.com/maximhq/maxim-go v0.1.3/go.mod h1:0+UTWM7UZwNNE5VnljLtr/vpRGtYP8r/2q9WDwlLWFw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.62.0 h1:8dKRBX/y2rCzyc6903Zu1+3qN0H/d2MsxPPmVNamiH0= +github.com/valyala/fasthttp v1.62.0/go.mod h1:FCINgr4GKdKqV8Q0xv8b+UxPV+H/O5nNFo3D+r54Htg= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= +golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/maxim/main.go b/plugins/maxim/main.go new file mode 100644 index 0000000000..68e301034a --- /dev/null +++ b/plugins/maxim/main.go @@ -0,0 +1,312 @@ +// Package maxim provides integration for Maxim's SDK as a Bifrost plugin. +// This file contains the main plugin implementation. +package maxim + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/google/uuid" + "github.com/maximhq/bifrost/core/schemas" + + "github.com/maximhq/maxim-go" + "github.com/maximhq/maxim-go/logging" +) + +// PluginName is the canonical name for the bifrost-maxim plugin. +const PluginName = "bifrost-maxim" + +// MaximConfig represents the configuration for the Maxim plugin +type MaximConfig struct { + APIKey string `json:"api_key"` + LogRepoID string `json:"log_repo_id"` +} + +// Init initializes the Maxim plugin from JSON configuration (for dynamic loading) +func Init(configData json.RawMessage) (schemas.Plugin, error) { + var config MaximConfig + if err := json.Unmarshal(configData, &config); err != nil { + return nil, fmt.Errorf("invalid maxim plugin config: %w", err) + } + + if config.APIKey == "" { + return nil, fmt.Errorf("maxim api_key is required") + } + if config.LogRepoID == "" { + return nil, fmt.Errorf("maxim log_repo_id is required") + } + + return NewMaximLoggerPlugin(config.APIKey, config.LogRepoID) +} + +// NewMaximLogger initializes and returns a Plugin instance for Maxim's logger. +// +// Parameters: +// - apiKey: API key for Maxim SDK authentication +// - logRepoId: ID for the Maxim logger instance +// +// Returns: +// - schemas.Plugin: A configured plugin instance for request/response tracing +// - error: Any error that occurred during plugin initialization +func NewMaximLoggerPlugin(apiKey string, logRepoId string) (schemas.Plugin, error) { + // check if Maxim Logger variables are set + if apiKey == "" { + return nil, fmt.Errorf("apiKey is not set") + } + + if logRepoId == "" { + return nil, fmt.Errorf("log repo id is not set") + } + + mx := maxim.Init(&maxim.MaximSDKConfig{ApiKey: apiKey}) + + logger, err := mx.GetLogger(&logging.LoggerConfig{Id: logRepoId}) + if err != nil { + return nil, err + } + + plugin := &Plugin{logger} + + return plugin, nil +} + +// ContextKey is a custom type for context keys to prevent key collisions in the context. +// It provides type safety for context values and ensures that context keys are unique +// across different packages. +type ContextKey string + +// TraceIDKey is the context key used to store and retrieve trace IDs. +// This constant provides a consistent key for tracking request traces +// throughout the request/response lifecycle. +const ( + SessionIDKey ContextKey = "session-id" + TraceIDKey ContextKey = "trace-id" + GenerationIDKey ContextKey = "generation-id" +) + +// The plugin provides request/response tracing functionality by integrating with Maxim's logging system. +// It supports both chat completion and text completion requests, tracking the entire lifecycle of each request +// including inputs, parameters, and responses. +// +// Key Features: +// - Automatic trace and generation ID management +// - Support for both chat and text completion requests +// - Contextual tracking across request lifecycle +// - Graceful handling of existing trace/generation IDs +// +// The plugin uses context values to maintain trace and generation IDs throughout the request lifecycle. +// These IDs can be propagated from external systems through HTTP headers (x-bf-maxim-trace-id and x-bf-maxim-generation-id). + +// Plugin implements the schemas.Plugin interface for Maxim's logger. +// It provides request and response tracing functionality using the Maxim logger, +// allowing detailed tracking of requests and responses. +// +// Fields: +// - logger: A Maxim logger instance used for tracing requests and responses +type Plugin struct { + logger *logging.Logger +} + +// GetName returns the name of the plugin. +func (plugin *Plugin) GetName() string { + return PluginName +} + +// PreHook is called before a request is processed by Bifrost. +// It manages trace and generation tracking for incoming requests by either: +// - Creating a new trace if none exists +// - Reusing an existing trace ID from the context +// - Creating a new generation within an existing trace +// - Skipping trace/generation creation if they already exist +// +// The function handles both chat completion and text completion requests, +// capturing relevant metadata such as: +// - Request type (chat/text completion) +// - Model information +// - Message content and role +// - Model parameters +// +// Parameters: +// - ctx: Pointer to the context.Context that may contain existing trace/generation IDs +// - req: The incoming Bifrost request to be traced +// +// Returns: +// - *schemas.BifrostRequest: The original request, unmodified +// - error: Any error that occurred during trace/generation creation +func (plugin *Plugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + var traceID string + var sessionID string + + // Check if context already has traceID and generationID + if ctx != nil { + if existingGenerationID, ok := (*ctx).Value(GenerationIDKey).(string); ok && existingGenerationID != "" { + // If generationID exists, return early + return req, nil, nil + } + + if existingTraceID, ok := (*ctx).Value(TraceIDKey).(string); ok && existingTraceID != "" { + // If traceID exists, and no generationID, create a new generation on the trace + traceID = existingTraceID + } + + if existingSessionID, ok := (*ctx).Value(SessionIDKey).(string); ok && existingSessionID != "" { + sessionID = existingSessionID + } + } + + // Determine request type and set appropriate tags + var requestType string + var tags map[string]string + var messages []logging.CompletionRequest + var latestMessage string + + if req.Input.ChatCompletionInput != nil { + requestType = "chat_completion" + tags = map[string]string{ + "action": "chat_completion", + "model": req.Model, + } + for _, message := range *req.Input.ChatCompletionInput { + messages = append(messages, logging.CompletionRequest{ + Role: string(message.Role), + Content: message.Content, + }) + } + if len(*req.Input.ChatCompletionInput) > 0 { + lastMsg := (*req.Input.ChatCompletionInput)[len(*req.Input.ChatCompletionInput)-1] + if lastMsg.Content.ContentStr != nil { + latestMessage = *lastMsg.Content.ContentStr + } else if lastMsg.Content.ContentBlocks != nil { + // Find the last text content block + for i := len(*lastMsg.Content.ContentBlocks) - 1; i >= 0; i-- { + block := (*lastMsg.Content.ContentBlocks)[i] + if block.Type == "text" && block.Text != nil { + latestMessage = *block.Text + break + } + } + // If no text block found, use placeholder + if latestMessage == "" { + latestMessage = "-" + } + } + } + } else if req.Input.TextCompletionInput != nil { + requestType = "text_completion" + tags = map[string]string{ + "action": "text_completion", + "model": req.Model, + } + messages = append(messages, logging.CompletionRequest{ + Role: string(schemas.ModelChatMessageRoleUser), + Content: req.Input.TextCompletionInput, + }) + latestMessage = *req.Input.TextCompletionInput + } + + if traceID == "" { + // If traceID is not set, create a new trace + traceID = uuid.New().String() + + traceConfig := logging.TraceConfig{ + Id: traceID, + Name: maxim.StrPtr(fmt.Sprintf("bifrost_%s", requestType)), + Tags: &tags, + } + + if sessionID != "" { + traceConfig.SessionId = &sessionID + } + + trace := plugin.logger.Trace(&traceConfig) + + trace.SetInput(latestMessage) + } + + // Convert ModelParameters to map[string]interface{} + modelParams := make(map[string]interface{}) + if req.Params != nil { + // Convert the struct to a map using reflection or JSON marshaling + jsonData, err := json.Marshal(req.Params) + if err == nil { + json.Unmarshal(jsonData, &modelParams) + } + } + + generationID := uuid.New().String() + + plugin.logger.AddGenerationToTrace(traceID, &logging.GenerationConfig{ + Id: generationID, + Model: req.Model, + Provider: string(req.Provider), + Tags: &tags, + Messages: messages, + ModelParameters: modelParams, + }) + + if ctx != nil { + if _, ok := (*ctx).Value(TraceIDKey).(string); !ok { + *ctx = context.WithValue(*ctx, TraceIDKey, traceID) + } + *ctx = context.WithValue(*ctx, GenerationIDKey, generationID) + } + + return req, nil, nil +} + +// PostHook is called after a request has been processed by Bifrost. +// It completes the request trace by: +// - Adding response data to the generation if a generation ID exists +// - Logging error details if bifrostErr is provided +// - Ending the generation if it exists +// - Ending the trace if a trace ID exists +// - Flushing all pending log data +// +// The function gracefully handles cases where trace or generation IDs may be missing, +// ensuring that partial logging is still performed when possible. +// +// Parameters: +// - ctxRef: Pointer to the context.Context containing trace/generation IDs +// - res: The Bifrost response to be traced +// - bifrostErr: The BifrostError returned by the request, if any +// +// Returns: +// - *schemas.BifrostResponse: The original response, unmodified +// - *schemas.BifrostError: The original error, unmodified +// - error: Never returns an error as it handles missing IDs gracefully +func (plugin *Plugin) PostHook(ctxRef *context.Context, res *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + if ctxRef != nil { + ctx := *ctxRef + + generationID, ok := ctx.Value(GenerationIDKey).(string) + if ok { + if bifrostErr != nil { + genErr := logging.GenerationError{ + Message: bifrostErr.Error.Message, + Code: bifrostErr.Error.Code, + Type: bifrostErr.Error.Type, + } + plugin.logger.SetGenerationError(generationID, &genErr) + } else if res != nil { + plugin.logger.AddResultToGeneration(generationID, res) + } + + plugin.logger.EndGeneration(generationID) + } + + traceID, ok := ctx.Value(TraceIDKey).(string) + if ok { + plugin.logger.EndTrace(traceID) + } + } + plugin.logger.Flush() + + return res, bifrostErr, nil +} + +func (plugin *Plugin) Cleanup() error { + plugin.logger.Flush() + + return nil +} diff --git a/plugins/maxim/plugin_test.go b/plugins/maxim/plugin_test.go new file mode 100644 index 0000000000..459d3def06 --- /dev/null +++ b/plugins/maxim/plugin_test.go @@ -0,0 +1,128 @@ +// Package maxim provides integration for Maxim's SDK as a Bifrost plugin. +// It includes tests for plugin initialization, Bifrost integration, and request/response tracing. +package maxim + +import ( + "context" + "fmt" + "log" + "os" + "testing" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// getPlugin initializes and returns a Plugin instance for testing purposes. +// It sets up the Maxim logger with configuration from environment variables. +// +// Environment Variables: +// - MAXIM_API_KEY: API key for Maxim SDK authentication +// - MAXIM_LOGGER_ID: ID for the Maxim logger instance +// +// Returns: +// - schemas.Plugin: A configured plugin instance for request/response tracing +// - error: Any error that occurred during plugin initialization +func getPlugin() (schemas.Plugin, error) { + // check if Maxim Logger variables are set + if os.Getenv("MAXIM_API_KEY") == "" { + return nil, fmt.Errorf("MAXIM_API_KEY is not set, please set it in your environment variables") + } + + if os.Getenv("MAXIM_LOGGER_ID") == "" { + return nil, fmt.Errorf("MAXIM_LOGGER_ID is not set, please set it in your environment variables") + } + + plugin, err := NewMaximLoggerPlugin(os.Getenv("MAXIM_API_KEY"), os.Getenv("MAXIM_LOGGER_ID")) + if err != nil { + return nil, err + } + + return plugin, nil +} + +// BaseAccount implements the schemas.Account interface for testing purposes. +// It provides mock implementations of the required methods to test the Maxim plugin +// with a basic OpenAI configuration. +type BaseAccount struct{} + +// GetConfiguredProviders returns a list of supported providers for testing. +// Currently only supports OpenAI for simplicity in testing. You are free to add more providers as needed. +func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + return []schemas.ModelProvider{schemas.OpenAI}, nil +} + +// GetKeysForProvider returns a mock API key configuration for testing. +// Uses the OPENAI_API_KEY environment variable for authentication. +func (baseAccount *BaseAccount) GetKeysForProvider(providerKey schemas.ModelProvider) ([]schemas.Key, error) { + return []schemas.Key{ + { + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{"gpt-4o-mini", "gpt-4-turbo"}, + Weight: 1.0, + }, + }, nil +} + +// GetConfigForProvider returns default provider configuration for testing. +// Uses standard network and concurrency settings. +func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil +} + +// TestMaximLoggerPlugin tests the integration of the Maxim Logger plugin with Bifrost. +// It performs the following steps: +// 1. Initializes the Maxim plugin with environment variables +// 2. Sets up a test Bifrost instance with the plugin +// 3. Makes a test chat completion request +// +// Required environment variables: +// - MAXIM_API_KEY: Your Maxim API key +// - MAXIM_LOGGER_ID: Your Maxim logger repository ID +// - OPENAI_API_KEY: Your OpenAI API key for the test request +func TestMaximLoggerPlugin(t *testing.T) { + // Initialize the Maxim plugin + plugin, err := getPlugin() + if err != nil { + log.Fatalf("Error setting up the plugin: %v", err) + } + + account := BaseAccount{} + + // Initialize Bifrost with the plugin + client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug), + }) + if err != nil { + log.Fatalf("Error initializing Bifrost: %v", err) + } + + // Make a test chat completion request + _, bifrostErr := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: "user", + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, how are you?"), + }, + }, + }, + }, + }) + + if bifrostErr != nil { + log.Printf("Error in Bifrost request: %v", bifrostErr) + } + + log.Println("Bifrost request completed, check your Maxim Dashboard for the trace") + + client.Cleanup() +} diff --git a/plugins/mocker/README.md b/plugins/mocker/README.md new file mode 100644 index 0000000000..7b8d0cce22 --- /dev/null +++ b/plugins/mocker/README.md @@ -0,0 +1,1301 @@ +# Bifrost Mocker Plugin + +The Mocker plugin for Bifrost allows you to intercept and mock AI provider responses for testing, development, and simulation purposes. It provides flexible rule-based mocking with support for custom responses, error simulation, latency injection, and comprehensive statistics tracking. + +**⚑ Performance Optimized** - Designed for high-throughput scenarios including benchmarking with minimal overhead. + +## Table of Contents + +1. [Quick Start](#quick-start) +2. [Installation](#installation) +3. [Basic Usage](#basic-usage) +4. [Configuration Reference](#configuration-reference) +5. [Advanced Features](#advanced-features) +6. [Faker Support](#faker-support) +7. [Examples](#examples) +8. [Statistics and Monitoring](#statistics-and-monitoring) +9. [Performance](#performance) +10. [Best Practices](#best-practices) +11. [Troubleshooting](#troubleshooting) + +## Quick Start + +### Minimal Configuration + +The simplest way to use the Mocker plugin is with no configuration - it will create a default catch-all rule: + +```go +package main + +import ( + "context" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + mocker "github.com/maximhq/bifrost/plugins/mocker" +) + +func main() { + // Create plugin with minimal config + plugin, err := mocker.NewMockerPlugin(mocker.MockerConfig{ + Enabled: true, // That's it! Default rule will be created automatically + }) + if err != nil { + panic(err) + } + + // Initialize Bifrost with the plugin + client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &yourAccount, + Plugins: []schemas.Plugin{plugin}, + }) + if err != nil { + panic(err) + } + defer client.Cleanup() + + // All requests will now return: "This is a mock response from the Mocker plugin" + response, _ := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello!"), + }, + }, + }, + }, + }) + + // response.Choices[0].Message.Content.ContentStr == "This is a mock response from the Mocker plugin" +} +``` + +### Quick Custom Response + +```go +plugin, err := mocker.NewMockerPlugin(mocker.MockerConfig{ + Enabled: true, + Rules: []mocker.MockRule{ + { + Name: "openai-mock", + Enabled: true, + Probability: 1.0, // Always trigger + Conditions: mocker.Conditions{ + Providers: []string{"openai"}, + }, + Responses: []mocker.Response{ + { + Type: mocker.ResponseTypeSuccess, + Content: &mocker.SuccessResponse{ + Message: "Hello! This is a custom mock response for OpenAI.", + Usage: &mocker.Usage{ + PromptTokens: 15, + CompletionTokens: 25, + TotalTokens: 40, + }, + }, + }, + }, + }, + }, +}) +``` + +## Installation + +### As a Go Module + +1. Add the plugin to your project: + + ```bash + go get github.com/maximhq/bifrost/plugins/mocker + ``` + +2. Import in your code: + + ```go + import mocker "github.com/maximhq/bifrost/plugins/mocker" + ``` + +### Development Setup + +1. Clone the repository: + + ```bash + git clone https://github.com/maximhq/bifrost.git + cd bifrost/plugins/mocker + ``` + +2. Install dependencies: + + ```bash + go mod tidy + ``` + +3. Run tests: + + ```bash + go test -v + ``` + +## Basic Usage + +### Creating the Plugin + +```go +// Basic configuration +config := mocker.MockerConfig{ + Enabled: true, + DefaultBehavior: mocker.DefaultBehaviorPassthrough, // Optional: "passthrough", "success", "error" + Rules: []mocker.MockRule{ + // Your rules here + }, +} + +plugin, err := mocker.NewMockerPlugin(config) +if err != nil { + // Handle configuration errors + log.Fatal(err) +} +``` + +### Adding to Bifrost + +```go +client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &yourAccount, + Plugins: []schemas.Plugin{ + plugin, // Add your mocker plugin + // Other plugins... + }, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo), +}) +``` + +### Disabling the Plugin + +```go +// Temporarily disable without removing +config := mocker.MockerConfig{ + Enabled: false, // All requests pass through to real providers +} +``` + +## Configuration Reference + +### MockerConfig + +| Field | Type | Required | Default | Description | +| ----------------- | ------------ | -------- | --------------- | ------------------------------------------------------------------- | +| `Enabled` | `bool` | Yes | `false` | Enable/disable the entire plugin | +| `DefaultBehavior` | `string` | No | `"passthrough"` | Action when no rules match: `"passthrough"`, `"success"`, `"error"` | +| `GlobalLatency` | `*Latency` | No | `nil` | Global latency applied to all rules (can be overridden per rule) | +| `Rules` | `[]MockRule` | No | `[]` | List of mock rules evaluated in priority order | + +### MockRule + +| Field | Type | Required | Default | Description | +| ------------- | ------------ | -------- | ------- | -------------------------------------------------- | +| `Name` | `string` | Yes | - | Unique rule name for identification and statistics | +| `Enabled` | `bool` | No | `true` | Enable/disable this specific rule | +| `Priority` | `int` | No | `0` | Higher numbers = higher priority (checked first) | +| `Probability` | `float64` | No | `1.0` | Activation probability (0.0=never, 1.0=always) | +| `Conditions` | `Conditions` | No | `{}` | Matching conditions (empty = match all) | +| `Responses` | `[]Response` | Yes | - | Possible responses (weighted random selection) | +| `Latency` | `*Latency` | No | `nil` | Rule-specific latency override | + +### Conditions + +| Field | Type | Required | Default | Description | +| -------------- | ------------ | -------- | ------- | --------------------------------------------------- | +| `Providers` | `[]string` | No | `[]` | Match specific providers: `["openai", "anthropic"]` | +| `Models` | `[]string` | No | `[]` | Match specific models: `["gpt-4", "claude-3"]` | +| `MessageRegex` | `*string` | No | `nil` | Regex pattern to match message content | +| `RequestSize` | `*SizeRange` | No | `nil` | Request size constraints in bytes | + +### Response + +| Field | Type | Required | Default | Description | +| ---------------- | ------------------ | ----------- | ------- | ------------------------------------------------------ | +| `Type` | `string` | Yes | - | Response type: `"success"` or `"error"` | +| `Weight` | `float64` | No | `1.0` | Weight for random selection (higher = more likely) | +| `Content` | `*SuccessResponse` | Conditional | - | Required if `Type="success"` | +| `Error` | `*ErrorResponse` | Conditional | - | Required if `Type="error"` | +| `AllowFallbacks` | `*bool` | No | `nil` | Control fallback behavior (`nil`=allow, `false`=block) | + +### SuccessResponse + +| Field | Type | Required | Default | Description | +| ----------------- | ------------------------ | ----------- | -------------- | ------------------------------------------------------------------- | +| `Message` | `string` | Conditional | - | Static response message (required if no template) | +| `MessageTemplate` | `*string` | Conditional | - | Template with variables: `{{provider}}`, `{{model}}`, `{{faker.*}}` | +| `Model` | `*string` | No | `nil` | Override model name in response | +| `Usage` | `*Usage` | No | Default values | Token usage information | +| `FinishReason` | `*string` | No | `"stop"` | Completion reason | +| `CustomFields` | `map[string]interface{}` | No | `{}` | Additional metadata fields | + +### ErrorResponse + +| Field | Type | Required | Default | Description | +| ------------ | --------- | -------- | ------- | ------------------------------------------------- | +| `Message` | `string` | Yes | - | Error message to return | +| `Type` | `*string` | No | `nil` | Error type (e.g., `"rate_limit"`, `"auth_error"`) | +| `Code` | `*string` | No | `nil` | Error code (e.g., `"429"`, `"401"`) | +| `StatusCode` | `*int` | No | `nil` | HTTP status code | + +### Latency + +| Field | Type | Required | Default | Description | +| ------ | --------------- | ----------- | ------- | ------------------------------------------------------------------ | +| `Type` | `string` | Yes | - | Latency type: `"fixed"` or `"uniform"` | +| `Min` | `time.Duration` | Yes | - | Minimum/exact latency (use `time.Millisecond`, NOT raw int) | +| `Max` | `time.Duration` | Conditional | - | Maximum latency (required for `"uniform"`, use `time.Millisecond`) | + +**⚠️ Important**: Use Go's `time.Duration` constants, not raw integers: + +- βœ… Correct: `100 * time.Millisecond` +- ❌ Wrong: `100` (this would be 100 nanoseconds, barely noticeable) + +### SizeRange + +| Field | Type | Required | Default | Description | +| ----- | ----- | -------- | ------- | ----------------------------- | +| `Min` | `int` | Yes | - | Minimum request size in bytes | +| `Max` | `int` | Yes | - | Maximum request size in bytes | + +### Usage + +| Field | Type | Required | Default | Description | +| ------------------ | ----- | -------- | ------- | ---------------------------------- | +| `PromptTokens` | `int` | No | `10` | Number of tokens in the prompt | +| `CompletionTokens` | `int` | No | `20` | Number of tokens in the completion | +| `TotalTokens` | `int` | No | `30` | Total tokens (prompt + completion) | + +## Advanced Features + +### Template Variables + +Use templates to create dynamic responses: + +```go +Response{ + Type: mocker.ResponseTypeSuccess, + Content: &mocker.SuccessResponse{ + MessageTemplate: stringPtr("Hello from {{provider}} using model {{model}}!"), + }, +} +``` + +**Available Variables:** + +- `{{provider}}` - Provider name (e.g., "openai", "anthropic") +- `{{model}}` - Model name (e.g., "gpt-4", "claude-3") +- `{{faker.*}}` - Fake data generation (see [Faker Support](#faker-support) section for full list) + +### Weighted Response Selection + +Configure multiple responses with different weights: + +```go +Responses: []mocker.Response{ + { + Type: mocker.ResponseTypeSuccess, + Weight: 0.8, // 80% chance + Content: &mocker.SuccessResponse{ + Message: "Success response", + }, + }, + { + Type: mocker.ResponseTypeError, + Weight: 0.2, // 20% chance + Error: &mocker.ErrorResponse{ + Message: "Simulated error", + StatusCode: intPtr(500), + }, + }, +} +``` + +### Latency Simulation + +Add realistic delays to responses: + +```go +// Fixed latency +Latency: &mocker.Latency{ + Type: mocker.LatencyTypeFixed, + Min: 100 * time.Millisecond, +} + +// Variable latency +Latency: &mocker.Latency{ + Type: mocker.LatencyTypeUniform, + Min: 50 * time.Millisecond, + Max: 200 * time.Millisecond, +} +``` + +**⚠️ Critical**: Always use `time.Duration` constants (e.g., `time.Millisecond`), never raw integers. Raw integers are interpreted as nanoseconds and will be barely noticeable. + +### Regex Message Matching + +Match specific message content: + +```go +Conditions: mocker.Conditions{ + MessageRegex: stringPtr(`(?i).*error.*|.*help.*`), // Case-insensitive match for "error" or "help" +} +``` + +### Request Size Filtering + +Match requests by size: + +```go +Conditions: mocker.Conditions{ + RequestSize: &mocker.SizeRange{ + Min: 100, // Minimum 100 bytes + Max: 1000, // Maximum 1000 bytes + }, +} +``` + +## Faker Support + +The Mocker plugin includes comprehensive fake data generation capabilities using the [jaswdr/faker](https://github.com/jaswdr/faker) library. This allows you to create realistic mock responses with dynamic, fake data that changes on each request. + +### Using Faker in Templates + +Faker variables can be used in the `MessageTemplate` field using the `{{faker.method}}` syntax: + +```go +Response{ + Type: mocker.ResponseTypeSuccess, + Content: &mocker.SuccessResponse{ + MessageTemplate: stringPtr(`Hello {{faker.first_name}}! Your email {{faker.email}} has been verified. +Your account ID is {{faker.uuid}} and your phone number {{faker.phone}} is on file.`), + }, +} +``` + +### Available Faker Methods + +#### Personal Information + +- `{{faker.name}}` - Full name (e.g., "John Smith") +- `{{faker.first_name}}` - First name (e.g., "John") +- `{{faker.last_name}}` - Last name (e.g., "Smith") +- `{{faker.email}}` - Email address (e.g., "john123@example.com") +- `{{faker.phone}}` - Phone number (e.g., "+1-555-123-4567") + +#### Location + +- `{{faker.address}}` - Full address (e.g., "123 Main St") +- `{{faker.city}}` - City name (e.g., "New York") +- `{{faker.state}}` - State name (e.g., "California") +- `{{faker.zip_code}}` - Postal code (e.g., "12345") + +#### Business + +- `{{faker.company}}` - Company name (e.g., "Tech Solutions Inc.") +- `{{faker.job_title}}` - Job title (e.g., "Software Engineer") + +#### Text and Lorem Ipsum + +- `{{faker.word}}` - Single word (e.g., "example") +- `{{faker.sentence}}` - Sentence with default 8 words +- `{{faker.sentence:5}}` - Sentence with 5 words +- `{{faker.lorem_ipsum}}` - Lorem ipsum text with default 10 words +- `{{faker.lorem_ipsum:20}}` - Lorem ipsum text with 20 words + +#### Identifiers and Data + +- `{{faker.uuid}}` - UUID v4 (e.g., "123e4567-e89b-12d3-a456-426614174000") +- `{{faker.hex_color}}` - Hex color code (e.g., "#FF5733") + +#### Numbers and Dates + +- `{{faker.integer}}` - Random integer between 1-100 +- `{{faker.integer:10,50}}` - Random integer between 10-50 +- `{{faker.float}}` - Random float between 0-100 (2 decimal places) +- `{{faker.float:1,10}}` - Random float between 1-10 +- `{{faker.boolean}}` - Random boolean (true/false) +- `{{faker.date}}` - Date in YYYY-MM-DD format +- `{{faker.datetime}}` - Datetime in YYYY-MM-DD HH:MM:SS format + +### Faker Examples + +#### Customer Support Simulation + +```go +{ + Name: "customer-support-faker", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: mocker.Conditions{ + MessageRegex: stringPtr(`(?i).*support.*|.*help.*`), + }, + Responses: []mocker.Response{ + { + Type: mocker.ResponseTypeSuccess, + Content: &mocker.SuccessResponse{ + MessageTemplate: stringPtr(`Hello {{faker.first_name}}! + +Thank you for contacting {{faker.company}} support. I've reviewed your account and here are the details: + +**Account Information:** +- Name: {{faker.name}} +- Email: {{faker.email}} +- Phone: {{faker.phone}} +- Account ID: {{faker.uuid}} +- Address: {{faker.address}}, {{faker.city}}, {{faker.state}} {{faker.zip_code}} + +**Support Ticket:** #{{faker.integer:10000,99999}} +**Priority:** {{faker.boolean}} +**Estimated Resolution:** {{faker.date}} + +How can I help you today?`), + Usage: &mocker.Usage{ + PromptTokens: 25, + CompletionTokens: 150, + TotalTokens: 175, + }, + }, + }, + }, +} +``` + +#### E-commerce Order Confirmation + +```go +{ + Name: "ecommerce-order-faker", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: mocker.Conditions{ + MessageRegex: stringPtr(`(?i).*order.*|.*purchase.*`), + }, + Responses: []mocker.Response{ + { + Type: mocker.ResponseTypeSuccess, + Content: &mocker.SuccessResponse{ + MessageTemplate: stringPtr(`πŸŽ‰ Order Confirmed! + +**Order Details:** +- Order ID: {{faker.uuid}} +- Customer: {{faker.name}} +- Email: {{faker.email}} +- Total: ${{faker.float:10,500}} +- Items: {{faker.integer:1,5}} items + +**Shipping Address:** +{{faker.address}} +{{faker.city}}, {{faker.state}} {{faker.zip_code}} + +**Estimated Delivery:** {{faker.date}} +**Tracking Number:** {{faker.integer:100000000,999999999}} + +Thank you for shopping with {{faker.company}}!`), + }, + }, + }, +} +``` + +#### User Profile Generation + +```go +{ + Name: "user-profile-faker", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: mocker.Conditions{ + MessageRegex: stringPtr(`(?i).*profile.*|.*user.*info.*`), + }, + Responses: []mocker.Response{ + { + Type: mocker.ResponseTypeSuccess, + Content: &mocker.SuccessResponse{ + MessageTemplate: stringPtr(`**User Profile Generated:** + +**Personal Information:** +- Full Name: {{faker.name}} +- Email: {{faker.email}} +- Phone: {{faker.phone}} +- Preferred Color: {{faker.hex_color}} + +**Professional Details:** +- Company: {{faker.company}} +- Job Title: {{faker.job_title}} +- Work Phone: {{faker.phone}} + +**Address:** +{{faker.address}} +{{faker.city}}, {{faker.state}} {{faker.zip_code}} + +**Account Settings:** +- User ID: {{faker.uuid}} +- Account Created: {{faker.date}} +- Email Notifications: {{faker.boolean}} +- SMS Alerts: {{faker.boolean}} + +**Bio:** {{faker.lorem_ipsum:25}}`), + }, + }, + }, +} +``` + +### Faker with Weighted Responses + +You can combine faker with weighted response selection for even more realistic scenarios: + +```go +{ + Name: "mixed-faker-responses", + Enabled: true, + Priority: 100, + Probability: 1.0, + Responses: []mocker.Response{ + { + Type: mocker.ResponseTypeSuccess, + Weight: 0.7, // 70% positive responses + Content: &mocker.SuccessResponse{ + MessageTemplate: stringPtr(`Great news, {{faker.first_name}}! Your request has been approved. +Reference number: {{faker.uuid}}. +Contact us at {{faker.phone}} if you have questions.`), + }, + }, + { + Type: mocker.ResponseTypeSuccess, + Weight: 0.2, // 20% neutral responses + Content: &mocker.SuccessResponse{ + MessageTemplate: stringPtr(`Hello {{faker.first_name}}, your request is being processed. +Ticket ID: {{faker.integer:1000,9999}}. +Expected completion: {{faker.date}}.`), + }, + }, + { + Type: mocker.ResponseTypeError, + Weight: 0.1, // 10% error responses + Error: &mocker.ErrorResponse{ + Message: fmt.Sprintf("Account validation failed for user %s. Please contact support.", "{{faker.email}}"), + Type: stringPtr("validation_error"), + Code: stringPtr("VAL_001"), + }, + }, + }, +} +``` + +### Important Notes + +- **Dynamic Generation**: Faker values are generated fresh on each request, ensuring unique responses +- **Performance**: Faker generation is highly optimized and adds minimal overhead +- **Parameters**: Some faker methods support parameters (e.g., `{{faker.sentence:10}}` for 10 words) +- **Reliability**: Uses the established [jaswdr/faker](https://github.com/jaswdr/faker) library with zero dependencies +- **Template Mixing**: You can freely mix faker variables with regular template variables like `{{provider}}` and `{{model}}` + +## Examples + +### Development Environment Mock + +```go +config := mocker.MockerConfig{ + Enabled: true, + DefaultBehavior: mocker.DefaultBehaviorPassthrough, + Rules: []mocker.MockRule{ + { + Name: "dev-openai-mock", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: mocker.Conditions{ + Providers: []string{"openai"}, + Models: []string{"gpt-4", "gpt-4-turbo"}, + }, + Responses: []mocker.Response{ + { + Type: mocker.ResponseTypeSuccess, + Content: &mocker.SuccessResponse{ + MessageTemplate: stringPtr("Development mock response from {{model}} ({{provider}})"), + Usage: &mocker.Usage{ + PromptTokens: 20, + CompletionTokens: 30, + TotalTokens: 50, + }, + }, + }, + }, + }, + }, +} +``` + +### Error Simulation for Testing + +```go +config := mocker.MockerConfig{ + Enabled: true, + Rules: []mocker.MockRule{ + { + Name: "rate-limit-simulation", + Enabled: true, + Priority: 200, + Probability: 0.1, // 10% of requests + Conditions: mocker.Conditions{ + Providers: []string{"openai", "anthropic"}, + }, + Responses: []mocker.Response{ + { + Type: mocker.ResponseTypeError, + AllowFallbacks: boolPtr(true), // Allow fallback providers + Error: &mocker.ErrorResponse{ + Message: "Rate limit exceeded. Please try again later.", + Type: stringPtr("rate_limit"), + Code: stringPtr("429"), + StatusCode: intPtr(429), + }, + }, + }, + Latency: &mocker.Latency{ + Type: mocker.LatencyTypeFixed, + Min: 500 * time.Millisecond, // Simulate slow error response + }, + }, + }, +} +``` + +### A/B Testing Different Responses + +```go +config := mocker.MockerConfig{ + Enabled: true, + Rules: []mocker.MockRule{ + { + Name: "ab-test-responses", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: mocker.Conditions{ + MessageRegex: stringPtr(`(?i).*greeting.*|.*hello.*`), + }, + Responses: []mocker.Response{ + { + Type: mocker.ResponseTypeSuccess, + Weight: 0.5, // 50% - Version A + Content: &mocker.SuccessResponse{ + Message: "Hello! How can I help you today?", + CustomFields: map[string]interface{}{ + "ab_test_version": "A", + "response_style": "formal", + }, + }, + }, + { + Type: mocker.ResponseTypeSuccess, + Weight: 0.5, // 50% - Version B + Content: &mocker.SuccessResponse{ + Message: "Hey there! What's up?", + CustomFields: map[string]interface{}{ + "ab_test_version": "B", + "response_style": "casual", + }, + }, + }, + }, + }, + }, +} +``` + +### Provider-Specific Behavior + +```go +config := mocker.MockerConfig{ + Enabled: true, + Rules: []mocker.MockRule{ + { + Name: "openai-success", + Enabled: true, + Priority: 100, + Probability: 0.9, // 90% success rate + Conditions: mocker.Conditions{ + Providers: []string{"openai"}, + }, + Responses: []mocker.Response{ + { + Type: mocker.ResponseTypeSuccess, + Content: &mocker.SuccessResponse{ + Message: "OpenAI mock response - high reliability", + }, + }, + }, + }, + { + Name: "anthropic-mixed", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: mocker.Conditions{ + Providers: []string{"anthropic"}, + }, + Responses: []mocker.Response{ + { + Type: mocker.ResponseTypeSuccess, + Weight: 0.7, // 70% success + Content: &mocker.SuccessResponse{ + Message: "Anthropic mock response", + }, + }, + { + Type: mocker.ResponseTypeError, + Weight: 0.3, // 30% error + AllowFallbacks: boolPtr(true), + Error: &mocker.ErrorResponse{ + Message: "Anthropic service temporarily unavailable", + StatusCode: intPtr(503), + }, + }, + }, + }, + }, +} +``` + +## Statistics and Monitoring + +### Getting Statistics + +```go +// Get current statistics +stats := plugin.GetStats() +fmt.Printf("Total Requests: %d\n", stats.TotalRequests) +fmt.Printf("Mocked Requests: %d\n", stats.MockedRequests) +fmt.Printf("Success Responses: %d\n", stats.ResponsesGenerated) +fmt.Printf("Error Responses: %d\n", stats.ErrorsGenerated) + +// Per-rule statistics +for ruleName, hits := range stats.RuleHits { + fmt.Printf("Rule '%s': %d hits\n", ruleName, hits) +} +``` + +### Statistics Structure + +```go +type MockStats struct { + TotalRequests int64 `json:"total_requests"` // Total requests processed + MockedRequests int64 `json:"mocked_requests"` // Requests that matched rules + RuleHits map[string]int64 `json:"rule_hits"` // Per-rule hit counts + ErrorsGenerated int64 `json:"errors_generated"` // Error responses generated + ResponsesGenerated int64 `json:"responses_generated"` // Success responses generated +} +``` + +### Monitoring Example + +```go +// Periodic monitoring +go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + stats := plugin.GetStats() + log.Printf("Mocker Stats - Total: %d, Mocked: %d, Success: %d, Errors: %d", + stats.TotalRequests, stats.MockedRequests, + stats.ResponsesGenerated, stats.ErrorsGenerated) + } + } +}() +``` + +## Performance + +The Mocker plugin has been extensively optimized for high-throughput scenarios, including benchmarking and load testing. Here are the key performance characteristics and optimizations: + +### πŸš€ **Key Optimizations** + +#### 1. **Pre-compiled Regex Patterns** + +- All regex patterns are compiled once during plugin initialization +- **Before**: `regexp.Compile()` on every request (~1000x slower) +- **After**: Pre-compiled patterns with direct matching + +#### 2. **Atomic Counters for Statistics** + +- Statistics use `sync/atomic` operations instead of mutex locks +- **Before**: Mutex lock/unlock for every counter increment +- **After**: Lock-free atomic operations + +#### 3. **Optimized String Operations** + +- Fast-path string matching and content extraction +- Efficient template processing with `strings.NewReplacer` +- Minimal memory allocations in hot paths + +#### 4. **Smart Memory Management** + +- Pre-allocated data structures +- Reduced map allocations +- Efficient string building for multi-message content + +### πŸ§ͺ **Running Benchmarks** + +Since performance varies significantly across different hardware configurations, you should run benchmarks on your specific system to get accurate measurements: + +```bash +# Run all benchmarks with memory allocation stats +go test -bench=. -benchmem + +# Run specific benchmark scenarios +go test -bench=BenchmarkMockerPlugin_PreHook_SimpleRule -benchmem +go test -bench=BenchmarkMockerPlugin_PreHook_RegexRule -benchmem +go test -bench=BenchmarkMockerPlugin_PreHook_NoMatch -benchmem + +# Run with CPU profiling for detailed analysis +go test -bench=. -cpuprofile=cpu.prof + +# Run with memory profiling +go test -bench=. -memprofile=mem.prof + +# Run benchmarks multiple times for statistical accuracy +go test -bench=. -count=5 +``` + +### πŸ“Š **Understanding Benchmark Output** + +The benchmark output will show: + +- **Operations per second**: How many operations can be performed per second +- **Nanoseconds per operation**: Average time per operation +- **Bytes per operation**: Memory allocated per operation +- **Allocations per operation**: Number of memory allocations per operation + +Example output format: + +```text +BenchmarkMockerPlugin_PreHook_SimpleRule-8 5000000 250 ns/op 400 B/op 5 allocs/op +``` + +### πŸ“ˆ **Performance Reference** + +As a reference, here are results from testing on a system with 16GB RAM: + +```text +BenchmarkMockerPlugin_PreHook_SimpleRule 6,306,205 ops 189.6 ns/op 416 B/op 5 allocs/op +BenchmarkMockerPlugin_PreHook_RegexRule 712,371 ops 1637 ns/op 420 B/op 5 allocs/op +BenchmarkMockerPlugin_PreHook_MultipleRules 5,604,916 ops 214.1 ns/op 416 B/op 5 allocs/op +BenchmarkMockerPlugin_PreHook_NoMatch 155,663,086 ops 7.7 ns/op 0 B/op 0 allocs/op +BenchmarkMockerPlugin_PreHook_Template 864,408 ops 1351 ns/op 1688 B/op 19 allocs/op +``` + +**Note**: Your results may vary based on your hardware configuration. Run the benchmarks yourself for accurate measurements on your system. + +### 🎯 **Performance Characteristics** + +Based on the optimizations implemented, you can expect: + +#### **Ultra-Fast No-Match Path** + +- Minimal overhead when no rules match +- Perfect for production with selective mocking +- Zero allocations when plugin is disabled + +#### **High-Speed Simple Rules** + +- Fast provider/model string matching +- Suitable for high-frequency benchmarking +- Minimal memory allocations + +#### **Efficient Regex Matching** + +- Pre-compiled patterns (much faster than runtime compilation) +- Good performance for pattern-based mocking +- Scales well with multiple regex rules + +#### **Multiple Rule Evaluation** + +- Priority-based early termination +- Performance doesn't degrade significantly with rule count +- Optimized rule traversal + +### ⚑ **Configuration for Maximum Performance** + +#### For **Benchmarking** (Maximum Speed): + +```go +config := mocker.MockerConfig{ + Enabled: true, + Rules: []mocker.MockRule{ + { + Name: "benchmark-rule", + Enabled: true, + Priority: 100, + Probability: 1.0, // Always match for consistent results + Conditions: mocker.Conditions{ + Providers: []string{"openai"}, // Simple string match (fastest) + }, + Responses: []mocker.Response{ + { + Type: mocker.ResponseTypeSuccess, + Content: &mocker.SuccessResponse{ + Message: "Benchmark response", // Static message (no templates) + Usage: &mocker.Usage{ // Pre-defined usage + PromptTokens: 10, + CompletionTokens: 20, + TotalTokens: 30, + }, + }, + }, + }, + }, + }, +} +``` + +#### For **Production** (Minimal Overhead): + +```go +config := mocker.MockerConfig{ + Enabled: true, + DefaultBehavior: mocker.DefaultBehaviorPassthrough, // Fast no-match path + Rules: []mocker.MockRule{ + // Only critical error simulation rules + { + Name: "rate-limit-sim", + Enabled: true, + Probability: 0.01, // 1% activation rate + Conditions: mocker.Conditions{ + Providers: []string{"openai"}, // Simple conditions only + }, + // ... error response + }, + }, +} +``` + +### πŸ”§ **Performance Tuning Tips** + +#### 1. **Rule Optimization** + +```go +// βœ… FAST: Simple string matching +Conditions: mocker.Conditions{ + Providers: []string{"openai", "anthropic"}, + Models: []string{"gpt-4", "claude-3"}, +} + +// ⚠️ SLOWER: Regex patterns (but still fast with pre-compilation) +Conditions: mocker.Conditions{ + MessageRegex: stringPtr(`(?i)error|fail`), +} + +// ❌ AVOID: Complex regex patterns +Conditions: mocker.Conditions{ + MessageRegex: stringPtr(`^.*complex.*nested.*(pattern|match).*$`), +} +``` + +#### 2. **Response Optimization** + +```go +// βœ… FAST: Static responses +Content: &mocker.SuccessResponse{ + Message: "Static response", + Usage: &predefinedUsage, // Reuse objects +} + +// ⚠️ MODERATE: Simple templates +Content: &mocker.SuccessResponse{ + MessageTemplate: stringPtr("Response from {{provider}}"), // Minimal variables +} + +// ❌ AVOID: Complex templates with many variables +Content: &mocker.SuccessResponse{ + MessageTemplate: stringPtr("Complex {{var1}} {{var2}} {{var3}} template"), +} +``` + +#### 3. **Rule Priority** + +```go +// βœ… Put most common rules first (higher priority) +Rules: []mocker.MockRule{ + {Priority: 100, /* most common conditions */}, + {Priority: 50, /* less common conditions */}, + {Priority: 10, /* rare conditions */}, +} +``` + +### πŸ“ˆ **Monitoring Performance** + +Track performance metrics in your application: + +```go +func monitorMockerPerformance(plugin *mocker.MockerPlugin) { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + lastStats := plugin.GetStats() + lastTime := time.Now() + + for { + select { + case <-ticker.C: + currentStats := plugin.GetStats() + currentTime := time.Now() + + duration := currentTime.Sub(lastTime) + requests := currentStats.TotalRequests - lastStats.TotalRequests + + rps := float64(requests) / duration.Seconds() + mockRate := float64(currentStats.MockedRequests) / float64(currentStats.TotalRequests) * 100 + + log.Printf("Mocker Performance: %.0f req/s, %.1f%% mock rate", rps, mockRate) + + lastStats = currentStats + lastTime = currentTime + } + } +} +``` + +**πŸ† The Mocker plugin is optimized for high-throughput scenarios and adds minimal overhead to your application.** + +## Best Practices + +### 1. Rule Organization + +- **Use descriptive names**: `"rate-limit-openai"` instead of `"rule1"` +- **Set appropriate priorities**: Critical rules should have higher priority +- **Group related rules**: Keep similar functionality together + +### 2. Development vs Production + +```go +// Development - High mock rate +config := mocker.MockerConfig{ + Enabled: true, + DefaultBehavior: mocker.DefaultBehaviorSuccess, // Mock everything by default +} + +// Production - Selective mocking +config := mocker.MockerConfig{ + Enabled: true, + DefaultBehavior: mocker.DefaultBehaviorPassthrough, // Pass through by default + Rules: []mocker.MockRule{ + // Only specific error simulation rules + }, +} +``` + +### 3. Error Handling + +- **Use appropriate fallback settings**: Allow fallbacks for temporary errors +- **Provide meaningful error messages**: Help with debugging +- **Set realistic status codes**: Match actual provider behavior + +### 4. Performance Considerations + +- **Limit regex complexity**: Simple patterns perform better +- **Use probability wisely**: Don't mock 100% in production +- **Monitor statistics**: Watch for unexpected behavior + +### 5. Testing + +```go +func TestYourAppWithMocking(t *testing.T) { + // Create predictable mock responses for testing + plugin, _ := mocker.NewMockerPlugin(mocker.MockerConfig{ + Enabled: true, + Rules: []mocker.MockRule{ + { + Name: "test-success", + Enabled: true, + Probability: 1.0, // Always trigger for consistent tests + Responses: []mocker.Response{ + { + Type: mocker.ResponseTypeSuccess, + Content: &mocker.SuccessResponse{ + Message: "Predictable test response", + }, + }, + }, + }, + }, + }) + + // Use plugin in your tests... +} +``` + +## Troubleshooting + +### Common Issues + +#### 1. Plugin Not Triggering + +**Problem**: Requests pass through to real providers instead of being mocked. + +**Solutions**: + +- Check `Enabled: true` in config +- Verify rule conditions match your requests +- Check rule `Probability` (should be > 0) +- Ensure rule is `Enabled: true` + +#### 2. Configuration Validation Errors + +**Problem**: `NewMockerPlugin()` returns validation errors. + +**Common Issues**: + +```go +// ❌ Missing rule name +{ + Name: "", // Error: rule name required +} + +// ❌ Invalid probability +{ + Probability: 1.5, // Error: must be 0.0-1.0 +} + +// ❌ Invalid response type +{ + Type: "invalid", // Error: must be "success" or "error" +} + +// ❌ Missing response content +{ + Type: mocker.ResponseTypeSuccess, + // Error: Content required for success type +} +``` + +#### 3. Statistics Not Updating + +**Problem**: `GetStats()` shows zero values. + +**Solutions**: + +- Ensure rules are actually matching (check conditions) +- Verify plugin is enabled +- Call `GetStats()` before `Cleanup()` (cleanup clears stats) + +#### 4. Regex Not Matching + +**Problem**: `MessageRegex` conditions not working. + +**Solutions**: + +```go +// ❌ Invalid regex +MessageRegex: stringPtr("[invalid"), // Syntax error + +// βœ… Valid regex patterns +MessageRegex: stringPtr(`(?i)hello`), // Case-insensitive +MessageRegex: stringPtr(`error|fail|problem`), // Multiple options +MessageRegex: stringPtr(`\d+`), // Numbers only +``` + +#### 5. Unexpected Fallback Behavior + +**Problem**: Errors don't trigger fallbacks as expected. + +**Solutions**: + +```go +// Control fallback behavior explicitly +Response{ + Type: mocker.ResponseTypeError, + AllowFallbacks: boolPtr(true), // Explicitly allow fallbacks + // or + AllowFallbacks: boolPtr(false), // Explicitly block fallbacks + // or + AllowFallbacks: nil, // Default behavior (allow) +} +``` + +#### 6. Latency Not Working + +**Problem**: Latency simulation has no effect or causes errors. + +**Common Issue**: Using raw integers instead of `time.Duration` values. + +**Solutions**: + +```go +// ❌ WRONG: Raw integers (these are nanoseconds, barely noticeable) +Latency: &mocker.Latency{ + Type: mocker.LatencyTypeFixed, + Min: 100, // 100 nanoseconds = 0.0001ms + Max: 500, // 500 nanoseconds = 0.0005ms +} + +// βœ… CORRECT: Use time.Duration constants +Latency: &mocker.Latency{ + Type: mocker.LatencyTypeFixed, + Min: 100 * time.Millisecond, // 100ms +} + +// βœ… CORRECT: Various duration examples +Latency: &mocker.Latency{ + Type: mocker.LatencyTypeUniform, + Min: 50 * time.Millisecond, // 50ms + Max: 200 * time.Millisecond, // 200ms +} + +// βœ… CORRECT: Other duration units +Min: 1 * time.Second, // 1 second +Min: 500 * time.Microsecond, // 500 microseconds +Min: 2500 * time.Nanosecond, // 2500 nanoseconds (rarely used) +``` + +### Debug Mode + +Enable debug logging to troubleshoot issues: + +```go +client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug), // Enable debug logs +}) +``` + +### Validation Testing + +Test your configuration before deployment: + +```go +func validateMockerConfig(config mocker.MockerConfig) error { + _, err := mocker.NewMockerPlugin(config) + return err +} + +// Test in your code +if err := validateMockerConfig(yourConfig); err != nil { + log.Fatalf("Invalid mocker configuration: %v", err) +} +``` + +--- + +**Need help?** Check the [Bifrost documentation](../../docs/plugins.md) or open an issue on GitHub. + +``` + +``` diff --git a/plugins/mocker/benchmark_test.go b/plugins/mocker/benchmark_test.go new file mode 100644 index 0000000000..8568062eee --- /dev/null +++ b/plugins/mocker/benchmark_test.go @@ -0,0 +1,296 @@ +package mocker + +import ( + "context" + "strconv" + "testing" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// BenchmarkMockerPlugin_PreHook_SimpleRule benchmarks simple rule matching +func BenchmarkMockerPlugin_PreHook_SimpleRule(b *testing.B) { + plugin, err := NewMockerPlugin(MockerConfig{ + Enabled: true, + Rules: []MockRule{ + { + Name: "simple-rule", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: Conditions{ + Providers: []string{"openai"}, + }, + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "Benchmark response", + }, + }, + }, + }, + }, + }) + if err != nil { + b.Fatal(err) + } + + req := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, benchmark test"), + }, + }, + }, + }, + } + + ctx := context.Background() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _, _ = plugin.PreHook(&ctx, req) + } +} + +// BenchmarkMockerPlugin_PreHook_RegexRule benchmarks regex rule matching +func BenchmarkMockerPlugin_PreHook_RegexRule(b *testing.B) { + plugin, err := NewMockerPlugin(MockerConfig{ + Enabled: true, + Rules: []MockRule{ + { + Name: "regex-rule", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: Conditions{ + MessageRegex: bifrost.Ptr(`(?i).*hello.*`), + }, + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "Regex matched response", + }, + }, + }, + }, + }, + }) + if err != nil { + b.Fatal(err) + } + + req := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, this should match the regex pattern"), + }, + }, + }, + }, + } + + ctx := context.Background() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _, _ = plugin.PreHook(&ctx, req) + } +} + +// BenchmarkMockerPlugin_PreHook_MultipleRules benchmarks multiple rule evaluation +func BenchmarkMockerPlugin_PreHook_MultipleRules(b *testing.B) { + rules := make([]MockRule, 10) + for i := 0; i < 10; i++ { + rules[i] = MockRule{ + Name: "rule-" + strconv.Itoa(i), + Enabled: true, + Priority: 100 - i, // Descending priority + Probability: 1.0, + Conditions: Conditions{ + Models: []string{"gpt-" + strconv.Itoa(i)}, + }, + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "Response from rule " + strconv.Itoa(i), + }, + }, + }, + } + } + + // Add a matching rule at the end + rules = append(rules, MockRule{ + Name: "matching-rule", + Enabled: true, + Priority: 50, + Probability: 1.0, + Conditions: Conditions{ + Models: []string{"gpt-4"}, + }, + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "Matching rule response", + }, + }, + }, + }) + + plugin, err := NewMockerPlugin(MockerConfig{ + Enabled: true, + Rules: rules, + }) + if err != nil { + b.Fatal(err) + } + + req := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Test message"), + }, + }, + }, + }, + } + + ctx := context.Background() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _, _ = plugin.PreHook(&ctx, req) + } +} + +// BenchmarkMockerPlugin_PreHook_NoMatch benchmarks when no rules match +func BenchmarkMockerPlugin_PreHook_NoMatch(b *testing.B) { + plugin, err := NewMockerPlugin(MockerConfig{ + Enabled: true, + DefaultBehavior: DefaultBehaviorPassthrough, + Rules: []MockRule{ + { + Name: "non-matching-rule", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: Conditions{ + Providers: []string{"anthropic"}, // Won't match OpenAI + }, + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "This won't match", + }, + }, + }, + }, + }, + }) + if err != nil { + b.Fatal(err) + } + + req := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, // Different from rule condition + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Test message"), + }, + }, + }, + }, + } + + ctx := context.Background() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _, _ = plugin.PreHook(&ctx, req) + } +} + +// BenchmarkMockerPlugin_PreHook_Template benchmarks template processing +func BenchmarkMockerPlugin_PreHook_Template(b *testing.B) { + plugin, err := NewMockerPlugin(MockerConfig{ + Enabled: true, + Rules: []MockRule{ + { + Name: "template-rule", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: Conditions{}, // Match all + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + MessageTemplate: bifrost.Ptr("Hello from {{provider}} using model {{model}}!"), + }, + }, + }, + }, + }, + }) + if err != nil { + b.Fatal(err) + } + + req := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Test message"), + }, + }, + }, + }, + } + + ctx := context.Background() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _, _ = plugin.PreHook(&ctx, req) + } +} diff --git a/plugins/mocker/go.mod b/plugins/mocker/go.mod new file mode 100644 index 0000000000..26a56fa141 --- /dev/null +++ b/plugins/mocker/go.mod @@ -0,0 +1,37 @@ +module github.com/maximhq/bifrost/plugins/mocker + +go 1.24.1 + +require ( + github.com/jaswdr/faker/v2 v2.5.0 + github.com/maximhq/bifrost/core v1.1.6 +) + +require ( + cloud.google.com/go/compute/metadata v0.3.0 // indirect + github.com/andybalholm/brotli v1.1.1 // indirect + github.com/aws/aws-sdk-go-v2 v1.36.3 // indirect + github.com/aws/aws-sdk-go-v2/config v1.29.14 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect + github.com/aws/smithy-go v1.22.3 // indirect + github.com/goccy/go-json v0.10.5 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/mark3labs/mcp-go v0.32.0 // indirect + github.com/spf13/cast v1.7.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.60.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/net v0.39.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/text v0.24.0 // indirect +) diff --git a/plugins/mocker/go.sum b/plugins/mocker/go.sum new file mode 100644 index 0000000000..6aefe9aa2f --- /dev/null +++ b/plugins/mocker/go.sum @@ -0,0 +1,76 @@ +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= +github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= +github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= +github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= +github.com/aws/aws-sdk-go-v2/config v1.29.14 h1:f+eEi/2cKCg9pqKBoAIwRGzVb70MRKqWX4dg1BDcSJM= +github.com/aws/aws-sdk-go-v2/config v1.29.14/go.mod h1:wVPHWcIFv3WO89w0rE10gzf17ZYy+UVS1Geq8Iei34g= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67 h1:9KxtdcIA/5xPNQyZRgUSpYOE6j9Bc4+D7nZua0KGYOM= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67/go.mod h1:p3C44m+cfnbv763s52gCqrjaqyPikj9Sg47kUVaNZQQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 h1:1Gw+9ajCV1jogloEv1RRnvfRFia2cL6c9cuKV2Ps+G8= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0cFmC3JvwLm5kM83luako= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/XvaX32evhproijJEZY= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= +github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k= +github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= +github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jaswdr/faker/v2 v2.5.0 h1:KUYfnleIZMSHNp/q+rDk7XEuqUUL5FhfT19iTTFqF5o= +github.com/jaswdr/faker/v2 v2.5.0/go.mod h1:ROK8xwQV0hYOLDUtxCQgHGcl10jbVzIvqHxcIDdwY2Q= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8= +github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/maximhq/bifrost/core v1.1.6 h1:rZrfPVcAfNggfBaOTdu/w+xNwDhW79bfexXsw8LRoMQ= +github.com/maximhq/bifrost/core v1.1.6/go.mod h1:yMRCncTgKYBIrECSRVxMbY3BL8CjLbipJlc644jryxc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.60.0 h1:kBRYS0lOhVJ6V+bYN8PqAHELKHtXqwq9zNMLKx1MBsw= +github.com/valyala/fasthttp v1.60.0/go.mod h1:iY4kDgV3Gc6EqhRZ8icqcmlG6bqhcDXfuHgTO4FXCvc= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/mocker/main.go b/plugins/mocker/main.go new file mode 100644 index 0000000000..9be3eb508c --- /dev/null +++ b/plugins/mocker/main.go @@ -0,0 +1,1094 @@ +package mocker + +import ( + "context" + "encoding/json" + "fmt" + "maps" + "math/rand" + "regexp" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/jaswdr/faker/v2" + "github.com/maximhq/bifrost/core/schemas" +) + +// Init initializes the Mocker plugin from JSON configuration (for dynamic loading) +func Init(configData json.RawMessage) (schemas.Plugin, error) { + var config MockerConfig + if err := json.Unmarshal(configData, &config); err != nil { + return nil, fmt.Errorf("invalid mocker plugin config: %w", err) + } + + return NewMockerPlugin(config) +} + +const ( + PluginName = "bifrost-mocker" +) + +// Constants for type checking and validation +const ( + // Response types + ResponseTypeSuccess = "success" + ResponseTypeError = "error" + + // Default behaviors + DefaultBehaviorPassthrough = "passthrough" + DefaultBehaviorError = "error" + DefaultBehaviorSuccess = "success" + + // Latency types + LatencyTypeFixed = "fixed" + LatencyTypeUniform = "uniform" +) + +// compiledRule represents a rule with pre-compiled regex and normalized weights for performance +type compiledRule struct { + MockRule + compiledRegex *regexp.Regexp // Pre-compiled regex for fast matching + normalizedWeights []float64 // Pre-calculated normalized weights for fast response selection +} + +// MockerPlugin provides comprehensive request/response mocking capabilities +type MockerPlugin struct { + config MockerConfig + rules []MockRule + compiledRules []compiledRule // Pre-compiled rules for performance + mu sync.RWMutex + faker faker.Faker // Use jaswdr/faker library + + // Atomic counters for high-performance statistics tracking + totalRequests int64 + mockedRequests int64 + responsesGenerated int64 + errorsGenerated int64 + + // Rule hits tracking (still needs mutex for map access) + ruleHitsMu sync.RWMutex + ruleHits map[string]int64 +} + +// MockerConfig defines the overall configuration for the mocker plugin +type MockerConfig struct { + Enabled bool `json:"enabled"` // Enable/disable the mocker plugin + GlobalLatency *Latency `json:"global_latency"` // Global latency settings applied to all rules (can be overridden per rule) + Rules []MockRule `json:"rules"` // List of mock rules to be evaluated in priority order + DefaultBehavior string `json:"default_behavior"` // Action when no rules match: "passthrough", "error", or "success" +} + +// MockRule defines a single mocking rule with conditions and responses +// Rules are evaluated in priority order (higher numbers = higher priority) +type MockRule struct { + Name string `json:"name"` // Unique rule name for identification and statistics tracking + Enabled bool `json:"enabled"` // Enable/disable this rule (disabled rules are skipped) + Priority int `json:"priority"` // Higher priority rules are checked first (higher numbers = higher priority) + Conditions Conditions `json:"conditions"` // Conditions that must match for this rule to apply + Responses []Response `json:"responses"` // Possible responses (selected using weighted random selection) + Latency *Latency `json:"latency"` // Rule-specific latency override (overrides global latency if set) + Probability float64 `json:"probability"` // Probability of rule activation (0.0=never, 1.0=always, 0=disabled) +} + +// Conditions define when a mock rule should be applied +// All specified conditions must match for the rule to trigger +type Conditions struct { + Providers []string `json:"providers"` // Match specific providers (e.g., ["openai", "anthropic"]) + Models []string `json:"models"` // Match specific models (e.g., ["gpt-4", "claude-3"]) + MessageRegex *string `json:"message_regex"` // Regex pattern to match against message content + RequestSize *SizeRange `json:"request_size"` // Request size constraints in bytes +} + +// Response defines a mock response configuration +// Either Content (for success) or Error (for error) should be set, not both +type Response struct { + Type string `json:"type"` // Response type: "success" or "error" + Weight float64 `json:"weight"` // Weight for random selection (higher = more likely) + Content *SuccessResponse `json:"content"` // Success response content (required if Type="success") + Error *ErrorResponse `json:"error"` // Error response content (required if Type="error") + AllowFallbacks *bool `json:"allow_fallbacks"` // Control fallback behavior for errors (nil=true, false=no fallbacks) +} + +// SuccessResponse defines mock success response content +// Either Message or MessageTemplate should be set (MessageTemplate takes precedence) +type SuccessResponse struct { + Message string `json:"message"` // Static response message + Model *string `json:"model"` // Override model name in response (optional) + Usage *Usage `json:"usage"` // Token usage info (optional, defaults applied if nil) + FinishReason *string `json:"finish_reason"` // Completion reason (optional, defaults to "stop") + MessageTemplate *string `json:"message_template"` // Template with variables like {{model}}, {{provider}} (overrides Message) + CustomFields map[string]interface{} `json:"custom_fields"` // Additional fields stored in response metadata +} + +// ErrorResponse defines mock error response content +type ErrorResponse struct { + Message string `json:"message"` // Error message to return + Type *string `json:"type"` // Error type (e.g., "rate_limit", "auth_error") + Code *string `json:"code"` // Error code (e.g., "429", "401") + StatusCode *int `json:"status_code"` // HTTP status code for the error +} + +// Latency defines latency simulation settings +type Latency struct { + Min time.Duration `json:"min"` // Minimum latency as time.Duration (e.g., 100*time.Millisecond, NOT raw int) + Max time.Duration `json:"max"` // Maximum latency as time.Duration (e.g., 500*time.Millisecond, NOT raw int) + Type string `json:"type"` // Latency type: "fixed" or "uniform" +} + +// SizeRange defines request size constraints in bytes +type SizeRange struct { + Min int `json:"min"` // Minimum request size in bytes + Max int `json:"max"` // Maximum request size in bytes +} + +// Usage defines token usage information +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// MockStats tracks plugin statistics and rule execution counts +type MockStats struct { + TotalRequests int64 `json:"total_requests"` // Total number of requests processed + MockedRequests int64 `json:"mocked_requests"` // Number of requests that were mocked (rules matched) + RuleHits map[string]int64 `json:"rule_hits"` // Rule name -> hit count mapping + ErrorsGenerated int64 `json:"errors_generated"` // Number of error responses generated + ResponsesGenerated int64 `json:"responses_generated"` // Number of success responses generated +} + +// NewMockerPlugin creates a new mocker plugin instance with sensible defaults +// Returns an error if required configuration is invalid or missing +func NewMockerPlugin(config MockerConfig) (*MockerPlugin, error) { + // Validate configuration + if err := validateConfig(config); err != nil { + return nil, fmt.Errorf("invalid mocker plugin configuration: %w", err) + } + + // Apply defaults if not set + if config.DefaultBehavior == "" { + config.DefaultBehavior = DefaultBehaviorPassthrough // Default to passthrough if no rules match + } + + // If no rules provided, create a simple catch-all rule for quick testing + if len(config.Rules) == 0 && config.Enabled { + config.Rules = []MockRule{ + { + Name: "default-mock", + Enabled: true, + Priority: 1, + Conditions: Conditions{}, // Empty conditions = match all requests + Probability: 1.0, // Always activate + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Weight: 1.0, + Content: &SuccessResponse{ + Message: "This is a mock response from the Mocker plugin", + }, + }, + }, + }, + } + } + + plugin := &MockerPlugin{ + config: config, + rules: config.Rules, + ruleHits: make(map[string]int64), + faker: faker.New(), // Initialize faker + } + + // Pre-compile all regex patterns for performance + if err := plugin.compileRules(); err != nil { + return nil, fmt.Errorf("failed to compile rules: %w", err) + } + + return plugin, nil +} + +// compileRules pre-compiles all regex patterns and calculates normalized weights for performance +func (p *MockerPlugin) compileRules() error { + p.compiledRules = make([]compiledRule, 0, len(p.rules)) + + for _, rule := range p.rules { + compiled := compiledRule{MockRule: rule} + + // Pre-compile regex if present + if rule.Conditions.MessageRegex != nil { + regex, err := regexp.Compile(*rule.Conditions.MessageRegex) + if err != nil { + return fmt.Errorf("invalid regex in rule '%s': %w", rule.Name, err) + } + compiled.compiledRegex = regex + } + + // Pre-calculate normalized weights for fast response selection + compiled.normalizedWeights = p.calculateNormalizedWeights(rule.Responses) + + p.compiledRules = append(p.compiledRules, compiled) + } + + // Sort compiled rules by priority (higher first) + p.sortCompiledRulesByPriority() + + return nil +} + +// calculateNormalizedWeights pre-calculates normalized cumulative weights for fast response selection +func (p *MockerPlugin) calculateNormalizedWeights(responses []Response) []float64 { + if len(responses) == 0 { + return nil + } + + if len(responses) == 1 { + return []float64{1.0} // Single response always gets 100% probability + } + + // Calculate total weight, applying default weight of 1.0 if not specified + totalWeight := 0.0 + for _, response := range responses { + weight := response.Weight + if weight == 0 { + weight = 1.0 // Default weight + } + totalWeight += weight + } + + // Calculate normalized cumulative weights for O(1) selection + normalizedWeights := make([]float64, len(responses)) + cumulativeWeight := 0.0 + + for i, response := range responses { + weight := response.Weight + if weight == 0 { + weight = 1.0 // Default weight + } + cumulativeWeight += weight / totalWeight // Normalize to [0, 1] + normalizedWeights[i] = cumulativeWeight + } + + // Ensure the last weight is exactly 1.0 to handle floating point precision issues + if len(normalizedWeights) > 0 { + normalizedWeights[len(normalizedWeights)-1] = 1.0 + } + + return normalizedWeights +} + +// validateConfig validates the mocker plugin configuration +func validateConfig(config MockerConfig) error { + // Validate default behavior + if config.DefaultBehavior != "" { + switch config.DefaultBehavior { + case DefaultBehaviorPassthrough, DefaultBehaviorError, DefaultBehaviorSuccess: + // Valid + default: + return fmt.Errorf("invalid default_behavior '%s', must be one of: %s, %s, %s", + config.DefaultBehavior, DefaultBehaviorPassthrough, DefaultBehaviorError, DefaultBehaviorSuccess) + } + } + + // Validate global latency if provided + if config.GlobalLatency != nil { + if err := validateLatency(*config.GlobalLatency); err != nil { + return fmt.Errorf("invalid global_latency: %w", err) + } + } + + // Validate each rule + for i, rule := range config.Rules { + if err := validateRule(rule); err != nil { + return fmt.Errorf("invalid rule at index %d (%s): %w", i, rule.Name, err) + } + } + + return nil +} + +// validateRule validates a single mock rule +func validateRule(rule MockRule) error { + // Rule name is required + if rule.Name == "" { + return fmt.Errorf("rule name is required") + } + + // Priority should be reasonable (allow negative for low priority) + if rule.Priority < -1000 || rule.Priority > 1000 { + return fmt.Errorf("priority %d is out of reasonable range (-1000 to 1000)", rule.Priority) + } + + // Probability must be between 0 and 1 + if rule.Probability < 0 || rule.Probability > 1 { + return fmt.Errorf("probability %.2f must be between 0.0 and 1.0", rule.Probability) + } + + // At least one response is required + if len(rule.Responses) == 0 { + return fmt.Errorf("at least one response is required") + } + + // Validate rule-specific latency if provided + if rule.Latency != nil { + if err := validateLatency(*rule.Latency); err != nil { + return fmt.Errorf("invalid rule latency: %w", err) + } + } + + // Validate conditions + if err := validateConditions(rule.Conditions); err != nil { + return fmt.Errorf("invalid conditions: %w", err) + } + + // Validate each response + for i, response := range rule.Responses { + if err := validateResponse(response); err != nil { + return fmt.Errorf("invalid response at index %d: %w", i, err) + } + } + + return nil +} + +// validateLatency validates latency configuration +func validateLatency(latency Latency) error { + // Type is required + if latency.Type == "" { + return fmt.Errorf("latency type is required") + } + + // Validate type + switch latency.Type { + case LatencyTypeFixed, LatencyTypeUniform: + // Valid + default: + return fmt.Errorf("invalid latency type '%s', must be one of: %s, %s", + latency.Type, LatencyTypeFixed, LatencyTypeUniform) + } + + // Min latency should be non-negative + if latency.Min < 0 { + return fmt.Errorf("minimum latency cannot be negative") + } + + // For uniform type, max should be >= min + if latency.Type == LatencyTypeUniform { + if latency.Max < latency.Min { + return fmt.Errorf("maximum latency (%v) cannot be less than minimum latency (%v)", latency.Max, latency.Min) + } + } + + return nil +} + +// validateConditions validates rule conditions +func validateConditions(conditions Conditions) error { + // Validate regex if provided + if conditions.MessageRegex != nil { + _, err := regexp.Compile(*conditions.MessageRegex) + if err != nil { + return fmt.Errorf("invalid message regex '%s': %w", *conditions.MessageRegex, err) + } + } + + // Validate request size range if provided + if conditions.RequestSize != nil { + if conditions.RequestSize.Min < 0 { + return fmt.Errorf("request size minimum cannot be negative") + } + if conditions.RequestSize.Max < conditions.RequestSize.Min { + return fmt.Errorf("request size maximum (%d) cannot be less than minimum (%d)", + conditions.RequestSize.Max, conditions.RequestSize.Min) + } + } + + return nil +} + +// validateResponse validates a response configuration +func validateResponse(response Response) error { + // Type is required + if response.Type == "" { + return fmt.Errorf("response type is required") + } + + // Validate type + switch response.Type { + case ResponseTypeSuccess, ResponseTypeError: + // Valid + default: + return fmt.Errorf("invalid response type '%s', must be one of: %s, %s", + response.Type, ResponseTypeSuccess, ResponseTypeError) + } + + // Weight should be non-negative + if response.Weight < 0 { + return fmt.Errorf("response weight cannot be negative") + } + + // Validate response content based on type + if response.Type == ResponseTypeSuccess { + if response.Content == nil { + return fmt.Errorf("success response must have content") + } + if err := validateSuccessResponse(*response.Content); err != nil { + return fmt.Errorf("invalid success content: %w", err) + } + } else if response.Type == ResponseTypeError { + if response.Error == nil { + return fmt.Errorf("error response must have error content") + } + if err := validateErrorResponse(*response.Error); err != nil { + return fmt.Errorf("invalid error content: %w", err) + } + } + + return nil +} + +// validateSuccessResponse validates success response content +func validateSuccessResponse(content SuccessResponse) error { + // Either Message or MessageTemplate must be provided + if content.Message == "" && (content.MessageTemplate == nil || *content.MessageTemplate == "") { + return fmt.Errorf("either message or message_template is required") + } + + // If usage is provided, validate it + if content.Usage != nil { + if content.Usage.PromptTokens < 0 || content.Usage.CompletionTokens < 0 || content.Usage.TotalTokens < 0 { + return fmt.Errorf("token counts cannot be negative") + } + } + + return nil +} + +// validateErrorResponse validates error response content +func validateErrorResponse(errorContent ErrorResponse) error { + // Message is required + if errorContent.Message == "" { + return fmt.Errorf("error message is required") + } + + // Status code should be reasonable if provided + if errorContent.StatusCode != nil { + if *errorContent.StatusCode < 100 || *errorContent.StatusCode > 599 { + return fmt.Errorf("status code %d is out of valid HTTP range (100-599)", *errorContent.StatusCode) + } + } + + return nil +} + +// GetName returns the plugin name +func (p *MockerPlugin) GetName() string { + return PluginName +} + +// PreHook intercepts requests and applies mocking rules based on configuration +// This is called before the actual provider request and can short-circuit the flow +func (p *MockerPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + // Skip processing if plugin is disabled + if !p.config.Enabled { + return req, nil, nil + } + + // Track total request count using atomic operation (no lock needed) + atomic.AddInt64(&p.totalRequests, 1) + + // Find the first matching rule based on priority order + rule := p.findMatchingCompiledRule(req) + if rule == nil { + // No rules matched, handle according to default behavior + return p.handleDefaultBehavior(req) + } + + // Check if rule should activate based on probability (0.0 = never, 1.0 = always) + if rule.Probability > 0 && rand.Float64() > rule.Probability { + // Rule didn't activate due to probability, continue with normal flow + return req, nil, nil + } + + // Apply artificial latency simulation if configured + if latency := p.getLatency(&rule.MockRule); latency != nil { + delay := p.calculateLatency(latency) + time.Sleep(delay) + } + + // Select a response from the rule's possible responses using pre-calculated weights + response := p.selectResponse(rule) + if response == nil { + // No valid response configuration, continue with normal flow + return req, nil, nil + } + + // Update statistics using atomic operations and minimal locking + atomic.AddInt64(&p.mockedRequests, 1) + + // Rule hits still need a mutex since it's a map, but we minimize lock time + p.ruleHitsMu.Lock() + p.ruleHits[rule.Name]++ + p.ruleHitsMu.Unlock() + + // Generate appropriate mock response based on type + if response.Type == ResponseTypeSuccess { + return p.generateSuccessShortCircuit(req, response) + } else if response.Type == ResponseTypeError { + return p.generateErrorShortCircuit(req, response) + } + + // Fallback: continue with normal flow if response type is unrecognized + return req, nil, nil +} + +// PostHook processes responses after provider calls +func (p *MockerPlugin) PostHook(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + return result, err, nil +} + +// Cleanup performs plugin cleanup and frees memory +// IMPORTANT: Call GetStats() before Cleanup() if you need the statistics, +// as this method clears all statistics data to free memory +func (p *MockerPlugin) Cleanup() error { + p.mu.Lock() + defer p.mu.Unlock() + + // Clear all statistics to free memory using atomic operations + atomic.StoreInt64(&p.totalRequests, 0) + atomic.StoreInt64(&p.mockedRequests, 0) + atomic.StoreInt64(&p.responsesGenerated, 0) + atomic.StoreInt64(&p.errorsGenerated, 0) + + // Clear rule hits map + p.ruleHitsMu.Lock() + p.ruleHits = make(map[string]int64) + p.ruleHitsMu.Unlock() + + // Clear rules to free memory + p.rules = nil + p.compiledRules = nil + + return nil +} + +// findMatchingCompiledRule finds the first rule that matches the request using pre-compiled rules +func (p *MockerPlugin) findMatchingCompiledRule(req *schemas.BifrostRequest) *compiledRule { + for i := range p.compiledRules { + rule := &p.compiledRules[i] + if !rule.Enabled { + continue + } + + if p.matchesConditionsFast(req, &rule.Conditions, rule.compiledRegex) { + return rule + } + } + return nil +} + +// matchesConditionsFast checks if request matches rule conditions with optimized performance +func (p *MockerPlugin) matchesConditionsFast(req *schemas.BifrostRequest, conditions *Conditions, compiledRegex *regexp.Regexp) bool { + // Check providers - optimized string comparison + if len(conditions.Providers) > 0 { + providerStr := string(req.Provider) + found := false + for _, provider := range conditions.Providers { + if providerStr == provider { + found = true + break + } + } + if !found { + return false + } + } + + // Check models - direct string comparison + if len(conditions.Models) > 0 { + found := false + for _, model := range conditions.Models { + if req.Model == model { + found = true + break + } + } + if !found { + return false + } + } + + // Check message regex using pre-compiled regex (major performance improvement) + if compiledRegex != nil { + // Extract message content from request (cached if possible) + messageContent := p.extractMessageContentFast(req) + if !compiledRegex.MatchString(messageContent) { + return false + } + } + + // Check request size - only calculate if needed + if conditions.RequestSize != nil { + size := p.calculateRequestSizeFast(req) + if size < conditions.RequestSize.Min || size > conditions.RequestSize.Max { + return false + } + } + + // All conditions matched + return true +} + +// extractMessageContentFast extracts message content with optimized performance +func (p *MockerPlugin) extractMessageContentFast(req *schemas.BifrostRequest) string { + // Handle text completion input + if req.Input.TextCompletionInput != nil { + return *req.Input.TextCompletionInput + } + + // Handle chat completion input - optimized for common cases + if req.Input.ChatCompletionInput != nil { + messages := *req.Input.ChatCompletionInput + if len(messages) == 0 { + return "" + } + + // Fast path for single message + if len(messages) == 1 { + if messages[0].Content.ContentStr != nil { + return *messages[0].Content.ContentStr + } + return "" + } + + // Multiple messages - use string builder for efficiency + var builder strings.Builder + for i, message := range messages { + if message.Content.ContentStr != nil { + if i > 0 { + builder.WriteByte(' ') + } + builder.WriteString(*message.Content.ContentStr) + } + } + return builder.String() + } + + return "" +} + +// calculateRequestSizeFast calculates request size with minimal overhead +func (p *MockerPlugin) calculateRequestSizeFast(req *schemas.BifrostRequest) int { + // Approximate size calculation to avoid expensive JSON marshaling + size := len(req.Model) + len(string(req.Provider)) + + // Add input size + if req.Input.TextCompletionInput != nil { + size += len(*req.Input.TextCompletionInput) + } + + if req.Input.ChatCompletionInput != nil { + for _, message := range *req.Input.ChatCompletionInput { + if message.Content.ContentStr != nil { + size += len(*message.Content.ContentStr) + } + size += 50 // Approximate overhead for message structure + } + } + + return size +} + +// generateSuccessShortCircuit creates a success response short-circuit with optimized allocations +func (p *MockerPlugin) generateSuccessShortCircuit(req *schemas.BifrostRequest, response *Response) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + if response.Content == nil { + return req, nil, nil + } + + content := response.Content + message := content.Message + + // Apply message template if provided + if content.MessageTemplate != nil { + message = p.applyTemplate(*content.MessageTemplate, req) + } + + // Apply defaults for token usage if not provided + var usage schemas.LLMUsage + if content.Usage != nil { + usage = schemas.LLMUsage{ + PromptTokens: p.getOrDefault(content.Usage.PromptTokens, 10), + CompletionTokens: p.getOrDefault(content.Usage.CompletionTokens, 20), + TotalTokens: p.getOrDefault(content.Usage.TotalTokens, content.Usage.PromptTokens+content.Usage.CompletionTokens), + } + } else { + // Default usage when none specified + usage = schemas.LLMUsage{ + PromptTokens: 10, + CompletionTokens: 20, + TotalTokens: 30, + } + } + + // Get finish reason with minimal allocation + var finishReason *string + if content.FinishReason != nil { + finishReason = content.FinishReason + } else { + // Use a static string to avoid allocation + static := "stop" + finishReason = &static + } + + // Create mock response with proper structure + mockResponse := &schemas.BifrostResponse{ + Model: req.Model, + Usage: usage, + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + Message: schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ + ContentStr: &message, + }, + }, + FinishReason: finishReason, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: req.Provider, + }, + } + + // Override model if specified + if content.Model != nil { + mockResponse.Model = *content.Model + } + + // Only create raw response map if there are custom fields (avoid allocation) + if len(content.CustomFields) > 0 { + rawResponse := make(map[string]interface{}, len(content.CustomFields)+1) + + // Add custom fields + for key, value := range content.CustomFields { + rawResponse[key] = value + } + + // Add mock metadata + rawResponse["mock_rule"] = "success" + mockResponse.ExtraFields.RawResponse = rawResponse + } + + // Increment success response counter using atomic operation + atomic.AddInt64(&p.responsesGenerated, 1) + + return req, &schemas.PluginShortCircuit{ + Response: mockResponse, + }, nil +} + +// generateErrorShortCircuit creates an error response short-circuit with optimized performance +func (p *MockerPlugin) generateErrorShortCircuit(req *schemas.BifrostRequest, response *Response) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + if response.Error == nil { + return req, nil, nil + } + + errorContent := response.Error + allowFallbacks := response.AllowFallbacks + + // Create mock error + mockError := &schemas.BifrostError{ + Error: schemas.ErrorField{ + Message: errorContent.Message, + }, + AllowFallbacks: allowFallbacks, + } + + // Set error type + if errorContent.Type != nil { + mockError.Error.Type = errorContent.Type + } + + // Set error code + if errorContent.Code != nil { + mockError.Error.Code = errorContent.Code + } + + // Set status code + if errorContent.StatusCode != nil { + mockError.StatusCode = errorContent.StatusCode + } + + // Increment error counter using atomic operation + atomic.AddInt64(&p.errorsGenerated, 1) + + return req, &schemas.PluginShortCircuit{ + Error: mockError, + }, nil +} + +// selectResponse selects a response using pre-calculated normalized weights for optimal performance +func (p *MockerPlugin) selectResponse(rule *compiledRule) *Response { + responses := rule.Responses + normalizedWeights := rule.normalizedWeights + + if len(responses) == 0 { + return nil + } + + if len(responses) == 1 { + return &responses[0] + } + + // Fast O(log n) binary search using pre-calculated cumulative weights + randomValue := rand.Float64() + + // Binary search for the selected response + left, right := 0, len(normalizedWeights)-1 + for left < right { + mid := (left + right) / 2 + if randomValue <= normalizedWeights[mid] { + right = mid + } else { + left = mid + 1 + } + } + + return &responses[left] +} + +// getLatency returns the applicable latency configuration +func (p *MockerPlugin) getLatency(rule *MockRule) *Latency { + if rule.Latency != nil { + return rule.Latency + } + return p.config.GlobalLatency +} + +// calculateLatency calculates the actual delay based on latency configuration +func (p *MockerPlugin) calculateLatency(latency *Latency) time.Duration { + switch latency.Type { + case LatencyTypeFixed: + return latency.Min + case LatencyTypeUniform: + if latency.Max <= latency.Min { + return latency.Min + } + // Calculate random duration between Min and Max + diff := latency.Max - latency.Min + return latency.Min + time.Duration(rand.Float64()*float64(diff)) + default: + // Default to fixed latency + return latency.Min + } +} + +// handleDefaultBehavior handles requests when no rules match +func (p *MockerPlugin) handleDefaultBehavior(req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + switch p.config.DefaultBehavior { + case DefaultBehaviorError: + return req, &schemas.PluginShortCircuit{ + Error: &schemas.BifrostError{ + Error: schemas.ErrorField{ + Message: "Mock plugin default error", + }, + }, + }, nil + case DefaultBehaviorSuccess: + finishReason := "stop" + return req, &schemas.PluginShortCircuit{ + Response: &schemas.BifrostResponse{ + Model: req.Model, + Usage: schemas.LLMUsage{ + PromptTokens: 5, + CompletionTokens: 10, + TotalTokens: 15, + }, + Choices: []schemas.BifrostResponseChoice{ + { + Index: 0, + Message: schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleAssistant, + Content: schemas.MessageContent{ + ContentStr: func() *string { s := "Mock plugin default response"; return &s }(), + }, + }, + FinishReason: &finishReason, + }, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: req.Provider, + }, + }, + }, nil + default: // DefaultBehaviorPassthrough + return req, nil, nil + } +} + +// Helper functions + +// sortCompiledRulesByPriority sorts rules by priority (descending) +func (p *MockerPlugin) sortCompiledRulesByPriority() { + sort.Slice(p.compiledRules, func(i, j int) bool { + return p.compiledRules[i].Priority > p.compiledRules[j].Priority + }) +} + +// applyTemplate applies template variables with optimized string operations including faker support +func (p *MockerPlugin) applyTemplate(template string, req *schemas.BifrostRequest) string { + // Fast path: no template variables + if !strings.Contains(template, "{{") { + return template + } + + result := template + + // Replace basic variables first + replacer := strings.NewReplacer( + "{{provider}}", string(req.Provider), + "{{model}}", req.Model, + ) + result = replacer.Replace(result) + + // Handle faker variables with regex for more complex patterns + fakerRegex := regexp.MustCompile(`\{\{faker\.([^}]+)\}\}`) + result = fakerRegex.ReplaceAllStringFunc(result, func(match string) string { + // Extract the faker method name + submatch := fakerRegex.FindStringSubmatch(match) + if len(submatch) < 2 { + return match // Return original if no match + } + + fakerMethod := submatch[1] + return p.generateFakerValue(fakerMethod) + }) + + return result +} + +// generateFakerValue generates fake data based on the faker method name +func (p *MockerPlugin) generateFakerValue(method string) string { + // Parse method with potential parameters (e.g., "lorem_ipsum:20" for 20 words) + parts := strings.Split(method, ":") + baseMethod := parts[0] + + switch baseMethod { + case "name": + return p.faker.Person().Name() + case "first_name": + return p.faker.Person().FirstName() + case "last_name": + return p.faker.Person().LastName() + case "email": + return p.faker.Internet().Email() + case "phone": + return p.faker.Phone().Number() + case "address": + return p.faker.Address().Address() + case "city": + return p.faker.Address().City() + case "state": + return p.faker.Address().State() + case "zip_code": + return p.faker.Address().PostCode() + case "company": + return p.faker.Company().Name() + case "job_title": + return p.faker.Company().JobTitle() + case "lorem_ipsum": + wordCount := 10 // default + if len(parts) > 1 { + if count, err := fmt.Sscanf(parts[1], "%d", &wordCount); err != nil || count != 1 { + wordCount = 10 + } + } + return p.faker.Lorem().Sentence(wordCount) + case "uuid": + return p.faker.UUID().V4() + case "hex_color": + return p.faker.Color().Hex() + case "integer": + min, max := 1, 100 // defaults + if len(parts) > 1 { + params := strings.Split(parts[1], ",") + if len(params) >= 2 { + if _, err := fmt.Sscanf(params[0], "%d", &min); err != nil { + min = 1 // fallback to default on parse error + } + if _, err := fmt.Sscanf(params[1], "%d", &max); err != nil { + max = 100 // fallback to default on parse error + } + } + } + return fmt.Sprintf("%d", p.faker.IntBetween(min, max)) + case "float": + min, max := 0, 100 // defaults as integers + if len(parts) > 1 { + params := strings.Split(parts[1], ",") + if len(params) >= 2 { + if _, err := fmt.Sscanf(params[0], "%d", &min); err != nil { + min = 0 // fallback to default on parse error + } + if _, err := fmt.Sscanf(params[1], "%d", &max); err != nil { + max = 100 // fallback to default on parse error + } + } + } + return fmt.Sprintf("%.2f", p.faker.Float64(2, min, max)) + case "boolean": + return fmt.Sprintf("%t", p.faker.Bool()) + case "date": + return p.faker.Time().Time(time.Now()).Format("2006-01-02") + case "datetime": + return p.faker.Time().Time(time.Now()).Format("2006-01-02 15:04:05") + case "word": + return p.faker.Lorem().Word() + case "sentence": + wordCount := 8 // default + if len(parts) > 1 { + if count, err := fmt.Sscanf(parts[1], "%d", &wordCount); err != nil || count != 1 { + wordCount = 8 + } + } + return p.faker.Lorem().Sentence(wordCount) + default: + // Return the original placeholder if method is not recognized + return fmt.Sprintf("{{faker.%s}}", method) + } +} + +// getOrDefault returns value or default if 0 +func (p *MockerPlugin) getOrDefault(value, defaultValue int) int { + if value == 0 { + return defaultValue + } + return value +} + +// GetStats returns current plugin statistics +// IMPORTANT: Call this method before Cleanup() if you need the statistics, +// as Cleanup() clears all statistics data to free memory +func (p *MockerPlugin) GetStats() MockStats { + p.mu.RLock() + defer p.mu.RUnlock() + + // Create a deep copy using atomic reads for counters + statsCopy := MockStats{ + TotalRequests: atomic.LoadInt64(&p.totalRequests), + MockedRequests: atomic.LoadInt64(&p.mockedRequests), + ErrorsGenerated: atomic.LoadInt64(&p.errorsGenerated), + ResponsesGenerated: atomic.LoadInt64(&p.responsesGenerated), + RuleHits: make(map[string]int64), + } + + // Copy rule hits map (still needs lock) + p.ruleHitsMu.RLock() + maps.Copy(statsCopy.RuleHits, p.ruleHits) + p.ruleHitsMu.RUnlock() + + return statsCopy +} diff --git a/plugins/mocker/plugin_test.go b/plugins/mocker/plugin_test.go new file mode 100644 index 0000000000..2112df26c6 --- /dev/null +++ b/plugins/mocker/plugin_test.go @@ -0,0 +1,538 @@ +package mocker + +import ( + "context" + "testing" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// BaseAccount implements the schemas.Account interface for testing purposes. +// It provides mock implementations of the required methods to test the Mocker plugin +// with a basic OpenAI configuration. +type BaseAccount struct{} + +// GetConfiguredProviders returns a list of supported providers for testing. +func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + return []schemas.ModelProvider{schemas.OpenAI, schemas.Anthropic}, nil +} + +// GetKeysForProvider returns a dummy API key configuration for testing. +// Since we're testing the mocker plugin, these keys should never be used +// as the plugin intercepts requests before they reach the actual providers. +func (baseAccount *BaseAccount) GetKeysForProvider(providerKey schemas.ModelProvider) ([]schemas.Key, error) { + return []schemas.Key{ + { + Value: "dummy-api-key-for-testing", // Dummy key + Models: []string{"gpt-4", "gpt-4-turbo", "claude-3"}, + Weight: 1.0, + }, + }, nil +} + +// GetConfigForProvider returns default provider configuration for testing. +func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil +} + +// TestMockerPlugin_GetName tests the plugin name +func TestMockerPlugin_GetName(t *testing.T) { + plugin, err := NewMockerPlugin(MockerConfig{}) + if err != nil { + t.Fatalf("Expected no error creating plugin, got: %v", err) + } + if plugin.GetName() != PluginName { + t.Errorf("Expected '%s', got '%s'", PluginName, plugin.GetName()) + } +} + +// TestMockerPlugin_Disabled tests that disabled plugin doesn't interfere +func TestMockerPlugin_Disabled(t *testing.T) { + config := MockerConfig{ + Enabled: false, + } + plugin, err := NewMockerPlugin(config) + if err != nil { + t.Fatalf("Expected no error creating plugin, got: %v", err) + } + + account := BaseAccount{} + client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelError), + }) + if err != nil { + t.Fatalf("Error initializing Bifrost: %v", err) + } + defer client.Cleanup() + + // This should pass through to the real provider (but will fail due to dummy key) + _, bifrostErr := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, test message"), + }, + }, + }, + }, + }) + + // Should get an authentication error from OpenAI, not a mock response + // This proves the plugin is disabled and not intercepting requests + if bifrostErr == nil { + t.Error("Expected error from real provider with dummy API key") + } +} + +// TestMockerPlugin_DefaultMockRule tests the default catch-all rule +func TestMockerPlugin_DefaultMockRule(t *testing.T) { + config := MockerConfig{ + Enabled: true, // No rules provided, should create default rule + } + plugin, err := NewMockerPlugin(config) + if err != nil { + t.Fatalf("Expected no error creating plugin, got: %v", err) + } + + account := BaseAccount{} + client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelError), + }) + if err != nil { + t.Fatalf("Error initializing Bifrost: %v", err) + } + defer client.Cleanup() + + response, bifrostErr := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, test message"), + }, + }, + }, + }, + }) + + if bifrostErr != nil { + t.Fatalf("Expected no error, got: %v", bifrostErr) + } + if response == nil { + t.Fatal("Expected response") + } + if len(response.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + if response.Choices[0].Message.Content.ContentStr == nil { + t.Fatal("Expected content string") + } + if *response.Choices[0].Message.Content.ContentStr != "This is a mock response from the Mocker plugin" { + t.Errorf("Expected default mock message, got: %s", *response.Choices[0].Message.Content.ContentStr) + } +} + +// TestMockerPlugin_CustomSuccessRule tests custom success response +func TestMockerPlugin_CustomSuccessRule(t *testing.T) { + config := MockerConfig{ + Enabled: true, + Rules: []MockRule{ + { + Name: "openai-success", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: Conditions{ + Providers: []string{"openai"}, + }, + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "Custom OpenAI mock response", + Usage: &Usage{ + PromptTokens: 15, + CompletionTokens: 25, + TotalTokens: 40, + }, + }, + }, + }, + }, + }, + } + plugin, err := NewMockerPlugin(config) + if err != nil { + t.Fatalf("Expected no error creating plugin, got: %v", err) + } + + account := BaseAccount{} + client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelError), + }) + if err != nil { + t.Fatalf("Error initializing Bifrost: %v", err) + } + defer client.Cleanup() + + response, bifrostErr := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, test message"), + }, + }, + }, + }, + }) + + if bifrostErr != nil { + t.Fatalf("Expected no error, got: %v", bifrostErr) + } + if response == nil { + t.Fatal("Expected response") + } + if len(response.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + if response.Choices[0].Message.Content.ContentStr == nil { + t.Fatal("Expected content string") + } + if *response.Choices[0].Message.Content.ContentStr != "Custom OpenAI mock response" { + t.Errorf("Expected custom message, got: %s", *response.Choices[0].Message.Content.ContentStr) + } + if response.Usage.TotalTokens != 40 { + t.Errorf("Expected 40 total tokens, got %d", response.Usage.TotalTokens) + } +} + +// TestMockerPlugin_ErrorResponse tests error response generation +func TestMockerPlugin_ErrorResponse(t *testing.T) { + allowFallbacks := false + config := MockerConfig{ + Enabled: true, + Rules: []MockRule{ + { + Name: "rate-limit-error", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: Conditions{ + Providers: []string{"openai"}, + }, + Responses: []Response{ + { + Type: ResponseTypeError, + AllowFallbacks: &allowFallbacks, + Error: &ErrorResponse{ + Message: "Rate limit exceeded", + Type: bifrost.Ptr("rate_limit"), + Code: bifrost.Ptr("429"), + StatusCode: bifrost.Ptr(429), + }, + }, + }, + }, + }, + } + plugin, err := NewMockerPlugin(config) + if err != nil { + t.Fatalf("Expected no error creating plugin, got: %v", err) + } + + account := BaseAccount{} + client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelError), + }) + if err != nil { + t.Fatalf("Error initializing Bifrost: %v", err) + } + defer client.Cleanup() + + _, bifrostErr := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, test message"), + }, + }, + }, + }, + }) + + if bifrostErr == nil { + t.Fatal("Expected error response") + } + if bifrostErr.Error.Message != "Rate limit exceeded" { + t.Errorf("Expected 'Rate limit exceeded', got: %s", bifrostErr.Error.Message) + } + if bifrostErr.StatusCode == nil || *bifrostErr.StatusCode != 429 { + t.Errorf("Expected status code 429, got: %v", bifrostErr.StatusCode) + } +} + +// TestMockerPlugin_MessageTemplate tests template variable substitution +func TestMockerPlugin_MessageTemplate(t *testing.T) { + config := MockerConfig{ + Enabled: true, + Rules: []MockRule{ + { + Name: "template-test", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: Conditions{}, // Match all + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + MessageTemplate: bifrost.Ptr("Hello from {{provider}} using model {{model}}"), + }, + }, + }, + }, + }, + } + plugin, err := NewMockerPlugin(config) + if err != nil { + t.Fatalf("Expected no error creating plugin, got: %v", err) + } + + account := BaseAccount{} + client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelError), + }) + if err != nil { + t.Fatalf("Error initializing Bifrost: %v", err) + } + defer client.Cleanup() + + response, bifrostErr := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.Anthropic, + Model: "claude-3", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, test message"), + }, + }, + }, + }, + }) + + if bifrostErr != nil { + t.Fatalf("Expected no error, got: %v", bifrostErr) + } + if response == nil { + t.Fatal("Expected response") + } + if len(response.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + if response.Choices[0].Message.Content.ContentStr == nil { + t.Fatal("Expected content string") + } + expectedMessage := "Hello from anthropic using model claude-3" + if *response.Choices[0].Message.Content.ContentStr != expectedMessage { + t.Errorf("Expected '%s', got: %s", expectedMessage, *response.Choices[0].Message.Content.ContentStr) + } +} + +// TestMockerPlugin_Statistics tests plugin statistics tracking +func TestMockerPlugin_Statistics(t *testing.T) { + config := MockerConfig{ + Enabled: true, + Rules: []MockRule{ + { + Name: "stats-test", + Enabled: true, + Priority: 100, + Probability: 1.0, + Conditions: Conditions{}, // Match all + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "Stats test response", + }, + }, + }, + }, + }, + } + plugin, err := NewMockerPlugin(config) + if err != nil { + t.Fatalf("Expected no error creating plugin, got: %v", err) + } + + account := BaseAccount{} + client, err := bifrost.Init(schemas.BifrostConfig{ + Account: &account, + Plugins: []schemas.Plugin{plugin}, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelError), + }) + if err != nil { + t.Fatalf("Error initializing Bifrost: %v", err) + } + defer client.Cleanup() + + // Make multiple requests + for i := 0; i < 3; i++ { + _, _ = client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr("Hello, test message"), + }, + }, + }, + }, + }) + } + + // Check statistics + stats := plugin.GetStats() + if stats.TotalRequests != 3 { + t.Errorf("Expected 3 total requests, got %d", stats.TotalRequests) + } + if stats.MockedRequests != 3 { + t.Errorf("Expected 3 mocked requests, got %d", stats.MockedRequests) + } + if stats.ResponsesGenerated != 3 { + t.Errorf("Expected 3 responses generated, got %d", stats.ResponsesGenerated) + } + if stats.RuleHits["stats-test"] != 3 { + t.Errorf("Expected 3 hits for 'stats-test' rule, got %d", stats.RuleHits["stats-test"]) + } +} + +// TestMockerPlugin_ValidationErrors tests configuration validation +func TestMockerPlugin_ValidationErrors(t *testing.T) { + tests := []struct { + name string + config MockerConfig + expectError bool + }{ + { + name: "invalid default behavior", + config: MockerConfig{ + Enabled: true, + DefaultBehavior: "invalid", + }, + expectError: true, + }, + { + name: "missing rule name", + config: MockerConfig{ + Enabled: true, + Rules: []MockRule{ + { + Name: "", // Missing name + Enabled: true, + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "test", + }, + }, + }, + }, + }, + }, + expectError: true, + }, + { + name: "invalid probability", + config: MockerConfig{ + Enabled: true, + Rules: []MockRule{ + { + Name: "test", + Enabled: true, + Probability: 1.5, // Invalid probability > 1 + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "test", + }, + }, + }, + }, + }, + }, + expectError: true, + }, + { + name: "valid configuration", + config: MockerConfig{ + Enabled: true, + DefaultBehavior: DefaultBehaviorPassthrough, + Rules: []MockRule{ + { + Name: "valid-rule", + Enabled: true, + Probability: 0.5, + Responses: []Response{ + { + Type: ResponseTypeSuccess, + Content: &SuccessResponse{ + Message: "Valid response", + }, + }, + }, + }, + }, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewMockerPlugin(tt.config) + if tt.expectError && err == nil { + t.Error("Expected error but got none") + } + if !tt.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + }) + } +} diff --git a/tests/core-chatbot/go.mod b/tests/core-chatbot/go.mod new file mode 100644 index 0000000000..fc3cd00cb1 --- /dev/null +++ b/tests/core-chatbot/go.mod @@ -0,0 +1,34 @@ +module github.com/maximhq/bifrost/tests/core-chatbot + +go 1.24.1 + +require github.com/maximhq/bifrost/core v1.1.5 + +require ( + cloud.google.com/go/compute/metadata v0.3.0 // indirect + github.com/andybalholm/brotli v1.1.1 // indirect + github.com/aws/aws-sdk-go-v2 v1.36.3 // indirect + github.com/aws/aws-sdk-go-v2/config v1.29.14 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect + github.com/aws/smithy-go v1.22.3 // indirect + github.com/goccy/go-json v0.10.5 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/mark3labs/mcp-go v0.32.0 // indirect + github.com/spf13/cast v1.7.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.60.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/net v0.39.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/text v0.24.0 // indirect +) diff --git a/tests/core-chatbot/go.sum b/tests/core-chatbot/go.sum new file mode 100644 index 0000000000..affaa20bcf --- /dev/null +++ b/tests/core-chatbot/go.sum @@ -0,0 +1,74 @@ +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= +github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= +github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= +github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= +github.com/aws/aws-sdk-go-v2/config v1.29.14 h1:f+eEi/2cKCg9pqKBoAIwRGzVb70MRKqWX4dg1BDcSJM= +github.com/aws/aws-sdk-go-v2/config v1.29.14/go.mod h1:wVPHWcIFv3WO89w0rE10gzf17ZYy+UVS1Geq8Iei34g= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67 h1:9KxtdcIA/5xPNQyZRgUSpYOE6j9Bc4+D7nZua0KGYOM= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67/go.mod h1:p3C44m+cfnbv763s52gCqrjaqyPikj9Sg47kUVaNZQQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 h1:1Gw+9ajCV1jogloEv1RRnvfRFia2cL6c9cuKV2Ps+G8= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0cFmC3JvwLm5kM83luako= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/XvaX32evhproijJEZY= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= +github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k= +github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= +github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8= +github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/maximhq/bifrost/core v1.1.5 h1:Nm9XlS9Nso+pn+U5/btsJD8qRDYGQ1BBOjgqWT3PYSc= +github.com/maximhq/bifrost/core v1.1.5/go.mod h1:yMRCncTgKYBIrECSRVxMbY3BL8CjLbipJlc644jryxc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.60.0 h1:kBRYS0lOhVJ6V+bYN8PqAHELKHtXqwq9zNMLKx1MBsw= +github.com/valyala/fasthttp v1.60.0/go.mod h1:iY4kDgV3Gc6EqhRZ8icqcmlG6bqhcDXfuHgTO4FXCvc= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/tests/core-chatbot/main.go b/tests/core-chatbot/main.go new file mode 100644 index 0000000000..2a734f1c7e --- /dev/null +++ b/tests/core-chatbot/main.go @@ -0,0 +1,943 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "os" + "os/signal" + "strconv" + "strings" + "sync" + "syscall" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/core/schemas/meta" +) + +// ChatbotConfig holds configuration for the chatbot +type ChatbotConfig struct { + Provider schemas.ModelProvider + Model string + MCPAgenticMode bool + MCPServerPort int + Temperature *float64 + MaxTokens *int +} + +// ChatSession manages the conversation state +type ChatSession struct { + history []schemas.BifrostMessage + client *bifrost.Bifrost + config ChatbotConfig + systemPrompt string + account *ComprehensiveTestAccount +} + +// ComprehensiveTestAccount provides a test implementation of the Account interface for comprehensive testing. +type ComprehensiveTestAccount struct{} + +// getEnvWithDefault returns the value of the environment variable if set, otherwise returns the default value +func getEnvWithDefault(envVar, defaultValue string) string { + if value := os.Getenv(envVar); value != "" { + return value + } + return defaultValue +} + +// GetConfiguredProviders returns the list of initially supported providers. +func (account *ComprehensiveTestAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + return []schemas.ModelProvider{ + schemas.OpenAI, + schemas.Anthropic, + schemas.Bedrock, + schemas.Cohere, + schemas.Azure, + schemas.Vertex, + schemas.Ollama, + schemas.Mistral, + }, nil +} + +// GetKeysForProvider returns the API keys and associated models for a given provider. +func (account *ComprehensiveTestAccount) GetKeysForProvider(providerKey schemas.ModelProvider) ([]schemas.Key, error) { + switch providerKey { + case schemas.OpenAI: + return []schemas.Key{ + { + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{"gpt-4o-mini", "gpt-4-turbo", "gpt-4o"}, + Weight: 1.0, + }, + }, nil + case schemas.Anthropic: + return []schemas.Key{ + { + Value: os.Getenv("ANTHROPIC_API_KEY"), + Models: []string{"claude-3-7-sonnet-20250219", "claude-3-5-sonnet-20240620", "claude-2.1"}, + Weight: 1.0, + }, + }, nil + case schemas.Bedrock: + return []schemas.Key{ + { + Value: os.Getenv("BEDROCK_API_KEY"), + Models: []string{"anthropic.claude-v2:1", "mistral.mixtral-8x7b-instruct-v0:1", "mistral.mistral-large-2402-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"}, + Weight: 1.0, + }, + }, nil + case schemas.Cohere: + return []schemas.Key{ + { + Value: os.Getenv("COHERE_API_KEY"), + Models: []string{"command-a-03-2025", "c4ai-aya-vision-8b"}, + Weight: 1.0, + }, + }, nil + case schemas.Azure: + return []schemas.Key{ + { + Value: os.Getenv("AZURE_API_KEY"), + Models: []string{"gpt-4o"}, + Weight: 1.0, + }, + }, nil + case schemas.Vertex: + return []schemas.Key{ + { + Value: os.Getenv("VERTEX_API_KEY"), + Models: []string{"gemini-pro", "gemini-1.5-pro"}, + Weight: 1.0, + }, + }, nil + case schemas.Mistral: + return []schemas.Key{ + { + Value: os.Getenv("MISTRAL_API_KEY"), + Models: []string{"mistral-large-2411", "pixtral-12b-latest"}, + Weight: 1.0, + }, + }, nil + case schemas.Ollama: + return []schemas.Key{ + { + Value: "", // Ollama is keyless + Models: []string{"llama3.2", "llama3.1", "mistral", "codellama"}, + Weight: 1.0, + }, + }, nil + default: + return nil, fmt.Errorf("unsupported provider: %s", providerKey) + } +} + +// GetConfigForProvider returns the configuration settings for a given provider. +func (account *ComprehensiveTestAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + switch providerKey { + case schemas.OpenAI: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case schemas.Anthropic: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + case schemas.Bedrock: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + MetaConfig: &meta.BedrockMetaConfig{ + SecretAccessKey: os.Getenv("AWS_SECRET_ACCESS_KEY"), + Region: bifrost.Ptr(getEnvWithDefault("AWS_REGION", "us-east-1")), + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case schemas.Cohere: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + case schemas.Azure: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + MetaConfig: &meta.AzureMetaConfig{ + Endpoint: os.Getenv("AZURE_ENDPOINT"), + Deployments: map[string]string{ + "gpt-4o": "gpt-4o-aug", + }, + APIVersion: bifrost.Ptr(getEnvWithDefault("AZURE_API_VERSION", "2024-08-01-preview")), + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case schemas.Vertex: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + MetaConfig: &meta.VertexMetaConfig{ + ProjectID: os.Getenv("VERTEX_PROJECT_ID"), + Region: getEnvWithDefault("VERTEX_REGION", "us-central1"), + AuthCredentials: os.Getenv("VERTEX_CREDENTIALS"), + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case schemas.Ollama: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + case schemas.Mistral: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + default: + return nil, fmt.Errorf("unsupported provider: %s", providerKey) + } +} + +// NewChatSession creates a new chat session with the given configuration +func NewChatSession(config ChatbotConfig) (*ChatSession, error) { + // Create MCP configuration for Bifrost + mcpConfig := &schemas.MCPConfig{ + ServerPort: bifrost.Ptr(config.MCPServerPort), + ClientConfigs: []schemas.MCPClientConfig{}, + } + + fmt.Println("πŸ”Œ Configuring Serper MCP server...") + mcpConfig.ClientConfigs = append(mcpConfig.ClientConfigs, schemas.MCPClientConfig{ + Name: "serper-web-search-mcp", + ConnectionType: schemas.MCPConnectionTypeSTDIO, + StdioConfig: &schemas.MCPStdioConfig{ + Command: "npx", + Args: []string{"-y", "serper-search-scrape-mcp-server"}, + Envs: []string{"SERPER_API_KEY"}, + }, + ToolsToSkip: []string{}, // No tools to skip for this client + }, + schemas.MCPClientConfig{ + Name: "gmail-mcp", + ConnectionType: schemas.MCPConnectionTypeSSE, + ConnectionString: bifrost.Ptr("https://mcp.composio.dev/composio/server/654c1e3f-ea7d-47b6-9e31-398d00449654/sse"), + }, + ) + + fmt.Println("πŸ”Œ Configuring Context7 MCP server...") + mcpConfig.ClientConfigs = append(mcpConfig.ClientConfigs, schemas.MCPClientConfig{ + Name: "context7", + ConnectionType: schemas.MCPConnectionTypeSTDIO, + StdioConfig: &schemas.MCPStdioConfig{ + Command: "npx", + Args: []string{"-y", "@upstash/context7-mcp"}, + }, + ToolsToSkip: []string{}, // No tools to skip for this client + }) + + // Initialize Bifrost with MCP configuration + account := &ComprehensiveTestAccount{} + + client, err := bifrost.Init(schemas.BifrostConfig{ + Account: account, + Plugins: []schemas.Plugin{}, // No separate plugins needed - MCP is integrated + Logger: bifrost.NewDefaultLogger(schemas.LogLevelInfo), + MCPConfig: mcpConfig, // MCP is now configured here + }) + if err != nil { + return nil, fmt.Errorf("failed to initialize Bifrost: %w", err) + } + + session := &ChatSession{ + history: make([]schemas.BifrostMessage, 0), + client: client, + config: config, + account: account, + systemPrompt: "You are a helpful AI assistant with access to various tools. " + + "Use the available tools when they can help answer the user's questions more accurately or provide additional information.", + } + + // Add system message to history + if session.systemPrompt != "" { + session.history = append(session.history, schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleSystem, + Content: schemas.MessageContent{ + ContentStr: &session.systemPrompt, + }, + }) + } + + return session, nil +} + +// getAvailableProviders returns a list of providers that have valid configurations +func (s *ChatSession) getAvailableProviders() []schemas.ModelProvider { + configuredProviders, err := s.account.GetConfiguredProviders() + if err != nil { + return []schemas.ModelProvider{} + } + + var availableProviders []schemas.ModelProvider + for _, provider := range configuredProviders { + // Check if provider has valid keys (except for keyless providers) + if provider == schemas.Ollama || provider == schemas.Vertex { + availableProviders = append(availableProviders, provider) + continue + } + + keys, err := s.account.GetKeysForProvider(provider) + if err == nil && len(keys) > 0 && keys[0].Value != "" { + availableProviders = append(availableProviders, provider) + } + } + return availableProviders +} + +// getAvailableModels returns available models for a given provider +func (s *ChatSession) getAvailableModels(provider schemas.ModelProvider) []string { + keys, err := s.account.GetKeysForProvider(provider) + if err != nil || len(keys) == 0 { + return []string{} + } + return keys[0].Models +} + +// switchProvider handles switching to a different provider +func (s *ChatSession) switchProvider() error { + availableProviders := s.getAvailableProviders() + if len(availableProviders) == 0 { + fmt.Println("❌ No available providers found") + return fmt.Errorf("no available providers") + } + + fmt.Println("\nπŸ”„ Available Providers:") + fmt.Println("======================") + for i, provider := range availableProviders { + status := "" + if provider == s.config.Provider { + status = " (current)" + } + fmt.Printf("[%d] %s%s\n", i+1, provider, status) + } + + fmt.Print("\nSelect provider (number): ") + scanner := bufio.NewScanner(os.Stdin) + if !scanner.Scan() { + return fmt.Errorf("input cancelled") + } + + choice, err := strconv.Atoi(strings.TrimSpace(scanner.Text())) + if err != nil || choice < 1 || choice > len(availableProviders) { + return fmt.Errorf("invalid choice") + } + + newProvider := availableProviders[choice-1] + + // Get available models for the new provider + models := s.getAvailableModels(newProvider) + if len(models) == 0 { + return fmt.Errorf("no models available for provider %s", newProvider) + } + + // Auto-select first model or let user choose if multiple + var newModel string + if len(models) == 1 { + newModel = models[0] + } else { + fmt.Printf("\n🧠 Available Models for %s:\n", newProvider) + fmt.Println("================================") + for i, model := range models { + fmt.Printf("[%d] %s\n", i+1, model) + } + + fmt.Print("\nSelect model (number): ") + if !scanner.Scan() { + return fmt.Errorf("input cancelled") + } + + modelChoice, err := strconv.Atoi(strings.TrimSpace(scanner.Text())) + if err != nil || modelChoice < 1 || modelChoice > len(models) { + return fmt.Errorf("invalid model choice") + } + + newModel = models[modelChoice-1] + } + + // Update configuration + s.config.Provider = newProvider + s.config.Model = newModel + + fmt.Printf("βœ… Switched to %s with model %s\n", newProvider, newModel) + return nil +} + +// switchModel handles switching to a different model for the current provider +func (s *ChatSession) switchModel() error { + models := s.getAvailableModels(s.config.Provider) + if len(models) == 0 { + return fmt.Errorf("no models available for provider %s", s.config.Provider) + } + + if len(models) == 1 { + fmt.Printf("Only one model available for %s: %s\n", s.config.Provider, models[0]) + return nil + } + + fmt.Printf("\n🧠 Available Models for %s:\n", s.config.Provider) + fmt.Println("===============================") + for i, model := range models { + status := "" + if model == s.config.Model { + status = " (current)" + } + fmt.Printf("[%d] %s%s\n", i+1, model, status) + } + + fmt.Print("\nSelect model (number): ") + scanner := bufio.NewScanner(os.Stdin) + if !scanner.Scan() { + return fmt.Errorf("input cancelled") + } + + choice, err := strconv.Atoi(strings.TrimSpace(scanner.Text())) + if err != nil || choice < 1 || choice > len(models) { + return fmt.Errorf("invalid choice") + } + + newModel := models[choice-1] + s.config.Model = newModel + + fmt.Printf("βœ… Switched to model %s\n", newModel) + return nil +} + +// showCurrentConfig displays the current configuration +func (s *ChatSession) showCurrentConfig() { + fmt.Println("\nβš™οΈ Current Configuration:") + fmt.Println("=========================") + fmt.Printf("πŸ”§ Provider: %s\n", s.config.Provider) + fmt.Printf("🧠 Model: %s\n", s.config.Model) + fmt.Printf("πŸ”„ Agentic Mode: %t\n", s.config.MCPAgenticMode) + fmt.Printf("🌑️ Temperature: %.1f\n", *s.config.Temperature) + fmt.Printf("πŸ“ Max Tokens: %d\n", *s.config.MaxTokens) + fmt.Printf("πŸ”§ Tool Execution: Manual approval required\n") +} + +// AddUserMessage adds a user message to the conversation history +func (s *ChatSession) AddUserMessage(message string) { + userMessage := schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: &message, + }, + } + s.history = append(s.history, userMessage) +} + +// SendMessage sends a message and returns the assistant's response +func (s *ChatSession) SendMessage(message string) (string, error) { + // Add user message to history + s.AddUserMessage(message) + + // Prepare model parameters + params := &schemas.ModelParameters{} + if s.config.Temperature != nil { + params.Temperature = s.config.Temperature + } + if s.config.MaxTokens != nil { + params.MaxTokens = s.config.MaxTokens + } + params.ToolChoice = &schemas.ToolChoice{ + ToolChoiceStr: stringPtr("auto"), + } + + // Create request + request := &schemas.BifrostRequest{ + Provider: s.config.Provider, + Model: s.config.Model, + Input: schemas.RequestInput{ + ChatCompletionInput: &s.history, + }, + Params: params, + } + + // Start loading animation + stopChan, wg := startLoader() + + // Send request + response, err := s.client.ChatCompletionRequest(context.Background(), request) + + // Stop loading animation + stopLoader(stopChan, wg) + + if err != nil { + return "", fmt.Errorf("chat completion failed: %s", err.Error.Message) + } + + if response == nil || len(response.Choices) == 0 { + return "", fmt.Errorf("no response received") + } + + // Get the assistant's response + choice := response.Choices[0] + assistantMessage := choice.Message + + // Add assistant message to history + s.history = append(s.history, assistantMessage) + + // Check if assistant wants to use tools + if assistantMessage.ToolCalls != nil && len(*assistantMessage.ToolCalls) > 0 { + return s.handleToolCalls(assistantMessage) + } + + // Extract text content for regular responses + var responseText string + if assistantMessage.Content.ContentStr != nil { + responseText = *assistantMessage.Content.ContentStr + } else if assistantMessage.Content.ContentBlocks != nil { + var textParts []string + for _, block := range *assistantMessage.Content.ContentBlocks { + if block.Text != nil { + textParts = append(textParts, *block.Text) + } + } + responseText = strings.Join(textParts, "\n") + } + + return responseText, nil +} + +// handleToolCalls handles tool execution using the new Bifrost MCP integration +func (s *ChatSession) handleToolCalls(assistantMessage schemas.BifrostMessage) (string, error) { + toolCalls := *assistantMessage.ToolCalls + + // Display tools to user for approval + fmt.Println("\nπŸ”§ Assistant wants to use the following tools:") + fmt.Println("============================================") + + for i, toolCall := range toolCalls { + fmt.Printf("[%d] Tool: %s\n", i+1, *toolCall.Function.Name) + fmt.Printf(" Arguments: %s\n", toolCall.Function.Arguments) + fmt.Println() + } + + fmt.Print("Do you want to execute these tools? (y/n): ") + + scanner := bufio.NewScanner(os.Stdin) + if !scanner.Scan() { + return "❌ Tool execution cancelled by user.", nil + } + + input := strings.ToLower(strings.TrimSpace(scanner.Text())) + if input != "y" && input != "yes" { + return "❌ Tool execution cancelled by user.", nil + } + + fmt.Println("βœ… Executing tools...") + + // Execute each tool using Bifrost's ExecuteMCPTool method + toolResults := make([]schemas.BifrostMessage, 0) + for _, toolCall := range toolCalls { + // Start loading animation for this tool + stopChan, wg := startLoader() + + // Execute the tool using Bifrost's integrated MCP functionality + toolResult, err := s.client.ExecuteMCPTool(context.Background(), toolCall) + + // Stop loading animation + stopLoader(stopChan, wg) + + if err != nil { + fmt.Printf("❌ Error executing tool %s: %v\n", *toolCall.Function.Name, err) + // Create error message for this tool + errorResult := schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleTool, + Content: schemas.MessageContent{ + ContentStr: stringPtr(fmt.Sprintf("Error executing tool: %v", err)), + }, + ToolMessage: &schemas.ToolMessage{ + ToolCallID: toolCall.ID, + }, + } + toolResults = append(toolResults, errorResult) + } else { + fmt.Printf("βœ… Tool %s executed successfully\n", *toolCall.Function.Name) + toolResults = append(toolResults, *toolResult) + } + } + + // Add tool results to conversation history + s.history = append(s.history, toolResults...) + + // If agentic mode is enabled, send conversation back to LLM for synthesis + if s.config.MCPAgenticMode { + return s.synthesizeToolResults() + } + + // Non-agentic mode: return the results directly + var responseText strings.Builder + responseText.WriteString("πŸ”§ Tool execution completed:\n\n") + + for i, result := range toolResults { + if result.Content.ContentStr != nil { + responseText.WriteString(fmt.Sprintf("Tool %d result: %s\n", i+1, *result.Content.ContentStr)) + } + } + + return responseText.String(), nil +} + +// synthesizeToolResults sends the conversation with tool results back to LLM for synthesis +func (s *ChatSession) synthesizeToolResults() (string, error) { + // Add synthesis prompt + synthesisPrompt := schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: stringPtr("Please provide a comprehensive response based on the tool results above."), + }, + } + + // Temporarily add synthesis prompt for the request + conversationWithSynthesis := append(s.history, synthesisPrompt) + + // Create synthesis request + synthesisRequest := &schemas.BifrostRequest{ + Provider: s.config.Provider, + Model: s.config.Model, + Input: schemas.RequestInput{ + ChatCompletionInput: &conversationWithSynthesis, + }, + Params: &schemas.ModelParameters{ + Temperature: s.config.Temperature, + MaxTokens: s.config.MaxTokens, + }, + } + + fmt.Println("πŸ€– Synthesizing response...") + + // Start loading animation + stopChan, wg := startLoader() + + // Send synthesis request + synthesisResponse, err := s.client.ChatCompletionRequest(context.Background(), synthesisRequest) + + // Stop loading animation + stopLoader(stopChan, wg) + + if err != nil { + fmt.Printf("⚠️ Synthesis failed: %v. Returning tool results directly.\n", err) + // Fallback to direct tool results + var responseText strings.Builder + responseText.WriteString("πŸ”§ Tool execution completed (synthesis failed):\n\n") + + // Get tool results from history (last few messages that are tool messages) + for i := len(s.history) - 1; i >= 0; i-- { + if s.history[i].Role == schemas.ModelChatMessageRoleTool { + if s.history[i].Content.ContentStr != nil { + responseText.WriteString(fmt.Sprintf("Tool result: %s\n", *s.history[i].Content.ContentStr)) + } + } else { + break // Stop when we hit non-tool messages + } + } + + return responseText.String(), nil + } + + if synthesisResponse == nil || len(synthesisResponse.Choices) == 0 { + return "❌ No synthesis response received", nil + } + + // Get synthesized response + synthesizedMessage := synthesisResponse.Choices[0].Message + + // Add synthesized response to history (replace the temporary synthesis prompt effect) + s.history = append(s.history, synthesizedMessage) + + // Extract text content + var responseText string + if synthesizedMessage.Content.ContentStr != nil { + responseText = *synthesizedMessage.Content.ContentStr + } else if synthesizedMessage.Content.ContentBlocks != nil { + var textParts []string + for _, block := range *synthesizedMessage.Content.ContentBlocks { + if block.Text != nil { + textParts = append(textParts, *block.Text) + } + } + responseText = strings.Join(textParts, "\n") + } + + return responseText, nil +} + +// PrintHistory prints the conversation history +func (s *ChatSession) PrintHistory() { + fmt.Println("\nπŸ“œ Conversation History:") + fmt.Println("========================") + + for i, msg := range s.history { + if msg.Role == schemas.ModelChatMessageRoleSystem { + continue // Skip system messages in history display + } + + var content string + if msg.Content.ContentStr != nil { + content = *msg.Content.ContentStr + } else if msg.Content.ContentBlocks != nil { + var textParts []string + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + textParts = append(textParts, *block.Text) + } + } + content = strings.Join(textParts, "\n") + } + + role := strings.Title(string(msg.Role)) + timestamp := fmt.Sprintf("[%d]", i) + + fmt.Printf("%s %s: %s\n\n", timestamp, role, content) + } +} + +// Cleanup closes the chat session and cleans up resources +func (s *ChatSession) Cleanup() { + if s.client != nil { + s.client.Cleanup() + } +} + +// printWelcome prints the welcome message and instructions +func printWelcome(config ChatbotConfig) { + fmt.Println("πŸ€– Bifrost CLI Chatbot") + fmt.Println("======================") + fmt.Printf("πŸ”§ Provider: %s\n", config.Provider) + fmt.Printf("🧠 Model: %s\n", config.Model) + fmt.Printf("πŸ”„ Agentic Mode: %t\n", config.MCPAgenticMode) + fmt.Printf("πŸ”§ Tool Execution: Manual approval required\n") + fmt.Println() + fmt.Println("Commands:") + fmt.Println(" /help - Show this help message") + fmt.Println(" /history - Show conversation history") + fmt.Println(" /clear - Clear conversation history") + fmt.Println(" /config - Show current configuration") + fmt.Println(" /provider - Switch provider") + fmt.Println(" /model - Switch model") + fmt.Println(" /quit - Exit the chatbot") + fmt.Println() + fmt.Println("Type your message and press Enter to chat!") + fmt.Println("When the assistant wants to use tools, you'll be asked to approve them.") + fmt.Println("==========================================") +} + +// printHelp prints help information +func printHelp() { + fmt.Println("\nπŸ“– Help") + fmt.Println("========") + fmt.Println("Available commands:") + fmt.Println(" /help - Show this help message") + fmt.Println(" /history - Show conversation history") + fmt.Println(" /clear - Clear conversation history (keeps system prompt)") + fmt.Println(" /config - Show current provider, model, and settings") + fmt.Println(" /provider - Switch between different AI providers") + fmt.Println(" /model - Switch between models for current provider") + fmt.Println(" /quit - Exit the chatbot") + fmt.Println() + fmt.Println("Supported providers:") + fmt.Println("β€’ OpenAI (gpt-4o-mini, gpt-4-turbo, gpt-4o)") + fmt.Println("β€’ Anthropic (claude models)") + fmt.Println("β€’ Bedrock (AWS hosted models)") + fmt.Println("β€’ Cohere (command models)") + fmt.Println("β€’ Azure (Azure OpenAI models)") + fmt.Println("β€’ Vertex (Google Cloud models)") + fmt.Println("β€’ Mistral (mistral models)") + fmt.Println("β€’ Ollama (local models)") + fmt.Println() + fmt.Println("Tool execution:") + fmt.Println("β€’ When the assistant wants to use tools, you'll be asked to approve them") + fmt.Println("β€’ You can review the tool names and arguments before approving") + fmt.Println("β€’ Available tools include web search and Context7") + fmt.Println("β€’ In agentic mode, tool results are synthesized into natural responses") + fmt.Println("β€’ In non-agentic mode, raw tool results are displayed") + fmt.Println() +} + +// stringPtr is a helper function to create string pointers +func stringPtr(s string) *string { + return &s +} + +// startLoader starts a loading spinner animation +func startLoader() (chan bool, *sync.WaitGroup) { + stopChan := make(chan bool) + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + spinner := []string{"β ‹", "β ™", "β Ή", "β Έ", "β Ό", "β ΄", "β ¦", "β §", "β ‡", "⠏"} + i := 0 + + for { + select { + case <-stopChan: + // Clear the spinner + fmt.Print("\r\033[K") // Clear current line + return + default: + fmt.Printf("\rπŸ€– Assistant: %s Thinking...", spinner[i%len(spinner)]) + i++ + time.Sleep(100 * time.Millisecond) + } + } + }() + + return stopChan, &wg +} + +// stopLoader stops the loading animation +func stopLoader(stopChan chan bool, wg *sync.WaitGroup) { + close(stopChan) + wg.Wait() +} + +func main() { + // Check for required environment variables + if os.Getenv("OPENAI_API_KEY") == "" { + fmt.Println("❌ Error: OPENAI_API_KEY environment variable is required") + fmt.Println("πŸ’‘ Set additional provider API keys to access more models:") + fmt.Println(" - ANTHROPIC_API_KEY for Claude models") + fmt.Println(" - COHERE_API_KEY for Cohere models") + fmt.Println(" - MISTRAL_API_KEY for Mistral models") + fmt.Println(" - AWS credentials for Bedrock") + fmt.Println(" - AZURE_API_KEY and AZURE_ENDPOINT for Azure OpenAI") + fmt.Println(" - VERTEX_PROJECT_ID and credentials for Vertex AI") + os.Exit(1) + } + + // Default configuration + config := ChatbotConfig{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + MCPAgenticMode: true, + MCPServerPort: 8585, + Temperature: bifrost.Ptr(0.7), + MaxTokens: bifrost.Ptr(1000), + } + + // Create chat session + fmt.Println("πŸš€ Starting Bifrost CLI Chatbot...") + session, err := NewChatSession(config) + if err != nil { + fmt.Printf("❌ Failed to create chat session: %v\n", err) + os.Exit(1) + } + + // Setup graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigChan + fmt.Println("\n\nπŸ‘‹ Goodbye! Cleaning up...") + session.Cleanup() + os.Exit(0) + }() + + // Give MCP servers time to initialize + fmt.Println("⏳ Waiting for MCP servers to initialize...") + time.Sleep(3 * time.Second) + + // Print welcome message + printWelcome(config) + + // Main chat loop + scanner := bufio.NewScanner(os.Stdin) + for { + fmt.Print("\nπŸ’¬ You: ") + if !scanner.Scan() { + break + } + + input := strings.TrimSpace(scanner.Text()) + if input == "" { + continue + } + + // Handle commands + switch input { + case "/help": + printHelp() + continue + case "/history": + session.PrintHistory() + continue + case "/clear": + // Keep system prompt but clear conversation history + systemPrompt := session.history[0] // Assuming first message is system + session.history = []schemas.BifrostMessage{systemPrompt} + fmt.Println("🧹 Conversation history cleared!") + continue + case "/config": + session.showCurrentConfig() + continue + case "/provider": + if err := session.switchProvider(); err != nil { + fmt.Printf("❌ Error switching provider: %v\n", err) + } + continue + case "/model": + if err := session.switchModel(); err != nil { + fmt.Printf("❌ Error switching model: %v\n", err) + } + continue + case "/quit": + fmt.Println("πŸ‘‹ Goodbye!") + session.Cleanup() + return + } + + // Send message and get response + response, err := session.SendMessage(input) + if err != nil { + fmt.Printf("\rπŸ€– Assistant: ❌ Error: %v\n", err) + continue + } + + fmt.Printf("πŸ€– Assistant: %s\n", response) + } + + // Cleanup + session.Cleanup() +} diff --git a/tests/core-providers/README.md b/tests/core-providers/README.md new file mode 100644 index 0000000000..85d3556e39 --- /dev/null +++ b/tests/core-providers/README.md @@ -0,0 +1,417 @@ +# Bifrost Core Providers Test Suite πŸš€ + +This directory contains comprehensive tests for all Bifrost AI providers, ensuring compatibility and functionality across different AI services. + +## πŸ“‹ Supported Providers + +- **OpenAI** - GPT models and function calling +- **Anthropic** - Claude models +- **Azure OpenAI** - Azure-hosted OpenAI models +- **AWS Bedrock** - Amazon's managed AI service +- **Cohere** - Cohere's language models +- **Google Vertex AI** - Google Cloud's AI platform +- **Mistral** - Mistral AI models with vision capabilities +- **Ollama** - Local LLM serving platform + +## πŸƒβ€β™‚οΈ Running Tests + +### Development with Local Bifrost Core + +To test changes with a forked or local version of bifrost-core: + +1. **Uncomment the replace directive** in `tests/core-providers/go.mod`: + + ```go + // Uncomment this line to use your local bifrost-core + replace github.com/maximhq/bifrost/core => ../../core + ``` + +2. **Update dependencies**: + + ```bash + cd tests/core-providers + go mod tidy + ``` + +3. **Run tests** with your local changes: + + ```bash + go test -v ./tests/core-providers/ + ``` + +⚠️ **Important**: Ensure your local `../../core` directory contains your bifrost-core implementation. The path should be relative to the `tests/core-providers` directory. + +### Prerequisites + +Set up environment variables for the providers you want to test: + +```bash +# OpenAI +export OPENAI_API_KEY="your-openai-key" + +# Anthropic +export ANTHROPIC_API_KEY="your-anthropic-key" + +# Azure OpenAI +export AZURE_API_KEY="your-azure-key" +export AZURE_ENDPOINT="your-azure-endpoint" + +# AWS Bedrock +export AWS_ACCESS_KEY_ID="your-aws-access-key" +export AWS_SECRET_ACCESS_KEY="your-aws-secret-key" +export AWS_REGION="us-east-1" + +# Cohere +export COHERE_API_KEY="your-cohere-key" + +# Google Vertex AI +export GOOGLE_APPLICATION_CREDENTIALS="path/to/service-account.json" +export GOOGLE_PROJECT_ID="your-project-id" + +# Mistral AI +export MISTRAL_API_KEY="your-mistral-key" + +# Ollama (local installation) +# No API key required - ensure Ollama is running locally +# Default endpoint: http://localhost:11434 +``` + +### Run All Provider Tests + +```bash +# Run all tests with verbose output (recommended) +go test -v ./tests/core-providers/ + +# Run with debug logs +go test -v ./tests/core-providers/ -debug +``` + +### Run Specific Provider Tests + +```bash +# Test only OpenAI +go test -v ./tests/core-providers/ -run TestOpenAI + +# Test only Anthropic +go test -v ./tests/core-providers/ -run TestAnthropic + +# Test only Azure +go test -v ./tests/core-providers/ -run TestAzure + +# Test only Bedrock +go test -v ./tests/core-providers/ -run TestBedrock + +# Test only Cohere +go test -v ./tests/core-providers/ -run TestCohere + +# Test only Vertex AI +go test -v ./tests/core-providers/ -run TestVertex + +# Test only Mistral +go test -v ./tests/core-providers/ -run TestMistral + +# Test only Ollama +go test -v ./tests/core-providers/ -run TestOllama +``` + +### Run Specific Test Scenarios + +You can run specific scenarios across all providers: + +```bash +# Test only chat completion +go test -v ./tests/core-providers/ -run "Chat" + +# Test only function calling +go test -v ./tests/core-providers/ -run "Function" +``` + +### Run Specific Scenario for Specific Provider + +You can combine provider and scenario filters to test specific functionality: + +```bash +# Test only OpenAI simple chat +go test -v ./tests/core-providers/ -run "TestOpenAI/SimpleChat" + +# Test only Anthropic tool calls +go test -v ./tests/core-providers/ -run "TestAnthropic/ToolCalls" + +# Test only Azure multi-turn conversation +go test -v ./tests/core-providers/ -run "TestAzure/MultiTurnConversation" + +# Test only Bedrock text completion +go test -v ./tests/core-providers/ -run "TestBedrock/TextCompletion" + +# Test only Cohere image URL processing +go test -v ./tests/core-providers/ -run "TestCohere/ImageURL" + +# Test only Vertex automatic function calling +go test -v ./tests/core-providers/ -run "TestVertex/AutomaticFunctionCalling" + +# Test only Mistral image processing +go test -v ./tests/core-providers/ -run "TestMistral/ImageURL" + +# Test only Ollama simple chat +go test -v ./tests/core-providers/ -run "TestOllama/SimpleChat" +``` + +**Available Scenario Names:** + +- `SimpleChat` - Basic chat completion +- `TextCompletion` - Text completion (legacy models) +- `MultiTurnConversation` - Multi-turn chat conversations +- `ToolCalls` - Basic function/tool calling +- `MultipleToolCalls` - Multiple tool calls in one request +- `End2EndToolCalling` - Complete tool calling workflow +- `AutomaticFunctionCalling` - Automatic function selection +- `ImageURL` - Image processing from URLs +- `ImageBase64` - Image processing from base64 +- `MultipleImages` - Multiple image processing +- `CompleteEnd2End` - Full end-to-end test +- `ProviderSpecific` - Provider-specific features + +## πŸ§ͺ Test Scenarios + +Each provider is tested against these scenarios when supported: + +βœ… **Supported by Most Providers:** + +- Simple Text Completion +- Simple Chat Completion +- Multi-turn Chat Conversation +- Chat with System Message +- Text Completion with Parameters +- Chat Completion with Parameters +- Error Handling (Invalid Model) +- Model Information Retrieval +- Simple Function Calling + +❌ **Provider-Specific Support:** + +- **Automatic Function Calling**: OpenAI, Anthropic, Bedrock, Azure, Vertex, Mistral, Ollama +- **Vision/Image Analysis**: OpenAI, Anthropic, Bedrock, Azure, Vertex, Mistral (limited support for Cohere and Ollama) +- **Text Completion**: Legacy models only (most providers now focus on chat completion) + +## πŸ“Š Understanding Test Output + +The test suite provides rich visual feedback: + +- πŸš€ **Test suite starting** +- βœ… **Successful operations and supported tests** +- ❌ **Failed operations and unsupported features** +- ⏭️ **Skipped scenarios (not supported by provider)** +- πŸ“Š **Summary statistics** +- ℹ️ **Informational notes** + +Example output: + +```text +=== RUN TestOpenAI +πŸš€ Starting comprehensive test suite for OpenAI provider... +βœ… Simple Text Completion test completed successfully +βœ… Simple Chat Completion test completed successfully +⏭️ Automatic Function Calling not supported by this provider +πŸ“Š Test Summary for OpenAI: +βœ…βœ… Supported Tests: 11 +❌ Unsupported Tests: 1 +``` + +## πŸ”§ Adding New Providers + +To add a new provider to the test suite: + +### 1. Create Provider Test File + +Create a new file `{provider}_test.go`: + +```go +package tests + +import ( + "testing" + "github.com/BifrostDev/bifrost/pkg/client" +) + +func TestNewProvider(t *testing.T) { + config := client.Config{ + Provider: "newprovider", + APIKey: getEnvVar("NEW_PROVIDER_API_KEY"), + // Add other required config fields + } + + // Skip if no API key provided + if config.APIKey == "" { + t.Skip("NEW_PROVIDER_API_KEY not set, skipping NewProvider tests") + } + + runProviderTests(t, config, "NewProvider") +} +``` + +### 2. Update Provider Configuration + +Add your provider's capabilities in `tests.go`: + +```go +func getProviderCapabilities(providerName string) ProviderCapabilities { + switch providerName { + case "NewProvider": + return ProviderCapabilities{ + SupportsTextCompletion: true, + SupportsChatCompletion: true, + SupportsFunctionCalling: false, // Update based on provider + SupportsAutomaticFunctions: false, + SupportsVision: false, + SupportsSystemMessages: true, + SupportsMultiTurn: true, + SupportsParameters: true, + SupportsModelInfo: true, + SupportsErrorHandling: true, + } + // ... other cases + } +} +``` + +### 3. Add Default Models + +Add default models for your provider: + +```go +func getDefaultModel(providerName string) string { + switch providerName { + case "NewProvider": + return "newprovider-model-name" + // ... other cases + } +} +``` + +### 4. Environment Variables + +Document any required environment variables in this README and ensure they're handled in the test setup. + +### 5. Test Your Implementation + +Run your new provider tests: + +```bash +go test -v ./tests/core-providers/ -run TestNewProvider +``` + +## πŸ› οΈ Troubleshooting + +### Common Issues + +1. **Tests being skipped**: Make sure environment variables are set correctly +2. **Connection timeouts**: Check your network connection and API endpoints +3. **Authentication errors**: Verify your API keys are valid and have proper permissions +4. **Missing logs**: Use `-v` flag to see detailed test output +5. **Rate limiting**: Some providers have rate limits; tests may need delays +6. **Ollama connection issues**: Ensure Ollama is running locally (`ollama serve`) +7. **Mistral vision failures**: Check if your account has access to Pixtral models + +### Debug Mode + +Enable debug logging to see detailed API interactions: + +```bash +go test -v ./tests/core-providers/ -debug +``` + +### Provider-Specific Considerations + +#### Mistral AI + +- **Models**: Uses `pixtral-12b-latest` for vision tasks +- **Capabilities**: Full support for chat, tools, and vision +- **API Key**: Required via `MISTRAL_API_KEY` environment variable + +#### Ollama + +- **Local Setup**: Requires Ollama to be running locally (default: `http://localhost:11434`) +- **Models**: Uses `llama3.2` model by default +- **No API Key**: Authentication not required for local instances +- **Limitations**: No vision/image processing support +- **Installation**: [Download from ollama.ai](https://ollama.ai/) and ensure the service is running + +### Checking Provider Status + +If a provider seems to be failing, you can check their status pages: + +- [OpenAI Status](https://status.openai.com/) +- [Anthropic Status](https://status.anthropic.com/) +- [Azure Status](https://status.azure.com/) +- [AWS Status](https://status.aws.amazon.com/) +- [Mistral Status](https://status.mistral.ai/) + +## πŸ“ Test Coverage + +The comprehensive test suite covers: + +- βœ… **Text Completion** - Legacy completion models (where supported) +- βœ… **Simple Chat** - Basic chat completion functionality +- βœ… **Multi-Turn Conversations** - Context maintenance across messages +- βœ… **Tool Calls** - Basic function/tool calling capabilities +- βœ… **Multiple Tool Calls** - Multiple tools in a single request +- βœ… **End-to-End Tool Calling** - Complete tool workflow with result integration +- βœ… **Automatic Function Calling** - Provider-managed tool execution +- βœ… **Image URL Processing** - Image analysis from URLs +- βœ… **Image Base64 Processing** - Image analysis from base64 encoded data +- βœ… **Multiple Images** - Multi-image analysis and comparison +- βœ… **Complete End-to-End** - Full multimodal workflows +- βœ… **Provider-Specific Features** - Integration-unique capabilities + +### Provider Capability Matrix + +| Provider | Chat | Tools | Vision | Text Completion | Auto Functions | +| --------- | ---- | ----- | ------ | --------------- | -------------- | +| OpenAI | βœ… | βœ… | βœ… | ❌ | βœ… | +| Anthropic | βœ… | βœ… | βœ… | βœ… | βœ… | +| Azure | βœ… | βœ… | βœ… | βœ… | βœ… | +| Bedrock | βœ… | βœ… | βœ… | βœ… | βœ… | +| Vertex | βœ… | βœ… | βœ… | ❌ | βœ… | +| Cohere | βœ… | βœ… | ❌ | ❌ | ❌ | +| Mistral | βœ… | βœ… | βœ… | ❌ | βœ… | +| Ollama | βœ… | βœ… | ❌ | ❌ | βœ… | + +## 🀝 Contributing + +When adding new providers or test scenarios: + +### Adding New Providers + +1. **Create test file**: Add `{provider}_test.go` following the existing pattern +2. **Update config**: Add provider configuration in `config/account.go`: + - Add to `GetKeysForProvider()` (if API key required) + - Add to `GetConfigForProvider()` + - Add to `GetConfiguredProviders()` list +3. **Test scenarios**: Configure supported scenarios in the test file +4. **Documentation**: Update this README with environment variables and capabilities +5. **Testing**: Test with multiple scenarios to verify integration + +### Adding New Test Scenarios + +1. **Implement scenario**: Add new test function in `scenarios/` directory +2. **Update structure**: Add scenario to `TestScenarios` struct in `config/account.go` +3. **Configure providers**: Update each provider's scenario configuration +4. **Update runner**: Add scenario call to `runAllComprehensiveTests()` in `tests.go` +5. **Documentation**: Update README with scenario description and examples + +### Testing Your Changes + +```bash +# Test specific provider +go test -v ./tests/core-providers/ -run TestYourProvider + +# Test all providers +go test -v ./tests/core-providers/ + +# Test with debug output +go test -v ./tests/core-providers/ -debug +``` + +## πŸ“„ License + +This test suite is part of the Bifrost project and follows the same license terms. diff --git a/tests/core-providers/anthropic_test.go b/tests/core-providers/anthropic_test.go new file mode 100644 index 0000000000..d4760255ce --- /dev/null +++ b/tests/core-providers/anthropic_test.go @@ -0,0 +1,43 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestAnthropic(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Cleanup() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.Anthropic, + ChatModel: "claude-3-7-sonnet-20250219", + TextModel: "", // Anthropic doesn't support text completion + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/azure_test.go b/tests/core-providers/azure_test.go new file mode 100644 index 0000000000..a574703e8f --- /dev/null +++ b/tests/core-providers/azure_test.go @@ -0,0 +1,43 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestAzure(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Cleanup() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.Azure, + ChatModel: "gpt-4o", + TextModel: "", // Azure OpenAI doesn't support text completion in newer models + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/bedrock_test.go b/tests/core-providers/bedrock_test.go new file mode 100644 index 0000000000..f20c50ebe6 --- /dev/null +++ b/tests/core-providers/bedrock_test.go @@ -0,0 +1,43 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestBedrock(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Cleanup() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.Bedrock, + ChatModel: "anthropic.claude-3-sonnet-20240229-v1:0", + TextModel: "", // Bedrock Claude doesn't support text completion + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported for Claude + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: true, + MultipleImages: false, + CompleteEnd2End: true, + ProviderSpecific: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/cohere_test.go b/tests/core-providers/cohere_test.go new file mode 100644 index 0000000000..a1dc284f17 --- /dev/null +++ b/tests/core-providers/cohere_test.go @@ -0,0 +1,43 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestCohere(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Cleanup() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.Cohere, + ChatModel: "command-a-03-2025", + TextModel: "", // Cohere focuses on chat + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not typical for Cohere + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: false, // May not support automatic + ImageURL: false, // Check if supported + ImageBase64: false, // Check if supported + MultipleImages: false, // Check if supported + CompleteEnd2End: true, + ProviderSpecific: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/config/account.go b/tests/core-providers/config/account.go new file mode 100644 index 0000000000..8957d2ae77 --- /dev/null +++ b/tests/core-providers/config/account.go @@ -0,0 +1,415 @@ +// Package config provides comprehensive test account and configuration management for the Bifrost system. +// It implements account functionality for testing purposes, supporting multiple AI providers +// and comprehensive test scenarios. +package config + +import ( + "fmt" + "os" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/core/schemas/meta" +) + +// TestScenarios defines the comprehensive test scenarios +type TestScenarios struct { + TextCompletion bool + SimpleChat bool + MultiTurnConversation bool + ToolCalls bool + MultipleToolCalls bool + End2EndToolCalling bool + AutomaticFunctionCall bool + ImageURL bool + ImageBase64 bool + MultipleImages bool + CompleteEnd2End bool + ProviderSpecific bool +} + +// ComprehensiveTestConfig extends TestConfig with additional scenarios +type ComprehensiveTestConfig struct { + Provider schemas.ModelProvider + ChatModel string + TextModel string + Scenarios TestScenarios + CustomParams *schemas.ModelParameters + Fallbacks []schemas.Fallback + SkipReason string // Reason to skip certain tests +} + +// ComprehensiveTestAccount provides a test implementation of the Account interface for comprehensive testing. +type ComprehensiveTestAccount struct{} + +// getEnvWithDefault returns the value of the environment variable if set, otherwise returns the default value +func getEnvWithDefault(envVar, defaultValue string) string { + if value := os.Getenv(envVar); value != "" { + return value + } + return defaultValue +} + +// GetConfiguredProviders returns the list of initially supported providers. +func (account *ComprehensiveTestAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + return []schemas.ModelProvider{ + schemas.OpenAI, + schemas.Anthropic, + schemas.Bedrock, + schemas.Cohere, + schemas.Azure, + schemas.Vertex, + schemas.Ollama, + schemas.Mistral, + }, nil +} + +// GetKeysForProvider returns the API keys and associated models for a given provider. +func (account *ComprehensiveTestAccount) GetKeysForProvider(providerKey schemas.ModelProvider) ([]schemas.Key, error) { + switch providerKey { + case schemas.OpenAI: + return []schemas.Key{ + { + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{"gpt-4o-mini", "gpt-4-turbo"}, + Weight: 1.0, + }, + }, nil + case schemas.Anthropic: + return []schemas.Key{ + { + Value: os.Getenv("ANTHROPIC_API_KEY"), + Models: []string{"claude-3-7-sonnet-20250219", "claude-3-5-sonnet-20240620", "claude-2.1"}, + Weight: 1.0, + }, + }, nil + case schemas.Bedrock: + return []schemas.Key{ + { + Value: os.Getenv("BEDROCK_API_KEY"), + Models: []string{"anthropic.claude-v2:1", "mistral.mixtral-8x7b-instruct-v0:1", "mistral.mistral-large-2402-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"}, + Weight: 1.0, + }, + }, nil + case schemas.Cohere: + return []schemas.Key{ + { + Value: os.Getenv("COHERE_API_KEY"), + Models: []string{"command-a-03-2025", "c4ai-aya-vision-8b"}, + Weight: 1.0, + }, + }, nil + case schemas.Azure: + return []schemas.Key{ + { + Value: os.Getenv("AZURE_API_KEY"), + Models: []string{"gpt-4o"}, + Weight: 1.0, + }, + }, nil + case schemas.Vertex: + return []schemas.Key{ + { + Value: os.Getenv("VERTEX_API_KEY"), + Models: []string{"gemini-pro"}, + Weight: 1.0, + }, + }, nil + case schemas.Mistral: + return []schemas.Key{ + { + Value: os.Getenv("MISTRAL_API_KEY"), + Models: []string{"mistral-large-2411", "pixtral-12b-latest"}, + Weight: 1.0, + }, + }, nil + default: + return nil, fmt.Errorf("unsupported provider: %s", providerKey) + } +} + +// GetConfigForProvider returns the configuration settings for a given provider. +func (account *ComprehensiveTestAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + switch providerKey { + case schemas.OpenAI: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case schemas.Anthropic: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + case schemas.Bedrock: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + MetaConfig: &meta.BedrockMetaConfig{ + SecretAccessKey: os.Getenv("AWS_SECRET_ACCESS_KEY"), + Region: bifrost.Ptr(getEnvWithDefault("AWS_REGION", "us-east-1")), + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case schemas.Cohere: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + case schemas.Azure: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + MetaConfig: &meta.AzureMetaConfig{ + Endpoint: os.Getenv("AZURE_ENDPOINT"), + Deployments: map[string]string{ + "gpt-4o": "gpt-4o-aug", + }, + // Use environment variable for API version with fallback to current preview version + // Note: This is a preview API version that may change over time. Update as needed. + // Set AZURE_API_VERSION environment variable to override the default. + APIVersion: bifrost.Ptr(getEnvWithDefault("AZURE_API_VERSION", "2024-08-01-preview")), + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case schemas.Vertex: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + MetaConfig: &meta.VertexMetaConfig{ + ProjectID: os.Getenv("VERTEX_PROJECT_ID"), + Region: getEnvWithDefault("VERTEX_REGION", "us-central1"), + AuthCredentials: os.Getenv("VERTEX_CREDENTIALS"), + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil + case schemas.Ollama: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + BaseURL: getEnvWithDefault("OLLAMA_BASE_URL", "http://localhost:11434"), + }, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + case schemas.Mistral: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + default: + return nil, fmt.Errorf("unsupported provider: %s", providerKey) + } +} + +// AllProviderConfigs contains test configurations for all providers +var AllProviderConfigs = []ComprehensiveTestConfig{ + { + Provider: schemas.OpenAI, + ChatModel: "gpt-4o-mini", + TextModel: "", // OpenAI doesn't support text completion in newer models + Scenarios: TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.Anthropic, Model: "claude-3-7-sonnet-20250219"}, + }, + }, + { + Provider: schemas.Anthropic, + ChatModel: "claude-3-7-sonnet-20250219", + TextModel: "", // Anthropic doesn't support text completion + Scenarios: TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + }, + { + Provider: schemas.Bedrock, + ChatModel: "anthropic.claude-3-sonnet-20240229-v1:0", + TextModel: "", // Bedrock Claude doesn't support text completion + Scenarios: TestScenarios{ + TextCompletion: false, // Not supported for Claude + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + }, + { + Provider: schemas.Cohere, + ChatModel: "command-a-03-2025", + TextModel: "", // Cohere focuses on chat + Scenarios: TestScenarios{ + TextCompletion: false, // Not typical for Cohere + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: false, // May not support automatic + ImageURL: false, // Check if supported + ImageBase64: false, // Check if supported + MultipleImages: false, // Check if supported + CompleteEnd2End: true, + ProviderSpecific: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + }, + { + Provider: schemas.Azure, + ChatModel: "gpt-4o", + TextModel: "", // Azure OpenAI doesn't support text completion in newer models + Scenarios: TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + }, + { + Provider: schemas.Vertex, + ChatModel: "gemini-pro", + TextModel: "", // Vertex focuses on chat + Scenarios: TestScenarios{ + TextCompletion: false, // Not typical + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + }, + { + Provider: schemas.Mistral, + ChatModel: "mistral-large-2411", + TextModel: "", // Mistral focuses on chat + Scenarios: TestScenarios{ + TextCompletion: false, // Not typical + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + }, + { + Provider: schemas.Ollama, + ChatModel: "llama3.2", + TextModel: "", // Ollama focuses on chat + Scenarios: TestScenarios{ + TextCompletion: false, // Not typical + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + CompleteEnd2End: true, + ProviderSpecific: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, + }, + }, +} diff --git a/tests/core-providers/config/setup.go b/tests/core-providers/config/setup.go new file mode 100644 index 0000000000..694dfde0b2 --- /dev/null +++ b/tests/core-providers/config/setup.go @@ -0,0 +1,59 @@ +// Package config provides comprehensive test utilities and configurations for the Bifrost system. +// It includes comprehensive test implementations covering all major AI provider scenarios, +// including text completion, chat, tool calling, image processing, and end-to-end workflows. +package config + +import ( + "context" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// Constants for test configuration +const ( + // TestTimeout defines the maximum duration for comprehensive tests + // Set to 5 minutes to allow for complex multi-step operations + TestTimeout = 5 * time.Minute +) + +// getBifrost initializes and returns a Bifrost instance for comprehensive testing. +// It sets up the comprehensive test account, plugin, and logger configuration. +// +// Environment variables are expected to be set by the system or test runner before calling this function. +// The account configuration will read API keys and settings from these environment variables. +// +// Returns: +// - *bifrost.Bifrost: A configured Bifrost instance ready for comprehensive testing +// - error: Any error that occurred during Bifrost initialization +// +// The function: +// 1. Creates a comprehensive test account instance +// 2. Configures Bifrost with the account and default logger +func getBifrost() (*bifrost.Bifrost, error) { + account := ComprehensiveTestAccount{} + + // Initialize Bifrost + b, err := bifrost.Init(schemas.BifrostConfig{ + Account: &account, + Plugins: nil, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug), + }) + if err != nil { + return nil, err + } + + return b, nil +} + +// SetupTest initializes a test environment with timeout context +func SetupTest() (*bifrost.Bifrost, context.Context, context.CancelFunc, error) { + client, err := getBifrost() + if err != nil { + return nil, nil, nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + return client, ctx, cancel, nil +} diff --git a/tests/core-providers/go.mod b/tests/core-providers/go.mod new file mode 100644 index 0000000000..1ed870d853 --- /dev/null +++ b/tests/core-providers/go.mod @@ -0,0 +1,42 @@ +module github.com/maximhq/bifrost/tests/core-providers + +go 1.24.1 + +require ( + github.com/maximhq/bifrost/core v1.1.5 + github.com/stretchr/testify v1.10.0 +) + +require ( + cloud.google.com/go/compute/metadata v0.3.0 // indirect + github.com/andybalholm/brotli v1.1.1 // indirect + github.com/aws/aws-sdk-go-v2 v1.36.3 // indirect + github.com/aws/aws-sdk-go-v2/config v1.29.14 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect + github.com/aws/smithy-go v1.22.3 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/goccy/go-json v0.10.5 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/mark3labs/mcp-go v0.32.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/spf13/cast v1.7.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.60.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/net v0.39.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/text v0.24.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +// replace github.com/maximhq/bifrost/core => ../../core diff --git a/tests/core-providers/go.sum b/tests/core-providers/go.sum new file mode 100644 index 0000000000..baed34d884 --- /dev/null +++ b/tests/core-providers/go.sum @@ -0,0 +1,76 @@ +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= +github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= +github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= +github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= +github.com/aws/aws-sdk-go-v2/config v1.29.14 h1:f+eEi/2cKCg9pqKBoAIwRGzVb70MRKqWX4dg1BDcSJM= +github.com/aws/aws-sdk-go-v2/config v1.29.14/go.mod h1:wVPHWcIFv3WO89w0rE10gzf17ZYy+UVS1Geq8Iei34g= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67 h1:9KxtdcIA/5xPNQyZRgUSpYOE6j9Bc4+D7nZua0KGYOM= +github.com/aws/aws-sdk-go-v2/credentials v1.17.67/go.mod h1:p3C44m+cfnbv763s52gCqrjaqyPikj9Sg47kUVaNZQQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3 h1:1Gw+9ajCV1jogloEv1RRnvfRFia2cL6c9cuKV2Ps+G8= +github.com/aws/aws-sdk-go-v2/service/sso v1.25.3/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 h1:hXmVKytPfTy5axZ+fYbR5d0cFmC3JvwLm5kM83luako= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/XvaX32evhproijJEZY= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= +github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k= +github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= +github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8= +github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/maximhq/bifrost/core v1.1.5 h1:Nm9XlS9Nso+pn+U5/btsJD8qRDYGQ1BBOjgqWT3PYSc= +github.com/maximhq/bifrost/core v1.1.5/go.mod h1:yMRCncTgKYBIrECSRVxMbY3BL8CjLbipJlc644jryxc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.60.0 h1:kBRYS0lOhVJ6V+bYN8PqAHELKHtXqwq9zNMLKx1MBsw= +github.com/valyala/fasthttp v1.60.0/go.mod h1:iY4kDgV3Gc6EqhRZ8icqcmlG6bqhcDXfuHgTO4FXCvc= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/tests/core-providers/mistral_test.go b/tests/core-providers/mistral_test.go new file mode 100644 index 0000000000..57a7acb637 --- /dev/null +++ b/tests/core-providers/mistral_test.go @@ -0,0 +1,40 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestMistral(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Cleanup() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.Mistral, + ChatModel: "pixtral-12b-latest", + TextModel: "", // Mistral doesn't support text completion in newer models + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/ollama_test.go b/tests/core-providers/ollama_test.go new file mode 100644 index 0000000000..2d69dd3dae --- /dev/null +++ b/tests/core-providers/ollama_test.go @@ -0,0 +1,40 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestOllama(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Cleanup() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.Ollama, + ChatModel: "llama3.2", + TextModel: "", // Ollama doesn't support text completion in newer models + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + CompleteEnd2End: true, + ProviderSpecific: true, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/openai_test.go b/tests/core-providers/openai_test.go new file mode 100644 index 0000000000..e944a748df --- /dev/null +++ b/tests/core-providers/openai_test.go @@ -0,0 +1,43 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestOpenAI(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Cleanup() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.OpenAI, + ChatModel: "gpt-4o-mini", + TextModel: "", // OpenAI doesn't support text completion in newer models + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.Anthropic, Model: "claude-3-7-sonnet-20250219"}, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/core-providers/scenarios/automatic_function_calling.go b/tests/core-providers/scenarios/automatic_function_calling.go new file mode 100644 index 0000000000..08d9b41916 --- /dev/null +++ b/tests/core-providers/scenarios/automatic_function_calling.go @@ -0,0 +1,76 @@ +package scenarios + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/require" +) + +// RunAutomaticFunctionCallingTest executes the automatic function calling test scenario +func RunAutomaticFunctionCallingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.AutomaticFunctionCall { + t.Logf("Automatic function calling not supported for provider %s", testConfig.Provider) + return + } + + t.Run("AutomaticFunctionCalling", func(t *testing.T) { + messages := []schemas.BifrostMessage{ + CreateBasicChatMessage("Get the current time in UTC timezone"), + } + + params := MergeModelParameters(&schemas.ModelParameters{ + Tools: &[]schemas.Tool{TimeToolDefinition}, + ToolChoice: &schemas.ToolChoice{ + ToolChoiceStruct: &schemas.ToolChoiceStruct{ + Type: schemas.ToolChoiceTypeFunction, + Function: schemas.ToolChoiceFunction{ + Name: "get_current_time", + }, + }, + }, + MaxTokens: bifrost.Ptr(150), + }, testConfig.CustomParams) + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: params, + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.ChatCompletionRequest(ctx, request) + require.Nilf(t, err, "Automatic function calling failed: %v", err) + require.NotNil(t, response) + require.NotEmpty(t, response.Choices) + + // Find at least one choice with valid tool calls + foundValidToolCall := false + for i, choice := range response.Choices { + message := choice.Message + if message.AssistantMessage != nil && message.AssistantMessage.ToolCalls != nil { + toolCalls := *message.AssistantMessage.ToolCalls + // Iterate through all tool calls, not just the first one + for j, toolCall := range toolCalls { + if toolCall.Function.Name != nil && *toolCall.Function.Name == "get_current_time" { + foundValidToolCall = true + t.Logf("βœ… Automatic function call for choice %d, tool call %d: %s", i, j, toolCall.Function.Arguments) + break // Found valid tool call, can break from this inner loop + } + } + if foundValidToolCall { + break // Found valid tool call, can break from choices loop + } + } + } + + require.True(t, foundValidToolCall, "Expected at least one choice to have automatic tool call for 'get_current_time'. Response: %s", GetResultContent(response)) + }) +} diff --git a/tests/core-providers/scenarios/complete_end_to_end.go b/tests/core-providers/scenarios/complete_end_to_end.go new file mode 100644 index 0000000000..880485a066 --- /dev/null +++ b/tests/core-providers/scenarios/complete_end_to_end.go @@ -0,0 +1,118 @@ +package scenarios + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunCompleteEnd2EndTest executes the complete end-to-end test scenario +func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.CompleteEnd2End { + t.Logf("Complete end-to-end not supported for provider %s", testConfig.Provider) + return + } + + t.Run("CompleteEnd2End", func(t *testing.T) { + // Multi-step conversation with tools and images + userMessage1 := CreateBasicChatMessage("Hi, I'm planning a trip. Can you help me get the weather in Paris?") + + request1 := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{userMessage1}, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + Tools: &[]schemas.Tool{WeatherToolDefinition}, + MaxTokens: bifrost.Ptr(150), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response1, err := client.ChatCompletionRequest(ctx, request1) + require.Nilf(t, err, "First end-to-end request failed: %v", err) + require.NotNil(t, response1) + require.NotEmpty(t, response1.Choices) + + t.Logf("βœ… First response: %s", GetResultContent(response1)) + + // If tool was called, simulate result and continue conversation + var conversationHistory []schemas.BifrostMessage + conversationHistory = append(conversationHistory, userMessage1) + + // Add all choice messages to conversation history + for _, choice := range response1.Choices { + conversationHistory = append(conversationHistory, choice.Message) + } + + // Find any choice with tool calls for processing + var selectedToolCall *schemas.ToolCall + for _, choice := range response1.Choices { + message := choice.Message + if message.AssistantMessage != nil && message.AssistantMessage.ToolCalls != nil { + toolCalls := *message.AssistantMessage.ToolCalls + // Look for a valid weather tool call + for _, toolCall := range toolCalls { + if toolCall.Function.Name != nil && *toolCall.Function.Name == "get_weather" { + selectedToolCall = &toolCall + break + } + } + if selectedToolCall != nil { + break + } + } + } + + // If a tool call was found, simulate the result + if selectedToolCall != nil { + // Simulate tool result + toolResult := `{"temperature": "18", "unit": "celsius", "description": "Partly cloudy", "humidity": "70%"}` + toolCallID := "" + if selectedToolCall.ID != nil { + toolCallID = *selectedToolCall.ID + } else if selectedToolCall.Function.Name != nil { + toolCallID = *selectedToolCall.Function.Name + } + require.NotEmpty(t, toolCallID, "toolCallID must not be empty – provider did not return ID or Function.Name") + toolMessage := CreateToolMessage(toolResult, toolCallID) + conversationHistory = append(conversationHistory, toolMessage) + } + + // Continue with follow-up + followUpMessage := CreateBasicChatMessage("Thanks! Now can you tell me about this travel image?") + if testConfig.Scenarios.ImageURL { + followUpMessage = CreateImageMessage("Thanks! Now can you tell me what you see in this travel-related image?", TestImageURL) + } + conversationHistory = append(conversationHistory, followUpMessage) + + finalRequest := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &conversationHistory, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(200), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + finalResponse, err := client.ChatCompletionRequest(ctx, finalRequest) + require.Nilf(t, err, "Final end-to-end request failed: %v", err) + require.NotNil(t, finalResponse) + require.NotEmpty(t, finalResponse.Choices) + + finalContent := GetResultContent(finalResponse) + assert.NotEmpty(t, finalContent, "Final response content should not be empty") + + t.Logf("βœ… Complete end-to-end result: %s", finalContent) + }) +} diff --git a/tests/core-providers/scenarios/end_to_end_tool_calling.go b/tests/core-providers/scenarios/end_to_end_tool_calling.go new file mode 100644 index 0000000000..9995b61cc7 --- /dev/null +++ b/tests/core-providers/scenarios/end_to_end_tool_calling.go @@ -0,0 +1,121 @@ +package scenarios + +import ( + "context" + "strings" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunEnd2EndToolCallingTest executes the end-to-end tool calling test scenario +func RunEnd2EndToolCallingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.End2EndToolCalling { + t.Logf("End-to-end tool calling not supported for provider %s", testConfig.Provider) + return + } + + t.Run("End2EndToolCalling", func(t *testing.T) { + // Step 1: User asks for weather + userMessage := CreateBasicChatMessage("What's the weather in San Francisco?") + + params := MergeModelParameters(&schemas.ModelParameters{ + Tools: &[]schemas.Tool{WeatherToolDefinition}, + MaxTokens: bifrost.Ptr(150), + }, testConfig.CustomParams) + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{userMessage}, + }, + Params: params, + Fallbacks: testConfig.Fallbacks, + } + + // Execute first request + firstResponse, err := client.ChatCompletionRequest(ctx, request) + require.Nilf(t, err, "First request failed: %v", err) + require.NotNil(t, firstResponse) + require.NotEmpty(t, firstResponse.Choices) + + // Find a choice with valid tool calls + var toolCall schemas.ToolCall + foundValidChoice := false + + for _, choice := range firstResponse.Choices { + if choice.Message.AssistantMessage != nil && + choice.Message.AssistantMessage.ToolCalls != nil && + len(*choice.Message.AssistantMessage.ToolCalls) > 0 { + + firstToolCall := (*choice.Message.AssistantMessage.ToolCalls)[0] + if firstToolCall.Function.Name != nil && *firstToolCall.Function.Name == "get_weather" { + toolCall = firstToolCall + foundValidChoice = true + break + } + } + } + + require.True(t, foundValidChoice, "Expected at least one choice to have valid tool call for 'get_weather'") + + // Step 2: Simulate tool execution and provide result + toolResult := `{"temperature": "22", "unit": "celsius", "description": "Sunny with light clouds", "humidity": "65%"}` + + toolCallID := "" + if toolCall.ID != nil { + toolCallID = *toolCall.ID + } else { + toolCallID = *toolCall.Function.Name + } + + require.NotEmpty(t, toolCallID, "toolCallID must not be empty") + + // Build conversation history with all choice messages from first response + conversationMessages := []schemas.BifrostMessage{ + userMessage, + } + + // Add all choice messages from the first response + for _, choice := range firstResponse.Choices { + conversationMessages = append(conversationMessages, choice.Message) + } + + // Add the tool result message + conversationMessages = append(conversationMessages, CreateToolMessage(toolResult, toolCallID)) + + secondRequest := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &conversationMessages, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(200), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + // Execute second request + finalResponse, err := client.ChatCompletionRequest(ctx, secondRequest) + require.Nilf(t, err, "Second request failed: %v", err) + require.NotNil(t, finalResponse) + require.NotEmpty(t, finalResponse.Choices) + + content := GetResultContent(finalResponse) + require.NotEmpty(t, content, "Response content should not be empty") + + // Verify response contains expected information + assert.Contains(t, strings.ToLower(content), "san francisco", "Response should mention San Francisco") + assert.Contains(t, content, "22", "Response should mention temperature") + assert.Contains(t, strings.ToLower(content), "sunny", "Response should mention weather description") + + t.Logf("βœ… End-to-end tool calling result: %s", content) + }) +} diff --git a/tests/core-providers/scenarios/image_base64.go b/tests/core-providers/scenarios/image_base64.go new file mode 100644 index 0000000000..b10655d6f0 --- /dev/null +++ b/tests/core-providers/scenarios/image_base64.go @@ -0,0 +1,49 @@ +package scenarios + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunImageBase64Test executes the image base64 test scenario +func RunImageBase64Test(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.ImageBase64 { + t.Logf("Image base64 not supported for provider %s", testConfig.Provider) + return + } + + t.Run("ImageBase64", func(t *testing.T) { + messages := []schemas.BifrostMessage{ + CreateImageMessage("Describe this image briefly", TestImageBase64), + } + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(200), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.ChatCompletionRequest(ctx, request) + require.Nilf(t, err, "Image base64 test failed: %v", err) + require.NotNil(t, response) + require.NotEmpty(t, response.Choices) + + content := GetResultContent(response) + assert.NotEmpty(t, content, "Response content should not be empty") + + t.Logf("βœ… Image base64 result: %s", content) + }) +} diff --git a/tests/core-providers/scenarios/image_url.go b/tests/core-providers/scenarios/image_url.go new file mode 100644 index 0000000000..9f11d63289 --- /dev/null +++ b/tests/core-providers/scenarios/image_url.go @@ -0,0 +1,55 @@ +package scenarios + +import ( + "context" + "strings" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunImageURLTest executes the image URL test scenario +func RunImageURLTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.ImageURL { + t.Logf("Image URL not supported for provider %s", testConfig.Provider) + return + } + + t.Run("ImageURL", func(t *testing.T) { + messages := []schemas.BifrostMessage{ + CreateImageMessage("What do you see in this image?", TestImageURL), + } + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(200), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.ChatCompletionRequest(ctx, request) + require.Nilf(t, err, "Image URL test failed: %v", err) + require.NotNil(t, response) + require.NotEmpty(t, response.Choices) + + content := GetResultContent(response) + assert.NotEmpty(t, content, "Response content should not be empty") + // Should mention something about the ant in the image + lowerContent := strings.ToLower(content) + assert.True(t, strings.Contains(lowerContent, "ant") || + strings.Contains(lowerContent, "insect"), + "Response should identify the ant/insect in the image") + + t.Logf("βœ… Image URL result: %s", content) + }) +} diff --git a/tests/core-providers/scenarios/multi_turn_conversation.go b/tests/core-providers/scenarios/multi_turn_conversation.go new file mode 100644 index 0000000000..10512578bc --- /dev/null +++ b/tests/core-providers/scenarios/multi_turn_conversation.go @@ -0,0 +1,84 @@ +package scenarios + +import ( + "context" + "strings" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunMultiTurnConversationTest executes the multi-turn conversation test scenario +func RunMultiTurnConversationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.MultiTurnConversation { + t.Logf("Multi-turn conversation not supported for provider %s", testConfig.Provider) + return + } + + t.Run("MultiTurnConversation", func(t *testing.T) { + // First message + userMessage1 := CreateBasicChatMessage("My name is Alice. Remember this.") + messages1 := []schemas.BifrostMessage{ + userMessage1, + } + + firstRequest := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages1, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(150), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response1, err := client.ChatCompletionRequest(ctx, firstRequest) + require.Nilf(t, err, "First conversation turn failed: %v", err) + require.NotNil(t, response1) + require.NotEmpty(t, response1.Choices) + + // Second message with conversation history + // Build conversation history with all choice messages + messages2 := []schemas.BifrostMessage{ + userMessage1, + } + + // Add all choice messages from the first response + for _, choice := range response1.Choices { + messages2 = append(messages2, choice.Message) + } + + // Add the follow-up question + messages2 = append(messages2, CreateBasicChatMessage("What's my name?")) + + secondRequest := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages2, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(150), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response2, err := client.ChatCompletionRequest(ctx, secondRequest) + require.Nilf(t, err, "Second conversation turn failed: %v", err) + require.NotNil(t, response2) + require.NotEmpty(t, response2.Choices) + + content := GetResultContent(response2) + assert.NotEmpty(t, content, "Response content should not be empty") + // Check if the model remembered the name + assert.Contains(t, strings.ToLower(content), "alice", "Model should remember the name Alice") + t.Logf("βœ… Multi-turn conversation result: %s", content) + }) +} diff --git a/tests/core-providers/scenarios/multiple_images.go b/tests/core-providers/scenarios/multiple_images.go new file mode 100644 index 0000000000..ba8d70a2b4 --- /dev/null +++ b/tests/core-providers/scenarios/multiple_images.go @@ -0,0 +1,71 @@ +package scenarios + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunMultipleImagesTest executes the multiple images test scenario +func RunMultipleImagesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.MultipleImages { + t.Logf("Multiple images not supported for provider %s", testConfig.Provider) + return + } + + t.Run("MultipleImages", func(t *testing.T) { + messages := []schemas.BifrostMessage{ + { + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentBlocks: &[]schemas.ContentBlock{ + { + Type: schemas.ContentBlockTypeText, + Text: bifrost.Ptr("Compare these two images - what are the similarities and differences?"), + }, + { + Type: schemas.ContentBlockTypeImage, + ImageURL: &schemas.ImageURLStruct{ + URL: TestImageURL, + }, + }, + { + Type: schemas.ContentBlockTypeImage, + ImageURL: &schemas.ImageURLStruct{ + URL: TestImageBase64, + }, + }, + }, + }, + }, + } + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(300), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.ChatCompletionRequest(ctx, request) + require.Nilf(t, err, "Multiple images test failed: %v", err) + require.NotNil(t, response) + require.NotEmpty(t, response.Choices) + + content := GetResultContent(response) + assert.NotEmpty(t, content, "Response content should not be empty") + + t.Logf("βœ… Multiple images result: %s", content) + }) +} diff --git a/tests/core-providers/scenarios/multiple_tool_calls.go b/tests/core-providers/scenarios/multiple_tool_calls.go new file mode 100644 index 0000000000..dd65e663c3 --- /dev/null +++ b/tests/core-providers/scenarios/multiple_tool_calls.go @@ -0,0 +1,109 @@ +package scenarios + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "slices" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/require" +) + +// getKeysFromMap returns the keys of a map[string]bool as a slice +func getKeysFromMap(m map[string]bool) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} + +// RunMultipleToolCallsTest executes the multiple tool calls test scenario +func RunMultipleToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.MultipleToolCalls { + t.Logf("Multiple tool calls not supported for provider %s", testConfig.Provider) + return + } + + t.Run("MultipleToolCalls", func(t *testing.T) { + messages := []schemas.BifrostMessage{ + CreateBasicChatMessage("I need to know the weather in London and also calculate 15 * 23. Can you help with both?"), + } + + params := MergeModelParameters(&schemas.ModelParameters{ + Tools: &[]schemas.Tool{WeatherToolDefinition, CalculatorToolDefinition}, + MaxTokens: bifrost.Ptr(200), + }, testConfig.CustomParams) + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: params, + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.ChatCompletionRequest(ctx, request) + require.Nilf(t, err, "Multiple tool calls failed: %v", err) + require.NotNil(t, response) + require.NotEmpty(t, response.Choices) + + // Find at least one choice with multiple valid tool calls + expectedToolNames := []string{"get_weather", "calculate"} + foundValidMultipleToolCalls := false + for choiceIdx, choice := range response.Choices { + message := choice.Message + if message.AssistantMessage != nil && message.AssistantMessage.ToolCalls != nil { + toolCalls := *message.AssistantMessage.ToolCalls + if len(toolCalls) >= 2 { + validToolCalls := 0 + foundToolNames := make(map[string]bool) + + for _, toolCall := range toolCalls { + if toolCall.Function.Name != nil { + toolName := *toolCall.Function.Name + // Check if this is one of the expected tool names + isExpected := false + for _, expectedName := range expectedToolNames { + if toolName == expectedName { + isExpected = true + foundToolNames[toolName] = true + break + } + } + if isExpected { + validToolCalls++ + } + } + } + + // Require at least 2 valid tool calls with expected names + if validToolCalls >= 2 { + foundValidMultipleToolCalls = true + t.Logf("βœ… Number of tool calls for choice %d: %d", choiceIdx, len(toolCalls)) + t.Logf("βœ… Found expected tools: %v", getKeysFromMap(foundToolNames)) + + for i, toolCall := range toolCalls { + if toolCall.Function.Name != nil { + toolName := *toolCall.Function.Name + // Validate that each tool name is expected + isExpected := slices.Contains(expectedToolNames, toolName) + require.True(t, isExpected, "Unexpected tool call '%s' - expected one of %v", toolName, expectedToolNames) + t.Logf("βœ… Tool call %d for choice %d: %s with args: %s", i+1, choiceIdx, toolName, toolCall.Function.Arguments) + } + } + break // Found a valid choice with multiple tool calls + } + } + } + } + + require.True(t, foundValidMultipleToolCalls, "Expected at least one choice to have 2 or more valid tool calls. Response: %s", GetResultContent(response)) + }) +} diff --git a/tests/core-providers/scenarios/provider_specific.go b/tests/core-providers/scenarios/provider_specific.go new file mode 100644 index 0000000000..0c4b3bff0c --- /dev/null +++ b/tests/core-providers/scenarios/provider_specific.go @@ -0,0 +1,55 @@ +package scenarios + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunProviderSpecificTest executes the provider-specific test scenario +func RunProviderSpecificTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.ProviderSpecific { + t.Logf("Provider-specific tests not configured for provider %s", testConfig.Provider) + return + } + + t.Run("ProviderSpecific", func(t *testing.T) { + // This would contain provider-specific tests + // For now, we'll do a basic functionality test + messages := []schemas.BifrostMessage{ + CreateBasicChatMessage("Test provider-specific functionality. What makes you unique?"), + } + + // Initialize with default parameters and merge with custom parameters + defaultParams := &schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(150), + } + params := MergeModelParameters(defaultParams, testConfig.CustomParams) + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: params, + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.ChatCompletionRequest(ctx, request) + require.Nilf(t, err, "Provider-specific test failed: %v", err) + require.NotNil(t, response) + require.NotEmpty(t, response.Choices) + + content := GetResultContent(response) + assert.NotEmpty(t, content, "Response content should not be empty") + + t.Logf("βœ… Provider-specific result for %s: %s", testConfig.Provider, content) + }) +} diff --git a/tests/core-providers/scenarios/simple_chat.go b/tests/core-providers/scenarios/simple_chat.go new file mode 100644 index 0000000000..5665e4fc63 --- /dev/null +++ b/tests/core-providers/scenarios/simple_chat.go @@ -0,0 +1,48 @@ +package scenarios + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunSimpleChatTest executes the simple chat test scenario +func RunSimpleChatTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.SimpleChat { + t.Logf("Simple chat not supported for provider %s", testConfig.Provider) + return + } + + t.Run("SimpleChat", func(t *testing.T) { + messages := []schemas.BifrostMessage{ + CreateBasicChatMessage("Hello! What's the capital of France?"), + } + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(150), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.ChatCompletionRequest(ctx, request) + require.Nilf(t, err, "Simple chat failed: %v", err) + require.NotNil(t, response) + require.NotEmpty(t, response.Choices) + + content := GetResultContent(response) + assert.NotEmpty(t, content, "Response content should not be empty") + t.Logf("βœ… Simple chat result: %s", content) + }) +} diff --git a/tests/core-providers/scenarios/text_completion.go b/tests/core-providers/scenarios/text_completion.go new file mode 100644 index 0000000000..ead8f5fe48 --- /dev/null +++ b/tests/core-providers/scenarios/text_completion.go @@ -0,0 +1,45 @@ +package scenarios + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RunTextCompletionTest tests text completion functionality +func RunTextCompletionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.TextCompletion || testConfig.TextModel == "" { + t.Logf("⏭️ Text completion not supported for provider %s", testConfig.Provider) + return + } + + t.Run("TextCompletion", func(t *testing.T) { + prompt := "The future of artificial intelligence is" + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.TextModel, + Input: schemas.RequestInput{ + TextCompletionInput: &prompt, + }, + Params: MergeModelParameters(&schemas.ModelParameters{ + MaxTokens: bifrost.Ptr(100), + }, testConfig.CustomParams), + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.TextCompletionRequest(ctx, request) + require.Nilf(t, err, "Text completion failed: %v", err) + require.NotNil(t, response) + require.NotEmpty(t, response.Choices) + + content := GetResultContent(response) + assert.NotEmpty(t, content, "Response content should not be empty") + t.Logf("βœ… Text completion result: %s", content) + }) +} diff --git a/tests/core-providers/scenarios/tool_calls.go b/tests/core-providers/scenarios/tool_calls.go new file mode 100644 index 0000000000..7cb30edce1 --- /dev/null +++ b/tests/core-providers/scenarios/tool_calls.go @@ -0,0 +1,79 @@ +package scenarios + +import ( + "context" + "encoding/json" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/require" +) + +// RunToolCallsTest executes the tool calls test scenario +func RunToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if !testConfig.Scenarios.ToolCalls { + t.Logf("Tool calls not supported for provider %s", testConfig.Provider) + return + } + + t.Run("ToolCalls", func(t *testing.T) { + messages := []schemas.BifrostMessage{ + CreateBasicChatMessage("What's the weather like in New York?"), + } + + params := MergeModelParameters(&schemas.ModelParameters{ + Tools: &[]schemas.Tool{WeatherToolDefinition}, + MaxTokens: bifrost.Ptr(150), + }, testConfig.CustomParams) + + request := &schemas.BifrostRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: schemas.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: params, + Fallbacks: testConfig.Fallbacks, + } + + response, err := client.ChatCompletionRequest(ctx, request) + require.Nilf(t, err, "Tool calls failed: %v", err) + require.NotNil(t, response) + require.NotEmpty(t, response.Choices) + + // Find at least one choice with valid tool calls + foundValidToolCall := false + for i, choice := range response.Choices { + message := choice.Message + if message.AssistantMessage != nil && message.AssistantMessage.ToolCalls != nil { + toolCalls := *message.AssistantMessage.ToolCalls + // Iterate through all tool calls, not just the first one + for j, toolCall := range toolCalls { + if toolCall.Function.Name != nil && *toolCall.Function.Name == "get_weather" { + // Verify arguments contain location + var args map[string]interface{} + err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args) + if err == nil { + if _, hasLocation := args["location"]; hasLocation { + foundValidToolCall = true + t.Logf("βœ… Tool call arguments for choice %d, tool call %d: %s", i, j, toolCall.Function.Arguments) + break // Found valid tool call, can break from this inner loop + } + } + } + } + if foundValidToolCall { + break // Found valid tool call, can break from choices loop + } + } + } + + if !foundValidToolCall { + t.Logf("❌ No valid tool calls found in any choice, response: %s", GetResultContent(response)) + } + require.True(t, foundValidToolCall, "Expected at least one choice to have valid tool call for 'get_weather' with 'location' argument") + }) +} diff --git a/tests/core-providers/scenarios/utils.go b/tests/core-providers/scenarios/utils.go new file mode 100644 index 0000000000..9776aafeb6 --- /dev/null +++ b/tests/core-providers/scenarios/utils.go @@ -0,0 +1,229 @@ +package scenarios + +import ( + "strings" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// Tool definitions for testing +var WeatherToolDefinition = schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: schemas.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": map[string]interface{}{ + "type": "string", + "enum": []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }, +} + +var CalculatorToolDefinition = schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: "calculate", + Description: "Perform basic mathematical calculations", + Parameters: schemas.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "expression": map[string]interface{}{ + "type": "string", + "description": "The mathematical expression to evaluate, e.g. '2 + 3' or '10 * 5'", + }, + }, + Required: []string{"expression"}, + }, + }, +} + +var TimeToolDefinition = schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: "get_current_time", + Description: "Get the current time in a specific timezone", + Parameters: schemas.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "timezone": map[string]interface{}{ + "type": "string", + "description": "The timezone identifier, e.g. 'America/New_York' or 'UTC'", + }, + }, + Required: []string{"timezone"}, + }, + }, +} + +// Test images for testing +const TestImageURL = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg" +const TestImageBase64 = "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEAYABgAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAAIAAoDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAb/xAAUEAEAAAAAAAAAAAAAAAAAAAAA/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAX/xAAUEQEAAAAAAAAAAAAAAAAAAAAA/9oADAMBAAIRAxEAPwCdABmX/9k=" + +// Helper functions for creating requests +func CreateBasicChatMessage(content string) schemas.BifrostMessage { + return schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr(content), + }, + } +} + +func CreateImageMessage(text, imageURL string) schemas.BifrostMessage { + return schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleUser, + Content: schemas.MessageContent{ + ContentBlocks: &[]schemas.ContentBlock{ + { + Type: schemas.ContentBlockTypeText, + Text: bifrost.Ptr(text), + }, + { + Type: schemas.ContentBlockTypeImage, + ImageURL: &schemas.ImageURLStruct{ + URL: imageURL, + }, + }, + }, + }, + } +} + +func CreateToolMessage(content string, toolCallID string) schemas.BifrostMessage { + return schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleTool, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr(content), + }, + ToolMessage: &schemas.ToolMessage{ + ToolCallID: &toolCallID, + }, + } +} + +// GetResultContent returns the string content from a BifrostResponse +// It looks through all choices and returns content from the first choice that has any +func GetResultContent(result *schemas.BifrostResponse) string { + if result == nil || len(result.Choices) == 0 { + return "" + } + + // Try to find content from any choice, prioritizing non-empty content + for _, choice := range result.Choices { + if choice.Message.Content.ContentStr != nil && *choice.Message.Content.ContentStr != "" { + return *choice.Message.Content.ContentStr + } else if choice.Message.Content.ContentBlocks != nil { + var builder strings.Builder + for _, block := range *choice.Message.Content.ContentBlocks { + if block.Text != nil { + builder.WriteString(*block.Text) + } + } + content := builder.String() + if content != "" { + return content + } + } + } + + // Fallback to first choice if no content found + if result.Choices[0].Message.Content.ContentStr != nil { + return *result.Choices[0].Message.Content.ContentStr + } else if result.Choices[0].Message.Content.ContentBlocks != nil { + var builder strings.Builder + for _, block := range *result.Choices[0].Message.Content.ContentBlocks { + if block.Text != nil { + builder.WriteString(*block.Text) + } + } + return builder.String() + } + return "" +} + +// MergeModelParameters performs a shallow merge of two ModelParameters instances. +// Non-nil fields from the override parameter take precedence over the base parameter. +// Returns a new ModelParameters instance with the merged values. +func MergeModelParameters(base *schemas.ModelParameters, override *schemas.ModelParameters) *schemas.ModelParameters { + if base == nil && override == nil { + return &schemas.ModelParameters{} + } + if base == nil { + return copyModelParameters(override) + } + if override == nil { + return copyModelParameters(base) + } + + // Start with a copy of base parameters + result := copyModelParameters(base) + + // Override with non-nil fields from override + if override.MaxTokens != nil { + result.MaxTokens = override.MaxTokens + } + if override.Temperature != nil { + result.Temperature = override.Temperature + } + if override.TopP != nil { + result.TopP = override.TopP + } + if override.TopK != nil { + result.TopK = override.TopK + } + if override.FrequencyPenalty != nil { + result.FrequencyPenalty = override.FrequencyPenalty + } + if override.PresencePenalty != nil { + result.PresencePenalty = override.PresencePenalty + } + if override.StopSequences != nil { + result.StopSequences = override.StopSequences + } + if override.Tools != nil { + result.Tools = override.Tools + } + if override.ToolChoice != nil { + result.ToolChoice = override.ToolChoice + } + if override.ParallelToolCalls != nil { + result.ParallelToolCalls = override.ParallelToolCalls + } + if override.ExtraParams != nil { + result.ExtraParams = override.ExtraParams + } + + return result +} + +// copyModelParameters creates a deep copy of a ModelParameters instance +func copyModelParameters(src *schemas.ModelParameters) *schemas.ModelParameters { + if src == nil { + return &schemas.ModelParameters{} + } + + return &schemas.ModelParameters{ + MaxTokens: src.MaxTokens, + Temperature: src.Temperature, + TopP: src.TopP, + TopK: src.TopK, + FrequencyPenalty: src.FrequencyPenalty, + PresencePenalty: src.PresencePenalty, + StopSequences: src.StopSequences, + Tools: src.Tools, + ToolChoice: src.ToolChoice, + ParallelToolCalls: src.ParallelToolCalls, + ExtraParams: src.ExtraParams, + } +} diff --git a/tests/core-providers/tests.go b/tests/core-providers/tests.go new file mode 100644 index 0000000000..e229c0b583 --- /dev/null +++ b/tests/core-providers/tests.go @@ -0,0 +1,97 @@ +package tests + +import ( + "context" + "strings" + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + "github.com/maximhq/bifrost/tests/core-providers/scenarios" + + bifrost "github.com/maximhq/bifrost/core" +) + +// TestScenarioFunc defines the function signature for test scenario functions +type TestScenarioFunc func(*testing.T, *bifrost.Bifrost, context.Context, config.ComprehensiveTestConfig) + +// runAllComprehensiveTests executes all comprehensive test scenarios for a given configuration +func runAllComprehensiveTests(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig config.ComprehensiveTestConfig) { + if testConfig.SkipReason != "" { + t.Skipf("Skipping %s: %s", testConfig.Provider, testConfig.SkipReason) + return + } + + t.Logf("πŸš€ Running comprehensive tests for provider: %s", testConfig.Provider) + + // Define all test scenario functions in a slice + testScenarios := []TestScenarioFunc{ + scenarios.RunTextCompletionTest, + scenarios.RunSimpleChatTest, + scenarios.RunMultiTurnConversationTest, + scenarios.RunToolCallsTest, + scenarios.RunMultipleToolCallsTest, + scenarios.RunEnd2EndToolCallingTest, + scenarios.RunAutomaticFunctionCallingTest, + scenarios.RunImageURLTest, + scenarios.RunImageBase64Test, + scenarios.RunMultipleImagesTest, + scenarios.RunCompleteEnd2EndTest, + scenarios.RunProviderSpecificTest, + } + + // Execute all test scenarios + for _, scenarioFunc := range testScenarios { + scenarioFunc(t, client, ctx, testConfig) + } + + // Print comprehensive summary based on configuration + printTestSummary(t, testConfig) +} + +// printTestSummary prints a detailed summary of all test scenarios +func printTestSummary(t *testing.T, testConfig config.ComprehensiveTestConfig) { + testScenarios := []struct { + name string + supported bool + }{ + {"TextCompletion", testConfig.Scenarios.TextCompletion && testConfig.TextModel != ""}, + {"SimpleChat", testConfig.Scenarios.SimpleChat}, + {"MultiTurnConversation", testConfig.Scenarios.MultiTurnConversation}, + {"ToolCalls", testConfig.Scenarios.ToolCalls}, + {"MultipleToolCalls", testConfig.Scenarios.MultipleToolCalls}, + {"End2EndToolCalling", testConfig.Scenarios.End2EndToolCalling}, + {"AutomaticFunctionCall", testConfig.Scenarios.AutomaticFunctionCall}, + {"ImageURL", testConfig.Scenarios.ImageURL}, + {"ImageBase64", testConfig.Scenarios.ImageBase64}, + {"MultipleImages", testConfig.Scenarios.MultipleImages}, + {"CompleteEnd2End", testConfig.Scenarios.CompleteEnd2End}, + {"ProviderSpecific", testConfig.Scenarios.ProviderSpecific}, + } + + supported := 0 + unsupported := 0 + + t.Logf("\n%s", strings.Repeat("=", 80)) + t.Logf("COMPREHENSIVE TEST SUMMARY FOR PROVIDER: %s", strings.ToUpper(string(testConfig.Provider))) + t.Logf("%s", strings.Repeat("=", 80)) + + for _, scenario := range testScenarios { + if scenario.supported { + supported++ + t.Logf("βœ… SUPPORTED: %-25s βœ… Configured to run", scenario.name) + } else { + unsupported++ + t.Logf("❌ UNSUPPORTED: %-25s ❌ Not supported by provider", scenario.name) + } + } + + t.Logf("%s", strings.Repeat("-", 80)) + t.Logf("CONFIGURATION SUMMARY:") + t.Logf(" βœ… Supported Tests: %d", supported) + t.Logf(" ❌ Unsupported Tests: %d", unsupported) + t.Logf(" πŸ“Š Total Test Types: %d", len(testScenarios)) + t.Logf("") + t.Logf("ℹ️ NOTE: Actual PASS/FAIL results are shown in the individual test output above.") + t.Logf("ℹ️ Look for individual test results like 'PASS: TestOpenAI/SimpleChat' or 'FAIL: TestOpenAI/ToolCalls'") + t.Logf("%s\n", strings.Repeat("=", 80)) +} diff --git a/tests/core-providers/vertex_test.go b/tests/core-providers/vertex_test.go new file mode 100644 index 0000000000..c1907522d6 --- /dev/null +++ b/tests/core-providers/vertex_test.go @@ -0,0 +1,43 @@ +package tests + +import ( + "testing" + + "github.com/maximhq/bifrost/tests/core-providers/config" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestVertex(t *testing.T) { + client, ctx, cancel, err := config.SetupTest() + if err != nil { + t.Fatalf("Error initializing test setup: %v", err) + } + defer cancel() + defer client.Cleanup() + + testConfig := config.ComprehensiveTestConfig{ + Provider: schemas.Vertex, + ChatModel: "google/gemini-2.0-flash-001", + TextModel: "", // Vertex doesn't support text completion in newer models + Scenarios: config.TestScenarios{ + TextCompletion: false, // Not supported + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + MultipleToolCalls: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ProviderSpecific: true, + }, + Fallbacks: []schemas.Fallback{ + {Provider: schemas.Anthropic, Model: "claude-3-7-sonnet-20250219"}, + }, + } + + runAllComprehensiveTests(t, client, ctx, testConfig) +} diff --git a/tests/transports-integrations/Makefile b/tests/transports-integrations/Makefile new file mode 100644 index 0000000000..2f0b2dc616 --- /dev/null +++ b/tests/transports-integrations/Makefile @@ -0,0 +1,120 @@ +# Bifrost Python E2E Test Makefile +# Provides convenient commands for running tests + +# Get the directory where this Makefile is located +SCRIPT_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST)))) + +.PHONY: help install test test-all test-parallel test-verbose clean lint format check-env + +# Default target +help: + @echo "Bifrost Python E2E Test Commands:" + @echo "" + @echo "Setup:" + @echo " install Install Python dependencies" + @echo " check-env Check environment variables" + @echo "" + @echo "Testing:" + @echo " test Run all tests using master runner" + @echo " test-all Run all tests with pytest" + @echo " test-parallel Run tests in parallel" + @echo " test-verbose Run tests with verbose output" + @echo " test-openai Run OpenAI integration tests only" + @echo " test-anthropic Run Anthropic integration tests only" + @echo " test-litellm Run LiteLLM integration tests only" + @echo " test-langchain Run LangChain integration tests only" + @echo " test-langgraph Run LangGraph integration tests only" + @echo " test-mistral Run Mistral integration tests only" + @echo " test-genai Run Google GenAI integration tests only" + @echo "" + @echo "Development:" + @echo " lint Run code linting" + @echo " format Format code with black" + @echo " clean Clean up temporary files" + +# Setup commands +install: + pip install -r $(SCRIPT_DIR)requirements.txt + +check-env: + @echo "Checking environment variables..." + @python -c "import os; print('βœ“ BIFROST_BASE_URL:', os.getenv('BIFROST_BASE_URL', 'http://localhost:8080'))" + @python -c "import os; print('βœ“ OPENAI_API_KEY:', 'Set' if os.getenv('OPENAI_API_KEY') else 'Not set')" + @python -c "import os; print('βœ“ ANTHROPIC_API_KEY:', 'Set' if os.getenv('ANTHROPIC_API_KEY') else 'Not set')" + @python -c "import os; print('βœ“ MISTRAL_API_KEY:', 'Set' if os.getenv('MISTRAL_API_KEY') else 'Not set')" + @python -c "import os; print('βœ“ GOOGLE_API_KEY:', 'Set' if os.getenv('GOOGLE_API_KEY') else 'Not set')" + +# Testing commands using master runner +test: + python $(SCRIPT_DIR)run_all_tests.py + +test-parallel: + python $(SCRIPT_DIR)run_all_tests.py --parallel + +test-verbose: + python $(SCRIPT_DIR)run_all_tests.py --verbose + +test-list: + python $(SCRIPT_DIR)run_all_tests.py --list + +# Individual integration tests +test-openai: + python $(SCRIPT_DIR)run_all_tests.py --integration openai --verbose + +test-anthropic: + python $(SCRIPT_DIR)run_all_tests.py --integration anthropic --verbose + +test-litellm: + python $(SCRIPT_DIR)run_all_tests.py --integration litellm --verbose + +test-langchain: + python $(SCRIPT_DIR)run_all_tests.py --integration langchain --verbose + +test-langgraph: + python $(SCRIPT_DIR)run_all_tests.py --integration langgraph --verbose + +test-mistral: + python $(SCRIPT_DIR)run_all_tests.py --integration mistral --verbose + +test-genai: + python $(SCRIPT_DIR)run_all_tests.py --integration genai --verbose + +# Pytest commands +test-all: + pytest -v + +test-pytest-parallel: + pytest -v -n auto + +test-coverage: + pytest --cov=. --cov-report=html --cov-report=term + +# Development commands +lint: + @echo "Running flake8..." + cd $(SCRIPT_DIR) && flake8 *.py + @echo "Running mypy..." + cd $(SCRIPT_DIR) && mypy *.py + +format: + @echo "Formatting code with black..." + cd $(SCRIPT_DIR) && black *.py + +clean: + @echo "Cleaning up temporary files..." + cd $(SCRIPT_DIR) && rm -rf __pycache__/ + cd $(SCRIPT_DIR) && rm -rf .pytest_cache/ + cd $(SCRIPT_DIR) && rm -rf .coverage + cd $(SCRIPT_DIR) && rm -rf htmlcov/ + cd $(SCRIPT_DIR) && rm -rf .mypy_cache/ + cd $(SCRIPT_DIR) && find . -name "*.pyc" -delete + cd $(SCRIPT_DIR) && find . -name "*.pyo" -delete + +# Quick commands for common workflows +quick-test: check-env test + +all-tests: install check-env test-parallel + +dev-setup: install check-env + @echo "Development environment ready!" + @echo "Run 'make test' to execute all tests" \ No newline at end of file diff --git a/tests/transports-integrations/README.md b/tests/transports-integrations/README.md new file mode 100644 index 0000000000..ddb6760f8f --- /dev/null +++ b/tests/transports-integrations/README.md @@ -0,0 +1,1202 @@ +# Bifrost Integration Tests + +Production-ready end-to-end test suite for testing AI integrations through Bifrost proxy. This test suite provides uniform testing across multiple AI integrations with comprehensive coverage of chat, tool calling, image processing, and multimodal workflows. + +## πŸŒ‰ Architecture Overview + +The Bifrost integration tests use a centralized configuration system that routes all AI integration requests through Bifrost as a gateway/proxy: + +```text +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Test Client │───▢│ Bifrost Gateway │───▢│ AI Integration β”‚ +β”‚ β”‚ β”‚ localhost:8080 β”‚ β”‚ (OpenAI, etc.) β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +### URL Structure + +- **Base URL**: `http://localhost:8080` (configurable via `BIFROST_BASE_URL`) +- **Integration Endpoints**: + - OpenAI: `http://localhost:8080/openai` + - Anthropic: `http://localhost:8080/anthropic` + - Google: `http://localhost:8080/genai` + - LiteLLM: `http://localhost:8080/litellm` + +## πŸš€ Features + +- **πŸŒ‰ Bifrost Gateway Integration**: All integrations route through Bifrost proxy +- **πŸ€– Centralized Configuration**: YAML-based configuration with environment variable support +- **πŸ”§ Integration-Specific Clients**: Type-safe, integration-optimized implementations +- **πŸ“‹ Comprehensive Test Coverage**: 11 categories covering all major AI functionality +- **βš™οΈ Flexible Execution**: Selective test running with command-line flags +- **πŸ›‘οΈ Robust Error Handling**: Graceful error handling and detailed error reporting +- **🎯 Production-Ready**: Async support, timeouts, retries, and logging + +## πŸ“‹ Test Categories + +Our test suite covers 11 comprehensive scenarios for each integration: + +1. **Simple Chat** - Basic single-message conversations +2. **Multi-turn Conversation** - Conversation history and context retention +3. **Single Tool Call** - Basic function calling capabilities +4. **Multiple Tool Calls** - Multiple tools in single request +5. **End-to-End Tool Calling** - Complete tool workflow with results +6. **Automatic Function Calling** - Integration-managed tool execution +7. **Image Analysis (URL)** - Image processing from URLs +8. **Image Analysis (Base64)** - Image processing from base64 data +9. **Multiple Images** - Multi-image analysis and comparison +10. **Complex End-to-End** - Comprehensive multimodal workflows +11. **Integration-Specific Features** - Integration-unique capabilities + +## πŸ“ Directory Structure + +```text +transports-integrations/ +β”œβ”€β”€ config.yml # Central configuration file +β”œβ”€β”€ requirements.txt # Python dependencies +β”œβ”€β”€ run_all_tests.py # Test runner script +β”œβ”€β”€ run_integration_tests.py # Integration-specific test runner +β”œβ”€β”€ pytest.ini # Pytest configuration +β”œβ”€β”€ Makefile # Convenience commands +β”œβ”€β”€ tests/ +β”‚ β”œβ”€β”€ conftest.py # Pytest configuration and fixtures +β”‚ β”œβ”€β”€ utils/ +β”‚ β”‚ β”œβ”€β”€ common.py # Shared test utilities and fixtures +β”‚ β”‚ β”œβ”€β”€ config_loader.py # Configuration system +β”‚ β”‚ └── models.py # Model configurations (compatibility layer) +β”‚ └── integrations/ +β”‚ β”œβ”€β”€ test_openai.py # OpenAI integration tests +β”‚ β”œβ”€β”€ test_anthropic.py # Anthropic integration tests +β”‚ β”œβ”€β”€ test_google.py # Google AI integration tests +β”‚ └── test_litellm.py # LiteLLM integration tests +``` + +## ⚑ Quick Start + +### 1. Installation + +```bash +# Clone the repository +git clone +cd bifrost/tests/transports-integrations + +# Option 1: Using Makefile (recommended) +make install + +# Option 2: Direct pip install +pip install -r requirements.txt +``` + +### 2. Configuration + +The system uses `config.yml` for centralized configuration. Set up your environment variables: + +```bash +# Required: Bifrost gateway +export BIFROST_BASE_URL="http://localhost:8080" + +# Required: Integration API keys +export OPENAI_API_KEY="your-openai-key" +export ANTHROPIC_API_KEY="your-anthropic-key" +export GOOGLE_API_KEY="your-google-api-key" + +# Optional: Integration-specific settings +export OPENAI_ORG_ID="org-..." +export OPENAI_PROJECT_ID="proj_..." +export GOOGLE_PROJECT_ID="your-project" +export GOOGLE_LOCATION="us-central1" +export TEST_ENV="development" + +# Quick check using Makefile +make check-env +``` + +### 3. Verify Configuration + +```bash +# Test the configuration system +python tests/utils/config_loader.py +``` + +This will display: + +- πŸŒ‰ Bifrost gateway URLs +- πŸ€– Model configurations +- βš™οΈ API settings +- βœ… Validation status + +### 4. Pytest Configuration + +The project includes a `pytest.ini` file with optimized settings: + +```ini +[pytest] +# Test discovery +testpaths = . +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +# Output formatting +addopts = + -v + --tb=short + --strict-markers + --disable-warnings + --color=yes + +# Timeout settings (3 minutes per test) +timeout = 180 + +# Markers for test categorization +markers = + integration: marks tests as integration tests + slow: marks tests as slow running + e2e: marks tests as end-to-end tests + tool_calling: marks tests as tool calling tests +``` + +### 5. Run Tests + +```bash +# Option 1: Using Makefile (recommended for convenience) +make test # Run all tests using master runner +make test-openai # Run OpenAI tests only +make test-anthropic # Run Anthropic tests only +make test-genai # Run Google GenAI tests only +make test-litellm # Run LiteLLM tests only +make test-verbose # Run all tests with verbose output +make test-parallel # Run tests in parallel + +# Option 2: Using test runner scripts directly +python run_all_tests.py + +# Run specific integration +python run_integration_tests.py openai +python run_integration_tests.py anthropic +python run_integration_tests.py google +python run_integration_tests.py litellm + +# Option 3: Using pytest directly +pytest tests/integrations/test_openai.py -v +``` + +#### Makefile Commands + +The project includes a `Makefile` with convenient commands: + +```bash +# Setup +make install # Install Python dependencies +make check-env # Check environment variables + +# Testing +make test # Run all tests using master runner +make test-all # Run all tests with pytest +make test-parallel # Run tests in parallel +make test-verbose # Run tests with verbose output +make test-openai # Run OpenAI integration tests only +make test-anthropic # Run Anthropic integration tests only +make test-genai # Run Google GenAI integration tests only +make test-litellm # Run LiteLLM integration tests only +make test-coverage # Run tests with coverage report + +# Development +make lint # Run code linting +make format # Format code with black +make clean # Clean up temporary files + +# Quick workflows +make quick-test # Check environment + run tests +make all-tests # Full install + check + parallel tests +make dev-setup # Setup development environment +``` + +## πŸ”§ Configuration System + +### Configuration Files + +#### 1. `config.yml` - Main Configuration + +Central configuration file containing: + +- Bifrost gateway settings and endpoints +- Model configurations for all integrations +- API settings (timeouts, retries) +- Test parameters and limits +- Environment-specific overrides +- Integration-specific settings + +#### 2. `tests/utils/config_loader.py` - Configuration Loader + +Python module that: + +- Loads and parses `config.yml` +- Expands environment variables with `${VAR:-default}` syntax +- Provides convenience functions for URLs and models +- Validates configuration completeness +- Handles error scenarios + +#### 3. `tests/utils/models.py` - Compatibility Layer + +Maintains backward compatibility while delegating to the new config system. + +### Key Configuration Sections + +#### Bifrost Gateway + +```yaml +bifrost: + base_url: "${BIFROST_BASE_URL:-http://localhost:8080}" + endpoints: + openai: "openai" + anthropic: "anthropic" + google: "genai" + litellm: "litellm" +``` + +#### Model Configurations + +```yaml +models: + openai: + chat: "gpt-3.5-turbo" + vision: "gpt-4o" + tools: "gpt-3.5-turbo" + alternatives: ["gpt-4", "gpt-4-turbo-preview", "gpt-4o", "gpt-4o-mini"] +``` + +#### API Settings + +```yaml +api: + timeout: 30 + max_retries: 3 + retry_delay: 1 +``` + +### Usage Examples + +#### Getting Integration URLs + +```python +from tests.utils.config_loader import get_integration_url + +# Get Bifrost URL for OpenAI +openai_url = get_integration_url("openai") +# Returns: http://localhost:8080/openai + +# Get integration URL through Bifrost +openai_url = get_integration_url("openai") +# Returns: http://localhost:8080/openai +``` + +#### Getting Model Names + +```python +from tests.utils.config_loader import get_model + +# Get chat model for OpenAI +chat_model = get_model("openai", "chat") +# Returns: gpt-3.5-turbo + +# Get vision model for Anthropic +vision_model = get_model("anthropic", "vision") +# Returns: claude-3-haiku-20240307 +``` + +## πŸ€– Integration Support + +### Currently Supported Integrations + +#### OpenAI + +- βœ… **Full Bifrost Integration**: Complete base URL support +- βœ… **Models**: gpt-3.5-turbo, gpt-4, gpt-4o, gpt-4o-mini +- βœ… **Features**: Chat, tools, vision +- βœ… **Settings**: Organization/project IDs, timeouts, retries +- βœ… **All Test Categories**: 11/11 scenarios supported + +#### Anthropic + +- βœ… **Full Bifrost Integration**: Complete base URL support +- βœ… **Models**: claude-3-haiku-20240307, claude-3-sonnet-20240229, claude-3-opus-20240229, claude-3-5-sonnet-20241022 +- βœ… **Features**: Chat, tools, vision +- βœ… **Settings**: API version headers, timeouts, retries +- βœ… **All Test Categories**: 11/11 scenarios supported + +#### Google AI + +- βœ… **Full Bifrost Integration**: Complete custom transport implementation +- βœ… **Models**: gemini-2.0-flash-001, gemini-1.5-pro, gemini-1.5-flash, gemini-1.0-pro +- βœ… **Features**: Chat, tools, vision, multimodal processing +- βœ… **Settings**: Project ID, location, API configuration +- βœ… **All Test Categories**: 11/11 scenarios supported +- βœ… **Custom Base64 Handling**: Resolved cross-language encoding compatibility + +#### LiteLLM + +- βœ… **Full Bifrost Integration**: Global base URL configuration +- βœ… **Models**: Supports all LiteLLM-compatible models +- βœ… **Features**: Chat, tools, vision (integration-dependent) +- βœ… **Settings**: Drop params, debug mode, integration-specific configs +- βœ… **All Test Categories**: 11/11 scenarios supported +- βœ… **Multi-Integration**: OpenAI, Anthropic, Google, Azure, Cohere, Mistral, etc. + +## πŸ§ͺ Running Tests + +### Test Execution Methods + +#### 1. Using Test Runner Scripts + +##### `run_integration_tests.py` - Advanced Integration Testing + +```bash +# Basic usage - run all available integrations +python run_integration_tests.py + +# Run specific integration +python run_integration_tests.py --integrations openai + +# Run multiple integrations +python run_integration_tests.py --integrations openai anthropic google + +# Run specific test across integrations +python run_integration_tests.py --integrations openai anthropic --test "test_03_single_tool_call" + +# Run test pattern (e.g., all tool calling tests) +python run_integration_tests.py --integrations google --test "tool_call" + +# Run with verbose output +python run_integration_tests.py --integrations openai --test "test_01_simple_chat" --verbose + +# Utility commands +python run_integration_tests.py --check-keys # Check API key availability +python run_integration_tests.py --show-models # Show model configuration +``` + +##### `run_all_tests.py` - Simple Sequential Testing + +```bash +# Run all integrations sequentially +python run_all_tests.py + +# Run with custom configuration +BIFROST_BASE_URL=https://your-bifrost.com python run_all_tests.py +``` + +#### 2. Using pytest Directly + +```bash +# Run all tests for a integration +pytest tests/integrations/test_openai.py -v + +# Run specific test categories +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_01_simple_chat -v + +# Run with coverage +pytest tests/integrations/ --cov=tests --cov-report=html + +# Run with custom markers +pytest tests/integrations/ -m "not slow" -v +``` + +#### 3. Selective Test Execution + +```bash +# Skip tests that require API keys you don't have +pytest tests/integrations/test_openai.py -v # Will skip if OPENAI_API_KEY not set + +# Run only specific test methods +pytest tests/integrations/test_anthropic.py -k "tool_call" -v + +# Run with timeout +pytest tests/integrations/ --timeout=300 -v +``` + +### πŸ” Checking and Running Specific Tests + +#### πŸš€ Quick Commands (Most Common) + +```bash +# Run specific test for specific integration (your example!) +python run_integration_tests.py --integrations google --test "test_03_single_tool_call" + +# Run all tool calling tests across multiple integrations +python run_integration_tests.py --integrations openai anthropic --test "tool_call" + +# Run all tests for one integration +python run_integration_tests.py --integrations openai -v + +# Check what integrations are available +python run_integration_tests.py --check-keys + +# Run specific test with pytest directly +pytest tests/integrations/test_google.py::TestGoogleIntegration::test_03_single_tool_call -v +``` + +#### Quick Reference: Test Categories + +```text +Test 01: Simple Chat - Basic single-message conversations +Test 02: Multi-turn Conversation - Conversation history and context +Test 03: Single Tool Call - Basic function calling +Test 04: Multiple Tool Calls - Multiple tools in one request +Test 05: End-to-End Tool Calling - Complete tool workflow with results +Test 06: Automatic Function Call - Integration-managed tool execution +Test 07: Image Analysis (URL) - Image processing from URLs +Test 08: Image Analysis (Base64) - Image processing from base64 +Test 09: Multiple Images - Multi-image analysis and comparison +Test 10: Complex End-to-End - Comprehensive multimodal workflows +Test 11: Integration-Specific - Integration-unique features +``` + +#### Listing Available Tests + +```bash +# List all tests for a specific integration +pytest tests/integrations/test_openai.py --collect-only + +# List all test methods with descriptions +pytest tests/integrations/test_openai.py --collect-only -q + +# Show test structure for all integrations +pytest tests/integrations/ --collect-only +``` + +#### Running Individual Test Categories + +```bash +# Test 1: Simple Chat +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_01_simple_chat -v + +# Test 3: Single Tool Call +pytest tests/integrations/test_anthropic.py::TestAnthropicIntegration::test_03_single_tool_call -v + +# Test 7: Image Analysis (URL) +pytest tests/integrations/test_google.py::TestGoogleIntegration::test_07_image_url -v + +# Test 9: Multiple Images +pytest tests/integrations/test_litellm.py::TestLiteLLMIntegration::test_09_multiple_images -v +``` + +#### Running Test Categories by Pattern + +```bash +# Run all simple chat tests across integrations +pytest tests/integrations/ -k "test_01_simple_chat" -v + +# Run all tool calling tests (single and multiple) +pytest tests/integrations/ -k "tool_call" -v + +# Run all image-related tests +pytest tests/integrations/ -k "image" -v + +# Run all end-to-end tests +pytest tests/integrations/ -k "end2end" -v + +# Run integration-specific feature tests +pytest tests/integrations/ -k "test_11_integration_specific" -v +``` + +#### Running Tests by Integration + +```bash +# Run all OpenAI tests +pytest tests/integrations/test_openai.py -v + +# Run all Anthropic tests with detailed output +pytest tests/integrations/test_anthropic.py -v -s + +# Run Google tests with coverage +pytest tests/integrations/test_google.py --cov=tests --cov-report=term-missing -v + +# Run LiteLLM tests with timing +pytest tests/integrations/test_litellm.py --durations=10 -v +``` + +#### Advanced Test Selection + +```bash +# Run tests 1-5 (basic functionality) for OpenAI +pytest tests/integrations/test_openai.py -k "test_01 or test_02 or test_03 or test_04 or test_05" -v + +# Run only vision tests (tests 7, 8, 9, 10) +pytest tests/integrations/ -k "test_07 or test_08 or test_09 or test_10" -v + +# Run tests excluding images (skip tests 7, 8, 9, 10) +pytest tests/integrations/ -k "not (test_07 or test_08 or test_09 or test_10)" -v + +# Run only tool-related tests (tests 3, 4, 5, 6) +pytest tests/integrations/ -k "test_03 or test_04 or test_05 or test_06" -v +``` + +#### Test Status and Validation + +```bash +# Check which tests would run (dry run) +pytest tests/integrations/test_openai.py --collect-only --quiet + +# Validate test setup without running +pytest tests/integrations/test_openai.py --setup-only -v + +# Run tests with immediate failure reporting +pytest tests/integrations/ -x -v # Stop on first failure + +# Run tests with detailed failure information +pytest tests/integrations/ --tb=long -v +``` + +#### Integration-Specific Test Validation + +```bash +# Check if integration supports all test categories +python -c " +from tests.integrations.test_openai import TestOpenAIIntegration +import inspect +methods = [m for m in dir(TestOpenAIIntegration) if m.startswith('test_')] +print('OpenAI Test Methods:') +for i, method in enumerate(sorted(methods), 1): + print(f' {i:2d}. {method}') +print(f'Total: {len(methods)} tests') +" + +# Verify integration configuration +python -c " +from tests.utils.config_loader import get_config, get_model +config = get_config() +integration = 'openai' +print(f'{integration.upper()} Configuration:') +for model_type in ['chat', 'vision', 'tools']: + try: + model = get_model(integration, model_type) + print(f' {model_type}: {model}') + except Exception as e: + print(f' {model_type}: ERROR - {e}') +" +``` + +#### Test Results Analysis + +```bash +# Run tests with detailed reporting +pytest tests/integrations/test_openai.py -v --tb=short --report=term-missing + +# Generate HTML test report +pytest tests/integrations/ --html=test_report.html --self-contained-html + +# Run tests with JSON output for analysis +pytest tests/integrations/test_openai.py --json-report --json-report-file=openai_results.json + +# Compare test results across integrations +pytest tests/integrations/ -v | grep -E "(PASSED|FAILED|SKIPPED)" | sort +``` + +#### Debugging Specific Tests + +```bash +# Debug a failing test with full output +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_03_single_tool_call -v -s --tb=long + +# Run test with Python debugger +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_03_single_tool_call --pdb + +# Run test with custom logging +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_03_single_tool_call --log-cli-level=DEBUG -s + +# Test with environment variable override +OPENAI_API_KEY=sk-test pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_01_simple_chat -v +``` + +#### Practical Testing Scenarios + +```bash +# Scenario 1: Test a new integration integration +# 1. Check configuration +python tests/utils/config_loader.py + +# 2. List available tests +pytest tests/integrations/test_your_integration.py --collect-only + +# 3. Run basic tests first (using test runner) +python run_integration_tests.py --integrations your_integration --test "test_01 or test_02" -v + +# 4. Test tool calling if supported (using test runner) +python run_integration_tests.py --integrations your_integration --test "tool_call" -v + +# Alternative: Direct pytest approach +pytest tests/integrations/test_your_integration.py -k "test_01 or test_02" -v +pytest tests/integrations/test_your_integration.py -k "tool_call" -v + +# Scenario 2: Debug a failing tool call test +# 1. Run with full debugging +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_03_single_tool_call -v -s --tb=long + +# 2. Check tool extraction function +python -c " +from tests.integrations.test_openai import extract_openai_tool_calls +print('Tool extraction function available:', callable(extract_openai_tool_calls)) +" + +# 3. Test with different model +OPENAI_CHAT_MODEL=gpt-4 pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_03_single_tool_call -v + +# Scenario 3: Compare integration capabilities +# Run the same test across all integrations (using test runner) +python run_integration_tests.py --integrations openai anthropic google litellm --test "test_01_simple_chat" -v + +# Alternative: Direct pytest approach +pytest tests/integrations/ -k "test_01_simple_chat" -v --tb=short + +# Scenario 4: Test only supported features +# For a integration that doesn't support images +pytest tests/integrations/test_your_integration.py -k "not (test_07 or test_08 or test_09 or test_10)" -v + +# Scenario 5: Performance testing +# Run with timing to identify slow tests +pytest tests/integrations/test_openai.py --durations=0 -v + +# Scenario 6: Continuous integration testing +# Run all tests with coverage and reports +pytest tests/integrations/ --cov=tests --cov-report=xml --junit-xml=test_results.xml -v +``` + +#### Test Output Examples + +```bash +# Successful test run +$ pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_01_simple_chat -v +========================= test session starts ========================= +tests/integrations/test_openai.py::TestOpenAIIntegration::test_01_simple_chat PASSED [100%] +βœ“ OpenAI simple chat test passed +Response: "Hello! I'm Claude, an AI assistant. How can I help you today?" + +# Failed test with debugging info +$ pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_03_single_tool_call -v -s +========================= FAILURES ========================= +_____________ TestOpenAIIntegration.test_03_single_tool_call _____________ +AssertionError: Expected tool calls but got none +Response content: "I can help with weather information, but I need a specific location." +Tool calls found: [] + +# Test collection output +$ pytest tests/integrations/test_openai.py --collect-only -q +tests/integrations/test_openai.py::TestOpenAIIntegration::test_01_simple_chat +tests/integrations/test_openai.py::TestOpenAIIntegration::test_02_multi_turn_conversation +tests/integrations/test_openai.py::TestOpenAIIntegration::test_03_single_tool_call +tests/integrations/test_openai.py::TestOpenAIIntegration::test_04_multiple_tool_calls +tests/integrations/test_openai.py::TestOpenAIIntegration::test_05_end2end_tool_calling +tests/integrations/test_openai.py::TestOpenAIIntegration::test_06_automatic_function_calling +tests/integrations/test_openai.py::TestOpenAIIntegration::test_07_image_url +tests/integrations/test_openai.py::TestOpenAIIntegration::test_08_image_base64 +tests/integrations/test_openai.py::TestOpenAIIntegration::test_09_multiple_images +tests/integrations/test_openai.py::TestOpenAIIntegration::test_10_complex_end2end +tests/integrations/test_openai.py::TestOpenAIIntegration::test_11_integration_specific_features +11 tests collected + +# Test runner script output +$ python run_integration_tests.py --integrations google --test "test_03_single_tool_call" -v +πŸš€ Starting integration tests... +πŸ“‹ Testing integrations: google +============================================================ +πŸ§ͺ TESTING GOOGLE INTEGRATION +============================================================ +========================= test session starts ========================= +tests/integrations/test_google.py::TestGoogleIntegration::test_03_single_tool_call PASSED [100%] +βœ… GOOGLE tests PASSED + +================================================================================ +🎯 FINAL SUMMARY +================================================================================ + +πŸ”‘ API Key Status: + βœ… GOOGLE: Available + +πŸ“Š Test Results: + βœ… GOOGLE: All tests passed + +πŸ† Overall Results: + Integrations tested: 1 + Integrations passed: 1 + Success rate: 100.0% +``` + +### Environment Variables + +#### Required Variables + +```bash +# Bifrost gateway (required) +export BIFROST_BASE_URL="http://localhost:8080" + +# Integration API keys (at least one required) +export OPENAI_API_KEY="sk-..." +export ANTHROPIC_API_KEY="sk-ant-..." +export GOOGLE_API_KEY="AIza..." +``` + +#### Optional Variables + +```bash +# Integration-specific settings +export OPENAI_ORG_ID="org-..." +export OPENAI_PROJECT_ID="proj_..." +export GOOGLE_PROJECT_ID="your-project" +export GOOGLE_LOCATION="us-central1" + +# Environment configuration +export TEST_ENV="development" # or "production" +``` + +### Test Output and Debugging + +#### Understanding Test Results + +```bash +# Successful test output +βœ“ OpenAI Integration Tests + βœ“ test_01_simple_chat - Response: "Hello! How can I help you today?" + βœ“ test_03_single_tool_call - Tool called: get_weather(location="New York") + βœ“ test_07_image_url - Image analyzed successfully + +# Failed test output +βœ— test_03_single_tool_call - AssertionError: Expected tool calls but got none + Response content: "I can help with weather, but I need a specific location." +``` + +#### Debug Mode + +```bash +# Enable verbose output +pytest tests/integrations/test_openai.py -v -s + +# Show full tracebacks +pytest tests/integrations/test_openai.py --tb=long + +# Enable debug logging +pytest tests/integrations/test_openai.py --log-cli-level=DEBUG +``` + +## πŸ”¨ Adding New Integrations + +### Step-by-Step Guide + +#### 1. Update Configuration + +Add your integration to `config.yml`: + +```yaml +# Add to bifrost endpoints +bifrost: + endpoints: + your_integration: "/your_integration" + +# Add model configuration +models: + your_integration: + chat: "your-chat-model" + vision: "your-vision-model" + tools: "your-tools-model" + alternatives: ["alternative-model-1", "alternative-model-2"] + +# Add model capabilities +model_capabilities: + "your-chat-model": + chat: true + tools: true + vision: false + max_tokens: 4096 + context_window: 8192 + +# Add integration settings +integration_settings: + your_integration: + api_version: "v1" + custom_header: "value" +``` + +#### 2. Create Integration Test File + +Create `tests/integrations/test_your_integration.py`: + +```python +""" +Your Integration Tests + +Tests all 11 core scenarios using Your Integration SDK. +""" + +import pytest +from your_integration_sdk import YourIntegrationClient + +from ..utils.common import ( + Config, + SIMPLE_CHAT_MESSAGES, + MULTI_TURN_MESSAGES, + # ... import all test fixtures + get_api_key, + skip_if_no_api_key, + get_model, +) + + +@pytest.fixture +def your_integration_client(): + """Create Your Integration client for testing""" + from ..utils.config_loader import get_integration_url, get_config + + api_key = get_api_key("your_integration") + base_url = get_integration_url("your_integration") + + # Get additional integration settings + config = get_config() + integration_settings = config.get_integration_settings("your_integration") + api_config = config.get_api_config() + + client_kwargs = { + "api_key": api_key, + "base_url": base_url, + "timeout": api_config.get("timeout", 30), + "max_retries": api_config.get("max_retries", 3), + } + + # Add integration-specific settings + if integration_settings.get("api_version"): + client_kwargs["api_version"] = integration_settings["api_version"] + + return YourIntegrationClient(**client_kwargs) + + +@pytest.fixture +def test_config(): + """Test configuration""" + return Config() + + +class TestYourIntegrationIntegration: + """Test suite for Your Integration covering all 11 core scenarios""" + + @skip_if_no_api_key("your_integration") + def test_01_simple_chat(self, your_integration_client, test_config): + """Test Case 1: Simple chat interaction""" + response = your_integration_client.chat.create( + model=get_model("your_integration", "chat"), + messages=SIMPLE_CHAT_MESSAGES, + max_tokens=100, + ) + + assert_valid_chat_response(response) + assert response.content is not None + assert len(response.content) > 0 + + # ... implement all 11 test methods following the same pattern + # See existing integration test files for complete examples + + +def extract_your_integration_tool_calls(response) -> List[Dict[str, Any]]: + """Extract tool calls from Your Integration response format""" + tool_calls = [] + + # Implement based on your integration's response format + if hasattr(response, 'tool_calls') and response.tool_calls: + for tool_call in response.tool_calls: + tool_calls.append({ + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments) + }) + + return tool_calls +``` + +#### 3. Update Common Utilities + +Add your integration to `tests/utils/common.py`: + +```python +def get_api_key(integration: str) -> str: + """Get API key for integration""" + key_map = { + "openai": "OPENAI_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", + "google": "GOOGLE_API_KEY", + "litellm": "LITELLM_API_KEY", + "your_integration": "YOUR_INTEGRATION_API_KEY", # Add this line + } + + env_var = key_map.get(integration) + if not env_var: + raise ValueError(f"Unknown integration: {integration}") + + api_key = os.getenv(env_var) + if not api_key: + raise ValueError(f"{env_var} environment variable not set") + + return api_key +``` + +#### 4. Add Integration-Specific Tool Extraction + +Update the tool extraction functions in your test file: + +```python +def extract_your_integration_tool_calls(response: Any) -> List[Dict[str, Any]]: + """Extract tool calls from Your Integration response format""" + tool_calls = [] + + try: + # Implement based on your integration's response structure + # Example for a hypothetical integration: + if hasattr(response, 'function_calls'): + for fc in response.function_calls: + tool_calls.append({ + "name": fc.name, + "arguments": fc.parameters + }) + + return tool_calls + + except Exception as e: + print(f"Error extracting tool calls: {e}") + return [] +``` + +#### 5. Test Your Implementation + +```bash +# Set up environment +export YOUR_INTEGRATION_API_KEY="your-api-key" +export BIFROST_BASE_URL="http://localhost:8080" + +# Test configuration +python tests/utils/config_loader.py + +# Run your integration tests +pytest tests/integrations/test_your_integration.py -v + +# Run specific test +pytest tests/integrations/test_your_integration.py::TestYourIntegrationIntegration::test_01_simple_chat -v +``` + +### 🎯 Key Implementation Points + +#### 1. **Follow the Pattern** + +- Use existing integration test files as templates +- Implement all 11 test scenarios +- Follow the same naming conventions and structure + +#### 2. **Handle Integration Differences** + +```python +# Example: Different response formats +def assert_valid_chat_response(response): + """Validate chat response - adapt for your integration""" + if hasattr(response, 'choices'): # OpenAI-style + assert response.choices[0].message.content + elif hasattr(response, 'content'): # Anthropic-style + assert response.content[0].text + elif hasattr(response, 'text'): # Google-style + assert response.text + # Add your integration's format here +``` + +#### 3. **Implement Tool Calling** + +```python +def convert_to_your_integration_tools(tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert common tool format to your integration's format""" + your_integration_tools = [] + + for tool in tools: + # Convert to your integration's tool schema + your_integration_tools.append({ + "name": tool["name"], + "description": tool["description"], + "parameters": tool["parameters"], + # Add integration-specific fields + }) + + return your_integration_tools +``` + +#### 4. **Handle Image Processing** + +```python +def convert_to_your_integration_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert common message format to your integration's format""" + your_integration_messages = [] + + for msg in messages: + if isinstance(msg.get("content"), list): + # Handle multimodal content (text + images) + content = [] + for item in msg["content"]: + if item["type"] == "text": + content.append({"type": "text", "text": item["text"]}) + elif item["type"] == "image_url": + # Convert to your integration's image format + content.append({ + "type": "image", + "source": item["image_url"]["url"] + }) + your_integration_messages.append({"role": msg["role"], "content": content}) + else: + your_integration_messages.append(msg) + + return your_integration_messages +``` + +#### 5. **Error Handling** + +```python +@skip_if_no_api_key("your_integration") +def test_03_single_tool_call(self, your_integration_client, test_config): + """Test Case 3: Single tool call""" + try: + response = your_integration_client.chat.create( + model=get_model("your_integration", "tools"), + messages=SINGLE_TOOL_CALL_MESSAGES, + tools=convert_to_your_integration_tools([WEATHER_TOOL]), + max_tokens=100, + ) + + assert_has_tool_calls(response, expected_count=1) + tool_calls = extract_your_integration_tool_calls(response) + assert tool_calls[0]["name"] == "get_weather" + assert "location" in tool_calls[0]["arguments"] + + except Exception as e: + pytest.skip(f"Tool calling not supported or failed: {e}") +``` + +### πŸ” Testing Checklist + +Before submitting your integration implementation: + +- [ ] **Configuration**: Integration added to `config.yml` with all required sections +- [ ] **Environment**: API key environment variable documented and tested +- [ ] **All 11 Tests**: Every test scenario implemented and passing +- [ ] **Tool Extraction**: Integration-specific tool call extraction function +- [ ] **Message Conversion**: Proper handling of multimodal messages +- [ ] **Error Handling**: Graceful handling of unsupported features +- [ ] **Documentation**: Integration added to README with capabilities +- [ ] **Bifrost Integration**: Base URL properly configured and tested + +### 🚨 Common Pitfalls + +1. **Incorrect Response Parsing**: Each integration has different response formats +2. **Tool Schema Differences**: Tool calling schemas vary significantly +3. **Image Format Handling**: Base64 vs URL handling differs per integration +4. **Missing Error Handling**: Some integrations don't support all features +5. **Configuration Errors**: Forgetting to add integration to all config sections + +## πŸ”§ Troubleshooting + +### Common Issues + +#### 1. Configuration Problems + +```bash +# Error: Configuration file not found +FileNotFoundError: Configuration file not found: config.yml + +# Solution: Ensure config.yml exists in project root +ls -la config.yml +``` + +#### 2. Integration Connection Issues + +```bash +# Error: Connection refused to Bifrost +ConnectionError: Connection refused to localhost:8080 + +# Solutions: +# 1. Check if Bifrost is running +curl http://localhost:8080/health + +# 2. Ensure BIFROST_BASE_URL is set correctly +echo $BIFROST_BASE_URL +``` + +#### 3. API Key Issues + +```bash +# Error: API key not set +ValueError: OPENAI_API_KEY environment variable not set + +# Solution: Set required environment variables +export OPENAI_API_KEY="sk-..." +export ANTHROPIC_API_KEY="sk-ant-..." +export GOOGLE_API_KEY="AIza..." +``` + +#### 4. Model Configuration Errors + +```bash +# Error: Unknown model type +ValueError: Unknown model type 'vision' for integration 'your_integration' + +# Solution: Check config.yml has all model types defined +python tests/utils/config_loader.py +``` + +#### 5. Test Failures + +```bash +# Error: Tool calls not found +AssertionError: Response should contain tool calls + +# Debug steps: +# 1. Check if integration supports tool calling +# 2. Verify tool extraction function +# 3. Check integration-specific tool format +pytest tests/integrations/test_openai.py::TestOpenAIIntegration::test_03_single_tool_call -v -s +``` + +### Debug Mode + +Enable comprehensive debugging: + +```bash +# Full verbose output with debugging +pytest tests/integrations/test_openai.py -v -s --tb=long --log-cli-level=DEBUG + +# Test configuration system +python tests/utils/config_loader.py + +# Check specific integration URL +python -c " +from tests.utils.config_loader import get_integration_url, get_model +print('OpenAI URL:', get_integration_url('openai')) +print('OpenAI Chat Model:', get_model('openai', 'chat')) +" +``` + +## πŸ“š Additional Resources + +### Configuration Examples + +- See `config.yml` for complete configuration reference +- Check `tests/utils/config_loader.py` for usage examples +- Review integration test files for implementation patterns + +### Contributing + +1. Fork the repository +2. Create feature branch: `git checkout -b feature/new-integration` +3. Follow the integration implementation guide above +4. Add comprehensive tests and documentation +5. Submit pull request with test results + +## πŸ†˜ Support + +For issues and questions: + +- Create GitHub issues for bugs and feature requests +- Check existing issues for solutions +- Review integration-specific documentation +- Test configuration with `python tests/utils/config_loader.py` + +--- + +**Note**: This test suite is designed for testing AI integrations through Bifrost proxy. Ensure your Bifrost instance is properly configured and running before executing tests. The configuration system provides Bifrost routing for maximum flexibility. diff --git a/tests/transports-integrations/config.yml b/tests/transports-integrations/config.yml new file mode 100644 index 0000000000..2de74ff9f2 --- /dev/null +++ b/tests/transports-integrations/config.yml @@ -0,0 +1,204 @@ +# Bifrost Integration Tests Configuration +# This file centralizes all configuration for AI integration clients and test settings + +# Bifrost Gateway Configuration +# All integrations route through Bifrost as a proxy/gateway +bifrost: + base_url: "${BIFROST_BASE_URL:-http://localhost:8080}" + + # Integration-specific endpoints (suffixes appended to base_url) + endpoints: + openai: "openai" + anthropic: "anthropic" + google: "genai" + litellm: "litellm" + + # Full URLs constructed as: {base_url.rstrip('/')}/{endpoints[integration]} + # Examples: + # - OpenAI: http://localhost:8080/openai + # - Anthropic: http://localhost:8080/anthropic + # - Google: http://localhost:8080/genai + # - LiteLLM: http://localhost:8080/litellm + +# API Configuration +api: + timeout: 30 # seconds + max_retries: 3 + retry_delay: 1 # seconds + +# Model configurations for each integration +models: + openai: + chat: "gpt-3.5-turbo" + vision: "gpt-4o" + tools: "gpt-3.5-turbo" + alternatives: + - "gpt-4" + - "gpt-4-turbo-preview" + - "gpt-4o" + - "gpt-4o-mini" + + anthropic: + chat: "claude-3-haiku-20240307" + vision: "claude-3-haiku-20240307" + tools: "claude-3-haiku-20240307" + alternatives: + - "claude-3-sonnet-20240229" + - "claude-3-opus-20240229" + - "claude-3-5-sonnet-20241022" + + google: + chat: "gemini-2.0-flash-001" + vision: "gemini-2.0-flash-001" + tools: "gemini-2.0-flash-001" + alternatives: + - "gemini-1.5-pro" + - "gemini-1.5-flash" + - "gemini-1.0-pro" + + litellm: + chat: "gpt-3.5-turbo" # Uses OpenAI by default + vision: "gpt-4o" # Uses OpenAI vision + tools: "gpt-3.5-turbo" # Uses OpenAI for tools + alternatives: + - "claude-3-haiku-20240307" # Anthropic via LiteLLM + - "gemini-pro" # Google via LiteLLM + - "gpt-4" # OpenAI GPT-4 + - "command-r-plus" # Cohere via LiteLLM + +# Model capabilities matrix +model_capabilities: + # OpenAI Models + "gpt-3.5-turbo": + chat: true + tools: true + vision: false + max_tokens: 4096 + context_window: 4096 + + "gpt-4": + chat: true + tools: true + vision: false + max_tokens: 8192 + context_window: 8192 + + "gpt-4o": + chat: true + tools: true + vision: true + max_tokens: 4096 + context_window: 128000 + + "gpt-4o-mini": + chat: true + tools: true + vision: true + max_tokens: 4096 + context_window: 128000 + + # Anthropic Models + "claude-3-haiku-20240307": + chat: true + tools: true + vision: true + max_tokens: 4096 + context_window: 200000 + + "claude-3-sonnet-20240229": + chat: true + tools: true + vision: true + max_tokens: 4096 + context_window: 200000 + + "claude-3-opus-20240229": + chat: true + tools: true + vision: true + max_tokens: 4096 + context_window: 200000 + + # Google Models + "gemini-pro": + chat: true + tools: true + vision: false + max_tokens: 8192 + context_window: 32768 + + "gemini-2.0-flash-001": + chat: true + tools: true + vision: true + max_tokens: 8192 + context_window: 32768 + + "gemini-1.5-pro": + chat: true + tools: true + vision: true + max_tokens: 8192 + context_window: 1000000 + +# Test configuration +test_settings: + # Maximum tokens for test responses + max_tokens: + chat: 100 + vision: 200 + tools: 100 + complex: 300 + + # Timeout settings for tests + timeouts: + simple: 30 # seconds + complex: 60 # seconds + + # Retry settings for flaky tests + retries: + max_attempts: 3 + delay: 2 # seconds + +# Integration-specific settings +integration_settings: + openai: + organization: "${OPENAI_ORG_ID:-}" + project: "${OPENAI_PROJECT_ID:-}" + + anthropic: + version: "2023-06-01" + + google: + project_id: "${GOOGLE_PROJECT_ID:-}" + location: "${GOOGLE_LOCATION:-us-central1}" + + litellm: + drop_params: true + debug: false + +# Environment-specific overrides +environments: + development: + api: + timeout: 60 + max_retries: 5 + test_settings: + timeouts: + simple: 60 + complex: 120 + + production: + api: + timeout: 15 + max_retries: 2 + test_settings: + timeouts: + simple: 20 + complex: 40 + +# Logging configuration +logging: + level: "INFO" + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + file: "tests.log" diff --git a/tests/transports-integrations/pytest.ini b/tests/transports-integrations/pytest.ini new file mode 100644 index 0000000000..6c53a50eab --- /dev/null +++ b/tests/transports-integrations/pytest.ini @@ -0,0 +1,27 @@ +[pytest] +# Test discovery +testpaths = . +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +# Output formatting +addopts = + -v + --tb=short + --strict-markers + --disable-warnings + --color=yes + +# Timeout settings (3 minutes per test) +timeout = 180 + +# Markers for test categorization +markers = + integration: marks tests as integration tests + slow: marks tests as slow running + e2e: marks tests as end-to-end tests + tool_calling: marks tests as tool calling tests + +# Minimum version +minversion = 7.0 \ No newline at end of file diff --git a/tests/transports-integrations/requirements.txt b/tests/transports-integrations/requirements.txt new file mode 100644 index 0000000000..6bbe3d0acd --- /dev/null +++ b/tests/transports-integrations/requirements.txt @@ -0,0 +1,41 @@ +# Core testing framework +pytest>=7.0.0 +pytest-asyncio>=0.21.0 + +# Environment and configuration +python-dotenv>=1.0.0 +PyYAML>=6.0 + +# Image processing +Pillow>=9.0.0 + +# HTTP requests for debugging +requests>=2.28.0 + +# Type hints +typing-extensions>=4.0.0 + +# Optional: For better test reporting +pytest-html>=3.1.0 +pytest-cov>=4.0.0 + +# AI/ML SDK dependencies +openai>=1.30.0 +anthropic>=0.25.0 +litellm>=1.35.0 +langchain-openai>=0.1.0 +langchain-core>=0.2.0 +langchain-anthropic>=0.1.0 +langgraph>=0.1.0 +mistralai>=0.4.0 +google-genai>=1.0.0 + +# Optional testing utilities +httpx>=0.25.0 +pytest-timeout>=2.1.0 +pytest-mock>=3.11.0 + +# Development dependencies (optional) +black>=23.0.0 # Code formatting +flake8>=6.0.0 # Linting +mypy>=1.5.0 # Type checking \ No newline at end of file diff --git a/tests/transports-integrations/run_all_tests.py b/tests/transports-integrations/run_all_tests.py new file mode 100755 index 0000000000..900f63781c --- /dev/null +++ b/tests/transports-integrations/run_all_tests.py @@ -0,0 +1,330 @@ +#!/usr/bin/env python3 +""" +Bifrost Integration End-to-End Test Runner + +This script runs all integration end-to-end tests for Bifrost. +It can run tests individually or all together, providing comprehensive +reporting and flexible execution options. + +Usage: + python run_all_tests.py # Run all tests + python run_all_tests.py --integration openai # Run specific integration + python run_all_tests.py --list # List available integrations + python run_all_tests.py --parallel # Run tests in parallel + python run_all_tests.py --verbose # Verbose output +""" + +import argparse +import subprocess +import sys +import time +import os +from pathlib import Path +from typing import List, Dict, Optional +import concurrent.futures +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + + +class BifrostTestRunner: + """Main test runner for Bifrost integration tests""" + + def __init__(self): + self.test_dir = Path(__file__).parent + self.integrations = { + "openai": { + "file": "tests/integrations/test_openai.py", + "description": "OpenAI Python SDK integration tests", + "env_vars": ["OPENAI_API_KEY"], + }, + "anthropic": { + "file": "tests/integrations/test_anthropic.py", + "description": "Anthropic Python SDK integration tests", + "env_vars": ["ANTHROPIC_API_KEY"], + }, + "litellm": { + "file": "tests/integrations/test_litellm.py", + "description": "LiteLLM integration tests", + "env_vars": ["OPENAI_API_KEY"], # LiteLLM can use OpenAI key + }, + "google": { + "file": "tests/integrations/test_google.py", + "description": "Google GenAI integration tests", + "env_vars": ["GOOGLE_API_KEY"], + }, + } + + self.results = {} + + def check_environment(self, integration: str) -> bool: + """Check if required environment variables are set for an integration""" + config = self.integrations[integration] + missing_vars = [] + + for var in config["env_vars"]: + if not os.getenv(var): + missing_vars.append(var) + + if missing_vars: + print( + f"⚠ Skipping {integration}: Missing environment variables: {', '.join(missing_vars)}" + ) + return False + + return True + + def run_integration_test(self, integration: str, verbose: bool = False) -> Dict: + """Run tests for a specific integration""" + if integration not in self.integrations: + return {"success": False, "error": f"Unknown integration: {integration}"} + + config = self.integrations[integration] + test_file = self.test_dir / config["file"] + + if not test_file.exists(): + return {"success": False, "error": f"Test file not found: {test_file}"} + + # Check environment variables + if not self.check_environment(integration): + return { + "success": False, + "error": "Missing required environment variables", + "skipped": True, + } + + print(f"\n{'='*60}") + print(f"Running {integration.upper()} Integration Tests") + print(f"{'='*60}") + print(f"Description: {config['description']}") + print(f"Test file: {config['file']}") + + start_time = time.time() + + try: + # Run the test with pytest + cmd = [sys.executable, "-m", "pytest", str(test_file)] + + # Add pytest flags for better output + if verbose: + cmd.extend(["-v", "-s"]) # verbose and don't capture output + else: + cmd.append("-q") # quiet mode + + if verbose: + result = subprocess.run( + cmd, cwd=self.test_dir, text=True, capture_output=False, timeout=300 + ) + else: + result = subprocess.run( + cmd, cwd=self.test_dir, text=True, capture_output=True, timeout=300 + ) + + elapsed_time = time.time() - start_time + + success = result.returncode == 0 + + return { + "success": success, + "return_code": result.returncode, + "stdout": result.stdout if not verbose else "", + "stderr": result.stderr if not verbose else "", + "elapsed_time": elapsed_time, + } + + except subprocess.TimeoutExpired: + return { + "success": False, + "error": "Test timed out (5 minutes)", + "elapsed_time": 300, + } + except Exception as e: + return { + "success": False, + "error": str(e), + "elapsed_time": time.time() - start_time, + } + + def run_all_tests(self, parallel: bool = False, verbose: bool = False) -> None: + """Run all integration tests""" + print("Bifrost Integration End-to-End Test Suite") + print("=" * 50) + print(f"Running tests for {len(self.integrations)} integrations") + print(f"Parallel execution: {'Enabled' if parallel else 'Disabled'}") + print(f"Verbose output: {'Enabled' if verbose else 'Disabled'}") + + # Check Bifrost availability + bifrost_url = os.getenv("BIFROST_BASE_URL", "http://localhost:8080") + print(f"Bifrost URL: {bifrost_url}") + + start_time = time.time() + + if parallel: + self._run_parallel(verbose) + else: + self._run_sequential(verbose) + + total_time = time.time() - start_time + self._print_summary(total_time) + + def _run_sequential(self, verbose: bool) -> None: + """Run tests sequentially""" + for integration in self.integrations: + self.results[integration] = self.run_integration_test(integration, verbose) + + def _run_parallel(self, verbose: bool) -> None: + """Run tests in parallel""" + print("\nRunning tests in parallel...") + + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + # Submit all tests + future_to_integration = { + executor.submit( + self.run_integration_test, integration, verbose + ): integration + for integration in self.integrations + } + + # Collect results + for future in concurrent.futures.as_completed(future_to_integration): + integration = future_to_integration[future] + try: + self.results[integration] = future.result() + except Exception as e: + self.results[integration] = {"success": False, "error": str(e)} + + def _print_summary(self, total_time: float) -> None: + """Print test summary""" + print(f"\n{'='*60}") + print("TEST SUMMARY") + print(f"{'='*60}") + + passed = 0 + failed = 0 + skipped = 0 + + for integration, result in self.results.items(): + status = ( + "SKIPPED" + if result.get("skipped") + else ("PASSED" if result["success"] else "FAILED") + ) + elapsed = result.get("elapsed_time", 0) + + if result.get("skipped"): + skipped += 1 + print( + f"⚠ {integration:12} {status:8} - {result.get('error', 'Unknown error')}" + ) + elif result["success"]: + passed += 1 + print(f"βœ“ {integration:12} {status:8} - {elapsed:.2f}s") + else: + failed += 1 + error_msg = result.get("error", "Unknown error") + print(f"βœ— {integration:12} {status:8} - {error_msg}") + + # Print stderr if available + if "stderr" in result and result["stderr"]: + print(f" Error output: {result['stderr'][:200]}...") + + print(f"\n{'='*60}") + print( + f"Total: {len(self.integrations)} | Passed: {passed} | Failed: {failed} | Skipped: {skipped}" + ) + print(f"Total time: {total_time:.2f} seconds") + print(f"{'='*60}") + + # Exit with appropriate code + if failed > 0: + sys.exit(1) + else: + print("All tests completed successfully!") + + def list_integrations(self) -> None: + """List available integrations""" + print("Available Integrations:") + print("=" * 30) + + for integration, config in self.integrations.items(): + env_status = "βœ“" if self.check_environment(integration) else "βœ—" + print(f"{env_status} {integration:12} - {config['description']}") + print(f" Required env vars: {', '.join(config['env_vars'])}") + print() + + +def main(): + parser = argparse.ArgumentParser( + description="Run Bifrost integration end-to-end tests", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python run_all_tests.py # Run all tests + python run_all_tests.py --integration openai # Run OpenAI tests only + python run_all_tests.py --parallel --verbose # Run all tests in parallel with verbose output + python run_all_tests.py --list # List available integrations + """, + ) + + parser.add_argument( + "--integration", "-i", help="Run tests for specific integration only" + ) + + parser.add_argument( + "--list", + "-l", + action="store_true", + help="List available integrations and their status", + ) + + parser.add_argument( + "--parallel", + "-p", + action="store_true", + help="Run tests in parallel (faster but less readable output)", + ) + + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Enable verbose output (shows test output in real-time)", + ) + + args = parser.parse_args() + + runner = BifrostTestRunner() + + if args.list: + runner.list_integrations() + return + + if args.integration: + if args.integration not in runner.integrations: + print(f"Error: Unknown integration '{args.integration}'") + print(f"Available integrations: {', '.join(runner.integrations.keys())}") + sys.exit(1) + + result = runner.run_integration_test(args.integration, args.verbose) + if result["success"]: + print(f"\nβœ“ {args.integration} tests passed!") + else: + error_msg = result.get("error", "Unknown error") + print(f"\nβœ— {args.integration} tests failed: {error_msg}") + + # Show stdout/stderr if available + if result.get("stdout"): + print("\n--- Test Output ---") + print(result["stdout"]) + if result.get("stderr"): + print("\n--- Error Output ---") + print(result["stderr"]) + + sys.exit(1) + else: + runner.run_all_tests(args.parallel, args.verbose) + + +if __name__ == "__main__": + main() diff --git a/tests/transports-integrations/run_integration_tests.py b/tests/transports-integrations/run_integration_tests.py new file mode 100755 index 0000000000..169e7f0f2b --- /dev/null +++ b/tests/transports-integrations/run_integration_tests.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 +""" +Integration-specific test runner for Bifrost integration tests. + +This script runs tests for each integration independently using their native SDKs. +No more complex gateway conversions - just direct testing! +""" + +import os +import sys +import argparse +import subprocess +from pathlib import Path +from typing import List, Optional + + +def check_api_keys(): + """Check which API keys are available""" + keys = { + "openai": os.getenv("OPENAI_API_KEY"), + "anthropic": os.getenv("ANTHROPIC_API_KEY"), + "google": os.getenv("GOOGLE_API_KEY"), + "litellm": os.getenv("LITELLM_API_KEY"), + } + + available = [integration for integration, key in keys.items() if key] + missing = [integration for integration, key in keys.items() if not key] + + return available, missing + + +def run_integration_tests( + integrations: List[str], test_pattern: Optional[str] = None, verbose: bool = False +): + """Run tests for specified integrations""" + + results = {} + + for integration in integrations: + print(f"\n{'='*60}") + print(f"πŸ§ͺ TESTING {integration.upper()} INTEGRATION") + print(f"{'='*60}") + + # Build pytest command with absolute path relative to script location + script_dir = Path(__file__).parent + test_file = script_dir / "tests" / "integrations" / f"test_{integration}.py" + + # Check if test file exists + if not test_file.exists(): + print(f"❌ Test file not found: {test_file}") + results[integration] = {"error": f"Test file not found: {test_file}"} + continue + + cmd = ["python", "-m", "pytest", str(test_file)] + + if test_pattern: + cmd.extend(["-k", test_pattern]) + + if verbose: + cmd.append("-v") + else: + cmd.append("-q") + + # Remove integration-specific marker (not needed for file-based selection) + # cmd.extend(["-m", integration]) + + # Run the tests + try: + result = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + check=True, + ) + results[integration] = { + "returncode": result.returncode, + "stdout": result.stdout, + "stderr": "", # stderr is now captured in stdout + } + + # Print results + print(f"βœ… {integration.upper()} tests PASSED") + + if verbose: + print(result.stdout) + + except subprocess.CalledProcessError as e: + print(f"❌ {integration.upper()} tests FAILED") + results[integration] = { + "returncode": e.returncode, + "stdout": e.stdout, + "stderr": "", # stderr is captured in stdout + } + + # Always print output on failure to show what went wrong + if e.stdout: + print(e.stdout) + + except Exception as e: + print(f"❌ Error running {integration} tests: {e}") + results[integration] = {"error": str(e)} + + return results + + +def print_summary( + results: dict, available_integrations: List[str], missing_integrations: List[str] +): + """Print final summary""" + print(f"\n{'='*80}") + print("🎯 FINAL SUMMARY") + print(f"{'='*80}") + + # API Key Status + print(f"\nπŸ”‘ API Key Status:") + for integration in available_integrations: + print(f" βœ… {integration.upper()}: Available") + + for integration in missing_integrations: + print(f" ❌ {integration.upper()}: Missing API key") + + # Test Results + print(f"\nπŸ“Š Test Results:") + passed_integrations = [] + failed_integrations = [] + + for integration, result in results.items(): + if "error" in result: + print(f" πŸ’₯ {integration.upper()}: Error - {result['error']}") + failed_integrations.append(integration) + elif result["returncode"] == 0: + print(f" βœ… {integration.upper()}: All tests passed") + passed_integrations.append(integration) + else: + print(f" ❌ {integration.upper()}: Some tests failed") + failed_integrations.append(integration) + + # Overall Status + total_tested = len(results) + total_passed = len(passed_integrations) + + print(f"\nπŸ† Overall Results:") + print(f" Integrations tested: {total_tested}") + print(f" Integrations passed: {total_passed}") + print( + f" Success rate: {(total_passed/total_tested)*100:.1f}%" + if total_tested > 0 + else " Success rate: N/A" + ) + + if failed_integrations: + print(f"\n⚠️ Failed integrations: {', '.join(failed_integrations)}") + print(" Check the detailed output above for specific test failures.") + + +def main(): + parser = argparse.ArgumentParser( + description="Run integration-specific integration tests" + ) + parser.add_argument( + "--integrations", + nargs="+", + choices=["openai", "anthropic", "google", "litellm", "all"], + default=["all"], + help="Integrations to test (default: all available)", + ) + parser.add_argument( + "--test", help="Run specific test pattern (e.g., 'test_01_simple_chat')" + ) + parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") + parser.add_argument( + "--check-keys", action="store_true", help="Only check API key availability" + ) + parser.add_argument( + "--show-models", + action="store_true", + help="Show model configuration for all integrations", + ) + + args = parser.parse_args() + + # Check API keys + available_integrations, missing_integrations = check_api_keys() + + if args.check_keys: + print("πŸ”‘ API Key Status:") + for integration in available_integrations: + print(f" βœ… {integration.upper()}: Available") + for integration in missing_integrations: + print(f" ❌ {integration.upper()}: Missing") + return + + if args.show_models: + # Import and show model configuration using absolute path + script_dir = Path(__file__).parent + models_path = script_dir / "tests" / "utils" / "models.py" + + if not models_path.exists(): + print(f"❌ Models file not found: {models_path}") + sys.exit(1) + + # Add the parent directory to sys.path to enable the import + models_parent_dir = str(script_dir) + if models_parent_dir not in sys.path: + sys.path.insert(0, models_parent_dir) + + try: + from tests.utils.models import print_model_summary + + print_model_summary() + except ImportError as e: + print(f"❌ Could not import print_model_summary: {e}") + print(f"Tried to import from: {models_path}") + sys.exit(1) + return + + # Determine which integrations to test + if "all" in args.integrations: + integrations_to_test = available_integrations + requested_integrations = [ + "openai", + "anthropic", + "google", + "litellm", + ] # all possible integrations + else: + integrations_to_test = [ + p for p in args.integrations if p in available_integrations + ] + requested_integrations = args.integrations + + if not integrations_to_test: + print("❌ No integrations available for testing. Please set API keys.") + print("\nRequired environment variables for requested integrations:") + for integration in requested_integrations: + if integration != "all": # Skip the "all" keyword + api_key_name = f"{integration.upper()}_API_KEY" + print(f" - {api_key_name}") + sys.exit(1) + + # Calculate which requested integrations are missing API keys + requested_missing_integrations = [ + integration + for integration in requested_integrations + if integration in missing_integrations + ] + + # Show what we're about to test + print("πŸš€ Starting integration tests...") + print(f"πŸ“‹ Testing integrations: {', '.join(integrations_to_test)}") + if requested_missing_integrations: + print( + f"⏭️ Skipping integrations (no API key): {', '.join(requested_missing_integrations)}" + ) + + # Run tests + results = run_integration_tests(integrations_to_test, args.test, args.verbose) + + # Print summary + print_summary(results, available_integrations, requested_missing_integrations) + + # Exit with appropriate code + failed_count = sum( + 1 for r in results.values() if r.get("returncode", 1) != 0 or "error" in r + ) + sys.exit(failed_count) + + +if __name__ == "__main__": + main() diff --git a/tests/transports-integrations/tests/__init__.py b/tests/transports-integrations/tests/__init__.py new file mode 100644 index 0000000000..92e4c036e6 --- /dev/null +++ b/tests/transports-integrations/tests/__init__.py @@ -0,0 +1,8 @@ +""" +Bifrost Integration Tests + +Production-ready test suite for testing various AI integrations through Bifrost proxy. +Supports multiple integrations with uniform test interface. +""" + +__version__ = "1.0.0" diff --git a/tests/transports-integrations/tests/conftest.py b/tests/transports-integrations/tests/conftest.py new file mode 100644 index 0000000000..9de5dca778 --- /dev/null +++ b/tests/transports-integrations/tests/conftest.py @@ -0,0 +1,159 @@ +""" +Pytest configuration for integration-specific tests. +""" + +import pytest +import os + + +def pytest_configure(config): + """Configure pytest with custom markers""" + config.addinivalue_line("markers", "openai: mark test as requiring OpenAI API key") + config.addinivalue_line( + "markers", "anthropic: mark test as requiring Anthropic API key" + ) + config.addinivalue_line("markers", "google: mark test as requiring Google API key") + config.addinivalue_line("markers", "litellm: mark test as requiring LiteLLM setup") + + +def pytest_collection_modifyitems(config, items): + """Modify test collection to add markers based on test file names""" + for item in items: + # Add markers based on test file location + if "test_openai" in item.nodeid: + item.add_marker(pytest.mark.openai) + elif "test_anthropic" in item.nodeid: + item.add_marker(pytest.mark.anthropic) + elif "test_google" in item.nodeid: + item.add_marker(pytest.mark.google) + elif "test_litellm" in item.nodeid: + item.add_marker(pytest.mark.litellm) + + +@pytest.fixture(scope="session") +def api_keys(): + """Collect all available API keys""" + return { + "openai": os.getenv("OPENAI_API_KEY"), + "anthropic": os.getenv("ANTHROPIC_API_KEY"), + "google": os.getenv("GOOGLE_API_KEY"), + "litellm": os.getenv("LITELLM_API_KEY"), + } + + +@pytest.fixture(scope="session") +def available_integrations(api_keys): + """Determine which integrations are available based on API keys""" + available = [] + + if api_keys["openai"]: + available.append("openai") + if api_keys["anthropic"]: + available.append("anthropic") + if api_keys["google"]: + available.append("google") + if api_keys["litellm"]: + available.append("litellm") + + return available + + +@pytest.fixture +def test_summary(): + """Fixture to collect test results for summary reporting""" + results = {"passed": [], "failed": [], "skipped": []} + return results + + +def pytest_runtest_makereport(item, call): + """Hook to capture test results""" + # Only record results during the "call" phase to avoid double counting + if call.when == "call": + # Extract integration and test info + integration = None + if "test_openai" in item.nodeid: + integration = "openai" + elif "test_anthropic" in item.nodeid: + integration = "anthropic" + elif "test_google" in item.nodeid: + integration = "google" + elif "test_litellm" in item.nodeid: + integration = "litellm" + + test_name = item.name + + # Store result info + result_info = { + "integration": integration, + "test": test_name, + "nodeid": item.nodeid, + } + + if hasattr(item.session, "test_results"): + if call.excinfo is None: + item.session.test_results["passed"].append(result_info) + else: + result_info["error"] = str(call.excinfo.value) + item.session.test_results["failed"].append(result_info) + + +def pytest_sessionstart(session): + """Initialize test results collection""" + session.test_results = {"passed": [], "failed": [], "skipped": []} + + +def pytest_sessionfinish(session, exitstatus): + """Print test summary at the end""" + results = session.test_results + + print("\n" + "=" * 80) + print("INTEGRATION TEST SUMMARY") + print("=" * 80) + + # Group results by integration + integration_results = {} + + for result in results["passed"] + results["failed"] + results["skipped"]: + integration = result.get("integration", "unknown") + if integration not in integration_results: + integration_results[integration] = {"passed": 0, "failed": 0, "skipped": 0} + + for result in results["passed"]: + integration = result.get("integration", "unknown") + integration_results[integration]["passed"] += 1 + + for result in results["failed"]: + integration = result.get("integration", "unknown") + integration_results[integration]["failed"] += 1 + + for result in results["skipped"]: + integration = result.get("integration", "unknown") + integration_results[integration]["skipped"] += 1 + + # Print summary by integration + for integration, counts in integration_results.items(): + total = counts["passed"] + counts["failed"] + counts["skipped"] + if total > 0: + print(f"\n{integration.upper()} Integration:") + print(f" βœ… Passed: {counts['passed']}") + print(f" ❌ Failed: {counts['failed']}") + print(f" ⏭️ Skipped: {counts['skipped']}") + print(f" πŸ“Š Total: {total}") + + if counts["passed"] > 0: + success_rate = ( + (counts["passed"] / (counts["passed"] + counts["failed"])) * 100 + if (counts["passed"] + counts["failed"]) > 0 + else 0 + ) + print(f" 🎯 Success Rate: {success_rate:.1f}%") + + # Print failed tests details + if results["failed"]: + print(f"\n❌ FAILED TESTS ({len(results['failed'])}):") + for result in results["failed"]: + print(f" β€’ {result['integration']}: {result['test']}") + if "error" in result: + print(f" Error: {result['error']}") + + print("\n" + "=" * 80) diff --git a/tests/transports-integrations/tests/integrations/__init__.py b/tests/transports-integrations/tests/integrations/__init__.py new file mode 100644 index 0000000000..ec4135e3b2 --- /dev/null +++ b/tests/transports-integrations/tests/integrations/__init__.py @@ -0,0 +1 @@ +# Integration-specific test packages diff --git a/tests/transports-integrations/tests/integrations/test_anthropic.py b/tests/transports-integrations/tests/integrations/test_anthropic.py new file mode 100644 index 0000000000..0e9c8a9907 --- /dev/null +++ b/tests/transports-integrations/tests/integrations/test_anthropic.py @@ -0,0 +1,568 @@ +""" +Anthropic Integration Tests + +πŸ€– MODELS USED: +- Chat: claude-3-haiku-20240307 +- Vision: claude-3-haiku-20240307 +- Tools: claude-3-haiku-20240307 +- Alternatives: claude-3-sonnet-20240229, claude-3-opus-20240229, claude-3-5-sonnet-20241022 + +Tests all 11 core scenarios using Anthropic SDK directly: +1. Simple chat +2. Multi turn conversation +3. Tool calls +4. Multiple tool calls +5. End2End tool calling +6. Automatic function calling +7. Image (url) +8. Image (base64) +9. Multiple images +10. Complete end2end test with conversation history, tool calls, tool results and images +11. Integration specific tests +""" + +import pytest +import base64 +import requests +from anthropic import Anthropic +from typing import List, Dict, Any + +from ..utils.common import ( + Config, + SIMPLE_CHAT_MESSAGES, + MULTI_TURN_MESSAGES, + SINGLE_TOOL_CALL_MESSAGES, + MULTIPLE_TOOL_CALL_MESSAGES, + IMAGE_URL, + BASE64_IMAGE, + WEATHER_TOOL, + CALCULATOR_TOOL, + mock_tool_response, + assert_valid_chat_response, + assert_has_tool_calls, + assert_valid_image_response, + extract_tool_calls, + get_api_key, + skip_if_no_api_key, + COMPARISON_KEYWORDS, + WEATHER_KEYWORDS, + LOCATION_KEYWORDS, +) +from ..utils.config_loader import get_model + + +@pytest.fixture +def anthropic_client(): + """Create Anthropic client for testing""" + from ..utils.config_loader import get_integration_url, get_config + + api_key = get_api_key("anthropic") + base_url = get_integration_url("anthropic") + + # Get additional integration settings + config = get_config() + integration_settings = config.get_integration_settings("anthropic") + api_config = config.get_api_config() + + client_kwargs = { + "api_key": api_key, + "base_url": base_url, + "timeout": api_config.get("timeout", 30), + "max_retries": api_config.get("max_retries", 3), + } + + # Add Anthropic-specific settings + if integration_settings.get("version"): + client_kwargs["default_headers"] = { + "anthropic-version": integration_settings["version"] + } + + return Anthropic(**client_kwargs) + + +@pytest.fixture +def test_config(): + """Test configuration""" + return Config() + + +def convert_to_anthropic_messages( + messages: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Convert common message format to Anthropic format""" + anthropic_messages = [] + + for msg in messages: + if msg["role"] == "system": + continue # System messages handled separately in Anthropic + + # Handle image messages + if isinstance(msg.get("content"), list): + content = [] + for item in msg["content"]: + if item["type"] == "text": + content.append({"type": "text", "text": item["text"]}) + elif item["type"] == "image_url": + url = item["image_url"]["url"] + if url.startswith("data:image"): + # Base64 image + media_type, data = url.split(",", 1) + content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": data, + }, + } + ) + else: + # URL image - need to download and convert to base64 + response = requests.get(url) + img_data = base64.b64encode(response.content).decode() + content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": img_data, + }, + } + ) + + anthropic_messages.append({"role": msg["role"], "content": content}) + else: + anthropic_messages.append({"role": msg["role"], "content": msg["content"]}) + + return anthropic_messages + + +def convert_to_anthropic_tools(tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert common tool format to Anthropic format""" + anthropic_tools = [] + + for tool in tools: + anthropic_tools.append( + { + "name": tool["name"], + "description": tool["description"], + "input_schema": tool["parameters"], + } + ) + + return anthropic_tools + + +class TestAnthropicIntegration: + """Test suite for Anthropic integration covering all 11 core scenarios""" + + @skip_if_no_api_key("anthropic") + def test_01_simple_chat(self, anthropic_client, test_config): + """Test Case 1: Simple chat interaction""" + messages = convert_to_anthropic_messages(SIMPLE_CHAT_MESSAGES) + + response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), messages=messages, max_tokens=100 + ) + + assert_valid_chat_response(response) + assert len(response.content) > 0 + assert response.content[0].type == "text" + assert len(response.content[0].text) > 0 + + @skip_if_no_api_key("anthropic") + def test_02_multi_turn_conversation(self, anthropic_client, test_config): + """Test Case 2: Multi-turn conversation""" + messages = convert_to_anthropic_messages(MULTI_TURN_MESSAGES) + + response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), messages=messages, max_tokens=150 + ) + + assert_valid_chat_response(response) + content = response.content[0].text.lower() + # Should mention population or numbers since we asked about Paris population + assert any( + word in content + for word in ["population", "million", "people", "inhabitants"] + ) + + @skip_if_no_api_key("anthropic") + def test_03_single_tool_call(self, anthropic_client, test_config): + """Test Case 3: Single tool call""" + messages = convert_to_anthropic_messages(SINGLE_TOOL_CALL_MESSAGES) + tools = convert_to_anthropic_tools([WEATHER_TOOL]) + + response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + messages=messages, + tools=tools, + max_tokens=100, + ) + + assert_has_tool_calls(response, expected_count=1) + tool_calls = extract_tool_calls(response) + assert tool_calls[0]["name"] == "get_weather" + assert "location" in tool_calls[0]["arguments"] + + @skip_if_no_api_key("anthropic") + def test_04_multiple_tool_calls(self, anthropic_client, test_config): + """Test Case 4: Multiple tool calls in one response""" + messages = convert_to_anthropic_messages(MULTIPLE_TOOL_CALL_MESSAGES) + tools = convert_to_anthropic_tools([WEATHER_TOOL, CALCULATOR_TOOL]) + + response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + messages=messages, + tools=tools, + max_tokens=200, + ) + + # Anthropic might be more conservative with multiple tool calls + # Let's check if it made at least one tool call and prefer multiple if possible + assert_has_tool_calls(response) # At least 1 tool call + tool_calls = extract_anthropic_tool_calls(response) + tool_names = [tc["name"] for tc in tool_calls] + + # Should make relevant tool calls - either weather, calculate, or both + expected_tools = ["get_weather", "calculate"] + made_relevant_calls = any(name in expected_tools for name in tool_names) + assert ( + made_relevant_calls + ), f"Expected tool calls from {expected_tools}, got {tool_names}" + + @skip_if_no_api_key("anthropic") + def test_05_end2end_tool_calling(self, anthropic_client, test_config): + """Test Case 5: Complete tool calling flow with responses""" + messages = [{"role": "user", "content": "What's the weather in Boston?"}] + tools = convert_to_anthropic_tools([WEATHER_TOOL]) + + response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + messages=messages, + tools=tools, + max_tokens=100, + ) + + assert_has_tool_calls(response, expected_count=1) + + # Add assistant's response to conversation + messages.append({"role": "assistant", "content": response.content}) + + # Add tool response + tool_calls = extract_anthropic_tool_calls(response) + tool_response = mock_tool_response( + tool_calls[0]["name"], tool_calls[0]["arguments"] + ) + + # Find the tool use block to get its ID + tool_use_id = None + for content in response.content: + if content.type == "tool_use": + tool_use_id = content.id + break + + messages.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": tool_response, + } + ], + } + ) + + # Get final response + final_response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), messages=messages, max_tokens=150 + ) + + # Anthropic might return empty content if tool result is sufficient + assert final_response is not None + if len(final_response.content) > 0: + assert_valid_chat_response(final_response) + content = final_response.content[0].text.lower() + weather_location_keywords = WEATHER_KEYWORDS + LOCATION_KEYWORDS + assert any(word in content for word in weather_location_keywords) + else: + # If no content, that's ok - tool result was sufficient + print("Model returned empty content - tool result was sufficient") + + @skip_if_no_api_key("anthropic") + def test_06_automatic_function_calling(self, anthropic_client, test_config): + """Test Case 6: Automatic function calling""" + messages = [{"role": "user", "content": "Calculate 25 * 4 for me"}] + tools = convert_to_anthropic_tools([CALCULATOR_TOOL]) + + response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + messages=messages, + tools=tools, + max_tokens=100, + ) + + # Should automatically choose to use the calculator + assert_has_tool_calls(response, expected_count=1) + tool_calls = extract_tool_calls(response) + assert tool_calls[0]["name"] == "calculate" + + @skip_if_no_api_key("anthropic") + def test_07_image_url(self, anthropic_client, test_config): + """Test Case 7: Image analysis from URL""" + # Download image and convert to base64 for Anthropic + response_img = requests.get(IMAGE_URL) + img_data = base64.b64encode(response_img.content).decode() + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What do you see in this image?"}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": img_data, + }, + }, + ], + } + ] + + response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), messages=messages, max_tokens=200 + ) + + assert_valid_image_response(response) + + @skip_if_no_api_key("anthropic") + def test_08_image_base64(self, anthropic_client, test_config): + """Test Case 8: Image analysis from base64""" + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": BASE64_IMAGE, + }, + }, + ], + } + ] + + response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), messages=messages, max_tokens=200 + ) + + assert_valid_image_response(response) + + @skip_if_no_api_key("anthropic") + def test_09_multiple_images(self, anthropic_client, test_config): + """Test Case 9: Multiple image analysis""" + # Download first image + response_img = requests.get(IMAGE_URL) + img_data = base64.b64encode(response_img.content).decode() + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Compare these two images"}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": img_data, + }, + }, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": BASE64_IMAGE, + }, + }, + ], + } + ] + + response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), messages=messages, max_tokens=300 + ) + + assert_valid_image_response(response) + content = response.content[0].text.lower() + # Should mention comparison or differences + assert any( + word in content for word in COMPARISON_KEYWORDS + ), f"Response should contain comparison keywords. Got content: {content}" + + @skip_if_no_api_key("anthropic") + def test_10_complex_end2end(self, anthropic_client, test_config): + """Test Case 10: Complex end-to-end with conversation, images, and tools""" + # Download image for Anthropic format + response_img = requests.get(IMAGE_URL) + img_data = base64.b64encode(response_img.content).decode() + + messages = [ + {"role": "user", "content": "Hello! I need help with some tasks."}, + { + "role": "assistant", + "content": "Hello! I'd be happy to help you with your tasks. What do you need assistance with?", + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "First, can you tell me what's in this image and then get the weather for the location shown?", + }, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": img_data, + }, + }, + ], + }, + ] + + tools = convert_to_anthropic_tools([WEATHER_TOOL]) + + response1 = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + messages=messages, + tools=tools, + max_tokens=300, + ) + + # Should either describe image or call weather tool (or both) + assert len(response1.content) > 0 + + # Add response to conversation + messages.append({"role": "assistant", "content": response1.content}) + + # If there were tool calls, handle them + tool_calls = extract_anthropic_tool_calls(response1) + if tool_calls: + for i, tool_call in enumerate(tool_calls): + tool_response = mock_tool_response( + tool_call["name"], tool_call["arguments"] + ) + + # Find the corresponding tool use ID + tool_use_id = None + for content in response1.content: + if content.type == "tool_use" and content.name == tool_call["name"]: + tool_use_id = content.id + break + + messages.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": tool_response, + } + ], + } + ) + + # Get final response after tool calls + final_response = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), messages=messages, max_tokens=200 + ) + + # Anthropic might return empty content if tool result is sufficient + # This is valid behavior - just check that we got a response + assert final_response is not None + if len(final_response.content) > 0: + # If there is content, validate it + assert_valid_chat_response(final_response) + else: + # If no content, that's ok too - tool result was sufficient + print("Model returned empty content - tool result was sufficient") + + @skip_if_no_api_key("anthropic") + def test_11_integration_specific_features(self, anthropic_client, test_config): + """Test Case 11: Anthropic-specific features""" + + # Test 1: System message + response1 = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + system="You are a helpful assistant that always responds in exactly 5 words.", + messages=[{"role": "user", "content": "Hello, how are you?"}], + max_tokens=50, + ) + + assert_valid_chat_response(response1) + # Check if response is approximately 5 words (allow some flexibility) + word_count = len(response1.content[0].text.split()) + assert 3 <= word_count <= 7, f"Expected ~5 words, got {word_count}" + + # Test 2: Temperature parameter + response2 = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + messages=[ + {"role": "user", "content": "Tell me a creative story in one sentence."} + ], + temperature=0.9, + max_tokens=100, + ) + + assert_valid_chat_response(response2) + + # Test 3: Tool choice (any tool) + tools = convert_to_anthropic_tools([CALCULATOR_TOOL, WEATHER_TOOL]) + response3 = anthropic_client.messages.create( + model=get_model("anthropic", "chat"), + messages=[{"role": "user", "content": "What's 15 + 27?"}], + tools=tools, + tool_choice={"type": "any"}, # Force tool use + max_tokens=100, + ) + + assert_has_tool_calls(response3) + tool_calls = extract_anthropic_tool_calls(response3) + # Should prefer calculator for math question + assert tool_calls[0]["name"] == "calculate" + + +# Additional helper functions specific to Anthropic +def extract_anthropic_tool_calls(response: Any) -> List[Dict[str, Any]]: + """Extract tool calls from Anthropic response format with proper type checking""" + tool_calls = [] + + # Type check for Anthropic Message response + if not hasattr(response, "content") or not response.content: + return tool_calls + + for content in response.content: + if hasattr(content, "type") and content.type == "tool_use": + if hasattr(content, "name") and hasattr(content, "input"): + try: + tool_calls.append( + {"name": content.name, "arguments": content.input} + ) + except AttributeError as e: + print(f"Warning: Failed to extract tool call from content: {e}") + continue + + return tool_calls diff --git a/tests/transports-integrations/tests/integrations/test_google.py b/tests/transports-integrations/tests/integrations/test_google.py new file mode 100644 index 0000000000..236a2fed19 --- /dev/null +++ b/tests/transports-integrations/tests/integrations/test_google.py @@ -0,0 +1,448 @@ +""" +Google GenAI Integration Tests + +Tests all 11 core scenarios using Google GenAI SDK directly: +1. Simple chat +2. Multi turn conversation +3. Tool calls +4. Multiple tool calls +5. End2End tool calling +6. Automatic function calling +7. Image (url) +8. Image (base64) +9. Multiple images +10. Complete end2end test with conversation history, tool calls, tool results and images +11. Integration specific tests +""" + +import pytest +import base64 +import requests +from PIL import Image +import io +from google import genai +from google.genai.types import HttpOptions +from typing import List, Dict, Any + +from ..utils.common import ( + Config, + SIMPLE_CHAT_MESSAGES, + SINGLE_TOOL_CALL_MESSAGES, + MULTIPLE_TOOL_CALL_MESSAGES, + IMAGE_URL, + BASE64_IMAGE, + WEATHER_TOOL, + CALCULATOR_TOOL, + assert_valid_chat_response, + assert_valid_image_response, + get_api_key, + skip_if_no_api_key, + COMPARISON_KEYWORDS, + WEATHER_KEYWORDS, + LOCATION_KEYWORDS, +) +from ..utils.config_loader import get_model + + +@pytest.fixture +def google_client(): + """Configure Google GenAI client for testing""" + from ..utils.config_loader import get_integration_url + + api_key = get_api_key("google") + base_url = get_integration_url("google") + + client_kwargs = { + "api_key": api_key, + } + + # Add base URL support and timeout through HttpOptions + http_options_kwargs = {} + if base_url: + http_options_kwargs["base_url"] = base_url + + if http_options_kwargs: + client_kwargs["http_options"] = HttpOptions(**http_options_kwargs) + + return genai.Client(**client_kwargs) + + +@pytest.fixture +def test_config(): + """Test configuration""" + return Config() + + +def convert_to_google_messages(messages: List[Dict[str, Any]]) -> str: + """Convert common message format to Google GenAI format""" + # Google GenAI uses a simpler format - just extract the first user message + for msg in messages: + if msg["role"] == "user": + if isinstance(msg["content"], str): + return msg["content"] + elif isinstance(msg["content"], list): + # Handle multimodal content + text_parts = [ + item["text"] for item in msg["content"] if item["type"] == "text" + ] + if text_parts: + return text_parts[0] + return "Hello" + + +def convert_to_google_tools(tools: List[Dict[str, Any]]) -> List[Any]: + """Convert common tool format to Google GenAI format using FunctionDeclaration""" + from google.genai import types + + google_tools = [] + + for tool in tools: + # Create a FunctionDeclaration for each tool + function_declaration = types.FunctionDeclaration( + name=tool["name"], + description=tool["description"], + parameters=types.Schema( + type=tool["parameters"]["type"].upper(), + properties={ + name: types.Schema( + type=prop["type"].upper(), + description=prop.get("description", ""), + ) + for name, prop in tool["parameters"]["properties"].items() + }, + required=tool["parameters"].get("required", []), + ), + ) + + # Create a Tool object containing the function declaration + google_tool = types.Tool(function_declarations=[function_declaration]) + google_tools.append(google_tool) + + return google_tools + + +def load_image_from_url(url: str): + """Load image from URL for Google GenAI""" + from google.genai import types + import io + import base64 + + if url.startswith("data:image"): + # Base64 image - extract the base64 data part + header, data = url.split(",", 1) + img_data = base64.b64decode(data) + image = Image.open(io.BytesIO(img_data)) + else: + # URL image + response = requests.get(url) + image = Image.open(io.BytesIO(response.content)) + + # Resize image to reduce payload size (max width/height of 512px) + max_size = 512 + if image.width > max_size or image.height > max_size: + image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) + + # Convert to RGB if necessary (for JPEG compatibility) + if image.mode in ("RGBA", "LA", "P"): + # Create a white background + background = Image.new("RGB", image.size, (255, 255, 255)) + if image.mode == "P": + image = image.convert("RGBA") + background.paste( + image, mask=image.split()[-1] if image.mode in ("RGBA", "LA") else None + ) + image = background + + # Convert PIL Image to compressed JPEG bytes + img_byte_arr = io.BytesIO() + image.save(img_byte_arr, format="JPEG", quality=85, optimize=True) + img_byte_arr = img_byte_arr.getvalue() + + # Use the correct Part.from_bytes method as per Google GenAI documentation + return types.Part.from_bytes(data=img_byte_arr, mime_type="image/jpeg") + + +class TestGoogleIntegration: + """Test suite for Google GenAI integration covering all 11 core scenarios""" + + @skip_if_no_api_key("google") + def test_01_simple_chat(self, google_client, test_config): + """Test Case 1: Simple chat interaction""" + message = convert_to_google_messages(SIMPLE_CHAT_MESSAGES) + + response = google_client.models.generate_content( + model=get_model("google", "chat"), contents=message + ) + + assert_valid_chat_response(response) + assert response.text is not None + assert len(response.text) > 0 + + @skip_if_no_api_key("google") + def test_02_multi_turn_conversation(self, google_client, test_config): + """Test Case 2: Multi-turn conversation""" + # Start a chat session for multi-turn + chat = google_client.chats.create(model=get_model("google", "chat")) + + # Send first message + response1 = chat.send_message("What's the capital of France?") + assert_valid_chat_response(response1) + + # Send follow-up message + response2 = chat.send_message("What's the population of that city?") + assert_valid_chat_response(response2) + + content = response2.text.lower() + # Should mention population or numbers since we asked about Paris population + assert any( + word in content + for word in ["population", "million", "people", "inhabitants"] + ) + + @skip_if_no_api_key("google") + def test_03_single_tool_call(self, google_client, test_config): + """Test Case 3: Single tool call""" + from google.genai import types + + tools = convert_to_google_tools([WEATHER_TOOL]) + message = convert_to_google_messages(SINGLE_TOOL_CALL_MESSAGES) + + response = google_client.models.generate_content( + model=get_model("google", "tools"), + contents=message, + config=types.GenerateContentConfig(tools=tools), + ) + + # Check for function calls in response + assert response.candidates is not None + assert len(response.candidates) > 0 + + # Check if function call was made (Google GenAI might return function calls) + if hasattr(response, "function_calls") and response.function_calls: + assert len(response.function_calls) >= 1 + assert response.function_calls[0].name == "get_weather" + + @skip_if_no_api_key("google") + def test_04_multiple_tool_calls(self, google_client, test_config): + """Test Case 4: Multiple tool calls in one response""" + from google.genai import types + + tools = convert_to_google_tools([WEATHER_TOOL, CALCULATOR_TOOL]) + message = convert_to_google_messages(MULTIPLE_TOOL_CALL_MESSAGES) + + response = google_client.models.generate_content( + model=get_model("google", "tools"), + contents=message, + config=types.GenerateContentConfig(tools=tools), + ) + + # Check for function calls + assert response.candidates is not None + + # Check if function calls were made + if hasattr(response, "function_calls") and response.function_calls: + # Should have multiple function calls + assert len(response.function_calls) >= 1 + function_names = [fc.name for fc in response.function_calls] + # At least one of the expected tools should be called + assert any(name in ["get_weather", "calculate"] for name in function_names) + + @skip_if_no_api_key("google") + def test_05_end2end_tool_calling(self, google_client, test_config): + """Test Case 5: Complete tool calling flow with responses""" + from google.genai import types + + tools = convert_to_google_tools([WEATHER_TOOL]) + + # Start chat for tool calling flow + chat = google_client.chats.create(model=get_model("google", "tools")) + + response1 = chat.send_message( + "What's the weather in Boston?", + config=types.GenerateContentConfig(tools=tools), + ) + + # Check if function call was made + if hasattr(response1, "function_calls") and response1.function_calls: + # Simulate function execution and send result back + for fc in response1.function_calls: + if fc.name == "get_weather": + # Mock function result and send back + response2 = chat.send_message( + types.Part.from_function_response( + name=fc.name, + response={ + "result": "The weather in Boston is 72Β°F and sunny." + }, + ) + ) + assert_valid_chat_response(response2) + + content = response2.text.lower() + weather_location_keywords = WEATHER_KEYWORDS + LOCATION_KEYWORDS + assert any(word in content for word in weather_location_keywords) + + @skip_if_no_api_key("google") + def test_06_automatic_function_calling(self, google_client, test_config): + """Test Case 6: Automatic function calling""" + from google.genai import types + + tools = convert_to_google_tools([CALCULATOR_TOOL]) + + response = google_client.models.generate_content( + model=get_model("google", "tools"), + contents="Calculate 25 * 4 for me", + config=types.GenerateContentConfig(tools=tools), + ) + + # Should automatically choose to use the calculator + assert response.candidates is not None + + # Check if function calls were made + if hasattr(response, "function_calls") and response.function_calls: + assert response.function_calls[0].name == "calculate" + + @skip_if_no_api_key("google") + def test_07_image_url(self, google_client, test_config): + """Test Case 7: Image analysis from URL""" + image = load_image_from_url(IMAGE_URL) + + response = google_client.models.generate_content( + model=get_model("google", "vision"), + contents=["What do you see in this image?", image], + ) + + assert_valid_image_response(response) + + @skip_if_no_api_key("google") + def test_08_image_base64(self, google_client, test_config): + """Test Case 8: Image analysis from base64""" + image = load_image_from_url(f"data:image/png;base64,{BASE64_IMAGE}") + + response = google_client.models.generate_content( + model=get_model("google", "vision"), contents=["Describe this image", image] + ) + + assert_valid_image_response(response) + + @skip_if_no_api_key("google") + def test_09_multiple_images(self, google_client, test_config): + """Test Case 9: Multiple image analysis""" + image1 = load_image_from_url(IMAGE_URL) + image2 = load_image_from_url(f"data:image/png;base64,{BASE64_IMAGE}") + + response = google_client.models.generate_content( + model=get_model("google", "vision"), + contents=["Compare these two images", image1, image2], + ) + + assert_valid_image_response(response) + content = response.text.lower() + # Should mention comparison or differences + assert any( + word in content for word in COMPARISON_KEYWORDS + ), f"Response should contain comparison keywords. Got content: {content}" + + @skip_if_no_api_key("google") + def test_10_complex_end2end(self, google_client, test_config): + """Test Case 10: Complex end-to-end with conversation, images, and tools""" + from google.genai import types + + tools = convert_to_google_tools([WEATHER_TOOL]) + + image = load_image_from_url(IMAGE_URL) + + # Start complex conversation + chat = google_client.chats.create(model=get_model("google", "vision")) + + response1 = chat.send_message( + [ + "First, can you tell me what's in this image and then get the weather for the location shown?", + image, + ], + config=types.GenerateContentConfig(tools=tools), + ) + + # Should either describe image or call weather tool (or both) + assert response1.candidates is not None + + # Check for function calls and handle them + if hasattr(response1, "function_calls") and response1.function_calls: + for fc in response1.function_calls: + if fc.name == "get_weather": + # Send function result back + final_response = chat.send_message( + types.Part.from_function_response( + name=fc.name, + response={"result": "The weather is 72Β°F and sunny."}, + ) + ) + assert_valid_chat_response(final_response) + + @skip_if_no_api_key("google") + def test_11_integration_specific_features(self, google_client, test_config): + """Test Case 11: Google GenAI-specific features""" + + # Test 1: Generation config with temperature + from google.genai import types + + response1 = google_client.models.generate_content( + model=get_model("google", "chat"), + contents="Tell me a creative story in one sentence.", + config=types.GenerateContentConfig(temperature=0.9, max_output_tokens=100), + ) + + assert_valid_chat_response(response1) + + # Test 2: Safety settings + response2 = google_client.models.generate_content( + model=get_model("google", "chat"), + contents="Hello, how are you?", + config=types.GenerateContentConfig( + safety_settings=[ + types.SafetySetting( + category="HARM_CATEGORY_HARASSMENT", + threshold="BLOCK_MEDIUM_AND_ABOVE", + ) + ] + ), + ) + + assert_valid_chat_response(response2) + + # Test 3: System instruction + response3 = google_client.models.generate_content( + model=get_model("google", "chat"), + contents="high", + config=types.GenerateContentConfig( + system_instruction="I say high, you say low", + max_output_tokens=10, + ), + ) + + assert_valid_chat_response(response3) + + +# Additional helper functions specific to Google GenAI +def extract_google_function_calls(response: Any) -> List[Dict[str, Any]]: + """Extract function calls from Google GenAI response format with proper type checking""" + function_calls = [] + + # Type check for Google GenAI response + if not hasattr(response, "function_calls") or not response.function_calls: + return function_calls + + for fc in response.function_calls: + if hasattr(fc, "name") and hasattr(fc, "args"): + try: + function_calls.append( + { + "name": fc.name, + "arguments": dict(fc.args) if fc.args else {}, + } + ) + except (AttributeError, TypeError) as e: + print(f"Warning: Failed to extract Google function call: {e}") + continue + + return function_calls diff --git a/tests/transports-integrations/tests/integrations/test_litellm.py b/tests/transports-integrations/tests/integrations/test_litellm.py new file mode 100644 index 0000000000..0c9ff41e6f --- /dev/null +++ b/tests/transports-integrations/tests/integrations/test_litellm.py @@ -0,0 +1,377 @@ +""" +LiteLLM Integration Tests + +πŸ€– MODELS USED: +- Chat: gpt-3.5-turbo (OpenAI via LiteLLM) +- Vision: gpt-4o (OpenAI via LiteLLM) +- Tools: gpt-3.5-turbo (OpenAI via LiteLLM) +- Alternatives: claude-3-haiku-20240307, gemini-pro, gpt-4, command-r-plus + +Tests all 11 core scenarios using LiteLLM SDK directly: +1. Simple chat +2. Multi turn conversation +3. Tool calls +4. Multiple tool calls +5. End2End tool calling +6. Automatic function calling +7. Image (url) +8. Image (base64) +9. Multiple images +10. Complete end2end test with conversation history, tool calls, tool results and images +11. Integration specific tests +""" + +import pytest +import json +import litellm +from typing import List, Dict, Any + +from ..utils.common import ( + Config, + SIMPLE_CHAT_MESSAGES, + MULTI_TURN_MESSAGES, + SINGLE_TOOL_CALL_MESSAGES, + MULTIPLE_TOOL_CALL_MESSAGES, + IMAGE_URL_MESSAGES, + IMAGE_BASE64_MESSAGES, + MULTIPLE_IMAGES_MESSAGES, + COMPLEX_E2E_MESSAGES, + WEATHER_TOOL, + CALCULATOR_TOOL, + mock_tool_response, + assert_valid_chat_response, + assert_has_tool_calls, + assert_valid_image_response, + extract_tool_calls, + COMPARISON_KEYWORDS, + WEATHER_KEYWORDS, + LOCATION_KEYWORDS, +) +from ..utils.config_loader import get_model + + +@pytest.fixture +def test_config(): + """Test configuration""" + return Config() + + +@pytest.fixture(autouse=True) +def setup_litellm(): + """Setup LiteLLM with Bifrost configuration""" + from ..utils.config_loader import get_integration_url, get_config + + # Get Bifrost URL for LiteLLM + base_url = get_integration_url("litellm") + config = get_config() + integration_settings = config.get_integration_settings("litellm") + api_config = config.get_api_config() + + # Configure LiteLLM globally + if base_url: + litellm.api_base = base_url + + # Set timeout and other settings + litellm.request_timeout = api_config.get("timeout", 30) + + # Apply integration-specific settings + if integration_settings.get("drop_params"): + litellm.drop_params = integration_settings["drop_params"] + if integration_settings.get("debug"): + litellm.set_verbose = integration_settings["debug"] + + +def convert_to_litellm_tools(tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert common tool format to LiteLLM format (OpenAI-compatible)""" + return [{"type": "function", "function": tool} for tool in tools] + + +class TestLiteLLMIntegration: + """Test suite for LiteLLM integration covering all 11 core scenarios""" + + def test_01_simple_chat(self, test_config): + """Test Case 1: Simple chat interaction""" + response = litellm.completion( + model=get_model("litellm", "chat"), + messages=SIMPLE_CHAT_MESSAGES, + max_tokens=100, + ) + + assert_valid_chat_response(response) + assert response.choices[0].message.content is not None + assert len(response.choices[0].message.content) > 0 + + def test_02_multi_turn_conversation(self, test_config): + """Test Case 2: Multi-turn conversation""" + response = litellm.completion( + model=get_model("litellm", "chat"), + messages=MULTI_TURN_MESSAGES, + max_tokens=150, + ) + + assert_valid_chat_response(response) + content = response.choices[0].message.content.lower() + # Should mention population or numbers since we asked about Paris population + assert any( + word in content + for word in ["population", "million", "people", "inhabitants"] + ) + + def test_03_single_tool_call(self, test_config): + """Test Case 3: Single tool call""" + tools = convert_to_litellm_tools([WEATHER_TOOL]) + + response = litellm.completion( + model=get_model("litellm", "chat"), + messages=SINGLE_TOOL_CALL_MESSAGES, + tools=tools, + max_tokens=100, + ) + + assert_has_tool_calls(response, expected_count=1) + tool_calls = extract_tool_calls(response) + assert tool_calls[0]["name"] == "get_weather" + assert "location" in tool_calls[0]["arguments"] + + def test_04_multiple_tool_calls(self, test_config): + """Test Case 4: Multiple tool calls in one response""" + tools = convert_to_litellm_tools([WEATHER_TOOL, CALCULATOR_TOOL]) + + response = litellm.completion( + model=get_model("litellm", "chat"), + messages=MULTIPLE_TOOL_CALL_MESSAGES, + tools=tools, + max_tokens=200, + ) + + assert_has_tool_calls(response, expected_count=2) + tool_calls = extract_tool_calls(response) + tool_names = [tc["name"] for tc in tool_calls] + assert "get_weather" in tool_names + assert "calculate" in tool_names + + def test_05_end2end_tool_calling(self, test_config): + """Test Case 5: Complete tool calling flow with responses""" + messages = [{"role": "user", "content": "What's the weather in Boston?"}] + tools = convert_to_litellm_tools([WEATHER_TOOL]) + + response = litellm.completion( + model=get_model("litellm", "chat"), + messages=messages, + tools=tools, + max_tokens=100, + ) + + assert_has_tool_calls(response, expected_count=1) + + # Add assistant's tool call to conversation + messages.append(response.choices[0].message) + + # Add tool response + tool_calls = extract_litellm_tool_calls(response) + tool_response = mock_tool_response( + tool_calls[0]["name"], tool_calls[0]["arguments"] + ) + + messages.append( + { + "role": "tool", + "tool_call_id": response.choices[0].message.tool_calls[0].id, + "content": tool_response, + } + ) + + # Get final response + final_response = litellm.completion( + model=get_model("litellm", "chat"), messages=messages, max_tokens=150 + ) + + assert_valid_chat_response(final_response) + content = final_response.choices[0].message.content.lower() + weather_location_keywords = WEATHER_KEYWORDS + LOCATION_KEYWORDS + assert any(word in content for word in weather_location_keywords) + + def test_06_automatic_function_calling(self, test_config): + """Test Case 6: Automatic function calling""" + tools = convert_to_litellm_tools([CALCULATOR_TOOL]) + + response = litellm.completion( + model=get_model("litellm", "chat"), + messages=[{"role": "user", "content": "Calculate 25 * 4 for me"}], + tools=tools, + tool_choice="auto", + max_tokens=100, + ) + + # Should automatically choose to use the calculator + assert_has_tool_calls(response, expected_count=1) + tool_calls = extract_litellm_tool_calls(response) + assert tool_calls[0]["name"] == "calculate" + + def test_07_image_url(self, test_config): + """Test Case 7: Image analysis from URL""" + response = litellm.completion( + model=get_model("litellm", "vision"), + messages=IMAGE_URL_MESSAGES, + max_tokens=200, + ) + + assert_valid_image_response(response) + + def test_08_image_base64(self, test_config): + """Test Case 8: Image analysis from base64""" + response = litellm.completion( + model=get_model("litellm", "vision"), + messages=IMAGE_BASE64_MESSAGES, + max_tokens=200, + ) + + assert_valid_image_response(response) + + def test_09_multiple_images(self, test_config): + """Test Case 9: Multiple image analysis""" + response = litellm.completion( + model=get_model("litellm", "vision"), + messages=MULTIPLE_IMAGES_MESSAGES, + max_tokens=300, + ) + + assert_valid_image_response(response) + content = response.choices[0].message.content.lower() + # Should mention comparison or differences + assert any( + word in content for word in COMPARISON_KEYWORDS + ), f"Response should contain comparison keywords. Got content: {content}" + + def test_10_complex_end2end(self, test_config): + """Test Case 10: Complex end-to-end with conversation, images, and tools""" + messages = COMPLEX_E2E_MESSAGES.copy() + tools = convert_to_litellm_tools([WEATHER_TOOL]) + + # First, analyze the image + response1 = litellm.completion( + model=get_model("litellm", "vision"), + messages=messages, + tools=tools, + max_tokens=300, + ) + + # Should either describe image or call weather tool (or both) + assert ( + response1.choices[0].message.content is not None + or response1.choices[0].message.tool_calls is not None + ) + + # Add response to conversation + messages.append(response1.choices[0].message) + + # If there were tool calls, handle them + if response1.choices[0].message.tool_calls: + for tool_call in response1.choices[0].message.tool_calls: + tool_name = tool_call.function.name + tool_args = json.loads(tool_call.function.arguments) + tool_response = mock_tool_response(tool_name, tool_args) + + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": tool_response, + } + ) + + # Get final response after tool calls + final_response = litellm.completion( + model=get_model("litellm", "vision"), messages=messages, max_tokens=200 + ) + + assert_valid_chat_response(final_response) + + def test_11_integration_specific_features(self, test_config): + """Test Case 11: LiteLLM-specific features""" + + # Test 1: Multiple integrations through LiteLLM + integrations_to_test = [ + "gpt-3.5-turbo", # OpenAI + "claude-3-haiku-20240307", # Anthropic + # Add more integrations as needed + ] + + for model in integrations_to_test: + try: + response = litellm.completion( + model=model, + messages=[{"role": "user", "content": "Hello, how are you?"}], + max_tokens=50, + ) + + assert_valid_chat_response(response) + + except Exception as e: + # Some integrations might not be available, skip gracefully + pytest.skip(f"Integration {model} not available: {e}") + + # Test 2: Function calling with specific tool choice + tools = convert_to_litellm_tools([CALCULATOR_TOOL, WEATHER_TOOL]) + + response2 = litellm.completion( + model=get_model("litellm", "chat"), + messages=[{"role": "user", "content": "What's 15 + 27?"}], + tools=tools, + tool_choice={"type": "function", "function": {"name": "calculate"}}, + max_tokens=100, + ) + + assert_has_tool_calls(response2, expected_count=1) + tool_calls = extract_litellm_tool_calls(response2) + assert tool_calls[0]["name"] == "calculate" + + # Test 3: Temperature and other parameters + response3 = litellm.completion( + model=get_model("litellm", "chat"), + messages=[ + {"role": "user", "content": "Tell me a creative story in one sentence."} + ], + temperature=0.9, + top_p=0.9, + max_tokens=100, + ) + + assert_valid_chat_response(response3) + + +# Additional helper functions specific to LiteLLM +def extract_litellm_tool_calls(response: Any) -> List[Dict[str, Any]]: + """Extract tool calls from LiteLLM response format (OpenAI-compatible) with proper type checking""" + tool_calls = [] + + # Type check for LiteLLM response (OpenAI-compatible format) + if not hasattr(response, "choices") or not response.choices: + return tool_calls + + choice = response.choices[0] + if not hasattr(choice, "message") or not hasattr(choice.message, "tool_calls"): + return tool_calls + + if not choice.message.tool_calls: + return tool_calls + + for tool_call in choice.message.tool_calls: + if hasattr(tool_call, "function") and hasattr(tool_call.function, "name"): + try: + arguments = ( + json.loads(tool_call.function.arguments) + if isinstance(tool_call.function.arguments, str) + else tool_call.function.arguments + ) + tool_calls.append( + { + "name": tool_call.function.name, + "arguments": arguments, + } + ) + except (json.JSONDecodeError, AttributeError) as e: + print(f"Warning: Failed to parse LiteLLM tool call arguments: {e}") + continue + + return tool_calls diff --git a/tests/transports-integrations/tests/integrations/test_openai.py b/tests/transports-integrations/tests/integrations/test_openai.py new file mode 100644 index 0000000000..09aa81ee51 --- /dev/null +++ b/tests/transports-integrations/tests/integrations/test_openai.py @@ -0,0 +1,391 @@ +""" +OpenAI Integration Tests + +πŸ€– MODELS USED: +- Chat: gpt-3.5-turbo +- Vision: gpt-4o +- Tools: gpt-3.5-turbo +- Alternatives: gpt-4, gpt-4-turbo-preview, gpt-4o, gpt-4o-mini + +Tests all 11 core scenarios using OpenAI SDK directly: +1. Simple chat +2. Multi turn conversation +3. Tool calls +4. Multiple tool calls +5. End2End tool calling +6. Automatic function calling +7. Image (url) +8. Image (base64) +9. Multiple images +10. Complete end2end test with conversation history, tool calls, tool results and images +11. Integration specific tests +""" + +import pytest +import json +from openai import OpenAI +from typing import List, Dict, Any + +from ..utils.common import ( + Config, + SIMPLE_CHAT_MESSAGES, + MULTI_TURN_MESSAGES, + SINGLE_TOOL_CALL_MESSAGES, + MULTIPLE_TOOL_CALL_MESSAGES, + IMAGE_URL_MESSAGES, + IMAGE_BASE64_MESSAGES, + MULTIPLE_IMAGES_MESSAGES, + COMPLEX_E2E_MESSAGES, + WEATHER_TOOL, + CALCULATOR_TOOL, + mock_tool_response, + assert_valid_chat_response, + assert_has_tool_calls, + assert_valid_image_response, + extract_tool_calls, + get_api_key, + skip_if_no_api_key, + COMPARISON_KEYWORDS, + WEATHER_KEYWORDS, + LOCATION_KEYWORDS, +) +from ..utils.config_loader import get_model + + +# Helper functions (defined early for use in test methods) +def extract_openai_tool_calls(response: Any) -> List[Dict[str, Any]]: + """Extract tool calls from OpenAI response format with proper type checking""" + tool_calls = [] + + # Type check for OpenAI ChatCompletion response + if not hasattr(response, "choices") or not response.choices: + return tool_calls + + choice = response.choices[0] + if not hasattr(choice, "message") or not hasattr(choice.message, "tool_calls"): + return tool_calls + + if not choice.message.tool_calls: + return tool_calls + + for tool_call in choice.message.tool_calls: + if hasattr(tool_call, "function") and hasattr(tool_call.function, "name"): + try: + arguments = ( + json.loads(tool_call.function.arguments) + if isinstance(tool_call.function.arguments, str) + else tool_call.function.arguments + ) + tool_calls.append( + { + "name": tool_call.function.name, + "arguments": arguments, + } + ) + except (json.JSONDecodeError, AttributeError) as e: + print(f"Warning: Failed to parse tool call arguments: {e}") + continue + + return tool_calls + + +def convert_to_openai_tools(tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert common tool format to OpenAI format""" + return [{"type": "function", "function": tool} for tool in tools] + + +@pytest.fixture +def openai_client(): + """Create OpenAI client for testing""" + from ..utils.config_loader import get_integration_url, get_config + + api_key = get_api_key("openai") + base_url = get_integration_url("openai") + + # Get additional integration settings + config = get_config() + integration_settings = config.get_integration_settings("openai") + api_config = config.get_api_config() + + client_kwargs = { + "api_key": api_key, + "base_url": base_url, + "timeout": api_config.get("timeout", 30), + "max_retries": api_config.get("max_retries", 3), + } + + # Add optional OpenAI-specific settings + if integration_settings.get("organization"): + client_kwargs["organization"] = integration_settings["organization"] + if integration_settings.get("project"): + client_kwargs["project"] = integration_settings["project"] + + return OpenAI(**client_kwargs) + + +@pytest.fixture +def test_config(): + """Test configuration""" + return Config() + + +class TestOpenAIIntegration: + """Test suite for OpenAI integration covering all 11 core scenarios""" + + @skip_if_no_api_key("openai") + def test_01_simple_chat(self, openai_client, test_config): + """Test Case 1: Simple chat interaction""" + response = openai_client.chat.completions.create( + model=get_model("openai", "chat"), + messages=SIMPLE_CHAT_MESSAGES, + max_tokens=100, + ) + + assert_valid_chat_response(response) + assert response.choices[0].message.content is not None + assert len(response.choices[0].message.content) > 0 + + @skip_if_no_api_key("openai") + def test_02_multi_turn_conversation(self, openai_client, test_config): + """Test Case 2: Multi-turn conversation""" + response = openai_client.chat.completions.create( + model=get_model("openai", "chat"), + messages=MULTI_TURN_MESSAGES, + max_tokens=150, + ) + + assert_valid_chat_response(response) + content = response.choices[0].message.content.lower() + # Should mention population or numbers since we asked about Paris population + assert any( + word in content + for word in ["population", "million", "people", "inhabitants"] + ) + + @skip_if_no_api_key("openai") + def test_03_single_tool_call(self, openai_client, test_config): + """Test Case 3: Single tool call""" + response = openai_client.chat.completions.create( + model=get_model("openai", "tools"), + messages=SINGLE_TOOL_CALL_MESSAGES, + tools=[{"type": "function", "function": WEATHER_TOOL}], + max_tokens=100, + ) + + assert_has_tool_calls(response, expected_count=1) + tool_calls = extract_tool_calls(response) + assert tool_calls[0]["name"] == "get_weather" + assert "location" in tool_calls[0]["arguments"] + + @skip_if_no_api_key("openai") + def test_04_multiple_tool_calls(self, openai_client, test_config): + """Test Case 4: Multiple tool calls in one response""" + response = openai_client.chat.completions.create( + model=get_model("openai", "tools"), + messages=MULTIPLE_TOOL_CALL_MESSAGES, + tools=[ + {"type": "function", "function": WEATHER_TOOL}, + {"type": "function", "function": CALCULATOR_TOOL}, + ], + max_tokens=200, + ) + + assert_has_tool_calls(response, expected_count=2) + tool_calls = extract_openai_tool_calls(response) + tool_names = [tc["name"] for tc in tool_calls] + assert "get_weather" in tool_names + assert "calculate" in tool_names + + @skip_if_no_api_key("openai") + def test_05_end2end_tool_calling(self, openai_client, test_config): + """Test Case 5: Complete tool calling flow with responses""" + # Initial request + messages = [{"role": "user", "content": "What's the weather in Boston?"}] + + response = openai_client.chat.completions.create( + model=get_model("openai", "tools"), + messages=messages, + tools=[{"type": "function", "function": WEATHER_TOOL}], + max_tokens=100, + ) + + assert_has_tool_calls(response, expected_count=1) + + # Add assistant's tool call to conversation + messages.append(response.choices[0].message) + + # Add tool response + tool_calls = extract_openai_tool_calls(response) + tool_response = mock_tool_response( + tool_calls[0]["name"], tool_calls[0]["arguments"] + ) + + messages.append( + { + "role": "tool", + "tool_call_id": response.choices[0].message.tool_calls[0].id, + "content": tool_response, + } + ) + + # Get final response + final_response = openai_client.chat.completions.create( + model=get_model("openai", "tools"), messages=messages, max_tokens=150 + ) + + assert_valid_chat_response(final_response) + content = final_response.choices[0].message.content.lower() + weather_location_keywords = WEATHER_KEYWORDS + LOCATION_KEYWORDS + assert any(word in content for word in weather_location_keywords) + + @skip_if_no_api_key("openai") + def test_06_automatic_function_calling(self, openai_client, test_config): + """Test Case 6: Automatic function calling (tool_choice='auto')""" + response = openai_client.chat.completions.create( + model=get_model("openai", "tools"), + messages=[{"role": "user", "content": "Calculate 25 * 4 for me"}], + tools=[{"type": "function", "function": CALCULATOR_TOOL}], + tool_choice="auto", # Let model decide + max_tokens=100, + ) + + # Should automatically choose to use the calculator + assert_has_tool_calls(response, expected_count=1) + tool_calls = extract_openai_tool_calls(response) + assert tool_calls[0]["name"] == "calculate" + + @skip_if_no_api_key("openai") + def test_07_image_url(self, openai_client, test_config): + """Test Case 7: Image analysis from URL""" + response = openai_client.chat.completions.create( + model=get_model("openai", "vision"), + messages=IMAGE_URL_MESSAGES, + max_tokens=200, + ) + + assert_valid_image_response(response) + + @skip_if_no_api_key("openai") + def test_08_image_base64(self, openai_client, test_config): + """Test Case 8: Image analysis from base64""" + response = openai_client.chat.completions.create( + model=get_model("openai", "vision"), + messages=IMAGE_BASE64_MESSAGES, + max_tokens=200, + ) + + assert_valid_image_response(response) + + @skip_if_no_api_key("openai") + def test_09_multiple_images(self, openai_client, test_config): + """Test Case 9: Multiple image analysis""" + response = openai_client.chat.completions.create( + model=get_model("openai", "vision"), + messages=MULTIPLE_IMAGES_MESSAGES, + max_tokens=300, + ) + + assert_valid_image_response(response) + content = response.choices[0].message.content.lower() + # Should mention comparison or differences (flexible matching) + assert any( + word in content for word in COMPARISON_KEYWORDS + ), f"Response should contain comparison keywords. Got content: {content}" + + @skip_if_no_api_key("openai") + def test_10_complex_end2end(self, openai_client, test_config): + """Test Case 10: Complex end-to-end with conversation, images, and tools""" + messages = COMPLEX_E2E_MESSAGES.copy() + + # First, analyze the image + response1 = openai_client.chat.completions.create( + model=get_model("openai", "vision"), + messages=messages, + tools=[{"type": "function", "function": WEATHER_TOOL}], + max_tokens=300, + ) + + # Should either describe image or call weather tool (or both) + assert ( + response1.choices[0].message.content is not None + or response1.choices[0].message.tool_calls is not None + ) + + # Add response to conversation + messages.append(response1.choices[0].message) + + # If there were tool calls, handle them + if response1.choices[0].message.tool_calls: + for tool_call in response1.choices[0].message.tool_calls: + tool_name = tool_call.function.name + tool_args = json.loads(tool_call.function.arguments) + tool_response = mock_tool_response(tool_name, tool_args) + + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": tool_response, + } + ) + + # Get final response after tool calls + final_response = openai_client.chat.completions.create( + model=get_model("openai", "vision"), messages=messages, max_tokens=200 + ) + + assert_valid_chat_response(final_response) + + @skip_if_no_api_key("openai") + def test_11_integration_specific_features(self, openai_client, test_config): + """Test Case 11: OpenAI-specific features""" + + # Test 1: Function calling with specific tool choice + response1 = openai_client.chat.completions.create( + model=get_model("openai", "tools"), + messages=[{"role": "user", "content": "What's 15 + 27?"}], + tools=[ + {"type": "function", "function": CALCULATOR_TOOL}, + {"type": "function", "function": WEATHER_TOOL}, + ], + tool_choice={ + "type": "function", + "function": {"name": "calculate"}, + }, # Force specific tool + max_tokens=100, + ) + + assert_has_tool_calls(response1, expected_count=1) + tool_calls = extract_openai_tool_calls(response1) + assert tool_calls[0]["name"] == "calculate" + + # Test 2: System message + response2 = openai_client.chat.completions.create( + model=get_model("openai", "chat"), + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that always responds in exactly 5 words.", + }, + {"role": "user", "content": "Hello, how are you?"}, + ], + max_tokens=50, + ) + + assert_valid_chat_response(response2) + # Check if response is approximately 5 words (allow some flexibility) + word_count = len(response2.choices[0].message.content.split()) + assert 3 <= word_count <= 7, f"Expected ~5 words, got {word_count}" + + # Test 3: Temperature and top_p parameters + response3 = openai_client.chat.completions.create( + model=get_model("openai", "chat"), + messages=[ + {"role": "user", "content": "Tell me a creative story in one sentence."} + ], + temperature=0.9, + top_p=0.9, + max_tokens=100, + ) + + assert_valid_chat_response(response3) diff --git a/tests/transports-integrations/tests/utils/__init__.py b/tests/transports-integrations/tests/utils/__init__.py new file mode 100644 index 0000000000..d0ba24ae94 --- /dev/null +++ b/tests/transports-integrations/tests/utils/__init__.py @@ -0,0 +1 @@ +# Utils package for shared test utilities diff --git a/tests/transports-integrations/tests/utils/common.py b/tests/transports-integrations/tests/utils/common.py new file mode 100644 index 0000000000..64fea1cfbf --- /dev/null +++ b/tests/transports-integrations/tests/utils/common.py @@ -0,0 +1,484 @@ +""" +Common utilities and test data for all integration tests. +This module contains shared functions, test data, and assertions +that can be used across all integration-specific test files. +""" + +import ast +import base64 +import json +import operator +import os +from typing import Dict, List, Any, Optional +from dataclasses import dataclass + + +# Test Configuration +@dataclass +class Config: + """Configuration for test execution""" + + timeout: int = 30 + max_retries: int = 3 + debug: bool = False + + +# Common Test Data +SIMPLE_CHAT_MESSAGES = [{"role": "user", "content": "Hello! How are you today?"}] + +MULTI_TURN_MESSAGES = [ + {"role": "user", "content": "What's the capital of France?"}, + {"role": "assistant", "content": "The capital of France is Paris."}, + {"role": "user", "content": "What's the population of that city?"}, +] + +# Tool Definitions +WEATHER_TOOL = { + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit", + }, + }, + "required": ["location"], + }, +} + +CALCULATOR_TOOL = { + "name": "calculate", + "description": "Perform basic mathematical calculations", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Mathematical expression to evaluate, e.g. '2 + 2'", + } + }, + "required": ["expression"], + }, +} + +SEARCH_TOOL = { + "name": "search_web", + "description": "Search the web for information", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string", "description": "Search query"}}, + "required": ["query"], + }, +} + +ALL_TOOLS = [WEATHER_TOOL, CALCULATOR_TOOL, SEARCH_TOOL] + +# Tool Call Test Messages +SINGLE_TOOL_CALL_MESSAGES = [ + {"role": "user", "content": "What's the weather like in San Francisco?"} +] + +MULTIPLE_TOOL_CALL_MESSAGES = [ + {"role": "user", "content": "What's the weather in New York and calculate 15 * 23?"} +] + +# Image Test Data +IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + +# Small test image as base64 (1x1 pixel red PNG) +BASE64_IMAGE = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" + +IMAGE_URL_MESSAGES = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What do you see in this image?"}, + {"type": "image_url", "image_url": {"url": IMAGE_URL}}, + ], + } +] + +IMAGE_BASE64_MESSAGES = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{BASE64_IMAGE}"}, + }, + ], + } +] + +MULTIPLE_IMAGES_MESSAGES = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Compare these two images"}, + {"type": "image_url", "image_url": {"url": IMAGE_URL}}, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{BASE64_IMAGE}"}, + }, + ], + } +] + +# Complex End-to-End Test Data +COMPLEX_E2E_MESSAGES = [ + {"role": "user", "content": "Hello! I need help with some tasks."}, + { + "role": "assistant", + "content": "Hello! I'd be happy to help you with your tasks. What do you need assistance with?", + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "First, can you tell me what's in this image and then get the weather for the location shown?", + }, + {"type": "image_url", "image_url": {"url": IMAGE_URL}}, + ], + }, +] + + +# Helper Functions +def safe_eval_arithmetic(expression: str) -> float: + """ + Safely evaluate arithmetic expressions using AST parsing. + Only allows basic arithmetic operations: +, -, *, /, **, (), and numbers. + + Args: + expression: String containing arithmetic expression + + Returns: + Evaluated result as float + + Raises: + ValueError: If expression contains unsupported operations + SyntaxError: If expression has invalid syntax + ZeroDivisionError: If division by zero occurs + """ + # Allowed operations mapping + ALLOWED_OPS = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.Pow: operator.pow, + ast.USub: operator.neg, + ast.UAdd: operator.pos, + } + + def eval_node(node): + """Recursively evaluate AST nodes""" + if isinstance(node, ast.Constant): # Numbers + return node.value + elif isinstance(node, ast.Num): # Numbers (Python < 3.8 compatibility) + return node.n + elif isinstance(node, ast.UnaryOp): + if type(node.op) in ALLOWED_OPS: + return ALLOWED_OPS[type(node.op)](eval_node(node.operand)) + else: + raise ValueError( + f"Unsupported unary operation: {type(node.op).__name__}" + ) + elif isinstance(node, ast.BinOp): + if type(node.op) in ALLOWED_OPS: + left = eval_node(node.left) + right = eval_node(node.right) + return ALLOWED_OPS[type(node.op)](left, right) + else: + raise ValueError( + f"Unsupported binary operation: {type(node.op).__name__}" + ) + else: + raise ValueError(f"Unsupported expression type: {type(node).__name__}") + + try: + # Parse the expression into an AST + tree = ast.parse(expression, mode="eval") + # Evaluate the AST + return eval_node(tree.body) + except SyntaxError as e: + raise SyntaxError(f"Invalid syntax in expression '{expression}': {e}") + except ZeroDivisionError: + raise ZeroDivisionError(f"Division by zero in expression '{expression}'") + except Exception as e: + raise ValueError(f"Error evaluating expression '{expression}': {e}") + + +def mock_tool_response(tool_name: str, args: Dict[str, Any]) -> str: + """Generate mock responses for tool calls""" + if tool_name == "get_weather": + location = args.get("location", "Unknown") + unit = args.get("unit", "fahrenheit") + return f"The weather in {location} is 72Β°{'F' if unit == 'fahrenheit' else 'C'} and sunny." + + elif tool_name == "calculate": + expression = args.get("expression", "") + try: + # Clean the expression and safely evaluate it + cleaned_expression = expression.replace("x", "*").replace("Γ—", "*") + result = safe_eval_arithmetic(cleaned_expression) + return f"The result of {expression} is {result}" + except (ValueError, SyntaxError, ZeroDivisionError) as e: + return f"Could not calculate {expression}: {e}" + + elif tool_name == "search_web": + query = args.get("query", "") + return f"Here are the search results for '{query}': [Mock search results]" + + return f"Tool {tool_name} executed with args: {args}" + + +def validate_response_structure(response: Any, expected_fields: List[str]) -> bool: + """Validate that a response has the expected structure""" + if not hasattr(response, "__dict__") and not isinstance(response, dict): + return False + + response_dict = response.__dict__ if hasattr(response, "__dict__") else response + + for field in expected_fields: + if field not in response_dict: + return False + + return True + + +def extract_tool_calls(response: Any) -> List[Dict[str, Any]]: + """Extract tool calls from various response formats""" + tool_calls = [] + + # Handle OpenAI format: response.choices[0].message.tool_calls + if hasattr(response, "choices") and len(response.choices) > 0: + choice = response.choices[0] + if ( + hasattr(choice, "message") + and hasattr(choice.message, "tool_calls") + and choice.message.tool_calls + ): + for tool_call in choice.message.tool_calls: + if hasattr(tool_call, "function"): + tool_calls.append( + { + "name": tool_call.function.name, + "arguments": ( + json.loads(tool_call.function.arguments) + if isinstance(tool_call.function.arguments, str) + else tool_call.function.arguments + ), + } + ) + + # Handle direct tool_calls attribute (other formats) + elif hasattr(response, "tool_calls") and response.tool_calls: + for tool_call in response.tool_calls: + if hasattr(tool_call, "function"): + tool_calls.append( + { + "name": tool_call.function.name, + "arguments": ( + json.loads(tool_call.function.arguments) + if isinstance(tool_call.function.arguments, str) + else tool_call.function.arguments + ), + } + ) + + # Handle Anthropic format: response.content with tool_use blocks + elif hasattr(response, "content") and isinstance(response.content, list): + for content in response.content: + if hasattr(content, "type") and content.type == "tool_use": + tool_calls.append({"name": content.name, "arguments": content.input}) + + return tool_calls + + +def assert_valid_chat_response(response: Any, min_length: int = 1): + """Assert that a chat response is valid""" + assert response is not None, "Response should not be None" + + # Extract content from various response formats + content = "" + if hasattr(response, "text"): # Google GenAI + content = response.text + elif hasattr(response, "content"): # Anthropic + if isinstance(response.content, str): + content = response.content + elif isinstance(response.content, list) and len(response.content) > 0: + # Handle list content (like Anthropic) + text_content = [ + c for c in response.content if hasattr(c, "type") and c.type == "text" + ] + if text_content: + content = text_content[0].text + elif hasattr(response, "choices") and len(response.choices) > 0: # OpenAI + # Handle OpenAI format + choice = response.choices[0] + if hasattr(choice, "message") and hasattr(choice.message, "content"): + content = choice.message.content or "" + + assert ( + len(content) >= min_length + ), f"Response content should be at least {min_length} characters, got: {content}" + + +def assert_has_tool_calls(response: Any, expected_count: Optional[int] = None): + """Assert that a response contains tool calls""" + tool_calls = extract_tool_calls(response) + + assert len(tool_calls) > 0, "Response should contain tool calls" + + if expected_count is not None: + assert ( + len(tool_calls) == expected_count + ), f"Expected {expected_count} tool calls, got {len(tool_calls)}" + + # Validate tool call structure + for tool_call in tool_calls: + assert "name" in tool_call, "Tool call should have a name" + assert "arguments" in tool_call, "Tool call should have arguments" + + +def assert_valid_image_response(response: Any): + """Assert that an image analysis response is valid""" + assert_valid_chat_response(response, min_length=10) + + # Extract content for image-specific validation + content = "" + if hasattr(response, "text"): # Google GenAI + content = response.text.lower() + elif hasattr(response, "content"): # Anthropic + if isinstance(response.content, str): + content = response.content.lower() + elif isinstance(response.content, list): + text_content = [ + c for c in response.content if hasattr(c, "type") and c.type == "text" + ] + if text_content: + content = text_content[0].text.lower() + elif hasattr(response, "choices") and len(response.choices) > 0: # OpenAI + choice = response.choices[0] + if hasattr(choice, "message") and hasattr(choice.message, "content"): + content = (choice.message.content or "").lower() + + # Check for image-related keywords + image_keywords = [ + "image", + "picture", + "photo", + "see", + "visual", + "show", + "appear", + "color", + "scene", + ] + has_image_reference = any(keyword in content for keyword in image_keywords) + + assert ( + has_image_reference + ), f"Response should reference the image content. Got: {content}" + + +# Common keyword arrays for flexible assertions +COMPARISON_KEYWORDS = [ + "compare", + "comparison", + "different", + "difference", + "differences", + "both", + "two", + "first", + "second", + "images", + "image", + "versus", + "vs", + "contrast", + "unlike", + "while", + "whereas", +] + +WEATHER_KEYWORDS = [ + "weather", + "temperature", + "sunny", + "cloudy", + "rain", + "snow", + "celsius", + "fahrenheit", + "degrees", + "hot", + "cold", + "warm", + "cool", +] + +LOCATION_KEYWORDS = ["boston", "san francisco", "new york", "city", "location", "place"] + + +# Test Categories +class TestCategories: + """Constants for test categories""" + + SIMPLE_CHAT = "simple_chat" + MULTI_TURN = "multi_turn" + SINGLE_TOOL = "single_tool" + MULTIPLE_TOOLS = "multiple_tools" + E2E_TOOLS = "e2e_tools" + AUTO_FUNCTION = "auto_function" + IMAGE_URL = "image_url" + IMAGE_BASE64 = "image_base64" + MULTIPLE_IMAGES = "multiple_images" + COMPLEX_E2E = "complex_e2e" + INTEGRATION_SPECIFIC = "integration_specific" + + +# Environment helpers +def get_api_key(integration: str) -> str: + """Get API key for a integration from environment variables""" + key_map = { + "openai": "OPENAI_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", + "google": "GOOGLE_API_KEY", + "litellm": "LITELLM_API_KEY", + } + + env_var = key_map.get(integration.lower()) + if not env_var: + raise ValueError(f"Unknown integration: {integration}") + + api_key = os.getenv(env_var) + if not api_key: + raise ValueError(f"Missing environment variable: {env_var}") + + return api_key + + +def skip_if_no_api_key(integration: str): + """Decorator to skip tests if API key is not available""" + import pytest + + def decorator(func): + try: + get_api_key(integration) + return func + except ValueError: + return pytest.mark.skip(f"No API key available for {integration}")(func) + + return decorator diff --git a/tests/transports-integrations/tests/utils/config_loader.py b/tests/transports-integrations/tests/utils/config_loader.py new file mode 100644 index 0000000000..ae683d6b0c --- /dev/null +++ b/tests/transports-integrations/tests/utils/config_loader.py @@ -0,0 +1,299 @@ +""" +Configuration loader for Bifrost integration tests. + +This module loads configuration from config.yml and provides utilities +for constructing integration URLs through the Bifrost gateway. +""" + +import os +import yaml +from typing import Dict, Any, Optional +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class BifrostConfig: + """Bifrost gateway configuration""" + + base_url: str + endpoints: Dict[str, str] + + +@dataclass +class IntegrationModels: + """Model configuration for a integration""" + + chat: str + vision: str + tools: str + alternatives: list + + +@dataclass +class TestConfig: + """Complete test configuration""" + + bifrost: BifrostConfig + api: Dict[str, Any] + models: Dict[str, IntegrationModels] + model_capabilities: Dict[str, Dict[str, Any]] + test_settings: Dict[str, Any] + integration_settings: Dict[str, Any] + environments: Dict[str, Any] + logging: Dict[str, Any] + + +class ConfigLoader: + """Configuration loader for Bifrost integration tests""" + + def __init__(self, config_path: Optional[str] = None): + """Initialize configuration loader + + Args: + config_path: Path to config.yml file. If None, looks for config.yml in project root. + """ + if config_path is None: + # Look for config.yml in project root + project_root = Path(__file__).parent.parent.parent + config_path = project_root / "config.yml" + + self.config_path = Path(config_path) + self._config = None + self._load_config() + + def _load_config(self): + """Load configuration from YAML file""" + if not self.config_path.exists(): + raise FileNotFoundError(f"Configuration file not found: {self.config_path}") + + with open(self.config_path, "r") as f: + raw_config = yaml.safe_load(f) + + # Expand environment variables + self._config = self._expand_env_vars(raw_config) + + def _expand_env_vars(self, obj): + """Recursively expand environment variables in configuration""" + if isinstance(obj, dict): + return {k: self._expand_env_vars(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._expand_env_vars(item) for item in obj] + elif isinstance(obj, str): + # Handle ${VAR:-default} syntax + import re + + pattern = r"\$\{([^}]+)\}" + + def replace_var(match): + var_expr = match.group(1) + if ":-" in var_expr: + var_name, default_value = var_expr.split(":-", 1) + return os.getenv(var_name, default_value) + else: + return os.getenv(var_expr, "") + + return re.sub(pattern, replace_var, obj) + else: + return obj + + def get_integration_url(self, integration: str) -> str: + """Get the complete URL for a integration + + Args: + integration: Integration name (openai, anthropic, google, litellm) + + Returns: + Complete URL for the integration + + Examples: + get_integration_url("openai") -> "http://localhost:8080/openai" + """ + bifrost_config = self._config["bifrost"] + base_url = bifrost_config["base_url"] + endpoint = bifrost_config["endpoints"].get(integration, "") + + if not endpoint: + raise ValueError(f"No endpoint configured for integration: {integration}") + + return f"{base_url.rstrip('/')}/{endpoint}" + + def get_bifrost_config(self) -> BifrostConfig: + """Get Bifrost configuration""" + bifrost_data = self._config["bifrost"] + return BifrostConfig( + base_url=bifrost_data["base_url"], endpoints=bifrost_data["endpoints"] + ) + + def get_model(self, integration: str, model_type: str = "chat") -> str: + """Get model name for a integration and type""" + if integration not in self._config["models"]: + raise ValueError(f"Unknown integration: {integration}") + + integration_models = self._config["models"][integration] + + if model_type not in integration_models: + raise ValueError( + f"Unknown model type '{model_type}' for integration '{integration}'" + ) + + return integration_models[model_type] + + def get_model_alternatives(self, integration: str) -> list: + """Get alternative models for a integration""" + if integration not in self._config["models"]: + raise ValueError(f"Unknown integration: {integration}") + + return self._config["models"][integration].get("alternatives", []) + + def get_model_capabilities(self, model: str) -> Dict[str, Any]: + """Get capabilities for a specific model""" + return self._config["model_capabilities"].get( + model, + { + "chat": True, + "tools": False, + "vision": False, + "max_tokens": 4096, + "context_window": 4096, + }, + ) + + def supports_capability(self, model: str, capability: str) -> bool: + """Check if a model supports a specific capability""" + caps = self.get_model_capabilities(model) + return caps.get(capability, False) + + def get_api_config(self) -> Dict[str, Any]: + """Get API configuration (timeout, retries, etc.)""" + return self._config["api"] + + def get_test_settings(self) -> Dict[str, Any]: + """Get test configuration settings""" + return self._config["test_settings"] + + def get_integration_settings(self, integration: str) -> Dict[str, Any]: + """Get integration-specific settings""" + return self._config["integration_settings"].get(integration, {}) + + def get_environment_config(self, environment: str = None) -> Dict[str, Any]: + """Get environment-specific configuration + + Args: + environment: Environment name (development, production, etc.) + If None, uses TEST_ENV environment variable or 'development' + """ + if environment is None: + environment = os.getenv("TEST_ENV", "development") + + return self._config["environments"].get(environment, {}) + + def get_logging_config(self) -> Dict[str, Any]: + """Get logging configuration""" + return self._config["logging"] + + def list_integrations(self) -> list: + """List all configured integrations""" + return list(self._config["bifrost"]["endpoints"].keys()) + + def list_models(self, integration: str = None) -> Dict[str, Any]: + """List all models for a integration or all integrations""" + if integration: + if integration not in self._config["models"]: + raise ValueError(f"Unknown integration: {integration}") + return {integration: self._config["models"][integration]} + + return self._config["models"] + + def validate_config(self) -> bool: + """Validate configuration completeness""" + required_sections = ["bifrost", "models", "api", "test_settings"] + + for section in required_sections: + if section not in self._config: + raise ValueError(f"Missing required configuration section: {section}") + + # Validate Bifrost configuration + bifrost = self._config["bifrost"] + if "base_url" not in bifrost or "endpoints" not in bifrost: + raise ValueError("Bifrost configuration missing base_url or endpoints") + + # Validate that all integrations have model configurations + integrations = list(bifrost["endpoints"].keys()) + for integration in integrations: + if integration not in self._config["models"]: + raise ValueError( + f"No model configuration for integration: {integration}" + ) + + return True + + def print_config_summary(self): + """Print a summary of the configuration""" + print("πŸ”§ BIFROST INTEGRATION TEST CONFIGURATION") + print("=" * 80) + + # Bifrost configuration + bifrost = self.get_bifrost_config() + print(f"\nπŸŒ‰ BIFROST GATEWAY:") + print(f" Base URL: {bifrost.base_url}") + print(f" Endpoints:") + for integration, endpoint in bifrost.endpoints.items(): + full_url = f"{bifrost.base_url.rstrip('/')}/{endpoint}" + print(f" {integration}: {full_url}") + + # Model configurations + print(f"\nπŸ€– MODEL CONFIGURATIONS:") + for integration, models in self._config["models"].items(): + print(f" {integration.upper()}:") + print(f" Chat: {models['chat']}") + print(f" Vision: {models['vision']}") + print(f" Tools: {models['tools']}") + print(f" Alternatives: {len(models['alternatives'])} models") + + # API settings + api_config = self.get_api_config() + print(f"\nβš™οΈ API SETTINGS:") + print(f" Timeout: {api_config['timeout']}s") + print(f" Max Retries: {api_config['max_retries']}") + print(f" Retry Delay: {api_config['retry_delay']}s") + + print(f"\nβœ… Configuration loaded successfully from: {self.config_path}") + + +# Global configuration instance +_config_loader = None + + +def get_config() -> ConfigLoader: + """Get global configuration instance""" + global _config_loader + if _config_loader is None: + _config_loader = ConfigLoader() + return _config_loader + + +def get_integration_url(integration: str) -> str: + return get_config().get_integration_url(integration) + + +def get_model(integration: str, model_type: str = "chat") -> str: + """Convenience function to get model name""" + return get_config().get_model(integration, model_type) + + +def get_model_capabilities(model: str) -> Dict[str, Any]: + """Convenience function to get model capabilities""" + return get_config().get_model_capabilities(model) + + +def supports_capability(model: str, capability: str) -> bool: + """Convenience function to check model capability""" + return get_config().supports_capability(model, capability) + + +if __name__ == "__main__": + # Print configuration summary when run directly + config = get_config() + config.validate_config() + config.print_config_summary() diff --git a/tests/transports-integrations/tests/utils/models.py b/tests/transports-integrations/tests/utils/models.py new file mode 100644 index 0000000000..315e5410c0 --- /dev/null +++ b/tests/transports-integrations/tests/utils/models.py @@ -0,0 +1,66 @@ +""" +Model configurations for each integration. + +This file now acts as a compatibility layer and convenience wrapper +around the new configuration system in config.yml and config_loader.py. + +All model data is now centralized in config.yml for easier maintenance. +""" + +from typing import Dict, List +from dataclasses import dataclass +from .config_loader import get_config + + +@dataclass +class IntegrationModels: + """Model configuration for a integration""" + + chat: str # Primary chat model + vision: str # Vision/multimodal model + tools: str # Function calling model + alternatives: List[str] # Alternative models for testing + + +def get_integration_models() -> Dict[str, IntegrationModels]: + """Get all integration model configurations from config.yml""" + config = get_config() + integration_models = {} + + for integration in config.list_integrations(): + models_config = config.list_models(integration) + integration_models[integration] = IntegrationModels( + chat=models_config["chat"], + vision=models_config["vision"], + tools=models_config["tools"], + alternatives=models_config["alternatives"], + ) + + return integration_models + + +# Backward compatibility - load from config +INTEGRATION_MODELS = get_integration_models() + + +def get_alternatives(integration: str) -> List[str]: + """Get alternative models for a integration""" + config = get_config() + return config.get_model_alternatives(integration) + + +def list_all_models() -> Dict[str, Dict[str, str]]: + """List all models by integration and type""" + config = get_config() + return config.list_models() + + +# Print model summary for documentation +def print_model_summary(): + """Print a summary of all models and their capabilities""" + config = get_config() + config.print_config_summary() + + +if __name__ == "__main__": + print_model_summary() diff --git a/transports/.env.sample b/transports/.env.sample deleted file mode 100644 index 30e582a355..0000000000 --- a/transports/.env.sample +++ /dev/null @@ -1,10 +0,0 @@ -OPENAI_API_KEY = YOUR_OPENAI_API_KEY -ANTHROPIC_API_KEY = YOUR_ANTHROPIC_API_KEY -BEDROCK_API_KEY = YOUR_BEDROCK_API_KEY -BEDROCK_ACCESS_KEY = YOUR_BEDROCK_ACCESS_KEY -COHERE_API_KEY = YOUR_COHERE_API_KEY -AZURE_API_KEY = YOUR_AZURE_API_KEY -AZURE_ENDPOINT = YOUR_AZURE_ENDPOINT - -MAXIM_API_KEY = YOUR_MAXIM_API_KEY -MAXIM_LOGGER_ID = YOUR_MAXIM_LOGGER_ID \ No newline at end of file diff --git a/transports/Dockerfile b/transports/Dockerfile index df9ac9901c..d26d7b31c9 100644 --- a/transports/Dockerfile +++ b/transports/Dockerfile @@ -1,61 +1,52 @@ # --- First Stage: Builder image --- -FROM golang:1.24 AS builder +FROM golang:1.24-alpine AS builder WORKDIR /app +# Install dependencies in a single layer +RUN apk add --no-cache upx + # Set environment for static build -ENV CGO_ENABLED=0 -ENV GOOS=linux -ENV GOARCH=amd64 +ENV CGO_ENABLED=0 GOOS=linux GOARCH=amd64 # Define build-time variable for transport type ARG TRANSPORT_TYPE=http -# Initialize Go module and fetch the bifrost transport package -RUN go mod init bifrost-transports && \ - go get github.com/maximhq/bifrost/transports/${TRANSPORT_TYPE}@latest +# Initialize go module and get bifrost-http +RUN go mod init bifrost-build && \ + go get github.com/maximhq/bifrost/transports/bifrost-${TRANSPORT_TYPE}@latest + +# Build the binary locally +RUN go build \ + -ldflags="-w -s -extldflags '-static'" \ + -a -trimpath \ + -o /app/main \ + github.com/maximhq/bifrost/transports/bifrost-${TRANSPORT_TYPE} + +# Compress binary with upx +RUN upx --best --lzma /app/main -# Build the binary from the fetched package with static linking -RUN go build -ldflags="-w -s" -o /app/main github.com/maximhq/bifrost/transports/${TRANSPORT_TYPE} && \ - test -f /app/main || (echo "Build failed: /app/main not found" && exit 1) && \ - ls -lh /app/main +# Verify build succeeded +RUN test -f /app/main || (echo "Build failed" && exit 1) -# --- Second Stage: Runtime image --- -FROM alpine:latest +# --- Second Stage: Minimal runtime image --- +FROM golang:1.24-alpine WORKDIR /app -# Copy the compiled binary from the builder stage +# Copy necessary files and create structure in one layer COPY --from=builder /app/main . -# Ensure the binary is executable -RUN chmod +x /app/main -# Create a directory to store configuration files -RUN mkdir -p /app/config - -# Define build-time variables for config file paths -ARG CONFIG_PATH -ARG ENV_PATH -ARG PORT -ARG POOL_SIZE -ARG DROP_EXCESS_REQUESTS - -# Set default values if args are not provided -ENV APP_PORT=${PORT:-8080} -ENV APP_POOL_SIZE=${POOL_SIZE:-300} -ENV APP_DROP_EXCESS_REQUESTS=${DROP_EXCESS_REQUESTS:-false} - -# Copy the config and environment files into the image -COPY ${CONFIG_PATH} /app/config/config.json -COPY ${ENV_PATH} /app/config/.env - -# Write a small script to validate config presence and run the app -RUN echo '#!/bin/sh' > /app/entrypoint.sh && \ - echo 'if [ ! -f /app/config/config.json ]; then echo "Missing config.json"; exit 1; fi' >> /app/entrypoint.sh && \ - echo 'if [ ! -f /app/config/.env ]; then echo "Missing .env"; exit 1; fi' >> /app/entrypoint.sh && \ - echo 'if [ ! -f /app/main ]; then echo "Missing main binary"; exit 1; fi' >> /app/entrypoint.sh && \ - echo 'exec /app/main -config /app/config/config.json -env /app/config/.env -port "$APP_PORT" -pool-size "$APP_POOL_SIZE" -drop-excess-requests "$APP_DROP_EXCESS_REQUESTS"' >> /app/entrypoint.sh && \ - chmod +x /app/entrypoint.sh - -# Expose the port defined by argument -EXPOSE ${PORT:-8080} - -# Use the script as the entry point -ENTRYPOINT ["/app/entrypoint.sh"] \ No newline at end of file +RUN mkdir -p /app/config /app/plugins && \ + adduser -D -s /bin/sh appuser && \ + chown -R appuser:appuser /app +USER appuser + +# Environment variables with defaults +ENV APP_PORT=8080 \ + APP_POOL_SIZE=300 \ + APP_DROP_EXCESS_REQUESTS=false \ + APP_PLUGINS="" \ + APP_PROMETHEUS_LABELS="" + +EXPOSE 8080 + +# Direct entrypoint with environment variable expansion +ENTRYPOINT ["/bin/sh", "-c", "exec /app/main -config /app/config/config.json -port \"${APP_PORT}\" -pool-size \"${APP_POOL_SIZE}\" -drop-excess-requests \"${APP_DROP_EXCESS_REQUESTS}\" -plugins \"${APP_PLUGINS}\" -prometheus-labels \"${APP_PROMETHEUS_LABELS}\""] \ No newline at end of file diff --git a/transports/README.md b/transports/README.md index 0f0670a03b..8a519ce5b0 100644 --- a/transports/README.md +++ b/transports/README.md @@ -2,6 +2,8 @@ This package contains clients for various transports that can be used to spin up your Bifrost client with just a single line of code. +πŸ“– **Comprehensive HTTP API documentation is available in** _[`docs/http-transport-api.md`](../docs/http-transport-api.md)_. + ## πŸ“‘ Table of Contents - [Bifrost Transports](#bifrost-transports) @@ -9,12 +11,17 @@ This package contains clients for various transports that can be used to spin up - [πŸš€ Setting Up Transports](#-setting-up-transports) - [Prerequisites](#prerequisites) - [Configuration](#configuration) + - [MCP (Model Context Protocol) Configuration](#mcp-model-context-protocol-configuration) + - [MCP Environment Variables](#mcp-environment-variables) - [Docker Setup](#docker-setup) - [Go Setup](#go-setup) - [🧰 Usage](#-usage) - [Text Completions](#text-completions) - [Chat Completions](#chat-completions) + - [Multi-Turn Conversations with MCP Tools](#multi-turn-conversations-with-mcp-tools) - [πŸ”§ Advanced Features](#-advanced-features) + - [Prometheus Support](#prometheus-support) + - [Plugin Support](#plugin-support) - [Fallbacks](#fallbacks) --- @@ -22,6 +29,7 @@ This package contains clients for various transports that can be used to spin up ## πŸš€ Setting Up Transports ### Prerequisites + - Go 1.23 or higher (if not using Docker) - Access to at least one AI model provider (OpenAI, Anthropic, etc.) - API keys for the providers you wish to use @@ -31,52 +39,173 @@ This package contains clients for various transports that can be used to spin up Bifrost uses a combination of a JSON configuration file and environment variables: 1. **JSON Configuration File**: Bifrost requires a configuration file to set up the gateway. This includes all your provider-level settings, keys, and meta configs for each of your providers. - -2. **Environment Variables**: If you don't want to include your keys in your config file, you can provide a `.env` file and add a prefix of `env.` followed by its key in your `.env` file. +2. **Environment Variables**: If you don't want to include your keys in your config file, you can add a prefix of `env.` followed by its key in your environment. ```json { - "keys": [{ - "value": "env.OPENAI_API_KEY", - "models": ["gpt-4o-mini", "gpt-4-turbo"], - "weight": 1.0 - }] + "providers": { + "openai": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": ["gpt-4o-mini"], + "weight": 1.0 + } + ] + } + } } ``` -In this example, `OPENAI_API_KEY` refers to a key in the `.env` file. At runtime, its value will be used to replace the placeholder. +In this example config file, `OPENAI_API_KEY` refers to a key set in your environment. At runtime, its value will be used to replace the placeholder. The same setup applies to keys in meta configs of all providers: ```json { - "meta_config": { - "secret_access_key": "env.BEDROCK_ACCESS_KEY", - "region": "env.BEDROCK_REGION" + "providers": { + "bedrock": { + "keys": [ + { + "value": "env.BEDROCK_API_KEY", + "models": ["anthropic.claude-v2:1"], + "weight": 1.0 + } + ], + "meta_config": { + "secret_access_key": "env.AWS_SECRET_ACCESS_KEY", + "region": "env.AWS_REGION" + } + } + } +} +``` + +In this example, `AWS_SECRET_ACCESS_KEY` and `AWS_REGION` refer to keys in the environment. + +**Please refer to `config.example.json` for examples.** + +### MCP (Model Context Protocol) Configuration + +Bifrost supports MCP integration for tool usage with AI models. You can configure MCP servers and tools in your configuration file: + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": ["gpt-4o-mini"], + "weight": 1.0 + } + ] + } + }, + "mcp": { + "client_configs": [ + { + "name": "filesystem", + "connection_type": "stdio", + "stdio_config": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], + "envs": ["NODE_ENV", "FILESYSTEM_ROOT"] + }, + "tools_to_skip": [], + "tools_to_execute": [] + }, + { + "name": "web-search", + "connection_type": "http", + "connection_string": "http://localhost:3001/mcp", + "tools_to_skip": [], + "tools_to_execute": [] + }, + { + "name": "real-time-data", + "connection_type": "sse", + "connection_string": "http://localhost:3002/sse", + "tools_to_skip": [], + "tools_to_execute": [] + } + ] + } +} +``` + +#### MCP Environment Variables + +The `envs` field in STDIO MCP configuration serves a different purpose than regular environment variables in Bifrost: + +- **Regular Bifrost environment variables** (like `"env.OPENAI_API_KEY"`) use the `env.` prefix and are accessed directly by Bifrost +- **MCP environment variables** (in the `envs` array) do **NOT** use the `env.` prefix and are not accessed by Bifrost directly + +Instead, Bifrost checks if the environment variables listed in `envs` are available in the environment **before establishing the MCP client connection**. This ensures that MCP tools that require specific environment variables (like API keys or configuration values) have their dependencies available before attempting to connect. + +For example: + +```json +{ + "name": "weather-service", + "connection_type": "stdio", + "stdio_config": { + "command": "npx", + "args": ["weather-mcp-server"], + "envs": ["WEATHER_API_KEY", "DEFAULT_LOCATION"] } } ``` -In this example, `BEDROCK_ACCESS_KEY` and `BEDROCK_REGION` refer to keys in the `.env` file. +In this case, Bifrost will verify that `WEATHER_API_KEY` and `DEFAULT_LOCATION` exist in the environment before attempting to start the weather MCP server. -Please refer to `config.example.json` and `.env.sample` for examples. +**Configuration Summary:** + +- Connects to a filesystem MCP tool via STDIO (requires `NODE_ENV` and `FILESYSTEM_ROOT` environment variables) +- Connects to a web-search MCP service via HTTP + +**For comprehensive MCP documentation including Go package usage, local tool registration, and advanced configurations, see [MCP Integration Guide](../docs/mcp.md).** This section focuses on HTTP transport specific MCP usage. ### Docker Setup -You can run Bifrost using our **independent Dockerfile**. Just copy our Dockerfile and run these commands to get your Bifrost instance up and running: +1. Pull the Docker image: + + ```bash + docker pull maximhq/bifrost + ``` + +2. Run the Docker container: + + ```bash + docker run -p 8080:8080 \ + -v $(pwd)/config.json:/app/config/config.json \ + -e OPENAI_API_KEY \ + -e ANTHROPIC_API_KEY \ + maximhq/bifrost + ``` + +Note: In the command above, `OPENAI_API_KEY` and `ANTHROPIC_API_KEY` are just example environment variables. +Ensure you mount your config file and use the `-e` flag to pass all environment variables referenced in your `config.json` that are prefixed with `env.` to the container. This ensures Docker sets them correctly inside the container. + +Example usage: Suppose your config.json only contains one environment variable placeholder, `env.COHERE_API_KEY`. Here's how you would run it: ```bash -docker build \ - --build-arg CONFIG_PATH=./config.example.json \ - --build-arg ENV_PATH=./.env.sample \ - --build-arg PORT=8080 \ - --build-arg POOL_SIZE=300 \ - -t bifrost-transports . - -docker run -p 8080:8080 bifrost-transports +export COHERE_API_KEY=your_cohere_api_key + +docker run -p 8080:8080 \ + -v $(pwd)/config.example.json:/app/config/config.json \ + -e COHERE_API_KEY \ + maximhq/bifrost ``` -You can also add a flag for `DROP_EXCESS_REQUESTS=false` in your Docker build command to drop excess requests when the buffer is full. Read more about `DROP_EXCESS_REQUESTS` and `POOL_SIZE` [here](https://github.com/maximhq/bifrost/tree/main?tab=README-ov-file#additional-configurations). +You can also set runtime environment variables for configuration: + +- `APP_PORT`: Server port (default: 8080) +- `APP_POOL_SIZE`: Connection pool size (default: 300) +- `APP_DROP_EXCESS_REQUESTS`: Drop excess requests when buffer is full (default: false) +- `APP_PLUGINS`: Comma-separated list of plugins + +Read more about these [configurations](https://github.com/maximhq/bifrost/tree/main?tab=README-ov-file#additional-configurations). --- @@ -86,38 +215,34 @@ If you wish to run Bifrost in your Go environment, follow these steps: 1. Install your binary: -```bash -go install github.com/maximhq/bifrost/transports/http@latest -``` + ```bash + go install github.com/maximhq/bifrost/transports/bifrost-http@latest + ``` -2. Run your binary: +2. Run your binary (ensure Go is in your PATH): -- If it's in your PATH: ```bash -http -config config.json -env .env -port 8080 -pool-size 300 +bifrost-http -config config.json -port 8080 -pool-size 300 ``` -- Otherwise: -```bash -./http -config config.json -env .env -port 8080 -pool-size 300 -``` - -You can also add a flag for `-drop-excess-requests=false` in your command to drop excess requests when the buffer is full. Read more about `DROP_EXCESS_REQUESTS` and `POOL_SIZE` [here](https://github.com/maximhq/bifrost/tree/main?tab=README-ov-file#additional-configurations). +You can also add a flag for `-drop-excess-requests=false` in your command to drop excess requests when the buffer is full. Read more about `DROP_EXCESS_REQUESTS` and `POOL_SIZE` in [additional configurations](https://github.com/maximhq/bifrost/tree/main?tab=README-ov-file#additional-configurations). ## 🧰 Usage Ensure that: + - Bifrost's HTTP server is running - The providers/models you use are configured in your JSON config file ### Text Completions ```bash +# Make sure to set up Anthropic and claude-2.1 in your config.json curl -X POST http://localhost:8080/v1/text/completions \ -H "Content-Type: application/json" \ -d '{ - "provider": "openai", - "model": "gpt-4o-mini", + "provider": "anthropic", + "model": "claude-2.1", "text": "Once upon a time in the land of AI,", "params": { "temperature": 0.7, @@ -145,10 +270,157 @@ curl -X POST http://localhost:8080/v1/chat/completions \ }' ``` +### Multi-Turn Conversations with MCP Tools + +When MCP is configured, Bifrost automatically adds available tools to requests. Here's an example of a multi-turn conversation where the AI uses tools: + +1. **Initial Request** (AI decides to use a tool): + +```bash +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "provider": "openai", + "model": "gpt-4o-mini", + "messages": [ + {"role": "user", "content": "Can you list the files in the /tmp directory?"} + ] + }' +``` + +Response includes tool calls: + + ```json + { + "data": { + "choices": [ + { + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "list_files", + "arguments": "{\"path\": \"/tmp\"}" + } + } + ] + } + } + ] + } + } + ``` + +2. **Execute Tool** (Use Bifrost's MCP tool execution endpoint): + + ```bash + curl -X POST http://localhost:8080/v1/mcp/tool/execute \ + -H "Content-Type: application/json" \ + -d '{ + "id": "call_abc123", + "type": "function", + "function": { + "name": "list_files", + "arguments": "{\"path\": \"/tmp\"}" + } + }' + ``` + + Response with tool result: + + ```json + { + "role": "tool", + "content": "config.json\nreadme.txt\ndata.csv", + "tool_call_id": "call_abc123" + } + ``` + +3. **Continue Conversation** (Add tool result and get final response): + + ```bash + curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "provider": "openai", + "model": "gpt-4o-mini", + "messages": [ + {"role": "user", "content": "Can you list the files in the /tmp directory?"}, + { + "role": "assistant", + "content": null, + "tool_calls": [{ + "id": "call_abc123", + "type": "function", + "function": { + "name": "list_files", + "arguments": "{\"path\": \"/tmp\"}" + } + }] + }, + { + "role": "tool", + "content": "config.json\nreadme.txt\ndata.csv", + "tool_call_id": "call_abc123" + } + ] + }' + ``` + + Final response: + + ```json + { + "data": { + "choices": [ + { + "message": { + "role": "assistant", + "content": "I found 3 files in the /tmp directory:\n1. config.json\n2. readme.txt\n3. data.csv\n\nWould you like me to read the contents of any of these files?" + } + } + ] + } + } + ``` + +**Tool Execution Flow Summary:** + +1. Send chat completion request β†’ AI responds with tool_calls +2. Send tool_calls to `/v1/mcp/tool/execute` β†’ Get tool_result message +3. Append tool_result to conversation β†’ Send back for final response + +**Key Endpoints:** + +- `POST /v1/chat/completions` - Chat with automatic tool discovery +- `POST /v1/mcp/tool/execute` - Execute tool calls returned by the AI + +> πŸ”§ **For Go package integration and advanced tool execution patterns, see [Implementing Chat Conversations with MCP Tools](../docs/mcp.md#implementing-chat-conversations-with-mcp-tools).** + --- ## πŸ”§ Advanced Features +### Prometheus Support + +HTTP transport supports Prometheus out of the box. By default, all metrics are available at the `/metrics` endpoint. It provides metrics for httpRequestsTotal, httpRequestDuration, httpRequestSizeBytes, httpResponseSizeBytes, bifrostUpstreamRequestsTotal, and bifrostUpstreamLatencySeconds. To add custom labels to these metrics, pass the `-prometheus-labels` flag while running the HTTP transport. + +e.g., `-prometheus-labels team-id,task-id,location` + +Values for labels are then picked up from the HTTP request headers with the prefix `x-bf-prom-`. + +### Plugin Support + +You can explore the [available plugins](https://github.com/maximhq/bifrost/tree/main/plugins). To attach these plugins to your HTTP transport, pass the `-plugins` flag. + +e.g., `-plugins maxim` + +Note: Check plugin-specific documentation (github.com/maximhq/bifrost/tree/main/plugins/{plugin_name}) for more granular control and additional setup requirements. + ### Fallbacks Configure fallback options in your requests: @@ -171,8 +443,6 @@ Configure fallback options in your requests: } ``` -Read more about fallbacks and other additional configurations [here](https://github.com/maximhq/bifrost/tree/main?tab=README-ov-file#additional-configurations). - ---- +Read more about fallbacks and other [additional configurations](https://github.com/maximhq/bifrost/tree/main?tab=README-ov-file#additional-configurations). -Built with ❀️ by [Maxim](https://github.com/maximhq) \ No newline at end of file +Built with ❀️ by [Maxim](https://github.com/maximhq) diff --git a/transports/bifrost-http/integrations/anthropic/router.go b/transports/bifrost-http/integrations/anthropic/router.go new file mode 100644 index 0000000000..81d2275997 --- /dev/null +++ b/transports/bifrost-http/integrations/anthropic/router.go @@ -0,0 +1,41 @@ +package anthropic + +import ( + "errors" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" +) + +// AnthropicRouter holds route registrations for Anthropic endpoints. +// It supports standard chat completions and image-enabled vision capabilities. +type AnthropicRouter struct { + *integrations.GenericRouter +} + +// NewAnthropicRouter creates a new AnthropicRouter with the given bifrost client. +func NewAnthropicRouter(client *bifrost.Bifrost) *AnthropicRouter { + routes := []integrations.RouteConfig{ + { + Path: "/anthropic/v1/messages", + Method: "POST", + GetRequestTypeInstance: func() interface{} { + return &AnthropicMessageRequest{} + }, + RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { + if anthropicReq, ok := req.(*AnthropicMessageRequest); ok { + return anthropicReq.ConvertToBifrostRequest(), nil + } + return nil, errors.New("invalid request type") + }, + ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { + return DeriveAnthropicFromBifrostResponse(resp), nil + }, + }, + } + + return &AnthropicRouter{ + GenericRouter: integrations.NewGenericRouter(client, routes), + } +} diff --git a/transports/bifrost-http/integrations/anthropic/types.go b/transports/bifrost-http/integrations/anthropic/types.go new file mode 100644 index 0000000000..93a28f7b62 --- /dev/null +++ b/transports/bifrost-http/integrations/anthropic/types.go @@ -0,0 +1,463 @@ +package anthropic + +import ( + "encoding/json" + "fmt" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +var fnTypePtr = bifrost.Ptr(string(schemas.ToolChoiceTypeFunction)) + +// AnthropicContentBlock represents content in Anthropic message format +type AnthropicContentBlock struct { + Type string `json:"type"` // "text", "image", "tool_use", "tool_result" + Text *string `json:"text,omitempty"` // For text content + ToolUseID *string `json:"tool_use_id,omitempty"` // For tool_result content + ID *string `json:"id,omitempty"` // For tool_use content + Name *string `json:"name,omitempty"` // For tool_use content + Input interface{} `json:"input,omitempty"` // For tool_use content + Content AnthropicContent `json:"content,omitempty"` // For tool_result content + Source *AnthropicImageSource `json:"source,omitempty"` // For image content +} + +// AnthropicImageSource represents image source in Anthropic format +type AnthropicImageSource struct { + Type string `json:"type"` // "base64" or "url" + MediaType *string `json:"media_type,omitempty"` // "image/jpeg", "image/png", etc. + Data *string `json:"data,omitempty"` // Base64-encoded image data + URL *string `json:"url,omitempty"` // URL of the image +} + +// AnthropicMessage represents a message in Anthropic format +type AnthropicMessage struct { + Role string `json:"role"` // "user", "assistant" + Content AnthropicContent `json:"content"` // Array of content blocks +} + +type AnthropicContent struct { + ContentStr *string + ContentBlocks *[]AnthropicContentBlock +} + +// AnthropicTool represents a tool in Anthropic format +type AnthropicTool struct { + Name string `json:"name"` + Type *string `json:"type,omitempty"` + Description string `json:"description"` + InputSchema *struct { + Type string `json:"type"` // "object" + Properties map[string]interface{} `json:"properties"` + Required []string `json:"required"` + } `json:"input_schema,omitempty"` +} + +// AnthropicToolChoice represents tool choice in Anthropic format +type AnthropicToolChoice struct { + Type string `json:"type"` // "auto", "any", "tool" + Name string `json:"name,omitempty"` // For type "tool" +} + +// AnthropicMessageRequest represents an Anthropic messages API request +type AnthropicMessageRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + Messages []AnthropicMessage `json:"messages"` + System *AnthropicContent `json:"system,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK *int `json:"top_k,omitempty"` + StopSequences *[]string `json:"stop_sequences,omitempty"` + Stream *bool `json:"stream,omitempty"` + Tools *[]AnthropicTool `json:"tools,omitempty"` + ToolChoice *AnthropicToolChoice `json:"tool_choice,omitempty"` +} + +// AnthropicMessageResponse represents an Anthropic messages API response +type AnthropicMessageResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []AnthropicContentBlock `json:"content"` + Model string `json:"model"` + StopReason *string `json:"stop_reason,omitempty"` + StopSequence *string `json:"stop_sequence,omitempty"` + Usage *AnthropicUsage `json:"usage,omitempty"` +} + +// AnthropicUsage represents usage information in Anthropic format +type AnthropicUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +// MarshalJSON implements custom JSON marshalling for MessageContent. +// It marshals either ContentStr or ContentBlocks directly without wrapping. +func (mc AnthropicContent) MarshalJSON() ([]byte, error) { + // Validation: ensure only one field is set at a time + if mc.ContentStr != nil && mc.ContentBlocks != nil { + return nil, fmt.Errorf("both ContentStr and ContentBlocks are set; only one should be non-nil") + } + + if mc.ContentStr != nil { + return json.Marshal(*mc.ContentStr) + } + if mc.ContentBlocks != nil { + return json.Marshal(*mc.ContentBlocks) + } + // If both are nil, return null + return json.Marshal(nil) +} + +// UnmarshalJSON implements custom JSON unmarshalling for MessageContent. +// It determines whether "content" is a string or array and assigns to the appropriate field. +// It also handles direct string/array content without a wrapper object. +func (mc *AnthropicContent) UnmarshalJSON(data []byte) error { + // First, try to unmarshal as a direct string + var stringContent string + if err := json.Unmarshal(data, &stringContent); err == nil { + mc.ContentStr = &stringContent + return nil + } + + // Try to unmarshal as a direct array of ContentBlock + var arrayContent []AnthropicContentBlock + if err := json.Unmarshal(data, &arrayContent); err == nil { + mc.ContentBlocks = &arrayContent + return nil + } + + return fmt.Errorf("content field is neither a string nor an array of ContentBlock") +} + +// ConvertToBifrostRequest converts an Anthropic messages request to Bifrost format +func (r *AnthropicMessageRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { + bifrostReq := &schemas.BifrostRequest{ + Provider: schemas.Anthropic, + Model: r.Model, + } + + messages := []schemas.BifrostMessage{} + + // Add system message if present + if r.System != nil { + if r.System.ContentStr != nil && *r.System.ContentStr != "" { + messages = append(messages, schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleSystem, + Content: schemas.MessageContent{ + ContentStr: r.System.ContentStr, + }, + }) + } else if r.System.ContentBlocks != nil { + contentBlocks := []schemas.ContentBlock{} + for _, block := range *r.System.ContentBlocks { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeText, + Text: block.Text, + }) + } + messages = append(messages, schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleSystem, + Content: schemas.MessageContent{ + ContentBlocks: &contentBlocks, + }, + }) + } + } + + // Convert messages + for _, msg := range r.Messages { + var bifrostMsg schemas.BifrostMessage + bifrostMsg.Role = schemas.ModelChatMessageRole(msg.Role) + + if msg.Content.ContentStr != nil { + bifrostMsg.Content = schemas.MessageContent{ + ContentStr: msg.Content.ContentStr, + } + } else if msg.Content.ContentBlocks != nil { + // Handle different content types + var toolCalls []schemas.ToolCall + var contentBlocks []schemas.ContentBlock + + for _, content := range *msg.Content.ContentBlocks { + switch content.Type { + case "text": + if content.Text != nil { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeText, + Text: content.Text, + }) + } + case "image": + if content.Source != nil { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeImage, + ImageURL: &schemas.ImageURLStruct{ + URL: func() string { + if content.Source.Data != nil { + mime := "image/png" + if content.Source.MediaType != nil && *content.Source.MediaType != "" { + mime = *content.Source.MediaType + } + return "data:" + mime + ";base64," + *content.Source.Data + } + if content.Source.URL != nil { + return *content.Source.URL + } + return "" + }(), + }, + }) + } + case "tool_use": + if content.ID != nil && content.Name != nil { + tc := schemas.ToolCall{ + Type: fnTypePtr, + ID: content.ID, + Function: schemas.FunctionCall{ + Name: content.Name, + Arguments: jsonifyInput(content.Input), + }, + } + toolCalls = append(toolCalls, tc) + } + case "tool_result": + if content.ToolUseID != nil { + bifrostMsg.ToolMessage = &schemas.ToolMessage{ + ToolCallID: content.ToolUseID, + } + if content.Content.ContentStr != nil { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeText, + Text: content.Content.ContentStr, + }) + } else if content.Content.ContentBlocks != nil { + for _, block := range *content.Content.ContentBlocks { + if block.Text != nil { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeText, + Text: block.Text, + }) + } else if block.Source != nil { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeImage, + ImageURL: &schemas.ImageURLStruct{ + URL: func() string { + if block.Source.Data != nil { + mime := "image/png" + if block.Source.MediaType != nil && *block.Source.MediaType != "" { + mime = *block.Source.MediaType + } + return "data:" + mime + ";base64," + *block.Source.Data + } + if block.Source.URL != nil { + return *block.Source.URL + } + return "" + }()}, + }) + } + } + } + bifrostMsg.Role = schemas.ModelChatMessageRoleTool + } + } + } + + // Concatenate all text contents + if len(contentBlocks) > 0 { + bifrostMsg.Content = schemas.MessageContent{ + ContentBlocks: &contentBlocks, + } + } + + if len(toolCalls) > 0 && msg.Role == string(schemas.ModelChatMessageRoleAssistant) { + bifrostMsg.AssistantMessage = &schemas.AssistantMessage{ + ToolCalls: &toolCalls, + } + } + } + messages = append(messages, bifrostMsg) + } + + bifrostReq.Input.ChatCompletionInput = &messages + + // Convert parameters + if r.MaxTokens > 0 || r.Temperature != nil || r.TopP != nil || r.TopK != nil || r.StopSequences != nil { + params := &schemas.ModelParameters{} + + if r.MaxTokens > 0 { + params.MaxTokens = &r.MaxTokens + } + if r.Temperature != nil { + params.Temperature = r.Temperature + } + if r.TopP != nil { + params.TopP = r.TopP + } + if r.TopK != nil { + params.TopK = r.TopK + } + if r.StopSequences != nil { + params.StopSequences = r.StopSequences + } + + bifrostReq.Params = params + } + + // Convert tools + if r.Tools != nil { + tools := []schemas.Tool{} + for _, tool := range *r.Tools { + // Convert input_schema to FunctionParameters + params := schemas.FunctionParameters{ + Type: "object", + } + if tool.InputSchema != nil { + params.Type = tool.InputSchema.Type + params.Required = tool.InputSchema.Required + params.Properties = tool.InputSchema.Properties + } + + tools = append(tools, schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: tool.Name, + Description: tool.Description, + Parameters: params, + }, + }) + } + if bifrostReq.Params == nil { + bifrostReq.Params = &schemas.ModelParameters{} + } + bifrostReq.Params.Tools = &tools + } + + // Convert tool choice + if r.ToolChoice != nil { + if bifrostReq.Params == nil { + bifrostReq.Params = &schemas.ModelParameters{} + } + toolChoice := &schemas.ToolChoice{ + ToolChoiceStruct: &schemas.ToolChoiceStruct{ + Type: func() schemas.ToolChoiceType { + if r.ToolChoice.Type == "tool" { + return schemas.ToolChoiceTypeFunction + } + return schemas.ToolChoiceType(r.ToolChoice.Type) + }(), + }, + } + if r.ToolChoice.Type == "tool" && r.ToolChoice.Name != "" { + toolChoice.ToolChoiceStruct.Function = schemas.ToolChoiceFunction{ + Name: r.ToolChoice.Name, + } + } + bifrostReq.Params.ToolChoice = toolChoice + } + + return bifrostReq +} + +// Helper function to convert interface{} to JSON string +func jsonifyInput(input interface{}) string { + if input == nil { + return "{}" + } + jsonBytes, err := json.Marshal(input) + if err != nil { + return "{}" + } + return string(jsonBytes) +} + +// DeriveAnthropicFromBifrostResponse converts a Bifrost response to Anthropic format +func DeriveAnthropicFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *AnthropicMessageResponse { + if bifrostResp == nil { + return nil + } + + anthropicResp := &AnthropicMessageResponse{ + ID: bifrostResp.ID, + Type: "message", + Role: string(schemas.ModelChatMessageRoleAssistant), + Model: bifrostResp.Model, + } + + // Convert usage information + if bifrostResp.Usage != (schemas.LLMUsage{}) { + anthropicResp.Usage = &AnthropicUsage{ + InputTokens: bifrostResp.Usage.PromptTokens, + OutputTokens: bifrostResp.Usage.CompletionTokens, + } + } + + // Convert choices to content + var content []AnthropicContentBlock + if len(bifrostResp.Choices) > 0 { + choice := bifrostResp.Choices[0] // Anthropic typically returns one choice + + if choice.FinishReason != nil { + anthropicResp.StopReason = choice.FinishReason + } + if choice.StopString != nil { + anthropicResp.StopSequence = choice.StopString + } + + // Add thinking content if present + if choice.Message.AssistantMessage != nil && choice.Message.AssistantMessage.Thought != nil && *choice.Message.AssistantMessage.Thought != "" { + content = append(content, AnthropicContentBlock{ + Type: "thinking", + Text: choice.Message.AssistantMessage.Thought, + }) + } + + // Add text content + if choice.Message.Content.ContentStr != nil && *choice.Message.Content.ContentStr != "" { + content = append(content, AnthropicContentBlock{ + Type: "text", + Text: choice.Message.Content.ContentStr, + }) + } else if choice.Message.Content.ContentBlocks != nil { + for _, block := range *choice.Message.Content.ContentBlocks { + if block.Text != nil { + content = append(content, AnthropicContentBlock{ + Type: "text", + Text: block.Text, + }) + } + } + } + + // Add tool calls as tool_use content + if choice.Message.AssistantMessage != nil && choice.Message.AssistantMessage.ToolCalls != nil { + for _, toolCall := range *choice.Message.AssistantMessage.ToolCalls { + // Parse arguments JSON string back to map + var input map[string]interface{} + if toolCall.Function.Arguments != "" { + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil { + input = map[string]interface{}{} + } + } else { + input = map[string]interface{}{} + } + + content = append(content, AnthropicContentBlock{ + Type: "tool_use", + ID: toolCall.ID, + Name: toolCall.Function.Name, + Input: input, + }) + } + } + } + + if content == nil { + content = []AnthropicContentBlock{} + } + + anthropicResp.Content = content + return anthropicResp +} diff --git a/transports/bifrost-http/integrations/genai/router.go b/transports/bifrost-http/integrations/genai/router.go new file mode 100644 index 0000000000..8f0b470a9b --- /dev/null +++ b/transports/bifrost-http/integrations/genai/router.go @@ -0,0 +1,81 @@ +package genai + +import ( + "errors" + "fmt" + "strings" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" + "github.com/valyala/fasthttp" +) + +// GenAIRouter holds route registrations for genai endpoints. +type GenAIRouter struct { + *integrations.GenericRouter +} + +// NewGenAIRouter creates a new GenAIRouter with the given bifrost client. +func NewGenAIRouter(client *bifrost.Bifrost) *GenAIRouter { + routes := []integrations.RouteConfig{ + { + Path: "/genai/v1beta/models/{model}", + Method: "POST", + GetRequestTypeInstance: func() interface{} { + return &GeminiChatRequest{} + }, + RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { + if geminiReq, ok := req.(*GeminiChatRequest); ok { + return geminiReq.ConvertToBifrostRequest(), nil + } + return nil, errors.New("invalid request type") + }, + ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { + return DeriveGenAIFromBifrostResponse(resp), nil + }, + PreCallback: extractAndSetModelFromURL, + }, + } + + return &GenAIRouter{ + GenericRouter: integrations.NewGenericRouter(client, routes), + } +} + +// extractAndSetModelFromURL extracts model from URL and sets it in the request +func extractAndSetModelFromURL(ctx *fasthttp.RequestCtx, req interface{}) error { + model := ctx.UserValue("model") + if model == nil { + return fmt.Errorf("model parameter is required") + } + + modelStr := model.(string) + // Remove Google GenAI API endpoint suffixes if present + for _, sfx := range []string{ + ":streamGenerateContent", + ":generateContent", + ":countTokens", + } { + modelStr = strings.TrimSuffix(modelStr, sfx) + } + + // Remove trailing colon if present + if len(modelStr) > 0 && modelStr[len(modelStr)-1] == ':' { + modelStr = modelStr[:len(modelStr)-1] + } + + // Add google/ prefix for Bifrost if not already present + processedModel := modelStr + if !strings.HasPrefix(modelStr, "google/") { + processedModel = "google/" + modelStr + } + + // Set the model in the request + if geminiReq, ok := req.(*GeminiChatRequest); ok { + geminiReq.Model = processedModel + return nil + } + + return fmt.Errorf("invalid request type for GenAI") +} diff --git a/transports/bifrost-http/integrations/genai/types.go b/transports/bifrost-http/integrations/genai/types.go new file mode 100644 index 0000000000..70ad4aab98 --- /dev/null +++ b/transports/bifrost-http/integrations/genai/types.go @@ -0,0 +1,608 @@ +package genai + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "strings" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + genai_sdk "google.golang.org/genai" +) + +var fnTypePtr = bifrost.Ptr(string(schemas.ToolChoiceTypeFunction)) + +// CustomBlob handles URL-safe base64 decoding for Google GenAI requests +type CustomBlob struct { + Data []byte `json:"data,omitempty"` + MIMEType string `json:"mimeType,omitempty"` +} + +// UnmarshalJSON custom unmarshalling to handle URL-safe base64 encoding +func (b *CustomBlob) UnmarshalJSON(data []byte) error { + // First unmarshal into a temporary struct with string data + var temp struct { + Data string `json:"data,omitempty"` + MIMEType string `json:"mimeType,omitempty"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + b.MIMEType = temp.MIMEType + + if temp.Data != "" { + // Convert URL-safe base64 to standard base64 + standardBase64 := strings.ReplaceAll(strings.ReplaceAll(temp.Data, "_", "/"), "-", "+") + + // Add padding if necessary + switch len(standardBase64) % 4 { + case 2: + standardBase64 += "==" + case 3: + standardBase64 += "=" + } + + decoded, err := base64.StdEncoding.DecodeString(standardBase64) + if err != nil { + return fmt.Errorf("failed to decode base64 data: %v", err) + } + b.Data = decoded + } + + return nil +} + +// CustomPart handles Google GenAI Part with custom Blob unmarshalling +type CustomPart struct { + VideoMetadata *genai_sdk.VideoMetadata `json:"videoMetadata,omitempty"` + Thought bool `json:"thought,omitempty"` + CodeExecutionResult *genai_sdk.CodeExecutionResult `json:"codeExecutionResult,omitempty"` + ExecutableCode *genai_sdk.ExecutableCode `json:"executableCode,omitempty"` + FileData *genai_sdk.FileData `json:"fileData,omitempty"` + FunctionCall *genai_sdk.FunctionCall `json:"functionCall,omitempty"` + FunctionResponse *genai_sdk.FunctionResponse `json:"functionResponse,omitempty"` + InlineData *CustomBlob `json:"inlineData,omitempty"` + Text string `json:"text,omitempty"` +} + +// ToGenAIPart converts CustomPart to genai_sdk.Part +func (p *CustomPart) ToGenAIPart() *genai_sdk.Part { + part := &genai_sdk.Part{ + VideoMetadata: p.VideoMetadata, + Thought: p.Thought, + CodeExecutionResult: p.CodeExecutionResult, + ExecutableCode: p.ExecutableCode, + FileData: p.FileData, + FunctionCall: p.FunctionCall, + FunctionResponse: p.FunctionResponse, + Text: p.Text, + } + + if p.InlineData != nil { + part.InlineData = &genai_sdk.Blob{ + Data: p.InlineData.Data, + MIMEType: p.InlineData.MIMEType, + } + } + + return part +} + +// CustomContent handles Google GenAI Content with custom Part unmarshalling +type CustomContent struct { + Parts []*CustomPart `json:"parts,omitempty"` + Role string `json:"role,omitempty"` +} + +// ToGenAIContent converts CustomContent to genai_sdk.Content +func (c *CustomContent) ToGenAIContent() genai_sdk.Content { + parts := make([]*genai_sdk.Part, len(c.Parts)) + for i, part := range c.Parts { + parts[i] = part.ToGenAIPart() + } + + return genai_sdk.Content{ + Parts: parts, + Role: c.Role, + } +} + +// ensureExtraParams ensures that bifrostReq.Params and bifrostReq.Params.ExtraParams are initialized +func ensureExtraParams(bifrostReq *schemas.BifrostRequest) { + if bifrostReq.Params == nil { + bifrostReq.Params = &schemas.ModelParameters{ + ExtraParams: make(map[string]interface{}), + } + } + if bifrostReq.Params.ExtraParams == nil { + bifrostReq.Params.ExtraParams = make(map[string]interface{}) + } +} + +type GeminiChatRequest struct { + Model string `json:"model,omitempty"` // Model field for explicit model specification + Contents []CustomContent `json:"contents"` + SystemInstruction *CustomContent `json:"systemInstruction,omitempty"` + GenerationConfig genai_sdk.GenerationConfig `json:"generationConfig,omitempty"` + SafetySettings []genai_sdk.SafetySetting `json:"safetySettings,omitempty"` + Tools []genai_sdk.Tool `json:"tools,omitempty"` + ToolConfig genai_sdk.ToolConfig `json:"toolConfig,omitempty"` + Labels map[string]string `json:"labels,omitempty"` + CachedContent string `json:"cachedContent,omitempty"` + ResponseModalities []string `json:"responseModalities,omitempty"` +} + +func (r *GeminiChatRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { + bifrostReq := &schemas.BifrostRequest{ + Provider: schemas.Vertex, + Model: r.Model, + Input: schemas.RequestInput{ + ChatCompletionInput: &[]schemas.BifrostMessage{}, + }, + } + + messages := []schemas.BifrostMessage{} + + allGenAiMessages := []genai_sdk.Content{} + if r.SystemInstruction != nil { + allGenAiMessages = append(allGenAiMessages, r.SystemInstruction.ToGenAIContent()) + } + for _, content := range r.Contents { + allGenAiMessages = append(allGenAiMessages, content.ToGenAIContent()) + } + + for _, content := range allGenAiMessages { + if len(content.Parts) == 0 { + continue + } + + // Handle multiple parts - collect all content and tool calls + var toolCalls []schemas.ToolCall + var contentBlocks []schemas.ContentBlock + var thoughtStr string // Track thought content for assistant/model + + for _, part := range content.Parts { + switch { + case part.Text != "": + // Handle thought content specially for assistant messages + if part.Thought && + (content.Role == string(schemas.ModelChatMessageRoleAssistant) || content.Role == string(genai_sdk.RoleModel)) { + thoughtStr = thoughtStr + part.Text + "\n" + } else { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeText, + Text: &part.Text, + }) + } + + case part.FunctionCall != nil: + // Only add function calls for assistant messages + if content.Role == string(schemas.ModelChatMessageRoleAssistant) || content.Role == string(genai_sdk.RoleModel) { + jsonArgs, err := json.Marshal(part.FunctionCall.Args) + if err != nil { + jsonArgs = []byte(fmt.Sprintf("%v", part.FunctionCall.Args)) + } + id := part.FunctionCall.ID // create local copy + name := part.FunctionCall.Name // create local copy + toolCall := schemas.ToolCall{ + ID: bifrost.Ptr(id), + Type: fnTypePtr, + Function: schemas.FunctionCall{ + Name: &name, + Arguments: string(jsonArgs), + }, + } + toolCalls = append(toolCalls, toolCall) + } + + case part.FunctionResponse != nil: + // Create a separate tool response message + responseContent, err := json.Marshal(part.FunctionResponse.Response) + if err != nil { + responseContent = []byte(fmt.Sprintf("%v", part.FunctionResponse.Response)) + } + + toolResponseMsg := schemas.BifrostMessage{ + Role: schemas.ModelChatMessageRoleTool, + Content: schemas.MessageContent{ + ContentStr: bifrost.Ptr(string(responseContent)), + }, + ToolMessage: &schemas.ToolMessage{ + ToolCallID: &part.FunctionResponse.Name, + }, + } + + messages = append(messages, toolResponseMsg) + + case part.InlineData != nil: + // Handle inline images/media - only append if it's actually an image + if isImageMimeType(part.InlineData.MIMEType) { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeImage, + ImageURL: &schemas.ImageURLStruct{ + URL: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MIMEType, base64.StdEncoding.EncodeToString(part.InlineData.Data)), + }, + }) + } + + case part.FileData != nil: + // Handle file data - only append if it's actually an image + if isImageMimeType(part.FileData.MIMEType) { + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeImage, + ImageURL: &schemas.ImageURLStruct{ + URL: part.FileData.FileURI, + }, + }) + } + + case part.ExecutableCode != nil: + // Handle executable code as text content + codeText := fmt.Sprintf("```%s\n%s\n```", part.ExecutableCode.Language, part.ExecutableCode.Code) + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeText, + Text: &codeText, + }) + + case part.CodeExecutionResult != nil: + // Handle code execution results as text content + resultText := fmt.Sprintf("Code execution result (%s):\n%s", part.CodeExecutionResult.Outcome, part.CodeExecutionResult.Output) + contentBlocks = append(contentBlocks, schemas.ContentBlock{ + Type: schemas.ContentBlockTypeText, + Text: &resultText, + }) + } + } + + // Only create message if there's actual content, tool calls, or thought content + if len(contentBlocks) > 0 || len(toolCalls) > 0 || thoughtStr != "" { + // Create main message with content blocks + bifrostMsg := schemas.BifrostMessage{ + Role: func(r string) schemas.ModelChatMessageRole { + if r == string(genai_sdk.RoleModel) { // GenAI's internal alias + return schemas.ModelChatMessageRoleAssistant + } + return schemas.ModelChatMessageRole(r) + }(content.Role), + } + + // Set content only if there are content blocks + if len(contentBlocks) > 0 { + bifrostMsg.Content = schemas.MessageContent{ + ContentBlocks: &contentBlocks, + } + } + + // Set assistant-specific fields for assistant/model messages + if content.Role == string(schemas.ModelChatMessageRoleAssistant) || content.Role == string(genai_sdk.RoleModel) { + if len(toolCalls) > 0 || thoughtStr != "" { + bifrostMsg.AssistantMessage = &schemas.AssistantMessage{} + if len(toolCalls) > 0 { + bifrostMsg.AssistantMessage.ToolCalls = &toolCalls + } + if thoughtStr != "" { + bifrostMsg.AssistantMessage.Thought = &thoughtStr + } + } + } + + messages = append(messages, bifrostMsg) + } + } + + bifrostReq.Input.ChatCompletionInput = &messages + + // Convert generation config to parameters + if params := r.convertGenerationConfigToParams(); params != nil { + bifrostReq.Params = params + } + + // Convert safety settings + if len(r.SafetySettings) > 0 { + ensureExtraParams(bifrostReq) + bifrostReq.Params.ExtraParams["safety_settings"] = r.SafetySettings + } + + // Convert additional request fields + if r.CachedContent != "" { + ensureExtraParams(bifrostReq) + bifrostReq.Params.ExtraParams["cached_content"] = r.CachedContent + } + + // Convert response modalities + if len(r.ResponseModalities) > 0 { + ensureExtraParams(bifrostReq) + bifrostReq.Params.ExtraParams["response_modalities"] = r.ResponseModalities + } + + // Convert labels + if len(r.Labels) > 0 { + ensureExtraParams(bifrostReq) + bifrostReq.Params.ExtraParams["labels"] = r.Labels + } + + // Convert tools and tool config + if len(r.Tools) > 0 { + ensureExtraParams(bifrostReq) + + tools := make([]schemas.Tool, 0, len(r.Tools)) + for _, tool := range r.Tools { + if len(tool.FunctionDeclarations) > 0 { + for _, fn := range tool.FunctionDeclarations { + bifrostTool := schemas.Tool{ + Type: "function", + Function: schemas.Function{ + Name: fn.Name, + Description: fn.Description, + }, + } + // Convert parameters schema if present + if fn.Parameters != nil { + bifrostTool.Function.Parameters = r.convertSchemaToFunctionParameters(fn.Parameters) + } + tools = append(tools, bifrostTool) + } + } + // Handle other tool types (Retrieval, GoogleSearch, etc.) as ExtraParams + if tool.Retrieval != nil { + bifrostReq.Params.ExtraParams["retrieval"] = tool.Retrieval + } + if tool.GoogleSearch != nil { + bifrostReq.Params.ExtraParams["google_search"] = tool.GoogleSearch + } + if tool.CodeExecution != nil { + bifrostReq.Params.ExtraParams["code_execution"] = tool.CodeExecution + } + } + + if len(tools) > 0 { + bifrostReq.Params.Tools = &tools + } + } + + // Convert tool config + if r.ToolConfig.FunctionCallingConfig != nil || r.ToolConfig.RetrievalConfig != nil { + ensureExtraParams(bifrostReq) + bifrostReq.Params.ExtraParams["tool_config"] = r.ToolConfig + } + + return bifrostReq +} + +// convertGenerationConfigToParams converts Gemini GenerationConfig to ModelParameters +func (r *GeminiChatRequest) convertGenerationConfigToParams() *schemas.ModelParameters { + params := &schemas.ModelParameters{ + ExtraParams: make(map[string]interface{}), + } + + config := r.GenerationConfig + + // Map generation config fields to parameters + if config.Temperature != nil { + temp := float64(*config.Temperature) + params.Temperature = &temp + } + if config.TopP != nil { + params.TopP = bifrost.Ptr(float64(*config.TopP)) + } + if config.TopK != nil { + params.TopK = bifrost.Ptr(int(*config.TopK)) + } + if config.MaxOutputTokens > 0 { + maxTokens := int(config.MaxOutputTokens) + params.MaxTokens = &maxTokens + } + if config.CandidateCount > 0 { + params.ExtraParams["candidate_count"] = config.CandidateCount + } + if len(config.StopSequences) > 0 { + params.StopSequences = &config.StopSequences + } + if config.PresencePenalty != nil { + params.PresencePenalty = bifrost.Ptr(float64(*config.PresencePenalty)) + } + if config.FrequencyPenalty != nil { + params.FrequencyPenalty = bifrost.Ptr(float64(*config.FrequencyPenalty)) + } + if config.Seed != nil { + params.ExtraParams["seed"] = *config.Seed + } + if config.ResponseMIMEType != "" { + params.ExtraParams["response_mime_type"] = config.ResponseMIMEType + } + if config.ResponseLogprobs { + params.ExtraParams["response_logprobs"] = config.ResponseLogprobs + } + if config.Logprobs != nil { + params.ExtraParams["logprobs"] = *config.Logprobs + } + + return params +} + +// convertSchemaToFunctionParameters converts genai.Schema to schemas.FunctionParameters +func (r *GeminiChatRequest) convertSchemaToFunctionParameters(schema *genai_sdk.Schema) schemas.FunctionParameters { + params := schemas.FunctionParameters{ + Type: string(schema.Type), + } + + if schema.Description != "" { + params.Description = &schema.Description + } + + if len(schema.Required) > 0 { + params.Required = schema.Required + } + + if len(schema.Properties) > 0 { + params.Properties = make(map[string]interface{}) + for k, v := range schema.Properties { + params.Properties[k] = v + } + } + + if len(schema.Enum) > 0 { + params.Enum = &schema.Enum + } + + return params +} + +func DeriveGenAIFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *genai_sdk.GenerateContentResponse { + if bifrostResp == nil { + return nil + } + + genaiResp := &genai_sdk.GenerateContentResponse{ + Candidates: make([]*genai_sdk.Candidate, len(bifrostResp.Choices)), + } + + if bifrostResp.Usage != (schemas.LLMUsage{}) { + genaiResp.UsageMetadata = &genai_sdk.GenerateContentResponseUsageMetadata{ + PromptTokenCount: int32(bifrostResp.Usage.PromptTokens), + CandidatesTokenCount: int32(bifrostResp.Usage.CompletionTokens), + TotalTokenCount: int32(bifrostResp.Usage.TotalTokens), + } + } + + for i, choice := range bifrostResp.Choices { + candidate := &genai_sdk.Candidate{ + Index: int32(choice.Index), + } + if choice.FinishReason != nil { + candidate.FinishReason = genai_sdk.FinishReason(*choice.FinishReason) + } + + if bifrostResp.Usage != (schemas.LLMUsage{}) { + candidate.TokenCount = int32(bifrostResp.Usage.CompletionTokens) + } + + parts := []*genai_sdk.Part{} + if choice.Message.Content.ContentStr != nil && *choice.Message.Content.ContentStr != "" { + parts = append(parts, &genai_sdk.Part{Text: *choice.Message.Content.ContentStr}) + } else if choice.Message.Content.ContentBlocks != nil { + for _, block := range *choice.Message.Content.ContentBlocks { + if block.Text != nil { + parts = append(parts, &genai_sdk.Part{Text: *block.Text}) + } + } + } + + // Handle tool calls + if choice.Message.AssistantMessage != nil && choice.Message.AssistantMessage.ToolCalls != nil { + for _, toolCall := range *choice.Message.AssistantMessage.ToolCalls { + argsMap := make(map[string]interface{}) + if toolCall.Function.Arguments != "" { + // Attempt to unmarshal arguments, but don't fail if it's not valid JSON, + // as BifrostResponse.FunctionCall.Arguments is a string. + // genai.FunctionCall.Args expects map[string]any. + json.Unmarshal([]byte(toolCall.Function.Arguments), &argsMap) + } + if toolCall.Function.Name != nil { + fc := &genai_sdk.FunctionCall{ + Name: *toolCall.Function.Name, + Args: argsMap, + } + if toolCall.ID != nil { + fc.ID = *toolCall.ID + } + parts = append(parts, &genai_sdk.Part{FunctionCall: fc}) + } + } + } + + // Handle thinking content if present + if choice.Message.AssistantMessage != nil && choice.Message.AssistantMessage.Thought != nil && *choice.Message.AssistantMessage.Thought != "" { + parts = append(parts, &genai_sdk.Part{ + Text: *choice.Message.AssistantMessage.Thought, + Thought: true, + }) + } + + if len(parts) > 0 { + candidate.Content = &genai_sdk.Content{ + Parts: parts, + Role: string(choice.Message.Role), + } + } + + // Handle safety ratings if available (from ExtraFields) + if bifrostResp.ExtraFields.RawResponse != nil { + if rawMap, ok := bifrostResp.ExtraFields.RawResponse.(map[string]interface{}); ok { + if candidates, ok := rawMap["candidates"].([]interface{}); ok && len(candidates) > i { + if candidateMap, ok := candidates[i].(map[string]interface{}); ok { + if safetyRatings, ok := candidateMap["safetyRatings"].([]interface{}); ok { + var ratings []*genai_sdk.SafetyRating + for _, rating := range safetyRatings { + if ratingMap, ok := rating.(map[string]interface{}); ok { + sr := &genai_sdk.SafetyRating{} + if category, ok := ratingMap["category"].(string); ok { + sr.Category = genai_sdk.HarmCategory(category) + } + if probability, ok := ratingMap["probability"].(string); ok { + sr.Probability = genai_sdk.HarmProbability(probability) + } + if blocked, ok := ratingMap["blocked"].(bool); ok { + sr.Blocked = blocked + } + ratings = append(ratings, sr) + } + } + candidate.SafetyRatings = ratings + } + } + } + } + } + + genaiResp.Candidates[i] = candidate + } + + return genaiResp +} + +// isImageMimeType checks if a MIME type represents an image format +func isImageMimeType(mimeType string) bool { + if mimeType == "" { + return false + } + + // Convert to lowercase for case-insensitive comparison + mimeType = strings.ToLower(mimeType) + + // Remove any parameters (e.g., "image/jpeg; charset=utf-8" -> "image/jpeg") + if idx := strings.Index(mimeType, ";"); idx != -1 { + mimeType = strings.TrimSpace(mimeType[:idx]) + } + + // If it starts with "image/", it's an image + if strings.HasPrefix(mimeType, "image/") { + return true + } + + // Check for common image formats that might not have the "image/" prefix + commonImageTypes := []string{ + "jpeg", + "jpg", + "png", + "gif", + "webp", + "bmp", + "svg", + "tiff", + "ico", + "avif", + } + + // Check if the mimeType contains any of the common image type strings + for _, imageType := range commonImageTypes { + if strings.Contains(mimeType, imageType) { + return true + } + } + + return false +} diff --git a/transports/bifrost-http/integrations/litellm/router.go b/transports/bifrost-http/integrations/litellm/router.go new file mode 100644 index 0000000000..f8d2c25464 --- /dev/null +++ b/transports/bifrost-http/integrations/litellm/router.go @@ -0,0 +1,158 @@ +package litellm + +import ( + "encoding/json" + "errors" + "slices" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/anthropic" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/genai" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/openai" + "github.com/valyala/fasthttp" +) + +// LiteLLMRequestWrapper wraps any provider-specific request type +type LiteLLMRequestWrapper struct { + Model string `json:"model"` + ActualRequest interface{} `json:"-"` // This will hold the actual provider-specific request + Provider schemas.ModelProvider `json:"-"` +} + +// LiteLLMRouter holds route registrations for LiteLLM endpoints. +// It supports standard chat completions and image-enabled vision capabilities. +// LiteLLM is fully OpenAI-compatible, so we reuse OpenAI types +// with aliases for clarity and minimal LiteLLM-specific extensions +type LiteLLMRouter struct { + *integrations.GenericRouter +} + +// NewLiteLLMRouter creates a new LiteLLMRouter with the given bifrost client. +func NewLiteLLMRouter(client *bifrost.Bifrost) *LiteLLMRouter { + paths := []string{ + "/chat/completions", + "/v1/messages", + } + + getRequestTypeInstance := func() interface{} { + return &LiteLLMRequestWrapper{} + } + + availableProviders := []schemas.ModelProvider{ + schemas.OpenAI, + schemas.Anthropic, + schemas.Vertex, + schemas.Azure, + } + + // Pre-hook to determine provider and parse request with correct type + preHook := func(ctx *fasthttp.RequestCtx, req interface{}) error { + wrapper, ok := req.(*LiteLLMRequestWrapper) + if !ok { + return errors.New("invalid request wrapper type") + } + + if wrapper.Model == "" { + return errors.New("model field is required") + } + + // Determine provider from model + provider := integrations.GetProviderFromModel(wrapper.Model) + if !slices.Contains(availableProviders, provider) { + return errors.New("unsupported provider: " + string(provider)) + } + + // Get the request body + body := ctx.Request.Body() + if len(body) == 0 { + return errors.New("request body is required") + } + + // Create the appropriate request type based on provider and re-parse + var actualReq interface{} + switch provider { + case schemas.OpenAI, schemas.Azure: + actualReq = &openai.OpenAIChatRequest{} + case schemas.Anthropic: + actualReq = &anthropic.AnthropicMessageRequest{} + case schemas.Vertex: + actualReq = &genai.GeminiChatRequest{} + default: + return errors.New("unsupported provider: " + string(provider)) + } + + // Parse the body into the correct request type + if err := json.Unmarshal(body, actualReq); err != nil { + return errors.New("failed to parse request for provider " + string(provider) + ": " + err.Error()) + } + + // Store the parsed request and provider in the wrapper + wrapper.ActualRequest = actualReq + wrapper.Provider = provider + + return nil + } + + requestConverter := func(req interface{}) (*schemas.BifrostRequest, error) { + wrapper, ok := req.(*LiteLLMRequestWrapper) + if !ok { + return nil, errors.New("invalid request wrapper type") + } + + if wrapper.ActualRequest == nil { + return nil, errors.New("request was not properly processed by pre-hook") + } + + // Handle different provider-specific request types + switch actualReq := wrapper.ActualRequest.(type) { + case *openai.OpenAIChatRequest: + bifrostReq := actualReq.ConvertToBifrostRequest() + bifrostReq.Provider = wrapper.Provider + return bifrostReq, nil + + case *anthropic.AnthropicMessageRequest: + bifrostReq := actualReq.ConvertToBifrostRequest() + bifrostReq.Provider = wrapper.Provider + return bifrostReq, nil + + case *genai.GeminiChatRequest: + bifrostReq := actualReq.ConvertToBifrostRequest() + bifrostReq.Provider = wrapper.Provider + return bifrostReq, nil + + default: + return nil, errors.New("unsupported request type") + } + } + + responseConverter := func(resp *schemas.BifrostResponse) (interface{}, error) { + switch resp.ExtraFields.Provider { + case schemas.OpenAI, schemas.Azure: + return openai.DeriveOpenAIFromBifrostResponse(resp), nil + case schemas.Anthropic: + return anthropic.DeriveAnthropicFromBifrostResponse(resp), nil + case schemas.Vertex: + return genai.DeriveGenAIFromBifrostResponse(resp), nil + default: + return resp, nil + } + } + + routes := []integrations.RouteConfig{} + for _, path := range paths { + routes = append(routes, integrations.RouteConfig{ + Path: "/litellm" + path, + Method: "POST", + GetRequestTypeInstance: getRequestTypeInstance, + RequestConverter: requestConverter, + ResponseConverter: responseConverter, + PreCallback: preHook, + }) + } + + return &LiteLLMRouter{ + GenericRouter: integrations.NewGenericRouter(client, routes), + } +} diff --git a/transports/bifrost-http/integrations/openai/router.go b/transports/bifrost-http/integrations/openai/router.go new file mode 100644 index 0000000000..7371781f0b --- /dev/null +++ b/transports/bifrost-http/integrations/openai/router.go @@ -0,0 +1,41 @@ +package openai + +import ( + "errors" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" +) + +// OpenAIRouter holds route registrations for OpenAI endpoints. +// It supports standard chat completions and image-enabled vision capabilities. +type OpenAIRouter struct { + *integrations.GenericRouter +} + +// NewOpenAIRouter creates a new OpenAIRouter with the given bifrost client. +func NewOpenAIRouter(client *bifrost.Bifrost) *OpenAIRouter { + routes := []integrations.RouteConfig{ + { + Path: "/openai/chat/completions", + Method: "POST", + GetRequestTypeInstance: func() interface{} { + return &OpenAIChatRequest{} + }, + RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { + if openaiReq, ok := req.(*OpenAIChatRequest); ok { + return openaiReq.ConvertToBifrostRequest(), nil + } + return nil, errors.New("invalid request type") + }, + ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { + return DeriveOpenAIFromBifrostResponse(resp), nil + }, + }, + } + + return &OpenAIRouter{ + GenericRouter: integrations.NewGenericRouter(client, routes), + } +} diff --git a/transports/bifrost-http/integrations/openai/types.go b/transports/bifrost-http/integrations/openai/types.go new file mode 100644 index 0000000000..043b88b740 --- /dev/null +++ b/transports/bifrost-http/integrations/openai/types.go @@ -0,0 +1,129 @@ +package openai + +import ( + "github.com/maximhq/bifrost/core/schemas" +) + +// OpenAIChatRequest represents an OpenAI chat completion request +type OpenAIChatRequest struct { + Model string `json:"model"` + Messages []schemas.BifrostMessage `json:"messages"` + MaxTokens *int `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + N *int `json:"n,omitempty"` + Stop interface{} `json:"stop,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + LogitBias map[string]float64 `json:"logit_bias,omitempty"` + User *string `json:"user,omitempty"` + Tools *[]schemas.Tool `json:"tools,omitempty"` // Reuse schema type + ToolChoice *schemas.ToolChoice `json:"tool_choice,omitempty"` + Stream *bool `json:"stream,omitempty"` + LogProbs *bool `json:"logprobs,omitempty"` + TopLogProbs *int `json:"top_logprobs,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` +} + +// OpenAIChatResponse represents an OpenAI chat completion response +type OpenAIChatResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + Choices []schemas.BifrostResponseChoice `json:"choices"` + Usage *schemas.LLMUsage `json:"usage,omitempty"` // Reuse schema type + ServiceTier *string `json:"service_tier,omitempty"` + SystemFingerprint *string `json:"system_fingerprint,omitempty"` +} + +// ConvertToBifrostRequest converts an OpenAI chat request to Bifrost format +func (r *OpenAIChatRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { + bifrostReq := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: r.Model, + Input: schemas.RequestInput{ + ChatCompletionInput: &r.Messages, + }, + } + + // Map extra parameters and tool settings + bifrostReq.Params = r.convertParameters() + + return bifrostReq +} + +// convertParameters converts OpenAI request parameters to Bifrost ModelParameters +// using direct field access for better performance and type safety. +func (r *OpenAIChatRequest) convertParameters() *schemas.ModelParameters { + params := &schemas.ModelParameters{ + ExtraParams: make(map[string]interface{}), + } + + params.Tools = r.Tools + params.ToolChoice = r.ToolChoice + + // Direct field mapping + if r.MaxTokens != nil { + params.MaxTokens = r.MaxTokens + } + if r.Temperature != nil { + params.Temperature = r.Temperature + } + if r.TopP != nil { + params.TopP = r.TopP + } + if r.PresencePenalty != nil { + params.PresencePenalty = r.PresencePenalty + } + if r.FrequencyPenalty != nil { + params.FrequencyPenalty = r.FrequencyPenalty + } + if r.N != nil { + params.ExtraParams["n"] = *r.N + } + if r.LogProbs != nil { + params.ExtraParams["logprobs"] = *r.LogProbs + } + if r.TopLogProbs != nil { + params.ExtraParams["top_logprobs"] = *r.TopLogProbs + } + if r.Stop != nil { + params.ExtraParams["stop"] = r.Stop + } + if r.LogitBias != nil { + params.ExtraParams["logit_bias"] = r.LogitBias + } + if r.User != nil { + params.ExtraParams["user"] = *r.User + } + if r.Stream != nil { + params.ExtraParams["stream"] = *r.Stream + } + if r.Seed != nil { + params.ExtraParams["seed"] = *r.Seed + } + + return params +} + +// DeriveOpenAIFromBifrostResponse converts a Bifrost response to OpenAI format +func DeriveOpenAIFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *OpenAIChatResponse { + if bifrostResp == nil { + return nil + } + + openaiResp := &OpenAIChatResponse{ + ID: bifrostResp.ID, + Object: bifrostResp.Object, + Created: bifrostResp.Created, + Model: bifrostResp.Model, + Choices: bifrostResp.Choices, + Usage: &bifrostResp.Usage, + ServiceTier: bifrostResp.ServiceTier, + SystemFingerprint: bifrostResp.SystemFingerprint, + } + + return openaiResp +} diff --git a/transports/bifrost-http/integrations/utils.go b/transports/bifrost-http/integrations/utils.go new file mode 100644 index 0000000000..de99c4024c --- /dev/null +++ b/transports/bifrost-http/integrations/utils.go @@ -0,0 +1,363 @@ +package integrations + +import ( + "encoding/json" + "fmt" + "log" + "strings" + + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// ExtensionRouter defines the interface that all integration routers must implement +// to register their routes with the main HTTP router. +type ExtensionRouter interface { + RegisterRoutes(r *router.Router) +} + +// RequestConverter is a function that converts integration-specific requests to Bifrost format. +// It takes the parsed request object and returns a BifrostRequest ready for processing. +type RequestConverter func(req interface{}) (*schemas.BifrostRequest, error) + +// ResponseConverter is a function that converts Bifrost responses to integration-specific format. +// It takes a BifrostResponse and returns the format expected by the specific integration. +type ResponseConverter func(*schemas.BifrostResponse) (interface{}, error) + +// PreRequestCallback is called before processing the request. +// It can be used to modify the request object (e.g., extract model from URL parameters) +// or perform validation. If it returns an error, the request processing stops. +type PreRequestCallback func(ctx *fasthttp.RequestCtx, req interface{}) error + +// PostRequestCallback is called after processing the request but before sending the response. +// It can be used to modify the response or perform additional logging/metrics. +// If it returns an error, an error response is sent instead of the success response. +type PostRequestCallback func(ctx *fasthttp.RequestCtx, req interface{}, resp *schemas.BifrostResponse) error + +// RouteConfig defines configuration for a single HTTP route in an integration. +// Each route specifies how to handle requests for a specific endpoint. +type RouteConfig struct { + Path string // HTTP path pattern (e.g., "/openai/v1/chat/completions") + Method string // HTTP method (POST, GET, PUT, DELETE) + GetRequestTypeInstance func() interface{} // Factory function to create request instance (SHOULD NOT BE NIL) + RequestConverter RequestConverter // Function to convert request to BifrostRequest (SHOULD NOT BE NIL) + ResponseConverter ResponseConverter // Function to convert BifrostResponse to integration format (SHOULD NOT BE NIL) + PreCallback PreRequestCallback // Optional: called before request processing + PostCallback PostRequestCallback // Optional: called after request processing +} + +// GenericRouter provides a reusable router implementation for all integrations. +// It handles the common flow of: parse request β†’ convert to Bifrost β†’ execute β†’ convert response. +// Integration-specific logic is handled through the RouteConfig callbacks and converters. +type GenericRouter struct { + client *bifrost.Bifrost // Bifrost client for executing requests + routes []RouteConfig // List of route configurations +} + +// NewGenericRouter creates a new generic router with the given bifrost client and route configurations. +// Each integration should create their own routes and pass them to this constructor. +func NewGenericRouter(client *bifrost.Bifrost, routes []RouteConfig) *GenericRouter { + return &GenericRouter{ + client: client, + routes: routes, + } +} + +// RegisterRoutes registers all configured routes on the given fasthttp router. +// This method implements the ExtensionRouter interface. +func (g *GenericRouter) RegisterRoutes(r *router.Router) { + for _, route := range g.routes { + // Validate route configuration at startup to fail fast + if route.GetRequestTypeInstance == nil { + log.Println("[WARN] route configuration is invalid: GetRequestTypeInstance cannot be nil for route " + route.Path) + continue + } + if route.RequestConverter == nil { + log.Println("[WARN] route configuration is invalid: RequestConverter cannot be nil for route " + route.Path) + continue + } + if route.ResponseConverter == nil { + log.Println("[WARN] route configuration is invalid: ResponseConverter cannot be nil for route " + route.Path) + continue + } + + // Test that GetRequestTypeInstance returns a valid instance + if testInstance := route.GetRequestTypeInstance(); testInstance == nil { + log.Println("[WARN] route configuration is invalid: GetRequestTypeInstance returned nil for route " + route.Path) + continue + } + + handler := g.createHandler(route) + switch strings.ToUpper(route.Method) { + case fasthttp.MethodPost: + r.POST(route.Path, handler) + case fasthttp.MethodGet: + r.GET(route.Path, handler) + case fasthttp.MethodPut: + r.PUT(route.Path, handler) + case fasthttp.MethodDelete: + r.DELETE(route.Path, handler) + default: + r.POST(route.Path, handler) // Default to POST + } + } +} + +// createHandler creates a fasthttp handler for the given route configuration. +// The handler follows this flow: +// 1. Parse JSON request body into the configured request type (for methods that expect bodies) +// 2. Execute pre-callback (if configured) for request modification/validation +// 3. Convert request to BifrostRequest using the configured converter +// 4. Execute the request through Bifrost +// 5. Execute post-callback (if configured) for response modification +// 6. Convert and send the response using the configured response converter +func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + // Parse request body into the integration-specific request type + // Note: config validation is performed at startup in RegisterRoutes + req := config.GetRequestTypeInstance() + + method := string(ctx.Method()) + + if method != fasthttp.MethodGet && method != fasthttp.MethodDelete { + // Use ctx.Request.Body() instead of ctx.PostBody() to support all HTTP methods + body := ctx.Request.Body() + if len(body) > 0 { + if err := json.Unmarshal(body, req); err != nil { + g.sendError(ctx, newBifrostError(err, "Invalid JSON")) + return + } + } + } + + // Execute pre-request callback if configured + // This is typically used for extracting data from URL parameters + // or performing request-specific validation + if config.PreCallback != nil { + if err := config.PreCallback(ctx, req); err != nil { + g.sendError(ctx, newBifrostError(err, "failed to execute pre-request callback")) + return + } + } + + // Convert the integration-specific request to Bifrost format + bifrostReq, err := config.RequestConverter(req) + if err != nil { + g.sendError(ctx, newBifrostError(err, "failed to convert request to Bifrost format")) + return + } + if bifrostReq == nil { + g.sendError(ctx, newBifrostError(nil, "Invalid request")) + return + } + if bifrostReq.Model == "" { + g.sendError(ctx, newBifrostError(nil, "Model parameter is required")) + return + } + + // Execute the request through Bifrost + bifrostCtx := lib.ConvertToBifrostContext(ctx) + result, bifrostErr := g.client.ChatCompletionRequest(*bifrostCtx, bifrostReq) + if bifrostErr != nil { + g.sendError(ctx, bifrostErr) + return + } + // Execute post-request callback if configured + // This is typically used for response modification or additional processing + if config.PostCallback != nil { + if err := config.PostCallback(ctx, req, result); err != nil { + g.sendError(ctx, newBifrostError(err, "failed to execute post-request callback")) + return + } + } + + if result == nil { + g.sendError(ctx, newBifrostError(nil, "Bifrost response is nil after post-request callback")) + return + } + + // Convert Bifrost response to integration-specific format and send + response, err := config.ResponseConverter(result) + if err != nil { + g.sendError(ctx, newBifrostError(err, "failed to encode response")) + return + } + g.sendSuccess(ctx, response) + } +} + +// sendError sends an error response with the appropriate status code and JSON body. +// It handles different error types (string, error interface, or arbitrary objects). +func (g *GenericRouter) sendError(ctx *fasthttp.RequestCtx, err *schemas.BifrostError) { + if err.StatusCode != nil { + ctx.SetStatusCode(*err.StatusCode) + } else { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + } + + ctx.SetContentType("application/json") + if encodeErr := json.NewEncoder(ctx).Encode(err); encodeErr != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.SetBodyString(fmt.Sprintf("failed to encode error response: %v", encodeErr)) + } +} + +// sendSuccess sends a successful response with HTTP 200 status and JSON body. +func (g *GenericRouter) sendSuccess(ctx *fasthttp.RequestCtx, response interface{}) { + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/json") + + responseBody, err := json.Marshal(response) + if err != nil { + g.sendError(ctx, newBifrostError(err, "failed to encode response")) + return + } + + ctx.SetBody(responseBody) +} + +// GetProviderFromModel determines the appropriate provider based on model name patterns +// This function uses comprehensive pattern matching to identify the correct provider +// for various model naming conventions used across different AI providers. +func GetProviderFromModel(model string) schemas.ModelProvider { + // Normalize model name for case-insensitive matching + modelLower := strings.ToLower(strings.TrimSpace(model)) + + // Azure OpenAI Models - check first to prevent false positives from OpenAI "gpt" patterns + if isAzureModel(modelLower) { + return schemas.Azure + } + + // OpenAI Models - comprehensive pattern matching + if isOpenAIModel(modelLower) { + return schemas.OpenAI + } + + // Anthropic Models - Claude family + if isAnthropicModel(modelLower) { + return schemas.Anthropic + } + + // Google Vertex AI Models - Gemini and Palm family + if isVertexModel(modelLower) { + return schemas.Vertex + } + + // AWS Bedrock Models - various model providers through Bedrock + if isBedrockModel(modelLower) { + return schemas.Bedrock + } + + // Cohere Models - Command and Embed family + if isCohereModel(modelLower) { + return schemas.Cohere + } + + // Default to OpenAI for unknown models (most LiteLLM compatible) + return schemas.OpenAI +} + +// isOpenAIModel checks for OpenAI model patterns +func isOpenAIModel(model string) bool { + // Exclude Azure models to prevent overlap + if strings.Contains(model, "azure/") { + return false + } + + openaiPatterns := []string{ + "gpt", "davinci", "curie", "babbage", "ada", "o1", "o3", "o4", + "text-embedding", "dall-e", "whisper", "tts", "chatgpt", + } + + return matchesAnyPattern(model, openaiPatterns) +} + +// isAzureModel checks for Azure OpenAI specific patterns +func isAzureModel(model string) bool { + azurePatterns := []string{ + "azure", "model-router", "computer-use-preview", + } + + return matchesAnyPattern(model, azurePatterns) +} + +// isAnthropicModel checks for Anthropic Claude model patterns +func isAnthropicModel(model string) bool { + anthropicPatterns := []string{ + "claude", "anthropic/", + } + + return matchesAnyPattern(model, anthropicPatterns) +} + +// isVertexModel checks for Google Vertex AI model patterns +func isVertexModel(model string) bool { + vertexPatterns := []string{ + "gemini", "palm", "bison", "gecko", "vertex/", "google/", + } + + return matchesAnyPattern(model, vertexPatterns) +} + +// isBedrockModel checks for AWS Bedrock model patterns +func isBedrockModel(model string) bool { + bedrockPatterns := []string{ + "bedrock", "bedrock.amazonaws.com/", "bedrock/", + "amazon.titan", "amazon.nova", "aws/amazon.", + "ai21.jamba", "ai21.j2", "aws/ai21.", + "meta.llama", "aws/meta.", + "stability.stable-diffusion", "stability.sd3", "aws/stability.", + "anthropic.claude", "aws/anthropic.", + "cohere.command", "cohere.embed", "aws/cohere.", + "mistral.mistral", "mistral.mixtral", "aws/mistral.", + "titan-text", "titan-embed", "nova-micro", "nova-lite", "nova-pro", + "jamba-instruct", "j2-ultra", "j2-mid", + "llama-2", "llama-3", "llama-3.1", "llama-3.2", + "stable-diffusion-xl", "sd3-large", + } + + return matchesAnyPattern(model, bedrockPatterns) +} + +// isCohereModel checks for Cohere model patterns +func isCohereModel(model string) bool { + coherePatterns := []string{ + "command-", "embed-", "cohere", + } + + return matchesAnyPattern(model, coherePatterns) +} + +// matchesAnyPattern checks if the model matches any of the given patterns +func matchesAnyPattern(model string, patterns []string) bool { + for _, pattern := range patterns { + if strings.Contains(model, pattern) { + return true + } + } + return false +} + +// newBifrostError wraps a standard error into a BifrostError with IsBifrostError set to false. +// This helper function reduces code duplication when handling non-Bifrost errors. +func newBifrostError(err error, message string) *schemas.BifrostError { + if err == nil { + return &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: message, + }, + } + } + + return &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: message, + Error: err, + }, + } +} diff --git a/transports/bifrost-http/lib/account.go b/transports/bifrost-http/lib/account.go new file mode 100644 index 0000000000..0e96c6abd0 --- /dev/null +++ b/transports/bifrost-http/lib/account.go @@ -0,0 +1,175 @@ +// Package lib provides core functionality for the Bifrost HTTP service, +// including context propagation, header management, and integration with monitoring systems. +package lib + +import ( + "errors" + "fmt" + "os" + "reflect" + "strings" + "sync" + + "github.com/maximhq/bifrost/core/schemas" +) + +// BaseAccount implements the Account interface for Bifrost. +// It manages provider configurations and API keys. +type BaseAccount struct { + Config ConfigMap // Map of provider configurations + mu sync.Mutex // Mutex to protect Config access +} + +// GetConfiguredProviders returns a list of all configured providers. +// Implements the Account interface. +func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + baseAccount.mu.Lock() + defer baseAccount.mu.Unlock() + + providers := make([]schemas.ModelProvider, 0, len(baseAccount.Config)) + for provider := range baseAccount.Config { + providers = append(providers, provider) + } + return providers, nil +} + +// GetKeysForProvider returns the API keys configured for a specific provider. +// Implements the Account interface. +func (baseAccount *BaseAccount) GetKeysForProvider(providerKey schemas.ModelProvider) ([]schemas.Key, error) { + baseAccount.mu.Lock() + defer baseAccount.mu.Unlock() + + return baseAccount.Config[providerKey].Keys, nil +} + +// GetConfigForProvider returns the complete configuration for a specific provider. +// Implements the Account interface. +func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + baseAccount.mu.Lock() + defer baseAccount.mu.Unlock() + + config, exists := baseAccount.Config[providerKey] + if !exists { + return nil, errors.New("config for provider not found") + } + + providerConfig := &schemas.ProviderConfig{} + + if config.NetworkConfig != nil { + providerConfig.NetworkConfig = *config.NetworkConfig + } else { + providerConfig.NetworkConfig = schemas.DefaultNetworkConfig + } + + if config.MetaConfig != nil { + providerConfig.MetaConfig = *config.MetaConfig + } + + if config.ConcurrencyAndBufferSize != nil { + providerConfig.ConcurrencyAndBufferSize = *config.ConcurrencyAndBufferSize + } else { + providerConfig.ConcurrencyAndBufferSize = schemas.DefaultConcurrencyAndBufferSize + } + + return providerConfig, nil +} + +// ReadKeys reads environment variables from the environment and updates the provider configurations. +// It replaces values starting with "env." in the config with actual values from the environment. +// Returns an error if any required environment variable is missing. +func (baseAccount *BaseAccount) ReadKeys() error { + // Helper function to check and replace env values + replaceEnvValue := func(value string) (string, error) { + if strings.HasPrefix(value, "env.") { + envKey := strings.TrimPrefix(value, "env.") + if envValue := os.Getenv(envKey); envValue != "" { + return envValue, nil + } + return "", fmt.Errorf("environment variable %s not found in the environment", envKey) + } + return value, nil + } + + // Helper function to recursively check and replace env values in a struct + var processStruct func(interface{}) error + processStruct = func(v interface{}) error { + val := reflect.ValueOf(v) + + // Dereference pointer if present + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + // Handle interface types + if val.Kind() == reflect.Interface { + val = val.Elem() + // If the interface value is a pointer, dereference it + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + } + + if val.Kind() != reflect.Struct { + return nil + } + + typ := val.Type() + for i := 0; i < val.NumField(); i++ { + field := val.Field(i) + fieldType := typ.Field(i) + + // Skip unexported fields + if !field.CanSet() { + continue + } + + switch field.Kind() { + case reflect.String: + if field.CanSet() { + value := field.String() + if strings.HasPrefix(value, "env.") { + newValue, err := replaceEnvValue(value) + if err != nil { + return fmt.Errorf("field %s: %w", fieldType.Name, err) + } + field.SetString(newValue) + } + } + case reflect.Interface: + if !field.IsNil() { + if err := processStruct(field.Interface()); err != nil { + return err + } + } + } + } + return nil + } + + // Lock the config map for the entire update operation + baseAccount.mu.Lock() + defer baseAccount.mu.Unlock() + + // Check and replace values in provider configs + for provider, config := range baseAccount.Config { + // Check keys + for i, key := range config.Keys { + newValue, err := replaceEnvValue(key.Value) + if err != nil { + return fmt.Errorf("provider %s: %w", provider, err) + } + config.Keys[i].Value = newValue + } + + // Check meta config if it exists + if config.MetaConfig != nil { + if err := processStruct(config.MetaConfig); err != nil { + return fmt.Errorf("provider %s: %w", provider, err) + } + } + + baseAccount.Config[provider] = config + } + + return nil +} diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go new file mode 100644 index 0000000000..7e721ae3ec --- /dev/null +++ b/transports/bifrost-http/lib/config.go @@ -0,0 +1,188 @@ +// Package lib provides core functionality for the Bifrost HTTP service, +// including context propagation, header management, and integration with monitoring systems. +package lib + +import ( + "encoding/json" + "fmt" + "log" + "os" + "strings" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/core/schemas/meta" +) + +// ProviderConfig represents the configuration for a specific AI model provider. +// It includes API keys, network settings, provider-specific metadata, and concurrency settings. +type ProviderConfig struct { + Keys []schemas.Key `json:"keys"` // API keys for the provider + NetworkConfig *schemas.NetworkConfig `json:"network_config,omitempty"` // Network-related settings + MetaConfig *schemas.MetaConfig `json:"-"` // Provider-specific metadata + ConcurrencyAndBufferSize *schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size,omitempty"` // Concurrency settings +} + +// ConfigMap maps provider names to their configurations. +type ConfigMap map[schemas.ModelProvider]ProviderConfig + +// PluginConfig represents the configuration for a dynamically loaded plugin +type PluginConfig struct { + Name string `json:"name"` // Plugin name (used for identification) + Source string `json:"source"` // Source path (Go module path or local directory) + Type string `json:"type"` // "remote" for Go modules, "local" for local paths + Config json.RawMessage `json:"config"` // Plugin-specific configuration as raw JSON +} + +// BifrostHTTPConfig represents the complete configuration structure for Bifrost HTTP transport. +// It includes provider configurations, MCP configuration, and plugin configurations. +type BifrostHTTPConfig struct { + ProviderConfig ConfigMap `json:"providers"` // Provider configurations + MCPConfig *schemas.MCPConfig `json:"mcp"` // MCP configuration (optional) + Plugins []PluginConfig `json:"plugins"` // Plugin configurations +} + +// ReadMCPKeys reads environment variables from the environment and updates the MCP configurations. +// It replaces values starting with "env." in the connection_string field with actual values from the environment. +// Returns an error if any required environment variable is missing. +func (config *BifrostHTTPConfig) ReadMCPKeys() error { + if config.MCPConfig == nil { + return nil // No MCP config to process + } + + // Helper function to check and replace env values + replaceEnvValue := func(value string) (string, error) { + if strings.HasPrefix(value, "env.") { + envKey := strings.TrimPrefix(value, "env.") + if envValue := os.Getenv(envKey); envValue != "" { + return envValue, nil + } + return "", fmt.Errorf("environment variable %s not found in the environment", envKey) + } + return value, nil + } + + // Process each client config + for i, clientConfig := range config.MCPConfig.ClientConfigs { + // Process ConnectionString if present + if clientConfig.ConnectionString != nil { + newValue, err := replaceEnvValue(*clientConfig.ConnectionString) + if err != nil { + return fmt.Errorf("MCP client %s: %w", clientConfig.Name, err) + } + config.MCPConfig.ClientConfigs[i].ConnectionString = &newValue + } + } + + return nil +} + +// readConfig reads and parses the configuration file. +// It handles case conversion for provider names and sets up provider-specific metadata. +// Returns a BifrostHTTPConfig containing both provider and MCP configurations. +// Panics if the config file cannot be read or parsed. +// +// In the config file, use placeholder keys (e.g., env.OPENAI_API_KEY) instead of hardcoding actual values. +// These placeholders will be replaced with the corresponding values from the environment variables. +// Example: +// +// "providers": { +// "openAI": { +// "keys":[{ +// "value": "env.OPENAI_API_KEY" +// "models": ["gpt-4o-mini", "gpt-4-turbo"], +// "weight": 1.0 +// }] +// } +// }, +// "mcp": { +// "client_configs": [...] +// } +// +// In this example, OPENAI_API_KEY refers to a key in the environment variables. At runtime, its value will be used to replace the placeholder. +// Same setup applies to keys in meta configs of all the providers. +// Example: +// +// "meta_config": { +// "secret_access_key": "env.AWS_SECRET_ACCESS_KEY" +// "region": "env.AWS_REGION" +// } +// +// In this example, AWS_SECRET_ACCESS_KEY and AWS_REGION refer to keys in environment variables. +func ReadConfig(configLocation string) *BifrostHTTPConfig { + data, err := os.ReadFile(configLocation) + if err != nil { + log.Fatalf("failed to read config JSON file: %v", err) + } + + // First unmarshal into the new structure + var fullConfig BifrostHTTPConfig + if err := json.Unmarshal(data, &fullConfig); err != nil { + log.Fatalf("failed to unmarshal JSON: %v", err) + } + + if fullConfig.ProviderConfig == nil { + log.Fatalf("providers section is required in config") + } + + // Process provider configurations - convert string keys to lowercase provider names and handle meta configs + processedProviders := make(ConfigMap) + + // First unmarshal providers into a map with string keys to handle case conversion + var rawProviders map[string]ProviderConfig + if providersBytes, err := json.Marshal(fullConfig.ProviderConfig); err != nil { + log.Fatalf("failed to marshal providers: %v", err) + } else if err := json.Unmarshal(providersBytes, &rawProviders); err != nil { + log.Fatalf("failed to unmarshal providers: %v", err) + } + + // Create a temporary structure to unmarshal the full JSON with proper meta configs + var tempConfig struct { + Providers map[string]struct { + MetaConfig json.RawMessage `json:"meta_config"` + } `json:"providers"` + } + + if err := json.Unmarshal(data, &tempConfig); err != nil { + log.Fatalf("failed to unmarshal configuration file: %v\n\n Please check your configuration file for proper JSON formatting and meta_config structure", err) + } else { + for rawProvider, cfg := range rawProviders { + provider := schemas.ModelProvider(strings.ToLower(rawProvider)) + + // Get the raw meta config for this provider + if tempProvider, exists := tempConfig.Providers[rawProvider]; exists && len(tempProvider.MetaConfig) > 0 { + switch provider { + case schemas.Azure: + var azureMetaConfig meta.AzureMetaConfig + if err := json.Unmarshal(tempProvider.MetaConfig, &azureMetaConfig); err != nil { + log.Printf("warning: failed to unmarshal Azure meta config: %v", err) + } else { + var metaConfig schemas.MetaConfig = &azureMetaConfig + cfg.MetaConfig = &metaConfig + } + case schemas.Bedrock: + var bedrockMetaConfig meta.BedrockMetaConfig + if err := json.Unmarshal(tempProvider.MetaConfig, &bedrockMetaConfig); err != nil { + log.Printf("warning: failed to unmarshal Bedrock meta config: %v", err) + } else { + var metaConfig schemas.MetaConfig = &bedrockMetaConfig + cfg.MetaConfig = &metaConfig + } + case schemas.Vertex: + var vertexMetaConfig meta.VertexMetaConfig + if err := json.Unmarshal(tempProvider.MetaConfig, &vertexMetaConfig); err != nil { + log.Printf("warning: failed to unmarshal Vertex meta config: %v", err) + } else { + var metaConfig schemas.MetaConfig = &vertexMetaConfig + cfg.MetaConfig = &metaConfig + } + } + } + + processedProviders[provider] = cfg + } + + } + + fullConfig.ProviderConfig = processedProviders + return &fullConfig +} diff --git a/transports/bifrost-http/lib/ctx.go b/transports/bifrost-http/lib/ctx.go new file mode 100644 index 0000000000..e0a2c360ca --- /dev/null +++ b/transports/bifrost-http/lib/ctx.go @@ -0,0 +1,77 @@ +// Package lib provides core functionality for the Bifrost HTTP service, +// including context propagation, header management, and integration with monitoring systems. +// +// This package handles the conversion of FastHTTP request contexts to Bifrost contexts, +// ensuring that important metadata and tracking information is preserved across the system. +// It supports propagation of both Prometheus metrics and Maxim tracing data through HTTP headers. +package lib + +import ( + "context" + "strings" + + "github.com/maximhq/bifrost/transports/bifrost-http/tracking" + "github.com/valyala/fasthttp" +) + +// ConvertToBifrostContext converts a FastHTTP RequestCtx to a Bifrost context, +// preserving important header values for monitoring and tracing purposes. +// +// The function processes two types of special headers: +// 1. Prometheus Headers (x-bf-prom-*): +// - All headers prefixed with 'x-bf-prom-' are copied to the context +// - The prefix is stripped and the remainder becomes the context key +// - Example: 'x-bf-prom-latency' becomes 'latency' in the context +// +// 2. Maxim Tracing Headers (x-bf-maxim-*): +// - Specifically handles 'x-bf-maxim-traceID' and 'x-bf-maxim-generationID' +// - These headers enable trace correlation across service boundaries +// - Values are stored using Maxim's context keys for consistency +// +// Parameters: +// - ctx: The FastHTTP request context containing the original headers +// +// Returns: +// - *context.Context: A new context.Context containing the propagated values +// +// Example Usage: +// +// fastCtx := &fasthttp.RequestCtx{...} +// bifrostCtx := ConvertToBifrostContext(fastCtx) +// // bifrostCtx now contains any prometheus and maxim header values + +type ContextKey string + +func ConvertToBifrostContext(ctx *fasthttp.RequestCtx) *context.Context { + bifrostCtx := context.Background() + + // Copy all prometheus header values to the new context + ctx.Request.Header.VisitAll(func(key, value []byte) { + keyStr := strings.ToLower(string(key)) + + if strings.HasPrefix(keyStr, "x-bf-prom-") { + labelName := strings.TrimPrefix(keyStr, "x-bf-prom-") + bifrostCtx = context.WithValue(bifrostCtx, tracking.PrometheusContextKey(labelName), string(value)) + } + + if strings.HasPrefix(keyStr, "x-bf-maxim-") { + labelName := strings.TrimPrefix(keyStr, "x-bf-maxim-") + + if labelName == "session-id" || labelName == "trace-id" || labelName == "generation-id" { + bifrostCtx = context.WithValue(bifrostCtx, ContextKey(labelName), string(value)) + return + } + } + + if strings.HasPrefix(keyStr, "x-bf-maxim-") { + labelName := strings.TrimPrefix(keyStr, "x-bf-maxim-") + + if labelName == "session-id" || labelName == "trace-id" || labelName == "generation-id" { + bifrostCtx = context.WithValue(bifrostCtx, ContextKey(labelName), string(value)) + return + } + } + }) + + return &bifrostCtx +} diff --git a/transports/bifrost-http/lib/plugin_compiler.go b/transports/bifrost-http/lib/plugin_compiler.go new file mode 100644 index 0000000000..29f6e386d3 --- /dev/null +++ b/transports/bifrost-http/lib/plugin_compiler.go @@ -0,0 +1,477 @@ +// Package lib provides plugin compilation and loading functionality for Bifrost HTTP transport. +package lib + +import ( + "encoding/json" + "fmt" + "log" + "os" + "os/exec" + "path/filepath" + "plugin" + "strings" + "time" + + "github.com/maximhq/bifrost/core/schemas" +) + +// PluginCompiler handles just-in-time compilation of plugins +type PluginCompiler struct { + tempDir string + compiledPlugins map[string]string // plugin name -> .so file path + sessionID string // unique session for cleanup +} + +// NewPluginCompiler creates a new plugin compiler instance +func NewPluginCompiler() *PluginCompiler { + sessionID := fmt.Sprintf("bifrost-plugins-%d", time.Now().Unix()) + tempDir := filepath.Join(os.TempDir(), sessionID) + + if err := os.MkdirAll(tempDir, 0755); err != nil { + log.Printf("warning: failed to create plugin temp directory: %v", err) + } + + return &PluginCompiler{ + tempDir: tempDir, + compiledPlugins: make(map[string]string), + sessionID: sessionID, + } +} + +// LoadPlugin handles the complete workflow for a single plugin +func (pc *PluginCompiler) LoadPlugin(config PluginConfig) (schemas.Plugin, error) { + log.Printf("loading plugin %s from %s", config.Name, config.Source) + + // 1. Setup workspace + workDir := filepath.Join(pc.tempDir, config.Name) + if err := os.MkdirAll(workDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create workspace: %w", err) + } + + // 2. Get source code + var err error + switch config.Type { + case "remote": + err = pc.downloadRemotePlugin(workDir, config.Source) + case "local": + err = pc.copyLocalPlugin(workDir, config.Source) + default: + return nil, fmt.Errorf("unsupported plugin type: %s", config.Type) + } + + if err != nil { + return nil, fmt.Errorf("failed to get plugin source: %w", err) + } + + // 3. Compile plugin + soPath, err := pc.compilePlugin(workDir, config.Name) + if err != nil { + return nil, fmt.Errorf("failed to compile plugin: %w", err) + } + + // 4. Load compiled plugin + pluginInstance, err := pc.loadCompiledPlugin(soPath, config.Config) + if err != nil { + return nil, fmt.Errorf("failed to load compiled plugin: %w", err) + } + + pc.compiledPlugins[config.Name] = soPath + return pluginInstance, nil +} + +// downloadRemotePlugin downloads a remote Go module plugin +func (pc *PluginCompiler) downloadRemotePlugin(workDir, modulePath string) error { + log.Printf("downloading remote plugin: %s", modulePath) + + // Initialize a new Go module in the workspace + cmd := exec.Command("go", "mod", "init", "plugin-temp") + cmd.Dir = workDir + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("go mod init failed: %s", string(output)) + } + + // Get the plugin module + cmd = exec.Command("go", "get", modulePath) + cmd.Dir = workDir + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("go get failed: %s", string(output)) + } + + // Note: We rely on go mod tidy to resolve compatible versions + + // Create a simple main.go that re-exports the plugin + mainContent := fmt.Sprintf(`package main + +import ( + "encoding/json" + "github.com/maximhq/bifrost/core/schemas" + plugin "%s" +) + +// Re-export the Init function +func Init(config json.RawMessage) (schemas.Plugin, error) { + return plugin.Init(config) +} + +func main() {} +`, modulePath) + + if err := os.WriteFile(filepath.Join(workDir, "main.go"), []byte(mainContent), 0644); err != nil { + return fmt.Errorf("failed to create main.go: %w", err) + } + + // Run go mod tidy to resolve dependencies + cmd = exec.Command("go", "mod", "tidy") + cmd.Dir = workDir + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("go mod tidy failed: %s", string(output)) + } + + return nil +} + +// copyLocalPlugin copies a local plugin to the workspace and modifies it to be a main package +func (pc *PluginCompiler) copyLocalPlugin(workDir, localPath string) error { + log.Printf("copying local plugin: %s", localPath) + + dir, err := os.Getwd() + if err != nil { + return fmt.Errorf("failed to get working directory: %w", err) + } + + // Copy all files from the source directory directly to workDir + cmd := exec.Command("cp", "-r", filepath.Join(dir, localPath)+"/.", workDir) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to copy local plugin: %s", string(output)) + } + + // Read the plugin's go.mod to get the module path + pluginModPath, err := pc.getPluginModulePath(workDir) + if err != nil { + return fmt.Errorf("failed to determine plugin module path: %w", err) + } + + // Modify all .go files to change package declaration to main + err = pc.convertPackageToMain(workDir, pluginModPath) + if err != nil { + return fmt.Errorf("failed to convert package to main: %w", err) + } + + // Enforce main application's dependency versions + err = pc.enforceMainAppDependencies(workDir, pluginModPath) + if err != nil { + return fmt.Errorf("failed to enforce main app dependencies: %w", err) + } + + // Run go mod tidy to resolve dependencies + cmd = exec.Command("go", "mod", "tidy") + cmd.Dir = workDir + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("go mod tidy failed: %s", string(output)) + } + + return nil +} + +// getPluginModulePath reads the go.mod file to extract the module path +func (pc *PluginCompiler) getPluginModulePath(pluginDir string) (string, error) { + goModPath := filepath.Join(pluginDir, "go.mod") + + // Check if go.mod exists + if _, err := os.Stat(goModPath); os.IsNotExist(err) { + return "", fmt.Errorf("go.mod not found in plugin directory") + } + + // Read go.mod file + content, err := os.ReadFile(goModPath) + if err != nil { + return "", fmt.Errorf("failed to read go.mod: %w", err) + } + + // Extract module name from first line + lines := strings.Split(string(content), "\n") + if len(lines) == 0 { + return "", fmt.Errorf("empty go.mod file") + } + + firstLine := strings.TrimSpace(lines[0]) + if !strings.HasPrefix(firstLine, "module ") { + return "", fmt.Errorf("invalid go.mod format: missing module declaration") + } + + modulePath := strings.TrimSpace(strings.TrimPrefix(firstLine, "module ")) + if modulePath == "" { + return "", fmt.Errorf("empty module path in go.mod") + } + + return modulePath, nil +} + +// convertPackageToMain modifies all .go files to change package declaration to main +func (pc *PluginCompiler) convertPackageToMain(pluginDir, originalPackage string) error { + // Extract the package name from the module path (last component) + packageName := filepath.Base(originalPackage) + + // Find all .go files in the plugin directory + files, err := filepath.Glob(filepath.Join(pluginDir, "*.go")) + if err != nil { + return fmt.Errorf("failed to find .go files: %w", err) + } + + for _, file := range files { + // Read the file + content, err := os.ReadFile(file) + if err != nil { + return fmt.Errorf("failed to read file %s: %w", file, err) + } + + // Convert the content + modifiedContent := string(content) + + // Replace package declaration + oldPackageDecl := fmt.Sprintf("package %s", packageName) + newPackageDecl := "package main" + + modifiedContent = strings.Replace(modifiedContent, oldPackageDecl, newPackageDecl, 1) + + // Write the modified content back + if err := os.WriteFile(file, []byte(modifiedContent), 0644); err != nil { + return fmt.Errorf("failed to write modified file %s: %w", file, err) + } + } + + return nil +} + +// compilePlugin compiles the plugin source to a .so file +func (pc *PluginCompiler) compilePlugin(workDir, pluginName string) (string, error) { + log.Printf("compiling plugin: %s", pluginName) + + soPath := filepath.Join(workDir, pluginName+".so") + + // Compile as plugin + cmd := exec.Command("go", "build", "-buildmode=plugin", "-o", soPath) + cmd.Dir = workDir + + // Set environment for compilation + cmd.Env = append(os.Environ(), "CGO_ENABLED=1") + + if output, err := cmd.CombinedOutput(); err != nil { + return "", fmt.Errorf("compilation failed: %s", string(output)) + } + + // Verify the .so file was created + if _, err := os.Stat(soPath); err != nil { + return "", fmt.Errorf("compiled plugin file not found: %w", err) + } + + log.Printf("successfully compiled plugin to: %s", soPath) + return soPath, nil +} + +// loadCompiledPlugin loads a compiled .so file and initializes the plugin +func (pc *PluginCompiler) loadCompiledPlugin(soPath string, config json.RawMessage) (schemas.Plugin, error) { + log.Printf("loading compiled plugin: %s", soPath) + + // Load the plugin + p, err := plugin.Open(soPath) + if err != nil { + return nil, fmt.Errorf("failed to open plugin: %w", err) + } + + // Look for the Init function + initSymbol, err := p.Lookup("Init") + if err != nil { + return nil, fmt.Errorf("init function not found in plugin: %w", err) + } + + // Cast to the expected function signature + initFunc, ok := initSymbol.(func(json.RawMessage) (schemas.Plugin, error)) + if !ok { + return nil, fmt.Errorf("init function has wrong signature") + } + + // Initialize the plugin with its config + pluginInstance, err := initFunc(config) + if err != nil { + return nil, fmt.Errorf("plugin initialization failed: %w", err) + } + + return pluginInstance, nil +} + +// Cleanup removes all temporary files and directories +func (pc *PluginCompiler) Cleanup() error { + if pc.tempDir == "" { + return nil + } + + log.Printf("cleaning up plugin temp directory: %s", pc.tempDir) + + if err := os.RemoveAll(pc.tempDir); err != nil { + return fmt.Errorf("failed to cleanup plugin temp directory: %w", err) + } + + return nil +} + +// enforceMainAppDependencies enforces main application's dependency versions on plugins +func (pc *PluginCompiler) enforceMainAppDependencies(workDir, pluginModPath string) error { + // Get main application's dependencies using hybrid approach + mainDeps, err := pc.getMainAppDependencies() + if err != nil { + log.Printf("warning: could not determine main app dependencies, using plugin's original versions: %v", err) + return nil // Don't fail, just use plugin's original versions + } + + // Read plugin's original go.mod to preserve plugin-specific dependencies + pluginGoModContent, err := os.ReadFile(filepath.Join(workDir, "go.mod")) + if err != nil { + return fmt.Errorf("failed to read plugin's go.mod: %w", err) + } + + pluginDeps := pc.parseGoModRequires(string(pluginGoModContent)) + + // Create new go.mod with main app's versions for shared dependencies + // Keep plugin's versions for plugin-specific dependencies + var requires []string + + // Add all main app dependencies + for dep, version := range mainDeps { + requires = append(requires, fmt.Sprintf("\t%s %s", dep, version)) + } + + // Add plugin-specific dependencies that aren't in main app + for dep, version := range pluginDeps { + if _, exists := mainDeps[dep]; !exists { + requires = append(requires, fmt.Sprintf("\t%s %s", dep, version)) + } + } + + // Create new go.mod content + newGoModContent := fmt.Sprintf(`module %s +go 1.21 + +require ( +%s +) +`, pluginModPath, strings.Join(requires, "\n")) + + // Write the new go.mod + if err := os.WriteFile(filepath.Join(workDir, "go.mod"), []byte(newGoModContent), 0644); err != nil { + return fmt.Errorf("failed to write updated go.mod: %w", err) + } + + log.Printf("Enforced main app dependency versions on plugin %s", pluginModPath) + return nil +} + +// parseGoModRequires extracts require dependencies from go.mod content +func (pc *PluginCompiler) parseGoModRequires(content string) map[string]string { + deps := make(map[string]string) + lines := strings.Split(content, "\n") + inRequireBlock := false + + for _, line := range lines { + line = strings.TrimSpace(line) + + // Handle require block start + if strings.HasPrefix(line, "require (") { + inRequireBlock = true + continue + } + + // Handle require block end + if inRequireBlock && line == ")" { + inRequireBlock = false + continue + } + + // Handle single-line require + if strings.HasPrefix(line, "require ") && !strings.Contains(line, "(") { + parts := strings.Fields(line) + if len(parts) >= 3 { + dep := parts[1] + version := parts[2] + deps[dep] = version + } + continue + } + + // Handle require block contents + if inRequireBlock && line != "" && !strings.HasPrefix(line, "//") { + parts := strings.Fields(line) + if len(parts) >= 2 { + dep := parts[0] + version := parts[1] + // Remove any trailing comments + if idx := strings.Index(version, "//"); idx != -1 { + version = strings.TrimSpace(version[:idx]) + } + deps[dep] = version + } + } + } + + return deps +} + +// getMainAppDependencies gets main application dependencies using hybrid approach +func (pc *PluginCompiler) getMainAppDependencies() (map[string]string, error) { + // Strategy 1: Try local go.mod (development case) + if deps, err := pc.getLocalGoModDeps(); err == nil { + log.Printf("Using local go.mod dependencies") + return deps, nil + } + + // Strategy 2: Use go list to get runtime dependencies (binary installation case) + log.Printf("Local go.mod not found, using runtime dependencies from go list") + return pc.getRuntimeDependencies() +} + +// getLocalGoModDeps attempts to read dependencies from local go.mod file +func (pc *PluginCompiler) getLocalGoModDeps() (map[string]string, error) { + cwd, _ := os.Getwd() + + // Try different possible locations for the transport's go.mod + possiblePaths := []string{ + filepath.Join(cwd, "go.mod"), // ./go.mod (current directory) + filepath.Join(cwd, "..", "go.mod"), // ../go.mod (parent directory) + filepath.Join(cwd, "..", "..", "go.mod"), // ../../go.mod (grandparent) + } + + for _, path := range possiblePaths { + if content, err := os.ReadFile(path); err == nil { + log.Printf("Reading local go.mod from: %s", path) + return pc.parseGoModRequires(string(content)), nil + } + } + + return nil, fmt.Errorf("no local go.mod file found") +} + +// getRuntimeDependencies gets the dependencies that were used to build the current binary +func (pc *PluginCompiler) getRuntimeDependencies() (map[string]string, error) { + cmd := exec.Command("go", "list", "-m", "-f", "{{.Path}} {{.Version}}", "all") + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("failed to get runtime dependencies: %w", err) + } + + deps := make(map[string]string) + lines := strings.Split(string(output), "\n") + for _, line := range lines { + parts := strings.Fields(strings.TrimSpace(line)) + if len(parts) >= 2 { + path := parts[0] + version := parts[1] + // Skip the main module itself + if path != "" && version != "" && !strings.Contains(version, "(main)") { + deps[path] = version + } + } + } + + log.Printf("Retrieved %d runtime dependencies via go list", len(deps)) + return deps, nil +} diff --git a/transports/bifrost-http/main.go b/transports/bifrost-http/main.go new file mode 100644 index 0000000000..d0134e11e7 --- /dev/null +++ b/transports/bifrost-http/main.go @@ -0,0 +1,376 @@ +// Package http provides an HTTP service using FastHTTP that exposes endpoints +// for text and chat completions using various AI model providers (OpenAI, Anthropic, Bedrock, Mistral, Ollama, etc.). +// +// The HTTP service provides three main endpoints: +// - /v1/text/completions: For text completion requests +// - /v1/chat/completions: For chat completion requests +// - /v1/mcp/tool/execute: For MCP tool execution requests +// +// Configuration is handled through a JSON config file and environment variables: +// - Use -config flag to specify the config file location +// - Use -port flag to specify the server port (default: 8080) +// - Use -pool-size flag to specify the initial connection pool size (default: 300) +// +// Example usage: +// +// go run main.go -config config.example.json -port 8080 -pool-size 300 +// after setting the environment variables present in config.example.json in the environment. +// +// Integration Support: +// Bifrost supports multiple AI provider integrations through dedicated HTTP endpoints. +// Each integration exposes API-compatible endpoints that accept the provider's native request format, +// automatically convert it to Bifrost's unified format, process it, and return the expected response format. +// +// Integration endpoints follow the pattern: /{provider}/{provider_api_path} +// Examples: +// - OpenAI: POST /openai/v1/chat/completions (accepts OpenAI ChatCompletion requests) +// - GenAI: POST /genai/v1beta/models/{model} (accepts Google GenAI requests) +// - Anthropic: POST /anthropic/v1/messages (accepts Anthropic Messages requests) +// +// This allows clients to use their existing integration code without modification while benefiting +// from Bifrost's unified model routing, fallbacks, and monitoring capabilities. +// +// NOTE: Streaming is not supported yet so all the flags related to streaming are ignored. (in both bifrost and its integrations) +package main + +import ( + "encoding/json" + "flag" + "fmt" + "log" + "strings" + + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/anthropic" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/genai" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/litellm" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations/openai" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/maximhq/bifrost/transports/bifrost-http/tracking" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/collectors" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttpadaptor" +) + +// Command line flags +var ( + initialPoolSize int // Initial size of the connection pool + dropExcessRequests bool // Drop excess requests + port string // Port to run the server on + configPath string // Path to the config file + pluginsToLoad []string // Path to the plugins + prometheusLabels []string // Labels to add to Prometheus metrics (optional) +) + +// init initializes command line flags and validates required configuration. +// It sets up the following flags: +// - pool-size: Initial connection pool size (default: 300) +// - port: Server port (default: 8080) +// - config: Path to config file (required) +// - drop-excess-requests: Whether to drop excess requests +func init() { + pluginString := "" + var prometheusLabelsString string + + flag.IntVar(&initialPoolSize, "pool-size", 300, "Initial pool size for Bifrost") + flag.StringVar(&port, "port", "8080", "Port to run the server on") + flag.StringVar(&configPath, "config", "", "Path to the config file") + flag.BoolVar(&dropExcessRequests, "drop-excess-requests", false, "Drop excess requests") + flag.StringVar(&pluginString, "plugins", "", "Comma separated list of plugins to load") + flag.StringVar(&prometheusLabelsString, "prometheus-labels", "", "Labels to add to Prometheus metrics") + flag.Parse() + + pluginsToLoad = strings.Split(pluginString, ",") + + if configPath == "" { + log.Fatalf("config path is required") + } + + if prometheusLabelsString != "" { + // Split and filter out empty strings + rawLabels := strings.Split(prometheusLabelsString, ",") + prometheusLabels = make([]string, 0, len(rawLabels)) + for _, label := range rawLabels { + if trimmed := strings.TrimSpace(label); trimmed != "" { + prometheusLabels = append(prometheusLabels, strings.ToLower(trimmed)) + } + } + } +} + +// CompletionRequest represents a request for either text or chat completion. +// It includes all necessary fields for both types of completions. +type CompletionRequest struct { + Provider schemas.ModelProvider `json:"provider"` // The AI model provider to use + Messages []schemas.BifrostMessage `json:"messages"` // Chat messages (for chat completion) + Text string `json:"text"` // Text input (for text completion) + Model string `json:"model"` // Model to use + Params *schemas.ModelParameters `json:"params"` // Additional model parameters + Fallbacks []schemas.Fallback `json:"fallbacks"` // Fallback providers and models +} + +// registerCollectorSafely attempts to register a Prometheus collector, +// handling the case where it may already be registered. +// It logs any errors that occur during registration, except for AlreadyRegisteredError. +func registerCollectorSafely(collector prometheus.Collector) { + if err := prometheus.Register(collector); err != nil { + if _, ok := err.(prometheus.AlreadyRegisteredError); !ok { + log.Printf("Failed to register collector: %v", err) + } + } +} + +// main is the entry point of the application. +// It: +// 1. Initializes Prometheus collectors for monitoring +// 2. Reads and parses configuration from the specified config file +// 3. Initializes the Bifrost client with the configuration +// 4. Sets up HTTP routes for text and chat completions +// 5. Starts the HTTP server on the specified port +// +// The server exposes the following endpoints: +// - POST /v1/text/completions: For text completion requests +// - POST /v1/chat/completions: For chat completion requests +// - GET /metrics: For Prometheus metrics +func main() { + // Register Prometheus collectors + registerCollectorSafely(collectors.NewGoCollector()) + registerCollectorSafely(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})) + + tracking.InitPrometheusMetrics(prometheusLabels) + + log.Println("Prometheus Go/Process collectors registered.") + + config := lib.ReadConfig(configPath) + account := &lib.BaseAccount{Config: config.ProviderConfig} + + if err := account.ReadKeys(); err != nil { + log.Printf("warning: failed to read environment variables: %v", err) + } + + if err := config.ReadMCPKeys(); err != nil { + log.Printf("warning: failed to read MCP environment variables: %v", err) + } + + loadedPlugins := []schemas.Plugin{} + + // Load plugins from configuration + if len(config.Plugins) > 0 { + // Initialize plugin compiler for dynamic loading + pluginCompiler := lib.NewPluginCompiler() + defer pluginCompiler.Cleanup() + + for _, config := range config.Plugins { + plugin, err := pluginCompiler.LoadPlugin(config) + if err != nil { + log.Printf("warning: failed to load plugin %s: %v", config.Name, err) + continue + } + + if plugin != nil { + loadedPlugins = append(loadedPlugins, plugin) + log.Printf("successfully loaded plugin: %s", plugin.GetName()) + } + } + } + + // Always add Prometheus plugin + promPlugin := tracking.NewPrometheusPlugin() + loadedPlugins = append(loadedPlugins, promPlugin) + + log.Printf("Successfully loaded %d plugins in total:", len(loadedPlugins)) + for _, plugin := range loadedPlugins { + log.Printf(" - %s", plugin.GetName()) + } + + client, err := bifrost.Init(schemas.BifrostConfig{ + Account: account, + InitialPoolSize: initialPoolSize, + DropExcessRequests: dropExcessRequests, + Plugins: loadedPlugins, + MCPConfig: config.MCPConfig, + }) + if err != nil { + log.Fatalf("failed to initialize bifrost: %v", err) + } + + r := router.New() + + extensions := []integrations.ExtensionRouter{ + genai.NewGenAIRouter(client), + openai.NewOpenAIRouter(client), + anthropic.NewAnthropicRouter(client), + litellm.NewLiteLLMRouter(client), + } + + r.POST("/v1/text/completions", func(ctx *fasthttp.RequestCtx) { + handleCompletion(ctx, client, false) + }) + + r.POST("/v1/chat/completions", func(ctx *fasthttp.RequestCtx) { + handleCompletion(ctx, client, true) + }) + + r.POST("/v1/mcp/tool/execute", func(ctx *fasthttp.RequestCtx) { + handleMCPToolExecution(ctx, client) + }) + + for _, extension := range extensions { + extension.RegisterRoutes(r) + } + + // Add Prometheus /metrics endpoint + r.GET("/metrics", fasthttpadaptor.NewFastHTTPHandler(promhttp.Handler())) + + r.NotFound = func(ctx *fasthttp.RequestCtx) { + ctx.SetStatusCode(fasthttp.StatusNotFound) + ctx.SetContentType("text/plain") + ctx.SetBodyString("Route not found: " + string(ctx.Path())) + } + + server := &fasthttp.Server{ + // A custom handler that excludes middleware from /metrics + Handler: func(ctx *fasthttp.RequestCtx) { + if string(ctx.Path()) == "/metrics" { + r.Handler(ctx) + return + } + tracking.PrometheusMiddleware(r.Handler)(ctx) + }, + } + + log.Println("Started Bifrost HTTP server on port", port) + if err := server.ListenAndServe(fmt.Sprintf(":%s", port)); err != nil { + log.Fatalf("failed to start server: %v", err) + } + + client.Cleanup() +} + +// handleCompletion processes both text and chat completion requests. +// It handles request parsing, validation, and response formatting. +// +// Parameters: +// - ctx: The FastHTTP request context +// - client: The Bifrost client instance +// - isChat: Whether this is a chat completion request (true) or text completion (false) +// +// The function: +// 1. Parses the request body into a CompletionRequest +// 2. Validates required fields based on the request type +// 3. Creates a BifrostRequest with the appropriate input type +// 4. Calls the appropriate completion method on the client +// 5. Handles any errors and formats the response +func handleCompletion(ctx *fasthttp.RequestCtx, client *bifrost.Bifrost, isChat bool) { + var req CompletionRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString(fmt.Sprintf("invalid request format: %v", err)) + return + } + + if req.Provider == "" { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString("Provider is required") + return + } + + bifrostReq := &schemas.BifrostRequest{ + Provider: req.Provider, + Model: req.Model, + Params: req.Params, + Fallbacks: req.Fallbacks, + } + + if isChat { + if len(req.Messages) == 0 { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString("Messages array is required") + return + } + bifrostReq.Input = schemas.RequestInput{ + ChatCompletionInput: &req.Messages, + } + } else { + if req.Text == "" { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString("Text is required") + return + } + bifrostReq.Input = schemas.RequestInput{ + TextCompletionInput: &req.Text, + } + } + + bifrostCtx := lib.ConvertToBifrostContext(ctx) + + var resp *schemas.BifrostResponse + var bifrostErr *schemas.BifrostError + + if bifrostCtx == nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.SetBodyString("Failed to convert context") + return + } + + if isChat { + resp, bifrostErr = client.ChatCompletionRequest(*bifrostCtx, bifrostReq) + } else { + resp, bifrostErr = client.TextCompletionRequest(*bifrostCtx, bifrostReq) + } + + if bifrostErr != nil { + handleBifrostError(ctx, bifrostErr) + return + } + + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/json") + if encodeErr := json.NewEncoder(ctx).Encode(resp); encodeErr != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.SetBodyString(fmt.Sprintf("failed to encode response: %v", encodeErr)) + } +} + +func handleMCPToolExecution(ctx *fasthttp.RequestCtx, client *bifrost.Bifrost) { + var req schemas.ToolCall + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString(fmt.Sprintf("invalid request format: %v", err)) + return + } + + bifrostCtx := lib.ConvertToBifrostContext(ctx) + + resp, bifrostErr := client.ExecuteMCPTool(*bifrostCtx, req) + if bifrostErr != nil { + handleBifrostError(ctx, bifrostErr) + return + } + + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/json") + if encodeErr := json.NewEncoder(ctx).Encode(resp); encodeErr != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.SetBodyString(fmt.Sprintf("failed to encode response: %v", encodeErr)) + } +} + +func handleBifrostError(ctx *fasthttp.RequestCtx, bifrostErr *schemas.BifrostError) { + if bifrostErr.StatusCode != nil { + ctx.SetStatusCode(*bifrostErr.StatusCode) + } else { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + } + + ctx.SetContentType("application/json") + if encodeErr := json.NewEncoder(ctx).Encode(bifrostErr); encodeErr != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.SetBodyString(fmt.Sprintf("failed to encode error response: %v", encodeErr)) + } +} diff --git a/transports/bifrost-http/tracking/docker-compose.yml b/transports/bifrost-http/tracking/docker-compose.yml new file mode 100644 index 0000000000..26ebdad612 --- /dev/null +++ b/transports/bifrost-http/tracking/docker-compose.yml @@ -0,0 +1,29 @@ +# Prometheus and Grafana for tracking bifrost-http service (for development and testing purposes only, don't use in production without proper setup) +services: + prometheus: + image: prom/prometheus:latest + container_name: prometheus + ports: + - "9090:9090" # Expose Prometheus web UI + volumes: + - ./prometheus.yml:/etc/prometheus/prometheus.yml # Prometheus config file + restart: always + networks: + - bifrost_tracking_network + + grafana: + image: grafana/grafana:latest + container_name: grafana + ports: + - "3000:3000" # Expose Grafana web UI + depends_on: + - prometheus + environment: + GF_SECURITY_ADMIN_PASSWORD: "admin" # Default admin password for Grafana + restart: always + networks: + - bifrost_tracking_network + +networks: + bifrost_tracking_network: + driver: bridge diff --git a/transports/bifrost-http/tracking/plugin.go b/transports/bifrost-http/tracking/plugin.go new file mode 100644 index 0000000000..730387b817 --- /dev/null +++ b/transports/bifrost-http/tracking/plugin.go @@ -0,0 +1,111 @@ +// Package tracking provides Prometheus metrics collection and monitoring functionality +// for the Bifrost HTTP service. It includes middleware for HTTP request tracking +// and a plugin for tracking upstream provider metrics. +package tracking + +import ( + "context" + "fmt" + "log" + "time" + + schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/prometheus/client_golang/prometheus" +) + +// Define context key type for storing start time +type contextKey string + +// PrometheusContextKey is a custom type for prometheus context keys to prevent collisions +type PrometheusContextKey string + +const startTimeKey contextKey = "startTime" +const methodKey contextKey = "method" + +// PrometheusPlugin implements the schemas.Plugin interface for Prometheus metrics. +// It tracks metrics for upstream provider requests, including: +// - Total number of requests +// - Request latency +// - Error counts +type PrometheusPlugin struct { + // Metrics are defined using promauto for automatic registration + UpstreamRequestsTotal *prometheus.CounterVec + UpstreamLatency *prometheus.HistogramVec +} + +// NewPrometheusPlugin creates a new PrometheusPlugin with initialized metrics. +func NewPrometheusPlugin() *PrometheusPlugin { + return &PrometheusPlugin{ + UpstreamRequestsTotal: bifrostUpstreamRequestsTotal, + UpstreamLatency: bifrostUpstreamLatencySeconds, + } +} + +// GetName returns the name of the plugin. +func (p *PrometheusPlugin) GetName() string { + return "bifrost-http-prometheus" +} + +// PreHook records the start time of the request in the context. +// This time is used later in PostHook to calculate request duration. +func (p *PrometheusPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + *ctx = context.WithValue(*ctx, startTimeKey, time.Now()) + + if req.Input.ChatCompletionInput != nil { + *ctx = context.WithValue(*ctx, methodKey, "chat") + } else if req.Input.TextCompletionInput != nil { + *ctx = context.WithValue(*ctx, methodKey, "text") + } + + return req, nil, nil +} + +// PostHook calculates duration and records upstream metrics for successful requests. +// It records: +// - Request latency +// - Total request count +func (p *PrometheusPlugin) PostHook(ctx *context.Context, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + if result == nil { + return result, bifrostErr, nil + } + + startTime, ok := (*ctx).Value(startTimeKey).(time.Time) + if !ok { + log.Println("Warning: startTime not found in context for Prometheus PostHook") + return result, bifrostErr, nil + } + + method, ok := (*ctx).Value(methodKey).(string) + if !ok { + log.Println("Warning: method not found in context for Prometheus PostHook") + return result, bifrostErr, nil + } + + // Collect prometheus labels from context + labelValues := map[string]string{ + "target": fmt.Sprintf("%s/%s", result.ExtraFields.Provider, result.Model), + "method": method, + } + + // Get all prometheus labels from context + for _, key := range customLabels { + if value := (*ctx).Value(PrometheusContextKey(key)); value != nil { + if strValue, ok := value.(string); ok { + labelValues[key] = strValue + } + } + } + + // Get label values in the correct order + promLabelValues := getPrometheusLabelValues(append([]string{"target", "method"}, customLabels...), labelValues) + + duration := time.Since(startTime).Seconds() + p.UpstreamLatency.WithLabelValues(promLabelValues...).Observe(duration) + p.UpstreamRequestsTotal.WithLabelValues(promLabelValues...).Inc() + + return result, bifrostErr, nil +} + +func (p *PrometheusPlugin) Cleanup() error { + return nil +} diff --git a/transports/bifrost-http/tracking/prometheus.yml b/transports/bifrost-http/tracking/prometheus.yml new file mode 100644 index 0000000000..6682b021fa --- /dev/null +++ b/transports/bifrost-http/tracking/prometheus.yml @@ -0,0 +1,15 @@ +# Prometheus configuration for tracking bifrost-http service (for development and testing purposes only, don't use in production without proper setup) +global: + scrape_interval: 5s # Scrape every 5 seconds + +# Note: Target configuration depends on your deployment environment: +# - For local development: Use "host.docker.internal:8080" to access the service running on your host machine +# - For Docker deployment: Use "bifrost-api:8080" to access the service within the Docker network +# Make sure to replace "bifrost-api" and "8080" with your actual docker container name and port if different +# Also check that you have the bifrost container inside "bifrost_tracking_network". + +scrape_configs: + - job_name: "bifrost-api" + static_configs: + - targets: ["host.docker.internal:8080"] # Scrape from the /metrics endpoint + diff --git a/transports/bifrost-http/tracking/setup.go b/transports/bifrost-http/tracking/setup.go new file mode 100644 index 0000000000..c65499981e --- /dev/null +++ b/transports/bifrost-http/tracking/setup.go @@ -0,0 +1,205 @@ +// Package tracking provides Prometheus metrics collection and monitoring functionality +// for the Bifrost HTTP service. This file contains the setup and configuration +// for Prometheus metrics collection, including HTTP middleware and metric definitions. +package tracking + +import ( + "log" + "math" + "strconv" + "strings" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/valyala/fasthttp" +) + +var ( + // httpRequestsTotal tracks the total number of HTTP requests + httpRequestsTotal *prometheus.CounterVec + // httpRequestDuration tracks the duration of HTTP requests + httpRequestDuration *prometheus.HistogramVec + // httpRequestSizeBytes tracks the size of incoming HTTP requests + httpRequestSizeBytes *prometheus.HistogramVec + // httpResponseSizeBytes tracks the size of outgoing HTTP responses + httpResponseSizeBytes *prometheus.HistogramVec + + // bifrostUpstreamRequestsTotal tracks the total number of requests forwarded to upstream providers by Bifrost. + bifrostUpstreamRequestsTotal *prometheus.CounterVec + + bifrostUpstreamLatencySeconds *prometheus.HistogramVec + + // customLabels stores the expected label names in order + customLabels []string + isInitialized bool +) + +func InitPrometheusMetrics(labels []string) { + if isInitialized { + return + } + + customLabels = labels + + httpDefaultLabels := []string{"path", "method", "status"} + bifrostDefaultLabels := []string{"target", "method"} + + httpRequestsTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "http_requests_total", + Help: "Total number of HTTP requests.", + }, + append(httpDefaultLabels, labels...), + ) + + // httpRequestDuration tracks the duration of HTTP requests + httpRequestDuration = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "http_request_duration_seconds", + Help: "Duration of HTTP requests.", + Buckets: prometheus.DefBuckets, // Default buckets: .005, .01, .025, .05, .1, .25, .5, 1, 2.5, 5, 10 + }, + append(httpDefaultLabels, labels...), + ) + + // httpRequestSizeBytes tracks the size of incoming HTTP requests + httpRequestSizeBytes = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "http_request_size_bytes", + Help: "Size of HTTP requests.", + Buckets: prometheus.ExponentialBuckets(100, 10, 8), // 100B to 1GB + }, + append(httpDefaultLabels, labels...), + ) + + // httpResponseSizeBytes tracks the size of outgoing HTTP responses + httpResponseSizeBytes = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "http_response_size_bytes", + Help: "Size of HTTP responses.", + Buckets: prometheus.ExponentialBuckets(100, 10, 8), // 100B to 1GB + }, + append(httpDefaultLabels, labels...), + ) + + // Bifrost Upstream Metrics (Defined globally, used by PrometheusPlugin) + bifrostUpstreamRequestsTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "bifrost_upstream_requests_total", + Help: "Total number of requests forwarded to upstream providers by Bifrost.", + }, + append(bifrostDefaultLabels, labels...), + ) + + bifrostUpstreamLatencySeconds = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "bifrost_upstream_latency_seconds", + Help: "Latency of requests forwarded to upstream providers by Bifrost.", + Buckets: prometheus.DefBuckets, + }, + append(bifrostDefaultLabels, labels...), + ) + + isInitialized = true +} + +// getPrometheusLabelValues takes an array of expected label keys and a map of header values, +// and returns an array of values in the same order as the keys, using empty string for missing values. +func getPrometheusLabelValues(expectedLabels []string, headerValues map[string]string) []string { + values := make([]string, len(expectedLabels)) + for i, label := range expectedLabels { + if value, exists := headerValues[label]; exists { + values[i] = value + } else { + values[i] = "" // Default empty value for missing labels + } + } + return values +} + +// collectPrometheusKeyValues collects all metrics for a request including: +// - Default metrics (path, method, status, request size) +// - Custom prometheus headers (x-bf-prom-*) +// Returns a map of all label values +func collectPrometheusKeyValues(ctx *fasthttp.RequestCtx) map[string]string { + path := string(ctx.Path()) + method := string(ctx.Method()) + + // Initialize with default metrics + labelValues := map[string]string{ + "path": path, + "method": method, + } + + // Collect custom prometheus headers + ctx.Request.Header.VisitAll(func(key, value []byte) { + keyStr := strings.ToLower(string(key)) + if strings.HasPrefix(keyStr, "x-bf-prom-") { + labelName := strings.TrimPrefix(keyStr, "x-bf-prom-") + labelValues[labelName] = string(value) + ctx.SetUserValue(keyStr, string(value)) + } + }) + + return labelValues +} + +// PrometheusMiddleware wraps a FastHTTP handler to collect Prometheus metrics. +// It tracks: +// - Total number of requests +// - Request duration +// - Request and response sizes +// - HTTP status codes +// - Bifrost upstream requests and errors +func PrometheusMiddleware(handler fasthttp.RequestHandler) fasthttp.RequestHandler { + if !isInitialized { + log.Println("Prometheus metrics are not initialized. Please call InitPrometheusMetrics first. Skipping metrics collection.") + return handler + } + + return func(ctx *fasthttp.RequestCtx) { + start := time.Now() + + // Collect request metrics and headers + promKeyValues := collectPrometheusKeyValues(ctx) + reqSize := float64(ctx.Request.Header.ContentLength()) + + // Process the request + handler(ctx) + + // Record metrics after request completion + duration := time.Since(start).Seconds() + status := strconv.Itoa(ctx.Response.StatusCode()) + respSize := float64(ctx.Response.Header.ContentLength()) + + // Add status to the label values + promKeyValues["status"] = status + + // Get label values in the correct order + promLabelValues := getPrometheusLabelValues(append([]string{"path", "method", "status"}, customLabels...), promKeyValues) + + // Record all metrics with prometheus labels + httpRequestsTotal.WithLabelValues(promLabelValues...).Inc() + httpRequestDuration.WithLabelValues(promLabelValues...).Observe(duration) + if reqSize >= 0 { + safeObserve(httpRequestSizeBytes, reqSize, promLabelValues...) + } + if respSize >= 0 { + safeObserve(httpResponseSizeBytes, respSize, promLabelValues...) + } + } +} + +// safeObserve safely records a value in a Prometheus histogram. +// It prevents recording invalid values (negative or infinite) that could cause issues. +func safeObserve(histogram *prometheus.HistogramVec, value float64, labels ...string) { + if value > 0 && value < math.MaxFloat64 { + metric, err := histogram.GetMetricWithLabelValues(labels...) + if err != nil { + log.Printf("Error getting metric with label values: %v", err) + } else { + metric.Observe(value) + } + } +} diff --git a/transports/config.example.json b/transports/config.example.json index 159aecac63..d54dcf6c22 100644 --- a/transports/config.example.json +++ b/transports/config.example.json @@ -1,117 +1,200 @@ { - "OpenAI": { - "keys": [ - { - "value": "env.OPENAI_API_KEY", - "models": ["gpt-4o-mini", "gpt-4-turbo"], - "weight": 1.0 + "providers": { + "openai": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": [ + "gpt-3.5-turbo", + "gpt-3.5-turbo-preview", + "gpt-4", + "gpt-4o", + "gpt-4o-mini", + "gpt-4-turbo", + "gpt-4-turbo-preview", + "gpt-4-vision-preview" + ], + "weight": 1.0 + } + ], + "network_config": { + "extra_headers": { + "X-Organization-ID": "org-123", + "X-Environment": "production" + }, + "default_request_timeout_in_seconds": 30, + "max_retries": 1, + "retry_backoff_initial_ms": 100, + "retry_backoff_max_ms": 2000 + }, + "concurrency_and_buffer_size": { + "concurrency": 3, + "buffer_size": 10 } - ], - "network_config": { - "default_request_timeout_in_seconds": 30, - "max_retries": 1, - "retry_backoff_initial_ms": 100, - "retry_backoff_max_ms": 2000 }, - "concurrency_and_buffer_size": { - "concurrency": 3, - "buffer_size": 10 - } - }, - "Anthropic": { - "keys": [ - { - "value": "env.ANTHROPIC_API_KEY", - "models": [ - "claude-3-7-sonnet-20250219", - "claude-3-5-sonnet-20240620", - "claude-2.1" - ], - "weight": 1.0 + "anthropic": { + "keys": [ + { + "value": "env.ANTHROPIC_API_KEY", + "models": [ + "claude-2.1", + "claude-3-sonnet-20240229", + "claude-3-haiku-20240307", + "claude-3-opus-20240229", + "claude-3-5-sonnet-20240620", + "claude-3-7-sonnet-20250219" + ], + "weight": 1.0 + } + ], + "network_config": { + "default_request_timeout_in_seconds": 30, + "max_retries": 1, + "retry_backoff_initial_ms": 100, + "retry_backoff_max_ms": 2000 + }, + "concurrency_and_buffer_size": { + "concurrency": 3, + "buffer_size": 10 } - ], - "network_config": { - "default_request_timeout_in_seconds": 30, - "max_retries": 1, - "retry_backoff_initial_ms": 100, - "retry_backoff_max_ms": 2000 }, - "concurrency_and_buffer_size": { - "concurrency": 3, - "buffer_size": 10 - } - }, - "Bedrock": { - "keys": [ - { - "value": "env.BEDROCK_API_KEY", - "models": [ - "anthropic.claude-v2:1", - "mistral.mixtral-8x7b-instruct-v0:1", - "mistral.mistral-large-2402-v1:0", - "anthropic.claude-3-sonnet-20240229-v1:0" - ], - "weight": 1.0 + "bedrock": { + "keys": [ + { + "value": "env.BEDROCK_API_KEY", + "models": [ + "anthropic.claude-v2:1", + "mistral.mixtral-8x7b-instruct-v0:1", + "mistral.mistral-large-2402-v1:0", + "anthropic.claude-3-sonnet-20240229-v1:0" + ], + "weight": 1.0 + } + ], + "network_config": { + "default_request_timeout_in_seconds": 30, + "max_retries": 1, + "retry_backoff_initial_ms": 100, + "retry_backoff_max_ms": 2000 + }, + "meta_config": { + "secret_access_key": "env.AWS_SECRET_ACCESS_KEY", + "region": "us-east-1" + }, + "concurrency_and_buffer_size": { + "concurrency": 3, + "buffer_size": 10 } - ], - "network_config": { - "default_request_timeout_in_seconds": 30, - "max_retries": 1, - "retry_backoff_initial_ms": 100, - "retry_backoff_max_ms": 2000 }, - "meta_config": { - "secret_access_key": "env.BEDROCK_ACCESS_KEY", - "region": "us-east-1" + "cohere": { + "keys": [ + { + "value": "env.COHERE_API_KEY", + "models": ["command-a-03-2025"], + "weight": 1.0 + } + ], + "network_config": { + "default_request_timeout_in_seconds": 30, + "max_retries": 1, + "retry_backoff_initial_ms": 100, + "retry_backoff_max_ms": 2000 + }, + "concurrency_and_buffer_size": { + "concurrency": 3, + "buffer_size": 10 + } }, - "concurrency_and_buffer_size": { - "concurrency": 3, - "buffer_size": 10 - } - }, - "Cohere": { - "keys": [ - { - "value": "env.COHERE_API_KEY", - "models": ["command-a-03-2025"], - "weight": 1.0 + "azure": { + "keys": [ + { + "value": "env.AZURE_API_KEY", + "models": ["gpt-4o"], + "weight": 1.0 + } + ], + "network_config": { + "default_request_timeout_in_seconds": 30, + "max_retries": 1, + "retry_backoff_initial_ms": 100, + "retry_backoff_max_ms": 2000 + }, + "meta_config": { + "endpoint": "env.AZURE_ENDPOINT", + "deployments": { + "gpt-4o": "gpt-4o-aug" + }, + "api_version": "2024-08-01-preview" + }, + "concurrency_and_buffer_size": { + "concurrency": 3, + "buffer_size": 10 } - ], - "network_config": { - "default_request_timeout_in_seconds": 30, - "max_retries": 1, - "retry_backoff_initial_ms": 100, - "retry_backoff_max_ms": 2000 }, - "concurrency_and_buffer_size": { - "concurrency": 3, - "buffer_size": 10 + "vertex": { + "keys": [], + "meta_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "region": "us-central1", + "auth_credentials": "env.VERTEX_CREDENTIALS" + }, + "concurrency_and_buffer_size": { + "concurrency": 3, + "buffer_size": 10 + } } }, - "Azure": { - "keys": [ + "mcp": { + "client_configs": [ { - "value": "env.AZURE_API_KEY", - "models": ["gpt-4o"], - "weight": 1.0 + "name": "your-mcp-server-name", + "connection_type": "stdio", + "stdio_config": { + "command": "npx", + "args": ["-y", "your-mcp-server-name"], + "envs": ["YOUR_MCP_SERVER_ENV_VAR"] + } + } + ] + }, + "plugins": [ + { + "name": "maxim", + "source": "../../plugins/maxim", + "type": "local", + "config": { + "api_key": "env.MAXIM_API_KEY", + "log_repo_id": "env.MAXIM_LOG_REPO_ID" } - ], - "network_config": { - "default_request_timeout_in_seconds": 30, - "max_retries": 1, - "retry_backoff_initial_ms": 100, - "retry_backoff_max_ms": 2000 - }, - "meta_config": { - "endpoint": "env.AZURE_ENDPOINT", - "deployments": { - "gpt-4o": "gpt-4o-aug" - }, - "api_version": "2024-08-01-preview" }, - "concurrency_and_buffer_size": { - "concurrency": 3, - "buffer_size": 10 + { + "name": "mocker", + "source": "../../plugins/mocker", + "type": "local", + "config": { + "enabled": true, + "default_behavior": "passthrough", + "rules": [ + { + "name": "test-mock", + "enabled": true, + "priority": 1, + "probability": 1, + "conditions": { + "providers": ["openai"] + }, + "responses": [ + { + "type": "success", + "weight": 1.0, + "content": { + "message": "This is a mock response for testing" + } + } + ] + } + ] + } } - } + ] } diff --git a/transports/go.mod b/transports/go.mod index c92d309e37..6e442b18b4 100644 --- a/transports/go.mod +++ b/transports/go.mod @@ -4,12 +4,16 @@ go 1.24.1 require ( github.com/fasthttp/router v1.5.4 - github.com/joho/godotenv v1.5.1 - github.com/maximhq/bifrost/core v1.0.2 - github.com/valyala/fasthttp v1.60.0 + github.com/maximhq/bifrost/core v1.1.6 + github.com/prometheus/client_golang v1.22.0 + github.com/valyala/fasthttp v1.62.0 + google.golang.org/genai v1.4.0 ) require ( + cloud.google.com/go v0.121.0 // indirect + cloud.google.com/go/auth v0.16.0 // indirect + cloud.google.com/go/compute/metadata v0.7.0 // indirect github.com/andybalholm/brotli v1.1.1 // indirect github.com/aws/aws-sdk-go-v2 v1.36.3 // indirect github.com/aws/aws-sdk-go-v2/config v1.29.14 // indirect @@ -24,10 +28,39 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.1 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect github.com/aws/smithy-go v1.22.3 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-logr/logr v1.4.2 // indirect + github.com/go-logr/stdr v1.2.2 // indirect github.com/goccy/go-json v0.10.5 // indirect + github.com/google/go-cmp v0.7.0 // indirect + github.com/google/s2a-go v0.1.9 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect + github.com/googleapis/gax-go/v2 v2.14.1 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/klauspost/compress v1.18.0 // indirect + github.com/mark3labs/mcp-go v0.32.0 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.62.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 // indirect + github.com/spf13/cast v1.7.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/net v0.39.0 // indirect - golang.org/x/text v0.24.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect + go.opentelemetry.io/otel v1.35.0 // indirect + go.opentelemetry.io/otel/metric v1.35.0 // indirect + go.opentelemetry.io/otel/trace v1.35.0 // indirect + golang.org/x/crypto v0.38.0 // indirect + golang.org/x/net v0.40.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.25.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250425173222-7b384671a197 // indirect + google.golang.org/grpc v1.72.0 // indirect + google.golang.org/protobuf v1.36.6 // indirect ) diff --git a/transports/go.sum b/transports/go.sum index bab9764a15..edbfab2f23 100644 --- a/transports/go.sum +++ b/transports/go.sum @@ -1,3 +1,9 @@ +cloud.google.com/go v0.121.0 h1:pgfwva8nGw7vivjZiRfrmglGWiCJBP+0OmDpenG/Fwg= +cloud.google.com/go v0.121.0/go.mod h1:rS7Kytwheu/y9buoDmu5EIpMMCI4Mb8ND4aeN4Vwj7Q= +cloud.google.com/go/auth v0.16.0 h1:Pd8P1s9WkcrBE2n/PhAwKsdrR35V3Sg2II9B+ndM3CU= +cloud.google.com/go/auth v0.16.0/go.mod h1:1howDHJ5IETh/LwYs3ZxvlkXF48aSqqJUM+5o02dNOI= +cloud.google.com/go/compute/metadata v0.7.0 h1:PBWF+iiAerVNe8UCHxdOt6eHLVc3ydFeOCw78U8ytSU= +cloud.google.com/go/compute/metadata v0.7.0/go.mod h1:j5MvL9PprKL39t166CoB1uVHfQMs4tFQZZcKwksXUjo= github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= @@ -26,27 +32,112 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 h1:1XuUZ8mYJw9B6lzAkXhqHlJd/Xv github.com/aws/aws-sdk-go-v2/service/sts v1.33.19/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4= github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k= github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fasthttp/router v1.5.4 h1:oxdThbBwQgsDIYZ3wR1IavsNl6ZS9WdjKukeMikOnC8= github.com/fasthttp/router v1.5.4/go.mod h1:3/hysWq6cky7dTfzaaEPZGdptwjwx0qzTgFCKEWRjgc= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= -github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= -github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU9uHLo7OnF5tL52HFAgMmyrf4= +github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= +github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q= +github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= -github.com/maximhq/bifrost/core v1.0.1 h1:B0u6o13faUexA+V0EUU0bsLW2dHg9+R2TZKQzPzCxlY= -github.com/maximhq/bifrost/core v1.0.1/go.mod h1:4+Ept2EnX1EEjH/mBuSwK7eE56znI/BCoCbIrx25/x8= -github.com/maximhq/bifrost/core v1.0.2 h1:GG1CGrvbz5lbdDudlJodKHx9pHr0VAoUd5lhgxUWc00= -github.com/maximhq/bifrost/core v1.0.2/go.mod h1:ZF8LVnUwVzHZ3SkCQPvXXmu0w3b4sjRLS6ij9aPYcjg= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7MU8= +github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/maximhq/bifrost/core v1.1.6 h1:rZrfPVcAfNggfBaOTdu/w+xNwDhW79bfexXsw8LRoMQ= +github.com/maximhq/bifrost/core v1.1.6/go.mod h1:yMRCncTgKYBIrECSRVxMbY3BL8CjLbipJlc644jryxc= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= +github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= +github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 h1:D0vL7YNisV2yqE55+q0lFuGse6U8lxlg7fYTctlT5Gc= github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.60.0 h1:kBRYS0lOhVJ6V+bYN8PqAHELKHtXqwq9zNMLKx1MBsw= -github.com/valyala/fasthttp v1.60.0/go.mod h1:iY4kDgV3Gc6EqhRZ8icqcmlG6bqhcDXfuHgTO4FXCvc= +github.com/valyala/fasthttp v1.62.0 h1:8dKRBX/y2rCzyc6903Zu1+3qN0H/d2MsxPPmVNamiH0= +github.com/valyala/fasthttp v1.62.0/go.mod h1:FCINgr4GKdKqV8Q0xv8b+UxPV+H/O5nNFo3D+r54Htg= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= -golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= -golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= -golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= -golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 h1:sbiXRNDSWJOTobXh5HyQKjq6wUC5tNybqjIqDpAY4CU= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0/go.mod h1:69uWxva0WgAA/4bu2Yy70SLDBwZXuQ6PbBpbsa5iZrQ= +go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ= +go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y= +go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M= +go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE= +go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY= +go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg= +go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o= +go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w= +go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs= +go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= +golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= +golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= +golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= +golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +google.golang.org/genai v1.4.0 h1:i3D6q5UTLoAHuXOaDtJnA4Lcz6v+aBP3phGBYOgzEm4= +google.golang.org/genai v1.4.0/go.mod h1:TyfOKRz/QyCaj6f/ZDt505x+YreXnY40l2I6k8TvgqY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250425173222-7b384671a197 h1:29cjnHVylHwTzH66WfFZqgSQgnxzvWE+jvBwpZCLRxY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250425173222-7b384671a197/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= +google.golang.org/grpc v1.72.0 h1:S7UkcVa60b5AAQTaO6ZKamFp1zMZSU0fGDK2WZLbBnM= +google.golang.org/grpc v1.72.0/go.mod h1:wH5Aktxcg25y1I3w7H69nHfXdOG3UiadoBtjh3izSDM= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/transports/http/main.go b/transports/http/main.go deleted file mode 100644 index 8af6fb3171..0000000000 --- a/transports/http/main.go +++ /dev/null @@ -1,443 +0,0 @@ -// Package http provides an HTTP service using FastHTTP that exposes endpoints -// for text and chat completions using various AI model providers (OpenAI, Anthropic, Bedrock, etc.). - -// The HTTP service provides two main endpoints: -// - /v1/text/completions: For text completion requests -// - /v1/chat/completions: For chat completion requests - -// Configuration is handled through a JSON config file and environment variables: -// - Use -config flag to specify the config file location -// - Use -env flag to specify the .env file location -// - Use -port flag to specify the server port (default: 8080) -// - Use -pool-size flag to specify the initial connection pool size (default: 300) - -// try running the server with: -// go run http.go -config config.example.json -env .env -port 8080 -pool-size 300 -// after setting the environment variables present in config.example.json in your .env file. - -package main - -import ( - "encoding/json" - "errors" - "flag" - "fmt" - "log" - "os" - "reflect" - "strings" - "sync" - - "github.com/fasthttp/router" - "github.com/joho/godotenv" - bifrost "github.com/maximhq/bifrost/core" - schemas "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/core/schemas/meta" - "github.com/valyala/fasthttp" -) - -// Command line flags -var ( - initialPoolSize int // Initial size of the connection pool - dropExcessRequests bool // Drop excess requests - port string // Port to run the server on - configPath string // Path to the config file - envPath string // Path to the .env file -) - -// init initializes command line flags with default values. -// It also checks for environment variables that might override the defaults. -func init() { - flag.IntVar(&initialPoolSize, "pool-size", 300, "Initial pool size for Bifrost") - flag.StringVar(&port, "port", "8080", "Port to run the server on") - flag.StringVar(&configPath, "config", "", "Path to the config file") - flag.StringVar(&envPath, "env", "", "Path to the .env file") - flag.BoolVar(&dropExcessRequests, "drop-excess-requests", false, "Drop excess requests") - flag.Parse() - - if configPath == "" { - log.Fatalf("config path is required") - } - - if envPath == "" { - log.Fatalf("env path is required") - } -} - -// ProviderConfig represents the configuration for a specific AI model provider. -// It includes API keys, network settings, provider-specific metadata, and concurrency settings. -type ProviderConfig struct { - Keys []schemas.Key `json:"keys"` // API keys for the provider - NetworkConfig *schemas.NetworkConfig `json:"network_config,omitempty"` // Network-related settings - MetaConfig *schemas.MetaConfig `json:"-"` // Provider-specific metadata - ConcurrencyAndBufferSize *schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size,omitempty"` // Concurrency settings -} - -// ConfigMap maps provider names to their configurations. -type ConfigMap map[schemas.ModelProvider]ProviderConfig - -// readConfig reads and parses the configuration file. -// It handles case conversion for provider names and sets up provider-specific metadata. -// Returns a ConfigMap containing all provider configurations. -// Panics if the config file cannot be read or parsed. -// -// In the config file, use placeholder keys (e.g., env.OPENAI_API_KEY) instead of hardcoding actual values. -// These placeholders will be replaced with the corresponding values from the .env file. -// Location of the .env file is specified by the -env flag. It -// Example: -// -// "keys":[{ -// "value": "env.OPENAI_API_KEY" -// "models": ["gpt-4o-mini", "gpt-4-turbo"], -// "weight": 1.0 -// }] -// -// In this example, OPENAI_API_KEY refers to a key in the .env file. At runtime, its value will be used to replace the placeholder. -// Same setup applies to keys in meta configs of all the providers. -// Example: -// -// "meta_config": { -// "secret_access_key": "env.BEDROCK_ACCESS_KEY" -// "region": "env.BEDROCK_REGION" -// } -// -// In this example, BEDROCK_ACCESS_KEY and BEDROCK_REGION refer to keys in the .env file. -func readConfig(configLocation string) ConfigMap { - data, err := os.ReadFile(configLocation) - if err != nil { - log.Fatalf("failed to read config JSON file: %v", err) - } - - // First unmarshal into a map with string keys to handle case conversion - var rawConfig map[string]ProviderConfig - if err := json.Unmarshal(data, &rawConfig); err != nil { - log.Fatalf("failed to unmarshal JSON: %v", err) - } - - if rawConfig == nil { - log.Fatalf("provided config is nil") - } - - // Create a new config map with lowercase provider names - config := make(ConfigMap) - for rawProvider, cfg := range rawConfig { - provider := schemas.ModelProvider(strings.ToLower(rawProvider)) - - switch provider { - case schemas.Azure: - var azureMetaConfig meta.AzureMetaConfig - if err := json.Unmarshal(data, &struct { - Azure struct { - MetaConfig *meta.AzureMetaConfig `json:"meta_config"` - } `json:"Azure"` - }{Azure: struct { - MetaConfig *meta.AzureMetaConfig `json:"meta_config"` - }{&azureMetaConfig}}); err != nil { - log.Printf("warning: failed to unmarshal Azure meta config: %v", err) - } - var metaConfig schemas.MetaConfig = &azureMetaConfig - cfg.MetaConfig = &metaConfig - case schemas.Bedrock: - var bedrockMetaConfig meta.BedrockMetaConfig - if err := json.Unmarshal(data, &struct { - Bedrock struct { - MetaConfig *meta.BedrockMetaConfig `json:"meta_config"` - } `json:"Bedrock"` - }{Bedrock: struct { - MetaConfig *meta.BedrockMetaConfig `json:"meta_config"` - }{&bedrockMetaConfig}}); err != nil { - log.Printf("warning: failed to unmarshal Bedrock meta config: %v", err) - } - var metaConfig schemas.MetaConfig = &bedrockMetaConfig - cfg.MetaConfig = &metaConfig - } - - config[provider] = cfg - } - - return config -} - -// BaseAccount implements the Account interface for Bifrost. -// It manages provider configurations and API keys. -type BaseAccount struct { - Config ConfigMap // Map of provider configurations - mu sync.Mutex // Mutex to protect Config access -} - -// GetConfiguredProviders returns a list of all configured providers. -// Implements the Account interface. -func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { - baseAccount.mu.Lock() - defer baseAccount.mu.Unlock() - - providers := make([]schemas.ModelProvider, 0, len(baseAccount.Config)) - for provider := range baseAccount.Config { - providers = append(providers, provider) - } - return providers, nil -} - -// GetKeysForProvider returns the API keys configured for a specific provider. -// Implements the Account interface. -func (baseAccount *BaseAccount) GetKeysForProvider(providerKey schemas.ModelProvider) ([]schemas.Key, error) { - baseAccount.mu.Lock() - defer baseAccount.mu.Unlock() - - return baseAccount.Config[providerKey].Keys, nil -} - -// GetConfigForProvider returns the complete configuration for a specific provider. -// Implements the Account interface. -func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { - baseAccount.mu.Lock() - defer baseAccount.mu.Unlock() - - config, exists := baseAccount.Config[providerKey] - if !exists { - return nil, errors.New("config for provider not found") - } - - providerConfig := &schemas.ProviderConfig{} - - if config.NetworkConfig != nil { - providerConfig.NetworkConfig = *config.NetworkConfig - } - - if config.MetaConfig != nil { - providerConfig.MetaConfig = *config.MetaConfig - } - - if config.ConcurrencyAndBufferSize != nil { - providerConfig.ConcurrencyAndBufferSize = *config.ConcurrencyAndBufferSize - } - - return providerConfig, nil -} - -// readKeys reads environment variables from a .env file and updates the provider configurations. -// It replaces values starting with "env." in the config with actual values from the environment. -// Returns an error if any required environment variable is missing. -func (baseAccount *BaseAccount) readKeys(envLocation string) error { - envVars, err := godotenv.Read(envLocation) - if err != nil { - return fmt.Errorf("failed to read .env file: %w", err) - } - - // Helper function to check and replace env values - replaceEnvValue := func(value string) (string, error) { - if strings.HasPrefix(value, "env.") { - envKey := strings.TrimPrefix(value, "env.") - if envValue, exists := envVars[envKey]; exists { - return envValue, nil - } - return "", fmt.Errorf("environment variable %s not found in .env file", envKey) - } - return value, nil - } - - // Helper function to recursively check and replace env values in a struct - var processStruct func(interface{}) error - processStruct = func(v interface{}) error { - val := reflect.ValueOf(v) - - // Dereference pointer if present - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - - // Handle interface types - if val.Kind() == reflect.Interface { - val = val.Elem() - // If the interface value is a pointer, dereference it - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - } - - if val.Kind() != reflect.Struct { - return nil - } - - typ := val.Type() - for i := 0; i < val.NumField(); i++ { - field := val.Field(i) - fieldType := typ.Field(i) - - // Skip unexported fields - if !field.CanSet() { - continue - } - - switch field.Kind() { - case reflect.String: - if field.CanSet() { - value := field.String() - if strings.HasPrefix(value, "env.") { - newValue, err := replaceEnvValue(value) - if err != nil { - return fmt.Errorf("field %s: %w", fieldType.Name, err) - } - field.SetString(newValue) - } - } - case reflect.Interface: - if !field.IsNil() { - if err := processStruct(field.Interface()); err != nil { - return err - } - } - } - } - return nil - } - - // Lock the config map for the entire update operation - baseAccount.mu.Lock() - defer baseAccount.mu.Unlock() - - // Check and replace values in provider configs - for provider, config := range baseAccount.Config { - // Check keys - for i, key := range config.Keys { - newValue, err := replaceEnvValue(key.Value) - if err != nil { - return fmt.Errorf("provider %s: %w", provider, err) - } - config.Keys[i].Value = newValue - } - - // Check meta config if it exists - if config.MetaConfig != nil { - if err := processStruct(config.MetaConfig); err != nil { - return fmt.Errorf("provider %s: %w", provider, err) - } - } - - baseAccount.Config[provider] = config - } - - return nil -} - -// CompletionRequest represents a request for either text or chat completion. -// It includes all necessary fields for both types of completions. -type CompletionRequest struct { - Provider schemas.ModelProvider `json:"provider"` // The AI model provider to use - Messages []schemas.Message `json:"messages"` // Chat messages (for chat completion) - Text string `json:"text"` // Text input (for text completion) - Model string `json:"model"` // Model to use - Params *schemas.ModelParameters `json:"params"` // Additional model parameters - Fallbacks []schemas.Fallback `json:"fallbacks"` // Fallback providers and models -} - -// handleCompletion processes both text and chat completion requests. -// It handles request parsing, validation, and response formatting. -func handleCompletion(ctx *fasthttp.RequestCtx, client *bifrost.Bifrost, isChat bool) { - var req CompletionRequest - if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { - ctx.SetStatusCode(fasthttp.StatusBadRequest) - ctx.SetBodyString(fmt.Sprintf("invalid request format: %v", err)) - return - } - - if req.Provider == "" { - ctx.SetStatusCode(fasthttp.StatusBadRequest) - ctx.SetBodyString("Provider is required") - return - } - - bifrostReq := &schemas.BifrostRequest{ - Model: req.Model, - Params: req.Params, - Fallbacks: req.Fallbacks, - } - - if isChat { - if len(req.Messages) == 0 { - ctx.SetStatusCode(fasthttp.StatusBadRequest) - ctx.SetBodyString("Messages array is required") - return - } - bifrostReq.Input = schemas.RequestInput{ - ChatCompletionInput: &req.Messages, - } - } else { - if req.Text == "" { - ctx.SetStatusCode(fasthttp.StatusBadRequest) - ctx.SetBodyString("Text is required") - return - } - bifrostReq.Input = schemas.RequestInput{ - TextCompletionInput: &req.Text, - } - } - - var resp *schemas.BifrostResponse - var err *schemas.BifrostError - if isChat { - resp, err = client.ChatCompletionRequest(req.Provider, bifrostReq, ctx) - } else { - resp, err = client.TextCompletionRequest(req.Provider, bifrostReq, ctx) - } - - if err != nil { - if err.IsBifrostError { - ctx.SetStatusCode(fasthttp.StatusInternalServerError) - } else { - ctx.SetStatusCode(fasthttp.StatusBadRequest) - } - ctx.SetContentType("application/json") - json.NewEncoder(ctx).Encode(err) - return - } - - ctx.SetStatusCode(fasthttp.StatusOK) - ctx.SetContentType("application/json") - json.NewEncoder(ctx).Encode(resp) -} - -// main is the entry point of the application. -// It: -// 1. Reads and parses configuration -// 2. Initializes the Bifrost client -// 3. Sets up HTTP routes -// 4. Starts the HTTP server -func main() { - config := readConfig(configPath) - account := &BaseAccount{Config: config} - - if err := account.readKeys(envPath); err != nil { - log.Printf("warning: failed to read environment variables: %v", err) - } - - client, err := bifrost.Init(schemas.BifrostConfig{ - Account: account, - InitialPoolSize: initialPoolSize, - DropExcessRequests: dropExcessRequests, - }) - if err != nil { - log.Fatalf("failed to initialize bifrost: %v", err) - } - - r := router.New() - - r.POST("/v1/text/completions", func(ctx *fasthttp.RequestCtx) { - handleCompletion(ctx, client, false) - }) - - r.POST("/v1/chat/completions", func(ctx *fasthttp.RequestCtx) { - handleCompletion(ctx, client, true) - }) - - server := &fasthttp.Server{ - Handler: r.Handler, - } - - fmt.Printf("Starting HTTP server on port %s\n", port) - if err := server.ListenAndServe(fmt.Sprintf(":%s", port)); err != nil { - log.Fatalf("failed to start server: %v", err) - } - - client.Shutdown() -}