diff --git a/.editorconfig b/.editorconfig index 7223b342a2..71e2db6637 100644 --- a/.editorconfig +++ b/.editorconfig @@ -5,5 +5,9 @@ insert_final_newline = false end_of_line = lf charset = utf-8 +[*.go] +indent_style = tab +indent_size = 4 + [*.{js,jsx,ts,tsx,mjs,json,md,css,scss,html}] insert_final_newline = false diff --git a/.github/workflows/release-pipeline.yml b/.github/workflows/release-pipeline.yml index 27f6391ba5..f052d1b11f 100644 --- a/.github/workflows/release-pipeline.yml +++ b/.github/workflows/release-pipeline.yml @@ -3,7 +3,7 @@ name: Release Pipeline # Triggers automatically on push to main when any version file changes on: push: - branches: ["main"] + branches: ["main", "v1.4.0"] # Prevent concurrent runs concurrency: @@ -606,7 +606,7 @@ jobs: fi # Build the message with proper formatting - MESSAGE=$(printf "šŸš€ **Release Pipeline Complete**\n\n**Components:**\n• Core: %s\n• Framework: %s\n• Plugins: %s\n• Bifrost HTTP: %s\n\n**Details:**\n• Branch: \`main\`\n• Commit: \`%.8s\`\n• Author: %s\n\n[View Workflow Run](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }})" "$CORE_STATUS" "$FRAMEWORK_STATUS" "$PLUGINS_STATUS" "$BIFROST_STATUS" "${{ github.sha }}" "${{ github.actor }}") + MESSAGE=$(printf "šŸš€ **Release Pipeline Complete**\n\n**Components:**\n• Core: %s\n• Framework: %s\n• Plugins: %s\n• Bifrost HTTP: %s\n\n**Details:**\n• Branch: \`${{ github.ref_name }}\`\n• Commit: \`%.8s\`\n• Author: %s\n\n[View Workflow Run](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }})" "$CORE_STATUS" "$FRAMEWORK_STATUS" "$PLUGINS_STATUS" "$BIFROST_STATUS" "${{ github.sha }}" "${{ github.actor }}") payload="$(jq -n --arg content "$MESSAGE" '{content:$content}')" curl -sS -H "Content-Type: application/json" -d "$payload" "$DISCORD_WEBHOOK" diff --git a/.github/workflows/scripts/push-mintlify-changelog.sh b/.github/workflows/scripts/push-mintlify-changelog.sh index cb322ef422..2e4c835f18 100755 --- a/.github/workflows/scripts/push-mintlify-changelog.sh +++ b/.github/workflows/scripts/push-mintlify-changelog.sh @@ -236,7 +236,18 @@ if ! grep -q "\"$route\"" docs/docs.json; then fi # Pulling again before committing -git pull origin main +CURRENT_BRANCH="$(git rev-parse --abbrev-ref HEAD)" +if [ "$CURRENT_BRANCH" = "HEAD" ]; then + # In detached HEAD state (common in CI), use GITHUB_REF_NAME or default to main + CURRENT_BRANCH="${GITHUB_REF_NAME:-main}" +fi + +echo "Pulling latest changes from origin/$CURRENT_BRANCH..." +if ! git pull origin "$CURRENT_BRANCH"; then + echo "āŒ Error: git pull origin $CURRENT_BRANCH failed" + exit 1 +fi + # Commit and push changes git add docs/changelogs/$VERSION.mdx git add docs/docs.json @@ -247,4 +258,4 @@ done git config user.name "github-actions[bot]" git config user.email "41898282+github-actions[bot]@users.noreply.github.com" git commit -m "Adds changelog for $VERSION --skip-pipeline" -git push origin main +git push origin "$CURRENT_BRANCH" diff --git a/.github/workflows/scripts/release-bifrost-http.sh b/.github/workflows/scripts/release-bifrost-http.sh index 8b10586fdf..0e47988676 100755 --- a/.github/workflows/scripts/release-bifrost-http.sh +++ b/.github/workflows/scripts/release-bifrost-http.sh @@ -237,7 +237,7 @@ for config in "${CONFIGS_TO_TEST[@]}"; do SERVER_LOG=$(mktemp) # Start the server in background with a timeout, logging to file and console - timeout 180s $TEST_BINARY --app-dir "$config_path" --port 18080 --log-level debug 2>&1 | tee "$SERVER_LOG" & + timeout 120s $TEST_BINARY --app-dir "$config_path" --port 18080 --log-level debug 2>&1 | tee "$SERVER_LOG" & SERVER_PID=$! # Wait for server to be ready by looking for the startup message @@ -313,7 +313,17 @@ echo "āœ… Transport build validation successful" # Commit and push changes if any # First, pull latest changes to avoid conflicts -git pull origin main +CURRENT_BRANCH="$(git rev-parse --abbrev-ref HEAD)" +if [ "$CURRENT_BRANCH" = "HEAD" ]; then + # In detached HEAD state (common in CI), use GITHUB_REF_NAME or default to main + CURRENT_BRANCH="${GITHUB_REF_NAME:-main}" +fi + +echo "Pulling latest changes from origin/$CURRENT_BRANCH..." +if ! git pull origin "$CURRENT_BRANCH"; then + echo "āŒ Error: git pull origin $CURRENT_BRANCH failed" + exit 1 +fi # Stage any changes made to transports/ git add transports/ diff --git a/.github/workflows/scripts/release-framework.sh b/.github/workflows/scripts/release-framework.sh index d2ead7c25a..b5fef1e8f4 100755 --- a/.github/workflows/scripts/release-framework.sh +++ b/.github/workflows/scripts/release-framework.sh @@ -26,7 +26,31 @@ TAG_NAME="framework/${VERSION}" echo "šŸ“¦ Releasing framework $VERSION..." # Ensure we have the latest version -git pull origin +CURRENT_BRANCH="$(git rev-parse --abbrev-ref HEAD)" +if [ "$CURRENT_BRANCH" = "HEAD" ]; then + # In detached HEAD state (common in CI), use GITHUB_REF_NAME or default to main + CURRENT_BRANCH="${GITHUB_REF_NAME:-main}" +fi + +echo "Pulling latest changes from origin/$CURRENT_BRANCH..." +if ! git pull origin "$CURRENT_BRANCH"; then + echo "āŒ Error: git pull origin $CURRENT_BRANCH failed" + exit 1 +fi + +# Check for merge conflicts or unexpected working-tree changes +if ! git diff --quiet; then + echo "āŒ Error: Unstaged changes detected after pull (possible merge conflict)" + git status --short + exit 1 +fi + +if ! git diff --cached --quiet; then + echo "āŒ Error: Staged changes detected after pull (unexpected state)" + git status --short + exit 1 +fi + # Fetching all tags git fetch --tags >/dev/null 2>&1 || true @@ -106,6 +130,10 @@ if ! git diff --cached --quiet; then git commit -m "framework: bump core to $CORE_VERSION --skip-pipeline" # Push the bump so go.mod/go.sum changes are recorded on the branch CURRENT_BRANCH="$(git rev-parse --abbrev-ref HEAD)" + if [ "$CURRENT_BRANCH" = "HEAD" ]; then + # In detached HEAD state (common in CI), use GITHUB_REF_NAME or default to main + CURRENT_BRANCH="${GITHUB_REF_NAME:-main}" + fi git push origin "$CURRENT_BRANCH" echo "šŸ”§ Pushed framework bump to $CURRENT_BRANCH" else diff --git a/.github/workflows/scripts/release-single-plugin.sh b/.github/workflows/scripts/release-single-plugin.sh index 5c20f9bbd8..6f793354e1 100755 --- a/.github/workflows/scripts/release-single-plugin.sh +++ b/.github/workflows/scripts/release-single-plugin.sh @@ -28,7 +28,17 @@ else fi # Ensure we have the latest version -git pull origin +CURRENT_BRANCH="$(git rev-parse --abbrev-ref HEAD)" +if [ "$CURRENT_BRANCH" = "HEAD" ]; then + # In detached HEAD state (common in CI), use GITHUB_REF_NAME or default to main + CURRENT_BRANCH="${GITHUB_REF_NAME:-main}" +fi + +echo "Pulling latest changes from origin/$CURRENT_BRANCH..." +if ! git pull origin "$CURRENT_BRANCH"; then + echo "āŒ Error: git pull origin $CURRENT_BRANCH failed" + exit 1 +fi echo "šŸ”Œ Releasing plugin: $PLUGIN_NAME" echo "šŸ”§ Core version: $CORE_VERSION" @@ -66,19 +76,24 @@ if [ -f "go.mod" ]; then # Run tests with coverage if any exist if go list ./... | grep -q .; then - echo "🧪 Running plugin tests with coverage..." - go test -coverprofile=coverage.txt -coverpkg=./... ./... - - # Upload coverage to Codecov - if [ -n "${CODECOV_TOKEN:-}" ]; then - echo "šŸ“Š Uploading coverage to Codecov..." - curl -Os https://uploader.codecov.io/latest/linux/codecov - chmod +x codecov - ./codecov -t "$CODECOV_TOKEN" -f coverage.txt -F "plugin-${PLUGIN_NAME}" - rm -f codecov coverage.txt + # Skip tests for governance plugin (no tests yet) + if [ "$PLUGIN_NAME" = "governance" ]; then + echo "ā„¹ļø Skipping tests for governance plugin" else - echo "ā„¹ļø CODECOV_TOKEN not set, skipping coverage upload" - rm -f coverage.txt + echo "🧪 Running plugin tests with coverage..." + go test -coverprofile=coverage.txt -coverpkg=./... ./... + + # Upload coverage to Codecov + if [ -n "${CODECOV_TOKEN:-}" ]; then + echo "šŸ“Š Uploading coverage to Codecov..." + curl -Os https://uploader.codecov.io/latest/linux/codecov + chmod +x codecov + ./codecov -t "$CODECOV_TOKEN" -f coverage.txt -F "plugin-${PLUGIN_NAME}" + rm -f codecov coverage.txt + else + echo "ā„¹ļø CODECOV_TOKEN not set, skipping coverage upload" + rm -f coverage.txt + fi fi fi diff --git a/.gitignore b/.gitignore index 540b0b388d..3c1908a6ce 100644 --- a/.gitignore +++ b/.gitignore @@ -53,4 +53,6 @@ test-reports # Cursor specific -.cursor/ \ No newline at end of file +.cursor/ +build/ +target/ \ No newline at end of file diff --git a/Makefile b/Makefile index b136c5eceb..ebb4ac5fb3 100644 --- a/Makefile +++ b/Makefile @@ -26,7 +26,7 @@ include recipes/fly.mk include recipes/ecs.mk include recipes/local-k8s.mk -.PHONY: all help dev build-ui build run install-air clean test install-ui setup-workspace work-init work-clean docs build-docker-image cleanup-enterprise mod-tidy +.PHONY: all help dev build-ui build run install-air clean test install-ui setup-workspace work-init work-clean docs build-docker-image cleanup-enterprise mod-tidy test-integrations-py test-integrations-ts all: help @@ -44,7 +44,7 @@ help: ## Show this help message @echo " LOG_LEVEL Logger level: debug|info|warn|error (default: info)" @echo " APP_DIR App data directory inside container (default: /app/data)" @echo " LOCAL Use local go.work for builds (e.g., make build LOCAL=1)" - @echo " DEBUG Enable air + delve debugger on port 2345 (e.g., make dev DEBUG=1)" + @echo " DEBUG Enable delve debugger on port 2345 (e.g., make dev DEBUG=1, make test-core DEBUG=1)" @echo "" @echo "$(YELLOW)Test Configuration:$(NC)" @echo " TEST_REPORTS_DIR Directory for HTML test reports (default: test-reports)" @@ -333,7 +333,7 @@ test: install-gotestsum ## Run tests for bifrost-http echo "$(CYAN)JUnit XML report: $(TEST_REPORTS_DIR)/bifrost-http.xml$(NC)"; \ fi -test-core: install-gotestsum ## Run core tests (Usage: make test-core PROVIDER=openai TESTCASE=TestName or PATTERN=substring) +test-core: install-gotestsum $(if $(DEBUG),install-delve) ## Run core tests (Usage: make test-core PROVIDER=openai TESTCASE=TestName or PATTERN=substring, DEBUG=1 for debugger) @echo "$(GREEN)Running core tests...$(NC)" @mkdir -p $(TEST_REPORTS_DIR) @if [ -n "$(PATTERN)" ] && [ -n "$(TESTCASE)" ]; then \ @@ -356,6 +356,10 @@ test-core: install-gotestsum ## Run core tests (Usage: make test-core PROVIDER=o echo "$(YELLOW)Loading environment variables from .env...$(NC)"; \ set -a; . ./.env; set +a; \ fi; \ + if [ -n "$(DEBUG)" ]; then \ + echo "$(CYAN)Debug mode enabled - delve debugger will listen on port 2345$(NC)"; \ + echo "$(YELLOW)Attach your debugger to localhost:2345$(NC)"; \ + fi; \ if [ -n "$(PROVIDER)" ]; then \ PROVIDER_TEST_NAME=$$(echo "$(PROVIDER)" | awk '{print toupper(substr($$0,1,1)) tolower(substr($$0,2))}' | sed 's/openai/OpenAI/i; s/sgl/SGL/i'); \ if [ -n "$(TESTCASE)" ]; then \ @@ -365,10 +369,14 @@ test-core: install-gotestsum ## Run core tests (Usage: make test-core PROVIDER=o CLEAN_TESTCASE=$$(echo "$$CLEAN_TESTCASE" | sed 's|^Test[A-Z][A-Za-z]*/[A-Z][A-Za-z]*Tests/||'); \ echo "$(CYAN)Running Test$${PROVIDER_TEST_NAME}/$${PROVIDER_TEST_NAME}Tests/$$CLEAN_TESTCASE...$(NC)"; \ REPORT_FILE="$(TEST_REPORTS_DIR)/core-$(PROVIDER)-$$(echo $$CLEAN_TESTCASE | sed 's|/|_|g').xml"; \ - cd core/providers/$(PROVIDER) && GOWORK=off gotestsum \ - --format=$(GOTESTSUM_FORMAT) \ - --junitfile=../../../$$REPORT_FILE \ - -- -v -run "^Test$${PROVIDER_TEST_NAME}$$/.*Tests/$$CLEAN_TESTCASE$$" || TEST_FAILED=1; \ + if [ -n "$(DEBUG)" ]; then \ + cd core/providers/$(PROVIDER) && GOWORK=off dlv test --headless --listen=:2345 --api-version=2 -- -test.v -test.run "^Test$${PROVIDER_TEST_NAME}$$/.*Tests/$$CLEAN_TESTCASE$$" || TEST_FAILED=1; \ + else \ + cd core/providers/$(PROVIDER) && GOWORK=off gotestsum \ + --format=$(GOTESTSUM_FORMAT) \ + --junitfile=../../../$$REPORT_FILE \ + -- -v -run "^Test$${PROVIDER_TEST_NAME}$$/.*Tests/$$CLEAN_TESTCASE$$" || TEST_FAILED=1; \ + fi; \ cd ../../..; \ $(MAKE) cleanup-junit-xml REPORT_FILE=$$REPORT_FILE; \ if [ -z "$$CI" ] && [ -z "$$GITHUB_ACTIONS" ] && [ -z "$$GITLAB_CI" ] && [ -z "$$CIRCLECI" ] && [ -z "$$JENKINS_HOME" ]; then \ @@ -389,10 +397,14 @@ test-core: install-gotestsum ## Run core tests (Usage: make test-core PROVIDER=o elif [ -n "$(PATTERN)" ]; then \ echo "$(CYAN)Running tests matching '$(PATTERN)' for $${PROVIDER_TEST_NAME}...$(NC)"; \ REPORT_FILE="$(TEST_REPORTS_DIR)/core-$(PROVIDER)-$(PATTERN).xml"; \ - cd core/providers/$(PROVIDER) && GOWORK=off gotestsum \ - --format=$(GOTESTSUM_FORMAT) \ - --junitfile=../../../$$REPORT_FILE \ - -- -v -run ".*$(PATTERN).*" || TEST_FAILED=1; \ + if [ -n "$(DEBUG)" ]; then \ + cd core/providers/$(PROVIDER) && GOWORK=off dlv test --headless --listen=:2345 --api-version=2 -- -test.v -test.run ".*$(PATTERN).*" || TEST_FAILED=1; \ + else \ + cd core/providers/$(PROVIDER) && GOWORK=off gotestsum \ + --format=$(GOTESTSUM_FORMAT) \ + --junitfile=../../../$$REPORT_FILE \ + -- -v -run ".*$(PATTERN).*" || TEST_FAILED=1; \ + fi; \ cd ../../..; \ $(MAKE) cleanup-junit-xml REPORT_FILE=$$REPORT_FILE; \ if [ -z "$$CI" ] && [ -z "$$GITHUB_ACTIONS" ] && [ -z "$$GITLAB_CI" ] && [ -z "$$CIRCLECI" ] && [ -z "$$JENKINS_HOME" ]; then \ @@ -413,10 +425,14 @@ test-core: install-gotestsum ## Run core tests (Usage: make test-core PROVIDER=o else \ echo "$(CYAN)Running Test$${PROVIDER_TEST_NAME}...$(NC)"; \ REPORT_FILE="$(TEST_REPORTS_DIR)/core-$(PROVIDER).xml"; \ - cd core/providers/$(PROVIDER) && GOWORK=off gotestsum \ - --format=$(GOTESTSUM_FORMAT) \ - --junitfile=../../../$$REPORT_FILE \ - -- -v -run "^Test$${PROVIDER_TEST_NAME}$$" || TEST_FAILED=1; \ + if [ -n "$(DEBUG)" ]; then \ + cd core/providers/$(PROVIDER) && GOWORK=off dlv test --headless --listen=:2345 --api-version=2 -- -test.v -test.run "^Test$${PROVIDER_TEST_NAME}$$" || TEST_FAILED=1; \ + else \ + cd core/providers/$(PROVIDER) && GOWORK=off gotestsum \ + --format=$(GOTESTSUM_FORMAT) \ + --junitfile=../../../$$REPORT_FILE \ + -- -v -run "^Test$${PROVIDER_TEST_NAME}$$" || TEST_FAILED=1; \ + fi; \ cd ../../..; \ $(MAKE) cleanup-junit-xml REPORT_FILE=$$REPORT_FILE; \ if [ -z "$$CI" ] && [ -z "$$GITHUB_ACTIONS" ] && [ -z "$$GITLAB_CI" ] && [ -z "$$CIRCLECI" ] && [ -z "$$JENKINS_HOME" ]; then \ @@ -444,16 +460,24 @@ test-core: install-gotestsum ## Run core tests (Usage: make test-core PROVIDER=o if [ -n "$(PATTERN)" ]; then \ echo "$(CYAN)Running tests matching '$(PATTERN)' across all providers...$(NC)"; \ REPORT_FILE="$(TEST_REPORTS_DIR)/core-all-$(PATTERN).xml"; \ - cd core && GOWORK=off gotestsum \ - --format=$(GOTESTSUM_FORMAT) \ - --junitfile=../$$REPORT_FILE \ - -- -v -run ".*$(PATTERN).*" ./providers/... || TEST_FAILED=1; \ + if [ -n "$(DEBUG)" ]; then \ + cd core && GOWORK=off dlv test --headless --listen=:2345 --api-version=2 ./providers/... -- -test.v -test.run ".*$(PATTERN).*" || TEST_FAILED=1; \ + else \ + cd core && GOWORK=off gotestsum \ + --format=$(GOTESTSUM_FORMAT) \ + --junitfile=../$$REPORT_FILE \ + -- -v -run ".*$(PATTERN).*" ./providers/... || TEST_FAILED=1; \ + fi; \ else \ REPORT_FILE="$(TEST_REPORTS_DIR)/core-all.xml"; \ - cd core && GOWORK=off gotestsum \ - --format=$(GOTESTSUM_FORMAT) \ - --junitfile=../$$REPORT_FILE \ - -- -v ./providers/... || TEST_FAILED=1; \ + if [ -n "$(DEBUG)" ]; then \ + cd core && GOWORK=off dlv test --headless --listen=:2345 --api-version=2 ./providers/... -- -test.v || TEST_FAILED=1; \ + else \ + cd core && GOWORK=off gotestsum \ + --format=$(GOTESTSUM_FORMAT) \ + --junitfile=../$$REPORT_FILE \ + -- -v ./providers/... || TEST_FAILED=1; \ + fi; \ fi; \ cd ..; \ $(MAKE) cleanup-junit-xml REPORT_FILE=$$REPORT_FILE; \ @@ -606,10 +630,10 @@ test-chatbot: ## Run interactive chatbot integration test (Usage: RUN_CHATBOT_TE fi @cd core && RUN_CHATBOT_TEST=1 go test -v -run TestChatbot -test-integrations: ## Run Python integration tests (Usage: make test-integrations [INTEGRATION=openai] [TESTCASE=test_name] [PATTERN=substring] [VERBOSE=1]) +test-integrations-py: ## Run Python integration tests (Usage: make test-integrations-py [INTEGRATION=openai] [TESTCASE=test_name] [PATTERN=substring] [VERBOSE=1]) @echo "$(GREEN)Running Python integration tests...$(NC)" - @if [ ! -d "tests/integrations" ]; then \ - echo "$(RED)Error: tests/integrations directory not found$(NC)"; \ + @if [ ! -d "tests/integrations/python" ]; then \ + echo "$(RED)Error: tests/integrations/python directory not found$(NC)"; \ exit 1; \ fi; \ if [ -n "$(PATTERN)" ] && [ -n "$(TESTCASE)" ]; then \ @@ -619,7 +643,7 @@ test-integrations: ## Run Python integration tests (Usage: make test-integration fi; \ if [ -n "$(TESTCASE)" ] && [ -z "$(INTEGRATION)" ]; then \ echo "$(RED)Error: TESTCASE requires INTEGRATION to be specified$(NC)"; \ - echo "$(YELLOW)Usage: make test-integrations INTEGRATION=anthropic TESTCASE=test_05_end2end_tool_calling$(NC)"; \ + echo "$(YELLOW)Usage: make test-integrations-py INTEGRATION=anthropic TESTCASE=test_05_end2end_tool_calling$(NC)"; \ exit 1; \ fi; \ if [ -f .env ]; then \ @@ -636,7 +660,7 @@ test-integrations: ## Run Python integration tests (Usage: make test-integration echo "$(GREEN)āœ“ Bifrost is already running$(NC)"; \ else \ echo "$(YELLOW)Bifrost not running, starting it...$(NC)"; \ - ./tmp/bifrost-http -host "$$TEST_HOST" -port "$$TEST_PORT" -log-style "$(LOG_STYLE)" -log-level "$(LOG_LEVEL)" -app-dir tests/integrations > /tmp/bifrost-test.log 2>&1 & \ + ./tmp/bifrost-http -host "$$TEST_HOST" -port "$$TEST_PORT" -log-style "$(LOG_STYLE)" -log-level "$(LOG_LEVEL)" -app-dir tests/integrations/python > /tmp/bifrost-test.log 2>&1 & \ BIFROST_PID=$$!; \ BIFROST_STARTED=1; \ echo "$(YELLOW)Waiting for Bifrost to be ready...$(NC)"; \ @@ -674,30 +698,26 @@ test-integrations: ## Run Python integration tests (Usage: make test-integration if [ -n "$(INTEGRATION)" ]; then \ if [ -n "$(TESTCASE)" ]; then \ echo "$(CYAN)Running $(INTEGRATION) integration test: $(TESTCASE)...$(NC)"; \ - cd tests/integrations && pytest tests/test_$(INTEGRATION).py::$(TESTCASE) $(if $(VERBOSE),-v,-q) || TEST_FAILED=1; \ + cd tests/integrations/python && pytest tests/test_$(INTEGRATION).py::$(TESTCASE) $(if $(VERBOSE),-v,-q) || TEST_FAILED=1; \ elif [ -n "$(PATTERN)" ]; then \ echo "$(CYAN)Running $(INTEGRATION) integration tests matching '$(PATTERN)'...$(NC)"; \ - cd tests/integrations && pytest tests/test_$(INTEGRATION).py -k "$(PATTERN)" $(if $(VERBOSE),-v,-q) || TEST_FAILED=1; \ + cd tests/integrations/python && pytest tests/test_$(INTEGRATION).py -k "$(PATTERN)" $(if $(VERBOSE),-v,-q) || TEST_FAILED=1; \ else \ echo "$(CYAN)Running $(INTEGRATION) integration tests...$(NC)"; \ - cd tests/integrations && pytest tests/test_$(INTEGRATION).py $(if $(VERBOSE),-v,-q) || TEST_FAILED=1; \ + cd tests/integrations/python && pytest tests/test_$(INTEGRATION).py $(if $(VERBOSE),-v,-q) || TEST_FAILED=1; \ fi; \ else \ if [ -n "$(PATTERN)" ]; then \ echo "$(CYAN)Running all integration tests matching '$(PATTERN)'...$(NC)"; \ - cd tests/integrations && pytest -k "$(PATTERN)" $(if $(VERBOSE),-v,-q) || TEST_FAILED=1; \ + cd tests/integrations/python && pytest -k "$(PATTERN)" $(if $(VERBOSE),-v,-q) || TEST_FAILED=1; \ else \ echo "$(CYAN)Running all integration tests...$(NC)"; \ - cd tests/integrations && pytest $(if $(VERBOSE),-v,-q) || TEST_FAILED=1; \ + cd tests/integrations/python && pytest $(if $(VERBOSE),-v,-q) || TEST_FAILED=1; \ fi; \ fi; \ else \ echo "$(CYAN)Using uv (fast mode)$(NC)"; \ - cd tests/integrations && \ - if [ ! -f .venv/bin/python ]; then \ - echo "$(YELLOW)Installing dependencies with uv...$(NC)"; \ - uv venv && uv pip install -r requirements.txt; \ - fi; \ + cd tests/integrations/python && \ if [ -n "$(INTEGRATION)" ]; then \ if [ -n "$(TESTCASE)" ]; then \ echo "$(CYAN)Running $(INTEGRATION) integration test: $(TESTCASE)...$(NC)"; \ @@ -740,6 +760,114 @@ test-integrations: ## Run Python integration tests (Usage: make test-integration echo "$(GREEN)āœ“ Integration tests complete$(NC)"; \ fi +test-integrations-ts: ## Run TypeScript integration tests (Usage: make test-integrations-ts [INTEGRATION=openai] [TESTCASE=test_name] [PATTERN=substring] [VERBOSE=1]) + @echo "$(GREEN)Running TypeScript integration tests...$(NC)" + @if [ ! -d "tests/integrations/typescript" ]; then \ + echo "$(RED)Error: tests/integrations/typescript directory not found$(NC)"; \ + exit 1; \ + fi; \ + if [ -n "$(PATTERN)" ] && [ -n "$(TESTCASE)" ]; then \ + echo "$(RED)Error: PATTERN and TESTCASE are mutually exclusive$(NC)"; \ + echo "$(YELLOW)Use PATTERN for substring matching or TESTCASE for exact match$(NC)"; \ + exit 1; \ + fi; \ + if [ -n "$(TESTCASE)" ] && [ -z "$(INTEGRATION)" ]; then \ + echo "$(RED)Error: TESTCASE requires INTEGRATION to be specified$(NC)"; \ + echo "$(YELLOW)Usage: make test-integrations-ts INTEGRATION=openai TESTCASE=test_simple_chat$(NC)"; \ + exit 1; \ + fi; \ + if [ -f .env ]; then \ + echo "$(YELLOW)Loading environment variables from .env...$(NC)"; \ + set -a; . ./.env; set +a; \ + fi; \ + BIFROST_STARTED=0; \ + BIFROST_PID=""; \ + TAIL_PID=""; \ + TEST_PORT=$${PORT:-8080}; \ + TEST_HOST=$${HOST:-localhost}; \ + echo "$(CYAN)Checking if Bifrost is running on $$TEST_HOST:$$TEST_PORT...$(NC)"; \ + if curl -s -o /dev/null -w "%{http_code}" http://$$TEST_HOST:$$TEST_PORT/health 2>/dev/null | grep -q "200\|404"; then \ + echo "$(GREEN)āœ“ Bifrost is already running$(NC)"; \ + else \ + echo "$(YELLOW)Bifrost not running, starting it...$(NC)"; \ + ./tmp/bifrost-http -host "$$TEST_HOST" -port "$$TEST_PORT" -log-style "$(LOG_STYLE)" -log-level "$(LOG_LEVEL)" -app-dir tests/integrations/typescript > /tmp/bifrost-test.log 2>&1 & \ + BIFROST_PID=$$!; \ + BIFROST_STARTED=1; \ + echo "$(YELLOW)Waiting for Bifrost to be ready...$(NC)"; \ + echo "$(CYAN)Bifrost logs: /tmp/bifrost-test.log$(NC)"; \ + (tail -f /tmp/bifrost-test.log 2>/dev/null | grep -E "error|panic|Error|ERRO|fatal|Fatal|FATAL" --line-buffered &) & \ + TAIL_PID=$$!; \ + for i in 1 2 3 4 5 6 7 8 9 10; do \ + if curl -s -o /dev/null http://$$TEST_HOST:$$TEST_PORT/health 2>/dev/null; then \ + echo "$(GREEN)āœ“ Bifrost is ready (PID: $$BIFROST_PID)$(NC)"; \ + break; \ + fi; \ + if [ $$i -eq 10 ]; then \ + echo "$(RED)Failed to start Bifrost$(NC)"; \ + echo "$(YELLOW)Bifrost logs:$(NC)"; \ + cat /tmp/bifrost-test.log 2>/dev/null || echo "No log file found"; \ + [ -n "$$BIFROST_PID" ] && kill $$BIFROST_PID 2>/dev/null; \ + [ -n "$$TAIL_PID" ] && kill $$TAIL_PID 2>/dev/null; \ + exit 1; \ + fi; \ + sleep 1; \ + done; \ + fi; \ + TEST_FAILED=0; \ + if ! which npm > /dev/null 2>&1; then \ + echo "$(RED)Error: npm not found$(NC)"; \ + echo "$(YELLOW)Install Node.js: https://nodejs.org/$(NC)"; \ + [ $$BIFROST_STARTED -eq 1 ] && [ -n "$$BIFROST_PID" ] && kill $$BIFROST_PID 2>/dev/null; \ + [ -n "$$TAIL_PID" ] && kill $$TAIL_PID 2>/dev/null; \ + exit 1; \ + fi; \ + echo "$(CYAN)Using npm$(NC)"; \ + cd tests/integrations/typescript && \ + if [ ! -d "node_modules" ]; then \ + echo "$(YELLOW)Installing dependencies...$(NC)"; \ + npm install; \ + fi; \ + if [ -n "$(INTEGRATION)" ]; then \ + if [ -n "$(TESTCASE)" ]; then \ + echo "$(CYAN)Running $(INTEGRATION) integration test: $(TESTCASE)...$(NC)"; \ + npm test -- tests/test-$(INTEGRATION).test.ts -t "$(TESTCASE)" $(if $(VERBOSE),--reporter=verbose,) || TEST_FAILED=1; \ + elif [ -n "$(PATTERN)" ]; then \ + echo "$(CYAN)Running $(INTEGRATION) integration tests matching '$(PATTERN)'...$(NC)"; \ + npm test -- tests/test-$(INTEGRATION).test.ts -t "$(PATTERN)" $(if $(VERBOSE),--reporter=verbose,) || TEST_FAILED=1; \ + else \ + echo "$(CYAN)Running $(INTEGRATION) integration tests...$(NC)"; \ + npm test -- tests/test-$(INTEGRATION).test.ts $(if $(VERBOSE),--reporter=verbose,) || TEST_FAILED=1; \ + fi; \ + else \ + if [ -n "$(PATTERN)" ]; then \ + echo "$(CYAN)Running all integration tests matching '$(PATTERN)'...$(NC)"; \ + npm test -- -t "$(PATTERN)" $(if $(VERBOSE),--reporter=verbose,) || TEST_FAILED=1; \ + else \ + echo "$(CYAN)Running all integration tests...$(NC)"; \ + npm test $(if $(VERBOSE),-- --reporter=verbose,) || TEST_FAILED=1; \ + fi; \ + fi; \ + if [ $$BIFROST_STARTED -eq 1 ] && [ -n "$$BIFROST_PID" ]; then \ + echo "$(YELLOW)Stopping Bifrost (PID: $$BIFROST_PID)...$(NC)"; \ + kill $$BIFROST_PID 2>/dev/null || true; \ + [ -n "$$TAIL_PID" ] && kill $$TAIL_PID 2>/dev/null || true; \ + wait $$BIFROST_PID 2>/dev/null || true; \ + echo "$(GREEN)āœ“ Bifrost stopped$(NC)"; \ + if [ $$TEST_FAILED -eq 1 ]; then \ + echo ""; \ + echo "$(YELLOW)Last 50 lines of Bifrost logs:$(NC)"; \ + tail -50 /tmp/bifrost-test.log 2>/dev/null || echo "No log file found"; \ + fi; \ + fi; \ + echo ""; \ + if [ $$TEST_FAILED -eq 1 ]; then \ + echo "$(RED)āœ— TypeScript integration tests failed$(NC)"; \ + echo "$(CYAN)Full Bifrost logs: /tmp/bifrost-test.log$(NC)"; \ + exit 1; \ + else \ + echo "$(GREEN)āœ“ TypeScript integration tests complete$(NC)"; \ + fi + # Quick start with example config quick-start: ## Quick start with example config and maxim plugin @echo "$(GREEN)Quick starting Bifrost with example configuration...$(NC)" diff --git a/core/bifrost.go b/core/bifrost.go index f2b0bcda95..6e5e26dbd7 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -15,6 +15,7 @@ import ( "time" "github.com/google/uuid" + "github.com/maximhq/bifrost/core/mcp" "github.com/maximhq/bifrost/core/providers/anthropic" "github.com/maximhq/bifrost/core/providers/azure" "github.com/maximhq/bifrost/core/providers/bedrock" @@ -43,7 +44,7 @@ import ( // It contains the request, response and error channels, and the request type. type ChannelMessage struct { schemas.BifrostRequest - Context context.Context + Context *schemas.BifrostContext Response chan *schemas.BifrostResponse ResponseStream chan chan *schemas.BifrostStream Err chan schemas.BifrostError @@ -52,7 +53,7 @@ type ChannelMessage struct { // Bifrost manages providers and maintains specified open channels for concurrent processing. // It handles request routing, provider management, and response processing. type Bifrost struct { - ctx context.Context + ctx *schemas.BifrostContext cancel context.CancelFunc account schemas.Account // account interface plugins atomic.Pointer[[]schemas.Plugin] // list of plugins @@ -67,7 +68,9 @@ type Bifrost struct { pluginPipelinePool sync.Pool // Pool for PluginPipeline objects bifrostRequestPool sync.Pool // Pool for BifrostRequest objects logger schemas.Logger // logger instance, default logger is used if not provided - mcpManager *MCPManager // MCP integration manager (nil if MCP not configured) + tracer atomic.Value // tracer for distributed tracing (stores schemas.Tracer, NoOpTracer if not configured) + mcpManager *mcp.MCPManager // MCP integration manager (nil if MCP not configured) + mcpInitOnce sync.Once // Ensures MCP manager is initialized only once dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. keySelector schemas.KeySelector // Custom key selector function } @@ -76,12 +79,32 @@ type Bifrost struct { type PluginPipeline struct { plugins []schemas.Plugin logger schemas.Logger + tracer schemas.Tracer // Number of PreHooks that were executed (used to determine which PostHooks to run in reverse order) executedPreHooks int // Errors from PreHooks and PostHooks preHookErrors []error postHookErrors []error + + // Streaming post-hook timing accumulation (for aggregated spans) + postHookTimings map[string]*pluginTimingAccumulator // keyed by plugin name + postHookPluginOrder []string // order in which post-hooks ran (for nested span creation) + chunkCount int +} + +// pluginTimingAccumulator accumulates timing information for a plugin across streaming chunks +type pluginTimingAccumulator struct { + totalDuration time.Duration + invocations int + errors int +} + +// tracerWrapper wraps a Tracer to ensure atomic.Value stores consistent types. +// This is necessary because atomic.Value.Store() panics if called with values +// of different concrete types, even if they implement the same interface. +type tracerWrapper struct { + tracer schemas.Tracer } // Global logger instance which is set in the Init function @@ -103,7 +126,14 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { } providerUtils.SetLogger(config.Logger) - bifrostCtx, cancel := context.WithCancel(ctx) + + // Initialize tracer (use NoOpTracer if not provided) + tracer := config.Tracer + if tracer == nil { + tracer = schemas.DefaultTracer() + } + + bifrostCtx, cancel := schemas.NewBifrostContextWithCancel(ctx) bifrost := &Bifrost{ ctx: bifrostCtx, cancel: cancel, @@ -114,6 +144,7 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { keySelector: config.KeySelector, logger: config.Logger, } + bifrost.tracer.Store(&tracerWrapper{tracer: tracer}) bifrost.plugins.Store(&config.Plugins) // Initialize providers slice @@ -180,13 +211,10 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { // Initialize MCP manager if configured if config.MCPConfig != nil { - mcpManager, err := newMCPManager(bifrostCtx, *config.MCPConfig, bifrost.logger) - if err != nil { - bifrost.logger.Warn(fmt.Sprintf("failed to initialize MCP manager: %v", err)) - } else { - bifrost.mcpManager = mcpManager + bifrost.mcpInitOnce.Do(func() { + bifrost.mcpManager = mcp.NewMCPManager(bifrostCtx, *config.MCPConfig, bifrost.logger) bifrost.logger.Info("MCP integration initialized successfully") - } + }) } // Create buffered channels for each provider and start workers @@ -223,6 +251,20 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { return bifrost, nil } +// SetTracer sets the tracer for the Bifrost instance. +func (bifrost *Bifrost) SetTracer(tracer schemas.Tracer) { + if tracer == nil { + // Fall back to no-op tracer if not provided + tracer = schemas.DefaultTracer() + } + bifrost.tracer.Store(&tracerWrapper{tracer: tracer}) +} + +// getTracer returns the tracer from atomic storage with type assertion. +func (bifrost *Bifrost) getTracer() schemas.Tracer { + return bifrost.tracer.Load().(*tracerWrapper).tracer +} + // ReloadConfig reloads the config from DB // Currently we only update account and drop excess requests // We will keep on adding other aspects as required @@ -234,7 +276,7 @@ func (bifrost *Bifrost) ReloadConfig(config schemas.BifrostConfig) error { // PUBLIC API METHODS // ListModelsRequest sends a list models request to the specified provider. -func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) ListModelsRequest(ctx *schemas.BifrostContext, req *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -307,7 +349,7 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr var keys []schemas.Key if providerRequiresKey(baseProvider, config.CustomProviderConfig) { - keys, err = bifrost.getAllSupportedKeys(&ctx, req.Provider, baseProvider) + keys, err = bifrost.getAllSupportedKeys(ctx, req.Provider, baseProvider) if err != nil { bifrostErr := newBifrostError(err) bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ @@ -318,9 +360,15 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr } } - response, bifrostErr := executeRequestWithRetries(&ctx, config, func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + // Store tracer in context BEFORE calling requestHandler, so streaming goroutines + // have access to it for completing deferred spans when the stream ends. + // The streaming goroutine captures the context when it starts, so these values + // must be set before requestHandler() is called. + ctx.SetValue(schemas.BifrostContextKeyTracer, bifrost.getTracer()) + + response, bifrostErr := executeRequestWithRetries(ctx, config, func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { return provider.ListModels(ctx, keys, request) - }, schemas.ListModelsRequest, req.Provider, "") + }, schemas.ListModelsRequest, req.Provider, "", nil) if bifrostErr != nil { bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ RequestType: schemas.ListModelsRequest, @@ -333,7 +381,7 @@ func (bifrost *Bifrost) ListModelsRequest(ctx context.Context, req *schemas.Bifr // ListAllModels lists all models from all configured providers. // It accumulates responses from all providers with a limit of 1000 per provider to get all results. -func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) ListAllModels(ctx *schemas.BifrostContext, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if request == nil { request = &schemas.BifrostListModelsRequest{} } @@ -470,7 +518,7 @@ func (bifrost *Bifrost) ListAllModels(ctx context.Context, request *schemas.Bifr } // TextCompletionRequest sends a text completion request to the specified provider. -func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) TextCompletionRequest(ctx *schemas.BifrostContext, req *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -509,7 +557,7 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas. } // TextCompletionStreamRequest sends a streaming text completion request to the specified provider. -func (bifrost *Bifrost) TextCompletionStreamRequest(ctx context.Context, req *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (bifrost *Bifrost) TextCompletionStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -540,8 +588,7 @@ func (bifrost *Bifrost) TextCompletionStreamRequest(ctx context.Context, req *sc return bifrost.handleStreamRequest(ctx, bifrostReq) } -// ChatCompletionRequest sends a chat completion request to the specified provider. -func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) makeChatCompletionRequest(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -575,12 +622,37 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas. if err != nil { return nil, err } - //TODO: Release the response + return response.ChatResponse, nil } +// ChatCompletionRequest sends a chat completion request to the specified provider. +func (bifrost *Bifrost) ChatCompletionRequest(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + // If ctx is nil, use the bifrost context (defensive check for mcp agent mode) + if ctx == nil { + ctx = bifrost.ctx + } + + response, err := bifrost.makeChatCompletionRequest(ctx, req) + if err != nil { + return nil, err + } + + // Check if we should enter agent mode + if bifrost.mcpManager != nil { + return bifrost.mcpManager.CheckAndExecuteAgentForChatRequest( + ctx, + req, + response, + bifrost.makeChatCompletionRequest, + ) + } + + return response, nil +} + // ChatCompletionStreamRequest sends a chat completion stream request to the specified provider. -func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx context.Context, req *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -613,8 +685,7 @@ func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx context.Context, req *sc return bifrost.handleStreamRequest(ctx, bifrostReq) } -// ResponsesRequest sends a responses request to the specified provider. -func (bifrost *Bifrost) ResponsesRequest(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) makeResponsesRequest(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -648,12 +719,36 @@ func (bifrost *Bifrost) ResponsesRequest(ctx context.Context, req *schemas.Bifro if err != nil { return nil, err } - //TODO: Release the response return response.ResponsesResponse, nil } +// ResponsesRequest sends a responses request to the specified provider. +func (bifrost *Bifrost) ResponsesRequest(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + // If ctx is nil, use the bifrost context (defensive check for mcp agent mode) + if ctx == nil { + ctx = bifrost.ctx + } + + response, err := bifrost.makeResponsesRequest(ctx, req) + if err != nil { + return nil, err + } + + // Check if we should enter agent mode + if bifrost.mcpManager != nil { + return bifrost.mcpManager.CheckAndExecuteAgentForResponsesRequest( + ctx, + req, + response, + bifrost.makeResponsesRequest, + ) + } + + return response, nil +} + // ResponsesStreamRequest sends a responses stream request to the specified provider. -func (bifrost *Bifrost) ResponsesStreamRequest(ctx context.Context, req *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (bifrost *Bifrost) ResponsesStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -687,7 +782,7 @@ func (bifrost *Bifrost) ResponsesStreamRequest(ctx context.Context, req *schemas } // CountTokensRequest sends a count tokens request to the specified provider. -func (bifrost *Bifrost) CountTokensRequest(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) CountTokensRequest(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -726,7 +821,7 @@ func (bifrost *Bifrost) CountTokensRequest(ctx context.Context, req *schemas.Bif } // EmbeddingRequest sends an embedding request to the specified provider. -func (bifrost *Bifrost) EmbeddingRequest(ctx context.Context, req *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) EmbeddingRequest(ctx *schemas.BifrostContext, req *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -765,7 +860,7 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx context.Context, req *schemas.Bifro } // SpeechRequest sends a speech request to the specified provider. -func (bifrost *Bifrost) SpeechRequest(ctx context.Context, req *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) SpeechRequest(ctx *schemas.BifrostContext, req *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -804,7 +899,7 @@ func (bifrost *Bifrost) SpeechRequest(ctx context.Context, req *schemas.BifrostS } // SpeechStreamRequest sends a speech stream request to the specified provider. -func (bifrost *Bifrost) SpeechStreamRequest(ctx context.Context, req *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (bifrost *Bifrost) SpeechStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -838,7 +933,7 @@ func (bifrost *Bifrost) SpeechStreamRequest(ctx context.Context, req *schemas.Bi } // TranscriptionRequest sends a transcription request to the specified provider. -func (bifrost *Bifrost) TranscriptionRequest(ctx context.Context, req *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) TranscriptionRequest(ctx *schemas.BifrostContext, req *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -877,7 +972,7 @@ func (bifrost *Bifrost) TranscriptionRequest(ctx context.Context, req *schemas.B } // TranscriptionStreamRequest sends a transcription stream request to the specified provider. -func (bifrost *Bifrost) TranscriptionStreamRequest(ctx context.Context, req *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (bifrost *Bifrost) TranscriptionStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -911,7 +1006,7 @@ func (bifrost *Bifrost) TranscriptionStreamRequest(ctx context.Context, req *sch } // BatchCreateRequest creates a new batch job for asynchronous processing. -func (bifrost *Bifrost) BatchCreateRequest(ctx context.Context, req *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) BatchCreateRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -962,7 +1057,7 @@ func (bifrost *Bifrost) BatchCreateRequest(ctx context.Context, req *schemas.Bif } // BatchListRequest lists batch jobs for the specified provider. -func (bifrost *Bifrost) BatchListRequest(ctx context.Context, req *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) BatchListRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -995,7 +1090,7 @@ func (bifrost *Bifrost) BatchListRequest(ctx context.Context, req *schemas.Bifro } // BatchRetrieveRequest retrieves a specific batch job. -func (bifrost *Bifrost) BatchRetrieveRequest(ctx context.Context, req *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) BatchRetrieveRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1036,7 +1131,7 @@ func (bifrost *Bifrost) BatchRetrieveRequest(ctx context.Context, req *schemas.B } // BatchCancelRequest cancels a batch job. -func (bifrost *Bifrost) BatchCancelRequest(ctx context.Context, req *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) BatchCancelRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1077,7 +1172,7 @@ func (bifrost *Bifrost) BatchCancelRequest(ctx context.Context, req *schemas.Bif } // BatchResultsRequest retrieves results from a completed batch job. -func (bifrost *Bifrost) BatchResultsRequest(ctx context.Context, req *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) BatchResultsRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1128,7 +1223,7 @@ func (bifrost *Bifrost) BatchResultsRequest(ctx context.Context, req *schemas.Bi } // FileUploadRequest uploads a file to the specified provider. -func (bifrost *Bifrost) FileUploadRequest(ctx context.Context, req *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) FileUploadRequest(ctx *schemas.BifrostContext, req *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1179,7 +1274,7 @@ func (bifrost *Bifrost) FileUploadRequest(ctx context.Context, req *schemas.Bifr } // FileListRequest lists files from the specified provider. -func (bifrost *Bifrost) FileListRequest(ctx context.Context, req *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) FileListRequest(ctx *schemas.BifrostContext, req *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1218,7 +1313,7 @@ func (bifrost *Bifrost) FileListRequest(ctx context.Context, req *schemas.Bifros } // FileRetrieveRequest retrieves file metadata from the specified provider. -func (bifrost *Bifrost) FileRetrieveRequest(ctx context.Context, req *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) FileRetrieveRequest(ctx *schemas.BifrostContext, req *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1259,7 +1354,7 @@ func (bifrost *Bifrost) FileRetrieveRequest(ctx context.Context, req *schemas.Bi } // FileDeleteRequest deletes a file from the specified provider. -func (bifrost *Bifrost) FileDeleteRequest(ctx context.Context, req *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) FileDeleteRequest(ctx *schemas.BifrostContext, req *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1300,7 +1395,7 @@ func (bifrost *Bifrost) FileDeleteRequest(ctx context.Context, req *schemas.Bifr } // FileContentRequest downloads file content from the specified provider. -func (bifrost *Bifrost) FileContentRequest(ctx context.Context, req *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) FileContentRequest(ctx *schemas.BifrostContext, req *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1684,10 +1779,10 @@ func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(a return fmt.Errorf("MCP is not configured in this Bifrost instance") } - return bifrost.mcpManager.registerTool(name, description, handler, toolSchema) + return bifrost.mcpManager.RegisterTool(name, description, handler, toolSchema) } -// ExecuteMCPTool executes an MCP tool call and returns the result as a tool message. +// ExecuteChatMCPTool executes an MCP tool call and returns the result as a chat message. // This is the main public API for manual MCP tool execution. // // Parameters: @@ -1697,7 +1792,7 @@ func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(a // Returns: // - schemas.ChatMessage: Tool message with execution result // - schemas.BifrostError: Any execution error -func (bifrost *Bifrost) ExecuteMCPTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, *schemas.BifrostError) { +func (bifrost *Bifrost) ExecuteChatMCPTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, *schemas.BifrostError) { if bifrost.mcpManager == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1710,13 +1805,12 @@ func (bifrost *Bifrost) ExecuteMCPTool(ctx context.Context, toolCall schemas.Cha } } - result, err := bifrost.mcpManager.executeTool(ctx, toolCall) + result, err := bifrost.mcpManager.ExecuteChatTool(ctx, toolCall) if err != nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: err.Error(), - Error: err, }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ChatCompletionRequest, // MCP tools are used with chat completions @@ -1727,6 +1821,38 @@ func (bifrost *Bifrost) ExecuteMCPTool(ctx context.Context, toolCall schemas.Cha return result, nil } +// ExecuteResponsesMCPTool executes an MCP tool call and returns the result as a responses message. + +// ExecuteResponsesMCPTool executes an MCP tool call and returns the result as a responses message. +func (bifrost *Bifrost) ExecuteResponsesMCPTool(ctx *schemas.BifrostContext, toolCall *schemas.ResponsesToolMessage) (*schemas.ResponsesMessage, *schemas.BifrostError) { + if bifrost.mcpManager == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "MCP is not configured in this Bifrost instance", + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.ResponsesRequest, // MCP tools are used with responses requests + }, + } + } + + result, err := bifrost.mcpManager.ExecuteResponsesTool(ctx, toolCall) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: err.Error(), + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.ResponsesRequest, // MCP tools are used with responses requests + }, + } + } + + return result, nil +} + // IMPORTANT: Running the MCP client management operations (GetMCPClients, AddMCPClient, RemoveMCPClient, EditMCPClientTools) // may temporarily increase latency for incoming requests while the operations are being processed. // These operations involve network I/O and connection management that require mutex locks @@ -1742,12 +1868,9 @@ func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { return nil, fmt.Errorf("MCP is not configured in this Bifrost instance") } - clients, err := bifrost.mcpManager.GetClients() - if err != nil { - return nil, err - } - + clients := bifrost.mcpManager.GetClients() clientsInConfig := make([]schemas.MCPClient, 0, len(clients)) + for _, client := range clients { tools := make([]schemas.ChatToolFunction, 0, len(client.ToolMap)) for _, tool := range client.ToolMap { @@ -1760,21 +1883,27 @@ func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { return tools[i].Name < tools[j].Name }) - state := schemas.MCPConnectionStateConnected - if client.Conn == nil { - state = schemas.MCPConnectionStateDisconnected - } - clientsInConfig = append(clientsInConfig, schemas.MCPClient{ Config: client.ExecutionConfig, Tools: tools, - State: state, + State: client.State, }) } return clientsInConfig, nil } +// GetAvailableTools returns the available tools for the given context. +// +// Returns: +// - []schemas.ChatTool: List of available tools +func (bifrost *Bifrost) GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool { + if bifrost.mcpManager == nil { + return nil + } + return bifrost.mcpManager.GetAvailableTools(ctx) +} + // AddMCPClient adds a new MCP client to the Bifrost instance. // This allows for dynamic MCP client management at runtime. // @@ -1793,13 +1922,17 @@ func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { // }) func (bifrost *Bifrost) AddMCPClient(config schemas.MCPClientConfig) error { if bifrost.mcpManager == nil { - manager := &MCPManager{ - ctx: bifrost.ctx, - clientMap: make(map[string]*MCPClient), - logger: bifrost.logger, - } + // Use sync.Once to ensure thread-safe initialization + bifrost.mcpInitOnce.Do(func() { + bifrost.mcpManager = mcp.NewMCPManager(bifrost.ctx, schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{config}, + }, bifrost.logger) + }) + } - bifrost.mcpManager = manager + // Handle case where initialization succeeded elsewhere but manager is still nil + if bifrost.mcpManager == nil { + return fmt.Errorf("MCP manager is not initialized") } return bifrost.mcpManager.AddClient(config) @@ -1867,6 +2000,21 @@ func (bifrost *Bifrost) ReconnectMCPClient(id string) error { return bifrost.mcpManager.ReconnectClient(id) } +// UpdateToolManagerConfig updates the tool manager config for the MCP manager. +// This allows for hot-reloading of the tool manager config at runtime. +func (bifrost *Bifrost) UpdateToolManagerConfig(maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error { + if bifrost.mcpManager == nil { + return fmt.Errorf("MCP is not configured in this Bifrost instance") + } + + bifrost.mcpManager.UpdateToolManagerConfig(&schemas.MCPToolManagerConfig{ + MaxAgentDepth: maxAgentDepth, + ToolExecutionTimeout: time.Duration(toolExecutionTimeoutInSeconds) * time.Second, + CodeModeBindingLevel: schemas.CodeModeBindingLevel(codeModeBindingLevel), + }) + return nil +} + // PROVIDER MANAGEMENT // createBaseProvider creates a provider based on the base provider type @@ -2196,7 +2344,7 @@ func (bifrost *Bifrost) shouldContinueWithFallbacks(fallback schemas.Fallback, f // It handles plugin hooks, request validation, response processing, and fallback providers. // If the primary provider fails, it will try each fallback provider in order until one succeeds. // It is the wrapper for all non-streaming public API methods. -func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) handleRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { defer bifrost.releaseBifrostRequest(req) provider, model, fallbacks := req.GetRequestFields() if err := validateRequest(req); err != nil { @@ -2216,11 +2364,11 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR bifrost.logger.Debug(fmt.Sprintf("primary provider %s with model %s and %d fallbacks", provider, model, len(fallbacks))) // Try the primary provider first - ctx = context.WithValue(ctx, schemas.BifrostContextKeyFallbackIndex, 0) + ctx.SetValue(schemas.BifrostContextKeyFallbackIndex, 0) // Ensure request ID is set in context before PreHooks if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok { requestID := uuid.New().String() - ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestID, requestID) + ctx.SetValue(schemas.BifrostContextKeyRequestID, requestID) } primaryResult, primaryErr := bifrost.tryRequest(ctx, req) if primaryErr != nil { @@ -2249,13 +2397,23 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR // Try fallbacks in order for i, fallback := range fallbacks { - ctx = context.WithValue(ctx, schemas.BifrostContextKeyFallbackIndex, i+1) + ctx.SetValue(schemas.BifrostContextKeyFallbackIndex, i+1) bifrost.logger.Debug(fmt.Sprintf("trying fallback provider %s with model %s", fallback.Provider, fallback.Model)) - ctx = context.WithValue(ctx, schemas.BifrostContextKeyFallbackRequestID, uuid.New().String()) + ctx.SetValue(schemas.BifrostContextKeyFallbackRequestID, uuid.New().String()) + + // Start span for fallback attempt + tracer := bifrost.getTracer() + spanCtx, handle := tracer.StartSpan(ctx, fmt.Sprintf("fallback.%s.%s", fallback.Provider, fallback.Model), schemas.SpanKindFallback) + tracer.SetAttribute(handle, schemas.AttrProviderName, string(fallback.Provider)) + tracer.SetAttribute(handle, schemas.AttrRequestModel, fallback.Model) + tracer.SetAttribute(handle, "fallback.index", i+1) + ctx.SetValue(schemas.BifrostContextKeySpanID, spanCtx.Value(schemas.BifrostContextKeySpanID)) fallbackReq := bifrost.prepareFallbackRequest(req, fallback) if fallbackReq == nil { bifrost.logger.Debug(fmt.Sprintf("fallback provider %s with model %s is nil", fallback.Provider, fallback.Model)) + tracer.SetAttribute(handle, "error", "fallback request preparation failed") + tracer.EndSpan(handle, schemas.SpanStatusError, "fallback request preparation failed") continue } @@ -2263,9 +2421,16 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR result, fallbackErr := bifrost.tryRequest(ctx, fallbackReq) if fallbackErr == nil { bifrost.logger.Debug(fmt.Sprintf("successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) + tracer.EndSpan(handle, schemas.SpanStatusOk, "") return result, nil } + // End span with error status + if fallbackErr.Error != nil { + tracer.SetAttribute(handle, "error", fallbackErr.Error.Message) + } + tracer.EndSpan(handle, schemas.SpanStatusError, "fallback failed") + // Check if we should continue with more fallbacks if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) { fallbackErr.ExtraFields = schemas.BifrostErrorExtraFields{ @@ -2293,7 +2458,7 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR // It handles plugin hooks, request validation, response processing, and fallback providers. // If the primary provider fails, it will try each fallback provider in order until one succeeds. // It is the wrapper for all streaming public API methods. -func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (bifrost *Bifrost) handleStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { defer bifrost.releaseBifrostRequest(req) provider, model, fallbacks := req.GetRequestFields() @@ -2314,11 +2479,11 @@ func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.Bi } // Try the primary provider first - ctx = context.WithValue(ctx, schemas.BifrostContextKeyFallbackIndex, 0) + ctx.SetValue(schemas.BifrostContextKeyFallbackIndex, 0) // Ensure request ID is set in context before PreHooks if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok { requestID := uuid.New().String() - ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestID, requestID) + ctx.SetValue(schemas.BifrostContextKeyRequestID, requestID) } primaryResult, primaryErr := bifrost.tryStreamRequest(ctx, req) @@ -2337,11 +2502,21 @@ func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.Bi // Try fallbacks in order for i, fallback := range fallbacks { - ctx = context.WithValue(ctx, schemas.BifrostContextKeyFallbackIndex, i+1) - ctx = context.WithValue(ctx, schemas.BifrostContextKeyFallbackRequestID, uuid.New().String()) + ctx.SetValue(schemas.BifrostContextKeyFallbackIndex, i+1) + ctx.SetValue(schemas.BifrostContextKeyFallbackRequestID, uuid.New().String()) + + // Start span for fallback attempt + tracer := bifrost.getTracer() + spanCtx, handle := tracer.StartSpan(ctx, fmt.Sprintf("fallback.%s.%s", fallback.Provider, fallback.Model), schemas.SpanKindFallback) + tracer.SetAttribute(handle, schemas.AttrProviderName, string(fallback.Provider)) + tracer.SetAttribute(handle, schemas.AttrRequestModel, fallback.Model) + tracer.SetAttribute(handle, "fallback.index", i+1) + ctx.SetValue(schemas.BifrostContextKeySpanID, spanCtx.Value(schemas.BifrostContextKeySpanID)) fallbackReq := bifrost.prepareFallbackRequest(req, fallback) if fallbackReq == nil { + tracer.SetAttribute(handle, "error", "fallback request preparation failed") + tracer.EndSpan(handle, schemas.SpanStatusError, "fallback request preparation failed") continue } @@ -2349,9 +2524,16 @@ func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.Bi result, fallbackErr := bifrost.tryStreamRequest(ctx, fallbackReq) if fallbackErr == nil { bifrost.logger.Debug(fmt.Sprintf("successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) + tracer.EndSpan(handle, schemas.SpanStatusOk, "") return result, nil } + // End span with error status + if fallbackErr.Error != nil { + tracer.SetAttribute(handle, "error", fallbackErr.Error.Message) + } + tracer.EndSpan(handle, schemas.SpanStatusError, "fallback failed") + // Check if we should continue with more fallbacks if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) { fallbackErr.ExtraFields = schemas.BifrostErrorExtraFields{ @@ -2377,7 +2559,7 @@ func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.Bi // tryRequest is a generic function that handles common request processing logic // It consolidates queue setup, plugin pipeline execution, enqueue logic, and response handling -func (bifrost *Bifrost) tryRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { provider, model, _ := req.GetRequestFields() queue, err := bifrost.getProviderQueue(provider) if err != nil { @@ -2391,21 +2573,29 @@ func (bifrost *Bifrost) tryRequest(ctx context.Context, req *schemas.BifrostRequ } // Add MCP tools to request if MCP is configured and requested - if req.RequestType != schemas.EmbeddingRequest && - req.RequestType != schemas.SpeechRequest && - req.RequestType != schemas.TranscriptionRequest && - bifrost.mcpManager != nil { - req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req) + if bifrost.mcpManager != nil { + req = bifrost.mcpManager.AddToolsToRequest(ctx, req) } + tracer := bifrost.getTracer() + if tracer == nil { + return nil, newBifrostErrorFromMsg("tracer not found in context") + } + + // Store tracer in context BEFORE calling requestHandler, so streaming goroutines + // have access to it for completing deferred spans when the stream ends. + // The streaming goroutine captures the context when it starts, so these values + // must be set before requestHandler() is called. + ctx.SetValue(schemas.BifrostContextKeyTracer, tracer) + pipeline := bifrost.getPluginPipeline() defer bifrost.releasePluginPipeline(pipeline) - preReq, shortCircuit, preCount := pipeline.RunPreHooks(&ctx, req) + preReq, shortCircuit, preCount := pipeline.RunPreHooks(ctx, req) if shortCircuit != nil { // Handle short-circuit with response (success case) if shortCircuit.Response != nil { - resp, bifrostErr := pipeline.RunPostHooks(&ctx, shortCircuit.Response, nil, preCount) + resp, bifrostErr := pipeline.RunPostHooks(ctx, shortCircuit.Response, nil, preCount) if bifrostErr != nil { return nil, bifrostErr } @@ -2413,7 +2603,7 @@ func (bifrost *Bifrost) tryRequest(ctx context.Context, req *schemas.BifrostRequ } // Handle short-circuit with error if shortCircuit.Error != nil { - resp, bifrostErr := pipeline.RunPostHooks(&ctx, nil, shortCircuit.Error, preCount) + resp, bifrostErr := pipeline.RunPostHooks(ctx, nil, shortCircuit.Error, preCount) if bifrostErr != nil { return nil, bifrostErr } @@ -2476,7 +2666,7 @@ func (bifrost *Bifrost) tryRequest(ctx context.Context, req *schemas.BifrostRequ pluginCount := len(*bifrost.plugins.Load()) select { case result = <-msg.Response: - resp, bifrostErr := pipeline.RunPostHooks(&msg.Context, result, nil, pluginCount) + resp, bifrostErr := pipeline.RunPostHooks(msg.Context, result, nil, pluginCount) if bifrostErr != nil { bifrost.releaseChannelMessage(msg) return nil, bifrostErr @@ -2485,7 +2675,7 @@ func (bifrost *Bifrost) tryRequest(ctx context.Context, req *schemas.BifrostRequ return resp, nil case bifrostErrVal := <-msg.Err: bifrostErrPtr := &bifrostErrVal - resp, bifrostErrPtr = pipeline.RunPostHooks(&msg.Context, nil, bifrostErrPtr, pluginCount) + resp, bifrostErrPtr = pipeline.RunPostHooks(msg.Context, nil, bifrostErrPtr, pluginCount) bifrost.releaseChannelMessage(msg) if bifrostErrPtr != nil { return nil, bifrostErrPtr @@ -2496,7 +2686,7 @@ func (bifrost *Bifrost) tryRequest(ctx context.Context, req *schemas.BifrostRequ // tryStreamRequest is a generic function that handles common request processing logic // It consolidates queue setup, plugin pipeline execution, enqueue logic, and response handling -func (bifrost *Bifrost) tryStreamRequest(ctx context.Context, req *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { provider, model, _ := req.GetRequestFields() queue, err := bifrost.getProviderQueue(provider) if err != nil { @@ -2511,17 +2701,28 @@ func (bifrost *Bifrost) tryStreamRequest(ctx context.Context, req *schemas.Bifro // Add MCP tools to request if MCP is configured and requested if req.RequestType != schemas.SpeechStreamRequest && req.RequestType != schemas.TranscriptionStreamRequest && bifrost.mcpManager != nil { - req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req) + req = bifrost.mcpManager.AddToolsToRequest(ctx, req) } + tracer := bifrost.getTracer() + if tracer == nil { + return nil, newBifrostErrorFromMsg("tracer not found in context") + } + + // Store tracer in context BEFORE calling RunPreHooks, so plugins and streaming goroutines + // have access to it for completing deferred spans when the stream ends. + // The streaming goroutine captures the context when it starts, so these values + // must be set before requestHandler() is called. + ctx.SetValue(schemas.BifrostContextKeyTracer, tracer) + pipeline := bifrost.getPluginPipeline() defer bifrost.releasePluginPipeline(pipeline) - preReq, shortCircuit, preCount := pipeline.RunPreHooks(&ctx, req) + preReq, shortCircuit, preCount := pipeline.RunPreHooks(ctx, req) if shortCircuit != nil { // Handle short-circuit with response (success case) if shortCircuit.Response != nil { - resp, bifrostErr := pipeline.RunPostHooks(&ctx, shortCircuit.Response, nil, preCount) + resp, bifrostErr := pipeline.RunPostHooks(ctx, shortCircuit.Response, nil, preCount) if bifrostErr != nil { return nil, bifrostErr } @@ -2532,7 +2733,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx context.Context, req *schemas.Bifro outputStream := make(chan *schemas.BifrostStream) // Create a post hook runner cause pipeline object is put back in the pool on defer - pipelinePostHookRunner := func(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + pipelinePostHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { return pipeline.RunPostHooks(ctx, result, err, preCount) } @@ -2562,7 +2763,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx context.Context, req *schemas.Bifro } // Run post hooks on the stream message - processedResponse, processedError := pipelinePostHookRunner(&ctx, bifrostResponse, streamMsg.BifrostError) + processedResponse, processedError := pipelinePostHookRunner(ctx, bifrostResponse, streamMsg.BifrostError) streamResponse := &schemas.BifrostStream{} if processedResponse != nil { @@ -2587,7 +2788,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx context.Context, req *schemas.Bifro } // Handle short-circuit with error if shortCircuit.Error != nil { - resp, bifrostErr := pipeline.RunPostHooks(&ctx, nil, shortCircuit.Error, preCount) + resp, bifrostErr := pipeline.RunPostHooks(ctx, nil, shortCircuit.Error, preCount) if bifrostErr != nil { return nil, bifrostErr } @@ -2657,9 +2858,9 @@ func (bifrost *Bifrost) tryStreamRequest(ctx context.Context, req *schemas.Bifro bifrost.logger.Debug("error while executing stream request: %+v", bifrostErrVal) } // Marking final chunk - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) // On error we will complete post-hooks - recoveredResp, recoveredErr := pipeline.RunPostHooks(&ctx, nil, &bifrostErrVal, len(*bifrost.plugins.Load())) + recoveredResp, recoveredErr := pipeline.RunPostHooks(ctx, nil, &bifrostErrVal, len(*bifrost.plugins.Load())) bifrost.releaseChannelMessage(msg) if recoveredErr != nil { return nil, recoveredErr @@ -2675,19 +2876,20 @@ func (bifrost *Bifrost) tryStreamRequest(ctx context.Context, req *schemas.Bifro // It consolidates retry logic, backoff calculation, and error handling // It is not a bifrost method because interface methods in go cannot be generic func executeRequestWithRetries[T any]( - ctx *context.Context, + ctx *schemas.BifrostContext, config *schemas.ProviderConfig, requestHandler func() (T, *schemas.BifrostError), requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, + req *schemas.BifrostRequest, ) (T, *schemas.BifrostError) { var result T var bifrostError *schemas.BifrostError var attempts int for attempts = 0; attempts <= config.NetworkConfig.MaxRetries; attempts++ { - *ctx = context.WithValue(*ctx, schemas.BifrostContextKeyNumberOfRetries, attempts) + ctx.SetValue(schemas.BifrostContextKeyNumberOfRetries, attempts) if attempts > 0 { // Log retry attempt var retryMsg string @@ -2703,15 +2905,81 @@ func executeRequestWithRetries[T any]( // Calculate and apply backoff backoff := calculateBackoff(attempts-1, config) - logger.Debug("sleeping for %s", backoff) + logger.Debug("sleeping for %s before retry", backoff) time.Sleep(backoff) } logger.Debug("attempting %s request for provider %s", requestType, providerKey) + // Start span for LLM call (or retry attempt) + tracer, ok := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer) + if !ok || tracer == nil { + logger.Error("tracer not found in context of executeRequestWithRetries") + return result, newBifrostErrorFromMsg("tracer not found in context") + } + var spanName string + var spanKind schemas.SpanKind + if attempts > 0 { + spanName = fmt.Sprintf("retry.attempt.%d", attempts) + spanKind = schemas.SpanKindRetry + } else { + spanName = "llm.call" + spanKind = schemas.SpanKindLLMCall + } + spanCtx, handle := tracer.StartSpan(ctx, spanName, spanKind) + tracer.SetAttribute(handle, schemas.AttrProviderName, string(providerKey)) + tracer.SetAttribute(handle, schemas.AttrRequestModel, model) + tracer.SetAttribute(handle, "request.type", string(requestType)) + if attempts > 0 { + tracer.SetAttribute(handle, "retry.count", attempts) + } + + // Populate LLM request attributes (messages, parameters, etc.) + if req != nil { + tracer.PopulateLLMRequestAttributes(handle, req) + } + + // Update context with span ID + ctx.SetValue(schemas.BifrostContextKeySpanID, spanCtx.Value(schemas.BifrostContextKeySpanID)) + + // Record stream start time for TTFT calculation (only for streaming requests) + // This is also used by RunPostHooks to detect streaming mode + if IsStreamRequestType(requestType) { + streamStartTime := time.Now() + ctx.SetValue(schemas.BifrostContextKeyStreamStartTime, streamStartTime) + } + // Attempt the request result, bifrostError = requestHandler() + // Check if result is a streaming channel - if so, defer span completion + if _, isStreamChan := any(result).(chan *schemas.BifrostStream); isStreamChan { + // For streaming requests, store the span handle in TraceStore keyed by trace ID + // This allows the provider's streaming goroutine to retrieve it later + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" { + tracer.StoreDeferredSpan(traceID, handle) + } + // Don't end the span here - it will be ended when streaming completes + } else { + // Populate LLM response attributes for non-streaming responses + if resp, ok := any(result).(*schemas.BifrostResponse); ok { + tracer.PopulateLLMResponseAttributes(handle, resp, bifrostError) + } + + // End span with appropriate status + if bifrostError != nil { + if bifrostError.Error != nil { + tracer.SetAttribute(handle, "error", bifrostError.Error.Message) + } + if bifrostError.StatusCode != nil { + tracer.SetAttribute(handle, "status_code", *bifrostError.StatusCode) + } + tracer.EndSpan(handle, schemas.SpanStatusError, "request failed") + } else { + tracer.EndSpan(handle, schemas.SpanStatusOk, "") + } + } + logger.Debug("request %s for provider %s completed", requestType, providerKey) // Check if successful or if we should retry @@ -2788,7 +3056,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas if model != "" { modelPtr = &model } - keys, err = bifrost.getKeysForBatchAndFileOps(&req.Context, provider.GetProviderKey(), baseProvider, modelPtr, isMultiKeyBatchOp) + keys, err = bifrost.getKeysForBatchAndFileOps(req.Context, provider.GetProviderKey(), baseProvider, modelPtr, isMultiKeyBatchOp) if err != nil { bifrost.logger.Debug("error getting keys for batch/file operation: %v", err) req.Err <- schemas.BifrostError{ @@ -2807,8 +3075,16 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas } } else { // Use the custom provider name for actual key selection, but pass base provider type for key validation - key, err = bifrost.selectKeyFromProviderForModel(&req.Context, req.RequestType, provider.GetProviderKey(), model, baseProvider) + // Start span for key selection + keyTracer := bifrost.getTracer() + keySpanCtx, keyHandle := keyTracer.StartSpan(req.Context, "key.selection", schemas.SpanKindInternal) + keyTracer.SetAttribute(keyHandle, schemas.AttrProviderName, string(provider.GetProviderKey())) + keyTracer.SetAttribute(keyHandle, schemas.AttrRequestModel, model) + + key, err = bifrost.selectKeyFromProviderForModel(req.Context, req.RequestType, provider.GetProviderKey(), model, baseProvider) if err != nil { + keyTracer.SetAttribute(keyHandle, "error", err.Error()) + keyTracer.EndSpan(keyHandle, schemas.SpanStatusError, err.Error()) bifrost.logger.Debug("error selecting key for model %s: %v", model, err) req.Err <- schemas.BifrostError{ IsBifrostError: false, @@ -2824,8 +3100,13 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas } continue } - req.Context = context.WithValue(req.Context, schemas.BifrostContextKeySelectedKeyID, key.ID) - req.Context = context.WithValue(req.Context, schemas.BifrostContextKeySelectedKeyName, key.Name) + keyTracer.SetAttribute(keyHandle, "key.id", key.ID) + keyTracer.SetAttribute(keyHandle, "key.name", key.Name) + keyTracer.EndSpan(keyHandle, schemas.SpanStatusOk, "") + // Update context with span ID for subsequent operations + req.Context.SetValue(schemas.BifrostContextKeySpanID, keySpanCtx.Value(schemas.BifrostContextKeySpanID)) + req.Context.SetValue(schemas.BifrostContextKeySelectedKeyID, key.ID) + req.Context.SetValue(schemas.BifrostContextKeySelectedKeyName, key.Name) } } // Create plugin pipeline for streaming requests outside retry loop to prevent leaks @@ -2833,27 +3114,38 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas var pipeline *PluginPipeline if IsStreamRequestType(req.RequestType) { pipeline = bifrost.getPluginPipeline() - postHookRunner = func(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + postHookRunner = func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { resp, bifrostErr := pipeline.RunPostHooks(ctx, result, err, len(*bifrost.plugins.Load())) if bifrostErr != nil { return nil, bifrostErr } return resp, nil } + // Store a finalizer callback to create aggregated post-hook spans at stream end + // This closure captures the pipeline reference and releases it after finalization + postHookSpanFinalizer := func(ctx context.Context) { + pipeline.FinalizeStreamingPostHookSpans(ctx) + // Release the pipeline AFTER finalizing spans (not before streaming completes) + bifrost.releasePluginPipeline(pipeline) + } + req.Context.SetValue(schemas.BifrostContextKeyPostHookSpanFinalizer, postHookSpanFinalizer) } // Execute request with retries if IsStreamRequestType(req.RequestType) { - stream, bifrostError = executeRequestWithRetries(&req.Context, config, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { + stream, bifrostError = executeRequestWithRetries(req.Context, config, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { return bifrost.handleProviderStreamRequest(provider, req, key, postHookRunner) - }, req.RequestType, provider.GetProviderKey(), model) + }, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest) } else { - result, bifrostError = executeRequestWithRetries(&req.Context, config, func() (*schemas.BifrostResponse, *schemas.BifrostError) { + result, bifrostError = executeRequestWithRetries(req.Context, config, func() (*schemas.BifrostResponse, *schemas.BifrostError) { return bifrost.handleProviderRequest(provider, req, key, keys) - }, req.RequestType, provider.GetProviderKey(), model) + }, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest) } - if pipeline != nil { + // Release pipeline immediately for non-streaming requests only + // For streaming, the pipeline is released in the postHookSpanFinalizer after streaming completes + // Exception: if streaming request has an error, release immediately since finalizer won't be called + if pipeline != nil && (!IsStreamRequestType(req.RequestType) || bifrostError != nil) { bifrost.releasePluginPipeline(pipeline) } @@ -3063,21 +3355,38 @@ func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, r // PLUGIN MANAGEMENT // RunPreHooks executes PreHooks in order, tracks how many ran, and returns the final request, any short-circuit decision, and the count. -func (p *PluginPipeline) RunPreHooks(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, int) { +func (p *PluginPipeline) RunPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, int) { var shortCircuit *schemas.PluginShortCircuit var err error - pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(*ctx, 10*time.Second) - defer cancel() - defer func() { - *ctx = pluginCtx.GetParentCtxWithUserValues() - }() + ctx.BlockRestrictedWrites() + defer ctx.UnblockRestrictedWrites() for i, plugin := range p.plugins { - p.logger.Debug("running pre-hook for plugin %s", plugin.GetName()) - req, shortCircuit, err = plugin.PreHook(pluginCtx, req) + pluginName := plugin.GetName() + p.logger.Debug("running pre-hook for plugin %s", pluginName) + // Start span for this plugin's PreHook + spanCtx, handle := p.tracer.StartSpan(ctx, fmt.Sprintf("plugin.%s.prehook", sanitizeSpanName(pluginName)), schemas.SpanKindPlugin) + // Update pluginCtx with span context for nested operations + if spanCtx != nil { + if spanID, ok := spanCtx.Value(schemas.BifrostContextKeySpanID).(string); ok { + ctx.SetValue(schemas.BifrostContextKeySpanID, spanID) + } + } + + req, shortCircuit, err = plugin.PreHook(ctx, req) + + // End span with appropriate status if err != nil { + p.tracer.SetAttribute(handle, "error", err.Error()) + p.tracer.EndSpan(handle, schemas.SpanStatusError, err.Error()) p.preHookErrors = append(p.preHookErrors, err) - p.logger.Warn("error in PreHook for plugin %s: %v", plugin.GetName(), err) + p.logger.Warn("error in PreHook for plugin %s: %s", pluginName, err.Error()) + } else if shortCircuit != nil { + p.tracer.SetAttribute(handle, "short_circuit", true) + p.tracer.EndSpan(handle, schemas.SpanStatusOk, "short-circuit") + } else { + p.tracer.EndSpan(handle, schemas.SpanStatusOk, "") } + p.executedPreHooks = i + 1 if shortCircuit != nil { return req, shortCircuit, p.executedPreHooks // short-circuit: only plugins up to and including i ran @@ -3090,7 +3399,8 @@ func (p *PluginPipeline) RunPreHooks(ctx *context.Context, req *schemas.BifrostR // Accepts the response and error, and allows plugins to transform either (e.g., recover from error, or invalidate a response). // Returns the final response and error after all hooks. If both are set, error takes precedence unless error is nil. // runFrom is the count of plugins whose PreHooks ran; PostHooks will run in reverse from index (runFrom - 1) down to 0 -func (p *PluginPipeline) RunPostHooks(ctx *context.Context, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError, runFrom int) (*schemas.BifrostResponse, *schemas.BifrostError) { +// For streaming requests, it accumulates timing per plugin instead of creating individual spans per chunk. +func (p *PluginPipeline) RunPostHooks(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError, runFrom int) (*schemas.BifrostResponse, *schemas.BifrostError) { // Defensive: ensure count is within valid bounds if runFrom < 0 { runFrom = 0 @@ -3098,22 +3408,55 @@ func (p *PluginPipeline) RunPostHooks(ctx *context.Context, resp *schemas.Bifros if runFrom > len(p.plugins) { runFrom = len(p.plugins) } + // Detect streaming mode - if StreamStartTime is set, we're in a streaming context + isStreaming := ctx.Value(schemas.BifrostContextKeyStreamStartTime) != nil + ctx.BlockRestrictedWrites() + defer ctx.UnblockRestrictedWrites() var err error - pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(*ctx, 10*time.Second) - defer cancel() for i := runFrom - 1; i >= 0; i-- { plugin := p.plugins[i] - p.logger.Debug("running post-hook for plugin %s", plugin.GetName()) - resp, bifrostErr, err = plugin.PostHook(pluginCtx, resp, bifrostErr) - if err != nil { - p.postHookErrors = append(p.postHookErrors, err) - p.logger.Warn("error in PostHook for plugin %s: %v", plugin.GetName(), err) + pluginName := plugin.GetName() + p.logger.Debug("running post-hook for plugin %s", pluginName) + if isStreaming { + // For streaming: accumulate timing, don't create individual spans per chunk + start := time.Now() + resp, bifrostErr, err = plugin.PostHook(ctx, resp, bifrostErr) + duration := time.Since(start) + + p.accumulatePluginTiming(pluginName, duration, err != nil) + if err != nil { + p.postHookErrors = append(p.postHookErrors, err) + p.logger.Warn("error in PostHook for plugin %s: %v", pluginName, err) + } + } else { + // For non-streaming: create span per plugin (existing behavior) + spanCtx, handle := p.tracer.StartSpan(ctx, fmt.Sprintf("plugin.%s.posthook", sanitizeSpanName(pluginName)), schemas.SpanKindPlugin) + // Update pluginCtx with span context for nested operations + if spanCtx != nil { + if spanID, ok := spanCtx.Value(schemas.BifrostContextKeySpanID).(string); ok { + ctx.SetValue(schemas.BifrostContextKeySpanID, spanID) + } + } + + resp, bifrostErr, err = plugin.PostHook(ctx, resp, bifrostErr) + + // End span with appropriate status + if err != nil { + p.tracer.SetAttribute(handle, "error", err.Error()) + p.tracer.EndSpan(handle, schemas.SpanStatusError, err.Error()) + p.postHookErrors = append(p.postHookErrors, err) + p.logger.Warn("error in PostHook for plugin %s: %v", pluginName, err) + } else { + p.tracer.EndSpan(handle, schemas.SpanStatusOk, "") + } } // If a plugin recovers from an error (sets bifrostErr to nil and sets resp), allow that // If a plugin invalidates a response (sets resp to nil and sets bifrostErr), allow that } - // Capturing plugin ctx values and putting them in the request context - *ctx = pluginCtx.GetParentCtxWithUserValues() + // Increment chunk count for streaming + if isStreaming { + p.chunkCount++ + } // Final logic: if both are set, error takes precedence, unless error is nil if bifrostErr != nil { if resp != nil && bifrostErr.StatusCode == nil && bifrostErr.Error != nil && bifrostErr.Error.Type == nil && @@ -3131,6 +3474,91 @@ func (p *PluginPipeline) resetPluginPipeline() { p.executedPreHooks = 0 p.preHookErrors = p.preHookErrors[:0] p.postHookErrors = p.postHookErrors[:0] + // Reset streaming timing accumulation + p.chunkCount = 0 + if p.postHookTimings != nil { + clear(p.postHookTimings) + } + p.postHookPluginOrder = p.postHookPluginOrder[:0] +} + +// accumulatePluginTiming accumulates timing for a plugin during streaming +func (p *PluginPipeline) accumulatePluginTiming(pluginName string, duration time.Duration, hasError bool) { + if p.postHookTimings == nil { + p.postHookTimings = make(map[string]*pluginTimingAccumulator) + } + timing, ok := p.postHookTimings[pluginName] + if !ok { + timing = &pluginTimingAccumulator{} + p.postHookTimings[pluginName] = timing + // Track order on first occurrence (first chunk) + p.postHookPluginOrder = append(p.postHookPluginOrder, pluginName) + } + timing.totalDuration += duration + timing.invocations++ + if hasError { + timing.errors++ + } +} + +// FinalizeStreamingPostHookSpans creates aggregated spans for each plugin after streaming completes. +// This should be called once at the end of streaming to create one span per plugin with average timing. +// Spans are nested to mirror the pre-hook hierarchy (each post-hook is a child of the previous one). +func (p *PluginPipeline) FinalizeStreamingPostHookSpans(ctx context.Context) { + if p.postHookTimings == nil || len(p.postHookPluginOrder) == 0 { + return + } + + // Collect handles and timing info to end spans in reverse order + type spanInfo struct { + handle schemas.SpanHandle + hasErrors bool + } + spans := make([]spanInfo, 0, len(p.postHookPluginOrder)) + currentCtx := ctx + + // Start spans in execution order (nested: each is a child of the previous) + for _, pluginName := range p.postHookPluginOrder { + timing, ok := p.postHookTimings[pluginName] + if !ok || timing.invocations == 0 { + continue + } + + // Create span as child of the previous span (nested hierarchy) + newCtx, handle := p.tracer.StartSpan(currentCtx, fmt.Sprintf("plugin.%s.posthook", sanitizeSpanName(pluginName)), schemas.SpanKindPlugin) + if handle == nil { + continue + } + + // Calculate average duration in milliseconds + avgMs := float64(timing.totalDuration.Milliseconds()) / float64(timing.invocations) + + // Set aggregated attributes + p.tracer.SetAttribute(handle, schemas.AttrPluginInvocations, timing.invocations) + p.tracer.SetAttribute(handle, schemas.AttrPluginAvgDurationMs, avgMs) + p.tracer.SetAttribute(handle, schemas.AttrPluginTotalDurationMs, timing.totalDuration.Milliseconds()) + + if timing.errors > 0 { + p.tracer.SetAttribute(handle, schemas.AttrPluginErrorCount, timing.errors) + } + + spans = append(spans, spanInfo{handle: handle, hasErrors: timing.errors > 0}) + currentCtx = newCtx + } + + // End spans in reverse order (innermost first, like unwinding a call stack) + for i := len(spans) - 1; i >= 0; i-- { + if spans[i].hasErrors { + p.tracer.EndSpan(spans[i].handle, schemas.SpanStatusError, "some invocations failed") + } else { + p.tracer.EndSpan(spans[i].handle, schemas.SpanStatusOk, "") + } + } +} + +// GetChunkCount returns the number of chunks processed during streaming +func (p *PluginPipeline) GetChunkCount() int { + return p.chunkCount } // getPluginPipeline gets a PluginPipeline from the pool and configures it @@ -3138,6 +3566,7 @@ func (bifrost *Bifrost) getPluginPipeline() *PluginPipeline { pipeline := bifrost.pluginPipelinePool.Get().(*PluginPipeline) pipeline.plugins = *bifrost.plugins.Load() pipeline.logger = bifrost.logger + pipeline.tracer = bifrost.getTracer() return pipeline } @@ -3248,10 +3677,10 @@ func (bifrost *Bifrost) releaseBifrostRequest(req *schemas.BifrostRequest) { // getAllSupportedKeys retrieves all valid keys for a ListModels request. // allowing the provider to aggregate results from multiple keys. -func (bifrost *Bifrost) getAllSupportedKeys(ctx *context.Context, providerKey schemas.ModelProvider, baseProviderType schemas.ModelProvider) ([]schemas.Key, error) { +func (bifrost *Bifrost) getAllSupportedKeys(ctx *schemas.BifrostContext, providerKey schemas.ModelProvider, baseProviderType schemas.ModelProvider) ([]schemas.Key, error) { // Check if key has been set in the context explicitly if ctx != nil { - key, ok := (*ctx).Value(schemas.BifrostContextKeyDirectKey).(schemas.Key) + key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key) if ok { // If a direct key is specified, return it as a single-element slice return []schemas.Key{key}, nil @@ -3289,10 +3718,10 @@ func (bifrost *Bifrost) getAllSupportedKeys(ctx *context.Context, providerKey sc // getKeysForBatchAndFileOps retrieves keys for batch and file operations with model filtering. // For batch operations, only keys with UseForBatchAPI enabled are included. // Model filtering: if model is specified and key has model restrictions, only include if model is in list. -func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *context.Context, providerKey schemas.ModelProvider, baseProviderType schemas.ModelProvider, model *string, isBatchOp bool) ([]schemas.Key, error) { +func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *schemas.BifrostContext, providerKey schemas.ModelProvider, baseProviderType schemas.ModelProvider, model *string, isBatchOp bool) ([]schemas.Key, error) { // Check if key has been set in the context explicitly if ctx != nil { - key, ok := (*ctx).Value(schemas.BifrostContextKeyDirectKey).(schemas.Key) + key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key) if ok { // If a direct key is specified, return it as a single-element slice return []schemas.Key{key}, nil @@ -3358,16 +3787,16 @@ func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *context.Context, provider // selectKeyFromProviderForModel selects an appropriate API key for a given provider and model. // It uses weighted random selection if multiple keys are available. -func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) (schemas.Key, error) { +func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContext, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) (schemas.Key, error) { // Check if key has been set in the context explicitly if ctx != nil { - key, ok := (*ctx).Value(schemas.BifrostContextKeyDirectKey).(schemas.Key) + key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key) if ok { return key, nil } } // Check if key skipping is allowed - if skipKeySelection, ok := (*ctx).Value(schemas.BifrostContextKeySkipKeySelection).(bool); ok && skipKeySelection && isKeySkippingAllowed(providerKey) { + if skipKeySelection, ok := ctx.Value(schemas.BifrostContextKeySkipKeySelection).(bool); ok && skipKeySelection && isKeySkippingAllowed(providerKey) { return schemas.Key{}, nil } // Get keys for provider @@ -3452,7 +3881,7 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, requ var requestedKeyName string if ctx != nil { - if keyName, ok := (*ctx).Value(schemas.BifrostContextKeyAPIKeyName).(string); ok { + if keyName, ok := ctx.Value(schemas.BifrostContextKeyAPIKeyName).(string); ok { requestedKeyName = strings.TrimSpace(keyName) } } @@ -3479,7 +3908,7 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, requ } -func WeightedRandomKeySelector(ctx *context.Context, keys []schemas.Key, providerKey schemas.ModelProvider, model string) (schemas.Key, error) { +func WeightedRandomKeySelector(ctx *schemas.BifrostContext, keys []schemas.Key, providerKey schemas.ModelProvider, model string) (schemas.Key, error) { // Use a weighted random selection based on key weights totalWeight := 0 for _, key := range keys { @@ -3529,12 +3958,17 @@ func (bifrost *Bifrost) Shutdown() { // Cleanup MCP manager if bifrost.mcpManager != nil { - err := bifrost.mcpManager.cleanup() + err := bifrost.mcpManager.Cleanup() if err != nil { bifrost.logger.Warn(fmt.Sprintf("Error cleaning up MCP manager: %s", err.Error())) } } + // Stop the tracerWrapper to clean up background goroutines + if tracerWrapper := bifrost.tracer.Load().(*tracerWrapper); tracerWrapper != nil && tracerWrapper.tracer != nil { + tracerWrapper.tracer.Stop() + } + // Cleanup plugins for _, plugin := range *bifrost.plugins.Load() { err := plugin.Cleanup() diff --git a/core/bifrost_test.go b/core/bifrost_test.go index 6b642d632a..2053515196 100644 --- a/core/bifrost_test.go +++ b/core/bifrost_test.go @@ -51,8 +51,9 @@ func createBifrostError(message string, statusCode *int, errorType *string, isBi // Test executeRequestWithRetries - success scenarios func TestExecuteRequestWithRetries_SuccessScenarios(t *testing.T) { config := createTestConfig(3, 100*time.Millisecond, 1*time.Second) - ctx := context.Background() - + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + // Adding dummy tracer to the context + ctx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{}) // Test immediate success t.Run("ImmediateSuccess", func(t *testing.T) { callCount := 0 @@ -62,12 +63,13 @@ func TestExecuteRequestWithRetries_SuccessScenarios(t *testing.T) { } result, err := executeRequestWithRetries( - &ctx, + ctx, config, handler, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", + nil, ) if callCount != 1 { @@ -95,12 +97,13 @@ func TestExecuteRequestWithRetries_SuccessScenarios(t *testing.T) { } result, err := executeRequestWithRetries( - &ctx, + ctx, config, handler, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", + nil, ) if callCount != 3 { @@ -118,7 +121,8 @@ func TestExecuteRequestWithRetries_SuccessScenarios(t *testing.T) { // Test executeRequestWithRetries - retry limits func TestExecuteRequestWithRetries_RetryLimits(t *testing.T) { config := createTestConfig(2, 100*time.Millisecond, 1*time.Second) - ctx := context.Background() + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + ctx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{}) t.Run("ExceedsMaxRetries", func(t *testing.T) { callCount := 0 handler := func() (string, *schemas.BifrostError) { @@ -128,12 +132,13 @@ func TestExecuteRequestWithRetries_RetryLimits(t *testing.T) { } result, err := executeRequestWithRetries( - &ctx, + ctx, config, handler, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", + nil, ) // Should try: initial + 2 retries = 3 total attempts @@ -158,7 +163,8 @@ func TestExecuteRequestWithRetries_RetryLimits(t *testing.T) { // Test executeRequestWithRetries - non-retryable errors func TestExecuteRequestWithRetries_NonRetryableErrors(t *testing.T) { config := createTestConfig(3, 100*time.Millisecond, 1*time.Second) - ctx := context.Background() + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + ctx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{}) testCases := []struct { name string error *schemas.BifrostError @@ -190,12 +196,13 @@ func TestExecuteRequestWithRetries_NonRetryableErrors(t *testing.T) { } result, err := executeRequestWithRetries( - &ctx, + ctx, config, handler, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", + nil, ) if callCount != 1 { @@ -214,7 +221,8 @@ func TestExecuteRequestWithRetries_NonRetryableErrors(t *testing.T) { // Test executeRequestWithRetries - retryable conditions func TestExecuteRequestWithRetries_RetryableConditions(t *testing.T) { config := createTestConfig(1, 100*time.Millisecond, 1*time.Second) - ctx := context.Background() + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + ctx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{}) testCases := []struct { name string error *schemas.BifrostError @@ -262,12 +270,13 @@ func TestExecuteRequestWithRetries_RetryableConditions(t *testing.T) { } result, err := executeRequestWithRetries( - &ctx, + ctx, config, handler, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", + nil, ) // Should try: initial + 1 retry = 2 total attempts @@ -287,7 +296,7 @@ func TestExecuteRequestWithRetries_RetryableConditions(t *testing.T) { // Test calculateBackoff - exponential growth (base calculations without jitter) func TestCalculateBackoff_ExponentialGrowth(t *testing.T) { config := createTestConfig(5, 100*time.Millisecond, 5*time.Second) - + // Test the base exponential calculation by checking that results fall within expected ranges // Since we can't easily mock rand.Float64, we'll test the bounds instead testCases := []struct { @@ -471,8 +480,8 @@ func TestIsRateLimitError_EdgeCases(t *testing.T) { // Test retry logging and attempt counting func TestExecuteRequestWithRetries_LoggingAndCounting(t *testing.T) { config := createTestConfig(2, 50*time.Millisecond, 1*time.Second) - ctx := context.Background() - + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + ctx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{}) // Capture calls and timing for verification var attemptCounts []int callCount := 0 @@ -490,12 +499,13 @@ func TestExecuteRequestWithRetries_LoggingAndCounting(t *testing.T) { } result, err := executeRequestWithRetries( - &ctx, + ctx, config, handler, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", + nil, ) // Verify call progression @@ -632,7 +642,7 @@ func (ma *MockAccount) GetConfigForProvider(provider schemas.ModelProvider) (*sc return nil, fmt.Errorf("provider %s not configured", provider) } -func (ma *MockAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { +func (ma *MockAccount) GetKeysForProvider(ctx context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { if keys, exists := ma.keys[provider]; exists { return keys, nil } @@ -647,7 +657,7 @@ func TestUpdateProvider(t *testing.T) { account.AddProvider(schemas.OpenAI, 5, 1000) // Initialize Bifrost - ctx := context.Background() + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) bifrost, err := Init(ctx, schemas.BifrostConfig{ Account: account, Logger: NewDefaultLogger(schemas.LogLevelError), // Keep tests quiet diff --git a/core/changelog.md b/core/changelog.md index e69de29bb2..e9993725b0 100644 --- a/core/changelog.md +++ b/core/changelog.md @@ -0,0 +1 @@ +- chore: added case-insensitive helper methods for header and query parameter lookups in HTTPRequest \ No newline at end of file diff --git a/core/chatbot_test.go b/core/chatbot_test.go index ff8cef7a23..8d1258d391 100644 --- a/core/chatbot_test.go +++ b/core/chatbot_test.go @@ -56,7 +56,7 @@ func (account *ComprehensiveTestAccount) GetConfiguredProviders() ([]schemas.Mod } // GetKeysForProvider returns the API keys and associated models for a given provider. -func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { +func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { switch providerKey { case schemas.OpenAI: return []schemas.Key{ @@ -306,8 +306,8 @@ func (s *ChatSession) getAvailableProviders() []schemas.ModelProvider { availableProviders = append(availableProviders, provider) continue } - ctx := context.Background() - keys, err := s.account.GetKeysForProvider(&ctx, provider) + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + keys, err := s.account.GetKeysForProvider(ctx, provider) if err == nil && len(keys) > 0 && keys[0].Value != "" { availableProviders = append(availableProviders, provider) } @@ -317,8 +317,8 @@ func (s *ChatSession) getAvailableProviders() []schemas.ModelProvider { // getAvailableModels returns available models for a given provider func (s *ChatSession) getAvailableModels(provider schemas.ModelProvider) []string { - ctx := context.Background() - keys, err := s.account.GetKeysForProvider(&ctx, provider) + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + keys, err := s.account.GetKeysForProvider(ctx, provider) if err != nil || len(keys) == 0 { return []string{} } @@ -486,7 +486,7 @@ func (s *ChatSession) SendMessage(message string) (string, error) { stopChan, wg := startLoader() // Send request - response, err := s.client.ChatCompletionRequest(context.Background(), request) + response, err := s.client.ChatCompletionRequest(schemas.NewBifrostContext(context.Background(), schemas.NoDeadline), request) // Stop loading animation stopLoader(stopChan, wg) @@ -563,7 +563,7 @@ func (s *ChatSession) handleToolCalls(assistantMessage schemas.ChatMessage) (str stopChan, wg := startLoader() // Execute the tool using Bifrost's integrated MCP functionality - toolResult, err := s.client.ExecuteMCPTool(context.Background(), toolCall) + toolResult, err := s.client.ExecuteChatMCPTool(schemas.NewBifrostContext(context.Background(), schemas.NoDeadline), toolCall) // Stop loading animation stopLoader(stopChan, wg) @@ -638,7 +638,7 @@ func (s *ChatSession) synthesizeToolResults() (string, error) { stopChan, wg := startLoader() // Send synthesis request - synthesisResponse, err := s.client.ChatCompletionRequest(context.Background(), synthesisRequest) + synthesisResponse, err := s.client.ChatCompletionRequest(schemas.NewBifrostContext(context.Background(), schemas.NoDeadline), synthesisRequest) // Stop loading animation stopLoader(stopChan, wg) diff --git a/core/go.mod b/core/go.mod index 681293bcec..1cbbc4ecf5 100644 --- a/core/go.mod +++ b/core/go.mod @@ -12,6 +12,8 @@ require ( github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0 github.com/aws/smithy-go v1.24.0 github.com/bytedance/sonic v1.14.2 + github.com/clarkmcc/go-typescript v0.7.0 + github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 github.com/google/uuid v1.6.0 github.com/hajimehoshi/go-mp3 v0.3.4 github.com/mark3labs/mcp-go v0.43.2 @@ -26,6 +28,7 @@ require ( cloud.google.com/go/compute/metadata v0.9.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect + github.com/Masterminds/semver/v3 v3.3.1 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 // indirect @@ -46,7 +49,10 @@ require ( github.com/bytedance/sonic/loader v0.4.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect diff --git a/core/go.sum b/core/go.sum index eca31022e0..82b68dfd53 100644 --- a/core/go.sum +++ b/core/go.sum @@ -12,6 +12,8 @@ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/Masterminds/semver/v3 v3.3.1 h1:QtNSWtVZ3nBfk8mAOu/B6v7FMJ+NHTIgUPi7rj+4nv4= +github.com/Masterminds/semver/v3 v3.3.1/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= @@ -62,6 +64,8 @@ github.com/bytedance/sonic v1.14.2 h1:k1twIoe97C1DtYUo+fZQy865IuHia4PR5RPiuGPPII github.com/bytedance/sonic v1.14.2/go.mod h1:T80iDELeHiHKSc0C9tubFygiuXoGzrkjKzX2quAx980= github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2NYzevs+o= github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/clarkmcc/go-typescript v0.7.0 h1:3nVeaPYyTCWjX6Lf8GoEOTxME2bM5tLuWmwhSZ86uxg= +github.com/clarkmcc/go-typescript v0.7.0/go.mod h1:IZ/nzoVeydAmyfX7l6Jmp8lJDOEnae3jffoXwP4UyYg= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -69,13 +73,21 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 h1:jxmXU5V9tXxJnydU5v/m9SG8TRUa/Z7IXODBpMs/P+U= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f h1:HU1RgM6NALf/KW9HEY6zry3ADbDKcmpQ+hJedoNGQYQ= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f/go.mod h1:67FPmZWbr+KDT/VlpWtw6sO9XSjpJmLuHpoLmWiTGgY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= @@ -161,6 +173,8 @@ golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/core/internal/testutil/account.go b/core/internal/testutil/account.go index 826e209531..9f2da6f1c9 100644 --- a/core/internal/testutil/account.go +++ b/core/internal/testutil/account.go @@ -123,7 +123,7 @@ func (account *ComprehensiveTestAccount) GetConfiguredProviders() ([]schemas.Mod } // GetKeysForProvider returns the API keys and associated models for a given provider. -func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { +func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { switch providerKey { case schemas.OpenAI: return []schemas.Key{ diff --git a/core/internal/testutil/automatic_function_calling.go b/core/internal/testutil/automatic_function_calling.go index 93b9711a5e..7a3b775fbd 100644 --- a/core/internal/testutil/automatic_function_calling.go +++ b/core/internal/testutil/automatic_function_calling.go @@ -65,6 +65,7 @@ func RunAutomaticFunctionCallingTest(t *testing.T, client *bifrost.Bifrost, ctx // Create operations for both Chat Completions and Responses API chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) chatReq := &schemas.BifrostChatRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -85,10 +86,11 @@ func RunAutomaticFunctionCallingTest(t *testing.T, client *bifrost.Bifrost, ctx Fallbacks: testConfig.Fallbacks, } - return client.ChatCompletionRequest(ctx, chatReq) + return client.ChatCompletionRequest(bfCtx, chatReq) } responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) responsesReq := &schemas.BifrostResponsesRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -107,7 +109,7 @@ func RunAutomaticFunctionCallingTest(t *testing.T, client *bifrost.Bifrost, ctx Fallbacks: testConfig.Fallbacks, } - return client.ResponsesRequest(ctx, responsesReq) + return client.ResponsesRequest(bfCtx, responsesReq) } // Execute dual API test - passes only if BOTH APIs succeed diff --git a/core/internal/testutil/batch.go b/core/internal/testutil/batch.go index cd1a53ecaa..92f5e89c07 100644 --- a/core/internal/testutil/batch.go +++ b/core/internal/testutil/batch.go @@ -38,8 +38,8 @@ func RunBatchCreateTest(t *testing.T, client *bifrost.Bifrost, ctx context.Conte CompletionWindow: "24h", ExtraParams: testConfig.BatchExtraParams, } - - response, err := client.BatchCreateRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + response, err := client.BatchCreateRequest(bfCtx, request) if err != nil { // Check if this is an unsupported operation error if err.Error != nil && (err.Error.Code != nil && *err.Error.Code == "unsupported_operation") { @@ -79,7 +79,8 @@ func RunBatchListTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context Limit: 10, } - response, err := client.BatchListRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + response, err := client.BatchListRequest(bfCtx, request) if err != nil { // Check if this is an unsupported operation error if err.Error != nil && (err.Error.Code != nil && *err.Error.Code == "unsupported_operation") { @@ -130,7 +131,8 @@ func RunBatchRetrieveTest(t *testing.T, client *bifrost.Bifrost, ctx context.Con ExtraParams: testConfig.BatchExtraParams, } - createResponse, createErr := client.BatchCreateRequest(ctx, createRequest) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + createResponse, createErr := client.BatchCreateRequest(bfCtx, createRequest) if createErr != nil { // Check if this is an unsupported operation error if createErr.Error != nil && (createErr.Error.Code != nil && *createErr.Error.Code == "unsupported_operation") { @@ -152,7 +154,8 @@ func RunBatchRetrieveTest(t *testing.T, client *bifrost.Bifrost, ctx context.Con BatchID: createResponse.ID, } - response, err := client.BatchRetrieveRequest(ctx, retrieveRequest) + bfCtx2 := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + response, err := client.BatchRetrieveRequest(bfCtx2, retrieveRequest) if err != nil { t.Errorf("BatchRetrieve failed: %v", err) return @@ -202,7 +205,8 @@ func RunBatchCancelTest(t *testing.T, client *bifrost.Bifrost, ctx context.Conte ExtraParams: testConfig.BatchExtraParams, } - createResponse, createErr := client.BatchCreateRequest(ctx, createRequest) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + createResponse, createErr := client.BatchCreateRequest(bfCtx, createRequest) if createErr != nil { // Check if this is an unsupported operation error if createErr.Error != nil && (createErr.Error.Code != nil && *createErr.Error.Code == "unsupported_operation") { @@ -224,7 +228,8 @@ func RunBatchCancelTest(t *testing.T, client *bifrost.Bifrost, ctx context.Conte BatchID: createResponse.ID, } - response, err := client.BatchCancelRequest(ctx, cancelRequest) + bfCtx2 := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + response, err := client.BatchCancelRequest(bfCtx2, cancelRequest) if err != nil { // Note: Cancel might fail if batch has already completed t.Logf("[WARNING] BatchCancel failed (batch may have already completed): %v", err) @@ -261,7 +266,8 @@ func RunBatchResultsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Cont BatchID: "test-batch-id", // This would be a real batch ID in practice } - _, err := client.BatchResultsRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + _, err := client.BatchResultsRequest(bfCtx, request) if err != nil { // This is expected to fail with a "batch not found" error since we're using a fake ID // In a real test, you would use an actual completed batch ID @@ -310,7 +316,8 @@ func RunBatchUnsupportedTest(t *testing.T, client *bifrost.Bifrost, ctx context. }, } - _, err := client.BatchCreateRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + _, err := client.BatchCreateRequest(bfCtx, request) if err == nil { t.Error("BatchCreate should have failed for unsupported provider") return @@ -352,7 +359,8 @@ func RunFileUploadTest(t *testing.T, client *bifrost.Bifrost, ctx context.Contex ExtraParams: testConfig.FileExtraParams, } - response, err := client.FileUploadRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + response, err := client.FileUploadRequest(bfCtx, request) if err != nil { // Check if this is an unsupported operation error if err.Error != nil && (err.Error.Code != nil && *err.Error.Code == "unsupported_operation") { @@ -393,7 +401,8 @@ func RunFileListTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, ExtraParams: testConfig.FileExtraParams, } - response, err := client.FileListRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + response, err := client.FileListRequest(bfCtx, request) if err != nil { // Check if this is an unsupported operation error if err.Error != nil && (err.Error.Code != nil && *err.Error.Code == "unsupported_operation") { @@ -435,7 +444,8 @@ func RunFileRetrieveTest(t *testing.T, client *bifrost.Bifrost, ctx context.Cont ExtraParams: testConfig.FileExtraParams, } - uploadResponse, uploadErr := client.FileUploadRequest(ctx, uploadRequest) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + uploadResponse, uploadErr := client.FileUploadRequest(bfCtx, uploadRequest) if uploadErr != nil { if uploadErr.Error != nil && (uploadErr.Error.Code != nil && *uploadErr.Error.Code == "unsupported_operation") { t.Logf("[EXPECTED] Provider %s returned unsupported operation error for upload", testConfig.Provider) @@ -456,7 +466,8 @@ func RunFileRetrieveTest(t *testing.T, client *bifrost.Bifrost, ctx context.Cont FileID: uploadResponse.ID, } - response, err := client.FileRetrieveRequest(ctx, retrieveRequest) + bfCtx2 := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + response, err := client.FileRetrieveRequest(bfCtx2, retrieveRequest) if err != nil { t.Errorf("FileRetrieve failed: %v", err) return @@ -498,7 +509,8 @@ func RunFileDeleteTest(t *testing.T, client *bifrost.Bifrost, ctx context.Contex ExtraParams: testConfig.FileExtraParams, } - uploadResponse, uploadErr := client.FileUploadRequest(ctx, uploadRequest) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + uploadResponse, uploadErr := client.FileUploadRequest(bfCtx, uploadRequest) if uploadErr != nil { if uploadErr.Error != nil && (uploadErr.Error.Code != nil && *uploadErr.Error.Code == "unsupported_operation") { t.Logf("[EXPECTED] Provider %s returned unsupported operation error for upload", testConfig.Provider) @@ -519,7 +531,8 @@ func RunFileDeleteTest(t *testing.T, client *bifrost.Bifrost, ctx context.Contex FileID: uploadResponse.ID, } - response, err := client.FileDeleteRequest(ctx, deleteRequest) + bfCtx2 := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + response, err := client.FileDeleteRequest(bfCtx2, deleteRequest) if err != nil { t.Errorf("FileDelete failed: %v", err) return @@ -561,7 +574,8 @@ func RunFileContentTest(t *testing.T, client *bifrost.Bifrost, ctx context.Conte ExtraParams: testConfig.FileExtraParams, } - uploadResponse, uploadErr := client.FileUploadRequest(ctx, uploadRequest) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + uploadResponse, uploadErr := client.FileUploadRequest(bfCtx, uploadRequest) if uploadErr != nil { if uploadErr.Error != nil && (uploadErr.Error.Code != nil && *uploadErr.Error.Code == "unsupported_operation") { t.Logf("[EXPECTED] Provider %s returned unsupported operation error for upload", testConfig.Provider) @@ -582,7 +596,8 @@ func RunFileContentTest(t *testing.T, client *bifrost.Bifrost, ctx context.Conte FileID: uploadResponse.ID, } - response, err := client.FileContentRequest(ctx, contentRequest) + bfCtx2 := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + response, err := client.FileContentRequest(bfCtx2, contentRequest) if err != nil { t.Errorf("FileContent failed: %v", err) return @@ -630,7 +645,8 @@ func RunFileUnsupportedTest(t *testing.T, client *bifrost.Bifrost, ctx context.C Purpose: "batch", } - _, err := client.FileUploadRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + _, err := client.FileUploadRequest(bfCtx, request) if err == nil { t.Error("FileUpload should have failed for unsupported provider") return @@ -671,7 +687,8 @@ func RunFileAndBatchIntegrationTest(t *testing.T, client *bifrost.Bifrost, ctx c ExtraParams: testConfig.FileExtraParams, } - uploadResponse, uploadErr := client.FileUploadRequest(ctx, uploadRequest) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + uploadResponse, uploadErr := client.FileUploadRequest(bfCtx, uploadRequest) if uploadErr != nil { if uploadErr.Error != nil && (uploadErr.Error.Code != nil && *uploadErr.Error.Code == "unsupported_operation") { t.Logf("[EXPECTED] Provider %s returned unsupported operation error for upload", testConfig.Provider) @@ -698,7 +715,8 @@ func RunFileAndBatchIntegrationTest(t *testing.T, client *bifrost.Bifrost, ctx c ExtraParams: testConfig.BatchExtraParams, } - batchResponse, batchErr := client.BatchCreateRequest(ctx, batchRequest) + bfCtx2 := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + batchResponse, batchErr := client.BatchCreateRequest(bfCtx2, batchRequest) if batchErr != nil { if batchErr.Error != nil && (batchErr.Error.Code != nil && *batchErr.Error.Code == "unsupported_operation") { t.Logf("[EXPECTED] Provider %s returned unsupported operation error for batch create", testConfig.Provider) diff --git a/core/internal/testutil/chat_audio.go b/core/internal/testutil/chat_audio.go index 98cc6de050..60c3c5cc85 100644 --- a/core/internal/testutil/chat_audio.go +++ b/core/internal/testutil/chat_audio.go @@ -74,7 +74,8 @@ func RunChatAudioTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context }, Fallbacks: testConfig.Fallbacks, } - response, err := client.ChatCompletionRequest(ctx, chatReq) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + response, err := client.ChatCompletionRequest(bfCtx, chatReq) if err != nil { return nil, err } @@ -202,7 +203,8 @@ func RunChatAudioStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.C } responseChannel, bifrostErr := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.ChatCompletionStreamRequest(ctx, chatReq) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ChatCompletionStreamRequest(bfCtx, chatReq) }) // Enhanced error handling diff --git a/core/internal/testutil/chat_completion_stream.go b/core/internal/testutil/chat_completion_stream.go index e4c7ed9735..4876680d6b 100644 --- a/core/internal/testutil/chat_completion_stream.go +++ b/core/internal/testutil/chat_completion_stream.go @@ -55,7 +55,8 @@ func RunChatCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont // Use proper streaming retry wrapper for the stream request responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.ChatCompletionStreamRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ChatCompletionStreamRequest(bfCtx, request) }) // Enhanced error handling @@ -259,7 +260,8 @@ func RunChatCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.ChatCompletionStreamRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ChatCompletionStreamRequest(bfCtx, request) }, func(responseChannel chan *schemas.BifrostStream) ChatStreamValidationResult { var toolCallDetected bool @@ -401,7 +403,8 @@ func RunChatCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont // Use proper streaming retry wrapper for the stream request responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.ChatCompletionStreamRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ChatCompletionStreamRequest(bfCtx, request) }) RequireNoError(t, err, "Chat completion stream with reasoning failed") @@ -578,7 +581,8 @@ func RunChatCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.ChatCompletionStreamRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ChatCompletionStreamRequest(bfCtx, request) }, func(responseChannel chan *schemas.BifrostStream) ChatStreamValidationResult { var reasoningDetected bool diff --git a/core/internal/testutil/complete_end_to_end.go b/core/internal/testutil/complete_end_to_end.go index 852135f492..ad79b88ce0 100644 --- a/core/internal/testutil/complete_end_to_end.go +++ b/core/internal/testutil/complete_end_to_end.go @@ -6,7 +6,6 @@ import ( "strings" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) @@ -61,6 +60,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C // Create operations for both APIs chatOperation1 := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) chatReq := &schemas.BifrostChatRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -74,10 +74,11 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C }, Fallbacks: testConfig.Fallbacks, } - return client.ChatCompletionRequest(ctx, chatReq) + return client.ChatCompletionRequest(bfCtx, chatReq) } responsesOperation1 := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) responsesReq := &schemas.BifrostResponsesRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -90,7 +91,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C MaxOutputTokens: bifrost.Ptr(150), }, } - return client.ResponsesRequest(ctx, responsesReq) + return client.ResponsesRequest(bfCtx, responsesReq) } // Execute dual API test for Step 1 @@ -198,6 +199,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C // Create operations for both APIs - Step 2 (processing tool results) chatOperation2 := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) chatReq := &schemas.BifrostChatRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -207,10 +209,11 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C }, Fallbacks: testConfig.Fallbacks, } - return client.ChatCompletionRequest(ctx, chatReq) + return client.ChatCompletionRequest(bfCtx, chatReq) } responsesOperation2 := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) responsesReq := &schemas.BifrostResponsesRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -219,7 +222,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C MaxOutputTokens: bifrost.Ptr(200), }, } - return client.ResponsesRequest(ctx, responsesReq) + return client.ResponsesRequest(bfCtx, responsesReq) } // Execute dual API test for Step 2 (processing tool results) @@ -334,6 +337,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C // Create operations for both APIs - Step 3 chatOperation3 := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) chatReq := &schemas.BifrostChatRequest{ Provider: testConfig.Provider, Model: model, @@ -343,10 +347,11 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C }, Fallbacks: testConfig.Fallbacks, } - return client.ChatCompletionRequest(ctx, chatReq) + return client.ChatCompletionRequest(bfCtx, chatReq) } responsesOperation3 := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) responsesReq := &schemas.BifrostResponsesRequest{ Provider: testConfig.Provider, Model: model, @@ -355,7 +360,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C MaxOutputTokens: bifrost.Ptr(200), }, } - return client.ResponsesRequest(ctx, responsesReq) + return client.ResponsesRequest(bfCtx, responsesReq) } // Execute dual API test for Step 3 diff --git a/core/internal/testutil/count_tokens.go b/core/internal/testutil/count_tokens.go index 3b170b8a62..27032b25be 100644 --- a/core/internal/testutil/count_tokens.go +++ b/core/internal/testutil/count_tokens.go @@ -70,7 +70,8 @@ func RunCountTokenTest(t *testing.T, client *bifrost.Bifrost, ctx context.Contex expectations, "CountTokens", func() (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { - return client.CountTokensRequest(ctx, countTokensReq) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.CountTokensRequest(bfCtx, countTokensReq) }, ) diff --git a/core/internal/testutil/cross_provider_scenarios.go b/core/internal/testutil/cross_provider_scenarios.go index b8fa6fc789..3af9b33d5a 100644 --- a/core/internal/testutil/cross_provider_scenarios.go +++ b/core/internal/testutil/cross_provider_scenarios.go @@ -1,7 +1,6 @@ package testutil import ( - "context" "encoding/json" "fmt" "strings" @@ -389,7 +388,7 @@ func NewOpenAIJudge(client *bifrost.Bifrost, judgeModel string, t *testing.T) *O } // EvaluateResponse judges an LLM response -func (judge *OpenAIJudge) EvaluateResponse(ctx context.Context, evaluation EvaluationRequest) (*EvaluationResult, error) { +func (judge *OpenAIJudge) EvaluateResponse(ctx *schemas.BifrostContext, evaluation EvaluationRequest) (*EvaluationResult, error) { prompt := fmt.Sprintf(`You are an expert AI system evaluator. Evaluate this LLM response. SCENARIO: %s @@ -532,7 +531,7 @@ func NewOpenAIConversationDriver(client *bifrost.Bifrost, driverModel string, t } // GenerateNextMessage creates a natural followup message -func (driver *OpenAIConversationDriver) GenerateNextMessage(ctx context.Context, request NextMessageRequest) (*GeneratedFollowup, error) { +func (driver *OpenAIConversationDriver) GenerateNextMessage(ctx *schemas.BifrostContext, request NextMessageRequest) (*GeneratedFollowup, error) { conversationHistory := driver.formatConversationHistory(request.ConversationHistory) prompt := fmt.Sprintf(`Generate the next realistic user message for a %s scenario. @@ -646,7 +645,7 @@ func (driver *OpenAIConversationDriver) generateFallbackMessage(request NextMess // ============================================================================= // RunCrossProviderScenarioTest executes a complete scenario -func RunCrossProviderScenarioTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, config CrossProviderTestConfig, scenario CrossProviderScenario, useResponsesAPI bool) { +func RunCrossProviderScenarioTest(t *testing.T, client *bifrost.Bifrost, ctx *schemas.BifrostContext, config CrossProviderTestConfig, scenario CrossProviderScenario, useResponsesAPI bool) { apiType := "Chat Completions" if useResponsesAPI { apiType = "Responses API" @@ -778,7 +777,7 @@ func RunCrossProviderScenarioTest(t *testing.T, client *bifrost.Bifrost, ctx con // ============================================================================= // RunCrossProviderConsistencyTest tests same prompt across providers -func RunCrossProviderConsistencyTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, config CrossProviderTestConfig, useResponsesAPI bool) { +func RunCrossProviderConsistencyTest(t *testing.T, client *bifrost.Bifrost, ctx *schemas.BifrostContext, config CrossProviderTestConfig, useResponsesAPI bool) { apiType := "Chat Completions" if useResponsesAPI { apiType = "Responses API" @@ -872,7 +871,7 @@ type ConsistencyResult struct { // HELPER FUNCTIONS // ============================================================================= -func executeStepWithProvider(t *testing.T, client *bifrost.Bifrost, ctx context.Context, +func executeStepWithProvider(t *testing.T, client *bifrost.Bifrost, ctx *schemas.BifrostContext, provider ProviderConfig, history []schemas.ChatMessage, step ScenarioStep, useResponsesAPI bool) (*schemas.BifrostResponse, *schemas.BifrostError) { // Prepare request parameters diff --git a/core/internal/testutil/cross_provider_test.go b/core/internal/testutil/cross_provider_test.go index 0244ffc4df..2a976d0aac 100644 --- a/core/internal/testutil/cross_provider_test.go +++ b/core/internal/testutil/cross_provider_test.go @@ -101,11 +101,13 @@ func TestCrossProviderScenarios(t *testing.T) { for _, scenario := range scenariosList { // Test each scenario with both Chat Completions and Responses API t.Run(scenario.Name+"_ChatCompletions", func(t *testing.T) { - RunCrossProviderScenarioTest(t, client, ctx, testConfig, scenario, false) // false = Chat Completions API + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + RunCrossProviderScenarioTest(t, client, bfCtx, testConfig, scenario, false) // false = Chat Completions API }) t.Run(scenario.Name+"_ResponsesAPI", func(t *testing.T) { - RunCrossProviderScenarioTest(t, client, ctx, testConfig, scenario, true) // true = Responses API + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + RunCrossProviderScenarioTest(t, client, bfCtx, testConfig, scenario, true) // true = Responses API }) } } @@ -138,10 +140,12 @@ func TestCrossProviderConsistency(t *testing.T) { // Test same prompt across different providers t.Run("SamePrompt_DifferentProviders_ChatCompletions", func(t *testing.T) { - RunCrossProviderConsistencyTest(t, client, ctx, testConfig, false) // Chat Completions + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + RunCrossProviderConsistencyTest(t, client, bfCtx, testConfig, false) // Chat Completions }) t.Run("SamePrompt_DifferentProviders_ResponsesAPI", func(t *testing.T) { - RunCrossProviderConsistencyTest(t, client, ctx, testConfig, true) // Responses API + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + RunCrossProviderConsistencyTest(t, client, bfCtx, testConfig, true) // Responses API }) } diff --git a/core/internal/testutil/embedding.go b/core/internal/testutil/embedding.go index b044858b97..632f64b826 100644 --- a/core/internal/testutil/embedding.go +++ b/core/internal/testutil/embedding.go @@ -8,7 +8,6 @@ import ( "strings" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) @@ -100,7 +99,8 @@ func RunEmbeddingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context } embeddingResponse, bifrostErr := WithEmbeddingTestRetry(t, embeddingRetryConfig, retryContext, expectations, "Embedding", func() (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { - return client.EmbeddingRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.EmbeddingRequest(bfCtx, request) }) if bifrostErr != nil { diff --git a/core/internal/testutil/end_to_end_tool_calling.go b/core/internal/testutil/end_to_end_tool_calling.go index cd9294253d..2fbe1d2344 100644 --- a/core/internal/testutil/end_to_end_tool_calling.go +++ b/core/internal/testutil/end_to_end_tool_calling.go @@ -6,7 +6,6 @@ import ( "strings" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) @@ -59,6 +58,7 @@ func RunEnd2EndToolCallingTest(t *testing.T, client *bifrost.Bifrost, ctx contex // Create operations for both APIs chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) chatReq := &schemas.BifrostChatRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -69,10 +69,11 @@ func RunEnd2EndToolCallingTest(t *testing.T, client *bifrost.Bifrost, ctx contex }, Fallbacks: testConfig.Fallbacks, } - return client.ChatCompletionRequest(ctx, chatReq) + return client.ChatCompletionRequest(bfCtx, chatReq) } responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) responsesReq := &schemas.BifrostResponsesRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -81,7 +82,7 @@ func RunEnd2EndToolCallingTest(t *testing.T, client *bifrost.Bifrost, ctx contex Tools: []schemas.ResponsesTool{*responsesTool}, }, } - return client.ResponsesRequest(ctx, responsesReq) + return client.ResponsesRequest(bfCtx, responsesReq) } // Execute dual API test for Step 1 @@ -175,6 +176,7 @@ func RunEnd2EndToolCallingTest(t *testing.T, client *bifrost.Bifrost, ctx contex // Create operations for both APIs - Step 2 chatOperation2 := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) chatReq := &schemas.BifrostChatRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -184,10 +186,11 @@ func RunEnd2EndToolCallingTest(t *testing.T, client *bifrost.Bifrost, ctx contex }, Fallbacks: testConfig.Fallbacks, } - return client.ChatCompletionRequest(ctx, chatReq) + return client.ChatCompletionRequest(bfCtx, chatReq) } responsesOperation2 := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) responsesReq := &schemas.BifrostResponsesRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -196,7 +199,7 @@ func RunEnd2EndToolCallingTest(t *testing.T, client *bifrost.Bifrost, ctx contex MaxOutputTokens: bifrost.Ptr(200), }, } - return client.ResponsesRequest(ctx, responsesReq) + return client.ResponsesRequest(bfCtx, responsesReq) } // Execute dual API test for Step 2 diff --git a/core/internal/testutil/file_base64.go b/core/internal/testutil/file_base64.go index 9f4d63126d..1f2759abea 100644 --- a/core/internal/testutil/file_base64.go +++ b/core/internal/testutil/file_base64.go @@ -10,8 +10,8 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -// Base64 encoded PDF file containing "Hello World!" text -// This is a minimal valid PDF for testing document input functionality +// HelloWorldPDFBase64 is a base64 encoded PDF file containing "Hello World!" text. +// This is a minimal valid PDF for testing document input functionality. const HelloWorldPDFBase64 = "data:application/pdf;base64,JVBERi0xLjcKCjEgMCBvYmogICUgZW50cnkgcG9pbnQKPDwKICAvVHlwZSAvQ2F0YWxvZwogIC" + "9QYWdlcyAyIDAgUgo+PgplbmRvYmoKCjIgMCBvYmoKPDwKICAvVHlwZSAvUGFnZXwKICAvTWV" + "kaWFCb3ggWyAwIDAgMjAwIDIwMCBdCiAgL0NvdW50IDEKICAvS2lkcyBbIDMgMCBSIF0KPj4K" + @@ -64,30 +64,41 @@ func CreateDocumentResponsesMessage(text, documentBase64 string) schemas.Respons } } -// RunFileBase64Test executes the PDF file input test scenario using dual API testing framework +// RunFileBase64Test executes the PDF file input test scenario with separate subtests for each API func RunFileBase64Test(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { if !testConfig.Scenarios.FileBase64 { t.Logf("File base64 not supported for provider %s", testConfig.Provider) return } - t.Run("FileBase64", func(t *testing.T) { + // Run Chat Completions subtest + RunFileBase64ChatCompletionsTest(t, client, ctx, testConfig) + + // Run Responses API subtest + RunFileBase64ResponsesTest(t, client, ctx, testConfig) +} + +// RunFileBase64ChatCompletionsTest executes the file base64 test using Chat Completions API +func RunFileBase64ChatCompletionsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if !testConfig.Scenarios.FileBase64 { + t.Logf("File base64 not supported for provider %s", testConfig.Provider) + return + } + + t.Run("FileBase64-ChatCompletions", func(t *testing.T) { if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { t.Parallel() } - // Create messages for both APIs with base64 PDF document + // Create messages for Chat Completions API with base64 PDF document chatMessages := []schemas.ChatMessage{ CreateDocumentChatMessage("What is the main content of this PDF document? Summarize it.", HelloWorldPDFBase64), } - responsesMessages := []schemas.ResponsesMessage{ - CreateDocumentResponsesMessage("What is the main content of this PDF document? Summarize it.", HelloWorldPDFBase64), - } // Use retry framework for document input requests retryConfig := GetTestRetryConfigForScenario("FileInput", testConfig) retryContext := TestRetryContext{ - ScenarioName: "FileBase64", + ScenarioName: "FileBase64-ChatCompletions", ExpectedBehavior: map[string]interface{}{ "should_process_pdf": true, "should_read_document": true, @@ -104,7 +115,7 @@ func RunFileBase64Test(t *testing.T, client *bifrost.Bifrost, ctx context.Contex }, } - // Enhanced validation for PDF document processing (same for both APIs) + // Enhanced validation for PDF document processing expectations := GetExpectationsForScenario("FileInput", testConfig, map[string]interface{}{}) expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider) expectations.ShouldContainKeywords = append(expectations.ShouldContainKeywords, "hello", "world") @@ -113,8 +124,17 @@ func RunFileBase64Test(t *testing.T, client *bifrost.Bifrost, ctx context.Contex "unable to read", "no file", "corrupted", "unsupported", }...) // PDF processing failure indicators - // Create operations for both Chat Completions and Responses API - chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { + chatRetryConfig := ChatRetryConfig{ + MaxAttempts: retryConfig.MaxAttempts, + BaseDelay: retryConfig.BaseDelay, + MaxDelay: retryConfig.MaxDelay, + Conditions: []ChatRetryCondition{}, + OnRetry: retryConfig.OnRetry, + OnFinalFail: retryConfig.OnFinalFail, + } + + response, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "FileBase64", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) chatReq := &schemas.BifrostChatRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -124,10 +144,78 @@ func RunFileBase64Test(t *testing.T, client *bifrost.Bifrost, ctx context.Contex }, Fallbacks: testConfig.Fallbacks, } - return client.ChatCompletionRequest(ctx, chatReq) + return client.ChatCompletionRequest(bfCtx, chatReq) + }) + + if chatError != nil { + t.Fatalf("āŒ FileBase64 Chat Completions test failed: %v", GetErrorMessage(chatError)) + } + + // Additional validation for PDF document processing + content := GetChatContent(response) + validateDocumentContent(t, content, "Chat Completions") + + t.Logf("šŸŽ‰ Chat Completions API passed FileBase64 test!") + }) +} + +// RunFileBase64ResponsesTest executes the file base64 test using Responses API +func RunFileBase64ResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if !testConfig.Scenarios.FileBase64 { + t.Logf("File base64 not supported for provider %s", testConfig.Provider) + return + } + + t.Run("FileBase64-Responses", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + // Create messages for Responses API with base64 PDF document + responsesMessages := []schemas.ResponsesMessage{ + CreateDocumentResponsesMessage("What is the main content of this PDF document? Summarize it.", HelloWorldPDFBase64), + } + + // Set up retry context for document input requests + retryContext := TestRetryContext{ + ScenarioName: "FileBase64-Responses", + ExpectedBehavior: map[string]interface{}{ + "should_process_pdf": true, + "should_read_document": true, + "should_extract_content": true, + "document_understanding": true, + }, + TestMetadata: map[string]interface{}{ + "provider": testConfig.Provider, + "model": testConfig.ChatModel, + "file_type": "pdf", + "encoding": "base64", + "test_content": "Hello World!", + "expected_keywords": []string{"hello", "world", "pdf", "document"}, + }, } - responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + // Enhanced validation for PDF document processing + expectations := GetExpectationsForScenario("FileInput", testConfig, map[string]interface{}{}) + expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider) + expectations.ShouldContainKeywords = append(expectations.ShouldContainKeywords, "hello", "world") + expectations.ShouldNotContainWords = append(expectations.ShouldNotContainWords, []string{ + "cannot process", "invalid format", "decode error", + "unable to read", "no file", "corrupted", "unsupported", + }...) // PDF processing failure indicators + + retryConfig := GetTestRetryConfigForScenario("FileInput", testConfig) + responsesRetryConfig := ResponsesRetryConfig{ + MaxAttempts: retryConfig.MaxAttempts, + BaseDelay: retryConfig.BaseDelay, + MaxDelay: retryConfig.MaxDelay, + Conditions: []ResponsesRetryCondition{}, + OnRetry: retryConfig.OnRetry, + OnFinalFail: retryConfig.OnFinalFail, + } + + response, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "FileBase64", func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) responsesReq := &schemas.BifrostResponsesRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -137,54 +225,18 @@ func RunFileBase64Test(t *testing.T, client *bifrost.Bifrost, ctx context.Contex }, Fallbacks: testConfig.Fallbacks, } - return client.ResponsesRequest(ctx, responsesReq) - } + return client.ResponsesRequest(bfCtx, responsesReq) + }) - // Execute dual API test - passes only if BOTH APIs succeed - result := WithDualAPITestRetry(t, - retryConfig, - retryContext, - expectations, - "FileBase64", - chatOperation, - responsesOperation) - - // Validate both APIs succeeded - if !result.BothSucceeded { - var errors []string - if result.ChatCompletionsError != nil { - errors = append(errors, "Chat Completions: "+GetErrorMessage(result.ChatCompletionsError)) - } - if result.ResponsesAPIError != nil { - errors = append(errors, "Responses API: "+GetErrorMessage(result.ResponsesAPIError)) - } - if len(errors) == 0 { - errors = append(errors, "One or both APIs failed validation (see logs above)") - } - t.Fatalf("āŒ FileBase64 dual API test failed: %v", errors) - } - - // Additional validation for PDF document processing using universal content extraction - validateChatDocumentProcessing := func(response *schemas.BifrostChatResponse, apiName string) { - content := GetChatContent(response) - validateDocumentContent(t, content, apiName) - } - - validateResponsesDocumentProcessing := func(response *schemas.BifrostResponsesResponse, apiName string) { - content := GetResponsesContent(response) - validateDocumentContent(t, content, apiName) - } - - // Validate both API responses - if result.ChatCompletionsResponse != nil { - validateChatDocumentProcessing(result.ChatCompletionsResponse, "Chat Completions") + if responsesError != nil { + t.Fatalf("āŒ FileBase64 Responses test failed: %v", GetErrorMessage(responsesError)) } - if result.ResponsesAPIResponse != nil { - validateResponsesDocumentProcessing(result.ResponsesAPIResponse, "Responses") - } + // Additional validation for PDF document processing + content := GetResponsesContent(response) + validateDocumentContent(t, content, "Responses") - t.Logf("šŸŽ‰ Both Chat Completions and Responses APIs passed FileBase64 test!") + t.Logf("šŸŽ‰ Responses API passed FileBase64 test!") }) } diff --git a/core/internal/testutil/file_url.go b/core/internal/testutil/file_url.go index 7aa1764afd..a941248970 100644 --- a/core/internal/testutil/file_url.go +++ b/core/internal/testutil/file_url.go @@ -125,6 +125,7 @@ func RunFileURLChatCompletionsTest(t *testing.T, client *bifrost.Bifrost, ctx co } response, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "FileURL", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) chatReq := &schemas.BifrostChatRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -134,7 +135,7 @@ func RunFileURLChatCompletionsTest(t *testing.T, client *bifrost.Bifrost, ctx co }, Fallbacks: testConfig.Fallbacks, } - return client.ChatCompletionRequest(ctx, chatReq) + return client.ChatCompletionRequest(bfCtx, chatReq) }) if chatError != nil { @@ -166,8 +167,7 @@ func RunFileURLResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx context. CreateFileURLResponsesMessage("What is this document about? Please provide a summary of its main topics.", TestFileURL), } - // Use retry framework for file URL requests - retryConfig := GetTestRetryConfigForScenario("FileInput", testConfig) + // Set up retry context for file URL requests retryContext := TestRetryContext{ ScenarioName: "FileURL-Responses", ExpectedBehavior: map[string]interface{}{ @@ -197,16 +197,10 @@ func RunFileURLResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx context. "cannot fetch", "download failed", "url not found", }...) // File URL processing failure indicators - responsesRetryConfig := ResponsesRetryConfig{ - MaxAttempts: retryConfig.MaxAttempts, - BaseDelay: retryConfig.BaseDelay, - MaxDelay: retryConfig.MaxDelay, - Conditions: []ResponsesRetryCondition{}, - OnRetry: retryConfig.OnRetry, - OnFinalFail: retryConfig.OnFinalFail, - } + responsesRetryConfig := FileInputResponsesRetryConfig() response, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "FileURL", func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) responsesReq := &schemas.BifrostResponsesRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -216,7 +210,7 @@ func RunFileURLResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx context. }, Fallbacks: testConfig.Fallbacks, } - return client.ResponsesRequest(ctx, responsesReq) + return client.ResponsesRequest(bfCtx, responsesReq) }) if responsesError != nil { diff --git a/core/internal/testutil/image_base64.go b/core/internal/testutil/image_base64.go index a4ba88c643..41d9377bad 100644 --- a/core/internal/testutil/image_base64.go +++ b/core/internal/testutil/image_base64.go @@ -66,6 +66,7 @@ func RunImageBase64Test(t *testing.T, client *bifrost.Bifrost, ctx context.Conte // Create operations for both Chat Completions and Responses API chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) chatReq := &schemas.BifrostChatRequest{ Provider: testConfig.Provider, Model: testConfig.VisionModel, @@ -75,10 +76,11 @@ func RunImageBase64Test(t *testing.T, client *bifrost.Bifrost, ctx context.Conte }, Fallbacks: testConfig.Fallbacks, } - return client.ChatCompletionRequest(ctx, chatReq) + return client.ChatCompletionRequest(bfCtx, chatReq) } responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) responsesReq := &schemas.BifrostResponsesRequest{ Provider: testConfig.Provider, Model: testConfig.VisionModel, @@ -88,7 +90,7 @@ func RunImageBase64Test(t *testing.T, client *bifrost.Bifrost, ctx context.Conte }, Fallbacks: testConfig.Fallbacks, } - return client.ResponsesRequest(ctx, responsesReq) + return client.ResponsesRequest(bfCtx, responsesReq) } // Execute dual API test - passes only if BOTH APIs succeed diff --git a/core/internal/testutil/image_url.go b/core/internal/testutil/image_url.go index db9d931950..2dccfa2b5e 100644 --- a/core/internal/testutil/image_url.go +++ b/core/internal/testutil/image_url.go @@ -6,7 +6,6 @@ import ( "strings" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) @@ -58,6 +57,7 @@ func RunImageURLTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, // Create operations for both Chat Completions and Responses API chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) chatReq := &schemas.BifrostChatRequest{ Provider: testConfig.Provider, Model: testConfig.VisionModel, @@ -67,10 +67,11 @@ func RunImageURLTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, Fallbacks: testConfig.Fallbacks, } chatReq.Input = chatMessages - return client.ChatCompletionRequest(ctx, chatReq) + return client.ChatCompletionRequest(bfCtx, chatReq) } responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) responsesReq := &schemas.BifrostResponsesRequest{ Provider: testConfig.Provider, Model: testConfig.VisionModel, @@ -80,7 +81,7 @@ func RunImageURLTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, Fallbacks: testConfig.Fallbacks, } responsesReq.Input = responsesMessages - return client.ResponsesRequest(ctx, responsesReq) + return client.ResponsesRequest(bfCtx, responsesReq) } // Execute dual API test - passes only if BOTH APIs succeed diff --git a/core/internal/testutil/list_models.go b/core/internal/testutil/list_models.go index 43be92f2a5..086a084bbf 100644 --- a/core/internal/testutil/list_models.go +++ b/core/internal/testutil/list_models.go @@ -5,7 +5,6 @@ import ( "os" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) @@ -60,7 +59,8 @@ func RunListModelsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Contex } response, bifrostErr := WithListModelsTestRetry(t, listModelsRetryConfig, retryContext, expectations, "ListModels", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - return client.ListModelsRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ListModelsRequest(bfCtx, request) }) if bifrostErr != nil { @@ -161,7 +161,8 @@ func RunListModelsPaginationTest(t *testing.T, client *bifrost.Bifrost, ctx cont } response, bifrostErr := WithListModelsTestRetry(t, listModelsRetryConfig, retryContext, expectations, "ListModelsPagination", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - return client.ListModelsRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ListModelsRequest(bfCtx, request) }) if bifrostErr != nil { @@ -203,7 +204,8 @@ func RunListModelsPaginationTest(t *testing.T, client *bifrost.Bifrost, ctx cont } nextPageResponse, nextPageErr := WithListModelsTestRetry(t, listModelsRetryConfig, nextPageRetryContext, expectations, "ListModelsPagination_NextPage", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - return client.ListModelsRequest(ctx, nextPageRequest) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ListModelsRequest(bfCtx, nextPageRequest) }) if nextPageErr != nil { diff --git a/core/internal/testutil/multi_turn_conversation.go b/core/internal/testutil/multi_turn_conversation.go index e7d471c92e..e1c04ec691 100644 --- a/core/internal/testutil/multi_turn_conversation.go +++ b/core/internal/testutil/multi_turn_conversation.go @@ -66,7 +66,8 @@ func RunMultiTurnConversationTest(t *testing.T, client *bifrost.Bifrost, ctx con expectations1 = ModifyExpectationsForProvider(expectations1, testConfig.Provider) response1, bifrostErr := WithChatTestRetry(t, chatRetryConfig1, retryContext1, expectations1, "MultiTurnConversation_Step1", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { - return client.ChatCompletionRequest(ctx, firstRequest) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ChatCompletionRequest(bfCtx, firstRequest) }) if bifrostErr != nil { @@ -133,7 +134,8 @@ func RunMultiTurnConversationTest(t *testing.T, client *bifrost.Bifrost, ctx con expectations2.ShouldNotContainWords = []string{"don't know", "can't remember", "forgot"} // Memory failure indicators response2, bifrostErr := WithChatTestRetry(t, chatRetryConfig2, retryContext2, expectations2, "MultiTurnConversation_Step2", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { - return client.ChatCompletionRequest(ctx, secondRequest) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ChatCompletionRequest(bfCtx, secondRequest) }) if bifrostErr != nil { diff --git a/core/internal/testutil/multiple_images.go b/core/internal/testutil/multiple_images.go index 9c1ea07b3a..bf5f2d3e45 100644 --- a/core/internal/testutil/multiple_images.go +++ b/core/internal/testutil/multiple_images.go @@ -6,7 +6,6 @@ import ( "strings" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) @@ -101,7 +100,8 @@ func RunMultipleImagesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Co }...) // Failure to process multiple images indicators response, bifrostError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "MultipleImages", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { - return client.ChatCompletionRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ChatCompletionRequest(bfCtx, request) }) // Validation now happens inside WithTestRetry - no need to check again diff --git a/core/internal/testutil/multiple_tool_calls.go b/core/internal/testutil/multiple_tool_calls.go index d1cf7bfbdb..70d53d7000 100644 --- a/core/internal/testutil/multiple_tool_calls.go +++ b/core/internal/testutil/multiple_tool_calls.go @@ -5,7 +5,6 @@ import ( "os" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) @@ -74,6 +73,7 @@ func RunMultipleToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context // Create operations for both Chat Completions and Responses API chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) chatReq := &schemas.BifrostChatRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -83,10 +83,11 @@ func RunMultipleToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context Fallbacks: testConfig.Fallbacks, } chatReq.Input = chatMessages - return client.ChatCompletionRequest(ctx, chatReq) + return client.ChatCompletionRequest(bfCtx, chatReq) } responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) responsesReq := &schemas.BifrostResponsesRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -96,7 +97,7 @@ func RunMultipleToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context Fallbacks: testConfig.Fallbacks, } responsesReq.Input = responsesMessages - return client.ResponsesRequest(ctx, responsesReq) + return client.ResponsesRequest(bfCtx, responsesReq) } // Execute dual API test - passes only if BOTH APIs succeed diff --git a/core/internal/testutil/prompt_caching.go b/core/internal/testutil/prompt_caching.go index 00e02097c5..b1cdd7579f 100644 --- a/core/internal/testutil/prompt_caching.go +++ b/core/internal/testutil/prompt_caching.go @@ -380,7 +380,8 @@ func RunPromptCachingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Con // Execute with retry framework operation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { - return client.ChatCompletionRequest(ctx, chatReq) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ChatCompletionRequest(bfCtx, chatReq) } response, err := WithChatTestRetry(t, retryConfig, retryContext, expectations, query.name, operation) diff --git a/core/internal/testutil/reasoning.go b/core/internal/testutil/reasoning.go index 17b7c45834..752d58a1ad 100644 --- a/core/internal/testutil/reasoning.go +++ b/core/internal/testutil/reasoning.go @@ -85,7 +85,8 @@ func RunResponsesReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx contex expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider) response, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "Reasoning", func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { - return client.ResponsesRequest(ctx, responsesReq) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ResponsesRequest(bfCtx, responsesReq) }) if responsesError != nil { @@ -278,7 +279,8 @@ func RunChatCompletionReasoningTest(t *testing.T, client *bifrost.Bifrost, ctx c expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider) response, chatError := WithChatTestRetry(t, chatRetryConfig, retryContext, expectations, "Reasoning", func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { - return client.ChatCompletionRequest(ctx, chatReq) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ChatCompletionRequest(bfCtx, chatReq) }) if chatError != nil { diff --git a/core/internal/testutil/responses_stream.go b/core/internal/testutil/responses_stream.go index 3059ad9e98..f4bd0e7eb2 100644 --- a/core/internal/testutil/responses_stream.go +++ b/core/internal/testutil/responses_stream.go @@ -63,7 +63,8 @@ func RunResponsesStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.C // Use validation retry wrapper that validates stream content and retries on validation failures validationResult := WithResponsesStreamValidationRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.ResponsesStreamRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ResponsesStreamRequest(bfCtx, request) }, func(responseChannel chan *schemas.BifrostStream) ResponsesStreamValidationResult { var fullContent strings.Builder @@ -332,7 +333,8 @@ func RunResponsesStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.C // Use proper streaming retry wrapper for the stream request responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.ResponsesStreamRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ResponsesStreamRequest(bfCtx, request) }) RequireNoError(t, err, "Responses stream with tools failed") @@ -465,7 +467,8 @@ func RunResponsesStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.C // Use proper streaming retry wrapper for the stream request responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.ResponsesStreamRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ResponsesStreamRequest(bfCtx, request) }) RequireNoError(t, err, "Responses stream with reasoning failed") @@ -589,7 +592,8 @@ func RunResponsesStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.C // Use validation retry wrapper that validates lifecycle events and retries on validation failures validationResult := WithResponsesStreamValidationRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.ResponsesStreamRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ResponsesStreamRequest(bfCtx, request) }, func(responseChannel chan *schemas.BifrostStream) ResponsesStreamValidationResult { // Track lifecycle events diff --git a/core/internal/testutil/simple_chat.go b/core/internal/testutil/simple_chat.go index df9f95d81b..8492d96f97 100644 --- a/core/internal/testutil/simple_chat.go +++ b/core/internal/testutil/simple_chat.go @@ -5,7 +5,6 @@ import ( "os" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) @@ -71,6 +70,7 @@ func RunSimpleChatTest(t *testing.T, client *bifrost.Bifrost, ctx context.Contex // Test Chat Completions API chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) chatReq := &schemas.BifrostChatRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -80,7 +80,7 @@ func RunSimpleChatTest(t *testing.T, client *bifrost.Bifrost, ctx context.Contex }, Fallbacks: testConfig.Fallbacks, } - response, err := client.ChatCompletionRequest(ctx, chatReq) + response, err := client.ChatCompletionRequest(bfCtx, chatReq) if err != nil { return nil, err } @@ -99,13 +99,14 @@ func RunSimpleChatTest(t *testing.T, client *bifrost.Bifrost, ctx context.Contex // Test Responses API responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) responsesReq := &schemas.BifrostResponsesRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, Input: responsesMessages, Fallbacks: testConfig.Fallbacks, } - response, err := client.ResponsesRequest(ctx, responsesReq) + response, err := client.ResponsesRequest(bfCtx, responsesReq) if err != nil { return nil, err } diff --git a/core/internal/testutil/speech_synthesis.go b/core/internal/testutil/speech_synthesis.go index c916f0e5d7..8ab5dc6e49 100644 --- a/core/internal/testutil/speech_synthesis.go +++ b/core/internal/testutil/speech_synthesis.go @@ -110,9 +110,9 @@ func RunSpeechSynthesisTest(t *testing.T, client *bifrost.Bifrost, ctx context.C OnFinalFail: retryConfig.OnFinalFail, } - requestCtx := context.Background() - + speechResponse, bifrostErr := WithSpeechTestRetry(t, speechRetryConfig, retryContext, expectations, "SpeechSynthesis_"+tc.name, func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { + requestCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) return client.SpeechRequest(requestCtx, request) }) @@ -216,9 +216,10 @@ func RunSpeechSynthesisAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx c OnFinalFail: retryConfig.OnFinalFail, } - requestCtx := context.Background() + speechResponse, bifrostErr := WithSpeechTestRetry(t, speechRetryConfig, retryContext, expectations, "SpeechSynthesis_HD", func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { + requestCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) return client.SpeechRequest(requestCtx, request) }) if bifrostErr != nil { @@ -298,9 +299,9 @@ func RunSpeechSynthesisAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx c OnFinalFail: voiceRetryConfig.OnFinalFail, } - requestCtx := context.Background() - + speechResponse, bifrostErr := WithSpeechTestRetry(t, voiceSpeechRetryConfig, voiceRetryContext, expectations, "SpeechSynthesis_VoiceType_"+voiceType, func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { + requestCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) return client.SpeechRequest(requestCtx, request) }) diff --git a/core/internal/testutil/speech_synthesis_stream.go b/core/internal/testutil/speech_synthesis_stream.go index 19955ea201..3ff02118f8 100644 --- a/core/internal/testutil/speech_synthesis_stream.go +++ b/core/internal/testutil/speech_synthesis_stream.go @@ -115,9 +115,10 @@ func RunSpeechSynthesisStreamTest(t *testing.T, client *bifrost.Bifrost, ctx con }, } - requestCtx := context.Background() + responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { + requestCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) return client.SpeechStreamRequest(requestCtx, request) }) @@ -306,9 +307,9 @@ func RunSpeechSynthesisStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, }, } - requestCtx := context.Background() - + responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { + requestCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) return client.SpeechStreamRequest(requestCtx, request) }) @@ -452,16 +453,16 @@ func RunSpeechSynthesisStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, }, } - requestCtx := context.Background() - + // Use retry framework with stream validation var accumulatedAudio bytes.Buffer // Accumulate audio for codec validation validationResult := WithSpeechStreamValidationRetry( t, retryConfig, retryContext, - func() (chan *schemas.BifrostStream, *schemas.BifrostError) { + func() (chan *schemas.BifrostStream, *schemas.BifrostError) { accumulatedAudio.Reset() // Reset buffer on retry + requestCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) return client.SpeechStreamRequest(requestCtx, request) }, func(responseChannel chan *schemas.BifrostStream) SpeechStreamValidationResult { diff --git a/core/internal/testutil/structured_outputs.go b/core/internal/testutil/structured_outputs.go index 865e9054ad..0d9b12af73 100644 --- a/core/internal/testutil/structured_outputs.go +++ b/core/internal/testutil/structured_outputs.go @@ -102,12 +102,12 @@ func testStructuredOutputChatWithValue(t *testing.T, client *bifrost.Bifrost, ct chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { // Add Anthropic beta header for structured outputs if model contains "claude" - reqCtx := ctx + reqCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) if strings.Contains(strings.ToLower(testConfig.ChatModel), "claude") { extraHeaders := map[string][]string{ "anthropic-beta": {"structured-outputs-2025-11-13"}, } - reqCtx = context.WithValue(ctx, schemas.BifrostContextKeyExtraHeaders, extraHeaders) + reqCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, extraHeaders) } chatReq := &schemas.BifrostChatRequest{ @@ -217,12 +217,12 @@ func RunStructuredOutputChatStreamTest(t *testing.T, client *bifrost.Bifrost, ct } // Add Anthropic beta header for structured outputs if model contains "claude" - reqCtx := ctx + reqCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) if strings.Contains(strings.ToLower(testConfig.ChatModel), "claude") { extraHeaders := map[string][]string{ "anthropic-beta": {"structured-outputs-2025-11-13"}, } - reqCtx = context.WithValue(ctx, schemas.BifrostContextKeyExtraHeaders, extraHeaders) + reqCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, extraHeaders) } request := &schemas.BifrostChatRequest{ @@ -373,12 +373,12 @@ func RunStructuredOutputResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx } // Add Anthropic beta header for structured outputs if model contains "claude" - reqCtx := ctx + reqCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) if strings.Contains(strings.ToLower(testConfig.ChatModel), "claude") { extraHeaders := map[string][]string{ "anthropic-beta": {"structured-outputs-2025-11-13"}, } - reqCtx = context.WithValue(ctx, schemas.BifrostContextKeyExtraHeaders, extraHeaders) + reqCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, extraHeaders) } retryConfig := GetTestRetryConfigForScenario("StructuredOutputResponses", testConfig) @@ -510,12 +510,12 @@ func RunStructuredOutputResponsesStreamTest(t *testing.T, client *bifrost.Bifros } // Add Anthropic beta header for structured outputs if model contains "claude" - reqCtx := ctx + reqCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) if strings.Contains(strings.ToLower(testConfig.ChatModel), "claude") { extraHeaders := map[string][]string{ "anthropic-beta": {"structured-outputs-2025-11-13"}, } - reqCtx = context.WithValue(ctx, schemas.BifrostContextKeyExtraHeaders, extraHeaders) + reqCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, extraHeaders) } typeStr := "object" diff --git a/core/internal/testutil/test_retry_conditions.go b/core/internal/testutil/test_retry_conditions.go index dbbceaa221..744f2de700 100644 --- a/core/internal/testutil/test_retry_conditions.go +++ b/core/internal/testutil/test_retry_conditions.go @@ -979,3 +979,172 @@ func (c *InvalidCountTokensCondition) ShouldRetry(response *schemas.BifrostRespo func (c *InvalidCountTokensCondition) GetConditionName() string { return "InvalidCountTokens" } + +// ============================================================================= +// RESPONSES API CONDITIONS +// These implement ResponsesRetryCondition for use with WithResponsesTestRetry +// ============================================================================= + +// ResponsesEmptyCondition checks for empty Responses API responses +type ResponsesEmptyCondition struct{} + +func (c *ResponsesEmptyCondition) ShouldRetry(response *schemas.BifrostResponsesResponse, err *schemas.BifrostError, context TestRetryContext) (bool, string) { + if err != nil { + return false, "" + } + if response == nil { + return true, "response is nil" + } + content := GetResponsesContent(response) + if strings.TrimSpace(content) == "" { + return true, "response has empty content" + } + return false, "" +} + +func (c *ResponsesEmptyCondition) GetConditionName() string { + return "ResponsesEmpty" +} + +// ResponsesFileNotProcessedCondition checks if file/document was not properly processed in Responses API +type ResponsesFileNotProcessedCondition struct{} + +func (c *ResponsesFileNotProcessedCondition) ShouldRetry(response *schemas.BifrostResponsesResponse, err *schemas.BifrostError, context TestRetryContext) (bool, string) { + if err != nil || response == nil { + return false, "" + } + + content := strings.ToLower(GetResponsesContent(response)) + + // Check for generic responses that don't indicate file/document processing + fileProcessingFailurePhrases := []string{ + "i can't read", + "i cannot read", + "unable to read", + "can't access", + "cannot access", + "no file", + "no document", + "not able to read", + "i don't see", + "i cannot process", + "unable to process", + "can't open", + "cannot open", + "invalid file", + "corrupted", + "unsupported format", + "failed to load", + "no pdf", + "cannot view", + } + + for _, phrase := range fileProcessingFailurePhrases { + if strings.Contains(content, phrase) { + return true, fmt.Sprintf("response suggests file was not processed: contains '%s'", phrase) + } + } + + // If content is suspiciously short for document analysis + if len(strings.TrimSpace(content)) < 15 { + return true, "response too short for meaningful document analysis" + } + + return false, "" +} + +func (c *ResponsesFileNotProcessedCondition) GetConditionName() string { + return "ResponsesFileNotProcessed" +} + +// ResponsesGenericResponseCondition checks for generic/template responses in Responses API +type ResponsesGenericResponseCondition struct{} + +func (c *ResponsesGenericResponseCondition) ShouldRetry(response *schemas.BifrostResponsesResponse, err *schemas.BifrostError, context TestRetryContext) (bool, string) { + if err != nil || response == nil { + return false, "" + } + + content := strings.ToLower(GetResponsesContent(response)) + + // Generic phrases that suggest the model didn't engage with the specific request + genericPhrases := []string{ + "as an ai", + "as a language model", + "i'm an ai", + "i am an ai", + "i'm a language model", + "i am a language model", + "i can help you with", + "how can i assist you", + "what would you like to know", + "is there anything else", + } + + // Check if response starts with generic phrases (more concerning) + for _, phrase := range genericPhrases { + if strings.HasPrefix(content, phrase) { + return true, fmt.Sprintf("response starts with generic phrase: '%s'", phrase) + } + } + + // Check for overly generic responses (short and generic) + if len(strings.TrimSpace(content)) < 30 { + for _, phrase := range genericPhrases { + if strings.Contains(content, phrase) { + return true, fmt.Sprintf("short response contains generic phrase: '%s'", phrase) + } + } + } + + return false, "" +} + +func (c *ResponsesGenericResponseCondition) GetConditionName() string { + return "ResponsesGenericResponse" +} + +// ResponsesContentValidationCondition checks if response fails basic content validation for Responses API +type ResponsesContentValidationCondition struct{} + +func (c *ResponsesContentValidationCondition) ShouldRetry(response *schemas.BifrostResponsesResponse, err *schemas.BifrostError, context TestRetryContext) (bool, string) { + if err != nil || response == nil { + return false, "" + } + + content := strings.ToLower(GetResponsesContent(response)) + + // Skip if response is too short (other conditions will handle these) + if len(content) < 10 { + return false, "" + } + + // Check for file/document processing scenarios + scenarioName := strings.ToLower(context.ScenarioName) + if strings.Contains(scenarioName, "file") || strings.Contains(scenarioName, "document") || strings.Contains(scenarioName, "pdf") { + // Check if this test has expected keywords from the TestRetryContext + if testMetadata, exists := context.TestMetadata["expected_keywords"]; exists { + if expectedKeywords, ok := testMetadata.([]string); ok && len(expectedKeywords) > 0 { + // Check if ANY of the expected keywords are present + foundExpectedKeyword := false + for _, keyword := range expectedKeywords { + if strings.Contains(content, strings.ToLower(keyword)) { + foundExpectedKeyword = true + break + } + } + + // If valid response but missing ALL expected keywords, retry + if !foundExpectedKeyword && len(content) > 20 && len(content) < 2000 { + return true, fmt.Sprintf("response missing expected keywords %v, might include them on retry", expectedKeywords) + } + } + } + } + + return false, "" +} + +func (c *ResponsesContentValidationCondition) GetConditionName() string { + return "ResponsesContentValidation" +} diff --git a/core/internal/testutil/test_retry_framework.go b/core/internal/testutil/test_retry_framework.go index aa0fbb317d..477e8dec98 100644 --- a/core/internal/testutil/test_retry_framework.go +++ b/core/internal/testutil/test_retry_framework.go @@ -849,6 +849,24 @@ func FileInputRetryConfig() TestRetryConfig { } } +// FileInputResponsesRetryConfig creates a retry config for file/document input tests using Responses API +func FileInputResponsesRetryConfig() ResponsesRetryConfig { + return ResponsesRetryConfig{ + MaxAttempts: 10, + BaseDelay: 2000 * time.Millisecond, + MaxDelay: 10 * time.Second, + Conditions: []ResponsesRetryCondition{ + &ResponsesEmptyCondition{}, + &ResponsesFileNotProcessedCondition{}, + &ResponsesGenericResponseCondition{}, + &ResponsesContentValidationCondition{}, + }, + OnRetry: func(attempt int, reason string, t *testing.T) { + t.Logf("šŸ”„ Retrying file input test (attempt %d): %s", attempt, reason) + }, + } +} + // StreamingRetryConfig creates a retry config for streaming tests func StreamingRetryConfig() TestRetryConfig { return TestRetryConfig{ diff --git a/core/internal/testutil/text_completion.go b/core/internal/testutil/text_completion.go index f1f5348717..7979595617 100644 --- a/core/internal/testutil/text_completion.go +++ b/core/internal/testutil/text_completion.go @@ -66,7 +66,8 @@ func RunTextCompletionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Co } response, bifrostErr := WithTextCompletionTestRetry(t, textCompletionRetryConfig, retryContext, expectations, "TextCompletion", func() (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { - return client.TextCompletionRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.TextCompletionRequest(bfCtx, request) }) if bifrostErr != nil { diff --git a/core/internal/testutil/text_completion_stream.go b/core/internal/testutil/text_completion_stream.go index c731bd3382..e6a10e1724 100644 --- a/core/internal/testutil/text_completion_stream.go +++ b/core/internal/testutil/text_completion_stream.go @@ -8,7 +8,6 @@ import ( "testing" "time" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" ) @@ -65,7 +64,8 @@ func RunTextCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont // Use proper streaming retry wrapper for the stream request responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.TextCompletionStreamRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.TextCompletionStreamRequest(bfCtx, request) }) // Enhanced error handling @@ -264,7 +264,8 @@ func RunTextCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont } responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.TextCompletionStreamRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.TextCompletionStreamRequest(bfCtx, request) }) RequireNoError(t, err, "Text completion stream with varied prompts failed") @@ -406,7 +407,8 @@ func RunTextCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont } responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.TextCompletionStreamRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.TextCompletionStreamRequest(bfCtx, request) }) RequireNoError(t, err, "Text completion stream with parameters failed") diff --git a/core/internal/testutil/tool_calls.go b/core/internal/testutil/tool_calls.go index 739684b1f4..72850d59e7 100644 --- a/core/internal/testutil/tool_calls.go +++ b/core/internal/testutil/tool_calls.go @@ -7,7 +7,6 @@ import ( "strings" "testing" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/stretchr/testify/require" @@ -61,6 +60,7 @@ func RunToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context // Create operations for both Chat Completions and Responses API chatOperation := func() (*schemas.BifrostChatResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) chatReq := &schemas.BifrostChatRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -71,10 +71,11 @@ func RunToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context }, Fallbacks: testConfig.Fallbacks, } - return client.ChatCompletionRequest(ctx, chatReq) + return client.ChatCompletionRequest(bfCtx, chatReq) } responsesOperation := func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) responsesReq := &schemas.BifrostResponsesRequest{ Provider: testConfig.Provider, Model: testConfig.ChatModel, @@ -83,7 +84,7 @@ func RunToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context Tools: []schemas.ResponsesTool{*responsesTool}, }, } - return client.ResponsesRequest(ctx, responsesReq) + return client.ResponsesRequest(bfCtx, responsesReq) } // Execute dual API test - passes only if BOTH APIs succeed diff --git a/core/internal/testutil/tool_calls_streaming.go b/core/internal/testutil/tool_calls_streaming.go index 96be6927c8..4ad74eecef 100644 --- a/core/internal/testutil/tool_calls_streaming.go +++ b/core/internal/testutil/tool_calls_streaming.go @@ -260,7 +260,8 @@ func RunToolCallsStreamingTest(t *testing.T, client *bifrost.Bifrost, ctx contex } responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.ChatCompletionStreamRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ChatCompletionStreamRequest(bfCtx, request) }) RequireNoError(t, err, "Chat completion stream with tools failed") @@ -387,7 +388,8 @@ func RunToolCallsStreamingTest(t *testing.T, client *bifrost.Bifrost, ctx contex // Use validation retry wrapper that validates tool calls and retries on validation failures validationResult := WithResponsesStreamValidationRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.ResponsesStreamRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ResponsesStreamRequest(bfCtx, request) }, func(responseChannel chan *schemas.BifrostStream) ResponsesStreamValidationResult { accumulator := NewStreamingToolCallAccumulator() diff --git a/core/internal/testutil/transcription.go b/core/internal/testutil/transcription.go index a1d0d62bd1..4508c27ef3 100644 --- a/core/internal/testutil/transcription.go +++ b/core/internal/testutil/transcription.go @@ -139,7 +139,8 @@ func RunTranscriptionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Con } ttsResponse, err := WithSpeechTestRetry(t, speechRetryConfig, ttsRetryContext, ttsExpectations, "Transcription_RoundTrip_TTS_"+tc.name, func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { - return client.SpeechRequest(ctx, ttsRequest) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.SpeechRequest(bfCtx, ttsRequest) }) if err != nil { t.Fatalf("āŒ TTS generation failed for round-trip test after retries: %v", GetErrorMessage(err)) @@ -207,7 +208,8 @@ func RunTranscriptionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Con } transcriptionResponse, bifrostErr := WithTranscriptionTestRetry(t, transcriptionRetryConfig, retryContext, expectations, "Transcription_RoundTrip_"+tc.name, func() (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { - return client.TranscriptionRequest(ctx, transcriptionRequest) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.TranscriptionRequest(bfCtx, transcriptionRequest) }) if bifrostErr != nil { @@ -315,7 +317,8 @@ func RunTranscriptionTest(t *testing.T, client *bifrost.Bifrost, ctx context.Con } response, err := WithTranscriptionTestRetry(t, customTranscriptionRetryConfig, customRetryContext, customExpectations, "Transcription_Custom_"+tc.name, func() (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { - return client.TranscriptionRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.TranscriptionRequest(bfCtx, request) }) if err != nil { errorMsg := GetErrorMessage(err) @@ -424,7 +427,8 @@ func RunTranscriptionAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx con } response, err := WithTranscriptionTestRetry(t, formatTranscriptionRetryConfig, formatRetryContext, formatExpectations, "Transcription_Format_"+format, func() (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { - return client.TranscriptionRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.TranscriptionRequest(bfCtx, request) }) if err != nil { errorMsg := GetErrorMessage(err) @@ -519,7 +523,8 @@ func RunTranscriptionAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx con } response, err := WithTranscriptionTestRetry(t, advancedTranscriptionRetryConfig, advancedRetryContext, advancedExpectations, "Transcription_Advanced_CustomParams", func() (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { - return client.TranscriptionRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.TranscriptionRequest(bfCtx, request) }) if err != nil { errorMsg := GetErrorMessage(err) @@ -616,7 +621,8 @@ func RunTranscriptionAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx con } response, err := WithTranscriptionTestRetry(t, langTranscriptionRetryConfig, langRetryContext, langExpectations, "Transcription_Language_"+lang, func() (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { - return client.TranscriptionRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.TranscriptionRequest(bfCtx, request) }) if err != nil { errorMsg := GetErrorMessage(err) diff --git a/core/internal/testutil/transcription_stream.go b/core/internal/testutil/transcription_stream.go index 1f27dcf1ff..e7a1855dea 100644 --- a/core/internal/testutil/transcription_stream.go +++ b/core/internal/testutil/transcription_stream.go @@ -113,7 +113,8 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte } ttsResponse, err := WithSpeechTestRetry(t, ttsSpeechRetryConfig, ttsRetryContext, ttsExpectations, "TranscriptionStream_TTS", func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { - return client.SpeechRequest(ctx, ttsRequest) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.SpeechRequest(bfCtx, ttsRequest) }) if err != nil { t.Fatalf("āŒ TTS generation failed for stream round-trip test after retries: %v", GetErrorMessage(err)) @@ -170,7 +171,8 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte } responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.TranscriptionStreamRequest(ctx, streamRequest) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.TranscriptionStreamRequest(bfCtx, streamRequest) }) RequireNoError(t, err, "Transcription stream initiation failed") @@ -387,7 +389,8 @@ func RunTranscriptionStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, c } responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.TranscriptionStreamRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.TranscriptionStreamRequest(bfCtx, request) }) RequireNoError(t, err, "JSON streaming failed") @@ -487,7 +490,8 @@ func RunTranscriptionStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, c } responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.TranscriptionStreamRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.TranscriptionStreamRequest(bfCtx, request) }) RequireNoError(t, err, fmt.Sprintf("Streaming failed for language %s", lang)) @@ -581,7 +585,8 @@ func RunTranscriptionStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, c } responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return client.TranscriptionStreamRequest(ctx, request) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.TranscriptionStreamRequest(bfCtx, request) }) RequireNoError(t, err, "Custom prompt streaming failed") diff --git a/core/internal/testutil/utils.go b/core/internal/testutil/utils.go index b268996293..1f4d259c94 100644 --- a/core/internal/testutil/utils.go +++ b/core/internal/testutil/utils.go @@ -616,7 +616,8 @@ func GenerateTTSAudioForTest(ctx context.Context, t *testing.T, client *bifrost. } resp, err := WithSpeechTestRetry(t, speechRetryConfig, retryContext, expectations, "GenerateTTSAudioForTest", func() (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { - return client.SpeechRequest(ctx, req) + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.SpeechRequest(bfCtx, req) }) if err != nil { t.Fatalf("TTS request failed after retries: %v", GetErrorMessage(err)) diff --git a/core/mcp.go b/core/mcp.go index b1eb7a73e5..be9fb580b0 100644 --- a/core/mcp.go +++ b/core/mcp.go @@ -1135,15 +1135,8 @@ func (m *MCPManager) createInProcessConnection(config schemas.MCPClientConfig) ( if config.InProcessServer == nil { return nil, MCPClientConnectionInfo{}, fmt.Errorf("InProcess connection requires a server instance") } - - // Type assert to ensure we have a proper MCP server - mcpServer, ok := config.InProcessServer.(*server.MCPServer) - if !ok { - return nil, MCPClientConnectionInfo{}, fmt.Errorf("InProcessServer must be a *server.MCPServer instance") - } - // Create in-process client directly connected to the provided server - inProcessClient, err := client.NewInProcessClient(mcpServer) + inProcessClient, err := client.NewInProcessClient(config.InProcessServer) if err != nil { return nil, MCPClientConnectionInfo{}, fmt.Errorf("failed to create in-process client: %w", err) } diff --git a/core/mcp/agent.go b/core/mcp/agent.go new file mode 100644 index 0000000000..1992864d76 --- /dev/null +++ b/core/mcp/agent.go @@ -0,0 +1,472 @@ +package mcp + +import ( + "fmt" + "strings" + "sync" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" +) + +// ExecuteAgentForChatRequest handles the agent mode execution loop for Chat API. +// It orchestrates iterative tool execution up to the maximum depth, handling +// auto-executable and non-auto-executable tools appropriately. +// +// Parameters: +// - ctx: Context for agent execution +// - maxAgentDepth: Maximum number of agent iterations allowed +// - originalReq: The original chat request +// - initialResponse: The initial chat response containing tool calls +// - makeReq: Function to make subsequent chat requests during agent execution +// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration +// - executeToolFunc: Function to execute individual tool calls +// - clientManager: Client manager for accessing MCP clients and tools +// +// Returns: +// - *schemas.BifrostChatResponse: The final response after agent execution +// - *schemas.BifrostError: Any error that occurred during agent execution +func ExecuteAgentForChatRequest( + ctx *schemas.BifrostContext, + maxAgentDepth int, + originalReq *schemas.BifrostChatRequest, + initialResponse *schemas.BifrostChatResponse, + makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError), + fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, + executeToolFunc func(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error), + clientManager ClientManager, +) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + // Create adapter for Chat API + adapter := &chatAPIAdapter{ + originalReq: originalReq, + initialResponse: initialResponse, + makeReq: makeReq, + } + + result, err := executeAgent(ctx, maxAgentDepth, adapter, fetchNewRequestIDFunc, executeToolFunc, clientManager) + if err != nil { + return nil, err + } + + chatResponse, ok := result.(*schemas.BifrostChatResponse) + // Should never happen, but just in case + if !ok { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "Failed to convert result to schemas.BifrostChatResponse", + }, + } + } + + return chatResponse, nil +} + +// ExecuteAgentForResponsesRequest handles the agent mode execution loop for Responses API. +// It orchestrates iterative tool execution up to the maximum depth, handling +// auto-executable and non-auto-executable tools appropriately. +// +// Parameters: +// - ctx: Context for agent execution +// - maxAgentDepth: Maximum number of agent iterations allowed +// - originalReq: The original responses request +// - initialResponse: The initial responses response containing tool calls +// - makeReq: Function to make subsequent responses requests during agent execution +// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration +// - executeToolFunc: Function to execute individual tool calls +// - clientManager: Client manager for accessing MCP clients and tools +// +// Returns: +// - *schemas.BifrostResponsesResponse: The final response after agent execution +// - *schemas.BifrostError: Any error that occurred during agent execution +func ExecuteAgentForResponsesRequest( + ctx *schemas.BifrostContext, + maxAgentDepth int, + originalReq *schemas.BifrostResponsesRequest, + initialResponse *schemas.BifrostResponsesResponse, + makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError), + fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, + executeToolFunc func(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error), + clientManager ClientManager, +) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + // Create adapter for Responses API + adapter := &responsesAPIAdapter{ + originalReq: originalReq, + initialResponse: initialResponse, + makeReq: makeReq, + } + + result, err := executeAgent(ctx, maxAgentDepth, adapter, fetchNewRequestIDFunc, executeToolFunc, clientManager) + if err != nil { + return nil, err + } + + responsesResponse, ok := result.(*schemas.BifrostResponsesResponse) + // Should never happen, but just in case + if !ok { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "Failed to convert result to schemas.BifrostResponsesResponse", + }, + } + } + + return responsesResponse, nil +} + +// executeAgent handles the generic agent mode execution loop using an API adapter pattern. +// It iteratively executes tools, separates auto-executable from non-auto-executable tools, +// executes auto-executable tools in parallel, and continues the loop until no more tool +// calls are present or the maximum depth is reached. +// +// Parameters: +// - ctx: Context for agent execution (may be modified to add request IDs) +// - maxAgentDepth: Maximum number of agent iterations allowed +// - adapter: API adapter that abstracts differences between Chat and Responses APIs +// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for each iteration +// - executeToolFunc: Function to execute individual tool calls +// - clientManager: Client manager for accessing MCP clients and tools +// +// Returns: +// - interface{}: The final response after agent execution (type depends on adapter) +// - *schemas.BifrostError: Any error that occurred during agent execution +func executeAgent( + ctx *schemas.BifrostContext, + maxAgentDepth int, + adapter agentAPIAdapter, + fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, + executeToolFunc func(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error), + clientManager ClientManager, +) (interface{}, *schemas.BifrostError) { + logger.Debug("Entering agent mode - detected tool calls in response") + + // Get initial response from adapter + currentResponse := adapter.getInitialResponse() + + // Create conversation history starting with original messages + conversationHistory := adapter.getConversationHistory() + + depth := 0 + + // Track all executed tool results and tool calls across all iterations + allExecutedToolResults := make([]*schemas.ChatMessage, 0) + allExecutedToolCalls := make([]schemas.ChatAssistantMessageToolCall, 0) + + originalRequestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string) + if ok { + ctx.SetValue(schemas.BifrostMCPAgentOriginalRequestID, originalRequestID) + } + + for depth < maxAgentDepth { + depth++ + toolCalls := adapter.extractToolCalls(currentResponse) + if len(toolCalls) == 0 { + logger.Debug("No more tool calls found, exiting agent mode") + break + } + + logger.Debug(fmt.Sprintf("Agent mode depth %d: executing %d tool calls", depth, len(toolCalls))) + + // Separate tools into auto-executable and non-auto-executable groups + var autoExecutableTools []schemas.ChatAssistantMessageToolCall + var nonAutoExecutableTools []schemas.ChatAssistantMessageToolCall + + for _, toolCall := range toolCalls { + if toolCall.Function.Name == nil { + // Skip tools without names + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + continue + } + + toolName := *toolCall.Function.Name + client := clientManager.GetClientForTool(toolName) + if client == nil { + // Allow code mode list and read tool tools + if toolName == ToolTypeListToolFiles || toolName == ToolTypeReadToolFile { + autoExecutableTools = append(autoExecutableTools, toolCall) + logger.Debug(fmt.Sprintf("Tool %s can be auto-executed", toolName)) + continue + } else if toolName == ToolTypeExecuteToolCode { + // Build allowed auto-execution tools map for code mode validation + allClientNames, allowedAutoExecutionTools := buildAllowedAutoExecutionTools(ctx, clientManager) + + // Parse tool arguments + var arguments map[string]interface{} + if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + logger.Debug(fmt.Sprintf("%s Failed to parse tool arguments: %v", CodeModeLogPrefix, err)) + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + continue + } + + code, ok := arguments["code"].(string) + if !ok || code == "" { + logger.Debug(fmt.Sprintf("%s Code parameter missing or empty", CodeModeLogPrefix)) + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + continue + } + + // Step 1: Convert literal \n escape sequences to actual newlines for parsing + codeWithNewlines := strings.ReplaceAll(code, "\\n", "\n") + if len(codeWithNewlines) != len(code) { + logger.Debug(fmt.Sprintf("%s Converted literal \\n escape sequences to actual newlines", CodeModeLogPrefix)) + } + + // Step 2: Extract tool calls from code during AST formation + extractedToolCalls, err := extractToolCallsFromCode(codeWithNewlines) + if err != nil { + logger.Debug(fmt.Sprintf("%s Failed to parse code for tool calls: %v", CodeModeLogPrefix, err)) + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + continue + } + + logger.Debug(fmt.Sprintf("%s Extracted %d tool call(s) from code", CodeModeLogPrefix, len(extractedToolCalls))) + + // Step 3: Validate all tool calls against allowedAutoExecutionTools + canAutoExecute := true + if len(extractedToolCalls) > 0 { + // If there are tool calls, we need allowedAutoExecutionTools to validate them + if len(allowedAutoExecutionTools) == 0 { + logger.Debug(fmt.Sprintf("%s Validation failed: no allowed auto-execution tools configured", CodeModeLogPrefix)) + canAutoExecute = false + } else { + logger.Debug(fmt.Sprintf("%s Validating %d tool call(s) against %d allowed server(s)", CodeModeLogPrefix, len(extractedToolCalls), len(allowedAutoExecutionTools))) + + // Validate each tool call + for _, extractedToolCall := range extractedToolCalls { + isAllowed := isToolCallAllowedForCodeMode(extractedToolCall.serverName, extractedToolCall.toolName, allClientNames, allowedAutoExecutionTools) + if !isAllowed { + logger.Debug(fmt.Sprintf("%s Tool call %s.%s: allowed=%v", CodeModeLogPrefix, extractedToolCall.serverName, extractedToolCall.toolName, isAllowed)) + logger.Debug(fmt.Sprintf("%s Validation failed: tool call %s.%s not in auto-execute list", CodeModeLogPrefix, extractedToolCall.serverName, extractedToolCall.toolName)) + canAutoExecute = false + break + } + } + if canAutoExecute { + logger.Debug(fmt.Sprintf("%s All tool calls validated successfully", CodeModeLogPrefix)) + } + } + } else { + logger.Debug(fmt.Sprintf("%s No tool calls found in code, skipping validation", CodeModeLogPrefix)) + } + + // Add to appropriate list based on validation result + if canAutoExecute { + autoExecutableTools = append(autoExecutableTools, toolCall) + logger.Debug(fmt.Sprintf("Tool %s can be auto-executed (validation passed)", toolName)) + } else { + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + logger.Debug(fmt.Sprintf("Tool %s cannot be auto-executed (validation failed)", toolName)) + } + continue + } + // Else, if client not found, treat as non-auto-executable (can be a manually passed tool) + logger.Debug(fmt.Sprintf("Client not found for tool %s, treating as non-auto-executable", toolName)) + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + continue + } + + // Check if tool can be auto-executed + if canAutoExecuteTool(toolName, client.ExecutionConfig) { + autoExecutableTools = append(autoExecutableTools, toolCall) + logger.Debug(fmt.Sprintf("Tool %s can be auto-executed", toolName)) + } else { + nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) + logger.Debug(fmt.Sprintf("Tool %s cannot be auto-executed", toolName)) + } + } + + logger.Debug(fmt.Sprintf("Auto-executable tools: %d", len(autoExecutableTools))) + logger.Debug(fmt.Sprintf("Non-auto-executable tools: %d", len(nonAutoExecutableTools))) + + // Execute auto-executable tools first + var executedToolResults []*schemas.ChatMessage + if len(autoExecutableTools) > 0 { + // Add assistant message with auto-executable tool calls to conversation + conversationHistory = adapter.addAssistantMessage(conversationHistory, currentResponse) + + // Execute all auto-executable tool calls parallelly + wg := sync.WaitGroup{} + wg.Add(len(autoExecutableTools)) + channelToolResults := make(chan *schemas.ChatMessage, len(autoExecutableTools)) + for _, toolCall := range autoExecutableTools { + go func(toolCall schemas.ChatAssistantMessageToolCall) { + defer wg.Done() + toolResult, toolErr := executeToolFunc(ctx, toolCall) + if toolErr != nil { + logger.Warn(fmt.Sprintf("Tool execution failed: %v", toolErr)) + channelToolResults <- createToolResultMessage(toolCall, "", toolErr) + } else { + channelToolResults <- toolResult + } + }(toolCall) + } + wg.Wait() + close(channelToolResults) + + // Collect tool results + executedToolResults = make([]*schemas.ChatMessage, 0, len(autoExecutableTools)) + for toolResult := range channelToolResults { + executedToolResults = append(executedToolResults, toolResult) + } + + // Track executed tool results and calls across all iterations + allExecutedToolResults = append(allExecutedToolResults, executedToolResults...) + allExecutedToolCalls = append(allExecutedToolCalls, autoExecutableTools...) + + // Add tool results to conversation history + conversationHistory = adapter.addToolResults(conversationHistory, executedToolResults) + } + + // If there are non-auto-executable tools, return them immediately without continuing the loop + if len(nonAutoExecutableTools) > 0 { + logger.Debug(fmt.Sprintf("Found %d non-auto-executable tools, returning them immediately without continuing the loop", len(nonAutoExecutableTools))) + // Return as is if its the first iteration + if depth == 1 && len(allExecutedToolResults) == 0 { + return currentResponse, nil + } + // Create response with all executed tool results from all iterations, and non-auto-executable tool calls + return adapter.createResponseWithExecutedTools(currentResponse, allExecutedToolResults, allExecutedToolCalls, nonAutoExecutableTools), nil + } + + // Create new request with updated conversation history + newReq := adapter.createNewRequest(conversationHistory) + + if fetchNewRequestIDFunc != nil { + newID := fetchNewRequestIDFunc(ctx) + if newID != "" { + ctx.SetValue(schemas.BifrostContextKeyRequestID, newID) + } + } + + // Make new LLM request + response, err := adapter.makeLLMCall(ctx, newReq) + if err != nil { + logger.Error("Agent mode: LLM request failed: %v", err) + return nil, err + } + + currentResponse = response + } + + logger.Debug(fmt.Sprintf("Agent mode completed after %d iterations", depth)) + return currentResponse, nil +} + +// extractToolCalls extracts all tool calls from a chat response. +// It iterates through all choices in the response and collects tool calls +// from assistant messages. +// +// Parameters: +// - response: The chat response to extract tool calls from +// +// Returns: +// - []schemas.ChatAssistantMessageToolCall: List of extracted tool calls, or nil if none found +func extractToolCalls(response *schemas.BifrostChatResponse) []schemas.ChatAssistantMessageToolCall { + if !hasToolCallsForChatResponse(response) { + return nil + } + + var toolCalls []schemas.ChatAssistantMessageToolCall + for _, choice := range response.Choices { + if choice.ChatNonStreamResponseChoice != nil && + choice.ChatNonStreamResponseChoice.Message != nil && + choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil { + toolCalls = append(toolCalls, choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls...) + } + } + + return toolCalls +} + +// createToolResultMessage creates a tool result message from tool execution. +// It formats the result or error into a chat message with the appropriate tool call ID. +// +// Parameters: +// - toolCall: The original tool call that was executed +// - result: The successful execution result (ignored if err is not nil) +// - err: Any error that occurred during tool execution +// +// Returns: +// - *schemas.ChatMessage: A tool message containing the execution result or error +func createToolResultMessage(toolCall schemas.ChatAssistantMessageToolCall, result string, err error) *schemas.ChatMessage { + var content string + if err != nil { + content = fmt.Sprintf("Error executing tool %s: %s", + func() string { + if toolCall.Function.Name != nil { + return *toolCall.Function.Name + } + return "unknown" + }(), err.Error()) + } else { + content = result + } + + return &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + Content: &schemas.ChatMessageContent{ + ContentStr: &content, + }, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: toolCall.ID, + }, + } +} + +// buildAllowedAutoExecutionTools builds a map of client names to their auto-executable tools. +// It processes code mode clients and parses their ToolsToAutoExecute configuration to create +// a map of allowed tools. Tool names are parsed to match their appearance in JavaScript code. +// +// Parameters: +// - ctx: Context for accessing client tools +// - clientManager: Client manager for accessing MCP clients +// +// Returns: +// - []string: List of all client names +// - map[string][]string: Map of client names to their auto-executable tool names (as they appear in code) +func buildAllowedAutoExecutionTools(ctx *schemas.BifrostContext, clientManager ClientManager) ([]string, map[string][]string) { + allowedTools := make(map[string][]string) + availableToolsPerClient := clientManager.GetToolPerClient(ctx) + allClientNames := []string{} + + for clientName := range availableToolsPerClient { + client := clientManager.GetClientByName(clientName) + if client == nil { + continue + } + allClientNames = append(allClientNames, clientName) + + // Only include code mode clients + if !client.ExecutionConfig.IsCodeModeClient { + continue + } + + // Get auto-executable tools from config + toolsToAutoExecute := client.ExecutionConfig.ToolsToAutoExecute + if len(toolsToAutoExecute) == 0 { + // No auto-executable tools configured for this client + continue + } + + // Parse tool names (as they appear in JavaScript code) + autoExecutableTools := []string{} + for _, originalToolName := range toolsToAutoExecute { + // Handle wildcard "*" - means all tools are auto-executable + if originalToolName == "*" { + autoExecutableTools = append(autoExecutableTools, "*") + continue + } + // Use parsed tool name (as it appears in code) + parsedToolName := parseToolName(originalToolName) + autoExecutableTools = append(autoExecutableTools, parsedToolName) + } + + // Add to map if there are auto-executable tools + if len(autoExecutableTools) > 0 { + allowedTools[clientName] = autoExecutableTools + } + } + + return allClientNames, allowedTools +} diff --git a/core/mcp/agent_test.go b/core/mcp/agent_test.go new file mode 100644 index 0000000000..350a5ac2ae --- /dev/null +++ b/core/mcp/agent_test.go @@ -0,0 +1,719 @@ +package mcp + +import ( + "context" + "encoding/json" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +// MockLLMCaller implements schemas.BifrostLLMCaller for testing +type MockLLMCaller struct { + chatResponses []*schemas.BifrostChatResponse + responsesResponses []*schemas.BifrostResponsesResponse + chatCallCount int + responsesCallCount int +} + +func (m *MockLLMCaller) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + if m.chatCallCount >= len(m.chatResponses) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "no more mock chat responses available", + }, + } + } + + response := m.chatResponses[m.chatCallCount] + m.chatCallCount++ + return response, nil +} + +func (m *MockLLMCaller) ResponsesRequest(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + if m.responsesCallCount >= len(m.responsesResponses) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "no more mock responses api responses available", + }, + } + } + + response := m.responsesResponses[m.responsesCallCount] + m.responsesCallCount++ + return response, nil +} + +// MockLogger implements schemas.Logger for testing +type MockLogger struct{} + +func (m *MockLogger) Debug(msg string, args ...any) {} +func (m *MockLogger) Info(msg string, args ...any) {} +func (m *MockLogger) Warn(msg string, args ...any) {} +func (m *MockLogger) Error(msg string, args ...any) {} +func (m *MockLogger) Fatal(msg string, args ...any) {} +func (m *MockLogger) SetLevel(level schemas.LogLevel) {} +func (m *MockLogger) SetOutputType(outputType schemas.LoggerOutputType) {} + +// MockClientManager implements ClientManager for testing +type MockClientManager struct{} + +func (m *MockClientManager) GetClientForTool(toolName string) *schemas.MCPClientState { + return nil // Return nil to simulate no client found +} + +func (m *MockClientManager) GetClientByName(clientName string) *schemas.MCPClientState { + return nil +} + +func (m *MockClientManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool { + return make(map[string][]schemas.ChatTool) +} + +func TestHasToolCallsForChatResponse(t *testing.T) { + // Test nil response + if hasToolCallsForChatResponse(nil) { + t.Error("Should return false for nil response") + } + + // Test empty choices + emptyResponse := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{}, + } + if hasToolCallsForChatResponse(emptyResponse) { + t.Error("Should return false for response with empty choices") + } + + // Test response with tool_calls finish reason + toolCallsResponse := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + FinishReason: schemas.Ptr("tool_calls"), + }, + }, + } + if !hasToolCallsForChatResponse(toolCallsResponse) { + t.Error("Should return true for response with tool_calls finish reason") + } + + // Test response with actual tool calls + responseWithToolCalls := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{ + { + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr("test_tool"), + }, + }, + }, + }, + }, + }, + }, + }, + } + if !hasToolCallsForChatResponse(responseWithToolCalls) { + t.Error("Should return true for response with tool calls in message") + } + + // Test response with stop finish reason (should return false even with tool calls) + responseWithStopReason := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + FinishReason: schemas.Ptr("stop"), + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{ + { + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr("test_tool"), + }, + }, + }, + }, + }, + }, + }, + }, + } + if hasToolCallsForChatResponse(responseWithStopReason) { + t.Error("Should return false for response with stop finish reason even with tool calls") + } +} + +func TestExtractToolCalls(t *testing.T) { + // Test response without tool calls + responseNoTools := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + FinishReason: schemas.Ptr("stop"), + }, + }, + } + + toolCalls := extractToolCalls(responseNoTools) + if len(toolCalls) != 0 { + t.Error("Should return empty slice for response without tool calls") + } + + // Test response with tool calls + expectedToolCalls := []schemas.ChatAssistantMessageToolCall{ + { + ID: schemas.Ptr("call_123"), + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr("test_tool"), + Arguments: `{"param": "value"}`, + }, + }, + } + + responseWithTools := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: expectedToolCalls, + }, + }, + }, + }, + }, + } + + actualToolCalls := extractToolCalls(responseWithTools) + if len(actualToolCalls) != 1 { + t.Errorf("Expected 1 tool call, got %d", len(actualToolCalls)) + } + + if actualToolCalls[0].Function.Name == nil || *actualToolCalls[0].Function.Name != "test_tool" { + t.Error("Tool call name mismatch") + } +} + +func TestExecuteAgentForChatRequest(t *testing.T) { + // Set up logger for the test + SetLogger(&MockLogger{}) + + // Test with response that has no tool calls - should return immediately + responseNoTools := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + FinishReason: schemas.Ptr("stop"), + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("Hello, how can I help you?"), + }, + }, + }, + }, + }, + } + + llmCaller := &MockLLMCaller{} + makeReq := func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + return llmCaller.ChatCompletionRequest(ctx, req) + } + originalReq := &schemas.BifrostChatRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("Hello"), + }, + }, + }, + } + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + + result, err := ExecuteAgentForChatRequest(ctx, 10, originalReq, responseNoTools, makeReq, nil, nil, &MockClientManager{}) + if err != nil { + t.Errorf("Expected no error for response without tool calls, got: %v", err) + } + if result != responseNoTools { + t.Error("Expected same response to be returned for response without tool calls") + } +} + +func TestExecuteAgentForChatRequest_WithNonAutoExecutableTools(t *testing.T) { + // Set up logger for the test + SetLogger(&MockLogger{}) + + // Create a response with tool calls that will NOT be auto-executed + responseWithNonAutoTools := &schemas.BifrostChatResponse{ + Choices: []schemas.BifrostResponseChoice{ + { + FinishReason: schemas.Ptr("tool_calls"), + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("I need to call a tool"), + }, + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{ + { + ID: schemas.Ptr("call_123"), + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr("non_auto_executable_tool"), + Arguments: `{"param": "value"}`, + }, + }, + }, + }, + }, + }, + }, + }, + } + + llmCaller := &MockLLMCaller{} + makeReq := func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + return llmCaller.ChatCompletionRequest(ctx, req) + } + originalReq := &schemas.BifrostChatRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("Test message"), + }, + }, + }, + } + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + + // Execute agent mode - should return immediately with non-auto-executable tools + result, err := ExecuteAgentForChatRequest(ctx, 10, originalReq, responseWithNonAutoTools, makeReq, nil, nil, &MockClientManager{}) + + // Should not return error for non-auto-executable tools + if err != nil { + t.Errorf("Expected no error for non-auto-executable tools, got: %v", err) + } + + // Should return a response with the non-auto-executable tool calls + if result == nil { + t.Error("Expected result to be returned for non-auto-executable tools") + } + + // Verify that no LLM calls were made (since tools are non-auto-executable) + if llmCaller.chatCallCount != 0 { + t.Errorf("Expected 0 LLM calls for non-auto-executable tools, got %d", llmCaller.chatCallCount) + } +} + +func TestHasToolCallsForResponsesResponse(t *testing.T) { + // Test nil response + if hasToolCallsForResponsesResponse(nil) { + t.Error("Should return false for nil response") + } + + // Test empty output + emptyResponse := &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{}, + } + if hasToolCallsForResponsesResponse(emptyResponse) { + t.Error("Should return false for response with empty output") + } + + // Test response with function call + responseWithFunctionCall := &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr("call_123"), + Name: schemas.Ptr("test_tool"), + }, + }, + }, + } + if !hasToolCallsForResponsesResponse(responseWithFunctionCall) { + t.Error("Should return true for response with function call") + } + + // Test response with function call but no ResponsesToolMessage + responseWithoutToolMessage := &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + // No ResponsesToolMessage + }, + }, + } + if hasToolCallsForResponsesResponse(responseWithoutToolMessage) { + t.Error("Should return false for response with function call type but no ResponsesToolMessage") + } + + // Test response with regular message + responseWithRegularMessage := &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("Hello"), + }, + }, + }, + } + if hasToolCallsForResponsesResponse(responseWithRegularMessage) { + t.Error("Should return false for response with regular message") + } +} + +func TestExecuteAgentForResponsesRequest(t *testing.T) { + // Set up logger for the test + SetLogger(&MockLogger{}) + + // Test with response that has no tool calls - should return immediately + responseNoTools := &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("Hello, how can I help you?"), + }, + }, + }, + } + + llmCaller := &MockLLMCaller{} + makeReq := func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + return llmCaller.ResponsesRequest(ctx, req) + } + originalReq := &schemas.BifrostResponsesRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("Hello"), + }, + }, + }, + } + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + + result, err := ExecuteAgentForResponsesRequest(ctx, 10, originalReq, responseNoTools, makeReq, nil, nil, &MockClientManager{}) + if err != nil { + t.Errorf("Expected no error for response without tool calls, got: %v", err) + } + if result != responseNoTools { + t.Error("Expected same response to be returned for response without tool calls") + } +} + +func TestExecuteAgentForResponsesRequest_WithNonAutoExecutableTools(t *testing.T) { + // Set up logger for the test + SetLogger(&MockLogger{}) + + // Create a response with tool calls that will NOT be auto-executed + responseWithNonAutoTools := &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr("call_123"), + Name: schemas.Ptr("non_auto_executable_tool"), + Arguments: schemas.Ptr(`{"param": "value"}`), + }, + }, + }, + } + + llmCaller := &MockLLMCaller{} + makeReq := func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + return llmCaller.ResponsesRequest(ctx, req) + } + originalReq := &schemas.BifrostResponsesRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("Test message"), + }, + }, + }, + } + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + + // Execute agent mode - should return immediately with non-auto-executable tools + result, err := ExecuteAgentForResponsesRequest(ctx, 10, originalReq, responseWithNonAutoTools, makeReq, nil, nil, &MockClientManager{}) + + // Should not return error for non-auto-executable tools + if err != nil { + t.Errorf("Expected no error for non-auto-executable tools, got: %v", err) + } + + // Should return a response with the non-auto-executable tool calls + if result == nil { + t.Error("Expected result to be returned for non-auto-executable tools") + } + + // Verify that no LLM calls were made (since tools are non-auto-executable) + if llmCaller.responsesCallCount != 0 { + t.Errorf("Expected 0 LLM calls for non-auto-executable tools, got %d", llmCaller.responsesCallCount) + } +} + +// ============================================================================ +// CONVERTER TESTS (Phase 2) +// ============================================================================ + +// TestResponsesToolMessageToChatAssistantMessageToolCall tests conversion of Responses tool message to Chat tool call +func TestResponsesToolMessageToChatAssistantMessageToolCall(t *testing.T) { + // Test with valid tool message + responsesToolMsg := &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr("call-123"), + Name: schemas.Ptr("calculate"), + Arguments: schemas.Ptr("{\"x\": 10, \"y\": 20}"), + } + + chatToolCall := responsesToolMsg.ToChatAssistantMessageToolCall() + + if chatToolCall == nil { + t.Fatal("Expected non-nil ChatAssistantMessageToolCall") + } + + if chatToolCall.Type == nil || *chatToolCall.Type != "function" { + t.Errorf("Expected Type 'function', got %v", chatToolCall.Type) + } + + if chatToolCall.Function.Name == nil || *chatToolCall.Function.Name != "calculate" { + t.Errorf("Expected Name 'calculate', got %v", chatToolCall.Function.Name) + } + + if chatToolCall.Function.Arguments != `{"x": 10, "y": 20}` { + t.Errorf("Expected Arguments '{\"x\": 10, \"y\": 20}', got %s", chatToolCall.Function.Arguments) + } +} + +// TestResponsesToolMessageToChatAssistantMessageToolCall_Nil tests nil handling +func TestResponsesToolMessageToChatAssistantMessageToolCall_Nil(t *testing.T) { + responsesToolMsg := &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr("call-123"), + Name: schemas.Ptr("calculate"), + Arguments: nil, // Test nil Arguments case + } + + chatToolCall := responsesToolMsg.ToChatAssistantMessageToolCall() + if chatToolCall == nil { + t.Fatal("Expected non-nil ChatAssistantMessageToolCall") + } + + // Assert that nil Arguments produces a valid empty JSON object + if chatToolCall.Function.Arguments != "{}" { + t.Errorf("Expected Arguments '{}' for nil input, got %q", chatToolCall.Function.Arguments) + } + + // Verify it's valid JSON by attempting to unmarshal + var args map[string]interface{} + if err := json.Unmarshal([]byte(chatToolCall.Function.Arguments), &args); err != nil { + t.Errorf("Expected valid JSON, but unmarshaling failed: %v", err) + } +} + +// TestChatMessageToResponsesToolMessage tests conversion of Chat tool result to Responses tool message +func TestChatMessageToResponsesToolMessage(t *testing.T) { + // Test with valid chat tool message + chatMsg := &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: schemas.Ptr("call-123"), + }, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("Result: 30"), + }, + } + + responsesMsg := chatMsg.ToResponsesToolMessage() + + if responsesMsg == nil { + t.Fatal("Expected non-nil ResponsesMessage") + } + + if responsesMsg.Type == nil || *responsesMsg.Type != schemas.ResponsesMessageTypeFunctionCallOutput { + t.Errorf("Expected Type 'function_call_output', got %v", responsesMsg.Type) + } + + if responsesMsg.ResponsesToolMessage == nil { + t.Fatal("Expected non-nil ResponsesToolMessage") + } + + if responsesMsg.ResponsesToolMessage.CallID == nil || *responsesMsg.ResponsesToolMessage.CallID != "call-123" { + t.Errorf("Expected CallID 'call-123', got %v", responsesMsg.ResponsesToolMessage.CallID) + } + + if responsesMsg.ResponsesToolMessage.Output == nil { + t.Fatal("Expected non-nil Output") + } + + if responsesMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr == nil { + t.Fatal("Expected non-nil ResponsesToolCallOutputStr") + } + + if *responsesMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != "Result: 30" { + t.Errorf("Expected Output 'Result: 30', got %s", *responsesMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr) + } +} + +// TestChatMessageToResponsesToolMessage_Nil tests nil handling +func TestChatMessageToResponsesToolMessage_Nil(t *testing.T) { + var chatMsg *schemas.ChatMessage + + responsesMsg := chatMsg.ToResponsesToolMessage() + + if responsesMsg != nil { + t.Errorf("Expected nil for nil input, got %v", responsesMsg) + } +} + +// TestChatMessageToResponsesToolMessage_NoToolMessage tests with non-tool message +func TestChatMessageToResponsesToolMessage_NoToolMessage(t *testing.T) { + // Chat message without ChatToolMessage + chatMsg := &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + } + + responsesMsg := chatMsg.ToResponsesToolMessage() + + if responsesMsg != nil { + t.Errorf("Expected nil for non-tool message, got %v", responsesMsg) + } +} + +// ============================================================================ +// RESPONSES API TOOL CONVERSION TESTS (Phase 3) +// ============================================================================ + +// TestExecuteAgentForResponsesRequest_ConversionRoundTrip tests that tool calls survive format conversion +// This is a unit test of the conversion logic only, not full agent execution +func TestExecuteAgentForResponsesRequest_ConversionRoundTrip(t *testing.T) { + // Create a tool message in Responses format + responsesToolMsg := &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr("call-456"), + Name: schemas.Ptr("readToolFile"), + Arguments: schemas.Ptr("{\"file\": \"test.txt\"}"), + } + + // Step 1: Convert Responses format to Chat format + chatToolCall := responsesToolMsg.ToChatAssistantMessageToolCall() + + if chatToolCall == nil { + t.Fatal("Failed to convert Responses to Chat format") + } + + if *chatToolCall.ID != "call-456" { + t.Errorf("ID lost in conversion: expected 'call-456', got %s", *chatToolCall.ID) + } + + if *chatToolCall.Function.Name != "readToolFile" { + t.Errorf("Name lost in conversion: expected 'readToolFile', got %s", *chatToolCall.Function.Name) + } + + if chatToolCall.Function.Arguments != "{\"file\": \"test.txt\"}" { + t.Errorf("Arguments lost in conversion: expected '%s', got %s", + "{\"file\": \"test.txt\"}", chatToolCall.Function.Arguments) + } + + // Step 2: Simulate tool execution by creating a result message + chatResultMsg := &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: chatToolCall.ID, + }, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("File contents here"), + }, + } + + // Step 3: Convert tool result back to Responses format + responsesResultMsg := chatResultMsg.ToResponsesToolMessage() + + if responsesResultMsg == nil { + t.Fatal("Failed to convert Chat result to Responses format") + } + + if responsesResultMsg.ResponsesToolMessage.CallID == nil { + t.Error("CallID lost in round-trip conversion") + } else if *responsesResultMsg.ResponsesToolMessage.CallID != "call-456" { + t.Errorf("CallID changed in round-trip: expected 'call-456', got %s", *responsesResultMsg.ResponsesToolMessage.CallID) + } + + // Verify output is preserved + if responsesResultMsg.ResponsesToolMessage.Output == nil { + t.Error("Output lost in conversion") + } else if responsesResultMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr == nil { + t.Error("Output content lost in conversion") + } else if *responsesResultMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr != "File contents here" { + t.Errorf("Output content changed: expected 'File contents here', got %s", + *responsesResultMsg.ResponsesToolMessage.Output.ResponsesToolCallOutputStr) + } + + // Verify message type is correct + if responsesResultMsg.Type == nil || *responsesResultMsg.Type != schemas.ResponsesMessageTypeFunctionCallOutput { + t.Errorf("Expected message type 'function_call_output', got %v", responsesResultMsg.Type) + } +} + +// TestExecuteAgentForResponsesRequest_OutputStructured tests conversion with structured output blocks +func TestExecuteAgentForResponsesRequest_OutputStructured(t *testing.T) { + chatResultMsg := &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: schemas.Ptr("call-789"), + }, + Content: &schemas.ChatMessageContent{ + ContentBlocks: []schemas.ChatContentBlock{ + { + Type: schemas.ChatContentBlockTypeText, + Text: schemas.Ptr("Block 1"), + }, + { + Type: schemas.ChatContentBlockTypeText, + Text: schemas.Ptr("Block 2"), + }, + }, + }, + } + + responsesMsg := chatResultMsg.ToResponsesToolMessage() + + if responsesMsg == nil { + t.Fatal("Expected non-nil ResponsesMessage for structured output") + } + + if responsesMsg.ResponsesToolMessage.Output == nil { + t.Fatal("Expected non-nil Output for structured content") + } + + if responsesMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks == nil { + t.Error("Expected output blocks for structured content") + } else if len(responsesMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks) != 2 { + t.Errorf("Expected 2 output blocks, got %d", len(responsesMsg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks)) + } +} diff --git a/core/mcp/agentadaptors.go b/core/mcp/agentadaptors.go new file mode 100644 index 0000000000..3a32694d3e --- /dev/null +++ b/core/mcp/agentadaptors.go @@ -0,0 +1,563 @@ +package mcp + +import ( + "fmt" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" +) + +// agentAPIAdapter defines the interface for API-specific operations in agent mode. +// This adapter pattern allows the agent execution logic to work with both Chat Completions +// and Responses APIs without requiring API-specific code in the agent loop. +// +// The adapter handles format conversions at the boundaries: +// - Responses API requests/responses are converted to/from Chat API format +// - Tool calls are extracted in Chat format for uniform processing +// - Results are converted back to the original API format for the response +// +// This design ensures that: +// 1. Tool execution logic is format-agnostic +// 2. Both APIs have feature parity +// 3. Conversions are localized to adapters +// 4. The agent loop remains API-neutral +type agentAPIAdapter interface { + // Extract conversation history from the original request + getConversationHistory() []interface{} + + // Get original request + getOriginalRequest() interface{} + + // Get initial response + getInitialResponse() interface{} + + // Check if response has tool calls + hasToolCalls(response interface{}) bool + + // Extract tool calls from response. + // For Chat API: Returns tool calls directly from the response. + // For Responses API: Converts ResponsesMessage tool calls to ChatAssistantMessageToolCall for processing. + extractToolCalls(response interface{}) []schemas.ChatAssistantMessageToolCall + + // Add assistant message with tool calls to conversation + addAssistantMessage(conversation []interface{}, response interface{}) []interface{} + + // Add tool results to conversation. + // For Chat API: Adds ChatMessage results directly. + // For Responses API: Converts ChatMessage results to ResponsesMessage via ToResponsesToolMessage(). + addToolResults(conversation []interface{}, toolResults []*schemas.ChatMessage) []interface{} + + // Create new request with updated conversation + createNewRequest(conversation []interface{}) interface{} + + // Make LLM call + makeLLMCall(ctx *schemas.BifrostContext, request interface{}) (interface{}, *schemas.BifrostError) + + // Create response with executed tools and non-auto-executable calls + createResponseWithExecutedTools( + response interface{}, + executedToolResults []*schemas.ChatMessage, + executedToolCalls []schemas.ChatAssistantMessageToolCall, + nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall, + ) interface{} +} + +// chatAPIAdapter implements agentAPIAdapter for Chat API +type chatAPIAdapter struct { + originalReq *schemas.BifrostChatRequest + initialResponse *schemas.BifrostChatResponse + makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) +} + +// responsesAPIAdapter implements agentAPIAdapter for Responses API. +// It enables the agent mode execution loop to work with Responses API requests and responses +// by handling format conversions transparently. +// +// Key conversions performed: +// - extractToolCalls(): Converts ResponsesMessage tool calls to ChatAssistantMessageToolCall +// via BifrostResponsesResponse.ToBifrostChatResponse() and existing extraction logic +// - addToolResults(): Converts ChatMessage tool results back to ResponsesMessage +// via ChatMessage.ToResponsesMessages() and ToResponsesToolMessage() +// - createNewRequest(): Builds a new BifrostResponsesRequest from converted conversation +// - createResponseWithExecutedTools(): Creates a Responses response with results and pending tools +// +// This adapter enables full feature parity between Chat Completions and Responses APIs +// for tool execution in agent mode. +type responsesAPIAdapter struct { + originalReq *schemas.BifrostResponsesRequest + initialResponse *schemas.BifrostResponsesResponse + makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) +} + +// Chat API adapter implementations +func (c *chatAPIAdapter) getConversationHistory() []interface{} { + history := make([]interface{}, 0) + if c.originalReq.Input != nil { + for _, msg := range c.originalReq.Input { + history = append(history, msg) + } + } + return history +} + +func (c *chatAPIAdapter) getOriginalRequest() interface{} { + return c.originalReq +} + +func (c *chatAPIAdapter) getInitialResponse() interface{} { + return c.initialResponse +} + +func (c *chatAPIAdapter) hasToolCalls(response interface{}) bool { + chatResponse := response.(*schemas.BifrostChatResponse) + return hasToolCallsForChatResponse(chatResponse) +} + +func (c *chatAPIAdapter) extractToolCalls(response interface{}) []schemas.ChatAssistantMessageToolCall { + chatResponse := response.(*schemas.BifrostChatResponse) + return extractToolCalls(chatResponse) +} + +func (c *chatAPIAdapter) addAssistantMessage(conversation []interface{}, response interface{}) []interface{} { + chatResponse := response.(*schemas.BifrostChatResponse) + for _, choice := range chatResponse.Choices { + if choice.ChatNonStreamResponseChoice != nil && choice.ChatNonStreamResponseChoice.Message != nil { + conversation = append(conversation, *choice.ChatNonStreamResponseChoice.Message) + } + } + return conversation +} + +func (c *chatAPIAdapter) addToolResults(conversation []interface{}, toolResults []*schemas.ChatMessage) []interface{} { + for _, toolResult := range toolResults { + conversation = append(conversation, *toolResult) + } + return conversation +} + +func (c *chatAPIAdapter) createNewRequest(conversation []interface{}) interface{} { + // Convert conversation back to ChatMessage slice + chatMessages := make([]schemas.ChatMessage, 0, len(conversation)) + for _, msg := range conversation { + if msg == nil { + continue + } + if chatMessage, ok := msg.(schemas.ChatMessage); ok { + chatMessages = append(chatMessages, chatMessage) + } + } + + return &schemas.BifrostChatRequest{ + Provider: c.originalReq.Provider, + Model: c.originalReq.Model, + Fallbacks: c.originalReq.Fallbacks, + Params: c.originalReq.Params, + Input: chatMessages, + } +} + +func (c *chatAPIAdapter) makeLLMCall(ctx *schemas.BifrostContext, request interface{}) (interface{}, *schemas.BifrostError) { + chatRequest := request.(*schemas.BifrostChatRequest) + return c.makeReq(ctx, chatRequest) +} + +func (c *chatAPIAdapter) createResponseWithExecutedTools( + response interface{}, + executedToolResults []*schemas.ChatMessage, + executedToolCalls []schemas.ChatAssistantMessageToolCall, + nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall, +) interface{} { + chatResponse := response.(*schemas.BifrostChatResponse) + return createChatResponseWithExecutedToolsAndNonAutoExecutableCalls( + chatResponse, + executedToolResults, + executedToolCalls, + nonAutoExecutableToolCalls, + ) +} + +// createChatResponseWithExecutedToolsAndNonAutoExecutableCalls creates a chat response +// that includes executed tool results and non-auto-executable tool calls. The response +// contains a formatted text summary of executed tool results and includes the non-auto-executable +// tool calls for the caller to handle. The finish reason is set to "stop" to prevent +// further agent loop iterations. +// +// Parameters: +// - originalResponse: The original chat response to copy metadata from +// - executedToolResults: List of tool execution results from auto-executable tools +// - executedToolCalls: List of tool calls that were executed +// - nonAutoExecutableToolCalls: List of tool calls that require manual execution +// +// Returns: +// - *schemas.BifrostChatResponse: A new chat response with executed results and pending tool calls +func createChatResponseWithExecutedToolsAndNonAutoExecutableCalls( + originalResponse *schemas.BifrostChatResponse, + executedToolResults []*schemas.ChatMessage, + executedToolCalls []schemas.ChatAssistantMessageToolCall, + nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall, +) *schemas.BifrostChatResponse { + // Start with a copy of the original response metadata + response := &schemas.BifrostChatResponse{ + ID: originalResponse.ID, + Object: originalResponse.Object, + Created: originalResponse.Created, + Model: originalResponse.Model, + Choices: make([]schemas.BifrostResponseChoice, 0), + ServiceTier: originalResponse.ServiceTier, + SystemFingerprint: originalResponse.SystemFingerprint, + Usage: originalResponse.Usage, + ExtraFields: originalResponse.ExtraFields, + SearchResults: originalResponse.SearchResults, + Videos: originalResponse.Videos, + Citations: originalResponse.Citations, + } + + // Build a map from tool call ID to tool name for easy lookup + toolCallIDToName := make(map[string]string) + for _, toolCall := range executedToolCalls { + if toolCall.ID != nil && toolCall.Function.Name != nil { + toolCallIDToName[*toolCall.ID] = *toolCall.Function.Name + } + } + + // Build content text showing executed tool results + var contentText string + if len(executedToolResults) > 0 { + // Format tool results as JSON-like structure + toolResultsMap := make(map[string]interface{}) + for _, toolResult := range executedToolResults { + // Get tool name from tool call ID mapping + var toolName string + if toolResult.ChatToolMessage != nil && toolResult.ChatToolMessage.ToolCallID != nil { + toolCallID := *toolResult.ChatToolMessage.ToolCallID + if name, ok := toolCallIDToName[toolCallID]; ok { + toolName = name + } else { + toolName = toolCallID // Fallback to tool call ID if name not found + } + } else { + toolName = "unknown_tool" + } + + // Extract output from tool result + var output interface{} + if toolResult.Content != nil { + if toolResult.Content.ContentStr != nil { + output = *toolResult.Content.ContentStr + } else if toolResult.Content.ContentBlocks != nil { + // Convert content blocks to a readable format + blocks := make([]map[string]interface{}, 0) + for _, block := range toolResult.Content.ContentBlocks { + blockMap := make(map[string]interface{}) + blockMap["type"] = string(block.Type) + if block.Text != nil { + blockMap["text"] = *block.Text + } + blocks = append(blocks, blockMap) + } + output = blocks + } + } + toolResultsMap[toolName] = output + } + + // Convert to JSON string for display + jsonBytes, err := sonic.Marshal(toolResultsMap) + if err != nil { + // Fallback to simple string representation + contentText = fmt.Sprintf("The Output from allowed tools calls is - %v\n\nNow I shall call these tools next...", toolResultsMap) + } else { + contentText = fmt.Sprintf("The Output from allowed tools calls is - %s\n\nNow I shall call these tools next...", string(jsonBytes)) + } + } else { + contentText = "Now I shall call these tools next..." + } + + // Create content with the formatted text + content := &schemas.ChatMessageContent{ + ContentStr: &contentText, + } + + // Determine finish reason + // Note: We set finish_reason to "stop" (not "tool_calls") for non-auto-executable tools + // to prevent the agent loop from retrying. The tool calls are still included in the response + // for the caller to handle, but setting finish_reason to "stop" ensures hasToolCalls returns false + // and the agent loop exits properly. + finishReason := "stop" + + // Create a single choice with the formatted content and non-auto-executable tool calls + response.Choices = append(response.Choices, schemas.BifrostResponseChoice{ + Index: 0, + FinishReason: &finishReason, + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: content, + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: nonAutoExecutableToolCalls, + }, + }, + }, + }) + + return response +} + +// Responses API adapter implementations +func (r *responsesAPIAdapter) getConversationHistory() []interface{} { + history := make([]interface{}, 0) + if r.originalReq.Input != nil { + for _, msg := range r.originalReq.Input { + history = append(history, msg) + } + } + return history +} + +func (r *responsesAPIAdapter) getOriginalRequest() interface{} { + return r.originalReq +} + +func (r *responsesAPIAdapter) getInitialResponse() interface{} { + return r.initialResponse +} + +func (r *responsesAPIAdapter) hasToolCalls(response interface{}) bool { + responsesResponse := response.(*schemas.BifrostResponsesResponse) + return hasToolCallsForResponsesResponse(responsesResponse) +} + +func (r *responsesAPIAdapter) extractToolCalls(response interface{}) []schemas.ChatAssistantMessageToolCall { + responsesResponse := response.(*schemas.BifrostResponsesResponse) + // Convert to Chat format and extract tool calls using existing logic + chatResponse := responsesResponse.ToBifrostChatResponse() + return extractToolCalls(chatResponse) +} + +func (r *responsesAPIAdapter) addAssistantMessage(conversation []interface{}, response interface{}) []interface{} { + responsesResponse := response.(*schemas.BifrostResponsesResponse) + for _, output := range responsesResponse.Output { + conversation = append(conversation, output) + } + return conversation +} + +func (r *responsesAPIAdapter) addToolResults(conversation []interface{}, toolResults []*schemas.ChatMessage) []interface{} { + for _, toolResult := range toolResults { + // Convert using existing converter + responsesMessages := toolResult.ToResponsesMessages() + for _, respMsg := range responsesMessages { + conversation = append(conversation, respMsg) + } + } + return conversation +} + +func (r *responsesAPIAdapter) createNewRequest(conversation []interface{}) interface{} { + // Convert conversation back to ResponsesMessage slice + responsesMessages := make([]schemas.ResponsesMessage, 0, len(conversation)) + for _, msg := range conversation { + responsesMessages = append(responsesMessages, msg.(schemas.ResponsesMessage)) + } + + return &schemas.BifrostResponsesRequest{ + Provider: r.originalReq.Provider, + Model: r.originalReq.Model, + Fallbacks: r.originalReq.Fallbacks, + Params: r.originalReq.Params, + Input: responsesMessages, + } +} + +func (r *responsesAPIAdapter) makeLLMCall(ctx *schemas.BifrostContext, request interface{}) (interface{}, *schemas.BifrostError) { + responsesRequest := request.(*schemas.BifrostResponsesRequest) + return r.makeReq(ctx, responsesRequest) +} + +func (r *responsesAPIAdapter) createResponseWithExecutedTools( + response interface{}, + executedToolResults []*schemas.ChatMessage, + executedToolCalls []schemas.ChatAssistantMessageToolCall, + nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall, +) interface{} { + responsesResponse := response.(*schemas.BifrostResponsesResponse) + + // Create response with executed tools directly on Responses schema + return createResponsesResponseWithExecutedToolsAndNonAutoExecutableCalls( + responsesResponse, + executedToolResults, + executedToolCalls, + nonAutoExecutableToolCalls, + ) +} + +// createResponsesResponseWithExecutedToolsAndNonAutoExecutableCalls creates a responses response +// that includes executed tool results and non-auto-executable tool calls. The response +// contains a formatted text summary of executed tool results and includes the non-auto-executable +// tool calls for the caller to handle. All Response-specific fields are preserved. +// +// Parameters: +// - originalResponse: The original responses response to copy metadata from +// - executedToolResults: List of tool execution results from auto-executable tools +// - executedToolCalls: List of tool calls that were executed +// - nonAutoExecutableToolCalls: List of tool calls that require manual execution +// +// Returns: +// - *schemas.BifrostResponsesResponse: A new responses response with executed results and pending tool calls +func createResponsesResponseWithExecutedToolsAndNonAutoExecutableCalls( + originalResponse *schemas.BifrostResponsesResponse, + executedToolResults []*schemas.ChatMessage, + executedToolCalls []schemas.ChatAssistantMessageToolCall, + nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall, +) *schemas.BifrostResponsesResponse { + // Start with a copy of the original response, preserving all Response-specific fields + response := &schemas.BifrostResponsesResponse{ + ID: originalResponse.ID, + Background: originalResponse.Background, + Conversation: originalResponse.Conversation, + CreatedAt: originalResponse.CreatedAt, + Error: originalResponse.Error, + Include: originalResponse.Include, + IncompleteDetails: originalResponse.IncompleteDetails, + Instructions: originalResponse.Instructions, + MaxOutputTokens: originalResponse.MaxOutputTokens, + MaxToolCalls: originalResponse.MaxToolCalls, + Metadata: originalResponse.Metadata, + ParallelToolCalls: originalResponse.ParallelToolCalls, + PreviousResponseID: originalResponse.PreviousResponseID, + Prompt: originalResponse.Prompt, + PromptCacheKey: originalResponse.PromptCacheKey, + Reasoning: originalResponse.Reasoning, + SafetyIdentifier: originalResponse.SafetyIdentifier, + ServiceTier: originalResponse.ServiceTier, + StreamOptions: originalResponse.StreamOptions, + Store: originalResponse.Store, + Temperature: originalResponse.Temperature, + Text: originalResponse.Text, + TopLogProbs: originalResponse.TopLogProbs, + TopP: originalResponse.TopP, + ToolChoice: originalResponse.ToolChoice, + Tools: originalResponse.Tools, + Truncation: originalResponse.Truncation, + Usage: originalResponse.Usage, + ExtraFields: originalResponse.ExtraFields, + // Perplexity-specific fields + SearchResults: originalResponse.SearchResults, + Videos: originalResponse.Videos, + Citations: originalResponse.Citations, + Output: make([]schemas.ResponsesMessage, 0), + } + + // Build a map from tool call ID to tool name for easy lookup + toolCallIDToName := make(map[string]string) + for _, toolCall := range executedToolCalls { + if toolCall.ID != nil && toolCall.Function.Name != nil { + toolCallIDToName[*toolCall.ID] = *toolCall.Function.Name + } + } + + // Build content text showing executed tool results + var contentText string + if len(executedToolResults) > 0 { + // Format tool results as JSON-like structure + toolResultsMap := make(map[string]interface{}) + for _, toolResult := range executedToolResults { + // Get tool name from tool call ID mapping + var toolName string + if toolResult.ChatToolMessage != nil && toolResult.ChatToolMessage.ToolCallID != nil { + toolCallID := *toolResult.ChatToolMessage.ToolCallID + if name, ok := toolCallIDToName[toolCallID]; ok { + toolName = name + } else { + toolName = toolCallID // Fallback to tool call ID if name not found + } + } else { + toolName = "unknown_tool" + } + + // Extract output from tool result + var output interface{} + if toolResult.Content != nil { + if toolResult.Content.ContentStr != nil { + output = *toolResult.Content.ContentStr + } else if toolResult.Content.ContentBlocks != nil { + // Convert content blocks to a readable format + blocks := make([]map[string]interface{}, 0) + for _, block := range toolResult.Content.ContentBlocks { + blockMap := make(map[string]interface{}) + blockMap["type"] = string(block.Type) + if block.Text != nil { + blockMap["text"] = *block.Text + } + blocks = append(blocks, blockMap) + } + output = blocks + } + } + toolResultsMap[toolName] = output + } + + // Convert to JSON string for display + jsonBytes, err := sonic.Marshal(toolResultsMap) + if err != nil { + // Fallback to simple string representation + contentText = fmt.Sprintf("The Output from allowed tools calls is - %v\n\nNow I shall call these tools next...", toolResultsMap) + } else { + contentText = fmt.Sprintf("The Output from allowed tools calls is - %s\n\nNow I shall call these tools next...", string(jsonBytes)) + } + } else { + contentText = "Now I shall call these tools next..." + } + + // Create assistant message with the formatted text content + messageType := schemas.ResponsesMessageTypeMessage + role := schemas.ResponsesInputMessageRoleAssistant + assistantMessage := schemas.ResponsesMessage{ + Type: &messageType, + Role: &role, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{ + { + Type: schemas.ResponsesOutputMessageContentTypeText, + Text: &contentText, + }, + }, + }, + } + response.Output = append(response.Output, assistantMessage) + + // Add non-auto-executable tool calls as separate function_call messages + for _, toolCall := range nonAutoExecutableToolCalls { + functionCallType := schemas.ResponsesMessageTypeFunctionCall + assistantRole := schemas.ResponsesInputMessageRoleAssistant + + var callID *string + if toolCall.ID != nil && *toolCall.ID != "" { + callID = toolCall.ID + } + + var namePtr *string + if toolCall.Function.Name != nil && *toolCall.Function.Name != "" { + namePtr = toolCall.Function.Name + } + + var argumentsPtr *string + if toolCall.Function.Arguments != "" { + argumentsPtr = &toolCall.Function.Arguments + } + + toolCallMessage := schemas.ResponsesMessage{ + Type: &functionCallType, + Role: &assistantRole, + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: callID, + Name: namePtr, + Arguments: argumentsPtr, + }, + } + + response.Output = append(response.Output, toolCallMessage) + } + + return response +} diff --git a/core/mcp/clientmanager.go b/core/mcp/clientmanager.go new file mode 100644 index 0000000000..fe7392abbe --- /dev/null +++ b/core/mcp/clientmanager.go @@ -0,0 +1,700 @@ +package mcp + +import ( + "context" + "fmt" + "maps" + "os" + "strings" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/maximhq/bifrost/core/schemas" +) + +// GetClients returns all MCP clients managed by the manager. +// +// Returns: +// - []*schemas.MCPClientState: List of all MCP clients +func (m *MCPManager) GetClients() []schemas.MCPClientState { + m.mu.RLock() + defer m.mu.RUnlock() + + clients := make([]schemas.MCPClientState, 0, len(m.clientMap)) + for _, client := range m.clientMap { + snapshot := *client + if client.ToolMap != nil { + snapshot.ToolMap = make(map[string]schemas.ChatTool, len(client.ToolMap)) + maps.Copy(snapshot.ToolMap, client.ToolMap) + } + clients = append(clients, snapshot) + } + + return clients +} + +// ReconnectClient attempts to reconnect an MCP client if it is disconnected. +// It validates that the client exists and then establishes a new connection using +// the client's existing configuration. +// +// Parameters: +// - id: ID of the client to reconnect +// +// Returns: +// - error: Any error that occurred during reconnection +func (m *MCPManager) ReconnectClient(id string) error { + m.mu.Lock() + client, ok := m.clientMap[id] + if !ok { + m.mu.Unlock() + return fmt.Errorf("client %s not found", id) + } + config := client.ExecutionConfig + m.mu.Unlock() + + // connectToMCPClient handles locking internally + err := m.connectToMCPClient(config) + if err != nil { + return fmt.Errorf("failed to connect to MCP client %s: %w", id, err) + } + + return nil +} + +// AddClient adds a new MCP client to the manager. +// It validates the client configuration and establishes a connection. +// If connection fails, the client entry is automatically cleaned up. +// +// Parameters: +// - config: MCP client configuration +// +// Returns: +// - error: Any error that occurred during client addition or connection +func (m *MCPManager) AddClient(config schemas.MCPClientConfig) error { + if err := validateMCPClientConfig(&config); err != nil { + return fmt.Errorf("invalid MCP client configuration: %w", err) + } + + // Make a copy of the config to use after unlocking + configCopy := config + + m.mu.Lock() + + if _, ok := m.clientMap[config.ID]; ok { + m.mu.Unlock() + return fmt.Errorf("client %s already exists", config.Name) + } + + // Create placeholder entry + m.clientMap[config.ID] = &schemas.MCPClientState{ + ExecutionConfig: config, + ToolMap: make(map[string]schemas.ChatTool), + } + + // Temporarily unlock for the connection attempt + // This is to avoid deadlocks when the connection attempt is made + m.mu.Unlock() + + // Connect using the copied config + if err := m.connectToMCPClient(configCopy); err != nil { + // Re-lock to clean up the failed entry + m.mu.Lock() + delete(m.clientMap, config.ID) + m.mu.Unlock() + return fmt.Errorf("failed to connect to MCP client %s: %w", config.Name, err) + } + + return nil +} + +// RemoveClient removes an MCP client from the manager. +// It handles cleanup for all transport types (HTTP, STDIO, SSE). +// +// Parameters: +// - id: ID of the client to remove +func (m *MCPManager) RemoveClient(id string) error { + m.mu.Lock() + defer m.mu.Unlock() + + return m.removeClientUnsafe(id) +} + +// removeClientUnsafe removes an MCP client from the manager without acquiring locks. +// This is an internal method that should only be called when the caller already holds +// the appropriate lock. It handles cleanup for all transport types including cancellation +// of SSE contexts and closing of transport connections. +// +// Parameters: +// - id: ID of the client to remove +// +// Returns: +// - error: Any error that occurred during client removal +func (m *MCPManager) removeClientUnsafe(id string) error { + client, ok := m.clientMap[id] + if !ok { + return fmt.Errorf("client %s not found", id) + } + + logger.Info(fmt.Sprintf("%s Disconnecting MCP server '%s'", MCPLogPrefix, client.ExecutionConfig.Name)) + + // Stop health monitoring for this client + m.healthMonitorManager.StopMonitoring(id) + + // Cancel SSE context if present (required for proper SSE cleanup) + if client.CancelFunc != nil { + client.CancelFunc() + client.CancelFunc = nil + } + + // Close the client transport connection + // This handles cleanup for all transport types (HTTP, STDIO, SSE) + if client.Conn != nil { + if err := client.Conn.Close(); err != nil { + logger.Error("%s Failed to close MCP server '%s': %v", MCPLogPrefix, client.ExecutionConfig.Name, err) + } + client.Conn = nil + } + + // Clear client tool map + client.ToolMap = make(map[string]schemas.ChatTool) + + delete(m.clientMap, id) + return nil +} + +// EditClient updates an existing MCP client's configuration and refreshes its tool list. +// It updates the client's execution config with new settings and retrieves updated tools +// from the MCP server if the client is connected. +// This method does not refresh the client's tool list. +// To refresh the client's tool list, use the ReconnectClient method. +// +// Parameters: +// - id: ID of the client to edit +// - updatedConfig: Updated client configuration with new settings +// +// Returns: +// - error: Any error that occurred during client update or tool retrieval +func (m *MCPManager) EditClient(id string, updatedConfig schemas.MCPClientConfig) error { + m.mu.Lock() + defer m.mu.Unlock() + + client, ok := m.clientMap[id] + if !ok { + return fmt.Errorf("client %s not found", id) + } + + if err := validateMCPClientName(updatedConfig.Name); err != nil { + return fmt.Errorf("invalid MCP client configuration: %w", err) + } + + // Update the client's execution config with new tool filters + config := client.ExecutionConfig + config.Name = updatedConfig.Name + config.IsCodeModeClient = updatedConfig.IsCodeModeClient + config.Headers = updatedConfig.Headers + config.ToolsToExecute = updatedConfig.ToolsToExecute + config.ToolsToAutoExecute = updatedConfig.ToolsToAutoExecute + + // Store the updated config + client.ExecutionConfig = config + return nil +} + +// registerTool registers a typed tool handler with the local MCP server. +// This is a convenience function that handles the conversion between typed Go +// handlers and the MCP protocol. +// +// Type Parameters: +// - T: The expected argument type for the tool (must be JSON-deserializable) +// +// Parameters: +// - name: Unique tool name +// - description: Human-readable tool description +// - handler: Typed function that handles tool execution +// - toolSchema: Bifrost tool schema for function calling +// +// Returns: +// - error: Any registration error +// +// Example: +// +// type EchoArgs struct { +// Message string `json:"message"` +// } +// +// err := bifrost.RegisterMCPTool("echo", "Echo a message", +// func(args EchoArgs) (string, error) { +// return args.Message, nil +// }, toolSchema) +func (m *MCPManager) RegisterTool(name, description string, toolFunction MCPToolFunction[any], toolSchema schemas.ChatTool) error { + // Ensure local server is set up + if err := m.setupLocalHost(); err != nil { + return fmt.Errorf("failed to setup local host: %w", err) + } + + // Validate tool name + if strings.TrimSpace(name) == "" { + return fmt.Errorf("tool name is required") + } + if strings.Contains(name, "-") { + return fmt.Errorf("tool name cannot contain hyphens") + } + if strings.Contains(name, " ") { + return fmt.Errorf("tool name cannot contain spaces") + } + if len(name) > 0 && name[0] >= '0' && name[0] <= '9' { + return fmt.Errorf("tool name cannot start with a number") + } + + m.mu.Lock() + defer m.mu.Unlock() + + // Verify internal client exists + internalClient, ok := m.clientMap[BifrostMCPClientKey] + if !ok { + return fmt.Errorf("bifrost client not found") + } + + // Check if tool name already exists to prevent silent overwrites + if _, exists := internalClient.ToolMap[name]; exists { + return fmt.Errorf("tool '%s' is already registered", name) + } + + logger.Info(fmt.Sprintf("%s Registering typed tool: %s", MCPLogPrefix, name)) + + // Create MCP handler wrapper that converts between typed and MCP interfaces + mcpHandler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract arguments from the request using the request's methods + args := request.GetArguments() + result, err := toolFunction(args) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Error: %s", err.Error())), nil + } + return mcp.NewToolResultText(result), nil + } + + // Register the tool with the local MCP server using AddTool + if m.server != nil { + tool := mcp.NewTool(name, mcp.WithDescription(description)) + m.server.AddTool(tool, mcpHandler) + } + + // Store tool definition for Bifrost integration + internalClient.ToolMap[name] = toolSchema + + return nil +} + +// ============================================================================ +// CONNECTION HELPER METHODS +// ============================================================================ + +// connectToMCPClient establishes a connection to an external MCP server and +// registers its available tools with the manager. +func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { + // First lock: Initialize or validate client entry + m.mu.Lock() + + // Initialize or validate client entry + if existingClient, exists := m.clientMap[config.ID]; exists { + // Client entry exists from config, check for existing connection, if it does then close + if existingClient.CancelFunc != nil { + existingClient.CancelFunc() + existingClient.CancelFunc = nil + } + if existingClient.Conn != nil { + existingClient.Conn.Close() + } + // Update connection type for this connection attempt + existingClient.ConnectionInfo.Type = config.ConnectionType + } + // Create new client entry with configuration + m.clientMap[config.ID] = &schemas.MCPClientState{ + ExecutionConfig: config, + ToolMap: make(map[string]schemas.ChatTool), + ConnectionInfo: schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + }, + } + m.mu.Unlock() + + // Heavy operations performed outside lock + var externalClient *client.Client + var connectionInfo schemas.MCPClientConnectionInfo + var err error + + // Create appropriate transport based on connection type + switch config.ConnectionType { + case schemas.MCPConnectionTypeHTTP: + externalClient, connectionInfo, err = m.createHTTPConnection(config) + case schemas.MCPConnectionTypeSTDIO: + externalClient, connectionInfo, err = m.createSTDIOConnection(config) + case schemas.MCPConnectionTypeSSE: + externalClient, connectionInfo, err = m.createSSEConnection(config) + case schemas.MCPConnectionTypeInProcess: + externalClient, connectionInfo, err = m.createInProcessConnection(config) + default: + return fmt.Errorf("unknown connection type: %s", config.ConnectionType) + } + + if err != nil { + return fmt.Errorf("failed to create connection: %w", err) + } + + // Initialize the external client with timeout + // For SSE connections, we need a long-lived context, for others we can use timeout + var ctx context.Context + var cancel context.CancelFunc + + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + // SSE connections need a long-lived context for the persistent stream + ctx, cancel = context.WithCancel(m.ctx) + // Don't defer cancel here - SSE needs the context to remain active + } else { + // Other connection types can use timeout context + ctx, cancel = context.WithTimeout(m.ctx, MCPClientConnectionEstablishTimeout) + defer cancel() + } + + // Start the transport first (required for STDIO and SSE clients) + if err := externalClient.Start(ctx); err != nil { + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + cancel() // Cancel SSE context only on error + } + return fmt.Errorf("failed to start MCP client transport %s: %v", config.Name, err) + } + + // Create proper initialize request for external client + extInitRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: fmt.Sprintf("Bifrost-%s", config.Name), + Version: "1.0.0", + }, + }, + } + + _, err = externalClient.Initialize(ctx, extInitRequest) + if err != nil { + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + cancel() // Cancel SSE context only on error + } + return fmt.Errorf("failed to initialize MCP client %s: %v", config.Name, err) + } + + // Retrieve tools from the external server (this also requires network I/O) + tools, err := retrieveExternalTools(ctx, externalClient, config.Name) + if err != nil { + logger.Warn(fmt.Sprintf("%s Failed to retrieve tools from %s: %v", MCPLogPrefix, config.Name, err)) + // Continue with connection even if tool retrieval fails + tools = make(map[string]schemas.ChatTool) + } + + // Second lock: Update client with final connection details and tools + m.mu.Lock() + defer m.mu.Unlock() + + // Verify client still exists (could have been cleaned up during heavy operations) + if client, exists := m.clientMap[config.ID]; exists { + // Store the external client connection and details + client.Conn = externalClient + client.ConnectionInfo = connectionInfo + client.State = schemas.MCPConnectionStateConnected + + // Store cancel function for SSE connections to enable proper cleanup + if config.ConnectionType == schemas.MCPConnectionTypeSSE { + client.CancelFunc = cancel + } + + // Store discovered tools + for toolName, tool := range tools { + client.ToolMap[toolName] = tool + } + + logger.Info(fmt.Sprintf("%s Connected to MCP server '%s'", MCPLogPrefix, config.Name)) + } else { + // Clean up resources before returning error: client was removed during connection setup + // Cancel SSE context if it was created + if config.ConnectionType == schemas.MCPConnectionTypeSSE && cancel != nil { + cancel() + } + // Close external client connection to prevent transport/goroutine leaks + if externalClient != nil { + if err := externalClient.Close(); err != nil { + logger.Warn(fmt.Sprintf("%s Failed to close external client during cleanup: %v", MCPLogPrefix, err)) + } + } + return fmt.Errorf("client %s was removed during connection setup", config.Name) + } + + // Register OnConnectionLost hook for SSE connections to detect idle timeouts + if config.ConnectionType == schemas.MCPConnectionTypeSSE && externalClient != nil { + externalClient.OnConnectionLost(func(err error) { + logger.Warn(fmt.Sprintf("%s SSE connection lost for MCP server '%s': %v", MCPLogPrefix, config.Name, err)) + // Update state to disconnected + m.mu.Lock() + if client, exists := m.clientMap[config.ID]; exists { + client.State = schemas.MCPConnectionStateDisconnected + } + m.mu.Unlock() + }) + } + + // Start health monitoring for the client + monitor := NewClientHealthMonitor(m, config.ID, DefaultHealthCheckInterval) + m.healthMonitorManager.StartMonitoring(monitor) + + return nil +} + +// createHTTPConnection creates an HTTP-based MCP client connection without holding locks. +func (m *MCPManager) createHTTPConnection(config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { + if config.ConnectionString == nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("HTTP connection string is required") + } + + // Prepare connection info + connectionInfo := schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + ConnectionURL: config.ConnectionString, + } + + // Create StreamableHTTP transport + httpTransport, err := transport.NewStreamableHTTP(*config.ConnectionString, transport.WithHTTPHeaders(config.Headers)) + if err != nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to create HTTP transport: %w", err) + } + + client := client.NewClient(httpTransport) + + return client, connectionInfo, nil +} + +// createSTDIOConnection creates a STDIO-based MCP client connection without holding locks. +func (m *MCPManager) createSTDIOConnection(config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { + if config.StdioConfig == nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("stdio config is required") + } + + // Prepare STDIO command info for display + cmdString := fmt.Sprintf("%s %s", config.StdioConfig.Command, strings.Join(config.StdioConfig.Args, " ")) + + // Check if environment variables are set + for _, env := range config.StdioConfig.Envs { + if os.Getenv(env) == "" { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("environment variable %s is not set for MCP client %s", env, config.Name) + } + } + + // Create STDIO transport + stdioTransport := transport.NewStdio( + config.StdioConfig.Command, + config.StdioConfig.Envs, + config.StdioConfig.Args..., + ) + + // Prepare connection info + connectionInfo := schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + StdioCommandString: &cmdString, + } + + client := client.NewClient(stdioTransport) + + // Return nil for cmd since mark3labs/mcp-go manages the process internally + return client, connectionInfo, nil +} + +// createSSEConnection creates a SSE-based MCP client connection without holding locks. +func (m *MCPManager) createSSEConnection(config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { + if config.ConnectionString == nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("SSE connection string is required") + } + + // Prepare connection info + connectionInfo := schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + ConnectionURL: config.ConnectionString, // Reuse HTTPConnectionURL field for SSE URL display + } + + // Create SSE transport + sseTransport, err := transport.NewSSE(*config.ConnectionString, transport.WithHeaders(config.Headers)) + if err != nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to create SSE transport: %w", err) + } + + client := client.NewClient(sseTransport) + + return client, connectionInfo, nil +} + +// createInProcessConnection creates an in-process MCP client connection without holding locks. +// This allows direct connection to an MCP server running in the same process, providing +// the lowest latency and highest performance for tool execution. +func (m *MCPManager) createInProcessConnection(config schemas.MCPClientConfig) (*client.Client, schemas.MCPClientConnectionInfo, error) { + if config.InProcessServer == nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("InProcess connection requires a server instance") + } + + // Create in-process client directly connected to the provided server + inProcessClient, err := client.NewInProcessClient(config.InProcessServer) + if err != nil { + return nil, schemas.MCPClientConnectionInfo{}, fmt.Errorf("failed to create in-process client: %w", err) + } + + // Prepare connection info + connectionInfo := schemas.MCPClientConnectionInfo{ + Type: config.ConnectionType, + } + + return inProcessClient, connectionInfo, nil +} + +// ============================================================================ +// LOCAL MCP SERVER AND CLIENT MANAGEMENT +// ============================================================================ + +// setupLocalHost initializes the local MCP server and client if not already running. +// This creates a STDIO-based server for local tool hosting and a corresponding client. +// This is called automatically when tools are registered or when the server is needed. +// +// Returns: +// - error: Any setup error +func (m *MCPManager) setupLocalHost() error { + // First check: fast path if already initialized + m.mu.Lock() + if m.server != nil && m.serverRunning { + m.mu.Unlock() + return nil + } + m.mu.Unlock() + + // Create server and client into local variables (outside lock to avoid + // holding lock during object creation, even though it's lightweight) + server, err := m.createLocalMCPServer() + if err != nil { + return fmt.Errorf("failed to create local MCP server: %w", err) + } + + client, err := m.createLocalMCPClient() + if err != nil { + return fmt.Errorf("failed to create local MCP client: %w", err) + } + + // Second check and assignment: hold lock for atomic check-and-set + m.mu.Lock() + // Double-check: another goroutine might have initialized while we were creating + if m.server != nil && m.serverRunning { + m.mu.Unlock() + return nil + } + + // Assign server and client atomically while holding the lock + m.server = server + m.clientMap[BifrostMCPClientKey] = client + m.mu.Unlock() + + // Start the server and initialize client connection + // (startLocalMCPServer already locks internally) + return m.startLocalMCPServer() +} + +// createLocalMCPServer creates a new local MCP server instance with STDIO transport. +// This server will host tools registered via RegisterTool function. +// +// Returns: +// - *server.MCPServer: Configured MCP server instance +// - error: Any creation error +func (m *MCPManager) createLocalMCPServer() (*server.MCPServer, error) { + // Create MCP server + mcpServer := server.NewMCPServer( + "Bifrost-MCP-Server", + "1.0.0", + server.WithToolCapabilities(true), + ) + + return mcpServer, nil +} + +// createLocalMCPClient creates a placeholder client entry for the local MCP server. +// The actual in-process client connection will be established in startLocalMCPServer. +// +// Returns: +// - *schemas.MCPClientState: Placeholder client for local server +// - error: Any creation error +func (m *MCPManager) createLocalMCPClient() (*schemas.MCPClientState, error) { + // Don't create the actual client connection here - it will be created + // after the server is ready using NewInProcessClient + return &schemas.MCPClientState{ + ExecutionConfig: schemas.MCPClientConfig{ + ID: BifrostMCPClientKey, + Name: BifrostMCPClientName, + ToolsToExecute: []string{"*"}, // Allow all tools for internal client + }, + ToolMap: make(map[string]schemas.ChatTool), + ConnectionInfo: schemas.MCPClientConnectionInfo{ + Type: schemas.MCPConnectionTypeInProcess, // Accurate: in-process (in-memory) transport + }, + }, nil +} + +// startLocalMCPServer creates an in-process connection between the local server and client. +// +// Returns: +// - error: Any startup error +func (m *MCPManager) startLocalMCPServer() error { + m.mu.Lock() + defer m.mu.Unlock() + + // Check if server is already running + if m.server != nil && m.serverRunning { + return nil + } + + if m.server == nil { + return fmt.Errorf("server not initialized") + } + + // Create in-process client directly connected to the server + inProcessClient, err := client.NewInProcessClient(m.server) + if err != nil { + return fmt.Errorf("failed to create in-process MCP client: %w", err) + } + + // Update the client connection + clientEntry, ok := m.clientMap[BifrostMCPClientKey] + if !ok { + return fmt.Errorf("bifrost client not found") + } + clientEntry.Conn = inProcessClient + + // Initialize the in-process client + ctx, cancel := context.WithTimeout(m.ctx, MCPClientConnectionEstablishTimeout) + defer cancel() + + // Create proper initialize request with correct structure + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: BifrostMCPClientName, + Version: BifrostMCPVersion, + }, + }, + } + + _, err = inProcessClient.Initialize(ctx, initRequest) + if err != nil { + return fmt.Errorf("failed to initialize MCP client: %w", err) + } + + // Mark server as running + m.serverRunning = true + + return nil +} diff --git a/core/mcp/codemodeexecutecode.go b/core/mcp/codemodeexecutecode.go new file mode 100644 index 0000000000..2019339203 --- /dev/null +++ b/core/mcp/codemodeexecutecode.go @@ -0,0 +1,1035 @@ +package mcp + +import ( + "context" + "fmt" + "regexp" + "strings" + "time" + + "github.com/bytedance/sonic" + "github.com/clarkmcc/go-typescript" + "github.com/dop251/goja" + "github.com/mark3labs/mcp-go/mcp" + "github.com/maximhq/bifrost/core/schemas" +) + +// toolBinding represents a tool binding for the VM +type toolBinding struct { + toolName string + clientName string +} + +// toolCallInfo represents a tool call extracted from code +type toolCallInfo struct { + serverName string + toolName string +} + +// ExecutionResult represents the result of code execution +type ExecutionResult struct { + Result interface{} `json:"result"` + Logs []string `json:"logs"` + Errors *ExecutionError `json:"errors,omitempty"` + Environment ExecutionEnvironment `json:"environment"` +} + +type ExecutionErrorType string + +const ( + ExecutionErrorTypeCompile ExecutionErrorType = "compile" + ExecutionErrorTypeTypescript ExecutionErrorType = "typescript" + ExecutionErrorTypeRuntime ExecutionErrorType = "runtime" +) + +// ExecutionError represents an error during code execution +type ExecutionError struct { + Kind ExecutionErrorType `json:"kind"` // "compile", "typescript", or "runtime" + Message string `json:"message"` + Hints []string `json:"hints"` +} + +// ExecutionEnvironment contains information about the execution environment +type ExecutionEnvironment struct { + ServerKeys []string `json:"serverKeys"` + ImportsStripped bool `json:"importsStripped"` + StrippedLines []int `json:"strippedLines"` + TypeScriptUsed bool `json:"typescriptUsed"` +} + +const ( + CodeModeLogPrefix = "[CODE MODE]" +) + +// createExecuteToolCodeTool creates the executeToolCode tool definition for code mode. +// This tool allows executing TypeScript code in a sandboxed VM with access to MCP server tools. +// +// Returns: +// - schemas.ChatTool: The tool definition for executing tool code +func (m *ToolsManager) createExecuteToolCodeTool() schemas.ChatTool { + executeToolCodeProps := schemas.OrderedMap{ + "code": map[string]interface{}{ + "type": "string", + "description": "TypeScript code to execute. The code will be transpiled to JavaScript and validated before execution. Import/export statements will be stripped. You can use async/await syntax for async operations. For simple use cases, directly return results. Check keys and value types only for debugging. Do not print entire outputs in console logs - only print structure (keys, types) when debugging. ALWAYS retry if code fails. Example (simple): const result = await serverName.toolName({arg: 'value'}); return result; Example (debugging): const result = await serverName.toolName({arg: 'value'}); const getStruct = (o, d=0) => d>2 ? '...' : o===null ? 'null' : Array.isArray(o) ? `Array[${o.length}]` : typeof o !== 'object' ? typeof o : Object.keys(o).reduce((a,k) => (a[k]=getStruct(o[k],d+1), a), {}); console.log('Structure:', getStruct(result)); return result;", + }, + } + return schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: ToolTypeExecuteToolCode, + Description: schemas.Ptr( + "Executes TypeScript code inside a sandboxed goja-based VM with access to all connected MCP servers' tools. " + + "TypeScript code is automatically transpiled to JavaScript and validated before execution, providing type checking and validation. " + + "All connected servers are exposed as global objects named after their configuration keys, and each server " + + "provides async (Promise-returning) functions for every tool available on that server. The canonical usage " + + "pattern is: const result = await .({ ...args }); Both and " + + "should be discovered using listToolFiles and readToolFile. " + + + "IMPORTANT WORKFLOW: Always follow this order — first use listToolFiles to see available servers and tools, " + + "then use readToolFile to understand the tool definitions and their parameters, and finally use executeToolCode " + + "to execute your code. Check listToolFiles whenever you are unsure about what tools you have available or if you want to verify available servers and their tools. " + + + "LOGGING GUIDELINES: For simple use cases, you can directly return results without logging. Check for keys and value types only " + + "for debugging purposes when you need to understand the response structure. Do not print the entire output in console logs. " + + "When debugging, use console logs to print just the output structure to understand its type. For nested objects, use a recursive helper to show types at all levels. " + + "For example: const getStruct = (o, d=0) => d>2 ? '...' : o===null ? 'null' : Array.isArray(o) ? `Array[${o.length}]` : typeof o !== 'object' ? typeof o : Object.keys(o).reduce((a,k) => (a[k]=getStruct(o[k],d+1), a), {}); " + + "console.log('Structure:', getStruct(result)); Only print the entire data if absolutely necessary for debugging. " + + "This helps understand the response structure without cluttering the output with full object contents. " + + + "RETRY POLICY: ALWAYS retry if a code block fails. If execution produces an error or unexpected result, analyze the error, " + + "adjust your code accordingly for better results or debugging, and retry the execution. Do not give up after a single failure — iterate and improve your code until it succeeds. " + + + "The environment is intentionally minimal and has several constraints: " + + "• ES modules are not supported — any leading import/export statements are automatically stripped and imported symbols will not exist. " + + "• Browser and Node APIs such as fetch, XMLHttpRequest, axios, require, setTimeout, setInterval, window, and document do not exist. " + + "• async/await syntax is supported and automatically transpiled to Promise chains compatible with goja. " + + "• Using undefined server names or tool names will result in reference or function errors. " + + "• The VM does not emulate a browser or Node.js environment — no DOM, timers, modules, or network APIs are available. " + + "• Only ES5.1+ features supported by goja are guaranteed to work. " + + "• TypeScript type checking occurs during transpilation — type errors will prevent execution. " + + + "If you want a value returned from the code, write a top-level 'return '; otherwise the return value will be null. " + + "Console output (log, error, warn, info) is captured and returned. " + + "Long-running or blocked operations are interrupted via execution timeout. " + + "This tool is designed specifically for orchestrating MCP tool calls and lightweight TypeScript computation.", + ), + + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &executeToolCodeProps, + Required: []string{"code"}, + }, + }, + } +} + +// handleExecuteToolCode handles the executeToolCode tool call. +// It parses the code argument, executes it in a sandboxed VM, and formats the response +// with execution results, logs, errors, and environment information. +// +// Parameters: +// - ctx: Context for code execution +// - toolCall: The tool call request containing the TypeScript code to execute +// +// Returns: +// - *schemas.ChatMessage: A tool response message containing execution results +// - error: Any error that occurred during processing +func (m *ToolsManager) handleExecuteToolCode(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + toolName := "unknown" + if toolCall.Function.Name != nil { + toolName = *toolCall.Function.Name + } + logger.Debug(fmt.Sprintf("%s Handling executeToolCode tool call: %s", CodeModeLogPrefix, toolName)) + + // Parse tool arguments + var arguments map[string]interface{} + if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + logger.Debug(fmt.Sprintf("%s Failed to parse tool arguments: %v", CodeModeLogPrefix, err)) + return nil, fmt.Errorf("failed to parse tool arguments: %v", err) + } + + code, ok := arguments["code"].(string) + if !ok || code == "" { + logger.Debug(fmt.Sprintf("%s Code parameter missing or empty", CodeModeLogPrefix)) + return nil, fmt.Errorf("code parameter is required and must be a non-empty string") + } + + logger.Debug(fmt.Sprintf("%s Starting code execution", CodeModeLogPrefix)) + result := m.executeCode(ctx, code) + logger.Debug(fmt.Sprintf("%s Code execution completed. Success: %v, Has errors: %v, Log count: %d", CodeModeLogPrefix, result.Errors == nil, result.Errors != nil, len(result.Logs))) + + // Format response text + var responseText string + var executionSuccess bool = true // Track if execution was successful (has data) + if result.Errors != nil { + logger.Debug(fmt.Sprintf("%s Formatting error response. Error kind: %s, Message length: %d, Hints count: %d", CodeModeLogPrefix, result.Errors.Kind, len(result.Errors.Message), len(result.Errors.Hints))) + logsText := "" + if len(result.Logs) > 0 { + logsText = fmt.Sprintf("\n\nConsole/Log Output:\n%s\n", + strings.Join(result.Logs, "\n")) + } + errorKindLabel := result.Errors.Kind + + responseText = fmt.Sprintf( + "Execution %s error:\n\n%s\n\nHints:\n%s%s\n\nEnvironment:\n Available server keys: %s\n TypeScript used: %s\n Imports stripped: %s", + errorKindLabel, + result.Errors.Message, + strings.Join(result.Errors.Hints, "\n"), + logsText, + strings.Join(result.Environment.ServerKeys, ", "), + map[bool]string{true: "Yes", false: "No"}[result.Environment.TypeScriptUsed], + map[bool]string{true: "Yes", false: "No"}[result.Environment.ImportsStripped], + ) + if len(result.Environment.StrippedLines) > 0 { + strippedStr := make([]string, len(result.Environment.StrippedLines)) + for i, line := range result.Environment.StrippedLines { + strippedStr[i] = fmt.Sprintf("%d", line) + } + responseText += fmt.Sprintf("\n Stripped lines: %s", strings.Join(strippedStr, ", ")) + } + logger.Debug(fmt.Sprintf("%s Error response formatted. Response length: %d chars", CodeModeLogPrefix, len(responseText))) + } else { + // Success case - check if execution produced any data + hasLogs := len(result.Logs) > 0 + hasResult := result.Result != nil + logger.Debug(fmt.Sprintf("%s Formatting success response. Has logs: %v, Has result: %v", CodeModeLogPrefix, hasLogs, hasResult)) + + // If execution completed but produced no data (no logs, no return value), treat as failure + if !hasLogs && !hasResult { + executionSuccess = false + logger.Debug(fmt.Sprintf("%s Execution completed with no data (no logs, no result), marking as failure", CodeModeLogPrefix)) + hints := []string{ + "Add console.log() statements throughout your code to debug and see what's happening at each step", + "Ensure your code has a top-level return statement if you want to return a value", + "Check that your tool calls are actually executing and returning data", + "Verify that async operations (like await) are properly handled", + } + responseText = fmt.Sprintf( + "Execution completed but produced no data:\n\n"+ + "The code executed without errors but returned no output (no console logs and no return value).\n\n"+ + "Hints:\n%s\n\n"+ + "Environment:\n Available server keys: %s\n TypeScript used: %s\n Imports stripped: %s", + strings.Join(hints, "\n"), + strings.Join(result.Environment.ServerKeys, ", "), + map[bool]string{true: "Yes", false: "No"}[result.Environment.TypeScriptUsed], + map[bool]string{true: "Yes", false: "No"}[result.Environment.ImportsStripped], + ) + if len(result.Environment.StrippedLines) > 0 { + strippedStr := make([]string, len(result.Environment.StrippedLines)) + for i, line := range result.Environment.StrippedLines { + strippedStr[i] = fmt.Sprintf("%d", line) + } + responseText += fmt.Sprintf("\n Stripped lines: %s", strings.Join(strippedStr, ", ")) + } + logger.Debug(fmt.Sprintf("%s No-data failure response formatted. Response length: %d chars", CodeModeLogPrefix, len(responseText))) + } else { + // Normal success case with data + if hasLogs { + responseText = fmt.Sprintf("Console output:\n%s\n\nExecution completed successfully.", + strings.Join(result.Logs, "\n")) + } else { + responseText = "Execution completed successfully." + } + if hasResult { + resultJSON, err := sonic.MarshalIndent(result.Result, "", " ") + if err == nil { + responseText += fmt.Sprintf("\nReturn value: %s", string(resultJSON)) + logger.Debug(fmt.Sprintf("%s Added return value to response (JSON length: %d chars)", CodeModeLogPrefix, len(resultJSON))) + } else { + logger.Debug(fmt.Sprintf("%s Failed to marshal result to JSON: %v", CodeModeLogPrefix, err)) + } + } + + // Add environment information for successful executions + responseText += fmt.Sprintf("\n\nEnvironment:\n Available server keys: %s\n TypeScript used: %s\n Imports stripped: %s", + strings.Join(result.Environment.ServerKeys, ", "), + map[bool]string{true: "Yes", false: "No"}[result.Environment.TypeScriptUsed], + map[bool]string{true: "Yes", false: "No"}[result.Environment.ImportsStripped]) + if len(result.Environment.StrippedLines) > 0 { + strippedStr := make([]string, len(result.Environment.StrippedLines)) + for i, line := range result.Environment.StrippedLines { + strippedStr[i] = fmt.Sprintf("%d", line) + } + responseText += fmt.Sprintf("\n Stripped lines: %s", strings.Join(strippedStr, ", ")) + } + responseText += "\nNote: Browser APIs like fetch, setTimeout are not available. Use MCP tools for external interactions." + logger.Debug(fmt.Sprintf("%s Success response formatted. Response length: %d chars, Server keys: %v", CodeModeLogPrefix, len(responseText), result.Environment.ServerKeys)) + } + } + + logger.Debug(fmt.Sprintf("%s Returning tool response message. Execution success: %v", CodeModeLogPrefix, executionSuccess)) + return createToolResponseMessage(toolCall, responseText), nil +} + +// executeCode executes TypeScript code in a sandboxed VM with MCP tool bindings. +// It handles code preprocessing (stripping imports/exports), TypeScript transpilation, +// VM setup with tool bindings, and promise-based async execution with timeout handling. +// +// Parameters: +// - ctx: Context for code execution (used for timeout and tool access) +// - code: TypeScript code string to execute +// +// Returns: +// - ExecutionResult: Result containing execution output, logs, errors, and environment info +func (m *ToolsManager) executeCode(ctx context.Context, code string) ExecutionResult { + logs := []string{} + strippedLines := []int{} + + logger.Debug(fmt.Sprintf("%s Starting TypeScript code execution", CodeModeLogPrefix)) + + // Step 1: Convert literal \n escape sequences to actual newlines first + // This ensures multiline code and import/export stripping work correctly + codeWithNewlines := strings.ReplaceAll(code, "\\n", "\n") + + // Step 2: Strip import/export statements + cleanedCode, strippedLineNumbers := stripImportsAndExports(codeWithNewlines) + strippedLines = append(strippedLines, strippedLineNumbers...) + if len(strippedLineNumbers) > 0 { + logger.Debug(fmt.Sprintf("%s Stripped %d import/export lines", CodeModeLogPrefix, len(strippedLineNumbers))) + } + + // Step 3: Handle empty code after stripping (in case stripping made it empty) + trimmedCode := strings.TrimSpace(cleanedCode) + if trimmedCode == "" { + // Empty code should return null - return early without VM execution + return ExecutionResult{ + Result: nil, + Logs: logs, + Errors: nil, + Environment: ExecutionEnvironment{ + ServerKeys: []string{}, // Will be populated below if needed, but empty code doesn't need tools + ImportsStripped: len(strippedLines) > 0, + StrippedLines: strippedLines, + TypeScriptUsed: true, + }, + } + } + + // Step 4: Wrap code in async function for proper await transpilation + // TypeScript needs an async function context to properly transpile await expressions + // Check if code is already an async IIFE - if so, await it + trimmedLower := strings.ToLower(strings.TrimSpace(trimmedCode)) + isAsyncIIFE := strings.HasPrefix(trimmedLower, "(async") && strings.Contains(trimmedCode, ")()") + + var codeToTranspile string + if isAsyncIIFE { + // Code is already an async IIFE - await it to get the result + codeToTranspile = fmt.Sprintf("async function __execute__() {\nreturn await %s\n}", trimmedCode) + } else { + // Regular code - wrap in async function + codeToTranspile = fmt.Sprintf("async function __execute__() {\n%s\n}", trimmedCode) + } + + // Step 5: Transpile TypeScript to JavaScript with validation + // Configure TypeScript compiler to transpile async/await to Promise chains (ES5 compatible) + logger.Debug(fmt.Sprintf("%s Transpiling TypeScript code", CodeModeLogPrefix)) + compileOptions := map[string]interface{}{ + "target": "ES5", // Target ES5 for goja compatibility + "module": "None", // No module system + "lib": []string{}, // No lib (minimal environment) + "downlevelIteration": true, // Support async/await transpilation + } + jsCode, transpileErr := typescript.TranspileString(codeToTranspile, typescript.WithCompileOptions(compileOptions)) + if transpileErr != nil { + logger.Debug(fmt.Sprintf("%s TypeScript transpilation failed: %v", CodeModeLogPrefix, transpileErr)) + // Build bindings to get server keys for error hints + availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) + serverKeys := make([]string, 0, len(availableToolsPerClient)) + for clientName := range availableToolsPerClient { + client := m.clientManager.GetClientByName(clientName) + if client == nil { + logger.Warn(fmt.Sprintf("%s Client %s not found, skipping", MCPLogPrefix, clientName)) + continue + } + if !client.ExecutionConfig.IsCodeModeClient { + continue + } + serverKeys = append(serverKeys, clientName) + } + + errorMessage := transpileErr.Error() + hints := generateTypeScriptErrorHints(errorMessage, serverKeys) + + return ExecutionResult{ + Result: nil, + Logs: logs, + Errors: &ExecutionError{ + Kind: ExecutionErrorTypeTypescript, + Message: fmt.Sprintf("TypeScript compilation error: %s", errorMessage), + Hints: hints, + }, + Environment: ExecutionEnvironment{ + ServerKeys: serverKeys, + ImportsStripped: len(strippedLines) > 0, + StrippedLines: strippedLines, + TypeScriptUsed: true, + }, + } + } + + logger.Debug(fmt.Sprintf("%s TypeScript transpiled successfully", CodeModeLogPrefix)) + + // Step 5: Create timeout context early so goroutines can use it + toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration) + timeoutCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) + defer cancel() + + // Step 6: Build bindings for all connected servers + availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) + bindings := make(map[string]map[string]toolBinding) + serverKeys := make([]string, 0, len(availableToolsPerClient)) + + for clientName, tools := range availableToolsPerClient { + client := m.clientManager.GetClientByName(clientName) + if client == nil { + logger.Warn(fmt.Sprintf("%s Client %s not found, skipping", MCPLogPrefix, clientName)) + continue + } + if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 { + continue + } + serverKeys = append(serverKeys, clientName) + + toolFunctions := make(map[string]toolBinding) + + // Create a function for each tool + for _, tool := range tools { + if tool.Function == nil || tool.Function.Name == "" { + continue + } + + originalToolName := tool.Function.Name + // Parse tool name for property name compatibility (used as property name in the runtime) + parsedToolName := parseToolName(originalToolName) + + // Store tool binding + toolFunctions[parsedToolName] = toolBinding{ + toolName: originalToolName, + clientName: clientName, + } + } + + bindings[clientName] = toolFunctions + } + + if len(serverKeys) > 0 { + logger.Debug(fmt.Sprintf("%s Bound %d servers with tools", CodeModeLogPrefix, len(serverKeys))) + } + + // Step 7: Wrap transpiled code to execute the async function and return its result + // The transpiled code contains an async function __execute__() that we need to call + // Trim trailing newlines to avoid issues when wrapping + codeToWrap := strings.TrimRight(jsCode, "\n\r") + // Wrap in IIFE that calls the transpiled async function and returns the promise + wrappedCode := fmt.Sprintf("(function() {\n%s\nreturn __execute__();\n})()", codeToWrap) + + // Step 8: Create goja runtime + vm := goja.New() + + // Step 9: Set up thread-safe logging + appendLog := func(msg string) { + m.logMu.Lock() + defer m.logMu.Unlock() + logs = append(logs, msg) + } + + // Step 10: Set up console + consoleObj := vm.NewObject() + consoleObj.Set("log", func(args ...interface{}) { + message := formatConsoleArgs(args) + appendLog(message) + }) + consoleObj.Set("error", func(args ...interface{}) { + message := formatConsoleArgs(args) + appendLog(fmt.Sprintf("[ERROR] %s", message)) + }) + consoleObj.Set("warn", func(args ...interface{}) { + message := formatConsoleArgs(args) + appendLog(fmt.Sprintf("[WARN] %s", message)) + }) + consoleObj.Set("info", func(args ...interface{}) { + message := formatConsoleArgs(args) + appendLog(fmt.Sprintf("[INFO] %s", message)) + }) + vm.Set("console", consoleObj) + + // Step 11: Set up server bindings + for serverKey, tools := range bindings { + serverObj := vm.NewObject() + for toolName, binding := range tools { + // Capture variables for closure + toolNameFinal := binding.toolName + clientNameFinal := binding.clientName + + serverObj.Set(toolName, func(call goja.FunctionCall) goja.Value { + args := call.Argument(0).Export() + + // Convert args to map[string]interface{} + argsMap, ok := args.(map[string]interface{}) + if !ok { + logger.Debug(fmt.Sprintf("%s Invalid args type for %s.%s: expected object, got %T", + CodeModeLogPrefix, clientNameFinal, toolNameFinal, args)) + // Return rejected promise for invalid args + promise, _, reject := vm.NewPromise() + err := fmt.Errorf("expected object argument, got %T", args) + reject(vm.ToValue(err)) + return vm.ToValue(promise) + } + + // Create promise on VM goroutine (thread-safe) + promise, resolve, reject := vm.NewPromise() + + // Define result struct for channel communication + type toolResult struct { + result interface{} + err error + } + + // Create buffered channel for worker communication + resultChan := make(chan toolResult, 1) + + // Call tool asynchronously with timeout context and panic recovery + // Worker goroutine - NO VM calls allowed here + go func() { + defer func() { + if r := recover(); r != nil { + logger.Debug(fmt.Sprintf("%s Panic in tool call goroutine for %s.%s: %v", + CodeModeLogPrefix, clientNameFinal, toolNameFinal, r)) + // Send panic as error through channel (no VM calls in worker) + select { + case resultChan <- toolResult{nil, fmt.Errorf("tool call panic: %v", r)}: + case <-timeoutCtx.Done(): + // Context cancelled, ignore + } + } + }() + + // Check if context is already cancelled before starting + select { + case <-timeoutCtx.Done(): + // Send timeout error through channel (no VM calls in worker) + select { + case resultChan <- toolResult{nil, fmt.Errorf("execution timeout")}: + case <-timeoutCtx.Done(): + // Already cancelled, ignore + } + return + default: + } + + result, err := m.callMCPTool(timeoutCtx, clientNameFinal, toolNameFinal, argsMap, appendLog) + + // Check if context was cancelled during execution + select { + case <-timeoutCtx.Done(): + // Send timeout error through channel (no VM calls in worker) + select { + case resultChan <- toolResult{nil, fmt.Errorf("execution timeout")}: + case <-timeoutCtx.Done(): + // Already cancelled, ignore + } + return + default: + } + + // Send result through channel (no VM calls in worker) + select { + case resultChan <- toolResult{result, err}: + case <-timeoutCtx.Done(): + // Context cancelled, ignore + } + }() + + // Process result synchronously on VM goroutine to ensure thread safety + // This blocks the VM goroutine until the tool call completes, but ensures + // all VM operations (vm.ToValue, resolve, reject) happen on the correct thread + select { + case res := <-resultChan: + if res.err != nil { + logger.Debug(fmt.Sprintf("%s Tool call failed: %s.%s - %v", + CodeModeLogPrefix, clientNameFinal, toolNameFinal, res.err)) + reject(vm.ToValue(res.err)) + } else { + resolve(vm.ToValue(res.result)) + } + case <-timeoutCtx.Done(): + reject(vm.ToValue(fmt.Errorf("execution timeout"))) + } + + return vm.ToValue(promise) + }) + } + vm.Set(serverKey, serverObj) + } + + // Step 12: Set up environment info + envObj := vm.NewObject() + envObj.Set("serverKeys", serverKeys) + envObj.Set("version", "1.0.0") + vm.Set("__MCP_ENV__", envObj) + + // Step 13: Execute code with timeout + + // Set up interrupt handler + interruptDone := make(chan struct{}) + go func() { + select { + case <-timeoutCtx.Done(): + logger.Debug(fmt.Sprintf("%s Execution timeout reached", CodeModeLogPrefix)) + vm.Interrupt("execution timeout") + case <-interruptDone: + } + }() + + var result interface{} + var executionErr error + + func() { + defer close(interruptDone) + val, err := vm.RunString(wrappedCode) + if err != nil { + logger.Debug(fmt.Sprintf("%s VM execution error: %v", CodeModeLogPrefix, err)) + executionErr = err + return + } + + // Check if the result is a promise by checking its type + // First check if val is nil or undefined (these can't be converted to objects) + if val == nil || val == goja.Undefined() { + result = nil + return + } + + // Try to convert to object to check if it's a promise + // Use recover to safely handle null values that can't be converted to objects + var valObj *goja.Object + func() { + defer func() { + if r := recover(); r != nil { + // Value is null or can't be converted to object, just export it + valObj = nil + } + }() + valObj = val.ToObject(vm) + }() + + if valObj != nil { + // Check if it has a 'then' method (Promise-like) + if then := valObj.Get("then"); then != nil && then != goja.Undefined() { + // It's a promise, we need to await it + // Use buffered channels to prevent blocking if handlers are called after timeout + resultChan := make(chan interface{}, 1) + errChan := make(chan error, 1) + + // Set up promise handlers + thenFunc, ok := goja.AssertFunction(then) + if ok { + // Call then with resolve and reject handlers + _, err := thenFunc(val, + vm.ToValue(func(res goja.Value) { + select { + case resultChan <- res.Export(): + case <-timeoutCtx.Done(): + // Timeout already occurred, ignore result + } + }), + vm.ToValue(func(err goja.Value) { + var errMsg string + if err == nil || err == goja.Undefined() { + errMsg = "unknown error" + } else { + // Try to get error message from Error object + if errObj := err.ToObject(vm); errObj != nil { + if msg := errObj.Get("message"); msg != nil && msg != goja.Undefined() { + errMsg = msg.String() + } else if name := errObj.Get("name"); name != nil && name != goja.Undefined() { + errMsg = name.String() + } else { + errMsg = err.String() + } + } else { + // Fallback to string conversion + errMsg = err.String() + } + } + select { + case errChan <- fmt.Errorf("%s", errMsg): + case <-timeoutCtx.Done(): + // Timeout already occurred, ignore error + } + }), + ) + if err != nil { + executionErr = err + return + } + + // Wait for result or error with timeout + select { + case res := <-resultChan: + result = res + case err := <-errChan: + logger.Debug(fmt.Sprintf("%s Promise rejected: %v", CodeModeLogPrefix, err)) + executionErr = err + case <-timeoutCtx.Done(): + logger.Debug(fmt.Sprintf("%s Promise timeout while waiting for result", CodeModeLogPrefix)) + executionErr = fmt.Errorf("execution timeout") + } + } else { + result = val.Export() + } + } else { + result = val.Export() + } + } else { + // Not an object (or null/undefined), just export the value + result = val.Export() + } + }() + + if executionErr != nil { + errorMessage := executionErr.Error() + hints := generateErrorHints(errorMessage, serverKeys) + logger.Debug(fmt.Sprintf("%s Execution failed: %s", CodeModeLogPrefix, errorMessage)) + + return ExecutionResult{ + Result: nil, + Logs: logs, + Errors: &ExecutionError{ + Kind: ExecutionErrorTypeRuntime, + Message: errorMessage, + Hints: hints, + }, + Environment: ExecutionEnvironment{ + ServerKeys: serverKeys, + ImportsStripped: len(strippedLines) > 0, + StrippedLines: strippedLines, + TypeScriptUsed: true, + }, + } + } + + logger.Debug(fmt.Sprintf("%s Execution completed successfully", CodeModeLogPrefix)) + return ExecutionResult{ + Result: result, + Logs: logs, + Errors: nil, + Environment: ExecutionEnvironment{ + ServerKeys: serverKeys, + ImportsStripped: len(strippedLines) > 0, + StrippedLines: strippedLines, + TypeScriptUsed: true, + }, + } +} + +// callMCPTool calls an MCP tool and returns the result. +// It locates the client by name, constructs the MCP tool call request, executes it +// with timeout handling, and parses the response as JSON or returns it as a string. +// +// Parameters: +// - ctx: Context for tool execution (used for timeout) +// - clientName: Name of the MCP client/server to call +// - toolName: Name of the tool to execute +// - args: Tool arguments as a map +// - appendLog: Function to append log messages during execution +// +// Returns: +// - interface{}: Parsed tool result (JSON object or string) +// - error: Any error that occurred during tool execution +func (m *ToolsManager) callMCPTool(ctx context.Context, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { + // Get available tools per client + availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) + + // Find the client by name + tools, exists := availableToolsPerClient[clientName] + if !exists || len(tools) == 0 { + return nil, fmt.Errorf("client not found for server name: %s", clientName) + } + + // Get client using a tool from this client + // Find the first tool with a valid Function to use for client lookup + var client *schemas.MCPClientState + for _, tool := range tools { + if tool.Function != nil && tool.Function.Name != "" { + client = m.clientManager.GetClientForTool(tool.Function.Name) + if client != nil { + break + } + } + } + + if client == nil { + return nil, fmt.Errorf("client not found for server name: %s", clientName) + } + + // Strip the client name prefix from tool name before calling MCP server + // The MCP server expects the original tool name, not the prefixed version + originalToolName := stripClientPrefix(toolName, clientName) + + // Call the tool via MCP client + callRequest := mcp.CallToolRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsCall), + }, + Params: mcp.CallToolParams{ + Name: originalToolName, + Arguments: args, + }, + } + + // Create timeout context + toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration) + toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) + defer cancel() + + toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest) + if callErr != nil { + logger.Debug(fmt.Sprintf("%s Tool call failed: %s.%s - %v", CodeModeLogPrefix, clientName, toolName, callErr)) + appendLog(fmt.Sprintf("[TOOL] %s.%s error: %v", clientName, toolName, callErr)) + return nil, fmt.Errorf("tool call failed for %s.%s: %v", clientName, toolName, callErr) + } + + // Extract result + rawResult := extractTextFromMCPResponse(toolResponse, toolName) + + // Check if this is an error result (from NewToolResultError) + // Error results start with "Error: " prefix + if after, ok := strings.CutPrefix(rawResult, "Error: "); ok { + errorMsg := after + logger.Debug(fmt.Sprintf("%s Tool returned error result: %s.%s - %s", CodeModeLogPrefix, clientName, toolName, errorMsg)) + appendLog(fmt.Sprintf("[TOOL] %s.%s error result: %s", clientName, toolName, errorMsg)) + return nil, fmt.Errorf("%s", errorMsg) + } + + // Try to parse as JSON, otherwise use as string + var finalResult interface{} + if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil { + // Not JSON, use as string + finalResult = rawResult + } + + // Log the result + resultStr := formatResultForLog(finalResult) + appendLog(fmt.Sprintf("[TOOL] %s.%s raw response: %s", clientName, toolName, resultStr)) + + return finalResult, nil +} + +// HELPER FUNCTIONS + +// formatResultForLog formats a result value for logging purposes. +// It attempts to marshal to JSON for structured output, falling back to string representation. +// +// Parameters: +// - result: The result value to format +// +// Returns: +// - string: Formatted string representation of the result +func formatResultForLog(result interface{}) string { + var resultStr string + if result == nil { + resultStr = "null" + } else if resultBytes, err := sonic.Marshal(result); err == nil { + resultStr = string(resultBytes) + } else { + resultStr = fmt.Sprintf("%v", result) + } + return resultStr +} + +// formatConsoleArgs formats console arguments for logging. +// It formats each argument as JSON if possible, otherwise uses string representation. +// +// Parameters: +// - args: Array of console arguments to format +// +// Returns: +// - string: Formatted string with all arguments joined by spaces +func formatConsoleArgs(args []interface{}) string { + parts := make([]string, len(args)) + for i, arg := range args { + if argBytes, err := sonic.MarshalIndent(arg, "", " "); err == nil { + parts[i] = string(argBytes) + } else { + parts[i] = fmt.Sprintf("%v", arg) + } + } + return strings.Join(parts, " ") +} + +// stripImportsAndExports strips import and export statements from code. +// It removes lines that start with import or export keywords and returns +// the cleaned code along with 1-based line numbers of stripped lines. +// +// Parameters: +// - code: Source code string to process +// +// Returns: +// - string: Code with import/export statements removed +// - []int: 1-based line numbers of stripped lines +func stripImportsAndExports(code string) (string, []int) { + lines := strings.Split(code, "\n") + keptLines := []string{} + strippedLineNumbers := []int{} + + importExportRegex := regexp.MustCompile(`^\s*(import|export)\b`) + + for i, line := range lines { + trimmed := strings.TrimSpace(line) + + // Skip empty lines + if trimmed == "" { + keptLines = append(keptLines, line) + continue + } + + // Check if this is an import or export statement + isImportOrExport := importExportRegex.MatchString(line) + + if isImportOrExport { + strippedLineNumbers = append(strippedLineNumbers, i+1) // 1-based line numbers + continue // Skip import/export lines + } + + // Keep comment lines and all other non-import/export lines + keptLines = append(keptLines, line) + } + + return strings.Join(keptLines, "\n"), strippedLineNumbers +} + +// generateTypeScriptErrorHints generates helpful hints for TypeScript compilation errors. +// It analyzes the error message and provides context-specific guidance based on error patterns. +// +// Parameters: +// - errorMessage: The TypeScript compilation error message +// - serverKeys: List of available MCP server keys for context +// +// Returns: +// - []string: Array of helpful hint messages +func generateTypeScriptErrorHints(errorMessage string, serverKeys []string) []string { + hints := []string{} + + // TypeScript-specific error patterns + if strings.Contains(errorMessage, "Cannot find name") || strings.Contains(errorMessage, "is not defined") { + hints = append(hints, "TypeScript compilation error: undefined variable or identifier.") + hints = append(hints, "Check that all variables are properly declared and typed.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + hints = append(hints, "Use server keys to access MCP tools: .(args)") + } + } else if strings.Contains(errorMessage, "Type") && (strings.Contains(errorMessage, "is not assignable") || strings.Contains(errorMessage, "does not exist")) { + hints = append(hints, "TypeScript type error detected.") + hints = append(hints, "Check that variable types match their usage.") + hints = append(hints, "Ensure function arguments match the expected types.") + } else if strings.Contains(errorMessage, "Expected") { + hints = append(hints, "TypeScript syntax error detected.") + hints = append(hints, "Check for missing parentheses, brackets, or semicolons.") + hints = append(hints, "Ensure all code blocks are properly closed.") + } else if strings.Contains(errorMessage, "async") || strings.Contains(errorMessage, "await") { + hints = append(hints, "async/await syntax should be supported. If you see this error, it may be a TypeScript compilation issue.") + hints = append(hints, "Ensure async functions are properly declared: async function myFunction() { ... }") + hints = append(hints, "Example: const result = await serverName.toolName({...});") + } else { + hints = append(hints, "TypeScript compilation error detected.") + hints = append(hints, "Review the error message above for specific details.") + hints = append(hints, "Ensure your TypeScript code follows valid syntax and type rules.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + } + + return hints +} + +// generateErrorHints generates helpful hints based on runtime error messages. +// It analyzes common runtime error patterns (undefined variables, missing functions, etc.) +// and provides context-specific guidance including available server keys and usage examples. +// +// Parameters: +// - errorMessage: The runtime error message +// - serverKeys: List of available MCP server keys for context +// +// Returns: +// - []string: Array of helpful hint messages +func generateErrorHints(errorMessage string, serverKeys []string) []string { + hints := []string{} + + if strings.Contains(errorMessage, "is not defined") { + re := regexp.MustCompile(`(\w+)\s+is not defined`) + if match := re.FindStringSubmatch(errorMessage); len(match) > 1 { + undefinedVar := match[1] + + // Special handling for common browser/Node.js APIs + if undefinedVar == "fetch" { + hints = append(hints, "The 'fetch' API is not available in this runtime environment.") + hints = append(hints, "Instead of using fetch for HTTP requests, use the available MCP tools.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + hints = append(hints, fmt.Sprintf("Example: const result = await %s.({ url: 'https://example.com' });", serverKeys[0])) + } + hints = append(hints, "MCP tools handle HTTP requests, file operations, and other external interactions.") + return hints + } else if undefinedVar == "XMLHttpRequest" || undefinedVar == "axios" { + hints = append(hints, fmt.Sprintf("The '%s' API is not available in this runtime environment.", undefinedVar)) + hints = append(hints, "Use MCP tools instead for HTTP requests and external API calls.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + return hints + } else if undefinedVar == "setTimeout" || undefinedVar == "setInterval" { + hints = append(hints, fmt.Sprintf("The '%s' API is not available in this runtime environment.", undefinedVar)) + hints = append(hints, "This is a sandboxed environment focused on MCP tool interactions.") + hints = append(hints, "Use Promise chains with MCP tools instead of timing functions.") + return hints + } else if undefinedVar == "require" || undefinedVar == "import" { + hints = append(hints, "Module imports are not supported in this runtime environment.") + hints = append(hints, "Use the available MCP tools for external functionality.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + return hints + } + + // Generic undefined variable handling + hints = append(hints, fmt.Sprintf("Variable or identifier '%s' is not defined.", undefinedVar)) + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Use one of the available server keys as the object name: %s", strings.Join(serverKeys, ", "))) + hints = append(hints, "Then access tools using: .(args)") + hints = append(hints, fmt.Sprintf("For example: const result = await %s.({ ... });", serverKeys[0])) + } + } + } else if strings.Contains(errorMessage, "is not a function") { + re := regexp.MustCompile(`(\w+(?:\.\w+)?)\s+is not a function`) + if match := re.FindStringSubmatch(errorMessage); len(match) > 1 { + notFunction := match[1] + hints = append(hints, fmt.Sprintf("'%s' is not a function.", notFunction)) + hints = append(hints, "Ensure you're using the correct server key and tool name.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + hints = append(hints, "To see available tools for a server, use listToolFiles and readToolFile.") + } + } else if strings.Contains(errorMessage, "Cannot read property") || + strings.Contains(errorMessage, "Cannot read properties") || + strings.Contains(errorMessage, "is not an object") { + hints = append(hints, "You're trying to access a property that doesn't exist or is undefined.") + hints = append(hints, "The tool response structure might be different than expected.") + hints = append(hints, "Check the console logs above to see the actual response structure from the tool.") + hints = append(hints, "Add console.log() statements to inspect the response before accessing properties.") + hints = append(hints, "Example: console.log('searchResults:', searchResults);") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + } else { + hints = append(hints, "Check the error message above for details.") + hints = append(hints, "Check the console logs above to see tool responses and debug the issue.") + if len(serverKeys) > 0 { + hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) + } + hints = append(hints, "Ensure you're using the correct syntax: const result = await .({ ...args });") + } + + return hints +} diff --git a/core/mcp/codemodelistfiles.go b/core/mcp/codemodelistfiles.go new file mode 100644 index 0000000000..730a1083dd --- /dev/null +++ b/core/mcp/codemodelistfiles.go @@ -0,0 +1,229 @@ +package mcp + +import ( + "context" + "fmt" + "strings" + + "github.com/maximhq/bifrost/core/schemas" +) + +// createListToolFilesTool creates the listToolFiles tool definition for code mode. +// This tool allows listing all available virtual .d.ts declaration files for connected MCP servers. +// The description is dynamically generated based on the configured CodeModeBindingLevel. +// +// Returns: +// - schemas.ChatTool: The tool definition for listing tool files +func (m *ToolsManager) createListToolFilesTool() schemas.ChatTool { + bindingLevel := m.GetCodeModeBindingLevel() + var description string + + if bindingLevel == schemas.CodeModeBindingLevelServer { + description = "Returns a tree structure listing all virtual .d.ts declaration files available for connected MCP servers. " + + "Each server has a corresponding file (e.g., servers/.d.ts) that contains definitions for all tools in that server. " + + "Use readToolFile to read a specific server file and see all available tools. " + + "In code, access tools via: await serverName.toolName({ args }). " + + "The server names used in code correspond to the human-readable names shown in this listing. " + + "This tool is generic and works with any set of servers connected at runtime. " + + "Always check this tool whenever you are unsure about what tools you have available or if you want to verify available servers and their tools. " + + "If you have even the SLIGHTEST DOUBT that the current tools might not be useful for the task, check listToolFiles to discover all available tools." + } else { + description = "Returns a tree structure listing all virtual .d.ts declaration files available for connected MCP servers, organized by individual tool. " + + "Each tool has a corresponding file (e.g., servers//.d.ts) that contains definitions for that specific tool. " + + "Use readToolFile to read a specific tool file and see its parameters and usage. " + + "In code, access tools via: await serverName.toolName({ args }). " + + "The server names used in code correspond to the human-readable names shown in this listing. " + + "This tool is generic and works with any set of servers connected at runtime. " + + "Always check this tool whenever you are unsure about what tools you have available or if you want to verify available servers and their tools. " + + "If you have even the SLIGHTEST DOUBT that the current tools might not be useful for the task, check listToolFiles to discover all available tools." + } + + return schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: ToolTypeListToolFiles, + Description: schemas.Ptr(description), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &schemas.OrderedMap{}, + Required: []string{}, + }, + }, + } +} + +// handleListToolFiles handles the listToolFiles tool call. +// It builds a tree structure listing all virtual .d.ts files available for code mode clients. +// The structure depends on the CodeModeBindingLevel: +// - "server": servers/.d.ts (one file per server) +// - "tool": servers//.d.ts (one file per tool) +// +// Parameters: +// - ctx: Context for accessing client tools +// - toolCall: The tool call request containing no arguments +// +// Returns: +// - *schemas.ChatMessage: A tool response message containing the file tree structure +// - error: Any error that occurred during processing +func (m *ToolsManager) handleListToolFiles(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) + + if len(availableToolsPerClient) == 0 { + responseText := "No servers are currently connected. There are no virtual .d.ts files available. " + + "Please ensure servers are connected before using this tool." + return createToolResponseMessage(toolCall, responseText), nil + } + + // Get the code mode binding level + bindingLevel := m.GetCodeModeBindingLevel() + + // Build file list based on binding level + var files []string + codeModeServerCount := 0 + + for clientName, tools := range availableToolsPerClient { + client := m.clientManager.GetClientByName(clientName) + if client == nil { + logger.Warn(fmt.Sprintf("%s Client %s not found, skipping", MCPLogPrefix, clientName)) + continue + } + if !client.ExecutionConfig.IsCodeModeClient { + continue + } + codeModeServerCount++ + + if bindingLevel == schemas.CodeModeBindingLevelServer { + // Server-level: one file per server + files = append(files, fmt.Sprintf("servers/%s.d.ts", clientName)) + } else { + // Tool-level: one file per tool + for _, tool := range tools { + if tool.Function != nil && tool.Function.Name != "" { + toolFileName := fmt.Sprintf("servers/%s/%s.d.ts", clientName, tool.Function.Name) + files = append(files, toolFileName) + } + } + } + } + + if codeModeServerCount == 0 { + responseText := "Servers are connected but none are configured for code mode. " + + "There are no virtual .d.ts files available." + return createToolResponseMessage(toolCall, responseText), nil + } + + // Build tree structure from file list + responseText := buildVFSTree(files) + return createToolResponseMessage(toolCall, responseText), nil +} + +// VFS tree node structure for building hierarchical file structure +type treeNode struct { + isDirectory bool + children map[string]*treeNode +} + +// buildVFSTree creates a hierarchical tree structure from a flat list of file paths. +// It groups files by directory and formats them with proper indentation. +// +// Example input: +// - ["servers/calculator.d.ts", "servers/youtube.d.ts"] +// - ["servers/calculator/add.d.ts", "servers/youtube/GET_CHANNELS.d.ts"] +// +// Example output for server-level: +// servers/ +// calculator.d.ts +// youtube.d.ts +// +// Example output for tool-level: +// servers/ +// calculator/ +// add.d.ts +// youtube/ +// GET_CHANNELS.d.ts +func buildVFSTree(files []string) string { + if len(files) == 0 { + return "" + } + + root := &treeNode{ + isDirectory: true, + children: make(map[string]*treeNode), + } + + // Parse all files and build tree structure + for _, file := range files { + parts := strings.Split(file, "/") + current := root + + // Create all intermediate directories and final file + for i, part := range parts { + if _, exists := current.children[part]; !exists { + current.children[part] = &treeNode{ + isDirectory: i < len(parts)-1, // Last part is file, not directory + children: make(map[string]*treeNode), + } + } + current = current.children[part] + } + } + + // Render tree structure with proper indentation + var lines []string + renderTreeNode(root, "", &lines, true) + + return strings.Join(lines, "\n") +} + +// renderTreeNode recursively renders a tree node and its children with proper indentation. +func renderTreeNode(node *treeNode, indent string, lines *[]string, isRoot bool) { + // Get sorted keys for consistent output + var keys []string + for key := range node.children { + keys = append(keys, key) + } + + // Simple bubble sort for small lists (good enough for this use case) + for i := 0; i < len(keys); i++ { + for j := i + 1; j < len(keys); j++ { + if keys[j] < keys[i] { + keys[i], keys[j] = keys[j], keys[i] + } + } + } + + for _, key := range keys { + child := node.children[key] + + // Format the line + var line string + if isRoot { + // Root level - no indentation + if child.isDirectory { + line = key + "/" + } else { + line = key + } + } else { + // Non-root levels - add indentation + if child.isDirectory { + line = indent + key + "/" + } else { + line = indent + key + } + } + + *lines = append(*lines, line) + + // Recurse into children + if child.isDirectory && len(child.children) > 0 { + var nextIndent string + if isRoot { + nextIndent = " " + } else { + nextIndent = indent + " " + } + renderTreeNode(child, nextIndent, lines, false) + } + } +} diff --git a/core/mcp/codemodereadfile.go b/core/mcp/codemodereadfile.go new file mode 100644 index 0000000000..776a6ac3b3 --- /dev/null +++ b/core/mcp/codemodereadfile.go @@ -0,0 +1,503 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/maximhq/bifrost/core/schemas" +) + +// createReadToolFileTool creates the readToolFile tool definition for code mode. +// This tool allows reading virtual .d.ts declaration files for specific MCP servers/tools, +// generating TypeScript type definitions from the server's tool schemas. +// The description is dynamically generated based on the configured CodeModeBindingLevel. +// +// Returns: +// - schemas.ChatTool: The tool definition for reading tool files +func (m *ToolsManager) createReadToolFileTool() schemas.ChatTool { + bindingLevel := m.GetCodeModeBindingLevel() + + var fileNameDescription, toolDescription string + + if bindingLevel == schemas.CodeModeBindingLevelServer { + fileNameDescription = "The virtual filename from listToolFiles in format: servers/.d.ts (e.g., 'calculator.d.ts')" + toolDescription = "Reads a virtual .d.ts declaration file for a specific MCP server, generating TypeScript type definitions " + + "for all tools available on that server. The fileName should be in format servers/.d.ts as listed by listToolFiles. " + + "The function performs case-insensitive matching and removes the .d.ts extension. " + + "Optionally, you can specify startLine and endLine (1-based, inclusive) to read only a portion of the file. " + + "IMPORTANT: Line numbers are 1-based, not 0-based. The first line is line 1, not line 0. " + + "This generates TypeScript type definitions describing all tools in the server and their argument types, " + + "enabling code-mode execution. Each tool can be accessed in code via: await serverName.toolName({ args }). " + + "Always follow this workflow: first use listToolFiles to see available servers, then use readToolFile to understand " + + "all available tool definitions for a server, and finally use executeToolCode to execute your code." + } else { + fileNameDescription = "The virtual filename from listToolFiles in format: servers//.d.ts (e.g., 'calculator/add.d.ts')" + toolDescription = "Reads a virtual .d.ts declaration file for a specific tool, generating TypeScript type definitions " + + "for that individual tool. The fileName should be in format servers//.d.ts as listed by listToolFiles. " + + "The function performs case-insensitive matching and removes the .d.ts extension. " + + "Optionally, you can specify startLine and endLine (1-based, inclusive) to read only a portion of the file. " + + "IMPORTANT: Line numbers are 1-based, not 0-based. The first line is line 1, not line 0. " + + "This generates TypeScript type definitions for a single tool, describing its parameters and usage, " + + "enabling focused code-mode execution. The tool can be accessed in code via: await serverName.toolName({ args }). " + + "Always follow this workflow: first use listToolFiles to see available tools, then use readToolFile to understand " + + "a specific tool's definition, and finally use executeToolCode to execute your code." + } + + readToolFileProps := schemas.OrderedMap{ + "fileName": map[string]interface{}{ + "type": "string", + "description": fileNameDescription, + }, + "startLine": map[string]interface{}{ + "type": "number", + "description": "Optional 1-based starting line number for partial file read (inclusive). Note: Line numbers start at 1, not 0. The first line is line 1.", + }, + "endLine": map[string]interface{}{ + "type": "number", + "description": "Optional 1-based ending line number for partial file read (inclusive)", + }, + } + return schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: ToolTypeReadToolFile, + Description: schemas.Ptr(toolDescription), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &readToolFileProps, + Required: []string{"fileName"}, + }, + }, + } +} + +// handleReadToolFile handles the readToolFile tool call. +// It reads a virtual .d.ts file for a specific MCP server/tool, generates TypeScript type definitions, +// and optionally returns a portion of the file based on line range parameters. +// Supports both server-level files (e.g., "calculator.d.ts") and tool-level files (e.g., "calculator/add.d.ts"). +// +// Parameters: +// - ctx: Context for accessing client tools +// - toolCall: The tool call request containing fileName and optional startLine/endLine +// +// Returns: +// - *schemas.ChatMessage: A tool response message containing the TypeScript definitions +// - error: Any error that occurred during processing +func (m *ToolsManager) handleReadToolFile(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + // Parse tool arguments + var arguments map[string]interface{} + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + return nil, fmt.Errorf("failed to parse tool arguments: %v", err) + } + + fileName, ok := arguments["fileName"].(string) + if !ok || fileName == "" { + return nil, fmt.Errorf("fileName parameter is required and must be a string") + } + + // Parse the file path to extract server name and optional tool name + serverName, toolName, isToolLevel := parseVFSFilePath(fileName) + + // Get available tools per client + availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) + + // Find matching client + var matchedClientName string + var matchedTools []schemas.ChatTool + matchCount := 0 + + for clientName, tools := range availableToolsPerClient { + client := m.clientManager.GetClientByName(clientName) + if client == nil { + logger.Warn(fmt.Sprintf("%s Client %s not found, skipping", MCPLogPrefix, clientName)) + continue + } + if !client.ExecutionConfig.IsCodeModeClient || len(tools) == 0 { + continue + } + + clientNameLower := strings.ToLower(clientName) + serverNameLower := strings.ToLower(serverName) + + if clientNameLower == serverNameLower { + matchCount++ + if matchCount > 1 { + // Multiple matches found + errorMsg := fmt.Sprintf("Multiple servers match filename '%s':\n", fileName) + for name := range availableToolsPerClient { + if strings.ToLower(name) == serverNameLower { + errorMsg += fmt.Sprintf(" - %s\n", name) + } + } + errorMsg += "\nPlease use a more specific filename. Use the exact display name from listToolFiles to avoid ambiguity." + return createToolResponseMessage(toolCall, errorMsg), nil + } + + matchedClientName = clientName + + if isToolLevel { + // Tool-level: filter to specific tool + var foundTool *schemas.ChatTool + toolNameLower := strings.ToLower(toolName) + for i, tool := range tools { + if tool.Function != nil && strings.ToLower(tool.Function.Name) == toolNameLower { + foundTool = &tools[i] + break + } + } + + if foundTool == nil { + availableTools := make([]string, 0) + for _, tool := range tools { + if tool.Function != nil { + availableTools = append(availableTools, tool.Function.Name) + } + } + errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools in this server are:\n", toolName, clientName) + for _, t := range availableTools { + errorMsg += fmt.Sprintf(" - %s/%s.d.ts\n", clientName, t) + } + return createToolResponseMessage(toolCall, errorMsg), nil + } + + matchedTools = []schemas.ChatTool{*foundTool} + } else { + // Server-level: use all tools + matchedTools = tools + } + } + } + + if matchedClientName == "" { + // Build helpful error message with available files + bindingLevel := m.GetCodeModeBindingLevel() + var availableFiles []string + + for name := range availableToolsPerClient { + if bindingLevel == schemas.CodeModeBindingLevelServer { + availableFiles = append(availableFiles, fmt.Sprintf("%s.d.ts", name)) + } else { + client := m.clientManager.GetClientByName(name) + if client != nil && client.ExecutionConfig.IsCodeModeClient { + if tools, ok := availableToolsPerClient[name]; ok { + for _, tool := range tools { + if tool.Function != nil { + availableFiles = append(availableFiles, fmt.Sprintf("%s/%s.d.ts", name, tool.Function.Name)) + } + } + } + } + } + } + + errorMsg := fmt.Sprintf("No server found matching '%s'. Available virtual files are:\n", serverName) + for _, f := range availableFiles { + errorMsg += fmt.Sprintf(" - %s\n", f) + } + return createToolResponseMessage(toolCall, errorMsg), nil + } + + // Generate TypeScript definitions + fileContent := generateTypeDefinitions(matchedClientName, matchedTools, isToolLevel) + lines := strings.Split(fileContent, "\n") + totalLines := len(lines) + + // Handle line slicing if provided + var startLine, endLine *int + if sl, ok := arguments["startLine"].(float64); ok { + slInt := int(sl) + startLine = &slInt + } + if el, ok := arguments["endLine"].(float64); ok { + elInt := int(el) + endLine = &elInt + } + + if startLine != nil || endLine != nil { + start := 1 + if startLine != nil { + start = *startLine + } + end := totalLines + if endLine != nil { + end = *endLine + } + + // Validate line numbers + if start < 1 || start > totalLines { + errorMsg := fmt.Sprintf("Invalid startLine: %d. Must be between 1 and %d (total lines in file). Provided: startLine=%d, endLine=%v, totalLines=%d", + start, totalLines, start, endLine, totalLines) + return createToolResponseMessage(toolCall, errorMsg), nil + } + if end < 1 || end > totalLines { + errorMsg := fmt.Sprintf("Invalid endLine: %d. Must be between 1 and %d (total lines in file). Provided: startLine=%d, endLine=%d, totalLines=%d", + end, totalLines, start, end, totalLines) + return createToolResponseMessage(toolCall, errorMsg), nil + } + if start > end { + errorMsg := fmt.Sprintf("Invalid line range: startLine (%d) must be less than or equal to endLine (%d). Total lines in file: %d", + start, end, totalLines) + return createToolResponseMessage(toolCall, errorMsg), nil + } + + // Slice lines (convert to 0-based indexing) + selectedLines := lines[start-1 : end] + fileContent = strings.Join(selectedLines, "\n") + } + + return createToolResponseMessage(toolCall, fileContent), nil +} + +// HELPER FUNCTIONS + +// parseVFSFilePath parses a VFS file path and extracts the server name and optional tool name. +// For server-level paths (e.g., "calculator.d.ts"), returns (serverName="calculator", toolName="", isToolLevel=false) +// For tool-level paths (e.g., "calculator/add.d.ts"), returns (serverName="calculator", toolName="add", isToolLevel=true) +// +// Parameters: +// - fileName: The virtual file path from listToolFiles +// +// Returns: +// - serverName: The name of the MCP server +// - toolName: The name of the tool (empty for server-level) +// - isToolLevel: Whether this is a tool-level path +func parseVFSFilePath(fileName string) (serverName, toolName string, isToolLevel bool) { + // Remove .d.ts extension + basePath := strings.TrimSuffix(fileName, ".d.ts") + + // Remove "servers/" prefix if present + basePath = strings.TrimPrefix(basePath, "servers/") + + // Check for path separator + parts := strings.Split(basePath, "/") + if len(parts) == 2 { + // Tool-level: "serverName/toolName" + return parts[0], parts[1], true + } + // Server-level: "serverName" + return basePath, "", false +} + +// generateTypeDefinitions generates TypeScript type definitions from ChatTool schemas +// with comprehensive comments to help LLMs understand how to use the tools. +// It creates interfaces for tool inputs and responses, along with function declarations. +// +// Parameters: +// - clientName: Name of the MCP client/server +// - tools: List of chat tools to generate definitions for +// - isToolLevel: Whether this is a tool-level definition (single tool) or server-level (all tools) +// +// Returns: +// - string: Complete TypeScript declaration file content +func generateTypeDefinitions(clientName string, tools []schemas.ChatTool, isToolLevel bool) string { + var sb strings.Builder + + // Write comprehensive header comment + sb.WriteString("// ============================================================================\n") + if isToolLevel && len(tools) == 1 && tools[0].Function != nil { + // Tool-level: show individual tool name + sb.WriteString(fmt.Sprintf("// Type definitions for %s.%s tool\n", clientName, tools[0].Function.Name)) + } else { + // Server-level: show all tools in server + sb.WriteString(fmt.Sprintf("// Type definitions for %s MCP server\n", clientName)) + } + sb.WriteString("// ============================================================================\n") + sb.WriteString("//\n") + if isToolLevel && len(tools) == 1 { + sb.WriteString("// This file contains TypeScript type definitions for a specific tool on this MCP server.\n") + } else { + sb.WriteString("// This file contains TypeScript type definitions for all tools available on this MCP server.\n") + } + sb.WriteString("// These definitions enable code-mode execution as described in the MCP code execution pattern.\n") + sb.WriteString("//\n") + sb.WriteString("// USAGE INSTRUCTIONS:\n") + sb.WriteString("// 1. Each tool has an input interface (e.g., ToolNameInput) that defines the required parameters\n") + sb.WriteString("// 2. Each tool has a function declaration showing how to call it\n") + sb.WriteString("// 3. To use these tools in executeToolCode, you would call them like:\n") + sb.WriteString("// const result = await .({ ...args });\n") + sb.WriteString("//\n") + sb.WriteString("// NOTE: The server name used in executeToolCode is the same as the display name shown here.\n") + sb.WriteString("// ============================================================================\n\n") + + // Generate interfaces and function declarations for each tool + for _, tool := range tools { + if tool.Function == nil || tool.Function.Name == "" { + continue + } + + originalToolName := tool.Function.Name + // Parse tool name for property name compatibility (used in virtual TypeScript files) + toolName := parseToolName(originalToolName) + description := "" + if tool.Function.Description != nil { + description = *tool.Function.Description + } + + // Generate input interface with detailed comments + inputInterfaceName := toPascalCase(toolName) + "Input" + sb.WriteString("// ----------------------------------------------------------------------------\n") + sb.WriteString(fmt.Sprintf("// Tool: %s\n", toolName)) + sb.WriteString("// ----------------------------------------------------------------------------\n") + if description != "" { + sb.WriteString(fmt.Sprintf("// Description: %s\n", description)) + } + sb.WriteString(fmt.Sprintf("// Input interface for %s\n", toolName)) + sb.WriteString(fmt.Sprintf("// This interface defines all parameters that can be passed to the %s tool.\n", toolName)) + sb.WriteString(fmt.Sprintf("interface %s {\n", inputInterfaceName)) + + if tool.Function.Parameters != nil && tool.Function.Parameters.Properties != nil { + props := *tool.Function.Parameters.Properties + required := make(map[string]bool) + if tool.Function.Parameters.Required != nil { + for _, req := range tool.Function.Parameters.Required { + required[req] = true + } + } + + // Sort properties for consistent output + propNames := make([]string, 0, len(props)) + for name := range props { + propNames = append(propNames, name) + } + // Simple alphabetical sort + for i := 0; i < len(propNames)-1; i++ { + for j := i + 1; j < len(propNames); j++ { + if propNames[i] > propNames[j] { + propNames[i], propNames[j] = propNames[j], propNames[i] + } + } + } + + for _, propName := range propNames { + prop := props[propName] + propMap, ok := prop.(map[string]interface{}) + if !ok { + continue + } + + tsType := jsonSchemaToTypeScript(propMap) + optional := "" + if !required[propName] { + optional = "?" + } + + propDesc := "" + if desc, ok := propMap["description"].(string); ok && desc != "" { + propDesc = fmt.Sprintf(" // %s", desc) + } else { + propDesc = fmt.Sprintf(" // %s parameter", propName) + } + + requiredNote := "" + if required[propName] { + requiredNote = " (required)" + } else { + requiredNote = " (optional)" + } + + sb.WriteString(fmt.Sprintf(" %s%s: %s;%s%s\n", propName, optional, tsType, propDesc, requiredNote)) + } + } + + sb.WriteString("}\n\n") + + // Generate response interface with helpful comments + responseInterfaceName := toPascalCase(toolName) + "Response" + sb.WriteString(fmt.Sprintf("// Response interface for %s\n", toolName)) + sb.WriteString("// The actual response structure depends on the tool implementation.\n") + sb.WriteString("// This is a placeholder interface - the actual response may contain different fields.\n") + sb.WriteString(fmt.Sprintf("interface %s {\n", responseInterfaceName)) + sb.WriteString(" // Response structure depends on the tool implementation\n") + sb.WriteString(" // Common fields may include: result, error, data, etc.\n") + sb.WriteString(" [key: string]: any;\n") + sb.WriteString("}\n\n") + + // Generate function declaration with usage example + sb.WriteString(fmt.Sprintf("// Function declaration for %s\n", toolName)) + if description != "" { + sb.WriteString(fmt.Sprintf("// %s\n", description)) + } + sb.WriteString("//\n") + sb.WriteString("// Usage example in executeToolCode:\n") + sb.WriteString(fmt.Sprintf("// const result = await .%s({ ... });\n", toolName)) + sb.WriteString("// // Replace with the actual server name/ID\n") + sb.WriteString(fmt.Sprintf("// // Replace { ... } with the appropriate %sInput object\n", inputInterfaceName)) + sb.WriteString(fmt.Sprintf("export async function %s(input: %s): Promise<%s>;\n\n", toolName, inputInterfaceName, responseInterfaceName)) + } + + return sb.String() +} + +// jsonSchemaToTypeScript converts a JSON Schema type definition to a TypeScript type string. +// It handles basic types, arrays, enums, and defaults to "any" for unknown types. +// +// Parameters: +// - prop: JSON Schema property definition map +// +// Returns: +// - string: TypeScript type string representation +func jsonSchemaToTypeScript(prop map[string]interface{}) string { + // Check for explicit type + if typeVal, ok := prop["type"].(string); ok { + switch typeVal { + case "string": + return "string" + case "number", "integer": + return "number" + case "boolean": + return "boolean" + case "array": + itemsType := "any" + if items, ok := prop["items"].(map[string]interface{}); ok { + itemsType = jsonSchemaToTypeScript(items) + } + return fmt.Sprintf("%s[]", itemsType) + case "object": + return "object" + case "null": + return "null" + } + } + + // Check for enum + if enum, ok := prop["enum"].([]interface{}); ok && len(enum) > 0 { + enumStrs := make([]string, 0, len(enum)) + for _, e := range enum { + enumStrs = append(enumStrs, fmt.Sprintf("%q", e)) + } + return strings.Join(enumStrs, " | ") + } + + // Default to any + return "any" +} + +// toPascalCase converts a string to PascalCase format. +// It splits on underscores, hyphens, and spaces, then capitalizes the first letter +// of each word and lowercases the rest. +// +// Parameters: +// - s: Input string to convert +// +// Returns: +// - string: PascalCase formatted string +func toPascalCase(s string) string { + if s == "" { + return s + } + parts := strings.FieldsFunc(s, func(r rune) bool { + return r == '_' || r == '-' || r == ' ' + }) + result := "" + for _, part := range parts { + if len(part) > 0 { + result += strings.ToUpper(part[:1]) + strings.ToLower(part[1:]) + } + } + if result == "" { + return strings.ToUpper(s[:1]) + strings.ToLower(s[1:]) + } + return result +} diff --git a/core/mcp/health_monitor.go b/core/mcp/health_monitor.go new file mode 100644 index 0000000000..6a55938fef --- /dev/null +++ b/core/mcp/health_monitor.go @@ -0,0 +1,231 @@ +package mcp + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/maximhq/bifrost/core/schemas" +) + +const ( + // Health check configuration + DefaultHealthCheckInterval = 10 * time.Second // Interval between health checks + DefaultHealthCheckTimeout = 5 * time.Second // Timeout for each health check + MaxConsecutiveFailures = 5 // Number of failures before marking as unhealthy +) + +// ClientHealthMonitor tracks the health status of an MCP client +type ClientHealthMonitor struct { + manager *MCPManager + clientID string + interval time.Duration + timeout time.Duration + maxConsecutiveFailures int + mu sync.Mutex + ticker *time.Ticker + ctx context.Context + cancel context.CancelFunc + isMonitoring bool + consecutiveFailures int +} + +// NewClientHealthMonitor creates a new health monitor for an MCP client +func NewClientHealthMonitor( + manager *MCPManager, + clientID string, + interval time.Duration, +) *ClientHealthMonitor { + if interval == 0 { + interval = DefaultHealthCheckInterval + } + + return &ClientHealthMonitor{ + manager: manager, + clientID: clientID, + interval: interval, + timeout: DefaultHealthCheckTimeout, + maxConsecutiveFailures: MaxConsecutiveFailures, + isMonitoring: false, + consecutiveFailures: 0, + } +} + +// Start begins monitoring the client's health in a background goroutine +func (chm *ClientHealthMonitor) Start() { + chm.mu.Lock() + defer chm.mu.Unlock() + + if chm.isMonitoring { + return // Already monitoring + } + + chm.isMonitoring = true + chm.ctx, chm.cancel = context.WithCancel(context.Background()) + chm.ticker = time.NewTicker(chm.interval) + + go chm.monitorLoop() + logger.Debug(fmt.Sprintf("%s Health monitor started for client %s (interval: %v)", MCPLogPrefix, chm.clientID, chm.interval)) +} + +// Stop stops monitoring the client's health +func (chm *ClientHealthMonitor) Stop() { + chm.mu.Lock() + defer chm.mu.Unlock() + + if !chm.isMonitoring { + return // Not monitoring + } + + chm.isMonitoring = false + if chm.ticker != nil { + chm.ticker.Stop() + } + if chm.cancel != nil { + chm.cancel() + } + logger.Debug(fmt.Sprintf("%s Health monitor stopped for client %s", MCPLogPrefix, chm.clientID)) +} + +// monitorLoop runs the health check loop +func (chm *ClientHealthMonitor) monitorLoop() { + for { + select { + case <-chm.ctx.Done(): + return + case <-chm.ticker.C: + chm.performHealthCheck() + } + } +} + +// performHealthCheck performs a health check on the client +func (chm *ClientHealthMonitor) performHealthCheck() { + // Get the client connection + chm.manager.mu.RLock() + clientState, exists := chm.manager.clientMap[chm.clientID] + chm.manager.mu.RUnlock() + + if !exists { + chm.Stop() + return + } + + if clientState.Conn == nil { + // Client not connected, mark as disconnected + chm.updateClientState(schemas.MCPConnectionStateDisconnected) + chm.incrementFailures() + return + } + + // Perform ping with timeout + ctx, cancel := context.WithTimeout(context.Background(), chm.timeout) + defer cancel() + + err := clientState.Conn.Ping(ctx) + if err != nil { + chm.incrementFailures() + + // After max consecutive failures, mark as disconnected + if chm.getConsecutiveFailures() >= chm.maxConsecutiveFailures { + chm.updateClientState(schemas.MCPConnectionStateDisconnected) + } + } else { + // Health check passed + chm.resetFailures() + chm.updateClientState(schemas.MCPConnectionStateConnected) + } +} + +// updateClientState updates the client's connection state +func (chm *ClientHealthMonitor) updateClientState(state schemas.MCPConnectionState) { + chm.manager.mu.Lock() + clientState, exists := chm.manager.clientMap[chm.clientID] + if !exists { + chm.manager.mu.Unlock() + return + } + + // Only update if state changed + stateChanged := clientState.State != state + if stateChanged { + clientState.State = state + } + chm.manager.mu.Unlock() + + // Log after releasing the lock + if stateChanged { + logger.Info(fmt.Sprintf("%s Client %s connection state changed to: %s", MCPLogPrefix, chm.clientID, state)) + } +} + +// incrementFailures increments the consecutive failure counter +func (chm *ClientHealthMonitor) incrementFailures() { + chm.mu.Lock() + defer chm.mu.Unlock() + chm.consecutiveFailures++ +} + +// resetFailures resets the consecutive failure counter +func (chm *ClientHealthMonitor) resetFailures() { + chm.mu.Lock() + defer chm.mu.Unlock() + chm.consecutiveFailures = 0 +} + +// getConsecutiveFailures returns the current consecutive failure count +func (chm *ClientHealthMonitor) getConsecutiveFailures() int { + chm.mu.Lock() + defer chm.mu.Unlock() + return chm.consecutiveFailures +} + +// HealthMonitorManager manages all client health monitors +type HealthMonitorManager struct { + monitors map[string]*ClientHealthMonitor + mu sync.RWMutex +} + +// NewHealthMonitorManager creates a new health monitor manager +func NewHealthMonitorManager() *HealthMonitorManager { + return &HealthMonitorManager{ + monitors: make(map[string]*ClientHealthMonitor), + } +} + +// StartMonitoring starts monitoring a specific client +func (hmm *HealthMonitorManager) StartMonitoring(monitor *ClientHealthMonitor) { + hmm.mu.Lock() + defer hmm.mu.Unlock() + + // Stop any existing monitor for this client + if existing, ok := hmm.monitors[monitor.clientID]; ok { + existing.Stop() + } + + hmm.monitors[monitor.clientID] = monitor + monitor.Start() +} + +// StopMonitoring stops monitoring a specific client +func (hmm *HealthMonitorManager) StopMonitoring(clientID string) { + hmm.mu.Lock() + defer hmm.mu.Unlock() + + if monitor, ok := hmm.monitors[clientID]; ok { + monitor.Stop() + delete(hmm.monitors, clientID) + } +} + +// StopAll stops all monitoring +func (hmm *HealthMonitorManager) StopAll() { + hmm.mu.Lock() + defer hmm.mu.Unlock() + + for _, monitor := range hmm.monitors { + monitor.Stop() + } + hmm.monitors = make(map[string]*ClientHealthMonitor) +} diff --git a/core/mcp/init.go b/core/mcp/init.go new file mode 100644 index 0000000000..d0eb389c18 --- /dev/null +++ b/core/mcp/init.go @@ -0,0 +1,9 @@ +package mcp + +import "github.com/maximhq/bifrost/core/schemas" + +var logger schemas.Logger + +func SetLogger(l schemas.Logger) { + logger = l +} diff --git a/core/mcp/mcp.go b/core/mcp/mcp.go new file mode 100644 index 0000000000..89e9603eb7 --- /dev/null +++ b/core/mcp/mcp.go @@ -0,0 +1,288 @@ +package mcp + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/maximhq/bifrost/core/schemas" + + "github.com/mark3labs/mcp-go/server" +) + +// ============================================================================ +// CONSTANTS +// ============================================================================ + +const ( + // MCP defaults and identifiers + BifrostMCPVersion = "1.0.0" // Version identifier for Bifrost + BifrostMCPClientName = "BifrostClient" // Name for internal Bifrost MCP client + BifrostMCPClientKey = "bifrostInternal" // Key for internal Bifrost client in clientMap + MCPLogPrefix = "[Bifrost MCP]" // Consistent logging prefix + MCPClientConnectionEstablishTimeout = 30 * time.Second // Timeout for MCP client connection establishment + + // Context keys for client filtering in requests + // NOTE: []string is used for both keys, and by default all clients/tools are included (when nil). + // If "*" is present, all clients/tools are included, and [] means no clients/tools are included. + // Request context filtering takes priority over client config - context can override client exclusions. + MCPContextKeyIncludeClients schemas.BifrostContextKey = "mcp-include-clients" // Context key for whitelist client filtering + MCPContextKeyIncludeTools schemas.BifrostContextKey = "mcp-include-tools" // Context key for whitelist tool filtering (Note: toolName should be in "clientName/toolName" format) +) + +// ============================================================================ +// TYPE DEFINITIONS +// ============================================================================ + +// MCPManager manages MCP integration for Bifrost core. +// It provides a bridge between Bifrost and various MCP servers, supporting +// both local tool hosting and external MCP server connections. +type MCPManager struct { + ctx context.Context + toolsManager *ToolsManager // Handler for MCP tools + server *server.MCPServer // Local MCP server instance for hosting tools (STDIO-based) + clientMap map[string]*schemas.MCPClientState // Map of MCP client names to their configurations + mu sync.RWMutex // Read-write mutex for thread-safe operations + serverRunning bool // Track whether local MCP server is running + healthMonitorManager *HealthMonitorManager // Manager for client health monitors +} + +// MCPToolFunction is a generic function type for handling tool calls with typed arguments. +// T represents the expected argument structure for the tool. +type MCPToolFunction[T any] func(args T) (string, error) + +// ============================================================================ +// CONSTRUCTOR AND INITIALIZATION +// ============================================================================ + +// NewMCPManager creates and initializes a new MCP manager instance. +// +// Parameters: +// - config: MCP configuration including server port and client configs +// - logger: Logger instance for structured logging (uses default if nil) +// +// Returns: +// - *MCPManager: Initialized manager instance +// - error: Any initialization error +func NewMCPManager(ctx context.Context, config schemas.MCPConfig, logger schemas.Logger) *MCPManager { + SetLogger(logger) + // Set default values + if config.ToolManagerConfig == nil { + config.ToolManagerConfig = &schemas.MCPToolManagerConfig{ + ToolExecutionTimeout: schemas.DefaultToolExecutionTimeout, + MaxAgentDepth: schemas.DefaultMaxAgentDepth, + } + } + // Creating new instance + manager := &MCPManager{ + ctx: ctx, + clientMap: make(map[string]*schemas.MCPClientState), + healthMonitorManager: NewHealthMonitorManager(), + } + manager.toolsManager = NewToolsManager(config.ToolManagerConfig, manager, config.FetchNewRequestIDFunc) + // Process client configs: create client map entries and establish connections + if len(config.ClientConfigs) > 0 { + for _, clientConfig := range config.ClientConfigs { + if err := manager.AddClient(clientConfig); err != nil { + logger.Warn(fmt.Sprintf("%s Failed to add MCP client %s: %v", MCPLogPrefix, clientConfig.Name, err)) + } + } + } + logger.Info(MCPLogPrefix + " MCP Manager initialized") + return manager +} + +// AddToolsToRequest parses available MCP tools from the context and adds them to the request. +// It respects context-based filtering for clients and tools, and returns the modified request +// with tools attached. +// +// Parameters: +// - ctx: Context containing optional client/tool filtering keys +// - req: The Bifrost request to add tools to +// +// Returns: +// - *schemas.BifrostRequest: The request with tools added +func (m *MCPManager) AddToolsToRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest { + return m.toolsManager.ParseAndAddToolsToRequest(ctx, req) +} + +func (m *MCPManager) GetAvailableTools(ctx context.Context) []schemas.ChatTool { + return m.toolsManager.GetAvailableTools(ctx) +} + +// ExecuteChatTool executes a single tool call and returns the result as a chat message. +// This is the primary tool executor and is used by both Chat Completions and Responses APIs. +// +// The method accepts tool calls in Chat API format (ChatAssistantMessageToolCall) and returns +// results in Chat API format (ChatMessage). For Responses API users: +// - Convert ResponsesToolMessage to ChatAssistantMessageToolCall using ToChatAssistantMessageToolCall() +// - Execute the tool with this method +// - Convert the result back using ChatMessage.ToResponsesToolMessage() +// +// Alternatively, use ExecuteResponsesTool() in the ToolsManager for a type-safe wrapper +// that handles format conversions automatically. +// +// Parameters: +// - ctx: Context for the tool execution +// - toolCall: The tool call to execute in Chat API format +// +// Returns: +// - *schemas.ChatMessage: The result message containing tool execution output +// - error: Any error that occurred during tool execution +func (m *MCPManager) ExecuteChatTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + return m.toolsManager.ExecuteChatTool(ctx, toolCall) +} + +// ExecuteResponsesTool executes a single tool call and returns the result as a responses message. + +// - ctx: Context for the tool execution +// - toolCall: The tool call to execute in Responses API format +// +// Returns: +// - *schemas.ResponsesMessage: The result message containing tool execution output +// - error: Any error that occurred during tool execution +func (m *MCPManager) ExecuteResponsesTool(ctx *schemas.BifrostContext, toolCall *schemas.ResponsesToolMessage) (*schemas.ResponsesMessage, error) { + return m.toolsManager.ExecuteResponsesTool(ctx, toolCall) +} + +// UpdateToolManagerConfig updates the configuration for the tool manager. +// This allows runtime updates to settings like execution timeout and max agent depth. +// +// Parameters: +// - config: The new tool manager configuration to apply +func (m *MCPManager) UpdateToolManagerConfig(config *schemas.MCPToolManagerConfig) { + m.toolsManager.UpdateConfig(config) +} + +// CheckAndExecuteAgentForChatRequest checks if the chat response contains tool calls, +// and if so, executes agent mode to handle the tool calls iteratively. If no tool calls +// are present, it returns the original response unchanged. +// +// Agent mode enables autonomous tool execution where: +// 1. Tool calls are automatically executed +// 2. Results are fed back to the LLM +// 3. The loop continues until no more tool calls are made or max depth is reached +// 4. Non-auto-executable tools are returned to the caller +// +// This method is available for both Chat Completions and Responses APIs. +// For Responses API, use CheckAndExecuteAgentForResponsesRequest(). +// +// Parameters: +// - ctx: Context for the agent execution +// - req: The original chat request +// - response: The initial chat response that may contain tool calls +// - makeReq: Function to make subsequent chat requests during agent execution +// +// Returns: +// - *schemas.BifrostChatResponse: The final response after agent execution (or original if no tool calls) +// - *schemas.BifrostError: Any error that occurred during agent execution +func (m *MCPManager) CheckAndExecuteAgentForChatRequest( + ctx *schemas.BifrostContext, + req *schemas.BifrostChatRequest, + response *schemas.BifrostChatResponse, + makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError), +) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + if makeReq == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "makeReq is required to execute agent mode", + }, + } + } + // Check if initial response has tool calls + if !hasToolCallsForChatResponse(response) { + logger.Debug("No tool calls detected, returning response") + return response, nil + } + // Execute agent mode + return m.toolsManager.ExecuteAgentForChatRequest(ctx, req, response, makeReq) +} + +// CheckAndExecuteAgentForResponsesRequest checks if the responses response contains tool calls, +// and if so, executes agent mode to handle the tool calls iteratively. If no tool calls +// are present, it returns the original response unchanged. +// +// Agent mode for Responses API works identically to Chat API: +// 1. Detects tool calls in the response (function_call messages) +// 2. Automatically executes tools in parallel when possible +// 3. Feeds results back to the LLM in Responses API format +// 4. Continues the loop until no more tool calls or max depth reached +// 5. Returns non-auto-executable tools to the caller +// +// Format Handling: +// This method automatically handles format conversions: +// - Responses tool calls (ResponsesToolMessage) are converted to Chat format for execution +// - Tool execution results are converted back to Responses format (ResponsesMessage) +// - All conversions use the adapters in agent_adaptors.go and converters in schemas/mux.go +// +// This provides full feature parity between Chat Completions and Responses APIs for tool execution. +// +// Parameters: +// - ctx: Context for the agent execution +// - req: The original responses request +// - response: The initial responses response that may contain tool calls +// - makeReq: Function to make subsequent responses requests during agent execution +// +// Returns: +// - *schemas.BifrostResponsesResponse: The final response after agent execution (or original if no tool calls) +// - *schemas.BifrostError: Any error that occurred during agent execution +func (m *MCPManager) CheckAndExecuteAgentForResponsesRequest( + ctx *schemas.BifrostContext, + req *schemas.BifrostResponsesRequest, + response *schemas.BifrostResponsesResponse, + makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError), +) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + if makeReq == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "makeReq is required to execute agent mode", + }, + } + } + // Check if initial response has tool calls + if !hasToolCallsForResponsesResponse(response) { + logger.Debug("No tool calls detected, returning response") + return response, nil + } + // Execute agent mode + return m.toolsManager.ExecuteAgentForResponsesRequest(ctx, req, response, makeReq) +} + +// Cleanup performs cleanup of all MCP resources including clients and local server. +// This function safely disconnects all MCP clients (HTTP, STDIO, and SSE) and +// cleans up the local MCP server. It handles proper cancellation of SSE contexts +// and closes all transport connections. +// +// Returns: +// - error: Always returns nil, but maintains error interface for consistency +func (m *MCPManager) Cleanup() error { + // Stop all health monitors first + m.healthMonitorManager.StopAll() + + m.mu.Lock() + defer m.mu.Unlock() + + // Disconnect all external MCP clients + for id := range m.clientMap { + if err := m.removeClientUnsafe(id); err != nil { + logger.Error("%s Failed to remove MCP client %s: %v", MCPLogPrefix, id, err) + } + } + + // Clear the client map + m.clientMap = make(map[string]*schemas.MCPClientState) + + // Clear local server reference + // Note: mark3labs/mcp-go STDIO server cleanup is handled automatically + if m.server != nil { + logger.Info(MCPLogPrefix + " Clearing local MCP server reference") + m.server = nil + m.serverRunning = false + } + + logger.Info(MCPLogPrefix + " MCP cleanup completed") + return nil +} diff --git a/core/mcp/toolmanager.go b/core/mcp/toolmanager.go new file mode 100644 index 0000000000..3197b7fd55 --- /dev/null +++ b/core/mcp/toolmanager.go @@ -0,0 +1,556 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/maximhq/bifrost/core/schemas" +) + +type ClientManager interface { + GetClientByName(clientName string) *schemas.MCPClientState + GetClientForTool(toolName string) *schemas.MCPClientState + GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool +} + +type ToolsManager struct { + toolExecutionTimeout atomic.Value + maxAgentDepth atomic.Int32 + codeModeBindingLevel atomic.Value // Stores CodeModeBindingLevel + clientManager ClientManager + logMu sync.Mutex // Protects concurrent access to logs slice in codemode execution + + // Function to fetch a new request ID for each tool call result message in agent mode, + // this is used to ensure that the tool call result messages are unique and can be tracked in plugins or by the user. + // This id is attached to ctx.Value(schemas.BifrostContextKeyRequestID) in the agent mode. + // If not provided, same request ID is used for all tool call result messages without any overrides. + fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string +} + +const ( + ToolTypeListToolFiles string = "listToolFiles" + ToolTypeReadToolFile string = "readToolFile" + ToolTypeExecuteToolCode string = "executeToolCode" +) + +// NewToolsManager creates and initializes a new tools manager instance. +// It validates the configuration, sets defaults if needed, and initializes atomic values +// for thread-safe configuration updates. +// +// Parameters: +// - config: Tool manager configuration with execution timeout and max agent depth +// - clientManager: Client manager interface for accessing MCP clients and tools +// - fetchNewRequestIDFunc: Optional function to generate unique request IDs for agent mode +// +// Returns: +// - *ToolsManager: Initialized tools manager instance +func NewToolsManager(config *schemas.MCPToolManagerConfig, clientManager ClientManager, fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string) *ToolsManager { + if config == nil { + config = &schemas.MCPToolManagerConfig{ + ToolExecutionTimeout: schemas.DefaultToolExecutionTimeout, + MaxAgentDepth: schemas.DefaultMaxAgentDepth, + CodeModeBindingLevel: schemas.CodeModeBindingLevelServer, + } + } + if config.MaxAgentDepth <= 0 { + config.MaxAgentDepth = schemas.DefaultMaxAgentDepth + } + if config.ToolExecutionTimeout <= 0 { + config.ToolExecutionTimeout = schemas.DefaultToolExecutionTimeout + } + // Default to server-level binding if not specified + if config.CodeModeBindingLevel == "" { + config.CodeModeBindingLevel = schemas.CodeModeBindingLevelServer + } + manager := &ToolsManager{ + clientManager: clientManager, + fetchNewRequestIDFunc: fetchNewRequestIDFunc, + } + // Initialize atomic values + manager.toolExecutionTimeout.Store(config.ToolExecutionTimeout) + manager.maxAgentDepth.Store(int32(config.MaxAgentDepth)) + manager.codeModeBindingLevel.Store(config.CodeModeBindingLevel) + + logger.Info(fmt.Sprintf("%s tool manager initialized with tool execution timeout: %v, max agent depth: %d, and code mode binding level: %s", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth, config.CodeModeBindingLevel)) + return manager +} + +// GetAvailableTools returns the available tools for the given context. +func (m *ToolsManager) GetAvailableTools(ctx context.Context) []schemas.ChatTool { + availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) + // Flatten tools from all clients into a single slice, avoiding duplicates + var availableTools []schemas.ChatTool + var includeCodeModeTools bool + // Track tool names to prevent duplicates + seenToolNames := make(map[string]bool) + + for clientName, clientTools := range availableToolsPerClient { + client := m.clientManager.GetClientByName(clientName) + if client == nil { + logger.Warn(fmt.Sprintf("%s Client %s not found, skipping", MCPLogPrefix, clientName)) + continue + } + if client.ExecutionConfig.IsCodeModeClient { + includeCodeModeTools = true + } else { + // Add tools from this client, checking for duplicates + for _, tool := range clientTools { + if tool.Function != nil && tool.Function.Name != "" { + if !seenToolNames[tool.Function.Name] { + availableTools = append(availableTools, tool) + seenToolNames[tool.Function.Name] = true + } + } + } + } + } + + if includeCodeModeTools { + codeModeTools := []schemas.ChatTool{ + m.createListToolFilesTool(), + m.createReadToolFileTool(), + m.createExecuteToolCodeTool(), + } + // Add code mode tools, checking for duplicates + for _, tool := range codeModeTools { + if tool.Function != nil && tool.Function.Name != "" { + if !seenToolNames[tool.Function.Name] { + availableTools = append(availableTools, tool) + seenToolNames[tool.Function.Name] = true + } + } + } + } + + return availableTools +} + +// buildIntegrationDuplicateCheckMap builds a map of tool names to check for duplicates +// based on the integration user agent. This includes both direct tool names and +// integration-specific naming patterns from existing tools in the request. +// +// Parameters: +// - existingTools: List of existing tools in the request +// - integrationUserAgent: Integration user agent string (e.g., "claude-cli") +// +// Returns: +// - map[string]bool: Map of tool names/patterns to check against +func buildIntegrationDuplicateCheckMap(existingTools []schemas.ChatTool, integrationUserAgent string) map[string]bool { + duplicateCheckMap := make(map[string]bool) + + // Add direct tool names + for _, tool := range existingTools { + if tool.Function != nil && tool.Function.Name != "" { + duplicateCheckMap[tool.Function.Name] = true + } + } + + // Add integration-specific patterns from existing tools + switch integrationUserAgent { + case "claude-cli": + // Claude CLI uses pattern: mcp__{foreign_name}__{tool_name} + // The middle part is a foreign name we cannot check for, so we extract the last part + // Examples: + // mcp__bifrost__executeToolCode -> executeToolCode + // mcp__bifrost__listToolFiles -> listToolFiles + // mcp__bifrost__readToolFile -> readToolFile + // mcp__calculator__calculator_add -> calculator_add + for _, tool := range existingTools { + if tool.Function != nil && tool.Function.Name != "" { + existingToolName := tool.Function.Name + // Check if existing tool matches Claude CLI pattern: mcp__*__{tool_name} + if strings.HasPrefix(existingToolName, "mcp__") { + // Split on __ and take the last entry (the tool_name) + parts := strings.Split(existingToolName, "__") + if len(parts) >= 3 { + toolName := parts[len(parts)-1] // Last part is the tool name + // Map Claude CLI pattern back to our tool name format + // This handles both regular MCP tools and code mode tools + if toolName != "" { + duplicateCheckMap[toolName] = true + // Also keep the original pattern for direct matching + duplicateCheckMap[existingToolName] = true + } + } + } + } + } + // Add more integration-specific patterns here as needed + // case "another-integration": + // // Add patterns for other integrations + } + + return duplicateCheckMap +} + +// ParseAndAddToolsToRequest parses the available tools per client and adds them to the Bifrost request. +// +// Parameters: +// - ctx: Execution context +// - req: Bifrost request +// - availableToolsPerClient: Map of client name to its available tools +// +// Returns: +// - *schemas.BifrostRequest: Bifrost request with MCP tools added +func (m *ToolsManager) ParseAndAddToolsToRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest { + // MCP is only supported for chat and responses requests + if req.ChatRequest == nil && req.ResponsesRequest == nil { + return req + } + + availableTools := m.GetAvailableTools(ctx) + + if len(availableTools) == 0 { + return req + } + + // Get integration user agent for duplicate checking + var integrationUserAgentStr string + integrationUserAgent := ctx.Value(schemas.BifrostContextKey("integration-user-agent")) + if integrationUserAgent != nil { + if str, ok := integrationUserAgent.(string); ok { + integrationUserAgentStr = str + } + } + + if len(availableTools) > 0 { + switch req.RequestType { + case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: + // Only allocate new Params if it's nil to preserve caller-supplied settings + if req.ChatRequest.Params == nil { + req.ChatRequest.Params = &schemas.ChatParameters{} + } + + tools := req.ChatRequest.Params.Tools + + // Build integration-aware duplicate check map + duplicateCheckMap := buildIntegrationDuplicateCheckMap(tools, integrationUserAgentStr) + + // Add MCP tools that are not already present + for _, mcpTool := range availableTools { + // Skip tools with nil Function or empty Name + if mcpTool.Function == nil || mcpTool.Function.Name == "" { + continue + } + + toolName := mcpTool.Function.Name + + // Check for duplicates using integration-aware logic + if !duplicateCheckMap[toolName] { + tools = append(tools, mcpTool) + // Update the map to prevent duplicates within MCP tools as well + duplicateCheckMap[toolName] = true + } + } + req.ChatRequest.Params.Tools = tools + case schemas.ResponsesRequest, schemas.ResponsesStreamRequest: + // Only allocate new Params if it's nil to preserve caller-supplied settings + if req.ResponsesRequest.Params == nil { + req.ResponsesRequest.Params = &schemas.ResponsesParameters{} + } + + tools := req.ResponsesRequest.Params.Tools + + // Convert Responses tools to ChatTool format for duplicate checking + existingChatTools := make([]schemas.ChatTool, 0, len(tools)) + for _, tool := range tools { + if tool.Name != nil { + existingChatTools = append(existingChatTools, schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: *tool.Name, + }, + }) + } + } + + // Build integration-aware duplicate check map + duplicateCheckMap := buildIntegrationDuplicateCheckMap(existingChatTools, integrationUserAgentStr) + + // Add MCP tools that are not already present + for _, mcpTool := range availableTools { + // Skip tools with nil Function or empty Name + if mcpTool.Function == nil || mcpTool.Function.Name == "" { + continue + } + + toolName := mcpTool.Function.Name + + // Check for duplicates using integration-aware logic + if !duplicateCheckMap[toolName] { + responsesTool := mcpTool.ToResponsesTool() + // Skip if the converted tool has nil Name + if responsesTool.Name == nil { + continue + } + + tools = append(tools, *responsesTool) + // Update the map to prevent duplicates within MCP tools as well + duplicateCheckMap[toolName] = true + } + } + req.ResponsesRequest.Params.Tools = tools + } + } + return req +} + +// ============================================================================ +// TOOL REGISTRATION AND DISCOVERY +// ============================================================================ + +// ExecuteChatTool executes a tool call in Chat Completions API format and returns the result as a chat tool message. +// This is the primary tool executor that works with both Chat Completions and Responses APIs. +// +// For Responses API users, use ExecuteResponsesTool() for a more type-safe interface. +// However, internally this method is format-agnostic - it executes the tool and returns +// a ChatMessage which can then be converted to ResponsesMessage via ToResponsesToolMessage(). +// +// Parameters: +// - ctx: Execution context +// - toolCall: The tool call to execute (from assistant message) +// +// Returns: +// - *schemas.ChatMessage: Tool message with execution result +// - error: Any execution error +func (m *ToolsManager) ExecuteChatTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { + if toolCall.Function.Name == nil { + return nil, fmt.Errorf("tool call missing function name") + } + toolName := *toolCall.Function.Name + + // Handle code mode tools + switch toolName { + case ToolTypeListToolFiles: + return m.handleListToolFiles(ctx, toolCall) + case ToolTypeReadToolFile: + return m.handleReadToolFile(ctx, toolCall) + case ToolTypeExecuteToolCode: + return m.handleExecuteToolCode(ctx, toolCall) + default: + // Check if the user has permission to execute the tool call + availableTools := m.clientManager.GetToolPerClient(ctx) + toolFound := false + for _, tools := range availableTools { + for _, mcpTool := range tools { + if mcpTool.Function != nil && mcpTool.Function.Name == toolName { + toolFound = true + break + } + } + if toolFound { + break + } + } + + if !toolFound { + return nil, fmt.Errorf("tool '%s' is not available or not permitted", toolName) + } + + client := m.clientManager.GetClientForTool(toolName) + if client == nil { + return nil, fmt.Errorf("client not found for tool %s", toolName) + } + + // Parse tool arguments + var arguments map[string]interface{} + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + return nil, fmt.Errorf("failed to parse tool arguments for '%s': %v", toolName, err) + } + + // Strip the client name prefix from tool name before calling MCP server + // The MCP server expects the original tool name, not the prefixed version + originalToolName := stripClientPrefix(toolName, client.ExecutionConfig.Name) + + // Call the tool via MCP client -> MCP server + callRequest := mcp.CallToolRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsCall), + }, + Params: mcp.CallToolParams{ + Name: originalToolName, + Arguments: arguments, + }, + } + + logger.Debug(fmt.Sprintf("%s Starting tool execution: %s via client: %s", MCPLogPrefix, toolName, client.ExecutionConfig.Name)) + + // Create timeout context for tool execution + toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration) + toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) + defer cancel() + + toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest) + if callErr != nil { + // Check if it was a timeout error + if toolCtx.Err() == context.DeadlineExceeded { + return nil, fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName) + } + logger.Error("%s Tool execution failed for %s via client %s: %v", MCPLogPrefix, toolName, client.ExecutionConfig.Name, callErr) + return nil, fmt.Errorf("MCP tool call failed: %v", callErr) + } + + logger.Debug(fmt.Sprintf("%s Tool execution completed: %s", MCPLogPrefix, toolName)) + + // Extract text from MCP response + responseText := extractTextFromMCPResponse(toolResponse, toolName) + + // Create tool response message + return createToolResponseMessage(toolCall, responseText), nil + } +} + +// ExecuteToolForResponses executes a tool call from a Responses API tool message and returns +// the result in Responses API format. This is a type-safe wrapper around ExecuteTool that +// handles the conversion between Responses and Chat API formats. +// +// This method: +// 1. Converts the Responses tool message to Chat API format +// 2. Executes the tool using the standard tool executor +// 3. Converts the result back to Responses API format +// +// Parameters: +// - ctx: Execution context +// - toolMessage: The Responses API tool message to execute +// - callID: The original call ID from the Responses API +// +// Returns: +// - *schemas.ResponsesMessage: Tool result message in Responses API format +// - error: Any execution error +// +// Example: +// +// responsesToolMsg := &schemas.ResponsesToolMessage{ +// Name: Ptr("calculate"), +// Arguments: Ptr("{\"x\": 10, \"y\": 20}"), +// } +// resultMsg, err := toolsManager.ExecuteResponsesTool(ctx, responsesToolMsg, "call-123") +// // resultMsg is a ResponsesMessage with type=function_call_output +func (m *ToolsManager) ExecuteResponsesTool( + ctx *schemas.BifrostContext, + toolMessage *schemas.ResponsesToolMessage, +) (*schemas.ResponsesMessage, error) { + if toolMessage == nil { + return nil, fmt.Errorf("tool message is nil") + } + if toolMessage.Name == nil { + return nil, fmt.Errorf("tool call missing function name") + } + + // Convert Responses format to Chat format for execution + chatToolCall := toolMessage.ToChatAssistantMessageToolCall() + if chatToolCall == nil { + return nil, fmt.Errorf("failed to convert Responses tool message to Chat format") + } + + // Execute the tool using the standard executor + chatResult, err := m.ExecuteChatTool(ctx, *chatToolCall) + if err != nil { + return nil, err + } + + // Convert the result back to Responses format + responsesMessage := chatResult.ToResponsesToolMessage() + if responsesMessage == nil { + return nil, fmt.Errorf("failed to convert tool result to Responses format") + } + + return responsesMessage, nil +} + +// ExecuteAgentForChatRequest executes agent mode for a chat request, handling +// iterative tool calls up to the configured maximum depth. It delegates to the +// shared agent execution logic with the manager's configuration and dependencies. +// +// Parameters: +// - ctx: Context for agent execution +// - req: The original chat request +// - resp: The initial chat response containing tool calls +// - makeReq: Function to make subsequent chat requests during agent execution +// +// Returns: +// - *schemas.BifrostChatResponse: The final response after agent execution +// - *schemas.BifrostError: Any error that occurred during agent execution +func (m *ToolsManager) ExecuteAgentForChatRequest( + ctx *schemas.BifrostContext, + req *schemas.BifrostChatRequest, + resp *schemas.BifrostChatResponse, + makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError), +) (*schemas.BifrostChatResponse, *schemas.BifrostError) { + return ExecuteAgentForChatRequest( + ctx, + int(m.maxAgentDepth.Load()), + req, + resp, + makeReq, + m.fetchNewRequestIDFunc, + m.ExecuteChatTool, + m.clientManager, + ) +} + +// ExecuteAgentForResponsesRequest executes agent mode for a responses request, handling +// iterative tool calls up to the configured maximum depth. It delegates to the +// shared agent execution logic with the manager's configuration and dependencies. +// +// Parameters: +// - ctx: Context for agent execution +// - req: The original responses request +// - resp: The initial responses response containing tool calls +// - makeReq: Function to make subsequent responses requests during agent execution +// +// Returns: +// - *schemas.BifrostResponsesResponse: The final response after agent execution +// - *schemas.BifrostError: Any error that occurred during agent execution +func (m *ToolsManager) ExecuteAgentForResponsesRequest( + ctx *schemas.BifrostContext, + req *schemas.BifrostResponsesRequest, + resp *schemas.BifrostResponsesResponse, + makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError), +) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { + return ExecuteAgentForResponsesRequest( + ctx, + int(m.maxAgentDepth.Load()), + req, + resp, + makeReq, + m.fetchNewRequestIDFunc, + m.ExecuteChatTool, + m.clientManager, + ) +} + +// UpdateConfig updates tool manager configuration atomically. +// This method is safe to call concurrently from multiple goroutines. +func (m *ToolsManager) UpdateConfig(config *schemas.MCPToolManagerConfig) { + if config == nil { + return + } + if config.ToolExecutionTimeout > 0 { + m.toolExecutionTimeout.Store(config.ToolExecutionTimeout) + } + if config.MaxAgentDepth > 0 { + m.maxAgentDepth.Store(int32(config.MaxAgentDepth)) + } + if config.CodeModeBindingLevel != "" { + m.codeModeBindingLevel.Store(config.CodeModeBindingLevel) + } + + logger.Info(fmt.Sprintf("%s tool manager configuration updated with tool execution timeout: %v, max agent depth: %d, and code mode binding level: %s", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth, config.CodeModeBindingLevel)) +} + +// GetCodeModeBindingLevel returns the current code mode binding level. +// This method is safe to call concurrently from multiple goroutines. +func (m *ToolsManager) GetCodeModeBindingLevel() schemas.CodeModeBindingLevel { + val := m.codeModeBindingLevel.Load() + if val == nil { + return schemas.CodeModeBindingLevelServer + } + return val.(schemas.CodeModeBindingLevel) +} diff --git a/core/mcp/utils.go b/core/mcp/utils.go new file mode 100644 index 0000000000..3fe8b9e7c7 --- /dev/null +++ b/core/mcp/utils.go @@ -0,0 +1,567 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "maps" + "regexp" + "slices" + "strings" + "unicode" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" + "github.com/maximhq/bifrost/core/schemas" +) + +// GetClientForTool safely finds a client that has the specified tool. +// Returns a copy of the client state to avoid data races. Callers should be aware +// that fields like Conn and ToolMap are still shared references and may be modified +// by other goroutines, but the struct itself is safe from concurrent modification. +func (m *MCPManager) GetClientForTool(toolName string) *schemas.MCPClientState { + m.mu.RLock() + defer m.mu.RUnlock() + + for _, client := range m.clientMap { + if _, exists := client.ToolMap[toolName]; exists { + // Return a copy to prevent TOCTOU race conditions + // The caller receives a snapshot of the client state at this point in time + clientCopy := *client + return &clientCopy + } + } + return nil +} + +// GetToolPerClient returns all tools from connected MCP clients. +// Applies client filtering if specified in the context. +// Returns a map of client name to its available tools. +// Parameters: +// - ctx: Execution context +// +// Returns: +// - map[string][]schemas.ChatTool: Map of client name to its available tools +func (m *MCPManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool { + m.mu.RLock() + defer m.mu.RUnlock() + + var includeClients []string + + // Extract client filtering from request context + if existingIncludeClients, ok := ctx.Value(MCPContextKeyIncludeClients).([]string); ok && existingIncludeClients != nil { + includeClients = existingIncludeClients + } + + tools := make(map[string][]schemas.ChatTool) + for _, client := range m.clientMap { + // Use client name as the key (not ID) + clientName := client.ExecutionConfig.Name + + // Apply client filtering logic + if !shouldIncludeClient(clientName, includeClients) { + logger.Debug(fmt.Sprintf("%s Skipping MCP client %s: not in include clients list", MCPLogPrefix, clientName)) + continue + } + + logger.Debug(fmt.Sprintf("Checking tools for MCP client %s with tools to execute: %v", clientName, client.ExecutionConfig.ToolsToExecute)) + + // Add all tools from this client + for toolName, tool := range client.ToolMap { + // Check if tool should be skipped based on client configuration + if shouldSkipToolForConfig(toolName, client.ExecutionConfig) { + logger.Debug(fmt.Sprintf("%s Skipping MCP tool %s: not in tools to execute list", MCPLogPrefix, toolName)) + continue + } + + // Check if tool should be skipped based on request context + if shouldSkipToolForRequest(ctx, clientName, toolName) { + logger.Debug(fmt.Sprintf("%s Skipping MCP tool %s: not in include tools list", MCPLogPrefix, toolName)) + continue + } + + tools[clientName] = append(tools[clientName], tool) + } + if len(tools[clientName]) > 0 { + logger.Debug(fmt.Sprintf("%s Added %d tools for MCP client %s", MCPLogPrefix, len(tools[clientName]), clientName)) + } + } + return tools +} + +// GetClientByName returns a client by name. +// +// Parameters: +// - clientName: Name of the client to get +// +// Returns: +// - *schemas.MCPClientState: Client state if found, nil otherwise +func (m *MCPManager) GetClientByName(clientName string) *schemas.MCPClientState { + m.mu.RLock() + defer m.mu.RUnlock() + for _, client := range m.clientMap { + if client.ExecutionConfig.Name == clientName { + // Return a copy to prevent TOCTOU race conditions + // The caller receives a snapshot of the client state at this point in time + clientCopy := *client + return &clientCopy + } + } + return nil +} + +// retrieveExternalTools retrieves and filters tools from an external MCP server without holding locks. +func retrieveExternalTools(ctx context.Context, client *client.Client, clientName string) (map[string]schemas.ChatTool, error) { + // Get available tools from external server + listRequest := mcp.ListToolsRequest{ + PaginatedRequest: mcp.PaginatedRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsList), + }, + }, + } + + toolsResponse, err := client.ListTools(ctx, listRequest) + if err != nil { + return nil, fmt.Errorf("failed to list tools: %v", err) + } + + if toolsResponse == nil { + return make(map[string]schemas.ChatTool), nil // No tools available + } + + tools := make(map[string]schemas.ChatTool) + + // toolsResponse is already a ListToolsResult + for _, mcpTool := range toolsResponse.Tools { + // Convert MCP tool schema to Bifrost format + bifrostTool := convertMCPToolToBifrostSchema(&mcpTool) + // Prefix tool name with client name to make it permanent + prefixedToolName := fmt.Sprintf("%s_%s", clientName, mcpTool.Name) + // Update the tool's function name to match the prefixed name + if bifrostTool.Function != nil { + bifrostTool.Function.Name = prefixedToolName + } + tools[prefixedToolName] = bifrostTool + } + + return tools, nil +} + +// shouldIncludeClient determines if a client should be included based on filtering rules. +func shouldIncludeClient(clientName string, includeClients []string) bool { + // If includeClients is specified (not nil), apply whitelist filtering + if includeClients != nil { + // Handle empty array [] - means no clients are included + if len(includeClients) == 0 { + return false // No clients allowed + } + + // Handle wildcard "*" - if present, all clients are included + if slices.Contains(includeClients, "*") { + return true // All clients allowed + } + + // Check if specific client is in the list + return slices.Contains(includeClients, clientName) + } + + // Default: include all clients when no filtering specified (nil case) + return true +} + +// shouldSkipToolForConfig checks if a tool should be skipped based on client configuration (without accessing clientMap). +func shouldSkipToolForConfig(toolName string, config schemas.MCPClientConfig) bool { + // If ToolsToExecute is specified (not nil), apply filtering + if config.ToolsToExecute != nil { + // Handle empty array [] - means no tools are allowed + if len(config.ToolsToExecute) == 0 { + return true // No tools allowed + } + + // Handle wildcard "*" - if present, all tools are allowed + if slices.Contains(config.ToolsToExecute, "*") { + return false // All tools allowed + } + + // Check if specific tool is in the allowed list + return !slices.Contains(config.ToolsToExecute, toolName) // Tool not in allowed list + } + + return true // Tool is skipped (nil is treated as [] - no tools) +} + +// canAutoExecuteTool checks if a tool can be auto-executed based on client configuration. +// Returns true if the tool can be auto-executed, false otherwise. +func canAutoExecuteTool(toolName string, config schemas.MCPClientConfig) bool { + // First check if tool is in ToolsToExecute (must be executable first) + if shouldSkipToolForConfig(toolName, config) { + return false // Tool is not in ToolsToExecute, so it cannot be auto-executed + } + + // If ToolsToAutoExecute is specified (not nil), apply filtering + if config.ToolsToAutoExecute != nil { + // Handle empty array [] - means no tools are auto-executed + if len(config.ToolsToAutoExecute) == 0 { + return false // No tools auto-executed + } + + // Handle wildcard "*" - if present, all tools are auto-executed + if slices.Contains(config.ToolsToAutoExecute, "*") { + return true // All tools auto-executed + } + + // Check if specific tool is in the auto-execute list + return slices.Contains(config.ToolsToAutoExecute, toolName) + } + + return false // Tool is not auto-executed (nil is treated as [] - no tools) +} + +// shouldSkipToolForRequest checks if a tool should be skipped based on the request context. +func shouldSkipToolForRequest(ctx context.Context, clientName, toolName string) bool { + includeTools := ctx.Value(MCPContextKeyIncludeTools) + + if includeTools != nil { + // Try []string first (preferred type) + if includeToolsList, ok := includeTools.([]string); ok { + // Handle empty array [] - means no tools are included + if len(includeToolsList) == 0 { + return true // No tools allowed + } + + // Handle wildcard "clientName/*" - if present, all tools are included for this client + if slices.Contains(includeToolsList, fmt.Sprintf("%s/*", clientName)) { + return false // All tools allowed + } + + // Check if specific tool is in the list (format: clientName/toolName) + fullToolName := fmt.Sprintf("%s/%s", clientName, toolName) + if slices.Contains(includeToolsList, fullToolName) { + return false // Tool is explicitly allowed + } + + // If includeTools is specified but this tool is not in it, skip it + return true + } + } + + return false // Tool is allowed (default when no filtering specified) +} + +// convertMCPToolToBifrostSchema converts an MCP tool definition to Bifrost format. +func convertMCPToolToBifrostSchema(mcpTool *mcp.Tool) schemas.ChatTool { + var properties *schemas.OrderedMap + if len(mcpTool.InputSchema.Properties) > 0 { + orderedProps := make(schemas.OrderedMap, len(mcpTool.InputSchema.Properties)) + maps.Copy(orderedProps, mcpTool.InputSchema.Properties) + properties = &orderedProps + } + return schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: mcpTool.Name, + Description: schemas.Ptr(mcpTool.Description), + Parameters: &schemas.ToolFunctionParameters{ + Type: mcpTool.InputSchema.Type, + Properties: properties, + Required: mcpTool.InputSchema.Required, + }, + }, + } +} + +// extractTextFromMCPResponse extracts text content from an MCP tool response. +func extractTextFromMCPResponse(toolResponse *mcp.CallToolResult, toolName string) string { + if toolResponse == nil { + return fmt.Sprintf("MCP tool '%s' executed successfully", toolName) + } + + var result strings.Builder + for _, contentBlock := range toolResponse.Content { + // Handle typed content + switch content := contentBlock.(type) { + case mcp.TextContent: + result.WriteString(content.Text) + case mcp.ImageContent: + result.WriteString(fmt.Sprintf("[Image Response: %s, MIME: %s]\n", content.Data, content.MIMEType)) + case mcp.AudioContent: + result.WriteString(fmt.Sprintf("[Audio Response: %s, MIME: %s]\n", content.Data, content.MIMEType)) + case mcp.EmbeddedResource: + result.WriteString(fmt.Sprintf("[Embedded Resource Response: %s]\n", content.Type)) + default: + // Fallback: try to extract from map structure + if jsonBytes, err := json.Marshal(contentBlock); err == nil { + var contentMap map[string]interface{} + if json.Unmarshal(jsonBytes, &contentMap) == nil { + if text, ok := contentMap["text"].(string); ok { + result.WriteString(fmt.Sprintf("[Text Response: %s]\n", text)) + continue + } + } + // Final fallback: serialize as JSON + result.WriteString(string(jsonBytes)) + } + } + } + + if result.Len() > 0 { + return strings.TrimSpace(result.String()) + } + return fmt.Sprintf("MCP tool '%s' executed successfully", toolName) +} + +// createToolResponseMessage creates a tool response message with the execution result. +func createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, responseText string) *schemas.ChatMessage { + return &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + Content: &schemas.ChatMessageContent{ + ContentStr: &responseText, + }, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: toolCall.ID, + }, + } +} + +// validateMCPClientConfig validates an MCP client configuration. +func validateMCPClientConfig(config *schemas.MCPClientConfig) error { + if strings.TrimSpace(config.ID) == "" { + return fmt.Errorf("id is required for MCP client config") + } + if err := validateMCPClientName(config.Name); err != nil { + return fmt.Errorf("invalid name for MCP client: %w", err) + } + if config.ConnectionType == "" { + return fmt.Errorf("connection type is required for MCP client config") + } + switch config.ConnectionType { + case schemas.MCPConnectionTypeHTTP: + if config.ConnectionString == nil { + return fmt.Errorf("ConnectionString is required for HTTP connection type in client '%s'", config.Name) + } + case schemas.MCPConnectionTypeSSE: + if config.ConnectionString == nil { + return fmt.Errorf("ConnectionString is required for SSE connection type in client '%s'", config.Name) + } + case schemas.MCPConnectionTypeSTDIO: + if config.StdioConfig == nil { + return fmt.Errorf("StdioConfig is required for STDIO connection type in client '%s'", config.Name) + } + case schemas.MCPConnectionTypeInProcess: + // InProcess requires a server instance to be provided programmatically + // This cannot be validated from JSON config - the server must be set when using the Go package + if config.InProcessServer == nil { + return fmt.Errorf("InProcessServer is required for InProcess connection type in client '%s' (Go package only)", config.Name) + } + default: + return fmt.Errorf("unknown connection type '%s' in client '%s'", config.ConnectionType, config.Name) + } + return nil +} + +func validateMCPClientName(name string) error { + if strings.TrimSpace(name) == "" { + return fmt.Errorf("name is required for MCP client") + } + for _, r := range name { + if r > 127 { // non-ASCII + return fmt.Errorf("name must contain only ASCII characters") + } + } + if strings.Contains(name, "-") { + return fmt.Errorf("name cannot contain hyphens") + } + if strings.Contains(name, " ") { + return fmt.Errorf("name cannot contain spaces") + } + if len(name) > 0 && name[0] >= '0' && name[0] <= '9' { + return fmt.Errorf("name cannot start with a number") + } + return nil +} + +// parseToolName parses the tool name to be JavaScript-compatible. +// It converts spaces and hyphens to underscores, removes invalid characters, and ensures +// the name starts with a valid JavaScript identifier character. +func parseToolName(toolName string) string { + if toolName == "" { + return "" + } + + var result strings.Builder + runes := []rune(toolName) + + // Process first character - must be letter, underscore, or dollar sign + if len(runes) > 0 { + first := runes[0] + if unicode.IsLetter(first) || first == '_' || first == '$' { + result.WriteRune(unicode.ToLower(first)) + } else { + // If first char is invalid, prefix with underscore + result.WriteRune('_') + if unicode.IsDigit(first) { + result.WriteRune(first) + } + } + } + + // Process remaining characters + for i := 1; i < len(runes); i++ { + r := runes[i] + if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' { + result.WriteRune(unicode.ToLower(r)) + } else if unicode.IsSpace(r) || r == '-' { + // Replace spaces and hyphens with single underscore + // Avoid consecutive underscores + if result.Len() > 0 && result.String()[result.Len()-1] != '_' { + result.WriteRune('_') + } + } + // Skip other invalid characters + } + + parsed := result.String() + + // Remove trailing underscores + parsed = strings.TrimRight(parsed, "_") + + // Ensure we have at least one character + // Should never happen, but just in case + if parsed == "" { + return "tool" + } + + return parsed +} + +// extractToolCallsFromCode extracts tool calls from TypeScript code +// Tool calls are in the format: serverName.toolName(...) or await serverName.toolName(...) +func extractToolCallsFromCode(code string) ([]toolCallInfo, error) { + toolCalls := []toolCallInfo{} + + // Regex pattern to match tool calls: + // - Optional "await" keyword + // - Server name (identifier) + // - Dot + // - Tool name (identifier) + // - Opening parenthesis + // This pattern matches: await serverName.toolName( or serverName.toolName( + toolCallPattern := regexp.MustCompile(`(?:await\s+)?([a-zA-Z_$][a-zA-Z0-9_$]*)\s*\.\s*([a-zA-Z_$][a-zA-Z0-9_$]*)\s*\(`) + + // Find all matches + matches := toolCallPattern.FindAllStringSubmatch(code, -1) + for _, match := range matches { + if len(match) >= 3 { + serverName := match[1] + toolName := match[2] + toolCalls = append(toolCalls, toolCallInfo{ + serverName: serverName, + toolName: toolName, + }) + } + } + + return toolCalls, nil +} + +// isToolCallAllowedForCodeMode checks if a tool call is allowed based on allowedAutoExecutionTools map +func isToolCallAllowedForCodeMode(serverName, toolName string, allClientNames []string, allowedAutoExecutionTools map[string][]string) bool { + // Check if the server name is in the list of all client names + if !slices.Contains(allClientNames, serverName) { + // It can be a built-in JavaScript/TypeScript object, if not then downstream execution will fail with a runtime error. + return true + } + + // Get allowed tools for this server + allowedTools, exists := allowedAutoExecutionTools[serverName] + if !exists { + // Server not in allowed list, return false to prevent downstream execution. + return false + } + + // Check if wildcard "*" is present (all tools allowed) + if slices.Contains(allowedTools, "*") { + return true + } + + // Check if specific tool is in the allowed list + if slices.Contains(allowedTools, toolName) { + return true + } + + return false // Tool not in allowed list +} + +// hasToolCalls checks if a chat response contains tool calls that need to be executed +func hasToolCallsForChatResponse(response *schemas.BifrostChatResponse) bool { + if response == nil || len(response.Choices) == 0 { + return false + } + + choice := response.Choices[0] + + // If finish_reason is "stop", this indicates non-auto-executable tools that require user approval. + // Don't return true even if tool calls are present, as the agent loop should not process them. + if choice.FinishReason != nil && *choice.FinishReason == "stop" { + return false + } + + // Check finish reason + if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" { + return true + } + + // Check if message has tool calls + if choice.ChatNonStreamResponseChoice != nil && + choice.ChatNonStreamResponseChoice.Message != nil && + choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage != nil && + len(choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls) > 0 { + return true + } + + return false +} + +func hasToolCallsForResponsesResponse(response *schemas.BifrostResponsesResponse) bool { + if response == nil || len(response.Output) == 0 { + return false + } + + // Check if any output message is a tool call + for _, output := range response.Output { + if output.Type == nil { + continue + } + + // Check for tool call types + switch *output.Type { + case schemas.ResponsesMessageTypeFunctionCall, schemas.ResponsesMessageTypeCustomToolCall: + // Verify that ResponsesToolMessage is actually set + if output.ResponsesToolMessage != nil { + return true + } + } + } + + return false +} + +// stripClientPrefix removes the client name prefix from a tool name. +// Tool names are stored with format "{clientName}_{toolName}", but when calling +// the MCP server, we need the original tool name without the prefix. +// +// Parameters: +// - prefixedToolName: Tool name with client prefix (e.g., "calculator_add") +// - clientName: Client name to strip (e.g., "calculator") +// +// Returns: +// - string: Original tool name without prefix (e.g., "add") +func stripClientPrefix(prefixedToolName, clientName string) string { + prefix := clientName + "_" + if strings.HasPrefix(prefixedToolName, prefix) { + return strings.TrimPrefix(prefixedToolName, prefix) + } + // If prefix doesn't match, return as-is (shouldn't happen, but be safe) + return prefixedToolName +} diff --git a/core/providers/anthropic/anthropic.go b/core/providers/anthropic/anthropic.go index 211a641698..0d5eced0e9 100644 --- a/core/providers/anthropic/anthropic.go +++ b/core/providers/anthropic/anthropic.go @@ -120,14 +120,14 @@ func (provider *AnthropicProvider) GetProviderKey() schemas.ModelProvider { } // buildRequestURL constructs the full request URL using the provider's configuration. -func (provider *AnthropicProvider) buildRequestURL(ctx context.Context, defaultPath string, requestType schemas.RequestType) string { +func (provider *AnthropicProvider) buildRequestURL(ctx *schemas.BifrostContext, defaultPath string, requestType schemas.RequestType) string { return provider.networkConfig.BaseURL + providerUtils.GetRequestPath(ctx, defaultPath, provider.customProviderConfig, requestType) } // completeRequest sends a request to Anthropic's API and handles the response. // It constructs the API URL, sets up authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *AnthropicProvider) completeRequest(ctx context.Context, jsonData []byte, url string, key string, meta *providerUtils.RequestMetadata) ([]byte, time.Duration, *schemas.BifrostError) { +func (provider *AnthropicProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string, meta *providerUtils.RequestMetadata) ([]byte, time.Duration, *schemas.BifrostError) { // Create the request with the JSON body req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -173,7 +173,7 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, jsonData // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. -func (provider *AnthropicProvider) listModelsByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -234,7 +234,7 @@ func (provider *AnthropicProvider) listModelsByKey(ctx context.Context, key sche // It fetches models using all provided keys and aggregates the results. // Uses a best-effort approach: continues with remaining keys even if some fail. // Requests are made concurrently for improved performance. -func (provider *AnthropicProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ListModelsRequest); err != nil { return nil, err } @@ -253,7 +253,7 @@ func (provider *AnthropicProvider) ListModels(ctx context.Context, keys []schema // TextCompletion performs a text completion request to Anthropic's API. // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AnthropicProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.TextCompletionRequest); err != nil { return nil, err } @@ -311,14 +311,14 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, key schem // TextCompletionStream performs a streaming text completion request to Anthropic's API. // It formats the request, sends it to Anthropic, and processes the response. // Returns a channel of BifrostStream objects or an error if the request fails. -func (provider *AnthropicProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *AnthropicProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } // ChatCompletion performs a chat completion request to Anthropic's API. // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil { return nil, err } @@ -376,7 +376,7 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, key schem // ChatCompletionStream performs a streaming chat completion request to the Anthropic API. // It supports real-time streaming of responses using Server-Sent Events (SSE). // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *AnthropicProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *AnthropicProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err } @@ -434,7 +434,7 @@ func (provider *AnthropicProvider) ChatCompletionStream(ctx context.Context, pos // HandleAnthropicChatCompletionStreaming handles streaming for Anthropic-compatible APIs. // This shared function reduces code duplication between providers that use the same SSE event format. func HandleAnthropicChatCompletionStreaming( - ctx context.Context, + ctx *schemas.BifrostContext, client *fasthttp.Client, url string, jsonBody []byte, @@ -496,20 +496,35 @@ func HandleAnthropicChatCompletionStreaming( // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + model := "unknown" + if meta != nil { + model = meta.Model + } + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) - + if resp.BodyStream() == nil { bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", fmt.Errorf("provider returned an empty response"), providerName, ) - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) return } + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger) + defer stopCancellation() + scanner := bufio.NewScanner(resp.BodyStream()) buf := make([]byte, 0, 1024*1024) scanner.Buffer(buf, 10*1024*1024) @@ -531,13 +546,15 @@ func HandleAnthropicChatCompletionStreaming( var eventData string for scanner.Scan() { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } line := scanner.Text() - // Skip empty lines and comments if line == "" || strings.HasPrefix(line, ":") { continue } - // Parse SSE event - track event type and data separately if after, ok := strings.CutPrefix(line, "event: "); ok { eventType = after @@ -547,22 +564,18 @@ func HandleAnthropicChatCompletionStreaming( } else { continue } - // Skip if we don't have both event type and data if eventType == "" || eventData == "" { continue } - var event AnthropicStreamEvent if err := sonic.Unmarshal([]byte(eventData), &event); err != nil { logger.Warn(fmt.Sprintf("Failed to parse message_start event: %v", err)) continue } - if event.Type == AnthropicStreamEventTypeMessageStart && event.Message != nil && event.Message.ID != "" { messageID = event.Message.ID } - // Check for usage in both top-level event.Usage and nested event.Message.Usage // message_start events have usage nested in message.usage, while message_delta has it at top level var usageToProcess *AnthropicUsage @@ -571,7 +584,6 @@ func HandleAnthropicChatCompletionStreaming( } else if event.Message != nil && event.Message.Usage != nil { usageToProcess = event.Message.Usage } - if usageToProcess != nil { // Collect usage information and send at the end of the stream // Here in some cases usage comes before final message @@ -606,7 +618,6 @@ func HandleAnthropicChatCompletionStreaming( } } } - if event.Delta != nil && event.Delta.StopReason != nil { mappedReason := ConvertAnthropicFinishReasonToBifrost(*event.Delta.StopReason) finishReason = &mappedReason @@ -615,7 +626,6 @@ func HandleAnthropicChatCompletionStreaming( // Handle different event types modelName = event.Message.Model } - response, bifrostErr, isLastChunk := event.ToBifrostChatCompletionStream() if bifrostErr != nil { bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ @@ -623,7 +633,7 @@ func HandleAnthropicChatCompletionStreaming( Provider: providerName, ModelRequested: modelName, } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) break } @@ -652,36 +662,40 @@ func HandleAnthropicChatCompletionStreaming( providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) } - if isLastChunk { break } - // Reset for next event eventType = "" eventData = "" } - if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn(fmt.Sprintf("Error reading %s stream: %v", providerName, err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, modelName, logger) - } else { - response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerName, modelName) - if postResponseConverter != nil { - response = postResponseConverter(response) - if response == nil { - logger.Warn("postResponseConverter returned nil; skipping chunk") - return - } - } - // Set raw request if enabled - if sendBackRawRequest { - providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) + return + } + response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerName, modelName) + if postResponseConverter != nil { + response = postResponseConverter(response) + if response == nil { + logger.Warn("postResponseConverter returned nil; skipping chunk") + // Setting error on the context to signal to the defer that we need to close the stream + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + return } - response.ExtraFields.Latency = time.Since(startTime).Milliseconds() - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) } + // Set raw request if enabled + if sendBackRawRequest { + providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) + } + response.ExtraFields.Latency = time.Since(startTime).Milliseconds() + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) }() return responseChan, nil @@ -690,7 +704,7 @@ func HandleAnthropicChatCompletionStreaming( // Responses performs a chat completion request to Anthropic's API. // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { return nil, err } @@ -742,7 +756,7 @@ func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Ke } // ResponsesStream performs a streaming responses request to the Anthropic API. -func (provider *AnthropicProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *AnthropicProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { return nil, err } @@ -788,7 +802,7 @@ func (provider *AnthropicProvider) ResponsesStream(ctx context.Context, postHook // HandleAnthropicResponsesStream handles streaming for Anthropic-compatible APIs. // This shared function reduces code duplication between providers that use the same SSE event format. func HandleAnthropicResponsesStream( - ctx context.Context, + ctx *schemas.BifrostContext, client *fasthttp.Client, url string, jsonBody []byte, @@ -850,16 +864,30 @@ func HandleAnthropicResponsesStream( // Start streaming in a goroutine go func() { + defer func() { + model := "" + if meta != nil { + model = meta.Model + } + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) - defer close(responseChan) - + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger) + defer stopCancellation() + // If body stream is nil, return an error if resp.BodyStream() == nil { bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", fmt.Errorf("provider returned an empty response"), providerName, ) - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) return } @@ -883,13 +911,15 @@ func HandleAnthropicResponsesStream( var modelName string for scanner.Scan() { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } line := scanner.Text() - // Skip empty lines and comments if line == "" || strings.HasPrefix(line, ":") { continue } - // Parse SSE event - track event type and data separately if after, ok := strings.CutPrefix(line, "event: "); ok { eventType = after @@ -899,22 +929,18 @@ func HandleAnthropicResponsesStream( } else { continue } - // Skip if we don't have both event type and data if eventType == "" || eventData == "" { continue } - var event AnthropicStreamEvent if err := sonic.Unmarshal([]byte(eventData), &event); err != nil { logger.Warn(fmt.Sprintf("Failed to parse message_start event: %v", err)) continue } - if event.Message != nil && modelName == "" { modelName = event.Message.Model } - // Note: response.created and response.in_progress are now emitted by ToBifrostResponsesStream // from the message_start event, so we don't need to call them manually here @@ -969,7 +995,11 @@ func HandleAnthropicResponsesStream( Provider: providerName, ModelRequested: modelName, } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) break } @@ -1008,7 +1038,7 @@ func HandleAnthropicResponsesStream( providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil), responseChan) return } @@ -1020,8 +1050,12 @@ func HandleAnthropicResponsesStream( eventType = "" eventData = "" } - if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn(fmt.Sprintf("Error reading %s stream: %v", providerName, err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, modelName, logger) } @@ -1031,7 +1065,7 @@ func HandleAnthropicResponsesStream( } // BatchCreate creates a new batch job. -func (provider *AnthropicProvider) BatchCreate(ctx context.Context, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) BatchCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.BatchCreateRequest); err != nil { return nil, err } @@ -1111,7 +1145,7 @@ func (provider *AnthropicProvider) BatchCreate(ctx context.Context, key schemas. // BatchList lists batch jobs using serial pagination across keys. // Exhausts all pages from one key before moving to the next. -func (provider *AnthropicProvider) BatchList(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.BatchListRequest); err != nil { return nil, err } @@ -1229,7 +1263,7 @@ func (provider *AnthropicProvider) BatchList(ctx context.Context, keys []schemas } // BatchRetrieve retrieves a specific batch job by trying each key until found. -func (provider *AnthropicProvider) BatchRetrieve(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.BatchRetrieveRequest); err != nil { return nil, err } @@ -1311,7 +1345,7 @@ func (provider *AnthropicProvider) BatchRetrieve(ctx context.Context, keys []sch } // BatchCancel cancels a batch job by trying each key until successful. -func (provider *AnthropicProvider) BatchCancel(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) BatchCancel(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.BatchCancelRequest); err != nil { return nil, err } @@ -1419,7 +1453,7 @@ func (provider *AnthropicProvider) BatchCancel(ctx context.Context, keys []schem } // BatchResults retrieves batch results by trying each key until found. -func (provider *AnthropicProvider) BatchResults(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) BatchResults(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.BatchResultsRequest); err != nil { return nil, err } @@ -1544,39 +1578,32 @@ func splitJSONL(data []byte) [][]byte { } // Embedding is not supported by the Anthropic provider. -func (provider *AnthropicProvider) Embedding(ctx context.Context, key schemas.Key, input *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, input *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey()) } // Speech is not supported by the Anthropic provider. -func (provider *AnthropicProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) } // SpeechStream is not supported by the Anthropic provider. -func (provider *AnthropicProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *AnthropicProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } // Transcription is not supported by the Anthropic provider. -func (provider *AnthropicProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey()) } // TranscriptionStream is not supported by the Anthropic provider. -func (provider *AnthropicProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *AnthropicProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } -// parseStreamAnthropicError parses Anthropic streaming error responses. -func parseStreamAnthropicError(resp *fasthttp.Response, providerType schemas.ModelProvider) *schemas.BifrostError { - statusCode := resp.StatusCode() - body := resp.Body() - return providerUtils.NewProviderAPIError(string(body), nil, statusCode, providerType, nil, nil) -} - // FileUpload uploads a file to Anthropic's Files API. -func (provider *AnthropicProvider) FileUpload(ctx context.Context, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.FileUploadRequest); err != nil { return nil, err } @@ -1659,7 +1686,7 @@ func (provider *AnthropicProvider) FileUpload(ctx context.Context, key schemas.K // FileList lists files from all provided keys and aggregates results. // FileList lists files using serial pagination across keys. // Exhausts all pages from one key before moving to the next. -func (provider *AnthropicProvider) FileList(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) FileList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.FileListRequest); err != nil { return nil, err } @@ -1783,7 +1810,7 @@ func (provider *AnthropicProvider) FileList(ctx context.Context, keys []schemas. } // FileRetrieve retrieves file metadata from Anthropic's Files API by trying each key until found. -func (provider *AnthropicProvider) FileRetrieve(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) FileRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.FileRetrieveRequest); err != nil { return nil, err } @@ -1864,7 +1891,7 @@ func (provider *AnthropicProvider) FileRetrieve(ctx context.Context, keys []sche } // FileDelete deletes a file from Anthropic's Files API by trying each key until successful. -func (provider *AnthropicProvider) FileDelete(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.FileDeleteRequest); err != nil { return nil, err } @@ -1977,7 +2004,7 @@ func (provider *AnthropicProvider) FileDelete(ctx context.Context, keys []schema // FileContent downloads file content from Anthropic's Files API by trying each key until found. // Note: Only files created by skills or the code execution tool can be downloaded. -func (provider *AnthropicProvider) FileContent(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) FileContent(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.FileContentRequest); err != nil { return nil, err } @@ -2057,7 +2084,7 @@ func (provider *AnthropicProvider) FileContent(ctx context.Context, keys []schem } // CountTokens counts tokens for a given request using Anthropic's API. -func (provider *AnthropicProvider) CountTokens(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) CountTokens(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.CountTokensRequest); err != nil { return nil, err } diff --git a/core/providers/anthropic/responses.go b/core/providers/anthropic/responses.go index de020c9f35..21ce814a95 100644 --- a/core/providers/anthropic/responses.go +++ b/core/providers/anthropic/responses.go @@ -1502,6 +1502,17 @@ func ToAnthropicResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (* // Set system message if present if systemContent != nil { anthropicReq.System = systemContent + } else if bifrostReq.Params != nil && bifrostReq.Params.Instructions != nil && *bifrostReq.Params.Instructions != "" { + // if no system content, check if instructions are present + // system messages take precedence over instructions + anthropicReq.System = &AnthropicContent{ + ContentBlocks: []AnthropicContentBlock{ + { + Type: AnthropicContentBlockTypeText, + Text: bifrostReq.Params.Instructions, + }, + }, + } } // Set regular messages diff --git a/core/providers/azure/azure.go b/core/providers/azure/azure.go index 38fc1e2f0e..f9bc2f13ba 100644 --- a/core/providers/azure/azure.go +++ b/core/providers/azure/azure.go @@ -72,7 +72,7 @@ func (p *AzureProvider) getOrCreateAuth( // 1. Service Principal (client ID/secret/tenant ID) - Bearer token // 2. Context token - Bearer token // 3. API key - api-key or x-api-key header -func (provider *AzureProvider) getAzureAuthHeaders(ctx context.Context, key schemas.Key, isAnthropicModel bool) (map[string]string, *schemas.BifrostError) { +func (provider *AzureProvider) getAzureAuthHeaders(ctx *schemas.BifrostContext, key schemas.Key, isAnthropicModel bool) (map[string]string, *schemas.BifrostError) { authHeader := make(map[string]string) // Service Principal authentication @@ -148,7 +148,7 @@ func (provider *AzureProvider) GetProviderKey() schemas.ModelProvider { // It constructs the API URL, sets up authentication, and processes the response. // Returns the response body, request latency, or an error if the request fails. func (provider *AzureProvider) completeRequest( - ctx context.Context, + ctx *schemas.BifrostContext, jsonData []byte, path string, key schemas.Key, @@ -224,7 +224,7 @@ func (provider *AzureProvider) completeRequest( // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. -func (provider *AzureProvider) listModelsByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *AzureProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { // Validate Azure key configuration if key.AzureKeyConfig == nil { return nil, providerUtils.NewConfigurationError("azure key config not set", schemas.Azure) @@ -313,7 +313,7 @@ func (provider *AzureProvider) listModelsByKey(ctx context.Context, key schemas. // ListModels performs a list models request to Azure's API. // It retrieves all models accessible by the Azure resource // Requests are made concurrently for improved performance. -func (provider *AzureProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *AzureProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { return providerUtils.HandleMultipleListModelsRequests( ctx, keys, @@ -326,7 +326,7 @@ func (provider *AzureProvider) ListModels(ctx context.Context, keys []schemas.Ke // TextCompletion performs a text completion request to Azure's API. // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AzureProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { if err := provider.validateKeyConfig(key); err != nil { return nil, err } @@ -388,7 +388,7 @@ func (provider *AzureProvider) TextCompletion(ctx context.Context, key schemas.K // TextCompletionStream performs a streaming text completion request to Azure's API. // It formats the request, sends it to Azure, and processes the response. // Returns a channel of BifrostStream objects or an error if the request fails. -func (provider *AzureProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if err := provider.validateKeyConfig(key); err != nil { return nil, err } @@ -426,6 +426,7 @@ func (provider *AzureProvider) TextCompletionStream(ctx context.Context, postHoo providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + nil, postHookRunner, customPostResponseConverter, provider.logger, @@ -435,7 +436,7 @@ func (provider *AzureProvider) TextCompletionStream(ctx context.Context, postHoo // ChatCompletion performs a chat completion request to Azure's API. // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AzureProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { if err := provider.validateKeyConfig(key); err != nil { return nil, err } @@ -529,7 +530,7 @@ func (provider *AzureProvider) ChatCompletion(ctx context.Context, key schemas.K // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Azure-specific URL construction with deployments and supports both api-key and Bearer token authentication. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if err := provider.validateKeyConfig(key); err != nil { return nil, err } @@ -617,6 +618,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo postHookRunner, nil, nil, + nil, postResponseConverter, provider.logger, ) @@ -626,7 +628,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo // Responses performs a responses request to Azure's API. // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AzureProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { if err := provider.validateKeyConfig(key); err != nil { return nil, err } @@ -716,7 +718,7 @@ func (provider *AzureProvider) Responses(ctx context.Context, key schemas.Key, r } // ResponsesStream performs a streaming responses request to Azure's API. -func (provider *AzureProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if err := provider.validateKeyConfig(key); err != nil { return nil, err } @@ -789,6 +791,7 @@ func (provider *AzureProvider) ResponsesStream(ctx context.Context, postHookRunn providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), postHookRunner, + nil, postRequestConverter, postResponseConverter, provider.logger, @@ -799,7 +802,7 @@ func (provider *AzureProvider) ResponsesStream(ctx context.Context, postHookRunn // Embedding generates embeddings for the given input text(s) using Azure. // The input can be either a single string or a slice of strings for batch embedding. // Returns a BifrostResponse containing the embedding(s) and any error that occurred. -func (provider *AzureProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { if err := provider.validateKeyConfig(key); err != nil { return nil, err } @@ -860,7 +863,7 @@ func (provider *AzureProvider) Embedding(ctx context.Context, key schemas.Key, r } // Speech is not supported by the Azure provider. -func (provider *AzureProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *AzureProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { if err := provider.validateKeyConfig(key); err != nil { return nil, err } @@ -901,7 +904,7 @@ func (provider *AzureProvider) Speech(ctx context.Context, key schemas.Key, requ // SpeechStream handles streaming for speech synthesis with Azure. // Azure sends raw binary audio bytes in SSE format, unlike OpenAI which sends JSON. -func (provider *AzureProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if err := provider.validateKeyConfig(key); err != nil { return nil, err } @@ -1003,7 +1006,19 @@ func (provider *AzureProvider) SpeechStream(ctx context.Context, postHookRunner // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.SpeechStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.SpeechStreamRequest, provider.logger) + } + close(responseChan) + }() + // Always release response on exit; bodyStream close should prevent indefinite blocking. + defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger) + defer stopCancellation() // Check if response is compressed bodyStream := resp.BodyStream() @@ -1018,13 +1033,10 @@ func (provider *AzureProvider) SpeechStream(ctx context.Context, postHookRunner var accumulated []byte for { - // Check if context is done - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } - // Read from stream n, readErr := bodyStream.Read(readBuffer) if n > 0 { @@ -1054,7 +1066,6 @@ func (provider *AzureProvider) SpeechStream(ctx context.Context, postHookRunner // Check if this has "data: " prefix (standard SSE format) if bytes.HasPrefix(event, []byte("data: ")) { audioData = event[6:] // Skip "data: " prefix - // Check for [DONE] marker if bytes.Equal(audioData, []byte("[DONE]")) { return @@ -1073,7 +1084,7 @@ func (provider *AzureProvider) SpeechStream(ctx context.Context, postHookRunner ModelRequested: request.Model, RequestType: schemas.SpeechStreamRequest, } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, provider.logger) return } @@ -1112,6 +1123,10 @@ func (provider *AzureProvider) SpeechStream(ctx context.Context, postHookRunner // Handle read errors if readErr != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } if readErr != io.EOF { provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", readErr)) } @@ -1137,7 +1152,7 @@ func (provider *AzureProvider) SpeechStream(ctx context.Context, postHookRunner providerUtils.ParseAndSetRawRequest(&finalResponse.ExtraFields, jsonBody) } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &finalResponse, nil), responseChan) } @@ -1150,7 +1165,7 @@ func (provider *AzureProvider) SpeechStream(ctx context.Context, postHookRunner } // Transcription is not supported by the Azure provider. -func (provider *AzureProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *AzureProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { if err := provider.validateKeyConfig(key); err != nil { return nil, err } @@ -1190,7 +1205,7 @@ func (provider *AzureProvider) Transcription(ctx context.Context, key schemas.Ke } // TranscriptionStream is not supported by the Azure provider. -func (provider *AzureProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *AzureProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -1239,7 +1254,7 @@ func (provider *AzureProvider) getModelDeployment(key schemas.Key, model string) } // FileUpload uploads a file to Azure OpenAI. -func (provider *AzureProvider) FileUpload(ctx context.Context, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { if err := provider.validateKeyConfigForFiles(key); err != nil { return nil, err } @@ -1342,7 +1357,7 @@ func (provider *AzureProvider) FileUpload(ctx context.Context, key schemas.Key, // FileList lists files from all provided Azure keys and aggregates results. // FileList lists files using serial pagination across keys. // Exhausts all pages from one key before moving to the next. -func (provider *AzureProvider) FileList(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() if len(keys) == 0 { @@ -1478,7 +1493,7 @@ func (provider *AzureProvider) FileList(ctx context.Context, keys []schemas.Key, } // FileRetrieve retrieves file metadata from Azure OpenAI by trying each key until found. -func (provider *AzureProvider) FileRetrieve(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() if request.FileID == "" { @@ -1570,7 +1585,7 @@ func (provider *AzureProvider) FileRetrieve(ctx context.Context, keys []schemas. } // FileDelete deletes a file from Azure OpenAI by trying each key until successful. -func (provider *AzureProvider) FileDelete(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() if request.FileID == "" { @@ -1700,7 +1715,7 @@ func (provider *AzureProvider) FileDelete(ctx context.Context, keys []schemas.Ke } // FileContent downloads file content from Azure OpenAI by trying each key until found. -func (provider *AzureProvider) FileContent(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() if request.FileID == "" { @@ -1801,7 +1816,7 @@ func (provider *AzureProvider) FileContent(ctx context.Context, keys []schemas.K // BatchCreate creates a new batch job on Azure OpenAI. // Azure Batch API uses the same format as OpenAI but with Azure-specific URL patterns. -func (provider *AzureProvider) BatchCreate(ctx context.Context, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { if err := provider.validateKeyConfigForFiles(key); err != nil { return nil, err } @@ -1916,7 +1931,7 @@ func (provider *AzureProvider) BatchCreate(ctx context.Context, key schemas.Key, // BatchList lists batch jobs from all provided Azure keys and aggregates results. // BatchList lists batch jobs using serial pagination across keys. // Exhausts all pages from one key before moving to the next. -func (provider *AzureProvider) BatchList(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) @@ -2040,7 +2055,7 @@ func (provider *AzureProvider) BatchList(ctx context.Context, keys []schemas.Key } // BatchRetrieve retrieves a specific batch job from Azure OpenAI by trying each key until found. -func (provider *AzureProvider) BatchRetrieve(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() if request.BatchID == "" { @@ -2138,7 +2153,7 @@ func (provider *AzureProvider) BatchRetrieve(ctx context.Context, keys []schemas } // BatchCancel cancels a batch job on Azure OpenAI by trying each key until successful. -func (provider *AzureProvider) BatchCancel(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() if request.BatchID == "" { @@ -2264,7 +2279,7 @@ func (provider *AzureProvider) BatchCancel(ctx context.Context, keys []schemas.K // BatchResults retrieves batch results from Azure OpenAI by trying each key until successful. // For Azure (like OpenAI), batch results are obtained by downloading the output_file_id. -func (provider *AzureProvider) BatchResults(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *AzureProvider) BatchResults(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() // First, retrieve the batch to get the output_file_id (using all keys) @@ -2320,6 +2335,6 @@ func (provider *AzureProvider) BatchResults(ctx context.Context, keys []schemas. } // CountTokens is not supported by the Azure provider. -func (provider *AzureProvider) CountTokens(_ context.Context, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *AzureProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey()) } diff --git a/core/providers/bedrock/bedrock.go b/core/providers/bedrock/bedrock.go index ddb66f1767..373b90aea9 100644 --- a/core/providers/bedrock/bedrock.go +++ b/core/providers/bedrock/bedrock.go @@ -88,7 +88,7 @@ func (provider *BedrockProvider) GetProviderKey() schemas.ModelProvider { // completeRequest sends a request to Bedrock's API and handles the response. // It constructs the API URL, sets up AWS authentication, and processes the response. // Returns the response body, request latency, or an error if the request fails. -func (provider *BedrockProvider) completeRequest(ctx context.Context, jsonData []byte, path string, key schemas.Key) ([]byte, time.Duration, *schemas.BifrostError) { +func (provider *BedrockProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, path string, key schemas.Key) ([]byte, time.Duration, *schemas.BifrostError) { config := key.BedrockKeyConfig region := DefaultBedrockRegion @@ -189,7 +189,7 @@ func (provider *BedrockProvider) completeRequest(ctx context.Context, jsonData [ // makeStreamingRequest creates a streaming request to Bedrock's API. // It formats the request, sends it to Bedrock, and returns the response. // Returns the response body and an error if the request fails. -func (provider *BedrockProvider) makeStreamingRequest(ctx context.Context, jsonData []byte, key schemas.Key, model string, action string) (*http.Response, string, *schemas.BifrostError) { +func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContext, jsonData []byte, key schemas.Key, model string, action string) (*http.Response, string, *schemas.BifrostError) { providerName := provider.GetProviderKey() if key.BedrockKeyConfig == nil { @@ -256,7 +256,7 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx context.Context, jsonD // It sets required headers, calculates the request body hash, and signs the request // using the provided AWS credentials. // Returns a BifrostError if signing fails. -func signAWSRequest(ctx context.Context, req *http.Request, accessKey, secretKey string, sessionToken *string, region, service string, providerName schemas.ModelProvider) *schemas.BifrostError { +func signAWSRequest(ctx *schemas.BifrostContext, req *http.Request, accessKey, secretKey string, sessionToken *string, region, service string, providerName schemas.ModelProvider) *schemas.BifrostError { // Set required headers before signing (only if not already set) if req.Header.Get("Content-Type") == "" { req.Header.Set("Content-Type", "application/json") @@ -334,7 +334,7 @@ func signAWSRequest(ctx context.Context, req *http.Request, accessKey, secretKey // listModelsByKey performs a list models request to Bedrock's API for a single key. // It retrieves all foundation models available in Amazon Bedrock for a specific key. -func (provider *BedrockProvider) listModelsByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() if key.BedrockKeyConfig == nil { @@ -484,7 +484,7 @@ func (provider *BedrockProvider) listModelsByKey(ctx context.Context, key schema // ListModels performs a list models request to Bedrock's API. // It retrieves all foundation models available in Amazon Bedrock. // Requests are made concurrently for improved performance. -func (provider *BedrockProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ListModelsRequest); err != nil { return nil, err } @@ -500,7 +500,7 @@ func (provider *BedrockProvider) ListModels(ctx context.Context, keys []schemas. // TextCompletion performs a text completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *BedrockProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.TextCompletionRequest); err != nil { return nil, err } @@ -574,7 +574,7 @@ func (provider *BedrockProvider) TextCompletion(ctx context.Context, key schemas // TextCompletionStream performs a streaming text completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the response. // Returns a channel of BifrostStream objects or an error if the request fails. -func (provider *BedrockProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.TextCompletionStreamRequest); err != nil { return nil, err } @@ -604,8 +604,18 @@ func (provider *BedrockProvider) TextCompletionStream(ctx context.Context, postH // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, provider.logger) + } + close(responseChan) + }() defer resp.Body.Close() + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.Body, provider.logger) + defer stopCancellation() // Process AWS Event Stream format startTime := time.Now() @@ -613,14 +623,22 @@ func (provider *BedrockProvider) TextCompletionStream(ctx context.Context, postH payloadBuf := make([]byte, 0, 1024*1024) // 1MB payload buffer for { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } // Decode a single EventStream message message, err := decoder.Decode(resp.Body, payloadBuf) if err != nil { - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } if err == io.EOF { // End of stream - this is normal break } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error decoding %s EventStream message: %v", providerName, err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, provider.logger) return @@ -681,7 +699,7 @@ func (provider *BedrockProvider) TextCompletionStream(ctx context.Context, postH // ChatCompletion performs a chat completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *BedrockProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil { return nil, err } @@ -696,7 +714,7 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, key schemas jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (any, error) { return ToBedrockChatCompletionRequest(&ctx, request) }, + func() (any, error) { return ToBedrockChatCompletionRequest(ctx, request) }, provider.GetProviderKey()) if bifrostErr != nil { return nil, bifrostErr @@ -752,7 +770,7 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, key schemas // ChatCompletionStream performs a streaming chat completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the streaming response. // Returns a channel for streaming BifrostResponse objects or an error if the request fails. -func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err } @@ -762,7 +780,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (any, error) { return ToBedrockChatCompletionRequest(&ctx, request) }, + func() (any, error) { return ToBedrockChatCompletionRequest(ctx, request) }, provider.GetProviderKey()) if bifrostErr != nil { return nil, bifrostErr @@ -778,8 +796,18 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + } + close(responseChan) + }() defer resp.Body.Close() + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.Body, provider.logger) + defer stopCancellation() // Process AWS Event Stream format usage := &schemas.BifrostLLMUsage{} @@ -796,13 +824,22 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH id := uuid.New().String() for { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } // Decode a single EventStream message message, err := decoder.Decode(resp.Body, payloadBuf) if err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + // End of stream - this is normal if err == io.EOF { - // End of stream - this is normal break } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error decoding %s EventStream message: %v", providerName, err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) return @@ -875,7 +912,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH Provider: providerName, ModelRequested: request.Model, } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return } @@ -910,7 +947,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonData) } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) }() @@ -920,7 +957,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH // Responses performs a chat completion request to Anthropic's API. // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *BedrockProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { return nil, err } @@ -935,7 +972,7 @@ func (provider *BedrockProvider) Responses(ctx context.Context, key schemas.Key, jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (any, error) { return ToBedrockResponsesRequest(&ctx, request) }, + func() (any, error) { return ToBedrockResponsesRequest(ctx, request) }, provider.GetProviderKey()) if bifrostErr != nil { return nil, bifrostErr @@ -960,7 +997,7 @@ func (provider *BedrockProvider) Responses(ctx context.Context, key schemas.Key, } // Convert using the new response converter - bifrostResponse, err := bedrockResponse.ToBifrostResponsesResponse(&ctx) + bifrostResponse, err := bedrockResponse.ToBifrostResponsesResponse(ctx) if err != nil { return nil, providerUtils.NewBifrostOperationError("failed to convert bedrock response", err, providerName) } @@ -993,7 +1030,7 @@ func (provider *BedrockProvider) Responses(ctx context.Context, key schemas.Key, // ResponsesStream performs a streaming chat completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the streaming response. // Returns a channel for streaming BifrostResponse objects or an error if the request fails. -func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { return nil, err } @@ -1003,7 +1040,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRu jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (any, error) { return ToBedrockResponsesRequest(&ctx, request) }, + func() (any, error) { return ToBedrockResponsesRequest(ctx, request) }, provider.GetProviderKey()) if bifrostErr != nil { return nil, bifrostErr @@ -1019,8 +1056,19 @@ func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRu // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + } + close(responseChan) + }() + // Always release response on exit; bodyStream close should prevent indefinite blocking. defer resp.Body.Close() + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.Body, provider.logger) + defer stopCancellation() // Process AWS Event Stream format usage := &schemas.ResponsesResponseUsage{} @@ -1038,9 +1086,17 @@ func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRu payloadBuf := make([]byte, 0, 1024*1024) // 1MB payload buffer for { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } // Decode a single EventStream message message, err := decoder.Decode(resp.Body, payloadBuf) if err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } if err == io.EOF { // End of stream - finalize any open items finalResponses := FinalizeBedrockStream(streamState, chunkIndex, usage) @@ -1062,7 +1118,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRu if i == len(finalResponses)-1 { // Set raw request if enabled - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&finalResponse.ExtraFields, jsonData) } @@ -1072,8 +1128,8 @@ func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRu providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, finalResponse, nil, nil), responseChan) } break - } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error decoding %s EventStream message: %v", providerName, err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) return @@ -1134,7 +1190,6 @@ func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRu } } } - responses, bifrostErr, _ := streamEvent.ToBifrostResponsesStream(chunkIndex, streamState) if bifrostErr != nil { bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ @@ -1142,7 +1197,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRu Provider: providerName, ModelRequested: request.Model, } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return } @@ -1175,7 +1230,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRu // Embedding generates embeddings for the given input text(s) using Amazon Bedrock. // Supports Titan and Cohere embedding models. Returns a BifrostResponse containing the embedding(s) and any error that occurred. -func (provider *BedrockProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.EmbeddingRequest); err != nil { return nil, err } @@ -1270,27 +1325,27 @@ func (provider *BedrockProvider) Embedding(ctx context.Context, key schemas.Key, } // Speech is not supported by the Bedrock provider. -func (provider *BedrockProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, schemas.Bedrock) } // SpeechStream is not supported by the Bedrock provider. -func (provider *BedrockProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *BedrockProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, schemas.Bedrock) } // Transcription is not supported by the Bedrock provider. -func (provider *BedrockProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, schemas.Bedrock) } // TranscriptionStream is not supported by the Bedrock provider. -func (provider *BedrockProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *BedrockProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, schemas.Bedrock) } // FileUpload uploads a file to S3 for Bedrock batch processing. -func (provider *BedrockProvider) FileUpload(ctx context.Context, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.FileUploadRequest); err != nil { if err.Error != nil { @@ -1422,7 +1477,7 @@ func (provider *BedrockProvider) FileUpload(ctx context.Context, key schemas.Key // FileList lists files in the S3 bucket used for Bedrock batch processing from all provided keys. // FileList lists S3 files using serial pagination across keys. // Exhausts all pages from one key before moving to the next. -func (provider *BedrockProvider) FileList(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.FileListRequest); err != nil { return nil, err } @@ -1589,7 +1644,7 @@ func (provider *BedrockProvider) FileList(ctx context.Context, keys []schemas.Ke } // FileRetrieve retrieves S3 object metadata for Bedrock batch processing by trying each key until found. -func (provider *BedrockProvider) FileRetrieve(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) FileRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.FileRetrieveRequest); err != nil { return nil, err } @@ -1696,7 +1751,7 @@ func (provider *BedrockProvider) FileRetrieve(ctx context.Context, keys []schema } // FileDelete deletes an S3 object used for Bedrock batch processing by trying each key until successful. -func (provider *BedrockProvider) FileDelete(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.FileDeleteRequest); err != nil { return nil, err } @@ -1786,7 +1841,7 @@ func (provider *BedrockProvider) FileDelete(ctx context.Context, keys []schemas. } // FileContent downloads S3 object content for Bedrock batch processing by trying each key until found. -func (provider *BedrockProvider) FileContent(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) FileContent(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.FileContentRequest); err != nil { return nil, err } @@ -1885,7 +1940,7 @@ func (provider *BedrockProvider) FileContent(ctx context.Context, keys []schemas } // BatchCreate creates a new batch inference job on AWS Bedrock. -func (provider *BedrockProvider) BatchCreate(ctx context.Context, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.BatchCreateRequest); err != nil { provider.logger.Error("batch create is not allowed for Bedrock provider", "error", err) return nil, err @@ -2131,7 +2186,7 @@ func (provider *BedrockProvider) BatchCreate(ctx context.Context, key schemas.Ke // BatchList lists batch inference jobs using serial pagination across keys. // Exhausts all pages from one key before moving to the next. -func (provider *BedrockProvider) BatchList(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.BatchListRequest); err != nil { return nil, err } @@ -2286,7 +2341,7 @@ func (provider *BedrockProvider) BatchList(ctx context.Context, keys []schemas.K // fetchBatchManifest fetches the manifest.json.out from S3 to get record counts. // Returns nil if manifest doesn't exist (job still in progress) or on error. -func (provider *BedrockProvider) fetchBatchManifest(ctx context.Context, key schemas.Key, region, outputS3Uri string) *BedrockBatchManifest { +func (provider *BedrockProvider) fetchBatchManifest(ctx *schemas.BifrostContext, key schemas.Key, region, outputS3Uri string) *BedrockBatchManifest { if outputS3Uri == "" { return nil } @@ -2347,7 +2402,7 @@ func (provider *BedrockProvider) fetchBatchManifest(ctx context.Context, key sch } // BatchRetrieve retrieves a specific batch inference job from AWS Bedrock by trying each key until found. -func (provider *BedrockProvider) BatchRetrieve(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.BatchRetrieveRequest); err != nil { return nil, err } @@ -2496,7 +2551,7 @@ func (provider *BedrockProvider) BatchRetrieve(ctx context.Context, keys []schem } // BatchCancel stops a batch inference job on AWS Bedrock by trying each key until successful. -func (provider *BedrockProvider) BatchCancel(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.BatchCancelRequest); err != nil { return nil, err } @@ -2610,7 +2665,7 @@ func (provider *BedrockProvider) BatchCancel(ctx context.Context, keys []schemas // BatchResults retrieves batch results from AWS Bedrock by trying each key until successful. // For Bedrock, results are stored in S3 at the output S3 URI prefix. // The output includes JSONL files with results (*.jsonl.out) and a manifest file. -func (provider *BedrockProvider) BatchResults(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.BatchResultsRequest); err != nil { return nil, err } @@ -2743,6 +2798,6 @@ func (provider *BedrockProvider) getModelPath(basePath string, model string, key return path, deployment } -func (provider *BedrockProvider) CountTokens(_ context.Context, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey()) } diff --git a/core/providers/bedrock/bedrock_test.go b/core/providers/bedrock/bedrock_test.go index 76df97ca28..98e724bcc3 100644 --- a/core/providers/bedrock/bedrock_test.go +++ b/core/providers/bedrock/bedrock_test.go @@ -677,8 +677,8 @@ func TestBifrostToBedrockRequestConversion(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - actual, err := bedrock.ToBedrockChatCompletionRequest(&ctx, tt.input) + ctx := schemas.NewBifrostContext(context.Background(),schemas.NoDeadline) + actual, err := bedrock.ToBedrockChatCompletionRequest(ctx, tt.input) if tt.wantErr { assert.Error(t, err) assert.Nil(t, actual) @@ -705,7 +705,7 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { trace := testTrace latency := testLatency props := testProps - ctx := context.Background() + ctx := schemas.NewBifrostContext(context.Background(),schemas.NoDeadline) tests := []struct { name string @@ -1194,9 +1194,9 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { var err error if tt.input == nil { var bedrockReq *bedrock.BedrockConverseRequest - actual, err = bedrockReq.ToBifrostResponsesRequest(&ctx) + actual, err = bedrockReq.ToBifrostResponsesRequest(ctx) } else { - actual, err = tt.input.ToBifrostResponsesRequest(&ctx) + actual, err = tt.input.ToBifrostResponsesRequest(ctx) } if tt.wantErr { assert.Error(t, err) @@ -1775,7 +1775,7 @@ func TestBedrockToBifrostResponseConversion(t *testing.T) { toolInput := map[string]interface{}{ "location": "NYC", } - ctx := context.Background() + ctx := schemas.NewBifrostContext(context.Background(),schemas.NoDeadline) tests := []struct { name string @@ -1919,9 +1919,9 @@ func TestBedrockToBifrostResponseConversion(t *testing.T) { var err error if tt.input == nil { var bedrockResp *bedrock.BedrockConverseResponse - actual, err = bedrockResp.ToBifrostResponsesResponse(&ctx) + actual, err = bedrockResp.ToBifrostResponsesResponse(ctx) } else { - actual, err = tt.input.ToBifrostResponsesResponse(&ctx) + actual, err = tt.input.ToBifrostResponsesResponse(ctx) } if tt.wantErr { assert.Error(t, err) @@ -1986,8 +1986,8 @@ func TestToBedrockResponsesRequest_AdditionalFields(t *testing.T) { }, } - ctx := context.Background() - bedrockReq, err := bedrock.ToBedrockResponsesRequest(&ctx, req) + ctx := schemas.NewBifrostContext(context.Background(),schemas.NoDeadline) + bedrockReq, err := bedrock.ToBedrockResponsesRequest(ctx, req) require.NoError(t, err) require.NotNil(t, bedrockReq) @@ -2013,8 +2013,8 @@ func TestToBedrockResponsesRequest_AdditionalFields_InterfaceSlice(t *testing.T) }, } - ctx := context.Background() - bedrockReq, err := bedrock.ToBedrockResponsesRequest(&ctx, req) + ctx := schemas.NewBifrostContext(context.Background(),schemas.NoDeadline) + bedrockReq, err := bedrock.ToBedrockResponsesRequest(ctx, req) require.NoError(t, err) require.NotNil(t, bedrockReq) diff --git a/core/providers/bedrock/chat.go b/core/providers/bedrock/chat.go index f2434e703f..97a635a5a6 100644 --- a/core/providers/bedrock/chat.go +++ b/core/providers/bedrock/chat.go @@ -12,7 +12,7 @@ import ( ) // ToBedrockChatCompletionRequest converts a Bifrost request to Bedrock Converse API format -func ToBedrockChatCompletionRequest(ctx *context.Context, bifrostReq *schemas.BifrostChatRequest) (*BedrockConverseRequest, error) { +func ToBedrockChatCompletionRequest(ctx *schemas.BifrostContext, bifrostReq *schemas.BifrostChatRequest) (*BedrockConverseRequest, error) { if bifrostReq == nil { return nil, fmt.Errorf("bifrost request is nil") } diff --git a/core/providers/bedrock/responses.go b/core/providers/bedrock/responses.go index 6478d5c7c1..b7290d2cd5 100644 --- a/core/providers/bedrock/responses.go +++ b/core/providers/bedrock/responses.go @@ -1,7 +1,6 @@ package bedrock import ( - "context" "encoding/base64" "encoding/json" "fmt" @@ -1205,7 +1204,7 @@ func (event *BedrockStreamEvent) ToEncodedEvents() []BedrockEncodedEvent { } // ToBifrostResponsesRequest converts a BedrockConverseRequest to Bifrost Responses Request format -func (request *BedrockConverseRequest) ToBifrostResponsesRequest(ctx *context.Context) (*schemas.BifrostResponsesRequest, error) { +func (request *BedrockConverseRequest) ToBifrostResponsesRequest(ctx *schemas.BifrostContext) (*schemas.BifrostResponsesRequest, error) { if request == nil { return nil, fmt.Errorf("bedrock request is nil") } @@ -1435,7 +1434,7 @@ func (request *BedrockConverseRequest) ToBifrostResponsesRequest(ctx *context.Co } // ToBedrockResponsesRequest converts a BifrostRequest (Responses structure) back to BedrockConverseRequest -func ToBedrockResponsesRequest(ctx *context.Context, bifrostReq *schemas.BifrostResponsesRequest) (*BedrockConverseRequest, error) { +func ToBedrockResponsesRequest(ctx *schemas.BifrostContext, bifrostReq *schemas.BifrostResponsesRequest) (*BedrockConverseRequest, error) { if bifrostReq == nil { return nil, fmt.Errorf("bifrost request is nil") } @@ -1453,6 +1452,15 @@ func ToBedrockResponsesRequest(ctx *context.Context, bifrostReq *schemas.Bifrost bedrockReq.Messages = messages if len(systemMessages) > 0 { bedrockReq.System = systemMessages + } else { + if bifrostReq.Params != nil && bifrostReq.Params.Instructions != nil { + // if no system messages, check if instructions are present + bedrockReq.System = []BedrockSystemMessage{ + { + Text: bifrostReq.Params.Instructions, + }, + } + } } } @@ -1697,7 +1705,7 @@ func ToBedrockResponsesRequest(ctx *context.Context, bifrostReq *schemas.Bifrost } // ToBifrostResponsesResponse converts BedrockConverseResponse to BifrostResponsesResponse -func (response *BedrockConverseResponse) ToBifrostResponsesResponse(ctx *context.Context) (*schemas.BifrostResponsesResponse, error) { +func (response *BedrockConverseResponse) ToBifrostResponsesResponse(ctx *schemas.BifrostContext) (*schemas.BifrostResponsesResponse, error) { if response == nil { return nil, fmt.Errorf("bedrock response is nil") } @@ -2484,7 +2492,7 @@ func ConvertBifrostMessagesToBedrockMessages(bifrostMessages []schemas.Responses // ConvertBedrockMessagesToBifrostMessages converts an array of Bedrock messages to Bifrost ResponsesMessage format // This is the main conversion method from Bedrock to Bifrost - handles all message types and content blocks -func ConvertBedrockMessagesToBifrostMessages(ctx *context.Context, bedrockMessages []BedrockMessage, systemMessages []BedrockSystemMessage, isOutputMessage bool) []schemas.ResponsesMessage { +func ConvertBedrockMessagesToBifrostMessages(ctx *schemas.BifrostContext, bedrockMessages []BedrockMessage, systemMessages []BedrockSystemMessage, isOutputMessage bool) []schemas.ResponsesMessage { var bifrostMessages []schemas.ResponsesMessage // Convert system messages first @@ -2630,14 +2638,14 @@ func createTextMessage( } // convertSingleBedrockMessageToBifrostMessages converts a single Bedrock message to Bifrost messages -func convertSingleBedrockMessageToBifrostMessages(ctx *context.Context, msg *BedrockMessage, isOutputMessage bool) []schemas.ResponsesMessage { +func convertSingleBedrockMessageToBifrostMessages(ctx *schemas.BifrostContext, msg *BedrockMessage, isOutputMessage bool) []schemas.ResponsesMessage { var outputMessages []schemas.ResponsesMessage var reasoningContentBlocks []schemas.ResponsesMessageContentBlock // Check if we have a structured output tool var structuredOutputToolName string - if ctx != nil && *ctx != nil { - if toolName, ok := (*ctx).Value(schemas.BifrostContextKeyStructuredOutputToolName).(string); ok { + if ctx != nil { + if toolName, ok := ctx.Value(schemas.BifrostContextKeyStructuredOutputToolName).(string); ok { structuredOutputToolName = toolName } } diff --git a/core/providers/bedrock/utils.go b/core/providers/bedrock/utils.go index 87ab52a2fa..b9d8948569 100644 --- a/core/providers/bedrock/utils.go +++ b/core/providers/bedrock/utils.go @@ -1,7 +1,6 @@ package bedrock import ( - "context" "encoding/base64" "encoding/json" "fmt" @@ -46,7 +45,7 @@ func normalizeBedrockFilename(filename string) string { } // convertParameters handles parameter conversion -func convertChatParameters(ctx *context.Context, bifrostReq *schemas.BifrostChatRequest, bedrockReq *BedrockConverseRequest) error { +func convertChatParameters(ctx *schemas.BifrostContext, bifrostReq *schemas.BifrostChatRequest, bedrockReq *BedrockConverseRequest) error { // Parameters are optional - if not provided, just skip conversion if bifrostReq.Params == nil { return nil @@ -277,7 +276,7 @@ func convertChatParameters(ctx *context.Context, bifrostReq *schemas.BifrostChat } // Handle request metadata if reqMetadata, exists := bifrostReq.Params.ExtraParams["requestMetadata"]; exists { - if metadata, ok := reqMetadata.(map[string]string); ok { + if metadata, ok := schemas.SafeExtractStringMap(reqMetadata); ok { bedrockReq.RequestMetadata = metadata } } @@ -676,7 +675,7 @@ func convertImageToBedrockSource(imageURL string) (*BedrockImageSource, error) { // convertResponseFormatToTool converts a response_format parameter to a Bedrock tool // Returns nil if no response_format is present or if it's not a json_schema type -func convertResponseFormatToTool(ctx *context.Context, params *schemas.ChatParameters) *BedrockTool { +func convertResponseFormatToTool(ctx *schemas.BifrostContext, params *schemas.ChatParameters) *BedrockTool { if params == nil || params.ResponseFormat == nil { return nil } @@ -718,7 +717,7 @@ func convertResponseFormatToTool(ctx *context.Context, params *schemas.ChatParam // set bifrost context key structured output tool name toolName = fmt.Sprintf("bf_so_%s", toolName) - (*ctx) = context.WithValue(*ctx, schemas.BifrostContextKeyStructuredOutputToolName, toolName) + ctx.SetValue(schemas.BifrostContextKeyStructuredOutputToolName, toolName) // Create the Bedrock tool return &BedrockTool{ @@ -733,7 +732,7 @@ func convertResponseFormatToTool(ctx *context.Context, params *schemas.ChatParam } // convertTextFormatToTool converts a text config to a Bedrock tool for structured outpute -func convertTextFormatToTool(ctx *context.Context, textConfig *schemas.ResponsesTextConfig) *BedrockTool { +func convertTextFormatToTool(ctx *schemas.BifrostContext, textConfig *schemas.ResponsesTextConfig) *BedrockTool { if textConfig == nil || textConfig.Format == nil { return nil } @@ -754,7 +753,7 @@ func convertTextFormatToTool(ctx *context.Context, textConfig *schemas.Responses } toolName = fmt.Sprintf("bf_so_%s", toolName) - (*ctx) = context.WithValue(*ctx, schemas.BifrostContextKeyStructuredOutputToolName, toolName) + ctx.SetValue(schemas.BifrostContextKeyStructuredOutputToolName, toolName) var schemaObj any if format.JSONSchema != nil { diff --git a/core/providers/cerebras/cerebras.go b/core/providers/cerebras/cerebras.go index 69838c7b5f..47891f004c 100644 --- a/core/providers/cerebras/cerebras.go +++ b/core/providers/cerebras/cerebras.go @@ -2,7 +2,6 @@ package cerebras import ( - "context" "strings" "time" @@ -59,7 +58,7 @@ func (provider *CerebrasProvider) GetProviderKey() schemas.ModelProvider { } // ListModels performs a list models request to Cerebras's API. -func (provider *CerebrasProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { return openai.HandleOpenAIListModelsRequest( ctx, provider.client, @@ -77,7 +76,7 @@ func (provider *CerebrasProvider) ListModels(ctx context.Context, keys []schemas // TextCompletion performs a text completion request to Cerebras's API. // It formats the request, sends it to Cerebras, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *CerebrasProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionRequest( ctx, provider.client, @@ -88,6 +87,7 @@ func (provider *CerebrasProvider) TextCompletion(ctx context.Context, key schema provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), + nil, provider.logger, ) } @@ -95,7 +95,7 @@ func (provider *CerebrasProvider) TextCompletion(ctx context.Context, key schema // TextCompletionStream performs a streaming text completion request to Cerebras's API. // It formats the request, sends it to Cerebras, and processes the response. // Returns a channel of BifrostStream objects or an error if the request fails. -func (provider *CerebrasProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *CerebrasProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { var authHeader map[string]string if key.Value != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value} @@ -111,6 +111,7 @@ func (provider *CerebrasProvider) TextCompletionStream(ctx context.Context, post providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + nil, postHookRunner, nil, provider.logger, @@ -118,7 +119,7 @@ func (provider *CerebrasProvider) TextCompletionStream(ctx context.Context, post } // ChatCompletion performs a chat completion request to the Cerebras API. -func (provider *CerebrasProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { return openai.HandleOpenAIChatCompletionRequest( ctx, provider.client, @@ -129,6 +130,7 @@ func (provider *CerebrasProvider) ChatCompletion(ctx context.Context, key schema providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + nil, provider.logger, ) } @@ -137,7 +139,7 @@ func (provider *CerebrasProvider) ChatCompletion(ctx context.Context, key schema // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Cerebras's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *CerebrasProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *CerebrasProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { var authHeader map[string]string if key.Value != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value} @@ -157,11 +159,12 @@ func (provider *CerebrasProvider) ChatCompletionStream(ctx context.Context, post nil, nil, nil, + nil, provider.logger, ) } -func (provider *CerebrasProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) if err != nil { return nil, err @@ -176,8 +179,8 @@ func (provider *CerebrasProvider) Responses(ctx context.Context, key schemas.Key } // ResponsesStream performs a streaming responses request to the Cerebras API. -func (provider *CerebrasProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) +func (provider *CerebrasProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, @@ -187,81 +190,81 @@ func (provider *CerebrasProvider) ResponsesStream(ctx context.Context, postHookR } // Embedding is not supported by the Cerebras provider. -func (provider *CerebrasProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey()) } // Speech is not supported by the Cerebras provider. -func (provider *CerebrasProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) } // SpeechStream is not supported by the Cerebras provider. -func (provider *CerebrasProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *CerebrasProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } // Transcription is not supported by the Cerebras provider. -func (provider *CerebrasProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey()) } // TranscriptionStream is not supported by the Cerebras provider. -func (provider *CerebrasProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *CerebrasProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } // FileUpload is not supported by Cerebras provider. -func (provider *CerebrasProvider) FileUpload(_ context.Context, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey()) } // FileList is not supported by Cerebras provider. -func (provider *CerebrasProvider) FileList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey()) } // FileRetrieve is not supported by Cerebras provider. -func (provider *CerebrasProvider) FileRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey()) } // FileDelete is not supported by Cerebras provider. -func (provider *CerebrasProvider) FileDelete(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey()) } // FileContent is not supported by Cerebras provider. -func (provider *CerebrasProvider) FileContent(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey()) } // BatchCreate is not supported by Cerebras provider. -func (provider *CerebrasProvider) BatchCreate(_ context.Context, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) } // BatchList is not supported by Cerebras provider. -func (provider *CerebrasProvider) BatchList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey()) } // BatchRetrieve is not supported by Cerebras provider. -func (provider *CerebrasProvider) BatchRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey()) } // BatchCancel is not supported by Cerebras provider. -func (provider *CerebrasProvider) BatchCancel(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey()) } // BatchResults is not supported by Cerebras provider. -func (provider *CerebrasProvider) BatchResults(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey()) } // CountTokens is not supported by the Cerebras provider. -func (provider *CerebrasProvider) CountTokens(_ context.Context, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey()) } diff --git a/core/providers/cohere/cohere.go b/core/providers/cohere/cohere.go index c086a7c433..2adbfd3ac5 100644 --- a/core/providers/cohere/cohere.go +++ b/core/providers/cohere/cohere.go @@ -118,14 +118,14 @@ func (provider *CohereProvider) GetProviderKey() schemas.ModelProvider { } // buildRequestURL constructs the full request URL using the provider's configuration. -func (provider *CohereProvider) buildRequestURL(ctx context.Context, defaultPath string, requestType schemas.RequestType) string { +func (provider *CohereProvider) buildRequestURL(ctx *schemas.BifrostContext, defaultPath string, requestType schemas.RequestType) string { return provider.networkConfig.BaseURL + providerUtils.GetRequestPath(ctx, defaultPath, provider.customProviderConfig, requestType) } // completeRequest sends a request to Cohere's API and handles the response. // It constructs the API URL, sets up authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *CohereProvider) completeRequest(ctx context.Context, jsonData []byte, url string, key string, meta *providerUtils.RequestMetadata) ([]byte, time.Duration, *schemas.BifrostError) { +func (provider *CohereProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string, meta *providerUtils.RequestMetadata) ([]byte, time.Duration, *schemas.BifrostError) { // Create the request with the JSON body req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -170,7 +170,7 @@ func (provider *CohereProvider) completeRequest(ctx context.Context, jsonData [] // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. -func (provider *CohereProvider) listModelsByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() // Create request @@ -248,7 +248,7 @@ func (provider *CohereProvider) listModelsByKey(ctx context.Context, key schemas // ListModels performs a list models request to Cohere's API. // Requests are made concurrently for improved performance. -func (provider *CohereProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *CohereProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ListModelsRequest); err != nil { return nil, err } @@ -266,21 +266,21 @@ func (provider *CohereProvider) ListModels(ctx context.Context, keys []schemas.K // TextCompletion is not supported by the Cohere provider. // Returns an error indicating that text completion is not supported. -func (provider *CohereProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *CohereProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionRequest, provider.GetProviderKey()) } // TextCompletionStream performs a streaming text completion request to Cohere's API. // It formats the request, sends it to Cohere, and processes the response. // Returns a channel of BifrostStream objects or an error if the request fails. -func (provider *CohereProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *CohereProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } // ChatCompletion performs a chat completion request to the Cohere API using v2 converter. // It formats the request, sends it to Cohere, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *CohereProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *CohereProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { // Check if chat completion is allowed if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil { return nil, err @@ -338,7 +338,7 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, key schemas. // ChatCompletionStream performs a streaming chat completion request to the Cohere API. // It supports real-time streaming of responses using Server-Sent Events (SSE). // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Check if chat completion stream is allowed if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err @@ -417,8 +417,18 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) buf := make([]byte, 0, 1024*1024) @@ -430,6 +440,10 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo var responseID string for scanner.Scan() { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } line := scanner.Text() // Skip empty lines and comments @@ -466,7 +480,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo Provider: providerName, ModelRequested: request.Model, } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) break } @@ -493,7 +507,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) break } @@ -503,6 +517,11 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo } if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) } @@ -512,7 +531,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo } // Responses performs a responses request to the Cohere API using v2 converter. -func (provider *CohereProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *CohereProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { // Check if chat completion is allowed if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { return nil, err @@ -570,7 +589,7 @@ func (provider *CohereProvider) Responses(ctx context.Context, key schemas.Key, } // ResponsesStream performs a streaming responses request to the Cohere API. -func (provider *CohereProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Check if responses stream is allowed if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { return nil, err @@ -650,8 +669,18 @@ func (provider *CohereProvider) ResponsesStream(ctx context.Context, postHookRun // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) buf := make([]byte, 0, 1024*1024) @@ -671,6 +700,10 @@ func (provider *CohereProvider) ResponsesStream(ctx context.Context, postHookRun var eventData string for scanner.Scan() { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } line := scanner.Text() // Skip empty lines and comments @@ -713,7 +746,7 @@ func (provider *CohereProvider) ResponsesStream(ctx context.Context, postHookRun Provider: providerName, ModelRequested: request.Model, } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) break } @@ -743,7 +776,7 @@ func (provider *CohereProvider) ResponsesStream(ctx context.Context, postHookRun providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil), responseChan) return } @@ -756,6 +789,11 @@ func (provider *CohereProvider) ResponsesStream(ctx context.Context, postHookRun } if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error reading %s stream: %v", providerName, err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) } @@ -766,7 +804,7 @@ func (provider *CohereProvider) ResponsesStream(ctx context.Context, postHookRun // Embedding generates embeddings for the given input text(s) using the Cohere API. // Supports Cohere's embedding models and returns a BifrostResponse containing the embedding(s). -func (provider *CohereProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *CohereProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { // Check if embedding is allowed if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.EmbeddingRequest); err != nil { return nil, err @@ -822,77 +860,77 @@ func (provider *CohereProvider) Embedding(ctx context.Context, key schemas.Key, } // Speech is not supported by the Cohere provider. -func (provider *CohereProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *CohereProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) } // SpeechStream is not supported by the Cohere provider. -func (provider *CohereProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *CohereProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } // Transcription is not supported by the Cohere provider. -func (provider *CohereProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *CohereProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey()) } // TranscriptionStream is not supported by the Cohere provider. -func (provider *CohereProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *CohereProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } // BatchCreate is not supported by Cohere provider. -func (provider *CohereProvider) BatchCreate(_ context.Context, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *CohereProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) } // BatchList is not supported by Cohere provider. -func (provider *CohereProvider) BatchList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *CohereProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey()) } // BatchRetrieve is not supported by Cohere provider. -func (provider *CohereProvider) BatchRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *CohereProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey()) } // BatchCancel is not supported by Cohere provider. -func (provider *CohereProvider) BatchCancel(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *CohereProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey()) } // BatchResults is not supported by Cohere provider. -func (provider *CohereProvider) BatchResults(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *CohereProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey()) } // FileUpload is not supported by Cohere provider. -func (provider *CohereProvider) FileUpload(_ context.Context, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *CohereProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey()) } // FileList is not supported by Cohere provider. -func (provider *CohereProvider) FileList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *CohereProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey()) } // FileRetrieve is not supported by Cohere provider. -func (provider *CohereProvider) FileRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *CohereProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey()) } // FileDelete is not supported by Cohere provider. -func (provider *CohereProvider) FileDelete(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *CohereProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey()) } // FileContent is not supported by Cohere provider. -func (provider *CohereProvider) FileContent(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *CohereProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey()) } // CountTokens performs a token counting request via Cohere's /v1/tokenize API. -func (provider *CohereProvider) CountTokens(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *CohereProvider) CountTokens(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.CountTokensRequest); err != nil { return nil, err } diff --git a/core/providers/cohere/responses.go b/core/providers/cohere/responses.go index 5d98d0958b..29971514c5 100644 --- a/core/providers/cohere/responses.go +++ b/core/providers/cohere/responses.go @@ -1031,7 +1031,7 @@ func ToCohereResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (*Coh // Process ResponsesInput (which contains the Responses items) if bifrostReq.Input != nil { - cohereReq.Messages = ConvertBifrostMessagesToCohereMessages(bifrostReq.Input) + cohereReq.Messages = ConvertBifrostMessagesToCohereMessages(bifrostReq.Input, bifrostReq.Params) } return cohereReq, nil @@ -1082,7 +1082,7 @@ func (response *CohereChatResponse) ToBifrostResponsesResponse() *schemas.Bifros // ConvertBifrostMessagesToCohereMessages converts an array of Bifrost ResponsesMessage to Cohere message format // This is the main conversion method from Bifrost to Cohere - handles all message types and returns messages -func ConvertBifrostMessagesToCohereMessages(bifrostMessages []schemas.ResponsesMessage) []CohereMessage { +func ConvertBifrostMessagesToCohereMessages(bifrostMessages []schemas.ResponsesMessage, params *schemas.ResponsesParameters) []CohereMessage { var cohereMessages []CohereMessage var systemContent []string var pendingReasoningContentBlocks []CohereContentBlock @@ -1212,6 +1212,13 @@ func ConvertBifrostMessagesToCohereMessages(bifrostMessages []schemas.ResponsesM Content: NewStringContent(strings.Join(systemContent, "\n")), } cohereMessages = append([]CohereMessage{systemMsg}, cohereMessages...) + } else if params != nil && params.Instructions != nil { + // if no system messages, check if instructions are present + systemMsg := CohereMessage{ + Role: "system", + Content: NewStringContent(*params.Instructions), + } + cohereMessages = append([]CohereMessage{systemMsg}, cohereMessages...) } return cohereMessages diff --git a/core/providers/elevenlabs/elevenlabs.go b/core/providers/elevenlabs/elevenlabs.go index 0e8d082b39..07d27b059d 100644 --- a/core/providers/elevenlabs/elevenlabs.go +++ b/core/providers/elevenlabs/elevenlabs.go @@ -69,7 +69,7 @@ func (provider *ElevenlabsProvider) GetProviderKey() schemas.ModelProvider { // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. -func (provider *ElevenlabsProvider) listModelsByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() // Create request @@ -127,7 +127,7 @@ func (provider *ElevenlabsProvider) listModelsByKey(ctx context.Context, key sch // ListModels performs a list models request to Elevenlabs' API. // Requests are made concurrently for improved performance. -func (provider *ElevenlabsProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Elevenlabs, provider.customProviderConfig, schemas.ListModelsRequest); err != nil { return nil, err } @@ -141,42 +141,42 @@ func (provider *ElevenlabsProvider) ListModels(ctx context.Context, keys []schem } // TextCompletion is not supported by the Elevenlabs provider -func (provider *ElevenlabsProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionRequest, provider.GetProviderKey()) } // TextCompletionStream is not supported by the Elevenlabs provider -func (provider *ElevenlabsProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } // ChatCompletion is not supported by the Elevenlabs provider -func (provider *ElevenlabsProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ChatCompletionRequest, provider.GetProviderKey()) } // ChatCompletionStream is not supported by the Elevenlabs provider -func (provider *ElevenlabsProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ChatCompletionStreamRequest, provider.GetProviderKey()) } // Responses is not supported by the Elevenlabs provider -func (provider *ElevenlabsProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ResponsesRequest, provider.GetProviderKey()) } // ResponsesStream is not supported by the Elevenlabs provider -func (provider *ElevenlabsProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ResponsesStreamRequest, provider.GetProviderKey()) } // Embedding is not supported by the Elevenlabs provider. -func (provider *ElevenlabsProvider) Embedding(ctx context.Context, key schemas.Key, input *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, input *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey()) } // Speech performs a text to speech request -func (provider *ElevenlabsProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Elevenlabs, provider.customProviderConfig, schemas.SpeechRequest); err != nil { return nil, err } @@ -296,7 +296,7 @@ func (provider *ElevenlabsProvider) Speech(ctx context.Context, key schemas.Key, } // SpeechStream performs a text to speech stream request -func (provider *ElevenlabsProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Elevenlabs, provider.customProviderConfig, schemas.SpeechStreamRequest); err != nil { return nil, err } @@ -376,8 +376,18 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx context.Context, postHookRu responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize) go func() { + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) - defer close(responseChan) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger) + defer stopCancellation() // read binary audio chunks from the stream // 4KB buffer for reading chunks @@ -387,18 +397,20 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx context.Context, postHookRu lastChunkTime := time.Now() for { - // Check if context is done before processing - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } - n, err := bodyStream.Read(buffer) if err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } if err == io.EOF { break - } + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, provider.logger) return @@ -448,7 +460,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx context.Context, postHookRu if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&finalResponse.ExtraFields, jsonBody) } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, finalResponse, nil), responseChan) }() @@ -456,7 +468,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx context.Context, postHookRu } // Transcription performs a transcription request -func (provider *ElevenlabsProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Elevenlabs, provider.customProviderConfig, schemas.TranscriptionRequest); err != nil { return nil, err } @@ -687,12 +699,12 @@ func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTr } // TranscriptionStream is not supported by the Elevenlabs provider -func (provider *ElevenlabsProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } // buildSpeechRequestURL constructs the full request URL using the provider's configuration for speech. -func (provider *ElevenlabsProvider) buildBaseSpeechRequestURL(ctx context.Context, defaultPath string, requestType schemas.RequestType, request *schemas.BifrostSpeechRequest) string { +func (provider *ElevenlabsProvider) buildBaseSpeechRequestURL(ctx *schemas.BifrostContext, defaultPath string, requestType schemas.RequestType, request *schemas.BifrostSpeechRequest) string { baseURL := provider.networkConfig.BaseURL requestPath := providerUtils.GetRequestPath(ctx, defaultPath, provider.customProviderConfig, requestType) @@ -724,56 +736,56 @@ func (provider *ElevenlabsProvider) buildBaseSpeechRequestURL(ctx context.Contex } // BatchCreate is not supported by Elevenlabs provider. -func (provider *ElevenlabsProvider) BatchCreate(_ context.Context, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) } // BatchList is not supported by Elevenlabs provider. -func (provider *ElevenlabsProvider) BatchList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey()) } // BatchRetrieve is not supported by Elevenlabs provider. -func (provider *ElevenlabsProvider) BatchRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey()) } // BatchCancel is not supported by Elevenlabs provider. -func (provider *ElevenlabsProvider) BatchCancel(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey()) } // BatchResults is not supported by Elevenlabs provider. -func (provider *ElevenlabsProvider) BatchResults(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey()) } // FileUpload is not supported by Elevenlabs provider. -func (provider *ElevenlabsProvider) FileUpload(_ context.Context, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey()) } // FileList is not supported by Elevenlabs provider. -func (provider *ElevenlabsProvider) FileList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey()) } // FileRetrieve is not supported by Elevenlabs provider. -func (provider *ElevenlabsProvider) FileRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey()) } // FileDelete is not supported by Elevenlabs provider. -func (provider *ElevenlabsProvider) FileDelete(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey()) } // FileContent is not supported by Elevenlabs provider. -func (provider *ElevenlabsProvider) FileContent(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey()) } // CountTokens is not supported by the Elevenlabs provider. -func (provider *ElevenlabsProvider) CountTokens(_ context.Context, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *ElevenlabsProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey()) } diff --git a/core/providers/gemini/chat.go b/core/providers/gemini/chat.go index bd5fef5f14..6bc416f90e 100644 --- a/core/providers/gemini/chat.go +++ b/core/providers/gemini/chat.go @@ -37,7 +37,7 @@ func ToGeminiChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) *Gemi if bifrostReq.Params.ExtraParams != nil { // Safety settings if safetySettings, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "safety_settings"); ok { - if settings, ok := safetySettings.([]SafetySetting); ok { + if settings, ok := SafeExtractSafetySettings(safetySettings); ok { geminiReq.SafetySettings = settings } } @@ -49,7 +49,7 @@ func ToGeminiChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) *Gemi // Labels if labels, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "labels"); ok { - if labelMap, ok := labels.(map[string]string); ok { + if labelMap, ok := schemas.SafeExtractStringMap(labels); ok { geminiReq.Labels = labelMap } } diff --git a/core/providers/gemini/gemini.go b/core/providers/gemini/gemini.go index 9eb2b98ad8..50894ffb1c 100644 --- a/core/providers/gemini/gemini.go +++ b/core/providers/gemini/gemini.go @@ -70,7 +70,7 @@ func (provider *GeminiProvider) GetProviderKey() schemas.ModelProvider { } // completeRequest handles the common HTTP request pattern for Gemini API calls -func (provider *GeminiProvider) completeRequest(ctx context.Context, model string, key schemas.Key, jsonBody []byte, endpoint string, meta *providerUtils.RequestMetadata) (*GenerateContentResponse, interface{}, time.Duration, *schemas.BifrostError) { +func (provider *GeminiProvider) completeRequest(ctx *schemas.BifrostContext, model string, key schemas.Key, jsonBody []byte, endpoint string, meta *providerUtils.RequestMetadata) (*GenerateContentResponse, interface{}, time.Duration, *schemas.BifrostError) { providerName := provider.GetProviderKey() // Create request @@ -130,7 +130,7 @@ func (provider *GeminiProvider) completeRequest(ctx context.Context, model strin // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. -func (provider *GeminiProvider) listModelsByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() // Create request @@ -190,7 +190,7 @@ func (provider *GeminiProvider) listModelsByKey(ctx context.Context, key schemas // ListModels performs a list models request to Gemini's API. // Requests are made concurrently for improved performance. -func (provider *GeminiProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ListModelsRequest); err != nil { return nil, err } @@ -207,19 +207,19 @@ func (provider *GeminiProvider) ListModels(ctx context.Context, keys []schemas.K } // TextCompletion is not supported by the Gemini provider. -func (provider *GeminiProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionRequest, provider.GetProviderKey()) } // TextCompletionStream performs a streaming text completion request to Gemini's API. // It formats the request, sends it to Gemini, and processes the response. // Returns a channel of BifrostStream objects or an error if the request fails. -func (provider *GeminiProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *GeminiProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } // ChatCompletion performs a chat completion request to the Gemini API. -func (provider *GeminiProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { // Check if chat completion is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil { return nil, err @@ -268,7 +268,7 @@ func (provider *GeminiProvider) ChatCompletion(ctx context.Context, key schemas. // ChatCompletionStream performs a streaming chat completion request to the Gemini API. // It supports real-time streaming of responses using Server-Sent Events (SSE). // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *GeminiProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *GeminiProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Check if chat completion stream is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err @@ -318,7 +318,7 @@ func (provider *GeminiProvider) ChatCompletionStream(ctx context.Context, postHo // HandleGeminiChatCompletionStream handles streaming for Gemini-compatible APIs. func HandleGeminiChatCompletionStream( - ctx context.Context, + ctx *schemas.BifrostContext, client *fasthttp.Client, url string, jsonBody []byte, @@ -380,9 +380,31 @@ func HandleGeminiChatCompletionStream( // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + if resp.BodyStream() == nil { + bifrostErr := providerUtils.NewBifrostOperationError( + "Provider returned an empty response", + fmt.Errorf("provider returned an empty response"), + providerName, + ) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) + return + } + + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger) + defer stopCancellation() + scanner := bufio.NewScanner(resp.BodyStream()) buf := make([]byte, 0, 1024*1024) scanner.Buffer(buf, 10*1024*1024) @@ -395,10 +417,9 @@ func HandleGeminiChatCompletionStream( var modelName string for scanner.Scan() { - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } line := scanner.Text() @@ -406,19 +427,15 @@ func HandleGeminiChatCompletionStream( if line == "" || strings.HasPrefix(line, ":") { continue } - // Parse SSE data if !strings.HasPrefix(line, "data: ") { continue } - eventData := strings.TrimPrefix(line, "data: ") - // Skip empty data if strings.TrimSpace(eventData) == "" { continue } - // Process chunk using shared function geminiResponse, err := processGeminiStreamChunk(eventData) if err != nil { @@ -437,7 +454,7 @@ func HandleGeminiChatCompletionStream( ModelRequested: model, }, } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) return } @@ -461,7 +478,7 @@ func HandleGeminiChatCompletionStream( Provider: providerName, ModelRequested: model, } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) return } @@ -499,7 +516,7 @@ func HandleGeminiChatCompletionStream( providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) break } @@ -511,6 +528,11 @@ func HandleGeminiChatCompletionStream( // Handle scanner errors if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, model, logger) } @@ -522,7 +544,7 @@ func HandleGeminiChatCompletionStream( // Responses performs a chat completion request to Gemini's API. // It formats the request, sends it to Gemini, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *GeminiProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { return nil, err } @@ -576,7 +598,7 @@ func (provider *GeminiProvider) Responses(ctx context.Context, key schemas.Key, } // ResponsesStream performs a streaming responses request to the Gemini API. -func (provider *GeminiProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *GeminiProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Check if responses stream is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { return nil, err @@ -625,7 +647,7 @@ func (provider *GeminiProvider) ResponsesStream(ctx context.Context, postHookRun // HandleGeminiResponsesStream handles streaming for Gemini-compatible APIs. func HandleGeminiResponsesStream( - ctx context.Context, + ctx *schemas.BifrostContext, client *fasthttp.Client, url string, jsonBody []byte, @@ -687,9 +709,32 @@ func HandleGeminiResponsesStream( // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + } + close(responseChan) + }() + defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger) + defer stopCancellation() + + if resp.BodyStream() == nil { + bifrostErr := providerUtils.NewBifrostOperationError( + "Provider returned an empty response", + fmt.Errorf("provider returned an empty response"), + providerName, + ) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) + return + } + scanner := bufio.NewScanner(resp.BodyStream()) buf := make([]byte, 0, 1024*1024) scanner.Buffer(buf, 10*1024*1024) @@ -706,11 +751,11 @@ func HandleGeminiResponsesStream( var lastUsageMetadata *GenerateContentResponseUsageMetadata for scanner.Scan() { - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } + line := scanner.Text() // Skip empty lines and comments @@ -748,7 +793,7 @@ func HandleGeminiResponsesStream( ModelRequested: model, }, } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) return } @@ -769,7 +814,7 @@ func HandleGeminiResponsesStream( Provider: providerName, ModelRequested: model, } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) return } @@ -811,7 +856,7 @@ func HandleGeminiResponsesStream( providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil), responseChan) return } @@ -829,43 +874,50 @@ func HandleGeminiResponsesStream( // Handle scanner errors if err := scanner.Err(); err != nil { + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, model, logger) - } else { - // Finalize the stream by closing any open items - finalResponses := FinalizeGeminiResponsesStream(streamState, lastUsageMetadata, sequenceNumber) - for i, finalResponse := range finalResponses { - finalResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), - } - - if postResponseConverter != nil { - finalResponse = postResponseConverter(finalResponse) - if finalResponse == nil { - logger.Warn("postResponseConverter returned nil; skipping final response") - continue - } - } - - chunkIndex++ - sequenceNumber++ + return + } + // Finalize the stream by closing any open items + finalResponses := FinalizeGeminiResponsesStream(streamState, lastUsageMetadata, sequenceNumber) + for i, finalResponse := range finalResponses { + if finalResponse == nil { + logger.Warn("FinalizeGeminiResponsesStream returned nil; skipping final response") + continue + } + finalResponse.ExtraFields = schemas.BifrostResponseExtraFields{ + RequestType: schemas.ResponsesStreamRequest, + Provider: providerName, + ModelRequested: model, + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), + } - if sendBackRawResponse { - finalResponse.ExtraFields.RawResponse = "{}" // Final event has no payload + if postResponseConverter != nil { + finalResponse = postResponseConverter(finalResponse) + if finalResponse == nil { + logger.Warn("postResponseConverter returned nil; skipping final response") + continue } + } - // Set final latency on the last response (completed event) - if i == len(finalResponses)-1 { - finalResponse.ExtraFields.Latency = time.Since(startTime).Milliseconds() - } + chunkIndex++ + sequenceNumber++ - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, finalResponse, nil, nil), responseChan) + if sendBackRawResponse { + finalResponse.ExtraFields.RawResponse = "{}" // Final event has no payload + } + isLast := i == len(finalResponses)-1 + // Set final latency on the last response (completed event) + if isLast { + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + finalResponse.ExtraFields.Latency = time.Since(startTime).Milliseconds() } + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, finalResponse, nil, nil), responseChan) } }() @@ -873,7 +925,7 @@ func HandleGeminiResponsesStream( } // Embedding performs an embedding request to the Gemini API. -func (provider *GeminiProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { // Check if embedding is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.EmbeddingRequest); err != nil { return nil, err @@ -966,7 +1018,7 @@ func (provider *GeminiProvider) Embedding(ctx context.Context, key schemas.Key, } // Speech performs a speech synthesis request to the Gemini API. -func (provider *GeminiProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { // Check if speech is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.SpeechRequest); err != nil { return nil, err @@ -991,7 +1043,9 @@ func (provider *GeminiProvider) Speech(ctx context.Context, key schemas.Key, req if bifrostErr != nil { return nil, bifrostErr } - ctx = context.WithValue(ctx, BifrostContextKeyResponseFormat, request.Params.ResponseFormat) + if request.Params != nil { + ctx.SetValue(BifrostContextKeyResponseFormat, request.Params.ResponseFormat) + } response, convErr := geminiResponse.ToBifrostSpeechResponse(ctx) if convErr != nil { return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) @@ -1015,7 +1069,7 @@ func (provider *GeminiProvider) Speech(ctx context.Context, key schemas.Key, req } // SpeechStream performs a streaming speech synthesis request to the Gemini API. -func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Check if speech stream is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.SpeechStreamRequest); err != nil { return nil, err @@ -1091,8 +1145,20 @@ func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner // Start streaming in a goroutine go func() { + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + } + close(responseChan) + }() + defer providerUtils.ReleaseStreamingResponse(resp) - defer close(responseChan) + + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) // Increase buffer size to handle large chunks (especially for audio data) @@ -1104,11 +1170,11 @@ func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner lastChunkTime := startTime for scanner.Scan() { - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } + line := scanner.Text() // Skip empty lines @@ -1148,7 +1214,7 @@ func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner ModelRequested: request.Model, }, } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return } @@ -1208,38 +1274,40 @@ func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil), responseChan) } } - // Handle scanner errors if err := scanner.Err(); err != nil { + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, provider.logger) - } else { - response := &schemas.BifrostSpeechStreamResponse{ - Type: schemas.SpeechStreamResponseTypeDone, - Usage: usage, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), - }, - } - - // Set raw request if enabled - if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { - providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) - } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil), responseChan) + return + } + response := &schemas.BifrostSpeechStreamResponse{ + Type: schemas.SpeechStreamResponseTypeDone, + Usage: usage, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.SpeechStreamRequest, + Provider: providerName, + ModelRequested: request.Model, + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), + }, + } + // Set raw request if enabled + if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { + providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil), responseChan) }() return responseChan, nil } // Transcription performs a speech-to-text request to the Gemini API. -func (provider *GeminiProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { // Check if transcription is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.TranscriptionRequest); err != nil { return nil, err @@ -1285,7 +1353,7 @@ func (provider *GeminiProvider) Transcription(ctx context.Context, key schemas.K } // TranscriptionStream performs a streaming speech-to-text request to the Gemini API. -func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Check if transcription stream is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.TranscriptionStreamRequest); err != nil { return nil, err @@ -1360,8 +1428,18 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) // Increase buffer size to handle large chunks (especially for audio data) @@ -1375,11 +1453,11 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo var fullTranscriptionText string for scanner.Scan() { - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } + line := scanner.Text() // Skip empty lines @@ -1422,7 +1500,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo ModelRequested: request.Model, }, } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return } @@ -1489,34 +1567,39 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo // Handle scanner errors if err := scanner.Err(); err != nil { + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, provider.logger) - } else { - response := &schemas.BifrostTranscriptionStreamResponse{ - Type: schemas.TranscriptionStreamResponseTypeDone, - Text: fullTranscriptionText, - Usage: &schemas.TranscriptionUsage{ - Type: "tokens", - InputTokens: usage.InputTokens, - OutputTokens: usage.OutputTokens, - TotalTokens: usage.TotalTokens, - }, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), - }, - } + return + } + response := &schemas.BifrostTranscriptionStreamResponse{ + Type: schemas.TranscriptionStreamResponseTypeDone, + Text: fullTranscriptionText, + Usage: &schemas.TranscriptionUsage{ + Type: "tokens", + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + TotalTokens: usage.TotalTokens, + }, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.TranscriptionStreamRequest, + Provider: providerName, + ModelRequested: request.Model, + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), + }, + } - // Set raw request if enabled - if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { - providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) - } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response), responseChan) + // Set raw request if enabled + if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { + providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response), responseChan) + }() return responseChan, nil @@ -1527,7 +1610,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo // BatchCreate creates a new batch job for Gemini. // Uses the asynchronous batchGenerateContent endpoint as per official documentation. // Supports both inline requests and file-based input (via InputFileID). -func (provider *GeminiProvider) BatchCreate(ctx context.Context, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.BatchCreateRequest); err != nil { return nil, err } @@ -1698,7 +1781,7 @@ func (provider *GeminiProvider) BatchCreate(ctx context.Context, key schemas.Key } // batchListByKey lists batch jobs for Gemini for a single key. -func (provider *GeminiProvider) batchListByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, time.Duration, *schemas.BifrostError) { +func (provider *GeminiProvider) batchListByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, time.Duration, *schemas.BifrostError) { providerName := provider.GetProviderKey() // Create HTTP request @@ -1808,7 +1891,7 @@ func (provider *GeminiProvider) batchListByKey(ctx context.Context, key schemas. // Note: The consumer API may have limited list functionality. // BatchList lists batch jobs using serial pagination across keys. // Exhausts all pages from one key before moving to the next. -func (provider *GeminiProvider) BatchList(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) BatchList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.BatchListRequest); err != nil { return nil, err } @@ -1881,7 +1964,7 @@ func (provider *GeminiProvider) BatchList(ctx context.Context, keys []schemas.Ke } // batchRetrieveByKey retrieves a specific batch job for Gemini for a single key. -func (provider *GeminiProvider) batchRetrieveByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) batchRetrieveByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() // Create HTTP request @@ -1965,7 +2048,7 @@ func (provider *GeminiProvider) batchRetrieveByKey(ctx context.Context, key sche } // BatchRetrieve retrieves a specific batch job for Gemini, trying each key until successful. -func (provider *GeminiProvider) BatchRetrieve(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.BatchRetrieveRequest); err != nil { return nil, err } @@ -1996,7 +2079,7 @@ func (provider *GeminiProvider) BatchRetrieve(ctx context.Context, keys []schema } // batchCancelByKey cancels a batch job for Gemini for a single key. -func (provider *GeminiProvider) batchCancelByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) batchCancelByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() // Create HTTP request @@ -2062,7 +2145,7 @@ func (provider *GeminiProvider) batchCancelByKey(ctx context.Context, key schema // BatchCancel cancels a batch job for Gemini, trying each key until successful. // Note: Cancellation support depends on the API version and batch state. -func (provider *GeminiProvider) BatchCancel(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) BatchCancel(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.BatchCancelRequest); err != nil { return nil, err } @@ -2115,7 +2198,7 @@ func processGeminiStreamChunk(jsonData string) (*GenerateContentResponse, error) } // batchResultsByKey retrieves batch results for Gemini for a single key. -func (provider *GeminiProvider) batchResultsByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) batchResultsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() // We need to get the full batch response with results, so make the API call directly @@ -2275,7 +2358,7 @@ func (provider *GeminiProvider) batchResultsByKey(ctx context.Context, key schem // BatchResults retrieves batch results for Gemini, trying each key until successful. // Results are extracted from dest.inlinedResponses for inline batches, // or downloaded from dest.fileName for file-based batches. -func (provider *GeminiProvider) BatchResults(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) BatchResults(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.BatchResultsRequest); err != nil { return nil, err } @@ -2306,7 +2389,7 @@ func (provider *GeminiProvider) BatchResults(ctx context.Context, keys []schemas } // FileUpload uploads a file to Gemini. -func (provider *GeminiProvider) FileUpload(ctx context.Context, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.FileUploadRequest); err != nil { return nil, err } @@ -2443,7 +2526,7 @@ func (provider *GeminiProvider) FileUpload(ctx context.Context, key schemas.Key, } // fileListByKey lists files from Gemini for a single key. -func (provider *GeminiProvider) fileListByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, time.Duration, *schemas.BifrostError) { +func (provider *GeminiProvider) fileListByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, time.Duration, *schemas.BifrostError) { providerName := provider.GetProviderKey() // Create request @@ -2548,7 +2631,7 @@ func (provider *GeminiProvider) fileListByKey(ctx context.Context, key schemas.K // FileList lists files from Gemini across all provided keys. // FileList lists files using serial pagination across keys. // Exhausts all pages from one key before moving to the next. -func (provider *GeminiProvider) FileList(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) FileList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.FileListRequest); err != nil { return nil, err } @@ -2621,7 +2704,7 @@ func (provider *GeminiProvider) FileList(ctx context.Context, keys []schemas.Key } // fileRetrieveByKey retrieves file metadata from Gemini for a single key. -func (provider *GeminiProvider) fileRetrieveByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) fileRetrieveByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() // Create request @@ -2705,7 +2788,7 @@ func (provider *GeminiProvider) fileRetrieveByKey(ctx context.Context, key schem } // FileRetrieve retrieves file metadata from Gemini, trying each key until successful. -func (provider *GeminiProvider) FileRetrieve(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) FileRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.FileRetrieveRequest); err != nil { return nil, err } @@ -2736,7 +2819,7 @@ func (provider *GeminiProvider) FileRetrieve(ctx context.Context, keys []schemas } // fileDeleteByKey deletes a file from Gemini for a single key. -func (provider *GeminiProvider) fileDeleteByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) fileDeleteByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() // Create request @@ -2787,7 +2870,7 @@ func (provider *GeminiProvider) fileDeleteByKey(ctx context.Context, key schemas } // FileDelete deletes a file from Gemini, trying each key until successful. -func (provider *GeminiProvider) FileDelete(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) FileDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.FileDeleteRequest); err != nil { return nil, err } @@ -2820,7 +2903,7 @@ func (provider *GeminiProvider) FileDelete(ctx context.Context, keys []schemas.K // FileContent downloads file content from Gemini. // Note: Gemini Files API doesn't support direct content download. // Files are accessed via their URI in API requests. -func (provider *GeminiProvider) FileContent(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) FileContent(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.FileContentRequest); err != nil { return nil, err } @@ -2837,7 +2920,7 @@ func (provider *GeminiProvider) FileContent(ctx context.Context, keys []schemas. } // CountTokens performs a token counting request to Gemini's countTokens endpoint. -func (provider *GeminiProvider) CountTokens(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) CountTokens(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.CountTokensRequest); err != nil { return nil, err } diff --git a/core/providers/gemini/responses.go b/core/providers/gemini/responses.go index 87d6a8c079..1368e0eea1 100644 --- a/core/providers/gemini/responses.go +++ b/core/providers/gemini/responses.go @@ -27,23 +27,25 @@ func (request *GeminiGenerationRequest) ToBifrostResponsesRequest() *schemas.Bif params := request.convertGenerationConfigToResponsesParameters() + // Convert SystemInstruction to system messages first + var inputMessages []schemas.ResponsesMessage + if request.SystemInstruction != nil && len(request.SystemInstruction.Parts) > 0 { + systemMsg := convertGeminiSystemInstructionToResponsesMessage(request.SystemInstruction) + if systemMsg != nil { + inputMessages = append(inputMessages, *systemMsg) + } + } + // Convert Contents to Input messages if len(request.Contents) > 0 { - bifrostReq.Input = convertGeminiContentsToResponsesMessages(request.Contents) + contentsMessages := convertGeminiContentsToResponsesMessages(request.Contents) + if len(contentsMessages) > 0 { + inputMessages = append(inputMessages, contentsMessages...) + } } - if request.SystemInstruction != nil { - var systemInstructionText string - if len(request.SystemInstruction.Parts) > 0 { - for _, part := range request.SystemInstruction.Parts { - if part.Text != "" { - systemInstructionText += part.Text - } - } - } - if systemInstructionText != "" { - params.Instructions = &systemInstructionText - } + if len(inputMessages) > 0 { + bifrostReq.Input = inputMessages } if len(request.Tools) > 0 { @@ -106,14 +108,27 @@ func ToGeminiResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) *Gemi } } - if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil { - if safetySettings, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "safety_settings"); ok { - if settings, ok := safetySettings.([]SafetySetting); ok { - geminiReq.SafetySettings = settings + if bifrostReq.Params != nil { + if bifrostReq.Params.Instructions != nil { + // check if system instruction is already set + if geminiReq.SystemInstruction == nil { + geminiReq.SystemInstruction = &Content{ + Parts: []*Part{ + {Text: *bifrostReq.Params.Instructions}, + }, + } } } - if cachedContent, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["cached_content"]); ok { - geminiReq.CachedContent = cachedContent + + if bifrostReq.Params.ExtraParams != nil { + if safetySettings, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "safety_settings"); ok { + if settings, ok := SafeExtractSafetySettings(safetySettings); ok { + geminiReq.SafetySettings = settings + } + } + if cachedContent, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["cached_content"]); ok { + geminiReq.CachedContent = cachedContent + } } } @@ -1362,6 +1377,48 @@ func FinalizeGeminiResponsesStream(state *GeminiResponsesStreamState, usage *Gen return closeGeminiOpenItems(state, usage, sequenceNumber) } +// convertGeminiSystemInstructionToResponsesMessage converts Gemini SystemInstruction to a system role message +func convertGeminiSystemInstructionToResponsesMessage(systemInstruction *Content) *schemas.ResponsesMessage { + if systemInstruction == nil || len(systemInstruction.Parts) == 0 { + return nil + } + + var contentBlocks []schemas.ResponsesMessageContentBlock + var hasTextContent bool + + for _, part := range systemInstruction.Parts { + if part.Text != "" { + contentBlocks = append(contentBlocks, schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesInputMessageContentBlockTypeText, + Text: &part.Text, + }) + hasTextContent = true + } + } + + if !hasTextContent { + return nil + } + + // If single text block, use ContentStr + if len(contentBlocks) == 1 { + return &schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleSystem), + Content: &schemas.ResponsesMessageContent{ + ContentStr: contentBlocks[0].Text, + }, + } + } + + // Multiple blocks, use ContentBlocks + return &schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleSystem), + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: contentBlocks, + }, + } +} + func convertGeminiContentsToResponsesMessages(contents []Content) []schemas.ResponsesMessage { var messages []schemas.ResponsesMessage // Track function call IDs by name to match with responses diff --git a/core/providers/gemini/transcription.go b/core/providers/gemini/transcription.go index 7dfb99ebc2..c846ee0c2b 100644 --- a/core/providers/gemini/transcription.go +++ b/core/providers/gemini/transcription.go @@ -115,7 +115,7 @@ func ToGeminiTranscriptionRequest(bifrostReq *schemas.BifrostTranscriptionReques if bifrostReq.Params.ExtraParams != nil { // Safety settings if safetySettings, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "safety_settings"); ok { - if settings, ok := safetySettings.([]SafetySetting); ok { + if settings, ok := SafeExtractSafetySettings(safetySettings); ok { geminiReq.SafetySettings = settings } } @@ -127,7 +127,7 @@ func ToGeminiTranscriptionRequest(bifrostReq *schemas.BifrostTranscriptionReques // Labels if labels, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "labels"); ok { - if labelMap, ok := labels.(map[string]string); ok { + if labelMap, ok := schemas.SafeExtractStringMap(labels); ok { geminiReq.Labels = labelMap } } diff --git a/core/providers/gemini/types.go b/core/providers/gemini/types.go index f0c5ae953a..6d2f85f04e 100644 --- a/core/providers/gemini/types.go +++ b/core/providers/gemini/types.go @@ -98,6 +98,40 @@ type SafetySetting struct { Threshold string `json:"threshold,omitempty"` } +// SafeExtractSafetySettings safely extracts []SafetySetting from an interface{} with type checking. +// Handles both direct []SafetySetting and JSON-deserialized []interface{} cases. +func SafeExtractSafetySettings(value interface{}) ([]SafetySetting, bool) { + if value == nil { + return nil, false + } + switch v := value.(type) { + case []SafetySetting: + return v, true + case []interface{}: + settings := make([]SafetySetting, 0, len(v)) + for _, item := range v { + if m, ok := item.(map[string]interface{}); ok { + setting := SafetySetting{} + if method, ok := m["method"].(string); ok { + setting.Method = method + } + if category, ok := m["category"].(string); ok { + setting.Category = category + } + if threshold, ok := m["threshold"].(string); ok { + setting.Threshold = threshold + } + settings = append(settings, setting) + } else { + return nil, false + } + } + return settings, true + default: + return nil, false + } +} + // FunctionCallingConfig represents function calling configuration. type FunctionCallingConfig struct { // Optional. Function calling mode. diff --git a/core/providers/groq/groq.go b/core/providers/groq/groq.go index 25c93f670b..fc5c37be28 100644 --- a/core/providers/groq/groq.go +++ b/core/providers/groq/groq.go @@ -2,7 +2,6 @@ package groq import ( - "context" "strings" "time" @@ -64,7 +63,7 @@ func (provider *GroqProvider) GetProviderKey() schemas.ModelProvider { } // ListModels performs a list models request to Groq's API. -func (provider *GroqProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *GroqProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { return openai.HandleOpenAIListModelsRequest( ctx, provider.client, @@ -80,7 +79,7 @@ func (provider *GroqProvider) ListModels(ctx context.Context, keys []schemas.Key } // TextCompletion is not supported by the Groq provider. -func (provider *GroqProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *GroqProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { // Checking if litellm fallback is set if _, ok := ctx.Value(schemas.BifrostContextKey("x-litellm-fallback")).(string); !ok { return nil, providerUtils.NewUnsupportedOperationError("text completion", "groq") @@ -109,7 +108,7 @@ func (provider *GroqProvider) TextCompletion(ctx context.Context, key schemas.Ke // TextCompletionStream performs a streaming text completion request to Groq's API. // It formats the request, sends it to Groq, and processes the response. // Returns a channel of BifrostStream objects or an error if the request fails. -func (provider *GroqProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *GroqProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Checking if litellm fallback is set if _, ok := ctx.Value(schemas.BifrostContextKey("x-litellm-fallback")).(string); !ok { return nil, providerUtils.NewUnsupportedOperationError("text completion", "groq") @@ -155,7 +154,7 @@ func (provider *GroqProvider) TextCompletionStream(ctx context.Context, postHook } // ChatCompletion performs a chat completion request to the Groq API. -func (provider *GroqProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *GroqProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { return openai.HandleOpenAIChatCompletionRequest( ctx, provider.client, @@ -166,6 +165,7 @@ func (provider *GroqProvider) ChatCompletion(ctx context.Context, key schemas.Ke providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + nil, provider.logger, ) } @@ -174,7 +174,7 @@ func (provider *GroqProvider) ChatCompletion(ctx context.Context, key schemas.Ke // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Groq's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *GroqProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *GroqProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { var authHeader map[string]string if key.Value != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value} @@ -194,12 +194,13 @@ func (provider *GroqProvider) ChatCompletionStream(ctx context.Context, postHook nil, nil, nil, + nil, provider.logger, ) } // Responses performs a responses request to the Groq API. -func (provider *GroqProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *GroqProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) if err != nil { return nil, err @@ -214,8 +215,8 @@ func (provider *GroqProvider) Responses(ctx context.Context, key schemas.Key, re } // ResponsesStream performs a streaming responses request to the Groq API. -func (provider *GroqProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) +func (provider *GroqProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, @@ -225,81 +226,81 @@ func (provider *GroqProvider) ResponsesStream(ctx context.Context, postHookRunne } // Embedding is not supported by the Groq provider. -func (provider *GroqProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *GroqProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey()) } // Speech is not supported by the Groq provider. -func (provider *GroqProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *GroqProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) } // SpeechStream is not supported by the Groq provider. -func (provider *GroqProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *GroqProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } // Transcription is not supported by the Groq provider. -func (provider *GroqProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *GroqProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey()) } // TranscriptionStream is not supported by the Groq provider. -func (provider *GroqProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *GroqProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } // BatchCreate is not supported by Groq provider. -func (provider *GroqProvider) BatchCreate(_ context.Context, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *GroqProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) } // BatchList is not supported by Groq provider. -func (provider *GroqProvider) BatchList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *GroqProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey()) } // BatchRetrieve is not supported by Groq provider. -func (provider *GroqProvider) BatchRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *GroqProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey()) } // BatchCancel is not supported by Groq provider. -func (provider *GroqProvider) BatchCancel(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *GroqProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey()) } // BatchResults is not supported by Groq provider. -func (provider *GroqProvider) BatchResults(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *GroqProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey()) } // FileUpload is not supported by Groq provider. -func (provider *GroqProvider) FileUpload(_ context.Context, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *GroqProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey()) } // FileList is not supported by Groq provider. -func (provider *GroqProvider) FileList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *GroqProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey()) } // FileRetrieve is not supported by Groq provider. -func (provider *GroqProvider) FileRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *GroqProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey()) } // FileDelete is not supported by Groq provider. -func (provider *GroqProvider) FileDelete(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *GroqProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey()) } // FileContent is not supported by Groq provider. -func (provider *GroqProvider) FileContent(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *GroqProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey()) } // CountTokens is not supported by the Groq provider. -func (provider *GroqProvider) CountTokens(_ context.Context, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *GroqProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey()) } diff --git a/core/providers/huggingface/huggingface.go b/core/providers/huggingface/huggingface.go index 2df8d7f87c..2103022b13 100644 --- a/core/providers/huggingface/huggingface.go +++ b/core/providers/huggingface/huggingface.go @@ -1,7 +1,6 @@ package huggingface import ( - "context" "fmt" "net/http" "strings" @@ -104,13 +103,13 @@ func (provider *HuggingFaceProvider) GetProviderKey() schemas.ModelProvider { } // buildRequestURL composes the final request URL based on context overrides. -func (provider *HuggingFaceProvider) buildRequestURL(ctx context.Context, defaultPath string, requestType schemas.RequestType) string { +func (provider *HuggingFaceProvider) buildRequestURL(ctx *schemas.BifrostContext, defaultPath string, requestType schemas.RequestType) string { return provider.networkConfig.BaseURL + providerUtils.GetRequestPath(ctx, defaultPath, provider.customProviderConfig, requestType) } // completeRequestWithModelAliasCache performs a request and retries once on 404 by clearing the cache and refetching model info func (provider *HuggingFaceProvider) completeRequestWithModelAliasCache( - ctx context.Context, + ctx *schemas.BifrostContext, jsonData []byte, key string, isHFInferenceAudioRequest bool, @@ -187,7 +186,7 @@ func (provider *HuggingFaceProvider) completeRequestWithModelAliasCache( return responseBody, latency, nil } -func (provider *HuggingFaceProvider) completeRequest(ctx context.Context, jsonData []byte, url string, key string, isHFInferenceAudioRequest bool) ([]byte, time.Duration, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string, isHFInferenceAudioRequest bool) ([]byte, time.Duration, *schemas.BifrostError) { req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() defer fasthttp.ReleaseRequest(req) @@ -249,7 +248,7 @@ func (provider *HuggingFaceProvider) completeRequest(ctx context.Context, jsonDa return bodyCopy, latency, nil } -func (provider *HuggingFaceProvider) listModelsByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() type providerResult struct { @@ -398,7 +397,7 @@ func (provider *HuggingFaceProvider) listModelsByKey(ctx context.Context, key sc } // ListModels queries the Hugging Face model hub API to list models served by the inference provider. -func (provider *HuggingFaceProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.HuggingFace, provider.customProviderConfig, schemas.ListModelsRequest); err != nil { return nil, err @@ -416,15 +415,15 @@ func (provider *HuggingFaceProvider) ListModels(ctx context.Context, keys []sche } -func (provider *HuggingFaceProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionRequest, provider.GetProviderKey()) } -func (provider *HuggingFaceProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } -func (provider *HuggingFaceProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.HuggingFace, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil { return nil, err } @@ -507,7 +506,7 @@ func (provider *HuggingFaceProvider) ChatCompletion(ctx context.Context, key sch return bifrostResponse, nil } -func (provider *HuggingFaceProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.HuggingFace, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err } @@ -559,11 +558,12 @@ func (provider *HuggingFaceProvider) ChatCompletionStream(ctx context.Context, p customRequestConverter, nil, nil, + nil, provider.logger, ) } -func (provider *HuggingFaceProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.HuggingFace, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { return nil, err } @@ -581,12 +581,12 @@ func (provider *HuggingFaceProvider) Responses(ctx context.Context, key schemas. return response, nil } -func (provider *HuggingFaceProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.HuggingFace, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { return nil, err } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) + ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, @@ -595,7 +595,7 @@ func (provider *HuggingFaceProvider) ResponsesStream(ctx context.Context, postHo ) } -func (provider *HuggingFaceProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.HuggingFace, provider.customProviderConfig, schemas.EmbeddingRequest); err != nil { return nil, err } @@ -680,7 +680,7 @@ func (provider *HuggingFaceProvider) Embedding(ctx context.Context, key schemas. return bifrostResponse, nil } -func (provider *HuggingFaceProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { // Check if Speech is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.HuggingFace, provider.customProviderConfig, schemas.SpeechRequest); err != nil { return nil, err @@ -762,11 +762,11 @@ func (provider *HuggingFaceProvider) Speech(ctx context.Context, key schemas.Key return bifrostResponse, nil } -func (provider *HuggingFaceProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } -func (provider *HuggingFaceProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { // Check if Transcription is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.HuggingFace, provider.customProviderConfig, schemas.TranscriptionRequest); err != nil { return nil, err @@ -856,61 +856,61 @@ func (provider *HuggingFaceProvider) Transcription(ctx context.Context, key sche } // TranscriptionStream is not supported by the Hugging Face provider. -func (provider *HuggingFaceProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } // BatchCreate is not supported by the Hugging Face provider. -func (provider *HuggingFaceProvider) BatchCreate(_ context.Context, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) } // BatchList is not supported by the Hugging Face provider. -func (provider *HuggingFaceProvider) BatchList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey()) } // BatchRetrieve is not supported by the Hugging Face provider. -func (provider *HuggingFaceProvider) BatchRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey()) } // BatchCancel is not supported by the Hugging Face provider. -func (provider *HuggingFaceProvider) BatchCancel(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey()) } // BatchResults is not supported by the Hugging Face provider. -func (provider *HuggingFaceProvider) BatchResults(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey()) } // FileUpload is not supported by the Hugging Face provider. -func (provider *HuggingFaceProvider) FileUpload(_ context.Context, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey()) } // FileList is not supported by the Hugging Face provider. -func (provider *HuggingFaceProvider) FileList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey()) } // FileRetrieve is not supported by the Hugging Face provider. -func (provider *HuggingFaceProvider) FileRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey()) } // FileDelete is not supported by the Hugging Face provider. -func (provider *HuggingFaceProvider) FileDelete(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey()) } // FileContent is not supported by the Hugging Face provider. -func (provider *HuggingFaceProvider) FileContent(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey()) } // CountTokens is not supported by the Hugging Face provider. -func (provider *HuggingFaceProvider) CountTokens(_ context.Context, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *HuggingFaceProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey()) } diff --git a/core/providers/mistral/mistral.go b/core/providers/mistral/mistral.go index df8de4a469..34e6c32d13 100644 --- a/core/providers/mistral/mistral.go +++ b/core/providers/mistral/mistral.go @@ -70,7 +70,7 @@ func (provider *MistralProvider) GetProviderKey() schemas.ModelProvider { // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. -func (provider *MistralProvider) listModelsByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() // Create request @@ -131,7 +131,7 @@ func (provider *MistralProvider) listModelsByKey(ctx context.Context, key schema // ListModels performs a list models request to Mistral's API. // Requests are made concurrently for improved performance. -func (provider *MistralProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *MistralProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { return providerUtils.HandleMultipleListModelsRequests( ctx, keys, @@ -142,19 +142,19 @@ func (provider *MistralProvider) ListModels(ctx context.Context, keys []schemas. } // TextCompletion is not supported by the Mistral provider. -func (provider *MistralProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *MistralProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionRequest, provider.GetProviderKey()) } // TextCompletionStream performs a streaming text completion request to Mistral's API. // It formats the request, sends it to Mistral, and processes the response. // Returns a channel of BifrostStream objects or an error if the request fails. -func (provider *MistralProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *MistralProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } // ChatCompletion performs a chat completion request to the Mistral API. -func (provider *MistralProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *MistralProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { return openai.HandleOpenAIChatCompletionRequest( ctx, provider.client, @@ -165,6 +165,7 @@ func (provider *MistralProvider) ChatCompletion(ctx context.Context, key schemas providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + nil, provider.logger, ) } @@ -173,7 +174,7 @@ func (provider *MistralProvider) ChatCompletion(ctx context.Context, key schemas // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Mistral's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *MistralProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *MistralProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { var authHeader map[string]string if key.Value != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value} @@ -193,12 +194,13 @@ func (provider *MistralProvider) ChatCompletionStream(ctx context.Context, postH nil, nil, nil, + nil, provider.logger, ) } // Responses performs a responses request to the Mistral API. -func (provider *MistralProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *MistralProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) if err != nil { return nil, err @@ -213,8 +215,8 @@ func (provider *MistralProvider) Responses(ctx context.Context, key schemas.Key, } // ResponsesStream performs a streaming responses request to the Mistral API. -func (provider *MistralProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) +func (provider *MistralProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, @@ -225,7 +227,7 @@ func (provider *MistralProvider) ResponsesStream(ctx context.Context, postHookRu // Embedding generates embeddings for the given input text(s) using the Mistral API. // Supports Mistral's embedding models and returns a BifrostResponse containing the embedding(s). -func (provider *MistralProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *MistralProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { // Use the shared embedding request handler return openai.HandleOpenAIEmbeddingRequest( ctx, @@ -242,19 +244,19 @@ func (provider *MistralProvider) Embedding(ctx context.Context, key schemas.Key, } // Speech is not supported by the Mistral provider. -func (provider *MistralProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *MistralProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) } // SpeechStream is not supported by the Mistral provider. -func (provider *MistralProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *MistralProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } // Transcription performs an audio transcription request to the Mistral API. // It creates a multipart form with the audio file and sends it to Mistral's transcription endpoint. // Returns the transcribed text and metadata, or an error if the request fails. -func (provider *MistralProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() // Convert Bifrost request to Mistral format @@ -358,7 +360,7 @@ func (provider *MistralProvider) Transcription(ctx context.Context, key schemas. // TranscriptionStream performs a streaming transcription request to Mistral's API. // It creates a multipart form with the audio file and streams transcription events. // Returns a channel of BifrostStream objects containing transcription deltas. -func (provider *MistralProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { providerName := provider.GetProviderKey() // Convert Bifrost request to Mistral format @@ -436,8 +438,18 @@ func (provider *MistralProvider) TranscriptionStream(ctx context.Context, postHo // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) // Increase buffer size to handle large chunks @@ -452,13 +464,12 @@ func (provider *MistralProvider) TranscriptionStream(ctx context.Context, postHo var currentData string for scanner.Scan() { - // Check if context is done before processing - select { - case <-ctx.Done(): + + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } - + line := scanner.Text() // Skip empty lines (event delimiter) @@ -496,6 +507,11 @@ func (provider *MistralProvider) TranscriptionStream(ctx context.Context, postHo // Handle scanner errors if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, provider.logger) } @@ -506,7 +522,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx context.Context, postHo // processTranscriptionStreamEvent processes a single SSE event and sends it to the response channel. func (provider *MistralProvider) processTranscriptionStreamEvent( - ctx context.Context, + ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, eventType string, jsonData string, @@ -531,7 +547,7 @@ func (provider *MistralProvider) processTranscriptionStreamEvent( ModelRequested: model, RequestType: schemas.TranscriptionStreamRequest, } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, provider.logger) return } @@ -573,7 +589,7 @@ func (provider *MistralProvider) processTranscriptionStreamEvent( // Check for done event (handle both "transcription.done" and "transcript.text.done") if MistralTranscriptionStreamEventType(eventType) == MistralTranscriptionStreamEventDone || eventType == "transcript.text.done" { response.ExtraFields.Latency = time.Since(startTime).Milliseconds() - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) // Ensure response type is set to Done response.Type = schemas.TranscriptionStreamResponseTypeDone } @@ -582,56 +598,56 @@ func (provider *MistralProvider) processTranscriptionStreamEvent( } // BatchCreate is not supported by Mistral provider. -func (provider *MistralProvider) BatchCreate(_ context.Context, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *MistralProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) } // BatchList is not supported by Mistral provider. -func (provider *MistralProvider) BatchList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *MistralProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey()) } // BatchRetrieve is not supported by Mistral provider. -func (provider *MistralProvider) BatchRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *MistralProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey()) } // BatchCancel is not supported by Mistral provider. -func (provider *MistralProvider) BatchCancel(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *MistralProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey()) } // BatchResults is not supported by Mistral provider. -func (provider *MistralProvider) BatchResults(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *MistralProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey()) } // FileUpload is not supported by Mistral provider. -func (provider *MistralProvider) FileUpload(_ context.Context, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *MistralProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey()) } // FileList is not supported by Mistral provider. -func (provider *MistralProvider) FileList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *MistralProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey()) } // FileRetrieve is not supported by Mistral provider. -func (provider *MistralProvider) FileRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *MistralProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey()) } // FileDelete is not supported by Mistral provider. -func (provider *MistralProvider) FileDelete(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *MistralProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey()) } // FileContent is not supported by Mistral provider. -func (provider *MistralProvider) FileContent(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *MistralProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey()) } // CountTokens is not supported by the Mistral provider. -func (provider *MistralProvider) CountTokens(_ context.Context, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *MistralProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey()) } diff --git a/core/providers/mistral/transcription_test.go b/core/providers/mistral/transcription_test.go index d1a2dbb68a..dded6da926 100644 --- a/core/providers/mistral/transcription_test.go +++ b/core/providers/mistral/transcription_test.go @@ -568,7 +568,7 @@ func TestTranscriptionWithMockServer(t *testing.T) { } // Make request - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 10*time.Second) defer cancel() resp, err := provider.Transcription(ctx, schemas.Key{Value: "test-api-key"}, request) @@ -599,7 +599,7 @@ func TestTranscriptionNilInput(t *testing.T) { }, }, &testLogger{}) - ctx := context.Background() + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) tests := []struct { name string @@ -762,12 +762,12 @@ func TestTranscriptionStreamWithMockServer(t *testing.T) { } // Create post hook runner (no-op for tests) - postHookRunner := func(ctx *context.Context, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + postHookRunner := func(ctx *schemas.BifrostContext, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { return response, err } // Make streaming request - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 10*time.Second) defer cancel() streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, schemas.Key{Value: "test-api-key"}, request) @@ -805,11 +805,11 @@ func TestTranscriptionStreamNilInput(t *testing.T) { }, &testLogger{}) // Create post hook runner (no-op for tests) - postHookRunner := func(ctx *context.Context, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + postHookRunner := func(ctx *schemas.BifrostContext, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { return response, err } - ctx := context.Background() + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) tests := []struct { name string @@ -1243,11 +1243,11 @@ func TestTranscriptionStreamEdgeCases(t *testing.T) { } // Create post hook runner - postHookRunner := func(ctx *context.Context, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + postHookRunner := func(ctx *schemas.BifrostContext, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { return response, err } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 10*time.Second) defer cancel() streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, schemas.Key{Value: "test-key"}, request) @@ -1311,12 +1311,12 @@ func TestTranscriptionStreamContextCancellation(t *testing.T) { }, } - postHookRunner := func(ctx *context.Context, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + postHookRunner := func(ctx *schemas.BifrostContext, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { return response, err } // Create context with short timeout - ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 500*time.Millisecond) defer cancel() streamChan, err := provider.TranscriptionStream(ctx, postHookRunner, schemas.Key{Value: "test-key"}, request) @@ -1495,7 +1495,7 @@ func TestMistralTranscriptionIntegration(t *testing.T) { // Note: Mistral may reject this minimal WAV file - this tests error handling too audioData := createMinimalAudioFile() - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 60*time.Second) defer cancel() request := &schemas.BifrostTranscriptionRequest{ @@ -1553,7 +1553,7 @@ func TestMistralTranscriptionStreamIntegration(t *testing.T) { // Create a minimal but valid audio file for testing audioData := createMinimalAudioFile() - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), 60*time.Second) defer cancel() request := &schemas.BifrostTranscriptionRequest{ @@ -1567,7 +1567,7 @@ func TestMistralTranscriptionStreamIntegration(t *testing.T) { } // Create post hook runner (no-op for tests) - postHookRunner := func(ctx *context.Context, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + postHookRunner := func(ctx *schemas.BifrostContext, response *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { return response, err } diff --git a/core/providers/nebius/nebius.go b/core/providers/nebius/nebius.go index 9a72779b83..45ff1183c9 100644 --- a/core/providers/nebius/nebius.go +++ b/core/providers/nebius/nebius.go @@ -2,7 +2,6 @@ package nebius import ( - "context" "fmt" "strings" "time" @@ -49,6 +48,7 @@ func NewNebiusProvider(config *schemas.ProviderConfig, logger schemas.Logger) (* logger: logger, client: client, networkConfig: config.NetworkConfig, + sendBackRawRequest: config.SendBackRawRequest, sendBackRawResponse: config.SendBackRawResponse, }, nil } @@ -59,7 +59,7 @@ func (provider *NebiusProvider) GetProviderKey() schemas.ModelProvider { } // ListModels performs a list models request to Nebius's API. -func (provider *NebiusProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *NebiusProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { return openai.HandleOpenAIListModelsRequest( ctx, provider.client, @@ -77,7 +77,7 @@ func (provider *NebiusProvider) ListModels(ctx context.Context, keys []schemas.K // TextCompletion performs a text completion request to Nebius's API. // It formats the request, sends it to Nebius, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *NebiusProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *NebiusProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionRequest( ctx, provider.client, @@ -88,6 +88,7 @@ func (provider *NebiusProvider) TextCompletion(ctx context.Context, key schemas. provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), + nil, provider.logger, ) } @@ -95,7 +96,7 @@ func (provider *NebiusProvider) TextCompletion(ctx context.Context, key schemas. // TextCompletionStream performs a streaming text completion request to Nebius's API. // It formats the request, sends it to Nebius, and processes the response. // Returns a channel of BifrostStream objects or an error if the request fails. -func (provider *NebiusProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *NebiusProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { var authHeader map[string]string if key.Value != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value} @@ -111,6 +112,7 @@ func (provider *NebiusProvider) TextCompletionStream(ctx context.Context, postHo providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + nil, postHookRunner, nil, provider.logger, @@ -118,7 +120,7 @@ func (provider *NebiusProvider) TextCompletionStream(ctx context.Context, postHo } // ChatCompletion performs a chat completion request to the Nebius API. -func (provider *NebiusProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *NebiusProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { path := providerUtils.GetPathFromContext(ctx, "/v1/chat/completions") // Append query parameter if present @@ -140,6 +142,7 @@ func (provider *NebiusProvider) ChatCompletion(ctx context.Context, key schemas. providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + nil, provider.logger, ) } @@ -148,7 +151,7 @@ func (provider *NebiusProvider) ChatCompletion(ctx context.Context, key schemas. // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Nebius's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *NebiusProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *NebiusProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { var authHeader map[string]string if key.Value != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value} @@ -169,11 +172,12 @@ func (provider *NebiusProvider) ChatCompletionStream(ctx context.Context, postHo nil, nil, nil, + nil, provider.logger, ) } -func (provider *NebiusProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *NebiusProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) if err != nil { return nil, err @@ -188,8 +192,8 @@ func (provider *NebiusProvider) Responses(ctx context.Context, key schemas.Key, } // ResponsesStream performs a streaming responses request to the Nebius API. -func (provider *NebiusProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) +func (provider *NebiusProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, @@ -201,7 +205,7 @@ func (provider *NebiusProvider) ResponsesStream(ctx context.Context, postHookRun // Embedding generates embeddings for the given input text(s). // The input can be either a single string or a slice of strings for batch embedding. // Returns a BifrostResponse containing the embedding(s) and any error that occurred. -func (provider *NebiusProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *NebiusProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { return openai.HandleOpenAIEmbeddingRequest( ctx, provider.client, @@ -216,76 +220,76 @@ func (provider *NebiusProvider) Embedding(ctx context.Context, key schemas.Key, } // Speech is not supported by the Nebius provider. -func (provider *NebiusProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *NebiusProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) } // SpeechStream is not supported by the Nebius provider. -func (provider *NebiusProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *NebiusProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } // Transcription is not supported by the Nebius provider. -func (provider *NebiusProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *NebiusProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey()) } // TranscriptionStream is not supported by the Nebius provider. -func (provider *NebiusProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *NebiusProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } // BatchCreate is not supported by Nebius provider. -func (provider *NebiusProvider) BatchCreate(_ context.Context, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *NebiusProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) } // BatchList is not supported by Nebius provider. -func (provider *NebiusProvider) BatchList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *NebiusProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey()) } // BatchRetrieve is not supported by Nebius provider. -func (provider *NebiusProvider) BatchRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *NebiusProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey()) } // BatchCancel is not supported by Nebius provider. -func (provider *NebiusProvider) BatchCancel(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *NebiusProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey()) } // BatchResults is not supported by Nebius provider. -func (provider *NebiusProvider) BatchResults(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *NebiusProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey()) } // FileUpload is not supported by Nebius provider. -func (provider *NebiusProvider) FileUpload(_ context.Context, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *NebiusProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey()) } // FileList is not supported by Nebius provider. -func (provider *NebiusProvider) FileList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *NebiusProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey()) } // FileRetrieve is not supported by Nebius provider. -func (provider *NebiusProvider) FileRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *NebiusProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey()) } // FileDelete is not supported by Nebius provider. -func (provider *NebiusProvider) FileDelete(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *NebiusProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey()) } // FileContent is not supported by Nebius provider. -func (provider *NebiusProvider) FileContent(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *NebiusProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey()) } // CountTokens is not supported by Nebius provider. -func (provider *NebiusProvider) CountTokens(_ context.Context, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *NebiusProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey()) } diff --git a/core/providers/ollama/ollama.go b/core/providers/ollama/ollama.go index 75ad2639e8..d6950bc78c 100644 --- a/core/providers/ollama/ollama.go +++ b/core/providers/ollama/ollama.go @@ -3,7 +3,6 @@ package ollama import ( - "context" "fmt" "strings" "time" @@ -67,7 +66,7 @@ func (provider *OllamaProvider) GetProviderKey() schemas.ModelProvider { } // ListModels performs a list models request to Ollama's API. -func (provider *OllamaProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if provider.networkConfig.BaseURL == "" { return nil, providerUtils.NewConfigurationError("base_url is not set", provider.GetProviderKey()) } @@ -86,7 +85,7 @@ func (provider *OllamaProvider) ListModels(ctx context.Context, keys []schemas.K } // TextCompletion performs a text completion request to the Ollama API. -func (provider *OllamaProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionRequest( ctx, provider.client, @@ -97,6 +96,7 @@ func (provider *OllamaProvider) TextCompletion(ctx context.Context, key schemas. provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), + nil, provider.logger, ) } @@ -104,7 +104,7 @@ func (provider *OllamaProvider) TextCompletion(ctx context.Context, key schemas. // TextCompletionStream performs a streaming text completion request to Ollama's API. // It formats the request, sends it to Ollama, and processes the response. // Returns a channel of BifrostStream objects or an error if the request fails. -func (provider *OllamaProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OllamaProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionStreaming( ctx, provider.client, @@ -115,6 +115,7 @@ func (provider *OllamaProvider) TextCompletionStream(ctx context.Context, postHo providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + nil, postHookRunner, nil, provider.logger, @@ -122,7 +123,7 @@ func (provider *OllamaProvider) TextCompletionStream(ctx context.Context, postHo } // ChatCompletion performs a chat completion request to the Ollama API. -func (provider *OllamaProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { return openai.HandleOpenAIChatCompletionRequest( ctx, provider.client, @@ -133,6 +134,7 @@ func (provider *OllamaProvider) ChatCompletion(ctx context.Context, key schemas. providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + nil, provider.logger, ) } @@ -141,7 +143,7 @@ func (provider *OllamaProvider) ChatCompletion(ctx context.Context, key schemas. // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Ollama's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *OllamaProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OllamaProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Use shared OpenAI-compatible streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, @@ -157,12 +159,13 @@ func (provider *OllamaProvider) ChatCompletionStream(ctx context.Context, postHo nil, nil, nil, + nil, provider.logger, ) } // Responses performs a responses request to the Ollama API. -func (provider *OllamaProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) if err != nil { return nil, err @@ -177,8 +180,8 @@ func (provider *OllamaProvider) Responses(ctx context.Context, key schemas.Key, } // ResponsesStream performs a streaming responses request to the Ollama API. -func (provider *OllamaProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) +func (provider *OllamaProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, @@ -188,7 +191,7 @@ func (provider *OllamaProvider) ResponsesStream(ctx context.Context, postHookRun } // Embedding performs an embedding request to the Ollama API. -func (provider *OllamaProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { return openai.HandleOpenAIEmbeddingRequest( ctx, provider.client, @@ -204,75 +207,75 @@ func (provider *OllamaProvider) Embedding(ctx context.Context, key schemas.Key, } // Speech is not supported by the Ollama provider. -func (provider *OllamaProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) } // SpeechStream is not supported by the Ollama provider. -func (provider *OllamaProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OllamaProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } // Transcription is not supported by the Ollama provider. -func (provider *OllamaProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey()) } // TranscriptionStream is not supported by the Ollama provider. -func (provider *OllamaProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OllamaProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } // BatchCreate is not supported by Ollama provider. -func (provider *OllamaProvider) BatchCreate(_ context.Context, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) } // BatchList is not supported by Ollama provider. -func (provider *OllamaProvider) BatchList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey()) } // BatchRetrieve is not supported by Ollama provider. -func (provider *OllamaProvider) BatchRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey()) } // BatchCancel is not supported by Ollama provider. -func (provider *OllamaProvider) BatchCancel(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey()) } // BatchResults is not supported by Ollama provider. -func (provider *OllamaProvider) BatchResults(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey()) } // FileUpload is not supported by Ollama provider. -func (provider *OllamaProvider) FileUpload(_ context.Context, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey()) } // FileList is not supported by Ollama provider. -func (provider *OllamaProvider) FileList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey()) } // FileRetrieve is not supported by Ollama provider. -func (provider *OllamaProvider) FileRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey()) } // FileDelete is not supported by Ollama provider. -func (provider *OllamaProvider) FileDelete(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey()) } // FileContent is not supported by Ollama provider. -func (provider *OllamaProvider) FileContent(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey()) } -func (provider *OllamaProvider) CountTokens(_ context.Context, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey()) } diff --git a/core/providers/openai/chat.go b/core/providers/openai/chat.go index 291e5fc9a7..99a6d66112 100644 --- a/core/providers/openai/chat.go +++ b/core/providers/openai/chat.go @@ -1,6 +1,8 @@ package openai import ( + "strings" + "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -41,6 +43,10 @@ func ToOpenAIChatRequest(bifrostReq *schemas.BifrostChatRequest) *OpenAIChatRequ switch bifrostReq.Provider { case schemas.OpenAI: return openaiReq + case schemas.XAI: + openaiReq.filterOpenAISpecificParameters() + openaiReq.applyXAICompatibility(bifrostReq.Model) + return openaiReq case schemas.Gemini: openaiReq.filterOpenAISpecificParameters() // Removing extra parameters that are not supported by Gemini @@ -117,3 +123,27 @@ func (request *OpenAIChatRequest) applyMistralCompatibility() { request.ToolChoice.ChatToolChoiceStruct = nil } } + +// applyXAICompatibility applies xAI-specific transformations to the request +func (request *OpenAIChatRequest) applyXAICompatibility(model string) { + // Only apply filters if this is a grok reasoning model + if !schemas.IsGrokReasoningModel(model) { + return + } + + request.ChatParameters.PresencePenalty = nil + + // Only non-mini grok-3 models support frequency_penalty and stop + // grok-3-mini only supports reasoning_effort in reasoning mode + if !strings.Contains(model, "grok-3") || strings.Contains(model, "grok-3-mini") { + request.ChatParameters.FrequencyPenalty = nil + request.ChatParameters.Stop = nil + } + + // Only grok-3-mini supports reasoning_effort + if request.ChatParameters.Reasoning != nil && + !strings.Contains(model, "grok-3-mini") { + // Clear reasoning_effort for non-grok-3-mini models + request.ChatParameters.Reasoning.Effort = nil + } +} diff --git a/core/providers/openai/chat_test.go b/core/providers/openai/chat_test.go new file mode 100644 index 0000000000..8fe4ef3565 --- /dev/null +++ b/core/providers/openai/chat_test.go @@ -0,0 +1,317 @@ +package openai + +import ( + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestApplyXAICompatibility(t *testing.T) { + tests := []struct { + name string + model string + request *OpenAIChatRequest + validate func(t *testing.T, req *OpenAIChatRequest) + }{ + { + name: "grok-3: preserves frequency_penalty and stop, clears presence_penalty and reasoning_effort", + model: "grok-3", + request: &OpenAIChatRequest{ + Model: "grok-3", + Messages: []OpenAIMessage{}, + ChatParameters: schemas.ChatParameters{ + FrequencyPenalty: schemas.Ptr(0.5), + PresencePenalty: schemas.Ptr(0.3), + Stop: []string{"STOP"}, + Reasoning: &schemas.ChatReasoning{ + Effort: schemas.Ptr("high"), + }, + }, + }, + validate: func(t *testing.T, req *OpenAIChatRequest) { + // frequency_penalty should be preserved + if req.FrequencyPenalty == nil || *req.FrequencyPenalty != 0.5 { + t.Errorf("Expected FrequencyPenalty to be preserved at 0.5, got %v", req.FrequencyPenalty) + } + + // stop should be preserved + if len(req.Stop) != 1 || req.Stop[0] != "STOP" { + t.Errorf("Expected Stop to be preserved as ['STOP'], got %v", req.Stop) + } + + // presence_penalty should be cleared + if req.PresencePenalty != nil { + t.Errorf("Expected PresencePenalty to be cleared (nil), got %v", *req.PresencePenalty) + } + + // reasoning_effort should be cleared for non-mini grok-3 + if req.Reasoning == nil { + t.Fatal("Expected Reasoning to remain non-nil") + } + if req.Reasoning.Effort != nil { + t.Errorf("Expected Reasoning.Effort to be cleared (nil) for grok-3, got %v", *req.Reasoning.Effort) + } + }, + }, + { + name: "grok-3-mini: clears all penalties and stop, preserves reasoning_effort", + model: "grok-3-mini", + request: &OpenAIChatRequest{ + Model: "grok-3-mini", + Messages: []OpenAIMessage{}, + ChatParameters: schemas.ChatParameters{ + FrequencyPenalty: schemas.Ptr(0.5), + PresencePenalty: schemas.Ptr(0.3), + Stop: []string{"STOP"}, + Reasoning: &schemas.ChatReasoning{ + Effort: schemas.Ptr("medium"), + }, + }, + }, + validate: func(t *testing.T, req *OpenAIChatRequest) { + // presence_penalty should be cleared + if req.PresencePenalty != nil { + t.Errorf("Expected PresencePenalty to be cleared (nil), got %v", *req.PresencePenalty) + } + + // frequency_penalty should be cleared for grok-3-mini + if req.FrequencyPenalty != nil { + t.Errorf("Expected FrequencyPenalty to be cleared (nil) for grok-3-mini, got %v", *req.FrequencyPenalty) + } + + // stop should be cleared for grok-3-mini + if req.Stop != nil { + t.Errorf("Expected Stop to be cleared (nil) for grok-3-mini, got %v", req.Stop) + } + + // reasoning_effort should be preserved for grok-3-mini + if req.Reasoning == nil || req.Reasoning.Effort == nil { + t.Fatal("Expected Reasoning.Effort to be preserved for grok-3-mini") + } + if *req.Reasoning.Effort != "medium" { + t.Errorf("Expected Reasoning.Effort to be 'medium', got %v", *req.Reasoning.Effort) + } + }, + }, + { + name: "grok-4: clears all penalties, stop, and reasoning_effort", + model: "grok-4", + request: &OpenAIChatRequest{ + Model: "grok-4", + Messages: []OpenAIMessage{}, + ChatParameters: schemas.ChatParameters{ + FrequencyPenalty: schemas.Ptr(0.5), + PresencePenalty: schemas.Ptr(0.3), + Stop: []string{"STOP"}, + Reasoning: &schemas.ChatReasoning{ + Effort: schemas.Ptr("high"), + }, + }, + }, + validate: func(t *testing.T, req *OpenAIChatRequest) { + // presence_penalty should be cleared + if req.PresencePenalty != nil { + t.Errorf("Expected PresencePenalty to be cleared (nil), got %v", *req.PresencePenalty) + } + + // frequency_penalty should be cleared for grok-4 + if req.FrequencyPenalty != nil { + t.Errorf("Expected FrequencyPenalty to be cleared (nil) for grok-4, got %v", *req.FrequencyPenalty) + } + + // stop should be cleared for grok-4 + if req.Stop != nil { + t.Errorf("Expected Stop to be cleared (nil) for grok-4, got %v", req.Stop) + } + + // reasoning_effort should be cleared for grok-4 + if req.Reasoning == nil { + t.Fatal("Expected Reasoning to remain non-nil") + } + if req.Reasoning.Effort != nil { + t.Errorf("Expected Reasoning.Effort to be cleared (nil) for grok-4, got %v", *req.Reasoning.Effort) + } + }, + }, + { + name: "grok-4-fast-reasoning: clears all penalties, stop, and reasoning_effort", + model: "grok-4-fast-reasoning", + request: &OpenAIChatRequest{ + Model: "grok-4-fast-reasoning", + Messages: []OpenAIMessage{}, + ChatParameters: schemas.ChatParameters{ + FrequencyPenalty: schemas.Ptr(0.5), + PresencePenalty: schemas.Ptr(0.3), + Stop: []string{"STOP", "END"}, + Reasoning: &schemas.ChatReasoning{ + Effort: schemas.Ptr("high"), + }, + }, + }, + validate: func(t *testing.T, req *OpenAIChatRequest) { + // presence_penalty should be cleared + if req.PresencePenalty != nil { + t.Errorf("Expected PresencePenalty to be cleared (nil), got %v", *req.PresencePenalty) + } + + // frequency_penalty should be cleared + if req.FrequencyPenalty != nil { + t.Errorf("Expected FrequencyPenalty to be cleared (nil), got %v", *req.FrequencyPenalty) + } + + // stop should be cleared + if req.Stop != nil { + t.Errorf("Expected Stop to be cleared (nil), got %v", req.Stop) + } + + // reasoning_effort should be cleared + if req.Reasoning == nil { + t.Fatal("Expected Reasoning to remain non-nil") + } + if req.Reasoning.Effort != nil { + t.Errorf("Expected Reasoning.Effort to be cleared (nil), got %v", *req.Reasoning.Effort) + } + }, + }, + { + name: "grok-code-fast-1: clears all penalties, stop, and reasoning_effort", + model: "grok-code-fast-1", + request: &OpenAIChatRequest{ + Model: "grok-code-fast-1", + Messages: []OpenAIMessage{}, + ChatParameters: schemas.ChatParameters{ + FrequencyPenalty: schemas.Ptr(0.2), + PresencePenalty: schemas.Ptr(0.1), + Stop: []string{"END"}, + Reasoning: &schemas.ChatReasoning{ + Effort: schemas.Ptr("low"), + }, + }, + }, + validate: func(t *testing.T, req *OpenAIChatRequest) { + // presence_penalty should be cleared + if req.PresencePenalty != nil { + t.Errorf("Expected PresencePenalty to be cleared (nil), got %v", *req.PresencePenalty) + } + + // frequency_penalty should be cleared + if req.FrequencyPenalty != nil { + t.Errorf("Expected FrequencyPenalty to be cleared (nil), got %v", *req.FrequencyPenalty) + } + + // stop should be cleared + if req.Stop != nil { + t.Errorf("Expected Stop to be cleared (nil), got %v", req.Stop) + } + + // reasoning_effort should be cleared + if req.Reasoning == nil { + t.Fatal("Expected Reasoning to remain non-nil") + } + if req.Reasoning.Effort != nil { + t.Errorf("Expected Reasoning.Effort to be cleared (nil), got %v", *req.Reasoning.Effort) + } + }, + }, + { + name: "non-reasoning grok model: no changes applied", + model: "grok-2-latest", + request: &OpenAIChatRequest{ + Model: "grok-2-latest", + Messages: []OpenAIMessage{}, + ChatParameters: schemas.ChatParameters{ + FrequencyPenalty: schemas.Ptr(0.5), + PresencePenalty: schemas.Ptr(0.3), + Stop: []string{"STOP"}, + Reasoning: &schemas.ChatReasoning{ + Effort: schemas.Ptr("high"), + }, + }, + }, + validate: func(t *testing.T, req *OpenAIChatRequest) { + // All parameters should be preserved for non-reasoning models + if req.FrequencyPenalty == nil || *req.FrequencyPenalty != 0.5 { + t.Errorf("Expected FrequencyPenalty to be preserved at 0.5, got %v", req.FrequencyPenalty) + } + + if req.PresencePenalty == nil || *req.PresencePenalty != 0.3 { + t.Errorf("Expected PresencePenalty to be preserved at 0.3, got %v", req.PresencePenalty) + } + + if len(req.Stop) != 1 || req.Stop[0] != "STOP" { + t.Errorf("Expected Stop to be preserved as ['STOP'], got %v", req.Stop) + } + + if req.Reasoning == nil || req.Reasoning.Effort == nil { + t.Fatal("Expected Reasoning.Effort to be preserved") + } + if *req.Reasoning.Effort != "high" { + t.Errorf("Expected Reasoning.Effort to be 'high', got %v", *req.Reasoning.Effort) + } + }, + }, + { + name: "grok-3: handles nil reasoning gracefully", + model: "grok-3", + request: &OpenAIChatRequest{ + Model: "grok-3", + Messages: []OpenAIMessage{}, + ChatParameters: schemas.ChatParameters{ + FrequencyPenalty: schemas.Ptr(0.5), + PresencePenalty: schemas.Ptr(0.3), + Stop: []string{"STOP"}, + Reasoning: nil, + }, + }, + validate: func(t *testing.T, req *OpenAIChatRequest) { + // Should handle nil reasoning without panicking + if req.Reasoning != nil { + t.Errorf("Expected Reasoning to remain nil, got %v", req.Reasoning) + } + + // Other parameters should still be processed + if req.PresencePenalty != nil { + t.Errorf("Expected PresencePenalty to be cleared (nil), got %v", *req.PresencePenalty) + } + + if req.FrequencyPenalty == nil || *req.FrequencyPenalty != 0.5 { + t.Errorf("Expected FrequencyPenalty to be preserved at 0.5, got %v", req.FrequencyPenalty) + } + }, + }, + { + name: "grok-3: preserves other parameters like temperature", + model: "grok-3", + request: &OpenAIChatRequest{ + Model: "grok-3", + Messages: []OpenAIMessage{}, + ChatParameters: schemas.ChatParameters{ + Temperature: schemas.Ptr(0.8), + TopP: schemas.Ptr(0.9), + FrequencyPenalty: schemas.Ptr(0.5), + PresencePenalty: schemas.Ptr(0.3), + }, + }, + validate: func(t *testing.T, req *OpenAIChatRequest) { + // Unrelated parameters should be preserved + if req.Temperature == nil || *req.Temperature != 0.8 { + t.Errorf("Expected Temperature to be preserved at 0.8, got %v", req.Temperature) + } + + if req.TopP == nil || *req.TopP != 0.9 { + t.Errorf("Expected TopP to be preserved at 0.9, got %v", req.TopP) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Apply the compatibility function + tt.request.applyXAICompatibility(tt.model) + + // Validate the results + tt.validate(t, tt.request) + }) + } +} diff --git a/core/providers/openai/errors.go b/core/providers/openai/errors.go index ab4e813da4..c5efbe328c 100644 --- a/core/providers/openai/errors.go +++ b/core/providers/openai/errors.go @@ -6,6 +6,9 @@ import ( "github.com/valyala/fasthttp" ) +// ErrorConverter is a function that converts provider-specific error responses to BifrostError. +type ErrorConverter func(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError + // ParseOpenAIError parses OpenAI error responses. func ParseOpenAIError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { var errorResp schemas.BifrostError diff --git a/core/providers/openai/models.go b/core/providers/openai/models.go index cd76eaab54..ff2f112e2c 100644 --- a/core/providers/openai/models.go +++ b/core/providers/openai/models.go @@ -6,6 +6,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) +// ToBifrostListModelsResponse converts an OpenAI list models response to a Bifrost list models response func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string) *schemas.BifrostListModelsResponse { if response == nil { return nil @@ -31,16 +32,14 @@ func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKe return bifrostResponse } +// ToOpenAIListModelsResponse converts a Bifrost list models response to an OpenAI list models response func ToOpenAIListModelsResponse(response *schemas.BifrostListModelsResponse) *OpenAIListModelsResponse { - if response == nil { return nil } - openaiResponse := &OpenAIListModelsResponse{ Data: make([]OpenAIModel, 0, len(response.Data)), } - for _, model := range response.Data { openaiModel := OpenAIModel{ ID: model.ID, @@ -56,6 +55,5 @@ func ToOpenAIListModelsResponse(response *schemas.BifrostListModelsResponse) *Op openaiResponse.Data = append(openaiResponse.Data, openaiModel) } - return openaiResponse } diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index 86947f12a9..f4b2e61777 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -75,11 +75,11 @@ func (provider *OpenAIProvider) GetProviderKey() schemas.ModelProvider { } // buildRequestURL constructs the full request URL using the provider's configuration. -func (provider *OpenAIProvider) buildRequestURL(ctx context.Context, defaultPath string, requestType schemas.RequestType) string { +func (provider *OpenAIProvider) buildRequestURL(ctx *schemas.BifrostContext, defaultPath string, requestType schemas.RequestType) string { return provider.networkConfig.BaseURL + providerUtils.GetRequestPath(ctx, defaultPath, provider.customProviderConfig, requestType) } -func (provider *OpenAIProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ListModelsRequest); err != nil { return nil, err } @@ -115,7 +115,7 @@ func (provider *OpenAIProvider) ListModels(ctx context.Context, keys []schemas.K // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func listModelsByKey( - ctx context.Context, + ctx *schemas.BifrostContext, client *fasthttp.Client, url string, key schemas.Key, @@ -185,7 +185,7 @@ func listModelsByKey( // HandleOpenAIListModelsRequest handles a list models request to OpenAI's API. func HandleOpenAIListModelsRequest( - ctx context.Context, + ctx *schemas.BifrostContext, client *fasthttp.Client, request *schemas.BifrostListModelsRequest, url string, @@ -199,7 +199,7 @@ func HandleOpenAIListModelsRequest( if len(keys) == 0 { return listModelsByKey(ctx, client, url, schemas.Key{}, extraHeaders, providerName, sendBackRawRequest, sendBackRawResponse) } - listModelsByKeyWrapper := func(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + listModelsByKeyWrapper := func(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { return listModelsByKey(ctx, client, url, key, extraHeaders, providerName, sendBackRawRequest, sendBackRawResponse) } return providerUtils.HandleMultipleListModelsRequests( @@ -213,7 +213,7 @@ func HandleOpenAIListModelsRequest( // TextCompletion is not supported by the OpenAI provider. // Returns an error indicating that text completion is not available. -func (provider *OpenAIProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.TextCompletionRequest); err != nil { return nil, err } @@ -227,13 +227,14 @@ func (provider *OpenAIProvider) TextCompletion(ctx context.Context, key schemas. provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), + nil, provider.logger, ) } // HandleOpenAITextCompletionRequest handles a text completion request to OpenAI's API. func HandleOpenAITextCompletionRequest( - ctx context.Context, + ctx *schemas.BifrostContext, client *fasthttp.Client, url string, request *schemas.BifrostTextCompletionRequest, @@ -242,6 +243,7 @@ func HandleOpenAITextCompletionRequest( providerName schemas.ModelProvider, sendBackRawRequest bool, sendBackRawResponse bool, + customErrorConverter ErrorConverter, logger schemas.Logger, ) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { // Create request @@ -280,6 +282,9 @@ func HandleOpenAITextCompletionRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { + if customErrorConverter != nil { + return nil, customErrorConverter(resp, schemas.TextCompletionRequest, providerName, request.Model) + } return nil, ParseOpenAIError(resp, schemas.TextCompletionRequest, providerName, request.Model) } @@ -316,7 +321,7 @@ func HandleOpenAITextCompletionRequest( // TextCompletionStream performs a streaming text completion request to OpenAI's API. // It formats the request, sends it to OpenAI, and processes the response. // Returns a channel of BifrostStream objects or an error if the request fails. -func (provider *OpenAIProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OpenAIProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.TextCompletionStreamRequest); err != nil { return nil, err } @@ -334,6 +339,7 @@ func (provider *OpenAIProvider) TextCompletionStream(ctx context.Context, postHo providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + nil, postHookRunner, nil, provider.logger, @@ -343,7 +349,7 @@ func (provider *OpenAIProvider) TextCompletionStream(ctx context.Context, postHo // HandleOpenAITextCompletionStreaming handles text completion streaming for OpenAI-compatible APIs. // This shared function reduces code duplication between providers that use the same SSE format. func HandleOpenAITextCompletionStreaming( - ctx context.Context, + ctx *schemas.BifrostContext, client *fasthttp.Client, url string, request *schemas.BifrostTextCompletionRequest, @@ -352,6 +358,7 @@ func HandleOpenAITextCompletionStreaming( sendBackRawRequest bool, sendBackRawResponse bool, providerName schemas.ModelProvider, + customErrorConverter ErrorConverter, postHookRunner schemas.PostHookRunner, postResponseConverter func(*schemas.BifrostTextCompletionResponse) *schemas.BifrostTextCompletionResponse, logger schemas.Logger, @@ -428,6 +435,9 @@ func HandleOpenAITextCompletionStreaming( // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) + if customErrorConverter != nil { + return nil, customErrorConverter(resp, schemas.TextCompletionStreamRequest, providerName, request.Model) + } return nil, ParseOpenAIError(resp, schemas.TextCompletionStreamRequest, providerName, request.Model) } @@ -436,8 +446,18 @@ func HandleOpenAITextCompletionStreaming( // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) buf := make([]byte, 0, 1024*1024) @@ -452,13 +472,10 @@ func HandleOpenAITextCompletionStreaming( lastChunkTime := startTime for scanner.Scan() { - // Check if context is done before processing - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } - line := scanner.Text() // Skip empty lines and comments @@ -495,7 +512,7 @@ func HandleOpenAITextCompletionStreaming( ModelRequested: request.Model, RequestType: schemas.TextCompletionStreamRequest, } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, logger) return } @@ -587,21 +604,31 @@ func HandleOpenAITextCompletionStreaming( // Handle scanner errors first if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, logger) - } else { - response := providerUtils.CreateBifrostTextCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.TextCompletionStreamRequest, providerName, request.Model) - if postResponseConverter != nil { - response = postResponseConverter(response) - } - // Set raw request if enabled - if sendBackRawRequest { - providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) + return + } + + response := providerUtils.CreateBifrostTextCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.TextCompletionStreamRequest, providerName, request.Model) + if postResponseConverter != nil { + response = postResponseConverter(response) + if response == nil { + logger.Warn("postResponseConverter returned nil; leaving chunk unmodified") + return } - response.ExtraFields.Latency = time.Since(startTime).Milliseconds() - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(response, nil, nil, nil, nil), responseChan) } + // Set raw request if enabled + if sendBackRawRequest { + providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) + } + response.ExtraFields.Latency = time.Since(startTime).Milliseconds() + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(response, nil, nil, nil, nil), responseChan) }() return responseChan, nil @@ -610,7 +637,7 @@ func HandleOpenAITextCompletionStreaming( // ChatCompletion performs a chat completion request to the OpenAI API. // It supports both text and image content in messages. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { // Check if chat completion is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil { return nil, err @@ -626,13 +653,14 @@ func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, key schemas. providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + nil, provider.logger, ) } // HandleOpenAIChatCompletionRequest handles a chat completion request to OpenAI's API. func HandleOpenAIChatCompletionRequest( - ctx context.Context, + ctx *schemas.BifrostContext, client *fasthttp.Client, url string, request *schemas.BifrostChatRequest, @@ -641,6 +669,7 @@ func HandleOpenAIChatCompletionRequest( sendBackRawRequest bool, sendBackRawResponse bool, providerName schemas.ModelProvider, + customErrorConverter ErrorConverter, logger schemas.Logger, ) (*schemas.BifrostChatResponse, *schemas.BifrostError) { // Create request @@ -680,6 +709,9 @@ func HandleOpenAIChatCompletionRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) + if customErrorConverter != nil { + return nil, customErrorConverter(resp, schemas.ChatCompletionRequest, providerName, request.Model) + } return nil, ParseOpenAIError(resp, schemas.ChatCompletionRequest, providerName, request.Model) } @@ -717,7 +749,7 @@ func HandleOpenAIChatCompletionRequest( // ChatCompletionStream handles streaming for OpenAI chat completions. // It formats messages, prepares request body, and uses shared streaming logic. // Returns a channel for streaming responses and any error that occurred. -func (provider *OpenAIProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OpenAIProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Check if chat completion stream is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err @@ -741,6 +773,7 @@ func (provider *OpenAIProvider) ChatCompletionStream(ctx context.Context, postHo nil, nil, nil, + nil, provider.logger, ) } @@ -748,7 +781,7 @@ func (provider *OpenAIProvider) ChatCompletionStream(ctx context.Context, postHo // HandleOpenAIChatCompletionStreaming handles streaming for OpenAI-compatible APIs. // This shared function reduces code duplication between providers that use the same SSE format. func HandleOpenAIChatCompletionStreaming( - ctx context.Context, + ctx *schemas.BifrostContext, client *fasthttp.Client, url string, request *schemas.BifrostChatRequest, @@ -759,6 +792,7 @@ func HandleOpenAIChatCompletionStreaming( providerName schemas.ModelProvider, postHookRunner schemas.PostHookRunner, customRequestConverter func(*schemas.BifrostChatRequest) (any, error), + customErrorConverter ErrorConverter, postRequestConverter func(*OpenAIChatRequest) *OpenAIChatRequest, postResponseConverter func(*schemas.BifrostChatResponse) *schemas.BifrostChatResponse, logger schemas.Logger, @@ -854,16 +888,35 @@ func HandleOpenAIChatCompletionStreaming( // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) + if customErrorConverter != nil { + return nil, customErrorConverter(resp, schemas.ChatCompletionStreamRequest, providerName, request.Model) + } return nil, ParseOpenAIError(resp, schemas.ChatCompletionStreamRequest, providerName, request.Model) } // Create response channel responseChan := make(chan *schemas.BifrostStream, schemas.DefaultStreamBufferSize) + // Determine request type for cleanup + streamRequestType := schemas.ChatCompletionStreamRequest + if isResponsesToChatCompletionsFallback { + streamRequestType = schemas.ResponsesStreamRequest + } + // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, streamRequestType, logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, streamRequestType, logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) buf := make([]byte, 0, 1024*1024) @@ -879,13 +932,10 @@ func HandleOpenAIChatCompletionStreaming( var messageID string for scanner.Scan() { - // Check if context is done before processing - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } - line := scanner.Text() // Skip empty lines and comments @@ -920,9 +970,9 @@ func HandleOpenAIChatCompletionStreaming( bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ Provider: providerName, ModelRequested: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, + RequestType: streamRequestType, } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, logger) return } @@ -944,7 +994,7 @@ func HandleOpenAIChatCompletionStreaming( IsBifrostError: false, Error: &schemas.ErrorField{}, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, + RequestType: streamRequestType, Provider: providerName, ModelRequested: request.Model, }, @@ -960,12 +1010,12 @@ func HandleOpenAIChatCompletionStreaming( bifrostErr.Error.Code = response.Code } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) return } - response.ExtraFields.RequestType = schemas.ResponsesStreamRequest + response.ExtraFields.RequestType = streamRequestType response.ExtraFields.Provider = providerName response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = response.SequenceNumber @@ -980,7 +1030,7 @@ func HandleOpenAIChatCompletionStreaming( providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil), responseChan) return } @@ -1077,10 +1127,18 @@ func HandleOpenAIChatCompletionStreaming( // Handle scanner errors first if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, logger) - } else if !isResponsesToChatCompletionsFallback { - response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerName, request.Model) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, streamRequestType, providerName, request.Model, logger) + return + } + + if !isResponsesToChatCompletionsFallback { + response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, streamRequestType, providerName, request.Model) if postResponseConverter != nil { response = postResponseConverter(response) } @@ -1089,7 +1147,7 @@ func HandleOpenAIChatCompletionStreaming( providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil), responseChan) } }() @@ -1098,7 +1156,7 @@ func HandleOpenAIChatCompletionStreaming( } // Responses performs a responses request to the OpenAI API. -func (provider *OpenAIProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { // Check if chat completion is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { return nil, err @@ -1114,6 +1172,7 @@ func (provider *OpenAIProvider) Responses(ctx context.Context, key schemas.Key, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + nil, provider.logger, ) } @@ -1129,6 +1188,7 @@ func HandleOpenAIResponsesRequest( sendBackRawRequest bool, sendBackRawResponse bool, providerName schemas.ModelProvider, + customErrorConverter ErrorConverter, logger schemas.Logger, ) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { // Create request @@ -1169,6 +1229,9 @@ func HandleOpenAIResponsesRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) + if customErrorConverter != nil { + return nil, customErrorConverter(resp, schemas.ResponsesRequest, providerName, request.Model) + } return nil, ParseOpenAIError(resp, schemas.ResponsesRequest, providerName, request.Model) } @@ -1204,7 +1267,7 @@ func HandleOpenAIResponsesRequest( } // ResponsesStream performs a streaming responses request to the OpenAI API. -func (provider *OpenAIProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OpenAIProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Check if chat completion stream is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { return nil, err @@ -1227,6 +1290,7 @@ func (provider *OpenAIProvider) ResponsesStream(ctx context.Context, postHookRun postHookRunner, nil, nil, + nil, provider.logger, ) } @@ -1234,7 +1298,7 @@ func (provider *OpenAIProvider) ResponsesStream(ctx context.Context, postHookRun // HandleOpenAIResponsesStreaming handles streaming for OpenAI-compatible APIs. // This shared function reduces code duplication between providers that use the same SSE format. func HandleOpenAIResponsesStreaming( - ctx context.Context, + ctx *schemas.BifrostContext, client *fasthttp.Client, url string, request *schemas.BifrostResponsesRequest, @@ -1244,6 +1308,7 @@ func HandleOpenAIResponsesStreaming( sendBackRawResponse bool, providerName schemas.ModelProvider, postHookRunner schemas.PostHookRunner, + customErrorConverter ErrorConverter, postRequestConverter func(*OpenAIResponsesRequest) *OpenAIResponsesRequest, postResponseConverter func(*schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse, logger schemas.Logger, @@ -1321,6 +1386,9 @@ func HandleOpenAIResponsesStreaming( // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) + if customErrorConverter != nil { + return nil, customErrorConverter(resp, schemas.ResponsesStreamRequest, providerName, request.Model) + } return nil, ParseOpenAIError(resp, schemas.ResponsesStreamRequest, providerName, request.Model) } @@ -1329,8 +1397,18 @@ func HandleOpenAIResponsesStreaming( // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) buf := make([]byte, 0, 1024*1024) @@ -1340,13 +1418,10 @@ func HandleOpenAIResponsesStreaming( lastChunkTime := startTime for scanner.Scan() { - // Check if context is done before processing - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } - line := scanner.Text() // Skip empty lines, comments, and event lines @@ -1418,7 +1493,7 @@ func HandleOpenAIResponsesStreaming( bifrostErr.Error.Code = response.Code } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) return } @@ -1434,7 +1509,7 @@ func HandleOpenAIResponsesStreaming( providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, &response, nil, nil), responseChan) return } @@ -1446,6 +1521,11 @@ func HandleOpenAIResponsesStreaming( } // Handle scanner errors first if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, logger) } @@ -1457,7 +1537,7 @@ func HandleOpenAIResponsesStreaming( // Embedding generates embeddings for the given input text(s). // The input can be either a single string or a slice of strings for batch embedding. // Returns a BifrostResponse containing the embedding(s) and any error that occurred. -func (provider *OpenAIProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { // Check if embedding is allowed for this provider if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.EmbeddingRequest); err != nil { return nil, err @@ -1481,7 +1561,7 @@ func (provider *OpenAIProvider) Embedding(ctx context.Context, key schemas.Key, // HandleOpenAIEmbeddingRequest handles embedding requests for OpenAI-compatible APIs. // This shared function reduces code duplication between providers that use the same embedding request format. func HandleOpenAIEmbeddingRequest( - ctx context.Context, + ctx *schemas.BifrostContext, client *fasthttp.Client, url string, request *schemas.BifrostEmbeddingRequest, @@ -1567,7 +1647,7 @@ func HandleOpenAIEmbeddingRequest( // Speech handles non-streaming speech synthesis requests. // It formats the request body, makes the API call, and returns the response. // Returns the response and any error that occurred. -func (provider *OpenAIProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.SpeechRequest); err != nil { return nil, err } @@ -1588,7 +1668,7 @@ func (provider *OpenAIProvider) Speech(ctx context.Context, key schemas.Key, req // HandleOpenAISpeechRequest handles speech requests for OpenAI-compatible APIs. // This shared function reduces code duplication between providers that use the same speech request format. func HandleOpenAISpeechRequest( - ctx context.Context, + ctx *schemas.BifrostContext, client *fasthttp.Client, url string, request *schemas.BifrostSpeechRequest, @@ -1666,7 +1746,7 @@ func HandleOpenAISpeechRequest( // SpeechStream handles streaming for speech synthesis. // It formats the request body, creates HTTP request, and uses shared streaming logic. // Returns a channel for streaming responses and any error that occurred. -func (provider *OpenAIProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OpenAIProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.SpeechStreamRequest); err != nil { return nil, err } @@ -1702,7 +1782,7 @@ func (provider *OpenAIProvider) SpeechStream(ctx context.Context, postHookRunner // HandleOpenAISpeechStreamRequest handles speech stream requests for OpenAI-compatible APIs. // This shared function reduces code duplication between providers that use the same speech stream request format. func HandleOpenAISpeechStreamRequest( - ctx context.Context, + ctx *schemas.BifrostContext, client *fasthttp.Client, url string, request *schemas.BifrostSpeechRequest, @@ -1797,8 +1877,18 @@ func HandleOpenAISpeechStreamRequest( // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) chunkIndex := -1 @@ -1807,11 +1897,9 @@ func HandleOpenAISpeechStreamRequest( lastChunkTime := startTime for scanner.Scan() { - // Check if context is done before processing - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } line := scanner.Text() @@ -1850,7 +1938,7 @@ func HandleOpenAISpeechStreamRequest( ModelRequested: request.Model, RequestType: schemas.SpeechStreamRequest, } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, logger) return } @@ -1891,7 +1979,7 @@ func HandleOpenAISpeechStreamRequest( if sendBackRawRequest { providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody) } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &response, nil), responseChan) return } @@ -1901,6 +1989,11 @@ func HandleOpenAISpeechStreamRequest( // Handle scanner errors if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, logger) } @@ -1912,7 +2005,7 @@ func HandleOpenAISpeechStreamRequest( // Transcription handles non-streaming transcription requests. // It creates a multipart form, adds fields, makes the API call, and returns the response. // Returns the response and any error that occurred. -func (provider *OpenAIProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.TranscriptionRequest); err != nil { return nil, err } @@ -1931,7 +2024,7 @@ func (provider *OpenAIProvider) Transcription(ctx context.Context, key schemas.K } func HandleOpenAITranscriptionRequest( - ctx context.Context, + ctx *schemas.BifrostContext, client *fasthttp.Client, url string, request *schemas.BifrostTranscriptionRequest, @@ -2019,6 +2112,8 @@ func HandleOpenAITranscriptionRequest( return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) } + //TODO: add HandleProviderResponse here + // Parse raw response for RawResponse field var rawResponse interface{} if sendBackRawResponse { @@ -2042,7 +2137,7 @@ func HandleOpenAITranscriptionRequest( } // TranscriptionStream performs a streaming transcription request to the OpenAI API. -func (provider *OpenAIProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OpenAIProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.TranscriptionStreamRequest); err != nil { return nil, err } @@ -2071,7 +2166,7 @@ func (provider *OpenAIProvider) TranscriptionStream(ctx context.Context, postHoo // HandleOpenAITranscriptionStreamRequest handles transcription stream requests for OpenAI-compatible APIs. // This shared function reduces code duplication between providers that use the same transcription stream request format. func HandleOpenAITranscriptionStreamRequest( - ctx context.Context, + ctx *schemas.BifrostContext, client *fasthttp.Client, url string, request *schemas.BifrostTranscriptionRequest, @@ -2164,8 +2259,18 @@ func HandleOpenAITranscriptionStreamRequest( // Start streaming in a goroutine go func() { - defer close(responseChan) + defer func() { + if ctx.Err() == context.Canceled { + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + } else if ctx.Err() == context.DeadlineExceeded { + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + } + close(responseChan) + }() defer providerUtils.ReleaseStreamingResponse(resp) + // Setup cancellation handler to close body stream on ctx cancellation + stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger) + defer stopCancellation() scanner := bufio.NewScanner(resp.BodyStream()) chunkIndex := -1 @@ -2174,13 +2279,11 @@ func HandleOpenAITranscriptionStreamRequest( lastChunkTime := startTime for scanner.Scan() { - // Check if context is done before processing - select { - case <-ctx.Done(): + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { return - default: } - + line := scanner.Text() // Skip empty lines and comments @@ -2216,7 +2319,7 @@ func HandleOpenAITranscriptionStreamRequest( ModelRequested: request.Model, RequestType: schemas.TranscriptionStreamRequest, } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, logger) return } @@ -2253,7 +2356,7 @@ func HandleOpenAITranscriptionStreamRequest( if response.Usage != nil { response.ExtraFields.Latency = time.Since(startTime).Milliseconds() - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, &response), responseChan) return } @@ -2263,6 +2366,11 @@ func HandleOpenAITranscriptionStreamRequest( // Handle scanner errors if err := scanner.Err(); err != nil { + // If context was cancelled/timed out, let defer handle it + if ctx.Err() != nil { + return + } + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, logger) } @@ -2272,7 +2380,7 @@ func HandleOpenAITranscriptionStreamRequest( } // CountTokens performs a count tokens request to the OpenAI API. -func (provider *OpenAIProvider) CountTokens(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) CountTokens(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.CountTokensRequest); err != nil { return nil, err } @@ -2293,7 +2401,7 @@ func (provider *OpenAIProvider) CountTokens(ctx context.Context, key schemas.Key // HandleOpenAICountTokensRequest handles a count tokens request to OpenAI's API. func HandleOpenAICountTokensRequest( - ctx context.Context, + ctx *schemas.BifrostContext, client *fasthttp.Client, url string, request *schemas.BifrostResponsesRequest, @@ -2376,7 +2484,7 @@ func HandleOpenAICountTokensRequest( } // FileUpload uploads a file to OpenAI. -func (provider *OpenAIProvider) FileUpload(ctx context.Context, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.FileUploadRequest); err != nil { return nil, err } @@ -2465,7 +2573,7 @@ func (provider *OpenAIProvider) FileUpload(ctx context.Context, key schemas.Key, // FileList lists files using serial pagination across keys. // Exhausts all pages from one key before moving to the next. -func (provider *OpenAIProvider) FileList(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.FileListRequest); err != nil { return nil, err } @@ -2594,7 +2702,7 @@ func (provider *OpenAIProvider) FileList(ctx context.Context, keys []schemas.Key } // FileRetrieve retrieves file metadata from OpenAI by trying each key until found. -func (provider *OpenAIProvider) FileRetrieve(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) FileRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.FileRetrieveRequest); err != nil { return nil, err } @@ -2669,7 +2777,7 @@ func (provider *OpenAIProvider) FileRetrieve(ctx context.Context, keys []schemas } // FileDelete deletes a file from OpenAI by trying each key until successful. -func (provider *OpenAIProvider) FileDelete(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.FileDeleteRequest); err != nil { return nil, err } @@ -2763,7 +2871,7 @@ func (provider *OpenAIProvider) FileDelete(ctx context.Context, keys []schemas.K } // FileContent downloads file content from OpenAI by trying each key until found. -func (provider *OpenAIProvider) FileContent(ctx context.Context, keys []schemas.Key, request *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.FileContentRequest); err != nil { return nil, err } @@ -2841,7 +2949,7 @@ func (provider *OpenAIProvider) FileContent(ctx context.Context, keys []schemas. } // BatchCreate creates a new batch job. -func (provider *OpenAIProvider) BatchCreate(ctx context.Context, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.BatchCreateRequest); err != nil { return nil, err } @@ -2946,7 +3054,7 @@ func (provider *OpenAIProvider) BatchCreate(ctx context.Context, key schemas.Key // BatchList lists batch jobs using serial pagination across keys. // Exhausts all pages from one key before moving to the next. -func (provider *OpenAIProvider) BatchList(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.BatchListRequest); err != nil { return nil, err } @@ -3060,7 +3168,7 @@ func (provider *OpenAIProvider) BatchList(ctx context.Context, keys []schemas.Ke } // BatchRetrieve retrieves a specific batch job by trying each key until found. -func (provider *OpenAIProvider) BatchRetrieve(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.BatchRetrieveRequest); err != nil { return nil, err } @@ -3135,7 +3243,7 @@ func (provider *OpenAIProvider) BatchRetrieve(ctx context.Context, keys []schema } // BatchCancel cancels a batch job by trying each key until successful. -func (provider *OpenAIProvider) BatchCancel(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.BatchCancelRequest); err != nil { return nil, err } @@ -3239,7 +3347,7 @@ func (provider *OpenAIProvider) BatchCancel(ctx context.Context, keys []schemas. // BatchResults retrieves batch results by trying each key until successful. // Note: For OpenAI, batch results are obtained by downloading the output_file_id. // This method returns the file content parsed as batch results. -func (provider *OpenAIProvider) BatchResults(ctx context.Context, keys []schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.BatchResultsRequest); err != nil { return nil, err } diff --git a/core/providers/openai/responses.go b/core/providers/openai/responses.go index d1e10cfb55..edcdac3cf7 100644 --- a/core/providers/openai/responses.go +++ b/core/providers/openai/responses.go @@ -152,6 +152,15 @@ func ToOpenAIResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) *Open // Clear max_tokens since OpenAI doesn't use it req.ResponsesParameters.Reasoning.MaxTokens = nil } + + // Handle xAI-specific parameter filtering + // Only grok-3-mini supports reasoning_effort + if bifrostReq.Provider == schemas.XAI && + schemas.IsGrokReasoningModel(bifrostReq.Model) && + !strings.Contains(bifrostReq.Model, "grok-3-mini") { + // Clear reasoning_effort for non-grok-3-mini xAI reasoning models + req.ResponsesParameters.Reasoning.Effort = nil + } } // Filter out tools that OpenAI doesn't support diff --git a/core/providers/openai/types.go b/core/providers/openai/types.go index 4b42b33776..916efc1e92 100644 --- a/core/providers/openai/types.go +++ b/core/providers/openai/types.go @@ -525,7 +525,7 @@ func (r *OpenAITranscriptionRequest) IsStreamingRequested() bool { return r.Stream != nil && *r.Stream } -// MODEL TYPES +// OpenAIModel represents an OpenAI model type OpenAIModel struct { ID string `json:"id"` Object string `json:"object"` @@ -537,6 +537,7 @@ type OpenAIModel struct { ContextWindow *int `json:"context_window,omitempty"` } +// OpenAIListModelsResponse represents an OpenAI list models response type OpenAIListModelsResponse struct { Object string `json:"object"` Data []OpenAIModel `json:"data"` diff --git a/core/providers/openrouter/openrouter.go b/core/providers/openrouter/openrouter.go index 4504987f1b..bd0a3dd056 100644 --- a/core/providers/openrouter/openrouter.go +++ b/core/providers/openrouter/openrouter.go @@ -2,7 +2,6 @@ package openrouter import ( - "context" "fmt" "net/http" "strings" @@ -62,7 +61,7 @@ func (provider *OpenRouterProvider) GetProviderKey() schemas.ModelProvider { // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. -func (provider *OpenRouterProvider) listModelsByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() // Create request @@ -124,7 +123,7 @@ func (provider *OpenRouterProvider) listModelsByKey(ctx context.Context, key sch // ListModels performs a list models request to OpenRouter's API. // Requests are made concurrently for improved performance. -func (provider *OpenRouterProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { return providerUtils.HandleMultipleListModelsRequests( ctx, keys, @@ -135,7 +134,7 @@ func (provider *OpenRouterProvider) ListModels(ctx context.Context, keys []schem } // TextCompletion performs a text completion request to the OpenRouter API. -func (provider *OpenRouterProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionRequest( ctx, provider.client, @@ -146,6 +145,7 @@ func (provider *OpenRouterProvider) TextCompletion(ctx context.Context, key sche provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), + nil, provider.logger, ) } @@ -153,7 +153,7 @@ func (provider *OpenRouterProvider) TextCompletion(ctx context.Context, key sche // TextCompletionStream performs a streaming text completion request to OpenRouter's API. // It formats the request, sends it to OpenRouter, and processes the response. // Returns a channel of BifrostStream objects or an error if the request fails. -func (provider *OpenRouterProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OpenRouterProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { var authHeader map[string]string if key.Value != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value} @@ -168,6 +168,7 @@ func (provider *OpenRouterProvider) TextCompletionStream(ctx context.Context, po providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + nil, postHookRunner, nil, provider.logger, @@ -175,7 +176,7 @@ func (provider *OpenRouterProvider) TextCompletionStream(ctx context.Context, po } // ChatCompletion performs a chat completion request to the OpenRouter API. -func (provider *OpenRouterProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { return openai.HandleOpenAIChatCompletionRequest( ctx, provider.client, @@ -186,6 +187,7 @@ func (provider *OpenRouterProvider) ChatCompletion(ctx context.Context, key sche providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + nil, provider.logger, ) } @@ -194,7 +196,7 @@ func (provider *OpenRouterProvider) ChatCompletion(ctx context.Context, key sche // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses OpenRouter's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *OpenRouterProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OpenRouterProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { var authHeader map[string]string if key.Value != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value} @@ -214,12 +216,13 @@ func (provider *OpenRouterProvider) ChatCompletionStream(ctx context.Context, po nil, nil, nil, + nil, provider.logger, ) } // Responses performs a responses request to the OpenRouter API. -func (provider *OpenRouterProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { return openai.HandleOpenAIResponsesRequest( ctx, provider.client, @@ -230,12 +233,13 @@ func (provider *OpenRouterProvider) Responses(ctx context.Context, key schemas.K providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + nil, provider.logger, ) } // ResponsesStream performs a streaming responses request to the OpenRouter API. -func (provider *OpenRouterProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OpenRouterProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { var authHeader map[string]string if key.Value != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value} @@ -253,86 +257,87 @@ func (provider *OpenRouterProvider) ResponsesStream(ctx context.Context, postHoo postHookRunner, nil, nil, + nil, provider.logger, ) } // Embedding is not supported by the OpenRouter provider. -func (provider *OpenRouterProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey()) } // Speech is not supported by the OpenRouter provider. -func (provider *OpenRouterProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) } // SpeechStream is not supported by the OpenRouter provider. -func (provider *OpenRouterProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OpenRouterProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } // Transcription is not supported by the OpenRouter provider. -func (provider *OpenRouterProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey()) } // TranscriptionStream is not supported by the OpenRouter provider. -func (provider *OpenRouterProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OpenRouterProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } // BatchCreate is not supported by OpenRouter provider. -func (provider *OpenRouterProvider) BatchCreate(_ context.Context, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) } // BatchList is not supported by OpenRouter provider. -func (provider *OpenRouterProvider) BatchList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey()) } // BatchRetrieve is not supported by OpenRouter provider. -func (provider *OpenRouterProvider) BatchRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey()) } // BatchCancel is not supported by OpenRouter provider. -func (provider *OpenRouterProvider) BatchCancel(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey()) } // BatchResults is not supported by OpenRouter provider. -func (provider *OpenRouterProvider) BatchResults(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey()) } // FileUpload is not supported by OpenRouter provider. -func (provider *OpenRouterProvider) FileUpload(_ context.Context, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey()) } // FileList is not supported by OpenRouter provider. -func (provider *OpenRouterProvider) FileList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey()) } // FileRetrieve is not supported by OpenRouter provider. -func (provider *OpenRouterProvider) FileRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey()) } // FileDelete is not supported by OpenRouter provider. -func (provider *OpenRouterProvider) FileDelete(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey()) } // FileContent is not supported by OpenRouter provider. -func (provider *OpenRouterProvider) FileContent(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey()) } // CountTokens is not supported by the OpenRouter provider. -func (provider *OpenRouterProvider) CountTokens(_ context.Context, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey()) } diff --git a/core/providers/parasail/parasail.go b/core/providers/parasail/parasail.go index 4b51cf06d2..a4c32d81a7 100644 --- a/core/providers/parasail/parasail.go +++ b/core/providers/parasail/parasail.go @@ -3,7 +3,6 @@ package parasail import ( - "context" "strings" "time" @@ -60,7 +59,7 @@ func (provider *ParasailProvider) GetProviderKey() schemas.ModelProvider { } // ListModels performs a list models request to Parasail's API. -func (provider *ParasailProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { return openai.HandleOpenAIListModelsRequest( ctx, provider.client, @@ -76,19 +75,19 @@ func (provider *ParasailProvider) ListModels(ctx context.Context, keys []schemas } // TextCompletion is not supported by the Parasail provider. -func (provider *ParasailProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionRequest, provider.GetProviderKey()) } // TextCompletionStream performs a streaming text completion request to Parasail's API. // It formats the request, sends it to Parasail, and processes the response. // Returns a channel of BifrostStream objects or an error if the request fails. -func (provider *ParasailProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *ParasailProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } // ChatCompletion performs a chat completion request to the Parasail API. -func (provider *ParasailProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { return openai.HandleOpenAIChatCompletionRequest( ctx, provider.client, @@ -99,6 +98,7 @@ func (provider *ParasailProvider) ChatCompletion(ctx context.Context, key schema providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + nil, provider.logger, ) } @@ -107,7 +107,7 @@ func (provider *ParasailProvider) ChatCompletion(ctx context.Context, key schema // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Parasail's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *ParasailProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *ParasailProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { var authHeader map[string]string if key.Value != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value} @@ -127,12 +127,13 @@ func (provider *ParasailProvider) ChatCompletionStream(ctx context.Context, post nil, nil, nil, + nil, provider.logger, ) } // Responses performs a responses request to the Parasail API. -func (provider *ParasailProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) if err != nil { return nil, err @@ -147,8 +148,8 @@ func (provider *ParasailProvider) Responses(ctx context.Context, key schemas.Key } // ResponsesStream performs a streaming responses request to the Parasail API. -func (provider *ParasailProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) +func (provider *ParasailProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, @@ -158,81 +159,81 @@ func (provider *ParasailProvider) ResponsesStream(ctx context.Context, postHookR } // Embedding is not supported by the Parasail provider. -func (provider *ParasailProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey()) } // Speech is not supported by the Parasail provider. -func (provider *ParasailProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) } // SpeechStream is not supported by the Parasail provider. -func (provider *ParasailProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *ParasailProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } // Transcription is not supported by the Parasail provider. -func (provider *ParasailProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey()) } // TranscriptionStream is not supported by the Parasail provider. -func (provider *ParasailProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *ParasailProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } // FileUpload is not supported by Parasail provider. -func (provider *ParasailProvider) FileUpload(_ context.Context, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey()) } // FileList is not supported by Parasail provider. -func (provider *ParasailProvider) FileList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey()) } // FileRetrieve is not supported by Parasail provider. -func (provider *ParasailProvider) FileRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey()) } // FileDelete is not supported by Parasail provider. -func (provider *ParasailProvider) FileDelete(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey()) } // FileContent is not supported by Parasail provider. -func (provider *ParasailProvider) FileContent(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey()) } // BatchCreate is not supported by Parasail provider. -func (provider *ParasailProvider) BatchCreate(_ context.Context, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) } // BatchList is not supported by Parasail provider. -func (provider *ParasailProvider) BatchList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey()) } // BatchRetrieve is not supported by Parasail provider. -func (provider *ParasailProvider) BatchRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey()) } // BatchCancel is not supported by Parasail provider. -func (provider *ParasailProvider) BatchCancel(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey()) } // BatchResults is not supported by Parasail provider. -func (provider *ParasailProvider) BatchResults(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey()) } // CountTokens is not supported by the Parasail provider. -func (provider *ParasailProvider) CountTokens(_ context.Context, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey()) } diff --git a/core/providers/perplexity/perplexity.go b/core/providers/perplexity/perplexity.go index c4441a04d7..4e8b7d681e 100644 --- a/core/providers/perplexity/perplexity.go +++ b/core/providers/perplexity/perplexity.go @@ -3,7 +3,6 @@ package perplexity import ( - "context" "fmt" "net/http" "strings" @@ -64,7 +63,7 @@ func (provider *PerplexityProvider) GetProviderKey() schemas.ModelProvider { // completeRequest sends a request to Perplexity's API and handles the response. // It constructs the API URL, sets up authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *PerplexityProvider) completeRequest(ctx context.Context, jsonData []byte, url string, key string, model string) ([]byte, time.Duration, *schemas.BifrostError) { +func (provider *PerplexityProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string, model string) ([]byte, time.Duration, *schemas.BifrostError) { // Create the request with the JSON body req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -108,24 +107,24 @@ func (provider *PerplexityProvider) completeRequest(ctx context.Context, jsonDat } // ListModels performs a list models request to Perplexity's API. -func (provider *PerplexityProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *PerplexityProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.ListModelsRequest, provider.GetProviderKey()) } // TextCompletion is not supported by the Perplexity provider. -func (provider *PerplexityProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *PerplexityProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionRequest, provider.GetProviderKey()) } // TextCompletionStream performs a streaming text completion request to Perplexity's API. // It formats the request, sends it to Perplexity, and processes the response. // Returns a channel of BifrostStream objects or an error if the request fails. -func (provider *PerplexityProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *PerplexityProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } // ChatCompletion performs a chat completion request to the Perplexity API. -func (provider *PerplexityProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *PerplexityProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { // Convert to Perplexity chat completion request jsonBody, err := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -172,7 +171,7 @@ func (provider *PerplexityProvider) ChatCompletion(ctx context.Context, key sche // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Perplexity's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *PerplexityProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *PerplexityProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { var authHeader map[string]string if key.Value != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value} @@ -197,12 +196,13 @@ func (provider *PerplexityProvider) ChatCompletionStream(ctx context.Context, po customRequestConverter, nil, nil, + nil, provider.logger, ) } // Responses performs a responses request to the Perplexity API. -func (provider *PerplexityProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *PerplexityProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) if err != nil { return nil, err @@ -217,8 +217,8 @@ func (provider *PerplexityProvider) Responses(ctx context.Context, key schemas.K } // ResponsesStream performs a streaming responses request to the Perplexity API. -func (provider *PerplexityProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) +func (provider *PerplexityProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, @@ -228,81 +228,81 @@ func (provider *PerplexityProvider) ResponsesStream(ctx context.Context, postHoo } // Embedding is not supported by the Perplexity provider. -func (provider *PerplexityProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *PerplexityProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey()) } // Speech is not supported by the Perplexity provider. -func (provider *PerplexityProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *PerplexityProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) } // SpeechStream is not supported by the Perplexity provider. -func (provider *PerplexityProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *PerplexityProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } // Transcription is not supported by the Perplexity provider. -func (provider *PerplexityProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *PerplexityProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey()) } // TranscriptionStream is not supported by the Perplexity provider. -func (provider *PerplexityProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *PerplexityProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } // BatchCreate is not supported by Perplexity provider. -func (provider *PerplexityProvider) BatchCreate(_ context.Context, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *PerplexityProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) } // BatchList is not supported by Perplexity provider. -func (provider *PerplexityProvider) BatchList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *PerplexityProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey()) } // BatchRetrieve is not supported by Perplexity provider. -func (provider *PerplexityProvider) BatchRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *PerplexityProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey()) } // BatchCancel is not supported by Perplexity provider. -func (provider *PerplexityProvider) BatchCancel(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *PerplexityProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey()) } // BatchResults is not supported by Perplexity provider. -func (provider *PerplexityProvider) BatchResults(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *PerplexityProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey()) } // FileUpload is not supported by Perplexity provider. -func (provider *PerplexityProvider) FileUpload(_ context.Context, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *PerplexityProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey()) } // FileList is not supported by Perplexity provider. -func (provider *PerplexityProvider) FileList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *PerplexityProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey()) } // FileRetrieve is not supported by Perplexity provider. -func (provider *PerplexityProvider) FileRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *PerplexityProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey()) } // FileDelete is not supported by Perplexity provider. -func (provider *PerplexityProvider) FileDelete(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *PerplexityProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey()) } // FileContent is not supported by Perplexity provider. -func (provider *PerplexityProvider) FileContent(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *PerplexityProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey()) } // CountTokens is not supported by the Perplexity provider. -func (provider *PerplexityProvider) CountTokens(_ context.Context, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *PerplexityProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey()) } diff --git a/core/providers/sgl/sgl.go b/core/providers/sgl/sgl.go index 85db38e4ab..5b7af3c42f 100644 --- a/core/providers/sgl/sgl.go +++ b/core/providers/sgl/sgl.go @@ -3,7 +3,6 @@ package sgl import ( - "context" "fmt" "strings" "time" @@ -67,7 +66,7 @@ func (provider *SGLProvider) GetProviderKey() schemas.ModelProvider { } // ListModels performs a list models request to SGL's API. -func (provider *SGLProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *SGLProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { return openai.HandleOpenAIListModelsRequest( ctx, provider.client, @@ -83,7 +82,7 @@ func (provider *SGLProvider) ListModels(ctx context.Context, keys []schemas.Key, } // TextCompletion is not supported by the SGL provider. -func (provider *SGLProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *SGLProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionRequest( ctx, provider.client, @@ -94,6 +93,7 @@ func (provider *SGLProvider) TextCompletion(ctx context.Context, key schemas.Key provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), + nil, provider.logger, ) } @@ -101,7 +101,7 @@ func (provider *SGLProvider) TextCompletion(ctx context.Context, key schemas.Key // TextCompletionStream performs a streaming text completion request to SGL's API. // It formats the request, sends it to SGL, and processes the response. // Returns a channel of BifrostStream objects or an error if the request fails. -func (provider *SGLProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *SGLProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionStreaming( ctx, provider.client, @@ -112,6 +112,7 @@ func (provider *SGLProvider) TextCompletionStream(ctx context.Context, postHookR providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + nil, postHookRunner, nil, provider.logger, @@ -119,7 +120,7 @@ func (provider *SGLProvider) TextCompletionStream(ctx context.Context, postHookR } // ChatCompletion performs a chat completion request to the SGL API. -func (provider *SGLProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *SGLProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { return openai.HandleOpenAIChatCompletionRequest( ctx, provider.client, @@ -130,6 +131,7 @@ func (provider *SGLProvider) ChatCompletion(ctx context.Context, key schemas.Key providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + nil, provider.logger, ) } @@ -138,7 +140,7 @@ func (provider *SGLProvider) ChatCompletion(ctx context.Context, key schemas.Key // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses SGL's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *SGLProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *SGLProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Use shared OpenAI-compatible streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, @@ -154,12 +156,13 @@ func (provider *SGLProvider) ChatCompletionStream(ctx context.Context, postHookR nil, nil, nil, + nil, provider.logger, ) } // Responses performs a responses request to the SGL API. -func (provider *SGLProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *SGLProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) if err != nil { return nil, err @@ -174,8 +177,8 @@ func (provider *SGLProvider) Responses(ctx context.Context, key schemas.Key, req } // ResponsesStream performs a streaming responses request to the SGL API. -func (provider *SGLProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) +func (provider *SGLProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, @@ -185,7 +188,7 @@ func (provider *SGLProvider) ResponsesStream(ctx context.Context, postHookRunner } // Embedding is not supported by the SGL provider. -func (provider *SGLProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *SGLProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { return openai.HandleOpenAIEmbeddingRequest( ctx, provider.client, @@ -201,76 +204,76 @@ func (provider *SGLProvider) Embedding(ctx context.Context, key schemas.Key, req } // Speech is not supported by the SGL provider. -func (provider *SGLProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *SGLProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) } // SpeechStream is not supported by the SGL provider. -func (provider *SGLProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *SGLProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } // Transcription is not supported by the SGL provider. -func (provider *SGLProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *SGLProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey()) } // TranscriptionStream is not supported by the SGL provider. -func (provider *SGLProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *SGLProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } // FileUpload is not supported by SGL provider. -func (provider *SGLProvider) FileUpload(_ context.Context, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *SGLProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey()) } // FileList is not supported by SGL provider. -func (provider *SGLProvider) FileList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *SGLProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey()) } // FileRetrieve is not supported by SGL provider. -func (provider *SGLProvider) FileRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *SGLProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey()) } // FileDelete is not supported by SGL provider. -func (provider *SGLProvider) FileDelete(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *SGLProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey()) } // FileContent is not supported by SGL provider. -func (provider *SGLProvider) FileContent(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *SGLProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey()) } // BatchCreate is not supported by SGL provider. -func (provider *SGLProvider) BatchCreate(_ context.Context, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *SGLProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) } // BatchList is not supported by SGL provider. -func (provider *SGLProvider) BatchList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *SGLProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey()) } // BatchRetrieve is not supported by SGL provider. -func (provider *SGLProvider) BatchRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *SGLProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey()) } // BatchCancel is not supported by SGL provider. -func (provider *SGLProvider) BatchCancel(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *SGLProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey()) } // BatchResults is not supported by SGL provider. -func (provider *SGLProvider) BatchResults(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *SGLProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey()) } // CountTokens is not supported by the SGL provider. -func (provider *SGLProvider) CountTokens(_ context.Context, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *SGLProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey()) } diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index 91a48fb451..c714ef1e8b 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -832,7 +832,7 @@ func ShouldSendBackRawResponse(ctx context.Context, defaultSendBackRawResponse b } // SendCreatedEventResponsesChunk sends a ResponsesStreamResponseTypeCreated event. -func SendCreatedEventResponsesChunk(ctx context.Context, postHookRunner schemas.PostHookRunner, provider schemas.ModelProvider, model string, startTime time.Time, responseChan chan *schemas.BifrostStream) { +func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, provider schemas.ModelProvider, model string, startTime time.Time, responseChan chan *schemas.BifrostStream) { firstChunk := &schemas.BifrostResponsesStreamResponse{ Type: schemas.ResponsesStreamResponseTypeCreated, SequenceNumber: 0, @@ -853,7 +853,7 @@ func SendCreatedEventResponsesChunk(ctx context.Context, postHookRunner schemas. } // SendInProgressEventResponsesChunk sends a ResponsesStreamResponseTypeInProgress event -func SendInProgressEventResponsesChunk(ctx context.Context, postHookRunner schemas.PostHookRunner, provider schemas.ModelProvider, model string, startTime time.Time, responseChan chan *schemas.BifrostStream) { +func SendInProgressEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, provider schemas.ModelProvider, model string, startTime time.Time, responseChan chan *schemas.BifrostStream) { chunk := &schemas.BifrostResponsesStreamResponse{ Type: schemas.ResponsesStreamResponseTypeInProgress, SequenceNumber: 1, @@ -877,16 +877,30 @@ func SendInProgressEventResponsesChunk(ctx context.Context, postHookRunner schem // This utility reduces code duplication across streaming implementations by encapsulating // the common pattern of running post hooks, handling errors, and sending responses with // proper context cancellation handling. +// It also completes the deferred LLM span when the final chunk is sent (StreamEndIndicator is true). func ProcessAndSendResponse( - ctx context.Context, + ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, response *schemas.BifrostResponse, responseChan chan *schemas.BifrostStream, ) { - // Run post hooks on the response - processedResponse, processedError := postHookRunner(&ctx, response, nil) + // Accumulate chunk for tracing (common for all providers) + if tracer, ok := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer); ok && tracer != nil { + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" { + tracer.AddStreamingChunk(traceID, response) + } + } + + // Run post hooks on the response (note: accumulated chunks above contain pre-hook data) + processedResponse, processedError := postHookRunner(ctx, response, nil) if HandleStreamControlSkip(processedError) { + // Even if skipping, complete the deferred span if this is the final chunk + if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil { + if final, ok := isFinalChunk.(bool); ok && final { + completeDeferredSpan(ctx, processedResponse, processedError) + } + } return } @@ -907,23 +921,37 @@ func ProcessAndSendResponse( case <-ctx.Done(): return } + + // Check if this is the final chunk and complete deferred span with post-processed data + if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil { + if final, ok := isFinalChunk.(bool); ok && final { + completeDeferredSpan(ctx, processedResponse, processedError) + } + } } // ProcessAndSendBifrostError handles post-hook processing and sends the bifrost error to the channel. // This utility reduces code duplication across streaming implementations by encapsulating // the common pattern of running post hooks, handling errors, and sending responses with // proper context cancellation handling. +// It also completes the deferred LLM span when the final chunk is sent (StreamEndIndicator is true). func ProcessAndSendBifrostError( - ctx context.Context, + ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, bifrostErr *schemas.BifrostError, responseChan chan *schemas.BifrostStream, logger schemas.Logger, ) { - // Send scanner error through channel - processedResponse, processedError := postHookRunner(&ctx, nil, bifrostErr) + // Run post hooks first so span reflects post-processed data + processedResponse, processedError := postHookRunner(ctx, nil, bifrostErr) if HandleStreamControlSkip(processedError) { + // Even if skipping, complete the deferred span if this is the final chunk + if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil { + if final, ok := isFinalChunk.(bool); ok && final { + completeDeferredSpan(ctx, processedResponse, processedError) + } + } return } @@ -943,6 +971,120 @@ func ProcessAndSendBifrostError( case responseChan <- streamResponse: case <-ctx.Done(): } + + // Check if this is the final chunk and complete deferred span with post-processed data + if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil { + if final, ok := isFinalChunk.(bool); ok && final { + completeDeferredSpan(ctx, processedResponse, processedError) + } + } +} + +// SetupStreamCancellation spawns a goroutine that closes the body stream when +// the context is cancelled or deadline exceeded, unblocking any blocked Read/Scan operations. +// Returns a cleanup function that MUST be called when streaming is done to +// prevent the goroutine from closing the stream during normal operation. +// Works with both fasthttp's BodyStream() (io.Reader) and net/http's resp.Body (io.ReadCloser). +func SetupStreamCancellation(ctx context.Context, bodyStream io.Reader, logger schemas.Logger) (cleanup func()) { + done := make(chan struct{}) + + go func() { + select { + case <-ctx.Done(): + // Context cancelled or deadline exceeded - close the body stream to unblock reads + if closer, ok := bodyStream.(io.Closer); ok { + if err := closer.Close(); err != nil && logger != nil { + logger.Debug(fmt.Sprintf("Error closing body stream on context done: %v", err)) + } + } + case <-done: + // Normal completion - do nothing + } + }() + + return func() { close(done) } +} + +// HandleStreamCancellation should be called when a streaming goroutine exits +// due to context cancellation. It ensures proper cleanup by: +// 1. Checking if StreamEndIndicator was already set (to avoid duplicate handling) +// 2. Setting StreamEndIndicator to true +// 3. Sending a cancellation error through PostHook chain +// +// This is critical for the logging plugin to update log status from "processing" to "error" +// when a client disconnects mid-stream. +func HandleStreamCancellation( + ctx *schemas.BifrostContext, + postHookRunner schemas.PostHookRunner, + responseChan chan *schemas.BifrostStream, + provider schemas.ModelProvider, + model string, + requestType schemas.RequestType, + logger schemas.Logger, +) { + // Check if already handled (StreamEndIndicator already set) + if indicator := ctx.GetAndSetValue(schemas.BifrostContextKeyStreamEndIndicator, true); indicator != nil { + if set, ok := indicator.(bool); ok && set { + return // Already handled + } + } + // Create cancellation error + cancelErr := &schemas.BifrostError{ + StatusCode: schemas.Ptr(499), // Client Closed Request + Error: &schemas.ErrorField{ + Message: "Request cancelled: client disconnected", + Type: schemas.Ptr(schemas.RequestCancelled), + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + Provider: provider, + ModelRequested: model, + RequestType: requestType, + }, + } + + // Send through PostHook chain - this updates the log to "error" status + ProcessAndSendBifrostError(ctx, postHookRunner, cancelErr, responseChan, logger) +} + +// HandleStreamTimeout should be called when a streaming goroutine exits +// due to context deadline exceeded. It ensures proper cleanup by: +// 1. Checking if StreamEndIndicator was already set (to avoid duplicate handling) +// 2. Setting StreamEndIndicator to true +// 3. Sending a timeout error through PostHook chain +// +// This is critical for the logging plugin to update log status from "processing" to "error" +// when a request times out mid-stream. +func HandleStreamTimeout( + ctx *schemas.BifrostContext, + postHookRunner schemas.PostHookRunner, + responseChan chan *schemas.BifrostStream, + provider schemas.ModelProvider, + model string, + requestType schemas.RequestType, + logger schemas.Logger, +) { + // Check if already handled (StreamEndIndicator already set) + if indicator := ctx.GetAndSetValue(schemas.BifrostContextKeyStreamEndIndicator, true); indicator != nil { + if set, ok := indicator.(bool); ok && set { + return // Already handled + } + } + // Create timeout error + timeoutErr := &schemas.BifrostError{ + StatusCode: schemas.Ptr(504), // Gateway Timeout + Error: &schemas.ErrorField{ + Message: "Request timed out: deadline exceeded", + Type: schemas.Ptr(schemas.RequestTimedOut), + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + Provider: provider, + ModelRequested: model, + RequestType: requestType, + }, + } + + // Send through PostHook chain - this updates the log to "error" status + ProcessAndSendBifrostError(ctx, postHookRunner, timeoutErr, responseChan, logger) } // ProcessAndSendError handles post-hook processing and sends the error to the channel. @@ -950,7 +1092,7 @@ func ProcessAndSendBifrostError( // the common pattern of running post hooks, handling errors, and sending responses with // proper context cancellation handling. func ProcessAndSendError( - ctx context.Context, + ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, err error, responseChan chan *schemas.BifrostStream, @@ -973,7 +1115,7 @@ func ProcessAndSendError( ModelRequested: model, }, } - processedResponse, processedError := postHookRunner(&ctx, nil, bifrostError) + processedResponse, processedError := postHookRunner(ctx, nil, bifrostError) if HandleStreamControlSkip(processedError) { return @@ -1113,6 +1255,7 @@ func ReleaseStreamingResponse(resp *fasthttp.Response) { // Drain any remaining data from the body stream before releasing // This prevents "whitespace in header" errors when the response is reused if resp.BodyStream() != nil { + // Drain the body stream io.Copy(io.Discard, resp.BodyStream()) } fasthttp.ReleaseResponse(resp) @@ -1246,10 +1389,10 @@ func extractSuccessfulListModelsResponses( // It launches concurrent requests for all keys and waits for all goroutines to complete. // It returns the aggregated response or an error if the request fails. func HandleMultipleListModelsRequests( - ctx context.Context, + ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest, - listModelsByKey func(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError), + listModelsByKey func(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError), logger schemas.Logger, ) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { startTime := time.Now() @@ -1385,3 +1528,94 @@ func GetBudgetTokensFromReasoningEffort( return budget, nil } + +// completeDeferredSpan completes the deferred LLM span for streaming requests. +// This is called when the final chunk is processed (when StreamEndIndicator is true). +// It retrieves the deferred span handle from TraceStore using the trace ID from context, +// populates response attributes from accumulated chunks, and ends the span. +func completeDeferredSpan(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) { + if ctx == nil { + return + } + + // Get the trace ID from context (this IS available in the provider's goroutine) + traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string) + if !ok || traceID == "" { + return + } + + // Get the tracer from context + tracerVal := ctx.Value(schemas.BifrostContextKeyTracer) + if tracerVal == nil { + return + } + tracer, ok := tracerVal.(schemas.Tracer) + if !ok || tracer == nil { + return + } + + // Get the deferred span handle from TraceStore using trace ID + handle := tracer.GetDeferredSpanHandle(traceID) + if handle == nil { + return + } + + // Set total latency from the final chunk + if result != nil { + extraFields := result.GetExtraFields() + if extraFields.Latency > 0 { + tracer.SetAttribute(handle, "gen_ai.response.total_latency_ms", extraFields.Latency) + } + } + + // Get accumulated response with full data (content, tool calls, reasoning, etc.) + // This builds a complete BifrostResponse from all the streaming chunks + accumulatedResp, ttftMs, chunkCount := tracer.GetAccumulatedChunks(traceID) + if accumulatedResp != nil { + // Use accumulated response for attributes (includes full content, tool calls, etc.) + tracer.PopulateLLMResponseAttributes(handle, accumulatedResp, err) + + // Set Time to First Token (TTFT) attribute + if ttftMs > 0 { + tracer.SetAttribute(handle, schemas.AttrTimeToFirstToken, ttftMs) + } + + // Set total chunks attribute + if chunkCount > 0 { + tracer.SetAttribute(handle, schemas.AttrTotalChunks, chunkCount) + } + } else if result != nil { + // Fall back to final chunk if no accumulated data (shouldn't happen normally) + tracer.PopulateLLMResponseAttributes(handle, result, err) + } + + // Finalize aggregated post-hook spans before ending the LLM span + // This creates one span per plugin with average execution time + // We need to set the llm.call span ID in context so post-hook spans become its children + if finalizer, ok := ctx.Value(schemas.BifrostContextKeyPostHookSpanFinalizer).(func(context.Context)); ok && finalizer != nil { + // Get the deferred span ID (the llm.call span) to set as parent for post-hook spans + spanID := tracer.GetDeferredSpanID(traceID) + if spanID != "" { + finalizerCtx := context.WithValue(ctx, schemas.BifrostContextKeySpanID, spanID) + finalizer(finalizerCtx) + } else { + finalizer(ctx) + } + } + + // End span with appropriate status + if err != nil { + if err.Error != nil { + tracer.SetAttribute(handle, "error", err.Error.Message) + } + if err.StatusCode != nil { + tracer.SetAttribute(handle, "status_code", *err.StatusCode) + } + tracer.EndSpan(handle, schemas.SpanStatusError, "streaming request failed") + } else { + tracer.EndSpan(handle, schemas.SpanStatusOk, "") + } + + // Clear the deferred span from TraceStore + tracer.ClearDeferredSpan(traceID) +} diff --git a/core/providers/vertex/models.go b/core/providers/vertex/models.go index fc79107498..ce788e73fe 100644 --- a/core/providers/vertex/models.go +++ b/core/providers/vertex/models.go @@ -5,8 +5,37 @@ import ( "strings" "github.com/maximhq/bifrost/core/schemas" + "golang.org/x/text/cases" + "golang.org/x/text/language" ) +// formatDeploymentName converts a deployment alias into a human-readable name. +// It splits the alias by "-" or "_", capitalizes each word, and joins them with spaces. +// Example: "gemini-pro" → "Gemini Pro", "claude_3_opus" → "Claude 3 Opus" +func formatDeploymentName(alias string) string { + caser := cases.Title(language.English) + + // Try splitting by hyphen first, then underscore + var parts []string + if strings.Contains(alias, "-") { + parts = strings.Split(alias, "-") + } else if strings.Contains(alias, "_") { + parts = strings.Split(alias, "_") + } else { + // No delimiter found, just capitalize the whole string + return caser.String(strings.ToLower(alias)) + } + + // Capitalize each part + for i, part := range parts { + if part != "" { + parts[i] = caser.String(strings.ToLower(part)) + } + } + + return strings.Join(parts, " ") +} + // findDeploymentMatch finds a matching deployment value in the deployments map. // Returns the deployment value and alias if found, empty strings otherwise. func findDeploymentMatch(deployments map[string]string, customModelID string) (deploymentValue, alias string) { @@ -23,6 +52,21 @@ func findDeploymentMatch(deployments map[string]string, customModelID string) (d return "", "" } +// ToBifrostListModelsResponse converts a Vertex AI list models response to Bifrost's format. +// It processes both custom models (from the API response) and non-custom models (from deployments and allowedModels). +// +// Custom models are those with digit-only deployment values, extracted from the API response. +// Non-custom models are those with non-digit characters in their deployment values or model names. +// +// The function performs three passes: +// 1. First pass: Process all models from the Vertex AI API response (custom models) +// 2. Second pass: Add non-custom models from deployments that aren't already in the list +// 3. Third pass: Add non-custom models from allowedModels that aren't in deployments or already added +// +// Filtering logic: +// - If allowedModels is empty, all models are allowed +// - If allowedModels is non-empty, only models/deployments with keys in allowedModels are included +// - Deployments map is used to match model IDs to aliases and filter accordingly func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, deployments map[string]string) *schemas.BifrostListModelsResponse { if response == nil { return nil @@ -31,6 +75,11 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod bifrostResponse := &schemas.BifrostListModelsResponse{ Data: make([]schemas.Model, 0, len(response.Models)), } + + // Track which model IDs have been added to avoid duplicates + addedModelIDs := make(map[string]bool) + + // First pass: Process all models from the Vertex AI API response (custom models) for _, model := range response.Models { if len(model.DeployedModels) == 0 { continue @@ -92,8 +141,82 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod modelEntry.Deployment = schemas.Ptr(deploymentValue) } bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + addedModelIDs[modelEntry.ID] = true } } + + // Second pass: Add non-custom models from deployments + // Non-custom models are identified by having non-digit characters in their deployment values + for alias, deploymentValue := range deployments { + // Skip if deployment value contains only digits (custom model, already processed) + if schemas.IsAllDigitsASCII(deploymentValue) { + continue + } + + // Check if this deployment alias is allowed + if len(allowedModels) > 0 { + // If allowedModels is non-empty, only include if alias is in the list + if !slices.Contains(allowedModels, alias) { + continue + } + } + + // Check if model already exists in the list + modelID := string(schemas.Vertex) + "/" + alias + if addedModelIDs[modelID] { + continue + } + + // Create model entry for non-custom model + modelName := formatDeploymentName(alias) + modelEntry := schemas.Model{ + ID: modelID, + Name: schemas.Ptr(modelName), + Description: nil, // No description available for non-custom models + Created: nil, // No creation time available for non-custom models + Deployment: schemas.Ptr(deploymentValue), + } + + bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + addedModelIDs[modelID] = true + } + + // Third pass: Add non-custom models from allowedModels that aren't in deployments + // This handles cases where a model is specified in allowedModels but not explicitly mapped in deployments + if len(allowedModels) > 0 { + for _, allowedModel := range allowedModels { + // Skip if model is all digits (custom model ID) + if schemas.IsAllDigitsASCII(allowedModel) { + continue + } + + // Skip if model is already in deployments (already processed in second pass) + if _, existsInDeployments := deployments[allowedModel]; existsInDeployments { + continue + } + + // Check if model already exists in the list + modelID := string(schemas.Vertex) + "/" + allowedModel + if addedModelIDs[modelID] { + continue + } + + // Create model entry for allowed model + // Use the model name itself as the deployment value + modelName := formatDeploymentName(allowedModel) + modelEntry := schemas.Model{ + ID: modelID, + Name: schemas.Ptr(modelName), + Description: nil, // No description available for models from allowedModels + Created: nil, // No creation time available for models from allowedModels + Deployment: schemas.Ptr(allowedModel), + } + + bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + addedModelIDs[modelID] = true + } + } + bifrostResponse.NextPageToken = response.NextPageToken return bifrostResponse diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index fad8786228..56f1dc1846 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -121,7 +121,7 @@ func (provider *VertexProvider) GetProviderKey() schemas.ModelProvider { // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. // Handles pagination automatically by following nextPageToken until all models are retrieved. -func (provider *VertexProvider) listModelsByKey(ctx context.Context, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() if key.VertexKeyConfig == nil { @@ -241,7 +241,7 @@ func (provider *VertexProvider) listModelsByKey(ctx context.Context, key schemas // ListModels performs a list models request to Vertex's API. // Requests are made concurrently for improved performance. -func (provider *VertexProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *VertexProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { finalResponse, bifrostErr := providerUtils.HandleMultipleListModelsRequests( ctx, keys, @@ -258,21 +258,21 @@ func (provider *VertexProvider) ListModels(ctx context.Context, keys []schemas.K // TextCompletion is not supported by the Vertex provider. // Returns an error indicating that text completion is not available. -func (provider *VertexProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *VertexProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionRequest, provider.GetProviderKey()) } // TextCompletionStream performs a streaming text completion request to Vertex's API. // It formats the request, sends it to Vertex, and processes the response. // Returns a channel of BifrostStream objects or an error if the request fails. -func (provider *VertexProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *VertexProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } // ChatCompletion performs a chat completion request to the Vertex API. // It supports both text and image content in messages. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() if key.VertexKeyConfig == nil { @@ -562,7 +562,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas. // ChatCompletionStream performs a streaming chat completion request to the Vertex API. // It supports both OpenAI-style streaming (for non-Claude models) and Anthropic-style streaming (for Claude models). // Returns a channel of BifrostResponse objects for streaming results or an error if the request fails. -func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { providerName := provider.GetProviderKey() if key.VertexKeyConfig == nil { return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) @@ -812,6 +812,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo providerName, postHookRunner, nil, + nil, postRequestConverter, postResponseConverter, provider.logger, @@ -820,7 +821,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo } // Responses performs a responses request to the Vertex API. -func (provider *VertexProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() if key.VertexKeyConfig == nil { @@ -1075,7 +1076,7 @@ func (provider *VertexProvider) Responses(ctx context.Context, key schemas.Key, } // ResponsesStream performs a streaming responses request to the Vertex API. -func (provider *VertexProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { providerName := provider.GetProviderKey() if key.VertexKeyConfig == nil { @@ -1252,7 +1253,7 @@ func (provider *VertexProvider) ResponsesStream(ctx context.Context, postHookRun provider.logger, ) } else { - ctx = context.WithValue(ctx, schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) + ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true) return provider.ChatCompletionStream( ctx, postHookRunner, @@ -1265,7 +1266,7 @@ func (provider *VertexProvider) ResponsesStream(ctx context.Context, postHookRun // Embedding generates embeddings for the given input text(s) using Vertex AI. // All Vertex AI embedding models use the same response format regardless of the model type. // Returns a BifrostResponse containing the embedding(s) and any error that occurred. -func (provider *VertexProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() if key.VertexKeyConfig == nil { @@ -1397,22 +1398,22 @@ func (provider *VertexProvider) Embedding(ctx context.Context, key schemas.Key, } // Speech is not supported by the Vertex provider. -func (provider *VertexProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *VertexProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) } // SpeechStream is not supported by the Vertex provider. -func (provider *VertexProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *VertexProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } // Transcription is not supported by the Vertex provider. -func (provider *VertexProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *VertexProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey()) } // TranscriptionStream is not supported by the Vertex provider. -func (provider *VertexProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *VertexProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } @@ -1448,59 +1449,59 @@ func (provider *VertexProvider) getModelDeployment(key schemas.Key, model string } // BatchCreate is not supported by Vertex AI provider. -func (provider *VertexProvider) BatchCreate(ctx context.Context, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *VertexProvider) BatchCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) } // BatchList is not supported by Vertex AI provider. -func (provider *VertexProvider) BatchList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *VertexProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey()) } // BatchRetrieve is not supported by Vertex AI provider. -func (provider *VertexProvider) BatchRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *VertexProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey()) } // BatchCancel is not supported by Vertex AI provider. -func (provider *VertexProvider) BatchCancel(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *VertexProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey()) } // BatchResults is not supported by Vertex AI provider. -func (provider *VertexProvider) BatchResults(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *VertexProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey()) } // FileUpload is not yet implemented for Vertex AI provider. // Vertex AI uses Google Cloud Storage (GCS) for batch input/output files. -func (provider *VertexProvider) FileUpload(_ context.Context, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *VertexProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey()) } // FileList is not yet implemented for Vertex AI provider. -func (provider *VertexProvider) FileList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *VertexProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey()) } // FileRetrieve is not yet implemented for Vertex AI provider. -func (provider *VertexProvider) FileRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *VertexProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey()) } // FileDelete is not yet implemented for Vertex AI provider. -func (provider *VertexProvider) FileDelete(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *VertexProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey()) } // FileContent is not yet implemented for Vertex AI provider. -func (provider *VertexProvider) FileContent(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *VertexProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey()) } // CountTokens counts the number of tokens in the provided content using Vertex AI's countTokens endpoint. // Supports Gemini models with both text and image content. -func (provider *VertexProvider) CountTokens(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() if key.VertexKeyConfig == nil { diff --git a/core/providers/xai/errors.go b/core/providers/xai/errors.go new file mode 100644 index 0000000000..c621a5e7b8 --- /dev/null +++ b/core/providers/xai/errors.go @@ -0,0 +1,71 @@ +package xai + +import ( + "github.com/bytedance/sonic" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" + "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" +) + +// XAIErrorResponse represents xAI's error response format +type XAIErrorResponse struct { + Code string `json:"code"` + Error string `json:"error"` +} + +// ParseXAIError parses xAI-specific error responses. +// xAI returns errors in format: {"code": "...", "error": "..."} +// Unlike OpenAI which uses: {"error": {"message": "...", "type": "...", "code": "..."}} +func ParseXAIError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { + statusCode := resp.StatusCode() + + // Decode body + decodedBody, err := providerUtils.CheckAndDecodeBody(resp) + if err != nil { + return &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: &statusCode, + Error: &schemas.ErrorField{ + Message: err.Error(), + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + Provider: providerName, + ModelRequested: model, + RequestType: requestType, + }, + } + } + + // Try to parse xAI error format + var xaiErr XAIErrorResponse + if err := sonic.Unmarshal(decodedBody, &xaiErr); err == nil && xaiErr.Error != "" { + code := xaiErr.Code + return &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: &statusCode, + Error: &schemas.ErrorField{ + Code: &code, + Message: xaiErr.Error, + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + Provider: providerName, + ModelRequested: model, + RequestType: requestType, + }, + } + } + + // Fallback: couldn't parse as xAI format, return raw body + return &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: &statusCode, + Error: &schemas.ErrorField{ + Message: string(decodedBody), + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + Provider: providerName, + ModelRequested: model, + RequestType: requestType, + }, + } +} diff --git a/core/providers/xai/xai.go b/core/providers/xai/xai.go index 4b27fdab69..6ff94ac424 100644 --- a/core/providers/xai/xai.go +++ b/core/providers/xai/xai.go @@ -3,7 +3,6 @@ package xai import ( - "context" "strings" "time" @@ -60,7 +59,7 @@ func (provider *XAIProvider) GetProviderKey() schemas.ModelProvider { } // ListModels performs a list models request to xAI's API. -func (provider *XAIProvider) ListModels(ctx context.Context, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { +func (provider *XAIProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if provider.networkConfig.BaseURL == "" { return nil, providerUtils.NewConfigurationError("base_url is not set", provider.GetProviderKey()) } @@ -79,7 +78,7 @@ func (provider *XAIProvider) ListModels(ctx context.Context, keys []schemas.Key, } // TextCompletion performs a text completion request to the xAI API. -func (provider *XAIProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { +func (provider *XAIProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionRequest( ctx, provider.client, @@ -90,6 +89,7 @@ func (provider *XAIProvider) TextCompletion(ctx context.Context, key schemas.Key provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), + ParseXAIError, provider.logger, ) } @@ -97,7 +97,7 @@ func (provider *XAIProvider) TextCompletion(ctx context.Context, key schemas.Key // TextCompletionStream performs a streaming text completion request to xAI's API. // It formats the request, sends it to xAI, and processes the response. // Returns a channel of BifrostStream objects or an error if the request fails. -func (provider *XAIProvider) TextCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *XAIProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionStreaming( ctx, provider.client, @@ -108,6 +108,7 @@ func (provider *XAIProvider) TextCompletionStream(ctx context.Context, postHookR providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + ParseXAIError, postHookRunner, nil, provider.logger, @@ -115,7 +116,7 @@ func (provider *XAIProvider) TextCompletionStream(ctx context.Context, postHookR } // ChatCompletion performs a chat completion request to the xAI API. -func (provider *XAIProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { +func (provider *XAIProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { return openai.HandleOpenAIChatCompletionRequest( ctx, provider.client, @@ -126,6 +127,7 @@ func (provider *XAIProvider) ChatCompletion(ctx context.Context, key schemas.Key providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + ParseXAIError, provider.logger, ) } @@ -134,7 +136,7 @@ func (provider *XAIProvider) ChatCompletion(ctx context.Context, key schemas.Key // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses xAI's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *XAIProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *XAIProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { var authHeader map[string]string if key.Value != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value} @@ -152,6 +154,7 @@ func (provider *XAIProvider) ChatCompletionStream(ctx context.Context, postHookR schemas.XAI, postHookRunner, nil, + ParseXAIError, nil, nil, provider.logger, @@ -159,7 +162,7 @@ func (provider *XAIProvider) ChatCompletionStream(ctx context.Context, postHookR } // Responses performs a responses request to the xAI API. -func (provider *XAIProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { +func (provider *XAIProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { return openai.HandleOpenAIResponsesRequest( ctx, provider.client, @@ -170,12 +173,13 @@ func (provider *XAIProvider) Responses(ctx context.Context, key schemas.Key, req providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), + ParseXAIError, provider.logger, ) } // ResponsesStream performs a streaming responses request to the xAI API. -func (provider *XAIProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *XAIProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { var authHeader map[string]string if key.Value != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value} @@ -191,6 +195,7 @@ func (provider *XAIProvider) ResponsesStream(ctx context.Context, postHookRunner providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), postHookRunner, + ParseXAIError, nil, nil, provider.logger, @@ -198,80 +203,80 @@ func (provider *XAIProvider) ResponsesStream(ctx context.Context, postHookRunner } // Embedding is not supported by the xAI provider. -func (provider *XAIProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { +func (provider *XAIProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey()) } // Speech is not supported by the xAI provider. -func (provider *XAIProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { +func (provider *XAIProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey()) } // SpeechStream is not supported by the xAI provider. -func (provider *XAIProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *XAIProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey()) } // Transcription is not supported by the xAI provider. -func (provider *XAIProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { +func (provider *XAIProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey()) } // TranscriptionStream is not supported by the xAI provider. -func (provider *XAIProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *XAIProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey()) } // BatchCreate is not supported by xAI provider. -func (provider *XAIProvider) BatchCreate(_ context.Context, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { +func (provider *XAIProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) } // BatchList is not supported by xAI provider. -func (provider *XAIProvider) BatchList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { +func (provider *XAIProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey()) } // BatchRetrieve is not supported by xAI provider. -func (provider *XAIProvider) BatchRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { +func (provider *XAIProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey()) } // BatchCancel is not supported by xAI provider. -func (provider *XAIProvider) BatchCancel(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { +func (provider *XAIProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey()) } // BatchResults is not supported by xAI provider. -func (provider *XAIProvider) BatchResults(_ context.Context, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { +func (provider *XAIProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey()) } // FileUpload is not supported by xAI provider. -func (provider *XAIProvider) FileUpload(_ context.Context, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { +func (provider *XAIProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey()) } // FileList is not supported by xAI provider. -func (provider *XAIProvider) FileList(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { +func (provider *XAIProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey()) } // FileRetrieve is not supported by xAI provider. -func (provider *XAIProvider) FileRetrieve(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { +func (provider *XAIProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey()) } // FileDelete is not supported by xAI provider. -func (provider *XAIProvider) FileDelete(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { +func (provider *XAIProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey()) } // FileContent is not supported by xAI provider. -func (provider *XAIProvider) FileContent(_ context.Context, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { +func (provider *XAIProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey()) } -func (provider *XAIProvider) CountTokens(_ context.Context, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { +func (provider *XAIProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey()) } diff --git a/core/schemas/account.go b/core/schemas/account.go index 23473378ce..70a685a325 100644 --- a/core/schemas/account.go +++ b/core/schemas/account.go @@ -88,7 +88,7 @@ type Account interface { // The context can carry data from any source that sets values before the Bifrost request, // including but not limited to plugin pre-hooks, application logic, or any in app middleware sharing the context. // This enables dynamic key selection based on any context values present during the request. - GetKeysForProvider(ctx *context.Context, providerKey ModelProvider) ([]Key, error) + GetKeysForProvider(ctx context.Context, providerKey ModelProvider) ([]Key, error) // GetConfigForProvider returns the configuration for a specific provider. // This includes network settings, authentication details, and other provider-specific diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 6471e9ae9f..dd1e669a9c 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -2,18 +2,15 @@ package schemas import ( - "context" "encoding/json" "errors" - - "github.com/bytedance/sonic" ) const ( DefaultInitialPoolSize = 5000 ) -type KeySelector func(ctx *context.Context, keys []Key, providerKey ModelProvider, model string) (Key, error) +type KeySelector func(ctx *BifrostContext, keys []Key, providerKey ModelProvider, model string) (Key, error) // BifrostConfig represents the configuration for initializing a Bifrost instance. // It contains the necessary components for setting up the system including account details, @@ -22,6 +19,7 @@ type BifrostConfig struct { Account Account Plugins []Plugin Logger Logger + Tracer Tracer // Tracer for distributed tracing (nil = NoOpTracer) InitialPoolSize int // Initial pool size for sync pools in Bifrost. Higher values will reduce memory allocations but will increase memory usage. DropExcessRequests bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. MCPConfig *MCPConfig // MCP (Model Context Protocol) configuration for tool integration @@ -125,11 +123,11 @@ const ( BifrostContextKeyRequestID BifrostContextKey = "request-id" // string BifrostContextKeyFallbackRequestID BifrostContextKey = "fallback-request-id" // string BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" // Key struct - BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost)) - BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost)) - BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost)) - BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost)) 0 for primary, 1 for first fallback, etc. - BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost) + BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost - DO NOT SET THIS MANUALLY)) 0 for primary, 1 for first fallback, etc. + BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeySkipKeySelection BifrostContextKey = "bifrost-skip-key-selection" // bool (will pass an empty key to the provider) BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string][]string BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string @@ -137,9 +135,19 @@ const ( BifrostContextKeySendBackRawRequest BifrostContextKey = "bifrost-send-back-raw-request" // bool BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool BifrostContextKeyIntegrationType BifrostContextKey = "bifrost-integration-type" // integration used in gateway (e.g. openai, anthropic, bedrock, etc.) - BifrostContextKeyIsResponsesToChatCompletionFallback BifrostContextKey = "bifrost-is-responses-to-chat-completion-fallback" // bool (set by bifrost) + BifrostContextKeyIsResponsesToChatCompletionFallback BifrostContextKey = "bifrost-is-responses-to-chat-completion-fallback" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostMCPAgentOriginalRequestID BifrostContextKey = "bifrost-mcp-agent-original-request-id" // string (to store the original request ID for MCP agent mode) BifrostContextKeyStructuredOutputToolName BifrostContextKey = "bifrost-structured-output-tool-name" // string (to store the name of the structured output tool (set by bifrost)) BifrostContextKeyUserAgent BifrostContextKey = "bifrost-user-agent" // string (set by bifrost) + BifrostContextKeyTraceID BifrostContextKey = "bifrost-trace-id" // string (trace ID for distributed tracing - set by tracing middleware) + BifrostContextKeySpanID BifrostContextKey = "bifrost-span-id" // string (current span ID for child span creation - set by tracer) + BifrostContextKeyParentSpanID BifrostContextKey = "bifrost-parent-span-id" // string (parent span ID from W3C traceparent header - set by tracing middleware) + BifrostContextKeyStreamStartTime BifrostContextKey = "bifrost-stream-start-time" // time.Time (start time for streaming TTFT calculation - set by bifrost) + BifrostContextKeyTracer BifrostContextKey = "bifrost-tracer" // Tracer (tracer instance for completing deferred spans - set by bifrost) + BifrostContextKeyDeferTraceCompletion BifrostContextKey = "bifrost-defer-trace-completion" // bool (signals trace completion should be deferred for streaming - set by streaming handlers) + BifrostContextKeyTraceCompleter BifrostContextKey = "bifrost-trace-completer" // func() (callback to complete trace after streaming - set by tracing middleware) + BifrostContextKeyPostHookSpanFinalizer BifrostContextKey = "bifrost-posthook-span-finalizer" // func(context.Context) (callback to finalize post-hook spans after streaming - set by bifrost) + BifrostContextKeyAccumulatorID BifrostContextKey = "bifrost-accumulator-id" // string (ID for streaming accumulator lookup - set by tracer for accumulator operations) ) // NOTE: for custom plugin implementation dealing with streaming short circuit, @@ -440,6 +448,7 @@ type BifrostCacheDebug struct { const ( RequestCancelled = "request_cancelled" + RequestTimedOut = "request_timed_out" ) // BifrostStream represents a stream of responses from the Bifrost system. @@ -457,17 +466,17 @@ type BifrostStream struct { // This ensures that only the non-nil embedded struct is marshaled, func (bs BifrostStream) MarshalJSON() ([]byte, error) { if bs.BifrostTextCompletionResponse != nil { - return sonic.Marshal(bs.BifrostTextCompletionResponse) + return Marshal(bs.BifrostTextCompletionResponse) } else if bs.BifrostChatResponse != nil { - return sonic.Marshal(bs.BifrostChatResponse) + return Marshal(bs.BifrostChatResponse) } else if bs.BifrostResponsesStreamResponse != nil { - return sonic.Marshal(bs.BifrostResponsesStreamResponse) + return Marshal(bs.BifrostResponsesStreamResponse) } else if bs.BifrostSpeechStreamResponse != nil { - return sonic.Marshal(bs.BifrostSpeechStreamResponse) + return Marshal(bs.BifrostSpeechStreamResponse) } else if bs.BifrostTranscriptionStreamResponse != nil { - return sonic.Marshal(bs.BifrostTranscriptionStreamResponse) + return Marshal(bs.BifrostTranscriptionStreamResponse) } else if bs.BifrostError != nil { - return sonic.Marshal(bs.BifrostError) + return Marshal(bs.BifrostError) } // Return empty object if both are nil (shouldn't happen in practice) return []byte("{}"), nil @@ -487,7 +496,7 @@ type BifrostError struct { Error *ErrorField `json:"error"` AllowFallbacks *bool `json:"-"` // Optional: Controls fallback behavior (nil = true by default) StreamControl *StreamControl `json:"-"` // Optional: Controls stream behavior - ExtraFields BifrostErrorExtraFields `json:"extra_fields,omitempty"` + ExtraFields BifrostErrorExtraFields `json:"extra_fields"` } // StreamControl represents stream control options. diff --git a/core/schemas/chatcompletions.go b/core/schemas/chatcompletions.go index 4f9091b965..5aae9d6d2f 100644 --- a/core/schemas/chatcompletions.go +++ b/core/schemas/chatcompletions.go @@ -4,8 +4,6 @@ import ( "bytes" "fmt" "sort" - - "github.com/bytedance/sonic" ) // BifrostChatRequest is the request struct for chat completion requests @@ -200,7 +198,7 @@ func (cp *ChatParameters) UnmarshalJSON(data []byte) error { aux.Alias = (*Alias)(cp) // Single unmarshal - if err := sonic.Unmarshal(data, &aux); err != nil { + if err := Unmarshal(data, &aux); err != nil { return err } @@ -288,11 +286,11 @@ type ToolFunctionParameters struct { func (t *ToolFunctionParameters) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a JSON string (xAI format) var jsonStr string - if err := sonic.Unmarshal(data, &jsonStr); err == nil { + if err := Unmarshal(data, &jsonStr); err == nil { // It's a string, so parse the string as JSON type Alias ToolFunctionParameters var temp Alias - if err := sonic.Unmarshal([]byte(jsonStr), &temp); err != nil { + if err := Unmarshal([]byte(jsonStr), &temp); err != nil { return fmt.Errorf("failed to unmarshal parameters string: %w", err) } *t = ToolFunctionParameters(temp) @@ -302,7 +300,7 @@ func (t *ToolFunctionParameters) UnmarshalJSON(data []byte) error { // Otherwise, unmarshal as a normal JSON object type Alias ToolFunctionParameters var temp Alias - if err := sonic.Unmarshal(data, &temp); err != nil { + if err := Unmarshal(data, &temp); err != nil { return err } *t = ToolFunctionParameters(temp) @@ -370,7 +368,7 @@ func (om OrderedMap) MarshalJSON() ([]byte, error) { } // key - keyBytes, err := sonic.Marshal(k) + keyBytes, err := Marshal(k) if err != nil { return nil, err } @@ -378,7 +376,7 @@ func (om OrderedMap) MarshalJSON() ([]byte, error) { buf.WriteByte(':') // value - valBytes, err := sonic.Marshal(norm[k]) + valBytes, err := Marshal(norm[k]) if err != nil { return nil, err } @@ -443,13 +441,13 @@ func (ctc ChatToolChoice) MarshalJSON() ([]byte, error) { } if ctc.ChatToolChoiceStr != nil { - return sonic.Marshal(ctc.ChatToolChoiceStr) + return Marshal(ctc.ChatToolChoiceStr) } if ctc.ChatToolChoiceStruct != nil { - return sonic.Marshal(ctc.ChatToolChoiceStruct) + return Marshal(ctc.ChatToolChoiceStruct) } // If both are nil, return null - return sonic.Marshal(nil) + return Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ChatMessageContent. @@ -458,7 +456,7 @@ func (ctc ChatToolChoice) MarshalJSON() ([]byte, error) { func (ctc *ChatToolChoice) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var toolChoiceStr string - if err := sonic.Unmarshal(data, &toolChoiceStr); err == nil { + if err := Unmarshal(data, &toolChoiceStr); err == nil { ctc.ChatToolChoiceStr = &toolChoiceStr ctc.ChatToolChoiceStruct = nil return nil @@ -466,7 +464,7 @@ func (ctc *ChatToolChoice) UnmarshalJSON(data []byte) error { // Try to unmarshal as a direct array of ContentBlock var chatToolChoice ChatToolChoiceStruct - if err := sonic.Unmarshal(data, &chatToolChoice); err == nil { + if err := Unmarshal(data, &chatToolChoice); err == nil { ctc.ChatToolChoiceStr = nil ctc.ChatToolChoiceStruct = &chatToolChoice return nil @@ -523,7 +521,7 @@ type ChatMessage struct { // UnmarshalJSON implements custom JSON unmarshalling for ChatMessage. // This is needed because ChatAssistantMessage has a custom UnmarshalJSON method, -// which interferes with sonic's handling of other fields in ChatMessage. +// which interferes with the JSON library's handling of other fields in ChatMessage. func (cm *ChatMessage) UnmarshalJSON(data []byte) error { // Unmarshal the base fields directly type baseFields struct { @@ -532,7 +530,7 @@ func (cm *ChatMessage) UnmarshalJSON(data []byte) error { Content *ChatMessageContent `json:"content,omitempty"` } var base baseFields - if err := sonic.Unmarshal(data, &base); err != nil { + if err := Unmarshal(data, &base); err != nil { return err } cm.Name = base.Name @@ -542,7 +540,7 @@ func (cm *ChatMessage) UnmarshalJSON(data []byte) error { // Unmarshal ChatToolMessage fields type toolMsgAlias ChatToolMessage var toolMsg toolMsgAlias - if err := sonic.Unmarshal(data, &toolMsg); err != nil { + if err := Unmarshal(data, &toolMsg); err != nil { return err } if toolMsg.ToolCallID != nil { @@ -551,7 +549,7 @@ func (cm *ChatMessage) UnmarshalJSON(data []byte) error { // Unmarshal ChatAssistantMessage (which has its own custom unmarshaller) var assistantMsg ChatAssistantMessage - if err := sonic.Unmarshal(data, &assistantMsg); err != nil { + if err := Unmarshal(data, &assistantMsg); err != nil { return err } // Only set if any field is populated @@ -579,13 +577,13 @@ func (mc ChatMessageContent) MarshalJSON() ([]byte, error) { } if mc.ContentStr != nil { - return sonic.Marshal(*mc.ContentStr) + return Marshal(*mc.ContentStr) } if mc.ContentBlocks != nil { - return sonic.Marshal(mc.ContentBlocks) + return Marshal(mc.ContentBlocks) } // If both are nil, return null - return sonic.Marshal(nil) + return Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ChatMessageContent. @@ -601,7 +599,7 @@ func (mc *ChatMessageContent) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var stringContent string - if err := sonic.Unmarshal(data, &stringContent); err == nil { + if err := Unmarshal(data, &stringContent); err == nil { mc.ContentStr = &stringContent mc.ContentBlocks = nil return nil @@ -609,7 +607,7 @@ func (mc *ChatMessageContent) UnmarshalJSON(data []byte) error { // Try to unmarshal as a direct array of ContentBlock var arrayContent []ChatContentBlock - if err := sonic.Unmarshal(data, &arrayContent); err == nil { + if err := Unmarshal(data, &arrayContent); err == nil { mc.ContentBlocks = arrayContent mc.ContentStr = nil return nil @@ -709,7 +707,7 @@ func (cm *ChatAssistantMessage) UnmarshalJSON(data []byte) error { ReasoningContent *string `json:"reasoning_content,omitempty"` // xAI uses this field name } - if err := sonic.Unmarshal(data, &aux); err != nil { + if err := Unmarshal(data, &aux); err != nil { return err } @@ -856,7 +854,7 @@ func (d *ChatStreamResponseChoiceDelta) UnmarshalJSON(data []byte) error { ReasoningContent *string `json:"reasoning_content,omitempty"` // xAI uses this field name } - if err := sonic.Unmarshal(data, &aux); err != nil { + if err := Unmarshal(data, &aux); err != nil { return err } @@ -945,7 +943,7 @@ type BifrostCost struct { func (bc *BifrostCost) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct float var costFloat float64 - if err := sonic.Unmarshal(data, &costFloat); err == nil { + if err := Unmarshal(data, &costFloat); err == nil { bc.TotalCost = costFloat return nil } @@ -954,7 +952,7 @@ func (bc *BifrostCost) UnmarshalJSON(data []byte) error { // Use a type alias to avoid infinite recursion type Alias BifrostCost var costStruct Alias - if err := sonic.Unmarshal(data, &costStruct); err == nil { + if err := Unmarshal(data, &costStruct); err == nil { *bc = BifrostCost(costStruct) return nil } diff --git a/core/schemas/context.go b/core/schemas/context.go index b84267d39e..c39afd8713 100644 --- a/core/schemas/context.go +++ b/core/schemas/context.go @@ -4,6 +4,7 @@ import ( "context" "slices" "sync" + "sync/atomic" "time" ) @@ -22,24 +23,26 @@ var reservedKeys = []any{ BifrostContextKeySkipKeySelection, BifrostContextKeyExtraHeaders, BifrostContextKeyURLPath, + BifrostContextKeyDeferTraceCompletion, } // BifrostContext is a custom context.Context implementation that tracks user-set values. // It supports deadlines, can be derived from other contexts, and provides layered // value inheritance when derived from another BifrostContext. type BifrostContext struct { - parent context.Context - deadline time.Time - hasDeadline bool - done chan struct{} - doneOnce sync.Once - err error - errMu sync.RWMutex - userValues map[any]any - valuesMu sync.RWMutex -} - -// NewBifrostContext creates a new PluginContext with the given parent context and deadline. + parent context.Context + deadline time.Time + hasDeadline bool + done chan struct{} + doneOnce sync.Once + err error + errMu sync.RWMutex + userValues map[any]any + valuesMu sync.RWMutex + blockRestrictedWrites atomic.Bool +} + +// NewBifrostContext creates a new BifrostContext with the given parent context and deadline. // If the deadline is zero, no deadline is set on this context (though the parent may have one). // The context will be cancelled when the deadline expires or when the parent context is cancelled. func NewBifrostContext(parent context.Context, deadline time.Time) *BifrostContext { @@ -47,29 +50,65 @@ func NewBifrostContext(parent context.Context, deadline time.Time) *BifrostConte parent = context.Background() } ctx := &BifrostContext{ - parent: parent, - deadline: deadline, - hasDeadline: !deadline.IsZero(), - done: make(chan struct{}), - userValues: make(map[any]any), + parent: parent, + deadline: deadline, + hasDeadline: !deadline.IsZero(), + done: make(chan struct{}), + userValues: make(map[any]any), + blockRestrictedWrites: atomic.Bool{}, } + ctx.blockRestrictedWrites.Store(false) // Only start goroutine if there's something to watch: // - If we have a deadline, we need the timer - // - If parent can be cancelled (Done() != nil), we need to propagate cancellation - if ctx.hasDeadline || parent.Done() != nil { + // - If parent can be cancelled (Done() != nil) AND is not a non-cancelling context + // - If parent has a deadline, we need a timer (parent may not properly cancel via Done()) + _, parentHasDeadline := parent.Deadline() + parentCanCancel := parent.Done() != nil && !isNonCancellingContext(parent) + if ctx.hasDeadline || parentCanCancel || parentHasDeadline { go ctx.watchCancellation() } return ctx } -// NewBifrostContextWithTimeout creates a new PluginContext with a timeout duration. -// This is a convenience wrapper around NewPluginContext. +// NewBifrostContextWithValue creates a new BifrostContext with the given value set. +func NewBifrostContextWithValue(parent context.Context, deadline time.Time, key any, value any) *BifrostContext { + ctx := NewBifrostContext(parent, deadline) + ctx.SetValue(key, value) + return ctx +} + +// NewBifrostContextWithTimeout creates a new BifrostContext with a timeout duration. +// This is a convenience wrapper around NewBifrostContext. // Returns the context and a cancel function that should be called to release resources. func NewBifrostContextWithTimeout(parent context.Context, timeout time.Duration) (*BifrostContext, context.CancelFunc) { ctx := NewBifrostContext(parent, time.Now().Add(timeout)) return ctx, func() { ctx.Cancel() } } +// NewBifrostContextWithCancel creates a new BifrostContext with a cancel function. +// This is a convenience wrapper around NewBifrostContext. +// Returns the context and a cancel function that should be called to release resources. +func NewBifrostContextWithCancel(parent context.Context) (*BifrostContext, context.CancelFunc) { + ctx := NewBifrostContext(parent, NoDeadline) + return ctx, func() { ctx.Cancel() } +} + +// WithValue returns a new context with the given value set. +func (bc *BifrostContext) WithValue(key any, value any) *BifrostContext { + bc.SetValue(key, value) + return bc +} + +// BlockRestrictedWrites returns true if restricted writes are blocked. +func (bc *BifrostContext) BlockRestrictedWrites() { + bc.blockRestrictedWrites.Store(true) +} + +// UnblockRestrictedWrites unblocks restricted writes. +func (bc *BifrostContext) UnblockRestrictedWrites() { + bc.blockRestrictedWrites.Store(false) +} + // Cancel cancels the context, closing the Done channel and setting the error to context.Canceled. func (bc *BifrostContext) Cancel() { bc.cancel(context.Canceled) @@ -78,8 +117,12 @@ func (bc *BifrostContext) Cancel() { // watchCancellation monitors for deadline expiration and parent cancellation. func (bc *BifrostContext) watchCancellation() { var timer <-chan time.Time - if bc.hasDeadline { - duration := time.Until(bc.deadline) + + // Use effective deadline (considers both own and parent deadlines) + // This handles cases where parent has a deadline but doesn't properly + // cancel via Done() (e.g., fasthttp.RequestCtx) + if effectiveDeadline, hasDeadline := bc.Deadline(); hasDeadline { + duration := time.Until(effectiveDeadline) if duration <= 0 { // Deadline already passed bc.cancel(context.DeadlineExceeded) @@ -90,6 +133,18 @@ func (bc *BifrostContext) watchCancellation() { timer = t.C } + // Don't watch parent.Done() for contexts known to never close it + // (e.g., fasthttp.RequestCtx pools contexts and never cancels them) + if isNonCancellingContext(bc.parent) { + select { + case <-timer: + bc.cancel(context.DeadlineExceeded) + case <-bc.done: + // Already cancelled + } + return + } + select { case <-bc.parent.Done(): bc.cancel(bc.parent.Err()) @@ -164,13 +219,33 @@ func (bc *BifrostContext) Value(key any) any { // This is thread-safe and can be called concurrently. func (bc *BifrostContext) SetValue(key, value any) { // Check if the key is a reserved key - if slices.Contains(reservedKeys, key) { - // we silently drop writes for these reserved keys + if bc.blockRestrictedWrites.Load() && slices.Contains(reservedKeys, key) { + // we silently drop writes for these reserved keys return } bc.valuesMu.Lock() defer bc.valuesMu.Unlock() + if bc.userValues == nil { + bc.userValues = make(map[any]any) + } + bc.userValues[key] = value +} + +// GetAndSetValue gets a value from the internal userValues map and sets it +func (bc *BifrostContext) GetAndSetValue(key any, value any) any { + bc.valuesMu.Lock() + defer bc.valuesMu.Unlock() + // Check if the key is a reserved key + if bc.blockRestrictedWrites.Load() && slices.Contains(reservedKeys, key) { + // we silently drop writes for these reserved keys + return bc.userValues[key] + } + if bc.userValues == nil { + bc.userValues = make(map[any]any) + } + oldValue := bc.userValues[key] bc.userValues[key] = value + return oldValue } // GetUserValues returns a copy of all user-set values in this context. diff --git a/core/schemas/context_native.go b/core/schemas/context_native.go new file mode 100644 index 0000000000..cf1a764ec4 --- /dev/null +++ b/core/schemas/context_native.go @@ -0,0 +1,12 @@ +//go:build !tinygo && !wasm + +package schemas + +import "github.com/valyala/fasthttp" + +// isNonCancellingContext returns true if the context is known to have +// a Done() channel that never closes (e.g., fasthttp.RequestCtx). +func isNonCancellingContext(parent any) bool { + _, ok := parent.(*fasthttp.RequestCtx) + return ok +} diff --git a/core/schemas/context_wasm.go b/core/schemas/context_wasm.go new file mode 100644 index 0000000000..981bb598c4 --- /dev/null +++ b/core/schemas/context_wasm.go @@ -0,0 +1,10 @@ +//go:build tinygo || wasm + +package schemas + +// isNonCancellingContext returns true if the context is known to have +// a Done() channel that never closes. In wasm builds, fasthttp is not +// available, so this always returns false. +func isNonCancellingContext(parent any) bool { + return false +} diff --git a/core/schemas/embedding.go b/core/schemas/embedding.go index 73f0d2664c..e1fbbe9f3d 100644 --- a/core/schemas/embedding.go +++ b/core/schemas/embedding.go @@ -2,8 +2,6 @@ package schemas import ( "fmt" - - "github.com/bytedance/sonic" ) type BifrostEmbeddingRequest struct { @@ -58,16 +56,16 @@ func (e *EmbeddingInput) MarshalJSON() ([]byte, error) { } if e.Text != nil { - return sonic.Marshal(*e.Text) + return Marshal(*e.Text) } if e.Texts != nil { - return sonic.Marshal(e.Texts) + return Marshal(e.Texts) } if e.Embedding != nil { - return sonic.Marshal(e.Embedding) + return Marshal(e.Embedding) } if e.Embeddings != nil { - return sonic.Marshal(e.Embeddings) + return Marshal(e.Embeddings) } return nil, fmt.Errorf("invalid embedding input") @@ -80,25 +78,25 @@ func (e *EmbeddingInput) UnmarshalJSON(data []byte) error { e.Embeddings = nil // Try string var s string - if err := sonic.Unmarshal(data, &s); err == nil { + if err := Unmarshal(data, &s); err == nil { e.Text = &s return nil } // Try []string var ss []string - if err := sonic.Unmarshal(data, &ss); err == nil { + if err := Unmarshal(data, &ss); err == nil { e.Texts = ss return nil } // Try []int var i []int - if err := sonic.Unmarshal(data, &i); err == nil { + if err := Unmarshal(data, &i); err == nil { e.Embedding = i return nil } // Try [][]int var i2 [][]int - if err := sonic.Unmarshal(data, &i2); err == nil { + if err := Unmarshal(data, &i2); err == nil { e.Embeddings = i2 return nil } @@ -129,13 +127,13 @@ type EmbeddingStruct struct { func (be EmbeddingStruct) MarshalJSON() ([]byte, error) { if be.EmbeddingStr != nil { - return sonic.Marshal(be.EmbeddingStr) + return Marshal(be.EmbeddingStr) } if be.EmbeddingArray != nil { - return sonic.Marshal(be.EmbeddingArray) + return Marshal(be.EmbeddingArray) } if be.Embedding2DArray != nil { - return sonic.Marshal(be.Embedding2DArray) + return Marshal(be.Embedding2DArray) } return nil, fmt.Errorf("no embedding found") } @@ -143,21 +141,21 @@ func (be EmbeddingStruct) MarshalJSON() ([]byte, error) { func (be *EmbeddingStruct) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var stringContent string - if err := sonic.Unmarshal(data, &stringContent); err == nil { + if err := Unmarshal(data, &stringContent); err == nil { be.EmbeddingStr = &stringContent return nil } // Try to unmarshal as a direct array of float32 var arrayContent []float32 - if err := sonic.Unmarshal(data, &arrayContent); err == nil { + if err := Unmarshal(data, &arrayContent); err == nil { be.EmbeddingArray = arrayContent return nil } // Try to unmarshal as a direct 2D array of float32 var arrayContent2D [][]float32 - if err := sonic.Unmarshal(data, &arrayContent2D); err == nil { + if err := Unmarshal(data, &arrayContent2D); err == nil { be.Embedding2DArray = arrayContent2D return nil } diff --git a/core/schemas/json_native.go b/core/schemas/json_native.go new file mode 100644 index 0000000000..3d21c91444 --- /dev/null +++ b/core/schemas/json_native.go @@ -0,0 +1,20 @@ +//go:build !tinygo && !wasm + +package schemas + +import "github.com/bytedance/sonic" + +// Marshal encodes v to JSON bytes using the high-performance sonic library. +func Marshal(v interface{}) ([]byte, error) { + return sonic.Marshal(v) +} + +// MarshalString encodes v to a JSON string using sonic. +func MarshalString(v interface{}) (string, error) { + return sonic.MarshalString(v) +} + +// Unmarshal decodes JSON data into v using sonic. +func Unmarshal(data []byte, v interface{}) error { + return sonic.Unmarshal(data, v) +} diff --git a/core/schemas/json_wasm.go b/core/schemas/json_wasm.go new file mode 100644 index 0000000000..f04c328d2f --- /dev/null +++ b/core/schemas/json_wasm.go @@ -0,0 +1,24 @@ +//go:build tinygo || wasm + +package schemas + +import "encoding/json" + +// Marshal encodes v to JSON bytes using the standard library. +func Marshal(v interface{}) ([]byte, error) { + return json.Marshal(v) +} + +// MarshalString encodes v to a JSON string using the standard library. +func MarshalString(v interface{}) (string, error) { + data, err := json.Marshal(v) + if err != nil { + return "", err + } + return string(data), nil +} + +// Unmarshal decodes JSON data into v using the standard library. +func Unmarshal(data []byte, v interface{}) error { + return json.Unmarshal(data, v) +} diff --git a/core/schemas/mcp.go b/core/schemas/mcp.go index e26409e122..3c19880a6d 100644 --- a/core/schemas/mcp.go +++ b/core/schemas/mcp.go @@ -1,32 +1,72 @@ +//go:build !tinygo && !wasm + // Package schemas defines the core schemas and types used by the Bifrost system. package schemas -// MCPServerInstance represents an MCP server instance for InProcess connections. -// This should be a *github.com/mark3labs/mcp-go/server.MCPServer instance. -// We use interface{} to avoid creating a dependency on the mcp-go package in schemas. -type MCPServerInstance interface{} +import ( + "context" + "time" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/server" +) // MCPConfig represents the configuration for MCP integration in Bifrost. // It enables tool auto-discovery and execution from local and external MCP servers. type MCPConfig struct { - ClientConfigs []MCPClientConfig `json:"client_configs,omitempty"` // Per-client execution configurations + ClientConfigs []MCPClientConfig `json:"client_configs,omitempty"` // Per-client execution configurations + ToolManagerConfig *MCPToolManagerConfig `json:"tool_manager_config,omitempty"` // MCP tool manager configuration + + // Function to fetch a new request ID for each tool call result message in agent mode, + // this is used to ensure that the tool call result messages are unique and can be tracked in plugins or by the user. + // This id is attached to ctx.Value(schemas.BifrostContextKeyRequestID) in the agent mode. + // If not provider, same request ID is used for all tool call result messages without any overrides. + FetchNewRequestIDFunc func(ctx *BifrostContext) string `json:"-"` } +type MCPToolManagerConfig struct { + ToolExecutionTimeout time.Duration `json:"tool_execution_timeout"` + MaxAgentDepth int `json:"max_agent_depth"` + CodeModeBindingLevel CodeModeBindingLevel `json:"code_mode_binding_level,omitempty"` // How tools are exposed in VFS: "server" or "tool" +} + +const ( + DefaultMaxAgentDepth = 10 + DefaultToolExecutionTimeout = 30 * time.Second +) + +// CodeModeBindingLevel defines how tools are exposed in the VFS for code execution +type CodeModeBindingLevel string + +const ( + CodeModeBindingLevelServer CodeModeBindingLevel = "server" + CodeModeBindingLevelTool CodeModeBindingLevel = "tool" +) + // MCPClientConfig defines tool filtering for an MCP client. type MCPClientConfig struct { ID string `json:"id"` // Client ID Name string `json:"name"` // Client name + IsCodeModeClient bool `json:"is_code_mode_client"` // Whether the client is a code mode client ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess) ConnectionString *string `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections) StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections) Headers map[string]string `json:"headers,omitempty"` // Headers to send with the request - InProcessServer MCPServerInstance `json:"-"` // MCP server instance for in-process connections (Go package only) + InProcessServer *server.MCPServer `json:"-"` // MCP server instance for in-process connections (Go package only) ToolsToExecute []string `json:"tools_to_execute,omitempty"` // Include-only list. // ToolsToExecute semantics: // - ["*"] => all tools are included // - [] => no tools are included (deny-by-default) // - nil/omitted => treated as [] (no tools) // - ["tool1", "tool2"] => include only the specified tools + ToolsToAutoExecute []string `json:"tools_to_auto_execute,omitempty"` // Auto-execute list. + // ToolsToAutoExecute semantics: + // - ["*"] => all tools are auto-executed + // - [] => no tools are auto-executed (deny-by-default) + // - nil/omitted => treated as [] (no tools) + // - ["tool1", "tool2"] => auto-execute only the specified tools + // Note: If a tool is in ToolsToAutoExecute but not in ToolsToExecute, it will be skipped. + ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) } // MCPConnectionType defines the communication protocol for MCP connections @@ -54,9 +94,28 @@ const ( MCPConnectionStateError MCPConnectionState = "error" // Client is in an error state, and cannot be used ) +// MCPClientState represents a connected MCP client with its configuration and tools. +// It is used internally by the MCP manager to track the state of a connected MCP client. +type MCPClientState struct { + Name string // Unique name for this client + Conn *client.Client // Active MCP client connection + ExecutionConfig MCPClientConfig // Tool filtering settings + ToolMap map[string]ChatTool // Available tools mapped by name + ConnectionInfo MCPClientConnectionInfo `json:"connection_info"` // Connection metadata for management + CancelFunc context.CancelFunc `json:"-"` // Cancel function for SSE connections (not serialized) + State MCPConnectionState // Connection state (connected, disconnected, error) +} + +// MCPClientConnectionInfo stores metadata about how a client is connected. +type MCPClientConnectionInfo struct { + Type MCPConnectionType `json:"type"` // Connection type (HTTP, STDIO, SSE, or InProcess) + ConnectionURL *string `json:"connection_url,omitempty"` // HTTP/SSE endpoint URL (for HTTP/SSE connections) + StdioCommandString *string `json:"stdio_command_string,omitempty"` // Command string for display (for STDIO connections) +} + // MCPClient represents a connected MCP client with its configuration and tools, // and connection information, after it has been initialized. -// It is returned by GetMCPClients() method. +// It is returned by GetMCPClients() method in bifrost. type MCPClient struct { Config MCPClientConfig `json:"config"` // Tool filtering settings Tools []ChatToolFunction `json:"tools"` // Available tools diff --git a/core/schemas/mcp_wasm.go b/core/schemas/mcp_wasm.go new file mode 100644 index 0000000000..1a34e39b26 --- /dev/null +++ b/core/schemas/mcp_wasm.go @@ -0,0 +1,7 @@ +//go:build tinygo || wasm + +package schemas + +// MCPConfig is a stub for WASM builds. +// MCP functionality is not available in WASM plugins. +type MCPConfig struct{} diff --git a/core/schemas/models.go b/core/schemas/models.go index ab1e627559..4226e6e31e 100644 --- a/core/schemas/models.go +++ b/core/schemas/models.go @@ -3,8 +3,6 @@ package schemas import ( "encoding/base64" "fmt" - - "github.com/bytedance/sonic" ) // DefaultPageSize is the default page size for listing models @@ -182,7 +180,7 @@ func encodePaginationCursor(offset int, lastID string) (string, error) { LastID: lastID, } - jsonData, err := sonic.Marshal(cursor) + jsonData, err := Marshal(cursor) if err != nil { return "", fmt.Errorf("failed to marshal pagination cursor: %w", err) } @@ -206,7 +204,7 @@ func decodePaginationCursor(token string) paginationCursor { } var cursor paginationCursor - if err := sonic.Unmarshal(decoded, &cursor); err != nil { + if err := Unmarshal(decoded, &cursor); err != nil { return paginationCursor{} } diff --git a/core/schemas/mux.go b/core/schemas/mux.go index de8392900d..09addb7a9d 100644 --- a/core/schemas/mux.go +++ b/core/schemas/mux.go @@ -113,6 +113,130 @@ func (rt *ResponsesTool) ToChatTool() *ChatTool { return ct } +// ToChatAssistantMessageToolCall converts a ResponsesToolMessage to ChatAssistantMessageToolCall format. +// This is useful for executing Responses API tool calls using the Chat API tool executor. +// +// Returns: +// - *ChatAssistantMessageToolCall: The converted tool call in Chat API format +// +// Example: +// +// responsesToolMsg := &ResponsesToolMessage{ +// CallID: Ptr("call-123"), +// Name: Ptr("calculate"), +// Arguments: Ptr("{\"x\": 10, \"y\": 20}"), +// } +// chatToolCall := responsesToolMsg.ToChatAssistantMessageToolCall() +func (rtm *ResponsesToolMessage) ToChatAssistantMessageToolCall() *ChatAssistantMessageToolCall { + if rtm == nil { + return nil + } + + toolCall := &ChatAssistantMessageToolCall{ + ID: rtm.CallID, + Type: Ptr("function"), + Function: ChatAssistantMessageToolCallFunction{ + Name: rtm.Name, + Arguments: "{}", // Default to empty JSON object for valid JSON unmarshaling + }, + } + + // Extract arguments string + if rtm.Arguments != nil { + toolCall.Function.Arguments = *rtm.Arguments + } + + return toolCall +} + +// ToResponsesToolMessage converts a ChatToolMessage (tool execution result) to ResponsesToolMessage format. +// This creates a function_call_output message suitable for the Responses API. +// +// Returns: +// - *ResponsesMessage: A ResponsesMessage with type=function_call_output containing the tool result +// +// Example: +// +// chatToolMsg := &ChatMessage{ +// Role: ChatMessageRoleTool, +// ChatToolMessage: &ChatToolMessage{ +// ToolCallID: Ptr("call-123"), +// }, +// Content: &ChatMessageContent{ +// ContentStr: Ptr("Result: 30"), +// }, +// } +// responsesMsg := chatToolMsg.ToResponsesToolMessage() +func (cm *ChatMessage) ToResponsesToolMessage() *ResponsesMessage { + if cm == nil || cm.ChatToolMessage == nil { + return nil + } + + msgType := ResponsesMessageTypeFunctionCallOutput + + respMsg := &ResponsesMessage{ + Type: &msgType, + ResponsesToolMessage: &ResponsesToolMessage{ + CallID: cm.ChatToolMessage.ToolCallID, + }, + } + + // Extract output from content + if cm.Content != nil { + if cm.Content.ContentStr != nil { + output := *cm.Content.ContentStr + respMsg.ResponsesToolMessage.Output = &ResponsesToolMessageOutputStruct{ + ResponsesToolCallOutputStr: &output, + } + } else if len(cm.Content.ContentBlocks) > 0 { + // For structured content blocks, convert to ResponsesMessageContentBlock + respBlocks := make([]ResponsesMessageContentBlock, len(cm.Content.ContentBlocks)) + for i, block := range cm.Content.ContentBlocks { + respBlocks[i] = ResponsesMessageContentBlock{ + Type: ResponsesMessageContentBlockType(block.Type), + Text: block.Text, + CacheControl: block.CacheControl, + } + + // Map image + if block.ImageURLStruct != nil { + respBlocks[i].ResponsesInputMessageContentBlockImage = &ResponsesInputMessageContentBlockImage{ + ImageURL: &block.ImageURLStruct.URL, + Detail: block.ImageURLStruct.Detail, + } + } + + // Map file + if block.File != nil { + respBlocks[i].FileID = block.File.FileID + respBlocks[i].ResponsesInputMessageContentBlockFile = &ResponsesInputMessageContentBlockFile{ + FileData: block.File.FileData, + Filename: block.File.Filename, + FileType: block.File.FileType, + } + } + + // Map audio + if block.InputAudio != nil { + format := "" + if block.InputAudio.Format != nil { + format = *block.InputAudio.Format + } + respBlocks[i].Audio = &ResponsesInputMessageContentBlockAudio{ + Data: block.InputAudio.Data, + Format: format, + } + } + } + respMsg.ResponsesToolMessage.Output = &ResponsesToolMessageOutputStruct{ + ResponsesFunctionToolCallOutputBlocks: respBlocks, + } + } + } + + return respMsg +} + // ============================================================================= // TOOL CHOICE CONVERSION METHODS // ============================================================================= @@ -324,14 +448,17 @@ func (cm *ChatMessage) ToResponsesMessages() []ResponsesMessage { role = ResponsesInputMessageRoleSystem case ChatMessageRoleTool: messageType = ResponsesMessageTypeFunctionCallOutput - role = ResponsesInputMessageRoleUser // Tool messages are typically user role in responses + role = "" // tool call output messages don't include a role field case ChatMessageRoleDeveloper: role = ResponsesInputMessageRoleDeveloper } rm := ResponsesMessage{ Type: &messageType, - Role: &role, + } + + if role != "" { + rm.Role = &role } // Handle refusal content specifically - use content blocks with ResponsesOutputMessageContentRefusal @@ -347,7 +474,10 @@ func (cm *ChatMessage) ToResponsesMessages() []ResponsesMessage { } } else if cm.Content != nil && cm.Content.ContentStr != nil { // Convert regular string content (if input message then ContentStr, else ContentBlocks) - if cm.Role == ChatMessageRoleAssistant { + // Skip setting content for function_call_output - content should only be in output field + if messageType == ResponsesMessageTypeFunctionCallOutput { + // Don't set content for function_call_output - it will be set in ResponsesToolMessage.Output + } else if cm.Role == ChatMessageRoleAssistant { rm.Content = &ResponsesMessageContent{ ContentBlocks: []ResponsesMessageContentBlock{ {Type: ResponsesOutputMessageContentTypeText, Text: cm.Content.ContentStr}, @@ -360,57 +490,62 @@ func (cm *ChatMessage) ToResponsesMessages() []ResponsesMessage { } } else if cm.Content != nil && cm.Content.ContentBlocks != nil { // Convert content blocks - responseBlocks := make([]ResponsesMessageContentBlock, len(cm.Content.ContentBlocks)) - for i, block := range cm.Content.ContentBlocks { - blockType := ResponsesMessageContentBlockType(block.Type) - - switch block.Type { - case ChatContentBlockTypeText: - if cm.Role == ChatMessageRoleAssistant { - blockType = ResponsesOutputMessageContentTypeText - } else { - blockType = ResponsesInputMessageContentBlockTypeText + // Skip setting content blocks for function_call_output + if messageType == ResponsesMessageTypeFunctionCallOutput { + // Don't set content for function_call_output - it will be set in ResponsesToolMessage.Output + } else { + responseBlocks := make([]ResponsesMessageContentBlock, len(cm.Content.ContentBlocks)) + for i, block := range cm.Content.ContentBlocks { + blockType := ResponsesMessageContentBlockType(block.Type) + + switch block.Type { + case ChatContentBlockTypeText: + if cm.Role == ChatMessageRoleAssistant { + blockType = ResponsesOutputMessageContentTypeText + } else { + blockType = ResponsesInputMessageContentBlockTypeText + } + case ChatContentBlockTypeImage: + blockType = ResponsesInputMessageContentBlockTypeImage + case ChatContentBlockTypeFile: + blockType = ResponsesInputMessageContentBlockTypeFile + case ChatContentBlockTypeInputAudio: + blockType = ResponsesInputMessageContentBlockTypeAudio } - case ChatContentBlockTypeImage: - blockType = ResponsesInputMessageContentBlockTypeImage - case ChatContentBlockTypeFile: - blockType = ResponsesInputMessageContentBlockTypeFile - case ChatContentBlockTypeInputAudio: - blockType = ResponsesInputMessageContentBlockTypeAudio - } - responseBlocks[i] = ResponsesMessageContentBlock{ - Type: blockType, - Text: block.Text, - } - - // Convert specific block types - if block.ImageURLStruct != nil { - responseBlocks[i].ResponsesInputMessageContentBlockImage = &ResponsesInputMessageContentBlockImage{ - ImageURL: &block.ImageURLStruct.URL, - Detail: block.ImageURLStruct.Detail, + responseBlocks[i] = ResponsesMessageContentBlock{ + Type: blockType, + Text: block.Text, } - } - if block.File != nil { - responseBlocks[i].ResponsesInputMessageContentBlockFile = &ResponsesInputMessageContentBlockFile{ - FileData: block.File.FileData, - Filename: block.File.Filename, + + // Convert specific block types + if block.ImageURLStruct != nil { + responseBlocks[i].ResponsesInputMessageContentBlockImage = &ResponsesInputMessageContentBlockImage{ + ImageURL: &block.ImageURLStruct.URL, + Detail: block.ImageURLStruct.Detail, + } } - responseBlocks[i].FileID = block.File.FileID - } - if block.InputAudio != nil { - format := "" - if block.InputAudio.Format != nil { - format = *block.InputAudio.Format + if block.File != nil { + responseBlocks[i].ResponsesInputMessageContentBlockFile = &ResponsesInputMessageContentBlockFile{ + FileData: block.File.FileData, + Filename: block.File.Filename, + } + responseBlocks[i].FileID = block.File.FileID } - responseBlocks[i].Audio = &ResponsesInputMessageContentBlockAudio{ - Data: block.InputAudio.Data, - Format: format, + if block.InputAudio != nil { + format := "" + if block.InputAudio.Format != nil { + format = *block.InputAudio.Format + } + responseBlocks[i].Audio = &ResponsesInputMessageContentBlockAudio{ + Data: block.InputAudio.Data, + Format: format, + } } } - } - rm.Content = &ResponsesMessageContent{ - ContentBlocks: responseBlocks, + rm.Content = &ResponsesMessageContent{ + ContentBlocks: responseBlocks, + } } } @@ -422,9 +557,56 @@ func (cm *ChatMessage) ToResponsesMessages() []ResponsesMessage { } // If tool output content exists, add it to function_call_output - if rm.Content != nil && rm.Content.ContentStr != nil && *rm.Content.ContentStr != "" { - rm.ResponsesToolMessage.Output = &ResponsesToolMessageOutputStruct{ - ResponsesToolCallOutputStr: rm.Content.ContentStr, + // For function_call_output, get content from cm.Content since rm.Content is not set + if messageType == ResponsesMessageTypeFunctionCallOutput && cm.Content != nil { + // Prefer ContentStr if present + if cm.Content.ContentStr != nil && *cm.Content.ContentStr != "" { + rm.ResponsesToolMessage.Output = &ResponsesToolMessageOutputStruct{ + ResponsesToolCallOutputStr: cm.Content.ContentStr, + } + } else if len(cm.Content.ContentBlocks) > 0 { + // For structured content blocks, convert to ResponsesMessageContentBlock + respBlocks := make([]ResponsesMessageContentBlock, len(cm.Content.ContentBlocks)) + for i, block := range cm.Content.ContentBlocks { + respBlocks[i] = ResponsesMessageContentBlock{ + Type: ResponsesMessageContentBlockType(block.Type), + Text: block.Text, + CacheControl: block.CacheControl, + } + + // Map image + if block.ImageURLStruct != nil { + respBlocks[i].ResponsesInputMessageContentBlockImage = &ResponsesInputMessageContentBlockImage{ + ImageURL: &block.ImageURLStruct.URL, + Detail: block.ImageURLStruct.Detail, + } + } + + // Map file + if block.File != nil { + respBlocks[i].FileID = block.File.FileID + respBlocks[i].ResponsesInputMessageContentBlockFile = &ResponsesInputMessageContentBlockFile{ + FileData: block.File.FileData, + Filename: block.File.Filename, + FileType: block.File.FileType, + } + } + + // Map audio + if block.InputAudio != nil { + format := "" + if block.InputAudio.Format != nil { + format = *block.InputAudio.Format + } + respBlocks[i].Audio = &ResponsesInputMessageContentBlockAudio{ + Data: block.InputAudio.Data, + Format: format, + } + } + } + rm.ResponsesToolMessage.Output = &ResponsesToolMessageOutputStruct{ + ResponsesFunctionToolCallOutputBlocks: respBlocks, + } } } } diff --git a/core/schemas/plugin.go b/core/schemas/plugin.go index 2b7d5be930..1bd794b870 100644 --- a/core/schemas/plugin.go +++ b/core/schemas/plugin.go @@ -1,13 +1,11 @@ // Package schemas defines the core schemas and types used by the Bifrost system. package schemas -// PluginShortCircuit represents a plugin's decision to short-circuit the normal flow. -// It can contain either a response (success short-circuit), a stream (streaming short-circuit), or an error (error short-circuit). -type PluginShortCircuit struct { - Response *BifrostResponse // If set, short-circuit with this response (skips provider call) - Stream chan *BifrostStream // If set, short-circuit with this stream (skips provider call) - Error *BifrostError // If set, short-circuit with this error (can set AllowFallbacks field) -} +import ( + "context" + "strings" + "sync" +) // PluginStatus constants const ( @@ -27,6 +25,93 @@ type PluginStatus struct { Logs []string `json:"logs"` } +// HTTPRequest is a serializable representation of an HTTP request. +// Used for plugin HTTP transport interception (supports both native .so and WASM plugins). +// This type is pooled for allocation control - use AcquireHTTPRequest and ReleaseHTTPRequest. +type HTTPRequest struct { + Method string `json:"method"` + Path string `json:"path"` + Headers map[string]string `json:"headers"` + Query map[string]string `json:"query"` + Body []byte `json:"body"` +} + +// CaseInsensitiveHeaderLookup looks up a header key in a case-insensitive manner +func (req *HTTPRequest) CaseInsensitiveHeaderLookup(key string) string { + return caseInsensitiveLookup(req.Headers, key) +} + +// CaseInsensitiveQueryLookup looks up a query key in a case-insensitive manner +func (req *HTTPRequest) CaseInsensitiveQueryLookup(key string) string { + return caseInsensitiveLookup(req.Query, key) +} + +// caseInsensitiveLookup looks up a key in a case-insensitive manner for a map of strings +// Returns the value if found, otherwise an empty string +func caseInsensitiveLookup(data map[string]string, key string) string { + if data == nil || key == "" { + return "" + } + // exact match + if v, ok := data[key]; ok { + return v + } + // lower key checks + lowerKey := strings.ToLower(key) + if v, ok := data[lowerKey]; ok { + return v + } + // case-insensitive iteration + for k, v := range data { + if strings.EqualFold(k, key) { + return v + } + } + return "" +} + +// HTTPResponse is a serializable representation of an HTTP response. +// Used for short-circuit responses in plugin HTTP transport interception. +type HTTPResponse struct { + StatusCode int `json:"status_code"` + Headers map[string]string `json:"headers"` + Body []byte `json:"body"` +} + +// httpRequestPool is the pool for HTTPRequest objects to reduce allocations. +var httpRequestPool = sync.Pool{ + New: func() any { + return &HTTPRequest{ + Headers: make(map[string]string, 16), + Query: make(map[string]string, 8), + } + }, +} + +// AcquireHTTPRequest gets an HTTPRequest from the pool. +// The returned HTTPRequest is ready to use with pre-allocated maps. +// Call ReleaseHTTPRequest when done to return it to the pool. +func AcquireHTTPRequest() *HTTPRequest { + return httpRequestPool.Get().(*HTTPRequest) +} + +// ReleaseHTTPRequest returns an HTTPRequest to the pool. +// The HTTPRequest is reset before being returned to the pool. +// Do not use the HTTPRequest after calling this function. +func ReleaseHTTPRequest(req *HTTPRequest) { + if req == nil { + return + } + // Clear the maps + clear(req.Headers) + clear(req.Query) + // Reset fields + req.Method = "" + req.Path = "" + req.Body = nil + httpRequestPool.Put(req) +} + // Plugin defines the interface for Bifrost plugins. // Plugins can intercept and modify requests and responses at different stages // of the processing pipeline. @@ -35,7 +120,7 @@ type PluginStatus struct { // PostHooks are executed in the reverse order of PreHooks. // // Execution order: -// 1. TransportInterceptor (HTTP transport only, modifies raw headers/body before entering Bifrost core) +// 1. HTTPTransportIntercept (HTTP transport only, modifies raw headers/body before entering Bifrost core) // 2. PreHook (executed in registration order) // 3. Provider call // 4. PostHook (executed in reverse order of PreHooks) @@ -62,11 +147,18 @@ type Plugin interface { // GetName returns the name of the plugin. GetName() string - // TransportInterceptor is called at the HTTP transport layer before requests enter Bifrost core. - // It allows plugins to modify raw HTTP headers and body before transformation into BifrostRequest. + // HTTPTransportIntercept is called at the HTTP transport layer before requests enter Bifrost core. + // It receives a serializable HTTPRequest and allows plugins to modify it in-place. // Only invoked when using HTTP transport (bifrost-http), not when using Bifrost as a Go SDK directly. - // Returns modified headers, modified body, and any error that occurred during interception. - TransportInterceptor(ctx *BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) + // Works with both native .so plugins and WASM plugins due to serializable types. + // + // Return values: + // - (nil, nil): Continue to next plugin/handler, request modifications are applied + // - (*HTTPResponse, nil): Short-circuit with this response, skip remaining plugins and provider call + // - (nil, error): Short-circuit with error response + // + // Return nil for both values if the plugin doesn't need HTTP transport interception. + HTTPTransportIntercept(ctx *BifrostContext, req *HTTPRequest) (*HTTPResponse, error) // PreHook is called before a request is processed by a provider. // It allows plugins to modify the request before it is sent to the provider. @@ -95,3 +187,33 @@ type PluginConfig struct { Version *int16 `json:"version,omitempty"` Config any `json:"config,omitempty"` } + +// ObservabilityPlugin is an interface for plugins that receive completed traces +// for forwarding to observability backends (e.g., OTEL collectors, Datadog, etc.) +// +// ObservabilityPlugins are called asynchronously after the HTTP response has been +// written to the wire, ensuring they don't add latency to the client response. +// +// Plugins implementing this interface will: +// 1. Continue to work as regular plugins via PreHook/PostHook +// 2. Additionally receive completed traces via the Inject method +// +// Example backends: OpenTelemetry collectors, Datadog, Jaeger, Maxim, etc. +// +// Note: Go type assertion (plugin.(ObservabilityPlugin)) is used to identify +// plugins implementing this interface - no marker method is needed. +type ObservabilityPlugin interface { + Plugin + + // Inject receives a completed trace for forwarding to observability backends. + // This method is called asynchronously after the response has been written to the client. + // The trace contains all spans that were added during request processing. + // + // Implementations should: + // - Convert the trace to their backend's format + // - Send the trace to the backend (can be async) + // - Handle errors gracefully (log and continue) + // + // The context passed is a fresh background context, not the request context. + Inject(ctx context.Context, trace *Trace) error +} diff --git a/core/schemas/plugin_native.go b/core/schemas/plugin_native.go new file mode 100644 index 0000000000..ccd68b8388 --- /dev/null +++ b/core/schemas/plugin_native.go @@ -0,0 +1,21 @@ +//go:build !tinygo && !wasm + +package schemas + +import ( + "github.com/valyala/fasthttp" +) + +// BifrostHTTPMiddleware is a middleware function for the Bifrost HTTP transport. +// It follows the standard pattern: receives the next handler and returns a new handler. +// Used internally for CORS, Auth, Tracing middleware. Plugins use HTTPTransportIntercept instead. +type BifrostHTTPMiddleware func(next fasthttp.RequestHandler) fasthttp.RequestHandler + + +// PluginShortCircuit represents a plugin's decision to short-circuit the normal flow. +// It can contain either a response (success short-circuit), a stream (streaming short-circuit), or an error (error short-circuit). +type PluginShortCircuit struct { + Response *BifrostResponse // If set, short-circuit with this response (skips provider call) + Stream chan *BifrostStream // If set, short-circuit with this stream (skips provider call) + Error *BifrostError // If set, short-circuit with this error (can set AllowFallbacks field) +} \ No newline at end of file diff --git a/core/schemas/plugin_wasm.go b/core/schemas/plugin_wasm.go new file mode 100644 index 0000000000..04fc06e710 --- /dev/null +++ b/core/schemas/plugin_wasm.go @@ -0,0 +1,11 @@ +//go:build tinygo || wasm + +package schemas + +// PluginShortCircuit represents a plugin's decision to short-circuit the normal flow. +// It can contain either a response (success short-circuit), an error (error short-circuit). +// Streams are not supported in WASM plugins. +type PluginShortCircuit struct { + Response *BifrostResponse // If set, short-circuit with this response (skips provider call) + Error *BifrostError // If set, short-circuit with this error (can set AllowFallbacks field) +} diff --git a/core/schemas/provider.go b/core/schemas/provider.go index 0d5527addf..5de9644093 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -2,7 +2,6 @@ package schemas import ( - "context" "encoding/json" "maps" "time" @@ -312,56 +311,56 @@ func (config *ProviderConfig) CheckAndSetDefaults() { } } -type PostHookRunner func(ctx *context.Context, result *BifrostResponse, err *BifrostError) (*BifrostResponse, *BifrostError) +type PostHookRunner func(ctx *BifrostContext, result *BifrostResponse, err *BifrostError) (*BifrostResponse, *BifrostError) // Provider defines the interface for AI model providers. type Provider interface { // GetProviderKey returns the provider's identifier GetProviderKey() ModelProvider // ListModels performs a list models request - ListModels(ctx context.Context, keys []Key, request *BifrostListModelsRequest) (*BifrostListModelsResponse, *BifrostError) + ListModels(ctx *BifrostContext, keys []Key, request *BifrostListModelsRequest) (*BifrostListModelsResponse, *BifrostError) // TextCompletion performs a text completion request - TextCompletion(ctx context.Context, key Key, request *BifrostTextCompletionRequest) (*BifrostTextCompletionResponse, *BifrostError) + TextCompletion(ctx *BifrostContext, key Key, request *BifrostTextCompletionRequest) (*BifrostTextCompletionResponse, *BifrostError) // TextCompletionStream performs a text completion stream request - TextCompletionStream(ctx context.Context, postHookRunner PostHookRunner, key Key, request *BifrostTextCompletionRequest) (chan *BifrostStream, *BifrostError) + TextCompletionStream(ctx *BifrostContext, postHookRunner PostHookRunner, key Key, request *BifrostTextCompletionRequest) (chan *BifrostStream, *BifrostError) // ChatCompletion performs a chat completion request - ChatCompletion(ctx context.Context, key Key, request *BifrostChatRequest) (*BifrostChatResponse, *BifrostError) + ChatCompletion(ctx *BifrostContext, key Key, request *BifrostChatRequest) (*BifrostChatResponse, *BifrostError) // ChatCompletionStream performs a chat completion stream request - ChatCompletionStream(ctx context.Context, postHookRunner PostHookRunner, key Key, request *BifrostChatRequest) (chan *BifrostStream, *BifrostError) + ChatCompletionStream(ctx *BifrostContext, postHookRunner PostHookRunner, key Key, request *BifrostChatRequest) (chan *BifrostStream, *BifrostError) // Responses performs a completion request using the Responses API (uses chat completion request internally for non-openai providers) - Responses(ctx context.Context, key Key, request *BifrostResponsesRequest) (*BifrostResponsesResponse, *BifrostError) + Responses(ctx *BifrostContext, key Key, request *BifrostResponsesRequest) (*BifrostResponsesResponse, *BifrostError) // ResponsesStream performs a completion request using the Responses API stream (uses chat completion stream request internally for non-openai providers) - ResponsesStream(ctx context.Context, postHookRunner PostHookRunner, key Key, request *BifrostResponsesRequest) (chan *BifrostStream, *BifrostError) + ResponsesStream(ctx *BifrostContext, postHookRunner PostHookRunner, key Key, request *BifrostResponsesRequest) (chan *BifrostStream, *BifrostError) // CountTokens performs a count tokens request - CountTokens(ctx context.Context, key Key, request *BifrostResponsesRequest) (*BifrostCountTokensResponse, *BifrostError) + CountTokens(ctx *BifrostContext, key Key, request *BifrostResponsesRequest) (*BifrostCountTokensResponse, *BifrostError) // Embedding performs an embedding request - Embedding(ctx context.Context, key Key, request *BifrostEmbeddingRequest) (*BifrostEmbeddingResponse, *BifrostError) + Embedding(ctx *BifrostContext, key Key, request *BifrostEmbeddingRequest) (*BifrostEmbeddingResponse, *BifrostError) // Speech performs a text to speech request - Speech(ctx context.Context, key Key, request *BifrostSpeechRequest) (*BifrostSpeechResponse, *BifrostError) + Speech(ctx *BifrostContext, key Key, request *BifrostSpeechRequest) (*BifrostSpeechResponse, *BifrostError) // SpeechStream performs a text to speech stream request - SpeechStream(ctx context.Context, postHookRunner PostHookRunner, key Key, request *BifrostSpeechRequest) (chan *BifrostStream, *BifrostError) + SpeechStream(ctx *BifrostContext, postHookRunner PostHookRunner, key Key, request *BifrostSpeechRequest) (chan *BifrostStream, *BifrostError) // Transcription performs a transcription request - Transcription(ctx context.Context, key Key, request *BifrostTranscriptionRequest) (*BifrostTranscriptionResponse, *BifrostError) + Transcription(ctx *BifrostContext, key Key, request *BifrostTranscriptionRequest) (*BifrostTranscriptionResponse, *BifrostError) // TranscriptionStream performs a transcription stream request - TranscriptionStream(ctx context.Context, postHookRunner PostHookRunner, key Key, request *BifrostTranscriptionRequest) (chan *BifrostStream, *BifrostError) + TranscriptionStream(ctx *BifrostContext, postHookRunner PostHookRunner, key Key, request *BifrostTranscriptionRequest) (chan *BifrostStream, *BifrostError) // BatchCreate creates a new batch job for asynchronous processing - BatchCreate(ctx context.Context, key Key, request *BifrostBatchCreateRequest) (*BifrostBatchCreateResponse, *BifrostError) + BatchCreate(ctx *BifrostContext, key Key, request *BifrostBatchCreateRequest) (*BifrostBatchCreateResponse, *BifrostError) // BatchList lists batch jobs - BatchList(ctx context.Context, keys []Key, request *BifrostBatchListRequest) (*BifrostBatchListResponse, *BifrostError) + BatchList(ctx *BifrostContext, keys []Key, request *BifrostBatchListRequest) (*BifrostBatchListResponse, *BifrostError) // BatchRetrieve retrieves a specific batch job - BatchRetrieve(ctx context.Context, keys []Key, request *BifrostBatchRetrieveRequest) (*BifrostBatchRetrieveResponse, *BifrostError) + BatchRetrieve(ctx *BifrostContext, keys []Key, request *BifrostBatchRetrieveRequest) (*BifrostBatchRetrieveResponse, *BifrostError) // BatchCancel cancels a batch job - BatchCancel(ctx context.Context, keys []Key, request *BifrostBatchCancelRequest) (*BifrostBatchCancelResponse, *BifrostError) + BatchCancel(ctx *BifrostContext, keys []Key, request *BifrostBatchCancelRequest) (*BifrostBatchCancelResponse, *BifrostError) // BatchResults retrieves results from a completed batch job - BatchResults(ctx context.Context, keys []Key, request *BifrostBatchResultsRequest) (*BifrostBatchResultsResponse, *BifrostError) + BatchResults(ctx *BifrostContext, keys []Key, request *BifrostBatchResultsRequest) (*BifrostBatchResultsResponse, *BifrostError) // FileUpload uploads a file to the provider - FileUpload(ctx context.Context, key Key, request *BifrostFileUploadRequest) (*BifrostFileUploadResponse, *BifrostError) + FileUpload(ctx *BifrostContext, key Key, request *BifrostFileUploadRequest) (*BifrostFileUploadResponse, *BifrostError) // FileList lists files from the provider - FileList(ctx context.Context, keys []Key, request *BifrostFileListRequest) (*BifrostFileListResponse, *BifrostError) + FileList(ctx *BifrostContext, keys []Key, request *BifrostFileListRequest) (*BifrostFileListResponse, *BifrostError) // FileRetrieve retrieves file metadata from the provider - FileRetrieve(ctx context.Context, keys []Key, request *BifrostFileRetrieveRequest) (*BifrostFileRetrieveResponse, *BifrostError) + FileRetrieve(ctx *BifrostContext, keys []Key, request *BifrostFileRetrieveRequest) (*BifrostFileRetrieveResponse, *BifrostError) // FileDelete deletes a file from the provider - FileDelete(ctx context.Context, keys []Key, request *BifrostFileDeleteRequest) (*BifrostFileDeleteResponse, *BifrostError) + FileDelete(ctx *BifrostContext, keys []Key, request *BifrostFileDeleteRequest) (*BifrostFileDeleteResponse, *BifrostError) // FileContent downloads file content from the provider - FileContent(ctx context.Context, keys []Key, request *BifrostFileContentRequest) (*BifrostFileContentResponse, *BifrostError) + FileContent(ctx *BifrostContext, keys []Key, request *BifrostFileContentRequest) (*BifrostFileContentResponse, *BifrostError) } diff --git a/core/schemas/responses.go b/core/schemas/responses.go index 526717720c..566018970e 100644 --- a/core/schemas/responses.go +++ b/core/schemas/responses.go @@ -2,8 +2,6 @@ package schemas import ( "fmt" - - "github.com/bytedance/sonic" ) // ============================================================================= @@ -56,7 +54,7 @@ type BifrostResponsesResponse struct { MaxToolCalls *int `json:"max_tool_calls,omitempty"` Metadata *map[string]any `json:"metadata,omitempty"` Model string `json:"model"` - Output []ResponsesMessage `json:"output,omitempty"` + Output []ResponsesMessage `json:"output"` ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` PreviousResponseID *string `json:"previous_response_id,omitempty"` Prompt *ResponsesPrompt `json:"prompt,omitempty"` // Reference to a prompt template and variables @@ -155,13 +153,13 @@ func (rc ResponsesResponseConversation) MarshalJSON() ([]byte, error) { } if rc.ResponsesResponseConversationStr != nil { - return sonic.Marshal(*rc.ResponsesResponseConversationStr) + return Marshal(*rc.ResponsesResponseConversationStr) } if rc.ResponsesResponseConversationStruct != nil { - return sonic.Marshal(rc.ResponsesResponseConversationStruct) + return Marshal(rc.ResponsesResponseConversationStruct) } // If both are nil, return null - return sonic.Marshal(nil) + return Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ResponsesMessageContent. @@ -170,14 +168,14 @@ func (rc ResponsesResponseConversation) MarshalJSON() ([]byte, error) { func (rc *ResponsesResponseConversation) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var stringContent string - if err := sonic.Unmarshal(data, &stringContent); err == nil { + if err := Unmarshal(data, &stringContent); err == nil { rc.ResponsesResponseConversationStr = &stringContent return nil } // Try to unmarshal as a direct array of ContentBlock var structContent ResponsesResponseConversationStruct - if err := sonic.Unmarshal(data, &structContent); err == nil { + if err := Unmarshal(data, &structContent); err == nil { rc.ResponsesResponseConversationStruct = &structContent return nil } @@ -199,13 +197,13 @@ func (rc ResponsesResponseInstructions) MarshalJSON() ([]byte, error) { } if rc.ResponsesResponseInstructionsStr != nil { - return sonic.Marshal(*rc.ResponsesResponseInstructionsStr) + return Marshal(*rc.ResponsesResponseInstructionsStr) } if rc.ResponsesResponseInstructionsArray != nil { - return sonic.Marshal(rc.ResponsesResponseInstructionsArray) + return Marshal(rc.ResponsesResponseInstructionsArray) } // If both are nil, return null - return sonic.Marshal(nil) + return Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ResponsesMessageContent. @@ -214,14 +212,14 @@ func (rc ResponsesResponseInstructions) MarshalJSON() ([]byte, error) { func (rc *ResponsesResponseInstructions) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var stringContent string - if err := sonic.Unmarshal(data, &stringContent); err == nil { + if err := Unmarshal(data, &stringContent); err == nil { rc.ResponsesResponseInstructionsStr = &stringContent return nil } // Try to unmarshal as a direct array of ContentBlock var arrayContent []ResponsesMessage - if err := sonic.Unmarshal(data, &arrayContent); err == nil { + if err := Unmarshal(data, &arrayContent); err == nil { rc.ResponsesResponseInstructionsArray = arrayContent return nil } @@ -359,13 +357,13 @@ func (rc ResponsesMessageContent) MarshalJSON() ([]byte, error) { } if rc.ContentStr != nil { - return sonic.Marshal(*rc.ContentStr) + return Marshal(*rc.ContentStr) } if rc.ContentBlocks != nil { - return sonic.Marshal(rc.ContentBlocks) + return Marshal(rc.ContentBlocks) } // If both are nil, return null - return sonic.Marshal(nil) + return Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ResponsesMessageContent. @@ -374,14 +372,14 @@ func (rc ResponsesMessageContent) MarshalJSON() ([]byte, error) { func (rc *ResponsesMessageContent) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var stringContent string - if err := sonic.Unmarshal(data, &stringContent); err == nil { + if err := Unmarshal(data, &stringContent); err == nil { rc.ContentStr = &stringContent return nil } // Try to unmarshal as a direct array of ContentBlock var arrayContent []ResponsesMessageContentBlock - if err := sonic.Unmarshal(data, &arrayContent); err == nil { + if err := Unmarshal(data, &arrayContent); err == nil { rc.ContentBlocks = arrayContent return nil } @@ -501,38 +499,38 @@ type ResponsesToolMessageActionStruct struct { func (action ResponsesToolMessageActionStruct) MarshalJSON() ([]byte, error) { if action.ResponsesComputerToolCallAction != nil { - return sonic.Marshal(action.ResponsesComputerToolCallAction) + return Marshal(action.ResponsesComputerToolCallAction) } if action.ResponsesWebSearchToolCallAction != nil { - return sonic.Marshal(action.ResponsesWebSearchToolCallAction) + return Marshal(action.ResponsesWebSearchToolCallAction) } if action.ResponsesLocalShellToolCallAction != nil { - return sonic.Marshal(action.ResponsesLocalShellToolCallAction) + return Marshal(action.ResponsesLocalShellToolCallAction) } if action.ResponsesMCPApprovalRequestAction != nil { - return sonic.Marshal(action.ResponsesMCPApprovalRequestAction) + return Marshal(action.ResponsesMCPApprovalRequestAction) } return nil, fmt.Errorf("responses tool message action struct is neither a computer tool call action nor a web search tool call action nor a local shell tool call action nor a mcp approval request action") } func (action *ResponsesToolMessageActionStruct) UnmarshalJSON(data []byte) error { var computerToolCallAction ResponsesComputerToolCallAction - if err := sonic.Unmarshal(data, &computerToolCallAction); err == nil { + if err := Unmarshal(data, &computerToolCallAction); err == nil { action.ResponsesComputerToolCallAction = &computerToolCallAction return nil } var webSearchToolCallAction ResponsesWebSearchToolCallAction - if err := sonic.Unmarshal(data, &webSearchToolCallAction); err == nil { + if err := Unmarshal(data, &webSearchToolCallAction); err == nil { action.ResponsesWebSearchToolCallAction = &webSearchToolCallAction return nil } var localShellToolCallAction ResponsesLocalShellToolCallAction - if err := sonic.Unmarshal(data, &localShellToolCallAction); err == nil { + if err := Unmarshal(data, &localShellToolCallAction); err == nil { action.ResponsesLocalShellToolCallAction = &localShellToolCallAction return nil } var mcpApprovalRequestAction ResponsesMCPApprovalRequestAction - if err := sonic.Unmarshal(data, &mcpApprovalRequestAction); err == nil { + if err := Unmarshal(data, &mcpApprovalRequestAction); err == nil { action.ResponsesMCPApprovalRequestAction = &mcpApprovalRequestAction return nil } @@ -547,29 +545,29 @@ type ResponsesToolMessageOutputStruct struct { func (output ResponsesToolMessageOutputStruct) MarshalJSON() ([]byte, error) { if output.ResponsesToolCallOutputStr != nil { - return sonic.Marshal(*output.ResponsesToolCallOutputStr) + return Marshal(*output.ResponsesToolCallOutputStr) } if output.ResponsesFunctionToolCallOutputBlocks != nil { - return sonic.Marshal(output.ResponsesFunctionToolCallOutputBlocks) + return Marshal(output.ResponsesFunctionToolCallOutputBlocks) } if output.ResponsesComputerToolCallOutput != nil { - return sonic.Marshal(output.ResponsesComputerToolCallOutput) + return Marshal(output.ResponsesComputerToolCallOutput) } return nil, fmt.Errorf("responses tool message output struct is neither a string nor an array of responses message content blocks nor a computer tool call output data") } func (output *ResponsesToolMessageOutputStruct) UnmarshalJSON(data []byte) error { var str string - if err := sonic.Unmarshal(data, &str); err == nil { + if err := Unmarshal(data, &str); err == nil { output.ResponsesToolCallOutputStr = &str return nil } var array []ResponsesMessageContentBlock - if err := sonic.Unmarshal(data, &array); err == nil { + if err := Unmarshal(data, &array); err == nil { output.ResponsesFunctionToolCallOutputBlocks = array return nil } var computerToolCallOutput ResponsesComputerToolCallOutputData - if err := sonic.Unmarshal(data, &computerToolCallOutput); err == nil { + if err := Unmarshal(data, &computerToolCallOutput); err == nil { output.ResponsesComputerToolCallOutput = &computerToolCallOutput return nil } @@ -685,13 +683,13 @@ func (rf ResponsesFunctionToolCallOutput) MarshalJSON() ([]byte, error) { } if rf.ResponsesFunctionToolCallOutputStr != nil { - return sonic.Marshal(*rf.ResponsesFunctionToolCallOutputStr) + return Marshal(*rf.ResponsesFunctionToolCallOutputStr) } if rf.ResponsesFunctionToolCallOutputBlocks != nil { - return sonic.Marshal(rf.ResponsesFunctionToolCallOutputBlocks) + return Marshal(rf.ResponsesFunctionToolCallOutputBlocks) } // If both are nil, return null - return sonic.Marshal(nil) + return Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ResponsesFunctionToolCallOutput. @@ -700,7 +698,7 @@ func (rf ResponsesFunctionToolCallOutput) MarshalJSON() ([]byte, error) { func (rf *ResponsesFunctionToolCallOutput) UnmarshalJSON(data []byte) error { // Parse as generic object to check if it contains content-like fields var genericObj map[string]interface{} - if err := sonic.Unmarshal(data, &genericObj); err != nil { + if err := Unmarshal(data, &genericObj); err != nil { return err } @@ -720,14 +718,14 @@ func (rf *ResponsesFunctionToolCallOutput) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var stringContent string - if err := sonic.Unmarshal(data, &stringContent); err == nil { + if err := Unmarshal(data, &stringContent); err == nil { rf.ResponsesFunctionToolCallOutputStr = &stringContent return nil } // Try to unmarshal as a direct array of ContentBlock var arrayContent []ResponsesMessageContentBlock - if err := sonic.Unmarshal(data, &arrayContent); err == nil { + if err := Unmarshal(data, &arrayContent); err == nil { rf.ResponsesFunctionToolCallOutputBlocks = arrayContent return nil } @@ -794,10 +792,10 @@ func (o ResponsesCodeInterpreterOutput) MarshalJSON() ([]byte, error) { // Marshal whichever one is present if o.ResponsesCodeInterpreterOutputLogs != nil { - return sonic.Marshal(o.ResponsesCodeInterpreterOutputLogs) + return Marshal(o.ResponsesCodeInterpreterOutputLogs) } if o.ResponsesCodeInterpreterOutputImage != nil { - return sonic.Marshal(o.ResponsesCodeInterpreterOutputImage) + return Marshal(o.ResponsesCodeInterpreterOutputImage) } // Return null if neither is set @@ -815,7 +813,7 @@ func (o *ResponsesCodeInterpreterOutput) UnmarshalJSON(data []byte) error { var typeStruct struct { Type string `json:"type"` } - if err := sonic.Unmarshal(data, &typeStruct); err != nil { + if err := Unmarshal(data, &typeStruct); err != nil { return fmt.Errorf("failed to read type field: %w", err) } @@ -823,7 +821,7 @@ func (o *ResponsesCodeInterpreterOutput) UnmarshalJSON(data []byte) error { switch typeStruct.Type { case "logs": var logs ResponsesCodeInterpreterOutputLogs - if err := sonic.Unmarshal(data, &logs); err != nil { + if err := Unmarshal(data, &logs); err != nil { return fmt.Errorf("failed to unmarshal logs output: %w", err) } o.ResponsesCodeInterpreterOutputLogs = &logs @@ -832,7 +830,7 @@ func (o *ResponsesCodeInterpreterOutput) UnmarshalJSON(data []byte) error { case "image": var image ResponsesCodeInterpreterOutputImage - if err := sonic.Unmarshal(data, &image); err != nil { + if err := Unmarshal(data, &image); err != nil { return fmt.Errorf("failed to unmarshal image output: %w", err) } o.ResponsesCodeInterpreterOutputImage = &image @@ -982,13 +980,13 @@ func (tc ResponsesToolChoice) MarshalJSON() ([]byte, error) { } if tc.ResponsesToolChoiceStr != nil { - return sonic.Marshal(tc.ResponsesToolChoiceStr) + return Marshal(tc.ResponsesToolChoiceStr) } if tc.ResponsesToolChoiceStruct != nil { - return sonic.Marshal(tc.ResponsesToolChoiceStruct) + return Marshal(tc.ResponsesToolChoiceStruct) } // If both are nil, return null - return sonic.Marshal(nil) + return Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ChatMessageContent. @@ -997,14 +995,14 @@ func (tc ResponsesToolChoice) MarshalJSON() ([]byte, error) { func (tc *ResponsesToolChoice) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var toolChoiceStr string - if err := sonic.Unmarshal(data, &toolChoiceStr); err == nil { + if err := Unmarshal(data, &toolChoiceStr); err == nil { tc.ResponsesToolChoiceStr = &toolChoiceStr return nil } // Try to unmarshal as a direct array of ContentBlock var responsesToolChoiceStruct ResponsesToolChoiceStruct - if err := sonic.Unmarshal(data, &responsesToolChoiceStruct); err == nil { + if err := Unmarshal(data, &responsesToolChoiceStruct); err == nil { tc.ResponsesToolChoiceStruct = &responsesToolChoiceStruct return nil } @@ -1115,14 +1113,14 @@ func (f *ResponsesToolFileSearchFilter) MarshalJSON() ([]byte, error) { return nil, fmt.Errorf("unknown filter type: %s", f.Type) } - return sonic.Marshal(result) + return Marshal(result) } // UnmarshalJSON implements custom JSON unmarshaling for ResponsesToolFileSearchFilter func (f *ResponsesToolFileSearchFilter) UnmarshalJSON(data []byte) error { // First, unmarshal into a map to inspect the type field var raw map[string]interface{} - if err := sonic.Unmarshal(data, &raw); err != nil { + if err := Unmarshal(data, &raw); err != nil { return fmt.Errorf("failed to unmarshal filter JSON: %w", err) } @@ -1147,7 +1145,7 @@ func (f *ResponsesToolFileSearchFilter) UnmarshalJSON(data []byte) error { f.ResponsesToolFileSearchCompoundFilter = nil // Unmarshal into the comparison filter - if err := sonic.Unmarshal(data, f.ResponsesToolFileSearchComparisonFilter); err != nil { + if err := Unmarshal(data, f.ResponsesToolFileSearchComparisonFilter); err != nil { return fmt.Errorf("failed to unmarshal comparison filter: %w", err) } @@ -1165,7 +1163,7 @@ func (f *ResponsesToolFileSearchFilter) UnmarshalJSON(data []byte) error { f.ResponsesToolFileSearchComparisonFilter = nil // Unmarshal into the compound filter - if err := sonic.Unmarshal(data, f.ResponsesToolFileSearchCompoundFilter); err != nil { + if err := Unmarshal(data, f.ResponsesToolFileSearchCompoundFilter); err != nil { return fmt.Errorf("failed to unmarshal compound filter: %w", err) } @@ -1273,7 +1271,7 @@ func (as ResponsesToolMCPAllowedToolsApprovalSetting) MarshalJSON() ([]byte, err } if as.Setting != nil { - return sonic.Marshal(*as.Setting) + return Marshal(*as.Setting) } if as.Always != nil || as.Never != nil { // Marshal as an object with always/never fields @@ -1284,17 +1282,17 @@ func (as ResponsesToolMCPAllowedToolsApprovalSetting) MarshalJSON() ([]byte, err if as.Never != nil { obj["never"] = as.Never } - return sonic.Marshal(obj) + return Marshal(obj) } // If all are nil, return null - return sonic.Marshal(nil) + return Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for ResponsesToolMCPAllowedToolsApprovalSetting func (as *ResponsesToolMCPAllowedToolsApprovalSetting) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var settingStr string - if err := sonic.Unmarshal(data, &settingStr); err == nil { + if err := Unmarshal(data, &settingStr); err == nil { as.Setting = &settingStr return nil } @@ -1304,7 +1302,7 @@ func (as *ResponsesToolMCPAllowedToolsApprovalSetting) UnmarshalJSON(data []byte Always *ResponsesToolMCPAllowedToolsApprovalFilter `json:"always,omitempty"` Never *ResponsesToolMCPAllowedToolsApprovalFilter `json:"never,omitempty"` } - if err := sonic.Unmarshal(data, &obj); err == nil { + if err := Unmarshal(data, &obj); err == nil { as.Always = obj.Always as.Never = obj.Never return nil diff --git a/core/schemas/speech.go b/core/schemas/speech.go index 6dcf4fec86..afea821177 100644 --- a/core/schemas/speech.go +++ b/core/schemas/speech.go @@ -2,8 +2,6 @@ package schemas import ( "fmt" - - "github.com/bytedance/sonic" ) type BifrostSpeechRequest struct { @@ -81,13 +79,13 @@ func (vi *SpeechVoiceInput) MarshalJSON() ([]byte, error) { } if vi.Voice != nil { - return sonic.Marshal(*vi.Voice) + return Marshal(*vi.Voice) } if len(vi.MultiVoiceConfig) > 0 { - return sonic.Marshal(vi.MultiVoiceConfig) + return Marshal(vi.MultiVoiceConfig) } // If both are nil, return null - return sonic.Marshal(nil) + return Marshal(nil) } // UnmarshalJSON implements custom JSON unmarshalling for SpeechVoiceInput. @@ -100,14 +98,14 @@ func (vi *SpeechVoiceInput) UnmarshalJSON(data []byte) error { // First, try to unmarshal as a direct string var stringContent string - if err := sonic.Unmarshal(data, &stringContent); err == nil { + if err := Unmarshal(data, &stringContent); err == nil { vi.Voice = &stringContent return nil } // Try to unmarshal as an array of VoiceConfig objects var voiceConfigs []VoiceConfig - if err := sonic.Unmarshal(data, &voiceConfigs); err == nil { + if err := Unmarshal(data, &voiceConfigs); err == nil { // Validate each VoiceConfig and build a new slice deterministically validConfigs := make([]VoiceConfig, 0, len(voiceConfigs)) for _, config := range voiceConfigs { diff --git a/core/schemas/textcompletions.go b/core/schemas/textcompletions.go index c65f0db2f9..071673b51a 100644 --- a/core/schemas/textcompletions.go +++ b/core/schemas/textcompletions.go @@ -2,8 +2,6 @@ package schemas import ( "fmt" - - "github.com/bytedance/sonic" ) // BifrostTextCompletionRequest is the request struct for text completion requests @@ -96,20 +94,20 @@ func (t *TextCompletionInput) MarshalJSON() ([]byte, error) { return nil, fmt.Errorf("text completion input must set exactly one of: prompt_str or prompt_array") } if t.PromptStr != nil { - return sonic.Marshal(*t.PromptStr) + return Marshal(*t.PromptStr) } - return sonic.Marshal(t.PromptArray) + return Marshal(t.PromptArray) } func (t *TextCompletionInput) UnmarshalJSON(data []byte) error { var prompt string - if err := sonic.Unmarshal(data, &prompt); err == nil { + if err := Unmarshal(data, &prompt); err == nil { t.PromptStr = &prompt t.PromptArray = nil return nil } var promptArray []string - if err := sonic.Unmarshal(data, &promptArray); err == nil { + if err := Unmarshal(data, &promptArray); err == nil { t.PromptStr = nil t.PromptArray = promptArray return nil diff --git a/core/schemas/trace.go b/core/schemas/trace.go new file mode 100644 index 0000000000..f6de7bfcfd --- /dev/null +++ b/core/schemas/trace.go @@ -0,0 +1,335 @@ +// Package schemas defines the core schemas and types used by the Bifrost system. +package schemas + +import ( + "sync" + "time" +) + +// Trace represents a distributed trace that captures the full lifecycle of a request +type Trace struct { + TraceID string // Unique identifier for this trace + ParentID string // Parent trace ID from incoming W3C traceparent header + RootSpan *Span // The root span of this trace + Spans []*Span // All spans in this trace + StartTime time.Time // When the trace started + EndTime time.Time // When the trace completed + Attributes map[string]any // Additional attributes for the trace + mu sync.Mutex // Mutex for thread-safe span operations +} + +// AddSpan adds a span to the trace in a thread-safe manner +func (t *Trace) AddSpan(span *Span) { + t.mu.Lock() + defer t.mu.Unlock() + t.Spans = append(t.Spans, span) +} + +// GetSpan retrieves a span by ID +func (t *Trace) GetSpan(spanID string) *Span { + t.mu.Lock() + defer t.mu.Unlock() + for _, span := range t.Spans { + if span.SpanID == spanID { + return span + } + } + return nil +} + +// Reset clears the trace for reuse from pool +func (t *Trace) Reset() { + t.TraceID = "" + t.ParentID = "" + t.RootSpan = nil + t.Spans = t.Spans[:0] + t.StartTime = time.Time{} + t.EndTime = time.Time{} + t.Attributes = nil +} + +// Span represents a single operation within a trace +type Span struct { + SpanID string // Unique identifier for this span + ParentID string // Parent span ID (empty for root span) + TraceID string // The trace this span belongs to + Name string // Name of the operation + Kind SpanKind // Type of span (LLM call, plugin, etc.) + StartTime time.Time // When the span started + EndTime time.Time // When the span completed + Status SpanStatus // Status of the operation + StatusMsg string // Optional status message (for errors) + Attributes map[string]any // Additional attributes for the span + Events []SpanEvent // Events that occurred during the span + mu sync.Mutex // Mutex for thread-safe attribute operations +} + +// SetAttribute sets an attribute on the span in a thread-safe manner +func (s *Span) SetAttribute(key string, value any) { + s.mu.Lock() + defer s.mu.Unlock() + if s.Attributes == nil { + s.Attributes = make(map[string]any) + } + s.Attributes[key] = value +} + +// AddEvent adds an event to the span in a thread-safe manner +func (s *Span) AddEvent(event SpanEvent) { + s.mu.Lock() + defer s.mu.Unlock() + s.Events = append(s.Events, event) +} + +// End marks the span as complete with the given status +func (s *Span) End(status SpanStatus, statusMsg string) { + s.mu.Lock() + defer s.mu.Unlock() + s.EndTime = time.Now() + s.Status = status + s.StatusMsg = statusMsg +} + +// Reset clears the span for reuse from pool +func (s *Span) Reset() { + s.SpanID = "" + s.ParentID = "" + s.TraceID = "" + s.Name = "" + s.Kind = SpanKindUnspecified + s.StartTime = time.Time{} + s.EndTime = time.Time{} + s.Status = SpanStatusUnset + s.StatusMsg = "" + s.Attributes = nil + s.Events = s.Events[:0] +} + +// SpanEvent represents a time-stamped event within a span +type SpanEvent struct { + Name string // Name of the event + Timestamp time.Time // When the event occurred + Attributes map[string]any // Additional attributes for the event +} + +// SpanKind represents the type of operation a span represents +// These are LLM-specific kinds designed for AI gateway observability +type SpanKind string + +const ( + // SpanKindUnspecified is the default span kind + SpanKindUnspecified SpanKind = "" + // SpanKindLLMCall represents a call to an LLM provider + SpanKindLLMCall SpanKind = "llm.call" + // SpanKindPlugin represents plugin execution (PreHook/PostHook) + SpanKindPlugin SpanKind = "plugin" + // SpanKindMCPTool represents an MCP tool invocation + SpanKindMCPTool SpanKind = "mcp.tool" + // SpanKindRetry represents a retry attempt + SpanKindRetry SpanKind = "retry" + // SpanKindFallback represents a fallback to another provider + SpanKindFallback SpanKind = "fallback" + // SpanKindHTTPRequest represents the root HTTP request span + SpanKindHTTPRequest SpanKind = "http.request" + // SpanKindEmbedding represents an embedding request + SpanKindEmbedding SpanKind = "embedding" + // SpanKindSpeech represents a text-to-speech request + SpanKindSpeech SpanKind = "speech" + // SpanKindTranscription represents a speech-to-text request + SpanKindTranscription SpanKind = "transcription" + // SpanKindInternal represents internal operations (key selection, etc.) + SpanKindInternal SpanKind = "internal" +) + +// SpanStatus represents the status of a span's operation +type SpanStatus string + +const ( + // SpanStatusUnset indicates status has not been set + SpanStatusUnset SpanStatus = "unset" + // SpanStatusOk indicates the operation completed successfully + SpanStatusOk SpanStatus = "ok" + // SpanStatusError indicates the operation failed + SpanStatusError SpanStatus = "error" +) + +// LLM Attribute Keys (gen_ai.* namespace) +// These follow the OpenTelemetry semantic conventions for GenAI +// and are compatible with both OTEL and Datadog backends. +const ( + // Provider and Model Attributes + AttrProviderName = "gen_ai.provider.name" + AttrRequestModel = "gen_ai.request.model" + + // Request Parameter Attributes + AttrMaxTokens = "gen_ai.request.max_tokens" + AttrTemperature = "gen_ai.request.temperature" + AttrTopP = "gen_ai.request.top_p" + AttrStopSequences = "gen_ai.request.stop_sequences" + AttrPresencePenalty = "gen_ai.request.presence_penalty" + AttrFrequencyPenalty = "gen_ai.request.frequency_penalty" + AttrParallelToolCall = "gen_ai.request.parallel_tool_calls" + AttrRequestUser = "gen_ai.request.user" + AttrBestOf = "gen_ai.request.best_of" + AttrEcho = "gen_ai.request.echo" + AttrLogitBias = "gen_ai.request.logit_bias" + AttrLogProbs = "gen_ai.request.logprobs" + AttrN = "gen_ai.request.n" + AttrSeed = "gen_ai.request.seed" + AttrSuffix = "gen_ai.request.suffix" + AttrDimensions = "gen_ai.request.dimensions" + AttrEncodingFormat = "gen_ai.request.encoding_format" + AttrLanguage = "gen_ai.request.language" + AttrPrompt = "gen_ai.request.prompt" + AttrResponseFormat = "gen_ai.request.response_format" + AttrFormat = "gen_ai.request.format" + AttrVoice = "gen_ai.request.voice" + AttrMultiVoiceConfig = "gen_ai.request.multi_voice_config" + AttrInstructions = "gen_ai.request.instructions" + AttrSpeed = "gen_ai.request.speed" + AttrMessageCount = "gen_ai.request.message_count" + + // Response Attributes + AttrResponseID = "gen_ai.response.id" + AttrResponseModel = "gen_ai.response.model" + AttrFinishReason = "gen_ai.response.finish_reason" + AttrSystemFprint = "gen_ai.response.system_fingerprint" + AttrServiceTier = "gen_ai.response.service_tier" + AttrCreated = "gen_ai.response.created" + AttrObject = "gen_ai.response.object" + AttrTimeToFirstToken = "gen_ai.response.time_to_first_token" + AttrTotalChunks = "gen_ai.response.total_chunks" + + // Plugin Attributes (for aggregated streaming post-hook spans) + AttrPluginInvocations = "plugin.invocation_count" + AttrPluginAvgDurationMs = "plugin.avg_duration_ms" + AttrPluginTotalDurationMs = "plugin.total_duration_ms" + AttrPluginErrorCount = "plugin.error_count" + + // Usage Attributes + AttrPromptTokens = "gen_ai.usage.prompt_tokens" + AttrCompletionTokens = "gen_ai.usage.completion_tokens" + AttrTotalTokens = "gen_ai.usage.total_tokens" + AttrInputTokens = "gen_ai.usage.input_tokens" + AttrOutputTokens = "gen_ai.usage.output_tokens" + AttrUsageCost = "gen_ai.usage.cost" + + // Error Attributes + AttrError = "gen_ai.error" + AttrErrorType = "gen_ai.error.type" + AttrErrorCode = "gen_ai.error.code" + + // Input/Output Attributes + AttrInputText = "gen_ai.input.text" + AttrInputMessages = "gen_ai.input.messages" + AttrInputSpeech = "gen_ai.input.speech" + AttrInputEmbedding = "gen_ai.input.embedding" + AttrOutputMessages = "gen_ai.output.messages" + + // Bifrost Context Attributes + AttrVirtualKeyID = "gen_ai.virtual_key_id" + AttrVirtualKeyName = "gen_ai.virtual_key_name" + AttrSelectedKeyID = "gen_ai.selected_key_id" + AttrSelectedKeyName = "gen_ai.selected_key_name" + AttrTeamID = "gen_ai.team_id" + AttrTeamName = "gen_ai.team_name" + AttrCustomerID = "gen_ai.customer_id" + AttrCustomerName = "gen_ai.customer_name" + AttrNumberOfRetries = "gen_ai.number_of_retries" + AttrFallbackIndex = "gen_ai.fallback_index" + + // Responses API Request Attributes + AttrPromptCacheKey = "gen_ai.request.prompt_cache_key" + AttrReasoningEffort = "gen_ai.request.reasoning_effort" + AttrReasoningSummary = "gen_ai.request.reasoning_summary" + AttrReasoningGenSummary = "gen_ai.request.reasoning_generate_summary" + AttrSafetyIdentifier = "gen_ai.request.safety_identifier" + AttrStore = "gen_ai.request.store" + AttrTextVerbosity = "gen_ai.request.text_verbosity" + AttrTextFormatType = "gen_ai.request.text_format_type" + AttrTopLogProbs = "gen_ai.request.top_logprobs" + AttrToolChoiceType = "gen_ai.request.tool_choice_type" + AttrToolChoiceName = "gen_ai.request.tool_choice_name" + AttrTools = "gen_ai.request.tools" + AttrTruncation = "gen_ai.request.truncation" + + // Responses API Response Attributes + AttrRespInclude = "gen_ai.responses.include" + AttrRespMaxOutputTokens = "gen_ai.responses.max_output_tokens" + AttrRespMaxToolCalls = "gen_ai.responses.max_tool_calls" + AttrRespMetadata = "gen_ai.responses.metadata" + AttrRespPreviousRespID = "gen_ai.responses.previous_response_id" + AttrRespPromptCacheKey = "gen_ai.responses.prompt_cache_key" + AttrRespReasoningText = "gen_ai.responses.reasoning" + AttrRespReasoningEffort = "gen_ai.responses.reasoning_effort" + AttrRespReasoningGenSum = "gen_ai.responses.reasoning_generate_summary" + AttrRespSafetyIdentifier = "gen_ai.responses.safety_identifier" + AttrRespStore = "gen_ai.responses.store" + AttrRespTemperature = "gen_ai.responses.temperature" + AttrRespTextVerbosity = "gen_ai.responses.text_verbosity" + AttrRespTextFormatType = "gen_ai.responses.text_format_type" + AttrRespTopLogProbs = "gen_ai.responses.top_logprobs" + AttrRespTopP = "gen_ai.responses.top_p" + AttrRespToolChoiceType = "gen_ai.responses.tool_choice_type" + AttrRespToolChoiceName = "gen_ai.responses.tool_choice_name" + AttrRespTruncation = "gen_ai.responses.truncation" + AttrRespTools = "gen_ai.responses.tools" + + // Batch Operation Attributes + AttrBatchID = "gen_ai.batch.id" + AttrBatchStatus = "gen_ai.batch.status" + AttrBatchObject = "gen_ai.batch.object" + AttrBatchEndpoint = "gen_ai.batch.endpoint" + AttrBatchInputFileID = "gen_ai.batch.input_file_id" + AttrBatchOutputFileID = "gen_ai.batch.output_file_id" + AttrBatchErrorFileID = "gen_ai.batch.error_file_id" + AttrBatchCompletionWin = "gen_ai.batch.completion_window" + AttrBatchCreatedAt = "gen_ai.batch.created_at" + AttrBatchExpiresAt = "gen_ai.batch.expires_at" + AttrBatchRequestsCount = "gen_ai.batch.requests_count" + AttrBatchDataCount = "gen_ai.batch.data_count" + AttrBatchResultsCount = "gen_ai.batch.results_count" + AttrBatchHasMore = "gen_ai.batch.has_more" + AttrBatchMetadata = "gen_ai.batch.metadata" + AttrBatchLimit = "gen_ai.batch.limit" + AttrBatchAfter = "gen_ai.batch.after" + AttrBatchBeforeID = "gen_ai.batch.before_id" + AttrBatchAfterID = "gen_ai.batch.after_id" + AttrBatchPageToken = "gen_ai.batch.page_token" + AttrBatchPageSize = "gen_ai.batch.page_size" + AttrBatchCountTotal = "gen_ai.batch.request_counts.total" + AttrBatchCountCompleted = "gen_ai.batch.request_counts.completed" + AttrBatchCountFailed = "gen_ai.batch.request_counts.failed" + AttrBatchFirstID = "gen_ai.batch.first_id" + AttrBatchLastID = "gen_ai.batch.last_id" + AttrBatchInProgressAt = "gen_ai.batch.in_progress_at" + AttrBatchFinalizingAt = "gen_ai.batch.finalizing_at" + AttrBatchCompletedAt = "gen_ai.batch.completed_at" + AttrBatchFailedAt = "gen_ai.batch.failed_at" + AttrBatchExpiredAt = "gen_ai.batch.expired_at" + AttrBatchCancellingAt = "gen_ai.batch.cancelling_at" + AttrBatchCancelledAt = "gen_ai.batch.cancelled_at" + AttrBatchNextCursor = "gen_ai.batch.next_cursor" + + // Transcription Response Attributes + AttrInputTokenDetailsText = "gen_ai.usage.input_token_details.text_tokens" + AttrInputTokenDetailsAudio = "gen_ai.usage.input_token_details.audio_tokens" + + // File Operation Attributes + AttrFileID = "gen_ai.file.id" + AttrFileObject = "gen_ai.file.object" + AttrFileFilename = "gen_ai.file.filename" + AttrFilePurpose = "gen_ai.file.purpose" + AttrFileBytes = "gen_ai.file.bytes" + AttrFileCreatedAt = "gen_ai.file.created_at" + AttrFileStatus = "gen_ai.file.status" + AttrFileStorageBackend = "gen_ai.file.storage_backend" + AttrFileDataCount = "gen_ai.file.data_count" + AttrFileHasMore = "gen_ai.file.has_more" + AttrFileDeleted = "gen_ai.file.deleted" + AttrFileContentType = "gen_ai.file.content_type" + AttrFileContentBytes = "gen_ai.file.content_bytes" + AttrFileLimit = "gen_ai.file.limit" + AttrFileAfter = "gen_ai.file.after" + AttrFileOrder = "gen_ai.file.order" +) diff --git a/core/schemas/tracer.go b/core/schemas/tracer.go new file mode 100644 index 0000000000..5c5132faaf --- /dev/null +++ b/core/schemas/tracer.go @@ -0,0 +1,186 @@ +// Package schemas defines the core schemas and types used by the Bifrost system. +package schemas + +import ( + "context" + "time" +) + +// SpanHandle is an opaque handle to a span, implementation-specific. +// Different Tracer implementations can use their own concrete types. +type SpanHandle interface{} + +// StreamAccumulatorResult contains the accumulated data from streaming chunks. +// This is the return type for tracer's streaming accumulation methods. +type StreamAccumulatorResult struct { + RequestID string // Request ID + Model string // Model used + Provider ModelProvider // Provider used + Status string // Status of the stream + Latency int64 // Latency in milliseconds + TimeToFirstToken int64 // Time to first token in milliseconds + OutputMessage *ChatMessage // Accumulated output message + OutputMessages []ResponsesMessage // For responses API + TokenUsage *BifrostLLMUsage // Token usage + Cost *float64 // Cost in dollars + ErrorDetails *BifrostError // Error details if any + AudioOutput *BifrostSpeechResponse // For speech streaming + TranscriptionOutput *BifrostTranscriptionResponse // For transcription streaming + FinishReason *string // Finish reason + RawResponse *string // Raw response + RawRequest interface{} // Raw request +} + +// Tracer defines the interface for distributed tracing in Bifrost. +// Implementations can be injected via BifrostConfig to enable automatic instrumentation. +// The interface is designed to be minimal and implementation-agnostic. +type Tracer interface { + // CreateTrace creates a new trace with optional parent ID and returns the trace ID. + // The parentID can be extracted from W3C traceparent headers for distributed tracing. + CreateTrace(parentID string) string + + // EndTrace completes a trace and returns the trace data for observation/export. + // After this call, the trace is removed from active tracking and returned for cleanup. + // Returns nil if trace not found. + EndTrace(traceID string) *Trace + + // StartSpan creates a new span as a child of the current span in context. + // Returns updated context with new span and a handle for the span. + // The context should be used for subsequent operations to maintain span hierarchy. + StartSpan(ctx context.Context, name string, kind SpanKind) (context.Context, SpanHandle) + + // EndSpan completes a span with status and optional message. + // Should be called when the operation represented by the span is complete. + EndSpan(handle SpanHandle, status SpanStatus, statusMsg string) + + // SetAttribute sets an attribute on the span. + // Attributes provide additional context about the operation. + SetAttribute(handle SpanHandle, key string, value any) + + // AddEvent adds a timestamped event to the span. + // Events represent discrete occurrences during the span's lifetime. + AddEvent(handle SpanHandle, name string, attrs map[string]any) + + // PopulateLLMRequestAttributes populates all LLM-specific request attributes on the span. + // This includes model parameters, input messages, temperature, max tokens, etc. + PopulateLLMRequestAttributes(handle SpanHandle, req *BifrostRequest) + + // PopulateLLMResponseAttributes populates all LLM-specific response attributes on the span. + // This includes output messages, tokens, usage stats, and error information if present. + PopulateLLMResponseAttributes(handle SpanHandle, resp *BifrostResponse, err *BifrostError) + + // StoreDeferredSpan stores a span handle for later completion (used for streaming requests). + // The span handle is stored keyed by trace ID so it can be retrieved when the stream completes. + StoreDeferredSpan(traceID string, handle SpanHandle) + + // GetDeferredSpanHandle retrieves a deferred span handle by trace ID. + // Returns nil if no deferred span exists for the given trace ID. + GetDeferredSpanHandle(traceID string) SpanHandle + + // ClearDeferredSpan removes the deferred span handle for a trace ID. + // Should be called after the deferred span has been completed. + ClearDeferredSpan(traceID string) + + // GetDeferredSpanID returns the span ID for the deferred span. + // Returns empty string if no deferred span exists. + GetDeferredSpanID(traceID string) string + + // AddStreamingChunk accumulates a streaming chunk for the deferred span. + // Pass the full BifrostResponse to capture content, tool calls, reasoning, etc. + // This is called for each streaming chunk to build up the complete response. + AddStreamingChunk(traceID string, response *BifrostResponse) + + // GetAccumulatedChunks returns the accumulated BifrostResponse, TTFT, and chunk count for a deferred span. + // Returns the built response (with content, tool calls, etc.), time-to-first-token in ms, and total chunk count. + // Returns nil, 0, 0 if no accumulated data exists. + GetAccumulatedChunks(traceID string) (response *BifrostResponse, ttftMs int64, chunkCount int) + + // CreateStreamAccumulator creates a new stream accumulator for the given trace ID. + // This should be called at the start of a streaming request. + CreateStreamAccumulator(traceID string, startTime time.Time) + + // CleanupStreamAccumulator removes the stream accumulator for the given trace ID. + // This should be called after the streaming request is complete. + CleanupStreamAccumulator(traceID string) + + // ProcessStreamingChunk processes a streaming chunk and accumulates it. + // Returns the accumulated result. IsFinal will be true when the stream is complete. + // This method is used by plugins to access accumulated streaming data. + // The ctx parameter must contain the stream end indicator for proper final chunk detection. + ProcessStreamingChunk(ctx *BifrostContext, traceID string, result *BifrostResponse, err *BifrostError) *StreamAccumulatorResult + + // Stop releases resources associated with the tracer. + // Should be called during shutdown to stop background goroutines. + Stop() +} + +// NoOpTracer is a tracer that does nothing (default when tracing disabled). +// It satisfies the Tracer interface but performs no actual tracing operations. +type NoOpTracer struct{} + +// CreateTrace returns an empty string (no trace created). +func (n *NoOpTracer) CreateTrace(_ string) string { return "" } + +// EndTrace returns nil (no trace to end). +func (n *NoOpTracer) EndTrace(_ string) *Trace { return nil } + +// StartSpan returns the context unchanged and a nil handle. +func (n *NoOpTracer) StartSpan(ctx context.Context, _ string, _ SpanKind) (context.Context, SpanHandle) { + return ctx, nil +} + +// EndSpan does nothing. +func (n *NoOpTracer) EndSpan(_ SpanHandle, _ SpanStatus, _ string) {} + +// SetAttribute does nothing. +func (n *NoOpTracer) SetAttribute(_ SpanHandle, _ string, _ any) {} + +// AddEvent does nothing. +func (n *NoOpTracer) AddEvent(_ SpanHandle, _ string, _ map[string]any) {} + +// PopulateLLMRequestAttributes does nothing. +func (n *NoOpTracer) PopulateLLMRequestAttributes(_ SpanHandle, _ *BifrostRequest) {} + +// PopulateLLMResponseAttributes does nothing. +func (n *NoOpTracer) PopulateLLMResponseAttributes(_ SpanHandle, _ *BifrostResponse, _ *BifrostError) { +} + +// StoreDeferredSpan does nothing. +func (n *NoOpTracer) StoreDeferredSpan(_ string, _ SpanHandle) {} + +// GetDeferredSpanHandle returns nil. +func (n *NoOpTracer) GetDeferredSpanHandle(_ string) SpanHandle { return nil } + +// ClearDeferredSpan does nothing. +func (n *NoOpTracer) ClearDeferredSpan(_ string) {} + +// GetDeferredSpanID returns empty string. +func (n *NoOpTracer) GetDeferredSpanID(_ string) string { return "" } + +// AddStreamingChunk does nothing. +func (n *NoOpTracer) AddStreamingChunk(_ string, _ *BifrostResponse) {} + +// GetAccumulatedChunks returns nil, 0, 0. +func (n *NoOpTracer) GetAccumulatedChunks(_ string) (*BifrostResponse, int64, int) { return nil, 0, 0 } + +// CreateStreamAccumulator does nothing. +func (n *NoOpTracer) CreateStreamAccumulator(_ string, _ time.Time) {} + +// CleanupStreamAccumulator does nothing. +func (n *NoOpTracer) CleanupStreamAccumulator(_ string) {} + +// ProcessStreamingChunk returns nil. +func (n *NoOpTracer) ProcessStreamingChunk(_ *BifrostContext, _ string, _ *BifrostResponse, _ *BifrostError) *StreamAccumulatorResult { + return nil +} + +// Stop does nothing. +func (n *NoOpTracer) Stop() {} + +// DefaultTracer returns a no-op tracer for use when tracing is disabled. +func DefaultTracer() Tracer { + return &NoOpTracer{} +} + +// Ensure NoOpTracer implements Tracer at compile time +var _ Tracer = (*NoOpTracer)(nil) diff --git a/core/schemas/utils.go b/core/schemas/utils.go index 66ec247f97..9bfe96bd6e 100644 --- a/core/schemas/utils.go +++ b/core/schemas/utils.go @@ -7,8 +7,6 @@ import ( "regexp" "strconv" "strings" - - "github.com/bytedance/sonic" ) // Ptr creates a pointer to any value. @@ -267,7 +265,7 @@ func JsonifyInput(input interface{}) string { if input == nil { return "{}" } - jsonString, err := sonic.MarshalString(input) + jsonString, err := MarshalString(input) if err != nil { return "{}" } @@ -524,6 +522,30 @@ func SafeExtractFromMap(m map[string]interface{}, key string) (interface{}, bool return value, exists } +// SafeExtractStringMap safely extracts a map[string]string from an interface{} with type checking. +// Handles both direct map[string]string and JSON-deserialized map[string]interface{} cases. +func SafeExtractStringMap(value interface{}) (map[string]string, bool) { + if value == nil { + return nil, false + } + switch v := value.(type) { + case map[string]string: + return v, true + case map[string]interface{}: + result := make(map[string]string, len(v)) + for key, val := range v { + if str, ok := SafeExtractString(val); ok { + result[key] = str + } else { + return nil, false + } + } + return result, true + default: + return nil, false + } +} + func SafeExtractOrderedMap(value interface{}) (OrderedMap, bool) { if value == nil { return OrderedMap{}, false @@ -727,6 +749,106 @@ func deepCopyChatContentBlock(original ChatContentBlock) ChatContentBlock { return copy } +// DeepCopyChatTool creates a deep copy of a ChatTool +// to prevent shared data mutation between different plugin accumulators +func DeepCopyChatTool(original ChatTool) ChatTool { + copyTool := ChatTool{ + Type: original.Type, + } + + // Deep copy Function if present + if original.Function != nil { + copyTool.Function = &ChatToolFunction{ + Name: original.Function.Name, + } + + if original.Function.Description != nil { + copyDescription := *original.Function.Description + copyTool.Function.Description = ©Description + } + + if original.Function.Parameters != nil { + copyParams := &ToolFunctionParameters{ + Type: original.Function.Parameters.Type, + } + + if original.Function.Parameters.Description != nil { + copyParamDesc := *original.Function.Parameters.Description + copyParams.Description = ©ParamDesc + } + + if original.Function.Parameters.Required != nil { + copyParams.Required = make([]string, len(original.Function.Parameters.Required)) + copy(copyParams.Required, original.Function.Parameters.Required) + } + + if original.Function.Parameters.Properties != nil { + // Deep copy the map + copyProps := make(map[string]interface{}, len(*original.Function.Parameters.Properties)) + for k, v := range *original.Function.Parameters.Properties { + copyProps[k] = DeepCopy(v) + } + orderedProps := OrderedMap(copyProps) + copyParams.Properties = &orderedProps + } + + if original.Function.Parameters.Enum != nil { + copyParams.Enum = make([]string, len(original.Function.Parameters.Enum)) + copy(copyParams.Enum, original.Function.Parameters.Enum) + } + + if original.Function.Parameters.AdditionalProperties != nil { + copyAdditionalProps := *original.Function.Parameters.AdditionalProperties + copyParams.AdditionalProperties = ©AdditionalProps + } + + copyTool.Function.Parameters = copyParams + } + + if original.Function.Strict != nil { + copyStrict := *original.Function.Strict + copyTool.Function.Strict = ©Strict + } + } + + // Deep copy Custom if present + if original.Custom != nil { + copyTool.Custom = &ChatToolCustom{} + + if original.Custom.Format != nil { + copyFormat := &ChatToolCustomFormat{ + Type: original.Custom.Format.Type, + } + + if original.Custom.Format.Grammar != nil { + copyGrammar := &ChatToolCustomGrammarFormat{ + Definition: original.Custom.Format.Grammar.Definition, + Syntax: original.Custom.Format.Grammar.Syntax, + } + copyFormat.Grammar = copyGrammar + } + + copyTool.Custom.Format = copyFormat + } + } + + // Deep copy CacheControl if present + if original.CacheControl != nil { + copyCacheControl := &CacheControl{ + Type: original.CacheControl.Type, + } + + if original.CacheControl.TTL != nil { + copyTTL := *original.CacheControl.TTL + copyCacheControl.TTL = ©TTL + } + + copyTool.CacheControl = copyCacheControl + } + + return copyTool +} + // DeepCopyResponsesMessage creates a deep copy of a ResponsesMessage // to prevent shared data mutation between different plugin accumulators func DeepCopyResponsesMessage(original ResponsesMessage) ResponsesMessage { @@ -1058,6 +1180,31 @@ func IsGeminiModel(model string) bool { return strings.Contains(model, "gemini") } +// List of grok reasoning models +var grokReasoningModels = []string{ + "grok-3", + "grok-3-mini", + "grok-4", + "grok-4-fast-reasoning", + "grok-4-1-fast-reasoning", + "grok-code-fast-1", +} + +// IsGrokReasoningModel checks if the given model is a grok reasoning model +func IsGrokReasoningModel(model string) bool { + // Check if the model matches any of the reasoning models + for _, reasoningModel := range grokReasoningModels { + if strings.Contains(model, reasoningModel) { + // Make sure it's not a non-reasoning variant. Safety check for variants + if strings.Contains(model, "non-reasoning") { + return false + } + return true + } + } + return false +} + // Precompiled regexes for different kinds of version suffixes. var ( // Anthropic-style date: 20250514 diff --git a/core/utils.go b/core/utils.go index e38df6a279..a5a60cfe08 100644 --- a/core/utils.go +++ b/core/utils.go @@ -198,6 +198,18 @@ func IsStreamRequestType(reqType schemas.RequestType) bool { return reqType == schemas.TextCompletionStreamRequest || reqType == schemas.ChatCompletionStreamRequest || reqType == schemas.ResponsesStreamRequest || reqType == schemas.SpeechStreamRequest || reqType == schemas.TranscriptionStreamRequest } +func GetTracerFromContext(ctx *schemas.BifrostContext) (schemas.Tracer, string, error) { + tracer, ok := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer) + if !ok || tracer == nil { + return nil, "", fmt.Errorf("tracer not found in context") + } + traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string) + if !ok || traceID == "" { + return nil, "", fmt.Errorf("traceID not found in context") + } + return tracer, traceID, nil +} + // isBatchRequestType returns true if the given request type is a batch API operation. func isBatchRequestType(reqType schemas.RequestType) bool { return reqType == schemas.BatchCreateRequest || reqType == schemas.BatchListRequest || reqType == schemas.BatchRetrieveRequest || reqType == schemas.BatchCancelRequest || reqType == schemas.BatchResultsRequest @@ -396,3 +408,8 @@ func isPrivateIP(ip net.IP) bool { } return false } + +// sanitizeSpanName sanitizes a span name to remove capital letters and spaces to make it a valid span name +func sanitizeSpanName(name string) string { + return strings.ToLower(strings.ReplaceAll(name, " ", "-")) +} diff --git a/core/version b/core/version index f4872e765a..06c7347f09 100644 --- a/core/version +++ b/core/version @@ -1 +1 @@ -1.2.49 \ No newline at end of file +1.3.8 \ No newline at end of file diff --git a/docs/changelogs/v1.4.0-prerelease8.mdx b/docs/changelogs/v1.4.0-prerelease8.mdx new file mode 100644 index 0000000000..a7abec74d2 --- /dev/null +++ b/docs/changelogs/v1.4.0-prerelease8.mdx @@ -0,0 +1,66 @@ +--- +title: "v1.4.0-prerelease8" +description: "v1.4.0-prerelease8 changelog - 2026-01-09" +--- + + + ```bash + npx -y @maximhq/bifrost --transport-version v1.4.0-prerelease8 + ``` + + + ```bash + docker pull maximhq/bifrost:v1.4.0-prerelease8 + docker run -p 8080:8080 maximhq/bifrost:v1.4.0-prerelease8 + ``` + + + + +- fix: vertex list models enhanced to support values from deployments +- fix: header keys are now converted to lowercase for better consistency in plugin usage +- fix: gemini system message conversion and added support for using instructions parameter as a fallback when no system message + + + +- fix: vertex list models enhanced to support values from deployments +- fix: gemini system message conversion and added support for using instructions parameter as a fallback when no system message + + + +- chore: updated core version to 1.3.7 + + + +- chore: updated core version to 1.3.7 and framework version to 1.2.7 + + + +- chore: updated core version to 1.3.7 and framework version to 1.2.7 + + + +- chore: updated core version to 1.3.7 and framework version to 1.2.7 + + + +- chore: updated core version to 1.3.7 and framework version to 1.2.7 + + + +- chore: updated core version to 1.3.7 and framework version to 1.2.7 + + + +- chore: updated core version to 1.3.7 and framework version to 1.2.7 + + + +- chore: updated core version to 1.3.7 and framework version to 1.2.7 + + + +- feat: adds support for external Prometheus registry +- chore: updated core version to 1.3.7 and framework version to 1.2.7 + + diff --git a/docs/changelogs/v1.4.0-prerelease9.mdx b/docs/changelogs/v1.4.0-prerelease9.mdx new file mode 100644 index 0000000000..ba2d066b40 --- /dev/null +++ b/docs/changelogs/v1.4.0-prerelease9.mdx @@ -0,0 +1,62 @@ +--- +title: "v1.4.0-prerelease9" +description: "v1.4.0-prerelease9 changelog - 2026-01-11" +--- + + + ```bash + npx -y @maximhq/bifrost --transport-version v1.4.0-prerelease9 + ``` + + + ```bash + docker pull maximhq/bifrost:v1.4.0-prerelease9 + docker run -p 8080:8080 maximhq/bifrost:v1.4.0-prerelease9 + ``` + + + + +- fix: handles client disconnects and server timeouts gracefully for streaming responses + + + +- fix: adds timeout and connection disconnect handling for streaming responses + + + +- chore: updated core version to 1.3.8 + + + +- chore: updated core version to 1.3.8 and framework version to 1.2.8 + + + +- chore: updated core version to 1.3.8 and framework version to 1.2.8 + + + +- chore: updated core version to 1.3.8 and framework version to 1.2.8 + + + +- chore: updated core version to 1.3.8 and framework version to 1.2.8 + + + +- chore: updated core version to 1.3.8 and framework version to 1.2.8 + + + +- chore: updated core version to 1.3.8 and framework version to 1.2.8 + + + +- chore: updated core version to 1.3.8 and framework version to 1.2.8 + + + +- chore: updated core version to 1.3.8 and framework version to 1.2.8 + + diff --git a/docs/docs.json b/docs/docs.json index df8b2e79ff..64191b783d 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -358,6 +358,8 @@ "tab": "Changelogs", "icon": "bolt", "pages": [ + "changelogs/v1.4.0-prerelease9", + "changelogs/v1.4.0-prerelease8", "changelogs/v1.4.0-prerelease7", "changelogs/v1.4.0-prerelease6", "changelogs/v1.4.0-prerelease5", diff --git a/docs/plugins/migration-guide.mdx b/docs/plugins/migration-guide.mdx index b2ceab6d67..f6f0a4e6c4 100644 --- a/docs/plugins/migration-guide.mdx +++ b/docs/plugins/migration-guide.mdx @@ -77,7 +77,7 @@ func TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[s // Modify req in-place. Return (*HTTPResponse, nil) to short-circuit. func HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { // Add custom header (in-place modification) - req.Headers["X-Custom-Header"] = "value" + req.Headers["x-custom-header"] = "value" // Modify body (in-place modification) var body map[string]any @@ -128,13 +128,13 @@ func HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPReques **v1.3.x:** ```go -headers["Authorization"] = "Bearer " + token +headers["authorization"] = "Bearer " + token return headers, body, nil ``` **v1.4.x+:** ```go -req.Headers["Authorization"] = "Bearer " + token +req.Headers["authorization"] = "Bearer " + token return nil, nil ``` @@ -147,7 +147,7 @@ apiKey := headers["X-API-Key"] **v1.4.x+:** ```go -apiKey := req.Headers["X-API-Key"] +apiKey := req.Headers["x-api-key"] ``` ### Conditional Processing @@ -155,7 +155,7 @@ apiKey := req.Headers["X-API-Key"] **v1.3.x:** ```go func TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - if headers["X-Skip-Processing"] == "true" { + if headers["x-skip-processing"] == "true" { return headers, body, nil } // Process... @@ -166,7 +166,7 @@ func TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[s **v1.4.x+:** ```go func HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { - if req.Headers["X-Skip-Processing"] == "true" { + if req.Headers["x-skip-processing"] == "true" { return nil, nil // Continue without modification } // Process... @@ -179,7 +179,7 @@ func HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPReques **v1.3.x:** ```go func TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - if headers["X-API-Key"] == "" { + if headers["x-api-key"] == "" { return nil, nil, fmt.Errorf("missing API key") } return headers, body, nil @@ -189,7 +189,7 @@ func TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[s **v1.4.x+:** ```go func HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { - if req.Headers["X-API-Key"] == "" { + if req.Headers["x-api-key"] == "" { // Return a custom response to short-circuit return &schemas.HTTPResponse{ StatusCode: 401, @@ -284,10 +284,10 @@ Make sure you're modifying `req.Headers` directly: ```go // Set header (case-sensitive keys) -req.Headers["X-Custom-Header"] = "value" +req.Headers["x-custom-header"] = "value" // Read header -value := req.Headers["X-Custom-Header"] +value := req.Headers["x-custom-header"] ``` ## Need Help? diff --git a/docs/plugins/writing-go-plugin.mdx b/docs/plugins/writing-go-plugin.mdx index ff4c12a434..a5e13a9e7f 100644 --- a/docs/plugins/writing-go-plugin.mdx +++ b/docs/plugins/writing-go-plugin.mdx @@ -91,8 +91,14 @@ func GetName() string { // Only called when using HTTP transport (bifrost-http) func HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { fmt.Println("HTTPTransportIntercept called") - // Modify request in-place (headers, body, query params) - req.Headers["X-Custom-Header"] = "custom-value" + + // Read headers using case-insensitive helper (recommended) + contentType := req.CaseInsensitiveHeaderLookup("Content-Type") + fmt.Printf("Content-Type: %s\n", contentType) + + // Modify request in-place (use lowercase for direct map access) + req.Headers["x-custom-header"] = "custom-value" + // Return nil to continue, or return &schemas.HTTPResponse{} to short-circuit return nil, nil } @@ -223,6 +229,25 @@ Key points: - Return `(*HTTPResponse, nil)` to short-circuit with response - Return `(nil, error)` to short-circuit with error + +**Header and Query Parameter Lookups**: Use the case-insensitive helper methods for reading headers and query parameters: + +```go +// āœ… Correct - use helper methods for case-insensitive lookup +contentType := req.CaseInsensitiveHeaderLookup("Content-Type") +apiKey := req.CaseInsensitiveQueryLookup("api_key") + +// Also works with any casing +contentType := req.CaseInsensitiveHeaderLookup("content-type") +contentType := req.CaseInsensitiveHeaderLookup("CONTENT-TYPE") + +// For setting headers, use direct map access +req.Headers["X-Custom-Header"] = "value" +``` + +The helper methods (`CaseInsensitiveHeaderLookup` and `CaseInsensitiveQueryLookup`) ensure your plugin works correctly regardless of how the client sends header/query parameter names. + + This function is **only called** when using `bifrost-http`. It's **not invoked** when using Bifrost as a Go SDK. diff --git a/docs/plugins/writing-plugin.mdx b/docs/plugins/writing-plugin.mdx deleted file mode 100644 index aa4c9e58be..0000000000 --- a/docs/plugins/writing-plugin.mdx +++ /dev/null @@ -1,845 +0,0 @@ ---- -title: "Writing Plugins" -description: "Step-by-step guide to creating custom plugins for Bifrost using the hello-world example" -icon: "code" ---- - -## Overview - -This guide walks you through creating a custom plugin for Bifrost using our [hello-world example](https://github.com/maximhq/bifrost/tree/main/examples/plugins/hello-world) as a reference. You'll learn how to structure your plugin, implement required functions, build the shared object, and integrate it with Bifrost. - -## Prerequisites - -Before you start, ensure you have: - -- **Go 1.25.5** installed (must match Bifrost's Go version) -- **Linux or macOS** (Go plugins are not supported on Windows) -- **Bifrost** installed and configured -- Basic understanding of Go programming - -Make sure your go.mod has the go version pinned to 1.25.5 - -## Project Structure - -A minimal plugin project should have the following structure: - -``` -hello-world/ -ā”œā”€ā”€ main.go # Plugin implementation -ā”œā”€ā”€ go.mod # Go module definition -ā”œā”€ā”€ go.sum # Dependency checksums -ā”œā”€ā”€ Makefile # Build automation -└── .gitignore # Git ignore patterns -``` - -## Step 1: Initialize Your Plugin Project - -Create a new directory and initialize a Go module: - -```bash -mkdir my-plugin -cd my-plugin -go mod init github.com/yourusername/my-plugin -``` - -Add Bifrost as a dependency: - -```bash -go get github.com/maximhq/bifrost/core@latest -``` - -Your `go.mod` should look like this: - -```go -module github.com/yourusername/my-plugin - -go 1.25.5 - -require github.com/maximhq/bifrost/core v1.2.38 -``` - -## Step 2: Implement the Plugin Interface - -Create `main.go` with the required plugin functions. Here's the complete hello-world example: - - - -```go -package main - -import ( - "fmt" - - "github.com/maximhq/bifrost/core/schemas" - "github.com/valyala/fasthttp" -) - -// Init is called when the plugin is loaded -// config contains the plugin configuration from config.json -func Init(config any) error { - fmt.Println("Init called") - // Initialize your plugin here (database connections, API clients, etc.) - return nil -} - -// GetName returns the plugin's unique identifier -func GetName() string { - return "Hello World Plugin" -} - -// HTTPTransportMiddleware returns a middleware for HTTP transport -// Only called when using HTTP transport (bifrost-http) -func HTTPTransportMiddleware() schemas.BifrostHTTPMiddleware { - return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { - return func(ctx *fasthttp.RequestCtx) { - fmt.Println("HTTPTransportMiddleware called") - // Modify request headers/body via ctx.Request before calling next - // Call next handler in the chain - next(ctx) - // Can also modify response via ctx.Response after next returns - } - } -} - -// PreHook is called before the request is sent to the provider -// This is where you can modify requests or short-circuit the flow -func PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { - fmt.Println("PreHook called") - // Modify the request or return a short-circuit to skip provider call - return req, nil, nil -} - -// PostHook is called after receiving a response from the provider -// This is where you can modify responses or handle errors -func PostHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { - fmt.Println("PostHook called") - // Modify the response or error before returning to caller - return resp, bifrostErr, nil -} - -// Cleanup is called when Bifrost shuts down -func Cleanup() error { - fmt.Println("Cleanup called") - // Clean up resources (close connections, flush buffers, etc.) - return nil -} -``` - - -```go -package main - -import ( - "fmt" - - "github.com/maximhq/bifrost/core/schemas" -) - -// Init is called when the plugin is loaded -// config contains the plugin configuration from config.json -func Init(config any) error { - fmt.Println("Init called") - // Initialize your plugin here (database connections, API clients, etc.) - return nil -} - -// GetName returns the plugin's unique identifier -func GetName() string { - return "Hello World Plugin" -} - -// TransportInterceptor modifies raw HTTP headers and body -// Only called when using HTTP transport (bifrost-http) -func TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - fmt.Println("TransportInterceptor called") - // Modify headers or body before they enter Bifrost core - return headers, body, nil -} - -// PreHook is called before the request is sent to the provider -// This is where you can modify requests or short-circuit the flow -func PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { - fmt.Println("PreHook called") - // Modify the request or return a short-circuit to skip provider call - return req, nil, nil -} - -// PostHook is called after receiving a response from the provider -// This is where you can modify responses or handle errors -func PostHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { - fmt.Println("PostHook called") - // Modify the response or error before returning to caller - return resp, bifrostErr, nil -} - -// Cleanup is called when Bifrost shuts down -func Cleanup() error { - fmt.Println("Cleanup called") - // Clean up resources (close connections, flush buffers, etc.) - return nil -} -``` - - - -### Understanding Each Function - -#### `Init(config any) error` - -Called once when the plugin is loaded. Use this to: -- Parse plugin configuration -- Initialize database connections -- Set up API clients -- Validate required environment variables - -```go -func Init(config any) error { - // Parse configuration - cfg, ok := config.(map[string]interface{}) - if !ok { - return fmt.Errorf("invalid config format") - } - - apiKey := cfg["api_key"].(string) - // Initialize your resources - return nil -} -``` - -#### `GetName() string` - -Returns a unique identifier for your plugin. This name appears in logs and status reports. - - - -#### `HTTPTransportMiddleware()` - -**HTTP transport only.** Returns a middleware that wraps the HTTP request handler chain. Use this to: -- Intercept and modify requests before they enter Bifrost core -- Intercept and modify responses before they're returned to clients -- Implement authentication or logging at the transport layer -- Access raw `*fasthttp.RequestCtx` for full HTTP control - -The middleware pattern requires calling `next(ctx)` to pass control to subsequent handlers. - - -This function is **only called** when using `bifrost-http`. It's **not invoked** when using Bifrost as a Go SDK. - - - -#### `TransportInterceptor(...)` - -**HTTP transport only.** Called before requests enter Bifrost core. Use this to: -- Add or modify HTTP headers -- Transform request body -- Implement authentication at the transport layer - - -This function is **only called** when using `bifrost-http`. It's **not invoked** when using Bifrost as a Go SDK. - - - - -#### `PreHook(...)` - -Called before each provider request. Use this to: -- Modify request parameters -- Add logging or monitoring -- Implement caching (check cache, return cached response) -- Apply governance rules (rate limiting, budget checks) -- **Short-circuit** to skip provider calls - -**Short-Circuiting Example:** - -```go -func PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { - // Return cached response without calling provider - if cachedResponse := checkCache(req) { - return req, &schemas.PluginShortCircuit{ - Response: cachedResponse, - }, nil - } - return req, nil, nil -} -``` - -#### `PostHook(...)` - -Called after provider responses (or short-circuits). Use this to: -- Transform responses -- Log response data -- Store responses in cache -- Handle errors or implement fallback logic -- Add custom metadata - -**Response Transformation Example:** - -```go -func PostHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { - if resp != nil && resp.ChatResponse != nil { - // Add custom metadata - resp.ChatResponse.ExtraFields.RawResponse = map[string]interface{}{ - "plugin_processed": true, - "timestamp": time.Now().Unix(), - } - } - return resp, bifrostErr, nil -} -``` - -#### `Cleanup() error` - -Called on Bifrost shutdown. Use this to: -- Close database connections -- Flush buffers -- Save state -- Release resources - -## Step 3: Create a Makefile - -Create a `Makefile` to automate building your plugin: - -```makefile -.PHONY: all build clean install help - -PLUGIN_NAME = my-plugin -OUTPUT_DIR = build - -# Platform detection -UNAME_S := $(shell uname -s) -ifeq ($(UNAME_S),Linux) - PLUGIN_EXT = .so - PLATFORM = linux -endif -ifeq ($(UNAME_S),Darwin) - PLUGIN_EXT = .so - PLATFORM = darwin -endif - -# Architecture detection -UNAME_M := $(shell uname -m) -ifeq ($(UNAME_M),x86_64) - ARCH = amd64 -endif -ifeq ($(UNAME_M),arm64) - ARCH = arm64 -endif - -OUTPUT = $(OUTPUT_DIR)/$(PLUGIN_NAME)$(PLUGIN_EXT) - -build: ## Build the plugin for current platform - @echo "Building plugin for $(PLATFORM)/$(ARCH)..." - @mkdir -p $(OUTPUT_DIR) - go build -buildmode=plugin -o $(OUTPUT) main.go - @echo "Plugin built successfully: $(OUTPUT)" - -clean: ## Remove build artifacts - @rm -rf $(OUTPUT_DIR) - -install: build ## Build and install to Bifrost plugins directory - @mkdir -p ~/.bifrost/plugins - @cp $(OUTPUT) ~/.bifrost/plugins/ - @echo "Plugin installed to ~/.bifrost/plugins/" -``` - -## Step 4: Build Your Plugin - -Build the plugin using the Makefile: - -```bash -make build -``` - -This creates `build/my-plugin.so` in your project directory. - -For production, you may need to build for specific platforms: - -```bash -# Build for Linux AMD64 -GOOS=linux GOARCH=amd64 go build -buildmode=plugin -o my-plugin-linux-amd64.so main.go - -# Build for Linux ARM64 -GOOS=linux GOARCH=arm64 go build -buildmode=plugin -o my-plugin-linux-arm64.so main.go - -# Build for macOS ARM64 (M1/M2) -GOOS=darwin GOARCH=arm64 go build -buildmode=plugin -o my-plugin-darwin-arm64.so main.go -``` - - -**Cross-compilation doesn't work for plugins!** You must build on the target platform. If you need a Linux plugin, build it on a Linux machine or use Docker. - - -## Step 5: Configure Bifrost to Load Your Plugin - -Add your plugin to Bifrost's `config.json`: - -```json -{ - "plugins": [ - { - "enabled": true, - "name": "my-plugin", - "path": "/path/to/my-plugin.so", - "version": 1, - "config": { - "api_key": "your-api-key", - "custom_setting": "value" - } - } - ] -} -``` - -### Plugin Configuration Options - -- `enabled` - Set to `true` to load the plugin -- `name` - Plugin identifier (used in logs) -- `path` - Absolute or relative path to the `.so` file -- `config` - Plugin-specific configuration passed to `Init()` -- `version` - (Optional) Plugin version number (default: 1). Increment this value to force a reload of the plugin and database update when Bifrost restarts. Useful when you want to ensure config changes take effect without manually clearing plugin state. - -## Step 6: Test Your Plugin - -Start Bifrost and verify your plugin loads: - -```bash -./bifrost-http -``` - -You should see output like: - -``` -Init called -[INFO] Plugin loaded: Hello World Plugin -``` - -Make a test request: - -```bash -curl -X POST http://localhost:8080/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "openai/gpt-4o-mini", - "messages": [{"role": "user", "content": "Hello!"}] - }' -``` - -Check the logs for plugin hook calls: - - - -``` -HTTPTransportMiddleware called -PreHook called -PostHook called -``` - - -``` -TransportInterceptor called -PreHook called -PostHook called -``` - - - -## Advanced Plugin Patterns - -### Stateful Plugins - -For plugins that need to maintain state across requests: - -```go -package main - -import ( - "sync" - "github.com/maximhq/bifrost/core/schemas" -) - -var ( - requestCount int64 - mu sync.Mutex -) - -func PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { - mu.Lock() - requestCount++ - count := requestCount - mu.Unlock() - - // Use count for rate limiting, metrics, etc. - return req, nil, nil -} -``` - -### Error Handling with Fallbacks - -Control whether Bifrost should try fallback providers: - -```go -func PostHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { - if bifrostErr != nil { - // Allow fallbacks for rate limit errors - if bifrostErr.Error.Type != nil && *bifrostErr.Error.Type == "rate_limit" { - allowFallbacks := true - bifrostErr.AllowFallbacks = &allowFallbacks - } else { - // Don't try fallbacks for auth errors - allowFallbacks := false - bifrostErr.AllowFallbacks = &allowFallbacks - } - } - return resp, bifrostErr, nil -} -``` - -### Caching Plugin Example - -```go -var cache sync.Map - -func PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { - // Generate cache key from request - key := generateCacheKey(req) - - // Check cache - if cached, ok := cache.Load(key); ok { - return req, &schemas.PluginShortCircuit{ - Response: cached.(*schemas.BifrostResponse), - }, nil - } - - return req, nil, nil -} - -func PostHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { - if resp != nil && bifrostErr == nil { - // Store in cache - key := generateCacheKeyFromResponse(resp) - cache.Store(key, resp) - } - return resp, bifrostErr, nil -} -``` - -## Troubleshooting - -### Plugin Fails to Load - -**Error**: `plugin: not a plugin file` - -**Solution**: Ensure you built with `-buildmode=plugin`: -```bash -go build -buildmode=plugin -o plugin.so main.go -``` - -### Version Mismatch Errors - -**Error**: `plugin was built with a different version of package` - -**Why this happens**: Go's plugin system requires **exact version matching** for: -- The Go compiler version -- **All shared packages** (especially `github.com/maximhq/bifrost/core`) -- **Transitive dependencies** (packages that your dependencies depend on) - -This is more strict than typical Go builds. Even if only one transitive dependency differs by a patch version, the plugin will fail to load. - -**Solution**: Ensure your plugin is built with the exact same versions as Bifrost. - -**Step 1: Diagnose the mismatch** - -Use `go version -m` to inspect the build info of both your plugin and the Bifrost binary: - -```bash -# Check what versions your plugin was built with: -$ go version -m my-plugin.so -my-plugin.so: go1.25.5 - dep github.com/maximhq/bifrost/core v1.3.50 - dep github.com/valyala/fasthttp v1.51.0 - -# Check what versions Bifrost was built with: -$ go version -m bifrost-http -bifrost-http: go1.25.5 - dep github.com/maximhq/bifrost/core v1.3.54 # <-- MISMATCH! - dep github.com/valyala/fasthttp v1.55.0 # <-- MISMATCH! -``` - -Notice that even though the Go version matches (`go1.25.5`), the **package versions** are different — this causes the error. - -**Step 2: Update your plugin dependencies** - -```bash -# Update to match Bifrost's core version -go get github.com/maximhq/bifrost/core@v1.3.54 -go mod tidy - -# Rebuild the plugin -go build -buildmode=plugin -o my-plugin.so main.go -``` - -**Step 3: Verify the fix** - -```bash -# Confirm versions now match -$ go version -m my-plugin.so | grep bifrost - dep github.com/maximhq/bifrost/core v1.3.54 # Now matches! -``` - - -**Pro tip**: Pin exact versions in your `go.mod` and keep your plugin's dependencies in sync with the Bifrost version you're deploying. Consider building both Bifrost and your plugins in the same CI pipeline to guarantee version alignment. - - -### Platform/Architecture Mismatch - -**Error**: `cannot load plugin built for GOOS=linux on darwin` - -**Solution**: Build on the target platform or use the correct GOOS/GOARCH for your system. - -### Function Not Found - -**Error**: `plugin: symbol Init not found` - -**Solution**: Ensure all required functions are exported (start with capital letter) and have the correct signature. - -## Source Code Reference - -The complete hello-world example is available in the Bifrost repository: - -- **Full Example**: [examples/plugins/hello-world](https://github.com/maximhq/bifrost/tree/main/examples/plugins/hello-world) -- **main.go**: [Plugin implementation](https://github.com/maximhq/bifrost/blob/main/examples/plugins/hello-world/main.go) -- **Makefile**: [Build configuration](https://github.com/maximhq/bifrost/blob/main/examples/plugins/hello-world/Makefile) -- **go.mod**: [Dependencies](https://github.com/maximhq/bifrost/blob/main/examples/plugins/hello-world/go.mod) - -## Real-World Plugin Examples - -Explore production-ready plugins in the Bifrost repository: - -- **[Mocker Plugin](https://github.com/maximhq/bifrost/tree/main/plugins/mocker)** - Mock responses for testing -- **[Logging Plugin](https://github.com/maximhq/bifrost/tree/main/plugins/logging)** - Advanced request/response logging -- **[Semantic Cache Plugin](https://github.com/maximhq/bifrost/tree/main/plugins/semanticcache)** - Cache based on semantic similarity -- **[Governance Plugin](https://github.com/maximhq/bifrost/tree/main/plugins/governance)** - Rate limiting and budget controls -- **[JSON Parser Plugin](https://github.com/maximhq/bifrost/tree/main/plugins/jsonparser)** - Parse and validate JSON responses - -## Frequently Asked Questions - -### Do I need to rebuild my plugin when upgrading Bifrost? - -**Yes, absolutely.** Plugins must be compiled against the exact same version of `github.com/maximhq/bifrost/core` that Bifrost is using. This is a fundamental requirement of Go's plugin system. - -When you upgrade Bifrost, you must: -1. Update your plugin's `go.mod` to use the matching core version -2. Rebuild the plugin with the same Go version -3. Redeploy the plugin alongside the new Bifrost version - -**Example:** - -If upgrading from Bifrost v1.2.17 to v1.3.0: - -```bash -# Update your plugin dependency -go get github.com/maximhq/bifrost/core@v1.3.0 -go mod tidy - -# Rebuild the plugin -go build -buildmode=plugin -o my-plugin.so main.go -``` - - -**Version mismatch will cause runtime errors!** If your plugin is compiled with v1.2.17 but Bifrost is running v1.3.0, the plugin will fail to load with cryptic errors about package versions. - - -### Should plugin builds be part of my deployment pipeline? - -**Yes, strongly recommended.** Your plugin build and deployment should be tightly coupled with your Bifrost deployment. - -**Recommended CI/CD Workflow:** - -```yaml -# Example GitHub Actions workflow -name: Deploy Bifrost with Plugins - -on: - push: - branches: [main] - -jobs: - deploy: - runs-on: ubuntu-latest - steps: - # 1. Checkout code - - uses: actions/checkout@v3 - - # 2. Setup Go - - uses: actions/setup-go@v4 - with: - go-version: '1.25.5' - - # 3. Build Bifrost - - name: Build Bifrost - run: | - cd transports/bifrost-http - go build -o bifrost-http - - # 4. Build ALL plugins with matching version - - name: Build Plugins - run: | - cd plugins/my-plugin - # Ensure plugin uses same core version as Bifrost - go get github.com/maximhq/bifrost/core@${{ env.BIFROST_VERSION }} - go mod tidy - go build -buildmode=plugin -o my-plugin.so main.go - - # 5. Bundle everything together - - name: Create deployment bundle - run: | - mkdir -p deploy/plugins - cp transports/bifrost-http/bifrost-http deploy/ - cp plugins/my-plugin/my-plugin.so deploy/plugins/ - cp config.json deploy/ - - # 6. Deploy bundle to your infrastructure - - name: Deploy to Production - run: | - # Upload to S3, copy to servers, deploy to K8s, etc. - ./deploy.sh -``` - -**Key Principles:** - -1. **Version Lock** - Pin your plugin dependencies to specific Bifrost versions -2. **Atomic Deployment** - Deploy Bifrost and plugins together as a single unit -3. **Build Verification** - Test plugin loading as part of CI -4. **Rollback Strategy** - Keep previous plugin versions for rollbacks - -### How do I handle plugin versioning in production? - -Organize your plugin deployments by version: - -``` -/opt/bifrost/ -ā”œā”€ā”€ v1.3.0/ -│ ā”œā”€ā”€ bifrost-http -│ └── plugins/ -│ ā”œā”€ā”€ my-plugin.so -│ └── cache-plugin.so -ā”œā”€ā”€ v1.2.17/ -│ ā”œā”€ā”€ bifrost-http -│ └── plugins/ -│ ā”œā”€ā”€ my-plugin.so -│ └── cache-plugin.so -└── current -> v1.3.0/ # Symlink to active version -``` - -This allows easy rollbacks: - -```bash -# Rollback to previous version -ln -sfn /opt/bifrost/v1.2.17 /opt/bifrost/current -systemctl restart bifrost -``` - -### Can I use different plugin versions for different Bifrost instances? - -**No.** Each plugin must match the exact core version of the Bifrost instance loading it. If you're running multiple Bifrost versions (e.g., staging vs production), you need separate plugin builds for each version. - -``` -staging/ - bifrost-http (v1.3.0) - plugins/ - my-plugin-v1.3.0.so - -production/ - bifrost-http (v1.2.17) - plugins/ - my-plugin-v1.2.17.so -``` - -### What happens if I forget to rebuild a plugin? - -You'll see errors like: - -``` -plugin: symbol Init not found in plugin github.com/you/plugin -plugin was built with a different version of package github.com/maximhq/bifrost/core -``` - -**Solution:** Rebuild the plugin with the correct core version. See the [Version Mismatch Errors](#version-mismatch-errors) troubleshooting section for detailed diagnosis steps using `go version -m`. - -### How do I test plugins before production deployment? - -**Multi-stage testing approach:** - -1. **Unit Tests** - Test plugin logic in isolation - ```go - func TestPreHook(t *testing.T) { - req := &schemas.BifrostRequest{...} - modifiedReq, shortCircuit, err := PreHook(&ctx, req) - assert.NoError(t, err) - assert.Nil(t, shortCircuit) - } - ``` - -2. **Integration Tests** - Load plugin in test Bifrost instance - ```bash - # Start test Bifrost with plugin - ./bifrost-http --config test-config.json - - # Run test requests - curl -X POST http://localhost:8080/v1/chat/completions ... - ``` - -3. **Staging Environment** - Deploy to staging with production-like load - -4. **Canary Deployment** - Gradually roll out to production - -### Can I hot-reload plugins without restarting Bifrost? - -**Yes!** Bifrost supports hot-reloading plugins at runtime. You can update plugin configurations or reload plugin code without restarting the entire Bifrost instance. - -### How do I debug plugin loading issues? - -**Enable verbose logging:** - -```json -{ - "log_level": "debug", - "plugins": [ - { - "enabled": true, - "name": "my-plugin", - "path": "./plugins/my-plugin.so", - "config": {} - } - ] -} -``` - -**Check plugin symbols:** - -```bash -# List symbols exported by plugin -go tool nm my-plugin.so | grep -E 'Init|GetName|PreHook' -``` - -**Verify Go version:** - -```bash -# Check Go version used to build plugin -go version -m my-plugin.so -``` - -**Common debugging steps:** - -1. Verify file exists and has correct permissions -2. Check Go version matches Bifrost -3. Confirm core package version matches -4. Ensure all required symbols are exported -5. Review Bifrost logs for detailed error messages - -## Need Help? - -- **Discord Community**: [Join our Discord](https://getmax.im/bifrost-discord) -- **GitHub Issues**: [Report bugs or request features](https://github.com/maximhq/bifrost/issues) -- **Documentation**: [Browse all docs](/) - diff --git a/docs/plugins/writing-wasm-plugin.mdx b/docs/plugins/writing-wasm-plugin.mdx index bad5ef9ffe..885bc11c3c 100644 --- a/docs/plugins/writing-wasm-plugin.mdx +++ b/docs/plugins/writing-wasm-plugin.mdx @@ -751,6 +751,12 @@ Output: `build/plugin.wasm` ### http_intercept + +**Header and Query Parameter Handling**: Headers and query parameters in `request.headers` and `request.query` preserve the original casing sent by the client. When looking up headers/query params, you should perform case-insensitive comparisons in your WASM plugin code to handle various casing (e.g., `Content-Type`, `content-type`, `CONTENT-TYPE`). + +For Go native plugins, use the built-in `CaseInsensitiveHeaderLookup()` and `CaseInsensitiveQueryLookup()` helper methods. + + **Input:** ```json { @@ -760,7 +766,7 @@ Output: `build/plugin.wasm` "request": { "method": "POST", "path": "/v1/chat/completions", - "headers": { "Content-Type": "application/json" }, + "headers": { "content-type": "application/json" }, "query": {}, "body": "" } diff --git a/docs/providers/supported-providers/vertex.mdx b/docs/providers/supported-providers/vertex.mdx index b2340490b1..f5ec59cb47 100644 --- a/docs/providers/supported-providers/vertex.mdx +++ b/docs/providers/supported-providers/vertex.mdx @@ -337,6 +337,111 @@ Lists models available in the specified project and region with metadata and dep } ``` +## Custom vs Non-Custom Models + + +**Important**: Vertex AI's List Models API **only returns custom fine-tuned models** that have been deployed to your project. It does NOT return standard foundation models (Gemini, Claude, etc.). + + +To provide a complete model listing experience, Bifrost performs **multi-pass model discovery**: + +### Three-Pass Model Discovery + +1. **First Pass - Custom Models from API Response** + - Queries Vertex AI's List Models API + - Returns only custom fine-tuned models deployed to your project + - Custom models are identified by having deployment values that contain only digits + - Example: `"deployment": "1234567890"` + +2. **Second Pass - Non-Custom Models from Deployments** + - Adds standard foundation models from your `deployments` configuration + - Non-custom models have alphanumeric deployment values (e.g., `gemini-pro`, `claude-3-5-sonnet`) + - Filters by `allowedModels` if specified + - Example: `"deployment": "gemini-2.0-flash"` + +3. **Third Pass - Allowed Models Not in Deployments** + - Adds models specified in `allowedModels` that weren't in the `deployments` map + - Ensures all explicitly allowed models appear in the list + - Uses the model name itself as the deployment value + - Skips digit-only model IDs (reserved for custom models) + +### Model Filtering Logic + +- **If `allowedModels` is empty**: All models from all three passes are included +- **If `allowedModels` is non-empty**: Only models/deployments with keys in `allowedModels` are included +- **Duplicate Prevention**: Each model ID is tracked to prevent duplicates across passes + +### Model Name Formatting + +Non-custom models from deployments and allowed models are automatically formatted for display: + +- `gemini-pro` → "Gemini Pro" +- `claude-3-5-sonnet` → "Claude 3 5 Sonnet" +- `gemini_2_flash` → "Gemini 2 Flash" + +Formatting uses title case and converts hyphens/underscores to spaces. + +### Example Configuration + + + + +```json +{ + "vertex_key_config": { + "project_id": "my-project", + "region": "us-central1", + "deployments": { + "my-gemini-ft": "1234567890", + "my-claude-ft": "9876543210" + } + } +} +``` + +This returns only your custom fine-tuned models from the API. + + + + +```json +{ + "vertex_key_config": { + "project_id": "my-project", + "region": "us-central1", + "deployments": { + "gemini-2.0-flash": "gemini-2.0-flash", + "claude-3-5-sonnet": "claude-3-5-sonnet-v2@20241022" + } + } +} +``` + +This returns both custom models AND foundation models from deployments. + + + + +```json +{ + "vertex_key_config": { + "project_id": "my-project", + "region": "us-central1", + "deployments": { + "gemini-2.0-flash": "gemini-2.0-flash", + "claude-3-5-sonnet": "claude-3-5-sonnet-v2@20241022", + "gemini-1.5-pro": "gemini-1.5-pro" + }, + "allowedModels": ["gemini-2.0-flash", "claude-3-5-sonnet"] + } +} +``` + +Only returns `gemini-2.0-flash` and `claude-3-5-sonnet`, excluding `gemini-1.5-pro`. + + + + ### Pagination Model listing is paginated automatically. If more than 100 models exist, `next_page_token` will be present. Bifrost handles pagination internally. @@ -387,6 +492,14 @@ Model listing is paginated automatically. If more than 100 models exist, `next_p **Code**: `embedding.go:84-87` + +**Severity**: High +**Behavior**: Vertex AI's List Models API only returns custom fine-tuned models, NOT foundation models +**Impact**: Bifrost performs three-pass discovery to include foundation models from deployments and allowedModels configuration +**Why**: This is a Vertex AI API limitation - foundation models must be explicitly configured +**Code**: `models.go:76-217` + + --- ## Configuration diff --git a/examples/configs/withvirtualkeys/config.json b/examples/configs/withvirtualkeys/config.json index f007ead383..45e69adfed 100644 --- a/examples/configs/withvirtualkeys/config.json +++ b/examples/configs/withvirtualkeys/config.json @@ -64,7 +64,7 @@ "weight": 0.5 } ], - "value": "sk-bf-vk-prod-assistant-us-01" + "value": "env.BIFROST_VK_PROD_ASSISTANT_US_01" }, { "id": "sk-bf-vk-prod-assistant-eu-01", diff --git a/examples/plugins/hello-world-wasm-go/Makefile b/examples/plugins/hello-world-wasm-go/Makefile new file mode 100644 index 0000000000..713f911bd3 --- /dev/null +++ b/examples/plugins/hello-world-wasm-go/Makefile @@ -0,0 +1,74 @@ +.PHONY: all build clean help check-tinygo + +# Colors +COLOR_RESET = \033[0m +COLOR_INFO = \033[36m +COLOR_SUCCESS = \033[32m +COLOR_WARNING = \033[33m +COLOR_ERROR = \033[31m +COLOR_BOLD = \033[1m + +# Plugin configuration +PLUGIN_NAME = hello-world +OUTPUT_DIR = build +OUTPUT = $(OUTPUT_DIR)/$(PLUGIN_NAME).wasm + +# TinyGo build flags +TINYGO_TARGET = wasi +TINYGO_SCHEDULER = none + +help: ## Show this help message + @echo '$(COLOR_BOLD)Hello World WASM Plugin$(COLOR_RESET)' + @echo '' + @echo '$(COLOR_BOLD)Usage:$(COLOR_RESET) make [target]' + @echo '' + @echo '$(COLOR_BOLD)Prerequisites:$(COLOR_RESET)' + @echo ' - TinyGo (https://tinygo.org/getting-started/install/)' + @echo ' macOS: brew install tinygo' + @echo ' Linux: See TinyGo installation docs' + @echo '' + @echo '$(COLOR_BOLD)Available targets:$(COLOR_RESET)' + @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {printf " $(COLOR_INFO)%-15s$(COLOR_RESET) %s\n", $$1, $$2}' $(MAKEFILE_LIST) + +check-tinygo: ## Check if TinyGo is installed + @which tinygo > /dev/null 2>&1 || (echo "$(COLOR_ERROR)Error: TinyGo is not installed$(COLOR_RESET)"; \ + echo "$(COLOR_INFO)Install TinyGo:$(COLOR_RESET)"; \ + echo " macOS: brew install tinygo"; \ + echo " Linux: See https://tinygo.org/getting-started/install/"; \ + exit 1) + @echo "$(COLOR_SUCCESS)āœ“ TinyGo found: $$(tinygo version)$(COLOR_RESET)" + +build: check-tinygo ## Build the WASM plugin + @mkdir -p $(OUTPUT_DIR) + @echo "$(COLOR_INFO)Building WASM plugin...$(COLOR_RESET)" + GOWORK=off tinygo build -o $(OUTPUT) -target=$(TINYGO_TARGET) -scheduler=$(TINYGO_SCHEDULER) . + @echo "$(COLOR_SUCCESS)āœ“ Plugin built successfully: $(OUTPUT)$(COLOR_RESET)" + @ls -lh $(OUTPUT) | awk '{print " Size: " $$5}' + +build-optimized: check-tinygo ## Build the WASM plugin with size optimizations + @mkdir -p $(OUTPUT_DIR) + @echo "$(COLOR_INFO)Building optimized WASM plugin...$(COLOR_RESET)" + GOWORK=off tinygo build -o $(OUTPUT) -target=$(TINYGO_TARGET) -scheduler=$(TINYGO_SCHEDULER) -no-debug -gc=leaking . + @echo "$(COLOR_SUCCESS)āœ“ Optimized plugin built: $(OUTPUT)$(COLOR_RESET)" + @ls -lh $(OUTPUT) | awk '{print " Size: " $$5}' + +clean: ## Remove build artifacts + @echo "$(COLOR_INFO)Cleaning build artifacts...$(COLOR_RESET)" + @rm -rf $(OUTPUT_DIR) + @echo "$(COLOR_SUCCESS)āœ“ Clean complete$(COLOR_RESET)" + +info: ## Show build information + @echo "$(COLOR_BOLD)Build Configuration$(COLOR_RESET)" + @echo " Plugin Name: $(PLUGIN_NAME)" + @echo " Output: $(OUTPUT)" + @echo " Target: $(TINYGO_TARGET)" + @echo " Scheduler: $(TINYGO_SCHEDULER)" + @echo "" + @if [ -f "$(OUTPUT)" ]; then \ + echo "$(COLOR_SUCCESS)Plugin exists:$(COLOR_RESET)"; \ + ls -lh $(OUTPUT) | awk '{print " " $$9 " (" $$5 ")"}'; \ + else \ + echo "$(COLOR_WARNING)Plugin not built yet$(COLOR_RESET)"; \ + fi + +.DEFAULT_GOAL := help diff --git a/examples/plugins/hello-world-wasm-go/README.md b/examples/plugins/hello-world-wasm-go/README.md new file mode 100644 index 0000000000..3b12233419 --- /dev/null +++ b/examples/plugins/hello-world-wasm-go/README.md @@ -0,0 +1,170 @@ +# Hello World WASM Plugin + +A minimal example of a Bifrost plugin written in Go and compiled to WebAssembly using TinyGo. + +## Prerequisites + +### TinyGo Installation + +TinyGo is required to compile Go code to WebAssembly with a small binary size. + +**macOS:** +```bash +brew install tinygo +``` + +**Linux (Ubuntu/Debian):** +```bash +wget https://github.com/tinygo-org/tinygo/releases/download/v0.32.0/tinygo_0.32.0_amd64.deb +sudo dpkg -i tinygo_0.32.0_amd64.deb +``` + +**Other platforms:** +See [TinyGo Installation Guide](https://tinygo.org/getting-started/install/) + +## Building + +```bash +# Build the WASM plugin +make build + +# Build with size optimizations +make build-optimized + +# Clean build artifacts +make clean +``` + +The compiled plugin will be at `build/hello-world.wasm`. + +## Plugin Structure + +WASM plugins must export the following functions: + +| Export | Signature | Description | +|--------|-----------|-------------| +| `plugin_malloc` | `(size: u32) -> u32` | Allocate memory for host to write data (or `malloc` for non-TinyGo) | +| `plugin_free` | `(ptr: u32)` | Free allocated memory (optional, or `free` for non-TinyGo) | +| `get_name` | `() -> u64` | Returns packed ptr+len of plugin name | +| `http_transport_intercept` | `(ctx_ptr, ctx_len, req_ptr, req_len: u32) -> u64` | HTTP transport intercept | +| `pre_hook` | `(ctx_ptr, ctx_len, req_ptr, req_len: u32) -> u64` | Pre-request hook | +| `post_hook` | `(ctx_ptr, ctx_len, resp_ptr, resp_len, err_ptr, err_len: u32) -> u64` | Post-response hook | +| `cleanup` | `() -> i32` | Cleanup resources (0 = success) | +| `init` | `(config_ptr, config_len: u32) -> i32` | Initialize with config (optional) | + +### Return Value Format + +Functions returning data use a packed `u64` format: +- Upper 32 bits: pointer to data in WASM memory +- Lower 32 bits: length of data + +### Data Exchange + +All complex data is exchanged as JSON: + +**HTTPTransportIntercept Input:** +- `ctx`: `{"request_id": "..."}` (context info) +- `req`: HTTP request JSON +```json +{ + "method": "POST", + "path": "/v1/chat/completions", + "headers": {"Content-Type": "application/json"}, + "query": {}, + "body": "base64-encoded-body" +} +``` + +**HTTPTransportIntercept Output:** +```json +{ + "response": null, + "error": "" +} +``` +To short-circuit, return a response: +```json +{ + "response": { + "status_code": 401, + "headers": {"Content-Type": "application/json"}, + "body": "base64-encoded-body" + }, + "error": "" +} +``` + +**PreHook Input:** +- `ctx`: `{"request_id": "..."}` (context info) +- `req`: Bifrost request JSON + +**PreHook Output:** +```json +{ + "request": { ... }, + "short_circuit": null, + "error": "" +} +``` + +**PostHook Input:** +- `ctx`: Context JSON +- `resp`: Bifrost response JSON +- `err`: Bifrost error JSON (or null) + +**PostHook Output:** +```json +{ + "response": { ... }, + "bifrost_error": null, + "error": "" +} +``` + +## Usage with Bifrost + +Configure the plugin in your Bifrost config: + +```json +{ + "plugins": [ + { + "path": "/path/to/hello-world.wasm", + "name": "hello-world-wasm", + "enabled": true + } + ] +} +``` + +Or load from URL: + +```json +{ + "plugins": [ + { + "path": "https://example.com/plugins/hello-world.wasm", + "name": "hello-world-wasm", + "enabled": true + } + ] +} +``` + +## Limitations + +WASM plugins have some limitations compared to native `.so` plugins: + +1. **Performance**: JSON serialization/deserialization adds overhead compared to native plugins. + +2. **Memory**: WASM modules have a linear memory model with limited addressing. + +3. **TinyGo Constraints**: Some Go standard library features are not available in TinyGo. + +## Benefits + +1. **Cross-platform**: Single `.wasm` binary runs on any OS/architecture +2. **Security**: WASM provides sandboxed execution +3. **No CGO**: Pure Go compilation, no C dependencies needed on the host +4. **Portability**: Easy to distribute and deploy +5. **Full feature parity**: HTTP transport intercept, PreHook, and PostHook all supported \ No newline at end of file diff --git a/examples/plugins/hello-world-wasm-go/go.mod b/examples/plugins/hello-world-wasm-go/go.mod new file mode 100644 index 0000000000..64a44e2780 --- /dev/null +++ b/examples/plugins/hello-world-wasm-go/go.mod @@ -0,0 +1,32 @@ +module github.com/maximhq/bifrost/examples/plugins/hello-world-wasm + +go 1.25.5 + +require github.com/maximhq/bifrost/core v0.0.0-00010101000000-000000000000 + +replace github.com/maximhq/bifrost/core => ../../../core + +require ( + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.14.2 // indirect + github.com/bytedance/sonic/loader v0.4.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/klauspost/compress v1.18.2 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mailru/easyjson v0.9.1 // indirect + github.com/mark3labs/mcp-go v0.43.2 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.68.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/arch v0.23.0 // indirect + golang.org/x/sys v0.39.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/examples/plugins/hello-world-wasm-go/go.sum b/examples/plugins/hello-world-wasm-go/go.sum new file mode 100644 index 0000000000..fee36f9db2 --- /dev/null +++ b/examples/plugins/hello-world-wasm-go/go.sum @@ -0,0 +1,78 @@ +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/sonic v1.14.2 h1:k1twIoe97C1DtYUo+fZQy865IuHia4PR5RPiuGPPIIE= +github.com/bytedance/sonic v1.14.2/go.mod h1:T80iDELeHiHKSc0C9tubFygiuXoGzrkjKzX2quAx980= +github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2NYzevs+o= +github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= +github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8= +github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I= +github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= +github.com/maximhq/bifrost/core v1.3.3 h1:r2llMAfzIHeSxwY2L55UaSOsY17JSg5zYcqF2JtaRVY= +github.com/maximhq/bifrost/core v1.3.3/go.mod h1:abKQRnJQPZz8/UMxCcbuNHEyq19Db+IX4KlGJdlLY8E= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.68.0 h1:v12Nx16iepr8r9ySOwqI+5RBJ/DqTxhOy1HrHoDFnok= +github.com/valyala/fasthttp v1.68.0/go.mod h1:5EXiRfYQAoiO/khu4oU9VISC/eVY6JqmSpPJoHCKsz4= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/arch v0.23.0 h1:lKF64A2jF6Zd8L0knGltUnegD62JMFBiCPBmQpToHhg= +golang.org/x/arch v0.23.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/examples/plugins/hello-world-wasm-go/main.go b/examples/plugins/hello-world-wasm-go/main.go new file mode 100644 index 0000000000..3f0027d013 --- /dev/null +++ b/examples/plugins/hello-world-wasm-go/main.go @@ -0,0 +1,180 @@ +// Package main provides a hello-world WASM plugin example for Bifrost. +// This plugin demonstrates the basic structure and exports required for WASM plugins. +// +// Build with TinyGo: +// +// tinygo build -o build/hello-world.wasm -target=wasi -scheduler=none main.go +package main + +import ( + "encoding/json" +) + +// ============================================================================ +// Plugin Exports +// ============================================================================ + +//export get_name +func get_name() uint64 { + return writeBytes([]byte("Hello World WASM Plugin")) +} + +//export init +func init_plugin(configPtr, configLen uint32) int32 { + println("WASM Plugin: Init called") + if configLen > 0 { + configData := readInput(configPtr, configLen) + println("WASM Plugin: Config received:", string(configData)) + } + return 0 +} + +//export http_intercept +func http_intercept(inputPtr, inputLen uint32) uint64 { + println("WASM Plugin: http_intercept called") + + inputData := readInput(inputPtr, inputLen) + if inputData == nil { + return writeError("no input data") + } + + // Parse input + var input HTTPInterceptInput + if err := json.Unmarshal(inputData, &input); err != nil { + println("WASM Plugin: parse error:", err.Error()) + return writeError("parse error: " + err.Error()) + } + + // Log parsed data + println("WASM Plugin: HTTP", input.Request.Method, input.Request.Path) + if ct, ok := input.Request.Headers["content-type"]; ok { + println("WASM Plugin: Content-Type:", ct) + } + input.Context["from-http"] = "123" + // Return pass-through + output := HTTPInterceptOutput{ + Context: input.Context, + Request: input.Request, + HasResponse: false, + Error: "", + } + + data, _ := json.Marshal(output) + return writeBytes(data) +} + +//export pre_hook +func pre_hook(inputPtr, inputLen uint32) uint64 { + println("WASM Plugin: pre_hook called") + + inputData := readInput(inputPtr, inputLen) + if inputData == nil { + return writePreHookError("no input data") + } + + println("WASM Plugin: Pre-hook input:", string(inputData)) + + // Parse input + var input PreHookInput + if err := json.Unmarshal(inputData, &input); err != nil { + println("WASM Plugin: parse error:", err.Error()) + return writePreHookError("parse error: " + err.Error()) + } + + // Print existing context + for k, v := range input.Context { + println("WASM Plugin: Context", k, "=", v) + } + + input.Context["from-pre-hook"] = "789" + + // Return with custom context value + output := PreHookOutput{ + Context: input.Context, + Request: input.Request, + HasShortCircuit: false, + Error: "", + } + + data, _ := json.Marshal(output) + return writeBytes(data) +} + +//export post_hook +func post_hook(inputPtr, inputLen uint32) uint64 { + println("WASM Plugin: post_hook called") + + inputData := readInput(inputPtr, inputLen) + if inputData == nil { + return writePostHookError("no input data") + } + + // Parse input + var input PostHookInput + if err := json.Unmarshal(inputData, &input); err != nil { + println("WASM Plugin: parse error:", err.Error()) + return writePostHookError("parse error: " + err.Error()) + } + + println("WASM Plugin: Post-hook input:", string(inputData)) + // Print existing context + for k, v := range input.Context { + println("WASM Plugin: Context", k, "=", v) + } + + // Parse response for logging + + if processed, ok := input.Context["wasm_plugin_processed"].(bool); ok && processed { + println("WASM Plugin: Pre-hook context value present") + } + + input.Context["from-post-hook"] = "456" + // Return pass-through + output := PostHookOutput{ + Context: input.Context, + Response: input.Response, + Error: input.Error, + HasError: false, + HookError: "", + } + + data, _ := json.Marshal(output) + return writeBytes(data) +} + +//export cleanup +func cleanup() int32 { + println("WASM Plugin: Cleanup called") + return 0 +} + +// Helper functions for error responses +func writeError(msg string) uint64 { + output := HTTPInterceptOutput{HasResponse: false, Error: msg} + data, _ := json.Marshal(output) + return writeBytes(data) +} + +func writePreHookError(msg string) uint64 { + output := PreHookOutput{ + Context: map[string]interface{}{}, + Request: nil, + HasShortCircuit: false, + Error: msg, + } + data, _ := json.Marshal(output) + return writeBytes(data) +} + +func writePostHookError(msg string) uint64 { + output := PostHookOutput{ + Context: map[string]interface{}{}, + Response: nil, + HasError: false, + HookError: msg, + } + data, _ := json.Marshal(output) + return writeBytes(data) +} + +func main() {} diff --git a/examples/plugins/hello-world-wasm-go/memory.go b/examples/plugins/hello-world-wasm-go/memory.go new file mode 100644 index 0000000000..63b8fa3af1 --- /dev/null +++ b/examples/plugins/hello-world-wasm-go/memory.go @@ -0,0 +1,89 @@ +package main + +import "unsafe" + +// ============================================================================ +// Memory Management +// ============================================================================ + +// heapSize is the fixed size of the pre-allocated heap. +// This must be large enough to handle all allocations during the plugin lifetime. +// The heap is never reallocated to ensure all pointers remain valid. +const heapSize = 4 * 1024 * 1024 // 4MB fixed heap + +// heapBase is a fixed-size buffer that is never reallocated. +// All allocations come from this buffer to ensure pointer stability. +var heapBase []byte + +// heapOffset tracks the next available position in heapBase. +var heapOffset uint32 = 0 + +// heapBasePtr caches the base pointer of heapBase for efficient offset-to-pointer conversion. +var heapBasePtr uintptr + +func init() { + // Pre-allocate the fixed heap once at startup. + // This ensures heapBase is never reallocated after pointers are handed out. + heapBase = make([]byte, heapSize) + heapBasePtr = uintptr(unsafe.Pointer(&heapBase[0])) +} + +//export plugin_malloc +func plugin_malloc(size uint32) uint32 { + if size == 0 { + return 0 + } + // Align to 8-byte boundary + alignedSize := (size + 7) &^ 7 + // Check if we have enough space (no reallocation allowed) + if heapOffset+alignedSize > uint32(len(heapBase)) { + // Allocation failure - heap exhausted + // Return 0 to indicate failure rather than reallocating + return 0 + } + // Return pointer to the allocated region + ptr := uint32(heapBasePtr + uintptr(heapOffset)) + heapOffset += alignedSize + return ptr +} + +//export plugin_free +func plugin_free(ptr uint32) { + // No-op: we use a simple bump allocator without individual frees. + // Memory is reclaimed when the plugin is unloaded. +} + +// plugin_reset resets the heap allocator, allowing memory to be reused. +// This should only be called when no allocated memory is in use. +// +//export plugin_reset +func plugin_reset() { + heapOffset = 0 +} + +func packResult(ptr uint32, length uint32) uint64 { + return (uint64(ptr) << 32) | uint64(length) +} + +func writeBytes(data []byte) uint64 { + if len(data) == 0 { + return 0 + } + // Allocate from the stable heap + ptr := plugin_malloc(uint32(len(data))) + if ptr == 0 { + // Allocation failed + return 0 + } + // Copy data into the allocated region + offset := ptr - uint32(heapBasePtr) + copy(heapBase[offset:offset+uint32(len(data))], data) + return packResult(ptr, uint32(len(data))) +} + +func readInput(ptr, length uint32) []byte { + if length == 0 { + return nil + } + return unsafe.Slice((*byte)(unsafe.Pointer(uintptr(ptr))), length) +} diff --git a/examples/plugins/hello-world-wasm-go/types.go b/examples/plugins/hello-world-wasm-go/types.go new file mode 100644 index 0000000000..6333b802e6 --- /dev/null +++ b/examples/plugins/hello-world-wasm-go/types.go @@ -0,0 +1,54 @@ +package main + +import "github.com/maximhq/bifrost/core/schemas" + +// ============================================================================ +// Input/Output Structs +// ============================================================================ + +// HTTPInterceptInput is the input for http_intercept +type HTTPInterceptInput struct { + Context map[string]interface{} `json:"context"` + Request *schemas.HTTPRequest `json:"request,omitempty"` +} + +// HTTPInterceptOutput is the output for http_intercept +type HTTPInterceptOutput struct { + Context map[string]interface{} `json:"context"` + Request *schemas.HTTPRequest `json:"request,omitempty"` + Response *schemas.HTTPResponse `json:"response,omitempty"` + HasResponse bool `json:"has_response"` + Error string `json:"error"` +} + +// PreHookInput is the input for pre_hook +type PreHookInput struct { + Context map[string]interface{} `json:"context"` + Request *schemas.BifrostRequest `json:"request,omitempty"` // Keep raw for pass-through +} + +// PreHookOutput is the output for pre_hook +type PreHookOutput struct { + Context map[string]interface{} `json:"context"` + Request *schemas.BifrostRequest `json:"request,omitempty"` + ShortCircuit *schemas.PluginShortCircuit `json:"short_circuit,omitempty"` + HasShortCircuit bool `json:"has_short_circuit"` + Error string `json:"error"` +} + +// PostHookInput is the input for post_hook +type PostHookInput struct { + Context map[string]interface{} `json:"context"` + Response *schemas.BifrostResponse `json:"response,omitempty"` + Error *schemas.BifrostError `json:"error,omitempty"` + HasError bool `json:"has_error"` +} + +// PostHookOutput is the output for post_hook +type PostHookOutput struct { + Context map[string]interface{} `json:"context"` + Response *schemas.BifrostResponse `json:"response,omitempty"` + Error *schemas.BifrostError `json:"error,omitempty"` + HasError bool `json:"has_error"` + HookError string `json:"hook_error"` +} diff --git a/examples/plugins/hello-world-wasm-rust/Cargo.lock b/examples/plugins/hello-world-wasm-rust/Cargo.lock new file mode 100644 index 0000000000..1d3f17998b --- /dev/null +++ b/examples/plugins/hello-world-wasm-rust/Cargo.lock @@ -0,0 +1,107 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "hello-world-wasm-rust" +version = "0.1.0" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "itoa" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" + +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + +[[package]] +name = "proc-macro2" +version = "1.0.105" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "535d180e0ecab6268a3e718bb9fd44db66bbbc256257165fc699dadf70d16fe7" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc74d9a594b72ae6656596548f56f667211f8a97b3d4c3d467150794690dc40a" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "syn" +version = "2.0.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" + +[[package]] +name = "zmij" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fc5a66a20078bf1251bde995aa2fdcc4b800c70b5d92dd2c62abc5c60f679f8" diff --git a/examples/plugins/hello-world-wasm-rust/Cargo.toml b/examples/plugins/hello-world-wasm-rust/Cargo.toml new file mode 100644 index 0000000000..1b97bd30ab --- /dev/null +++ b/examples/plugins/hello-world-wasm-rust/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "hello-world-wasm-rust" +version = "0.1.0" +edition = "2021" +description = "A minimal Bifrost WASM plugin example in Rust" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" + +[profile.release] +opt-level = "s" +lto = true +strip = true +panic = "abort" diff --git a/examples/plugins/hello-world-wasm-rust/Makefile b/examples/plugins/hello-world-wasm-rust/Makefile new file mode 100644 index 0000000000..152dd8a39d --- /dev/null +++ b/examples/plugins/hello-world-wasm-rust/Makefile @@ -0,0 +1,80 @@ +.PHONY: all build build-optimized clean help check-rust + +# Colors +COLOR_RESET = \033[0m +COLOR_INFO = \033[36m +COLOR_SUCCESS = \033[32m +COLOR_WARNING = \033[33m +COLOR_ERROR = \033[31m +COLOR_BOLD = \033[1m + +# Plugin configuration +PLUGIN_NAME = hello-world +OUTPUT_DIR = build +OUTPUT = $(OUTPUT_DIR)/$(PLUGIN_NAME).wasm +TARGET = wasm32-unknown-unknown + +help: ## Show this help message + @echo '$(COLOR_BOLD)Hello World WASM Plugin (Rust)$(COLOR_RESET)' + @echo '' + @echo '$(COLOR_BOLD)Usage:$(COLOR_RESET) make [target]' + @echo '' + @echo '$(COLOR_BOLD)Prerequisites:$(COLOR_RESET)' + @echo ' - Rust with wasm32-unknown-unknown target' + @echo ' rustup target add wasm32-unknown-unknown' + @echo '' + @echo '$(COLOR_BOLD)Available targets:$(COLOR_RESET)' + @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {printf " $(COLOR_INFO)%-15s$(COLOR_RESET) %s\n", $$1, $$2}' $(MAKEFILE_LIST) + +check-rust: ## Check if Rust and WASM target are installed + @which cargo > /dev/null 2>&1 || (echo "$(COLOR_ERROR)Error: Rust/Cargo is not installed$(COLOR_RESET)"; \ + echo "$(COLOR_INFO)Install Rust: https://rustup.rs/$(COLOR_RESET)"; \ + exit 1) + @rustup target list --installed | grep -q $(TARGET) || (echo "$(COLOR_ERROR)Error: WASM target not installed$(COLOR_RESET)"; \ + echo "$(COLOR_INFO)Install with: rustup target add $(TARGET)$(COLOR_RESET)"; \ + exit 1) + @echo "$(COLOR_SUCCESS)āœ“ Rust found: $$(rustc --version)$(COLOR_RESET)" + @echo "$(COLOR_SUCCESS)āœ“ WASM target: $(TARGET)$(COLOR_RESET)" + +build: check-rust ## Build the WASM plugin + @mkdir -p $(OUTPUT_DIR) + @echo "$(COLOR_INFO)Building WASM plugin...$(COLOR_RESET)" + cargo build --release --target $(TARGET) + @cp target/$(TARGET)/release/hello_world_wasm_rust.wasm $(OUTPUT) + @echo "$(COLOR_SUCCESS)āœ“ Plugin built successfully: $(OUTPUT)$(COLOR_RESET)" + @ls -lh $(OUTPUT) | awk '{print " Size: " $$5}' + +build-optimized: check-rust ## Build with wasm-opt optimization (requires wasm-opt) + @mkdir -p $(OUTPUT_DIR) + @echo "$(COLOR_INFO)Building optimized WASM plugin...$(COLOR_RESET)" + cargo build --release --target $(TARGET) + @cp target/$(TARGET)/release/hello_world_wasm_rust.wasm $(OUTPUT) + @if which wasm-opt > /dev/null 2>&1; then \ + echo "$(COLOR_INFO)Running wasm-opt...$(COLOR_RESET)"; \ + wasm-opt -Os -o $(OUTPUT) $(OUTPUT); \ + else \ + echo "$(COLOR_WARNING)wasm-opt not found, skipping optimization$(COLOR_RESET)"; \ + fi + @echo "$(COLOR_SUCCESS)āœ“ Plugin built: $(OUTPUT)$(COLOR_RESET)" + @ls -lh $(OUTPUT) | awk '{print " Size: " $$5}' + +clean: ## Remove build artifacts + @echo "$(COLOR_INFO)Cleaning build artifacts...$(COLOR_RESET)" + @cargo clean + @rm -rf $(OUTPUT_DIR) + @echo "$(COLOR_SUCCESS)āœ“ Clean complete$(COLOR_RESET)" + +info: ## Show build information + @echo "$(COLOR_BOLD)Build Configuration$(COLOR_RESET)" + @echo " Plugin Name: $(PLUGIN_NAME)" + @echo " Output: $(OUTPUT)" + @echo " Target: $(TARGET)" + @echo "" + @if [ -f "$(OUTPUT)" ]; then \ + echo "$(COLOR_SUCCESS)Plugin exists:$(COLOR_RESET)"; \ + ls -lh $(OUTPUT) | awk '{print " " $$9 " (" $$5 ")"}'; \ + else \ + echo "$(COLOR_WARNING)Plugin not built yet$(COLOR_RESET)"; \ + fi + +.DEFAULT_GOAL := help diff --git a/examples/plugins/hello-world-wasm-rust/README.md b/examples/plugins/hello-world-wasm-rust/README.md new file mode 100644 index 0000000000..625794c422 --- /dev/null +++ b/examples/plugins/hello-world-wasm-rust/README.md @@ -0,0 +1,528 @@ +# Bifrost WASM Plugin (Rust) + +A comprehensive example of a Bifrost plugin written in Rust and compiled to WebAssembly. This plugin demonstrates proper structure definitions with serde, JSON parsing, context handling, and request/response modification patterns. + +## Prerequisites + +### Rust Installation + +Install Rust from [rustup.rs](https://rustup.rs/) and add the WASM target: + +```bash +# Install Rust (if not already installed) +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh + +# Add WASM target +rustup target add wasm32-unknown-unknown +``` + +### Optional: wasm-opt + +For smaller binaries, install `wasm-opt` from [binaryen](https://github.com/WebAssembly/binaryen): + +```bash +# macOS +brew install binaryen + +# Linux +apt install binaryen +``` + +## Building + +```bash +# Build the WASM plugin +make build + +# Build with wasm-opt optimization +make build-optimized + +# Clean build artifacts +make clean +``` + +The compiled plugin will be at `build/hello-world.wasm`. + +## File Structure + +``` +src/ +ā”œā”€ā”€ lib.rs # Plugin implementation (hooks) +ā”œā”€ā”€ memory.rs # Memory management utilities +└── types.rs # Type definitions (mirrors Go SDK) +``` + +## Plugin Structure + +WASM plugins must export the following functions: + +| Export | Signature | Description | +|--------|-----------|-------------| +| `malloc` | `(size: u32) -> u32` | Allocate memory for host to write data | +| `free` | `(ptr: u32, size: u32)` | Free allocated memory | +| `get_name` | `() -> u64` | Returns packed ptr+len of plugin name | +| `init` | `(config_ptr, config_len: u32) -> i32` | Initialize with config (optional) | +| `http_intercept` | `(input_ptr, input_len: u32) -> u64` | HTTP transport intercept | +| `pre_hook` | `(input_ptr, input_len: u32) -> u64` | Pre-request hook | +| `post_hook` | `(input_ptr, input_len: u32) -> u64` | Post-response hook | +| `cleanup` | `() -> i32` | Cleanup resources (0 = success) | + +### Return Value Format + +Functions returning data use a packed `u64` format: +- Upper 32 bits: pointer to data in WASM memory +- Lower 32 bits: length of data + +## Data Structures + +This plugin uses `serde` with derive macros for JSON serialization. All structures mirror the Go SDK types: + +### Context + +```rust +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostContext { + pub request_id: Option, + + // Custom values via HashMap + #[serde(flatten)] + pub values: HashMap, +} + +impl BifrostContext { + pub fn set_value(&mut self, key: &str, value: impl Into); + pub fn get_string(&self, key: &str) -> Option<&str>; + pub fn get_bool(&self, key: &str) -> Option; +} +``` + +### HTTP Transport Types + +```rust +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct HTTPRequest { + pub method: String, + pub path: String, + pub headers: HashMap, + pub query: HashMap, + pub body: String, // base64 encoded +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct HTTPResponse { + pub status_code: i32, + pub headers: HashMap, + pub body: String, // base64 encoded +} +``` + +### Chat Completion Types + +```rust +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ChatMessageRole { + User, + Assistant, + System, + Tool, + Developer, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ChatMessageContent { + Text(String), + Blocks(Vec), +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ChatMessage { + pub role: ChatMessageRole, + pub content: Option, + pub name: Option, + pub tool_call_id: Option, + pub tool_calls: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ChatParameters { + pub temperature: Option, + pub max_completion_tokens: Option, + pub top_p: Option, + pub frequency_penalty: Option, + pub presence_penalty: Option, + pub stop: Option>, + pub tools: Option>, + + #[serde(flatten)] + pub extra: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostChatRequest { + pub provider: String, + pub model: String, + pub input: Vec, + pub params: Option, + pub fallbacks: Option>, +} +``` + +### Response Types + +```rust +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct LLMUsage { + pub prompt_tokens: i32, + pub completion_tokens: i32, + pub total_tokens: i32, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ResponseChoice { + pub index: i32, + pub message: Option, + pub delta: Option, + pub finish_reason: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostChatResponse { + pub id: String, + pub model: String, + pub choices: Vec, + pub usage: Option, + pub created: Option, + pub object: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostResponse { + pub chat_response: Option, +} +``` + +### Error Types + +```rust +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ErrorField { + pub message: String, + #[serde(rename = "type")] + pub error_type: Option, + pub code: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostError { + pub error: ErrorField, + pub status_code: Option, + pub allow_fallbacks: Option, +} + +impl BifrostError { + pub fn new(message: &str) -> Self; + pub fn with_type(self, error_type: &str) -> Self; + pub fn with_code(self, code: &str) -> Self; + pub fn with_status(self, status: i32) -> Self; +} +``` + +### Short Circuit + +```rust +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PluginShortCircuit { + pub response: Option, + pub error: Option, +} +``` + +## Hook Input/Output Structures + +### http_intercept + +**Input:** +```json +{ + "context": { "request_id": "abc-123" }, + "request": { + "method": "POST", + "path": "/v1/chat/completions", + "headers": { "Content-Type": "application/json" }, + "query": {}, + "body": "" + } +} +``` + +**Output:** +```json +{ + "context": { "request_id": "abc-123" }, + "request": {}, + "response": { "status_code": 200, "headers": {}, "body": "" }, + "has_response": false, + "error": "" +} +``` + +### pre_hook + +**Input:** +```json +{ + "context": { "request_id": "abc-123" }, + "request": { + "provider": "openai", + "model": "gpt-4", + "input": [{ "role": "user", "content": "Hello" }], + "params": { "temperature": 0.7 } + } +} +``` + +**Output:** +```json +{ + "context": { "request_id": "abc-123", "plugin_processed": true }, + "request": {}, + "short_circuit": { + "response": { "chat_response": { ... } } + }, + "has_short_circuit": false, + "error": "" +} +``` + +### post_hook + +**Input:** +```json +{ + "context": { "request_id": "abc-123", "plugin_processed": true }, + "response": { + "chat_response": { + "id": "chatcmpl-123", + "model": "gpt-4", + "choices": [{ "index": 0, "message": { "role": "assistant", "content": "Hi!" } }], + "usage": { "prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15 } + } + }, + "error": {}, + "has_error": false +} +``` + +**Output:** +```json +{ + "context": { "request_id": "abc-123", "post_hook_completed": true }, + "response": {}, + "error": {}, + "has_error": false, + "hook_error": "" +} +``` + +## Usage Examples + +### Modifying Context + +```rust +#[no_mangle] +pub extern "C" fn pre_hook(input_ptr: u32, input_len: u32) -> u64 { + let input_str = read_string(input_ptr, input_len); + let input: PreHookInput = serde_json::from_str(&input_str).unwrap(); + + let mut output = PreHookOutput { + context: input.context.clone(), + ..Default::default() + }; + + // Add custom values to context + output.context.set_value("plugin_processed", serde_json::json!(true)); + output.context.set_value("plugin_name", serde_json::json!("my-rust-plugin")); + + write_string(&serde_json::to_string(&output).unwrap()) +} +``` + +### Short-Circuit with Mock Response + +```rust +#[no_mangle] +pub extern "C" fn pre_hook(input_ptr: u32, input_len: u32) -> u64 { + let input_str = read_string(input_ptr, input_len); + let input: PreHookInput = serde_json::from_str(&input_str).unwrap(); + + let (provider, model) = input.get_provider_model(); + + // Check if this should be mocked + if model == "mock-model" { + let mut output = PreHookOutput { + context: input.context.clone(), + has_short_circuit: true, + ..Default::default() + }; + + // Build mock response + let mock_response = BifrostResponse { + chat_response: Some(BifrostChatResponse { + id: format!("mock-{}", input.context.request_id.unwrap_or_default()), + model: "mock-model".to_string(), + choices: vec![ResponseChoice { + index: 0, + message: Some(ChatMessage { + role: ChatMessageRole::Assistant, + content: Some(ChatMessageContent::Text( + "This is a mock response!".to_string() + )), + ..Default::default() + }), + finish_reason: Some("stop".to_string()), + ..Default::default() + }], + usage: Some(LLMUsage { + prompt_tokens: 10, + completion_tokens: 15, + total_tokens: 25, + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + + output.short_circuit = Some(PluginShortCircuit { + response: Some(mock_response), + error: None, + }); + + return write_string(&serde_json::to_string(&output).unwrap()); + } + + // Pass through + let output = PreHookOutput { + context: input.context, + ..Default::default() + }; + write_string(&serde_json::to_string(&output).unwrap()) +} +``` + +### Short-Circuit with Error + +```rust +#[no_mangle] +pub extern "C" fn pre_hook(input_ptr: u32, input_len: u32) -> u64 { + let input_str = read_string(input_ptr, input_len); + let input: PreHookInput = serde_json::from_str(&input_str).unwrap(); + + // Check rate limit (example) + if should_rate_limit(&input.context) { + let mut output = PreHookOutput { + context: input.context.clone(), + has_short_circuit: true, + ..Default::default() + }; + + output.short_circuit = Some(PluginShortCircuit { + response: None, + error: Some( + BifrostError::new("Rate limit exceeded") + .with_type("rate_limit") + .with_code("429") + .with_status(429) + ), + }); + + return write_string(&serde_json::to_string(&output).unwrap()); + } + + // Pass through + let output = PreHookOutput { + context: input.context, + ..Default::default() + }; + write_string(&serde_json::to_string(&output).unwrap()) +} +``` + +### Modifying Responses in post_hook + +```rust +#[no_mangle] +pub extern "C" fn post_hook(input_ptr: u32, input_len: u32) -> u64 { + let input_str = read_string(input_ptr, input_len); + let input: PostHookInput = serde_json::from_str(&input_str).unwrap(); + + let mut output = PostHookOutput { + context: input.context.clone(), + ..Default::default() + }; + + // Handle errors + if input.has_error { + output.has_error = true; + output.error = input.error.clone(); + + // Optionally modify the error + if let Some(mut error) = input.parse_error() { + error.error.message = format!("{} (via rust plugin)", error.error.message); + output.error = serde_json::to_value(&error).unwrap_or_default(); + } + + return write_string(&serde_json::to_string(&output).unwrap()); + } + + // Pass through or modify response + if let Some(mut response) = input.parse_response() { + if let Some(ref mut chat) = response.chat_response { + // Add a marker to the model name + chat.model = format!("{} (via rust-wasm)", chat.model); + } + output.response = serde_json::to_value(&response).unwrap_or_default(); + } + + write_string(&serde_json::to_string(&output).unwrap()) +} +``` + +## Usage with Bifrost + +Configure the plugin in your Bifrost config: + +```json +{ + "plugins": [ + { + "path": "/path/to/hello-world.wasm", + "name": "hello-world-wasm-rust", + "enabled": true, + "config": { + "custom_option": "value" + } + } + ] +} +``` + +## Testing + +The plugin includes unit tests that can be run with: + +```bash +cargo test +``` + +## Benefits + +1. **Performance**: Rust compiles to highly optimized WASM +2. **Safety**: Memory safety without garbage collection +3. **Small binaries**: Rust WASM binaries are typically very small +4. **Cross-platform**: Single `.wasm` binary runs on any OS/architecture +5. **Security**: WASM provides sandboxed execution +6. **Type Safety**: Strongly typed structures with serde derive macros +7. **Excellent JSON**: serde_json provides robust JSON handling diff --git a/examples/plugins/hello-world-wasm-rust/src/lib.rs b/examples/plugins/hello-world-wasm-rust/src/lib.rs new file mode 100644 index 0000000000..b856b8acd0 --- /dev/null +++ b/examples/plugins/hello-world-wasm-rust/src/lib.rs @@ -0,0 +1,292 @@ +//! Bifrost WASM Plugin for Rust +//! +//! This plugin demonstrates the proper structure for parsing inputs, +//! building responses, and handling context - similar to Go plugin patterns. +//! +//! Build with: cargo build --release --target wasm32-unknown-unknown + +mod memory; +mod types; + +use memory::{read_string, write_string}; +use types::*; + +// Global configuration storage +static mut PLUGIN_CONFIG: Option = None; + +// ============================================================================= +// Exported Plugin Functions +// ============================================================================= + +/// Return the plugin name +#[no_mangle] +pub extern "C" fn get_name() -> u64 { + write_string("hello-world-wasm-rust") +} + +/// Initialize the plugin with config +/// Returns 0 on success, non-zero on error +#[no_mangle] +pub extern "C" fn init(config_ptr: u32, config_len: u32) -> i32 { + let config_str = read_string(config_ptr, config_len); + + // Parse configuration + let config: PluginConfig = if config_str.is_empty() { + PluginConfig::default() + } else { + match serde_json::from_str(&config_str) { + Ok(c) => c, + Err(_) => return 1, // Config parse error + } + }; + + // Store configuration + unsafe { + PLUGIN_CONFIG = Some(config); + } + + 0 // Success +} + +/// HTTP transport intercept +/// Called at the HTTP layer before request enters Bifrost core. +/// Can modify headers, query params, or short-circuit with a response. +#[no_mangle] +pub extern "C" fn http_intercept(input_ptr: u32, input_len: u32) -> u64 { + let input_str = read_string(input_ptr, input_len); + + // Parse input + let input: HTTPInterceptInput = match serde_json::from_str(&input_str) { + Ok(i) => i, + Err(e) => { + // Include context around the error position for debugging + let error_context = if let Some(col) = extract_column(&e.to_string()) { + let start = col.saturating_sub(50); + let end = (col + 50).min(input_str.len()); + format!(" | context: ...{}...", &input_str[start..end]) + } else { + String::new() + }; + let output = HTTPInterceptOutput { + error: format!("Failed to parse input: {}{}", e, error_context), + ..Default::default() + }; + return write_string(&serde_json::to_string(&output).unwrap_or_default()); + } + }; + + + // Add context value like Go plugin does + let mut context = input.context; + context.set_value("from-http", serde_json::json!("123")); + + // Create output with context and request preserved (pass-through) + // Serialize request to Value to ensure proper JSON structure + let request_value = serde_json::to_value(&input.request).ok(); + + let output = HTTPInterceptOutput { + context: input.context, + request: input.request, + has_response: false, + ..Default::default() + }; + + // Pass through + write_string(&serde_json::to_string(&output).unwrap_or_default()) +} + +/// Pre-request hook +/// Called before request is sent to the provider. +/// Can modify the request or short-circuit with a response/error. +#[no_mangle] +pub extern "C" fn pre_hook(input_ptr: u32, input_len: u32) -> u64 { + let input_str = read_string(input_ptr, input_len); + + // Parse input + let input: PreHookInput = match serde_json::from_str(&input_str) { + Ok(i) => i, + Err(e) => { + let output = PreHookOutput { + error: format!("Failed to parse input: {}", e), + ..Default::default() + }; + return write_string(&serde_json::to_string(&output).unwrap_or_default()); + } + }; + + // Create output with context preserved + let mut output = PreHookOutput { + context: input.context.clone(), + request: input.request.clone(), + has_short_circuit: false, + ..Default::default() + }; + + // Get provider and model for potential modifications + let (_provider, model) = input.get_provider_model(); + + // Example: Short-circuit with mock response for specific model + // Uncomment to test: + /* + if model == "mock-model" { + output.has_short_circuit = true; + + let mock_response = BifrostResponse { + chat_response: Some(BifrostChatResponse { + id: format!("mock-{}", input.context.request_id.unwrap_or_default()), + model: "mock-model".to_string(), + choices: vec![ResponseChoice { + index: 0, + message: Some(ChatMessage { + role: ChatMessageRole::Assistant, + content: Some(ChatMessageContent::Text( + "This is a mock response from the Rust WASM plugin!".to_string() + )), + ..Default::default() + }), + finish_reason: Some("stop".to_string()), + ..Default::default() + }], + usage: Some(LLMUsage { + prompt_tokens: 10, + completion_tokens: 15, + total_tokens: 25, + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + + output.short_circuit = Some(PluginShortCircuit { + response: Some(mock_response), + error: None, + }); + + return write_string(&serde_json::to_string(&output).unwrap_or_default()); + } + */ + + // Example: Short-circuit with rate limit error + // Uncomment to test: + /* + if should_rate_limit(&input.context) { + output.has_short_circuit = true; + output.short_circuit = Some(PluginShortCircuit { + response: None, + error: Some( + BifrostError::new("Rate limit exceeded") + .with_type("rate_limit") + .with_code("429") + .with_status(429) + ), + }); + return write_string(&serde_json::to_string(&output).unwrap_or_default()); + } + */ + + // Silence unused variable warning in example code + let _ = model; + + // Pass through - empty request means use original + write_string(&serde_json::to_string(&output).unwrap_or_default()) +} + +/// Post-response hook +/// Called after response is received from provider. +/// Can modify the response or error. +#[no_mangle] +pub extern "C" fn post_hook(input_ptr: u32, input_len: u32) -> u64 { + let input_str = read_string(input_ptr, input_len); + + // Parse input + let input: PostHookInput = match serde_json::from_str(&input_str) { + Ok(i) => i, + Err(e) => { + let output = PostHookOutput { + hook_error: format!("Failed to parse input: {}", e), + ..Default::default() + }; + return write_string(&serde_json::to_string(&output).unwrap_or_default()); + } + }; + + // Add context value like Go plugin does + let mut context = input.context.clone(); + context.set_value("from-post-hook", serde_json::json!("456")); + + // Create output with context and response/error preserved (pass-through) + // This matches Go plugin behavior exactly + let output = PostHookOutput { + context, + response: Some(input.response.clone()), + error: Some(input.error.clone()), + has_error: input.has_error, + hook_error: String::new(), + }; + + // Example: Modify error message when has_error is true + // Uncomment to test: + /* + if input.has_error { + if let Some(mut error) = input.parse_error() { + error.error.message = format!("{} (processed by Rust WASM plugin)", error.error.message); + let mut output = output; + output.error = Some(serde_json::to_value(&error).unwrap_or_default()); + return write_string(&serde_json::to_string(&output).unwrap_or_default()); + } + } + */ + + // Example: Modify response + // Uncomment to test: + /* + if let Some(mut response) = input.parse_response() { + // Add custom metadata, modify model name, etc. + if let Some(ref mut chat) = response.chat_response { + // Add a marker to the model name + chat.model = format!("{} (via rust-wasm)", chat.model); + } + let mut output = output; + output.response = Some(serde_json::to_value(&response).unwrap_or_default()); + return write_string(&serde_json::to_string(&output).unwrap_or_default()); + } + */ + + write_string(&serde_json::to_string(&output).unwrap_or_default()) +} + +/// Cleanup resources +/// Called when plugin is being unloaded. +/// Returns 0 on success, non-zero on error +#[no_mangle] +pub extern "C" fn cleanup() -> i32 { + // Clear stored configuration + unsafe { + PLUGIN_CONFIG = None; + } + + 0 // Success +} + +// ============================================================================= +// Helper Functions +// ============================================================================= + +/// Extract column number from serde error message for debugging +fn extract_column(error_msg: &str) -> Option { + // Error format: "... at line X column Y" + if let Some(idx) = error_msg.rfind("column ") { + let col_str = &error_msg[idx + 7..]; + col_str.split_whitespace().next()?.parse().ok() + } else { + None + } +} + +/// Example rate limit check function +#[allow(dead_code)] +fn should_rate_limit(_context: &BifrostContext) -> bool { + // Implement your rate limiting logic here + false +} diff --git a/examples/plugins/hello-world-wasm-rust/src/memory.rs b/examples/plugins/hello-world-wasm-rust/src/memory.rs new file mode 100644 index 0000000000..bab6fecc1d --- /dev/null +++ b/examples/plugins/hello-world-wasm-rust/src/memory.rs @@ -0,0 +1,70 @@ +//! Memory management utilities for WASM plugins. +//! Handles allocation, deallocation, and string read/write operations. + +use std::alloc::{alloc, dealloc, Layout}; +use std::slice; + +/// Pack a pointer and length into a single u64 +/// Upper 32 bits: pointer, Lower 32 bits: length +pub fn pack_result(ptr: u32, len: u32) -> u64 { + ((ptr as u64) << 32) | (len as u64) +} + +/// Write a string to WASM memory and return packed pointer+length +pub fn write_string(s: &str) -> u64 { + if s.is_empty() { + return 0; + } + let bytes = s.as_bytes(); + let ptr = unsafe { malloc(bytes.len() as u32) }; + if ptr == 0 { + return 0; + } + unsafe { + std::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr as *mut u8, bytes.len()); + } + pack_result(ptr, bytes.len() as u32) +} + +/// Read a string from WASM memory given pointer and length +pub fn read_string(ptr: u32, len: u32) -> String { + if len == 0 { + return String::new(); + } + let bytes = unsafe { slice::from_raw_parts(ptr as *const u8, len as usize) }; + String::from_utf8_lossy(bytes).into_owned() +} + +/// Allocate memory for the host to write data +/// +/// # Safety +/// This function is marked as safe but performs unsafe operations internally. +/// It is intended to be called from WASM host. +#[no_mangle] +pub extern "C" fn malloc(size: u32) -> u32 { + if size == 0 { + return 0; + } + let layout = match Layout::from_size_align(size as usize, 1) { + Ok(l) => l, + Err(_) => return 0, + }; + unsafe { alloc(layout) as u32 } +} + +/// Free allocated memory +/// +/// # Safety +/// This function is marked as safe but performs unsafe operations internally. +/// It is intended to be called from WASM host. +#[no_mangle] +pub extern "C" fn free(ptr: u32, size: u32) { + if ptr == 0 || size == 0 { + return; + } + let layout = match Layout::from_size_align(size as usize, 1) { + Ok(l) => l, + Err(_) => return, + }; + unsafe { dealloc(ptr as *mut u8, layout) } +} diff --git a/examples/plugins/hello-world-wasm-rust/src/types.rs b/examples/plugins/hello-world-wasm-rust/src/types.rs new file mode 100644 index 0000000000..6bb9bfc234 --- /dev/null +++ b/examples/plugins/hello-world-wasm-rust/src/types.rs @@ -0,0 +1,834 @@ +//! Type definitions for Bifrost WASM plugins. +//! These structures mirror the Go SDK types for interoperability. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +// ============================================================================= +// Nullable Deserializers +// ============================================================================= + +/// Helper module for deserializing fields that may be null in JSON. +/// Go's JSON encoder outputs `null` for nil slices/maps, but Rust's serde +/// with `#[serde(default)]` only handles missing fields, not explicit nulls. +mod nullable { + use serde::{Deserialize, Deserializer}; + use std::collections::HashMap; + + /// Deserialize a string that may be null, converting null to empty string. + pub fn string<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + Option::::deserialize(deserializer).map(|opt| opt.unwrap_or_default()) + } + + /// Deserialize a HashMap that may be null or contain null values. + /// Handles both `null` (entire map is null) and `{"key": null}` (value is null). + pub fn string_map<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + // First deserialize as Option>> to handle null values + let opt_map: Option>> = Option::deserialize(deserializer)?; + + match opt_map { + None => Ok(HashMap::new()), + Some(map) => { + // Filter out null values and unwrap the rest + Ok(map + .into_iter() + .filter_map(|(k, v)| v.map(|val| (k, val))) + .collect()) + } + } + } + + /// Deserialize an i32 that may be null, converting null to 0. + pub fn i32_field<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + Option::::deserialize(deserializer).map(|opt| opt.unwrap_or_default()) + } + + /// Deserialize an HTTPRequest that may be null, converting null to default. + pub fn http_request<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + Option::::deserialize(deserializer).map(|opt| opt.unwrap_or_default()) + } + + /// Deserialize a BifrostContext that may be null, converting null to default. + pub fn context<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + Option::::deserialize(deserializer).map(|opt| opt.unwrap_or_default()) + } +} + +// ============================================================================= +// Context Structure +// ============================================================================= + +/// BifrostContext holds request-scoped values passed between hooks. +/// This is a dynamic map (map[string]any in Go) that can hold any JSON values. +/// Common keys include: +/// - request_id: Unique identifier for the request +/// - Custom plugin values can be added and will be persisted across hooks +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(transparent)] +pub struct BifrostContext(pub HashMap); + +impl BifrostContext { + pub fn new() -> Self { + Self(HashMap::new()) + } + + /// Set a custom value in the context + pub fn set_value(&mut self, key: &str, value: impl Into) { + self.0.insert(key.to_string(), value.into()); + } + + /// Get a value from the context + pub fn get(&self, key: &str) -> Option<&serde_json::Value> { + self.0.get(key) + } + + /// Get a string value from the context + pub fn get_string(&self, key: &str) -> Option<&str> { + self.0.get(key).and_then(|v| v.as_str()) + } + + /// Get a boolean value from the context + pub fn get_bool(&self, key: &str) -> Option { + self.0.get(key).and_then(|v| v.as_bool()) + } + + /// Get an i64 value from the context + pub fn get_i64(&self, key: &str) -> Option { + self.0.get(key).and_then(|v| v.as_i64()) + } + + /// Check if a key exists in the context + pub fn contains_key(&self, key: &str) -> bool { + self.0.contains_key(key) + } + + /// Remove a value from the context + pub fn remove(&mut self, key: &str) -> Option { + self.0.remove(key) + } + + /// Get the underlying HashMap for iteration + pub fn inner(&self) -> &HashMap { + &self.0 + } + + /// Get mutable access to the underlying HashMap + pub fn inner_mut(&mut self) -> &mut HashMap { + &mut self.0 + } +} + +// ============================================================================= +// HTTP Transport Structures +// ============================================================================= + +/// HTTPRequest represents an incoming HTTP request at the transport layer. +/// Body is base64-encoded. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct HTTPRequest { + #[serde(default, deserialize_with = "nullable::string")] + pub method: String, + + #[serde(default, deserialize_with = "nullable::string")] + pub path: String, + + #[serde(default, deserialize_with = "nullable::string_map")] + pub headers: HashMap, + + #[serde(default, deserialize_with = "nullable::string_map")] + pub query: HashMap, + + /// Base64-encoded request body + #[serde(default, deserialize_with = "nullable::string")] + pub body: String, +} + +/// HTTPResponse represents an HTTP response to return. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct HTTPResponse { + #[serde(default, deserialize_with = "nullable::i32_field")] + pub status_code: i32, + + #[serde(default, deserialize_with = "nullable::string_map")] + pub headers: HashMap, + + /// Base64-encoded response body + #[serde(default, deserialize_with = "nullable::string")] + pub body: String, +} + +/// HTTPInterceptInput is the input for http_intercept hook. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct HTTPInterceptInput { + #[serde(default, deserialize_with = "nullable::context")] + pub context: BifrostContext, + + #[serde(default, deserialize_with = "nullable::http_request")] + pub request: HTTPRequest, +} + +/// HTTPInterceptOutput is the output for http_intercept hook. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct HTTPInterceptOutput { + pub context: BifrostContext, + + #[serde(skip_serializing_if = "Option::is_none")] + pub request: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub response: Option, + + #[serde(default)] + pub has_response: bool, + + #[serde(default)] + pub error: String, +} + +// ============================================================================= +// Chat Completion Structures (BifrostRequest) +// ============================================================================= + +/// ChatMessageRole represents the role of a message sender. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ChatMessageRole { + User, + Assistant, + System, + Tool, + Developer, +} + +impl Default for ChatMessageRole { + fn default() -> Self { + ChatMessageRole::User + } +} + +/// ChatMessageContent can be either a string or an array of content blocks. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ChatMessageContent { + Text(String), + Blocks(Vec), +} + +impl Default for ChatMessageContent { + fn default() -> Self { + ChatMessageContent::Text(String::new()) + } +} + +/// ChatContentBlock represents a content block in a message. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ChatContentBlock { + #[serde(rename = "type")] + pub block_type: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub image_url: Option, +} + +/// ImageUrl represents an image URL in a content block. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ImageUrl { + pub url: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub detail: Option, +} + +/// ChatMessage represents a message in the conversation. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ChatMessage { + #[serde(default)] + pub role: ChatMessageRole, + + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, +} + +/// ToolCall represents a tool call made by the assistant. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ToolCall { + #[serde(default)] + pub id: Option, + + #[serde(rename = "type", default)] + pub call_type: Option, + + #[serde(default)] + pub function: ToolCallFunction, +} + +/// ToolCallFunction represents the function being called. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ToolCallFunction { + #[serde(default)] + pub name: Option, + + #[serde(default)] + pub arguments: String, +} + +/// ChatParameters contains optional parameters for chat completion. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ChatParameters { + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub max_completion_tokens: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + /// Catch-all for additional parameters + #[serde(flatten)] + pub extra: HashMap, +} + +/// ChatTool represents a tool definition. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ChatTool { + #[serde(rename = "type")] + pub tool_type: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub function: Option, +} + +/// ChatToolFunction represents a function definition. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ChatToolFunction { + pub name: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, +} + +/// BifrostChatRequest represents a chat completion request. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostChatRequest { + #[serde(default)] + pub provider: String, + + #[serde(default)] + pub model: String, + + #[serde(default)] + pub input: Vec, + + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub fallbacks: Option>, +} + +/// Fallback represents a fallback provider/model. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct Fallback { + pub provider: String, + pub model: String, +} + +/// BifrostRequest is the unified request structure. +/// Only one of the request types should be present. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostRequest { + #[serde(skip_serializing_if = "Option::is_none")] + pub chat_request: Option, + + // Add other request types as needed + #[serde(flatten)] + pub extra: HashMap, +} + +impl BifrostRequest { + /// Get provider and model from the request + pub fn get_provider_model(&self) -> (String, String) { + if let Some(ref chat) = self.chat_request { + return (chat.provider.clone(), chat.model.clone()); + } + (String::new(), String::new()) + } +} + +// ============================================================================= +// Response Structures (BifrostResponse) +// ============================================================================= + +/// LLMUsage contains token usage information. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct LLMUsage { + #[serde(default)] + pub prompt_tokens: i32, + + #[serde(default)] + pub completion_tokens: i32, + + #[serde(default)] + pub total_tokens: i32, + + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_tokens_details: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_tokens_details: Option, +} + +/// ResponseChoice represents a single completion choice. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ResponseChoice { + #[serde(default)] + pub index: i32, + + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub delta: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, +} + +/// BifrostChatResponse represents a chat completion response. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostChatResponse { + #[serde(default)] + pub id: String, + + #[serde(default)] + pub model: String, + + #[serde(default)] + pub choices: Vec, + + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub created: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub object: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, +} + +/// BifrostResponse is the unified response structure. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostResponse { + #[serde(skip_serializing_if = "Option::is_none")] + pub chat_response: Option, + + #[serde(flatten)] + pub extra: HashMap, +} + +// ============================================================================= +// Error Structure +// ============================================================================= + +/// ErrorField contains the error details. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ErrorField { + #[serde(default)] + pub message: String, + + #[serde(skip_serializing_if = "Option::is_none", rename = "type")] + pub error_type: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub param: Option, +} + +/// BifrostError represents an error response. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BifrostError { + #[serde(default)] + pub error: ErrorField, + + #[serde(skip_serializing_if = "Option::is_none")] + pub status_code: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub allow_fallbacks: Option, +} + +impl BifrostError { + /// Create a new error with a message + pub fn new(message: &str) -> Self { + Self { + error: ErrorField { + message: message.to_string(), + ..Default::default() + }, + ..Default::default() + } + } + + /// Set the error type + pub fn with_type(mut self, error_type: &str) -> Self { + self.error.error_type = Some(error_type.to_string()); + self + } + + /// Set the error code + pub fn with_code(mut self, code: &str) -> Self { + self.error.code = Some(code.to_string()); + self + } + + /// Set the status code + pub fn with_status(mut self, status: i32) -> Self { + self.status_code = Some(status); + self + } +} + +// ============================================================================= +// Short Circuit Structure +// ============================================================================= + +/// PluginShortCircuit allows plugins to short-circuit the request flow. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PluginShortCircuit { + #[serde(skip_serializing_if = "Option::is_none")] + pub response: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +// ============================================================================= +// Hook Input/Output Structures +// ============================================================================= + +/// PreHookInput is the input for pre_hook. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PreHookInput { + #[serde(default)] + pub context: BifrostContext, + + #[serde(default)] + pub request: serde_json::Value, +} + +impl PreHookInput { + /// Parse the request as a BifrostRequest + pub fn parse_request(&self) -> Option { + serde_json::from_value(self.request.clone()).ok() + } + + /// Get provider and model from the request + pub fn get_provider_model(&self) -> (String, String) { + if let Some(req) = self.parse_request() { + return req.get_provider_model(); + } + // Try direct access for simpler structures + let provider = self.request.get("provider") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let model = self.request.get("model") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + (provider, model) + } +} + +/// PreHookOutput is the output for pre_hook. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PreHookOutput { + pub context: BifrostContext, + + #[serde(skip_serializing_if = "Option::is_none")] + pub request: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub short_circuit: Option, + + #[serde(default)] + pub has_short_circuit: bool, + + #[serde(default)] + pub error: String, +} + +/// PostHookInput is the input for post_hook. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PostHookInput { + #[serde(default)] + pub context: BifrostContext, + + #[serde(default)] + pub response: serde_json::Value, + + #[serde(default)] + pub error: serde_json::Value, + + #[serde(default)] + pub has_error: bool, +} + +impl PostHookInput { + /// Parse the response as a BifrostResponse + pub fn parse_response(&self) -> Option { + serde_json::from_value(self.response.clone()).ok() + } + + /// Parse the error as a BifrostError + pub fn parse_error(&self) -> Option { + if self.has_error { + serde_json::from_value(self.error.clone()).ok() + } else { + None + } + } +} + +/// PostHookOutput is the output for post_hook. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PostHookOutput { + pub context: BifrostContext, + + #[serde(skip_serializing_if = "Option::is_none")] + pub response: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + + #[serde(default)] + pub has_error: bool, + + #[serde(default)] + pub hook_error: String, +} + +// ============================================================================= +// Plugin Configuration +// ============================================================================= + +/// Plugin configuration (customize as needed) +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PluginConfig { + #[serde(flatten)] + pub values: HashMap, +} + +// ============================================================================= +// Tests +// ============================================================================= + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_context_serialization() { + let mut ctx = BifrostContext::new(); + ctx.set_value("request_id", "test-123"); + ctx.set_value("custom_key", "custom_value"); + ctx.set_value("is_enabled", true); + ctx.set_value("count", 42); + + let json = serde_json::to_string(&ctx).unwrap(); + assert!(json.contains("request_id")); + assert!(json.contains("custom_key")); + assert!(json.contains("is_enabled")); + assert!(json.contains("count")); + } + + #[test] + fn test_context_deserialization() { + let json = r#"{"request_id": "test-123", "custom_key": "custom_value", "is_enabled": true}"#; + let ctx: BifrostContext = serde_json::from_str(json).unwrap(); + + assert_eq!(ctx.get_string("request_id"), Some("test-123")); + assert_eq!(ctx.get_string("custom_key"), Some("custom_value")); + assert_eq!(ctx.get_bool("is_enabled"), Some(true)); + } + + #[test] + fn test_context_methods() { + let mut ctx = BifrostContext::new(); + ctx.set_value("key1", "value1"); + ctx.set_value("enabled", true); + ctx.set_value("count", 42); + + assert_eq!(ctx.get_string("key1"), Some("value1")); + assert_eq!(ctx.get_bool("enabled"), Some(true)); + assert_eq!(ctx.get_i64("count"), Some(42)); + assert!(ctx.contains_key("key1")); + assert!(!ctx.contains_key("nonexistent")); + + ctx.remove("key1"); + assert!(!ctx.contains_key("key1")); + } + + #[test] + fn test_chat_message() { + let msg = ChatMessage { + role: ChatMessageRole::User, + content: Some(ChatMessageContent::Text("Hello!".to_string())), + ..Default::default() + }; + + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains("user")); + assert!(json.contains("Hello!")); + } + + #[test] + fn test_bifrost_error() { + let error = BifrostError::new("Test error") + .with_type("test_type") + .with_code("500") + .with_status(500); + + let json = serde_json::to_string(&error).unwrap(); + assert!(json.contains("Test error")); + assert!(json.contains("test_type")); + } + + #[test] + fn test_pre_hook_input_parsing() { + let json = r#"{ + "context": {"request_id": "test-123", "custom": "value"}, + "request": {"provider": "openai", "model": "gpt-4"} + }"#; + + let input: PreHookInput = serde_json::from_str(json).unwrap(); + assert_eq!(input.context.get_string("request_id"), Some("test-123")); + assert_eq!(input.context.get_string("custom"), Some("value")); + + let (provider, model) = input.get_provider_model(); + assert_eq!(provider, "openai"); + assert_eq!(model, "gpt-4"); + } + + #[test] + fn test_http_request_with_null_fields() { + // Simulates Go sending null for nil []byte and nil maps + let json = r#"{ + "method": "POST", + "path": "/v1/chat/completions", + "headers": null, + "query": null, + "body": null + }"#; + + let req: HTTPRequest = serde_json::from_str(json).unwrap(); + assert_eq!(req.method, "POST"); + assert_eq!(req.path, "/v1/chat/completions"); + assert!(req.headers.is_empty()); + assert!(req.query.is_empty()); + assert_eq!(req.body, ""); + } + + #[test] + fn test_http_request_with_missing_fields() { + // Test that missing fields also work (default behavior) + let json = r#"{ + "method": "GET", + "path": "/health" + }"#; + + let req: HTTPRequest = serde_json::from_str(json).unwrap(); + assert_eq!(req.method, "GET"); + assert_eq!(req.path, "/health"); + assert!(req.headers.is_empty()); + assert!(req.query.is_empty()); + assert_eq!(req.body, ""); + } + + #[test] + fn test_http_intercept_input_with_nulls() { + // Simulates a full HTTP intercept input with null body from Go + let json = r#"{ + "context": {"request_id": "abc-123"}, + "request": { + "method": "POST", + "path": "/v1/chat/completions", + "headers": {"content-type": "application/json"}, + "query": {}, + "body": null + } + }"#; + + let input: HTTPInterceptInput = serde_json::from_str(json).unwrap(); + assert_eq!(input.context.get_string("request_id"), Some("abc-123")); + assert_eq!(input.request.method, "POST"); + assert_eq!(input.request.path, "/v1/chat/completions"); + assert_eq!(input.request.headers.get("content-type"), Some(&"application/json".to_string())); + assert_eq!(input.request.body, ""); + } + + #[test] + fn test_http_response_with_null_fields() { + let json = r#"{ + "status_code": null, + "headers": null, + "body": null + }"#; + + let resp: HTTPResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.status_code, 0); + assert!(resp.headers.is_empty()); + assert_eq!(resp.body, ""); + } +} diff --git a/examples/plugins/hello-world-wasm-typescript/Makefile b/examples/plugins/hello-world-wasm-typescript/Makefile new file mode 100644 index 0000000000..bb4c2e1a7a --- /dev/null +++ b/examples/plugins/hello-world-wasm-typescript/Makefile @@ -0,0 +1,70 @@ +.PHONY: all build build-debug clean help install check-node + +# Colors +COLOR_RESET = \033[0m +COLOR_INFO = \033[36m +COLOR_SUCCESS = \033[32m +COLOR_WARNING = \033[33m +COLOR_ERROR = \033[31m +COLOR_BOLD = \033[1m + +# Plugin configuration +PLUGIN_NAME = hello-world +OUTPUT_DIR = build +OUTPUT = $(OUTPUT_DIR)/$(PLUGIN_NAME).wasm + +help: ## Show this help message + @echo '$(COLOR_BOLD)Hello World WASM Plugin (TypeScript/AssemblyScript)$(COLOR_RESET)' + @echo '' + @echo '$(COLOR_BOLD)Usage:$(COLOR_RESET) make [target]' + @echo '' + @echo '$(COLOR_BOLD)Prerequisites:$(COLOR_RESET)' + @echo ' - Node.js (https://nodejs.org/)' + @echo ' - npm install (to install AssemblyScript)' + @echo '' + @echo '$(COLOR_BOLD)Available targets:$(COLOR_RESET)' + @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {printf " $(COLOR_INFO)%-15s$(COLOR_RESET) %s\n", $$1, $$2}' $(MAKEFILE_LIST) + +check-node: ## Check if Node.js is installed + @which node > /dev/null 2>&1 || (echo "$(COLOR_ERROR)Error: Node.js is not installed$(COLOR_RESET)"; \ + echo "$(COLOR_INFO)Install Node.js: https://nodejs.org/$(COLOR_RESET)"; \ + exit 1) + @echo "$(COLOR_SUCCESS)āœ“ Node.js found: $$(node --version)$(COLOR_RESET)" + +install: check-node ## Install dependencies + @echo "$(COLOR_INFO)Installing dependencies...$(COLOR_RESET)" + npm install + @echo "$(COLOR_SUCCESS)āœ“ Dependencies installed$(COLOR_RESET)" + +build: install ## Build the WASM plugin + @mkdir -p $(OUTPUT_DIR) + @echo "$(COLOR_INFO)Building WASM plugin...$(COLOR_RESET)" + npm run build + @echo "$(COLOR_SUCCESS)āœ“ Plugin built successfully: $(OUTPUT)$(COLOR_RESET)" + @ls -lh $(OUTPUT) | awk '{print " Size: " $$5}' + +build-debug: install ## Build with debug info + @mkdir -p $(OUTPUT_DIR) + @echo "$(COLOR_INFO)Building WASM plugin (debug)...$(COLOR_RESET)" + npm run build:debug + @echo "$(COLOR_SUCCESS)āœ“ Debug plugin built: $(OUTPUT)$(COLOR_RESET)" + @ls -lh $(OUTPUT) | awk '{print " Size: " $$5}' + +clean: ## Remove build artifacts + @echo "$(COLOR_INFO)Cleaning build artifacts...$(COLOR_RESET)" + @rm -rf $(OUTPUT_DIR) node_modules + @echo "$(COLOR_SUCCESS)āœ“ Clean complete$(COLOR_RESET)" + +info: ## Show build information + @echo "$(COLOR_BOLD)Build Configuration$(COLOR_RESET)" + @echo " Plugin Name: $(PLUGIN_NAME)" + @echo " Output: $(OUTPUT)" + @echo "" + @if [ -f "$(OUTPUT)" ]; then \ + echo "$(COLOR_SUCCESS)Plugin exists:$(COLOR_RESET)"; \ + ls -lh $(OUTPUT) | awk '{print " " $$9 " (" $$5 ")"}'; \ + else \ + echo "$(COLOR_WARNING)Plugin not built yet$(COLOR_RESET)"; \ + fi + +.DEFAULT_GOAL := help diff --git a/examples/plugins/hello-world-wasm-typescript/README.md b/examples/plugins/hello-world-wasm-typescript/README.md new file mode 100644 index 0000000000..d573350193 --- /dev/null +++ b/examples/plugins/hello-world-wasm-typescript/README.md @@ -0,0 +1,453 @@ +# Bifrost WASM Plugin (TypeScript/AssemblyScript) + +A comprehensive example of a Bifrost plugin written in TypeScript and compiled to WebAssembly using AssemblyScript. This plugin demonstrates proper structure definitions, JSON parsing, context handling, and request/response modification patterns. + +## Prerequisites + +### Node.js Installation + +Node.js is required to run AssemblyScript: + +**macOS:** +```bash +brew install node +``` + +**Linux:** +```bash +curl -fsSL https://deb.nodesource.com/setup_20.x | sudo -E bash - +sudo apt install -y nodejs +``` + +**Other platforms:** +See [Node.js Downloads](https://nodejs.org/en/download/) + +## Building + +```bash +# Install dependencies and build +make build + +# Build with debug info +make build-debug + +# Clean build artifacts +make clean +``` + +The compiled plugin will be at `build/hello-world.wasm`. + +## File Structure + +``` +assembly/ +ā”œā”€ā”€ index.ts # Plugin implementation (hooks) +ā”œā”€ā”€ memory.ts # Memory management utilities +ā”œā”€ā”€ types.ts # Type definitions (mirrors Go SDK) +└── tsconfig.json # AssemblyScript config +``` + +## Plugin Structure + +WASM plugins must export the following functions: + +| Export | Signature | Description | +|--------|-----------|-------------| +| `malloc` | `(size: u32) -> u32` | Allocate memory for host to write data | +| `free` | `(ptr: u32)` | Free allocated memory | +| `get_name` | `() -> u64` | Returns packed ptr+len of plugin name | +| `init` | `(config_ptr, config_len: u32) -> i32` | Initialize with config (optional) | +| `http_intercept` | `(input_ptr, input_len: u32) -> u64` | HTTP transport intercept | +| `pre_hook` | `(input_ptr, input_len: u32) -> u64` | Pre-request hook | +| `post_hook` | `(input_ptr, input_len: u32) -> u64` | Post-response hook | +| `cleanup` | `() -> i32` | Cleanup resources (0 = success) | + +### Return Value Format + +Functions returning data use a packed `u64` format: +- Upper 32 bits: pointer to data in WASM memory +- Lower 32 bits: length of data + +## Data Structures + +This plugin uses `json-as` with `@json` decorators for automatic JSON serialization. All structures mirror the Go SDK types: + +### Context + +```typescript +@json +class BifrostContext { + request_id: string = '' // Unique request identifier + plugin_processed: string = '' // Custom plugin values + plugin_name: string = '' +} +``` + +### HTTP Transport Types + +```typescript +@json +class HTTPRequest { + method: string = '' // GET, POST, etc. + path: string = '' // /v1/chat/completions + body: string = '' // base64 encoded +} + +@json +class HTTPResponse { + status_code: i32 = 200 // HTTP status code + body: string = '' // base64 encoded +} +``` + +### Chat Completion Types + +```typescript +@json +class ChatMessage { + role: string = '' // "user", "assistant", "system", "tool" + content: string = '' + name: string = '' + tool_call_id: string = '' +} + +@json +class ChatParameters { + temperature: f64 = 0 + max_completion_tokens: i32 = 0 + top_p: f64 = 0 +} + +@json +class BifrostChatRequest { + provider: string = '' // "openai", "anthropic", etc. + model: string = '' // "gpt-4", "claude-3", etc. + input: ChatMessage[] = [] + params: ChatParameters = new ChatParameters() +} +``` + +### Response Types + +```typescript +@json +class LLMUsage { + prompt_tokens: i32 = 0 + completion_tokens: i32 = 0 + total_tokens: i32 = 0 +} + +@json +class ResponseChoice { + index: i32 = 0 + message: ChatMessage = new ChatMessage() + finish_reason: string = 'stop' // "stop", "length", "tool_calls" +} + +@json +class BifrostChatResponse { + id: string = '' + model: string = '' + choices: ResponseChoice[] = [] + usage: LLMUsage = new LLMUsage() +} +``` + +### Error Types + +```typescript +@json +class ErrorField { + message: string = '' + type: string = '' // "rate_limit", "auth_error", etc. + code: string = '' // "429", "401", etc. +} + +@json +class BifrostError { + error: ErrorField = new ErrorField() + status_code: i32 = 0 +} +``` + +### Short Circuit + +```typescript +@json +class PluginShortCircuit { + response: BifrostResponse | null = null // Success short-circuit + error: BifrostError | null = null // Error short-circuit +} +``` + +## Hook Input/Output Structures + +### http_intercept + +**Input:** +```json +{ + "context": { "request_id": "abc-123" }, + "request": { + "method": "POST", + "path": "/v1/chat/completions", + "headers": { "Content-Type": "application/json" }, + "query": {}, + "body": "" + } +} +``` + +**Output:** +```json +{ + "context": { "request_id": "abc-123", "custom_key": "value" }, + "request": {}, + "response": { "status_code": 200, "headers": {}, "body": "" }, + "has_response": false, + "error": "" +} +``` + +### pre_hook + +**Input:** +```json +{ + "context": { "request_id": "abc-123" }, + "request": { + "provider": "openai", + "model": "gpt-4", + "input": [{ "role": "user", "content": "Hello" }], + "params": { "temperature": 0.7 } + } +} +``` + +**Output:** +```json +{ + "context": { "request_id": "abc-123", "plugin_processed": "true" }, + "request": {}, + "short_circuit": { + "response": { "chat_response": { ... } } + }, + "has_short_circuit": false, + "error": "" +} +``` + +### post_hook + +**Input:** +```json +{ + "context": { "request_id": "abc-123", "plugin_processed": "true" }, + "response": { + "chat_response": { + "id": "chatcmpl-123", + "model": "gpt-4", + "choices": [{ "index": 0, "message": { "role": "assistant", "content": "Hi!" } }], + "usage": { "prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15 } + } + }, + "error": {}, + "has_error": false +} +``` + +**Output:** +```json +{ + "context": { "request_id": "abc-123", "post_hook_completed": "true" }, + "response": {}, + "error": {}, + "has_error": false, + "hook_error": "" +} +``` + +## Usage Examples + +### Modifying Context + +```typescript +import { JSON } from 'json-as' + +export function pre_hook(inputPtr: u32, inputLen: u32): u64 { + const inputJson = readString(inputPtr, inputLen) + const input = JSON.parse(inputJson) + + const output = new PreHookOutput() + output.context = input.context + + // Add custom values to context + output.context.plugin_processed = 'true' + output.context.plugin_name = 'my-plugin' + + return writeString(JSON.stringify(output)) +} +``` + +### Short-Circuit with Mock Response + +```typescript +import { JSON } from 'json-as' + +export function pre_hook(inputPtr: u32, inputLen: u32): u64 { + const inputJson = readString(inputPtr, inputLen) + const input = JSON.parse(inputJson) + + // Check if this should be mocked + const model = input.request.model + if (model === 'mock-model') { + const output = new PreHookOutput() + output.context = input.context + output.has_short_circuit = true + output.short_circuit = new PluginShortCircuit() + + // Build mock response + const mockResponse = new BifrostResponse() + mockResponse.chat_response = new BifrostChatResponse() + mockResponse.chat_response!.id = 'mock-' + input.context.request_id + mockResponse.chat_response!.model = 'mock-model' + + const choice = new ResponseChoice() + choice.message.role = 'assistant' + choice.message.content = 'This is a mock response!' + mockResponse.chat_response!.choices.push(choice) + + mockResponse.chat_response!.usage.prompt_tokens = 10 + mockResponse.chat_response!.usage.completion_tokens = 15 + mockResponse.chat_response!.usage.total_tokens = 25 + + output.short_circuit!.response = mockResponse + return writeString(JSON.stringify(output)) + } + + // Pass through + const output = new PreHookOutput() + output.context = input.context + return writeString(JSON.stringify(output)) +} +``` + +### Short-Circuit with Error + +```typescript +import { JSON } from 'json-as' + +export function pre_hook(inputPtr: u32, inputLen: u32): u64 { + const inputJson = readString(inputPtr, inputLen) + const input = JSON.parse(inputJson) + + // Check rate limit (example) + if (shouldRateLimit(input.context.request_id)) { + const output = new PreHookOutput() + output.context = input.context + output.has_short_circuit = true + output.short_circuit = new PluginShortCircuit() + + const error = new BifrostError() + error.error.message = 'Rate limit exceeded' + error.error.type = 'rate_limit' + error.error.code = '429' + error.status_code = 429 + + output.short_circuit!.error = error + return writeString(JSON.stringify(output)) + } + + // Pass through + const output = new PreHookOutput() + output.context = input.context + return writeString(JSON.stringify(output)) +} +``` + +### Modifying Responses in post_hook + +```typescript +import { JSON } from 'json-as' + +export function post_hook(inputPtr: u32, inputLen: u32): u64 { + const inputJson = readString(inputPtr, inputLen) + const input = JSON.parse(inputJson) + + const output = new PostHookOutput() + output.context = input.context + + // Handle errors + if (input.has_error && input.error !== null) { + output.has_error = true + output.error = input.error + // Could modify error here if needed + return writeString(JSON.stringify(output)) + } + + // Modify response + if (input.response !== null && input.response!.chat_response !== null) { + output.response = input.response + // Could add logging, metrics, or modify response here + } + + return writeString(JSON.stringify(output)) +} +``` + +## Usage with Bifrost + +Configure the plugin in your Bifrost config: + +```json +{ + "plugins": [ + { + "path": "/path/to/hello-world.wasm", + "name": "hello-world-wasm-typescript", + "enabled": true, + "config": { + "custom_option": "value" + } + } + ] +} +``` + +## AssemblyScript Notes + +AssemblyScript is similar to TypeScript but with some differences: + +1. **Types are required**: All variables must have explicit types +2. **No closures**: Functions cannot capture variables from outer scope +3. **Limited stdlib**: Not all JavaScript/TypeScript features are available +4. **Strict null handling**: Null checks are required +5. **JSON via json-as**: Uses the `json-as` package with `@json` decorators for serialization + +This plugin uses `json-as` for JSON parsing/serialization: + +```typescript +import { JSON } from 'json-as' + +@json +class MyClass { + name: string = '' + value: i32 = 0 +} + +// Parse JSON +const obj = JSON.parse('{"name":"test","value":42}') + +// Stringify to JSON +const json = JSON.stringify(obj) +``` + +See [AssemblyScript Documentation](https://www.assemblyscript.org/introduction.html) and [json-as Documentation](https://github.com/JairusSW/as-json) for more details. + +## Benefits + +1. **Familiar syntax**: TypeScript-like syntax for JS/TS developers +2. **Cross-platform**: Single `.wasm` binary runs on any OS/architecture +3. **Security**: WASM provides sandboxed execution +4. **Type Safety**: Strongly typed structures catch errors at compile time +5. **npm ecosystem**: Can use npm for dependency management diff --git a/examples/plugins/hello-world-wasm-typescript/assembly/index.ts b/examples/plugins/hello-world-wasm-typescript/assembly/index.ts new file mode 100644 index 0000000000..0533924ef6 --- /dev/null +++ b/examples/plugins/hello-world-wasm-typescript/assembly/index.ts @@ -0,0 +1,107 @@ +/** + * Bifrost WASM Plugin for TypeScript/AssemblyScript + * + * This plugin uses json-as for safe JSON parsing with @json decorators. + * + * Build with: npm run build + */ + +import { JSON } from 'json-as' +import { free as _free, malloc as _malloc, readString, writeString } from './memory' +import { + HTTPInterceptInput, + HTTPInterceptOutput, + PreHookInput, + PreHookOutput, + PostHookInput, + PostHookOutput +} from './types' + +// ============================================================================= +// Re-export memory functions for WASM host +// ============================================================================= + +export function malloc(size: u32): u32 { + return _malloc(size) +} + +export function free(ptr: u32): void { + _free(ptr) +} + +// ============================================================================= +// Plugin Configuration +// ============================================================================= + +let pluginConfig: string = '' + +// ============================================================================= +// Exported Plugin Functions +// ============================================================================= + +export function get_name(): u64 { + return writeString('hello-world-wasm-typescript') +} + +export function init(configPtr: u32, configLen: u32): i32 { + pluginConfig = readString(configPtr, configLen) + return 0 +} + +/** + * HTTP transport intercept + * Pass through the request with added context value + */ +export function http_intercept(inputPtr: u32, inputLen: u32): u64 { + const inputJson = readString(inputPtr, inputLen) + const input = JSON.parse(inputJson) + + const output = new HTTPInterceptOutput() + output.context = input.context + output.context.set('from-http', '123') + output.request = input.request + + return writeString(JSON.stringify(output)) +} + +/** + * Pre-request hook + * Pass through the request with added context value + */ +export function pre_hook(inputPtr: u32, inputLen: u32): u64 { + const inputJson = readString(inputPtr, inputLen) + const input = JSON.parse(inputJson) + + const output = new PreHookOutput() + output.context = input.context + output.context.set('from-pre-hook', '789') + output.request = input.request + + return writeString(JSON.stringify(output)) +} + +/** + * Post-response hook + * Pass through the response/error with added context value + */ +export function post_hook(inputPtr: u32, inputLen: u32): u64 { + const inputJson = readString(inputPtr, inputLen) + const input = JSON.parse(inputJson) + + const output = new PostHookOutput() + output.context = input.context + output.context.set('from-post-hook', '456') + output.response = input.response + output.error = input.error + output.has_error = input.has_error + + return writeString(JSON.stringify(output)) +} + +/** + * Cleanup resources + */ +export function cleanup(): i32 { + pluginConfig = '' + return 0 +} diff --git a/examples/plugins/hello-world-wasm-typescript/assembly/memory.ts b/examples/plugins/hello-world-wasm-typescript/assembly/memory.ts new file mode 100644 index 0000000000..fcfb425e68 --- /dev/null +++ b/examples/plugins/hello-world-wasm-typescript/assembly/memory.ts @@ -0,0 +1,45 @@ +/** + * Memory management utilities for WASM plugins. + * Handles allocation, deallocation, and string read/write operations. + */ + +// Pack a pointer and length into a single u64 +// Upper 32 bits: pointer, Lower 32 bits: length +export function packResult(ptr: u32, len: u32): u64 { + return (u64(ptr) << 32) | u64(len) +} + +// Write a string to memory and return packed pointer+length +export function writeString(s: string): u64 { + if (s.length === 0) { + return 0 + } + const encoded = String.UTF8.encode(s) + const ptr = changetype(encoded) + return packResult(ptr, encoded.byteLength) +} + +// Read a string from memory given pointer and length +export function readString(ptr: u32, len: u32): string { + if (len === 0) { + return '' + } + const buffer = new ArrayBuffer(len) + memory.copy(changetype(buffer), ptr, len) + return String.UTF8.decode(buffer) +} + +// Allocate memory for the host to write data +export function malloc(size: u32): u32 { + if (size === 0) { + return 0 + } + const buffer = new ArrayBuffer(size) + return changetype(buffer) +} + +// Free allocated memory (handled by AssemblyScript runtime) +export function free(_ptr: u32): void { + // AssemblyScript handles garbage collection + // This is provided for API compatibility +} diff --git a/examples/plugins/hello-world-wasm-typescript/assembly/tsconfig.json b/examples/plugins/hello-world-wasm-typescript/assembly/tsconfig.json new file mode 100644 index 0000000000..798b474eab --- /dev/null +++ b/examples/plugins/hello-world-wasm-typescript/assembly/tsconfig.json @@ -0,0 +1,4 @@ +{ + "extends": "assemblyscript/std/assembly.json", + "include": ["./**/*.ts"] +} diff --git a/examples/plugins/hello-world-wasm-typescript/assembly/types.ts b/examples/plugins/hello-world-wasm-typescript/assembly/types.ts new file mode 100644 index 0000000000..719b405de5 --- /dev/null +++ b/examples/plugins/hello-world-wasm-typescript/assembly/types.ts @@ -0,0 +1,130 @@ +/** + * Type definitions for Bifrost WASM plugins. + * + * Uses json-as library with @json decorators for safe JSON parsing. + * These types mirror the Go SDK types for interoperability. + */ + +import { JSON } from 'json-as' + +// ============================================================================= +// HTTP Transport Input/Output Types +// ============================================================================= + +/** + * BifrostContext holds request-scoped values passed between hooks. + * Common keys include: + * - request_id: Unique identifier for the request + * - Custom plugin values can be added and will be persisted across hooks + */ +@json +export class BifrostContext { + request_id: string = '' + + // Custom values for plugin use (add more as needed) + plugin_processed: string = '' + plugin_name: string = '' + post_hook_completed: string = '' +} + +// ============================================================================= +// HTTP Transport Structures +// ============================================================================= + +/** + * HTTPRequest represents an incoming HTTP request at the transport layer. + * Body is base64-encoded. + */ +@json +export class HTTPRequest { + method: string = '' + path: string = '' + body: string = '' // base64 encoded + headers: Map = new Map() + query: Map = new Map() +} + +/** + * HTTPResponse represents an HTTP response to return. + */ +@json +export class HTTPResponse { + status_code: i32 = 200 + body: string = '' // base64 encoded + headers: Map = new Map() +} + +/** + * HTTPInterceptInput is the input for http_intercept hook. + * Context is a dynamic object (JSON.Obj) since Go sends map[string]interface{}. + * Request is kept as JSON.Raw to pass through without full parsing. + */ +@json +export class HTTPInterceptInput { + context: JSON.Obj = new JSON.Obj() + request: JSON.Raw = new JSON.Raw('null') +} + +/** + * HTTPInterceptOutput is the output for http_intercept hook. + */ +@json +export class HTTPInterceptOutput { + context: JSON.Obj = new JSON.Obj() + request: JSON.Raw = new JSON.Raw('null') + response: JSON.Raw = new JSON.Raw('null') + has_response: bool = false + error: string = '' +} + +// ============================================================================= +// Pre-Hook Input/Output Types +// ============================================================================= + +/** + * PreHookInput is the input for pre_hook. + */ +@json +export class PreHookInput { + context: JSON.Obj = new JSON.Obj() + request: JSON.Raw = new JSON.Raw('null') +} + +/** + * PreHookOutput is the output for pre_hook. + */ +@json +export class PreHookOutput { + context: JSON.Obj = new JSON.Obj() + request: JSON.Raw = new JSON.Raw('null') + short_circuit: JSON.Raw = new JSON.Raw('null') + has_short_circuit: bool = false + error: string = '' +} + +// ============================================================================= +// Post-Hook Input/Output Types +// ============================================================================= + +/** + * PostHookInput is the input for post_hook. + */ +@json +export class PostHookInput { + context: JSON.Obj = new JSON.Obj() + response: JSON.Raw = new JSON.Raw('null') + error: JSON.Raw = new JSON.Raw('null') + has_error: bool = false +} + +/** + * PostHookOutput is the output for post_hook. + */ +@json +export class PostHookOutput { + context: JSON.Obj = new JSON.Obj() + response: JSON.Raw = new JSON.Raw('null') + error: JSON.Raw = new JSON.Raw('null') + has_error: bool = false + hook_error: string = '' +} diff --git a/examples/plugins/hello-world-wasm-typescript/package-lock.json b/examples/plugins/hello-world-wasm-typescript/package-lock.json new file mode 100644 index 0000000000..b66ee621e1 --- /dev/null +++ b/examples/plugins/hello-world-wasm-typescript/package-lock.json @@ -0,0 +1,65 @@ +{ + "name": "hello-world-wasm-typescript", + "version": "0.1.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "hello-world-wasm-typescript", + "version": "0.1.0", + "dependencies": { + "json-as": "^1.0.0" + }, + "devDependencies": { + "assemblyscript": "^0.27.29" + } + }, + "node_modules/assemblyscript": { + "version": "0.27.37", + "resolved": "https://registry.npmjs.org/assemblyscript/-/assemblyscript-0.27.37.tgz", + "integrity": "sha512-YtY5k3PiV3SyUQ6gRlR2OCn8dcVRwkpiG/k2T5buoL2ymH/Z/YbaYWbk/f9mO2HTgEtGWjPiAQrIuvA7G/63Gg==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "binaryen": "116.0.0-nightly.20240114", + "long": "^5.2.4" + }, + "bin": { + "asc": "bin/asc.js", + "asinit": "bin/asinit.js" + }, + "engines": { + "node": ">=18", + "npm": ">=10" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/assemblyscript" + } + }, + "node_modules/binaryen": { + "version": "116.0.0-nightly.20240114", + "resolved": "https://registry.npmjs.org/binaryen/-/binaryen-116.0.0-nightly.20240114.tgz", + "integrity": "sha512-0GZrojJnuhoe+hiwji7QFaL3tBlJoA+KFUN7ouYSDGZLSo9CKM8swQX8n/UcbR0d1VuZKU+nhogNzv423JEu5A==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "wasm-opt": "bin/wasm-opt", + "wasm2js": "bin/wasm2js" + } + }, + "node_modules/json-as": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/json-as/-/json-as-1.2.3.tgz", + "integrity": "sha512-yvRkR0Lv8597jHbsf+e93fo+pQctbsiDl7HGuBl71GqKhNT9KtyqtNzal7L7nEIfUq1NNkdACaT1O5D8KtX2zw==", + "license": "MIT" + }, + "node_modules/long": { + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/long/-/long-5.3.2.tgz", + "integrity": "sha512-mNAgZ1GmyNhD7AuqnTG3/VQ26o760+ZYBPKjPvugO8+nLbYfX6TVpJPseBvopbdY+qpZ/lKUnmEc1LeZYS3QAA==", + "dev": true, + "license": "Apache-2.0" + } + } +} diff --git a/examples/plugins/hello-world-wasm-typescript/package.json b/examples/plugins/hello-world-wasm-typescript/package.json new file mode 100644 index 0000000000..0500ff9d11 --- /dev/null +++ b/examples/plugins/hello-world-wasm-typescript/package.json @@ -0,0 +1,15 @@ +{ + "name": "hello-world-wasm-typescript", + "version": "0.1.0", + "description": "A Bifrost WASM plugin example in TypeScript (AssemblyScript)", + "scripts": { + "build": "asc assembly/index.ts --outFile build/hello-world.wasm --optimize --runtime stub --use abort=", + "build:debug": "asc assembly/index.ts --outFile build/hello-world.wasm --debug --runtime stub --use abort=" + }, + "dependencies": { + "json-as": "^1.0.0" + }, + "devDependencies": { + "assemblyscript": "^0.27.29" + } +} diff --git a/examples/plugins/hello-world/go.mod b/examples/plugins/hello-world/go.mod index 90338cd925..9b77e8116c 100644 --- a/examples/plugins/hello-world/go.mod +++ b/examples/plugins/hello-world/go.mod @@ -2,15 +2,31 @@ module github.com/maximhq/bifrost/examples/plugins/hello-world go 1.25.5 -require github.com/maximhq/bifrost/core v1.2.49 +require github.com/maximhq/bifrost/core v1.3.8 require ( + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/sonic v1.14.2 // indirect github.com/bytedance/sonic/loader v0.4.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect + github.com/google/go-cmp v0.7.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mailru/easyjson v0.9.1 // indirect + github.com/mark3labs/mcp-go v0.43.2 // indirect + github.com/spf13/cast v1.10.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.68.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect golang.org/x/arch v0.23.0 // indirect golang.org/x/sys v0.39.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/examples/plugins/hello-world/go.sum b/examples/plugins/hello-world/go.sum index bb5b851623..ca6cb21a1f 100644 --- a/examples/plugins/hello-world/go.sum +++ b/examples/plugins/hello-world/go.sum @@ -1,3 +1,9 @@ +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/sonic v1.14.2 h1:k1twIoe97C1DtYUo+fZQy865IuHia4PR5RPiuGPPIIE= @@ -10,13 +16,38 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= +github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= -github.com/maximhq/bifrost/core v1.2.49 h1:fk6l6r3kVBlpN73wYXmgtV6O4bhedOjSO4LAEz/7leg= -github.com/maximhq/bifrost/core v1.2.49/go.mod h1:z7nOx15e91ktZGi+pZHq+uhShlEK+fM4UyYUpP6oHAw= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8= +github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I= +github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= +github.com/maximhq/bifrost/core v1.3.8 h1:xtwB9+HeTzYz5IKHkpUtupzBd0A5yl1avdLJGjsOKPI= +github.com/maximhq/bifrost/core v1.3.8/go.mod h1:abKQRnJQPZz8/UMxCcbuNHEyq19Db+IX4KlGJdlLY8E= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -29,11 +60,23 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.68.0 h1:v12Nx16iepr8r9ySOwqI+5RBJ/DqTxhOy1HrHoDFnok= +github.com/valyala/fasthttp v1.68.0/go.mod h1:5EXiRfYQAoiO/khu4oU9VISC/eVY6JqmSpPJoHCKsz4= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/arch v0.23.0 h1:lKF64A2jF6Zd8L0knGltUnegD62JMFBiCPBmQpToHhg= golang.org/x/arch v0.23.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/examples/plugins/hello-world/main.go b/examples/plugins/hello-world/main.go index fb9c7d7694..692eb4d8ba 100644 --- a/examples/plugins/hello-world/main.go +++ b/examples/plugins/hello-world/main.go @@ -15,10 +15,14 @@ func GetName() string { return "Hello World Plugin" } -func TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - fmt.Println("TransportInterceptor called") +func HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + fmt.Println("HTTPTransportIntercept called") + // Modify request in-place + req.Headers["x-hello-world-plugin"] = "transport-interceptor-value" + // Store value in context for PreHook/PostHook ctx.SetValue(schemas.BifrostContextKey("hello-world-plugin-transport-interceptor"), "transport-interceptor-value") - return headers, body, nil + // Return nil to continue processing, or return &schemas.HTTPResponse{} to short-circuit + return nil, nil } func PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { diff --git a/framework/configstore/clientconfig.go b/framework/configstore/clientconfig.go index 08352f811a..e5d7b13b78 100644 --- a/framework/configstore/clientconfig.go +++ b/framework/configstore/clientconfig.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "encoding/json" "sort" + "strconv" "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/schemas" @@ -46,6 +47,9 @@ type ClientConfig struct { AllowedOrigins []string `json:"allowed_origins,omitempty"` // Additional allowed origins for CORS and WebSocket (localhost is always allowed) MaxRequestBodySizeMB int `json:"max_request_body_size_mb"` // The maximum request body size in MB EnableLiteLLMFallbacks bool `json:"enable_litellm_fallbacks"` // Enable litellm-specific fallbacks for text completion for Groq + MCPAgentDepth int `json:"mcp_agent_depth"` // The maximum depth for MCP agent mode tool execution + MCPToolExecutionTimeout int `json:"mcp_tool_execution_timeout"` // The timeout for individual tool execution in seconds + MCPCodeModeBindingLevel string `json:"mcp_code_mode_binding_level"` // Code mode binding level: "server" or "tool" HeaderFilterConfig *tables.GlobalHeaderFilterConfig `json:"header_filter_config,omitempty"` // Global header filtering configuration for x-bf-eh-* headers ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) } @@ -98,6 +102,24 @@ func (c *ClientConfig) GenerateClientConfigHash() (string, error) { hash.Write([]byte("enableLiteLLMFallbacks:false")) } + if c.MCPAgentDepth > 0 { + hash.Write([]byte("mcpAgentDepth:" + strconv.Itoa(c.MCPAgentDepth))) + } else { + hash.Write([]byte("mcpAgentDepth:0")) + } + + if c.MCPToolExecutionTimeout > 0 { + hash.Write([]byte("mcpToolExecutionTimeout:" + strconv.Itoa(c.MCPToolExecutionTimeout))) + } else { + hash.Write([]byte("mcpToolExecutionTimeout:0")) + } + + if c.MCPCodeModeBindingLevel != "" { + hash.Write([]byte("mcpCodeModeBindingLevel:" + c.MCPCodeModeBindingLevel)) + } else { + hash.Write([]byte("mcpCodeModeBindingLevel:server")) + } + // Hash integer fields data, err := sonic.Marshal(c.InitialPoolSize) if err != nil { diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index 8f0773b370..3646165c6a 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -3,7 +3,10 @@ package configstore import ( "context" "fmt" + "log" "strconv" + "strings" + "unicode" "github.com/google/uuid" bifrost "github.com/maximhq/bifrost/core" @@ -78,6 +81,12 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationMissingProviderColumnInKeyTable(ctx, db); err != nil { return err } + if err := migrationAddToolsToAutoExecuteJSONColumn(ctx, db); err != nil { + return err + } + if err := migrationAddIsCodeModeClientColumn(ctx, db); err != nil { + return err + } if err := migrationAddLogRetentionDaysColumn(ctx, db); err != nil { return err } @@ -87,6 +96,15 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddBatchAndCachePricingColumns(ctx, db); err != nil { return err } + if err := migrationAddMCPAgentDepthAndMCPToolExecutionTimeoutColumns(ctx, db); err != nil { + return err + } + if err := migrationAddMCPCodeModeBindingLevelColumn(ctx, db); err != nil { + return err + } + if err := migrationNormalizeMCPClientNames(ctx, db); err != nil { + return err + } if err := migrationMoveKeysToProviderConfig(ctx, db); err != nil { return err } @@ -1078,6 +1096,74 @@ func migrationMissingProviderColumnInKeyTable(ctx context.Context, db *gorm.DB) return nil } +// migrationAddToolsToAutoExecuteJSONColumn adds the tools_to_auto_execute_json column to the mcp_client table +func migrationAddToolsToAutoExecuteJSONColumn(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_tools_to_auto_execute_json_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&tables.TableMCPClient{}, "tools_to_auto_execute_json") { + if err := migrator.AddColumn(&tables.TableMCPClient{}, "tools_to_auto_execute_json"); err != nil { + return err + } + // Initialize existing rows with empty array + if err := tx.Exec("UPDATE config_mcp_clients SET tools_to_auto_execute_json = '[]' WHERE tools_to_auto_execute_json IS NULL OR tools_to_auto_execute_json = ''").Error; err != nil { + return fmt.Errorf("failed to initialize tools_to_auto_execute_json: %w", err) + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if err := migrator.DropColumn(&tables.TableMCPClient{}, "tools_to_auto_execute_json"); err != nil { + return err + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running db migration: %s", err.Error()) + } + return nil +} + +// migrationAddIsCodeModeClientColumn adds the is_code_mode_client column to the config_mcp_clients table +func migrationAddIsCodeModeClientColumn(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_is_code_mode_client_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&tables.TableMCPClient{}, "is_code_mode_client") { + if err := migrator.AddColumn(&tables.TableMCPClient{}, "is_code_mode_client"); err != nil { + return err + } + // Initialize existing rows with false (default value) + if err := tx.Exec("UPDATE config_mcp_clients SET is_code_mode_client = false WHERE is_code_mode_client IS NULL").Error; err != nil { + return fmt.Errorf("failed to initialize is_code_mode_client: %w", err) + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if err := migrator.DropColumn(&tables.TableMCPClient{}, "is_code_mode_client"); err != nil { + return err + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running db migration: %s", err.Error()) + } + return nil +} + // migrationAddLogRetentionDaysColumn adds the log_retention_days column to the client config table func migrationAddLogRetentionDaysColumn(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ @@ -1201,6 +1287,207 @@ func migrationAddBatchAndCachePricingColumns(ctx context.Context, db *gorm.DB) e return m.Migrate() } +func migrationAddMCPAgentDepthAndMCPToolExecutionTimeoutColumns(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_mcp_agent_depth_and_mcp_tool_execution_timeout_columns", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&tables.TableClientConfig{}, "mcp_agent_depth") { + if err := migrator.AddColumn(&tables.TableClientConfig{}, "mcp_agent_depth"); err != nil { + return err + } + } + if !migrator.HasColumn(&tables.TableClientConfig{}, "mcp_tool_execution_timeout") { + if err := migrator.AddColumn(&tables.TableClientConfig{}, "mcp_tool_execution_timeout"); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if err := migrator.DropColumn(&tables.TableClientConfig{}, "mcp_agent_depth"); err != nil { + return err + } + if err := migrator.DropColumn(&tables.TableClientConfig{}, "mcp_tool_execution_timeout"); err != nil { + return err + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running db migration: %s", err.Error()) + } + return nil +} + +// migrationAddMCPCodeModeBindingLevelColumn adds the mcp_code_mode_binding_level column to the client config table. +// This column stores the code mode binding level preference (server or tool). +func migrationAddMCPCodeModeBindingLevelColumn(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_mcp_code_mode_binding_level_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migratorInstance := tx.Migrator() + if !migratorInstance.HasColumn(&tables.TableClientConfig{}, "mcp_code_mode_binding_level") { + if err := migratorInstance.AddColumn(&tables.TableClientConfig{}, "mcp_code_mode_binding_level"); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migratorInstance := tx.Migrator() + if err := migratorInstance.DropColumn(&tables.TableClientConfig{}, "mcp_code_mode_binding_level"); err != nil { + return err + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running db migration: %s", err.Error()) + } + return nil +} + +// normalizeMCPClientName normalizes an MCP client name by: +// 1. Replacing hyphens and spaces with underscores +// 2. Removing leading digits +// 3. Using a default name if the result is empty +func normalizeMCPClientName(name string) string { + // Replace hyphens and spaces with underscores + normalized := strings.ReplaceAll(name, "-", "_") + normalized = strings.ReplaceAll(normalized, " ", "_") + + // Remove leading digits + normalized = strings.TrimLeftFunc(normalized, func(r rune) bool { + return unicode.IsDigit(r) + }) + + // If name becomes empty after normalization, use a default name + if normalized == "" { + normalized = "mcp_client" + } + + return normalized +} + +// migrationNormalizeMCPClientNames normalizes MCP client names by: +// 1. Replacing hyphens and spaces with underscores +// 2. Removing leading digits +// 3. Adding number suffix if name already exists +func migrationNormalizeMCPClientNames(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "normalize_mcp_client_names", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + + // Fetch all MCP clients + var mcpClients []tables.TableMCPClient + if err := tx.Find(&mcpClients).Error; err != nil { + return fmt.Errorf("failed to fetch MCP clients: %w", err) + } + + // Track assigned names in memory to avoid transaction visibility issues + // and ensure we see all updates made during this migration + assignedNames := make(map[string]bool) + + // Helper function to find a unique name + findUniqueName := func(baseName string, originalName string, excludeID uint, tx *gorm.DB, assignedNames map[string]bool) (string, error) { + // First check if base name is already assigned in this migration + if !assignedNames[baseName] { + // Also check database for existing names (excluding current client) + var existing tables.TableMCPClient + err := tx.Where("name = ? AND id != ?", baseName, excludeID).First(&existing).Error + if err == gorm.ErrRecordNotFound { + // Name is available + assignedNames[baseName] = true + // Log normalization even when no collision + if originalName != baseName { + log.Printf("MCP Client Name Normalized: '%s' -> '%s'", originalName, baseName) + } + return baseName, nil + } else if err != nil { + return "", fmt.Errorf("failed to check name availability: %w", err) + } + } + + // Name exists (either assigned in this migration or in database), try with number suffix starting from 2 + // (base name is conceptually "1", so collisions start from "2") + suffix := 2 + const maxSuffix = 1000 // Safety limit to prevent infinite loops + for { + if suffix > maxSuffix { + return "", fmt.Errorf("could not find unique name after %d attempts for base name: %s", maxSuffix, baseName) + } + candidateName := baseName + strconv.Itoa(suffix) + + // Check both in-memory map and database + if !assignedNames[candidateName] { + var existing tables.TableMCPClient + err := tx.Where("name = ? AND id != ?", candidateName, excludeID).First(&existing).Error + if err == gorm.ErrRecordNotFound { + // Found available name - log the transformation + assignedNames[candidateName] = true + log.Printf("MCP Client Name Normalized: '%s' -> '%s'", originalName, candidateName) + return candidateName, nil + } else if err != nil { + return "", fmt.Errorf("failed to check name availability: %w", err) + } + } + suffix++ + } + } + + // Process each client + for _, client := range mcpClients { + originalName := client.Name + needsUpdate := false + + // Check if name needs normalization + if strings.Contains(originalName, "-") || strings.Contains(originalName, " ") { + needsUpdate = true + } else if len(originalName) > 0 && unicode.IsDigit(rune(originalName[0])) { + needsUpdate = true + } + + if needsUpdate { + // Normalize the name + normalizedName := normalizeMCPClientName(originalName) + + // Find a unique name (pass assignedNames map to track names in this migration) + uniqueName, err := findUniqueName(normalizedName, originalName, client.ID, tx, assignedNames) + if err != nil { + return fmt.Errorf("failed to find unique name for client %d (original: %s): %w", client.ID, originalName, err) + } + + // Update the client name + if err := tx.Model(&client).Update("name", uniqueName).Error; err != nil { + return fmt.Errorf("failed to update MCP client %d name from %s to %s: %w", client.ID, originalName, uniqueName, err) + } + } + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + // Rollback is not possible as we don't store the original names + // This migration is one-way + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while running MCP client name normalization migration: %s", err.Error()) + } + return nil +} + // migrationMoveKeysToProviderConfig migrates keys from virtual key level to provider config level func migrationMoveKeysToProviderConfig(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ diff --git a/framework/configstore/migrations_test.go b/framework/configstore/migrations_test.go new file mode 100644 index 0000000000..cada594b38 --- /dev/null +++ b/framework/configstore/migrations_test.go @@ -0,0 +1,539 @@ +package configstore + +import ( + "bytes" + "context" + "fmt" + "log" + "os" + "strconv" + "strings" + "testing" + "time" + + "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +// setupTestDB creates an in-memory SQLite database for testing +func setupTestDB(t *testing.T) *gorm.DB { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + require.NoError(t, err, "Failed to create test database") + + // Create the MCP clients table + err = db.AutoMigrate(&tables.TableMCPClient{}) + require.NoError(t, err, "Failed to migrate test database") + + return db +} + +// captureLogOutput captures log output during a function execution +func captureLogOutput(fn func()) string { + var buf bytes.Buffer + log.SetOutput(&buf) + defer log.SetOutput(os.Stderr) + + fn() + return buf.String() +} + +func TestNormalizeName(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "hyphen to underscore", + input: "my-tool", + expected: "my_tool", + }, + { + name: "space to underscore", + input: "my tool", + expected: "my_tool", + }, + { + name: "multiple hyphens", + input: "my-super-tool", + expected: "my_super_tool", + }, + { + name: "multiple spaces", + input: "my super tool", + expected: "my_super_tool", + }, + { + name: "leading digits removed", + input: "123tool", + expected: "tool", + }, + { + name: "leading digits with hyphen", + input: "123my-tool", + expected: "my_tool", + }, + { + name: "empty after normalization", + input: "123", + expected: "mcp_client", + }, + { + name: "no change needed", + input: "my_tool", + expected: "my_tool", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + normalized := normalizeMCPClientName(tt.input) + assert.Equal(t, tt.expected, normalized, "normalizeMCPClientName should produce expected output") + }) + } +} + +func TestFindUniqueName_NoCollision(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // Create a test client with a unique name + client := &tables.TableMCPClient{ + Name: "existing_client", + ClientID: "client-1", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err := db.WithContext(ctx).Create(client).Error + require.NoError(t, err) + + // Test findUniqueName with a different base name (no collision) + logOutput := captureLogOutput(func() { + uniqueName, err := findUniqueNameForTest("new_client", "new_client", 999, db.WithContext(ctx)) + require.NoError(t, err) + assert.Equal(t, "new_client", uniqueName, "Should return base name when no collision") + }) + + // Should not log anything when there's no collision + assert.Empty(t, logOutput, "Should not log when name is available without suffix") +} + +func TestFindUniqueName_WithCollision(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // Create existing clients that will cause collisions + // First client with base name + client1 := &tables.TableMCPClient{ + Name: "my_tool", + ClientID: "client-1", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err := db.WithContext(ctx).Create(client1).Error + require.NoError(t, err) + + // Second client with first suffix + client2 := &tables.TableMCPClient{ + Name: "my_tool1", + ClientID: "client-2", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err = db.WithContext(ctx).Create(client2).Error + require.NoError(t, err) + + // Test findUniqueName with collision - should find "my_tool2" + // excludeID is set to a non-existent ID (999) so all existing clients are considered + var uniqueName string + logOutput := captureLogOutput(func() { + uniqueName, err = findUniqueNameForTest("my_tool", "my-tool", 999, db.WithContext(ctx)) + }) + + require.NoError(t, err) + assert.Equal(t, "my_tool2", uniqueName, "Should return name with suffix when collision occurs") + assert.Contains(t, logOutput, "MCP Client Name Normalized: 'my-tool' -> 'my_tool2'", "Should log the transformation") +} + +func TestFindUniqueName_MultipleCollisions(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // Create existing clients that will cause multiple collisions + client1 := &tables.TableMCPClient{ + Name: "test_tool", + ClientID: "client-1", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err := db.WithContext(ctx).Create(client1).Error + require.NoError(t, err) + + client2 := &tables.TableMCPClient{ + Name: "test_tool1", + ClientID: "client-2", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err = db.WithContext(ctx).Create(client2).Error + require.NoError(t, err) + + client3 := &tables.TableMCPClient{ + Name: "test_tool2", + ClientID: "client-3", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err = db.WithContext(ctx).Create(client3).Error + require.NoError(t, err) + + // Test findUniqueName with multiple collisions - should find "test_tool3" + var uniqueName string + logOutput := captureLogOutput(func() { + uniqueName, err = findUniqueNameForTest("test_tool", "test tool", 999, db.WithContext(ctx)) + }) + + require.NoError(t, err) + assert.Equal(t, "test_tool3", uniqueName, "Should return name with correct suffix after multiple collisions") + assert.Contains(t, logOutput, "MCP Client Name Normalized: 'test tool' -> 'test_tool3'", "Should log the transformation") +} + +func TestFindUniqueName_NormalizationAndCollision(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // Create existing client with normalized name + client := &tables.TableMCPClient{ + Name: "my_tool", + ClientID: "client-1", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err := db.WithContext(ctx).Create(client).Error + require.NoError(t, err) + + // Test that "my-tool" normalizes to "my_tool" and then collides, requiring suffix + var uniqueName string + logOutput := captureLogOutput(func() { + uniqueName, err = findUniqueNameForTest("my_tool", "my-tool", 999, db.WithContext(ctx)) + }) + + require.NoError(t, err) + assert.Equal(t, "my_tool2", uniqueName, "Should handle normalization and collision") + assert.Contains(t, logOutput, "MCP Client Name Normalized: 'my-tool' -> 'my_tool2'", "Should log the full transformation") +} + +func TestFindUniqueName_MultipleNormalizationsToSameBase(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // Test case: 3 entries that normalize to the same base name: + // "mcp client" -> "mcp_client" + // "mcp-client" -> "mcp_client" (collision, becomes "mcp_client2") + // "1mcp-client" -> "mcp_client" (collision, becomes "mcp_client3") + // Note: In the actual migration, names are processed sequentially and each checks + // against all previously created names. To simulate this, we need to create clients + // with the original names first, then normalize them in sequence. + + // Helper function to normalize (same logic as in migrations.go) + normalizeName := func(name string) string { + normalized := strings.ReplaceAll(name, "-", "_") + normalized = strings.ReplaceAll(normalized, " ", "_") + normalized = strings.TrimLeftFunc(normalized, func(r rune) bool { + return r >= '0' && r <= '9' + }) + if normalized == "" { + normalized = "mcp_client" + } + return normalized + } + + // Create three clients with original names (simulating pre-migration state) + clients := []*tables.TableMCPClient{ + { + Name: "mcp client", + ClientID: "client-1", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + { + Name: "mcp-client", + ClientID: "client-2", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + { + Name: "1mcp-client", + ClientID: "client-3", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + } + + for _, client := range clients { + err := db.WithContext(ctx).Create(client).Error + require.NoError(t, err) + } + + // Now simulate the migration: process each client sequentially + // First: "mcp client" -> "mcp_client" (no collision) + client1 := clients[0] + normalizedName1 := normalizeName(client1.Name) + var uniqueName1 string + var err error + logOutput1 := captureLogOutput(func() { + uniqueName1, err = findUniqueNameForTest(normalizedName1, client1.Name, client1.ID, db.WithContext(ctx)) + }) + require.NoError(t, err) + assert.Equal(t, "mcp_client", uniqueName1, "First normalization should use base name") + assert.Empty(t, logOutput1, "Should not log when name is available without suffix") + + // Update first client + err = db.WithContext(ctx).Model(client1).Update("name", uniqueName1).Error + require.NoError(t, err) + + // Second: "mcp-client" -> "mcp_client" (collision with "mcp_client", becomes "mcp_client2") + // Note: We need to check that "mcp_client" exists (from client1), so it should skip to "mcp_client2" + client2 := clients[1] + normalizedName2 := normalizeName(client2.Name) + var uniqueName2 string + logOutput2 := captureLogOutput(func() { + uniqueName2, err = findUniqueNameForTest(normalizedName2, client2.Name, client2.ID, db.WithContext(ctx)) + }) + require.NoError(t, err) + // With the updated implementation, suffixes start from 2 when base name exists + // So "mcp-client" normalizes to "mcp_client" which collides, becomes "mcp_client2" + assert.Equal(t, "mcp_client2", uniqueName2, "Second normalization should get suffix 2 (skipping 1)") + assert.Contains(t, logOutput2, "MCP Client Name Normalized: 'mcp-client' -> 'mcp_client2'", "Should log the transformation") + + // Update second client + err = db.WithContext(ctx).Model(client2).Update("name", uniqueName2).Error + require.NoError(t, err) + + // Third: "1mcp-client" -> "mcp_client" (collision with "mcp_client" and "mcp_client2", becomes "mcp_client3") + client3 := clients[2] + normalizedName3 := normalizeName(client3.Name) + var uniqueName3 string + logOutput3 := captureLogOutput(func() { + uniqueName3, err = findUniqueNameForTest(normalizedName3, client3.Name, client3.ID, db.WithContext(ctx)) + }) + require.NoError(t, err) + // Third normalization finds "mcp_client" and "mcp_client2" exist, so becomes "mcp_client3" + assert.Equal(t, "mcp_client3", uniqueName3, "Third normalization should get suffix 3") + assert.Contains(t, logOutput3, "MCP Client Name Normalized: '1mcp-client' -> 'mcp_client3'", "Should log the transformation") + + // Update third client + err = db.WithContext(ctx).Model(client3).Update("name", uniqueName3).Error + require.NoError(t, err) + + // Final verification: all three should exist with correct names + var finalClients []tables.TableMCPClient + err = db.WithContext(ctx).Find(&finalClients).Error + require.NoError(t, err) + assert.Len(t, finalClients, 3, "Should have all 3 clients") + + names := make([]string, len(finalClients)) + for i, c := range finalClients { + names[i] = c.Name + } + assert.Contains(t, names, "mcp_client", "Should contain mcp_client") + assert.Contains(t, names, "mcp_client2", "Should contain mcp_client2") + assert.Contains(t, names, "mcp_client3", "Should contain mcp_client3") +} + +func TestFindUniqueName_MigrationScenarioWithInMemoryTracking(t *testing.T) { + db := setupTestDB(t) + ctx := context.Background() + + // This test simulates the exact migration scenario where clients are processed in a loop + // and we need to track assigned names in memory to avoid transaction visibility issues + + // Create three clients with original names (simulating pre-migration state) + clients := []*tables.TableMCPClient{ + { + Name: "mcp client", + ClientID: "client-1", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + { + Name: "mcp-client", + ClientID: "client-2", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + { + Name: "1mcp-client", + ClientID: "client-3", + ConnectionType: "stdio", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + } + + for _, client := range clients { + err := db.WithContext(ctx).Create(client).Error + require.NoError(t, err) + } + + // Simulate the migration: process clients in a loop with in-memory tracking + assignedNames := make(map[string]bool) + normalizeName := func(name string) string { + normalized := strings.ReplaceAll(name, "-", "_") + normalized = strings.ReplaceAll(normalized, " ", "_") + normalized = strings.TrimLeftFunc(normalized, func(r rune) bool { + return r >= '0' && r <= '9' + }) + if normalized == "" { + normalized = "mcp_client" + } + return normalized + } + + var logOutputs []string + for _, client := range clients { + originalName := client.Name + needsUpdate := strings.Contains(originalName, "-") || strings.Contains(originalName, " ") || + (len(originalName) > 0 && originalName[0] >= '0' && originalName[0] <= '9') + + if needsUpdate { + normalizedName := normalizeName(originalName) + uniqueName, err := findUniqueNameForTestWithTracking(normalizedName, originalName, client.ID, db.WithContext(ctx), assignedNames) + require.NoError(t, err) + + // Capture log output + logOutput := captureLogOutput(func() { + // Log if name changed + if originalName != uniqueName { + log.Printf("MCP Client Name Normalized: '%s' -> '%s'", originalName, uniqueName) + } + }) + if logOutput != "" { + logOutputs = append(logOutputs, logOutput) + } + + // Update client + err = db.WithContext(ctx).Model(client).Update("name", uniqueName).Error + require.NoError(t, err) + } + } + + // Verify all three clients have correct names + var finalClients []tables.TableMCPClient + err := db.WithContext(ctx).Find(&finalClients).Error + require.NoError(t, err) + assert.Len(t, finalClients, 3, "Should have all 3 clients") + + names := make([]string, len(finalClients)) + for i, c := range finalClients { + names[i] = c.Name + } + assert.Contains(t, names, "mcp_client", "Should contain mcp_client") + assert.Contains(t, names, "mcp_client2", "Should contain mcp_client2") + assert.Contains(t, names, "mcp_client3", "Should contain mcp_client3") + + // Verify logging: should log all three transformations + allLogs := strings.Join(logOutputs, "") + assert.Contains(t, allLogs, "MCP Client Name Normalized: 'mcp client' -> 'mcp_client'", "Should log first normalization") + assert.Contains(t, allLogs, "MCP Client Name Normalized: 'mcp-client' -> 'mcp_client2'", "Should log second normalization") + assert.Contains(t, allLogs, "MCP Client Name Normalized: '1mcp-client' -> 'mcp_client3'", "Should log third normalization") +} + +// findUniqueNameForTestWithTracking is a test helper that tracks assigned names in memory +func findUniqueNameForTestWithTracking(baseName string, originalName string, excludeID uint, tx *gorm.DB, assignedNames map[string]bool) (string, error) { + // First check if base name is already assigned in this migration + if !assignedNames[baseName] { + // Also check database for existing names (excluding current client) + var count int64 + err := tx.Model(&tables.TableMCPClient{}).Where("name = ? AND id != ?", baseName, excludeID).Count(&count).Error + if err != nil { + return "", fmt.Errorf("failed to check name availability: %w", err) + } + if count == 0 { + // Name is available + assignedNames[baseName] = true + // Log normalization even when no collision + if originalName != baseName { + log.Printf("MCP Client Name Normalized: '%s' -> '%s'", originalName, baseName) + } + return baseName, nil + } + } + + // Name exists (either assigned in this migration or in database), try with number suffix starting from 2 + suffix := 2 + const maxSuffix = 1000 + for { + if suffix > maxSuffix { + return "", fmt.Errorf("could not find unique name after %d attempts for base name: %s", maxSuffix, baseName) + } + candidateName := baseName + strconv.Itoa(suffix) + + // Check both in-memory map and database + if !assignedNames[candidateName] { + var count int64 + err := tx.Model(&tables.TableMCPClient{}).Where("name = ? AND id != ?", candidateName, excludeID).Count(&count).Error + if err != nil { + return "", fmt.Errorf("failed to check name availability: %w", err) + } + if count == 0 { + // Found available name + assignedNames[candidateName] = true + log.Printf("MCP Client Name Normalized: '%s' -> '%s'", originalName, candidateName) + return candidateName, nil + } + } + suffix++ + } +} + +// findUniqueNameForTest is a test helper that extracts the findUniqueName logic +// This mirrors the implementation in migrations.go for testing +func findUniqueNameForTest(baseName string, originalName string, excludeID uint, tx *gorm.DB) (string, error) { + // First, try the base name + var count int64 + err := tx.Model(&tables.TableMCPClient{}).Where("name = ? AND id != ?", baseName, excludeID).Count(&count).Error + if err != nil { + return "", fmt.Errorf("failed to check name availability: %w", err) + } + if count == 0 { + // Name is available + return baseName, nil + } + + // Name exists, try with number suffix starting from 2 + // (base name is conceptually "1", so collisions start from "2") + suffix := 2 + const maxSuffix = 1000 // Safety limit to prevent infinite loops + for { + if suffix > maxSuffix { + return "", fmt.Errorf("could not find unique name after %d attempts for base name: %s", maxSuffix, baseName) + } + candidateName := baseName + strconv.Itoa(suffix) + err := tx.Model(&tables.TableMCPClient{}).Where("name = ? AND id != ?", candidateName, excludeID).Count(&count).Error + if err != nil { + return "", fmt.Errorf("failed to check name availability: %w", err) + } + if count == 0 { + // Found available name - log the transformation + log.Printf("MCP Client Name Normalized: '%s' -> '%s'", originalName, candidateName) + return candidateName, nil + } + suffix++ + } +} diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index bb4444a73a..60846eb5ae 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "strings" + "time" "github.com/bytedance/sonic" bifrost "github.com/maximhq/bifrost/core" @@ -50,7 +51,11 @@ func (s *RDBConfigStore) UpdateClientConfig(ctx context.Context, config *ClientC AllowedOrigins: config.AllowedOrigins, MaxRequestBodySizeMB: config.MaxRequestBodySizeMB, EnableLiteLLMFallbacks: config.EnableLiteLLMFallbacks, + MCPAgentDepth: config.MCPAgentDepth, + MCPToolExecutionTimeout: config.MCPToolExecutionTimeout, + MCPCodeModeBindingLevel: config.MCPCodeModeBindingLevel, HeaderFilterConfig: config.HeaderFilterConfig, + ConfigHash: config.ConfigHash, } // Delete existing client config and create new one in a transaction return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { @@ -168,7 +173,6 @@ func (s *RDBConfigStore) UpdateFrameworkConfig(ctx context.Context, config *tabl } return tx.Create(config).Error }) - } // GetFrameworkConfig retrieves the framework configuration from the database. @@ -205,7 +209,11 @@ func (s *RDBConfigStore) GetClientConfig(ctx context.Context) (*ClientConfig, er AllowedOrigins: dbConfig.AllowedOrigins, MaxRequestBodySizeMB: dbConfig.MaxRequestBodySizeMB, EnableLiteLLMFallbacks: dbConfig.EnableLiteLLMFallbacks, + MCPAgentDepth: dbConfig.MCPAgentDepth, + MCPToolExecutionTimeout: dbConfig.MCPToolExecutionTimeout, + MCPCodeModeBindingLevel: dbConfig.MCPCodeModeBindingLevel, HeaderFilterConfig: dbConfig.HeaderFilterConfig, + ConfigHash: dbConfig.ConfigHash, }, nil } @@ -609,7 +617,7 @@ func (s *RDBConfigStore) DeleteProvider(ctx context.Context, provider schemas.Mo return err } - // Delete the provider (keys will be deleted due to CASCADE constraint) + // Delete the provider first (keys will be deleted due to CASCADE constraint) if err := txDB.WithContext(ctx).Delete(&dbProvider).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound @@ -624,9 +632,6 @@ func (s *RDBConfigStore) DeleteProvider(ctx context.Context, provider schemas.Mo func (s *RDBConfigStore) GetProvidersConfig(ctx context.Context) (map[schemas.ModelProvider]ProviderConfig, error) { var dbProviders []tables.TableProvider if err := s.db.WithContext(ctx).Preload("Keys").Find(&dbProviders).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrNotFound - } return nil, err } if len(dbProviders) == 0 { @@ -774,17 +779,40 @@ func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, } clientConfigs[i] = schemas.MCPClientConfig{ - ID: dbClient.ClientID, - Name: dbClient.Name, - ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), - ConnectionString: processedConnectionString, - StdioConfig: dbClient.StdioConfig, - ToolsToExecute: dbClient.ToolsToExecute, - Headers: processedHeaders, + ID: dbClient.ClientID, + Name: dbClient.Name, + IsCodeModeClient: dbClient.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), + ConnectionString: processedConnectionString, + StdioConfig: dbClient.StdioConfig, + ToolsToExecute: dbClient.ToolsToExecute, + ToolsToAutoExecute: dbClient.ToolsToAutoExecute, + Headers: processedHeaders, + } + } + var clientConfig tables.TableClientConfig + if err := s.db.WithContext(ctx).First(&clientConfig).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + // Return MCP config with default ToolManagerConfig if no client config exists + // This will never happen, but just in case. + return &schemas.MCPConfig{ + ClientConfigs: clientConfigs, + ToolManagerConfig: &schemas.MCPToolManagerConfig{ + ToolExecutionTimeout: 30 * time.Second, // default from TableClientConfig + MaxAgentDepth: 10, // default from TableClientConfig + }, + }, nil } + return nil, err + } + toolManagerConfig := schemas.MCPToolManagerConfig{ + ToolExecutionTimeout: time.Duration(clientConfig.MCPToolExecutionTimeout) * time.Second, + MaxAgentDepth: clientConfig.MCPAgentDepth, + CodeModeBindingLevel: schemas.CodeModeBindingLevel(clientConfig.MCPCodeModeBindingLevel), } return &schemas.MCPConfig{ - ClientConfigs: clientConfigs, + ClientConfigs: clientConfigs, + ToolManagerConfig: &toolManagerConfig, }, nil } @@ -810,17 +838,20 @@ func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig } // Substitute environment variables back to their original form - substituteMCPClientEnvVars(&clientConfigCopy, envKeys) + // For create operations, no existing headers to restore from + substituteMCPClientEnvVars(&clientConfigCopy, envKeys, nil) // Create new client dbClient := tables.TableMCPClient{ - ClientID: clientConfigCopy.ID, - Name: clientConfigCopy.Name, - ConnectionType: string(clientConfigCopy.ConnectionType), - ConnectionString: clientConfigCopy.ConnectionString, - StdioConfig: clientConfigCopy.StdioConfig, - ToolsToExecute: clientConfigCopy.ToolsToExecute, - Headers: clientConfigCopy.Headers, + ClientID: clientConfigCopy.ID, + Name: clientConfigCopy.Name, + IsCodeModeClient: clientConfigCopy.IsCodeModeClient, + ConnectionType: string(clientConfigCopy.ConnectionType), + ConnectionString: clientConfigCopy.ConnectionString, + StdioConfig: clientConfigCopy.StdioConfig, + ToolsToExecute: clientConfigCopy.ToolsToExecute, + ToolsToAutoExecute: clientConfigCopy.ToolsToAutoExecute, + Headers: clientConfigCopy.Headers, } if err := tx.WithContext(ctx).Create(&dbClient).Error; err != nil { @@ -849,17 +880,20 @@ func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, c } // Substitute environment variables back to their original form - substituteMCPClientEnvVars(&clientConfigCopy, envKeys) + // Pass existing headers to restore redacted plain values + substituteMCPClientEnvVars(&clientConfigCopy, envKeys, existingClient.Headers) // Update existing client existingClient.Name = clientConfigCopy.Name - existingClient.ConnectionType = string(clientConfigCopy.ConnectionType) - existingClient.ConnectionString = clientConfigCopy.ConnectionString - existingClient.StdioConfig = clientConfigCopy.StdioConfig + existingClient.IsCodeModeClient = clientConfigCopy.IsCodeModeClient existingClient.ToolsToExecute = clientConfigCopy.ToolsToExecute + existingClient.ToolsToAutoExecute = clientConfigCopy.ToolsToAutoExecute existingClient.Headers = clientConfigCopy.Headers - if err := tx.WithContext(ctx).Updates(&existingClient).Error; err != nil { + // Use Select to explicitly include IsCodeModeClient even when it's false (zero value) + // GORM's Updates() skips zero values by default, so we need to explicitly select fields + // Using struct field names - GORM will convert them to column names automatically + if err := tx.WithContext(ctx).Select("name", "is_code_mode_client", "tools_to_execute_json", "tools_to_auto_execute_json", "headers_json", "updated_at").Updates(&existingClient).Error; err != nil { return s.parseGormError(err) } return nil @@ -968,9 +1002,6 @@ func (s *RDBConfigStore) UpdateLogsStoreConfig(ctx context.Context, config *logs func (s *RDBConfigStore) GetEnvKeys(ctx context.Context) (map[string][]EnvKeyInfo, error) { var dbEnvKeys []tables.TableEnvKey if err := s.db.WithContext(ctx).Find(&dbEnvKeys).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrNotFound - } return nil, err } envKeys := make(map[string][]EnvKeyInfo) @@ -1405,7 +1436,80 @@ func (s *RDBConfigStore) GetAllRedactedKeys(ctx context.Context, ids []string) ( // DeleteVirtualKey deletes a virtual key from the database. func (s *RDBConfigStore) DeleteVirtualKey(ctx context.Context, id string) error { - return s.db.WithContext(ctx).Delete(&tables.TableVirtualKey{}, "id = ?", id).Error + if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var virtualKey tables.TableVirtualKey + if err := tx.WithContext(ctx).Preload("ProviderConfigs").First(&virtualKey, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + + // Collect budget and rate limit IDs from provider configs before deletion + var providerConfigBudgetIDs []string + var providerConfigRateLimitIDs []string + for _, pc := range virtualKey.ProviderConfigs { + // Delete the keys join table entries + if err := tx.WithContext(ctx).Exec("DELETE FROM governance_virtual_key_provider_config_keys WHERE table_virtual_key_provider_config_id = ?", pc.ID).Error; err != nil { + return err + } + // Collect budget and rate limit IDs for deletion after provider config + if pc.BudgetID != nil { + providerConfigBudgetIDs = append(providerConfigBudgetIDs, *pc.BudgetID) + } + if pc.RateLimitID != nil { + providerConfigRateLimitIDs = append(providerConfigRateLimitIDs, *pc.RateLimitID) + } + } + + // Delete all provider configs associated with the virtual key first + if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKeyProviderConfig{}, "virtual_key_id = ?", id).Error; err != nil { + return err + } + // Now delete the collected budgets and rate limits + for _, budgetID := range providerConfigBudgetIDs { + if err := tx.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", budgetID).Error; err != nil { + return err + } + } + for _, rateLimitID := range providerConfigRateLimitIDs { + if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", rateLimitID).Error; err != nil { + return err + } + } + // Delete all MCP configs associated with the virtual key + if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKeyMCPConfig{}, "virtual_key_id = ?", id).Error; err != nil { + return err + } + // Delete the budget associated with the virtual key + budgetID := virtualKey.BudgetID + rateLimitID := virtualKey.RateLimitID + // Delete the virtual key + if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKey{}, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + if budgetID != nil { + if err := tx.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil { + return err + } + } + // Delete the rate limit associated with the virtual key + if rateLimitID != nil { + if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil { + return err + } + } + return nil + }); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + return nil } // GetVirtualKeyProviderConfigs retrieves all virtual key provider configs from the database. @@ -1581,7 +1685,34 @@ func (s *RDBConfigStore) DeleteVirtualKeyProviderConfig(ctx context.Context, id } else { txDB = s.db } - return txDB.WithContext(ctx).Delete(&tables.TableVirtualKeyProviderConfig{}, "id = ?", id).Error + // First fetch the provider config to get budget and rate limit IDs + var providerConfig tables.TableVirtualKeyProviderConfig + if err := txDB.WithContext(ctx).First(&providerConfig, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + // Store the budget and rate limit IDs before deleting + budgetID := providerConfig.BudgetID + rateLimitID := providerConfig.RateLimitID + // Delete the provider config first + if err := txDB.WithContext(ctx).Delete(&tables.TableVirtualKeyProviderConfig{}, "id = ?", id).Error; err != nil { + return err + } + // Delete the budget if it exists + if budgetID != nil { + if err := txDB.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil { + return err + } + } + // Delete the rate limit if it exists + if rateLimitID != nil { + if err := txDB.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil { + return err + } + } + return nil } // GetVirtualKeyMCPConfigs retrieves all virtual key MCP configs from the database. @@ -1652,9 +1783,6 @@ func (s *RDBConfigStore) GetTeams(ctx context.Context, customerID string) ([]tab } var teams []tables.TableTeam if err := query.Find(&teams).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrNotFound - } return nil, err } return teams, nil @@ -1702,16 +1830,47 @@ func (s *RDBConfigStore) UpdateTeam(ctx context.Context, team *tables.TableTeam, // DeleteTeam deletes a team from the database. func (s *RDBConfigStore) DeleteTeam(ctx context.Context, id string) error { - return s.db.WithContext(ctx).Delete(&tables.TableTeam{}, "id = ?", id).Error + if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var team tables.TableTeam + if err := tx.WithContext(ctx).Preload("Budget").First(&team, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + // Set team_id to null for all virtual keys associated with the team + if err := tx.WithContext(ctx).Model(&tables.TableVirtualKey{}).Where("team_id = ?", id).Update("team_id", nil).Error; err != nil { + return err + } + // Store the budget ID before deleting the team + budgetID := team.BudgetID + // Delete the team first + if err := tx.WithContext(ctx).Delete(&tables.TableTeam{}, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + // Delete the team's budget if it exists + if budgetID != nil { + if err := tx.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil { + return err + } + } + return nil + }); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + return nil } // GetCustomers retrieves all customers from the database. func (s *RDBConfigStore) GetCustomers(ctx context.Context) ([]tables.TableCustomer, error) { var customers []tables.TableCustomer if err := s.db.WithContext(ctx).Preload("Teams").Preload("Budget").Find(&customers).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrNotFound - } return nil, err } return customers, nil @@ -1759,7 +1918,54 @@ func (s *RDBConfigStore) UpdateCustomer(ctx context.Context, customer *tables.Ta // DeleteCustomer deletes a customer from the database. func (s *RDBConfigStore) DeleteCustomer(ctx context.Context, id string) error { - return s.db.WithContext(ctx).Delete(&tables.TableCustomer{}, "id = ?", id).Error + if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var customer tables.TableCustomer + if err := tx.WithContext(ctx).Preload("Budget").First(&customer, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + // Set customer_id to null for all virtual keys associated with the customer + if err := tx.WithContext(ctx).Model(&tables.TableVirtualKey{}).Where("customer_id = ?", id).Update("customer_id", nil).Error; err != nil { + return err + } + // Set customer_id to null for all teams associated with the customer + if err := tx.WithContext(ctx).Model(&tables.TableTeam{}).Where("customer_id = ?", id).Update("customer_id", nil).Error; err != nil { + return err + } + // Store the budget ID before deleting the customer + budgetID := customer.BudgetID + // Delete the customer first + if err := tx.WithContext(ctx).Delete(&tables.TableCustomer{}, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + // Delete the customer's budget if it exists + if budgetID != nil { + if err := tx.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil { + return err + } + } + return nil + }); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + return nil +} + +// GetRateLimits retrieves all rate limits from the database. +func (s *RDBConfigStore) GetRateLimits(ctx context.Context) ([]tables.TableRateLimit, error) { + var rateLimits []tables.TableRateLimit + if err := s.db.WithContext(ctx).Find(&rateLimits).Error; err != nil { + return nil, err + } + return rateLimits, nil } // GetRateLimit retrieves a specific rate limit from the database. @@ -1822,9 +2028,6 @@ func (s *RDBConfigStore) UpdateRateLimits(ctx context.Context, rateLimits []*tab func (s *RDBConfigStore) GetBudgets(ctx context.Context) ([]tables.TableBudget, error) { var budgets []tables.TableBudget if err := s.db.WithContext(ctx).Find(&budgets).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrNotFound - } return nil, err } return budgets, nil @@ -2159,6 +2362,33 @@ func (s *RDBConfigStore) ExecuteTransaction(ctx context.Context, fn func(tx *gor return s.db.WithContext(ctx).Transaction(fn) } +// RetryOnNotFound retries a function up to 3 times with 1-second delays if it returns ErrNotFound +func (s *RDBConfigStore) RetryOnNotFound(ctx context.Context, fn func(ctx context.Context) (any, error), maxRetries int, retryDelay time.Duration) (any, error) { + var lastErr error + for attempt := range maxRetries { + result, err := fn(ctx) + if err == nil { + return result, nil + } + if !errors.Is(err, ErrNotFound) && !errors.Is(err, gorm.ErrRecordNotFound) { + return nil, err + } + + lastErr = err + + // Don't wait after the last attempt + if attempt < maxRetries-1 { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(retryDelay): + // Continue to next retry + } + } + } + return nil, lastErr +} + // doesTableExist checks if a table exists in the database. func (s *RDBConfigStore) doesTableExist(ctx context.Context, tableName string) bool { return s.db.WithContext(ctx).Migrator().HasTable(tableName) diff --git a/framework/configstore/store.go b/framework/configstore/store.go index 781e72014f..7a5393e2b4 100644 --- a/framework/configstore/store.go +++ b/framework/configstore/store.go @@ -4,6 +4,7 @@ package configstore import ( "context" "fmt" + "time" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore/tables" @@ -100,6 +101,7 @@ type ConfigStore interface { DeleteCustomer(ctx context.Context, id string) error // Rate limit CRUD + GetRateLimits(ctx context.Context) ([]tables.TableRateLimit, error) GetRateLimit(ctx context.Context, id string) (*tables.TableRateLimit, error) CreateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error UpdateRateLimit(ctx context.Context, rateLimit *tables.TableRateLimit, tx ...*gorm.DB) error @@ -147,6 +149,9 @@ type ConfigStore interface { // Generic transaction manager ExecuteTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error + // Not found retry wrapper + RetryOnNotFound(ctx context.Context, fn func(ctx context.Context) (any, error), maxRetries int, retryDelay time.Duration) (any, error) + // DB returns the underlying database connection. DB() *gorm.DB diff --git a/framework/configstore/tables/budget.go b/framework/configstore/tables/budget.go index 1744363b28..543d0cb577 100644 --- a/framework/configstore/tables/budget.go +++ b/framework/configstore/tables/budget.go @@ -21,17 +21,20 @@ type TableBudget struct { CreatedAt time.Time `gorm:"index;not null" json:"created_at"` UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` + + // Virtual fields for runtime use (not stored in DB) + LastDBUsage float64 `gorm:"-" json:"-"` } // TableName sets the table name for each model func (TableBudget) TableName() string { return "governance_budgets" } // BeforeSave hook for Budget to validate reset duration format and max limit -func (b *TableBudget) BeforeSave(tx *gorm.DB) error { +func (b *TableBudget) BeforeSave(tx *gorm.DB) error { // Validate that ResetDuration is in correct format (e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y") if d, err := ParseDuration(b.ResetDuration); err != nil { return fmt.Errorf("invalid reset duration format: %s", b.ResetDuration) - }else if d <= 0 { + } else if d <= 0 { return fmt.Errorf("reset duration must be > 0: %s", b.ResetDuration) } // Validate that MaxLimit is not negative (budgets should be positive) @@ -41,3 +44,9 @@ func (b *TableBudget) BeforeSave(tx *gorm.DB) error { return nil } + +// AfterFind hook for Budget to set the LastDBUsage virtual field +func (b *TableBudget) AfterFind(tx *gorm.DB) error { + b.LastDBUsage = b.CurrentUsage + return nil +} \ No newline at end of file diff --git a/framework/configstore/tables/clientconfig.go b/framework/configstore/tables/clientconfig.go index 8ecf45acd6..138024b21c 100644 --- a/framework/configstore/tables/clientconfig.go +++ b/framework/configstore/tables/clientconfig.go @@ -16,12 +16,16 @@ type TableClientConfig struct { HeaderFilterConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized GlobalHeaderFilterConfig InitialPoolSize int `gorm:"default:300" json:"initial_pool_size"` EnableLogging bool `gorm:"" json:"enable_logging"` - DisableContentLogging bool `gorm:"default:false" json:"disable_content_logging"` // DisableContentLogging controls whether sensitive content (inputs, outputs, embeddings, etc.) is logged - LogRetentionDays int `gorm:"default:365" json:"log_retention_days" validate:"min=1"` // Number of days to retain logs (minimum 1 day) + DisableContentLogging bool `gorm:"default:false" json:"disable_content_logging"` // DisableContentLogging controls whether sensitive content (inputs, outputs, embeddings, etc.) is logged + LogRetentionDays int `gorm:"default:365" json:"log_retention_days" validate:"min=1"` // Number of days to retain logs (minimum 1 day) EnableGovernance bool `gorm:"" json:"enable_governance"` EnforceGovernanceHeader bool `gorm:"" json:"enforce_governance_header"` AllowDirectKeys bool `gorm:"" json:"allow_direct_keys"` MaxRequestBodySizeMB int `gorm:"default:100" json:"max_request_body_size_mb"` + MCPAgentDepth int `gorm:"default:10" json:"mcp_agent_depth"` + MCPToolExecutionTimeout int `gorm:"default:30" json:"mcp_tool_execution_timeout"` // Timeout for individual tool execution in seconds (default: 30) + MCPCodeModeBindingLevel string `gorm:"default:server" json:"mcp_code_mode_binding_level"` // How tools are exposed in VFS: "server" or "tool" + // LiteLLM fallback flag EnableLiteLLMFallbacks bool `gorm:"column:enable_litellm_fallbacks;default:false" json:"enable_litellm_fallbacks"` diff --git a/framework/configstore/tables/mcp.go b/framework/configstore/tables/mcp.go index f5c2381a6a..687c60355a 100644 --- a/framework/configstore/tables/mcp.go +++ b/framework/configstore/tables/mcp.go @@ -10,14 +10,16 @@ import ( // TableMCPClient represents an MCP client configuration in the database type TableMCPClient struct { - ID uint `gorm:"primaryKey;autoIncrement" json:"id"` // ID is used as the internal primary key and is also accessed by public methods, so it must be present. - ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"` - Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"` - ConnectionType string `gorm:"type:varchar(20);not null" json:"connection_type"` // schemas.MCPConnectionType - ConnectionString *string `gorm:"type:text" json:"connection_string,omitempty"` - StdioConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.MCPStdioConfig - ToolsToExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string - HeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` // ID is used as the internal primary key and is also accessed by public methods, so it must be present. + ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"` + Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"` + IsCodeModeClient bool `gorm:"default:false" json:"is_code_mode_client"` // Whether the client is a code mode client + ConnectionType string `gorm:"type:varchar(20);not null" json:"connection_type"` // schemas.MCPConnectionType + ConnectionString *string `gorm:"type:text" json:"connection_string,omitempty"` + StdioConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.MCPStdioConfig + ToolsToExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + ToolsToAutoExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + HeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string // Config hash is used to detect the changes synced from config.json file // Every time we sync the config.json file, we will update the config hash @@ -27,9 +29,10 @@ type TableMCPClient struct { UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` // Virtual fields for runtime use (not stored in DB) - StdioConfig *schemas.MCPStdioConfig `gorm:"-" json:"stdio_config,omitempty"` - ToolsToExecute []string `gorm:"-" json:"tools_to_execute"` - Headers map[string]string `gorm:"-" json:"headers"` + StdioConfig *schemas.MCPStdioConfig `gorm:"-" json:"stdio_config,omitempty"` + ToolsToExecute []string `gorm:"-" json:"tools_to_execute"` + ToolsToAutoExecute []string `gorm:"-" json:"tools_to_auto_execute"` + Headers map[string]string `gorm:"-" json:"headers"` } // TableName sets the table name for each model @@ -57,6 +60,16 @@ func (c *TableMCPClient) BeforeSave(tx *gorm.DB) error { c.ToolsToExecuteJSON = "[]" } + if c.ToolsToAutoExecute != nil { + data, err := json.Marshal(c.ToolsToAutoExecute) + if err != nil { + return err + } + c.ToolsToAutoExecuteJSON = string(data) + } else { + c.ToolsToAutoExecuteJSON = "[]" + } + if c.Headers != nil { data, err := json.Marshal(c.Headers) if err != nil { @@ -66,7 +79,6 @@ func (c *TableMCPClient) BeforeSave(tx *gorm.DB) error { } else { c.HeadersJSON = "{}" } - return nil } @@ -86,6 +98,12 @@ func (c *TableMCPClient) AfterFind(tx *gorm.DB) error { } } + if c.ToolsToAutoExecuteJSON != "" { + if err := json.Unmarshal([]byte(c.ToolsToAutoExecuteJSON), &c.ToolsToAutoExecute); err != nil { + return err + } + } + if c.HeadersJSON != "" { if err := json.Unmarshal([]byte(c.HeadersJSON), &c.Headers); err != nil { return err diff --git a/framework/configstore/tables/ratelimit.go b/framework/configstore/tables/ratelimit.go index 7147e7b89f..1a46c690e3 100644 --- a/framework/configstore/tables/ratelimit.go +++ b/framework/configstore/tables/ratelimit.go @@ -29,6 +29,10 @@ type TableRateLimit struct { CreatedAt time.Time `gorm:"index;not null" json:"created_at"` UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` + + // Virtual fields for runtime use (not stored in DB) + LastDBTokenUsage int64 `gorm:"-" json:"-"` + LastDBRequestUsage int64 `gorm:"-" json:"-"` } // TableName sets the table name for each model @@ -75,3 +79,10 @@ func (rl *TableRateLimit) BeforeSave(tx *gorm.DB) error { return nil } + +// AfterFind hook for RateLimit to set the LastDBTokenUsage and LastDBRequestUsage virtual fields +func (rl *TableRateLimit) AfterFind(tx *gorm.DB) error { + rl.LastDBTokenUsage = rl.TokenCurrentUsage + rl.LastDBRequestUsage = rl.RequestCurrentUsage + return nil +} \ No newline at end of file diff --git a/framework/configstore/utils.go b/framework/configstore/utils.go index 78f0e133eb..24f1f636cc 100644 --- a/framework/configstore/utils.go +++ b/framework/configstore/utils.go @@ -183,32 +183,59 @@ func substituteMCPEnvVars(config *schemas.MCPConfig, envKeys map[string][]EnvKey } // substituteMCPClientEnvVars replaces resolved environment variable values with their original env.VAR_NAME references for a single MCP client config -func substituteMCPClientEnvVars(clientConfig *schemas.MCPClientConfig, envKeys map[string][]EnvKeyInfo) { +// If existingHeaders is provided, it will restore redacted plain header values from the existing headers before substitution +func substituteMCPClientEnvVars(clientConfig *schemas.MCPClientConfig, envKeys map[string][]EnvKeyInfo, existingHeaders map[string]string) { + // First, restore redacted plain header values from existing headers if provided + // This handles the case where UI sends redacted headers that aren't env vars + if existingHeaders != nil && clientConfig.Headers != nil { + for header, value := range clientConfig.Headers { + // Check if the value is redacted (contains **** pattern) and not an env var + if strings.Contains(value, "****") && !strings.HasPrefix(value, "env.") { + // If header exists in existing headers and wasn't an env var, restore it + if oldHeaderValue, exists := existingHeaders[header]; exists { + if !strings.HasPrefix(oldHeaderValue, "env.") { + clientConfig.Headers[header] = oldHeaderValue + } + } + } + } + } + // Find the environment variable for this client's connection string and headers for envVar, keyInfos := range envKeys { for _, keyInfo := range keyInfos { // For MCP connection strings if keyInfo.KeyType == "connection_string" { - // Extract client name from config path like "mcp.client_configs.clientName.connection_string" + // Extract client ID from config path like "mcp.client_configs.clientID.connection_string" pathParts := strings.Split(keyInfo.ConfigPath, ".") if len(pathParts) >= 3 && pathParts[0] == "mcp" && pathParts[1] == "client_configs" { - clientName := pathParts[2] - // If this environment variable is for the current client - if clientName == clientConfig.Name && clientConfig.ConnectionString != nil { + clientID := pathParts[2] + // If this environment variable is for the current client (match by ID) + if clientID == clientConfig.ID && clientConfig.ConnectionString != nil { clientConfig.ConnectionString = &[]string{fmt.Sprintf("env.%s", envVar)}[0] } } } // For MCP headers if keyInfo.KeyType == "mcp_header" { - // Extract client name and header name from config path like "mcp.client_configs.clientName.headers.headerName" + // Extract client ID and header name from config path like "mcp.client_configs.clientID.headers.headerName" pathParts := strings.Split(keyInfo.ConfigPath, ".") if len(pathParts) >= 5 && pathParts[0] == "mcp" && pathParts[1] == "client_configs" && pathParts[3] == "headers" { - clientName := pathParts[2] + clientID := pathParts[2] headerName := pathParts[4] - // If this environment variable is for the current client - if clientName == clientConfig.Name && clientConfig.Headers != nil { - clientConfig.Headers[headerName] = fmt.Sprintf("env.%s", envVar) + // If this environment variable is for the current client (match by ID) + if clientID == clientConfig.ID && clientConfig.Headers != nil { + if headerValue, exists := clientConfig.Headers[headerName]; exists { + // If it's already in env.VAR format, update to use the correct env var + if strings.HasPrefix(headerValue, "env.") { + clientConfig.Headers[headerName] = fmt.Sprintf("env.%s", envVar) + } else if strings.Contains(headerValue, "****") { + // If it's redacted (contains ****), restore to env.VAR format + // This handles the case where UI sends redacted headers back for env vars + clientConfig.Headers[headerName] = fmt.Sprintf("env.%s", envVar) + } + // If it's a plain value (not env. and not redacted), leave it as-is + } } } } diff --git a/framework/go.mod b/framework/go.mod index e8fb900874..68852cd826 100644 --- a/framework/go.mod +++ b/framework/go.mod @@ -4,7 +4,7 @@ go 1.25.5 require ( github.com/google/uuid v1.6.0 - github.com/maximhq/bifrost/core v1.2.49 + github.com/maximhq/bifrost/core v1.3.8 github.com/qdrant/go-client v1.16.2 github.com/redis/go-redis/v9 v9.17.2 github.com/stretchr/testify v1.11.1 @@ -28,6 +28,9 @@ require ( github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 // indirect github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/sonic/loader v0.4.0 // indirect + github.com/clarkmcc/go-typescript v0.7.0 // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/swag/cmdutils v0.25.4 // indirect @@ -41,8 +44,10 @@ require ( github.com/go-openapi/swag/stringutils v0.25.4 // indirect github.com/go-openapi/swag/typeutils v0.25.4 // indirect github.com/go-openapi/swag/yamlutils v0.25.4 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgx/v5 v5.7.6 // indirect diff --git a/framework/go.sum b/framework/go.sum index 6ac5a353e0..7ce6efcfe9 100644 --- a/framework/go.sum +++ b/framework/go.sum @@ -12,6 +12,8 @@ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/Masterminds/semver/v3 v3.3.1 h1:QtNSWtVZ3nBfk8mAOu/B6v7FMJ+NHTIgUPi7rj+4nv4= +github.com/Masterminds/semver/v3 v3.3.1/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= @@ -68,6 +70,8 @@ github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2N github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/clarkmcc/go-typescript v0.7.0 h1:3nVeaPYyTCWjX6Lf8GoEOTxME2bM5tLuWmwhSZ86uxg= +github.com/clarkmcc/go-typescript v0.7.0/go.mod h1:IZ/nzoVeydAmyfX7l6Jmp8lJDOEnae3jffoXwP4UyYg= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -77,6 +81,10 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 h1:jxmXU5V9tXxJnydU5v/m9SG8TRUa/Z7IXODBpMs/P+U= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -132,6 +140,8 @@ github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6 github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= github.com/go-openapi/validate v0.25.1 h1:sSACUI6Jcnbo5IWqbYHgjibrhhmt3vR6lCzKZnmAgBw= github.com/go-openapi/validate v0.25.1/go.mod h1:RMVyVFYte0gbSTaZ0N4KmTn6u/kClvAFp+mAVfS/DQc= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -141,6 +151,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f h1:HU1RgM6NALf/KW9HEY6zry3ADbDKcmpQ+hJedoNGQYQ= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f/go.mod h1:67FPmZWbr+KDT/VlpWtw6sO9XSjpJmLuHpoLmWiTGgY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= @@ -184,8 +196,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.2.49 h1:fk6l6r3kVBlpN73wYXmgtV6O4bhedOjSO4LAEz/7leg= -github.com/maximhq/bifrost/core v1.2.49/go.mod h1:z7nOx15e91ktZGi+pZHq+uhShlEK+fM4UyYUpP6oHAw= +github.com/maximhq/bifrost/core v1.3.8 h1:xtwB9+HeTzYz5IKHkpUtupzBd0A5yl1avdLJGjsOKPI= +github.com/maximhq/bifrost/core v1.3.8/go.mod h1:abKQRnJQPZz8/UMxCcbuNHEyq19Db+IX4KlGJdlLY8E= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= @@ -281,6 +293,8 @@ google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/framework/modelcatalog/main.go b/framework/modelcatalog/main.go index ad3997e1fd..64bba62ffd 100644 --- a/framework/modelcatalog/main.go +++ b/framework/modelcatalog/main.go @@ -29,6 +29,8 @@ type ModelCatalog struct { pricingSyncInterval time.Duration pricingMu sync.RWMutex + shouldSyncPricingFunc ShouldSyncPricingFunc + // In-memory cache for fast access - direct map for O(1) lookups pricingData map[string]configstoreTables.TableModelPricing mu sync.RWMutex @@ -76,8 +78,14 @@ type PricingEntry struct { OutputCostPerTokenBatches *float64 `json:"output_cost_per_token_batches,omitempty"` } +// ShouldSyncPricingFunc is a function that determines if pricing data should be synced +// It returns a boolean indicating if syncing is needed +// It is completely optional and can be nil if not needed +// syncPricing function will be called if this function returns true +type ShouldSyncPricingFunc func(ctx context.Context) bool + // Init initializes the pricing manager -func Init(ctx context.Context, config *Config, configStore configstore.ConfigStore, logger schemas.Logger) (*ModelCatalog, error) { +func Init(ctx context.Context, config *Config, configStore configstore.ConfigStore, shouldSyncPricingFunc ShouldSyncPricingFunc, logger schemas.Logger) (*ModelCatalog, error) { // Initialize pricing URL and sync interval pricingURL := DefaultPricingURL if config.PricingURL != nil { @@ -89,13 +97,14 @@ func Init(ctx context.Context, config *Config, configStore configstore.ConfigSto } mc := &ModelCatalog{ - pricingURL: pricingURL, - pricingSyncInterval: pricingSyncInterval, - configStore: configStore, - logger: logger, - pricingData: make(map[string]configstoreTables.TableModelPricing), - modelPool: make(map[schemas.ModelProvider][]string), - done: make(chan struct{}), + pricingURL: pricingURL, + pricingSyncInterval: pricingSyncInterval, + configStore: configStore, + logger: logger, + pricingData: make(map[string]configstoreTables.TableModelPricing), + modelPool: make(map[schemas.ModelProvider][]string), + done: make(chan struct{}), + shouldSyncPricingFunc: shouldSyncPricingFunc, } logger.Info("initializing pricing manager...") diff --git a/framework/modelcatalog/sync.go b/framework/modelcatalog/sync.go index 7f81cbae67..ac94d7e73a 100644 --- a/framework/modelcatalog/sync.go +++ b/framework/modelcatalog/sync.go @@ -58,6 +58,13 @@ func (mc *ModelCatalog) shouldSyncPricing(ctx context.Context) (bool, string) { func (mc *ModelCatalog) syncPricing(ctx context.Context) error { mc.logger.Debug("starting pricing data synchronization for governance") + if mc.shouldSyncPricingFunc != nil { + if !mc.shouldSyncPricingFunc(ctx) { + mc.logger.Debug("pricing sync cancelled by custom function") + return nil + } + } + // Load pricing data from URL pricingData, err := mc.loadPricingFromURL(ctx) if err != nil { diff --git a/framework/plugins/dynamicplugin.go b/framework/plugins/dynamicplugin.go deleted file mode 100644 index 7e8a561cef..0000000000 --- a/framework/plugins/dynamicplugin.go +++ /dev/null @@ -1,174 +0,0 @@ -package plugins - -import ( - "fmt" - "os" - "plugin" - "strings" - "time" - - "github.com/maximhq/bifrost/core/schemas" - "github.com/valyala/fasthttp" -) - -// DynamicPlugin is the interface for a dynamic plugin -type DynamicPlugin struct { - Enabled bool - Path string - - Config any - - filename string - plugin *plugin.Plugin - - getName func() string - transportInterceptor func(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) - preHook func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) - postHook func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) - cleanup func() error -} - -// GetName returns the name of the plugin -func (dp *DynamicPlugin) GetName() string { - return dp.getName() -} - -// TransportInterceptor is not used for dynamic plugins -func (dp *DynamicPlugin) TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - return dp.transportInterceptor(ctx, url, headers, body) -} - -// PreHook is not used for dynamic plugins -func (dp *DynamicPlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { - return dp.preHook(ctx, req) -} - -// PostHook is not used for dynamic plugins -func (dp *DynamicPlugin) PostHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { - return dp.postHook(ctx, resp, bifrostErr) -} - -// Cleanup is not used for dynamic plugins -func (dp *DynamicPlugin) Cleanup() error { - return dp.cleanup() -} - -// loadDynamicPlugin loads a dynamic plugin from a path -func loadDynamicPlugin(path string, config any) (schemas.Plugin, error) { - dp := &DynamicPlugin{ - Path: path, - } - // Checking if path is URL or file path - if strings.HasPrefix(dp.Path, "http") { - // Download the file - req := fasthttp.AcquireRequest() - defer fasthttp.ReleaseRequest(req) - response := fasthttp.AcquireResponse() - defer fasthttp.ReleaseResponse(response) - - req.SetRequestURI(dp.Path) - req.Header.SetMethod(fasthttp.MethodGet) - req.Header.Set("Accept", "application/octet-stream") - req.Header.Set("Accept-Encoding", "gzip") - req.Header.Set("Accept-Language", "en-US,en;q=0.9") - err := fasthttp.DoTimeout(req, response, 120*time.Second) - if err != nil { - return nil, err - } - if response.StatusCode() != fasthttp.StatusOK { - return nil, fmt.Errorf("failed to download plugin: %d", response.StatusCode()) - } - // Create a unique temporary file for the plugin - tempFile, err := os.CreateTemp(os.TempDir(), "bifrost-plugin-*.so") - if err != nil { - return nil, fmt.Errorf("failed to create temporary file: %w", err) - } - tempPath := tempFile.Name() - // Write the downloaded body to the temporary file - _, err = tempFile.Write(response.Body()) - if err != nil { - tempFile.Close() - os.Remove(tempPath) - return nil, fmt.Errorf("failed to write plugin to temporary file: %w", err) - } - // Close the file - err = tempFile.Close() - if err != nil { - os.Remove(tempPath) - return nil, fmt.Errorf("failed to close temporary file: %w", err) - } - // Set file permissions to be executable - err = os.Chmod(tempPath, 0755) - if err != nil { - os.Remove(tempPath) - return nil, fmt.Errorf("failed to set executable permissions on plugin: %w", err) - } - dp.Path = tempPath - } - plugin, err := plugin.Open(dp.Path) - if err != nil { - return nil, err - } - ok := false - // Looking up for optional Init method - initSym, err := plugin.Lookup("Init") - if err != nil { - if strings.Contains(err.Error(), "symbol Init not found") { - initSym = nil - } else { - return nil, err - } - } - if initSym != nil { - initFunc, ok := initSym.(func(config any) error) - if !ok { - return nil, fmt.Errorf("failed to cast Init to func(config any) error") - } - err := initFunc(config) - if err != nil { - return nil, err - } - } - // Looking up for GetName method - getNameSym, err := plugin.Lookup("GetName") - if err != nil { - return nil, err - } - if dp.getName, ok = getNameSym.(func() string); !ok { - return nil, fmt.Errorf("failed to cast GetName to func() string") - } - // Looking up for TransportInterceptor method - transportInterceptorSym, err := plugin.Lookup("TransportInterceptor") - if err != nil { - return nil, err - } - if dp.transportInterceptor, ok = transportInterceptorSym.(func(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error)); !ok { - return nil, fmt.Errorf("failed to cast TransportInterceptor to func(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error)") - } - // Looking up for PreHook method - preHookSym, err := plugin.Lookup("PreHook") - if err != nil { - return nil, err - } - if dp.preHook, ok = preHookSym.(func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error)); !ok { - return nil, fmt.Errorf("failed to cast PreHook to func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error)") - } - // Looking up for PostHook method - postHookSym, err := plugin.Lookup("PostHook") - if err != nil { - return nil, err - } - if dp.postHook, ok = postHookSym.(func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error)); !ok { - return nil, fmt.Errorf("failed to cast PostHook to func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error)") - } - // Looking up for Cleanup method - cleanupSym, err := plugin.Lookup("Cleanup") - if err != nil { - return nil, err - } - if dp.cleanup, ok = cleanupSym.(func() error); !ok { - return nil, fmt.Errorf("failed to cast Cleanup to func() error") - } - dp.plugin = plugin - return dp, nil -} diff --git a/framework/plugins/loader.go b/framework/plugins/loader.go new file mode 100644 index 0000000000..1ad11a2eba --- /dev/null +++ b/framework/plugins/loader.go @@ -0,0 +1,8 @@ +package plugins + +import "github.com/maximhq/bifrost/core/schemas" + +// PluginLoader is the contract for a plugin loader +type PluginLoader interface { + LoadDynamicPlugin(path string, config any) (schemas.Plugin, error) +} diff --git a/framework/plugins/main.go b/framework/plugins/main.go index dde83b4ccf..ee1a6dc07e 100644 --- a/framework/plugins/main.go +++ b/framework/plugins/main.go @@ -14,11 +14,12 @@ type DynamicPluginConfig struct { // Config is the configuration for the plugins framework type Config struct { + Plugins []DynamicPluginConfig `json:"plugins"` } // LoadPlugins loads the plugins from the config -func LoadPlugins(config *Config) ([]schemas.Plugin, error) { +func LoadPlugins(loader PluginLoader, config *Config) ([]schemas.Plugin, error) { plugins := []schemas.Plugin{} if config == nil { return plugins, nil @@ -27,7 +28,7 @@ func LoadPlugins(config *Config) ([]schemas.Plugin, error) { if !dp.Enabled { continue } - plugin, err := loadDynamicPlugin(dp.Path, dp.Config) + plugin, err := loader.LoadDynamicPlugin(dp.Path, dp.Config) if err != nil { return nil, err } diff --git a/framework/plugins/soloader.go b/framework/plugins/soloader.go new file mode 100644 index 0000000000..9738fe057d --- /dev/null +++ b/framework/plugins/soloader.go @@ -0,0 +1,94 @@ +package plugins + +import ( + "fmt" + "plugin" + "strings" + + "github.com/maximhq/bifrost/core/schemas" +) + +// SharedObjectPluginLoader is the loader for shared object plugins +type SharedObjectPluginLoader struct{} + +// LoadDynamicPlugin loads a dynamic plugin from a shared object file +func (l *SharedObjectPluginLoader) LoadDynamicPlugin(path string, config any) (schemas.Plugin, error) { + dp := &DynamicPlugin{ + Path: path, + } + // Checking if path is URL or file path + if strings.HasPrefix(dp.Path, "http") { + // Download the file + tempPath, err := DownloadPlugin(dp.Path, ".so") + if err != nil { + return nil, err + } + dp.Path = tempPath + } + plugin, err := plugin.Open(dp.Path) + if err != nil { + return nil, err + } + ok := false + // Looking up for optional Init method + initSym, err := plugin.Lookup("Init") + if err != nil { + if strings.Contains(err.Error(), "symbol Init not found") { + initSym = nil + } else { + return nil, err + } + } + if initSym != nil { + initFunc, ok := initSym.(func(config any) error) + if !ok { + return nil, fmt.Errorf("failed to cast Init to func(config any) error") + } + err := initFunc(config) + if err != nil { + return nil, err + } + } + // Looking up for GetName method + getNameSym, err := plugin.Lookup("GetName") + if err != nil { + return nil, err + } + if dp.getName, ok = getNameSym.(func() string); !ok { + return nil, fmt.Errorf("failed to cast GetName to func() string") + } + // Looking up for HTTPTransportIntercept method + httpTransportInterceptSym, err := plugin.Lookup("HTTPTransportIntercept") + if err != nil { + return nil, err + } + if dp.httpTransportIntercept, ok = httpTransportInterceptSym.(func(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error)); !ok { + return nil, fmt.Errorf("failed to cast HTTPTransportIntercept to func(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error)") + } + // Looking up for PreHook method + preHookSym, err := plugin.Lookup("PreHook") + if err != nil { + return nil, err + } + if dp.preHook, ok = preHookSym.(func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error)); !ok { + return nil, fmt.Errorf("failed to cast PreHook to func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error)") + } + // Looking up for PostHook method + postHookSym, err := plugin.Lookup("PostHook") + if err != nil { + return nil, err + } + if dp.postHook, ok = postHookSym.(func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error)); !ok { + return nil, fmt.Errorf("failed to cast PostHook to func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error)") + } + // Looking up for Cleanup method + cleanupSym, err := plugin.Lookup("Cleanup") + if err != nil { + return nil, err + } + if dp.cleanup, ok = cleanupSym.(func() error); !ok { + return nil, fmt.Errorf("failed to cast Cleanup to func() error") + } + dp.plugin = plugin + return dp, nil +} diff --git a/framework/plugins/soplugin.go b/framework/plugins/soplugin.go new file mode 100644 index 0000000000..c909181cb2 --- /dev/null +++ b/framework/plugins/soplugin.go @@ -0,0 +1,52 @@ +package plugins + +import ( + "plugin" + + "github.com/maximhq/bifrost/core/schemas" +) + +// DynamicPlugin is the interface for a dynamic plugin +type DynamicPlugin struct { + Enabled bool + Path string + + Config any + + filename string + plugin *plugin.Plugin + + getName func() string + httpTransportIntercept func(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) + preHook func(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) + postHook func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) + cleanup func() error +} + +// GetName returns the name of the plugin +func (dp *DynamicPlugin) GetName() string { + return dp.getName() +} + +// HTTPTransportIntercept intercepts HTTP requests at the transport layer for this plugin +func (dp *DynamicPlugin) HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + if dp.httpTransportIntercept == nil { + return nil, nil + } + return dp.httpTransportIntercept(ctx, req) +} + +// PreHook is not used for dynamic plugins +func (dp *DynamicPlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + return dp.preHook(ctx, req) +} + +// PostHook is not used for dynamic plugins +func (dp *DynamicPlugin) PostHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + return dp.postHook(ctx, resp, bifrostErr) +} + +// Cleanup is not used for dynamic plugins +func (dp *DynamicPlugin) Cleanup() error { + return dp.cleanup() +} diff --git a/framework/plugins/dynamicplugin_test.go b/framework/plugins/soplugin_test.go similarity index 87% rename from framework/plugins/dynamicplugin_test.go rename to framework/plugins/soplugin_test.go index bfba8d567b..6e4821918f 100644 --- a/framework/plugins/dynamicplugin_test.go +++ b/framework/plugins/soplugin_test.go @@ -38,7 +38,8 @@ func TestDynamicPluginLifecycle(t *testing.T) { }, } - plugins, err := LoadPlugins(config) + loader := &SharedObjectPluginLoader{} + plugins, err := LoadPlugins(loader, config) require.NoError(t, err, "Failed to load plugins") require.Len(t, plugins, 1, "Expected exactly one plugin to be loaded") @@ -50,26 +51,31 @@ func TestDynamicPluginLifecycle(t *testing.T) { assert.Equal(t, "Hello World Plugin", name, "Plugin name should match") }) - // Test TransportInterceptor - t.Run("TransportInterceptor", func(t *testing.T) { + // Test HTTPTransportIntercept + t.Run("HTTPTransportIntercept", func(t *testing.T) { ctx := context.Background() - url := "http://example.com/api" - headers := map[string]string{ - "Content-Type": "application/json", - "Authorization": "Bearer token123", - } - body := map[string]any{ - "model": "gpt-4", - "messages": []map[string]string{ - {"role": "user", "content": "Hello"}, - }, - } pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second) defer cancel() - modifiedHeaders, modifiedBody, err := plugin.TransportInterceptor(pluginCtx, url, headers, body) - require.NoError(t, err, "TransportInterceptor should not return error") - assert.Equal(t, headers, modifiedHeaders, "Headers should be unchanged") - assert.Equal(t, body, modifiedBody, "Body should be unchanged") + + // Create a test HTTP request + req := &schemas.HTTPRequest{ + Method: "POST", + Path: "/api", + Headers: map[string]string{ + "Content-Type": "application/json", + "Authorization": "Bearer token123", + }, + Query: map[string]string{}, + Body: []byte(`{"test": "data"}`), + } + + // Call HTTPTransportIntercept + resp, err := plugin.HTTPTransportIntercept(pluginCtx, req) + require.NoError(t, err, "HTTPTransportIntercept should not return error") + assert.Nil(t, resp, "HTTPTransportIntercept should return nil response to continue") + + // Verify headers were modified (hello-world plugin adds a header) + assert.Equal(t, "transport-interceptor-value", req.Headers["x-hello-world-plugin"], "Plugin should have added custom header") }) // Test PreHook @@ -172,7 +178,8 @@ func TestLoadPlugins_DisabledPlugin(t *testing.T) { }, } - plugins, err := LoadPlugins(config) + loader := &SharedObjectPluginLoader{} + plugins, err := LoadPlugins(loader, config) require.NoError(t, err, "LoadPlugins should not error for disabled plugins") assert.Len(t, plugins, 0, "No plugins should be loaded when all are disabled") } @@ -199,7 +206,8 @@ func TestLoadPlugins_MultiplePlugins(t *testing.T) { }, } - plugins, err := LoadPlugins(config) + loader := &SharedObjectPluginLoader{} + plugins, err := LoadPlugins(loader, config) require.NoError(t, err, "LoadPlugins should succeed for multiple plugins") assert.Len(t, plugins, 2, "Two plugins should be loaded") @@ -221,7 +229,8 @@ func TestLoadPlugins_InvalidPath(t *testing.T) { }, } - plugins, err := LoadPlugins(config) + loader := &SharedObjectPluginLoader{} + plugins, err := LoadPlugins(loader, config) assert.Error(t, err, "LoadPlugins should return error for invalid path") assert.Nil(t, plugins, "No plugins should be loaded on error") } @@ -231,8 +240,8 @@ func TestLoadPlugins_EmptyConfig(t *testing.T) { config := &Config{ Plugins: []DynamicPluginConfig{}, } - - plugins, err := LoadPlugins(config) + loader := &SharedObjectPluginLoader{} + plugins, err := LoadPlugins(loader, config) require.NoError(t, err, "LoadPlugins should succeed with empty config") assert.Len(t, plugins, 0, "No plugins should be loaded with empty config") } @@ -242,7 +251,8 @@ func TestDynamicPlugin_ContextPropagation(t *testing.T) { pluginPath := buildHelloWorldPlugin(t) defer cleanupHelloWorldPlugin(t) - plugin, err := loadDynamicPlugin(pluginPath, nil) + loader := &SharedObjectPluginLoader{} + plugin, err := loader.LoadDynamicPlugin(pluginPath, nil) require.NoError(t, err, "Failed to load plugin") // Create a context with a value @@ -277,7 +287,8 @@ func TestDynamicPlugin_ConcurrentCalls(t *testing.T) { pluginPath := buildHelloWorldPlugin(t) defer cleanupHelloWorldPlugin(t) - plugin, err := loadDynamicPlugin(pluginPath, nil) + loader := &SharedObjectPluginLoader{} + plugin, err := loader.LoadDynamicPlugin(pluginPath, nil) require.NoError(t, err, "Failed to load plugin") // Run multiple goroutines calling plugin methods @@ -384,7 +395,8 @@ func TestLoadDynamicPlugin_DirectCall(t *testing.T) { pluginPath := buildHelloWorldPlugin(t) defer cleanupHelloWorldPlugin(t) - plugin, err := loadDynamicPlugin(pluginPath, map[string]interface{}{ + loader := &SharedObjectPluginLoader{} + plugin, err := loader.LoadDynamicPlugin(pluginPath, map[string]interface{}{ "test": "config", }) require.NoError(t, err, "loadDynamicPlugin should succeed") @@ -401,7 +413,8 @@ func TestDynamicPlugin_NilConfig(t *testing.T) { pluginPath := buildHelloWorldPlugin(t) defer cleanupHelloWorldPlugin(t) - plugin, err := loadDynamicPlugin(pluginPath, nil) + loader := &SharedObjectPluginLoader{} + plugin, err := loader.LoadDynamicPlugin(pluginPath, nil) require.NoError(t, err, "loadDynamicPlugin should succeed with nil config") assert.NotNil(t, plugin, "Plugin should not be nil") @@ -415,7 +428,8 @@ func TestDynamicPlugin_ShortCircuitNil(t *testing.T) { pluginPath := buildHelloWorldPlugin(t) defer cleanupHelloWorldPlugin(t) - plugin, err := loadDynamicPlugin(pluginPath, nil) + loader := &SharedObjectPluginLoader{} + plugin, err := loader.LoadDynamicPlugin(pluginPath, nil) require.NoError(t, err, "Failed to load plugin") ctx := context.Background() @@ -440,7 +454,8 @@ func BenchmarkDynamicPlugin_PreHook(b *testing.B) { pluginPath := buildHelloWorldPluginForBenchmark(b) defer cleanupHelloWorldPluginForBenchmark(b) - plugin, err := loadDynamicPlugin(pluginPath, nil) + loader := &SharedObjectPluginLoader{} + plugin, err := loader.LoadDynamicPlugin(pluginPath, nil) require.NoError(b, err, "Failed to load plugin") ctx := context.Background() @@ -465,7 +480,8 @@ func BenchmarkDynamicPlugin_PostHook(b *testing.B) { pluginPath := buildHelloWorldPluginForBenchmark(b) defer cleanupHelloWorldPluginForBenchmark(b) - plugin, err := loadDynamicPlugin(pluginPath, nil) + loader := &SharedObjectPluginLoader{} + plugin, err := loader.LoadDynamicPlugin(pluginPath, nil) require.NoError(b, err, "Failed to load plugin") ctx := context.Background() @@ -488,7 +504,8 @@ func BenchmarkDynamicPlugin_GetName(b *testing.B) { pluginPath := buildHelloWorldPluginForBenchmark(b) defer cleanupHelloWorldPluginForBenchmark(b) - plugin, err := loadDynamicPlugin(pluginPath, nil) + loader := &SharedObjectPluginLoader{} + plugin, err := loader.LoadDynamicPlugin(pluginPath, nil) require.NoError(b, err, "Failed to load plugin") b.ResetTimer() @@ -555,7 +572,8 @@ func TestDynamicPlugin_GetNameNotEmpty(t *testing.T) { pluginPath := buildHelloWorldPlugin(t) defer cleanupHelloWorldPlugin(t) - plugin, err := loadDynamicPlugin(pluginPath, nil) + loader := &SharedObjectPluginLoader{} + plugin, err := loader.LoadDynamicPlugin(pluginPath, nil) require.NoError(t, err, "Failed to load plugin") name := plugin.GetName() diff --git a/framework/plugins/utils.go b/framework/plugins/utils.go new file mode 100644 index 0000000000..f9f4465c5c --- /dev/null +++ b/framework/plugins/utils.go @@ -0,0 +1,72 @@ +package plugins + +import ( + "fmt" + "os" + "time" + + "github.com/valyala/fasthttp" +) + +// DownloadPlugin downloads a plugin from a URL and returns the local file path +func DownloadPlugin(url string, extension string) (string, error) { + req := fasthttp.AcquireRequest() + defer fasthttp.ReleaseRequest(req) + response := fasthttp.AcquireResponse() + defer fasthttp.ReleaseResponse(response) + + req.SetRequestURI(url) + req.Header.SetMethod(fasthttp.MethodGet) + req.Header.Set("Accept", "application/octet-stream") + req.Header.Set("Accept-Encoding", "gzip") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") + + err := fasthttp.DoTimeout(req, response, 120*time.Second) + if err != nil { + return "", err + } + + if response.StatusCode() != fasthttp.StatusOK { + return "", fmt.Errorf("failed to download plugin: %d", response.StatusCode()) + } + + // Decompress the response body if it was gzip/deflate compressed + // BodyUncompressed handles both gzip and deflate encodings based on Content-Encoding header + body, err := response.BodyUncompressed() + if err != nil { + return "", fmt.Errorf("failed to decompress response body: %w", err) + } + + // Create a unique temporary file for the plugin + tempFile, err := os.CreateTemp(os.TempDir(), "bifrost-plugin-*"+extension) + if err != nil { + return "", fmt.Errorf("failed to create temporary file: %w", err) + } + tempPath := tempFile.Name() + + // Write the downloaded body to the temporary file + _, err = tempFile.Write(body) + if err != nil { + tempFile.Close() + os.Remove(tempPath) + return "", fmt.Errorf("failed to write plugin to temporary file: %w", err) + } + + // Close the file + err = tempFile.Close() + if err != nil { + os.Remove(tempPath) + return "", fmt.Errorf("failed to close temporary file: %w", err) + } + + // Set file permissions to be executable (for .so files) + if extension == ".so" { + err = os.Chmod(tempPath, 0755) + if err != nil { + os.Remove(tempPath) + return "", fmt.Errorf("failed to set executable permissions on plugin: %w", err) + } + } + + return tempPath, nil +} diff --git a/framework/streaming/accumulator.go b/framework/streaming/accumulator.go index b19bfde0c0..d8604e1b70 100644 --- a/framework/streaming/accumulator.go +++ b/framework/streaming/accumulator.go @@ -10,6 +10,15 @@ import ( "github.com/maximhq/bifrost/framework/modelcatalog" ) +// getAccumulatorID extracts the ID for accumulator lookup from context. +// Returns the value of BifrostContextKeyAccumulatorID. +func getAccumulatorID(ctx *schemas.BifrostContext) (string, bool) { + if id, ok := ctx.Value(schemas.BifrostContextKeyAccumulatorID).(string); ok && id != "" { + return id, true + } + return "", false +} + // Accumulator manages accumulation of streaming chunks type Accumulator struct { logger schemas.Logger @@ -101,14 +110,28 @@ func (a *Accumulator) putResponsesStreamChunk(chunk *ResponsesStreamChunk) { a.responsesStreamChunkPool.Put(chunk) } -// CreateStreamAccumulator creates a new stream accumulator for a request +// createStreamAccumulator creates a new stream accumulator for a request +// StartTimestamp is set to current time if not provided via CreateStreamAccumulator func (a *Accumulator) createStreamAccumulator(requestID string) *StreamAccumulator { + now := time.Now() sc := &StreamAccumulator{ - RequestID: requestID, - ChatStreamChunks: make([]*ChatStreamChunk, 0), - ResponsesStreamChunks: make([]*ResponsesStreamChunk, 0), - IsComplete: false, - Timestamp: time.Now(), + RequestID: requestID, + ChatStreamChunks: make([]*ChatStreamChunk, 0), + ResponsesStreamChunks: make([]*ResponsesStreamChunk, 0), + TranscriptionStreamChunks: make([]*TranscriptionStreamChunk, 0), + AudioStreamChunks: make([]*AudioStreamChunk, 0), + ChatChunksSeen: make(map[int]struct{}), + ResponsesChunksSeen: make(map[int]struct{}), + TranscriptionChunksSeen: make(map[int]struct{}), + AudioChunksSeen: make(map[int]struct{}), + MaxChatChunkIndex: -1, + MaxResponsesChunkIndex: -1, + MaxTranscriptionChunkIndex: -1, + MaxAudioChunkIndex: -1, + IsComplete: false, + mu: sync.Mutex{}, + Timestamp: now, + StartTimestamp: now, // Set default StartTimestamp for proper TTFT/latency calculation } a.streamAccumulators.Store(requestID, sc) return sc @@ -132,8 +155,19 @@ func (a *Accumulator) addChatStreamChunk(requestID string, chunk *ChatStreamChun if accumulator.StartTimestamp.IsZero() { accumulator.StartTimestamp = chunk.Timestamp } - // Add chunk to the list (chunks arrive in order) - accumulator.ChatStreamChunks = append(accumulator.ChatStreamChunks, chunk) + // Track first chunk timestamp for TTFT calculation + if accumulator.FirstChunkTimestamp.IsZero() { + accumulator.FirstChunkTimestamp = chunk.Timestamp + } + // De-dup check - only add if not seen (handles out-of-order arrival and multiple plugins) + if _, seen := accumulator.ChatChunksSeen[chunk.ChunkIndex]; !seen { + accumulator.ChatChunksSeen[chunk.ChunkIndex] = struct{}{} + accumulator.ChatStreamChunks = append(accumulator.ChatStreamChunks, chunk) + // Track max index for metadata extraction + if chunk.ChunkIndex > accumulator.MaxChatChunkIndex { + accumulator.MaxChatChunkIndex = chunk.ChunkIndex + } + } // Check if this is the final chunk // Set FinalTimestamp when either FinishReason is present or token usage exists // This handles both normal completion chunks and usage-only last chunks @@ -152,8 +186,18 @@ func (a *Accumulator) addTranscriptionStreamChunk(requestID string, chunk *Trans if accumulator.StartTimestamp.IsZero() { accumulator.StartTimestamp = chunk.Timestamp } - // Add chunk to the list (chunks arrive in order) - accumulator.TranscriptionStreamChunks = append(accumulator.TranscriptionStreamChunks, chunk) + // Track first chunk timestamp for TTFT calculation + if accumulator.FirstChunkTimestamp.IsZero() { + accumulator.FirstChunkTimestamp = chunk.Timestamp + } + if _, seen := accumulator.TranscriptionChunksSeen[chunk.ChunkIndex]; !seen { + accumulator.TranscriptionChunksSeen[chunk.ChunkIndex] = struct{}{} + accumulator.TranscriptionStreamChunks = append(accumulator.TranscriptionStreamChunks, chunk) + // Track max index for metadata extraction + if chunk.ChunkIndex > accumulator.MaxTranscriptionChunkIndex { + accumulator.MaxTranscriptionChunkIndex = chunk.ChunkIndex + } + } // Check if this is the final chunk // Set FinalTimestamp when either FinishReason is present or token usage exists // This handles both normal completion chunks and usage-only last chunks @@ -172,8 +216,18 @@ func (a *Accumulator) addAudioStreamChunk(requestID string, chunk *AudioStreamCh if accumulator.StartTimestamp.IsZero() { accumulator.StartTimestamp = chunk.Timestamp } - // Add chunk to the list (chunks arrive in order) - accumulator.AudioStreamChunks = append(accumulator.AudioStreamChunks, chunk) + // Track first chunk timestamp for TTFT calculation + if accumulator.FirstChunkTimestamp.IsZero() { + accumulator.FirstChunkTimestamp = chunk.Timestamp + } + if _, seen := accumulator.AudioChunksSeen[chunk.ChunkIndex]; !seen { + accumulator.AudioChunksSeen[chunk.ChunkIndex] = struct{}{} + accumulator.AudioStreamChunks = append(accumulator.AudioStreamChunks, chunk) + // Track max index for metadata extraction + if chunk.ChunkIndex > accumulator.MaxAudioChunkIndex { + accumulator.MaxAudioChunkIndex = chunk.ChunkIndex + } + } // Check if this is the final chunk // Set FinalTimestamp when either FinishReason is present or token usage exists // This handles both normal completion chunks and usage-only last chunks @@ -192,8 +246,18 @@ func (a *Accumulator) addResponsesStreamChunk(requestID string, chunk *Responses if accumulator.StartTimestamp.IsZero() { accumulator.StartTimestamp = chunk.Timestamp } - // Add chunk to the list (chunks arrive in order) - accumulator.ResponsesStreamChunks = append(accumulator.ResponsesStreamChunks, chunk) + // Track first chunk timestamp for TTFT calculation + if accumulator.FirstChunkTimestamp.IsZero() { + accumulator.FirstChunkTimestamp = chunk.Timestamp + } + if _, seen := accumulator.ResponsesChunksSeen[chunk.ChunkIndex]; !seen { + accumulator.ResponsesChunksSeen[chunk.ChunkIndex] = struct{}{} + accumulator.ResponsesStreamChunks = append(accumulator.ResponsesStreamChunks, chunk) + // Track max index for metadata extraction + if chunk.ChunkIndex > accumulator.MaxResponsesChunkIndex { + accumulator.MaxResponsesChunkIndex = chunk.ChunkIndex + } + } // Check if this is the final chunk // Set FinalTimestamp when either FinishReason is present or token usage exists // This handles both normal completion chunks and usage-only last chunks @@ -363,8 +427,11 @@ func (a *Accumulator) Cleanup() { } // CreateStreamAccumulator creates a new stream accumulator for a request +// It increments the reference counter atomically for concurrent access tracking func (a *Accumulator) CreateStreamAccumulator(requestID string, startTimestamp time.Time) *StreamAccumulator { sc := a.getOrCreateStreamAccumulator(requestID) + // Atomically increment reference counter + sc.refCount.Add(1) // Lock before writing to StartTimestamp sc.mu.Lock() sc.StartTimestamp = startTimestamp @@ -372,16 +439,25 @@ func (a *Accumulator) CreateStreamAccumulator(requestID string, startTimestamp t return sc } -// CleanupStreamAccumulator cleans up the stream accumulator for a request +// CleanupStreamAccumulator decrements the reference counter for a stream accumulator. +// The accumulator is only cleaned up when the reference counter reaches 0. +// This function is idempotent - calling it after cleanup has already happened is safe. func (a *Accumulator) CleanupStreamAccumulator(requestID string) error { acc, exists := a.streamAccumulators.Load(requestID) if !exists { - return fmt.Errorf("accumulator not found for request ID: %s", requestID) + // Accumulator already cleaned up - this is expected when multiple callers + // (e.g., completeDeferredSpan and HTTP middleware) both call cleanup + return nil } if accumulator, ok := acc.(*StreamAccumulator); ok { - accumulator.mu.Lock() - defer accumulator.mu.Unlock() - a.cleanupStreamAccumulator(requestID) + // Atomically decrement reference counter + newCount := accumulator.refCount.Add(-1) + // Only cleanup when reference counter reaches 0 + if newCount <= 0 { + accumulator.mu.Lock() + defer accumulator.mu.Unlock() + a.cleanupStreamAccumulator(requestID) + } } return nil } diff --git a/framework/streaming/audio.go b/framework/streaming/audio.go index 5123d1edeb..ffc4a166cd 100644 --- a/framework/streaming/audio.go +++ b/framework/streaming/audio.go @@ -28,26 +28,30 @@ func (a *Accumulator) processAccumulatedAudioStreamingChunks(requestID string, b accumulator := a.getOrCreateStreamAccumulator(requestID) // Lock the accumulator accumulator.mu.Lock() - defer func() { - if isFinalChunk { - // Cleanup BEFORE unlocking to prevent other goroutines from accessing chunks being returned to pool - a.cleanupStreamAccumulator(requestID) - } - accumulator.mu.Unlock() - }() + defer accumulator.mu.Unlock() + // Note: Cleanup is handled by CleanupStreamAccumulator when refcount reaches 0 + // This is called from completeDeferredSpan after streaming ends + + // Calculate Time to First Token (TTFT) in milliseconds + var ttft int64 + if !accumulator.StartTimestamp.IsZero() && !accumulator.FirstChunkTimestamp.IsZero() { + ttft = accumulator.FirstChunkTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6 + } + data := &AccumulatedData{ - RequestID: requestID, - Status: "success", - Stream: true, - StartTimestamp: accumulator.StartTimestamp, - EndTimestamp: accumulator.FinalTimestamp, - Latency: 0, - OutputMessage: nil, - ToolCalls: nil, - ErrorDetails: nil, - TokenUsage: nil, - CacheDebug: nil, - Cost: nil, + RequestID: requestID, + Status: "success", + Stream: true, + StartTimestamp: accumulator.StartTimestamp, + EndTimestamp: accumulator.FinalTimestamp, + Latency: 0, + TimeToFirstToken: ttft, + OutputMessage: nil, + ToolCalls: nil, + ErrorDetails: nil, + TokenUsage: nil, + CacheDebug: nil, + Cost: nil, } completeMessage := a.buildCompleteMessageFromAudioStreamChunks(accumulator.AudioStreamChunks) if !isFinalChunk { @@ -66,9 +70,8 @@ func (a *Accumulator) processAccumulatedAudioStreamingChunks(requestID string, b data.EndTimestamp = accumulator.FinalTimestamp data.AudioOutput = completeMessage data.ErrorDetails = bifrostErr - // Update token usage from final chunk if available - if len(accumulator.AudioStreamChunks) > 0 { - lastChunk := accumulator.AudioStreamChunks[len(accumulator.AudioStreamChunks)-1] + // Update metadata from the chunk with highest index (contains TokenUsage, Cost, CacheDebug) + if lastChunk := accumulator.getLastAudioChunk(); lastChunk != nil { if lastChunk.TokenUsage != nil { data.TokenUsage = &schemas.BifrostLLMUsage{ PromptTokens: lastChunk.TokenUsage.InputTokens, @@ -76,17 +79,9 @@ func (a *Accumulator) processAccumulatedAudioStreamingChunks(requestID string, b TotalTokens: lastChunk.TokenUsage.TotalTokens, } } - } - // Update cost from final chunk if available - if len(accumulator.AudioStreamChunks) > 0 { - lastChunk := accumulator.AudioStreamChunks[len(accumulator.AudioStreamChunks)-1] if lastChunk.Cost != nil { data.Cost = lastChunk.Cost } - } - // Update semantic cache debug from final chunk if available - if len(accumulator.AudioStreamChunks) > 0 { - lastChunk := accumulator.AudioStreamChunks[len(accumulator.AudioStreamChunks)-1] if lastChunk.SemanticCacheDebug != nil { data.CacheDebug = lastChunk.SemanticCacheDebug } @@ -112,11 +107,11 @@ func (a *Accumulator) processAccumulatedAudioStreamingChunks(requestID string, b // processAudioStreamingResponse processes a audio streaming response func (a *Accumulator) processAudioStreamingResponse(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*ProcessedStreamResponse, error) { - // Extract request ID from context - requestID, ok := (*ctx).Value(schemas.BifrostContextKeyRequestID).(string) + // Extract accumulator ID from context + requestID, ok := getAccumulatorID(ctx) if !ok || requestID == "" { // Log error but don't fail the request - return nil, fmt.Errorf("request-id not found in context or is empty") + return nil, fmt.Errorf("accumulator-id not found in context or is empty") } _, provider, model := bifrost.GetResponseFields(result, bifrostErr) isFinalChunk := bifrost.IsFinalChunk(ctx) @@ -152,36 +147,35 @@ func (a *Accumulator) processAudioStreamingResponse(ctx *schemas.BifrostContext, if addErr := a.addAudioStreamChunk(requestID, chunk, isFinalChunk); addErr != nil { return nil, fmt.Errorf("failed to add stream chunk for request %s: %w", requestID, addErr) } + // Always return data on final chunk - multiple plugins may need the result if isFinalChunk { - shouldProcess := false + // Get the accumulator and mark as complete (idempotent) accumulator := a.getOrCreateStreamAccumulator(requestID) accumulator.mu.Lock() - shouldProcess = !accumulator.IsComplete - if shouldProcess { + if !accumulator.IsComplete { accumulator.IsComplete = true } accumulator.mu.Unlock() - if shouldProcess { - data, processErr := a.processAccumulatedAudioStreamingChunks(requestID, bifrostErr, isFinalChunk) - if processErr != nil { - a.logger.Error("failed to process accumulated chunks for request %s: %v", requestID, processErr) - return nil, processErr - } - var rawRequest interface{} - if result != nil && result.SpeechStreamResponse != nil && result.SpeechStreamResponse.ExtraFields.RawRequest != nil { - rawRequest = result.SpeechStreamResponse.ExtraFields.RawRequest - } - return &ProcessedStreamResponse{ - Type: StreamResponseTypeFinal, - RequestID: requestID, - StreamType: StreamTypeAudio, - Model: model, - Provider: provider, - Data: data, - RawRequest: &rawRequest, - }, nil + + // Always process and return data on final chunk + // Multiple plugins can call this - the processing is idempotent + data, processErr := a.processAccumulatedAudioStreamingChunks(requestID, bifrostErr, isFinalChunk) + if processErr != nil { + a.logger.Error("failed to process accumulated chunks for request %s: %v", requestID, processErr) + return nil, processErr + } + var rawRequest interface{} + if result != nil && result.SpeechStreamResponse != nil && result.SpeechStreamResponse.ExtraFields.RawRequest != nil { + rawRequest = result.SpeechStreamResponse.ExtraFields.RawRequest } - return nil, nil + return &ProcessedStreamResponse{ + RequestID: requestID, + StreamType: StreamTypeAudio, + Model: model, + Provider: provider, + Data: data, + RawRequest: &rawRequest, + }, nil } data, processErr := a.processAccumulatedAudioStreamingChunks(requestID, bifrostErr, isFinalChunk) if processErr != nil { @@ -189,7 +183,6 @@ func (a *Accumulator) processAudioStreamingResponse(ctx *schemas.BifrostContext, return nil, processErr } return &ProcessedStreamResponse{ - Type: StreamResponseTypeDelta, RequestID: requestID, StreamType: StreamTypeAudio, Model: model, diff --git a/framework/streaming/chat.go b/framework/streaming/chat.go index 36f3b97176..cae3b35757 100644 --- a/framework/streaming/chat.go +++ b/framework/streaming/chat.go @@ -9,6 +9,105 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) +// deepCopyChatStreamDelta creates a deep copy of ChatStreamResponseChoiceDelta +// to prevent shared data mutation between different chunks +func deepCopyChatStreamDelta(original *schemas.ChatStreamResponseChoiceDelta) *schemas.ChatStreamResponseChoiceDelta { + if original == nil { + return nil + } + + copy := &schemas.ChatStreamResponseChoiceDelta{} + + if original.Role != nil { + copyRole := *original.Role + copy.Role = ©Role + } + + if original.Content != nil { + copyContent := *original.Content + copy.Content = ©Content + } + + if original.Refusal != nil { + copyRefusal := *original.Refusal + copy.Refusal = ©Refusal + } + + if original.Reasoning != nil { + copyReasoning := *original.Reasoning + copy.Reasoning = ©Reasoning + } + + // Deep copy ReasoningDetails slice + if original.ReasoningDetails != nil { + copy.ReasoningDetails = make([]schemas.ChatReasoningDetails, len(original.ReasoningDetails)) + for i, rd := range original.ReasoningDetails { + copyRd := schemas.ChatReasoningDetails{ + Index: rd.Index, + Type: rd.Type, + } + if rd.ID != nil { + copyID := *rd.ID + copyRd.ID = ©ID + } + if rd.Text != nil { + copyText := *rd.Text + copyRd.Text = ©Text + } + if rd.Signature != nil { + copySig := *rd.Signature + copyRd.Signature = ©Sig + } + if rd.Summary != nil { + copySummary := *rd.Summary + copyRd.Summary = ©Summary + } + if rd.Data != nil { + copyData := *rd.Data + copyRd.Data = ©Data + } + copy.ReasoningDetails[i] = copyRd + } + } + + // Deep copy ToolCalls slice + if original.ToolCalls != nil { + copy.ToolCalls = make([]schemas.ChatAssistantMessageToolCall, len(original.ToolCalls)) + for i, tc := range original.ToolCalls { + copyTc := schemas.ChatAssistantMessageToolCall{ + Index: tc.Index, + Function: tc.Function, // struct value, safe to copy directly + } + if tc.ID != nil { + copyID := *tc.ID + copyTc.ID = ©ID + } + if tc.Type != nil { + copyType := *tc.Type + copyTc.Type = ©Type + } + // Deep copy Function's Name pointer + if tc.Function.Name != nil { + copyName := *tc.Function.Name + copyTc.Function.Name = ©Name + } + copy.ToolCalls[i] = copyTc + } + } + + // Deep copy Audio + if original.Audio != nil { + copy.Audio = &schemas.ChatAudioMessageAudio{ + ID: original.Audio.ID, + Data: original.Audio.Data, + ExpiresAt: original.Audio.ExpiresAt, + Transcript: original.Audio.Transcript, + } + } + + return copy +} + // buildCompleteMessageFromChunks builds a complete message from accumulated chunks func (a *Accumulator) buildCompleteMessageFromChatStreamChunks(chunks []*ChatStreamChunk) *schemas.ChatMessage { completeMessage := &schemas.ChatMessage{ @@ -18,6 +117,7 @@ func (a *Accumulator) buildCompleteMessageFromChatStreamChunks(chunks []*ChatStr sort.Slice(chunks, func(i, j int) bool { return chunks[i].ChunkIndex < chunks[j].ChunkIndex }) + for _, chunk := range chunks { if chunk.Delta == nil { continue @@ -26,80 +126,121 @@ func (a *Accumulator) buildCompleteMessageFromChatStreamChunks(chunks []*ChatStr if chunk.Delta.Role != nil { completeMessage.Role = schemas.ChatMessageRole(*chunk.Delta.Role) } - // Append content + // Append content delta if chunk.Delta.Content != nil && *chunk.Delta.Content != "" { a.appendContentToMessage(completeMessage, *chunk.Delta.Content) } - // Handle refusal + // Handle refusal delta if chunk.Delta.Refusal != nil && *chunk.Delta.Refusal != "" { if completeMessage.ChatAssistantMessage == nil { completeMessage.ChatAssistantMessage = &schemas.ChatAssistantMessage{} } if completeMessage.ChatAssistantMessage.Refusal == nil { - completeMessage.ChatAssistantMessage.Refusal = bifrost.Ptr(*chunk.Delta.Refusal) + // Deep copy on first assignment + refusalCopy := *chunk.Delta.Refusal + completeMessage.ChatAssistantMessage.Refusal = &refusalCopy } else { *completeMessage.ChatAssistantMessage.Refusal += *chunk.Delta.Refusal } } - // Handle reasoning + // Handle reasoning delta if chunk.Delta.Reasoning != nil && *chunk.Delta.Reasoning != "" { if completeMessage.ChatAssistantMessage == nil { completeMessage.ChatAssistantMessage = &schemas.ChatAssistantMessage{} } if completeMessage.ChatAssistantMessage.Reasoning == nil { - completeMessage.ChatAssistantMessage.Reasoning = bifrost.Ptr(*chunk.Delta.Reasoning) + // Deep copy on first assignment + reasoningCopy := *chunk.Delta.Reasoning + completeMessage.ChatAssistantMessage.Reasoning = &reasoningCopy } else { *completeMessage.ChatAssistantMessage.Reasoning += *chunk.Delta.Reasoning } } - // Handle reasoning details + // Handle reasoning details delta if len(chunk.Delta.ReasoningDetails) > 0 { if completeMessage.ChatAssistantMessage == nil { completeMessage.ChatAssistantMessage = &schemas.ChatAssistantMessage{} } - // Check if the reasoning detail already exists on that index, if so, update it else add it to the list - for _, reasoningDetail := range chunk.Delta.ReasoningDetails { + // Accumulate reasoning details by index + for _, rd := range chunk.Delta.ReasoningDetails { found := false for i := range completeMessage.ChatAssistantMessage.ReasoningDetails { - existingReasoningDetail := &completeMessage.ChatAssistantMessage.ReasoningDetails[i] - if existingReasoningDetail.Index == reasoningDetail.Index { - // Update text - accumulate if both exist - if reasoningDetail.Text != nil { - if existingReasoningDetail.Text == nil { - existingReasoningDetail.Text = reasoningDetail.Text + existingRd := &completeMessage.ChatAssistantMessage.ReasoningDetails[i] + if existingRd.Index == rd.Index { + // Found matching index - accumulate text delta + if rd.Text != nil && *rd.Text != "" { + if existingRd.Text == nil { + // Deep copy on first assignment + textCopy := *rd.Text + existingRd.Text = &textCopy } else { - *existingReasoningDetail.Text += *reasoningDetail.Text + *existingRd.Text += *rd.Text } } - // Update signature - overwrite (signatures are typically final) - if reasoningDetail.Signature != nil { - existingReasoningDetail.Signature = reasoningDetail.Signature - } - // Update other fields if present - if reasoningDetail.Summary != nil { - if existingReasoningDetail.Summary == nil { - existingReasoningDetail.Summary = reasoningDetail.Summary + // Accumulate summary delta + if rd.Summary != nil && *rd.Summary != "" { + if existingRd.Summary == nil { + summaryCopy := *rd.Summary + existingRd.Summary = &summaryCopy } else { - *existingReasoningDetail.Summary += *reasoningDetail.Summary + *existingRd.Summary += *rd.Summary } } - if reasoningDetail.Data != nil { - if existingReasoningDetail.Data == nil { - existingReasoningDetail.Data = reasoningDetail.Data + // Accumulate data delta + if rd.Data != nil && *rd.Data != "" { + if existingRd.Data == nil { + dataCopy := *rd.Data + existingRd.Data = &dataCopy } else { - *existingReasoningDetail.Data += *reasoningDetail.Data + *existingRd.Data += *rd.Data } } - if reasoningDetail.Type != "" { - existingReasoningDetail.Type = reasoningDetail.Type + // Overwrite signature (typically sent once at the end) + if rd.Signature != nil { + sigCopy := *rd.Signature + existingRd.Signature = &sigCopy + } + // Update type if present + if rd.Type != "" { + existingRd.Type = rd.Type + } + // Update ID if present + if rd.ID != nil { + idCopy := *rd.ID + existingRd.ID = &idCopy } found = true break } } - // If not found, add it to the list + // If not found, add new entry with deep copied values if !found { - completeMessage.ChatAssistantMessage.ReasoningDetails = append(completeMessage.ChatAssistantMessage.ReasoningDetails, reasoningDetail) + newRd := schemas.ChatReasoningDetails{ + Index: rd.Index, + Type: rd.Type, + } + if rd.ID != nil { + idCopy := *rd.ID + newRd.ID = &idCopy + } + if rd.Text != nil { + textCopy := *rd.Text + newRd.Text = &textCopy + } + if rd.Signature != nil { + sigCopy := *rd.Signature + newRd.Signature = &sigCopy + } + if rd.Summary != nil { + summaryCopy := *rd.Summary + newRd.Summary = &summaryCopy + } + if rd.Data != nil { + dataCopy := *rd.Data + newRd.Data = &dataCopy + } + completeMessage.ChatAssistantMessage.ReasoningDetails = append( + completeMessage.ChatAssistantMessage.ReasoningDetails, newRd) } } } @@ -109,7 +250,7 @@ func (a *Accumulator) buildCompleteMessageFromChatStreamChunks(chunks []*ChatStr completeMessage.ChatAssistantMessage = &schemas.ChatAssistantMessage{} } if completeMessage.ChatAssistantMessage.Audio == nil { - // First chunk with audio - initialize + // First chunk with audio - initialize with copies completeMessage.ChatAssistantMessage.Audio = &schemas.ChatAudioMessageAudio{ ID: chunk.Delta.Audio.ID, Data: chunk.Delta.Audio.Data, @@ -138,6 +279,7 @@ func (a *Accumulator) buildCompleteMessageFromChatStreamChunks(chunks []*ChatStr a.accumulateToolCallsInMessage(completeMessage, chunk.Delta.ToolCalls) } } + return completeMessage } @@ -146,27 +288,31 @@ func (a *Accumulator) processAccumulatedChatStreamingChunks(requestID string, re accumulator := a.getOrCreateStreamAccumulator(requestID) // Lock the accumulator accumulator.mu.Lock() - defer func() { - if isFinalChunk { - // Cleanup BEFORE unlocking to prevent other goroutines from accessing chunks being returned to pool - a.cleanupStreamAccumulator(requestID) - } - accumulator.mu.Unlock() - }() + defer accumulator.mu.Unlock() + // Note: Cleanup is handled by CleanupStreamAccumulator when refcount reaches 0 + // This is called from completeDeferredSpan after streaming ends + + // Calculate Time to First Token (TTFT) in milliseconds + var ttft int64 + if !accumulator.StartTimestamp.IsZero() && !accumulator.FirstChunkTimestamp.IsZero() { + ttft = accumulator.FirstChunkTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6 + } + // Initialize accumulated data data := &AccumulatedData{ - RequestID: requestID, - Status: "success", - Stream: true, - StartTimestamp: accumulator.StartTimestamp, - EndTimestamp: accumulator.FinalTimestamp, - Latency: 0, - OutputMessage: nil, - ToolCalls: nil, - ErrorDetails: nil, - TokenUsage: nil, - CacheDebug: nil, - Cost: nil, + RequestID: requestID, + Status: "success", + Stream: true, + StartTimestamp: accumulator.StartTimestamp, + EndTimestamp: accumulator.FinalTimestamp, + Latency: 0, + TimeToFirstToken: ttft, + OutputMessage: nil, + ToolCalls: nil, + ErrorDetails: nil, + TokenUsage: nil, + CacheDebug: nil, + Cost: nil, } // Build complete message from accumulated chunks completeMessage := a.buildCompleteMessageFromChatStreamChunks(accumulator.ChatStreamChunks) @@ -190,20 +336,14 @@ func (a *Accumulator) processAccumulatedChatStreamingChunks(requestID string, re data.ToolCalls = data.OutputMessage.ChatAssistantMessage.ToolCalls } data.ErrorDetails = respErr - // Update token usage from final chunk if available - if len(accumulator.ChatStreamChunks) > 0 { - lastChunk := accumulator.ChatStreamChunks[len(accumulator.ChatStreamChunks)-1] + // Update metadata from the chunk with highest index (contains TokenUsage, Cost, FinishReason) + if lastChunk := accumulator.getLastChatChunk(); lastChunk != nil { if lastChunk.TokenUsage != nil { data.TokenUsage = lastChunk.TokenUsage } - // Handle cache debug if lastChunk.SemanticCacheDebug != nil { data.CacheDebug = lastChunk.SemanticCacheDebug } - } - // Update cost from final chunk if available - if len(accumulator.ChatStreamChunks) > 0 { - lastChunk := accumulator.ChatStreamChunks[len(accumulator.ChatStreamChunks)-1] if lastChunk.Cost != nil { data.Cost = lastChunk.Cost } @@ -231,11 +371,11 @@ func (a *Accumulator) processAccumulatedChatStreamingChunks(requestID string, re // processChatStreamingResponse processes a chat streaming response func (a *Accumulator) processChatStreamingResponse(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*ProcessedStreamResponse, error) { a.logger.Debug("[streaming] processing chat streaming response") - // Extract request ID from context - requestID, ok := (*ctx).Value(schemas.BifrostContextKeyRequestID).(string) + // Extract accumulator ID from context + requestID, ok := getAccumulatorID(ctx) if !ok || requestID == "" { // Log error but don't fail the request - return nil, fmt.Errorf("request-id not found in context or is empty") + return nil, fmt.Errorf("accumulator-id not found in context or is empty") } requestType, provider, model := bifrost.GetResponseFields(result, bifrostErr) @@ -280,9 +420,8 @@ func (a *Accumulator) processChatStreamingResponse(ctx *schemas.BifrostContext, if len(result.ChatResponse.Choices) > 0 { choice := result.ChatResponse.Choices[0] if choice.ChatStreamResponseChoice != nil { - // Shallow-copy struct and deep-copy slices to avoid aliasing - copied := choice.ChatStreamResponseChoice.Delta - chunk.Delta = copied + // Deep copy delta to prevent shared data mutation between chunks + chunk.Delta = deepCopyChatStreamDelta(choice.ChatStreamResponseChoice.Delta) chunk.FinishReason = choice.FinishReason } } @@ -305,40 +444,38 @@ func (a *Accumulator) processChatStreamingResponse(ctx *schemas.BifrostContext, if addErr := a.addChatStreamChunk(requestID, chunk, isFinalChunk); addErr != nil { return nil, fmt.Errorf("failed to add stream chunk for request %s: %w", requestID, addErr) } - // If this is the final chunk, process accumulated chunks asynchronously - // Use the IsComplete flag to prevent duplicate processing + // If this is the final chunk, process accumulated chunks + // Always return data on final chunk - multiple plugins may need the result if isFinalChunk { - shouldProcess := false - // Get the accumulator to check if processing has already been triggered + // Get the accumulator and mark as complete (idempotent) accumulator := a.getOrCreateStreamAccumulator(requestID) accumulator.mu.Lock() - shouldProcess = !accumulator.IsComplete - // Mark as complete when we're about to process - if shouldProcess { + if !accumulator.IsComplete { accumulator.IsComplete = true } accumulator.mu.Unlock() - if shouldProcess { - data, processErr := a.processAccumulatedChatStreamingChunks(requestID, bifrostErr, isFinalChunk) - if processErr != nil { - a.logger.Error("failed to process accumulated chunks for request %s: %v", requestID, processErr) - return nil, processErr - } - var rawRequest interface{} - if result != nil && result.ChatResponse != nil && result.ChatResponse.ExtraFields.RawRequest != nil { - rawRequest = result.ChatResponse.ExtraFields.RawRequest - } - return &ProcessedStreamResponse{ - Type: StreamResponseTypeFinal, - RequestID: requestID, - StreamType: streamType, - Provider: provider, - Model: model, - Data: data, - RawRequest: &rawRequest, - }, nil + + // Always process and return data on final chunk + // Multiple plugins can call this - the processing is idempotent + data, processErr := a.processAccumulatedChatStreamingChunks(requestID, bifrostErr, isFinalChunk) + if processErr != nil { + a.logger.Error("failed to process accumulated chunks for request %s: %v", requestID, processErr) + return nil, processErr + } + var rawRequest interface{} + if result != nil && result.ChatResponse != nil && result.ChatResponse.ExtraFields.RawRequest != nil { + rawRequest = result.ChatResponse.ExtraFields.RawRequest + } else if result != nil && result.TextCompletionResponse != nil && result.TextCompletionResponse.ExtraFields.RawRequest != nil { + rawRequest = result.TextCompletionResponse.ExtraFields.RawRequest } - return nil, nil + return &ProcessedStreamResponse{ + RequestID: requestID, + StreamType: streamType, + Provider: provider, + Model: model, + Data: data, + RawRequest: &rawRequest, + }, nil } // This is going to be a delta response data, processErr := a.processAccumulatedChatStreamingChunks(requestID, bifrostErr, isFinalChunk) @@ -348,7 +485,6 @@ func (a *Accumulator) processChatStreamingResponse(ctx *schemas.BifrostContext, } // This is not the final chunk, so we will send back the delta return &ProcessedStreamResponse{ - Type: StreamResponseTypeDelta, RequestID: requestID, StreamType: streamType, Provider: provider, diff --git a/framework/streaming/responses.go b/framework/streaming/responses.go index ca418573c8..e6eb9c9cb8 100644 --- a/framework/streaming/responses.go +++ b/framework/streaming/responses.go @@ -191,6 +191,28 @@ func deepCopyResponsesMessage(original schemas.ResponsesMessage) schemas.Respons } } + // Deep copy ResponsesReasoning if present + if original.ResponsesReasoning != nil { + copy.ResponsesReasoning = &schemas.ResponsesReasoning{} + + // Deep copy Summary slice + if original.ResponsesReasoning.Summary != nil { + copy.ResponsesReasoning.Summary = make([]schemas.ResponsesReasoningSummary, len(original.ResponsesReasoning.Summary)) + for i, summary := range original.ResponsesReasoning.Summary { + copy.ResponsesReasoning.Summary[i] = schemas.ResponsesReasoningSummary{ + Type: summary.Type, + Text: summary.Text, + } + } + } + + // Deep copy EncryptedContent if present + if original.ResponsesReasoning.EncryptedContent != nil { + copyEncrypted := *original.ResponsesReasoning.EncryptedContent + copy.ResponsesReasoning.EncryptedContent = ©Encrypted + } + } + if original.ResponsesToolMessage != nil { copy.ResponsesToolMessage = &schemas.ResponsesToolMessage{} @@ -444,12 +466,14 @@ func (a *Accumulator) buildCompleteMessageFromResponsesStreamChunks(chunks []*Re switch resp.Type { case schemas.ResponsesStreamResponseTypeOutputItemAdded: // Always append new items - this fixes multiple function calls issue + // Deep copy to prevent shared pointer mutation when deltas are appended if resp.Item != nil { - messages = append(messages, *resp.Item) + messages = append(messages, deepCopyResponsesMessage(*resp.Item)) } case schemas.ResponsesStreamResponseTypeContentPartAdded: // Add content part to the most recent message, create message if none exists + // Deep copy to prevent shared pointer mutation if resp.Part != nil { if len(messages) == 0 { messages = append(messages, createNewMessage()) @@ -462,7 +486,7 @@ func (a *Accumulator) buildCompleteMessageFromResponsesStreamChunks(chunks []*Re if lastMsg.Content.ContentBlocks == nil { lastMsg.Content.ContentBlocks = make([]schemas.ResponsesMessageContentBlock, 0) } - lastMsg.Content.ContentBlocks = append(lastMsg.Content.ContentBlocks, *resp.Part) + lastMsg.Content.ContentBlocks = append(lastMsg.Content.ContentBlocks, deepCopyResponsesMessageContentBlock(*resp.Part)) } case schemas.ResponsesStreamResponseTypeOutputTextDelta: @@ -487,8 +511,9 @@ func (a *Accumulator) buildCompleteMessageFromResponsesStreamChunks(chunks []*Re if len(messages) == 0 { messages = append(messages, createNewMessage()) } + // Deep copy to prevent shared pointer mutation when arguments are appended if resp.Item != nil { - messages = append(messages, *resp.Item) + messages = append(messages, deepCopyResponsesMessage(*resp.Item)) } // Append arguments to the most recent message if resp.Delta != nil && len(messages) > 0 { @@ -510,8 +535,14 @@ func (a *Accumulator) buildCompleteMessageFromResponsesStreamChunks(chunks []*Re // If no message found, create a new reasoning message if targetMessage == nil { + // Deep copy ItemID to prevent shared pointer mutation + var copyID *string + if resp.ItemID != nil { + id := *resp.ItemID + copyID = &id + } newMessage := schemas.ResponsesMessage{ - ID: resp.ItemID, + ID: copyID, Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning), Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), ResponsesReasoning: &schemas.ResponsesReasoning{ @@ -731,28 +762,31 @@ func (a *Accumulator) processAccumulatedResponsesStreamingChunks(requestID strin accumulator := a.getOrCreateStreamAccumulator(requestID) // Lock the accumulator accumulator.mu.Lock() - defer func() { - if isFinalChunk { - // Cleanup BEFORE unlocking to prevent other goroutines from accessing chunks being returned to pool - a.cleanupStreamAccumulator(requestID) - } - accumulator.mu.Unlock() - }() + defer accumulator.mu.Unlock() + // Note: Cleanup is handled by CleanupStreamAccumulator when refcount reaches 0 + // This is called from completeDeferredSpan after streaming ends + + // Calculate Time to First Token (TTFT) in milliseconds + var ttft int64 + if !accumulator.StartTimestamp.IsZero() && !accumulator.FirstChunkTimestamp.IsZero() { + ttft = accumulator.FirstChunkTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6 + } // Initialize accumulated data data := &AccumulatedData{ - RequestID: requestID, - Status: "success", - Stream: true, - StartTimestamp: accumulator.StartTimestamp, - EndTimestamp: accumulator.FinalTimestamp, - Latency: 0, - OutputMessages: nil, - ToolCalls: nil, - ErrorDetails: respErr, - TokenUsage: nil, - CacheDebug: nil, - Cost: nil, + RequestID: requestID, + Status: "success", + Stream: true, + StartTimestamp: accumulator.StartTimestamp, + EndTimestamp: accumulator.FinalTimestamp, + Latency: 0, + TimeToFirstToken: ttft, + OutputMessages: nil, + ToolCalls: nil, + ErrorDetails: respErr, + TokenUsage: nil, + CacheDebug: nil, + Cost: nil, } // Build complete messages from accumulated chunks @@ -780,21 +814,14 @@ func (a *Accumulator) processAccumulatedResponsesStreamingChunks(requestID strin data.ErrorDetails = respErr - // Update token usage from final chunk if available - if len(accumulator.ResponsesStreamChunks) > 0 { - lastChunk := accumulator.ResponsesStreamChunks[len(accumulator.ResponsesStreamChunks)-1] + // Update metadata from the chunk with highest index (contains TokenUsage, Cost, FinishReason) + if lastChunk := accumulator.getLastResponsesChunk(); lastChunk != nil { if lastChunk.TokenUsage != nil { data.TokenUsage = lastChunk.TokenUsage } - // Handle cache debug if lastChunk.SemanticCacheDebug != nil { data.CacheDebug = lastChunk.SemanticCacheDebug } - } - - // Update cost from final chunk if available - if len(accumulator.ResponsesStreamChunks) > 0 { - lastChunk := accumulator.ResponsesStreamChunks[len(accumulator.ResponsesStreamChunks)-1] if lastChunk.Cost != nil { data.Cost = lastChunk.Cost } @@ -825,10 +852,10 @@ func (a *Accumulator) processAccumulatedResponsesStreamingChunks(requestID strin func (a *Accumulator) processResponsesStreamingResponse(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*ProcessedStreamResponse, error) { a.logger.Debug("[streaming] processing responses streaming response") - // Extract request ID from context - requestID, ok := (*ctx).Value(schemas.BifrostContextKeyRequestID).(string) + // Extract accumulator ID from context + requestID, ok := getAccumulatorID(ctx) if !ok || requestID == "" { - return nil, fmt.Errorf("request-id not found in context or is empty") + return nil, fmt.Errorf("accumulator-id not found in context or is empty") } _, provider, model := bifrost.GetResponseFields(result, bifrostErr) @@ -865,80 +892,73 @@ func (a *Accumulator) processResponsesStreamingResponse(ctx *schemas.BifrostCont if result != nil && result.ResponsesStreamResponse != nil && result.ResponsesStreamResponse.ExtraFields.RawRequest != nil { rawRequest = result.ResponsesStreamResponse.ExtraFields.RawRequest } - shouldProcess := false - // Get the accumulator to check if processing has already been triggered + // Get the accumulator and mark as complete (idempotent) accumulator := a.getOrCreateStreamAccumulator(requestID) accumulator.mu.Lock() - shouldProcess = !accumulator.IsComplete - // Mark as complete when we're about to process - if shouldProcess { + if !accumulator.IsComplete { accumulator.IsComplete = true } accumulator.mu.Unlock() - if shouldProcess { - accumulatedData, processErr := a.processAccumulatedResponsesStreamingChunks(requestID, bifrostErr, isFinalChunk) - if processErr != nil { - a.logger.Error("failed to process accumulated responses chunks for request %s: %v", requestID, processErr) - return nil, processErr - } + // Always process and return data on final chunk + // Multiple plugins can call this - the processing is idempotent + accumulatedData, processErr := a.processAccumulatedResponsesStreamingChunks(requestID, bifrostErr, isFinalChunk) + if processErr != nil { + a.logger.Error("failed to process accumulated responses chunks for request %s: %v", requestID, processErr) + return nil, processErr + } - // For OpenAI, the final chunk contains the complete response - // Extract the complete response and return it - if result != nil && result.ResponsesStreamResponse != nil { - // Build the complete response from the final chunk - data := &AccumulatedData{ - RequestID: requestID, - Status: "success", - Stream: true, - StartTimestamp: startTimestamp, - EndTimestamp: endTimestamp, - Latency: result.GetExtraFields().Latency, - ErrorDetails: bifrostErr, - RawResponse: accumulatedData.RawResponse, - } + // For OpenAI, the final chunk contains the complete response + // Extract the complete response and return it + if result != nil && result.ResponsesStreamResponse != nil { + // Build the complete response from the final chunk + data := &AccumulatedData{ + RequestID: requestID, + Status: "success", + Stream: true, + StartTimestamp: startTimestamp, + EndTimestamp: endTimestamp, + Latency: result.GetExtraFields().Latency, + ErrorDetails: bifrostErr, + RawResponse: accumulatedData.RawResponse, + } - if bifrostErr != nil { - data.Status = "error" - } + if bifrostErr != nil { + data.Status = "error" + } - // Extract the complete response from the stream response - if result.ResponsesStreamResponse.Response != nil { - data.OutputMessages = result.ResponsesStreamResponse.Response.Output - if result.ResponsesStreamResponse.Response.Usage != nil { - // Convert ResponsesResponseUsage to schemas.LLMUsage - data.TokenUsage = &schemas.BifrostLLMUsage{ - PromptTokens: result.ResponsesStreamResponse.Response.Usage.InputTokens, - CompletionTokens: result.ResponsesStreamResponse.Response.Usage.OutputTokens, - TotalTokens: result.ResponsesStreamResponse.Response.Usage.TotalTokens, - } + // Extract the complete response from the stream response + if result.ResponsesStreamResponse.Response != nil { + data.OutputMessages = result.ResponsesStreamResponse.Response.Output + if result.ResponsesStreamResponse.Response.Usage != nil { + // Convert ResponsesResponseUsage to schemas.LLMUsage + data.TokenUsage = &schemas.BifrostLLMUsage{ + PromptTokens: result.ResponsesStreamResponse.Response.Usage.InputTokens, + CompletionTokens: result.ResponsesStreamResponse.Response.Usage.OutputTokens, + TotalTokens: result.ResponsesStreamResponse.Response.Usage.TotalTokens, } } + } - if a.pricingManager != nil { - cost := a.pricingManager.CalculateCostWithCacheDebug(result) - data.Cost = bifrost.Ptr(cost) - } - - return &ProcessedStreamResponse{ - Type: StreamResponseTypeFinal, - RequestID: requestID, - StreamType: StreamTypeResponses, - Provider: provider, - Model: model, - Data: data, - RawRequest: &rawRequest, - }, nil - } else { - return nil, nil + if a.pricingManager != nil { + cost := a.pricingManager.CalculateCostWithCacheDebug(result) + data.Cost = bifrost.Ptr(cost) } + + return &ProcessedStreamResponse{ + RequestID: requestID, + StreamType: StreamTypeResponses, + Provider: provider, + Model: model, + Data: data, + RawRequest: &rawRequest, + }, nil } return nil, nil } // For non-final chunks from OpenAI, just pass through return &ProcessedStreamResponse{ - Type: StreamResponseTypeDelta, RequestID: requestID, StreamType: StreamTypeResponses, Provider: provider, @@ -985,45 +1005,40 @@ func (a *Accumulator) processResponsesStreamingResponse(ctx *schemas.BifrostCont } // If this is the final chunk, process accumulated chunks + // Always return data on final chunk - multiple plugins may need the result if isFinalChunk { - shouldProcess := false - // Get the accumulator to check if processing has already been triggered + // Get the accumulator and mark as complete (idempotent) accumulator := a.getOrCreateStreamAccumulator(requestID) accumulator.mu.Lock() - shouldProcess = !accumulator.IsComplete - // Mark as complete when we're about to process - if shouldProcess { + if !accumulator.IsComplete { accumulator.IsComplete = true } accumulator.mu.Unlock() - if shouldProcess { - data, processErr := a.processAccumulatedResponsesStreamingChunks(requestID, bifrostErr, isFinalChunk) - if processErr != nil { - a.logger.Error("failed to process accumulated responses chunks for request %s: %v", requestID, processErr) - return nil, processErr - } - - var rawRequest interface{} - if result != nil && result.ResponsesStreamResponse != nil && result.ResponsesStreamResponse.ExtraFields.RawRequest != nil { - rawRequest = result.ResponsesStreamResponse.ExtraFields.RawRequest - } + // Always process and return data on final chunk + // Multiple plugins can call this - the processing is idempotent + data, processErr := a.processAccumulatedResponsesStreamingChunks(requestID, bifrostErr, isFinalChunk) + if processErr != nil { + a.logger.Error("failed to process accumulated responses chunks for request %s: %v", requestID, processErr) + return nil, processErr + } - return &ProcessedStreamResponse{ - Type: StreamResponseTypeFinal, - RequestID: requestID, - StreamType: StreamTypeResponses, - Provider: provider, - Model: model, - Data: data, - RawRequest: &rawRequest, - }, nil + var rawRequest interface{} + if result != nil && result.ResponsesStreamResponse != nil && result.ResponsesStreamResponse.ExtraFields.RawRequest != nil { + rawRequest = result.ResponsesStreamResponse.ExtraFields.RawRequest } - return nil, nil + + return &ProcessedStreamResponse{ + RequestID: requestID, + StreamType: StreamTypeResponses, + Provider: provider, + Model: model, + Data: data, + RawRequest: &rawRequest, + }, nil } return &ProcessedStreamResponse{ - Type: StreamResponseTypeDelta, RequestID: requestID, StreamType: StreamTypeResponses, Provider: provider, diff --git a/framework/streaming/transcription.go b/framework/streaming/transcription.go index 314f4cdb55..1794e54e2f 100644 --- a/framework/streaming/transcription.go +++ b/framework/streaming/transcription.go @@ -34,26 +34,30 @@ func (a *Accumulator) processAccumulatedTranscriptionStreamingChunks(requestID s accumulator := a.getOrCreateStreamAccumulator(requestID) // Lock the accumulator accumulator.mu.Lock() - defer func() { - if isFinalChunk { - // Cleanup BEFORE unlocking to prevent other goroutines from accessing chunks being returned to pool - a.cleanupStreamAccumulator(requestID) - } - accumulator.mu.Unlock() - }() + defer accumulator.mu.Unlock() + // Note: Cleanup is handled by CleanupStreamAccumulator when refcount reaches 0 + // This is called from completeDeferredSpan after streaming ends + + // Calculate Time to First Token (TTFT) in milliseconds + var ttft int64 + if !accumulator.StartTimestamp.IsZero() && !accumulator.FirstChunkTimestamp.IsZero() { + ttft = accumulator.FirstChunkTimestamp.Sub(accumulator.StartTimestamp).Nanoseconds() / 1e6 + } + data := &AccumulatedData{ - RequestID: requestID, - Status: "success", - Stream: true, - StartTimestamp: accumulator.StartTimestamp, - EndTimestamp: accumulator.FinalTimestamp, - Latency: 0, - OutputMessage: nil, - ToolCalls: nil, - ErrorDetails: nil, - TokenUsage: nil, - CacheDebug: nil, - Cost: nil, + RequestID: requestID, + Status: "success", + Stream: true, + StartTimestamp: accumulator.StartTimestamp, + EndTimestamp: accumulator.FinalTimestamp, + Latency: 0, + TimeToFirstToken: ttft, + OutputMessage: nil, + ToolCalls: nil, + ErrorDetails: nil, + TokenUsage: nil, + CacheDebug: nil, + Cost: nil, } // Build complete message from accumulated chunks completeMessage := a.buildCompleteMessageFromTranscriptionStreamChunks(accumulator.TranscriptionStreamChunks) @@ -73,9 +77,8 @@ func (a *Accumulator) processAccumulatedTranscriptionStreamingChunks(requestID s data.EndTimestamp = accumulator.FinalTimestamp data.TranscriptionOutput = completeMessage data.ErrorDetails = bifrostErr - // Update token usage from final chunk if available - if len(accumulator.TranscriptionStreamChunks) > 0 { - lastChunk := accumulator.TranscriptionStreamChunks[len(accumulator.TranscriptionStreamChunks)-1] + // Update metadata from the chunk with highest index (contains TokenUsage, Cost, CacheDebug) + if lastChunk := accumulator.getLastTranscriptionChunk(); lastChunk != nil { if lastChunk.TokenUsage != nil { data.TokenUsage = &schemas.BifrostLLMUsage{} if lastChunk.TokenUsage.InputTokens != nil { @@ -88,17 +91,9 @@ func (a *Accumulator) processAccumulatedTranscriptionStreamingChunks(requestID s data.TokenUsage.TotalTokens = *lastChunk.TokenUsage.TotalTokens } } - } - // Update cost from final chunk if available - if len(accumulator.TranscriptionStreamChunks) > 0 { - lastChunk := accumulator.TranscriptionStreamChunks[len(accumulator.TranscriptionStreamChunks)-1] if lastChunk.Cost != nil { data.Cost = lastChunk.Cost } - } - // Update semantic cache debug from final chunk if available - if len(accumulator.TranscriptionStreamChunks) > 0 { - lastChunk := accumulator.TranscriptionStreamChunks[len(accumulator.TranscriptionStreamChunks)-1] if lastChunk.SemanticCacheDebug != nil { data.CacheDebug = lastChunk.SemanticCacheDebug } @@ -124,11 +119,11 @@ func (a *Accumulator) processAccumulatedTranscriptionStreamingChunks(requestID s // processTranscriptionStreamingResponse processes a transcription streaming response func (a *Accumulator) processTranscriptionStreamingResponse(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*ProcessedStreamResponse, error) { - // Extract request ID from context - requestID, ok := (*ctx).Value(schemas.BifrostContextKeyRequestID).(string) + // Extract accumulator ID from context + requestID, ok := getAccumulatorID(ctx) if !ok || requestID == "" { // Log error but don't fail the request - return nil, fmt.Errorf("request-id not found in context or is empty") + return nil, fmt.Errorf("accumulator-id not found in context or is empty") } _, provider, model := bifrost.GetResponseFields(result, bifrostErr) isFinalChunk := bifrost.IsFinalChunk(ctx) @@ -171,36 +166,35 @@ func (a *Accumulator) processTranscriptionStreamingResponse(ctx *schemas.Bifrost if addErr := a.addTranscriptionStreamChunk(requestID, chunk, isFinalChunk); addErr != nil { return nil, fmt.Errorf("failed to add stream chunk for request %s: %w", requestID, addErr) } + // Always return data on final chunk - multiple plugins may need the result if isFinalChunk { - shouldProcess := false + // Get the accumulator and mark as complete (idempotent) accumulator := a.getOrCreateStreamAccumulator(requestID) accumulator.mu.Lock() - shouldProcess = !accumulator.IsComplete - if shouldProcess { + if !accumulator.IsComplete { accumulator.IsComplete = true } accumulator.mu.Unlock() - if shouldProcess { - data, processErr := a.processAccumulatedTranscriptionStreamingChunks(requestID, bifrostErr, isFinalChunk) - if processErr != nil { - a.logger.Error("failed to process accumulated chunks for request %s: %v", requestID, processErr) - return nil, processErr - } - var rawRequest interface{} - if result != nil && result.TranscriptionStreamResponse != nil && result.TranscriptionStreamResponse.ExtraFields.RawRequest != nil { - rawRequest = result.TranscriptionStreamResponse.ExtraFields.RawRequest - } - return &ProcessedStreamResponse{ - Type: StreamResponseTypeFinal, - RequestID: requestID, - StreamType: StreamTypeTranscription, - Provider: provider, - Model: model, - Data: data, - RawRequest: &rawRequest, - }, nil + + // Always process and return data on final chunk + // Multiple plugins can call this - the processing is idempotent + data, processErr := a.processAccumulatedTranscriptionStreamingChunks(requestID, bifrostErr, isFinalChunk) + if processErr != nil { + a.logger.Error("failed to process accumulated chunks for request %s: %v", requestID, processErr) + return nil, processErr + } + var rawRequest interface{} + if result != nil && result.TranscriptionStreamResponse != nil && result.TranscriptionStreamResponse.ExtraFields.RawRequest != nil { + rawRequest = result.TranscriptionStreamResponse.ExtraFields.RawRequest } - return nil, nil + return &ProcessedStreamResponse{ + RequestID: requestID, + StreamType: StreamTypeTranscription, + Provider: provider, + Model: model, + Data: data, + RawRequest: &rawRequest, + }, nil } data, processErr := a.processAccumulatedTranscriptionStreamingChunks(requestID, bifrostErr, isFinalChunk) if processErr != nil { @@ -208,7 +202,6 @@ func (a *Accumulator) processTranscriptionStreamingResponse(ctx *schemas.Bifrost return nil, processErr } return &ProcessedStreamResponse{ - Type: StreamResponseTypeDelta, RequestID: requestID, StreamType: StreamTypeTranscription, Provider: provider, diff --git a/framework/streaming/types.go b/framework/streaming/types.go index 29bb62dfc3..ad0f9f4891 100644 --- a/framework/streaming/types.go +++ b/framework/streaming/types.go @@ -2,6 +2,7 @@ package streaming import ( "sync" + "sync/atomic" "time" schemas "github.com/maximhq/bifrost/core/schemas" @@ -17,13 +18,6 @@ const ( StreamTypeResponses StreamType = "responses" ) -type StreamResponseType string - -const ( - StreamResponseTypeDelta StreamResponseType = "delta" - StreamResponseTypeFinal StreamResponseType = "final" -) - // AccumulatedData contains the accumulated data for a stream type AccumulatedData struct { RequestID string @@ -31,6 +25,7 @@ type AccumulatedData struct { Status string Stream bool Latency int64 // in milliseconds + TimeToFirstToken int64 // Time to first token in milliseconds (streaming only) StartTimestamp time.Time EndTimestamp time.Time OutputMessage *schemas.ChatMessage @@ -102,19 +97,93 @@ type ResponsesStreamChunk struct { type StreamAccumulator struct { RequestID string StartTimestamp time.Time + FirstChunkTimestamp time.Time // Timestamp when the first chunk was received (for TTFT calculation) ChatStreamChunks []*ChatStreamChunk ResponsesStreamChunks []*ResponsesStreamChunk TranscriptionStreamChunks []*TranscriptionStreamChunk AudioStreamChunks []*AudioStreamChunk - IsComplete bool - FinalTimestamp time.Time - mu sync.Mutex - Timestamp time.Time + + // De-dup maps to prevent chunk loss on out-of-order arrival + ChatChunksSeen map[int]struct{} + ResponsesChunksSeen map[int]struct{} + TranscriptionChunksSeen map[int]struct{} + AudioChunksSeen map[int]struct{} + + // Track highest ChunkIndex for metadata extraction (TokenUsage, Cost, FinishReason) + MaxChatChunkIndex int + MaxResponsesChunkIndex int + MaxTranscriptionChunkIndex int + MaxAudioChunkIndex int + + IsComplete bool + FinalTimestamp time.Time + mu sync.Mutex + Timestamp time.Time + refCount atomic.Int64 +} + +// getLastChatChunk returns the chunk with the highest ChunkIndex (contains metadata like TokenUsage, Cost) +func (sa *StreamAccumulator) getLastChatChunk() *ChatStreamChunk { + sa.mu.Lock() + defer sa.mu.Unlock() + if sa.MaxChatChunkIndex < 0 { + return nil + } + for _, chunk := range sa.ChatStreamChunks { + if chunk.ChunkIndex == sa.MaxChatChunkIndex { + return chunk + } + } + return nil +} + +// getLastResponsesChunk returns the chunk with the highest ChunkIndex (contains metadata like TokenUsage, Cost) +func (sa *StreamAccumulator) getLastResponsesChunk() *ResponsesStreamChunk { + sa.mu.Lock() + defer sa.mu.Unlock() + if sa.MaxResponsesChunkIndex < 0 { + return nil + } + for _, chunk := range sa.ResponsesStreamChunks { + if chunk.ChunkIndex == sa.MaxResponsesChunkIndex { + return chunk + } + } + return nil +} + +// getLastTranscriptionChunk returns the chunk with the highest ChunkIndex (contains metadata like TokenUsage, Cost) +func (sa *StreamAccumulator) getLastTranscriptionChunk() *TranscriptionStreamChunk { + sa.mu.Lock() + defer sa.mu.Unlock() + if sa.MaxTranscriptionChunkIndex < 0 { + return nil + } + for _, chunk := range sa.TranscriptionStreamChunks { + if chunk.ChunkIndex == sa.MaxTranscriptionChunkIndex { + return chunk + } + } + return nil +} + +// getLastAudioChunk returns the chunk with the highest ChunkIndex (contains metadata like TokenUsage, Cost) +func (sa *StreamAccumulator) getLastAudioChunk() *AudioStreamChunk { + sa.mu.Lock() + defer sa.mu.Unlock() + if sa.MaxAudioChunkIndex < 0 { + return nil + } + for _, chunk := range sa.AudioStreamChunks { + if chunk.ChunkIndex == sa.MaxAudioChunkIndex { + return chunk + } + } + return nil } // ProcessedStreamResponse represents a processed streaming response type ProcessedStreamResponse struct { - Type StreamResponseType RequestID string StreamType StreamType Provider schemas.ModelProvider @@ -159,7 +228,23 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { if p.RawRequest != nil { resp.TextCompletionResponse.ExtraFields.RawRequest = p.RawRequest } + if p.Data.RawResponse != nil { + resp.TextCompletionResponse.ExtraFields.RawResponse = *p.Data.RawResponse + } + if p.Data.CacheDebug != nil { + resp.TextCompletionResponse.ExtraFields.CacheDebug = p.Data.CacheDebug + } case StreamTypeChat: + var message *schemas.ChatMessage + if p.Data.OutputMessage != nil { + message = &schemas.ChatMessage{ + Role: p.Data.OutputMessage.Role, + Content: p.Data.OutputMessage.Content, + ChatAssistantMessage: p.Data.OutputMessage.ChatAssistantMessage, + ChatToolMessage: p.Data.OutputMessage.ChatToolMessage, + Name: p.Data.OutputMessage.Name, + } + } chatResp := &schemas.BifrostChatResponse{ ID: p.RequestID, Object: "chat.completion", @@ -169,38 +254,14 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { { Index: 0, FinishReason: p.Data.FinishReason, + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: message, + }, }, }, Usage: p.Data.TokenUsage, } - // Get reference to the choice in the slice so we can modify it - choice := &chatResp.Choices[0] - - if p.Data.OutputMessage.Content.ContentStr != nil { - choice.ChatNonStreamResponseChoice = &schemas.ChatNonStreamResponseChoice{ - Message: &schemas.ChatMessage{ - Role: schemas.ChatMessageRoleAssistant, - Content: &schemas.ChatMessageContent{ - ContentStr: p.Data.OutputMessage.Content.ContentStr, - }, - }, - } - } - if p.Data.OutputMessage.ChatAssistantMessage != nil { - if choice.ChatNonStreamResponseChoice == nil { - choice.ChatNonStreamResponseChoice = &schemas.ChatNonStreamResponseChoice{ - Message: &schemas.ChatMessage{ - Role: schemas.ChatMessageRoleAssistant, - ChatAssistantMessage: p.Data.OutputMessage.ChatAssistantMessage, - }, - } - } else { - // If we already have a message, we need to add the ChatAssistantMessage to it - choice.ChatNonStreamResponseChoice.Message.ChatAssistantMessage = p.Data.OutputMessage.ChatAssistantMessage - } - } - resp.ChatResponse = chatResp resp.ChatResponse.ExtraFields = schemas.BifrostResponseExtraFields{ RequestType: schemas.ChatCompletionRequest, @@ -211,6 +272,12 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { if p.RawRequest != nil { resp.ChatResponse.ExtraFields.RawRequest = p.RawRequest } + if p.Data.RawResponse != nil { + resp.ChatResponse.ExtraFields.RawResponse = *p.Data.RawResponse + } + if p.Data.CacheDebug != nil { + resp.ChatResponse.ExtraFields.CacheDebug = p.Data.CacheDebug + } case StreamTypeResponses: responsesResp := &schemas.BifrostResponsesResponse{} @@ -229,6 +296,12 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { if p.RawRequest != nil { responsesResp.ExtraFields.RawRequest = p.RawRequest } + if p.Data.RawResponse != nil { + responsesResp.ExtraFields.RawResponse = *p.Data.RawResponse + } + if p.Data.CacheDebug != nil { + responsesResp.ExtraFields.CacheDebug = p.Data.CacheDebug + } resp.ResponsesResponse = responsesResp case StreamTypeAudio: speechResp := p.Data.AudioOutput @@ -245,6 +318,12 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { if p.RawRequest != nil { resp.SpeechResponse.ExtraFields.RawRequest = p.RawRequest } + if p.Data.RawResponse != nil { + resp.SpeechResponse.ExtraFields.RawResponse = *p.Data.RawResponse + } + if p.Data.CacheDebug != nil { + resp.SpeechResponse.ExtraFields.CacheDebug = p.Data.CacheDebug + } case StreamTypeTranscription: transcriptionResp := p.Data.TranscriptionOutput if transcriptionResp == nil { @@ -260,6 +339,12 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { if p.RawRequest != nil { resp.TranscriptionResponse.ExtraFields.RawRequest = p.RawRequest } + if p.Data.RawResponse != nil { + resp.TranscriptionResponse.ExtraFields.RawResponse = *p.Data.RawResponse + } + if p.Data.CacheDebug != nil { + resp.TranscriptionResponse.ExtraFields.CacheDebug = p.Data.CacheDebug + } } return resp } diff --git a/framework/tracing/helpers.go b/framework/tracing/helpers.go new file mode 100644 index 0000000000..edb9a73106 --- /dev/null +++ b/framework/tracing/helpers.go @@ -0,0 +1,83 @@ +// Package tracing provides distributed tracing infrastructure for Bifrost +package tracing + +import ( + "context" + + "github.com/maximhq/bifrost/core/schemas" +) + +// GetTraceID retrieves the trace ID from the context +func GetTraceID(ctx context.Context) string { + if ctx == nil { + return "" + } + traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string) + if !ok { + return "" + } + return traceID +} + +// GetTrace retrieves the current trace from context using the store +func GetTrace(ctx context.Context, store *TraceStore) *schemas.Trace { + traceID := GetTraceID(ctx) + if traceID == "" { + return nil + } + return store.GetTrace(traceID) +} + +// AddSpan adds a new span to the current trace +func AddSpan(ctx context.Context, store *TraceStore, name string, kind schemas.SpanKind) *schemas.Span { + traceID := GetTraceID(ctx) + if traceID == "" { + return nil + } + return store.StartSpan(traceID, name, kind) +} + +// AddChildSpan adds a new child span to the current trace under a specific parent +func AddChildSpan(ctx context.Context, store *TraceStore, parentSpanID, name string, kind schemas.SpanKind) *schemas.Span { + traceID := GetTraceID(ctx) + if traceID == "" { + return nil + } + return store.StartChildSpan(traceID, parentSpanID, name, kind) +} + +// EndSpan completes a span with the given status +func EndSpan(ctx context.Context, store *TraceStore, spanID string, status schemas.SpanStatus, statusMsg string, attrs map[string]any) { + traceID := GetTraceID(ctx) + if traceID == "" { + return + } + store.EndSpan(traceID, spanID, status, statusMsg, attrs) +} + +// SetSpanAttribute sets an attribute on a span +func SetSpanAttribute(ctx context.Context, store *TraceStore, spanID, key string, value any) { + trace := GetTrace(ctx, store) + if trace == nil { + return + } + span := trace.GetSpan(spanID) + if span == nil { + return + } + span.SetAttribute(key, value) +} + +// AddSpanEvent adds an event to a span +func AddSpanEvent(ctx context.Context, store *TraceStore, spanID string, event schemas.SpanEvent) { + trace := GetTrace(ctx, store) + if trace == nil { + return + } + span := trace.GetSpan(spanID) + if span == nil { + return + } + span.AddEvent(event) +} + diff --git a/framework/tracing/llmspan.go b/framework/tracing/llmspan.go new file mode 100644 index 0000000000..fe3958ed22 --- /dev/null +++ b/framework/tracing/llmspan.go @@ -0,0 +1,1343 @@ +// Package tracing provides distributed tracing utilities for Bifrost. +package tracing + +import ( + "fmt" + "strings" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/modelcatalog" +) + +// PopulateRequestAttributes extracts common request attributes from a BifrostRequest. +// This is the main entry point for populating request attributes on a span. +func PopulateRequestAttributes(req *schemas.BifrostRequest) map[string]any { + attrs := make(map[string]any) + if req == nil { + return attrs + } + + provider, model, _ := req.GetRequestFields() + attrs[schemas.AttrProviderName] = string(provider) + attrs[schemas.AttrRequestModel] = model + + switch req.RequestType { + case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: + PopulateChatRequestAttributes(req.ChatRequest, attrs) + case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest: + PopulateTextCompletionRequestAttributes(req.TextCompletionRequest, attrs) + case schemas.EmbeddingRequest: + PopulateEmbeddingRequestAttributes(req.EmbeddingRequest, attrs) + case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: + PopulateTranscriptionRequestAttributes(req.TranscriptionRequest, attrs) + case schemas.SpeechRequest, schemas.SpeechStreamRequest: + PopulateSpeechRequestAttributes(req.SpeechRequest, attrs) + case schemas.ResponsesRequest, schemas.ResponsesStreamRequest: + PopulateResponsesRequestAttributes(req.ResponsesRequest, attrs) + case schemas.BatchCreateRequest: + PopulateBatchCreateRequestAttributes(req.BatchCreateRequest, attrs) + case schemas.BatchListRequest: + PopulateBatchListRequestAttributes(req.BatchListRequest, attrs) + case schemas.BatchRetrieveRequest: + PopulateBatchRetrieveRequestAttributes(req.BatchRetrieveRequest, attrs) + case schemas.BatchCancelRequest: + PopulateBatchCancelRequestAttributes(req.BatchCancelRequest, attrs) + case schemas.BatchResultsRequest: + PopulateBatchResultsRequestAttributes(req.BatchResultsRequest, attrs) + case schemas.FileUploadRequest: + PopulateFileUploadRequestAttributes(req.FileUploadRequest, attrs) + case schemas.FileListRequest: + PopulateFileListRequestAttributes(req.FileListRequest, attrs) + case schemas.FileRetrieveRequest: + PopulateFileRetrieveRequestAttributes(req.FileRetrieveRequest, attrs) + case schemas.FileDeleteRequest: + PopulateFileDeleteRequestAttributes(req.FileDeleteRequest, attrs) + case schemas.FileContentRequest: + PopulateFileContentRequestAttributes(req.FileContentRequest, attrs) + } + + return attrs +} + +// PopulateResponseAttributes extracts common response attributes from a BifrostResponse. +// This is the main entry point for populating response attributes on a span. +func PopulateResponseAttributes(resp *schemas.BifrostResponse) map[string]any { + attrs := make(map[string]any) + if resp == nil { + return attrs + } + + switch { + case resp.ChatResponse != nil: + PopulateChatResponseAttributes(resp.ChatResponse, attrs) + case resp.TextCompletionResponse != nil: + PopulateTextCompletionResponseAttributes(resp.TextCompletionResponse, attrs) + case resp.EmbeddingResponse != nil: + PopulateEmbeddingResponseAttributes(resp.EmbeddingResponse, attrs) + case resp.TranscriptionResponse != nil: + PopulateTranscriptionResponseAttributes(resp.TranscriptionResponse, attrs) + case resp.SpeechResponse != nil: + PopulateSpeechResponseAttributes(resp.SpeechResponse, attrs) + case resp.ResponsesResponse != nil: + PopulateResponsesResponseAttributes(resp.ResponsesResponse, attrs) + case resp.BatchCreateResponse != nil: + PopulateBatchCreateResponseAttributes(resp.BatchCreateResponse, attrs) + case resp.BatchListResponse != nil: + PopulateBatchListResponseAttributes(resp.BatchListResponse, attrs) + case resp.BatchRetrieveResponse != nil: + PopulateBatchRetrieveResponseAttributes(resp.BatchRetrieveResponse, attrs) + case resp.BatchCancelResponse != nil: + PopulateBatchCancelResponseAttributes(resp.BatchCancelResponse, attrs) + case resp.BatchResultsResponse != nil: + PopulateBatchResultsResponseAttributes(resp.BatchResultsResponse, attrs) + case resp.FileUploadResponse != nil: + PopulateFileUploadResponseAttributes(resp.FileUploadResponse, attrs) + case resp.FileListResponse != nil: + PopulateFileListResponseAttributes(resp.FileListResponse, attrs) + case resp.FileRetrieveResponse != nil: + PopulateFileRetrieveResponseAttributes(resp.FileRetrieveResponse, attrs) + case resp.FileDeleteResponse != nil: + PopulateFileDeleteResponseAttributes(resp.FileDeleteResponse, attrs) + case resp.FileContentResponse != nil: + PopulateFileContentResponseAttributes(resp.FileContentResponse, attrs) + } + + return attrs +} + +// PopulateErrorAttributes extracts error attributes from a BifrostError. +func PopulateErrorAttributes(err *schemas.BifrostError) map[string]any { + attrs := make(map[string]any) + if err == nil || err.Error == nil { + return attrs + } + + attrs[schemas.AttrError] = err.Error.Message + if err.Error.Type != nil { + attrs[schemas.AttrErrorType] = *err.Error.Type + } + if err.Error.Code != nil { + attrs[schemas.AttrErrorCode] = *err.Error.Code + } + + return attrs +} + +// PopulateContextAttributes extracts context-related attributes (virtual keys, retries, etc.) +func PopulateContextAttributes( + attrs map[string]any, + virtualKeyID, virtualKeyName string, + selectedKeyID, selectedKeyName string, + teamID, teamName string, + customerID, customerName string, + numberOfRetries, fallbackIndex int, +) { + if virtualKeyID != "" { + attrs[schemas.AttrVirtualKeyID] = virtualKeyID + attrs[schemas.AttrVirtualKeyName] = virtualKeyName + } + if selectedKeyID != "" { + attrs[schemas.AttrSelectedKeyID] = selectedKeyID + attrs[schemas.AttrSelectedKeyName] = selectedKeyName + } + if teamID != "" { + attrs[schemas.AttrTeamID] = teamID + attrs[schemas.AttrTeamName] = teamName + } + if customerID != "" { + attrs[schemas.AttrCustomerID] = customerID + attrs[schemas.AttrCustomerName] = customerName + } + attrs[schemas.AttrNumberOfRetries] = numberOfRetries + attrs[schemas.AttrFallbackIndex] = fallbackIndex +} + +// =============================================== +// Chat Completion Request/Response +// =============================================== + +// PopulateChatRequestAttributes extracts chat completion request attributes. +func PopulateChatRequestAttributes(req *schemas.BifrostChatRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.Params != nil { + if req.Params.MaxCompletionTokens != nil { + attrs[schemas.AttrMaxTokens] = *req.Params.MaxCompletionTokens + } + if req.Params.Temperature != nil { + attrs[schemas.AttrTemperature] = *req.Params.Temperature + } + if req.Params.TopP != nil { + attrs[schemas.AttrTopP] = *req.Params.TopP + } + if req.Params.Stop != nil { + attrs[schemas.AttrStopSequences] = strings.Join(req.Params.Stop, ",") + } + if req.Params.PresencePenalty != nil { + attrs[schemas.AttrPresencePenalty] = *req.Params.PresencePenalty + } + if req.Params.FrequencyPenalty != nil { + attrs[schemas.AttrFrequencyPenalty] = *req.Params.FrequencyPenalty + } + if req.Params.ParallelToolCalls != nil { + attrs[schemas.AttrParallelToolCall] = *req.Params.ParallelToolCalls + } + if req.Params.User != nil { + attrs[schemas.AttrRequestUser] = *req.Params.User + } + // ExtraParams + for k, v := range req.Params.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } + } + + // Extract input messages + if req.Input != nil { + attrs[schemas.AttrMessageCount] = len(req.Input) + messages := extractChatMessages(req.Input) + if len(messages) > 0 { + attrs[schemas.AttrInputMessages] = messages + } + } +} + +// PopulateChatResponseAttributes extracts chat completion response attributes. +func PopulateChatResponseAttributes(resp *schemas.BifrostChatResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrResponseID] = resp.ID + attrs[schemas.AttrResponseModel] = resp.Model + if resp.Object != "" { + attrs[schemas.AttrObject] = resp.Object + } + if resp.SystemFingerprint != "" { + attrs[schemas.AttrSystemFprint] = resp.SystemFingerprint + } + attrs[schemas.AttrCreated] = resp.Created + if resp.ServiceTier != nil { + attrs[schemas.AttrServiceTier] = *resp.ServiceTier + } + + // Extract output messages + outputMessages := extractChatResponseMessages(resp) + if len(outputMessages) > 0 { + attrs[schemas.AttrOutputMessages] = outputMessages + } + + // Extract finish reason from first choice + if len(resp.Choices) > 0 && resp.Choices[0].FinishReason != nil { + attrs[schemas.AttrFinishReason] = *resp.Choices[0].FinishReason + } + + // Usage + if resp.Usage != nil { + attrs[schemas.AttrPromptTokens] = resp.Usage.PromptTokens + attrs[schemas.AttrCompletionTokens] = resp.Usage.CompletionTokens + attrs[schemas.AttrTotalTokens] = resp.Usage.TotalTokens + } +} + +// =============================================== +// Text Completion Request/Response +// =============================================== + +// PopulateTextCompletionRequestAttributes extracts text completion request attributes. +func PopulateTextCompletionRequestAttributes(req *schemas.BifrostTextCompletionRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.Params != nil { + if req.Params.MaxTokens != nil { + attrs[schemas.AttrMaxTokens] = *req.Params.MaxTokens + } + if req.Params.Temperature != nil { + attrs[schemas.AttrTemperature] = *req.Params.Temperature + } + if req.Params.TopP != nil { + attrs[schemas.AttrTopP] = *req.Params.TopP + } + if req.Params.Stop != nil { + attrs[schemas.AttrStopSequences] = strings.Join(req.Params.Stop, ",") + } + if req.Params.PresencePenalty != nil { + attrs[schemas.AttrPresencePenalty] = *req.Params.PresencePenalty + } + if req.Params.FrequencyPenalty != nil { + attrs[schemas.AttrFrequencyPenalty] = *req.Params.FrequencyPenalty + } + if req.Params.BestOf != nil { + attrs[schemas.AttrBestOf] = *req.Params.BestOf + } + if req.Params.Echo != nil { + attrs[schemas.AttrEcho] = *req.Params.Echo + } + if req.Params.LogitBias != nil { + attrs[schemas.AttrLogitBias] = fmt.Sprintf("%v", req.Params.LogitBias) + } + if req.Params.LogProbs != nil { + attrs[schemas.AttrLogProbs] = *req.Params.LogProbs + } + if req.Params.N != nil { + attrs[schemas.AttrN] = *req.Params.N + } + if req.Params.Seed != nil { + attrs[schemas.AttrSeed] = *req.Params.Seed + } + if req.Params.Suffix != nil { + attrs[schemas.AttrSuffix] = *req.Params.Suffix + } + if req.Params.User != nil { + attrs[schemas.AttrRequestUser] = *req.Params.User + } + // ExtraParams + for k, v := range req.Params.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } + } + + // Extract input text + if req.Input != nil { + if req.Input.PromptStr != nil { + attrs[schemas.AttrInputText] = *req.Input.PromptStr + } else if req.Input.PromptArray != nil { + attrs[schemas.AttrInputText] = strings.Join(req.Input.PromptArray, ",") + } + } +} + +// PopulateTextCompletionResponseAttributes extracts text completion response attributes. +func PopulateTextCompletionResponseAttributes(resp *schemas.BifrostTextCompletionResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrResponseID] = resp.ID + attrs[schemas.AttrResponseModel] = resp.Model + if resp.Object != "" { + attrs[schemas.AttrObject] = resp.Object + } + if resp.SystemFingerprint != "" { + attrs[schemas.AttrSystemFprint] = resp.SystemFingerprint + } + + // Extract output text + var outputs []string + for _, choice := range resp.Choices { + if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil { + outputs = append(outputs, *choice.TextCompletionResponseChoice.Text) + } + } + if len(outputs) > 0 { + attrs[schemas.AttrOutputMessages] = outputs + } + + // Usage + if resp.Usage != nil { + attrs[schemas.AttrPromptTokens] = resp.Usage.PromptTokens + attrs[schemas.AttrCompletionTokens] = resp.Usage.CompletionTokens + attrs[schemas.AttrTotalTokens] = resp.Usage.TotalTokens + } +} + +// =============================================== +// Embedding Request/Response +// =============================================== + +// PopulateEmbeddingRequestAttributes extracts embedding request attributes. +func PopulateEmbeddingRequestAttributes(req *schemas.BifrostEmbeddingRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.Params != nil { + if req.Params.Dimensions != nil { + attrs[schemas.AttrDimensions] = *req.Params.Dimensions + } + if req.Params.EncodingFormat != nil { + attrs[schemas.AttrEncodingFormat] = *req.Params.EncodingFormat + } + // ExtraParams + for k, v := range req.Params.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } + } + + // Extract input + if req.Input != nil { + if req.Input.Text != nil { + attrs[schemas.AttrInputText] = *req.Input.Text + } else if req.Input.Texts != nil { + attrs[schemas.AttrInputText] = strings.Join(req.Input.Texts, ",") + } else if req.Input.Embedding != nil { + embedding := make([]string, len(req.Input.Embedding)) + for i, v := range req.Input.Embedding { + // Use a float‑safe representation; adjust precision as needed. + embedding[i] = fmt.Sprintf("%v", v) + } + attrs[schemas.AttrInputEmbedding] = strings.Join(embedding, ",") + } + } +} + +// PopulateEmbeddingResponseAttributes extracts embedding response attributes. +func PopulateEmbeddingResponseAttributes(resp *schemas.BifrostEmbeddingResponse, attrs map[string]any) { + if resp == nil { + return + } + // Usage + if resp.Usage != nil { + attrs[schemas.AttrPromptTokens] = resp.Usage.PromptTokens + attrs[schemas.AttrCompletionTokens] = resp.Usage.CompletionTokens + attrs[schemas.AttrTotalTokens] = resp.Usage.TotalTokens + } +} + +// =============================================== +// Transcription Request/Response +// =============================================== + +// PopulateTranscriptionRequestAttributes extracts transcription request attributes. +func PopulateTranscriptionRequestAttributes(req *schemas.BifrostTranscriptionRequest, attrs map[string]any) { + if req == nil || req.Params == nil { + return + } + + if req.Params.Language != nil { + attrs[schemas.AttrLanguage] = *req.Params.Language + } + if req.Params.Prompt != nil { + attrs[schemas.AttrPrompt] = *req.Params.Prompt + } + if req.Params.ResponseFormat != nil { + attrs[schemas.AttrResponseFormat] = *req.Params.ResponseFormat + } + if req.Params.Format != nil { + attrs[schemas.AttrFormat] = *req.Params.Format + } +} + +// PopulateTranscriptionResponseAttributes extracts transcription response attributes. +func PopulateTranscriptionResponseAttributes(resp *schemas.BifrostTranscriptionResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrOutputMessages] = resp.Text + + // Usage + if resp.Usage != nil { + if resp.Usage.InputTokens != nil { + attrs[schemas.AttrInputTokens] = *resp.Usage.InputTokens + } + if resp.Usage.OutputTokens != nil { + attrs[schemas.AttrOutputTokens] = *resp.Usage.OutputTokens + } + if resp.Usage.TotalTokens != nil { + attrs[schemas.AttrTotalTokens] = *resp.Usage.TotalTokens + } + if resp.Usage.InputTokenDetails != nil { + attrs[schemas.AttrInputTokenDetailsText] = resp.Usage.InputTokenDetails.TextTokens + attrs[schemas.AttrInputTokenDetailsAudio] = resp.Usage.InputTokenDetails.AudioTokens + } + } +} + +// =============================================== +// Speech Request/Response +// =============================================== + +// PopulateSpeechRequestAttributes extracts speech request attributes. +func PopulateSpeechRequestAttributes(req *schemas.BifrostSpeechRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.Params != nil { + if req.Params.VoiceConfig != nil { + if req.Params.VoiceConfig.Voice != nil { + attrs[schemas.AttrVoice] = *req.Params.VoiceConfig.Voice + } + if len(req.Params.VoiceConfig.MultiVoiceConfig) > 0 { + voices := make([]string, len(req.Params.VoiceConfig.MultiVoiceConfig)) + for i, vc := range req.Params.VoiceConfig.MultiVoiceConfig { + voices[i] = vc.Voice + } + attrs[schemas.AttrMultiVoiceConfig] = strings.Join(voices, ",") + } + } + if req.Params.Instructions != "" { + attrs[schemas.AttrInstructions] = req.Params.Instructions + } + if req.Params.ResponseFormat != "" { + attrs[schemas.AttrResponseFormat] = req.Params.ResponseFormat + } + if req.Params.Speed != nil { + attrs[schemas.AttrSpeed] = *req.Params.Speed + } + } + + if req.Input != nil && req.Input.Input != "" { + attrs[schemas.AttrInputSpeech] = req.Input.Input + } +} + +// PopulateSpeechResponseAttributes extracts speech response attributes. +func PopulateSpeechResponseAttributes(resp *schemas.BifrostSpeechResponse, attrs map[string]any) { + if resp == nil { + return + } + + // Usage + if resp.Usage != nil { + attrs[schemas.AttrInputTokens] = resp.Usage.InputTokens + attrs[schemas.AttrOutputTokens] = resp.Usage.OutputTokens + attrs[schemas.AttrTotalTokens] = resp.Usage.TotalTokens + } +} + +// =============================================== +// Responses API Request/Response +// =============================================== + +// PopulateResponsesRequestAttributes extracts responses API request attributes. +func PopulateResponsesRequestAttributes(req *schemas.BifrostResponsesRequest, attrs map[string]any) { + if req == nil || req.Params == nil { + return + } + + if req.Params.ParallelToolCalls != nil { + attrs[schemas.AttrParallelToolCall] = *req.Params.ParallelToolCalls + } + if req.Params.PromptCacheKey != nil { + attrs[schemas.AttrPromptCacheKey] = *req.Params.PromptCacheKey + } + if req.Params.Reasoning != nil { + if req.Params.Reasoning.Effort != nil { + attrs[schemas.AttrReasoningEffort] = *req.Params.Reasoning.Effort + } + if req.Params.Reasoning.Summary != nil { + attrs[schemas.AttrReasoningSummary] = *req.Params.Reasoning.Summary + } + if req.Params.Reasoning.GenerateSummary != nil { + attrs[schemas.AttrReasoningGenSummary] = *req.Params.Reasoning.GenerateSummary + } + } + if req.Params.SafetyIdentifier != nil { + attrs[schemas.AttrSafetyIdentifier] = *req.Params.SafetyIdentifier + } + if req.Params.ServiceTier != nil { + attrs[schemas.AttrServiceTier] = *req.Params.ServiceTier + } + if req.Params.Store != nil { + attrs[schemas.AttrStore] = *req.Params.Store + } + if req.Params.Temperature != nil { + attrs[schemas.AttrTemperature] = *req.Params.Temperature + } + if req.Params.Text != nil { + if req.Params.Text.Verbosity != nil { + attrs[schemas.AttrTextVerbosity] = *req.Params.Text.Verbosity + } + if req.Params.Text.Format != nil { + attrs[schemas.AttrTextFormatType] = req.Params.Text.Format.Type + } + } + if req.Params.TopLogProbs != nil { + attrs[schemas.AttrTopLogProbs] = *req.Params.TopLogProbs + } + if req.Params.TopP != nil { + attrs[schemas.AttrTopP] = *req.Params.TopP + } + if req.Params.ToolChoice != nil { + if req.Params.ToolChoice.ResponsesToolChoiceStr != nil && *req.Params.ToolChoice.ResponsesToolChoiceStr != "" { + attrs[schemas.AttrToolChoiceType] = *req.Params.ToolChoice.ResponsesToolChoiceStr + } + if req.Params.ToolChoice.ResponsesToolChoiceStruct != nil && req.Params.ToolChoice.ResponsesToolChoiceStruct.Name != nil { + attrs[schemas.AttrToolChoiceName] = *req.Params.ToolChoice.ResponsesToolChoiceStruct.Name + } + } + if req.Params.Tools != nil { + tools := make([]string, len(req.Params.Tools)) + for i, tool := range req.Params.Tools { + tools[i] = string(tool.Type) + } + attrs[schemas.AttrTools] = strings.Join(tools, ",") + } + if req.Params.Truncation != nil { + attrs[schemas.AttrTruncation] = *req.Params.Truncation + } + // ExtraParams + for k, v := range req.Params.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateResponsesResponseAttributes extracts responses API response attributes. +func PopulateResponsesResponseAttributes(resp *schemas.BifrostResponsesResponse, attrs map[string]any) { + if resp == nil { + return + } + + if resp.ID != nil && *resp.ID != "" { + attrs[schemas.AttrResponseID] = *resp.ID + } + if resp.Model != "" { + attrs[schemas.AttrResponseModel] = resp.Model + } + if resp.ServiceTier != nil { + attrs[schemas.AttrServiceTier] = *resp.ServiceTier + } + + // Extract output messages (includes reasoning) + outputMessages := extractResponsesOutputMessages(resp) + if len(outputMessages) > 0 { + attrs[schemas.AttrOutputMessages] = outputMessages + } + + // Additional response fields + if resp.Include != nil { + attrs[schemas.AttrRespInclude] = strings.Join(resp.Include, ",") + } + if resp.MaxOutputTokens != nil { + attrs[schemas.AttrRespMaxOutputTokens] = *resp.MaxOutputTokens + } + if resp.MaxToolCalls != nil { + attrs[schemas.AttrRespMaxToolCalls] = *resp.MaxToolCalls + } + if resp.Metadata != nil { + attrs[schemas.AttrRespMetadata] = fmt.Sprintf("%v", resp.Metadata) + } + if resp.PreviousResponseID != nil { + attrs[schemas.AttrRespPreviousRespID] = *resp.PreviousResponseID + } + if resp.PromptCacheKey != nil { + attrs[schemas.AttrRespPromptCacheKey] = *resp.PromptCacheKey + } + if resp.Reasoning != nil { + if resp.Reasoning.Summary != nil { + attrs[schemas.AttrRespReasoningText] = *resp.Reasoning.Summary + } + if resp.Reasoning.Effort != nil { + attrs[schemas.AttrRespReasoningEffort] = *resp.Reasoning.Effort + } + if resp.Reasoning.GenerateSummary != nil { + attrs[schemas.AttrRespReasoningGenSum] = *resp.Reasoning.GenerateSummary + } + } + if resp.SafetyIdentifier != nil { + attrs[schemas.AttrRespSafetyIdentifier] = *resp.SafetyIdentifier + } + if resp.Store != nil { + attrs[schemas.AttrRespStore] = *resp.Store + } + if resp.Temperature != nil { + attrs[schemas.AttrRespTemperature] = *resp.Temperature + } + if resp.Text != nil { + if resp.Text.Verbosity != nil { + attrs[schemas.AttrRespTextVerbosity] = *resp.Text.Verbosity + } + if resp.Text.Format != nil { + attrs[schemas.AttrRespTextFormatType] = resp.Text.Format.Type + } + } + if resp.TopLogProbs != nil { + attrs[schemas.AttrRespTopLogProbs] = *resp.TopLogProbs + } + if resp.TopP != nil { + attrs[schemas.AttrRespTopP] = *resp.TopP + } + if resp.ToolChoice != nil { + if resp.ToolChoice.ResponsesToolChoiceStr != nil { + attrs[schemas.AttrRespToolChoiceType] = *resp.ToolChoice.ResponsesToolChoiceStr + } + if resp.ToolChoice.ResponsesToolChoiceStruct != nil && resp.ToolChoice.ResponsesToolChoiceStruct.Name != nil { + attrs[schemas.AttrRespToolChoiceName] = *resp.ToolChoice.ResponsesToolChoiceStruct.Name + } + } + if resp.Truncation != nil { + attrs[schemas.AttrRespTruncation] = *resp.Truncation + } + if resp.Tools != nil { + tools := make([]string, len(resp.Tools)) + for i, tool := range resp.Tools { + tools[i] = string(tool.Type) + } + attrs[schemas.AttrRespTools] = strings.Join(tools, ",") + } + + // Usage + if resp.Usage != nil { + attrs[schemas.AttrInputTokens] = resp.Usage.InputTokens + attrs[schemas.AttrOutputTokens] = resp.Usage.OutputTokens + attrs[schemas.AttrTotalTokens] = resp.Usage.TotalTokens + } +} + +// =============================================== +// Batch Operations Request/Response +// =============================================== + +// PopulateBatchCreateRequestAttributes extracts batch create request attributes. +func PopulateBatchCreateRequestAttributes(req *schemas.BifrostBatchCreateRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.InputFileID != "" { + attrs[schemas.AttrBatchInputFileID] = req.InputFileID + } + if req.Endpoint != "" { + attrs[schemas.AttrBatchEndpoint] = string(req.Endpoint) + } + if req.CompletionWindow != "" { + attrs[schemas.AttrBatchCompletionWin] = req.CompletionWindow + } + if len(req.Requests) > 0 { + attrs[schemas.AttrBatchRequestsCount] = len(req.Requests) + } + if len(req.Metadata) > 0 { + attrs[schemas.AttrBatchMetadata] = fmt.Sprintf("%v", req.Metadata) + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateBatchListRequestAttributes extracts batch list request attributes. +func PopulateBatchListRequestAttributes(req *schemas.BifrostBatchListRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.Limit > 0 { + attrs[schemas.AttrBatchLimit] = req.Limit + } + if req.After != nil { + attrs[schemas.AttrBatchAfter] = *req.After + } + if req.BeforeID != nil { + attrs[schemas.AttrBatchBeforeID] = *req.BeforeID + } + if req.AfterID != nil { + attrs[schemas.AttrBatchAfterID] = *req.AfterID + } + if req.PageToken != nil { + attrs[schemas.AttrBatchPageToken] = *req.PageToken + } + if req.PageSize > 0 { + attrs[schemas.AttrBatchPageSize] = req.PageSize + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateBatchRetrieveRequestAttributes extracts batch retrieve request attributes. +func PopulateBatchRetrieveRequestAttributes(req *schemas.BifrostBatchRetrieveRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.BatchID != "" { + attrs[schemas.AttrBatchID] = req.BatchID + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateBatchCancelRequestAttributes extracts batch cancel request attributes. +func PopulateBatchCancelRequestAttributes(req *schemas.BifrostBatchCancelRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.BatchID != "" { + attrs[schemas.AttrBatchID] = req.BatchID + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateBatchResultsRequestAttributes extracts batch results request attributes. +func PopulateBatchResultsRequestAttributes(req *schemas.BifrostBatchResultsRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.BatchID != "" { + attrs[schemas.AttrBatchID] = req.BatchID + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateBatchCreateResponseAttributes extracts batch create response attributes. +func PopulateBatchCreateResponseAttributes(resp *schemas.BifrostBatchCreateResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrBatchID] = resp.ID + attrs[schemas.AttrBatchStatus] = string(resp.Status) + if resp.Object != "" { + attrs[schemas.AttrBatchObject] = resp.Object + } + if resp.Endpoint != "" { + attrs[schemas.AttrBatchEndpoint] = resp.Endpoint + } + if resp.InputFileID != "" { + attrs[schemas.AttrBatchInputFileID] = resp.InputFileID + } + if resp.CompletionWindow != "" { + attrs[schemas.AttrBatchCompletionWin] = resp.CompletionWindow + } + if resp.CreatedAt != 0 { + attrs[schemas.AttrBatchCreatedAt] = resp.CreatedAt + } + if resp.ExpiresAt != nil { + attrs[schemas.AttrBatchExpiresAt] = *resp.ExpiresAt + } + if resp.OutputFileID != nil { + attrs[schemas.AttrBatchOutputFileID] = *resp.OutputFileID + } + if resp.ErrorFileID != nil { + attrs[schemas.AttrBatchErrorFileID] = *resp.ErrorFileID + } + attrs[schemas.AttrBatchCountTotal] = resp.RequestCounts.Total + attrs[schemas.AttrBatchCountCompleted] = resp.RequestCounts.Completed + attrs[schemas.AttrBatchCountFailed] = resp.RequestCounts.Failed +} + +// PopulateBatchListResponseAttributes extracts batch list response attributes. +func PopulateBatchListResponseAttributes(resp *schemas.BifrostBatchListResponse, attrs map[string]any) { + if resp == nil { + return + } + + if resp.Object != "" { + attrs[schemas.AttrBatchObject] = resp.Object + } + attrs[schemas.AttrBatchDataCount] = len(resp.Data) + attrs[schemas.AttrBatchHasMore] = resp.HasMore + if resp.FirstID != nil { + attrs[schemas.AttrBatchFirstID] = *resp.FirstID + } + if resp.LastID != nil { + attrs[schemas.AttrBatchLastID] = *resp.LastID + } +} + +// PopulateBatchRetrieveResponseAttributes extracts batch retrieve response attributes. +func PopulateBatchRetrieveResponseAttributes(resp *schemas.BifrostBatchRetrieveResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrBatchID] = resp.ID + attrs[schemas.AttrBatchStatus] = string(resp.Status) + if resp.Object != "" { + attrs[schemas.AttrBatchObject] = resp.Object + } + if resp.Endpoint != "" { + attrs[schemas.AttrBatchEndpoint] = resp.Endpoint + } + if resp.InputFileID != "" { + attrs[schemas.AttrBatchInputFileID] = resp.InputFileID + } + if resp.CompletionWindow != "" { + attrs[schemas.AttrBatchCompletionWin] = resp.CompletionWindow + } + if resp.CreatedAt != 0 { + attrs[schemas.AttrBatchCreatedAt] = resp.CreatedAt + } + if resp.ExpiresAt != nil { + attrs[schemas.AttrBatchExpiresAt] = *resp.ExpiresAt + } + if resp.InProgressAt != nil { + attrs[schemas.AttrBatchInProgressAt] = *resp.InProgressAt + } + if resp.FinalizingAt != nil { + attrs[schemas.AttrBatchFinalizingAt] = *resp.FinalizingAt + } + if resp.CompletedAt != nil { + attrs[schemas.AttrBatchCompletedAt] = *resp.CompletedAt + } + if resp.FailedAt != nil { + attrs[schemas.AttrBatchFailedAt] = *resp.FailedAt + } + if resp.ExpiredAt != nil { + attrs[schemas.AttrBatchExpiredAt] = *resp.ExpiredAt + } + if resp.CancellingAt != nil { + attrs[schemas.AttrBatchCancellingAt] = *resp.CancellingAt + } + if resp.CancelledAt != nil { + attrs[schemas.AttrBatchCancelledAt] = *resp.CancelledAt + } + if resp.OutputFileID != nil { + attrs[schemas.AttrBatchOutputFileID] = *resp.OutputFileID + } + if resp.ErrorFileID != nil { + attrs[schemas.AttrBatchErrorFileID] = *resp.ErrorFileID + } + attrs[schemas.AttrBatchCountTotal] = resp.RequestCounts.Total + attrs[schemas.AttrBatchCountCompleted] = resp.RequestCounts.Completed + attrs[schemas.AttrBatchCountFailed] = resp.RequestCounts.Failed +} + +// PopulateBatchCancelResponseAttributes extracts batch cancel response attributes. +func PopulateBatchCancelResponseAttributes(resp *schemas.BifrostBatchCancelResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrBatchID] = resp.ID + attrs[schemas.AttrBatchStatus] = string(resp.Status) + if resp.Object != "" { + attrs[schemas.AttrBatchObject] = resp.Object + } + if resp.CancellingAt != nil { + attrs[schemas.AttrBatchCancellingAt] = *resp.CancellingAt + } + if resp.CancelledAt != nil { + attrs[schemas.AttrBatchCancelledAt] = *resp.CancelledAt + } + attrs[schemas.AttrBatchCountTotal] = resp.RequestCounts.Total + attrs[schemas.AttrBatchCountCompleted] = resp.RequestCounts.Completed + attrs[schemas.AttrBatchCountFailed] = resp.RequestCounts.Failed +} + +// PopulateBatchResultsResponseAttributes extracts batch results response attributes. +func PopulateBatchResultsResponseAttributes(resp *schemas.BifrostBatchResultsResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrBatchID] = resp.BatchID + attrs[schemas.AttrBatchResultsCount] = len(resp.Results) + attrs[schemas.AttrBatchHasMore] = resp.HasMore + if resp.NextCursor != nil { + attrs[schemas.AttrBatchNextCursor] = *resp.NextCursor + } +} + +// =============================================== +// File Operations Request/Response +// =============================================== + +// PopulateFileUploadRequestAttributes extracts file upload request attributes. +func PopulateFileUploadRequestAttributes(req *schemas.BifrostFileUploadRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.Filename != "" { + attrs[schemas.AttrFileFilename] = req.Filename + } + if req.Purpose != "" { + attrs[schemas.AttrFilePurpose] = string(req.Purpose) + } + if len(req.File) > 0 { + attrs[schemas.AttrFileBytes] = len(req.File) + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateFileListRequestAttributes extracts file list request attributes. +func PopulateFileListRequestAttributes(req *schemas.BifrostFileListRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.Purpose != "" { + attrs[schemas.AttrFilePurpose] = string(req.Purpose) + } + if req.Limit > 0 { + attrs[schemas.AttrFileLimit] = req.Limit + } + if req.After != nil { + attrs[schemas.AttrFileAfter] = *req.After + } + if req.Order != nil { + attrs[schemas.AttrFileOrder] = *req.Order + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateFileRetrieveRequestAttributes extracts file retrieve request attributes. +func PopulateFileRetrieveRequestAttributes(req *schemas.BifrostFileRetrieveRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.FileID != "" { + attrs[schemas.AttrFileID] = req.FileID + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateFileDeleteRequestAttributes extracts file delete request attributes. +func PopulateFileDeleteRequestAttributes(req *schemas.BifrostFileDeleteRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.FileID != "" { + attrs[schemas.AttrFileID] = req.FileID + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateFileContentRequestAttributes extracts file content request attributes. +func PopulateFileContentRequestAttributes(req *schemas.BifrostFileContentRequest, attrs map[string]any) { + if req == nil { + return + } + + if req.FileID != "" { + attrs[schemas.AttrFileID] = req.FileID + } + // ExtraParams + for k, v := range req.ExtraParams { + attrs[k] = fmt.Sprintf("%v", v) + } +} + +// PopulateFileUploadResponseAttributes extracts file upload response attributes. +func PopulateFileUploadResponseAttributes(resp *schemas.BifrostFileUploadResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrFileID] = resp.ID + if resp.Object != "" { + attrs[schemas.AttrFileObject] = resp.Object + } + attrs[schemas.AttrFileBytes] = resp.Bytes + attrs[schemas.AttrFileCreatedAt] = resp.CreatedAt + attrs[schemas.AttrFileFilename] = resp.Filename + attrs[schemas.AttrFilePurpose] = string(resp.Purpose) + if resp.Status != "" { + attrs[schemas.AttrFileStatus] = string(resp.Status) + } + if resp.StorageBackend != "" { + attrs[schemas.AttrFileStorageBackend] = string(resp.StorageBackend) + } +} + +// PopulateFileListResponseAttributes extracts file list response attributes. +func PopulateFileListResponseAttributes(resp *schemas.BifrostFileListResponse, attrs map[string]any) { + if resp == nil { + return + } + + if resp.Object != "" { + attrs[schemas.AttrFileObject] = resp.Object + } + attrs[schemas.AttrFileDataCount] = len(resp.Data) + attrs[schemas.AttrFileHasMore] = resp.HasMore +} + +// PopulateFileRetrieveResponseAttributes extracts file retrieve response attributes. +func PopulateFileRetrieveResponseAttributes(resp *schemas.BifrostFileRetrieveResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrFileID] = resp.ID + if resp.Object != "" { + attrs[schemas.AttrFileObject] = resp.Object + } + attrs[schemas.AttrFileBytes] = resp.Bytes + attrs[schemas.AttrFileCreatedAt] = resp.CreatedAt + attrs[schemas.AttrFileFilename] = resp.Filename + attrs[schemas.AttrFilePurpose] = string(resp.Purpose) + if resp.Status != "" { + attrs[schemas.AttrFileStatus] = string(resp.Status) + } + if resp.StorageBackend != "" { + attrs[schemas.AttrFileStorageBackend] = string(resp.StorageBackend) + } +} + +// PopulateFileDeleteResponseAttributes extracts file delete response attributes. +func PopulateFileDeleteResponseAttributes(resp *schemas.BifrostFileDeleteResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrFileID] = resp.ID + if resp.Object != "" { + attrs[schemas.AttrFileObject] = resp.Object + } + attrs[schemas.AttrFileDeleted] = resp.Deleted +} + +// PopulateFileContentResponseAttributes extracts file content response attributes. +func PopulateFileContentResponseAttributes(resp *schemas.BifrostFileContentResponse, attrs map[string]any) { + if resp == nil { + return + } + + attrs[schemas.AttrFileID] = resp.FileID + if resp.ContentType != "" { + attrs[schemas.AttrFileContentType] = resp.ContentType + } + if len(resp.Content) > 0 { + attrs[schemas.AttrFileContentBytes] = len(resp.Content) + } +} + +// =============================================== +// Helper functions for extracting messages +// =============================================== + +// MessageSummary represents a summarized chat message for tracing +type MessageSummary struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCallSummary `json:"tool_calls,omitempty"` + Reasoning string `json:"reasoning,omitempty"` + ReasoningDetails []ReasoningDetailSummary `json:"reasoning_details,omitempty"` + Audio *AudioSummary `json:"audio,omitempty"` + Refusal string `json:"refusal,omitempty"` +} + +// ToolCallSummary represents a summarized tool call for tracing +type ToolCallSummary struct { + ID string `json:"id"` + Type string `json:"type"` + Name string `json:"name"` + Args string `json:"args,omitempty"` +} + +// ReasoningDetailSummary represents a summarized reasoning detail for tracing +type ReasoningDetailSummary struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` +} + +// AudioSummary represents summarized audio data for tracing +type AudioSummary struct { + ID string `json:"id,omitempty"` + Transcript string `json:"transcript,omitempty"` +} + +// extractChatMessages extracts chat messages into a slice of MessageSummary +func extractChatMessages(messages []schemas.ChatMessage) []MessageSummary { + result := make([]MessageSummary, 0, len(messages)) + for _, msg := range messages { + summary := extractMessageSummary(&msg) + result = append(result, summary) + } + return result +} + +// extractChatResponseMessages extracts output messages from chat response +func extractChatResponseMessages(resp *schemas.BifrostChatResponse) []MessageSummary { + if resp == nil { + return nil + } + + result := make([]MessageSummary, 0, len(resp.Choices)) + for _, choice := range resp.Choices { + if choice.ChatNonStreamResponseChoice == nil || choice.ChatNonStreamResponseChoice.Message == nil { + continue + } + msg := choice.ChatNonStreamResponseChoice.Message + summary := extractMessageSummary(msg) + result = append(result, summary) + } + return result +} + +// extractMessageSummary extracts a full MessageSummary from a ChatMessage +func extractMessageSummary(msg *schemas.ChatMessage) MessageSummary { + if msg == nil { + return MessageSummary{} + } + + summary := MessageSummary{ + Role: string(schemas.ChatMessageRoleAssistant), + Content: extractMessageContent(msg.Content), + } + + if msg.Role != "" { + summary.Role = string(msg.Role) + } + + // Extract assistant-specific fields + if msg.ChatAssistantMessage != nil { + am := msg.ChatAssistantMessage + + // Extract refusal + if am.Refusal != nil && *am.Refusal != "" { + summary.Refusal = *am.Refusal + } + + // Extract reasoning + if am.Reasoning != nil && *am.Reasoning != "" { + summary.Reasoning = *am.Reasoning + } + + // Extract reasoning details + if len(am.ReasoningDetails) > 0 { + summary.ReasoningDetails = make([]ReasoningDetailSummary, 0, len(am.ReasoningDetails)) + for _, rd := range am.ReasoningDetails { + detail := ReasoningDetailSummary{ + Type: string(rd.Type), + } + if rd.Text != nil { + detail.Text = *rd.Text + } + summary.ReasoningDetails = append(summary.ReasoningDetails, detail) + } + } + + // Extract audio + if am.Audio != nil { + summary.Audio = &AudioSummary{ + ID: am.Audio.ID, + Transcript: am.Audio.Transcript, + } + } + + // Extract tool calls + if len(am.ToolCalls) > 0 { + summary.ToolCalls = make([]ToolCallSummary, 0, len(am.ToolCalls)) + for _, tc := range am.ToolCalls { + toolCall := ToolCallSummary{ + Type: "function", + } + if tc.ID != nil { + toolCall.ID = *tc.ID + } + if tc.Type != nil { + toolCall.Type = *tc.Type + } + if tc.Function.Name != nil { + toolCall.Name = *tc.Function.Name + } + toolCall.Args = tc.Function.Arguments + summary.ToolCalls = append(summary.ToolCalls, toolCall) + } + } + } + + return summary +} + +// ResponsesMessageSummary extends MessageSummary with reasoning +type ResponsesMessageSummary struct { + Role string `json:"role"` + Content string `json:"content"` + Reasoning string `json:"reasoning,omitempty"` +} + +// extractResponsesOutputMessages extracts output messages from responses API +func extractResponsesOutputMessages(resp *schemas.BifrostResponsesResponse) []ResponsesMessageSummary { + if resp == nil { + return nil + } + + result := make([]ResponsesMessageSummary, 0, len(resp.Output)) + for _, msg := range resp.Output { + if msg.Role == nil { + continue + } + content := "" + if msg.Content != nil { + if msg.Content.ContentStr != nil && *msg.Content.ContentStr != "" { + content = *msg.Content.ContentStr + } else if msg.Content.ContentBlocks != nil { + for _, block := range msg.Content.ContentBlocks { + if block.Text != nil { + content += *block.Text + } + } + } + } + // Extract reasoning text + reasoning := "" + if msg.ResponsesReasoning != nil && msg.ResponsesReasoning.Summary != nil { + for _, block := range msg.ResponsesReasoning.Summary { + if block.Text != "" { + reasoning += block.Text + } + } + } + result = append(result, ResponsesMessageSummary{ + Role: string(*msg.Role), + Content: content, + Reasoning: reasoning, + }) + } + return result +} + +// extractMessageContent extracts text content from ChatMessageContent +func extractMessageContent(content *schemas.ChatMessageContent) string { + if content == nil { + return "" + } + + if content.ContentStr != nil { + return *content.ContentStr + } + + if content.ContentBlocks != nil { + var builder strings.Builder + for _, block := range content.ContentBlocks { + if block.Text != nil { + builder.WriteString(*block.Text) + } + } + return builder.String() + } + + return "" +} + +// =============================================== +// Cost Calculation +// =============================================== + +// PopulateCostAttribute calculates and adds the cost attribute for a response. +// The pricingManager is optional; if nil, no cost attribute is added. +func PopulateCostAttribute( + resp *schemas.BifrostResponse, + pricingManager *modelcatalog.ModelCatalog, + attrs map[string]any, +) { + if pricingManager == nil || resp == nil { + return + } + cost := pricingManager.CalculateCostWithCacheDebug(resp) + attrs[schemas.AttrUsageCost] = cost +} diff --git a/framework/tracing/propagation.go b/framework/tracing/propagation.go new file mode 100644 index 0000000000..6e8c5ae935 --- /dev/null +++ b/framework/tracing/propagation.go @@ -0,0 +1,211 @@ +// Package tracing provides distributed tracing infrastructure for Bifrost +package tracing + +import ( + "strings" + + "github.com/valyala/fasthttp" +) + +// normalizeTraceID normalizes a trace ID to W3C-compliant format. +// Strips hyphens and ensures 32 lowercase hex characters. +// Returns empty string if input cannot be normalized to a valid trace ID. +func normalizeTraceID(traceID string) string { + // Remove hyphens (handles UUID format) + normalized := strings.ReplaceAll(traceID, "-", "") + normalized = strings.ToLower(normalized) + + // Validate length - must be exactly 32 hex chars + if len(normalized) != 32 { + return "" + } + + // Validate hex characters + if !isHex(normalized) { + return "" + } + + return normalized +} + +// normalizeSpanID normalizes a span ID to W3C-compliant format. +// Strips hyphens and ensures 16 lowercase hex characters. +// If input is longer (e.g., UUID format), takes first 16 hex chars. +// Returns empty string if input cannot be normalized to a valid span ID. +func normalizeSpanID(spanID string) string { + // Remove hyphens (handles UUID format) + normalized := strings.ReplaceAll(spanID, "-", "") + normalized = strings.ToLower(normalized) + + // If longer than 16 chars, truncate (e.g., full UUID -> first 16 hex chars) + if len(normalized) > 16 { + normalized = normalized[:16] + } + + // Validate length - must be exactly 16 hex chars + if len(normalized) != 16 { + return "" + } + + // Validate hex characters + if !isHex(normalized) { + return "" + } + + return normalized +} + +// W3C Trace Context header names +const ( + TraceParentHeader = "traceparent" + TraceStateHeader = "tracestate" +) + +// W3CTraceContext holds parsed W3C trace context values +type W3CTraceContext struct { + TraceID string // 32 hex characters + ParentID string // 16 hex characters (span ID of parent) + TraceFlags string // 2 hex characters + TraceState string // Optional vendor-specific trace state +} + +// ExtractParentID extracts the trace ID from W3C traceparent header. +// This returns the trace ID (32 hex chars) which should be used to continue +// the distributed trace from the upstream service. +// Returns empty string if header is not present or invalid. +func ExtractParentID(header *fasthttp.RequestHeader) string { + traceParent := string(header.Peek(TraceParentHeader)) + if traceParent == "" { + return "" + } + ctx := ParseTraceparent(traceParent) + if ctx == nil { + return "" + } + return ctx.TraceID +} + +// ExtractTraceParentSpanID extracts the parent span ID from W3C traceparent header. +// This returns the span ID (16 hex chars) of the upstream service's span that +// initiated this request. This should be set as the ParentID of the root span +// in the receiving service to establish the parent-child relationship. +// Returns empty string if header is not present or invalid. +func ExtractTraceParentSpanID(header *fasthttp.RequestHeader) string { + traceParent := string(header.Peek(TraceParentHeader)) + if traceParent == "" { + return "" + } + ctx := ParseTraceparent(traceParent) + if ctx == nil { + return "" + } + return ctx.ParentID +} + +// ExtractTraceContext extracts full W3C trace context from headers +func ExtractTraceContext(header *fasthttp.RequestHeader) *W3CTraceContext { + traceparent := string(header.Peek(TraceParentHeader)) + if traceparent == "" { + return nil + } + + ctx := ParseTraceparent(traceparent) + if ctx == nil { + return nil + } + + // Also extract tracestate if present + ctx.TraceState = string(header.Peek(TraceStateHeader)) + + return ctx +} + +// ParseTraceparent parses a W3C traceparent header value +// Format: version-traceid-parentid-traceflags +// Example: 00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01 +func ParseTraceparent(traceparent string) *W3CTraceContext { + parts := strings.Split(traceparent, "-") + if len(parts) != 4 { + return nil + } + + version := parts[0] + traceID := parts[1] + parentID := parts[2] + traceFlags := parts[3] + + // Validate version (only 00 is currently supported) + if version != "00" { + return nil + } + + // Validate trace ID (32 hex characters) + if len(traceID) != 32 || !isHex(traceID) { + return nil + } + + // Validate parent ID (16 hex characters) + if len(parentID) != 16 || !isHex(parentID) { + return nil + } + + // Validate trace flags (2 hex characters) + if len(traceFlags) != 2 || !isHex(traceFlags) { + return nil + } + + return &W3CTraceContext{ + TraceID: traceID, + ParentID: parentID, + TraceFlags: traceFlags, + } +} + +// FormatTraceparent formats a W3C traceparent header value. +// It normalizes trace ID and span ID to W3C-compliant format: +// - trace ID: 32 lowercase hex characters +// - span ID: 16 lowercase hex characters +// Returns empty string if IDs cannot be normalized to valid format. +func FormatTraceparent(traceID, spanID, traceFlags string) string { + normalizedTraceID := normalizeTraceID(traceID) + normalizedSpanID := normalizeSpanID(spanID) + + if normalizedTraceID == "" || normalizedSpanID == "" { + return "" + } + + // Normalize and validate traceFlags + traceFlags = strings.ToLower(traceFlags) + if len(traceFlags) != 2 || !isHex(traceFlags) { + traceFlags = "00" // Default: not sampled + } + + return "00-" + normalizedTraceID + "-" + normalizedSpanID + "-" + traceFlags +} + +// InjectTraceContext injects W3C trace context headers into outgoing request +func InjectTraceContext(header *fasthttp.RequestHeader, traceID, spanID, traceFlags, traceState string) { + if traceID == "" || spanID == "" { + return + } + + traceparent := FormatTraceparent(traceID, spanID, traceFlags) + if traceparent == "" { + return // IDs could not be normalized to valid W3C format + } + header.Set(TraceParentHeader, traceparent) + + if traceState != "" { + header.Set(TraceStateHeader, traceState) + } +} + +// isHex checks if a string contains only hexadecimal characters +func isHex(s string) bool { + for _, c := range s { + if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) { + return false + } + } + return true +} diff --git a/framework/tracing/propagation_test.go b/framework/tracing/propagation_test.go new file mode 100644 index 0000000000..af2eaa756c --- /dev/null +++ b/framework/tracing/propagation_test.go @@ -0,0 +1,356 @@ +package tracing + +import ( + "testing" + + "github.com/valyala/fasthttp" +) + +func TestParseTraceparent_ValidHeader(t *testing.T) { + // Example from W3C spec and the user's actual Datadog headers + tests := []struct { + name string + traceparent string + wantTraceID string + wantParent string + wantFlags string + }{ + { + name: "valid traceparent from Datadog", + traceparent: "00-69538b980000000079943934f90c1d40-aad09d1659b4c7e3-01", + wantTraceID: "69538b980000000079943934f90c1d40", + wantParent: "aad09d1659b4c7e3", + wantFlags: "01", + }, + { + name: "valid traceparent with sampled flag", + traceparent: "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01", + wantTraceID: "0af7651916cd43dd8448eb211c80319c", + wantParent: "b7ad6b7169203331", + wantFlags: "01", + }, + { + name: "valid traceparent not sampled", + traceparent: "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-00", + wantTraceID: "0af7651916cd43dd8448eb211c80319c", + wantParent: "b7ad6b7169203331", + wantFlags: "00", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := ParseTraceparent(tt.traceparent) + if ctx == nil { + t.Fatalf("ParseTraceparent() returned nil for valid header") + } + if ctx.TraceID != tt.wantTraceID { + t.Errorf("TraceID = %q, want %q", ctx.TraceID, tt.wantTraceID) + } + if ctx.ParentID != tt.wantParent { + t.Errorf("ParentID = %q, want %q", ctx.ParentID, tt.wantParent) + } + if ctx.TraceFlags != tt.wantFlags { + t.Errorf("TraceFlags = %q, want %q", ctx.TraceFlags, tt.wantFlags) + } + }) + } +} + +func TestParseTraceparent_InvalidVersion(t *testing.T) { + // Only version 00 is supported + tests := []struct { + name string + traceparent string + }{ + { + name: "version 01", + traceparent: "01-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01", + }, + { + name: "version ff", + traceparent: "ff-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := ParseTraceparent(tt.traceparent) + if ctx != nil { + t.Errorf("ParseTraceparent() should return nil for unsupported version") + } + }) + } +} + +func TestParseTraceparent_InvalidTraceID(t *testing.T) { + tests := []struct { + name string + traceparent string + }{ + { + name: "trace ID too short", + traceparent: "00-0af7651916cd43dd8448eb211c8031-b7ad6b7169203331-01", + }, + { + name: "trace ID too long", + traceparent: "00-0af7651916cd43dd8448eb211c80319c00-b7ad6b7169203331-01", + }, + { + name: "trace ID with invalid chars", + traceparent: "00-0af7651916cd43dd8448eb211c80319z-b7ad6b7169203331-01", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := ParseTraceparent(tt.traceparent) + if ctx != nil { + t.Errorf("ParseTraceparent() should return nil for invalid trace ID") + } + }) + } +} + +func TestParseTraceparent_InvalidParentID(t *testing.T) { + tests := []struct { + name string + traceparent string + }{ + { + name: "parent ID too short", + traceparent: "00-0af7651916cd43dd8448eb211c80319c-b7ad6b71692033-01", + }, + { + name: "parent ID too long", + traceparent: "00-0af7651916cd43dd8448eb211c80319c-b7ad6b716920333100-01", + }, + { + name: "parent ID with invalid chars", + traceparent: "00-0af7651916cd43dd8448eb211c80319c-b7ad6b716920333z-01", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := ParseTraceparent(tt.traceparent) + if ctx != nil { + t.Errorf("ParseTraceparent() should return nil for invalid parent ID") + } + }) + } +} + +func TestParseTraceparent_MalformedHeader(t *testing.T) { + tests := []struct { + name string + traceparent string + }{ + { + name: "empty string", + traceparent: "", + }, + { + name: "missing parts", + traceparent: "00-0af7651916cd43dd8448eb211c80319c", + }, + { + name: "too many parts", + traceparent: "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01-extra", + }, + { + name: "wrong delimiter", + traceparent: "00_0af7651916cd43dd8448eb211c80319c_b7ad6b7169203331_01", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := ParseTraceparent(tt.traceparent) + if ctx != nil { + t.Errorf("ParseTraceparent() should return nil for malformed header") + } + }) + } +} + +func TestExtractParentID_ReturnsTraceID(t *testing.T) { + header := &fasthttp.RequestHeader{} + header.Set(TraceParentHeader, "00-69538b980000000079943934f90c1d40-aad09d1659b4c7e3-01") + + traceID := ExtractParentID(header) + if traceID != "69538b980000000079943934f90c1d40" { + t.Errorf("ExtractParentID() = %q, want %q", traceID, "69538b980000000079943934f90c1d40") + } +} + +func TestExtractParentID_EmptyHeader(t *testing.T) { + header := &fasthttp.RequestHeader{} + + traceID := ExtractParentID(header) + if traceID != "" { + t.Errorf("ExtractParentID() = %q, want empty string", traceID) + } +} + +func TestExtractTraceParentSpanID_ReturnsParentSpanID(t *testing.T) { + header := &fasthttp.RequestHeader{} + header.Set(TraceParentHeader, "00-69538b980000000079943934f90c1d40-aad09d1659b4c7e3-01") + + parentSpanID := ExtractTraceParentSpanID(header) + if parentSpanID != "aad09d1659b4c7e3" { + t.Errorf("ExtractTraceParentSpanID() = %q, want %q", parentSpanID, "aad09d1659b4c7e3") + } +} + +func TestExtractTraceParentSpanID_EmptyHeader(t *testing.T) { + header := &fasthttp.RequestHeader{} + + parentSpanID := ExtractTraceParentSpanID(header) + if parentSpanID != "" { + t.Errorf("ExtractTraceParentSpanID() = %q, want empty string", parentSpanID) + } +} + +func TestExtractTraceParentSpanID_InvalidHeader(t *testing.T) { + header := &fasthttp.RequestHeader{} + header.Set(TraceParentHeader, "invalid-header") + + parentSpanID := ExtractTraceParentSpanID(header) + if parentSpanID != "" { + t.Errorf("ExtractTraceParentSpanID() = %q, want empty string for invalid header", parentSpanID) + } +} + +func TestFormatTraceparent_NormalizesIDs(t *testing.T) { + tests := []struct { + name string + traceID string + spanID string + traceFlags string + want string + }{ + { + name: "already normalized", + traceID: "69538b980000000079943934f90c1d40", + spanID: "aad09d1659b4c7e3", + traceFlags: "01", + want: "00-69538b980000000079943934f90c1d40-aad09d1659b4c7e3-01", + }, + { + name: "uppercase to lowercase", + traceID: "69538B980000000079943934F90C1D40", + spanID: "AAD09D1659B4C7E3", + traceFlags: "01", + want: "00-69538b980000000079943934f90c1d40-aad09d1659b4c7e3-01", + }, + { + name: "UUID format trace ID", + traceID: "69538b98-0000-0000-7994-3934f90c1d40", + spanID: "aad09d1659b4c7e3", + traceFlags: "01", + want: "00-69538b980000000079943934f90c1d40-aad09d1659b4c7e3-01", + }, + { + name: "default trace flags when invalid", + traceID: "69538b980000000079943934f90c1d40", + spanID: "aad09d1659b4c7e3", + traceFlags: "xyz", + want: "00-69538b980000000079943934f90c1d40-aad09d1659b4c7e3-00", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FormatTraceparent(tt.traceID, tt.spanID, tt.traceFlags) + if got != tt.want { + t.Errorf("FormatTraceparent() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestFormatTraceparent_InvalidIDs(t *testing.T) { + tests := []struct { + name string + traceID string + spanID string + }{ + { + name: "empty trace ID", + traceID: "", + spanID: "aad09d1659b4c7e3", + }, + { + name: "empty span ID", + traceID: "69538b980000000079943934f90c1d40", + spanID: "", + }, + { + name: "invalid trace ID length", + traceID: "69538b98", + spanID: "aad09d1659b4c7e3", + }, + { + name: "invalid span ID length", + traceID: "69538b980000000079943934f90c1d40", + spanID: "aad09d16", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FormatTraceparent(tt.traceID, tt.spanID, "01") + if got != "" { + t.Errorf("FormatTraceparent() = %q, want empty string for invalid IDs", got) + } + }) + } +} + +func TestExtractTraceContext_WithTraceState(t *testing.T) { + header := &fasthttp.RequestHeader{} + header.Set(TraceParentHeader, "00-69538b980000000079943934f90c1d40-aad09d1659b4c7e3-01") + header.Set(TraceStateHeader, "dd=p:aad09d1659b4c7e3;s:1;t.dm:-1;t.tid:69538b9800000000") + + ctx := ExtractTraceContext(header) + if ctx == nil { + t.Fatal("ExtractTraceContext() returned nil") + } + if ctx.TraceID != "69538b980000000079943934f90c1d40" { + t.Errorf("TraceID = %q, want %q", ctx.TraceID, "69538b980000000079943934f90c1d40") + } + if ctx.ParentID != "aad09d1659b4c7e3" { + t.Errorf("ParentID = %q, want %q", ctx.ParentID, "aad09d1659b4c7e3") + } + if ctx.TraceState != "dd=p:aad09d1659b4c7e3;s:1;t.dm:-1;t.tid:69538b9800000000" { + t.Errorf("TraceState = %q, want Datadog tracestate", ctx.TraceState) + } +} + +func TestInjectTraceContext(t *testing.T) { + header := &fasthttp.RequestHeader{} + + InjectTraceContext(header, "69538b980000000079943934f90c1d40", "aad09d1659b4c7e3", "01", "dd=s:1") + + traceparent := string(header.Peek(TraceParentHeader)) + if traceparent != "00-69538b980000000079943934f90c1d40-aad09d1659b4c7e3-01" { + t.Errorf("traceparent = %q, want formatted header", traceparent) + } + + tracestate := string(header.Peek(TraceStateHeader)) + if tracestate != "dd=s:1" { + t.Errorf("tracestate = %q, want %q", tracestate, "dd=s:1") + } +} + +func TestInjectTraceContext_EmptyIDs(t *testing.T) { + header := &fasthttp.RequestHeader{} + + InjectTraceContext(header, "", "aad09d1659b4c7e3", "01", "") + + traceparent := string(header.Peek(TraceParentHeader)) + if traceparent != "" { + t.Errorf("traceparent should not be set for empty trace ID") + } +} diff --git a/framework/tracing/store.go b/framework/tracing/store.go new file mode 100644 index 0000000000..e815dd69d3 --- /dev/null +++ b/framework/tracing/store.go @@ -0,0 +1,403 @@ +// Package tracing provides distributed tracing infrastructure for Bifrost +package tracing + +import ( + "encoding/hex" + "sync" + "time" + + "github.com/google/uuid" + "github.com/maximhq/bifrost/core/schemas" +) + +// DeferredSpanInfo stores information about a deferred span for streaming requests +type DeferredSpanInfo struct { + SpanID string + StartTime time.Time + Tracer schemas.Tracer // Reference to tracer for completing the span + RequestID string // Request ID for accumulator lookup + FirstChunkTime time.Time // Timestamp of first chunk (for TTFT calculation) + AccumulatedChunks []*schemas.BifrostResponse // Accumulated streaming chunks + mu sync.Mutex // Mutex for thread-safe chunk accumulation +} + +// TraceStore manages traces with thread-safe access and object pooling +type TraceStore struct { + traces sync.Map // map[traceID]*schemas.Trace - thread-safe concurrent access + deferredSpans sync.Map // map[traceID]*DeferredSpanInfo - deferred spans for streaming requests + tracePool sync.Pool // Reuse Trace objects to reduce allocations + spanPool sync.Pool // Reuse Span objects to reduce allocations + logger schemas.Logger + + ttl time.Duration + cleanupTicker *time.Ticker + stopCleanup chan struct{} + cleanupWg sync.WaitGroup + stopOnce sync.Once // Ensures Stop() cleanup runs only once +} + +// NewTraceStore creates a new TraceStore with the given TTL for cleanup +func NewTraceStore(ttl time.Duration, logger schemas.Logger) *TraceStore { + store := &TraceStore{ + ttl: ttl, + logger: logger, + tracePool: sync.Pool{ + New: func() any { + return &schemas.Trace{ + Spans: make([]*schemas.Span, 0, 16), // Pre-allocate capacity + Attributes: make(map[string]any), + } + }, + }, + spanPool: sync.Pool{ + New: func() any { + return &schemas.Span{ + Attributes: make(map[string]any), + Events: make([]schemas.SpanEvent, 0, 4), // Pre-allocate capacity + } + }, + }, + stopCleanup: make(chan struct{}), + } + + // Start background cleanup goroutine + store.startCleanup() + + return store +} + +// CreateTrace creates a new trace and stores it, returns trace ID only. +// The inheritedTraceID parameter is the trace ID from an incoming W3C traceparent header. +// If provided, this trace will use that ID to continue the distributed trace. +// If empty, a new trace ID will be generated. +// Note: The parent span ID (for linking to upstream spans) is handled separately +// via context in StartSpan, not stored on the trace itself. +func (s *TraceStore) CreateTrace(inheritedTraceID string) string { + trace := s.tracePool.Get().(*schemas.Trace) + // Reset and initialize the trace + if inheritedTraceID != "" { + trace.TraceID = inheritedTraceID + } else { + trace.TraceID = generateTraceID() + } + // Note: trace.ParentID is intentionally not set here. + // Parent-child relationships are between spans, not traces. + // The root span's ParentID is set in StartSpan from context. + trace.ParentID = "" + trace.StartTime = time.Now() + trace.EndTime = time.Time{} + trace.RootSpan = nil + + // Reset slices but keep capacity + if trace.Spans != nil { + trace.Spans = trace.Spans[:0] + } else { + trace.Spans = make([]*schemas.Span, 0, 16) + } + + // Reset attributes + if trace.Attributes == nil { + trace.Attributes = make(map[string]any) + } else { + clear(trace.Attributes) + } + + s.traces.Store(trace.TraceID, trace) + return trace.TraceID +} + +// GetTrace retrieves a trace by ID +func (s *TraceStore) GetTrace(traceID string) *schemas.Trace { + if val, ok := s.traces.Load(traceID); ok { + return val.(*schemas.Trace) + } + return nil +} + +// CompleteTrace marks the trace as complete, removes it from store, and returns it for flushing +func (s *TraceStore) CompleteTrace(traceID string) *schemas.Trace { + // Clear any deferred span for this trace + s.deferredSpans.Delete(traceID) + + if val, ok := s.traces.LoadAndDelete(traceID); ok { + trace := val.(*schemas.Trace) + trace.EndTime = time.Now() + return trace + } + return nil +} + +// StoreDeferredSpan stores a span ID for later completion (used for streaming requests) +func (s *TraceStore) StoreDeferredSpan(traceID, spanID string) { + s.deferredSpans.Store(traceID, &DeferredSpanInfo{ + SpanID: spanID, + StartTime: time.Now(), + }) +} + +// GetDeferredSpan retrieves the deferred span info for a trace ID +func (s *TraceStore) GetDeferredSpan(traceID string) *DeferredSpanInfo { + if val, ok := s.deferredSpans.Load(traceID); ok { + return val.(*DeferredSpanInfo) + } + return nil +} + +// ClearDeferredSpan removes the deferred span info for a trace ID +func (s *TraceStore) ClearDeferredSpan(traceID string) { + s.deferredSpans.Delete(traceID) +} + +// AppendStreamingChunk adds a streaming chunk to the deferred span's accumulated data +func (s *TraceStore) AppendStreamingChunk(traceID string, chunk *schemas.BifrostResponse) { + if chunk == nil { + return + } + info := s.GetDeferredSpan(traceID) + if info == nil { + return + } + info.mu.Lock() + defer info.mu.Unlock() + + // Track first chunk time for TTFT calculation + if info.FirstChunkTime.IsZero() { + info.FirstChunkTime = time.Now() + } + + // Append chunk to accumulated list + info.AccumulatedChunks = append(info.AccumulatedChunks, chunk) +} + +// GetAccumulatedData returns the accumulated chunks and TTFT for a deferred span +func (s *TraceStore) GetAccumulatedData(traceID string) ([]*schemas.BifrostResponse, int64) { + info := s.GetDeferredSpan(traceID) + if info == nil { + return nil, 0 + } + info.mu.Lock() + defer info.mu.Unlock() + + // Calculate TTFT in milliseconds + var ttftMs int64 + if !info.StartTime.IsZero() && !info.FirstChunkTime.IsZero() { + ttftMs = info.FirstChunkTime.Sub(info.StartTime).Milliseconds() + } + + return info.AccumulatedChunks, ttftMs +} + +// ReleaseTrace returns the trace and its spans to the pools for reuse +func (s *TraceStore) ReleaseTrace(trace *schemas.Trace) { + if trace == nil { + return + } + + // Return all spans to the pool + for _, span := range trace.Spans { + s.releaseSpan(span) + } + + // Reset the trace + trace.Reset() + + // Return trace to pool + s.tracePool.Put(trace) +} + +// StartSpan creates a new span and adds it to the trace +func (s *TraceStore) StartSpan(traceID, name string, kind schemas.SpanKind) *schemas.Span { + trace := s.GetTrace(traceID) + if trace == nil { + return nil + } + + span := s.spanPool.Get().(*schemas.Span) + + // Reset and initialize the span + span.SpanID = generateSpanID() + span.TraceID = traceID + span.Name = name + span.Kind = kind + span.StartTime = time.Now() + span.EndTime = time.Time{} + span.Status = schemas.SpanStatusUnset + span.StatusMsg = "" + + // Reset slices but keep capacity + if span.Events != nil { + span.Events = span.Events[:0] + } else { + span.Events = make([]schemas.SpanEvent, 0, 4) + } + + // Reset attributes + if span.Attributes == nil { + span.Attributes = make(map[string]any) + } else { + clear(span.Attributes) + } + + // Set parent ID to root span if it exists, otherwise this is root + if trace.RootSpan != nil { + span.ParentID = trace.RootSpan.SpanID + } else { + span.ParentID = "" + trace.RootSpan = span + } + + // Add span to trace + trace.AddSpan(span) + + return span +} + +// StartChildSpan creates a new span as a child of the specified parent span +func (s *TraceStore) StartChildSpan(traceID, parentSpanID, name string, kind schemas.SpanKind) *schemas.Span { + trace := s.GetTrace(traceID) + if trace == nil { + return nil + } + + span := s.spanPool.Get().(*schemas.Span) + + // Reset and initialize the span + span.SpanID = generateSpanID() + span.ParentID = parentSpanID + span.TraceID = traceID + span.Name = name + span.Kind = kind + span.StartTime = time.Now() + span.EndTime = time.Time{} + span.Status = schemas.SpanStatusUnset + span.StatusMsg = "" + + // Reset slices but keep capacity + if span.Events != nil { + span.Events = span.Events[:0] + } else { + span.Events = make([]schemas.SpanEvent, 0, 4) + } + + // Reset attributes + if span.Attributes == nil { + span.Attributes = make(map[string]any) + } else { + clear(span.Attributes) + } + + // Set as root span if this is the first span in the trace. + // This can happen when the span has an external parent (from W3C traceparent) + // but is the first span within this service's trace. + if trace.RootSpan == nil { + trace.RootSpan = span + } + + // Add span to trace + trace.AddSpan(span) + + return span +} + +// EndSpan marks a span as complete with the given status and attributes +func (s *TraceStore) EndSpan(traceID, spanID string, status schemas.SpanStatus, statusMsg string, attrs map[string]any) { + trace := s.GetTrace(traceID) + if trace == nil { + return + } + + span := trace.GetSpan(spanID) + if span == nil { + return + } + + span.End(status, statusMsg) + + // Add any final attributes + for k, v := range attrs { + span.SetAttribute(k, v) + } +} + +// releaseSpan returns a span to the pool +func (s *TraceStore) releaseSpan(span *schemas.Span) { + if span == nil { + return + } + span.Reset() + s.spanPool.Put(span) +} + +// startCleanup starts the background cleanup goroutine +func (s *TraceStore) startCleanup() { + if s.ttl <= 0 { + return + } + + // Cleanup interval is TTL / 2 + cleanupInterval := s.ttl / 2 + if cleanupInterval < time.Minute { + cleanupInterval = time.Minute + } + + s.cleanupTicker = time.NewTicker(cleanupInterval) + s.cleanupWg.Add(1) + + go func() { + defer s.cleanupWg.Done() + for { + select { + case <-s.cleanupTicker.C: + s.cleanupOldTraces() + case <-s.stopCleanup: + return + } + } + }() +} + +// cleanupOldTraces removes traces that have exceeded the TTL +func (s *TraceStore) cleanupOldTraces() { + cutoff := time.Now().Add(-s.ttl) + count := 0 + + s.traces.Range(func(key, value any) bool { + trace := value.(*schemas.Trace) + if trace.StartTime.Before(cutoff) { + if deleted, ok := s.traces.LoadAndDelete(key); ok { + s.ReleaseTrace(deleted.(*schemas.Trace)) + count++ + } + } + return true + }) + + if count > 0 && s.logger != nil { + s.logger.Debug("tracing: cleaned up %d orphaned traces", count) + } +} + +// Stop stops the cleanup goroutine and releases resources +func (s *TraceStore) Stop() { + s.stopOnce.Do(func() { + if s.cleanupTicker != nil { + s.cleanupTicker.Stop() + } + close(s.stopCleanup) + s.cleanupWg.Wait() + }) +} + +// generateTraceID generates a W3C-compliant trace ID. +// Returns 32 lowercase hex characters (128-bit UUID without hyphens). +func generateTraceID() string { + u := uuid.New() + return hex.EncodeToString(u[:]) +} + +// generateSpanID generates a W3C-compliant span ID. +// Returns 16 lowercase hex characters (first 64 bits of a UUID). +func generateSpanID() string { + u := uuid.New() + return hex.EncodeToString(u[:8]) +} diff --git a/framework/tracing/store_test.go b/framework/tracing/store_test.go new file mode 100644 index 0000000000..6548d12446 --- /dev/null +++ b/framework/tracing/store_test.go @@ -0,0 +1,252 @@ +package tracing + +import ( + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestCreateTrace_WithInheritedTraceID(t *testing.T) { + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + // Use a trace ID from an incoming W3C traceparent header + inheritedTraceID := "69538b980000000079943934f90c1d40" + + traceID := store.CreateTrace(inheritedTraceID) + + if traceID != inheritedTraceID { + t.Errorf("CreateTrace() returned %q, want inherited trace ID %q", traceID, inheritedTraceID) + } + + trace := store.GetTrace(traceID) + if trace == nil { + t.Fatal("GetTrace() returned nil") + } + + if trace.TraceID != inheritedTraceID { + t.Errorf("trace.TraceID = %q, want %q", trace.TraceID, inheritedTraceID) + } + + // ParentID should be empty - we no longer set it incorrectly to the trace ID + if trace.ParentID != "" { + t.Errorf("trace.ParentID = %q, want empty string (parent span ID is set on spans, not traces)", trace.ParentID) + } +} + +func TestCreateTrace_GeneratesNewTraceID(t *testing.T) { + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + traceID := store.CreateTrace("") + + if traceID == "" { + t.Error("CreateTrace() returned empty trace ID") + } + + // Generated trace ID should be 32 hex characters + if len(traceID) != 32 { + t.Errorf("Generated trace ID length = %d, want 32", len(traceID)) + } + + // Verify it's valid hex + if !isHex(traceID) { + t.Errorf("Generated trace ID %q is not valid hex", traceID) + } + + trace := store.GetTrace(traceID) + if trace == nil { + t.Fatal("GetTrace() returned nil") + } + + if trace.ParentID != "" { + t.Errorf("trace.ParentID = %q, want empty string", trace.ParentID) + } +} + +func TestStartSpan_RootSpanHasNoParent(t *testing.T) { + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + traceID := store.CreateTrace("") + + span := store.StartSpan(traceID, "root-operation", schemas.SpanKindHTTPRequest) + if span == nil { + t.Fatal("StartSpan() returned nil") + } + + // Root span should have no parent when there's no incoming trace context + if span.ParentID != "" { + t.Errorf("root span.ParentID = %q, want empty string", span.ParentID) + } + + if span.TraceID != traceID { + t.Errorf("span.TraceID = %q, want %q", span.TraceID, traceID) + } + + // Verify it's set as root span + trace := store.GetTrace(traceID) + if trace.RootSpan != span { + t.Error("StartSpan() did not set trace.RootSpan") + } +} + +func TestStartSpan_SecondSpanHasRootAsParent(t *testing.T) { + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + traceID := store.CreateTrace("") + + rootSpan := store.StartSpan(traceID, "root-operation", schemas.SpanKindHTTPRequest) + if rootSpan == nil { + t.Fatal("StartSpan() returned nil for root span") + } + + // Second span created with StartSpan should have root as parent + secondSpan := store.StartSpan(traceID, "second-operation", schemas.SpanKindLLMCall) + if secondSpan == nil { + t.Fatal("StartSpan() returned nil for second span") + } + + if secondSpan.ParentID != rootSpan.SpanID { + t.Errorf("second span.ParentID = %q, want root span ID %q", secondSpan.ParentID, rootSpan.SpanID) + } +} + +func TestStartChildSpan_HasCorrectParent(t *testing.T) { + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + traceID := store.CreateTrace("") + + rootSpan := store.StartSpan(traceID, "root-operation", schemas.SpanKindHTTPRequest) + if rootSpan == nil { + t.Fatal("StartSpan() returned nil for root span") + } + + // Create a child span with explicit parent + childSpan := store.StartChildSpan(traceID, rootSpan.SpanID, "child-operation", schemas.SpanKindLLMCall) + if childSpan == nil { + t.Fatal("StartChildSpan() returned nil") + } + + if childSpan.ParentID != rootSpan.SpanID { + t.Errorf("child span.ParentID = %q, want %q", childSpan.ParentID, rootSpan.SpanID) + } + + if childSpan.TraceID != traceID { + t.Errorf("child span.TraceID = %q, want %q", childSpan.TraceID, traceID) + } +} + +func TestStartChildSpan_WithExternalParentSpanID(t *testing.T) { + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + // Simulating an incoming request with W3C traceparent header + inheritedTraceID := "69538b980000000079943934f90c1d40" + externalParentSpanID := "aad09d1659b4c7e3" // Parent span ID from upstream service + + traceID := store.CreateTrace(inheritedTraceID) + + // Create root span as child of external parent span + // This is what should happen when processing an incoming distributed trace + rootSpan := store.StartChildSpan(traceID, externalParentSpanID, "bifrost-request", schemas.SpanKindHTTPRequest) + if rootSpan == nil { + t.Fatal("StartChildSpan() returned nil") + } + + // Root span should have the external parent span ID + if rootSpan.ParentID != externalParentSpanID { + t.Errorf("root span.ParentID = %q, want external parent %q", rootSpan.ParentID, externalParentSpanID) + } + + if rootSpan.TraceID != inheritedTraceID { + t.Errorf("root span.TraceID = %q, want inherited trace ID %q", rootSpan.TraceID, inheritedTraceID) + } +} + +func TestGetTrace_NotFound(t *testing.T) { + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + trace := store.GetTrace("nonexistent-trace-id") + if trace != nil { + t.Error("GetTrace() should return nil for nonexistent trace") + } +} + +func TestCompleteTrace_ReturnsAndRemoves(t *testing.T) { + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + traceID := store.CreateTrace("") + store.StartSpan(traceID, "operation", schemas.SpanKindHTTPRequest) + + trace := store.CompleteTrace(traceID) + if trace == nil { + t.Fatal("CompleteTrace() returned nil") + } + + if trace.TraceID != traceID { + t.Errorf("trace.TraceID = %q, want %q", trace.TraceID, traceID) + } + + if trace.EndTime.IsZero() { + t.Error("trace.EndTime should be set") + } + + // Trace should be removed from store + if store.GetTrace(traceID) != nil { + t.Error("Trace should be removed from store after CompleteTrace()") + } +} + +func TestEndSpan_SetsStatusAndTime(t *testing.T) { + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + traceID := store.CreateTrace("") + span := store.StartSpan(traceID, "operation", schemas.SpanKindHTTPRequest) + + store.EndSpan(traceID, span.SpanID, schemas.SpanStatusOk, "success", map[string]any{ + "custom.attr": "value", + }) + + if span.Status != schemas.SpanStatusOk { + t.Errorf("span.Status = %v, want SpanStatusOk", span.Status) + } + + if span.EndTime.IsZero() { + t.Error("span.EndTime should be set") + } + + if span.Attributes["custom.attr"] != "value" { + t.Error("EndSpan() should set custom attributes") + } +} + +func TestGenerateTraceID_Format(t *testing.T) { + id := generateTraceID() + + if len(id) != 32 { + t.Errorf("generateTraceID() length = %d, want 32", len(id)) + } + + if !isHex(id) { + t.Errorf("generateTraceID() = %q, not valid hex", id) + } +} + +func TestGenerateSpanID_Format(t *testing.T) { + id := generateSpanID() + + if len(id) != 16 { + t.Errorf("generateSpanID() length = %d, want 16", len(id)) + } + + if !isHex(id) { + t.Errorf("generateSpanID() = %q, not valid hex", id) + } +} diff --git a/framework/tracing/tracer.go b/framework/tracing/tracer.go new file mode 100644 index 0000000000..d12072b634 --- /dev/null +++ b/framework/tracing/tracer.go @@ -0,0 +1,564 @@ +// Package tracing provides distributed tracing infrastructure for Bifrost +package tracing + +import ( + "context" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/modelcatalog" + "github.com/maximhq/bifrost/framework/streaming" +) + +// Tracer implements schemas.Tracer using TraceStore. +// It provides the bridge between the core Tracer interface and the +// framework's TraceStore implementation. +// It also embeds a streaming.Accumulator for centralized streaming chunk accumulation. +type Tracer struct { + store *TraceStore + accumulator *streaming.Accumulator +} + +// NewTracer creates a new Tracer wrapping the given TraceStore. +// The accumulator is embedded for centralized streaming chunk accumulation. +func NewTracer(store *TraceStore, pricingManager *modelcatalog.ModelCatalog, logger schemas.Logger) *Tracer { + return &Tracer{ + store: store, + accumulator: streaming.NewAccumulator(pricingManager, logger), + } +} + +// CreateTrace creates a new trace with optional parent ID and returns the trace ID. +func (t *Tracer) CreateTrace(parentID string) string { + return t.store.CreateTrace(parentID) +} + +// EndTrace completes a trace and returns the trace data for observation/export. +// The returned trace should be released after use by calling ReleaseTrace. +func (t *Tracer) EndTrace(traceID string) *schemas.Trace { + trace := t.store.CompleteTrace(traceID) + if trace == nil { + return nil + } + // Note: Caller is responsible for releasing the trace after plugin processing + // by calling ReleaseTrace on the store or letting GC handle it + return trace +} + +// ReleaseTrace returns the trace to the pool for reuse. +// Should be called after EndTrace when the trace data is no longer needed. +func (t *Tracer) ReleaseTrace(trace *schemas.Trace) { + t.store.ReleaseTrace(trace) +} + +// spanHandle is the concrete implementation of schemas.SpanHandle for Tracer. +// It contains the trace and span IDs needed to reference the span in the store. +type spanHandle struct { + traceID string + spanID string +} + +// StartSpan creates a new span as a child of the current span in context. +// It reads the trace ID and parent span ID from context, creates the span, +// and returns an updated context with the new span ID. +// +// Parent span resolution order: +// 1. BifrostContextKeySpanID - existing span in this service (for child spans) +// 2. BifrostContextKeyParentSpanID - incoming parent from W3C traceparent (for root spans) +// 3. No parent - creates a root span with no parent +func (t *Tracer) StartSpan(ctx context.Context, name string, kind schemas.SpanKind) (context.Context, schemas.SpanHandle) { + traceID := GetTraceID(ctx) + if traceID == "" { + return ctx, nil + } + + // Get parent span ID from context - first check for existing span in this service + parentSpanID, _ := ctx.Value(schemas.BifrostContextKeySpanID).(string) + + // If no existing span, check for incoming parent span ID from W3C traceparent header + // This links the root span of this service to the upstream service's span + if parentSpanID == "" { + parentSpanID, _ = ctx.Value(schemas.BifrostContextKeyParentSpanID).(string) + } + + var span *schemas.Span + if parentSpanID != "" { + span = t.store.StartChildSpan(traceID, parentSpanID, name, kind) + } else { + span = t.store.StartSpan(traceID, name, kind) + } + if span == nil { + return ctx, nil + } + // Update context with new span ID + newCtx := context.WithValue(ctx, schemas.BifrostContextKeySpanID, span.SpanID) + return newCtx, &spanHandle{traceID: traceID, spanID: span.SpanID} +} + +// EndSpan completes a span with the given status and message. +func (t *Tracer) EndSpan(handle schemas.SpanHandle, status schemas.SpanStatus, statusMsg string) { + h, ok := handle.(*spanHandle) + if !ok || h == nil { + return + } + t.store.EndSpan(h.traceID, h.spanID, status, statusMsg, nil) +} + +// SetAttribute sets an attribute on the span identified by the handle. +func (t *Tracer) SetAttribute(handle schemas.SpanHandle, key string, value any) { + h, ok := handle.(*spanHandle) + if !ok || h == nil { + return + } + trace := t.store.GetTrace(h.traceID) + if trace == nil { + return + } + span := trace.GetSpan(h.spanID) + if span != nil { + span.SetAttribute(key, value) + } +} + +// AddEvent adds a timestamped event to the span identified by the handle. +func (t *Tracer) AddEvent(handle schemas.SpanHandle, name string, attrs map[string]any) { + h, ok := handle.(*spanHandle) + if !ok || h == nil { + return + } + trace := t.store.GetTrace(h.traceID) + if trace == nil { + return + } + span := trace.GetSpan(h.spanID) + if span != nil { + span.AddEvent(schemas.SpanEvent{ + Name: name, + Timestamp: time.Now(), + Attributes: attrs, + }) + } +} + +// PopulateLLMRequestAttributes populates all LLM-specific request attributes on the span. +func (t *Tracer) PopulateLLMRequestAttributes(handle schemas.SpanHandle, req *schemas.BifrostRequest) { + h, ok := handle.(*spanHandle) + if !ok || h == nil || req == nil { + return + } + trace := t.store.GetTrace(h.traceID) + if trace == nil { + return + } + span := trace.GetSpan(h.spanID) + if span == nil { + return + } + + for k, v := range PopulateRequestAttributes(req) { + span.SetAttribute(k, v) + } +} + +// PopulateLLMResponseAttributes populates all LLM-specific response attributes on the span. +func (t *Tracer) PopulateLLMResponseAttributes(handle schemas.SpanHandle, resp *schemas.BifrostResponse, err *schemas.BifrostError) { + h, ok := handle.(*spanHandle) + if !ok || h == nil { + return + } + trace := t.store.GetTrace(h.traceID) + if trace == nil { + return + } + span := trace.GetSpan(h.spanID) + if span == nil { + return + } + for k, v := range PopulateResponseAttributes(resp) { + span.SetAttribute(k, v) + } + for k, v := range PopulateErrorAttributes(err) { + span.SetAttribute(k, v) + } +} + +// StoreDeferredSpan stores a span handle for later completion (used for streaming requests). +// The span handle is stored keyed by trace ID so it can be retrieved when the stream completes. +func (t *Tracer) StoreDeferredSpan(traceID string, handle schemas.SpanHandle) { + h, ok := handle.(*spanHandle) + if !ok || h == nil { + return + } + t.store.StoreDeferredSpan(traceID, h.spanID) +} + +// GetDeferredSpanHandle retrieves a deferred span handle by trace ID. +// Returns nil if no deferred span exists for the given trace ID. +func (t *Tracer) GetDeferredSpanHandle(traceID string) schemas.SpanHandle { + info := t.store.GetDeferredSpan(traceID) + if info == nil { + return nil + } + return &spanHandle{traceID: traceID, spanID: info.SpanID} +} + +// ClearDeferredSpan removes the deferred span handle for a trace ID. +// Should be called after the deferred span has been completed. +func (t *Tracer) ClearDeferredSpan(traceID string) { + t.store.ClearDeferredSpan(traceID) +} + +// GetDeferredSpanID returns the span ID for the deferred span. +// Returns empty string if no deferred span exists. +func (t *Tracer) GetDeferredSpanID(traceID string) string { + info := t.store.GetDeferredSpan(traceID) + if info == nil { + return "" + } + return info.SpanID +} + +// AddStreamingChunk accumulates a streaming chunk for the deferred span. +// This stores the full BifrostResponse chunk for later reconstruction. +// Note: This method still uses the store for backward compatibility with existing code. +// For new code, prefer using ProcessStreamingChunk which uses the embedded accumulator. +func (t *Tracer) AddStreamingChunk(traceID string, response *schemas.BifrostResponse) { + if traceID == "" || response == nil { + return + } + t.store.AppendStreamingChunk(traceID, response) +} + +// GetAccumulatedChunks returns the accumulated BifrostResponse, TTFT, and chunk count for a deferred span. +// It reconstructs a complete response from all accumulated streaming chunks. +// Note: This method still uses the store for backward compatibility with existing code. +// For new code, prefer using ProcessStreamingChunk which uses the embedded accumulator. +func (t *Tracer) GetAccumulatedChunks(traceID string) (*schemas.BifrostResponse, int64, int) { + chunks, ttftMs := t.store.GetAccumulatedData(traceID) + if len(chunks) == 0 { + return nil, 0, 0 + } + + // Build complete response from accumulated chunks + return buildCompleteResponseFromChunks(chunks), ttftMs, len(chunks) +} + +// buildCompleteResponseFromChunks reconstructs a complete BifrostResponse from streaming chunks. +// This accumulates content, tool calls, reasoning, audio, and other fields. +// Note: This is kept for backward compatibility with existing code that uses AddStreamingChunk/GetAccumulatedChunks. +func buildCompleteResponseFromChunks(chunks []*schemas.BifrostResponse) *schemas.BifrostResponse { + if len(chunks) == 0 { + return nil + } + + // Use the last chunk as a base (it typically has final usage stats, finish reason, etc.) + lastChunk := chunks[len(chunks)-1] + if lastChunk.ChatResponse == nil { + return nil + } + + result := &schemas.BifrostResponse{ + ChatResponse: &schemas.BifrostChatResponse{ + ID: lastChunk.ChatResponse.ID, + Object: lastChunk.ChatResponse.Object, + Model: lastChunk.ChatResponse.Model, + Created: lastChunk.ChatResponse.Created, + Usage: lastChunk.ChatResponse.Usage, + ExtraFields: lastChunk.ChatResponse.ExtraFields, + Choices: make([]schemas.BifrostResponseChoice, 0), + }, + } + + // Track accumulated content per choice index + type choiceAccumulator struct { + content string + refusal string + reasoning string + reasoningDetails []schemas.ChatReasoningDetails + toolCalls map[int]schemas.ChatAssistantMessageToolCall // keyed by tool call index + audio *schemas.ChatAudioMessageAudio + role schemas.ChatMessageRole + finishReason *string + } + + choiceMap := make(map[int]*choiceAccumulator) + + // Process chunks in order + for _, chunk := range chunks { + if chunk.ChatResponse == nil { + continue + } + for _, choice := range chunk.ChatResponse.Choices { + if choice.ChatStreamResponseChoice == nil || choice.ChatStreamResponseChoice.Delta == nil { + continue + } + delta := choice.ChatStreamResponseChoice.Delta + idx := choice.Index + + // Get or create accumulator for this choice + acc, ok := choiceMap[idx] + if !ok { + acc = &choiceAccumulator{ + role: schemas.ChatMessageRoleAssistant, + toolCalls: make(map[int]schemas.ChatAssistantMessageToolCall), + } + choiceMap[idx] = acc + } + + // Accumulate content + if delta.Content != nil { + acc.content += *delta.Content + } + + // Role (usually in first chunk) + if delta.Role != nil { + acc.role = schemas.ChatMessageRole(*delta.Role) + } + + // Refusal + if delta.Refusal != nil { + acc.refusal += *delta.Refusal + } + + // Reasoning + if delta.Reasoning != nil { + acc.reasoning += *delta.Reasoning + } + + // Reasoning details (merge by index) + for _, rd := range delta.ReasoningDetails { + found := false + for i := range acc.reasoningDetails { + if acc.reasoningDetails[i].Index == rd.Index { + // Accumulate text + if rd.Text != nil { + if acc.reasoningDetails[i].Text == nil { + acc.reasoningDetails[i].Text = rd.Text + } else { + newText := *acc.reasoningDetails[i].Text + *rd.Text + acc.reasoningDetails[i].Text = &newText + } + } + // Update type if present + if rd.Type != "" { + acc.reasoningDetails[i].Type = rd.Type + } + found = true + break + } + } + if !found { + acc.reasoningDetails = append(acc.reasoningDetails, rd) + } + } + + // Audio + if delta.Audio != nil { + if acc.audio == nil { + acc.audio = &schemas.ChatAudioMessageAudio{ + ID: delta.Audio.ID, + Data: delta.Audio.Data, + ExpiresAt: delta.Audio.ExpiresAt, + Transcript: delta.Audio.Transcript, + } + } else { + acc.audio.Data += delta.Audio.Data + acc.audio.Transcript += delta.Audio.Transcript + if delta.Audio.ID != "" { + acc.audio.ID = delta.Audio.ID + } + if delta.Audio.ExpiresAt != 0 { + acc.audio.ExpiresAt = delta.Audio.ExpiresAt + } + } + } + + // Tool calls (merge by index) + for _, tc := range delta.ToolCalls { + tcIdx := int(tc.Index) + existing, ok := acc.toolCalls[tcIdx] + if !ok { + // New tool call + acc.toolCalls[tcIdx] = tc + } else { + // Merge: accumulate arguments, update other fields + if tc.ID != nil { + existing.ID = tc.ID + } + if tc.Type != nil { + existing.Type = tc.Type + } + if tc.Function.Name != nil { + existing.Function.Name = tc.Function.Name + } + existing.Function.Arguments += tc.Function.Arguments + acc.toolCalls[tcIdx] = existing + } + } + + // Finish reason (from BifrostResponseChoice, not ChatStreamResponseChoice) + if choice.FinishReason != nil { + acc.finishReason = choice.FinishReason + } + } + } + + // Build final choices from accumulated data + // Sort choice indices for deterministic output + choiceIndices := make([]int, 0, len(choiceMap)) + for idx := range choiceMap { + choiceIndices = append(choiceIndices, idx) + } + + for _, idx := range choiceIndices { + accum := choiceMap[idx] + + // Build message + msg := &schemas.ChatMessage{ + Role: accum.role, + } + + // Set content + if accum.content != "" { + msg.Content = &schemas.ChatMessageContent{ + ContentStr: &accum.content, + } + } + + // Build assistant message fields + if accum.refusal != "" || accum.reasoning != "" || len(accum.reasoningDetails) > 0 || + accum.audio != nil || len(accum.toolCalls) > 0 { + msg.ChatAssistantMessage = &schemas.ChatAssistantMessage{} + + if accum.refusal != "" { + msg.ChatAssistantMessage.Refusal = &accum.refusal + } + if accum.reasoning != "" { + msg.ChatAssistantMessage.Reasoning = &accum.reasoning + } + if len(accum.reasoningDetails) > 0 { + msg.ChatAssistantMessage.ReasoningDetails = accum.reasoningDetails + } + if accum.audio != nil { + msg.ChatAssistantMessage.Audio = accum.audio + } + if len(accum.toolCalls) > 0 { + // Sort tool calls by index + tcIndices := make([]int, 0, len(accum.toolCalls)) + for tcIdx := range accum.toolCalls { + tcIndices = append(tcIndices, tcIdx) + } + toolCalls := make([]schemas.ChatAssistantMessageToolCall, 0, len(accum.toolCalls)) + for _, tcIdx := range tcIndices { + toolCalls = append(toolCalls, accum.toolCalls[tcIdx]) + } + msg.ChatAssistantMessage.ToolCalls = toolCalls + } + } + + // Build choice + choice := schemas.BifrostResponseChoice{ + Index: idx, + FinishReason: accum.finishReason, + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: msg, + }, + } + result.ChatResponse.Choices = append(result.ChatResponse.Choices, choice) + } + + return result +} + +// CreateStreamAccumulator creates a new stream accumulator for the given trace ID. +// This should be called at the start of a streaming request. +func (t *Tracer) CreateStreamAccumulator(traceID string, startTime time.Time) { + if traceID == "" || t.accumulator == nil { + return + } + t.accumulator.CreateStreamAccumulator(traceID, startTime) +} + +// CleanupStreamAccumulator removes the stream accumulator for the given trace ID. +// This should be called after the streaming request is complete. +func (t *Tracer) CleanupStreamAccumulator(traceID string) { + if traceID == "" || t.accumulator == nil { + if t.store != nil && t.store.logger != nil { + t.store.logger.Error("traceID or accumulator is nil in CleanupStreamAccumulator") + } + return + } + if err := t.accumulator.CleanupStreamAccumulator(traceID); err != nil { + if t.store != nil && t.store.logger != nil { + t.store.logger.Error("error in CleanupStreamAccumulator: %v", err) + } + } +} + +// ProcessStreamingChunk processes a streaming chunk and accumulates it. +// Returns the accumulated result. IsFinal will be true when the stream is complete. +// This method is used by plugins to access accumulated streaming data. +// The ctx parameter must contain the stream end indicator for proper final chunk detection. +func (t *Tracer) ProcessStreamingChunk(ctx *schemas.BifrostContext, traceID string, result *schemas.BifrostResponse, err *schemas.BifrostError) *schemas.StreamAccumulatorResult { + if traceID == "" || t.accumulator == nil || ctx == nil { + return nil + } + + // Create a new context for accumulator that sets the traceID as the accumulator lookup ID. + // This inherits from the original context (preserves stream end indicator). + accumCtx := schemas.NewBifrostContext(ctx, time.Time{}) + accumCtx.SetValue(schemas.BifrostContextKeyAccumulatorID, traceID) + + processedResp, processErr := t.accumulator.ProcessStreamingResponse(accumCtx, result, err) + if processErr != nil || processedResp == nil { + return nil + } + + // Convert ProcessedStreamResponse to StreamAccumulatorResult + accResult := &schemas.StreamAccumulatorResult{ + RequestID: processedResp.RequestID, + Model: processedResp.Model, + Provider: processedResp.Provider, + } + + if processedResp.Data != nil { + accResult.Status = processedResp.Data.Status + accResult.Latency = processedResp.Data.Latency + accResult.TimeToFirstToken = processedResp.Data.TimeToFirstToken + accResult.OutputMessage = processedResp.Data.OutputMessage + accResult.OutputMessages = processedResp.Data.OutputMessages + accResult.TokenUsage = processedResp.Data.TokenUsage + accResult.Cost = processedResp.Data.Cost + accResult.ErrorDetails = processedResp.Data.ErrorDetails + accResult.AudioOutput = processedResp.Data.AudioOutput + accResult.TranscriptionOutput = processedResp.Data.TranscriptionOutput + accResult.FinishReason = processedResp.Data.FinishReason + accResult.RawResponse = processedResp.Data.RawResponse + } + + if processedResp.RawRequest != nil { + accResult.RawRequest = *processedResp.RawRequest + } + + return accResult +} + +// GetAccumulator returns the embedded streaming accumulator. +// This is useful for plugins that need direct access to accumulator methods. +func (t *Tracer) GetAccumulator() *streaming.Accumulator { + return t.accumulator +} + +// Stop stops the tracer and releases its resources. +// This stops the internal TraceStore's cleanup goroutine. +func (t *Tracer) Stop() { + if t.store != nil { + t.store.Stop() + } + if t.accumulator != nil { + t.accumulator.Cleanup() + } +} + +// Ensure Tracer implements schemas.Tracer at compile time +var _ schemas.Tracer = (*Tracer)(nil) diff --git a/framework/tracing/tracer_test.go b/framework/tracing/tracer_test.go new file mode 100644 index 0000000000..372e075829 --- /dev/null +++ b/framework/tracing/tracer_test.go @@ -0,0 +1,388 @@ +package tracing + +import ( + "context" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestTracer_StartSpan_RootSpanWithW3CParent(t *testing.T) { + // This is the key test: verifies that when an incoming request has a W3C traceparent header, + // the root span in Bifrost correctly links to the upstream service's span. + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + tracer := NewTracer(store, nil, nil) + defer tracer.Stop() + + // Simulate incoming W3C traceparent: 00-{traceID}-{parentSpanID}-01 + inheritedTraceID := "69538b980000000079943934f90c1d40" + externalParentSpanID := "aad09d1659b4c7e3" + + // Create trace with inherited trace ID + traceID := tracer.CreateTrace(inheritedTraceID) + if traceID != inheritedTraceID { + t.Errorf("CreateTrace() = %q, want inherited trace ID %q", traceID, inheritedTraceID) + } + + // Set up context with trace ID and parent span ID (as middleware would do) + ctx := context.WithValue(context.Background(), schemas.BifrostContextKeyTraceID, traceID) + ctx = context.WithValue(ctx, schemas.BifrostContextKeyParentSpanID, externalParentSpanID) + + // Create root span - this should link to the external parent + newCtx, handle := tracer.StartSpan(ctx, "bifrost-http-request", schemas.SpanKindHTTPRequest) + if handle == nil { + t.Fatal("StartSpan() returned nil handle") + } + + // Verify the span was created with correct parent + trace := store.GetTrace(traceID) + if trace == nil { + t.Fatal("Trace not found in store") + } + + if trace.RootSpan == nil { + t.Fatal("Root span not set on trace") + } + + // THE CRITICAL CHECK: Root span should have the external parent span ID + if trace.RootSpan.ParentID != externalParentSpanID { + t.Errorf("Root span ParentID = %q, want external parent span ID %q", trace.RootSpan.ParentID, externalParentSpanID) + } + + // Verify trace ID is preserved + if trace.RootSpan.TraceID != inheritedTraceID { + t.Errorf("Root span TraceID = %q, want %q", trace.RootSpan.TraceID, inheritedTraceID) + } + + // Verify context has span ID for child span creation + spanID, ok := newCtx.Value(schemas.BifrostContextKeySpanID).(string) + if !ok || spanID == "" { + t.Error("Context should have span ID after StartSpan()") + } + + if spanID != trace.RootSpan.SpanID { + t.Errorf("Context span ID = %q, want %q", spanID, trace.RootSpan.SpanID) + } +} + +func TestTracer_StartSpan_RootSpanWithoutW3CParent(t *testing.T) { + // When there's no incoming W3C context, root span should have no parent + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + tracer := NewTracer(store, nil, nil) + defer tracer.Stop() + + // Create new trace (no inherited trace ID) + traceID := tracer.CreateTrace("") + + // Set up context with only trace ID (no parent span ID) + ctx := context.WithValue(context.Background(), schemas.BifrostContextKeyTraceID, traceID) + + // Create root span + _, handle := tracer.StartSpan(ctx, "local-request", schemas.SpanKindHTTPRequest) + if handle == nil { + t.Fatal("StartSpan() returned nil handle") + } + + trace := store.GetTrace(traceID) + if trace == nil { + t.Fatal("Trace not found in store") + } + + // Root span should have no parent + if trace.RootSpan.ParentID != "" { + t.Errorf("Root span ParentID = %q, want empty string (no W3C parent)", trace.RootSpan.ParentID) + } +} + +func TestTracer_StartSpan_ChildSpanLinking(t *testing.T) { + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + tracer := NewTracer(store, nil, nil) + defer tracer.Stop() + + inheritedTraceID := "69538b980000000079943934f90c1d40" + externalParentSpanID := "aad09d1659b4c7e3" + + traceID := tracer.CreateTrace(inheritedTraceID) + + // Set up context with W3C parent span ID + ctx := context.WithValue(context.Background(), schemas.BifrostContextKeyTraceID, traceID) + ctx = context.WithValue(ctx, schemas.BifrostContextKeyParentSpanID, externalParentSpanID) + + // Create root span + rootCtx, rootHandle := tracer.StartSpan(ctx, "http-request", schemas.SpanKindHTTPRequest) + if rootHandle == nil { + t.Fatal("StartSpan() returned nil handle for root span") + } + + // Create child span using the context from root span + childCtx, childHandle := tracer.StartSpan(rootCtx, "llm-call", schemas.SpanKindLLMCall) + if childHandle == nil { + t.Fatal("StartSpan() returned nil handle for child span") + } + + trace := store.GetTrace(traceID) + + // Find the child span + var childSpan *schemas.Span + for _, span := range trace.Spans { + if span.Name == "llm-call" { + childSpan = span + break + } + } + + if childSpan == nil { + t.Fatal("Child span not found in trace") + } + + // Child span should have root span as parent (not the external parent) + if childSpan.ParentID != trace.RootSpan.SpanID { + t.Errorf("Child span ParentID = %q, want root span ID %q", childSpan.ParentID, trace.RootSpan.SpanID) + } + + // Create grandchild span + _, grandchildHandle := tracer.StartSpan(childCtx, "plugin-call", schemas.SpanKindPlugin) + if grandchildHandle == nil { + t.Fatal("StartSpan() returned nil handle for grandchild span") + } + + // Find the grandchild span + var grandchildSpan *schemas.Span + for _, span := range trace.Spans { + if span.Name == "plugin-call" { + grandchildSpan = span + break + } + } + + if grandchildSpan == nil { + t.Fatal("Grandchild span not found in trace") + } + + // Grandchild should have child as parent + if grandchildSpan.ParentID != childSpan.SpanID { + t.Errorf("Grandchild span ParentID = %q, want child span ID %q", grandchildSpan.ParentID, childSpan.SpanID) + } +} + +func TestTracer_StartSpan_NoTraceID(t *testing.T) { + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + tracer := NewTracer(store, nil, nil) + defer tracer.Stop() + + // Context without trace ID + ctx := context.Background() + + newCtx, handle := tracer.StartSpan(ctx, "operation", schemas.SpanKindHTTPRequest) + if handle != nil { + t.Error("StartSpan() should return nil handle when no trace ID in context") + } + + // Context should be unchanged + if newCtx != ctx { + t.Error("Context should be unchanged when StartSpan() fails") + } +} + +func TestTracer_EndTrace_ReturnsTraceData(t *testing.T) { + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + tracer := NewTracer(store, nil, nil) + defer tracer.Stop() + + inheritedTraceID := "69538b980000000079943934f90c1d40" + externalParentSpanID := "aad09d1659b4c7e3" + + traceID := tracer.CreateTrace(inheritedTraceID) + + ctx := context.WithValue(context.Background(), schemas.BifrostContextKeyTraceID, traceID) + ctx = context.WithValue(ctx, schemas.BifrostContextKeyParentSpanID, externalParentSpanID) + + _, rootHandle := tracer.StartSpan(ctx, "http-request", schemas.SpanKindHTTPRequest) + tracer.EndSpan(rootHandle, schemas.SpanStatusOk, "") + + trace := tracer.EndTrace(traceID) + if trace == nil { + t.Fatal("EndTrace() returned nil") + } + + if trace.TraceID != inheritedTraceID { + t.Errorf("trace.TraceID = %q, want %q", trace.TraceID, inheritedTraceID) + } + + if len(trace.Spans) != 1 { + t.Errorf("len(trace.Spans) = %d, want 1", len(trace.Spans)) + } + + // Root span should still have external parent + if trace.RootSpan.ParentID != externalParentSpanID { + t.Errorf("Root span ParentID = %q, want %q", trace.RootSpan.ParentID, externalParentSpanID) + } +} + +func TestTracer_SetAttribute(t *testing.T) { + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + tracer := NewTracer(store, nil, nil) + defer tracer.Stop() + + traceID := tracer.CreateTrace("") + ctx := context.WithValue(context.Background(), schemas.BifrostContextKeyTraceID, traceID) + + _, handle := tracer.StartSpan(ctx, "operation", schemas.SpanKindHTTPRequest) + + tracer.SetAttribute(handle, "http.method", "POST") + tracer.SetAttribute(handle, "http.status_code", 200) + + trace := store.GetTrace(traceID) + span := trace.RootSpan + + if span.Attributes["http.method"] != "POST" { + t.Errorf("span attribute http.method = %v, want POST", span.Attributes["http.method"]) + } + + if span.Attributes["http.status_code"] != 200 { + t.Errorf("span attribute http.status_code = %v, want 200", span.Attributes["http.status_code"]) + } +} + +func TestTracer_AddEvent(t *testing.T) { + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + tracer := NewTracer(store, nil, nil) + defer tracer.Stop() + + traceID := tracer.CreateTrace("") + ctx := context.WithValue(context.Background(), schemas.BifrostContextKeyTraceID, traceID) + + _, handle := tracer.StartSpan(ctx, "operation", schemas.SpanKindHTTPRequest) + + tracer.AddEvent(handle, "request.received", map[string]any{ + "size": 1024, + }) + + trace := store.GetTrace(traceID) + span := trace.RootSpan + + if len(span.Events) != 1 { + t.Fatalf("len(span.Events) = %d, want 1", len(span.Events)) + } + + if span.Events[0].Name != "request.received" { + t.Errorf("event name = %q, want request.received", span.Events[0].Name) + } + + if span.Events[0].Attributes["size"] != 1024 { + t.Errorf("event attribute size = %v, want 1024", span.Events[0].Attributes["size"]) + } +} + +// TestIntegration_FullDistributedTraceFlow tests the complete flow of receiving +// a distributed trace from an upstream service and properly linking spans. +func TestIntegration_FullDistributedTraceFlow(t *testing.T) { + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + tracer := NewTracer(store, nil, nil) + defer tracer.Stop() + + // Simulating headers from user's actual Datadog request: + // traceparent: 00-69538b980000000079943934f90c1d40-aad09d1659b4c7e3-01 + inheritedTraceID := "69538b980000000079943934f90c1d40" + externalParentSpanID := "aad09d1659b4c7e3" + + // Step 1: Middleware extracts trace context and creates trace + traceID := tracer.CreateTrace(inheritedTraceID) + + // Step 2: Middleware sets up context (simulating what TracingMiddleware does) + ctx := context.WithValue(context.Background(), schemas.BifrostContextKeyTraceID, traceID) + ctx = context.WithValue(ctx, schemas.BifrostContextKeyParentSpanID, externalParentSpanID) + + // Step 3: Middleware creates root span + httpCtx, httpHandle := tracer.StartSpan(ctx, "/v1/chat/completions", schemas.SpanKindHTTPRequest) + tracer.SetAttribute(httpHandle, "http.method", "POST") + + // Step 4: Bifrost creates LLM call span + llmCtx, llmHandle := tracer.StartSpan(httpCtx, "openai.chat.completions", schemas.SpanKindLLMCall) + tracer.SetAttribute(llmHandle, "llm.model", "gpt-4") + tracer.SetAttribute(llmHandle, "llm.provider", "openai") + + // Step 5: Plugin creates its own span + _, pluginHandle := tracer.StartSpan(llmCtx, "governance-plugin", schemas.SpanKindPlugin) + tracer.SetAttribute(pluginHandle, "plugin.name", "governance") + + // Step 6: Complete spans (in reverse order) + tracer.EndSpan(pluginHandle, schemas.SpanStatusOk, "") + tracer.EndSpan(llmHandle, schemas.SpanStatusOk, "") + tracer.EndSpan(httpHandle, schemas.SpanStatusOk, "") + + // Step 7: Complete trace + trace := tracer.EndTrace(traceID) + + // Verify the trace structure for Datadog + if trace.TraceID != inheritedTraceID { + t.Errorf("Trace ID should match inherited ID from Datadog: got %q, want %q", trace.TraceID, inheritedTraceID) + } + + // Find spans by name + var httpSpan, llmSpan, pluginSpan *schemas.Span + for _, span := range trace.Spans { + switch span.Name { + case "/v1/chat/completions": + httpSpan = span + case "openai.chat.completions": + llmSpan = span + case "governance-plugin": + pluginSpan = span + } + } + + if httpSpan == nil || llmSpan == nil || pluginSpan == nil { + t.Fatal("Not all spans found in trace") + } + + // Verify span hierarchy for Datadog linking: + // External Parent (aad09d1659b4c7e3) -> HTTP Span -> LLM Span -> Plugin Span + + // HTTP span should link to Datadog's parent span + if httpSpan.ParentID != externalParentSpanID { + t.Errorf("HTTP span should link to Datadog parent: got ParentID %q, want %q", + httpSpan.ParentID, externalParentSpanID) + } + + // LLM span should be child of HTTP span + if llmSpan.ParentID != httpSpan.SpanID { + t.Errorf("LLM span should be child of HTTP span: got ParentID %q, want %q", + llmSpan.ParentID, httpSpan.SpanID) + } + + // Plugin span should be child of LLM span + if pluginSpan.ParentID != llmSpan.SpanID { + t.Errorf("Plugin span should be child of LLM span: got ParentID %q, want %q", + pluginSpan.ParentID, llmSpan.SpanID) + } + + // All spans should have the same trace ID + if httpSpan.TraceID != inheritedTraceID || llmSpan.TraceID != inheritedTraceID || pluginSpan.TraceID != inheritedTraceID { + t.Error("All spans should have the inherited trace ID") + } + + t.Logf("Trace structure (for Datadog):") + t.Logf(" Trace ID: %s", trace.TraceID) + t.Logf(" External Parent Span: %s (from Datadog)", externalParentSpanID) + t.Logf(" -> HTTP Span: %s (ParentID: %s)", httpSpan.SpanID, httpSpan.ParentID) + t.Logf(" -> LLM Span: %s (ParentID: %s)", llmSpan.SpanID, llmSpan.ParentID) + t.Logf(" -> Plugin Span: %s (ParentID: %s)", pluginSpan.SpanID, pluginSpan.ParentID) +} diff --git a/framework/vectorstore/weaviate.go b/framework/vectorstore/weaviate.go index 4c8d3ec017..d65a297fe6 100644 --- a/framework/vectorstore/weaviate.go +++ b/framework/vectorstore/weaviate.go @@ -348,6 +348,17 @@ func (s *WeaviateStore) Delete(ctx context.Context, className string, id string) } func (s *WeaviateStore) DeleteAll(ctx context.Context, className string, queries []Query) ([]DeleteResult, error) { + // Check if class exists first to avoid 500 errors from Weaviate + exists, err := s.client.Schema().ClassExistenceChecker(). + WithClassName(className). + Do(ctx) + if err != nil { + return nil, fmt.Errorf("failed to check class existence: %w", err) + } + if !exists { + return []DeleteResult{}, nil // Class doesn't exist, nothing to delete + } + where := buildWeaviateFilter(queries) res, err := s.client.Batch().ObjectsBatchDeleter(). diff --git a/framework/version b/framework/version index f85711d68a..5975b143a0 100644 --- a/framework/version +++ b/framework/version @@ -1 +1 @@ -1.1.61 \ No newline at end of file +1.2.8 \ No newline at end of file diff --git a/plugins/governance/advancedscenarios_test.go b/plugins/governance/advancedscenarios_test.go new file mode 100644 index 0000000000..4181b86d9f --- /dev/null +++ b/plugins/governance/advancedscenarios_test.go @@ -0,0 +1,1681 @@ +package governance + +import ( + "testing" + "time" +) + +// ============================================================================ +// SCENARIO 1: VK Switching Teams After Budget Exhaustion +// ============================================================================ + +// TestVKSwitchTeamAfterBudgetExhaustion verifies that after exhausting one team's budget, +// switching the VK to another team allows requests to pass +func TestVKSwitchTeamAfterBudgetExhaustion(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create Team 1 with small budget + team1Name := "test-team1-switch-" + generateRandomID() + team1Budget := 0.01 // $0.01 + createTeam1Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: team1Name, + Budget: &BudgetRequest{ + MaxLimit: team1Budget, + ResetDuration: "1h", + }, + }, + }) + + if createTeam1Resp.StatusCode != 200 { + t.Fatalf("Failed to create team1: status %d", createTeam1Resp.StatusCode) + } + + team1ID := ExtractIDFromResponse(t, createTeam1Resp, "id") + testData.AddTeam(team1ID) + + // Create Team 2 with higher budget + team2Name := "test-team2-switch-" + generateRandomID() + team2Budget := 10.0 // $10 + createTeam2Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: team2Name, + Budget: &BudgetRequest{ + MaxLimit: team2Budget, + ResetDuration: "1h", + }, + }, + }) + + if createTeam2Resp.StatusCode != 200 { + t.Fatalf("Failed to create team2: status %d", createTeam2Resp.StatusCode) + } + + team2ID := ExtractIDFromResponse(t, createTeam2Resp, "id") + testData.AddTeam(team2ID) + + t.Logf("Created Team1 (budget: $%.2f) and Team2 (budget: $%.2f)", team1Budget, team2Budget) + + // Create VK assigned to Team 1 + vkName := "test-vk-team-switch-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + TeamID: &team1ID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK assigned to Team1") + + // Exhaust Team1's budget + consumedBudget := 0.0 + requestNum := 1 + + for requestNum <= 150 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Hello how are you?"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "budget") { + t.Logf("Team1 budget exhausted at request %d (consumed: $%.6f)", requestNum, consumedBudget) + break + } else { + t.Fatalf("Request %d failed with unexpected error: %v", requestNum, resp.Body) + } + } + + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + cost, _ := CalculateCost("openai/gpt-4o", int(prompt), int(completion)) + consumedBudget += cost + } + } + } + + requestNum++ + + if consumedBudget >= team1Budget { + // Make one more request to trigger rejection + continue + } + } + + if consumedBudget < team1Budget { + t.Fatalf("Could not exhaust Team1 budget") + } + + // Now switch VK to Team2 + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + TeamID: &team2ID, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to switch VK to Team2: status %d", updateResp.StatusCode) + } + + t.Logf("Switched VK from Team1 to Team2") + + // Wait for in-memory update + time.Sleep(500 * time.Millisecond) + + // Request should now succeed with Team2's budget + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Request after switching to Team2"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Fatalf("Request should succeed after switching to Team2 with available budget, got status %d", resp.StatusCode) + } + + t.Logf("VK switch team after budget exhaustion verified āœ“") +} + +// ============================================================================ +// SCENARIO 2: VK Switching Customers After Budget Exhaustion +// ============================================================================ + +// TestVKSwitchCustomerAfterBudgetExhaustion verifies that after exhausting one customer's budget, +// switching the VK to another customer allows requests to pass +func TestVKSwitchCustomerAfterBudgetExhaustion(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create Customer 1 with small budget + customer1Name := "test-customer1-switch-" + generateRandomID() + customer1Budget := 0.01 // $0.01 + createCustomer1Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customer1Name, + Budget: &BudgetRequest{ + MaxLimit: customer1Budget, + ResetDuration: "1h", + }, + }, + }) + + if createCustomer1Resp.StatusCode != 200 { + t.Fatalf("Failed to create customer1: status %d", createCustomer1Resp.StatusCode) + } + + customer1ID := ExtractIDFromResponse(t, createCustomer1Resp, "id") + testData.AddCustomer(customer1ID) + + // Create Customer 2 with higher budget + customer2Name := "test-customer2-switch-" + generateRandomID() + customer2Budget := 10.0 // $10 + createCustomer2Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customer2Name, + Budget: &BudgetRequest{ + MaxLimit: customer2Budget, + ResetDuration: "1h", + }, + }, + }) + + if createCustomer2Resp.StatusCode != 200 { + t.Fatalf("Failed to create customer2: status %d", createCustomer2Resp.StatusCode) + } + + customer2ID := ExtractIDFromResponse(t, createCustomer2Resp, "id") + testData.AddCustomer(customer2ID) + + t.Logf("Created Customer1 (budget: $%.2f) and Customer2 (budget: $%.2f)", customer1Budget, customer2Budget) + + // Create VK assigned directly to Customer 1 + vkName := "test-vk-customer-switch-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + CustomerID: &customer1ID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK assigned to Customer1") + + // Exhaust Customer1's budget + consumedBudget := 0.0 + requestNum := 1 + + for requestNum <= 150 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Hello how are you?"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "budget") { + t.Logf("Customer1 budget exhausted at request %d (consumed: $%.6f)", requestNum, consumedBudget) + break + } else { + t.Fatalf("Request %d failed with unexpected error: %v", requestNum, resp.Body) + } + } + + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + cost, _ := CalculateCost("openai/gpt-4o", int(prompt), int(completion)) + consumedBudget += cost + } + } + } + + requestNum++ + + if consumedBudget >= customer1Budget { + continue + } + } + + if consumedBudget < customer1Budget { + t.Fatalf("Could not exhaust Customer1 budget") + } + + // Now switch VK to Customer2 + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + CustomerID: &customer2ID, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to switch VK to Customer2: status %d", updateResp.StatusCode) + } + + t.Logf("Switched VK from Customer1 to Customer2") + + time.Sleep(500 * time.Millisecond) + + // Request should now succeed with Customer2's budget + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Request after switching to Customer2"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Fatalf("Request should succeed after switching to Customer2 with available budget, got status %d", resp.StatusCode) + } + + t.Logf("VK switch customer after budget exhaustion verified āœ“") +} + +// ============================================================================ +// SCENARIO 3: Hierarchical Chain VK->Team->Customer Budget Switching +// ============================================================================ + +// TestHierarchicalChainBudgetSwitch verifies switching the entire hierarchy +func TestHierarchicalChainBudgetSwitch(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create Customer 1 with small budget + customer1Name := "test-customer1-hierarchy-" + generateRandomID() + createCustomer1Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customer1Name, + Budget: &BudgetRequest{ + MaxLimit: 0.01, // $0.01 - most restrictive + ResetDuration: "1h", + }, + }, + }) + + if createCustomer1Resp.StatusCode != 200 { + t.Fatalf("Failed to create customer1: status %d", createCustomer1Resp.StatusCode) + } + + customer1ID := ExtractIDFromResponse(t, createCustomer1Resp, "id") + testData.AddCustomer(customer1ID) + + // Create Team 1 under Customer 1 + team1Name := "test-team1-hierarchy-" + generateRandomID() + createTeam1Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: team1Name, + CustomerID: &customer1ID, + Budget: &BudgetRequest{ + MaxLimit: 100.0, // High budget - customer is limiting + ResetDuration: "1h", + }, + }, + }) + + if createTeam1Resp.StatusCode != 200 { + t.Fatalf("Failed to create team1: status %d", createTeam1Resp.StatusCode) + } + + team1ID := ExtractIDFromResponse(t, createTeam1Resp, "id") + testData.AddTeam(team1ID) + + // Create Customer 2 with higher budget + customer2Name := "test-customer2-hierarchy-" + generateRandomID() + createCustomer2Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customer2Name, + Budget: &BudgetRequest{ + MaxLimit: 100.0, // High budget + ResetDuration: "1h", + }, + }, + }) + + if createCustomer2Resp.StatusCode != 200 { + t.Fatalf("Failed to create customer2: status %d", createCustomer2Resp.StatusCode) + } + + customer2ID := ExtractIDFromResponse(t, createCustomer2Resp, "id") + testData.AddCustomer(customer2ID) + + // Create Team 2 under Customer 2 + team2Name := "test-team2-hierarchy-" + generateRandomID() + createTeam2Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: team2Name, + CustomerID: &customer2ID, + Budget: &BudgetRequest{ + MaxLimit: 100.0, // High budget + ResetDuration: "1h", + }, + }, + }) + + if createTeam2Resp.StatusCode != 200 { + t.Fatalf("Failed to create team2: status %d", createTeam2Resp.StatusCode) + } + + team2ID := ExtractIDFromResponse(t, createTeam2Resp, "id") + testData.AddTeam(team2ID) + + t.Logf("Created hierarchy: Customer1(low budget)->Team1 and Customer2(high budget)->Team2") + + // Create VK assigned to Team 1 + vkName := "test-vk-hierarchy-switch-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + TeamID: &team1ID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + // Exhaust Customer1's budget (which is limiting Team1) + consumedBudget := 0.0 + requestNum := 1 + budgetExhausted := false + + for requestNum <= 150 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Hello how are you?"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "budget") { + budgetExhausted = true + t.Logf("Customer1 budget exhausted at request %d (consumed: $%.6f)", requestNum, consumedBudget) + break + } else { + t.Fatalf("Request %d failed with unexpected error: %v", requestNum, resp.Body) + } + } + + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + cost, _ := CalculateCost("openai/gpt-4o", int(prompt), int(completion)) + consumedBudget += cost + } + } + } + + requestNum++ + } + + if !budgetExhausted { + t.Fatalf("Budget should have been exhausted within 150 requests, but no budget rejection was observed (consumed: $%.6f)", consumedBudget) + } + + // Switch VK to Team2 (under Customer2) + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + TeamID: &team2ID, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to switch VK to Team2: status %d", updateResp.StatusCode) + } + + t.Logf("Switched VK from Team1(Customer1) to Team2(Customer2)") + + time.Sleep(500 * time.Millisecond) + + // Request should now succeed + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Request after switching hierarchy"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Fatalf("Request should succeed after switching hierarchy, got status %d", resp.StatusCode) + } + + t.Logf("Hierarchical chain budget switch verified āœ“") +} + +// ============================================================================ +// SCENARIO 4: VK Budget Update After Exhaustion +// ============================================================================ + +// TestVKBudgetUpdateAfterExhaustion verifies that updating VK budget after exhaustion allows requests +func TestVKBudgetUpdateAfterExhaustion(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with small budget + vkName := "test-vk-budget-update-" + generateRandomID() + initialBudget := 0.01 // $0.01 + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: "1h", + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with budget: $%.2f", initialBudget) + + // Exhaust VK budget + consumedBudget := 0.0 + requestNum := 1 + sawBudgetRejection := false + + for requestNum <= 150 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Hello how are you?"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "budget") { + sawBudgetRejection = true + t.Logf("VK budget exhausted at request %d (consumed: $%.6f)", requestNum, consumedBudget) + break + } else { + t.Fatalf("Request %d failed with unexpected error: %v", requestNum, resp.Body) + } + } + + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + cost, _ := CalculateCost("openai/gpt-4o", int(prompt), int(completion)) + consumedBudget += cost + } + } + } + + requestNum++ + } + + if !sawBudgetRejection { + t.Fatalf("No budget rejection observed; consumed budget: $%.6f", consumedBudget) + } + + // Update VK budget to a higher value + newBudget := 10.0 + resetDuration := "1h" + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + Budget: &UpdateBudgetRequest{ + MaxLimit: &newBudget, + ResetDuration: &resetDuration, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update VK budget: status %d", updateResp.StatusCode) + } + + t.Logf("Updated VK budget from $%.2f to $%.2f", initialBudget, newBudget) + + time.Sleep(500 * time.Millisecond) + + // Request should now succeed + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Request after budget update"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Fatalf("Request should succeed after budget update, got status %d", resp.StatusCode) + } + + t.Logf("VK budget update after exhaustion verified āœ“") +} + +// ============================================================================ +// SCENARIO 5: Team Budget Update After Exhaustion +// ============================================================================ + +// TestTeamBudgetUpdateAfterExhaustion verifies that updating team budget after exhaustion allows requests +func TestTeamBudgetUpdateAfterExhaustion(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create team with small budget + teamName := "test-team-budget-update-" + generateRandomID() + initialBudget := 0.01 // $0.01 + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: "1h", + }, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + // Create VK under team + vkName := "test-vk-team-budget-update-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + TeamID: &teamID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created team with budget: $%.2f", initialBudget) + + // Exhaust team budget + consumedBudget := 0.0 + requestNum := 1 + sawBudgetRejection := false + + for requestNum <= 150 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Hello how are you?"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "budget") { + sawBudgetRejection = true + t.Logf("Team budget exhausted at request %d (consumed: $%.6f)", requestNum, consumedBudget) + break + } else { + t.Fatalf("Request %d failed with unexpected error: %v", requestNum, resp.Body) + } + } + + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + cost, _ := CalculateCost("openai/gpt-4o", int(prompt), int(completion)) + consumedBudget += cost + } + } + } + + requestNum++ + } + + if !sawBudgetRejection { + t.Fatalf("No budget rejection observed; consumed budget: $%.6f", consumedBudget) + } + + // Update team budget + newBudget := 10.0 + resetDuration := "1h" + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/teams/" + teamID, + Body: UpdateTeamRequest{ + Budget: &UpdateBudgetRequest{ + MaxLimit: &newBudget, + ResetDuration: &resetDuration, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update team budget: status %d", updateResp.StatusCode) + } + + t.Logf("Updated team budget from $%.2f to $%.2f", initialBudget, newBudget) + + time.Sleep(500 * time.Millisecond) + + // Request should now succeed + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Request after team budget update"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Fatalf("Request should succeed after team budget update, got status %d", resp.StatusCode) + } + + t.Logf("Team budget update after exhaustion verified āœ“") +} + +// ============================================================================ +// SCENARIO 6: Customer Budget Update After Exhaustion +// ============================================================================ + +// TestCustomerBudgetUpdateAfterExhaustion verifies that updating customer budget after exhaustion allows requests +func TestCustomerBudgetUpdateAfterExhaustion(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create customer with small budget + customerName := "test-customer-budget-update-" + generateRandomID() + initialBudget := 0.01 // $0.01 + createCustomerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: "1h", + }, + }, + }) + + if createCustomerResp.StatusCode != 200 { + t.Fatalf("Failed to create customer: status %d", createCustomerResp.StatusCode) + } + + customerID := ExtractIDFromResponse(t, createCustomerResp, "id") + testData.AddCustomer(customerID) + + // Create team under customer + teamName := "test-team-customer-update-" + generateRandomID() + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + CustomerID: &customerID, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + // Create VK under team + vkName := "test-vk-customer-budget-update-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + TeamID: &teamID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created customer with budget: $%.2f", initialBudget) + + // Exhaust customer budget + consumedBudget := 0.0 + requestNum := 1 + sawBudgetRejection := false + + for requestNum <= 150 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Hello how are you?"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "budget") { + sawBudgetRejection = true + t.Logf("Customer budget exhausted at request %d (consumed: $%.6f)", requestNum, consumedBudget) + break + } else { + t.Fatalf("Request %d failed with unexpected error: %v", requestNum, resp.Body) + } + } + + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + cost, _ := CalculateCost("openai/gpt-4o", int(prompt), int(completion)) + consumedBudget += cost + } + } + } + + requestNum++ + } + + if !sawBudgetRejection { + t.Fatalf("No budget rejection observed; consumed budget: $%.6f", consumedBudget) + } + + // Update customer budget + newBudget := 10.0 + resetDuration := "1h" + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/customers/" + customerID, + Body: UpdateCustomerRequest{ + Budget: &UpdateBudgetRequest{ + MaxLimit: &newBudget, + ResetDuration: &resetDuration, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update customer budget: status %d", updateResp.StatusCode) + } + + t.Logf("Updated customer budget from $%.2f to $%.2f", initialBudget, newBudget) + + time.Sleep(500 * time.Millisecond) + + // Request should now succeed + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Request after customer budget update"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Fatalf("Request should succeed after customer budget update, got status %d", resp.StatusCode) + } + + t.Logf("Customer budget update after exhaustion verified āœ“") +} + +// ============================================================================ +// SCENARIO 7: Provider Config Budget Update After Exhaustion +// ============================================================================ + +// TestProviderConfigBudgetUpdateAfterExhaustion verifies that updating provider config budget after exhaustion allows requests +func TestProviderConfigBudgetUpdateAfterExhaustion(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with provider config budget + vkName := "test-vk-provider-budget-update-" + generateRandomID() + initialBudget := 0.01 // $0.01 + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: "1h", + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with provider config budget: $%.2f", initialBudget) + + // Get provider config ID + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + providerConfigs := vkData["provider_configs"].([]interface{}) + providerConfig := providerConfigs[0].(map[string]interface{}) + providerConfigID := uint(providerConfig["id"].(float64)) + + // Exhaust provider config budget + consumedBudget := 0.0 + requestNum := 1 + sawBudgetRejection := false + + for requestNum <= 150 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Hello how are you?"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "budget") { + sawBudgetRejection = true + t.Logf("Provider config budget exhausted at request %d (consumed: $%.6f)", requestNum, consumedBudget) + break + } else { + t.Fatalf("Request %d failed with unexpected error: %v", requestNum, resp.Body) + } + } + + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + cost, _ := CalculateCost("openai/gpt-4o", int(prompt), int(completion)) + consumedBudget += cost + } + } + } + + requestNum++ + } + + if !sawBudgetRejection { + t.Fatalf("No budget rejection observed; consumed budget: $%.6f", consumedBudget) + } + + // Update provider config budget + newBudget := 10.0 + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + ProviderConfigs: []ProviderConfigRequest{ + { + ID: &providerConfigID, + Provider: "openai", + Weight: 1.0, + Budget: &BudgetRequest{ + MaxLimit: newBudget, + ResetDuration: "1h", + }, + }, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update provider config budget: status %d", updateResp.StatusCode) + } + + t.Logf("Updated provider config budget from $%.2f to $%.2f", initialBudget, newBudget) + + time.Sleep(500 * time.Millisecond) + + // Request should now succeed + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: "Request after provider config budget update"}, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Fatalf("Request should succeed after provider config budget update, got status %d", resp.StatusCode) + } + + t.Logf("Provider config budget update after exhaustion verified āœ“") +} + +// ============================================================================ +// SCENARIO 8: VK Deletion Cascade +// ============================================================================ + +// TestVKDeletionCascadeComplete verifies deleting VK removes provider configs, budgets, and rate limits from memory +func TestVKDeletionCascadeComplete(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with budget, rate limit, and provider configs + vkName := "test-vk-deletion-cascade-" + generateRandomID() + tokenLimit := int64(10000) + tokenResetDuration := "1h" + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: 10.0, + ResetDuration: "1h", + }, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + Budget: &BudgetRequest{ + MaxLimit: 5.0, + ResetDuration: "1h", + }, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + // Don't add to testData since we'll delete manually + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with budget, rate limit, and provider config") + + // Get initial state from in-memory store + getDataResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap1 := getDataResp1.Body["virtual_keys"].(map[string]interface{}) + + getBudgetsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap1 := getBudgetsResp1.Body["budgets"].(map[string]interface{}) + + getRateLimitsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap1 := getRateLimitsResp1.Body["rate_limits"].(map[string]interface{}) + + // Verify VK exists + _, vkExists := virtualKeysMap1[vkValue] + if !vkExists { + t.Fatalf("VK not found in in-memory store") + } + + vkData1 := virtualKeysMap1[vkValue].(map[string]interface{}) + vkBudgetID := vkData1["budget_id"].(string) + vkRateLimitID := vkData1["rate_limit_id"].(string) + providerConfigs := vkData1["provider_configs"].([]interface{}) + pc := providerConfigs[0].(map[string]interface{}) + pcBudgetID := pc["budget_id"].(string) + pcRateLimitID := pc["rate_limit_id"].(string) + + // Verify all resources exist in memory + _, vkBudgetExists := budgetsMap1[vkBudgetID] + _, vkRateLimitExists := rateLimitsMap1[vkRateLimitID] + _, pcBudgetExists := budgetsMap1[pcBudgetID] + _, pcRateLimitExists := rateLimitsMap1[pcRateLimitID] + + if !vkBudgetExists || !vkRateLimitExists || !pcBudgetExists || !pcRateLimitExists { + t.Fatalf("Not all resources found in memory before deletion") + } + + t.Logf("All resources exist in memory before deletion āœ“") + + // Delete VK + deleteResp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: "/api/governance/virtual-keys/" + vkID, + }) + + if deleteResp.StatusCode != 200 { + t.Fatalf("Failed to delete VK: status %d", deleteResp.StatusCode) + } + + t.Logf("VK deleted") + + time.Sleep(500 * time.Millisecond) + + // Verify VK and all related resources are removed from memory + getDataResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap2 := getDataResp2.Body["virtual_keys"].(map[string]interface{}) + + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + + getRateLimitsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap2 := getRateLimitsResp2.Body["rate_limits"].(map[string]interface{}) + + // VK should be gone + _, vkStillExists := virtualKeysMap2[vkValue] + if vkStillExists { + t.Fatalf("VK still exists in memory after deletion") + } + + // Budgets should be gone + _, vkBudgetStillExists := budgetsMap2[vkBudgetID] + _, pcBudgetStillExists := budgetsMap2[pcBudgetID] + if vkBudgetStillExists || pcBudgetStillExists { + t.Fatalf("Budgets should be cascade-deleted: VK budget exists=%v, PC budget exists=%v", + vkBudgetStillExists, pcBudgetStillExists) + } + + // Rate limits should be gone + _, vkRateLimitStillExists := rateLimitsMap2[vkRateLimitID] + _, pcRateLimitStillExists := rateLimitsMap2[pcRateLimitID] + if vkRateLimitStillExists || pcRateLimitStillExists { + t.Logf("Note: Rate limits may still exist in memory (orphaned) - this is acceptable") + } + + t.Logf("VK removed from memory after deletion āœ“") + t.Logf("VK deletion cascade verified āœ“") +} + +// ============================================================================ +// SCENARIO 9: Team/Customer Deletion Should Delete Budget +// ============================================================================ + +// TestTeamDeletionDeletesBudget verifies that deleting a team also deletes its budget from memory +func TestTeamDeletionDeletesBudget(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create team with budget + teamName := "test-team-delete-budget-" + generateRandomID() + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + Budget: &BudgetRequest{ + MaxLimit: 100.0, + ResetDuration: "1h", + }, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + // Don't add to testData since we'll delete manually + + t.Logf("Created team with budget") + + // Get budget ID from in-memory store + getTeamsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/teams?from_memory=true", + }) + + teamsMap1 := getTeamsResp1.Body["teams"].(map[string]interface{}) + + getBudgetsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap1 := getBudgetsResp1.Body["budgets"].(map[string]interface{}) + + teamData1 := teamsMap1[teamID].(map[string]interface{}) + budgetID := teamData1["budget_id"].(string) + + _, budgetExists := budgetsMap1[budgetID] + if !budgetExists { + t.Fatalf("Budget not found in memory before deletion") + } + + t.Logf("Team and budget exist in memory āœ“") + + // Delete team + deleteResp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: "/api/governance/teams/" + teamID, + }) + + if deleteResp.StatusCode != 200 { + t.Fatalf("Failed to delete team: status %d", deleteResp.StatusCode) + } + + t.Logf("Team deleted") + + time.Sleep(500 * time.Millisecond) + + // Verify team and budget are removed from memory + getTeamsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/teams?from_memory=true", + }) + + teamsMap2 := getTeamsResp2.Body["teams"].(map[string]interface{}) + + _, teamStillExists := teamsMap2[teamID] + if teamStillExists { + t.Fatalf("Team still exists in memory after deletion") + } + + t.Logf("Team removed from memory āœ“") + + // Verify budget is also removed from memory + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + if getBudgetsResp2.StatusCode != 200 { + t.Fatalf("Failed to get budgets from memory: status %d", getBudgetsResp2.StatusCode) + } + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + + _, budgetStillExists := budgetsMap2[budgetID] + if budgetStillExists { + t.Fatalf("Budget %s still exists in memory after team deletion", budgetID) + } + + t.Logf("Budget removed from memory āœ“") + t.Logf("Team deletion with budget verified āœ“") +} + +// TestCustomerDeletionDeletesBudget verifies that deleting a customer also deletes its budget from memory +func TestCustomerDeletionDeletesBudget(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create customer with budget + customerName := "test-customer-delete-budget-" + generateRandomID() + createCustomerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + Budget: &BudgetRequest{ + MaxLimit: 100.0, + ResetDuration: "1h", + }, + }, + }) + + if createCustomerResp.StatusCode != 200 { + t.Fatalf("Failed to create customer: status %d", createCustomerResp.StatusCode) + } + + customerID := ExtractIDFromResponse(t, createCustomerResp, "id") + // Don't add to testData since we'll delete manually + + t.Logf("Created customer with budget") + + // Get budget ID from in-memory store + getCustomersResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/customers?from_memory=true", + }) + + customersMap1 := getCustomersResp1.Body["customers"].(map[string]interface{}) + + getBudgetsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap1 := getBudgetsResp1.Body["budgets"].(map[string]interface{}) + + customerData1 := customersMap1[customerID].(map[string]interface{}) + budgetID := customerData1["budget_id"].(string) + + _, budgetExists := budgetsMap1[budgetID] + if !budgetExists { + t.Fatalf("Budget not found in memory before deletion") + } + + t.Logf("Customer and budget exist in memory āœ“") + + // Delete customer + deleteResp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: "/api/governance/customers/" + customerID, + }) + + if deleteResp.StatusCode != 200 { + t.Fatalf("Failed to delete customer: status %d", deleteResp.StatusCode) + } + + t.Logf("Customer deleted") + + time.Sleep(500 * time.Millisecond) + + // Verify customer is removed from memory + getCustomersResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/customers?from_memory=true", + }) + + customersMap2 := getCustomersResp2.Body["customers"].(map[string]interface{}) + + _, customerStillExists := customersMap2[customerID] + if customerStillExists { + t.Fatalf("Customer still exists in memory after deletion") + } + + t.Logf("Customer removed from memory āœ“") + + // Verify budget is also removed from memory + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + if getBudgetsResp2.StatusCode != 200 { + t.Fatalf("Failed to get budgets from memory: status %d", getBudgetsResp2.StatusCode) + } + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + + _, budgetStillExists := budgetsMap2[budgetID] + if budgetStillExists { + t.Fatalf("Budget still exists in memory after customer deletion") + } + + t.Logf("Budget removed from memory āœ“") + t.Logf("Customer deletion with budget verified āœ“") +} + +// ============================================================================ +// SCENARIO 10: Team/Customer Deletion Sets VK entity_id = nil +// ============================================================================ + +// TestTeamDeletionSetsVKTeamIDToNil verifies that deleting a team sets team_id=nil on associated VKs +func TestTeamDeletionSetsVKTeamIDToNil(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create team + teamName := "test-team-vk-nil-" + generateRandomID() + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + // Don't add to testData since we'll delete manually + + // Create VK assigned to team + vkName := "test-vk-team-nil-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + TeamID: &teamID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created team and VK assigned to it") + + // Verify VK has team_id set + getDataResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap1 := getDataResp1.Body["virtual_keys"].(map[string]interface{}) + vkData1 := virtualKeysMap1[vkValue].(map[string]interface{}) + + teamIDFromVK1, hasTeamID := vkData1["team_id"].(string) + if !hasTeamID || teamIDFromVK1 != teamID { + t.Fatalf("VK team_id not set correctly before team deletion") + } + + t.Logf("VK has team_id=%s āœ“", teamID) + + // Delete team + deleteResp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: "/api/governance/teams/" + teamID, + }) + + if deleteResp.StatusCode != 200 { + t.Fatalf("Failed to delete team: status %d", deleteResp.StatusCode) + } + + t.Logf("Team deleted") + + time.Sleep(500 * time.Millisecond) + + // Verify VK still exists but team_id is nil + getDataResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap2 := getDataResp2.Body["virtual_keys"].(map[string]interface{}) + + vkData2, vkStillExists := virtualKeysMap2[vkValue].(map[string]interface{}) + if !vkStillExists { + t.Fatalf("VK should still exist after team deletion") + } + + teamIDFromVK2, hasTeamID2 := vkData2["team_id"].(string) + if hasTeamID2 && teamIDFromVK2 != "" { + t.Fatalf("VK team_id should be nil after team deletion, got: %s", teamIDFromVK2) + } + + t.Logf("VK team_id is now nil āœ“") + t.Logf("Team deletion sets VK team_id to nil verified āœ“") +} + +// TestCustomerDeletionSetsVKCustomerIDToNil verifies that deleting a customer sets customer_id=nil on associated VKs +func TestCustomerDeletionSetsVKCustomerIDToNil(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create customer + customerName := "test-customer-vk-nil-" + generateRandomID() + createCustomerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + }, + }) + + if createCustomerResp.StatusCode != 200 { + t.Fatalf("Failed to create customer: status %d", createCustomerResp.StatusCode) + } + + customerID := ExtractIDFromResponse(t, createCustomerResp, "id") + // Don't add to testData since we'll delete manually + + // Create VK assigned directly to customer + vkName := "test-vk-customer-nil-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + CustomerID: &customerID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created customer and VK assigned to it") + + // Verify VK has customer_id set + getDataResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap1 := getDataResp1.Body["virtual_keys"].(map[string]interface{}) + vkData1 := virtualKeysMap1[vkValue].(map[string]interface{}) + + customerIDFromVK1, hasCustomerID := vkData1["customer_id"].(string) + if !hasCustomerID || customerIDFromVK1 != customerID { + t.Fatalf("VK customer_id not set correctly before customer deletion") + } + + t.Logf("VK has customer_id=%s āœ“", customerID) + + // Delete customer + deleteResp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: "/api/governance/customers/" + customerID, + }) + + if deleteResp.StatusCode != 200 { + t.Fatalf("Failed to delete customer: status %d", deleteResp.StatusCode) + } + + t.Logf("Customer deleted") + + time.Sleep(500 * time.Millisecond) + + // Verify VK still exists but customer_id is nil + getDataResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap2 := getDataResp2.Body["virtual_keys"].(map[string]interface{}) + + vkData2, vkStillExists := virtualKeysMap2[vkValue].(map[string]interface{}) + if !vkStillExists { + t.Fatalf("VK should still exist after customer deletion") + } + + customerIDFromVK2, hasCustomerID2 := vkData2["customer_id"].(string) + if hasCustomerID2 && customerIDFromVK2 != "" { + t.Fatalf("VK customer_id should be nil after customer deletion, got: %s", customerIDFromVK2) + } + + t.Logf("VK customer_id is now nil āœ“") + t.Logf("Customer deletion sets VK customer_id to nil verified āœ“") +} diff --git a/plugins/governance/configupdatesync_test.go b/plugins/governance/configupdatesync_test.go new file mode 100644 index 0000000000..a252c7d6e9 --- /dev/null +++ b/plugins/governance/configupdatesync_test.go @@ -0,0 +1,1123 @@ +package governance + +import ( + "testing" + "time" +) + +// ============================================================================ +// VK-LEVEL RATE LIMIT UPDATE SYNC +// ============================================================================ + +// TestVKRateLimitUpdateSyncToMemory tests that VK rate limit updates sync to in-memory store +// and that usage resets to 0 when new max limit < current usage +func TestVKRateLimitUpdateSyncToMemory(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with initial rate limit + vkName := "test-vk-rate-update-" + generateRandomID() + initialTokenLimit := int64(10000) + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &initialTokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with initial token limit: %d", initialTokenLimit) + + // Get initial in-memory state + getVKResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + vkData1 := getVKResp1.Body["virtual_keys"].(map[string]interface{})[vkValue].(map[string]interface{}) + rateLimitID1, _ := vkData1["rate_limit_id"].(string) + + getRateLimitsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap1 := getRateLimitsResp1.Body["rate_limits"].(map[string]interface{}) + rateLimit1 := rateLimitsMap1[rateLimitID1].(map[string]interface{}) + + initialTokenMaxLimit, _ := rateLimit1["token_max_limit"].(float64) + initialTokenUsage, _ := rateLimit1["token_current_usage"].(float64) + + if int64(initialTokenMaxLimit) != initialTokenLimit { + t.Fatalf("Initial token max limit not correct: expected %d, got %d", initialTokenLimit, int64(initialTokenMaxLimit)) + } + + t.Logf("Initial state in memory: TokenMaxLimit=%d, TokenCurrentUsage=%d", int64(initialTokenMaxLimit), int64(initialTokenUsage)) + + // Make a request to consume some tokens + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request to consume tokens.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not make request to consume tokens") + } + + // Wait for async update + time.Sleep(500 * time.Millisecond) + + // Get state with usage + getVKResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + vkData2 := getVKResp2.Body["virtual_keys"].(map[string]interface{})[vkValue].(map[string]interface{}) + rateLimitID2, _ := vkData2["rate_limit_id"].(string) + + getRateLimitsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap2 := getRateLimitsResp2.Body["rate_limits"].(map[string]interface{}) + rateLimit2 := rateLimitsMap2[rateLimitID2].(map[string]interface{}) + + tokenUsageBeforeUpdate, _ := rateLimit2["token_current_usage"].(float64) + t.Logf("Token usage after request: %d", int64(tokenUsageBeforeUpdate)) + + if tokenUsageBeforeUpdate <= 0 { + t.Skip("No tokens consumed - cannot test usage reset") + } + + // NOW UPDATE: set new limit LOWER than current usage to trigger reset + // Usage reset only happens when new max limit <= current usage + newLowerLimit := int64(tokenUsageBeforeUpdate / 2) // Set to half of current usage to ensure it's lower + if newLowerLimit <= 0 { + newLowerLimit = int64(tokenUsageBeforeUpdate / 10) // Fallback to 10% if too small + } + if newLowerLimit <= 0 { + newLowerLimit = 1 // Minimum of 1 + } + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &newLowerLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update VK rate limit: status %d", updateResp.StatusCode) + } + + t.Logf("Updated token limit from %d to %d (new limit %d <= current usage %d)", initialTokenLimit, newLowerLimit, newLowerLimit, int64(tokenUsageBeforeUpdate)) + + // Wait for update to sync + time.Sleep(500 * time.Millisecond) + + // Verify update in in-memory store + getVKResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + vkData3 := getVKResp3.Body["virtual_keys"].(map[string]interface{})[vkValue].(map[string]interface{}) + rateLimitID3, _ := vkData3["rate_limit_id"].(string) + + getRateLimitsResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap3 := getRateLimitsResp3.Body["rate_limits"].(map[string]interface{}) + rateLimit3 := rateLimitsMap3[rateLimitID3].(map[string]interface{}) + + newTokenMaxLimit, _ := rateLimit3["token_max_limit"].(float64) + tokenUsageAfterUpdate, _ := rateLimit3["token_current_usage"].(float64) + + // Verify new max limit is reflected + if int64(newTokenMaxLimit) != newLowerLimit { + t.Fatalf("Token max limit not updated in memory: expected %d, got %d", newLowerLimit, int64(newTokenMaxLimit)) + } + + t.Logf("āœ“ Token max limit updated in memory: %d", int64(newTokenMaxLimit)) + + // Verify usage reset to 0 (since new max limit <= current usage) + if tokenUsageAfterUpdate > 0.001 { + t.Fatalf("Token usage should reset to 0 when new limit (%d) <= current usage (%d), but got %d", newLowerLimit, int64(tokenUsageBeforeUpdate), int64(tokenUsageAfterUpdate)) + } + + t.Logf("āœ“ Token usage correctly reset to 0 (new limit: %d <= old usage: %d)", int64(newTokenMaxLimit), int64(tokenUsageBeforeUpdate)) + + // Test UPDATE with higher limit (usage should NOT reset) + newerHigherLimit := int64(50000) + updateResp2 := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &newerHigherLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if updateResp2.StatusCode != 200 { + t.Fatalf("Failed to update VK rate limit second time: status %d", updateResp2.StatusCode) + } + + time.Sleep(500 * time.Millisecond) + + getVKResp4 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + vkData4 := getVKResp4.Body["virtual_keys"].(map[string]interface{})[vkValue].(map[string]interface{}) + rateLimitID4, _ := vkData4["rate_limit_id"].(string) + + getRateLimitsResp4 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap4 := getRateLimitsResp4.Body["rate_limits"].(map[string]interface{}) + rateLimit4 := rateLimitsMap4[rateLimitID4].(map[string]interface{}) + + newerTokenMaxLimit, _ := rateLimit4["token_max_limit"].(float64) + tokenUsageAfterSecondUpdate, _ := rateLimit4["token_current_usage"].(float64) + + // Verify new higher limit is reflected + if int64(newerTokenMaxLimit) != newerHigherLimit { + t.Fatalf("Token max limit not updated to higher value: expected %d, got %d", newerHigherLimit, int64(newerTokenMaxLimit)) + } + + t.Logf("āœ“ Token max limit updated to higher value: %d", int64(newerTokenMaxLimit)) + + // Since usage is 0 and new limit is higher, usage stays 0 + if tokenUsageAfterSecondUpdate != 0 { + t.Logf("Note: Token usage is %d (expected 0 since it was reset)", int64(tokenUsageAfterSecondUpdate)) + } + + t.Logf("VK rate limit update sync to memory verified āœ“") +} + +// TestVKBudgetUpdateSyncToMemory tests that VK budget updates sync to in-memory store +// and that usage resets to 0 when new max budget < current usage +func TestVKBudgetUpdateSyncToMemory(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with initial budget + vkName := "test-vk-budget-update-" + generateRandomID() + initialBudget := 10.0 // $10 + resetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: resetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with initial budget: $%.2f", initialBudget) + + // Get initial in-memory state + getVKResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + vkData1 := getVKResp1.Body["virtual_keys"].(map[string]interface{})[vkValue].(map[string]interface{}) + budgetID, _ := vkData1["budget_id"].(string) + + getBudgetsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap1 := getBudgetsResp1.Body["budgets"].(map[string]interface{}) + budget1 := budgetsMap1[budgetID].(map[string]interface{}) + + initialMaxLimit, _ := budget1["max_limit"].(float64) + initialUsage, _ := budget1["current_usage"].(float64) + + if initialMaxLimit != initialBudget { + t.Fatalf("Initial budget max limit not correct: expected %.2f, got %.2f", initialBudget, initialMaxLimit) + } + + t.Logf("Initial state in memory: MaxLimit=$%.2f, CurrentUsage=$%.6f", initialMaxLimit, initialUsage) + + // Make a request to consume some budget + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request to consume budget.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not make request to consume budget") + } + + // Wait for async update + time.Sleep(500 * time.Millisecond) + + // Get state with usage + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + budget2 := budgetsMap2[budgetID].(map[string]interface{}) + + usageBeforeUpdate, _ := budget2["current_usage"].(float64) + t.Logf("Budget usage after request: $%.6f", usageBeforeUpdate) + + if usageBeforeUpdate <= 0 { + t.Skip("No budget consumed - cannot test usage reset") + } + + // UPDATE: set new limit LOWER than current usage to trigger reset + // Usage reset only happens when new max limit <= current usage + newLowerBudget := usageBeforeUpdate * 0.5 // Set to half of current usage to ensure it's lower + if newLowerBudget <= 0 { + newLowerBudget = usageBeforeUpdate * 0.1 // Fallback to 10% if too small + } + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + Budget: &UpdateBudgetRequest{ + MaxLimit: &newLowerBudget, + ResetDuration: &resetDuration, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update VK budget: status %d", updateResp.StatusCode) + } + + t.Logf("Updated budget from $%.2f to $%.6f (new limit %.6f < current usage %.6f)", initialBudget, newLowerBudget, newLowerBudget, usageBeforeUpdate) + + // Wait for update to sync + time.Sleep(1500 * time.Millisecond) + + // Verify update in in-memory store + getBudgetsResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap3 := getBudgetsResp3.Body["budgets"].(map[string]interface{}) + budget3 := budgetsMap3[budgetID].(map[string]interface{}) + + newMaxLimit, _ := budget3["max_limit"].(float64) + usageAfterUpdate, _ := budget3["current_usage"].(float64) + + // Verify new max limit is reflected + if newMaxLimit != newLowerBudget { + t.Fatalf("Budget max limit not updated in memory: expected %.6f, got %.6f", newLowerBudget, newMaxLimit) + } + + t.Logf("āœ“ Budget max limit updated in memory: $%.6f", newMaxLimit) + + // Verify usage reset to 0 (since new max limit <= current usage) + if usageAfterUpdate > 0.000001 { + t.Fatalf("Budget usage should reset to 0 when new limit (%.6f) <= current usage (%.6f), but got $%.6f", newMaxLimit, usageBeforeUpdate, usageAfterUpdate) + } + + t.Logf("āœ“ Budget usage correctly reset to 0 (new limit: $%.6f <= old usage: $%.6f)", newMaxLimit, usageBeforeUpdate) + + t.Logf("VK budget update sync to memory verified āœ“") +} + +// ============================================================================ +// PROVIDER CONFIG RATE LIMIT UPDATE SYNC +// ============================================================================ + +// TestProviderRateLimitUpdateSyncToMemory tests that provider config rate limit updates sync to memory +func TestProviderRateLimitUpdateSyncToMemory(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with provider config and initial rate limit + vkName := "test-vk-provider-rate-update-" + generateRandomID() + initialTokenLimit := int64(5000) + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &initialTokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with provider config, initial token limit: %d", initialTokenLimit) + + // Get initial in-memory state + getVKResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + vkData1 := getVKResp1.Body["virtual_keys"].(map[string]interface{})[vkValue].(map[string]interface{}) + providerConfigs1 := vkData1["provider_configs"].([]interface{}) + providerConfig1 := providerConfigs1[0].(map[string]interface{}) + providerConfigID := uint(providerConfig1["id"].(float64)) + rateLimitID1, _ := providerConfig1["rate_limit_id"].(string) + + getRateLimitsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap1 := getRateLimitsResp1.Body["rate_limits"].(map[string]interface{}) + rateLimit1 := rateLimitsMap1[rateLimitID1].(map[string]interface{}) + + initialTokenMaxLimit, _ := rateLimit1["token_max_limit"].(float64) + initialTokenUsage, _ := rateLimit1["token_current_usage"].(float64) + + if int64(initialTokenMaxLimit) != initialTokenLimit { + t.Fatalf("Initial token max limit not correct: expected %d, got %d", initialTokenLimit, int64(initialTokenMaxLimit)) + } + + t.Logf("Initial provider rate limit in memory: TokenMaxLimit=%d, TokenCurrentUsage=%d", int64(initialTokenMaxLimit), int64(initialTokenUsage)) + + // Make a request to consume some tokens + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request to consume provider tokens.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not make request to consume provider tokens") + } + + time.Sleep(500 * time.Millisecond) + + // Get state with usage + getVKResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + vkData2 := getVKResp2.Body["virtual_keys"].(map[string]interface{})[vkValue].(map[string]interface{}) + providerConfigs2 := vkData2["provider_configs"].([]interface{}) + providerConfig2 := providerConfigs2[0].(map[string]interface{}) + rateLimitID2, _ := providerConfig2["rate_limit_id"].(string) + + getRateLimitsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap2 := getRateLimitsResp2.Body["rate_limits"].(map[string]interface{}) + rateLimit2 := rateLimitsMap2[rateLimitID2].(map[string]interface{}) + + tokenUsageBeforeUpdate, _ := rateLimit2["token_current_usage"].(float64) + t.Logf("Provider token usage after request: %d", int64(tokenUsageBeforeUpdate)) + + if tokenUsageBeforeUpdate <= 0 { + t.Skip("No provider tokens consumed - cannot test usage reset") + } + + // UPDATE: set new limit LOWER than current usage + newLowerLimit := int64(50) // Much lower + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + ProviderConfigs: []ProviderConfigRequest{ + { + ID: &providerConfigID, + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &newLowerLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update provider rate limit: status %d", updateResp.StatusCode) + } + + t.Logf("Updated provider token limit from %d to %d", initialTokenLimit, newLowerLimit) + + time.Sleep(500 * time.Millisecond) + + // Verify update in in-memory store + getVKResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + vkData3 := getVKResp3.Body["virtual_keys"].(map[string]interface{})[vkValue].(map[string]interface{}) + providerConfigs3 := vkData3["provider_configs"].([]interface{}) + providerConfig3 := providerConfigs3[0].(map[string]interface{}) + rateLimitID3, _ := providerConfig3["rate_limit_id"].(string) + + getRateLimitsResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap3 := getRateLimitsResp3.Body["rate_limits"].(map[string]interface{}) + rateLimit3 := rateLimitsMap3[rateLimitID3].(map[string]interface{}) + + newTokenMaxLimit, _ := rateLimit3["token_max_limit"].(float64) + tokenUsageAfterUpdate, _ := rateLimit3["token_current_usage"].(float64) + + // Verify new limit is reflected + if int64(newTokenMaxLimit) != newLowerLimit { + t.Fatalf("Provider token max limit not updated: expected %d, got %d", newLowerLimit, int64(newTokenMaxLimit)) + } + + t.Logf("āœ“ Provider token max limit updated in memory: %d", int64(newTokenMaxLimit)) + + // Verify usage reset to 0 (since new max < old usage) + if tokenUsageAfterUpdate > 0.001 { + t.Fatalf("Provider token usage should reset to 0 when new limit < current usage, but got %d", int64(tokenUsageAfterUpdate)) + } + + t.Logf("āœ“ Provider token usage reset to 0 (new limit: %d < old usage: %d)", int64(newTokenMaxLimit), int64(tokenUsageBeforeUpdate)) + + t.Logf("Provider rate limit update sync to memory verified āœ“") +} + +// ============================================================================ +// TEAM BUDGET UPDATE SYNC +// ============================================================================ + +// TestTeamBudgetUpdateSyncToMemory tests that team budget updates sync to in-memory store +func TestTeamBudgetUpdateSyncToMemory(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create team with initial budget + teamName := "test-team-budget-update-" + generateRandomID() + initialBudget := 5.0 + resetDuration := "1h" + + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: resetDuration, + }, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + // Create VK under team to consume budget + vkName := "test-vk-under-team-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + TeamID: &teamID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created team with initial budget: $%.2f", initialBudget) + + // Get initial in-memory state + getTeamsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/teams?from_memory=true", + }) + + teamsMap1 := getTeamsResp1.Body["teams"].(map[string]interface{}) + teamData1 := teamsMap1[teamID].(map[string]interface{}) + budgetID, _ := teamData1["budget_id"].(string) + + getBudgetsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap1 := getBudgetsResp1.Body["budgets"].(map[string]interface{}) + budget1 := budgetsMap1[budgetID].(map[string]interface{}) + + initialMaxLimit, _ := budget1["max_limit"].(float64) + initialUsage, _ := budget1["current_usage"].(float64) + + if initialMaxLimit != initialBudget { + t.Fatalf("Initial budget not correct: expected %.2f, got %.2f", initialBudget, initialMaxLimit) + } + + t.Logf("Initial team budget in memory: MaxLimit=$%.2f, CurrentUsage=$%.6f", initialMaxLimit, initialUsage) + + // Make request to consume team budget + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request to consume team budget.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not make request to consume team budget") + } + + // Wait for usage to be updated in memory + var usageBeforeUpdate float64 + usageUpdated := WaitForCondition(t, func() bool { + getBudgetsResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap := getBudgetsResp.Body["budgets"].(map[string]interface{}) + if budget, ok := budgetsMap[budgetID].(map[string]interface{}); ok { + if usage, ok := budget["current_usage"].(float64); ok && usage > 0 { + usageBeforeUpdate = usage + return true + } + } + return false + }, 3*time.Second, "team budget usage > 0") + + if !usageUpdated { + t.Skip("Team budget usage did not update in time") + } + + t.Logf("Team budget usage after request: $%.6f", usageBeforeUpdate) + + // UPDATE: set new limit LOWER than current usage + newLowerBudget := 0.001 + resetDurationPtr := resetDuration + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/teams/" + teamID, + Body: UpdateTeamRequest{ + Budget: &UpdateBudgetRequest{ + MaxLimit: &newLowerBudget, + ResetDuration: &resetDurationPtr, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update team budget: status %d", updateResp.StatusCode) + } + + t.Logf("Updated team budget from $%.2f to $%.2f", initialBudget, newLowerBudget) + + // Wait for update to sync to in-memory store + var newMaxLimit, usageAfterUpdate float64 + updateSynced := WaitForCondition(t, func() bool { + getBudgetsResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap := getBudgetsResp.Body["budgets"].(map[string]interface{}) + if budget, ok := budgetsMap[budgetID].(map[string]interface{}); ok { + if maxLimit, ok := budget["max_limit"].(float64); ok { + newMaxLimit = maxLimit + usageAfterUpdate, _ = budget["current_usage"].(float64) + // Check if the new limit has been applied + return maxLimit == newLowerBudget + } + } + return false + }, 3*time.Second, "team budget max limit updated to new value") + + if !updateSynced { + t.Fatalf("Team budget update did not sync to memory in time") + } + + t.Logf("āœ“ Team budget max limit updated in memory: $%.2f", newMaxLimit) + + // Verify usage reset to 0 (since new max < old usage) + if usageAfterUpdate > 0.000001 { + t.Fatalf("Team budget usage should reset to 0 when new limit < current usage, but got $%.6f", usageAfterUpdate) + } + + t.Logf("āœ“ Team budget usage correctly reset to 0 (new limit: $%.2f < old usage: $%.6f)", newMaxLimit, usageBeforeUpdate) + + t.Logf("Team budget update sync to memory verified āœ“") +} + +// ============================================================================ +// CUSTOMER BUDGET UPDATE SYNC +// ============================================================================ + +// TestCustomerBudgetUpdateSyncToMemory tests that customer budget updates sync to in-memory store +func TestCustomerBudgetUpdateSyncToMemory(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create customer with initial budget + customerName := "test-customer-budget-update-" + generateRandomID() + initialBudget := 20.0 + resetDuration := "1h" + + createCustomerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: resetDuration, + }, + }, + }) + + if createCustomerResp.StatusCode != 200 { + t.Fatalf("Failed to create customer: status %d", createCustomerResp.StatusCode) + } + + customerID := ExtractIDFromResponse(t, createCustomerResp, "id") + testData.AddCustomer(customerID) + + // Create team and VK under customer + teamName := "test-team-under-customer-" + generateRandomID() + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + CustomerID: &customerID, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + vkName := "test-vk-under-customer-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + TeamID: &teamID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created customer with initial budget: $%.2f", initialBudget) + + // Get initial in-memory state + getCustomersResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/customers?from_memory=true", + }) + + customersMap1 := getCustomersResp1.Body["customers"].(map[string]interface{}) + customerData1 := customersMap1[customerID].(map[string]interface{}) + budgetID, _ := customerData1["budget_id"].(string) + + getBudgetsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap1 := getBudgetsResp1.Body["budgets"].(map[string]interface{}) + budget1 := budgetsMap1[budgetID].(map[string]interface{}) + + initialMaxLimit, _ := budget1["max_limit"].(float64) + initialUsage, _ := budget1["current_usage"].(float64) + + if initialMaxLimit != initialBudget { + t.Fatalf("Initial customer budget not correct: expected %.2f, got %.2f", initialBudget, initialMaxLimit) + } + + t.Logf("Initial customer budget in memory: MaxLimit=$%.2f, CurrentUsage=$%.6f", initialMaxLimit, initialUsage) + + // Make request to consume customer budget + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request to consume customer budget.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not make request to consume customer budget") + } + + time.Sleep(500 * time.Millisecond) + + // Get state with usage + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + budget2 := budgetsMap2[budgetID].(map[string]interface{}) + + usageBeforeUpdate, _ := budget2["current_usage"].(float64) + t.Logf("Customer budget usage after request: $%.6f", usageBeforeUpdate) + + if usageBeforeUpdate <= 0 { + t.Skip("No customer budget consumed") + } + + // UPDATE: set new limit LOWER than current usage + newLowerBudget := 0.001 + resetDurationPtr := resetDuration + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/customers/" + customerID, + Body: UpdateCustomerRequest{ + Budget: &UpdateBudgetRequest{ + MaxLimit: &newLowerBudget, + ResetDuration: &resetDurationPtr, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update customer budget: status %d", updateResp.StatusCode) + } + + t.Logf("Updated customer budget from $%.2f to $%.2f", initialBudget, newLowerBudget) + + time.Sleep(500 * time.Millisecond) + + // Verify update in in-memory store + getBudgetsResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap3 := getBudgetsResp3.Body["budgets"].(map[string]interface{}) + budget3 := budgetsMap3[budgetID].(map[string]interface{}) + + newMaxLimit, _ := budget3["max_limit"].(float64) + usageAfterUpdate, _ := budget3["current_usage"].(float64) + + // Verify new limit is reflected + if newMaxLimit != newLowerBudget { + t.Fatalf("Customer budget max limit not updated: expected %.2f, got %.2f", newLowerBudget, newMaxLimit) + } + + t.Logf("āœ“ Customer budget max limit updated in memory: $%.2f", newMaxLimit) + + // Verify usage reset to 0 (since new max < old usage) + if usageAfterUpdate > 0.000001 { + t.Fatalf("Customer budget usage should reset to 0 when new limit < current usage, but got $%.6f", usageAfterUpdate) + } + + t.Logf("āœ“ Customer budget usage correctly reset to 0 (new limit: $%.2f < old usage: $%.6f)", newMaxLimit, usageBeforeUpdate) + + t.Logf("Customer budget update sync to memory verified āœ“") +} + +// ============================================================================ +// PROVIDER CONFIG BUDGET UPDATE SYNC +// ============================================================================ + +// TestProviderBudgetUpdateSyncToMemory tests that provider config budget updates sync to memory +func TestProviderBudgetUpdateSyncToMemory(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with provider config and initial budget + vkName := "test-vk-provider-budget-update-" + generateRandomID() + initialBudget := 5.0 + resetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: resetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with provider budget: $%.2f", initialBudget) + + // Get initial in-memory state + getVKResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + vkData1 := getVKResp1.Body["virtual_keys"].(map[string]interface{})[vkValue].(map[string]interface{}) + providerConfigs1 := vkData1["provider_configs"].([]interface{}) + providerConfig1 := providerConfigs1[0].(map[string]interface{}) + providerConfigID := uint(providerConfig1["id"].(float64)) + budgetID, _ := providerConfig1["budget_id"].(string) + + getBudgetsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap1 := getBudgetsResp1.Body["budgets"].(map[string]interface{}) + budget1 := budgetsMap1[budgetID].(map[string]interface{}) + + initialMaxLimit, _ := budget1["max_limit"].(float64) + initialUsage, _ := budget1["current_usage"].(float64) + + if initialMaxLimit != initialBudget { + t.Fatalf("Initial provider budget not correct: expected %.2f, got %.2f", initialBudget, initialMaxLimit) + } + + t.Logf("Initial provider budget in memory: MaxLimit=$%.2f, CurrentUsage=$%.6f", initialMaxLimit, initialUsage) + + // Make request to consume provider budget + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request to consume provider budget.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not make request to consume provider budget") + } + + time.Sleep(500 * time.Millisecond) + + // Get state with usage + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + budget2 := budgetsMap2[budgetID].(map[string]interface{}) + + usageBeforeUpdate, _ := budget2["current_usage"].(float64) + t.Logf("Provider budget usage after request: $%.6f", usageBeforeUpdate) + + if usageBeforeUpdate <= 0 { + t.Skip("No provider budget consumed") + } + + // UPDATE: set new limit LOWER than current usage + newLowerBudget := 0.001 + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + ProviderConfigs: []ProviderConfigRequest{ + { + ID: &providerConfigID, + Provider: "openai", + Weight: 1.0, + Budget: &BudgetRequest{ + MaxLimit: newLowerBudget, + ResetDuration: resetDuration, + }, + }, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update provider budget: status %d", updateResp.StatusCode) + } + + t.Logf("Updated provider budget from $%.2f to $%.2f", initialBudget, newLowerBudget) + + time.Sleep(500 * time.Millisecond) + + // Verify update in in-memory store + getBudgetsResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap3 := getBudgetsResp3.Body["budgets"].(map[string]interface{}) + budget3 := budgetsMap3[budgetID].(map[string]interface{}) + + newMaxLimit, _ := budget3["max_limit"].(float64) + usageAfterUpdate, _ := budget3["current_usage"].(float64) + + // Verify new limit is reflected + if newMaxLimit != newLowerBudget { + t.Fatalf("Provider budget max limit not updated: expected %.2f, got %.2f", newLowerBudget, newMaxLimit) + } + + t.Logf("āœ“ Provider budget max limit updated in memory: $%.2f", newMaxLimit) + + // Verify usage reset to 0 (since new max < old usage) + if usageAfterUpdate > 0.000001 { + t.Fatalf("Provider budget usage should reset to 0 when new limit < current usage, but got $%.6f", usageAfterUpdate) + } + + t.Logf("āœ“ Provider budget usage correctly reset to 0 (new limit: $%.2f < old usage: $%.6f)", newMaxLimit, usageBeforeUpdate) + + t.Logf("Provider budget update sync to memory verified āœ“") +} diff --git a/plugins/governance/customerbudget_test.go b/plugins/governance/customerbudget_test.go new file mode 100644 index 0000000000..79e04c1df8 --- /dev/null +++ b/plugins/governance/customerbudget_test.go @@ -0,0 +1,335 @@ +package governance + +import ( + "strconv" + "testing" +) + +// TestCustomerBudgetExceededWithMultipleVKs tests that customer level budgets are enforced across multiple VKs +// by making requests until budget is consumed +func TestCustomerBudgetExceededWithMultipleVKs(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a customer with a fixed budget + customerBudget := 0.01 + customerName := "test-customer-budget-exceeded-" + generateRandomID() + createCustomerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + Budget: &BudgetRequest{ + MaxLimit: customerBudget, + ResetDuration: "1h", + }, + }, + }) + + if createCustomerResp.StatusCode != 200 { + t.Fatalf("Failed to create customer: status %d", createCustomerResp.StatusCode) + } + + customerID := ExtractIDFromResponse(t, createCustomerResp, "id") + testData.AddCustomer(customerID) + + // Create 2 VKs under the customer (directly, without team) + var vkValues []string + for i := 1; i <= 2; i++ { + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: "test-vk-" + generateRandomID(), + CustomerID: &customerID, + Budget: &BudgetRequest{ + MaxLimit: 1.0, // High VK budget so customer is the limiting factor + ResetDuration: "1h", + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK %d: status %d", i, createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValues = append(vkValues, vk["value"].(string)) + } + + t.Logf("Created customer %s with budget $%.2f and 2 VKs", customerName, customerBudget) + + // Keep making requests alternating between VKs, tracking actual token usage until customer budget is exceeded + consumedBudget := 0.0 + requestNum := 1 + var lastSuccessfulCost float64 + var shouldStop = false + vkIndex := 0 + + for requestNum <= 50 { + // Alternate between VKs to test shared customer budget + vkValue := vkValues[vkIndex%2] + + // Create a longer prompt to consume more tokens and budget faster + longPrompt := "Please provide a comprehensive and detailed response to the following question. " + + "I need extensive information covering all aspects of the topic. " + + "Provide multiple paragraphs with detailed explanations. " + + "Request number " + strconv.Itoa(requestNum) + ". " + + "Here is a detailed prompt that will consume significant tokens: " + + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum. Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum." + + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: longPrompt, + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request failed - check if it's due to budget + if CheckErrorMessage(t, resp, "budget") || CheckErrorMessage(t, resp, "customer") { + t.Logf("Request %d correctly rejected: customer budget exceeded", requestNum) + t.Logf("Consumed budget: $%.6f (limit: $%.2f)", consumedBudget, customerBudget) + t.Logf("Last successful request cost: $%.6f", lastSuccessfulCost) + + // Verify that we made at least one successful request before hitting budget + if requestNum == 1 { + t.Fatalf("First request should have succeeded but was rejected due to budget") + } + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not budget): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract actual token usage from response + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + actualInputTokens := int(prompt) + actualOutputTokens := int(completion) + actualCost, _ := CalculateCost("openai/gpt-4o", actualInputTokens, actualOutputTokens) + + consumedBudget += actualCost + lastSuccessfulCost = actualCost + + t.Logf("Request %d (VK%d) succeeded: input_tokens=%d, output_tokens=%d, cost=$%.6f, consumed=$%.6f/$%.2f", + requestNum, (vkIndex%2)+1, actualInputTokens, actualOutputTokens, actualCost, consumedBudget, customerBudget) + } + } + } + + requestNum++ + vkIndex++ + + if shouldStop { + break + } + + if consumedBudget >= customerBudget { + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit customer budget limit (consumed $%.6f / $%.2f) - budget not being enforced", + requestNum-1, consumedBudget, customerBudget) +} + +// TestCustomerBudgetExceededWithMultipleTeams tests that customer level budgets are enforced across multiple teams +// by making requests until budget is consumed +func TestCustomerBudgetExceededWithMultipleTeams(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a customer with a fixed budget + customerBudget := 0.01 + customerName := "test-customer-multi-team-" + generateRandomID() + createCustomerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + Budget: &BudgetRequest{ + MaxLimit: customerBudget, + ResetDuration: "1h", + }, + }, + }) + + if createCustomerResp.StatusCode != 200 { + t.Fatalf("Failed to create customer: status %d", createCustomerResp.StatusCode) + } + + customerID := ExtractIDFromResponse(t, createCustomerResp, "id") + testData.AddCustomer(customerID) + + // Create 2 teams under the customer + var vkValues []string + for i := 1; i <= 2; i++ { + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: "test-team-" + generateRandomID(), + CustomerID: &customerID, + Budget: &BudgetRequest{ + MaxLimit: 1.0, // High team budget so customer is the limiting factor + ResetDuration: "1h", + }, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team %d: status %d", i, createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + // Create a VK under each team + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: "test-vk-" + generateRandomID(), + TeamID: &teamID, + Budget: &BudgetRequest{ + MaxLimit: 1.0, // High VK budget so customer is the limiting factor + ResetDuration: "1h", + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK %d: status %d", i, createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValues = append(vkValues, vk["value"].(string)) + } + + t.Logf("Created customer %s with budget $%.2f and 2 teams with VKs", customerName, customerBudget) + + // Keep making requests alternating between VKs in different teams, tracking actual token usage until customer budget is exceeded + consumedBudget := 0.0 + requestNum := 1 + var lastSuccessfulCost float64 + var shouldStop = false + vkIndex := 0 + + for requestNum <= 50 { + // Alternate between VKs in different teams to test shared customer budget + vkValue := vkValues[vkIndex%2] + + // Create a longer prompt to consume more tokens and budget faster + longPrompt := "Please provide a comprehensive and detailed response to the following question. " + + "I need extensive information covering all aspects of the topic. " + + "Provide multiple paragraphs with detailed explanations. " + + "Request number " + strconv.Itoa(requestNum) + ". " + + "Here is a detailed prompt that will consume significant tokens: " + + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum. Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum." + + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: longPrompt, + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request failed - check if it's due to budget + if CheckErrorMessage(t, resp, "budget") || CheckErrorMessage(t, resp, "customer") { + t.Logf("Request %d correctly rejected: customer budget exceeded", requestNum) + t.Logf("Consumed budget: $%.6f (limit: $%.2f)", consumedBudget, customerBudget) + t.Logf("Last successful request cost: $%.6f", lastSuccessfulCost) + + // Verify that we made at least one successful request before hitting budget + if requestNum == 1 { + t.Fatalf("First request should have succeeded but was rejected due to budget") + } + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not budget): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract actual token usage from response + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + actualInputTokens := int(prompt) + actualOutputTokens := int(completion) + actualCost, _ := CalculateCost("openai/gpt-4o", actualInputTokens, actualOutputTokens) + + consumedBudget += actualCost + lastSuccessfulCost = actualCost + + t.Logf("Request %d (VK%d) succeeded: input_tokens=%d, output_tokens=%d, cost=$%.6f, consumed=$%.6f/$%.2f", + requestNum, (vkIndex%2)+1, actualInputTokens, actualOutputTokens, actualCost, consumedBudget, customerBudget) + } + } + } + + requestNum++ + vkIndex++ + + if shouldStop { + break + } + + if consumedBudget >= customerBudget { + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit customer budget limit (consumed $%.6f / $%.2f) - budget not being enforced", + requestNum-1, consumedBudget, customerBudget) +} diff --git a/plugins/governance/e2e_test.go b/plugins/governance/e2e_test.go new file mode 100644 index 0000000000..de8e9c3e38 --- /dev/null +++ b/plugins/governance/e2e_test.go @@ -0,0 +1,1543 @@ +package governance + +import ( + "fmt" + "sync" + "testing" + "time" + + configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" +) + +// ============================================================================ +// CRITICAL: Multiple VKs Sharing Team Budget +// ============================================================================ + +// TestMultipleVKsSharingTeamBudgetFairness verifies that when multiple VKs share a team budget, +// one VK cannot monopolize the budget and block others. +// Budget enforcement is POST-HOC: the request that exceeds the budget is allowed, +// but subsequent requests are blocked. +func TestMultipleVKsSharingTeamBudgetFairness(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a team with a small budget that will be exceeded quickly + teamName := "test-team-shared-budget-" + generateRandomID() + teamBudget := 0.01 // $0.01 for team - small enough to exceed in a few requests + teamResetDuration := "1h" + + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + Budget: &BudgetRequest{ + MaxLimit: teamBudget, + ResetDuration: teamResetDuration, + }, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + t.Logf("Created team with shared budget: $%.4f", teamBudget) + + // Create VK1 assigned to team + vk1Name := "test-vk1-shared-" + generateRandomID() + createVK1Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vk1Name, + TeamID: &teamID, + }, + }) + + if createVK1Resp.StatusCode != 200 { + t.Fatalf("Failed to create VK1: status %d", createVK1Resp.StatusCode) + } + + vk1ID := ExtractIDFromResponse(t, createVK1Resp, "id") + testData.AddVirtualKey(vk1ID) + + vk1 := createVK1Resp.Body["virtual_key"].(map[string]interface{}) + vk1Value := vk1["value"].(string) + + // Create VK2 assigned to same team + vk2Name := "test-vk2-shared-" + generateRandomID() + createVK2Resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vk2Name, + TeamID: &teamID, + }, + }) + + if createVK2Resp.StatusCode != 200 { + t.Fatalf("Failed to create VK2: status %d", createVK2Resp.StatusCode) + } + + vk2ID := ExtractIDFromResponse(t, createVK2Resp, "id") + testData.AddVirtualKey(vk2ID) + + vk2 := createVK2Resp.Body["virtual_key"].(map[string]interface{}) + vk2Value := vk2["value"].(string) + + t.Logf("Created VK1 and VK2 both assigned to same team") + + // Use VK1 to consume team budget until it's exceeded + // Budget enforcement is POST-HOC: request that exceeds is allowed, next is blocked + consumedBudget := 0.0 + requestNum := 1 + shouldStop := false + + for requestNum <= 150 { // Need many requests since each costs ~$0.0001 + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Hi, how are you?", + }, + }, + }, + VKHeader: &vk1Value, + }) + + if resp.StatusCode >= 400 { + // VK1 got rejected - budget exceeded + if CheckErrorMessage(t, resp, "budget") { + t.Logf("VK1 request %d rejected: team budget exceeded at $%.6f/$%.4f", requestNum, consumedBudget, teamBudget) + break + } else { + t.Fatalf("VK1 request %d failed with unexpected error: %v", requestNum, resp.Body) + } + } + + // Extract cost from response + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + cost, _ := CalculateCost("openai/gpt-4o", int(prompt), int(completion)) + consumedBudget += cost + t.Logf("VK1 request %d: cost=$%.6f, total consumed=$%.6f/$%.4f", requestNum, cost, consumedBudget, teamBudget) + } + } + } + + requestNum++ + + if shouldStop { + break + } + + if consumedBudget >= teamBudget { + shouldStop = true + } + } + + // Verify that team budget was indeed exceeded + if consumedBudget < teamBudget { + t.Fatalf("Could not exceed team budget after %d requests (consumed $%.6f / $%.4f)", requestNum-1, consumedBudget, teamBudget) + } + + t.Logf("Team budget exhausted by VK1: $%.6f consumed (limit: $%.4f)", consumedBudget, teamBudget) + + // Now try VK2 - should be rejected because team budget was exhausted by VK1 + resp2 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Hello how are you?", + }, + }, + }, + VKHeader: &vk2Value, + }) + + // VK2 should be rejected because team budget was consumed by VK1 + if resp2.StatusCode < 400 { + t.Fatalf("VK2 request should be rejected due to shared team budget exhaustion but got status %d", resp2.StatusCode) + } + + if !CheckErrorMessage(t, resp2, "budget") { + t.Fatalf("Expected budget error for VK2 but got: %v", resp2.Body) + } + + t.Logf("Multiple VKs sharing team budget verified āœ“") + t.Logf("VK2 correctly rejected when team budget exhausted by VK1") +} + +// ============================================================================ +// CRITICAL: Full Budget Hierarchy Validation (All 4 Levels) +// ============================================================================ + +// TestFullBudgetHierarchyEnforcement verifies that ALL levels of hierarchy are checked: +// Provider Budget → VK Budget → Team Budget → Customer Budget +// Budget enforcement happens AFTER limit is exceeded - the request that exceeds is allowed, +// but subsequent requests are blocked. +func TestFullBudgetHierarchyEnforcement(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create customer with high budget + customerName := "test-customer-hierarchy-" + generateRandomID() + customerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + Budget: &BudgetRequest{ + MaxLimit: 1000.0, // Very high + ResetDuration: "1h", + }, + }, + }) + + if customerResp.StatusCode != 200 { + t.Fatalf("Failed to create customer: status %d", customerResp.StatusCode) + } + + customerID := ExtractIDFromResponse(t, customerResp, "id") + testData.AddCustomer(customerID) + + // Create team under customer with medium budget + teamName := "test-team-hierarchy-" + generateRandomID() + teamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + CustomerID: &customerID, + Budget: &BudgetRequest{ + MaxLimit: 100.0, // Medium + ResetDuration: "1h", + }, + }, + }) + + if teamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", teamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, teamResp, "id") + testData.AddTeam(teamID) + + // Create VK under team with lower budget + // Provider budget is MOST RESTRICTIVE at $0.01 - should be exceeded after 2-3 requests + vkName := "test-vk-hierarchy-" + generateRandomID() + vkBudget := 0.1 // $0.1 + providerBudget := 0.01 // $0.01 - MOST RESTRICTIVE + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + TeamID: &teamID, + Budget: &BudgetRequest{ + MaxLimit: vkBudget, + ResetDuration: "1h", + }, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + Budget: &BudgetRequest{ + MaxLimit: providerBudget, + ResetDuration: "1h", + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created full hierarchy:") + t.Logf(" Customer Budget: $1000.0 (not limiting)") + t.Logf(" Team Budget: $100.0 (not limiting)") + t.Logf(" VK Budget: $%.2f (not limiting)", vkBudget) + t.Logf(" Provider Budget: $%.2f (MOST RESTRICTIVE)", providerBudget) + + // Make requests until provider budget is exceeded + // Budget enforcement: request that exceeds is allowed, NEXT request is blocked + consumedBudget := 0.0 + requestNum := 1 + var lastSuccessfulCost float64 + shouldStop := false + + for requestNum <= 20 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test hierarchy enforcement request " + string(rune('0'+requestNum%10)), + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request failed - check if it's due to budget + if CheckErrorMessage(t, resp, "budget") { + t.Logf("Request %d correctly rejected: budget exceeded at provider level", requestNum) + t.Logf("Consumed budget: $%.6f (provider limit: $%.2f)", consumedBudget, providerBudget) + t.Logf("Last successful request cost: $%.6f", lastSuccessfulCost) + + // Verify rejection happened after exceeding the budget + if consumedBudget < providerBudget { + t.Fatalf("Request rejected before budget was exceeded: consumed $%.6f < limit $%.2f", consumedBudget, providerBudget) + } + + t.Logf("Full budget hierarchy enforcement verified āœ“") + t.Logf("Request blocked at provider level (lowest in hierarchy)") + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not budget): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract actual token usage + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + actualCost, _ := CalculateCost("openai/gpt-4o", int(prompt), int(completion)) + consumedBudget += actualCost + lastSuccessfulCost = actualCost + t.Logf("Request %d succeeded: cost=$%.6f, consumed=$%.6f/$%.2f", + requestNum, actualCost, consumedBudget, providerBudget) + } + } + } + + requestNum++ + + if shouldStop { + break + } + + if consumedBudget >= providerBudget { + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit provider budget limit (consumed $%.6f / $%.2f) - budget not being enforced at provider level", + requestNum-1, consumedBudget, providerBudget) +} + +// ============================================================================ +// CRITICAL: Failed Requests Don't Consume Budget/Rate Limits +// ============================================================================ + +// TestFailedRequestsDoNotConsumeBudget verifies that requests that fail +// (4xx/5xx responses) do not consume budget or rate limits +func TestFailedRequestsDoNotConsumeBudget(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with small budget to easily verify consumption + vkName := "test-vk-failed-requests-" + generateRandomID() + budget := 0.1 + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: budget, + ResetDuration: "1h", + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with budget: $%.2f", budget) + + // Get initial budget from in-memory store + getDataResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap1 := getDataResp1.Body["virtual_keys"].(map[string]interface{}) + + getBudgetsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap1 := getBudgetsResp1.Body["budgets"].(map[string]interface{}) + + vkData1 := virtualKeysMap1[vkValue].(map[string]interface{}) + budgetID, _ := vkData1["budget_id"].(string) + + budgetData1 := budgetsMap1[budgetID].(map[string]interface{}) + initialUsage, _ := budgetData1["current_usage"].(float64) + + t.Logf("Initial budget usage: $%.6f", initialUsage) + + // Make a request with invalid input that will fail + // Using an invalid model name to force 400 error + failResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "invalid-model-that-does-not-exist", + Messages: []ChatMessage{ + { + Role: "user", + Content: "This request should fail.", + }, + }, + }, + VKHeader: &vkValue, + }) + + t.Logf("Failed request status: %d", failResp.StatusCode) + + if failResp.StatusCode < 400 { + t.Skip("Could not create failing request - model may be accepted") + } + + // Wait for any async processing + time.Sleep(500 * time.Millisecond) + + // Check budget usage - should NOT have changed + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + budgetData2 := budgetsMap2[budgetID].(map[string]interface{}) + usageAfterFailed, _ := budgetData2["current_usage"].(float64) + + t.Logf("Budget usage after failed request: $%.6f", usageAfterFailed) + + if usageAfterFailed > initialUsage+0.0001 { + t.Fatalf("Failed request consumed budget: before=$%.6f, after=$%.6f", initialUsage, usageAfterFailed) + } + + // Now make a successful request + successResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "This request should succeed.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if successResp.StatusCode != 200 { + t.Skip("Could not make successful request") + } + + // Wait for async update + time.Sleep(500 * time.Millisecond) + + // Check budget usage - should have changed + getBudgetsResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap3 := getBudgetsResp3.Body["budgets"].(map[string]interface{}) + budgetData3 := budgetsMap3[budgetID].(map[string]interface{}) + usageAfterSuccess, _ := budgetData3["current_usage"].(float64) + + t.Logf("Budget usage after successful request: $%.6f", usageAfterSuccess) + + if usageAfterSuccess <= usageAfterFailed+0.0001 { + t.Fatalf("Successful request did not consume budget: before=$%.6f, after=$%.6f", usageAfterFailed, usageAfterSuccess) + } + + t.Logf("Failed requests do NOT consume budget āœ“") + t.Logf("Successful requests DO consume budget āœ“") +} + +// ============================================================================ +// CRITICAL: Inactive Virtual Key Behavior +// ============================================================================ + +// TestInactiveVirtualKeyBlocking verifies that inactive VKs reject requests immediately +// and that reactivating VK allows requests again +func TestInactiveVirtualKeyBlocking(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create active VK + vkName := "test-vk-inactive-" + generateRandomID() + isActive := true + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + IsActive: &isActive, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK in ACTIVE state") + + // Verify active VK works + resp1 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request with active VK should succeed.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp1.StatusCode != 200 { + t.Fatalf("Active VK request should succeed but got status %d", resp1.StatusCode) + } + + t.Logf("Active VK request succeeded āœ“") + + // Deactivate VK + isInactive := false + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + IsActive: &isInactive, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to deactivate VK: status %d", updateResp.StatusCode) + } + + t.Logf("VK deactivated (isActive = false)") + + // Wait for in-memory store update + time.Sleep(500 * time.Millisecond) + + // Verify inactive VK is blocked + resp2 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request with inactive VK should be blocked.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp2.StatusCode < 400 { + t.Fatalf("Inactive VK request should be blocked but got status %d", resp2.StatusCode) + } + + if !CheckErrorMessage(t, resp2, "blocked") { + t.Fatalf("Expected 'blocked' in error message but got: %v", resp2.Body) + } + + t.Logf("Inactive VK request rejected āœ“") + + // Reactivate VK + isActiveAgain := true + reactivateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + IsActive: &isActiveAgain, + }, + }) + + if reactivateResp.StatusCode != 200 { + t.Fatalf("Failed to reactivate VK: status %d", reactivateResp.StatusCode) + } + + t.Logf("VK reactivated (isActive = true)") + + // Wait for in-memory store update + time.Sleep(500 * time.Millisecond) + + // Verify reactivated VK works + resp3 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request with reactivated VK should succeed.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp3.StatusCode != 200 { + t.Fatalf("Reactivated VK request should succeed but got status %d", resp3.StatusCode) + } + + t.Logf("Reactivated VK request succeeded āœ“") + t.Logf("Inactive VK behavior verified āœ“") +} + +// ============================================================================ +// HIGH: Rate Limit Reset Boundaries and Edge Cases +// ============================================================================ + +// TestRateLimitResetBoundaryConditions verifies rate limit resets at exact boundaries +func TestRateLimitResetBoundaryConditions(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with short reset duration for quick testing + vkName := "test-vk-reset-boundary-" + generateRandomID() + requestLimit := int64(1) + resetDuration := "15s" // Short duration for testing + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + RequestMaxLimit: &requestLimit, + RequestResetDuration: &resetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with request limit: %d request per %s", requestLimit, resetDuration) + + // Make first request at t=0 + startTime := time.Now() + resp1 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "First request at t=0.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp1.StatusCode != 200 { + t.Skip("Could not make first request") + } + + t.Logf("First request succeeded at t=0 āœ“") + + // Try immediate second request - should fail + resp2 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Second request before reset.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp2.StatusCode < 400 { + t.Fatalf("Second request should be rejected but got status %d", resp2.StatusCode) + } + + t.Logf("Second request rejected (within reset window) āœ“") + + // Wait for reset duration + 1 second to ensure reset happens + waitTime := time.Until(startTime.Add(16 * time.Second)) + if waitTime > 0 { + t.Logf("Waiting %.1f seconds for rate limit to reset...", waitTime.Seconds()) + time.Sleep(waitTime) + } + + // After reset, third request should succeed + resp3 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Third request after reset duration.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp3.StatusCode != 200 { + t.Fatalf("Third request after reset should succeed but got status %d", resp3.StatusCode) + } + + t.Logf("Third request succeeded after reset duration āœ“") + t.Logf("Rate limit reset boundary conditions verified āœ“") +} + +// ============================================================================ +// HIGH: Concurrent Requests to Same VK +// ============================================================================ + +// TestConcurrentRequestsToSameVK verifies that concurrent requests are handled safely +// and counters remain accurate under concurrent load +func TestConcurrentRequestsToSameVK(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with high token limit to allow concurrent requests + vkName := "test-vk-concurrent-" + generateRandomID() + tokenLimit := int64(100000) + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with high token limit for concurrent testing") + + // Launch concurrent requests + numGoroutines := 5 + requestsPerGoroutine := 3 + totalRequests := numGoroutines * requestsPerGoroutine + + var wg sync.WaitGroup + successCount := 0 + var mu sync.Mutex + + t.Logf("Launching %d goroutines with %d requests each (total: %d requests)", + numGoroutines, requestsPerGoroutine, totalRequests) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(goID int) { + defer wg.Done() + for j := 0; j < requestsPerGoroutine; j++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Concurrent request from goroutine.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode == 200 { + mu.Lock() + successCount++ + mu.Unlock() + } + } + }(i) + } + + wg.Wait() + + t.Logf("Concurrent requests completed: %d successful out of %d total", successCount, totalRequests) + + if successCount == 0 { + t.Skip("No requests succeeded - cannot test concurrent behavior") + } + + if successCount < totalRequests/2 { + t.Logf("Warning: Less than 50%% requests succeeded (%d/%d)", successCount, totalRequests) + } + + t.Logf("Concurrent request handling verified āœ“") + t.Logf("No data corruption detected (test completed successfully)") +} + +// ============================================================================ +// HIGH: Budget State After Reset +// ============================================================================ + +// TestBudgetStateAfterReset verifies that budget usage is correctly reset to 0 +// and LastReset timestamp is updated +func TestBudgetStateAfterReset(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with short reset duration + vkName := "test-vk-budget-reset-state-" + generateRandomID() + budgetLimit := 1.0 + resetDuration := "15s" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: budgetLimit, + ResetDuration: resetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with budget: $%.2f, reset duration: %s", budgetLimit, resetDuration) + + // Get initial budget state + getDataResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap1 := getDataResp1.Body["virtual_keys"].(map[string]interface{}) + + getBudgetsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap1 := getBudgetsResp1.Body["budgets"].(map[string]interface{}) + + vkData1 := virtualKeysMap1[vkValue].(map[string]interface{}) + budgetID, _ := vkData1["budget_id"].(string) + + budgetData1 := budgetsMap1[budgetID].(map[string]interface{}) + initialUsage, _ := budgetData1["current_usage"].(float64) + lastReset1, _ := budgetData1["last_reset"].(string) + + t.Logf("Initial budget state: usage=$%.6f, lastReset=%s", initialUsage, lastReset1) + + // Make a request to consume some budget + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request to consume budget before reset.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not make request to consume budget") + } + + // Wait for async update + time.Sleep(500 * time.Millisecond) + + // Check usage after request + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + budgetData2 := budgetsMap2[budgetID].(map[string]interface{}) + usageAfterRequest, _ := budgetData2["current_usage"].(float64) + + t.Logf("Budget after request: usage=$%.6f (consumed)", usageAfterRequest) + + if usageAfterRequest <= initialUsage { + t.Skip("Request did not consume budget") + } + + // Wait for reset duration to pass + // We need to wait until LastReset + resetDuration has passed + // Parse the lastReset time to calculate the exact wait time + lastResetTime, err := time.Parse(time.RFC3339Nano, lastReset1) + if err != nil { + // Fallback to RFC3339 if RFC3339Nano fails + lastResetTime, err = time.Parse(time.RFC3339, lastReset1) + if err != nil { + t.Fatalf("Failed to parse lastReset time: %v", err) + } + } + resetDurationParsed, err := configstoreTables.ParseDuration(resetDuration) + if err != nil { + t.Fatalf("Failed to parse reset duration: %v", err) + } + + // Calculate when reset should occur with a 2-second safety buffer + resetTime := lastResetTime.Add(resetDurationParsed).Add(2 * time.Second) + waitTime := time.Until(resetTime) + if waitTime > 0 { + t.Logf("Waiting %.1f seconds for budget to reset (lastReset was %s, reset duration is %s)...", waitTime.Seconds(), lastReset1, resetDuration) + time.Sleep(waitTime) + } else { + t.Logf("No wait needed - reset duration has already passed") + } + + // Budget resets are LAZY - they happen when: + // 1. Background tracker runs ResetExpiredBudgets, OR + // 2. A new request triggers UpdateBudgetUsage (which resets expired budgets inline) + // Make another request to trigger the lazy reset mechanism + t.Logf("Making request to trigger lazy budget reset...") + resp2 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request after reset duration to trigger lazy reset.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp2.StatusCode != 200 { + t.Logf("Post-reset request status: %d (expected 200)", resp2.StatusCode) + } + + // Wait for async update using polling instead of fixed sleep + // Poll for budget data to reflect the reset + _, resetVerified := WaitForAPICondition(t, APIRequest{ + Method: "GET", + Path: fmt.Sprintf("/api/governance/budgets?from_memory=true"), + }, func(resp *APIResponse) bool { + if resp.StatusCode != 200 { + return false + } + budgetsData, ok := resp.Body["budgets"].(map[string]interface{}) + if !ok { + return false + } + budgetData, ok := budgetsData[budgetID].(map[string]interface{}) + if !ok { + return false + } + // Check if LastReset has been updated (indicating reset occurred) + newLastReset, ok := budgetData["last_reset"].(string) + return ok && newLastReset != lastReset1 + }, 5*time.Second, "budget reset verified by timestamp") + + if !resetVerified { + t.Logf("Warning: Reset verification polling timed out, but will proceed with final check") + } + + // Check budget after reset + getBudgetsResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap3 := getBudgetsResp3.Body["budgets"].(map[string]interface{}) + budgetData3 := budgetsMap3[budgetID].(map[string]interface{}) + usageAfterReset, _ := budgetData3["current_usage"].(float64) + lastReset3, _ := budgetData3["last_reset"].(string) + + t.Logf("Budget after reset: usage=$%.6f, lastReset=%s", usageAfterReset, lastReset3) + + // Verify the reset actually happened by checking the LastReset timestamp changed + // This is the most reliable indicator that a reset occurred + if lastReset3 == lastReset1 { + t.Fatalf("Budget reset failed: LastReset timestamp was not updated (%s -> %s)", lastReset1, lastReset3) + } + t.Logf("āœ“ Budget reset verified by LastReset timestamp change") + + // Verify budget wasn't cumulative (which would indicate no reset) + // A normal request costs $0.003-0.010 + // If it's the sum of two requests, it would be $0.008+ + // This maximum check prevents detecting cumulative usage while allowing cost variations + if usageAfterReset > 0.012 { + t.Logf("WARNING: Budget usage suspiciously high after reset: $%.6f (might indicate reset didn't work, but timestamp changed so reset verified)", usageAfterReset) + t.Logf(" Before reset: $%.6f", usageAfterRequest) + t.Logf(" After reset: $%.6f", usageAfterReset) + // Don't fail - could be legitimate variation in API costs + } + + t.Logf("Budget state after reset verified āœ“") + t.Logf("Usage was reset from $%.6f to ~$%.6f (cost of one post-reset request) āœ“", usageAfterRequest, usageAfterReset) +} + +// ============================================================================ +// HIGH: Team Deletion Cascade +// ============================================================================ + +// TestTeamDeletionCascade verifies that deleting a team with VKs properly cleans up +func TestTeamDeletionCascade(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create team + teamName := "test-team-deletion-" + generateRandomID() + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + Budget: &BudgetRequest{ + MaxLimit: 100.0, + ResetDuration: "1h", + }, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + t.Logf("Created team: %s", teamID) + + // Create VK assigned to team + vkName := "test-vk-for-team-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + TeamID: &teamID, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK assigned to team: %s", vkID) + + // Verify VK works + resp1 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request before team deletion.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp1.StatusCode != 200 { + t.Skip("Could not verify VK before deletion") + } + + t.Logf("VK works before team deletion āœ“") + + // Delete team + deleteResp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: "/api/governance/teams/" + teamID, + }) + + if deleteResp.StatusCode != 200 { + t.Fatalf("Failed to delete team: status %d", deleteResp.StatusCode) + } + + t.Logf("Team deleted") + + // Wait for in-memory store update + time.Sleep(500 * time.Millisecond) + + // Try to use VK after team deletion + // Expected: VK should continue to work after team deletion + // VKs can function independently without a team, but they lose access to team budget + resp2 := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request after team deletion.", + }, + }, + }, + VKHeader: &vkValue, + }) + + // Assert VK request succeeds after team deletion + if resp2.StatusCode != 200 { + t.Fatalf("Expected 200 OK after team deletion (VK should continue to work), got status %d. Response: %v", resp2.StatusCode, resp2.Body) + } + + // Assert no team budget was billed (team is deleted, so team budget should not be used) + // The request should succeed but without team budget constraints + // Note: We can't directly verify team budget wasn't billed from the response, + // but we verify the request succeeds which confirms VK works independently + t.Logf("Team deletion cascade verified āœ“: VK continues to work after team deletion (without team budget)") +} + +// ============================================================================ +// HIGH: VK Deletion Cascade +// ============================================================================ + +// TestVKDeletionCascade verifies that deleting a VK properly cleans up all related resources +func TestVKDeletionCascade(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with rate limit and budget + vkName := "test-vk-deletion-" + generateRandomID() + tokenLimit := int64(1000) + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: 10.0, + ResetDuration: "1h", + }, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with rate limit and budget") + + // Verify VK exists in in-memory store + getDataResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap1 := getDataResp1.Body["virtual_keys"].(map[string]interface{}) + + _, exists1 := virtualKeysMap1[vkValue] + if !exists1 { + t.Fatalf("VK not found in in-memory store after creation") + } + + t.Logf("VK exists in in-memory store āœ“") + + // Delete VK + deleteResp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: "/api/governance/virtual-keys/" + vkID, + }) + + if deleteResp.StatusCode != 200 { + t.Fatalf("Failed to delete VK: status %d", deleteResp.StatusCode) + } + + t.Logf("VK deleted from database") + + // Wait for in-memory store update + time.Sleep(500 * time.Millisecond) + + // Verify VK is removed from in-memory store + getDataResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap2 := getDataResp2.Body["virtual_keys"].(map[string]interface{}) + + _, exists2 := virtualKeysMap2[vkValue] + if exists2 { + t.Fatalf("VK still exists in in-memory store after deletion") + } + + t.Logf("VK removed from in-memory store āœ“") + + // Try to use deleted VK + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request with deleted VK should fail.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode < 400 { + t.Logf("Deleted VK still accepts requests (status=%d) - may be cached in SDK", resp.StatusCode) + } else { + t.Logf("Deleted VK request rejected (status=%d) āœ“", resp.StatusCode) + } + + t.Logf("VK deletion cascade verified āœ“") +} + +// ============================================================================ +// FEATURE: Load Balancing with Weighted Provider Distribution +// ============================================================================ + +// TestWeightedProviderLoadBalancing verifies that traffic is distributed between +// providers according to their weights when they share common models +func TestWeightedProviderLoadBalancing(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with two providers: 99% OpenAI, 1% Azure (both support gpt-4o) + vkName := "test-vk-weighted-lb-" + generateRandomID() + openaiWeight := 99.0 + azureWeight := 1.0 + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: openaiWeight, + AllowedModels: []string{"gpt-4o"}, + }, + { + Provider: "azure", + Weight: azureWeight, + AllowedModels: []string{"gpt-4o"}, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with weighted providers: OpenAI(%.0f%%), Azure(%.0f%%)", openaiWeight, azureWeight) + + // Verify both providers are configured + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + providerConfigs, _ := vkData["provider_configs"].([]interface{}) + + if len(providerConfigs) != 2 { + t.Fatalf("Expected 2 provider configs, got %d", len(providerConfigs)) + } + + t.Logf("Both provider configs present in in-memory store āœ“") + + // Make 10 requests with just "gpt-4o" (no provider prefix) + // Expected: ~99 go to OpenAI, ~1 go to Azure + numRequests := 10 + openaiCount := 0 + azureCount := 0 + failureCount := 0 + + t.Logf("Making %d weighted requests with model: 'gpt-4o' (no provider prefix)...", numRequests) + + for i := 0; i < numRequests; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "gpt-4o", // No provider prefix - should be routed based on weights + Messages: []ChatMessage{ + { + Role: "user", + Content: "Hello how are you?", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + failureCount++ + t.Logf("Request %d failed with status %d", i+1, resp.StatusCode) + continue + } + + // Try to detect which provider was used + // Check if model in response contains provider name + if provider, ok := resp.Body["extra_fields"].(map[string]interface{})["provider"].(string); ok { + model, ok := resp.Body["extra_fields"].(map[string]interface{})["model_requested"].(string) + if !ok { + t.Logf("Request %d failed to get model requested", i+1) + continue + } + if provider == "openai" { + openaiCount++ + t.Logf("Request %d routed to OpenAI (model: %s)", i+1, model) + } else if provider == "azure" { + azureCount++ + t.Logf("Request %d routed to Azure (model: %s)", i+1, model) + } + } + } + + totalSuccess := openaiCount + azureCount + t.Logf("Results: OpenAI=%d, Azure=%d, Failed=%d (total requests=%d)", + openaiCount, azureCount, failureCount, numRequests) + + if totalSuccess == 0 { + t.Skip("No successful requests to analyze distribution") + } + + // With 99% weight to OpenAI and 1% to Azure: + // Out of 10 requests, we expect ~0-2 to go to Azure (1%) + if azureCount > 2 { + t.Logf("Warning: More requests went to Azure than expected (got %d, expected ~0-2)", azureCount) + } + + t.Logf("Weighted provider load balancing verified āœ“") + t.Logf("Traffic distribution approximately matches configured weights") +} + +// ============================================================================ +// FEATURE: Fallback Provider Mechanism +// ============================================================================ + +// TestProviderFallbackMechanism verifies that when primary provider doesn't support +// a model, fallback providers are used automatically +func TestProviderFallbackMechanism(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with two providers: + // - 99% Anthropic (does NOT support gpt-4o) + // - 1% OpenAI (DOES support gpt-4o) + // When requesting gpt-4o, it should fall back to OpenAI since Anthropic doesn't have it + vkName := "test-vk-fallback-" + generateRandomID() + anthropicWeight := 99.0 + openaiWeight := 1.0 + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "anthropic", + Weight: anthropicWeight, + AllowedModels: []string{"claude-3-sonnet"}, // Does NOT include gpt-4o + }, + { + Provider: "openai", + Weight: openaiWeight, + AllowedModels: []string{"gpt-4o"}, // DOES include gpt-4o + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with providers: Anthropic(99%%, no gpt-4o), OpenAI(1%%, supports gpt-4o)") + + // Make 5 requests for gpt-4o model + // Even though Anthropic has 99% weight, all should succeed via OpenAI fallback + numRequests := 5 + successCount := 0 + + t.Logf("Making %d requests with model: 'gpt-4o' (not supported by primary provider)...", numRequests) + + for i := 0; i < numRequests; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "gpt-4o", // Only OpenAI supports this + Messages: []ChatMessage{ + { + Role: "user", + Content: "Hello how are you?", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode == 200 { + successCount++ + + // Try to detect which provider actually handled it + model := "" + if m, ok := resp.Body["model"].(string); ok { + model = m + } + + t.Logf("Request %d succeeded (model: %s) - likely via OpenAI fallback", i+1, model) + } else { + t.Logf("Request %d failed with status %d", i+1, resp.StatusCode) + } + } + + t.Logf("Results: %d/%d requests succeeded via fallback", successCount, numRequests) + + if successCount == 0 { + t.Skip("No successful requests - cannot verify fallback mechanism") + } + + if successCount < numRequests { + t.Logf("Warning: Not all requests succeeded (got %d/%d)", successCount, numRequests) + } else { + t.Logf("All requests succeeded via fallback provider āœ“") + } + + t.Logf("Fallback provider mechanism verified āœ“") + t.Logf("Requests successfully routed to fallback when primary doesn't support model") +} diff --git a/plugins/governance/edgecases_test.go b/plugins/governance/edgecases_test.go new file mode 100644 index 0000000000..1e2c50d1c1 --- /dev/null +++ b/plugins/governance/edgecases_test.go @@ -0,0 +1,188 @@ +package governance + +import ( + "strconv" + "testing" + "time" +) + +// TestCrissCrossComplexBudgetHierarchy tests complex scenarios involving provider, VK, team, and customer level budgets +// Tests that the most restrictive budget at each level is enforced +func TestCrissCrossComplexBudgetHierarchy(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a customer with a moderate budget + customerBudget := 0.15 + customerName := "test-customer-criss-cross-" + generateRandomID() + createCustomerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + Budget: &BudgetRequest{ + MaxLimit: customerBudget, + ResetDuration: "1h", + }, + }, + }) + + if createCustomerResp.StatusCode != 200 { + t.Fatalf("Failed to create customer: status %d", createCustomerResp.StatusCode) + } + + customerID := ExtractIDFromResponse(t, createCustomerResp, "id") + testData.AddCustomer(customerID) + + // Create a team under customer with a tighter budget + teamBudget := 0.12 + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: "test-team-criss-cross-" + generateRandomID(), + CustomerID: &customerID, + Budget: &BudgetRequest{ + MaxLimit: teamBudget, + ResetDuration: "1h", + }, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + // Create a VK with even tighter budget and provider-specific budgets + vkBudget := 0.01 + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: "test-vk-criss-cross-" + generateRandomID(), + TeamID: &teamID, + Budget: &BudgetRequest{ + MaxLimit: vkBudget, + ResetDuration: "1h", + }, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + Budget: &BudgetRequest{ + MaxLimit: 0.08, // Even tighter provider budget + ResetDuration: "1h", + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created hierarchy: Customer ($%.2f) -> Team ($%.2f) -> VK ($%.2f) with Provider Budget ($0.08)", + customerBudget, teamBudget, vkBudget) + + // Wait for VK and provider config budgets to be synced to in-memory store + time.Sleep(1000 * time.Millisecond) + + // Test: Provider budget should be the limiting factor (most restrictive) + consumedBudget := 0.0 + requestNum := 1 + var lastSuccessfulCost float64 + var shouldStop = false + + for requestNum <= 50 { + longPrompt := "Please provide a comprehensive and detailed response to the following question. " + + "I need extensive information covering all aspects of the topic. " + + "Provide multiple paragraphs with detailed explanations. " + + "Request number " + strconv.Itoa(requestNum) + ". " + + "Here is a detailed prompt that will consume significant tokens: " + + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum. Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum." + + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: longPrompt, + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request failed - check if it's due to budget + if CheckErrorMessage(t, resp, "budget") || CheckErrorMessage(t, resp, "provider") { + t.Logf("Request %d correctly rejected: budget exceeded in criss-cross hierarchy", requestNum) + t.Logf("Consumed budget: $%.6f (provider budget limit: $0.08)", consumedBudget) + t.Logf("Last successful request cost: $%.6f", lastSuccessfulCost) + + if requestNum == 1 { + t.Fatalf("First request should have succeeded but was rejected due to budget") + } + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not budget): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract actual token usage from response + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + actualInputTokens := int(prompt) + actualOutputTokens := int(completion) + actualCost, _ := CalculateCost("openai/gpt-4o", actualInputTokens, actualOutputTokens) + + consumedBudget += actualCost + lastSuccessfulCost = actualCost + + t.Logf("Request %d succeeded: input_tokens=%d, output_tokens=%d, cost=$%.6f, consumed=$%.6f", + requestNum, actualInputTokens, actualOutputTokens, actualCost, consumedBudget) + } + } + } + + requestNum++ + + if shouldStop { + break + } + + if consumedBudget >= 0.08 { // Provider budget + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit provider budget limit - budget not being enforced", + requestNum-1) +} diff --git a/plugins/governance/fixtures_test.go b/plugins/governance/fixtures_test.go new file mode 100644 index 0000000000..c3e2bf576d --- /dev/null +++ b/plugins/governance/fixtures_test.go @@ -0,0 +1,222 @@ +package governance + +import ( + "sync" + "testing" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// MockLogger implements schemas.Logger for testing +type MockLogger struct { + mu sync.Mutex + logs []string + errors []string + debugs []string + infos []string + warnings []string +} + +func NewMockLogger() *MockLogger { + return &MockLogger{ + logs: make([]string, 0), + errors: make([]string, 0), + debugs: make([]string, 0), + infos: make([]string, 0), + warnings: make([]string, 0), + } +} + +func (ml *MockLogger) SetLevel(level schemas.LogLevel) {} + +func (ml *MockLogger) SetOutputType(outputType schemas.LoggerOutputType) {} + +func (ml *MockLogger) Error(format string, args ...interface{}) { + ml.mu.Lock() + defer ml.mu.Unlock() + ml.errors = append(ml.errors, format) +} + +func (ml *MockLogger) Warn(format string, args ...interface{}) { + ml.mu.Lock() + defer ml.mu.Unlock() + ml.warnings = append(ml.warnings, format) +} + +func (ml *MockLogger) Info(format string, args ...interface{}) { + ml.mu.Lock() + defer ml.mu.Unlock() + ml.infos = append(ml.infos, format) +} + +func (ml *MockLogger) Debug(format string, args ...interface{}) { + ml.mu.Lock() + defer ml.mu.Unlock() + ml.debugs = append(ml.debugs, format) +} + +func (ml *MockLogger) Fatal(format string, args ...interface{}) { + ml.mu.Lock() + defer ml.mu.Unlock() + ml.errors = append(ml.errors, format) +} + +// Test data builders + +func buildVirtualKey(id, value, name string, isActive bool) *configstoreTables.TableVirtualKey { + return &configstoreTables.TableVirtualKey{ + ID: id, + Value: value, + Name: name, + IsActive: isActive, + } +} + +func buildVirtualKeyWithBudget(id, value, name string, budget *configstoreTables.TableBudget) *configstoreTables.TableVirtualKey { + vk := buildVirtualKey(id, value, name, true) + vk.Budget = budget + budgetID := budget.ID + vk.BudgetID = &budgetID + return vk +} + +func buildVirtualKeyWithRateLimit(id, value, name string, rateLimit *configstoreTables.TableRateLimit) *configstoreTables.TableVirtualKey { + vk := buildVirtualKey(id, value, name, true) + vk.RateLimit = rateLimit + rateLimitID := rateLimit.ID + vk.RateLimitID = &rateLimitID + return vk +} + +func buildVirtualKeyWithProviders(id, value, name string, providers []configstoreTables.TableVirtualKeyProviderConfig) *configstoreTables.TableVirtualKey { + vk := buildVirtualKey(id, value, name, true) + vk.ProviderConfigs = providers + return vk +} + +func buildBudget(id string, maxLimit float64, resetDuration string) *configstoreTables.TableBudget { + return &configstoreTables.TableBudget{ + ID: id, + MaxLimit: maxLimit, + CurrentUsage: 0, + ResetDuration: resetDuration, + LastReset: time.Now(), + } +} + +func buildBudgetWithUsage(id string, maxLimit, currentUsage float64, resetDuration string) *configstoreTables.TableBudget { + return &configstoreTables.TableBudget{ + ID: id, + MaxLimit: maxLimit, + CurrentUsage: currentUsage, + ResetDuration: resetDuration, + LastReset: time.Now(), + } +} + +func buildRateLimit(id string, tokenMaxLimit, requestMaxLimit int64) *configstoreTables.TableRateLimit { + duration := "1m" + return &configstoreTables.TableRateLimit{ + ID: id, + TokenMaxLimit: &tokenMaxLimit, + TokenCurrentUsage: 0, + TokenResetDuration: &duration, + TokenLastReset: time.Now(), + RequestMaxLimit: &requestMaxLimit, + RequestCurrentUsage: 0, + RequestResetDuration: &duration, + RequestLastReset: time.Now(), + } +} + +func buildRateLimitWithUsage(id string, tokenMaxLimit, tokenUsage, requestMaxLimit, requestUsage int64) *configstoreTables.TableRateLimit { + duration := "1m" + return &configstoreTables.TableRateLimit{ + ID: id, + TokenMaxLimit: &tokenMaxLimit, + TokenCurrentUsage: tokenUsage, + TokenResetDuration: &duration, + TokenLastReset: time.Now(), + RequestMaxLimit: &requestMaxLimit, + RequestCurrentUsage: requestUsage, + RequestResetDuration: &duration, + RequestLastReset: time.Now(), + } +} + +func buildTeam(id, name string, budget *configstoreTables.TableBudget) *configstoreTables.TableTeam { + team := &configstoreTables.TableTeam{ + ID: id, + Name: name, + } + if budget != nil { + team.Budget = budget + team.BudgetID = &budget.ID + } + return team +} + +func buildCustomer(id, name string, budget *configstoreTables.TableBudget) *configstoreTables.TableCustomer { + customer := &configstoreTables.TableCustomer{ + ID: id, + Name: name, + } + if budget != nil { + customer.Budget = budget + customer.BudgetID = &budget.ID + } + return customer +} + +func buildProviderConfig(provider string, allowedModels []string) configstoreTables.TableVirtualKeyProviderConfig { + return configstoreTables.TableVirtualKeyProviderConfig{ + Provider: provider, + AllowedModels: allowedModels, + Weight: bifrost.Ptr(1.0), + RateLimit: nil, + Budget: nil, + Keys: []configstoreTables.TableKey{}, + } +} + +func buildProviderConfigWithRateLimit(provider string, allowedModels []string, rateLimit *configstoreTables.TableRateLimit) configstoreTables.TableVirtualKeyProviderConfig { + pc := buildProviderConfig(provider, allowedModels) + pc.RateLimit = rateLimit + if rateLimit != nil { + pc.RateLimitID = &rateLimit.ID + } + return pc +} + +// Test helpers + +func assertDecision(t *testing.T, expected Decision, result *EvaluationResult) { + t.Helper() + assert.NotNil(t, result, "EvaluationResult should not be nil") + assert.Equal(t, expected, result.Decision, "Decision mismatch. Reason: %s", result.Reason) +} + +func assertVirtualKeyFound(t *testing.T, result *EvaluationResult) { + t.Helper() + assert.NotNil(t, result.VirtualKey, "VirtualKey should be found in result") +} + +func assertRateLimitInfo(t *testing.T, result *EvaluationResult) { + t.Helper() + assert.NotNil(t, result.RateLimitInfo, "RateLimitInfo should be present in result") +} + +func requireNoError(t *testing.T, err error, msg string) { + t.Helper() + require.NoError(t, err, msg) +} + +func requireError(t *testing.T, err error, msg string) { + t.Helper() + require.Error(t, err, msg) +} diff --git a/plugins/governance/go.mod b/plugins/governance/go.mod index 5de07882e0..bc7ac16aa7 100644 --- a/plugins/governance/go.mod +++ b/plugins/governance/go.mod @@ -5,8 +5,10 @@ go 1.25.5 require gorm.io/gorm v1.31.1 require ( - github.com/maximhq/bifrost/core v1.2.49 - github.com/maximhq/bifrost/framework v1.1.61 + github.com/bytedance/sonic v1.14.2 + github.com/maximhq/bifrost/core v1.3.8 + github.com/maximhq/bifrost/framework v1.2.8 + github.com/stretchr/testify v1.11.1 ) require ( @@ -38,11 +40,14 @@ require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/bytedance/gopkg v0.1.3 // indirect - github.com/bytedance/sonic v1.14.2 // indirect github.com/bytedance/sonic/loader v0.4.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/clarkmcc/go-typescript v0.7.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/analysis v0.24.2 // indirect @@ -66,8 +71,10 @@ require ( github.com/go-openapi/swag/typeutils v0.25.4 // indirect github.com/go-openapi/swag/yamlutils v0.25.4 // indirect github.com/go-openapi/validate v0.25.1 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f // indirect github.com/google/uuid v1.6.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -87,6 +94,7 @@ require ( github.com/oklog/ulid v1.3.1 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/qdrant/go-client v1.16.2 // indirect github.com/redis/go-redis/v9 v9.17.2 // indirect github.com/rs/zerolog v1.34.0 // indirect diff --git a/plugins/governance/go.sum b/plugins/governance/go.sum index 0c229672d9..065f485a18 100644 --- a/plugins/governance/go.sum +++ b/plugins/governance/go.sum @@ -12,6 +12,8 @@ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/Masterminds/semver/v3 v3.3.1 h1:QtNSWtVZ3nBfk8mAOu/B6v7FMJ+NHTIgUPi7rj+4nv4= +github.com/Masterminds/semver/v3 v3.3.1/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= @@ -68,6 +70,8 @@ github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2N github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/clarkmcc/go-typescript v0.7.0 h1:3nVeaPYyTCWjX6Lf8GoEOTxME2bM5tLuWmwhSZ86uxg= +github.com/clarkmcc/go-typescript v0.7.0/go.mod h1:IZ/nzoVeydAmyfX7l6Jmp8lJDOEnae3jffoXwP4UyYg= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -77,6 +81,10 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 h1:jxmXU5V9tXxJnydU5v/m9SG8TRUa/Z7IXODBpMs/P+U= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -132,6 +140,8 @@ github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6 github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= github.com/go-openapi/validate v0.25.1 h1:sSACUI6Jcnbo5IWqbYHgjibrhhmt3vR6lCzKZnmAgBw= github.com/go-openapi/validate v0.25.1/go.mod h1:RMVyVFYte0gbSTaZ0N4KmTn6u/kClvAFp+mAVfS/DQc= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -141,6 +151,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f h1:HU1RgM6NALf/KW9HEY6zry3ADbDKcmpQ+hJedoNGQYQ= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f/go.mod h1:67FPmZWbr+KDT/VlpWtw6sO9XSjpJmLuHpoLmWiTGgY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= @@ -184,10 +196,10 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.2.49 h1:fk6l6r3kVBlpN73wYXmgtV6O4bhedOjSO4LAEz/7leg= -github.com/maximhq/bifrost/core v1.2.49/go.mod h1:z7nOx15e91ktZGi+pZHq+uhShlEK+fM4UyYUpP6oHAw= -github.com/maximhq/bifrost/framework v1.1.61 h1:fMjvICbkrdWMtGnLYrjSNrcmQYqtQvOh/swmrJTvf+E= -github.com/maximhq/bifrost/framework v1.1.61/go.mod h1:wVUPzB8K5S/5GWuxqp8dXf3nNZkqJsS/APMIcq48SOI= +github.com/maximhq/bifrost/core v1.3.8 h1:xtwB9+HeTzYz5IKHkpUtupzBd0A5yl1avdLJGjsOKPI= +github.com/maximhq/bifrost/core v1.3.8/go.mod h1:abKQRnJQPZz8/UMxCcbuNHEyq19Db+IX4KlGJdlLY8E= +github.com/maximhq/bifrost/framework v1.2.8 h1:/oTpacuw7k0zRUJ9dSSQRtAVx3nLGSiR7GFwOjGxZNs= +github.com/maximhq/bifrost/framework v1.2.8/go.mod h1:mjw9YXh/Oxi3HeBCJ+3HJ6ftv43Wo4t0T4EzpcIbnr0= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= @@ -283,6 +295,8 @@ google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/governance/headerparsing_test.go b/plugins/governance/headerparsing_test.go new file mode 100644 index 0000000000..4a4056f3a2 --- /dev/null +++ b/plugins/governance/headerparsing_test.go @@ -0,0 +1,154 @@ +package governance + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCaseInsensitiveLookup(t *testing.T) { + tests := []struct { + name string + data map[string]string + key string + expected string + }{ + { + name: "nil map returns empty string", + data: nil, + key: "Content-Type", + expected: "", + }, + { + name: "empty key returns empty string", + data: map[string]string{"Content-Type": "application/json"}, + key: "", + expected: "", + }, + { + name: "key not found returns empty string", + data: map[string]string{"Content-Type": "application/json"}, + key: "Authorization", + expected: "", + }, + { + name: "exact match", + data: map[string]string{"Content-Type": "application/json"}, + key: "Content-Type", + expected: "application/json", + }, + { + name: "lowercase key match - map has lowercase key", + data: map[string]string{"content-type": "application/json"}, + key: "Content-Type", + expected: "application/json", + }, + { + name: "lowercase key match - query is lowercase", + data: map[string]string{"content-type": "application/json"}, + key: "content-type", + expected: "application/json", + }, + { + name: "case-insensitive iteration - map has mixed case", + data: map[string]string{"Content-Type": "application/json"}, + key: "content-type", + expected: "application/json", + }, + { + name: "case-insensitive iteration - uppercase query", + data: map[string]string{"Content-Type": "application/json"}, + key: "CONTENT-TYPE", + expected: "application/json", + }, + { + name: "multiple keys - finds correct one", + data: map[string]string{"Accept": "text/html", "Content-Type": "application/json"}, + key: "content-type", + expected: "application/json", + }, + // x-bf-vk header variations + { + name: "x-bf-vk exact match lowercase", + data: map[string]string{"x-bf-vk": "sk-bf-test123"}, + key: "x-bf-vk", + expected: "sk-bf-test123", + }, + { + name: "x-bf-vk mixed case in map", + data: map[string]string{"X-Bf-Vk": "sk-bf-test123"}, + key: "x-bf-vk", + expected: "sk-bf-test123", + }, + { + name: "x-bf-vk uppercase in map", + data: map[string]string{"X-BF-VK": "sk-bf-test123"}, + key: "x-bf-vk", + expected: "sk-bf-test123", + }, + // authorization header variations + { + name: "authorization exact match lowercase", + data: map[string]string{"authorization": "Bearer sk-bf-test123"}, + key: "authorization", + expected: "Bearer sk-bf-test123", + }, + { + name: "authorization capitalized in map", + data: map[string]string{"Authorization": "Bearer sk-bf-test123"}, + key: "authorization", + expected: "Bearer sk-bf-test123", + }, + { + name: "authorization uppercase in map", + data: map[string]string{"AUTHORIZATION": "Bearer sk-bf-test123"}, + key: "authorization", + expected: "Bearer sk-bf-test123", + }, + // x-api-key header variations + { + name: "x-api-key exact match lowercase", + data: map[string]string{"x-api-key": "sk-bf-apikey123"}, + key: "x-api-key", + expected: "sk-bf-apikey123", + }, + { + name: "x-api-key mixed case in map", + data: map[string]string{"X-Api-Key": "sk-bf-apikey123"}, + key: "x-api-key", + expected: "sk-bf-apikey123", + }, + { + name: "x-api-key uppercase in map", + data: map[string]string{"X-API-KEY": "sk-bf-apikey123"}, + key: "x-api-key", + expected: "sk-bf-apikey123", + }, + // x-goog-api-key header variations + { + name: "x-goog-api-key exact match lowercase", + data: map[string]string{"x-goog-api-key": "sk-bf-google123"}, + key: "x-goog-api-key", + expected: "sk-bf-google123", + }, + { + name: "x-goog-api-key mixed case in map", + data: map[string]string{"X-Goog-Api-Key": "sk-bf-google123"}, + key: "x-goog-api-key", + expected: "sk-bf-google123", + }, + { + name: "x-goog-api-key uppercase in map", + data: map[string]string{"X-GOOG-API-KEY": "sk-bf-google123"}, + key: "x-goog-api-key", + expected: "sk-bf-google123", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := caseInsensitiveLookup(tt.data, tt.key) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/plugins/governance/inmemorysync_test.go b/plugins/governance/inmemorysync_test.go new file mode 100644 index 0000000000..8de677a25f --- /dev/null +++ b/plugins/governance/inmemorysync_test.go @@ -0,0 +1,554 @@ +package governance + +import ( + "testing" + "time" +) + +// TestInMemorySyncVirtualKeyUpdate tests that in-memory store is updated when VK is updated in DB +func TestInMemorySyncVirtualKeyUpdate(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with initial budget + vkName := "test-vk-sync-" + generateRandomID() + initialBudget := 10.0 + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: "1h", + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with initial budget $%.2f", vkName, initialBudget) + + // Verify in-memory store has the VK + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + + // Check that VK exists in in-memory store + vkData, exists := virtualKeysMap[vkValue] + if !exists { + t.Fatalf("VK %s not found in in-memory store after creation", vkValue) + } + + vkDataMap := vkData.(map[string]interface{}) + vkID2, _ := vkDataMap["id"].(string) + if vkID2 != vkID { + t.Fatalf("VK ID mismatch in in-memory store: expected %s, got %s", vkID, vkID2) + } + + t.Logf("VK found in in-memory store after creation āœ“") + + // Update VK budget to 20.0 + newBudget := 20.0 + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + Budget: &UpdateBudgetRequest{ + MaxLimit: &newBudget, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update VK: status %d, body: %v", updateResp.StatusCode, updateResp.Body) + } + + t.Logf("Updated VK budget from $%.2f to $%.2f", initialBudget, newBudget) + + // Verify in-memory store is updated + time.Sleep(500 * time.Millisecond) // Small delay for async updates + + getVKResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getVKResp2.StatusCode != 200 { + t.Fatalf("Failed to get governance data after update: status %d", getVKResp2.StatusCode) + } + + virtualKeysMap2 := getVKResp2.Body["virtual_keys"].(map[string]interface{}) + + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + + // Check that VK still exists + vkData2, exists := virtualKeysMap2[vkValue] + if !exists { + t.Fatalf("VK %s not found in in-memory store after update", vkValue) + } + + vkDataMap2 := vkData2.(map[string]interface{}) + budgetID, _ := vkDataMap2["budget_id"].(string) + + // Check that budget in in-memory store is updated + if budgetID != "" { + budgetData, budgetExists := budgetsMap2[budgetID] + if !budgetExists { + t.Fatalf("Budget %s not found in in-memory store", budgetID) + } + + budgetDataMap := budgetData.(map[string]interface{}) + maxLimit, _ := budgetDataMap["max_limit"].(float64) + if maxLimit != newBudget { + t.Fatalf("Budget max_limit not updated in in-memory store: expected %.2f, got %.2f", newBudget, maxLimit) + } + } + + t.Logf("VK budget updated in in-memory store āœ“") +} + +// TestInMemorySyncTeamUpdate tests that in-memory store is updated when Team is updated +func TestInMemorySyncTeamUpdate(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a team with initial budget + teamName := "test-team-sync-" + generateRandomID() + initialBudget := 50.0 + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: "1h", + }, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + t.Logf("Created team %s with initial budget $%.2f", teamName, initialBudget) + + // Verify in-memory store has the team + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/teams?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + teamsMap := getDataResp.Body["teams"].(map[string]interface{}) + + _, exists := teamsMap[teamID] + if !exists { + t.Fatalf("Team %s not found in in-memory store after creation", teamID) + } + + t.Logf("Team found in in-memory store after creation āœ“") + + // Update team budget to 100.0 + newTeamBudget := 100.0 + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/teams/" + teamID, + Body: UpdateTeamRequest{ + Budget: &UpdateBudgetRequest{ + MaxLimit: &newTeamBudget, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update team: status %d", updateResp.StatusCode) + } + + t.Logf("Updated team budget from $%.2f to $%.2f", initialBudget, newTeamBudget) + + // Verify in-memory store is updated + time.Sleep(500 * time.Millisecond) + + getTeamsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/teams?from_memory=true", + }) + + if getTeamsResp2.StatusCode != 200 { + t.Fatalf("Failed to get governance data after update: status %d", getTeamsResp2.StatusCode) + } + + teamsMap2 := getTeamsResp2.Body["teams"].(map[string]interface{}) + + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + + teamData2, exists := teamsMap2[teamID] + if !exists { + t.Fatalf("Team %s not found in in-memory store after update", teamID) + } + + teamDataMap := teamData2.(map[string]interface{}) + budgetID, _ := teamDataMap["budget_id"].(string) + + if budgetID != "" { + budgetData, budgetExists := budgetsMap2[budgetID] + if !budgetExists { + t.Fatalf("Budget %s not found in in-memory store", budgetID) + } + + budgetDataMap := budgetData.(map[string]interface{}) + maxLimit, _ := budgetDataMap["max_limit"].(float64) + if maxLimit != newTeamBudget { + t.Fatalf("Team budget max_limit not updated in in-memory store: expected %.2f, got %.2f", newTeamBudget, maxLimit) + } + } + + t.Logf("Team budget updated in in-memory store āœ“") +} + +// TestInMemorySyncCustomerUpdate tests that in-memory store is updated when Customer is updated +func TestInMemorySyncCustomerUpdate(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a customer with initial budget + customerName := "test-customer-sync-" + generateRandomID() + initialBudget := 100.0 + createCustomerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + Budget: &BudgetRequest{ + MaxLimit: initialBudget, + ResetDuration: "1h", + }, + }, + }) + + if createCustomerResp.StatusCode != 200 { + t.Fatalf("Failed to create customer: status %d", createCustomerResp.StatusCode) + } + + customerID := ExtractIDFromResponse(t, createCustomerResp, "id") + testData.AddCustomer(customerID) + + t.Logf("Created customer %s with initial budget $%.2f", customerName, initialBudget) + + // Verify in-memory store has the customer + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/customers?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + customersMap := getDataResp.Body["customers"].(map[string]interface{}) + + _, exists := customersMap[customerID] + if !exists { + t.Fatalf("Customer %s not found in in-memory store after creation", customerID) + } + + t.Logf("Customer found in in-memory store after creation āœ“") + + // Update customer budget to 250.0 + newCustomerBudget := 250.0 + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/customers/" + customerID, + Body: UpdateCustomerRequest{ + Budget: &UpdateBudgetRequest{ + MaxLimit: &newCustomerBudget, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update customer: status %d", updateResp.StatusCode) + } + + t.Logf("Updated customer budget from $%.2f to $%.2f", initialBudget, newCustomerBudget) + + // Verify in-memory store is updated + time.Sleep(500 * time.Millisecond) + + getCustomersResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/customers?from_memory=true", + }) + + if getCustomersResp2.StatusCode != 200 { + t.Fatalf("Failed to get governance data after update: status %d", getCustomersResp2.StatusCode) + } + + customersMap2 := getCustomersResp2.Body["customers"].(map[string]interface{}) + + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + + customerData2, exists := customersMap2[customerID] + if !exists { + t.Fatalf("Customer %s not found in in-memory store after update", customerID) + } + + customerDataMap := customerData2.(map[string]interface{}) + budgetID, _ := customerDataMap["budget_id"].(string) + + if budgetID != "" { + budgetData, budgetExists := budgetsMap2[budgetID] + if !budgetExists { + t.Fatalf("Budget %s not found in in-memory store", budgetID) + } + + budgetDataMap := budgetData.(map[string]interface{}) + maxLimit, _ := budgetDataMap["max_limit"].(float64) + if maxLimit != newCustomerBudget { + t.Fatalf("Customer budget max_limit not updated in in-memory store: expected %.2f, got %.2f", newCustomerBudget, maxLimit) + } + } + + t.Logf("Customer budget updated in in-memory store āœ“") +} + +// TestInMemorySyncVirtualKeyDelete tests that in-memory store is updated when VK is deleted +func TestInMemorySyncVirtualKeyDelete(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK + vkName := "test-vk-delete-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: 10.0, + ResetDuration: "1h", + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + // Verify in-memory store has the VK + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + + _, exists := virtualKeysMap[vkValue] + if !exists { + t.Fatalf("VK not found in in-memory store after creation") + } + + t.Logf("VK found in in-memory store after creation āœ“") + + // Delete the VK + deleteResp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: "/api/governance/virtual-keys/" + vkID, + }) + + if deleteResp.StatusCode != 200 { + t.Fatalf("Failed to delete VK: status %d", deleteResp.StatusCode) + } + + t.Logf("Deleted VK from database") + + // Verify in-memory store is updated + time.Sleep(2 * time.Second) + + getDataResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap2 := getDataResp2.Body["virtual_keys"].(map[string]interface{}) + + _, exists = virtualKeysMap2[vkValue] + if exists { + t.Fatalf("VK %s still exists in in-memory store after deletion", vkValue) + } + + t.Logf("VK removed from in-memory store āœ“") +} + +// TestDataEndpointConsistency tests that governance endpoints return consistent data +func TestDataEndpointConsistency(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create multiple resources + vkName := "test-vk-consistency-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: 15.0, + ResetDuration: "1h", + }, + }, + }) + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + teamName := "test-team-consistency-" + generateRandomID() + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + Budget: &BudgetRequest{ + MaxLimit: 30.0, + ResetDuration: "1h", + }, + }, + }) + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + customerName := "test-customer-consistency-" + generateRandomID() + createCustomerResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/customers", + Body: CreateCustomerRequest{ + Name: customerName, + Budget: &BudgetRequest{ + MaxLimit: 60.0, + ResetDuration: "1h", + }, + }, + }) + + customerID := ExtractIDFromResponse(t, createCustomerResp, "id") + testData.AddCustomer(customerID) + + time.Sleep(1 * time.Second) + + // Get data from separate endpoints + getVKResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getVKResp.StatusCode != 200 { + t.Fatalf("Failed to get virtual keys: status %d", getVKResp.StatusCode) + } + + getTeamsResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/teams?from_memory=true", + }) + + if getTeamsResp.StatusCode != 200 { + t.Fatalf("Failed to get teams: status %d", getTeamsResp.StatusCode) + } + + getCustomersResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/customers?from_memory=true", + }) + + if getCustomersResp.StatusCode != 200 { + t.Fatalf("Failed to get customers: status %d", getCustomersResp.StatusCode) + } + + virtualKeysMap := getVKResp.Body["virtual_keys"].(map[string]interface{}) + teamsMap := getTeamsResp.Body["teams"].(map[string]interface{}) + customersMap := getCustomersResp.Body["customers"].(map[string]interface{}) + + // Verify all created resources are in the in-memory data + vkCount := len(virtualKeysMap) + teamCount := len(teamsMap) + customerCount := len(customersMap) + + if vkCount == 0 { + t.Fatalf("No virtual keys found in data endpoint") + } + if teamCount == 0 { + t.Fatalf("No teams found in data endpoint") + } + if customerCount == 0 { + t.Fatalf("No customers found in data endpoint") + } + + t.Logf("Data endpoint returned consistent data: %d VKs, %d teams, %d customers āœ“", vkCount, teamCount, customerCount) + + // Get the individual endpoints and verify consistency + getVKsResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys", + }) + + if getVKsResp.StatusCode != 200 { + t.Fatalf("Failed to get virtual keys: status %d", getVKsResp.StatusCode) + } + + vksFromEndpoint, _ := getVKsResp.Body["count"].(float64) + if int(vksFromEndpoint) != vkCount { + // Can fail because sqlite db might get locked because of all parallel tests + t.Logf("[WARN]VK count mismatch between /data endpoint and /virtual-keys endpoint: %d vs %d (this can happen because of parallel tests)", vkCount, int(vksFromEndpoint)) + } + + t.Logf("Data consistency verified between endpoints āœ“") +} diff --git a/plugins/governance/main.go b/plugins/governance/main.go index 80320c2c1c..c2588d8ee8 100644 --- a/plugins/governance/main.go +++ b/plugins/governance/main.go @@ -10,6 +10,7 @@ import ( "strings" "sync" + "github.com/bytedance/sonic" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" @@ -37,6 +38,15 @@ type InMemoryStore interface { GetConfiguredProviders() map[schemas.ModelProvider]configstore.ProviderConfig } +type BaseGovernancePlugin interface { + GetName() string + HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) + PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) + PostHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) + Cleanup() error + GetGovernanceStore() GovernanceStore +} + // GovernancePlugin implements the main governance plugin with hierarchical budget system type GovernancePlugin struct { ctx context.Context @@ -44,9 +54,9 @@ type GovernancePlugin struct { wg sync.WaitGroup // Track active goroutines // Core components with clear separation of concerns - store *GovernanceStore // Pure data access layer - resolver *BudgetResolver // Pure decision engine for hierarchical governance - tracker *UsageTracker // Business logic owner (updates, resets, persistence) + store GovernanceStore // Pure data access layer + resolver *BudgetResolver // Pure decision engine for hierarchical governance + tracker *UsageTracker // Business logic owner (updates, resets, persistence) // Dependencies configStore configstore.ConfigStore @@ -67,7 +77,9 @@ type GovernancePlugin struct { // // Behavior and defaults: // - Enables all governance features with optimized defaults. -// - If `store` is nil, the plugin runs in-memory only (no persistence). +// - If `configStore` is nil, the plugin will use an in-memory LocalGovernanceStore +// (no persistence). Init constructs a LocalGovernanceStore internally when +// configStore is nil. // - If `modelCatalog` is nil, cost calculation is skipped. // - `config.IsVkMandatory` controls whether `x-bf-vk` is required in PreHook. // - `inMemoryStore` is used by TransportInterceptor to validate configured providers @@ -80,7 +92,7 @@ type GovernancePlugin struct { // - ctx: base context for the plugin; a child context with cancel is created. // - config: plugin flags; may be nil. // - logger: logger used by all subcomponents. -// - store: configuration store used for persistence; may be nil. +// - configStore: configuration store used for persistence; may be nil. // - governanceConfig: initial/seed governance configuration for the store. // - modelCatalog: optional model catalog to compute request cost. // - inMemoryStore: provider registry used for routing/validation in transports. @@ -91,17 +103,21 @@ type GovernancePlugin struct { // // Side effects: // - Logs warnings when optional dependencies are missing. -// - May perform startup resets via the usage tracker when `store` is non-nil. +// - May perform startup resets via the usage tracker when `configStore` is non-nil. +// +// Alternative entry point: +// - Use InitFromStore to inject a custom GovernanceStore implementation instead +// of constructing a LocalGovernanceStore internally. func Init( ctx context.Context, config *Config, logger schemas.Logger, - store configstore.ConfigStore, + configStore configstore.ConfigStore, governanceConfig *configstore.GovernanceConfig, modelCatalog *modelcatalog.ModelCatalog, inMemoryStore InMemoryStore, ) (*GovernancePlugin, error) { - if store == nil { + if configStore == nil { logger.Warn("governance plugin requires config store to persist data, running in memory only mode") } if modelCatalog == nil { @@ -114,7 +130,7 @@ func Init( isVkMandatory = config.IsVkMandatory } - governanceStore, err := NewGovernanceStore(ctx, logger, store, governanceConfig) + governanceStore, err := NewLocalGovernanceStore(ctx, logger, configStore, governanceConfig) if err != nil { return nil, fmt.Errorf("failed to initialize governance store: %w", err) } @@ -123,10 +139,10 @@ func Init( resolver := NewBudgetResolver(governanceStore, logger) // 3. Tracker (business logic owner, depends on store and resolver) - tracker := NewUsageTracker(ctx, governanceStore, resolver, store, logger) + tracker := NewUsageTracker(ctx, governanceStore, resolver, configStore, logger) // 4. Perform startup reset check for any expired limits from downtime - if store != nil { + if configStore != nil { if err := tracker.PerformStartupResets(ctx); err != nil { logger.Warn("startup reset failed: %v", err) // Continue initialization even if startup reset fails (non-critical) @@ -139,7 +155,7 @@ func Init( store: governanceStore, resolver: resolver, tracker: tracker, - configStore: store, + configStore: configStore, modelCatalog: modelCatalog, logger: logger, isVkMandatory: isVkMandatory, @@ -148,69 +164,142 @@ func Init( return plugin, nil } +// InitFromStore initializes and returns a governance plugin instance with a custom store. +// +// This constructor allows providing a custom GovernanceStore implementation instead of +// creating a new LocalGovernanceStore. Use this when you need to: +// - Inject a custom store implementation for testing +// - Use a pre-configured store instance +// - Integrate with non-standard storage backends +// +// Parameters are the same as Init, except governanceConfig is replaced by governanceStore. +// The governanceStore must not be nil, or an error is returned. +// +// See Init documentation for details on other parameters and behavior. +func InitFromStore( + ctx context.Context, + config *Config, + logger schemas.Logger, + governanceStore GovernanceStore, + configStore configstore.ConfigStore, + modelCatalog *modelcatalog.ModelCatalog, + inMemoryStore InMemoryStore, +) (*GovernancePlugin, error) { + if configStore == nil { + logger.Warn("governance plugin requires config store to persist data, running in memory only mode") + } + if modelCatalog == nil { + logger.Warn("governance plugin requires model catalog to calculate cost, all cost calculations will be skipped.") + } + if governanceStore == nil { + return nil, fmt.Errorf("governance store is nil") + } + // Handle nil config - use safe default for IsVkMandatory + var isVkMandatory *bool + if config != nil { + isVkMandatory = config.IsVkMandatory + } + resolver := NewBudgetResolver(governanceStore, logger) + tracker := NewUsageTracker(ctx, governanceStore, resolver, configStore, logger) + // Perform startup reset check for any expired limits from downtime + if configStore != nil { + if err := tracker.PerformStartupResets(ctx); err != nil { + logger.Warn("startup reset failed: %v", err) + // Continue initialization even if startup reset fails (non-critical) + } + } + ctx, cancelFunc := context.WithCancel(ctx) + plugin := &GovernancePlugin{ + ctx: ctx, + cancelFunc: cancelFunc, + store: governanceStore, + resolver: resolver, + tracker: tracker, + configStore: configStore, + modelCatalog: modelCatalog, + logger: logger, + inMemoryStore: inMemoryStore, + isVkMandatory: isVkMandatory, + } + return plugin, nil +} + // GetName returns the name of the plugin func (p *GovernancePlugin) GetName() string { return PluginName } -// TransportInterceptor intercepts requests before they are processed (governance decision point) -// Parameters: -// - ctx: The Bifrost context -// - url: The URL of the request -// - headers: The request headers -// - body: The request body -// -// Returns: -// - map[string]string: The updated request headers -// - map[string]any: The updated request body -// - error: Any error that occurred during processing -func (p *GovernancePlugin) TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { +func parseVirtualKeyFromHTTPRequest(req *schemas.HTTPRequest) *string { var virtualKeyValue string - var err error - - for header, value := range headers { - headerStr := strings.ToLower(header) - if headerStr == string(schemas.BifrostContextKeyVirtualKey) { - virtualKeyValue = string(value) - break - } - if headerStr == "authorization" { - valueStr := string(value) - // Only accept Bearer token format: "Bearer ..." - if strings.HasPrefix(strings.ToLower(valueStr), "bearer ") { - authHeaderValue := strings.TrimSpace(valueStr[7:]) // Remove "Bearer " prefix - if authHeaderValue != "" && strings.HasPrefix(strings.ToLower(authHeaderValue), VirtualKeyPrefix) { - virtualKeyValue = authHeaderValue - break - } + vkHeader := req.CaseInsensitiveHeaderLookup("x-bf-vk") + if vkHeader != "" { + return bifrost.Ptr(vkHeader) + } + authHeader := req.CaseInsensitiveHeaderLookup("authorization") + if authHeader != "" { + if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { + authHeaderValue := strings.TrimSpace(authHeader[7:]) // Remove "Bearer " prefix + if authHeaderValue != "" && strings.HasPrefix(strings.ToLower(authHeaderValue), VirtualKeyPrefix) { + virtualKeyValue = authHeaderValue } } - if (headerStr == "x-api-key" || headerStr == "x-goog-api-key") && strings.HasPrefix(strings.ToLower(string(value)), VirtualKeyPrefix) { - virtualKeyValue = string(value) - break - } } - if virtualKeyValue == "" { - return headers, body, nil + if virtualKeyValue != "" { + return bifrost.Ptr(virtualKeyValue) + } + xAPIKey := req.CaseInsensitiveHeaderLookup("x-api-key") + if xAPIKey != "" && strings.HasPrefix(strings.ToLower(xAPIKey), VirtualKeyPrefix) { + return bifrost.Ptr(xAPIKey) + } + // Checking x-goog-api-key header + xGoogleAPIKey := req.CaseInsensitiveHeaderLookup("x-goog-api-key") + if xGoogleAPIKey != "" && strings.HasPrefix(strings.ToLower(xGoogleAPIKey), VirtualKeyPrefix) { + return bifrost.Ptr(xGoogleAPIKey) } + return nil +} - virtualKey, ok := p.store.GetVirtualKey(virtualKeyValue) +// HTTPTransportIntercept intercepts requests before they are processed (governance decision point) +// It modifies the request in-place and returns nil to continue, or an HTTPResponse to short-circuit. +func (p *GovernancePlugin) HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + virtualKeyValue := parseVirtualKeyFromHTTPRequest(req) + if virtualKeyValue == nil { + return nil, nil + } + // Get the virtual key from the store + virtualKey, ok := p.store.GetVirtualKey(*virtualKeyValue) if !ok || virtualKey == nil || !virtualKey.IsActive { - return headers, body, nil + return nil, nil } - - - body, err = p.loadBalanceProvider(body, virtualKey) + headers, err := p.addMCPIncludeTools(nil, virtualKey) if err != nil { - return headers, body, err + p.logger.Error("failed to add MCP include tools: %v", err) + return nil, nil } - - headers, err = p.addMCPIncludeTools(headers, virtualKey) + for header, value := range headers { + req.Headers[header] = value + } + if len(req.Body) == 0 { + return nil, nil + } + var payload map[string]any + err = sonic.Unmarshal(req.Body, &payload) if err != nil { - return headers, body, err + p.logger.Error("failed to unmarshal request body to check for virtual key: %v", err) + return nil, nil } - - return headers, body, nil + payload, err = p.loadBalanceProvider(payload, virtualKey) + if err != nil { + p.logger.Error("failed to load balance provider: %v", err) + return nil, nil + } + body, err := sonic.Marshal(payload) + if err != nil { + p.logger.Error("failed to marshal request body to check for virtual key: %v", err) + return nil, nil + } + req.Body = body + return nil, nil } // loadBalanceProvider loads balances the provider for the request @@ -386,7 +475,7 @@ func (p *GovernancePlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.Bif Type: bifrost.Ptr("virtual_key_required"), StatusCode: bifrost.Ptr(401), Error: &schemas.ErrorField{ - Message: "x-bf-vk header is missing and is mandatory.", + Message: "virtual key is missing in headers and is mandatory.", }, }, }, nil @@ -410,7 +499,7 @@ func (p *GovernancePlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.Bif if result.Decision != DecisionAllow { if ctx != nil { - if _, ok := (*ctx).Value(governanceRejectedContextKey).(bool); !ok { + if _, ok := ctx.Value(governanceRejectedContextKey).(bool); !ok { ctx.SetValue(governanceRejectedContextKey, true) } } @@ -502,7 +591,7 @@ func (p *GovernancePlugin) PostHook(ctx *schemas.BifrostContext, result *schemas isCacheRead = b } } - if val := (*ctx).Value(governanceIsBatchContextKey); val != nil { + if val := ctx.Value(governanceIsBatchContextKey); val != nil { if b, ok := val.(bool); ok { isBatch = b } @@ -597,6 +686,6 @@ func (p *GovernancePlugin) postHookWorker(result *schemas.BifrostResponse, provi } // GetGovernanceStore returns the governance store -func (p *GovernancePlugin) GetGovernanceStore() *GovernanceStore { +func (p *GovernancePlugin) GetGovernanceStore() GovernanceStore { return p.store } diff --git a/plugins/governance/providerbudget_test.go b/plugins/governance/providerbudget_test.go new file mode 100644 index 0000000000..a4096d7649 --- /dev/null +++ b/plugins/governance/providerbudget_test.go @@ -0,0 +1,236 @@ +package governance + +import ( + "strconv" + "testing" +) + +// TestProviderBudgetExceeded tests provider-specific budgets within a VK by making requests until budget is consumed +func TestProviderBudgetExceeded(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with different budgets for different providers + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: "test-vk-provider-budget-" + generateRandomID(), + Budget: &BudgetRequest{ + MaxLimit: 1.0, // High overall budget + ResetDuration: "1h", + }, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + Budget: &BudgetRequest{ + MaxLimit: 0.01, // Specific OpenAI budget + ResetDuration: "1h", + }, + }, + { + Provider: "anthropic", + Weight: 1.0, + Budget: &BudgetRequest{ + MaxLimit: 0.01, // Specific Anthropic budget + ResetDuration: "1h", + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with OpenAI budget $0.01 and Anthropic budget $0.01") + + // Test OpenAI provider budget exceeded + t.Run("OpenAIProviderBudgetExceeded", func(t *testing.T) { + providerBudget := 0.01 + consumedBudget := 0.0 + requestNum := 1 + var lastSuccessfulCost float64 + var shouldStop = false + + for requestNum <= 50 { + longPrompt := "Please provide a comprehensive and detailed response to the following question. " + + "I need extensive information covering all aspects of the topic. " + + "Provide multiple paragraphs with detailed explanations. " + + "Request number " + strconv.Itoa(requestNum) + ". " + + "Here is a detailed prompt that will consume significant tokens: " + + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum. Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum." + + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: longPrompt, + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "budget") || CheckErrorMessage(t, resp, "provider") { + t.Logf("Request %d correctly rejected: OpenAI provider budget exceeded", requestNum) + t.Logf("Consumed budget: $%.6f (limit: $%.2f)", consumedBudget, providerBudget) + t.Logf("Last successful request cost: $%.6f", lastSuccessfulCost) + + if requestNum == 1 { + t.Fatalf("First request should have succeeded but was rejected due to budget") + } + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not budget): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract actual token usage from response + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + actualInputTokens := int(prompt) + actualOutputTokens := int(completion) + actualCost, _ := CalculateCost("openai/gpt-4o", actualInputTokens, actualOutputTokens) + + consumedBudget += actualCost + lastSuccessfulCost = actualCost + + t.Logf("Request %d succeeded: input_tokens=%d, output_tokens=%d, cost=$%.6f, consumed=$%.6f/$%.2f", + requestNum, actualInputTokens, actualOutputTokens, actualCost, consumedBudget, providerBudget) + } + } + } + + requestNum++ + + if shouldStop { + break + } + + if consumedBudget >= providerBudget { + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit provider budget limit (consumed $%.6f / $%.2f) - budget not being enforced", + requestNum-1, consumedBudget, providerBudget) + }) + + // Test Anthropic provider budget exceeded + t.Run("AnthropicProviderBudgetExceeded", func(t *testing.T) { + providerBudget := 0.01 + consumedBudget := 0.0 + requestNum := 1 + var lastSuccessfulCost float64 + var shouldStop = false + + for requestNum <= 50 { + longPrompt := "Please provide a comprehensive and detailed response to the following question. " + + "I need extensive information covering all aspects of the topic. " + + "Provide multiple paragraphs with detailed explanations. " + + "Request number " + strconv.Itoa(requestNum) + ". " + + "Here is a detailed prompt that will consume significant tokens: " + + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum. Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum." + + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "anthropic/claude-3-7-sonnet-20250219", + Messages: []ChatMessage{ + { + Role: "user", + Content: longPrompt, + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "budget") || CheckErrorMessage(t, resp, "provider") { + t.Logf("Request %d correctly rejected: Anthropic provider budget exceeded", requestNum) + t.Logf("Consumed budget: $%.6f (limit: $%.2f)", consumedBudget, providerBudget) + t.Logf("Last successful request cost: $%.6f", lastSuccessfulCost) + + if requestNum == 1 { + t.Fatalf("First request should have succeeded but was rejected due to budget") + } + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not budget): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract actual token usage from response + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + actualInputTokens := int(prompt) + actualOutputTokens := int(completion) + actualCost, _ := CalculateCost("anthropic/claude-3-7-sonnet-20250219", actualInputTokens, actualOutputTokens) + + consumedBudget += actualCost + lastSuccessfulCost = actualCost + + t.Logf("Request %d succeeded: input_tokens=%d, output_tokens=%d, cost=$%.6f, consumed=$%.6f/$%.2f", + requestNum, actualInputTokens, actualOutputTokens, actualCost, consumedBudget, providerBudget) + } + } + } + + requestNum++ + + if shouldStop { + break + } + + if consumedBudget >= providerBudget { + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit provider budget limit (consumed $%.6f / $%.2f) - budget not being enforced", + requestNum-1, consumedBudget, providerBudget) + }) +} diff --git a/plugins/governance/ratelimit_test.go b/plugins/governance/ratelimit_test.go new file mode 100644 index 0000000000..8a5b4c8159 --- /dev/null +++ b/plugins/governance/ratelimit_test.go @@ -0,0 +1,991 @@ +package governance + +import ( + "testing" + "time" +) + +// TestVirtualKeyTokenRateLimit tests that VK-level token rate limits are enforced +func TestVirtualKeyTokenRateLimit(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a very restrictive token rate limit + vkName := "test-vk-token-limit-" + generateRandomID() + tokenLimit := int64(500) // Only 500 tokens per hour + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with token limit: %d tokens per %s", vkName, tokenLimit, tokenResetDuration) + + // Make requests until we hit the token limit + successCount := 0 + for i := 0; i < 10; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Short test request " + string(rune('0'+i)) + " for token limit.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "token") || CheckErrorMessage(t, resp, "rate") { + t.Logf("Request %d correctly rejected due to token rate limit", i+1) + return // Test passed - hit the token limit + } else { + t.Logf("Request %d failed with unexpected error: %v", i+1, resp.Body) + } + } else if resp.StatusCode == 200 { + successCount++ + t.Logf("Request %d succeeded (tokens within limit)", i+1) + } + } + + if successCount > 0 { + t.Logf("Made %d successful requests before hitting token limit āœ“", successCount) + } else { + t.Skip("Could not make requests to test token limit") + } +} + +// TestVirtualKeyRequestRateLimit tests that VK-level request rate limits are enforced +func TestVirtualKeyRequestRateLimit(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a very restrictive request rate limit + vkName := "test-vk-request-limit-" + generateRandomID() + requestLimit := int64(3) // Only 3 requests per minute + requestResetDuration := "1m" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + RequestMaxLimit: &requestLimit, + RequestResetDuration: &requestResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with request limit: %d requests per %s", vkName, requestLimit, requestResetDuration) + + // Make requests until we hit the request limit + successCount := 0 + for i := 0; i < 5; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request number " + string(rune('0'+i)) + ".", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "request") || CheckErrorMessage(t, resp, "rate") { + t.Logf("Request %d correctly rejected due to request rate limit", i+1) + return // Test passed + } else { + t.Logf("Request %d failed with different error", i+1) + } + } else if resp.StatusCode == 200 { + successCount++ + t.Logf("Request %d succeeded (count: %d/%d)", i+1, successCount, requestLimit) + } + } + + if successCount > 0 { + t.Logf("Made %d successful requests before hitting request limit āœ“", successCount) + } else { + t.Skip("Could not make requests to test request limit") + } +} + +// TestProviderConfigTokenRateLimit tests that provider-level token rate limits are enforced +func TestProviderConfigTokenRateLimit(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a provider config that has a token rate limit + vkName := "test-vk-provider-token-limit-" + generateRandomID() + providerTokenLimit := int64(300) // Limited tokens per provider + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &providerTokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with provider token limit: %d tokens per %s", vkName, providerTokenLimit, tokenResetDuration) + + // Make requests to openai until we hit provider token limit + successCount := 0 + for i := 0; i < 10; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Provider token limit test " + string(rune('0'+i)) + ".", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "token") || CheckErrorMessage(t, resp, "rate") { + t.Logf("Request %d correctly rejected due to provider token limit", i+1) + return // Test passed + } else { + t.Logf("Request %d failed with different error", i+1) + } + } else if resp.StatusCode == 200 { + successCount++ + t.Logf("Request %d succeeded", i+1) + } + } + + if successCount > 0 { + t.Logf("Made %d successful requests with provider token limit āœ“", successCount) + } else { + t.Skip("Could not make requests to test provider token limit") + } +} + +// TestProviderConfigRequestRateLimit tests that provider-level request rate limits are enforced +func TestProviderConfigRequestRateLimit(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a provider config that has a request rate limit + vkName := "test-vk-provider-request-limit-" + generateRandomID() + providerRequestLimit := int64(2) // Only 2 requests per minute for this provider + requestResetDuration := "1m" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + RequestMaxLimit: &providerRequestLimit, + RequestResetDuration: &requestResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with provider request limit: %d requests per %s", vkName, providerRequestLimit, requestResetDuration) + + // Make requests to openai until we hit provider request limit + successCount := 0 + for i := 0; i < 5; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Provider request limit test " + string(rune('0'+i)) + ".", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + if CheckErrorMessage(t, resp, "request") || CheckErrorMessage(t, resp, "rate") { + t.Logf("Request %d correctly rejected due to provider request limit", i+1) + return // Test passed + } else { + t.Logf("Request %d failed with different error", i+1) + } + } else if resp.StatusCode == 200 { + successCount++ + t.Logf("Request %d succeeded (count: %d/%d)", i+1, successCount, providerRequestLimit) + } + } + + if successCount > 0 { + t.Logf("Made %d successful requests with provider request limit āœ“", successCount) + } else { + t.Skip("Could not make requests to test provider request limit") + } +} + +// TestMultipleProvidersSeparateRateLimits tests that different providers have independent rate limits +func TestMultipleProvidersSeparateRateLimits(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with multiple providers, each with their own rate limits + vkName := "test-vk-multi-provider-limits-" + generateRandomID() + openaiLimit := int64(100) + anthropicLimit := int64(50) + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &openaiLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + { + Provider: "anthropic", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &anthropicLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with separate rate limits per provider", vkName) + + // Verify both providers are allowed + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + + providerConfigs, _ := vkData["provider_configs"].([]interface{}) + if len(providerConfigs) != 2 { + t.Fatalf("Expected 2 provider configs, got %d", len(providerConfigs)) + } + + t.Logf("VK has %d provider configs with separate rate limits āœ“", len(providerConfigs)) +} + +// TestProviderAndVKRateLimitTogether tests that both provider and VK rate limits are enforced together +func TestProviderAndVKRateLimitTogether(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with both VK-level and provider-level rate limits + vkName := "test-vk-both-limits-" + generateRandomID() + vkTokenLimit := int64(1000) + vkTokenResetDuration := "1h" + providerTokenLimit := int64(300) + providerTokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &vkTokenLimit, + TokenResetDuration: &vkTokenResetDuration, + }, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &providerTokenLimit, + TokenResetDuration: &providerTokenResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with VK limit (%d tokens) and provider limit (%d tokens)", vkName, vkTokenLimit, providerTokenLimit) + + // Verify the VK has both limits configured + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + + // Check VK has rate limit + vkRateLimitID, _ := vkData["rate_limit_id"].(string) + if vkRateLimitID == "" { + t.Fatalf("VK rate limit ID not found") + } + + // Check provider config exists + providerConfigs, _ := vkData["provider_configs"].([]interface{}) + if len(providerConfigs) == 0 { + t.Fatalf("No provider configs found") + } + + t.Logf("VK has both VK-level rate limit and provider-level rate limit configured āœ“") +} + +// TestRateLimitInMemorySync tests that rate limit changes sync to in-memory store +func TestRateLimitInMemorySync(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a token rate limit + vkName := "test-vk-rate-limit-sync-" + generateRandomID() + initialTokenLimit := int64(1000) + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &initialTokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with rate limit: %d tokens", vkName, initialTokenLimit) + + // Get initial rate limit from in-memory store + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + rateLimitID, _ := vkData["rate_limit_id"].(string) + + if rateLimitID == "" { + t.Fatalf("Rate limit ID not found in VK") + } + + // Update the rate limit + newTokenLimit := int64(5000) + updateResp := MakeRequest(t, APIRequest{ + Method: "PUT", + Path: "/api/governance/virtual-keys/" + vkID, + Body: UpdateVirtualKeyRequest{ + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &newTokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if updateResp.StatusCode != 200 { + t.Fatalf("Failed to update VK rate limit: status %d", updateResp.StatusCode) + } + + t.Logf("Updated rate limit from %d to %d tokens", initialTokenLimit, newTokenLimit) + + // Verify rate limit is updated in in-memory store + time.Sleep(500 * time.Millisecond) + + getDataResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp2.StatusCode != 200 { + t.Fatalf("Failed to get governance data after update: status %d", getDataResp2.StatusCode) + } + + virtualKeysMap2 := getDataResp2.Body["virtual_keys"].(map[string]interface{}) + vkData2 := virtualKeysMap2[vkValue].(map[string]interface{}) + + // Verify VK still has rate limit configured + rateLimitID2, _ := vkData2["rate_limit_id"].(string) + if rateLimitID2 == "" { + t.Fatalf("Rate limit ID removed after update") + } + + // Verify it's the same rate limit (ID should match) + if rateLimitID2 != rateLimitID { + t.Fatalf("Rate limit ID changed after update: was %s, now %s", rateLimitID, rateLimitID2) + } + + // Verify rate limit content - check the actual values in the main RateLimits map + getRateLimitsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap2 := getRateLimitsResp2.Body["rate_limits"].(map[string]interface{}) + rateLimit2, ok := rateLimitsMap2[rateLimitID2].(map[string]interface{}) + if !ok { + t.Fatalf("Rate limit not found in RateLimits map") + } + + // Check TokenMaxLimit was updated + tokenMaxLimit, ok := rateLimit2["token_max_limit"].(float64) + if !ok { + t.Fatalf("Token max limit not found in rate limit") + } + if int64(tokenMaxLimit) != newTokenLimit { + t.Fatalf("Token max limit not updated: expected %d but got %d", newTokenLimit, int64(tokenMaxLimit)) + } + t.Logf("Token max limit correctly updated to %d āœ“", int64(tokenMaxLimit)) + + // Check TokenResetDuration persists + resetDuration, ok := rateLimit2["token_reset_duration"].(string) + if !ok { + t.Fatalf("Token reset duration not found in rate limit") + } + if resetDuration != tokenResetDuration { + t.Fatalf("Token reset duration changed: expected %s but got %s", tokenResetDuration, resetDuration) + } + t.Logf("Token reset duration persisted: %s āœ“", resetDuration) + + // Check usage counters exist + if tokenCurrentUsage, ok := rateLimit2["token_current_usage"].(float64); ok { + t.Logf("Token current usage in memory: %d", int64(tokenCurrentUsage)) + } + + t.Logf("Rate limit in-memory sync verified āœ“") + t.Logf("VK rate limit ID persisted: %s", rateLimitID2) +} + +// TestRateLimitTokenAndRequestTogether tests that both token and request limits work together +func TestRateLimitTokenAndRequestTogether(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with both token and request limits + vkName := "test-vk-token-and-request-" + generateRandomID() + tokenLimit := int64(5000) + tokenResetDuration := "1h" + requestLimit := int64(100) + requestResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + RequestMaxLimit: &requestLimit, + RequestResetDuration: &requestResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with token limit (%d) and request limit (%d)", vkName, tokenLimit, requestLimit) + + // Make a few requests and verify both limits are being tracked + successCount := 0 + for i := 0; i < 3; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request for token and request limits " + string(rune('0'+i)) + ".", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode == 200 { + successCount++ + t.Logf("Request %d succeeded", i+1) + } else if resp.StatusCode >= 400 { + t.Logf("Request %d failed with status %d", i+1, resp.StatusCode) + break + } + } + + if successCount > 0 { + t.Logf("Made %d successful requests with both token and request limits āœ“", successCount) + } else { + t.Skip("Could not make requests to test combined limits") + } +} + +// TestRateLimitUsageTrackedInMemory tests that VK-level rate limit usage is tracked in in-memory store +func TestRateLimitUsageTrackedInMemory(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with both token and request rate limits + vkName := "test-vk-usage-tracking-" + generateRandomID() + tokenLimit := int64(100000) + tokenResetDuration := "1h" + requestLimit := int64(100) + requestResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + RequestMaxLimit: &requestLimit, + RequestResetDuration: &requestResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with rate limits for usage tracking", vkName) + + // Get initial state - rate limit usage should be 0 + getDataResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap1 := getDataResp1.Body["virtual_keys"].(map[string]interface{}) + vkData1 := virtualKeysMap1[vkValue].(map[string]interface{}) + rateLimitID1, _ := vkData1["rate_limit_id"].(string) + + initialTokenUsage := 0.0 + initialRequestUsage := 0.0 + + // Check initial rate limit usage (should be 0) from main RateLimits map + getRateLimitsResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap1 := getRateLimitsResp1.Body["rate_limits"].(map[string]interface{}) + rateLimit1, ok := rateLimitsMap1[rateLimitID1].(map[string]interface{}) + if !ok { + t.Fatalf("Rate limit not found in RateLimits map") + } + + if tokenUsage, ok := rateLimit1["token_current_usage"].(float64); ok { + initialTokenUsage = tokenUsage + t.Logf("Initial token usage: %d", int64(initialTokenUsage)) + } + if requestUsage, ok := rateLimit1["request_current_usage"].(float64); ok { + initialRequestUsage = requestUsage + t.Logf("Initial request usage: %d", int64(initialRequestUsage)) + } + + // Make a request to use some tokens and increment request count + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request for usage tracking.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not make request to test usage tracking") + } + + // Wait for async update to in-memory store + time.Sleep(500 * time.Millisecond) + + // Get updated state - rate limit usage should have increased + getDataResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap2 := getDataResp2.Body["virtual_keys"].(map[string]interface{}) + vkData2 := virtualKeysMap2[vkValue].(map[string]interface{}) + rateLimitID2, _ := vkData2["rate_limit_id"].(string) + + // Get rate limit from main RateLimits map + getRateLimitsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap2 := getRateLimitsResp2.Body["rate_limits"].(map[string]interface{}) + rateLimit2, ok := rateLimitsMap2[rateLimitID2].(map[string]interface{}) + if !ok { + t.Fatalf("Rate limit not found in RateLimits map after request") + } + + // Check that token usage increased + tokenUsage2, ok := rateLimit2["token_current_usage"].(float64) + if !ok { + t.Fatalf("Token current usage not found in rate limit") + } + + if tokenUsage2 <= initialTokenUsage { + t.Logf("Warning: Token usage did not increase (before: %d, after: %d)", int64(initialTokenUsage), int64(tokenUsage2)) + } else { + t.Logf("Token usage increased from %d to %d āœ“", int64(initialTokenUsage), int64(tokenUsage2)) + } + + // Check that request usage increased + requestUsage2, ok := rateLimit2["request_current_usage"].(float64) + if !ok { + t.Fatalf("Request current usage not found in rate limit") + } + + if requestUsage2 <= initialRequestUsage { + t.Logf("Warning: Request usage did not increase (before: %d, after: %d)", int64(initialRequestUsage), int64(requestUsage2)) + } else { + t.Logf("Request usage increased from %d to %d āœ“", int64(initialRequestUsage), int64(requestUsage2)) + } + + // Verify rate limit still has the configured max limits + tokenMaxLimit, ok := rateLimit2["token_max_limit"].(float64) + if ok && int64(tokenMaxLimit) != tokenLimit { + t.Fatalf("Token max limit changed: expected %d but got %d", tokenLimit, int64(tokenMaxLimit)) + } + + requestMaxLimit, ok := rateLimit2["request_max_limit"].(float64) + if ok && int64(requestMaxLimit) != requestLimit { + t.Fatalf("Request max limit changed: expected %d but got %d", requestLimit, int64(requestMaxLimit)) + } + + t.Logf("VK-level rate limit usage properly tracked in in-memory store āœ“") + t.Logf("Token usage: %d/%d, Request usage: %d/%d", + int64(tokenUsage2), tokenLimit, int64(requestUsage2), requestLimit) +} + +// TestProviderLevelRateLimitUsageTracking tests that provider-level rate limits are separately tracked +func TestProviderLevelRateLimitUsageTracking(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with multiple providers, each with their own rate limits + vkName := "test-vk-provider-usage-" + generateRandomID() + openaiTokenLimit := int64(50000) + anthropicTokenLimit := int64(30000) + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &openaiTokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + { + Provider: "anthropic", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &anthropicTokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with per-provider rate limits", vkName) + + // Get initial state - provider rate limit usage should be 0 + getDataResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap1 := getDataResp1.Body["virtual_keys"].(map[string]interface{}) + vkData1 := virtualKeysMap1[vkValue].(map[string]interface{}) + + providerConfigs1, ok := vkData1["provider_configs"].([]interface{}) + if !ok { + t.Fatalf("Provider configs not found in VK data") + } + + if len(providerConfigs1) != 2 { + t.Fatalf("Expected 2 provider configs, got %d", len(providerConfigs1)) + } + + t.Logf("VK has %d provider configs with separate rate limits", len(providerConfigs1)) + + // Make a request with openai model to use openai provider's rate limit + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request for provider rate limit tracking.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not make request to test provider rate limit tracking") + } + + // Wait for async update + time.Sleep(500 * time.Millisecond) + + // Get updated state - openai provider rate limit usage should have increased + getDataResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap2 := getDataResp2.Body["virtual_keys"].(map[string]interface{}) + vkData2 := virtualKeysMap2[vkValue].(map[string]interface{}) + + providerConfigs2, ok := vkData2["provider_configs"].([]interface{}) + if !ok { + t.Fatalf("Provider configs not found in VK data after request") + } + + // Check each provider config for rate limit updates + var openaiUsage, anthropicUsage float64 + var openaiMaxLimit, anthropicMaxLimit float64 + + // Get rate limits from main RateLimits map + getRateLimitsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/rate-limits?from_memory=true", + }) + + rateLimitsMap2 := getRateLimitsResp2.Body["rate_limits"].(map[string]interface{}) + + for i, providerConfig := range providerConfigs2 { + config, ok := providerConfig.(map[string]interface{}) + if !ok { + continue + } + + provider, ok := config["provider"].(string) + if !ok { + continue + } + + rateLimitID, ok := config["rate_limit_id"].(string) + if !ok { + t.Logf("Provider %s: No rate limit ID found", provider) + continue + } + + rateLimit, ok := rateLimitsMap2[rateLimitID].(map[string]interface{}) + if !ok { + t.Logf("Provider %s: No rate limit found in RateLimits map", provider) + continue + } + + tokenUsage, _ := rateLimit["token_current_usage"].(float64) + tokenMaxLimit, _ := rateLimit["token_max_limit"].(float64) + + if provider == "openai" { + openaiUsage = tokenUsage + openaiMaxLimit = tokenMaxLimit + t.Logf("Provider %d (openai): Token usage: %d/%d", i, int64(tokenUsage), int64(tokenMaxLimit)) + } else if provider == "anthropic" { + anthropicUsage = tokenUsage + anthropicMaxLimit = tokenMaxLimit + t.Logf("Provider %d (anthropic): Token usage: %d/%d", i, int64(tokenUsage), int64(tokenMaxLimit)) + } + } + + // Verify provider limits are independent + if openaiMaxLimit != float64(openaiTokenLimit) { + t.Logf("Warning: OpenAI max limit changed: expected %d but got %d", openaiTokenLimit, int64(openaiMaxLimit)) + } + + if anthropicMaxLimit != float64(anthropicTokenLimit) { + t.Logf("Warning: Anthropic max limit changed: expected %d but got %d", anthropicTokenLimit, int64(anthropicMaxLimit)) + } + + t.Logf("Provider-level rate limits properly tracked separately in in-memory store āœ“") + t.Logf("OpenAI usage: %d, Anthropic usage: %d (separate limits)", int64(openaiUsage), int64(anthropicUsage)) +} diff --git a/plugins/governance/ratelimitenforcement_test.go b/plugins/governance/ratelimitenforcement_test.go new file mode 100644 index 0000000000..859c2f63fb --- /dev/null +++ b/plugins/governance/ratelimitenforcement_test.go @@ -0,0 +1,615 @@ +package governance + +import ( + "testing" + "time" +) + +// TestVirtualKeyTokenRateLimitEnforcement verifies VK token rate limits actually reject requests +// Rate limit enforcement is POST-HOC: the request that exceeds the limit is ALLOWED, +// but subsequent requests are BLOCKED. +func TestVirtualKeyTokenRateLimitEnforcement(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a VERY restrictive token rate limit + vkName := "test-vk-strict-token-limit-" + generateRandomID() + tokenLimit := int64(100) // Only 100 tokens max + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with strict token limit: %d tokens per %s", tokenLimit, tokenResetDuration) + + // Verify rate limit is in in-memory store + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + rateLimitID, _ := vkData["rate_limit_id"].(string) + + if rateLimitID == "" { + t.Fatalf("Rate limit not configured on VK") + } + + t.Logf("Rate limit ID %s configured on VK āœ“", rateLimitID) + + // Make requests until token limit is exceeded + // Rate limit enforcement is POST-HOC: request that exceeds is allowed, next is blocked + consumedTokens := int64(0) + requestNum := 1 + shouldStop := false + + for requestNum <= 20 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Hello how are you?", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request rejected - check if it's due to rate limit + if resp.StatusCode == 429 || CheckErrorMessage(t, resp, "token") || CheckErrorMessage(t, resp, "rate") { + t.Logf("Request %d correctly rejected: token limit exceeded at %d/%d", requestNum, consumedTokens, tokenLimit) + + // Verify rejection happened after exceeding the limit + if consumedTokens < tokenLimit { + t.Fatalf("Request rejected before token limit was exceeded: consumed %d < limit %d", consumedTokens, tokenLimit) + } + + t.Logf("Token rate limit enforcement verified āœ“") + t.Logf("Request blocked after token limit exceeded") + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not rate limit): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract token usage + var tokensUsed int64 + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if total, ok := usage["total_tokens"].(float64); ok { + tokensUsed = int64(total) + } + } + + consumedTokens += tokensUsed + t.Logf("Request %d succeeded: tokens=%d, consumed=%d/%d", requestNum, tokensUsed, consumedTokens, tokenLimit) + + requestNum++ + + if shouldStop { + break + } + + if consumedTokens >= tokenLimit { + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit token rate limit (consumed %d / %d) - rate limit not being enforced", + requestNum-1, consumedTokens, tokenLimit) +} + +// TestVirtualKeyRequestRateLimitEnforcement verifies VK request rate limits actually reject requests +func TestVirtualKeyRequestRateLimitEnforcement(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a very restrictive request rate limit + vkName := "test-vk-strict-request-limit-" + generateRandomID() + requestLimit := int64(1) // Only 1 request allowed + requestResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + RequestMaxLimit: &requestLimit, + RequestResetDuration: &requestResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with request limit: %d request per %s", requestLimit, requestResetDuration) + + // Make requests until request limit is exceeded + requestCount := int64(0) + requestNum := 1 + + for requestNum <= 10 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request to test request rate limit.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request rejected - check if it's due to rate limit + if resp.StatusCode == 429 || CheckErrorMessage(t, resp, "request") || CheckErrorMessage(t, resp, "rate") { + t.Logf("Request %d correctly rejected: request limit exceeded at %d/%d", requestNum, requestCount, requestLimit) + + // Verify rejection happened after exceeding the limit + if requestCount < requestLimit { + t.Fatalf("Request rejected before request limit was exceeded: count %d < limit %d", requestCount, requestLimit) + } + + t.Logf("Request rate limit enforcement verified āœ“") + t.Logf("Request blocked after request limit exceeded") + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not rate limit): %v", requestNum, resp.Body) + } + } + + // Request succeeded - increment count + requestCount++ + t.Logf("Request %d succeeded: count=%d/%d", requestNum, requestCount, requestLimit) + + requestNum++ + } + + t.Fatalf("Made %d requests but never hit request rate limit (count %d / %d) - rate limit not being enforced", + requestNum-1, requestCount, requestLimit) +} + +// TestProviderConfigTokenRateLimitEnforcement verifies provider-level token limits reject requests +func TestProviderConfigTokenRateLimitEnforcement(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with provider-level token rate limit + vkName := "test-vk-provider-strict-token-" + generateRandomID() + providerTokenLimit := int64(100) + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &providerTokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with provider token limit: %d tokens", providerTokenLimit) + + // Verify provider config rate limit is set + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + providerConfigs, _ := vkData["provider_configs"].([]interface{}) + + if len(providerConfigs) == 0 { + t.Fatalf("Provider config not found") + } + + t.Logf("Provider config rate limit configured āœ“") + + // Make requests until provider token limit is exceeded + // Rate limit enforcement is POST-HOC: request that exceeds is allowed, next is blocked + consumedTokens := int64(0) + requestNum := 1 + shouldStop := false + + for requestNum <= 20 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request to openai to test provider token limit.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request rejected - check if it's due to rate limit + if resp.StatusCode == 429 || CheckErrorMessage(t, resp, "token") || CheckErrorMessage(t, resp, "rate") { + t.Logf("Request %d correctly rejected: provider token limit exceeded at %d/%d", requestNum, consumedTokens, providerTokenLimit) + + // Verify rejection happened after exceeding the limit + if consumedTokens < providerTokenLimit { + t.Fatalf("Request rejected before provider token limit was exceeded: consumed %d < limit %d", consumedTokens, providerTokenLimit) + } + + t.Logf("Provider token rate limit enforcement verified āœ“") + t.Logf("Request blocked after provider token limit exceeded") + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not rate limit): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract token usage + var tokensUsed int64 + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if total, ok := usage["total_tokens"].(float64); ok { + tokensUsed = int64(total) + } + } + + consumedTokens += tokensUsed + t.Logf("Request %d succeeded: tokens=%d, consumed=%d/%d", requestNum, tokensUsed, consumedTokens, providerTokenLimit) + + requestNum++ + + if shouldStop { + break + } + + if consumedTokens >= providerTokenLimit { + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit provider token rate limit (consumed %d / %d) - rate limit not being enforced", + requestNum-1, consumedTokens, providerTokenLimit) +} + +// TestProviderConfigRequestRateLimitEnforcement verifies provider-level request limits +func TestProviderConfigRequestRateLimitEnforcement(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with provider-level request rate limit + vkName := "test-vk-provider-strict-request-" + generateRandomID() + providerRequestLimit := int64(1) // Only 1 request allowed + requestResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + RequestMaxLimit: &providerRequestLimit, + RequestResetDuration: &requestResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with provider request limit: %d request", providerRequestLimit) + + // Make requests until provider request limit is exceeded + requestCount := int64(0) + requestNum := 1 + + for requestNum <= 10 { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request to test provider request rate limit.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request rejected - check if it's due to rate limit + if resp.StatusCode == 429 || CheckErrorMessage(t, resp, "request") || CheckErrorMessage(t, resp, "rate") { + t.Logf("Request %d correctly rejected: provider request limit exceeded at %d/%d", requestNum, requestCount, providerRequestLimit) + + // Verify rejection happened after exceeding the limit + if requestCount < providerRequestLimit { + t.Fatalf("Request rejected before provider request limit was exceeded: count %d < limit %d", requestCount, providerRequestLimit) + } + + t.Logf("Provider request rate limit enforcement verified āœ“") + t.Logf("Request blocked after provider request limit exceeded") + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not rate limit): %v", requestNum, resp.Body) + } + } + + // Request succeeded - increment count + requestCount++ + t.Logf("Request %d succeeded: count=%d/%d", requestNum, requestCount, providerRequestLimit) + + requestNum++ + } + + t.Fatalf("Made %d requests but never hit provider request rate limit (count %d / %d) - rate limit not being enforced", + requestNum-1, requestCount, providerRequestLimit) +} + +// TestProviderAndVKRateLimitBothEnforced verifies both provider and VK limits are enforced +func TestProviderAndVKRateLimitBothEnforced(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with both VK and provider request limits + vkName := "test-vk-both-enforced-" + generateRandomID() + vkRequestLimit := int64(5) + providerRequestLimit := int64(2) // More restrictive + requestResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + RequestMaxLimit: &vkRequestLimit, + RequestResetDuration: &requestResetDuration, + }, + ProviderConfigs: []ProviderConfigRequest{ + { + Provider: "openai", + Weight: 1.0, + RateLimit: &CreateRateLimitRequest{ + RequestMaxLimit: &providerRequestLimit, + RequestResetDuration: &requestResetDuration, + }, + }, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK with VK limit (%d) and provider limit (%d requests)", vkRequestLimit, providerRequestLimit) + + // Make requests - provider limit (2) is more restrictive than VK limit (5) + // So we should hit provider limit first + successCount := 0 + for i := 0; i < 5; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Request " + string(rune('0'+i)) + " to test both limits.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode == 200 { + successCount++ + t.Logf("Request %d succeeded (count: %d)", i+1, successCount) + } else if resp.StatusCode >= 400 { + t.Logf("Request %d rejected with status %d", i+1, resp.StatusCode) + if successCount < int(providerRequestLimit) { + t.Fatalf("Request rejected before provider limit (%d): %v", providerRequestLimit, resp.Body) + } + // Expected - hit provider limit first + return + } + } + + if successCount > 0 { + if successCount >= 5 { + t.Fatalf("Made all %d requests without hitting rate limit (provider limit was %d) - rate limit not enforced", + successCount, providerRequestLimit) + } + t.Logf("Both VK and provider rate limits are configured and enforced āœ“") + } else { + t.Skip("Could not test - all requests failed") + } +} + +// TestRateLimitInMemoryUsageTracking verifies usage counters are tracked in in-memory store +func TestRateLimitInMemoryUsageTracking(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create VK with rate limit + vkName := "test-vk-usage-tracking-" + generateRandomID() + tokenLimit := int64(10000) + tokenResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK for usage tracking test") + + // Make a request + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test for usage tracking.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Skip("Could not execute request for usage tracking test") + } + + // Get usage from response + var tokensUsed int + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if total, ok := usage["total_tokens"].(float64); ok { + tokensUsed = int(total) + } + } + + if tokensUsed == 0 { + t.Skip("Could not extract token usage from response") + } + + t.Logf("Request used %d tokens", tokensUsed) + + // Wait for async update + time.Sleep(1 * time.Second) + + // Verify rate limit usage is tracked in in-memory store + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + virtualKeysMap, ok := getDataResp.Body["virtual_keys"].(map[string]interface{}) + if !ok || virtualKeysMap == nil { + t.Fatalf("Virtual keys field missing or not a map in get response") + } + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + rateLimitID, _ := vkData["rate_limit_id"].(string) + + if rateLimitID != "" { + t.Logf("Rate limit %s is configured and tracking usage āœ“", rateLimitID) + } else { + t.Logf("Rate limit is configured āœ“") + } +} diff --git a/plugins/governance/resolver.go b/plugins/governance/resolver.go index 4518a1ba48..e37f92a971 100644 --- a/plugins/governance/resolver.go +++ b/plugins/governance/resolver.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "slices" - "strings" "time" "github.com/maximhq/bifrost/core/schemas" @@ -63,12 +62,12 @@ type UsageInfo struct { // BudgetResolver provides decision logic for the new hierarchical governance system type BudgetResolver struct { - store *GovernanceStore + store GovernanceStore logger schemas.Logger } // NewBudgetResolver creates a new budget-based governance resolver -func NewBudgetResolver(store *GovernanceStore, logger schemas.Logger) *BudgetResolver { +func NewBudgetResolver(store GovernanceStore, logger schemas.Logger) *BudgetResolver { return &BudgetResolver{ store: store, logger: logger, @@ -127,13 +126,13 @@ func (r *BudgetResolver) EvaluateRequest(ctx *schemas.BifrostContext, evaluation } } - // 4. Check rate limits (Provider level first, then VK level) - if rateLimitResult := r.checkRateLimits(vk, string(evaluationRequest.Provider)); rateLimitResult != nil { + // 4. Check rate limits hierarchy (Provider level first, then VK level) + if rateLimitResult := r.checkRateLimitHierarchy(ctx, vk, string(evaluationRequest.Provider), evaluationRequest.Model, evaluationRequest.RequestID); rateLimitResult != nil { return rateLimitResult } // 5. Check budget hierarchy (VK → Team → Customer) - if budgetResult := r.checkBudgetHierarchy(ctx, vk, evaluationRequest.Provider); budgetResult != nil { + if budgetResult := r.checkBudgetHierarchy(ctx, vk, evaluationRequest); budgetResult != nil { return budgetResult } @@ -192,77 +191,25 @@ func (r *BudgetResolver) isProviderAllowed(vk *configstoreTables.TableVirtualKey return false } -// checkRateLimits checks provider-level rate limits first, then VK rate limits using flexible approach -func (r *BudgetResolver) checkRateLimits(vk *configstoreTables.TableVirtualKey, provider string) *EvaluationResult { - // First check provider-level rate limits - if providerRateLimitResult := r.checkProviderRateLimits(vk, provider); providerRateLimitResult != nil { - return providerRateLimitResult - } - - // Then check VK-level rate limits - if vk.RateLimit == nil { - return nil // No VK rate limits defined - } - - return r.checkSingleRateLimit(vk.RateLimit, "virtual key", vk) -} - -// checkProviderRateLimits checks rate limits for a specific provider config -func (r *BudgetResolver) checkProviderRateLimits(vk *configstoreTables.TableVirtualKey, provider string) *EvaluationResult { - if vk.ProviderConfigs == nil { - return nil // No provider configs defined - } - - // Find the specific provider config - for _, pc := range vk.ProviderConfigs { - if pc.Provider == provider && pc.RateLimit != nil { - return r.checkSingleRateLimit(pc.RateLimit, fmt.Sprintf("provider '%s'", provider), vk) - } - } - - return nil // No rate limits for this provider -} - -// checkSingleRateLimit checks a single rate limit and returns evaluation result if violated -func (r *BudgetResolver) checkSingleRateLimit(rateLimit *configstoreTables.TableRateLimit, rateLimitName string, vk *configstoreTables.TableVirtualKey) *EvaluationResult { - var violations []string - - // Token limits - if rateLimit.TokenMaxLimit != nil && rateLimit.TokenCurrentUsage >= *rateLimit.TokenMaxLimit { - duration := "unknown" - if rateLimit.TokenResetDuration != nil { - duration = *rateLimit.TokenResetDuration - } - violations = append(violations, fmt.Sprintf("token limit exceeded (%d/%d, resets every %s)", - rateLimit.TokenCurrentUsage, *rateLimit.TokenMaxLimit, duration)) - } - - // Request limits - if rateLimit.RequestMaxLimit != nil && rateLimit.RequestCurrentUsage >= *rateLimit.RequestMaxLimit { - duration := "unknown" - if rateLimit.RequestResetDuration != nil { - duration = *rateLimit.RequestResetDuration - } - violations = append(violations, fmt.Sprintf("request limit exceeded (%d/%d, resets every %s)", - rateLimit.RequestCurrentUsage, *rateLimit.RequestMaxLimit, duration)) - } - - if len(violations) > 0 { - // Determine specific violation type - decision := DecisionRateLimited - if len(violations) == 1 { - if strings.Contains(violations[0], "token") { - decision = DecisionTokenLimited - } else if strings.Contains(violations[0], "request") { - decision = DecisionRequestLimited +// checkRateLimitHierarchy checks provider-level rate limits first, then VK rate limits using flexible approach +func (r *BudgetResolver) checkRateLimitHierarchy(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider string, model string, requestID string) *EvaluationResult { + if decision, err := r.store.CheckRateLimit(ctx, vk, schemas.ModelProvider(provider), model, requestID, nil, nil); err != nil { + // Check provider-level first (matching check order), then VK-level + var rateLimitInfo *configstoreTables.TableRateLimit + for _, pc := range vk.ProviderConfigs { + if pc.Provider == provider && pc.RateLimit != nil { + rateLimitInfo = pc.RateLimit + break } } - + if rateLimitInfo == nil && vk.RateLimit != nil { + rateLimitInfo = vk.RateLimit + } return &EvaluationResult{ Decision: decision, - Reason: fmt.Sprintf("%s rate limits exceeded: %v", rateLimitName, violations), + Reason: fmt.Sprintf("Rate limit check failed: %s", err.Error()), VirtualKey: vk, - RateLimitInfo: rateLimit, + RateLimitInfo: rateLimitInfo, } } @@ -270,14 +217,14 @@ func (r *BudgetResolver) checkSingleRateLimit(rateLimit *configstoreTables.Table } // checkBudgetHierarchy checks the budget hierarchy atomically (VK → Team → Customer) -func (r *BudgetResolver) checkBudgetHierarchy(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider) *EvaluationResult { +func (r *BudgetResolver) checkBudgetHierarchy(ctx context.Context, vk *configstoreTables.TableVirtualKey, request *EvaluationRequest) *EvaluationResult { // Use atomic budget checking to prevent race conditions - if err := r.store.CheckBudget(ctx, vk, provider); err != nil { - r.logger.Debug(fmt.Sprintf("Atomic budget check failed for VK %s: %s", vk.ID, err.Error())) + if err := r.store.CheckBudget(ctx, vk, request, nil); err != nil { + r.logger.Debug(fmt.Sprintf("Atomic budget exceeded for VK %s: %s", vk.ID, err.Error())) return &EvaluationResult{ Decision: DecisionBudgetExceeded, - Reason: fmt.Sprintf("Budget check failed: %s", err.Error()), + Reason: fmt.Sprintf("Budget exceeded: %s", err.Error()), VirtualKey: vk, } } diff --git a/plugins/governance/resolver_test.go b/plugins/governance/resolver_test.go new file mode 100644 index 0000000000..1fb6e78b98 --- /dev/null +++ b/plugins/governance/resolver_test.go @@ -0,0 +1,552 @@ +package governance + +import ( + "context" + "testing" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestBudgetResolver_EvaluateRequest_AllowedRequest tests happy path +func TestBudgetResolver_EvaluateRequest_AllowedRequest(t *testing.T) { + logger := NewMockLogger() + vk := buildVirtualKey("vk1", "sk-bf-test", "Test VK", true) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + RequestID: "req-123", + }) + + assertDecision(t, DecisionAllow, result) + assertVirtualKeyFound(t, result) +} + +// TestBudgetResolver_EvaluateRequest_VirtualKeyNotFound tests missing VK +func TestBudgetResolver_EvaluateRequest_VirtualKeyNotFound(t *testing.T) { + logger := NewMockLogger() + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-nonexistent", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + assertDecision(t, DecisionVirtualKeyNotFound, result) +} + +// TestBudgetResolver_EvaluateRequest_VirtualKeyBlocked tests inactive VK +func TestBudgetResolver_EvaluateRequest_VirtualKeyBlocked(t *testing.T) { + logger := NewMockLogger() + vk := buildVirtualKey("vk1", "sk-bf-test", "Test VK", false) // Inactive + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + assertDecision(t, DecisionVirtualKeyBlocked, result) +} + +// TestBudgetResolver_EvaluateRequest_ProviderBlocked tests provider filtering +func TestBudgetResolver_EvaluateRequest_ProviderBlocked(t *testing.T) { + logger := NewMockLogger() + + // VK with only Anthropic allowed + providerConfigs := []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("anthropic", []string{"claude-3-sonnet"}), + } + vk := buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test VK", providerConfigs) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + // Try to use OpenAI (not allowed) + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + assertDecision(t, DecisionProviderBlocked, result) + assertVirtualKeyFound(t, result) +} + +// TestBudgetResolver_EvaluateRequest_ModelBlocked tests model filtering +func TestBudgetResolver_EvaluateRequest_ModelBlocked(t *testing.T) { + logger := NewMockLogger() + + // VK with specific models allowed + providerConfigs := []configstoreTables.TableVirtualKeyProviderConfig{ + { + Provider: "openai", + AllowedModels: []string{"gpt-4", "gpt-4-turbo"}, // Only these models + Weight: bifrost.Ptr(1.0), + RateLimit: nil, + Budget: nil, + Keys: []configstoreTables.TableKey{}, + }, + } + vk := buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test VK", providerConfigs) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + // Try to use gpt-4o-mini (not in allowed list) + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + }) + + assertDecision(t, DecisionModelBlocked, result) +} + +// TestBudgetResolver_EvaluateRequest_RateLimitExceeded_TokenLimit tests token limit +func TestBudgetResolver_EvaluateRequest_RateLimitExceeded_TokenLimit(t *testing.T) { + logger := NewMockLogger() + + // VK with rate limit already at max + rateLimit := buildRateLimitWithUsage("rl1", 10000, 10000, 1000, 0) // Tokens at max + vk := buildVirtualKeyWithRateLimit("vk1", "sk-bf-test", "Test VK", rateLimit) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + RateLimits: []configstoreTables.TableRateLimit{*rateLimit}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + assertDecision(t, DecisionTokenLimited, result) + assertRateLimitInfo(t, result) +} + +// TestBudgetResolver_EvaluateRequest_RateLimitExceeded_RequestLimit tests request limit +func TestBudgetResolver_EvaluateRequest_RateLimitExceeded_RequestLimit(t *testing.T) { + logger := NewMockLogger() + + // VK with request limit already at max + rateLimit := buildRateLimitWithUsage("rl1", 10000, 0, 100, 100) // Requests at max + vk := buildVirtualKeyWithRateLimit("vk1", "sk-bf-test", "Test VK", rateLimit) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + RateLimits: []configstoreTables.TableRateLimit{*rateLimit}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + assertDecision(t, DecisionRequestLimited, result) +} + +// TestBudgetResolver_EvaluateRequest_RateLimitExpired tests rate limit reset +func TestBudgetResolver_EvaluateRequest_RateLimitExpired(t *testing.T) { + logger := NewMockLogger() + + // VK with rate limit that's expired (should be treated as reset) + duration := "1m" + rateLimit := &configstoreTables.TableRateLimit{ + ID: "rl1", + TokenMaxLimit: ptrInt64(10000), + TokenCurrentUsage: 10000, // At limit + TokenResetDuration: &duration, + TokenLastReset: time.Now().Add(-2 * time.Minute), // Expired + RequestMaxLimit: ptrInt64(1000), + RequestCurrentUsage: 0, + RequestResetDuration: &duration, + RequestLastReset: time.Now(), + } + vk := buildVirtualKeyWithRateLimit("vk1", "sk-bf-test", "Test VK", rateLimit) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + RateLimits: []configstoreTables.TableRateLimit{*rateLimit}, + }) + require.NoError(t, err) + + // Reset expired rate limits (simulating ticker behavior) + expiredRateLimits := store.ResetExpiredRateLimitsInMemory(context.Background()) + err = store.ResetExpiredRateLimits(context.Background(), expiredRateLimits) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + // Should allow because rate limit was expired and has been reset + assertDecision(t, DecisionAllow, result) +} + +// TestBudgetResolver_EvaluateRequest_BudgetExceeded tests budget violation +func TestBudgetResolver_EvaluateRequest_BudgetExceeded(t *testing.T) { + logger := NewMockLogger() + + budget := buildBudgetWithUsage("budget1", 100.0, 100.0, "1d") // At limit + vk := buildVirtualKeyWithBudget("vk1", "sk-bf-test", "Test VK", budget) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*budget}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + assertDecision(t, DecisionBudgetExceeded, result) +} + +// TestBudgetResolver_EvaluateRequest_BudgetExpired tests expired budget (should be treated as reset) +func TestBudgetResolver_EvaluateRequest_BudgetExpired(t *testing.T) { + logger := NewMockLogger() + + budget := &configstoreTables.TableBudget{ + ID: "budget1", + MaxLimit: 100.0, + CurrentUsage: 100.0, // At limit + ResetDuration: "1d", + LastReset: time.Now().Add(-48 * time.Hour), // Expired + } + vk := buildVirtualKeyWithBudget("vk1", "sk-bf-test", "Test VK", budget) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*budget}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + // Should allow because budget is expired (will be reset) + assertDecision(t, DecisionAllow, result) +} + +// TestBudgetResolver_EvaluateRequest_MultiLevelBudgetHierarchy tests hierarchy checking +func TestBudgetResolver_EvaluateRequest_MultiLevelBudgetHierarchy(t *testing.T) { + logger := NewMockLogger() + + vkBudget := buildBudgetWithUsage("vk-budget", 100.0, 50.0, "1d") + teamBudget := buildBudgetWithUsage("team-budget", 500.0, 200.0, "1d") + customerBudget := buildBudgetWithUsage("customer-budget", 1000.0, 400.0, "1d") + + team := buildTeam("team1", "Team 1", teamBudget) + customer := buildCustomer("customer1", "Customer 1", customerBudget) + team.CustomerID = &customer.ID + team.Customer = customer + + vk := buildVirtualKeyWithBudget("vk1", "sk-bf-test", "Test VK", vkBudget) + vk.TeamID = &team.ID + vk.Team = team + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*vkBudget, *teamBudget, *customerBudget}, + Teams: []configstoreTables.TableTeam{*team}, + Customers: []configstoreTables.TableCustomer{*customer}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + // Test: All under limit should pass + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + assertDecision(t, DecisionAllow, result) + + // Test: VK budget exceeds should fail + // Get the governance data to update the budget directly + governanceData := store.GetGovernanceData() + vkBudgetToUpdate := governanceData.Budgets["vk-budget"] + if vkBudgetToUpdate != nil { + vkBudgetToUpdate.CurrentUsage = 100.0 + store.budgets.Store("vk-budget", vkBudgetToUpdate) + } + result = resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + assertDecision(t, DecisionBudgetExceeded, result) +} + +// TestBudgetResolver_EvaluateRequest_ProviderLevelRateLimit tests provider-specific rate limits +func TestBudgetResolver_EvaluateRequest_ProviderLevelRateLimit(t *testing.T) { + logger := NewMockLogger() + + // Provider with rate limit at max + providerRL := buildRateLimitWithUsage("provider-rl", 5000, 5000, 500, 0) + providerConfig := buildProviderConfigWithRateLimit("openai", []string{"gpt-4"}, providerRL) + vk := buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test VK", []configstoreTables.TableVirtualKeyProviderConfig{providerConfig}) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + RateLimits: []configstoreTables.TableRateLimit{*providerRL}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + assertDecision(t, DecisionTokenLimited, result) + assertRateLimitInfo(t, result) +} + +// TestBudgetResolver_CheckRateLimits_BothExceeded tests token and request limits simultaneously +func TestBudgetResolver_CheckRateLimits_BothExceeded(t *testing.T) { + logger := NewMockLogger() + + // Rate limit with both token and request at max + rateLimit := buildRateLimitWithUsage("rl1", 1000, 1000, 100, 100) + vk := buildVirtualKeyWithRateLimit("vk1", "sk-bf-test", "Test VK", rateLimit) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + RateLimits: []configstoreTables.TableRateLimit{*rateLimit}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + assertDecision(t, DecisionRateLimited, result) + assert.Contains(t, result.Reason, "rate limit") +} + +// TestBudgetResolver_IsProviderAllowed tests provider filtering logic +func TestBudgetResolver_IsProviderAllowed(t *testing.T) { + logger := NewMockLogger() + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + + tests := []struct { + name string + vk *configstoreTables.TableVirtualKey + provider schemas.ModelProvider + shouldBeAllowed bool + }{ + { + name: "No provider configs (all allowed)", + vk: buildVirtualKey("vk1", "sk-bf-test", "Test", true), + provider: schemas.OpenAI, + shouldBeAllowed: true, + }, + { + name: "Provider in allowlist", + vk: buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test", + []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"gpt-4"}), + }), + provider: schemas.OpenAI, + shouldBeAllowed: true, + }, + { + name: "Provider not in allowlist", + vk: buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test", + []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("anthropic", []string{"claude-3-sonnet"}), + }), + provider: schemas.OpenAI, + shouldBeAllowed: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + allowed := resolver.isProviderAllowed(tt.vk, tt.provider) + assert.Equal(t, tt.shouldBeAllowed, allowed) + }) + } +} + +// TestBudgetResolver_IsModelAllowed tests model filtering logic +func TestBudgetResolver_IsModelAllowed(t *testing.T) { + logger := NewMockLogger() + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + + tests := []struct { + name string + vk *configstoreTables.TableVirtualKey + provider schemas.ModelProvider + model string + shouldBeAllowed bool + }{ + { + name: "No provider configs (all models allowed)", + vk: buildVirtualKey("vk1", "sk-bf-test", "Test", true), + provider: schemas.OpenAI, + model: "gpt-4", + shouldBeAllowed: true, + }, + { + name: "Empty allowed models (all models allowed)", + vk: buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test", + []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{}), // Empty = all allowed + }), + provider: schemas.OpenAI, + model: "gpt-4", + shouldBeAllowed: true, + }, + { + name: "Model in allowlist", + vk: buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test", + []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"gpt-4", "gpt-4-turbo"}), + }), + provider: schemas.OpenAI, + model: "gpt-4", + shouldBeAllowed: true, + }, + { + name: "Model not in allowlist", + vk: buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test", + []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"gpt-4", "gpt-4-turbo"}), + }), + provider: schemas.OpenAI, + model: "gpt-4o-mini", + shouldBeAllowed: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + allowed := resolver.isModelAllowed(tt.vk, tt.provider, tt.model) + assert.Equal(t, tt.shouldBeAllowed, allowed) + }) + } +} + +// TestBudgetResolver_ContextPopulation tests context values are set correctly +func TestBudgetResolver_ContextPopulation(t *testing.T) { + logger := NewMockLogger() + vk := buildVirtualKey("vk1", "sk-bf-test", "Test VK", true) + customer := buildCustomer("cust1", "Customer 1", nil) + team := buildTeam("team1", "Team 1", nil) + team.CustomerID = &customer.ID + team.Customer = customer + vk.TeamID = &team.ID + vk.Team = team + vk.CustomerID = &customer.ID + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Teams: []configstoreTables.TableTeam{*team}, + Customers: []configstoreTables.TableCustomer{*customer}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateRequest(ctx, &EvaluationRequest{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + }) + + assert.Equal(t, DecisionAllow, result.Decision) + + // Check context was populated + vkID, _ := ctx.Value(schemas.BifrostContextKey("bf-governance-virtual-key-id")).(string) + teamID, _ := ctx.Value(schemas.BifrostContextKey("bf-governance-team-id")).(string) + customerID, _ := ctx.Value(schemas.BifrostContextKey("bf-governance-customer-id")).(string) + + assert.Equal(t, "vk1", vkID) + assert.Equal(t, "team1", teamID) + assert.Equal(t, "cust1", customerID) +} diff --git a/plugins/governance/store.go b/plugins/governance/store.go index f82240fb09..c3ef36fe31 100644 --- a/plugins/governance/store.go +++ b/plugins/governance/store.go @@ -4,6 +4,7 @@ package governance import ( "context" "fmt" + "strings" "sync" "time" @@ -11,16 +12,16 @@ import ( "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "gorm.io/gorm" - "gorm.io/gorm/clause" ) -// GovernanceStore provides in-memory cache for governance data with fast, non-blocking access -type GovernanceStore struct { +// LocalGovernanceStore provides in-memory cache for governance data with fast, non-blocking access +type LocalGovernanceStore struct { // Core data maps using sync.Map for lock-free reads virtualKeys sync.Map // string -> *VirtualKey (VK value -> VirtualKey with preloaded relationships) teams sync.Map // string -> *Team (Team ID -> Team) customers sync.Map // string -> *Customer (Customer ID -> Customer) budgets sync.Map // string -> *Budget (Budget ID -> Budget) + rateLimits sync.Map // string -> *RateLimit (RateLimit ID -> RateLimit) // Config store for refresh operations configStore configstore.ConfigStore @@ -29,9 +30,50 @@ type GovernanceStore struct { logger schemas.Logger } -// NewGovernanceStore creates a new in-memory governance store -func NewGovernanceStore(ctx context.Context, logger schemas.Logger, configStore configstore.ConfigStore, governanceConfig *configstore.GovernanceConfig) (*GovernanceStore, error) { - store := &GovernanceStore{ +type GovernanceData struct { + VirtualKeys map[string]*configstoreTables.TableVirtualKey `json:"virtual_keys"` + Teams map[string]*configstoreTables.TableTeam `json:"teams"` + Customers map[string]*configstoreTables.TableCustomer `json:"customers"` + Budgets map[string]*configstoreTables.TableBudget `json:"budgets"` + RateLimits map[string]*configstoreTables.TableRateLimit `json:"rate_limits"` +} + +// GovernanceStore defines the interface for governance data access and policy evaluation. +// +// Error semantics contract: +// - CheckRateLimit and CheckBudget return a non-nil error to indicate a governance/policy +// violation (not an infrastructure/operational failure). +// - Callers must treat any non-nil error from these methods as an explicit denial/violation +// decision rather than a retryable infrastructure error. +// - This contract ensures consistent behavior across implementations (e.g., in-memory, +// DB-backed) and prevents retry loops on policy violations. +type GovernanceStore interface { + GetGovernanceData() *GovernanceData + GetVirtualKey(vkValue string) (*configstoreTables.TableVirtualKey, bool) + CheckBudget(ctx context.Context, vk *configstoreTables.TableVirtualKey, request *EvaluationRequest, baselines map[string]float64) error + CheckRateLimit(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, model string, requestID string, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) + UpdateBudgetUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, cost float64) error + UpdateRateLimitUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) error + ResetExpiredRateLimitsInMemory(ctx context.Context) []*configstoreTables.TableRateLimit + ResetExpiredBudgetsInMemory(ctx context.Context) []*configstoreTables.TableBudget + ResetExpiredRateLimits(ctx context.Context, resetRateLimits []*configstoreTables.TableRateLimit) error + ResetExpiredBudgets(ctx context.Context, resetBudgets []*configstoreTables.TableBudget) error + DumpRateLimits(ctx context.Context, tokenBaselines map[string]int64, requestBaselines map[string]int64) error + DumpBudgets(ctx context.Context, baselines map[string]float64) error + CreateVirtualKeyInMemory(vk *configstoreTables.TableVirtualKey) + UpdateVirtualKeyInMemory(vk *configstoreTables.TableVirtualKey, budgetBaselines map[string]float64, rateLimitTokensBaselines map[string]int64, rateLimitRequestsBaselines map[string]int64) + DeleteVirtualKeyInMemory(vkID string) + CreateTeamInMemory(team *configstoreTables.TableTeam) + UpdateTeamInMemory(team *configstoreTables.TableTeam, budgetBaselines map[string]float64) + DeleteTeamInMemory(teamID string) + CreateCustomerInMemory(customer *configstoreTables.TableCustomer) + UpdateCustomerInMemory(customer *configstoreTables.TableCustomer, budgetBaselines map[string]float64) + DeleteCustomerInMemory(customerID string) +} + +// NewLocalGovernanceStore creates a new in-memory governance store +func NewLocalGovernanceStore(ctx context.Context, logger schemas.Logger, configStore configstore.ConfigStore, governanceConfig *configstore.GovernanceConfig) (*LocalGovernanceStore, error) { + store := &LocalGovernanceStore{ configStore: configStore, logger: logger, } @@ -51,8 +93,63 @@ func NewGovernanceStore(ctx context.Context, logger schemas.Logger, configStore return store, nil } +func (gs *LocalGovernanceStore) GetGovernanceData() *GovernanceData { + virtualKeys := make(map[string]*configstoreTables.TableVirtualKey) + gs.virtualKeys.Range(func(key, value interface{}) bool { + vk, ok := value.(*configstoreTables.TableVirtualKey) + if !ok || vk == nil { + return true // continue + } + virtualKeys[key.(string)] = vk + return true // continue iteration + }) + teams := make(map[string]*configstoreTables.TableTeam) + gs.teams.Range(func(key, value interface{}) bool { + team, ok := value.(*configstoreTables.TableTeam) + if !ok || team == nil { + return true // continue + } + teams[key.(string)] = team + return true // continue iteration + }) + customers := make(map[string]*configstoreTables.TableCustomer) + gs.customers.Range(func(key, value interface{}) bool { + customer, ok := value.(*configstoreTables.TableCustomer) + if !ok || customer == nil { + return true // continue + } + customers[key.(string)] = customer + return true // continue iteration + }) + budgets := make(map[string]*configstoreTables.TableBudget) + gs.budgets.Range(func(key, value interface{}) bool { + budget, ok := value.(*configstoreTables.TableBudget) + if !ok || budget == nil { + return true // continue + } + budgets[key.(string)] = budget + return true // continue iteration + }) + rateLimits := make(map[string]*configstoreTables.TableRateLimit) + gs.rateLimits.Range(func(key, value interface{}) bool { + rateLimit, ok := value.(*configstoreTables.TableRateLimit) + if !ok || rateLimit == nil { + return true // continue + } + rateLimits[key.(string)] = rateLimit + return true // continue iteration + }) + return &GovernanceData{ + VirtualKeys: virtualKeys, + Teams: teams, + Customers: customers, + Budgets: budgets, + RateLimits: rateLimits, + } +} + // GetVirtualKey retrieves a virtual key by its value (lock-free) with all relationships preloaded -func (gs *GovernanceStore) GetVirtualKey(vkValue string) (*configstoreTables.TableVirtualKey, bool) { +func (gs *LocalGovernanceStore) GetVirtualKey(vkValue string) (*configstoreTables.TableVirtualKey, bool) { value, exists := gs.virtualKeys.Load(vkValue) if !exists || value == nil { return nil, false @@ -65,308 +162,582 @@ func (gs *GovernanceStore) GetVirtualKey(vkValue string) (*configstoreTables.Tab return vk, true } -// GetAllBudgets returns all budgets (for background reset operations) -func (gs *GovernanceStore) GetAllBudgets() map[string]*configstoreTables.TableBudget { - result := make(map[string]*configstoreTables.TableBudget) - gs.budgets.Range(func(key, value interface{}) bool { - // Type-safe conversion - keyStr, keyOk := key.(string) - budget, budgetOk := value.(*configstoreTables.TableBudget) - - if keyOk && budgetOk && budget != nil { - result[keyStr] = budget - } - return true // continue iteration - }) - return result -} - // CheckBudget performs budget checking using in-memory store data (lock-free for high performance) -func (gs *GovernanceStore) CheckBudget(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider) error { +func (gs *LocalGovernanceStore) CheckBudget(ctx context.Context, vk *configstoreTables.TableVirtualKey, request *EvaluationRequest, baselines map[string]float64) error { if vk == nil { return fmt.Errorf("virtual key cannot be nil") } + // This is to prevent nil pointer dereference + if baselines == nil { + baselines = map[string]float64{} + } + // Use helper to collect budgets and their names (lock-free) - budgetsToCheck, budgetNames := gs.collectBudgetsFromHierarchy(ctx, vk, provider) + budgetsToCheck, budgetNames := gs.collectBudgetsFromHierarchy(vk, request.Provider) + + gs.logger.Debug("LocalStore CheckBudget: Received %d baselines from remote nodes", len(baselines)) + for budgetID, baseline := range baselines { + gs.logger.Debug(" - Baseline for budget %s: %.4f", budgetID, baseline) + } // Check each budget in hierarchy order using in-memory data for i, budget := range budgetsToCheck { // Check if budget needs reset (in-memory check) if budget.ResetDuration != "" { if duration, err := configstoreTables.ParseDuration(budget.ResetDuration); err == nil { - if time.Since(budget.LastReset).Round(time.Millisecond) >= duration { + if time.Since(budget.LastReset) >= duration { // Budget expired but hasn't been reset yet - treat as reset // Note: actual reset will happen in post-hook via AtomicBudgetUpdate + gs.logger.Debug("LocalStore CheckBudget: Budget %s (%s) expired, skipping check", budget.ID, budgetNames[i]) continue // Skip budget check for expired budgets } } } - // Check if current usage exceeds budget limit - if budget.CurrentUsage > budget.MaxLimit { - return fmt.Errorf("%s budget exceeded: %.4f > %.4f dollars", - budgetNames[i], budget.CurrentUsage, budget.MaxLimit) + baseline, exists := baselines[budget.ID] + if !exists { + baseline = 0 + } + + gs.logger.Debug("LocalStore CheckBudget: Checking %s budget %s: local=%.4f, remote=%.4f, total=%.4f, limit=%.4f", + budgetNames[i], budget.ID, budget.CurrentUsage, baseline, budget.CurrentUsage+baseline, budget.MaxLimit) + + // Check if current usage (local + remote baseline) exceeds budget limit + if budget.CurrentUsage+baseline >= budget.MaxLimit { + gs.logger.Debug("LocalStore CheckBudget: Budget %s EXCEEDED", budget.ID) + return fmt.Errorf("%s budget exceeded: %.4f >= %.4f dollars", + budgetNames[i], budget.CurrentUsage+baseline, budget.MaxLimit) } } + gs.logger.Debug("LocalStore CheckBudget: All budgets passed") + return nil } -// UpdateBudget performs atomic budget updates across the hierarchy (both in memory and in database) -func (gs *GovernanceStore) UpdateBudget(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, cost float64) error { - if vk == nil { - return fmt.Errorf("virtual key cannot be nil") - } +// CheckRateLimit checks a single rate limit and returns evaluation result if violated (true if violated, false if not) +func (gs *LocalGovernanceStore) CheckRateLimit(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, model string, requestID string, tokensBaselines map[string]int64, requestsBaselines map[string]int64) (Decision, error) { + var violations []string - // Collect budget IDs using fast in-memory lookup instead of DB queries - budgetIDs := gs.collectBudgetIDsFromMemory(ctx, vk, provider) + // Collect rate limits and their names from the hierarchy + rateLimits, rateLimitNames := gs.collectRateLimitsFromHierarchy(vk, provider) - if gs.configStore == nil { - for _, budgetID := range budgetIDs { - // Update in-memory cache for next read (lock-free) - if cachedBudgetValue, exists := gs.budgets.Load(budgetID); exists && cachedBudgetValue != nil { - if cachedBudget, ok := cachedBudgetValue.(*configstoreTables.TableBudget); ok && cachedBudget != nil { - clone := *cachedBudget - clone.CurrentUsage += cost - gs.budgets.Store(budgetID, &clone) + // This is to prevent nil pointer dereference + if tokensBaselines == nil { + tokensBaselines = map[string]int64{} + } + if requestsBaselines == nil { + requestsBaselines = map[string]int64{} + } + + for i, rateLimit := range rateLimits { + // Determine token and request expiration independently + tokenExpired := false + requestExpired := false + + // Check if token reset duration is expired + if rateLimit.TokenResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration); err == nil { + if time.Since(rateLimit.TokenLastReset) >= duration { + // Token rate limit expired but hasn't been reset yet - skip token checks + // Note: actual reset will happen in post-hook via AtomicRateLimitUpdate + tokenExpired = true } } } - return nil - } - - return gs.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { - // budgetIDs already collected from in-memory data - no need to duplicate - - // Update each budget atomically - for _, budgetID := range budgetIDs { - var budget configstoreTables.TableBudget - if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).First(&budget, "id = ?", budgetID).Error; err != nil { - return fmt.Errorf("failed to lock budget %s: %w", budgetID, err) + // Check if request reset duration is expired + if rateLimit.RequestResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration); err == nil { + if time.Since(rateLimit.RequestLastReset) >= duration { + // Request rate limit expired but hasn't been reset yet - skip request checks + // Note: actual reset will happen in post-hook via AtomicRateLimitUpdate + requestExpired = true + } } + } - // Check if budget needs reset - if err := gs.resetBudgetIfNeeded(ctx, tx, &budget); err != nil { - return fmt.Errorf("failed to reset budget: %w", err) + tokensBaseline, exists := tokensBaselines[rateLimit.ID] + if !exists { + tokensBaseline = 0 + } + requestsBaseline, exists := requestsBaselines[rateLimit.ID] + if !exists { + requestsBaseline = 0 + } + + // Token limits - check if total usage (local + remote baseline) exceeds limit + // Only check if token limit is not expired + if !tokenExpired && rateLimit.TokenMaxLimit != nil && rateLimit.TokenCurrentUsage+tokensBaseline >= *rateLimit.TokenMaxLimit { + duration := "unknown" + if rateLimit.TokenResetDuration != nil { + duration = *rateLimit.TokenResetDuration } + violations = append(violations, fmt.Sprintf("token limit exceeded (%d/%d, resets every %s)", + rateLimit.TokenCurrentUsage+tokensBaseline, *rateLimit.TokenMaxLimit, duration)) + } - // Update usage - budget.CurrentUsage += cost - if err := gs.configStore.UpdateBudget(ctx, &budget, tx); err != nil { - return fmt.Errorf("failed to save budget %s: %w", budgetID, err) + // Request limits - check if total usage (local + remote baseline) exceeds limit + // Only check if request limit is not expired + if !requestExpired && rateLimit.RequestMaxLimit != nil && rateLimit.RequestCurrentUsage+requestsBaseline >= *rateLimit.RequestMaxLimit { + duration := "unknown" + if rateLimit.RequestResetDuration != nil { + duration = *rateLimit.RequestResetDuration } + violations = append(violations, fmt.Sprintf("request limit exceeded (%d/%d, resets every %s)", + rateLimit.RequestCurrentUsage+requestsBaseline, *rateLimit.RequestMaxLimit, duration)) + } - // Update in-memory cache for next read (lock-free) - if cachedBudgetValue, exists := gs.budgets.Load(budgetID); exists && cachedBudgetValue != nil { - if cachedBudget, ok := cachedBudgetValue.(*configstoreTables.TableBudget); ok && cachedBudget != nil { - clone := *cachedBudget - clone.CurrentUsage += cost - clone.LastReset = budget.LastReset - gs.budgets.Store(budgetID, &clone) + if len(violations) > 0 { + // Determine specific violation type + decision := DecisionRateLimited // Default to general rate limited decision + if len(violations) == 1 { + if strings.Contains(violations[0], "token") { + decision = DecisionTokenLimited // More specific violation type + } else if strings.Contains(violations[0], "request") { + decision = DecisionRequestLimited // More specific violation type } } + msg := strings.Join(violations, "; ") + return decision, fmt.Errorf("rate limit violated for %s: %s", rateLimitNames[i], msg) } + } - return nil - }) + return DecisionAllow, nil // No rate limit violations } -// UpdateRateLimitUsage updates rate limit counters for both provider-level and VK-level rate limits (lock-free) -func (gs *GovernanceStore) UpdateRateLimitUsage(ctx context.Context, vkValue string, provider string, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) error { - if vkValue == "" { - return fmt.Errorf("virtual key value cannot be empty") +// UpdateBudgetUsageInMemory performs atomic budget updates across the hierarchy (both in memory and in database) +func (gs *LocalGovernanceStore) UpdateBudgetUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, cost float64) error { + if vk == nil { + return fmt.Errorf("virtual key cannot be nil") } - vkValue_, exists := gs.virtualKeys.Load(vkValue) - if !exists || vkValue_ == nil { - return fmt.Errorf("virtual key not found: %s", vkValue) + // Collect budget IDs using fast in-memory lookup instead of DB queries + budgetIDs := gs.collectBudgetIDsFromMemory(ctx, vk, provider) + now := time.Now() + for _, budgetID := range budgetIDs { + // Update in-memory cache for next read (lock-free) + if cachedBudgetValue, exists := gs.budgets.Load(budgetID); exists && cachedBudgetValue != nil { + if cachedBudget, ok := cachedBudgetValue.(*configstoreTables.TableBudget); ok && cachedBudget != nil { + // Clone FIRST to avoid race conditions + clone := *cachedBudget + oldUsage := clone.CurrentUsage + + // Check if budget needs reset (in-memory check) - operate on clone + if clone.ResetDuration != "" { + if duration, err := configstoreTables.ParseDuration(clone.ResetDuration); err == nil { + if now.Sub(clone.LastReset) >= duration { + clone.CurrentUsage = 0 + clone.LastReset = now + gs.logger.Debug("UpdateBudgetUsage: Budget %s was reset (expired, duration: %v)", budgetID, duration) + } + } + } + + // Update the clone + clone.CurrentUsage += cost + gs.budgets.Store(budgetID, &clone) + gs.logger.Debug("UpdateBudgetUsage: Updated budget %s: %.4f -> %.4f (added %.4f)", + budgetID, oldUsage, clone.CurrentUsage, cost) + } + } else { + gs.logger.Warn("UpdateBudgetUsage: Budget %s not found in local store", budgetID) + } } + return nil +} - vk, ok := vkValue_.(*configstoreTables.TableVirtualKey) - if !ok || vk == nil { - return fmt.Errorf("invalid virtual key type for: %s", vkValue) +// UpdateRateLimitUsageInMemory updates rate limit counters for both provider-level and VK-level rate limits (lock-free) +func (gs *LocalGovernanceStore) UpdateRateLimitUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) error { + if vk == nil { + return fmt.Errorf("virtual key cannot be nil") } - var rateLimitsToUpdate []*configstoreTables.TableRateLimit + // Collect rate limit IDs using fast in-memory lookup instead of DB queries + rateLimitIDs := gs.collectRateLimitIDsFromMemory(vk, provider) + now := time.Now() - // First, update provider-level rate limits if they exist - if provider != "" && vk.ProviderConfigs != nil { - for _, pc := range vk.ProviderConfigs { - if pc.Provider == provider && pc.RateLimit != nil { - if gs.updateSingleRateLimit(pc.RateLimit, tokensUsed, shouldUpdateTokens, shouldUpdateRequests) { - rateLimitsToUpdate = append(rateLimitsToUpdate, pc.RateLimit) + for _, rateLimitID := range rateLimitIDs { + // Update in-memory cache for next read (lock-free) + if cachedRateLimitValue, exists := gs.rateLimits.Load(rateLimitID); exists && cachedRateLimitValue != nil { + if cachedRateLimit, ok := cachedRateLimitValue.(*configstoreTables.TableRateLimit); ok && cachedRateLimit != nil { + // Clone FIRST to avoid race conditions + clone := *cachedRateLimit + + // Check if rate limit needs reset (in-memory check) - operate on clone + if clone.TokenResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*clone.TokenResetDuration); err == nil { + if now.Sub(clone.TokenLastReset) >= duration { + clone.TokenCurrentUsage = 0 + clone.TokenLastReset = now + } + } + } + if clone.RequestResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*clone.RequestResetDuration); err == nil { + if now.Sub(clone.RequestLastReset) >= duration { + clone.RequestCurrentUsage = 0 + clone.RequestLastReset = now + } + } + } + + // Update the clone + if shouldUpdateTokens { + clone.TokenCurrentUsage += tokensUsed } - break + if shouldUpdateRequests { + clone.RequestCurrentUsage += 1 + } + gs.rateLimits.Store(rateLimitID, &clone) } } } + return nil +} - // Then, update VK-level rate limits if they exist - if vk.RateLimit != nil { - if gs.updateSingleRateLimit(vk.RateLimit, tokensUsed, shouldUpdateTokens, shouldUpdateRequests) { - rateLimitsToUpdate = append(rateLimitsToUpdate, vk.RateLimit) +// ResetExpiredBudgetsInMemory checks and resets budgets that have exceeded their reset duration (lock-free) +func (gs *LocalGovernanceStore) ResetExpiredBudgetsInMemory(ctx context.Context) []*configstoreTables.TableBudget { + now := time.Now() + var resetBudgets []*configstoreTables.TableBudget + + gs.budgets.Range(func(key, value interface{}) bool { + // Type-safe conversion + budget, ok := value.(*configstoreTables.TableBudget) + if !ok || budget == nil { + return true // continue } - } - // Save all updated rate limits to database - if len(rateLimitsToUpdate) > 0 && gs.configStore != nil { - if err := gs.configStore.UpdateRateLimits(ctx, rateLimitsToUpdate); err != nil { - return fmt.Errorf("failed to update rate limit usage: %w", err) + duration, err := configstoreTables.ParseDuration(budget.ResetDuration) + if err != nil { + gs.logger.Error("invalid budget reset duration %s: %v", budget.ResetDuration, err) + return true // continue } - } - return nil + if now.Sub(budget.LastReset) >= duration { + // Create a copy to avoid data race (sync.Map is concurrent-safe for reads/writes but not mutations) + copiedBudget := *budget + oldUsage := copiedBudget.CurrentUsage + copiedBudget.CurrentUsage = 0 + copiedBudget.LastReset = now + copiedBudget.LastDBUsage = 0 + + // Atomically replace the entry using the original key + gs.budgets.Store(key, &copiedBudget) + resetBudgets = append(resetBudgets, &copiedBudget) + + // Update all VKs, teams, customers, and provider configs that reference this budget + gs.updateBudgetReferences(&copiedBudget) + + gs.logger.Debug(fmt.Sprintf("Reset budget %s (was %.2f, reset to 0)", + copiedBudget.ID, oldUsage)) + } + return true // continue + }) + + return resetBudgets } -// updateSingleRateLimit updates a single rate limit's counters and returns true if any changes were made -func (gs *GovernanceStore) updateSingleRateLimit(rateLimit *configstoreTables.TableRateLimit, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) bool { +// ResetExpiredRateLimitsInMemory performs background reset of expired rate limits for both provider-level and VK-level (lock-free) +func (gs *LocalGovernanceStore) ResetExpiredRateLimitsInMemory(ctx context.Context) []*configstoreTables.TableRateLimit { now := time.Now() - updated := false + var resetRateLimits []*configstoreTables.TableRateLimit - // Check and reset token counter if needed - if rateLimit.TokenResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration); err == nil { - if now.Sub(rateLimit.TokenLastReset) >= duration { - rateLimit.TokenCurrentUsage = 0 - rateLimit.TokenLastReset = now - updated = true - } + gs.rateLimits.Range(func(key, value interface{}) bool { + // Type-safe conversion + rateLimit, ok := value.(*configstoreTables.TableRateLimit) + if !ok || rateLimit == nil { + return true // continue } - } - // Check and reset request counter if needed - if rateLimit.RequestResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration); err == nil { - if now.Sub(rateLimit.RequestLastReset) >= duration { - rateLimit.RequestCurrentUsage = 0 - rateLimit.RequestLastReset = now - updated = true + needsReset := false + // Check if token reset is needed + if rateLimit.TokenResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration); err == nil { + if now.Sub(rateLimit.TokenLastReset) >= duration { + needsReset = true + } + } + } + // Check if request reset is needed + if rateLimit.RequestResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration); err == nil { + if now.Sub(rateLimit.RequestLastReset) >= duration { + needsReset = true + } } } - } - // Update usage counters based on flags - if shouldUpdateTokens && tokensUsed > 0 { - rateLimit.TokenCurrentUsage += tokensUsed - updated = true - } + if needsReset { + // Create a copy to avoid data race (sync.Map is concurrent-safe for reads/writes but not mutations) + copiedRateLimit := *rateLimit + + // Reset token limits if expired + if copiedRateLimit.TokenResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*copiedRateLimit.TokenResetDuration); err == nil { + if now.Sub(copiedRateLimit.TokenLastReset) >= duration { + copiedRateLimit.TokenCurrentUsage = 0 + copiedRateLimit.TokenLastReset = now + copiedRateLimit.LastDBTokenUsage = 0 + } + } + } + // Reset request limits if expired + if copiedRateLimit.RequestResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*copiedRateLimit.RequestResetDuration); err == nil { + if now.Sub(copiedRateLimit.RequestLastReset) >= duration { + copiedRateLimit.RequestCurrentUsage = 0 + copiedRateLimit.RequestLastReset = now + copiedRateLimit.LastDBRequestUsage = 0 + } + } + } - if shouldUpdateRequests { - rateLimit.RequestCurrentUsage += 1 - updated = true - } + // Atomically replace the entry using the original key + gs.rateLimits.Store(key, &copiedRateLimit) + resetRateLimits = append(resetRateLimits, &copiedRateLimit) - return updated -} + // Update all VKs and provider configs that reference this rate limit + gs.updateRateLimitReferences(&copiedRateLimit) + } + return true // continue + }) -// checkAndResetSingleRateLimit checks and resets a single rate limit's counters if expired -func (gs *GovernanceStore) checkAndResetSingleRateLimit(ctx context.Context, rateLimit *configstoreTables.TableRateLimit, now time.Time) bool { - updated := false + return resetRateLimits +} - // Check and reset token counter if needed - if rateLimit.TokenResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration); err == nil { - if now.Sub(rateLimit.TokenLastReset).Round(time.Millisecond) >= duration { - rateLimit.TokenCurrentUsage = 0 - rateLimit.TokenLastReset = now - updated = true +// ResetExpiredBudgets checks and resets budgets that have exceeded their reset duration in database +func (gs *LocalGovernanceStore) ResetExpiredBudgets(ctx context.Context, resetBudgets []*configstoreTables.TableBudget) error { + // Persist to database if any resets occurred using direct UPDATE to avoid overwriting config fields + if len(resetBudgets) > 0 && gs.configStore != nil { + if err := gs.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { + for _, budget := range resetBudgets { + // Direct UPDATE only resets current_usage and last_reset + // This prevents overwriting max_limit or reset_duration that may have been changed by other nodes/requests + result := tx.WithContext(ctx). + Session(&gorm.Session{SkipHooks: true}). + Model(&configstoreTables.TableBudget{}). + Where("id = ?", budget.ID). + Updates(map[string]interface{}{ + "current_usage": budget.CurrentUsage, + "last_reset": budget.LastReset, + }) + + if result.Error != nil { + return fmt.Errorf("failed to reset budget %s: %w", budget.ID, result.Error) + } } + return nil + }); err != nil { + return fmt.Errorf("failed to persist budget resets to database: %w", err) } } - // Check and reset request counter if needed - if rateLimit.RequestResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration); err == nil { - if now.Sub(rateLimit.RequestLastReset).Round(time.Millisecond) >= duration { - rateLimit.RequestCurrentUsage = 0 - rateLimit.RequestLastReset = now - updated = true + return nil +} + +// ResetExpiredRateLimits performs background reset of expired rate limits for both provider-level and VK-level in database +func (gs *LocalGovernanceStore) ResetExpiredRateLimits(ctx context.Context, resetRateLimits []*configstoreTables.TableRateLimit) error { + if len(resetRateLimits) > 0 && gs.configStore != nil { + if err := gs.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { + for _, rateLimit := range resetRateLimits { + // Build update map with only the fields that were reset + updates := make(map[string]interface{}) + + // Check which fields were reset by comparing with current values + if rateLimit.TokenCurrentUsage == 0 && rateLimit.TokenResetDuration != nil { + updates["token_current_usage"] = 0 + updates["token_last_reset"] = rateLimit.TokenLastReset + } + if rateLimit.RequestCurrentUsage == 0 && rateLimit.RequestResetDuration != nil { + updates["request_current_usage"] = 0 + updates["request_last_reset"] = rateLimit.RequestLastReset + } + + if len(updates) > 0 { + // Direct UPDATE only resets usage and last_reset fields + // This prevents overwriting max_limit or reset_duration that may have been changed by other nodes/requests + result := tx.WithContext(ctx). + Session(&gorm.Session{SkipHooks: true}). + Model(&configstoreTables.TableRateLimit{}). + Where("id = ?", rateLimit.ID). + Updates(updates) + + if result.Error != nil { + return fmt.Errorf("failed to reset rate limit %s: %w", rateLimit.ID, result.Error) + } + } } + return nil + }); err != nil { + return fmt.Errorf("failed to persist rate limit resets to database: %w", err) } } - - return updated + return nil } -// ResetExpiredRateLimits performs background reset of expired rate limits for both provider-level and VK-level (lock-free) -func (gs *GovernanceStore) ResetExpiredRateLimits(ctx context.Context) error { - now := time.Now() - var resetRateLimits []*configstoreTables.TableRateLimit +// DumpRateLimits dumps all rate limits to the database +func (gs *LocalGovernanceStore) DumpRateLimits(ctx context.Context, tokenBaselines map[string]int64, requestBaselines map[string]int64) error { + if gs.configStore == nil { + return nil + } + // This is to prevent nil pointer dereference + if tokenBaselines == nil { + tokenBaselines = map[string]int64{} + } + if requestBaselines == nil { + requestBaselines = map[string]int64{} + } + + // Collect unique rate limit IDs from virtual keys + rateLimitIDs := make(map[string]bool) gs.virtualKeys.Range(func(key, value interface{}) bool { - // Type-safe conversion vk, ok := value.(*configstoreTables.TableVirtualKey) if !ok || vk == nil { return true // continue } - - // Check provider-level rate limits + if vk.RateLimitID != nil { + rateLimitIDs[*vk.RateLimitID] = true + } if vk.ProviderConfigs != nil { for _, pc := range vk.ProviderConfigs { - if pc.RateLimit != nil { - if gs.checkAndResetSingleRateLimit(ctx, pc.RateLimit, now) { - resetRateLimits = append(resetRateLimits, pc.RateLimit) - } + if pc.RateLimitID != nil { + rateLimitIDs[*pc.RateLimitID] = true } } } - - // Check VK-level rate limits - if vk.RateLimit != nil { - if gs.checkAndResetSingleRateLimit(ctx, vk.RateLimit, now) { - resetRateLimits = append(resetRateLimits, vk.RateLimit) - } - } - return true // continue }) - // Persist reset rate limits to database - if len(resetRateLimits) > 0 && gs.configStore != nil { - if err := gs.configStore.UpdateRateLimits(ctx, resetRateLimits); err != nil { - return fmt.Errorf("failed to persist rate limit resets to database: %w", err) + // Prepare rate limit usage updates with baselines + type rateLimitUpdate struct { + ID string + TokenCurrentUsage int64 + RequestCurrentUsage int64 + } + var rateLimitUpdates []rateLimitUpdate + for rateLimitID := range rateLimitIDs { + if rateLimitValue, exists := gs.rateLimits.Load(rateLimitID); exists && rateLimitValue != nil { + if rateLimit, ok := rateLimitValue.(*configstoreTables.TableRateLimit); ok && rateLimit != nil { + update := rateLimitUpdate{ + ID: rateLimit.ID, + TokenCurrentUsage: rateLimit.TokenCurrentUsage, + RequestCurrentUsage: rateLimit.RequestCurrentUsage, + } + if tokenBaseline, exists := tokenBaselines[rateLimit.ID]; exists { + update.TokenCurrentUsage += tokenBaseline + } + if requestBaseline, exists := requestBaselines[rateLimit.ID]; exists { + update.RequestCurrentUsage += requestBaseline + } + rateLimitUpdates = append(rateLimitUpdates, update) + } } } + // Save all updated rate limits to database using direct UPDATE to avoid overwriting config fields + if len(rateLimitUpdates) > 0 && gs.configStore != nil { + if err := gs.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { + for _, update := range rateLimitUpdates { + // Direct UPDATE only updates usage fields + // This prevents overwriting max_limit or reset_duration that may have been changed by other nodes/requests + result := tx.WithContext(ctx). + Session(&gorm.Session{SkipHooks: true}). + Model(&configstoreTables.TableRateLimit{}). + Where("id = ?", update.ID). + Updates(map[string]interface{}{ + "token_current_usage": update.TokenCurrentUsage, + "request_current_usage": update.RequestCurrentUsage, + }) + + if result.Error != nil { + return fmt.Errorf("failed to dump rate limit %s: %w", update.ID, result.Error) + } + } + return nil + }); err != nil { + // Check if error is a deadlock (SQLSTATE 40P01 for PostgreSQL, 1213 for MySQL) + errStr := err.Error() + isDeadlock := strings.Contains(errStr, "deadlock") || + strings.Contains(errStr, "40P01") || + strings.Contains(errStr, "1213") + + if isDeadlock { + // Deadlock means another node is updating the same rows - this is fine! + // Our usage data will be synced via gossip and written in the next dump cycle + gs.logger.Debug("Rate limit dump encountered deadlock (another node is updating) - will retry next cycle") + return nil // Not a real error in multi-node setup + } + return fmt.Errorf("failed to dump rate limits to database: %w", err) + } + } return nil } -// ResetExpiredBudgets checks and resets budgets that have exceeded their reset duration (lock-free) -func (gs *GovernanceStore) ResetExpiredBudgets(ctx context.Context) error { - now := time.Now() - var resetBudgets []*configstoreTables.TableBudget +// DumpBudgets dumps all budgets to the database +func (gs *LocalGovernanceStore) DumpBudgets(ctx context.Context, baselines map[string]float64) error { + if gs.configStore == nil { + return nil + } + + // This is to prevent nil pointer dereference + if baselines == nil { + baselines = map[string]float64{} + } + + budgets := make(map[string]*configstoreTables.TableBudget) gs.budgets.Range(func(key, value interface{}) bool { // Type-safe conversion - budget, ok := value.(*configstoreTables.TableBudget) - if !ok || budget == nil { - return true // continue - } + keyStr, keyOk := key.(string) + budget, budgetOk := value.(*configstoreTables.TableBudget) - duration, err := configstoreTables.ParseDuration(budget.ResetDuration) - if err != nil { - gs.logger.Error("invalid budget reset duration %s: %w", budget.ResetDuration, err) - return true // continue + if keyOk && budgetOk && budget != nil { + budgets[keyStr] = budget // Store budget by ID } + return true // continue iteration + }) - if now.Sub(budget.LastReset) >= duration { - oldUsage := budget.CurrentUsage - budget.CurrentUsage = 0 - budget.LastReset = now - resetBudgets = append(resetBudgets, budget) + if len(budgets) > 0 && gs.configStore != nil { + if err := gs.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { + // Update each budget atomically using direct UPDATE to avoid deadlocks + // (SELECT + Save pattern causes deadlocks when multiple instances run concurrently) + for _, inMemoryBudget := range budgets { + // Calculate the new usage value + newUsage := inMemoryBudget.CurrentUsage + if baseline, exists := baselines[inMemoryBudget.ID]; exists { + newUsage += baseline + } - gs.logger.Debug(fmt.Sprintf("Reset budget %s (was %.2f, reset to 0)", - budget.ID, oldUsage)) - } - return true // continue - }) + // Direct UPDATE avoids read-then-write lock escalation that causes deadlocks + // Use Session with SkipHooks to avoid triggering BeforeSave hook validation + result := tx.WithContext(ctx). + Session(&gorm.Session{SkipHooks: true}). + Model(&configstoreTables.TableBudget{}). + Where("id = ?", inMemoryBudget.ID). + Update("current_usage", newUsage) - // Persist to database if any resets occurred - if len(resetBudgets) > 0 && gs.configStore != nil { - if err := gs.configStore.UpdateBudgets(ctx, resetBudgets); err != nil { - return fmt.Errorf("failed to persist budget resets to database: %w", err) + if result.Error != nil { + return fmt.Errorf("failed to update budget %s: %w", inMemoryBudget.ID, result.Error) + } + } + return nil + }); err != nil { + // Check if error is a deadlock (SQLSTATE 40P01 for PostgreSQL, 1213 for MySQL) + errStr := err.Error() + isDeadlock := strings.Contains(errStr, "deadlock") || + strings.Contains(errStr, "40P01") || + strings.Contains(errStr, "1213") + + if isDeadlock { + // Deadlock means another node is updating the same rows - this is fine! + // Our usage data will be synced via gossip and written in the next dump cycle + gs.logger.Debug("Budget dump encountered deadlock (another node is updating) - will retry next cycle") + return nil // Not a real error in multi-node setup + } + return fmt.Errorf("failed to dump budgets to database: %w", err) } } @@ -376,7 +747,7 @@ func (gs *GovernanceStore) ResetExpiredBudgets(ctx context.Context) error { // DATABASE METHODS // loadFromDatabase loads all governance data from the database into memory -func (gs *GovernanceStore) loadFromDatabase(ctx context.Context) error { +func (gs *LocalGovernanceStore) loadFromDatabase(ctx context.Context) error { // Load customers with their budgets customers, err := gs.configStore.GetCustomers(ctx) if err != nil { @@ -401,14 +772,20 @@ func (gs *GovernanceStore) loadFromDatabase(ctx context.Context) error { return fmt.Errorf("failed to load budgets: %w", err) } + // Load rate limits + rateLimits, err := gs.configStore.GetRateLimits(ctx) + if err != nil { + return fmt.Errorf("failed to load rate limits: %w", err) + } + // Rebuild in-memory structures (lock-free) - gs.rebuildInMemoryStructures(ctx, customers, teams, virtualKeys, budgets) + gs.rebuildInMemoryStructures(ctx, customers, teams, virtualKeys, budgets, rateLimits) return nil } // loadFromConfigMemory loads all governance data from the config's memory into store's memory -func (gs *GovernanceStore) loadFromConfigMemory(ctx context.Context, config *configstore.GovernanceConfig) error { +func (gs *LocalGovernanceStore) loadFromConfigMemory(ctx context.Context, config *configstore.GovernanceConfig) error { if config == nil { return fmt.Errorf("governance config is nil") } @@ -456,22 +833,50 @@ func (gs *GovernanceStore) loadFromConfigMemory(ctx context.Context, config *con } } + // Populate provider config relationships with budgets and rate limits + if vk.ProviderConfigs != nil { + for j := range vk.ProviderConfigs { + pc := &vk.ProviderConfigs[j] + + // Populate budget + if pc.BudgetID != nil { + for k := range budgets { + if budgets[k].ID == *pc.BudgetID { + pc.Budget = &budgets[k] + break + } + } + } + + // Populate rate limit + if pc.RateLimitID != nil { + for k := range rateLimits { + if rateLimits[k].ID == *pc.RateLimitID { + pc.RateLimit = &rateLimits[k] + break + } + } + } + } + } + virtualKeys[i] = *vk } // Rebuild in-memory structures (lock-free) - gs.rebuildInMemoryStructures(ctx, customers, teams, virtualKeys, budgets) + gs.rebuildInMemoryStructures(ctx, customers, teams, virtualKeys, budgets, rateLimits) return nil } // rebuildInMemoryStructures rebuilds all in-memory data structures (lock-free) -func (gs *GovernanceStore) rebuildInMemoryStructures(ctx context.Context, customers []configstoreTables.TableCustomer, teams []configstoreTables.TableTeam, virtualKeys []configstoreTables.TableVirtualKey, budgets []configstoreTables.TableBudget) { +func (gs *LocalGovernanceStore) rebuildInMemoryStructures(ctx context.Context, customers []configstoreTables.TableCustomer, teams []configstoreTables.TableTeam, virtualKeys []configstoreTables.TableVirtualKey, budgets []configstoreTables.TableBudget, rateLimits []configstoreTables.TableRateLimit) { // Clear existing data by creating new sync.Maps gs.virtualKeys = sync.Map{} gs.teams = sync.Map{} gs.customers = sync.Map{} gs.budgets = sync.Map{} + gs.rateLimits = sync.Map{} // Build customers map for i := range customers { @@ -491,6 +896,12 @@ func (gs *GovernanceStore) rebuildInMemoryStructures(ctx context.Context, custom gs.budgets.Store(budget.ID, budget) } + // Build rate limits map + for i := range rateLimits { + rateLimit := &rateLimits[i] + gs.rateLimits.Store(rateLimit.ID, rateLimit) + } + // Build virtual keys map and track active VKs for i := range virtualKeys { vk := &virtualKeys[i] @@ -500,8 +911,40 @@ func (gs *GovernanceStore) rebuildInMemoryStructures(ctx context.Context, custom // UTILITY FUNCTIONS +// collectRateLimitsFromHierarchy collects rate limits and their metadata from the hierarchy (Provider Configs → VK) +func (gs *LocalGovernanceStore) collectRateLimitsFromHierarchy(vk *configstoreTables.TableVirtualKey, requestedProvider schemas.ModelProvider) ([]*configstoreTables.TableRateLimit, []string) { + if vk == nil { + return nil, nil + } + + var rateLimits []*configstoreTables.TableRateLimit + var rateLimitNames []string + + for _, pc := range vk.ProviderConfigs { + if pc.RateLimitID != nil && pc.Provider == string(requestedProvider) { + if rateLimitValue, exists := gs.rateLimits.Load(*pc.RateLimitID); exists && rateLimitValue != nil { + if rateLimit, ok := rateLimitValue.(*configstoreTables.TableRateLimit); ok && rateLimit != nil { + rateLimits = append(rateLimits, rateLimit) + rateLimitNames = append(rateLimitNames, pc.Provider) + } + } + } + } + + if vk.RateLimitID != nil { + if rateLimitValue, exists := gs.rateLimits.Load(*vk.RateLimitID); exists && rateLimitValue != nil { + if rateLimit, ok := rateLimitValue.(*configstoreTables.TableRateLimit); ok && rateLimit != nil { + rateLimits = append(rateLimits, rateLimit) + rateLimitNames = append(rateLimitNames, "VK") + } + } + } + + return rateLimits, rateLimitNames +} + // collectBudgetsFromHierarchy collects budgets and their metadata from the hierarchy (Provider Configs → VK → Team → Customer) -func (gs *GovernanceStore) collectBudgetsFromHierarchy(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider) ([]*configstoreTables.TableBudget, []string) { +func (gs *LocalGovernanceStore) collectBudgetsFromHierarchy(vk *configstoreTables.TableVirtualKey, requestedProvider schemas.ModelProvider) ([]*configstoreTables.TableBudget, []string) { if vk == nil { return nil, nil } @@ -511,7 +954,7 @@ func (gs *GovernanceStore) collectBudgetsFromHierarchy(ctx context.Context, vk * // Collect all budgets in hierarchy order using lock-free sync.Map access (Provider Configs → VK → Team → Customer) for _, pc := range vk.ProviderConfigs { - if pc.BudgetID != nil && pc.Provider == string(provider) { + if pc.BudgetID != nil && pc.Provider == string(requestedProvider) { if budgetValue, exists := gs.budgets.Load(*pc.BudgetID); exists && budgetValue != nil { if budget, ok := budgetValue.(*configstoreTables.TableBudget); ok && budget != nil { budgets = append(budgets, budget) @@ -580,8 +1023,8 @@ func (gs *GovernanceStore) collectBudgetsFromHierarchy(ctx context.Context, vk * } // collectBudgetIDsFromMemory collects budget IDs from in-memory store data (lock-free) -func (gs *GovernanceStore) collectBudgetIDsFromMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider) []string { - budgets, _ := gs.collectBudgetsFromHierarchy(ctx, vk, provider) +func (gs *LocalGovernanceStore) collectBudgetIDsFromMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider) []string { + budgets, _ := gs.collectBudgetsFromHierarchy(vk, provider) budgetIDs := make([]string, len(budgets)) for i, budget := range budgets { @@ -591,49 +1034,195 @@ func (gs *GovernanceStore) collectBudgetIDsFromMemory(ctx context.Context, vk *c return budgetIDs } -// resetBudgetIfNeeded checks and resets budget within a transaction -func (gs *GovernanceStore) resetBudgetIfNeeded(ctx context.Context, tx *gorm.DB, budget *configstoreTables.TableBudget) error { - duration, err := configstoreTables.ParseDuration(budget.ResetDuration) - if err != nil { - return fmt.Errorf("invalid reset duration %s: %w", budget.ResetDuration, err) - } - - now := time.Now() - if now.Sub(budget.LastReset) >= duration { - budget.CurrentUsage = 0 - budget.LastReset = now +// collectRateLimitIDsFromMemory collects rate limit IDs from in-memory store data (lock-free) +func (gs *LocalGovernanceStore) collectRateLimitIDsFromMemory(vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider) []string { + rateLimits, _ := gs.collectRateLimitsFromHierarchy(vk, provider) - if gs.configStore != nil { - // Save reset to database - if err := gs.configStore.UpdateBudget(ctx, budget, tx); err != nil { - return fmt.Errorf("failed to save budget reset: %w", err) - } - } + rateLimitIDs := make([]string, len(rateLimits)) + for i, rateLimit := range rateLimits { + rateLimitIDs[i] = rateLimit.ID } - return nil + return rateLimitIDs } // PUBLIC API METHODS // CreateVirtualKeyInMemory adds a new virtual key to the in-memory store (lock-free) -func (gs *GovernanceStore) CreateVirtualKeyInMemory(vk *configstoreTables.TableVirtualKey) { // with rateLimit preloaded +func (gs *LocalGovernanceStore) CreateVirtualKeyInMemory(vk *configstoreTables.TableVirtualKey) { if vk == nil { return // Nothing to create } + + // Create associated budget if exists + if vk.Budget != nil { + gs.budgets.Store(vk.Budget.ID, vk.Budget) + } + + // Create associated rate limit if exists + if vk.RateLimit != nil { + gs.rateLimits.Store(vk.RateLimit.ID, vk.RateLimit) + } + + // Create provider config budgets and rate limits if they exist + if vk.ProviderConfigs != nil { + for _, pc := range vk.ProviderConfigs { + if pc.Budget != nil { + gs.budgets.Store(pc.Budget.ID, pc.Budget) + } + if pc.RateLimit != nil { + gs.rateLimits.Store(pc.RateLimit.ID, pc.RateLimit) + } + } + } + gs.virtualKeys.Store(vk.Value, vk) } // UpdateVirtualKeyInMemory updates an existing virtual key in the in-memory store (lock-free) -func (gs *GovernanceStore) UpdateVirtualKeyInMemory(vk *configstoreTables.TableVirtualKey) { // with rateLimit preloaded +func (gs *LocalGovernanceStore) UpdateVirtualKeyInMemory(vk *configstoreTables.TableVirtualKey, budgetBaselines map[string]float64, rateLimitTokensBaselines map[string]int64, rateLimitRequestsBaselines map[string]int64) { if vk == nil { return // Nothing to update } - gs.virtualKeys.Store(vk.Value, vk) + if budgetBaselines == nil { + budgetBaselines = make(map[string]float64) + } + if rateLimitTokensBaselines == nil { + rateLimitTokensBaselines = make(map[string]int64) + } + if rateLimitRequestsBaselines == nil { + rateLimitRequestsBaselines = make(map[string]int64) + } + // Do not update the current usage of the rate limit, as it will be updated by the usage tracker. + // But update if max limit or reset duration changes. + if existingVKValue, exists := gs.virtualKeys.Load(vk.Value); exists && existingVKValue != nil { + existingVK, ok := existingVKValue.(*configstoreTables.TableVirtualKey) + if !ok || existingVK == nil { + return // Nothing to update + } + // Create clone to avoid modifying the original + clone := *vk + // Update Budget using checkAndUpdateBudget logic (preserve usage unless currentUsage+baseline > newMaxLimit) + if clone.Budget != nil { + // Get existing budget from gs.budgets (NOT from VK.Budget which may be stale) + var existingBudget *configstoreTables.TableBudget + if existingBudgetValue, exists := gs.budgets.Load(clone.Budget.ID); exists && existingBudgetValue != nil { + if eb, ok := existingBudgetValue.(*configstoreTables.TableBudget); ok && eb != nil { + existingBudget = eb + } + } + budgetBaseline, exists := budgetBaselines[clone.Budget.ID] + if !exists { + budgetBaseline = 0.0 + } + clone.Budget = checkAndUpdateBudget(clone.Budget, existingBudget, budgetBaseline) + // Update the budget in the main budgets sync.Map + if clone.Budget != nil { + gs.budgets.Store(clone.Budget.ID, clone.Budget) + } + } else if existingVK.Budget != nil { + // Budget was removed from the virtual key, delete it from memory + gs.budgets.Delete(existingVK.Budget.ID) + } + if clone.RateLimit != nil { + // Get existing rate limit from gs.rateLimits (NOT from VK.RateLimit which may be stale) + var existingRateLimit *configstoreTables.TableRateLimit + if existingRateLimitValue, exists := gs.rateLimits.Load(clone.RateLimit.ID); exists && existingRateLimitValue != nil { + if erl, ok := existingRateLimitValue.(*configstoreTables.TableRateLimit); ok && erl != nil { + existingRateLimit = erl + } + } + tokenBaseline, exists := rateLimitTokensBaselines[clone.RateLimit.ID] + if !exists { + tokenBaseline = 0 + } + requestBaseline, exists := rateLimitRequestsBaselines[clone.RateLimit.ID] + if !exists { + requestBaseline = 0 + } + clone.RateLimit = checkAndUpdateRateLimit(clone.RateLimit, existingRateLimit, tokenBaseline, requestBaseline) + // Update the rate limit in the main rateLimits sync.Map + if clone.RateLimit != nil { + gs.rateLimits.Store(clone.RateLimit.ID, clone.RateLimit) + } + } else if existingVK.RateLimit != nil { + // Rate limit was removed from the virtual key, delete it from memory + gs.rateLimits.Delete(existingVK.RateLimit.ID) + } + if clone.ProviderConfigs != nil { + // Create a map of existing provider configs by ID for fast lookup + existingProviderConfigs := make(map[uint]configstoreTables.TableVirtualKeyProviderConfig) + if existingVK.ProviderConfigs != nil { + for _, existingPC := range existingVK.ProviderConfigs { + existingProviderConfigs[existingPC.ID] = existingPC + } + } + + // Process each new/updated provider config + for i, pc := range clone.ProviderConfigs { + if pc.RateLimit != nil { + // Get existing rate limit from gs.rateLimits (NOT from provider config which may be stale) + var existingProviderRateLimit *configstoreTables.TableRateLimit + if existingRateLimitValue, exists := gs.rateLimits.Load(pc.RateLimit.ID); exists && existingRateLimitValue != nil { + if erl, ok := existingRateLimitValue.(*configstoreTables.TableRateLimit); ok && erl != nil { + existingProviderRateLimit = erl + } + } + tokenBaseline, exists := rateLimitTokensBaselines[pc.RateLimit.ID] + if !exists { + tokenBaseline = 0 + } + requestBaseline, exists := rateLimitRequestsBaselines[pc.RateLimit.ID] + if !exists { + requestBaseline = 0 + } + clone.ProviderConfigs[i].RateLimit = checkAndUpdateRateLimit(pc.RateLimit, existingProviderRateLimit, tokenBaseline, requestBaseline) + // Also update the rate limit in the main rateLimits sync.Map + if clone.ProviderConfigs[i].RateLimit != nil { + gs.rateLimits.Store(clone.ProviderConfigs[i].RateLimit.ID, clone.ProviderConfigs[i].RateLimit) + } + } else { + // Rate limit was removed from provider config, delete it from memory if it existed + if existingPC, exists := existingProviderConfigs[pc.ID]; exists && existingPC.RateLimit != nil { + gs.rateLimits.Delete(existingPC.RateLimit.ID) + clone.ProviderConfigs[i].RateLimit = nil + } + } + // Update Budget for provider config (preserve usage unless currentUsage+baseline > newMaxLimit) + if pc.Budget != nil { + // Get existing budget from gs.budgets (NOT from provider config which may be stale) + var existingProviderBudget *configstoreTables.TableBudget + if existingBudgetValue, exists := gs.budgets.Load(pc.Budget.ID); exists && existingBudgetValue != nil { + if eb, ok := existingBudgetValue.(*configstoreTables.TableBudget); ok && eb != nil { + existingProviderBudget = eb + } + } + budgetBaseline, exists := budgetBaselines[pc.Budget.ID] + if !exists { + budgetBaseline = 0.0 + } + clone.ProviderConfigs[i].Budget = checkAndUpdateBudget(pc.Budget, existingProviderBudget, budgetBaseline) + // Also update the budget in the main budgets sync.Map + if clone.ProviderConfigs[i].Budget != nil { + gs.budgets.Store(clone.ProviderConfigs[i].Budget.ID, clone.ProviderConfigs[i].Budget) + } + } else { + // Budget was removed from provider config, delete it from memory if it existed + if existingPC, exists := existingProviderConfigs[pc.ID]; exists && existingPC.Budget != nil { + gs.budgets.Delete(existingPC.Budget.ID) + clone.ProviderConfigs[i].Budget = nil + } + } + } + } + gs.virtualKeys.Store(vk.Value, &clone) + } else { + gs.CreateVirtualKeyInMemory(vk) + } } // DeleteVirtualKeyInMemory removes a virtual key from the in-memory store -func (gs *GovernanceStore) DeleteVirtualKeyInMemory(vkID string) { +func (gs *LocalGovernanceStore) DeleteVirtualKeyInMemory(vkID string) { if vkID == "" { return // Nothing to delete } @@ -647,6 +1236,28 @@ func (gs *GovernanceStore) DeleteVirtualKeyInMemory(vkID string) { } if vk.ID == vkID { + // Delete associated budget if exists + if vk.BudgetID != nil { + gs.budgets.Delete(*vk.BudgetID) + } + + // Delete associated rate limit if exists + if vk.RateLimitID != nil { + gs.rateLimits.Delete(*vk.RateLimitID) + } + + // Delete provider config budgets and rate limits + if vk.ProviderConfigs != nil { + for _, pc := range vk.ProviderConfigs { + if pc.BudgetID != nil { + gs.budgets.Delete(*pc.BudgetID) + } + if pc.RateLimitID != nil { + gs.rateLimits.Delete(*pc.RateLimitID) + } + } + } + gs.virtualKeys.Delete(key) return false // stop iteration } @@ -655,74 +1266,403 @@ func (gs *GovernanceStore) DeleteVirtualKeyInMemory(vkID string) { } // CreateTeamInMemory adds a new team to the in-memory store (lock-free) -func (gs *GovernanceStore) CreateTeamInMemory(team *configstoreTables.TableTeam) { +func (gs *LocalGovernanceStore) CreateTeamInMemory(team *configstoreTables.TableTeam) { if team == nil { return // Nothing to create } + + // Create associated budget if exists + if team.Budget != nil { + gs.budgets.Store(team.Budget.ID, team.Budget) + } + gs.teams.Store(team.ID, team) } // UpdateTeamInMemory updates an existing team in the in-memory store (lock-free) -func (gs *GovernanceStore) UpdateTeamInMemory(team *configstoreTables.TableTeam) { +func (gs *LocalGovernanceStore) UpdateTeamInMemory(team *configstoreTables.TableTeam, budgetBaselines map[string]float64) { if team == nil { return // Nothing to update } - gs.teams.Store(team.ID, team) + if budgetBaselines == nil { + budgetBaselines = make(map[string]float64) + } + + // Check if there's an existing team to get current budget state + if existingTeamValue, exists := gs.teams.Load(team.ID); exists && existingTeamValue != nil { + existingTeam, ok := existingTeamValue.(*configstoreTables.TableTeam) + if !ok || existingTeam == nil { + return // Nothing to update + } + // Create clone to avoid modifying the original + clone := *team + + // Handle budget updates with consistent logic + if clone.Budget != nil { + // Get existing budget from gs.budgets (NOT from Team.Budget which may be stale) + var existingBudget *configstoreTables.TableBudget + if existingBudgetValue, exists := gs.budgets.Load(clone.Budget.ID); exists && existingBudgetValue != nil { + if eb, ok := existingBudgetValue.(*configstoreTables.TableBudget); ok && eb != nil { + existingBudget = eb + } + } + budgetBaseline, exists := budgetBaselines[clone.Budget.ID] + if !exists { + budgetBaseline = 0.0 + } + clone.Budget = checkAndUpdateBudget(clone.Budget, existingBudget, budgetBaseline) + // Update the budget in the main budgets sync.Map + if clone.Budget != nil { + gs.budgets.Store(clone.Budget.ID, clone.Budget) + } + } else if existingTeam.Budget != nil { + // Budget was removed from the team, delete it from memory + gs.budgets.Delete(existingTeam.Budget.ID) + } + + gs.teams.Store(team.ID, &clone) + } else { + gs.CreateTeamInMemory(team) + } } // DeleteTeamInMemory removes a team from the in-memory store (lock-free) -func (gs *GovernanceStore) DeleteTeamInMemory(teamID string) { +func (gs *LocalGovernanceStore) DeleteTeamInMemory(teamID string) { if teamID == "" { return // Nothing to delete } + + // Get team to check for associated budget + if teamValue, exists := gs.teams.Load(teamID); exists && teamValue != nil { + if team, ok := teamValue.(*configstoreTables.TableTeam); ok && team != nil { + // Delete associated budget if exists + if team.BudgetID != nil { + gs.budgets.Delete(*team.BudgetID) + } + } + } + + // Set team_id to null for all virtual keys associated with the team + // Iterate through all VKs since team.VirtualKeys may not be populated + gs.virtualKeys.Range(func(key, value interface{}) bool { + vk, ok := value.(*configstoreTables.TableVirtualKey) + if !ok || vk == nil { + return true // continue + } + if vk.TeamID != nil && *vk.TeamID == teamID { + clone := *vk + clone.TeamID = nil + clone.Team = nil + gs.virtualKeys.Store(key, &clone) + } + return true // continue iteration + }) + gs.teams.Delete(teamID) } // CreateCustomerInMemory adds a new customer to the in-memory store (lock-free) -func (gs *GovernanceStore) CreateCustomerInMemory(customer *configstoreTables.TableCustomer) { +func (gs *LocalGovernanceStore) CreateCustomerInMemory(customer *configstoreTables.TableCustomer) { if customer == nil { return // Nothing to create } + + // Create associated budget if exists + if customer.Budget != nil { + gs.budgets.Store(customer.Budget.ID, customer.Budget) + } + gs.customers.Store(customer.ID, customer) } // UpdateCustomerInMemory updates an existing customer in the in-memory store (lock-free) -func (gs *GovernanceStore) UpdateCustomerInMemory(customer *configstoreTables.TableCustomer) { +func (gs *LocalGovernanceStore) UpdateCustomerInMemory(customer *configstoreTables.TableCustomer, budgetBaselines map[string]float64) { if customer == nil { return // Nothing to update } - gs.customers.Store(customer.ID, customer) + if budgetBaselines == nil { + budgetBaselines = make(map[string]float64) + } + + // Check if there's an existing customer to get current budget state + if existingCustomerValue, exists := gs.customers.Load(customer.ID); exists && existingCustomerValue != nil { + existingCustomer, ok := existingCustomerValue.(*configstoreTables.TableCustomer) + if !ok || existingCustomer == nil { + return // Nothing to update + } + // Create clone to avoid modifying the original + clone := *customer + + // Handle budget updates with consistent logic + if clone.Budget != nil { + // Get existing budget from gs.budgets (NOT from Customer.Budget which may be stale) + var existingBudget *configstoreTables.TableBudget + if existingBudgetValue, exists := gs.budgets.Load(clone.Budget.ID); exists && existingBudgetValue != nil { + if eb, ok := existingBudgetValue.(*configstoreTables.TableBudget); ok && eb != nil { + existingBudget = eb + } + } + budgetBaseline, exists := budgetBaselines[clone.Budget.ID] + if !exists { + budgetBaseline = 0.0 + } + clone.Budget = checkAndUpdateBudget(clone.Budget, existingBudget, budgetBaseline) + // Update the budget in the main budgets sync.Map + if clone.Budget != nil { + gs.budgets.Store(clone.Budget.ID, clone.Budget) + } + } else if existingCustomer.Budget != nil { + // Budget was removed from the customer, delete it from memory + gs.budgets.Delete(existingCustomer.Budget.ID) + } + + gs.customers.Store(customer.ID, &clone) + } else { + gs.CreateCustomerInMemory(customer) + } } // DeleteCustomerInMemory removes a customer from the in-memory store (lock-free) -func (gs *GovernanceStore) DeleteCustomerInMemory(customerID string) { +func (gs *LocalGovernanceStore) DeleteCustomerInMemory(customerID string) { if customerID == "" { return // Nothing to delete } + + // Get customer to check for associated budget + if customerValue, exists := gs.customers.Load(customerID); exists && customerValue != nil { + if customer, ok := customerValue.(*configstoreTables.TableCustomer); ok && customer != nil { + // Delete associated budget if exists + if customer.BudgetID != nil { + gs.budgets.Delete(*customer.BudgetID) + } + } + } + + // Set customer_id to null for all virtual keys associated with the customer + // Iterate through all VKs since customer.VirtualKeys may not be populated + gs.virtualKeys.Range(func(key, value interface{}) bool { + vk, ok := value.(*configstoreTables.TableVirtualKey) + if !ok || vk == nil { + return true // continue + } + if vk.CustomerID != nil && *vk.CustomerID == customerID { + clone := *vk + clone.CustomerID = nil + clone.Customer = nil + gs.virtualKeys.Store(key, &clone) + } + return true // continue iteration + }) + + // Set customer_id to null for all teams associated with the customer + // Iterate through all teams since customer.Teams may not be populated + gs.teams.Range(func(key, value interface{}) bool { + team, ok := value.(*configstoreTables.TableTeam) + if !ok || team == nil { + return true // continue + } + if team.CustomerID != nil && *team.CustomerID == customerID { + clone := *team + clone.CustomerID = nil + clone.Customer = nil + gs.teams.Store(key, &clone) + } + return true // continue iteration + }) + gs.customers.Delete(customerID) } -// CreateBudgetInMemory adds a new budget to the in-memory store (lock-free) -func (gs *GovernanceStore) CreateBudgetInMemory(budget *configstoreTables.TableBudget) { - if budget == nil { - return // Nothing to create - } - gs.budgets.Store(budget.ID, budget) +// Helper functions + +// updateBudgetReferences updates all VKs, teams, customers, and provider configs that reference a reset budget +func (gs *LocalGovernanceStore) updateBudgetReferences(resetBudget *configstoreTables.TableBudget) { + budgetID := resetBudget.ID + // Update VKs that reference this budget + gs.virtualKeys.Range(func(key, value interface{}) bool { + vk, ok := value.(*configstoreTables.TableVirtualKey) + if !ok || vk == nil { + return true // continue + } + needsUpdate := false + clone := *vk + + // Check VK-level budget + if vk.BudgetID != nil && *vk.BudgetID == budgetID { + clone.Budget = resetBudget + needsUpdate = true + } + + // Check provider config budgets + if vk.ProviderConfigs != nil { + for i, pc := range clone.ProviderConfigs { + if pc.BudgetID != nil && *pc.BudgetID == budgetID { + clone.ProviderConfigs[i].Budget = resetBudget + needsUpdate = true + } + } + } + + if needsUpdate { + gs.virtualKeys.Store(key, &clone) + } + return true // continue + }) + + // Update teams that reference this budget + gs.teams.Range(func(key, value interface{}) bool { + team, ok := value.(*configstoreTables.TableTeam) + if !ok || team == nil { + return true // continue + } + if team.BudgetID != nil && *team.BudgetID == budgetID { + clone := *team + clone.Budget = resetBudget + gs.teams.Store(key, &clone) + } + return true // continue + }) + + // Update customers that reference this budget + gs.customers.Range(func(key, value interface{}) bool { + customer, ok := value.(*configstoreTables.TableCustomer) + if !ok || customer == nil { + return true // continue + } + if customer.BudgetID != nil && *customer.BudgetID == budgetID { + clone := *customer + clone.Budget = resetBudget + gs.customers.Store(key, &clone) + } + return true // continue + }) +} + +// updateRateLimitReferences updates all VKs and provider configs that reference a reset rate limit +func (gs *LocalGovernanceStore) updateRateLimitReferences(resetRateLimit *configstoreTables.TableRateLimit) { + rateLimitID := resetRateLimit.ID + // Update VKs that reference this rate limit + gs.virtualKeys.Range(func(key, value interface{}) bool { + vk, ok := value.(*configstoreTables.TableVirtualKey) + if !ok || vk == nil { + return true // continue + } + needsUpdate := false + clone := *vk + + // Check VK-level rate limit + if vk.RateLimitID != nil && *vk.RateLimitID == rateLimitID { + clone.RateLimit = resetRateLimit + needsUpdate = true + } + + // Check provider config rate limits + if vk.ProviderConfigs != nil { + for i, pc := range clone.ProviderConfigs { + if pc.RateLimitID != nil && *pc.RateLimitID == rateLimitID { + clone.ProviderConfigs[i].RateLimit = resetRateLimit + needsUpdate = true + } + } + } + + if needsUpdate { + gs.virtualKeys.Store(key, &clone) + } + return true // continue + }) } -// UpdateBudgetInMemory updates a specific budget in the in-memory cache (lock-free) -func (gs *GovernanceStore) UpdateBudgetInMemory(budget *configstoreTables.TableBudget) error { - if budget == nil { - return fmt.Errorf("budget cannot be nil") +// checkAndUpdateBudget checks and updates a budget with usage reset logic +// If currentUsage+baseline >= newMaxLimit, reset usage to 0 +// Otherwise preserve existing usage and accept reset duration and max limit changes +func checkAndUpdateBudget(budgetToUpdate *configstoreTables.TableBudget, existingBudget *configstoreTables.TableBudget, baseline float64) *configstoreTables.TableBudget { + // Create clone to avoid modifying the original + clone := *budgetToUpdate + if existingBudget == nil { + // New budget, return as-is + return budgetToUpdate } - gs.budgets.Store(budget.ID, budget) - return nil + + // Check if reset duration or max limit changed + resetDurationChanged := budgetToUpdate.ResetDuration != existingBudget.ResetDuration + maxLimitChanged := budgetToUpdate.MaxLimit != existingBudget.MaxLimit + + if resetDurationChanged || maxLimitChanged { + // If currentUsage + baseline >= new max limit, reset usage to 0 + // This handles the case where new max limit is lower than or equal to current usage + if existingBudget.CurrentUsage+baseline >= budgetToUpdate.MaxLimit { + clone.CurrentUsage = 0 + } else { + // Otherwise, preserve the existing usage from memory (which may have been updated) + clone.CurrentUsage = existingBudget.CurrentUsage + // Preserve LastDBUsage baseline to prevent multi-node baseline corruption + clone.LastDBUsage = existingBudget.LastDBUsage + } + } else { + // No changes to max limit or reset duration, preserve existing usage + clone.CurrentUsage = existingBudget.CurrentUsage + // Preserve LastDBUsage baseline to prevent multi-node baseline corruption + clone.LastDBUsage = existingBudget.LastDBUsage + } + + return &clone } -// DeleteBudgetInMemory removes a budget from the in-memory store (lock-free) -func (gs *GovernanceStore) DeleteBudgetInMemory(budgetID string) { - if budgetID == "" { - return // Nothing to delete +// checkAndUpdateRateLimit checks and updates a rate limit with usage reset logic +// If currentUsage+baseline > newMaxLimit, reset usage to 0 +// Otherwise preserve existing usage and accept reset duration and max limit changes +func checkAndUpdateRateLimit(rateLimitToUpdate *configstoreTables.TableRateLimit, existingRateLimit *configstoreTables.TableRateLimit, tokenBaseline int64, requestBaseline int64) *configstoreTables.TableRateLimit { + // Create clone to avoid modifying the original + clone := *rateLimitToUpdate + if existingRateLimit == nil { + // New rate limit, return as-is + return rateLimitToUpdate + } + + // Check if token settings changed + tokenMaxLimitChanged := !equalPtr(existingRateLimit.TokenMaxLimit, rateLimitToUpdate.TokenMaxLimit) + tokenResetDurationChanged := !equalPtr(existingRateLimit.TokenResetDuration, rateLimitToUpdate.TokenResetDuration) + + // Check if request settings changed + requestMaxLimitChanged := !equalPtr(existingRateLimit.RequestMaxLimit, rateLimitToUpdate.RequestMaxLimit) + requestResetDurationChanged := !equalPtr(existingRateLimit.RequestResetDuration, rateLimitToUpdate.RequestResetDuration) + + if tokenMaxLimitChanged || tokenResetDurationChanged { + // If currentUsage + baseline >= new max limit, reset usage to 0 + // This handles the case where new max limit is lower than or equal to current usage + if rateLimitToUpdate.TokenMaxLimit != nil && existingRateLimit.TokenCurrentUsage+tokenBaseline >= *rateLimitToUpdate.TokenMaxLimit { + clone.TokenCurrentUsage = 0 + } else { + // Otherwise, preserve the existing usage + clone.TokenCurrentUsage = existingRateLimit.TokenCurrentUsage + // Preserve LastDBTokenUsage baseline to prevent multi-node baseline corruption + clone.LastDBTokenUsage = existingRateLimit.LastDBTokenUsage + } + } else { + // No changes to max limit or reset duration, preserve existing usage + clone.TokenCurrentUsage = existingRateLimit.TokenCurrentUsage + // Preserve LastDBTokenUsage baseline to prevent multi-node baseline corruption + clone.LastDBTokenUsage = existingRateLimit.LastDBTokenUsage } - gs.budgets.Delete(budgetID) + + if requestMaxLimitChanged || requestResetDurationChanged { + // If currentUsage + baseline >= new max limit, reset usage to 0 + // This handles the case where new max limit is lower than or equal to current usage + if rateLimitToUpdate.RequestMaxLimit != nil && existingRateLimit.RequestCurrentUsage+requestBaseline >= *rateLimitToUpdate.RequestMaxLimit { + clone.RequestCurrentUsage = 0 + } else { + // Otherwise, preserve the existing usage + clone.RequestCurrentUsage = existingRateLimit.RequestCurrentUsage + // Preserve LastDBRequestUsage baseline to prevent multi-node baseline corruption + clone.LastDBRequestUsage = existingRateLimit.LastDBRequestUsage + } + } else { + // No changes to max limit or reset duration, preserve existing usage + clone.RequestCurrentUsage = existingRateLimit.RequestCurrentUsage + // Preserve LastDBRequestUsage baseline to prevent multi-node baseline corruption + clone.LastDBRequestUsage = existingRateLimit.LastDBRequestUsage + } + + return &clone } diff --git a/plugins/governance/store_test.go b/plugins/governance/store_test.go new file mode 100644 index 0000000000..0793df5419 --- /dev/null +++ b/plugins/governance/store_test.go @@ -0,0 +1,351 @@ +package governance + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestGovernanceStore_GetVirtualKey tests lock-free VK retrieval +func TestGovernanceStore_GetVirtualKey(t *testing.T) { + logger := NewMockLogger() + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{ + *buildVirtualKey("vk1", "sk-bf-test1", "Test VK 1", true), + *buildVirtualKey("vk2", "sk-bf-test2", "Test VK 2", false), + }, + }) + require.NoError(t, err) + + tests := []struct { + name string + vkValue string + wantNil bool + wantID string + }{ + { + name: "Found active VK", + vkValue: "sk-bf-test1", + wantNil: false, + wantID: "vk1", + }, + { + name: "Found inactive VK", + vkValue: "sk-bf-test2", + wantNil: false, + wantID: "vk2", + }, + { + name: "VK not found", + vkValue: "sk-bf-nonexistent", + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + vk, exists := store.GetVirtualKey(tt.vkValue) + if tt.wantNil { + assert.False(t, exists) + assert.Nil(t, vk) + } else { + assert.True(t, exists) + assert.NotNil(t, vk) + assert.Equal(t, tt.wantID, vk.ID) + } + }) + } +} + +// TestGovernanceStore_ConcurrentReads tests lock-free concurrent reads +func TestGovernanceStore_ConcurrentReads(t *testing.T) { + logger := NewMockLogger() + vk := buildVirtualKey("vk1", "sk-bf-test", "Test VK", true) + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + }) + require.NoError(t, err) + + // Launch 100 concurrent readers + var wg sync.WaitGroup + readCount := atomic.Int64{} + errorCount := atomic.Int64{} + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + vk, exists := store.GetVirtualKey("sk-bf-test") + if !exists || vk == nil { + errorCount.Add(1) + return + } + readCount.Add(1) + } + }() + } + + wg.Wait() + + assert.Equal(t, int64(10000), readCount.Load(), "Expected 10000 successful reads") + assert.Equal(t, int64(0), errorCount.Load(), "Expected 0 errors") +} + +// TestGovernanceStore_CheckBudget_SingleBudget tests budget validation with single budget +func TestGovernanceStore_CheckBudget_SingleBudget(t *testing.T) { + logger := NewMockLogger() + budget := buildBudgetWithUsage("budget1", 100.0, 50.0, "1d") + vk := buildVirtualKeyWithBudget("vk1", "sk-bf-test", "Test VK", budget) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*budget}, + }) + require.NoError(t, err) + + // Retrieve VK with budget + vk, _ = store.GetVirtualKey("sk-bf-test") + + tests := []struct { + name string + usage float64 + maxLimit float64 + shouldErr bool + }{ + { + name: "Usage below limit", + usage: 50.0, + maxLimit: 100.0, + shouldErr: false, + }, + { + name: "Usage at limit (should fail)", + usage: 100.0, + maxLimit: 100.0, + shouldErr: true, + }, + { + name: "Usage exceeds limit", + usage: 150.0, + maxLimit: 100.0, + shouldErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create new budget with test usage + testBudget := buildBudgetWithUsage("budget1", tt.maxLimit, tt.usage, "1d") + testVK := buildVirtualKeyWithBudget("vk1", "sk-bf-test", "Test VK", testBudget) + testStore, _ := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*testVK}, + Budgets: []configstoreTables.TableBudget{*testBudget}, + }) + + testVK, _ = testStore.GetVirtualKey("sk-bf-test") + err := testStore.CheckBudget(context.Background(), testVK, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + if tt.shouldErr { + assert.Error(t, err, "Expected error for usage check") + } else { + assert.NoError(t, err, "Expected no error for usage check") + } + }) + } +} + +// TestGovernanceStore_CheckBudget_HierarchyValidation tests multi-level budget hierarchy +func TestGovernanceStore_CheckBudget_HierarchyValidation(t *testing.T) { + logger := NewMockLogger() + + // Create budgets at different levels + vkBudget := buildBudgetWithUsage("vk-budget", 100.0, 50.0, "1d") + teamBudget := buildBudgetWithUsage("team-budget", 500.0, 200.0, "1d") + customerBudget := buildBudgetWithUsage("customer-budget", 1000.0, 400.0, "1d") + + // Build hierarchy + team := buildTeam("team1", "Team 1", teamBudget) + customer := buildCustomer("customer1", "Customer 1", customerBudget) + team.CustomerID = &customer.ID + team.Customer = customer + + vk := buildVirtualKeyWithBudget("vk1", "sk-bf-test", "Test VK", vkBudget) + vk.TeamID = &team.ID + vk.Team = team + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*vkBudget, *teamBudget, *customerBudget}, + Teams: []configstoreTables.TableTeam{*team}, + Customers: []configstoreTables.TableCustomer{*customer}, + }) + require.NoError(t, err) + + vk, _ = store.GetVirtualKey("sk-bf-test") + + // Test: All budgets under limit should pass + err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + assert.NoError(t, err, "Should pass when all budgets are under limit") + + // Test: If VK budget exceeds limit, should fail + // Update the budget directly in the budgets map (since UpdateVirtualKeyInMemory preserves usage) + if vk.BudgetID != nil { + if budgetValue, exists := store.budgets.Load(*vk.BudgetID); exists && budgetValue != nil { + if budget, ok := budgetValue.(*configstoreTables.TableBudget); ok && budget != nil { + budget.CurrentUsage = 100.0 + store.budgets.Store(*vk.BudgetID, budget) + } + } + } + err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + assert.Error(t, err, "Should fail when VK budget exceeds limit") +} + +// TestGovernanceStore_UpdateRateLimitUsage_TokensAndRequests tests atomic rate limit usage updates +func TestGovernanceStore_UpdateRateLimitUsage_TokensAndRequests(t *testing.T) { + logger := NewMockLogger() + + rateLimit := buildRateLimitWithUsage("rl1", 10000, 0, 1000, 0) + vk := buildVirtualKeyWithRateLimit("vk1", "sk-bf-test", "Test VK", rateLimit) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + RateLimits: []configstoreTables.TableRateLimit{*rateLimit}, + }) + require.NoError(t, err) + + // Test updating tokens + err = store.UpdateRateLimitUsageInMemory(context.Background(), vk, schemas.OpenAI, 500, true, false) + assert.NoError(t, err, "Rate limit update should succeed") + + // Retrieve the updated rate limit from the main RateLimits map + governanceData := store.GetGovernanceData() + updatedRateLimit, exists := governanceData.RateLimits["rl1"] + require.True(t, exists, "Rate limit should exist") + require.NotNil(t, updatedRateLimit) + + assert.Equal(t, int64(500), updatedRateLimit.TokenCurrentUsage, "Token usage should be updated") + assert.Equal(t, int64(0), updatedRateLimit.RequestCurrentUsage, "Request usage should not change") + + // Test updating requests + err = store.UpdateRateLimitUsageInMemory(context.Background(), vk, schemas.OpenAI, 0, false, true) + assert.NoError(t, err, "Rate limit update should succeed") + + // Retrieve the updated rate limit again + governanceData = store.GetGovernanceData() + updatedRateLimit, exists = governanceData.RateLimits["rl1"] + require.True(t, exists, "Rate limit should exist") + require.NotNil(t, updatedRateLimit) + + assert.Equal(t, int64(500), updatedRateLimit.TokenCurrentUsage, "Token usage should not change") + assert.Equal(t, int64(1), updatedRateLimit.RequestCurrentUsage, "Request usage should be incremented") +} + +// TestGovernanceStore_ResetExpiredRateLimits tests rate limit reset +func TestGovernanceStore_ResetExpiredRateLimits(t *testing.T) { + logger := NewMockLogger() + + // Create rate limit that's already expired + duration := "1m" + rateLimit := &configstoreTables.TableRateLimit{ + ID: "rl1", + TokenMaxLimit: ptrInt64(10000), + TokenCurrentUsage: 5000, + TokenResetDuration: &duration, + TokenLastReset: time.Now().Add(-2 * time.Minute), // Expired + RequestMaxLimit: ptrInt64(1000), + RequestCurrentUsage: 500, + RequestResetDuration: &duration, + RequestLastReset: time.Now().Add(-2 * time.Minute), // Expired + } + + vk := buildVirtualKeyWithRateLimit("vk1", "sk-bf-test", "Test VK", rateLimit) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + RateLimits: []configstoreTables.TableRateLimit{*rateLimit}, + }) + require.NoError(t, err) + + // Reset expired rate limits + expiredRateLimits := store.ResetExpiredRateLimitsInMemory(context.Background()) + err = store.ResetExpiredRateLimits(context.Background(), expiredRateLimits) + assert.NoError(t, err, "Reset should succeed") + + // Retrieve the updated VK to check rate limit changes + updatedVK, _ := store.GetVirtualKey("sk-bf-test") + require.NotNil(t, updatedVK) + require.NotNil(t, updatedVK.RateLimit) + + assert.Equal(t, int64(0), updatedVK.RateLimit.TokenCurrentUsage, "Token usage should be reset") + assert.Equal(t, int64(0), updatedVK.RateLimit.RequestCurrentUsage, "Request usage should be reset") +} + +// TestGovernanceStore_ResetExpiredBudgets tests budget reset +func TestGovernanceStore_ResetExpiredBudgets(t *testing.T) { + logger := NewMockLogger() + + // Create budget that's already expired + budget := &configstoreTables.TableBudget{ + ID: "budget1", + MaxLimit: 100.0, + CurrentUsage: 75.0, + ResetDuration: "1d", + LastReset: time.Now().Add(-48 * time.Hour), // Expired + } + + vk := buildVirtualKeyWithBudget("vk1", "sk-bf-test", "Test VK", budget) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*budget}, + }) + require.NoError(t, err) + + // Reset expired budgets + expiredBudgets := store.ResetExpiredBudgetsInMemory(context.Background()) + err = store.ResetExpiredBudgets(context.Background(), expiredBudgets) + assert.NoError(t, err, "Reset should succeed") + + // Retrieve the updated VK to check budget changes + updatedVK, _ := store.GetVirtualKey("sk-bf-test") + require.NotNil(t, updatedVK) + require.NotNil(t, updatedVK.Budget) + + assert.Equal(t, 0.0, updatedVK.Budget.CurrentUsage, "Budget usage should be reset") +} + +// TestGovernanceStore_GetAllBudgets tests retrieving all budgets +func TestGovernanceStore_GetAllBudgets(t *testing.T) { + logger := NewMockLogger() + + budgets := []configstoreTables.TableBudget{ + *buildBudget("budget1", 100.0, "1d"), + *buildBudget("budget2", 500.0, "1d"), + *buildBudget("budget3", 1000.0, "1d"), + } + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + Budgets: budgets, + }) + require.NoError(t, err) + + allBudgets := store.GetGovernanceData().Budgets + assert.Equal(t, 3, len(allBudgets), "Should have 3 budgets") + assert.NotNil(t, allBudgets["budget1"]) + assert.NotNil(t, allBudgets["budget2"]) + assert.NotNil(t, allBudgets["budget3"]) +} + +// Utility functions for tests +func ptrInt64(i int64) *int64 { + return &i +} diff --git a/plugins/governance/teambudget_test.go b/plugins/governance/teambudget_test.go new file mode 100644 index 0000000000..1323d056ae --- /dev/null +++ b/plugins/governance/teambudget_test.go @@ -0,0 +1,160 @@ +package governance + +import ( + "strconv" + "testing" +) + +// TestTeamBudgetExceededWithMultipleVKs tests that team level budgets are enforced across multiple VKs +// by making requests until budget is consumed +func TestTeamBudgetExceededWithMultipleVKs(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a team with a fixed budget + teamBudget := 0.01 + teamName := "test-team-budget-exceeded-" + generateRandomID() + createTeamResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/teams", + Body: CreateTeamRequest{ + Name: teamName, + Budget: &BudgetRequest{ + MaxLimit: teamBudget, + ResetDuration: "1h", + }, + }, + }) + + if createTeamResp.StatusCode != 200 { + t.Fatalf("Failed to create team: status %d", createTeamResp.StatusCode) + } + + teamID := ExtractIDFromResponse(t, createTeamResp, "id") + testData.AddTeam(teamID) + + // Create 2 VKs under the team + var vkValues []string + for i := 1; i <= 2; i++ { + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: "test-vk-" + generateRandomID(), + TeamID: &teamID, + Budget: &BudgetRequest{ + MaxLimit: 1.0, // High VK budget so team is the limiting factor + ResetDuration: "1h", + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK %d: status %d", i, createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValues = append(vkValues, vk["value"].(string)) + } + + t.Logf("Created team %s with budget $%.2f and 2 VKs", teamName, teamBudget) + + // Keep making requests alternating between VKs, tracking actual token usage until team budget is exceeded + consumedBudget := 0.0 + requestNum := 1 + var lastSuccessfulCost float64 + var shouldStop = false + vkIndex := 0 + + for requestNum <= 50 { + // Alternate between VKs to test shared team budget + vkValue := vkValues[vkIndex%2] + + // Create a longer prompt to consume more tokens and budget faster + longPrompt := "Please provide a comprehensive and detailed response to the following question. " + + "I need extensive information covering all aspects of the topic. " + + "Provide multiple paragraphs with detailed explanations. " + + "Request number " + strconv.Itoa(requestNum) + ". " + + "Here is a detailed prompt that will consume significant tokens: " + + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum. Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum." + + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: longPrompt, + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request failed - check if it's due to budget + if CheckErrorMessage(t, resp, "budget") || CheckErrorMessage(t, resp, "team") { + t.Logf("Request %d correctly rejected: team budget exceeded", requestNum) + t.Logf("Consumed budget: $%.6f (limit: $%.2f)", consumedBudget, teamBudget) + t.Logf("Last successful request cost: $%.6f", lastSuccessfulCost) + + // Verify that we made at least one successful request before hitting budget + if requestNum == 1 { + t.Fatalf("First request should have succeeded but was rejected due to budget") + } + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not budget): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract actual token usage from response + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + actualInputTokens := int(prompt) + actualOutputTokens := int(completion) + actualCost, _ := CalculateCost("openai/gpt-4o", actualInputTokens, actualOutputTokens) + + consumedBudget += actualCost + lastSuccessfulCost = actualCost + + t.Logf("Request %d (VK%d) succeeded: input_tokens=%d, output_tokens=%d, cost=$%.6f, consumed=$%.6f/$%.2f", + requestNum, (vkIndex%2)+1, actualInputTokens, actualOutputTokens, actualCost, consumedBudget, teamBudget) + } + } + } + + requestNum++ + vkIndex++ + + if shouldStop { + break + } + + if consumedBudget >= teamBudget { + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit team budget limit (consumed $%.6f / $%.2f) - budget not being enforced", + requestNum-1, consumedBudget, teamBudget) +} diff --git a/plugins/governance/test_utils.go b/plugins/governance/test_utils.go new file mode 100644 index 0000000000..3b9bf35274 --- /dev/null +++ b/plugins/governance/test_utils.go @@ -0,0 +1,424 @@ +package governance + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "math/rand" + "net/http" + "strings" + "testing" + "time" +) + +// ModelCost defines the cost structure for a model +type ModelCost struct { + Provider string + InputCostPerToken float64 + OutputCostPerToken float64 + MaxInputTokens int + MaxOutputTokens int +} + +// TestModels defines all models used for testing +var TestModels = map[string]ModelCost{ + "openai/gpt-4o": { + Provider: "openai", + InputCostPerToken: 0.0000025, + OutputCostPerToken: 0.00001, + MaxInputTokens: 128000, + MaxOutputTokens: 16384, + }, + "anthropic/claude-3-7-sonnet-20250219": { + Provider: "anthropic", + InputCostPerToken: 0.000003, + OutputCostPerToken: 0.000015, + MaxInputTokens: 200000, + MaxOutputTokens: 128000, + }, + "anthropic/claude-4-opus-20250514": { + Provider: "anthropic", + InputCostPerToken: 0.000015, + OutputCostPerToken: 0.000075, + MaxInputTokens: 200000, + MaxOutputTokens: 32000, + }, + "openrouter/anthropic/claude-3.7-sonnet": { + Provider: "openrouter", + InputCostPerToken: 0.000003, + OutputCostPerToken: 0.000015, + MaxInputTokens: 200000, + MaxOutputTokens: 128000, + }, + "openrouter/openai/gpt-4o": { + Provider: "openrouter", + InputCostPerToken: 0.0000025, + OutputCostPerToken: 0.00001, + MaxInputTokens: 128000, + MaxOutputTokens: 4096, + }, +} + +// CalculateCost calculates the cost based on input and output tokens +func CalculateCost(model string, inputTokens, outputTokens int) (float64, error) { + modelInfo, ok := TestModels[model] + if !ok { + return 0, fmt.Errorf("unknown model: %s", model) + } + + inputCost := float64(inputTokens) * modelInfo.InputCostPerToken + outputCost := float64(outputTokens) * modelInfo.OutputCostPerToken + return inputCost + outputCost, nil +} + +// APIRequest represents a request to the Bifrost API +type APIRequest struct { + Method string + Path string + Body interface{} + VKHeader *string +} + +// APIResponse represents a response from the Bifrost API +type APIResponse struct { + StatusCode int + Body map[string]interface{} + RawBody []byte +} + +// MakeRequest makes an HTTP request to the Bifrost API +func MakeRequest(t *testing.T, req APIRequest) *APIResponse { + client := &http.Client{} + url := fmt.Sprintf("http://localhost:8080%s", req.Path) + + var body io.Reader + if req.Body != nil { + bodyBytes, err := json.Marshal(req.Body) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + body = bytes.NewReader(bodyBytes) + } + + httpReq, err := http.NewRequest(req.Method, url, body) + if err != nil { + t.Fatalf("Failed to create HTTP request: %v", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + + // Add virtual key header if provided + if req.VKHeader != nil { + httpReq.Header.Set("x-bf-vk", *req.VKHeader) + } + + resp, err := client.Do(httpReq) + if err != nil { + t.Fatalf("Failed to execute HTTP request: %v", err) + } + defer resp.Body.Close() + + rawBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + var responseBody map[string]interface{} + if len(rawBody) > 0 { + err = json.Unmarshal(rawBody, &responseBody) + if err != nil { + // If unmarshaling fails, store the raw response + responseBody = map[string]interface{}{"raw": string(rawBody)} + } + } + + return &APIResponse{ + StatusCode: resp.StatusCode, + Body: responseBody, + RawBody: rawBody, + } +} + +// generateRandomID generates a random ID for test resources +func generateRandomID() string { + rand.Seed(time.Now().UnixNano()) + const letters = "abcdefghijklmnopqrstuvwxyz0123456789" + b := make([]byte, 8) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return string(b) +} + +// CreateVirtualKeyRequest represents a request to create a virtual key +type CreateVirtualKeyRequest struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + IsActive *bool `json:"is_active,omitempty"` + TeamID *string `json:"team_id,omitempty"` + CustomerID *string `json:"customer_id,omitempty"` + Budget *BudgetRequest `json:"budget,omitempty"` + RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"` + ProviderConfigs []ProviderConfigRequest `json:"provider_configs,omitempty"` +} + +// ProviderConfigRequest represents a provider configuration for a virtual key +type ProviderConfigRequest struct { + ID *uint `json:"id,omitempty"` + Provider string `json:"provider"` + Weight float64 `json:"weight,omitempty"` + AllowedModels []string `json:"allowed_models,omitempty"` + Budget *BudgetRequest `json:"budget,omitempty"` + RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"` +} + +// BudgetRequest represents a budget request +type BudgetRequest struct { + MaxLimit float64 `json:"max_limit"` + ResetDuration string `json:"reset_duration"` +} + +// CreateTeamRequest represents a request to create a team +type CreateTeamRequest struct { + Name string `json:"name"` + CustomerID *string `json:"customer_id,omitempty"` + Budget *BudgetRequest `json:"budget,omitempty"` +} + +// CreateCustomerRequest represents a request to create a customer +type CreateCustomerRequest struct { + Name string `json:"name"` + Budget *BudgetRequest `json:"budget,omitempty"` +} + +// UpdateBudgetRequest represents a request to update a budget +type UpdateBudgetRequest struct { + MaxLimit *float64 `json:"max_limit,omitempty"` + ResetDuration *string `json:"reset_duration,omitempty"` +} + +// CreateRateLimitRequest represents a request to create a rate limit +type CreateRateLimitRequest struct { + TokenMaxLimit *int64 `json:"token_max_limit,omitempty"` + TokenResetDuration *string `json:"token_reset_duration,omitempty"` + RequestMaxLimit *int64 `json:"request_max_limit,omitempty"` + RequestResetDuration *string `json:"request_reset_duration,omitempty"` +} + +// UpdateVirtualKeyRequest represents a request to update a virtual key +type UpdateVirtualKeyRequest struct { + Name *string `json:"name,omitempty"` + TeamID *string `json:"team_id,omitempty"` + CustomerID *string `json:"customer_id,omitempty"` + Budget *UpdateBudgetRequest `json:"budget,omitempty"` + RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"` + IsActive *bool `json:"is_active,omitempty"` + ProviderConfigs []ProviderConfigRequest `json:"provider_configs,omitempty"` +} + +// UpdateTeamRequest represents a request to update a team +type UpdateTeamRequest struct { + Name *string `json:"name,omitempty"` + Budget *UpdateBudgetRequest `json:"budget,omitempty"` +} + +// UpdateCustomerRequest represents a request to update a customer +type UpdateCustomerRequest struct { + Name *string `json:"name,omitempty"` + Budget *UpdateBudgetRequest `json:"budget,omitempty"` +} + +// ChatCompletionRequest represents an OpenAI-compatible chat completion request +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + Temperature *float64 `json:"temperature,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + TopP *float64 `json:"top_p,omitempty"` +} + +// ChatMessage represents a chat message in OpenAI format +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ExtractIDFromResponse extracts the ID from a creation response +func ExtractIDFromResponse(t *testing.T, resp *APIResponse, keyPath string) string { + if resp.StatusCode >= 400 { + t.Fatalf("Request failed with status %d: %v", resp.StatusCode, resp.Body) + } + + // Navigate through the response to find the ID + data := resp.Body + parts := []string{"virtual_key", "team", "customer"} + for _, part := range parts { + if val, ok := data[part]; ok { + if nested, ok := val.(map[string]interface{}); ok { + if id, ok := nested["id"].(string); ok { + return id + } + } + } + } + + t.Fatalf("Could not extract ID from response: %v", resp.Body) + return "" +} + +// CheckErrorMessage checks if the response error contains expected text +// Returns true if error found, false otherwise. Asserts fail if status is not >= 400. +func CheckErrorMessage(t *testing.T, resp *APIResponse, expectedText string) bool { + if resp.StatusCode < 400 { + t.Fatalf("Expected error response but got status %d. Response: %v", resp.StatusCode, resp.Body) + } + + // Check in various fields where errors might appear + if msg, ok := resp.Body["message"].(string); ok && contains(msg, expectedText) { + return true + } + + if err, ok := resp.Body["error"].(string); ok && contains(err, expectedText) { + return true + } + + // Check raw body as fallback + if contains(string(resp.RawBody), expectedText) { + return true + } + + return false +} + +// contains checks if a string contains a substring (case-insensitive) +func contains(haystack, needle string) bool { + return strings.Contains(strings.ToLower(haystack), strings.ToLower(needle)) +} + +// GlobalTestData stores IDs of created resources for cleanup +type GlobalTestData struct { + VirtualKeys []string + Teams []string + Customers []string +} + +// NewGlobalTestData creates a new test data holder +func NewGlobalTestData() *GlobalTestData { + return &GlobalTestData{ + VirtualKeys: make([]string, 0), + Teams: make([]string, 0), + Customers: make([]string, 0), + } +} + +// AddVirtualKey adds a virtual key ID to the test data +func (g *GlobalTestData) AddVirtualKey(id string) { + g.VirtualKeys = append(g.VirtualKeys, id) +} + +// AddTeam adds a team ID to the test data +func (g *GlobalTestData) AddTeam(id string) { + g.Teams = append(g.Teams, id) +} + +// AddCustomer adds a customer ID to the test data +func (g *GlobalTestData) AddCustomer(id string) { + g.Customers = append(g.Customers, id) +} + +// Cleanup deletes all created resources +func (g *GlobalTestData) Cleanup(t *testing.T) { + // Delete virtual keys + for _, vkID := range g.VirtualKeys { + resp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: fmt.Sprintf("/api/governance/virtual-keys/%s", vkID), + }) + if resp.StatusCode >= 400 && resp.StatusCode != 404 { + t.Logf("Warning: failed to delete virtual key %s: status %d", vkID, resp.StatusCode) + } + } + + // Delete teams + for _, teamID := range g.Teams { + resp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: fmt.Sprintf("/api/governance/teams/%s", teamID), + }) + if resp.StatusCode >= 400 && resp.StatusCode != 404 { + t.Logf("Warning: failed to delete team %s: status %d", teamID, resp.StatusCode) + } + } + + // Delete customers + for _, customerID := range g.Customers { + resp := MakeRequest(t, APIRequest{ + Method: "DELETE", + Path: fmt.Sprintf("/api/governance/customers/%s", customerID), + }) + if resp.StatusCode >= 400 && resp.StatusCode != 404 { + t.Logf("Warning: failed to delete customer %s: status %d", customerID, resp.StatusCode) + } + } + + t.Logf("Cleanup completed: deleted %d VKs, %d teams, %d customers", + len(g.VirtualKeys), len(g.Teams), len(g.Customers)) +} + +// WaitForCondition polls a condition function until it returns true or times out +// Useful for waiting for async updates to propagate to in-memory store +func WaitForCondition(t *testing.T, checkFunc func() bool, timeout time.Duration, description string) bool { + deadline := time.Now().Add(timeout) + attempt := 0 + + for time.Now().Before(deadline) { + attempt++ + if checkFunc() { + if attempt > 1 { + t.Logf("Condition '%s' met after %d attempts", description, attempt) + } + return true + } + + // Progressive backoff: start with 50ms, max 500ms + sleepDuration := time.Duration(50*attempt) * time.Millisecond + if sleepDuration > 500*time.Millisecond { + sleepDuration = 500 * time.Millisecond + } + time.Sleep(sleepDuration) + } + + t.Logf("Timeout waiting for condition '%s' after %d attempts (%.1fs)", description, attempt, timeout.Seconds()) + return false +} + +// WaitForAPICondition makes repeated API requests until a condition is satisfied or times out +// Useful for verifying async updates in API responses +func WaitForAPICondition(t *testing.T, req APIRequest, condition func(*APIResponse) bool, timeout time.Duration, description string) (*APIResponse, bool) { + deadline := time.Now().Add(timeout) + attempt := 0 + var lastResp *APIResponse + + for time.Now().Before(deadline) { + attempt++ + lastResp = MakeRequest(t, req) + + if condition(lastResp) { + if attempt > 1 { + t.Logf("API condition '%s' met after %d attempts", description, attempt) + } + return lastResp, true + } + + // Progressive backoff: start with 100ms, max 500ms + sleepDuration := time.Duration(100*attempt) * time.Millisecond + if sleepDuration > 500*time.Millisecond { + sleepDuration = 500 * time.Millisecond + } + time.Sleep(sleepDuration) + } + + t.Logf("Timeout waiting for API condition '%s' after %d attempts (%.1fs)", description, attempt, timeout.Seconds()) + return lastResp, false +} diff --git a/plugins/governance/tracker.go b/plugins/governance/tracker.go index 67c0831044..1a10622a51 100644 --- a/plugins/governance/tracker.go +++ b/plugins/governance/tracker.go @@ -10,6 +10,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "gorm.io/gorm" ) // UsageUpdate contains data for VK-level usage tracking @@ -30,7 +31,7 @@ type UsageUpdate struct { // UsageTracker manages VK-level usage tracking and budget management type UsageTracker struct { - store *GovernanceStore + store GovernanceStore resolver *BudgetResolver configStore configstore.ConfigStore logger schemas.Logger @@ -43,8 +44,12 @@ type UsageTracker struct { wg sync.WaitGroup } +const ( + workerInterval = 10 * time.Second +) + // NewUsageTracker creates a new usage tracker for the hierarchical budget system -func NewUsageTracker(ctx context.Context, store *GovernanceStore, resolver *BudgetResolver, configStore configstore.ConfigStore, logger schemas.Logger) *UsageTracker { +func NewUsageTracker(ctx context.Context, store GovernanceStore, resolver *BudgetResolver, configStore configstore.ConfigStore, logger schemas.Logger) *UsageTracker { tracker := &UsageTracker{ store: store, resolver: resolver, @@ -57,7 +62,6 @@ func NewUsageTracker(ctx context.Context, store *GovernanceStore, resolver *Budg tracker.trackerCtx, tracker.trackerCancel = context.WithCancel(context.Background()) tracker.startWorkers(tracker.trackerCtx) - tracker.logger.Info("usage tracker initialized for hierarchical budget system") return tracker } @@ -66,7 +70,6 @@ func (t *UsageTracker) UpdateUsage(ctx context.Context, update *UsageUpdate) { // Get virtual key vk, exists := t.store.GetVirtualKey(update.VirtualKey) if !exists { - t.logger.Debug(fmt.Sprintf("Virtual key not found: %s", update.VirtualKey)) return } @@ -83,29 +86,25 @@ func (t *UsageTracker) UpdateUsage(ctx context.Context, update *UsageUpdate) { // Update rate limit usage (both provider-level and VK-level) if applicable if vk.RateLimit != nil || len(vk.ProviderConfigs) > 0 { - if err := t.store.UpdateRateLimitUsage(ctx, update.VirtualKey, string(update.Provider), update.TokensUsed, shouldUpdateTokens, shouldUpdateRequests); err != nil { + if err := t.store.UpdateRateLimitUsageInMemory(ctx, vk, update.Provider, update.TokensUsed, shouldUpdateTokens, shouldUpdateRequests); err != nil { t.logger.Error("failed to update rate limit usage for VK %s: %v", vk.ID, err) } } // Update budget usage in hierarchy (VK → Team → Customer) only if we have usage data if shouldUpdateBudget && update.Cost > 0 { - t.updateBudgetHierarchy(ctx, vk, update) - } -} - -// updateBudgetHierarchy updates budget usage atomically in the VK → Team → Customer hierarchy -func (t *UsageTracker) updateBudgetHierarchy(ctx context.Context, vk *configstoreTables.TableVirtualKey, update *UsageUpdate) { - // Use atomic budget update to prevent race conditions and ensure consistency - if err := t.store.UpdateBudget(ctx, vk, update.Provider, update.Cost); err != nil { - t.logger.Error("failed to update budget hierarchy atomically for VK %s: %v", vk.ID, err) + t.logger.Debug("updating budget usage for VK %s", vk.ID) + // Use atomic budget update to prevent race conditions and ensure consistency + if err := t.store.UpdateBudgetUsageInMemory(ctx, vk, update.Provider, update.Cost); err != nil { + t.logger.Error("failed to update budget hierarchy atomically for VK %s: %v", vk.ID, err) + } } } // startWorkers starts all background workers for business logic func (t *UsageTracker) startWorkers(ctx context.Context) { // Counter reset manager (business logic) - t.resetTicker = time.NewTicker(1 * time.Minute) + t.resetTicker = time.NewTicker(workerInterval) t.wg.Add(1) go t.resetWorker(ctx) } @@ -128,14 +127,24 @@ func (t *UsageTracker) resetWorker(ctx context.Context) { // resetExpiredCounters manages periodic resets of usage counters AND budgets using flexible durations func (t *UsageTracker) resetExpiredCounters(ctx context.Context) { // ==== PART 1: Reset Rate Limits ==== - if err := t.store.ResetExpiredRateLimits(ctx); err != nil { + resetRateLimits := t.store.ResetExpiredRateLimitsInMemory(ctx) + if err := t.store.ResetExpiredRateLimits(ctx, resetRateLimits); err != nil { t.logger.Error("failed to reset expired rate limits: %v", err) } // ==== PART 2: Reset Budgets ==== - if err := t.store.ResetExpiredBudgets(ctx); err != nil { + resetBudgets := t.store.ResetExpiredBudgetsInMemory(ctx) + if err := t.store.ResetExpiredBudgets(ctx, resetBudgets); err != nil { t.logger.Error("failed to reset expired budgets: %v", err) } + + // ==== PART 3: Dump all rate limits and budgets to database ==== + if err := t.store.DumpRateLimits(ctx, nil, nil); err != nil { + t.logger.Error("failed to dump rate limits to database: %v", err) + } + if err := t.store.DumpBudgets(ctx, nil); err != nil { + t.logger.Error("failed to dump budgets to database: %v", err) + } } // Public methods for monitoring and admin operations @@ -147,7 +156,7 @@ func (t *UsageTracker) PerformStartupResets(ctx context.Context) error { return nil } - t.logger.Info("performing startup reset check for expired rate limits and budgets") + t.logger.Debug("performing startup reset check for expired rate limits and budgets") now := time.Now() var resetRateLimits []*configstoreTables.TableRateLimit @@ -210,16 +219,38 @@ func (t *UsageTracker) PerformStartupResets(ctx context.Context) error { } // DB reset is also handled by this function - if err := t.store.ResetExpiredBudgets(ctx); err != nil { + resetBudgets := t.store.ResetExpiredBudgetsInMemory(ctx) + if err := t.store.ResetExpiredBudgets(ctx, resetBudgets); err != nil { errs = append(errs, fmt.Sprintf("failed to reset expired budgets: %s", err.Error())) } // ==== PERSIST RESETS TO DATABASE ==== - if t.configStore != nil { - if len(resetRateLimits) > 0 { - if err := t.configStore.UpdateRateLimits(ctx, resetRateLimits); err != nil { - errs = append(errs, fmt.Sprintf("failed to persist rate limit resets: %s", err.Error())) + // Use selective updates to avoid overwriting config fields (max_limit, reset_duration) + if t.configStore != nil && len(resetRateLimits) > 0 { + if err := t.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { + for _, rateLimit := range resetRateLimits { + // Build update map with only the fields that were reset + updates := make(map[string]interface{}) + updates["token_current_usage"] = rateLimit.TokenCurrentUsage + updates["token_last_reset"] = rateLimit.TokenLastReset + updates["request_current_usage"] = rateLimit.RequestCurrentUsage + updates["request_last_reset"] = rateLimit.RequestLastReset + + // Direct UPDATE only resets usage and last_reset fields + // This prevents overwriting max_limit or reset_duration that may have been changed during startup + result := tx.WithContext(ctx). + Session(&gorm.Session{SkipHooks: true}). + Model(&configstoreTables.TableRateLimit{}). + Where("id = ?", rateLimit.ID). + Updates(updates) + + if result.Error != nil { + return fmt.Errorf("failed to reset rate limit %s: %w", rateLimit.ID, result.Error) + } } + return nil + }); err != nil { + errs = append(errs, fmt.Sprintf("failed to persist rate limit resets: %s", err.Error())) } } if len(errs) > 0 { diff --git a/plugins/governance/tracker_test.go b/plugins/governance/tracker_test.go new file mode 100644 index 0000000000..76f7e37a9c --- /dev/null +++ b/plugins/governance/tracker_test.go @@ -0,0 +1,166 @@ +package governance + +import ( + "context" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestUsageTracker_UpdateUsage_FailedRequest tests usage tracking for a failed request +func TestUsageTracker_UpdateUsage_FailedRequest(t *testing.T) { + logger := NewMockLogger() + + budget := buildBudgetWithUsage("budget1", 1000.0, 0.0, "1d") + vk := buildVirtualKeyWithBudget("vk1", "sk-bf-test", "Test VK", budget) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*budget}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + tracker := NewUsageTracker(context.Background(), store, resolver, nil, logger) + defer tracker.Cleanup() + + update := &UsageUpdate{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + Success: false, // Failed request + TokensUsed: 100, + Cost: 25.5, + RequestID: "req-123", + } + + tracker.UpdateUsage(context.Background(), update) + + // Give time for async processing + time.Sleep(200 * time.Millisecond) + + // Verify budget was NOT updated - retrieve from store + budgets := store.GetGovernanceData().Budgets + updatedBudget, exists := budgets["budget1"] + require.True(t, exists) + require.NotNil(t, updatedBudget) + + assert.Equal(t, 0.0, updatedBudget.CurrentUsage, "Failed request should not update budget") +} + +// TestUsageTracker_UpdateUsage_VirtualKeyNotFound tests handling of missing VK +func TestUsageTracker_UpdateUsage_VirtualKeyNotFound(t *testing.T) { + logger := NewMockLogger() + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + tracker := NewUsageTracker(context.Background(), store, resolver, nil, logger) + defer tracker.Cleanup() + + update := &UsageUpdate{ + VirtualKey: "sk-bf-nonexistent", + Provider: schemas.OpenAI, + Model: "gpt-4", + Success: true, + TokensUsed: 100, + Cost: 25.5, + } + + // Should not panic or error + tracker.UpdateUsage(context.Background(), update) + + time.Sleep(100 * time.Millisecond) + // Just verify it doesn't crash + assert.True(t, true) +} + +// TestUsageTracker_UpdateUsage_StreamingOptimization tests streaming request handling +func TestUsageTracker_UpdateUsage_StreamingOptimization(t *testing.T) { + logger := NewMockLogger() + + rateLimit := buildRateLimitWithUsage("rl1", 10000, 0, 1000, 0) + vk := buildVirtualKeyWithRateLimit("vk1", "sk-bf-test", "Test VK", rateLimit) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + RateLimits: []configstoreTables.TableRateLimit{*rateLimit}, + }) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + tracker := NewUsageTracker(context.Background(), store, resolver, nil, logger) + defer tracker.Cleanup() + + // First streaming chunk (not final, has usage data) + update1 := &UsageUpdate{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + Success: true, + TokensUsed: 50, + Cost: 0.0, // No cost on non-final chunks + RequestID: "req-123", + IsStreaming: true, + IsFinalChunk: false, + HasUsageData: true, + } + + tracker.UpdateUsage(context.Background(), update1) + time.Sleep(200 * time.Millisecond) + + // Retrieve the updated rate limit from the main RateLimits map + governanceData := store.GetGovernanceData() + updatedRateLimit, exists := governanceData.RateLimits["rl1"] + require.True(t, exists, "Rate limit should exist") + require.NotNil(t, updatedRateLimit) + + // Tokens should be updated but not requests (not final chunk) + assert.Equal(t, int64(50), updatedRateLimit.TokenCurrentUsage, "Tokens should be updated on non-final chunk") + + // Final chunk + update2 := &UsageUpdate{ + VirtualKey: "sk-bf-test", + Provider: schemas.OpenAI, + Model: "gpt-4", + Success: true, + TokensUsed: 0, // Already counted + Cost: 12.5, + RequestID: "req-123", + IsStreaming: true, + IsFinalChunk: true, + HasUsageData: true, + } + + tracker.UpdateUsage(context.Background(), update2) + time.Sleep(200 * time.Millisecond) + + // Retrieve the updated rate limit again + governanceData = store.GetGovernanceData() + updatedRateLimit, exists = governanceData.RateLimits["rl1"] + require.True(t, exists, "Rate limit should exist") + require.NotNil(t, updatedRateLimit) + + // Request counter should be updated on final chunk + assert.Equal(t, int64(1), updatedRateLimit.RequestCurrentUsage, "Request should be incremented on final chunk") +} + +// TestUsageTracker_Cleanup tests cleanup of the usage tracker +func TestUsageTracker_Cleanup(t *testing.T) { + logger := NewMockLogger() + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, logger) + tracker := NewUsageTracker(context.Background(), store, resolver, nil, logger) + + // Should cleanup without error + err = tracker.Cleanup() + assert.NoError(t, err, "Cleanup should succeed") +} diff --git a/plugins/governance/usagetracking_test.go b/plugins/governance/usagetracking_test.go new file mode 100644 index 0000000000..8564a1e68b --- /dev/null +++ b/plugins/governance/usagetracking_test.go @@ -0,0 +1,571 @@ +package governance + +import ( + "testing" + "time" +) + +// TestUsageTrackingRateLimitReset tests that rate limit resets happen correctly on ticker +func TestUsageTrackingRateLimitReset(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a rate limit that resets every 30 seconds + vkName := "test-vk-rate-limit-reset-" + generateRandomID() + tokenLimit := int64(10000) // 10k token limit + tokenResetDuration := "30s" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with rate limit: %d tokens reset every %s", vkName, tokenLimit, tokenResetDuration) + + // Get initial rate limit data from data endpoint + getVKResp1 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getVKResp1.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getVKResp1.StatusCode) + } + + virtualKeysMap1 := getVKResp1.Body["virtual_keys"].(map[string]interface{}) + vkData1 := virtualKeysMap1[vkValue].(map[string]interface{}) + rateLimitID, _ := vkData1["rate_limit_id"].(string) + if rateLimitID == "" { + t.Fatalf("Rate limit ID not found for VK") + } + + t.Logf("Rate limit ID: %s", rateLimitID) + + // Make a request to consume tokens + // Cost should be approximately: 5000 * 0.0000025 + 100 * 0.00001 = 0.013-0.014 dollars + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "This is a test prompt to consume tokens for rate limit testing.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Logf("Request failed with status %d (may be due to other limits), body: %v", resp.StatusCode, resp.Body) + t.Skip("Could not execute request to test rate limit reset") + } + + // Extract token count from response + var tokensUsed int + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if totalTokens, ok := usage["total_tokens"].(float64); ok { + tokensUsed = int(totalTokens) + } + } + + if tokensUsed == 0 { + t.Logf("No token usage in response, cannot verify rate limit reset") + t.Skip("Could not extract token usage from response") + } + + t.Logf("Request consumed %d tokens", tokensUsed) + + // Get rate limit data after request + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + // Rate limit counter should have been updated + t.Logf("Rate limit should be tracking usage in in-memory store") + + // Wait for more than 30 seconds for the rate limit to reset + t.Logf("Waiting 35 seconds for rate limit ticker to reset...") + time.Sleep(35 * time.Second) + + // Get rate limit data after reset + getDataResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp3.StatusCode != 200 { + t.Fatalf("Failed to get governance data after reset wait: status %d", getDataResp3.StatusCode) + } + + // Verify rate limit has been reset (usage should be 0 or close to it) + t.Logf("Rate limit reset should have occurred after 30s timeout āœ“") +} + +// TestUsageTrackingBudgetReset tests that budget resets happen correctly on ticker +func TestUsageTrackingBudgetReset(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a budget that resets every 30 seconds + vkName := "test-vk-budget-reset-" + generateRandomID() + budgetLimit := 1.0 // $1 budget + resetDuration := "30s" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: budgetLimit, + ResetDuration: resetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with budget: $%.2f reset every %s", vkName, budgetLimit, resetDuration) + + // Get initial budget data + getVKResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap := getVKResp.Body["virtual_keys"].(map[string]interface{}) + + getBudgetsResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap := getBudgetsResp.Body["budgets"].(map[string]interface{}) + + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + budgetID, _ := vkData["budget_id"].(string) + if budgetID == "" { + t.Fatalf("Budget ID not found for VK") + } + + budgetData := budgetsMap[budgetID].(map[string]interface{}) + initialUsage, _ := budgetData["current_usage"].(float64) + + t.Logf("Initial budget usage: $%.6f", initialUsage) + + // Make a request to consume budget + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test prompt for budget reset testing.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Logf("Request failed with status %d, body: %v", resp.StatusCode, resp.Body) + t.Skip("Could not execute request to test budget reset") + } + + // Get updated budget usage + time.Sleep(500 * time.Millisecond) + + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + budgetData2 := budgetsMap2[budgetID].(map[string]interface{}) + usageAfterRequest, _ := budgetData2["current_usage"].(float64) + + t.Logf("Budget usage after request: $%.6f", usageAfterRequest) + + // Wait for budget reset + t.Logf("Waiting 35 seconds for budget ticker to reset...") + time.Sleep(35 * time.Second) + + // Get budget data after reset + getDataResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp3.StatusCode != 200 { + t.Fatalf("Failed to get governance data after reset wait: status %d", getDataResp3.StatusCode) + } + + getBudgetsResp3 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap3 := getBudgetsResp3.Body["budgets"].(map[string]interface{}) + budgetData3 := budgetsMap3[budgetID].(map[string]interface{}) + usageAfterReset, _ := budgetData3["current_usage"].(float64) + + // Budget should be reset (close to 0) + if usageAfterReset > 0.001 { + t.Fatalf("Budget not reset after 30s timeout: usage is $%.6f (should be ~0)", usageAfterReset) + } + + t.Logf("Budget reset correctly after 30s timeout āœ“") +} + +// TestInMemoryUsageUpdateOnRequest tests that in-memory usage counters are updated on request +func TestInMemoryUsageUpdateOnRequest(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with no limits (to ensure request succeeds) + vkName := "test-vk-usage-update-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s for usage tracking test", vkName) + + // Make a request to consume tokens + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Short test prompt for usage tracking.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Logf("Request failed with status %d", resp.StatusCode) + t.Skip("Could not execute request to test usage tracking") + } + + // Extract token usage from response + var tokensUsed int + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if totalTokens, ok := usage["total_tokens"].(float64); ok { + tokensUsed = int(totalTokens) + } + } + + if tokensUsed == 0 { + t.Logf("No token usage in response") + t.Skip("Could not extract token usage from response") + } + + t.Logf("Request consumed %d tokens", tokensUsed) + + // Give time for async update + time.Sleep(1 * time.Second) + + // Check in-memory store for updated rate limit usage + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + + // Rate limit should exist and be updated + rateLimitID, _ := vkData["rate_limit_id"].(string) + if rateLimitID != "" { + t.Logf("Rate limit tracking is enabled for VK āœ“") + } else { + t.Logf("No rate limit on VK (optional)") + } + + t.Logf("In-memory usage tracking verified āœ“") +} + +// TestResetTickerBothBudgetAndRateLimit tests that ticker resets both budget and rate limit together +func TestResetTickerBothBudgetAndRateLimit(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with both budget and rate limit that reset every 30 seconds + vkName := "test-vk-both-reset-" + generateRandomID() + budgetLimit := 2.0 + budgetResetDuration := "30s" + tokenLimit := int64(50000) + tokenResetDuration := "30s" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: budgetLimit, + ResetDuration: budgetResetDuration, + }, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with budget and rate limit both resetting every 30s", vkName) + + // Make requests to consume both budget and tokens + for i := 0; i < 3; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Test request " + string(rune('0'+i)) + " for reset ticker test.", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode != 200 { + t.Logf("Request %d failed with status %d", i+1, resp.StatusCode) + break + } + t.Logf("Request %d succeeded", i+1) + } + + // Get usage before reset + getVKResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + virtualKeysMap := getVKResp.Body["virtual_keys"].(map[string]interface{}) + + getBudgetsResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap := getBudgetsResp.Body["budgets"].(map[string]interface{}) + + vkData := virtualKeysMap[vkValue].(map[string]interface{}) + budgetID, _ := vkData["budget_id"].(string) + + var usageBeforeReset float64 + if budgetID != "" { + budgetData := budgetsMap[budgetID].(map[string]interface{}) + usageBeforeReset, _ = budgetData["current_usage"].(float64) + } + + t.Logf("Budget usage before reset: $%.6f", usageBeforeReset) + + // Wait for reset + t.Logf("Waiting 35 seconds for reset ticker...") + time.Sleep(35 * time.Second) + + // Get usage after reset + getBudgetsResp2 := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/budgets?from_memory=true", + }) + + budgetsMap2 := getBudgetsResp2.Body["budgets"].(map[string]interface{}) + + var usageAfterReset float64 + if budgetID != "" { + budgetData2 := budgetsMap2[budgetID].(map[string]interface{}) + usageAfterReset, _ = budgetData2["current_usage"].(float64) + } + + t.Logf("Budget usage after reset: $%.6f", usageAfterReset) + + if usageBeforeReset > 0 && usageAfterReset >= usageBeforeReset { + t.Fatalf("Budget not reset properly: before=$%.6f, after=$%.6f (expected reset to ~0)", usageBeforeReset, usageAfterReset) + } + + t.Logf("Both budget and rate limit reset on ticker āœ“") +} + +// TestDataPersistenceAcrossRequests tests that budget and rate limit data persists correctly +func TestDataPersistenceAcrossRequests(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with both budget and rate limit + vkName := "test-vk-persistence-" + generateRandomID() + budgetLimit := 5.0 + budgetResetDuration := "1h" + tokenLimit := int64(100000) + tokenResetDuration := "1h" + requestLimit := int64(100) + requestResetDuration := "1h" + + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: budgetLimit, + ResetDuration: budgetResetDuration, + }, + RateLimit: &CreateRateLimitRequest{ + TokenMaxLimit: &tokenLimit, + TokenResetDuration: &tokenResetDuration, + RequestMaxLimit: &requestLimit, + RequestResetDuration: &requestResetDuration, + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s for persistence testing", vkName) + + // Make multiple requests and verify data persists + successCount := 0 + for i := 0; i < 2; i++ { + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: "Persistence test request " + string(rune('0'+i)) + ".", + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode == 200 { + successCount++ + } else { + t.Logf("Request %d failed with status %d", i+1, resp.StatusCode) + } + } + + if successCount == 0 { + t.Skip("Could not make requests to test persistence") + } + + t.Logf("Made %d successful requests", successCount) + + // Verify data persists in in-memory store + getDataResp := MakeRequest(t, APIRequest{ + Method: "GET", + Path: "/api/governance/virtual-keys?from_memory=true", + }) + + if getDataResp.StatusCode != 200 { + t.Fatalf("Failed to get governance data: status %d", getDataResp.StatusCode) + } + + virtualKeysMap := getDataResp.Body["virtual_keys"].(map[string]interface{}) + + vkData, exists := virtualKeysMap[vkValue] + if !exists { + t.Fatalf("VK not found in in-memory store after requests") + } + + vkDataMap := vkData.(map[string]interface{}) + budgetID, _ := vkDataMap["budget_id"].(string) + rateLimitID, _ := vkDataMap["rate_limit_id"].(string) + + if budgetID == "" { + t.Fatalf("Budget ID not found for VK") + } + if rateLimitID == "" { + t.Fatalf("Rate limit ID not found for VK") + } + + t.Logf("VK data persists correctly in in-memory store āœ“") +} diff --git a/plugins/governance/utils.go b/plugins/governance/utils.go index c3d39ad3fc..7b625645f9 100644 --- a/plugins/governance/utils.go +++ b/plugins/governance/utils.go @@ -15,6 +15,15 @@ func getStringFromContext(ctx context.Context, key any) string { return "" } +// equalPtr compares two pointers of comparable type for value equality +// Returns true if both are nil or both are non-nil with equal values +func equalPtr[T comparable](a, b *T) bool { + if a == nil || b == nil { + return a == b + } + return *a == *b +} + // getWeight safely dereferences a *float64 weight pointer, returning 1.0 as default if nil. // This allows distinguishing between "not set" (nil -> 1.0) and "explicitly set to 0" (0.0). func getWeight(w *float64) float64 { @@ -22,4 +31,4 @@ func getWeight(w *float64) float64 { return 1.0 } return *w -} \ No newline at end of file +} diff --git a/plugins/governance/version b/plugins/governance/version index 439f0a32d7..5596554988 100644 --- a/plugins/governance/version +++ b/plugins/governance/version @@ -1 +1 @@ -1.3.62 \ No newline at end of file +1.4.9 \ No newline at end of file diff --git a/plugins/governance/vkbudget_test.go b/plugins/governance/vkbudget_test.go new file mode 100644 index 0000000000..0ddce49520 --- /dev/null +++ b/plugins/governance/vkbudget_test.go @@ -0,0 +1,131 @@ +package governance + +import ( + "strconv" + "testing" +) + +// TestVKBudgetExceeded tests that VK level budgets are enforced by making requests until budget is consumed +func TestVKBudgetExceeded(t *testing.T) { + t.Parallel() + testData := NewGlobalTestData() + defer testData.Cleanup(t) + + // Create a VK with a fixed budget + vkBudget := 0.01 + vkName := "test-vk-budget-exceeded-" + generateRandomID() + createVKResp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/api/governance/virtual-keys", + Body: CreateVirtualKeyRequest{ + Name: vkName, + Budget: &BudgetRequest{ + MaxLimit: vkBudget, + ResetDuration: "1h", + }, + }, + }) + + if createVKResp.StatusCode != 200 { + t.Fatalf("Failed to create VK: status %d", createVKResp.StatusCode) + } + + vkID := ExtractIDFromResponse(t, createVKResp, "id") + testData.AddVirtualKey(vkID) + + vk := createVKResp.Body["virtual_key"].(map[string]interface{}) + vkValue := vk["value"].(string) + + t.Logf("Created VK %s with budget $%.2f", vkName, vkBudget) + + // Keep making requests, tracking actual token usage from responses, until budget is exceeded + consumedBudget := 0.0 + requestNum := 1 + var lastSuccessfulCost float64 + + var shouldStop = false + + for requestNum <= 50 { + // Create a longer prompt to consume more tokens and budget faster + longPrompt := "Please provide a comprehensive and detailed response to the following question. " + + "I need extensive information covering all aspects of the topic. " + + "Provide multiple paragraphs with detailed explanations. " + + "Request number " + strconv.Itoa(requestNum) + ". " + + "Here is a detailed prompt that will consume significant tokens: " + + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum. Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris. " + + "Nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit. " + + "In voluptate velit esse cillum dolore eu fugiat nulla pariatur. " + + "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt. " + + "Mollit anim id est laborum." + + resp := MakeRequest(t, APIRequest{ + Method: "POST", + Path: "/v1/chat/completions", + Body: ChatCompletionRequest{ + Model: "openai/gpt-4o", + Messages: []ChatMessage{ + { + Role: "user", + Content: longPrompt, + }, + }, + }, + VKHeader: &vkValue, + }) + + if resp.StatusCode >= 400 { + // Request failed - check if it's due to budget + if CheckErrorMessage(t, resp, "budget") { + t.Logf("Request %d correctly rejected: budget exceeded", requestNum) + t.Logf("Consumed budget: $%.6f (limit: $%.2f)", consumedBudget, vkBudget) + t.Logf("Last successful request cost: $%.6f", lastSuccessfulCost) + + // Verify that we made at least one successful request before hitting budget + if requestNum == 1 { + t.Fatalf("First request should have succeeded but was rejected due to budget") + } + return // Test passed + } else { + t.Fatalf("Request %d failed with unexpected error (not budget): %v", requestNum, resp.Body) + } + } + + // Request succeeded - extract actual token usage from response + if usage, ok := resp.Body["usage"].(map[string]interface{}); ok { + if prompt, ok := usage["prompt_tokens"].(float64); ok { + if completion, ok := usage["completion_tokens"].(float64); ok { + actualInputTokens := int(prompt) + actualOutputTokens := int(completion) + actualCost, _ := CalculateCost("openai/gpt-4o", actualInputTokens, actualOutputTokens) + + consumedBudget += actualCost + lastSuccessfulCost = actualCost + + t.Logf("Request %d succeeded: input_tokens=%d, output_tokens=%d, cost=$%.6f, consumed=$%.6f/$%.2f", + requestNum, actualInputTokens, actualOutputTokens, actualCost, consumedBudget, vkBudget) + } + } + } + + requestNum++ + + if shouldStop { + break + } + + if consumedBudget >= vkBudget { + shouldStop = true + } + } + + t.Fatalf("Made %d requests but never hit budget limit (consumed $%.6f / $%.2f) - budget not being enforced", + requestNum-1, consumedBudget, vkBudget) +} diff --git a/plugins/jsonparser/go.mod b/plugins/jsonparser/go.mod index 7bf9c0ac02..a0e99a11e7 100644 --- a/plugins/jsonparser/go.mod +++ b/plugins/jsonparser/go.mod @@ -2,7 +2,7 @@ module github.com/maximhq/bifrost/plugins/jsonparser go 1.25.5 -require github.com/maximhq/bifrost/core v1.2.49 +require github.com/maximhq/bifrost/core v1.3.8 require ( cloud.google.com/go/compute/metadata v0.9.0 // indirect @@ -35,8 +35,13 @@ require ( github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/sonic v1.14.2 // indirect github.com/bytedance/sonic/loader v0.4.0 // indirect + github.com/clarkmcc/go-typescript v0.7.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f // indirect github.com/google/uuid v1.6.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/klauspost/compress v1.18.2 // indirect diff --git a/plugins/jsonparser/go.sum b/plugins/jsonparser/go.sum index 1d6c4a7a62..9c7a631bed 100644 --- a/plugins/jsonparser/go.sum +++ b/plugins/jsonparser/go.sum @@ -12,6 +12,8 @@ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/Masterminds/semver/v3 v3.3.1 h1:QtNSWtVZ3nBfk8mAOu/B6v7FMJ+NHTIgUPi7rj+4nv4= +github.com/Masterminds/semver/v3 v3.3.1/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= @@ -62,6 +64,8 @@ github.com/bytedance/sonic v1.14.2 h1:k1twIoe97C1DtYUo+fZQy865IuHia4PR5RPiuGPPII github.com/bytedance/sonic v1.14.2/go.mod h1:T80iDELeHiHKSc0C9tubFygiuXoGzrkjKzX2quAx980= github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2NYzevs+o= github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/clarkmcc/go-typescript v0.7.0 h1:3nVeaPYyTCWjX6Lf8GoEOTxME2bM5tLuWmwhSZ86uxg= +github.com/clarkmcc/go-typescript v0.7.0/go.mod h1:IZ/nzoVeydAmyfX7l6Jmp8lJDOEnae3jffoXwP4UyYg= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -69,13 +73,21 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 h1:jxmXU5V9tXxJnydU5v/m9SG8TRUa/Z7IXODBpMs/P+U= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f h1:HU1RgM6NALf/KW9HEY6zry3ADbDKcmpQ+hJedoNGQYQ= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f/go.mod h1:67FPmZWbr+KDT/VlpWtw6sO9XSjpJmLuHpoLmWiTGgY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= @@ -105,8 +117,8 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/maximhq/bifrost/core v1.2.49 h1:fk6l6r3kVBlpN73wYXmgtV6O4bhedOjSO4LAEz/7leg= -github.com/maximhq/bifrost/core v1.2.49/go.mod h1:z7nOx15e91ktZGi+pZHq+uhShlEK+fM4UyYUpP6oHAw= +github.com/maximhq/bifrost/core v1.3.8 h1:xtwB9+HeTzYz5IKHkpUtupzBd0A5yl1avdLJGjsOKPI= +github.com/maximhq/bifrost/core v1.3.8/go.mod h1:abKQRnJQPZz8/UMxCcbuNHEyq19Db+IX4KlGJdlLY8E= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -161,6 +173,8 @@ golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/jsonparser/main.go b/plugins/jsonparser/main.go index 7281790e98..5ce59be483 100644 --- a/plugins/jsonparser/main.go +++ b/plugins/jsonparser/main.go @@ -83,24 +83,16 @@ func (p *JsonParserPlugin) GetName() string { return PluginName } -// TransportInterceptor is not used for this plugin -// Parameters: -// - ctx: The Bifrost context -// - url: The URL of the request -// - headers: The request headers -// - body: The request body -// Returns: -// - map[string]string: The updated request headers -// - map[string]any: The updated request body -// - error: Any error that occurred during processing -func (p *JsonParserPlugin) TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - return headers, body, nil +// HTTPTransportIntercept is not used for this plugin +func (p *JsonParserPlugin) HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + return nil, nil } // PreHook is not used for this plugin as we only process responses // Parameters: // - ctx: The Bifrost context // - req: The Bifrost request +// // Returns: // - *schemas.BifrostRequest: The processed request // - *schemas.PluginShortCircuit: The plugin short circuit if the request is not allowed @@ -114,6 +106,7 @@ func (p *JsonParserPlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.Bif // - ctx: The Bifrost context // - result: The Bifrost response to be processed // - err: The Bifrost error to be processed +// // Returns: // - *schemas.BifrostResponse: The processed response // - *schemas.BifrostError: The processed error diff --git a/plugins/jsonparser/plugin_test.go b/plugins/jsonparser/plugin_test.go index 9fdc0bfee5..52c26dde78 100644 --- a/plugins/jsonparser/plugin_test.go +++ b/plugins/jsonparser/plugin_test.go @@ -23,7 +23,7 @@ func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvide // GetKeysForProvider returns a mock API key configuration for testing. // Uses the OPENAI_API_KEY environment variable for authentication. -func (baseAccount *BaseAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { +func (baseAccount *BaseAccount) GetKeysForProvider(ctx context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { return []schemas.Key{ { Value: os.Getenv("OPENAI_API_KEY"), @@ -52,7 +52,7 @@ func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelPr // Required environment variables: // - OPENAI_API_KEY: Your OpenAI API key for the test request func TestJsonParserPluginEndToEnd(t *testing.T) { - ctx := context.Background() + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) // Check if OpenAI API key is set if os.Getenv("OPENAI_API_KEY") == "" { t.Skip("OPENAI_API_KEY is not set, skipping end-to-end test") @@ -202,7 +202,7 @@ func TestJsonParserPluginPerRequest(t *testing.T) { } // Create context with plugin enabled - newContext := context.WithValue(ctx, EnableStreamingJSONParser, true) + newContext := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline).WithValue(EnableStreamingJSONParser, true) // Make the streaming request responseChan, bifrostErr := client.ChatCompletionStreamRequest(newContext, request) diff --git a/plugins/jsonparser/version b/plugins/jsonparser/version index 439f0a32d7..721b9931f4 100644 --- a/plugins/jsonparser/version +++ b/plugins/jsonparser/version @@ -1 +1 @@ -1.3.62 \ No newline at end of file +1.4.8 \ No newline at end of file diff --git a/plugins/logging/go.mod b/plugins/logging/go.mod index 8ef15d1eac..2da17bf6f8 100644 --- a/plugins/logging/go.mod +++ b/plugins/logging/go.mod @@ -4,8 +4,8 @@ go 1.25.5 require ( github.com/bytedance/sonic v1.14.2 - github.com/maximhq/bifrost/core v1.2.49 - github.com/maximhq/bifrost/framework v1.1.61 + github.com/maximhq/bifrost/core v1.3.8 + github.com/maximhq/bifrost/framework v1.2.8 ) require ( @@ -39,8 +39,11 @@ require ( github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/sonic/loader v0.4.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/clarkmcc/go-typescript v0.7.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/analysis v0.24.2 // indirect @@ -64,8 +67,10 @@ require ( github.com/go-openapi/swag/typeutils v0.25.4 // indirect github.com/go-openapi/swag/yamlutils v0.25.4 // indirect github.com/go-openapi/validate v0.25.1 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f // indirect github.com/google/uuid v1.6.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect diff --git a/plugins/logging/go.sum b/plugins/logging/go.sum index 0c229672d9..065f485a18 100644 --- a/plugins/logging/go.sum +++ b/plugins/logging/go.sum @@ -12,6 +12,8 @@ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/Masterminds/semver/v3 v3.3.1 h1:QtNSWtVZ3nBfk8mAOu/B6v7FMJ+NHTIgUPi7rj+4nv4= +github.com/Masterminds/semver/v3 v3.3.1/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= @@ -68,6 +70,8 @@ github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2N github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/clarkmcc/go-typescript v0.7.0 h1:3nVeaPYyTCWjX6Lf8GoEOTxME2bM5tLuWmwhSZ86uxg= +github.com/clarkmcc/go-typescript v0.7.0/go.mod h1:IZ/nzoVeydAmyfX7l6Jmp8lJDOEnae3jffoXwP4UyYg= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -77,6 +81,10 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 h1:jxmXU5V9tXxJnydU5v/m9SG8TRUa/Z7IXODBpMs/P+U= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -132,6 +140,8 @@ github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6 github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= github.com/go-openapi/validate v0.25.1 h1:sSACUI6Jcnbo5IWqbYHgjibrhhmt3vR6lCzKZnmAgBw= github.com/go-openapi/validate v0.25.1/go.mod h1:RMVyVFYte0gbSTaZ0N4KmTn6u/kClvAFp+mAVfS/DQc= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -141,6 +151,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f h1:HU1RgM6NALf/KW9HEY6zry3ADbDKcmpQ+hJedoNGQYQ= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f/go.mod h1:67FPmZWbr+KDT/VlpWtw6sO9XSjpJmLuHpoLmWiTGgY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= @@ -184,10 +196,10 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.2.49 h1:fk6l6r3kVBlpN73wYXmgtV6O4bhedOjSO4LAEz/7leg= -github.com/maximhq/bifrost/core v1.2.49/go.mod h1:z7nOx15e91ktZGi+pZHq+uhShlEK+fM4UyYUpP6oHAw= -github.com/maximhq/bifrost/framework v1.1.61 h1:fMjvICbkrdWMtGnLYrjSNrcmQYqtQvOh/swmrJTvf+E= -github.com/maximhq/bifrost/framework v1.1.61/go.mod h1:wVUPzB8K5S/5GWuxqp8dXf3nNZkqJsS/APMIcq48SOI= +github.com/maximhq/bifrost/core v1.3.8 h1:xtwB9+HeTzYz5IKHkpUtupzBd0A5yl1avdLJGjsOKPI= +github.com/maximhq/bifrost/core v1.3.8/go.mod h1:abKQRnJQPZz8/UMxCcbuNHEyq19Db+IX4KlGJdlLY8E= +github.com/maximhq/bifrost/framework v1.2.8 h1:/oTpacuw7k0zRUJ9dSSQRtAVx3nLGSiR7GFwOjGxZNs= +github.com/maximhq/bifrost/framework v1.2.8/go.mod h1:mjw9YXh/Oxi3HeBCJ+3HJ6ftv43Wo4t0T4EzpcIbnr0= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= @@ -283,6 +295,8 @@ google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/logging/main.go b/plugins/logging/main.go index 614c450635..8678a07172 100644 --- a/plugins/logging/main.go +++ b/plugins/logging/main.go @@ -105,10 +105,9 @@ type LoggerPlugin struct { logger schemas.Logger logCallback LogCallback droppedRequests atomic.Int64 - cleanupTicker *time.Ticker // Ticker for cleaning up old processing logs - logMsgPool sync.Pool // Pool for reusing LogMessage structs - updateDataPool sync.Pool // Pool for reusing UpdateLogData structs - accumulator *streaming.Accumulator // Accumulator for streaming chunks + cleanupTicker *time.Ticker // Ticker for cleaning up old processing logs + logMsgPool sync.Pool // Pool for reusing LogMessage structs + updateDataPool sync.Pool // Pool for reusing UpdateLogData structs } // Init creates new logger plugin with given log store @@ -140,7 +139,6 @@ func Init(ctx context.Context, config *Config, logger schemas.Logger, logsStore return &UpdateLogData{} }, }, - accumulator: streaming.NewAccumulator(pricingManager, logger), } // Prewarm the pools for better performance at startup @@ -174,7 +172,8 @@ func (p *LoggerPlugin) cleanupWorker() { func (p *LoggerPlugin) cleanupOldProcessingLogs() { // Calculate timestamp for 30 minutes ago in UTC to match log entry timestamps thirtyMinutesAgo := time.Now().UTC().Add(-1 * 30 * time.Minute) - p.logger.Debug("cleaning up old processing logs before %s", thirtyMinutesAgo) // Delete processing logs older than 30 minutes using the store + p.logger.Debug("cleaning up old processing logs before %s", thirtyMinutesAgo) + // Delete processing logs older than 30 minutes using the store if err := p.store.Flush(p.ctx, thirtyMinutesAgo); err != nil { p.logger.Warn("failed to cleanup old processing logs: %v", err) } @@ -192,19 +191,9 @@ func (p *LoggerPlugin) GetName() string { return PluginName } -// TransportInterceptor is not used for this plugin -// Parameters: -// - ctx: The Bifrost context -// - url: The URL of the request -// - headers: The request headers -// - body: The request body -// -// Returns: -// - map[string]string: The updated request headers -// - map[string]any: The updated request body -// - error: Any error that occurred during processing -func (p *LoggerPlugin) TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - return headers, body, nil +// HTTPTransportIntercept is not used for this plugin +func (p *LoggerPlugin) HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + return nil, nil } // PreHook is called before a request is processed - FULLY ASYNC, NO DATABASE I/O @@ -233,9 +222,12 @@ func (p *LoggerPlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.Bifrost createdTimestamp := time.Now().UTC() - // If request type is streaming we create a stream accumulator + // If request type is streaming we create a stream accumulator via the tracer if bifrost.IsStreamRequestType(req.RequestType) { - p.accumulator.CreateStreamAccumulator(requestID, createdTimestamp) + tracer, traceID, err := bifrost.GetTracerFromContext(ctx) + if err == nil && tracer != nil && traceID != "" { + tracer.CreateStreamAccumulator(traceID, createdTimestamp) + } } provider, model, _ := req.GetRequestFields() @@ -272,7 +264,7 @@ func (p *LoggerPlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.Bifrost initialData.SpeechInput = req.SpeechRequest.Input case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: initialData.Params = req.TranscriptionRequest.Params - initialData.TranscriptionInput = req.TranscriptionRequest.Input + initialData.TranscriptionInput = req.TranscriptionRequest.Input } } @@ -367,8 +359,20 @@ func (p *LoggerPlugin) PostHook(ctx *schemas.BifrostContext, result *schemas.Bif virtualKeyName := getStringFromContext(ctx, schemas.BifrostContextKey("bf-governance-virtual-key-name")) numberOfRetries := getIntFromContext(ctx, schemas.BifrostContextKeyNumberOfRetries) + requestType, _, _ := bifrost.GetResponseFields(result, bifrostErr) + + var tracer schemas.Tracer + var traceID string + if bifrost.IsStreamRequestType(requestType) { + var err error + tracer, traceID, err = bifrost.GetTracerFromContext(ctx) + if err != nil { + p.logger.Warn("failed to get traceID/tracer from context of logging plugin posthook: %v", err) + return result, bifrostErr, nil + } + } + go func() { - requestType, _, _ := bifrost.GetResponseFields(result, bifrostErr) // Queue the log update message (non-blocking) - use same pattern for both streaming and regular logMsg := p.getLogMessage() logMsg.RequestID = requestID @@ -387,10 +391,7 @@ func (p *LoggerPlugin) PostHook(ctx *schemas.BifrostContext, result *schemas.Bif // If response is nil, and there is an error, we update log with error if result == nil && bifrostErr != nil { - // If request type is streaming, then we trigger cleanup as well - if bifrost.IsStreamRequestType(requestType) { - p.accumulator.CleanupStreamAccumulator(requestID) - } + // Note: Stream accumulator cleanup is handled by the tracing middleware logMsg.Operation = LogOperationUpdate logMsg.UpdateData = &UpdateLogData{ Status: "error", @@ -429,10 +430,20 @@ func (p *LoggerPlugin) PostHook(ctx *schemas.BifrostContext, result *schemas.Bif if bifrost.IsStreamRequestType(requestType) { p.logger.Debug("[logging] processing streaming response") - streamResponse, err := p.accumulator.ProcessStreamingResponse(ctx, result, bifrostErr) - if err != nil { - p.logger.Debug("failed to process streaming response: %v", err) - } else if streamResponse != nil && streamResponse.Type == streaming.StreamResponseTypeFinal { + // Process streaming response via tracer's central accumulator + var streamResponse *streaming.ProcessedStreamResponse + if tracer != nil && traceID != "" { + accResult := tracer.ProcessStreamingChunk(ctx, traceID, result, bifrostErr) + if accResult != nil { + streamResponse = convertToProcessedStreamResponse(accResult, requestType) + } + } else { + p.logger.Debug("tracer or traceID not available in streaming path for request %s, skipping stream processing", logMsg.RequestID) + } + + if streamResponse == nil { + p.logger.Debug("failed to process streaming response: tracer or traceID not available") + } else if bifrost.IsFinalChunk(ctx) { // Prepare final log data logMsg.Operation = LogOperationStreamUpdate logMsg.StreamResponse = streamResponse @@ -447,7 +458,7 @@ func (p *LoggerPlugin) PostHook(ctx *schemas.BifrostContext, result *schemas.Bif logMsg.NumberOfRetries, logMsg.SemanticCacheDebug, logMsg.StreamResponse, - streamResponse.Type == streaming.StreamResponseTypeFinal, + bifrost.IsFinalChunk(ctx), ) }) if processingErr != nil { @@ -463,6 +474,11 @@ func (p *LoggerPlugin) PostHook(ctx *schemas.BifrostContext, result *schemas.Bif } p.mu.Unlock() } + // Note: Stream accumulator cleanup is handled by the tracer + if tracer != nil && traceID != "" { + p.logger.Debug("cleaning up stream accumulator for trace ID: %s in logging plugin posthook", traceID) + tracer.CleanupStreamAccumulator(traceID) + } } } else { // Handle regular response @@ -617,7 +633,7 @@ func (p *LoggerPlugin) Cleanup() error { close(p.done) // Wait for the background worker to finish processing remaining items p.wg.Wait() - p.accumulator.Cleanup() + // Note: Accumulator cleanup is handled by the tracer, not the logging plugin // GORM handles connection cleanup automatically return nil } diff --git a/plugins/logging/utils.go b/plugins/logging/utils.go index 93859bc7fe..2a085ae0f4 100644 --- a/plugins/logging/utils.go +++ b/plugins/logging/utils.go @@ -10,6 +10,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/logstore" + "github.com/maximhq/bifrost/framework/streaming" ) // KeyPair represents an ID-Name pair for keys @@ -224,3 +225,67 @@ func getIntFromContext(ctx context.Context, key any) int { } return 0 } + +// convertToProcessedStreamResponse converts a StreamAccumulatorResult to ProcessedStreamResponse +// for use with the logging plugin's streaming log update functionality. +func convertToProcessedStreamResponse(result *schemas.StreamAccumulatorResult, requestType schemas.RequestType) *streaming.ProcessedStreamResponse { + if result == nil { + return nil + } + + // Determine stream type from request type + var streamType streaming.StreamType + switch requestType { + case schemas.TextCompletionStreamRequest: + streamType = streaming.StreamTypeText + case schemas.ChatCompletionStreamRequest: + streamType = streaming.StreamTypeChat + case schemas.ResponsesStreamRequest: + streamType = streaming.StreamTypeResponses + case schemas.SpeechStreamRequest: + streamType = streaming.StreamTypeAudio + case schemas.TranscriptionStreamRequest: + streamType = streaming.StreamTypeTranscription + default: + streamType = streaming.StreamTypeChat + } + + // Build accumulated data + data := &streaming.AccumulatedData{ + RequestID: result.RequestID, + Model: result.Model, + Status: result.Status, + Stream: true, + Latency: result.Latency, + TimeToFirstToken: result.TimeToFirstToken, + OutputMessage: result.OutputMessage, + OutputMessages: result.OutputMessages, + ErrorDetails: result.ErrorDetails, + TokenUsage: result.TokenUsage, + Cost: result.Cost, + AudioOutput: result.AudioOutput, + TranscriptionOutput: result.TranscriptionOutput, + FinishReason: result.FinishReason, + RawResponse: result.RawResponse, + } + + // Handle tool calls if present + if result.OutputMessage != nil && result.OutputMessage.ChatAssistantMessage != nil { + data.ToolCalls = result.OutputMessage.ChatAssistantMessage.ToolCalls + } + + resp := &streaming.ProcessedStreamResponse{ + RequestID: result.RequestID, + StreamType: streamType, + Provider: result.Provider, + Model: result.Model, + Data: data, + } + + if result.RawRequest != nil { + rawReq := result.RawRequest + resp.RawRequest = &rawReq + } + + return resp +} diff --git a/plugins/logging/version b/plugins/logging/version index 439f0a32d7..721b9931f4 100644 --- a/plugins/logging/version +++ b/plugins/logging/version @@ -1 +1 @@ -1.3.62 \ No newline at end of file +1.4.8 \ No newline at end of file diff --git a/plugins/maxim/go.mod b/plugins/maxim/go.mod index e7b8d81606..7811bd3f62 100644 --- a/plugins/maxim/go.mod +++ b/plugins/maxim/go.mod @@ -3,9 +3,9 @@ module github.com/maximhq/bifrost/plugins/maxim go 1.25.5 require ( - github.com/maximhq/bifrost/core v1.2.49 - github.com/maximhq/bifrost/framework v1.1.61 - github.com/maximhq/maxim-go v0.1.15 + github.com/maximhq/bifrost/core v1.3.8 + github.com/maximhq/bifrost/framework v1.2.8 + github.com/maximhq/maxim-go v0.1.14 ) require github.com/google/uuid v1.6.0 @@ -42,8 +42,11 @@ require ( github.com/bytedance/sonic v1.14.2 // indirect github.com/bytedance/sonic/loader v0.4.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/clarkmcc/go-typescript v0.7.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/analysis v0.24.2 // indirect @@ -67,8 +70,10 @@ require ( github.com/go-openapi/swag/typeutils v0.25.4 // indirect github.com/go-openapi/swag/yamlutils v0.25.4 // indirect github.com/go-openapi/validate v0.25.1 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect diff --git a/plugins/maxim/go.sum b/plugins/maxim/go.sum index d7e48688e3..fed4236597 100644 --- a/plugins/maxim/go.sum +++ b/plugins/maxim/go.sum @@ -12,6 +12,8 @@ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/Masterminds/semver/v3 v3.3.1 h1:QtNSWtVZ3nBfk8mAOu/B6v7FMJ+NHTIgUPi7rj+4nv4= +github.com/Masterminds/semver/v3 v3.3.1/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= @@ -68,6 +70,8 @@ github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2N github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/clarkmcc/go-typescript v0.7.0 h1:3nVeaPYyTCWjX6Lf8GoEOTxME2bM5tLuWmwhSZ86uxg= +github.com/clarkmcc/go-typescript v0.7.0/go.mod h1:IZ/nzoVeydAmyfX7l6Jmp8lJDOEnae3jffoXwP4UyYg= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -77,6 +81,10 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 h1:jxmXU5V9tXxJnydU5v/m9SG8TRUa/Z7IXODBpMs/P+U= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -132,6 +140,8 @@ github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6 github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= github.com/go-openapi/validate v0.25.1 h1:sSACUI6Jcnbo5IWqbYHgjibrhhmt3vR6lCzKZnmAgBw= github.com/go-openapi/validate v0.25.1/go.mod h1:RMVyVFYte0gbSTaZ0N4KmTn6u/kClvAFp+mAVfS/DQc= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -141,6 +151,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f h1:HU1RgM6NALf/KW9HEY6zry3ADbDKcmpQ+hJedoNGQYQ= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f/go.mod h1:67FPmZWbr+KDT/VlpWtw6sO9XSjpJmLuHpoLmWiTGgY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= @@ -184,12 +196,12 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.2.49 h1:fk6l6r3kVBlpN73wYXmgtV6O4bhedOjSO4LAEz/7leg= -github.com/maximhq/bifrost/core v1.2.49/go.mod h1:z7nOx15e91ktZGi+pZHq+uhShlEK+fM4UyYUpP6oHAw= -github.com/maximhq/bifrost/framework v1.1.61 h1:fMjvICbkrdWMtGnLYrjSNrcmQYqtQvOh/swmrJTvf+E= -github.com/maximhq/bifrost/framework v1.1.61/go.mod h1:wVUPzB8K5S/5GWuxqp8dXf3nNZkqJsS/APMIcq48SOI= -github.com/maximhq/maxim-go v0.1.15 h1:PCoS5B/0QB3VqwqpgDgCHSTaYPVVKp/mFpb7iZ09XM0= -github.com/maximhq/maxim-go v0.1.15/go.mod h1:0+UTWM7UZwNNE5VnljLtr/vpRGtYP8r/2q9WDwlLWFw= +github.com/maximhq/bifrost/core v1.3.8 h1:xtwB9+HeTzYz5IKHkpUtupzBd0A5yl1avdLJGjsOKPI= +github.com/maximhq/bifrost/core v1.3.8/go.mod h1:abKQRnJQPZz8/UMxCcbuNHEyq19Db+IX4KlGJdlLY8E= +github.com/maximhq/bifrost/framework v1.2.8 h1:/oTpacuw7k0zRUJ9dSSQRtAVx3nLGSiR7GFwOjGxZNs= +github.com/maximhq/bifrost/framework v1.2.8/go.mod h1:mjw9YXh/Oxi3HeBCJ+3HJ6ftv43Wo4t0T4EzpcIbnr0= +github.com/maximhq/maxim-go v0.1.14 h1:NQgpf3aRoD2Kq1GAqeSrLn3rQresn1H6mPP3JJ85qhA= +github.com/maximhq/maxim-go v0.1.14/go.mod h1:0+UTWM7UZwNNE5VnljLtr/vpRGtYP8r/2q9WDwlLWFw= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= @@ -285,6 +297,8 @@ google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/maxim/main.go b/plugins/maxim/main.go index 59e8f51119..93a05230cc 100644 --- a/plugins/maxim/main.go +++ b/plugins/maxim/main.go @@ -46,7 +46,6 @@ type Plugin struct { defaultLogRepoID string loggers map[string]*logging.Logger loggerMutex *sync.RWMutex - accumulator *streaming.Accumulator logger schemas.Logger } @@ -74,7 +73,6 @@ func Init(config *Config, logger schemas.Logger) (schemas.Plugin, error) { defaultLogRepoID: config.LogRepoID, loggers: make(map[string]*logging.Logger), loggerMutex: &sync.RWMutex{}, - accumulator: streaming.NewAccumulator(nil, logger), logger: logger, } @@ -103,6 +101,43 @@ const ( LogRepoIDKey schemas.BifrostContextKey = "log-repo-id" ) +// convertAccResultToProcessedStreamResponse converts StreamAccumulatorResult to ProcessedStreamResponse +func convertAccResultToProcessedStreamResponse(accResult *schemas.StreamAccumulatorResult) *streaming.ProcessedStreamResponse { + if accResult == nil { + return nil + } + // Determine StreamType based on the response content + streamType := streaming.StreamTypeChat + if accResult.AudioOutput != nil { + streamType = streaming.StreamTypeAudio + } else if accResult.TranscriptionOutput != nil { + streamType = streaming.StreamTypeTranscription + } else if len(accResult.OutputMessages) > 0 { + streamType = streaming.StreamTypeResponses + } + return &streaming.ProcessedStreamResponse{ + RequestID: accResult.RequestID, + StreamType: streamType, + Model: accResult.Model, + Provider: accResult.Provider, + Data: &streaming.AccumulatedData{ + Status: accResult.Status, + Latency: accResult.Latency, + TimeToFirstToken: accResult.TimeToFirstToken, + OutputMessage: accResult.OutputMessage, + OutputMessages: accResult.OutputMessages, + TokenUsage: accResult.TokenUsage, + Cost: accResult.Cost, + ErrorDetails: accResult.ErrorDetails, + AudioOutput: accResult.AudioOutput, + TranscriptionOutput: accResult.TranscriptionOutput, + FinishReason: accResult.FinishReason, + RawResponse: accResult.RawResponse, + }, + RawRequest: &accResult.RawRequest, + } +} + // The plugin provides request/response tracing functionality by integrating with Maxim's logging system. // It supports both chat completion and text completion requests, tracking the entire lifecycle of each request // including inputs, parameters, and responses. @@ -121,9 +156,9 @@ func (plugin *Plugin) GetName() string { return PluginName } -// TransportInterceptor is not used for this plugin -func (plugin *Plugin) TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - return headers, body, nil +// HTTPTransportIntercept is not used for this plugin +func (plugin *Plugin) HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + return nil, nil } // getEffectiveLogRepoID determines which single log repo ID to use based on priority: @@ -213,25 +248,25 @@ func (plugin *Plugin) PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostR // Check if context already has traceID and generationID if ctx != nil { - if existingGenerationID, ok := (*ctx).Value(GenerationIDKey).(string); ok && existingGenerationID != "" { + if existingGenerationID, ok := ctx.Value(GenerationIDKey).(string); ok && existingGenerationID != "" { // If generationID exists, return early return req, nil, nil } - if existingTraceID, ok := (*ctx).Value(TraceIDKey).(string); ok && existingTraceID != "" { + if existingTraceID, ok := ctx.Value(TraceIDKey).(string); ok && existingTraceID != "" { // If traceID exists, and no generationID, create a new generation on the trace traceID = existingTraceID } - if existingSessionID, ok := (*ctx).Value(SessionIDKey).(string); ok && existingSessionID != "" { + if existingSessionID, ok := ctx.Value(SessionIDKey).(string); ok && existingSessionID != "" { sessionID = existingSessionID } - if existingTraceName, ok := (*ctx).Value(TraceNameKey).(string); ok && existingTraceName != "" { + if existingTraceName, ok := ctx.Value(TraceNameKey).(string); ok && existingTraceName != "" { traceName = existingTraceName } - if existingGenerationName, ok := (*ctx).Value(GenerationNameKey).(string); ok && existingGenerationName != "" { + if existingGenerationName, ok := ctx.Value(GenerationNameKey).(string); ok && existingGenerationName != "" { generationName = existingGenerationName } } @@ -390,7 +425,6 @@ func (plugin *Plugin) PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostR // Add generation to the effective log repository logger.AddGenerationToTrace(traceID, &generationConfig) - var requestID string if ctx != nil { if _, ok := ctx.Value(TraceIDKey).(string); !ok { ctx.SetValue(TraceIDKey, traceID) @@ -404,10 +438,14 @@ func (plugin *Plugin) PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostR requestID = uuid.New().String() plugin.logger.Warn("%s request ID missing in PreHook, using fallback: %s", PluginLoggerPrefix, requestID) } - } - if bifrost.IsStreamRequestType(req.RequestType) { - plugin.accumulator.CreateStreamAccumulator(requestID, time.Now()) + // If streaming, create accumulator via central tracer using traceID + if bifrost.IsStreamRequestType(req.RequestType) { + tracer, bifrostTraceID, err := bifrost.GetTracerFromContext(ctx) + if err == nil && tracer != nil && bifrostTraceID != "" { + tracer.CreateStreamAccumulator(bifrostTraceID, time.Now()) + } + } } return req, nil, nil @@ -445,20 +483,28 @@ func (plugin *Plugin) PostHook(ctx *schemas.BifrostContext, result *schemas.Bifr return result, bifrostErr, nil } + // Capture context values BEFORE goroutine to avoid race conditions + // when the same context is reused across multiple requests + generationID, hasGenerationID := ctx.Value(GenerationIDKey).(string) + traceID, hasTraceID := ctx.Value(TraceIDKey).(string) + tags, hasTags := ctx.Value(TagsKey).(map[string]string) + go func() { requestType, _, model := bifrost.GetResponseFields(result, bifrostErr) var streamResponse *streaming.ProcessedStreamResponse - var err error if bifrost.IsStreamRequestType(requestType) { - streamResponse, err = plugin.accumulator.ProcessStreamingResponse(ctx, result, bifrostErr) - if err != nil { - plugin.logger.Error("%s failed to process streaming response: %v", PluginLoggerPrefix, err) - return + // Use central tracer's accumulator + tracer, bifrostTraceID, err := bifrost.GetTracerFromContext(ctx) + if err == nil && tracer != nil && bifrostTraceID != "" { + accResult := tracer.ProcessStreamingChunk(ctx, bifrostTraceID, result, bifrostErr) + if accResult != nil { + streamResponse = convertAccResultToProcessedStreamResponse(accResult) + } } - // Return the result if it is a delta response - if streamResponse == nil || streamResponse.Type == streaming.StreamResponseTypeDelta { + // Return if no stream response or it's a delta response + if streamResponse == nil || !bifrost.IsFinalChunk(ctx) { return } } @@ -467,8 +513,7 @@ func (plugin *Plugin) PostHook(ctx *schemas.BifrostContext, result *schemas.Bifr if err != nil { return } - generationID, ok := (*ctx).Value(GenerationIDKey).(string) - if ok { + if hasGenerationID { if bifrostErr != nil { // Safely extract message from nested error message := "" @@ -491,7 +536,11 @@ func (plugin *Plugin) PostHook(ctx *schemas.BifrostContext, result *schemas.Bifr logger.SetGenerationError(generationID, &genErr) if bifrost.IsStreamRequestType(requestType) { - plugin.accumulator.CleanupStreamAccumulator(requestID) + // Cleanup via central tracer + tracer, bifrostTraceID, err := bifrost.GetTracerFromContext(ctx) + if err == nil && tracer != nil && bifrostTraceID != "" { + tracer.CleanupStreamAccumulator(bifrostTraceID) + } } } else if result != nil { switch requestType { @@ -514,21 +563,23 @@ func (plugin *Plugin) PostHook(ctx *schemas.BifrostContext, result *schemas.Bifr logger.AddResultToGeneration(generationID, result.ResponsesResponse) } } - if streamResponse != nil && streamResponse.Type == streaming.StreamResponseTypeFinal { - plugin.accumulator.CleanupStreamAccumulator(requestID) + if streamResponse != nil && bifrost.IsFinalChunk(ctx) { + // Cleanup via central tracer + tracer, bifrostTraceID, err := bifrost.GetTracerFromContext(ctx) + if err == nil && tracer != nil && bifrostTraceID != "" { + tracer.CleanupStreamAccumulator(bifrostTraceID) + } } } logger.EndGeneration(generationID) } - traceID, ok := (*ctx).Value(TraceIDKey).(string) - if ok { + if hasTraceID { logger.EndTrace(traceID) } // add tags to the generation and trace - tags, ok := (*ctx).Value(TagsKey).(map[string]string) - if ok { + if hasTags { for key, value := range tags { if generationID != "" { logger.AddTagToGeneration(generationID, key, value) @@ -547,9 +598,6 @@ func (plugin *Plugin) PostHook(ctx *schemas.BifrostContext, result *schemas.Bifr } func (plugin *Plugin) Cleanup() error { - if plugin.accumulator != nil { - plugin.accumulator.Cleanup() - } // Flush all loggers plugin.loggerMutex.RLock() for _, logger := range plugin.loggers { diff --git a/plugins/maxim/plugin_test.go b/plugins/maxim/plugin_test.go index 9a80e360ad..a7416ce713 100644 --- a/plugins/maxim/plugin_test.go +++ b/plugins/maxim/plugin_test.go @@ -54,7 +54,7 @@ func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvide // GetKeysForProvider returns a mock API key configuration for testing. // Uses the OPENAI_API_KEY environment variable for authentication. -func (baseAccount *BaseAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { +func (baseAccount *BaseAccount) GetKeysForProvider(ctx context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { return []schemas.Key{ { Value: os.Getenv("OPENAI_API_KEY"), @@ -104,7 +104,7 @@ func TestMaximLoggerPlugin(t *testing.T) { } // Make a test chat completion request - _, bifrostErr := client.ChatCompletionRequest(context.Background(), &schemas.BifrostChatRequest{ + _, bifrostErr := client.ChatCompletionRequest(schemas.NewBifrostContext(context.Background(), schemas.NoDeadline), &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", Input: []schemas.ChatMessage{ diff --git a/plugins/maxim/version b/plugins/maxim/version index 4eb3ee9026..fa5512aeca 100644 --- a/plugins/maxim/version +++ b/plugins/maxim/version @@ -1 +1 @@ -1.4.63 \ No newline at end of file +1.5.8 \ No newline at end of file diff --git a/plugins/mocker/go.mod b/plugins/mocker/go.mod index 2c62d76470..19e8bb3dc9 100644 --- a/plugins/mocker/go.mod +++ b/plugins/mocker/go.mod @@ -4,7 +4,7 @@ go 1.25.5 require ( github.com/jaswdr/faker/v2 v2.8.0 - github.com/maximhq/bifrost/core v1.2.49 + github.com/maximhq/bifrost/core v1.3.8 ) require ( @@ -38,8 +38,13 @@ require ( github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/sonic v1.14.2 // indirect github.com/bytedance/sonic/loader v0.4.0 // indirect + github.com/clarkmcc/go-typescript v0.7.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f // indirect github.com/google/uuid v1.6.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/klauspost/compress v1.18.2 // indirect diff --git a/plugins/mocker/go.sum b/plugins/mocker/go.sum index e21e94e343..0706b21e5f 100644 --- a/plugins/mocker/go.sum +++ b/plugins/mocker/go.sum @@ -12,6 +12,8 @@ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/Masterminds/semver/v3 v3.3.1 h1:QtNSWtVZ3nBfk8mAOu/B6v7FMJ+NHTIgUPi7rj+4nv4= +github.com/Masterminds/semver/v3 v3.3.1/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= @@ -62,6 +64,8 @@ github.com/bytedance/sonic v1.14.2 h1:k1twIoe97C1DtYUo+fZQy865IuHia4PR5RPiuGPPII github.com/bytedance/sonic v1.14.2/go.mod h1:T80iDELeHiHKSc0C9tubFygiuXoGzrkjKzX2quAx980= github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2NYzevs+o= github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/clarkmcc/go-typescript v0.7.0 h1:3nVeaPYyTCWjX6Lf8GoEOTxME2bM5tLuWmwhSZ86uxg= +github.com/clarkmcc/go-typescript v0.7.0/go.mod h1:IZ/nzoVeydAmyfX7l6Jmp8lJDOEnae3jffoXwP4UyYg= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -69,13 +73,21 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 h1:jxmXU5V9tXxJnydU5v/m9SG8TRUa/Z7IXODBpMs/P+U= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f h1:HU1RgM6NALf/KW9HEY6zry3ADbDKcmpQ+hJedoNGQYQ= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f/go.mod h1:67FPmZWbr+KDT/VlpWtw6sO9XSjpJmLuHpoLmWiTGgY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= @@ -107,8 +119,8 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/maximhq/bifrost/core v1.2.49 h1:fk6l6r3kVBlpN73wYXmgtV6O4bhedOjSO4LAEz/7leg= -github.com/maximhq/bifrost/core v1.2.49/go.mod h1:z7nOx15e91ktZGi+pZHq+uhShlEK+fM4UyYUpP6oHAw= +github.com/maximhq/bifrost/core v1.3.8 h1:xtwB9+HeTzYz5IKHkpUtupzBd0A5yl1avdLJGjsOKPI= +github.com/maximhq/bifrost/core v1.3.8/go.mod h1:abKQRnJQPZz8/UMxCcbuNHEyq19Db+IX4KlGJdlLY8E= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -163,6 +175,8 @@ golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/mocker/main.go b/plugins/mocker/main.go index d15dfacdba..dbfc183803 100644 --- a/plugins/mocker/main.go +++ b/plugins/mocker/main.go @@ -478,9 +478,9 @@ func (p *MockerPlugin) GetName() string { return PluginName } -// TransportInterceptor is not used for this plugin -func (p *MockerPlugin) TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - return headers, body, nil +// HTTPTransportIntercept is not used for this plugin +func (p *MockerPlugin) HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + return nil, nil } // PreHook intercepts requests and applies mocking rules based on configuration diff --git a/plugins/mocker/plugin_test.go b/plugins/mocker/plugin_test.go index 7fde1e7819..5ccc515b46 100644 --- a/plugins/mocker/plugin_test.go +++ b/plugins/mocker/plugin_test.go @@ -21,7 +21,7 @@ func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvide // GetKeysForProvider returns a dummy API key configuration for testing. // Since we're testing the mocker plugin, these keys should never be used // as the plugin intercepts requests before they reach the actual providers. -func (baseAccount *BaseAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { +func (baseAccount *BaseAccount) GetKeysForProvider(ctx context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { return []schemas.Key{ { Value: "dummy-api-key-for-testing", // Dummy key @@ -52,7 +52,7 @@ func TestMockerPlugin_GetName(t *testing.T) { // TestMockerPlugin_Disabled tests that disabled plugin doesn't interfere func TestMockerPlugin_Disabled(t *testing.T) { - ctx := context.Background() + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) config := MockerConfig{ Enabled: false, } @@ -95,7 +95,7 @@ func TestMockerPlugin_Disabled(t *testing.T) { // TestMockerPlugin_DefaultMockRule tests the default catch-all rule func TestMockerPlugin_DefaultMockRule(t *testing.T) { - ctx := context.Background() + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) config := MockerConfig{ Enabled: true, // No rules provided, should create default rule } @@ -147,7 +147,7 @@ func TestMockerPlugin_DefaultMockRule(t *testing.T) { // TestMockerPlugin_CustomSuccessRule tests custom success response func TestMockerPlugin_CustomSuccessRule(t *testing.T) { - ctx := context.Background() + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) config := MockerConfig{ Enabled: true, Rules: []MockRule{ @@ -226,7 +226,7 @@ func TestMockerPlugin_CustomSuccessRule(t *testing.T) { // TestMockerPlugin_ErrorResponse tests error response generation func TestMockerPlugin_ErrorResponse(t *testing.T) { - ctx := context.Background() + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) allowFallbacks := false config := MockerConfig{ Enabled: true, @@ -296,7 +296,7 @@ func TestMockerPlugin_ErrorResponse(t *testing.T) { // TestMockerPlugin_MessageTemplate tests template variable substitution func TestMockerPlugin_MessageTemplate(t *testing.T) { - ctx := context.Background() + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) config := MockerConfig{ Enabled: true, Rules: []MockRule{ @@ -366,7 +366,7 @@ func TestMockerPlugin_MessageTemplate(t *testing.T) { // TestMockerPlugin_Statistics tests plugin statistics tracking func TestMockerPlugin_Statistics(t *testing.T) { - ctx := context.Background() + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) config := MockerConfig{ Enabled: true, Rules: []MockRule{ diff --git a/plugins/mocker/version b/plugins/mocker/version index 0629a74033..721b9931f4 100644 --- a/plugins/mocker/version +++ b/plugins/mocker/version @@ -1 +1 @@ -1.3.60 \ No newline at end of file +1.4.8 \ No newline at end of file diff --git a/plugins/otel/converter.go b/plugins/otel/converter.go index d113b2bf54..9decd7c01f 100644 --- a/plugins/otel/converter.go +++ b/plugins/otel/converter.go @@ -4,10 +4,8 @@ import ( "encoding/hex" "fmt" "strings" - "time" "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/framework/modelcatalog" commonpb "go.opentelemetry.io/proto/otlp/common/v1" resourcepb "go.opentelemetry.io/proto/otlp/resource/v1" tracepb "go.opentelemetry.io/proto/otlp/trace/v1" @@ -71,1021 +69,217 @@ func hexToBytes(hexStr string, length int) []byte { return bytes } -// getSpeechRequestParams handles the speech request -func getSpeechRequestParams(req *schemas.BifrostSpeechRequest) []*KeyValue { - params := []*KeyValue{} - if req.Params != nil { - if req.Params.VoiceConfig != nil { - if req.Params.VoiceConfig.Voice != nil { - params = append(params, kvStr("gen_ai.request.voice", *req.Params.VoiceConfig.Voice)) - } - if len(req.Params.VoiceConfig.MultiVoiceConfig) > 0 { - multiVoiceConfigParams := []*KeyValue{} - for _, voiceConfig := range req.Params.VoiceConfig.MultiVoiceConfig { - multiVoiceConfigParams = append(multiVoiceConfigParams, kvStr("gen_ai.request.voice", voiceConfig.Voice)) - } - params = append(params, kvAny("gen_ai.request.multi_voice_config", arrValue(listValue(multiVoiceConfigParams...)))) - } - } - params = append(params, kvStr("gen_ai.request.instructions", req.Params.Instructions)) - params = append(params, kvStr("gen_ai.request.response_format", req.Params.ResponseFormat)) - if req.Params.Speed != nil { - params = append(params, kvDbl("gen_ai.request.speed", *req.Params.Speed)) - } - } - if req.Input != nil { - params = append(params, kvStr("gen_ai.input.speech", req.Input.Input)) - } - return params -} - -// getEmbeddingRequestParams handles the embedding request -func getEmbeddingRequestParams(req *schemas.BifrostEmbeddingRequest) []*KeyValue { - params := []*KeyValue{} - if req.Params != nil { - if req.Params.Dimensions != nil { - params = append(params, kvInt("gen_ai.request.dimensions", int64(*req.Params.Dimensions))) - } - if req.Params.ExtraParams != nil { - for k, v := range req.Params.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } - } - if req.Params.EncodingFormat != nil { - params = append(params, kvStr("gen_ai.request.encoding_format", *req.Params.EncodingFormat)) - } - } - if req.Input.Text != nil { - params = append(params, kvStr("gen_ai.input.text", *req.Input.Text)) - } - if req.Input.Texts != nil { - params = append(params, kvStr("gen_ai.input.text", strings.Join(req.Input.Texts, ","))) +// convertTraceToResourceSpan converts a Bifrost trace to OTEL ResourceSpan +func (p *OtelPlugin) convertTraceToResourceSpan(trace *schemas.Trace) *ResourceSpan { + otelSpans := make([]*Span, 0, len(trace.Spans)) + for _, span := range trace.Spans { + otelSpans = append(otelSpans, p.convertSpanToOTELSpan(trace.TraceID, span)) } - if req.Input.Embedding != nil { - embedding := make([]string, len(req.Input.Embedding)) - for i, v := range req.Input.Embedding { - embedding[i] = fmt.Sprintf("%d", v) - } - params = append(params, kvStr("gen_ai.input.embedding", strings.Join(embedding, ","))) - } - return params -} - -// getTextCompletionRequestParams handles the text completion request -func getTextCompletionRequestParams(req *schemas.BifrostTextCompletionRequest) []*KeyValue { - params := []*KeyValue{} - if req.Params != nil { - if req.Params.MaxTokens != nil { - params = append(params, kvInt("gen_ai.request.max_tokens", int64(*req.Params.MaxTokens))) - } - if req.Params.Temperature != nil { - params = append(params, kvDbl("gen_ai.request.temperature", *req.Params.Temperature)) - } - if req.Params.TopP != nil { - params = append(params, kvDbl("gen_ai.request.top_p", *req.Params.TopP)) - } - if req.Params.Stop != nil { - params = append(params, kvStr("gen_ai.request.stop_sequences", strings.Join(req.Params.Stop, ","))) - } - if req.Params.PresencePenalty != nil { - params = append(params, kvDbl("gen_ai.request.presence_penalty", *req.Params.PresencePenalty)) - } - if req.Params.FrequencyPenalty != nil { - params = append(params, kvDbl("gen_ai.request.frequency_penalty", *req.Params.FrequencyPenalty)) - } - if req.Params.BestOf != nil { - params = append(params, kvInt("gen_ai.request.best_of", int64(*req.Params.BestOf))) - } - if req.Params.Echo != nil { - params = append(params, kvBool("gen_ai.request.echo", *req.Params.Echo)) - } - if req.Params.LogitBias != nil { - params = append(params, kvStr("gen_ai.request.logit_bias", fmt.Sprintf("%v", req.Params.LogitBias))) - } - if req.Params.LogProbs != nil { - params = append(params, kvInt("gen_ai.request.logprobs", int64(*req.Params.LogProbs))) - } - if req.Params.N != nil { - params = append(params, kvInt("gen_ai.request.n", int64(*req.Params.N))) - } - if req.Params.Seed != nil { - params = append(params, kvInt("gen_ai.request.seed", int64(*req.Params.Seed))) - } - if req.Params.Suffix != nil { - params = append(params, kvStr("gen_ai.request.suffix", *req.Params.Suffix)) - } - if req.Params.User != nil { - params = append(params, kvStr("gen_ai.request.user", *req.Params.User)) - } - if req.Params.ExtraParams != nil { - for k, v := range req.Params.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } - } - } - if req.Input.PromptStr != nil { - params = append(params, kvStr("gen_ai.input.text", *req.Input.PromptStr)) - } - if req.Input.PromptArray != nil { - params = append(params, kvStr("gen_ai.input.text", strings.Join(req.Input.PromptArray, ","))) - } - return params -} - -// getChatRequestParams handles the chat completion request -func getChatRequestParams(req *schemas.BifrostChatRequest) []*KeyValue { - params := []*KeyValue{} - if req.Params != nil { - if req.Params.MaxCompletionTokens != nil { - params = append(params, kvInt("gen_ai.request.max_tokens", int64(*req.Params.MaxCompletionTokens))) - } - if req.Params.Temperature != nil { - params = append(params, kvDbl("gen_ai.request.temperature", *req.Params.Temperature)) - } - if req.Params.TopP != nil { - params = append(params, kvDbl("gen_ai.request.top_p", *req.Params.TopP)) - } - if req.Params.Stop != nil { - params = append(params, kvStr("gen_ai.request.stop_sequences", strings.Join(req.Params.Stop, ","))) - } - if req.Params.PresencePenalty != nil { - params = append(params, kvDbl("gen_ai.request.presence_penalty", *req.Params.PresencePenalty)) - } - if req.Params.FrequencyPenalty != nil { - params = append(params, kvDbl("gen_ai.request.frequency_penalty", *req.Params.FrequencyPenalty)) - } - if req.Params.ParallelToolCalls != nil { - params = append(params, kvBool("gen_ai.request.parallel_tool_calls", *req.Params.ParallelToolCalls)) - } - if req.Params.User != nil { - params = append(params, kvStr("gen_ai.request.user", *req.Params.User)) - } - if req.Params.ExtraParams != nil { - for k, v := range req.Params.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } - } - } - // Handling chat completion - if req.Input != nil { - messages := []*AnyValue{} - for _, message := range req.Input { - if message.Content == nil { - continue - } - switch message.Role { - case schemas.ChatMessageRoleUser: - kvs := []*KeyValue{kvStr("role", "user")} - if message.Content.ContentStr != nil { - kvs = append(kvs, kvStr("content", *message.Content.ContentStr)) - } - messages = append(messages, listValue(kvs...)) - case schemas.ChatMessageRoleAssistant: - kvs := []*KeyValue{kvStr("role", "assistant")} - if message.Content.ContentStr != nil { - kvs = append(kvs, kvStr("content", *message.Content.ContentStr)) - } - messages = append(messages, listValue(kvs...)) - case schemas.ChatMessageRoleSystem: - kvs := []*KeyValue{kvStr("role", "system")} - if message.Content.ContentStr != nil { - kvs = append(kvs, kvStr("content", *message.Content.ContentStr)) - } - messages = append(messages, listValue(kvs...)) - case schemas.ChatMessageRoleTool: - kvs := []*KeyValue{kvStr("role", "tool")} - if message.Content.ContentStr != nil { - kvs = append(kvs, kvStr("content", *message.Content.ContentStr)) - } - messages = append(messages, listValue(kvs...)) - case schemas.ChatMessageRoleDeveloper: - kvs := []*KeyValue{kvStr("role", "developer")} - if message.Content.ContentStr != nil { - kvs = append(kvs, kvStr("content", *message.Content.ContentStr)) - } - messages = append(messages, listValue(kvs...)) - } - } - params = append(params, kvAny("gen_ai.input.messages", arrValue(messages...))) - } - return params -} -// getTranscriptionRequestParams handles the transcription request -func getTranscriptionRequestParams(req *schemas.BifrostTranscriptionRequest) []*KeyValue { - params := []*KeyValue{} - if req.Params != nil { - if req.Params.Language != nil { - params = append(params, kvStr("gen_ai.request.language", *req.Params.Language)) - } - if req.Params.Prompt != nil { - params = append(params, kvStr("gen_ai.request.prompt", *req.Params.Prompt)) - } - if req.Params.ResponseFormat != nil { - params = append(params, kvStr("gen_ai.request.response_format", *req.Params.ResponseFormat)) - } - if req.Params.Format != nil { - params = append(params, kvStr("gen_ai.request.format", *req.Params.Format)) - } - } - return params -} - -// getResponsesRequestParams handles the responses request -func getResponsesRequestParams(req *schemas.BifrostResponsesRequest) []*KeyValue { - params := []*KeyValue{} - if req.Params != nil { - if req.Params.ParallelToolCalls != nil { - params = append(params, kvBool("gen_ai.request.parallel_tool_calls", *req.Params.ParallelToolCalls)) - } - if req.Params.PromptCacheKey != nil { - params = append(params, kvStr("gen_ai.request.prompt_cache_key", *req.Params.PromptCacheKey)) - } - if req.Params.Reasoning != nil { - if req.Params.Reasoning.Effort != nil { - params = append(params, kvStr("gen_ai.request.reasoning_effort", *req.Params.Reasoning.Effort)) - } - if req.Params.Reasoning.Summary != nil { - params = append(params, kvStr("gen_ai.request.reasoning_summary", *req.Params.Reasoning.Summary)) - } - if req.Params.Reasoning.GenerateSummary != nil { - params = append(params, kvStr("gen_ai.request.reasoning_generate_summary", *req.Params.Reasoning.GenerateSummary)) - } - } - if req.Params.SafetyIdentifier != nil { - params = append(params, kvStr("gen_ai.request.safety_identifier", *req.Params.SafetyIdentifier)) - } - if req.Params.ServiceTier != nil { - params = append(params, kvStr("gen_ai.request.service_tier", *req.Params.ServiceTier)) - } - if req.Params.Store != nil { - params = append(params, kvBool("gen_ai.request.store", *req.Params.Store)) - } - if req.Params.Temperature != nil { - params = append(params, kvDbl("gen_ai.request.temperature", *req.Params.Temperature)) - } - if req.Params.Text != nil { - if req.Params.Text.Verbosity != nil { - params = append(params, kvStr("gen_ai.request.text", *req.Params.Text.Verbosity)) - } - if req.Params.Text.Format != nil { - params = append(params, kvStr("gen_ai.request.text_format_type", req.Params.Text.Format.Type)) - } - - } - if req.Params.TopLogProbs != nil { - params = append(params, kvInt("gen_ai.request.top_logprobs", int64(*req.Params.TopLogProbs))) - } - if req.Params.TopP != nil { - params = append(params, kvDbl("gen_ai.request.top_p", *req.Params.TopP)) - } - if req.Params.ToolChoice != nil { - if req.Params.ToolChoice.ResponsesToolChoiceStr != nil && *req.Params.ToolChoice.ResponsesToolChoiceStr != "" { - params = append(params, kvStr("gen_ai.request.tool_choice_type", *req.Params.ToolChoice.ResponsesToolChoiceStr)) - } - if req.Params.ToolChoice.ResponsesToolChoiceStruct != nil && req.Params.ToolChoice.ResponsesToolChoiceStruct.Name != nil { - params = append(params, kvStr("gen_ai.request.tool_choice_name", *req.Params.ToolChoice.ResponsesToolChoiceStruct.Name)) - } - - } - if req.Params.Tools != nil { - tools := make([]string, len(req.Params.Tools)) - for i, tool := range req.Params.Tools { - tools[i] = string(tool.Type) - } - params = append(params, kvStr("gen_ai.request.tools", strings.Join(tools, ","))) - } - if req.Params.Truncation != nil { - params = append(params, kvStr("gen_ai.request.truncation", *req.Params.Truncation)) - } - if req.Params.ExtraParams != nil { - for k, v := range req.Params.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } - } - } - return params -} - -// getFileUploadRequestParams handles the file upload request -func getFileUploadRequestParams(req *schemas.BifrostFileUploadRequest) []*KeyValue { - params := []*KeyValue{} - if req.Filename != "" { - params = append(params, kvStr("gen_ai.file.filename", req.Filename)) - } - if req.Purpose != "" { - params = append(params, kvStr("gen_ai.file.purpose", string(req.Purpose))) - } - if len(req.File) > 0 { - params = append(params, kvInt("gen_ai.file.bytes", int64(len(req.File)))) - } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } - } - return params -} - -// getFileListRequestParams handles the file list request -func getFileListRequestParams(req *schemas.BifrostFileListRequest) []*KeyValue { - params := []*KeyValue{} - if req.Purpose != "" { - params = append(params, kvStr("gen_ai.file.purpose", string(req.Purpose))) - } - if req.Limit > 0 { - params = append(params, kvInt("gen_ai.file.limit", int64(req.Limit))) - } - if req.After != nil { - params = append(params, kvStr("gen_ai.file.after", *req.After)) - } - if req.Order != nil { - params = append(params, kvStr("gen_ai.file.order", *req.Order)) - } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } + return &ResourceSpan{ + Resource: &resourcepb.Resource{ + Attributes: p.getResourceAttributes(), + }, + ScopeSpans: []*ScopeSpan{{ + Scope: p.getInstrumentationScope(), + Spans: otelSpans, + }}, } - return params } -// getFileRetrieveRequestParams handles the file retrieve request -func getFileRetrieveRequestParams(req *schemas.BifrostFileRetrieveRequest) []*KeyValue { - params := []*KeyValue{} - if req.FileID != "" { - params = append(params, kvStr("gen_ai.file.file_id", req.FileID)) - } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } +// convertSpanToOTELSpan converts a single Bifrost span to OTEL format +func (p *OtelPlugin) convertSpanToOTELSpan(traceID string, span *schemas.Span) *Span { + otelSpan := &Span{ + TraceId: hexToBytes(traceID, 16), + SpanId: hexToBytes(span.SpanID, 8), + Name: span.Name, + Kind: convertSpanKind(span.Kind), + StartTimeUnixNano: uint64(span.StartTime.UnixNano()), + EndTimeUnixNano: uint64(span.EndTime.UnixNano()), + Attributes: convertAttributesToKeyValues(span.Attributes), + Status: convertSpanStatus(span.Status, span.StatusMsg), + Events: convertSpanEvents(span.Events), } - return params -} -// getFileDeleteRequestParams handles the file delete request -func getFileDeleteRequestParams(req *schemas.BifrostFileDeleteRequest) []*KeyValue { - params := []*KeyValue{} - if req.FileID != "" { - params = append(params, kvStr("gen_ai.file.file_id", req.FileID)) - } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } + // Set parent span ID if present + if span.ParentID != "" { + otelSpan.ParentSpanId = hexToBytes(span.ParentID, 8) } - return params -} -// getFileContentRequestParams handles the file content request -func getFileContentRequestParams(req *schemas.BifrostFileContentRequest) []*KeyValue { - params := []*KeyValue{} - if req.FileID != "" { - params = append(params, kvStr("gen_ai.file.file_id", req.FileID)) - } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } - } - return params + return otelSpan } -// getBatchCreateRequestParams handles the batch create request -func getBatchCreateRequestParams(req *schemas.BifrostBatchCreateRequest) []*KeyValue { - params := []*KeyValue{} - if req.InputFileID != "" { - params = append(params, kvStr("gen_ai.batch.input_file_id", req.InputFileID)) - } - if req.Endpoint != "" { - params = append(params, kvStr("gen_ai.batch.endpoint", string(req.Endpoint))) - } - if req.CompletionWindow != "" { - params = append(params, kvStr("gen_ai.batch.completion_window", req.CompletionWindow)) - } - if len(req.Requests) > 0 { - params = append(params, kvInt("gen_ai.batch.requests_count", int64(len(req.Requests)))) - } - if len(req.Metadata) > 0 { - params = append(params, kvStr("gen_ai.batch.metadata", fmt.Sprintf("%v", req.Metadata))) - } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } - } - return params +// getResourceAttributes returns the resource attributes for the OTEL span +func (p *OtelPlugin) getResourceAttributes() []*KeyValue { + attrs := []*KeyValue{ + kvStr("service.name", p.serviceName), + kvStr("service.version", p.bifrostVersion), + kvStr("telemetry.sdk.name", "bifrost"), + kvStr("telemetry.sdk.language", "go"), + } + // Add environment attributes + attrs = append(attrs, p.attributesFromEnvironment...) + return attrs } -// getBatchListRequestParams handles the batch list request -func getBatchListRequestParams(req *schemas.BifrostBatchListRequest) []*KeyValue { - params := []*KeyValue{} - if req.Limit > 0 { - params = append(params, kvInt("gen_ai.batch.limit", int64(req.Limit))) - } - if req.After != nil { - params = append(params, kvStr("gen_ai.batch.after", *req.After)) - } - if req.BeforeID != nil { - params = append(params, kvStr("gen_ai.batch.before_id", *req.BeforeID)) +// getInstrumentationScope returns the instrumentation scope for OTEL +func (p *OtelPlugin) getInstrumentationScope() *commonpb.InstrumentationScope { + return &commonpb.InstrumentationScope{ + Name: p.serviceName, + Version: p.bifrostVersion, } - if req.AfterID != nil { - params = append(params, kvStr("gen_ai.batch.after_id", *req.AfterID)) - } - if req.PageToken != nil { - params = append(params, kvStr("gen_ai.batch.page_token", *req.PageToken)) - } - if req.PageSize > 0 { - params = append(params, kvInt("gen_ai.batch.page_size", int64(req.PageSize))) - } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } - } - return params } -// getBatchRetrieveRequestParams handles the batch retrieve request -func getBatchRetrieveRequestParams(req *schemas.BifrostBatchRetrieveRequest) []*KeyValue { - params := []*KeyValue{} - if req.BatchID != "" { - params = append(params, kvStr("gen_ai.batch.batch_id", req.BatchID)) +// convertAttributesToKeyValues converts map[string]any to OTEL KeyValue slice +func convertAttributesToKeyValues(attrs map[string]any) []*KeyValue { + if attrs == nil { + return nil } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) + kvs := make([]*KeyValue, 0, len(attrs)) + for k, v := range attrs { + kv := anyToKeyValue(k, v) + if kv != nil { + kvs = append(kvs, kv) } } - return params + return kvs } -// getBatchCancelRequestParams handles the batch cancel request -func getBatchCancelRequestParams(req *schemas.BifrostBatchCancelRequest) []*KeyValue { - params := []*KeyValue{} - if req.BatchID != "" { - params = append(params, kvStr("gen_ai.batch.batch_id", req.BatchID)) - } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } +// anyToKeyValue converts any Go value to OTEL KeyValue +func anyToKeyValue(key string, value any) *KeyValue { + if value == nil { + return nil + } + switch v := value.(type) { + case string: + if v == "" { + return nil + } + return kvStr(key, v) + case int: + return kvInt(key, int64(v)) + case int32: + return kvInt(key, int64(v)) + case int64: + return kvInt(key, v) + case uint: + return kvInt(key, int64(v)) + case uint32: + return kvInt(key, int64(v)) + case uint64: + return kvInt(key, int64(v)) + case float32: + return kvDbl(key, float64(v)) + case float64: + return kvDbl(key, v) + case bool: + return kvBool(key, v) + case []string: + if len(v) == 0 { + return nil + } + vals := make([]*AnyValue, len(v)) + for i, s := range v { + vals[i] = &AnyValue{Value: &StringValue{StringValue: s}} + } + return kvAny(key, arrValue(vals...)) + case []int: + if len(v) == 0 { + return nil + } + vals := make([]*AnyValue, len(v)) + for i, n := range v { + vals[i] = &AnyValue{Value: &IntValue{IntValue: int64(n)}} + } + return kvAny(key, arrValue(vals...)) + case []int64: + if len(v) == 0 { + return nil + } + vals := make([]*AnyValue, len(v)) + for i, n := range v { + vals[i] = &AnyValue{Value: &IntValue{IntValue: n}} + } + return kvAny(key, arrValue(vals...)) + case []float64: + if len(v) == 0 { + return nil + } + vals := make([]*AnyValue, len(v)) + for i, n := range v { + vals[i] = &AnyValue{Value: &DoubleValue{DoubleValue: n}} + } + return kvAny(key, arrValue(vals...)) + case map[string]any: + if len(v) == 0 { + return nil + } + kvList := make([]*KeyValue, 0, len(v)) + for k, val := range v { + kv := anyToKeyValue(k, val) + if kv != nil { + kvList = append(kvList, kv) + } + } + return kvAny(key, listValue(kvList...)) + default: + // For any other type, convert to string + return kvStr(key, fmt.Sprintf("%v", v)) } - return params } -// getBatchResultsRequestParams handles the batch results request -func getBatchResultsRequestParams(req *schemas.BifrostBatchResultsRequest) []*KeyValue { - params := []*KeyValue{} - if req.BatchID != "" { - params = append(params, kvStr("gen_ai.batch.batch_id", req.BatchID)) +// convertSpanKind maps Bifrost SpanKind to OTEL SpanKind +func convertSpanKind(kind schemas.SpanKind) tracepb.Span_SpanKind { + switch kind { + case schemas.SpanKindLLMCall: + return tracepb.Span_SPAN_KIND_CLIENT + case schemas.SpanKindHTTPRequest: + return tracepb.Span_SPAN_KIND_SERVER + case schemas.SpanKindPlugin: + return tracepb.Span_SPAN_KIND_INTERNAL + case schemas.SpanKindInternal: + return tracepb.Span_SPAN_KIND_INTERNAL + case schemas.SpanKindRetry: + return tracepb.Span_SPAN_KIND_INTERNAL + case schemas.SpanKindFallback: + return tracepb.Span_SPAN_KIND_INTERNAL + case schemas.SpanKindMCPTool: + return tracepb.Span_SPAN_KIND_CLIENT + case schemas.SpanKindEmbedding: + return tracepb.Span_SPAN_KIND_CLIENT + case schemas.SpanKindSpeech: + return tracepb.Span_SPAN_KIND_CLIENT + case schemas.SpanKindTranscription: + return tracepb.Span_SPAN_KIND_CLIENT + default: + return tracepb.Span_SPAN_KIND_UNSPECIFIED } - if req.ExtraParams != nil { - for k, v := range req.ExtraParams { - params = append(params, kvStr(k, fmt.Sprintf("%v", v))) - } - } - return params } -// createResourceSpan creates a new resource span for a Bifrost request -func (p *OtelPlugin) createResourceSpan(traceID, spanID string, timestamp time.Time, req *schemas.BifrostRequest) *ResourceSpan { - provider, model, _ := req.GetRequestFields() - - // preparing parameters - params := []*KeyValue{} - spanName := "span" - params = append(params, kvStr("gen_ai.provider.name", string(provider))) - params = append(params, kvStr("gen_ai.request.model", model)) - // Preparing parameters - switch req.RequestType { - case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest: - spanName = "gen_ai.text" - params = append(params, getTextCompletionRequestParams(req.TextCompletionRequest)...) - case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: - spanName = "gen_ai.chat" - params = append(params, getChatRequestParams(req.ChatRequest)...) - case schemas.EmbeddingRequest: - spanName = "gen_ai.embedding" - params = append(params, getEmbeddingRequestParams(req.EmbeddingRequest)...) - case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: - spanName = "gen_ai.transcription" - params = append(params, getTranscriptionRequestParams(req.TranscriptionRequest)...) - case schemas.SpeechRequest, schemas.SpeechStreamRequest: - spanName = "gen_ai.speech" - params = append(params, getSpeechRequestParams(req.SpeechRequest)...) - case schemas.ResponsesRequest, schemas.ResponsesStreamRequest: - spanName = "gen_ai.responses" - params = append(params, getResponsesRequestParams(req.ResponsesRequest)...) - case schemas.BatchCreateRequest: - spanName = "gen_ai.batch.create" - params = append(params, getBatchCreateRequestParams(req.BatchCreateRequest)...) - case schemas.BatchListRequest: - spanName = "gen_ai.batch.list" - params = append(params, getBatchListRequestParams(req.BatchListRequest)...) - case schemas.BatchRetrieveRequest: - spanName = "gen_ai.batch.retrieve" - params = append(params, getBatchRetrieveRequestParams(req.BatchRetrieveRequest)...) - case schemas.BatchCancelRequest: - spanName = "gen_ai.batch.cancel" - params = append(params, getBatchCancelRequestParams(req.BatchCancelRequest)...) - case schemas.BatchResultsRequest: - spanName = "gen_ai.batch.results" - params = append(params, getBatchResultsRequestParams(req.BatchResultsRequest)...) - case schemas.FileUploadRequest: - spanName = "gen_ai.file.upload" - params = append(params, getFileUploadRequestParams(req.FileUploadRequest)...) - case schemas.FileListRequest: - spanName = "gen_ai.file.list" - params = append(params, getFileListRequestParams(req.FileListRequest)...) - case schemas.FileRetrieveRequest: - spanName = "gen_ai.file.retrieve" - params = append(params, getFileRetrieveRequestParams(req.FileRetrieveRequest)...) - case schemas.FileDeleteRequest: - spanName = "gen_ai.file.delete" - params = append(params, getFileDeleteRequestParams(req.FileDeleteRequest)...) - case schemas.FileContentRequest: - spanName = "gen_ai.file.content" - params = append(params, getFileContentRequestParams(req.FileContentRequest)...) - } - attributes := append(p.attributesFromEnvironment, kvStr("service.name", p.serviceName), kvStr("service.version", p.bifrostVersion)) - // Preparing final resource span - return &ResourceSpan{ - Resource: &resourcepb.Resource{ - Attributes: attributes, - }, - ScopeSpans: []*ScopeSpan{ - { - Scope: &commonpb.InstrumentationScope{ - Name: "bifrost-otel-plugin", - }, - Spans: []*Span{ - { - TraceId: hexToBytes(traceID, 16), - SpanId: hexToBytes(spanID, 8), - Kind: tracepb.Span_SPAN_KIND_SERVER, - StartTimeUnixNano: uint64(timestamp.UnixNano()), - EndTimeUnixNano: uint64(timestamp.UnixNano()), - Name: spanName, - Attributes: params, - }, - }, - }, - }, +// convertSpanStatus maps Bifrost SpanStatus to OTEL Status +func convertSpanStatus(status schemas.SpanStatus, msg string) *tracepb.Status { + switch status { + case schemas.SpanStatusOk: + return &tracepb.Status{Code: tracepb.Status_STATUS_CODE_OK} + case schemas.SpanStatusError: + return &tracepb.Status{Code: tracepb.Status_STATUS_CODE_ERROR, Message: msg} + default: + return &tracepb.Status{Code: tracepb.Status_STATUS_CODE_UNSET} } } -// completeResourceSpan completes a resource span for a Bifrost response -func completeResourceSpan( - span *ResourceSpan, - timestamp time.Time, - resp *schemas.BifrostResponse, - bifrostErr *schemas.BifrostError, - pricingManager *modelcatalog.ModelCatalog, - virtualKeyID string, - virtualKeyName string, - selectedKeyID string, - selectedKeyName string, - numberOfRetries int, - fallbackIndex int, - teamID string, - teamName string, - customerID string, - customerName string, -) *ResourceSpan { - params := []*KeyValue{} - - if resp != nil { - switch { // Accumulator wont return stream type responses - case resp.TextCompletionResponse != nil: - params = append(params, kvStr("gen_ai.text.id", resp.TextCompletionResponse.ID)) - params = append(params, kvStr("gen_ai.text.model", resp.TextCompletionResponse.Model)) - params = append(params, kvStr("gen_ai.text.object", resp.TextCompletionResponse.Object)) - params = append(params, kvStr("gen_ai.text.system_fingerprint", resp.TextCompletionResponse.SystemFingerprint)) - outputMessages := []*AnyValue{} - for _, choice := range resp.TextCompletionResponse.Choices { - if choice.TextCompletionResponseChoice == nil { - continue - } - kvs := []*KeyValue{kvStr("role", string(schemas.ChatMessageRoleAssistant))} - if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil { - kvs = append(kvs, kvStr("content", *choice.TextCompletionResponseChoice.Text)) - } - outputMessages = append(outputMessages, listValue(kvs...)) - } - params = append(params, kvAny("gen_ai.text.output_messages", arrValue(outputMessages...))) - if resp.TextCompletionResponse.Usage != nil { - params = append(params, kvInt("gen_ai.usage.prompt_tokens", int64(resp.TextCompletionResponse.Usage.PromptTokens))) - params = append(params, kvInt("gen_ai.usage.completion_tokens", int64(resp.TextCompletionResponse.Usage.CompletionTokens))) - params = append(params, kvInt("gen_ai.usage.total_tokens", int64(resp.TextCompletionResponse.Usage.TotalTokens))) - } - // Computing cost - if pricingManager != nil { - cost := pricingManager.CalculateCostWithCacheDebug(resp) - params = append(params, kvDbl("gen_ai.usage.cost", cost)) - } - case resp.ChatResponse != nil: - params = append(params, kvStr("gen_ai.chat.id", resp.ChatResponse.ID)) - params = append(params, kvStr("gen_ai.chat.model", resp.ChatResponse.Model)) - params = append(params, kvStr("gen_ai.chat.object", resp.ChatResponse.Object)) - params = append(params, kvStr("gen_ai.chat.system_fingerprint", resp.ChatResponse.SystemFingerprint)) - params = append(params, kvStr("gen_ai.chat.created", fmt.Sprintf("%d", resp.ChatResponse.Created))) - if resp.ChatResponse.ServiceTier != nil { - params = append(params, kvStr("gen_ai.chat.service_tier", *resp.ChatResponse.ServiceTier)) - } - outputMessages := []*AnyValue{} - for _, choice := range resp.ChatResponse.Choices { - var role string - if choice.ChatNonStreamResponseChoice != nil && choice.ChatNonStreamResponseChoice.Message != nil && choice.ChatNonStreamResponseChoice.Message.Role != "" { - role = string(choice.ChatNonStreamResponseChoice.Message.Role) - } else { - role = string(schemas.ChatMessageRoleAssistant) - } - kvs := []*KeyValue{kvStr("role", role)} - - if choice.ChatNonStreamResponseChoice != nil && - choice.ChatNonStreamResponseChoice.Message != nil && - choice.ChatNonStreamResponseChoice.Message.Content != nil { - if choice.ChatNonStreamResponseChoice.Message.Content.ContentStr != nil { - kvs = append(kvs, kvStr("content", *choice.ChatNonStreamResponseChoice.Message.Content.ContentStr)) - } else if choice.ChatNonStreamResponseChoice.Message.Content.ContentBlocks != nil { - blockText := "" - for _, block := range choice.ChatNonStreamResponseChoice.Message.Content.ContentBlocks { - if block.Text != nil { - blockText += *block.Text - } - } - kvs = append(kvs, kvStr("content", blockText)) - } - } - outputMessages = append(outputMessages, listValue(kvs...)) - } - params = append(params, kvAny("gen_ai.chat.output_messages", arrValue(outputMessages...))) - if resp.ChatResponse.Usage != nil { - params = append(params, kvInt("gen_ai.usage.prompt_tokens", int64(resp.ChatResponse.Usage.PromptTokens))) - params = append(params, kvInt("gen_ai.usage.completion_tokens", int64(resp.ChatResponse.Usage.CompletionTokens))) - params = append(params, kvInt("gen_ai.usage.total_tokens", int64(resp.ChatResponse.Usage.TotalTokens))) - } - // Computing cost - if pricingManager != nil { - cost := pricingManager.CalculateCostWithCacheDebug(resp) - params = append(params, kvDbl("gen_ai.usage.cost", cost)) - } - case resp.ResponsesResponse != nil: - outputMessages := []*AnyValue{} - for _, message := range resp.ResponsesResponse.Output { - if message.Role == nil { - continue - } - kvs := []*KeyValue{kvStr("role", string(*message.Role))} - if message.Content != nil { - if message.Content.ContentStr != nil && *message.Content.ContentStr != "" { - kvs = append(kvs, kvStr("content", *message.Content.ContentStr)) - } else if message.Content.ContentBlocks != nil { - blockText := "" - for _, block := range message.Content.ContentBlocks { - if block.Text != nil { - blockText += *block.Text - } - } - kvs = append(kvs, kvStr("content", blockText)) - } - } - if message.ResponsesReasoning != nil && message.ResponsesReasoning.Summary != nil { - reasoningText := "" - for _, block := range message.ResponsesReasoning.Summary { - if block.Text != "" { - reasoningText += block.Text - } - } - kvs = append(kvs, kvStr("reasoning", reasoningText)) - } - outputMessages = append(outputMessages, listValue(kvs...)) - - } - params = append(params, kvAny("gen_ai.responses.output_messages", arrValue(outputMessages...))) - - responsesResponse := resp.ResponsesResponse - if responsesResponse.Include != nil { - params = append(params, kvStr("gen_ai.responses.include", strings.Join(responsesResponse.Include, ","))) - } - if responsesResponse.MaxOutputTokens != nil { - params = append(params, kvInt("gen_ai.responses.max_output_tokens", int64(*responsesResponse.MaxOutputTokens))) - } - if responsesResponse.MaxToolCalls != nil { - params = append(params, kvInt("gen_ai.responses.max_tool_calls", int64(*responsesResponse.MaxToolCalls))) - } - if responsesResponse.Metadata != nil { - params = append(params, kvStr("gen_ai.responses.metadata", fmt.Sprintf("%v", responsesResponse.Metadata))) - } - if responsesResponse.PreviousResponseID != nil { - params = append(params, kvStr("gen_ai.responses.previous_response_id", *responsesResponse.PreviousResponseID)) - } - if responsesResponse.PromptCacheKey != nil { - params = append(params, kvStr("gen_ai.responses.prompt_cache_key", *responsesResponse.PromptCacheKey)) - } - if responsesResponse.Reasoning != nil { - if responsesResponse.Reasoning.Summary != nil { - params = append(params, kvStr("gen_ai.responses.reasoning", *responsesResponse.Reasoning.Summary)) - } - if responsesResponse.Reasoning.Effort != nil { - params = append(params, kvStr("gen_ai.responses.reasoning_effort", *responsesResponse.Reasoning.Effort)) - } - if responsesResponse.Reasoning.GenerateSummary != nil { - params = append(params, kvStr("gen_ai.responses.reasoning_generate_summary", *responsesResponse.Reasoning.GenerateSummary)) - } - } - if responsesResponse.SafetyIdentifier != nil { - params = append(params, kvStr("gen_ai.responses.safety_identifier", *responsesResponse.SafetyIdentifier)) - } - if responsesResponse.ServiceTier != nil { - params = append(params, kvStr("gen_ai.responses.service_tier", *responsesResponse.ServiceTier)) - } - if responsesResponse.Store != nil { - params = append(params, kvBool("gen_ai.responses.store", *responsesResponse.Store)) - } - if responsesResponse.Temperature != nil { - params = append(params, kvDbl("gen_ai.responses.temperature", *responsesResponse.Temperature)) - } - if responsesResponse.Text != nil { - if responsesResponse.Text.Verbosity != nil { - params = append(params, kvStr("gen_ai.responses.text", *responsesResponse.Text.Verbosity)) - } - if responsesResponse.Text.Format != nil { - params = append(params, kvStr("gen_ai.responses.text_format_type", responsesResponse.Text.Format.Type)) - } - } - if responsesResponse.TopLogProbs != nil { - params = append(params, kvInt("gen_ai.responses.top_logprobs", int64(*responsesResponse.TopLogProbs))) - } - if responsesResponse.TopP != nil { - params = append(params, kvDbl("gen_ai.responses.top_p", *responsesResponse.TopP)) - } - if responsesResponse.ToolChoice != nil { - if responsesResponse.ToolChoice.ResponsesToolChoiceStruct != nil && responsesResponse.ToolChoice.ResponsesToolChoiceStr != nil { - params = append(params, kvStr("gen_ai.responses.tool_choice_type", *responsesResponse.ToolChoice.ResponsesToolChoiceStr)) - } - if responsesResponse.ToolChoice.ResponsesToolChoiceStruct != nil && responsesResponse.ToolChoice.ResponsesToolChoiceStruct.Name != nil { - params = append(params, kvStr("gen_ai.responses.tool_choice_name", *responsesResponse.ToolChoice.ResponsesToolChoiceStruct.Name)) - } - } - if responsesResponse.Truncation != nil { - params = append(params, kvStr("gen_ai.responses.truncation", *responsesResponse.Truncation)) - } - if responsesResponse.Tools != nil { - tools := make([]string, len(responsesResponse.Tools)) - for i, tool := range responsesResponse.Tools { - tools[i] = string(tool.Type) - } - params = append(params, kvStr("gen_ai.responses.tools", strings.Join(tools, ","))) - } - case resp.EmbeddingResponse != nil: - if resp.EmbeddingResponse.Usage != nil { - params = append(params, kvInt("gen_ai.usage.prompt_tokens", int64(resp.EmbeddingResponse.Usage.PromptTokens))) - params = append(params, kvInt("gen_ai.usage.completion_tokens", int64(resp.EmbeddingResponse.Usage.CompletionTokens))) - params = append(params, kvInt("gen_ai.usage.total_tokens", int64(resp.EmbeddingResponse.Usage.TotalTokens))) - } - case resp.SpeechResponse != nil: - if resp.SpeechResponse.Usage != nil { - params = append(params, kvInt("gen_ai.usage.input_tokens", int64(resp.SpeechResponse.Usage.InputTokens))) - params = append(params, kvInt("gen_ai.usage.output_tokens", int64(resp.SpeechResponse.Usage.OutputTokens))) - params = append(params, kvInt("gen_ai.usage.total_tokens", int64(resp.SpeechResponse.Usage.TotalTokens))) - } - case resp.TranscriptionResponse != nil: - outputMessages := []*AnyValue{} - kvs := []*KeyValue{kvStr("text", resp.TranscriptionResponse.Text)} - outputMessages = append(outputMessages, listValue(kvs...)) - params = append(params, kvAny("gen_ai.transcribe.output_messages", arrValue(outputMessages...))) - if resp.TranscriptionResponse.Usage != nil { - if resp.TranscriptionResponse.Usage.InputTokens != nil { - params = append(params, kvInt("gen_ai.usage.input_tokens", int64(*resp.TranscriptionResponse.Usage.InputTokens))) - } - if resp.TranscriptionResponse.Usage.OutputTokens != nil { - params = append(params, kvInt("gen_ai.usage.completion_tokens", int64(*resp.TranscriptionResponse.Usage.OutputTokens))) - } - if resp.TranscriptionResponse.Usage.TotalTokens != nil { - params = append(params, kvInt("gen_ai.usage.total_tokens", int64(*resp.TranscriptionResponse.Usage.TotalTokens))) - } - if resp.TranscriptionResponse.Usage.InputTokenDetails != nil { - params = append(params, kvInt("gen_ai.usage.input_token_details.text_tokens", int64(resp.TranscriptionResponse.Usage.InputTokenDetails.TextTokens))) - params = append(params, kvInt("gen_ai.usage.input_token_details.audio_tokens", int64(resp.TranscriptionResponse.Usage.InputTokenDetails.AudioTokens))) - } - } - case resp.BatchCreateResponse != nil: - params = append(params, kvStr("gen_ai.batch.id", resp.BatchCreateResponse.ID)) - params = append(params, kvStr("gen_ai.batch.status", string(resp.BatchCreateResponse.Status))) - if resp.BatchCreateResponse.Object != "" { - params = append(params, kvStr("gen_ai.batch.object", resp.BatchCreateResponse.Object)) - } - if resp.BatchCreateResponse.Endpoint != "" { - params = append(params, kvStr("gen_ai.batch.endpoint", resp.BatchCreateResponse.Endpoint)) - } - if resp.BatchCreateResponse.InputFileID != "" { - params = append(params, kvStr("gen_ai.batch.input_file_id", resp.BatchCreateResponse.InputFileID)) - } - if resp.BatchCreateResponse.CompletionWindow != "" { - params = append(params, kvStr("gen_ai.batch.completion_window", resp.BatchCreateResponse.CompletionWindow)) - } - if resp.BatchCreateResponse.CreatedAt != 0 { - params = append(params, kvInt("gen_ai.batch.created_at", resp.BatchCreateResponse.CreatedAt)) - } - if resp.BatchCreateResponse.ExpiresAt != nil { - params = append(params, kvInt("gen_ai.batch.expires_at", *resp.BatchCreateResponse.ExpiresAt)) - } - if resp.BatchCreateResponse.OutputFileID != nil { - params = append(params, kvStr("gen_ai.batch.output_file_id", *resp.BatchCreateResponse.OutputFileID)) - } - if resp.BatchCreateResponse.ErrorFileID != nil { - params = append(params, kvStr("gen_ai.batch.error_file_id", *resp.BatchCreateResponse.ErrorFileID)) - } - params = append(params, kvInt("gen_ai.batch.request_counts.total", int64(resp.BatchCreateResponse.RequestCounts.Total))) - params = append(params, kvInt("gen_ai.batch.request_counts.completed", int64(resp.BatchCreateResponse.RequestCounts.Completed))) - params = append(params, kvInt("gen_ai.batch.request_counts.failed", int64(resp.BatchCreateResponse.RequestCounts.Failed))) - case resp.BatchListResponse != nil: - if resp.BatchListResponse.Object != "" { - params = append(params, kvStr("gen_ai.batch.object", resp.BatchListResponse.Object)) - } - params = append(params, kvInt("gen_ai.batch.data_count", int64(len(resp.BatchListResponse.Data)))) - params = append(params, kvBool("gen_ai.batch.has_more", resp.BatchListResponse.HasMore)) - if resp.BatchListResponse.FirstID != nil { - params = append(params, kvStr("gen_ai.batch.first_id", *resp.BatchListResponse.FirstID)) - } - if resp.BatchListResponse.LastID != nil { - params = append(params, kvStr("gen_ai.batch.last_id", *resp.BatchListResponse.LastID)) - } - case resp.BatchRetrieveResponse != nil: - params = append(params, kvStr("gen_ai.batch.id", resp.BatchRetrieveResponse.ID)) - params = append(params, kvStr("gen_ai.batch.status", string(resp.BatchRetrieveResponse.Status))) - if resp.BatchRetrieveResponse.Object != "" { - params = append(params, kvStr("gen_ai.batch.object", resp.BatchRetrieveResponse.Object)) - } - if resp.BatchRetrieveResponse.Endpoint != "" { - params = append(params, kvStr("gen_ai.batch.endpoint", resp.BatchRetrieveResponse.Endpoint)) - } - if resp.BatchRetrieveResponse.InputFileID != "" { - params = append(params, kvStr("gen_ai.batch.input_file_id", resp.BatchRetrieveResponse.InputFileID)) - } - if resp.BatchRetrieveResponse.CompletionWindow != "" { - params = append(params, kvStr("gen_ai.batch.completion_window", resp.BatchRetrieveResponse.CompletionWindow)) - } - if resp.BatchRetrieveResponse.CreatedAt != 0 { - params = append(params, kvInt("gen_ai.batch.created_at", resp.BatchRetrieveResponse.CreatedAt)) - } - if resp.BatchRetrieveResponse.ExpiresAt != nil { - params = append(params, kvInt("gen_ai.batch.expires_at", *resp.BatchRetrieveResponse.ExpiresAt)) - } - if resp.BatchRetrieveResponse.InProgressAt != nil { - params = append(params, kvInt("gen_ai.batch.in_progress_at", *resp.BatchRetrieveResponse.InProgressAt)) - } - if resp.BatchRetrieveResponse.FinalizingAt != nil { - params = append(params, kvInt("gen_ai.batch.finalizing_at", *resp.BatchRetrieveResponse.FinalizingAt)) - } - if resp.BatchRetrieveResponse.CompletedAt != nil { - params = append(params, kvInt("gen_ai.batch.completed_at", *resp.BatchRetrieveResponse.CompletedAt)) - } - if resp.BatchRetrieveResponse.FailedAt != nil { - params = append(params, kvInt("gen_ai.batch.failed_at", *resp.BatchRetrieveResponse.FailedAt)) - } - if resp.BatchRetrieveResponse.ExpiredAt != nil { - params = append(params, kvInt("gen_ai.batch.expired_at", *resp.BatchRetrieveResponse.ExpiredAt)) - } - if resp.BatchRetrieveResponse.CancellingAt != nil { - params = append(params, kvInt("gen_ai.batch.cancelling_at", *resp.BatchRetrieveResponse.CancellingAt)) - } - if resp.BatchRetrieveResponse.CancelledAt != nil { - params = append(params, kvInt("gen_ai.batch.cancelled_at", *resp.BatchRetrieveResponse.CancelledAt)) - } - if resp.BatchRetrieveResponse.OutputFileID != nil { - params = append(params, kvStr("gen_ai.batch.output_file_id", *resp.BatchRetrieveResponse.OutputFileID)) - } - if resp.BatchRetrieveResponse.ErrorFileID != nil { - params = append(params, kvStr("gen_ai.batch.error_file_id", *resp.BatchRetrieveResponse.ErrorFileID)) - } - params = append(params, kvInt("gen_ai.batch.request_counts.total", int64(resp.BatchRetrieveResponse.RequestCounts.Total))) - params = append(params, kvInt("gen_ai.batch.request_counts.completed", int64(resp.BatchRetrieveResponse.RequestCounts.Completed))) - params = append(params, kvInt("gen_ai.batch.request_counts.failed", int64(resp.BatchRetrieveResponse.RequestCounts.Failed))) - case resp.BatchCancelResponse != nil: - params = append(params, kvStr("gen_ai.batch.id", resp.BatchCancelResponse.ID)) - params = append(params, kvStr("gen_ai.batch.status", string(resp.BatchCancelResponse.Status))) - if resp.BatchCancelResponse.Object != "" { - params = append(params, kvStr("gen_ai.batch.object", resp.BatchCancelResponse.Object)) - } - if resp.BatchCancelResponse.CancellingAt != nil { - params = append(params, kvInt("gen_ai.batch.cancelling_at", *resp.BatchCancelResponse.CancellingAt)) - } - if resp.BatchCancelResponse.CancelledAt != nil { - params = append(params, kvInt("gen_ai.batch.cancelled_at", *resp.BatchCancelResponse.CancelledAt)) - } - params = append(params, kvInt("gen_ai.batch.request_counts.total", int64(resp.BatchCancelResponse.RequestCounts.Total))) - params = append(params, kvInt("gen_ai.batch.request_counts.completed", int64(resp.BatchCancelResponse.RequestCounts.Completed))) - params = append(params, kvInt("gen_ai.batch.request_counts.failed", int64(resp.BatchCancelResponse.RequestCounts.Failed))) - case resp.BatchResultsResponse != nil: - params = append(params, kvStr("gen_ai.batch.batch_id", resp.BatchResultsResponse.BatchID)) - params = append(params, kvInt("gen_ai.batch.results_count", int64(len(resp.BatchResultsResponse.Results)))) - params = append(params, kvBool("gen_ai.batch.has_more", resp.BatchResultsResponse.HasMore)) - if resp.BatchResultsResponse.NextCursor != nil { - params = append(params, kvStr("gen_ai.batch.next_cursor", *resp.BatchResultsResponse.NextCursor)) - } - case resp.FileUploadResponse != nil: - params = append(params, kvStr("gen_ai.file.id", resp.FileUploadResponse.ID)) - if resp.FileUploadResponse.Object != "" { - params = append(params, kvStr("gen_ai.file.object", resp.FileUploadResponse.Object)) - } - params = append(params, kvInt("gen_ai.file.bytes", resp.FileUploadResponse.Bytes)) - params = append(params, kvInt("gen_ai.file.created_at", resp.FileUploadResponse.CreatedAt)) - params = append(params, kvStr("gen_ai.file.filename", resp.FileUploadResponse.Filename)) - params = append(params, kvStr("gen_ai.file.purpose", string(resp.FileUploadResponse.Purpose))) - if resp.FileUploadResponse.Status != "" { - params = append(params, kvStr("gen_ai.file.status", string(resp.FileUploadResponse.Status))) - } - if resp.FileUploadResponse.StorageBackend != "" { - params = append(params, kvStr("gen_ai.file.storage_backend", string(resp.FileUploadResponse.StorageBackend))) - } - case resp.FileListResponse != nil: - if resp.FileListResponse.Object != "" { - params = append(params, kvStr("gen_ai.file.object", resp.FileListResponse.Object)) - } - params = append(params, kvInt("gen_ai.file.data_count", int64(len(resp.FileListResponse.Data)))) - params = append(params, kvBool("gen_ai.file.has_more", resp.FileListResponse.HasMore)) - case resp.FileRetrieveResponse != nil: - params = append(params, kvStr("gen_ai.file.id", resp.FileRetrieveResponse.ID)) - if resp.FileRetrieveResponse.Object != "" { - params = append(params, kvStr("gen_ai.file.object", resp.FileRetrieveResponse.Object)) - } - params = append(params, kvInt("gen_ai.file.bytes", resp.FileRetrieveResponse.Bytes)) - params = append(params, kvInt("gen_ai.file.created_at", resp.FileRetrieveResponse.CreatedAt)) - params = append(params, kvStr("gen_ai.file.filename", resp.FileRetrieveResponse.Filename)) - params = append(params, kvStr("gen_ai.file.purpose", string(resp.FileRetrieveResponse.Purpose))) - if resp.FileRetrieveResponse.Status != "" { - params = append(params, kvStr("gen_ai.file.status", string(resp.FileRetrieveResponse.Status))) - } - if resp.FileRetrieveResponse.StorageBackend != "" { - params = append(params, kvStr("gen_ai.file.storage_backend", string(resp.FileRetrieveResponse.StorageBackend))) - } - case resp.FileDeleteResponse != nil: - params = append(params, kvStr("gen_ai.file.id", resp.FileDeleteResponse.ID)) - if resp.FileDeleteResponse.Object != "" { - params = append(params, kvStr("gen_ai.file.object", resp.FileDeleteResponse.Object)) - } - params = append(params, kvBool("gen_ai.file.deleted", resp.FileDeleteResponse.Deleted)) - case resp.FileContentResponse != nil: - params = append(params, kvStr("gen_ai.file.file_id", resp.FileContentResponse.FileID)) - if resp.FileContentResponse.ContentType != "" { - params = append(params, kvStr("gen_ai.file.content_type", resp.FileContentResponse.ContentType)) - } - if len(resp.FileContentResponse.Content) > 0 { - params = append(params, kvInt("gen_ai.file.content_bytes", int64(len(resp.FileContentResponse.Content)))) - } - } +// convertSpanEvents converts Bifrost span events to OTEL events +func convertSpanEvents(events []schemas.SpanEvent) []*Event { + if len(events) == 0 { + return nil } - - // This is a fallback for worst case scenario where latency is not available - status := tracepb.Status_STATUS_CODE_OK - if bifrostErr != nil { - status = tracepb.Status_STATUS_CODE_ERROR - if bifrostErr.Error != nil { - if bifrostErr.Error.Type != nil { - params = append(params, kvStr("gen_ai.error.type", *bifrostErr.Error.Type)) - } - if bifrostErr.Error.Code != nil { - params = append(params, kvStr("gen_ai.error.code", *bifrostErr.Error.Code)) - } + otelEvents := make([]*Event, len(events)) + for i, event := range events { + otelEvents[i] = &Event{ + TimeUnixNano: uint64(event.Timestamp.UnixNano()), + Name: event.Name, + Attributes: convertAttributesToKeyValues(event.Attributes), } - params = append(params, kvStr("gen_ai.error", bifrostErr.Error.Message)) - } - // Adding request metadata to the span for backward compatibility - if virtualKeyID != "" { - params = append(params, kvStr("gen_ai.virtual_key_id", virtualKeyID)) - params = append(params, kvStr("gen_ai.virtual_key_name", virtualKeyName)) - } - if selectedKeyID != "" { - params = append(params, kvStr("gen_ai.selected_key_id", selectedKeyID)) - params = append(params, kvStr("gen_ai.selected_key_name", selectedKeyName)) - } - if teamID != "" { - params = append(params, kvStr("gen_ai.team_id", teamID)) - params = append(params, kvStr("gen_ai.team_name", teamName)) - } - if customerID != "" { - params = append(params, kvStr("gen_ai.customer_id", customerID)) - params = append(params, kvStr("gen_ai.customer_name", customerName)) } - params = append(params, kvInt("gen_ai.number_of_retries", int64(numberOfRetries))) - params = append(params, kvInt("gen_ai.fallback_index", int64(fallbackIndex))) - span.ScopeSpans[0].Spans[0].Attributes = append(span.ScopeSpans[0].Spans[0].Attributes, params...) - span.ScopeSpans[0].Spans[0].Status = &tracepb.Status{Code: status} - span.ScopeSpans[0].Spans[0].EndTimeUnixNano = uint64(timestamp.UnixNano()) - // Attaching virtual keys as resource attributes as well - span.Resource.Attributes = append(span.Resource.Attributes, kvStr("virtual_key_id", virtualKeyID)) - span.Resource.Attributes = append(span.Resource.Attributes, kvStr("virtual_key_name", virtualKeyName)) - span.Resource.Attributes = append(span.Resource.Attributes, kvStr("selected_key_id", selectedKeyID)) - span.Resource.Attributes = append(span.Resource.Attributes, kvStr("selected_key_name", selectedKeyName)) - span.Resource.Attributes = append(span.Resource.Attributes, kvStr("team_id", teamID)) - span.Resource.Attributes = append(span.Resource.Attributes, kvStr("team_name", teamName)) - span.Resource.Attributes = append(span.Resource.Attributes, kvStr("customer_id", customerID)) - span.Resource.Attributes = append(span.Resource.Attributes, kvStr("customer_name", customerName)) - span.Resource.Attributes = append(span.Resource.Attributes, kvInt("number_of_retries", int64(numberOfRetries))) - span.Resource.Attributes = append(span.Resource.Attributes, kvInt("fallback_index", int64(fallbackIndex))) - return span + return otelEvents } diff --git a/plugins/otel/go.mod b/plugins/otel/go.mod index f5063443e1..583c5fb9e6 100644 --- a/plugins/otel/go.mod +++ b/plugins/otel/go.mod @@ -3,8 +3,8 @@ module github.com/maximhq/bifrost/plugins/otel go 1.25.5 require ( - github.com/maximhq/bifrost/core v1.2.49 - github.com/maximhq/bifrost/framework v1.1.61 + github.com/maximhq/bifrost/core v1.3.8 + github.com/maximhq/bifrost/framework v1.2.8 google.golang.org/grpc v1.77.0 google.golang.org/protobuf v1.36.11 ) @@ -40,7 +40,10 @@ require ( github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/sonic/loader v0.4.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/clarkmcc/go-typescript v0.7.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/analysis v0.24.2 // indirect @@ -64,8 +67,10 @@ require ( github.com/go-openapi/swag/typeutils v0.25.4 // indirect github.com/go-openapi/swag/yamlutils v0.25.4 // indirect github.com/go-openapi/validate v0.25.1 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f // indirect github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect github.com/invopop/jsonschema v0.13.0 // indirect diff --git a/plugins/otel/go.sum b/plugins/otel/go.sum index acad845470..e82a26fd8d 100644 --- a/plugins/otel/go.sum +++ b/plugins/otel/go.sum @@ -12,6 +12,8 @@ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/Masterminds/semver/v3 v3.3.1 h1:QtNSWtVZ3nBfk8mAOu/B6v7FMJ+NHTIgUPi7rj+4nv4= +github.com/Masterminds/semver/v3 v3.3.1/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= @@ -68,6 +70,8 @@ github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2N github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/clarkmcc/go-typescript v0.7.0 h1:3nVeaPYyTCWjX6Lf8GoEOTxME2bM5tLuWmwhSZ86uxg= +github.com/clarkmcc/go-typescript v0.7.0/go.mod h1:IZ/nzoVeydAmyfX7l6Jmp8lJDOEnae3jffoXwP4UyYg= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -77,6 +81,10 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 h1:jxmXU5V9tXxJnydU5v/m9SG8TRUa/Z7IXODBpMs/P+U= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -132,6 +140,8 @@ github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6 github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= github.com/go-openapi/validate v0.25.1 h1:sSACUI6Jcnbo5IWqbYHgjibrhhmt3vR6lCzKZnmAgBw= github.com/go-openapi/validate v0.25.1/go.mod h1:RMVyVFYte0gbSTaZ0N4KmTn6u/kClvAFp+mAVfS/DQc= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -141,6 +151,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f h1:HU1RgM6NALf/KW9HEY6zry3ADbDKcmpQ+hJedoNGQYQ= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f/go.mod h1:67FPmZWbr+KDT/VlpWtw6sO9XSjpJmLuHpoLmWiTGgY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU= @@ -186,10 +198,10 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.2.49 h1:fk6l6r3kVBlpN73wYXmgtV6O4bhedOjSO4LAEz/7leg= -github.com/maximhq/bifrost/core v1.2.49/go.mod h1:z7nOx15e91ktZGi+pZHq+uhShlEK+fM4UyYUpP6oHAw= -github.com/maximhq/bifrost/framework v1.1.61 h1:fMjvICbkrdWMtGnLYrjSNrcmQYqtQvOh/swmrJTvf+E= -github.com/maximhq/bifrost/framework v1.1.61/go.mod h1:wVUPzB8K5S/5GWuxqp8dXf3nNZkqJsS/APMIcq48SOI= +github.com/maximhq/bifrost/core v1.3.8 h1:xtwB9+HeTzYz5IKHkpUtupzBd0A5yl1avdLJGjsOKPI= +github.com/maximhq/bifrost/core v1.3.8/go.mod h1:abKQRnJQPZz8/UMxCcbuNHEyq19Db+IX4KlGJdlLY8E= +github.com/maximhq/bifrost/framework v1.2.8 h1:/oTpacuw7k0zRUJ9dSSQRtAVx3nLGSiR7GFwOjGxZNs= +github.com/maximhq/bifrost/framework v1.2.8/go.mod h1:mjw9YXh/Oxi3HeBCJ+3HJ6ftv43Wo4t0T4EzpcIbnr0= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= @@ -289,6 +301,8 @@ google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/otel/main.go b/plugins/otel/main.go index 40d7ec2431..cca58c13ac 100644 --- a/plugins/otel/main.go +++ b/plugins/otel/main.go @@ -6,29 +6,16 @@ import ( "fmt" "os" "strings" - "sync" - "time" "github.com/bytedance/sonic" - bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/modelcatalog" - "github.com/maximhq/bifrost/framework/streaming" commonpb "go.opentelemetry.io/proto/otlp/common/v1" ) // logger is the logger for the OTEL plugin var logger schemas.Logger -// ContextKey is a custom type for context keys to prevent collisions -type ContextKey string - -// Context keys for otel plugin -const ( - TraceIDKey ContextKey = "plugin-otel-trace-id" - SpanIDKey ContextKey = "plugin-otel-span-id" -) - // OTELResponseAttributesEnvKey is the environment variable key for the OTEL resource attributes // We check if this is present in the environment variables and if so, we will use it to set the attributes for all spans at the resource level const OTELResponseAttributesEnvKey = "OTEL_RESOURCE_ATTRIBUTES" @@ -65,7 +52,9 @@ type Config struct { TLSCACert string `json:"tls_ca_cert"` } -// OtelPlugin is the plugin for OpenTelemetry +// OtelPlugin is the plugin for OpenTelemetry. +// It implements the ObservabilityPlugin interface to receive completed traces +// from the tracing middleware and forward them to an OTEL collector. type OtelPlugin struct { ctx context.Context cancel context.CancelFunc @@ -80,14 +69,9 @@ type OtelPlugin struct { attributesFromEnvironment []*commonpb.KeyValue - ongoingSpans *TTLSyncMap - client OtelClient pricingManager *modelcatalog.ModelCatalog - accumulator *streaming.Accumulator // Accumulator for streaming chunks - - emitWg sync.WaitGroup // Track in-flight emissions } // Init function for the OTEL plugin @@ -100,7 +84,7 @@ func Init(ctx context.Context, config *Config, _logger schemas.Logger, pricingMa logger.Warn("otel plugin requires model catalog to calculate cost, all cost calculations will be skipped.") } var err error - // If headers are present , and any of them start with env., we will replace the value with the environment variable + // If headers are present, and any of them start with env., we will replace the value with the environment variable if config.Headers != nil { for key, value := range config.Headers { if newValue, ok := strings.CutPrefix(value, "env."); ok { @@ -132,11 +116,8 @@ func Init(ctx context.Context, config *Config, _logger schemas.Logger, pricingMa url: config.CollectorURL, traceType: config.TraceType, headers: config.Headers, - ongoingSpans: NewTTLSyncMap(20*time.Minute, 1*time.Minute), protocol: config.Protocol, pricingManager: pricingManager, - accumulator: streaming.NewAccumulator(pricingManager, logger), - emitWg: sync.WaitGroup{}, bifrostVersion: bifrostVersion, attributesFromEnvironment: attributesFromEnvironment, } @@ -164,9 +145,9 @@ func (p *OtelPlugin) GetName() string { return PluginName } -// TransportInterceptor is not used for this plugin -func (p *OtelPlugin) TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - return headers, body, nil +// HTTPTransportIntercept is not used for this plugin +func (p *OtelPlugin) HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + return nil, nil } // ValidateConfig function for the OTEL plugin @@ -205,139 +186,53 @@ func (p *OtelPlugin) ValidateConfig(config any) (*Config, error) { return &otelConfig, nil } -// PreHook function for the OTEL plugin -func (p *OtelPlugin) PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { - if p.client == nil { - logger.Warn("otel client is not initialized") - return req, nil, nil - } - traceIDValue := ctx.Value(schemas.BifrostContextKeyRequestID) - if traceIDValue == nil { - logger.Warn("trace id not found in context") - return req, nil, nil - } - traceID, ok := traceIDValue.(string) - if !ok { - logger.Warn("trace id not found in context") - return req, nil, nil - } - spanID := fmt.Sprintf("%s-root-span", traceID) - createdTimestamp := time.Now() - if bifrost.IsStreamRequestType(req.RequestType) { - p.accumulator.CreateStreamAccumulator(traceID, createdTimestamp) - } - p.ongoingSpans.Set(traceID, p.createResourceSpan(traceID, spanID, time.Now(), req)) +// PreHook is a no-op - tracing is handled via the Inject method. +// The OTEL plugin receives completed traces from TracingMiddleware. +func (p *OtelPlugin) PreHook(_ *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { return req, nil, nil } -// PostHook function for the OTEL plugin -func (p *OtelPlugin) PostHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { - traceIDValue := ctx.Value(schemas.BifrostContextKeyRequestID) - if traceIDValue == nil { - logger.Warn("trace id not found in context") - return resp, bifrostErr, nil +// PostHook is a no-op - tracing is handled via the Inject method. +// The OTEL plugin receives completed traces from TracingMiddleware. +func (p *OtelPlugin) PostHook(_ *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + return resp, bifrostErr, nil +} + +// Inject receives a completed trace and sends it to the OTEL collector. +// Implements schemas.ObservabilityPlugin interface. +// This method is called asynchronously by TracingMiddleware after the response +// has been written to the client. +func (p *OtelPlugin) Inject(ctx context.Context, trace *schemas.Trace) error { + if trace == nil { + return nil } - traceID, ok := traceIDValue.(string) - if !ok { - logger.Warn("trace id not found in context") - return resp, bifrostErr, nil + if p.client == nil { + logger.Warn("otel client is not initialized") + return nil } - virtualKeyID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKey("bf-governance-virtual-key-id")) - virtualKeyName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKey("bf-governance-virtual-key-name")) - - selectedKeyID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeySelectedKeyID) - selectedKeyName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeySelectedKeyName) + // Convert schemas.Trace to OTEL ResourceSpan + resourceSpan := p.convertTraceToResourceSpan(trace) - numberOfRetries := bifrost.GetIntFromContext(ctx, schemas.BifrostContextKeyNumberOfRetries) - fallbackIndex := bifrost.GetIntFromContext(ctx, schemas.BifrostContextKeyFallbackIndex) - - teamID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKey("bf-governance-team-id")) - teamName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKey("bf-governance-team-name")) - customerID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKey("bf-governance-customer-id")) - customerName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKey("bf-governance-customer-name")) + // Emit to collector + if err := p.client.Emit(ctx, []*ResourceSpan{resourceSpan}); err != nil { + logger.Error("failed to emit trace %s: %v", trace.TraceID, err) + return err + } - // Track every PostHook emission, stream and non-stream. - p.emitWg.Add(1) - go func() { - defer p.emitWg.Done() - span, ok := p.ongoingSpans.Get(traceID) - if !ok { - logger.Warn("span not found in ongoing spans") - return - } - requestType, _, _ := bifrost.GetResponseFields(resp, bifrostErr) - if span, ok := span.(*ResourceSpan); ok { - // We handle streaming responses differently, we will use the accumulator to process the response and then emit the final response - if bifrost.IsStreamRequestType(requestType) { - streamResponse, err := p.accumulator.ProcessStreamingResponse(ctx, resp, bifrostErr) - if err != nil { - logger.Debug("failed to process streaming response: %v", err) - } - if streamResponse != nil && streamResponse.Type == streaming.StreamResponseTypeFinal { - defer p.ongoingSpans.Delete(traceID) - if err := p.client.Emit(p.ctx, []*ResourceSpan{completeResourceSpan( - span, - time.Now(), - streamResponse.ToBifrostResponse(), - bifrostErr, - p.pricingManager, - virtualKeyID, - virtualKeyName, - selectedKeyID, - selectedKeyName, - numberOfRetries, - fallbackIndex, - teamID, - teamName, - customerID, - customerName, - )}); err != nil { - logger.Error("failed to emit response span for request %s: %v", traceID, err) - } - } - return - } - defer p.ongoingSpans.Delete(traceID) - rs := completeResourceSpan( - span, - time.Now(), - resp, - bifrostErr, - p.pricingManager, - virtualKeyID, - virtualKeyName, - selectedKeyID, - selectedKeyName, - numberOfRetries, - fallbackIndex, - teamID, - teamName, - customerID, - customerName, - ) - if err := p.client.Emit(p.ctx, []*ResourceSpan{rs}); err != nil { - logger.Error("failed to emit response span for request %s: %v", traceID, err) - } - } - }() - return resp, bifrostErr, nil + return nil } // Cleanup function for the OTEL plugin func (p *OtelPlugin) Cleanup() error { - p.emitWg.Wait() if p.cancel != nil { p.cancel() } - if p.ongoingSpans != nil { - p.ongoingSpans.Stop() - } - if p.accumulator != nil { - p.accumulator.Cleanup() - } if p.client != nil { return p.client.Close() } return nil } + +// Compile-time check that OtelPlugin implements ObservabilityPlugin +var _ schemas.ObservabilityPlugin = (*OtelPlugin)(nil) diff --git a/plugins/otel/ttlsyncmap.go b/plugins/otel/ttlsyncmap.go deleted file mode 100644 index d54999d1b3..0000000000 --- a/plugins/otel/ttlsyncmap.go +++ /dev/null @@ -1,184 +0,0 @@ -package otel - -import ( - "sync" - "time" -) - -// TTLSyncMap is a thread-safe map with automatic cleanup of expired entries -type TTLSyncMap struct { - data sync.Map - ttl time.Duration - cleanupTicker *time.Ticker - stopCleanup chan struct{} - cleanupWg sync.WaitGroup - stopOnce sync.Once -} - -// entry stores the value along with its expiration time -type entry struct { - value interface{} - expiresAt time.Time -} - -// NewTTLSyncMap creates a new TTL sync map with the specified TTL and cleanup interval -// ttl: time to live for each entry -// cleanupInterval: how often to check for expired entries (should be <= ttl) -func NewTTLSyncMap(ttl time.Duration, cleanupInterval time.Duration) *TTLSyncMap { - if ttl <= 0 { - ttl = time.Minute - } - if cleanupInterval <= 0 { - cleanupInterval = ttl / 2 - if cleanupInterval <= 0 { - cleanupInterval = time.Minute - } - } - - m := &TTLSyncMap{ - ttl: ttl, - cleanupTicker: time.NewTicker(cleanupInterval), - stopCleanup: make(chan struct{}), - } - - // Start the cleanup goroutine - m.cleanupWg.Add(1) - go m.startCleanup() - - return m -} - -// Set stores a key-value pair with TTL -func (m *TTLSyncMap) Set(key, value interface{}) { - m.data.Store(key, &entry{ - value: value, - expiresAt: time.Now().Add(m.ttl), - }) -} - -// Get retrieves a value by key, returns (value, true) if found and not expired, -// (nil, false) otherwise -func (m *TTLSyncMap) Get(key interface{}) (interface{}, bool) { - val, ok := m.data.Load(key) - if !ok { - return nil, false - } - - e := val.(*entry) - if time.Now().After(e.expiresAt) { - // Entry has expired, delete it - m.data.Delete(key) - return nil, false - } - - return e.value, true -} - -// Delete removes a key-value pair from the map -func (m *TTLSyncMap) Delete(key interface{}) { - m.data.Delete(key) -} - -// Refresh updates the expiration time of an existing entry -func (m *TTLSyncMap) Refresh(key interface{}) bool { - val, ok := m.data.Load(key) - if !ok { - return false - } - e, _ := val.(*entry) - if e == nil || time.Now().After(e.expiresAt) { - m.data.Delete(key) - return false - } - m.data.Store(key, &entry{ - value: e.value, - expiresAt: time.Now().Add(m.ttl), - }) - return true -} - -// GetOrSet retrieves a value by key if it exists and is not expired, -// otherwise sets the new value and returns it -func (m *TTLSyncMap) GetOrSet(key, value interface{}) (actual interface{}, loaded bool) { - actual, loaded = m.Get(key) - if !loaded { - m.Set(key, value) - actual = value - } - return actual, loaded -} - -// Range calls f sequentially for each key and value present in the map. -// If f returns false, range stops the iteration. -// Only non-expired entries are included. -func (m *TTLSyncMap) Range(f func(key, value interface{}) bool) { - now := time.Now() - m.data.Range(func(key, val interface{}) bool { - e := val.(*entry) - if now.After(e.expiresAt) { - // Skip expired entry and delete it - m.data.Delete(key) - return true - } - return f(key, e.value) - }) -} - -// Len returns the number of non-expired entries in the map -func (m *TTLSyncMap) Len() int { - count := 0 - m.Range(func(_, _ interface{}) bool { - count++ - return true - }) - return count -} - -// startCleanup runs in a background goroutine to periodically remove expired entries -func (m *TTLSyncMap) startCleanup() { - defer m.cleanupWg.Done() - - for { - select { - case <-m.cleanupTicker.C: - m.cleanup() - case <-m.stopCleanup: - return - } - } -} - -// cleanup removes all expired entries from the map -func (m *TTLSyncMap) cleanup() { - now := time.Now() - m.data.Range(func(key, val interface{}) bool { - e := val.(*entry) - if now.After(e.expiresAt) { - m.data.Delete(key) - } - return true - }) - if m.Len() > 10000 { - logger.Warn("[otel] map cleanup done. current size: %d entries", m.Len()) - } else { - logger.Debug("[otel] map cleanup done. current size: %d entries", m.Len()) - } -} - -// Stop stops the cleanup goroutine and releases resources -// Call this when you're done with the map to prevent goroutine leaks -func (m *TTLSyncMap) Stop() { - m.stopOnce.Do(func() { - close(m.stopCleanup) - m.cleanupTicker.Stop() - m.cleanupWg.Wait() - }) -} - -// Clear removes all entries from the map -func (m *TTLSyncMap) Clear() { - m.data.Range(func(key, _ interface{}) bool { - m.data.Delete(key) - return true - }) -} diff --git a/plugins/otel/version b/plugins/otel/version index 260e057ecf..db15278970 100644 --- a/plugins/otel/version +++ b/plugins/otel/version @@ -1 +1 @@ -1.0.61 \ No newline at end of file +1.1.8 \ No newline at end of file diff --git a/plugins/semanticcache/go.mod b/plugins/semanticcache/go.mod index 11fc667f6b..95fffcdebc 100644 --- a/plugins/semanticcache/go.mod +++ b/plugins/semanticcache/go.mod @@ -5,9 +5,9 @@ go 1.25.5 require ( github.com/cespare/xxhash/v2 v2.3.0 github.com/google/uuid v1.6.0 - github.com/maximhq/bifrost/core v1.2.49 - github.com/maximhq/bifrost/framework v1.1.61 - github.com/maximhq/bifrost/plugins/mocker v1.3.40 + github.com/maximhq/bifrost/core v1.3.8 + github.com/maximhq/bifrost/framework v1.2.8 + github.com/maximhq/bifrost/plugins/mocker v1.4.4 ) require ( @@ -41,8 +41,11 @@ require ( github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/sonic v1.14.2 // indirect github.com/bytedance/sonic/loader v0.4.0 // indirect + github.com/clarkmcc/go-typescript v0.7.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/analysis v0.24.2 // indirect @@ -66,8 +69,10 @@ require ( github.com/go-openapi/swag/typeutils v0.25.4 // indirect github.com/go-openapi/swag/yamlutils v0.25.4 // indirect github.com/go-openapi/validate v0.25.1 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect diff --git a/plugins/semanticcache/go.sum b/plugins/semanticcache/go.sum index 65954a3afb..6277aab61d 100644 --- a/plugins/semanticcache/go.sum +++ b/plugins/semanticcache/go.sum @@ -12,6 +12,8 @@ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/Masterminds/semver/v3 v3.3.1 h1:QtNSWtVZ3nBfk8mAOu/B6v7FMJ+NHTIgUPi7rj+4nv4= +github.com/Masterminds/semver/v3 v3.3.1/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= @@ -68,6 +70,8 @@ github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2N github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/clarkmcc/go-typescript v0.7.0 h1:3nVeaPYyTCWjX6Lf8GoEOTxME2bM5tLuWmwhSZ86uxg= +github.com/clarkmcc/go-typescript v0.7.0/go.mod h1:IZ/nzoVeydAmyfX7l6Jmp8lJDOEnae3jffoXwP4UyYg= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -77,6 +81,10 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 h1:jxmXU5V9tXxJnydU5v/m9SG8TRUa/Z7IXODBpMs/P+U= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -132,6 +140,8 @@ github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6 github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= github.com/go-openapi/validate v0.25.1 h1:sSACUI6Jcnbo5IWqbYHgjibrhhmt3vR6lCzKZnmAgBw= github.com/go-openapi/validate v0.25.1/go.mod h1:RMVyVFYte0gbSTaZ0N4KmTn6u/kClvAFp+mAVfS/DQc= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -141,6 +151,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f h1:HU1RgM6NALf/KW9HEY6zry3ADbDKcmpQ+hJedoNGQYQ= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f/go.mod h1:67FPmZWbr+KDT/VlpWtw6sO9XSjpJmLuHpoLmWiTGgY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= @@ -186,12 +198,12 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.2.49 h1:fk6l6r3kVBlpN73wYXmgtV6O4bhedOjSO4LAEz/7leg= -github.com/maximhq/bifrost/core v1.2.49/go.mod h1:z7nOx15e91ktZGi+pZHq+uhShlEK+fM4UyYUpP6oHAw= -github.com/maximhq/bifrost/framework v1.1.61 h1:fMjvICbkrdWMtGnLYrjSNrcmQYqtQvOh/swmrJTvf+E= -github.com/maximhq/bifrost/framework v1.1.61/go.mod h1:wVUPzB8K5S/5GWuxqp8dXf3nNZkqJsS/APMIcq48SOI= -github.com/maximhq/bifrost/plugins/mocker v1.3.40 h1:42/NppC7Vlwsjnjd2GRu/Bf3D/Jpo0JW0RDYwamIOMk= -github.com/maximhq/bifrost/plugins/mocker v1.3.40/go.mod h1:6B4dtTixMsdGtot/6U7nEIzJxFsM8j6ZYCmk7qSHNr8= +github.com/maximhq/bifrost/core v1.3.8 h1:xtwB9+HeTzYz5IKHkpUtupzBd0A5yl1avdLJGjsOKPI= +github.com/maximhq/bifrost/core v1.3.8/go.mod h1:abKQRnJQPZz8/UMxCcbuNHEyq19Db+IX4KlGJdlLY8E= +github.com/maximhq/bifrost/framework v1.2.8 h1:/oTpacuw7k0zRUJ9dSSQRtAVx3nLGSiR7GFwOjGxZNs= +github.com/maximhq/bifrost/framework v1.2.8/go.mod h1:mjw9YXh/Oxi3HeBCJ+3HJ6ftv43Wo4t0T4EzpcIbnr0= +github.com/maximhq/bifrost/plugins/mocker v1.4.4 h1:hLmqonf8IFtNBCHQ+R40yCNX5rCTRkqYw0+hU5L5zlg= +github.com/maximhq/bifrost/plugins/mocker v1.4.4/go.mod h1:U9ytiBZHQDRGn9nOlfjb08wood4AtiqzsD+dmsFugAY= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= @@ -287,6 +299,8 @@ google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/semanticcache/main.go b/plugins/semanticcache/main.go index 0da4592ffa..1153a5ac7a 100644 --- a/plugins/semanticcache/main.go +++ b/plugins/semanticcache/main.go @@ -206,7 +206,7 @@ func (pa *PluginAccount) GetConfiguredProviders() ([]schemas.ModelProvider, erro return []schemas.ModelProvider{pa.provider}, nil } -func (pa *PluginAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { +func (pa *PluginAccount) GetKeysForProvider(ctx context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { return pa.keys, nil } @@ -335,9 +335,9 @@ func (plugin *Plugin) GetName() string { return PluginName } -// TransportInterceptor is not used for this plugin -func (plugin *Plugin) TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - return headers, body, nil +// HTTPTransportIntercept is not used for this plugin +func (plugin *Plugin) HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + return nil, nil } // PreHook is called before a request is processed by Bifrost. @@ -354,12 +354,11 @@ func (plugin *Plugin) TransportInterceptor(ctx *schemas.BifrostContext, url stri // - error: Any error that occurred during cache lookup func (plugin *Plugin) PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { provider, model, _ := req.GetRequestFields() - // Get the cache key from the context var cacheKey string var ok bool - cacheKey, ok = (*ctx).Value(CacheKey).(string) + cacheKey, ok = ctx.Value(CacheKey).(string) if !ok || cacheKey == "" { plugin.logger.Debug(PluginLoggerPrefix + " No cache key found in context, continuing without caching") return req, nil, nil @@ -377,10 +376,10 @@ func (plugin *Plugin) PreHook(ctx *schemas.BifrostContext, req *schemas.BifrostR ctx.SetValue(requestIDKey, requestID) ctx.SetValue(requestModelKey, model) ctx.SetValue(requestProviderKey, provider) - + performDirectSearch, performSemanticSearch := true, true - if (*ctx).Value(CacheTypeKey) != nil { - cacheTypeVal, ok := (*ctx).Value(CacheTypeKey).(CacheType) + if ctx.Value(CacheTypeKey) != nil { + cacheTypeVal, ok := ctx.Value(CacheTypeKey).(CacheType) if !ok { plugin.logger.Warn(PluginLoggerPrefix + " Cache type is not a CacheType, using all available cache types") } else { @@ -450,7 +449,7 @@ func (plugin *Plugin) PostHook(ctx *schemas.BifrostContext, res *schemas.Bifrost return res, bifrostErr, nil } - isCacheHit := (*ctx).Value(isCacheHitKey) + isCacheHit := ctx.Value(isCacheHitKey) if isCacheHit != nil { isCacheHitValue, ok := isCacheHit.(bool) if ok && isCacheHitValue { @@ -459,7 +458,7 @@ func (plugin *Plugin) PostHook(ctx *schemas.BifrostContext, res *schemas.Bifrost } // Check if caching is explicitly disabled - noStore := (*ctx).Value(CacheNoStoreKey) + noStore := ctx.Value(CacheNoStoreKey) if noStore != nil { noStoreValue, ok := noStore.(bool) if ok && noStoreValue { @@ -469,13 +468,13 @@ func (plugin *Plugin) PostHook(ctx *schemas.BifrostContext, res *schemas.Bifrost } // Get the cache key from context - cacheKey, ok := (*ctx).Value(CacheKey).(string) + cacheKey, ok := ctx.Value(CacheKey).(string) if !ok { return res, nil, nil } // Get the request ID from context - requestID, ok := (*ctx).Value(requestIDKey).(string) + requestID, ok := ctx.Value(requestIDKey).(string) if !ok { return res, nil, nil } @@ -485,8 +484,8 @@ func (plugin *Plugin) PostHook(ctx *schemas.BifrostContext, res *schemas.Bifrost var shouldStoreEmbeddings = true var shouldStoreHash = true - if (*ctx).Value(CacheTypeKey) != nil { - cacheTypeVal, ok := (*ctx).Value(CacheTypeKey).(CacheType) + if ctx.Value(CacheTypeKey) != nil { + cacheTypeVal, ok := ctx.Value(CacheTypeKey).(CacheType) if ok { if cacheTypeVal == CacheTypeDirect { // For direct-only caching, skip embedding operations entirely @@ -501,7 +500,7 @@ func (plugin *Plugin) PostHook(ctx *schemas.BifrostContext, res *schemas.Bifrost if shouldStoreHash { // Get the hash from context - hash, ok = (*ctx).Value(requestHashKey).(string) + hash, ok = ctx.Value(requestHashKey).(string) if !ok { plugin.logger.Warn(PluginLoggerPrefix + " Hash is not a string. Continuing without caching") return res, nil, nil @@ -513,7 +512,7 @@ func (plugin *Plugin) PostHook(ctx *schemas.BifrostContext, res *schemas.Bifrost // Get embedding from context if available and needed if shouldStoreEmbeddings && requestType != schemas.EmbeddingRequest && requestType != schemas.TranscriptionRequest { - embeddingValue := (*ctx).Value(requestEmbeddingKey) + embeddingValue := ctx.Value(requestEmbeddingKey) if embeddingValue != nil { embedding, ok = embeddingValue.([]float32) if !ok { @@ -526,14 +525,14 @@ func (plugin *Plugin) PostHook(ctx *schemas.BifrostContext, res *schemas.Bifrost } // Get the provider from context - provider, ok := (*ctx).Value(requestProviderKey).(schemas.ModelProvider) + provider, ok := ctx.Value(requestProviderKey).(schemas.ModelProvider) if !ok { plugin.logger.Warn(PluginLoggerPrefix + " Provider is not a schemas.ModelProvider, continuing without caching") return res, nil, nil } // Get the model from context - model, ok := (*ctx).Value(requestModelKey).(string) + model, ok := ctx.Value(requestModelKey).(string) if !ok { plugin.logger.Warn(PluginLoggerPrefix + " Model is not a string, continuing without caching") return res, nil, nil @@ -542,7 +541,7 @@ func (plugin *Plugin) PostHook(ctx *schemas.BifrostContext, res *schemas.Bifrost isFinalChunk := bifrost.IsFinalChunk(ctx) // Get the input tokens from context (can be nil if not set) - inputTokens, ok := (*ctx).Value(requestEmbeddingTokensKey).(int) + inputTokens, ok := ctx.Value(requestEmbeddingTokensKey).(int) if ok { isStreamRequest := bifrost.IsStreamRequestType(requestType) @@ -559,7 +558,7 @@ func (plugin *Plugin) PostHook(ctx *schemas.BifrostContext, res *schemas.Bifrost cacheTTL := plugin.config.TTL - ttlValue := (*ctx).Value(CacheTTLKey) + ttlValue := ctx.Value(CacheTTLKey) if ttlValue != nil { // Get the request TTL from the context ttl, ok := ttlValue.(time.Duration) @@ -570,6 +569,10 @@ func (plugin *Plugin) PostHook(ctx *schemas.BifrostContext, res *schemas.Bifrost } } + // Get metadata from context BEFORE goroutine to avoid race conditions + // when the same context is reused across multiple requests + paramsHash, _ := ctx.Value(requestParamsHashKey).(string) + // Cache everything in a unified VectorEntry asynchronously to avoid blocking the response plugin.waitGroup.Add(1) go func() { @@ -578,9 +581,6 @@ func (plugin *Plugin) PostHook(ctx *schemas.BifrostContext, res *schemas.Bifrost cacheCtx, cancel := context.WithTimeout(context.Background(), CacheSetTimeout) defer cancel() - // Get metadata from context - paramsHash, _ := (*ctx).Value(requestParamsHashKey).(string) - // Build unified metadata with provider, model, and all params unifiedMetadata := plugin.buildUnifiedMetadata(provider, model, paramsHash, hash, cacheKey, cacheTTL) diff --git a/plugins/semanticcache/plugin_cache_type_test.go b/plugins/semanticcache/plugin_cache_type_test.go index 603e00726c..bf9b797518 100644 --- a/plugins/semanticcache/plugin_cache_type_test.go +++ b/plugins/semanticcache/plugin_cache_type_test.go @@ -1,7 +1,6 @@ package semanticcache import ( - "context" "testing" "time" @@ -148,7 +147,7 @@ func TestCacheTypeInvalidValue(t *testing.T) { // Create context with invalid cache type ctx := CreateContextWithCacheKey("test-invalid-cache-type") - ctx = context.WithValue(ctx, CacheTypeKey, "invalid_type") + ctx = ctx.WithValue(CacheTypeKey, "invalid_type") testRequest := CreateBasicChatRequest("Test invalid cache type", 0.7, 50) diff --git a/plugins/semanticcache/plugin_core_test.go b/plugins/semanticcache/plugin_core_test.go index 044d8327b1..0ae5d6908c 100644 --- a/plugins/semanticcache/plugin_core_test.go +++ b/plugins/semanticcache/plugin_core_test.go @@ -348,6 +348,7 @@ func TestCacheConfiguration(t *testing.T) { config: &Config{ Provider: schemas.OpenAI, EmbeddingModel: "text-embedding-3-small", + Dimension: 1536, Threshold: 0.95, // Very high threshold Keys: []schemas.Key{ {Value: os.Getenv("OPENAI_API_KEY"), Models: []string{}, Weight: 1.0}, @@ -360,6 +361,7 @@ func TestCacheConfiguration(t *testing.T) { config: &Config{ Provider: schemas.OpenAI, EmbeddingModel: "text-embedding-3-small", + Dimension: 1536, Threshold: 0.1, // Very low threshold Keys: []schemas.Key{ {Value: os.Getenv("OPENAI_API_KEY"), Models: []string{}, Weight: 1.0}, @@ -372,6 +374,7 @@ func TestCacheConfiguration(t *testing.T) { config: &Config{ Provider: schemas.OpenAI, EmbeddingModel: "text-embedding-3-small", + Dimension: 1536, Threshold: 0.8, TTL: 1 * time.Hour, // Custom TTL Keys: []schemas.Key{ diff --git a/plugins/semanticcache/plugin_cross_cache_test.go b/plugins/semanticcache/plugin_cross_cache_test.go index 931f6c8d92..bb313b1a3f 100644 --- a/plugins/semanticcache/plugin_cross_cache_test.go +++ b/plugins/semanticcache/plugin_cross_cache_test.go @@ -1,7 +1,6 @@ package semanticcache import ( - "context" "testing" "github.com/maximhq/bifrost/core/schemas" @@ -298,7 +297,7 @@ func TestCacheTypeErrorHandling(t *testing.T) { // Test invalid cache type (should fallback to default) ctx1 := CreateContextWithCacheKey("test-cache-error-handling") - ctx1 = context.WithValue(ctx1, CacheTypeKey, "invalid_cache_type") + ctx1 = ctx1.WithValue(CacheTypeKey, "invalid_cache_type") t.Log("Testing invalid cache type (should fallback to default behavior)...") response1, err1 := setup.Client.ChatCompletionRequest(ctx1, testRequest) @@ -311,7 +310,7 @@ func TestCacheTypeErrorHandling(t *testing.T) { // Test nil cache type (should use default) ctx2 := CreateContextWithCacheKey("test-cache-error-handling") - ctx2 = context.WithValue(ctx2, CacheTypeKey, nil) + ctx2 = ctx2.WithValue(CacheTypeKey, nil) t.Log("Testing nil cache type (should use default behavior)...") response2, err2 := setup.Client.ChatCompletionRequest(ctx2, testRequest) diff --git a/plugins/semanticcache/plugin_edge_cases_test.go b/plugins/semanticcache/plugin_edge_cases_test.go index 8cce0fbc6f..e941db1a1f 100644 --- a/plugins/semanticcache/plugin_edge_cases_test.go +++ b/plugins/semanticcache/plugin_edge_cases_test.go @@ -14,7 +14,6 @@ func TestParameterVariations(t *testing.T) { setup := NewTestSetup(t) defer setup.Cleanup() - ctx := CreateContextWithCacheKey("param-variations-test") basePrompt := "What is the capital of France?" tests := []struct { @@ -45,6 +44,9 @@ func TestParameterVariations(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + // Create a fresh context for each subtest to avoid context pollution + ctx := CreateContextWithCacheKey("param-variations-test") + // Clear cache for this subtest clearTestKeysWithStore(t, setup.Store) @@ -221,8 +223,6 @@ func TestContentVariations(t *testing.T) { setup := NewTestSetup(t) defer setup.Cleanup() - ctx := CreateContextWithCacheKey("content-variations-test") - tests := []struct { name string request *schemas.BifrostChatRequest @@ -349,6 +349,9 @@ func TestContentVariations(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Logf("Testing content variation: %s", tt.name) + // Create a fresh context for each subtest to avoid context pollution + ctx := CreateContextWithCacheKey("content-variations-test") + // Make first request _, err1 := setup.Client.ChatCompletionRequest(ctx, tt.request) if err1 != nil { @@ -376,8 +379,6 @@ func TestBoundaryParameterValues(t *testing.T) { setup := NewTestSetup(t) defer setup.Cleanup() - ctx := CreateContextWithCacheKey("boundary-params-test") - tests := []struct { name string request *schemas.BifrostChatRequest @@ -453,6 +454,9 @@ func TestBoundaryParameterValues(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Logf("Testing boundary parameters: %s", tt.name) + // Create a fresh context for each subtest to avoid context pollution + ctx := CreateContextWithCacheKey("boundary-params-test") + _, err := setup.Client.ChatCompletionRequest(ctx, tt.request) if err != nil { t.Logf("āš ļø %s request failed (may be expected): %v", tt.name, err) @@ -470,8 +474,6 @@ func TestSemanticSimilarityEdgeCases(t *testing.T) { setup.Config.Threshold = 0.9 - ctx := CreateContextWithCacheKey("semantic-edge-test") - // Test case: Similar questions with different wording similarTests := []struct { prompt1 string @@ -507,6 +509,9 @@ func TestSemanticSimilarityEdgeCases(t *testing.T) { for i, test := range similarTests { t.Run(test.description, func(t *testing.T) { + // Create a fresh context for each subtest to avoid context pollution + ctx := CreateContextWithCacheKey("semantic-edge-test") + // Clear cache for this subtest clearTestKeysWithStore(t, setup.Store) @@ -575,7 +580,7 @@ func TestErrorHandlingEdgeCases(t *testing.T) { // Test without cache key (should not crash and bypass cache) t.Run("Request without cache key", func(t *testing.T) { - ctxNoKey := context.Background() // No cache key + ctxNoKey := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) response, err := setup.Client.ChatCompletionRequest(ctxNoKey, testRequest) if err != nil { @@ -600,7 +605,7 @@ func TestErrorHandlingEdgeCases(t *testing.T) { WaitForCache() // Now test with invalid key type - should bypass cache - ctxInvalidKey := context.WithValue(context.Background(), CacheKey, 12345) // Wrong type (int instead of string) + ctxInvalidKey := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline).WithValue(CacheKey, 12345) response, err := setup.Client.ChatCompletionRequest(ctxInvalidKey, testRequest) if err != nil { diff --git a/plugins/semanticcache/plugin_no_store_test.go b/plugins/semanticcache/plugin_no_store_test.go index d48791986f..4840795bd9 100644 --- a/plugins/semanticcache/plugin_no_store_test.go +++ b/plugins/semanticcache/plugin_no_store_test.go @@ -1,7 +1,6 @@ package semanticcache import ( - "context" "testing" "github.com/maximhq/bifrost/core/schemas" @@ -172,8 +171,8 @@ func TestCacheNoStoreWithCacheTypes(t *testing.T) { // Test no-store with direct cache type ctx1 := CreateContextWithCacheKey("test-no-store-cache-types") - ctx1 = context.WithValue(ctx1, CacheNoStoreKey, true) - ctx1 = context.WithValue(ctx1, CacheTypeKey, CacheTypeDirect) + ctx1 = ctx1.WithValue(CacheNoStoreKey, true) + ctx1 = ctx1.WithValue(CacheTypeKey, CacheTypeDirect) t.Log("Testing no-store with CacheTypeKey=direct...") response1, err1 := setup.Client.ChatCompletionRequest(ctx1, testRequest) @@ -193,8 +192,8 @@ func TestCacheNoStoreWithCacheTypes(t *testing.T) { // Test no-store with semantic cache type ctx2 := CreateContextWithCacheKey("test-no-store-cache-types") - ctx2 = context.WithValue(ctx2, CacheNoStoreKey, true) - ctx2 = context.WithValue(ctx2, CacheTypeKey, CacheTypeSemantic) + ctx2 = ctx2.WithValue(CacheNoStoreKey, true) + ctx2 = ctx2.WithValue(CacheTypeKey, CacheTypeSemantic) t.Log("Testing no-store with CacheTypeKey=semantic...") response3, err3 := setup.Client.ChatCompletionRequest(ctx2, testRequest) @@ -224,7 +223,7 @@ func TestCacheNoStoreErrorHandling(t *testing.T) { // Test with invalid no-store value (non-boolean) ctx1 := CreateContextWithCacheKey("test-no-store-errors") - ctx1 = context.WithValue(ctx1, CacheNoStoreKey, "invalid") + ctx1 = ctx1.WithValue(CacheNoStoreKey, "invalid") t.Log("Testing no-store with invalid value (should cache normally)...") response1, err1 := setup.Client.ChatCompletionRequest(ctx1, testRequest) @@ -248,7 +247,7 @@ func TestCacheNoStoreErrorHandling(t *testing.T) { // Test with nil value (should cache normally) ctx2 := CreateContextWithCacheKey("test-no-store-nil") - ctx2 = context.WithValue(ctx2, CacheNoStoreKey, nil) + ctx2 = ctx2.WithValue(CacheNoStoreKey, nil) t.Log("Testing no-store with nil value (should cache normally)...") response3, err3 := setup.Client.ChatCompletionRequest(ctx2, testRequest) diff --git a/plugins/semanticcache/search.go b/plugins/semanticcache/search.go index 1872e36245..05bc6cceb3 100644 --- a/plugins/semanticcache/search.go +++ b/plugins/semanticcache/search.go @@ -103,7 +103,7 @@ func (plugin *Plugin) performSemanticSearch(ctx *schemas.BifrostContext, req *sc cacheThreshold := plugin.config.Threshold - thresholdValue := (*ctx).Value(CacheThresholdKey) + thresholdValue := ctx.Value(CacheThresholdKey) if thresholdValue != nil { threshold, ok := thresholdValue.(float64) if !ok { @@ -297,7 +297,7 @@ func (plugin *Plugin) buildStreamingResponseFromResult(ctx *schemas.BifrostConte // Mark cache-hit once to avoid concurrent ctx writes ctx.SetValue(isCacheHitKey, true) ctx.SetValue(cacheHitTypeKey, cacheType) - + // Create stream channel streamChan := make(chan *schemas.BifrostStream) diff --git a/plugins/semanticcache/stream.go b/plugins/semanticcache/stream.go index bd9f19e05b..385b6e114f 100644 --- a/plugins/semanticcache/stream.go +++ b/plugins/semanticcache/stream.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "sort" + "sync" "time" ) @@ -19,6 +20,7 @@ func (plugin *Plugin) createStreamAccumulator(requestID string, embedding []floa Embedding: embedding, Metadata: metadata, TTL: ttl, + mu: sync.Mutex{}, } plugin.streamAccumulators.Store(requestID, accumulator) diff --git a/plugins/semanticcache/test_utils.go b/plugins/semanticcache/test_utils.go index a83f4162fc..519469b119 100644 --- a/plugins/semanticcache/test_utils.go +++ b/plugins/semanticcache/test_utils.go @@ -85,7 +85,7 @@ func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvide return []schemas.ModelProvider{schemas.OpenAI}, nil } -func (baseAccount *BaseAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { +func (baseAccount *BaseAccount) GetKeysForProvider(ctx context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { return []schemas.Key{ { Value: os.Getenv("OPENAI_API_KEY"), @@ -297,7 +297,7 @@ func getMockRules() []mocker.MockRule { } // getMockedBifrostClient creates a Bifrost client with a mocker plugin for testing -func getMockedBifrostClient(t *testing.T, ctx context.Context, logger schemas.Logger, semanticCachePlugin schemas.Plugin) *bifrost.Bifrost { +func getMockedBifrostClient(t *testing.T, ctx *schemas.BifrostContext, logger schemas.Logger, semanticCachePlugin schemas.Plugin) *bifrost.Bifrost { mockerCfg := mocker.MockerConfig{ Enabled: true, Rules: getMockRules(), @@ -335,6 +335,7 @@ func NewTestSetup(t *testing.T) *TestSetup { return NewTestSetupWithConfig(t, &Config{ Provider: schemas.OpenAI, EmbeddingModel: "text-embedding-3-small", + Dimension: 1536, Threshold: 0.8, CleanUpOnShutdown: true, Keys: []schemas.Key{ @@ -349,7 +350,7 @@ func NewTestSetup(t *testing.T) *TestSetup { // NewTestSetupWithConfig creates a new test setup with custom configuration func NewTestSetupWithConfig(t *testing.T, config *Config) *TestSetup { - ctx := context.Background() + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) logger := bifrost.NewDefaultLogger(schemas.LogLevelDebug) // Keep Weaviate for embeddings, as mocker only affects chat completions @@ -362,7 +363,7 @@ func NewTestSetupWithConfig(t *testing.T, config *Config) *TestSetup { t.Fatalf("Vector store not available or failed to connect: %v", err) } - plugin, err := Init(context.Background(), config, logger, store) + plugin, err := Init(schemas.NewBifrostContext(context.Background(), schemas.NoDeadline), config, logger, store) if err != nil { t.Fatalf("Failed to initialize plugin: %v", err) } @@ -542,32 +543,28 @@ func CreateStreamingResponsesRequest(content string, temperature float64, maxTok } // CreateContextWithCacheKey creates a context with the test cache key -func CreateContextWithCacheKey(value string) context.Context { - return context.WithValue(context.Background(), CacheKey, value) +func CreateContextWithCacheKey(value string) *schemas.BifrostContext { + return schemas.NewBifrostContextWithValue(context.Background(), schemas.NoDeadline, CacheKey, value) } // CreateContextWithCacheKeyAndType creates a context with cache key and cache type -func CreateContextWithCacheKeyAndType(value string, cacheType CacheType) context.Context { - ctx := context.WithValue(context.Background(), CacheKey, value) - return context.WithValue(ctx, CacheTypeKey, cacheType) +func CreateContextWithCacheKeyAndType(value string, cacheType CacheType) *schemas.BifrostContext { + return schemas.NewBifrostContextWithValue(context.Background(), schemas.NoDeadline, CacheKey, value).WithValue(CacheTypeKey, cacheType) } // CreateContextWithCacheKeyAndTTL creates a context with cache key and custom TTL -func CreateContextWithCacheKeyAndTTL(value string, ttl time.Duration) context.Context { - ctx := context.WithValue(context.Background(), CacheKey, value) - return context.WithValue(ctx, CacheTTLKey, ttl) +func CreateContextWithCacheKeyAndTTL(value string, ttl time.Duration) *schemas.BifrostContext { + return schemas.NewBifrostContextWithValue(context.Background(), schemas.NoDeadline, CacheKey, value).WithValue(CacheTTLKey, ttl) } // CreateContextWithCacheKeyAndThreshold creates a context with cache key and custom threshold -func CreateContextWithCacheKeyAndThreshold(value string, threshold float64) context.Context { - ctx := context.WithValue(context.Background(), CacheKey, value) - return context.WithValue(ctx, CacheThresholdKey, threshold) +func CreateContextWithCacheKeyAndThreshold(value string, threshold float64) *schemas.BifrostContext { + return schemas.NewBifrostContext(context.Background(), schemas.NoDeadline).WithValue(CacheKey, value).WithValue(CacheThresholdKey, threshold) } // CreateContextWithCacheKeyAndNoStore creates a context with cache key and no-store flag -func CreateContextWithCacheKeyAndNoStore(value string, noStore bool) context.Context { - ctx := context.WithValue(context.Background(), CacheKey, value) - return context.WithValue(ctx, CacheNoStoreKey, noStore) +func CreateContextWithCacheKeyAndNoStore(value string, noStore bool) *schemas.BifrostContext { + return schemas.NewBifrostContext(context.Background(), schemas.NoDeadline).WithValue(CacheKey, value).WithValue(CacheNoStoreKey, noStore) } // CreateTestSetupWithConversationThreshold creates a test setup with custom conversation history threshold @@ -575,6 +572,7 @@ func CreateTestSetupWithConversationThreshold(t *testing.T, threshold int) *Test config := &Config{ Provider: schemas.OpenAI, EmbeddingModel: "text-embedding-3-small", + Dimension: 1536, CleanUpOnShutdown: true, Threshold: 0.8, ConversationHistoryThreshold: threshold, @@ -595,6 +593,7 @@ func CreateTestSetupWithExcludeSystemPrompt(t *testing.T, excludeSystem bool) *T config := &Config{ Provider: schemas.OpenAI, EmbeddingModel: "text-embedding-3-small", + Dimension: 1536, CleanUpOnShutdown: true, Threshold: 0.8, ExcludeSystemPrompt: &excludeSystem, @@ -615,6 +614,7 @@ func CreateTestSetupWithThresholdAndExcludeSystem(t *testing.T, threshold int, e config := &Config{ Provider: schemas.OpenAI, EmbeddingModel: "text-embedding-3-small", + Dimension: 1536, CleanUpOnShutdown: true, Threshold: 0.8, ConversationHistoryThreshold: threshold, diff --git a/plugins/semanticcache/utils.go b/plugins/semanticcache/utils.go index 08b37e2d1b..8e93a53b4b 100644 --- a/plugins/semanticcache/utils.go +++ b/plugins/semanticcache/utils.go @@ -20,7 +20,7 @@ func normalizeText(text string) string { } // generateEmbedding generates an embedding for the given text using the configured provider. -func (plugin *Plugin) generateEmbedding(ctx context.Context, text string) ([]float32, int, error) { +func (plugin *Plugin) generateEmbedding(ctx *schemas.BifrostContext, text string) ([]float32, int, error) { // Create embedding request embeddingReq := &schemas.BifrostEmbeddingRequest{ Provider: plugin.config.Provider, diff --git a/plugins/semanticcache/version b/plugins/semanticcache/version index ee96f2fc42..721b9931f4 100644 --- a/plugins/semanticcache/version +++ b/plugins/semanticcache/version @@ -1 +1 @@ -1.3.61 \ No newline at end of file +1.4.8 \ No newline at end of file diff --git a/plugins/telemetry/go.mod b/plugins/telemetry/go.mod index 0263a86070..bdb1a0845b 100644 --- a/plugins/telemetry/go.mod +++ b/plugins/telemetry/go.mod @@ -3,8 +3,8 @@ module github.com/maximhq/bifrost/plugins/telemetry go 1.25.5 require ( - github.com/maximhq/bifrost/core v1.2.49 - github.com/maximhq/bifrost/framework v1.1.61 + github.com/maximhq/bifrost/core v1.3.8 + github.com/maximhq/bifrost/framework v1.2.8 github.com/prometheus/client_golang v1.23.0 github.com/valyala/fasthttp v1.68.0 ) @@ -42,8 +42,11 @@ require ( github.com/bytedance/sonic v1.14.2 // indirect github.com/bytedance/sonic/loader v0.4.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/clarkmcc/go-typescript v0.7.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/analysis v0.24.2 // indirect @@ -67,8 +70,10 @@ require ( github.com/go-openapi/swag/typeutils v0.25.4 // indirect github.com/go-openapi/swag/yamlutils v0.25.4 // indirect github.com/go-openapi/validate v0.25.1 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f // indirect github.com/google/uuid v1.6.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect diff --git a/plugins/telemetry/go.sum b/plugins/telemetry/go.sum index f3506dda2d..ec1e0da55e 100644 --- a/plugins/telemetry/go.sum +++ b/plugins/telemetry/go.sum @@ -12,6 +12,8 @@ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/Masterminds/semver/v3 v3.3.1 h1:QtNSWtVZ3nBfk8mAOu/B6v7FMJ+NHTIgUPi7rj+4nv4= +github.com/Masterminds/semver/v3 v3.3.1/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= @@ -70,6 +72,8 @@ github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2N github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/clarkmcc/go-typescript v0.7.0 h1:3nVeaPYyTCWjX6Lf8GoEOTxME2bM5tLuWmwhSZ86uxg= +github.com/clarkmcc/go-typescript v0.7.0/go.mod h1:IZ/nzoVeydAmyfX7l6Jmp8lJDOEnae3jffoXwP4UyYg= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -79,6 +83,10 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 h1:jxmXU5V9tXxJnydU5v/m9SG8TRUa/Z7IXODBpMs/P+U= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -134,6 +142,8 @@ github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6 github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= github.com/go-openapi/validate v0.25.1 h1:sSACUI6Jcnbo5IWqbYHgjibrhhmt3vR6lCzKZnmAgBw= github.com/go-openapi/validate v0.25.1/go.mod h1:RMVyVFYte0gbSTaZ0N4KmTn6u/kClvAFp+mAVfS/DQc= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -143,6 +153,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f h1:HU1RgM6NALf/KW9HEY6zry3ADbDKcmpQ+hJedoNGQYQ= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f/go.mod h1:67FPmZWbr+KDT/VlpWtw6sO9XSjpJmLuHpoLmWiTGgY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= @@ -186,10 +198,10 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.2.49 h1:fk6l6r3kVBlpN73wYXmgtV6O4bhedOjSO4LAEz/7leg= -github.com/maximhq/bifrost/core v1.2.49/go.mod h1:z7nOx15e91ktZGi+pZHq+uhShlEK+fM4UyYUpP6oHAw= -github.com/maximhq/bifrost/framework v1.1.61 h1:fMjvICbkrdWMtGnLYrjSNrcmQYqtQvOh/swmrJTvf+E= -github.com/maximhq/bifrost/framework v1.1.61/go.mod h1:wVUPzB8K5S/5GWuxqp8dXf3nNZkqJsS/APMIcq48SOI= +github.com/maximhq/bifrost/core v1.3.8 h1:xtwB9+HeTzYz5IKHkpUtupzBd0A5yl1avdLJGjsOKPI= +github.com/maximhq/bifrost/core v1.3.8/go.mod h1:abKQRnJQPZz8/UMxCcbuNHEyq19Db+IX4KlGJdlLY8E= +github.com/maximhq/bifrost/framework v1.2.8 h1:/oTpacuw7k0zRUJ9dSSQRtAVx3nLGSiR7GFwOjGxZNs= +github.com/maximhq/bifrost/framework v1.2.8/go.mod h1:mjw9YXh/Oxi3HeBCJ+3HJ6ftv43Wo4t0T4EzpcIbnr0= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= @@ -297,6 +309,8 @@ google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/telemetry/main.go b/plugins/telemetry/main.go index e227bac279..b1b4898c5a 100644 --- a/plugins/telemetry/main.go +++ b/plugins/telemetry/main.go @@ -63,6 +63,7 @@ type PrometheusPlugin struct { type Config struct { CustomLabels []string `json:"custom_labels"` + Registry *prometheus.Registry } // Init creates a new PrometheusPlugin with initialized metrics. @@ -75,7 +76,11 @@ func Init(config *Config, pricingManager *modelcatalog.ModelCatalog, logger sche logger.Warn("telemetry plugin requires model catalog to calculate cost, all cost calculations will be skipped.") } - registry := prometheus.NewRegistry() + registry := config.Registry + // If config has no registry, create a new one + if registry == nil { + registry = prometheus.NewRegistry() + } // Create collectors and store references for cleanup goCollector := collectors.NewGoCollector() @@ -276,9 +281,9 @@ func (p *PrometheusPlugin) GetName() string { return PluginName } -// TransportInterceptor is not used for this plugin -func (p *PrometheusPlugin) TransportInterceptor(ctx *schemas.BifrostContext, url string, headers map[string]string, body map[string]any) (map[string]string, map[string]any, error) { - return headers, body, nil +// HTTPTransportIntercept is not used for this plugin +func (p *PrometheusPlugin) HTTPTransportIntercept(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + return nil, nil } // PreHook records the start time of the request in the context. @@ -315,45 +320,46 @@ func (p *PrometheusPlugin) PostHook(ctx *schemas.BifrostContext, result *schemas customerID := getStringFromContext(ctx, schemas.BifrostContextKey("bf-governance-customer-id")) customerName := getStringFromContext(ctx, schemas.BifrostContextKey("bf-governance-customer-name")) - // Calculate cost and record metrics in a separate goroutine to avoid blocking the main thread - go func() { - labelValues := map[string]string{ - "provider": string(provider), - "model": model, - "method": string(requestType), - "virtual_key_id": virtualKeyID, - "virtual_key_name": virtualKeyName, - "selected_key_id": selectedKeyID, - "selected_key_name": selectedKeyName, - "number_of_retries": strconv.Itoa(numberOfRetries), - "fallback_index": strconv.Itoa(fallbackIndex), - "team_id": teamID, - "team_name": teamName, - "customer_id": customerID, - "customer_name": customerName, - } + // Extract ALL context values BEFORE spawning the goroutine. + labelValues := map[string]string{ + "provider": string(provider), + "model": model, + "method": string(requestType), + "virtual_key_id": virtualKeyID, + "virtual_key_name": virtualKeyName, + "selected_key_id": selectedKeyID, + "selected_key_name": selectedKeyName, + "number_of_retries": strconv.Itoa(numberOfRetries), + "fallback_index": strconv.Itoa(fallbackIndex), + "team_id": teamID, + "team_name": teamName, + "customer_id": customerID, + "customer_name": customerName, + } - // Get all prometheus labels from context - for _, key := range p.customLabels { - if value := (*ctx).Value(schemas.BifrostContextKey(key)); value != nil { - if strValue, ok := value.(string); ok { - labelValues[key] = strValue - } + // Get all custom prometheus labels from context BEFORE the goroutine + for _, key := range p.customLabels { + if value := ctx.Value(schemas.BifrostContextKey(key)); value != nil { + if strValue, ok := value.(string); ok { + labelValues[key] = strValue } } + } + + // Get label values in the correct order (cache_type will be handled separately for cache hits) + promLabelValues := getPrometheusLabelValues(append(p.defaultBifrostLabels, p.customLabels...), labelValues) - // Get label values in the correct order (cache_type will be handled separately for cache hits) - promLabelValues := getPrometheusLabelValues(append(p.defaultBifrostLabels, p.customLabels...), labelValues) + // Extract stream end indicator BEFORE the goroutine + streamEndIndicatorValue := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator) + isFinalChunk, hasFinalChunkIndicator := streamEndIndicatorValue.(bool) + // Calculate cost and record metrics in a separate goroutine to avoid blocking the main thread + go func() { // For streaming requests, handle per-token metrics for intermediate chunks if bifrost.IsStreamRequestType(requestType) { - // Determine if this is the final chunk - streamEndIndicatorValue := (*ctx).Value(schemas.BifrostContextKeyStreamEndIndicator) - isFinalChunk, ok := streamEndIndicatorValue.(bool) - // For intermediate chunks, record per-token metrics and exit. // The final chunk will fall through to record full request metrics. - if !ok || !isFinalChunk { + if !hasFinalChunkIndicator || !isFinalChunk { // Record metrics for the first token if result != nil { extraFields := result.GetExtraFields() @@ -464,7 +470,7 @@ func (p *PrometheusPlugin) PostHook(ctx *schemas.BifrostContext, result *schemas return result, bifrostErr, nil } -// PrometheusMiddleware wraps a FastHTTP handler to collect Prometheus metrics. +// HTTPMiddleware wraps a FastHTTP handler to collect Prometheus metrics. // It tracks: // - Total number of requests // - Request duration diff --git a/plugins/telemetry/version b/plugins/telemetry/version index ee96f2fc42..5596554988 100644 --- a/plugins/telemetry/version +++ b/plugins/telemetry/version @@ -1 +1 @@ -1.3.61 \ No newline at end of file +1.4.9 \ No newline at end of file diff --git a/tests/core-mcp/README.md b/tests/core-mcp/README.md new file mode 100644 index 0000000000..b2e6745de2 --- /dev/null +++ b/tests/core-mcp/README.md @@ -0,0 +1,230 @@ +# MCP Test Suite + +This directory contains comprehensive tests for the MCP (Model Context Protocol) functionality in Bifrost, covering code mode and non-code mode clients, auto-execute and non-auto-execute tools, and their various combinations. + +## Overview + +The test suite is organized into multiple test files covering different aspects of MCP: + +1. **Client Configuration Tests** (`client_config_test.go`) + - Single and multiple code mode clients + - Single and multiple non-code mode clients + - Mixed code mode + non-code mode clients + - Client connection states + - Client configuration updates + +2. **Tool Execution Tests** (`tool_execution_test.go`) + - Non-code mode tool execution (direct) + - Code mode tool execution (`executeToolCode`) + - Code mode calling code mode client tools + - Code mode calling multiple servers + - `listToolFiles` and `readToolFile` functionality + +3. **Auto-Execute Configuration Tests** (`auto_execute_config_test.go`) + - Tools in `ToolsToExecute` but not in `ToolsToAutoExecute` + - Tools in both lists (auto-execute) + - Tools in `ToolsToAutoExecute` but not in `ToolsToExecute` (should be skipped) + - Wildcard configurations + - Empty and nil configurations + - Mixed auto-execute configurations + +4. **Code Mode Auto-Execute Validation Tests** (`codemode_auto_execute_test.go`) + - `executeToolCode` with code calling only auto-execute tools + - `executeToolCode` with code calling non-auto-execute tools + - `executeToolCode` with code calling mixed auto/non-auto tools + - `executeToolCode` with no tool calls + - `executeToolCode` with `listToolFiles`/`readToolFile` calls + +5. **Agent Mode Tests** (`agent_mode_test.go`) + - Agent mode configuration validation + - Max depth configuration + - Note: Full agent mode flow testing requires LLM integration (see `integration_test.go`) + +6. **Edge Cases & Error Handling** (`edge_cases_test.go`) + - Code mode client calling non-code mode client tool (runtime error) + - Tool not in `ToolsToExecute` (should not be available) + - Tool execution timeout + - Tool execution error propagation + - Empty code execution + - Code with syntax errors + - Code with TypeScript compilation errors + - Code with runtime errors + - Code calling tools with invalid arguments + - Code mode tools always auto-executable + +7. **Integration Tests** (`integration_test.go`) + - Full workflow: `listToolFiles` → `readToolFile` → `executeToolCode` + - Multiple code mode clients with different auto-execute configs + - Tool filtering with code mode + - Code mode and non-code mode tools in same request + - Complex code execution scenarios + - Error handling in code execution + +8. **Basic MCP Connection Tests** (`mcp_connection_test.go`) + - MCP manager initialization + - Local tool registration + - Tool discovery and execution + - Multiple servers + - Tool execution timeout and errors + +## MCP Architecture + +### Client Types + +- **Code Mode Clients** (`IsCodeModeClient=true`): + - Enable code mode tools: `listToolFiles`, `readToolFile`, `executeToolCode` + - Tools accessible via TypeScript code execution in sandboxed VM + - Only code mode clients appear in `listToolFiles` output + +- **Non-Code Mode Clients** (`IsCodeModeClient=false`): + - Tools exposed directly as function-calling tools + - Cannot be called from `executeToolCode` code + +### Tool Execution Modes + +- **Auto-Execute Tools** (`ToolsToAutoExecute`): + - Automatically executed in agent mode without user approval + - Must also be in `ToolsToExecute` list + - For `executeToolCode`: validates all tool calls within code against auto-execute list + +- **Non-Auto-Execute Tools**: + - Require explicit user approval in agent mode + - Agent loop stops and returns these tools for user decision + +### Agent Mode Behavior + +When agent mode receives tool calls: + +- **All auto-execute tools**: Executes all tools, makes new LLM call, continues loop +- **All non-auto-execute tools**: Stops immediately, returns tool calls in `tool_calls` field +- **Mixed scenario** (e.g., 3 auto-execute, 2 non-auto-execute): + - Executes all auto-executable tools (3 in example) + - Adds executed tool results to message content (formatted as JSON) + - Includes non-auto-executable tool calls (2 in example) in `tool_calls` field + - Sets `finish_reason` to "stop" (not "tool_calls") to prevent loop continuation + - Returns immediately without making another LLM call + +Agent mode respects `maxAgentDepth` limit and returns an error if exceeded. + +## Test Structure + +### Setup Files + +- `setup.go` - Test setup utilities for initializing Bifrost and configuring clients + - `setupTestBifrost()` - Basic Bifrost instance + - `setupTestBifrostWithCodeMode()` - Bifrost with code mode enabled + - `setupTestBifrostWithMCPConfig()` - Bifrost with custom MCP config + - `setupCodeModeClient()` - Helper to create code mode client config + - `setupNonCodeModeClient()` - Helper to create non-code mode client config + - `setupClientWithAutoExecute()` - Helper to create client with auto-execute config + - `registerTestTools()` - Registers test tools (echo, add, multiply, etc.) + +- `fixtures.go` - Sample TypeScript code snippets and expected results + - Basic expressions and tool calls + - Auto-execute validation scenarios + - Mixed client scenarios + - Edge case scenarios + +- `utils.go` - Test helper functions for assertions and validation + - `createToolCall()` - Creates tool call messages + - `assertExecutionResult()` - Validates execution results + - `assertAgentModeResponse()` - Validates agent mode response structure + - `extractExecutedToolResults()` - Extracts executed tool results from agent mode response + - `canAutoExecuteTool()` - Checks if a tool can be auto-executed + - `createMCPClientConfig()` - Creates MCP client configs + +## Running Tests + +### Run all tests: +```bash +cd tests/core-mcp +go test -v ./... +``` + +### Run specific test file: +```bash +go test -v -run TestClientConfig ./... +``` + +### Run specific test: +```bash +go test -v -run TestSingleCodeModeClient +``` + +### Run with coverage: +```bash +go test -v -cover ./... +``` + +### Run tests by category: +```bash +# Client configuration tests +go test -v -run "^Test.*Client.*" ./... + +# Tool execution tests +go test -v -run "^Test.*Tool.*" ./... + +# Auto-execute tests +go test -v -run "^Test.*Auto.*" ./... + +# Edge case tests +go test -v -run "^Test.*Error|^Test.*Timeout|^Test.*Empty" ./... + +# Integration tests +go test -v -run "^Test.*Workflow|^Test.*Integration" ./... +``` + +## Test Tools + +The test suite registers several test tools: + +1. **echo** - Simple echo that returns input +2. **add** - Adds two numbers +3. **multiply** - Multiplies two numbers +4. **get_data** - Returns structured data (object/array) +5. **error_tool** - Tool that always returns an error +6. **slow_tool** - Tool that takes time to execute +7. **complex_args_tool** - Tool that accepts complex nested arguments + +## Key Test Scenarios + +### Scenario 1: Mixed Auto-Execute and Non-Auto-Execute Tools (Critical) + +When agent mode receives 5 tool calls: 3 auto-execute, 2 non-auto-execute: +- Agent executes the 3 auto-execute tools +- Adds their results to message content (JSON formatted) +- Includes the 2 non-auto-execute tool calls in `tool_calls` field +- Sets `finish_reason` to "stop" +- Stops immediately (no further LLM call) +- Response structure validated correctly + +### Scenario 2: Code Mode Client + Auto-Execute Tools + +- Setup: Code mode client with tools configured for auto-execute +- Test: `executeToolCode` with code calling these tools should auto-execute in agent mode + +### Scenario 3: Mixed Client Types + +- Setup: One code mode client + one non-code mode client +- Test: Code mode tools only see code mode client, non-code mode tools available separately + +### Scenario 4: Auto-Execute Validation in Code + +- Setup: Code mode client with mixed auto-execute config +- Test: `executeToolCode` validates all tool calls in code against auto-execute list + +### Scenario 5: Code Mode Tools Always Auto-Execute + +- Setup: Code mode enabled +- Test: `listToolFiles` and `readToolFile` always auto-execute regardless of config + +## Notes + +- All tests use a timeout context to prevent hanging +- Tests are designed to be independent and can run in parallel +- The test suite uses the `bifrostInternal` server for local tool registration +- Code mode tests verify that TypeScript code is transpiled and executes correctly in the sandboxed goja VM +- TypeScript compilation errors are caught and reported with helpful hints +- Async/await syntax is automatically transpiled to Promise chains compatible with goja +- Error handling tests verify that helpful error hints are provided for both runtime and TypeScript compilation errors +- Agent mode tests verify the critical mixed auto-execute/non-auto-execute scenario where some tools are executed and others are returned for user approval diff --git a/tests/core-mcp/agent_mode_test.go b/tests/core-mcp/agent_mode_test.go new file mode 100644 index 0000000000..7788841fbc --- /dev/null +++ b/tests/core-mcp/agent_mode_test.go @@ -0,0 +1,77 @@ +package mcp + +import ( + "context" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Note: Full agent mode testing requires integration with LLM calls. +// These tests verify the configuration and tool execution aspects that can be tested directly. +// For full agent mode flow testing, see integration_test.go + +// TestAgentModeConfiguration tests the configuration aspects of agent mode +// Full agent mode flow testing requires LLM integration (see integration_test.go) +func TestAgentModeConfiguration(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Test configuration: echo auto-execute, add non-auto-execute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo"}, // Only echo is auto-execute + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + + // Verify configuration + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should be auto-executable") + assert.False(t, canAutoExecuteTool("add", bifrostClient.Config), "add should not be auto-executable") + assert.False(t, canAutoExecuteTool("multiply", bifrostClient.Config), "multiply should not be auto-executable") +} + +func TestAgentModeMaxDepthConfiguration(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + // Create Bifrost with max depth of 2 + mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{}, + ToolManagerConfig: &schemas.MCPToolManagerConfig{ + MaxAgentDepth: 2, + ToolExecutionTimeout: 30 * time.Second, + }, + FetchNewRequestIDFunc: func(ctx *schemas.BifrostContext) string { + return "test-request-id" + }, + } + b, err := setupTestBifrostWithMCPConfig(ctx, mcpConfig) + require.NoError(t, err) + + // Verify max depth is configured + clients, err := b.GetMCPClients() + require.NoError(t, err) + assert.NotNil(t, clients, "Should have clients") +} diff --git a/tests/core-mcp/auto_execute_config_test.go b/tests/core-mcp/auto_execute_config_test.go new file mode 100644 index 0000000000..ec4946380e --- /dev/null +++ b/tests/core-mcp/auto_execute_config_test.go @@ -0,0 +1,322 @@ +package mcp + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestToolInToolsToExecuteButNotInToolsToAutoExecute(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure echo in ToolsToExecute but not in ToolsToAutoExecute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"echo"}, + ToolsToAutoExecute: []string{}, // Empty - no auto-execute + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "echo") + assert.Empty(t, bifrostClient.Config.ToolsToAutoExecute) + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable") +} + +func TestToolInBothToolsToExecuteAndToolsToAutoExecute(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure echo in both lists + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"echo"}, + ToolsToAutoExecute: []string{"echo"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "echo") + assert.Contains(t, bifrostClient.Config.ToolsToAutoExecute, "echo") + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should be auto-executable") +} + +func TestToolInToolsToAutoExecuteButNotInToolsToExecute(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure echo in ToolsToAutoExecute but not in ToolsToExecute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"add"}, // echo not in this list + ToolsToAutoExecute: []string{"echo"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + // echo should not be auto-executable because it's not in ToolsToExecute + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable (not in ToolsToExecute)") +} + +func TestWildcardInToolsToAutoExecute(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure wildcard in ToolsToAutoExecute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"*"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Contains(t, bifrostClient.Config.ToolsToAutoExecute, "*") + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should be auto-executable with wildcard") + assert.True(t, canAutoExecuteTool("add", bifrostClient.Config), "add should be auto-executable with wildcard") +} + +func TestEmptyToolsToAutoExecute(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure empty ToolsToAutoExecute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{}, // Empty - no auto-execute + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Empty(t, bifrostClient.Config.ToolsToAutoExecute) + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable") +} + +func TestNilToolsToAutoExecute(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure nil ToolsToAutoExecute (omitted) + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"*"}, + // ToolsToAutoExecute omitted (nil) + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + // nil should be treated as empty + if bifrostClient.Config.ToolsToAutoExecute == nil { + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable (nil treated as empty)") + } else { + assert.Empty(t, bifrostClient.Config.ToolsToAutoExecute) + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable") + } +} + +func TestMultipleToolsWithMixedAutoExecuteConfigs(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure mixed: echo auto-execute, add non-auto-execute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"echo", "add", "multiply"}, + ToolsToAutoExecute: []string{"echo", "multiply"}, // add not in auto-execute + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should be auto-executable") + assert.False(t, canAutoExecuteTool("add", bifrostClient.Config), "add should not be auto-executable") + assert.True(t, canAutoExecuteTool("multiply", bifrostClient.Config), "multiply should be auto-executable") +} + +func TestToolsToExecuteEmptyList(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure empty ToolsToExecute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{}, // Empty - no tools allowed + ToolsToAutoExecute: []string{"*"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Empty(t, bifrostClient.Config.ToolsToExecute) + // Even with wildcard in ToolsToAutoExecute, tools not in ToolsToExecute should not be auto-executable + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable (not in ToolsToExecute)") +} + +func TestToolsToExecuteNil(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure nil ToolsToExecute (omitted) + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + // ToolsToExecute omitted (nil) + ToolsToAutoExecute: []string{"*"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + // nil ToolsToExecute should be treated as empty + if bifrostClient.Config.ToolsToExecute == nil { + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable (nil ToolsToExecute treated as empty)") + } else { + assert.Empty(t, bifrostClient.Config.ToolsToExecute) + assert.False(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should not be auto-executable") + } +} diff --git a/tests/core-mcp/client_config_test.go b/tests/core-mcp/client_config_test.go new file mode 100644 index 0000000000..3c06666a2a --- /dev/null +++ b/tests/core-mcp/client_config_test.go @@ -0,0 +1,346 @@ +package mcp + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSingleCodeModeClient(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + clients, err := b.GetMCPClients() + require.NoError(t, err) + require.NotEmpty(t, clients) + + // Find bifrostInternal client + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient, "bifrostInternal client should exist") + assert.True(t, bifrostClient.Config.IsCodeModeClient, "bifrostInternal should be code mode client") + assert.Equal(t, schemas.MCPConnectionStateConnected, bifrostClient.State) +} + +func TestSingleNonCodeModeClient(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Note: For in-process clients, we need to register tools first + err = registerTestTools(b) + require.NoError(t, err) + + // Update bifrostInternal to be non-code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + require.NotEmpty(t, clients) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.False(t, bifrostClient.Config.IsCodeModeClient, "bifrostInternal should be non-code mode client") +} + +func TestMultipleCodeModeClients(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + codeModeCount := 0 + for _, client := range clients { + if client.Config.IsCodeModeClient { + codeModeCount++ + } + } + + assert.GreaterOrEqual(t, codeModeCount, 1, "Should have at least one code mode client") +} + +func TestMultipleNonCodeModeClients(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to non-code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + nonCodeModeCount := 0 + for _, client := range clients { + if !client.Config.IsCodeModeClient { + nonCodeModeCount++ + } + } + + assert.GreaterOrEqual(t, nonCodeModeCount, 1, "Should have at least one non-code mode client") +} + +func TestMixedCodeModeAndNonCodeModeClients(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + codeModeCount := 0 + + for _, client := range clients { + if client.Config.IsCodeModeClient { + codeModeCount++ + } + } + + // At minimum, we should have bifrostInternal as code mode + assert.GreaterOrEqual(t, codeModeCount, 1, "Should have at least one code mode client") +} + +func TestClientConnectionStates(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + clients, err := b.GetMCPClients() + require.NoError(t, err) + require.NotEmpty(t, clients) + + // All clients should be connected + for _, client := range clients { + assert.Equal(t, schemas.MCPConnectionStateConnected, client.State, "Client %s should be connected", client.Config.ID) + } +} + +func TestClientWithNoTools(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Don't register any tools - bifrostInternal client should still exist but with no tools + clients, err := b.GetMCPClients() + require.NoError(t, err) + + // bifrostInternal client is created when MCP is initialized, but won't have tools until registered + // This test verifies the client exists even without tools + assert.NotNil(t, clients, "Clients list should exist") + + // Find bifrostInternal client + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient, "bifrostInternal client should exist") + assert.Empty(t, bifrostClient.Tools, "bifrostInternal client should have no tools") +} + +func TestClientWithEmptyToolLists(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set ToolsToExecute to empty list + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Equal(t, []string{}, bifrostClient.Config.ToolsToExecute, "ToolsToExecute should be empty") +} + +func TestClientConfigUpdate(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Initially, bifrostInternal should not be code mode (default) + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + initialIsCodeMode := bifrostClient.Config.IsCodeModeClient + + // Update to code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + }) + require.NoError(t, err) + + // Verify update + clients, err = b.GetMCPClients() + require.NoError(t, err) + + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.NotEqual(t, initialIsCodeMode, bifrostClient.Config.IsCodeModeClient, "IsCodeModeClient should have changed") + assert.True(t, bifrostClient.Config.IsCodeModeClient, "Should now be code mode") +} + +func TestClientWithToolsToExecuteWildcard(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set ToolsToExecute to wildcard + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"*"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "*", "Should contain wildcard") +} + +func TestClientWithSpecificToolsToExecute(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set ToolsToExecute to specific tools + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"echo", "add"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "echo") + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "add") + assert.Len(t, bifrostClient.Config.ToolsToExecute, 2) +} diff --git a/tests/core-mcp/codemode_auto_execute_test.go b/tests/core-mcp/codemode_auto_execute_test.go new file mode 100644 index 0000000000..0be7177917 --- /dev/null +++ b/tests/core-mcp/codemode_auto_execute_test.go @@ -0,0 +1,233 @@ +package mcp + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/require" +) + +func TestExecuteToolCodeWithAutoExecuteTool(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Configure echo as auto-execute - preserve existing config + clients, err := b.GetMCPClients() + require.NoError(t, err) + var currentConfig *schemas.MCPClientConfig + for _, client := range clients { + if client.Config.ID == "bifrostInternal" { + currentConfig = &client.Config + break + } + } + require.NotNil(t, currentConfig) + + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ID: currentConfig.ID, + Name: currentConfig.Name, + ConnectionType: currentConfig.ConnectionType, + IsCodeModeClient: currentConfig.IsCodeModeClient, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo"}, + }) + require.NoError(t, err) + + // Test executeToolCode with code calling auto-execute tool + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithAutoExecuteTool, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestExecuteToolCodeWithNonAutoExecuteTool(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Configure multiply as non-auto-execute - preserve existing config + clients, err := b.GetMCPClients() + require.NoError(t, err) + var currentConfig *schemas.MCPClientConfig + for _, client := range clients { + if client.Config.ID == "bifrostInternal" { + currentConfig = &client.Config + break + } + } + require.NotNil(t, currentConfig) + + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ID: currentConfig.ID, + Name: currentConfig.Name, + ConnectionType: currentConfig.ConnectionType, + IsCodeModeClient: currentConfig.IsCodeModeClient, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo"}, // multiply not in auto-execute + }) + require.NoError(t, err) + + // Test executeToolCode with code calling non-auto-execute tool + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithNonAutoExecuteTool, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestExecuteToolCodeWithMixedAutoExecute(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Configure echo as auto-execute, multiply as non-auto-execute - preserve existing config + clients, err := b.GetMCPClients() + require.NoError(t, err) + var currentConfig *schemas.MCPClientConfig + for _, client := range clients { + if client.Config.ID == "bifrostInternal" { + currentConfig = &client.Config + break + } + } + require.NotNil(t, currentConfig) + + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ID: currentConfig.ID, + Name: currentConfig.Name, + ConnectionType: currentConfig.ConnectionType, + IsCodeModeClient: currentConfig.IsCodeModeClient, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo"}, // multiply not in auto-execute + }) + require.NoError(t, err) + + // Test executeToolCode with code calling mixed tools + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithMixedAutoExecute, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestExecuteToolCodeWithNoToolCalls(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test executeToolCode with no tool calls + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithNoToolCalls, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestExecuteToolCodeWithListToolFiles(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // listToolFiles should always be auto-executable + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithListToolFiles, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + + // listToolFiles and readToolFile are code mode meta-tools and cannot be called from within executeToolCode + // They're only available as direct tool calls, not from within code execution + // So this will fail with a runtime error + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestExecuteToolCodeWithReadToolFile(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // readToolFile should always be auto-executable + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithReadToolFile, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + + // listToolFiles and readToolFile are code mode meta-tools and cannot be called from within executeToolCode + // They're only available as direct tool calls, not from within code execution + // So this will fail with a runtime error + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestExecuteToolCodeWithUndefinedServer(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test executeToolCode with undefined server + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithUndefinedServer, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + // Should fail with runtime error + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestExecuteToolCodeWithUndefinedTool(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test executeToolCode with undefined tool + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeWithUndefinedTool, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + // Should fail with runtime error + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + assertExecutionResult(t, result, false, nil, "runtime") +} diff --git a/tests/core-mcp/edge_cases_test.go b/tests/core-mcp/edge_cases_test.go new file mode 100644 index 0000000000..aa4292536f --- /dev/null +++ b/tests/core-mcp/edge_cases_test.go @@ -0,0 +1,299 @@ +package mcp + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCodeModeClientCallingNonCodeModeClientTool(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test code trying to call non-code mode client tool + // This should fail at runtime since non-code mode clients aren't available in code execution + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeCallingNonCodeModeTool, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + // Should fail with runtime error - tool call succeeds but code execution fails + requireNoBifrostError(t, bifrostErr, "Tool call should succeed") + require.NotNil(t, result, "Result should be present") + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestNonCodeModeClientToolCalledFromExecuteToolCode(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Code mode can only call code mode client tools + // Non-code mode tools are not available in executeToolCode context + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": `const result = await NonExistentClient.tool({}); return result`, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + // Should fail with runtime error - tool call succeeds but code execution fails + requireNoBifrostError(t, bifrostErr, "Tool call should succeed") + require.NotNil(t, result, "Result should be present") + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestToolNotInToolsToExecute(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure only echo in ToolsToExecute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ToolsToExecute: []string{"echo"}, // add not in list + }) + require.NoError(t, err) + + // Try to execute add tool (not in ToolsToExecute) + addCall := createToolCall("add", map[string]interface{}{ + "a": float64(1), + "b": float64(2), + }) + _, bifrostErr := b.ExecuteChatMCPTool(ctx, addCall) + + // Should fail - tool not available + assert.NotNil(t, bifrostErr, "Should fail when tool not in ToolsToExecute") +} + +func TestToolExecutionTimeoutEdgeCase(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Test slow tool with timeout + slowCall := createToolCall("slow_tool", map[string]interface{}{ + "delay_ms": float64(100), + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, slowCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "Completed", "Should complete execution") +} + +func TestToolExecutionErrorPropagation(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Test error tool + errorCall := createToolCall("error_tool", map[string]interface{}{}) + result, bifrostErr := b.ExecuteChatMCPTool(ctx, errorCall) + + // Tool execution should succeed (no bifrostErr), but result should contain error message + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "Error:", "Result should contain error message") + assert.Contains(t, responseText, "this tool always fails", "Result should contain the error text") +} + +func TestEmptyCodeExecution(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.EmptyCode, + }) + + _, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + // Empty code should return an error + require.NotNil(t, bifrostErr, "Empty code should return an error") + assert.Contains(t, bifrostErr.Error.Message, "code parameter is required", "Error should mention code parameter") +} + +func TestCodeWithSyntaxErrors(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.SyntaxError, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + + // Syntax errors are caught during JavaScript execution (runtime), not TypeScript compilation + // The error will be a runtime SyntaxError + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestCodeWithTypeScriptCompilationErrors(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Invalid TypeScript code + invalidCode := `const x: string = 123; return x` + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": invalidCode, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + + // TypeScript type errors might not be caught - the code might execute successfully + // This is acceptable behavior if type checking is disabled + // Just verify the execution completed (either with error or success) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) +} + +func TestCodeWithRuntimeErrors(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.RuntimeError, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + // Should fail with runtime error + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + assertExecutionResult(t, result, false, nil, "runtime") +} + +func TestCodeCallingToolsWithInvalidArguments(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Code calling tool with invalid arguments + invalidArgsCode := `const result = await BifrostClient.echo({invalid: "arg"}); return result` + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": invalidArgsCode, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + // Should fail - tool expects "message" parameter + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + assertExecutionResult(t, result, false, nil, "") +} + +func TestCodeModeToolsAlwaysAutoExecutable(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{}, // Empty - no auto-execute configured + }) + require.NoError(t, err) + + // listToolFiles and readToolFile should always be auto-executable + // This is tested in integration tests that verify agent mode behavior + // For now, verify they can be executed directly + listCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteChatMCPTool(ctx, listCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) +} + +func TestCommentsOnlyCode(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CommentsOnly, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + + // Comments-only code should execute (return null) + assertExecutionResult(t, result, true, nil, "") +} + +func TestUndefinedVariableError(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.UndefinedVariable, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + // Should fail with runtime error + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + assertExecutionResult(t, result, false, nil, "runtime") +} diff --git a/tests/core-mcp/fixtures.go b/tests/core-mcp/fixtures.go new file mode 100644 index 0000000000..fe8b5a82e5 --- /dev/null +++ b/tests/core-mcp/fixtures.go @@ -0,0 +1,311 @@ +package mcp + +// CodeFixtures contains sample TypeScript code snippets for testing +var CodeFixtures = struct { + // Basic expressions + SimpleExpression string + SimpleString string + VariableAssignment string + ConsoleLogging string + ExplicitReturn string + AutoReturnExpression string + + // MCP tool calls + SingleToolCall string + ToolCallWithPromise string + ToolCallChain string + ToolCallErrorHandling string + MultipleServerToolCalls string + ToolCallWithComplexArgs string + + // Import/Export + ImportStatement string + ExportStatement string + MultipleImportExport string + ImportExportWithComments string + + // Expression analysis + FunctionCallExpression string + PromiseChainExpression string + ObjectLiteralExpression string + AssignmentStatement string + ControlFlowStatement string + TopLevelReturn string + + // Error cases + UndefinedVariable string + UndefinedServer string + UndefinedTool string + SyntaxError string + RuntimeError string + + // Edge cases + NestedPromiseChains string + PromiseErrorHandling string + ComplexDataStructures string + MultiLineExpression string + EmptyCode string + CommentsOnly string + FunctionDefinition string + + // Environment tests + AsyncAwaitTest string + EnvironmentTest string + + // Long code test + LongCodeExecution string + + // Auto-execute validation tests + CodeWithAutoExecuteTool string + CodeWithNonAutoExecuteTool string + CodeWithMixedAutoExecute string + CodeWithMultipleClients string + CodeWithNoToolCalls string + CodeWithListToolFiles string + CodeWithReadToolFile string + + // Mixed client scenarios + CodeCallingCodeModeTool string + CodeCallingNonCodeModeTool string + CodeCallingMultipleServers string + CodeWithUndefinedServer string + CodeWithUndefinedTool string + + // Agent mode scenarios + CodeForAgentModeAutoExecute string + CodeForAgentModeNonAutoExecute string +}{ + SimpleExpression: `return 1 + 1`, + SimpleString: `return "hello"`, + VariableAssignment: `var x = 5; return x`, + ConsoleLogging: `console.log("test"); return "logged"`, + ExplicitReturn: `return 42`, + AutoReturnExpression: `return 2 + 2`, // Note: Now requires explicit return + + SingleToolCall: `const result = await BifrostClient.echo({message: "hello"}); return result`, + ToolCallWithPromise: `const result = await BifrostClient.echo({message: "test"}); console.log(result); return result`, + ToolCallChain: `const result1 = await BifrostClient.add({a: 1, b: 2}); const result2 = await BifrostClient.multiply({a: result1, b: 3}); return result2`, + ToolCallErrorHandling: `try { await BifrostClient.error_tool({}); } catch (err) { console.error(err); return "handled"; }`, + MultipleServerToolCalls: `const r1 = await BifrostClient.echo({message: "test"}); const r2 = await BifrostClient.add({a: 1, b: 2}); return r2`, + ToolCallWithComplexArgs: `return await BifrostClient.complex_args_tool({data: {nested: {value: 42}}})`, + + ImportStatement: `import { something } from "module"; return 1 + 1`, + ExportStatement: `export const x = 5; return x`, + MultipleImportExport: `import a from "a"; import b from "b"; export const c = 1; return 2 + 2`, + ImportExportWithComments: `// comment\nimport x from "x";\n// another comment\nreturn 2 + 2`, + + FunctionCallExpression: `return Math.max(1, 2)`, // Note: Now requires explicit return + PromiseChainExpression: `return Promise.resolve(1).then(x => x + 1)`, // Note: Now requires explicit return + ObjectLiteralExpression: `return {a: 1, b: 2}`, // Note: Now requires explicit return + AssignmentStatement: `var x = 5`, // Assignment statements don't return values + ControlFlowStatement: `if (true) { return 1; } else { return 2; }`, // Note: Now requires explicit return + TopLevelReturn: `return 42`, + + UndefinedVariable: `return undefinedVar`, // Will cause runtime error + UndefinedServer: `return nonexistentServer.tool({})`, // Will cause runtime error + UndefinedTool: `return BifrostClient.nonexistentTool({})`, // Will cause runtime error + SyntaxError: `var x = `, // Syntax error - no return needed + RuntimeError: `return null.someProperty`, // Will cause runtime error + + NestedPromiseChains: `return Promise.resolve(1).then(x => Promise.resolve(x + 1).then(y => y + 1))`, // Note: Now requires explicit return + PromiseErrorHandling: `return Promise.reject("error").catch(err => "handled")`, // Note: Now requires explicit return + ComplexDataStructures: `return [{a: 1}, {b: 2}].map(x => x.a || x.b)`, // Note: Now requires explicit return + MultiLineExpression: `const result = await BifrostClient.echo({message: "test"});\n return result`, // Note: Now requires explicit return + EmptyCode: ``, + CommentsOnly: `// comment\n/* another */`, + FunctionDefinition: `function test() { return 1; } return test()`, // Note: Now requires explicit return for function call + + AsyncAwaitTest: `async function test() { const result = await Promise.resolve(1); return result; } return test()`, + EnvironmentTest: `return __MCP_ENV__.serverKeys`, + + LongCodeExecution: `// Long and complex code execution test with extensive operations\n` + + `(async function() {\n` + + ` var results = [];\n` + + ` var sum = 0;\n` + + ` var processedData = [];\n` + + ` var executionLog = [];\n` + + ` \n` + + ` // Initialize execution context\n` + + ` var context = {\n` + + ` startTime: Date.now(),\n` + + ` steps: 0,\n` + + ` errors: [],\n` + + ` warnings: []\n` + + ` };\n` + + ` \n` + + ` try {\n` + + ` // Step 1: Initial echo call\n` + + ` const result1 = await BifrostClient.echo({message: "step1"});\n` + + ` console.log("Step 1 completed:", result1);\n` + + ` results.push(result1);\n` + + ` context.steps++;\n` + + ` executionLog.push({step: 1, action: "echo", result: result1});\n` + + ` \n` + + ` // Step 2: Add operation\n` + + ` const result2 = await BifrostClient.add({a: 10, b: 20});\n` + + ` console.log("Step 2 completed:", result2);\n` + + ` results.push(result2);\n` + + ` sum += result2;\n` + + ` context.steps++;\n` + + ` executionLog.push({step: 2, action: "add", result: result2, sum: sum});\n` + + ` \n` + + ` // Conditional logic based on result\n` + + ` let result3;\n` + + ` if (result2 > 25) {\n` + + ` console.log("Result is greater than 25, proceeding with multiplication");\n` + + ` result3 = await BifrostClient.multiply({a: result2, b: 2});\n` + + ` } else {\n` + + ` console.log("Result is less than or equal to 25, using add again");\n` + + ` result3 = await BifrostClient.add({a: result2, b: 5});\n` + + ` }\n` + + ` console.log("Step 3 completed:", result3);\n` + + ` results.push(result3);\n` + + ` sum += result3;\n` + + ` context.steps++;\n` + + ` executionLog.push({step: 3, action: "math", result: result3, sum: sum});\n` + + ` \n` + + ` // Step 4: Echo call\n` + + ` const result4 = await BifrostClient.echo({message: "step4"});\n` + + ` console.log("Step 4 completed:", result4);\n` + + ` results.push(result4);\n` + + ` context.steps++;\n` + + ` executionLog.push({step: 4, action: "echo", result: result4});\n` + + ` \n` + + ` // Complex loop with nested operations\n` + + ` for (var i = 0; i < 20; i++) {\n` + + ` sum += i;\n` + + ` if (i % 3 === 0) {\n` + + ` processedData.push({\n` + + ` index: i,\n` + + ` value: i * 2,\n` + + ` isMultipleOfThree: true\n` + + ` });\n` + + ` } else if (i % 2 === 0) {\n` + + ` processedData.push({\n` + + ` index: i,\n` + + ` value: i * 1.5,\n` + + ` isEven: true\n` + + ` });\n` + + ` } else {\n` + + ` processedData.push({\n` + + ` index: i,\n` + + ` value: i,\n` + + ` isOdd: true\n` + + ` });\n` + + ` }\n` + + ` }\n` + + ` \n` + + ` console.log("Processed", processedData.length, "data items");\n` + + ` \n` + + ` // Step 5: Get data\n` + + ` const result5 = await BifrostClient.get_data({key: "test"});\n` + + ` console.log("Step 5 completed:", result5);\n` + + ` results.push(result5);\n` + + ` context.steps++;\n` + + ` executionLog.push({step: 5, action: "get_data", result: result5});\n` + + ` \n` + + ` // Nested data processing\n` + + ` var nestedResults = [];\n` + + ` for (var j = 0; j < results.length; j++) {\n` + + ` var item = results[j];\n` + + ` nestedResults.push({\n` + + ` original: item,\n` + + ` processed: typeof item === "string" ? item.toUpperCase() : item * 1.1,\n` + + ` index: j,\n` + + ` metadata: {\n` + + ` type: typeof item,\n` + + ` isString: typeof item === "string",\n` + + ` isNumber: typeof item === "number"\n` + + ` }\n` + + ` });\n` + + ` }\n` + + ` \n` + + ` // Step 6: Final echo call\n` + + ` const result6 = await BifrostClient.echo({message: "final_step"});\n` + + ` console.log("Step 6 completed:", result6);\n` + + ` results.push(result6);\n` + + ` context.steps++;\n` + + ` executionLog.push({step: 6, action: "echo", result: result6});\n` + + ` \n` + + ` // Calculate statistics\n` + + ` var stats = {\n` + + ` totalResults: results.length,\n` + + ` numericSum: sum,\n` + + ` average: sum / results.length,\n` + + ` processedItems: processedData.length,\n` + + ` executionSteps: context.steps\n` + + ` };\n` + + ` \n` + + ` // Create comprehensive final data structure\n` + + ` var finalData = {\n` + + ` results: results,\n` + + ` processedData: processedData,\n` + + ` executionLog: executionLog,\n` + + ` statistics: stats,\n` + + ` context: {\n` + + ` steps: context.steps,\n` + + ` executionTime: Date.now() - context.startTime,\n` + + ` errors: context.errors,\n` + + ` warnings: context.warnings\n` + + ` },\n` + + ` metadata: {\n` + + ` executed: true,\n` + + ` completed: true,\n` + + ` totalOperations: context.steps,\n` + + ` dataProcessed: processedData.length,\n` + + ` finalSum: sum,\n` + + ` resultCount: results.length\n` + + ` }\n` + + ` };\n` + + ` \n` + + ` console.log("Final statistics:", JSON.stringify(stats));\n` + + ` console.log("Execution completed successfully with", context.steps, "steps");\n` + + ` console.log("Processed", processedData.length, "data items");\n` + + ` console.log("Final sum:", sum);\n` + + ` \n` + + ` return finalData;\n` + + ` } catch (error) {\n` + + ` console.error("Error in long execution:", error);\n` + + ` context.errors.push(error.toString());\n` + + ` return {\n` + + ` error: error.toString(),\n` + + ` context: context,\n` + + ` partialResults: results,\n` + + ` partialSum: sum\n` + + ` };\n` + + ` }\n` + + `})()`, + + // Auto-execute validation tests + CodeWithAutoExecuteTool: `const result = await BifrostClient.echo({message: "auto-execute"}); return result`, + CodeWithNonAutoExecuteTool: `const result = await BifrostClient.multiply({a: 2, b: 3}); return result`, + CodeWithMixedAutoExecute: `const r1 = await BifrostClient.echo({message: "auto"}); const r2 = await BifrostClient.multiply({a: 2, b: 3}); return r2`, + CodeWithMultipleClients: `const r1 = await BifrostClient.echo({message: "test"}); const r2 = await Server2.add({a: 1, b: 2}); return r2`, + CodeWithNoToolCalls: `return 42`, + CodeWithListToolFiles: `const files = await BifrostClient.listToolFiles({}); return files`, + CodeWithReadToolFile: `const content = await BifrostClient.readToolFile({fileName: "BifrostClient.d.ts"}); return content`, + + // Mixed client scenarios + CodeCallingCodeModeTool: `const result = await BifrostClient.echo({message: "test"}); return result`, + CodeCallingNonCodeModeTool: `const result = await NonCodeModeClient.someTool({}); return result`, + CodeCallingMultipleServers: `const r1 = await BifrostClient.echo({message: "test"}); const r2 = await Server2.add({a: 1, b: 2}); return {r1, r2}`, + CodeWithUndefinedServer: `const result = await UndefinedServer.tool({}); return result`, + CodeWithUndefinedTool: `const result = await BifrostClient.undefinedTool({}); return result`, + + // Agent mode scenarios + CodeForAgentModeAutoExecute: `const result = await BifrostClient.echo({message: "agent-auto"}); return result`, + CodeForAgentModeNonAutoExecute: `const result = await BifrostClient.multiply({a: 5, b: 6}); return result`, +} + +// ExpectedResults contains expected results for validation +var ExpectedResults = struct { + SimpleExpressionResult interface{} + EchoResult string + AddResult float64 + MultiplyResult float64 +}{ + SimpleExpressionResult: float64(2), + EchoResult: "hello", + AddResult: float64(3), + MultiplyResult: float64(6), +} diff --git a/tests/core-mcp/go.mod b/tests/core-mcp/go.mod new file mode 100644 index 0000000000..fc1f162481 --- /dev/null +++ b/tests/core-mcp/go.mod @@ -0,0 +1,76 @@ +module github.com/maximhq/bifrost/tests/core-mcp + +go 1.25.5 + +replace github.com/maximhq/bifrost/core => ../../core + +require ( + github.com/maximhq/bifrost/core v0.0.0-00010101000000-000000000000 + github.com/stretchr/testify v1.11.1 +) + +require ( + cloud.google.com/go/compute/metadata v0.9.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.41.0 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4 // indirect + github.com/aws/aws-sdk-go-v2/config v1.32.6 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.19.6 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.16 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16 // indirect + github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.8 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 // indirect + github.com/aws/smithy-go v1.24.0 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.14.2 // indirect + github.com/bytedance/sonic/loader v0.4.0 // indirect + github.com/clarkmcc/go-typescript v0.7.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/klauspost/compress v1.18.2 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect + github.com/mailru/easyjson v0.9.1 // indirect + github.com/mark3labs/mcp-go v0.43.2 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/rs/zerolog v1.34.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.68.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/arch v0.23.0 // indirect + golang.org/x/crypto v0.46.0 // indirect + golang.org/x/net v0.48.0 // indirect + golang.org/x/oauth2 v0.34.0 // indirect + golang.org/x/sys v0.39.0 // indirect + golang.org/x/text v0.32.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/tests/core-mcp/go.sum b/tests/core-mcp/go.sum new file mode 100644 index 0000000000..0717d36512 --- /dev/null +++ b/tests/core-mcp/go.sum @@ -0,0 +1,178 @@ +cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= +cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/Masterminds/semver/v3 v3.3.1 h1:QtNSWtVZ3nBfk8mAOu/B6v7FMJ+NHTIgUPi7rj+4nv4= +github.com/Masterminds/semver/v3 v3.3.1/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= +github.com/aws/aws-sdk-go-v2 v1.41.0/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4 h1:489krEF9xIGkOaaX3CE/Be2uWjiXrkCH6gUX+bZA/BU= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4/go.mod h1:IOAPF6oT9KCsceNTvvYMNHy0+kMF8akOjeDvPENWxp4= +github.com/aws/aws-sdk-go-v2/config v1.32.6 h1:hFLBGUKjmLAekvi1evLi5hVvFQtSo3GYwi+Bx4lpJf8= +github.com/aws/aws-sdk-go-v2/config v1.32.6/go.mod h1:lcUL/gcd8WyjCrMnxez5OXkO3/rwcNmvfno62tnXNcI= +github.com/aws/aws-sdk-go-v2/credentials v1.19.6 h1:F9vWao2TwjV2MyiyVS+duza0NIRtAslgLUM0vTA1ZaE= +github.com/aws/aws-sdk-go-v2/credentials v1.19.6/go.mod h1:SgHzKjEVsdQr6Opor0ihgWtkWdfRAIwxYzSJ8O85VHY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 h1:80+uETIWS1BqjnN9uJ0dBUaETh+P1XwFy5vwHwK5r9k= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16/go.mod h1:wOOsYuxYuB/7FlnVtzeBYRcjSRtQpAW0hCP7tIULMwo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 h1:rgGwPzb82iBYSvHMHXc8h9mRoOUBZIGFgKb9qniaZZc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16/go.mod h1:L/UxsGeKpGoIj6DxfhOWHWQ/kGKcd4I1VncE4++IyKA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 h1:1jtGzuV7c82xnqOVfx2F0xmJcOw5374L7N6juGW6x6U= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16/go.mod h1:M2E5OQf+XLe+SZGmmpaI2yy+J326aFf6/+54PoxSANc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.16 h1:CjMzUs78RDDv4ROu3JnJn/Ig1r6ZD7/T2DXLLRpejic= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.16/go.mod h1:uVW4OLBqbJXSHJYA9svT9BluSvvwbzLQ2Crf6UPzR3c= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 h1:0ryTNEdJbzUCEWkVXEXoqlXV72J5keC1GvILMOuD00E= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4/go.mod h1:HQ4qwNZh32C3CBeO6iJLQlgtMzqeG17ziAA/3KDJFow= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.7 h1:DIBqIrJ7hv+e4CmIk2z3pyKT+3B6qVMgRsawHiR3qso= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.7/go.mod h1:vLm00xmBke75UmpNvOcZQ/Q30ZFjbczeLFqGx5urmGo= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 h1:oHjJHeUy0ImIV0bsrX0X91GkV5nJAyv1l1CC9lnO0TI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16/go.mod h1:iRSNGgOYmiYwSCXxXaKb9HfOEj40+oTKn8pTxMlYkRM= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16 h1:NSbvS17MlI2lurYgXnCOLvCFX38sBW4eiVER7+kkgsU= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16/go.mod h1:SwT8Tmqd4sA6G1qaGdzWCJN99bUmPGHfRwwq3G5Qb+A= +github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0 h1:SWTxh/EcUCDVqi/0s26V6pVUq0BBG7kx0tDTmF/hCgA= +github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0/go.mod h1:79S2BdqCJpScXZA2y+cpZuocWsjGjJINyXnOsf5DTz8= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 h1:HpI7aMmJ+mm1wkSHIA2t5EaFFv5EFYXePW30p1EIrbQ= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.4/go.mod h1:C5RdGMYGlfM0gYq/tifqgn4EbyX99V15P2V3R+VHbQU= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.8 h1:aM/Q24rIlS3bRAhTyFurowU8A0SMyGDtEOY/l/s/1Uw= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.8/go.mod h1:+fWt2UHSb4kS7Pu8y+BMBvJF0EWx+4H0hzNwtDNRTrg= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 h1:AHDr0DaHIAo8c9t1emrzAlVDFp+iMMKnPdYy6XO4MCE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12/go.mod h1:GQ73XawFFiWxyWXMHWfhiomvP3tXtdNar/fi8z18sx0= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 h1:SciGFVNZ4mHdm7gpD1dgZYnCuVdX1s+lFTg4+4DOy70= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5/go.mod h1:iW40X4QBmUxdP+fZNOpfmkdMZqsovezbAeO+Ubiv2pk= +github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk= +github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/sonic v1.14.2 h1:k1twIoe97C1DtYUo+fZQy865IuHia4PR5RPiuGPPIIE= +github.com/bytedance/sonic v1.14.2/go.mod h1:T80iDELeHiHKSc0C9tubFygiuXoGzrkjKzX2quAx980= +github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2NYzevs+o= +github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/clarkmcc/go-typescript v0.7.0 h1:3nVeaPYyTCWjX6Lf8GoEOTxME2bM5tLuWmwhSZ86uxg= +github.com/clarkmcc/go-typescript v0.7.0/go.mod h1:IZ/nzoVeydAmyfX7l6Jmp8lJDOEnae3jffoXwP4UyYg= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 h1:jxmXU5V9tXxJnydU5v/m9SG8TRUa/Z7IXODBpMs/P+U= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f h1:HU1RgM6NALf/KW9HEY6zry3ADbDKcmpQ+hJedoNGQYQ= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f/go.mod h1:67FPmZWbr+KDT/VlpWtw6sO9XSjpJmLuHpoLmWiTGgY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= +github.com/hajimehoshi/go-mp3 v0.3.4/go.mod h1:fRtZraRFcWb0pu7ok0LqyFhCUrPeMsGRSVop0eemFmo= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= +github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k= +github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= +github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8= +github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I= +github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.68.0 h1:v12Nx16iepr8r9ySOwqI+5RBJ/DqTxhOy1HrHoDFnok= +github.com/valyala/fasthttp v1.68.0/go.mod h1:5EXiRfYQAoiO/khu4oU9VISC/eVY6JqmSpPJoHCKsz4= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/arch v0.23.0 h1:lKF64A2jF6Zd8L0knGltUnegD62JMFBiCPBmQpToHhg= +golang.org/x/arch v0.23.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= +golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= +golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= +golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/tests/core-mcp/integration_test.go b/tests/core-mcp/integration_test.go new file mode 100644 index 0000000000..191b4bf268 --- /dev/null +++ b/tests/core-mcp/integration_test.go @@ -0,0 +1,229 @@ +package mcp + +import ( + "context" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFullWorkflowListToolFilesReadToolFileExecuteToolCode(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Step 1: List tool files + listCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteChatMCPTool(ctx, listCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "BifrostClient.d.ts", "Should list BifrostClient") + + // Step 2: Read tool file + readCall := createToolCall("readToolFile", map[string]interface{}{ + "fileName": "BifrostClient.d.ts", + }) + result, bifrostErr = b.ExecuteChatMCPTool(ctx, readCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText = *result.Content.ContentStr + assert.Contains(t, responseText, "interface", "Should contain interface definitions") + assert.Contains(t, responseText, "echo", "Should contain echo tool") + + // Step 3: Execute code using the discovered tools + executeCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeCallingCodeModeTool, + }) + result, bifrostErr = b.ExecuteChatMCPTool(ctx, executeCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestMultipleCodeModeClientsWithDifferentAutoExecuteConfigs(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure bifrostInternal with mixed auto-execute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo", "add"}, // multiply not auto-execute + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config)) + assert.True(t, canAutoExecuteTool("add", bifrostClient.Config)) + assert.False(t, canAutoExecuteTool("multiply", bifrostClient.Config)) +} + +func TestToolFilteringWithCodeMode(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure specific tools only + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + ToolsToExecute: []string{"echo", "add"}, // Only these tools available + ToolsToAutoExecute: []string{"echo"}, + }) + require.NoError(t, err) + + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "echo") + assert.Contains(t, bifrostClient.Config.ToolsToExecute, "add") + assert.NotContains(t, bifrostClient.Config.ToolsToExecute, "multiply") +} + +func TestCodeModeAndNonCodeModeToolsInSameRequest(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: true, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"*"}, + }) + require.NoError(t, err) + + // Code mode tools should be available + listCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteChatMCPTool(ctx, listCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + + // Verify direct tools are not exposed for code-mode clients + // Code mode clients expose tools via executeToolCode, not as direct tool calls + echoCall := createToolCall("echo", map[string]interface{}{ + "message": "test", + }) + _, bifrostErr = b.ExecuteChatMCPTool(ctx, echoCall) + require.NotNil(t, bifrostErr, "Direct tool call should fail for code-mode client") + assert.Contains(t, bifrostErr.Error.Message, "not available", "Error should indicate tool is not available") +} + +func TestComplexCodeExecutionWithMultipleToolCalls(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test complex code with multiple tool calls + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.ToolCallChain, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestCodeExecutionWithErrorHandling(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test code with error handling + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.ToolCallErrorHandling, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") + assertResultContains(t, result, "handled") +} + +func TestCodeExecutionWithAsyncAwait(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test async/await syntax + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.AsyncAwaitTest, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestLongCodeExecution(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test long and complex code execution + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.LongCodeExecution, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + assertExecutionResult(t, result, true, nil, "") +} diff --git a/tests/core-mcp/mcp_connection_test.go b/tests/core-mcp/mcp_connection_test.go new file mode 100644 index 0000000000..137eed520b --- /dev/null +++ b/tests/core-mcp/mcp_connection_test.go @@ -0,0 +1,299 @@ +package mcp + +import ( + "context" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMCPManagerInitialization(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + require.NotNil(t, b) + + // Verify MCP is configured + clients, err := b.GetMCPClients() + require.NoError(t, err) + assert.NotNil(t, clients) +} + +func TestLocalToolRegistration(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Register test tools + err = registerTestTools(b) + require.NoError(t, err) + + // Verify tools are available + clients, err := b.GetMCPClients() + require.NoError(t, err) + require.NotEmpty(t, clients) + + // Find the bifrostInternal client + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient, "bifrostInternal client should exist") + assert.Equal(t, schemas.MCPConnectionStateConnected, bifrostClient.State) + + // Verify tools are registered + toolNames := make(map[string]bool) + for _, tool := range bifrostClient.Tools { + toolNames[tool.Name] = true + } + + assert.True(t, toolNames["echo"], "echo tool should be registered") + assert.True(t, toolNames["add"], "add tool should be registered") + assert.True(t, toolNames["multiply"], "multiply tool should be registered") +} + +func TestToolDiscovery(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + // Use CodeMode since we're testing CodeMode tools (listToolFiles, readToolFile) + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test listToolFiles + listToolCall := createResponsesToolCall("listToolFiles", schemas.OrderedMap{}) + result, bifrostErr := b.ExecuteResponsesMCPTool(ctx, listToolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "servers/", "Should list servers") + assert.Contains(t, responseText, "BifrostClient.d.ts", "Should list BifrostClient server") + + // Test readToolFile + readToolCall := createResponsesToolCall("readToolFile", schemas.OrderedMap{ + "fileName": "BifrostClient.d.ts", + }) + result, bifrostErr = b.ExecuteResponsesMCPTool(ctx, readToolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText = *result.Content.ContentStr + assert.Contains(t, responseText, "interface", "Should contain TypeScript interface declarations") + assert.Contains(t, responseText, "echo", "Should contain echo tool definition") + assert.Contains(t, responseText, "EchoInput", "Should contain echo input interface") +} + +func TestToolExecution(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Register test tools + err = registerTestTools(b) + require.NoError(t, err) + + // Test echo tool + echoCall := createResponsesToolCall("echo", schemas.OrderedMap{ + "message": "test message", + }) + result, bifrostErr := b.ExecuteResponsesMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Equal(t, "test message", responseText) + + // Test add tool + addCall := createResponsesToolCall("add", schemas.OrderedMap{ + "a": schemas.Ptr(5), + "b": schemas.Ptr(3), + }) + result, bifrostErr = b.ExecuteResponsesMCPTool(ctx, addCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText = *result.Content.ContentStr + assert.Equal(t, "8", responseText) + + // Test multiply tool + multiplyCall := createResponsesToolCall("multiply", schemas.OrderedMap{ + "a": schemas.Ptr(4), + "b": schemas.Ptr(7), + }) + result, bifrostErr = b.ExecuteResponsesMCPTool(ctx, multiplyCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText = *result.Content.ContentStr + assert.Equal(t, "28", responseText) +} + +func TestMultipleServers(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + // Use CodeMode since we're testing CodeMode tools (listToolFiles) + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Verify we have at least one server + clients, err := b.GetMCPClients() + require.NoError(t, err) + require.NotEmpty(t, clients) + + // Test listToolFiles with multiple servers + listToolCall := createResponsesToolCall("listToolFiles", schemas.OrderedMap{}) + result, bifrostErr := b.ExecuteResponsesMCPTool(ctx, listToolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "BifrostClient.d.ts", "Should list BifrostClient server") +} + +// TestExternalMCPConnection tests connection to external MCP server +// This test requires external MCP credentials to be provided via environment variables +// or test configuration. For now, it's a placeholder that can be enabled when credentials are available. +func TestExternalMCPConnection(t *testing.T) { + t.Skip("Skipping external MCP connection test - requires credentials") + + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + _, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Example: Connect to external MCP server + // Uncomment and configure when credentials are available + /* + connectionString := os.Getenv("EXTERNAL_MCP_CONNECTION_STRING") + if connectionString == "" { + t.Skip("EXTERNAL_MCP_CONNECTION_STRING not set") + } + + err = connectExternalMCP(b, "external-server", "external-1", "http", connectionString) + require.NoError(t, err) + + // Verify connection + clients := b.GetMCPClients() + found := false + for _, client := range clients { + if client.Config.ID == "external-1" { + found = true + assert.Equal(t, schemas.MCPConnectionStateConnected, client.State) + break + } + } + assert.True(t, found, "External client should be connected") + */ +} + +func TestToolExecutionTimeout(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Register test tools + err = registerTestTools(b) + require.NoError(t, err) + + // Test slow tool with short timeout + slowCall := createResponsesToolCall("slow_tool", schemas.OrderedMap{ + "delay_ms": schemas.Ptr(100), + }) + + start := time.Now() + result, bifrostErr := b.ExecuteResponsesMCPTool(ctx, slowCall) + duration := time.Since(start) + + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + assert.GreaterOrEqual(t, duration, 100*time.Millisecond, "Should take at least 100ms") +} + +func TestToolExecutionError(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Register test tools + err = registerTestTools(b) + require.NoError(t, err) + + // Test error tool - tool execution succeeds but result contains error message + errorCall := createResponsesToolCall("error_tool", schemas.OrderedMap{}) + result, bifrostErr := b.ExecuteResponsesMCPTool(ctx, errorCall) + + // Tool execution should succeed (no bifrostErr), but result should contain error message + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "Error:", "Result should contain error message") + assert.Contains(t, responseText, "this tool always fails", "Result should contain the error text") +} + +func TestComplexArgsTool(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Register test tools + err = registerTestTools(b) + require.NoError(t, err) + + // Test complex args tool + complexCall := createResponsesToolCall("complex_args_tool", schemas.OrderedMap{ + "data": map[string]interface{}{ + "nested": map[string]interface{}{ + "value": float64(42), + "array": []interface{}{1, 2, 3}, + }, + }, + }) + result, bifrostErr := b.ExecuteResponsesMCPTool(ctx, complexCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "Received data", "Should process complex args") + assert.Contains(t, responseText, "42", "Should contain nested value") +} diff --git a/tests/core-mcp/responses_test.go b/tests/core-mcp/responses_test.go new file mode 100644 index 0000000000..8590b25ddd --- /dev/null +++ b/tests/core-mcp/responses_test.go @@ -0,0 +1,466 @@ +package mcp + +import ( + "context" + "testing" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestResponsesNonCodeModeToolExecution tests direct tool execution via Responses API +func TestResponsesNonCodeModeToolExecution(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to non-code mode and ensure tools are available + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, // Allow all tools + }) + require.NoError(t, err) + + // Execute tool directly to verify it works + echoCall := &schemas.ResponsesToolMessage{ + Name: schemas.Ptr("echo"), + Arguments: schemas.Ptr("{\"message\": \"test message\"}"), + } + result, bifrostErr := b.ExecuteResponsesMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "test message", responseText, "Echo tool should return the input message") +} + +// TestResponsesCodeModeToolExecution tests code mode tool execution via Responses API +func TestResponsesCodeModeToolExecution(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test executeToolCode directly to verify code mode works + toolCall := &schemas.ResponsesToolMessage{ + Name: schemas.Ptr("executeToolCode"), + Arguments: schemas.Ptr("{\"code\": \"console.log('test');\"}"), + } + + + result, bifrostErr := b.ExecuteResponsesMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + assertResponsesExecutionResult(t, result, true, nil, "") + assertResponsesResultContains(t, result, "completed successfully") +} + +// TestResponsesAgentModeWithAutoExecuteTools tests agent mode configuration with auto-executable tools +func TestResponsesAgentModeWithAutoExecuteTools(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithMCPConfig(ctx, &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{}, + ToolManagerConfig: &schemas.MCPToolManagerConfig{ + MaxAgentDepth: 10, + ToolExecutionTimeout: 30 * time.Second, + }, + FetchNewRequestIDFunc: func(ctx *schemas.BifrostContext) string { + return "test-request-id" + }, + }) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure bifrostInternal with echo as auto-execute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo"}, // Only echo is auto-execute + }) + require.NoError(t, err) + + // Verify configuration + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should be auto-executable") + assert.False(t, canAutoExecuteTool("multiply", bifrostClient.Config), "multiply should not be auto-executable") + + // Verify echo tool can be executed directly + echoCall := &schemas.ResponsesToolMessage{ + Name: schemas.Ptr("echo"), + Arguments: schemas.Ptr("{\"message\": \"test message\"}"), + } + result, bifrostErr := b.ExecuteResponsesMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "test message", responseText, "Echo tool should return the input message") +} + +// TestResponsesAgentModeWithNonAutoExecuteTools tests agent mode configuration with non-auto-executable tools +func TestResponsesAgentModeWithNonAutoExecuteTools(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithMCPConfig(ctx, &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{}, + ToolManagerConfig: &schemas.MCPToolManagerConfig{ + MaxAgentDepth: 10, + ToolExecutionTimeout: 30 * time.Second, + }, + FetchNewRequestIDFunc: func(ctx *schemas.BifrostContext) string { + return "test-request-id" + }, + }) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure bifrostInternal with multiply NOT in auto-execute + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"echo"}, // multiply is NOT auto-execute + }) + require.NoError(t, err) + + // Verify configuration + clients, err := b.GetMCPClients() + require.NoError(t, err) + + var bifrostClient *schemas.MCPClient + for i := range clients { + if clients[i].Config.ID == "bifrostInternal" { + bifrostClient = &clients[i] + break + } + } + + require.NotNil(t, bifrostClient) + assert.True(t, canAutoExecuteTool("echo", bifrostClient.Config), "echo should be auto-executable") + assert.False(t, canAutoExecuteTool("multiply", bifrostClient.Config), "multiply should not be auto-executable") + + // Verify multiply tool can still be executed directly (just not auto-executed) + multiplyCall := &schemas.ResponsesToolMessage{ + Name: schemas.Ptr("multiply"), + Arguments: schemas.Ptr("{\"a\": 2, \"b\": 3}"), + } + result, bifrostErr := b.ExecuteResponsesMCPTool(ctx, multiplyCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "6", responseText, "Multiply tool should return correct result") +} + +// TestResponsesAgentModeMaxDepth tests agent mode max depth configuration via Responses API +func TestResponsesAgentModeMaxDepth(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + // Create Bifrost with max depth of 2 + mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{}, + ToolManagerConfig: &schemas.MCPToolManagerConfig{ + MaxAgentDepth: 2, + ToolExecutionTimeout: 30 * time.Second, + }, + FetchNewRequestIDFunc: func(ctx *schemas.BifrostContext) string { + return "test-request-id" + }, + } + b, err := setupTestBifrostWithMCPConfig(ctx, mcpConfig) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure all tools as available + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, + }) + require.NoError(t, err) + + // Verify tools still work with max depth configured + echoCall := &schemas.ResponsesToolMessage{ + Name: schemas.Ptr("echo"), + Arguments: schemas.Ptr("{\"message\": \"test\"}"), + } + result, bifrostErr := b.ExecuteResponsesMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "test", responseText, "Echo tool should work with max depth configured") +} + +// TestResponsesToolExecutionTimeout tests tool execution timeout via Responses API +func TestResponsesToolExecutionTimeout(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + // Create Bifrost with short timeout + mcpConfig := &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{}, + ToolManagerConfig: &schemas.MCPToolManagerConfig{ + MaxAgentDepth: 10, + ToolExecutionTimeout: 100 * time.Millisecond, // Very short timeout + }, + FetchNewRequestIDFunc: func(ctx *schemas.BifrostContext) string { + return "test-request-id" + }, + } + b, err := setupTestBifrostWithMCPConfig(ctx, mcpConfig) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure slow_tool + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, + ToolsToAutoExecute: []string{"*"}, + }) + require.NoError(t, err) + + // Create a Responses request that will trigger a slow tool + req := &schemas.BifrostResponsesRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + Input: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("Call slow_tool with delay 500ms"), + }, + }, + }, + Params: &schemas.ResponsesParameters{ + Tools: []schemas.ResponsesTool{ + { + Name: schemas.Ptr("slow_tool"), + Description: schemas.Ptr("A tool that takes time to execute"), + }, + }, + }, + } + + // Execute the request - should handle timeout gracefully + _, bifrostErr := b.ResponsesRequest(ctx, req) + // Timeout errors are acceptable in this test + if bifrostErr != nil { + assert.Contains(t, bifrost.GetErrorMessage(bifrostErr), "timeout", "Should contain timeout error") + } +} + +// TestResponsesMultipleToolCalls tests multiple tool calls via Responses API +func TestResponsesMultipleToolCalls(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure all tools as available + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, + }) + require.NoError(t, err) + + // Test echo tool + echoCall := &schemas.ResponsesToolMessage{ + Name: schemas.Ptr("echo"), + Arguments: schemas.Ptr("{\"message\": \"test\"}"), + } + result, bifrostErr := b.ExecuteResponsesMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "test", responseText, "Echo tool should return correct result") + + // Test add tool + addCall := createResponsesToolCall("add", schemas.OrderedMap{ + "a": schemas.Ptr(5), + "b": schemas.Ptr(3), + }) + result, bifrostErr = b.ExecuteResponsesMCPTool(ctx, addCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText = *result.Content.ContentStr + assert.Equal(t, "8", responseText, "Add tool should return correct result") +} + +// TestResponsesCodeModeWithCodeExecution tests code mode with code execution via Responses API +func TestResponsesCodeModeWithCodeExecution(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test code calling code mode client tools + toolCall := createResponsesToolCall("executeToolCode", schemas.OrderedMap{ + "code": CodeFixtures.CodeCallingCodeModeTool, + }) + + result, bifrostErr := b.ExecuteResponsesMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + assertResponsesExecutionResult(t, result, true, nil, "") + assertResponsesResultContains(t, result, "test") +} + +// TestResponsesToolFiltering tests tool filtering via Responses API +func TestResponsesToolFiltering(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure specific tools only + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"echo", "add"}, // Only these tools available + ToolsToAutoExecute: []string{"echo"}, + }) + require.NoError(t, err) + + // Verify allowed tools work + echoCall := createResponsesToolCall("echo", schemas.OrderedMap{ + "message": "test", + }) + result, bifrostErr := b.ExecuteResponsesMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "test", responseText, "Echo tool should work") + + addCall := createResponsesToolCall("add", schemas.OrderedMap{ + "a": schemas.Ptr(1), + "b": schemas.Ptr(2), + }) + result, bifrostErr = b.ExecuteResponsesMCPTool(ctx, addCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText = *result.Content.ContentStr + assert.Equal(t, "3", responseText, "Add tool should work") + + // Verify multiply tool is NOT available (should fail) + multiplyCall := createResponsesToolCall("multiply", schemas.OrderedMap{ + "a": float64(2), + "b": float64(3), + }) + result, bifrostErr = b.ExecuteResponsesMCPTool(ctx, multiplyCall) + // Should fail because multiply is not in ToolsToExecute + assert.NotNil(t, bifrostErr, "Multiply tool should fail when not in ToolsToExecute") +} + +// TestResponsesComplexWorkflow tests a complex workflow via Responses API +func TestResponsesComplexWorkflow(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Configure all tools as available + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, + }) + require.NoError(t, err) + + // Test echo tool + echoCall := createResponsesToolCall("echo", schemas.OrderedMap{ + "message": "hello", + }) + result, bifrostErr := b.ExecuteResponsesMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText := *result.Content.ContentStr + assert.Equal(t, "hello", responseText, "Echo tool should return correct result") + + // Test add tool + addCall := createResponsesToolCall("add", schemas.OrderedMap{ + "a": schemas.Ptr(5), + "b": schemas.Ptr(3), + }) + result, bifrostErr = b.ExecuteResponsesMCPTool(ctx, addCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText = *result.Content.ContentStr + assert.Equal(t, "8", responseText, "Add tool should return correct result") + + // Test multiply tool with result from add + multiplyCall := createResponsesToolCall("multiply", schemas.OrderedMap{ + "a": schemas.Ptr(8), // Result from add + "b": schemas.Ptr(2), + }) + result, bifrostErr = b.ExecuteResponsesMCPTool(ctx, multiplyCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + responseText = *result.Content.ContentStr + assert.Equal(t, "16", responseText, "Multiply tool should return correct result") +} diff --git a/tests/core-mcp/setup.go b/tests/core-mcp/setup.go new file mode 100644 index 0000000000..5e957ed618 --- /dev/null +++ b/tests/core-mcp/setup.go @@ -0,0 +1,401 @@ +package mcp + +import ( + "fmt" + "os" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// TestTimeout defines the maximum duration for MCP tests +const TestTimeout = 10 * time.Minute + +// TestAccount is a minimal account implementation for testing +type TestAccount struct{} + +func (a *TestAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + return []schemas.ModelProvider{schemas.OpenAI}, nil +} + +func (a *TestAccount) GetKeysForProvider(ctx *schemas.BifrostContext, providerKey schemas.ModelProvider) ([]schemas.Key, error) { + return []schemas.Key{ + { + Value: os.Getenv("OPENAI_API_KEY"), + Models: []string{}, + Weight: 1.0, + }, + }, nil +} + +func (a *TestAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + return &schemas.ProviderConfig{ + NetworkConfig: schemas.DefaultNetworkConfig, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil +} + +// setupTestBifrost initializes and returns a Bifrost instance for testing +// This creates a basic Bifrost instance without any MCP clients configured +func setupTestBifrost(ctx *schemas.BifrostContext) (*bifrost.Bifrost, error) { + return setupTestBifrostWithMCPConfig(ctx, &schemas.MCPConfig{ + ClientConfigs: []schemas.MCPClientConfig{}, + ToolManagerConfig: &schemas.MCPToolManagerConfig{ + MaxAgentDepth: 10, + ToolExecutionTimeout: 30 * time.Second, + }, + FetchNewRequestIDFunc: func(ctx *schemas.BifrostContext) string { + return "test-request-id" + }, + }) +} + +// setupTestBifrostWithCodeMode initializes and returns a Bifrost instance for testing with CodeMode +// This sets up bifrostInternal client as a code mode client +// Note: Tools must be registered first to create the bifrostInternal client +func setupTestBifrostWithCodeMode(ctx *schemas.BifrostContext) (*bifrost.Bifrost, error) { + b, err := setupTestBifrost(ctx) + if err != nil { + return nil, err + } + + // Register tools first to create the bifrostInternal client + err = registerTestTools(b) + if err != nil { + return nil, fmt.Errorf("failed to register test tools: %w", err) + } + + // Get current client config to preserve existing settings + clients, err := b.GetMCPClients() + if err != nil { + return nil, fmt.Errorf("failed to get MCP clients: %w", err) + } + + var currentConfig *schemas.MCPClientConfig + for _, client := range clients { + if client.Config.ID == "bifrostInternal" { + currentConfig = &client.Config + break + } + } + + if currentConfig == nil { + return nil, fmt.Errorf("bifrostInternal client not found") + } + + // Set bifrostInternal client to code mode and ensure tools are available + // Preserve existing ToolsToExecute if set, otherwise use wildcard + toolsToExecute := currentConfig.ToolsToExecute + if len(toolsToExecute) == 0 { + toolsToExecute = []string{"*"} + } + + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + ID: currentConfig.ID, + Name: currentConfig.Name, + ConnectionType: currentConfig.ConnectionType, + IsCodeModeClient: true, + ToolsToExecute: toolsToExecute, + ToolsToAutoExecute: currentConfig.ToolsToAutoExecute, + }) + if err != nil { + return nil, fmt.Errorf("failed to set bifrostInternal client to code mode: %w", err) + } + + return b, nil +} + +// setupTestBifrostWithMCPConfig initializes Bifrost with custom MCP config +func setupTestBifrostWithMCPConfig(ctx *schemas.BifrostContext, mcpConfig *schemas.MCPConfig) (*bifrost.Bifrost, error) { + account := &TestAccount{} + + // Ensure FetchNewRequestIDFunc is set if not provided + // This is required for the tools handler to be fully setup + if mcpConfig.FetchNewRequestIDFunc == nil { + mcpConfig.FetchNewRequestIDFunc = func(ctx *schemas.BifrostContext) string { + return "test-request-id" + } + } + + if mcpConfig.ToolManagerConfig == nil { + mcpConfig.ToolManagerConfig = &schemas.MCPToolManagerConfig{ + MaxAgentDepth: schemas.DefaultMaxAgentDepth, + ToolExecutionTimeout: schemas.DefaultToolExecutionTimeout, + } + } + + b, err := bifrost.Init(ctx, schemas.BifrostConfig{ + Account: account, + Plugins: nil, + Logger: bifrost.NewDefaultLogger(schemas.LogLevelDebug), + MCPConfig: mcpConfig, + }) + if err != nil { + return nil, fmt.Errorf("failed to initialize Bifrost: %w", err) + } + + return b, nil +} + +// registerTestTools registers simple test tools for testing +func registerTestTools(b *bifrost.Bifrost) error { + // Echo tool + echoSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "echo", + Description: schemas.Ptr("Echoes back the input message"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &schemas.OrderedMap{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to echo", + }, + }, + Required: []string{"message"}, + }, + }, + } + if err := b.RegisterMCPTool("echo", "Echoes back the input message", func(args any) (string, error) { + argsMap, ok := args.(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid args type") + } + message, ok := argsMap["message"].(string) + if !ok { + return "", fmt.Errorf("message field is required") + } + return message, nil + }, echoSchema); err != nil { + return fmt.Errorf("failed to register echo tool: %w", err) + } + + // Add tool + addSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "add", + Description: schemas.Ptr("Adds two numbers"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &schemas.OrderedMap{ + "a": map[string]interface{}{ + "type": "number", + "description": "First number", + }, + "b": map[string]interface{}{ + "type": "number", + "description": "Second number", + }, + }, + Required: []string{"a", "b"}, + }, + }, + } + if err := b.RegisterMCPTool("add", "Adds two numbers", func(args any) (string, error) { + argsMap, ok := args.(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid args type") + } + a, ok := argsMap["a"].(float64) + if !ok { + return "", fmt.Errorf("a field is required") + } + bVal, ok := argsMap["b"].(float64) + if !ok { + return "", fmt.Errorf("b field is required") + } + return fmt.Sprintf("%.0f", a+bVal), nil + }, addSchema); err != nil { + return fmt.Errorf("failed to register add tool: %w", err) + } + + // Multiply tool + multiplySchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "multiply", + Description: schemas.Ptr("Multiplies two numbers"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &schemas.OrderedMap{ + "a": map[string]interface{}{ + "type": "number", + "description": "First number", + }, + "b": map[string]interface{}{ + "type": "number", + "description": "Second number", + }, + }, + Required: []string{"a", "b"}, + }, + }, + } + if err := b.RegisterMCPTool("multiply", "Multiplies two numbers", func(args any) (string, error) { + argsMap, ok := args.(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid args type") + } + a, ok := argsMap["a"].(float64) + if !ok { + return "", fmt.Errorf("a field is required") + } + bVal, ok := argsMap["b"].(float64) + if !ok { + return "", fmt.Errorf("b field is required") + } + return fmt.Sprintf("%.0f", a*bVal), nil + }, multiplySchema); err != nil { + return fmt.Errorf("failed to register multiply tool: %w", err) + } + + // GetData tool - returns structured data + getDataSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_data", + Description: schemas.Ptr("Returns structured data"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &schemas.OrderedMap{}, + Required: []string{}, + }, + }, + } + if err := b.RegisterMCPTool("get_data", "Returns structured data", func(args any) (string, error) { + return `{"items": [{"id": 1, "name": "test"}, {"id": 2, "name": "example"}]}`, nil + }, getDataSchema); err != nil { + return fmt.Errorf("failed to register get_data tool: %w", err) + } + + // ErrorTool - always returns an error + errorToolSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "error_tool", + Description: schemas.Ptr("A tool that always returns an error"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &schemas.OrderedMap{}, + Required: []string{}, + }, + }, + } + if err := b.RegisterMCPTool("error_tool", "A tool that always returns an error", func(args any) (string, error) { + return "", fmt.Errorf("this tool always fails") + }, errorToolSchema); err != nil { + return fmt.Errorf("failed to register error_tool: %w", err) + } + + // SlowTool - takes time to execute + slowToolSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "slow_tool", + Description: schemas.Ptr("A tool that takes time to execute"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &schemas.OrderedMap{ + "delay_ms": map[string]interface{}{ + "type": "number", + "description": "Delay in milliseconds", + }, + }, + Required: []string{"delay_ms"}, + }, + }, + } + if err := b.RegisterMCPTool("slow_tool", "A tool that takes time to execute", func(args any) (string, error) { + argsMap, ok := args.(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid args type") + } + delayMs, ok := argsMap["delay_ms"].(float64) + if !ok { + return "", fmt.Errorf("delay_ms field is required") + } + time.Sleep(time.Duration(delayMs) * time.Millisecond) + return fmt.Sprintf("Completed after %v ms", delayMs), nil + }, slowToolSchema); err != nil { + return fmt.Errorf("failed to register slow_tool: %w", err) + } + + // ComplexArgsTool - accepts complex nested arguments + complexArgsSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "complex_args_tool", + Description: schemas.Ptr("A tool that accepts complex nested arguments"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &schemas.OrderedMap{ + "data": map[string]interface{}{ + "type": "object", + "description": "Complex nested data", + }, + }, + Required: []string{"data"}, + }, + }, + } + if err := b.RegisterMCPTool("complex_args_tool", "A tool that accepts complex nested arguments", func(args any) (string, error) { + argsMap, ok := args.(map[string]interface{}) + if !ok { + return "", fmt.Errorf("invalid args type") + } + data, ok := argsMap["data"] + if !ok { + return "", fmt.Errorf("data field is required") + } + return fmt.Sprintf("Received data: %v", data), nil + }, complexArgsSchema); err != nil { + return fmt.Errorf("failed to register complex_args_tool: %w", err) + } + + return nil +} + +// connectExternalMCP connects to an external MCP server +// This is a helper function that can be used when external MCP credentials are provided +func connectExternalMCP(b *bifrost.Bifrost, name, id, connectionType, connectionString string) error { + var clientConfig schemas.MCPClientConfig + + switch connectionType { + case "http": + clientConfig = schemas.MCPClientConfig{ + ID: id, + Name: name, + ConnectionType: schemas.MCPConnectionTypeHTTP, + ConnectionString: schemas.Ptr(connectionString), + } + case "sse": + clientConfig = schemas.MCPClientConfig{ + ID: id, + Name: name, + ConnectionType: schemas.MCPConnectionTypeSSE, + ConnectionString: schemas.Ptr(connectionString), + } + default: + return fmt.Errorf("unsupported connection type: %s", connectionType) + } + + clients, err := b.GetMCPClients() + if err != nil { + return fmt.Errorf("failed to get MCP clients: %w", err) + } + for _, client := range clients { + if client.Config.ID == id { + // Client already exists + return nil + } + } + + if err := b.AddMCPClient(clientConfig); err != nil { + return fmt.Errorf("failed to add external MCP client: %w", err) + } + + return nil +} diff --git a/tests/core-mcp/tool_execution_test.go b/tests/core-mcp/tool_execution_test.go new file mode 100644 index 0000000000..991d9fe464 --- /dev/null +++ b/tests/core-mcp/tool_execution_test.go @@ -0,0 +1,246 @@ +package mcp + +import ( + "context" + "strings" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNonCodeModeToolExecution(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to non-code mode and ensure tools are available + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + ToolsToExecute: []string{"*"}, // Allow all tools + }) + require.NoError(t, err) + + // Test direct tool execution + echoCall := createToolCall("echo", schemas.OrderedMap{ + "message": "test message", + }) + result, bifrostErr := b.ExecuteChatMCPTool(ctx, echoCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Equal(t, "test message", responseText) +} + +func TestCodeModeToolExecution(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test executeToolCode + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.SimpleExpression, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + assertExecutionResult(t, result, true, nil, "") + assertResultContains(t, result, "completed successfully") +} + +func TestCodeModeCallingCodeModeClientTools(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test code calling code mode client tools + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.CodeCallingCodeModeTool, + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + assertExecutionResult(t, result, true, nil, "") + assertResultContains(t, result, "test") +} + +func TestCodeModeCallingMultipleCodeModeClients(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + // Test code calling tools from multiple code mode clients + // Since we only have bifrostInternal, we'll test calling multiple tools from the same client + toolCall := createToolCall("executeToolCode", map[string]interface{}{ + "code": CodeFixtures.MultipleServerToolCalls, // This calls echo and add from BifrostClient + }) + + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + assertExecutionResult(t, result, true, nil, "") +} + +func TestListToolFilesWithNoClients(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + // Don't register tools or set code mode - should have no code mode clients + toolCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + // listToolFiles should still work but return empty/no servers message + if bifrostErr == nil && result != nil { + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "No servers", "Should indicate no servers") + } +} + +func TestListToolFilesWithOnlyNonCodeModeClients(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrost(ctx) + require.NoError(t, err) + + err = registerTestTools(b) + require.NoError(t, err) + + // Set bifrostInternal to non-code mode + err = b.EditMCPClient("bifrostInternal", schemas.MCPClientConfig{ + IsCodeModeClient: false, + }) + require.NoError(t, err) + + // listToolFiles should not be available when no code mode clients exist + // But if it is called, it should return empty + toolCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + if bifrostErr == nil && result != nil { + responseText := *result.Content.ContentStr + // Should indicate no servers or empty list + assert.True(t, + len(responseText) == 0 || + strings.Contains(responseText, "No servers") || strings.Contains(responseText, "servers/"), + "Should return empty or no servers message") + } +} + +func TestListToolFilesWithCodeModeClients(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("listToolFiles", map[string]interface{}{}) + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "servers/", "Should list servers") + assert.Contains(t, responseText, "BifrostClient.d.ts", "Should list BifrostClient server") +} + +func TestReadToolFileForNonExistentClient(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("readToolFile", map[string]interface{}{ + "fileName": "NonExistentClient.d.ts", + }) + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "No server found", "Should indicate server not found") +} + +func TestReadToolFileForCodeModeClient(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("readToolFile", map[string]interface{}{ + "fileName": "BifrostClient.d.ts", + }) + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, "interface", "Should contain TypeScript interface declarations") + assert.Contains(t, responseText, "echo", "Should contain echo tool definition") +} + +func TestReadToolFileWithLineRange(t *testing.T) { + ctx, cancel := schemas.NewBifrostContextWithTimeout(context.Background(), TestTimeout) + defer cancel() + + b, err := setupTestBifrostWithCodeMode(ctx) + require.NoError(t, err) + // Tools are already registered in setupTestBifrostWithCodeMode + + toolCall := createToolCall("readToolFile", map[string]interface{}{ + "fileName": "BifrostClient.d.ts", + "startLine": float64(1), + "endLine": float64(10), + }) + result, bifrostErr := b.ExecuteChatMCPTool(ctx, toolCall) + requireNoBifrostError(t, bifrostErr) + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.NotEmpty(t, responseText, "Should return content") +} diff --git a/tests/core-mcp/utils.go b/tests/core-mcp/utils.go new file mode 100644 index 0000000000..dd6a0e1681 --- /dev/null +++ b/tests/core-mcp/utils.go @@ -0,0 +1,150 @@ +package mcp + +import ( + "encoding/json" + "fmt" + "slices" + "testing" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// createToolCall creates a tool call message for testing +func createToolCall(toolName string, arguments schemas.OrderedMap) schemas.ChatAssistantMessageToolCall { + argsJSON, _ := json.Marshal(arguments) + argsStr := string(argsJSON) + id := fmt.Sprintf("test-tool-call-%d", len(argsStr)) + toolType := "function" + + return schemas.ChatAssistantMessageToolCall{ + ID: &id, + Type: &toolType, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: &toolName, + Arguments: argsStr, + }, + } +} + +// createResponsesToolCall creates a tool call message for testing +func createResponsesToolCall(toolName string, arguments schemas.OrderedMap) *schemas.ResponsesToolMessage { + argsJSON, _ := json.Marshal(arguments) + argsStr := string(argsJSON) + id := fmt.Sprintf("test-tool-call-%d", len(argsStr)) + + return &schemas.ResponsesToolMessage{ + CallID: &id, + Name: &toolName, + Arguments: &argsStr, + } +} + +// assertResponsesExecutionResult validates execution results +func assertResponsesExecutionResult(t *testing.T, result *schemas.ResponsesMessage, expectedSuccess bool, expectedLogs []string, expectedErrorKind string) { + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + + if expectedSuccess { + // Success case - should not contain error indicators (but allow console.error output) + assert.NotContains(t, responseText, "Execution runtime error", "Response should not contain execution runtime error for successful execution") + assert.NotContains(t, responseText, "Execution typescript error", "Response should not contain execution typescript error for successful execution") + assert.NotContains(t, responseText, "Error:", "Response should not contain Error: prefix for successful execution") + } else { + // Error case - should contain error information + assert.Contains(t, responseText, "error", "Response should contain error for failed execution") + + if expectedErrorKind != "" { + assert.Contains(t, responseText, expectedErrorKind, "Response should contain expected error kind") + } + } +} + +// assertExecutionResult validates execution results +func assertExecutionResult(t *testing.T, result *schemas.ChatMessage, expectedSuccess bool, expectedLogs []string, expectedErrorKind string) { + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + + if expectedSuccess { + // Success case - should not contain error indicators (but allow console.error output) + assert.NotContains(t, responseText, "Execution runtime error", "Response should not contain execution runtime error for successful execution") + assert.NotContains(t, responseText, "Execution typescript error", "Response should not contain execution typescript error for successful execution") + assert.NotContains(t, responseText, "Error:", "Response should not contain Error: prefix for successful execution") + + // Check logs if expected + if len(expectedLogs) > 0 { + for _, expectedLog := range expectedLogs { + assert.Contains(t, responseText, expectedLog, "Response should contain expected log") + } + } + } else { + // Error case - should contain error information + assert.Contains(t, responseText, "error", "Response should contain error for failed execution") + + if expectedErrorKind != "" { + assert.Contains(t, responseText, expectedErrorKind, "Response should contain expected error kind") + } + } +} + +// assertResultContains validates that the result contains specific text +func assertResultContains(t *testing.T, result *schemas.ChatMessage, expectedText string) { + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, expectedText, "Response should contain expected text") +} + +// assertResponsesResultContains validates that the result contains specific text +func assertResponsesResultContains(t *testing.T, result *schemas.ResponsesMessage, expectedText string) { + require.NotNil(t, result) + require.NotNil(t, result.Content) + require.NotNil(t, result.Content.ContentStr) + + responseText := *result.Content.ContentStr + assert.Contains(t, responseText, expectedText, "Response should contain expected text") +} + +// requireNoBifrostError asserts that bifrostErr is nil, using GetErrorMessage for better error reporting +func requireNoBifrostError(t *testing.T, bifrostErr *schemas.BifrostError, msgAndArgs ...interface{}) { + if bifrostErr != nil { + errorMsg := bifrost.GetErrorMessage(bifrostErr) + if len(msgAndArgs) > 0 { + require.Fail(t, fmt.Sprintf("Expected no error but got: %s", errorMsg), msgAndArgs...) + } else { + require.Fail(t, fmt.Sprintf("Expected no error but got: %s", errorMsg)) + } + } +} + +// canAutoExecuteTool checks if a tool can be auto-executed based on client config +func canAutoExecuteTool(toolName string, config schemas.MCPClientConfig) bool { + // First check if tool is in ToolsToExecute + if config.ToolsToExecute != nil { + if len(config.ToolsToExecute) == 0 { + return false // Empty list means no tools allowed + } + if !slices.Contains(config.ToolsToExecute, "*") && !slices.Contains(config.ToolsToExecute, toolName) { + return false // Tool not in allowed list + } + } else { + return false // nil means no tools allowed + } + + // Then check if tool is in ToolsToAutoExecute + if len(config.ToolsToAutoExecute) == 0 { + return false // No auto-execute tools configured + } + + return slices.Contains(config.ToolsToAutoExecute, "*") || slices.Contains(config.ToolsToAutoExecute, toolName) +} diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index 0441537f0c..ac02a2cb68 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -11,7 +11,7 @@ services: - http environment: - CLUSTER_HOSTNAME=weaviate - - CLUSTER_ADVERTISE_ADDR=172.38.0.12 + - CLUSTER_ADVERTISE_ADDR=172.28.0.12 - CLUSTER_GOSSIP_BIND_PORT=7946 - CLUSTER_DATA_BIND_PORT=7947 - DISABLE_TELEMETRY=true @@ -26,7 +26,7 @@ services: - weaviate_data:/var/lib/weaviate networks: bifrost_network: - ipv4_address: 172.38.0.12 + ipv4_address: 172.28.0.12 # Redis Stack instance for vector store tests redis-stack: @@ -39,7 +39,7 @@ services: - redis_data:/data networks: bifrost_network: - ipv4_address: 172.38.0.13 + ipv4_address: 172.28.0.13 healthcheck: test: ["CMD", "redis-cli", "ping"] interval: 30s @@ -55,15 +55,15 @@ services: - qdrant_data:/qdrant/storage networks: bifrost_network: - ipv4_address: 172.38.0.14 + ipv4_address: 172.28.0.14 networks: bifrost_network: driver: bridge ipam: config: - - subnet: 172.38.0.0/16 - gateway: 172.38.0.1 + - subnet: 172.28.0.0/16 + gateway: 172.28.0.1 volumes: weaviate_data: diff --git a/tests/integrations/.python-version b/tests/integrations/python/.python-version similarity index 100% rename from tests/integrations/.python-version rename to tests/integrations/python/.python-version diff --git a/tests/integrations/README.md b/tests/integrations/python/README.md similarity index 100% rename from tests/integrations/README.md rename to tests/integrations/python/README.md diff --git a/tests/integrations/config.json b/tests/integrations/python/config.json similarity index 98% rename from tests/integrations/config.json rename to tests/integrations/python/config.json index 7a7e437805..2e2730b924 100644 --- a/tests/integrations/config.json +++ b/tests/integrations/python/config.json @@ -139,7 +139,7 @@ "access_key": "env.AWS_ACCESS_KEY_ID", "secret_key": "env.AWS_SECRET_ACCESS_KEY", "region": "env.AWS_REGION", - "arn": "env.AWS_BEDROCK_ROLE_ARN" + "arn": "env.AWS_ARN" }, "weight": 1, "use_for_batch_api": true @@ -180,7 +180,7 @@ "*" ], "enable_logging": true, - "enable_governance": true, + "enable_governance": false, "enforce_governance_header": false, "allow_direct_keys": false, "max_request_body_size_mb": 100, diff --git a/tests/integrations/config.yml b/tests/integrations/python/config.yml similarity index 96% rename from tests/integrations/config.yml rename to tests/integrations/python/config.yml index b813cac758..2d2658cd77 100644 --- a/tests/integrations/config.yml +++ b/tests/integrations/python/config.yml @@ -61,6 +61,13 @@ providers: - "gpt-4o" - "gpt-3.5-turbo" + xai: + chat: "grok-4-0709" + vision: "grok-2-vision-1212" + tools: "grok-4-0709" + streaming: "grok-4-0709" + thinking: "grok-3-mini" + anthropic: chat: "claude-sonnet-4-5-20250929" vision: "claude-3-7-sonnet-20250219" @@ -164,6 +171,7 @@ provider_api_keys: vertex: "VERTEX_API_KEY" bedrock: "AWS_ACCESS_KEY_ID" cohere: "COHERE_API_KEY" + xai: "XAI_API_KEY" # Provider test scenarios - which tests each provider supports provider_scenarios: @@ -206,6 +214,27 @@ provider_scenarios: file_delete: true file_content: true count_tokens: true + + xai: + simple_chat: true + multi_turn_conversation: true + streaming: true + tool_calls: true + multiple_tool_calls: true + end2end_tool_calling: true + automatic_function_calling: true + image_url: true + image_base64: false + file_input: false + multiple_images: false + thinking: true + list_models: true + responses: true + responses_image: true + text_completion: false + langchain_structured_output: true + pydantic_structured_output: true + pydanticai_streaming: true anthropic: simple_chat: true @@ -724,7 +753,7 @@ integration_settings: bedrock: region: "${AWS_REGION:-us-west-2}" s3_bucket: "${AWS_S3_BUCKET:-}" - batch_role_arn: "${AWS_BEDROCK_ROLE_ARN:-}" + batch_role_arn: "${AWS_ARN:-}" output_s3_prefix: "${AWS_OUTPUT_S3_PREFIX:-bifrost-batch-output/}" # Environment-specific overrides diff --git a/tests/integrations/dummy-gcp-credentials.json b/tests/integrations/python/dummy-gcp-credentials.json similarity index 100% rename from tests/integrations/dummy-gcp-credentials.json rename to tests/integrations/python/dummy-gcp-credentials.json diff --git a/tests/integrations/pyproject.toml b/tests/integrations/python/pyproject.toml similarity index 100% rename from tests/integrations/pyproject.toml rename to tests/integrations/python/pyproject.toml diff --git a/tests/integrations/run_all_tests.py b/tests/integrations/python/run_all_tests.py similarity index 100% rename from tests/integrations/run_all_tests.py rename to tests/integrations/python/run_all_tests.py diff --git a/tests/integrations/run_integration_tests.py b/tests/integrations/python/run_integration_tests.py similarity index 100% rename from tests/integrations/run_integration_tests.py rename to tests/integrations/python/run_integration_tests.py diff --git a/tests/integrations/tests/__init__.py b/tests/integrations/python/tests/__init__.py similarity index 100% rename from tests/integrations/tests/__init__.py rename to tests/integrations/python/tests/__init__.py diff --git a/tests/integrations/tests/conftest.py b/tests/integrations/python/tests/conftest.py similarity index 100% rename from tests/integrations/tests/conftest.py rename to tests/integrations/python/tests/conftest.py diff --git a/tests/integrations/tests/test_anthropic.py b/tests/integrations/python/tests/test_anthropic.py similarity index 99% rename from tests/integrations/tests/test_anthropic.py rename to tests/integrations/python/tests/test_anthropic.py index 9d4525fa96..3dc1a55671 100644 --- a/tests/integrations/tests/test_anthropic.py +++ b/tests/integrations/python/tests/test_anthropic.py @@ -789,7 +789,7 @@ def test_16_extended_thinking_streaming(self, anthropic_client, test_config, pro # Stream with thinking enabled - use thinking-capable model stream = anthropic_client.messages.create( model=format_provider_model(provider, model), - max_tokens=4000, # Reduced to prevent token limit errors for smaller context window models + max_tokens=3000, thinking={ "type": "enabled", "budget_tokens": 2000, # Reduced to prevent token limit errors @@ -836,7 +836,8 @@ def test_16_extended_thinking_streaming(self, anthropic_client, test_config, pro text_parts.append(str(event.delta.text)) # Safety check - if chunk_count > 1000: + print("chunk_count", chunk_count) + if chunk_count > 5000: break # Combine collected content diff --git a/tests/integrations/tests/test_bedrock.py b/tests/integrations/python/tests/test_bedrock.py similarity index 100% rename from tests/integrations/tests/test_bedrock.py rename to tests/integrations/python/tests/test_bedrock.py diff --git a/tests/integrations/tests/test_google.py b/tests/integrations/python/tests/test_google.py similarity index 96% rename from tests/integrations/tests/test_google.py rename to tests/integrations/python/tests/test_google.py index 9c53aad6e8..526a513bcd 100644 --- a/tests/integrations/tests/test_google.py +++ b/tests/integrations/python/tests/test_google.py @@ -751,6 +751,84 @@ def test_11_integration_specific_features(self, google_client, test_config): assert_valid_chat_response(response3) + + @pytest.mark.parametrize("provider,model", get_cross_provider_params_for_scenario("simple_chat")) + def test_11a_system_instruction(self, google_client, test_config, provider, model): + if provider == "_no_providers_" or model == "_no_model_": + pytest.skip("No providers configured for this scenario") + """Test Case 11a: System instruction (cross-provider)""" + from google.genai import types + + # Test 1: System instruction with word count constraint + response = google_client.models.generate_content( + model=format_provider_model(provider, model), + contents="What is 2 + 2?", + config=types.GenerateContentConfig( + system_instruction="You are a helpful assistant that always responds in exactly 5 words or fewer.", + max_output_tokens=300, + ), + ) + + assert_valid_chat_response(response) + assert response.text is not None + assert len(response.text) > 0 + + # Verify response respects the constraint AND contains correct answer + word_count = len(response.text.split()) + content_lower = response.text.lower() + + # Should be short (respecting the 5 word limit with small tolerance) + assert word_count <= 8, ( + f"Expected ≤8 words (system instruction: ≤5 words), got {word_count} words: {response.text}" + ) + + # Should contain the correct answer + has_answer = any(ans in content_lower for ans in ["4", "four", "quatre"]) + assert has_answer, ( + f"Response should contain the answer '4' or 'four'. Got: {response.text}" + ) + + print(f"āœ“ Word limit test passed: {response.text} ({word_count} words)") + + # Test 2: System instruction for translation (English to French) + response2 = google_client.models.generate_content( + model=format_provider_model(provider, model), + contents="Hello, how are you?", + config=types.GenerateContentConfig( + system_instruction=[ + "You are a language translator.", + "Your mission is to translate text from English to French.", + "Only output the French translation, nothing else.", + ], + max_output_tokens=300, + ), + ) + + assert_valid_chat_response(response2) + assert response2.text is not None + assert len(response2.text) > 0 + + content_lower = response2.text.lower() + + # Check for French translation keywords + french_keywords = ["bonjour", "salut", "comment", "allez", "vous", "Ƨa", "va"] + has_french = any(keyword in content_lower for keyword in french_keywords) + + # Check for common English words that shouldn't appear in pure French translation + english_words = ["hello", "how", "are", "you"] + has_english = any(word in content_lower for word in english_words) + + # Should have French keywords AND not have English words (pure translation) + assert has_french, ( + f"Response should contain French keywords. Got: {response2.text}" + ) + assert not has_english, ( + f"Response should not contain English words (should be pure French translation). Got: {response2.text}" + ) + + print(f"āœ“ Translation test passed: {response2.text}") + print(f"āœ“ System instruction test completed for provider {provider}") + @skip_if_no_api_key("google") def test_12_error_handling_invalid_roles(self, google_client, test_config): """Test Case 12: Error handling for invalid roles""" diff --git a/tests/integrations/tests/test_langchain.py b/tests/integrations/python/tests/test_langchain.py similarity index 100% rename from tests/integrations/tests/test_langchain.py rename to tests/integrations/python/tests/test_langchain.py diff --git a/tests/integrations/tests/test_litellm.py b/tests/integrations/python/tests/test_litellm.py similarity index 100% rename from tests/integrations/tests/test_litellm.py rename to tests/integrations/python/tests/test_litellm.py diff --git a/tests/integrations/tests/test_openai.py b/tests/integrations/python/tests/test_openai.py similarity index 99% rename from tests/integrations/tests/test_openai.py rename to tests/integrations/python/tests/test_openai.py index e262ebdc12..342f1e15a2 100644 --- a/tests/integrations/tests/test_openai.py +++ b/tests/integrations/python/tests/test_openai.py @@ -273,9 +273,8 @@ def test_01_simple_chat(self, test_config, provider, model, vk_enabled): response = client.chat.completions.create( model=format_provider_model(provider, model), messages=SIMPLE_CHAT_MESSAGES, - max_tokens=100, + max_tokens=100, ) - assert_valid_chat_response(response) assert response.choices[0].message.content is not None assert len(response.choices[0].message.content) > 0 @@ -582,6 +581,7 @@ def test_13_streaming(self, test_config, provider, model, vk_enabled): messages=STREAMING_CHAT_MESSAGES, max_tokens=200, stream=True, + extra_body={"reasoning": {"effort": "high"}} ) content, chunk_count, tool_calls_detected = collect_streaming_content( diff --git a/tests/integrations/tests/test_pydanticai.py b/tests/integrations/python/tests/test_pydanticai.py similarity index 100% rename from tests/integrations/tests/test_pydanticai.py rename to tests/integrations/python/tests/test_pydanticai.py diff --git a/tests/integrations/tests/utils/__init__.py b/tests/integrations/python/tests/utils/__init__.py similarity index 100% rename from tests/integrations/tests/utils/__init__.py rename to tests/integrations/python/tests/utils/__init__.py diff --git a/tests/integrations/tests/utils/common.py b/tests/integrations/python/tests/utils/common.py similarity index 99% rename from tests/integrations/tests/utils/common.py rename to tests/integrations/python/tests/utils/common.py index 6ff7ca1ced..e5a3f65f00 100644 --- a/tests/integrations/tests/utils/common.py +++ b/tests/integrations/python/tests/utils/common.py @@ -1813,6 +1813,7 @@ def get_api_key(integration: str) -> str: "bedrock": "AWS_ACCESS_KEY_ID", # Bedrock uses AWS credentials "cohere": "COHERE_API_KEY", "vertex": "VERTEX_API_KEY", + "xai": "XAI_API_KEY", } env_var = key_map.get(integration.lower()) @@ -2363,7 +2364,7 @@ def get_bedrock_s3_config() -> Dict[str, Optional[str]]: """ return { "s3_bucket": os.environ.get("AWS_S3_BUCKET"), - "role_arn": os.environ.get("AWS_BEDROCK_ROLE_ARN"), + "role_arn": os.environ.get("AWS_ARN"), "output_s3_prefix": os.environ.get("AWS_OUTPUT_S3_PREFIX", "bifrost-batch-output/"), "region": os.environ.get("AWS_REGION", "us-west-2"), } diff --git a/tests/integrations/tests/utils/config_loader.py b/tests/integrations/python/tests/utils/config_loader.py similarity index 100% rename from tests/integrations/tests/utils/config_loader.py rename to tests/integrations/python/tests/utils/config_loader.py diff --git a/tests/integrations/tests/utils/models.py b/tests/integrations/python/tests/utils/models.py similarity index 100% rename from tests/integrations/tests/utils/models.py rename to tests/integrations/python/tests/utils/models.py diff --git a/tests/integrations/tests/utils/parametrize.py b/tests/integrations/python/tests/utils/parametrize.py similarity index 100% rename from tests/integrations/tests/utils/parametrize.py rename to tests/integrations/python/tests/utils/parametrize.py diff --git a/tests/integrations/uv.lock b/tests/integrations/python/uv.lock similarity index 100% rename from tests/integrations/uv.lock rename to tests/integrations/python/uv.lock diff --git a/tests/integrations/typescript/README.md b/tests/integrations/typescript/README.md new file mode 100644 index 0000000000..b116abb4bc --- /dev/null +++ b/tests/integrations/typescript/README.md @@ -0,0 +1,340 @@ +# Bifrost TypeScript Integration Tests + +TypeScript/JavaScript integration test suite for testing AI providers through Bifrost proxy. This test suite uses Vitest and provides comprehensive coverage across multiple AI SDKs. + +## Quick Start + +```bash +# 1. Install dependencies +cd bifrost/tests/integrations/typescript +npm install + +# 2. Set environment variables +export BIFROST_BASE_URL="http://localhost:8080" +export OPENAI_API_KEY="your-key" +export ANTHROPIC_API_KEY="your-key" +export GEMINI_API_KEY="your-key" + +# 3. Run tests +npm test # All tests +npm test -- tests/test-openai.test.ts # Specific SDK +npm test -- -t "Simple Chat" # By pattern +``` + +## Architecture Overview + +The TypeScript integration tests use the same centralized configuration as the Python tests, routing all AI requests through Bifrost: + +```text +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ Test Client │───▶│ Bifrost Gateway │───▶│ AI Provider │ +│ (TypeScript) │ │ localhost:8080 │ │ (OpenAI, etc.) │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ +``` + +## Supported SDKs + +| SDK | Package | Features | +|-----|---------|----------| +| **OpenAI** | `openai` | Chat, Streaming, Tools, Vision, Speech, Embeddings | +| **Anthropic** | `@anthropic-ai/sdk` | Chat, Streaming, Tools, Vision, Thinking | +| **Google GenAI** | `@google/generative-ai` | Chat, Streaming, Tools, Vision, Embeddings | +| **LangChain.js** | `@langchain/*` | Chat, Streaming, Tools, Structured Output | + +## Test Scenarios + +Each SDK test file covers these scenarios where supported: + +### Core Chat +1. **Simple Chat** - Basic single-message conversations +2. **Multi-turn Conversation** - Context retention across messages +3. **Streaming Chat** - Real-time streaming responses + +### Tool Calling +4. **Single Tool Call** - Basic function calling +5. **Multiple Tool Calls** - Multiple tools in single request +6. **End-to-End Tool Calling** - Complete workflow with results + +### Vision +7. **Image URL** - Image analysis from URLs +8. **Image Base64** - Image analysis from base64 data +9. **Multiple Images** - Multi-image comparison + +### Advanced Features +10. **Speech Synthesis** - Text-to-speech (OpenAI) +11. **Transcription** - Speech-to-text (OpenAI) +12. **Embeddings** - Text-to-vector conversion +13. **Structured Output** - Schema-based responses +14. **Thinking/Reasoning** - Extended reasoning modes + +## Directory Structure + +```text +typescript/ +ā”œā”€ā”€ package.json # Dependencies and scripts +ā”œā”€ā”€ tsconfig.json # TypeScript configuration +ā”œā”€ā”€ vitest.config.ts # Vitest test configuration +ā”œā”€ā”€ config.yml # Shared config (mirrors ../python/config.yml) +ā”œā”€ā”€ README.md # This file +ā”œā”€ā”€ src/ +│ └── utils/ +│ ā”œā”€ā”€ config-loader.ts # Configuration loading +│ ā”œā”€ā”€ common.ts # Test data and assertions +│ ā”œā”€ā”€ parametrize.ts # Cross-provider utilities +│ └── index.ts # Barrel export +└── tests/ + ā”œā”€ā”€ setup.ts # Global test setup + ā”œā”€ā”€ test-openai.test.ts # OpenAI SDK tests + ā”œā”€ā”€ test-anthropic.test.ts # Anthropic SDK tests + ā”œā”€ā”€ test-google.test.ts # Google GenAI tests + └── test-langchain.test.ts # LangChain.js tests +``` + +## Configuration + +### Shared Configuration + +The TypeScript tests share configuration with Python tests. The `config.yml` file mirrors the Python test configuration to ensure consistency: + +```bash +# Both test suites use the same configuration format +tests/integrations/typescript/config.yml # TypeScript tests +tests/integrations/python/config.yml # Python tests +``` + +This ensures consistent: +- Provider model configurations +- Scenario capability mappings +- API settings (timeouts, retries) +- Virtual key settings + +### Environment Variables + +**Required:** +```bash +export BIFROST_BASE_URL="http://localhost:8080" +``` + +**Provider API Keys (at least one required):** +```bash +export OPENAI_API_KEY="sk-..." +export ANTHROPIC_API_KEY="sk-ant-..." +export GEMINI_API_KEY="AIza..." +``` + +**Optional:** +```bash +export AWS_ACCESS_KEY_ID="..." # For Bedrock +export AWS_SECRET_ACCESS_KEY="..." +export COHERE_API_KEY="..." +``` + +## Running Tests + +### Using npm scripts + +```bash +# Run all tests +npm test + +# Run tests with verbose output +npm test -- --reporter=verbose + +# Run tests in watch mode +npm run test:watch + +# Run with coverage +npm run test:coverage + +# Run with UI +npm run test:ui +``` + +### Filtering tests + +```bash +# Run specific test file +npm test -- tests/test-openai.test.ts + +# Run tests matching pattern +npm test -- -t "Simple Chat" +npm test -- -t "Tool" +npm test -- -t "Streaming" + +# Run tests for specific provider +npm test -- tests/test-anthropic.test.ts -t "Streaming" +``` + +### Using Makefile + +From the repository root: + +```bash +# Run TypeScript integration tests +make test-integrations LANG=ts + +# Run specific SDK tests +make test-integrations LANG=ts INTEGRATION=openai + +# Run with pattern +make test-integrations LANG=ts PATTERN="tool" + +# Verbose output +make test-integrations LANG=ts VERBOSE=1 +``` + +## Cross-Provider Testing + +The OpenAI test file supports cross-provider testing through Bifrost's model name routing. By formatting the model name as `provider/model`, Bifrost routes the request to the appropriate provider: + +```typescript +import { formatProviderModel } from '../src/utils' + +const client = new OpenAI({ + baseURL: 'http://localhost:8080/openai', + apiKey: 'your-api-key', +}) + +// Route to Anthropic using the model name format +const response = await client.chat.completions.create({ + model: formatProviderModel('anthropic', 'claude-sonnet-4-20250514'), + // Results in: "anthropic/claude-sonnet-4-20250514" + messages: [{ role: 'user', content: 'Hello' }], +}) + +// Route to Bedrock +const bedrockResponse = await client.chat.completions.create({ + model: formatProviderModel('bedrock', 'global.anthropic.claude-sonnet-4-20250514-v1:0'), + // Results in: "bedrock/global.anthropic.claude-sonnet-4-20250514-v1:0" + messages: [{ role: 'user', content: 'Hello' }], +}) +``` + +This allows testing any provider using the OpenAI SDK format while Bifrost handles the routing based on the model name prefix. + +## Writing New Tests + +### Basic Test Structure + +```typescript +import { describe, it, expect } from 'vitest' +import OpenAI from 'openai' +import { getIntegrationUrl, getProviderModel } from '../src/utils' + +describe('My Feature Tests', () => { + it('should do something', async () => { + const client = new OpenAI({ + baseURL: getIntegrationUrl('openai'), + apiKey: process.env.OPENAI_API_KEY, + }) + + const response = await client.chat.completions.create({ + model: getProviderModel('openai', 'chat'), + messages: [{ role: 'user', content: 'Hello' }], + }) + + expect(response.choices[0].message.content).toBeDefined() + }) +}) +``` + +### Using Test Utilities + +```typescript +import { + SIMPLE_CHAT_MESSAGES, + WEATHER_TOOL, + assertValidChatResponse, + assertHasToolCalls, + convertToOpenAITools, +} from '../src/utils' + +// Use predefined test messages +const response = await client.chat.completions.create({ + model, + messages: SIMPLE_CHAT_MESSAGES, +}) + +// Use assertion helpers +assertValidChatResponse(response) +assertHasToolCalls(response, 1) + +// Use tool conversion utilities +const tools = convertToOpenAITools([WEATHER_TOOL]) +``` + +### Cross-Provider Parametrization + +```typescript +import { getCrossProviderParamsWithVkForScenario } from '../src/utils' + +describe('Cross-Provider Tests', () => { + const testCases = getCrossProviderParamsWithVkForScenario('simple_chat') + + it.each(testCases)( + 'should work - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }) => { + // Test implementation + } + ) +}) +``` + +## Troubleshooting + +### Common Issues + +**1. Connection Refused** +```text +Error: connect ECONNREFUSED 127.0.0.1:8080 +``` +Solution: Ensure Bifrost is running on the expected port. + +**2. API Key Not Set** +```text +Error: OPENAI_API_KEY environment variable not set +``` +Solution: Set the required environment variables. + +**3. Timeout Errors** +```text +Error: Timeout of 300000ms exceeded +``` +Solution: Check network connectivity and Bifrost logs. + +### Debug Mode + +```bash +# Run with debug output +DEBUG=* npm test -- tests/test-openai.test.ts + +# Check Bifrost logs +tail -f /tmp/bifrost-test.log +``` + +## Integration with Python Tests + +The TypeScript and Python test suites share: +- **Configuration** (`config.yml`) - Same provider/model settings +- **Test Scenarios** - Same test categories and assertions +- **Makefile Integration** - Unified `test-integrations` command + +To run both: +```bash +# Python tests +make test-integrations-py + +# TypeScript tests +make test-integrations-ts + +# Both +make test-integrations-py && make test-integrations-ts +``` + +## Contributing + +1. Follow the existing test structure +2. Use the shared utilities from `src/utils/` +3. Add tests for all applicable scenarios +4. Ensure tests pass locally before submitting +5. Update this README if adding new SDKs or features diff --git a/tests/integrations/typescript/config.json b/tests/integrations/typescript/config.json new file mode 100644 index 0000000000..cab65df51d --- /dev/null +++ b/tests/integrations/typescript/config.json @@ -0,0 +1,201 @@ +{ + "$schema": "https://www.getbifrost.ai/schema", + "providers": { + "openai": { + "keys": [ + { + "name": "OpenAI API Key", + "value": "env.OPENAI_API_KEY", + "weight": 1, + "use_for_batch_api": true + } + ], + "network_config": { + "default_request_timeout_in_seconds": 300 + } + }, + "anthropic": { + "keys": [ + { + "name": "Anthropic API Key", + "value": "env.ANTHROPIC_API_KEY", + "weight": 1, + "use_for_batch_api": true + } + ], + "network_config": { + "default_request_timeout_in_seconds": 300 + } + }, + "gemini": { + "keys": [ + { + "name": "Gemini API Key", + "value": "env.GEMINI_API_KEY", + "weight": 1, + "use_for_batch_api": true + } + ], + "network_config": { + "default_request_timeout_in_seconds": 300 + } + }, + "vertex": { + "keys": [ + { + "name": "Vertex API Key", + "vertex_key_config": { + "project_id": "env.GOOGLE_PROJECT_ID", + "region": "env.GOOGLE_LOCATION" + }, + "weight": 1 + } + ], + "network_config": { + "default_request_timeout_in_seconds": 300 + } + }, + "mistral": { + "keys": [ + { + "name": "Mistral API Key", + "value": "env.MISTRAL_API_KEY", + "weight": 1 + } + ], + "network_config": { + "default_request_timeout_in_seconds": 300 + } + }, + "cohere": { + "keys": [ + { + "name": "Cohere API Key", + "value": "env.COHERE_API_KEY", + "weight": 1 + } + ], + "network_config": { + "default_request_timeout_in_seconds": 300 + } + }, + "groq": { + "keys": [ + { + "name": "Groq API Key", + "value": "env.GROQ_API_KEY", + "weight": 1 + } + ], + "network_config": { + "default_request_timeout_in_seconds": 300 + } + }, + "perplexity": { + "keys": [ + { + "name": "Perplexity API Key", + "value": "env.PERPLEXITY_API_KEY", + "weight": 1 + } + ], + "network_config": { + "default_request_timeout_in_seconds": 300 + } + }, + "cerebras": { + "keys": [ + { + "name": "Cerebras API Key", + "value": "env.CEREBRAS_API_KEY", + "weight": 1 + } + ], + "network_config": { + "default_request_timeout_in_seconds": 300 + } + }, + "openrouter": { + "keys": [ + { + "name": "OpenRouter API Key", + "value": "env.OPENROUTER_API_KEY", + "weight": 1 + } + ], + "network_config": { + "default_request_timeout_in_seconds": 300 + } + }, + "azure": { + "keys": [ + { + "name": "Azure OpenAI API Key", + "value": "env.AZURE_OPENAI_API_KEY", + "azure_key_config": { + "endpoint": "env.AZURE_OPENAI_ENDPOINT", + "api_version": "env.AZURE_OPENAI_API_VERSION" + }, + "weight": 1 + } + ], + "network_config": { + "default_request_timeout_in_seconds": 300 + } + }, + "bedrock": { + "keys": [ + { + "name": "Bedrock API Key", + "bedrock_key_config": { + "access_key": "env.AWS_ACCESS_KEY_ID", + "secret_key": "env.AWS_SECRET_ACCESS_KEY", + "region": "env.AWS_REGION", + "arn": "env.AWS_ARN" + }, + "weight": 1, + "use_for_batch_api": true + } + ], + "network_config": { + "default_request_timeout_in_seconds": 300 + } + } + }, + "config_store": { + "enabled": true, + "type": "sqlite", + "config": { + "path": "../../tests/integrations/typescript/config.db" + } + }, + "logs_store": { + "enabled": true, + "type": "sqlite", + "config": { + "path": "../../tests/integrations/typescript/logs.db" + } + }, + "governance": { + "virtual_keys": [ + { + "id": "vk-test", + "value": "sk-bf-test-key", + "is_active": true + } + ] + }, + "client": { + "drop_excess_requests": false, + "initial_pool_size": 300, + "allowed_origins": [ + "*" + ], + "enable_logging": true, + "enable_governance": false, + "enforce_governance_header": false, + "allow_direct_keys": false, + "max_request_body_size_mb": 100, + "enable_litellm_fallbacks": false + } +} \ No newline at end of file diff --git a/tests/integrations/typescript/config.yml b/tests/integrations/typescript/config.yml new file mode 120000 index 0000000000..4ae3243f3a --- /dev/null +++ b/tests/integrations/typescript/config.yml @@ -0,0 +1 @@ +../python/config.yml \ No newline at end of file diff --git a/tests/integrations/typescript/package-lock.json b/tests/integrations/typescript/package-lock.json new file mode 100644 index 0000000000..735782a242 --- /dev/null +++ b/tests/integrations/typescript/package-lock.json @@ -0,0 +1,6332 @@ +{ + "name": "bifrost-integration-tests-typescript", + "version": "0.1.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "bifrost-integration-tests-typescript", + "version": "0.1.0", + "dependencies": { + "@anthropic-ai/sdk": "^0.71.2", + "@aws-sdk/client-bedrock": "^3.966.0", + "@aws-sdk/client-bedrock-runtime": "^3.965.0", + "@google/generative-ai": "^0.24.1", + "@langchain/anthropic": "^0.3.0", + "@langchain/core": "^0.3.0", + "@langchain/google-genai": "^0.1.0", + "@langchain/openai": "^0.3.0", + "openai": "^6.15.0", + "yaml": "^2.6.0", + "zod": "^3.24.0" + }, + "devDependencies": { + "@types/node": "^22.10.0", + "@typescript-eslint/eslint-plugin": "^8.0.0", + "@typescript-eslint/parser": "^8.0.0", + "@vitest/coverage-v8": "^2.1.0", + "@vitest/ui": "^2.1.0", + "dotenv": "^16.4.0", + "eslint": "^9.0.0", + "typescript": "^5.7.0", + "vitest": "^2.1.0" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@ampproject/remapping": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/@ampproject/remapping/-/remapping-2.3.0.tgz", + "integrity": "sha512-30iZtAPgz+LTIYoeivqYo853f02jBYSd5uGnGpkFV0M3xOt9aN73erkgYAmZU43x4VfqcnLxW9Kpg3R5LC4YYw==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@jridgewell/gen-mapping": "^0.3.5", + "@jridgewell/trace-mapping": "^0.3.24" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@anthropic-ai/sdk": { + "version": "0.71.2", + "resolved": "https://registry.npmjs.org/@anthropic-ai/sdk/-/sdk-0.71.2.tgz", + "integrity": "sha512-TGNDEUuEstk/DKu0/TflXAEt+p+p/WhTlFzEnoosvbaDU2LTjm42igSdlL0VijrKpWejtOKxX0b8A7uc+XiSAQ==", + "license": "MIT", + "dependencies": { + "json-schema-to-ts": "^3.1.1" + }, + "bin": { + "anthropic-ai-sdk": "bin/cli" + }, + "peerDependencies": { + "zod": "^3.25.0 || ^4.0.0" + }, + "peerDependenciesMeta": { + "zod": { + "optional": true + } + } + }, + "node_modules/@aws-crypto/crc32": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/@aws-crypto/crc32/-/crc32-5.2.0.tgz", + "integrity": "sha512-nLbCWqQNgUiwwtFsen1AdzAtvuLRsQS8rYgMuxCrdKf9kOssamGLuPwyTY9wyYblNr9+1XM8v6zoDTPPSIeANg==", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/util": "^5.2.0", + "@aws-sdk/types": "^3.222.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=16.0.0" + } + }, + "node_modules/@aws-crypto/sha256-browser": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/@aws-crypto/sha256-browser/-/sha256-browser-5.2.0.tgz", + "integrity": "sha512-AXfN/lGotSQwu6HNcEsIASo7kWXZ5HYWvfOmSNKDsEqC4OashTp8alTmaz+F7TC2L083SFv5RdB+qU3Vs1kZqw==", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/sha256-js": "^5.2.0", + "@aws-crypto/supports-web-crypto": "^5.2.0", + "@aws-crypto/util": "^5.2.0", + "@aws-sdk/types": "^3.222.0", + "@aws-sdk/util-locate-window": "^3.0.0", + "@smithy/util-utf8": "^2.0.0", + "tslib": "^2.6.2" + } + }, + "node_modules/@aws-crypto/sha256-browser/node_modules/@smithy/is-array-buffer": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/@smithy/is-array-buffer/-/is-array-buffer-2.2.0.tgz", + "integrity": "sha512-GGP3O9QFD24uGeAXYUjwSTXARoqpZykHadOmA8G5vfJPK0/DC67qa//0qvqrJzL1xc8WQWX7/yc7fwudjPHPhA==", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@aws-crypto/sha256-browser/node_modules/@smithy/util-buffer-from": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/@smithy/util-buffer-from/-/util-buffer-from-2.2.0.tgz", + "integrity": "sha512-IJdWBbTcMQ6DA0gdNhh/BwrLkDR+ADW5Kr1aZmd4k3DIF6ezMV4R2NIAmT08wQJ3yUK82thHWmC/TnK/wpMMIA==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/is-array-buffer": "^2.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@aws-crypto/sha256-browser/node_modules/@smithy/util-utf8": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/@smithy/util-utf8/-/util-utf8-2.3.0.tgz", + "integrity": "sha512-R8Rdn8Hy72KKcebgLiv8jQcQkXoLMOGGv5uI1/k0l+snqkOzQ1R0ChUBCxWMlBsFMekWjq0wRudIweFs7sKT5A==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/util-buffer-from": "^2.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@aws-crypto/sha256-js": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/@aws-crypto/sha256-js/-/sha256-js-5.2.0.tgz", + "integrity": "sha512-FFQQyu7edu4ufvIZ+OadFpHHOt+eSTBaYaki44c+akjg7qZg9oOQeLlk77F6tSYqjDAFClrHJk9tMf0HdVyOvA==", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/util": "^5.2.0", + "@aws-sdk/types": "^3.222.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=16.0.0" + } + }, + "node_modules/@aws-crypto/supports-web-crypto": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/@aws-crypto/supports-web-crypto/-/supports-web-crypto-5.2.0.tgz", + "integrity": "sha512-iAvUotm021kM33eCdNfwIN//F77/IADDSs58i+MDaOqFrVjZo9bAal0NK7HurRuWLLpF1iLX7gbWrjHjeo+YFg==", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + } + }, + "node_modules/@aws-crypto/util": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/@aws-crypto/util/-/util-5.2.0.tgz", + "integrity": "sha512-4RkU9EsI6ZpBve5fseQlGNUWKMa1RLPQ1dnjnQoe07ldfIzcsGb5hC5W0Dm7u423KWzawlrpbjXBrXCEv9zazQ==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "^3.222.0", + "@smithy/util-utf8": "^2.0.0", + "tslib": "^2.6.2" + } + }, + "node_modules/@aws-crypto/util/node_modules/@smithy/is-array-buffer": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/@smithy/is-array-buffer/-/is-array-buffer-2.2.0.tgz", + "integrity": "sha512-GGP3O9QFD24uGeAXYUjwSTXARoqpZykHadOmA8G5vfJPK0/DC67qa//0qvqrJzL1xc8WQWX7/yc7fwudjPHPhA==", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@aws-crypto/util/node_modules/@smithy/util-buffer-from": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/@smithy/util-buffer-from/-/util-buffer-from-2.2.0.tgz", + "integrity": "sha512-IJdWBbTcMQ6DA0gdNhh/BwrLkDR+ADW5Kr1aZmd4k3DIF6ezMV4R2NIAmT08wQJ3yUK82thHWmC/TnK/wpMMIA==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/is-array-buffer": "^2.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@aws-crypto/util/node_modules/@smithy/util-utf8": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/@smithy/util-utf8/-/util-utf8-2.3.0.tgz", + "integrity": "sha512-R8Rdn8Hy72KKcebgLiv8jQcQkXoLMOGGv5uI1/k0l+snqkOzQ1R0ChUBCxWMlBsFMekWjq0wRudIweFs7sKT5A==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/util-buffer-from": "^2.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock": { + "version": "3.966.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/client-bedrock/-/client-bedrock-3.966.0.tgz", + "integrity": "sha512-fk3CL7v0JeHzIB3i7qzo8au6zfMIibNw8avxdJRXW14pINRO6nLd/l75xqs/IXKYmv0h7lnFpMHVdekfKa6nIQ==", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/sha256-browser": "5.2.0", + "@aws-crypto/sha256-js": "5.2.0", + "@aws-sdk/core": "3.966.0", + "@aws-sdk/credential-provider-node": "3.966.0", + "@aws-sdk/middleware-host-header": "3.965.0", + "@aws-sdk/middleware-logger": "3.965.0", + "@aws-sdk/middleware-recursion-detection": "3.965.0", + "@aws-sdk/middleware-user-agent": "3.966.0", + "@aws-sdk/region-config-resolver": "3.965.0", + "@aws-sdk/token-providers": "3.966.0", + "@aws-sdk/types": "3.965.0", + "@aws-sdk/util-endpoints": "3.965.0", + "@aws-sdk/util-user-agent-browser": "3.965.0", + "@aws-sdk/util-user-agent-node": "3.966.0", + "@smithy/config-resolver": "^4.4.5", + "@smithy/core": "^3.20.1", + "@smithy/fetch-http-handler": "^5.3.8", + "@smithy/hash-node": "^4.2.7", + "@smithy/invalid-dependency": "^4.2.7", + "@smithy/middleware-content-length": "^4.2.7", + "@smithy/middleware-endpoint": "^4.4.2", + "@smithy/middleware-retry": "^4.4.18", + "@smithy/middleware-serde": "^4.2.8", + "@smithy/middleware-stack": "^4.2.7", + "@smithy/node-config-provider": "^4.3.7", + "@smithy/node-http-handler": "^4.4.7", + "@smithy/protocol-http": "^5.3.7", + "@smithy/smithy-client": "^4.10.3", + "@smithy/types": "^4.11.0", + "@smithy/url-parser": "^4.2.7", + "@smithy/util-base64": "^4.3.0", + "@smithy/util-body-length-browser": "^4.2.0", + "@smithy/util-body-length-node": "^4.2.1", + "@smithy/util-defaults-mode-browser": "^4.3.17", + "@smithy/util-defaults-mode-node": "^4.2.20", + "@smithy/util-endpoints": "^3.2.7", + "@smithy/util-middleware": "^4.2.7", + "@smithy/util-retry": "^4.2.7", + "@smithy/util-utf8": "^4.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock-runtime": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/client-bedrock-runtime/-/client-bedrock-runtime-3.965.0.tgz", + "integrity": "sha512-ccx3IJcSYNrkj3lAojip2Esjd6YSbrfEvJmvunNkcciexJsEaykDQExN+RSxIcaSvqVXkfqoSbxapI62fOUOfg==", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/sha256-browser": "5.2.0", + "@aws-crypto/sha256-js": "5.2.0", + "@aws-sdk/core": "3.965.0", + "@aws-sdk/credential-provider-node": "3.965.0", + "@aws-sdk/eventstream-handler-node": "3.965.0", + "@aws-sdk/middleware-eventstream": "3.965.0", + "@aws-sdk/middleware-host-header": "3.965.0", + "@aws-sdk/middleware-logger": "3.965.0", + "@aws-sdk/middleware-recursion-detection": "3.965.0", + "@aws-sdk/middleware-user-agent": "3.965.0", + "@aws-sdk/middleware-websocket": "3.965.0", + "@aws-sdk/region-config-resolver": "3.965.0", + "@aws-sdk/token-providers": "3.965.0", + "@aws-sdk/types": "3.965.0", + "@aws-sdk/util-endpoints": "3.965.0", + "@aws-sdk/util-user-agent-browser": "3.965.0", + "@aws-sdk/util-user-agent-node": "3.965.0", + "@smithy/config-resolver": "^4.4.5", + "@smithy/core": "^3.20.0", + "@smithy/eventstream-serde-browser": "^4.2.7", + "@smithy/eventstream-serde-config-resolver": "^4.3.7", + "@smithy/eventstream-serde-node": "^4.2.7", + "@smithy/fetch-http-handler": "^5.3.8", + "@smithy/hash-node": "^4.2.7", + "@smithy/invalid-dependency": "^4.2.7", + "@smithy/middleware-content-length": "^4.2.7", + "@smithy/middleware-endpoint": "^4.4.1", + "@smithy/middleware-retry": "^4.4.17", + "@smithy/middleware-serde": "^4.2.8", + "@smithy/middleware-stack": "^4.2.7", + "@smithy/node-config-provider": "^4.3.7", + "@smithy/node-http-handler": "^4.4.7", + "@smithy/protocol-http": "^5.3.7", + "@smithy/smithy-client": "^4.10.2", + "@smithy/types": "^4.11.0", + "@smithy/url-parser": "^4.2.7", + "@smithy/util-base64": "^4.3.0", + "@smithy/util-body-length-browser": "^4.2.0", + "@smithy/util-body-length-node": "^4.2.1", + "@smithy/util-defaults-mode-browser": "^4.3.16", + "@smithy/util-defaults-mode-node": "^4.2.19", + "@smithy/util-endpoints": "^3.2.7", + "@smithy/util-middleware": "^4.2.7", + "@smithy/util-retry": "^4.2.7", + "@smithy/util-stream": "^4.5.8", + "@smithy/util-utf8": "^4.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock/node_modules/@aws-sdk/client-sso": { + "version": "3.966.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/client-sso/-/client-sso-3.966.0.tgz", + "integrity": "sha512-hQZDQgqRJclALDo9wK+bb5O+VpO8JcjImp52w9KPSz9XveNRgE9AYfklRJd8qT2Bwhxe6IbnqYEino2wqUMA1w==", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/sha256-browser": "5.2.0", + "@aws-crypto/sha256-js": "5.2.0", + "@aws-sdk/core": "3.966.0", + "@aws-sdk/middleware-host-header": "3.965.0", + "@aws-sdk/middleware-logger": "3.965.0", + "@aws-sdk/middleware-recursion-detection": "3.965.0", + "@aws-sdk/middleware-user-agent": "3.966.0", + "@aws-sdk/region-config-resolver": "3.965.0", + "@aws-sdk/types": "3.965.0", + "@aws-sdk/util-endpoints": "3.965.0", + "@aws-sdk/util-user-agent-browser": "3.965.0", + "@aws-sdk/util-user-agent-node": "3.966.0", + "@smithy/config-resolver": "^4.4.5", + "@smithy/core": "^3.20.1", + "@smithy/fetch-http-handler": "^5.3.8", + "@smithy/hash-node": "^4.2.7", + "@smithy/invalid-dependency": "^4.2.7", + "@smithy/middleware-content-length": "^4.2.7", + "@smithy/middleware-endpoint": "^4.4.2", + "@smithy/middleware-retry": "^4.4.18", + "@smithy/middleware-serde": "^4.2.8", + "@smithy/middleware-stack": "^4.2.7", + "@smithy/node-config-provider": "^4.3.7", + "@smithy/node-http-handler": "^4.4.7", + "@smithy/protocol-http": "^5.3.7", + "@smithy/smithy-client": "^4.10.3", + "@smithy/types": "^4.11.0", + "@smithy/url-parser": "^4.2.7", + "@smithy/util-base64": "^4.3.0", + "@smithy/util-body-length-browser": "^4.2.0", + "@smithy/util-body-length-node": "^4.2.1", + "@smithy/util-defaults-mode-browser": "^4.3.17", + "@smithy/util-defaults-mode-node": "^4.2.20", + "@smithy/util-endpoints": "^3.2.7", + "@smithy/util-middleware": "^4.2.7", + "@smithy/util-retry": "^4.2.7", + "@smithy/util-utf8": "^4.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock/node_modules/@aws-sdk/core": { + "version": "3.966.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/core/-/core-3.966.0.tgz", + "integrity": "sha512-QaRVBHD1prdrFXIeFAY/1w4b4S0EFyo/ytzU+rCklEjMRT7DKGXGoHXTWLGz+HD7ovlS5u+9cf8a/LeSOEMzww==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "3.965.0", + "@aws-sdk/xml-builder": "3.965.0", + "@smithy/core": "^3.20.1", + "@smithy/node-config-provider": "^4.3.7", + "@smithy/property-provider": "^4.2.7", + "@smithy/protocol-http": "^5.3.7", + "@smithy/signature-v4": "^5.3.7", + "@smithy/smithy-client": "^4.10.3", + "@smithy/types": "^4.11.0", + "@smithy/util-base64": "^4.3.0", + "@smithy/util-middleware": "^4.2.7", + "@smithy/util-utf8": "^4.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock/node_modules/@aws-sdk/credential-provider-env": { + "version": "3.966.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-env/-/credential-provider-env-3.966.0.tgz", + "integrity": "sha512-sxVKc9PY0SH7jgN/8WxhbKQ7MWDIgaJv1AoAKJkhJ+GM5r09G5Vb2Vl8ALYpsy+r8b+iYpq5dGJj8k2VqxoQMg==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "3.966.0", + "@aws-sdk/types": "3.965.0", + "@smithy/property-provider": "^4.2.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock/node_modules/@aws-sdk/credential-provider-http": { + "version": "3.966.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-http/-/credential-provider-http-3.966.0.tgz", + "integrity": "sha512-VTJDP1jOibVtc5pn5TNE12rhqOO/n10IjkoJi8fFp9BMfmh3iqo70Ppvphz/Pe/R9LcK5Z3h0Z4EB9IXDR6kag==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "3.966.0", + "@aws-sdk/types": "3.965.0", + "@smithy/fetch-http-handler": "^5.3.8", + "@smithy/node-http-handler": "^4.4.7", + "@smithy/property-provider": "^4.2.7", + "@smithy/protocol-http": "^5.3.7", + "@smithy/smithy-client": "^4.10.3", + "@smithy/types": "^4.11.0", + "@smithy/util-stream": "^4.5.8", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock/node_modules/@aws-sdk/credential-provider-ini": { + "version": "3.966.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-ini/-/credential-provider-ini-3.966.0.tgz", + "integrity": "sha512-4oQKkYMCUx0mffKuH8LQag1M4Fo5daKVmsLAnjrIqKh91xmCrcWlAFNMgeEYvI1Yy125XeNSaFMfir6oNc2ODA==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "3.966.0", + "@aws-sdk/credential-provider-env": "3.966.0", + "@aws-sdk/credential-provider-http": "3.966.0", + "@aws-sdk/credential-provider-login": "3.966.0", + "@aws-sdk/credential-provider-process": "3.966.0", + "@aws-sdk/credential-provider-sso": "3.966.0", + "@aws-sdk/credential-provider-web-identity": "3.966.0", + "@aws-sdk/nested-clients": "3.966.0", + "@aws-sdk/types": "3.965.0", + "@smithy/credential-provider-imds": "^4.2.7", + "@smithy/property-provider": "^4.2.7", + "@smithy/shared-ini-file-loader": "^4.4.2", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock/node_modules/@aws-sdk/credential-provider-login": { + "version": "3.966.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-login/-/credential-provider-login-3.966.0.tgz", + "integrity": "sha512-wD1KlqLyh23Xfns/ZAPxebwXixoJJCuDbeJHFrLDpP4D4h3vA2S8nSFgBSFR15q9FhgRfHleClycf6g5K4Ww6w==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "3.966.0", + "@aws-sdk/nested-clients": "3.966.0", + "@aws-sdk/types": "3.965.0", + "@smithy/property-provider": "^4.2.7", + "@smithy/protocol-http": "^5.3.7", + "@smithy/shared-ini-file-loader": "^4.4.2", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock/node_modules/@aws-sdk/credential-provider-node": { + "version": "3.966.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-node/-/credential-provider-node-3.966.0.tgz", + "integrity": "sha512-7QCOERGddMw7QbjE+LSAFgwOBpPv4px2ty0GCK7ZiPJGsni2EYmM4TtYnQb9u1WNHmHqIPWMbZR0pKDbyRyHlQ==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/credential-provider-env": "3.966.0", + "@aws-sdk/credential-provider-http": "3.966.0", + "@aws-sdk/credential-provider-ini": "3.966.0", + "@aws-sdk/credential-provider-process": "3.966.0", + "@aws-sdk/credential-provider-sso": "3.966.0", + "@aws-sdk/credential-provider-web-identity": "3.966.0", + "@aws-sdk/types": "3.965.0", + "@smithy/credential-provider-imds": "^4.2.7", + "@smithy/property-provider": "^4.2.7", + "@smithy/shared-ini-file-loader": "^4.4.2", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock/node_modules/@aws-sdk/credential-provider-process": { + "version": "3.966.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-process/-/credential-provider-process-3.966.0.tgz", + "integrity": "sha512-q5kCo+xHXisNbbPAh/DiCd+LZX4wdby77t7GLk0b2U0/mrel4lgy6o79CApe+0emakpOS1nPZS7voXA7vGPz4w==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "3.966.0", + "@aws-sdk/types": "3.965.0", + "@smithy/property-provider": "^4.2.7", + "@smithy/shared-ini-file-loader": "^4.4.2", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock/node_modules/@aws-sdk/credential-provider-sso": { + "version": "3.966.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-sso/-/credential-provider-sso-3.966.0.tgz", + "integrity": "sha512-Rv5aEfbpqsQZzxpX2x+FbSyVFOE3Dngome+exNA8jGzc00rrMZEUnm3J3yAsLp/I2l7wnTfI0r2zMe+T9/nZAQ==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/client-sso": "3.966.0", + "@aws-sdk/core": "3.966.0", + "@aws-sdk/token-providers": "3.966.0", + "@aws-sdk/types": "3.965.0", + "@smithy/property-provider": "^4.2.7", + "@smithy/shared-ini-file-loader": "^4.4.2", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock/node_modules/@aws-sdk/credential-provider-web-identity": { + "version": "3.966.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-web-identity/-/credential-provider-web-identity-3.966.0.tgz", + "integrity": "sha512-Yv1lc9iic9xg3ywMmIAeXN1YwuvfcClLVdiF2y71LqUgIOupW8B8my84XJr6pmOQuKzZa++c2znNhC9lGsbKyw==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "3.966.0", + "@aws-sdk/nested-clients": "3.966.0", + "@aws-sdk/types": "3.965.0", + "@smithy/property-provider": "^4.2.7", + "@smithy/shared-ini-file-loader": "^4.4.2", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock/node_modules/@aws-sdk/middleware-user-agent": { + "version": "3.966.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/middleware-user-agent/-/middleware-user-agent-3.966.0.tgz", + "integrity": "sha512-MvGoy0vhMluVpSB5GaGJbYLqwbZfZjwEZhneDHdPhgCgQqmCtugnYIIjpUw7kKqWGsmaMQmNEgSFf1zYYmwOyg==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "3.966.0", + "@aws-sdk/types": "3.965.0", + "@aws-sdk/util-endpoints": "3.965.0", + "@smithy/core": "^3.20.1", + "@smithy/protocol-http": "^5.3.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock/node_modules/@aws-sdk/nested-clients": { + "version": "3.966.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/nested-clients/-/nested-clients-3.966.0.tgz", + "integrity": "sha512-FRzAWwLNoKiaEWbYhnpnfartIdOgiaBLnPcd3uG1Io+vvxQUeRPhQIy4EfKnT3AuA+g7gzSCjMG2JKoJOplDtQ==", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/sha256-browser": "5.2.0", + "@aws-crypto/sha256-js": "5.2.0", + "@aws-sdk/core": "3.966.0", + "@aws-sdk/middleware-host-header": "3.965.0", + "@aws-sdk/middleware-logger": "3.965.0", + "@aws-sdk/middleware-recursion-detection": "3.965.0", + "@aws-sdk/middleware-user-agent": "3.966.0", + "@aws-sdk/region-config-resolver": "3.965.0", + "@aws-sdk/types": "3.965.0", + "@aws-sdk/util-endpoints": "3.965.0", + "@aws-sdk/util-user-agent-browser": "3.965.0", + "@aws-sdk/util-user-agent-node": "3.966.0", + "@smithy/config-resolver": "^4.4.5", + "@smithy/core": "^3.20.1", + "@smithy/fetch-http-handler": "^5.3.8", + "@smithy/hash-node": "^4.2.7", + "@smithy/invalid-dependency": "^4.2.7", + "@smithy/middleware-content-length": "^4.2.7", + "@smithy/middleware-endpoint": "^4.4.2", + "@smithy/middleware-retry": "^4.4.18", + "@smithy/middleware-serde": "^4.2.8", + "@smithy/middleware-stack": "^4.2.7", + "@smithy/node-config-provider": "^4.3.7", + "@smithy/node-http-handler": "^4.4.7", + "@smithy/protocol-http": "^5.3.7", + "@smithy/smithy-client": "^4.10.3", + "@smithy/types": "^4.11.0", + "@smithy/url-parser": "^4.2.7", + "@smithy/util-base64": "^4.3.0", + "@smithy/util-body-length-browser": "^4.2.0", + "@smithy/util-body-length-node": "^4.2.1", + "@smithy/util-defaults-mode-browser": "^4.3.17", + "@smithy/util-defaults-mode-node": "^4.2.20", + "@smithy/util-endpoints": "^3.2.7", + "@smithy/util-middleware": "^4.2.7", + "@smithy/util-retry": "^4.2.7", + "@smithy/util-utf8": "^4.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock/node_modules/@aws-sdk/token-providers": { + "version": "3.966.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/token-providers/-/token-providers-3.966.0.tgz", + "integrity": "sha512-8k5cBTicTGYJHhKaweO4gL4fud1KDnLS5fByT6/Xbiu59AxYM4E/h3ds+3jxDMnniCE3gIWpEnyfM9khtmw2lA==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "3.966.0", + "@aws-sdk/nested-clients": "3.966.0", + "@aws-sdk/types": "3.965.0", + "@smithy/property-provider": "^4.2.7", + "@smithy/shared-ini-file-loader": "^4.4.2", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock/node_modules/@aws-sdk/util-user-agent-node": { + "version": "3.966.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/util-user-agent-node/-/util-user-agent-node-3.966.0.tgz", + "integrity": "sha512-vPPe8V0GLj+jVS5EqFz2NUBgWH35favqxliUOvhp8xBdNRkEjiZm5TqitVtFlxS4RrLY3HOndrWbrP5ejbwl1Q==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/middleware-user-agent": "3.966.0", + "@aws-sdk/types": "3.965.0", + "@smithy/node-config-provider": "^4.3.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + }, + "peerDependencies": { + "aws-crt": ">=1.0.0" + }, + "peerDependenciesMeta": { + "aws-crt": { + "optional": true + } + } + }, + "node_modules/@aws-sdk/client-sso": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/client-sso/-/client-sso-3.965.0.tgz", + "integrity": "sha512-iv2tr+n4aZ+nPUFFvG00hISPuEd4DU+1/Q8rPAYKXsM+vEPJ2nAnP5duUOa2fbOLIUCRxX3dcQaQaghVHDHzQw==", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/sha256-browser": "5.2.0", + "@aws-crypto/sha256-js": "5.2.0", + "@aws-sdk/core": "3.965.0", + "@aws-sdk/middleware-host-header": "3.965.0", + "@aws-sdk/middleware-logger": "3.965.0", + "@aws-sdk/middleware-recursion-detection": "3.965.0", + "@aws-sdk/middleware-user-agent": "3.965.0", + "@aws-sdk/region-config-resolver": "3.965.0", + "@aws-sdk/types": "3.965.0", + "@aws-sdk/util-endpoints": "3.965.0", + "@aws-sdk/util-user-agent-browser": "3.965.0", + "@aws-sdk/util-user-agent-node": "3.965.0", + "@smithy/config-resolver": "^4.4.5", + "@smithy/core": "^3.20.0", + "@smithy/fetch-http-handler": "^5.3.8", + "@smithy/hash-node": "^4.2.7", + "@smithy/invalid-dependency": "^4.2.7", + "@smithy/middleware-content-length": "^4.2.7", + "@smithy/middleware-endpoint": "^4.4.1", + "@smithy/middleware-retry": "^4.4.17", + "@smithy/middleware-serde": "^4.2.8", + "@smithy/middleware-stack": "^4.2.7", + "@smithy/node-config-provider": "^4.3.7", + "@smithy/node-http-handler": "^4.4.7", + "@smithy/protocol-http": "^5.3.7", + "@smithy/smithy-client": "^4.10.2", + "@smithy/types": "^4.11.0", + "@smithy/url-parser": "^4.2.7", + "@smithy/util-base64": "^4.3.0", + "@smithy/util-body-length-browser": "^4.2.0", + "@smithy/util-body-length-node": "^4.2.1", + "@smithy/util-defaults-mode-browser": "^4.3.16", + "@smithy/util-defaults-mode-node": "^4.2.19", + "@smithy/util-endpoints": "^3.2.7", + "@smithy/util-middleware": "^4.2.7", + "@smithy/util-retry": "^4.2.7", + "@smithy/util-utf8": "^4.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/core": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/core/-/core-3.965.0.tgz", + "integrity": "sha512-aq9BhQxdHit8UUJ9C0im9TtuKeK0pT6NXmNJxMTCFeStI7GG7ImIsSislg3BZTIifVg1P6VLdzMyz9de85iutQ==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "3.965.0", + "@aws-sdk/xml-builder": "3.965.0", + "@smithy/core": "^3.20.0", + "@smithy/node-config-provider": "^4.3.7", + "@smithy/property-provider": "^4.2.7", + "@smithy/protocol-http": "^5.3.7", + "@smithy/signature-v4": "^5.3.7", + "@smithy/smithy-client": "^4.10.2", + "@smithy/types": "^4.11.0", + "@smithy/util-base64": "^4.3.0", + "@smithy/util-middleware": "^4.2.7", + "@smithy/util-utf8": "^4.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/credential-provider-env": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-env/-/credential-provider-env-3.965.0.tgz", + "integrity": "sha512-mdGnaIjMxTIjsb70dEj3VsWPWpoq1V5MWzBSfJq2H8zgMBXjn6d5/qHP8HMf53l9PrsgqzMpXGv3Av549A2x1g==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "3.965.0", + "@aws-sdk/types": "3.965.0", + "@smithy/property-provider": "^4.2.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/credential-provider-http": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-http/-/credential-provider-http-3.965.0.tgz", + "integrity": "sha512-YuGQel9EgA/z25oeLM+GYYQS750+8AESvr7ZEmVnRPL0sg+K3DmGqdv+9gFjFd0UkLjTlC/jtbP2cuY6UcPiHQ==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "3.965.0", + "@aws-sdk/types": "3.965.0", + "@smithy/fetch-http-handler": "^5.3.8", + "@smithy/node-http-handler": "^4.4.7", + "@smithy/property-provider": "^4.2.7", + "@smithy/protocol-http": "^5.3.7", + "@smithy/smithy-client": "^4.10.2", + "@smithy/types": "^4.11.0", + "@smithy/util-stream": "^4.5.8", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/credential-provider-ini": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-ini/-/credential-provider-ini-3.965.0.tgz", + "integrity": "sha512-xRo72Prer5s0xYVSCxCymVIRSqrVlevK5cmU0GWq9yJtaBNpnx02jwdJg80t/Ni7pgbkQyFWRMcq38c1tc6M/w==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "3.965.0", + "@aws-sdk/credential-provider-env": "3.965.0", + "@aws-sdk/credential-provider-http": "3.965.0", + "@aws-sdk/credential-provider-login": "3.965.0", + "@aws-sdk/credential-provider-process": "3.965.0", + "@aws-sdk/credential-provider-sso": "3.965.0", + "@aws-sdk/credential-provider-web-identity": "3.965.0", + "@aws-sdk/nested-clients": "3.965.0", + "@aws-sdk/types": "3.965.0", + "@smithy/credential-provider-imds": "^4.2.7", + "@smithy/property-provider": "^4.2.7", + "@smithy/shared-ini-file-loader": "^4.4.2", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/credential-provider-login": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-login/-/credential-provider-login-3.965.0.tgz", + "integrity": "sha512-43/H8Qku8LHyugbhLo8kjD+eauhybCeVkmrnvWl8bXNHJP7xi1jCdtBQJKKJqiIHZws4MOEwkji8kFdAVRCe6g==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "3.965.0", + "@aws-sdk/nested-clients": "3.965.0", + "@aws-sdk/types": "3.965.0", + "@smithy/property-provider": "^4.2.7", + "@smithy/protocol-http": "^5.3.7", + "@smithy/shared-ini-file-loader": "^4.4.2", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/credential-provider-node": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-node/-/credential-provider-node-3.965.0.tgz", + "integrity": "sha512-cRxmMHF+Zh2lkkkEVduKl+8OQdtg/DhYA69+/7SPSQURlgyjFQGlRQ58B7q8abuNlrGT3sV+UzeOylZpJbV61Q==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/credential-provider-env": "3.965.0", + "@aws-sdk/credential-provider-http": "3.965.0", + "@aws-sdk/credential-provider-ini": "3.965.0", + "@aws-sdk/credential-provider-process": "3.965.0", + "@aws-sdk/credential-provider-sso": "3.965.0", + "@aws-sdk/credential-provider-web-identity": "3.965.0", + "@aws-sdk/types": "3.965.0", + "@smithy/credential-provider-imds": "^4.2.7", + "@smithy/property-provider": "^4.2.7", + "@smithy/shared-ini-file-loader": "^4.4.2", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/credential-provider-process": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-process/-/credential-provider-process-3.965.0.tgz", + "integrity": "sha512-gmkPmdiR0yxnTzLPDb7rwrDhGuCUjtgnj8qWP+m0gSz/W43rR4jRPVEf6DUX2iC+ImQhxo3NFhuB3V42Kzo3TQ==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "3.965.0", + "@aws-sdk/types": "3.965.0", + "@smithy/property-provider": "^4.2.7", + "@smithy/shared-ini-file-loader": "^4.4.2", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/credential-provider-sso": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-sso/-/credential-provider-sso-3.965.0.tgz", + "integrity": "sha512-N01AYvtCqG3Wo/s/LvYt19ity18/FqggiXT+elAs3X9Om/Wfx+hw9G+i7jaDmy+/xewmv8AdQ2SK5Q30dXw/Fw==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/client-sso": "3.965.0", + "@aws-sdk/core": "3.965.0", + "@aws-sdk/token-providers": "3.965.0", + "@aws-sdk/types": "3.965.0", + "@smithy/property-provider": "^4.2.7", + "@smithy/shared-ini-file-loader": "^4.4.2", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/credential-provider-web-identity": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-web-identity/-/credential-provider-web-identity-3.965.0.tgz", + "integrity": "sha512-T4gMZ2JzXnfxe1oTD+EDGLSxFfk1+WkLZdiHXEMZp8bFI1swP/3YyDFXI+Ib9Uq1JhnAmrCXtOnkicKEhDkdhQ==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "3.965.0", + "@aws-sdk/nested-clients": "3.965.0", + "@aws-sdk/types": "3.965.0", + "@smithy/property-provider": "^4.2.7", + "@smithy/shared-ini-file-loader": "^4.4.2", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/eventstream-handler-node": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/eventstream-handler-node/-/eventstream-handler-node-3.965.0.tgz", + "integrity": "sha512-QriACiXP+/x2xXw8u849BxID+zSUbh/7Gt0Zfaxeye0mIKVeSTid5776rXfrM8wcYhbVXWWZhKd1Du7oPuFwsg==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "3.965.0", + "@smithy/eventstream-codec": "^4.2.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/middleware-eventstream": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/middleware-eventstream/-/middleware-eventstream-3.965.0.tgz", + "integrity": "sha512-YVNOPbc3r+gETUY6ufnJYsgIRMaBfoGRM9GzPb+gwtidCPd0BEpLjmZNIVGYawMrGc2kAdlV1kjBzAvmYaMINw==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "3.965.0", + "@smithy/protocol-http": "^5.3.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/middleware-host-header": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/middleware-host-header/-/middleware-host-header-3.965.0.tgz", + "integrity": "sha512-SfpSYqoPOAmdb3DBsnNsZ0vix+1VAtkUkzXM79JL3R5IfacpyKE2zytOgVAQx/FjhhlpSTwuXd+LRhUEVb3MaA==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "3.965.0", + "@smithy/protocol-http": "^5.3.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/middleware-logger": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/middleware-logger/-/middleware-logger-3.965.0.tgz", + "integrity": "sha512-gjUvJRZT1bUABKewnvkj51LAynFrfz2h5DYAg5/2F4Utx6UOGByTSr9Rq8JCLbURvvzAbCtcMkkIJRxw+8Zuzw==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "3.965.0", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/middleware-recursion-detection": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/middleware-recursion-detection/-/middleware-recursion-detection-3.965.0.tgz", + "integrity": "sha512-6dvD+18Ni14KCRu+tfEoNxq1sIGVp9tvoZDZ7aMvpnA7mDXuRLrOjRQ/TAZqXwr9ENKVGyxcPl0cRK8jk1YWjA==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "3.965.0", + "@aws/lambda-invoke-store": "^0.2.2", + "@smithy/protocol-http": "^5.3.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/middleware-user-agent": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/middleware-user-agent/-/middleware-user-agent-3.965.0.tgz", + "integrity": "sha512-RBEYVGgu/WeAt+H/qLrGc+t8LqAUkbyvh3wBfTiuAD+uBcWsKnvnB1iSBX75FearC0fmoxzXRUc0PMxMdqpjJQ==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "3.965.0", + "@aws-sdk/types": "3.965.0", + "@aws-sdk/util-endpoints": "3.965.0", + "@smithy/core": "^3.20.0", + "@smithy/protocol-http": "^5.3.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/middleware-websocket": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/middleware-websocket/-/middleware-websocket-3.965.0.tgz", + "integrity": "sha512-BGU92StrWF0EJj8jX5EFvRkX9z4/CVIZfON0nWow8gb5ouKwz47o1rO9CP/k2b3F6g134/0XqwXvrUgIWfjJeA==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "3.965.0", + "@aws-sdk/util-format-url": "3.965.0", + "@smithy/eventstream-codec": "^4.2.7", + "@smithy/eventstream-serde-browser": "^4.2.7", + "@smithy/fetch-http-handler": "^5.3.8", + "@smithy/protocol-http": "^5.3.7", + "@smithy/signature-v4": "^5.3.7", + "@smithy/types": "^4.11.0", + "@smithy/util-hex-encoding": "^4.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">= 14.0.0" + } + }, + "node_modules/@aws-sdk/nested-clients": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/nested-clients/-/nested-clients-3.965.0.tgz", + "integrity": "sha512-muNVUjUEU+/KLFrLzQ8PMXyw4+a/MP6t4GIvwLtyx/kH0rpSy5s0YmqacMXheuIe6F/5QT8uksXGNAQenitkGQ==", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/sha256-browser": "5.2.0", + "@aws-crypto/sha256-js": "5.2.0", + "@aws-sdk/core": "3.965.0", + "@aws-sdk/middleware-host-header": "3.965.0", + "@aws-sdk/middleware-logger": "3.965.0", + "@aws-sdk/middleware-recursion-detection": "3.965.0", + "@aws-sdk/middleware-user-agent": "3.965.0", + "@aws-sdk/region-config-resolver": "3.965.0", + "@aws-sdk/types": "3.965.0", + "@aws-sdk/util-endpoints": "3.965.0", + "@aws-sdk/util-user-agent-browser": "3.965.0", + "@aws-sdk/util-user-agent-node": "3.965.0", + "@smithy/config-resolver": "^4.4.5", + "@smithy/core": "^3.20.0", + "@smithy/fetch-http-handler": "^5.3.8", + "@smithy/hash-node": "^4.2.7", + "@smithy/invalid-dependency": "^4.2.7", + "@smithy/middleware-content-length": "^4.2.7", + "@smithy/middleware-endpoint": "^4.4.1", + "@smithy/middleware-retry": "^4.4.17", + "@smithy/middleware-serde": "^4.2.8", + "@smithy/middleware-stack": "^4.2.7", + "@smithy/node-config-provider": "^4.3.7", + "@smithy/node-http-handler": "^4.4.7", + "@smithy/protocol-http": "^5.3.7", + "@smithy/smithy-client": "^4.10.2", + "@smithy/types": "^4.11.0", + "@smithy/url-parser": "^4.2.7", + "@smithy/util-base64": "^4.3.0", + "@smithy/util-body-length-browser": "^4.2.0", + "@smithy/util-body-length-node": "^4.2.1", + "@smithy/util-defaults-mode-browser": "^4.3.16", + "@smithy/util-defaults-mode-node": "^4.2.19", + "@smithy/util-endpoints": "^3.2.7", + "@smithy/util-middleware": "^4.2.7", + "@smithy/util-retry": "^4.2.7", + "@smithy/util-utf8": "^4.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/region-config-resolver": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/region-config-resolver/-/region-config-resolver-3.965.0.tgz", + "integrity": "sha512-RoMhu9ly2B0coxn8ctXosPP2WmDD0MkQlZGLjoYHQUOCBmty5qmCxOqBmBDa6wbWbB8xKtMQ/4VXloQOgzjHXg==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "3.965.0", + "@smithy/config-resolver": "^4.4.5", + "@smithy/node-config-provider": "^4.3.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/token-providers": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/token-providers/-/token-providers-3.965.0.tgz", + "integrity": "sha512-aR0qxg0b8flkXJVE+CM1gzo7uJ57md50z2eyCwofC0QIz5Y0P7/7vvb9/dmUQt6eT9XRN5iRcUqq2IVxVDvJOw==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "3.965.0", + "@aws-sdk/nested-clients": "3.965.0", + "@aws-sdk/types": "3.965.0", + "@smithy/property-provider": "^4.2.7", + "@smithy/shared-ini-file-loader": "^4.4.2", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/types": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/types/-/types-3.965.0.tgz", + "integrity": "sha512-jvodoJdMavvg8faN7co58vVJRO5MVep4JFPRzUNCzpJ98BDqWDk/ad045aMJcmxkLzYLS2UAnUmqjJ/tUPNlzQ==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/util-endpoints": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/util-endpoints/-/util-endpoints-3.965.0.tgz", + "integrity": "sha512-WqSCB0XIsGUwZWvrYkuoofi2vzoVHqyeJ2kN+WyoOsxPLTiQSBIoqm/01R/qJvoxwK/gOOF7su9i84Vw2NQQpQ==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "3.965.0", + "@smithy/types": "^4.11.0", + "@smithy/url-parser": "^4.2.7", + "@smithy/util-endpoints": "^3.2.7", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/util-format-url": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/util-format-url/-/util-format-url-3.965.0.tgz", + "integrity": "sha512-KiplV4xYGXdNCcz5eRP8WfAejT5EkE2gQxC4IY6WsuxYprzQKsnGaAzEQ+giR5GgQLIRBkPaWT0xHEYkMiCQ1Q==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "3.965.0", + "@smithy/querystring-builder": "^4.2.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/util-locate-window": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/util-locate-window/-/util-locate-window-3.965.0.tgz", + "integrity": "sha512-9LJFand4bIoOjOF4x3wx0UZYiFZRo4oUauxQSiEX2dVg+5qeBOJSjp2SeWykIE6+6frCZ5wvWm2fGLK8D32aJw==", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/util-user-agent-browser": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/util-user-agent-browser/-/util-user-agent-browser-3.965.0.tgz", + "integrity": "sha512-Xiza/zMntQGpkd2dETQeAK8So1pg5+STTzpcdGWxj5q0jGO5ayjqT/q1Q7BrsX5KIr6PvRkl9/V7lLCv04wGjQ==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "3.965.0", + "@smithy/types": "^4.11.0", + "bowser": "^2.11.0", + "tslib": "^2.6.2" + } + }, + "node_modules/@aws-sdk/util-user-agent-node": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/util-user-agent-node/-/util-user-agent-node-3.965.0.tgz", + "integrity": "sha512-kokIHUfNT3/P55E4fUJJrFHuuA9BbjFKUIxoLrd3UaRfdafT0ScRfg2eaZie6arf60EuhlUIZH0yALxttMEjxQ==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/middleware-user-agent": "3.965.0", + "@aws-sdk/types": "3.965.0", + "@smithy/node-config-provider": "^4.3.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + }, + "peerDependencies": { + "aws-crt": ">=1.0.0" + }, + "peerDependenciesMeta": { + "aws-crt": { + "optional": true + } + } + }, + "node_modules/@aws-sdk/xml-builder": { + "version": "3.965.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/xml-builder/-/xml-builder-3.965.0.tgz", + "integrity": "sha512-Tcod25/BTupraQwtb+Q+GX8bmEZfxIFjjJ/AvkhUZsZlkPeVluzq1uu3Oeqf145DCdMjzLIN6vab5MrykbDP+g==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.11.0", + "fast-xml-parser": "5.2.5", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@aws-sdk/xml-builder/node_modules/fast-xml-parser": { + "version": "5.2.5", + "resolved": "https://registry.npmjs.org/fast-xml-parser/-/fast-xml-parser-5.2.5.tgz", + "integrity": "sha512-pfX9uG9Ki0yekDHx2SiuRIyFdyAr1kMIMitPvb0YBo8SUfKvia7w7FIyd/l6av85pFYRhZscS75MwMnbvY+hcQ==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/NaturalIntelligence" + } + ], + "license": "MIT", + "dependencies": { + "strnum": "^2.1.0" + }, + "bin": { + "fxparser": "src/cli/cli.js" + } + }, + "node_modules/@aws-sdk/xml-builder/node_modules/strnum": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/strnum/-/strnum-2.1.2.tgz", + "integrity": "sha512-l63NF9y/cLROq/yqKXSLtcMeeyOfnSQlfMSlzFt/K73oIaD8DGaQWd7Z34X9GPiKqP5rbSh84Hl4bOlLcjiSrQ==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/NaturalIntelligence" + } + ], + "license": "MIT" + }, + "node_modules/@aws/lambda-invoke-store": { + "version": "0.2.3", + "resolved": "https://registry.npmjs.org/@aws/lambda-invoke-store/-/lambda-invoke-store-0.2.3.tgz", + "integrity": "sha512-oLvsaPMTBejkkmHhjf09xTgk71mOqyr/409NKhRIL08If7AhVfUsJhVsx386uJaqNd42v9kWamQ9lFbkoC2dYw==", + "license": "Apache-2.0", + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@babel/helper-string-parser": { + "version": "7.27.1", + "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.27.1.tgz", + "integrity": "sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-validator-identifier": { + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.28.5.tgz", + "integrity": "sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/parser": { + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.28.5.tgz", + "integrity": "sha512-KKBU1VGYR7ORr3At5HAtUQ+TV3SzRCXmA/8OdDZiLDBIZxVyzXuztPjfLd3BV1PRAQGCMWWSHYhL0F8d5uHBDQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/types": "^7.28.5" + }, + "bin": { + "parser": "bin/babel-parser.js" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@babel/runtime": { + "version": "7.28.4", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.28.4.tgz", + "integrity": "sha512-Q/N6JNWvIvPnLDvjlE1OUBLPQHH6l3CltCEsHIujp45zQUSSh8K+gHnaEX45yAT1nyngnINhvWtzN+Nb9D8RAQ==", + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/types": { + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.28.5.tgz", + "integrity": "sha512-qQ5m48eI/MFLQ5PxQj4PFaprjyCTLI37ElWMmNs0K8Lk3dVeOdNpB3ks8jc7yM5CDmVC73eMVk/trk3fgmrUpA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/helper-string-parser": "^7.27.1", + "@babel/helper-validator-identifier": "^7.28.5" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@bcoe/v8-coverage": { + "version": "0.2.3", + "resolved": "https://registry.npmjs.org/@bcoe/v8-coverage/-/v8-coverage-0.2.3.tgz", + "integrity": "sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==", + "dev": true, + "license": "MIT" + }, + "node_modules/@cfworker/json-schema": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/@cfworker/json-schema/-/json-schema-4.1.1.tgz", + "integrity": "sha512-gAmrUZSGtKc3AiBL71iNWxDsyUC5uMaKKGdvzYsBoTW/xi42JQHl7eKV2OYzCUqvc+D2RCcf7EXY2iCyFIk6og==", + "license": "MIT" + }, + "node_modules/@esbuild/aix-ppc64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.21.5.tgz", + "integrity": "sha512-1SDgH6ZSPTlggy1yI6+Dbkiz8xzpHJEVAlF/AM1tHPLsf5STom9rwtjE4hKAF20FfXXNTFqEYXyJNWh1GiZedQ==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "aix" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/android-arm": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.21.5.tgz", + "integrity": "sha512-vCPvzSjpPHEi1siZdlvAlsPxXl7WbOVUBBAowWug4rJHb68Ox8KualB+1ocNvT5fjv6wpkX6o/iEpbDrf68zcg==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/android-arm64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.21.5.tgz", + "integrity": "sha512-c0uX9VAUBQ7dTDCjq+wdyGLowMdtR/GoC2U5IYk/7D1H1JYC0qseD7+11iMP2mRLN9RcCMRcjC4YMclCzGwS/A==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/android-x64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.21.5.tgz", + "integrity": "sha512-D7aPRUUNHRBwHxzxRvp856rjUHRFW1SdQATKXH2hqA0kAZb1hKmi02OpYRacl0TxIGz/ZmXWlbZgjwWYaCakTA==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/darwin-arm64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.21.5.tgz", + "integrity": "sha512-DwqXqZyuk5AiWWf3UfLiRDJ5EDd49zg6O9wclZ7kUMv2WRFr4HKjXp/5t8JZ11QbQfUS6/cRCKGwYhtNAY88kQ==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/darwin-x64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.21.5.tgz", + "integrity": "sha512-se/JjF8NlmKVG4kNIuyWMV/22ZaerB+qaSi5MdrXtd6R08kvs2qCN4C09miupktDitvh8jRFflwGFBQcxZRjbw==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/freebsd-arm64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.21.5.tgz", + "integrity": "sha512-5JcRxxRDUJLX8JXp/wcBCy3pENnCgBR9bN6JsY4OmhfUtIHe3ZW0mawA7+RDAcMLrMIZaf03NlQiX9DGyB8h4g==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/freebsd-x64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.21.5.tgz", + "integrity": "sha512-J95kNBj1zkbMXtHVH29bBriQygMXqoVQOQYA+ISs0/2l3T9/kj42ow2mpqerRBxDJnmkUDCaQT/dfNXWX/ZZCQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-arm": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.21.5.tgz", + "integrity": "sha512-bPb5AHZtbeNGjCKVZ9UGqGwo8EUu4cLq68E95A53KlxAPRmUyYv2D6F0uUI65XisGOL1hBP5mTronbgo+0bFcA==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-arm64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.21.5.tgz", + "integrity": "sha512-ibKvmyYzKsBeX8d8I7MH/TMfWDXBF3db4qM6sy+7re0YXya+K1cem3on9XgdT2EQGMu4hQyZhan7TeQ8XkGp4Q==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-ia32": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.21.5.tgz", + "integrity": "sha512-YvjXDqLRqPDl2dvRODYmmhz4rPeVKYvppfGYKSNGdyZkA01046pLWyRKKI3ax8fbJoK5QbxblURkwK/MWY18Tg==", + "cpu": [ + "ia32" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-loong64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.21.5.tgz", + "integrity": "sha512-uHf1BmMG8qEvzdrzAqg2SIG/02+4/DHB6a9Kbya0XDvwDEKCoC8ZRWI5JJvNdUjtciBGFQ5PuBlpEOXQj+JQSg==", + "cpu": [ + "loong64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-mips64el": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.21.5.tgz", + "integrity": "sha512-IajOmO+KJK23bj52dFSNCMsz1QP1DqM6cwLUv3W1QwyxkyIWecfafnI555fvSGqEKwjMXVLokcV5ygHW5b3Jbg==", + "cpu": [ + "mips64el" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-ppc64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.21.5.tgz", + "integrity": "sha512-1hHV/Z4OEfMwpLO8rp7CvlhBDnjsC3CttJXIhBi+5Aj5r+MBvy4egg7wCbe//hSsT+RvDAG7s81tAvpL2XAE4w==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-riscv64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.21.5.tgz", + "integrity": "sha512-2HdXDMd9GMgTGrPWnJzP2ALSokE/0O5HhTUvWIbD3YdjME8JwvSCnNGBnTThKGEB91OZhzrJ4qIIxk/SBmyDDA==", + "cpu": [ + "riscv64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-s390x": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.21.5.tgz", + "integrity": "sha512-zus5sxzqBJD3eXxwvjN1yQkRepANgxE9lgOW2qLnmr8ikMTphkjgXu1HR01K4FJg8h1kEEDAqDcZQtbrRnB41A==", + "cpu": [ + "s390x" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-x64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.21.5.tgz", + "integrity": "sha512-1rYdTpyv03iycF1+BhzrzQJCdOuAOtaqHTWJZCWvijKD2N5Xu0TtVC8/+1faWqcP9iBCWOmjmhoH94dH82BxPQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/netbsd-x64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.21.5.tgz", + "integrity": "sha512-Woi2MXzXjMULccIwMnLciyZH4nCIMpWQAs049KEeMvOcNADVxo0UBIQPfSmxB3CWKedngg7sWZdLvLczpe0tLg==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "netbsd" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/openbsd-x64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.21.5.tgz", + "integrity": "sha512-HLNNw99xsvx12lFBUwoT8EVCsSvRNDVxNpjZ7bPn947b8gJPzeHWyNVhFsaerc0n3TsbOINvRP2byTZ5LKezow==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/sunos-x64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.21.5.tgz", + "integrity": "sha512-6+gjmFpfy0BHU5Tpptkuh8+uw3mnrvgs+dSPQXQOv3ekbordwnzTVEb4qnIvQcYXq6gzkyTnoZ9dZG+D4garKg==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "sunos" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/win32-arm64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.21.5.tgz", + "integrity": "sha512-Z0gOTd75VvXqyq7nsl93zwahcTROgqvuAcYDUr+vOv8uHhNSKROyU961kgtCD1e95IqPKSQKH7tBTslnS3tA8A==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/win32-ia32": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.21.5.tgz", + "integrity": "sha512-SWXFF1CL2RVNMaVs+BBClwtfZSvDgtL//G/smwAc5oVK/UPu2Gu9tIaRgFmYFFKrmg3SyAjSrElf0TiJ1v8fYA==", + "cpu": [ + "ia32" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/win32-x64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.21.5.tgz", + "integrity": "sha512-tQd/1efJuzPC6rCFwEvLtci/xNFcTZknmXs98FYDfGE4wP9ClFV98nyKrzJKVPMhdDnjzLhdUyMX4PsQAPjwIw==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@eslint-community/eslint-utils": { + "version": "4.9.1", + "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.9.1.tgz", + "integrity": "sha512-phrYmNiYppR7znFEdqgfWHXR6NCkZEK7hwWDHZUjit/2/U0r6XvkDl0SYnoM51Hq7FhCGdLDT6zxCCOY1hexsQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "eslint-visitor-keys": "^3.4.3" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + }, + "peerDependencies": { + "eslint": "^6.0.0 || ^7.0.0 || >=8.0.0" + } + }, + "node_modules/@eslint-community/regexpp": { + "version": "4.12.2", + "resolved": "https://registry.npmjs.org/@eslint-community/regexpp/-/regexpp-4.12.2.tgz", + "integrity": "sha512-EriSTlt5OC9/7SXkRSCAhfSxxoSUgBm33OH+IkwbdpgoqsSsUg7y3uh+IICI/Qg4BBWr3U2i39RpmycbxMq4ew==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^12.0.0 || ^14.0.0 || >=16.0.0" + } + }, + "node_modules/@eslint/config-array": { + "version": "0.21.1", + "resolved": "https://registry.npmjs.org/@eslint/config-array/-/config-array-0.21.1.tgz", + "integrity": "sha512-aw1gNayWpdI/jSYVgzN5pL0cfzU02GT3NBpeT/DXbx1/1x7ZKxFPd9bwrzygx/qiwIQiJ1sw/zD8qY/kRvlGHA==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@eslint/object-schema": "^2.1.7", + "debug": "^4.3.1", + "minimatch": "^3.1.2" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@eslint/config-array/node_modules/brace-expansion": { + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/@eslint/config-array/node_modules/minimatch": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/@eslint/config-helpers": { + "version": "0.4.2", + "resolved": "https://registry.npmjs.org/@eslint/config-helpers/-/config-helpers-0.4.2.tgz", + "integrity": "sha512-gBrxN88gOIf3R7ja5K9slwNayVcZgK6SOUORm2uBzTeIEfeVaIhOpCtTox3P6R7o2jLFwLFTLnC7kU/RGcYEgw==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@eslint/core": "^0.17.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@eslint/core": { + "version": "0.17.0", + "resolved": "https://registry.npmjs.org/@eslint/core/-/core-0.17.0.tgz", + "integrity": "sha512-yL/sLrpmtDaFEiUj1osRP4TI2MDz1AddJL+jZ7KSqvBuliN4xqYY54IfdN8qD8Toa6g1iloph1fxQNkjOxrrpQ==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@types/json-schema": "^7.0.15" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@eslint/eslintrc": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-3.3.3.tgz", + "integrity": "sha512-Kr+LPIUVKz2qkx1HAMH8q1q6azbqBAsXJUxBl/ODDuVPX45Z9DfwB8tPjTi6nNZ8BuM3nbJxC5zCAg5elnBUTQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "ajv": "^6.12.4", + "debug": "^4.3.2", + "espree": "^10.0.1", + "globals": "^14.0.0", + "ignore": "^5.2.0", + "import-fresh": "^3.2.1", + "js-yaml": "^4.1.1", + "minimatch": "^3.1.2", + "strip-json-comments": "^3.1.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/@eslint/eslintrc/node_modules/brace-expansion": { + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/@eslint/eslintrc/node_modules/ignore": { + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.2.tgz", + "integrity": "sha512-hsBTNUqQTDwkWtcdYI2i06Y/nUBEsNEDJKjWdigLvegy8kDuJAS8uRlpkkcQpyEXL0Z/pjDy5HBmMjRCJ2gq+g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 4" + } + }, + "node_modules/@eslint/eslintrc/node_modules/minimatch": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/@eslint/js": { + "version": "9.39.2", + "resolved": "https://registry.npmjs.org/@eslint/js/-/js-9.39.2.tgz", + "integrity": "sha512-q1mjIoW1VX4IvSocvM/vbTiveKC4k9eLrajNEuSsmjymSDEbpGddtpfOoN7YGAqBK3NG+uqo8ia4PDTt8buCYA==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://eslint.org/donate" + } + }, + "node_modules/@eslint/object-schema": { + "version": "2.1.7", + "resolved": "https://registry.npmjs.org/@eslint/object-schema/-/object-schema-2.1.7.tgz", + "integrity": "sha512-VtAOaymWVfZcmZbp6E2mympDIHvyjXs/12LqWYjVw6qjrfF+VK+fyG33kChz3nnK+SU5/NeHOqrTEHS8sXO3OA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@eslint/plugin-kit": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/@eslint/plugin-kit/-/plugin-kit-0.4.1.tgz", + "integrity": "sha512-43/qtrDUokr7LJqoF2c3+RInu/t4zfrpYdoSDfYyhg52rwLV6TnOvdG4fXm7IkSB3wErkcmJS9iEhjVtOSEjjA==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@eslint/core": "^0.17.0", + "levn": "^0.4.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + } + }, + "node_modules/@google/generative-ai": { + "version": "0.24.1", + "resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.24.1.tgz", + "integrity": "sha512-MqO+MLfM6kjxcKoy0p1wRzG3b4ZZXtPI+z2IE26UogS2Cm/XHO+7gGRBh6gcJsOiIVoH93UwKvW4HdgiOZCy9Q==", + "license": "Apache-2.0", + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@humanfs/core": { + "version": "0.19.1", + "resolved": "https://registry.npmjs.org/@humanfs/core/-/core-0.19.1.tgz", + "integrity": "sha512-5DyQ4+1JEUzejeK1JGICcideyfUbGixgS9jNgex5nqkW+cY7WZhxBigmieN5Qnw9ZosSNVC9KQKyb+GUaGyKUA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=18.18.0" + } + }, + "node_modules/@humanfs/node": { + "version": "0.16.7", + "resolved": "https://registry.npmjs.org/@humanfs/node/-/node-0.16.7.tgz", + "integrity": "sha512-/zUx+yOsIrG4Y43Eh2peDeKCxlRt/gET6aHfaKpuq267qXdYDFViVHfMaLyygZOnl0kGWxFIgsBy8QFuTLUXEQ==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@humanfs/core": "^0.19.1", + "@humanwhocodes/retry": "^0.4.0" + }, + "engines": { + "node": ">=18.18.0" + } + }, + "node_modules/@humanwhocodes/module-importer": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@humanwhocodes/module-importer/-/module-importer-1.0.1.tgz", + "integrity": "sha512-bxveV4V8v5Yb4ncFTT3rPSgZBOpCkjfK0y4oVVVJwIuDVBRMDXrPyXRL988i5ap9m9bnyEEjWfm5WkBmtffLfA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=12.22" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/nzakas" + } + }, + "node_modules/@humanwhocodes/retry": { + "version": "0.4.3", + "resolved": "https://registry.npmjs.org/@humanwhocodes/retry/-/retry-0.4.3.tgz", + "integrity": "sha512-bV0Tgo9K4hfPCek+aMAn81RppFKv2ySDQeMoSZuvTASywNTnVJCArCZE2FWqpvIatKu7VMRLWlR1EazvVhDyhQ==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=18.18" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/nzakas" + } + }, + "node_modules/@isaacs/cliui": { + "version": "8.0.2", + "resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz", + "integrity": "sha512-O8jcjabXaleOG9DQ0+ARXWZBTfnP4WNAqzuiJK7ll44AmxGKv/J2M4TPjxjY3znBCfvBXFzucm1twdyFybFqEA==", + "dev": true, + "license": "ISC", + "dependencies": { + "string-width": "^5.1.2", + "string-width-cjs": "npm:string-width@^4.2.0", + "strip-ansi": "^7.0.1", + "strip-ansi-cjs": "npm:strip-ansi@^6.0.1", + "wrap-ansi": "^8.1.0", + "wrap-ansi-cjs": "npm:wrap-ansi@^7.0.0" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/@istanbuljs/schema": { + "version": "0.1.3", + "resolved": "https://registry.npmjs.org/@istanbuljs/schema/-/schema-0.1.3.tgz", + "integrity": "sha512-ZXRY4jNvVgSVQ8DL3LTcakaAtXwTVUxE81hslsyD2AtoXW/wVob10HkOJ1X/pAlcI7D+2YoZKg5do8G/w6RYgA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/@jridgewell/gen-mapping": { + "version": "0.3.13", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.13.tgz", + "integrity": "sha512-2kkt/7niJ6MgEPxF0bYdQ6etZaA+fQvDcLKckhy1yIQOzaoKjBBjSj63/aLVjYE3qhRt5dvM+uUyfCg6UKCBbA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/sourcemap-codec": "^1.5.0", + "@jridgewell/trace-mapping": "^0.3.24" + } + }, + "node_modules/@jridgewell/resolve-uri": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", + "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/sourcemap-codec": { + "version": "1.5.5", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.5.tgz", + "integrity": "sha512-cYQ9310grqxueWbl+WuIUIaiUaDcj7WOq5fVhEljNVgRfOUhY9fy2zTvfoqWsnebh8Sl70VScFbICvJnLKB0Og==", + "dev": true, + "license": "MIT" + }, + "node_modules/@jridgewell/trace-mapping": { + "version": "0.3.31", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.31.tgz", + "integrity": "sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/resolve-uri": "^3.1.0", + "@jridgewell/sourcemap-codec": "^1.4.14" + } + }, + "node_modules/@langchain/anthropic": { + "version": "0.3.34", + "resolved": "https://registry.npmjs.org/@langchain/anthropic/-/anthropic-0.3.34.tgz", + "integrity": "sha512-8bOW1A2VHRCjbzdYElrjxutKNs9NSIxYRGtR+OJWVzluMqoKKh2NmmFrpPizEyqCUEG2tTq5xt6XA1lwfqMJRA==", + "license": "MIT", + "dependencies": { + "@anthropic-ai/sdk": "^0.65.0", + "fast-xml-parser": "^4.4.1" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "@langchain/core": ">=0.3.58 <0.4.0" + } + }, + "node_modules/@langchain/anthropic/node_modules/@anthropic-ai/sdk": { + "version": "0.65.0", + "resolved": "https://registry.npmjs.org/@anthropic-ai/sdk/-/sdk-0.65.0.tgz", + "integrity": "sha512-zIdPOcrCVEI8t3Di40nH4z9EoeyGZfXbYSvWdDLsB/KkaSYMnEgC7gmcgWu83g2NTn1ZTpbMvpdttWDGGIk6zw==", + "license": "MIT", + "dependencies": { + "json-schema-to-ts": "^3.1.1" + }, + "bin": { + "anthropic-ai-sdk": "bin/cli" + }, + "peerDependencies": { + "zod": "^3.25.0 || ^4.0.0" + }, + "peerDependenciesMeta": { + "zod": { + "optional": true + } + } + }, + "node_modules/@langchain/core": { + "version": "0.3.80", + "resolved": "https://registry.npmjs.org/@langchain/core/-/core-0.3.80.tgz", + "integrity": "sha512-vcJDV2vk1AlCwSh3aBm/urQ1ZrlXFFBocv11bz/NBUfLWD5/UDNMzwPdaAd2dKvNmTWa9FM2lirLU3+JCf4cRA==", + "license": "MIT", + "dependencies": { + "@cfworker/json-schema": "^4.0.2", + "ansi-styles": "^5.0.0", + "camelcase": "6", + "decamelize": "1.2.0", + "js-tiktoken": "^1.0.12", + "langsmith": "^0.3.67", + "mustache": "^4.2.0", + "p-queue": "^6.6.2", + "p-retry": "4", + "uuid": "^10.0.0", + "zod": "^3.25.32", + "zod-to-json-schema": "^3.22.3" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@langchain/google-genai": { + "version": "0.1.12", + "resolved": "https://registry.npmjs.org/@langchain/google-genai/-/google-genai-0.1.12.tgz", + "integrity": "sha512-0Ea0E2g63ejCuormVxbuoyJQ5BYN53i2/fb6WP8bMKzyh+y43R13V8JqOtr3e/GmgNyv3ou/VeaZjx7KAvu/0g==", + "license": "MIT", + "dependencies": { + "@google/generative-ai": "^0.24.0", + "zod-to-json-schema": "^3.22.4" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "@langchain/core": ">=0.3.17 <0.4.0" + } + }, + "node_modules/@langchain/openai": { + "version": "0.3.17", + "resolved": "https://registry.npmjs.org/@langchain/openai/-/openai-0.3.17.tgz", + "integrity": "sha512-uw4po32OKptVjq+CYHrumgbfh4NuD7LqyE+ZgqY9I/LrLc6bHLMc+sisHmI17vgek0K/yqtarI0alPJbzrwyag==", + "license": "MIT", + "dependencies": { + "js-tiktoken": "^1.0.12", + "openai": "^4.77.0", + "zod": "^3.22.4", + "zod-to-json-schema": "^3.22.3" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "@langchain/core": ">=0.3.29 <0.4.0" + } + }, + "node_modules/@langchain/openai/node_modules/@types/node": { + "version": "18.19.130", + "resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.130.tgz", + "integrity": "sha512-GRaXQx6jGfL8sKfaIDD6OupbIHBr9jv7Jnaml9tB7l4v068PAOXqfcujMMo5PhbIs6ggR1XODELqahT2R8v0fg==", + "license": "MIT", + "dependencies": { + "undici-types": "~5.26.4" + } + }, + "node_modules/@langchain/openai/node_modules/openai": { + "version": "4.104.0", + "resolved": "https://registry.npmjs.org/openai/-/openai-4.104.0.tgz", + "integrity": "sha512-p99EFNsA/yX6UhVO93f5kJsDRLAg+CTA2RBqdHK4RtK8u5IJw32Hyb2dTGKbnnFmnuoBv5r7Z2CURI9sGZpSuA==", + "license": "Apache-2.0", + "dependencies": { + "@types/node": "^18.11.18", + "@types/node-fetch": "^2.6.4", + "abort-controller": "^3.0.0", + "agentkeepalive": "^4.2.1", + "form-data-encoder": "1.7.2", + "formdata-node": "^4.3.2", + "node-fetch": "^2.6.7" + }, + "bin": { + "openai": "bin/cli" + }, + "peerDependencies": { + "ws": "^8.18.0", + "zod": "^3.23.8" + }, + "peerDependenciesMeta": { + "ws": { + "optional": true + }, + "zod": { + "optional": true + } + } + }, + "node_modules/@langchain/openai/node_modules/undici-types": { + "version": "5.26.5", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", + "integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==", + "license": "MIT" + }, + "node_modules/@pkgjs/parseargs": { + "version": "0.11.0", + "resolved": "https://registry.npmjs.org/@pkgjs/parseargs/-/parseargs-0.11.0.tgz", + "integrity": "sha512-+1VkjdD0QBLPodGrJUeqarH8VAIvQODIbwh9XpP5Syisf7YoQgsJKPNFoqqLQlu+VQ/tVSshMR6loPMn8U+dPg==", + "dev": true, + "license": "MIT", + "optional": true, + "engines": { + "node": ">=14" + } + }, + "node_modules/@polka/url": { + "version": "1.0.0-next.29", + "resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.29.tgz", + "integrity": "sha512-wwQAWhWSuHaag8c4q/KN/vCoeOJYshAIvMQwD4GpSb3OiZklFfvAgmj0VCBBImRpuF/aFgIRzllXlVX93Jevww==", + "dev": true, + "license": "MIT" + }, + "node_modules/@rollup/rollup-android-arm-eabi": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.55.1.tgz", + "integrity": "sha512-9R0DM/ykwfGIlNu6+2U09ga0WXeZ9MRC2Ter8jnz8415VbuIykVuc6bhdrbORFZANDmTDvq26mJrEVTl8TdnDg==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ] + }, + "node_modules/@rollup/rollup-android-arm64": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.55.1.tgz", + "integrity": "sha512-eFZCb1YUqhTysgW3sj/55du5cG57S7UTNtdMjCW7LwVcj3dTTcowCsC8p7uBdzKsZYa8J7IDE8lhMI+HX1vQvg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ] + }, + "node_modules/@rollup/rollup-darwin-arm64": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.55.1.tgz", + "integrity": "sha512-p3grE2PHcQm2e8PSGZdzIhCKbMCw/xi9XvMPErPhwO17vxtvCN5FEA2mSLgmKlCjHGMQTP6phuQTYWUnKewwGg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/@rollup/rollup-darwin-x64": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.55.1.tgz", + "integrity": "sha512-rDUjG25C9qoTm+e02Esi+aqTKSBYwVTaoS1wxcN47/Luqef57Vgp96xNANwt5npq9GDxsH7kXxNkJVEsWEOEaQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/@rollup/rollup-freebsd-arm64": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.55.1.tgz", + "integrity": "sha512-+JiU7Jbp5cdxekIgdte0jfcu5oqw4GCKr6i3PJTlXTCU5H5Fvtkpbs4XJHRmWNXF+hKmn4v7ogI5OQPaupJgOg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ] + }, + "node_modules/@rollup/rollup-freebsd-x64": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.55.1.tgz", + "integrity": "sha512-V5xC1tOVWtLLmr3YUk2f6EJK4qksksOYiz/TCsFHu/R+woubcLWdC9nZQmwjOAbmExBIVKsm1/wKmEy4z4u4Bw==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ] + }, + "node_modules/@rollup/rollup-linux-arm-gnueabihf": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.55.1.tgz", + "integrity": "sha512-Rn3n+FUk2J5VWx+ywrG/HGPTD9jXNbicRtTM11e/uorplArnXZYsVifnPPqNNP5BsO3roI4n8332ukpY/zN7rQ==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-arm-musleabihf": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.55.1.tgz", + "integrity": "sha512-grPNWydeKtc1aEdrJDWk4opD7nFtQbMmV7769hiAaYyUKCT1faPRm2av8CX1YJsZ4TLAZcg9gTR1KvEzoLjXkg==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-arm64-gnu": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.55.1.tgz", + "integrity": "sha512-a59mwd1k6x8tXKcUxSyISiquLwB5pX+fJW9TkWU46lCqD/GRDe9uDN31jrMmVP3feI3mhAdvcCClhV8V5MhJFQ==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-arm64-musl": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.55.1.tgz", + "integrity": "sha512-puS1MEgWX5GsHSoiAsF0TYrpomdvkaXm0CofIMG5uVkP6IBV+ZO9xhC5YEN49nsgYo1DuuMquF9+7EDBVYu4uA==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-loong64-gnu": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.55.1.tgz", + "integrity": "sha512-r3Wv40in+lTsULSb6nnoudVbARdOwb2u5fpeoOAZjFLznp6tDU8kd+GTHmJoqZ9lt6/Sys33KdIHUaQihFcu7g==", + "cpu": [ + "loong64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-loong64-musl": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.55.1.tgz", + "integrity": "sha512-MR8c0+UxAlB22Fq4R+aQSPBayvYa3+9DrwG/i1TKQXFYEaoW3B5b/rkSRIypcZDdWjWnpcvxbNaAJDcSbJU3Lw==", + "cpu": [ + "loong64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-ppc64-gnu": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.55.1.tgz", + "integrity": "sha512-3KhoECe1BRlSYpMTeVrD4sh2Pw2xgt4jzNSZIIPLFEsnQn9gAnZagW9+VqDqAHgm1Xc77LzJOo2LdigS5qZ+gw==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-ppc64-musl": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.55.1.tgz", + "integrity": "sha512-ziR1OuZx0vdYZZ30vueNZTg73alF59DicYrPViG0NEgDVN8/Jl87zkAPu4u6VjZST2llgEUjaiNl9JM6HH1Vdw==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-riscv64-gnu": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.55.1.tgz", + "integrity": "sha512-uW0Y12ih2XJRERZ4jAfKamTyIHVMPQnTZcQjme2HMVDAHY4amf5u414OqNYC+x+LzRdRcnIG1YodLrrtA8xsxw==", + "cpu": [ + "riscv64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-riscv64-musl": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.55.1.tgz", + "integrity": "sha512-u9yZ0jUkOED1BFrqu3BwMQoixvGHGZ+JhJNkNKY/hyoEgOwlqKb62qu+7UjbPSHYjiVy8kKJHvXKv5coH4wDeg==", + "cpu": [ + "riscv64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-s390x-gnu": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.55.1.tgz", + "integrity": "sha512-/0PenBCmqM4ZUd0190j7J0UsQ/1nsi735iPRakO8iPciE7BQ495Y6msPzaOmvx0/pn+eJVVlZrNrSh4WSYLxNg==", + "cpu": [ + "s390x" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-x64-gnu": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.55.1.tgz", + "integrity": "sha512-a8G4wiQxQG2BAvo+gU6XrReRRqj+pLS2NGXKm8io19goR+K8lw269eTrPkSdDTALwMmJp4th2Uh0D8J9bEV1vg==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-x64-musl": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.55.1.tgz", + "integrity": "sha512-bD+zjpFrMpP/hqkfEcnjXWHMw5BIghGisOKPj+2NaNDuVT+8Ds4mPf3XcPHuat1tz89WRL+1wbcxKY3WSbiT7w==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-openbsd-x64": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.55.1.tgz", + "integrity": "sha512-eLXw0dOiqE4QmvikfQ6yjgkg/xDM+MdU9YJuP4ySTibXU0oAvnEWXt7UDJmD4UkYialMfOGFPJnIHSe/kdzPxg==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ] + }, + "node_modules/@rollup/rollup-openharmony-arm64": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.55.1.tgz", + "integrity": "sha512-xzm44KgEP11te3S2HCSyYf5zIzWmx3n8HDCc7EE59+lTcswEWNpvMLfd9uJvVX8LCg9QWG67Xt75AuHn4vgsXw==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openharmony" + ] + }, + "node_modules/@rollup/rollup-win32-arm64-msvc": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.55.1.tgz", + "integrity": "sha512-yR6Bl3tMC/gBok5cz/Qi0xYnVbIxGx5Fcf/ca0eB6/6JwOY+SRUcJfI0OpeTpPls7f194as62thCt/2BjxYN8g==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@rollup/rollup-win32-ia32-msvc": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.55.1.tgz", + "integrity": "sha512-3fZBidchE0eY0oFZBnekYCfg+5wAB0mbpCBuofh5mZuzIU/4jIVkbESmd2dOsFNS78b53CYv3OAtwqkZZmU5nA==", + "cpu": [ + "ia32" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@rollup/rollup-win32-x64-gnu": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.55.1.tgz", + "integrity": "sha512-xGGY5pXj69IxKb4yv/POoocPy/qmEGhimy/FoTpTSVju3FYXUQQMFCaZZXJVidsmGxRioZAwpThl/4zX41gRKg==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@rollup/rollup-win32-x64-msvc": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.55.1.tgz", + "integrity": "sha512-SPEpaL6DX4rmcXtnhdrQYgzQ5W2uW3SCJch88lB2zImhJRhIIK44fkUrgIV/Q8yUNfw5oyZ5vkeQsZLhCb06lw==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@smithy/abort-controller": { + "version": "4.2.7", + "resolved": "https://registry.npmjs.org/@smithy/abort-controller/-/abort-controller-4.2.7.tgz", + "integrity": "sha512-rzMY6CaKx2qxrbYbqjXWS0plqEy7LOdKHS0bg4ixJ6aoGDPNUcLWk/FRNuCILh7GKLG9TFUXYYeQQldMBBwuyw==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/config-resolver": { + "version": "4.4.5", + "resolved": "https://registry.npmjs.org/@smithy/config-resolver/-/config-resolver-4.4.5.tgz", + "integrity": "sha512-HAGoUAFYsUkoSckuKbCPayECeMim8pOu+yLy1zOxt1sifzEbrsRpYa+mKcMdiHKMeiqOibyPG0sFJnmaV/OGEg==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/node-config-provider": "^4.3.7", + "@smithy/types": "^4.11.0", + "@smithy/util-config-provider": "^4.2.0", + "@smithy/util-endpoints": "^3.2.7", + "@smithy/util-middleware": "^4.2.7", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/core": { + "version": "3.20.1", + "resolved": "https://registry.npmjs.org/@smithy/core/-/core-3.20.1.tgz", + "integrity": "sha512-wOboSEdQ85dbKAJ0zL+wQ6b0HTSBRhtGa0PYKysQXkRg+vK0tdCRRVruiFM2QMprkOQwSYOnwF4og96PAaEGag==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/middleware-serde": "^4.2.8", + "@smithy/protocol-http": "^5.3.7", + "@smithy/types": "^4.11.0", + "@smithy/util-base64": "^4.3.0", + "@smithy/util-body-length-browser": "^4.2.0", + "@smithy/util-middleware": "^4.2.7", + "@smithy/util-stream": "^4.5.8", + "@smithy/util-utf8": "^4.2.0", + "@smithy/uuid": "^1.1.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/credential-provider-imds": { + "version": "4.2.7", + "resolved": "https://registry.npmjs.org/@smithy/credential-provider-imds/-/credential-provider-imds-4.2.7.tgz", + "integrity": "sha512-CmduWdCiILCRNbQWFR0OcZlUPVtyE49Sr8yYL0rZQ4D/wKxiNzBNS/YHemvnbkIWj623fplgkexUd/c9CAKdoA==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/node-config-provider": "^4.3.7", + "@smithy/property-provider": "^4.2.7", + "@smithy/types": "^4.11.0", + "@smithy/url-parser": "^4.2.7", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/eventstream-codec": { + "version": "4.2.7", + "resolved": "https://registry.npmjs.org/@smithy/eventstream-codec/-/eventstream-codec-4.2.7.tgz", + "integrity": "sha512-DrpkEoM3j9cBBWhufqBwnbbn+3nf1N9FP6xuVJ+e220jbactKuQgaZwjwP5CP1t+O94brm2JgVMD2atMGX3xIQ==", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/crc32": "5.2.0", + "@smithy/types": "^4.11.0", + "@smithy/util-hex-encoding": "^4.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/eventstream-serde-browser": { + "version": "4.2.7", + "resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-browser/-/eventstream-serde-browser-4.2.7.tgz", + "integrity": "sha512-ujzPk8seYoDBmABDE5YqlhQZAXLOrtxtJLrbhHMKjBoG5b4dK4i6/mEU+6/7yXIAkqOO8sJ6YxZl+h0QQ1IJ7g==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/eventstream-serde-universal": "^4.2.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/eventstream-serde-config-resolver": { + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-config-resolver/-/eventstream-serde-config-resolver-4.3.7.tgz", + "integrity": "sha512-x7BtAiIPSaNaWuzm24Q/mtSkv+BrISO/fmheiJ39PKRNH3RmH2Hph/bUKSOBOBC9unqfIYDhKTHwpyZycLGPVQ==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/eventstream-serde-node": { + "version": "4.2.7", + "resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-node/-/eventstream-serde-node-4.2.7.tgz", + "integrity": "sha512-roySCtHC5+pQq5lK4be1fZ/WR6s/AxnPaLfCODIPArtN2du8s5Ot4mKVK3pPtijL/L654ws592JHJ1PbZFF6+A==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/eventstream-serde-universal": "^4.2.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/eventstream-serde-universal": { + "version": "4.2.7", + "resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-universal/-/eventstream-serde-universal-4.2.7.tgz", + "integrity": "sha512-QVD+g3+icFkThoy4r8wVFZMsIP08taHVKjE6Jpmz8h5CgX/kk6pTODq5cht0OMtcapUx+xrPzUTQdA+TmO0m1g==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/eventstream-codec": "^4.2.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/fetch-http-handler": { + "version": "5.3.8", + "resolved": "https://registry.npmjs.org/@smithy/fetch-http-handler/-/fetch-http-handler-5.3.8.tgz", + "integrity": "sha512-h/Fi+o7mti4n8wx1SR6UHWLaakwHRx29sizvp8OOm7iqwKGFneT06GCSFhml6Bha5BT6ot5pj3CYZnCHhGC2Rg==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/protocol-http": "^5.3.7", + "@smithy/querystring-builder": "^4.2.7", + "@smithy/types": "^4.11.0", + "@smithy/util-base64": "^4.3.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/hash-node": { + "version": "4.2.7", + "resolved": "https://registry.npmjs.org/@smithy/hash-node/-/hash-node-4.2.7.tgz", + "integrity": "sha512-PU/JWLTBCV1c8FtB8tEFnY4eV1tSfBc7bDBADHfn1K+uRbPgSJ9jnJp0hyjiFN2PMdPzxsf1Fdu0eo9fJ760Xw==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.11.0", + "@smithy/util-buffer-from": "^4.2.0", + "@smithy/util-utf8": "^4.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/invalid-dependency": { + "version": "4.2.7", + "resolved": "https://registry.npmjs.org/@smithy/invalid-dependency/-/invalid-dependency-4.2.7.tgz", + "integrity": "sha512-ncvgCr9a15nPlkhIUx3CU4d7E7WEuVJOV7fS7nnK2hLtPK9tYRBkMHQbhXU1VvvKeBm/O0x26OEoBq+ngFpOEQ==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/is-array-buffer": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/@smithy/is-array-buffer/-/is-array-buffer-4.2.0.tgz", + "integrity": "sha512-DZZZBvC7sjcYh4MazJSGiWMI2L7E0oCiRHREDzIxi/M2LY79/21iXt6aPLHge82wi5LsuRF5A06Ds3+0mlh6CQ==", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/middleware-content-length": { + "version": "4.2.7", + "resolved": "https://registry.npmjs.org/@smithy/middleware-content-length/-/middleware-content-length-4.2.7.tgz", + "integrity": "sha512-GszfBfCcvt7kIbJ41LuNa5f0wvQCHhnGx/aDaZJCCT05Ld6x6U2s0xsc/0mBFONBZjQJp2U/0uSJ178OXOwbhg==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/protocol-http": "^5.3.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/middleware-endpoint": { + "version": "4.4.2", + "resolved": "https://registry.npmjs.org/@smithy/middleware-endpoint/-/middleware-endpoint-4.4.2.tgz", + "integrity": "sha512-mqpAdux0BNmZu/SqkFhQEnod4fX23xxTvU2LUpmKp0JpSI+kPYCiHJMmzREr8yxbNxKL2/DU1UZm9i++ayU+2g==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/core": "^3.20.1", + "@smithy/middleware-serde": "^4.2.8", + "@smithy/node-config-provider": "^4.3.7", + "@smithy/shared-ini-file-loader": "^4.4.2", + "@smithy/types": "^4.11.0", + "@smithy/url-parser": "^4.2.7", + "@smithy/util-middleware": "^4.2.7", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/middleware-retry": { + "version": "4.4.18", + "resolved": "https://registry.npmjs.org/@smithy/middleware-retry/-/middleware-retry-4.4.18.tgz", + "integrity": "sha512-E5hulijA59nBk/zvcwVMaS7FG7Y4l6hWA9vrW018r+8kiZef4/ETQaPI4oY+3zsy9f6KqDv3c4VKtO4DwwgpCg==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/node-config-provider": "^4.3.7", + "@smithy/protocol-http": "^5.3.7", + "@smithy/service-error-classification": "^4.2.7", + "@smithy/smithy-client": "^4.10.3", + "@smithy/types": "^4.11.0", + "@smithy/util-middleware": "^4.2.7", + "@smithy/util-retry": "^4.2.7", + "@smithy/uuid": "^1.1.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/middleware-serde": { + "version": "4.2.8", + "resolved": "https://registry.npmjs.org/@smithy/middleware-serde/-/middleware-serde-4.2.8.tgz", + "integrity": "sha512-8rDGYen5m5+NV9eHv9ry0sqm2gI6W7mc1VSFMtn6Igo25S507/HaOX9LTHAS2/J32VXD0xSzrY0H5FJtOMS4/w==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/protocol-http": "^5.3.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/middleware-stack": { + "version": "4.2.7", + "resolved": "https://registry.npmjs.org/@smithy/middleware-stack/-/middleware-stack-4.2.7.tgz", + "integrity": "sha512-bsOT0rJ+HHlZd9crHoS37mt8qRRN/h9jRve1SXUhVbkRzu0QaNYZp1i1jha4n098tsvROjcwfLlfvcFuJSXEsw==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/node-config-provider": { + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/@smithy/node-config-provider/-/node-config-provider-4.3.7.tgz", + "integrity": "sha512-7r58wq8sdOcrwWe+klL9y3bc4GW1gnlfnFOuL7CXa7UzfhzhxKuzNdtqgzmTV+53lEp9NXh5hY/S4UgjLOzPfw==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/property-provider": "^4.2.7", + "@smithy/shared-ini-file-loader": "^4.4.2", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/node-http-handler": { + "version": "4.4.7", + "resolved": "https://registry.npmjs.org/@smithy/node-http-handler/-/node-http-handler-4.4.7.tgz", + "integrity": "sha512-NELpdmBOO6EpZtWgQiHjoShs1kmweaiNuETUpuup+cmm/xJYjT4eUjfhrXRP4jCOaAsS3c3yPsP3B+K+/fyPCQ==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/abort-controller": "^4.2.7", + "@smithy/protocol-http": "^5.3.7", + "@smithy/querystring-builder": "^4.2.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/property-provider": { + "version": "4.2.7", + "resolved": "https://registry.npmjs.org/@smithy/property-provider/-/property-provider-4.2.7.tgz", + "integrity": "sha512-jmNYKe9MGGPoSl/D7JDDs1C8b3dC8f/w78LbaVfoTtWy4xAd5dfjaFG9c9PWPihY4ggMQNQSMtzU77CNgAJwmA==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/protocol-http": { + "version": "5.3.7", + "resolved": "https://registry.npmjs.org/@smithy/protocol-http/-/protocol-http-5.3.7.tgz", + "integrity": "sha512-1r07pb994I20dD/c2seaZhoCuNYm0rWrvBxhCQ70brNh11M5Ml2ew6qJVo0lclB3jMIXirD4s2XRXRe7QEi0xA==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/querystring-builder": { + "version": "4.2.7", + "resolved": "https://registry.npmjs.org/@smithy/querystring-builder/-/querystring-builder-4.2.7.tgz", + "integrity": "sha512-eKONSywHZxK4tBxe2lXEysh8wbBdvDWiA+RIuaxZSgCMmA0zMgoDpGLJhnyj+c0leOQprVnXOmcB4m+W9Rw7sg==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.11.0", + "@smithy/util-uri-escape": "^4.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/querystring-parser": { + "version": "4.2.7", + "resolved": "https://registry.npmjs.org/@smithy/querystring-parser/-/querystring-parser-4.2.7.tgz", + "integrity": "sha512-3X5ZvzUHmlSTHAXFlswrS6EGt8fMSIxX/c3Rm1Pni3+wYWB6cjGocmRIoqcQF9nU5OgGmL0u7l9m44tSUpfj9w==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/service-error-classification": { + "version": "4.2.7", + "resolved": "https://registry.npmjs.org/@smithy/service-error-classification/-/service-error-classification-4.2.7.tgz", + "integrity": "sha512-YB7oCbukqEb2Dlh3340/8g8vNGbs/QsNNRms+gv3N2AtZz9/1vSBx6/6tpwQpZMEJFs7Uq8h4mmOn48ZZ72MkA==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.11.0" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/shared-ini-file-loader": { + "version": "4.4.2", + "resolved": "https://registry.npmjs.org/@smithy/shared-ini-file-loader/-/shared-ini-file-loader-4.4.2.tgz", + "integrity": "sha512-M7iUUff/KwfNunmrgtqBfvZSzh3bmFgv/j/t1Y1dQ+8dNo34br1cqVEqy6v0mYEgi0DkGO7Xig0AnuOaEGVlcg==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/signature-v4": { + "version": "5.3.7", + "resolved": "https://registry.npmjs.org/@smithy/signature-v4/-/signature-v4-5.3.7.tgz", + "integrity": "sha512-9oNUlqBlFZFOSdxgImA6X5GFuzE7V2H7VG/7E70cdLhidFbdtvxxt81EHgykGK5vq5D3FafH//X+Oy31j3CKOg==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/is-array-buffer": "^4.2.0", + "@smithy/protocol-http": "^5.3.7", + "@smithy/types": "^4.11.0", + "@smithy/util-hex-encoding": "^4.2.0", + "@smithy/util-middleware": "^4.2.7", + "@smithy/util-uri-escape": "^4.2.0", + "@smithy/util-utf8": "^4.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/smithy-client": { + "version": "4.10.3", + "resolved": "https://registry.npmjs.org/@smithy/smithy-client/-/smithy-client-4.10.3.tgz", + "integrity": "sha512-EfECiO/0fAfb590LBnUe7rI5ux7XfquQ8LBzTe7gxw0j9QW/q8UT/EHWHlxV/+jhQ3+Ssga9uUYXCQgImGMbNg==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/core": "^3.20.1", + "@smithy/middleware-endpoint": "^4.4.2", + "@smithy/middleware-stack": "^4.2.7", + "@smithy/protocol-http": "^5.3.7", + "@smithy/types": "^4.11.0", + "@smithy/util-stream": "^4.5.8", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/types": { + "version": "4.11.0", + "resolved": "https://registry.npmjs.org/@smithy/types/-/types-4.11.0.tgz", + "integrity": "sha512-mlrmL0DRDVe3mNrjTcVcZEgkFmufITfUAPBEA+AHYiIeYyJebso/He1qLbP3PssRe22KUzLRpQSdBPbXdgZ2VA==", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/url-parser": { + "version": "4.2.7", + "resolved": "https://registry.npmjs.org/@smithy/url-parser/-/url-parser-4.2.7.tgz", + "integrity": "sha512-/RLtVsRV4uY3qPWhBDsjwahAtt3x2IsMGnP5W1b2VZIe+qgCqkLxI1UOHDZp1Q1QSOrdOR32MF3Ph2JfWT1VHg==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/querystring-parser": "^4.2.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-base64": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/@smithy/util-base64/-/util-base64-4.3.0.tgz", + "integrity": "sha512-GkXZ59JfyxsIwNTWFnjmFEI8kZpRNIBfxKjv09+nkAWPt/4aGaEWMM04m4sxgNVWkbt2MdSvE3KF/PfX4nFedQ==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/util-buffer-from": "^4.2.0", + "@smithy/util-utf8": "^4.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-body-length-browser": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/@smithy/util-body-length-browser/-/util-body-length-browser-4.2.0.tgz", + "integrity": "sha512-Fkoh/I76szMKJnBXWPdFkQJl2r9SjPt3cMzLdOB6eJ4Pnpas8hVoWPYemX/peO0yrrvldgCUVJqOAjUrOLjbxg==", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-body-length-node": { + "version": "4.2.1", + "resolved": "https://registry.npmjs.org/@smithy/util-body-length-node/-/util-body-length-node-4.2.1.tgz", + "integrity": "sha512-h53dz/pISVrVrfxV1iqXlx5pRg3V2YWFcSQyPyXZRrZoZj4R4DeWRDo1a7dd3CPTcFi3kE+98tuNyD2axyZReA==", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-buffer-from": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/@smithy/util-buffer-from/-/util-buffer-from-4.2.0.tgz", + "integrity": "sha512-kAY9hTKulTNevM2nlRtxAG2FQ3B2OR6QIrPY3zE5LqJy1oxzmgBGsHLWTcNhWXKchgA0WHW+mZkQrng/pgcCew==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/is-array-buffer": "^4.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-config-provider": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/@smithy/util-config-provider/-/util-config-provider-4.2.0.tgz", + "integrity": "sha512-YEjpl6XJ36FTKmD+kRJJWYvrHeUvm5ykaUS5xK+6oXffQPHeEM4/nXlZPe+Wu0lsgRUcNZiliYNh/y7q9c2y6Q==", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-defaults-mode-browser": { + "version": "4.3.17", + "resolved": "https://registry.npmjs.org/@smithy/util-defaults-mode-browser/-/util-defaults-mode-browser-4.3.17.tgz", + "integrity": "sha512-dwN4GmivYF1QphnP3xJESXKtHvkkvKHSZI8GrSKMVoENVSKW2cFPRYC4ZgstYjUHdR3zwaDkIaTDIp26JuY7Cw==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/property-provider": "^4.2.7", + "@smithy/smithy-client": "^4.10.3", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-defaults-mode-node": { + "version": "4.2.20", + "resolved": "https://registry.npmjs.org/@smithy/util-defaults-mode-node/-/util-defaults-mode-node-4.2.20.tgz", + "integrity": "sha512-VD/I4AEhF1lpB3B//pmOIMBNLMrtdMXwy9yCOfa2QkJGDr63vH3RqPbSAKzoGMov3iryCxTXCxSsyGmEB8PDpg==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/config-resolver": "^4.4.5", + "@smithy/credential-provider-imds": "^4.2.7", + "@smithy/node-config-provider": "^4.3.7", + "@smithy/property-provider": "^4.2.7", + "@smithy/smithy-client": "^4.10.3", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-endpoints": { + "version": "3.2.7", + "resolved": "https://registry.npmjs.org/@smithy/util-endpoints/-/util-endpoints-3.2.7.tgz", + "integrity": "sha512-s4ILhyAvVqhMDYREeTS68R43B1V5aenV5q/V1QpRQJkCXib5BPRo4s7uNdzGtIKxaPHCfU/8YkvPAEvTpxgspg==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/node-config-provider": "^4.3.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-hex-encoding": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/@smithy/util-hex-encoding/-/util-hex-encoding-4.2.0.tgz", + "integrity": "sha512-CCQBwJIvXMLKxVbO88IukazJD9a4kQ9ZN7/UMGBjBcJYvatpWk+9g870El4cB8/EJxfe+k+y0GmR9CAzkF+Nbw==", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-middleware": { + "version": "4.2.7", + "resolved": "https://registry.npmjs.org/@smithy/util-middleware/-/util-middleware-4.2.7.tgz", + "integrity": "sha512-i1IkpbOae6NvIKsEeLLM9/2q4X+M90KV3oCFgWQI4q0Qz+yUZvsr+gZPdAEAtFhWQhAHpTsJO8DRJPuwVyln+w==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-retry": { + "version": "4.2.7", + "resolved": "https://registry.npmjs.org/@smithy/util-retry/-/util-retry-4.2.7.tgz", + "integrity": "sha512-SvDdsQyF5CIASa4EYVT02LukPHVzAgUA4kMAuZ97QJc2BpAqZfA4PINB8/KOoCXEw9tsuv/jQjMeaHFvxdLNGg==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/service-error-classification": "^4.2.7", + "@smithy/types": "^4.11.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-stream": { + "version": "4.5.8", + "resolved": "https://registry.npmjs.org/@smithy/util-stream/-/util-stream-4.5.8.tgz", + "integrity": "sha512-ZnnBhTapjM0YPGUSmOs0Mcg/Gg87k503qG4zU2v/+Js2Gu+daKOJMeqcQns8ajepY8tgzzfYxl6kQyZKml6O2w==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/fetch-http-handler": "^5.3.8", + "@smithy/node-http-handler": "^4.4.7", + "@smithy/types": "^4.11.0", + "@smithy/util-base64": "^4.3.0", + "@smithy/util-buffer-from": "^4.2.0", + "@smithy/util-hex-encoding": "^4.2.0", + "@smithy/util-utf8": "^4.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-uri-escape": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/@smithy/util-uri-escape/-/util-uri-escape-4.2.0.tgz", + "integrity": "sha512-igZpCKV9+E/Mzrpq6YacdTQ0qTiLm85gD6N/IrmyDvQFA4UnU3d5g3m8tMT/6zG/vVkWSU+VxeUyGonL62DuxA==", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-utf8": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/@smithy/util-utf8/-/util-utf8-4.2.0.tgz", + "integrity": "sha512-zBPfuzoI8xyBtR2P6WQj63Rz8i3AmfAaJLuNG8dWsfvPe8lO4aCPYLn879mEgHndZH1zQ2oXmG8O1GGzzaoZiw==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/util-buffer-from": "^4.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/uuid": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@smithy/uuid/-/uuid-1.1.0.tgz", + "integrity": "sha512-4aUIteuyxtBUhVdiQqcDhKFitwfd9hqoSDYY2KRXiWtgoWJ9Bmise+KfEPDiVHWeJepvF8xJO9/9+WDIciMFFw==", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@types/estree": { + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.8.tgz", + "integrity": "sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/json-schema": { + "version": "7.0.15", + "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz", + "integrity": "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/node": { + "version": "22.19.3", + "resolved": "https://registry.npmjs.org/@types/node/-/node-22.19.3.tgz", + "integrity": "sha512-1N9SBnWYOJTrNZCdh/yJE+t910Y128BoyY+zBLWhL3r0TYzlTmFdXrPwHL9DyFZmlEXNQQolTZh3KHV31QDhyA==", + "license": "MIT", + "dependencies": { + "undici-types": "~6.21.0" + } + }, + "node_modules/@types/node-fetch": { + "version": "2.6.13", + "resolved": "https://registry.npmjs.org/@types/node-fetch/-/node-fetch-2.6.13.tgz", + "integrity": "sha512-QGpRVpzSaUs30JBSGPjOg4Uveu384erbHBoT1zeONvyCfwQxIkUshLAOqN/k9EjGviPRmWTTe6aH2qySWKTVSw==", + "license": "MIT", + "dependencies": { + "@types/node": "*", + "form-data": "^4.0.4" + } + }, + "node_modules/@types/retry": { + "version": "0.12.0", + "resolved": "https://registry.npmjs.org/@types/retry/-/retry-0.12.0.tgz", + "integrity": "sha512-wWKOClTTiizcZhXnPY4wikVAwmdYHp8q6DmC+EJUzAMsycb7HB32Kh9RN4+0gExjmPmZSAQjgURXIGATPegAvA==", + "license": "MIT" + }, + "node_modules/@types/uuid": { + "version": "10.0.0", + "resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-10.0.0.tgz", + "integrity": "sha512-7gqG38EyHgyP1S+7+xomFtL+ZNHcKv6DwNaCZmJmo1vgMugyF3TCnXVg4t1uk89mLNwnLtnY3TpOpCOyp1/xHQ==", + "license": "MIT" + }, + "node_modules/@typescript-eslint/eslint-plugin": { + "version": "8.52.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-8.52.0.tgz", + "integrity": "sha512-okqtOgqu2qmZJ5iN4TWlgfF171dZmx2FzdOv2K/ixL2LZWDStL8+JgQerI2sa8eAEfoydG9+0V96m7V+P8yE1Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/regexpp": "^4.12.2", + "@typescript-eslint/scope-manager": "8.52.0", + "@typescript-eslint/type-utils": "8.52.0", + "@typescript-eslint/utils": "8.52.0", + "@typescript-eslint/visitor-keys": "8.52.0", + "ignore": "^7.0.5", + "natural-compare": "^1.4.0", + "ts-api-utils": "^2.4.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "@typescript-eslint/parser": "^8.52.0", + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/@typescript-eslint/parser": { + "version": "8.52.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-8.52.0.tgz", + "integrity": "sha512-iIACsx8pxRnguSYhHiMn2PvhvfpopO9FXHyn1mG5txZIsAaB6F0KwbFnUQN3KCiG3Jcuad/Cao2FAs1Wp7vAyg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/scope-manager": "8.52.0", + "@typescript-eslint/types": "8.52.0", + "@typescript-eslint/typescript-estree": "8.52.0", + "@typescript-eslint/visitor-keys": "8.52.0", + "debug": "^4.4.3" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/@typescript-eslint/project-service": { + "version": "8.52.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/project-service/-/project-service-8.52.0.tgz", + "integrity": "sha512-xD0MfdSdEmeFa3OmVqonHi+Cciab96ls1UhIF/qX/O/gPu5KXD0bY9lu33jj04fjzrXHcuvjBcBC+D3SNSadaw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/tsconfig-utils": "^8.52.0", + "@typescript-eslint/types": "^8.52.0", + "debug": "^4.4.3" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/@typescript-eslint/scope-manager": { + "version": "8.52.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-8.52.0.tgz", + "integrity": "sha512-ixxqmmCcc1Nf8S0mS0TkJ/3LKcC8mruYJPOU6Ia2F/zUUR4pApW7LzrpU3JmtePbRUTes9bEqRc1Gg4iyRnDzA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "8.52.0", + "@typescript-eslint/visitor-keys": "8.52.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@typescript-eslint/tsconfig-utils": { + "version": "8.52.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/tsconfig-utils/-/tsconfig-utils-8.52.0.tgz", + "integrity": "sha512-jl+8fzr/SdzdxWJznq5nvoI7qn2tNYV/ZBAEcaFMVXf+K6jmXvAFrgo/+5rxgnL152f//pDEAYAhhBAZGrVfwg==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/@typescript-eslint/type-utils": { + "version": "8.52.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/type-utils/-/type-utils-8.52.0.tgz", + "integrity": "sha512-JD3wKBRWglYRQkAtsyGz1AewDu3mTc7NtRjR/ceTyGoPqmdS5oCdx/oZMWD5Zuqmo6/MpsYs0wp6axNt88/2EQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "8.52.0", + "@typescript-eslint/typescript-estree": "8.52.0", + "@typescript-eslint/utils": "8.52.0", + "debug": "^4.4.3", + "ts-api-utils": "^2.4.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/@typescript-eslint/types": { + "version": "8.52.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-8.52.0.tgz", + "integrity": "sha512-LWQV1V4q9V4cT4H5JCIx3481iIFxH1UkVk+ZkGGAV1ZGcjGI9IoFOfg3O6ywz8QqCDEp7Inlg6kovMofsNRaGg==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@typescript-eslint/typescript-estree": { + "version": "8.52.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-8.52.0.tgz", + "integrity": "sha512-XP3LClsCc0FsTK5/frGjolyADTh3QmsLp6nKd476xNI9CsSsLnmn4f0jrzNoAulmxlmNIpeXuHYeEQv61Q6qeQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/project-service": "8.52.0", + "@typescript-eslint/tsconfig-utils": "8.52.0", + "@typescript-eslint/types": "8.52.0", + "@typescript-eslint/visitor-keys": "8.52.0", + "debug": "^4.4.3", + "minimatch": "^9.0.5", + "semver": "^7.7.3", + "tinyglobby": "^0.2.15", + "ts-api-utils": "^2.4.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/@typescript-eslint/utils": { + "version": "8.52.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-8.52.0.tgz", + "integrity": "sha512-wYndVMWkweqHpEpwPhwqE2lnD2DxC6WVLupU/DOt/0/v+/+iQbbzO3jOHjmBMnhu0DgLULvOaU4h4pwHYi2oRQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/eslint-utils": "^4.9.1", + "@typescript-eslint/scope-manager": "8.52.0", + "@typescript-eslint/types": "8.52.0", + "@typescript-eslint/typescript-estree": "8.52.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/@typescript-eslint/visitor-keys": { + "version": "8.52.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-8.52.0.tgz", + "integrity": "sha512-ink3/Zofus34nmBsPjow63FP5M7IGff0RKAgqR6+CFpdk22M7aLwC9gOcLGYqr7MczLPzZVERW9hRog3O4n1sQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "8.52.0", + "eslint-visitor-keys": "^4.2.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@typescript-eslint/visitor-keys/node_modules/eslint-visitor-keys": { + "version": "4.2.1", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-4.2.1.tgz", + "integrity": "sha512-Uhdk5sfqcee/9H/rCOJikYz67o0a2Tw2hGRPOG2Y1R2dg7brRe1uG0yaNQDHu+TO/uQPF/5eCapvYSmHUjt7JQ==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/@vitest/coverage-v8": { + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/@vitest/coverage-v8/-/coverage-v8-2.1.9.tgz", + "integrity": "sha512-Z2cOr0ksM00MpEfyVE8KXIYPEcBFxdbLSs56L8PO0QQMxt/6bDj45uQfxoc96v05KW3clk7vvgP0qfDit9DmfQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@ampproject/remapping": "^2.3.0", + "@bcoe/v8-coverage": "^0.2.3", + "debug": "^4.3.7", + "istanbul-lib-coverage": "^3.2.2", + "istanbul-lib-report": "^3.0.1", + "istanbul-lib-source-maps": "^5.0.6", + "istanbul-reports": "^3.1.7", + "magic-string": "^0.30.12", + "magicast": "^0.3.5", + "std-env": "^3.8.0", + "test-exclude": "^7.0.1", + "tinyrainbow": "^1.2.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "@vitest/browser": "2.1.9", + "vitest": "2.1.9" + }, + "peerDependenciesMeta": { + "@vitest/browser": { + "optional": true + } + } + }, + "node_modules/@vitest/expect": { + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/@vitest/expect/-/expect-2.1.9.tgz", + "integrity": "sha512-UJCIkTBenHeKT1TTlKMJWy1laZewsRIzYighyYiJKZreqtdxSos/S1t+ktRMQWu2CKqaarrkeszJx1cgC5tGZw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/spy": "2.1.9", + "@vitest/utils": "2.1.9", + "chai": "^5.1.2", + "tinyrainbow": "^1.2.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/mocker": { + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/@vitest/mocker/-/mocker-2.1.9.tgz", + "integrity": "sha512-tVL6uJgoUdi6icpxmdrn5YNo3g3Dxv+IHJBr0GXHaEdTcw3F+cPKnsXFhli6nO+f/6SDKPHEK1UN+k+TQv0Ehg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/spy": "2.1.9", + "estree-walker": "^3.0.3", + "magic-string": "^0.30.12" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "msw": "^2.4.9", + "vite": "^5.0.0" + }, + "peerDependenciesMeta": { + "msw": { + "optional": true + }, + "vite": { + "optional": true + } + } + }, + "node_modules/@vitest/pretty-format": { + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/@vitest/pretty-format/-/pretty-format-2.1.9.tgz", + "integrity": "sha512-KhRIdGV2U9HOUzxfiHmY8IFHTdqtOhIzCpd8WRdJiE7D/HUcZVD0EgQCVjm+Q9gkUXWgBvMmTtZgIG48wq7sOQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "tinyrainbow": "^1.2.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/runner": { + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/@vitest/runner/-/runner-2.1.9.tgz", + "integrity": "sha512-ZXSSqTFIrzduD63btIfEyOmNcBmQvgOVsPNPe0jYtESiXkhd8u2erDLnMxmGrDCwHCCHE7hxwRDCT3pt0esT4g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/utils": "2.1.9", + "pathe": "^1.1.2" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/snapshot": { + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/@vitest/snapshot/-/snapshot-2.1.9.tgz", + "integrity": "sha512-oBO82rEjsxLNJincVhLhaxxZdEtV0EFHMK5Kmx5sJ6H9L183dHECjiefOAdnqpIgT5eZwT04PoggUnW88vOBNQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/pretty-format": "2.1.9", + "magic-string": "^0.30.12", + "pathe": "^1.1.2" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/spy": { + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/@vitest/spy/-/spy-2.1.9.tgz", + "integrity": "sha512-E1B35FwzXXTs9FHNK6bDszs7mtydNi5MIfUWpceJ8Xbfb1gBMscAnwLbEu+B44ed6W3XjL9/ehLPHR1fkf1KLQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "tinyspy": "^3.0.2" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/ui": { + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/@vitest/ui/-/ui-2.1.9.tgz", + "integrity": "sha512-izzd2zmnk8Nl5ECYkW27328RbQ1nKvkm6Bb5DAaz1Gk59EbLkiCMa6OLT0NoaAYTjOFS6N+SMYW1nh4/9ljPiw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/utils": "2.1.9", + "fflate": "^0.8.2", + "flatted": "^3.3.1", + "pathe": "^1.1.2", + "sirv": "^3.0.0", + "tinyglobby": "^0.2.10", + "tinyrainbow": "^1.2.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "vitest": "2.1.9" + } + }, + "node_modules/@vitest/utils": { + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/@vitest/utils/-/utils-2.1.9.tgz", + "integrity": "sha512-v0psaMSkNJ3A2NMrUEHFRzJtDPFn+/VWZ5WxImB21T9fjucJRmS7xCS3ppEnARb9y11OAzaD+P2Ps+b+BGX5iQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/pretty-format": "2.1.9", + "loupe": "^3.1.2", + "tinyrainbow": "^1.2.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/abort-controller": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/abort-controller/-/abort-controller-3.0.0.tgz", + "integrity": "sha512-h8lQ8tacZYnR3vNQTgibj+tODHI5/+l06Au2Pcriv/Gmet0eaj4TwWH41sO9wnHDiQsEj19q0drzdWdeAHtweg==", + "license": "MIT", + "dependencies": { + "event-target-shim": "^5.0.0" + }, + "engines": { + "node": ">=6.5" + } + }, + "node_modules/acorn": { + "version": "8.15.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", + "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", + "dev": true, + "license": "MIT", + "bin": { + "acorn": "bin/acorn" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/acorn-jsx": { + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-5.3.2.tgz", + "integrity": "sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==", + "dev": true, + "license": "MIT", + "peerDependencies": { + "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" + } + }, + "node_modules/agentkeepalive": { + "version": "4.6.0", + "resolved": "https://registry.npmjs.org/agentkeepalive/-/agentkeepalive-4.6.0.tgz", + "integrity": "sha512-kja8j7PjmncONqaTsB8fQ+wE2mSU2DJ9D4XKoJ5PFWIdRMa6SLSN1ff4mOr4jCbfRSsxR4keIiySJU0N9T5hIQ==", + "license": "MIT", + "dependencies": { + "humanize-ms": "^1.2.1" + }, + "engines": { + "node": ">= 8.0.0" + } + }, + "node_modules/ajv": { + "version": "6.12.6", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", + "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "dev": true, + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.1", + "fast-json-stable-stringify": "^2.0.0", + "json-schema-traverse": "^0.4.1", + "uri-js": "^4.2.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/ansi-regex": { + "version": "6.2.2", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-6.2.2.tgz", + "integrity": "sha512-Bq3SmSpyFHaWjPk8If9yc6svM8c56dB5BAtW4Qbw5jHTwwXXcTLoRMkpDJp6VL0XzlWaCHTXrkFURMYmD0sLqg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/ansi-regex?sponsor=1" + } + }, + "node_modules/ansi-styles": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-5.2.0.tgz", + "integrity": "sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA==", + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/argparse": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz", + "integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==", + "dev": true, + "license": "Python-2.0" + }, + "node_modules/assertion-error": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/assertion-error/-/assertion-error-2.0.1.tgz", + "integrity": "sha512-Izi8RQcffqCeNVgFigKli1ssklIbpHnCYc6AknXGYoB6grJqyeby7jv12JUQgmTAnIDnbck1uxksT4dzN3PWBA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + } + }, + "node_modules/asynckit": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz", + "integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==", + "license": "MIT" + }, + "node_modules/balanced-match": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", + "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", + "dev": true, + "license": "MIT" + }, + "node_modules/base64-js": { + "version": "1.5.1", + "resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz", + "integrity": "sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT" + }, + "node_modules/bowser": { + "version": "2.13.1", + "resolved": "https://registry.npmjs.org/bowser/-/bowser-2.13.1.tgz", + "integrity": "sha512-OHawaAbjwx6rqICCKgSG0SAnT05bzd7ppyKLVUITZpANBaaMFBAsaNkto3LoQ31tyFP5kNujE8Cdx85G9VzOkw==", + "license": "MIT" + }, + "node_modules/brace-expansion": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/cac": { + "version": "6.7.14", + "resolved": "https://registry.npmjs.org/cac/-/cac-6.7.14.tgz", + "integrity": "sha512-b6Ilus+c3RrdDk+JhLKUAQfzzgLEPy6wcXqS7f/xe1EETvsDP6GORG7SFuOs6cID5YkqchW/LXZbX5bc8j7ZcQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/call-bind-apply-helpers": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz", + "integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/callsites": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz", + "integrity": "sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/camelcase": { + "version": "6.3.0", + "resolved": "https://registry.npmjs.org/camelcase/-/camelcase-6.3.0.tgz", + "integrity": "sha512-Gmy6FhYlCY7uOElZUSbxo2UCDH8owEk996gkbrpsgGtrJLM3J7jGxl9Ic7Qwwj4ivOE5AWZWRMecDdF7hqGjFA==", + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/chai": { + "version": "5.3.3", + "resolved": "https://registry.npmjs.org/chai/-/chai-5.3.3.tgz", + "integrity": "sha512-4zNhdJD/iOjSH0A05ea+Ke6MU5mmpQcbQsSOkgdaUMJ9zTlDTD/GYlwohmIE2u0gaxHYiVHEn1Fw9mZ/ktJWgw==", + "dev": true, + "license": "MIT", + "dependencies": { + "assertion-error": "^2.0.1", + "check-error": "^2.1.1", + "deep-eql": "^5.0.1", + "loupe": "^3.1.0", + "pathval": "^2.0.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/chalk": { + "version": "4.1.2", + "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", + "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "license": "MIT", + "dependencies": { + "ansi-styles": "^4.1.0", + "supports-color": "^7.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/chalk?sponsor=1" + } + }, + "node_modules/chalk/node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "license": "MIT", + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/check-error": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/check-error/-/check-error-2.1.3.tgz", + "integrity": "sha512-PAJdDJusoxnwm1VwW07VWwUN1sl7smmC3OKggvndJFadxxDRyFJBX/ggnu/KE4kQAB7a3Dp8f/YXC1FlUprWmA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 16" + } + }, + "node_modules/color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "license": "MIT", + "dependencies": { + "color-name": "~1.1.4" + }, + "engines": { + "node": ">=7.0.0" + } + }, + "node_modules/color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "license": "MIT" + }, + "node_modules/combined-stream": { + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz", + "integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==", + "license": "MIT", + "dependencies": { + "delayed-stream": "~1.0.0" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/concat-map": { + "version": "0.0.1", + "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", + "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==", + "dev": true, + "license": "MIT" + }, + "node_modules/console-table-printer": { + "version": "2.15.0", + "resolved": "https://registry.npmjs.org/console-table-printer/-/console-table-printer-2.15.0.tgz", + "integrity": "sha512-SrhBq4hYVjLCkBVOWaTzceJalvn5K1Zq5aQA6wXC/cYjI3frKWNPEMK3sZsJfNNQApvCQmgBcc13ZKmFj8qExw==", + "license": "MIT", + "dependencies": { + "simple-wcswidth": "^1.1.2" + } + }, + "node_modules/cross-spawn": { + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", + "dev": true, + "license": "MIT", + "dependencies": { + "path-key": "^3.1.0", + "shebang-command": "^2.0.0", + "which": "^2.0.1" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/debug": { + "version": "4.4.3", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.3.tgz", + "integrity": "sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==", + "dev": true, + "license": "MIT", + "dependencies": { + "ms": "^2.1.3" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } + } + }, + "node_modules/decamelize": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/decamelize/-/decamelize-1.2.0.tgz", + "integrity": "sha512-z2S+W9X73hAUUki+N+9Za2lBlun89zigOyGrsax+KUQ6wKW4ZoWpEYBkGhQjwAjjDCkWxhY0VKEhk8wzY7F5cA==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/deep-eql": { + "version": "5.0.2", + "resolved": "https://registry.npmjs.org/deep-eql/-/deep-eql-5.0.2.tgz", + "integrity": "sha512-h5k/5U50IJJFpzfL6nO9jaaumfjO/f2NjK/oYB2Djzm4p9L+3T9qWpZqZ2hAbLPuuYq9wrU08WQyBTL5GbPk5Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/deep-is": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/deep-is/-/deep-is-0.1.4.tgz", + "integrity": "sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/delayed-stream": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz", + "integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==", + "license": "MIT", + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/dotenv": { + "version": "16.6.1", + "resolved": "https://registry.npmjs.org/dotenv/-/dotenv-16.6.1.tgz", + "integrity": "sha512-uBq4egWHTcTt33a72vpSG0z3HnPuIl6NqYcTrKEg2azoEyl2hpW0zqlxysq2pK9HlDIHyHyakeYaYnSAwd8bow==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://dotenvx.com" + } + }, + "node_modules/dunder-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz", + "integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==", + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.1", + "es-errors": "^1.3.0", + "gopd": "^1.2.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/eastasianwidth": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/eastasianwidth/-/eastasianwidth-0.2.0.tgz", + "integrity": "sha512-I88TYZWc9XiYHRQ4/3c5rjjfgkjhLyW2luGIheGERbNQ6OY7yTybanSpDXZa8y7VUP9YmDcYa+eyq4ca7iLqWA==", + "dev": true, + "license": "MIT" + }, + "node_modules/emoji-regex": { + "version": "9.2.2", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-9.2.2.tgz", + "integrity": "sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg==", + "dev": true, + "license": "MIT" + }, + "node_modules/es-define-property": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz", + "integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-errors": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", + "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-module-lexer": { + "version": "1.7.0", + "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.7.0.tgz", + "integrity": "sha512-jEQoCwk8hyb2AZziIOLhDqpm5+2ww5uIE6lkO/6jcOCusfk6LhMHpXXfBLXTZ7Ydyt0j4VoUQv6uGNYbdW+kBA==", + "dev": true, + "license": "MIT" + }, + "node_modules/es-object-atoms": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz", + "integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-set-tostringtag": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/es-set-tostringtag/-/es-set-tostringtag-2.1.0.tgz", + "integrity": "sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.6", + "has-tostringtag": "^1.0.2", + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/esbuild": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.21.5.tgz", + "integrity": "sha512-mg3OPMV4hXywwpoDxu3Qda5xCKQi+vCTZq8S9J/EpkhB2HzKXq4SNFZE3+NK93JYxc8VMSep+lOUSC/RVKaBqw==", + "dev": true, + "hasInstallScript": true, + "license": "MIT", + "bin": { + "esbuild": "bin/esbuild" + }, + "engines": { + "node": ">=12" + }, + "optionalDependencies": { + "@esbuild/aix-ppc64": "0.21.5", + "@esbuild/android-arm": "0.21.5", + "@esbuild/android-arm64": "0.21.5", + "@esbuild/android-x64": "0.21.5", + "@esbuild/darwin-arm64": "0.21.5", + "@esbuild/darwin-x64": "0.21.5", + "@esbuild/freebsd-arm64": "0.21.5", + "@esbuild/freebsd-x64": "0.21.5", + "@esbuild/linux-arm": "0.21.5", + "@esbuild/linux-arm64": "0.21.5", + "@esbuild/linux-ia32": "0.21.5", + "@esbuild/linux-loong64": "0.21.5", + "@esbuild/linux-mips64el": "0.21.5", + "@esbuild/linux-ppc64": "0.21.5", + "@esbuild/linux-riscv64": "0.21.5", + "@esbuild/linux-s390x": "0.21.5", + "@esbuild/linux-x64": "0.21.5", + "@esbuild/netbsd-x64": "0.21.5", + "@esbuild/openbsd-x64": "0.21.5", + "@esbuild/sunos-x64": "0.21.5", + "@esbuild/win32-arm64": "0.21.5", + "@esbuild/win32-ia32": "0.21.5", + "@esbuild/win32-x64": "0.21.5" + } + }, + "node_modules/escape-string-regexp": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz", + "integrity": "sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/eslint": { + "version": "9.39.2", + "resolved": "https://registry.npmjs.org/eslint/-/eslint-9.39.2.tgz", + "integrity": "sha512-LEyamqS7W5HB3ujJyvi0HQK/dtVINZvd5mAAp9eT5S/ujByGjiZLCzPcHVzuXbpJDJF/cxwHlfceVUDZ2lnSTw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/eslint-utils": "^4.8.0", + "@eslint-community/regexpp": "^4.12.1", + "@eslint/config-array": "^0.21.1", + "@eslint/config-helpers": "^0.4.2", + "@eslint/core": "^0.17.0", + "@eslint/eslintrc": "^3.3.1", + "@eslint/js": "9.39.2", + "@eslint/plugin-kit": "^0.4.1", + "@humanfs/node": "^0.16.6", + "@humanwhocodes/module-importer": "^1.0.1", + "@humanwhocodes/retry": "^0.4.2", + "@types/estree": "^1.0.6", + "ajv": "^6.12.4", + "chalk": "^4.0.0", + "cross-spawn": "^7.0.6", + "debug": "^4.3.2", + "escape-string-regexp": "^4.0.0", + "eslint-scope": "^8.4.0", + "eslint-visitor-keys": "^4.2.1", + "espree": "^10.4.0", + "esquery": "^1.5.0", + "esutils": "^2.0.2", + "fast-deep-equal": "^3.1.3", + "file-entry-cache": "^8.0.0", + "find-up": "^5.0.0", + "glob-parent": "^6.0.2", + "ignore": "^5.2.0", + "imurmurhash": "^0.1.4", + "is-glob": "^4.0.0", + "json-stable-stringify-without-jsonify": "^1.0.1", + "lodash.merge": "^4.6.2", + "minimatch": "^3.1.2", + "natural-compare": "^1.4.0", + "optionator": "^0.9.3" + }, + "bin": { + "eslint": "bin/eslint.js" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://eslint.org/donate" + }, + "peerDependencies": { + "jiti": "*" + }, + "peerDependenciesMeta": { + "jiti": { + "optional": true + } + } + }, + "node_modules/eslint-scope": { + "version": "8.4.0", + "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-8.4.0.tgz", + "integrity": "sha512-sNXOfKCn74rt8RICKMvJS7XKV/Xk9kA7DyJr8mJik3S7Cwgy3qlkkmyS2uQB3jiJg6VNdZd/pDBJu0nvG2NlTg==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "esrecurse": "^4.3.0", + "estraverse": "^5.2.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/eslint-visitor-keys": { + "version": "3.4.3", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-3.4.3.tgz", + "integrity": "sha512-wpc+LXeiyiisxPlEkUzU6svyS1frIO3Mgxj1fdy7Pm8Ygzguax2N3Fa/D/ag1WqbOprdI+uY6wMUl8/a2G+iag==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/eslint/node_modules/brace-expansion": { + "version": "1.1.12", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", + "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0", + "concat-map": "0.0.1" + } + }, + "node_modules/eslint/node_modules/eslint-visitor-keys": { + "version": "4.2.1", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-4.2.1.tgz", + "integrity": "sha512-Uhdk5sfqcee/9H/rCOJikYz67o0a2Tw2hGRPOG2Y1R2dg7brRe1uG0yaNQDHu+TO/uQPF/5eCapvYSmHUjt7JQ==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/eslint/node_modules/ignore": { + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.2.tgz", + "integrity": "sha512-hsBTNUqQTDwkWtcdYI2i06Y/nUBEsNEDJKjWdigLvegy8kDuJAS8uRlpkkcQpyEXL0Z/pjDy5HBmMjRCJ2gq+g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 4" + } + }, + "node_modules/eslint/node_modules/minimatch": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^1.1.7" + }, + "engines": { + "node": "*" + } + }, + "node_modules/espree": { + "version": "10.4.0", + "resolved": "https://registry.npmjs.org/espree/-/espree-10.4.0.tgz", + "integrity": "sha512-j6PAQ2uUr79PZhBjP5C5fhl8e39FmRnOjsD5lGnWrFU8i2G776tBK7+nP8KuQUTTyAZUwfQqXAgrVH5MbH9CYQ==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "acorn": "^8.15.0", + "acorn-jsx": "^5.3.2", + "eslint-visitor-keys": "^4.2.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/espree/node_modules/eslint-visitor-keys": { + "version": "4.2.1", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-4.2.1.tgz", + "integrity": "sha512-Uhdk5sfqcee/9H/rCOJikYz67o0a2Tw2hGRPOG2Y1R2dg7brRe1uG0yaNQDHu+TO/uQPF/5eCapvYSmHUjt7JQ==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/esquery": { + "version": "1.7.0", + "resolved": "https://registry.npmjs.org/esquery/-/esquery-1.7.0.tgz", + "integrity": "sha512-Ap6G0WQwcU/LHsvLwON1fAQX9Zp0A2Y6Y/cJBl9r/JbW90Zyg4/zbG6zzKa2OTALELarYHmKu0GhpM5EO+7T0g==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "estraverse": "^5.1.0" + }, + "engines": { + "node": ">=0.10" + } + }, + "node_modules/esrecurse": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/esrecurse/-/esrecurse-4.3.0.tgz", + "integrity": "sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "estraverse": "^5.2.0" + }, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/estraverse": { + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", + "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=4.0" + } + }, + "node_modules/estree-walker": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-3.0.3.tgz", + "integrity": "sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/estree": "^1.0.0" + } + }, + "node_modules/esutils": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", + "integrity": "sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/event-target-shim": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/event-target-shim/-/event-target-shim-5.0.1.tgz", + "integrity": "sha512-i/2XbnSz/uxRCU6+NdVJgKWDTM427+MqYbkQzD321DuCQJUqOuJKIA0IM2+W2xtYHdKOmZ4dR6fExsd4SXL+WQ==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/eventemitter3": { + "version": "4.0.7", + "resolved": "https://registry.npmjs.org/eventemitter3/-/eventemitter3-4.0.7.tgz", + "integrity": "sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw==", + "license": "MIT" + }, + "node_modules/expect-type": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/expect-type/-/expect-type-1.3.0.tgz", + "integrity": "sha512-knvyeauYhqjOYvQ66MznSMs83wmHrCycNEN6Ao+2AeYEfxUIkuiVxdEa1qlGEPK+We3n0THiDciYSsCcgW/DoA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=12.0.0" + } + }, + "node_modules/fast-deep-equal": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", + "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", + "dev": true, + "license": "MIT" + }, + "node_modules/fast-json-stable-stringify": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", + "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", + "dev": true, + "license": "MIT" + }, + "node_modules/fast-levenshtein": { + "version": "2.0.6", + "resolved": "https://registry.npmjs.org/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz", + "integrity": "sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw==", + "dev": true, + "license": "MIT" + }, + "node_modules/fast-xml-parser": { + "version": "4.5.3", + "resolved": "https://registry.npmjs.org/fast-xml-parser/-/fast-xml-parser-4.5.3.tgz", + "integrity": "sha512-RKihhV+SHsIUGXObeVy9AXiBbFwkVk7Syp8XgwN5U3JV416+Gwp/GO9i0JYKmikykgz/UHRrrV4ROuZEo/T0ig==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/NaturalIntelligence" + } + ], + "license": "MIT", + "dependencies": { + "strnum": "^1.1.1" + }, + "bin": { + "fxparser": "src/cli/cli.js" + } + }, + "node_modules/fdir": { + "version": "6.5.0", + "resolved": "https://registry.npmjs.org/fdir/-/fdir-6.5.0.tgz", + "integrity": "sha512-tIbYtZbucOs0BRGqPJkshJUYdL+SDH7dVM8gjy+ERp3WAUjLEFJE+02kanyHtwjWOnwrKYBiwAmM0p4kLJAnXg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12.0.0" + }, + "peerDependencies": { + "picomatch": "^3 || ^4" + }, + "peerDependenciesMeta": { + "picomatch": { + "optional": true + } + } + }, + "node_modules/fflate": { + "version": "0.8.2", + "resolved": "https://registry.npmjs.org/fflate/-/fflate-0.8.2.tgz", + "integrity": "sha512-cPJU47OaAoCbg0pBvzsgpTPhmhqI5eJjh/JIu8tPj5q+T7iLvW/JAYUqmE7KOB4R1ZyEhzBaIQpQpardBF5z8A==", + "dev": true, + "license": "MIT" + }, + "node_modules/file-entry-cache": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-8.0.0.tgz", + "integrity": "sha512-XXTUwCvisa5oacNGRP9SfNtYBNAMi+RPwBFmblZEF7N7swHYQS6/Zfk7SRwx4D5j3CH211YNRco1DEMNVfZCnQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "flat-cache": "^4.0.0" + }, + "engines": { + "node": ">=16.0.0" + } + }, + "node_modules/find-up": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/find-up/-/find-up-5.0.0.tgz", + "integrity": "sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng==", + "dev": true, + "license": "MIT", + "dependencies": { + "locate-path": "^6.0.0", + "path-exists": "^4.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/flat-cache": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/flat-cache/-/flat-cache-4.0.1.tgz", + "integrity": "sha512-f7ccFPK3SXFHpx15UIGyRJ/FJQctuKZ0zVuN3frBo4HnK3cay9VEW0R6yPYFHC0AgqhukPzKjq22t5DmAyqGyw==", + "dev": true, + "license": "MIT", + "dependencies": { + "flatted": "^3.2.9", + "keyv": "^4.5.4" + }, + "engines": { + "node": ">=16" + } + }, + "node_modules/flatted": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz", + "integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==", + "dev": true, + "license": "ISC" + }, + "node_modules/foreground-child": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/foreground-child/-/foreground-child-3.3.1.tgz", + "integrity": "sha512-gIXjKqtFuWEgzFRJA9WCQeSJLZDjgJUOMCMzxtvFq/37KojM1BFGufqsCy0r4qSQmYLsZYMeyRqzIWOMup03sw==", + "dev": true, + "license": "ISC", + "dependencies": { + "cross-spawn": "^7.0.6", + "signal-exit": "^4.0.1" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/form-data": { + "version": "4.0.5", + "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.5.tgz", + "integrity": "sha512-8RipRLol37bNs2bhoV67fiTEvdTrbMUYcFTiy3+wuuOnUog2QBHCZWXDRijWQfAkhBj2Uf5UnVaiWwA5vdd82w==", + "license": "MIT", + "dependencies": { + "asynckit": "^0.4.0", + "combined-stream": "^1.0.8", + "es-set-tostringtag": "^2.1.0", + "hasown": "^2.0.2", + "mime-types": "^2.1.12" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/form-data-encoder": { + "version": "1.7.2", + "resolved": "https://registry.npmjs.org/form-data-encoder/-/form-data-encoder-1.7.2.tgz", + "integrity": "sha512-qfqtYan3rxrnCk1VYaA4H+Ms9xdpPqvLZa6xmMgFvhO32x7/3J/ExcTd6qpxM0vH2GdMI+poehyBZvqfMTto8A==", + "license": "MIT" + }, + "node_modules/formdata-node": { + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/formdata-node/-/formdata-node-4.4.1.tgz", + "integrity": "sha512-0iirZp3uVDjVGt9p49aTaqjk84TrglENEDuqfdlZQ1roC9CWlPk6Avf8EEnZNcAqPonwkG35x4n3ww/1THYAeQ==", + "license": "MIT", + "dependencies": { + "node-domexception": "1.0.0", + "web-streams-polyfill": "4.0.0-beta.3" + }, + "engines": { + "node": ">= 12.20" + } + }, + "node_modules/fsevents": { + "version": "2.3.3", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", + "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", + "dev": true, + "hasInstallScript": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, + "node_modules/function-bind": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", + "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/get-intrinsic": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz", + "integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==", + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "es-define-property": "^1.0.1", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", + "function-bind": "^1.1.2", + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "has-symbols": "^1.1.0", + "hasown": "^2.0.2", + "math-intrinsics": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/get-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz", + "integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==", + "license": "MIT", + "dependencies": { + "dunder-proto": "^1.0.1", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/glob": { + "version": "10.5.0", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", + "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", + "dev": true, + "license": "ISC", + "dependencies": { + "foreground-child": "^3.1.0", + "jackspeak": "^3.1.2", + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" + }, + "bin": { + "glob": "dist/esm/bin.mjs" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/glob-parent": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", + "integrity": "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==", + "dev": true, + "license": "ISC", + "dependencies": { + "is-glob": "^4.0.3" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/globals": { + "version": "14.0.0", + "resolved": "https://registry.npmjs.org/globals/-/globals-14.0.0.tgz", + "integrity": "sha512-oahGvuMGQlPw/ivIYBjVSrWAfWLBeku5tpPE2fOPLi+WHffIWbuh2tCjhyQhTBPMf5E9jDEH4FOmTYgYwbKwtQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/gopd": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz", + "integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/has-symbols": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz", + "integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-tostringtag": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.2.tgz", + "integrity": "sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==", + "license": "MIT", + "dependencies": { + "has-symbols": "^1.0.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "license": "MIT", + "dependencies": { + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/html-escaper": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/html-escaper/-/html-escaper-2.0.2.tgz", + "integrity": "sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg==", + "dev": true, + "license": "MIT" + }, + "node_modules/humanize-ms": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/humanize-ms/-/humanize-ms-1.2.1.tgz", + "integrity": "sha512-Fl70vYtsAFb/C06PTS9dZBo7ihau+Tu/DNCk/OyHhea07S+aeMWpFFkUaXRa8fI+ScZbEI8dfSxwY7gxZ9SAVQ==", + "license": "MIT", + "dependencies": { + "ms": "^2.0.0" + } + }, + "node_modules/ignore": { + "version": "7.0.5", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-7.0.5.tgz", + "integrity": "sha512-Hs59xBNfUIunMFgWAbGX5cq6893IbWg4KnrjbYwX3tx0ztorVgTDA6B2sxf8ejHJ4wz8BqGUMYlnzNBer5NvGg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 4" + } + }, + "node_modules/import-fresh": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.1.tgz", + "integrity": "sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "parent-module": "^1.0.0", + "resolve-from": "^4.0.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/imurmurhash": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/imurmurhash/-/imurmurhash-0.1.4.tgz", + "integrity": "sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.8.19" + } + }, + "node_modules/is-extglob": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", + "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-fullwidth-code-point": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", + "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/is-glob": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", + "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-extglob": "^2.1.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/isexe": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", + "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", + "dev": true, + "license": "ISC" + }, + "node_modules/istanbul-lib-coverage": { + "version": "3.2.2", + "resolved": "https://registry.npmjs.org/istanbul-lib-coverage/-/istanbul-lib-coverage-3.2.2.tgz", + "integrity": "sha512-O8dpsF+r0WV/8MNRKfnmrtCWhuKjxrq2w+jpzBL5UZKTi2LeVWnWOmWRxFlesJONmc+wLAGvKQZEOanko0LFTg==", + "dev": true, + "license": "BSD-3-Clause", + "engines": { + "node": ">=8" + } + }, + "node_modules/istanbul-lib-report": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/istanbul-lib-report/-/istanbul-lib-report-3.0.1.tgz", + "integrity": "sha512-GCfE1mtsHGOELCU8e/Z7YWzpmybrx/+dSTfLrvY8qRmaY6zXTKWn6WQIjaAFw069icm6GVMNkgu0NzI4iPZUNw==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "istanbul-lib-coverage": "^3.0.0", + "make-dir": "^4.0.0", + "supports-color": "^7.1.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/istanbul-lib-source-maps": { + "version": "5.0.6", + "resolved": "https://registry.npmjs.org/istanbul-lib-source-maps/-/istanbul-lib-source-maps-5.0.6.tgz", + "integrity": "sha512-yg2d+Em4KizZC5niWhQaIomgf5WlL4vOOjZ5xGCmF8SnPE/mDWWXgvRExdcpCgh9lLRRa1/fSYp2ymmbJ1pI+A==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "@jridgewell/trace-mapping": "^0.3.23", + "debug": "^4.1.1", + "istanbul-lib-coverage": "^3.0.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/istanbul-reports": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/istanbul-reports/-/istanbul-reports-3.2.0.tgz", + "integrity": "sha512-HGYWWS/ehqTV3xN10i23tkPkpH46MLCIMFNCaaKNavAXTF1RkqxawEPtnjnGZ6XKSInBKkiOA5BKS+aZiY3AvA==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "html-escaper": "^2.0.0", + "istanbul-lib-report": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/jackspeak": { + "version": "3.4.3", + "resolved": "https://registry.npmjs.org/jackspeak/-/jackspeak-3.4.3.tgz", + "integrity": "sha512-OGlZQpz2yfahA/Rd1Y8Cd9SIEsqvXkLVoSw/cgwhnhFMDbsQFeZYoJJ7bIZBS9BcamUW96asq/npPWugM+RQBw==", + "dev": true, + "license": "BlueOak-1.0.0", + "dependencies": { + "@isaacs/cliui": "^8.0.2" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + }, + "optionalDependencies": { + "@pkgjs/parseargs": "^0.11.0" + } + }, + "node_modules/js-tiktoken": { + "version": "1.0.21", + "resolved": "https://registry.npmjs.org/js-tiktoken/-/js-tiktoken-1.0.21.tgz", + "integrity": "sha512-biOj/6M5qdgx5TKjDnFT1ymSpM5tbd3ylwDtrQvFQSu0Z7bBYko2dF+W/aUkXUPuk6IVpRxk/3Q2sHOzGlS36g==", + "license": "MIT", + "dependencies": { + "base64-js": "^1.5.1" + } + }, + "node_modules/js-yaml": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", + "dev": true, + "license": "MIT", + "dependencies": { + "argparse": "^2.0.1" + }, + "bin": { + "js-yaml": "bin/js-yaml.js" + } + }, + "node_modules/json-buffer": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/json-buffer/-/json-buffer-3.0.1.tgz", + "integrity": "sha512-4bV5BfR2mqfQTJm+V5tPPdf+ZpuhiIvTuAB5g8kcrXOZpTT/QwwVRWBywX1ozr6lEuPdbHxwaJlm9G6mI2sfSQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/json-schema-to-ts": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/json-schema-to-ts/-/json-schema-to-ts-3.1.1.tgz", + "integrity": "sha512-+DWg8jCJG2TEnpy7kOm/7/AxaYoaRbjVB4LFZLySZlWn8exGs3A4OLJR966cVvU26N7X9TWxl+Jsw7dzAqKT6g==", + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.18.3", + "ts-algebra": "^2.0.0" + }, + "engines": { + "node": ">=16" + } + }, + "node_modules/json-schema-traverse": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", + "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", + "dev": true, + "license": "MIT" + }, + "node_modules/json-stable-stringify-without-jsonify": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/json-stable-stringify-without-jsonify/-/json-stable-stringify-without-jsonify-1.0.1.tgz", + "integrity": "sha512-Bdboy+l7tA3OGW6FjyFHWkP5LuByj1Tk33Ljyq0axyzdk9//JSi2u3fP1QSmd1KNwq6VOKYGlAu87CisVir6Pw==", + "dev": true, + "license": "MIT" + }, + "node_modules/keyv": { + "version": "4.5.4", + "resolved": "https://registry.npmjs.org/keyv/-/keyv-4.5.4.tgz", + "integrity": "sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw==", + "dev": true, + "license": "MIT", + "dependencies": { + "json-buffer": "3.0.1" + } + }, + "node_modules/langsmith": { + "version": "0.3.87", + "resolved": "https://registry.npmjs.org/langsmith/-/langsmith-0.3.87.tgz", + "integrity": "sha512-XXR1+9INH8YX96FKWc5tie0QixWz6tOqAsAKfcJyPkE0xPep+NDz0IQLR32q4bn10QK3LqD2HN6T3n6z1YLW7Q==", + "license": "MIT", + "dependencies": { + "@types/uuid": "^10.0.0", + "chalk": "^4.1.2", + "console-table-printer": "^2.12.1", + "p-queue": "^6.6.2", + "semver": "^7.6.3", + "uuid": "^10.0.0" + }, + "peerDependencies": { + "@opentelemetry/api": "*", + "@opentelemetry/exporter-trace-otlp-proto": "*", + "@opentelemetry/sdk-trace-base": "*", + "openai": "*" + }, + "peerDependenciesMeta": { + "@opentelemetry/api": { + "optional": true + }, + "@opentelemetry/exporter-trace-otlp-proto": { + "optional": true + }, + "@opentelemetry/sdk-trace-base": { + "optional": true + }, + "openai": { + "optional": true + } + } + }, + "node_modules/levn": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/levn/-/levn-0.4.1.tgz", + "integrity": "sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "prelude-ls": "^1.2.1", + "type-check": "~0.4.0" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/locate-path": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/locate-path/-/locate-path-6.0.0.tgz", + "integrity": "sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==", + "dev": true, + "license": "MIT", + "dependencies": { + "p-locate": "^5.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/lodash.merge": { + "version": "4.6.2", + "resolved": "https://registry.npmjs.org/lodash.merge/-/lodash.merge-4.6.2.tgz", + "integrity": "sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/loupe": { + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/loupe/-/loupe-3.2.1.tgz", + "integrity": "sha512-CdzqowRJCeLU72bHvWqwRBBlLcMEtIvGrlvef74kMnV2AolS9Y8xUv1I0U/MNAWMhBlKIoyuEgoJ0t/bbwHbLQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/lru-cache": { + "version": "10.4.3", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-10.4.3.tgz", + "integrity": "sha512-JNAzZcXrCt42VGLuYz0zfAzDfAvJWW6AfYlDBQyDV5DClI2m5sAmK+OIO7s59XfsRsWHp02jAJrRadPRGTt6SQ==", + "dev": true, + "license": "ISC" + }, + "node_modules/magic-string": { + "version": "0.30.21", + "resolved": "https://registry.npmjs.org/magic-string/-/magic-string-0.30.21.tgz", + "integrity": "sha512-vd2F4YUyEXKGcLHoq+TEyCjxueSeHnFxyyjNp80yg0XV4vUhnDer/lvvlqM/arB5bXQN5K2/3oinyCRyx8T2CQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/sourcemap-codec": "^1.5.5" + } + }, + "node_modules/magicast": { + "version": "0.3.5", + "resolved": "https://registry.npmjs.org/magicast/-/magicast-0.3.5.tgz", + "integrity": "sha512-L0WhttDl+2BOsybvEOLK7fW3UA0OQ0IQ2d6Zl2x/a6vVRs3bAY0ECOSHHeL5jD+SbOpOCUEi0y1DgHEn9Qn1AQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/parser": "^7.25.4", + "@babel/types": "^7.25.4", + "source-map-js": "^1.2.0" + } + }, + "node_modules/make-dir": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/make-dir/-/make-dir-4.0.0.tgz", + "integrity": "sha512-hXdUTZYIVOt1Ex//jAQi+wTZZpUpwBj/0QsOzqegb3rGMMeJiSEu5xLHnYfBrRV4RH2+OCSOO95Is/7x1WJ4bw==", + "dev": true, + "license": "MIT", + "dependencies": { + "semver": "^7.5.3" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/math-intrinsics": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz", + "integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/mime-db": { + "version": "1.52.0", + "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz", + "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/mime-types": { + "version": "2.1.35", + "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz", + "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", + "license": "MIT", + "dependencies": { + "mime-db": "1.52.0" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/minipass": { + "version": "7.1.2", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-7.1.2.tgz", + "integrity": "sha512-qOOzS1cBTWYF4BH8fVePDBOO9iptMnGUEZwNc/cMWnTV2nVLZ7VoNWEPHkYczZA0pdoA7dl6e7FL659nX9S2aw==", + "dev": true, + "license": "ISC", + "engines": { + "node": ">=16 || 14 >=14.17" + } + }, + "node_modules/mrmime": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/mrmime/-/mrmime-2.0.1.tgz", + "integrity": "sha512-Y3wQdFg2Va6etvQ5I82yUhGdsKrcYox6p7FfL1LbK2J4V01F9TGlepTIhnK24t7koZibmg82KGglhA1XK5IsLQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + } + }, + "node_modules/ms": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "license": "MIT" + }, + "node_modules/mustache": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/mustache/-/mustache-4.2.0.tgz", + "integrity": "sha512-71ippSywq5Yb7/tVYyGbkBggbU8H3u5Rz56fH60jGFgr8uHwxs+aSKeqmluIVzM0m0kB7xQjKS6qPfd0b2ZoqQ==", + "license": "MIT", + "bin": { + "mustache": "bin/mustache" + } + }, + "node_modules/nanoid": { + "version": "3.3.11", + "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.11.tgz", + "integrity": "sha512-N8SpfPUnUp1bK+PMYW8qSWdl9U+wwNWI4QKxOYDy9JAro3WMX7p2OeVRF9v+347pnakNevPmiHhNmZ2HbFA76w==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "bin": { + "nanoid": "bin/nanoid.cjs" + }, + "engines": { + "node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1" + } + }, + "node_modules/natural-compare": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz", + "integrity": "sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==", + "dev": true, + "license": "MIT" + }, + "node_modules/node-domexception": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/node-domexception/-/node-domexception-1.0.0.tgz", + "integrity": "sha512-/jKZoMpw0F8GRwl4/eLROPA3cfcXtLApP0QzLmUT/HuPCZWyB7IY9ZrMeKw2O/nFIqPQB3PVM9aYm0F312AXDQ==", + "deprecated": "Use your platform's native DOMException instead", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/jimmywarting" + }, + { + "type": "github", + "url": "https://paypal.me/jimmywarting" + } + ], + "license": "MIT", + "engines": { + "node": ">=10.5.0" + } + }, + "node_modules/node-fetch": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.7.0.tgz", + "integrity": "sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A==", + "license": "MIT", + "dependencies": { + "whatwg-url": "^5.0.0" + }, + "engines": { + "node": "4.x || >=6.0.0" + }, + "peerDependencies": { + "encoding": "^0.1.0" + }, + "peerDependenciesMeta": { + "encoding": { + "optional": true + } + } + }, + "node_modules/openai": { + "version": "6.15.0", + "resolved": "https://registry.npmjs.org/openai/-/openai-6.15.0.tgz", + "integrity": "sha512-F1Lvs5BoVvmZtzkUEVyh8mDQPPFolq4F+xdsx/DO8Hee8YF3IGAlZqUIsF+DVGhqf4aU0a3bTghsxB6OIsRy1g==", + "license": "Apache-2.0", + "bin": { + "openai": "bin/cli" + }, + "peerDependencies": { + "ws": "^8.18.0", + "zod": "^3.25 || ^4.0" + }, + "peerDependenciesMeta": { + "ws": { + "optional": true + }, + "zod": { + "optional": true + } + } + }, + "node_modules/optionator": { + "version": "0.9.4", + "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz", + "integrity": "sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g==", + "dev": true, + "license": "MIT", + "dependencies": { + "deep-is": "^0.1.3", + "fast-levenshtein": "^2.0.6", + "levn": "^0.4.1", + "prelude-ls": "^1.2.1", + "type-check": "^0.4.0", + "word-wrap": "^1.2.5" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/p-finally": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/p-finally/-/p-finally-1.0.0.tgz", + "integrity": "sha512-LICb2p9CB7FS+0eR1oqWnHhp0FljGLZCWBE9aix0Uye9W8LTQPwMTYVGWQWIw9RdQiDg4+epXQODwIYJtSJaow==", + "license": "MIT", + "engines": { + "node": ">=4" + } + }, + "node_modules/p-limit": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-3.1.0.tgz", + "integrity": "sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "yocto-queue": "^0.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/p-locate": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/p-locate/-/p-locate-5.0.0.tgz", + "integrity": "sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw==", + "dev": true, + "license": "MIT", + "dependencies": { + "p-limit": "^3.0.2" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/p-queue": { + "version": "6.6.2", + "resolved": "https://registry.npmjs.org/p-queue/-/p-queue-6.6.2.tgz", + "integrity": "sha512-RwFpb72c/BhQLEXIZ5K2e+AhgNVmIejGlTgiB9MzZ0e93GRvqZ7uSi0dvRF7/XIXDeNkra2fNHBxTyPDGySpjQ==", + "license": "MIT", + "dependencies": { + "eventemitter3": "^4.0.4", + "p-timeout": "^3.2.0" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/p-retry": { + "version": "4.6.2", + "resolved": "https://registry.npmjs.org/p-retry/-/p-retry-4.6.2.tgz", + "integrity": "sha512-312Id396EbJdvRONlngUx0NydfrIQ5lsYu0znKVUzVvArzEIt08V1qhtyESbGVd1FGX7UKtiFp5uwKZdM8wIuQ==", + "license": "MIT", + "dependencies": { + "@types/retry": "0.12.0", + "retry": "^0.13.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/p-timeout": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/p-timeout/-/p-timeout-3.2.0.tgz", + "integrity": "sha512-rhIwUycgwwKcP9yTOOFK/AKsAopjjCakVqLHePO3CC6Mir1Z99xT+R63jZxAT5lFZLa2inS5h+ZS2GvR99/FBg==", + "license": "MIT", + "dependencies": { + "p-finally": "^1.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/package-json-from-dist": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/package-json-from-dist/-/package-json-from-dist-1.0.1.tgz", + "integrity": "sha512-UEZIS3/by4OC8vL3P2dTXRETpebLI2NiI5vIrjaD/5UtrkFX/tNbwjTSRAGC/+7CAo2pIcBaRgWmcBBHcsaCIw==", + "dev": true, + "license": "BlueOak-1.0.0" + }, + "node_modules/parent-module": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz", + "integrity": "sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==", + "dev": true, + "license": "MIT", + "dependencies": { + "callsites": "^3.0.0" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/path-exists": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", + "integrity": "sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/path-key": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", + "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/path-scurry": { + "version": "1.11.1", + "resolved": "https://registry.npmjs.org/path-scurry/-/path-scurry-1.11.1.tgz", + "integrity": "sha512-Xa4Nw17FS9ApQFJ9umLiJS4orGjm7ZzwUrwamcGQuHSzDyth9boKDaycYdDcZDuqYATXw4HFXgaqWTctW/v1HA==", + "dev": true, + "license": "BlueOak-1.0.0", + "dependencies": { + "lru-cache": "^10.2.0", + "minipass": "^5.0.0 || ^6.0.2 || ^7.0.0" + }, + "engines": { + "node": ">=16 || 14 >=14.18" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/pathe": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/pathe/-/pathe-1.1.2.tgz", + "integrity": "sha512-whLdWMYL2TwI08hn8/ZqAbrVemu0LNaNNJZX73O6qaIdCTfXutsLhMkjdENX0qhsQ9uIimo4/aQOmXkoon2nDQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/pathval": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/pathval/-/pathval-2.0.1.tgz", + "integrity": "sha512-//nshmD55c46FuFw26xV/xFAaB5HF9Xdap7HJBBnrKdAd6/GxDBaNA1870O79+9ueg61cZLSVc+OaFlfmObYVQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 14.16" + } + }, + "node_modules/picocolors": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", + "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", + "dev": true, + "license": "ISC" + }, + "node_modules/picomatch": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", + "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, + "node_modules/postcss": { + "version": "8.5.6", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.6.tgz", + "integrity": "sha512-3Ybi1tAuwAP9s0r1UQ2J4n5Y0G05bJkpUIO0/bI9MhwmD70S5aTWbXGBwxHrelT+XM1k6dM0pk+SwNkpTRN7Pg==", + "dev": true, + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/postcss" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "nanoid": "^3.3.11", + "picocolors": "^1.1.1", + "source-map-js": "^1.2.1" + }, + "engines": { + "node": "^10 || ^12 || >=14" + } + }, + "node_modules/prelude-ls": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz", + "integrity": "sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/punycode": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", + "integrity": "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/resolve-from": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz", + "integrity": "sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=4" + } + }, + "node_modules/retry": { + "version": "0.13.1", + "resolved": "https://registry.npmjs.org/retry/-/retry-0.13.1.tgz", + "integrity": "sha512-XQBQ3I8W1Cge0Seh+6gjj03LbmRFWuoszgK9ooCpwYIrhhoO80pfq4cUkU5DkknwfOfFteRwlZ56PYOGYyFWdg==", + "license": "MIT", + "engines": { + "node": ">= 4" + } + }, + "node_modules/rollup": { + "version": "4.55.1", + "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.55.1.tgz", + "integrity": "sha512-wDv/Ht1BNHB4upNbK74s9usvl7hObDnvVzknxqY/E/O3X6rW1U1rV1aENEfJ54eFZDTNo7zv1f5N4edCluH7+A==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/estree": "1.0.8" + }, + "bin": { + "rollup": "dist/bin/rollup" + }, + "engines": { + "node": ">=18.0.0", + "npm": ">=8.0.0" + }, + "optionalDependencies": { + "@rollup/rollup-android-arm-eabi": "4.55.1", + "@rollup/rollup-android-arm64": "4.55.1", + "@rollup/rollup-darwin-arm64": "4.55.1", + "@rollup/rollup-darwin-x64": "4.55.1", + "@rollup/rollup-freebsd-arm64": "4.55.1", + "@rollup/rollup-freebsd-x64": "4.55.1", + "@rollup/rollup-linux-arm-gnueabihf": "4.55.1", + "@rollup/rollup-linux-arm-musleabihf": "4.55.1", + "@rollup/rollup-linux-arm64-gnu": "4.55.1", + "@rollup/rollup-linux-arm64-musl": "4.55.1", + "@rollup/rollup-linux-loong64-gnu": "4.55.1", + "@rollup/rollup-linux-loong64-musl": "4.55.1", + "@rollup/rollup-linux-ppc64-gnu": "4.55.1", + "@rollup/rollup-linux-ppc64-musl": "4.55.1", + "@rollup/rollup-linux-riscv64-gnu": "4.55.1", + "@rollup/rollup-linux-riscv64-musl": "4.55.1", + "@rollup/rollup-linux-s390x-gnu": "4.55.1", + "@rollup/rollup-linux-x64-gnu": "4.55.1", + "@rollup/rollup-linux-x64-musl": "4.55.1", + "@rollup/rollup-openbsd-x64": "4.55.1", + "@rollup/rollup-openharmony-arm64": "4.55.1", + "@rollup/rollup-win32-arm64-msvc": "4.55.1", + "@rollup/rollup-win32-ia32-msvc": "4.55.1", + "@rollup/rollup-win32-x64-gnu": "4.55.1", + "@rollup/rollup-win32-x64-msvc": "4.55.1", + "fsevents": "~2.3.2" + } + }, + "node_modules/semver": { + "version": "7.7.3", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.3.tgz", + "integrity": "sha512-SdsKMrI9TdgjdweUSR9MweHA4EJ8YxHn8DFaDisvhVlUOe4BF1tLD7GAj0lIqWVl+dPb/rExr0Btby5loQm20Q==", + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/shebang-command": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", + "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "dev": true, + "license": "MIT", + "dependencies": { + "shebang-regex": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/shebang-regex": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", + "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/siginfo": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/siginfo/-/siginfo-2.0.0.tgz", + "integrity": "sha512-ybx0WO1/8bSBLEWXZvEd7gMW3Sn3JFlW3TvX1nREbDLRNQNaeNN8WK0meBwPdAaOI7TtRRRJn/Es1zhrrCHu7g==", + "dev": true, + "license": "ISC" + }, + "node_modules/signal-exit": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-4.1.0.tgz", + "integrity": "sha512-bzyZ1e88w9O1iNJbKnOlvYTrWPDl46O1bG0D3XInv+9tkPrxrN8jUUTiFlDkkmKWgn1M6CfIA13SuGqOa9Korw==", + "dev": true, + "license": "ISC", + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/simple-wcswidth": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/simple-wcswidth/-/simple-wcswidth-1.1.2.tgz", + "integrity": "sha512-j7piyCjAeTDSjzTSQ7DokZtMNwNlEAyxqSZeCS+CXH7fJ4jx3FuJ/mTW3mE+6JLs4VJBbcll0Kjn+KXI5t21Iw==", + "license": "MIT" + }, + "node_modules/sirv": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/sirv/-/sirv-3.0.2.tgz", + "integrity": "sha512-2wcC/oGxHis/BoHkkPwldgiPSYcpZK3JU28WoMVv55yHJgcZ8rlXvuG9iZggz+sU1d4bRgIGASwyWqjxu3FM0g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@polka/url": "^1.0.0-next.24", + "mrmime": "^2.0.0", + "totalist": "^3.0.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/source-map-js": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz", + "integrity": "sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==", + "dev": true, + "license": "BSD-3-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/stackback": { + "version": "0.0.2", + "resolved": "https://registry.npmjs.org/stackback/-/stackback-0.0.2.tgz", + "integrity": "sha512-1XMJE5fQo1jGH6Y/7ebnwPOBEkIEnT4QF32d5R1+VXdXveM0IBMJt8zfaxX1P3QhVwrYe+576+jkANtSS2mBbw==", + "dev": true, + "license": "MIT" + }, + "node_modules/std-env": { + "version": "3.10.0", + "resolved": "https://registry.npmjs.org/std-env/-/std-env-3.10.0.tgz", + "integrity": "sha512-5GS12FdOZNliM5mAOxFRg7Ir0pWz8MdpYm6AY6VPkGpbA7ZzmbzNcBJQ0GPvvyWgcY7QAhCgf9Uy89I03faLkg==", + "dev": true, + "license": "MIT" + }, + "node_modules/string-width": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-5.1.2.tgz", + "integrity": "sha512-HnLOCR3vjcY8beoNLtcjZ5/nxn2afmME6lhrDrebokqMap+XbeW8n9TXpPDOqdGK5qcI3oT0GKTW6wC7EMiVqA==", + "dev": true, + "license": "MIT", + "dependencies": { + "eastasianwidth": "^0.2.0", + "emoji-regex": "^9.2.2", + "strip-ansi": "^7.0.1" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/string-width-cjs": { + "name": "string-width", + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "dev": true, + "license": "MIT", + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/string-width-cjs/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/string-width-cjs/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "dev": true, + "license": "MIT" + }, + "node_modules/string-width-cjs/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/strip-ansi": { + "version": "7.1.2", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-7.1.2.tgz", + "integrity": "sha512-gmBGslpoQJtgnMAvOVqGZpEz9dyoKTCzy2nfz/n8aIFhN/jCE/rCmcxabB6jOOHV+0WNnylOxaxBQPSvcWklhA==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-regex": "^6.0.1" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/strip-ansi?sponsor=1" + } + }, + "node_modules/strip-ansi-cjs": { + "name": "strip-ansi", + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/strip-ansi-cjs/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/strip-json-comments": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-3.1.1.tgz", + "integrity": "sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/strnum": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/strnum/-/strnum-1.1.2.tgz", + "integrity": "sha512-vrN+B7DBIoTTZjnPNewwhx6cBA/H+IS7rfW68n7XxC1y7uoiGQBxaKzqucGUgavX15dJgiGztLJ8vxuEzwqBdA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/NaturalIntelligence" + } + ], + "license": "MIT" + }, + "node_modules/supports-color": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "license": "MIT", + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/test-exclude": { + "version": "7.0.1", + "resolved": "https://registry.npmjs.org/test-exclude/-/test-exclude-7.0.1.tgz", + "integrity": "sha512-pFYqmTw68LXVjeWJMST4+borgQP2AyMNbg1BpZh9LbyhUeNkeaPF9gzfPGUAnSMV3qPYdWUwDIjjCLiSDOl7vg==", + "dev": true, + "license": "ISC", + "dependencies": { + "@istanbuljs/schema": "^0.1.2", + "glob": "^10.4.1", + "minimatch": "^9.0.4" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/tinybench": { + "version": "2.9.0", + "resolved": "https://registry.npmjs.org/tinybench/-/tinybench-2.9.0.tgz", + "integrity": "sha512-0+DUvqWMValLmha6lr4kD8iAMK1HzV0/aKnCtWb9v9641TnP/MFb7Pc2bxoxQjTXAErryXVgUOfv2YqNllqGeg==", + "dev": true, + "license": "MIT" + }, + "node_modules/tinyexec": { + "version": "0.3.2", + "resolved": "https://registry.npmjs.org/tinyexec/-/tinyexec-0.3.2.tgz", + "integrity": "sha512-KQQR9yN7R5+OSwaK0XQoj22pwHoTlgYqmUscPYoknOoWCWfj/5/ABTMRi69FrKU5ffPVh5QcFikpWJI/P1ocHA==", + "dev": true, + "license": "MIT" + }, + "node_modules/tinyglobby": { + "version": "0.2.15", + "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz", + "integrity": "sha512-j2Zq4NyQYG5XMST4cbs02Ak8iJUdxRM0XI5QyxXuZOzKOINmWurp3smXu3y5wDcJrptwpSjgXHzIQxR0omXljQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "fdir": "^6.5.0", + "picomatch": "^4.0.3" + }, + "engines": { + "node": ">=12.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/SuperchupuDev" + } + }, + "node_modules/tinypool": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/tinypool/-/tinypool-1.1.1.tgz", + "integrity": "sha512-Zba82s87IFq9A9XmjiX5uZA/ARWDrB03OHlq+Vw1fSdt0I+4/Kutwy8BP4Y/y/aORMo61FQ0vIb5j44vSo5Pkg==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.0.0 || >=20.0.0" + } + }, + "node_modules/tinyrainbow": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/tinyrainbow/-/tinyrainbow-1.2.0.tgz", + "integrity": "sha512-weEDEq7Z5eTHPDh4xjX789+fHfF+P8boiFB+0vbWzpbnbsEr/GRaohi/uMKxg8RZMXnl1ItAi/IUHWMsjDV7kQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/tinyspy": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/tinyspy/-/tinyspy-3.0.2.tgz", + "integrity": "sha512-n1cw8k1k0x4pgA2+9XrOkFydTerNcJ1zWCO5Nn9scWHTD+5tp8dghT2x1uduQePZTZgd3Tupf+x9BxJjeJi77Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/totalist": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/totalist/-/totalist-3.0.1.tgz", + "integrity": "sha512-sf4i37nQ2LBx4m3wB74y+ubopq6W/dIzXg0FDGjsYnZHVa1Da8FH853wlL2gtUhg+xJXjfk3kUZS3BRoQeoQBQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/tr46": { + "version": "0.0.3", + "resolved": "https://registry.npmjs.org/tr46/-/tr46-0.0.3.tgz", + "integrity": "sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==", + "license": "MIT" + }, + "node_modules/ts-algebra": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/ts-algebra/-/ts-algebra-2.0.0.tgz", + "integrity": "sha512-FPAhNPFMrkwz76P7cdjdmiShwMynZYN6SgOujD1urY4oNm80Ou9oMdmbR45LotcKOXoy7wSmHkRFE6Mxbrhefw==", + "license": "MIT" + }, + "node_modules/ts-api-utils": { + "version": "2.4.0", + "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-2.4.0.tgz", + "integrity": "sha512-3TaVTaAv2gTiMB35i3FiGJaRfwb3Pyn/j3m/bfAvGe8FB7CF6u+LMYqYlDh7reQf7UNvoTvdfAqHGmPGOSsPmA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18.12" + }, + "peerDependencies": { + "typescript": ">=4.8.4" + } + }, + "node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "license": "0BSD" + }, + "node_modules/type-check": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz", + "integrity": "sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==", + "dev": true, + "license": "MIT", + "dependencies": { + "prelude-ls": "^1.2.1" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/typescript": { + "version": "5.9.3", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.9.3.tgz", + "integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + }, + "node_modules/undici-types": { + "version": "6.21.0", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.21.0.tgz", + "integrity": "sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ==", + "license": "MIT" + }, + "node_modules/uri-js": { + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", + "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "punycode": "^2.1.0" + } + }, + "node_modules/uuid": { + "version": "10.0.0", + "resolved": "https://registry.npmjs.org/uuid/-/uuid-10.0.0.tgz", + "integrity": "sha512-8XkAphELsDnEGrDxUOHB3RGvXz6TeuYSGEZBOjtTtPm2lwhGBjLgOzLHB63IUWfBpNucQjND6d3AOudO+H3RWQ==", + "funding": [ + "https://github.com/sponsors/broofa", + "https://github.com/sponsors/ctavan" + ], + "license": "MIT", + "bin": { + "uuid": "dist/bin/uuid" + } + }, + "node_modules/vite": { + "version": "5.4.21", + "resolved": "https://registry.npmjs.org/vite/-/vite-5.4.21.tgz", + "integrity": "sha512-o5a9xKjbtuhY6Bi5S3+HvbRERmouabWbyUcpXXUA1u+GNUKoROi9byOJ8M0nHbHYHkYICiMlqxkg1KkYmm25Sw==", + "dev": true, + "license": "MIT", + "dependencies": { + "esbuild": "^0.21.3", + "postcss": "^8.4.43", + "rollup": "^4.20.0" + }, + "bin": { + "vite": "bin/vite.js" + }, + "engines": { + "node": "^18.0.0 || >=20.0.0" + }, + "funding": { + "url": "https://github.com/vitejs/vite?sponsor=1" + }, + "optionalDependencies": { + "fsevents": "~2.3.3" + }, + "peerDependencies": { + "@types/node": "^18.0.0 || >=20.0.0", + "less": "*", + "lightningcss": "^1.21.0", + "sass": "*", + "sass-embedded": "*", + "stylus": "*", + "sugarss": "*", + "terser": "^5.4.0" + }, + "peerDependenciesMeta": { + "@types/node": { + "optional": true + }, + "less": { + "optional": true + }, + "lightningcss": { + "optional": true + }, + "sass": { + "optional": true + }, + "sass-embedded": { + "optional": true + }, + "stylus": { + "optional": true + }, + "sugarss": { + "optional": true + }, + "terser": { + "optional": true + } + } + }, + "node_modules/vite-node": { + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/vite-node/-/vite-node-2.1.9.tgz", + "integrity": "sha512-AM9aQ/IPrW/6ENLQg3AGY4K1N2TGZdR5e4gu/MmmR2xR3Ll1+dib+nook92g4TV3PXVyeyxdWwtaCAiUL0hMxA==", + "dev": true, + "license": "MIT", + "dependencies": { + "cac": "^6.7.14", + "debug": "^4.3.7", + "es-module-lexer": "^1.5.4", + "pathe": "^1.1.2", + "vite": "^5.0.0" + }, + "bin": { + "vite-node": "vite-node.mjs" + }, + "engines": { + "node": "^18.0.0 || >=20.0.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/vitest": { + "version": "2.1.9", + "resolved": "https://registry.npmjs.org/vitest/-/vitest-2.1.9.tgz", + "integrity": "sha512-MSmPM9REYqDGBI8439mA4mWhV5sKmDlBKWIYbA3lRb2PTHACE0mgKwA8yQ2xq9vxDTuk4iPrECBAEW2aoFXY0Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/expect": "2.1.9", + "@vitest/mocker": "2.1.9", + "@vitest/pretty-format": "^2.1.9", + "@vitest/runner": "2.1.9", + "@vitest/snapshot": "2.1.9", + "@vitest/spy": "2.1.9", + "@vitest/utils": "2.1.9", + "chai": "^5.1.2", + "debug": "^4.3.7", + "expect-type": "^1.1.0", + "magic-string": "^0.30.12", + "pathe": "^1.1.2", + "std-env": "^3.8.0", + "tinybench": "^2.9.0", + "tinyexec": "^0.3.1", + "tinypool": "^1.0.1", + "tinyrainbow": "^1.2.0", + "vite": "^5.0.0", + "vite-node": "2.1.9", + "why-is-node-running": "^2.3.0" + }, + "bin": { + "vitest": "vitest.mjs" + }, + "engines": { + "node": "^18.0.0 || >=20.0.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "@edge-runtime/vm": "*", + "@types/node": "^18.0.0 || >=20.0.0", + "@vitest/browser": "2.1.9", + "@vitest/ui": "2.1.9", + "happy-dom": "*", + "jsdom": "*" + }, + "peerDependenciesMeta": { + "@edge-runtime/vm": { + "optional": true + }, + "@types/node": { + "optional": true + }, + "@vitest/browser": { + "optional": true + }, + "@vitest/ui": { + "optional": true + }, + "happy-dom": { + "optional": true + }, + "jsdom": { + "optional": true + } + } + }, + "node_modules/web-streams-polyfill": { + "version": "4.0.0-beta.3", + "resolved": "https://registry.npmjs.org/web-streams-polyfill/-/web-streams-polyfill-4.0.0-beta.3.tgz", + "integrity": "sha512-QW95TCTaHmsYfHDybGMwO5IJIM93I/6vTRk+daHTWFPhwh+C8Cg7j7XyKrwrj8Ib6vYXe0ocYNrmzY4xAAN6ug==", + "license": "MIT", + "engines": { + "node": ">= 14" + } + }, + "node_modules/webidl-conversions": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/webidl-conversions/-/webidl-conversions-3.0.1.tgz", + "integrity": "sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ==", + "license": "BSD-2-Clause" + }, + "node_modules/whatwg-url": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/whatwg-url/-/whatwg-url-5.0.0.tgz", + "integrity": "sha512-saE57nupxk6v3HY35+jzBwYa0rKSy0XR8JSxZPwgLr7ys0IBzhGviA1/TUGJLmSVqs8pb9AnvICXEuOHLprYTw==", + "license": "MIT", + "dependencies": { + "tr46": "~0.0.3", + "webidl-conversions": "^3.0.0" + } + }, + "node_modules/which": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", + "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "dev": true, + "license": "ISC", + "dependencies": { + "isexe": "^2.0.0" + }, + "bin": { + "node-which": "bin/node-which" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/why-is-node-running": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/why-is-node-running/-/why-is-node-running-2.3.0.tgz", + "integrity": "sha512-hUrmaWBdVDcxvYqnyh09zunKzROWjbZTiNy8dBEjkS7ehEDQibXJ7XvlmtbwuTclUiIyN+CyXQD4Vmko8fNm8w==", + "dev": true, + "license": "MIT", + "dependencies": { + "siginfo": "^2.0.0", + "stackback": "0.0.2" + }, + "bin": { + "why-is-node-running": "cli.js" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/word-wrap": { + "version": "1.2.5", + "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.5.tgz", + "integrity": "sha512-BN22B5eaMMI9UMtjrGd5g5eCYPpCPDUy0FJXbYsaT5zYxjFOckS53SQDE3pWkVoWpHXVb3BrYcEN4Twa55B5cA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/wrap-ansi": { + "version": "8.1.0", + "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-8.1.0.tgz", + "integrity": "sha512-si7QWI6zUMq56bESFvagtmzMdGOtoxfR+Sez11Mobfc7tm+VkUckk9bW2UeffTGVUbOksxmSw0AA2gs8g71NCQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-styles": "^6.1.0", + "string-width": "^5.0.1", + "strip-ansi": "^7.0.1" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/wrap-ansi?sponsor=1" + } + }, + "node_modules/wrap-ansi-cjs": { + "name": "wrap-ansi", + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz", + "integrity": "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-styles": "^4.0.0", + "string-width": "^4.1.0", + "strip-ansi": "^6.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/wrap-ansi?sponsor=1" + } + }, + "node_modules/wrap-ansi-cjs/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/wrap-ansi-cjs/node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "dev": true, + "license": "MIT", + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/wrap-ansi-cjs/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "dev": true, + "license": "MIT" + }, + "node_modules/wrap-ansi-cjs/node_modules/string-width": { + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "dev": true, + "license": "MIT", + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/wrap-ansi-cjs/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/wrap-ansi/node_modules/ansi-styles": { + "version": "6.2.3", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-6.2.3.tgz", + "integrity": "sha512-4Dj6M28JB+oAH8kFkTLUo+a2jwOFkuqb3yucU0CANcRRUbxS0cP0nZYCGjcc3BNXwRIsUVmDGgzawme7zvJHvg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/yaml": { + "version": "2.8.2", + "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.8.2.tgz", + "integrity": "sha512-mplynKqc1C2hTVYxd0PU2xQAc22TI1vShAYGksCCfxbn/dFwnHTNi1bvYsBTkhdUNtGIf5xNOg938rrSSYvS9A==", + "license": "ISC", + "bin": { + "yaml": "bin.mjs" + }, + "engines": { + "node": ">= 14.6" + }, + "funding": { + "url": "https://github.com/sponsors/eemeli" + } + }, + "node_modules/yocto-queue": { + "version": "0.1.0", + "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", + "integrity": "sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/zod": { + "version": "3.25.76", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz", + "integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + }, + "node_modules/zod-to-json-schema": { + "version": "3.25.1", + "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.25.1.tgz", + "integrity": "sha512-pM/SU9d3YAggzi6MtR4h7ruuQlqKtad8e9S0fmxcMi+ueAK5Korys/aWcV9LIIHTVbj01NdzxcnXSN+O74ZIVA==", + "license": "ISC", + "peerDependencies": { + "zod": "^3.25 || ^4" + } + } + } +} diff --git a/tests/integrations/typescript/package.json b/tests/integrations/typescript/package.json new file mode 100644 index 0000000000..50edcf192e --- /dev/null +++ b/tests/integrations/typescript/package.json @@ -0,0 +1,41 @@ +{ + "name": "bifrost-integration-tests-typescript", + "version": "0.1.0", + "description": "TypeScript integration tests for Bifrost AI gateway", + "type": "module", + "scripts": { + "test": "vitest run", + "test:watch": "vitest", + "test:coverage": "vitest run --coverage", + "test:ui": "vitest --ui", + "typecheck": "tsc --noEmit", + "lint": "eslint src tests --ext .ts" + }, + "devDependencies": { + "@types/node": "^22.10.0", + "@typescript-eslint/eslint-plugin": "^8.0.0", + "@typescript-eslint/parser": "^8.0.0", + "@vitest/coverage-v8": "^2.1.0", + "@vitest/ui": "^2.1.0", + "dotenv": "^16.4.0", + "eslint": "^9.0.0", + "typescript": "^5.7.0", + "vitest": "^2.1.0" + }, + "dependencies": { + "@anthropic-ai/sdk": "^0.71.2", + "@aws-sdk/client-bedrock": "^3.966.0", + "@aws-sdk/client-bedrock-runtime": "^3.965.0", + "@google/generative-ai": "^0.24.1", + "@langchain/anthropic": "^0.3.0", + "@langchain/core": "^0.3.0", + "@langchain/google-genai": "^0.1.0", + "@langchain/openai": "^0.3.0", + "openai": "^6.15.0", + "yaml": "^2.6.0", + "zod": "^3.24.0" + }, + "engines": { + "node": ">=20.0.0" + } +} diff --git a/tests/integrations/typescript/src/utils/common.ts b/tests/integrations/typescript/src/utils/common.ts new file mode 100644 index 0000000000..186e676b39 --- /dev/null +++ b/tests/integrations/typescript/src/utils/common.ts @@ -0,0 +1,949 @@ +/** + * Common utilities and test data for all integration tests. + * This module contains shared functions, test data, and assertions + * that can be used across all integration-specific test files. + */ + +import { expect } from 'vitest' + +// ============================================================================ +// Test Configuration +// ============================================================================ + +export interface Config { + timeout: number + maxRetries: number + debug: boolean +} + +export const defaultConfig: Config = { + timeout: 30, + maxRetries: 3, + debug: false, +} + +// ============================================================================ +// Image Test Data +// ============================================================================ + +export const IMAGE_URL = + 'https://pub-cdead89c2f004d8f963fd34010c479d0.r2.dev/Gfp-wisconsin-madison-the-nature-boardwalk.jpg' +export const IMAGE_URL_SECONDARY = 'https://goo.gle/instrument-img' + +// Small test image as base64 (1x1 pixel red PNG) +export const BASE64_IMAGE = + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==' + +// Base64 PDF test data +export const FILE_DATA_BASE64 = + 'JVBERi0xLjcKCjEgMCBvYmogICUgZW50cnkgcG9pbnQKPDwKICAvVHlwZSAvQ2F0YWxvZwogIC' + + '9QYWdlcyAyIDAgUgo+PgplbmRvYmoKCjIgMCBvYmoKPDwKICAvVHlwZSAvUGFnZXwKICAvTWV' + + 'kaWFCb3ggWyAwIDAgMjAwIDIwMCBdCiAgL0NvdW50IDEKICAvS2lkcyBbIDMgMCBSIF0KPj4K' + + 'ZW5kb2JqCgozIDAgb2JqCjw8CiAgL1R5cGUgL1BhZ2UKICAvUGFyZW50IDIgMCBSCiAgL1Jlc' + + '291cmNlcyA8PAogICAgL0ZvbnQgPDwKICAgICAgL0YxIDQgMCBSCj4+CiAgPj4KICAvQ29udG' + + 'VudHMgNSAwIFIKPj4KZW5kb2JqCgo0IDAgb2JqCjw8CiAgL1R5cGUgL0ZvbnQKICAvU3VidHl' + + 'wZSAvVHlwZTEKICAvQmFzZUZvbnQgL1RpbWVzLVJvbWFuCj4+CmVuZG9iagoKNSAwIG9iago8' + + 'PAogIC9MZW5ndGggNDQKPj4Kc3RyZWFtCkJUCjcwIDUwIFRECi9GMSAxMiBUZgooSGVsbG8gV' + + '29ybGQhKSBUagpFVAplbmRzdHJlYW0KZW5kb2JqCgp4cmVmCjAgNgowMDAwMDAwMDAwIDY1NT' + + 'M1IGYgCjAwMDAwMDAwMTAgMDAwMDAgbiAKMDAwMDAwMDA2MCAwMDAwMCBuIAowMDAwMDAwMTU' + + '3IDAwMDAwIG4gCjAwMDAwMDAyNTUgMDAwMDAgbiAKMDAwMDAwMDM1MyAwMDAwMCBuIAp0cmFp' + + 'bGVyCjw8CiAgL1NpemUgNgogIC9Sb290IDEgMCBSCj4+CnN0YXJ0eHJlZgo0NDkKJSVFT0YK' + +// ============================================================================ +// Common Test Messages +// ============================================================================ + +export interface ChatMessage { + role: 'user' | 'assistant' | 'system' + content: string | ContentPart[] +} + +export interface ContentPart { + type: 'text' | 'image_url' + text?: string + image_url?: { + url: string + } +} + +export const SIMPLE_CHAT_MESSAGES: ChatMessage[] = [ + { role: 'user', content: 'Hello! How are you today?' }, +] + +export const MULTI_TURN_MESSAGES: ChatMessage[] = [ + { role: 'user', content: "What's the capital of France?" }, + { role: 'assistant', content: 'The capital of France is Paris.' }, + { role: 'user', content: "What's the population of that city?" }, +] + +export const STREAMING_CHAT_MESSAGES: ChatMessage[] = [ + { role: 'user', content: 'Count from 1 to 5, one number per line.' }, +] + +export const INVALID_ROLE_MESSAGES = [{ role: 'invalid_role', content: 'This should fail' }] + +// ============================================================================ +// Tool Definitions +// ============================================================================ + +export interface ToolDefinition { + name: string + description: string + parameters: { + type: 'object' + properties: Record + required?: string[] + } +} + +export const WEATHER_TOOL: ToolDefinition = { + name: 'get_weather', + description: 'Get the current weather for a location', + parameters: { + type: 'object', + properties: { + location: { + type: 'string', + description: 'The city and state, e.g. San Francisco, CA', + }, + unit: { + type: 'string', + enum: ['celsius', 'fahrenheit'], + description: 'The temperature unit', + }, + }, + required: ['location'], + }, +} + +export const CALCULATOR_TOOL: ToolDefinition = { + name: 'calculate', + description: 'Perform basic mathematical calculations', + parameters: { + type: 'object', + properties: { + expression: { + type: 'string', + description: "Mathematical expression to evaluate, e.g. '2 + 2'", + }, + }, + required: ['expression'], + }, +} + +export const SEARCH_TOOL: ToolDefinition = { + name: 'search_web', + description: 'Search the web for information', + parameters: { + type: 'object', + properties: { + query: { + type: 'string', + description: 'Search query', + }, + }, + required: ['query'], + }, +} + +export const ALL_TOOLS: ToolDefinition[] = [WEATHER_TOOL, CALCULATOR_TOOL, SEARCH_TOOL] + +// ============================================================================ +// Tool Call Test Messages +// ============================================================================ + +export const SINGLE_TOOL_CALL_MESSAGES: ChatMessage[] = [ + { role: 'user', content: "What's the weather like in New York?" }, +] + +export const MULTIPLE_TOOL_CALL_MESSAGES: ChatMessage[] = [ + { + role: 'user', + content: "What's the weather in New York and also calculate 25 * 4?", + }, +] + +export const STREAMING_TOOL_CALL_MESSAGES: ChatMessage[] = [ + { role: 'user', content: "What's the weather like in San Francisco?" }, +] + +// ============================================================================ +// Image Test Messages +// ============================================================================ + +export const IMAGE_URL_MESSAGES: ChatMessage[] = [ + { + role: 'user', + content: [ + { type: 'text', text: 'What do you see in this image? Describe it briefly.' }, + { type: 'image_url', image_url: { url: IMAGE_URL } }, + ], + }, +] + +export const IMAGE_BASE64_MESSAGES: ChatMessage[] = [ + { + role: 'user', + content: [ + { type: 'text', text: 'What color is this image?' }, + { type: 'image_url', image_url: { url: `data:image/png;base64,${BASE64_IMAGE}` } }, + ], + }, +] + +export const MULTIPLE_IMAGES_MESSAGES: ChatMessage[] = [ + { + role: 'user', + content: [ + { type: 'text', text: 'Compare these two images. What do you see?' }, + { type: 'image_url', image_url: { url: IMAGE_URL } }, + { type: 'image_url', image_url: { url: `data:image/png;base64,${BASE64_IMAGE}` } }, + ], + }, +] + +// ============================================================================ +// Complex End-to-End Test Messages +// ============================================================================ + +export const COMPLEX_E2E_MESSAGES: ChatMessage[] = [ + { role: 'user', content: "What's the weather in Paris and calculate 100 / 5?" }, +] + +// ============================================================================ +// Speech and Transcription Test Data +// ============================================================================ + +export const SPEECH_TEST_INPUT = 'Hello, this is a test of speech synthesis through Bifrost.' + +export const SPEECH_VOICES = ['alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'] as const +export type SpeechVoice = (typeof SPEECH_VOICES)[number] + +export const AUDIO_FORMATS = ['mp3', 'opus', 'aac', 'flac', 'wav', 'pcm'] as const +export type AudioFormat = (typeof AUDIO_FORMATS)[number] + +// ============================================================================ +// Embeddings Test Data +// ============================================================================ + +export const EMBEDDINGS_SINGLE_TEXT = 'The quick brown fox jumps over the lazy dog.' + +export const EMBEDDINGS_MULTIPLE_TEXTS = [ + 'The quick brown fox jumps over the lazy dog.', + 'A fast auburn canine leaps above a sleepy hound.', + 'Machine learning is transforming technology.', +] + +export const EMBEDDINGS_SIMILAR_TEXTS = [ + 'The cat sat on the mat.', + 'A feline rested on the rug.', +] + +export const EMBEDDINGS_DIFFERENT_TEXTS = [ + 'The weather is sunny today.', + 'Quantum physics explores subatomic particles.', +] + +export const EMBEDDINGS_LONG_TEXT = + 'This is a longer piece of text that contains multiple sentences. ' + + 'It is designed to test how embedding models handle longer inputs. ' + + 'The text continues with more content to ensure adequate length for testing purposes. ' + + 'We want to verify that the embedding generation works correctly with extended text.' + +// ============================================================================ +// Responses API Test Data +// ============================================================================ + +export const RESPONSES_SIMPLE_TEXT_INPUT = 'What is the capital of France?' + +export const RESPONSES_TEXT_WITH_SYSTEM = { + system: 'You are a helpful geography assistant.', + user: 'What is the capital of France?', +} + +export const RESPONSES_IMAGE_INPUT = { + text: 'What do you see in this image?', + imageUrl: IMAGE_URL, +} + +export const RESPONSES_TOOL_CALL_INPUT = "What's the weather like in London?" + +export const RESPONSES_STREAMING_INPUT = 'Count from 1 to 5.' + +export const RESPONSES_REASONING_INPUT = 'Explain step by step how to solve: What is 15% of 80?' + +// ============================================================================ +// Text Completion Test Data +// ============================================================================ + +export const TEXT_COMPLETION_SIMPLE_PROMPT = 'Once upon a time, in a land far away,' + +export const TEXT_COMPLETION_STREAMING_PROMPT = 'The quick brown fox' + +// ============================================================================ +// Input Tokens Test Data +// ============================================================================ + +export const INPUT_TOKENS_SIMPLE_TEXT = 'Hello, how are you?' + +export const INPUT_TOKENS_LONG_TEXT = + 'This is a longer piece of text that should result in more tokens being counted. ' + + 'It contains multiple sentences and various words to ensure accurate token counting.' + +export const INPUT_TOKENS_WITH_SYSTEM = { + system: 'You are a helpful assistant.', + user: 'What is 2 + 2?', +} + +// ============================================================================ +// Keyword Lists for Response Validation +// ============================================================================ + +export const WEATHER_KEYWORDS = [ + 'weather', + 'temperature', + 'degrees', + 'sunny', + 'cloudy', + 'rain', + 'forecast', + 'warm', + 'cold', + 'humid', +] + +export const LOCATION_KEYWORDS = ['new york', 'ny', 'nyc', 'city', 'manhattan'] + +export const COMPARISON_KEYWORDS = ['both', 'compare', 'similar', 'different', 'first', 'second'] + +// ============================================================================ +// API Key Utilities +// ============================================================================ + +const API_KEY_MAP: Record = { + openai: 'OPENAI_API_KEY', + anthropic: 'ANTHROPIC_API_KEY', + google: 'GEMINI_API_KEY', + gemini: 'GEMINI_API_KEY', + litellm: 'LITELLM_API_KEY', + bedrock: 'AWS_ACCESS_KEY_ID', + cohere: 'COHERE_API_KEY', + xai: 'XAI_API_KEY', +} + +export function getApiKey(integration: string): string { + const envVar = API_KEY_MAP[integration.toLowerCase()] + if (!envVar) { + throw new Error(`Unknown integration: ${integration}`) + } + + const apiKey = process.env[envVar] + if (!apiKey) { + throw new Error(`${envVar} environment variable not set`) + } + + return apiKey +} + +export function hasApiKey(integration: string): boolean { + try { + getApiKey(integration) + return true + } catch { + return false + } +} + +export function skipIfNoApiKey(integration: string): void { + if (!hasApiKey(integration)) { + const envVar = API_KEY_MAP[integration.toLowerCase()] + throw new Error(`Skipping: ${envVar} not set`) + } +} + +// ============================================================================ +// Response Content Extraction +// ============================================================================ + +export function getContentString(content: unknown): string { + if (typeof content === 'string') { + return content + } + + if (Array.isArray(content)) { + return content + .map((item) => { + if (typeof item === 'string') return item + if (item?.text) return item.text + if (item?.content) return item.content + return '' + }) + .join('') + } + + if (content && typeof content === 'object') { + const obj = content as Record + if (obj.text) return String(obj.text) + if (obj.content) return getContentString(obj.content) + } + + return '' +} + +// ============================================================================ +// Tool Call Extraction +// ============================================================================ + +export interface ExtractedToolCall { + name: string + arguments: Record +} + +export function extractToolCalls(response: unknown): ExtractedToolCall[] { + const toolCalls: ExtractedToolCall[] = [] + + // Handle OpenAI-style response + if (response && typeof response === 'object') { + const obj = response as Record + + // OpenAI format: response.choices[0].message.tool_calls + if (Array.isArray(obj.choices)) { + const choice = obj.choices[0] as Record + const message = choice?.message as Record + const calls = message?.tool_calls as Array> + + if (Array.isArray(calls)) { + for (const call of calls) { + const fn = call.function as Record + if (fn?.name) { + let args: Record = {} + try { + args = + typeof fn.arguments === 'string' + ? JSON.parse(fn.arguments) + : (fn.arguments as Record) || {} + } catch { + // Keep empty args + } + toolCalls.push({ + name: String(fn.name), + arguments: args, + }) + } + } + } + } + + // Direct tool_calls property + if (Array.isArray(obj.tool_calls)) { + for (const call of obj.tool_calls) { + const callObj = call as Record + const fn = callObj.function as Record + if (fn?.name) { + let args: Record = {} + try { + args = + typeof fn.arguments === 'string' + ? JSON.parse(fn.arguments) + : (fn.arguments as Record) || {} + } catch { + // Keep empty args + } + toolCalls.push({ + name: String(fn.name), + arguments: args, + }) + } + } + } + } + + return toolCalls +} + +// ============================================================================ +// Mock Tool Responses +// ============================================================================ + +export function mockToolResponse(toolName: string, args: Record): string { + switch (toolName) { + case 'get_weather': + return JSON.stringify({ + temperature: 72, + condition: 'sunny', + location: args.location || 'Unknown', + unit: args.unit || 'fahrenheit', + }) + case 'calculate': + try { + // Safe evaluation for simple math + const expr = String(args.expression || '0') + // Guardrails against pathological model output + if (expr.length > 200) { + return JSON.stringify({ error: 'Expression too long', expression: args.expression }) + } + const sanitized = expr.replace(/[^0-9+\-*/().% ]/g, '') + // Reject empty expressions after sanitization + if (sanitized.trim() === '') { + return JSON.stringify({ error: 'Empty expression', expression: args.expression }) + } + const result = Function(`"use strict"; return (${sanitized})`)() + return JSON.stringify({ result, expression: args.expression }) + } catch { + return JSON.stringify({ error: 'Invalid expression', expression: args.expression }) + } + case 'search_web': + return JSON.stringify({ + results: [ + { title: 'Sample Result 1', url: 'https://example.com/1' }, + { title: 'Sample Result 2', url: 'https://example.com/2' }, + ], + query: args.query, + }) + default: + return JSON.stringify({ status: 'ok', tool: toolName, args }) + } +} + +// ============================================================================ +// Assertion Helpers +// ============================================================================ + +export function assertValidChatResponse(response: unknown): void { + expect(response).toBeDefined() + expect(response).not.toBeNull() + + const obj = response as Record + + // OpenAI-style response + if (obj.choices) { + expect(Array.isArray(obj.choices)).toBe(true) + expect(obj.choices.length).toBeGreaterThan(0) + + const choice = (obj.choices as Array>)[0] + expect(choice.message).toBeDefined() + + const message = choice.message as Record + const content = getContentString(message.content) + + // Allow empty content if there are tool calls + if (!message.tool_calls) { + expect(content.length).toBeGreaterThan(0) + } + } + // Direct content response + else if (obj.content !== undefined) { + const content = getContentString(obj.content) + if (!obj.tool_calls) { + expect(content.length).toBeGreaterThan(0) + } + } +} + +export function assertHasToolCalls(response: unknown, expectedCount?: number): void { + const toolCalls = extractToolCalls(response) + expect(toolCalls.length).toBeGreaterThan(0) + + if (expectedCount !== undefined) { + expect(toolCalls.length).toBe(expectedCount) + } +} + +export function assertValidImageResponse(response: unknown): void { + assertValidChatResponse(response) + + const obj = response as Record + let content = '' + + if (obj.choices) { + const choice = (obj.choices as Array>)[0] + const message = choice.message as Record + content = getContentString(message.content) + } else if (obj.content !== undefined) { + content = getContentString(obj.content) + } + + // Image analysis responses should have meaningful content + expect(content.length).toBeGreaterThan(10) +} + +export function assertValidEmbeddingResponse(response: unknown): void { + expect(response).toBeDefined() + const obj = response as Record + + // OpenAI-style embedding response + if (obj.data) { + expect(Array.isArray(obj.data)).toBe(true) + expect((obj.data as unknown[]).length).toBeGreaterThan(0) + + const embedding = (obj.data as Array>)[0] + expect(embedding.embedding).toBeDefined() + expect(Array.isArray(embedding.embedding)).toBe(true) + expect((embedding.embedding as number[]).length).toBeGreaterThan(0) + } + // Direct embedding array + else if (obj.embedding) { + expect(Array.isArray(obj.embedding)).toBe(true) + expect((obj.embedding as number[]).length).toBeGreaterThan(0) + } +} + +export function assertValidEmbeddingsBatchResponse(response: unknown, expectedCount: number): void { + expect(response).toBeDefined() + const obj = response as Record + + if (obj.data) { + expect(Array.isArray(obj.data)).toBe(true) + expect((obj.data as unknown[]).length).toBe(expectedCount) + } +} + +export function assertValidSpeechResponse(response: unknown): void { + expect(response).toBeDefined() + + // Response should be audio data (ArrayBuffer or similar) + if (response instanceof ArrayBuffer) { + expect(response.byteLength).toBeGreaterThan(0) + } else if (response instanceof Uint8Array) { + expect(response.length).toBeGreaterThan(0) + } else if (response && typeof response === 'object') { + const obj = response as Record + // Check for content property (some SDKs wrap it) + if (obj.content) { + expect(obj.content).toBeDefined() + } + } +} + +export function assertValidTranscriptionResponse(response: unknown): void { + expect(response).toBeDefined() + const obj = response as Record + + // Should have transcribed text + if (typeof obj.text === 'string') { + expect(obj.text.length).toBeGreaterThan(0) + } else if (typeof response === 'string') { + expect(response.length).toBeGreaterThan(0) + } +} + +export function assertValidBatchResponse(response: unknown): void { + expect(response).toBeDefined() + const obj = response as Record + + expect(obj.id).toBeDefined() + expect(typeof obj.id).toBe('string') +} + +export function assertValidBatchListResponse(response: unknown): void { + expect(response).toBeDefined() + const obj = response as Record + + expect(obj.data).toBeDefined() + expect(Array.isArray(obj.data)).toBe(true) +} + +export function assertValidFileResponse(response: unknown): void { + expect(response).toBeDefined() + const obj = response as Record + + expect(obj.id).toBeDefined() + expect(typeof obj.id).toBe('string') +} + +export function assertValidFileListResponse(response: unknown): void { + expect(response).toBeDefined() + const obj = response as Record + + expect(obj.data).toBeDefined() + expect(Array.isArray(obj.data)).toBe(true) +} + +export function assertValidFileDeleteResponse(response: unknown): void { + expect(response).toBeDefined() + const obj = response as Record + + expect(obj.deleted).toBe(true) +} + +export function assertValidInputTokensResponse(response: unknown): void { + expect(response).toBeDefined() + const obj = response as Record + + // Should have token count information + if (typeof obj.total_tokens === 'number') { + expect(obj.total_tokens).toBeGreaterThan(0) + } else if (typeof obj.input_tokens === 'number') { + expect(obj.input_tokens).toBeGreaterThan(0) + } +} + +export function assertValidResponsesResponse(response: unknown): void { + expect(response).toBeDefined() + const obj = response as Record + + // Responses API should have output + if (obj.output) { + expect(obj.output).toBeDefined() + } else if (obj.choices) { + expect(Array.isArray(obj.choices)).toBe(true) + expect((obj.choices as unknown[]).length).toBeGreaterThan(0) + } +} + +export function assertValidTextCompletionResponse(response: unknown): void { + expect(response).toBeDefined() + const obj = response as Record + + if (obj.choices) { + expect(Array.isArray(obj.choices)).toBe(true) + const choice = (obj.choices as Array>)[0] + expect(choice.text).toBeDefined() + } +} + +export function assertErrorPropagation(error: unknown): void { + expect(error).toBeDefined() + + if (error instanceof Error) { + expect(error.message).toBeDefined() + expect(error.message.length).toBeGreaterThan(0) + } +} + +export function assertValidErrorResponse(error: unknown): void { + assertErrorPropagation(error) +} + +// ============================================================================ +// Streaming Utilities +// ============================================================================ + +export async function collectStreamingContent(stream: AsyncIterable): Promise { + let content = '' + + for await (const chunk of stream) { + const chunkObj = chunk as Record + + // OpenAI-style streaming + if (chunkObj.choices) { + const choice = (chunkObj.choices as Array>)[0] + const delta = choice?.delta as Record + if (delta?.content) { + content += String(delta.content) + } + } + // Direct content delta + else if (chunkObj.delta) { + const delta = chunkObj.delta as Record + if (delta.content) { + content += String(delta.content) + } + } + // Text chunk + else if (chunkObj.text) { + content += String(chunkObj.text) + } + } + + return content +} + +export async function collectStreamingTranscriptionContent(stream: AsyncIterable): Promise { + let content = '' + + for await (const chunk of stream) { + if (typeof chunk === 'string') { + content += chunk + } else { + const chunkObj = chunk as Record + if (chunkObj.text) { + content += String(chunkObj.text) + } + } + } + + return content +} + +export async function collectTextCompletionStreamingContent(stream: AsyncIterable): Promise { + let content = '' + + for await (const chunk of stream) { + const chunkObj = chunk as Record + + if (chunkObj.choices) { + const choice = (chunkObj.choices as Array>)[0] + if (choice?.text) { + content += String(choice.text) + } + } + } + + return content +} + +export async function collectResponsesStreamingContent(stream: AsyncIterable): Promise { + let content = '' + + for await (const chunk of stream) { + const chunkObj = chunk as Record + + if (chunkObj.output) { + content += String(chunkObj.output) + } else if (chunkObj.delta) { + const delta = chunkObj.delta as Record + if (delta.content) { + content += String(delta.content) + } + } + } + + return content +} + +// ============================================================================ +// Cosine Similarity for Embeddings +// ============================================================================ + +export function calculateCosineSimilarity(a: number[], b: number[]): number { + if (a.length !== b.length) { + throw new Error('Vectors must have the same length') + } + + let dotProduct = 0 + let normA = 0 + let normB = 0 + + for (let i = 0; i < a.length; i++) { + dotProduct += a[i] * b[i] + normA += a[i] * a[i] + normB += b[i] * b[i] + } + + // Handle zero vectors to avoid division by zero + if (normA === 0 || normB === 0) { + return 0 + } + + return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)) +} + +// ============================================================================ +// Audio Generation for Testing +// ============================================================================ + +export function generateTestAudio(durationMs: number = 1000, sampleRate: number = 16000): Buffer { + // Generate a simple sine wave audio for testing + const numSamples = Math.floor((durationMs / 1000) * sampleRate) + const frequency = 440 // A4 note + + // Create WAV header + const headerSize = 44 + const dataSize = numSamples * 2 // 16-bit samples + const fileSize = headerSize + dataSize - 8 + + const buffer = Buffer.alloc(headerSize + dataSize) + + // RIFF header + buffer.write('RIFF', 0) + buffer.writeUInt32LE(fileSize, 4) + buffer.write('WAVE', 8) + + // fmt chunk + buffer.write('fmt ', 12) + buffer.writeUInt32LE(16, 16) // chunk size + buffer.writeUInt16LE(1, 20) // audio format (PCM) + buffer.writeUInt16LE(1, 22) // num channels + buffer.writeUInt32LE(sampleRate, 24) // sample rate + buffer.writeUInt32LE(sampleRate * 2, 28) // byte rate + buffer.writeUInt16LE(2, 32) // block align + buffer.writeUInt16LE(16, 34) // bits per sample + + // data chunk + buffer.write('data', 36) + buffer.writeUInt32LE(dataSize, 40) + + // Generate sine wave samples + for (let i = 0; i < numSamples; i++) { + const t = i / sampleRate + const sample = Math.sin(2 * Math.PI * frequency * t) * 32767 * 0.5 + buffer.writeInt16LE(Math.round(sample), headerSize + i * 2) + } + + return buffer +} + +// ============================================================================ +// Provider Voice Mapping +// ============================================================================ + +const PROVIDER_VOICES: Record = { + openai: ['alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'], + google: ['Puck', 'Charon', 'Kore', 'Fenrir', 'Aoede'], +} + +export function getProviderVoices(provider: string): string[] { + return PROVIDER_VOICES[provider] || PROVIDER_VOICES.openai +} + +export function getProviderVoice(provider: string, index: number = 0): string { + const voices = getProviderVoices(provider) + return voices[index % voices.length] +} + +// ============================================================================ +// OpenAI Tool Format Conversion +// ============================================================================ + +export interface OpenAITool { + type: 'function' + function: ToolDefinition +} + +export function convertToOpenAITools(tools: ToolDefinition[]): OpenAITool[] { + return tools.map((tool) => ({ + type: 'function', + function: tool, + })) +} + +// ============================================================================ +// Responses API Tool Format Conversion +// ============================================================================ + +export function convertToResponsesTools(tools: ToolDefinition[]): OpenAITool[] { + return convertToOpenAITools(tools) +} + +// ============================================================================ +// Batch API Utilities +// ============================================================================ + +export interface BatchRequest { + custom_id: string + method: string + url: string + body: Record +} + +export function createBatchJsonlContent(requests: BatchRequest[]): string { + return requests.map((r) => JSON.stringify(r)).join('\n') +} + +export function createBatchInlineRequests( + model: string, + messages: ChatMessage[][], + idPrefix: string = 'req' +): BatchRequest[] { + return messages.map((msgs, index) => ({ + custom_id: `${idPrefix}-${index}`, + method: 'POST', + url: '/v1/chat/completions', + body: { + model, + messages: msgs, + max_tokens: 100, + }, + })) +} diff --git a/tests/integrations/typescript/src/utils/config-loader.ts b/tests/integrations/typescript/src/utils/config-loader.ts new file mode 100644 index 0000000000..915826ec27 --- /dev/null +++ b/tests/integrations/typescript/src/utils/config-loader.ts @@ -0,0 +1,476 @@ +/** + * Configuration loader for Bifrost integration tests. + * + * This module loads configuration from config.yml and provides utilities + * for constructing integration URLs through the Bifrost gateway. + */ + +import { readFileSync, existsSync } from 'fs' +import { resolve, dirname } from 'path' +import { fileURLToPath } from 'url' +import { parse as parseYaml } from 'yaml' + +// Get __dirname equivalent for ES modules +const __filename = fileURLToPath(import.meta.url) +const __dirname = dirname(__filename) + +// Integration to provider mapping +// Maps integration names to their underlying provider configurations +export const INTEGRATION_TO_PROVIDER_MAP: Record = { + openai: 'openai', + anthropic: 'anthropic', + google: 'gemini', // Google integration uses Gemini provider + litellm: 'openai', // LiteLLM defaults to OpenAI + langchain: 'openai', // LangChain defaults to OpenAI + pydanticai: 'openai', // Pydantic AI defaults to OpenAI + bedrock: 'bedrock', // Bedrock defaults to Amazon provider +} + +export interface BifrostConfig { + base_url: string + endpoints: Record +} + +export interface ApiConfig { + timeout: number + max_retries: number + retry_delay: number +} + +export interface TestSettings { + max_tokens: Record + timeouts: Record + retries: { + max_attempts: number + delay: number + } +} + +export interface ProviderScenarios { + [scenario: string]: boolean +} + +export interface RawConfig { + bifrost: BifrostConfig + api: ApiConfig + providers: Record> + provider_api_keys: Record + provider_scenarios: Record + scenario_capabilities: Record + model_capabilities: Record> + test_settings: TestSettings + integration_settings: Record> + environments: Record> + logging: Record + virtual_key?: { + enabled: boolean + value: string + } +} + +class ConfigLoader { + private config: RawConfig | null = null + private configPath: string + + constructor(configPath?: string) { + if (configPath) { + this.configPath = configPath + } else { + // Look for config.yml in project root (symlinked from python) + this.configPath = resolve(__dirname, '../../config.yml') + } + this.loadConfig() + } + + private loadConfig(): void { + if (!existsSync(this.configPath)) { + throw new Error(`Configuration file not found: ${this.configPath}`) + } + + const rawContent = readFileSync(this.configPath, 'utf-8') + let rawConfig: unknown + try { + rawConfig = parseYaml(rawContent) + } catch (e) { + throw new Error(`Failed to parse YAML config at ${this.configPath}: ${String(e)}`) + } + if (rawConfig == null || typeof rawConfig !== 'object') { + throw new Error(`Invalid YAML config at ${this.configPath}: expected a top-level object`) + } + + // Expand environment variables + this.config = this.expandEnvVars(rawConfig) as RawConfig + } + + private expandEnvVars(obj: unknown): unknown { + if (typeof obj === 'object' && obj !== null) { + if (Array.isArray(obj)) { + return obj.map((item) => this.expandEnvVars(item)) + } + const result: Record = {} + for (const [key, value] of Object.entries(obj)) { + result[key] = this.expandEnvVars(value) + } + return result + } + + if (typeof obj === 'string') { + // Handle ${VAR:-default} syntax + return obj.replace(/\$\{([^}]+)\}/g, (_, varExpr: string) => { + if (varExpr.includes(':-')) { + const [varName, defaultValue] = varExpr.split(':-') + return process.env[varName] || defaultValue + } + return process.env[varExpr] || '' + }) + } + + return obj + } + + getIntegrationUrl(integration: string): string { + if (!this.config) throw new Error('Config not loaded') + + const bifrostConfig = this.config.bifrost + const baseUrl = bifrostConfig.base_url + const endpoint = bifrostConfig.endpoints[integration] + + if (!endpoint) { + throw new Error(`No endpoint configured for integration: ${integration}`) + } + + // Normalize URL to avoid double slashes + const base = baseUrl.replace(/\/+$/, '') + const ep = String(endpoint).replace(/^\/+/, '') + return `${base}/${ep}` + } + + getBifrostConfig(): BifrostConfig { + if (!this.config) throw new Error('Config not loaded') + return this.config.bifrost + } + + getModel(integration: string, modelType: string = 'chat'): string { + // Map integration to provider + const provider = INTEGRATION_TO_PROVIDER_MAP[integration] + if (!provider) { + throw new Error( + `Unknown integration: ${integration}. Valid integrations: ${Object.keys(INTEGRATION_TO_PROVIDER_MAP).join(', ')}` + ) + } + + // Get model from provider configuration + return this.getProviderModel(provider, modelType) + } + + getModelAlternatives(integration: string): string[] { + const provider = INTEGRATION_TO_PROVIDER_MAP[integration] + if (!provider || !this.config?.providers?.[provider]) { + return [] + } + + const alternatives = this.config.providers[provider].alternatives + return Array.isArray(alternatives) ? alternatives : [] + } + + getModelCapabilities(model: string): Record { + if (!this.config) throw new Error('Config not loaded') + + return ( + this.config.model_capabilities[model] || { + chat: true, + tools: false, + vision: false, + max_tokens: 4096, + context_window: 4096, + } + ) + } + + supportsCapability(model: string, capability: string): boolean { + const caps = this.getModelCapabilities(model) + return caps[capability] === true + } + + getApiConfig(): ApiConfig { + if (!this.config) throw new Error('Config not loaded') + return this.config.api + } + + getTestSettings(): TestSettings { + if (!this.config) throw new Error('Config not loaded') + return this.config.test_settings + } + + getIntegrationSettings(integration: string): Record { + if (!this.config) throw new Error('Config not loaded') + return this.config.integration_settings[integration] || {} + } + + getEnvironmentConfig(environment?: string): Record { + if (!this.config) throw new Error('Config not loaded') + const env = environment || process.env.TEST_ENV || 'development' + return this.config.environments[env] || {} + } + + getLoggingConfig(): Record { + if (!this.config) throw new Error('Config not loaded') + return this.config.logging + } + + listIntegrations(): string[] { + return Object.keys(INTEGRATION_TO_PROVIDER_MAP) + } + + listModels(integration?: string): Record { + if (!this.config) throw new Error('Config not loaded') + + if (integration) { + const provider = INTEGRATION_TO_PROVIDER_MAP[integration] + if (!provider) { + throw new Error(`Unknown integration: ${integration}`) + } + + if (!this.config.providers?.[provider]) { + throw new Error(`No provider configuration for: ${provider}`) + } + + return { [integration]: this.config.providers[provider] } + } + + // Return all providers mapped to their integration names + const result: Record = {} + for (const [integ, provider] of Object.entries(INTEGRATION_TO_PROVIDER_MAP)) { + if (this.config.providers?.[provider]) { + result[integ] = this.config.providers[provider] + } + } + + return result + } + + validateConfig(): boolean { + if (!this.config) throw new Error('Config not loaded') + + const requiredSections = ['bifrost', 'providers', 'api', 'test_settings'] + + for (const section of requiredSections) { + if (!(section in this.config)) { + throw new Error(`Missing required configuration section: ${section}`) + } + } + + // Validate Bifrost configuration + const bifrost = this.config.bifrost + if (!bifrost.base_url || !bifrost.endpoints) { + throw new Error('Bifrost configuration missing base_url or endpoints') + } + + // Validate that all integrations map to valid providers + for (const [integration, provider] of Object.entries(INTEGRATION_TO_PROVIDER_MAP)) { + if (!this.config.providers[provider]) { + throw new Error( + `Integration '${integration}' maps to provider '${provider}' which is not configured in providers section` + ) + } + } + + return true + } + + printConfigSummary(): void { + if (!this.config) throw new Error('Config not loaded') + + console.log('šŸ”§ BIFROST INTEGRATION TEST CONFIGURATION (TypeScript)') + console.log('='.repeat(80)) + + // Bifrost configuration + const bifrost = this.getBifrostConfig() + console.log('\nšŸŒ‰ BIFROST GATEWAY:') + console.log(` Base URL: ${bifrost.base_url}`) + console.log(' Endpoints:') + for (const [integration, endpoint] of Object.entries(bifrost.endpoints)) { + const fullUrl = `${bifrost.base_url.replace(/\/$/, '')}/${endpoint}` + console.log(` ${integration}: ${fullUrl}`) + } + + // Model configurations + console.log('\nšŸ¤– MODEL CONFIGURATIONS (via providers):') + for (const [integration, provider] of Object.entries(INTEGRATION_TO_PROVIDER_MAP)) { + if (this.config.providers?.[provider]) { + const models = this.config.providers[provider] + console.log(` ${integration.toUpperCase()} → ${provider}:`) + console.log(` Chat: ${models.chat || 'N/A'}`) + console.log(` Vision: ${models.vision || 'N/A'}`) + console.log(` Tools: ${models.tools || 'N/A'}`) + const alternatives = models.alternatives + console.log(` Alternatives: ${Array.isArray(alternatives) ? alternatives.length : 0} models`) + } + } + + // API settings + const apiConfig = this.getApiConfig() + console.log('\nāš™ļø API SETTINGS:') + console.log(` Timeout: ${apiConfig.timeout}s`) + console.log(` Max Retries: ${apiConfig.max_retries}`) + console.log(` Retry Delay: ${apiConfig.retry_delay}s`) + + console.log(`\nāœ… Configuration loaded successfully from: ${this.configPath}`) + } + + getProviderModel(provider: string, capability: string = 'chat'): string { + if (!this.config?.providers) { + return '' + } + + const providerModels = this.config.providers[provider] + if (!providerModels) { + return '' + } + + const model = providerModels[capability] + return typeof model === 'string' ? model : '' + } + + getProviderApiKeyEnv(provider: string): string { + if (!this.config?.provider_api_keys) { + return '' + } + return this.config.provider_api_keys[provider] || '' + } + + isProviderAvailable(provider: string): boolean { + const envVar = this.getProviderApiKeyEnv(provider) + if (!envVar) { + return false + } + + const apiKey = process.env[envVar] + return apiKey !== undefined && apiKey.trim() !== '' + } + + getAvailableProviders(): string[] { + if (!this.config?.providers) { + return [] + } + + const available: string[] = [] + for (const provider of Object.keys(this.config.providers)) { + if (this.isProviderAvailable(provider)) { + available.push(provider) + } + } + + return available + } + + providerSupportsScenario(provider: string, scenario: string): boolean { + if (!this.config?.provider_scenarios?.[provider]) { + return false + } + + return this.config.provider_scenarios[provider][scenario] === true + } + + getProvidersForScenario(scenario: string): string[] { + const availableProviders = this.getAvailableProviders() + const providers: string[] = [] + + for (const provider of availableProviders) { + if (this.providerSupportsScenario(provider, scenario)) { + providers.push(provider) + } + } + + return providers + } + + getScenarioCapability(scenario: string): string { + if (!this.config?.scenario_capabilities) { + return 'chat' + } + + return this.config.scenario_capabilities[scenario] || 'chat' + } + + getVirtualKey(): string { + if (!this.config?.virtual_key?.enabled) { + return '' + } + return this.config.virtual_key.value || '' + } + + isVirtualKeyConfigured(): boolean { + const vk = this.getVirtualKey() + return vk.trim() !== '' + } +} + +// Global configuration instance +let configLoader: ConfigLoader | null = null + +export function getConfig(): ConfigLoader { + if (!configLoader) { + configLoader = new ConfigLoader() + } + return configLoader +} + +export function getIntegrationUrl(integration: string): string { + return getConfig().getIntegrationUrl(integration) +} + +export function getModel(integration: string, modelType: string = 'chat'): string { + return getConfig().getModel(integration, modelType) +} + +export function getModelCapabilities(model: string): Record { + return getConfig().getModelCapabilities(model) +} + +export function supportsCapability(model: string, capability: string): boolean { + return getConfig().supportsCapability(model, capability) +} + +export function getProviderModel(provider: string, capability: string = 'chat'): string { + return getConfig().getProviderModel(provider, capability) +} + +export function isProviderAvailable(provider: string): boolean { + return getConfig().isProviderAvailable(provider) +} + +export function getAvailableProviders(): string[] { + return getConfig().getAvailableProviders() +} + +export function providerSupportsScenario(provider: string, scenario: string): boolean { + return getConfig().providerSupportsScenario(provider, scenario) +} + +export function getProvidersForScenario(scenario: string): string[] { + return getConfig().getProvidersForScenario(scenario) +} + +export function getVirtualKey(): string { + return getConfig().getVirtualKey() +} + +export function isVirtualKeyConfigured(): boolean { + return getConfig().isVirtualKeyConfigured() +} + +export function getApiConfig(): ApiConfig { + return getConfig().getApiConfig() +} + +export function getTestSettings(): TestSettings { + return getConfig().getTestSettings() +} + +export function getIntegrationSettings(integration: string): Record { + return getConfig().getIntegrationSettings(integration) +} + +// Export class for direct use if needed +export { ConfigLoader } diff --git a/tests/integrations/typescript/src/utils/index.ts b/tests/integrations/typescript/src/utils/index.ts new file mode 100644 index 0000000000..968e72424e --- /dev/null +++ b/tests/integrations/typescript/src/utils/index.ts @@ -0,0 +1,12 @@ +/** + * Barrel export for all utility modules + */ + +// Config loader +export * from './config-loader' + +// Common test utilities +export * from './common' + +// Parametrization utilities +export * from './parametrize' diff --git a/tests/integrations/typescript/src/utils/parametrize.ts b/tests/integrations/typescript/src/utils/parametrize.ts new file mode 100644 index 0000000000..d3c063f981 --- /dev/null +++ b/tests/integrations/typescript/src/utils/parametrize.ts @@ -0,0 +1,202 @@ +/** + * Parametrization utilities for cross-provider testing. + * + * This module provides utilities for testing across multiple AI providers + * with automatic scenario-based filtering. + */ + +import { getConfig } from './config-loader' + +export interface ProviderModelParam { + provider: string + model: string +} + +export interface ProviderModelVkParam extends ProviderModelParam { + vkEnabled: boolean +} + +/** + * Get cross-provider parameters for a specific scenario. + * + * @param scenario - Test scenario name + * @param includeProviders - Optional list of providers to include + * @param excludeProviders - Optional list of providers to exclude + * @returns Array of [provider, model] tuples for test parametrization + */ +export function getCrossProviderParamsForScenario( + scenario: string, + includeProviders?: string[], + excludeProviders?: string[] +): ProviderModelParam[] { + const config = getConfig() + + // Get providers that support this scenario + let providers = config.getProvidersForScenario(scenario) + + // Apply include filter + if (includeProviders && includeProviders.length > 0) { + providers = providers.filter((p) => includeProviders.includes(p)) + } + + // Apply exclude filter + if (excludeProviders && excludeProviders.length > 0) { + providers = providers.filter((p) => !excludeProviders.includes(p)) + } + + // Generate { provider, model } objects + // Automatically maps: scenario → capability → model + const params: ProviderModelParam[] = [] + + for (const provider of providers.sort()) { + // Map scenario to capability, then get model + const capability = config.getScenarioCapability(scenario) + const model = config.getProviderModel(provider, capability) + + // Only add if provider has a model for this scenario's capability + if (model) { + params.push({ provider, model }) + } + } + + // If no providers available, return a dummy tuple to avoid test errors + // The test will be skipped with appropriate message + if (params.length === 0) { + params.push({ provider: '_no_providers_', model: '_no_model_' }) + } + + return params +} + +/** + * Get cross-provider parameters with virtual key flag for test parametrization. + * + * When virtual key is configured, each provider/model combo is tested twice: + * once without VK (vkEnabled=false) and once with VK (vkEnabled=true). + * + * @param scenario - Test scenario name + * @param includeProviders - Optional list of providers to include + * @param excludeProviders - Optional list of providers to exclude + * @returns Array of { provider, model, vkEnabled } objects + */ +export function getCrossProviderParamsWithVkForScenario( + scenario: string, + includeProviders?: string[], + excludeProviders?: string[] +): ProviderModelVkParam[] { + const config = getConfig() + + // Get base params without VK + const baseParams = getCrossProviderParamsForScenario(scenario, includeProviders, excludeProviders) + + // Handle the dummy tuple case + if (baseParams.length === 1 && baseParams[0].provider === '_no_providers_') { + return [{ provider: '_no_providers_', model: '_no_model_', vkEnabled: false }] + } + + // Build params list with VK flag + const params: ProviderModelVkParam[] = [] + const vkConfigured = config.isVirtualKeyConfigured() + + for (const { provider, model } of baseParams) { + // Always add the non-VK variant + params.push({ provider, model, vkEnabled: false }) + + // Add VK variant only if VK is configured + if (vkConfigured) { + params.push({ provider, model, vkEnabled: true }) + } + } + + return params +} + +/** + * Format test ID for virtual key parameterized tests. + * + * @param provider - Provider name + * @param model - Model name + * @param vkEnabled - Whether VK is enabled + * @returns Formatted test ID string + */ +export function formatVkTestId(provider: string, model: string, vkEnabled: boolean): string { + const vkSuffix = vkEnabled ? 'with_vk' : 'no_vk' + return `${provider}-${model}-${vkSuffix}` +} + +/** + * Format provider and model into the standard "provider/model" format. + * + * @param provider - Provider name + * @param model - Model name + * @returns Formatted string "provider/model" + */ +export function formatProviderModel(provider: string, model: string): string { + return `${provider}/${model}` +} + +/** + * Helper to check if test should be skipped due to no providers. + */ +export function shouldSkipNoProviders(params: ProviderModelParam | ProviderModelVkParam): boolean { + return params.provider === '_no_providers_' +} + +/** + * Get test cases for Vitest's describe.each or it.each. + * + * Returns an array suitable for use with Vitest's parametrization. + * + * @example + * ```typescript + * const testCases = getTestCasesForScenario('simple_chat') + * describe.each(testCases)('Simple Chat - $provider', ({ provider, model }) => { + * it('should complete a simple chat', async () => { + * // test implementation + * }) + * }) + * ``` + */ +export function getTestCasesForScenario( + scenario: string, + includeProviders?: string[], + excludeProviders?: string[] +): ProviderModelParam[] { + return getCrossProviderParamsForScenario(scenario, includeProviders, excludeProviders) +} + +/** + * Get test cases with VK variants for Vitest's describe.each or it.each. + * + * @example + * ```typescript + * const testCases = getTestCasesWithVkForScenario('simple_chat') + * describe.each(testCases)('Simple Chat - $provider (VK: $vkEnabled)', ({ provider, model, vkEnabled }) => { + * it('should complete a simple chat', async () => { + * // test implementation + * }) + * }) + * ``` + */ +export function getTestCasesWithVkForScenario( + scenario: string, + includeProviders?: string[], + excludeProviders?: string[] +): ProviderModelVkParam[] { + return getCrossProviderParamsWithVkForScenario(scenario, includeProviders, excludeProviders) +} + +/** + * Create a test name with provider and model info. + */ +export function createTestName(baseName: string, provider: string, model: string): string { + return `${baseName} [${provider}/${model}]` +} + +/** + * Create a test name with provider, model, and VK info. + */ +export function createTestNameWithVk(baseName: string, provider: string, model: string, vkEnabled: boolean): string { + const vkSuffix = vkEnabled ? ' (with VK)' : '' + return `${baseName} [${provider}/${model}]${vkSuffix}` +} diff --git a/tests/integrations/typescript/tests/setup.ts b/tests/integrations/typescript/tests/setup.ts new file mode 100644 index 0000000000..db969620ce --- /dev/null +++ b/tests/integrations/typescript/tests/setup.ts @@ -0,0 +1,59 @@ +/** + * Global test setup for Vitest + * + * This file is loaded before all tests run. + * It sets up environment variables and global configuration. + */ + +import { config } from 'dotenv' +import { resolve, dirname } from 'path' +import { fileURLToPath } from 'url' + +// ES module compatibility - __dirname is not available in ESM +const __filename = fileURLToPath(import.meta.url) +const __dirname = dirname(__filename) + +// Load environment variables from .env file in project root +config({ path: resolve(__dirname, '../.env') }) + +// Also try loading from workspace root +config({ path: resolve(__dirname, '../../../../.env') }) + +// Set default environment variables if not present +if (!process.env.BIFROST_BASE_URL) { + process.env.BIFROST_BASE_URL = 'http://localhost:8080' +} + +// Log test environment info +console.log('\n🧪 Bifrost TypeScript Integration Tests') +console.log('='.repeat(50)) +console.log(`šŸ“ Bifrost URL: ${process.env.BIFROST_BASE_URL}`) +console.log(`šŸ• Started at: ${new Date().toISOString()}`) + +// Check for available API keys +const apiKeys = { + OpenAI: !!process.env.OPENAI_API_KEY, + Anthropic: !!process.env.ANTHROPIC_API_KEY, + Google: !!process.env.GEMINI_API_KEY, + Bedrock: !!process.env.AWS_ACCESS_KEY_ID, + Cohere: !!process.env.COHERE_API_KEY, +} + +console.log('\nšŸ”‘ Available API Keys:') +for (const [provider, available] of Object.entries(apiKeys)) { + const status = available ? 'āœ…' : 'āŒ' + console.log(` ${status} ${provider}`) +} +console.log('='.repeat(50) + '\n') + +// Global test timeout (can be overridden per test) +// This is set in vitest.config.ts but documented here +// Default: 300000ms (5 minutes) for integration tests + +// Export for use in tests if needed +export const testEnvironment = { + bifrostUrl: process.env.BIFROST_BASE_URL, + availableProviders: Object.entries(apiKeys) + .filter(([, available]) => available) + .map(([provider]) => provider.toLowerCase()), +} diff --git a/tests/integrations/typescript/tests/test-anthropic.test.ts b/tests/integrations/typescript/tests/test-anthropic.test.ts new file mode 100644 index 0000000000..611b4e00cf --- /dev/null +++ b/tests/integrations/typescript/tests/test-anthropic.test.ts @@ -0,0 +1,1628 @@ +/** + * Anthropic Integration Tests + * + * This test suite uses the Anthropic SDK to test Claude models through Bifrost. + * Tests cover chat, streaming, tool calling, vision, files, batch, and advanced capabilities. + * + * Test Scenarios: + * + * Chat Completions: + * 1. Simple chat + * 2. Multi-turn conversation + * 3. Streaming chat + * + * Tool Calling: + * 4. Single tool call + * 5. Multiple tool calls + * 6. End-to-end tool calling + * + * Vision/Image: + * 7. Image URL analysis + * 8. Image Base64 analysis + * 9. Multiple images analysis + * + * Extended Thinking: + * 10. Thinking/Extended reasoning + * 11. Extended thinking streaming + * + * Token Counting: + * 12. Count tokens (basic) + * 13. Count tokens - with system message + * 14. Count tokens - long text + * + * Prompt Caching: + * 15. Prompt caching - system message + * 16. Prompt caching - messages + * 17. Prompt caching - tools + * + * Document Input: + * 18. Document input - PDF Base64 + * 19. Document input - plain text + * + * Models: + * 20. List models + * + * Files API: + * 21. File upload + * 22. File list + * 23. File delete + * 24. File content download + * + * Batch API: + * 25. Batch create (inline requests) + * 26. Batch list + * 27. Batch retrieve + * 28. Batch cancel + * 29. Batch results + * 30. Batch end-to-end workflow + */ + +import Anthropic from '@anthropic-ai/sdk' +import { beforeAll, describe, expect, it } from 'vitest' + +import { + getIntegrationUrl, + getProviderModel, + isProviderAvailable, +} from '../src/utils/config-loader' + +import { + BASE64_IMAGE, + CALCULATOR_TOOL, + getApiKey, + hasApiKey, + IMAGE_URL, + mockToolResponse, + MULTI_TURN_MESSAGES, + MULTIPLE_TOOL_CALL_MESSAGES, + SIMPLE_CHAT_MESSAGES, + SINGLE_TOOL_CALL_MESSAGES, + STREAMING_CHAT_MESSAGES, + WEATHER_TOOL, + type ChatMessage, + type ToolDefinition, +} from '../src/utils/common' + +// Type for content blocks that include beta features +type ContentBlockParamWithBeta = Anthropic.ContentBlock | { type: 'text'; text: string; cache_control?: { type: 'ephemeral' } } | { type: 'document'; source: { type: 'base64'; media_type: string; data: string } } + +// ============================================================================ +// Helper Functions +// ============================================================================ + +function getAnthropicClient(): Anthropic { + const baseUrl = getIntegrationUrl('anthropic') + const apiKey = hasApiKey('anthropic') ? getApiKey('anthropic') : 'dummy-key' + + return new Anthropic({ + baseURL: baseUrl, + apiKey, + timeout: 300000, // 5 minutes + maxRetries: 3, + }) +} + +function convertToAnthropicMessages( + messages: ChatMessage[] +): Anthropic.MessageParam[] { + return messages.map((msg) => { + if (msg.role === 'assistant') { + return { + role: 'assistant' as const, + content: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content), + } + } + + if (typeof msg.content === 'string') { + return { + role: 'user' as const, + content: msg.content, + } + } + + // Handle multimodal content + const parts: Anthropic.ContentBlock[] = msg.content.map((part) => { + if (part.type === 'text') { + return { type: 'text' as const, text: part.text! } + } + + // Handle image content + const imageUrl = part.image_url!.url + if (imageUrl.startsWith('data:')) { + // Base64 image + const matches = imageUrl.match(/^data:([^;]+);base64,(.+)$/) + if (matches) { + return { + type: 'image' as const, + source: { + type: 'base64' as const, + media_type: matches[1] as 'image/jpeg' | 'image/png' | 'image/gif' | 'image/webp', + data: matches[2], + }, + } + } + } + + // URL image - Anthropic supports URL source type directly (beta feature) + return { + type: 'image' as const, + source: { + type: 'url' as const, + url: imageUrl, + }, + } as unknown as Anthropic.ContentBlock + }) as Anthropic.ContentBlock[] + + return { + role: 'user' as const, + content: parts, + } + }) +} + +function convertToAnthropicTools(tools: ToolDefinition[]): Anthropic.Tool[] { + return tools.map((tool) => ({ + name: tool.name, + description: tool.description, + input_schema: { + type: 'object' as const, + properties: tool.parameters.properties, + required: tool.parameters.required, + }, + })) +} + +function extractAnthropicToolCalls( + response: Anthropic.Message +): Array<{ name: string; arguments: Record; id: string }> { + const toolCalls: Array<{ name: string; arguments: Record; id: string }> = [] + + for (const block of response.content) { + if (block.type === 'tool_use') { + toolCalls.push({ + name: block.name, + arguments: block.input as Record, + id: block.id, + }) + } + } + + return toolCalls +} + +function getContentString(response: Anthropic.Message): string { + let content = '' + for (const block of response.content) { + if (block.type === 'text') { + content += block.text + } + } + return content +} + +// ============================================================================ +// Test Suite +// ============================================================================ + +describe('Anthropic SDK Integration Tests', () => { + const skipTests = !isProviderAvailable('anthropic') + + beforeAll(() => { + if (skipTests) { + console.log('āš ļø Skipping Anthropic tests: ANTHROPIC_API_KEY not set') + } + }) + + // ============================================================================ + // Simple Chat Tests + // ============================================================================ + + describe('Simple Chat', () => { + it('should complete a simple chat', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'chat') + + const response = await client.messages.create({ + model, + max_tokens: 100, + messages: convertToAnthropicMessages(SIMPLE_CHAT_MESSAGES), + }) + + expect(response).toBeDefined() + expect(response.content).toBeDefined() + expect(response.content.length).toBeGreaterThan(0) + + const content = getContentString(response) + expect(content.length).toBeGreaterThan(0) + console.log(`āœ… Simple chat passed for anthropic/${model}`) + }) + }) + + // ============================================================================ + // Multi-turn Conversation Tests + // ============================================================================ + + describe('Multi-turn Conversation', () => { + it('should handle multi-turn conversation', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'chat') + + const response = await client.messages.create({ + model, + max_tokens: 100, + messages: convertToAnthropicMessages(MULTI_TURN_MESSAGES), + }) + + expect(response).toBeDefined() + const content = getContentString(response) + expect(content.toLowerCase()).toMatch(/paris|population|million|people/i) + console.log(`āœ… Multi-turn conversation passed for anthropic/${model}`) + }) + }) + + // ============================================================================ + // Streaming Tests + // ============================================================================ + + describe('Streaming Chat', () => { + it('should stream chat response', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'chat') + + const stream = client.messages.stream({ + model, + max_tokens: 100, + messages: convertToAnthropicMessages(STREAMING_CHAT_MESSAGES), + }) + + let content = '' + for await (const event of stream) { + if (event.type === 'content_block_delta' && event.delta.type === 'text_delta') { + content += event.delta.text + } + } + + expect(content.length).toBeGreaterThan(0) + console.log(`āœ… Streaming chat passed for anthropic/${model}`) + }) + }) + + // ============================================================================ + // Streaming Client Disconnect Tests + // ============================================================================ + + describe('Streaming Chat - Client Disconnect', () => { + it('should handle client disconnect mid-stream', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'chat') + const abortController = new AbortController() + + // Request a longer response to ensure we have time to abort mid-stream + const stream = client.messages.stream({ + model, + max_tokens: 1000, + messages: [ + { + role: 'user', + content: 'Write a detailed essay about the history of computing, including at least 10 paragraphs.', + }, + ], + }, { + signal: abortController.signal, + }) + + let chunkCount = 0 + let content = '' + let wasAborted = false + + try { + for await (const event of stream) { + chunkCount++ + if (event.type === 'content_block_delta' && event.delta.type === 'text_delta') { + content += event.delta.text + } + + // Abort after receiving a few chunks + if (chunkCount >= 5) { + abortController.abort() + } + } + } catch (error) { + wasAborted = true + expect(error).toBeDefined() + // The error should be an AbortError or contain abort-related message + const errorMessage = error instanceof Error ? error.message.toLowerCase() : String(error).toLowerCase() + const isAbortError = errorMessage.includes('abort') || + errorMessage.includes('cancel') || + error instanceof DOMException || + (error as { name?: string })?.name === 'AbortError' + expect(isAbortError).toBe(true) + } + + // Verify we received some content before aborting + expect(chunkCount).toBeGreaterThanOrEqual(5) + expect(content.length).toBeGreaterThan(0) + expect(wasAborted).toBe(true) + console.log(`āœ… Streaming client disconnect passed for anthropic/${model} (${chunkCount} chunks before abort)`) + }) + }) + + // ============================================================================ + // Tool Calling Tests + // ============================================================================ + + describe('Single Tool Call', () => { + it('should make a single tool call', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'tools') + + const response = await client.messages.create({ + model, + max_tokens: 100, + messages: convertToAnthropicMessages(SINGLE_TOOL_CALL_MESSAGES), + tools: convertToAnthropicTools([WEATHER_TOOL]), + }) + + const toolCalls = extractAnthropicToolCalls(response) + expect(toolCalls.length).toBe(1) + expect(toolCalls[0].name).toBe('get_weather') + console.log(`āœ… Single tool call passed for anthropic/${model}`) + }) + }) + + describe('Multiple Tool Calls', () => { + it('should make multiple tool calls', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'tools') + + const response = await client.messages.create({ + model, + max_tokens: 150, + messages: convertToAnthropicMessages(MULTIPLE_TOOL_CALL_MESSAGES), + tools: convertToAnthropicTools([WEATHER_TOOL, CALCULATOR_TOOL]), + }) + + const toolCalls = extractAnthropicToolCalls(response) + expect(toolCalls.length).toBeGreaterThanOrEqual(1) + + const toolNames = toolCalls.map((tc) => tc.name) + expect(toolNames.some((name) => name === 'get_weather' || name === 'calculate')).toBe(true) + console.log(`āœ… Multiple tool calls passed for anthropic/${model}`) + }) + }) + + describe('End-to-End Tool Calling', () => { + it('should complete end-to-end tool calling', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'tools') + + // Step 1: Initial request with tools + const response1 = await client.messages.create({ + model, + max_tokens: 100, + messages: convertToAnthropicMessages(SINGLE_TOOL_CALL_MESSAGES), + tools: convertToAnthropicTools([WEATHER_TOOL]), + }) + + const toolCalls = extractAnthropicToolCalls(response1) + expect(toolCalls.length).toBeGreaterThan(0) + + // Step 2: Execute tool and get result + const toolResult = mockToolResponse(toolCalls[0].name, toolCalls[0].arguments) + + // Step 3: Send tool result back + const messages: Anthropic.MessageParam[] = [ + ...convertToAnthropicMessages(SINGLE_TOOL_CALL_MESSAGES), + { + role: 'assistant', + content: response1.content, + }, + { + role: 'user', + content: [ + { + type: 'tool_result', + tool_use_id: toolCalls[0].id, + content: toolResult, + }, + ], + }, + ] + + const response2 = await client.messages.create({ + model, + max_tokens: 200, + messages, + tools: convertToAnthropicTools([WEATHER_TOOL]), + }) + + expect(response2).toBeDefined() + const content = getContentString(response2) + expect(content.length).toBeGreaterThan(0) + console.log(`āœ… End-to-end tool calling passed for anthropic/${model}`) + }) + }) + + // ============================================================================ + // Image/Vision Tests + // ============================================================================ + + describe('Image URL', () => { + it('should analyze image from URL', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'vision') + + // Use type assertion for URL-based image source (beta feature) + const response = await client.messages.create({ + model, + max_tokens: 200, + messages: [ + { + role: 'user', + content: [ + { type: 'text', text: 'What do you see in this image? Describe it briefly.' }, + { + type: 'image', + source: { + type: 'url', + url: IMAGE_URL, + }, + }, + ], + }, + ], + } as never) + + expect(response).toBeDefined() + const content = getContentString(response) + expect(content.length).toBeGreaterThan(10) + console.log(`āœ… Image URL analysis passed for anthropic/${model}`) + }) + }) + + describe('Image Base64', () => { + it('should analyze image from Base64', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'vision') + + const response = await client.messages.create({ + model, + max_tokens: 200, + messages: [ + { + role: 'user', + content: [ + { type: 'text', text: 'What color is this image?' }, + { + type: 'image', + source: { + type: 'base64', + media_type: 'image/png', + data: BASE64_IMAGE, + }, + }, + ], + }, + ], + }) + + expect(response).toBeDefined() + const content = getContentString(response) + expect(content.length).toBeGreaterThan(10) + console.log(`āœ… Image Base64 analysis passed for anthropic/${model}`) + }) + }) + + describe('Multiple Images', () => { + it('should analyze multiple images', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'vision') + + // Use type assertion for URL-based image source (beta feature) + const response = await client.messages.create({ + model, + max_tokens: 300, + messages: [ + { + role: 'user', + content: [ + { type: 'text', text: 'Compare these two images. What do you see?' }, + { + type: 'image', + source: { + type: 'url', + url: IMAGE_URL, + }, + }, + { + type: 'image', + source: { + type: 'base64', + media_type: 'image/png', + data: BASE64_IMAGE, + }, + }, + ], + }, + ], + } as never) + + expect(response).toBeDefined() + const content = getContentString(response) + expect(content.length).toBeGreaterThan(10) + console.log(`āœ… Multiple images analysis passed for anthropic/${model}`) + }) + }) + + // ============================================================================ + // Thinking/Extended Reasoning Tests + // ============================================================================ + + describe('Thinking/Extended Reasoning', () => { + it('should support extended thinking', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'thinking') + + // Skip if no thinking model available + if (!model) { + console.log('āš ļø Skipping thinking test: No thinking model configured') + return + } + + try { + // Use type assertion for beta thinking feature + const response = await client.messages.create({ + model, + max_tokens: 8000, + thinking: { + type: 'enabled', + budget_tokens: 5000, + }, + messages: [ + { + role: 'user', + content: 'What is 15% of 80? Show your reasoning step by step.', + }, + ], + } as never) + + expect(response).toBeDefined() + expect(response.content).toBeDefined() + + // Check for thinking blocks (beta feature) + const hasThinking = response.content.some((block: { type: string }) => block.type === 'thinking') + const content = getContentString(response) + + // Either should have thinking blocks or text content + expect(hasThinking || content.length > 0).toBe(true) + console.log(`āœ… Thinking/Extended reasoning passed for anthropic/${model}`) + } catch (error) { + // Some models may not support thinking + console.log(`āš ļø Thinking test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + // ============================================================================ + // Count Tokens Tests + // ============================================================================ + + describe('Count Tokens', () => { + it('should return token usage in response', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'chat') + + const response = await client.messages.create({ + model, + max_tokens: 50, + messages: [{ role: 'user', content: 'Say hello' }], + }) + + expect(response.usage).toBeDefined() + expect(response.usage.input_tokens).toBeGreaterThan(0) + expect(response.usage.output_tokens).toBeGreaterThan(0) + console.log(`āœ… Count tokens passed for anthropic/${model}`) + }) + }) + + // ============================================================================ + // Prompt Caching Tests + // ============================================================================ + + describe('Prompt Caching - System Message', () => { + it('should support prompt caching with system message', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'chat') + + // Create a large context for caching + const largeContext = 'This is a legal document for analysis. '.repeat(100) + + // First request - should create cache (use type assertion for beta cache_control) + const response1 = await client.messages.create({ + model, + max_tokens: 100, + system: [ + { type: 'text', text: 'You are an AI assistant tasked with analyzing legal documents.' }, + { type: 'text', text: largeContext, cache_control: { type: 'ephemeral' } }, + ], + messages: [ + { role: 'user', content: 'What are the key elements of contract formation?' }, + ], + } as never) + + expect(response1).toBeDefined() + expect(response1.usage).toBeDefined() + + // Second request - should hit cache + const response2 = await client.messages.create({ + model, + max_tokens: 100, + system: [ + { type: 'text', text: 'You are an AI assistant tasked with analyzing legal documents.' }, + { type: 'text', text: largeContext, cache_control: { type: 'ephemeral' } }, + ], + messages: [ + { role: 'user', content: 'Explain the purpose of force majeure clauses.' }, + ], + } as never) + + expect(response2).toBeDefined() + expect(response2.usage).toBeDefined() + + console.log(`āœ… Prompt caching (system message) passed for anthropic/${model}`) + }) + }) + + describe('Prompt Caching - Messages', () => { + it('should support prompt caching with messages', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'chat') + + // Create a large context for caching + const largeContext = 'This is a legal document for analysis. '.repeat(100) + + // First request - should create cache (use type assertion for beta cache_control) + const response1 = await client.messages.create({ + model, + max_tokens: 100, + messages: [ + { + role: 'user', + content: [ + { type: 'text', text: 'Here is a large legal document to analyze:' }, + { type: 'text', text: largeContext, cache_control: { type: 'ephemeral' } }, + { type: 'text', text: 'What are the main indemnification principles?' }, + ], + }, + ], + } as never) + + expect(response1).toBeDefined() + expect(response1.usage).toBeDefined() + + // Second request with same cached content + const response2 = await client.messages.create({ + model, + max_tokens: 100, + messages: [ + { + role: 'user', + content: [ + { type: 'text', text: 'Here is a large legal document to analyze:' }, + { type: 'text', text: largeContext, cache_control: { type: 'ephemeral' } }, + { type: 'text', text: 'Summarize the dispute resolution methods.' }, + ], + }, + ], + } as never) + + expect(response2).toBeDefined() + expect(response2.usage).toBeDefined() + + console.log(`āœ… Prompt caching (messages) passed for anthropic/${model}`) + }) + }) + + describe('Prompt Caching - Tools', () => { + it('should support prompt caching with tools', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'tools') + + // Create multiple tools for caching + const tools = convertToAnthropicTools([WEATHER_TOOL, CALCULATOR_TOOL]) + // Add cache control to the last tool (use type assertion for beta feature) + const cachedTools = tools.map((tool, index) => + index === tools.length - 1 + ? { ...tool, cache_control: { type: 'ephemeral' as const } } + : tool + ) + + // First request - should create cache (use type assertion for beta cache_control) + const response1 = await client.messages.create({ + model, + max_tokens: 100, + tools: cachedTools, + messages: [ + { role: 'user', content: "What's the weather in Boston?" }, + ], + } as never) + + expect(response1).toBeDefined() + expect(response1.usage).toBeDefined() + + // Second request - should hit cache + const response2 = await client.messages.create({ + model, + max_tokens: 100, + tools: cachedTools, + messages: [ + { role: 'user', content: 'Calculate 42 * 17' }, + ], + } as never) + + expect(response2).toBeDefined() + expect(response2.usage).toBeDefined() + + console.log(`āœ… Prompt caching (tools) passed for anthropic/${model}`) + }) + }) + + // ============================================================================ + // Document Input Tests + // ============================================================================ + + describe('Document Input - PDF Base64', () => { + it('should handle PDF document input', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'file') + + // Sample PDF base64 (minimal PDF with "Hello World") + const pdfBase64 = + 'JVBERi0xLjcKCjEgMCBvYmogICUgZW50cnkgcG9pbnQKPDwKICAvVHlwZSAvQ2F0YWxvZwogIC' + + '9QYWdlcyAyIDAgUgo+PgplbmRvYmoKCjIgMCBvYmoKPDwKICAvVHlwZSAvUGFnZXwKICAvTWV' + + 'kaWFCb3ggWyAwIDAgMjAwIDIwMCBdCiAgL0NvdW50IDEKICAvS2lkcyBbIDMgMCBSIF0KPj4K' + + 'ZW5kb2JqCgozIDAgb2JqCjw8CiAgL1R5cGUgL1BhZ2UKICAvUGFyZW50IDIgMCBSCiAgL1Jlc' + + '291cmNlcyA8PAogICAgL0ZvbnQgPDwKICAgICAgL0YxIDQgMCBSCj4+CiAgPj4KICAvQ29udG' + + 'VudHMgNSAwIFIKPj4KZW5kb2JqCgo0IDAgb2JqCjw8CiAgL1R5cGUgL0ZvbnQKICAvU3VidHl' + + 'wZSAvVHlwZTEKICAvQmFzZUZvbnQgL1RpbWVzLVJvbWFuCj4+CmVuZG9iagoKNSAwIG9iago8' + + 'PAogIC9MZW5ndGggNDQKPj4Kc3RyZWFtCkJUCjcwIDUwIFRECi9GMSAxMiBUZgooSGVsbG8gV' + + '29ybGQhKSBUagpFVAplbmRzdHJlYW0KZW5kb2JqCgp4cmVmCjAgNgowMDAwMDAwMDAwIDY1NT' + + 'M1IGYgCjAwMDAwMDAwMTAgMDAwMDAgbiAKMDAwMDAwMDA2MCAwMDAwMCBuIAowMDAwMDAwMTU' + + '3IDAwMDAwIG4gCjAwMDAwMDAyNTUgMDAwMDAgbiAKMDAwMDAwMDM1MyAwMDAwMCBuIAp0cmFp' + + 'bGVyCjw8CiAgL1NpemUgNgogIC9Sb290IDEgMCBSCj4+CnN0YXJ0eHJlZgo0NDkKJSVFT0YK' + + try { + // Use type assertion for beta document feature + const response = await client.messages.create({ + model, + max_tokens: 200, + messages: [ + { + role: 'user', + content: [ + { type: 'text', text: 'What does this PDF document contain?' }, + { + type: 'document', + source: { + type: 'base64', + media_type: 'application/pdf', + data: pdfBase64, + }, + }, + ], + }, + ], + } as never) + + expect(response).toBeDefined() + const content = getContentString(response) + expect(content.length).toBeGreaterThan(0) + console.log(`āœ… Document input (PDF Base64) passed for anthropic/${model}`) + } catch (error) { + // Document input may not be supported on all models + console.log(`āš ļø Document input test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + // ============================================================================ + // List Models Tests + // ============================================================================ + + describe('List Models', () => { + it('should list available models', async () => { + if (skipTests) return + + const client = getAnthropicClient() + + try { + // Use type assertion for beta models API + const response = await (client as unknown as { models: { list: () => Promise<{ data: unknown[] }> } }).models.list() + + expect(response).toBeDefined() + expect(response.data).toBeDefined() + expect(Array.isArray(response.data)).toBe(true) + expect(response.data.length).toBeGreaterThan(0) + + console.log(`āœ… List models passed for anthropic - ${response.data.length} models`) + } catch (error) { + // List models may not be available on all versions + console.log(`āš ļø List models test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + // ============================================================================ + // Extended Thinking Streaming Tests + // ============================================================================ + + describe('Extended Thinking Streaming', () => { + it('should stream extended thinking response', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'thinking') + + if (!model) { + console.log('āš ļø Skipping thinking streaming test: No thinking model configured') + return + } + + try { + // Use type assertion for beta thinking feature + const stream = client.messages.stream({ + model, + max_tokens: 3000, + thinking: { + type: 'enabled', + budget_tokens: 2000, + }, + messages: [ + { + role: 'user', + content: 'Alice, Bob, and Carol went to dinner. The total bill was $90. If they split it equally, how much does each person owe? Show your reasoning.', + }, + ], + } as never) + + let thinkingContent = '' + let textContent = '' + let chunkCount = 0 + let hasThinkingDelta = false + + for await (const event of stream) { + chunkCount++ + + if (event.type === 'content_block_start') { + const block = event.content_block as { type: string } + if (block?.type === 'thinking') { + hasThinkingDelta = true + } + } + + if (event.type === 'content_block_delta') { + const delta = event.delta as { type: string; thinking?: string; text?: string } + if (delta.type === 'thinking_delta' && delta.thinking) { + thinkingContent += delta.thinking + } else if (delta.type === 'text_delta' && delta.text) { + textContent += delta.text + } + } + + if (chunkCount > 5000) break + } + + expect(chunkCount).toBeGreaterThan(0) + expect(hasThinkingDelta || thinkingContent.length > 0).toBe(true) + console.log(`āœ… Extended thinking streaming passed for anthropic/${model} (${chunkCount} chunks)`) + } catch (error) { + console.log(`āš ļø Extended thinking streaming test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + describe('Extended Thinking Streaming - Client Disconnect', () => { + it('should handle client disconnect mid-stream during extended thinking', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'thinking') + + if (!model) { + console.log('āš ļø Skipping thinking streaming disconnect test: No thinking model configured') + return + } + + const abortController = new AbortController() + + try { + // Use type assertion for beta thinking feature + const stream = client.messages.stream({ + model, + max_tokens: 5000, + thinking: { + type: 'enabled', + budget_tokens: 3000, + }, + messages: [ + { + role: 'user', + content: 'Solve this complex problem step by step: A train leaves Station A at 8:00 AM traveling at 60 mph. Another train leaves Station B, 300 miles away, at 9:00 AM traveling toward Station A at 80 mph. At what time will they meet? Show all your detailed reasoning.', + }, + ], + } as never, { + signal: abortController.signal, + }) + + let chunkCount = 0 + let wasAborted = false + + try { + for await (const event of stream) { + chunkCount++ + + // Abort after receiving a few chunks + if (chunkCount >= 10) { + abortController.abort() + } + } + } catch (error) { + wasAborted = true + expect(error).toBeDefined() + const errorMessage = error instanceof Error ? error.message.toLowerCase() : String(error).toLowerCase() + const isAbortError = errorMessage.includes('abort') || + errorMessage.includes('cancel') || + error instanceof DOMException || + (error as { name?: string })?.name === 'AbortError' + expect(isAbortError).toBe(true) + } + + expect(chunkCount).toBeGreaterThanOrEqual(10) + expect(wasAborted).toBe(true) + console.log(`āœ… Extended thinking streaming client disconnect passed for anthropic/${model} (${chunkCount} chunks before abort)`) + } catch (error) { + console.log(`āš ļø Extended thinking streaming disconnect test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + // ============================================================================ + // Files API Tests + // ============================================================================ + + describe('Files API - Upload', () => { + it('should upload a file', async () => { + if (skipTests) return + + const client = getAnthropicClient() + + try { + const beta = (client as unknown as { beta: { files: { upload: (params: { file: [string, Uint8Array, string] }) => Promise<{ id: string }> } } }).beta + + const testContent = new TextEncoder().encode('This is a test file for Files API integration testing.') + const response = await beta.files.upload({ + file: ['test_upload.txt', testContent, 'text/plain'], + }) + + expect(response).toBeDefined() + expect(response.id).toBeDefined() + expect(response.id.length).toBeGreaterThan(0) + + console.log(`āœ… Files API upload passed for anthropic - File ID: ${response.id}`) + + // Clean up + try { + const betaFiles = (client as unknown as { beta: { files: { delete: (id: string) => Promise } } }).beta + await betaFiles.files.delete(response.id) + } catch (e) { + console.log(`Warning: Failed to clean up file: ${e}`) + } + } catch (error) { + console.log(`āš ļø Files API upload test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + describe('Files API - List', () => { + it('should list files', async () => { + if (skipTests) return + + const client = getAnthropicClient() + + try { + const beta = (client as unknown as { + beta: { + files: { + upload: (params: { file: [string, Uint8Array, string] }) => Promise<{ id: string }> + list: () => Promise<{ data: Array<{ id: string }> }> + delete: (id: string) => Promise + } + } + }).beta + + // Upload a file first + const testContent = new TextEncoder().encode('Test file for listing') + const uploadedFile = await beta.files.upload({ + file: ['test_list.txt', testContent, 'text/plain'], + }) + + try { + const response = await beta.files.list() + + expect(response).toBeDefined() + expect(response.data).toBeDefined() + expect(Array.isArray(response.data)).toBe(true) + + const fileIds = response.data.map((f) => f.id) + expect(fileIds).toContain(uploadedFile.id) + + console.log(`āœ… Files API list passed for anthropic - ${response.data.length} files`) + } finally { + try { + await beta.files.delete(uploadedFile.id) + } catch (e) { + console.log(`Warning: Failed to clean up file: ${e}`) + } + } + } catch (error) { + console.log(`āš ļø Files API list test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + describe('Files API - Delete', () => { + it('should delete a file', async () => { + if (skipTests) return + + const client = getAnthropicClient() + + try { + const beta = (client as unknown as { + beta: { + files: { + upload: (params: { file: [string, Uint8Array, string] }) => Promise<{ id: string }> + delete: (id: string) => Promise + retrieve: (id: string) => Promise + } + } + }).beta + + // Upload a file first + const testContent = new TextEncoder().encode('Test file for deletion') + const uploadedFile = await beta.files.upload({ + file: ['test_delete.txt', testContent, 'text/plain'], + }) + + // Delete the file + const response = await beta.files.delete(uploadedFile.id) + expect(response).toBeDefined() + + console.log(`āœ… Files API delete passed for anthropic - Deleted file ${uploadedFile.id}`) + + // Verify file is gone + try { + await beta.files.retrieve(uploadedFile.id) + // Should not get here + } catch (e) { + // Expected - file should not be found + expect(e).toBeDefined() + } + } catch (error) { + console.log(`āš ļø Files API delete test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + describe('Files API - Content', () => { + it('should download file content', async () => { + if (skipTests) return + + const client = getAnthropicClient() + + try { + const beta = (client as unknown as { + beta: { + files: { + upload: (params: { file: [string, Uint8Array, string] }) => Promise<{ id: string }> + download: (id: string) => Promise<{ text: () => string }> + delete: (id: string) => Promise + } + } + }).beta + + const originalContent = 'Test file content for download' + const testContent = new TextEncoder().encode(originalContent) + const uploadedFile = await beta.files.upload({ + file: ['test_content.txt', testContent, 'text/plain'], + }) + + try { + const response = await beta.files.download(uploadedFile.id) + expect(response).toBeDefined() + + const downloadedContent = response.text() + expect(downloadedContent).toBe(originalContent) + + console.log(`āœ… Files API content passed for anthropic - Downloaded ${downloadedContent.length} bytes`) + } catch (downloadError) { + // Some providers don't allow downloading uploaded files + console.log(`āš ļø Files API download not supported: ${downloadError instanceof Error ? downloadError.message : 'Unknown error'}`) + } finally { + try { + await beta.files.delete(uploadedFile.id) + } catch (e) { + console.log(`Warning: Failed to clean up file: ${e}`) + } + } + } catch (error) { + console.log(`āš ļø Files API content test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + // ============================================================================ + // Batch API Tests + // ============================================================================ + + describe('Batch API - Create Inline', () => { + it('should create a batch job with inline requests', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'chat') + + try { + const beta = (client as unknown as { + beta: { + messages: { + batches: { + create: (params: { requests: Array<{ custom_id: string; params: unknown }> }) => Promise<{ id: string; processing_status: string }> + cancel: (id: string) => Promise + } + } + } + }).beta + + const batchRequests = [ + { + custom_id: 'request-1', + params: { + model, + max_tokens: 50, + messages: [{ role: 'user', content: 'Say hello' }], + }, + }, + { + custom_id: 'request-2', + params: { + model, + max_tokens: 50, + messages: [{ role: 'user', content: 'Say goodbye' }], + }, + }, + ] + + const batch = await beta.messages.batches.create({ requests: batchRequests }) + + expect(batch).toBeDefined() + expect(batch.id).toBeDefined() + expect(batch.processing_status).toBeDefined() + + console.log(`āœ… Batch API create inline passed for anthropic - Batch ID: ${batch.id}, Status: ${batch.processing_status}`) + + // Clean up + try { + await beta.messages.batches.cancel(batch.id) + } catch (e) { + console.log(`Info: Could not cancel batch: ${e}`) + } + } catch (error) { + console.log(`āš ļø Batch API create inline test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + describe('Batch API - List', () => { + it('should list batch jobs', async () => { + if (skipTests) return + + const client = getAnthropicClient() + + try { + const beta = (client as unknown as { + beta: { + messages: { + batches: { + list: (params: { limit: number }) => Promise<{ data: Array<{ id: string }> }> + } + } + } + }).beta + + const response = await beta.messages.batches.list({ limit: 10 }) + + expect(response).toBeDefined() + expect(response.data).toBeDefined() + expect(Array.isArray(response.data)).toBe(true) + + console.log(`āœ… Batch API list passed for anthropic - ${response.data.length} batches`) + } catch (error) { + console.log(`āš ļø Batch API list test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + describe('Batch API - Retrieve', () => { + it('should retrieve batch status by ID', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'chat') + + try { + const beta = (client as unknown as { + beta: { + messages: { + batches: { + create: (params: { requests: Array<{ custom_id: string; params: unknown }> }) => Promise<{ id: string; processing_status: string }> + retrieve: (id: string) => Promise<{ id: string; processing_status: string }> + cancel: (id: string) => Promise + } + } + } + }).beta + + // Create a batch first + const batchRequests = [{ + custom_id: 'request-1', + params: { + model, + max_tokens: 50, + messages: [{ role: 'user', content: 'Say hello' }], + }, + }] + + const batch = await beta.messages.batches.create({ requests: batchRequests }) + + try { + const retrieved = await beta.messages.batches.retrieve(batch.id) + + expect(retrieved).toBeDefined() + expect(retrieved.id).toBe(batch.id) + expect(retrieved.processing_status).toBeDefined() + + console.log(`āœ… Batch API retrieve passed for anthropic - Batch ID: ${retrieved.id}, Status: ${retrieved.processing_status}`) + } finally { + try { + await beta.messages.batches.cancel(batch.id) + } catch (e) { + console.log(`Info: Could not cancel batch: ${e}`) + } + } + } catch (error) { + console.log(`āš ļø Batch API retrieve test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + describe('Batch API - Cancel', () => { + it('should cancel a batch job', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'chat') + + try { + const beta = (client as unknown as { + beta: { + messages: { + batches: { + create: (params: { requests: Array<{ custom_id: string; params: unknown }> }) => Promise<{ id: string; processing_status: string }> + cancel: (id: string) => Promise<{ id: string; processing_status: string }> + } + } + } + }).beta + + // Create a batch first + const batchRequests = [{ + custom_id: 'request-1', + params: { + model, + max_tokens: 50, + messages: [{ role: 'user', content: 'Say hello' }], + }, + }] + + const batch = await beta.messages.batches.create({ requests: batchRequests }) + + // Cancel the batch + const cancelled = await beta.messages.batches.cancel(batch.id) + + expect(cancelled).toBeDefined() + expect(cancelled.id).toBe(batch.id) + expect(['canceling', 'ended']).toContain(cancelled.processing_status) + + console.log(`āœ… Batch API cancel passed for anthropic - Batch ID: ${cancelled.id}, Status: ${cancelled.processing_status}`) + } catch (error) { + console.log(`āš ļø Batch API cancel test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + describe('Batch API - Results', () => { + it('should retrieve batch results', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'chat') + + try { + const beta = (client as unknown as { + beta: { + messages: { + batches: { + create: (params: { requests: Array<{ custom_id: string; params: unknown }> }) => Promise<{ id: string; processing_status: string }> + results: (id: string) => AsyncIterable<{ custom_id: string }> + cancel: (id: string) => Promise + } + } + } + }).beta + + // Create a batch first + const batchRequests = [{ + custom_id: 'request-1', + params: { + model, + max_tokens: 50, + messages: [{ role: 'user', content: 'Say hello' }], + }, + }] + + const batch = await beta.messages.batches.create({ requests: batchRequests }) + + try { + const results = beta.messages.batches.results(batch.id) + + let resultCount = 0 + for await (const result of results) { + resultCount++ + expect(result.custom_id).toBeDefined() + } + + console.log(`āœ… Batch API results passed for anthropic - ${resultCount} results`) + } catch (resultsError) { + // Results might not be ready yet + console.log(`āš ļø Batch results not ready: ${resultsError instanceof Error ? resultsError.message : 'Unknown error'}`) + } finally { + try { + await beta.messages.batches.cancel(batch.id) + } catch (e) { + console.log(`Info: Could not cancel batch: ${e}`) + } + } + } catch (error) { + console.log(`āš ļø Batch API results test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + describe('Batch API - E2E', () => { + it('should complete end-to-end batch workflow', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'chat') + + try { + const beta = (client as unknown as { + beta: { + messages: { + batches: { + create: (params: { requests: Array<{ custom_id: string; params: unknown }> }) => Promise<{ id: string; processing_status: string }> + retrieve: (id: string) => Promise<{ id: string; processing_status: string; request_counts?: { processing: number; succeeded: number; errored: number } }> + list: (params: { limit: number }) => Promise<{ data: Array<{ id: string }> }> + cancel: (id: string) => Promise + } + } + } + }).beta + + // Step 1: Create batch + console.log('Step 1: Creating batch...') + const batchRequests = [ + { + custom_id: 'e2e-request-1', + params: { + model, + max_tokens: 50, + messages: [{ role: 'user', content: 'Say hello' }], + }, + }, + { + custom_id: 'e2e-request-2', + params: { + model, + max_tokens: 50, + messages: [{ role: 'user', content: 'Say goodbye' }], + }, + }, + ] + + const batch = await beta.messages.batches.create({ requests: batchRequests }) + expect(batch.id).toBeDefined() + console.log(` Created batch: ${batch.id}, status: ${batch.processing_status}`) + + try { + // Step 2: Poll batch status + console.log('Step 2: Polling batch status...') + for (let i = 0; i < 3; i++) { + const retrieved = await beta.messages.batches.retrieve(batch.id) + console.log(` Poll ${i + 1}: status = ${retrieved.processing_status}`) + + if (retrieved.processing_status === 'ended') { + break + } + + await new Promise((resolve) => setTimeout(resolve, 2000)) + } + + // Step 3: Verify batch in list + console.log('Step 3: Verifying batch in list...') + const listResponse = await beta.messages.batches.list({ limit: 20 }) + const batchIds = listResponse.data.map((b) => b.id) + expect(batchIds).toContain(batch.id) + + console.log(`āœ… Batch API E2E passed for anthropic - Batch ID: ${batch.id}`) + } finally { + try { + await beta.messages.batches.cancel(batch.id) + } catch (e) { + console.log(`Info: Could not cancel batch: ${e}`) + } + } + } catch (error) { + console.log(`āš ļø Batch API E2E test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + // ============================================================================ + // Additional Input Tokens Tests + // ============================================================================ + + describe('Count Tokens - With System Message', () => { + it('should return token usage with system message', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'chat') + + const response = await client.messages.create({ + model, + max_tokens: 50, + system: 'You are a helpful assistant.', + messages: [{ role: 'user', content: 'What is 2 + 2?' }], + }) + + expect(response.usage).toBeDefined() + expect(response.usage.input_tokens).toBeGreaterThan(0) + expect(response.usage.output_tokens).toBeGreaterThan(0) + console.log(`āœ… Count tokens with system message passed for anthropic/${model} (input: ${response.usage.input_tokens}, output: ${response.usage.output_tokens})`) + }) + }) + + describe('Count Tokens - Long Text', () => { + it('should return token usage for long text', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'chat') + + const longText = 'This is a longer piece of text that should result in more tokens being counted. '.repeat(10) + + const response = await client.messages.create({ + model, + max_tokens: 50, + messages: [{ role: 'user', content: longText }], + }) + + expect(response.usage).toBeDefined() + expect(response.usage.input_tokens).toBeGreaterThan(50) + expect(response.usage.output_tokens).toBeGreaterThan(0) + console.log(`āœ… Count tokens long text passed for anthropic/${model} (input: ${response.usage.input_tokens}, output: ${response.usage.output_tokens})`) + }) + }) + + // ============================================================================ + // Document Text Input Tests + // ============================================================================ + + describe('Document Input - Plain Text', () => { + it('should handle plain text document input', async () => { + if (skipTests) return + + const client = getAnthropicClient() + const model = getProviderModel('anthropic', 'file') + + const textDocument = ` + DOCUMENT TITLE: Test Agreement + + Section 1: Introduction + This document is a test agreement for integration testing purposes. + + Section 2: Terms + The parties agree to test the API functionality. + + Section 3: Conclusion + This concludes the test document. + ` + + const textBase64 = Buffer.from(textDocument).toString('base64') + + try { + // Use type assertion for beta document feature + const response = await client.messages.create({ + model, + max_tokens: 200, + messages: [ + { + role: 'user', + content: [ + { type: 'text', text: 'Summarize the sections in this document.' }, + { + type: 'document', + source: { + type: 'base64', + media_type: 'text/plain', + data: textBase64, + }, + }, + ], + }, + ], + } as never) + + expect(response).toBeDefined() + const content = getContentString(response) + expect(content.length).toBeGreaterThan(0) + console.log(`āœ… Document input (plain text) passed for anthropic/${model}`) + } catch (error) { + console.log(`āš ļø Document input (plain text) test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) +}) diff --git a/tests/integrations/typescript/tests/test-bedrock.test.ts b/tests/integrations/typescript/tests/test-bedrock.test.ts new file mode 100644 index 0000000000..b82d04c7b4 --- /dev/null +++ b/tests/integrations/typescript/tests/test-bedrock.test.ts @@ -0,0 +1,662 @@ +/** + * Bedrock Integration Tests - Cross-Provider Support + * + * This test suite uses the AWS SDK (v3) to test against multiple AI providers through Bifrost. + * Tests automatically run against all available providers with proper capability filtering. + * All requests include the x-model-provider header to route to the appropriate provider. + * + * Test Scenarios: + * 1. Simple chat (converse) + * 2. Multi-turn conversation (converse) + * 3. Streaming chat (converse-stream) + * 4. Single tool call (converse) + * 5. Multiple tool calls (converse) + * 6. End-to-end tool calling (converse) + * 7. Image analysis (converse) + * 8. System message handling (converse) + */ + +import { + BedrockRuntimeClient, + ConverseCommand, + ConverseStreamCommand, + type ContentBlock, + type Message, + type Tool, + type ToolConfiguration, + type ToolResultContentBlock, + type ToolUseBlock, +} from '@aws-sdk/client-bedrock-runtime' +import { describe, expect, it } from 'vitest' + +import { + getConfig, + getIntegrationUrl, + getProviderModel, +} from '../src/utils/config-loader' + +import { + BASE64_IMAGE, + CALCULATOR_TOOL, + LOCATION_KEYWORDS, + MULTI_TURN_MESSAGES, + MULTIPLE_TOOL_CALL_MESSAGES, + SIMPLE_CHAT_MESSAGES, + WEATHER_KEYWORDS, + WEATHER_TOOL, + mockToolResponse, + type ChatMessage, + type ToolDefinition, +} from '../src/utils/common' + +import { + formatProviderModel, + getCrossProviderParamsWithVkForScenario, + shouldSkipNoProviders, + type ProviderModelVkParam, +} from '../src/utils/parametrize' + +// ============================================================================ +// Helper Functions +// ============================================================================ + +function getBedrockRuntimeClient(): BedrockRuntimeClient { + const baseUrl = getIntegrationUrl('bedrock') + const config = getConfig() + const integrationSettings = config.getIntegrationSettings('bedrock') + const region = (integrationSettings.region as string) || 'us-west-2' + + return new BedrockRuntimeClient({ + region, + endpoint: baseUrl, + credentials: { + accessKeyId: process.env.AWS_ACCESS_KEY_ID || '', + secretAccessKey: process.env.AWS_SECRET_ACCESS_KEY || '', + }, + requestHandler: { + requestTimeout: 300000, // 5 minutes + } as never, + }) +} + +function convertToBedrockMessages(messages: ChatMessage[]): Message[] { + const bedrockMessages: Message[] = [] + + for (const msg of messages) { + if (msg.role === 'system') { + continue + } + + const content: ContentBlock[] = [] + + if (Array.isArray(msg.content)) { + for (const item of msg.content) { + if (item.type === 'text') { + content.push({ text: item.text }) + } else if (item.type === 'image_url' && item.image_url) { + const url = item.image_url.url + if (url.startsWith('data:image')) { + const [header, data] = url.split(',') + const mediaType = header.split(';')[0].split(':')[1] + const format = mediaType.split('/')[1] as 'png' | 'jpeg' | 'gif' | 'webp' + const imageBytes = Buffer.from(data, 'base64') + content.push({ + image: { + format, + source: { bytes: imageBytes }, + }, + }) + } + } + } + } else { + content.push({ text: msg.content }) + } + + const role = msg.role === 'user' ? 'user' : 'assistant' + bedrockMessages.push({ role, content }) + } + + return bedrockMessages +} + +function convertToBedrockTools(tools: ToolDefinition[]): ToolConfiguration { + const bedrockTools: Tool[] = tools.map((tool) => ({ + toolSpec: { + name: tool.name, + description: tool.description, + inputSchema: { json: tool.parameters }, + }, + })) + + return { tools: bedrockTools } +} + +function extractSystemMessages(messages: ChatMessage[]): { text: string }[] { + return messages + .filter((msg) => msg.role === 'system') + .map((msg) => ({ text: msg.content as string })) +} + +function extractToolCalls(response: { output?: { message?: Message } }): Array<{ + id: string + name: string + arguments: Record +}> { + const toolCalls: Array<{ + id: string + name: string + arguments: Record + }> = [] + + const message = response.output?.message + if (!message?.content) return toolCalls + + for (const item of message.content) { + if ('toolUse' in item && item.toolUse) { + const toolUse = item.toolUse as ToolUseBlock + toolCalls.push({ + id: toolUse.toolUseId || '', + name: toolUse.name || '', + arguments: (toolUse.input as Record) || {}, + }) + } + } + + return toolCalls +} + +function assertValidChatResponse(response: { output?: { message?: Message } }): void { + expect(response).toBeDefined() + expect(response.output).toBeDefined() + expect(response.output?.message).toBeDefined() + expect(response.output?.message?.content).toBeDefined() + expect(response.output?.message?.content?.length).toBeGreaterThan(0) +} + +function assertHasToolCalls( + response: { output?: { message?: Message } }, + expectedCount?: number +): void { + const toolCalls = extractToolCalls(response) + expect(toolCalls.length).toBeGreaterThan(0) + if (expectedCount !== undefined) { + expect(toolCalls.length).toBe(expectedCount) + } +} + +function getTextContent(response: { output?: { message?: Message } }): string { + const message = response.output?.message + if (!message?.content) return '' + + for (const item of message.content) { + if ('text' in item && item.text) { + return item.text + } + } + return '' +} + +// ============================================================================ +// Test Suite +// ============================================================================ + +describe('Bedrock SDK Integration Tests', () => { + // ============================================================================ + // Simple Chat Tests + // ============================================================================ + + describe('Simple Chat', () => { + const testCases = getCrossProviderParamsWithVkForScenario('simple_chat', ['bedrock']) + + it.each(testCases)( + 'should complete a simple chat - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for simple_chat') + return + } + + const client = getBedrockRuntimeClient() + const messages = convertToBedrockMessages(SIMPLE_CHAT_MESSAGES) + const modelId = formatProviderModel(provider, model) + + const command = new ConverseCommand({ + modelId, + messages, + inferenceConfig: { maxTokens: 100 }, + }) + + const response = await client.send(command) + assertValidChatResponse(response) + + const textContent = getTextContent(response) + expect(textContent.length).toBeGreaterThan(0) + console.log(`āœ… Simple chat passed for ${modelId}`) + } + ) + }) + + // ============================================================================ + // Multi-turn Conversation Tests + // ============================================================================ + + describe('Multi-turn Conversation', () => { + const testCases = getCrossProviderParamsWithVkForScenario('multi_turn_conversation', ['bedrock']) + + it.each(testCases)( + 'should handle multi-turn conversation - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for multi_turn_conversation') + return + } + + const client = getBedrockRuntimeClient() + const messages = convertToBedrockMessages(MULTI_TURN_MESSAGES) + const modelId = formatProviderModel(provider, model) + + const command = new ConverseCommand({ + modelId, + messages, + inferenceConfig: { maxTokens: 150 }, + }) + + const response = await client.send(command) + assertValidChatResponse(response) + + const textContent = getTextContent(response).toLowerCase() + const populationKeywords = ['population', 'million', 'people', 'inhabitants', 'resident'] + expect(populationKeywords.some((word) => textContent.includes(word))).toBe(true) + console.log(`āœ… Multi-turn conversation passed for ${modelId}`) + } + ) + }) + + // ============================================================================ + // Streaming Tests + // ============================================================================ + + describe('Streaming Chat', () => { + const testCases = getCrossProviderParamsWithVkForScenario('streaming', ['bedrock']) + + it.each(testCases)( + 'should stream chat response - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for streaming') + return + } + + const client = getBedrockRuntimeClient() + const messages = convertToBedrockMessages([ + { role: 'user', content: 'Say hello in exactly 3 words.' }, + ]) + const modelId = formatProviderModel(provider, model) + + const command = new ConverseStreamCommand({ + modelId, + messages, + inferenceConfig: { maxTokens: 100 }, + }) + + const response = await client.send(command) + const chunks: string[] = [] + + if (response.stream) { + for await (const event of response.stream) { + if (event.contentBlockDelta) { + const delta = event.contentBlockDelta.delta + if (delta && 'text' in delta && delta.text) { + chunks.push(delta.text) + } + } + } + } + + const combinedText = chunks.join('') + expect(combinedText.length).toBeGreaterThan(0) + console.log(`āœ… Streaming chat passed for ${modelId}`) + } + ) + }) + + // ============================================================================ + // Streaming Client Disconnect Tests + // ============================================================================ + + describe('Streaming Chat - Client Disconnect', () => { + const testCases = getCrossProviderParamsWithVkForScenario('streaming', ['bedrock']) + + it.each(testCases)( + 'should handle client disconnect mid-stream - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for streaming') + return + } + + const client = getBedrockRuntimeClient() + const abortController = new AbortController() + + // Request a longer response to ensure we have time to abort mid-stream + const messages = convertToBedrockMessages([ + { role: 'user', content: 'Write a detailed essay about the history of computing, including at least 10 paragraphs.' }, + ]) + const modelId = formatProviderModel(provider, model) + + const command = new ConverseStreamCommand({ + modelId, + messages, + inferenceConfig: { maxTokens: 1000 }, + }) + + const response = await client.send(command, { + abortSignal: abortController.signal, + }) + + let chunkCount = 0 + let content = '' + let wasAborted = false + + try { + if (response.stream) { + for await (const event of response.stream) { + chunkCount++ + if (event.contentBlockDelta) { + const delta = event.contentBlockDelta.delta + if (delta && 'text' in delta && delta.text) { + content += delta.text + } + } + + // Abort after receiving a few chunks + if (chunkCount >= 5) { + abortController.abort() + } + } + } + } catch (error) { + wasAborted = true + expect(error).toBeDefined() + // The error should be an AbortError or contain abort-related message + const errorMessage = error instanceof Error ? error.message.toLowerCase() : String(error).toLowerCase() + const errorName = (error as { name?: string })?.name?.toLowerCase() || '' + const isAbortError = errorMessage.includes('abort') || + errorMessage.includes('cancel') || + errorName.includes('abort') || + error instanceof DOMException || + (error as { name?: string })?.name === 'AbortError' + expect(isAbortError).toBe(true) + } + + // Verify we received some content before aborting + expect(chunkCount).toBeGreaterThanOrEqual(5) + expect(content.length).toBeGreaterThan(0) + expect(wasAborted).toBe(true) + console.log(`āœ… Streaming client disconnect passed for ${modelId} (${chunkCount} chunks before abort)`) + } + ) + }) + + // ============================================================================ + // Tool Calling Tests + // ============================================================================ + + describe('Single Tool Call', () => { + const testCases = getCrossProviderParamsWithVkForScenario('tool_calls', ['bedrock']) + + it.each(testCases)( + 'should make a single tool call - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for tool_calls') + return + } + + const client = getBedrockRuntimeClient() + const toolModel = getProviderModel(provider, 'tools') + const modelId = formatProviderModel(provider, toolModel || model) + + const messages = convertToBedrockMessages([ + { role: 'user', content: "What's the weather in Boston?" }, + ]) + const toolConfig = convertToBedrockTools([WEATHER_TOOL]) + toolConfig.toolChoice = { any: {} } + + const command = new ConverseCommand({ + modelId, + messages, + toolConfig, + inferenceConfig: { maxTokens: 500 }, + }) + + const response = await client.send(command) + assertHasToolCalls(response, 1) + + const toolCalls = extractToolCalls(response) + expect(toolCalls[0].name).toBe('get_weather') + console.log(`āœ… Single tool call passed for ${modelId}`) + } + ) + }) + + describe('Multiple Tool Calls', () => { + const testCases = getCrossProviderParamsWithVkForScenario('multiple_tool_calls', ['bedrock']) + + it.each(testCases)( + 'should make multiple tool calls - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for multiple_tool_calls') + return + } + + const client = getBedrockRuntimeClient() + const toolModel = getProviderModel(provider, 'tools') + const modelId = formatProviderModel(provider, toolModel || model) + + const messages = convertToBedrockMessages(MULTIPLE_TOOL_CALL_MESSAGES) + const toolConfig = convertToBedrockTools([WEATHER_TOOL, CALCULATOR_TOOL]) + toolConfig.toolChoice = { any: {} } + + const command = new ConverseCommand({ + modelId, + messages, + toolConfig, + inferenceConfig: { maxTokens: 200 }, + }) + + const response = await client.send(command) + const toolCalls = extractToolCalls(response) + expect(toolCalls.length).toBeGreaterThanOrEqual(1) + + const toolNames = toolCalls.map((tc) => tc.name) + const expectedTools = ['get_weather', 'calculate'] + expect(toolNames.some((name) => expectedTools.includes(name))).toBe(true) + console.log(`āœ… Multiple tool calls passed for ${modelId}`) + } + ) + }) + + describe('End-to-End Tool Calling', () => { + const testCases = getCrossProviderParamsWithVkForScenario('end2end_tool_calling', ['bedrock']) + + it.each(testCases)( + 'should complete end-to-end tool calling - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for end2end_tool_calling') + return + } + + const client = getBedrockRuntimeClient() + const toolModel = getProviderModel(provider, 'tools') + const modelId = formatProviderModel(provider, toolModel || model) + + // Step 1: Initial request + let messages = convertToBedrockMessages([ + { role: 'user', content: "What's the weather in San Francisco?" }, + ]) + const toolConfig = convertToBedrockTools([WEATHER_TOOL]) + toolConfig.toolChoice = { any: {} } + + const command1 = new ConverseCommand({ + modelId, + messages, + toolConfig, + inferenceConfig: { maxTokens: 500 }, + }) + + const response1 = await client.send(command1) + assertHasToolCalls(response1, 1) + + const toolCalls = extractToolCalls(response1) + expect(toolCalls[0].name).toBe('get_weather') + + // Step 2: Append assistant response and tool result + const assistantMessage = response1.output?.message + if (assistantMessage) { + messages = [...messages, assistantMessage] + } + + const toolCall = toolCalls[0] + const toolResponseText = mockToolResponse(toolCall.name, toolCall.arguments) + + const toolResultContent: ToolResultContentBlock[] = [{ text: toolResponseText }] + messages.push({ + role: 'user', + content: [ + { + toolResult: { + toolUseId: toolCall.id, + content: toolResultContent, + status: 'success', + }, + }, + ], + }) + + // Step 3: Final request with tool results + const command2 = new ConverseCommand({ + modelId, + messages, + toolConfig, + inferenceConfig: { maxTokens: 500 }, + }) + + const response2 = await client.send(command2) + assertValidChatResponse(response2) + + const finalText = getTextContent(response2).toLowerCase() + const weatherLocationKeywords = [...WEATHER_KEYWORDS, ...LOCATION_KEYWORDS, 'san francisco', 'sf'] + expect(weatherLocationKeywords.some((word) => finalText.includes(word))).toBe(true) + console.log(`āœ… End-to-end tool calling passed for ${modelId}`) + } + ) + }) + + // ============================================================================ + // Image Analysis Tests + // ============================================================================ + + describe('Image Base64', () => { + const testCases = getCrossProviderParamsWithVkForScenario('image_base64', ['bedrock']) + + it.each(testCases)( + 'should analyze image from Base64 - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for image_base64') + return + } + + const client = getBedrockRuntimeClient() + const visionModel = getProviderModel(provider, 'vision') + const modelId = formatProviderModel(provider, visionModel || model) + + const messages = convertToBedrockMessages([ + { + role: 'user', + content: [ + { + type: 'text', + text: 'What do you see in this image? Describe what you see.', + }, + { + type: 'image_url', + image_url: { url: `data:image/png;base64,${BASE64_IMAGE}` }, + }, + ], + }, + ]) + + const command = new ConverseCommand({ + modelId, + messages, + inferenceConfig: { maxTokens: 500 }, + }) + + const response = await client.send(command) + assertValidChatResponse(response) + + const textContent = getTextContent(response).toLowerCase() + const imageKeywords = [ + 'image', 'picture', 'photo', 'see', 'visual', 'show', + 'appear', 'color', 'scene', 'pixel', 'red', 'square', + ] + const hasImageReference = imageKeywords.some((keyword) => textContent.includes(keyword)) + expect(hasImageReference || textContent.length > 5).toBe(true) + console.log(`āœ… Image Base64 analysis passed for ${modelId}`) + } + ) + }) + + // ============================================================================ + // System Message Tests + // ============================================================================ + + describe('System Message', () => { + const testCases = getCrossProviderParamsWithVkForScenario('simple_chat', ['bedrock']) + + it.each(testCases)( + 'should handle system message - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for simple_chat') + return + } + + const client = getBedrockRuntimeClient() + const modelId = formatProviderModel(provider, model) + + const messagesWithSystem: ChatMessage[] = [ + { role: 'system', content: 'You are a helpful assistant that always responds in exactly 5 words.' }, + { role: 'user', content: 'Hello, how are you?' }, + ] + + const systemMessages = extractSystemMessages(messagesWithSystem) + const bedrockMessages = convertToBedrockMessages(messagesWithSystem) + + const command = new ConverseCommand({ + modelId, + messages: bedrockMessages, + system: systemMessages, + inferenceConfig: { maxTokens: 50 }, + }) + + const response = await client.send(command) + assertValidChatResponse(response) + + const textContent = getTextContent(response) + expect(textContent.length).toBeGreaterThan(0) + + // Check if response is approximately 5 words (allow some flexibility) + const wordCount = textContent.split(/\s+/).length + expect(wordCount).toBeGreaterThanOrEqual(3) + expect(wordCount).toBeLessThanOrEqual(10) + console.log(`āœ… System message handling passed for ${modelId}`) + } + ) + }) +}) diff --git a/tests/integrations/typescript/tests/test-google.test.ts b/tests/integrations/typescript/tests/test-google.test.ts new file mode 100644 index 0000000000..aeff1f5e7f --- /dev/null +++ b/tests/integrations/typescript/tests/test-google.test.ts @@ -0,0 +1,748 @@ +/** + * Google GenAI Integration Tests + * + * This test suite uses the Google Generative AI SDK to test Gemini models. + * Note: The @google/generative-ai SDK does not support custom base URL configuration, + * so these tests validate the SDK directly against Google's API rather than routing + * through Bifrost. To test Google models through Bifrost, use the OpenAI SDK with + * model name routing (e.g., model: "gemini/gemini-1.5-pro") or the LangChain tests. + * + * Tests cover chat, streaming, tool calling, and vision capabilities. + * + * Test Scenarios: + * 1. Simple chat + * 2. Multi-turn conversation + * 3. Streaming chat + * 4. Single tool call + * 5. Multiple tool calls + * 6. End-to-end tool calling + * 7. Image Base64 + * 8. Embeddings + * 9. Count tokens + */ + +import { describe, it, expect, beforeAll } from 'vitest' +import { + GoogleGenerativeAI, + GenerativeModel, + Content, + Part, + FunctionDeclaration, + Tool, + SchemaType, +} from '@google/generative-ai' + +// Explicit type mapping for tool parameters to avoid invalid enum values from toUpperCase() +const TYPE_MAP: Record = { + string: SchemaType.STRING, + number: SchemaType.NUMBER, + integer: SchemaType.INTEGER, + boolean: SchemaType.BOOLEAN, + array: SchemaType.ARRAY, + object: SchemaType.OBJECT, +} + +import { + getIntegrationUrl, + getProviderModel, + isProviderAvailable, + getConfig, +} from '../src/utils/config-loader' + +import { + SIMPLE_CHAT_MESSAGES, + MULTI_TURN_MESSAGES, + STREAMING_CHAT_MESSAGES, + SINGLE_TOOL_CALL_MESSAGES, + MULTIPLE_TOOL_CALL_MESSAGES, + BASE64_IMAGE, + WEATHER_TOOL, + CALCULATOR_TOOL, + EMBEDDINGS_SINGLE_TEXT, + EMBEDDINGS_MULTIPLE_TEXTS, + getApiKey, + hasApiKey, + mockToolResponse, + type ChatMessage, + type ToolDefinition, +} from '../src/utils/common' + +// ============================================================================ +// Helper Functions +// ============================================================================ + +function getGoogleClient(): GoogleGenerativeAI { + // Note: The @google/generative-ai SDK does not support custom base URL configuration. + // Unlike OpenAI and Anthropic SDKs, requests cannot be routed through Bifrost directly. + // These tests validate the Google GenAI SDK directly against Google's API. + // To test Google models through Bifrost, use the OpenAI SDK with model name routing + // (e.g., model: "gemini/gemini-1.5-pro") or the LangChain tests. + const apiKey = hasApiKey('gemini') ? getApiKey('gemini') : 'dummy-key' + return new GoogleGenerativeAI(apiKey) +} + +function getGenerativeModel(modelName?: string): GenerativeModel { + const client = getGoogleClient() + const model = modelName || getProviderModel('gemini', 'chat') + return client.getGenerativeModel({ model }) +} + +function convertToGoogleContent(messages: ChatMessage[]): Content[] { + return messages.map((msg) => { + const role = msg.role === 'assistant' ? 'model' : 'user' + + if (typeof msg.content === 'string') { + return { + role, + parts: [{ text: msg.content }], + } + } + + // Handle multimodal content + const parts: Part[] = msg.content.map((part) => { + if (part.type === 'text') { + return { text: part.text! } + } + + // Handle image content + const imageUrl = part.image_url!.url + if (imageUrl.startsWith('data:')) { + // Extract base64 data and mime type + const matches = imageUrl.match(/^data:([^;]+);base64,(.+)$/) + if (matches) { + return { + inlineData: { + mimeType: matches[1], + data: matches[2], + }, + } + } + } + + // URL images - Google expects inline data, so we'd need to fetch + // For now, return a text placeholder + return { text: `[Image: ${imageUrl}]` } + }) + + return { role, parts } + }) +} + +function convertToGoogleTools(tools: ToolDefinition[]): Tool[] { + const functionDeclarations: FunctionDeclaration[] = tools.map((tool) => ({ + name: tool.name, + description: tool.description, + parameters: { + type: SchemaType.OBJECT, + properties: Object.fromEntries( + Object.entries(tool.parameters.properties).map(([key, value]) => [ + key, + { + type: TYPE_MAP[value.type] || SchemaType.STRING, + description: value.description, + ...(value.enum ? { enum: value.enum } : {}), + }, + ]) + ), + required: tool.parameters.required || [], + }, + })) + + return [{ functionDeclarations }] +} + +interface GoogleToolCall { + name: string + arguments: Record +} + +function extractGoogleToolCalls(response: { response: { candidates?: Array<{ content?: { parts?: Part[] } }> } }): GoogleToolCall[] { + const toolCalls: GoogleToolCall[] = [] + + const candidates = response.response.candidates || [] + for (const candidate of candidates) { + const parts = candidate.content?.parts || [] + for (const part of parts) { + if ('functionCall' in part && part.functionCall) { + toolCalls.push({ + name: part.functionCall.name, + arguments: part.functionCall.args as Record, + }) + } + } + } + + return toolCalls +} + +function getResponseText(response: { response: { text: () => string } }): string { + try { + return response.response.text() + } catch { + return '' + } +} + +// ============================================================================ +// Test Suite +// ============================================================================ + +describe('Google GenAI SDK Integration Tests', () => { + const skipTests = !isProviderAvailable('gemini') + + beforeAll(() => { + if (skipTests) { + console.log('āš ļø Skipping Google GenAI tests: GEMINI_API_KEY not set') + } + }) + + // ============================================================================ + // Simple Chat Tests + // ============================================================================ + + describe('Simple Chat', () => { + it('should complete a simple chat', async () => { + if (skipTests) return + + const model = getGenerativeModel() + const modelName = getProviderModel('gemini', 'chat') + + const result = await model.generateContent(SIMPLE_CHAT_MESSAGES[0].content as string) + + expect(result).toBeDefined() + const text = getResponseText(result) + expect(text.length).toBeGreaterThan(0) + console.log(`āœ… Simple chat passed for google/${modelName}`) + }) + }) + + // ============================================================================ + // Multi-turn Conversation Tests + // ============================================================================ + + describe('Multi-turn Conversation', () => { + it('should handle multi-turn conversation', async () => { + if (skipTests) return + + const model = getGenerativeModel() + const modelName = getProviderModel('gemini', 'chat') + + const chat = model.startChat({ + history: convertToGoogleContent(MULTI_TURN_MESSAGES.slice(0, -1)), + }) + + const result = await chat.sendMessage(MULTI_TURN_MESSAGES[MULTI_TURN_MESSAGES.length - 1].content as string) + + expect(result).toBeDefined() + const text = getResponseText(result) + expect(text.toLowerCase()).toMatch(/paris|population|million|people/i) + console.log(`āœ… Multi-turn conversation passed for google/${modelName}`) + }) + }) + + // ============================================================================ + // Streaming Tests + // ============================================================================ + + describe('Streaming Chat', () => { + it('should stream chat response', async () => { + if (skipTests) return + + const model = getGenerativeModel() + const modelName = getProviderModel('gemini', 'chat') + + const result = await model.generateContentStream(STREAMING_CHAT_MESSAGES[0].content as string) + + let content = '' + for await (const chunk of result.stream) { + const text = chunk.text() + if (text) { + content += text + } + } + + expect(content.length).toBeGreaterThan(0) + console.log(`āœ… Streaming chat passed for google/${modelName}`) + }) + }) + + // ============================================================================ + // Tool Calling Tests + // ============================================================================ + + describe('Single Tool Call', () => { + it('should make a single tool call', async () => { + if (skipTests) return + + const toolModel = getProviderModel('gemini', 'tools') + const model = getGenerativeModel(toolModel) + + const result = await model.generateContent({ + contents: convertToGoogleContent(SINGLE_TOOL_CALL_MESSAGES), + tools: convertToGoogleTools([WEATHER_TOOL]), + }) + + const toolCalls = extractGoogleToolCalls(result) + expect(toolCalls.length).toBe(1) + expect(toolCalls[0].name).toBe('get_weather') + console.log(`āœ… Single tool call passed for google/${toolModel}`) + }) + }) + + describe('Multiple Tool Calls', () => { + it('should make multiple tool calls', async () => { + if (skipTests) return + + const toolModel = getProviderModel('gemini', 'tools') + const model = getGenerativeModel(toolModel) + + const result = await model.generateContent({ + contents: convertToGoogleContent(MULTIPLE_TOOL_CALL_MESSAGES), + tools: convertToGoogleTools([WEATHER_TOOL, CALCULATOR_TOOL]), + }) + + const toolCalls = extractGoogleToolCalls(result) + expect(toolCalls.length).toBeGreaterThanOrEqual(1) + + const toolNames = toolCalls.map((tc) => tc.name) + expect(toolNames.some((name) => name === 'get_weather' || name === 'calculate')).toBe(true) + console.log(`āœ… Multiple tool calls passed for google/${toolModel}`) + }) + }) + + describe('End-to-End Tool Calling', () => { + it('should complete end-to-end tool calling', async () => { + if (skipTests) return + + const toolModel = getProviderModel('gemini', 'tools') + const model = getGenerativeModel(toolModel) + + // Step 1: Initial request with tools + const chat = model.startChat({ + tools: convertToGoogleTools([WEATHER_TOOL]), + }) + + const result1 = await chat.sendMessage(SINGLE_TOOL_CALL_MESSAGES[0].content as string) + const toolCalls = extractGoogleToolCalls(result1) + + expect(toolCalls.length).toBeGreaterThan(0) + + // Step 2: Execute tool and get result + const toolResult = mockToolResponse(toolCalls[0].name, toolCalls[0].arguments) + + // Step 3: Send tool result back + const result2 = await chat.sendMessage([ + { + functionResponse: { + name: toolCalls[0].name, + response: JSON.parse(toolResult), + }, + }, + ]) + + expect(result2).toBeDefined() + const text = getResponseText(result2) + expect(text.length).toBeGreaterThan(0) + console.log(`āœ… End-to-end tool calling passed for google/${toolModel}`) + }) + }) + + // ============================================================================ + // Image/Vision Tests + // ============================================================================ + + describe('Image Base64', () => { + it('should analyze image from Base64', async () => { + if (skipTests) return + + const visionModel = getProviderModel('gemini', 'vision') + const model = getGenerativeModel(visionModel) + + const result = await model.generateContent([ + { text: 'What color is this image?' }, + { + inlineData: { + mimeType: 'image/png', + data: BASE64_IMAGE, + }, + }, + ]) + + expect(result).toBeDefined() + const text = getResponseText(result) + expect(text.length).toBeGreaterThan(10) + console.log(`āœ… Image Base64 analysis passed for google/${visionModel}`) + }) + }) + + // ============================================================================ + // Embeddings Tests + // ============================================================================ + + describe('Embeddings - Single Text', () => { + it('should generate single text embedding', async () => { + if (skipTests) return + + const client = getGoogleClient() + const embeddingsModel = getProviderModel('gemini', 'embeddings') + + // Skip if no embeddings model available + if (!embeddingsModel) { + console.log('āš ļø Skipping embeddings test: No embeddings model configured') + return + } + + const model = client.getGenerativeModel({ model: embeddingsModel }) + + const result = await model.embedContent(EMBEDDINGS_SINGLE_TEXT) + + expect(result).toBeDefined() + expect(result.embedding).toBeDefined() + expect(result.embedding.values).toBeDefined() + expect(result.embedding.values.length).toBeGreaterThan(0) + console.log(`āœ… Single text embedding passed for google/${embeddingsModel}`) + }) + }) + + describe('Embeddings - Batch', () => { + it('should generate batch embeddings', async () => { + if (skipTests) return + + const client = getGoogleClient() + const embeddingsModel = getProviderModel('gemini', 'embeddings') + + // Skip if no embeddings model available + if (!embeddingsModel) { + console.log('āš ļø Skipping embeddings test: No embeddings model configured') + return + } + + const model = client.getGenerativeModel({ model: embeddingsModel }) + + const result = await model.batchEmbedContents({ + requests: EMBEDDINGS_MULTIPLE_TEXTS.map((text) => ({ content: { parts: [{ text }], role: 'user' } })), + }) + + expect(result).toBeDefined() + expect(result.embeddings).toBeDefined() + expect(result.embeddings.length).toBe(EMBEDDINGS_MULTIPLE_TEXTS.length) + console.log(`āœ… Batch embeddings passed for google/${embeddingsModel}`) + }) + }) + + // ============================================================================ + // Count Tokens Tests + // ============================================================================ + + describe('Count Tokens', () => { + it('should count tokens', async () => { + if (skipTests) return + + const model = getGenerativeModel() + const modelName = getProviderModel('gemini', 'chat') + + const result = await model.countTokens('Hello, how are you today?') + + expect(result).toBeDefined() + expect(result.totalTokens).toBeGreaterThan(0) + console.log(`āœ… Count tokens passed for google/${modelName} (${result.totalTokens} tokens)`) + }) + }) + + // ============================================================================ + // Thinking/Extended Reasoning Tests + // ============================================================================ + + describe('Thinking/Extended Reasoning', () => { + it('should support extended thinking', async () => { + if (skipTests) return + + const thinkingModel = getProviderModel('gemini', 'thinking') + + // Skip if no thinking model available + if (!thinkingModel) { + console.log('āš ļø Skipping thinking test: No thinking model configured') + return + } + + const model = getGenerativeModel(thinkingModel) + + try { + const result = await model.generateContent({ + contents: [ + { + role: 'user', + parts: [{ text: 'What is 15% of 80? Show your reasoning step by step.' }], + }, + ], + generationConfig: { + // Google Gemini uses different config for reasoning + maxOutputTokens: 2048, + }, + }) + + expect(result).toBeDefined() + const text = getResponseText(result) + expect(text.length).toBeGreaterThan(0) + console.log(`āœ… Thinking/Extended reasoning passed for google/${thinkingModel}`) + } catch (error) { + console.log(`āš ļø Thinking test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + // ============================================================================ + // Audio Transcription Tests + // ============================================================================ + + describe('Audio Transcription', () => { + it('should transcribe audio content', async () => { + if (skipTests) return + + const transcriptionModel = getProviderModel('gemini', 'transcription') + + // Skip if no transcription model available + if (!transcriptionModel) { + console.log('āš ļø Skipping transcription test: No transcription model configured') + return + } + + const model = getGenerativeModel(transcriptionModel) + + // Generate a minimal audio WAV buffer for testing + const sampleRate = 16000 + const duration = 0.5 // 0.5 seconds + const numSamples = Math.floor(sampleRate * duration) + const frequency = 440 // A4 note + + // Create WAV header + const headerSize = 44 + const dataSize = numSamples * 2 + const buffer = new ArrayBuffer(headerSize + dataSize) + const view = new DataView(buffer) + + // RIFF header + const encoder = new TextEncoder() + new Uint8Array(buffer, 0, 4).set(encoder.encode('RIFF')) + view.setUint32(4, headerSize + dataSize - 8, true) + new Uint8Array(buffer, 8, 4).set(encoder.encode('WAVE')) + + // fmt chunk + new Uint8Array(buffer, 12, 4).set(encoder.encode('fmt ')) + view.setUint32(16, 16, true) + view.setUint16(20, 1, true) + view.setUint16(22, 1, true) + view.setUint32(24, sampleRate, true) + view.setUint32(28, sampleRate * 2, true) + view.setUint16(32, 2, true) + view.setUint16(34, 16, true) + + // data chunk + new Uint8Array(buffer, 36, 4).set(encoder.encode('data')) + view.setUint32(40, dataSize, true) + + // Generate sine wave + for (let i = 0; i < numSamples; i++) { + const t = i / sampleRate + const sample = Math.sin(2 * Math.PI * frequency * t) * 32767 * 0.5 + view.setInt16(headerSize + i * 2, Math.round(sample), true) + } + + const audioBase64 = btoa(String.fromCharCode(...new Uint8Array(buffer))) + + try { + const result = await model.generateContent([ + { text: 'Please transcribe this audio.' }, + { + inlineData: { + mimeType: 'audio/wav', + data: audioBase64, + }, + }, + ]) + + expect(result).toBeDefined() + // Note: A sine wave may not produce meaningful transcription + console.log(`āœ… Audio transcription passed for google/${transcriptionModel}`) + } catch (error) { + console.log(`āš ļø Transcription test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + // ============================================================================ + // Speech Synthesis Tests + // ============================================================================ + + describe('Speech Synthesis', () => { + it('should synthesize speech', async () => { + if (skipTests) return + + const speechModel = getProviderModel('gemini', 'speech') + + // Skip if no speech model available + if (!speechModel) { + console.log('āš ļø Skipping speech synthesis test: No speech model configured') + return + } + + // Google Gemini TTS requires specific API usage + // This test verifies the model is accessible + try { + const model = getGenerativeModel(speechModel) + + const result = await model.generateContent({ + contents: [ + { + role: 'user', + parts: [{ text: 'Hello, this is a test of speech synthesis.' }], + }, + ], + generationConfig: { + // TTS specific configuration + responseModalities: ['AUDIO'], + speechConfig: { + voiceConfig: { + prebuiltVoiceConfig: { + voiceName: 'Puck', + }, + }, + }, + } as never, + }) + + expect(result).toBeDefined() + console.log(`āœ… Speech synthesis passed for google/${speechModel}`) + } catch (error) { + console.log(`āš ļø Speech synthesis test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + // ============================================================================ + // Document/PDF Input Tests + // ============================================================================ + + describe('Document Input - PDF', () => { + it('should handle PDF document input', async () => { + if (skipTests) return + + const fileModel = getProviderModel('gemini', 'file') + + // Skip if no file model available + if (!fileModel) { + console.log('āš ļø Skipping document input test: No file model configured') + return + } + + const model = getGenerativeModel(fileModel) + + // Sample PDF base64 (minimal PDF with "Hello World") + const pdfBase64 = + 'JVBERi0xLjcKCjEgMCBvYmogICUgZW50cnkgcG9pbnQKPDwKICAvVHlwZSAvQ2F0YWxvZwogIC' + + '9QYWdlcyAyIDAgUgo+PgplbmRvYmoKCjIgMCBvYmoKPDwKICAvVHlwZSAvUGFnZXwKICAvTWV' + + 'kaWFCb3ggWyAwIDAgMjAwIDIwMCBdCiAgL0NvdW50IDEKICAvS2lkcyBbIDMgMCBSIF0KPj4K' + + 'ZW5kb2JqCgozIDAgb2JqCjw8CiAgL1R5cGUgL1BhZ2UKICAvUGFyZW50IDIgMCBSCiAgL1Jlc' + + '291cmNlcyA8PAogICAgL0ZvbnQgPDwKICAgICAgL0YxIDQgMCBSCj4+CiAgPj4KICAvQ29udG' + + 'VudHMgNSAwIFIKPj4KZW5kb2JqCgo0IDAgb2JqCjw8CiAgL1R5cGUgL0ZvbnQKICAvU3VidHl' + + 'wZSAvVHlwZTEKICAvQmFzZUZvbnQgL1RpbWVzLVJvbWFuCj4+CmVuZG9iagoKNSAwIG9iago8' + + 'PAogIC9MZW5ndGggNDQKPj4Kc3RyZWFtCkJUCjcwIDUwIFRECi9GMSAxMiBUZgooSGVsbG8gV' + + '29ybGQhKSBUagpFVAplbmRzdHJlYW0KZW5kb2JqCgp4cmVmCjAgNgowMDAwMDAwMDAwIDY1NT' + + 'M1IGYgCjAwMDAwMDAwMTAgMDAwMDAgbiAKMDAwMDAwMDA2MCAwMDAwMCBuIAowMDAwMDAwMTU' + + '3IDAwMDAwIG4gCjAwMDAwMDAyNTUgMDAwMDAgbiAKMDAwMDAwMDM1MyAwMDAwMCBuIAp0cmFp' + + 'bGVyCjw8CiAgL1NpemUgNgogIC9Sb290IDEgMCBSCj4+CnN0YXJ0eHJlZgo0NDkKJSVFT0YK' + + try { + const result = await model.generateContent([ + { text: 'What does this PDF document contain?' }, + { + inlineData: { + mimeType: 'application/pdf', + data: pdfBase64, + }, + }, + ]) + + expect(result).toBeDefined() + const text = getResponseText(result) + expect(text.length).toBeGreaterThan(0) + console.log(`āœ… Document input (PDF) passed for google/${fileModel}`) + } catch (error) { + console.log(`āš ļø Document input test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + // ============================================================================ + // System Instruction Tests + // ============================================================================ + + describe('System Instruction', () => { + it('should respect system instructions', async () => { + if (skipTests) return + + const model = getGenerativeModel() + const modelName = getProviderModel('gemini', 'chat') + + const client = getGoogleClient() + const systemModel = client.getGenerativeModel({ + model: modelName, + systemInstruction: 'You are a helpful assistant that always responds in exactly 5 words.', + }) + + try { + const result = await systemModel.generateContent('Hello, how are you?') + + expect(result).toBeDefined() + const text = getResponseText(result) + expect(text.length).toBeGreaterThan(0) + + // Check if response is approximately 5 words + const wordCount = text.trim().split(/\s+/).length + expect(wordCount).toBeGreaterThanOrEqual(3) + expect(wordCount).toBeLessThanOrEqual(10) + console.log(`āœ… System instruction passed for google/${modelName}`) + } catch (error) { + console.log(`āš ļø System instruction test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + // ============================================================================ + // Structured Output Tests + // ============================================================================ + + describe('Structured Output', () => { + it('should generate structured output with JSON schema', async () => { + if (skipTests) return + + const model = getGenerativeModel() + const modelName = getProviderModel('gemini', 'chat') + + try { + const result = await model.generateContent({ + contents: [ + { + role: 'user', + parts: [{ text: 'Give me a recipe for chocolate chip cookies as JSON with name, ingredients (array), and instructions (array).' }], + }, + ], + generationConfig: { + responseMimeType: 'application/json', + }, + }) + + expect(result).toBeDefined() + const text = getResponseText(result) + expect(text.length).toBeGreaterThan(0) + + // Try to parse as JSON + const parsed = JSON.parse(text) + expect(parsed).toBeDefined() + console.log(`āœ… Structured output passed for google/${modelName}`) + } catch (error) { + console.log(`āš ļø Structured output test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) +}) diff --git a/tests/integrations/typescript/tests/test-langchain.test.ts b/tests/integrations/typescript/tests/test-langchain.test.ts new file mode 100644 index 0000000000..3ff6fdd369 --- /dev/null +++ b/tests/integrations/typescript/tests/test-langchain.test.ts @@ -0,0 +1,864 @@ +/** + * LangChain.js Integration Tests + * + * This test suite uses LangChain.js to test multiple AI providers through Bifrost. + * Tests cover chat, streaming, tool calling, and structured output capabilities. + * + * Providers tested: + * - OpenAI (via @langchain/openai) + * - Anthropic (via @langchain/anthropic) + * - Google GenAI (via @langchain/google-genai) + * + * Test Scenarios: + * 1. Simple chat + * 2. Multi-turn conversation + * 3. Streaming chat + * 4. Tool calling + * 5. Structured output + */ + +import { describe, it, expect, beforeAll } from 'vitest' +import { ChatOpenAI } from '@langchain/openai' +import { ChatAnthropic } from '@langchain/anthropic' +import { ChatGoogleGenerativeAI } from '@langchain/google-genai' +import { HumanMessage, AIMessage, SystemMessage, BaseMessage } from '@langchain/core/messages' +import { DynamicStructuredTool } from '@langchain/core/tools' +import { z } from 'zod' + +import { + getIntegrationUrl, + getProviderModel, + isProviderAvailable, +} from '../src/utils/config-loader' + +import { + SIMPLE_CHAT_MESSAGES, + MULTI_TURN_MESSAGES, + STREAMING_CHAT_MESSAGES, + SINGLE_TOOL_CALL_MESSAGES, + getApiKey, + hasApiKey, + mockToolResponse, + type ChatMessage, +} from '../src/utils/common' + +// ============================================================================ +// Helper Functions +// ============================================================================ + +type LangChainModel = ChatOpenAI | ChatAnthropic | ChatGoogleGenerativeAI + +function getLangChainOpenAI(): ChatOpenAI { + const baseUrl = getIntegrationUrl('openai') + const apiKey = hasApiKey('openai') ? getApiKey('openai') : 'dummy-key' + const model = getProviderModel('openai', 'chat') + + return new ChatOpenAI({ + modelName: model, + openAIApiKey: apiKey, + configuration: { + baseURL: baseUrl, + }, + maxTokens: 100, + timeout: 300000, + maxRetries: 3, + }) +} + +function getLangChainAnthropic(): ChatAnthropic { + const baseUrl = getIntegrationUrl('anthropic') + const apiKey = hasApiKey('anthropic') ? getApiKey('anthropic') : 'dummy-key' + const model = getProviderModel('anthropic', 'chat') + + return new ChatAnthropic({ + modelName: model, + anthropicApiKey: apiKey, + anthropicApiUrl: baseUrl, + maxTokens: 100, + maxRetries: 3, + }) +} + +function getLangChainGoogle(): ChatGoogleGenerativeAI { + // Use 'gemini' consistently for both API key and model lookup + const apiKey = hasApiKey('gemini') ? getApiKey('gemini') : 'dummy-key' + const model = getProviderModel('gemini', 'chat') + + return new ChatGoogleGenerativeAI({ + modelName: model, + apiKey, + maxOutputTokens: 100, + maxRetries: 3, + }) +} + +function convertToLangChainMessages(messages: ChatMessage[]): BaseMessage[] { + return messages.map((msg) => { + const content = typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content) + + switch (msg.role) { + case 'system': + return new SystemMessage(content) + case 'assistant': + return new AIMessage(content) + case 'user': + default: + return new HumanMessage(content) + } + }) +} + +// Weather tool using Zod schema +const weatherTool = new DynamicStructuredTool({ + name: 'get_weather', + description: 'Get the current weather for a location', + schema: z.object({ + location: z.string().describe('The city and state, e.g. San Francisco, CA'), + unit: z.enum(['celsius', 'fahrenheit']).optional().describe('The temperature unit'), + }), + func: async ({ location, unit }) => { + return mockToolResponse('get_weather', { location, unit }) + }, +}) + +// Calculator tool using Zod schema +const calculatorTool = new DynamicStructuredTool({ + name: 'calculate', + description: 'Perform basic mathematical calculations', + schema: z.object({ + expression: z.string().describe("Mathematical expression to evaluate, e.g. '2 + 2'"), + }), + func: async ({ expression }) => { + return mockToolResponse('calculate', { expression }) + }, +}) + +// ============================================================================ +// Test Suite +// ============================================================================ + +describe('LangChain.js Integration Tests', () => { + // ============================================================================ + // OpenAI via LangChain + // ============================================================================ + + describe('LangChain OpenAI', () => { + const skipTests = !isProviderAvailable('openai') + + beforeAll(() => { + if (skipTests) { + console.log('āš ļø Skipping LangChain OpenAI tests: OPENAI_API_KEY not set') + } + }) + + describe('Simple Chat', () => { + it('should complete a simple chat', async () => { + if (skipTests) return + + const model = getLangChainOpenAI() + const messages = convertToLangChainMessages(SIMPLE_CHAT_MESSAGES) + + const response = await model.invoke(messages) + + expect(response).toBeDefined() + expect(response.content).toBeDefined() + const content = typeof response.content === 'string' ? response.content : JSON.stringify(response.content) + expect(content.length).toBeGreaterThan(0) + console.log(`āœ… LangChain OpenAI simple chat passed`) + }) + }) + + describe('Multi-turn Conversation', () => { + it('should handle multi-turn conversation', async () => { + if (skipTests) return + + const model = getLangChainOpenAI() + const messages = convertToLangChainMessages(MULTI_TURN_MESSAGES) + + const response = await model.invoke(messages) + + expect(response).toBeDefined() + const content = typeof response.content === 'string' ? response.content : JSON.stringify(response.content) + expect(content.toLowerCase()).toMatch(/paris|population|million|people/i) + console.log(`āœ… LangChain OpenAI multi-turn conversation passed`) + }) + }) + + describe('Streaming Chat', () => { + it('should stream chat response', async () => { + if (skipTests) return + + const model = getLangChainOpenAI() + const messages = convertToLangChainMessages(STREAMING_CHAT_MESSAGES) + + const stream = await model.stream(messages) + + let content = '' + for await (const chunk of stream) { + if (chunk.content) { + content += typeof chunk.content === 'string' ? chunk.content : JSON.stringify(chunk.content) + } + } + + expect(content.length).toBeGreaterThan(0) + console.log(`āœ… LangChain OpenAI streaming chat passed`) + }) + }) + + describe('Streaming Chat - Client Disconnect', () => { + it('should handle client disconnect mid-stream', async () => { + if (skipTests) return + + const baseUrl = getIntegrationUrl('openai') + const apiKey = hasApiKey('openai') ? getApiKey('openai') : 'dummy-key' + const modelName = getProviderModel('openai', 'chat') + + // Create model with longer max tokens for a longer response + const model = new ChatOpenAI({ + modelName, + openAIApiKey: apiKey, + configuration: { + baseURL: baseUrl, + }, + maxTokens: 1000, + timeout: 300000, + }) + + const abortController = new AbortController() + const messages = convertToLangChainMessages([ + { role: 'user', content: 'Write a detailed essay about the history of computing, including at least 10 paragraphs.' }, + ]) + + const stream = await model.stream(messages, { + signal: abortController.signal, + }) + + let chunkCount = 0 + let content = '' + let wasAborted = false + + try { + for await (const chunk of stream) { + chunkCount++ + if (chunk.content) { + content += typeof chunk.content === 'string' ? chunk.content : JSON.stringify(chunk.content) + } + + // Abort after receiving a few chunks + if (chunkCount >= 3) { + abortController.abort() + } + } + } catch (error) { + wasAborted = true + expect(error).toBeDefined() + const errorMessage = error instanceof Error ? error.message.toLowerCase() : String(error).toLowerCase() + const isAbortError = errorMessage.includes('abort') || + errorMessage.includes('cancel') || + error instanceof DOMException || + (error as { name?: string })?.name === 'AbortError' + expect(isAbortError).toBe(true) + } + + expect(chunkCount).toBeGreaterThanOrEqual(3) + expect(content.length).toBeGreaterThan(0) + expect(wasAborted).toBe(true) + console.log(`āœ… LangChain OpenAI streaming client disconnect passed (${chunkCount} chunks before abort)`) + }) + }) + + describe('Tool Calling', () => { + it('should make tool calls', async () => { + if (skipTests) return + + const model = getLangChainOpenAI() + const modelWithTools = model.bindTools([weatherTool]) + const messages = convertToLangChainMessages(SINGLE_TOOL_CALL_MESSAGES) + + const response = await modelWithTools.invoke(messages) + + expect(response).toBeDefined() + expect(response.tool_calls).toBeDefined() + expect(response.tool_calls!.length).toBeGreaterThan(0) + expect(response.tool_calls![0].name).toBe('get_weather') + console.log(`āœ… LangChain OpenAI tool calling passed`) + }) + }) + + describe('Structured Output', () => { + it('should generate structured output', async () => { + if (skipTests) return + + const model = getLangChainOpenAI() + + const ResponseSchema = z.object({ + answer: z.string().describe('The answer to the question'), + confidence: z.number().min(0).max(1).describe('Confidence score'), + }) + + const structuredModel = model.withStructuredOutput(ResponseSchema) + + const response = await structuredModel.invoke('What is 2 + 2?') + + expect(response).toBeDefined() + expect(response.answer).toBeDefined() + expect(typeof response.confidence).toBe('number') + console.log(`āœ… LangChain OpenAI structured output passed`) + }) + }) + }) + + // ============================================================================ + // Anthropic via LangChain + // ============================================================================ + + describe('LangChain Anthropic', () => { + const skipTests = !isProviderAvailable('anthropic') + + beforeAll(() => { + if (skipTests) { + console.log('āš ļø Skipping LangChain Anthropic tests: ANTHROPIC_API_KEY not set') + } + }) + + describe('Simple Chat', () => { + it('should complete a simple chat', async () => { + if (skipTests) return + + const model = getLangChainAnthropic() + const messages = convertToLangChainMessages(SIMPLE_CHAT_MESSAGES) + + const response = await model.invoke(messages) + + expect(response).toBeDefined() + expect(response.content).toBeDefined() + const content = typeof response.content === 'string' ? response.content : JSON.stringify(response.content) + expect(content.length).toBeGreaterThan(0) + console.log(`āœ… LangChain Anthropic simple chat passed`) + }) + }) + + describe('Multi-turn Conversation', () => { + it('should handle multi-turn conversation', async () => { + if (skipTests) return + + const model = getLangChainAnthropic() + const messages = convertToLangChainMessages(MULTI_TURN_MESSAGES) + + const response = await model.invoke(messages) + + expect(response).toBeDefined() + const content = typeof response.content === 'string' ? response.content : JSON.stringify(response.content) + expect(content.toLowerCase()).toMatch(/paris|population|million|people/i) + console.log(`āœ… LangChain Anthropic multi-turn conversation passed`) + }) + }) + + describe('Streaming Chat', () => { + it('should stream chat response', async () => { + if (skipTests) return + + const model = getLangChainAnthropic() + const messages = convertToLangChainMessages(STREAMING_CHAT_MESSAGES) + + const stream = await model.stream(messages) + + let content = '' + for await (const chunk of stream) { + if (chunk.content) { + content += typeof chunk.content === 'string' ? chunk.content : JSON.stringify(chunk.content) + } + } + + expect(content.length).toBeGreaterThan(0) + console.log(`āœ… LangChain Anthropic streaming chat passed`) + }) + }) + + describe('Streaming Chat - Client Disconnect', () => { + it('should handle client disconnect mid-stream', async () => { + if (skipTests) return + + const baseUrl = getIntegrationUrl('anthropic') + const apiKey = hasApiKey('anthropic') ? getApiKey('anthropic') : 'dummy-key' + const modelName = getProviderModel('anthropic', 'chat') + + // Create model with longer max tokens for a longer response + const model = new ChatAnthropic({ + modelName, + anthropicApiKey: apiKey, + anthropicApiUrl: baseUrl, + maxTokens: 1000, + maxRetries: 3, + }) + + const abortController = new AbortController() + const messages = convertToLangChainMessages([ + { role: 'user', content: 'Write a detailed essay about the history of computing, including at least 10 paragraphs.' }, + ]) + + const stream = await model.stream(messages, { + signal: abortController.signal, + }) + + let chunkCount = 0 + let content = '' + let wasAborted = false + + try { + for await (const chunk of stream) { + chunkCount++ + if (chunk.content) { + content += typeof chunk.content === 'string' ? chunk.content : JSON.stringify(chunk.content) + } + + // Abort after receiving a few chunks + if (chunkCount >= 5) { + abortController.abort() + } + } + } catch (error) { + wasAborted = true + expect(error).toBeDefined() + const errorMessage = error instanceof Error ? error.message.toLowerCase() : String(error).toLowerCase() + const isAbortError = errorMessage.includes('abort') || + errorMessage.includes('cancel') || + error instanceof DOMException || + (error as { name?: string })?.name === 'AbortError' + expect(isAbortError).toBe(true) + } + + expect(chunkCount).toBeGreaterThanOrEqual(5) + expect(content.length).toBeGreaterThan(0) + expect(wasAborted).toBe(true) + console.log(`āœ… LangChain Anthropic streaming client disconnect passed (${chunkCount} chunks before abort)`) + }) + }) + + describe('Tool Calling', () => { + it('should make tool calls', async () => { + if (skipTests) return + + const model = getLangChainAnthropic() + const modelWithTools = model.bindTools([weatherTool]) + const messages = convertToLangChainMessages(SINGLE_TOOL_CALL_MESSAGES) + + const response = await modelWithTools.invoke(messages) + + expect(response).toBeDefined() + expect(response.tool_calls).toBeDefined() + expect(response.tool_calls!.length).toBeGreaterThan(0) + expect(response.tool_calls![0].name).toBe('get_weather') + console.log(`āœ… LangChain Anthropic tool calling passed`) + }) + }) + }) + + // ============================================================================ + // Google via LangChain + // ============================================================================ + + describe('LangChain Google GenAI', () => { + const skipTests = !isProviderAvailable('gemini') + + beforeAll(() => { + if (skipTests) { + console.log('āš ļø Skipping LangChain Google GenAI tests: GEMINI_API_KEY not set') + } + }) + + describe('Simple Chat', () => { + it('should complete a simple chat', async () => { + if (skipTests) return + + const model = getLangChainGoogle() + const messages = convertToLangChainMessages(SIMPLE_CHAT_MESSAGES) + + const response = await model.invoke(messages) + + expect(response).toBeDefined() + expect(response.content).toBeDefined() + const content = typeof response.content === 'string' ? response.content : JSON.stringify(response.content) + expect(content.length).toBeGreaterThan(0) + console.log(`āœ… LangChain Google GenAI simple chat passed`) + }) + }) + + describe('Multi-turn Conversation', () => { + it('should handle multi-turn conversation', async () => { + if (skipTests) return + + const model = getLangChainGoogle() + const messages = convertToLangChainMessages(MULTI_TURN_MESSAGES) + + const response = await model.invoke(messages) + + expect(response).toBeDefined() + const content = typeof response.content === 'string' ? response.content : JSON.stringify(response.content) + expect(content.toLowerCase()).toMatch(/paris|population|million|people/i) + console.log(`āœ… LangChain Google GenAI multi-turn conversation passed`) + }) + }) + + describe('Streaming Chat', () => { + it('should stream chat response', async () => { + if (skipTests) return + + const model = getLangChainGoogle() + const messages = convertToLangChainMessages(STREAMING_CHAT_MESSAGES) + + const stream = await model.stream(messages) + + let content = '' + for await (const chunk of stream) { + if (chunk.content) { + content += typeof chunk.content === 'string' ? chunk.content : JSON.stringify(chunk.content) + } + } + + expect(content.length).toBeGreaterThan(0) + console.log(`āœ… LangChain Google GenAI streaming chat passed`) + }) + }) + + describe('Tool Calling', () => { + it('should make tool calls', async () => { + if (skipTests) return + + const model = getLangChainGoogle() + const modelWithTools = model.bindTools([weatherTool]) + const messages = convertToLangChainMessages(SINGLE_TOOL_CALL_MESSAGES) + + const response = await modelWithTools.invoke(messages) + + expect(response).toBeDefined() + expect(response.tool_calls).toBeDefined() + expect(response.tool_calls!.length).toBeGreaterThan(0) + expect(response.tool_calls![0].name).toBe('get_weather') + console.log(`āœ… LangChain Google GenAI tool calling passed`) + }) + }) + + describe('Structured Output', () => { + it('should generate structured output', async () => { + if (skipTests) return + + const model = getLangChainGoogle() + + const ResponseSchema = z.object({ + answer: z.string().describe('The answer to the question'), + confidence: z.number().min(0).max(1).describe('Confidence score'), + }) + + try { + const structuredModel = model.withStructuredOutput(ResponseSchema) + const response = await structuredModel.invoke('What is 2 + 2?') + + expect(response).toBeDefined() + expect(response.answer).toBeDefined() + expect(typeof response.confidence).toBe('number') + console.log(`āœ… LangChain Google GenAI structured output passed`) + } catch (error) { + console.log(`āš ļø LangChain Google GenAI structured output test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + }) + + // ============================================================================ + // Cross-Provider Token Counting Tests + // ============================================================================ + + describe('Token Counting', () => { + describe('OpenAI Token Counting', () => { + const skipTests = !isProviderAvailable('openai') + + it('should return token usage in response', async () => { + if (skipTests) return + + const model = getLangChainOpenAI() + const messages = convertToLangChainMessages(SIMPLE_CHAT_MESSAGES) + + const response = await model.invoke(messages) + + expect(response).toBeDefined() + // LangChain includes usage info in response_metadata + if (response.response_metadata) { + const usage = response.response_metadata.usage || response.response_metadata.tokenUsage + if (usage) { + expect(usage.prompt_tokens || usage.promptTokens).toBeGreaterThan(0) + expect(usage.completion_tokens || usage.completionTokens).toBeGreaterThan(0) + } + } + console.log(`āœ… LangChain OpenAI token counting passed`) + }) + }) + + describe('Anthropic Token Counting', () => { + const skipTests = !isProviderAvailable('anthropic') + + it('should return token usage in response', async () => { + if (skipTests) return + + const model = getLangChainAnthropic() + const messages = convertToLangChainMessages(SIMPLE_CHAT_MESSAGES) + + const response = await model.invoke(messages) + + expect(response).toBeDefined() + // Anthropic includes usage info in usage_metadata + if (response.usage_metadata) { + expect(response.usage_metadata.input_tokens).toBeGreaterThan(0) + expect(response.usage_metadata.output_tokens).toBeGreaterThan(0) + } + console.log(`āœ… LangChain Anthropic token counting passed`) + }) + }) + + describe('Google GenAI Token Counting', () => { + const skipTests = !isProviderAvailable('gemini') + + it('should return token usage in response', async () => { + if (skipTests) return + + const model = getLangChainGoogle() + const messages = convertToLangChainMessages(SIMPLE_CHAT_MESSAGES) + + const response = await model.invoke(messages) + + expect(response).toBeDefined() + // Google includes usage info in response_metadata + if (response.response_metadata) { + const usage = response.response_metadata.usage + if (usage) { + expect(usage.promptTokenCount || usage.prompt_tokens).toBeGreaterThan(0) + } + } + console.log(`āœ… LangChain Google GenAI token counting passed`) + }) + }) + }) + + // ============================================================================ + // Cross-Provider Structured Output Tests + // ============================================================================ + + describe('Comprehensive Structured Output', () => { + // Complex schema for testing + const RecipeSchema = z.object({ + name: z.string().describe('Name of the recipe'), + ingredients: z.array(z.object({ + item: z.string().describe('Ingredient name'), + amount: z.string().describe('Amount needed'), + })).describe('List of ingredients'), + steps: z.array(z.string()).describe('Cooking steps'), + prepTime: z.number().describe('Preparation time in minutes'), + cookTime: z.number().describe('Cooking time in minutes'), + }) + + describe('OpenAI Complex Structured Output', () => { + const skipTests = !isProviderAvailable('openai') + + it('should generate complex structured output', async () => { + if (skipTests) return + + const model = getLangChainOpenAI() + const structuredModel = model.withStructuredOutput(RecipeSchema) + + const response = await structuredModel.invoke('Give me a simple recipe for scrambled eggs') + + expect(response).toBeDefined() + expect(response.name).toBeDefined() + expect(Array.isArray(response.ingredients)).toBe(true) + expect(Array.isArray(response.steps)).toBe(true) + expect(typeof response.prepTime).toBe('number') + expect(typeof response.cookTime).toBe('number') + console.log(`āœ… LangChain OpenAI complex structured output passed`) + }) + }) + + describe('Anthropic Complex Structured Output', () => { + const skipTests = !isProviderAvailable('anthropic') + + it('should generate complex structured output', async () => { + if (skipTests) return + + const model = getLangChainAnthropic() + + try { + const structuredModel = model.withStructuredOutput(RecipeSchema) + const response = await structuredModel.invoke('Give me a simple recipe for scrambled eggs') + + expect(response).toBeDefined() + expect(response.name).toBeDefined() + expect(Array.isArray(response.ingredients)).toBe(true) + expect(Array.isArray(response.steps)).toBe(true) + expect(typeof response.prepTime).toBe('number') + expect(typeof response.cookTime).toBe('number') + console.log(`āœ… LangChain Anthropic complex structured output passed`) + } catch (error) { + console.log(`āš ļø LangChain Anthropic complex structured output test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + }) + + // ============================================================================ + // Extended Thinking Tests + // ============================================================================ + + describe('Thinking/Extended Reasoning', () => { + describe('OpenAI Thinking', () => { + const skipTests = !isProviderAvailable('openai') + + it('should support extended reasoning', async () => { + if (skipTests) return + + const thinkingModel = getProviderModel('openai', 'thinking') + + // Skip if no thinking model available + if (!thinkingModel) { + console.log('āš ļø Skipping OpenAI thinking test: No thinking model configured') + return + } + + const baseUrl = getIntegrationUrl('openai') + const apiKey = hasApiKey('openai') ? getApiKey('openai') : 'dummy-key' + + const model = new ChatOpenAI({ + modelName: thinkingModel, + openAIApiKey: apiKey, + configuration: { + baseURL: baseUrl, + }, + maxTokens: 2000, + timeout: 300000, + }) + + try { + const response = await model.invoke([ + new HumanMessage('What is 15% of 80? Think through this step by step.'), + ]) + + expect(response).toBeDefined() + const content = typeof response.content === 'string' ? response.content : JSON.stringify(response.content) + expect(content.length).toBeGreaterThan(0) + console.log(`āœ… LangChain OpenAI thinking passed`) + } catch (error) { + console.log(`āš ļø LangChain OpenAI thinking test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + describe('Anthropic Thinking', () => { + const skipTests = !isProviderAvailable('anthropic') + + it('should support extended reasoning', async () => { + if (skipTests) return + + const thinkingModel = getProviderModel('anthropic', 'thinking') + + // Skip if no thinking model available + if (!thinkingModel) { + console.log('āš ļø Skipping Anthropic thinking test: No thinking model configured') + return + } + + const baseUrl = getIntegrationUrl('anthropic') + const apiKey = hasApiKey('anthropic') ? getApiKey('anthropic') : 'dummy-key' + + const model = new ChatAnthropic({ + modelName: thinkingModel, + anthropicApiKey: apiKey, + anthropicApiUrl: baseUrl, + maxTokens: 8000, + maxRetries: 3, + }) + + try { + // Anthropic thinking requires specific configuration + const response = await model.invoke([ + new HumanMessage('What is 15% of 80? Think through this step by step.'), + ]) + + expect(response).toBeDefined() + const content = typeof response.content === 'string' ? response.content : JSON.stringify(response.content) + expect(content.length).toBeGreaterThan(0) + console.log(`āœ… LangChain Anthropic thinking passed`) + } catch (error) { + console.log(`āš ļø LangChain Anthropic thinking test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + + describe('Google GenAI Thinking', () => { + const skipTests = !isProviderAvailable('gemini') + + it('should support extended reasoning', async () => { + if (skipTests) return + + const thinkingModel = getProviderModel('gemini', 'thinking') + + // Skip if no thinking model available + if (!thinkingModel) { + console.log('āš ļø Skipping Google GenAI thinking test: No thinking model configured') + return + } + + const apiKey = hasApiKey('gemini') ? getApiKey('gemini') : 'dummy-key' + + const model = new ChatGoogleGenerativeAI({ + modelName: thinkingModel, + apiKey, + maxOutputTokens: 2048, + }) + + try { + const response = await model.invoke([ + new HumanMessage('What is 15% of 80? Think through this step by step.'), + ]) + + expect(response).toBeDefined() + const content = typeof response.content === 'string' ? response.content : JSON.stringify(response.content) + expect(content.length).toBeGreaterThan(0) + console.log(`āœ… LangChain Google GenAI thinking passed`) + } catch (error) { + console.log(`āš ļø LangChain Google GenAI thinking test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + }) + }) + }) + + // ============================================================================ + // Streaming Tool Calls Tests + // ============================================================================ + + describe('Streaming Tool Calls', () => { + describe('OpenAI Streaming Tool Calls', () => { + const skipTests = !isProviderAvailable('openai') + + it('should stream tool calls', async () => { + if (skipTests) return + + const model = getLangChainOpenAI() + const modelWithTools = model.bindTools([weatherTool, calculatorTool]) + const messages = convertToLangChainMessages(SINGLE_TOOL_CALL_MESSAGES) + + const stream = await modelWithTools.stream(messages) + + let hasToolCall = false + for await (const chunk of stream) { + if (chunk.tool_calls && chunk.tool_calls.length > 0) { + hasToolCall = true + } + if (chunk.tool_call_chunks && chunk.tool_call_chunks.length > 0) { + hasToolCall = true + } + } + + // Tool calls might not always stream, but the stream should complete + console.log(`āœ… LangChain OpenAI streaming tool calls passed (tool call detected: ${hasToolCall})`) + }) + }) + }) +}) diff --git a/tests/integrations/typescript/tests/test-openai.test.ts b/tests/integrations/typescript/tests/test-openai.test.ts new file mode 100644 index 0000000000..72bd7cabaa --- /dev/null +++ b/tests/integrations/typescript/tests/test-openai.test.ts @@ -0,0 +1,1776 @@ +/** + * OpenAI Integration Tests - Cross-Provider Support + * + * This test suite uses the OpenAI SDK to test against multiple AI providers through Bifrost. + * Tests automatically run against all available providers with proper capability filtering. + * + * Test Scenarios: + * + * Chat Completions: + * 1. Simple chat + * 2. Multi-turn conversation + * 3. Streaming chat + * + * Tool Calling: + * 4. Single tool call + * 5. Multiple tool calls + * 6. End-to-end tool calling + * + * Vision/Image: + * 7. Image URL analysis + * 8. Image Base64 analysis + * 9. Multiple images analysis + * + * Audio: + * 10. Speech synthesis (TTS) + * 11. Audio transcription + * + * Embeddings: + * 12. Single text embedding + * 13. Batch embeddings + * 14. Embedding similarity analysis + * + * Models & Tokens: + * 15. List models + * 16. Count tokens + * + * Files API: + * 17. File upload + * 18. File list + * 19. File retrieve + * 20. File delete + * 21. File content download + * + * Batch API: + * 22. Batch create + * 23. Batch list + * 24. Batch retrieve + * 25. Batch cancel + * + * Responses API: + * 26. Responses - simple text + * 27. Responses - with system message + * 28. Responses - with image + * 29. Responses - with tools + * 30. Responses - streaming + * 31. Responses - streaming with tools + * 32. Responses - reasoning + * + * Input Tokens API: + * 33. Input tokens - simple text + * 34. Input tokens - with system message + * 35. Input tokens - long text + */ + +import OpenAI from 'openai' +import { describe, expect, it } from 'vitest' + +import { + getIntegrationUrl, + getProviderModel, + getVirtualKey +} from '../src/utils/config-loader' + +import { + CALCULATOR_TOOL, + EMBEDDINGS_MULTIPLE_TEXTS, + EMBEDDINGS_SIMILAR_TEXTS, + EMBEDDINGS_SINGLE_TEXT, + IMAGE_BASE64_MESSAGES, + IMAGE_URL_MESSAGES, + MULTIPLE_IMAGES_MESSAGES, + MULTIPLE_TOOL_CALL_MESSAGES, + MULTI_TURN_MESSAGES, + RESPONSES_IMAGE_INPUT, + RESPONSES_REASONING_INPUT, + RESPONSES_SIMPLE_TEXT_INPUT, + RESPONSES_STREAMING_INPUT, + RESPONSES_TEXT_WITH_SYSTEM, + RESPONSES_TOOL_CALL_INPUT, + SIMPLE_CHAT_MESSAGES, + SINGLE_TOOL_CALL_MESSAGES, + SPEECH_TEST_INPUT, + STREAMING_CHAT_MESSAGES, + WEATHER_TOOL, + assertHasToolCalls, + assertValidChatResponse, + assertValidEmbeddingResponse, + assertValidEmbeddingsBatchResponse, + assertValidImageResponse, + assertValidSpeechResponse, + calculateCosineSimilarity, + collectStreamingContent, + convertToOpenAITools, + convertToResponsesTools, + extractToolCalls, + generateTestAudio, + getApiKey, + getProviderVoice, + hasApiKey, + mockToolResponse, + type ChatMessage, + type ExtractedToolCall +} from '../src/utils/common' + +import { + formatProviderModel, + getCrossProviderParamsWithVkForScenario, + shouldSkipNoProviders, + type ProviderModelVkParam, +} from '../src/utils/parametrize' + +// ============================================================================ +// Helper Functions +// ============================================================================ + +function getProviderOpenAIClient(provider: string, vkEnabled: boolean = false): OpenAI { + const baseUrl = getIntegrationUrl('openai') + const apiKey = hasApiKey('openai') ? getApiKey('openai') : 'dummy-key' + + const defaultHeaders: Record = {} + + if (vkEnabled) { + const vk = getVirtualKey() + if (vk) { + defaultHeaders['x-bf-vk'] = vk + } + } + + return new OpenAI({ + baseURL: baseUrl, + apiKey, + defaultHeaders: Object.keys(defaultHeaders).length > 0 ? defaultHeaders : undefined, + timeout: 300000, // 5 minutes + maxRetries: 3, + }) +} + +function convertMessages(messages: ChatMessage[]): OpenAI.Chat.ChatCompletionMessageParam[] { + return messages.map((msg) => { + if (typeof msg.content === 'string') { + return { + role: msg.role, + content: msg.content, + } as OpenAI.Chat.ChatCompletionMessageParam + } + + // Handle multimodal content + const parts: OpenAI.Chat.ChatCompletionContentPart[] = msg.content.map((part) => { + if (part.type === 'text') { + return { type: 'text', text: part.text! } + } + return { + type: 'image_url', + image_url: { url: part.image_url!.url }, + } + }) + + return { + role: msg.role, + content: parts, + } as OpenAI.Chat.ChatCompletionMessageParam + }) +} + +// ============================================================================ +// Test Suite +// ============================================================================ + +describe('OpenAI SDK Integration Tests', () => { + // ============================================================================ + // Simple Chat Tests + // ============================================================================ + + describe('Simple Chat', () => { + const testCases = getCrossProviderParamsWithVkForScenario('simple_chat') + + it.each(testCases)( + 'should complete a simple chat - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for simple_chat') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const response = await client.chat.completions.create({ + model: formatProviderModel(provider, model), + messages: convertMessages(SIMPLE_CHAT_MESSAGES), + max_tokens: 100, + }) + + assertValidChatResponse(response) + console.log(`āœ… Simple chat passed for ${formatProviderModel(provider, model)}`) + } + ) + }) + + // ============================================================================ + // Multi-turn Conversation Tests + // ============================================================================ + + describe('Multi-turn Conversation', () => { + const testCases = getCrossProviderParamsWithVkForScenario('multi_turn_conversation') + + it.each(testCases)( + 'should handle multi-turn conversation - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for multi_turn_conversation') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const response = await client.chat.completions.create({ + model: formatProviderModel(provider, model), + messages: convertMessages(MULTI_TURN_MESSAGES), + max_tokens: 100, + }) + + assertValidChatResponse(response) + + // Verify context is maintained + const content = response.choices[0]?.message?.content || '' + expect(content.toLowerCase()).toMatch(/paris|population|million|people/i) + console.log(`āœ… Multi-turn conversation passed for ${formatProviderModel(provider, model)}`) + } + ) + }) + + // ============================================================================ + // Streaming Tests + // ============================================================================ + + describe('Streaming Chat', () => { + const testCases = getCrossProviderParamsWithVkForScenario('streaming') + + it.each(testCases)( + 'should stream chat response - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for streaming') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const stream = await client.chat.completions.create({ + model: formatProviderModel(provider, model), + messages: convertMessages(STREAMING_CHAT_MESSAGES), + max_tokens: 100, + stream: true, + }) + + const content = await collectStreamingContent(stream) + expect(content.length).toBeGreaterThan(0) + console.log(`āœ… Streaming chat passed for ${formatProviderModel(provider, model)}`) + } + ) + }) + + // ============================================================================ + // Streaming Client Disconnect Tests + // ============================================================================ + + describe('Streaming Chat - Client Disconnect', () => { + const testCases = getCrossProviderParamsWithVkForScenario('streaming') + + it.each(testCases)( + 'should handle client disconnect mid-stream - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for streaming') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const abortController = new AbortController() + + // Request a longer response to ensure we have time to abort mid-stream + const stream = await client.chat.completions.create({ + model: formatProviderModel(provider, model), + messages: [{ role: 'user', content: 'Write a detailed essay about the history of computing, including at least 10 paragraphs.' }], + max_tokens: 1000, + stream: true, + }, { + signal: abortController.signal, + }) + + let chunkCount = 0 + let content = '' + let wasAborted = false + + try { + for await (const chunk of stream) { + chunkCount++ + const delta = chunk.choices[0]?.delta?.content || '' + content += delta + + // Abort after receiving a few chunks + if (chunkCount >= 3) { + abortController.abort() + } + } + } catch (error) { + // Expect an abort error + wasAborted = true + expect(error).toBeDefined() + // The error should be an AbortError or contain abort-related message + const errorMessage = error instanceof Error ? error.message.toLowerCase() : String(error).toLowerCase() + const isAbortError = errorMessage.includes('abort') || + errorMessage.includes('cancel') || + error instanceof DOMException || + (error as { name?: string })?.name === 'AbortError' + expect(isAbortError).toBe(true) + } + + // Verify we received some content before aborting + expect(chunkCount).toBeGreaterThanOrEqual(3) + expect(content.length).toBeGreaterThan(0) + expect(wasAborted).toBe(true) + console.log(`āœ… Streaming client disconnect passed for ${formatProviderModel(provider, model)} (${chunkCount} chunks before abort)`) + } + ) + }) + + // ============================================================================ + // Tool Calling Tests + // ============================================================================ + + describe('Single Tool Call', () => { + const testCases = getCrossProviderParamsWithVkForScenario('tool_calls') + + it.each(testCases)( + 'should make a single tool call - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for tool_calls') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const toolModel = getProviderModel(provider, 'tools') + + const response = await client.chat.completions.create({ + model: formatProviderModel(provider, toolModel || model), + messages: convertMessages(SINGLE_TOOL_CALL_MESSAGES), + tools: convertToOpenAITools([WEATHER_TOOL]), + max_tokens: 100, + }) + + assertHasToolCalls(response, 1) + const toolCalls = extractToolCalls(response) + expect(toolCalls[0].name).toBe('get_weather') + console.log(`āœ… Single tool call passed for ${formatProviderModel(provider, model)}`) + } + ) + }) + + describe('Multiple Tool Calls', () => { + const testCases = getCrossProviderParamsWithVkForScenario('multiple_tool_calls') + + it.each(testCases)( + 'should make multiple tool calls - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for multiple_tool_calls') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const toolModel = getProviderModel(provider, 'tools') + + const response = await client.chat.completions.create({ + model: formatProviderModel(provider, toolModel || model), + messages: convertMessages(MULTIPLE_TOOL_CALL_MESSAGES), + tools: convertToOpenAITools([WEATHER_TOOL, CALCULATOR_TOOL]), + max_tokens: 150, + }) + + const toolCalls = extractToolCalls(response) + expect(toolCalls.length).toBeGreaterThanOrEqual(1) + + const toolNames = toolCalls.map((tc: ExtractedToolCall) => tc.name) + expect(toolNames.some((name: string) => name === 'get_weather' || name === 'calculate')).toBe(true) + console.log(`āœ… Multiple tool calls passed for ${formatProviderModel(provider, model)}`) + } + ) + }) + + describe('End-to-End Tool Calling', () => { + const testCases = getCrossProviderParamsWithVkForScenario('end2end_tool_calling') + + it.each(testCases)( + 'should complete end-to-end tool calling - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for end2end_tool_calling') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const toolModel = getProviderModel(provider, 'tools') + + // Step 1: Initial request with tools + const response1 = await client.chat.completions.create({ + model: formatProviderModel(provider, toolModel || model), + messages: convertMessages(SINGLE_TOOL_CALL_MESSAGES), + tools: convertToOpenAITools([WEATHER_TOOL]), + max_tokens: 100, + }) + + const toolCalls = extractToolCalls(response1) + expect(toolCalls.length).toBeGreaterThan(0) + + // Step 2: Execute tool and get result + const toolResult = mockToolResponse(toolCalls[0].name, toolCalls[0].arguments) + + // Step 3: Send tool result back + const messages: OpenAI.Chat.ChatCompletionMessageParam[] = [ + ...convertMessages(SINGLE_TOOL_CALL_MESSAGES), + response1.choices[0].message as OpenAI.Chat.ChatCompletionMessageParam, + { + role: 'tool', + tool_call_id: response1.choices[0].message.tool_calls![0].id, + content: toolResult, + }, + ] + + const response2 = await client.chat.completions.create({ + model: formatProviderModel(provider, toolModel || model), + messages, + max_tokens: 200, + }) + + assertValidChatResponse(response2) + console.log(`āœ… End-to-end tool calling passed for ${formatProviderModel(provider, model)}`) + } + ) + }) + + // ============================================================================ + // Image/Vision Tests + // ============================================================================ + + describe('Image URL', () => { + const testCases = getCrossProviderParamsWithVkForScenario('image_url') + + it.each(testCases)( + 'should analyze image from URL - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for image_url') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const visionModel = getProviderModel(provider, 'vision') + + const response = await client.chat.completions.create({ + model: formatProviderModel(provider, visionModel || model), + messages: convertMessages(IMAGE_URL_MESSAGES), + max_tokens: 200, + }) + + assertValidImageResponse(response) + console.log(`āœ… Image URL analysis passed for ${formatProviderModel(provider, model)}`) + } + ) + }) + + describe('Image Base64', () => { + const testCases = getCrossProviderParamsWithVkForScenario('image_base64') + + it.each(testCases)( + 'should analyze image from Base64 - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for image_base64') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const visionModel = getProviderModel(provider, 'vision') + + const response = await client.chat.completions.create({ + model: formatProviderModel(provider, visionModel || model), + messages: convertMessages(IMAGE_BASE64_MESSAGES), + max_tokens: 200, + }) + + assertValidImageResponse(response) + console.log(`āœ… Image Base64 analysis passed for ${formatProviderModel(provider, model)}`) + } + ) + }) + + describe('Multiple Images', () => { + const testCases = getCrossProviderParamsWithVkForScenario('multiple_images') + + it.each(testCases)( + 'should analyze multiple images - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for multiple_images') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const visionModel = getProviderModel(provider, 'vision') + + const response = await client.chat.completions.create({ + model: formatProviderModel(provider, visionModel || model), + messages: convertMessages(MULTIPLE_IMAGES_MESSAGES), + max_tokens: 300, + }) + + assertValidImageResponse(response) + console.log(`āœ… Multiple images analysis passed for ${formatProviderModel(provider, model)}`) + } + ) + }) + + // ============================================================================ + // Speech Synthesis Tests (OpenAI only) + // ============================================================================ + + describe('Speech Synthesis', () => { + const testCases = getCrossProviderParamsWithVkForScenario('speech_synthesis') + + it.each(testCases)( + 'should synthesize speech - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for speech_synthesis') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const speechModel = getProviderModel(provider, 'speech') + const voice = getProviderVoice(provider) + + const response = await client.audio.speech.create({ + model: formatProviderModel(provider, speechModel || 'tts-1'), + voice: voice as 'alloy' | 'echo' | 'fable' | 'onyx' | 'nova' | 'shimmer', + input: SPEECH_TEST_INPUT, + }) + + const buffer = await response.arrayBuffer() + assertValidSpeechResponse(buffer) + expect(buffer.byteLength).toBeGreaterThan(1000) + console.log(`āœ… Speech synthesis passed for ${formatProviderModel(provider, model)}`) + } + ) + }) + + // ============================================================================ + // Transcription Tests (OpenAI only) + // ============================================================================ + + describe('Audio Transcription', () => { + const testCases = getCrossProviderParamsWithVkForScenario('transcription') + + it.each(testCases)( + 'should transcribe audio - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for transcription') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const transcriptionModel = getProviderModel(provider, 'transcription') + + // Generate test audio + const audioBuffer = generateTestAudio(1000) + const audioFile = new File([audioBuffer], 'test.wav', { type: 'audio/wav' }) + + const response = await client.audio.transcriptions.create({ + model: formatProviderModel(provider, transcriptionModel || 'whisper-1'), + file: audioFile, + language: 'en', + }) + + expect(response).toBeDefined() + // Note: Generated sine wave may not produce meaningful transcription + console.log(`āœ… Audio transcription passed for ${formatProviderModel(provider, model)}`) + } + ) + }) + + // ============================================================================ + // Embeddings Tests + // ============================================================================ + + describe('Embeddings - Single Text', () => { + const testCases = getCrossProviderParamsWithVkForScenario('embeddings') + + it.each(testCases)( + 'should generate single text embedding - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for embeddings') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const embeddingsModel = getProviderModel(provider, 'embeddings') + + const response = await client.embeddings.create({ + model: formatProviderModel(provider, embeddingsModel || 'text-embedding-3-small'), + input: EMBEDDINGS_SINGLE_TEXT, + }) + + assertValidEmbeddingResponse(response) + console.log(`āœ… Single text embedding passed for ${formatProviderModel(provider, model)}`) + } + ) + }) + + describe('Embeddings - Batch', () => { + const testCases = getCrossProviderParamsWithVkForScenario('embeddings') + + it.each(testCases)( + 'should generate batch embeddings - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for embeddings') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const embeddingsModel = getProviderModel(provider, 'embeddings') + + const response = await client.embeddings.create({ + model: formatProviderModel(provider, embeddingsModel || 'text-embedding-3-small'), + input: EMBEDDINGS_MULTIPLE_TEXTS, + }) + + assertValidEmbeddingsBatchResponse(response, EMBEDDINGS_MULTIPLE_TEXTS.length) + console.log(`āœ… Batch embeddings passed for ${formatProviderModel(provider, model)}`) + } + ) + }) + + describe('Embeddings - Similarity Analysis', () => { + const testCases = getCrossProviderParamsWithVkForScenario('embeddings') + + it.each(testCases)( + 'should compute similar embeddings - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for embeddings') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const embeddingsModel = getProviderModel(provider, 'embeddings') + + const response = await client.embeddings.create({ + model: formatProviderModel(provider, embeddingsModel || 'text-embedding-3-small'), + input: EMBEDDINGS_SIMILAR_TEXTS, + }) + + assertValidEmbeddingsBatchResponse(response, 2) + + // Calculate similarity between similar texts + const emb1 = response.data[0].embedding + const emb2 = response.data[1].embedding + const similarity = calculateCosineSimilarity(emb1, emb2) + + // Similar texts should have high cosine similarity (> 0.7) + expect(similarity).toBeGreaterThan(0.7) + console.log(`āœ… Embedding similarity analysis passed for ${formatProviderModel(provider, model)} (similarity: ${similarity.toFixed(3)})`) + } + ) + }) + + // ============================================================================ + // List Models Tests + // ============================================================================ + + describe('List Models', () => { + const testCases = getCrossProviderParamsWithVkForScenario('list_models') + + it.each(testCases)( + 'should list available models - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for list_models') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const models = await client.models.list() + + expect(models).toBeDefined() + expect(models.data).toBeDefined() + expect(Array.isArray(models.data)).toBe(true) + console.log(`āœ… List models passed for ${formatProviderModel(provider, model)} (${models.data.length} models)`) + } + ) + }) + + // ============================================================================ + // Count Tokens Tests + // ============================================================================ + + describe('Count Tokens', () => { + const testCases = getCrossProviderParamsWithVkForScenario('count_tokens') + + it.each(testCases)( + 'should count tokens - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for count_tokens') + return + } + + // Token counting typically requires a direct API call + // This is a placeholder that verifies the setup works + const client = getProviderOpenAIClient(provider, vkEnabled) + + // Use a simple chat completion to verify connectivity + const response = await client.chat.completions.create({ + model: formatProviderModel(provider, model), + messages: [{ role: 'user', content: 'Say hello' }], + max_tokens: 10, + }) + + expect(response.usage).toBeDefined() + if (response.usage) { + expect(response.usage.prompt_tokens).toBeGreaterThan(0) + expect(response.usage.completion_tokens).toBeGreaterThan(0) + expect(response.usage.total_tokens).toBeGreaterThan(0) + } + console.log(`āœ… Count tokens passed for ${formatProviderModel(provider, model)}`) + } + ) + }) + + // ============================================================================ + // Files API Tests + // ============================================================================ + + describe('Files API - Upload', () => { + const testCases = getCrossProviderParamsWithVkForScenario('file_upload') + + it.each(testCases)( + 'should upload a file - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for file_upload') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + + // Create JSONL content for batch processing + const jsonlContent = JSON.stringify({ + custom_id: 'request-1', + method: 'POST', + url: '/v1/chat/completions', + body: { + model: formatProviderModel(provider, model), + messages: [{ role: 'user', content: 'Hello' }], + max_tokens: 10, + }, + }) + + // Create a File object from the content + const file = new File([jsonlContent], 'batch_input.jsonl', { + type: 'application/jsonl', + }) + + let uploadedFileId: string | null = null + + try { + const response = await client.files.create({ + file, + purpose: 'batch', + }) + + expect(response).toBeDefined() + expect(response.id).toBeDefined() + expect(typeof response.id).toBe('string') + uploadedFileId = response.id + + console.log(`āœ… File upload passed for ${formatProviderModel(provider, model)} - File ID: ${response.id}`) + } finally { + // Clean up + if (uploadedFileId) { + try { + await client.files.del(uploadedFileId) + } catch (e) { + console.log(`Warning: Failed to clean up file: ${e}`) + } + } + } + } + ) + }) + + describe('Files API - List', () => { + const testCases = getCrossProviderParamsWithVkForScenario('file_list') + + it.each(testCases)( + 'should list files - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for file_list') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + + // First upload a file to ensure we have at least one + const jsonlContent = JSON.stringify({ + custom_id: 'request-1', + method: 'POST', + url: '/v1/chat/completions', + body: { + model: formatProviderModel(provider, model), + messages: [{ role: 'user', content: 'Hello' }], + max_tokens: 10, + }, + }) + + const file = new File([jsonlContent], 'test_list.jsonl', { + type: 'application/jsonl', + }) + + let uploadedFileId: string | null = null + + try { + const uploadedFile = await client.files.create({ + file, + purpose: 'batch', + }) + uploadedFileId = uploadedFile.id + + // List files + const response = await client.files.list() + + expect(response).toBeDefined() + expect(response.data).toBeDefined() + expect(Array.isArray(response.data)).toBe(true) + expect(response.data.length).toBeGreaterThan(0) + + // Check that our uploaded file is in the list + const fileIds = response.data.map((f) => f.id) + expect(fileIds).toContain(uploadedFileId) + + console.log(`āœ… File list passed for ${formatProviderModel(provider, model)} - ${response.data.length} files`) + } finally { + // Clean up + if (uploadedFileId) { + try { + await client.files.del(uploadedFileId) + } catch (e) { + console.log(`Warning: Failed to clean up file: ${e}`) + } + } + } + } + ) + }) + + describe('Files API - Retrieve', () => { + const testCases = getCrossProviderParamsWithVkForScenario('file_retrieve') + + it.each(testCases)( + 'should retrieve file metadata - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for file_retrieve') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + + // First upload a file + const jsonlContent = JSON.stringify({ + custom_id: 'request-1', + method: 'POST', + url: '/v1/chat/completions', + body: { + model: formatProviderModel(provider, model), + messages: [{ role: 'user', content: 'Hello' }], + max_tokens: 10, + }, + }) + + const file = new File([jsonlContent], 'test_retrieve.jsonl', { + type: 'application/jsonl', + }) + + let uploadedFileId: string | null = null + + try { + const uploadedFile = await client.files.create({ + file, + purpose: 'batch', + }) + uploadedFileId = uploadedFile.id + + // Retrieve file metadata + const response = await client.files.retrieve(uploadedFileId) + + expect(response).toBeDefined() + expect(response.id).toBe(uploadedFileId) + expect(response.filename).toBe('test_retrieve.jsonl') + expect(response.purpose).toBe('batch') + + console.log(`āœ… File retrieve passed for ${formatProviderModel(provider, model)} - File ID: ${response.id}`) + } finally { + // Clean up + if (uploadedFileId) { + try { + await client.files.del(uploadedFileId) + } catch (e) { + console.log(`Warning: Failed to clean up file: ${e}`) + } + } + } + } + ) + }) + + describe('Files API - Delete', () => { + const testCases = getCrossProviderParamsWithVkForScenario('file_delete') + + it.each(testCases)( + 'should delete a file - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for file_delete') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + + // First upload a file + const jsonlContent = JSON.stringify({ + custom_id: 'request-1', + method: 'POST', + url: '/v1/chat/completions', + body: { + model: formatProviderModel(provider, model), + messages: [{ role: 'user', content: 'Hello' }], + max_tokens: 10, + }, + }) + + const file = new File([jsonlContent], 'test_delete.jsonl', { + type: 'application/jsonl', + }) + + const uploadedFile = await client.files.create({ + file, + purpose: 'batch', + }) + + // Delete the file + const response = await client.files.del(uploadedFile.id) + + expect(response).toBeDefined() + expect(response.deleted).toBe(true) + expect(response.id).toBe(uploadedFile.id) + + console.log(`āœ… File delete passed for ${formatProviderModel(provider, model)} - File ID: ${response.id}`) + } + ) + }) + + describe('Files API - Content', () => { + const testCases = getCrossProviderParamsWithVkForScenario('file_content') + + it.each(testCases)( + 'should retrieve file content - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for file_content') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + + // Upload a file with known content + const jsonlContent = JSON.stringify({ + custom_id: 'request-1', + method: 'POST', + url: '/v1/chat/completions', + body: { + model: formatProviderModel(provider, model), + messages: [{ role: 'user', content: 'Hello' }], + max_tokens: 10, + }, + }) + + const file = new File([jsonlContent], 'test_content.jsonl', { + type: 'application/jsonl', + }) + + let uploadedFileId: string | null = null + + try { + const uploadedFile = await client.files.create({ + file, + purpose: 'batch', + }) + uploadedFileId = uploadedFile.id + + // Retrieve file content + const response = await client.files.content(uploadedFileId) + + expect(response).toBeDefined() + // Response should be the file content + const content = await response.text() + expect(content).toBe(jsonlContent) + + console.log(`āœ… File content passed for ${formatProviderModel(provider, model)}`) + } finally { + // Clean up + if (uploadedFileId) { + try { + await client.files.del(uploadedFileId) + } catch (e) { + console.log(`Warning: Failed to clean up file: ${e}`) + } + } + } + } + ) + }) + + // ============================================================================ + // Batch API Tests + // ============================================================================ + + describe('Batch API - Create', () => { + const testCases = getCrossProviderParamsWithVkForScenario('batch_file_upload') + + it.each(testCases)( + 'should create a batch job - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for batch_file_upload') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + + // Create JSONL content for batch processing + const requests = [ + { + custom_id: 'request-1', + method: 'POST', + url: '/v1/chat/completions', + body: { + model: formatProviderModel(provider, model), + messages: [{ role: 'user', content: 'Say hello' }], + max_tokens: 10, + }, + }, + { + custom_id: 'request-2', + method: 'POST', + url: '/v1/chat/completions', + body: { + model: formatProviderModel(provider, model), + messages: [{ role: 'user', content: 'Say goodbye' }], + max_tokens: 10, + }, + }, + ] + + const jsonlContent = requests.map((r) => JSON.stringify(r)).join('\n') + + // Upload the file + const file = new File([jsonlContent], 'batch_input.jsonl', { + type: 'application/jsonl', + }) + + let uploadedFileId: string | null = null + let batchId: string | null = null + + try { + const uploadedFile = await client.files.create({ + file, + purpose: 'batch', + }) + uploadedFileId = uploadedFile.id + + // Create batch job + const batch = await client.batches.create({ + input_file_id: uploadedFileId, + endpoint: '/v1/chat/completions', + completion_window: '24h', + }) + + expect(batch).toBeDefined() + expect(batch.id).toBeDefined() + expect(typeof batch.id).toBe('string') + batchId = batch.id + + console.log(`āœ… Batch create passed for ${formatProviderModel(provider, model)} - Batch ID: ${batch.id}`) + } finally { + // Clean up batch + if (batchId) { + try { + await client.batches.cancel(batchId) + } catch (e) { + console.log(`Warning: Failed to cancel batch: ${e}`) + } + } + // Clean up file + if (uploadedFileId) { + try { + await client.files.del(uploadedFileId) + } catch (e) { + console.log(`Warning: Failed to clean up file: ${e}`) + } + } + } + } + ) + }) + + describe('Batch API - List', () => { + const testCases = getCrossProviderParamsWithVkForScenario('batch_list') + + it.each(testCases)( + 'should list batch jobs - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for batch_list') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + + // List batches + const response = await client.batches.list({ limit: 10 }) + + expect(response).toBeDefined() + expect(response.data).toBeDefined() + expect(Array.isArray(response.data)).toBe(true) + + console.log(`āœ… Batch list passed for ${formatProviderModel(provider, model)} - ${response.data.length} batches`) + } + ) + }) + + describe('Batch API - Retrieve', () => { + const testCases = getCrossProviderParamsWithVkForScenario('batch_retrieve') + + it.each(testCases)( + 'should retrieve batch job status - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for batch_retrieve') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + + // First, list batches to get a batch ID + const listResponse = await client.batches.list({ limit: 10 }) + + if (listResponse.data.length === 0) { + console.log('Skipping: No batches available to retrieve') + return + } + + const batchId = listResponse.data[0].id + + // Retrieve batch + const response = await client.batches.retrieve(batchId) + + expect(response).toBeDefined() + expect(response.id).toBe(batchId) + expect(response.status).toBeDefined() + + console.log(`āœ… Batch retrieve passed for ${formatProviderModel(provider, model)} - Batch ID: ${response.id}, Status: ${response.status}`) + } + ) + }) + + describe('Batch API - Cancel', () => { + const testCases = getCrossProviderParamsWithVkForScenario('batch_cancel') + + it.each(testCases)( + 'should cancel a batch job - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for batch_cancel') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + + // Create JSONL content for batch processing + const requests = [ + { + custom_id: 'request-1', + method: 'POST', + url: '/v1/chat/completions', + body: { + model: formatProviderModel(provider, model), + messages: [{ role: 'user', content: 'Say hello' }], + max_tokens: 10, + }, + }, + ] + + const jsonlContent = requests.map((r) => JSON.stringify(r)).join('\n') + + // Upload the file + const file = new File([jsonlContent], 'batch_cancel_input.jsonl', { + type: 'application/jsonl', + }) + + let uploadedFileId: string | null = null + + try { + const uploadedFile = await client.files.create({ + file, + purpose: 'batch', + }) + uploadedFileId = uploadedFile.id + + // Create batch job + const batch = await client.batches.create({ + input_file_id: uploadedFileId, + endpoint: '/v1/chat/completions', + completion_window: '24h', + }) + + // Cancel the batch + const response = await client.batches.cancel(batch.id) + + expect(response).toBeDefined() + expect(response.id).toBe(batch.id) + expect(['cancelling', 'cancelled', 'completed', 'failed']).toContain(response.status) + + console.log(`āœ… Batch cancel passed for ${formatProviderModel(provider, model)} - Batch ID: ${response.id}, Status: ${response.status}`) + } finally { + // Clean up file + if (uploadedFileId) { + try { + await client.files.del(uploadedFileId) + } catch (e) { + console.log(`Warning: Failed to clean up file: ${e}`) + } + } + } + } + ) + }) + + // ============================================================================ + // Responses API Tests + // ============================================================================ + + describe('Responses API - Simple Text', () => { + const testCases = getCrossProviderParamsWithVkForScenario('responses') + + it.each(testCases)( + 'should create a response with simple text - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for responses') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + + // Use type assertion for beta responses API + const responses = (client as unknown as { responses: { create: (params: unknown) => Promise } }).responses + + try { + const response = await responses.create({ + model: formatProviderModel(provider, model), + input: RESPONSES_SIMPLE_TEXT_INPUT, + max_output_tokens: 1000, + }) as { output?: Array<{ content?: Array<{ text?: string }> }> } + + expect(response).toBeDefined() + expect(response.output).toBeDefined() + expect(response.output!.length).toBeGreaterThan(0) + + // Extract content + let content = '' + for (const item of response.output || []) { + if (item.content) { + for (const block of item.content) { + if (block.text) { + content += block.text + } + } + } + } + + expect(content.length).toBeGreaterThan(20) + expect(content.toLowerCase()).toMatch(/paris|france|capital/i) + console.log(`āœ… Responses API simple text passed for ${formatProviderModel(provider, model)}`) + } catch (error) { + console.log(`āš ļø Responses API test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + } + ) + }) + + describe('Responses API - With System Message', () => { + const testCases = getCrossProviderParamsWithVkForScenario('responses') + + it.each(testCases)( + 'should create a response with system message - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for responses') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const responses = (client as unknown as { responses: { create: (params: unknown) => Promise } }).responses + + try { + const response = await responses.create({ + model: formatProviderModel(provider, model), + input: RESPONSES_TEXT_WITH_SYSTEM, + max_output_tokens: 1000, + }) as { output?: Array<{ content?: Array<{ text?: string }> }> } + + expect(response).toBeDefined() + expect(response.output).toBeDefined() + + // Extract content + let content = '' + for (const item of response.output || []) { + if (item.content) { + for (const block of item.content) { + if (block.text) { + content += block.text + } + } + } + } + + expect(content.length).toBeGreaterThan(20) + console.log(`āœ… Responses API with system message passed for ${formatProviderModel(provider, model)}`) + } catch (error) { + console.log(`āš ļø Responses API with system message test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + } + ) + }) + + describe('Responses API - With Image', () => { + const testCases = getCrossProviderParamsWithVkForScenario('responses_image') + + it.each(testCases)( + 'should create a response with image input - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for responses_image') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const visionModel = getProviderModel(provider, 'vision') + const responses = (client as unknown as { responses: { create: (params: unknown) => Promise } }).responses + + try { + const response = await responses.create({ + model: formatProviderModel(provider, visionModel || model), + input: [ + { type: 'input_text', text: RESPONSES_IMAGE_INPUT.text }, + { type: 'input_image', image_url: RESPONSES_IMAGE_INPUT.imageUrl }, + ], + max_output_tokens: 1000, + }) as { output?: Array<{ content?: Array<{ text?: string }> }> } + + expect(response).toBeDefined() + expect(response.output).toBeDefined() + + // Extract content + let content = '' + for (const item of response.output || []) { + if (item.content) { + for (const block of item.content) { + if (block.text) { + content += block.text + } + } + } + } + + expect(content.length).toBeGreaterThan(20) + console.log(`āœ… Responses API with image passed for ${formatProviderModel(provider, model)}`) + } catch (error) { + console.log(`āš ļø Responses API with image test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + } + ) + }) + + describe('Responses API - With Tools', () => { + const testCases = getCrossProviderParamsWithVkForScenario('responses') + + it.each(testCases)( + 'should create a response with tools - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for responses') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const toolModel = getProviderModel(provider, 'tools') + const responses = (client as unknown as { responses: { create: (params: unknown) => Promise } }).responses + const tools = convertToResponsesTools([WEATHER_TOOL]) + + try { + const response = await responses.create({ + model: formatProviderModel(provider, toolModel || model), + input: RESPONSES_TOOL_CALL_INPUT, + tools, + max_output_tokens: 150, + }) as { output?: Array<{ type?: string; name?: string; arguments?: string }> } + + expect(response).toBeDefined() + expect(response.output).toBeDefined() + + // Check for function call in output + const hasFunctionCall = (response.output || []).some( + (item) => item.type === 'function_call' || item.name === 'get_weather' + ) + + expect(hasFunctionCall).toBe(true) + console.log(`āœ… Responses API with tools passed for ${formatProviderModel(provider, model)}`) + } catch (error) { + console.log(`āš ļø Responses API with tools test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + } + ) + }) + + describe('Responses API - Streaming', () => { + const testCases = getCrossProviderParamsWithVkForScenario('responses') + + it.each(testCases)( + 'should stream a response - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for responses') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const responses = (client as unknown as { responses: { create: (params: unknown) => Promise> } }).responses + + try { + const stream = await responses.create({ + model: formatProviderModel(provider, model), + input: RESPONSES_STREAMING_INPUT, + max_output_tokens: 1000, + stream: true, + }) + + let content = '' + let chunkCount = 0 + + for await (const event of stream as AsyncIterable<{ type?: string; delta?: { text?: string } }>) { + chunkCount++ + if (event.type === 'content_part.delta' || event.type === 'response.output_text.delta') { + if (event.delta?.text) { + content += event.delta.text + } + } + } + + expect(chunkCount).toBeGreaterThan(1) + expect(content.length).toBeGreaterThan(0) + console.log(`āœ… Responses API streaming passed for ${formatProviderModel(provider, model)} (${chunkCount} chunks)`) + } catch (error) { + console.log(`āš ļø Responses API streaming test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + } + ) + }) + + describe('Responses API - Streaming Client Disconnect', () => { + const testCases = getCrossProviderParamsWithVkForScenario('responses') + + it.each(testCases)( + 'should handle client disconnect mid-stream - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for responses') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const abortController = new AbortController() + const responses = (client as unknown as { + responses: { + create: (params: unknown, options?: { signal?: AbortSignal }) => Promise> + } + }).responses + + try { + const stream = await responses.create({ + model: formatProviderModel(provider, model), + input: 'Write a detailed essay about the history of artificial intelligence, including at least 10 paragraphs covering different eras and breakthroughs.', + max_output_tokens: 2000, + stream: true, + }, { + signal: abortController.signal, + }) + + let chunkCount = 0 + let content = '' + let wasAborted = false + + try { + for await (const event of stream as AsyncIterable<{ type?: string; delta?: { text?: string } }>) { + chunkCount++ + if (event.type === 'content_part.delta' || event.type === 'response.output_text.delta') { + if (event.delta?.text) { + content += event.delta.text + } + } + + // Abort after receiving a few chunks + if (chunkCount >= 3) { + abortController.abort() + } + } + } catch (error) { + wasAborted = true + expect(error).toBeDefined() + const errorMessage = error instanceof Error ? error.message.toLowerCase() : String(error).toLowerCase() + const isAbortError = errorMessage.includes('abort') || + errorMessage.includes('cancel') || + error instanceof DOMException || + (error as { name?: string })?.name === 'AbortError' + expect(isAbortError).toBe(true) + } + + expect(chunkCount).toBeGreaterThanOrEqual(3) + expect(wasAborted).toBe(true) + console.log(`āœ… Responses API streaming client disconnect passed for ${formatProviderModel(provider, model)} (${chunkCount} chunks before abort)`) + } catch (error) { + console.log(`āš ļø Responses API streaming client disconnect test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + } + ) + }) + + describe('Responses API - Streaming With Tools', () => { + const testCases = getCrossProviderParamsWithVkForScenario('responses') + + it.each(testCases)( + 'should stream a response with tools - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for responses') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const toolModel = getProviderModel(provider, 'tools') + const responses = (client as unknown as { responses: { create: (params: unknown) => Promise> } }).responses + const tools = convertToResponsesTools([WEATHER_TOOL]) + + try { + const stream = await responses.create({ + model: formatProviderModel(provider, toolModel || model), + input: [ + { type: 'input_text', text: "What's the weather in San Francisco?" }, + ], + tools, + max_output_tokens: 150, + stream: true, + }) + + let chunkCount = 0 + let hasToolCall = false + + for await (const event of stream as AsyncIterable<{ type?: string }>) { + chunkCount++ + if (event.type === 'response.function_call_arguments.delta' || event.type === 'function_call') { + hasToolCall = true + } + } + + expect(chunkCount).toBeGreaterThan(1) + console.log(`āœ… Responses API streaming with tools passed for ${formatProviderModel(provider, model)} (tool call: ${hasToolCall})`) + } catch (error) { + console.log(`āš ļø Responses API streaming with tools test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + } + ) + }) + + describe('Responses API - Reasoning', () => { + const testCases = getCrossProviderParamsWithVkForScenario('thinking') + + it.each(testCases)( + 'should create a response with reasoning - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for thinking') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + const thinkingModel = getProviderModel(provider, 'thinking') + const responses = (client as unknown as { responses: { create: (params: unknown) => Promise } }).responses + + try { + const response = await responses.create({ + model: formatProviderModel(provider, thinkingModel || model), + input: RESPONSES_REASONING_INPUT, + max_output_tokens: 1200, + reasoning: { + effort: 'high', + summary: 'auto', + }, + }) as { output?: Array<{ type?: string; content?: Array<{ text?: string }>; summary?: Array<{ text?: string }> }> } + + expect(response).toBeDefined() + expect(response.output).toBeDefined() + + // Extract content from output or summary + let content = '' + for (const item of response.output || []) { + if (item.content) { + for (const block of item.content) { + if (block.text) { + content += block.text + } + } + } + if (item.summary) { + for (const block of item.summary) { + if (block.text) { + content += block.text + } + } + } + } + + expect(content.length).toBeGreaterThan(30) + console.log(`āœ… Responses API reasoning passed for ${formatProviderModel(provider, model)}`) + } catch (error) { + console.log(`āš ļø Responses API reasoning test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + } + ) + }) + + // ============================================================================ + // Input Tokens API Tests + // ============================================================================ + + describe('Input Tokens - Simple Text', () => { + const testCases = getCrossProviderParamsWithVkForScenario('count_tokens') + + it.each(testCases)( + 'should count input tokens for simple text - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for count_tokens') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + + // Try to use the responses.input_tokens.count API if available + try { + const responses = (client as unknown as { responses: { input_tokens: { count: (params: unknown) => Promise<{ total_tokens: number }> } } }).responses + + const response = await responses.input_tokens.count({ + model: formatProviderModel(provider, model), + input: 'Hello, how are you?', + }) + + expect(response).toBeDefined() + expect(response.total_tokens).toBeGreaterThan(0) + console.log(`āœ… Input tokens count passed for ${formatProviderModel(provider, model)} (${response.total_tokens} tokens)`) + } catch (error) { + console.log(`āš ļø Input tokens count test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + } + ) + }) + + describe('Input Tokens - With System Message', () => { + const testCases = getCrossProviderParamsWithVkForScenario('count_tokens') + + it.each(testCases)( + 'should count input tokens with system message - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for count_tokens') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + + try { + const responses = (client as unknown as { responses: { input_tokens: { count: (params: unknown) => Promise<{ total_tokens: number }> } } }).responses + + const response = await responses.input_tokens.count({ + model: formatProviderModel(provider, model), + input: { + system: 'You are a helpful assistant.', + user: 'What is 2 + 2?', + }, + }) + + expect(response).toBeDefined() + expect(response.total_tokens).toBeGreaterThan(0) + console.log(`āœ… Input tokens with system message passed for ${formatProviderModel(provider, model)} (${response.total_tokens} tokens)`) + } catch (error) { + console.log(`āš ļø Input tokens with system message test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + } + ) + }) + + describe('Input Tokens - Long Text', () => { + const testCases = getCrossProviderParamsWithVkForScenario('count_tokens') + + it.each(testCases)( + 'should count input tokens for long text - $provider (VK: $vkEnabled)', + async ({ provider, model, vkEnabled }: ProviderModelVkParam) => { + if (shouldSkipNoProviders({ provider, model, vkEnabled })) { + console.log('Skipping: No providers available for count_tokens') + return + } + + const client = getProviderOpenAIClient(provider, vkEnabled) + + try { + const responses = (client as unknown as { responses: { input_tokens: { count: (params: unknown) => Promise<{ total_tokens: number }> } } }).responses + + const longText = 'This is a longer piece of text that should result in more tokens being counted. ' + + 'It contains multiple sentences and various words to ensure accurate token counting.' + + const response = await responses.input_tokens.count({ + model: formatProviderModel(provider, model), + input: longText, + }) + + expect(response).toBeDefined() + expect(response.total_tokens).toBeGreaterThan(10) + console.log(`āœ… Input tokens long text passed for ${formatProviderModel(provider, model)} (${response.total_tokens} tokens)`) + } catch (error) { + console.log(`āš ļø Input tokens long text test skipped: ${error instanceof Error ? error.message : 'Unknown error'}`) + } + } + ) + }) +}) diff --git a/tests/integrations/typescript/tsconfig.json b/tests/integrations/typescript/tsconfig.json new file mode 100644 index 0000000000..1d160f23a0 --- /dev/null +++ b/tests/integrations/typescript/tsconfig.json @@ -0,0 +1,32 @@ +{ + "compilerOptions": { + "target": "ES2022", + "module": "ESNext", + "moduleResolution": "bundler", + "lib": ["ES2022"], + "esModuleInterop": true, + "allowSyntheticDefaultImports": true, + "strict": true, + "skipLibCheck": true, + "declaration": true, + "declarationMap": true, + "sourceMap": true, + "outDir": "./dist", + "rootDir": ".", + "baseUrl": ".", + "paths": { + "@/*": ["./src/*"] + }, + "types": ["node", "vitest/globals"], + "resolveJsonModule": true, + "isolatedModules": true, + "noEmit": true, + "forceConsistentCasingInFileNames": true, + "noUnusedLocals": false, + "noUnusedParameters": false, + "noImplicitReturns": true, + "noFallthroughCasesInSwitch": true + }, + "include": ["src/**/*", "tests/**/*", "vitest.config.ts"], + "exclude": ["node_modules", "dist"] +} diff --git a/tests/integrations/typescript/vitest.config.ts b/tests/integrations/typescript/vitest.config.ts new file mode 100644 index 0000000000..43521673fd --- /dev/null +++ b/tests/integrations/typescript/vitest.config.ts @@ -0,0 +1,54 @@ +import { resolve } from 'path' +import { defineConfig } from 'vitest/config' + +export default defineConfig({ + test: { + // Test discovery + include: ['tests/**/*.test.ts'], + exclude: ['node_modules', 'dist'], + + // Global test settings + globals: true, + environment: 'node', + + // Timeout settings (5 minutes per test, matching Python) + testTimeout: 300000, + hookTimeout: 60000, + + // Run tests sequentially to avoid API rate limiting + pool: 'forks', + poolOptions: { + forks: { + singleFork: true, + }, + }, + + // Reporter configuration + reporters: ['verbose'], + + // Setup files + setupFiles: ['./tests/setup.ts'], + + // Retry flaky tests (matching Python pytest-rerunfailures) + retry: 2, + + // Coverage configuration + coverage: { + provider: 'v8', + reporter: ['text', 'html', 'json'], + include: ['src/**/*.ts'], + exclude: ['node_modules', 'dist', 'tests'], + }, + + // Environment variables + env: { + NODE_ENV: 'test', + }, + }, + + resolve: { + alias: { + '@': resolve(__dirname, './src'), + }, + }, +}) diff --git a/transports/bifrost-http/handlers/cache.go b/transports/bifrost-http/handlers/cache.go index a91d04aa5b..df3ef62d69 100644 --- a/transports/bifrost-http/handlers/cache.go +++ b/transports/bifrost-http/handlers/cache.go @@ -23,7 +23,7 @@ func NewCacheHandler(plugin schemas.Plugin) *CacheHandler { } } -func (h *CacheHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *CacheHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { r.DELETE("/api/cache/clear/{requestId}", lib.ChainMiddlewares(h.clearCache, middlewares...)) r.DELETE("/api/cache/clear-by-key/{cacheKey}", lib.ChainMiddlewares(h.clearCacheByKey, middlewares...)) } diff --git a/transports/bifrost-http/handlers/config.go b/transports/bifrost-http/handlers/config.go index 26ad2b6a36..0659253214 100644 --- a/transports/bifrost-http/handlers/config.go +++ b/transports/bifrost-http/handlers/config.go @@ -13,6 +13,7 @@ import ( "github.com/fasthttp/router" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/network" + "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework" "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" @@ -45,6 +46,7 @@ type ConfigManager interface { ReloadPricingManager(ctx context.Context) error ForceReloadPricing(ctx context.Context) error UpdateDropExcessRequests(ctx context.Context, value bool) + UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error ReloadPlugin(ctx context.Context, name string, path *string, pluginConfig any) error ReloadProxyConfig(ctx context.Context, config *configstoreTables.GlobalProxyConfig) error ReloadHeaderFilterConfig(ctx context.Context, config *configstoreTables.GlobalHeaderFilterConfig) error @@ -68,7 +70,7 @@ func NewConfigHandler(configManager ConfigManager, store *lib.Config) *ConfigHan // RegisterRoutes registers the configuration-related routes. // It adds the `PUT /api/config` endpoint. -func (h *ConfigHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *ConfigHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { r.GET("/api/config", lib.ChainMiddlewares(h.getConfig, middlewares...)) r.PUT("/api/config", lib.ChainMiddlewares(h.updateConfig, middlewares...)) r.GET("/api/version", lib.ChainMiddlewares(h.getVersion, middlewares...)) @@ -241,6 +243,49 @@ func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) { updatedConfig.DropExcessRequests = payload.ClientConfig.DropExcessRequests } + if payload.ClientConfig.MCPCodeModeBindingLevel != "" { + if payload.ClientConfig.MCPCodeModeBindingLevel != string(schemas.CodeModeBindingLevelServer) && payload.ClientConfig.MCPCodeModeBindingLevel != string(schemas.CodeModeBindingLevelTool) { + logger.Warn("mcp_code_mode_binding_level must be 'server' or 'tool'") + SendError(ctx, fasthttp.StatusBadRequest, "mcp_code_mode_binding_level must be 'server' or 'tool'") + return + } + } + + shouldReloadMCPToolManagerConfig := false + + if payload.ClientConfig.MCPAgentDepth != currentConfig.MCPAgentDepth { + if payload.ClientConfig.MCPAgentDepth <= 0 { + logger.Warn("mcp_agent_depth must be greater than 0") + SendError(ctx, fasthttp.StatusBadRequest, "mcp_agent_depth must be greater than 0") + return + } + updatedConfig.MCPAgentDepth = payload.ClientConfig.MCPAgentDepth + shouldReloadMCPToolManagerConfig = true + } + + if payload.ClientConfig.MCPToolExecutionTimeout != currentConfig.MCPToolExecutionTimeout { + if payload.ClientConfig.MCPToolExecutionTimeout <= 0 { + logger.Warn("mcp_tool_execution_timeout must be greater than 0") + SendError(ctx, fasthttp.StatusBadRequest, "mcp_tool_execution_timeout must be greater than 0") + return + } + updatedConfig.MCPToolExecutionTimeout = payload.ClientConfig.MCPToolExecutionTimeout + shouldReloadMCPToolManagerConfig = true + } + + if payload.ClientConfig.MCPCodeModeBindingLevel != "" && payload.ClientConfig.MCPCodeModeBindingLevel != currentConfig.MCPCodeModeBindingLevel { + updatedConfig.MCPCodeModeBindingLevel = payload.ClientConfig.MCPCodeModeBindingLevel + shouldReloadMCPToolManagerConfig = true + } + + if shouldReloadMCPToolManagerConfig { + if err := h.configManager.UpdateMCPToolManagerConfig(ctx, updatedConfig.MCPAgentDepth, updatedConfig.MCPToolExecutionTimeout, updatedConfig.MCPCodeModeBindingLevel); err != nil { + logger.Warn(fmt.Sprintf("failed to update mcp tool manager config: %v", err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to update mcp tool manager config: %v", err)) + return + } + } + if !slices.Equal(payload.ClientConfig.PrometheusLabels, currentConfig.PrometheusLabels) { updatedConfig.PrometheusLabels = payload.ClientConfig.PrometheusLabels shouldReloadTelemetryPlugin = true @@ -281,6 +326,12 @@ func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) { updatedConfig.MaxRequestBodySizeMB = payload.ClientConfig.MaxRequestBodySizeMB updatedConfig.EnableLiteLLMFallbacks = payload.ClientConfig.EnableLiteLLMFallbacks + updatedConfig.MCPAgentDepth = payload.ClientConfig.MCPAgentDepth + updatedConfig.MCPToolExecutionTimeout = payload.ClientConfig.MCPToolExecutionTimeout + // Only update MCPCodeModeBindingLevel if payload is non-empty to avoid clearing stored value + if payload.ClientConfig.MCPCodeModeBindingLevel != "" { + updatedConfig.MCPCodeModeBindingLevel = payload.ClientConfig.MCPCodeModeBindingLevel + } // Handle HeaderFilterConfig changes if !headerFilterConfigEqual(payload.ClientConfig.HeaderFilterConfig, currentConfig.HeaderFilterConfig) { diff --git a/transports/bifrost-http/handlers/devpprof.go b/transports/bifrost-http/handlers/devpprof.go new file mode 100644 index 0000000000..fd5db1ed6c --- /dev/null +++ b/transports/bifrost-http/handlers/devpprof.go @@ -0,0 +1,717 @@ +package handlers + +import ( + "bytes" + "os" + "regexp" + "runtime" + "runtime/pprof" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/fasthttp/router" + "github.com/google/pprof/profile" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +const ( + // Collection interval for metrics + metricsCollectionInterval = 10 * time.Second + // Number of data points to keep (5 minutes / 10 seconds = 30 points) + historySize = 30 + // Top allocations to return + topAllocationsCount = 5 +) + +// MemoryStats represents memory statistics at a point in time +type MemoryStats struct { + Alloc uint64 `json:"alloc"` + TotalAlloc uint64 `json:"total_alloc"` + HeapInuse uint64 `json:"heap_inuse"` + HeapObjects uint64 `json:"heap_objects"` + Sys uint64 `json:"sys"` +} + +// CPUStats represents CPU statistics +type CPUStats struct { + UsagePercent float64 `json:"usage_percent"` + UserTime float64 `json:"user_time"` + SystemTime float64 `json:"system_time"` +} + +// RuntimeStats represents runtime statistics +type RuntimeStats struct { + NumGoroutine int `json:"num_goroutine"` + NumGC uint32 `json:"num_gc"` + GCPauseNs uint64 `json:"gc_pause_ns"` + NumCPU int `json:"num_cpu"` + GOMAXPROCS int `json:"gomaxprocs"` +} + +// AllocationInfo represents a single allocation site +type AllocationInfo struct { + Function string `json:"function"` + File string `json:"file"` + Line int `json:"line"` + Bytes int64 `json:"bytes"` + Count int64 `json:"count"` +} + +// GoroutineGroup represents a group of goroutines with the same stack trace +type GoroutineGroup struct { + Count int `json:"count"` + State string `json:"state"` + WaitReason string `json:"wait_reason,omitempty"` + WaitMinutes int `json:"wait_minutes,omitempty"` // Parsed wait time in minutes + TopFunc string `json:"top_func"` + Stack []string `json:"stack"` + Category string `json:"category"` // "background", "per-request", "unknown" +} + +// GoroutineProfile represents the goroutine profile response +type GoroutineProfile struct { + Timestamp string `json:"timestamp"` + TotalGoroutines int `json:"total_goroutines"` + Groups []GoroutineGroup `json:"groups"` + Summary GoroutineSummary `json:"summary"` + RawProfile string `json:"raw_profile,omitempty"` +} + +// GoroutineSummary provides a quick overview of goroutine health +type GoroutineSummary struct { + Background int `json:"background"` // Expected long-running goroutines + PerRequest int `json:"per_request"` // Goroutines that should complete with requests + LongWaiting int `json:"long_waiting"` // Goroutines waiting > 1 minute (potential leaks) + PotentiallyStuck int `json:"potentially_stuck"` // Per-request goroutines waiting > 1 minute +} + +// HistoryPoint represents a single point in the metrics history +type HistoryPoint struct { + Timestamp string `json:"timestamp"` + Alloc uint64 `json:"alloc"` + HeapInuse uint64 `json:"heap_inuse"` + Goroutines int `json:"goroutines"` + GCPauseNs uint64 `json:"gc_pause_ns"` + CPUPercent float64 `json:"cpu_percent"` +} + +// PprofData represents the complete pprof response +type PprofData struct { + Timestamp string `json:"timestamp"` + Memory MemoryStats `json:"memory"` + CPU CPUStats `json:"cpu"` + Runtime RuntimeStats `json:"runtime"` + TopAllocations []AllocationInfo `json:"top_allocations"` + History []HistoryPoint `json:"history"` +} + +// cpuSample holds a CPU time sample for calculating usage +type cpuSample struct { + timestamp time.Time + userTime time.Duration + systemTime time.Duration +} + +// MetricsCollector collects and stores runtime metrics +type MetricsCollector struct { + mu sync.RWMutex + history []HistoryPoint + stopCh chan struct{} + started bool + lastCPUSample cpuSample + currentCPU CPUStats +} + +// DevPprofHandler handles development profiling endpoints +type DevPprofHandler struct { + collector *MetricsCollector +} + +// Global collector instance +var globalCollector *MetricsCollector +var collectorOnce sync.Once + +// IsDevMode checks if dev mode is enabled via environment variable +func IsDevMode() bool { + return os.Getenv("BIFROST_UI_DEV") == "true" +} + +// getOrCreateCollector returns the global metrics collector, creating it if needed +func getOrCreateCollector() *MetricsCollector { + collectorOnce.Do(func() { + globalCollector = &MetricsCollector{ + history: make([]HistoryPoint, 0, historySize), + stopCh: make(chan struct{}), + } + }) + return globalCollector +} + +// NewDevPprofHandler creates a new dev pprof handler +func NewDevPprofHandler() *DevPprofHandler { + return &DevPprofHandler{ + collector: getOrCreateCollector(), + } +} + +// Start begins the background metrics collection +func (c *MetricsCollector) Start() { + c.mu.Lock() + if c.started { + c.mu.Unlock() + return + } + c.stopCh = make(chan struct{}) + c.started = true + c.mu.Unlock() + + go c.collectLoop() +} + +// Stop stops the background metrics collection +func (c *MetricsCollector) Stop() { + c.mu.Lock() + defer c.mu.Unlock() + if !c.started { + return + } + close(c.stopCh) + c.stopCh = nil + c.started = false +} + +func (c *MetricsCollector) collectLoop() { + // Initialize CPU sample + c.lastCPUSample = getCPUSample() + + // Wait a bit before first collection to get accurate CPU reading + time.Sleep(100 * time.Millisecond) + + // Collect immediately on start + c.collect() + + ticker := time.NewTicker(metricsCollectionInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + c.collect() + case <-c.stopCh: + return + } + } +} + +// calculateCPUUsage calculates CPU usage percentage between two samples +func calculateCPUUsage(prev, curr cpuSample, numCPU int) CPUStats { + elapsed := curr.timestamp.Sub(prev.timestamp) + if elapsed <= 0 { + return CPUStats{} + } + + userDelta := curr.userTime - prev.userTime + systemDelta := curr.systemTime - prev.systemTime + totalCPUTime := userDelta + systemDelta + + // Calculate percentage: (CPU time used / wall time) * 100 + // Normalized by number of CPUs to get 0-100% range + cpuPercent := (float64(totalCPUTime) / float64(elapsed)) * 100.0 + + // Cap at 100% * numCPU (in case of measurement errors) + maxPercent := float64(numCPU) * 100.0 + if cpuPercent > maxPercent { + cpuPercent = maxPercent + } + + return CPUStats{ + UsagePercent: cpuPercent, + UserTime: userDelta.Seconds(), + SystemTime: systemDelta.Seconds(), + } +} + +func (c *MetricsCollector) collect() { + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + + // Get current CPU sample and calculate usage + currentSample := getCPUSample() + cpuStats := calculateCPUUsage(c.lastCPUSample, currentSample, runtime.NumCPU()) + c.lastCPUSample = currentSample + + point := HistoryPoint{ + Timestamp: time.Now().Format(time.RFC3339), + Alloc: memStats.Alloc, + HeapInuse: memStats.HeapInuse, + Goroutines: runtime.NumGoroutine(), + GCPauseNs: memStats.PauseNs[(memStats.NumGC+255)%256], + CPUPercent: cpuStats.UsagePercent, + } + + c.mu.Lock() + defer c.mu.Unlock() + + // Store current CPU stats for API response + c.currentCPU = cpuStats + + // Append to history, maintaining ring buffer behavior + if len(c.history) >= historySize { + // Shift left by one and append + copy(c.history, c.history[1:]) + c.history[len(c.history)-1] = point + } else { + c.history = append(c.history, point) + } +} + +func (c *MetricsCollector) getHistory() []HistoryPoint { + c.mu.RLock() + defer c.mu.RUnlock() + + // Return a copy to avoid race conditions + result := make([]HistoryPoint, len(c.history)) + copy(result, c.history) + return result +} + +func (c *MetricsCollector) getCPUStats() CPUStats { + c.mu.RLock() + defer c.mu.RUnlock() + return c.currentCPU +} + +// getTopAllocations analyzes heap profile to find top allocation sites +func getTopAllocations() []AllocationInfo { + // Write heap profile to buffer + var buf bytes.Buffer + if err := pprof.WriteHeapProfile(&buf); err != nil { + return []AllocationInfo{} + } + + // Parse the protobuf profile + p, err := profile.Parse(&buf) + if err != nil { + return []AllocationInfo{} + } + + // Find the indices for alloc_objects and alloc_space sample types + var allocObjectsIdx, allocSpaceIdx int + for i, st := range p.SampleType { + switch st.Type { + case "alloc_objects": + allocObjectsIdx = i + case "alloc_space": + allocSpaceIdx = i + } + } + + // Aggregate allocations by function (top of stack = allocation site) + allocMap := make(map[string]*AllocationInfo) + + for _, sample := range p.Sample { + if len(sample.Location) == 0 { + continue + } + loc := sample.Location[0] // Top of stack = allocation site + if len(loc.Line) == 0 { + continue + } + line := loc.Line[0] + fn := line.Function + if fn == nil { + continue + } + + // Skip allocations from the profiler itself + if isProfilerFunction(fn.Name, fn.Filename) { + continue + } + + key := fn.Name + if existing, ok := allocMap[key]; ok { + existing.Bytes += sample.Value[allocSpaceIdx] + existing.Count += sample.Value[allocObjectsIdx] + } else { + allocMap[key] = &AllocationInfo{ + Function: fn.Name, + File: fn.Filename, + Line: int(line.Line), + Bytes: sample.Value[allocSpaceIdx], + Count: sample.Value[allocObjectsIdx], + } + } + } + + // Convert map to slice + allocations := make([]AllocationInfo, 0, len(allocMap)) + for _, alloc := range allocMap { + allocations = append(allocations, *alloc) + } + + // Sort by bytes descending + sort.Slice(allocations, func(i, j int) bool { + return allocations[i].Bytes > allocations[j].Bytes + }) + + // Return top N allocations + if len(allocations) > topAllocationsCount { + allocations = allocations[:topAllocationsCount] + } + + return allocations +} + +// RegisterRoutes registers the dev pprof routes +func (h *DevPprofHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { + // Start the collector when routes are registered + h.collector.Start() + + r.GET("/api/dev/pprof", lib.ChainMiddlewares(h.getPprof, middlewares...)) + r.GET("/api/dev/pprof/goroutines", lib.ChainMiddlewares(h.getGoroutines, middlewares...)) +} + +// getPprof handles GET /api/dev/pprof +func (h *DevPprofHandler) getPprof(ctx *fasthttp.RequestCtx) { + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + + data := PprofData{ + Timestamp: time.Now().Format(time.RFC3339), + Memory: MemoryStats{ + Alloc: memStats.Alloc, + TotalAlloc: memStats.TotalAlloc, + HeapInuse: memStats.HeapInuse, + HeapObjects: memStats.HeapObjects, + Sys: memStats.Sys, + }, + CPU: h.collector.getCPUStats(), + Runtime: RuntimeStats{ + NumGoroutine: runtime.NumGoroutine(), + NumGC: memStats.NumGC, + GCPauseNs: memStats.PauseNs[(memStats.NumGC+255)%256], + NumCPU: runtime.NumCPU(), + GOMAXPROCS: runtime.GOMAXPROCS(0), + }, + TopAllocations: getTopAllocations(), + History: h.collector.getHistory(), + } + + SendJSON(ctx, data) +} + +// getGoroutines handles GET /api/dev/pprof/goroutines +// Returns goroutine stack traces grouped by stack signature +func (h *DevPprofHandler) getGoroutines(ctx *fasthttp.RequestCtx) { + // Check if raw output is requested + includeRaw := string(ctx.QueryArgs().Peek("raw")) == "true" + + // Get goroutine profile + var buf bytes.Buffer + if err := pprof.Lookup("goroutine").WriteTo(&buf, 2); err != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + SendJSON(ctx, map[string]string{"error": "failed to get goroutine profile"}) + return + } + + rawProfile := buf.String() + allGroups := parseGoroutineProfile(rawProfile) + + // Filter out profiler goroutines and calculate summary + groups := make([]GoroutineGroup, 0, len(allGroups)) + summary := GoroutineSummary{} + profilerGoroutineCount := 0 + + for i := range allGroups { + categorizeGoroutine(&allGroups[i]) + + // Skip profiler's own goroutines + if isProfilerGoroutine(&allGroups[i]) { + profilerGoroutineCount += allGroups[i].Count + continue + } + + groups = append(groups, allGroups[i]) + + switch allGroups[i].Category { + case "background": + summary.Background += allGroups[i].Count + case "per-request": + summary.PerRequest += allGroups[i].Count + } + + if allGroups[i].WaitMinutes >= 1 { + summary.LongWaiting += allGroups[i].Count + if allGroups[i].Category == "per-request" { + summary.PotentiallyStuck += allGroups[i].Count + } + } + } + + // Sort: potentially stuck first, then by wait time, then by count + sort.Slice(groups, func(i, j int) bool { + // Potentially stuck (per-request + long wait) first + iStuck := groups[i].Category == "per-request" && groups[i].WaitMinutes >= 1 + jStuck := groups[j].Category == "per-request" && groups[j].WaitMinutes >= 1 + if iStuck != jStuck { + return iStuck + } + // Then by wait time + if groups[i].WaitMinutes != groups[j].WaitMinutes { + return groups[i].WaitMinutes > groups[j].WaitMinutes + } + // Then by count + return groups[i].Count > groups[j].Count + }) + + // Calculate app goroutines (total minus profiler goroutines) + // Calculate total goroutines from profile snapshot + totalFromProfile := 0 + for _, g := range groups { + totalFromProfile += g.Count + } + + response := GoroutineProfile{ + Timestamp: time.Now().Format(time.RFC3339), + TotalGoroutines: totalFromProfile, + Groups: groups, + Summary: summary, + } + + if includeRaw { + response.RawProfile = rawProfile + } + + SendJSON(ctx, response) +} + +// categorizeGoroutine determines if a goroutine is a background worker or per-request +func categorizeGoroutine(g *GoroutineGroup) { + // Parse wait time from wait reason (e.g., "5 minutes" -> 5) + g.WaitMinutes = parseWaitMinutes(g.WaitReason) + + stackStr := strings.Join(g.Stack, " ") + + // Background goroutines - expected to run forever + backgroundPatterns := []string{ + "requestWorker", // Provider queue workers + "collectLoop", // Metrics collector + "cleanupWorker", // Various cleanup workers + "startAccumulatorMapCleanup", // Stream accumulator cleanup + "cleanupOldTraces", // Trace store cleanup + "startCleanup", // Generic cleanup + "monitorLoop", // Health monitor + "StartHeartbeat", // WebSocket heartbeat + "time.Sleep", // Ticker-based workers + "runtime.gopark", // Runtime parking (often tickers) + "sync.(*Cond).Wait", // Condition variable waits + "net/http.(*persistConn)", // HTTP connection pool + "internal/poll.runtime_pollWait", // Network polling + } + + for _, pattern := range backgroundPatterns { + if strings.Contains(stackStr, pattern) { + g.Category = "background" + return + } + } + + // Per-request goroutines - should complete when request ends + perRequestPatterns := []string{ + "PreHook", + "PostHook", + "completeAndFlushTrace", + "ProcessAndSend", + "handleProvider", + "Inject", // Observability plugin inject + "insertInitialLogEntry", // Logging + "updateLogEntry", // Logging + "updateStreamingLogEntry", + "retryOnNotFound", + "BroadcastLogUpdate", + } + + for _, pattern := range perRequestPatterns { + if strings.Contains(stackStr, pattern) { + g.Category = "per-request" + return + } + } + + g.Category = "unknown" +} + +// parseWaitMinutes extracts wait time in minutes from wait reason string +func parseWaitMinutes(waitReason string) int { + if waitReason == "" { + return 0 + } + + // Match patterns like "5 minutes", "1 minute", "30 seconds", "2 hours" + minuteRegex := regexp.MustCompile(`(\d+)\s*minute`) + if matches := minuteRegex.FindStringSubmatch(waitReason); len(matches) >= 2 { + if mins, err := strconv.Atoi(matches[1]); err == nil { + return mins + } + } + + hourRegex := regexp.MustCompile(`(\d+)\s*hour`) + if matches := hourRegex.FindStringSubmatch(waitReason); len(matches) >= 2 { + if hours, err := strconv.Atoi(matches[1]); err == nil { + return hours * 60 + } + } + + secondRegex := regexp.MustCompile(`(\d+)\s*second`) + if matches := secondRegex.FindStringSubmatch(waitReason); len(matches) >= 2 { + if secs, err := strconv.Atoi(matches[1]); err == nil { + return secs / 60 // Convert to minutes, will be 0 for < 60 seconds + } + } + + return 0 +} + +// parseGoroutineProfile parses the text output of pprof goroutine profile +// and groups goroutines by their stack trace +func parseGoroutineProfile(profile string) []GoroutineGroup { + // Regex to match goroutine header: "goroutine N [state, wait reason]:" + // Examples: + // goroutine 1 [running]: + // goroutine 42 [select, 5 minutes]: + // goroutine 100 [chan receive]: + headerRegex := regexp.MustCompile(`goroutine \d+ \[([^\]]+)\]:`) + + // Split by "goroutine " to get individual goroutine blocks + blocks := strings.Split(profile, "goroutine ") + + // Map to group goroutines by stack signature + groupMap := make(map[string]*GoroutineGroup) + + for _, block := range blocks { + block = strings.TrimSpace(block) + if block == "" { + continue + } + + // Re-add "goroutine " prefix for regex matching + fullBlock := "goroutine " + block + + // Extract state from header + matches := headerRegex.FindStringSubmatch(fullBlock) + if len(matches) < 2 { + continue + } + + stateInfo := matches[1] + state := stateInfo + waitReason := "" + + // Parse state and wait reason (e.g., "select, 5 minutes" -> state="select", waitReason="5 minutes") + if idx := strings.Index(stateInfo, ","); idx != -1 { + state = strings.TrimSpace(stateInfo[:idx]) + waitReason = strings.TrimSpace(stateInfo[idx+1:]) + } + + // Get stack trace (everything after the header line) + lines := strings.Split(block, "\n") + if len(lines) < 2 { + continue + } + + // Extract stack frames (skip the header line which is lines[0]) + var stackLines []string + var topFunc string + for i := 1; i < len(lines); i++ { + line := strings.TrimSpace(lines[i]) + if line == "" { + continue + } + stackLines = append(stackLines, line) + + // First function line (not a file:line) is the top function + if topFunc == "" && !strings.HasPrefix(line, "/") && !strings.Contains(line, ".go:") { + topFunc = line + } + } + + if len(stackLines) == 0 { + continue + } + + // Create a signature from the stack (top 10 frames for grouping) + maxFrames := 10 + if len(stackLines) < maxFrames { + maxFrames = len(stackLines) + } + signature := state + "|" + strings.Join(stackLines[:maxFrames], "|") + + // Group by signature + if existing, ok := groupMap[signature]; ok { + existing.Count++ + } else { + groupMap[signature] = &GoroutineGroup{ + Count: 1, + State: state, + WaitReason: waitReason, + TopFunc: topFunc, + Stack: stackLines, + } + } + } + + // Convert map to slice + groups := make([]GoroutineGroup, 0, len(groupMap)) + for _, group := range groupMap { + groups = append(groups, *group) + } + + return groups +} + +// profilerPatterns contains patterns to identify profiler-related code +var profilerPatterns = []string{ + "devpprof", + "pprof.WriteHeapProfile", + "pprof.Lookup", + "profile.Parse", + "MetricsCollector", + "collectLoop", + "getTopAllocations", + "parseGoroutineProfile", + "getGoroutines", + "getCPUSample", +} + +// isProfilerFunction checks if a function belongs to the profiler itself +func isProfilerFunction(funcName, fileName string) bool { + for _, pattern := range profilerPatterns { + if strings.Contains(funcName, pattern) || strings.Contains(fileName, pattern) { + return true + } + } + return false +} + +// isProfilerGoroutine checks if a goroutine belongs to the profiler +func isProfilerGoroutine(g *GoroutineGroup) bool { + stackStr := strings.Join(g.Stack, " ") + for _, pattern := range profilerPatterns { + if strings.Contains(stackStr, pattern) { + return true + } + } + return false +} + +// Cleanup stops the metrics collector +func (h *DevPprofHandler) Cleanup() { + if h.collector != nil { + h.collector.Stop() + } +} diff --git a/transports/bifrost-http/handlers/devpprof_unix.go b/transports/bifrost-http/handlers/devpprof_unix.go new file mode 100644 index 0000000000..5c9b72b2f4 --- /dev/null +++ b/transports/bifrost-http/handlers/devpprof_unix.go @@ -0,0 +1,27 @@ +//go:build !windows +// +build !windows + +package handlers + +import ( + "syscall" + "time" +) + +// getCPUSample gets the current CPU time sample using syscall +func getCPUSample() cpuSample { + var rusage syscall.Rusage + if err := syscall.Getrusage(syscall.RUSAGE_SELF, &rusage); err != nil { + return cpuSample{timestamp: time.Now()} + } + + userTime := time.Duration(rusage.Utime.Sec)*time.Second + time.Duration(rusage.Utime.Usec)*time.Microsecond + systemTime := time.Duration(rusage.Stime.Sec)*time.Second + time.Duration(rusage.Stime.Usec)*time.Microsecond + + return cpuSample{ + timestamp: time.Now(), + userTime: userTime, + systemTime: systemTime, + } +} + diff --git a/transports/bifrost-http/handlers/devpprof_windows.go b/transports/bifrost-http/handlers/devpprof_windows.go new file mode 100644 index 0000000000..1e8a805c44 --- /dev/null +++ b/transports/bifrost-http/handlers/devpprof_windows.go @@ -0,0 +1,13 @@ +//go:build windows +// +build windows + +package handlers + +import "time" + +// getCPUSample returns a zeroed CPU sample on Windows +// Windows does not support syscall.Getrusage +func getCPUSample() cpuSample { + return cpuSample{timestamp: time.Now()} +} + diff --git a/transports/bifrost-http/handlers/governance.go b/transports/bifrost-http/handlers/governance.go index 4ab20b7614..2cf2c06ef7 100644 --- a/transports/bifrost-http/handlers/governance.go +++ b/transports/bifrost-http/handlers/governance.go @@ -12,6 +12,7 @@ import ( "github.com/fasthttp/router" "github.com/google/uuid" + "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/plugins/governance" @@ -22,6 +23,7 @@ import ( // GovernanceManager is the interface for the governance manager type GovernanceManager interface { + GetGovernanceData() *governance.GovernanceData ReloadVirtualKey(ctx context.Context, id string) (*configstoreTables.TableVirtualKey, error) RemoveVirtualKey(ctx context.Context, id string) error ReloadTeam(ctx context.Context, id string) (*configstoreTables.TableTeam, error) @@ -38,6 +40,9 @@ type GovernanceHandler struct { // NewGovernanceHandler creates a new governance handler instance func NewGovernanceHandler(manager GovernanceManager, configStore configstore.ConfigStore) (*GovernanceHandler, error) { + if manager == nil { + return nil, fmt.Errorf("governance manager is required") + } if configStore == nil { return nil, fmt.Errorf("config store is required") } @@ -150,7 +155,7 @@ type UpdateCustomerRequest struct { } // RegisterRoutes registers all governance-related routes for the new hierarchical system -func (h *GovernanceHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *GovernanceHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { // Virtual Key CRUD operations r.GET("/api/governance/virtual-keys", lib.ChainMiddlewares(h.getVirtualKeys, middlewares...)) r.POST("/api/governance/virtual-keys", lib.ChainMiddlewares(h.createVirtualKey, middlewares...)) @@ -171,12 +176,30 @@ func (h *GovernanceHandler) RegisterRoutes(r *router.Router, middlewares ...lib. r.GET("/api/governance/customers/{customer_id}", lib.ChainMiddlewares(h.getCustomer, middlewares...)) r.PUT("/api/governance/customers/{customer_id}", lib.ChainMiddlewares(h.updateCustomer, middlewares...)) r.DELETE("/api/governance/customers/{customer_id}", lib.ChainMiddlewares(h.deleteCustomer, middlewares...)) + + // Budget and Rate Limit GET operations + r.GET("/api/governance/budgets", lib.ChainMiddlewares(h.getBudgets, middlewares...)) + r.GET("/api/governance/rate-limits", lib.ChainMiddlewares(h.getRateLimits, middlewares...)) } // Virtual Key CRUD Operations // getVirtualKeys handles GET /api/governance/virtual-keys - Get all virtual keys with relationships func (h *GovernanceHandler) getVirtualKeys(ctx *fasthttp.RequestCtx) { + // Check if "from_memory" query parameter is set to true + fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" + if fromMemory { + data := h.governanceManager.GetGovernanceData() + if data == nil { + SendError(ctx, 500, "Governance data is not available") + return + } + SendJSON(ctx, map[string]interface{}{ + "virtual_keys": data.VirtualKeys, + "count": len(data.VirtualKeys), + }) + return + } // Preload all relationships for complete information virtualKeys, err := h.configStore.GetVirtualKeys(ctx) if err != nil { @@ -285,29 +308,29 @@ func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { } } - // Get keys for this provider config if specified - var keys []configstoreTables.TableKey - if len(pc.KeyIDs) > 0 { - var err error - keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) - if err != nil { - return fmt.Errorf("failed to get keys by IDs for provider %s: %w", pc.Provider, err) - } - if len(keys) != len(pc.KeyIDs) { - return fmt.Errorf("some keys not found for provider %s: expected %d, found %d", pc.Provider, len(pc.KeyIDs), len(keys)) + // Get keys for this provider config if specified + var keys []configstoreTables.TableKey + if len(pc.KeyIDs) > 0 { + var err error + keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) + if err != nil { + return fmt.Errorf("failed to get keys by IDs for provider %s: %w", pc.Provider, err) + } + if len(keys) != len(pc.KeyIDs) { + return fmt.Errorf("some keys not found for provider %s: expected %d, found %d", pc.Provider, len(pc.KeyIDs), len(keys)) + } } - } - providerConfig := &configstoreTables.TableVirtualKeyProviderConfig{ - VirtualKeyID: vk.ID, - Provider: pc.Provider, - Weight: &pc.Weight, - AllowedModels: pc.AllowedModels, - Keys: keys, - } + providerConfig := &configstoreTables.TableVirtualKeyProviderConfig{ + VirtualKeyID: vk.ID, + Provider: pc.Provider, + Weight: &pc.Weight, + AllowedModels: pc.AllowedModels, + Keys: keys, + } - // Create budget for provider config if provided - if pc.Budget != nil { + // Create budget for provider config if provided + if pc.Budget != nil { budget := configstoreTables.TableBudget{ ID: uuid.NewString(), MaxLimit: pc.Budget.MaxLimit, @@ -397,6 +420,25 @@ func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { // getVirtualKey handles GET /api/governance/virtual-keys/{vk_id} - Get a specific virtual key func (h *GovernanceHandler) getVirtualKey(ctx *fasthttp.RequestCtx) { vkID := ctx.UserValue("vk_id").(string) + // Check if "from_memory" query parameter is set to true + fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" + if fromMemory { + data := h.governanceManager.GetGovernanceData() + if data == nil { + SendError(ctx, 500, "Governance data is not available") + return + } + for _, vk := range data.VirtualKeys { + if vk.ID == vkID { + SendJSON(ctx, map[string]interface{}{ + "virtual_key": vk, + }) + return + } + } + SendError(ctx, 404, "Virtual key not found") + return + } vk, err := h.configStore.GetVirtualKey(ctx, vkID) if err != nil { if errors.Is(err, configstore.ErrNotFound) { @@ -885,6 +927,7 @@ func (h *GovernanceHandler) deleteVirtualKey(ctx *fasthttp.RequestCtx) { SendError(ctx, 404, "Virtual key not found") return } + logger.Error("failed to delete virtual key: %v", err) SendError(ctx, 500, "Failed to delete virtual key") return } @@ -898,6 +941,33 @@ func (h *GovernanceHandler) deleteVirtualKey(ctx *fasthttp.RequestCtx) { // getTeams handles GET /api/governance/teams - Get all teams func (h *GovernanceHandler) getTeams(ctx *fasthttp.RequestCtx) { customerID := string(ctx.QueryArgs().Peek("customer_id")) + // Check if "from_memory" query parameter is set to true + fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" + if fromMemory { + data := h.governanceManager.GetGovernanceData() + if data == nil { + SendError(ctx, 500, "Governance data is not available") + return + } + if customerID != "" { + teams := make(map[string]*configstoreTables.TableTeam) + for _, team := range data.Teams { + if team.CustomerID != nil && *team.CustomerID == customerID { + teams[team.ID] = team + } + } + SendJSON(ctx, map[string]interface{}{ + "teams": teams, + "count": len(teams), + }) + } else { + SendJSON(ctx, map[string]interface{}{ + "teams": data.Teams, + "count": len(data.Teams), + }) + } + return + } // Preload relationships for complete information teams, err := h.configStore.GetTeams(ctx, customerID) if err != nil { @@ -980,6 +1050,24 @@ func (h *GovernanceHandler) createTeam(ctx *fasthttp.RequestCtx) { // getTeam handles GET /api/governance/teams/{team_id} - Get a specific team func (h *GovernanceHandler) getTeam(ctx *fasthttp.RequestCtx) { teamID := ctx.UserValue("team_id").(string) + // Check if "from_memory" query parameter is set to true + fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" + if fromMemory { + data := h.governanceManager.GetGovernanceData() + if data == nil { + SendError(ctx, 500, "Governance data is not available") + return + } + team, ok := data.Teams[teamID] + if !ok { + SendError(ctx, 404, "Team not found") + return + } + SendJSON(ctx, map[string]interface{}{ + "team": team, + }) + return + } team, err := h.configStore.GetTeam(ctx, teamID) if err != nil { if errors.Is(err, configstore.ErrNotFound) { @@ -1112,6 +1200,20 @@ func (h *GovernanceHandler) deleteTeam(ctx *fasthttp.RequestCtx) { // getCustomers handles GET /api/governance/customers - Get all customers func (h *GovernanceHandler) getCustomers(ctx *fasthttp.RequestCtx) { + // Check if "from_memory" query parameter is set to true + fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" + if fromMemory { + data := h.governanceManager.GetGovernanceData() + if data == nil { + SendError(ctx, 500, "Governance data is not available") + return + } + SendJSON(ctx, map[string]interface{}{ + "customers": data.Customers, + "count": len(data.Customers), + }) + return + } customers, err := h.configStore.GetCustomers(ctx) if err != nil { logger.Error("failed to retrieve customers: %v", err) @@ -1190,6 +1292,24 @@ func (h *GovernanceHandler) createCustomer(ctx *fasthttp.RequestCtx) { // getCustomer handles GET /api/governance/customers/{customer_id} - Get a specific customer func (h *GovernanceHandler) getCustomer(ctx *fasthttp.RequestCtx) { customerID := ctx.UserValue("customer_id").(string) + // Check if "from_memory" query parameter is set to true + fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" + if fromMemory { + data := h.governanceManager.GetGovernanceData() + if data == nil { + SendError(ctx, 500, "Governance data is not available") + return + } + customer, ok := data.Customers[customerID] + if !ok { + SendError(ctx, 404, "Customer not found") + return + } + SendJSON(ctx, map[string]interface{}{ + "customer": customer, + }) + return + } customer, err := h.configStore.GetCustomer(ctx, customerID) if err != nil { if errors.Is(err, configstore.ErrNotFound) { @@ -1316,6 +1436,64 @@ func (h *GovernanceHandler) deleteCustomer(ctx *fasthttp.RequestCtx) { }) } +// Budget and Rate Limit GET operations + +// getBudgets handles GET /api/governance/budgets - Get all budgets +func (h *GovernanceHandler) getBudgets(ctx *fasthttp.RequestCtx) { + // Check if "from_memory" query parameter is set to true + fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" + if fromMemory { + data := h.governanceManager.GetGovernanceData() + if data == nil { + SendError(ctx, 500, "Governance data is not available") + return + } + SendJSON(ctx, map[string]interface{}{ + "budgets": data.Budgets, + "count": len(data.Budgets), + }) + return + } + budgets, err := h.configStore.GetBudgets(ctx) + if err != nil { + logger.Error("failed to retrieve budgets: %v", err) + SendError(ctx, 500, "failed to retrieve budgets") + return + } + SendJSON(ctx, map[string]interface{}{ + "budgets": budgets, + "count": len(budgets), + }) +} + +// getRateLimits handles GET /api/governance/rate-limits - Get all rate limits +func (h *GovernanceHandler) getRateLimits(ctx *fasthttp.RequestCtx) { + // Check if "from_memory" query parameter is set to true + fromMemory := string(ctx.QueryArgs().Peek("from_memory")) == "true" + if fromMemory { + data := h.governanceManager.GetGovernanceData() + if data == nil { + SendError(ctx, 500, "Governance data is not available") + return + } + SendJSON(ctx, map[string]interface{}{ + "rate_limits": data.RateLimits, + "count": len(data.RateLimits), + }) + return + } + rateLimits, err := h.configStore.GetRateLimits(ctx) + if err != nil { + logger.Error("failed to retrieve rate limits: %v", err) + SendError(ctx, 500, "failed to retrieve rate limits") + return + } + SendJSON(ctx, map[string]interface{}{ + "rate_limits": rateLimits, + "count": len(rateLimits), + }) +} + // validateRateLimit validates the rate limit func validateRateLimit(rateLimit *configstoreTables.TableRateLimit) error { if rateLimit.TokenMaxLimit != nil && (*rateLimit.TokenMaxLimit < 0 || *rateLimit.TokenMaxLimit == 0) { diff --git a/transports/bifrost-http/handlers/health.go b/transports/bifrost-http/handlers/health.go index 9c4103354c..315315ccf4 100644 --- a/transports/bifrost-http/handlers/health.go +++ b/transports/bifrost-http/handlers/health.go @@ -6,6 +6,7 @@ import ( "time" "github.com/fasthttp/router" + "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" ) @@ -23,7 +24,7 @@ func NewHealthHandler(config *lib.Config) *HealthHandler { } // RegisterRoutes registers the health-related routes. -func (h *HealthHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *HealthHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { r.GET("/health", lib.ChainMiddlewares(h.getHealth, middlewares...)) } diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index bae59aab48..f920375ef7 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -412,7 +412,7 @@ const ( ) // RegisterRoutes registers all completion-related routes -func (h *CompletionHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *CompletionHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { // Model endpoints r.GET("/v1/models", lib.ChainMiddlewares(h.listModels, middlewares...)) @@ -485,9 +485,9 @@ func (h *CompletionHandler) listModels(ctx *fasthttp.RequestCtx) { // If provider is empty, list all models from all providers if provider == "" { - resp, bifrostErr = h.client.ListAllModels(*bifrostCtx, bifrostListModelsReq) + resp, bifrostErr = h.client.ListAllModels(bifrostCtx, bifrostListModelsReq) } else { - resp, bifrostErr = h.client.ListModelsRequest(*bifrostCtx, bifrostListModelsReq) + resp, bifrostErr = h.client.ListModelsRequest(bifrostCtx, bifrostListModelsReq) } if bifrostErr != nil { @@ -506,14 +506,14 @@ func (h *CompletionHandler) listModels(ctx *fasthttp.RequestCtx) { } if pricingEntry != nil && modelEntry.Pricing == nil { pricing := &schemas.Pricing{ - Prompt: bifrost.Ptr(fmt.Sprintf("%f", pricingEntry.InputCostPerToken)), - Completion: bifrost.Ptr(fmt.Sprintf("%f", pricingEntry.OutputCostPerToken)), + Prompt: bifrost.Ptr(fmt.Sprintf("%.10f", pricingEntry.InputCostPerToken)), + Completion: bifrost.Ptr(fmt.Sprintf("%.10f", pricingEntry.OutputCostPerToken)), } if pricingEntry.InputCostPerImage != nil { - pricing.Image = bifrost.Ptr(fmt.Sprintf("%f", *pricingEntry.InputCostPerImage)) + pricing.Image = bifrost.Ptr(fmt.Sprintf("%.10f", *pricingEntry.InputCostPerImage)) } if pricingEntry.CacheReadInputTokenCost != nil { - pricing.InputCacheRead = bifrost.Ptr(fmt.Sprintf("%f", *pricingEntry.CacheReadInputTokenCost)) + pricing.InputCacheRead = bifrost.Ptr(fmt.Sprintf("%.10f", *pricingEntry.CacheReadInputTokenCost)) } resp.Data[i].Pricing = pricing } @@ -585,7 +585,7 @@ func (h *CompletionHandler) textCompletion(ctx *fasthttp.RequestCtx) { // This is a known issue of valyala/fasthttp. And will be fixed here once it is fixed upstream. defer cancel() // Ensure cleanup on function exit - resp, bifrostErr := h.client.TextCompletionRequest(*bifrostCtx, bifrostTextReq) + resp, bifrostErr := h.client.TextCompletionRequest(bifrostCtx, bifrostTextReq) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) return @@ -671,15 +671,13 @@ func (h *CompletionHandler) chatCompletion(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context") return } - if req.Stream != nil && *req.Stream { h.handleStreamingChatCompletion(ctx, bifrostChatReq, bifrostCtx, cancel) return } - defer cancel() // Ensure cleanup on function exit - - resp, bifrostErr := h.client.ChatCompletionRequest(*bifrostCtx, bifrostChatReq) + // Complete the request + resp, bifrostErr := h.client.ChatCompletionRequest(bifrostCtx, bifrostChatReq) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) return @@ -761,7 +759,7 @@ func (h *CompletionHandler) responses(ctx *fasthttp.RequestCtx) { defer cancel() // Ensure cleanup on function exit - resp, bifrostErr := h.client.ResponsesRequest(*bifrostCtx, bifrostResponsesReq) + resp, bifrostErr := h.client.ResponsesRequest(bifrostCtx, bifrostResponsesReq) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) return @@ -827,7 +825,7 @@ func (h *CompletionHandler) embeddings(ctx *fasthttp.RequestCtx) { return } - resp, bifrostErr := h.client.EmbeddingRequest(*bifrostCtx, bifrostEmbeddingReq) + resp, bifrostErr := h.client.EmbeddingRequest(bifrostCtx, bifrostEmbeddingReq) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) return @@ -909,7 +907,7 @@ func (h *CompletionHandler) speech(ctx *fasthttp.RequestCtx) { defer cancel() // Ensure cleanup on function exit - resp, bifrostErr := h.client.SpeechRequest(*bifrostCtx, bifrostSpeechReq) + resp, bifrostErr := h.client.SpeechRequest(bifrostCtx, bifrostSpeechReq) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) return @@ -1044,7 +1042,7 @@ func (h *CompletionHandler) transcription(ctx *fasthttp.RequestCtx) { defer cancel() // Ensure cleanup on function exit // Make transcription request - resp, bifrostErr := h.client.TranscriptionRequest(*bifrostCtx, bifrostTranscriptionReq) + resp, bifrostErr := h.client.TranscriptionRequest(bifrostCtx, bifrostTranscriptionReq) // Handle response if bifrostErr != nil { @@ -1121,7 +1119,7 @@ func (h *CompletionHandler) countTokens(ctx *fasthttp.RequestCtx) { defer cancel() // Ensure cleanup on function exit // Make count tokens request - response, bifrostErr := h.client.CountTokensRequest(*bifrostCtx, bifrostReq) + response, bifrostErr := h.client.CountTokensRequest(bifrostCtx, bifrostReq) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) return @@ -1132,65 +1130,60 @@ func (h *CompletionHandler) countTokens(ctx *fasthttp.RequestCtx) { } // handleStreamingTextCompletion handles streaming text completion requests using Server-Sent Events (SSE) -func (h *CompletionHandler) handleStreamingTextCompletion(ctx *fasthttp.RequestCtx, req *schemas.BifrostTextCompletionRequest, bifrostCtx *context.Context, cancel context.CancelFunc) { +func (h *CompletionHandler) handleStreamingTextCompletion(ctx *fasthttp.RequestCtx, req *schemas.BifrostTextCompletionRequest, bifrostCtx *schemas.BifrostContext, cancel context.CancelFunc) { // Use the cancellable context from ConvertToBifrostContext // See router.go for detailed explanation of why we need a cancellable context - streamCtx := *bifrostCtx getStream := func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return h.client.TextCompletionStreamRequest(streamCtx, req) + return h.client.TextCompletionStreamRequest(bifrostCtx, req) } h.handleStreamingResponse(ctx, getStream, cancel) } // handleStreamingChatCompletion handles streaming chat completion requests using Server-Sent Events (SSE) -func (h *CompletionHandler) handleStreamingChatCompletion(ctx *fasthttp.RequestCtx, req *schemas.BifrostChatRequest, bifrostCtx *context.Context, cancel context.CancelFunc) { +func (h *CompletionHandler) handleStreamingChatCompletion(ctx *fasthttp.RequestCtx, req *schemas.BifrostChatRequest, bifrostCtx *schemas.BifrostContext, cancel context.CancelFunc) { // Use the cancellable context from ConvertToBifrostContext // See router.go for detailed explanation of why we need a cancellable context - streamCtx := *bifrostCtx getStream := func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return h.client.ChatCompletionStreamRequest(streamCtx, req) + return h.client.ChatCompletionStreamRequest(bifrostCtx, req) } h.handleStreamingResponse(ctx, getStream, cancel) } // handleStreamingResponses handles streaming responses requests using Server-Sent Events (SSE) -func (h *CompletionHandler) handleStreamingResponses(ctx *fasthttp.RequestCtx, req *schemas.BifrostResponsesRequest, bifrostCtx *context.Context, cancel context.CancelFunc) { +func (h *CompletionHandler) handleStreamingResponses(ctx *fasthttp.RequestCtx, req *schemas.BifrostResponsesRequest, bifrostCtx *schemas.BifrostContext, cancel context.CancelFunc) { // Use the cancellable context from ConvertToBifrostContext // See router.go for detailed explanation of why we need a cancellable context - streamCtx := *bifrostCtx getStream := func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return h.client.ResponsesStreamRequest(streamCtx, req) + return h.client.ResponsesStreamRequest(bifrostCtx, req) } h.handleStreamingResponse(ctx, getStream, cancel) } // handleStreamingSpeech handles streaming speech requests using Server-Sent Events (SSE) -func (h *CompletionHandler) handleStreamingSpeech(ctx *fasthttp.RequestCtx, req *schemas.BifrostSpeechRequest, bifrostCtx *context.Context, cancel context.CancelFunc) { +func (h *CompletionHandler) handleStreamingSpeech(ctx *fasthttp.RequestCtx, req *schemas.BifrostSpeechRequest, bifrostCtx *schemas.BifrostContext, cancel context.CancelFunc) { // Use the cancellable context from ConvertToBifrostContext // See router.go for detailed explanation of why we need a cancellable context - streamCtx := *bifrostCtx getStream := func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return h.client.SpeechStreamRequest(streamCtx, req) + return h.client.SpeechStreamRequest(bifrostCtx, req) } h.handleStreamingResponse(ctx, getStream, cancel) } // handleStreamingTranscriptionRequest handles streaming transcription requests using Server-Sent Events (SSE) -func (h *CompletionHandler) handleStreamingTranscriptionRequest(ctx *fasthttp.RequestCtx, req *schemas.BifrostTranscriptionRequest, bifrostCtx *context.Context, cancel context.CancelFunc) { +func (h *CompletionHandler) handleStreamingTranscriptionRequest(ctx *fasthttp.RequestCtx, req *schemas.BifrostTranscriptionRequest, bifrostCtx *schemas.BifrostContext, cancel context.CancelFunc) { // Use the cancellable context from ConvertToBifrostContext // See router.go for detailed explanation of why we need a cancellable context - streamCtx := *bifrostCtx getStream := func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return h.client.TranscriptionStreamRequest(streamCtx, req) + return h.client.TranscriptionStreamRequest(bifrostCtx, req) } h.handleStreamingResponse(ctx, getStream, cancel) @@ -1205,7 +1198,6 @@ func (h *CompletionHandler) handleStreamingResponse(ctx *fasthttp.RequestCtx, ge ctx.SetContentType("text/event-stream") ctx.Response.Header.Set("Cache-Control", "no-cache") ctx.Response.Header.Set("Connection", "keep-alive") - ctx.Response.Header.Set("Access-Control-Allow-Origin", "*") // Get the streaming channel stream, bifrostErr := getStream() @@ -1216,11 +1208,25 @@ func (h *CompletionHandler) handleStreamingResponse(ctx *fasthttp.RequestCtx, ge return } + // Signal to tracing middleware that trace completion should be deferred + // The streaming callback will complete the trace after the stream ends + ctx.SetUserValue(schemas.BifrostContextKeyDeferTraceCompletion, true) + + // Get the trace completer function for use in the streaming callback + traceCompleter, _ := ctx.UserValue(schemas.BifrostContextKeyTraceCompleter).(func()) + var includeEventType bool // Use streaming response writer ctx.Response.SetBodyStreamWriter(func(w *bufio.Writer) { - defer w.Flush() + defer func() { + w.Flush() + // Complete the trace after streaming finishes + // This ensures all spans (including llm.call) are properly ended before the trace is sent to OTEL + if traceCompleter != nil { + traceCompleter() + } + }() // Process streaming responses for chunk := range stream { @@ -1287,6 +1293,7 @@ func (h *CompletionHandler) handleStreamingResponse(ctx *fasthttp.RequestCtx, ge } // Note: OpenAI responses API doesn't use [DONE] marker, it ends when the stream closes // Stream completed normally, Bifrost handles cleanup internally + cancel() }) } @@ -1419,7 +1426,7 @@ func (h *CompletionHandler) batchCreate(ctx *fasthttp.RequestCtx) { return } - resp, bifrostErr := h.client.BatchCreateRequest(*bifrostCtx, bifrostBatchReq) + resp, bifrostErr := h.client.BatchCreateRequest(bifrostCtx, bifrostBatchReq) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) return @@ -1472,7 +1479,7 @@ func (h *CompletionHandler) batchList(ctx *fasthttp.RequestCtx) { return } - resp, bifrostErr := h.client.BatchListRequest(*bifrostCtx, bifrostBatchReq) + resp, bifrostErr := h.client.BatchListRequest(bifrostCtx, bifrostBatchReq) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) return @@ -1511,7 +1518,7 @@ func (h *CompletionHandler) batchRetrieve(ctx *fasthttp.RequestCtx) { return } - resp, bifrostErr := h.client.BatchRetrieveRequest(*bifrostCtx, bifrostBatchReq) + resp, bifrostErr := h.client.BatchRetrieveRequest(bifrostCtx, bifrostBatchReq) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) return @@ -1550,7 +1557,7 @@ func (h *CompletionHandler) batchCancel(ctx *fasthttp.RequestCtx) { return } - resp, bifrostErr := h.client.BatchCancelRequest(*bifrostCtx, bifrostBatchReq) + resp, bifrostErr := h.client.BatchCancelRequest(bifrostCtx, bifrostBatchReq) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) return @@ -1589,7 +1596,7 @@ func (h *CompletionHandler) batchResults(ctx *fasthttp.RequestCtx) { return } - resp, bifrostErr := h.client.BatchResultsRequest(*bifrostCtx, bifrostBatchReq) + resp, bifrostErr := h.client.BatchResultsRequest(bifrostCtx, bifrostBatchReq) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) return @@ -1671,7 +1678,7 @@ func (h *CompletionHandler) fileUpload(ctx *fasthttp.RequestCtx) { return } - resp, bifrostErr := h.client.FileUploadRequest(*bifrostCtx, bifrostFileReq) + resp, bifrostErr := h.client.FileUploadRequest(bifrostCtx, bifrostFileReq) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) return @@ -1730,7 +1737,7 @@ func (h *CompletionHandler) fileList(ctx *fasthttp.RequestCtx) { return } - resp, bifrostErr := h.client.FileListRequest(*bifrostCtx, bifrostFileReq) + resp, bifrostErr := h.client.FileListRequest(bifrostCtx, bifrostFileReq) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) return @@ -1769,7 +1776,7 @@ func (h *CompletionHandler) fileRetrieve(ctx *fasthttp.RequestCtx) { return } - resp, bifrostErr := h.client.FileRetrieveRequest(*bifrostCtx, bifrostFileReq) + resp, bifrostErr := h.client.FileRetrieveRequest(bifrostCtx, bifrostFileReq) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) return @@ -1808,7 +1815,7 @@ func (h *CompletionHandler) fileDelete(ctx *fasthttp.RequestCtx) { return } - resp, bifrostErr := h.client.FileDeleteRequest(*bifrostCtx, bifrostFileReq) + resp, bifrostErr := h.client.FileDeleteRequest(bifrostCtx, bifrostFileReq) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) return @@ -1847,7 +1854,7 @@ func (h *CompletionHandler) fileContent(ctx *fasthttp.RequestCtx) { return } - resp, bifrostErr := h.client.FileContentRequest(*bifrostCtx, bifrostFileReq) + resp, bifrostErr := h.client.FileContentRequest(bifrostCtx, bifrostFileReq) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) return diff --git a/transports/bifrost-http/handlers/integrations.go b/transports/bifrost-http/handlers/integrations.go index 60625475ad..de5b9fbcf4 100644 --- a/transports/bifrost-http/handlers/integrations.go +++ b/transports/bifrost-http/handlers/integrations.go @@ -5,6 +5,7 @@ package handlers import ( "github.com/fasthttp/router" bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/transports/bifrost-http/integrations" "github.com/maximhq/bifrost/transports/bifrost-http/lib" ) @@ -33,7 +34,7 @@ func NewIntegrationHandler(client *bifrost.Bifrost, handlerStore lib.HandlerStor } // RegisterRoutes registers all integration routes for AI provider compatibility endpoints -func (h *IntegrationHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *IntegrationHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { // Register routes for each integration extension for _, extension := range h.extensions { extension.RegisterRoutes(r, middlewares...) diff --git a/transports/bifrost-http/handlers/logging.go b/transports/bifrost-http/handlers/logging.go index cde2df1712..7fe7f15744 100644 --- a/transports/bifrost-http/handlers/logging.go +++ b/transports/bifrost-http/handlers/logging.go @@ -39,7 +39,7 @@ func NewLoggingHandler(logManager logging.LogManager, redactedKeysManager Redact } // RegisterRoutes registers all logging-related routes -func (h *LoggingHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *LoggingHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { // Log retrieval with filtering, search, and pagination r.GET("/api/logs", lib.ChainMiddlewares(h.getLogs, middlewares...)) r.GET("/api/logs/stats", lib.ChainMiddlewares(h.getLogsStats, middlewares...)) diff --git a/transports/bifrost-http/handlers/mcp.go b/transports/bifrost-http/handlers/mcp.go index 320b058e4b..21ed477755 100644 --- a/transports/bifrost-http/handlers/mcp.go +++ b/transports/bifrost-http/handlers/mcp.go @@ -8,6 +8,7 @@ import ( "fmt" "slices" "sort" + "strings" "github.com/fasthttp/router" bifrost "github.com/maximhq/bifrost/core" @@ -39,7 +40,7 @@ func NewMCPHandler(mcpManager MCPManager, client *bifrost.Bifrost, store *lib.Co } // RegisterRoutes registers all MCP-related routes -func (h *MCPHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *MCPHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { // MCP tool execution endpoint r.POST("/v1/mcp/tool/execute", lib.ChainMiddlewares(h.executeTool, middlewares...)) r.GET("/api/mcp/clients", lib.ChainMiddlewares(h.getMCPClients, middlewares...)) @@ -51,6 +52,21 @@ func (h *MCPHandler) RegisterRoutes(r *router.Router, middlewares ...lib.Bifrost // executeTool handles POST /v1/mcp/tool/execute - Execute MCP tool func (h *MCPHandler) executeTool(ctx *fasthttp.RequestCtx) { + // Check format query parameter + format := strings.ToLower(string(ctx.QueryArgs().Peek("format"))) + switch format { + case "chat", "": + h.executeChatMCPTool(ctx) + case "responses": + h.executeResponsesMCPTool(ctx) + default: + SendError(ctx, fasthttp.StatusBadRequest, "Invalid format value, must be 'chat' or 'responses'") + return + } +} + +// executeChatMCPTool handles POST /v1/mcp/tool/execute?format=chat - Execute MCP tool +func (h *MCPHandler) executeChatMCPTool(ctx *fasthttp.RequestCtx) { var req schemas.ChatAssistantMessageToolCall if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err)) @@ -72,14 +88,47 @@ func (h *MCPHandler) executeTool(ctx *fasthttp.RequestCtx) { } // Execute MCP tool - resp, bifrostErr := h.client.ExecuteMCPTool(*bifrostCtx, req) + toolMessage, bifrostErr := h.client.ExecuteChatMCPTool(bifrostCtx, req) if bifrostErr != nil { SendBifrostError(ctx, bifrostErr) return } // Send successful response - SendJSON(ctx, resp) + SendJSON(ctx, toolMessage) +} + +// executeResponsesMCPTool handles POST /v1/mcp/tool/execute?format=responses - Execute MCP tool +func (h *MCPHandler) executeResponsesMCPTool(ctx *fasthttp.RequestCtx) { + var req schemas.ResponsesToolMessage + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err)) + return + } + + // Validate required fields + if req.Name == nil || *req.Name == "" { + SendError(ctx, fasthttp.StatusBadRequest, "Tool function name is required") + return + } + + // Convert context + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.store.GetHeaderFilterConfig()) + defer cancel() // Ensure cleanup on function exit + if bifrostCtx == nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context") + return + } + + // Execute MCP tool + toolMessage, bifrostErr := h.client.ExecuteResponsesMCPTool(bifrostCtx, &req) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr) + return + } + + // Send successful response + SendJSON(ctx, toolMessage) } // getMCPClients handles GET /api/mcp/clients - Get all MCP clients @@ -119,7 +168,7 @@ func (h *MCPHandler) getMCPClients(ctx *fasthttp.RequestCtx) { clients = append(clients, schemas.MCPClient{ Config: h.store.RedactMCPClientConfig(connectedClient.Config), Tools: sortedTools, - State: connectedClient.State, + State: connectedClient.State, // Use the state from MCPClientState }) } else { // Client is in config but not connected, mark as errored @@ -189,13 +238,28 @@ func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid tools_to_execute: %v", err)) return } + + // Auto-clear tools_to_auto_execute if tools_to_execute is empty + // If no tools are allowed to execute, no tools can be auto-executed + if len(req.ToolsToExecute) == 0 { + req.ToolsToAutoExecute = []string{} + } + + if err := validateToolsToAutoExecute(req.ToolsToAutoExecute, req.ToolsToExecute); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid tools_to_auto_execute: %v", err)) + return + } + if err := validateMCPClientName(req.Name); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid client name: %v", err)) + return + } if err := h.mcpManager.AddMCPClient(ctx, req); err != nil { - SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to add MCP client: %v", err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to connect MCP client: %v", err)) return } SendJSON(ctx, map[string]any{ "status": "success", - "message": "MCP client added successfully", + "message": "MCP client connected successfully", }) } @@ -219,6 +283,24 @@ func (h *MCPHandler) editMCPClient(ctx *fasthttp.RequestCtx) { return } + // Auto-clear tools_to_auto_execute if tools_to_execute is empty + // If no tools are allowed to execute, no tools can be auto-executed + if len(req.ToolsToExecute) == 0 { + req.ToolsToAutoExecute = []string{} + } + + // Validate tools_to_auto_execute + if err := validateToolsToAutoExecute(req.ToolsToAutoExecute, req.ToolsToExecute); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid tools_to_auto_execute: %v", err)) + return + } + + // Validate client name + if err := validateMCPClientName(req.Name); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid client name: %v", err)) + return + } + if err := h.mcpManager.EditMCPClient(ctx, id, req); err != nil { SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to edit MCP client: %v", err)) return @@ -280,3 +362,69 @@ func validateToolsToExecute(toolsToExecute []string) error { return nil } + +func validateToolsToAutoExecute(toolsToAutoExecute []string, toolsToExecute []string) error { + if len(toolsToAutoExecute) > 0 { + // Check if wildcard "*" is combined with other tool names + hasWildcard := slices.Contains(toolsToAutoExecute, "*") + if hasWildcard && len(toolsToAutoExecute) > 1 { + return fmt.Errorf("wildcard '*' cannot be combined with other tool names") + } + + // Check for duplicate entries + seen := make(map[string]bool) + for _, tool := range toolsToAutoExecute { + if seen[tool] { + return fmt.Errorf("duplicate tool name '%s'", tool) + } + seen[tool] = true + } + + // Check that all tools in ToolsToAutoExecute are also in ToolsToExecute + // Create a set of allowed tools from ToolsToExecute + allowedTools := make(map[string]bool) + hasWildcardInExecute := slices.Contains(toolsToExecute, "*") + if hasWildcardInExecute { + // If "*" is in ToolsToExecute, all tools are allowed + return nil + } + for _, tool := range toolsToExecute { + allowedTools[tool] = true + } + + // Validate each tool in ToolsToAutoExecute + for _, tool := range toolsToAutoExecute { + if tool == "*" { + // Wildcard is allowed if "*" is in ToolsToExecute + if !hasWildcardInExecute { + return fmt.Errorf("tool '%s' in tools_to_auto_execute is not in tools_to_execute", tool) + } + } else if !allowedTools[tool] { + return fmt.Errorf("tool '%s' in tools_to_auto_execute is not in tools_to_execute", tool) + } + } + } + + return nil +} + +func validateMCPClientName(name string) error { + if strings.TrimSpace(name) == "" { + return fmt.Errorf("client name is required") + } + for _, r := range name { + if r > 127 { // non-ASCII + return fmt.Errorf("name must contain only ASCII characters") + } + } + if strings.Contains(name, "-") { + return fmt.Errorf("client name cannot contain hyphens") + } + if strings.Contains(name, " ") { + return fmt.Errorf("client name cannot contain spaces") + } + if len(name) > 0 && name[0] >= '0' && name[0] <= '9' { + return fmt.Errorf("client name cannot start with a number") + } + return nil +} diff --git a/transports/bifrost-http/handlers/mcpserver.go b/transports/bifrost-http/handlers/mcpserver.go new file mode 100644 index 0000000000..d131830e79 --- /dev/null +++ b/transports/bifrost-http/handlers/mcpserver.go @@ -0,0 +1,392 @@ +// Package handlers provides HTTP request handlers for the Bifrost HTTP transport. +// This file contains MCP (Model Context Protocol) server implementation for HTTP streaming. +package handlers + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "slices" + "strings" + "sync" + + "github.com/fasthttp/router" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/maximhq/bifrost/plugins/governance" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// MCPToolExecutor interface defines the method needed for executing MCP tools +type MCPToolManager interface { + GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool + ExecuteChatMCPTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, *schemas.BifrostError) + ExecuteResponsesMCPTool(ctx context.Context, toolCall *schemas.ResponsesToolMessage) (*schemas.ResponsesMessage, *schemas.BifrostError) +} + +// MCPServerHandler manages HTTP requests for MCP server operations +// It implements the MCP protocol over HTTP streaming (SSE) for MCP clients +type MCPServerHandler struct { + toolManager MCPToolManager + globalMCPServer *server.MCPServer + vkMCPServers map[string]*server.MCPServer // Map of vk value -> mcp server + config *lib.Config + mu sync.RWMutex +} + +// NewMCPServerHandler creates a new MCP server handler instance +func NewMCPServerHandler(ctx context.Context, config *lib.Config, toolManager MCPToolManager) (*MCPServerHandler, error) { + if config == nil { + return nil, fmt.Errorf("config is required") + } + if toolManager == nil { + return nil, fmt.Errorf("tool manager is required") + } + + // Create MCP server instance using mcp-go + globalMCPServer := server.NewMCPServer( + "global", + version, + server.WithToolCapabilities(true), + ) + + handler := &MCPServerHandler{ + toolManager: toolManager, + globalMCPServer: globalMCPServer, + config: config, + vkMCPServers: make(map[string]*server.MCPServer), + } + + if err := handler.SyncAllMCPServers(ctx); err != nil { + return nil, fmt.Errorf("failed to sync all MCP servers: %w", err) + } + + return handler, nil +} + +// RegisterRoutes registers the MCP server route +func (h *MCPServerHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { + // MCP server endpoint - supports both POST (JSON-RPC) and GET (SSE) + r.POST("/mcp", lib.ChainMiddlewares(h.handleMCPServer, middlewares...)) + r.GET("/mcp", lib.ChainMiddlewares(h.handleMCPServerSSE, middlewares...)) +} + +// handleMCPServer handles POST requests for MCP JSON-RPC 2.0 messages +func (h *MCPServerHandler) handleMCPServer(ctx *fasthttp.RequestCtx) { + mcpServer, err := h.getMCPServerForRequest(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusUnauthorized, err.Error()) + return + } + + // Convert context + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderFilterConfig()) + defer cancel() + + // Use mcp-go server to handle the request + // HandleMessage processes JSON-RPC messages and returns appropriate responses + response := mcpServer.HandleMessage(bifrostCtx, ctx.PostBody()) + + // Check if response is nil (notification - no response needed) + if response == nil { + ctx.SetStatusCode(fasthttp.StatusOK) + return + } + + // Marshal and send response + responseJSON, err := json.Marshal(response) + if err != nil { + logger.Warn(fmt.Sprintf("Failed to marshal MCP response: %v", err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to encode response: %v", err)) + return + } + + ctx.SetContentType("application/json") + ctx.SetBody(responseJSON) +} + +// handleMCPServerSSE handles GET requests for MCP Server-Sent Events streaming +func (h *MCPServerHandler) handleMCPServerSSE(ctx *fasthttp.RequestCtx) { + _, err := h.getMCPServerForRequest(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusUnauthorized, err.Error()) + return + } + + // Set SSE headers + ctx.SetContentType("text/event-stream") + ctx.Response.Header.Set("Cache-Control", "no-cache") + ctx.Response.Header.Set("Connection", "keep-alive") + + // Convert context + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderFilterConfig()) + + // Use streaming response writer + ctx.Response.SetBodyStreamWriter(func(w *bufio.Writer) { + defer func() { + cancel() + _ = w.Flush() + }() + + // Send initial connection message + initMessage := map[string]interface{}{ + "jsonrpc": "2.0", + "method": "connection/opened", + } + if initJSON, err := json.Marshal(initMessage); err == nil { + fmt.Fprintf(w, "data: %s\n\n", initJSON) + w.Flush() + } + + // Wait for context cancellation (client disconnect or server-side cancel) + <-(*bifrostCtx).Done() + }) +} + +// Sync methods for MCP servers + +func (h *MCPServerHandler) SyncAllMCPServers(ctx context.Context) error { + h.mu.Lock() + defer h.mu.Unlock() + availableTools := h.toolManager.GetAvailableMCPTools(ctx) + h.syncServer(h.globalMCPServer, availableTools) + logger.Debug("Synced global MCP server with %d tools", len(availableTools)) + + // initialize vkMCPServers map + if h.config.ConfigStore != nil { + virtualKeys, err := h.config.ConfigStore.GetVirtualKeys(ctx) + if err != nil { + return fmt.Errorf("failed to get virtual keys: %w", err) + } + h.vkMCPServers = make(map[string]*server.MCPServer) + for i := range virtualKeys { + vk := &virtualKeys[i] + h.vkMCPServers[vk.Value] = server.NewMCPServer( + vk.Name, + version, + server.WithToolCapabilities(true), + ) + availableTools := h.fetchToolsForVK(vk) + h.syncServer(h.vkMCPServers[vk.Value], availableTools) + logger.Debug("Synced MCP server for virtual key '%s' with %d tools", vk.Name, len(availableTools)) + } + } + return nil +} + +func (h *MCPServerHandler) SyncVKMCPServer(vk *tables.TableVirtualKey) { + h.mu.Lock() + defer h.mu.Unlock() + vkServer, ok := h.vkMCPServers[vk.Value] + if !ok { + // Add new server + vkServer = server.NewMCPServer( + vk.Name, + version, + server.WithToolCapabilities(true), + ) + h.vkMCPServers[vk.Value] = vkServer + } + availableTools := h.fetchToolsForVK(vk) + h.syncServer(vkServer, availableTools) + h.vkMCPServers[vk.Value] = vkServer + logger.Debug("Synced MCP server for virtual key '%s' with %d tools", vk.Name, len(availableTools)) +} + +func (h *MCPServerHandler) DeleteVKMCPServer(vkValue string) { + h.mu.Lock() + defer h.mu.Unlock() + delete(h.vkMCPServers, vkValue) +} + +func (h *MCPServerHandler) syncServer(server *server.MCPServer, availableTools []schemas.ChatTool) { + // Clear existing tools + toolMap := server.ListTools() + for toolName, _ := range toolMap { + server.DeleteTools(toolName) + } + + // Register tools from all connected clients + for _, tool := range availableTools { + // Only process function tools (skip custom tools) + if tool.Function == nil { + continue + } + + // Capture tool name for closure + toolName := tool.Function.Name + + handler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Convert to Bifrost tool call format + toolCallType := "function" + toolCallID := fmt.Sprintf("mcp-%s", toolName) + argsJSON, jsonErr := json.Marshal(request.GetArguments()) + if jsonErr != nil { + return mcp.NewToolResultError(fmt.Sprintf("Failed to marshal tool arguments: %v", jsonErr)), nil + } + toolCall := schemas.ChatAssistantMessageToolCall{ + ID: &toolCallID, + Type: &toolCallType, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: &toolName, + Arguments: string(argsJSON), + }, + } + + // Execute the tool via tool executor + toolMessage, err := h.toolManager.ExecuteChatMCPTool(ctx, toolCall) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Tool execution failed: %v", bifrost.GetErrorMessage(err))), nil + } + + // Extract content from tool message + var resultText string + if toolMessage != nil && toolMessage.Content != nil { + // Handle ContentStr (string content) + if toolMessage.Content.ContentStr != nil { + resultText = *toolMessage.Content.ContentStr + } else if toolMessage.Content.ContentBlocks != nil { + // Handle ContentBlocks (structured content) + for _, block := range toolMessage.Content.ContentBlocks { + if block.Type == schemas.ChatContentBlockTypeText && block.Text != nil { + resultText += *block.Text + } + } + } + } + + // Return result using mcp-go helper + return mcp.NewToolResultText(resultText), nil + } + + // Convert description from *string to string + description := "" + if tool.Function.Description != nil { + description = *tool.Function.Description + } + + // Convert Parameters to mcp.ToolInputSchema + var inputSchema mcp.ToolInputSchema + if tool.Function.Parameters != nil { + inputSchema.Type = tool.Function.Parameters.Type + if tool.Function.Parameters.Properties != nil { + // Convert *map[string]interface{} to map[string]any + props := make(map[string]any) + for k, v := range *tool.Function.Parameters.Properties { + props[k] = v + } + inputSchema.Properties = props + } + if tool.Function.Parameters.Required != nil { + inputSchema.Required = tool.Function.Parameters.Required + } + } else { + // Default to empty object schema if no parameters + inputSchema.Type = "object" + inputSchema.Properties = make(map[string]any) + } + + // Register tool with the server + server.AddTool(mcp.Tool{ + Name: toolName, + Description: description, + InputSchema: inputSchema, + }, handler) + } +} + +// fetchToolsForVK fetches the tools for a given virtual key value. +// vkValue is the virtual key value for the server, if empty, all tools will be fetched for global mcp server. +// Returns a map of tool name to tool. +func (h *MCPServerHandler) fetchToolsForVK(vk *tables.TableVirtualKey) []schemas.ChatTool { + ctx := context.Background() + + if len(vk.MCPConfigs) > 0 { + executeOnlyTools := make([]string, 0) + for _, vkMcpConfig := range vk.MCPConfigs { + if len(vkMcpConfig.ToolsToExecute) == 0 { + // No tools specified in virtual key config - skip this client entirely + continue + } + + // Handle wildcard in virtual key config - allow all tools from this client + if slices.Contains(vkMcpConfig.ToolsToExecute, "*") { + // Virtual key uses wildcard - use client-specific wildcard + executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s/*", vkMcpConfig.MCPClient.Name)) + continue + } + + for _, tool := range vkMcpConfig.ToolsToExecute { + if tool != "" { + // Add the tool - client config filtering will be handled by mcp.go + executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s/%s", vkMcpConfig.MCPClient.Name, tool)) + } + } + } + + // Set even when empty to exclude tools when no tools are present in the virtual key config + ctx = context.WithValue(ctx, schemas.BifrostContextKey("mcp-include-tools"), executeOnlyTools) + } + + return h.toolManager.GetAvailableMCPTools(ctx) +} + +// Utility methods + +func (h *MCPServerHandler) getMCPServerForRequest(ctx *fasthttp.RequestCtx) (*server.MCPServer, error) { + h.mu.RLock() + defer h.mu.RUnlock() + + h.config.Mu.RLock() + enforceVK := h.config.ClientConfig.EnforceGovernanceHeader + h.config.Mu.RUnlock() + + vk := getVKFromRequest(ctx) + + // Return global MCP server if not enforcing virtual key header and no virtual key is provided + if !enforceVK && vk == "" { + return h.globalMCPServer, nil + } + + // Check if virtual key is provided + if vk == "" { + return nil, fmt.Errorf("virtual key header is required to access MCP server.") + } + + // Check if vk exists in the map + vkServer, ok := h.vkMCPServers[vk] + if !ok { + return nil, fmt.Errorf("virtual key not found.") + } + + return vkServer, nil +} + +func getVKFromRequest(ctx *fasthttp.RequestCtx) string { + if value := strings.TrimSpace(string(ctx.Request.Header.Peek(string(schemas.BifrostContextKeyVirtualKey)))); value != "" { + return value + } + + authHeader := strings.TrimSpace(string(ctx.Request.Header.Peek("Authorization"))) + if authHeader != "" { + if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { + token := strings.TrimSpace(authHeader[7:]) + if token != "" && strings.HasPrefix(strings.ToLower(token), governance.VirtualKeyPrefix) { + return token + } + } + } + + if apiKey := strings.TrimSpace(string(ctx.Request.Header.Peek("x-api-key"))); apiKey != "" { + if strings.HasPrefix(strings.ToLower(apiKey), governance.VirtualKeyPrefix) { + return apiKey + } + } + + return "" +} diff --git a/transports/bifrost-http/handlers/middlewares.go b/transports/bifrost-http/handlers/middlewares.go index adb09ca8e0..6e4f5062c2 100644 --- a/transports/bifrost-http/handlers/middlewares.go +++ b/transports/bifrost-http/handlers/middlewares.go @@ -3,7 +3,6 @@ package handlers import ( "context" "encoding/base64" - "encoding/json" "fmt" "slices" "strings" @@ -13,13 +12,13 @@ import ( "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" "github.com/maximhq/bifrost/framework/encrypt" - "github.com/maximhq/bifrost/plugins/governance" + "github.com/maximhq/bifrost/framework/tracing" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" ) // CorsMiddleware handles CORS headers for localhost and configured allowed origins -func CorsMiddleware(config *lib.Config) lib.BifrostHTTPMiddleware { +func CorsMiddleware(config *lib.Config) schemas.BifrostHTTPMiddleware { return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { logger.Debug("CorsMiddleware: %s", string(ctx.Path())) @@ -47,101 +46,108 @@ func CorsMiddleware(config *lib.Config) lib.BifrostHTTPMiddleware { } } -// TransportInterceptorMiddleware collects all plugin interceptors and calls them one by one -func TransportInterceptorMiddleware(config *lib.Config) lib.BifrostHTTPMiddleware { +// TransportInterceptorMiddleware runs all plugin HTTP transport interceptors. +// It converts the fasthttp request to a serializable HTTPRequest, runs all plugin interceptors, +// and applies any modifications back to the fasthttp context. +func TransportInterceptorMiddleware(config *lib.Config) schemas.BifrostHTTPMiddleware { return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { - // Get plugins from config - lock-free read plugins := config.GetLoadedPlugins() if len(plugins) == 0 { next(ctx) return } - // If governance plugin is not loaded, skip interception - hasGovernance := false - for _, p := range plugins { - if p.GetName() == governance.PluginName { - hasGovernance = true - break + // Get or create BifrostContext from fasthttp context + bifrostCtx := getBifrostContextFromFastHTTP(ctx) + // Acquire pooled request + req := schemas.AcquireHTTPRequest() + defer schemas.ReleaseHTTPRequest(req) + fasthttpToHTTPRequest(ctx, req) + // Run plugin interceptors + for _, plugin := range plugins { + resp, err := plugin.HTTPTransportIntercept(bifrostCtx, req) + if err != nil { + // Short-circuit with error + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.SetBodyString(err.Error()) + return + } + if resp != nil { + // Short-circuit with response + applyHTTPResponseToCtx(ctx, resp) + return } + // If we got here, the plugin may have modified req in-place } - if !hasGovernance { - next(ctx) - return + // Apply modifications back to fasthttp context + applyHTTPRequestToCtx(ctx, req) + // Adding user values + for key, value := range bifrostCtx.GetUserValues() { + ctx.SetUserValue(key, value) } + next(ctx) + } + } +} - // Parse headers - headers := make(map[string]string) - originalHeaderNames := make([]string, 0, 16) - ctx.Request.Header.All()(func(key, value []byte) bool { - name := string(key) - headers[name] = string(value) - originalHeaderNames = append(originalHeaderNames, name) +// getBifrostContextFromFastHTTP gets or creates a BifrostContext from fasthttp context. +func getBifrostContextFromFastHTTP(ctx *fasthttp.RequestCtx) *schemas.BifrostContext { + return schemas.NewBifrostContext(ctx, schemas.NoDeadline) +} - return true - }) - requestBody := make(map[string]any) - // Only read body if Content-Type is JSON to avoid consuming multipart/form-data streams - contentType := string(ctx.Request.Header.Peek("Content-Type")) - isJSONRequest := strings.HasPrefix(contentType, "application/json") - - // Only run interceptors for JSON requests - if isJSONRequest { - bodyBytes := ctx.Request.Body() - if len(bodyBytes) > 0 { - if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { - // If body is not valid JSON, log warning and continue without interception - logger.Warn(fmt.Sprintf("[transportInterceptor]: Failed to unmarshal request body: %v, skipping interceptor", err)) - next(ctx) - return - } - } - for _, plugin := range plugins { - // Call TransportInterceptor on all plugins - pluginCtx, cancel := schemas.NewBifrostContextWithTimeout(ctx, 10*time.Second) - modifiedHeaders, modifiedBody, err := plugin.TransportInterceptor(pluginCtx, string(ctx.Request.URI().RequestURI()), headers, requestBody) - cancel() - if err != nil { - logger.Warn(fmt.Sprintf("TransportInterceptor: Plugin '%s' returned error: %v", plugin.GetName(), err)) - // Continue with unmodified headers/body - continue - } - // Update headers and body with modifications - if modifiedHeaders != nil { - headers = modifiedHeaders - } - if modifiedBody != nil { - requestBody = modifiedBody - } - // Capturing plugin ctx values and putting them in the request context - for k, v := range pluginCtx.GetUserValues() { - ctx.SetUserValue(k, v) - } - } +// fasthttpToHTTPRequest populates a pooled HTTPRequest from fasthttp context. +func fasthttpToHTTPRequest(ctx *fasthttp.RequestCtx, req *schemas.HTTPRequest) { + req.Method = string(ctx.Method()) + req.Path = string(ctx.Path()) - // Marshal the body back to JSON - updatedBody, err := json.Marshal(requestBody) - if err != nil { - SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("TransportInterceptor: Failed to marshal request body: %v", err)) - return - } - ctx.Request.SetBody(updatedBody) + // Copy headers + for key, value := range ctx.Request.Header.All() { + req.Headers[string(key)] = string(value) + } - // Remove headers that were present originally but removed by plugins - for _, name := range originalHeaderNames { - if _, exists := headers[name]; !exists { - ctx.Request.Header.Del(name) - } - } + // Copy query params + for key, value := range ctx.Request.URI().QueryArgs().All() { + req.Query[string(key)] = string(value) + } - // Set modified headers back on the request - for key, value := range headers { - ctx.Request.Header.Set(key, value) - } - } + // Copy body + body := ctx.Request.Body() + if len(body) > 0 { + req.Body = make([]byte, len(body)) + copy(req.Body, body) + } +} - next(ctx) - } +// applyHTTPRequestToCtx applies modifications from HTTPRequest back to fasthttp context. +func applyHTTPRequestToCtx(ctx *fasthttp.RequestCtx, req *schemas.HTTPRequest) { + // If path/method is different, throw error + if req.Method != string(ctx.Method()) || req.Path != string(ctx.Path()) { + logger.Error("request method/path mismatch: %s %s != %s %s", req.Method, req.Path, string(ctx.Method()), string(ctx.Path())) + SendError(ctx, fasthttp.StatusConflict, "request method/path was modified by a plugin, this is not allowed") + return + } + // Apply headers + for key, value := range req.Headers { + ctx.Request.Header.Set(key, value) + } + // Apply query params + for key, value := range req.Query { + ctx.Request.URI().QueryArgs().Set(key, value) + } + // Apply body if set + if req.Body != nil { + ctx.Request.SetBody(req.Body) + } +} + +// applyHTTPResponseToCtx writes a short-circuit response to fasthttp context. +func applyHTTPResponseToCtx(ctx *fasthttp.RequestCtx, resp *schemas.HTTPResponse) { + ctx.SetStatusCode(resp.StatusCode) + for key, value := range resp.Headers { + ctx.Response.Header.Set(key, value) + } + if resp.Body != nil { + ctx.SetBody(resp.Body) } } @@ -183,7 +189,7 @@ func (m *AuthMiddleware) UpdateAuthConfig(authConfig *configstore.AuthConfig) { } // InferenceMiddleware is for inference requests if authConfig is set, it will skip authentication if disableAuthOnInference is true. -func (m *AuthMiddleware) InferenceMiddleware() lib.BifrostHTTPMiddleware { +func (m *AuthMiddleware) InferenceMiddleware() schemas.BifrostHTTPMiddleware { return m.middleware(func(authConfig *configstore.AuthConfig, url string) bool { return authConfig.DisableAuthOnInference }) @@ -197,7 +203,7 @@ func (m *AuthMiddleware) InferenceMiddleware() lib.BifrostHTTPMiddleware { // // Basic auth may be acceptable for limited use cases, while Bearer and WebSocket flows provide // session-based authentication suitable for production environments. -func (m *AuthMiddleware) APIMiddleware() lib.BifrostHTTPMiddleware { +func (m *AuthMiddleware) APIMiddleware() schemas.BifrostHTTPMiddleware { whitelistedRoutes := []string{ "/api/session/is-auth-enabled", "/api/session/login", @@ -209,7 +215,7 @@ func (m *AuthMiddleware) APIMiddleware() lib.BifrostHTTPMiddleware { }) } -func (m *AuthMiddleware) middleware(shouldSkip func(*configstore.AuthConfig, string) bool) lib.BifrostHTTPMiddleware { +func (m *AuthMiddleware) middleware(shouldSkip func(*configstore.AuthConfig, string) bool) schemas.BifrostHTTPMiddleware { return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { authConfig := m.authConfig.Load() @@ -301,3 +307,151 @@ func (m *AuthMiddleware) middleware(shouldSkip func(*configstore.AuthConfig, str } } } + +// TracingMiddleware creates distributed traces for requests and forwards completed traces +// to observability plugins after the response has been written. +// +// The middleware: +// 1. Extracts parent trace ID from incoming W3C traceparent header (if present) +// 2. Creates a new trace in the store (only the lightweight trace ID is stored in context) +// 3. Calls the next handler to process the request +// 4. After response is written, asynchronously completes the trace and forwards it to observability plugins +// +// This middleware should be placed early in the middleware chain to capture the full request lifecycle. +type TracingMiddleware struct { + tracer atomic.Pointer[tracing.Tracer] + obsPlugins atomic.Pointer[[]schemas.ObservabilityPlugin] +} + +// NewTracingMiddleware creates a new tracing middleware +func NewTracingMiddleware(tracer *tracing.Tracer, obsPlugins []schemas.ObservabilityPlugin) *TracingMiddleware { + tm := &TracingMiddleware{ + tracer: atomic.Pointer[tracing.Tracer]{}, + obsPlugins: atomic.Pointer[[]schemas.ObservabilityPlugin]{}, + } + tm.tracer.Store(tracer) + tm.obsPlugins.Store(&obsPlugins) + return tm +} + +// SetObservabilityPlugins sets the observability plugins for the tracing middleware +func (m *TracingMiddleware) SetObservabilityPlugins(obsPlugins []schemas.ObservabilityPlugin) { + m.obsPlugins.Store(&obsPlugins) +} + +// SetTracer sets the tracer for the tracing middleware +func (m *TracingMiddleware) SetTracer(tracer *tracing.Tracer) { + m.tracer.Store(tracer) +} + +// Middleware returns the middleware function that creates distributed traces for requests and forwards completed traces +func (m *TracingMiddleware) Middleware() schemas.BifrostHTTPMiddleware { + return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + // Skip if store is nil + if m.tracer.Load() == nil { + next(ctx) + return + } + // Extract trace ID from W3C traceparent header (if present) + // This is the 32-char trace ID that links all spans in a distributed trace + inheritedTraceID := tracing.ExtractParentID(&ctx.Request.Header) + // Create trace in store - only ID returned (trace data stays in store) + traceID := m.tracer.Load().CreateTrace(inheritedTraceID) + // Only trace ID goes into context (lightweight, no bloat) + ctx.SetUserValue(schemas.BifrostContextKeyTraceID, traceID) + + // Extract parent span ID from W3C traceparent header (if present) + // This is the 16-char span ID from the upstream service that should be + // set as the ParentID of our root span for proper trace linking in Datadog/etc. + parentSpanID := tracing.ExtractTraceParentSpanID(&ctx.Request.Header) + if parentSpanID != "" { + ctx.SetUserValue(schemas.BifrostContextKeyParentSpanID, parentSpanID) + } + + // Store a trace completion callback for streaming handlers to use + ctx.SetUserValue(schemas.BifrostContextKeyTraceCompleter, func() { + m.completeAndFlushTrace(traceID) + }) + // Create root span for the HTTP request + spanCtx, rootSpan := m.tracer.Load().StartSpan(ctx, string(ctx.RequestURI()), schemas.SpanKindHTTPRequest) + if rootSpan != nil { + m.tracer.Load().SetAttribute(rootSpan, "http.method", string(ctx.Method())) + m.tracer.Load().SetAttribute(rootSpan, "http.url", string(ctx.RequestURI())) + m.tracer.Load().SetAttribute(rootSpan, "http.user_agent", string(ctx.Request.Header.UserAgent())) + // Set root span ID in context for child span creation + if spanID, ok := spanCtx.Value(schemas.BifrostContextKeySpanID).(string); ok { + ctx.SetUserValue(schemas.BifrostContextKeySpanID, spanID) + } + } + defer func() { + // Record response status on the root span + if rootSpan != nil { + m.tracer.Load().SetAttribute(rootSpan, "http.status_code", ctx.Response.StatusCode()) + if ctx.Response.StatusCode() >= 400 { + m.tracer.Load().EndSpan(rootSpan, schemas.SpanStatusError, fmt.Sprintf("HTTP %d", ctx.Response.StatusCode())) + } else { + m.tracer.Load().EndSpan(rootSpan, schemas.SpanStatusOk, "") + } + } + // Check if trace completion is deferred (for streaming requests) + // If deferred, the streaming handler will complete the trace after stream ends + if deferred, ok := ctx.UserValue(schemas.BifrostContextKeyDeferTraceCompletion).(bool); ok && deferred { + return + } + // After response written - async flush + m.completeAndFlushTrace(traceID) + }() + + next(ctx) + } + } +} + +// completeAndFlushTrace completes the trace and forwards it to observability plugins. +// This is called either by the middleware defer (for non-streaming) or by streaming handlers. +func (m *TracingMiddleware) completeAndFlushTrace(traceID string) { + go func() { + // Clean up the stream accumulator for this trace + + // Get completed trace from store + completedTrace := m.tracer.Load().EndTrace(traceID) + if completedTrace == nil { + return + } + // Forward to all observability plugins + for _, plugin := range *m.obsPlugins.Load() { + if plugin == nil { + continue + } + // Call inject with a background context (request context is done) + if err := plugin.Inject(context.Background(), completedTrace); err != nil { + logger.Warn("observability plugin %s failed to inject trace: %v", plugin.GetName(), err) + } + } + // Return trace to pool for reuse + m.tracer.Load().ReleaseTrace(completedTrace) + }() +} + +// GetTracer returns the tracer instance for use by streaming handlers +func (m *TracingMiddleware) GetTracer() *tracing.Tracer { + return m.tracer.Load() +} + +// GetObservabilityPlugins filters and returns only observability plugins from a list of plugins. +// Uses Go type assertion to identify plugins implementing the ObservabilityPlugin interface. +func GetObservabilityPlugins(plugins []schemas.Plugin) []schemas.ObservabilityPlugin { + if len(plugins) == 0 { + return nil + } + + obsPlugins := make([]schemas.ObservabilityPlugin, 0) + for _, plugin := range plugins { + if obsPlugin, ok := plugin.(schemas.ObservabilityPlugin); ok { + obsPlugins = append(obsPlugins, obsPlugin) + } + } + + return obsPlugins +} diff --git a/transports/bifrost-http/handlers/middlewares_test.go b/transports/bifrost-http/handlers/middlewares_test.go index e4e14c74ed..2feb346f9b 100644 --- a/transports/bifrost-http/handlers/middlewares_test.go +++ b/transports/bifrost-http/handlers/middlewares_test.go @@ -12,12 +12,12 @@ import ( // mockLogger is a mock implementation of schemas.Logger for testing type mockLogger struct{} -func (m *mockLogger) Debug(format string, args ...any) {} -func (m *mockLogger) Info(format string, args ...any) {} -func (m *mockLogger) Warn(format string, args ...any) {} -func (m *mockLogger) Error(format string, args ...any) {} -func (m *mockLogger) Fatal(format string, args ...any) {} -func (m *mockLogger) SetLevel(level schemas.LogLevel) {} +func (m *mockLogger) Debug(format string, args ...any) {} +func (m *mockLogger) Info(format string, args ...any) {} +func (m *mockLogger) Warn(format string, args ...any) {} +func (m *mockLogger) Error(format string, args ...any) {} +func (m *mockLogger) Fatal(format string, args ...any) {} +func (m *mockLogger) SetLevel(level schemas.LogLevel) {} func (m *mockLogger) SetOutputType(outputType schemas.LoggerOutputType) {} // TestCorsMiddleware_LocalhostOrigins tests that localhost origins are always allowed @@ -305,7 +305,7 @@ func TestChainMiddlewares_SingleMiddleware(t *testing.T) { middlewareCalled := false handlerCalled := false - middleware := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { middlewareCalled = true next(ctx) @@ -332,21 +332,21 @@ func TestChainMiddlewares_MultipleMiddlewares(t *testing.T) { ctx := &fasthttp.RequestCtx{} executionOrder := []int{} - middleware1 := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware1 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { executionOrder = append(executionOrder, 1) next(ctx) } }) - middleware2 := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware2 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { executionOrder = append(executionOrder, 2) next(ctx) } }) - middleware3 := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware3 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { executionOrder = append(executionOrder, 3) next(ctx) @@ -378,7 +378,7 @@ func TestChainMiddlewares_MultipleMiddlewares(t *testing.T) { func TestChainMiddlewares_MiddlewareCanModifyContext(t *testing.T) { ctx := &fasthttp.RequestCtx{} - middleware := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { ctx.SetUserValue("test-key", "test-value") next(ctx) @@ -405,7 +405,7 @@ func TestChainMiddlewares_ShortCircuit(t *testing.T) { executionOrder := []int{} // First middleware - writes response and short-circuits by not calling next - middleware1 := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware1 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { executionOrder = append(executionOrder, 1) ctx.SetStatusCode(fasthttp.StatusUnauthorized) @@ -415,7 +415,7 @@ func TestChainMiddlewares_ShortCircuit(t *testing.T) { }) // Second middleware - should NOT execute when middleware1 short-circuits - middleware2 := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware2 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { executionOrder = append(executionOrder, 2) next(ctx) @@ -423,7 +423,7 @@ func TestChainMiddlewares_ShortCircuit(t *testing.T) { }) // Third middleware - should NOT execute when middleware1 short-circuits - middleware3 := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware3 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { executionOrder = append(executionOrder, 3) next(ctx) @@ -469,7 +469,7 @@ func TestChainMiddlewares_ShortCircuitMiddlePosition(t *testing.T) { executionOrder := []int{} // First middleware - executes and calls next - middleware1 := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware1 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { executionOrder = append(executionOrder, 1) next(ctx) @@ -477,7 +477,7 @@ func TestChainMiddlewares_ShortCircuitMiddlePosition(t *testing.T) { }) // Second middleware - writes response and short-circuits - middleware2 := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware2 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { executionOrder = append(executionOrder, 2) ctx.SetStatusCode(fasthttp.StatusUnauthorized) @@ -487,7 +487,7 @@ func TestChainMiddlewares_ShortCircuitMiddlePosition(t *testing.T) { }) // Third middleware - should NOT execute - middleware3 := lib.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { + middleware3 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { executionOrder = append(executionOrder, 3) next(ctx) @@ -744,3 +744,86 @@ func TestAuthMiddleware_UpdateAuthConfig_EnabledToDisabled(t *testing.T) { t.Error("Second request should pass after auth is disabled") } } + +// TestFasthttpToHTTPRequest tests the conversion from fasthttp context to HTTPRequest +func TestFasthttpToHTTPRequest(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + + // Set up test data + ctx.Request.Header.SetMethod("POST") + // Query params include: integers, floats, booleans, timestamps, and strings with special chars + ctx.Request.SetRequestURI("/api/v1/test?limit=100&offset=50&min_cost=12.50&max_latency=1500.75&missing_cost_only=true&start_time=2023-01-15T10:30:00Z&content_search=test+query&special=%2B%26%3D%3F") + ctx.Request.Header.Set("Content-Type", "application/json") + ctx.Request.Header.Set("Authorization", "Bearer token123") + ctx.Request.Header.Set("X-Request-Id", "12345") + ctx.Request.Header.Set("X-Custom-Header", "value-with-dashes") + ctx.Request.SetBodyString(`{"key": "value", "number": 42, "nested": {"bool": true}}`) + + // Acquire HTTPRequest from pool + req := schemas.AcquireHTTPRequest() + defer schemas.ReleaseHTTPRequest(req) + + // Call the function + fasthttpToHTTPRequest(ctx, req) + + // Verify Method + if req.Method != "POST" { + t.Errorf("Expected Method to be 'POST', got '%s'", req.Method) + } + + // Verify Path (without query params) + if req.Path != "/api/v1/test" { + t.Errorf("Expected Path to be '/api/v1/test', got '%s'", req.Path) + } + + // Verify Headers + expectedHeaders := map[string]string{ + "Content-Type": "application/json", + "Authorization": "Bearer token123", + "X-Request-Id": "12345", + "X-Custom-Header": "value-with-dashes", + } + for key, expectedValue := range expectedHeaders { + if actualValue, exists := req.Headers[key]; !exists { + t.Errorf("Expected header '%s' to exist", key) + } else if actualValue != expectedValue { + t.Errorf("Expected header '%s' to be '%s', got '%s'", key, expectedValue, actualValue) + } + } + + // Verify Query params + expectedQuery := map[string]string{ + "limit": "100", // integer + "offset": "50", // integer + "min_cost": "12.50", // float + "max_latency": "1500.75", // float + "missing_cost_only": "true", // boolean + "start_time": "2023-01-15T10:30:00Z", // timestamp + "content_search": "test query", // string with space (decoded) + "special": "+&=?", // special characters (decoded) + } + for key, expectedValue := range expectedQuery { + if actualValue, exists := req.Query[key]; !exists { + t.Errorf("Expected query param '%s' to exist", key) + } else if actualValue != expectedValue { + t.Errorf("Expected query param '%s' to be '%s', got '%s'", key, expectedValue, actualValue) + } + } + + // Verify Body (JSON with various types) + expectedBody := `{"key": "value", "number": 42, "nested": {"bool": true}}` + if string(req.Body) != expectedBody { + t.Errorf("Expected Body to be '%s', got '%s'", expectedBody, string(req.Body)) + } + + // Verify body is a copy, not a reference + originalBody := ctx.Request.Body() + if len(req.Body) > 0 && len(originalBody) > 0 { + // Modify the HTTPRequest body + req.Body[0] = 'X' + // Original should remain unchanged + if originalBody[0] == 'X' { + t.Error("Body should be a copy, not a reference to the original") + } + } +} diff --git a/transports/bifrost-http/handlers/plugins.go b/transports/bifrost-http/handlers/plugins.go index b15cb328f0..6939fe5d69 100644 --- a/transports/bifrost-http/handlers/plugins.go +++ b/transports/bifrost-http/handlers/plugins.go @@ -50,7 +50,7 @@ type UpdatePluginRequest struct { } // RegisterRoutes registers the routes for the PluginsHandler -func (h *PluginsHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *PluginsHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { r.GET("/api/plugins", lib.ChainMiddlewares(h.getPlugins, middlewares...)) r.GET("/api/plugins/{name}", lib.ChainMiddlewares(h.getPlugin, middlewares...)) r.POST("/api/plugins", lib.ChainMiddlewares(h.createPlugin, middlewares...)) diff --git a/transports/bifrost-http/handlers/providers.go b/transports/bifrost-http/handlers/providers.go index 63811098da..76d6c2f2c3 100644 --- a/transports/bifrost-http/handlers/providers.go +++ b/transports/bifrost-http/handlers/providers.go @@ -77,7 +77,7 @@ type ErrorResponse struct { } // RegisterRoutes registers all provider management routes -func (h *ProviderHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *ProviderHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { // Provider CRUD operations r.GET("/api/providers", lib.ChainMiddlewares(h.listProviders, middlewares...)) r.GET("/api/providers/{provider}", lib.ChainMiddlewares(h.getProvider, middlewares...)) diff --git a/transports/bifrost-http/handlers/session.go b/transports/bifrost-http/handlers/session.go index 646c28fd3c..e72ed52c6e 100644 --- a/transports/bifrost-http/handlers/session.go +++ b/transports/bifrost-http/handlers/session.go @@ -8,6 +8,7 @@ import ( "github.com/fasthttp/router" "github.com/google/uuid" + "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/encrypt" @@ -28,7 +29,7 @@ func NewSessionHandler(configStore configstore.ConfigStore) *SessionHandler { } // RegisterRoutes registers the session-related routes -func (h *SessionHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *SessionHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { r.POST("/api/session/login", lib.ChainMiddlewares(h.login, middlewares...)) r.POST("/api/session/logout", lib.ChainMiddlewares(h.logout, middlewares...)) r.GET("/api/session/is-auth-enabled", lib.ChainMiddlewares(h.isAuthEnabled, middlewares...)) diff --git a/transports/bifrost-http/handlers/ui.go b/transports/bifrost-http/handlers/ui.go index cd42ad7dcb..e0872e249c 100644 --- a/transports/bifrost-http/handlers/ui.go +++ b/transports/bifrost-http/handlers/ui.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/fasthttp/router" + "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" ) @@ -25,7 +26,7 @@ func NewUIHandler(uiContent embed.FS) *UIHandler { } // RegisterRoutes registers the UI routes with the provided router. -func (h *UIHandler) RegisterRoutes(router *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *UIHandler) RegisterRoutes(router *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { router.GET("/", lib.ChainMiddlewares(h.serveDashboard, middlewares...)) router.GET("/{filepath:*}", lib.ChainMiddlewares(h.serveDashboard, middlewares...)) } diff --git a/transports/bifrost-http/handlers/websocket.go b/transports/bifrost-http/handlers/websocket.go index eb4b05f5aa..9fab07a00e 100644 --- a/transports/bifrost-http/handlers/websocket.go +++ b/transports/bifrost-http/handlers/websocket.go @@ -11,6 +11,7 @@ import ( "github.com/fasthttp/router" "github.com/fasthttp/websocket" + "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/logstore" "github.com/maximhq/bifrost/plugins/logging" "github.com/maximhq/bifrost/transports/bifrost-http/lib" @@ -47,7 +48,7 @@ func NewWebSocketHandler(ctx context.Context, logManager logging.LogManager, all } // RegisterRoutes registers all WebSocket-related routes -func (h *WebSocketHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (h *WebSocketHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { r.GET("/ws", lib.ChainMiddlewares(h.connectStream, middlewares...)) } diff --git a/transports/bifrost-http/integrations/anthropic.go b/transports/bifrost-http/integrations/anthropic.go index e3d05a342d..ba706f7c25 100644 --- a/transports/bifrost-http/integrations/anthropic.go +++ b/transports/bifrost-http/integrations/anthropic.go @@ -1,7 +1,6 @@ package integrations import ( - "context" "errors" "fmt" "io" @@ -31,7 +30,7 @@ func createAnthropicCompleteRouteConfig(pathPrefix string) RouteConfig { GetRequestTypeInstance: func() interface{} { return &anthropic.AnthropicTextRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if anthropicReq, ok := req.(*anthropic.AnthropicTextRequest); ok { return &schemas.BifrostRequest{ TextCompletionRequest: anthropicReq.ToBifrostTextCompletionRequest(), @@ -39,7 +38,7 @@ func createAnthropicCompleteRouteConfig(pathPrefix string) RouteConfig { } return nil, errors.New("invalid request type") }, - TextResponseConverter: func(ctx *context.Context, resp *schemas.BifrostTextCompletionResponse) (interface{}, error) { + TextResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostTextCompletionResponse) (interface{}, error) { if shouldUsePassthrough(ctx, resp.ExtraFields.Provider, resp.ExtraFields.ModelRequested, resp.ExtraFields.ModelDeployment) { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil @@ -47,7 +46,7 @@ func createAnthropicCompleteRouteConfig(pathPrefix string) RouteConfig { } return anthropic.ToAnthropicTextCompletionResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return anthropic.ToAnthropicChatCompletionError(err) }, } @@ -67,15 +66,15 @@ func createAnthropicMessagesRouteConfig(pathPrefix string) []RouteConfig { GetRequestTypeInstance: func() interface{} { return &anthropic.AnthropicMessageRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if anthropicReq, ok := req.(*anthropic.AnthropicMessageRequest); ok { return &schemas.BifrostRequest{ - ResponsesRequest: anthropicReq.ToBifrostResponsesRequest(*ctx), + ResponsesRequest: anthropicReq.ToBifrostResponsesRequest(ctx), }, nil } return nil, errors.New("invalid request type") }, - ResponsesResponseConverter: func(ctx *context.Context, resp *schemas.BifrostResponsesResponse) (interface{}, error) { + ResponsesResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesResponse) (interface{}, error) { if isClaudeModel(resp.ExtraFields.ModelRequested, resp.ExtraFields.ModelDeployment, string(resp.ExtraFields.Provider)) { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil @@ -83,11 +82,11 @@ func createAnthropicMessagesRouteConfig(pathPrefix string) []RouteConfig { } return anthropic.ToAnthropicResponsesResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return anthropic.ToAnthropicChatCompletionError(err) }, StreamConfig: &StreamConfig{ - ResponsesStreamResponseConverter: func(ctx *context.Context, resp *schemas.BifrostResponsesStreamResponse) (string, interface{}, error) { + ResponsesStreamResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesStreamResponse) (string, interface{}, error) { if shouldUsePassthrough(ctx, resp.ExtraFields.Provider, resp.ExtraFields.ModelRequested, resp.ExtraFields.ModelDeployment) { if resp.ExtraFields.RawResponse != nil { raw, ok := resp.ExtraFields.RawResponse.(string) @@ -101,7 +100,7 @@ func createAnthropicMessagesRouteConfig(pathPrefix string) []RouteConfig { } return "", nil, nil } - anthropicResponse := anthropic.ToAnthropicResponsesStreamResponse(*ctx, resp) + anthropicResponse := anthropic.ToAnthropicResponsesStreamResponse(ctx, resp) // Can happen for openai lifecycle events if len(anthropicResponse) == 0 { return "", nil, nil @@ -123,7 +122,7 @@ func createAnthropicMessagesRouteConfig(pathPrefix string) []RouteConfig { } } }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return anthropic.ToAnthropicResponsesStreamError(err) }, }, @@ -149,7 +148,7 @@ func CreateAnthropicListModelsRouteConfigs(pathPrefix string, handlerStore lib.H GetRequestTypeInstance: func() interface{} { return &schemas.BifrostListModelsRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if listModelsReq, ok := req.(*schemas.BifrostListModelsRequest); ok { return &schemas.BifrostRequest{ ListModelsRequest: listModelsReq, @@ -157,10 +156,10 @@ func CreateAnthropicListModelsRouteConfigs(pathPrefix string, handlerStore lib.H } return nil, errors.New("invalid request type") }, - ListModelsResponseConverter: func(ctx *context.Context, resp *schemas.BifrostListModelsResponse) (interface{}, error) { + ListModelsResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostListModelsResponse) (interface{}, error) { return anthropic.ToAnthropicListModelsResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return anthropic.ToAnthropicChatCompletionError(err) }, PreCallback: extractAnthropicListModelsParams, @@ -171,7 +170,7 @@ func CreateAnthropicListModelsRouteConfigs(pathPrefix string, handlerStore lib.H // checkAnthropicPassthrough pre-callback checks if the request is for a claude model. // If it is, it attaches the raw request body for direct use by the provider. // It also checks for anthropic oauth headers and sets the bifrost context. -func checkAnthropicPassthrough(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { +func checkAnthropicPassthrough(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { var provider schemas.ModelProvider var model string @@ -204,30 +203,30 @@ func checkAnthropicPassthrough(ctx *fasthttp.RequestCtx, bifrostCtx *context.Con if len(userAgent) > 0 { // Check if it's claude code if strings.Contains(userAgent[0], "claude-cli") { - *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyUserAgent, "claude-cli") + bifrostCtx.SetValue(schemas.BifrostContextKeyUserAgent, "claude-cli") } } } // Check if anthropic oauth headers are present if shouldUsePassthrough(bifrostCtx, provider, model, "") { - *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyUseRawRequestBody, true) + bifrostCtx.SetValue(schemas.BifrostContextKeyUseRawRequestBody, true) if !isAnthropicAPIKeyAuth(ctx) && (provider == schemas.Anthropic || provider == "") { url := extractExactPath(ctx) if !strings.HasPrefix(url, "/") { url = "/" + url } - *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyExtraHeaders, headers) - *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyURLPath, url) - *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeySkipKeySelection, true) + bifrostCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, headers) + bifrostCtx.SetValue(schemas.BifrostContextKeyURLPath, url) + bifrostCtx.SetValue(schemas.BifrostContextKeySkipKeySelection, true) } } return nil } -func shouldUsePassthrough(ctx *context.Context, provider schemas.ModelProvider, model string, deployment string) bool { +func shouldUsePassthrough(ctx *schemas.BifrostContext, provider schemas.ModelProvider, model string, deployment string) bool { isClaudeCode := false - if userAgent, ok := (*ctx).Value(schemas.BifrostContextKeyUserAgent).(string); ok { + if userAgent, ok := ctx.Value(schemas.BifrostContextKeyUserAgent).(string); ok { if strings.Contains(userAgent, "claude-cli") { isClaudeCode = true } @@ -243,7 +242,7 @@ func isClaudeModel(model, deployment, provider string) bool { } // extractAnthropicListModelsParams extracts query parameters for list models request -func extractAnthropicListModelsParams(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { +func extractAnthropicListModelsParams(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { if listModelsReq, ok := req.(*schemas.BifrostListModelsRequest); ok { // Set provider to Anthropic listModelsReq.Provider = schemas.Anthropic @@ -286,18 +285,18 @@ func CreateAnthropicCountTokensRouteConfigs(pathPrefix string, handlerStore lib. GetRequestTypeInstance: func() interface{} { return &anthropic.AnthropicMessageRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if anthropicReq, ok := req.(*anthropic.AnthropicMessageRequest); ok { return &schemas.BifrostRequest{ - CountTokensRequest: anthropicReq.ToBifrostResponsesRequest(*ctx), + CountTokensRequest: anthropicReq.ToBifrostResponsesRequest(ctx), }, nil } return nil, errors.New("invalid request type for Anthropic count tokens") }, - CountTokensResponseConverter: func(ctx *context.Context, resp *schemas.BifrostCountTokensResponse) (interface{}, error) { + CountTokensResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostCountTokensResponse) (interface{}, error) { return anthropic.ToAnthropicCountTokensResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return anthropic.ToAnthropicChatCompletionError(err) }, }, @@ -315,13 +314,13 @@ func CreateAnthropicBatchRouteConfigs(pathPrefix string, handlerStore lib.Handle GetRequestTypeInstance: func() any { return &anthropic.AnthropicBatchCreateRequest{} }, - BatchRequestConverter: func(ctx *context.Context, req any) (*BatchRequest, error) { + BatchRequestConverter: func(ctx *schemas.BifrostContext, req any) (*BatchRequest, error) { if anthropicReq, ok := req.(*anthropic.AnthropicBatchCreateRequest); ok { // Convert Anthropic batch request items to Bifrost format isNonAnthropicProvider := false var provider schemas.ModelProvider var ok bool - if provider, ok = (*ctx).Value(bifrostContextKeyProvider).(schemas.ModelProvider); ok && provider != schemas.Anthropic { + if provider, ok = ctx.Value(bifrostContextKeyProvider).(schemas.ModelProvider); ok && provider != schemas.Anthropic { isNonAnthropicProvider = true } var model *string @@ -370,13 +369,13 @@ func CreateAnthropicBatchRouteConfigs(pathPrefix string, handlerStore lib.Handle } return nil, errors.New("invalid batch create request type") }, - BatchCreateResponseConverter: func(ctx *context.Context, resp *schemas.BifrostBatchCreateResponse) (interface{}, error) { + BatchCreateResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostBatchCreateResponse) (interface{}, error) { if resp.ExtraFields.Provider == schemas.Gemini { resp.ID = strings.Replace(resp.ID, "batches/", "batches-", 1) } return anthropic.ToAnthropicBatchCreateResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return anthropic.ToAnthropicChatCompletionError(err) }, PreCallback: extractAnthropicBatchCreateParams, @@ -390,9 +389,9 @@ func CreateAnthropicBatchRouteConfigs(pathPrefix string, handlerStore lib.Handle GetRequestTypeInstance: func() interface{} { return &anthropic.AnthropicBatchListRequest{} }, - BatchRequestConverter: func(ctx *context.Context, req interface{}) (*BatchRequest, error) { + BatchRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*BatchRequest, error) { if listReq, ok := req.(*anthropic.AnthropicBatchListRequest); ok { - provider, ok := (*ctx).Value(bifrostContextKeyProvider).(schemas.ModelProvider) + provider, ok := ctx.Value(bifrostContextKeyProvider).(schemas.ModelProvider) if !ok { return nil, errors.New("provider not found in context") } @@ -407,7 +406,7 @@ func CreateAnthropicBatchRouteConfigs(pathPrefix string, handlerStore lib.Handle } return nil, errors.New("invalid batch list request type") }, - BatchListResponseConverter: func(ctx *context.Context, resp *schemas.BifrostBatchListResponse) (interface{}, error) { + BatchListResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostBatchListResponse) (interface{}, error) { if resp.ExtraFields.RawResponse != nil && resp.ExtraFields.Provider == schemas.Anthropic { return resp.ExtraFields.RawResponse, nil } @@ -418,7 +417,7 @@ func CreateAnthropicBatchRouteConfigs(pathPrefix string, handlerStore lib.Handle } return anthropic.ToAnthropicBatchListResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return anthropic.ToAnthropicChatCompletionError(err) }, PreCallback: extractAnthropicBatchListQueryParams, @@ -432,9 +431,9 @@ func CreateAnthropicBatchRouteConfigs(pathPrefix string, handlerStore lib.Handle GetRequestTypeInstance: func() interface{} { return &anthropic.AnthropicBatchRetrieveRequest{} }, - BatchRequestConverter: func(ctx *context.Context, req interface{}) (*BatchRequest, error) { + BatchRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*BatchRequest, error) { if retrieveReq, ok := req.(*anthropic.AnthropicBatchRetrieveRequest); ok { - provider := (*ctx).Value(bifrostContextKeyProvider).(schemas.ModelProvider) + provider := ctx.Value(bifrostContextKeyProvider).(schemas.ModelProvider) if provider == schemas.Gemini { retrieveReq.BatchID = strings.Replace(retrieveReq.BatchID, "batches-", "batches/", 1) } @@ -448,13 +447,13 @@ func CreateAnthropicBatchRouteConfigs(pathPrefix string, handlerStore lib.Handle } return nil, errors.New("invalid batch retrieve request type") }, - BatchRetrieveResponseConverter: func(ctx *context.Context, resp *schemas.BifrostBatchRetrieveResponse) (interface{}, error) { + BatchRetrieveResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostBatchRetrieveResponse) (interface{}, error) { if resp.ExtraFields.Provider == schemas.Gemini { resp.ID = strings.Replace(resp.ID, "batches/", "batches-", 1) } return anthropic.ToAnthropicBatchRetrieveResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return anthropic.ToAnthropicChatCompletionError(err) }, PreCallback: extractAnthropicBatchIDFromPath, @@ -468,9 +467,9 @@ func CreateAnthropicBatchRouteConfigs(pathPrefix string, handlerStore lib.Handle GetRequestTypeInstance: func() any { return &anthropic.AnthropicBatchCancelRequest{} }, - BatchRequestConverter: func(ctx *context.Context, req interface{}) (*BatchRequest, error) { + BatchRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*BatchRequest, error) { if cancelReq, ok := req.(*anthropic.AnthropicBatchCancelRequest); ok { - provider := (*ctx).Value(bifrostContextKeyProvider).(schemas.ModelProvider) + provider := ctx.Value(bifrostContextKeyProvider).(schemas.ModelProvider) if provider == schemas.Gemini { cancelReq.BatchID = strings.Replace(cancelReq.BatchID, "batches-", "batches/", 1) } @@ -484,13 +483,13 @@ func CreateAnthropicBatchRouteConfigs(pathPrefix string, handlerStore lib.Handle } return nil, errors.New("invalid batch cancel request type") }, - BatchCancelResponseConverter: func(ctx *context.Context, resp *schemas.BifrostBatchCancelResponse) (interface{}, error) { + BatchCancelResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostBatchCancelResponse) (interface{}, error) { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil } return anthropic.ToAnthropicBatchCancelResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return anthropic.ToAnthropicChatCompletionError(err) }, PreCallback: extractAnthropicBatchIDFromPath, @@ -504,9 +503,9 @@ func CreateAnthropicBatchRouteConfigs(pathPrefix string, handlerStore lib.Handle GetRequestTypeInstance: func() interface{} { return &anthropic.AnthropicBatchResultsRequest{} }, - BatchRequestConverter: func(ctx *context.Context, req interface{}) (*BatchRequest, error) { + BatchRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*BatchRequest, error) { if resultsReq, ok := req.(*anthropic.AnthropicBatchResultsRequest); ok { - provider := (*ctx).Value(bifrostContextKeyProvider).(schemas.ModelProvider) + provider := ctx.Value(bifrostContextKeyProvider).(schemas.ModelProvider) if provider == schemas.Gemini { resultsReq.BatchID = strings.Replace(resultsReq.BatchID, "batches-", "batches/", 1) } @@ -520,13 +519,13 @@ func CreateAnthropicBatchRouteConfigs(pathPrefix string, handlerStore lib.Handle } return nil, errors.New("invalid batch results request type") }, - BatchResultsResponseConverter: func(ctx *context.Context, resp *schemas.BifrostBatchResultsResponse) (interface{}, error) { + BatchResultsResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostBatchResultsResponse) (interface{}, error) { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil } return resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return anthropic.ToAnthropicChatCompletionError(err) }, PreCallback: extractAnthropicBatchIDFromPath, @@ -536,26 +535,26 @@ func CreateAnthropicBatchRouteConfigs(pathPrefix string, handlerStore lib.Handle } // extractAnthropicBatchCreateParams extracts provider from header for batch create requests -func extractAnthropicBatchCreateParams(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { +func extractAnthropicBatchCreateParams(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { // Extract provider from header, default to Anthropic provider := string(ctx.Request.Header.Peek("x-model-provider")) if provider == "" { provider = string(schemas.Anthropic) } // Store provider in context for batch create converter to use - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.ModelProvider(provider)) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.ModelProvider(provider)) return nil } // extractAnthropicBatchListQueryParams extracts provider from header and query parameters for Anthropic batch list requests -func extractAnthropicBatchListQueryParams(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { +func extractAnthropicBatchListQueryParams(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { if listReq, ok := req.(*anthropic.AnthropicBatchListRequest); ok { // Extract provider from header, default to Anthropic provider := string(ctx.Request.Header.Peek("x-model-provider")) if provider == "" { provider = string(schemas.Anthropic) } - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.ModelProvider(provider)) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.ModelProvider(provider)) // Printing all query parameters // Extract limit from query parameters if limitStr := string(ctx.QueryArgs().Peek("page_size")); limitStr != "" { @@ -574,13 +573,13 @@ func extractAnthropicBatchListQueryParams(ctx *fasthttp.RequestCtx, bifrostCtx * } // extractAnthropicBatchIDFromPath extracts provider from header and batch_id from path parameters -func extractAnthropicBatchIDFromPath(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { +func extractAnthropicBatchIDFromPath(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { // Extract provider from header, default to Anthropic provider := string(ctx.Request.Header.Peek("x-model-provider")) if provider == "" { provider = string(schemas.Anthropic) } - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.ModelProvider(provider)) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.ModelProvider(provider)) batchID := ctx.UserValue("batch_id") if batchID == nil { return errors.New("batch_id is required") @@ -601,17 +600,17 @@ func extractAnthropicBatchIDFromPath(ctx *fasthttp.RequestCtx, bifrostCtx *conte } // extractAnthropicFileUploadParams extracts provider from header for file upload requests -func extractAnthropicFileUploadParams(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { +func extractAnthropicFileUploadParams(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { provider := string(ctx.Request.Header.Peek("x-model-provider")) if provider == "" { provider = string(schemas.Anthropic) } - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.ModelProvider(provider)) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.ModelProvider(provider)) return nil } // extractAnthropicFileListQueryParams extracts provider from header and query parameters for Anthropic file list requests -func extractAnthropicFileListQueryParams(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { +func extractAnthropicFileListQueryParams(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { if listReq, ok := req.(*anthropic.AnthropicFileListRequest); ok { // Extract provider from header, default to Anthropic provider := string(ctx.Request.Header.Peek("x-model-provider")) @@ -619,7 +618,7 @@ func extractAnthropicFileListQueryParams(ctx *fasthttp.RequestCtx, bifrostCtx *c provider = string(schemas.Anthropic) } - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.ModelProvider(provider)) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.ModelProvider(provider)) // Extract limit from query parameters if limitStr := string(ctx.QueryArgs().Peek("limit")); limitStr != "" { @@ -641,13 +640,13 @@ func extractAnthropicFileListQueryParams(ctx *fasthttp.RequestCtx, bifrostCtx *c } // extractAnthropicFileIDFromPath extracts provider from header and file_id from path parameters -func extractAnthropicFileIDFromPath(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { +func extractAnthropicFileIDFromPath(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { // Extract provider from header, default to Anthropic provider := string(ctx.Request.Header.Peek("x-model-provider")) if provider == "" { provider = string(schemas.Anthropic) } - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.ModelProvider(provider)) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.ModelProvider(provider)) fileID := ctx.UserValue("file_id") if fileID == nil { return errors.New("file_id is required") @@ -727,10 +726,10 @@ func CreateAnthropicFilesRouteConfigs(pathPrefix string, handlerStore lib.Handle uploadReq.Filename = fileHeader.Filename return nil }, - FileRequestConverter: func(ctx *context.Context, req any) (*FileRequest, error) { + FileRequestConverter: func(ctx *schemas.BifrostContext, req any) (*FileRequest, error) { if uploadReq, ok := req.(*anthropic.AnthropicFileUploadRequest); ok { // Here if provider is OpenAI and purpose is empty then we override it with "batch" - provider, ok := (*ctx).Value(bifrostContextKeyProvider).(schemas.ModelProvider) + provider, ok := ctx.Value(bifrostContextKeyProvider).(schemas.ModelProvider) if !ok { return nil, errors.New("provider not found in context") } @@ -749,7 +748,7 @@ func CreateAnthropicFilesRouteConfigs(pathPrefix string, handlerStore lib.Handle } return nil, errors.New("invalid file upload request type") }, - FileUploadResponseConverter: func(ctx *context.Context, resp *schemas.BifrostFileUploadResponse) (interface{}, error) { + FileUploadResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileUploadResponse) (interface{}, error) { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil } @@ -759,7 +758,7 @@ func CreateAnthropicFilesRouteConfigs(pathPrefix string, handlerStore lib.Handle } return anthropic.ToAnthropicFileUploadResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return anthropic.ToAnthropicChatCompletionError(err) }, PreCallback: extractAnthropicFileUploadParams, @@ -773,9 +772,9 @@ func CreateAnthropicFilesRouteConfigs(pathPrefix string, handlerStore lib.Handle GetRequestTypeInstance: func() interface{} { return &anthropic.AnthropicFileListRequest{} }, - FileRequestConverter: func(ctx *context.Context, req interface{}) (*FileRequest, error) { + FileRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*FileRequest, error) { if listReq, ok := req.(*anthropic.AnthropicFileListRequest); ok { - provider := (*ctx).Value(bifrostContextKeyProvider).(schemas.ModelProvider) + provider := ctx.Value(bifrostContextKeyProvider).(schemas.ModelProvider) return &FileRequest{ Type: schemas.FileListRequest, ListRequest: &schemas.BifrostFileListRequest{ @@ -788,7 +787,7 @@ func CreateAnthropicFilesRouteConfigs(pathPrefix string, handlerStore lib.Handle } return nil, errors.New("invalid file list request type") }, - FileListResponseConverter: func(ctx *context.Context, resp *schemas.BifrostFileListResponse) (interface{}, error) { + FileListResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileListResponse) (interface{}, error) { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil } @@ -800,7 +799,7 @@ func CreateAnthropicFilesRouteConfigs(pathPrefix string, handlerStore lib.Handle } return anthropic.ToAnthropicFileListResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return anthropic.ToAnthropicChatCompletionError(err) }, PreCallback: extractAnthropicFileListQueryParams, @@ -814,9 +813,9 @@ func CreateAnthropicFilesRouteConfigs(pathPrefix string, handlerStore lib.Handle GetRequestTypeInstance: func() interface{} { return &anthropic.AnthropicFileRetrieveRequest{} }, - FileRequestConverter: func(ctx *context.Context, req interface{}) (*FileRequest, error) { + FileRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*FileRequest, error) { if retrieveReq, ok := req.(*anthropic.AnthropicFileRetrieveRequest); ok { - provider := (*ctx).Value(bifrostContextKeyProvider).(schemas.ModelProvider) + provider := ctx.Value(bifrostContextKeyProvider).(schemas.ModelProvider) // Handle file id conversion for Gemini if provider == schemas.Gemini { retrieveReq.FileID = strings.Replace(retrieveReq.FileID, "files-", "files/", 1) @@ -831,13 +830,13 @@ func CreateAnthropicFilesRouteConfigs(pathPrefix string, handlerStore lib.Handle } return nil, errors.New("invalid file retrieve request type") }, - FileRetrieveResponseConverter: func(ctx *context.Context, resp *schemas.BifrostFileRetrieveResponse) (interface{}, error) { + FileRetrieveResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileRetrieveResponse) (interface{}, error) { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil } return anthropic.ToAnthropicFileRetrieveResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return anthropic.ToAnthropicChatCompletionError(err) }, PreCallback: extractAnthropicFileIDFromPath, @@ -851,9 +850,9 @@ func CreateAnthropicFilesRouteConfigs(pathPrefix string, handlerStore lib.Handle GetRequestTypeInstance: func() interface{} { return &anthropic.AnthropicFileDeleteRequest{} }, - FileRequestConverter: func(ctx *context.Context, req interface{}) (*FileRequest, error) { + FileRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*FileRequest, error) { if deleteReq, ok := req.(*anthropic.AnthropicFileDeleteRequest); ok { - provider := (*ctx).Value(bifrostContextKeyProvider).(schemas.ModelProvider) + provider := ctx.Value(bifrostContextKeyProvider).(schemas.ModelProvider) if provider == schemas.Gemini { // Here we will convert fileId to replace files/ with files- deleteReq.FileID = strings.Replace(deleteReq.FileID, "files-", "files/", 1) @@ -868,13 +867,13 @@ func CreateAnthropicFilesRouteConfigs(pathPrefix string, handlerStore lib.Handle } return nil, errors.New("invalid file delete request type") }, - FileDeleteResponseConverter: func(ctx *context.Context, resp *schemas.BifrostFileDeleteResponse) (interface{}, error) { + FileDeleteResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileDeleteResponse) (interface{}, error) { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil } return anthropic.ToAnthropicFileDeleteResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return anthropic.ToAnthropicChatCompletionError(err) }, PreCallback: extractAnthropicFileIDFromPath, diff --git a/transports/bifrost-http/integrations/bedrock.go b/transports/bifrost-http/integrations/bedrock.go index 952e1e64db..c5af8eed8a 100644 --- a/transports/bifrost-http/integrations/bedrock.go +++ b/transports/bifrost-http/integrations/bedrock.go @@ -1,7 +1,6 @@ package integrations import ( - "context" "errors" "fmt" "net/url" @@ -39,7 +38,7 @@ func createBedrockConverseRouteConfig(pathPrefix string, handlerStore lib.Handle GetRequestTypeInstance: func() interface{} { return &bedrock.BedrockConverseRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if bedrockReq, ok := req.(*bedrock.BedrockConverseRequest); ok { bifrostReq, err := bedrockReq.ToBifrostResponsesRequest(ctx) if err != nil { @@ -51,10 +50,10 @@ func createBedrockConverseRouteConfig(pathPrefix string, handlerStore lib.Handle } return nil, errors.New("invalid request type") }, - ResponsesResponseConverter: func(ctx *context.Context, resp *schemas.BifrostResponsesResponse) (interface{}, error) { + ResponsesResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesResponse) (interface{}, error) { return bedrock.ToBedrockConverseResponse(resp) }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return bedrock.ToBedrockError(err) }, PreCallback: bedrockPreCallback(handlerStore), @@ -71,7 +70,7 @@ func createBedrockConverseStreamRouteConfig(pathPrefix string, handlerStore lib. GetRequestTypeInstance: func() interface{} { return &bedrock.BedrockConverseRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if bedrockReq, ok := req.(*bedrock.BedrockConverseRequest); ok { // Mark as streaming request bedrockReq.Stream = true @@ -85,11 +84,11 @@ func createBedrockConverseStreamRouteConfig(pathPrefix string, handlerStore lib. } return nil, errors.New("invalid request type") }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return bedrock.ToBedrockError(err) }, StreamConfig: &StreamConfig{ - ResponsesStreamResponseConverter: func(ctx *context.Context, resp *schemas.BifrostResponsesStreamResponse) (string, interface{}, error) { + ResponsesStreamResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesStreamResponse) (string, interface{}, error) { bedrockEvent, err := bedrock.ToBedrockConverseStreamResponse(resp) if err != nil { return "", nil, err @@ -116,7 +115,7 @@ func createBedrockInvokeWithResponseStreamRouteConfig(pathPrefix string, handler GetRequestTypeInstance: func() interface{} { return &bedrock.BedrockTextCompletionRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if bedrockReq, ok := req.(*bedrock.BedrockTextCompletionRequest); ok { // Mark as streaming request bedrockReq.Stream = true @@ -126,11 +125,11 @@ func createBedrockInvokeWithResponseStreamRouteConfig(pathPrefix string, handler } return nil, errors.New("invalid request type") }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return bedrock.ToBedrockError(err) }, StreamConfig: &StreamConfig{ - TextStreamResponseConverter: func(ctx *context.Context, resp *schemas.BifrostTextCompletionResponse) (string, interface{}, error) { + TextStreamResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostTextCompletionResponse) (string, interface{}, error) { if resp == nil { return "", nil, nil } @@ -161,7 +160,7 @@ func createBedrockInvokeRouteConfig(pathPrefix string, handlerStore lib.HandlerS GetRequestTypeInstance: func() interface{} { return &bedrock.BedrockTextCompletionRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if bedrockReq, ok := req.(*bedrock.BedrockTextCompletionRequest); ok { return &schemas.BifrostRequest{ TextCompletionRequest: bedrockReq.ToBifrostTextCompletionRequest(), @@ -169,10 +168,10 @@ func createBedrockInvokeRouteConfig(pathPrefix string, handlerStore lib.HandlerS } return nil, errors.New("invalid request type") }, - TextResponseConverter: func(ctx *context.Context, resp *schemas.BifrostTextCompletionResponse) (interface{}, error) { + TextResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostTextCompletionResponse) (interface{}, error) { return bedrock.ToBedrockTextCompletionResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return bedrock.ToBedrockError(err) }, PreCallback: bedrockPreCallback(handlerStore), @@ -201,9 +200,9 @@ func createBedrockBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerS GetRequestTypeInstance: func() interface{} { return &bedrock.BedrockBatchJobRequest{} }, - BatchRequestConverter: func(ctx *context.Context, req interface{}) (*BatchRequest, error) { + BatchRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*BatchRequest, error) { if bedrockReq, ok := req.(*bedrock.BedrockBatchJobRequest); ok { - provider := (*ctx).Value(bifrostContextKeyProvider).(schemas.ModelProvider) + provider := ctx.Value(bifrostContextKeyProvider).(schemas.ModelProvider) // Convert Bedrock batch request to Bifrost format // For Bedrock: use S3 URIs directly @@ -269,7 +268,7 @@ func createBedrockBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerS } return nil, errors.New("invalid batch create request type") }, - BatchCreateResponseConverter: func(ctx *context.Context, resp *schemas.BifrostBatchCreateResponse) (interface{}, error) { + BatchCreateResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostBatchCreateResponse) (interface{}, error) { // Only return raw response for native Bedrock calls // For cross-provider routing, always convert to Bedrock format if resp.ExtraFields.RawResponse != nil && resp.ExtraFields.Provider == schemas.Bedrock { @@ -277,16 +276,16 @@ func createBedrockBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerS } return bedrock.ToBedrockBatchJobResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return bedrock.ToBedrockError(err) }, - PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { // Extract provider from header for cross-provider routing provider := string(ctx.Request.Header.Peek("x-model-provider")) if provider != "" { - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.ModelProvider(provider)) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.ModelProvider(provider)) } else { - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.Bedrock) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.Bedrock) } return bedrockBatchPreCallback(handlerStore)(ctx, bifrostCtx, req) }, @@ -300,9 +299,9 @@ func createBedrockBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerS GetRequestTypeInstance: func() interface{} { return &bedrock.BedrockBatchListRequest{} }, - BatchRequestConverter: func(ctx *context.Context, req interface{}) (*BatchRequest, error) { + BatchRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*BatchRequest, error) { if bedrockReq, ok := req.(*bedrock.BedrockBatchListRequest); ok { - provider := (*ctx).Value(bifrostContextKeyProvider).(schemas.ModelProvider) + provider := ctx.Value(bifrostContextKeyProvider).(schemas.ModelProvider) bifrostReq := bedrock.ToBifrostBatchListRequest(bedrockReq, provider) return &BatchRequest{ Type: schemas.BatchListRequest, @@ -311,7 +310,7 @@ func createBedrockBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerS } return nil, errors.New("invalid batch list request type") }, - BatchListResponseConverter: func(ctx *context.Context, resp *schemas.BifrostBatchListResponse) (interface{}, error) { + BatchListResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostBatchListResponse) (interface{}, error) { // Only return raw response for native Bedrock calls // For cross-provider routing, always convert to Bedrock format if resp.ExtraFields.RawResponse != nil && resp.ExtraFields.Provider == schemas.Bedrock { @@ -319,16 +318,16 @@ func createBedrockBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerS } return bedrock.ToBedrockBatchJobListResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return bedrock.ToBedrockError(err) }, - PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { // Extract provider from header for cross-provider routing provider := string(ctx.Request.Header.Peek("x-model-provider")) if provider != "" { - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.ModelProvider(provider)) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.ModelProvider(provider)) } else { - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.Bedrock) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.Bedrock) } return extractBedrockBatchListQueryParams(handlerStore)(ctx, bifrostCtx, req) }, @@ -342,9 +341,9 @@ func createBedrockBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerS GetRequestTypeInstance: func() interface{} { return &bedrock.BedrockBatchRetrieveRequest{} }, - BatchRequestConverter: func(ctx *context.Context, req interface{}) (*BatchRequest, error) { + BatchRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*BatchRequest, error) { if bedrockReq, ok := req.(*bedrock.BedrockBatchRetrieveRequest); ok { - provider := (*ctx).Value(bifrostContextKeyProvider).(schemas.ModelProvider) + provider := ctx.Value(bifrostContextKeyProvider).(schemas.ModelProvider) bifrostReq := bedrock.ToBifrostBatchRetrieveRequest(bedrockReq, provider) return &BatchRequest{ Type: schemas.BatchRetrieveRequest, @@ -353,7 +352,7 @@ func createBedrockBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerS } return nil, errors.New("invalid batch retrieve request type") }, - BatchRetrieveResponseConverter: func(ctx *context.Context, resp *schemas.BifrostBatchRetrieveResponse) (interface{}, error) { + BatchRetrieveResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostBatchRetrieveResponse) (interface{}, error) { // Only return raw response for native Bedrock calls // For cross-provider routing, always convert to Bedrock format if resp.ExtraFields.RawResponse != nil && resp.ExtraFields.Provider == schemas.Bedrock { @@ -361,16 +360,16 @@ func createBedrockBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerS } return bedrock.ToBedrockBatchJobRetrieveResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return bedrock.ToBedrockError(err) }, - PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { // Extract provider from header for cross-provider routing provider := string(ctx.Request.Header.Peek("x-model-provider")) if provider != "" { - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.ModelProvider(provider)) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.ModelProvider(provider)) } else { - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.Bedrock) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.Bedrock) } return extractBedrockJobArnFromPath(handlerStore)(ctx, bifrostCtx, req) }, @@ -384,9 +383,9 @@ func createBedrockBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerS GetRequestTypeInstance: func() interface{} { return &bedrock.BedrockBatchCancelRequest{} }, - BatchRequestConverter: func(ctx *context.Context, req interface{}) (*BatchRequest, error) { + BatchRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*BatchRequest, error) { if bedrockReq, ok := req.(*bedrock.BedrockBatchCancelRequest); ok { - provider := (*ctx).Value(bifrostContextKeyProvider).(schemas.ModelProvider) + provider := ctx.Value(bifrostContextKeyProvider).(schemas.ModelProvider) bifrostReq := bedrock.ToBifrostBatchCancelRequest(bedrockReq, provider) return &BatchRequest{ Type: schemas.BatchCancelRequest, @@ -395,7 +394,7 @@ func createBedrockBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerS } return nil, errors.New("invalid batch cancel request type") }, - BatchCancelResponseConverter: func(ctx *context.Context, resp *schemas.BifrostBatchCancelResponse) (interface{}, error) { + BatchCancelResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostBatchCancelResponse) (interface{}, error) { // Only return raw response for native Bedrock calls // For cross-provider routing, always convert to Bedrock format if resp.ExtraFields.RawResponse != nil && resp.ExtraFields.Provider == schemas.Bedrock { @@ -403,16 +402,16 @@ func createBedrockBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerS } return bedrock.ToBedrockBatchCancelResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return bedrock.ToBedrockError(err) }, - PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { // Extract provider from header for cross-provider routing provider := string(ctx.Request.Header.Peek("x-model-provider")) if provider != "" { - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.ModelProvider(provider)) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.ModelProvider(provider)) } else { - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.Bedrock) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.Bedrock) } return extractBedrockJobArnFromPath(handlerStore)(ctx, bifrostCtx, req) }, @@ -421,8 +420,8 @@ func createBedrockBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerS } // bedrockBatchPreCallback returns a pre-callback for Bedrock batch create requests -func bedrockBatchPreCallback(handlerStore lib.HandlerStore) func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { - return func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { +func bedrockBatchPreCallback(handlerStore lib.HandlerStore) func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { + return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { // Handle direct key authentication if allowed if !handlerStore.ShouldAllowDirectKeys() { return nil @@ -446,7 +445,7 @@ func bedrockBatchPreCallback(handlerStore lib.HandlerStore) func(ctx *fasthttp.R if region != "" { key.BedrockKeyConfig.Region = ®ion } - *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyDirectKey, key) + bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, key) return nil } @@ -469,7 +468,7 @@ func bedrockBatchPreCallback(handlerStore lib.HandlerStore) func(ctx *fasthttp.R key.BedrockKeyConfig.SessionToken = &sessionToken } - *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyDirectKey, key) + bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, key) } return nil @@ -478,7 +477,7 @@ func bedrockBatchPreCallback(handlerStore lib.HandlerStore) func(ctx *fasthttp.R // extractBedrockBatchListQueryParams extracts query parameters for Bedrock batch list requests func extractBedrockBatchListQueryParams(handlerStore lib.HandlerStore) PreRequestCallback { - return func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { // Handle authentication if err := bedrockBatchPreCallback(handlerStore)(ctx, bifrostCtx, req); err != nil { return err @@ -530,7 +529,7 @@ func parseS3URI(uri string) (bucket, key string) { // extractBedrockJobArnFromPath extracts job_arn from path parameters for Bedrock func extractBedrockJobArnFromPath(handlerStore lib.HandlerStore) PreRequestCallback { - return func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { // Handle authentication if err := bedrockBatchPreCallback(handlerStore)(ctx, bifrostCtx, req); err != nil { return err @@ -594,9 +593,9 @@ func createBedrockFilesRouteConfigs(pathPrefix string, handlerStore lib.HandlerS return &bedrock.BedrockFileUploadRequest{} }, RequestParser: parseS3PutObjectRequest, - FileRequestConverter: func(ctx *context.Context, req interface{}) (*FileRequest, error) { + FileRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*FileRequest, error) { if uploadReq, ok := req.(*bedrock.BedrockFileUploadRequest); ok { - provider := (*ctx).Value(bifrostContextKeyProvider).(schemas.ModelProvider) + provider := ctx.Value(bifrostContextKeyProvider).(schemas.ModelProvider) prefix := "" if uploadReq.Key != "" { keyComponents := strings.Split(uploadReq.Key, "/") @@ -620,19 +619,19 @@ func createBedrockFilesRouteConfigs(pathPrefix string, handlerStore lib.HandlerS } return nil, errors.New("invalid file upload request type") }, - FileUploadResponseConverter: func(ctx *context.Context, resp *schemas.BifrostFileUploadResponse) (interface{}, error) { + FileUploadResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileUploadResponse) (interface{}, error) { // S3 PutObject returns empty body with ETag header return nil, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return bedrock.ToS3ErrorXML("InternalError", err.Error.Message, "", "") }, - PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { provider := string(ctx.Request.Header.Peek("x-model-provider")) if provider != "" { - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.ModelProvider(provider)) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.ModelProvider(provider)) } else { - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.Bedrock) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.Bedrock) } return nil }, @@ -647,9 +646,9 @@ func createBedrockFilesRouteConfigs(pathPrefix string, handlerStore lib.HandlerS GetRequestTypeInstance: func() interface{} { return &bedrock.BedrockFileContentRequest{} }, - FileRequestConverter: func(ctx *context.Context, req interface{}) (*FileRequest, error) { + FileRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*FileRequest, error) { if contentReq, ok := req.(*bedrock.BedrockFileContentRequest); ok { - provider := (*ctx).Value(bifrostContextKeyProvider).(schemas.ModelProvider) + provider := ctx.Value(bifrostContextKeyProvider).(schemas.ModelProvider) return &FileRequest{ Type: schemas.FileContentRequest, ContentRequest: &schemas.BifrostFileContentRequest{ @@ -660,11 +659,11 @@ func createBedrockFilesRouteConfigs(pathPrefix string, handlerStore lib.HandlerS } return nil, errors.New("invalid file content request type") }, - FileContentResponseConverter: func(ctx *context.Context, resp *schemas.BifrostFileContentResponse) (interface{}, error) { + FileContentResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileContentResponse) (interface{}, error) { // Return raw content return resp.Content, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return bedrock.ToS3ErrorXML("NoSuchKey", err.Error.Message, "", "") }, PreCallback: extractS3BucketKeyFromPath(handlerStore, "content"), @@ -679,9 +678,9 @@ func createBedrockFilesRouteConfigs(pathPrefix string, handlerStore lib.HandlerS GetRequestTypeInstance: func() interface{} { return &bedrock.BedrockFileRetrieveRequest{} }, - FileRequestConverter: func(ctx *context.Context, req interface{}) (*FileRequest, error) { + FileRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*FileRequest, error) { if retrieveReq, ok := req.(*bedrock.BedrockFileRetrieveRequest); ok { - provider := (*ctx).Value(bifrostContextKeyProvider).(schemas.ModelProvider) + provider := ctx.Value(bifrostContextKeyProvider).(schemas.ModelProvider) return &FileRequest{ Type: schemas.FileRetrieveRequest, RetrieveRequest: &schemas.BifrostFileRetrieveRequest{ @@ -698,19 +697,19 @@ func createBedrockFilesRouteConfigs(pathPrefix string, handlerStore lib.HandlerS } return nil, errors.New("invalid file retrieve request type") }, - FileRetrieveResponseConverter: func(ctx *context.Context, resp *schemas.BifrostFileRetrieveResponse) (interface{}, error) { + FileRetrieveResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileRetrieveResponse) (interface{}, error) { // HEAD returns empty body, headers set in PostCallback return nil, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return nil // HEAD returns no body on error }, - PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { provider := string(ctx.Request.Header.Peek("x-model-provider")) if provider != "" { - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.ModelProvider(provider)) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.ModelProvider(provider)) } else { - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.Bedrock) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.Bedrock) } return extractS3BucketKeyFromPath(handlerStore, "retrieve")(ctx, bifrostCtx, req) }, @@ -725,9 +724,9 @@ func createBedrockFilesRouteConfigs(pathPrefix string, handlerStore lib.HandlerS GetRequestTypeInstance: func() interface{} { return &bedrock.BedrockFileDeleteRequest{} }, - FileRequestConverter: func(ctx *context.Context, req interface{}) (*FileRequest, error) { + FileRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*FileRequest, error) { if deleteReq, ok := req.(*bedrock.BedrockFileDeleteRequest); ok { - provider := (*ctx).Value(bifrostContextKeyProvider).(schemas.ModelProvider) + provider := ctx.Value(bifrostContextKeyProvider).(schemas.ModelProvider) return &FileRequest{ Type: schemas.FileDeleteRequest, DeleteRequest: &schemas.BifrostFileDeleteRequest{ @@ -744,19 +743,19 @@ func createBedrockFilesRouteConfigs(pathPrefix string, handlerStore lib.HandlerS } return nil, errors.New("invalid file delete request type") }, - FileDeleteResponseConverter: func(ctx *context.Context, resp *schemas.BifrostFileDeleteResponse) (interface{}, error) { + FileDeleteResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileDeleteResponse) (interface{}, error) { // S3 DeleteObject returns empty body return nil, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return bedrock.ToS3ErrorXML("InternalError", err.Error.Message, "", "") }, - PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { provider := string(ctx.Request.Header.Peek("x-model-provider")) if provider != "" { - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.ModelProvider(provider)) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.ModelProvider(provider)) } else { - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.Bedrock) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.Bedrock) } return extractS3BucketKeyFromPath(handlerStore, "delete")(ctx, bifrostCtx, req) }, @@ -771,9 +770,9 @@ func createBedrockFilesRouteConfigs(pathPrefix string, handlerStore lib.HandlerS GetRequestTypeInstance: func() interface{} { return &bedrock.BedrockFileListRequest{} }, - FileRequestConverter: func(ctx *context.Context, req interface{}) (*FileRequest, error) { + FileRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*FileRequest, error) { if listReq, ok := req.(*bedrock.BedrockFileListRequest); ok { - provider := (*ctx).Value(bifrostContextKeyProvider).(schemas.ModelProvider) + provider := ctx.Value(bifrostContextKeyProvider).(schemas.ModelProvider) return &FileRequest{ Type: schemas.FileListRequest, ListRequest: &schemas.BifrostFileListRequest{ @@ -789,7 +788,7 @@ func createBedrockFilesRouteConfigs(pathPrefix string, handlerStore lib.HandlerS } return nil, errors.New("invalid file list request type") }, - FileListResponseConverter: func(ctx *context.Context, resp *schemas.BifrostFileListResponse) (interface{}, error) { + FileListResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileListResponse) (interface{}, error) { // Use raw S3 XML response directly if available (passthrough from core provider) if resp.ExtraFields.RawResponse != nil { if rawBytes, ok := resp.ExtraFields.RawResponse.([]byte); ok { @@ -800,26 +799,26 @@ func createBedrockFilesRouteConfigs(pathPrefix string, handlerStore lib.HandlerS bucket := "" prefix := "" maxKeys := 1000 - if b := (*ctx).Value(s3ContextKeyBucket); b != nil { + if b := ctx.Value(s3ContextKeyBucket); b != nil { bucket = b.(string) } - if p := (*ctx).Value(s3ContextKeyPrefix); p != nil { + if p := ctx.Value(s3ContextKeyPrefix); p != nil { prefix = p.(string) } - if m := (*ctx).Value(s3ContextKeyMaxKeys); m != nil { + if m := ctx.Value(s3ContextKeyMaxKeys); m != nil { maxKeys = m.(int) } return bedrock.ToS3ListObjectsV2XML(resp, bucket, prefix, maxKeys), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return bedrock.ToS3ErrorXML("InternalError", err.Error.Message, "", "") }, - PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { provider := string(ctx.Request.Header.Peek("x-model-provider")) if provider != "" { - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.ModelProvider(provider)) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.ModelProvider(provider)) } else { - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.Bedrock) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.Bedrock) } return extractS3ListObjectsV2Params(handlerStore)(ctx, bifrostCtx, req) }, @@ -868,7 +867,7 @@ func parseS3PutObjectRequest(ctx *fasthttp.RequestCtx, req interface{}) error { // extractS3BucketKeyFromPath extracts bucket and key from path for S3 operations func extractS3BucketKeyFromPath(handlerStore lib.HandlerStore, opType string) PreRequestCallback { - return func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { // Handle authentication first if err := bedrockBatchPreCallback(handlerStore)(ctx, bifrostCtx, req); err != nil { return err @@ -883,9 +882,9 @@ func extractS3BucketKeyFromPath(handlerStore lib.HandlerStore, opType string) Pr provider := string(ctx.Request.Header.Peek("x-model-provider")) if provider != "" { - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.ModelProvider(provider)) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.ModelProvider(provider)) } else { - *bifrostCtx = context.WithValue(*bifrostCtx, bifrostContextKeyProvider, schemas.Bedrock) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.Bedrock) } bucketStr := bucket.(string) @@ -923,7 +922,7 @@ func extractS3BucketKeyFromPath(handlerStore lib.HandlerStore, opType string) Pr // extractS3ListObjectsV2Params extracts query params for S3 ListObjectsV2 func extractS3ListObjectsV2Params(handlerStore lib.HandlerStore) PreRequestCallback { - return func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { // Handle authentication first if err := bedrockBatchPreCallback(handlerStore)(ctx, bifrostCtx, req); err != nil { return err @@ -948,9 +947,9 @@ func extractS3ListObjectsV2Params(handlerStore lib.HandlerStore) PreRequestCallb } // Store in context for response formatting - *bifrostCtx = context.WithValue(*bifrostCtx, s3ContextKeyBucket, bucketStr) - *bifrostCtx = context.WithValue(*bifrostCtx, s3ContextKeyPrefix, prefix) - *bifrostCtx = context.WithValue(*bifrostCtx, s3ContextKeyMaxKeys, maxKeys) + bifrostCtx.SetValue(s3ContextKeyBucket, bucketStr) + bifrostCtx.SetValue(s3ContextKeyPrefix, prefix) + bifrostCtx.SetValue(s3ContextKeyMaxKeys, maxKeys) if listReq, ok := req.(*bedrock.BedrockFileListRequest); ok { listReq.MaxKeys = maxKeys @@ -1015,8 +1014,8 @@ func s3ListObjectsV2PostCallback(ctx *fasthttp.RequestCtx, req interface{}, resp } // bedrockPreCallback returns a pre-callback that extracts model ID and handles direct authentication -func bedrockPreCallback(handlerStore lib.HandlerStore) func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { - return func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { +func bedrockPreCallback(handlerStore lib.HandlerStore) func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { + return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { // Extract modelId from path parameter modelIDVal := ctx.UserValue("modelId") if modelIDVal == nil { @@ -1085,7 +1084,7 @@ func bedrockPreCallback(handlerStore lib.HandlerStore) func(ctx *fasthttp.Reques if region != "" { key.BedrockKeyConfig.Region = ®ion } - *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyDirectKey, key) + bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, key) return nil } else if accessKey != "" && secretKey != "" { // Case 2: AWS Credentials Authentication @@ -1109,7 +1108,7 @@ func bedrockPreCallback(handlerStore lib.HandlerStore) func(ctx *fasthttp.Reques key.BedrockKeyConfig.SessionToken = &sessionToken } - *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyDirectKey, key) + bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, key) } return nil diff --git a/transports/bifrost-http/integrations/bedrock_test.go b/transports/bifrost-http/integrations/bedrock_test.go index 7894eb4206..84dc88f1a4 100644 --- a/transports/bifrost-http/integrations/bedrock_test.go +++ b/transports/bifrost-http/integrations/bedrock_test.go @@ -517,7 +517,7 @@ func Test_extractBedrockBatchListQueryParams(t *testing.T) { callback := extractBedrockBatchListQueryParams(handlerStore) bifrostCtx := createTestBifrostContext() - err := callback(ctx, &bifrostCtx, req) + err := callback(ctx, bifrostCtx, req) assert.NoError(t, err) assert.Equal(t, tt.wantMaxResults, req.MaxResults) @@ -591,7 +591,7 @@ func Test_extractBedrockJobArnFromPath(t *testing.T) { callback := extractBedrockJobArnFromPath(handlerStore) bifrostCtx := createTestBifrostContextWithProvider(tt.provider) - err := callback(ctx, &bifrostCtx, req) + err := callback(ctx, bifrostCtx, req) if tt.wantErr { assert.Error(t, err) @@ -664,7 +664,7 @@ func Test_extractS3ListObjectsV2Params(t *testing.T) { callback := extractS3ListObjectsV2Params(handlerStore) bifrostCtx := createTestBifrostContext() - err := callback(ctx, &bifrostCtx, req) + err := callback(ctx, bifrostCtx, req) if tt.wantErr { assert.Error(t, err) @@ -752,7 +752,7 @@ func Test_extractS3BucketKeyFromPath(t *testing.T) { req = &bedrock.BedrockFileDeleteRequest{} } - err := callback(ctx, &bifrostCtx, req) + err := callback(ctx, bifrostCtx, req) if tt.wantErr { assert.Error(t, err) @@ -774,10 +774,14 @@ func Test_extractS3BucketKeyFromPath(t *testing.T) { // Helper functions for creating test contexts -func createTestBifrostContext() context.Context { - return context.WithValue(context.Background(), bifrostContextKeyProvider, schemas.Bedrock) +func createTestBifrostContext() *schemas.BifrostContext { + bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.Bedrock) + return bifrostCtx } -func createTestBifrostContextWithProvider(provider schemas.ModelProvider) context.Context { - return context.WithValue(context.Background(), bifrostContextKeyProvider, provider) +func createTestBifrostContextWithProvider(provider schemas.ModelProvider) *schemas.BifrostContext { + bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + bifrostCtx.SetValue(bifrostContextKeyProvider, provider) + return bifrostCtx } diff --git a/transports/bifrost-http/integrations/cohere.go b/transports/bifrost-http/integrations/cohere.go index a9d0213e99..f7b66b0ef5 100644 --- a/transports/bifrost-http/integrations/cohere.go +++ b/transports/bifrost-http/integrations/cohere.go @@ -1,7 +1,6 @@ package integrations import ( - "context" "errors" bifrost "github.com/maximhq/bifrost/core" @@ -34,7 +33,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { GetRequestTypeInstance: func() interface{} { return &cohere.CohereChatRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if cohereReq, ok := req.(*cohere.CohereChatRequest); ok { return &schemas.BifrostRequest{ ChatRequest: cohereReq.ToBifrostChatRequest(), @@ -42,7 +41,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { } return nil, errors.New("invalid request type") }, - ChatResponseConverter: func(ctx *context.Context, resp *schemas.BifrostChatResponse) (interface{}, error) { + ChatResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostChatResponse) (interface{}, error) { if resp.ExtraFields.Provider == schemas.Cohere { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil @@ -50,11 +49,11 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { } return resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, StreamConfig: &StreamConfig{ - ChatStreamResponseConverter: func(ctx *context.Context, resp *schemas.BifrostChatResponse) (string, interface{}, error) { + ChatStreamResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostChatResponse) (string, interface{}, error) { if resp.ExtraFields.Provider == schemas.Cohere { if resp.ExtraFields.RawResponse != nil { return "", resp.ExtraFields.RawResponse, nil @@ -62,7 +61,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { } return "", resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, }, @@ -75,7 +74,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { GetRequestTypeInstance: func() interface{} { return &cohere.CohereEmbeddingRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if cohereReq, ok := req.(*cohere.CohereEmbeddingRequest); ok { return &schemas.BifrostRequest{ EmbeddingRequest: cohereReq.ToBifrostEmbeddingRequest(), @@ -83,7 +82,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { } return nil, errors.New("invalid embedding request type") }, - EmbeddingResponseConverter: func(ctx *context.Context, resp *schemas.BifrostEmbeddingResponse) (interface{}, error) { + EmbeddingResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostEmbeddingResponse) (interface{}, error) { if resp.ExtraFields.Provider == schemas.Cohere { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil @@ -91,7 +90,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { } return resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, }) @@ -103,7 +102,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { GetRequestTypeInstance: func() interface{} { return &cohere.CohereCountTokensRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if cohereReq, ok := req.(*cohere.CohereCountTokensRequest); ok { return &schemas.BifrostRequest{ CountTokensRequest: cohereReq.ToBifrostResponsesRequest(), @@ -111,7 +110,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { } return nil, errors.New("invalid count tokens request type") }, - CountTokensResponseConverter: func(ctx *context.Context, resp *schemas.BifrostCountTokensResponse) (interface{}, error) { + CountTokensResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostCountTokensResponse) (interface{}, error) { if resp.ExtraFields.Provider == schemas.Cohere { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil @@ -119,7 +118,7 @@ func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig { } return resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, }) diff --git a/transports/bifrost-http/integrations/genai.go b/transports/bifrost-http/integrations/genai.go index 86f9be81a2..38d03344e6 100644 --- a/transports/bifrost-http/integrations/genai.go +++ b/transports/bifrost-http/integrations/genai.go @@ -1,7 +1,6 @@ package integrations import ( - "context" "errors" "fmt" "io" @@ -34,7 +33,7 @@ func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig { GetRequestTypeInstance: func() interface{} { return &gemini.GeminiGenerationRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if geminiReq, ok := req.(*gemini.GeminiGenerationRequest); ok { if geminiReq.IsCountTokens { return &schemas.BifrostRequest{ @@ -60,26 +59,26 @@ func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig { } return nil, errors.New("invalid request type") }, - EmbeddingResponseConverter: func(ctx *context.Context, resp *schemas.BifrostEmbeddingResponse) (interface{}, error) { + EmbeddingResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostEmbeddingResponse) (interface{}, error) { return gemini.ToGeminiEmbeddingResponse(resp), nil }, - ResponsesResponseConverter: func(ctx *context.Context, resp *schemas.BifrostResponsesResponse) (interface{}, error) { + ResponsesResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesResponse) (interface{}, error) { return gemini.ToGeminiResponsesResponse(resp), nil }, - SpeechResponseConverter: func(ctx *context.Context, resp *schemas.BifrostSpeechResponse) (interface{}, error) { + SpeechResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostSpeechResponse) (interface{}, error) { return gemini.ToGeminiSpeechResponse(resp), nil }, - TranscriptionResponseConverter: func(ctx *context.Context, resp *schemas.BifrostTranscriptionResponse) (interface{}, error) { + TranscriptionResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostTranscriptionResponse) (interface{}, error) { return gemini.ToGeminiTranscriptionResponse(resp), nil }, - CountTokensResponseConverter: func(ctx *context.Context, resp *schemas.BifrostCountTokensResponse) (interface{}, error) { + CountTokensResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostCountTokensResponse) (interface{}, error) { return gemini.ToGeminiCountTokensResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return gemini.ToGeminiError(err) }, StreamConfig: &StreamConfig{ - ResponsesStreamResponseConverter: func(ctx *context.Context, resp *schemas.BifrostResponsesStreamResponse) (string, interface{}, error) { + ResponsesStreamResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesStreamResponse) (string, interface{}, error) { geminiResponse := gemini.ToGeminiResponsesStreamResponse(resp) // Skip lifecycle events with no Gemini equivalent if geminiResponse == nil { @@ -87,7 +86,7 @@ func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig { } return "", geminiResponse, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return gemini.ToGeminiError(err) }, }, @@ -101,7 +100,7 @@ func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig { GetRequestTypeInstance: func() interface{} { return &schemas.BifrostListModelsRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if listModelsReq, ok := req.(*schemas.BifrostListModelsRequest); ok { return &schemas.BifrostRequest{ ListModelsRequest: listModelsReq, @@ -109,10 +108,10 @@ func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig { } return nil, errors.New("invalid request type") }, - ListModelsResponseConverter: func(ctx *context.Context, resp *schemas.BifrostListModelsResponse) (interface{}, error) { + ListModelsResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostListModelsResponse) (interface{}, error) { return gemini.ToGeminiListModelsResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return gemini.ToGeminiError(err) }, PreCallback: extractGeminiListModelsParams, @@ -134,7 +133,7 @@ func CreateGenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerStor return &schemas.BifrostFileUploadRequest{} }, RequestParser: parseGeminiFileUploadRequest, - FileRequestConverter: func(ctx *context.Context, req interface{}) (*FileRequest, error) { + FileRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*FileRequest, error) { if uploadReq, ok := req.(*schemas.BifrostFileUploadRequest); ok { uploadReq.Provider = schemas.Gemini return &FileRequest{ @@ -144,13 +143,13 @@ func CreateGenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerStor } return nil, errors.New("invalid file upload request type") }, - FileUploadResponseConverter: func(ctx *context.Context, resp *schemas.BifrostFileUploadResponse) (interface{}, error) { + FileUploadResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileUploadResponse) (interface{}, error) { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil } return gemini.ToGeminiFileUploadResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return gemini.ToGeminiError(err) }, }) @@ -163,7 +162,7 @@ func CreateGenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerStor GetRequestTypeInstance: func() interface{} { return &schemas.BifrostFileListRequest{} }, - FileRequestConverter: func(ctx *context.Context, req interface{}) (*FileRequest, error) { + FileRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*FileRequest, error) { if listReq, ok := req.(*schemas.BifrostFileListRequest); ok { listReq.Provider = schemas.Gemini return &FileRequest{ @@ -173,13 +172,13 @@ func CreateGenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerStor } return nil, errors.New("invalid file list request type") }, - FileListResponseConverter: func(ctx *context.Context, resp *schemas.BifrostFileListResponse) (interface{}, error) { + FileListResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileListResponse) (interface{}, error) { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil } return gemini.ToGeminiFileListResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return gemini.ToGeminiError(err) }, PreCallback: extractGeminiFileListQueryParams, @@ -193,7 +192,7 @@ func CreateGenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerStor GetRequestTypeInstance: func() interface{} { return &schemas.BifrostFileRetrieveRequest{} }, - FileRequestConverter: func(ctx *context.Context, req interface{}) (*FileRequest, error) { + FileRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*FileRequest, error) { if retrieveReq, ok := req.(*schemas.BifrostFileRetrieveRequest); ok { retrieveReq.Provider = schemas.Gemini return &FileRequest{ @@ -203,13 +202,13 @@ func CreateGenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerStor } return nil, errors.New("invalid file retrieve request type") }, - FileRetrieveResponseConverter: func(ctx *context.Context, resp *schemas.BifrostFileRetrieveResponse) (interface{}, error) { + FileRetrieveResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileRetrieveResponse) (interface{}, error) { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil } return gemini.ToGeminiFileRetrieveResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return gemini.ToGeminiError(err) }, PreCallback: extractGeminiFileIDFromPath, @@ -223,7 +222,7 @@ func CreateGenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerStor GetRequestTypeInstance: func() interface{} { return &schemas.BifrostFileDeleteRequest{} }, - FileRequestConverter: func(ctx *context.Context, req interface{}) (*FileRequest, error) { + FileRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*FileRequest, error) { if deleteReq, ok := req.(*schemas.BifrostFileDeleteRequest); ok { deleteReq.Provider = schemas.Gemini return &FileRequest{ @@ -233,13 +232,13 @@ func CreateGenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerStor } return nil, errors.New("invalid file delete request type") }, - FileDeleteResponseConverter: func(ctx *context.Context, resp *schemas.BifrostFileDeleteResponse) (interface{}, error) { + FileDeleteResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileDeleteResponse) (interface{}, error) { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil } return map[string]interface{}{}, nil // Gemini returns empty response on delete }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return gemini.ToGeminiError(err) }, PreCallback: extractGeminiFileIDFromPath, @@ -293,7 +292,7 @@ func parseGeminiFileUploadRequest(ctx *fasthttp.RequestCtx, req interface{}) err } // extractGeminiFileListQueryParams extracts query parameters for Gemini file list requests -func extractGeminiFileListQueryParams(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { +func extractGeminiFileListQueryParams(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { if listReq, ok := req.(*schemas.BifrostFileListRequest); ok { listReq.Provider = schemas.Gemini @@ -314,7 +313,7 @@ func extractGeminiFileListQueryParams(ctx *fasthttp.RequestCtx, bifrostCtx *cont } // extractGeminiFileIDFromPath extracts file_id from path parameters for Gemini -func extractGeminiFileIDFromPath(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { +func extractGeminiFileIDFromPath(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { fileID := ctx.UserValue("file_id") if fileID == nil { return errors.New("file_id is required") @@ -354,7 +353,7 @@ var embeddingPaths = []string{ } // extractAndSetModelFromURL extracts model from URL and sets it in the request -func extractAndSetModelFromURL(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { +func extractAndSetModelFromURL(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { model := ctx.UserValue("model") if model == nil { return fmt.Errorf("model parameter is required") @@ -482,7 +481,7 @@ func isAudioMimeType(mimeType string) bool { } // extractGeminiListModelsParams extracts query parameters for list models request -func extractGeminiListModelsParams(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { +func extractGeminiListModelsParams(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { if listModelsReq, ok := req.(*schemas.BifrostListModelsRequest); ok { // Set provider to Gemini listModelsReq.Provider = schemas.Gemini diff --git a/transports/bifrost-http/integrations/openai.go b/transports/bifrost-http/integrations/openai.go index f243e7616d..99bedcdf7f 100644 --- a/transports/bifrost-http/integrations/openai.go +++ b/transports/bifrost-http/integrations/openai.go @@ -1,7 +1,6 @@ package integrations import ( - "context" "encoding/base64" "encoding/json" "errors" @@ -36,8 +35,8 @@ type OpenAIRouter struct { *GenericRouter } -func AzureEndpointPreHook(handlerStore lib.HandlerStore) func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { - return func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { +func AzureEndpointPreHook(handlerStore lib.HandlerStore) func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { + return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { azureKey := ctx.Request.Header.Peek("authorization") deploymentEndpoint := ctx.Request.Header.Peek("x-bf-azure-endpoint") deploymentID := ctx.UserValue("deployment-id") @@ -114,7 +113,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func() interface{} { return &openai.OpenAITextCompletionRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if openaiReq, ok := req.(*openai.OpenAITextCompletionRequest); ok { return &schemas.BifrostRequest{ TextCompletionRequest: openaiReq.ToBifrostTextCompletionRequest(), @@ -122,7 +121,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) } return nil, errors.New("invalid request type") }, - TextResponseConverter: func(ctx *context.Context, resp *schemas.BifrostTextCompletionResponse) (interface{}, error) { + TextResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostTextCompletionResponse) (interface{}, error) { if resp.ExtraFields.Provider == schemas.OpenAI { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil @@ -130,11 +129,11 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) } return resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, StreamConfig: &StreamConfig{ - TextStreamResponseConverter: func(ctx *context.Context, resp *schemas.BifrostTextCompletionResponse) (string, interface{}, error) { + TextStreamResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostTextCompletionResponse) (string, interface{}, error) { if resp.ExtraFields.Provider == schemas.OpenAI { if resp.ExtraFields.RawResponse != nil { return "", resp.ExtraFields.RawResponse, nil @@ -142,7 +141,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) } return "", resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, }, @@ -163,7 +162,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func() interface{} { return &openai.OpenAIChatRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if openaiReq, ok := req.(*openai.OpenAIChatRequest); ok { return &schemas.BifrostRequest{ ChatRequest: openaiReq.ToBifrostChatRequest(), @@ -171,7 +170,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) } return nil, errors.New("invalid request type") }, - ChatResponseConverter: func(ctx *context.Context, resp *schemas.BifrostChatResponse) (interface{}, error) { + ChatResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostChatResponse) (interface{}, error) { if resp.ExtraFields.Provider == schemas.OpenAI { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil @@ -179,11 +178,11 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) } return resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, StreamConfig: &StreamConfig{ - ChatStreamResponseConverter: func(ctx *context.Context, resp *schemas.BifrostChatResponse) (string, interface{}, error) { + ChatStreamResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostChatResponse) (string, interface{}, error) { if resp.ExtraFields.Provider == schemas.OpenAI { if resp.ExtraFields.RawResponse != nil { return "", resp.ExtraFields.RawResponse, nil @@ -191,7 +190,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) } return "", resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, }, @@ -212,7 +211,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func() interface{} { return &openai.OpenAIResponsesRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if openaiReq, ok := req.(*openai.OpenAIResponsesRequest); ok { return &schemas.BifrostRequest{ ResponsesRequest: openaiReq.ToBifrostResponsesRequest(), @@ -221,7 +220,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) } return nil, errors.New("invalid request type") }, - ResponsesResponseConverter: func(ctx *context.Context, resp *schemas.BifrostResponsesResponse) (interface{}, error) { + ResponsesResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesResponse) (interface{}, error) { if resp.ExtraFields.Provider == schemas.OpenAI { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil @@ -229,11 +228,11 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) } return resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, StreamConfig: &StreamConfig{ - ResponsesStreamResponseConverter: func(ctx *context.Context, resp *schemas.BifrostResponsesStreamResponse) (string, interface{}, error) { + ResponsesStreamResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesStreamResponse) (string, interface{}, error) { if resp.ExtraFields.Provider == schemas.OpenAI { if resp.ExtraFields.RawResponse != nil { return string(resp.Type), resp.ExtraFields.RawResponse, nil @@ -241,7 +240,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) } return string(resp.Type), resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, }, @@ -262,7 +261,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func() interface{} { return &openai.OpenAIResponsesRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if openaiReq, ok := req.(*openai.OpenAIResponsesRequest); ok { return &schemas.BifrostRequest{ CountTokensRequest: openaiReq.ToBifrostResponsesRequest(), @@ -270,7 +269,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) } return nil, errors.New("invalid request type for input tokens") }, - CountTokensResponseConverter: func(ctx *context.Context, resp *schemas.BifrostCountTokensResponse) (interface{}, error) { + CountTokensResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostCountTokensResponse) (interface{}, error) { if resp.ExtraFields.Provider == schemas.OpenAI { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil @@ -278,7 +277,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) } return resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, PreCallback: AzureEndpointPreHook(handlerStore), @@ -298,7 +297,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func() interface{} { return &openai.OpenAIEmbeddingRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if embeddingReq, ok := req.(*openai.OpenAIEmbeddingRequest); ok { return &schemas.BifrostRequest{ EmbeddingRequest: embeddingReq.ToBifrostEmbeddingRequest(), @@ -306,7 +305,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) } return nil, errors.New("invalid embedding request type") }, - EmbeddingResponseConverter: func(ctx *context.Context, resp *schemas.BifrostEmbeddingResponse) (interface{}, error) { + EmbeddingResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostEmbeddingResponse) (interface{}, error) { if resp.ExtraFields.Provider == schemas.OpenAI { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil @@ -314,7 +313,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) } return resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, PreCallback: AzureEndpointPreHook(handlerStore), @@ -334,7 +333,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) GetRequestTypeInstance: func() interface{} { return &openai.OpenAISpeechRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if speechReq, ok := req.(*openai.OpenAISpeechRequest); ok { return &schemas.BifrostRequest{ SpeechRequest: speechReq.ToBifrostSpeechRequest(), @@ -342,11 +341,11 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) } return nil, errors.New("invalid speech request type") }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, StreamConfig: &StreamConfig{ - SpeechStreamResponseConverter: func(ctx *context.Context, resp *schemas.BifrostSpeechStreamResponse) (string, interface{}, error) { + SpeechStreamResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostSpeechStreamResponse) (string, interface{}, error) { if resp.ExtraFields.Provider == schemas.OpenAI { if resp.ExtraFields.RawResponse != nil { return "", resp.ExtraFields.RawResponse, nil @@ -354,7 +353,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) } return "", resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, }, @@ -376,7 +375,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) return &openai.OpenAITranscriptionRequest{} }, RequestParser: parseTranscriptionMultipartRequest, // Handle multipart form parsing - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if transcriptionReq, ok := req.(*openai.OpenAITranscriptionRequest); ok { return &schemas.BifrostRequest{ TranscriptionRequest: transcriptionReq.ToBifrostTranscriptionRequest(), @@ -384,7 +383,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) } return nil, errors.New("invalid transcription request type") }, - TranscriptionResponseConverter: func(ctx *context.Context, resp *schemas.BifrostTranscriptionResponse) (interface{}, error) { + TranscriptionResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostTranscriptionResponse) (interface{}, error) { if resp.ExtraFields.Provider == schemas.OpenAI { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil @@ -392,11 +391,11 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) } return resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, StreamConfig: &StreamConfig{ - TranscriptionStreamResponseConverter: func(ctx *context.Context, resp *schemas.BifrostTranscriptionStreamResponse) (string, interface{}, error) { + TranscriptionStreamResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostTranscriptionStreamResponse) (string, interface{}, error) { if resp.ExtraFields.Provider == schemas.OpenAI { if resp.ExtraFields.RawResponse != nil { return "", resp.ExtraFields.RawResponse, nil @@ -404,7 +403,7 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) } return "", resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, }, @@ -431,7 +430,7 @@ func CreateOpenAIListModelsRouteConfigs(pathPrefix string, handlerStore lib.Hand GetRequestTypeInstance: func() interface{} { return &schemas.BifrostListModelsRequest{} }, - RequestConverter: func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) { + RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if listModelsReq, ok := req.(*schemas.BifrostListModelsRequest); ok { return &schemas.BifrostRequest{ ListModelsRequest: listModelsReq, @@ -439,10 +438,10 @@ func CreateOpenAIListModelsRouteConfigs(pathPrefix string, handlerStore lib.Hand } return nil, errors.New("invalid request type") }, - ListModelsResponseConverter: func(ctx *context.Context, resp *schemas.BifrostListModelsResponse) (interface{}, error) { + ListModelsResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostListModelsResponse) (interface{}, error) { return openai.ToOpenAIListModelsResponse(resp), nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, PreCallback: setQueryParamsAndAzureEndpointPreHook(handlerStore), @@ -457,7 +456,7 @@ func CreateOpenAIListModelsRouteConfigs(pathPrefix string, handlerStore lib.Hand func setQueryParamsAndAzureEndpointPreHook(handlerStore lib.HandlerStore) PreRequestCallback { azureHook := AzureEndpointPreHook(handlerStore) - return func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { // First run the Azure endpoint pre-hook if needed if azureHook != nil { if err := azureHook(ctx, bifrostCtx, req); err != nil { @@ -495,7 +494,7 @@ func CreateOpenAIBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerSt GetRequestTypeInstance: func() interface{} { return &schemas.BifrostBatchCreateRequest{} }, - BatchRequestConverter: func(ctx *context.Context, req interface{}) (*BatchRequest, error) { + BatchRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*BatchRequest, error) { if openaiReq, ok := req.(*schemas.BifrostBatchCreateRequest); ok { switch openaiReq.Provider { case schemas.Gemini: @@ -517,7 +516,7 @@ func CreateOpenAIBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerSt } return nil, errors.New("invalid batch create request type") }, - BatchCreateResponseConverter: func(ctx *context.Context, resp *schemas.BifrostBatchCreateResponse) (interface{}, error) { + BatchCreateResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostBatchCreateResponse) (interface{}, error) { switch resp.ExtraFields.Provider { case schemas.Gemini: resp.ID = strings.Replace(resp.ID, "batches/", "batches-", 1) @@ -528,10 +527,10 @@ func CreateOpenAIBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerSt } return resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, - PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { // Provider is parsed from JSON body (extra_body), default to OpenAI if not set if createReq, ok := req.(*schemas.BifrostBatchCreateRequest); ok { if createReq.Provider == "" { @@ -604,7 +603,7 @@ func CreateOpenAIBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerSt GetRequestTypeInstance: func() interface{} { return &schemas.BifrostBatchListRequest{} }, - BatchRequestConverter: func(ctx *context.Context, req interface{}) (*BatchRequest, error) { + BatchRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*BatchRequest, error) { if listReq, ok := req.(*schemas.BifrostBatchListRequest); ok { if listReq.Provider == "" { listReq.Provider = schemas.OpenAI @@ -616,7 +615,7 @@ func CreateOpenAIBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerSt } return nil, errors.New("invalid batch list request type") }, - BatchListResponseConverter: func(ctx *context.Context, resp *schemas.BifrostBatchListResponse) (interface{}, error) { + BatchListResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostBatchListResponse) (interface{}, error) { switch resp.ExtraFields.Provider { case schemas.Gemini: for i, batch := range resp.Data { @@ -631,7 +630,7 @@ func CreateOpenAIBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerSt } return resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, PreCallback: extractBatchListQueryParams(handlerStore), @@ -650,7 +649,7 @@ func CreateOpenAIBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerSt GetRequestTypeInstance: func() interface{} { return &schemas.BifrostBatchRetrieveRequest{} }, - BatchRequestConverter: func(ctx *context.Context, req interface{}) (*BatchRequest, error) { + BatchRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*BatchRequest, error) { if retrieveReq, ok := req.(*schemas.BifrostBatchRetrieveRequest); ok { if retrieveReq.Provider == "" { retrieveReq.Provider = schemas.OpenAI @@ -671,7 +670,7 @@ func CreateOpenAIBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerSt } return nil, errors.New("invalid batch retrieve request type") }, - BatchRetrieveResponseConverter: func(ctx *context.Context, resp *schemas.BifrostBatchRetrieveResponse) (interface{}, error) { + BatchRetrieveResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostBatchRetrieveResponse) (interface{}, error) { switch resp.ExtraFields.Provider { case schemas.Gemini: resp.ID = strings.Replace(resp.ID, "batches/", "batches-", 1) @@ -682,7 +681,7 @@ func CreateOpenAIBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerSt } return resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, PreCallback: extractBatchIDFromPath(handlerStore), @@ -701,7 +700,7 @@ func CreateOpenAIBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerSt GetRequestTypeInstance: func() interface{} { return &schemas.BifrostBatchCancelRequest{} }, - BatchRequestConverter: func(ctx *context.Context, req interface{}) (*BatchRequest, error) { + BatchRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*BatchRequest, error) { if cancelReq, ok := req.(*schemas.BifrostBatchCancelRequest); ok { if cancelReq.Provider == "" { cancelReq.Provider = schemas.OpenAI @@ -722,7 +721,7 @@ func CreateOpenAIBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerSt } return nil, errors.New("invalid batch cancel request type") }, - BatchCancelResponseConverter: func(ctx *context.Context, resp *schemas.BifrostBatchCancelResponse) (interface{}, error) { + BatchCancelResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostBatchCancelResponse) (interface{}, error) { switch resp.ExtraFields.Provider { case schemas.Gemini: resp.ID = strings.Replace(resp.ID, "batches/", "batches-", 1) @@ -731,7 +730,7 @@ func CreateOpenAIBatchRouteConfigs(pathPrefix string, handlerStore lib.HandlerSt } return resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, PreCallback: extractBatchIDFromPath(handlerStore), @@ -757,7 +756,7 @@ func CreateOpenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerSto return &schemas.BifrostFileUploadRequest{} }, RequestParser: parseOpenAIFileUploadMultipartRequest, - FileRequestConverter: func(ctx *context.Context, req interface{}) (*FileRequest, error) { + FileRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*FileRequest, error) { if uploadReq, ok := req.(*schemas.BifrostFileUploadRequest); ok { return &FileRequest{ Type: schemas.FileUploadRequest, @@ -766,7 +765,7 @@ func CreateOpenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerSto } return nil, errors.New("invalid file upload request type") }, - FileUploadResponseConverter: func(ctx *context.Context, resp *schemas.BifrostFileUploadResponse) (interface{}, error) { + FileUploadResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileUploadResponse) (interface{}, error) { if resp.ExtraFields.RawResponse != nil && resp.ExtraFields.Provider == schemas.OpenAI { return resp.ExtraFields.RawResponse, nil } @@ -780,10 +779,10 @@ func CreateOpenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerSto } return resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, - PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + PreCallback: func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { // Default to OpenAI if provider not set from extra_body if bifrostReq, ok := req.(*schemas.BifrostFileUploadRequest); ok { if bifrostReq.Provider == "" { @@ -807,7 +806,7 @@ func CreateOpenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerSto GetRequestTypeInstance: func() interface{} { return &schemas.BifrostFileListRequest{} }, - FileRequestConverter: func(ctx *context.Context, req interface{}) (*FileRequest, error) { + FileRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*FileRequest, error) { if listReq, ok := req.(*schemas.BifrostFileListRequest); ok { if listReq.Provider == "" { listReq.Provider = schemas.OpenAI @@ -819,7 +818,7 @@ func CreateOpenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerSto } return nil, errors.New("invalid file list request type") }, - FileListResponseConverter: func(ctx *context.Context, resp *schemas.BifrostFileListResponse) (interface{}, error) { + FileListResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileListResponse) (interface{}, error) { if resp.ExtraFields.RawResponse != nil && resp.ExtraFields.Provider == schemas.OpenAI { return resp.ExtraFields.RawResponse, nil } @@ -835,7 +834,7 @@ func CreateOpenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerSto } return resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, PreCallback: extractFileListQueryParams(handlerStore), @@ -854,7 +853,7 @@ func CreateOpenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerSto GetRequestTypeInstance: func() interface{} { return &schemas.BifrostFileRetrieveRequest{} }, - FileRequestConverter: func(ctx *context.Context, req interface{}) (*FileRequest, error) { + FileRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*FileRequest, error) { if retrieveReq, ok := req.(*schemas.BifrostFileRetrieveRequest); ok { if retrieveReq.Provider == "" { retrieveReq.Provider = schemas.OpenAI @@ -869,7 +868,7 @@ func CreateOpenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerSto } return nil, errors.New("invalid file content request type") }, - FileRetrieveResponseConverter: func(ctx *context.Context, resp *schemas.BifrostFileRetrieveResponse) (interface{}, error) { + FileRetrieveResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileRetrieveResponse) (interface{}, error) { // Raw response is invalid even for OpenAI switch resp.ExtraFields.Provider { case schemas.Gemini: @@ -881,7 +880,7 @@ func CreateOpenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerSto } return resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, PreCallback: extractFileIDFromPath(handlerStore), @@ -900,7 +899,7 @@ func CreateOpenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerSto GetRequestTypeInstance: func() interface{} { return &schemas.BifrostFileDeleteRequest{} }, - FileRequestConverter: func(ctx *context.Context, req interface{}) (*FileRequest, error) { + FileRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*FileRequest, error) { if deleteReq, ok := req.(*schemas.BifrostFileDeleteRequest); ok { if deleteReq.Provider == "" { deleteReq.Provider = schemas.OpenAI @@ -915,7 +914,7 @@ func CreateOpenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerSto } return nil, errors.New("invalid file delete request type") }, - FileDeleteResponseConverter: func(ctx *context.Context, resp *schemas.BifrostFileDeleteResponse) (interface{}, error) { + FileDeleteResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileDeleteResponse) (interface{}, error) { if resp.ExtraFields.RawResponse != nil && resp.ExtraFields.Provider == schemas.OpenAI { return resp.ExtraFields.RawResponse, nil } @@ -929,7 +928,7 @@ func CreateOpenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerSto } return resp, nil }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, PreCallback: extractFileIDFromPath(handlerStore), @@ -948,7 +947,7 @@ func CreateOpenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerSto GetRequestTypeInstance: func() interface{} { return &schemas.BifrostFileContentRequest{} }, - FileRequestConverter: func(ctx *context.Context, req interface{}) (*FileRequest, error) { + FileRequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*FileRequest, error) { if contentReq, ok := req.(*schemas.BifrostFileContentRequest); ok { if contentReq.Provider == "" { contentReq.Provider = schemas.OpenAI @@ -964,7 +963,7 @@ func CreateOpenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerSto } return nil, errors.New("invalid file content request type") }, - ErrorConverter: func(ctx *context.Context, err *schemas.BifrostError) interface{} { + ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return err }, PreCallback: extractFileIDFromPath(handlerStore), @@ -978,7 +977,7 @@ func CreateOpenAIFileRouteConfigs(pathPrefix string, handlerStore lib.HandlerSto func extractBatchListQueryParams(handlerStore lib.HandlerStore) PreRequestCallback { azureHook := AzureEndpointPreHook(handlerStore) - return func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { if azureHook != nil { if err := azureHook(ctx, bifrostCtx, req); err != nil { return err @@ -1017,7 +1016,7 @@ func extractBatchListQueryParams(handlerStore lib.HandlerStore) PreRequestCallba func extractBatchIDFromPath(handlerStore lib.HandlerStore) PreRequestCallback { azureHook := AzureEndpointPreHook(handlerStore) - return func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { if azureHook != nil { if err := azureHook(ctx, bifrostCtx, req); err != nil { return err @@ -1062,7 +1061,7 @@ func extractBatchIDFromPath(handlerStore lib.HandlerStore) PreRequestCallback { func extractFileListQueryParams(handlerStore lib.HandlerStore) PreRequestCallback { azureHook := AzureEndpointPreHook(handlerStore) - return func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { if azureHook != nil { if err := azureHook(ctx, bifrostCtx, req); err != nil { return err @@ -1141,7 +1140,7 @@ func extractFileListQueryParams(handlerStore lib.HandlerStore) PreRequestCallbac func extractFileIDFromPath(handlerStore lib.HandlerStore) PreRequestCallback { azureHook := AzureEndpointPreHook(handlerStore) - return func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error { + return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { if azureHook != nil { if err := azureHook(ctx, bifrostCtx, req); err != nil { return err diff --git a/transports/bifrost-http/integrations/router.go b/transports/bifrost-http/integrations/router.go index 27cf1dca11..7d86e22e4c 100644 --- a/transports/bifrost-http/integrations/router.go +++ b/transports/bifrost-http/integrations/router.go @@ -69,7 +69,7 @@ import ( // ExtensionRouter defines the interface that all integration routers must implement // to register their routes with the main HTTP router. type ExtensionRouter interface { - RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) + RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) } // StreamingRequest interface for requests that support streaming @@ -98,116 +98,115 @@ type FileRequest struct { } // BatchRequestConverter is a function that converts integration-specific batch requests to Bifrost format. -type BatchRequestConverter func(ctx *context.Context, req interface{}) (*BatchRequest, error) +type BatchRequestConverter func(ctx *schemas.BifrostContext, req interface{}) (*BatchRequest, error) // FileRequestConverter is a function that converts integration-specific file requests to Bifrost format. -type FileRequestConverter func(ctx *context.Context, req interface{}) (*FileRequest, error) +type FileRequestConverter func(ctx *schemas.BifrostContext, req interface{}) (*FileRequest, error) // RequestConverter is a function that converts integration-specific requests to Bifrost format. // It takes the parsed request object and returns a BifrostRequest ready for processing. -type RequestConverter func(ctx *context.Context, req interface{}) (*schemas.BifrostRequest, error) +type RequestConverter func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) // ListModelsResponseConverter is a function that converts BifrostListModelsResponse to integration-specific format. // It takes a BifrostListModelsResponse and returns the format expected by the specific integration. -type ListModelsResponseConverter func(ctx *context.Context, resp *schemas.BifrostListModelsResponse) (interface{}, error) +type ListModelsResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostListModelsResponse) (interface{}, error) // TextResponseConverter is a function that converts BifrostTextCompletionResponse to integration-specific format. // It takes a BifrostTextCompletionResponse and returns the format expected by the specific integration. -type TextResponseConverter func(ctx *context.Context, resp *schemas.BifrostTextCompletionResponse) (interface{}, error) +type TextResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostTextCompletionResponse) (interface{}, error) // ChatResponseConverter is a function that converts BifrostChatResponse to integration-specific format. // It takes a BifrostChatResponse and returns the format expected by the specific integration. -type ChatResponseConverter func(ctx *context.Context, resp *schemas.BifrostChatResponse) (interface{}, error) +type ChatResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostChatResponse) (interface{}, error) // ResponsesResponseConverter is a function that converts BifrostResponsesResponse to integration-specific format. // It takes a BifrostResponsesResponse and returns the format expected by the specific integration. -type ResponsesResponseConverter func(ctx *context.Context, resp *schemas.BifrostResponsesResponse) (interface{}, error) +type ResponsesResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesResponse) (interface{}, error) // EmbeddingResponseConverter is a function that converts BifrostEmbeddingResponse to integration-specific format. // It takes a BifrostEmbeddingResponse and returns the format expected by the specific integration. -type EmbeddingResponseConverter func(ctx *context.Context, resp *schemas.BifrostEmbeddingResponse) (interface{}, error) +type EmbeddingResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostEmbeddingResponse) (interface{}, error) // SpeechResponseConverter is a function that converts BifrostSpeechResponse to integration-specific format. // It takes a BifrostSpeechResponse and returns the format expected by the specific integration. -type SpeechResponseConverter func(ctx *context.Context, resp *schemas.BifrostSpeechResponse) (interface{}, error) +type SpeechResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostSpeechResponse) (interface{}, error) // TranscriptionResponseConverter is a function that converts BifrostTranscriptionResponse to integration-specific format. // It takes a BifrostTranscriptionResponse and returns the format expected by the specific integration. -type TranscriptionResponseConverter func(ctx *context.Context, resp *schemas.BifrostTranscriptionResponse) (interface{}, error) +type TranscriptionResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostTranscriptionResponse) (interface{}, error) // BatchCreateResponseConverter is a function that converts BifrostBatchCreateResponse to integration-specific format. // It takes a BifrostBatchCreateResponse and returns the format expected by the specific integration. -type BatchCreateResponseConverter func(ctx *context.Context, resp *schemas.BifrostBatchCreateResponse) (interface{}, error) +type BatchCreateResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostBatchCreateResponse) (interface{}, error) // BatchListResponseConverter is a function that converts BifrostBatchListResponse to integration-specific format. // It takes a BifrostBatchListResponse and returns the format expected by the specific integration. -type BatchListResponseConverter func(ctx *context.Context, resp *schemas.BifrostBatchListResponse) (interface{}, error) +type BatchListResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostBatchListResponse) (interface{}, error) // BatchRetrieveResponseConverter is a function that converts BifrostBatchRetrieveResponse to integration-specific format. // It takes a BifrostBatchRetrieveResponse and returns the format expected by the specific integration. -type BatchRetrieveResponseConverter func(ctx *context.Context, resp *schemas.BifrostBatchRetrieveResponse) (interface{}, error) +type BatchRetrieveResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostBatchRetrieveResponse) (interface{}, error) // BatchCancelResponseConverter is a function that converts BifrostBatchCancelResponse to integration-specific format. // It takes a BifrostBatchCancelResponse and returns the format expected by the specific integration. -type BatchCancelResponseConverter func(ctx *context.Context, resp *schemas.BifrostBatchCancelResponse) (interface{}, error) +type BatchCancelResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostBatchCancelResponse) (interface{}, error) // BatchResultsResponseConverter is a function that converts BifrostBatchResultsResponse to integration-specific format. // It takes a BifrostBatchResultsResponse and returns the format expected by the specific integration. -type BatchResultsResponseConverter func(ctx *context.Context, resp *schemas.BifrostBatchResultsResponse) (interface{}, error) +type BatchResultsResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostBatchResultsResponse) (interface{}, error) // FileUploadResponseConverter is a function that converts BifrostFileUploadResponse to integration-specific format. // It takes a BifrostFileUploadResponse and returns the format expected by the specific integration. -type FileUploadResponseConverter func(ctx *context.Context, resp *schemas.BifrostFileUploadResponse) (interface{}, error) +type FileUploadResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileUploadResponse) (interface{}, error) // FileListResponseConverter is a function that converts BifrostFileListResponse to integration-specific format. // It takes a BifrostFileListResponse and returns the format expected by the specific integration. -type FileListResponseConverter func(ctx *context.Context, resp *schemas.BifrostFileListResponse) (interface{}, error) +type FileListResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileListResponse) (interface{}, error) // FileRetrieveResponseConverter is a function that converts BifrostFileRetrieveResponse to integration-specific format. // It takes a BifrostFileRetrieveResponse and returns the format expected by the specific integration. -type FileRetrieveResponseConverter func(ctx *context.Context, resp *schemas.BifrostFileRetrieveResponse) (interface{}, error) +type FileRetrieveResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileRetrieveResponse) (interface{}, error) // FileDeleteResponseConverter is a function that converts BifrostFileDeleteResponse to integration-specific format. // It takes a BifrostFileDeleteResponse and returns the format expected by the specific integration. -type FileDeleteResponseConverter func(ctx *context.Context, resp *schemas.BifrostFileDeleteResponse) (interface{}, error) +type FileDeleteResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileDeleteResponse) (interface{}, error) // FileContentResponseConverter is a function that converts BifrostFileContentResponse to integration-specific format. // It takes a BifrostFileContentResponse and returns the format expected by the specific integration. // Note: This may return binary data or a wrapper object depending on the integration. -type FileContentResponseConverter func(ctx *context.Context, resp *schemas.BifrostFileContentResponse) (interface{}, error) +type FileContentResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostFileContentResponse) (interface{}, error) // CountTokensResponseConverter is a function that converts BifrostCountTokensResponse to integration-specific format. // It takes a BifrostCountTokensResponse and returns the format expected by the specific integration. -type CountTokensResponseConverter func(ctx *context.Context, resp *schemas.BifrostCountTokensResponse) (interface{}, error) +type CountTokensResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostCountTokensResponse) (interface{}, error) // TextStreamResponseConverter is a function that converts BifrostTextCompletionResponse to integration-specific streaming format. // It takes a BifrostTextCompletionResponse and returns the event type and the streaming format expected by the specific integration. -type TextStreamResponseConverter func(ctx *context.Context, resp *schemas.BifrostTextCompletionResponse) (string, interface{}, error) +type TextStreamResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostTextCompletionResponse) (string, interface{}, error) // ChatStreamResponseConverter is a function that converts BifrostChatResponse to integration-specific streaming format. // It takes a BifrostChatResponse and returns the event type and the streaming format expected by the specific integration. -type ChatStreamResponseConverter func(ctx *context.Context, resp *schemas.BifrostChatResponse) (string, interface{}, error) +type ChatStreamResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostChatResponse) (string, interface{}, error) // ResponsesStreamResponseConverter is a function that converts BifrostResponsesStreamResponse to integration-specific streaming format. -// It takes a BifrostResponsesStreamResponse and returns the list of event types and the list of streaming formats expected by the specific integration. (so that one bifrost response chunk can be converted to multiple streaming chunks by the specific integration) -// The list of event types and streaming formats are returned in the same order as the bifrost response chunks. -type ResponsesStreamResponseConverter func(ctx *context.Context, resp *schemas.BifrostResponsesStreamResponse) (string, interface{}, error) +// It takes a BifrostResponsesStreamResponse and returns a single event type and payload, which can itself encode one or more SSE events if needed by the integration. +type ResponsesStreamResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesStreamResponse) (string, interface{}, error) // SpeechStreamResponseConverter is a function that converts BifrostSpeechStreamResponse to integration-specific streaming format. // It takes a BifrostSpeechStreamResponse and returns the event type and the streaming format expected by the specific integration. -type SpeechStreamResponseConverter func(ctx *context.Context, resp *schemas.BifrostSpeechStreamResponse) (string, interface{}, error) +type SpeechStreamResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostSpeechStreamResponse) (string, interface{}, error) // TranscriptionStreamResponseConverter is a function that converts BifrostTranscriptionStreamResponse to integration-specific streaming format. // It takes a BifrostTranscriptionStreamResponse and returns the event type and the streaming format expected by the specific integration. -type TranscriptionStreamResponseConverter func(ctx *context.Context, resp *schemas.BifrostTranscriptionStreamResponse) (string, interface{}, error) +type TranscriptionStreamResponseConverter func(ctx *schemas.BifrostContext, resp *schemas.BifrostTranscriptionStreamResponse) (string, interface{}, error) // ErrorConverter is a function that converts BifrostError to integration-specific format. // It takes a BifrostError and returns the format expected by the specific integration. -type ErrorConverter func(ctx *context.Context, err *schemas.BifrostError) interface{} +type ErrorConverter func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} // StreamErrorConverter is a function that converts BifrostError to integration-specific streaming error format. // It takes a BifrostError and returns the streaming error format expected by the specific integration. -type StreamErrorConverter func(ctx *context.Context, err *schemas.BifrostError) interface{} +type StreamErrorConverter func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} // RequestParser is a function that handles custom request body parsing. // It replaces the default JSON parsing when configured (e.g., for multipart/form-data). @@ -219,7 +218,7 @@ type RequestParser func(ctx *fasthttp.RequestCtx, req interface{}) error // It can be used to modify the request object (e.g., extract model from URL parameters) // or perform validation. If it returns an error, the request processing stops. // It can also modify the bifrost context based on the request context before it is given to Bifrost. -type PreRequestCallback func(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, req interface{}) error +type PreRequestCallback func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error // PostRequestCallback is called after processing the request but before sending the response. // It can be used to modify the response or perform additional logging/metrics. @@ -322,7 +321,7 @@ func NewGenericRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, ro // RegisterRoutes registers all configured routes on the given fasthttp router. // This method implements the ExtensionRouter interface. -func (g *GenericRouter) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) { +func (g *GenericRouter) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { for _, route := range g.routes { // Validate route configuration at startup to fail fast method := strings.ToUpper(route.Method) @@ -393,10 +392,10 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, g.handlerStore.ShouldAllowDirectKeys(), g.handlerStore.GetHeaderFilterConfig()) // Set send back raw response flag for all integration requests - *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeySendBackRawResponse, true) + bifrostCtx.SetValue(schemas.BifrostContextKeySendBackRawResponse, true) // Set integration type to context - *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyIntegrationType, string(config.Type)) + bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, string(config.Type)) // Parse request body based on configuration if method != fasthttp.MethodGet && method != fasthttp.MethodHead { @@ -432,7 +431,7 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle if ctx.UserValue(string(schemas.BifrostContextKeyDirectKey)) != nil { key, ok := ctx.UserValue(string(schemas.BifrostContextKeyDirectKey)).(schemas.Key) if ok { - *bifrostCtx = context.WithValue(*bifrostCtx, schemas.BifrostContextKeyDirectKey, key) + bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, key) } } @@ -504,20 +503,19 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle } // handleNonStreamingRequest handles regular (non-streaming) requests -func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, config RouteConfig, req interface{}, bifrostReq *schemas.BifrostRequest, bifrostCtx *context.Context) { +func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, config RouteConfig, req interface{}, bifrostReq *schemas.BifrostRequest, bifrostCtx *schemas.BifrostContext) { // Use the cancellable context from ConvertToBifrostContext // While we can't detect client disconnects until we try to write, having a cancellable context // allows providers that check ctx.Done() to cancel early if needed. This is less critical than // streaming requests (where we actively detect write errors), but still provides a mechanism // for providers to respect cancellation. - requestCtx := *bifrostCtx - var response interface{} + var err error switch { case bifrostReq.ListModelsRequest != nil: - listModelsResponse, bifrostErr := g.client.ListModelsRequest(requestCtx, bifrostReq.ListModelsRequest) + listModelsResponse, bifrostErr := g.client.ListModelsRequest(bifrostCtx, bifrostReq.ListModelsRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -537,7 +535,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf response, err = config.ListModelsResponseConverter(bifrostCtx, listModelsResponse) case bifrostReq.TextCompletionRequest != nil: - textCompletionResponse, bifrostErr := g.client.TextCompletionRequest(requestCtx, bifrostReq.TextCompletionRequest) + textCompletionResponse, bifrostErr := g.client.TextCompletionRequest(bifrostCtx, bifrostReq.TextCompletionRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -560,7 +558,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf // Convert Bifrost response to integration-specific format and send response, err = config.TextResponseConverter(bifrostCtx, textCompletionResponse) case bifrostReq.ChatRequest != nil: - chatResponse, bifrostErr := g.client.ChatCompletionRequest(requestCtx, bifrostReq.ChatRequest) + chatResponse, bifrostErr := g.client.ChatCompletionRequest(bifrostCtx, bifrostReq.ChatRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -583,7 +581,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf // Convert Bifrost response to integration-specific format and send response, err = config.ChatResponseConverter(bifrostCtx, chatResponse) case bifrostReq.ResponsesRequest != nil: - responsesResponse, bifrostErr := g.client.ResponsesRequest(requestCtx, bifrostReq.ResponsesRequest) + responsesResponse, bifrostErr := g.client.ResponsesRequest(bifrostCtx, bifrostReq.ResponsesRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -606,7 +604,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf // Convert Bifrost response to integration-specific format and send response, err = config.ResponsesResponseConverter(bifrostCtx, responsesResponse) case bifrostReq.EmbeddingRequest != nil: - embeddingResponse, bifrostErr := g.client.EmbeddingRequest(requestCtx, bifrostReq.EmbeddingRequest) + embeddingResponse, bifrostErr := g.client.EmbeddingRequest(bifrostCtx, bifrostReq.EmbeddingRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -629,7 +627,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf // Convert Bifrost response to integration-specific format and send response, err = config.EmbeddingResponseConverter(bifrostCtx, embeddingResponse) case bifrostReq.SpeechRequest != nil: - speechResponse, bifrostErr := g.client.SpeechRequest(requestCtx, bifrostReq.SpeechRequest) + speechResponse, bifrostErr := g.client.SpeechRequest(bifrostCtx, bifrostReq.SpeechRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -663,7 +661,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf return } case bifrostReq.TranscriptionRequest != nil: - transcriptionResponse, bifrostErr := g.client.TranscriptionRequest(requestCtx, bifrostReq.TranscriptionRequest) + transcriptionResponse, bifrostErr := g.client.TranscriptionRequest(bifrostCtx, bifrostReq.TranscriptionRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -686,7 +684,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf // Convert Bifrost response to integration-specific format and send response, err = config.TranscriptionResponseConverter(bifrostCtx, transcriptionResponse) case bifrostReq.CountTokensRequest != nil: - countTokensResponse, bifrostErr := g.client.CountTokensRequest(requestCtx, bifrostReq.CountTokensRequest) + countTokensResponse, bifrostErr := g.client.CountTokensRequest(bifrostCtx, bifrostReq.CountTokensRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -726,9 +724,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf } // handleBatchRequest handles batch API requests (create, list, retrieve, cancel, results) -func (g *GenericRouter) handleBatchRequest(ctx *fasthttp.RequestCtx, config RouteConfig, req interface{}, batchReq *BatchRequest, bifrostCtx *context.Context) { - requestCtx := *bifrostCtx - +func (g *GenericRouter) handleBatchRequest(ctx *fasthttp.RequestCtx, config RouteConfig, req interface{}, batchReq *BatchRequest, bifrostCtx *schemas.BifrostContext) { var response interface{} var err error @@ -738,7 +734,7 @@ func (g *GenericRouter) handleBatchRequest(ctx *fasthttp.RequestCtx, config Rout g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(nil, "Invalid batch create request")) return } - batchResponse, bifrostErr := g.client.BatchCreateRequest(requestCtx, batchReq.CreateRequest) + batchResponse, bifrostErr := g.client.BatchCreateRequest(bifrostCtx, batchReq.CreateRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -760,7 +756,7 @@ func (g *GenericRouter) handleBatchRequest(ctx *fasthttp.RequestCtx, config Rout g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(nil, "Invalid batch list request")) return } - batchResponse, bifrostErr := g.client.BatchListRequest(requestCtx, batchReq.ListRequest) + batchResponse, bifrostErr := g.client.BatchListRequest(bifrostCtx, batchReq.ListRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -782,7 +778,7 @@ func (g *GenericRouter) handleBatchRequest(ctx *fasthttp.RequestCtx, config Rout g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(nil, "Invalid batch retrieve request")) return } - batchResponse, bifrostErr := g.client.BatchRetrieveRequest(requestCtx, batchReq.RetrieveRequest) + batchResponse, bifrostErr := g.client.BatchRetrieveRequest(bifrostCtx, batchReq.RetrieveRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -804,7 +800,7 @@ func (g *GenericRouter) handleBatchRequest(ctx *fasthttp.RequestCtx, config Rout g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(nil, "Invalid batch cancel request")) return } - batchResponse, bifrostErr := g.client.BatchCancelRequest(requestCtx, batchReq.CancelRequest) + batchResponse, bifrostErr := g.client.BatchCancelRequest(bifrostCtx, batchReq.CancelRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -826,7 +822,7 @@ func (g *GenericRouter) handleBatchRequest(ctx *fasthttp.RequestCtx, config Rout g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(nil, "Invalid batch results request")) return } - batchResponse, bifrostErr := g.client.BatchResultsRequest(requestCtx, batchReq.ResultsRequest) + batchResponse, bifrostErr := g.client.BatchResultsRequest(bifrostCtx, batchReq.ResultsRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -857,8 +853,8 @@ func (g *GenericRouter) handleBatchRequest(ctx *fasthttp.RequestCtx, config Rout } // handleFileRequest handles file API requests (upload, list, retrieve, delete, content) -func (g *GenericRouter) handleFileRequest(ctx *fasthttp.RequestCtx, config RouteConfig, req interface{}, fileReq *FileRequest, bifrostCtx *context.Context) { - requestCtx := *bifrostCtx +func (g *GenericRouter) handleFileRequest(ctx *fasthttp.RequestCtx, config RouteConfig, req interface{}, fileReq *FileRequest, bifrostCtx *schemas.BifrostContext) { + var response interface{} var err error @@ -869,7 +865,7 @@ func (g *GenericRouter) handleFileRequest(ctx *fasthttp.RequestCtx, config Route g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(nil, "Invalid file upload request")) return } - fileResponse, bifrostErr := g.client.FileUploadRequest(requestCtx, fileReq.UploadRequest) + fileResponse, bifrostErr := g.client.FileUploadRequest(bifrostCtx, fileReq.UploadRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -891,7 +887,7 @@ func (g *GenericRouter) handleFileRequest(ctx *fasthttp.RequestCtx, config Route g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(nil, "Invalid file list request")) return } - fileResponse, bifrostErr := g.client.FileListRequest(requestCtx, fileReq.ListRequest) + fileResponse, bifrostErr := g.client.FileListRequest(bifrostCtx, fileReq.ListRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -922,7 +918,7 @@ func (g *GenericRouter) handleFileRequest(ctx *fasthttp.RequestCtx, config Route g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(nil, "Invalid file retrieve request")) return } - fileResponse, bifrostErr := g.client.FileRetrieveRequest(requestCtx, fileReq.RetrieveRequest) + fileResponse, bifrostErr := g.client.FileRetrieveRequest(bifrostCtx, fileReq.RetrieveRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -944,7 +940,7 @@ func (g *GenericRouter) handleFileRequest(ctx *fasthttp.RequestCtx, config Route g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(nil, "Invalid file delete request")) return } - fileResponse, bifrostErr := g.client.FileDeleteRequest(requestCtx, fileReq.DeleteRequest) + fileResponse, bifrostErr := g.client.FileDeleteRequest(bifrostCtx, fileReq.DeleteRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -966,7 +962,7 @@ func (g *GenericRouter) handleFileRequest(ctx *fasthttp.RequestCtx, config Route g.sendError(ctx, bifrostCtx, config.ErrorConverter, newBifrostError(nil, "Invalid file content request")) return } - fileResponse, bifrostErr := g.client.FileContentRequest(requestCtx, fileReq.ContentRequest) + fileResponse, bifrostErr := g.client.FileContentRequest(bifrostCtx, fileReq.ContentRequest) if bifrostErr != nil { g.sendError(ctx, bifrostCtx, config.ErrorConverter, bifrostErr) return @@ -1019,7 +1015,7 @@ func (g *GenericRouter) handleFileRequest(ctx *fasthttp.RequestCtx, config Route } // handleStreamingRequest handles streaming requests using Server-Sent Events (SSE) -func (g *GenericRouter) handleStreamingRequest(ctx *fasthttp.RequestCtx, config RouteConfig, bifrostReq *schemas.BifrostRequest, bifrostCtx *context.Context, cancel context.CancelFunc) { +func (g *GenericRouter) handleStreamingRequest(ctx *fasthttp.RequestCtx, config RouteConfig, bifrostReq *schemas.BifrostRequest, bifrostCtx *schemas.BifrostContext, cancel context.CancelFunc) { // Set headers based on route type if config.Type == RouteConfigTypeBedrock { // AWS Event Stream headers for Bedrock @@ -1040,22 +1036,20 @@ func (g *GenericRouter) handleStreamingRequest(ctx *fasthttp.RequestCtx, config // That keeps goroutines and upstream tokens alive long after the SSE writer has exited. // // We now get a cancellable context from ConvertToBifrostContext so we can cancel the upstream stream immediately when the client disconnects. - streamCtx := *bifrostCtx - var stream chan *schemas.BifrostStream var bifrostErr *schemas.BifrostError // Handle different request types if bifrostReq.TextCompletionRequest != nil { - stream, bifrostErr = g.client.TextCompletionStreamRequest(streamCtx, bifrostReq.TextCompletionRequest) + stream, bifrostErr = g.client.TextCompletionStreamRequest(bifrostCtx, bifrostReq.TextCompletionRequest) } else if bifrostReq.ChatRequest != nil { - stream, bifrostErr = g.client.ChatCompletionStreamRequest(streamCtx, bifrostReq.ChatRequest) + stream, bifrostErr = g.client.ChatCompletionStreamRequest(bifrostCtx, bifrostReq.ChatRequest) } else if bifrostReq.ResponsesRequest != nil { - stream, bifrostErr = g.client.ResponsesStreamRequest(streamCtx, bifrostReq.ResponsesRequest) + stream, bifrostErr = g.client.ResponsesStreamRequest(bifrostCtx, bifrostReq.ResponsesRequest) } else if bifrostReq.SpeechRequest != nil { - stream, bifrostErr = g.client.SpeechStreamRequest(streamCtx, bifrostReq.SpeechRequest) + stream, bifrostErr = g.client.SpeechStreamRequest(bifrostCtx, bifrostReq.SpeechRequest) } else if bifrostReq.TranscriptionRequest != nil { - stream, bifrostErr = g.client.TranscriptionStreamRequest(streamCtx, bifrostReq.TranscriptionRequest) + stream, bifrostErr = g.client.TranscriptionStreamRequest(bifrostCtx, bifrostReq.TranscriptionRequest) } // Get the streaming channel from Bifrost @@ -1134,10 +1128,24 @@ func (g *GenericRouter) handleStreamingRequest(ctx *fasthttp.RequestCtx, config // The cancel function is called ONLY when client disconnects are detected via write errors. // Bifrost handles cleanup internally for normal completion and errors, so we only cancel // upstream streams when write errors indicate the client has disconnected. -func (g *GenericRouter) handleStreaming(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, config RouteConfig, streamChan chan *schemas.BifrostStream, cancel context.CancelFunc) { +func (g *GenericRouter) handleStreaming(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, config RouteConfig, streamChan chan *schemas.BifrostStream, cancel context.CancelFunc) { + // Signal to tracing middleware that trace completion should be deferred + // The streaming callback will complete the trace after the stream ends + ctx.SetUserValue(schemas.BifrostContextKeyDeferTraceCompletion, true) + + // Get the trace completer function for use in the streaming callback + traceCompleter, _ := ctx.UserValue(schemas.BifrostContextKeyTraceCompleter).(func()) + // Use streaming response writer ctx.Response.SetBodyStreamWriter(func(w *bufio.Writer) { - defer w.Flush() + defer func() { + w.Flush() + // Complete the trace after streaming finishes + // This ensures all spans (including llm.call) are properly ended before the trace is sent to OTEL + if traceCompleter != nil { + traceCompleter() + } + }() // Create encoder for AWS Event Stream if needed var eventStreamEncoder *eventstream.Encoder diff --git a/transports/bifrost-http/integrations/utils.go b/transports/bifrost-http/integrations/utils.go index 67c2e54e04..c8f207bb0e 100644 --- a/transports/bifrost-http/integrations/utils.go +++ b/transports/bifrost-http/integrations/utils.go @@ -2,7 +2,6 @@ package integrations import ( "bytes" - "context" "fmt" "log" "reflect" @@ -136,7 +135,7 @@ func extractExactPath(ctx *fasthttp.RequestCtx) string { } // sendStreamError sends an error in streaming format using the stream error converter if available -func (g *GenericRouter) sendStreamError(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, config RouteConfig, bifrostErr *schemas.BifrostError) { +func (g *GenericRouter) sendStreamError(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, config RouteConfig, bifrostErr *schemas.BifrostError) { var errorResponse interface{} // Use stream error converter if available, otherwise fallback to regular error converter @@ -162,7 +161,7 @@ func (g *GenericRouter) sendStreamError(ctx *fasthttp.RequestCtx, bifrostCtx *co // sendError sends an error response with the appropriate status code and JSON body. // It handles different error types (string, error interface, or arbitrary objects). -func (g *GenericRouter) sendError(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, errorConverter ErrorConverter, bifrostErr *schemas.BifrostError) { +func (g *GenericRouter) sendError(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, errorConverter ErrorConverter, bifrostErr *schemas.BifrostError) { if bifrostErr.StatusCode != nil { ctx.SetStatusCode(*bifrostErr.StatusCode) } else { @@ -185,7 +184,7 @@ func (g *GenericRouter) sendError(ctx *fasthttp.RequestCtx, bifrostCtx *context. } // sendSuccess sends a successful response with HTTP 200 status and JSON body. -func (g *GenericRouter) sendSuccess(ctx *fasthttp.RequestCtx, bifrostCtx *context.Context, errorConverter ErrorConverter, response interface{}) { +func (g *GenericRouter) sendSuccess(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, errorConverter ErrorConverter, response interface{}) { ctx.SetStatusCode(fasthttp.StatusOK) ctx.SetContentType("application/json") diff --git a/transports/bifrost-http/lib/account.go b/transports/bifrost-http/lib/account.go index b37f756b86..05ec593f59 100644 --- a/transports/bifrost-http/lib/account.go +++ b/transports/bifrost-http/lib/account.go @@ -36,7 +36,7 @@ func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvide // GetKeysForProvider returns the API keys configured for a specific provider. // Keys are already processed (environment variables resolved) by the store. // Implements the Account interface. -func (baseAccount *BaseAccount) GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { +func (baseAccount *BaseAccount) GetKeysForProvider(ctx context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { if baseAccount.store == nil { return nil, fmt.Errorf("store not initialized") } @@ -49,7 +49,7 @@ func (baseAccount *BaseAccount) GetKeysForProvider(ctx *context.Context, provide keys := config.Keys if baseAccount.store.ClientConfig.EnableGovernance { - if v := (*ctx).Value(schemas.BifrostContextKey("bf-governance-include-only-keys")); v != nil { + if v := ctx.Value(schemas.BifrostContextKey("bf-governance-include-only-keys")); v != nil { if includeOnlyKeys, ok := v.([]string); ok { if len(includeOnlyKeys) == 0 { // header present but empty means "no keys allowed" diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 2af17a70ea..455e662f98 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -25,6 +25,7 @@ import ( "github.com/maximhq/bifrost/framework/encrypt" "github.com/maximhq/bifrost/framework/logstore" "github.com/maximhq/bifrost/framework/modelcatalog" + plugins "github.com/maximhq/bifrost/framework/plugins" "github.com/maximhq/bifrost/framework/vectorstore" "github.com/maximhq/bifrost/plugins/semanticcache" "gorm.io/gorm" @@ -46,6 +47,11 @@ const ( MaxRetryBackoff = 1000000 * time.Millisecond // Maximum retry backoff: 1000000ms (1000 seconds) ) +const ( + DBLookupMaxRetries = 5 + DBLookupDelay = 1 * time.Second +) + // getWeight safely dereferences a *float64 weight pointer, returning 1.0 as default if nil. // This allows distinguishing between "not set" (nil -> 1.0) and "explicitly set to 0" (0.0). func getWeight(w *float64) float64 { @@ -227,7 +233,8 @@ type Config struct { EnvKeys map[string][]configstore.EnvKeyInfo // Plugin configs - atomic for lock-free reads with CAS updates - Plugins atomic.Pointer[[]schemas.Plugin] + Plugins atomic.Pointer[[]schemas.Plugin] + PluginLoader plugins.PluginLoader // Plugin configs from config file/database PluginConfigs []*schemas.PluginConfig @@ -247,6 +254,9 @@ var DefaultClientConfig = configstore.ClientConfig{ AllowDirectKeys: false, AllowedOrigins: []string{"*"}, MaxRequestBodySizeMB: 100, + MCPAgentDepth: 10, + MCPToolExecutionTimeout: 30, + MCPCodeModeBindingLevel: string(schemas.CodeModeBindingLevelServer), EnableLiteLLMFallbacks: false, } @@ -465,7 +475,7 @@ func initStoresFromFile(ctx context.Context, config *Config, configData *ConfigD return nil } -// loadClientConfigFromFile loads and merges client config from file with store +// loadClientConfigFromFile loads and merges client config from file with store using hash-based reconciliation func loadClientConfigFromFile(ctx context.Context, config *Config, configData *ConfigData) { var clientConfig *configstore.ClientConfig var err error @@ -477,33 +487,30 @@ func loadClientConfigFromFile(ctx context.Context, config *Config, configData *C } } - if clientConfig != nil { - config.ClientConfig = *clientConfig - // For backward compatibility, handle cases where max request body size is not set - if config.ClientConfig.MaxRequestBodySizeMB == 0 { - config.ClientConfig.MaxRequestBodySizeMB = DefaultClientConfig.MaxRequestBodySizeMB - } - - // Merge with config file if present - if configData.Client != nil { - mergeClientConfig(&config.ClientConfig, configData.Client) - // Update store with merged config - if config.ConfigStore != nil { - logger.Debug("updating merged client config in store") - if err = config.ConfigStore.UpdateClientConfig(ctx, &config.ClientConfig); err != nil { - logger.Warn("failed to update merged client config: %v", err) - } - } - } - } else { + // Case 1: No config in DB - use file config (or defaults) + if clientConfig == nil { logger.Debug("client config not found in store, using config file") if configData.Client != nil { config.ClientConfig = *configData.Client if config.ClientConfig.MaxRequestBodySizeMB == 0 { config.ClientConfig.MaxRequestBodySizeMB = DefaultClientConfig.MaxRequestBodySizeMB } + // Generate hash for the file config + fileHash, hashErr := configData.Client.GenerateClientConfigHash() + if hashErr != nil { + logger.Warn("failed to generate client config hash: %v", hashErr) + } else { + config.ClientConfig.ConfigHash = fileHash + } } else { config.ClientConfig = DefaultClientConfig + // Generate hash for default config + defaultHash, hashErr := config.ClientConfig.GenerateClientConfigHash() + if hashErr != nil { + logger.Warn("failed to generate default client config hash: %v", hashErr) + } else { + config.ClientConfig.ConfigHash = defaultHash + } } if config.ConfigStore != nil { logger.Debug("updating client config in store") @@ -511,6 +518,48 @@ func loadClientConfigFromFile(ctx context.Context, config *Config, configData *C logger.Warn("failed to update client config: %v", err) } } + return + } + + // Case 2: Config exists in DB + config.ClientConfig = *clientConfig + // For backward compatibility, handle cases where max request body size is not set + if config.ClientConfig.MaxRequestBodySizeMB == 0 { + config.ClientConfig.MaxRequestBodySizeMB = DefaultClientConfig.MaxRequestBodySizeMB + } + + // Case 2a: No file config - use DB config as-is + if configData.Client == nil { + logger.Debug("no client config in file, using DB config") + return + } + + // Case 2b: Both DB and file config exist - use hash-based reconciliation + fileHash, hashErr := configData.Client.GenerateClientConfigHash() + if hashErr != nil { + logger.Warn("failed to generate client config hash from file: %v", hashErr) + return + } + + if clientConfig.ConfigHash != fileHash { + // Hash mismatch - config.json was changed, sync from file + logger.Debug("client config hash mismatch, syncing from config file") + config.ClientConfig = *configData.Client + config.ClientConfig.ConfigHash = fileHash + // Apply defaults for zero values + if config.ClientConfig.MaxRequestBodySizeMB == 0 { + config.ClientConfig.MaxRequestBodySizeMB = DefaultClientConfig.MaxRequestBodySizeMB + } + // Update store with file config + if config.ConfigStore != nil { + logger.Debug("updating client config in store from file") + if err = config.ConfigStore.UpdateClientConfig(ctx, &config.ClientConfig); err != nil { + logger.Warn("failed to update client config: %v", err) + } + } + } else { + // Hash matches - keep DB config (preserves UI changes) + logger.Debug("client config hash matches, keeping DB config") } } @@ -869,7 +918,6 @@ func loadMCPConfigFromFile(ctx context.Context, config *Config, configData *Conf if mcpConfig != nil { config.MCPConfig = mcpConfig - // Merge with config file if present if configData.MCP != nil && len(configData.MCP.ClientConfigs) > 0 { mergeMCPConfig(ctx, config, configData, mcpConfig) @@ -1117,9 +1165,24 @@ func mergeGovernanceConfig(ctx context.Context, config *Config, configData *Conf if existingVirtualKey.ConfigHash != fileVKHash { logger.Debug("config hash mismatch for virtual key %s, syncing from config file", newVirtualKey.ID) configData.Governance.VirtualKeys[i].ConfigHash = fileVKHash + processedValue, envVar, err := config.processEnvValue(configData.Governance.VirtualKeys[i].Value) + if err != nil { + logger.Warn("failed to process env var for virtual key %s: %v", configData.Governance.VirtualKeys[i].ID, err) + continue + } + configData.Governance.VirtualKeys[i].Value = processedValue virtualKeysToUpdate = append(virtualKeysToUpdate, configData.Governance.VirtualKeys[i]) governanceConfig.VirtualKeys[j] = configData.Governance.VirtualKeys[i] - } else { + if envVar != "" { + config.EnvKeys[envVar] = append(config.EnvKeys[envVar], configstore.EnvKeyInfo{ + EnvVar: envVar, + Provider: "", + KeyType: "virtual_key", + ConfigPath: fmt.Sprintf("governance.virtual_keys[%s].value", configData.Governance.VirtualKeys[i].ID), + KeyID: "", + }) + } + } else { logger.Debug("config hash matches for virtual key %s, keeping DB config", newVirtualKey.ID) } break @@ -1564,7 +1627,11 @@ func initFrameworkConfigFromFile(ctx context.Context, config *Config, configData Pricing: pricingConfig, } - pricingManager, err := modelcatalog.Init(ctx, pricingConfig, config.ConfigStore, logger) + var pricingManager *modelcatalog.ModelCatalog + var err error + + // Use default modelcatalog initialization when no enterprise overrides are provided + pricingManager, err = modelcatalog.Init(ctx, pricingConfig, config.ConfigStore, nil, logger) if err != nil { logger.Warn("failed to initialize pricing manager: %v", err) } @@ -1939,7 +2006,9 @@ func initDefaultFrameworkConfig(ctx context.Context, config *Config) error { } // Initialize pricing manager - pricingManager, err := modelcatalog.Init(ctx, pricingConfig, config.ConfigStore, logger) + var pricingManager *modelcatalog.ModelCatalog + // Use default modelcatalog initialization when no enterprise overrides are provided + pricingManager, err = modelcatalog.Init(ctx, pricingConfig, config.ConfigStore, nil, logger) if err != nil { logger.Warn("failed to initialize pricing manager: %v", err) } @@ -2888,7 +2957,7 @@ func (c *Config) AddMCPClient(ctx context.Context, clientConfig schemas.MCPClien if err := c.client.AddMCPClient(c.MCPConfig.ClientConfigs[len(c.MCPConfig.ClientConfigs)-1]); err != nil { c.MCPConfig.ClientConfigs = c.MCPConfig.ClientConfigs[:len(c.MCPConfig.ClientConfigs)-1] c.cleanupEnvKeys("", clientConfig.ID, newEnvKeys) - return fmt.Errorf("failed to add MCP client: %w", err) + return fmt.Errorf("failed to connect MCP client: %w", err) } if c.ConfigStore != nil { @@ -3037,8 +3106,10 @@ func (c *Config) EditMCPClient(ctx context.Context, id string, updatedConfig sch // Update the in-memory config with the processed values c.MCPConfig.ClientConfigs[configIndex].Name = processedConfig.Name + c.MCPConfig.ClientConfigs[configIndex].IsCodeModeClient = processedConfig.IsCodeModeClient c.MCPConfig.ClientConfigs[configIndex].Headers = processedConfig.Headers c.MCPConfig.ClientConfigs[configIndex].ToolsToExecute = processedConfig.ToolsToExecute + c.MCPConfig.ClientConfigs[configIndex].ToolsToAutoExecute = processedConfig.ToolsToAutoExecute // Check if client is registered in Bifrost (can be not registered if client initialization failed) if clients, err := c.client.GetMCPClients(); err == nil && len(clients) > 0 { @@ -3073,12 +3144,14 @@ func (c *Config) EditMCPClient(ctx context.Context, id string, updatedConfig sch func (c *Config) RedactMCPClientConfig(config schemas.MCPClientConfig) schemas.MCPClientConfig { // Create a copy with basic fields configCopy := schemas.MCPClientConfig{ - ID: config.ID, - Name: config.Name, - ConnectionType: config.ConnectionType, - ConnectionString: config.ConnectionString, - StdioConfig: config.StdioConfig, - ToolsToExecute: append([]string{}, config.ToolsToExecute...), + ID: config.ID, + Name: config.Name, + IsCodeModeClient: config.IsCodeModeClient, + ConnectionType: config.ConnectionType, + ConnectionString: config.ConnectionString, + StdioConfig: config.StdioConfig, + ToolsToExecute: append([]string{}, config.ToolsToExecute...), + ToolsToAutoExecute: append([]string{}, config.ToolsToAutoExecute...), } // Handle connection string if present diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index e0b98f3700..8ca8a8d44c 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -133,6 +133,7 @@ import ( "os" "path/filepath" "testing" + "time" "github.com/google/uuid" "github.com/maximhq/bifrost/core/schemas" @@ -190,6 +191,10 @@ func (m *MockConfigStore) RunMigration(ctx context.Context, migration *migrator. return nil } +func (m *MockConfigStore) RetryOnNotFound(ctx context.Context, fn func(ctx context.Context) (any, error), maxRetries int, retryDelay time.Duration) (any, error) { + return fn(ctx) +} + // Client config func (m *MockConfigStore) UpdateClientConfig(ctx context.Context, config *configstore.ClientConfig) error { m.clientConfig = config @@ -309,6 +314,10 @@ func (m *MockConfigStore) GetRateLimit(ctx context.Context, id string) (*tables. return nil, nil } +func (m *MockConfigStore) GetRateLimits(ctx context.Context) ([]tables.TableRateLimit, error) { + return []tables.TableRateLimit{}, nil +} + func (m *MockConfigStore) CreateCustomer(ctx context.Context, customer *tables.TableCustomer, tx ...*gorm.DB) error { if m.governanceConfig == nil { m.governanceConfig = &configstore.GovernanceConfig{} @@ -615,12 +624,12 @@ func makeMCPClientConfig(id, name string) schemas.MCPClientConfig { // testLogger is a minimal logger implementation for testing type testLogger struct{} -func (l *testLogger) Debug(msg string, args ...any) {} -func (l *testLogger) Info(msg string, args ...any) {} -func (l *testLogger) Warn(msg string, args ...any) {} -func (l *testLogger) Error(msg string, args ...any) {} -func (l *testLogger) Fatal(msg string, args ...any) {} -func (l *testLogger) SetLevel(level schemas.LogLevel) {} +func (l *testLogger) Debug(msg string, args ...any) {} +func (l *testLogger) Info(msg string, args ...any) {} +func (l *testLogger) Warn(msg string, args ...any) {} +func (l *testLogger) Error(msg string, args ...any) {} +func (l *testLogger) Fatal(msg string, args ...any) {} +func (l *testLogger) SetLevel(level schemas.LogLevel) {} func (l *testLogger) SetOutputType(outputType schemas.LoggerOutputType) {} // initTestLogger initializes the global logger for SQLite integration tests @@ -1828,7 +1837,7 @@ func TestProviderHashComparison_DifferentHash(t *testing.T) { Weight: dbKey.Weight, }) fileKeyHash, _ := configstore.GenerateKeyHash(fileKey) - if dbKeyHash == fileKeyHash || fileKey.Name == dbKey.Name { + if dbKeyHash == fileKeyHash || fileKey.Name == dbKey.Name { found = true break } @@ -2189,13 +2198,13 @@ func TestProviderHashComparison_OptionalFieldsPresence(t *testing.T) { // All hashes should be unique hashes := map[string]string{ - "no_optional": hashNoOptional, - "with_network": hashWithNetwork, - "with_proxy": hashWithProxy, - "with_conc": hashWithConcurrency, - "with_custom": hashWithCustom, - "with_raw": hashWithRawResponse, - "all_fields": hashAllFields, + "no_optional": hashNoOptional, + "with_network": hashWithNetwork, + "with_proxy": hashWithProxy, + "with_conc": hashWithConcurrency, + "with_custom": hashWithCustom, + "with_raw": hashWithRawResponse, + "all_fields": hashAllFields, } seen := make(map[string]string) @@ -3001,9 +3010,9 @@ func TestProviderHashComparison_ProviderChangedKeysUnchanged(t *testing.T) { sameKey := schemas.Key{ ID: "key-1", Name: "openai-key", - Value: "sk-original-123", // SAME + Value: "sk-original-123", // SAME Models: []string{"gpt-4", "gpt-3.5-turbo"}, // SAME - Weight: 1.5, // SAME + Weight: 1.5, // SAME } sameKeyHash, _ := configstore.GenerateKeyHash(sameKey) @@ -3039,10 +3048,10 @@ func TestProviderHashComparison_ProviderChangedKeysUnchanged(t *testing.T) { // - Keep existing keys from DB (they weren't changed in file) updatedConfig := configstore.ProviderConfig{ - Keys: dbConfig.Keys, // Keep original keys from DB - NetworkConfig: fileConfig.NetworkConfig, // Update from file - SendBackRawResponse: fileConfig.SendBackRawResponse, // Update from file - ConfigHash: fileProviderHash, // New provider hash + Keys: dbConfig.Keys, // Keep original keys from DB + NetworkConfig: fileConfig.NetworkConfig, // Update from file + SendBackRawResponse: fileConfig.SendBackRawResponse, // Update from file + ConfigHash: fileProviderHash, // New provider hash } // Verify keys are preserved (same values as DB) @@ -3100,9 +3109,9 @@ func TestProviderHashComparison_KeysChangedProviderUnchanged(t *testing.T) { changedKey := schemas.Key{ ID: "key-1", Name: "openai-key", - Value: "sk-new-456", // CHANGED! - Models: []string{"gpt-4", "gpt-3.5-turbo", "o1"}, // CHANGED! - Weight: 2.0, // CHANGED! + Value: "sk-new-456", // CHANGED! + Models: []string{"gpt-4", "gpt-3.5-turbo", "o1"}, // CHANGED! + Weight: 2.0, // CHANGED! } changedKeyHash, _ := configstore.GenerateKeyHash(changedKey) @@ -3138,10 +3147,10 @@ func TestProviderHashComparison_KeysChangedProviderUnchanged(t *testing.T) { // - Update keys from file (they were changed) updatedConfig := configstore.ProviderConfig{ - Keys: fileConfig.Keys, // Update keys from file - NetworkConfig: dbConfig.NetworkConfig, // Keep from DB - SendBackRawResponse: dbConfig.SendBackRawResponse, // Keep from DB - ConfigHash: dbProviderHash, // Provider hash unchanged + Keys: fileConfig.Keys, // Update keys from file + NetworkConfig: dbConfig.NetworkConfig, // Keep from DB + SendBackRawResponse: dbConfig.SendBackRawResponse, // Keep from DB + ConfigHash: dbProviderHash, // Provider hash unchanged } // Verify provider config is preserved @@ -3200,9 +3209,9 @@ func TestProviderHashComparison_BothChangedIndependently(t *testing.T) { changedKey := schemas.Key{ ID: "key-1", Name: "openai-key", - Value: "sk-new-456", // CHANGED + Value: "sk-new-456", // CHANGED Models: []string{"gpt-4", "o1"}, // CHANGED - Weight: 2.0, // CHANGED + Weight: 2.0, // CHANGED } changedKeyHash, _ := configstore.GenerateKeyHash(changedKey) @@ -3302,7 +3311,7 @@ func TestProviderHashComparison_NeitherChanged(t *testing.T) { // === Verify: Both hashes match === if dbProviderHash != fileProviderHash { - t.Errorf("Expected provider hash to be SAME, got DB=%s File=%s", + t.Errorf("Expected provider hash to be SAME, got DB=%s File=%s", dbProviderHash[:16], fileProviderHash[:16]) } else { t.Log("āœ“ Provider hash unchanged") @@ -3352,9 +3361,9 @@ func TestKeyLevelSync_ProviderHashMatch_SingleKeyChanged(t *testing.T) { fileKey := schemas.Key{ ID: "key-1", Name: "openai-key", - Value: "sk-new-value", // CHANGED + Value: "sk-new-value", // CHANGED Models: []string{"gpt-4", "gpt-4-turbo"}, // CHANGED - Weight: 2.0, // CHANGED + Weight: 2.0, // CHANGED } fileKeyHash, _ := configstore.GenerateKeyHash(fileKey) @@ -3465,9 +3474,9 @@ func TestKeyLevelSync_ProviderHashMatch_NewKeyInFile(t *testing.T) { fileKey1 := schemas.Key{ ID: "key-1", Name: "openai-key-1", - Value: "sk-key-1", // SAME + Value: "sk-key-1", // SAME Models: []string{"gpt-4"}, // SAME - Weight: 1.0, // SAME + Weight: 1.0, // SAME } newFileKey := schemas.Key{ ID: "key-2", @@ -3594,9 +3603,9 @@ func TestKeyLevelSync_ProviderHashMatch_KeyOnlyInDB(t *testing.T) { fileKey1 := schemas.Key{ ID: "key-1", Name: "openai-key-1", - Value: "sk-key-1", // SAME + Value: "sk-key-1", // SAME Models: []string{"gpt-4"}, // SAME - Weight: 1.0, // SAME + Weight: 1.0, // SAME } fileConfig := configstore.ProviderConfig{ @@ -3719,16 +3728,16 @@ func TestKeyLevelSync_ProviderHashMatch_MixedScenario(t *testing.T) { fileUnchangedKey := schemas.Key{ ID: "key-unchanged", Name: "unchanged-key", - Value: "sk-unchanged", // SAME + Value: "sk-unchanged", // SAME Models: []string{"gpt-4"}, // SAME - Weight: 1.0, // SAME + Weight: 1.0, // SAME } fileChangedKey := schemas.Key{ ID: "key-changed", Name: "changed-key", - Value: "sk-NEW-value", // CHANGED + Value: "sk-NEW-value", // CHANGED Models: []string{"gpt-4", "gpt-4-turbo"}, // CHANGED - Weight: 2.0, // CHANGED + Weight: 2.0, // CHANGED } newFileKey := schemas.Key{ ID: "key-new", @@ -4862,8 +4871,8 @@ func TestProviderHashComparison_AzureProviderFullLifecycle(t *testing.T) { Endpoint: "https://new-azure.openai.azure.com", // Changed! APIVersion: stringPtr("2024-10-21"), // Changed! Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - "gpt-4o": "gpt-4o-deployment", // Added! + "gpt-4": "gpt-4-deployment", + "gpt-4o": "gpt-4o-deployment", // Added! }, }, }, @@ -5081,7 +5090,7 @@ func TestProviderHashComparison_BedrockProviderFullLifecycle(t *testing.T) { BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: "AKIAIOSFODNN7EXAMPLE", SecretKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", - Region: stringPtr("us-west-2"), // Changed! + Region: stringPtr("us-west-2"), // Changed! ARN: stringPtr("arn:aws:bedrock:us-west-2:123456789012:inference-profile/my-profile"), // Added! Deployments: map[string]string{ "claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", @@ -5092,7 +5101,7 @@ func TestProviderHashComparison_BedrockProviderFullLifecycle(t *testing.T) { }, NetworkConfig: &schemas.NetworkConfig{ BaseURL: "https://bedrock-runtime.us-west-2.amazonaws.com", // Changed! - MaxRetries: 5, // Changed! + MaxRetries: 5, // Changed! }, SendBackRawResponse: true, // Changed! } @@ -5518,9 +5527,9 @@ func TestProviderHashComparison_BedrockDBValuePreservedWhenHashMatches(t *testin Value: "", Weight: 1, BedrockKeyConfig: &schemas.BedrockKeyConfig{ - AccessKey: "AKIAIOSFODNN7EXAMPLE", // Different! - SecretKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", // Different! - Region: stringPtr("us-east-1"), // Same + AccessKey: "AKIAIOSFODNN7EXAMPLE", // Different! + SecretKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", // Different! + Region: stringPtr("us-east-1"), // Same Deployments: map[string]string{ "claude-3": "anthropic.claude-3-sonnet-20240229-v1:0", // Same }, @@ -5529,7 +5538,7 @@ func TestProviderHashComparison_BedrockDBValuePreservedWhenHashMatches(t *testin }, NetworkConfig: &schemas.NetworkConfig{ BaseURL: "https://bedrock-runtime.us-east-1.amazonaws.com", // Same - MaxRetries: 3, // Same + MaxRetries: 3, // Same }, SendBackRawResponse: false, // Same } @@ -5610,7 +5619,7 @@ func TestProviderHashComparison_AzureConfigChangedInFile(t *testing.T) { Weight: 1, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: "https://NEW-azure.openai.azure.com", // Changed! - APIVersion: stringPtr("2024-10-21"), // Changed! + APIVersion: stringPtr("2024-10-21"), // Changed! Deployments: map[string]string{ "gpt-4o": "gpt-4o-deployment", // Added! }, @@ -5701,7 +5710,7 @@ func TestProviderHashComparison_BedrockConfigChangedInFile(t *testing.T) { BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: "AKIAIOSFODNN7EXAMPLE", SecretKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", - Region: stringPtr("us-west-2"), // Changed! + Region: stringPtr("us-west-2"), // Changed! ARN: stringPtr("arn:aws:bedrock:us-west-2:123456789012:inference-profile/new-profile"), // Added! Deployments: map[string]string{ "claude-3-opus": "anthropic.claude-3-opus-20240229-v1:0", // Added! @@ -5711,7 +5720,7 @@ func TestProviderHashComparison_BedrockConfigChangedInFile(t *testing.T) { }, NetworkConfig: &schemas.NetworkConfig{ BaseURL: "https://bedrock-runtime.us-west-2.amazonaws.com", // Changed! - MaxRetries: 5, // Changed! + MaxRetries: 5, // Changed! }, SendBackRawResponse: true, // Changed! } @@ -11464,7 +11473,7 @@ func TestGenerateMCPClientHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch after GORM round-trip for StdioConfig\nBefore save: %s\nAfter load: %s\nStdioConfig populated: %v", + t.Errorf("Hash mismatch after GORM round-trip for StdioConfig\nBefore save: %s\nAfter load: %s\nStdioConfig populated: %v", hashBeforeSave, hashAfterLoad, mcpFromDB.StdioConfig != nil) } }) @@ -11495,7 +11504,7 @@ func TestGenerateMCPClientHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch after GORM round-trip for ToolsToExecute\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch after GORM round-trip for ToolsToExecute\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11529,7 +11538,7 @@ func TestGenerateMCPClientHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch after GORM round-trip for Headers\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch after GORM round-trip for Headers\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11561,7 +11570,7 @@ func TestGenerateMCPClientHash_RuntimeVsMigrationParity(t *testing.T) { hashAfterLoad, _ := configstore.GenerateMCPClientHash(mcpFromDB) if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch after GORM round-trip for all fields\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch after GORM round-trip for all fields\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11600,7 +11609,7 @@ func TestGenerateMCPClientHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch when using Find() (migration pattern)\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch when using Find() (migration pattern)\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11646,7 +11655,7 @@ func TestGeneratePluginHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch after GORM round-trip for plugin Config\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch after GORM round-trip for plugin Config\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11676,7 +11685,7 @@ func TestGeneratePluginHash_RuntimeVsMigrationParity(t *testing.T) { hashAfterLoad, _ := configstore.GeneratePluginHash(pluginFromDB) if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for nested config\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for nested config\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11699,7 +11708,7 @@ func TestGeneratePluginHash_RuntimeVsMigrationParity(t *testing.T) { hashAfterLoad, _ := configstore.GeneratePluginHash(pluginFromDB) if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for empty config\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for empty config\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11740,7 +11749,7 @@ func TestGenerateTeamHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for Profile\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for Profile\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11767,7 +11776,7 @@ func TestGenerateTeamHash_RuntimeVsMigrationParity(t *testing.T) { hashAfterLoad, _ := configstore.GenerateTeamHash(teamFromDB) if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for Config\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for Config\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11794,7 +11803,7 @@ func TestGenerateTeamHash_RuntimeVsMigrationParity(t *testing.T) { hashAfterLoad, _ := configstore.GenerateTeamHash(teamFromDB) if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for Claims\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for Claims\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11823,7 +11832,7 @@ func TestGenerateTeamHash_RuntimeVsMigrationParity(t *testing.T) { hashAfterLoad, _ := configstore.GenerateTeamHash(teamFromDB) if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for all fields\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for all fields\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11875,7 +11884,7 @@ func TestGenerateProviderHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for NetworkConfig\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for NetworkConfig\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11915,7 +11924,7 @@ func TestGenerateProviderHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for ConcurrencyAndBufferSize\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for ConcurrencyAndBufferSize\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11951,7 +11960,7 @@ func TestGenerateProviderHash_RuntimeVsMigrationParity(t *testing.T) { hashAfterLoad, _ := providerConfigFromDB.GenerateConfigHash("openai") if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for ProxyConfig\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for ProxyConfig\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -11987,7 +11996,7 @@ func TestGenerateProviderHash_RuntimeVsMigrationParity(t *testing.T) { hashAfterLoad, _ := providerConfigFromDB.GenerateConfigHash("custom") if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for CustomProviderConfig\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for CustomProviderConfig\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -12048,7 +12057,7 @@ func TestGenerateKeyHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for Models\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for Models\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -12098,7 +12107,7 @@ func TestGenerateKeyHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for AzureKeyConfig\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for AzureKeyConfig\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -12199,7 +12208,7 @@ func TestGenerateClientConfigHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for PrometheusLabels\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for PrometheusLabels\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) @@ -12247,7 +12256,7 @@ func TestGenerateClientConfigHash_RuntimeVsMigrationParity(t *testing.T) { } if hashBeforeSave != hashAfterLoad { - t.Errorf("Hash mismatch for AllowedOrigins\nBefore save: %s\nAfter load: %s", + t.Errorf("Hash mismatch for AllowedOrigins\nBefore save: %s\nAfter load: %s", hashBeforeSave, hashAfterLoad) } }) diff --git a/transports/bifrost-http/lib/ctx.go b/transports/bifrost-http/lib/ctx.go index 3e572ea4aa..b396755c38 100644 --- a/transports/bifrost-http/lib/ctx.go +++ b/transports/bifrost-http/lib/ctx.go @@ -73,25 +73,25 @@ import ( // Example Usage: // // fastCtx := &fasthttp.RequestCtx{...} -// bifrostCtx, cancel := ConvertToBifrostContext(fastCtx, true) +// bifrostCtx, cancel := ConvertToBifrostContext(fastCtx, true, nil) // defer cancel() // Ensure cleanup -// // bifrostCtx now contains any prometheus and maxim header values +// // bifrostCtx now contains propagated header values including Prometheus metrics, +// // Maxim tracing data, MCP filters, governance keys, API keys, cache settings, and extra headers -func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, headerFilterConfig *configstoreTables.GlobalHeaderFilterConfig) (*context.Context, context.CancelFunc) { +func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, headerFilterConfig *configstoreTables.GlobalHeaderFilterConfig) (*schemas.BifrostContext, context.CancelFunc) { // Create cancellable context for all requests // This enables proper cleanup when clients disconnect or requests are cancelled - baseCtx := context.Background() - bifrostCtx, cancel := context.WithCancel(baseCtx) + bifrostCtx, cancel := schemas.NewBifrostContextWithCancel(ctx) // First, check if x-request-id header exists requestID := string(ctx.Request.Header.Peek("x-request-id")) if requestID == "" { requestID = uuid.New().String() } - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKeyRequestID, requestID) + bifrostCtx.SetValue(schemas.BifrostContextKeyRequestID, requestID) // Populating all user values from the request context ctx.VisitUserValuesAll(func(key, value any) { - bifrostCtx = context.WithValue(bifrostCtx, key, value) + bifrostCtx.SetValue(key, value) }) // Initialize tags map for collecting maxim tags maximTags := make(map[string]string) @@ -159,24 +159,24 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, hea ctx.Request.Header.All()(func(key, value []byte) bool { keyStr := strings.ToLower(string(key)) if labelName, ok := strings.CutPrefix(keyStr, "x-bf-prom-"); ok { - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKey(labelName), string(value)) + bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value)) return true } // Checking for maxim headers if labelName, ok := strings.CutPrefix(keyStr, "x-bf-maxim-"); ok { switch labelName { case string(maxim.GenerationIDKey): - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKey(labelName), string(value)) + bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value)) case string(maxim.TraceIDKey): - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKey(labelName), string(value)) + bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value)) case string(maxim.SessionIDKey): - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKey(labelName), string(value)) + bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value)) case string(maxim.TraceNameKey): - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKey(labelName), string(value)) + bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value)) case string(maxim.GenerationNameKey): - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKey(labelName), string(value)) + bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value)) case string(maxim.LogRepoIDKey): - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKey(labelName), string(value)) + bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value)) default: // apart from these all headers starting with x-bf-maxim- are keys for tags // collect them in the maximTags map @@ -201,13 +201,13 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, hea } } } - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKey("mcp-"+labelName), parsedValues) + bifrostCtx.SetValue(schemas.BifrostContextKey("mcp-"+labelName), parsedValues) return true } } // Handle virtual key header (x-bf-vk, authorization, x-api-key, x-goog-api-key headers) if keyStr == string(schemas.BifrostContextKeyVirtualKey) { - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKeyVirtualKey, string(value)) + bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, string(value)) return true } if keyStr == "authorization" { @@ -216,28 +216,28 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, hea if strings.HasPrefix(strings.ToLower(valueStr), "bearer ") { authHeaderValue := strings.TrimSpace(valueStr[7:]) // Remove "Bearer " prefix if authHeaderValue != "" && strings.HasPrefix(strings.ToLower(authHeaderValue), governance.VirtualKeyPrefix) { - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKeyVirtualKey, authHeaderValue) + bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, authHeaderValue) return true } } } if keyStr == "x-api-key" && strings.HasPrefix(strings.ToLower(string(value)), governance.VirtualKeyPrefix) { - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKeyVirtualKey, string(value)) + bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, string(value)) return true } if keyStr == "x-goog-api-key" && strings.HasPrefix(strings.ToLower(string(value)), governance.VirtualKeyPrefix) { - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKeyVirtualKey, string(value)) + bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, string(value)) return true } if keyStr == "x-bf-api-key" { if keyName := strings.TrimSpace(string(value)); keyName != "" { - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKeyAPIKeyName, keyName) + bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyName, keyName) } return true } // Handle cache key header (x-bf-cache-key) if keyStr == "x-bf-cache-key" { - bifrostCtx = context.WithValue(bifrostCtx, semanticcache.CacheKey, string(value)) + bifrostCtx.SetValue(semanticcache.CacheKey, string(value)) return true } // Handle cache TTL header (x-bf-cache-ttl) @@ -256,7 +256,7 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, hea } if err == nil { - bifrostCtx = context.WithValue(bifrostCtx, semanticcache.CacheTTLKey, ttlDuration) + bifrostCtx.SetValue(semanticcache.CacheTTLKey, ttlDuration) } // If both parsing attempts fail, we silently ignore the header and use default TTL return true @@ -271,20 +271,20 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, hea } else if threshold > 1.0 { threshold = 1.0 } - bifrostCtx = context.WithValue(bifrostCtx, semanticcache.CacheThresholdKey, threshold) + bifrostCtx.SetValue(semanticcache.CacheThresholdKey, threshold) } // If parsing fails, silently ignore the header (no context value set) return true } // Cache type header if keyStr == "x-bf-cache-type" { - bifrostCtx = context.WithValue(bifrostCtx, semanticcache.CacheTypeKey, semanticcache.CacheType(string(value))) + bifrostCtx.SetValue(semanticcache.CacheTypeKey, semanticcache.CacheType(string(value))) return true } // Cache no store header if keyStr == "x-bf-cache-no-store" { if valueStr := string(value); valueStr == "true" { - bifrostCtx = context.WithValue(bifrostCtx, semanticcache.CacheNoStoreKey, true) + bifrostCtx.SetValue(semanticcache.CacheNoStoreKey, true) } return true } @@ -339,7 +339,7 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, hea // Send back raw response header if keyStr == "x-bf-send-back-raw-response" { if valueStr := string(value); valueStr == "true" { - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKeySendBackRawResponse, true) + bifrostCtx.SetValue(schemas.BifrostContextKeySendBackRawResponse, true) } return true } @@ -348,12 +348,12 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, hea // Store the collected maxim tags in the context if len(maximTags) > 0 { - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKey(maxim.TagsKey), maximTags) + bifrostCtx.SetValue(schemas.BifrostContextKey(maxim.TagsKey), maximTags) } // Store collected extra headers in the context if any were found if len(extraHeaders) > 0 { - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKeyExtraHeaders, extraHeaders) + bifrostCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, extraHeaders) } if allowDirectKeys { @@ -397,13 +397,13 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, hea Models: []string{}, // Empty models list - will be validated by provider Weight: 1.0, // Default weight } - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKeyDirectKey, key) + bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, key) } } // Adding fallback context if ctx.UserValue(schemas.BifrostContextKey("x-litellm-fallback")) != nil { - bifrostCtx = context.WithValue(bifrostCtx, schemas.BifrostContextKey("x-litellm-fallback"), "true") + bifrostCtx.SetValue(schemas.BifrostContextKey("x-litellm-fallback"), "true") } - return &bifrostCtx, cancel + return bifrostCtx, cancel } diff --git a/transports/bifrost-http/lib/lib.go b/transports/bifrost-http/lib/lib.go index 4669aca217..6562d1c72f 100644 --- a/transports/bifrost-http/lib/lib.go +++ b/transports/bifrost-http/lib/lib.go @@ -1,7 +1,11 @@ package lib import ( + "context" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + "github.com/maximhq/bifrost/framework/modelcatalog" ) var logger schemas.Logger @@ -10,3 +14,9 @@ var logger schemas.Logger func SetLogger(l schemas.Logger) { logger = l } + +type EnterpriseOverrides interface { + GetGovernancePluginName() string + LoadGovernancePlugin(ctx context.Context, config *Config) (schemas.Plugin, error) + LoadPricingManager(ctx context.Context, pricingConfig *modelcatalog.Config, configStore configstore.ConfigStore) (*modelcatalog.ModelCatalog, error) +} diff --git a/transports/bifrost-http/lib/middleware.go b/transports/bifrost-http/lib/middleware.go index c1657c6aa1..6ff0346441 100644 --- a/transports/bifrost-http/lib/middleware.go +++ b/transports/bifrost-http/lib/middleware.go @@ -1,15 +1,14 @@ package lib -import "github.com/valyala/fasthttp" - -// BifrostHTTPMiddleware is a middleware function for the Bifrost HTTP transport -// It follows the standard pattern: receives the next handler and returns a new handler -type BifrostHTTPMiddleware func(next fasthttp.RequestHandler) fasthttp.RequestHandler +import ( + "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" +) // ChainMiddlewares chains multiple middlewares together // Middlewares are applied in order: the first middleware wraps the second, etc. // This allows earlier middlewares to short-circuit by not calling next(ctx) -func ChainMiddlewares(handler fasthttp.RequestHandler, middlewares ...BifrostHTTPMiddleware) fasthttp.RequestHandler { +func ChainMiddlewares(handler fasthttp.RequestHandler, middlewares ...schemas.BifrostHTTPMiddleware) fasthttp.RequestHandler { // If no middlewares, return the original handler if len(middlewares) == 0 { return handler diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 5fbb86d230..074718ef73 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -18,12 +18,15 @@ import ( "github.com/bytedance/sonic" "github.com/fasthttp/router" + "github.com/google/uuid" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/logstore" + "github.com/maximhq/bifrost/framework/modelcatalog" dynamicPlugins "github.com/maximhq/bifrost/framework/plugins" + "github.com/maximhq/bifrost/framework/tracing" "github.com/maximhq/bifrost/plugins/governance" "github.com/maximhq/bifrost/plugins/logging" "github.com/maximhq/bifrost/plugins/maxim" @@ -65,25 +68,22 @@ type ServerCallbacks interface { ReloadProxyConfig(ctx context.Context, config *tables.GlobalProxyConfig) error ReloadHeaderFilterConfig(ctx context.Context, config *tables.GlobalHeaderFilterConfig) error UpdateDropExcessRequests(ctx context.Context, value bool) + UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error ReloadTeam(ctx context.Context, id string) (*tables.TableTeam, error) RemoveTeam(ctx context.Context, id string) error ReloadCustomer(ctx context.Context, id string) (*tables.TableCustomer, error) RemoveCustomer(ctx context.Context, id string) error ReloadVirtualKey(ctx context.Context, id string) (*tables.TableVirtualKey, error) RemoveVirtualKey(ctx context.Context, id string) error + GetGovernanceData() *governance.GovernanceData AddMCPClient(ctx context.Context, clientConfig schemas.MCPClientConfig) error RemoveMCPClient(ctx context.Context, id string) error EditMCPClient(ctx context.Context, id string, updatedConfig schemas.MCPClientConfig) error } -var ( - BifrostContextKeyBudgetIDs schemas.BifrostContextKey = "budget_ids" - BifrostContextKeyBudgetID schemas.BifrostContextKey = "budget_id" -) - // BifrostHTTPServer represents a HTTP server instance. type BifrostHTTPServer struct { - ctx context.Context + ctx *schemas.BifrostContext cancel context.CancelFunc Version string @@ -101,6 +101,8 @@ type BifrostHTTPServer struct { pluginStatusMutex sync.RWMutex pluginStatus []schemas.PluginStatus + tracingMiddleware *handlers.TracingMiddleware + Client *bifrost.Bifrost Config *lib.Config @@ -108,6 +110,8 @@ type BifrostHTTPServer struct { Router *router.Router WebSocketHandler *handlers.WebSocketHandler LogsCleaner *logstore.LogsCleaner + MCPServerHandler *handlers.MCPServerHandler + devPprofHandler *handlers.DevPprofHandler AuthMiddleware *handlers.AuthMiddleware } @@ -200,14 +204,14 @@ func MarshalPluginConfig[T any](source any) (*T, error) { } type GovernanceInMemoryStore struct { - config *lib.Config + Config *lib.Config } func (s *GovernanceInMemoryStore) GetConfiguredProviders() map[schemas.ModelProvider]configstore.ProviderConfig { // Use read lock for thread-safe access - no need to copy on hot path - s.config.Mu.RLock() - defer s.config.Mu.RUnlock() - return s.config.Providers + s.Config.Mu.RLock() + defer s.Config.Mu.RUnlock() + return s.Config.Providers } // LoadPlugin loads a plugin by name and returns it as type T. @@ -216,7 +220,7 @@ func LoadPlugin[T schemas.Plugin](ctx context.Context, name string, path *string if path != nil { logger.Info("loading dynamic plugin %s from path %s", name, *path) // Load dynamic plugin - plugins, err := dynamicPlugins.LoadPlugins(&dynamicPlugins.Config{ + plugins, err := dynamicPlugins.LoadPlugins(bifrostConfig.PluginLoader, &dynamicPlugins.Config{ Plugins: []dynamicPlugins.DynamicPluginConfig{ { Path: *path, @@ -268,7 +272,7 @@ func LoadPlugin[T schemas.Plugin](ctx context.Context, name string, path *string return zero, fmt.Errorf("failed to marshal governance plugin config: %v", err) } inMemoryStore := &GovernanceInMemoryStore{ - config: bifrostConfig, + Config: bifrostConfig, } plugin, err := governance.Init(ctx, governanceConfig, logger, bifrostConfig.ConfigStore, bifrostConfig.GovernanceConfig, bifrostConfig.PricingManager, inMemoryStore) if err != nil { @@ -374,10 +378,9 @@ func LoadPlugins(ctx context.Context, config *lib.Config) ([]schemas.Plugin, []s }) } // Initializing governance plugin - var governancePlugin *governance.GovernancePlugin if config.ClientConfig.EnableGovernance { // Initialize governance plugin - governancePlugin, err = LoadPlugin[*governance.GovernancePlugin](ctx, governance.PluginName, nil, &governance.Config{ + governancePlugin, err := LoadPlugin[*governance.GovernancePlugin](ctx, governance.PluginName, nil, &governance.Config{ IsVkMandatory: &config.ClientConfig.EnforceGovernanceHeader, }, config) if err != nil { @@ -387,7 +390,7 @@ func LoadPlugins(ctx context.Context, config *lib.Config) ([]schemas.Plugin, []s Status: schemas.PluginStatusError, Logs: []string{fmt.Sprintf("error initializing governance plugin %v", err)}, }) - } else { + } else if governancePlugin != nil { plugins = append(plugins, governancePlugin) pluginStatus = append(pluginStatus, schemas.PluginStatus{ Name: governance.PluginName, @@ -460,60 +463,112 @@ func FindPluginByName[T schemas.Plugin](plugins []schemas.Plugin, name string) ( // AddMCPClient adds a new MCP client to the in-memory store func (s *BifrostHTTPServer) AddMCPClient(ctx context.Context, clientConfig schemas.MCPClientConfig) error { - return s.Config.AddMCPClient(ctx, clientConfig) + if err := s.Config.AddMCPClient(ctx, clientConfig); err != nil { + return err + } + if err := s.MCPServerHandler.SyncAllMCPServers(ctx); err != nil { + logger.Warn("failed to sync MCP servers after adding client: %v", err) + } + return nil +} + +// EditMCPClient edits an MCP client in the in-memory store +func (s *BifrostHTTPServer) EditMCPClient(ctx context.Context, id string, updatedConfig schemas.MCPClientConfig) error { + if err := s.Config.EditMCPClient(ctx, id, updatedConfig); err != nil { + return err + } + if err := s.MCPServerHandler.SyncAllMCPServers(ctx); err != nil { + logger.Warn("failed to sync MCP servers after editing client: %v", err) + } + return nil } // RemoveMCPClient removes an MCP client from the in-memory store func (s *BifrostHTTPServer) RemoveMCPClient(ctx context.Context, id string) error { - return s.Config.RemoveMCPClient(ctx, id) + if err := s.Config.RemoveMCPClient(ctx, id); err != nil { + return err + } + if err := s.MCPServerHandler.SyncAllMCPServers(ctx); err != nil { + logger.Warn("failed to sync MCP servers after removing client: %v", err) + } + return nil } -// EditMCPClient edits an MCP client in the in-memory store -func (s *BifrostHTTPServer) EditMCPClient(ctx context.Context, id string, updatedConfig schemas.MCPClientConfig) error { - return s.Config.EditMCPClient(ctx, id, updatedConfig) +// ExecuteChatMCPTool executes an MCP tool call and returns the result as a chat message. +func (s *BifrostHTTPServer) ExecuteChatMCPTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, *schemas.BifrostError) { + bifrostCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return s.Client.ExecuteChatMCPTool(bifrostCtx, toolCall) +} + +// ExecuteResponsesMCPTool executes an MCP tool call and returns the result as a responses message. +func (s *BifrostHTTPServer) ExecuteResponsesMCPTool(ctx context.Context, toolCall *schemas.ResponsesToolMessage) (*schemas.ResponsesMessage, *schemas.BifrostError) { + bifrostCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return s.Client.ExecuteResponsesMCPTool(bifrostCtx, toolCall) +} + +func (s *BifrostHTTPServer) GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool { + return s.Client.GetAvailableMCPTools(ctx) +} + +// getGovernancePlugin safely retrieves the governance plugin with proper locking. +// It acquires a read lock, finds the plugin, releases the lock, performs type assertion, +// and returns the BaseGovernancePlugin implementation or an error. +func (s *BifrostHTTPServer) getGovernancePlugin() (governance.BaseGovernancePlugin, error) { + s.PluginsMutex.RLock() + plugin, err := FindPluginByName[schemas.Plugin](s.Plugins, governance.PluginName) + s.PluginsMutex.RUnlock() + if err != nil { + return nil, err + } + if plugin == nil { + return nil, fmt.Errorf("governance plugin not found") + } + governancePlugin, ok := plugin.(governance.BaseGovernancePlugin) + if !ok { + return nil, fmt.Errorf("governance plugin does not implement BaseGovernancePlugin") + } + return governancePlugin, nil } // ReloadVirtualKey reloads a virtual key from the in-memory store func (s *BifrostHTTPServer) ReloadVirtualKey(ctx context.Context, id string) (*tables.TableVirtualKey, error) { // Load relationships for response - preloadedVk, err := s.Config.ConfigStore.GetVirtualKey(ctx, id) + preloadedVk, err := s.Config.ConfigStore.RetryOnNotFound(ctx, func(ctx context.Context) (any, error) { + preloadedVk, err := s.Config.ConfigStore.GetVirtualKey(ctx, id) + if err != nil { + return nil, err + } + return preloadedVk, nil + }, lib.DBLookupMaxRetries, lib.DBLookupDelay) if err != nil { - logger.Error("failed to load relationships for created VK: %v", err) + logger.Error("failed to load virtual key: %v", err) return nil, err } - governancePlugin, err := FindPluginByName[*governance.GovernancePlugin](s.Plugins, governance.PluginName) - if err != nil { - return nil, err + if preloadedVk == nil { + logger.Error("virtual key not found") + return nil, fmt.Errorf("virtual key not found") } - if governancePlugin == nil { - return nil, fmt.Errorf("governance plugin not found") + // Type assertion (should never happen) + virtualKey, ok := preloadedVk.(*tables.TableVirtualKey) + if !ok { + logger.Error("virtual key type assertion failed") + return nil, fmt.Errorf("virtual key type assertion failed") } - // Add to in-memory store - governancePlugin.GetGovernanceStore().UpdateVirtualKeyInMemory(preloadedVk) - // If budget was created, add it to in-memory store - if preloadedVk.BudgetID != nil && preloadedVk.Budget != nil { - governancePlugin.GetGovernanceStore().UpdateBudgetInMemory(preloadedVk.Budget) - } - // Add provider-level budgets to in-memory store - if preloadedVk.ProviderConfigs != nil { - for _, pc := range preloadedVk.ProviderConfigs { - if pc.BudgetID != nil && pc.Budget != nil { - governancePlugin.GetGovernanceStore().UpdateBudgetInMemory(pc.Budget) - } - } + governancePlugin, err := s.getGovernancePlugin() + if err != nil { + return nil, err } - return preloadedVk, nil + governancePlugin.GetGovernanceStore().UpdateVirtualKeyInMemory(virtualKey, nil, nil, nil) + s.MCPServerHandler.SyncVKMCPServer(virtualKey) + return virtualKey, nil } // RemoveVirtualKey removes a virtual key from the in-memory store func (s *BifrostHTTPServer) RemoveVirtualKey(ctx context.Context, id string) error { - governancePlugin, err := FindPluginByName[*governance.GovernancePlugin](s.Plugins, governance.PluginName) + governancePlugin, err := s.getGovernancePlugin() if err != nil { return err } - if governancePlugin == nil { - return fmt.Errorf("governance plugin not found") - } preloadedVk, err := s.Config.ConfigStore.GetVirtualKey(ctx, id) if err != nil { if !errors.Is(err, configstore.ErrNotFound) { @@ -523,26 +578,10 @@ func (s *BifrostHTTPServer) RemoveVirtualKey(ctx context.Context, id string) err if preloadedVk == nil { // This could be broadcast message from other server, so we will just clean up in-memory store governancePlugin.GetGovernanceStore().DeleteVirtualKeyInMemory(id) - if budgetIDs, ok := ctx.Value(BifrostContextKeyBudgetIDs).([]string); ok { - for _, budgetID := range budgetIDs { - governancePlugin.GetGovernanceStore().DeleteBudgetInMemory(budgetID) - } - } return nil } governancePlugin.GetGovernanceStore().DeleteVirtualKeyInMemory(id) - // If budget was created, delete it from in-memory store - if preloadedVk.BudgetID != nil && preloadedVk.Budget != nil { - governancePlugin.GetGovernanceStore().DeleteBudgetInMemory(*preloadedVk.BudgetID) - } - // Delete provider-level budgets from in-memory store - if preloadedVk.ProviderConfigs != nil { - for _, pc := range preloadedVk.ProviderConfigs { - if pc.BudgetID != nil && pc.Budget != nil { - governancePlugin.GetGovernanceStore().DeleteBudgetInMemory(*pc.BudgetID) - } - } - } + s.MCPServerHandler.DeleteVKMCPServer(preloadedVk.Value) return nil } @@ -554,31 +593,21 @@ func (s *BifrostHTTPServer) ReloadTeam(ctx context.Context, id string) (*tables. logger.Error("failed to load relationships for created team: %v", err) return nil, err } - governancePlugin, err := FindPluginByName[*governance.GovernancePlugin](s.Plugins, governance.PluginName) + governancePlugin, err := s.getGovernancePlugin() if err != nil { return nil, err } - if governancePlugin == nil { - return nil, fmt.Errorf("governance plugin not found") - } // Add to in-memory store - governancePlugin.GetGovernanceStore().UpdateTeamInMemory(preloadedTeam) - // If budget was created, add it to in-memory store - if preloadedTeam.BudgetID != nil && preloadedTeam.Budget != nil { - governancePlugin.GetGovernanceStore().UpdateBudgetInMemory(preloadedTeam.Budget) - } + governancePlugin.GetGovernanceStore().UpdateTeamInMemory(preloadedTeam, nil) return preloadedTeam, nil } // RemoveTeam removes a team from the in-memory store func (s *BifrostHTTPServer) RemoveTeam(ctx context.Context, id string) error { - governancePlugin, err := FindPluginByName[*governance.GovernancePlugin](s.Plugins, governance.PluginName) + governancePlugin, err := s.getGovernancePlugin() if err != nil { return err } - if governancePlugin == nil { - return fmt.Errorf("governance plugin not found") - } preloadedTeam, err := s.Config.ConfigStore.GetTeam(ctx, id) if err != nil { if !errors.Is(err, configstore.ErrNotFound) { @@ -588,16 +617,9 @@ func (s *BifrostHTTPServer) RemoveTeam(ctx context.Context, id string) error { if preloadedTeam == nil { // At-least deleting from in-memory store to avoid conflicts governancePlugin.GetGovernanceStore().DeleteTeamInMemory(id) - if budgetID, ok := ctx.Value(BifrostContextKeyBudgetID).(string); ok { - governancePlugin.GetGovernanceStore().DeleteBudgetInMemory(budgetID) - } return nil } governancePlugin.GetGovernanceStore().DeleteTeamInMemory(id) - // If budget was created, delete it from in-memory store - if preloadedTeam.BudgetID != nil && preloadedTeam.Budget != nil { - governancePlugin.GetGovernanceStore().DeleteBudgetInMemory(*preloadedTeam.BudgetID) - } return nil } @@ -607,31 +629,21 @@ func (s *BifrostHTTPServer) ReloadCustomer(ctx context.Context, id string) (*tab if err != nil { return nil, err } - governancePlugin, err := FindPluginByName[*governance.GovernancePlugin](s.Plugins, governance.PluginName) + governancePlugin, err := s.getGovernancePlugin() if err != nil { return nil, err } - if governancePlugin == nil { - return nil, fmt.Errorf("governance plugin not found") - } // Add to in-memory store - governancePlugin.GetGovernanceStore().UpdateCustomerInMemory(preloadedCustomer) - // If budget was created, add it to in-memory store - if preloadedCustomer.BudgetID != nil && preloadedCustomer.Budget != nil { - governancePlugin.GetGovernanceStore().UpdateBudgetInMemory(preloadedCustomer.Budget) - } + governancePlugin.GetGovernanceStore().UpdateCustomerInMemory(preloadedCustomer, nil) return preloadedCustomer, nil } // RemoveCustomer removes a customer from the in-memory store func (s *BifrostHTTPServer) RemoveCustomer(ctx context.Context, id string) error { - governancePlugin, err := FindPluginByName[*governance.GovernancePlugin](s.Plugins, governance.PluginName) + governancePlugin, err := s.getGovernancePlugin() if err != nil { return err } - if governancePlugin == nil { - return fmt.Errorf("governance plugin not found") - } preloadedCustomer, err := s.Config.ConfigStore.GetCustomer(ctx, id) if err != nil { if !errors.Is(err, configstore.ErrNotFound) { @@ -641,15 +653,23 @@ func (s *BifrostHTTPServer) RemoveCustomer(ctx context.Context, id string) error if preloadedCustomer == nil { // At-least deleting from in-memory store to avoid conflicts governancePlugin.GetGovernanceStore().DeleteCustomerInMemory(id) - if budgetID, ok := ctx.Value(BifrostContextKeyBudgetID).(string); ok { - governancePlugin.GetGovernanceStore().DeleteBudgetInMemory(budgetID) - } return nil } governancePlugin.GetGovernanceStore().DeleteCustomerInMemory(id) - // If budget was created, delete it from in-memory store - if preloadedCustomer.BudgetID != nil && preloadedCustomer.Budget != nil { - governancePlugin.GetGovernanceStore().DeleteBudgetInMemory(*preloadedCustomer.BudgetID) + return nil +} + +// GetGovernanceData returns the governance data +func (s *BifrostHTTPServer) GetGovernanceData() *governance.GovernanceData { + s.PluginsMutex.RLock() + governancePlugin, err := FindPluginByName[schemas.Plugin](s.Plugins, governance.PluginName) + s.PluginsMutex.RUnlock() + if err != nil { + return nil + } + // Check if GetGovernanceStore method is implemented + if governancePlugin, ok := governancePlugin.(governance.BaseGovernancePlugin); ok { + return governancePlugin.GetGovernanceStore().GetGovernanceData() } return nil } @@ -718,6 +738,14 @@ func (s *BifrostHTTPServer) UpdateDropExcessRequests(ctx context.Context, value s.Client.UpdateDropExcessRequests(value) } +// UpdateMCPToolManagerConfig updates the MCP tool manager config +func (s *BifrostHTTPServer) UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error { + if s.Config == nil { + return fmt.Errorf("config not found") + } + return s.Client.UpdateToolManagerConfig(maxAgentDepth, toolExecutionTimeoutInSeconds, codeModeBindingLevel) +} + // UpdatePluginStatus updates the status of a plugin func (s *BifrostHTTPServer) UpdatePluginStatus(name string, status string, logs []string) error { s.pluginStatusMutex.Lock() @@ -824,7 +852,26 @@ func (s *BifrostHTTPServer) ReloadPlugin(ctx context.Context, name string, path s.UpdatePluginStatus(name, schemas.PluginStatusError, []string{fmt.Sprintf("error loading plugin %s: %v", name, err)}) return err } - return s.SyncLoadedPlugin(ctx, name, newPlugin) + err = s.SyncLoadedPlugin(ctx, name, newPlugin) + if err != nil { + return err + } + // Here if its observability plugin, we need to reload it + if _, ok := newPlugin.(schemas.ObservabilityPlugin); ok { + // We will re-collect the observability plugins from the plugins list + observabilityPlugins := []schemas.ObservabilityPlugin{} + s.PluginsMutex.RLock() + defer s.PluginsMutex.RUnlock() + for _, plugin := range s.Plugins { + if observabilityPlugin, ok := plugin.(schemas.ObservabilityPlugin); ok { + observabilityPlugins = append(observabilityPlugins, observabilityPlugin) + } + } + if len(observabilityPlugins) > 0 { + s.tracingMiddleware.SetObservabilityPlugins(observabilityPlugins) + } + } + return nil } // ReloadPricingManager reloads the pricing manager @@ -882,7 +929,9 @@ func (s *BifrostHTTPServer) RefetchModelsForProvider(ctx context.Context, provid if s.Client == nil { return fmt.Errorf("bifrost client not found") } - allModels, err := s.Client.ListModelsRequest(ctx, &schemas.BifrostListModelsRequest{ + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + defer bfCtx.Cancel() + allModels, err := s.Client.ListModelsRequest(bfCtx, &schemas.BifrostListModelsRequest{ Provider: provider, }) if err != nil { @@ -914,6 +963,8 @@ func (s *BifrostHTTPServer) GetModelsForProvider(provider schemas.ModelProvider) // RemovePlugin removes a plugin from the server. // Uses atomic CompareAndSwap with retry loop to handle concurrent updates safely. func (s *BifrostHTTPServer) RemovePlugin(ctx context.Context, name string) error { + // Get plugin + plugin, _ := FindPluginByName[schemas.Plugin](s.Plugins, name) if err := s.Client.RemovePlugin(name); err != nil { return err } @@ -924,6 +975,7 @@ func (s *BifrostHTTPServer) RemovePlugin(ctx context.Context, name string) error // Plugin is being deleted - remove the status entry completely s.DeletePluginStatus(name) } + // CAS retry loop (matching bifrost.go pattern) for { oldPlugins := s.Config.Plugins.Load() @@ -941,16 +993,30 @@ func (s *BifrostHTTPServer) RemovePlugin(ctx context.Context, name string) error // Atomic compare-and-swap if s.Config.Plugins.CompareAndSwap(oldPlugins, &newPlugins) { s.PluginsMutex.Lock() - defer s.PluginsMutex.Unlock() s.Plugins = newPlugins // Keep BifrostHTTPServer.Plugins in sync - return nil + s.PluginsMutex.Unlock() + break } // Retry on contention (extremely rare for plugin updates) } + // Here if its observability plugin, we need to reload it + if _, ok := plugin.(schemas.ObservabilityPlugin); ok { + // We will re-collect the observability plugins from the plugins list + observabilityPlugins := []schemas.ObservabilityPlugin{} + s.PluginsMutex.RLock() + defer s.PluginsMutex.RUnlock() + for _, plugin := range s.Plugins { + if observabilityPlugin, ok := plugin.(schemas.ObservabilityPlugin); ok { + observabilityPlugins = append(observabilityPlugins, observabilityPlugin) + } + } + s.tracingMiddleware.SetObservabilityPlugins(observabilityPlugins) + } + return nil } // RegisterInferenceRoutes initializes the routes for the inference handler -func (s *BifrostHTTPServer) RegisterInferenceRoutes(ctx context.Context, middlewares ...lib.BifrostHTTPMiddleware) error { +func (s *BifrostHTTPServer) RegisterInferenceRoutes(ctx context.Context, middlewares ...schemas.BifrostHTTPMiddleware) error { inferenceHandler := handlers.NewInferenceHandler(s.Client, s.Config) integrationHandler := handlers.NewIntegrationHandler(s.Client, s.Config) @@ -960,7 +1026,7 @@ func (s *BifrostHTTPServer) RegisterInferenceRoutes(ctx context.Context, middlew } // RegisterAPIRoutes initializes the routes for the Bifrost HTTP server. -func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks ServerCallbacks, middlewares ...lib.BifrostHTTPMiddleware) error { +func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks ServerCallbacks, middlewares ...schemas.BifrostHTTPMiddleware) error { var err error // Initializing plugin specific handlers var loggingHandler *handlers.LoggingHandler @@ -969,7 +1035,7 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser loggingHandler = handlers.NewLoggingHandler(loggerPlugin.GetPluginLogManager(), s) } var governanceHandler *handlers.GovernanceHandler - governancePlugin, _ := FindPluginByName[*governance.GovernancePlugin](s.Plugins, governance.PluginName) + governancePlugin, _ := FindPluginByName[schemas.Plugin](s.Plugins, governance.PluginName) if governancePlugin != nil { governanceHandler, err = handlers.NewGovernanceHandler(callbacks, s.Config.ConfigStore) if err != nil { @@ -997,6 +1063,11 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser healthHandler := handlers.NewHealthHandler(s.Config) providerHandler := handlers.NewProviderHandler(callbacks, s.Config, s.Client) mcpHandler := handlers.NewMCPHandler(callbacks, s.Client, s.Config) + mcpServerHandler, err := handlers.NewMCPServerHandler(ctx, s.Config, s) + if err != nil { + return fmt.Errorf("failed to initialize mcp server handler: %v", err) + } + s.MCPServerHandler = mcpServerHandler configHandler := handlers.NewConfigHandler(callbacks, s.Config) pluginsHandler := handlers.NewPluginsHandler(callbacks, s.Config.ConfigStore) sessionHandler := handlers.NewSessionHandler(s.Config.ConfigStore) @@ -1004,6 +1075,7 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser healthHandler.RegisterRoutes(s.Router, middlewares...) providerHandler.RegisterRoutes(s.Router, middlewares...) mcpHandler.RegisterRoutes(s.Router, middlewares...) + mcpServerHandler.RegisterRoutes(s.Router, middlewares...) configHandler.RegisterRoutes(s.Router, middlewares...) if pluginsHandler != nil { pluginsHandler.RegisterRoutes(s.Router, middlewares...) @@ -1023,6 +1095,12 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser if s.WebSocketHandler != nil { s.WebSocketHandler.RegisterRoutes(s.Router, middlewares...) } + // Register dev pprof handler only in dev mode + if handlers.IsDevMode() { + logger.Info("dev mode enabled, registering pprof endpoints") + s.devPprofHandler = handlers.NewDevPprofHandler() + s.devPprofHandler.RegisterRoutes(s.Router, middlewares...) + } // Add Prometheus /metrics endpoint prometheusPlugin, err := FindPluginByName[*telemetry.PrometheusPlugin](s.Plugins, telemetry.PluginName) if err == nil && prometheusPlugin.GetRegistry() != nil { @@ -1040,7 +1118,7 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser } // RegisterUIRoutes registers the UI handler with the specified router -func (s *BifrostHTTPServer) RegisterUIRoutes(middlewares ...lib.BifrostHTTPMiddleware) { +func (s *BifrostHTTPServer) RegisterUIRoutes(middlewares ...schemas.BifrostHTTPMiddleware) { // WARNING: This UI handler needs to be registered after all the other handlers handlers.NewUIHandler(s.UIContent).RegisterRoutes(s.Router, middlewares...) } @@ -1071,9 +1149,17 @@ func (s *BifrostHTTPServer) GetAllRedactedVirtualKeys(ctx context.Context, ids [ return virtualKeys } +func (s *BifrostHTTPServer) LoadPricingManager(ctx context.Context, pricingConfig *modelcatalog.Config, configStore configstore.ConfigStore) (*modelcatalog.ModelCatalog, error) { + pricingManager, err := modelcatalog.Init(ctx, pricingConfig, configStore, nil, logger) + if err != nil { + return nil, fmt.Errorf("failed to initialize pricing manager: %v", err) + } + return pricingManager, nil +} + // PrepareCommonMiddlewares gets the common middlewares for the Bifrost HTTP server -func (s *BifrostHTTPServer) PrepareCommonMiddlewares() []lib.BifrostHTTPMiddleware { - commonMiddlewares := []lib.BifrostHTTPMiddleware{} +func (s *BifrostHTTPServer) PrepareCommonMiddlewares() []schemas.BifrostHTTPMiddleware { + commonMiddlewares := []schemas.BifrostHTTPMiddleware{} // Preparing middlewares // Initializing prometheus plugin prometheusPlugin, err := FindPluginByName[*telemetry.PrometheusPlugin](s.Plugins, telemetry.PluginName) @@ -1098,7 +1184,7 @@ func (s *BifrostHTTPServer) PrepareCommonMiddlewares() []lib.BifrostHTTPMiddlewa // - GET /metrics: For Prometheus metrics func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { var err error - s.ctx, s.cancel = context.WithCancel(ctx) + s.ctx, s.cancel = schemas.NewBifrostContextWithCancel(ctx) handlers.SetVersion(s.Version) configDir := GetDefaultConfigDir(s.AppDir) s.pluginStatusMutex = sync.RWMutex{} @@ -1112,6 +1198,8 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to load config %v", err) } + // Initializing plugin loader + s.Config.PluginLoader = &dynamicPlugins.SharedObjectPluginLoader{} // Initialize log retention cleaner if log store is configured if s.Config.LogsStore != nil { // If log retention days remains 0, then we wont be initializing the log retention cleaner @@ -1151,6 +1239,12 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to load plugins %v", err) } + mcpConfig := s.Config.MCPConfig + if mcpConfig != nil { + mcpConfig.FetchNewRequestIDFunc = func(ctx *schemas.BifrostContext) string { + return uuid.New().String() + } + } // Initialize bifrost client // Create account backed by the high-performance store (all processing is done in LoadFromDatabase) // The account interface now benefits from ultra-fast config access times via in-memory storage @@ -1160,7 +1254,7 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { InitialPoolSize: s.Config.ClientConfig.InitialPoolSize, DropExcessRequests: s.Config.ClientConfig.DropExcessRequests, Plugins: s.Plugins, - MCPConfig: s.Config.MCPConfig, + MCPConfig: mcpConfig, Logger: logger, }) if err != nil { @@ -1169,7 +1263,7 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { logger.Info("bifrost client initialized") // List all models and add to model catalog logger.Info("listing all models and adding to model catalog") - modelData, listModelsErr := s.Client.ListAllModels(ctx, nil) + modelData, listModelsErr := s.Client.ListAllModels(s.ctx, nil) if listModelsErr != nil { if listModelsErr.Error != nil { logger.Error("failed to list all models: %s", listModelsErr.Error.Message) @@ -1208,7 +1302,23 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { inferenceMiddlewares = append(inferenceMiddlewares, s.AuthMiddleware.InferenceMiddleware()) } // Registering inference middlewares - inferenceMiddlewares = append([]lib.BifrostHTTPMiddleware{handlers.TransportInterceptorMiddleware(s.Config)}, inferenceMiddlewares...) + inferenceMiddlewares = append([]schemas.BifrostHTTPMiddleware{handlers.TransportInterceptorMiddleware(s.Config)}, inferenceMiddlewares...) + // Curating observability plugins + observabilityPlugins := []schemas.ObservabilityPlugin{} + for _, plugin := range s.Plugins { + if observabilityPlugin, ok := plugin.(schemas.ObservabilityPlugin); ok { + observabilityPlugins = append(observabilityPlugins, observabilityPlugin) + } + } + // This enables the central streaming accumulator for both use cases + // Initializing tracer with embedded streaming accumulator + traceStore := tracing.NewTraceStore(60*time.Minute, logger) + tracer := tracing.NewTracer(traceStore, s.Config.PricingManager, logger) + s.Client.SetTracer(tracer) + // Always add tracing middleware when tracer is enabled - it creates traces and sets traceID in context + // The observability plugins are optional (can be empty if only logging is enabled) + s.tracingMiddleware = handlers.NewTracingMiddleware(tracer, observabilityPlugins) + inferenceMiddlewares = append([]schemas.BifrostHTTPMiddleware{s.tracingMiddleware.Middleware()}, inferenceMiddlewares...) err = s.RegisterInferenceRoutes(s.ctx, inferenceMiddlewares...) if err != nil { return fmt.Errorf("failed to initialize inference routes: %v", err) @@ -1227,6 +1337,12 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { // Start starts the HTTP server at the specified host and port // Also watches signals and errors func (s *BifrostHTTPServer) Start() error { + // Printing plugin status in a table + s.pluginStatusMutex.RLock() + for _, pluginStatus := range s.pluginStatus { + logger.Info("plugin status: %s - %s", pluginStatus.Name, pluginStatus.Status) + } + s.pluginStatusMutex.RUnlock() // Create channels for signal and error handling sigChan := make(chan os.Signal, 1) errChan := make(chan error, 1) @@ -1280,6 +1396,10 @@ func (s *BifrostHTTPServer) Start() error { logger.Info("stopping log retention cleaner...") s.LogsCleaner.StopCleanupRoutine() } + if s.devPprofHandler != nil { + logger.Info("stopping dev pprof handler...") + s.devPprofHandler.Cleanup() + } if s.Config != nil && s.Config.LogsStore != nil { s.Config.LogsStore.Close(shutdownCtx) } diff --git a/transports/config.schema.json b/transports/config.schema.json index 91bb67b3cb..8b68a44c4d 100644 --- a/transports/config.schema.json +++ b/transports/config.schema.json @@ -416,6 +416,9 @@ "$ref": "#/$defs/mcp_client_config" }, "description": "MCP client configurations" + }, + "tool_manager_config": { + "$ref": "#/$defs/mcp_tool_manager_config" } }, "additionalProperties": false @@ -1781,6 +1784,23 @@ } ] }, + "mcp_tool_manager_config": { + "type": "object", + "properties": { + "tool_execution_timeout": { + "type": "integer", + "description": "Tool execution timeout in seconds", + "minimum": 1, + "default": 30 + }, + "max_agent_depth": { + "type": "integer", + "description": "Max agent depth", + "minimum": 1, + "default": 10 + } + } + }, "weaviate_config": { "type": "object", "description": "Weaviate configuration for vector store", diff --git a/transports/go.mod b/transports/go.mod index eb472a06b9..2b86bf9e19 100644 --- a/transports/go.mod +++ b/transports/go.mod @@ -7,15 +7,17 @@ require ( github.com/bytedance/sonic v1.14.2 github.com/fasthttp/router v1.5.4 github.com/fasthttp/websocket v1.5.12 + github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f github.com/google/uuid v1.6.0 - github.com/maximhq/bifrost/core v1.2.49 - github.com/maximhq/bifrost/framework v1.1.61 - github.com/maximhq/bifrost/plugins/governance v1.3.62 - github.com/maximhq/bifrost/plugins/logging v1.3.62 - github.com/maximhq/bifrost/plugins/maxim v1.4.63 - github.com/maximhq/bifrost/plugins/otel v1.0.61 - github.com/maximhq/bifrost/plugins/semanticcache v1.3.61 - github.com/maximhq/bifrost/plugins/telemetry v1.3.61 + github.com/mark3labs/mcp-go v0.43.2 + github.com/maximhq/bifrost/core v1.3.8 + github.com/maximhq/bifrost/framework v1.2.8 + github.com/maximhq/bifrost/plugins/governance v1.4.9 + github.com/maximhq/bifrost/plugins/logging v1.4.8 + github.com/maximhq/bifrost/plugins/maxim v1.5.8 + github.com/maximhq/bifrost/plugins/otel v1.1.8 + github.com/maximhq/bifrost/plugins/semanticcache v1.4.8 + github.com/maximhq/bifrost/plugins/telemetry v1.4.9 github.com/prometheus/client_golang v1.23.0 github.com/stretchr/testify v1.11.1 github.com/valyala/fasthttp v1.68.0 @@ -54,9 +56,12 @@ require ( github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/sonic/loader v0.4.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/clarkmcc/go-typescript v0.7.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/analysis v0.24.2 // indirect @@ -80,6 +85,7 @@ require ( github.com/go-openapi/swag/typeutils v0.25.4 // indirect github.com/go-openapi/swag/yamlutils v0.25.4 // indirect github.com/go-openapi/validate v0.25.1 // indirect + github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect @@ -95,12 +101,11 @@ require ( github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/mailru/easyjson v0.9.1 // indirect - github.com/mark3labs/mcp-go v0.43.2 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-sqlite3 v1.14.32 // indirect - github.com/maximhq/bifrost/plugins/mocker v1.3.60 // indirect - github.com/maximhq/maxim-go v0.1.15 // indirect + github.com/maximhq/bifrost/plugins/mocker v1.4.8 // indirect + github.com/maximhq/maxim-go v0.1.14 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect diff --git a/transports/go.sum b/transports/go.sum index 19dbd1bff7..46c352aeb9 100644 --- a/transports/go.sum +++ b/transports/go.sum @@ -12,6 +12,8 @@ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJ github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/Masterminds/semver/v3 v3.3.1 h1:QtNSWtVZ3nBfk8mAOu/B6v7FMJ+NHTIgUPi7rj+4nv4= +github.com/Masterminds/semver/v3 v3.3.1/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= @@ -70,6 +72,8 @@ github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2N github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/clarkmcc/go-typescript v0.7.0 h1:3nVeaPYyTCWjX6Lf8GoEOTxME2bM5tLuWmwhSZ86uxg= +github.com/clarkmcc/go-typescript v0.7.0/go.mod h1:IZ/nzoVeydAmyfX7l6Jmp8lJDOEnae3jffoXwP4UyYg= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= @@ -79,6 +83,10 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7 h1:jxmXU5V9tXxJnydU5v/m9SG8TRUa/Z7IXODBpMs/P+U= +github.com/dop251/goja v0.0.0-20251103141225-af2ceb9156d7/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= github.com/fasthttp/router v1.5.4 h1:oxdThbBwQgsDIYZ3wR1IavsNl6ZS9WdjKukeMikOnC8= github.com/fasthttp/router v1.5.4/go.mod h1:3/hysWq6cky7dTfzaaEPZGdptwjwx0qzTgFCKEWRjgc= github.com/fasthttp/websocket v1.5.12 h1:e4RGPpWW2HTbL3zV0Y/t7g0ub294LkiuXXUuTOUInlE= @@ -138,6 +146,8 @@ github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6 github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= github.com/go-openapi/validate v0.25.1 h1:sSACUI6Jcnbo5IWqbYHgjibrhhmt3vR6lCzKZnmAgBw= github.com/go-openapi/validate v0.25.1/go.mod h1:RMVyVFYte0gbSTaZ0N4KmTn6u/kClvAFp+mAVfS/DQc= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= +github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -147,6 +157,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f h1:HU1RgM6NALf/KW9HEY6zry3ADbDKcmpQ+hJedoNGQYQ= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f/go.mod h1:67FPmZWbr+KDT/VlpWtw6sO9XSjpJmLuHpoLmWiTGgY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU= @@ -194,26 +206,26 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.2.49 h1:fk6l6r3kVBlpN73wYXmgtV6O4bhedOjSO4LAEz/7leg= -github.com/maximhq/bifrost/core v1.2.49/go.mod h1:z7nOx15e91ktZGi+pZHq+uhShlEK+fM4UyYUpP6oHAw= -github.com/maximhq/bifrost/framework v1.1.61 h1:fMjvICbkrdWMtGnLYrjSNrcmQYqtQvOh/swmrJTvf+E= -github.com/maximhq/bifrost/framework v1.1.61/go.mod h1:wVUPzB8K5S/5GWuxqp8dXf3nNZkqJsS/APMIcq48SOI= -github.com/maximhq/bifrost/plugins/governance v1.3.62 h1:HKHtj1HxzmBn6Lan/HWvmC/Cnne4BVSwHn9kqDem4eM= -github.com/maximhq/bifrost/plugins/governance v1.3.62/go.mod h1:1poufEWNh0gmOGKUkjMTGCECUYhf0YBGnvmJAAwyq14= -github.com/maximhq/bifrost/plugins/logging v1.3.62 h1:pTnj3DudsUKzSaHfZxFAHk/Yz5CrKk1MdPVIwKPHHCI= -github.com/maximhq/bifrost/plugins/logging v1.3.62/go.mod h1:RIPJyB6Oft51to5zXc1xDJEAZWCuIts6lG1LeyDZTvw= -github.com/maximhq/bifrost/plugins/maxim v1.4.63 h1:550Q7MEwiKdvSnCzBw+kWQTo9vfDJG+mJaSUdNogoSE= -github.com/maximhq/bifrost/plugins/maxim v1.4.63/go.mod h1:N1l/ggh8Ys+ySkEyHOAz/ieRWi3FTgnbjb8S5q87tBo= -github.com/maximhq/bifrost/plugins/mocker v1.3.60 h1:iyAgQgvLh8KR9DOeN/gfi9Ie9yq8fWDpigTzaJOPuoQ= -github.com/maximhq/bifrost/plugins/mocker v1.3.60/go.mod h1:pWahxfU/fNCN3rLT431T0T5zkZs4QLuJToLGHH0nnQA= -github.com/maximhq/bifrost/plugins/otel v1.0.61 h1:hww7sTf6t5AgBLzdz17Jq4yZisTVmbMOEHB7800P3Cc= -github.com/maximhq/bifrost/plugins/otel v1.0.61/go.mod h1:Yf7+rd1tf7fiIzftyIxna2KFBHRbOMt14VPqiHNXdQ0= -github.com/maximhq/bifrost/plugins/semanticcache v1.3.61 h1:oz3g1EWf7Cexq9v47p/Pav4ILtar4tVJTT5722MrQkg= -github.com/maximhq/bifrost/plugins/semanticcache v1.3.61/go.mod h1:+h+NJA5cInuRRxZ3lFF2KOA8HEp2D3MakG03ByOKEX0= -github.com/maximhq/bifrost/plugins/telemetry v1.3.61 h1:3RZntuUOoAXrwTJ8r6YDd99pYlzm1t2LZcsPE88J22Y= -github.com/maximhq/bifrost/plugins/telemetry v1.3.61/go.mod h1:LTAR6ow7SjIGSRlVmcuKxUPHKycOO4dLmxJlhHHo9C8= -github.com/maximhq/maxim-go v0.1.15 h1:PCoS5B/0QB3VqwqpgDgCHSTaYPVVKp/mFpb7iZ09XM0= -github.com/maximhq/maxim-go v0.1.15/go.mod h1:0+UTWM7UZwNNE5VnljLtr/vpRGtYP8r/2q9WDwlLWFw= +github.com/maximhq/bifrost/core v1.3.8 h1:xtwB9+HeTzYz5IKHkpUtupzBd0A5yl1avdLJGjsOKPI= +github.com/maximhq/bifrost/core v1.3.8/go.mod h1:abKQRnJQPZz8/UMxCcbuNHEyq19Db+IX4KlGJdlLY8E= +github.com/maximhq/bifrost/framework v1.2.8 h1:/oTpacuw7k0zRUJ9dSSQRtAVx3nLGSiR7GFwOjGxZNs= +github.com/maximhq/bifrost/framework v1.2.8/go.mod h1:mjw9YXh/Oxi3HeBCJ+3HJ6ftv43Wo4t0T4EzpcIbnr0= +github.com/maximhq/bifrost/plugins/governance v1.4.9 h1:xjL5X5Ueraisl70sc0SWb/Ws/DsMLsgnUWWljjy+JSQ= +github.com/maximhq/bifrost/plugins/governance v1.4.9/go.mod h1:81hyb2O7X66Rpv/brM1pNVLp1hxtF++DaW6PWaTdL2Q= +github.com/maximhq/bifrost/plugins/logging v1.4.8 h1:qAByLKr+HNtJhgkI28a41ufXuUJi+DFDccYh4G9Xou8= +github.com/maximhq/bifrost/plugins/logging v1.4.8/go.mod h1:FyhognaaekK+mwykMpd4AdZ+iXOuQR78+u5Jgsw+OHo= +github.com/maximhq/bifrost/plugins/maxim v1.5.8 h1:ZIke9WLETYMVKpkKWl/m5F7VCEVIyGM2h6eqhrrPSyM= +github.com/maximhq/bifrost/plugins/maxim v1.5.8/go.mod h1:JxLzdHvmuGYasTxKRdpVbEXIINCNq8RWHY5agbcVX78= +github.com/maximhq/bifrost/plugins/mocker v1.4.8 h1:wh5JgUBLzLOjlTpRZD2MO2tl8fNAvhlWd8V5i6Ot+8Q= +github.com/maximhq/bifrost/plugins/mocker v1.4.8/go.mod h1:ScVor8GWDOytkJ0U7nZ1HkrZ/xQbwkLtWSU0m1vpdSs= +github.com/maximhq/bifrost/plugins/otel v1.1.8 h1:QvS3MnMXJITcnqe9cEPff31ZS78B5bDC/jXzdjRYFvk= +github.com/maximhq/bifrost/plugins/otel v1.1.8/go.mod h1:jFnfaIypMZsaGt0RJMM4Gr1gAsHiklj3knge9+sN2C8= +github.com/maximhq/bifrost/plugins/semanticcache v1.4.8 h1:uysgne4FAsFdzv8yTjvzRoOmJLU6pPLwCjyISgxY9E0= +github.com/maximhq/bifrost/plugins/semanticcache v1.4.8/go.mod h1:9Fz57x9d6k8i+pioI6ut7AqgrzLxBfiH+MroGNWd2pc= +github.com/maximhq/bifrost/plugins/telemetry v1.4.9 h1:EfZJiZR/sutyMO8RpVJzf0Vs4MCF47cnHaRDhx+VMxM= +github.com/maximhq/bifrost/plugins/telemetry v1.4.9/go.mod h1:5dyMWDZjU15zk0PHnGXII81mKLSTvRdzuIbRj7IYtp0= +github.com/maximhq/maxim-go v0.1.14 h1:NQgpf3aRoD2Kq1GAqeSrLn3rQresn1H6mPP3JJ85qhA= +github.com/maximhq/maxim-go v0.1.14/go.mod h1:0+UTWM7UZwNNE5VnljLtr/vpRGtYP8r/2q9WDwlLWFw= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= @@ -327,6 +339,8 @@ google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/transports/version b/transports/version index 022360fe76..dac8c45cbe 100644 --- a/transports/version +++ b/transports/version @@ -1 +1 @@ -1.3.63 \ No newline at end of file +1.4.0-prerelease9 \ No newline at end of file diff --git a/ui/app/_fallbacks/enterprise/components/api-keys/APIKeysView.tsx b/ui/app/_fallbacks/enterprise/components/api-keys/apiKeysIndexView.tsx similarity index 93% rename from ui/app/_fallbacks/enterprise/components/api-keys/APIKeysView.tsx rename to ui/app/_fallbacks/enterprise/components/api-keys/apiKeysIndexView.tsx index 36ef7bf5a7..7929d9bea0 100644 --- a/ui/app/_fallbacks/enterprise/components/api-keys/APIKeysView.tsx +++ b/ui/app/_fallbacks/enterprise/components/api-keys/apiKeysIndexView.tsx @@ -5,13 +5,12 @@ import { Button } from "@/components/ui/button"; import { useGetCoreConfigQuery } from "@/lib/store"; import { Copy, InfoIcon, KeyRound } from "lucide-react"; import Link from "next/link"; -import { useMemo, useState } from "react"; +import { useMemo } from "react"; import { toast } from "sonner"; import ContactUsView from "../views/contactUsView"; export default function APIKeysView() { const { data: bifrostConfig, isLoading } = useGetCoreConfigQuery({ fromDB: true }); - const [isTokenVisible, setIsTokenVisible] = useState(false); const isAuthConfigure = useMemo(() => { return bifrostConfig?.auth_config?.is_enabled; }, [bifrostConfig]); @@ -32,11 +31,6 @@ curl --location 'http://localhost:8080/v1/chat/completions' ] }'`; - const maskToken = (token: string, revealed: boolean) => { - if (revealed) return token; - return token.substring(0, 8) + "•".repeat(Math.max(0, token.length - 8)); - }; - const copyToClipboard = (text: string) => { navigator.clipboard.writeText(text); toast.success("Copied to clipboard"); diff --git a/ui/app/clientLayout.tsx b/ui/app/clientLayout.tsx index 363e16d65d..3d070526c8 100644 --- a/ui/app/clientLayout.tsx +++ b/ui/app/clientLayout.tsx @@ -10,11 +10,18 @@ import { WebSocketProvider } from "@/hooks/useWebSocket"; import { getErrorMessage, ReduxProvider, useGetCoreConfigQuery } from "@/lib/store"; import { BifrostConfig } from "@/lib/types/config"; import { RbacProvider } from "@enterprise/lib/contexts/rbacContext"; +import dynamic from "next/dynamic"; import { usePathname } from "next/navigation"; import { NuqsAdapter } from "nuqs/adapters/next/app"; -import { Suspense, useEffect } from "react"; +import { useEffect } from "react"; import { toast, Toaster } from "sonner"; +// Dynamic import - only loaded in development, completely excluded from prod bundle +const DevProfiler = dynamic( + () => import("@/components/devProfiler").then(mod => ({ default: mod.DevProfiler })), + { ssr: false } +); + function AppContent({ children }: { children: React.ReactNode }) { const { data: bifrostConfig, error, isLoading } = useGetCoreConfigQuery({}); @@ -28,7 +35,7 @@ function AppContent({ children }: { children: React.ReactNode }) { -
+
{isLoading ? : {children}}
@@ -58,6 +65,7 @@ export function ClientLayout({ children }: { children: React.ReactNode }) { {children} + {process.env.NODE_ENV === 'development' && } diff --git a/ui/app/workspace/config/api-keys/page.tsx b/ui/app/workspace/config/api-keys/page.tsx index c2a0feeaac..3c8ce98af7 100644 --- a/ui/app/workspace/config/api-keys/page.tsx +++ b/ui/app/workspace/config/api-keys/page.tsx @@ -1,6 +1,6 @@ "use client" -import APIKeysView from "@enterprise/components/api-keys/APIKeysView" +import APIKeysView from "@enterprise/components/api-keys/apiKeysIndexView" export default function APIKeysPage() { return ( diff --git a/ui/app/workspace/config/logging/page.tsx b/ui/app/workspace/config/logging/page.tsx index 4ccaddb310..a285754aa6 100644 --- a/ui/app/workspace/config/logging/page.tsx +++ b/ui/app/workspace/config/logging/page.tsx @@ -1,12 +1,11 @@ -"use client" +"use client"; -import LoggingView from "../views/loggingView" +import LoggingView from "../views/loggingView"; export default function LoggingPage() { - return ( -
- -
- ) + return ( +
+ +
+ ); } - diff --git a/ui/app/workspace/config/mcp-gateway/page.tsx b/ui/app/workspace/config/mcp-gateway/page.tsx new file mode 100644 index 0000000000..47f6865a9a --- /dev/null +++ b/ui/app/workspace/config/mcp-gateway/page.tsx @@ -0,0 +1,11 @@ +"use client"; + +import MCPGatewayView from "../views/mcpView"; + +export default function MCPGatewayPage() { + return ( +
+ +
+ ); +} diff --git a/ui/app/workspace/config/views/clientSettingsView.tsx b/ui/app/workspace/config/views/clientSettingsView.tsx index da1bff11dd..40f054f541 100644 --- a/ui/app/workspace/config/views/clientSettingsView.tsx +++ b/ui/app/workspace/config/views/clientSettingsView.tsx @@ -26,6 +26,9 @@ const defaultConfig: CoreConfig = { max_request_body_size_mb: 100, enable_litellm_fallbacks: false, log_retention_days: 365, + mcp_agent_depth: 10, + mcp_tool_execution_timeout: 30, + mcp_code_mode_binding_level: "server", header_filter_config: DefaultGlobalHeaderFilterConfig, }; @@ -119,6 +122,10 @@ export default function ClientSettingsView() { } try { + if (!bifrostConfig) { + toast.error("Configuration not loaded. Please refresh and try again."); + return; + } // Clean up empty strings from header filter config const cleanedConfig = { ...localConfig, diff --git a/ui/app/workspace/config/views/governanceView.tsx b/ui/app/workspace/config/views/governanceView.tsx index 6121b66926..e95441986e 100644 --- a/ui/app/workspace/config/views/governanceView.tsx +++ b/ui/app/workspace/config/views/governanceView.tsx @@ -21,6 +21,9 @@ const defaultConfig: CoreConfig = { allowed_origins: [], max_request_body_size_mb: 100, enable_litellm_fallbacks: false, + mcp_agent_depth: 10, + mcp_tool_execution_timeout: 30, + mcp_code_mode_binding_level: "server", }; export default function GovernanceView() { @@ -61,7 +64,7 @@ export default function GovernanceView() { }, [bifrostConfig, localConfig, updateCoreConfig]); return ( -
+

Governance

@@ -101,5 +104,3 @@ export default function GovernanceView() { const RestartWarning = () => { return
Need to restart Bifrost to apply changes.
; }; - - diff --git a/ui/app/workspace/config/views/loggingView.tsx b/ui/app/workspace/config/views/loggingView.tsx index df5beebf7c..2f6f9ed2c2 100644 --- a/ui/app/workspace/config/views/loggingView.tsx +++ b/ui/app/workspace/config/views/loggingView.tsx @@ -23,6 +23,9 @@ const defaultConfig: CoreConfig = { allowed_origins: [], max_request_body_size_mb: 100, enable_litellm_fallbacks: false, + mcp_agent_depth: 10, + mcp_tool_execution_timeout: 30, + mcp_code_mode_binding_level: "server", }; export default function LoggingView() { @@ -126,7 +129,8 @@ export default function LoggingView() { Disable Content Logging

- When enabled, only usage metadata (latency, cost, token count, etc.) will be logged. Request/response content will not be stored. + When enabled, only usage metadata (latency, cost, token count, etc.) will be logged. Request/response content will not be + stored.

{ return
Need to restart Bifrost to apply changes.
; }; - - diff --git a/ui/app/workspace/config/views/mcpView.tsx b/ui/app/workspace/config/views/mcpView.tsx new file mode 100644 index 0000000000..9fa664c1a3 --- /dev/null +++ b/ui/app/workspace/config/views/mcpView.tsx @@ -0,0 +1,219 @@ +"use client"; + +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; +import { getErrorMessage, useGetCoreConfigQuery, useUpdateCoreConfigMutation } from "@/lib/store"; +import { CoreConfig } from "@/lib/types/config"; +import { RbacOperation, RbacResource, useRbac } from "@enterprise/lib"; +import { useCallback, useEffect, useMemo, useState } from "react"; +import { toast } from "sonner"; + +const defaultConfig: CoreConfig = { + drop_excess_requests: false, + initial_pool_size: 1000, + prometheus_labels: [], + enable_logging: true, + enable_governance: true, + enforce_governance_header: false, + allow_direct_keys: false, + allowed_origins: [], + max_request_body_size_mb: 100, + enable_litellm_fallbacks: false, + disable_content_logging: false, + log_retention_days: 365, + mcp_agent_depth: 10, + mcp_tool_execution_timeout: 30, + mcp_code_mode_binding_level: "server", +}; + +export default function MCPView() { + const hasSettingsUpdateAccess = useRbac(RbacResource.Settings, RbacOperation.Update); + const { data: bifrostConfig } = useGetCoreConfigQuery({ fromDB: true }); + const config = bifrostConfig?.client_config; + const [updateCoreConfig, { isLoading }] = useUpdateCoreConfigMutation(); + const [localConfig, setLocalConfig] = useState(defaultConfig); + + const [localValues, setLocalValues] = useState<{ + mcp_agent_depth: string; + mcp_tool_execution_timeout: string; + mcp_code_mode_binding_level: string; + }>({ + mcp_agent_depth: "10", + mcp_tool_execution_timeout: "30", + mcp_code_mode_binding_level: "server", + }); + + useEffect(() => { + if (bifrostConfig && config) { + setLocalConfig(config); + setLocalValues({ + mcp_agent_depth: config?.mcp_agent_depth?.toString() || "10", + mcp_tool_execution_timeout: config?.mcp_tool_execution_timeout?.toString() || "30", + mcp_code_mode_binding_level: config?.mcp_code_mode_binding_level || "server", + }); + } + }, [config, bifrostConfig]); + + const hasChanges = useMemo(() => { + if (!config) return false; + return ( + localConfig.mcp_agent_depth !== config.mcp_agent_depth || + localConfig.mcp_tool_execution_timeout !== config.mcp_tool_execution_timeout || + localConfig.mcp_code_mode_binding_level !== (config.mcp_code_mode_binding_level || "server") + ); + }, [config, localConfig]); + + const handleAgentDepthChange = useCallback((value: string) => { + setLocalValues((prev) => ({ ...prev, mcp_agent_depth: value })); + const numValue = Number.parseInt(value); + if (!isNaN(numValue) && numValue > 0) { + setLocalConfig((prev) => ({ ...prev, mcp_agent_depth: numValue })); + } + }, []); + + const handleToolExecutionTimeoutChange = useCallback((value: string) => { + setLocalValues((prev) => ({ ...prev, mcp_tool_execution_timeout: value })); + const numValue = Number.parseInt(value); + if (!isNaN(numValue) && numValue > 0) { + setLocalConfig((prev) => ({ ...prev, mcp_tool_execution_timeout: numValue })); + } + }, []); + + const handleCodeModeBindingLevelChange = useCallback((value: string) => { + setLocalValues((prev) => ({ ...prev, mcp_code_mode_binding_level: value })); + if (value === "server" || value === "tool") { + setLocalConfig((prev) => ({ ...prev, mcp_code_mode_binding_level: value })); + } + }, []); + + const handleSave = useCallback(async () => { + try { + const agentDepth = Number.parseInt(localValues.mcp_agent_depth); + const toolTimeout = Number.parseInt(localValues.mcp_tool_execution_timeout); + + if (isNaN(agentDepth) || agentDepth <= 0) { + toast.error("Max agent depth must be a positive number."); + return; + } + + if (isNaN(toolTimeout) || toolTimeout <= 0) { + toast.error("Tool execution timeout must be a positive number."); + return; + } + + if (!bifrostConfig) { + toast.error("Configuration not loaded. Please refresh and try again."); + return; + } + await updateCoreConfig({ ...bifrostConfig, client_config: localConfig }).unwrap(); + toast.success("MCP settings updated successfully."); + } catch (error) { + toast.error(getErrorMessage(error)); + } + }, [bifrostConfig, localConfig, localValues, updateCoreConfig]); + + return ( +
+
+
+

MCP Settings

+

Configure MCP (Model Context Protocol) agent and tool settings.

+
+ +
+
+ {/* Max Agent Depth */} +
+
+ +

Maximum depth for MCP agent execution.

+
+ handleAgentDepthChange(e.target.value)} + min="1" + /> +
+ + {/* Tool Execution Timeout */} +
+
+ +

Maximum time in seconds for tool execution.

+
+ handleToolExecutionTimeoutChange(e.target.value)} + min="1" + /> +
+ + {/* Code Mode Binding Level */} +
+
+ +

+ How tools are exposed in the VFS: server-level (all tools per server) or tool-level (individual tools). +

+
+ + + {/* Visual Example */} +
+

VFS Structure:

+ + {localValues.mcp_code_mode_binding_level === "server" ? ( +
+
+
servers/
+
ā”œā”€ calculator.d.ts
+
ā”œā”€ youtube.d.ts
+
└─ weather.d.ts
+
+

All tools per server in a single .d.ts file

+
+ ) : ( +
+
+
servers/
+
ā”œā”€ calculator/
+
ā”œā”€ add.d.ts
+
└─ subtract.d.ts
+
ā”œā”€ youtube/
+
ā”œā”€ GET_CHANNELS.d.ts
+
└─ SEARCH_VIDEOS.d.ts
+
└─ weather/
+
└─ get_forecast.d.ts
+
+

Individual .d.ts file for each tool

+
+ )} +
+
+
+
+ ); +} diff --git a/ui/app/workspace/config/views/observabilityView.tsx b/ui/app/workspace/config/views/observabilityView.tsx index 9f680b933e..126da2a5d2 100644 --- a/ui/app/workspace/config/views/observabilityView.tsx +++ b/ui/app/workspace/config/views/observabilityView.tsx @@ -24,6 +24,9 @@ const defaultConfig: CoreConfig = { enable_litellm_fallbacks: false, disable_content_logging: false, log_retention_days: 365, + mcp_agent_depth: 10, + mcp_tool_execution_timeout: 30, + mcp_code_mode_binding_level: "server", }; export default function ObservabilityView() { diff --git a/ui/app/workspace/config/views/performanceTuningView.tsx b/ui/app/workspace/config/views/performanceTuningView.tsx index 2778c6e077..e6b6055537 100644 --- a/ui/app/workspace/config/views/performanceTuningView.tsx +++ b/ui/app/workspace/config/views/performanceTuningView.tsx @@ -23,6 +23,9 @@ const defaultConfig: CoreConfig = { enable_litellm_fallbacks: false, disable_content_logging: false, log_retention_days: 365, + mcp_agent_depth: 10, + mcp_tool_execution_timeout: 30, + mcp_code_mode_binding_level: "server", }; export default function PerformanceTuningView() { @@ -91,7 +94,11 @@ export default function PerformanceTuningView() { return; } - await updateCoreConfig({ ...bifrostConfig!, client_config: localConfig }).unwrap(); + if (!bifrostConfig) { + toast.error("Configuration not loaded. Please refresh and try again."); + return; + } + await updateCoreConfig({ ...bifrostConfig, client_config: localConfig }).unwrap(); toast.success("Performance settings updated successfully."); } catch (error) { toast.error(getErrorMessage(error)); @@ -99,7 +106,7 @@ export default function PerformanceTuningView() { }, [bifrostConfig, localConfig, localValues, updateCoreConfig]); return ( -
+

Performance Tuning

diff --git a/ui/app/workspace/config/views/pluginsForm.tsx b/ui/app/workspace/config/views/pluginsForm.tsx index 1e7d48abd5..2f65444c6f 100644 --- a/ui/app/workspace/config/views/pluginsForm.tsx +++ b/ui/app/workspace/config/views/pluginsForm.tsx @@ -238,8 +238,18 @@ export default function PluginsForm({ isVectorStoreEnabled }: PluginsFormProps) id="ttl" type="number" min="1" - value={cacheConfig.ttl_seconds} - onChange={(e) => updateCacheConfigLocal({ ttl_seconds: parseInt(e.target.value) || 300 })} + value={cacheConfig.ttl_seconds === undefined || Number.isNaN(cacheConfig.ttl_seconds) ? '' : cacheConfig.ttl_seconds} + onChange={(e) => { + const value = e.target.value + if (value === '') { + updateCacheConfigLocal({ ttl_seconds: undefined }) + return + } + const parsed = parseInt(value) + if (!Number.isNaN(parsed)) { + updateCacheConfigLocal({ ttl_seconds: parsed }) + } + }} />
@@ -250,8 +260,18 @@ export default function PluginsForm({ isVectorStoreEnabled }: PluginsFormProps) min="0" max="1" step="0.01" - value={cacheConfig.threshold} - onChange={(e) => updateCacheConfigLocal({ threshold: parseFloat(e.target.value) || 0.8 })} + value={cacheConfig.threshold === undefined || Number.isNaN(cacheConfig.threshold) ? '' : cacheConfig.threshold} + onChange={(e) => { + const value = e.target.value + if (value === '') { + updateCacheConfigLocal({ threshold: undefined }) + return + } + const parsed = parseFloat(value) + if (!Number.isNaN(parsed)) { + updateCacheConfigLocal({ threshold: parsed }) + } + }} />
@@ -260,8 +280,18 @@ export default function PluginsForm({ isVectorStoreEnabled }: PluginsFormProps) id="dimension" type="number" min="0" - value={cacheConfig.dimension} - onChange={(e) => updateCacheConfigLocal({ dimension: parseInt(e.target.value) || 0 })} + value={cacheConfig.dimension === undefined || Number.isNaN(cacheConfig.dimension) ? '' : cacheConfig.dimension} + onChange={(e) => { + const value = e.target.value + if (value === '') { + updateCacheConfigLocal({ dimension: undefined }) + return + } + const parsed = parseInt(value) + if (!Number.isNaN(parsed)) { + updateCacheConfigLocal({ dimension: parsed }) + } + }} />
diff --git a/ui/app/workspace/config/views/securityView.tsx b/ui/app/workspace/config/views/securityView.tsx index 5200af1362..f261d50965 100644 --- a/ui/app/workspace/config/views/securityView.tsx +++ b/ui/app/workspace/config/views/securityView.tsx @@ -31,6 +31,9 @@ const defaultConfig: CoreConfig = { max_request_body_size_mb: 100, enable_litellm_fallbacks: false, log_retention_days: 365, + mcp_agent_depth: 10, + mcp_tool_execution_timeout: 30, + mcp_code_mode_binding_level: "server", }; export default function SecurityView() { diff --git a/ui/app/workspace/logs/page.tsx b/ui/app/workspace/logs/page.tsx index 8468fcf801..c538cb82c7 100644 --- a/ui/app/workspace/logs/page.tsx +++ b/ui/app/workspace/logs/page.tsx @@ -72,7 +72,7 @@ export default function LogsPage() { content_search: parseAsString.withDefault(""), start_time: parseAsInteger.withDefault(DEFAULT_START_TIME), end_time: parseAsInteger.withDefault(DEFAULT_END_TIME), - limit: parseAsInteger.withDefault(50), + limit: parseAsInteger.withDefault(25), // Default fallback, actual value calculated based on table height offset: parseAsInteger.withDefault(0), sort_by: parseAsString.withDefault("timestamp"), order: parseAsString.withDefault("desc"), diff --git a/ui/app/workspace/logs/views/columns.tsx b/ui/app/workspace/logs/views/columns.tsx index 3fd03281e2..3c34bbc581 100644 --- a/ui/app/workspace/logs/views/columns.tsx +++ b/ui/app/workspace/logs/views/columns.tsx @@ -26,7 +26,8 @@ function getMessage(log?: LogEntry) { } return lastTextContentBlock; } else if (log?.responses_input_history && log.responses_input_history.length > 0) { - let lastMessageContent = log.responses_input_history[log.responses_input_history.length - 1].content; + let lastMessage = log.responses_input_history[log.responses_input_history.length - 1]; + let lastMessageContent = lastMessage.content; if (typeof lastMessageContent === "string") { return lastMessageContent; } @@ -36,7 +37,18 @@ function getMessage(log?: LogEntry) { lastTextContentBlock = block.text; } } - return lastTextContentBlock; + // If no content found in content field, check output field for Responses API + if (!lastTextContentBlock && lastMessage.output) { + // Handle output field - it could be a string, an array of content blocks, or a computer tool call output data + if (typeof lastMessage.output === "string") { + return lastMessage.output; + } else if (Array.isArray(lastMessage.output)) { + return lastMessage.output.map((block) => block.text).join("\n"); + } else if (lastMessage.output.type && lastMessage.output.type === "computer_screenshot") { + return lastMessage.output.image_url; + } + } + return lastTextContentBlock ?? ""; } else if (log?.speech_input) { return log.speech_input.input; } else if (log?.transcription_input) { diff --git a/ui/app/workspace/logs/views/logChatMessageView.tsx b/ui/app/workspace/logs/views/logChatMessageView.tsx index 5b0c821cf5..f99484ddf8 100644 --- a/ui/app/workspace/logs/views/logChatMessageView.tsx +++ b/ui/app/workspace/logs/views/logChatMessageView.tsx @@ -111,14 +111,14 @@ export default function LogChatMessageView({ message, audioFormat }: LogChatMess options={{ scrollBeyondLastLine: false, collapsibleBlocks: true, lineNumbers: "off", alwaysConsumeMouseWheel: false }} /> ) : ( -
{message.refusal}
+
{message.refusal}
)}
)} {/* Handle content */} {message.content && ( -
+
{typeof message.content === "string" ? ( <> {isJson(message.content) ? ( @@ -133,7 +133,7 @@ export default function LogChatMessageView({ message, audioFormat }: LogChatMess options={{ scrollBeyondLastLine: false, collapsibleBlocks: true, lineNumbers: "off", alwaysConsumeMouseWheel: false }} /> ) : ( -
{message.content}
+
{message.content}
)} ) : ( diff --git a/ui/app/workspace/logs/views/logDetailsSheet.tsx b/ui/app/workspace/logs/views/logDetailsSheet.tsx index 57f673e5f5..fb05bcc0f7 100644 --- a/ui/app/workspace/logs/views/logDetailsSheet.tsx +++ b/ui/app/workspace/logs/views/logDetailsSheet.tsx @@ -1,14 +1,33 @@ "use client"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger, +} from "@/components/ui/alertDialog"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuSeparator, + DropdownMenuTrigger, +} from "@/components/ui/dropdownMenu"; import { DottedSeparator } from "@/components/ui/separator"; import { Sheet, SheetContent, SheetHeader, SheetTitle } from "@/components/ui/sheet"; import { ProviderIconType, RenderProviderIcon } from "@/lib/constants/icons"; import { RequestTypeColors, RequestTypeLabels, Status, StatusColors } from "@/lib/constants/logs"; import { LogEntry } from "@/lib/types/logs"; -import { DollarSign, FileText, Timer, Trash2 } from "lucide-react"; +import { Clipboard, DollarSign, FileText, MoreVertical, Timer, Trash2 } from "lucide-react"; import moment from "moment"; +import { toast } from "sonner"; import { CodeEditor } from "./codeEditor"; import LogChatMessageView from "./logChatMessageView"; import LogEntryDetailsView from "./logEntryDetailsView"; @@ -34,13 +53,131 @@ export function LogDetailSheet({ log, open, onOpenChange, handleDelete }: LogDet } catch (ignored) {} } + const copyRequestBody = async () => { + try { + // Check if request is for responses, chat, speech, text completion, or embedding (exclude transcriptions) + const object = log.object?.toLowerCase() || ""; + const isChat = object === "chat_completion" || object === "chat_completion_stream"; + const isResponses = object === "responses" || object === "responses_stream"; + const isSpeech = object === "speech" || object === "speech_stream"; + const isTextCompletion = object === "text_completion" || object === "text_completion_stream"; + const isEmbedding = object === "embedding"; + const isTranscription = object === "transcription" || object === "transcription_stream"; + + // Skip if transcription + if (isTranscription) { + toast.error("Copy request body is not available for transcription requests"); + return; + } + + // Skip if not a supported request type + if (!isChat && !isResponses && !isSpeech && !isTextCompletion && !isEmbedding) { + toast.error("Copy request body is only available for chat, responses, speech, text completion, and embedding requests"); + return; + } + + // Helper function to extract text content from ChatMessage + const extractTextFromMessage = (message: any): string => { + if (!message || !message.content) { + return ""; + } + if (typeof message.content === "string") { + return message.content; + } + if (Array.isArray(message.content)) { + return message.content + .filter((block: any) => block && block.type === "text" && block.text) + .map((block: any) => block.text || "") + .join(""); + } + return ""; + }; + + // Helper function to extract texts from ChatMessage content blocks (for embeddings) + const extractTextsFromMessage = (message: any): string[] => { + if (!message || !message.content) { + return []; + } + if (typeof message.content === "string") { + return message.content ? [message.content] : []; + } + if (Array.isArray(message.content)) { + return message.content.filter((block: any) => block && block.type === "text" && block.text).map((block: any) => block.text); + } + return []; + }; + + // Build request body following OpenAI schema + const requestBody: any = { + model: log.provider && log.model ? `${log.provider}/${log.model}` : log.model || "", + }; + + // Add messages/input/prompt based on request type + if (isChat && log.input_history && log.input_history.length > 0) { + requestBody.messages = log.input_history; + } else if (isResponses && log.responses_input_history && log.responses_input_history.length > 0) { + requestBody.input = log.responses_input_history; + } else if (isSpeech && log.speech_input) { + requestBody.input = log.speech_input.input; + } else if (isTextCompletion && log.input_history && log.input_history.length > 0) { + // For text completions, extract prompt from input_history + const firstMessage = log.input_history[0]; + const prompt = extractTextFromMessage(firstMessage); + if (prompt) { + requestBody.prompt = prompt; + } + } else if (isEmbedding && log.input_history && log.input_history.length > 0) { + // For embeddings, extract all texts from input_history + const texts: string[] = []; + for (const message of log.input_history) { + const messageTexts = extractTextsFromMessage(message); + texts.push(...messageTexts); + } + if (texts.length > 0) { + // Use single string if only one text, otherwise use array + requestBody.input = texts.length === 1 ? texts[0] : texts; + } + } + + // Add params (excluding tools and instructions as they're handled separately in OpenAI schema) + if (log.params) { + const paramsCopy = { ...log.params }; + // Remove tools and instructions from params as they're typically top-level in OpenAI schema + // Keep all other params (temperature, max_tokens, voice, etc.) + delete paramsCopy.tools; + delete paramsCopy.instructions; + + // Merge remaining params into request body + Object.assign(requestBody, paramsCopy); + } + + // Add tools if they exist (for chat and responses) - OpenAI schema has tools at top level + if ((isChat || isResponses) && log.params?.tools && Array.isArray(log.params.tools) && log.params.tools.length > 0) { + requestBody.tools = log.params.tools; + } + + // Add instructions if they exist (for responses) - OpenAI schema has instructions at top level + if (isResponses && log.params?.instructions) { + requestBody.instructions = log.params.instructions; + } + + const requestBodyJson = JSON.stringify(requestBody, null, 2); + navigator.clipboard.writeText(requestBodyJson).then(() => { + toast.success("Request body copied to clipboard"); + }).catch((error) => { + toast.error("Failed to copy request body"); + }); + } catch (error) { + toast.error("Failed to copy request body"); + } + }; // Extract audio format from request params // Format can be in params.audio?.format or params.extra_params?.audio?.format const audioFormat = (log.params as any)?.audio?.format || (log.params as any)?.extra_params?.audio?.format || undefined; return ( - +
@@ -50,16 +187,45 @@ export function LogDetailSheet({ log, open, onOpenChange, handleDelete }: LogDet
- + + + + + + + + + Copy request body + + + + + + Delete log + + + + + + + Are you sure you want to delete this log? + This action cannot be undone. This will permanently delete the log entry. + + + Cancel + { + handleDelete(log); + onOpenChange(false); + }} + > + Delete + + + +
@@ -338,7 +504,7 @@ export function LogDetailSheet({ log, open, onOpenChange, handleDelete }: LogDet
{toolsParameter && (
-
Tools
+
Tools ({log.params?.tools?.length || 0})
([{ id: pagination.sort_by, desc: pagination.order === "desc" }]); + const tableContainerRef = useRef(null); + const calculatedPageSize = useTablePageSize(tableContainerRef); + + // Refs to avoid stale closures in the page size effect + const paginationRef = useRef(pagination); + const onPaginationChangeRef = useRef(onPaginationChange); + paginationRef.current = pagination; + onPaginationChangeRef.current = onPaginationChange; + + // Update pagination limit when calculated page size increases (don't reduce on size reduction) + useEffect(() => { + if (calculatedPageSize && calculatedPageSize > paginationRef.current.limit) { + onPaginationChangeRef.current({ + ...paginationRef.current, + limit: calculatedPageSize, + offset: 0, // Reset to first page when page size changes + }); + } + }, [calculatedPageSize]); const handleSortingChange = (updaterOrValue: SortingState | ((old: SortingState) => SortingState)) => { const newSorting = typeof updaterOrValue === "function" ? updaterOrValue(sorting) : updaterOrValue; @@ -86,8 +106,8 @@ export function LogsDataTable({ return (
-
- +
+
{table.getHeaderGroups().map((headerGroup) => ( diff --git a/ui/app/workspace/mcp-clients/layout.tsx b/ui/app/workspace/mcp-gateway/layout.tsx similarity index 100% rename from ui/app/workspace/mcp-clients/layout.tsx rename to ui/app/workspace/mcp-gateway/layout.tsx diff --git a/ui/app/workspace/mcp-clients/page.tsx b/ui/app/workspace/mcp-gateway/page.tsx similarity index 91% rename from ui/app/workspace/mcp-clients/page.tsx rename to ui/app/workspace/mcp-gateway/page.tsx index f046b88df6..98083d1b5e 100644 --- a/ui/app/workspace/mcp-clients/page.tsx +++ b/ui/app/workspace/mcp-gateway/page.tsx @@ -1,6 +1,6 @@ "use client"; -import MCPClientsTable from "@/app/workspace/mcp-clients/views/mcpClientsTable"; +import MCPClientsTable from "@/app/workspace/mcp-gateway/views/mcpClientsTable"; import FullPageLoader from "@/components/fullPageLoader"; import { useToast } from "@/hooks/use-toast"; import { getErrorMessage, useGetMCPClientsQuery } from "@/lib/store"; diff --git a/ui/app/workspace/mcp-clients/views/mcpClientForm.tsx b/ui/app/workspace/mcp-gateway/views/mcpClientForm.tsx similarity index 85% rename from ui/app/workspace/mcp-clients/views/mcpClientForm.tsx rename to ui/app/workspace/mcp-gateway/views/mcpClientForm.tsx index 7b0aee2ad3..2014bc2174 100644 --- a/ui/app/workspace/mcp-clients/views/mcpClientForm.tsx +++ b/ui/app/workspace/mcp-gateway/views/mcpClientForm.tsx @@ -6,6 +6,7 @@ import { HeadersTable } from "@/components/ui/headersTable"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; +import { Switch } from "@/components/ui/switch"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; import { useToast } from "@/hooks/use-toast"; import { getErrorMessage, useCreateMCPClientMutation } from "@/lib/store"; @@ -30,6 +31,7 @@ const emptyStdioConfig: MCPStdioConfig = { const emptyForm: CreateMCPClientRequest = { name: "", + is_code_mode_client: false, connection_type: "http", connection_string: "", stdio_config: emptyStdioConfig, @@ -57,7 +59,10 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => { } }, [open]); - const handleChange = (field: keyof CreateMCPClientRequest, value: string | string[] | MCPConnectionType | MCPStdioConfig | undefined) => { + const handleChange = ( + field: keyof CreateMCPClientRequest, + value: string | string[] | boolean | MCPConnectionType | MCPStdioConfig | undefined, + ) => { setForm((prev) => ({ ...prev, [field]: value })); }; @@ -95,10 +100,13 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => { const validator = new Validator([ // Name validation - Validator.required(form.name?.trim(), "Client name is required"), - Validator.pattern(form.name || "", /^[a-zA-Z0-9-_]+$/, "Client name can only contain letters, numbers, hyphens and underscores"), - Validator.minLength(form.name || "", 3, "Client name must be at least 3 characters"), - Validator.maxLength(form.name || "", 50, "Client name cannot exceed 50 characters"), + Validator.required(form.name?.trim(), "Server name is required"), + Validator.pattern(form.name || "", /^[a-zA-Z0-9_]+$/, "Server name can only contain letters, numbers, and underscores"), + Validator.custom(!(form.name || "").includes("-"), "Server name cannot contain hyphens"), + Validator.custom(!(form.name || "").includes(" "), "Server name cannot contain spaces"), + Validator.custom((form.name || "").length === 0 || !/^[0-9]/.test(form.name || ""), "Server name cannot start with a number"), + Validator.minLength(form.name || "", 3, "Server name must be at least 3 characters"), + Validator.maxLength(form.name || "", 50, "Server name cannot exceed 50 characters"), // Connection type specific validation ...(form.connection_type === "http" || form.connection_type === "sse" @@ -156,7 +164,7 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => { setIsLoading(false); toast({ title: "Success", - description: "Client created", + description: "Server created", }); onSaved(); onClose(); @@ -170,7 +178,7 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => { - New MCP Client + New MCP Server
@@ -178,7 +186,7 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => { ) => handleChange("name", e.target.value)} - placeholder="Client name" + placeholder="Server name" maxLength={50} />
@@ -197,6 +205,15 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => {
+
+ + handleChange("is_code_mode_client", checked)} + /> +
+ {(form.connection_type === "http" || form.connection_type === "sse") && ( <>
@@ -275,7 +292,11 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => { - diff --git a/ui/app/workspace/mcp-clients/views/mcpClientSheet.tsx b/ui/app/workspace/mcp-gateway/views/mcpClientSheet.tsx similarity index 62% rename from ui/app/workspace/mcp-clients/views/mcpClientSheet.tsx rename to ui/app/workspace/mcp-gateway/views/mcpClientSheet.tsx index 618e45e88a..52dcadb9ca 100644 --- a/ui/app/workspace/mcp-clients/views/mcpClientSheet.tsx +++ b/ui/app/workspace/mcp-gateway/views/mcpClientSheet.tsx @@ -8,6 +8,7 @@ import { HeadersTable } from "@/components/ui/headersTable"; import { Input } from "@/components/ui/input"; import { Sheet, SheetContent, SheetDescription, SheetHeader, SheetTitle } from "@/components/ui/sheet"; import { Switch } from "@/components/ui/switch"; +import { TriStateCheckbox } from "@/components/ui/tristateCheckbox"; import { useToast } from "@/hooks/use-toast"; import { MCP_STATUS_COLORS } from "@/lib/constants/config"; import { getErrorMessage, useUpdateMCPClientMutation } from "@/lib/store"; @@ -34,8 +35,10 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: mode: "onBlur", defaultValues: { name: mcpClient.config.name, + is_code_mode_client: mcpClient.config.is_code_mode_client || false, headers: mcpClient.config.headers, tools_to_execute: mcpClient.config.tools_to_execute || [], + tools_to_auto_execute: mcpClient.config.tools_to_auto_execute || [], }, }); @@ -43,8 +46,10 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: useEffect(() => { form.reset({ name: mcpClient.config.name, + is_code_mode_client: mcpClient.config.is_code_mode_client || false, headers: mcpClient.config.headers, tools_to_execute: mcpClient.config.tools_to_execute || [], + tools_to_auto_execute: mcpClient.config.tools_to_auto_execute || [], }); }, [form, mcpClient]); @@ -54,8 +59,10 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: id: mcpClient.config.id, data: { name: data.name, + is_code_mode_client: data.is_code_mode_client, headers: data.headers, tools_to_execute: data.tools_to_execute, + tools_to_auto_execute: data.tools_to_auto_execute, }, }).unwrap(); @@ -106,6 +113,69 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: } form.setValue("tools_to_execute", newTools, { shouldDirty: true }); + + // If tool is being removed from tools_to_execute, also remove it from tools_to_auto_execute + if (!checked) { + const currentAutoExecute = form.getValues("tools_to_auto_execute") || []; + if (currentAutoExecute.includes(toolName) || currentAutoExecute.includes("*")) { + const newAutoExecute = currentAutoExecute.filter((tool) => tool !== toolName); + // If we had "*" and removed a tool, we need to recalculate + if (currentAutoExecute.includes("*")) { + // If all tools mode, keep "*" only if tool is still in tools_to_execute + if (newTools.includes("*")) { + form.setValue("tools_to_auto_execute", ["*"], { shouldDirty: true }); + } else { + // Switch to explicit list - when in wildcard mode, all remaining tools should be auto-execute + form.setValue("tools_to_auto_execute", newTools, { shouldDirty: true }); + } + } else { + form.setValue("tools_to_auto_execute", newAutoExecute, { shouldDirty: true }); + } + } + } + }; + + const handleAutoExecuteToggle = (toolName: string, checked: boolean) => { + const currentAutoExecute = form.getValues("tools_to_auto_execute") || []; + const currentTools = form.getValues("tools_to_execute") || []; + const allToolNames = mcpClient.tools?.map((tool) => tool.name) || []; + + // Check if we're in "all tools" mode (wildcard) + const isAllToolsMode = currentTools.includes("*"); + const isAllAutoExecuteMode = currentAutoExecute.includes("*"); + + let newAutoExecute: string[]; + + if (isAllAutoExecuteMode) { + if (checked) { + // Already all selected, keep wildcard + newAutoExecute = ["*"]; + } else { + // Unchecking a tool when all are selected - switch to explicit list without this tool + if (isAllToolsMode) { + newAutoExecute = allToolNames.filter((name) => name !== toolName); + } else { + newAutoExecute = currentTools.filter((name) => name !== toolName); + } + } + } else { + // We're in explicit tool selection mode + if (checked) { + // Add tool to selection + newAutoExecute = currentAutoExecute.includes(toolName) ? currentAutoExecute : [...currentAutoExecute, toolName]; + + // If we now have all allowed tools selected, switch to wildcard mode + const allowedTools = isAllToolsMode ? allToolNames : currentTools; + if (newAutoExecute.length === allowedTools.length && allowedTools.every((tool) => newAutoExecute.includes(tool))) { + newAutoExecute = ["*"]; + } + } else { + // Remove tool from selection + newAutoExecute = currentAutoExecute.filter((tool) => tool !== toolName); + } + } + + form.setValue("tools_to_auto_execute", newAutoExecute, { shouldDirty: true }); }; return ( @@ -113,14 +183,14 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }:
- -
+ +
{mcpClient.config.name} {mcpClient.state} - MCP client configuration and available tools + MCP server configuration and available tools
- Manage clients that can connect to the MCP Tools endpoint. + Manage servers that can connect to the MCP Tools endpoint.
@@ -144,7 +144,10 @@ export default function MCPClientsTable({ mcpClients }: MCPClientsTableProps) { Name Connection Type + Code Mode Connection Info + Enabled Tools + Auto-execute Tools State @@ -152,56 +155,101 @@ export default function MCPClientsTable({ mcpClients }: MCPClientsTableProps) { {clients.length === 0 && ( - + No clients found. )} - {clients.map((c: MCPClient) => ( - handleRowClick(c)}> - {c.config.name} - {getConnectionTypeDisplay(c.config.connection_type)} - {getConnectionDisplay(c)} - - {c.state} - - e.stopPropagation()}> - - - - - - - - - Remove MCP Client - - Are you sure you want to remove MCP client {c.config.name}? You will need to reconnect the client to continue - using it. - - - - Cancel - handleDelete(c)}>Delete - - - - - - ))} + + + {c.state == "connected" ? ( + <> + {autoExecuteToolsCount}/{c.tools?.length} + + ) : ( + "-" + )} + + + {c.state} + + e.stopPropagation()}> + + + + + + + + + Remove MCP Server + + Are you sure you want to remove MCP server {c.config.name}? You will need to reconnect the server to continue + using it. + + + + Cancel + handleDelete(c)}>Delete + + + + + + ); + })}
diff --git a/ui/app/workspace/observability/fragments/maximFormFragment.tsx b/ui/app/workspace/observability/fragments/maximFormFragment.tsx index 719fbc0e40..845fa4936a 100644 --- a/ui/app/workspace/observability/fragments/maximFormFragment.tsx +++ b/ui/app/workspace/observability/fragments/maximFormFragment.tsx @@ -6,6 +6,7 @@ import { Input } from "@/components/ui/input"; import { Switch } from "@/components/ui/switch"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; import { maximFormSchema, type MaximFormSchema } from "@/lib/types/schemas"; +import { RbacOperation, RbacResource, useRbac } from "@enterprise/lib"; import { zodResolver } from "@hookform/resolvers/zod"; import { Eye, EyeOff } from "lucide-react"; import { useEffect, useState } from "react"; @@ -22,6 +23,7 @@ interface MaximFormFragmentProps { } export function MaximFormFragment({ initialConfig, onSave, isLoading = false }: MaximFormFragmentProps) { + const hasMaximAccess = useRbac(RbacResource.Observability, RbacOperation.Update); const [showApiKey, setShowApiKey] = useState(false); const [isSaving, setIsSaving] = useState(false); @@ -67,13 +69,14 @@ export function MaximFormFragment({ initialConfig, onSave, isLoading = false }: API Key
- + @@ -91,7 +94,7 @@ export function MaximFormFragment({ initialConfig, onSave, isLoading = false }: Log Repository ID (Optional) - + @@ -108,7 +111,7 @@ export function MaximFormFragment({ initialConfig, onSave, isLoading = false }: render={({ field }) => ( Enabled - + )} /> @@ -125,14 +128,14 @@ export function MaximFormFragment({ initialConfig, onSave, isLoading = false }: }, }); }} - disabled={isLoading || !form.formState.isDirty} + disabled={!hasMaximAccess || isLoading || !form.formState.isDirty} > Reset - diff --git a/ui/app/workspace/observability/fragments/otelFormFragment.tsx b/ui/app/workspace/observability/fragments/otelFormFragment.tsx index efe16dc957..074e0d772d 100644 --- a/ui/app/workspace/observability/fragments/otelFormFragment.tsx +++ b/ui/app/workspace/observability/fragments/otelFormFragment.tsx @@ -8,6 +8,7 @@ import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@ import { Switch } from "@/components/ui/switch"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; import { otelFormSchema, type OtelFormSchema } from "@/lib/types/schemas"; +import { RbacOperation, RbacResource, useRbac } from "@enterprise/lib"; import { zodResolver } from "@hookform/resolvers/zod"; import { useEffect, useState } from "react"; import { useForm, type Resolver } from "react-hook-form"; @@ -26,6 +27,7 @@ interface OtelFormFragmentProps { } export function OtelFormFragment({ currentConfig: initialConfig, onSave, isLoading = false }: OtelFormFragmentProps) { + const hasOtelAccess = useRbac(RbacResource.Observability, RbacOperation.Update); const [isSaving, setIsSaving] = useState(false); const form = useForm({ resolver: zodResolver(otelFormSchema) as Resolver, @@ -92,7 +94,7 @@ export function OtelFormFragment({ currentConfig: initialConfig, onSave, isLoadi Service Name If kept empty, the service name will be set to "bifrost" - + @@ -114,6 +116,7 @@ export function OtelFormFragment({ currentConfig: initialConfig, onSave, isLoadi ? "https://otel-collector.example.com:4318/v1/traces" : "otel-collector.example.com:4317" } + disabled={!hasOtelAccess} {...field} /> @@ -127,7 +130,7 @@ export function OtelFormFragment({ currentConfig: initialConfig, onSave, isLoadi render={({ field }) => ( - + @@ -140,7 +143,7 @@ export function OtelFormFragment({ currentConfig: initialConfig, onSave, isLoadi render={({ field }) => ( Format - @@ -170,7 +173,7 @@ export function OtelFormFragment({ currentConfig: initialConfig, onSave, isLoadi render={({ field }) => ( Protocol - @@ -205,7 +208,7 @@ export function OtelFormFragment({ currentConfig: initialConfig, onSave, isLoadi render={({ field }) => ( Enabled - + )} /> @@ -219,14 +222,14 @@ export function OtelFormFragment({ currentConfig: initialConfig, onSave, isLoadi otel_config: undefined, }); }} - disabled={isLoading || !form.formState.isDirty} + disabled={!hasOtelAccess || isLoading || !form.formState.isDirty} > Reset - diff --git a/ui/app/workspace/providers/fragments/allowedRequestsFields.tsx b/ui/app/workspace/providers/fragments/allowedRequestsFields.tsx index 6fe617e7a9..44c9c8d6b5 100644 --- a/ui/app/workspace/providers/fragments/allowedRequestsFields.tsx +++ b/ui/app/workspace/providers/fragments/allowedRequestsFields.tsx @@ -2,19 +2,20 @@ import { FormControl, FormField, FormItem, FormLabel } from "@/components/ui/form"; import { Input } from "@/components/ui/input"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Switch } from "@/components/ui/switch"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; -import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { BaseProvider, RequestType } from "@/lib/types/config"; import { isRequestTypeDisabled } from "@/lib/utils/validation"; -import { Control, useFormContext } from "react-hook-form"; import { Settings2 } from "lucide-react"; import { useEffect, useMemo } from "react"; +import { Control, useFormContext } from "react-hook-form"; interface AllowedRequestsFieldsProps { control: Control; namePrefix?: string; providerType?: BaseProvider; + disabled?: boolean; } // Provider-specific endpoint paths @@ -73,7 +74,7 @@ const REQUEST_TYPES: Array<{ key: RequestType; label: string }> = [ { key: "count_tokens", label: "Count Tokens" }, ]; -export function AllowedRequestsFields({ control, namePrefix = "allowed_requests", providerType }: AllowedRequestsFieldsProps) { +export function AllowedRequestsFields({ control, namePrefix = "allowed_requests", providerType, disabled = false }: AllowedRequestsFieldsProps) { const leftColumn = REQUEST_TYPES.slice(0, 6); const rightColumn = REQUEST_TYPES.slice(6); const { getValues, setValue } = useFormContext(); @@ -106,7 +107,7 @@ export function AllowedRequestsFields({ control, namePrefix = "allowed_requests"
{/* Settings icon for path override - only show when enabled */} - {allowedField.value && !isDisabled && !isPathOverrideDisabled && ( + {allowedField.value && !isDisabled && !isPathOverrideDisabled && !disabled && ( ) : ( - + )}
diff --git a/ui/app/workspace/providers/fragments/apiStructureFormFragment.tsx b/ui/app/workspace/providers/fragments/apiStructureFormFragment.tsx index 45aee698e9..db6e3cae37 100644 --- a/ui/app/workspace/providers/fragments/apiStructureFormFragment.tsx +++ b/ui/app/workspace/providers/fragments/apiStructureFormFragment.tsx @@ -133,7 +133,7 @@ export function ApiStructureFormFragment({ provider }: Props) {

Whether the custom provider requires a key

- +
)} @@ -142,7 +142,7 @@ export function ApiStructureFormFragment({ provider }: Props) {
{/* Allowed Requests Configuration */} - + {/* Form Actions */}
diff --git a/ui/app/workspace/providers/fragments/networkFormFragment.tsx b/ui/app/workspace/providers/fragments/networkFormFragment.tsx index e8515a7bac..1f9dde58fe 100644 --- a/ui/app/workspace/providers/fragments/networkFormFragment.tsx +++ b/ui/app/workspace/providers/fragments/networkFormFragment.tsx @@ -147,6 +147,7 @@ export function NetworkFormFragment({ provider }: NetworkFormFragmentProps) { placeholder={isCustomProvider ? "https://api.your-provider.com" : "https://api.example.com"} {...field} value={field.value || ""} + disabled={!hasUpdateProviderAccess} /> @@ -164,14 +165,18 @@ export function NetworkFormFragment({ provider }: NetworkFormFragmentProps) { { - if (isNaN(Number(e.target.value))) { - if (e.target.value.trim() === "") { - field.onChange(0); - } - return; + const value = e.target.value + if (value === '') { + field.onChange(undefined) + return + } + const parsed = Number(value) + if (!Number.isNaN(parsed)) { + field.onChange(parsed) } - field.onChange(Number(e.target.value)); }} /> @@ -187,7 +192,23 @@ export function NetworkFormFragment({ provider }: NetworkFormFragmentProps) { Max Retries - field.onChange(Number(e.target.value))} /> + { + const value = e.target.value + if (value === '') { + field.onChange(undefined) + return + } + const parsed = Number(value) + if (!Number.isNaN(parsed)) { + field.onChange(parsed) + } + }} + /> @@ -202,7 +223,23 @@ export function NetworkFormFragment({ provider }: NetworkFormFragmentProps) { Initial Backoff (ms) - field.onChange(Number(e.target.value))} /> + { + const value = e.target.value + if (value === '') { + field.onChange(undefined) + return + } + const parsed = Number(value) + if (!Number.isNaN(parsed)) { + field.onChange(parsed) + } + }} + /> @@ -215,7 +252,23 @@ export function NetworkFormFragment({ provider }: NetworkFormFragmentProps) { Max Backoff (ms) - field.onChange(Number(e.target.value))} /> + { + const value = e.target.value + if (value === '') { + field.onChange(undefined) + return + } + const parsed = Number(value) + if (!Number.isNaN(parsed)) { + field.onChange(parsed) + } + }} + /> @@ -234,6 +287,7 @@ export function NetworkFormFragment({ provider }: NetworkFormFragmentProps) { keyPlaceholder="Header name" valuePlaceholder="Header value" label="Extra Headers" + disabled={!hasUpdateProviderAccess} /> diff --git a/ui/app/workspace/providers/fragments/performanceFormFragment.tsx b/ui/app/workspace/providers/fragments/performanceFormFragment.tsx index 639e7c9265..29b2625000 100644 --- a/ui/app/workspace/providers/fragments/performanceFormFragment.tsx +++ b/ui/app/workspace/providers/fragments/performanceFormFragment.tsx @@ -95,7 +95,19 @@ export function PerformanceFormFragment({ provider }: PerformanceFormFragmentPro type="number" placeholder="10" {...field} - onChange={(e) => field.onChange(Number.parseInt(e.target.value) || 0)} + value={field.value === undefined || Number.isNaN(field.value) ? '' : field.value} + disabled={!hasUpdateProviderAccess} + onChange={(e) => { + const value = e.target.value + if (value === '') { + field.onChange(undefined) + return + } + const parsed = Number.parseInt(value) + if (!Number.isNaN(parsed)) { + field.onChange(parsed) + } + }} /> @@ -115,7 +127,19 @@ export function PerformanceFormFragment({ provider }: PerformanceFormFragmentPro type="number" placeholder="10" {...field} - onChange={(e) => field.onChange(Number.parseInt(e.target.value) || 0)} + value={field.value === undefined || Number.isNaN(field.value) ? '' : field.value} + disabled={!hasUpdateProviderAccess} + onChange={(e) => { + const value = e.target.value + if (value === '') { + field.onChange(undefined) + return + } + const parsed = Number.parseInt(value) + if (!Number.isNaN(parsed)) { + field.onChange(parsed) + } + }} /> @@ -142,6 +166,7 @@ export function PerformanceFormFragment({ provider }: PerformanceFormFragmentPro { field.onChange(checked); form.trigger("send_back_raw_request"); @@ -172,6 +197,7 @@ export function PerformanceFormFragment({ provider }: PerformanceFormFragmentPro { field.onChange(checked); form.trigger("send_back_raw_response"); diff --git a/ui/app/workspace/providers/fragments/proxyFormFragment.tsx b/ui/app/workspace/providers/fragments/proxyFormFragment.tsx index a0f3ac75ea..f0969d2ef0 100644 --- a/ui/app/workspace/providers/fragments/proxyFormFragment.tsx +++ b/ui/app/workspace/providers/fragments/proxyFormFragment.tsx @@ -91,7 +91,7 @@ export function ProxyFormFragment({ provider }: ProxyFormFragmentProps) { render={({ field }) => ( Proxy Type - @@ -122,7 +122,7 @@ export function ProxyFormFragment({ provider }: ProxyFormFragmentProps) { Proxy URL - + @@ -136,7 +136,7 @@ export function ProxyFormFragment({ provider }: ProxyFormFragmentProps) { Username - + @@ -149,7 +149,7 @@ export function ProxyFormFragment({ provider }: ProxyFormFragmentProps) { Password - + @@ -169,6 +169,7 @@ export function ProxyFormFragment({ provider }: ProxyFormFragmentProps) { rows={6} {...field} value={field.value || ""} + disabled={!hasUpdateProviderAccess} /> diff --git a/ui/app/workspace/providers/views/modelProviderKeysTableView.tsx b/ui/app/workspace/providers/views/modelProviderKeysTableView.tsx index accfd7fb49..b4fd6e59fb 100644 --- a/ui/app/workspace/providers/views/modelProviderKeysTableView.tsx +++ b/ui/app/workspace/providers/views/modelProviderKeysTableView.tsx @@ -13,8 +13,8 @@ import { import { Button } from "@/components/ui/button"; import { CardHeader, CardTitle } from "@/components/ui/card"; import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigger } from "@/components/ui/dropdownMenu"; -import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "@/components/ui/table"; import { Switch } from "@/components/ui/switch"; +import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "@/components/ui/table"; import { getErrorMessage, useUpdateProviderMutation } from "@/lib/store"; import { ModelProvider } from "@/lib/types/config"; import { cn } from "@/lib/utils"; @@ -31,6 +31,7 @@ interface Props { export default function ModelProviderKeysTableView({ provider, className }: Props) { const hasUpdateProviderAccess = useRbac(RbacResource.ModelProvider, RbacOperation.Update); + const hasDeleteProviderAccess = useRbac(RbacResource.ModelProvider, RbacOperation.Delete); const [updateProvider, { isLoading: isUpdatingProvider }] = useUpdateProviderMutation(); const [showAddNewKeyDialog, setShowAddNewKeyDialog] = useState<{ show: boolean; keyIndex: number } | undefined>(undefined); const [showDeleteKeyDialog, setShowDeleteKeyDialog] = useState<{ show: boolean; keyIndex: number } | undefined>(undefined); @@ -53,7 +54,7 @@ export default function ModelProviderKeysTableView({ provider, className }: Prop Cancel { updateProvider({ ...provider, @@ -168,15 +169,16 @@ export default function ModelProviderKeysTableView({ provider, className }: Prop > Edit - - { - setShowDeleteKeyDialog({ show: true, keyIndex: index }); - }} - > - - Delete - + + { + setShowDeleteKeyDialog({ show: true, keyIndex: index }); + }} + disabled={!hasDeleteProviderAccess} + > + + Delete +
diff --git a/ui/app/workspace/user-groups/views/customerDialog.tsx b/ui/app/workspace/user-groups/views/customerDialog.tsx index bbdc5c586e..61007f0027 100644 --- a/ui/app/workspace/user-groups/views/customerDialog.tsx +++ b/ui/app/workspace/user-groups/views/customerDialog.tsx @@ -180,7 +180,7 @@ export default function CustomerDialog({ customer, onSave, onCancel }: CustomerD label="Maximum Spend (USD)" value={formData.budgetMaxLimit?.toString() || ""} selectValue={formData.budgetResetDuration} - onChangeNumber={(value) => updateField("budgetMaxLimit", parseFloat(value) || 0)} + onChangeNumber={(value) => updateField("budgetMaxLimit", value === '' ? undefined : parseFloat(value))} onChangeSelect={(value) => updateField("budgetResetDuration", value)} options={resetDurationOptions} /> diff --git a/ui/app/workspace/user-groups/views/teamDialog.tsx b/ui/app/workspace/user-groups/views/teamDialog.tsx index 812c24cb21..e0230f6538 100644 --- a/ui/app/workspace/user-groups/views/teamDialog.tsx +++ b/ui/app/workspace/user-groups/views/teamDialog.tsx @@ -208,7 +208,7 @@ export default function TeamDialog({ team, customers, onSave, onCancel }: TeamDi label="Maximum Spend (USD)" value={formData.budgetMaxLimit?.toString() || ""} selectValue={formData.budgetResetDuration} - onChangeNumber={(value) => updateField("budgetMaxLimit", parseFloat(value) || 0)} + onChangeNumber={(value) => updateField("budgetMaxLimit", value === '' ? undefined : parseFloat(value))} onChangeSelect={(value) => updateField("budgetResetDuration", value)} options={resetDurationOptions} /> diff --git a/ui/components/devProfiler.tsx b/ui/components/devProfiler.tsx new file mode 100644 index 0000000000..5d32ce1350 --- /dev/null +++ b/ui/components/devProfiler.tsx @@ -0,0 +1,771 @@ +'use client' + +import { useGetDevGoroutinesQuery, useGetDevPprofQuery } from '@/lib/store' +import type { GoroutineGroup } from '@/lib/store/apis/devApi' +import { isDevelopmentMode } from '@/lib/utils/port' +import { Activity, AlertTriangle, ChevronDown, ChevronRight, ChevronUp, Cpu, EyeOff, HardDrive, RotateCcw, TrendingUp, X } from 'lucide-react' +import React, { useCallback, useEffect, useMemo, useState } from 'react' +import { + Area, + AreaChart, + CartesianGrid, + ResponsiveContainer, + Tooltip, + XAxis, + YAxis, +} from 'recharts' + +// Format bytes to human-readable string +function formatBytes (bytes: number): string { + if (bytes === 0) return '0 B' + const k = 1024 + const sizes = ['B', 'KB', 'MB', 'GB'] + const i = Math.floor(Math.log(bytes) / Math.log(k)) + return `${(bytes / Math.pow(k, i)).toFixed(1)} ${sizes[i]}` +} + +// Format nanoseconds to human-readable string +function formatNs (ns: number): string { + if (ns < 1000) return `${ns}ns` + if (ns < 1000000) return `${(ns / 1000).toFixed(1)}µs` + if (ns < 1000000000) return `${(ns / 1000000).toFixed(1)}ms` + return `${(ns / 1000000000).toFixed(2)}s` +} + +// Format timestamp to HH:MM:SS +function formatTime (timestamp: string): string { + const date = new Date(timestamp) + return date.toLocaleTimeString('en-US', { + hour12: false, + hour: '2-digit', + minute: '2-digit', + second: '2-digit', + }) +} + +// Truncate function name for display +function truncateFunction (fn: string): string { + const parts = fn.split('/') + const last = parts[parts.length - 1] + if (last.length > 40) { + return '...' + last.slice(-37) + } + return last +} + +// Get category badge color +function getCategoryColor (category: string): string { + switch (category) { + case 'per-request': + return 'text-amber-400 bg-amber-400/10' + case 'background': + return 'text-blue-400 bg-blue-400/10' + default: + return 'text-zinc-400 bg-zinc-400/10' + } +} + +// Extract file path from stack (first line containing .go:) +function getStackFilePath (stack: string[]): string { + for (const line of stack) { + // Match file path like "/path/to/file.go:123" and extract just the path + const match = line.match(/^\s*([^\s]+\.go):\d+/) + if (match) { + return match[1] + } + } + return '' +} + +// Generate a stable ID for a goroutine group +function getGoroutineId (g: GoroutineGroup): string { + return `${g.top_func}::${g.state}::${g.count}::${g.wait_minutes ?? 0}` +} + +// localStorage key for skipped goroutine file paths +const SKIPPED_GOROUTINE_FILES_KEY = 'devProfiler.skippedGoroutineFiles' + +// Load skipped goroutine file paths from localStorage +function loadSkippedGoroutineFiles (): Set { + if (typeof window === 'undefined') return new Set() + try { + const stored = localStorage.getItem(SKIPPED_GOROUTINE_FILES_KEY) + return stored ? new Set(JSON.parse(stored)) : new Set() + } catch { + return new Set() + } +} + +// Save skipped goroutine file paths to localStorage +function saveSkippedGoroutineFiles (skipped: Set): void { + if (typeof window === 'undefined') return + try { + localStorage.setItem(SKIPPED_GOROUTINE_FILES_KEY, JSON.stringify([...skipped])) + } catch { + // Ignore storage errors + } +} + +// Goroutine Health Section subcomponent +interface GoroutineHealthProps { + goroutineData: { + summary: { + background: number + per_request: number + long_waiting: number + potentially_stuck: number + } + total_goroutines: number + } | undefined + goroutineHealth: 'healthy' | 'warning' | 'critical' + goroutineTrend: { + isGrowing: boolean + growthPercent: number + avg: number + } | null + problemGoroutines: GoroutineGroup[] + expandedGoroutines: Set + toggleGoroutineExpand: (id: string) => void + skippedGoroutines: Set + onSkipGoroutine: (topFunc: string) => void + onClearSkipped: () => void +} + +function GoroutineHealthSection ({ + goroutineData, + goroutineHealth, + goroutineTrend, + problemGoroutines, + expandedGoroutines, + toggleGoroutineExpand, + skippedGoroutines, + onSkipGoroutine, + onClearSkipped, +}: GoroutineHealthProps): React.ReactNode { + if (!goroutineData) return null + + const { summary, total_goroutines } = goroutineData + + return ( +
+ {/* Header with health status */} +
+
+ + Goroutine Health +
+
+ {goroutineTrend?.isGrowing && ( + + + +{goroutineTrend.growthPercent.toFixed(0)}% + + )} + {goroutineHealth === 'critical' && ( + + + Stuck + + )} + {goroutineHealth === 'warning' && ( + + + Long Wait + + )} + {goroutineHealth === 'healthy' && ( + + Healthy + + )} +
+
+ + {/* Summary stats */} +
+
+ Total + {total_goroutines} +
+
+ Background + {summary.background} +
+
+ Per-Request + {summary.per_request} +
+
+ Stuck + 0 ? 'text-red-400' : 'text-zinc-500'}`}> + {summary.potentially_stuck} + +
+
+ + {/* Problem goroutines list */} + {(problemGoroutines.length > 0 || skippedGoroutines.size > 0) && ( +
+
+ Potential Leaks + {skippedGoroutines.size > 0 && ( + + )} +
+ {problemGoroutines.map((g) => { + const gid = getGoroutineId(g) + return ( +
+
toggleGoroutineExpand(gid)} + onKeyDown={(e) => { + if (e.key === 'Enter' || e.key === ' ') { + e.preventDefault() + toggleGoroutineExpand(gid) + } + }} + className="flex w-full cursor-pointer flex-col gap-1 px-2 py-1.5 pr-8 text-left hover:bg-zinc-700/50" + > +
+ {expandedGoroutines.has(gid) ? ( + + ) : ( + + )} + + {truncateFunction(g.top_func)} + +
+
+ + {g.category} + + {g.count}x + {g.wait_minutes != null && ( + {g.wait_minutes}m waiting + )} +
+
+ + {expandedGoroutines.has(gid) && ( +
+
+ State: {g.state} + {g.wait_reason && ( + Wait: {g.wait_reason} + )} +
+
+ {g.stack.slice(0, 10).map((line, j) => ( +
+ {line} +
+ ))} + {g.stack.length > 10 && ( +
+ ... {g.stack.length - 10} more frames +
+ )} +
+
+ )} +
+ )})} + + {problemGoroutines.length === 0 && skippedGoroutines.size > 0 && ( +
+ All potential leaks hidden +
+ )} + {problemGoroutines.length === 0 && skippedGoroutines.size === 0 && (summary.long_waiting > 0 || summary.potentially_stuck > 0) && ( +
+ {summary.long_waiting > 0 && summary.potentially_stuck > 0 + ? `${summary.long_waiting} long-waiting and ${summary.potentially_stuck} stuck goroutines (background workers filtered)` + : summary.long_waiting > 0 + ? `${summary.long_waiting} long-waiting goroutines (background workers filtered)` + : `${summary.potentially_stuck} stuck goroutines (background workers filtered)`} +
+ )} +
+ )} + + {/* No problems message */} + {problemGoroutines.length === 0 && summary.long_waiting === 0 && summary.potentially_stuck === 0 && ( +
+ No goroutine leaks detected +
+ )} +
+ ) +} + +export function DevProfiler (): React.ReactNode { + const [isVisible, setIsVisible] = useState(true) + const [isExpanded, setIsExpanded] = useState(true) + const [isDismissed, setIsDismissed] = useState(false) + const [expandedGoroutines, setExpandedGoroutines] = useState>(new Set()) + const [skippedGoroutines, setSkippedGoroutines] = useState>(() => loadSkippedGoroutineFiles()) + + // Sync skipped goroutines to localStorage + useEffect(() => { + saveSkippedGoroutineFiles(skippedGoroutines) + }, [skippedGoroutines]) + + // Only fetch in development mode and when not dismissed + const shouldFetch = isDevelopmentMode() && !isDismissed + + const { data, isLoading, error } = useGetDevPprofQuery(undefined, { + pollingInterval: shouldFetch ? 10000 : 0, // Poll every 10 seconds + skip: !shouldFetch, + }) + + const { data: goroutineData } = useGetDevGoroutinesQuery(undefined, { + pollingInterval: shouldFetch ? 10000 : 0, // Poll every 10 seconds + skip: !shouldFetch, + }) + + // Memoize chart data transformation + const memoryChartData = useMemo(() => { + if (!data?.history) return [] + return data.history.map((point) => ({ + time: formatTime(point.timestamp), + alloc: point.alloc / (1024 * 1024), // Convert to MB + heapInuse: point.heap_inuse / (1024 * 1024), + })) + }, [data?.history]) + + const cpuChartData = useMemo(() => { + if (!data?.history) return [] + return data.history.map((point) => ({ + time: formatTime(point.timestamp), + cpuPercent: point.cpu_percent, + goroutines: point.goroutines, + })) + }, [data?.history]) + + // Detect goroutine count trend (growing = potential leak) + const goroutineTrend = useMemo(() => { + if (!data?.history || data.history.length < 5 || !data?.runtime) return null + const recent = data.history.slice(-5) + const avg = recent.reduce((sum, p) => sum + p.goroutines, 0) / recent.length + const current = data.runtime.num_goroutine + const isGrowing = current > avg * 1.1 // 10% above average + const growthPercent = avg > 0 ? ((current - avg) / avg) * 100 : 0 + return { isGrowing, growthPercent, avg } + }, [data?.history, data?.runtime?.num_goroutine]) + + // Filter problem goroutines (stuck or long-waiting, excluding expected background workers and skipped) + const problemGoroutines = useMemo(() => { + if (!goroutineData?.groups) return [] + return goroutineData.groups + .filter((g) => { + if (!g.wait_minutes || g.wait_minutes < 1) return false + if (g.category === 'background') return false + const filePath = getStackFilePath(g.stack) + if (filePath && skippedGoroutines.has(filePath)) return false + return true + }) + .slice(0, 5) + }, [goroutineData?.groups, skippedGoroutines]) + + // Get goroutine health status + const goroutineHealth = useMemo(() => { + if (!goroutineData?.summary) return 'healthy' + const { potentially_stuck, long_waiting } = goroutineData.summary + if (potentially_stuck > 0) return 'critical' + if (long_waiting > 0) return 'warning' + return 'healthy' + }, [goroutineData?.summary]) + + const handleDismiss = useCallback(() => { + setIsDismissed(true) + }, []) + + const toggleGoroutineExpand = useCallback((id: string) => { + setExpandedGoroutines((prev) => { + const next = new Set(prev) + if (next.has(id)) { + next.delete(id) + } else { + next.add(id) + } + return next + }) + }, []) + + const handleSkipGoroutine = useCallback((filePath: string) => { + setSkippedGoroutines((prev) => { + const next = new Set(prev) + next.add(filePath) + return next + }) + }, []) + + const handleClearSkipped = useCallback(() => { + setSkippedGoroutines(new Set()) + }, []) + + const handleToggleExpand = useCallback(() => { + setIsExpanded((prev) => !prev) + }, []) + + const handleToggleVisible = useCallback(() => { + setIsVisible((prev) => !prev) + }, []) + + // Don't render in production mode or if dismissed + if (!isDevelopmentMode() || isDismissed) { + return null + } + + // Minimized state - just show a small button + if (!isVisible) { + return ( + + ) + } + + return ( +
+ {/* Header */} +
+
+ Dev Profiler + {isLoading && ( + + )} +
+
+ + + +
+
+ + {Boolean(error) && ( +
+ Failed to load profiling data +
+ )} + + {isExpanded && data && ( +
+ {/* Current Stats */} +
+
+ CPU Usage + + {data.cpu.usage_percent.toFixed(1)}% + +
+
+ Heap Alloc + + {formatBytes(data.memory.alloc)} + +
+
+ Heap In-Use + + {formatBytes(data.memory.heap_inuse)} + +
+
+ System + + {formatBytes(data.memory.sys)} + +
+
+ Goroutines + + {data.runtime.num_goroutine} + +
+
+ GC Pause + + {formatNs(data.runtime.gc_pause_ns)} + +
+
+ + {/* CPU Chart */} +
+
+ + CPU Usage (last 5 min) +
+
+ + + + + + + + + + + + + + + `${Number(v).toFixed(0)}%`} + width={35} + domain={[0, 'auto']} + /> + + + + + + +
+
+ + + CPU % + + + + Goroutines + +
+
+ + {/* Memory Chart */} +
+
+ + Memory (last 5 min) +
+
+ + + + + + + + + + + + + + + `${Number(v).toFixed(0)}MB`} + width={45} + /> + + + + + +
+
+ + + Alloc + + + + Heap In-Use + +
+
+ + {/* Top Allocations */} +
+
+ + Top Allocations +
+
+ {(data.top_allocations ?? []).map((alloc, i) => ( +
+
+ + {truncateFunction(alloc.function)} + + + {alloc.file}:{alloc.line} + +
+
+ + {formatBytes(alloc.bytes)} + + + {alloc.count.toLocaleString()} allocs + +
+
+ ))} +
+
+ + {/* Goroutine Health */} + + + {/* Footer with info */} +
+ CPUs: {data.runtime.num_cpu} | GOMAXPROCS: {data.runtime.gomaxprocs} | + GC: {data.runtime.num_gc} | Objects: {data.memory.heap_objects.toLocaleString()} +
+
+ )} + + {/* Collapsed state */} + {!isExpanded && data && ( +
+ + CPU: {data.cpu.usage_percent.toFixed(1)}% + + + Heap: {formatBytes(data.memory.heap_inuse)} + + + Goroutines: {data.runtime.num_goroutine} + +
+ )} +
+ ) +} diff --git a/ui/components/sidebar.tsx b/ui/components/sidebar.tsx index 10bf98bea5..cb2703f8b3 100644 --- a/ui/components/sidebar.tsx +++ b/ui/components/sidebar.tsx @@ -19,6 +19,7 @@ import { Layers, LogOut, Logs, + PanelLeftClose, Puzzle, ScrollText, Settings, @@ -29,7 +30,7 @@ import { User, UserRoundCheck, Users, - Zap, + Zap } from "lucide-react"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; @@ -46,6 +47,7 @@ import { SidebarMenuSub, SidebarMenuSubButton, SidebarMenuSubItem, + useSidebar, } from "@/components/ui/sidebar"; import { useWebSocket } from "@/hooks/useWebSocket"; import { IS_ENTERPRISE, TRIAL_EXPIRY } from "@/lib/constants/config"; @@ -160,6 +162,8 @@ const SidebarItemView = ({ onToggle, pathname, router, + isSidebarCollapsed, + expandSidebar, }: { item: SidebarItem; isActive: boolean; @@ -170,6 +174,8 @@ const SidebarItemView = ({ onToggle?: () => void; pathname: string; router: ReturnType; + isSidebarCollapsed: boolean; + expandSidebar: () => void; }) => { const hasSubItems = "subItems" in item && item.subItems && item.subItems.length > 0; const isAnySubItemActive = @@ -179,9 +185,18 @@ const SidebarItemView = ({ }); const handleClick = (e: React.MouseEvent) => { - if (hasSubItems && onToggle && item.hasAccess) { + if (hasSubItems && item.hasAccess) { e.preventDefault(); - onToggle(); + // If sidebar is collapsed, expand it first then toggle the submenu + if (isSidebarCollapsed) { + expandSidebar(); + // Small delay to allow sidebar to expand before toggling submenu + setTimeout(() => { + if (onToggle) onToggle(); + }, 100); + } else if (onToggle) { + onToggle(); + } } }; @@ -205,7 +220,8 @@ const SidebarItemView = ({ return (
- - {item.title} + + + {item.title} + {item.tag && ( - + {item.tag} )}
- {hasSubItems && } + {hasSubItems && ( + + )} {!hasSubItems && item.url === "/logs" && isWebSocketConnected && (
)} - {isExternal && } + {isExternal && }
{hasSubItems && isExpanded && ( @@ -240,7 +264,7 @@ const SidebarItemView = ({ return ( - -
- + + + {/* Expanded state: horizontal layout */} +
+ Bifrost + +
+ {/* Collapsed state: vertical layout */} +
+ Bifrost
@@ -729,19 +775,21 @@ export default function AppSidebar() { isExpanded={expandedItems.has(item.title)} onToggle={() => toggleItem(item.title)} pathname={pathname} - router={router} + router={router} + isSidebarCollapsed={sidebarState === "collapsed"} + expandSidebar={() => toggleSidebar()} /> ); })} -
-
+
+
-
+
{externalLinks.map((item, index) => (
-
+
{version ?? ""}
{trialDaysRemaining !== null && (
diff --git a/ui/components/ui/modelMultiselect.tsx b/ui/components/ui/modelMultiselect.tsx index f2ae3f4435..421c8382bb 100644 --- a/ui/components/ui/modelMultiselect.tsx +++ b/ui/components/ui/modelMultiselect.tsx @@ -152,7 +152,7 @@ export function ModelMultiselect({ noResultsFoundPlaceholder="No models found" emptyResultPlaceholder={provider ? "Start typing to search models..." : "Please select a provider first"} views={{ - dropdownIndicator: () => <>, + dropdownIndicator: () => <>, multiValue: (multiValueProps: MultiValueProps) => { return (
, SwitchProps>( - ({ className, size = "default", ...props }, ref) => ( + ({ className, size = "md", ...props }, ref) => ( , > void; + + /** Optional label to render to the right of the checkbox */ + label?: React.ReactNode; + + /** Optional disabled state */ + disabled?: boolean; + + /** Extra tailwind classes for the wrapper */ + className?: string; + + /** Accessible name for icon-only checkbox (e.g. when label is rendered elsewhere) */ + ariaLabel?: string; +} + +export const TriStateCheckbox: React.FC = ({ + allIds, + selectedIds, + onChange, + label, + disabled = false, + className = "", + ariaLabel, +}) => { + const state: TriState = useMemo(() => { + if (!allIds.length) return "none"; + + const selectedSet = new Set(selectedIds); + const selectedCount = allIds.filter((id) => selectedSet.has(id)).length; + + if (selectedCount === 0) return "none"; + if (selectedCount === allIds.length) return "all"; + return "some"; + }, [allIds, selectedIds]); + + const handleClick = () => { + if (disabled) return; + + let nextSelected: string[]; + + switch (state) { + case "all": + // clear all + nextSelected = []; + break; + case "some": + case "none": + default: + // select all + nextSelected = [...allIds]; + break; + } + + onChange(nextSelected); + }; + + const ariaChecked: boolean | "mixed" = state === "all" ? true : state === "none" ? false : "mixed"; + + const isChecked = state === "all"; + const isIndeterminate = state === "some"; + + return ( + + ); +}; diff --git a/ui/hooks/useTablePageSize.ts b/ui/hooks/useTablePageSize.ts new file mode 100644 index 0000000000..6ea4e383e9 --- /dev/null +++ b/ui/hooks/useTablePageSize.ts @@ -0,0 +1,67 @@ +"use client" + +import { RefObject, useCallback, useEffect, useState } from "react" + +const ROW_HEIGHT = 48 // h-12 = 3rem = 48px +const HEADER_HEIGHT = 44 // approximate table header height +const STATUS_ROW_HEIGHT = 48 // the "Listening for logs..." row (h-12) +const MIN_PAGE_SIZE = 5 // minimum items per page + +interface UseTablePageSizeOptions { + debounceMs?: number +} + +export function useTablePageSize ( + containerRef: RefObject, + options: UseTablePageSizeOptions = {} +): number | null { + const { debounceMs = 150 } = options + const [pageSize, setPageSize] = useState(null) + + const calculatePageSize = useCallback((height: number): number => { + const availableHeight = height - HEADER_HEIGHT - STATUS_ROW_HEIGHT + const calculated = Math.floor(availableHeight / ROW_HEIGHT) + return Math.max(calculated, MIN_PAGE_SIZE) + }, []) + + useEffect(() => { + const element = containerRef.current + if (!element) return + + let timeoutId: ReturnType | null = null + + const handleResize = (entries: ResizeObserverEntry[]) => { + const entry = entries[0] + if (!entry) return + + const height = entry.contentRect.height + + if (timeoutId) { + clearTimeout(timeoutId) + } + + timeoutId = setTimeout(() => { + const newPageSize = calculatePageSize(height) + setPageSize(newPageSize) + }, debounceMs) + } + + const resizeObserver = new ResizeObserver(handleResize) + resizeObserver.observe(element) + + // Calculate initial size immediately + const initialHeight = element.getBoundingClientRect().height + if (initialHeight > 0) { + setPageSize(calculatePageSize(initialHeight)) + } + + return () => { + if (timeoutId) { + clearTimeout(timeoutId) + } + resizeObserver.disconnect() + } + }, [containerRef, calculatePageSize, debounceMs]) + + return pageSize +} diff --git a/ui/lib/constants/logs.ts b/ui/lib/constants/logs.ts index 5cb84d994c..5d609fb340 100644 --- a/ui/lib/constants/logs.ts +++ b/ui/lib/constants/logs.ts @@ -128,8 +128,6 @@ export const RequestTypeLabels = { file_retrieve: "File Retrieve", file_delete: "File Delete", file_content: "File Content", - - } as const; export const RequestTypeColors = { @@ -164,7 +162,7 @@ export const RequestTypeColors = { batch_retrieve: "bg-red-100 text-red-800", batch_cancel: "bg-yellow-100 text-yellow-800", batch_results: "bg-purple-100 text-purple-800", - + file_upload: "bg-pink-100 text-pink-800", file_list: "bg-lime-100 text-lime-800", file_retrieve: "bg-orange-100 text-orange-800", diff --git a/ui/lib/store/apis/baseApi.ts b/ui/lib/store/apis/baseApi.ts index 3b6c6880ff..a44f1d2159 100644 --- a/ui/lib/store/apis/baseApi.ts +++ b/ui/lib/store/apis/baseApi.ts @@ -166,6 +166,7 @@ export const baseApi = createApi({ "Resources", "Operations", "Permissions", + "APIKeys", ], endpoints: () => ({}), }); diff --git a/ui/lib/store/apis/devApi.ts b/ui/lib/store/apis/devApi.ts new file mode 100644 index 0000000000..e8cc281dd3 --- /dev/null +++ b/ui/lib/store/apis/devApi.ts @@ -0,0 +1,107 @@ +import { baseApi } from './baseApi' + +// Memory statistics at a point in time +export interface MemoryStats { + alloc: number + total_alloc: number + heap_inuse: number + heap_objects: number + sys: number +} + +// CPU statistics +export interface CPUStats { + usage_percent: number + user_time: number + system_time: number +} + +// Runtime statistics +export interface RuntimeStats { + num_goroutine: number + num_gc: number + gc_pause_ns: number + num_cpu: number + gomaxprocs: number +} + +// Allocation info for top allocations +export interface AllocationInfo { + function: string + file: string + line: number + bytes: number + count: number +} + +// Single point in the metrics history +export interface HistoryPoint { + timestamp: string + alloc: number + heap_inuse: number + goroutines: number + gc_pause_ns: number + cpu_percent: number +} + +// Complete pprof data response +export interface PprofData { + timestamp: string + memory: MemoryStats + cpu: CPUStats + runtime: RuntimeStats + top_allocations: AllocationInfo[] + history: HistoryPoint[] +} + +// Goroutine group representing goroutines with same stack trace +export interface GoroutineGroup { + count: number + state: string + wait_reason?: string + wait_minutes?: number + top_func: string + stack: string[] + category: 'background' | 'per-request' | 'unknown' +} + +// Goroutine health summary +export interface GoroutineSummary { + background: number + per_request: number + long_waiting: number + potentially_stuck: number +} + +// Goroutine profile response +export interface GoroutineProfile { + timestamp: string + total_goroutines: number + groups: GoroutineGroup[] + summary: GoroutineSummary +} + +export const devApi = baseApi.injectEndpoints({ + endpoints: (builder) => ({ + // Get dev pprof data - polls every 10 seconds + getDevPprof: builder.query({ + query: () => ({ + url: '/dev/pprof', + }), + }), + // Get goroutine profile for leak detection + getDevGoroutines: builder.query({ + query: () => ({ + url: '/dev/pprof/goroutines', + }), + }), + }), +}) + +export const { + useGetDevPprofQuery, + useLazyGetDevPprofQuery, + useGetDevGoroutinesQuery, + useLazyGetDevGoroutinesQuery, +} = devApi + diff --git a/ui/lib/store/apis/index.ts b/ui/lib/store/apis/index.ts index 99fc3eb569..3946f3318f 100644 --- a/ui/lib/store/apis/index.ts +++ b/ui/lib/store/apis/index.ts @@ -3,6 +3,7 @@ export { baseApi, clearAuthStorage, getErrorMessage, setAuthToken } from "./base // API slices and hooks export * from "./configApi"; +export * from "./devApi"; export * from "./governanceApi"; export * from "./logsApi"; export * from "./mcpApi"; diff --git a/ui/lib/types/config.ts b/ui/lib/types/config.ts index 54b183837b..5732f8160b 100644 --- a/ui/lib/types/config.ts +++ b/ui/lib/types/config.ts @@ -351,6 +351,9 @@ export interface CoreConfig { allowed_origins: string[]; max_request_body_size_mb: number; enable_litellm_fallbacks: boolean; + mcp_agent_depth: number; + mcp_tool_execution_timeout: number; + mcp_code_mode_binding_level?: string; header_filter_config?: GlobalHeaderFilterConfig; } diff --git a/ui/lib/types/logs.ts b/ui/lib/types/logs.ts index 35d9a0b6de..024979ce1b 100644 --- a/ui/lib/types/logs.ts +++ b/ui/lib/types/logs.ts @@ -451,6 +451,13 @@ export interface ResponsesMessage { encrypted_content?: string; // Additional tool-specific fields [key: string]: any; + output?: string | ResponsesMessageContentBlock[] | ResponsesComputerToolCallOutputData; +} + +export interface ResponsesComputerToolCallOutputData { + type: "computer_screenshot"; + file_id?: string; + image_url?: string; } // Stream options for responses diff --git a/ui/lib/types/mcp.ts b/ui/lib/types/mcp.ts index 7b7f4f8fb8..9faee1dbb7 100644 --- a/ui/lib/types/mcp.ts +++ b/ui/lib/types/mcp.ts @@ -13,10 +13,12 @@ export interface MCPStdioConfig { export interface MCPClientConfig { id: string; name: string; + is_code_mode_client?: boolean; connection_type: MCPConnectionType; connection_string?: string; stdio_config?: MCPStdioConfig; tools_to_execute?: string[]; + tools_to_auto_execute?: string[]; headers?: Record; } @@ -28,15 +30,19 @@ export interface MCPClient { export interface CreateMCPClientRequest { name: string; + is_code_mode_client?: boolean; connection_type: MCPConnectionType; connection_string?: string; stdio_config?: MCPStdioConfig; tools_to_execute?: string[]; + tools_to_auto_execute?: string[]; headers?: Record; } export interface UpdateMCPClientRequest { name?: string; + is_code_mode_client?: boolean; headers?: Record; tools_to_execute?: string[]; + tools_to_auto_execute?: string[]; } diff --git a/ui/lib/types/schemas.ts b/ui/lib/types/schemas.ts index ed15e0d67c..d28779413c 100644 --- a/ui/lib/types/schemas.ts +++ b/ui/lib/types/schemas.ts @@ -94,11 +94,11 @@ export const s3BucketConfigSchema = z.object({ bucket_name: z.string().min(1, "Bucket name is required"), prefix: z.string().optional(), is_default: z.boolean().optional(), -}) +}); export const batchS3ConfigSchema = z.object({ buckets: z.array(s3BucketConfigSchema).optional(), -}) +}); // Bedrock key config schema export const bedrockKeyConfigSchema = z @@ -464,6 +464,9 @@ export const coreConfigSchema = z.object({ allow_direct_keys: z.boolean().default(false), allowed_origins: z.array(z.string()).default(["*"]), max_request_body_size_mb: z.number().min(1).default(100), + mcp_agent_depth: z.number().min(1).default(10), + mcp_tool_execution_timeout: z.number().min(1).default(30), + mcp_code_mode_binding_level: z.enum(["server", "tool"]).default("server"), }); // Bifrost config schema @@ -602,7 +605,13 @@ export const maximFormSchema = z.object({ // MCP Client update schema export const mcpClientUpdateSchema = z.object({ - name: z.string().min(1, "Name is required"), + is_code_mode_client: z.boolean().optional(), + name: z + .string() + .min(1, "Name is required") + .refine((val) => !val.includes("-"), { message: "Client name cannot contain hyphens" }) + .refine((val) => !val.includes(" "), { message: "Client name cannot contain spaces" }) + .refine((val) => !/^[0-9]/.test(val), { message: "Client name cannot start with a number" }), headers: z.record(z.string(), z.string()).optional(), tools_to_execute: z .array(z.string()) @@ -622,10 +631,28 @@ export const mcpClientUpdateSchema = z.object({ }, { message: "Duplicate tool names are not allowed" }, ), + tools_to_auto_execute: z + .array(z.string()) + .optional() + .refine( + (tools) => { + if (!tools || tools.length === 0) return true; + const hasWildcard = tools.includes("*"); + return !hasWildcard || tools.length === 1; + }, + { message: "Wildcard '*' cannot be combined with other tool names" }, + ) + .refine( + (tools) => { + if (!tools) return true; + return tools.length === new Set(tools).size; + }, + { message: "Duplicate tool names are not allowed" }, + ), }); // Global proxy type schema -export const globalProxyTypeSchema = z.enum(['http', 'socks5', 'tcp']); +export const globalProxyTypeSchema = z.enum(["http", "socks5", "tcp"]); // Global proxy configuration schema export const globalProxyConfigSchema = z @@ -652,8 +679,8 @@ export const globalProxyConfigSchema = z return true; }, { - message: 'Proxy URL is required when proxy is enabled', - path: ['url'], + message: "Proxy URL is required when proxy is enabled", + path: ["url"], }, ) .refine( @@ -670,8 +697,8 @@ export const globalProxyConfigSchema = z return true; }, { - message: 'Must be a valid URL (e.g., http://proxy.example.com:8080)', - path: ['url'], + message: "Must be a valid URL (e.g., http://proxy.example.com:8080)", + path: ["url"], }, ); diff --git a/ui/lib/utils/validation.ts b/ui/lib/utils/validation.ts index 7b77cb3780..aed9c326ec 100644 --- a/ui/lib/utils/validation.ts +++ b/ui/lib/utils/validation.ts @@ -371,7 +371,11 @@ function isValidWildcardOrigin(origin: string): boolean { * @returns Object with validation result and invalid origins */ export function validateOrigins(origins: string[]): { isValid: boolean; invalidOrigins: string[] } { - const invalidOrigins = origins?.filter((origin) => !isValidOrigin(origin)) || []; + if (!origins || origins.length === 0) { + return { isValid: true, invalidOrigins: [] }; + } + + const invalidOrigins = origins.filter((origin) => !isValidOrigin(origin)); return { isValid: invalidOrigins.length === 0, diff --git a/ui/public/bifrost-icon-dark.png b/ui/public/bifrost-icon-dark.png new file mode 100644 index 0000000000..583ffe4fc4 Binary files /dev/null and b/ui/public/bifrost-icon-dark.png differ diff --git a/ui/public/bifrost-icon.png b/ui/public/bifrost-icon.png new file mode 100644 index 0000000000..32ff249bd3 Binary files /dev/null and b/ui/public/bifrost-icon.png differ