diff --git a/.github/workflows/scripts/run-migration-tests.sh b/.github/workflows/scripts/run-migration-tests.sh index 817c73c1c3..d4bd7e89f8 100755 --- a/.github/workflows/scripts/run-migration-tests.sh +++ b/.github/workflows/scripts/run-migration-tests.sh @@ -473,10 +473,11 @@ VALUES (1, 'migration-test-hash-abc123def456', $now, $now) ON CONFLICT DO NOTHING; -- governance_budgets (reset_duration is a string like "1d", "1h", etc.) -INSERT INTO governance_budgets (id, max_limit, current_usage, reset_duration, last_reset, config_hash, calendar_aligned, created_at, updated_at) +-- NOTE: calendar_aligned excluded - it was added in prerelease1, dropped in prerelease2, re-added in prerelease4 +INSERT INTO governance_budgets (id, max_limit, current_usage, reset_duration, last_reset, config_hash, created_at, updated_at) VALUES - ('budget-migration-test-1', 1000.00, 100.00, '1d', $now, 'budget-hash-001', false, $now, $now), - ('budget-migration-test-2', 5000.00, 250.00, '7d', $now, 'budget-hash-002', false, $now, $now) + ('budget-migration-test-1', 1000.00, 100.00, '1d', $now, 'budget-hash-001', $now, $now), + ('budget-migration-test-2', 5000.00, 250.00, '7d', $now, 'budget-hash-002', $now, $now) ON CONFLICT DO NOTHING; -- governance_rate_limits (flexible duration format with token_* and request_* columns) @@ -493,11 +494,12 @@ VALUES ('customer-migration-test-2', 'Migration Test Customer Two', NULL, NULL, 'customer-hash-002', $now, $now) ON CONFLICT DO NOTHING; --- governance_teams (with customer_id, budget_id, rate_limit_id, profile, config, claims, config_hash) -INSERT INTO governance_teams (id, name, customer_id, budget_id, rate_limit_id, profile, config, claims, config_hash, created_at, updated_at) +-- governance_teams (with customer_id, rate_limit_id, profile, config, claims, config_hash) +-- NOTE: budget_id excluded - it was dropped from governance_teams in prerelease4 (team budgets moved to governance_budgets.team_id) +INSERT INTO governance_teams (id, name, customer_id, rate_limit_id, profile, config, claims, config_hash, created_at, updated_at) VALUES - ('team-migration-test-1', 'Migration Test Team Alpha', 'customer-migration-test-1', 'budget-migration-test-2', 'ratelimit-migration-test-2', '{"role": "admin"}', '{"setting": "value"}', '{"claim1": "val1"}', 'team-hash-001', $now, $now), - ('team-migration-test-2', 'Migration Test Team Beta', NULL, NULL, NULL, NULL, NULL, NULL, 'team-hash-002', $now, $now) + ('team-migration-test-1', 'Migration Test Team Alpha', 'customer-migration-test-1', 'ratelimit-migration-test-2', '{"role": "admin"}', '{"setting": "value"}', '{"claim1": "val1"}', 'team-hash-001', $now, $now), + ('team-migration-test-2', 'Migration Test Team Beta', NULL, NULL, NULL, NULL, NULL, 'team-hash-002', $now, $now) ON CONFLICT DO NOTHING; -- config_providers (with all JSON config fields and governance fields including budget_id, rate_limit_id) @@ -620,18 +622,20 @@ ON CONFLICT DO NOTHING; -- config_mcp_clients INSERT is generated dynamically after this heredoc -- to handle older schemas that may not have newer columns (tool_pricing_json, auth_type, etc.) --- governance_virtual_keys (with all columns including description, is_active, team_id, customer_id, budget_id, rate_limit_id, config_hash) -INSERT INTO governance_virtual_keys (id, name, description, value, is_active, team_id, customer_id, budget_id, rate_limit_id, config_hash, created_at, updated_at) +-- governance_virtual_keys (with all columns including description, is_active, team_id, customer_id, rate_limit_id, config_hash) +-- NOTE: budget_id excluded - dropped from governance_virtual_keys in prerelease2 (ownership moved to governance_budgets.virtual_key_id) +INSERT INTO governance_virtual_keys (id, name, description, value, is_active, team_id, customer_id, rate_limit_id, config_hash, created_at, updated_at) VALUES - ('vk-migration-test-1', 'Migration Test Virtual Key 1', 'Test virtual key for migration', 'vk-migration-fake-value-001', true, 'team-migration-test-1', NULL, 'budget-migration-test-1', 'ratelimit-migration-test-1', 'vk-hash-001', $now, $now), - ('vk-migration-test-2', 'Migration Test Virtual Key 2', 'Another test virtual key', 'vk-migration-fake-value-002', true, NULL, 'customer-migration-test-2', NULL, NULL, 'vk-hash-002', $now, $now) + ('vk-migration-test-1', 'Migration Test Virtual Key 1', 'Test virtual key for migration', 'vk-migration-fake-value-001', true, 'team-migration-test-1', NULL, 'ratelimit-migration-test-1', 'vk-hash-001', $now, $now), + ('vk-migration-test-2', 'Migration Test Virtual Key 2', 'Another test virtual key', 'vk-migration-fake-value-002', true, NULL, 'customer-migration-test-2', NULL, 'vk-hash-002', $now, $now) ON CONFLICT DO NOTHING; -- governance_virtual_key_provider_configs (references virtual_keys - with all columns) -INSERT INTO governance_virtual_key_provider_configs (virtual_key_id, provider, weight, allowed_models, budget_id, rate_limit_id) +-- NOTE: budget_id excluded - dropped from governance_virtual_key_provider_configs in prerelease2 +INSERT INTO governance_virtual_key_provider_configs (virtual_key_id, provider, weight, allowed_models, rate_limit_id) VALUES - ('vk-migration-test-1', 'openai', 0.7, '["gpt-4"]', NULL, NULL), - ('vk-migration-test-2', 'anthropic', 0.3, '[]', 'budget-migration-test-2', 'ratelimit-migration-test-2') + ('vk-migration-test-1', 'openai', 0.7, '["gpt-4"]', NULL), + ('vk-migration-test-2', 'anthropic', 0.3, '[]', 'ratelimit-migration-test-2') ON CONFLICT DO NOTHING; -- governance_virtual_key_provider_config_keys (join table for provider configs and keys) @@ -722,6 +726,7 @@ append_dynamic_mcp_clients_insert() { generate_mcp_clients_insert_postgres "$now" "$faker_sql" generate_async_jobs_insert_postgres "$now" "$future" "$faker_sql" generate_prompt_repo_tables_insert_postgres "$now" "$faker_sql" + generate_per_user_oauth_tables_insert_postgres "$now" "$faker_sql" generate_model_parameters_insert_postgres "$now" "$faker_sql" generate_routing_targets_insert_postgres "$now" "$faker_sql" generate_pricing_overrides_insert_postgres "$now" "$faker_sql" @@ -733,6 +738,7 @@ append_dynamic_mcp_clients_insert() { generate_mcp_clients_insert_sqlite "$now" "$faker_sql" "$config_db" generate_async_jobs_insert_sqlite "$now" "$future" "$faker_sql" generate_prompt_repo_tables_insert_sqlite "$now" "$faker_sql" "$config_db" + generate_per_user_oauth_tables_insert_sqlite "$now" "$faker_sql" "$config_db" generate_model_parameters_insert_sqlite "$now" "$faker_sql" "$config_db" generate_routing_targets_insert_sqlite "$now" "$faker_sql" "$config_db" generate_pricing_overrides_insert_sqlite "$now" "$faker_sql" "$config_db" @@ -1422,6 +1428,205 @@ append_dynamic_columns_postgres() { echo "UPDATE governance_model_pricing SET cache_read_input_token_cost_flex = NULL WHERE id = 1;" >> "$output_file" echo "UPDATE governance_model_pricing SET cache_read_input_token_cost_flex = NULL WHERE id = 2;" >> "$output_file" fi + + # ------------------------------------------------------------------------- + # Columns dropped in v1.5.0 migrations - coverage for old schemas that still have them + # ------------------------------------------------------------------------- + + # governance_virtual_keys.budget_id (dropped in v1.5.0-prerelease2 via migrationAddMultiBudgetTables) + # Static INSERT no longer includes this column; set NULL so old-schema validators pass + if column_exists_postgres "governance_virtual_keys" "budget_id"; then + echo "UPDATE governance_virtual_keys SET budget_id = NULL WHERE id = 'vk-migration-test-1';" >> "$output_file" + echo "UPDATE governance_virtual_keys SET budget_id = NULL WHERE id = 'vk-migration-test-2';" >> "$output_file" + fi + + # governance_virtual_key_provider_configs.budget_id (dropped in v1.5.0-prerelease2) + if column_exists_postgres "governance_virtual_key_provider_configs" "budget_id"; then + echo "UPDATE governance_virtual_key_provider_configs SET budget_id = NULL WHERE virtual_key_id = 'vk-migration-test-1';" >> "$output_file" + echo "UPDATE governance_virtual_key_provider_configs SET budget_id = NULL WHERE virtual_key_id = 'vk-migration-test-2';" >> "$output_file" + fi + + # governance_teams.budget_id (dropped in v1.5.0-prerelease4 via migrationAddTeamBudgetsToBudgetsTable) + if column_exists_postgres "governance_teams" "budget_id"; then + echo "UPDATE governance_teams SET budget_id = NULL WHERE id = 'team-migration-test-1';" >> "$output_file" + fi + + # ------------------------------------------------------------------------- + # v1.5.0-prerelease2 columns - config store tables + # ------------------------------------------------------------------------- + + # config_keys.aliases_json (added in v1.5.0-prerelease2 via migrationDropDeploymentColumnsAndAddAliases) + if column_exists_postgres "config_keys" "aliases_json"; then + echo "UPDATE config_keys SET aliases_json = NULL WHERE name = 'migration-test-key-openai';" >> "$output_file" + echo "UPDATE config_keys SET aliases_json = NULL WHERE name = 'migration-test-key-anthropic';" >> "$output_file" + fi + + # config_keys.replicate_use_deployments_endpoint (added in v1.5.0-prerelease2) + if column_exists_postgres "config_keys" "replicate_use_deployments_endpoint"; then + echo "UPDATE config_keys SET replicate_use_deployments_endpoint = false WHERE name = 'migration-test-key-openai';" >> "$output_file" + echo "UPDATE config_keys SET replicate_use_deployments_endpoint = false WHERE name = 'migration-test-key-anthropic';" >> "$output_file" + fi + + # config_client.routing_chain_max_depth (added in v1.5.0-prerelease2) + if column_exists_postgres "config_client" "routing_chain_max_depth"; then + echo "UPDATE config_client SET routing_chain_max_depth = 10 WHERE id = 1;" >> "$output_file" + fi + + # config_client compat columns (added in v1.5.0-prerelease2 via migrationReplaceEnableLiteLLMWithCompatColumns) + # NOTE: also present in static INSERT, but that INSERT may fail on old schemas; UPDATE covers all cases + if column_exists_postgres "config_client" "compat_convert_text_to_chat"; then + echo "UPDATE config_client SET compat_convert_text_to_chat = false WHERE id = 1;" >> "$output_file" + fi + if column_exists_postgres "config_client" "compat_convert_chat_to_responses"; then + echo "UPDATE config_client SET compat_convert_chat_to_responses = false WHERE id = 1;" >> "$output_file" + fi + if column_exists_postgres "config_client" "compat_should_drop_params"; then + echo "UPDATE config_client SET compat_should_drop_params = false WHERE id = 1;" >> "$output_file" + fi + if column_exists_postgres "config_client" "compat_should_convert_params"; then + echo "UPDATE config_client SET compat_should_convert_params = false WHERE id = 1;" >> "$output_file" + fi + + # governance_virtual_keys.calendar_aligned (added in v1.5.0-prerelease2 via migrationAddMultiBudgetTables) + if column_exists_postgres "governance_virtual_keys" "calendar_aligned"; then + echo "UPDATE governance_virtual_keys SET calendar_aligned = false WHERE id = 'vk-migration-test-1';" >> "$output_file" + echo "UPDATE governance_virtual_keys SET calendar_aligned = false WHERE id = 'vk-migration-test-2';" >> "$output_file" + fi + + # governance_budgets.virtual_key_id (added in v1.5.0-prerelease2 via migrationAddMultiBudgetTables) + if column_exists_postgres "governance_budgets" "virtual_key_id"; then + echo "UPDATE governance_budgets SET virtual_key_id = NULL WHERE id = 'budget-migration-test-1';" >> "$output_file" + echo "UPDATE governance_budgets SET virtual_key_id = NULL WHERE id = 'budget-migration-test-2';" >> "$output_file" + fi + + # governance_budgets.provider_config_id (added in v1.5.0-prerelease2 via migrationAddMultiBudgetTables) + if column_exists_postgres "governance_budgets" "provider_config_id"; then + echo "UPDATE governance_budgets SET provider_config_id = NULL WHERE id = 'budget-migration-test-1';" >> "$output_file" + echo "UPDATE governance_budgets SET provider_config_id = NULL WHERE id = 'budget-migration-test-2';" >> "$output_file" + fi + + # routing_rules.chain_rule (added in v1.5.0-prerelease2) + if column_exists_postgres "routing_rules" "chain_rule"; then + echo "UPDATE routing_rules SET chain_rule = false WHERE id = 'rule-migration-test-1';" >> "$output_file" + echo "UPDATE routing_rules SET chain_rule = false WHERE id = 'rule-migration-test-2';" >> "$output_file" + fi + + # governance_virtual_key_provider_configs.allow_all_keys (added in v1.5.0-prerelease2) + # vk-migration-test-1 has a key in the join table (restricted) -> allow_all_keys=false + # vk-migration-test-2 has no key rows (old empty=allow-all semantics) -> allow_all_keys=true + if column_exists_postgres "governance_virtual_key_provider_configs" "allow_all_keys"; then + echo "UPDATE governance_virtual_key_provider_configs SET allow_all_keys = false WHERE virtual_key_id = 'vk-migration-test-1';" >> "$output_file" + echo "UPDATE governance_virtual_key_provider_configs SET allow_all_keys = true WHERE virtual_key_id = 'vk-migration-test-2';" >> "$output_file" + fi + + # ------------------------------------------------------------------------- + # v1.5.0-prerelease2 columns - log store tables + # ------------------------------------------------------------------------- + + # logs.alias (added in v1.5.0-prerelease2 via migrationAddAliasColumn) + if column_exists_postgres "logs" "alias"; then + echo "UPDATE logs SET alias = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET alias = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET alias = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi + + # logs.has_object (added in v1.5.0-prerelease2 via migrationAddHasObjectColumn) + if column_exists_postgres "logs" "has_object"; then + echo "UPDATE logs SET has_object = false WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET has_object = false WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET has_object = false WHERE id = 'log-migration-test-003';" >> "$output_file" + fi + + # logs governance context columns (added in v1.5.0-prerelease2 via migrationAddGovernanceContextColumns) + for ctx_col in user_id team_id team_name customer_id customer_name business_unit_id business_unit_name; do + if column_exists_postgres "logs" "$ctx_col"; then + echo "UPDATE logs SET $ctx_col = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET $ctx_col = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET $ctx_col = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi + done + + # ------------------------------------------------------------------------- + # v1.5.0-prerelease4 columns - config store tables + # ------------------------------------------------------------------------- + + # governance_budgets.team_id (added in v1.5.0-prerelease4 via migrationAddTeamBudgetsToBudgetsTable) + if column_exists_postgres "governance_budgets" "team_id"; then + echo "UPDATE governance_budgets SET team_id = NULL WHERE id = 'budget-migration-test-1';" >> "$output_file" + echo "UPDATE governance_budgets SET team_id = NULL WHERE id = 'budget-migration-test-2';" >> "$output_file" + fi + + # governance_budgets.calendar_aligned (re-added in v1.5.0-prerelease4 via migrateCalendarAlignedToBudgetsAndRateLimitsTable) + # NOTE: was present in prerelease1, dropped in prerelease2, re-added in prerelease4 + if column_exists_postgres "governance_budgets" "calendar_aligned"; then + echo "UPDATE governance_budgets SET calendar_aligned = false WHERE id = 'budget-migration-test-1';" >> "$output_file" + echo "UPDATE governance_budgets SET calendar_aligned = false WHERE id = 'budget-migration-test-2';" >> "$output_file" + fi + + # governance_rate_limits.calendar_aligned (added in v1.5.0-prerelease4 via migrateCalendarAlignedToBudgetsAndRateLimitsTable) + if column_exists_postgres "governance_rate_limits" "calendar_aligned"; then + echo "UPDATE governance_rate_limits SET calendar_aligned = false WHERE id = 'ratelimit-migration-test-1';" >> "$output_file" + echo "UPDATE governance_rate_limits SET calendar_aligned = false WHERE id = 'ratelimit-migration-test-2';" >> "$output_file" + fi + + # governance_model_pricing OCR pricing columns (added in v1.5.0-prerelease4 via migrationAddOCRPricingColumns) + if column_exists_postgres "governance_model_pricing" "ocr_cost_per_page"; then + echo "UPDATE governance_model_pricing SET ocr_cost_per_page = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET ocr_cost_per_page = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_postgres "governance_model_pricing" "annotation_cost_per_page"; then + echo "UPDATE governance_model_pricing SET annotation_cost_per_page = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET annotation_cost_per_page = NULL WHERE id = 2;" >> "$output_file" + fi + + # ------------------------------------------------------------------------- + # v1.5.0-prerelease4 columns - log store tables + # ------------------------------------------------------------------------- + + # logs.user_name (added in v1.5.0-prerelease4 via migrationAddUserNameColumn) + if column_exists_postgres "logs" "user_name"; then + echo "UPDATE logs SET user_name = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET user_name = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET user_name = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi + + # logs.attempt_trail (added in v1.5.0-prerelease4 via migrationAddAttemptTrailColumn) + if column_exists_postgres "logs" "attempt_trail"; then + echo "UPDATE logs SET attempt_trail = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET attempt_trail = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET attempt_trail = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi + + # logs.selected_prompt_* columns (added in v1.5.0-prerelease4 via migrationAddSelectedPromptColumns) + for sp_col in selected_prompt_name selected_prompt_version selected_prompt_id; do + if column_exists_postgres "logs" "$sp_col"; then + echo "UPDATE logs SET $sp_col = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET $sp_col = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET $sp_col = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi + done + + # logs.ocr_input (added in v1.5.0-prerelease4 via migrationAddOCRInputColumn) + if column_exists_postgres "logs" "ocr_input"; then + echo "UPDATE logs SET ocr_input = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET ocr_input = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET ocr_input = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi + + # ------------------------------------------------------------------------- + # v1.5.0-prerelease2 columns - prompt repo tables + # ------------------------------------------------------------------------- + + # prompt_versions.variables_json (added in v1.5.0-prerelease2 via migrationAddPromptVariablesColumns) + if column_exists_postgres "prompt_versions" "variables_json"; then + echo "UPDATE prompt_versions SET variables_json = '{}' WHERE prompt_id = 'prompt-migration-test-001';" >> "$output_file" + fi + + # prompt_sessions.variables_json (added in v1.5.0-prerelease2 via migrationAddPromptVariablesColumns) + if column_exists_postgres "prompt_sessions" "variables_json"; then + echo "UPDATE prompt_sessions SET variables_json = '{}' WHERE prompt_id = 'prompt-migration-test-001';" >> "$output_file" + echo "UPDATE prompt_sessions SET variables_json = '{}' WHERE prompt_id = 'prompt-migration-test-002';" >> "$output_file" + fi } # Append dynamic column UPDATEs for columns that may not exist in older schemas (SQLite) @@ -2093,6 +2298,186 @@ append_dynamic_columns_sqlite() { echo "UPDATE governance_model_pricing SET cache_read_input_token_cost_flex = NULL WHERE id = 2;" >> "$output_file" fi fi + + # ------------------------------------------------------------------------- + # Columns dropped in v1.5.0 migrations - coverage for old schemas that still have them + # ------------------------------------------------------------------------- + + if [ -f "$config_db" ]; then + # governance_virtual_keys.budget_id (dropped in v1.5.0-prerelease2 via migrationAddMultiBudgetTables) + if column_exists_sqlite "$config_db" "governance_virtual_keys" "budget_id"; then + echo "UPDATE governance_virtual_keys SET budget_id = NULL WHERE id = 'vk-migration-test-1';" >> "$output_file" + echo "UPDATE governance_virtual_keys SET budget_id = NULL WHERE id = 'vk-migration-test-2';" >> "$output_file" + fi + + # governance_virtual_key_provider_configs.budget_id (dropped in v1.5.0-prerelease2) + if column_exists_sqlite "$config_db" "governance_virtual_key_provider_configs" "budget_id"; then + echo "UPDATE governance_virtual_key_provider_configs SET budget_id = NULL WHERE virtual_key_id = 'vk-migration-test-1';" >> "$output_file" + echo "UPDATE governance_virtual_key_provider_configs SET budget_id = NULL WHERE virtual_key_id = 'vk-migration-test-2';" >> "$output_file" + fi + + # governance_teams.budget_id (dropped in v1.5.0-prerelease4 via migrationAddTeamBudgetsToBudgetsTable) + if column_exists_sqlite "$config_db" "governance_teams" "budget_id"; then + echo "UPDATE governance_teams SET budget_id = NULL WHERE id = 'team-migration-test-1';" >> "$output_file" + fi + fi + + # ------------------------------------------------------------------------- + # v1.5.0-prerelease2 columns - config store tables + # ------------------------------------------------------------------------- + + if [ -f "$config_db" ]; then + # config_keys.aliases_json (added in v1.5.0-prerelease2) + if column_exists_sqlite "$config_db" "config_keys" "aliases_json"; then + echo "UPDATE config_keys SET aliases_json = NULL WHERE name = 'migration-test-key-openai';" >> "$output_file" + echo "UPDATE config_keys SET aliases_json = NULL WHERE name = 'migration-test-key-anthropic';" >> "$output_file" + fi + + # config_keys.replicate_use_deployments_endpoint (added in v1.5.0-prerelease2) + if column_exists_sqlite "$config_db" "config_keys" "replicate_use_deployments_endpoint"; then + echo "UPDATE config_keys SET replicate_use_deployments_endpoint = 0 WHERE name = 'migration-test-key-openai';" >> "$output_file" + echo "UPDATE config_keys SET replicate_use_deployments_endpoint = 0 WHERE name = 'migration-test-key-anthropic';" >> "$output_file" + fi + + # config_client.routing_chain_max_depth (added in v1.5.0-prerelease2) + if column_exists_sqlite "$config_db" "config_client" "routing_chain_max_depth"; then + echo "UPDATE config_client SET routing_chain_max_depth = 10 WHERE id = 1;" >> "$output_file" + fi + + # config_client compat columns (added in v1.5.0-prerelease2) + if column_exists_sqlite "$config_db" "config_client" "compat_convert_text_to_chat"; then + echo "UPDATE config_client SET compat_convert_text_to_chat = 0 WHERE id = 1;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "config_client" "compat_convert_chat_to_responses"; then + echo "UPDATE config_client SET compat_convert_chat_to_responses = 0 WHERE id = 1;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "config_client" "compat_should_drop_params"; then + echo "UPDATE config_client SET compat_should_drop_params = 0 WHERE id = 1;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "config_client" "compat_should_convert_params"; then + echo "UPDATE config_client SET compat_should_convert_params = 0 WHERE id = 1;" >> "$output_file" + fi + + # governance_virtual_keys.calendar_aligned (added in v1.5.0-prerelease2) + if column_exists_sqlite "$config_db" "governance_virtual_keys" "calendar_aligned"; then + echo "UPDATE governance_virtual_keys SET calendar_aligned = 0 WHERE id = 'vk-migration-test-1';" >> "$output_file" + echo "UPDATE governance_virtual_keys SET calendar_aligned = 0 WHERE id = 'vk-migration-test-2';" >> "$output_file" + fi + + # governance_budgets.virtual_key_id, provider_config_id (added in v1.5.0-prerelease2) + if column_exists_sqlite "$config_db" "governance_budgets" "virtual_key_id"; then + echo "UPDATE governance_budgets SET virtual_key_id = NULL WHERE id = 'budget-migration-test-1';" >> "$output_file" + echo "UPDATE governance_budgets SET virtual_key_id = NULL WHERE id = 'budget-migration-test-2';" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_budgets" "provider_config_id"; then + echo "UPDATE governance_budgets SET provider_config_id = NULL WHERE id = 'budget-migration-test-1';" >> "$output_file" + echo "UPDATE governance_budgets SET provider_config_id = NULL WHERE id = 'budget-migration-test-2';" >> "$output_file" + fi + + # routing_rules.chain_rule (added in v1.5.0-prerelease2) + if column_exists_sqlite "$config_db" "routing_rules" "chain_rule"; then + echo "UPDATE routing_rules SET chain_rule = 0 WHERE id = 'rule-migration-test-1';" >> "$output_file" + echo "UPDATE routing_rules SET chain_rule = 0 WHERE id = 'rule-migration-test-2';" >> "$output_file" + fi + + # governance_virtual_key_provider_configs.allow_all_keys (added in v1.5.0-prerelease2) + if column_exists_sqlite "$config_db" "governance_virtual_key_provider_configs" "allow_all_keys"; then + echo "UPDATE governance_virtual_key_provider_configs SET allow_all_keys = 0 WHERE virtual_key_id = 'vk-migration-test-1';" >> "$output_file" + echo "UPDATE governance_virtual_key_provider_configs SET allow_all_keys = 1 WHERE virtual_key_id = 'vk-migration-test-2';" >> "$output_file" + fi + + # ------------------------------------------------------------------------- + # v1.5.0-prerelease4 columns - config store tables + # ------------------------------------------------------------------------- + + # governance_budgets.team_id (added in v1.5.0-prerelease4) + if column_exists_sqlite "$config_db" "governance_budgets" "team_id"; then + echo "UPDATE governance_budgets SET team_id = NULL WHERE id = 'budget-migration-test-1';" >> "$output_file" + echo "UPDATE governance_budgets SET team_id = NULL WHERE id = 'budget-migration-test-2';" >> "$output_file" + fi + + # governance_budgets.calendar_aligned (re-added in v1.5.0-prerelease4) + if column_exists_sqlite "$config_db" "governance_budgets" "calendar_aligned"; then + echo "UPDATE governance_budgets SET calendar_aligned = 0 WHERE id = 'budget-migration-test-1';" >> "$output_file" + echo "UPDATE governance_budgets SET calendar_aligned = 0 WHERE id = 'budget-migration-test-2';" >> "$output_file" + fi + + # governance_rate_limits.calendar_aligned (added in v1.5.0-prerelease4) + if column_exists_sqlite "$config_db" "governance_rate_limits" "calendar_aligned"; then + echo "UPDATE governance_rate_limits SET calendar_aligned = 0 WHERE id = 'ratelimit-migration-test-1';" >> "$output_file" + echo "UPDATE governance_rate_limits SET calendar_aligned = 0 WHERE id = 'ratelimit-migration-test-2';" >> "$output_file" + fi + + # governance_model_pricing OCR pricing columns (added in v1.5.0-prerelease4 via migrationAddOCRPricingColumns) + if column_exists_sqlite "$config_db" "governance_model_pricing" "ocr_cost_per_page"; then + echo "UPDATE governance_model_pricing SET ocr_cost_per_page = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET ocr_cost_per_page = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_model_pricing" "annotation_cost_per_page"; then + echo "UPDATE governance_model_pricing SET annotation_cost_per_page = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET annotation_cost_per_page = NULL WHERE id = 2;" >> "$output_file" + fi + fi + + # ------------------------------------------------------------------------- + # v1.5.0-prerelease2 columns - log store tables (emitted unconditionally; fail silently on config_db) + # ------------------------------------------------------------------------- + + # logs.alias (added in v1.5.0-prerelease2) + echo "UPDATE logs SET alias = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET alias = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET alias = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + + # logs.has_object (added in v1.5.0-prerelease2) + echo "UPDATE logs SET has_object = 0 WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET has_object = 0 WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET has_object = 0 WHERE id = 'log-migration-test-003';" >> "$output_file" + + # logs governance context columns (added in v1.5.0-prerelease2) + for ctx_col in user_id team_id team_name customer_id customer_name business_unit_id business_unit_name; do + echo "UPDATE logs SET $ctx_col = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET $ctx_col = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET $ctx_col = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + done + + # ------------------------------------------------------------------------- + # v1.5.0-prerelease4 columns - log store tables (emitted unconditionally; fail silently on config_db) + # ------------------------------------------------------------------------- + + # logs.user_name (added in v1.5.0-prerelease4) + echo "UPDATE logs SET user_name = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET user_name = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET user_name = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + + # logs.attempt_trail (added in v1.5.0-prerelease4) + echo "UPDATE logs SET attempt_trail = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET attempt_trail = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET attempt_trail = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + + # logs.selected_prompt_* columns (added in v1.5.0-prerelease4) + for sp_col in selected_prompt_name selected_prompt_version selected_prompt_id; do + echo "UPDATE logs SET $sp_col = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET $sp_col = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET $sp_col = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + done + + # logs.ocr_input (added in v1.5.0-prerelease4) + echo "UPDATE logs SET ocr_input = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET ocr_input = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET ocr_input = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + + if [ -f "$config_db" ]; then + # prompt_versions.variables_json (added in v1.5.0-prerelease2 via migrationAddPromptVariablesColumns) + if column_exists_sqlite "$config_db" "prompt_versions" "variables_json"; then + echo "UPDATE prompt_versions SET variables_json = '{}' WHERE prompt_id = 'prompt-migration-test-001';" >> "$output_file" + fi + + # prompt_sessions.variables_json (added in v1.5.0-prerelease2 via migrationAddPromptVariablesColumns) + if column_exists_sqlite "$config_db" "prompt_sessions" "variables_json"; then + echo "UPDATE prompt_sessions SET variables_json = '{}' WHERE prompt_id = 'prompt-migration-test-001';" >> "$output_file" + echo "UPDATE prompt_sessions SET variables_json = '{}' WHERE prompt_id = 'prompt-migration-test-002';" >> "$output_file" + fi + fi } # ============================================================================ @@ -2192,6 +2577,18 @@ generate_mcp_clients_insert_postgres() { vals="$vals, false" fi + # config_mcp_clients.discovered_tools_json (added in v1.5.0-prerelease2) + if column_exists_postgres "config_mcp_clients" "discovered_tools_json"; then + cols="$cols, discovered_tools_json" + vals="$vals, '{}'" + fi + + # config_mcp_clients.tool_name_mapping_json (added in v1.5.0-prerelease2) + if column_exists_postgres "config_mcp_clients" "tool_name_mapping_json"; then + cols="$cols, tool_name_mapping_json" + vals="$vals, '{}'" + fi + # Append the dynamic INSERT to the output file echo "" >> "$output_file" echo "-- config_mcp_clients (MCP server configurations - dynamically generated based on schema)" >> "$output_file" @@ -2425,6 +2822,18 @@ generate_mcp_clients_insert_sqlite() { vals="$vals, 0" fi + # config_mcp_clients.discovered_tools_json (added in v1.5.0-prerelease2) + if column_exists_sqlite "$config_db" "config_mcp_clients" "discovered_tools_json"; then + cols="$cols, discovered_tools_json" + vals="$vals, '{}'" + fi + + # config_mcp_clients.tool_name_mapping_json (added in v1.5.0-prerelease2) + if column_exists_sqlite "$config_db" "config_mcp_clients" "tool_name_mapping_json"; then + cols="$cols, tool_name_mapping_json" + vals="$vals, '{}'" + fi + # Append the dynamic INSERT to the output file echo "" >> "$output_file" echo "-- config_mcp_clients (MCP server configurations - dynamically generated based on schema)" >> "$output_file" @@ -2580,6 +2989,106 @@ generate_prompt_repo_tables_insert_sqlite() { echo "INSERT INTO prompt_session_messages (prompt_id, session_id, order_index, message_json) SELECT 'prompt-migration-test-001', id, 0, '{\"role\":\"user\",\"content\":\"Test message in session\"}' FROM prompt_sessions WHERE prompt_id = 'prompt-migration-test-001' LIMIT 1 ON CONFLICT DO NOTHING;" >> "$output_file" } +# Generate per-user OAuth tables INSERTs for PostgreSQL +# These tables were added in v1.5.0-prerelease4 via migrationAddPerUserOAuthTables +generate_per_user_oauth_tables_insert_postgres() { + local now="$1" + local output_file="$2" + + # Check if the tables exist (added in v1.5.0-prerelease4) + if ! column_exists_postgres "oauth_per_user_clients" "id"; then + return + fi + + echo "" >> "$output_file" + echo "-- ============================================================================" >> "$output_file" + echo "-- Per-User OAuth Tables (added in v1.5.0-prerelease4, dynamically generated)" >> "$output_file" + echo "-- ============================================================================" >> "$output_file" + + # oauth_per_user_clients (no FK dependencies) + echo "" >> "$output_file" + echo "-- oauth_per_user_clients (registered OAuth clients for per-user flows)" >> "$output_file" + echo "INSERT INTO oauth_per_user_clients (id, client_id, client_name, redirect_uris, grant_types, created_at, updated_at) VALUES ('per-user-oauth-client-001', 'client-id-migration-test-001', 'Migration Test Client', '[\"http://localhost:3000/callback\"]', '[\"authorization_code\"]', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" + + # oauth_per_user_sessions (client_id is a string field, no FK constraint enforced by DB) + echo "" >> "$output_file" + echo "-- oauth_per_user_sessions (Bifrost-issued sessions for authenticated MCP connections)" >> "$output_file" + echo "INSERT INTO oauth_per_user_sessions (id, access_token, access_token_hash, refresh_token, refresh_token_hash, client_id, virtual_key_id, user_id, expires_at, encryption_status, created_at, updated_at) VALUES ('per-user-oauth-session-001', 'migration-test-access-token-001', 'a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3', '', 'a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae4', 'client-id-migration-test-001', 'vk-migration-test-1', NULL, $now + INTERVAL '1 hour', 'plain_text', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" + + # oauth_per_user_codes (references per_user_sessions.id as session_id, no enforced FK) + echo "" >> "$output_file" + echo "-- oauth_per_user_codes (short-lived authorization codes)" >> "$output_file" + echo "INSERT INTO oauth_per_user_codes (id, code, code_hash, client_id, redirect_uri, code_challenge, scopes, session_id, expires_at, used, created_at) VALUES ('per-user-oauth-code-001', 'migration-test-code-001', 'a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae5', 'client-id-migration-test-001', 'http://localhost:3000/callback', 'migration-test-challenge-001', '[\"openid\"]', 'per-user-oauth-session-001', $now + INTERVAL '5 minutes', false, $now) ON CONFLICT DO NOTHING;" >> "$output_file" + + # oauth_per_user_pending_flows (no enforced FK) + echo "" >> "$output_file" + echo "-- oauth_per_user_pending_flows (pending OAuth flows awaiting consent)" >> "$output_file" + echo "INSERT INTO oauth_per_user_pending_flows (id, client_id, redirect_uri, code_challenge, state, virtual_key_id, user_id, browser_secret_hash, expires_at, created_at, updated_at) VALUES ('per-user-oauth-flow-001', 'client-id-migration-test-001', 'http://localhost:3000/callback', 'migration-test-challenge-002', 'migration-test-state-001', NULL, NULL, 'a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae6', $now + INTERVAL '15 minutes', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" + + # oauth_user_sessions (per-user OAuth flow tracking - no enforced FK) + echo "" >> "$output_file" + echo "-- oauth_user_sessions (pending per-user OAuth flows)" >> "$output_file" + echo "INSERT INTO oauth_user_sessions (id, mcp_client_id, oauth_config_id, state, redirect_uri, code_verifier, session_token, session_token_hash, gateway_session_id, virtual_key_id, user_id, status, encryption_status, expires_at, created_at, updated_at) VALUES ('oauth-user-session-001', 'mcp-migration-test-001', 'oauth-config-migration-001', 'migration-test-state-002', 'http://localhost:3000/callback', 'migration-test-verifier-001', 'migration-test-session-token-001', 'a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae7', 'gateway-session-001', 'vk-migration-test-1', NULL, 'authorized', 'plain_text', $now + INTERVAL '15 minutes', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" + + # oauth_user_tokens (stores per-user access/refresh tokens - no enforced FK) + echo "" >> "$output_file" + echo "-- oauth_user_tokens (per-user OAuth credentials)" >> "$output_file" + echo "INSERT INTO oauth_user_tokens (id, session_token, session_token_hash, virtual_key_id, user_id, mcp_client_id, oauth_config_id, access_token, refresh_token, token_type, expires_at, scopes, last_refreshed_at, encryption_status, created_at, updated_at) VALUES ('oauth-user-token-001', 'migration-test-session-token-001', 'a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae7', 'vk-migration-test-1', NULL, 'mcp-migration-test-001', 'oauth-config-migration-001', 'migration-test-user-access-token-001', '', 'Bearer', $now + INTERVAL '1 hour', '[\"openid\"]', NULL, 'plain_text', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" +} + +# Generate per-user OAuth tables INSERTs for SQLite +# These tables were added in v1.5.0-prerelease4 via migrationAddPerUserOAuthTables +generate_per_user_oauth_tables_insert_sqlite() { + local now="$1" + local output_file="$2" + local config_db="$3" + + if [ ! -f "$config_db" ]; then + return + fi + + local table_exists + table_exists=$(sqlite3 "$config_db" "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='oauth_per_user_clients';" 2>/dev/null || echo "0") + if [ "$table_exists" != "1" ]; then + return + fi + + echo "" >> "$output_file" + echo "-- ============================================================================" >> "$output_file" + echo "-- Per-User OAuth Tables (added in v1.5.0-prerelease4, dynamically generated)" >> "$output_file" + echo "-- ============================================================================" >> "$output_file" + + # oauth_per_user_clients + echo "" >> "$output_file" + echo "-- oauth_per_user_clients" >> "$output_file" + echo "INSERT INTO oauth_per_user_clients (id, client_id, client_name, redirect_uris, grant_types, created_at, updated_at) VALUES ('per-user-oauth-client-001', 'client-id-migration-test-001', 'Migration Test Client', '[\"http://localhost:3000/callback\"]', '[\"authorization_code\"]', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" + + # oauth_per_user_sessions + echo "" >> "$output_file" + echo "-- oauth_per_user_sessions" >> "$output_file" + echo "INSERT INTO oauth_per_user_sessions (id, access_token, access_token_hash, refresh_token, refresh_token_hash, client_id, virtual_key_id, user_id, expires_at, encryption_status, created_at, updated_at) VALUES ('per-user-oauth-session-001', 'migration-test-access-token-001', 'a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae3', '', 'a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae4', 'client-id-migration-test-001', 'vk-migration-test-1', NULL, datetime('now', '+1 hour'), 'plain_text', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" + + # oauth_per_user_codes + echo "" >> "$output_file" + echo "-- oauth_per_user_codes" >> "$output_file" + echo "INSERT INTO oauth_per_user_codes (id, code, code_hash, client_id, redirect_uri, code_challenge, scopes, session_id, expires_at, used, created_at) VALUES ('per-user-oauth-code-001', 'migration-test-code-001', 'a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae5', 'client-id-migration-test-001', 'http://localhost:3000/callback', 'migration-test-challenge-001', '[\"openid\"]', 'per-user-oauth-session-001', datetime('now', '+5 minutes'), 0, $now) ON CONFLICT DO NOTHING;" >> "$output_file" + + # oauth_per_user_pending_flows + echo "" >> "$output_file" + echo "-- oauth_per_user_pending_flows" >> "$output_file" + echo "INSERT INTO oauth_per_user_pending_flows (id, client_id, redirect_uri, code_challenge, state, virtual_key_id, user_id, browser_secret_hash, expires_at, created_at, updated_at) VALUES ('per-user-oauth-flow-001', 'client-id-migration-test-001', 'http://localhost:3000/callback', 'migration-test-challenge-002', 'migration-test-state-001', NULL, NULL, 'a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae6', datetime('now', '+15 minutes'), $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" + + # oauth_user_sessions + echo "" >> "$output_file" + echo "-- oauth_user_sessions" >> "$output_file" + echo "INSERT INTO oauth_user_sessions (id, mcp_client_id, oauth_config_id, state, redirect_uri, code_verifier, session_token, session_token_hash, gateway_session_id, virtual_key_id, user_id, status, encryption_status, expires_at, created_at, updated_at) VALUES ('oauth-user-session-001', 'mcp-migration-test-001', 'oauth-config-migration-001', 'migration-test-state-002', 'http://localhost:3000/callback', 'migration-test-verifier-001', 'migration-test-session-token-001', 'a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae7', 'gateway-session-001', 'vk-migration-test-1', NULL, 'authorized', 'plain_text', datetime('now', '+15 minutes'), $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" + + # oauth_user_tokens + echo "" >> "$output_file" + echo "-- oauth_user_tokens" >> "$output_file" + echo "INSERT INTO oauth_user_tokens (id, session_token, session_token_hash, virtual_key_id, user_id, mcp_client_id, oauth_config_id, access_token, refresh_token, token_type, expires_at, scopes, last_refreshed_at, encryption_status, created_at, updated_at) VALUES ('oauth-user-token-001', 'migration-test-session-token-001', 'a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27ae7', 'vk-migration-test-1', NULL, 'mcp-migration-test-001', 'oauth-config-migration-001', 'migration-test-user-access-token-001', '', 'Bearer', datetime('now', '+1 hour'), '[\"openid\"]', NULL, 'plain_text', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" +} + # Generate governance_model_parameters INSERT for PostgreSQL # This table stores model parameters/capabilities data synced from external API generate_model_parameters_insert_postgres() { @@ -2992,10 +3501,13 @@ compare_postgres_snapshots() { if [ "$table" = "governance_virtual_keys" ] || [ "$table" = "governance_virtual_key_provider_configs" ]; then dropped_columns="$dropped_columns budget_id" fi - # calendar_aligned (dropped from governance_budgets in add_multi_budget_tables - moved to governance_virtual_keys.calendar_aligned) - if [ "$table" = "governance_budgets" ]; then - dropped_columns="$dropped_columns calendar_aligned" + # budget_id (dropped from governance_teams in migrationAddTeamBudgetsToBudgetsTable v1.5.0-prerelease4 - + # ownership moved to governance_budgets.team_id) + if [ "$table" = "governance_teams" ]; then + dropped_columns="$dropped_columns budget_id" fi + # calendar_aligned was dropped from governance_budgets in prerelease2 (add_multi_budget_tables) but + # re-added in prerelease4 (migrateCalendarAlignedToBudgetsAndRateLimitsTable) - no longer dropped # enable_litellm_fallbacks (dropped from config_client in latest cut - behavior moved elsewhere) if [ "$table" = "config_client" ]; then dropped_columns="$dropped_columns enable_litellm_fallbacks" @@ -3061,9 +3573,10 @@ compare_postgres_snapshots() { for col in "${before_col_array[@]}"; do # Skip columns that are expected to change # virtual_key_id, provider_config_id: only ignore on governance_budgets (new FK columns from multi-budget migration) + # team_id: only ignore on governance_budgets (backfilled from governance_teams.budget_id in prerelease4 migration) local table_ignore_columns="$ignore_columns" if [ "$table" = "governance_budgets" ]; then - table_ignore_columns="$table_ignore_columns virtual_key_id provider_config_id" + table_ignore_columns="$table_ignore_columns virtual_key_id provider_config_id team_id" fi if [[ " $table_ignore_columns " == *" $col "* ]]; then col_idx=$((col_idx + 1)) diff --git a/core/bifrost.go b/core/bifrost.go index cf2a095a67..c114ea0abc 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -6846,6 +6846,9 @@ func (bifrost *Bifrost) getAllSupportedKeys(ctx *schemas.BifrostContext, provide if ctx != nil { key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key) if ok { + if err := validateKey(baseProviderType, &key); err != nil { + return nil, fmt.Errorf("invalid direct key for provider %v: %w", baseProviderType, err) + } // If a direct key is specified, return it as a single-element slice return []schemas.Key{key}, nil } @@ -6893,6 +6896,9 @@ func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *schemas.BifrostContext, p if ctx != nil { key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key) if ok { + if err := validateKey(baseProviderType, &key); err != nil { + return nil, fmt.Errorf("invalid direct key for provider %v: %w", baseProviderType, err) + } // If a direct key is specified, return it as a single-element slice return []schemas.Key{key}, nil } @@ -6981,6 +6987,9 @@ func (bifrost *Bifrost) selectKeyFromProviderForModelWithPool(ctx *schemas.Bifro // DirectKey: caller supplied a key directly — no pool, no rotation. if ctx != nil { if key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key); ok { + if err := validateKey(baseProviderType, &key); err != nil { + return nil, false, fmt.Errorf("invalid direct key for provider %v: %w", baseProviderType, err) + } return []schemas.Key{key}, false, nil } } diff --git a/core/changelog.md b/core/changelog.md index e69de29bb2..5e00f8e847 100644 --- a/core/changelog.md +++ b/core/changelog.md @@ -0,0 +1,2 @@ +- fix: usage of per-user OAuth servers in codemode +- fix: adds validation on direct api keys diff --git a/core/mcp/codemode.go b/core/mcp/codemode.go index fa11e52d0b..a641289be0 100644 --- a/core/mcp/codemode.go +++ b/core/mcp/codemode.go @@ -72,6 +72,9 @@ type CodeModeDependencies struct { // LogMutex protects concurrent access to logs during code execution LogMutex *sync.Mutex + + // OAuth2Provider handles per-user OAuth token lookup and flow initiation + OAuth2Provider schemas.OAuth2Provider } // DefaultCodeModeConfig returns the default configuration for CodeMode. diff --git a/core/mcp/codemode/starlark/executecode.go b/core/mcp/codemode/starlark/executecode.go index d2d9435764..b5e5834e7f 100644 --- a/core/mcp/codemode/starlark/executecode.go +++ b/core/mcp/codemode/starlark/executecode.go @@ -544,7 +544,30 @@ func (s *StarlarkCodeMode) callMCPTool(ctx *schemas.BifrostContext, clientName, toolCtx, cancel := context.WithTimeout(nestedCtx, toolExecutionTimeout) defer cancel() - toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest) + var toolResponse *mcp.CallToolResult + var callErr error + + if client.ExecutionConfig.AuthType == schemas.MCPAuthTypePerUserOauth { + accessToken, err := utils.ResolvePerUserOAuthToken(nestedCtx, client, s.oauth2Provider) + if err != nil { + return nil, err + } + + if client.Conn == nil { + // Per-user OAuth with no persistent connection — use a temporary connection. + // Assign to outer toolResponse/callErr so the shared logging + post-hooks path runs. + toolResponse, callErr = codemcp.ExecuteToolWithUserToken(toolCtx, client.ExecutionConfig, toolNameToCall, args, accessToken, s.logger) + if callErr != nil && toolCtx.Err() == context.DeadlineExceeded { + callErr = fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName) + } + } else { + callRequest.Header = utils.BuildPerUserOAuthHeaders(callRequest.Header, accessToken) + toolResponse, callErr = client.Conn.CallTool(toolCtx, callRequest) + } + } else { + toolResponse, callErr = client.Conn.CallTool(toolCtx, callRequest) + } + latency := time.Since(startTime).Milliseconds() var mcpResp *schemas.BifrostMCPResponse diff --git a/core/mcp/codemode/starlark/starlark.go b/core/mcp/codemode/starlark/starlark.go index 348655b983..8b74c3fb07 100644 --- a/core/mcp/codemode/starlark/starlark.go +++ b/core/mcp/codemode/starlark/starlark.go @@ -27,6 +27,7 @@ type StarlarkCodeMode struct { pluginPipelineProvider func() mcp.PluginPipeline releasePluginPipeline func(pipeline mcp.PluginPipeline) fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string + oauth2Provider schemas.OAuth2Provider // Logger for this instance logger schemas.Logger @@ -86,6 +87,7 @@ func (s *StarlarkCodeMode) SetDependencies(deps *mcp.CodeModeDependencies) { s.pluginPipelineProvider = deps.PluginPipelineProvider s.releasePluginPipeline = deps.ReleasePluginPipeline s.fetchNewRequestIDFunc = deps.FetchNewRequestIDFunc + s.oauth2Provider = deps.OAuth2Provider } } diff --git a/core/mcp/toolmanager.go b/core/mcp/toolmanager.go index 8972bc1787..ecfad70a1b 100644 --- a/core/mcp/toolmanager.go +++ b/core/mcp/toolmanager.go @@ -5,7 +5,6 @@ package mcp import ( "context" "encoding/json" - "errors" "fmt" "net/http" "strings" @@ -184,6 +183,7 @@ func (m *ToolsManager) GetCodeModeDependencies() *CodeModeDependencies { PluginPipelineProvider: m.pluginPipelineProvider, ReleasePluginPipeline: m.releasePluginPipeline, FetchNewRequestIDFunc: m.fetchNewRequestIDFunc, + OAuth2Provider: m.oauth2Provider, } } @@ -686,55 +686,9 @@ func (m *ToolsManager) executeToolInternal(ctx *schemas.BifrostContext, toolCall // Handle per-user OAuth: inject user-specific Authorization header if client.ExecutionConfig.AuthType == schemas.MCPAuthTypePerUserOauth { - if m.oauth2Provider == nil { - return nil, "", "", fmt.Errorf("per-user OAuth requires an OAuth2Provider but none is configured") - } - virtualKeyID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string) - userID, _ := ctx.Value(schemas.BifrostContextKeyUserID).(string) - sessionToken, _ := ctx.Value(schemas.BifrostContextKeyMCPUserSession).(string) - - // Optional X-Bf-User-Id header overrides user identity; if absent, falls back to virtual key - if mcpUserID, _ := ctx.Value(schemas.BifrostContextKeyMCPUserID).(string); mcpUserID != "" { - userID = mcpUserID - } - - // Try identity-based token lookup first (works even without session token) - accessToken, err := m.oauth2Provider.GetUserAccessTokenByIdentity(ctx, virtualKeyID, userID, sessionToken, client.ExecutionConfig.ID) - if err != nil && !errors.Is(err, schemas.ErrOAuth2TokenNotFound) { - // Had session but token lookup failed with a real error (not just "not found") — return error - return nil, "", "", fmt.Errorf("failed to get user access token for MCP server %s: %w", client.ExecutionConfig.Name, err) - } + accessToken, err := utils.ResolvePerUserOAuthToken(ctx, client, m.oauth2Provider) if err != nil { - // No token found — user hasn't authenticated with this MCP server yet. - // In LLM gateway mode with no identity, we can't track who this user is, - // so an OAuth flow would produce an orphaned token. Return a clear error instead. - isMCPGateway, _ := ctx.Value(schemas.BifrostContextKeyIsMCPGateway).(bool) - if !isMCPGateway && userID == "" && virtualKeyID == "" { - return nil, "", "", fmt.Errorf( - "per-user OAuth for %s requires a user identity: include X-Bf-User-Id or a Virtual Key in your request so the token can be linked to you", - client.ExecutionConfig.Name, - ) - } - - // Initiate OAuth flow to get a proper authorize URL with session tracking. - if client.ExecutionConfig.OauthConfigID == nil || *client.ExecutionConfig.OauthConfigID == "" { - return nil, "", "", fmt.Errorf("per-user OAuth requires an OAuth config but MCP client %s has none", client.ExecutionConfig.Name) - } - redirectURI := buildRedirectURIFromContext(ctx) - if redirectURI == "" { - return nil, "", "", fmt.Errorf("per-user OAuth requires a redirect URI but none is available in context") - } - flowInitiation, sessionID, flowErr := m.oauth2Provider.InitiateUserOAuthFlow(ctx, *client.ExecutionConfig.OauthConfigID, client.ExecutionConfig.ID, redirectURI) - if flowErr != nil { - return nil, "", "", fmt.Errorf("failed to initiate per-user OAuth flow for %s: %w", client.ExecutionConfig.Name, flowErr) - } - return nil, "", "", &schemas.MCPUserOAuthRequiredError{ - MCPClientID: client.ExecutionConfig.ID, - MCPClientName: client.ExecutionConfig.Name, - AuthorizeURL: flowInitiation.AuthorizeURL, - SessionID: sessionID, - Message: fmt.Sprintf("Authentication required for %s. Please visit the authorize URL to connect your account.", client.ExecutionConfig.Name), - } + return nil, "", "", err } if client.Conn == nil { @@ -743,7 +697,7 @@ func (m *ToolsManager) executeToolInternal(ctx *schemas.BifrostContext, toolCall toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) defer cancel() - toolResponse, callErr := executeToolWithUserToken(toolCtx, client.ExecutionConfig, originalMCPToolName, arguments, accessToken, m.logger) + toolResponse, callErr := ExecuteToolWithUserToken(toolCtx, client.ExecutionConfig, originalMCPToolName, arguments, accessToken, m.logger) if callErr != nil { if toolCtx.Err() == context.DeadlineExceeded { return nil, "", "", fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName) @@ -755,15 +709,7 @@ func (m *ToolsManager) executeToolInternal(ctx *schemas.BifrostContext, toolCall return createToolResponseMessage(*toolCall, responseText), client.ExecutionConfig.Name, sanitizedToolName, nil } - // Persistent connection exists — use per-call headers - headers := make(http.Header) - if client.ExecutionConfig.Headers != nil { - for key, value := range client.ExecutionConfig.Headers { - headers.Add(key, value.GetValue()) - } - } - headers.Set("Authorization", "Bearer "+accessToken) - callRequest.Header = headers + callRequest.Header = utils.BuildPerUserOAuthHeaders(callRequest.Header, accessToken) } else if client.ExecutionConfig.Headers != nil { headers := make(http.Header) for key, value := range client.ExecutionConfig.Headers { @@ -911,7 +857,7 @@ func (m *ToolsManager) UpdateConfig(config *schemas.MCPToolManagerConfig) { // Returns: // - *mcp.CallToolResult: tool execution result // - error: any error during connection or execution -func executeToolWithUserToken(ctx context.Context, config *schemas.MCPClientConfig, toolName string, arguments map[string]interface{}, accessToken string, logger schemas.Logger) (*mcp.CallToolResult, error) { +func ExecuteToolWithUserToken(ctx context.Context, config *schemas.MCPClientConfig, toolName string, arguments map[string]interface{}, accessToken string, logger schemas.Logger) (*mcp.CallToolResult, error) { if config.ConnectionString == nil || config.ConnectionString.GetValue() == "" { return nil, fmt.Errorf("connection URL is required for per-user OAuth tool execution") } @@ -964,15 +910,6 @@ func executeToolWithUserToken(ctx context.Context, config *schemas.MCPClientConf return tempClient.CallTool(ctx, callRequest) } -// buildRedirectURIFromContext extracts the OAuth redirect URI from context. -// The URI is set by the HTTP middleware from the request's host. -func buildRedirectURIFromContext(ctx *schemas.BifrostContext) string { - if uri, ok := ctx.Value(schemas.BifrostContextKeyOAuthRedirectURI).(string); ok && uri != "" { - return uri - } - // Fallback — should not happen if middleware is configured correctly - return "" -} // GetCodeModeBindingLevel returns the current code mode binding level. // This method is safe to call concurrently from multiple goroutines. diff --git a/core/mcp/utils/utils.go b/core/mcp/utils/utils.go index 500792a09f..eaf4d8f892 100644 --- a/core/mcp/utils/utils.go +++ b/core/mcp/utils/utils.go @@ -1,11 +1,82 @@ package utils import ( + "errors" + "fmt" "net/http" "github.com/maximhq/bifrost/core/schemas" ) +// ResolvePerUserOAuthToken looks up the per-user OAuth access token for the given client. +// If no token exists yet, it initiates an OAuth flow and returns an MCPUserOAuthRequiredError. +func ResolvePerUserOAuthToken(ctx *schemas.BifrostContext, client *schemas.MCPClientState, oauth2Provider schemas.OAuth2Provider) (string, error) { + if oauth2Provider == nil { + return "", fmt.Errorf("per-user OAuth requires an OAuth2Provider but none is configured") + } + + virtualKeyID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string) + userID, _ := ctx.Value(schemas.BifrostContextKeyUserID).(string) + sessionToken, _ := ctx.Value(schemas.BifrostContextKeyMCPUserSession).(string) + + // Optional X-Bf-User-Id header overrides user identity; if absent, falls back to virtual key + if mcpUserID, _ := ctx.Value(schemas.BifrostContextKeyMCPUserID).(string); mcpUserID != "" { + userID = mcpUserID + } + + accessToken, err := oauth2Provider.GetUserAccessTokenByIdentity(ctx, virtualKeyID, userID, sessionToken, client.ExecutionConfig.ID) + if err != nil && !errors.Is(err, schemas.ErrOAuth2TokenNotFound) { + return "", fmt.Errorf("failed to get user access token for MCP server %s: %w", client.ExecutionConfig.Name, err) + } + if err != nil { + // In LLM gateway mode with no identity, an OAuth flow would produce an orphaned token. + isMCPGateway, _ := ctx.Value(schemas.BifrostContextKeyIsMCPGateway).(bool) + if !isMCPGateway && userID == "" && virtualKeyID == "" { + return "", fmt.Errorf( + "per-user OAuth for %s requires a user identity: include X-Bf-User-Id or a Virtual Key in your request so the token can be linked to you", + client.ExecutionConfig.Name, + ) + } + + if client.ExecutionConfig.OauthConfigID == nil || *client.ExecutionConfig.OauthConfigID == "" { + return "", fmt.Errorf("per-user OAuth requires an OAuth config but MCP client %s has none", client.ExecutionConfig.Name) + } + redirectURI := BuildRedirectURIFromContext(ctx) + if redirectURI == "" { + return "", fmt.Errorf("per-user OAuth requires a redirect URI but none is available in context") + } + flowInitiation, sessionID, flowErr := oauth2Provider.InitiateUserOAuthFlow(ctx, *client.ExecutionConfig.OauthConfigID, client.ExecutionConfig.ID, redirectURI) + if flowErr != nil { + return "", fmt.Errorf("failed to initiate per-user OAuth flow for %s: %w", client.ExecutionConfig.Name, flowErr) + } + return "", &schemas.MCPUserOAuthRequiredError{ + MCPClientID: client.ExecutionConfig.ID, + MCPClientName: client.ExecutionConfig.Name, + AuthorizeURL: flowInitiation.AuthorizeURL, + SessionID: sessionID, + Message: fmt.Sprintf("Authentication required for %s. Please visit the authorize URL to connect your account.", client.ExecutionConfig.Name), + } + } + + return accessToken, nil +} + +// BuildPerUserOAuthHeaders clones the provided headers and adds the Bearer token, +// preserving any request-scoped extra headers already present. +func BuildPerUserOAuthHeaders(headers http.Header, accessToken string) http.Header { + h := headers.Clone() + h.Set("Authorization", "Bearer "+accessToken) + return h +} + +// BuildRedirectURIFromContext extracts the OAuth redirect URI from context. +func BuildRedirectURIFromContext(ctx *schemas.BifrostContext) string { + if uri, ok := ctx.Value(schemas.BifrostContextKeyOAuthRedirectURI).(string); ok && uri != "" { + return uri + } + return "" +} + // GetHeadersForToolExecution sets additional headers for tool execution. // It returns the headers for the tool execution. func GetHeadersForToolExecution(ctx *schemas.BifrostContext, client *schemas.MCPClientState) http.Header { diff --git a/core/providers/bedrock/bedrock.go b/core/providers/bedrock/bedrock.go index 5d93f54b93..a47d526527 100644 --- a/core/providers/bedrock/bedrock.go +++ b/core/providers/bedrock/bedrock.go @@ -138,22 +138,6 @@ func (provider *BedrockProvider) GetProviderKey() schemas.ModelProvider { return providerUtils.GetProviderName(schemas.Bedrock, provider.customProviderConfig) } -// ensureBedrockKeyConfig ensures key.BedrockKeyConfig is non-nil. When the key -// uses API key authentication (key.Value is set) but has no Bedrock-specific -// config, a minimal default is created so the request URL can be constructed -// (region defaults to us-east-1). Returns false only when there is truly no -// way to authenticate (no API key AND no bedrock config). -func ensureBedrockKeyConfig(key *schemas.Key) bool { - if key.BedrockKeyConfig != nil { - return true - } - if key.Value.GetValue() != "" { - key.BedrockKeyConfig = &schemas.BedrockKeyConfig{} - return true - } - return false -} - // isStreamTransportError reports whether err is a transport-level connection // failure that occurred while reading the EventStream body — as opposed to a // semantic error (JSON parse failure, AWS exception event, etc.). diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index b932dc1aee..c20b343d74 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -250,8 +250,8 @@ const ( BifrostContextKeyRequestHeaders BifrostContextKey = "bifrost-request-headers" // map[string]string (all request headers with lowercased keys) BifrostContextKeySkipListModelsGovernanceFiltering BifrostContextKey = "bifrost-skip-list-models-governance-filtering" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeySCIMClaims BifrostContextKey = "scim_claims" - BifrostContextKeyUserID BifrostContextKey = "bifrost-user-id" // string (to store the user ID (set by enterprise auth middleware - DO NOT SET THIS MANUALLY)) - BifrostContextKeyUserName BifrostContextKey = "bifrost-user-name" // string (to store the user name (set by enterprise auth middleware - DO NOT SET THIS MANUALLY)) + BifrostContextKeyUserID BifrostContextKey = "bifrost-user-id" // string (to store the user ID (set by enterprise auth middleware - DO NOT SET THIS MANUALLY)) + BifrostContextKeyUserName BifrostContextKey = "bifrost-user-name" // string (to store the user name (set by enterprise auth middleware - DO NOT SET THIS MANUALLY)) BifrostContextKeyTargetUserID BifrostContextKey = "target_user_id" BifrostContextKeyIsAzureUserAgent BifrostContextKey = "bifrost-is-azure-user-agent" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) - whether the request is an Azure user agent (only used in gateway) BifrostContextKeyVideoOutputRequested BifrostContextKey = "bifrost-video-output-requested" @@ -1056,6 +1056,11 @@ func (r *BifrostResponse) PopulateExtraFields(requestType RequestType, provider r.ContainerFileDeleteResponse.ExtraFields.Provider = provider r.ContainerFileDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested r.ContainerFileDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.OCRResponse != nil: + r.OCRResponse.ExtraFields.RequestType = requestType + r.OCRResponse.ExtraFields.Provider = provider + r.OCRResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.OCRResponse.ExtraFields.ResolvedModelUsed = resolvedModel case r.PassthroughResponse != nil: r.PassthroughResponse.ExtraFields.RequestType = requestType r.PassthroughResponse.ExtraFields.Provider = provider diff --git a/core/utils.go b/core/utils.go index 0e1b45fefd..8778602c7e 100644 --- a/core/utils.go +++ b/core/utils.go @@ -144,13 +144,8 @@ func validateKey(providerKey schemas.ModelProvider, key *schemas.Key) error { return fmt.Errorf("azure_key_config.endpoint is required") } case schemas.Bedrock: - // Key is valid if either: - // 1. BedrockKeyConfig is provided - // 2. Value is provided and is not empty + // BedrockKeyConfig is optional — an empty config is valid for IRSA / ambient credential auth. if key.BedrockKeyConfig == nil { - if key.Value.GetValue() == "" { - return fmt.Errorf("either value in key or bedrock_key_config is required") - } key.BedrockKeyConfig = &schemas.BedrockKeyConfig{} } case schemas.Vertex: diff --git a/docs/features/async-inference.mdx b/docs/features/async-inference.mdx index 23c9fcab1f..3d0428d5cf 100644 --- a/docs/features/async-inference.mdx +++ b/docs/features/async-inference.mdx @@ -50,6 +50,7 @@ Streaming is not supported on async endpoints. | Image generations | `/v1/async/images/generations` | `/v1/async/images/generations/{job_id}` | | Image edits | `/v1/async/images/edits` | `/v1/async/images/edits/{job_id}` | | Image variations | `/v1/async/images/variations` | `/v1/async/images/variations/{job_id}` | +| OCR | `/v1/async/ocr` | `/v1/async/ocr/{job_id}` | | Rerank | `/v1/async/rerank` | `/v1/async/rerank/{job_id}` | ## Submitting a Request diff --git a/framework/changelog.md b/framework/changelog.md index e69de29bb2..553b78d3f7 100644 --- a/framework/changelog.md +++ b/framework/changelog.md @@ -0,0 +1 @@ +- fix: adds support for OCR request pricing diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index 303a3cc8c3..d749937224 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -420,6 +420,9 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddTeamBudgetsToBudgetsTable(ctx, db); err != nil { return err } + if err := migrationAddOCRPricingColumns(ctx, db); err != nil { + return err + } return nil } @@ -6713,51 +6716,77 @@ func migrateCalendarAlignedToBudgetsAndRateLimitsTable(ctx context.Context, db * } } // Prefill calendar_aligned for existing budgets and rate_limits attached to virtual keys. - // GORM v2: Preload must precede the Find finisher, otherwise it's a no-op on the executed query. - var virtualKeys []tables.TableVirtualKey - if err := tx.Preload("Budgets").Find(&virtualKeys).Error; err != nil { - return fmt.Errorf("failed to load virtual keys: %w", err) - } - for i := range virtualKeys { - // Preserve the legacy per-VK semantic: only copy calendar_aligned=true to - // the VK's budgets and rate_limit when the source VK itself was aligned. - // Hardcoding true would change reset behavior for tenants whose VKs were - // never calendar-aligned. - if !virtualKeys[i].CalendarAligned { - continue - } - // Ratelimit updates. A stale rate_limit_id is skipped — the FK is intentionally - // not DB-enforced for TableVirtualKey — but the VK's budgets are still migrated. - if virtualKeys[i].RateLimitID != nil { - var rateLimit tables.TableRateLimit - err := tx.First(&rateLimit, virtualKeys[i].RateLimitID).Error - switch { - case err == gorm.ErrRecordNotFound: - // Skip only the rate-limit update; fall through to the budget loop. - case err != nil: - return fmt.Errorf("failed to load rate limit for virtual key %s: %w", virtualKeys[i].ID, err) - default: - rateLimit.CalendarAligned = true - if err := tx.Save(&rateLimit).Error; err != nil { - return fmt.Errorf("failed to save rate limit for virtual key %s: %w", virtualKeys[i].ID, err) - } + // Use subquery-based raw SQL (compatible with both PostgreSQL and SQLite) to avoid + // "cached plan must not change result type" (SQLSTATE 0A000): earlier migrations in + // the same run added columns to these tables, invalidating pgx's prepared-statement cache. + if err := tx.Exec(` + UPDATE governance_rate_limits + SET calendar_aligned = true + WHERE id IN ( + SELECT rate_limit_id FROM governance_virtual_keys + WHERE calendar_aligned = true AND rate_limit_id IS NOT NULL + ) + `).Error; err != nil { + return fmt.Errorf("failed to propagate calendar_aligned to rate limits: %w", err) + } + if err := tx.Exec(` + UPDATE governance_budgets + SET calendar_aligned = true + WHERE virtual_key_id IN ( + SELECT id FROM governance_virtual_keys WHERE calendar_aligned = true + ) + `).Error; err != nil { + return fmt.Errorf("failed to propagate calendar_aligned to budgets: %w", err) + } + log.Printf("[Migration] Prefilled calendar_aligned field for existing budgets and rate limits") + return nil + }, + Rollback: func(tx *gorm.DB) error { return nil }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running migrate_calendar_aligned migration: %s", err.Error()) + } + return nil +} + +func migrationAddOCRPricingColumns(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_ocr_pricing_columns", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + columns := []string{ + "ocr_cost_per_page", + "annotation_cost_per_page", + } + for _, field := range columns { + if !mg.HasColumn(&tables.TableModelPricing{}, field) { + if err := mg.AddColumn(&tables.TableModelPricing{}, field); err != nil { + return fmt.Errorf("failed to add column %s: %w", field, err) } } - // Budgets update - for j := range virtualKeys[i].Budgets { - virtualKeys[i].Budgets[j].CalendarAligned = true - if err := tx.Save(&virtualKeys[i].Budgets[j]).Error; err != nil { - return fmt.Errorf("failed to save budget for virtual key %s: %w", virtualKeys[i].ID, err) + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + columns := []string{ + "ocr_cost_per_page", + "annotation_cost_per_page", + } + for _, field := range columns { + if mg.HasColumn(&tables.TableModelPricing{}, field) { + if err := mg.DropColumn(&tables.TableModelPricing{}, field); err != nil { + return fmt.Errorf("failed to drop column %s: %w", field, err) } } } - log.Printf("[Migration] Prefilled calendar_aligned field for existing budgets and rate limits") return nil }, - Rollback: func(tx *gorm.DB) error { return nil }, }}) if err := m.Migrate(); err != nil { - return fmt.Errorf("error running migrate_calendar_aligned migration: %s", err.Error()) + return fmt.Errorf("error running add_ocr_pricing_columns migration: %s", err.Error()) } return nil } diff --git a/framework/configstore/tables/modelpricing.go b/framework/configstore/tables/modelpricing.go index 28ed16aa37..fddb9b3ebc 100644 --- a/framework/configstore/tables/modelpricing.go +++ b/framework/configstore/tables/modelpricing.go @@ -31,8 +31,8 @@ type TableModelPricing struct { InputCostPerAudioPerSecondAbove128kTokens *float64 `gorm:"default:null;column:input_cost_per_audio_per_second_above_128k_tokens" json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"` OutputCostPerTokenAbove128kTokens *float64 `gorm:"default:null;column:output_cost_per_token_above_128k_tokens" json:"output_cost_per_token_above_128k_tokens,omitempty"` // Costs - 200k Tier - InputCostPerTokenAbove200kTokens *float64 `gorm:"default:null;column:input_cost_per_token_above_200k_tokens" json:"input_cost_per_token_above_200k_tokens,omitempty"` - InputCostPerTokenAbove200kTokensPriority *float64 `gorm:"default:null;column:input_cost_per_token_above_200k_tokens_priority" json:"input_cost_per_token_above_200k_tokens_priority,omitempty"` + InputCostPerTokenAbove200kTokens *float64 `gorm:"default:null;column:input_cost_per_token_above_200k_tokens" json:"input_cost_per_token_above_200k_tokens,omitempty"` + InputCostPerTokenAbove200kTokensPriority *float64 `gorm:"default:null;column:input_cost_per_token_above_200k_tokens_priority" json:"input_cost_per_token_above_200k_tokens_priority,omitempty"` OutputCostPerTokenAbove200kTokens *float64 `gorm:"default:null;column:output_cost_per_token_above_200k_tokens" json:"output_cost_per_token_above_200k_tokens,omitempty"` OutputCostPerTokenAbove200kTokensPriority *float64 `gorm:"default:null;column:output_cost_per_token_above_200k_tokens_priority" json:"output_cost_per_token_above_200k_tokens_priority,omitempty"` // Costs - 272k Tier @@ -87,6 +87,10 @@ type TableModelPricing struct { // Costs - Other SearchContextCostPerQuery *float64 `gorm:"default:null;column:search_context_cost_per_query" json:"search_context_cost_per_query,omitempty"` CodeInterpreterCostPerSession *float64 `gorm:"default:null;column:code_interpreter_cost_per_session" json:"code_interpreter_cost_per_session,omitempty"` + + // Costs - OCR + OCRCostPerPage *float64 `gorm:"default:null;column:ocr_cost_per_page" json:"ocr_cost_per_page,omitempty"` + AnnotationCostPerPage *float64 `gorm:"default:null;column:annotation_cost_per_page" json:"annotation_cost_per_page,omitempty"` } // TableName sets the table name for each model diff --git a/framework/logstore/asyncjob.go b/framework/logstore/asyncjob.go index 8923420d2d..f10ddfd894 100644 --- a/framework/logstore/asyncjob.go +++ b/framework/logstore/asyncjob.go @@ -60,7 +60,7 @@ func (e *AsyncJobExecutor) RetrieveJob(ctx context.Context, jobID string, vkValu if errors.Is(err, ErrNotFound) { return nil, fmt.Errorf("job not found or expired") } - return nil, fmt.Errorf("failed to retrieve async job: %w", err) + return nil, fmt.Errorf("%w: %w", ErrJobInternal, err) } if job.VirtualKeyID != nil { if vkValue == nil { diff --git a/framework/logstore/errors.go b/framework/logstore/errors.go index 650d767d33..e55db2f489 100644 --- a/framework/logstore/errors.go +++ b/framework/logstore/errors.go @@ -3,5 +3,6 @@ package logstore import "fmt" var ( - ErrNotFound = fmt.Errorf("log not found") + ErrNotFound = fmt.Errorf("log not found") + ErrJobInternal = fmt.Errorf("internal job store error") ) diff --git a/framework/logstore/migrations.go b/framework/logstore/migrations.go index 9993e2f9df..957d8a1a14 100644 --- a/framework/logstore/migrations.go +++ b/framework/logstore/migrations.go @@ -239,6 +239,9 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddUserNameColumn(ctx, db); err != nil { return err } + if err := migrationAddOCRInputColumn(ctx, db); err != nil { + return err + } return nil } @@ -2579,3 +2582,36 @@ func migrationAddSelectedPromptColumns(ctx context.Context, db *gorm.DB) error { } return nil } + +// migrationAddOCRInputColumn adds the ocr_input column to the logs table. +func migrationAddOCRInputColumn(ctx context.Context, db *gorm.DB) error { + opts := *migrator.DefaultOptions + opts.UseTransaction = true + m := migrator.New(db, &opts, []*migrator.Migration{{ + ID: "logs_add_ocr_input_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + if !mig.HasColumn(&Log{}, "ocr_input") { + if err := mig.AddColumn(&Log{}, "ocr_input"); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + if mig.HasColumn(&Log{}, "ocr_input") { + if err := mig.DropColumn(&Log{}, "ocr_input"); err != nil { + return err + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error while adding ocr_input column: %s", err.Error()) + } + return nil +} diff --git a/framework/logstore/payload.go b/framework/logstore/payload.go index 7d2287b6a2..2efefae494 100644 --- a/framework/logstore/payload.go +++ b/framework/logstore/payload.go @@ -20,6 +20,8 @@ var payloadFields = []string{ "responses_output", "embedding_output", "rerank_output", + "ocr_input", + "ocr_output", "params", "tools", "tool_calls", @@ -58,6 +60,8 @@ func ExtractPayload(l *Log) map[string]string { m["responses_output"] = l.ResponsesOutput m["embedding_output"] = l.EmbeddingOutput m["rerank_output"] = l.RerankOutput + m["ocr_input"] = l.OCRInput + m["ocr_output"] = l.OCROutput m["params"] = l.Params m["tools"] = l.Tools m["tool_calls"] = l.ToolCalls @@ -100,6 +104,8 @@ func ClearPayload(l *Log) { l.ResponsesOutput = "" l.EmbeddingOutput = "" l.RerankOutput = "" + l.OCRInput = "" + l.OCROutput = "" l.Params = "" l.Tools = "" l.ToolCalls = "" @@ -134,6 +140,8 @@ func ClearPayload(l *Log) { l.ResponsesOutputParsed = nil l.EmbeddingOutputParsed = nil l.RerankOutputParsed = nil + l.OCRInputParsed = nil + l.OCROutputParsed = nil l.ParamsParsed = nil l.ToolsParsed = nil l.ToolCallsParsed = nil @@ -183,6 +191,12 @@ func MergePayloadFromJSON(l *Log, data []byte) error { if v, ok := m["rerank_output"]; ok && v != "" { l.RerankOutput = v } + if v, ok := m["ocr_input"]; ok && v != "" { + l.OCRInput = v + } + if v, ok := m["ocr_output"]; ok && v != "" { + l.OCROutput = v + } if v, ok := m["params"]; ok && v != "" { l.Params = v } @@ -504,6 +518,12 @@ func clearPayloadField(l *Log, name string) { case "rerank_output": l.RerankOutput = "" l.RerankOutputParsed = nil + case "ocr_input": + l.OCRInput = "" + l.OCRInputParsed = nil + case "ocr_output": + l.OCROutput = "" + l.OCROutputParsed = nil case "params": l.Params = "" l.ParamsParsed = nil diff --git a/framework/logstore/tables.go b/framework/logstore/tables.go index 1d7446597e..b72398ed3f 100644 --- a/framework/logstore/tables.go +++ b/framework/logstore/tables.go @@ -144,6 +144,7 @@ type Log struct { ToolCalls string `gorm:"type:text" json:"-"` // JSON serialized []schemas.ToolCall (For backward compatibility, tool calls are now in the content) SpeechInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.SpeechInput TranscriptionInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.TranscriptionInput + OCRInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.OCRDocument ImageGenerationInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ImageGenerationInput ImageEditInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ImageEditInput ImageVariationInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ImageVariationInput @@ -200,6 +201,7 @@ type Log struct { ErrorDetailsParsed *schemas.BifrostError `gorm:"-" json:"error_details,omitempty"` SpeechInputParsed *schemas.SpeechInput `gorm:"-" json:"speech_input,omitempty"` TranscriptionInputParsed *schemas.TranscriptionInput `gorm:"-" json:"transcription_input,omitempty"` + OCRInputParsed *schemas.OCRDocument `gorm:"-" json:"ocr_input,omitempty"` ImageGenerationInputParsed *schemas.ImageGenerationInput `gorm:"-" json:"image_generation_input,omitempty"` ImageEditInputParsed *schemas.ImageEditInput `gorm:"-" json:"image_edit_input,omitempty"` ImageVariationInputParsed *schemas.ImageVariationInput `gorm:"-" json:"image_variation_input,omitempty"` @@ -337,6 +339,14 @@ func (l *Log) SerializeFields() error { } } + if l.OCRInputParsed != nil { + if data, err := sonic.Marshal(l.OCRInputParsed); err != nil { + return err + } else { + l.OCRInput = string(data) + } + } + if l.ImageGenerationInputParsed != nil { if data, err := sonic.Marshal(l.ImageGenerationInputParsed); err != nil { return err @@ -676,6 +686,12 @@ func (l *Log) DeserializeFields() error { } } + if l.OCRInput != "" { + if err := sonic.Unmarshal([]byte(l.OCRInput), &l.OCRInputParsed); err != nil { + l.OCRInputParsed = nil + } + } + if l.ImageGenerationInput != "" { if err := sonic.Unmarshal([]byte(l.ImageGenerationInput), &l.ImageGenerationInputParsed); err != nil { // Log error but don't fail the operation - initialize as nil diff --git a/framework/modelcatalog/pricing.go b/framework/modelcatalog/pricing.go index e1a961e713..5ca32ecb57 100644 --- a/framework/modelcatalog/pricing.go +++ b/framework/modelcatalog/pricing.go @@ -151,6 +151,10 @@ type PricingOptions struct { // See UnmarshalJSON below for the custom decoding logic. SearchContextCostPerQuery *float64 `json:"search_context_cost_per_query,omitempty"` CodeInterpreterCostPerSession *float64 `json:"code_interpreter_cost_per_session,omitempty"` + + // Costs - OCR + OCRCostPerPage *float64 `json:"ocr_cost_per_page,omitempty"` + AnnotationCostPerPage *float64 `json:"annotation_cost_per_page,omitempty"` } // serviceTier captures the OpenAI service_tier value from a response. @@ -171,6 +175,8 @@ type costInput struct { imageSize string // e.g. "1024x1024", used for per-pixel pricing imageQuality string // "low", "medium", "high", "auto" (gpt-image-1.5); empty = use base rate videoSeconds *int + ocrProcessedPages *int + ocrIsAnnotated *bool tier serviceTier } @@ -191,6 +197,7 @@ func (mc *ModelCatalog) GetPricingEntryForModel(model string, provider schemas.M schemas.ImageEditRequest, schemas.ImageVariationRequest, schemas.VideoGenerationRequest, + schemas.OCRRequest, } { key := makeKey(model, string(provider), normalizeRequestType(mode)) pricing, ok := mc.pricingData[key] @@ -280,7 +287,7 @@ func (mc *ModelCatalog) calculateBaseCost(result *schemas.BifrostResponse, scope } // If no usage data at all, nothing to price - if input.usage == nil && input.audioSeconds == nil && input.audioTokenDetails == nil && input.imageUsage == nil && input.videoSeconds == nil && input.audioTextInputChars == 0 { + if input.usage == nil && input.audioSeconds == nil && input.audioTokenDetails == nil && input.imageUsage == nil && input.videoSeconds == nil && input.audioTextInputChars == 0 && input.ocrProcessedPages == nil { return 0 } @@ -309,6 +316,8 @@ func (mc *ModelCatalog) calculateBaseCost(result *schemas.BifrostResponse, scope return computeImageCost(pricing, input.imageUsage, input.imageSize, input.imageQuality, input.tier) case schemas.VideoGenerationRequest, schemas.VideoRemixRequest: return computeVideoCost(pricing, input.usage, input.videoSeconds, input.tier) + case schemas.OCRRequest: + return computeOCRCost(pricing, input.ocrProcessedPages, input.ocrIsAnnotated) default: return 0 } @@ -384,6 +393,15 @@ func extractCostInput(result *schemas.BifrostResponse) costInput { if err == nil { input.videoSeconds = &seconds } + + case result.OCRResponse != nil: + pages := len(result.OCRResponse.Pages) + if result.OCRResponse.UsageInfo != nil && result.OCRResponse.UsageInfo.PagesProcessed > 0 { + pages = result.OCRResponse.UsageInfo.PagesProcessed + } + input.ocrProcessedPages = &pages + isAnnotated := result.OCRResponse.DocumentAnnotation != nil && *result.OCRResponse.DocumentAnnotation != "" + input.ocrIsAnnotated = &isAnnotated } return input @@ -779,6 +797,23 @@ func computeVideoCost(pricing *configstoreTables.TableModelPricing, usage *schem return inputCost + outputCost } +// computeOCRCost handles OCR requests, billing per page processed. +// ocr_cost_per_page covers base processing; annotation_cost_per_page is added when set. +func computeOCRCost(pricing *configstoreTables.TableModelPricing, ocrProcessedPages *int, ocrIsAnnotated *bool) float64 { + if ocrProcessedPages == nil { + return 0 + } + pages := float64(*ocrProcessedPages) + cost := 0.0 + if pricing.OCRCostPerPage != nil { + cost += pages * *pricing.OCRCostPerPage + } + if ocrIsAnnotated != nil && *ocrIsAnnotated && pricing.AnnotationCostPerPage != nil { + cost += pages * *pricing.AnnotationCostPerPage + } + return cost +} + // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- diff --git a/framework/modelcatalog/pricing_overrides.go b/framework/modelcatalog/pricing_overrides.go index 4908a859ff..baecf51347 100644 --- a/framework/modelcatalog/pricing_overrides.go +++ b/framework/modelcatalog/pricing_overrides.go @@ -448,6 +448,8 @@ func patchPricing(pricing configstoreTables.TableModelPricing, override PricingO {dst: &patched.OutputCostPerImageMediumQuality, src: override.OutputCostPerImageMediumQuality}, {dst: &patched.OutputCostPerImageHighQuality, src: override.OutputCostPerImageHighQuality}, {dst: &patched.OutputCostPerImageAutoQuality, src: override.OutputCostPerImageAutoQuality}, + {dst: &patched.OCRCostPerPage, src: override.OCRCostPerPage}, + {dst: &patched.AnnotationCostPerPage, src: override.AnnotationCostPerPage}, } { if field.src != nil { *field.dst = field.src diff --git a/framework/modelcatalog/utils.go b/framework/modelcatalog/utils.go index cf18c5b919..aba65a9678 100644 --- a/framework/modelcatalog/utils.go +++ b/framework/modelcatalog/utils.go @@ -95,6 +95,8 @@ func normalizeRequestType(reqType schemas.RequestType) string { baseType = "image_edit" case schemas.VideoGenerationRequest, schemas.VideoRemixRequest: baseType = "video_generation" + case schemas.OCRRequest: + baseType = "ocr" } return baseType @@ -225,6 +227,10 @@ func convertPricingDataToTableModelPricing(modelKey string, entry PricingEntry) // Costs - Other SearchContextCostPerQuery: entry.SearchContextCostPerQuery, CodeInterpreterCostPerSession: entry.CodeInterpreterCostPerSession, + + // Costs - OCR + OCRCostPerPage: entry.OCRCostPerPage, + AnnotationCostPerPage: entry.AnnotationCostPerPage, } } @@ -305,6 +311,10 @@ func convertTableModelPricingToPricingData(pricing *configstoreTables.TableModel // Costs - Other SearchContextCostPerQuery: pricing.SearchContextCostPerQuery, CodeInterpreterCostPerSession: pricing.CodeInterpreterCostPerSession, + + // Costs - OCR + OCRCostPerPage: pricing.OCRCostPerPage, + AnnotationCostPerPage: pricing.AnnotationCostPerPage, } return &PricingEntry{ BaseModel: pricing.BaseModel, diff --git a/plugins/logging/changelog.md b/plugins/logging/changelog.md index e69de29bb2..be3f1fae3d 100644 --- a/plugins/logging/changelog.md +++ b/plugins/logging/changelog.md @@ -0,0 +1 @@ +- fix: adds support for logging OCR requests diff --git a/plugins/logging/main.go b/plugins/logging/main.go index 19d385374a..0b7d0162a8 100644 --- a/plugins/logging/main.go +++ b/plugins/logging/main.go @@ -214,6 +214,7 @@ type InitialLogData struct { Params any SpeechInput *schemas.SpeechInput TranscriptionInput *schemas.TranscriptionInput + OCRInput *schemas.OCRDocument ImageGenerationInput *schemas.ImageGenerationInput ImageEditInput *schemas.ImageEditInput ImageVariationInput *schemas.ImageVariationInput @@ -489,6 +490,7 @@ func (p *LoggerPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifr initialData.Params = req.RerankRequest.Params case schemas.OCRRequest: initialData.Params = req.OCRRequest.Params + initialData.OCRInput = &req.OCRRequest.Document case schemas.SpeechRequest, schemas.SpeechStreamRequest: initialData.Params = req.SpeechRequest.Params initialData.SpeechInput = req.SpeechRequest.Input diff --git a/plugins/logging/operations.go b/plugins/logging/operations.go index a89e5afe09..9333abc112 100644 --- a/plugins/logging/operations.go +++ b/plugins/logging/operations.go @@ -43,6 +43,7 @@ func (p *LoggerPlugin) insertInitialLogEntry( ToolsParsed: data.Tools, SpeechInputParsed: data.SpeechInput, TranscriptionInputParsed: data.TranscriptionInput, + OCRInputParsed: data.OCRInput, ImageGenerationInputParsed: data.ImageGenerationInput, ImageEditInputParsed: data.ImageEditInput, ImageVariationInputParsed: data.ImageVariationInput, diff --git a/plugins/logging/utils.go b/plugins/logging/utils.go index bc074f14b5..4d1abbbde5 100644 --- a/plugins/logging/utils.go +++ b/plugins/logging/utils.go @@ -545,29 +545,6 @@ func (p *LoggerPlugin) extractInputHistory(request *schemas.BifrostRequest) ([]s }, }, []schemas.ResponsesMessage{} } - if request.OCRRequest != nil { - var docRef string - if request.OCRRequest.Document.DocumentURL != nil { - docRef = *request.OCRRequest.Document.DocumentURL - } else if request.OCRRequest.Document.ImageURL != nil { - docRef = *request.OCRRequest.Document.ImageURL - } - // Strip query parameters to avoid logging sensitive tokens (e.g., pre-signed URLs) - if idx := strings.Index(docRef, "?"); idx != -1 { - docRef = docRef[:idx] - } - if docRef == "" { - return []schemas.ChatMessage{}, []schemas.ResponsesMessage{} - } - return []schemas.ChatMessage{ - { - Role: schemas.ChatMessageRoleUser, - Content: &schemas.ChatMessageContent{ - ContentStr: &docRef, - }, - }, - }, []schemas.ResponsesMessage{} - } if request.CountTokensRequest != nil && len(request.CountTokensRequest.Input) > 0 { return []schemas.ChatMessage{}, request.CountTokensRequest.Input } diff --git a/plugins/logging/writer.go b/plugins/logging/writer.go index 93a20d2827..fe34bc4a8f 100644 --- a/plugins/logging/writer.go +++ b/plugins/logging/writer.go @@ -316,9 +316,11 @@ func buildCompleteLogEntryFromPending(pending *PendingLogData) *logstore.Log { ToolsParsed: pending.InitialData.Tools, SpeechInputParsed: pending.InitialData.SpeechInput, TranscriptionInputParsed: pending.InitialData.TranscriptionInput, + OCRInputParsed: pending.InitialData.OCRInput, ImageGenerationInputParsed: pending.InitialData.ImageGenerationInput, ImageEditInputParsed: pending.InitialData.ImageEditInput, ImageVariationInputParsed: pending.InitialData.ImageVariationInput, + VideoGenerationInputParsed: pending.InitialData.VideoGenerationInput, PassthroughRequestBody: pending.InitialData.PassthroughRequestBody, } if pending.ParentRequestID != "" { diff --git a/tests/async/README.md b/tests/async/README.md new file mode 100644 index 0000000000..2cdf377019 --- /dev/null +++ b/tests/async/README.md @@ -0,0 +1,63 @@ +# Async Inference E2E Tests + +End-to-end tests for Bifrost's async inference feature (`/v1/async/*` endpoints and integration route headers). + +## Running + +```bash +go test ./... -timeout 300s +``` + +With virtual keys (enables VK-scoped auth tests): + +```bash +BIFROST_VK=sk-bf-... BIFROST_ALT_VK=sk-bf-... go test ./... -timeout 300s +``` + +## Environment Variables + +| Variable | Default | Description | +|---|---|---| +| `BIFROST_BASE_URL` | `http://localhost:8080` | Bifrost gateway URL | +| `BIFROST_VK` | — | Primary virtual key; enables VK-mode tests | +| `BIFROST_ALT_VK` | — | Second virtual key; enables cross-VK auth tests | +| `BIFROST_POLL_TIMEOUT` | `30s` | Max time to wait for a job to reach terminal state | +| `BIFROST_POLL_INTERVAL` | `500ms` | Polling cadence | +| `BIFROST_INTEGRATION_PATH` | `/openai/v1/responses` | Override integration route path | +| `BIFROST_INTEGRATION_MODEL` | `openai/gpt-4o-mini` | Override model for integration route tests | +| `ASYNC_*_MODEL` | see below | Override model per endpoint (e.g. `ASYNC_CHAT_COMPLETION_MODEL`) | + +### Model overrides + +| Variable | Default | +|---|---| +| `ASYNC_TEXT_COMPLETION_MODEL` | `openai/gpt-3.5-turbo-instruct` | +| `ASYNC_CHAT_COMPLETION_MODEL` | `openai/gpt-4o-mini` | +| `ASYNC_RESPONSES_MODEL` | `openai/gpt-4o-mini` | +| `ASYNC_EMBEDDINGS_MODEL` | `openai/text-embedding-3-small` | +| `ASYNC_SPEECH_MODEL` | `openai/tts-1` | +| `ASYNC_TRANSCRIPTION_MODEL` | `openai/whisper-1` | +| `ASYNC_IMAGE_GEN_MODEL` | `openai/dall-e-3` | +| `ASYNC_IMAGE_EDIT_MODEL` | `openai/dall-e-2` | +| `ASYNC_IMAGE_VARIATION_MODEL` | `openai/dall-e-2` | +| `ASYNC_RERANK_MODEL` | `cohere/rerank-english-v3.0` | +| `ASYNC_OCR_MODEL` | `mistral/mistral-ocr-latest` | +| `ASYNC_OCR_IMAGE_URL` | carpenter-ant CDN URL | + +## Test files + +| File | What it covers | +|---|---| +| `submit_test.go` | All 11 endpoints return 202, well-formed job envelope, immediate poll status | +| `lifecycle_test.go` | Jobs reach terminal state, 404 for non-existent/wrong-type, result shape | +| `auth_test.go` | VK scoping, cross-VK isolation, all three auth header formats | +| `ttl_test.go` | Default/custom/invalid TTL, TTL expiry → 404 | +| `validation_test.go` | Stream rejection, malformed JSON, missing required fields, wrong HTTP method | +| `integration_route_test.go` | `x-bf-async` / `x-bf-async-id` headers on `/openai/v1/responses` | + +## Notes + +- Tests skip gracefully when the gateway is unreachable (`/health` check at startup). +- Most tests run in two modes: **global** (no VK) and **with_vk** (when `BIFROST_VK` is set). +- Integration route tests use the Responses API path — `AsyncChatResponseConverter` is not implemented on any route; only `AsyncResponsesResponseConverter` is wired up. +- `BIFROST_ALT_VK` is only required for cross-VK isolation tests (`TestAuth_VKScoped_DifferentKey_Returns404`, `TestIntegration_VKScope_DifferentKey_Returns4xx`). diff --git a/tests/async/auth_test.go b/tests/async/auth_test.go new file mode 100644 index 0000000000..1d2e1811cf --- /dev/null +++ b/tests/async/auth_test.go @@ -0,0 +1,176 @@ +package async + +import ( + "net/http" + "strings" + "testing" +) + +// Auth tests cover every combination of VK presence at submit and poll time. +// All tests use chat_completions as a representative endpoint. + +// assertPollSuccess fails the test unless the poll returned a success code (200 or 202). +func assertPollSuccess(t *testing.T, code int, body []byte) { + t.Helper() + if code != http.StatusOK && code != http.StatusAccepted { + t.Fatalf("expected 200/202, got %d: %s", code, body) + } +} + +// TestAuth_Submit_InvalidVK_Returns400 verifies that submitting with a VK value +// unknown to the governance store fails at submit time with 400. +// Requires BIFROST_VK to be set, which proves VK governance is active on the server. +func TestAuth_Submit_InvalidVK_Returns400(t *testing.T) { + if cfg.VK == "" { + t.Skip("BIFROST_VK not set — governance may not be active") + } + ec := chatCompletionCase() + code, _, body := submitCase(t, ec, vkHeaders("sk-bf-nonexistent-key-for-auth-test")) + if code != http.StatusBadRequest { + t.Errorf("expected 400 for unknown VK on submit, got %d: %s", code, body) + } +} + +// TestAuth_VKScoped_SameKey_Succeeds submits with a VK and polls with the same VK. +func TestAuth_VKScoped_SameKey_Succeeds(t *testing.T) { + if cfg.VK == "" { + t.Skip("BIFROST_VK not set") + } + ec := chatCompletionCase() + _, submitted, body := submitCase(t, ec, vkHeaders(cfg.VK)) + if submitted.ID == "" { + t.Fatalf("submit returned no job id: %s", body) + } + + pollPath := jobPollPath(ec.pollBase, submitted.ID) + code, _, body := pollOnce(t, pollPath, vkHeaders(cfg.VK)) + assertPollSuccess(t, code, body) +} + +// TestAuth_VKScoped_DifferentKey_Returns404 submits with VK1 and polls with VK2. +// The gateway must return 404 because the VK IDs will not match. +func TestAuth_VKScoped_DifferentKey_Returns404(t *testing.T) { + if cfg.VK == "" || cfg.AltVK == "" { + t.Skip("both BIFROST_VK and BIFROST_ALT_VK must be set") + } + ec := chatCompletionCase() + _, submitted, body := submitCase(t, ec, vkHeaders(cfg.VK)) + if submitted.ID == "" { + t.Fatalf("submit returned no job id: %s", body) + } + + pollPath := jobPollPath(ec.pollBase, submitted.ID) + code, _, _ := pollOnce(t, pollPath, vkHeaders(cfg.AltVK)) + if code != http.StatusNotFound { + t.Errorf("expected 404 when polling with a different VK, got %d", code) + } +} + +// TestAuth_VKScoped_MissingKeyOnPoll_Returns404 submits with a VK and polls +// without one. The job stores a VirtualKeyID so the gateway requires a VK on poll. +func TestAuth_VKScoped_MissingKeyOnPoll_Returns404(t *testing.T) { + if cfg.VK == "" { + t.Skip("BIFROST_VK not set") + } + ec := chatCompletionCase() + _, submitted, body := submitCase(t, ec, vkHeaders(cfg.VK)) + if submitted.ID == "" { + t.Fatalf("submit returned no job id: %s", body) + } + + pollPath := jobPollPath(ec.pollBase, submitted.ID) + code, _, _ := pollOnce(t, pollPath, nil) + if code != http.StatusNotFound { + t.Errorf("expected 404 when polling a VK-scoped job without a VK, got %d", code) + } +} + +// TestAuth_PublicJob_AnonymousPoll_Succeeds submits without a VK (VirtualKeyID = nil) +// and polls without a VK. The VK check is skipped for public jobs. +func TestAuth_PublicJob_AnonymousPoll_Succeeds(t *testing.T) { + ec := chatCompletionCase() + _, submitted, body := submitCase(t, ec, nil) + if submitted.ID == "" { + t.Fatalf("submit returned no job id: %s", body) + } + + pollPath := jobPollPath(ec.pollBase, submitted.ID) + code, _, body := pollOnce(t, pollPath, nil) + assertPollSuccess(t, code, body) +} + +// TestAuth_PublicJob_VKPoll_Succeeds submits without a VK and polls with one. +// Per docs: "Jobs created without a virtual key are not virtual-key scoped, so they +// can be polled by any caller that passes your gateway auth/middleware checks." +func TestAuth_PublicJob_VKPoll_Succeeds(t *testing.T) { + if cfg.VK == "" { + t.Skip("BIFROST_VK not set") + } + ec := chatCompletionCase() + _, submitted, body := submitCase(t, ec, nil) + if submitted.ID == "" { + t.Fatalf("submit returned no job id: %s", body) + } + + pollPath := jobPollPath(ec.pollBase, submitted.ID) + code, _, body := pollOnce(t, pollPath, vkHeaders(cfg.VK)) + assertPollSuccess(t, code, body) +} + +// vkPrefixed returns true when vk begins with the governance virtual-key prefix "sk-bf-". +// Only keys with this prefix are recognised by the Authorization, x-api-key, and +// x-goog-api-key header paths in ConvertToBifrostContext. +func vkPrefixed(vk string) bool { + return strings.HasPrefix(strings.ToLower(vk), "sk-bf-") +} + +// TestAuth_BearerVK_SameKey_Succeeds submits with "Authorization: Bearer " and +// polls with the same header. Verifies the Bearer token path in ConvertToBifrostContext. +func TestAuth_BearerVK_SameKey_Succeeds(t *testing.T) { + if cfg.VK == "" || !vkPrefixed(cfg.VK) { + t.Skip("BIFROST_VK not set or does not start with sk-bf- prefix") + } + ec := chatCompletionCase() + headers := map[string]string{"Authorization": "Bearer " + cfg.VK} + _, submitted, body := submitCase(t, ec, headers) + if submitted.ID == "" { + t.Fatalf("submit returned no job id: %s", body) + } + pollPath := jobPollPath(ec.pollBase, submitted.ID) + code, _, body := pollOnce(t, pollPath, headers) + assertPollSuccess(t, code, body) +} + +// TestAuth_ApiKeyVK_SameKey_Succeeds submits with "x-api-key: " and polls with +// the same header. Verifies the x-api-key path in ConvertToBifrostContext. +func TestAuth_ApiKeyVK_SameKey_Succeeds(t *testing.T) { + if cfg.VK == "" || !vkPrefixed(cfg.VK) { + t.Skip("BIFROST_VK not set or does not start with sk-bf- prefix") + } + ec := chatCompletionCase() + headers := map[string]string{"x-api-key": cfg.VK} + _, submitted, body := submitCase(t, ec, headers) + if submitted.ID == "" { + t.Fatalf("submit returned no job id: %s", body) + } + pollPath := jobPollPath(ec.pollBase, submitted.ID) + code, _, body := pollOnce(t, pollPath, headers) + assertPollSuccess(t, code, body) +} + +// TestAuth_GoogApiKeyVK_SameKey_Succeeds submits with "x-goog-api-key: " and polls +// with the same header. Verifies the x-goog-api-key path in ConvertToBifrostContext. +func TestAuth_GoogApiKeyVK_SameKey_Succeeds(t *testing.T) { + if cfg.VK == "" || !vkPrefixed(cfg.VK) { + t.Skip("BIFROST_VK not set or does not start with sk-bf- prefix") + } + ec := chatCompletionCase() + headers := map[string]string{"x-goog-api-key": cfg.VK} + _, submitted, body := submitCase(t, ec, headers) + if submitted.ID == "" { + t.Fatalf("submit returned no job id: %s", body) + } + pollPath := jobPollPath(ec.pollBase, submitted.ID) + code, _, body := pollOnce(t, pollPath, headers) + assertPollSuccess(t, code, body) +} diff --git a/tests/async/fixtures_test.go b/tests/async/fixtures_test.go new file mode 100644 index 0000000000..21db030673 --- /dev/null +++ b/tests/async/fixtures_test.go @@ -0,0 +1,220 @@ +package async + +import ( + "bytes" + "image" + "image/color" + "image/png" + "os" + "path/filepath" +) + +// endpointCase describes a single async endpoint and the request payload to send. +type endpointCase struct { + name string + submitPath string // POST target, e.g. /v1/async/chat/completions + pollBase string // GET base; job ID is appended as /{job_id} + body map[string]any + multipart *multipartCase +} + +// multipartCase holds fields and named files for a multipart/form-data submission. +type multipartCase struct { + fields map[string]string + files map[string]fileFixture +} + +type fileFixture struct { + filename string + data []byte +} + +// defaultModels maps each ASYNC_*_MODEL env key to its default model string. +var defaultModels = map[string]string{ + "ASYNC_TEXT_COMPLETION_MODEL": "openai/gpt-3.5-turbo-instruct", + "ASYNC_CHAT_COMPLETION_MODEL": "openai/gpt-4o-mini", + "ASYNC_RESPONSES_MODEL": "openai/gpt-4o-mini", + "ASYNC_EMBEDDINGS_MODEL": "openai/text-embedding-3-small", + "ASYNC_SPEECH_MODEL": "openai/tts-1", + "ASYNC_TRANSCRIPTION_MODEL": "openai/whisper-1", + "ASYNC_IMAGE_GEN_MODEL": "openai/dall-e-3", + "ASYNC_IMAGE_EDIT_MODEL": "openai/dall-e-2", + "ASYNC_IMAGE_VARIATION_MODEL": "openai/dall-e-2", + "ASYNC_RERANK_MODEL": "cohere/rerank-english-v3.0", + "ASYNC_OCR_MODEL": "mistral/mistral-ocr-latest", +} + +// modelFor returns the env-var override for envKey, falling back to the default in defaultModels. +func modelFor(envKey string) string { + if v := os.Getenv(envKey); v != "" { + return v + } + return defaultModels[envKey] +} + +// endpointCases returns the full set of async endpoint fixtures, one per supported endpoint. +// Override any model via the corresponding ASYNC_*_MODEL environment variable. +func endpointCases() []endpointCase { + return []endpointCase{ + { + name: "text_completions", + submitPath: "/v1/async/completions", + pollBase: "/v1/async/completions", + body: map[string]any{ + "model": modelFor("ASYNC_TEXT_COMPLETION_MODEL"), + "prompt": "Say hello in one word.", + "max_tokens": 10, + }, + }, + { + name: "chat_completions", + submitPath: "/v1/async/chat/completions", + pollBase: "/v1/async/chat/completions", + body: map[string]any{ + "model": modelFor("ASYNC_CHAT_COMPLETION_MODEL"), + "messages": []map[string]any{ + {"role": "user", "content": "Say hello in one word."}, + }, + "max_tokens": 10, + }, + }, + { + name: "responses", + submitPath: "/v1/async/responses", + pollBase: "/v1/async/responses", + body: map[string]any{ + "model": modelFor("ASYNC_RESPONSES_MODEL"), + "input": "Say hello in one word.", + }, + }, + { + name: "embeddings", + submitPath: "/v1/async/embeddings", + pollBase: "/v1/async/embeddings", + body: map[string]any{ + "model": modelFor("ASYNC_EMBEDDINGS_MODEL"), + "input": "Hello world", + }, + }, + { + name: "speech", + submitPath: "/v1/async/audio/speech", + pollBase: "/v1/async/audio/speech", + body: map[string]any{ + "model": modelFor("ASYNC_SPEECH_MODEL"), + "input": "Hello", + "voice": "alloy", + }, + }, + { + name: "transcriptions", + submitPath: "/v1/async/audio/transcriptions", + pollBase: "/v1/async/audio/transcriptions", + multipart: &multipartCase{ + fields: map[string]string{ + "model": modelFor("ASYNC_TRANSCRIPTION_MODEL"), + }, + files: map[string]fileFixture{ + "file": {filename: "sample.mp3", data: sampleAudio()}, + }, + }, + }, + { + name: "image_generations", + submitPath: "/v1/async/images/generations", + pollBase: "/v1/async/images/generations", + body: map[string]any{ + "model": modelFor("ASYNC_IMAGE_GEN_MODEL"), + "prompt": "A simple red circle on a white background", + "n": 1, + "size": "1024x1024", + }, + }, + { + name: "image_edits", + submitPath: "/v1/async/images/edits", + pollBase: "/v1/async/images/edits", + multipart: &multipartCase{ + fields: map[string]string{ + "model": modelFor("ASYNC_IMAGE_EDIT_MODEL"), + "prompt": "Make it blue", + "n": "1", + "size": "256x256", + }, + files: map[string]fileFixture{ + "image": {filename: "image.png", data: samplePNG()}, + }, + }, + }, + { + name: "image_variations", + submitPath: "/v1/async/images/variations", + pollBase: "/v1/async/images/variations", + multipart: &multipartCase{ + fields: map[string]string{ + "model": modelFor("ASYNC_IMAGE_VARIATION_MODEL"), + "n": "1", + "size": "256x256", + }, + files: map[string]fileFixture{ + "image": {filename: "image.png", data: samplePNG()}, + }, + }, + }, + { + name: "rerank", + submitPath: "/v1/async/rerank", + pollBase: "/v1/async/rerank", + body: map[string]any{ + "model": modelFor("ASYNC_RERANK_MODEL"), + "query": "What is the capital of France?", + "documents": []map[string]any{ + {"text": "Paris is the capital of France."}, + {"text": "London is the capital of the United Kingdom."}, + {"text": "Berlin is the capital of Germany."}, + }, + }, + }, + { + name: "ocr", + submitPath: "/v1/async/ocr", + pollBase: "/v1/async/ocr", + body: map[string]any{ + "model": modelFor("ASYNC_OCR_MODEL"), + "document": map[string]any{ + "type": "image_url", + "image_url": envOr("ASYNC_OCR_IMAGE_URL", "https://pestworldcdn-dcf2a8gbggazaghf.z01.azurefd.net/media/561791/carpenter-ant4.jpg"), + }, + }, + }, + } +} + +// sampleAudio reads core/internal/llmtests/scenarios/media/sample.mp3. +// go test sets the working directory to the package source directory, so the +// relative path is stable without runtime.Caller (which breaks under -trimpath). +func sampleAudio() []byte { + mediaPath := filepath.Join("..", "..", "core", "internal", "llmtests", "scenarios", "media", "sample.mp3") + data, err := os.ReadFile(mediaPath) + if err != nil { + panic("sampleAudio: cannot read " + mediaPath + ": " + err.Error()) + } + return data +} + +// samplePNG generates a 256x256 white RGBA PNG for image edit / variation fixtures. +// DALL-E 2 requires images with an alpha channel (RGBA PNG). +func samplePNG() []byte { + img := image.NewRGBA(image.Rect(0, 0, 256, 256)) + white := color.RGBA{R: 255, G: 255, B: 255, A: 255} + for y := range 256 { + for x := range 256 { + img.Set(x, y, white) + } + } + var buf bytes.Buffer + if err := png.Encode(&buf, img); err != nil { + panic("samplePNG: encode failed: " + err.Error()) + } + return buf.Bytes() +} diff --git a/tests/async/go.mod b/tests/async/go.mod new file mode 100644 index 0000000000..1e7b31bec7 --- /dev/null +++ b/tests/async/go.mod @@ -0,0 +1,3 @@ +module github.com/maximhq/bifrost/tests/async + +go 1.26.2 diff --git a/tests/async/helpers_test.go b/tests/async/helpers_test.go new file mode 100644 index 0000000000..acd829bb64 --- /dev/null +++ b/tests/async/helpers_test.go @@ -0,0 +1,276 @@ +package async + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "os" + "testing" + "time" +) + +const ( + defaultBaseURL = "http://localhost:8080" + defaultPollTimeout = 30 * time.Second + defaultPollInterval = 500 * time.Millisecond +) + +// httpClient is used for all test HTTP calls; the 15s timeout prevents CI hangs. +var httpClient = &http.Client{Timeout: 15 * time.Second} + +// cfg holds e2e configuration sourced from environment variables at startup. +var cfg = struct { + BaseURL string + VK string // BIFROST_VK — primary virtual key + AltVK string // BIFROST_ALT_VK — a second, different virtual key for auth tests + PollTimeout time.Duration + PollInterval time.Duration +}{ + BaseURL: envOr("BIFROST_BASE_URL", defaultBaseURL), + VK: os.Getenv("BIFROST_VK"), + AltVK: os.Getenv("BIFROST_ALT_VK"), + PollTimeout: parseDuration(os.Getenv("BIFROST_POLL_TIMEOUT"), defaultPollTimeout), + PollInterval: parseDuration(os.Getenv("BIFROST_POLL_INTERVAL"), defaultPollInterval), +} + +func envOr(key, fallback string) string { + if v := os.Getenv(key); v != "" { + return v + } + return fallback +} + +func parseDuration(s string, fallback time.Duration) time.Duration { + if s == "" { + return fallback + } + d, err := time.ParseDuration(s) + if err != nil { + return fallback + } + return d +} + +// testMode describes one execution round for the core test suites. +type testMode struct { + name string + headers map[string]string // headers to attach to every submit and poll call +} + +// testModes returns the rounds every core test must execute. +// When BIFROST_VK is unset, only the global (no-VK) round runs. +func testModes() []testMode { + modes := []testMode{ + {name: "global", headers: nil}, + } + if cfg.VK != "" { + modes = append(modes, testMode{name: "with_vk", headers: vkHeaders(cfg.VK)}) + } + return modes +} + +// --- Response types --- + +// AsyncJobResponse mirrors the gateway's JSON envelope for async job responses. +type AsyncJobResponse struct { + ID string `json:"id"` + Status string `json:"status"` + CreatedAt time.Time `json:"created_at"` + CompletedAt *time.Time `json:"completed_at"` + ExpiresAt *time.Time `json:"expires_at"` + StatusCode int `json:"status_code"` + Result json.RawMessage `json:"result"` + Error json.RawMessage `json:"error"` +} + +func (j AsyncJobResponse) isTerminal() bool { + return j.Status == "completed" || j.Status == "failed" +} + +// --- HTTP helpers --- + +// submitJSON POSTs a JSON body and returns the HTTP status code, decoded response, and raw body. +func submitJSON(t *testing.T, path string, body any, headers map[string]string) (int, AsyncJobResponse, []byte) { + t.Helper() + raw, err := json.Marshal(body) + if err != nil { + t.Fatalf("submitJSON: marshal: %v", err) + } + req, err := http.NewRequest(http.MethodPost, cfg.BaseURL+path, bytes.NewReader(raw)) + if err != nil { + t.Fatalf("submitJSON: new request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + for k, v := range headers { + req.Header.Set(k, v) + } + return doRequest(t, req) +} + +// submitRaw POSTs arbitrary bytes — used for malformed-JSON validation tests. +func submitRaw(t *testing.T, path string, raw []byte, contentType string, headers map[string]string) (int, []byte) { + t.Helper() + req, err := http.NewRequest(http.MethodPost, cfg.BaseURL+path, bytes.NewReader(raw)) + if err != nil { + t.Fatalf("submitRaw: new request: %v", err) + } + req.Header.Set("Content-Type", contentType) + for k, v := range headers { + req.Header.Set(k, v) + } + code, _, body := doRequest(t, req) + return code, body +} + +// submitMultipart POSTs a multipart/form-data body. +func submitMultipart(t *testing.T, path string, mp *multipartCase, headers map[string]string) (int, AsyncJobResponse, []byte) { + t.Helper() + var buf bytes.Buffer + w := multipart.NewWriter(&buf) + for k, v := range mp.fields { + if err := w.WriteField(k, v); err != nil { + t.Fatalf("submitMultipart: write field %q: %v", k, err) + } + } + for fieldName, ff := range mp.files { + fw, err := w.CreateFormFile(fieldName, ff.filename) + if err != nil { + t.Fatalf("submitMultipart: create form file %q: %v", fieldName, err) + } + if _, err := fw.Write(ff.data); err != nil { + t.Fatalf("submitMultipart: write file %q: %v", fieldName, err) + } + } + if err := w.Close(); err != nil { + t.Fatalf("submitMultipart: close writer: %v", err) + } + + req, err := http.NewRequest(http.MethodPost, cfg.BaseURL+path, &buf) + if err != nil { + t.Fatalf("submitMultipart: new request: %v", err) + } + req.Header.Set("Content-Type", w.FormDataContentType()) + for k, v := range headers { + req.Header.Set(k, v) + } + return doRequest(t, req) +} + +// submitCase dispatches to submitJSON or submitMultipart based on the fixture type. +func submitCase(t *testing.T, ec endpointCase, headers map[string]string) (int, AsyncJobResponse, []byte) { + t.Helper() + if ec.multipart != nil { + return submitMultipart(t, ec.submitPath, ec.multipart, headers) + } + return submitJSON(t, ec.submitPath, ec.body, headers) +} + +// pollOnce performs a single GET and returns HTTP status, decoded response, and raw body. +func pollOnce(t *testing.T, pollPath string, headers map[string]string) (int, AsyncJobResponse, []byte) { + t.Helper() + req, err := http.NewRequest(http.MethodGet, cfg.BaseURL+pollPath, nil) + if err != nil { + t.Fatalf("pollOnce: new request: %v", err) + } + for k, v := range headers { + req.Header.Set(k, v) + } + return doRequest(t, req) +} + +// pollUntilTerminal polls every cfg.PollInterval until the job is completed/failed or cfg.PollTimeout elapses. +func pollUntilTerminal(t *testing.T, pollPath string, headers map[string]string) (int, AsyncJobResponse) { + t.Helper() + deadline := time.Now().Add(cfg.PollTimeout) + for time.Now().Before(deadline) { + code, job, _ := pollOnce(t, pollPath, headers) + if job.isTerminal() { + return code, job + } + if code != http.StatusAccepted { + t.Fatalf("unexpected HTTP %d while polling %s (status=%q)", code, pollPath, job.Status) + } + time.Sleep(cfg.PollInterval) + } + t.Fatalf("timed out after %s waiting for terminal status on %s", cfg.PollTimeout, pollPath) + return 0, AsyncJobResponse{} +} + +// --- Path / header helpers --- + +// jobPollPath builds the GET path for a job: /pollBase/{jobID}. +func jobPollPath(base, jobID string) string { + return base + "/" + jobID +} + +// vkHeaders returns a header map carrying the given virtual key. +// Returns nil when vk is empty so callers can safely pass it to submitCase. +func vkHeaders(vk string) map[string]string { + if vk == "" { + return nil + } + return map[string]string{"x-bf-vk": vk} +} + +// withTTLHeader copies headers and appends x-bf-async-job-result-ttl. +func withTTLHeader(headers map[string]string, ttlSeconds int) map[string]string { + out := make(map[string]string, len(headers)+1) + for k, v := range headers { + out[k] = v + } + out["x-bf-async-job-result-ttl"] = fmt.Sprintf("%d", ttlSeconds) + return out +} + +// withRawHeader copies headers and appends a single key/value pair. +func withRawHeader(headers map[string]string, key, value string) map[string]string { + out := make(map[string]string, len(headers)+1) + for k, v := range headers { + out[k] = v + } + out[key] = value + return out +} + +// doRequest executes an HTTP request and returns (statusCode, decoded AsyncJobResponse, rawBody). +func doRequest(t *testing.T, req *http.Request) (int, AsyncJobResponse, []byte) { + t.Helper() + resp, err := httpClient.Do(req) + if err != nil { + t.Fatalf("HTTP %s %s failed: %v", req.Method, req.URL, err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read response body: %v", err) + } + var job AsyncJobResponse + _ = json.Unmarshal(body, &job) + return resp.StatusCode, job, body +} + +// chatCompletionCase returns the chat_completions fixture — used as a representative +// endpoint in auth and TTL tests where endpoint variety is not the focus. +func chatCompletionCase() endpointCase { + for _, ec := range endpointCases() { + if ec.name == "chat_completions" { + return ec + } + } + panic("chatCompletionCase: fixture not found") +} + +// TestMain checks that the Bifrost gateway is reachable before running any tests. +// Set BIFROST_BASE_URL to override the default http://localhost:8080. +func TestMain(m *testing.M) { + resp, err := httpClient.Get(cfg.BaseURL + "/health") + if err != nil || resp.StatusCode >= 500 { + fmt.Printf("SKIP: Bifrost gateway not reachable at %s (err=%v)\n", cfg.BaseURL, err) + os.Exit(0) + } + resp.Body.Close() + os.Exit(m.Run()) +} diff --git a/tests/async/integration_route_test.go b/tests/async/integration_route_test.go new file mode 100644 index 0000000000..27d41687a7 --- /dev/null +++ b/tests/async/integration_route_test.go @@ -0,0 +1,174 @@ +package async + +import ( + "encoding/json" + "maps" + "net/http" + "os" + "strings" + "testing" +) + +// Integration route tests verify that x-bf-async and x-bf-async-id headers work on +// provider integration routes. These routes apply a provider-specific response converter, +// so the envelope differs from /v1/async/* endpoints: +// +// Submit (x-bf-async: true) → HTTP 200 (not 202) +// Retrieve (x-bf-async-id: ) → HTTP 200 for any job state +// +// Optional env: +// +// BIFROST_INTEGRATION_PATH — override the default /openai/v1/responses +// BIFROST_INTEGRATION_MODEL — model string; defaults to ASYNC_RESPONSES_MODEL default +// +// Note: only routes with AsyncResponsesResponseConverter support x-bf-async. +// AsyncChatResponseConverter is not implemented on any route — the Responses API +// path (/openai/v1/responses) is the only integration route that supports async. +func integrationPath() string { + return envOr("BIFROST_INTEGRATION_PATH", "/openai/v1/responses") +} + +func integrationModel() string { + if v := os.Getenv("BIFROST_INTEGRATION_MODEL"); v != "" { + return v + } + return modelFor("ASYNC_RESPONSES_MODEL") +} + +// assert4xx fails the test unless code is a 4xx client error, catching 5xx regressions. +func assert4xx(t *testing.T, code int, body []byte) { + t.Helper() + if code < 400 || code >= 500 { + t.Fatalf("expected 4xx, got %d: %s", code, body) + } +} + +// integrationJobID extracts the job UUID from an integration route response body. +// All integration converters preserve the async job ID in the top-level "id" field. +func integrationJobID(t *testing.T, body []byte) string { + t.Helper() + var m map[string]any + if err := json.Unmarshal(body, &m); err != nil { + return "" + } + if id, ok := m["id"].(string); ok { + return id + } + return "" +} + +// pollIntegration POSTs to an integration path with x-bf-async-id header to retrieve a job. +// Integration routes use the same POST method for both submit and retrieve. +func pollIntegration(t *testing.T, path, jobID string, headers map[string]string) (int, []byte) { + t.Helper() + h := make(map[string]string, len(headers)+1) + maps.Copy(h, headers) + h["x-bf-async-id"] = jobID + code, body := submitRaw(t, path, []byte("{}"), "application/json", h) + return code, body +} + +// integrationSubmitBody returns a minimal Responses API body for the integration path. +func integrationSubmitBody() map[string]any { + return map[string]any{ + "model": integrationModel(), + "input": "Say hello in one word.", + } +} + +// TestIntegration_AsyncCreate_Returns200WithJobID submits a chat request via an integration +// route with x-bf-async header and confirms the response is 200 OK with a job UUID. +// Integration routes return 200 (not 202) because the response passes through the +// provider-specific converter before being sent. +func TestIntegration_AsyncCreate_Returns200WithJobID(t *testing.T) { + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + headers := withRawHeader(mode.headers, "x-bf-async", "true") + code, _, body := submitJSON(t, integrationPath(), integrationSubmitBody(), headers) + if code != http.StatusOK { + t.Fatalf("expected 200 from integration async submit, got %d: %s", code, body) + } + jobID := integrationJobID(t, body) + if jobID == "" { + t.Fatalf("no job id in integration route response: %s", body) + } + parts := strings.Split(jobID, "-") + if len(parts) != 5 || len(parts[0]) != 8 || len(parts[1]) != 4 || + len(parts[2]) != 4 || len(parts[3]) != 4 || len(parts[4]) != 12 { + t.Errorf("id %q does not look like a UUID", jobID) + } + }) + } +} + +// TestIntegration_AsyncRetrieve_Returns200 submits an async job on an integration route +// and polls it via x-bf-async-id header, confirming retrieve also returns 200 OK. +func TestIntegration_AsyncRetrieve_Returns200(t *testing.T) { + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + headers := withRawHeader(mode.headers, "x-bf-async", "true") + code, _, body := submitJSON(t, integrationPath(), integrationSubmitBody(), headers) + if code != http.StatusOK { + t.Fatalf("submit failed with %d: %s", code, body) + } + jobID := integrationJobID(t, body) + if jobID == "" { + t.Fatalf("no job id in submit response: %s", body) + } + + pollCode, pollBody := pollIntegration(t, integrationPath(), jobID, mode.headers) + if pollCode != http.StatusOK { + t.Errorf("expected 200 on integration retrieve, got %d: %s", pollCode, pollBody) + } + }) + } +} + +// TestIntegration_AsyncRetrieve_NonExistentJob_Returns4xx polls an integration route with +// a fake job ID and confirms a non-success status code is returned. +func TestIntegration_AsyncRetrieve_NonExistentJob_Returns4xx(t *testing.T) { + const fakeID = "00000000-0000-0000-0000-000000000000" + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + code, body := pollIntegration(t, integrationPath(), fakeID, mode.headers) + assert4xx(t, code, body) + }) + } +} + +// TestIntegration_AsyncCreate_StreamRejected confirms that submitting a streaming request +// via x-bf-async is rejected — streaming and async are mutually exclusive. +func TestIntegration_AsyncCreate_StreamRejected(t *testing.T) { + streamBody := map[string]any{ + "model": integrationModel(), + "input": "Hello", + "stream": true, + } + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + headers := withRawHeader(mode.headers, "x-bf-async", "true") + code, _, body := submitJSON(t, integrationPath(), streamBody, headers) + assert4xx(t, code, body) + }) + } +} + +// TestIntegration_VKScope_DifferentKey_Returns4xx submits an async job on an integration +// route with VK1 and retrieves with VK2, confirming VK isolation works on integration routes. +func TestIntegration_VKScope_DifferentKey_Returns4xx(t *testing.T) { + if cfg.VK == "" || cfg.AltVK == "" { + t.Skip("both BIFROST_VK and BIFROST_ALT_VK must be set") + } + headers := withRawHeader(vkHeaders(cfg.VK), "x-bf-async", "true") + code, _, body := submitJSON(t, integrationPath(), integrationSubmitBody(), headers) + if code != http.StatusOK { + t.Fatalf("submit failed with %d: %s", code, body) + } + jobID := integrationJobID(t, body) + if jobID == "" { + t.Fatalf("no job id in submit response: %s", body) + } + + pollCode, pollBody := pollIntegration(t, integrationPath(), jobID, vkHeaders(cfg.AltVK)) + assert4xx(t, pollCode, pollBody) +} diff --git a/tests/async/lifecycle_test.go b/tests/async/lifecycle_test.go new file mode 100644 index 0000000000..8ee9df094a --- /dev/null +++ b/tests/async/lifecycle_test.go @@ -0,0 +1,180 @@ +package async + +import ( + "encoding/json" + "net/http" + "testing" +) + +// TestLifecycle_AllEndpoints_ReachesTerminalState submits a job for every supported +// endpoint and polls until it reaches completed or failed, then validates the +// terminal response shape. Passes for either outcome — the test asserts the async +// mechanism itself, not model availability. +// Runs in both global and VK modes. +func TestLifecycle_AllEndpoints_ReachesTerminalState(t *testing.T) { + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + for _, ec := range endpointCases() { + t.Run(ec.name, func(t *testing.T) { + _, submitted, body := submitCase(t, ec, mode.headers) + if submitted.ID == "" { + t.Fatalf("submit returned no job id: %s", body) + } + + pollPath := jobPollPath(ec.pollBase, submitted.ID) + code, job := pollUntilTerminal(t, pollPath, mode.headers) + + if code != http.StatusOK { + t.Errorf("expected 200 for terminal job, got %d", code) + } + if job.ID != submitted.ID { + t.Errorf("polled id %q does not match submitted id %q", job.ID, submitted.ID) + } + if job.CompletedAt == nil { + t.Error("completed_at must be set on a terminal job") + } + if job.ExpiresAt == nil { + t.Error("expires_at must be set on a terminal job") + } + if job.CompletedAt != nil && job.ExpiresAt != nil && !job.ExpiresAt.After(*job.CompletedAt) { + t.Error("expires_at must be after completed_at") + } + + switch job.Status { + case "completed": + if len(job.Result) == 0 || string(job.Result) == "null" { + t.Error("completed job must have a non-null result") + } + case "failed": + if len(job.Error) == 0 || string(job.Error) == "null" { + t.Error("failed job must have a non-null error") + } + if job.StatusCode == 0 { + t.Error("failed job must carry a non-zero status_code") + } + } + }) + } + }) + } +} + +// TestLifecycle_Poll_NonExistentJob_Returns404 confirms that polling a random job ID +// returns 404 regardless of VK mode (job lookup fails before VK check). +// Uses chat_completions as a representative endpoint — all endpoints share the same +// RetrieveJob() path, so repeating across all 11 adds no coverage. +func TestLifecycle_Poll_NonExistentJob_Returns404(t *testing.T) { + const fakeID = "00000000-0000-0000-0000-000000000000" + ec := chatCompletionCase() + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + pollPath := jobPollPath(ec.pollBase, fakeID) + code, _, _ := pollOnce(t, pollPath, mode.headers) + if code != http.StatusNotFound { + t.Errorf("expected 404 for non-existent job, got %d", code) + } + }) + } +} + +// TestLifecycle_CompletedJobResultShape checks that completed jobs carry the expected +// top-level fields in their result JSON. If a job fails (e.g., no live API key), the +// shape check is skipped for that case — the test asserts structure, not model availability. +func TestLifecycle_CompletedJobResultShape(t *testing.T) { + type shapeCheck struct { + name string + check func(t *testing.T, result json.RawMessage) + } + + shapeChecks := map[string]shapeCheck{ + "chat_completions": { + "choices[]", + func(t *testing.T, result json.RawMessage) { + var r struct { + Choices []json.RawMessage `json:"choices"` + } + if err := json.Unmarshal(result, &r); err != nil { + t.Fatalf("unmarshal choices: %v", err) + } + if len(r.Choices) == 0 { + t.Error("completed chat job must have at least one choice") + } + }, + }, + "embeddings": { + "data[]", + func(t *testing.T, result json.RawMessage) { + var r struct { + Data []json.RawMessage `json:"data"` + } + if err := json.Unmarshal(result, &r); err != nil { + t.Fatalf("unmarshal data: %v", err) + } + if len(r.Data) == 0 { + t.Error("completed embeddings job must have at least one data entry") + } + }, + }, + "rerank": { + "results[]", + func(t *testing.T, result json.RawMessage) { + var r struct { + Results []json.RawMessage `json:"results"` + } + if err := json.Unmarshal(result, &r); err != nil { + t.Fatalf("unmarshal results: %v", err) + } + if len(r.Results) == 0 { + t.Error("completed rerank job must have at least one result") + } + }, + }, + } + + for _, ec := range endpointCases() { + sc, ok := shapeChecks[ec.name] + if !ok { + continue + } + t.Run(ec.name+"/"+sc.name, func(t *testing.T) { + _, submitted, body := submitCase(t, ec, nil) + if submitted.ID == "" { + t.Fatalf("submit returned no job id: %s", body) + } + pollPath := jobPollPath(ec.pollBase, submitted.ID) + _, job := pollUntilTerminal(t, pollPath, nil) + if job.Status != "completed" { + t.Skipf("job status=%q (not completed) — shape check skipped", job.Status) + } + sc.check(t, job.Result) + }) + } +} + +// TestLifecycle_Poll_WrongEndpointType_Returns404 submits a job on one endpoint and +// polls it via a different endpoint's path, expecting 404 (type mismatch). +func TestLifecycle_Poll_WrongEndpointType_Returns404(t *testing.T) { + cases := endpointCases() + if len(cases) < 2 { + t.Skip("need at least two endpoint cases") + } + + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + // Submit on cases[0], poll via cases[1]'s poll base. + submitter := cases[0] + wrongBase := cases[1].pollBase + + _, submitted, body := submitCase(t, submitter, mode.headers) + if submitted.ID == "" { + t.Fatalf("submit returned no job id: %s", body) + } + + pollPath := jobPollPath(wrongBase, submitted.ID) + code, _, _ := pollOnce(t, pollPath, mode.headers) + if code != http.StatusNotFound { + t.Errorf("expected 404 when polling with wrong endpoint type, got %d", code) + } + }) + } +} diff --git a/tests/async/submit_test.go b/tests/async/submit_test.go new file mode 100644 index 0000000000..ea79654f2d --- /dev/null +++ b/tests/async/submit_test.go @@ -0,0 +1,89 @@ +package async + +import ( + "net/http" + "strings" + "testing" + "time" +) + +// TestSubmit_AllEndpoints_Returns202 verifies that every async endpoint immediately +// returns 202 Accepted with a well-formed job envelope. +// Runs once in global mode (no VK) and once with BIFROST_VK when set. +func TestSubmit_AllEndpoints_Returns202(t *testing.T) { + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + for _, ec := range endpointCases() { + t.Run(ec.name, func(t *testing.T) { + code, job, body := submitCase(t, ec, mode.headers) + + if code != http.StatusAccepted { + t.Fatalf("expected 202, got %d: %s", code, body) + } + if job.ID == "" { + t.Fatal("response missing id") + } + // UUID format: 8-4-4-4-12 hex groups separated by hyphens. + parts := strings.Split(job.ID, "-") + if len(parts) != 5 || len(parts[0]) != 8 || len(parts[1]) != 4 || + len(parts[2]) != 4 || len(parts[3]) != 4 || len(parts[4]) != 12 { + t.Errorf("id %q does not look like a UUID", job.ID) + } + if job.Status != "pending" { + t.Errorf("expected status=pending, got %q", job.Status) + } + if job.CreatedAt.IsZero() { + t.Error("created_at is zero") + } + if time.Since(job.CreatedAt) > 30*time.Second { + t.Errorf("created_at %v appears stale (>30s ago)", job.CreatedAt) + } + if job.CompletedAt != nil { + t.Error("completed_at must be absent on a freshly submitted job") + } + if job.ExpiresAt != nil { + t.Error("expires_at must be absent on a freshly submitted job") + } + }) + } + }) + } +} + +// TestSubmit_AllEndpoints_PollPathReturnsPending verifies that polling immediately +// after submission yields a non-terminal (pending/processing) or just-completed state +// with the correct HTTP status code for each. +// Runs in both global and VK modes. +func TestSubmit_AllEndpoints_PollPathReturnsPending(t *testing.T) { + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + for _, ec := range endpointCases() { + t.Run(ec.name, func(t *testing.T) { + submitCode, submitted, body := submitCase(t, ec, mode.headers) + if submitCode != http.StatusAccepted { + t.Fatalf("expected submit 202, got %d: %s", submitCode, body) + } + if submitted.ID == "" { + t.Fatalf("submit returned no job id: %s", body) + } + + pollPath := jobPollPath(ec.pollBase, submitted.ID) + code, polled, _ := pollOnce(t, pollPath, mode.headers) + + switch polled.Status { + case "pending", "processing": + if code != http.StatusAccepted { + t.Errorf("expected 202 for status %q, got %d", polled.Status, code) + } + case "completed", "failed": + if code != http.StatusOK { + t.Errorf("expected 200 for terminal status %q, got %d", polled.Status, code) + } + default: + t.Errorf("unexpected status %q (HTTP %d)", polled.Status, code) + } + }) + } + }) + } +} diff --git a/tests/async/ttl_test.go b/tests/async/ttl_test.go new file mode 100644 index 0000000000..d7818a1917 --- /dev/null +++ b/tests/async/ttl_test.go @@ -0,0 +1,157 @@ +package async + +import ( + "net/http" + "testing" + "time" +) + +// TTL tests use chat_completions as a representative endpoint and run in both +// global and VK modes. They verify that expires_at is set correctly relative to +// completed_at based on the TTL value in effect. + +// TestTTL_DefaultApplied verifies that when no TTL header is sent, expires_at is +// approximately 3600s (one hour) after completed_at. +func TestTTL_DefaultApplied(t *testing.T) { + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + ec := chatCompletionCase() + _, submitted, body := submitCase(t, ec, mode.headers) + if submitted.ID == "" { + t.Fatalf("submit returned no job id: %s", body) + } + pollPath := jobPollPath(ec.pollBase, submitted.ID) + _, job := pollUntilTerminal(t, pollPath, mode.headers) + assertTTL(t, job, 3600, 60) + }) + } +} + +// TestTTL_CustomHeaderApplied verifies that x-bf-async-job-result-ttl overrides the +// default and expires_at is roughly TTL seconds after completed_at. +func TestTTL_CustomHeaderApplied(t *testing.T) { + const customTTL = 120 + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + ec := chatCompletionCase() + headers := withTTLHeader(mode.headers, customTTL) + _, submitted, body := submitCase(t, ec, headers) + if submitted.ID == "" { + t.Fatalf("submit returned no job id: %s", body) + } + pollPath := jobPollPath(ec.pollBase, submitted.ID) + // Poll must use the mode headers, not the TTL headers (TTL is submit-only). + _, job := pollUntilTerminal(t, pollPath, mode.headers) + assertTTL(t, job, customTTL, 30) + }) + } +} + +// TestTTL_InvalidHeader_FallsBackToDefault verifies that a non-numeric TTL header +// is ignored and the server falls back to the default 3600s TTL. +func TestTTL_InvalidHeader_FallsBackToDefault(t *testing.T) { + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + ec := chatCompletionCase() + headers := withRawHeader(mode.headers, "x-bf-async-job-result-ttl", "not-a-number") + _, submitted, body := submitCase(t, ec, headers) + if submitted.ID == "" { + t.Fatalf("submit returned no job id: %s", body) + } + pollPath := jobPollPath(ec.pollBase, submitted.ID) + _, job := pollUntilTerminal(t, pollPath, mode.headers) + assertTTL(t, job, 3600, 60) + }) + } +} + +// TestTTL_ZeroHeader_FallsBackToDefault verifies that TTL=0 is treated as invalid +// (per SubmitJob: if resultTTL <= 0 use default) and falls back to 3600s. +func TestTTL_ZeroHeader_FallsBackToDefault(t *testing.T) { + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + ec := chatCompletionCase() + headers := withTTLHeader(mode.headers, 0) + _, submitted, body := submitCase(t, ec, headers) + if submitted.ID == "" { + t.Fatalf("submit returned no job id: %s", body) + } + pollPath := jobPollPath(ec.pollBase, submitted.ID) + _, job := pollUntilTerminal(t, pollPath, mode.headers) + assertTTL(t, job, 3600, 60) + }) + } +} + +// TestTTL_NegativeHeader_FallsBackToDefault verifies that a negative TTL value +// falls back to the default 3600s. +func TestTTL_NegativeHeader_FallsBackToDefault(t *testing.T) { + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + ec := chatCompletionCase() + headers := withTTLHeader(mode.headers, -1) + _, submitted, body := submitCase(t, ec, headers) + if submitted.ID == "" { + t.Fatalf("submit returned no job id: %s", body) + } + pollPath := jobPollPath(ec.pollBase, submitted.ID) + _, job := pollUntilTerminal(t, pollPath, mode.headers) + assertTTL(t, job, 3600, 60) + }) + } +} + +// TestTTL_ExpiredJob_Returns404 submits a job with a very short TTL, waits for +// completion, then waits for the TTL to elapse and confirms polling returns 404. +// Verifies FindAsyncJobByID filters on expires_at > NOW(). +func TestTTL_ExpiredJob_Returns404(t *testing.T) { + const shortTTL = 10 // seconds — must be larger than BIFROST_POLL_INTERVAL + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + ec := chatCompletionCase() + headers := withTTLHeader(mode.headers, shortTTL) + _, submitted, body := submitCase(t, ec, headers) + if submitted.ID == "" { + t.Fatalf("submit returned no job id: %s", body) + } + + pollPath := jobPollPath(ec.pollBase, submitted.ID) + pollUntilTerminal(t, pollPath, mode.headers) + + // Poll until 404 (TTL expired) with a generous deadline to avoid flakiness. + deadline := time.Now().Add(time.Duration(shortTTL+10) * time.Second) + for { + code, _, _ := pollOnce(t, pollPath, mode.headers) + if code == http.StatusNotFound { + break + } + if time.Now().After(deadline) { + t.Fatalf("expected 404 after TTL expiry, last code=%d", code) + } + time.Sleep(250 * time.Millisecond) + } + }) + } +} + +// assertTTL checks that expires_at ≈ completed_at + wantTTLSeconds within toleranceSeconds. +func assertTTL(t *testing.T, job AsyncJobResponse, wantTTLSeconds, toleranceSeconds int) { + t.Helper() + if job.CompletedAt == nil { + t.Fatal("completed_at is nil, cannot verify TTL") + } + if job.ExpiresAt == nil { + t.Fatal("expires_at is nil, cannot verify TTL") + } + actual := job.ExpiresAt.Sub(*job.CompletedAt) + want := time.Duration(wantTTLSeconds) * time.Second + tolerance := time.Duration(toleranceSeconds) * time.Second + diff := actual - want + if diff < 0 { + diff = -diff + } + if diff > tolerance { + t.Errorf("TTL mismatch: expires_at - completed_at = %v, want %v ± %v", + actual, want, tolerance) + } +} diff --git a/tests/async/validation_test.go b/tests/async/validation_test.go new file mode 100644 index 0000000000..db337f52c0 --- /dev/null +++ b/tests/async/validation_test.go @@ -0,0 +1,306 @@ +package async + +import ( + "net/http" + "testing" +) + +// streamEndpoints lists async endpoints that reject stream=true in the JSON body. +// Speech uses stream_format instead and is tested separately. +// image_edits and image_variations are multipart-only endpoints; their stream field +// is a multipart form value — not a JSON body field — so they are not listed here. +var streamEndpoints = []struct { + name string + submitPath string + body map[string]any +}{ + { + name: "text_completions", + submitPath: "/v1/async/completions", + body: map[string]any{ + "model": modelFor("ASYNC_TEXT_COMPLETION_MODEL"), + "prompt": "Hello", + "stream": true, + }, + }, + { + name: "chat_completions", + submitPath: "/v1/async/chat/completions", + body: map[string]any{ + "model": modelFor("ASYNC_CHAT_COMPLETION_MODEL"), + "messages": []map[string]any{{"role": "user", "content": "Hello"}}, + "stream": true, + }, + }, + { + name: "responses", + submitPath: "/v1/async/responses", + body: map[string]any{ + "model": modelFor("ASYNC_RESPONSES_MODEL"), + "input": "Hello", + "stream": true, + }, + }, + { + name: "image_generations", + submitPath: "/v1/async/images/generations", + body: map[string]any{ + "model": modelFor("ASYNC_IMAGE_GEN_MODEL"), + "prompt": "A circle", + "stream": true, + }, + }, +} + +// TestValidation_StreamRejected_Returns400 confirms that stream=true is rejected +// with 400 before any job is created. Runs in both global and VK modes because the +// stream check happens before VK resolution. +func TestValidation_StreamRejected_Returns400(t *testing.T) { + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + for _, ep := range streamEndpoints { + t.Run(ep.name, func(t *testing.T) { + code, _, body := submitJSON(t, ep.submitPath, ep.body, mode.headers) + if code != http.StatusBadRequest { + t.Errorf("expected 400 for stream=true on %s, got %d: %s", + ep.submitPath, code, body) + } + }) + } + }) + } +} + +// TestValidation_SpeechStreamFormatRejected_Returns400 confirms that the speech +// endpoint rejects stream_format=sse with 400. +func TestValidation_SpeechStreamFormatRejected_Returns400(t *testing.T) { + body := map[string]any{ + "model": modelFor("ASYNC_SPEECH_MODEL"), + "input": "Hello", + "voice": "alloy", + "stream_format": "sse", + } + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + code, _, raw := submitJSON(t, "/v1/async/audio/speech", body, mode.headers) + if code != http.StatusBadRequest { + t.Errorf("expected 400 for stream_format=sse on speech, got %d: %s", code, raw) + } + }) + } +} + +// TestValidation_MalformedJSON_Returns400 verifies that sending malformed JSON to any +// async JSON endpoint returns 400 before a job is created. +func TestValidation_MalformedJSON_Returns400(t *testing.T) { + jsonEndpoints := []endpointCase{} + for _, ec := range endpointCases() { + if ec.multipart == nil { + jsonEndpoints = append(jsonEndpoints, ec) + } + } + + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + for _, ec := range jsonEndpoints { + t.Run(ec.name, func(t *testing.T) { + code, body := submitRaw(t, ec.submitPath, []byte(`{invalid json`), + "application/json", mode.headers) + if code != http.StatusBadRequest { + t.Errorf("expected 400 for malformed JSON on %s, got %d: %s", + ec.submitPath, code, body) + } + }) + } + }) + } +} + +// TestValidation_TranscriptionStreamRejected_Returns400 confirms that the transcription +// endpoint rejects stream=true (sent as a multipart field) with 400. +func TestValidation_TranscriptionStreamRejected_Returns400(t *testing.T) { + mp := &multipartCase{ + fields: map[string]string{ + "model": modelFor("ASYNC_TRANSCRIPTION_MODEL"), + "stream": "true", + }, + files: map[string]fileFixture{ + "file": {filename: "sample.mp3", data: sampleAudio()}, + }, + } + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + code, _, body := submitMultipart(t, "/v1/async/audio/transcriptions", mp, mode.headers) + if code != http.StatusBadRequest { + t.Errorf("expected 400 for stream=true on transcription, got %d: %s", code, body) + } + }) + } +} + +// TestValidation_MissingModel_Returns400 verifies that submitting without a model field +// is rejected with 400 across all JSON endpoints. +func TestValidation_MissingModel_Returns400(t *testing.T) { + missingModelCases := []struct { + name string + path string + body map[string]any + }{ + { + "chat_completions", + "/v1/async/chat/completions", + map[string]any{"messages": []map[string]any{{"role": "user", "content": "Hello"}}}, + }, + { + "text_completions", + "/v1/async/completions", + map[string]any{"prompt": "Hello"}, + }, + { + "embeddings", + "/v1/async/embeddings", + map[string]any{"input": "Hello"}, + }, + { + "responses", + "/v1/async/responses", + map[string]any{"input": "Hello"}, + }, + { + "speech", + "/v1/async/audio/speech", + map[string]any{"input": "Hello", "voice": "alloy"}, + }, + { + "rerank", + "/v1/async/rerank", + map[string]any{ + "query": "test", + "documents": []map[string]any{{"text": "test document"}}, + }, + }, + { + "ocr", + "/v1/async/ocr", + map[string]any{ + "document": map[string]any{ + "type": "image_url", + "image_url": envOr("ASYNC_OCR_IMAGE_URL", "https://pestworldcdn-dcf2a8gbggazaghf.z01.azurefd.net/media/561791/carpenter-ant4.jpg"), + }, + }, + }, + } + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + for _, mc := range missingModelCases { + t.Run(mc.name, func(t *testing.T) { + code, _, body := submitJSON(t, mc.path, mc.body, mode.headers) + if code != http.StatusBadRequest { + t.Errorf("expected 400 for missing model on %s, got %d: %s", mc.path, code, body) + } + }) + } + }) + } +} + +// TestValidation_ImageEditStreamRejected_Returns400 confirms that the image edit endpoint +// rejects stream=true (sent as a multipart form field) with 400. This requires a complete +// valid multipart body because stream validation runs after successful form parsing. +func TestValidation_ImageEditStreamRejected_Returns400(t *testing.T) { + mp := &multipartCase{ + fields: map[string]string{ + "model": modelFor("ASYNC_IMAGE_EDIT_MODEL"), + "prompt": "Make it blue", + "stream": "true", + }, + files: map[string]fileFixture{ + "image": {filename: "image.png", data: samplePNG()}, + }, + } + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + code, _, body := submitMultipart(t, "/v1/async/images/edits", mp, mode.headers) + if code != http.StatusBadRequest { + t.Errorf("expected 400 for stream=true on image edits, got %d: %s", code, body) + } + }) + } +} + +// TestValidation_Transcription_MissingFile_Returns400 verifies that a transcription request +// without the required audio file is rejected with 400 at the multipart parse stage. +func TestValidation_Transcription_MissingFile_Returns400(t *testing.T) { + mp := &multipartCase{ + fields: map[string]string{ + "model": modelFor("ASYNC_TRANSCRIPTION_MODEL"), + }, + // no "file" entry + } + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + code, _, body := submitMultipart(t, "/v1/async/audio/transcriptions", mp, mode.headers) + if code != http.StatusBadRequest { + t.Errorf("expected 400 for missing audio file on transcription, got %d: %s", code, body) + } + }) + } +} + +// TestValidation_ImageEdit_MissingImage_Returns400 verifies that an image edit request +// without the required image file is rejected with 400. +func TestValidation_ImageEdit_MissingImage_Returns400(t *testing.T) { + mp := &multipartCase{ + fields: map[string]string{ + "model": modelFor("ASYNC_IMAGE_EDIT_MODEL"), + "prompt": "Make it blue", + }, + // no "image" file entry + } + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + code, _, body := submitMultipart(t, "/v1/async/images/edits", mp, mode.headers) + if code != http.StatusBadRequest { + t.Errorf("expected 400 for missing image file on image edits, got %d: %s", code, body) + } + }) + } +} + +// TestValidation_ImageVariation_MissingImage_Returns400 verifies that an image variation +// request without the required image file is rejected with 400. +func TestValidation_ImageVariation_MissingImage_Returns400(t *testing.T) { + mp := &multipartCase{ + fields: map[string]string{ + "model": modelFor("ASYNC_IMAGE_VARIATION_MODEL"), + }, + // no "image" file entry + } + for _, mode := range testModes() { + t.Run(mode.name, func(t *testing.T) { + code, _, body := submitMultipart(t, "/v1/async/images/variations", mp, mode.headers) + if code != http.StatusBadRequest { + t.Errorf("expected 400 for missing image file on image variations, got %d: %s", code, body) + } + }) + } +} + +// TestHTTP_WrongMethod_Rejected verifies that POST on a poll-only path does not return +// a success status code. The converse (GET on a submit path) is not checked here +// because the server's UI layer intercepts bare GET requests on /v1/async/* paths +// before the async router is reached. +func TestHTTP_WrongMethod_Rejected(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, cfg.BaseURL+"/v1/async/chat/completions/00000000-0000-0000-0000-000000000000", nil) + if err != nil { + t.Fatalf("build request: %v", err) + } + resp, err := httpClient.Do(req) + if err != nil { + t.Fatalf("POST /v1/async/chat/completions/{id} failed: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusNotFound && resp.StatusCode != http.StatusMethodNotAllowed { + t.Errorf("POST on poll path returned %d, expected 404 or 405", resp.StatusCode) + } +} diff --git a/transports/bifrost-http/handlers/mcpserver.go b/transports/bifrost-http/handlers/mcpserver.go index 9db9d674ff..b03488ea6f 100644 --- a/transports/bifrost-http/handlers/mcpserver.go +++ b/transports/bifrost-http/handlers/mcpserver.go @@ -95,11 +95,8 @@ func injectMCPSessionIdentity(bifrostCtx *schemas.BifrostContext, session *table if session.AccessToken != "" { bifrostCtx.SetValue(schemas.BifrostContextKeyMCPUserSession, session.AccessToken) } - if session.VirtualKeyID != nil && *session.VirtualKeyID != "" { - bifrostCtx.SetValue(schemas.BifrostContextKeyGovernanceVirtualKeyID, *session.VirtualKeyID) - if session.VirtualKey != nil && session.VirtualKey.Name != "" { - bifrostCtx.SetValue(schemas.BifrostContextKeyGovernanceVirtualKeyName, session.VirtualKey.Name) - } + if session.VirtualKeyID != nil && *session.VirtualKeyID != "" && session.VirtualKey != nil && session.VirtualKey.Value != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, session.VirtualKey.Value) } if session.UserID != nil && *session.UserID != "" { bifrostCtx.SetValue(schemas.BifrostContextKeyUserID, *session.UserID) diff --git a/transports/bifrost-http/integrations/router.go b/transports/bifrost-http/integrations/router.go index 9b6d37e84d..cb4372f559 100644 --- a/transports/bifrost-http/integrations/router.go +++ b/transports/bifrost-http/integrations/router.go @@ -52,6 +52,7 @@ import ( "context" "fmt" "io" + "errors" "mime" "mime/multipart" "strconv" @@ -1458,7 +1459,7 @@ func (g *GenericRouter) handleAsyncCreate( // Reject streaming + async if streamingReq, ok := req.(StreamingRequest); ok && streamingReq.IsStreamingRequested() { g.sendError(ctx, bifrostCtx, config.ErrorConverter, - newBifrostError(nil, "streaming is not supported for async requests")) + newBifrostErrorWithCode(nil, "streaming is not supported for async requests", fasthttp.StatusBadRequest)) return } @@ -1538,8 +1539,13 @@ func (g *GenericRouter) handleAsyncRetrieve( job, err := executor.RetrieveJob(bifrostCtx, jobID, vkValue, config.GetHTTPRequestType(ctx)) if err != nil { - g.sendError(ctx, bifrostCtx, config.ErrorConverter, - newBifrostError(err, "job not found or expired")) + if errors.Is(err, logstore.ErrJobInternal) { + g.sendError(ctx, bifrostCtx, config.ErrorConverter, + newBifrostErrorWithCode(err, "failed to retrieve async job", fasthttp.StatusInternalServerError)) + } else { + g.sendError(ctx, bifrostCtx, config.ErrorConverter, + newBifrostErrorWithCode(err, "job not found or expired", fasthttp.StatusNotFound)) + } return } diff --git a/transports/bifrost-http/integrations/utils.go b/transports/bifrost-http/integrations/utils.go index 8fda18b1cd..db11c61f77 100644 --- a/transports/bifrost-http/integrations/utils.go +++ b/transports/bifrost-http/integrations/utils.go @@ -30,6 +30,13 @@ var availableIntegrations = []string{ "cohere", } +// newBifrostErrorWithCode is like newBifrostError but sets an explicit HTTP status code. +func newBifrostErrorWithCode(err error, message string, statusCode int) *schemas.BifrostError { + e := newBifrostError(err, message) + e.StatusCode = &statusCode + return e +} + // newBifrostError wraps a standard error into a BifrostError with IsBifrostError set to false. // This helper function reduces code duplication when handling non-Bifrost errors. func newBifrostError(err error, message string) *schemas.BifrostError { diff --git a/transports/changelog.md b/transports/changelog.md index dae1f060c1..c31f62172f 100644 --- a/transports/changelog.md +++ b/transports/changelog.md @@ -1 +1,5 @@ - fix: malformed OAuth authorization URL when base URL already contains query parameters +- fix: usage of per-user OAuth servers in codemode +- fix: adds support for OCR requests logging +- fix: adds validation on direct api keys +- fix: adds support for OCR request pricing diff --git a/ui/app/_fallbacks/enterprise/components/mcp-tool-groups/mcpToolGroups.tsx b/ui/app/_fallbacks/enterprise/components/mcp-tool-groups/mcpToolGroups.tsx index 05a95013dd..8e6dfc089d 100644 --- a/ui/app/_fallbacks/enterprise/components/mcp-tool-groups/mcpToolGroups.tsx +++ b/ui/app/_fallbacks/enterprise/components/mcp-tool-groups/mcpToolGroups.tsx @@ -1,16 +1,15 @@ -import { CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { ToolCase } from "lucide-react"; import ContactUsView from "../views/contactUsView"; export default function MCPToolGroups() { return ( <> - - -

MCP tool groups

-
- Configure tool groups for MCP servers to organize and govern tools. -
+
+
+

MCP tool groups

+

Configure tool groups for MCP servers to organize and govern tools.

+
+
+
); diff --git a/ui/app/workspace/config/views/mcpView.tsx b/ui/app/workspace/config/views/mcpView.tsx index e246c8aa39..551555a22b 100644 --- a/ui/app/workspace/config/views/mcpView.tsx +++ b/ui/app/workspace/config/views/mcpView.tsx @@ -147,7 +147,7 @@ export default function MCPView() { return (
@@ -158,7 +158,7 @@ export default function MCPView() {
{/* Max Agent Depth */} -
+
{/* Tool Execution Timeout */} -
+