diff --git a/.env.example b/.env.example index 48f38e6b2e..fad483428b 100644 --- a/.env.example +++ b/.env.example @@ -121,12 +121,32 @@ VITE_SHOW_DEVTOOLS=false PROD=false +# Cloud Provider Configuration (Optional - Can also be configured via UI) +# ⭐ RECOMMENDED: Configure these through Settings > Cloud Providers in the UI for easier management +# +# Azure OpenAI: +# AZURE_OPENAI_ENDPOINT: Your Azure OpenAI resource endpoint +# AZURE_OPENAI_API_KEY: Your Azure OpenAI API key +# AZURE_OPENAI_API_VERSION: API version (e.g., 2024-02-15-preview) +# AZURE_OPENAI_DEPLOYMENT: Your deployment name +# +# AWS Bedrock: +# AWS_ACCESS_KEY_ID: Your AWS IAM Access Key ID +# AWS_SECRET_ACCESS_KEY: Your AWS IAM Secret Access Key +# AWS_REGION: AWS region for Bedrock (e.g., us-east-1, us-west-2) +# AWS_BEDROCK_MODEL_ID: Bedrock model ID (e.g., anthropic.claude-3-sonnet-20240229-v1:0) +# +# Note: UI configuration is preferred as it provides encrypted storage and easier management + # NOTE: All other configuration has been moved to database management! # Run the credentials_setup.sql file in your Supabase SQL editor to set up the credentials table. # Then use the Settings page in the web UI to manage: -# - OPENAI_API_KEY (encrypted) +# - OPENAI_API_KEY (encrypted) - or use Azure OpenAI credentials above +# - AZURE_OPENAI_API_KEY (encrypted) - if using Azure OpenAI # - MODEL_CHOICE # - TRANSPORT settings +# - LLM_PROVIDER (set to "azure-openai" if using Azure OpenAI) +# - EMBEDDING_PROVIDER (set to "azure-openai" if using Azure OpenAI for embeddings) # - RAG strategy flags (USE_CONTEXTUAL_EMBEDDINGS, USE_HYBRID_SEARCH, etc.) # - Crawler settings: # * CRAWL_MAX_CONCURRENT (default: 10) - Max concurrent pages per crawl operation diff --git a/README.md b/README.md index e3353a2e2c..ad2b7902e1 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ This new vision for Archon replaces the old one (the agenteer). Archon used to b - [Docker Desktop](https://www.docker.com/products/docker-desktop/) - [Node.js 18+](https://nodejs.org/) (for hybrid development mode) - [Supabase](https://supabase.com/) account (free tier or local Supabase both work) -- [OpenAI API key](https://platform.openai.com/api-keys) (Gemini and Ollama are supported too!) +- [OpenAI API key](https://platform.openai.com/api-keys) (Azure OpenAI, Gemini, and Ollama are supported too!) - (OPTIONAL) [Make](https://www.gnu.org/software/make/) (see [Installing Make](#installing-make) below) ### Setup Instructions @@ -160,16 +160,16 @@ sudo yum install make 🚀 Quick Command Reference for Make
-| Command | Description | -| ----------------- | ------------------------------------------------------- | +| Command | Description | +| ----------------- | ------------------------------------------------------ | | `make dev` | Start hybrid dev (backend in Docker, frontend local) ⭐ | -| `make dev-docker` | Everything in Docker | -| `make stop` | Stop all services | -| `make test` | Run all tests | -| `make lint` | Run linters | -| `make install` | Install dependencies | -| `make check` | Check environment setup | -| `make clean` | Remove containers and volumes (with confirmation) | +| `make dev-docker` | Everything in Docker | +| `make stop` | Stop all services | +| `make test` | Run all tests | +| `make lint` | Run linters | +| `make install` | Install dependencies | +| `make check` | Check environment setup | +| `make clean` | Remove containers and volumes (with confirmation) | @@ -248,7 +248,7 @@ To upgrade Archon to the latest version: - **Model Context Protocol (MCP)**: Connect any MCP-compatible client (Claude Code, Cursor, even non-AI coding assistants like Claude Desktop) - **MCP Tools**: Comprehensive yet simple set of tools for RAG queries, task management, and project operations -- **Multi-LLM Support**: Works with OpenAI, Ollama, and Google Gemini models +- **Multi-LLM Support**: Works with OpenAI, Azure OpenAI, Ollama, and Google Gemini models - **RAG Strategies**: Hybrid search, contextual embeddings, and result reranking for optimal AI responses - **Real-time Streaming**: Live responses from AI agents with progress tracking diff --git a/archon-ui-main/src/components/settings/APIKeysSection.tsx b/archon-ui-main/src/components/settings/APIKeysSection.tsx index 0d92601448..16bd2e592a 100644 --- a/archon-ui-main/src/components/settings/APIKeysSection.tsx +++ b/archon-ui-main/src/components/settings/APIKeysSection.tsx @@ -41,20 +41,37 @@ export const APIKeysSection = () => { const loadCredentials = async () => { try { setLoading(true); - + // Load all credentials const allCredentials = await credentialsService.getAllCredentials(); - + // Filter to only show API keys (credentials that end with _KEY or _API) + // EXCLUDE cloud provider credentials (they have their own dedicated section) + const cloudProviderKeys = [ + 'AZURE_OPENAI_API_KEY', + 'AZURE_OPENAI_ENDPOINT', + 'AZURE_OPENAI_API_VERSION', + 'AZURE_OPENAI_DEPLOYMENT', + 'AWS_ACCESS_KEY_ID', + 'AWS_SECRET_ACCESS_KEY', + 'AWS_REGION', + 'AWS_BEDROCK_MODEL_ID' + ]; + const apiKeys = allCredentials.filter(cred => { const key = cred.key.toUpperCase(); + // Exclude cloud provider credentials + if (cloudProviderKeys.includes(key)) { + return false; + } + // Include credentials with _KEY, _API, or API_ in the name return key.includes('_KEY') || key.includes('_API') || key.includes('API_'); }); - + // Convert to UI format const uiCredentials = apiKeys.map(cred => { const isEncryptedFromBackend = cred.is_encrypted && cred.value === '[ENCRYPTED]'; - + return { key: cred.key, value: cred.value || '', @@ -68,7 +85,7 @@ export const APIKeysSection = () => { isFromBackend: !cred.isNew, // Mark as from backend unless it's a new credential }; }); - + setCustomCredentials(uiCredentials); } catch (err) { console.error('Failed to load credentials:', err); @@ -90,7 +107,7 @@ export const APIKeysSection = () => { isNew: true, isFromBackend: false // New credentials are not from backend }; - + setCustomCredentials([...customCredentials, newCred]); }; @@ -134,7 +151,7 @@ export const APIKeysSection = () => { const deleteCredential = async (index: number) => { const cred = customCredentials[index]; - + if (cred.isNew) { // Just remove from UI if it's not saved yet setCustomCredentials(customCredentials.filter((_, i) => i !== index)); @@ -153,7 +170,7 @@ export const APIKeysSection = () => { const saveAllChanges = async () => { setSaving(true); let hasErrors = false; - + for (const cred of customCredentials) { if (cred.hasChanges || cred.isNew) { if (!cred.key) { @@ -161,7 +178,7 @@ export const APIKeysSection = () => { hasErrors = true; continue; } - + try { if (cred.isNew) { await credentialsService.createCredential({ @@ -200,12 +217,12 @@ export const APIKeysSection = () => { } } } - + if (!hasErrors) { showToast('All changes saved successfully!', 'success'); await loadCredentials(); // Reload to get fresh data } - + setSaving(false); }; @@ -225,179 +242,177 @@ export const APIKeysSection = () => { return ( -
- {/* Description text */} -

- Manage your API keys and credentials for various services used by Archon. -

- - {/* Credentials list */} -
- {/* Header row */} -
-
Key Name
-
Value
-
-
+
+ {/* Description text */} +

+ Manage your API keys and credentials for various services used by Archon. +

+ + {/* Credentials list */} +
+ {/* Header row */} +
+
Key Name
+
Value
+
+
+ + {/* Credential rows */} + {customCredentials.map((cred, index) => ( +
+ {/* Key name column */} +
+ updateCredential(index, 'key', e.target.value)} + placeholder="Enter key name" + className="w-full px-3 py-2 rounded-md bg-white dark:bg-gray-900 border border-gray-300 dark:border-gray-700 text-sm font-mono" + /> +
- {/* Credential rows */} - {customCredentials.map((cred, index) => ( -
- {/* Key name column */} -
+ {/* Value column with encryption toggle */} +
+
updateCredential(index, 'key', e.target.value)} - placeholder="Enter key name" - className="w-full px-3 py-2 rounded-md bg-white dark:bg-gray-900 border border-gray-300 dark:border-gray-700 text-sm font-mono" + type={cred.showValue ? 'text' : 'password'} + value={cred.value} + onChange={(e) => updateCredential(index, 'value', e.target.value)} + placeholder={cred.is_encrypted && !cred.value ? 'Enter new value (encrypted)' : 'Enter value'} + className={`w-full px-3 py-2 pr-20 rounded-md border text-sm ${cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]' + ? 'bg-gray-100 dark:bg-gray-800 border-gray-200 dark:border-gray-600 text-gray-500 dark:text-gray-400' + : 'bg-white dark:bg-gray-900 border-gray-300 dark:border-gray-700' + }`} + title={cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]' + ? 'Click to edit this encrypted credential' + : undefined} /> -
- {/* Value column with encryption toggle */} -
-
- updateCredential(index, 'value', e.target.value)} - placeholder={cred.is_encrypted && !cred.value ? 'Enter new value (encrypted)' : 'Enter value'} - className={`w-full px-3 py-2 pr-20 rounded-md border text-sm ${ - cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]' - ? 'bg-gray-100 dark:bg-gray-800 border-gray-200 dark:border-gray-600 text-gray-500 dark:text-gray-400' - : 'bg-white dark:bg-gray-900 border-gray-300 dark:border-gray-700' - }`} - title={cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]' - ? 'Click to edit this encrypted credential' - : undefined} - /> - - {/* Show/Hide value button */} - - - {/* Encryption toggle */} - + + {/* Encryption toggle */} + -
-
- - {/* Actions column */} -
-
- ))} -
- {/* Add credential button */} -
+ {/* Actions column */} +
+ +
+
+ ))} +
+ + {/* Add credential button */} +
+ +
+ + {/* Save all changes button */} + {hasUnsavedChanges && ( +
+
+ )} - {/* Save all changes button */} - {hasUnsavedChanges && ( -
- - -
- )} - - {/* Security Notice */} -
-
- -
-
-

- Encrypted credentials are masked after saving. Click on a masked credential to edit it - this allows you to change the value and encryption settings. -

-
+ {/* Security Notice */} +
+
+ +
+
+

+ Encrypted credentials are masked after saving. Click on a masked credential to edit it - this allows you to change the value and encryption settings. +

- +
+ ); }; \ No newline at end of file diff --git a/archon-ui-main/src/components/settings/CloudProvidersSection.tsx b/archon-ui-main/src/components/settings/CloudProvidersSection.tsx new file mode 100644 index 0000000000..63f98ad1c1 --- /dev/null +++ b/archon-ui-main/src/components/settings/CloudProvidersSection.tsx @@ -0,0 +1,463 @@ +import { useState, useEffect } from 'react'; +import { Cloud, Lock, Unlock, Eye, EyeOff, Save, Loader, AlertCircle } from 'lucide-react'; +import { Input } from '../ui/Input'; +import { Button } from '../ui/Button'; +import { Card } from '../ui/Card'; +import { credentialsService } from '../../services/credentialsService'; +import { useToast } from '../../features/shared/hooks/useToast'; + +// Cloud provider configurations +const CLOUD_PROVIDERS = { + azure: { + name: 'Azure OpenAI', + icon: ( + + + + ), + color: 'blue', + credentials: [ + { + key: 'AZURE_OPENAI_API_KEY', + label: 'API Key', + placeholder: 'Enter your Azure OpenAI API key', + type: 'password', + encrypted: true, + required: true, + description: 'Your Azure OpenAI resource API key from Azure Portal' + }, + { + key: 'AZURE_OPENAI_ENDPOINT', + label: 'Endpoint', + placeholder: 'https://your-resource.openai.azure.com/', + type: 'text', + encrypted: false, + required: true, + description: 'Your Azure OpenAI resource endpoint URL' + }, + { + key: 'AZURE_OPENAI_API_VERSION', + label: 'API Version', + placeholder: '2024-02-15-preview', + type: 'text', + encrypted: false, + required: true, + description: 'API version (YYYY-MM-DD format)' + }, + { + key: 'AZURE_OPENAI_DEPLOYMENT', + label: 'Deployment Name', + placeholder: 'my-gpt4-deployment', + type: 'text', + encrypted: false, + required: false, + description: 'Default deployment name (optional, can be set per operation)' + } + ] + }, + aws: { + name: 'AWS Bedrock', + icon: ( + + + + ), + color: 'orange', + credentials: [ + { + key: 'AWS_ACCESS_KEY_ID', + label: 'Access Key ID', + placeholder: 'Enter your AWS Access Key ID', + type: 'password', + encrypted: true, + required: true, + description: 'AWS IAM Access Key ID' + }, + { + key: 'AWS_SECRET_ACCESS_KEY', + label: 'Secret Access Key', + placeholder: 'Enter your AWS Secret Access Key', + type: 'password', + encrypted: true, + required: true, + description: 'AWS IAM Secret Access Key' + }, + { + key: 'AWS_REGION', + label: 'Region', + placeholder: 'us-east-1', + type: 'text', + encrypted: false, + required: true, + description: 'AWS region for Bedrock (e.g., us-east-1, us-west-2)' + }, + { + key: 'AWS_BEDROCK_MODEL_ID', + label: 'Model ID', + placeholder: 'anthropic.claude-3-sonnet-20240229-v1:0', + type: 'text', + encrypted: false, + required: false, + description: 'Default Bedrock model ID (optional)' + } + ] + } +}; + +interface CredentialValue { + value: string; + showValue: boolean; + hasChanges: boolean; + originalValue: string; + isFromBackend: boolean; +} + +type ProviderKey = keyof typeof CLOUD_PROVIDERS; + +export const CloudProvidersSection = () => { + const [selectedProvider, setSelectedProvider] = useState('azure'); + const [credentialValues, setCredentialValues] = useState>({}); + const [loading, setLoading] = useState(true); + const [saving, setSaving] = useState(false); + const [hasUnsavedChanges, setHasUnsavedChanges] = useState(false); + + const { showToast } = useToast(); + + useEffect(() => { + loadCredentials(); + }, []); + + useEffect(() => { + // Check if there are unsaved changes + const hasChanges = Object.values(credentialValues).some(cred => cred.hasChanges); + setHasUnsavedChanges(hasChanges); + }, [credentialValues]); + + const loadCredentials = async () => { + try { + setLoading(true); + + // Load all cloud provider credentials + const allCredentials = await credentialsService.getAllCredentials(); + + const values: Record = {}; + + // Initialize all credential values from backend + Object.values(CLOUD_PROVIDERS).forEach(provider => { + provider.credentials.forEach(cred => { + const backendCred = allCredentials.find(c => c.key === cred.key); + const isEncryptedFromBackend = backendCred?.is_encrypted && backendCred.value === '[ENCRYPTED]'; + + values[cred.key] = { + value: backendCred?.value || '', + showValue: false, + hasChanges: false, + originalValue: backendCred?.value || '', + isFromBackend: !!backendCred && !backendCred.isNew + }; + }); + }); + + setCredentialValues(values); + } catch (err) { + console.error('Failed to load cloud credentials:', err); + showToast('Failed to load cloud provider credentials', 'error'); + } finally { + setLoading(false); + } + }; + + const updateCredentialValue = (key: string, value: string) => { + setCredentialValues(prev => { + const current = prev[key]; + const updated = { + ...current, + value, + hasChanges: true + }; + + // If editing an encrypted credential from backend, make it editable + if (current.isFromBackend && current.value === '[ENCRYPTED]' && value !== '[ENCRYPTED]') { + updated.isFromBackend = false; + updated.showValue = false; + if (value === '') { + // If they click to edit but haven't entered anything, clear the placeholder + updated.value = ''; + } + } + + return { + ...prev, + [key]: updated + }; + }); + }; + + const toggleValueVisibility = (key: string) => { + const cred = credentialValues[key]; + if (cred.isFromBackend && cred.value === '[ENCRYPTED]') { + showToast('Encrypted credentials cannot be viewed. Edit to make changes.', 'warning'); + return; + } + + setCredentialValues(prev => ({ + ...prev, + [key]: { + ...prev[key], + showValue: !prev[key].showValue + } + })); + }; + + const saveChanges = async () => { + setSaving(true); + let hasErrors = false; + const provider = CLOUD_PROVIDERS[selectedProvider]; + + try { + for (const credConfig of provider.credentials) { + const credValue = credentialValues[credConfig.key]; + + if (!credValue.hasChanges) continue; + + // Validate required fields + if (credConfig.required && !credValue.value) { + showToast(`${credConfig.label} is required`, 'error'); + hasErrors = true; + continue; + } + + // Skip if value is still [ENCRYPTED] (unchanged) + if (credValue.value === '[ENCRYPTED]') { + continue; + } + + try { + // Check if credential exists + const existing = await credentialsService.getCredential(credConfig.key).catch(() => null); + + if (existing) { + // Update existing + await credentialsService.updateCredential({ + key: credConfig.key, + value: credValue.value, + is_encrypted: credConfig.encrypted, + category: 'cloud_providers', + description: credConfig.description + }); + } else { + // Create new + await credentialsService.createCredential({ + key: credConfig.key, + value: credValue.value, + is_encrypted: credConfig.encrypted, + category: 'cloud_providers', + description: credConfig.description + }); + } + } catch (err) { + console.error(`Failed to save ${credConfig.key}:`, err); + showToast(`Failed to save ${credConfig.label}`, 'error'); + hasErrors = true; + } + } + + if (!hasErrors) { + showToast(`${provider.name} credentials saved successfully!`, 'success'); + await loadCredentials(); // Reload to get fresh data + } + } finally { + setSaving(false); + } + }; + + const discardChanges = () => { + loadCredentials(); + }; + + if (loading) { + return ( + +
+ +
+
+ ); + } + + const currentProvider = CLOUD_PROVIDERS[selectedProvider]; + + return ( + +
+ {/* Description */} +
+ +
+

+ Configure cloud-based AI services. These providers offer enterprise-grade security, compliance, and regional availability. +

+
+
+ + {/* Provider Selector */} +
+ {(Object.keys(CLOUD_PROVIDERS) as ProviderKey[]).map(providerKey => { + const provider = CLOUD_PROVIDERS[providerKey]; + const isSelected = selectedProvider === providerKey; + + return ( + + ); + })} +
+ + {/* Credentials Form */} +
+ {currentProvider.credentials.map(credConfig => { + const credValue = credentialValues[credConfig.key] || { + value: '', + showValue: false, + hasChanges: false, + originalValue: '', + isFromBackend: false + }; + + const isEncrypted = credConfig.encrypted; + const isFromBackend = credValue.isFromBackend && credValue.value === '[ENCRYPTED]'; + + return ( +
+ + +
+ updateCredentialValue(credConfig.key, e.target.value)} + placeholder={isFromBackend ? 'Click to edit encrypted value' : credConfig.placeholder} + className={` + w-full px-4 py-2 pr-24 rounded-md border text-sm + ${isFromBackend + ? 'bg-gray-100 dark:bg-gray-800 border-gray-200 dark:border-gray-600 text-gray-500 dark:text-gray-400' + : 'bg-white dark:bg-gray-900 border-gray-300 dark:border-gray-700' + } + focus:outline-none focus:ring-2 focus:ring-${currentProvider.color}-500 focus:border-transparent + `} + title={isFromBackend ? 'Click to edit this encrypted credential' : undefined} + /> + + {/* Show/Hide button for password fields */} + {credConfig.type === 'password' && ( + + )} + + {/* Encryption indicator */} + {isEncrypted && ( +
+ +
+ )} +
+ + {credConfig.description && ( +

+ {credConfig.description} +

+ )} +
+ ); + })} +
+ + {/* Save Button */} + {hasUnsavedChanges && ( +
+ + +
+ )} + + {/* Info Box */} +
+ +
+

Cloud Provider Setup:

+
    +
  • Encrypted credentials are masked after saving
  • +
  • Set LLM_PROVIDER to 'azure-openai' or 'aws-bedrock' in RAG Settings to use
  • +
  • Regional availability and pricing may vary by provider
  • +
+
+
+
+
+ ); +}; diff --git a/archon-ui-main/src/pages/SettingsPage.tsx b/archon-ui-main/src/pages/SettingsPage.tsx index 351366161d..f93b83579e 100644 --- a/archon-ui-main/src/pages/SettingsPage.tsx +++ b/archon-ui-main/src/pages/SettingsPage.tsx @@ -12,6 +12,7 @@ import { Bug, Info, Database, + Cloud, } from "lucide-react"; import { motion, AnimatePresence } from "framer-motion"; import { useToast } from "../features/shared/hooks/useToast"; @@ -19,6 +20,7 @@ import { useSettings } from "../contexts/SettingsContext"; import { useStaggeredEntrance } from "../hooks/useStaggeredEntrance"; import { FeaturesSection } from "../components/settings/FeaturesSection"; import { APIKeysSection } from "../components/settings/APIKeysSection"; +import { CloudProvidersSection } from "../components/settings/CloudProvidersSection"; import { RAGSettings } from "../components/settings/RAGSettings"; import { CodeExtractionSettings } from "../components/settings/CodeExtractionSettings"; import { IDEGlobalRules } from "../components/settings/IDEGlobalRules"; @@ -199,6 +201,17 @@ export const SettingsPage = () => { + + + + + Cloud Providers in the UI +INSERT INTO archon_settings (key, encrypted_value, is_encrypted, category, description) VALUES +('AZURE_OPENAI_API_KEY', NULL, true, 'cloud_providers', 'Azure OpenAI API key from Azure Portal. Configure via Settings > Cloud Providers.'), +('AZURE_OPENAI_ENDPOINT', NULL, false, 'cloud_providers', 'Azure OpenAI resource endpoint URL (e.g., https://your-resource.openai.azure.com/)'), +('AZURE_OPENAI_API_VERSION', NULL, false, 'cloud_providers', 'Azure OpenAI API version in YYYY-MM-DD format (e.g., 2024-02-15-preview)'), +('AZURE_OPENAI_DEPLOYMENT', NULL, false, 'cloud_providers', 'Azure OpenAI deployment name (optional, used for default model)'), +('AWS_ACCESS_KEY_ID', NULL, true, 'cloud_providers', 'AWS IAM Access Key ID for Bedrock. Configure via Settings > Cloud Providers.'), +('AWS_SECRET_ACCESS_KEY', NULL, true, 'cloud_providers', 'AWS IAM Secret Access Key for Bedrock. Configure via Settings > Cloud Providers.'), +('AWS_REGION', NULL, false, 'cloud_providers', 'AWS region for Bedrock service (e.g., us-east-1, us-west-2)'), +('AWS_BEDROCK_MODEL_ID', NULL, false, 'cloud_providers', 'Default AWS Bedrock model ID (e.g., anthropic.claude-3-sonnet-20240229-v1:0)') +ON CONFLICT (key) DO NOTHING; + -- Code Extraction Settings Migration -- Adds configurable settings for the code extraction service diff --git a/python/pyproject.toml b/python/pyproject.toml index 128e433290..694d0cbf93 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -45,6 +45,7 @@ server = [ "asyncpg>=0.29.0", # AI/ML libraries "openai==1.71.0", + "boto3>=1.34.0", # AWS SDK for Bedrock # Document processing "pypdf2>=3.0.1", "pdfplumber>=0.11.6", @@ -123,6 +124,7 @@ all = [ "supabase==2.15.1", "asyncpg>=0.29.0", "openai==1.71.0", + "boto3>=1.34.0", "pypdf2>=3.0.1", "pdfplumber>=0.11.6", "python-docx>=1.1.2", diff --git a/python/src/server/adapters/__init__.py b/python/src/server/adapters/__init__.py new file mode 100644 index 0000000000..247e00de25 --- /dev/null +++ b/python/src/server/adapters/__init__.py @@ -0,0 +1,7 @@ +""" +Adapters for external LLM providers that don't have OpenAI-compatible APIs. +""" + +from .aws_bedrock_adapter import AWSBedrockClientAdapter + +__all__ = ["AWSBedrockClientAdapter"] diff --git a/python/src/server/adapters/aws_bedrock_adapter.py b/python/src/server/adapters/aws_bedrock_adapter.py new file mode 100644 index 0000000000..a30411da60 --- /dev/null +++ b/python/src/server/adapters/aws_bedrock_adapter.py @@ -0,0 +1,374 @@ +""" +AWS Bedrock Client Adapter + +Provides an OpenAI-compatible interface for AWS Bedrock's Converse API. +This adapter wraps the boto3 bedrock-runtime client to work with our existing +LLM provider infrastructure. +""" + +import asyncio +import json +from typing import Any + +from ..config.logfire_config import get_logger + +logger = get_logger(__name__) + + +class AWSBedrockClientAdapter: + """ + Adapter to make AWS Bedrock Converse API compatible with OpenAI-style async clients. + + This adapter implements the minimum interface needed for our LLM operations, + translating between OpenAI's chat completion format and AWS Bedrock's Converse API. + """ + + def __init__(self, bedrock_client: Any, region: str): + """ + Initialize the AWS Bedrock adapter. + + Args: + bedrock_client: boto3 bedrock-runtime client instance + region: AWS region for the Bedrock service + """ + self.bedrock_client = bedrock_client + self.region = region + self._executor = None # Will be created on first use + + def _get_executor(self): + """Get or create thread pool executor for async operations.""" + if self._executor is None: + import concurrent.futures + + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) + return self._executor + + async def aclose(self): + """Close the adapter and cleanup resources.""" + if self._executor is not None: + self._executor.shutdown(wait=True) + self._executor = None + logger.debug("AWS Bedrock adapter closed") + + async def close(self): + """Alias for aclose() for compatibility.""" + await self.aclose() + + @property + def chat(self): + """Return chat completions interface.""" + return self + + @property + def completions(self): + """Return completions interface.""" + return ChatCompletions(self) + + def _openai_to_bedrock_messages(self, openai_messages: list[dict]) -> list[dict]: + """ + Convert OpenAI message format to Bedrock Converse API format. + + OpenAI format: + [{"role": "system", "content": "..."}, + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."}] + + Bedrock format: + [{"role": "user", "content": [{"text": "..."}]}, + {"role": "assistant", "content": [{"text": "..."}]}] + + Note: Bedrock handles system prompts separately, not in messages. + """ + bedrock_messages = [] + system_prompt = None + + for msg in openai_messages: + role = msg.get("role", "user") + content = msg.get("content", "") + + if role == "system": + # Bedrock handles system prompts separately + system_prompt = content + continue + + # Convert role (Bedrock uses "user" and "assistant") + bedrock_role = "assistant" if role == "assistant" else "user" + + # Convert content to Bedrock format + if isinstance(content, str): + bedrock_content = [{"text": content}] + elif isinstance(content, list): + # Handle multimodal content if needed + bedrock_content = [] + for item in content: + if isinstance(item, dict): + if item.get("type") == "text": + bedrock_content.append({"text": item.get("text", "")}) + # Add support for images if needed in future + elif isinstance(item, str): + bedrock_content.append({"text": item}) + else: + bedrock_content = [{"text": str(content)}] + + bedrock_messages.append({"role": bedrock_role, "content": bedrock_content}) + + return bedrock_messages, system_prompt + + def _bedrock_to_openai_response(self, bedrock_response: dict, model: str) -> dict: + """ + Convert Bedrock Converse API response to OpenAI format. + + Bedrock response: + { + "output": { + "message": { + "role": "assistant", + "content": [{"text": "..."}] + } + }, + "stopReason": "end_turn", + "usage": { + "inputTokens": 10, + "outputTokens": 20, + "totalTokens": 30 + } + } + + OpenAI format: + { + "id": "chatcmpl-xxx", + "object": "chat.completion", + "created": 1234567890, + "model": "model-name", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "..." + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30 + } + } + """ + import time + import uuid + + # Extract message content + output_message = bedrock_response.get("output", {}).get("message", {}) + content_blocks = output_message.get("content", []) + + # Combine text blocks + content = " ".join( + block.get("text", "") for block in content_blocks if "text" in block + ) + + # Map stop reason + stop_reason_map = { + "end_turn": "stop", + "max_tokens": "length", + "stop_sequence": "stop", + "content_filtered": "content_filter", + } + finish_reason = stop_reason_map.get( + bedrock_response.get("stopReason", "end_turn"), "stop" + ) + + # Extract usage + bedrock_usage = bedrock_response.get("usage", {}) + + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:24]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": finish_reason, + } + ], + "usage": { + "prompt_tokens": bedrock_usage.get("inputTokens", 0), + "completion_tokens": bedrock_usage.get("outputTokens", 0), + "total_tokens": bedrock_usage.get("totalTokens", 0), + }, + } + + async def create( + self, + model: str, + messages: list[dict], + temperature: float = 1.0, + max_tokens: int | None = None, + max_completion_tokens: int | None = None, + stream: bool = False, + **kwargs: Any, + ) -> dict: + """ + Create a chat completion using AWS Bedrock Converse API. + + Args: + model: Bedrock model ID (e.g., "anthropic.claude-3-sonnet-20240229-v1:0") + messages: List of messages in OpenAI format + temperature: Sampling temperature (0.0 to 1.0) + max_tokens: Maximum tokens to generate (OpenAI style) + max_completion_tokens: Maximum tokens to generate (new OpenAI style) + stream: Whether to stream the response (not supported yet) + **kwargs: Additional parameters + + Returns: + Chat completion response in OpenAI format + """ + if stream: + raise NotImplementedError( + "Streaming is not yet supported for AWS Bedrock adapter" + ) + + # Convert messages to Bedrock format + bedrock_messages, system_prompt = self._openai_to_bedrock_messages(messages) + + # Determine max tokens (prefer max_completion_tokens if provided) + max_tokens_value = max_completion_tokens or max_tokens or 2048 + + # Build inference configuration + inference_config = { + "temperature": temperature, + "maxTokens": max_tokens_value, + } + + # Add top_p if provided + if "top_p" in kwargs: + inference_config["topP"] = kwargs["top_p"] + + # Build converse request + converse_params = { + "modelId": model, + "messages": bedrock_messages, + "inferenceConfig": inference_config, + } + + # Add system prompt if present + if system_prompt: + converse_params["system"] = [{"text": system_prompt}] + + try: + # Call Bedrock Converse API asynchronously + loop = asyncio.get_event_loop() + executor = self._get_executor() + + bedrock_response = await loop.run_in_executor( + executor, lambda: self.bedrock_client.converse(**converse_params) + ) + + # Convert response to OpenAI format + openai_response = self._bedrock_to_openai_response(bedrock_response, model) + + logger.debug( + f"AWS Bedrock completion successful. Tokens used: {openai_response['usage']['total_tokens']}" + ) + + return openai_response + + except Exception as e: + logger.error(f"Error calling AWS Bedrock Converse API: {e}") + raise + + +class ChatCompletions: + """Chat completions interface wrapper.""" + + def __init__(self, adapter: AWSBedrockClientAdapter): + self.adapter = adapter + + async def create(self, *args, **kwargs): + """Create a chat completion.""" + return await self.adapter.create(*args, **kwargs) + + +# Bedrock embedding adapter for future use +class AWSBedrockEmbeddingAdapter: + """ + Adapter for AWS Bedrock embeddings. + + AWS Bedrock supports embedding models like Amazon Titan Embeddings. + This adapter will be used by the embedding service. + """ + + def __init__(self, bedrock_client: Any, region: str): + """ + Initialize the AWS Bedrock embedding adapter. + + Args: + bedrock_client: boto3 bedrock-runtime client instance + region: AWS region for the Bedrock service + """ + self.bedrock_client = bedrock_client + self.region = region + self._executor = None + + def _get_executor(self): + """Get or create thread pool executor for async operations.""" + if self._executor is None: + import concurrent.futures + + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) + return self._executor + + async def aclose(self): + """Close the adapter and cleanup resources.""" + if self._executor is not None: + self._executor.shutdown(wait=True) + self._executor = None + logger.debug("AWS Bedrock embedding adapter closed") + + async def generate_embeddings( + self, texts: list[str], model: str = "amazon.titan-embed-text-v1" + ) -> list[list[float]]: + """ + Generate embeddings using AWS Bedrock. + + Args: + texts: List of text strings to embed + model: Bedrock embedding model ID + + Returns: + List of embedding vectors + """ + embeddings = [] + + try: + loop = asyncio.get_event_loop() + executor = self._get_executor() + + for text in texts: + # Prepare request body for Titan Embeddings + request_body = json.dumps({"inputText": text}) + + # Call invoke_model asynchronously (bind request_body in lambda) + response = await loop.run_in_executor( + executor, + lambda body=request_body: self.bedrock_client.invoke_model( + modelId=model, + contentType="application/json", + accept="application/json", + body=body, + ), + ) + + # Parse response + response_body = json.loads(response["body"].read()) + embedding = response_body.get("embedding", []) + embeddings.append(embedding) + + logger.debug(f"Generated {len(embeddings)} embeddings using AWS Bedrock") + return embeddings + + except Exception as e: + logger.error(f"Error generating embeddings with AWS Bedrock: {e}") + raise diff --git a/python/src/server/api_routes/knowledge_api.py b/python/src/server/api_routes/knowledge_api.py index 052f75216e..620b7c2350 100644 --- a/python/src/server/api_routes/knowledge_api.py +++ b/python/src/server/api_routes/knowledge_api.py @@ -26,7 +26,11 @@ from ..services.crawling import CrawlingService from ..services.credential_service import credential_service from ..services.embeddings.provider_error_adapters import ProviderErrorFactory -from ..services.knowledge import DatabaseMetricsService, KnowledgeItemService, KnowledgeSummaryService +from ..services.knowledge import ( + DatabaseMetricsService, + KnowledgeItemService, + KnowledgeSummaryService, +) from ..services.search.rag_service import RAGService from ..services.storage import DocumentStorageService from ..utils import get_supabase_client @@ -50,39 +54,50 @@ # # The hardcoded limit of 3 protects the server from being overwhelmed by multiple users # starting crawls at the same time. Each crawl can still process many pages in parallel. -CONCURRENT_CRAWL_LIMIT = 3 # Max simultaneous crawl operations (protects server resources) +CONCURRENT_CRAWL_LIMIT = ( + 3 # Max simultaneous crawl operations (protects server resources) +) crawl_semaphore = asyncio.Semaphore(CONCURRENT_CRAWL_LIMIT) # Track active async crawl tasks for cancellation support active_crawl_tasks: dict[str, asyncio.Task] = {} - - async def _validate_provider_api_key(provider: str = None) -> None: """Validate LLM provider API key before starting operations.""" logger.info("🔑 Starting API key validation...") - + try: # Basic provider validation if not provider: provider = "openai" else: - # Simple provider validation - allowed_providers = {"openai", "ollama", "google", "openrouter", "anthropic", "grok"} + # Simple provider validation - include cloud providers + allowed_providers = { + "openai", + "azure-openai", # Azure OpenAI cloud provider + "aws-bedrock", # AWS Bedrock cloud provider + "ollama", + "google", + "openrouter", + "anthropic", + "grok", + } if provider not in allowed_providers: raise HTTPException( status_code=400, detail={ "error": "Invalid provider name", "message": f"Provider '{provider}' not supported", - "error_type": "validation_error" - } + "error_type": "validation_error", + }, ) # Basic sanitization for logging safe_provider = provider[:20] # Limit length - logger.info(f"🔑 Testing {safe_provider.title()} API key with minimal embedding request...") + logger.info( + f"🔑 Testing {safe_provider.title()} API key with minimal embedding request..." + ) try: # Test API key with minimal embedding request using provider-scoped configuration @@ -117,7 +132,7 @@ async def _validate_provider_api_key(provider: str = None) -> None: "provider": provider, }, ) - + logger.info(f"✅ {provider.title()} API key validation successful") except HTTPException: @@ -127,9 +142,13 @@ async def _validate_provider_api_key(provider: str = None) -> None: except Exception as e: # Sanitize error before logging to prevent sensitive data exposure error_str = str(e) - sanitized_error = ProviderErrorFactory.sanitize_provider_error(error_str, provider or "openai") - logger.error(f"❌ Caught exception during API key validation: {sanitized_error}") - + sanitized_error = ProviderErrorFactory.sanitize_provider_error( + error_str, provider or "openai" + ) + logger.error( + f"❌ Caught exception during API key validation: {sanitized_error}" + ) + # Always fail for any exception during validation - better safe than sorry logger.error("🚨 API key validation failed - blocking crawl operation") raise HTTPException( @@ -138,8 +157,8 @@ async def _validate_provider_api_key(provider: str = None) -> None: "error": "Invalid API key", "message": f"Please verify your {(provider or 'openai').title()} API key in Settings before starting a crawl.", "error_type": "authentication_failed", - "provider": provider or "openai" - } + "provider": provider or "openai", + }, ) from None @@ -183,7 +202,7 @@ class RagQueryRequest(BaseModel): @router.get("/crawl-progress/{progress_id}") async def get_crawl_progress(progress_id: str): """Get crawl progress for polling. - + Returns the current state of a crawl operation. Frontend should poll this endpoint to track crawl progress. """ @@ -193,11 +212,16 @@ async def get_crawl_progress(progress_id: str): # Get progress from the tracker's in-memory storage progress_data = ProgressTracker.get_progress(progress_id) - safe_logfire_info(f"Crawl progress requested | progress_id={progress_id} | found={progress_data is not None}") + safe_logfire_info( + f"Crawl progress requested | progress_id={progress_id} | found={progress_data is not None}" + ) if not progress_data: # Return 404 if no progress exists - this is correct behavior - raise HTTPException(status_code=404, detail={"error": f"No progress found for ID: {progress_id}"}) + raise HTTPException( + status_code=404, + detail={"error": f"No progress found for ID: {progress_id}"}, + ) # Ensure we have the progress_id in the data progress_data["progress_id"] = progress_id @@ -219,7 +243,9 @@ async def get_crawl_progress(progress_id: str): return response_data except Exception as e: - safe_logfire_error(f"Failed to get crawl progress | error={str(e)} | progress_id={progress_id}") + safe_logfire_error( + f"Failed to get crawl progress | error={str(e)} | progress_id={progress_id}" + ) raise HTTPException(status_code=500, detail={"error": str(e)}) @@ -237,7 +263,10 @@ async def get_knowledge_sources(): @router.get("/knowledge-items") async def get_knowledge_items( - page: int = 1, per_page: int = 20, knowledge_type: str | None = None, search: str | None = None + page: int = 1, + per_page: int = 20, + knowledge_type: str | None = None, + search: str | None = None, ): """Get knowledge items with pagination and filtering.""" try: @@ -257,16 +286,19 @@ async def get_knowledge_items( @router.get("/knowledge-items/summary") async def get_knowledge_items_summary( - page: int = 1, per_page: int = 20, knowledge_type: str | None = None, search: str | None = None + page: int = 1, + per_page: int = 20, + knowledge_type: str | None = None, + search: str | None = None, ): """ Get lightweight summaries of knowledge items. - + Returns minimal data optimized for frequent polling: - Only counts, no actual document/code content - Basic metadata for display - Efficient batch queries - + Use this endpoint for card displays and frequent polling. """ try: @@ -298,9 +330,13 @@ async def update_knowledge_item(source_id: str, updates: dict): return result else: if "not found" in result.get("error", "").lower(): - raise HTTPException(status_code=404, detail={"error": result.get("error")}) + raise HTTPException( + status_code=404, detail={"error": result.get("error")} + ) else: - raise HTTPException(status_code=500, detail={"error": result.get("error")}) + raise HTTPException( + status_code=500, detail={"error": result.get("error")} + ) except HTTPException: raise @@ -337,15 +373,21 @@ async def delete_knowledge_item(source_id: str): } if result.get("success"): - safe_logfire_info(f"Knowledge item deleted successfully | source_id={source_id}") + safe_logfire_info( + f"Knowledge item deleted successfully | source_id={source_id}" + ) - return {"success": True, "message": f"Successfully deleted knowledge item {source_id}"} + return { + "success": True, + "message": f"Successfully deleted knowledge item {source_id}", + } else: safe_logfire_error( f"Knowledge item deletion failed | source_id={source_id} | error={result.get('error')}" ) raise HTTPException( - status_code=500, detail={"error": result.get("error", "Deletion failed")} + status_code=500, + detail={"error": result.get("error", "Deletion failed")}, ) except Exception as e: @@ -362,28 +404,25 @@ async def delete_knowledge_item(source_id: str): @router.get("/knowledge-items/{source_id}/chunks") async def get_knowledge_item_chunks( - source_id: str, - domain_filter: str | None = None, - limit: int = 20, - offset: int = 0 + source_id: str, domain_filter: str | None = None, limit: int = 20, offset: int = 0 ): """ Get document chunks for a specific knowledge item with pagination. - + Args: source_id: The source ID domain_filter: Optional domain filter for URLs limit: Maximum number of chunks to return (default 20, max 100) offset: Number of chunks to skip (for pagination) - + Returns: Paginated chunks with metadata """ try: # Validate pagination parameters limit = min(limit, 100) # Cap at 100 to prevent excessive data transfer - limit = max(limit, 1) # At least 1 - offset = max(offset, 0) # Can't be negative + limit = max(limit, 1) # At least 1 + offset = max(offset, 0) # Can't be negative safe_logfire_info( f"Fetching chunks | source_id={source_id} | domain_filter={domain_filter} | " @@ -468,10 +507,17 @@ async def get_knowledge_item_chunks( for line in lines: line = line.strip() # Skip code blocks, empty lines, and very short lines - if (line and not line.startswith("```") and not line.startswith("Source:") - and len(line) > 15 and len(line) < 80 - and not line.startswith("from ") and not line.startswith("import ") - and "=" not in line and "{" not in line): + if ( + line + and not line.startswith("```") + and not line.startswith("Source:") + and len(line) > 15 + and len(line) < 80 + and not line.startswith("from ") + and not line.startswith("import ") + and "=" not in line + and "{" not in line + ): title = line break @@ -481,17 +527,31 @@ async def get_knowledge_item_chunks( if url: # Extract meaningful part from URL if url.endswith(".txt"): - title = url.split("/")[-1].replace(".txt", "").replace("-", " ").title() + title = ( + url.split("/")[-1] + .replace(".txt", "") + .replace("-", " ") + .title() + ) else: # Get domain and path info parsed = urlparse(url) if parsed.path and parsed.path != "/": - title = parsed.path.strip("/").replace("-", " ").replace("_", " ").title() + title = ( + parsed.path.strip("/") + .replace("-", " ") + .replace("_", " ") + .title() + ) else: title = parsed.netloc.replace("www.", "").title() chunk["title"] = title or "" - chunk["section"] = metadata.get("headers", "").replace(";", " > ") if metadata.get("headers") else None + chunk["section"] = ( + metadata.get("headers", "").replace(";", " > ") + if metadata.get("headers") + else None + ) chunk["source_type"] = metadata.get("source_type") chunk["knowledge_type"] = metadata.get("knowledge_type") @@ -521,26 +581,24 @@ async def get_knowledge_item_chunks( @router.get("/knowledge-items/{source_id}/code-examples") async def get_knowledge_item_code_examples( - source_id: str, - limit: int = 20, - offset: int = 0 + source_id: str, limit: int = 20, offset: int = 0 ): """ Get code examples for a specific knowledge item with pagination. - + Args: source_id: The source ID limit: Maximum number of examples to return (default 20, max 100) offset: Number of examples to skip (for pagination) - + Returns: Paginated code examples with metadata """ try: # Validate pagination parameters limit = min(limit, 100) # Cap at 100 to prevent excessive data transfer - limit = max(limit, 1) # At least 1 - offset = max(offset, 0) # Can't be negative + limit = max(limit, 1) # At least 1 + offset = max(offset, 0) # Can't be negative safe_logfire_info( f"Fetching code examples | source_id={source_id} | limit={limit} | offset={offset}" @@ -582,9 +640,13 @@ async def get_knowledge_item_code_examples( metadata = example.get("metadata", {}) or {} # Extract fields to match frontend TypeScript types example["title"] = metadata.get("title") # AI-generated title - example["example_name"] = metadata.get("example_name") # Same as title for compatibility + example["example_name"] = metadata.get( + "example_name" + ) # Same as title for compatibility example["language"] = metadata.get("language") # Programming language - example["file_path"] = metadata.get("file_path") # Original file path if available + example["file_path"] = metadata.get( + "file_path" + ) # Original file path if available # Note: content field is already at top level from database # Note: summary field is already at top level from database @@ -612,14 +674,14 @@ async def get_knowledge_item_code_examples( @router.post("/knowledge-items/{source_id}/refresh") async def refresh_knowledge_item(source_id: str): """Refresh a knowledge item by re-crawling its URL with the same metadata.""" - + # Validate API key before starting expensive refresh operation logger.info("🔍 About to validate API key for refresh...") provider_config = await credential_service.get_active_provider("embedding") provider = provider_config.get("provider", "openai") await _validate_provider_api_key(provider) logger.info("✅ API key validation completed successfully for refresh") - + try: safe_logfire_info(f"Starting knowledge item refresh | source_id={source_id}") @@ -629,7 +691,8 @@ async def refresh_knowledge_item(source_id: str): if not existing_item: raise HTTPException( - status_code=404, detail={"error": f"Knowledge item {source_id} not found"} + status_code=404, + detail={"error": f"Knowledge item {source_id} not found"}, ) # Extract metadata @@ -640,7 +703,8 @@ async def refresh_knowledge_item(source_id: str): url = metadata.get("original_url") or existing_item.get("url") if not url: raise HTTPException( - status_code=400, detail={"error": "Knowledge item does not have a URL to refresh"} + status_code=400, + detail={"error": "Knowledge item does not have a URL to refresh"}, ) knowledge_type = metadata.get("knowledge_type", "technical") tags = metadata.get("tags", []) @@ -651,26 +715,32 @@ async def refresh_knowledge_item(source_id: str): # Initialize progress tracker IMMEDIATELY so it's available for polling from ..utils.progress.progress_tracker import ProgressTracker + tracker = ProgressTracker(progress_id, operation_type="crawl") - await tracker.start({ - "url": url, - "status": "initializing", - "progress": 0, - "log": f"Starting refresh for {url}", - "source_id": source_id, - "operation": "refresh", - "crawl_type": "refresh" - }) + await tracker.start( + { + "url": url, + "status": "initializing", + "progress": 0, + "log": f"Starting refresh for {url}", + "source_id": source_id, + "operation": "refresh", + "crawl_type": "refresh", + } + ) # Get crawler from CrawlerManager - same pattern as _perform_crawl_with_progress try: crawler = await get_crawler() if crawler is None: - raise Exception("Crawler not available - initialization may have failed") + raise Exception( + "Crawler not available - initialization may have failed" + ) except Exception as e: safe_logfire_error(f"Failed to get crawler | error={str(e)}") raise HTTPException( - status_code=500, detail={"error": f"Failed to initialize crawler: {str(e)}"} + status_code=500, + detail={"error": f"Failed to initialize crawler: {str(e)}"}, ) # Use the same crawl orchestration as regular crawl @@ -736,7 +806,9 @@ async def crawl_knowledge_item(request: KnowledgeItemRequest): # Basic URL validation if not request.url.startswith(("http://", "https://")): - raise HTTPException(status_code=422, detail="URL must start with http:// or https://") + raise HTTPException( + status_code=422, detail="URL must start with http:// or https://" + ) # Validate API key before starting expensive operation logger.info("🔍 About to validate API key...") @@ -754,6 +826,7 @@ async def crawl_knowledge_item(request: KnowledgeItemRequest): # Initialize progress tracker IMMEDIATELY so it's available for polling from ..utils.progress.progress_tracker import ProgressTracker + tracker = ProgressTracker(progress_id, operation_type="crawl") # Detect crawl type from URL @@ -764,14 +837,16 @@ async def crawl_knowledge_item(request: KnowledgeItemRequest): elif url_str.endswith(".txt"): crawl_type = "llms-txt" if "llms" in url_str.lower() else "text_file" - await tracker.start({ - "url": url_str, - "current_url": url_str, - "crawl_type": crawl_type, - # Don't override status - let tracker.start() set it to "starting" - "progress": 0, - "log": f"Starting crawl for {request.url}" - }) + await tracker.start( + { + "url": url_str, + "current_url": url_str, + "crawl_type": crawl_type, + # Don't override status - let tracker.start() set it to "starting" + "progress": 0, + "log": f"Starting crawl for {request.url}", + } + ) # Start background task - no need to track this wrapper task # The actual crawl task will be stored inside _perform_crawl_with_progress @@ -795,12 +870,14 @@ class Config: success=True, progress_id=progress_id, message="Crawling started", - estimated_duration="3-5 minutes" + estimated_duration="3-5 minutes", ) return response.model_dump(by_alias=True) except Exception as e: - safe_logfire_error(f"Failed to start crawl | error={str(e)} | url={str(request.url)}") + safe_logfire_error( + f"Failed to start crawl | error={str(e)} | url={str(request.url)}" + ) raise HTTPException(status_code=500, detail=str(e)) @@ -822,7 +899,9 @@ async def _perform_crawl_with_progress( try: crawler = await get_crawler() if crawler is None: - raise Exception("Crawler not available - initialization may have failed") + raise Exception( + "Crawler not available - initialization may have failed" + ) except Exception as e: safe_logfire_error(f"Failed to get crawler | error={str(e)}") await tracker.error(f"Failed to initialize crawler: {str(e)}") @@ -853,7 +932,9 @@ async def _perform_crawl_with_progress( f"Stored actual crawl task in active_crawl_tasks | progress_id={progress_id} | task_name={crawl_task.get_name()}" ) else: - safe_logfire_error(f"No task returned from orchestrate_crawl | progress_id={progress_id}") + safe_logfire_error( + f"No task returned from orchestrate_crawl | progress_id={progress_id}" + ) # The orchestration service now runs in background and handles all progress updates safe_logfire_info( @@ -899,14 +980,14 @@ async def upload_document( extract_code_examples: bool = Form(True), ): """Upload and process a document with progress tracking.""" - - # Validate API key before starting expensive upload operation + + # Validate API key before starting expensive upload operation logger.info("🔍 About to validate API key for upload...") provider_config = await credential_service.get_active_provider("embedding") provider = provider_config.get("provider", "openai") await _validate_provider_api_key(provider) logger.info("✅ API key validation completed successfully for upload") - + try: # DETAILED LOGGING: Track knowledge_type parameter flow safe_logfire_info( @@ -923,11 +1004,19 @@ async def upload_document( tag_list = [] # Validate tags is a list of strings if not isinstance(tag_list, list): - raise HTTPException(status_code=422, detail={"error": "tags must be a JSON array of strings"}) + raise HTTPException( + status_code=422, + detail={"error": "tags must be a JSON array of strings"}, + ) if not all(isinstance(tag, str) for tag in tag_list): - raise HTTPException(status_code=422, detail={"error": "tags must be a JSON array of strings"}) + raise HTTPException( + status_code=422, + detail={"error": "tags must be a JSON array of strings"}, + ) except json.JSONDecodeError as ex: - raise HTTPException(status_code=422, detail={"error": f"Invalid tags JSON: {str(ex)}"}) + raise HTTPException( + status_code=422, detail={"error": f"Invalid tags JSON: {str(ex)}"} + ) # Read file content immediately to avoid closed file issues file_content = await file.read() @@ -939,18 +1028,27 @@ async def upload_document( # Initialize progress tracker IMMEDIATELY so it's available for polling from ..utils.progress.progress_tracker import ProgressTracker + tracker = ProgressTracker(progress_id, operation_type="upload") - await tracker.start({ - "filename": file.filename, - "status": "initializing", - "progress": 0, - "log": f"Starting upload for {file.filename}" - }) + await tracker.start( + { + "filename": file.filename, + "status": "initializing", + "progress": 0, + "log": f"Starting upload for {file.filename}", + } + ) # Start background task for processing with file content and metadata # Upload tasks can be tracked directly since they don't spawn sub-tasks upload_task = asyncio.create_task( _perform_upload_with_progress( - progress_id, file_content, file_metadata, tag_list, knowledge_type, extract_code_examples, tracker + progress_id, + file_content, + file_metadata, + tag_list, + knowledge_type, + extract_code_examples, + tracker, ) ) # Track the task for cancellation support @@ -982,6 +1080,7 @@ async def _perform_upload_with_progress( tracker: "ProgressTracker", ): """Perform document upload with progress tracking using service layer.""" + # Create cancellation check function for document uploads def check_upload_cancellation(): """Check if upload task has been cancelled.""" @@ -991,6 +1090,7 @@ def check_upload_cancellation(): # Import ProgressMapper to prevent progress from going backwards from ..services.crawling.progress_mapper import ProgressMapper + progress_mapper = ProgressMapper() try: @@ -1002,17 +1102,18 @@ def check_upload_cancellation(): f"Starting document upload with progress tracking | progress_id={progress_id} | filename={filename} | content_type={content_type}" ) - # Extract text from document with progress - use mapper for consistent progress mapped_progress = progress_mapper.map_progress("processing", 50) await tracker.update( status="processing", progress=mapped_progress, - log=f"Extracting text from {filename}" + log=f"Extracting text from {filename}", ) try: - extracted_text = extract_text_from_document(file_content, filename, content_type) + extracted_text = extract_text_from_document( + file_content, filename, content_type + ) safe_logfire_info( f"Document text extracted | filename={filename} | extracted_length={len(extracted_text)} | content_type={content_type}" ) @@ -1023,7 +1124,9 @@ def check_upload_cancellation(): return except Exception as ex: # Other exceptions are system errors - log with full traceback - logger.error(f"Failed to extract text from document: {filename}", exc_info=True) + logger.error( + f"Failed to extract text from document: {filename}", exc_info=True + ) await tracker.error(f"Failed to extract text from document: {str(ex)}") return @@ -1047,10 +1150,9 @@ async def document_progress_callback( progress=mapped_percentage, log=message, currentUrl=f"file://{filename}", - **(batch_info or {}) + **(batch_info or {}), ) - # Call the service's upload_document method success, result = await doc_storage_service.upload_document( file_content=extracted_text, @@ -1065,12 +1167,14 @@ async def document_progress_callback( if success: # Complete the upload with 100% progress - await tracker.complete({ - "log": "Document uploaded successfully!", - "chunks_stored": result.get("chunks_stored"), - "code_examples_stored": result.get("code_examples_stored", 0), - "sourceId": result.get("source_id"), - }) + await tracker.complete( + { + "log": "Document uploaded successfully!", + "chunks_stored": result.get("chunks_stored"), + "code_examples_stored": result.get("code_examples_stored", 0), + "sourceId": result.get("source_id"), + } + ) safe_logfire_info( f"Document uploaded successfully | progress_id={progress_id} | source_id={result.get('source_id')} | chunks_stored={result.get('chunks_stored')} | code_examples_stored={result.get('code_examples_stored', 0)}" ) @@ -1089,7 +1193,9 @@ async def document_progress_callback( # Clean up task from registry when done (success or failure) if progress_id in active_crawl_tasks: del active_crawl_tasks[progress_id] - safe_logfire_info(f"Cleaned up upload task from registry | progress_id={progress_id}") + safe_logfire_info( + f"Cleaned up upload task from registry | progress_id={progress_id}" + ) @router.post("/knowledge-items/search") @@ -1123,7 +1229,7 @@ async def perform_rag_query(request: RagQueryRequest): query=request.query, source=request.source, match_count=request.match_count, - return_mode=request.return_mode + return_mode=request.return_mode, ) if success: @@ -1132,7 +1238,8 @@ async def perform_rag_query(request: RagQueryRequest): return result else: raise HTTPException( - status_code=500, detail={"error": result.get("error", "RAG query failed")} + status_code=500, + detail={"error": result.get("error", "RAG query failed")}, ) except HTTPException: raise @@ -1140,7 +1247,9 @@ async def perform_rag_query(request: RagQueryRequest): safe_logfire_error( f"RAG query failed | error={str(e)} | query={request.query[:50]} | source={request.source}" ) - raise HTTPException(status_code=500, detail={"error": f"RAG query failed: {str(e)}"}) + raise HTTPException( + status_code=500, detail={"error": f"RAG query failed: {str(e)}"} + ) @router.post("/rag/code-examples") @@ -1230,12 +1339,15 @@ async def delete_source(source_id: str): f"Source deletion failed | source_id={source_id} | error={result_data.get('error')}" ) raise HTTPException( - status_code=500, detail={"error": result_data.get("error", "Deletion failed")} + status_code=500, + detail={"error": result_data.get("error", "Deletion failed")}, ) except HTTPException: raise except Exception as e: - safe_logfire_error(f"Failed to delete source | error={str(e)} | source_id={source_id}") + safe_logfire_error( + f"Failed to delete source | error={str(e)} | source_id={source_id}" + ) raise HTTPException(status_code=500, detail={"error": str(e)}) @@ -1267,7 +1379,7 @@ async def knowledge_health(): "ready": False, "migration_required": True, "message": schema_status["message"], - "migration_instructions": "Open Supabase Dashboard → SQL Editor → Run: migration/add_source_url_display_name.sql" + "migration_instructions": "Open Supabase Dashboard → SQL Editor → Run: migration/add_source_url_display_name.sql", } # Removed health check logging to reduce console noise @@ -1280,13 +1392,14 @@ async def knowledge_health(): return result - @router.post("/knowledge-items/stop/{progress_id}") async def stop_crawl_task(progress_id: str): """Stop a running crawl task.""" try: - from ..services.crawling import get_active_orchestration, unregister_orchestration - + from ..services.crawling import ( + get_active_orchestration, + unregister_orchestration, + ) safe_logfire_info(f"Stop crawl requested | progress_id={progress_id}") @@ -1316,24 +1429,32 @@ async def stop_crawl_task(progress_id: str): if found: try: from ..utils.progress.progress_tracker import ProgressTracker + # Get current progress from existing tracker, default to 0 if not found current_state = ProgressTracker.get_progress(progress_id) - current_progress = current_state.get("progress", 0) if current_state else 0 + current_progress = ( + current_state.get("progress", 0) if current_state else 0 + ) tracker = ProgressTracker(progress_id, operation_type="crawl") await tracker.update( status="cancelled", progress=current_progress, - log="Crawl cancelled by user" + log="Crawl cancelled by user", ) except Exception: # Best effort - don't fail the cancellation if tracker update fails pass if not found: - raise HTTPException(status_code=404, detail={"error": "No active task for given progress_id"}) + raise HTTPException( + status_code=404, + detail={"error": "No active task for given progress_id"}, + ) - safe_logfire_info(f"Successfully stopped crawl task | progress_id={progress_id}") + safe_logfire_info( + f"Successfully stopped crawl task | progress_id={progress_id}" + ) return { "success": True, "message": "Crawl task stopped successfully", diff --git a/python/src/server/api_routes/settings_api.py b/python/src/server/api_routes/settings_api.py index 30de2b9813..ad89124999 100644 --- a/python/src/server/api_routes/settings_api.py +++ b/python/src/server/api_routes/settings_api.py @@ -70,7 +70,9 @@ async def list_credentials(category: str | None = None): for cred in credentials ] except Exception as e: - logfire.error(f"Error listing credentials | category={category} | error={str(e)}") + logfire.error( + f"Error listing credentials | category={category} | error={str(e)}" + ) raise HTTPException(status_code=500, detail={"error": str(e)}) @@ -120,7 +122,9 @@ async def create_credential(request: CredentialRequest): } else: logfire.error(f"Failed to save credential | key={request.key}") - raise HTTPException(status_code=500, detail={"error": "Failed to save credential"}) + raise HTTPException( + status_code=500, detail={"error": "Failed to save credential"} + ) except Exception as e: logfire.error(f"Error creating credential | key={request.key} | error={str(e)}") @@ -149,7 +153,9 @@ async def get_credential(key: str): if value is None: # Check if this is an optional setting with a default value if key in OPTIONAL_SETTINGS_WITH_DEFAULTS: - logfire.info(f"Returning default value for optional setting | key={key}") + logfire.info( + f"Returning default value for optional setting | key={key}" + ) return { "key": key, "value": OPTIONAL_SETTINGS_WITH_DEFAULTS[key], @@ -159,7 +165,9 @@ async def get_credential(key: str): } logfire.warning(f"Credential not found | key={key}") - raise HTTPException(status_code=404, detail={"error": f"Credential {key} not found"}) + raise HTTPException( + status_code=404, detail={"error": f"Credential {key} not found"} + ) logfire.info(f"Credential retrieved successfully | key={key}") @@ -218,7 +226,9 @@ async def update_credential(key: str, request: dict[str, Any]): category = existing.category if description is None: description = existing.description - logfire.info(f"Updating existing credential | key={key} | category={category}") + logfire.info( + f"Updating existing credential | key={key} | category={category}" + ) success = await credential_service.set_credential( key=key, @@ -233,10 +243,95 @@ async def update_credential(key: str, request: dict[str, Any]): f"Credential updated successfully | key={key} | is_encrypted={is_encrypted}" ) - return {"success": True, "message": f"Credential {key} updated successfully"} + # Auto-switch provider when cloud provider credentials are saved + if category == "cloud_providers" and value and value.strip(): + # Determine which provider to switch to based on the credential key + provider_switch_map = { + "AZURE_OPENAI_API_KEY": "azure-openai", + "AZURE_OPENAI_ENDPOINT": "azure-openai", + "AWS_ACCESS_KEY_ID": "aws-bedrock", + "AWS_SECRET_ACCESS_KEY": "aws-bedrock", + } + + target_provider = provider_switch_map.get(key) + + if target_provider: + # Check if we have all required credentials for this provider + if target_provider == "azure-openai": + # Check if we have all Azure OpenAI credentials + azure_api_key = await credential_service.get_credential( + "AZURE_OPENAI_API_KEY" + ) + azure_endpoint = await credential_service.get_credential( + "AZURE_OPENAI_ENDPOINT" + ) + azure_version = await credential_service.get_credential( + "AZURE_OPENAI_API_VERSION" + ) + + if azure_api_key and azure_endpoint and azure_version: + # Switch to Azure OpenAI + await credential_service.set_credential( + key="LLM_PROVIDER", + value="azure-openai", + is_encrypted=False, + category="rag_strategy", + description="LLM provider to use: openai, ollama, or google", + ) + # Also update EMBEDDING_PROVIDER to match + await credential_service.set_credential( + key="EMBEDDING_PROVIDER", + value="azure-openai", + is_encrypted=False, + category="rag_strategy", + description="Embedding provider to use (if different from LLM_PROVIDER)", + ) + logfire.info( + "Auto-switched LLM_PROVIDER and EMBEDDING_PROVIDER to azure-openai after saving Azure credentials" + ) + + elif target_provider == "aws-bedrock": + # Check if we have all AWS Bedrock credentials + aws_access_key = await credential_service.get_credential( + "AWS_ACCESS_KEY_ID" + ) + aws_secret_key = await credential_service.get_credential( + "AWS_SECRET_ACCESS_KEY" + ) + aws_region = await credential_service.get_credential( + "AWS_REGION" + ) + + if aws_access_key and aws_secret_key and aws_region: + # Switch to AWS Bedrock + await credential_service.set_credential( + key="LLM_PROVIDER", + value="aws-bedrock", + is_encrypted=False, + category="rag_strategy", + description="LLM provider to use: openai, ollama, or google", + ) + # Also update EMBEDDING_PROVIDER to match + await credential_service.set_credential( + key="EMBEDDING_PROVIDER", + value="aws-bedrock", + is_encrypted=False, + category="rag_strategy", + description="Embedding provider to use (if different from LLM_PROVIDER)", + ) + logfire.info( + "Auto-switched LLM_PROVIDER and EMBEDDING_PROVIDER to aws-bedrock after saving AWS credentials" + ) + + return { + "success": True, + "message": f"Credential {key} updated successfully", + } else: logfire.error(f"Failed to update credential | key={key}") - raise HTTPException(status_code=500, detail={"error": "Failed to update credential"}) + raise HTTPException( + status_code=500, detail={"error": "Failed to update credential"} + ) except Exception as e: logfire.error(f"Error updating credential | key={key} | error={str(e)}") @@ -253,10 +348,15 @@ async def delete_credential(key: str): if success: logfire.info(f"Credential deleted successfully | key={key}") - return {"success": True, "message": f"Credential {key} deleted successfully"} + return { + "success": True, + "message": f"Credential {key} deleted successfully", + } else: logfire.error(f"Failed to delete credential | key={key}") - raise HTTPException(status_code=500, detail={"error": "Failed to delete credential"}) + raise HTTPException( + status_code=500, detail={"error": "Failed to delete credential"} + ) except Exception as e: logfire.error(f"Error deleting credential | key={key} | error={str(e)}") @@ -290,19 +390,27 @@ async def database_metrics(): # Get projects count projects_response = ( - supabase_client.table("archon_projects").select("id", count="exact").execute() + supabase_client.table("archon_projects") + .select("id", count="exact") + .execute() ) tables_info["projects"] = ( projects_response.count if projects_response.count is not None else 0 ) # Get tasks count - tasks_response = supabase_client.table("archon_tasks").select("id", count="exact").execute() - tables_info["tasks"] = tasks_response.count if tasks_response.count is not None else 0 + tasks_response = ( + supabase_client.table("archon_tasks").select("id", count="exact").execute() + ) + tables_info["tasks"] = ( + tasks_response.count if tasks_response.count is not None else 0 + ) # Get crawled pages count pages_response = ( - supabase_client.table("archon_crawled_pages").select("id", count="exact").execute() + supabase_client.table("archon_crawled_pages") + .select("id", count="exact") + .execute() ) tables_info["crawled_pages"] = ( pages_response.count if pages_response.count is not None else 0 @@ -310,7 +418,9 @@ async def database_metrics(): # Get settings count settings_response = ( - supabase_client.table("archon_settings").select("id", count="exact").execute() + supabase_client.table("archon_settings") + .select("id", count="exact") + .execute() ) tables_info["settings"] = ( settings_response.count if settings_response.count is not None else 0 @@ -346,46 +456,52 @@ async def settings_health(): @router.post("/credentials/status-check") async def check_credential_status(request: dict[str, list[str]]): """Check status of API credentials by actually decrypting and validating them. - + This endpoint is specifically for frontend status indicators and returns decrypted credential values for connectivity testing. """ try: credential_keys = request.get("keys", []) logfire.info(f"Checking status for credentials: {credential_keys}") - + result = {} - + for key in credential_keys: try: # Get decrypted value for status checking - decrypted_value = await credential_service.get_credential(key, decrypt=True) - - if decrypted_value and isinstance(decrypted_value, str) and decrypted_value.strip(): + decrypted_value = await credential_service.get_credential( + key, decrypt=True + ) + + if ( + decrypted_value + and isinstance(decrypted_value, str) + and decrypted_value.strip() + ): result[key] = { "key": key, "value": decrypted_value, - "has_value": True + "has_value": True, } else: - result[key] = { - "key": key, - "value": None, - "has_value": False - } - + result[key] = {"key": key, "value": None, "has_value": False} + except Exception as e: - logfire.warning(f"Failed to get credential for status check: {key} | error={str(e)}") + logfire.warning( + f"Failed to get credential for status check: {key} | error={str(e)}" + ) result[key] = { "key": key, "value": None, "has_value": False, - "error": str(e) + "error": str(e), } - - logfire.info(f"Credential status check completed | checked={len(credential_keys)} | found={len([k for k, v in result.items() if v.get('has_value')])}") + + logfire.info( + f"Credential status check completed | checked={len(credential_keys)} | found={len([k for k, v in result.items() if v.get('has_value')])}" + ) return result - + except Exception as e: logfire.error(f"Error in credential status check | error={str(e)}") raise HTTPException(status_code=500, detail={"error": str(e)}) diff --git a/python/src/server/config/config.py b/python/src/server/config/config.py index d8104bb0ea..709fff5818 100644 --- a/python/src/server/config/config.py +++ b/python/src/server/config/config.py @@ -26,6 +26,16 @@ class EnvironmentConfig: openai_api_key: str | None = None host: str = "0.0.0.0" transport: str = "sse" + # Azure OpenAI configuration + azure_openai_endpoint: str | None = None + azure_openai_api_key: str | None = None + azure_openai_api_version: str | None = None + azure_openai_deployment: str | None = None + # AWS Bedrock configuration + aws_access_key_id: str | None = None + aws_secret_access_key: str | None = None + aws_region: str | None = None + aws_bedrock_model_id: str | None = None @dataclass @@ -66,6 +76,93 @@ def validate_openai_api_key(api_key: str) -> bool: return True +def validate_azure_openai_endpoint(endpoint: str) -> bool: + """Validate Azure OpenAI endpoint format.""" + if not endpoint: + raise ConfigurationError("Azure OpenAI endpoint cannot be empty") + + parsed = urlparse(endpoint) + if parsed.scheme not in ("http", "https"): + raise ConfigurationError("Azure OpenAI endpoint must use HTTP or HTTPS") + + if not parsed.netloc: + raise ConfigurationError("Invalid Azure OpenAI endpoint format") + + # Azure OpenAI endpoints typically contain '.openai.azure.com' + if ".openai.azure.com" not in parsed.netloc: + # Warning but allow non-standard endpoints + pass + + return True + + +def validate_azure_openai_api_version(api_version: str) -> bool: + """Validate Azure OpenAI API version format.""" + if not api_version: + raise ConfigurationError("Azure OpenAI API version cannot be empty") + + # Azure API versions follow YYYY-MM-DD format or YYYY-MM-DD-preview + import re + + pattern = r"^\d{4}-\d{2}-\d{2}(-preview)?$" + if not re.match(pattern, api_version): + raise ConfigurationError( + f"Azure OpenAI API version must follow YYYY-MM-DD or YYYY-MM-DD-preview format, got: {api_version}" + ) + + return True + + +def validate_aws_access_key_id(access_key_id: str) -> bool: + """Validate AWS Access Key ID format.""" + if not access_key_id: + raise ConfigurationError("AWS Access Key ID cannot be empty") + + # AWS Access Key IDs are 20 characters long and start with AKIA or ASIA + if not (access_key_id.startswith("AKIA") or access_key_id.startswith("ASIA")): + raise ConfigurationError( + "AWS Access Key ID must start with 'AKIA' (long-term) or 'ASIA' (temporary)" + ) + + if len(access_key_id) != 20: + raise ConfigurationError( + f"AWS Access Key ID must be exactly 20 characters, got {len(access_key_id)}" + ) + + return True + + +def validate_aws_region(region: str) -> bool: + """Validate AWS region format.""" + if not region: + raise ConfigurationError("AWS region cannot be empty") + + # Common AWS regions for Bedrock + valid_regions = { + "us-east-1", + "us-east-2", + "us-west-1", + "us-west-2", + "ap-south-1", + "ap-northeast-1", + "ap-northeast-2", + "ap-southeast-1", + "ap-southeast-2", + "ca-central-1", + "eu-central-1", + "eu-west-1", + "eu-west-2", + "eu-west-3", + "sa-east-1", + } + + if region not in valid_regions: + # Warning but allow non-standard regions + pass + + return True + + def validate_supabase_key(supabase_key: str) -> tuple[bool, str]: """Validate Supabase key type and return validation result. @@ -137,14 +234,18 @@ def validate_supabase_url(url: str) -> bool: # Class C: 192.168.0.0/16 # Also includes link-local (169.254.0.0/16) and loopback # Exclude unspecified address (0.0.0.0) for security - if (ip.is_private or ip.is_loopback or ip.is_link_local) and not ip.is_unspecified: + if ( + ip.is_private or ip.is_loopback or ip.is_link_local + ) and not ip.is_unspecified: return True except ValueError: # hostname is not a valid IP address, could be a domain name pass # If not a local host or private IP, require HTTPS - raise ConfigurationError(f"Supabase URL must use HTTPS for non-local environments (hostname: {hostname})") + raise ConfigurationError( + f"Supabase URL must use HTTPS for non-local environments (hostname: {hostname})" + ) if not parsed.netloc: raise ConfigurationError("Invalid Supabase URL format") @@ -157,6 +258,12 @@ def load_environment_config() -> EnvironmentConfig: # OpenAI API key is optional at startup - can be set via API openai_api_key = os.getenv("OPENAI_API_KEY") + # Azure OpenAI configuration (optional) + azure_openai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") + azure_openai_api_key = os.getenv("AZURE_OPENAI_API_KEY") + azure_openai_api_version = os.getenv("AZURE_OPENAI_API_VERSION") + azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") + # Required environment variables for database access supabase_url = os.getenv("SUPABASE_URL") if not supabase_url: @@ -164,11 +271,20 @@ def load_environment_config() -> EnvironmentConfig: supabase_service_key = os.getenv("SUPABASE_SERVICE_KEY") if not supabase_service_key: - raise ConfigurationError("SUPABASE_SERVICE_KEY environment variable is required") + raise ConfigurationError( + "SUPABASE_SERVICE_KEY environment variable is required" + ) # Validate required fields if openai_api_key: validate_openai_api_key(openai_api_key) + + # Validate Azure OpenAI configuration if provided + if azure_openai_endpoint: + validate_azure_openai_endpoint(azure_openai_endpoint) + if azure_openai_api_version: + validate_azure_openai_api_version(azure_openai_api_version) + validate_supabase_url(supabase_url) # Validate Supabase key type @@ -217,7 +333,9 @@ def load_environment_config() -> EnvironmentConfig: try: port = int(port_str) except ValueError as e: - raise ConfigurationError(f"PORT must be a valid integer, got: {port_str}") from e + raise ConfigurationError( + f"PORT must be a valid integer, got: {port_str}" + ) from e return EnvironmentConfig( openai_api_key=openai_api_key, @@ -226,6 +344,10 @@ def load_environment_config() -> EnvironmentConfig: host=host, port=port, transport=transport, + azure_openai_endpoint=azure_openai_endpoint, + azure_openai_api_key=azure_openai_api_key, + azure_openai_api_version=azure_openai_api_version, + azure_openai_deployment=azure_openai_deployment, ) diff --git a/python/src/server/services/credential_service.py b/python/src/server/services/credential_service.py index a8aee8491d..faf6e72789 100644 --- a/python/src/server/services/credential_service.py +++ b/python/src/server/services/credential_service.py @@ -36,8 +36,6 @@ class CredentialItem: description: str | None = None - - class CredentialService: """Service for managing application credentials and configuration.""" @@ -71,7 +69,9 @@ def _get_supabase_client(self) -> Client: match = re.match(r"https://([^.]+)\.supabase\.co", url) if match: project_id = match.group(1) - logger.debug(f"Supabase client initialized for project: {project_id}") + logger.debug( + f"Supabase client initialized for project: {project_id}" + ) else: logger.debug("Supabase client initialized successfully") @@ -157,7 +157,9 @@ async def load_all_credentials(self) -> dict[str, Any]: logger.error(f"Error loading credentials: {e}") raise - async def get_credential(self, key: str, default: Any = None, decrypt: bool = True) -> Any: + async def get_credential( + self, key: str, default: Any = None, decrypt: bool = True + ) -> Any: """Get a credential value by key.""" if not self._cache_initialized: await self.load_all_credentials() @@ -244,6 +246,7 @@ async def set_credential( # Also invalidate provider service cache to ensure immediate effect try: from .llm_provider_service import clear_provider_cache + clear_provider_cache() logger.debug("Also cleared LLM provider service cache") except Exception as e: @@ -252,14 +255,23 @@ async def set_credential( # Also invalidate LLM provider service cache for provider config try: from . import llm_provider_service + # Clear the provider config caches that depend on RAG settings - cache_keys_to_clear = ["provider_config_llm", "provider_config_embedding", "rag_strategy_settings"] + cache_keys_to_clear = [ + "provider_config_llm", + "provider_config_embedding", + "rag_strategy_settings", + ] for cache_key in cache_keys_to_clear: if cache_key in llm_provider_service._settings_cache: del llm_provider_service._settings_cache[cache_key] - logger.debug(f"Invalidated LLM provider service cache key: {cache_key}") + logger.debug( + f"Invalidated LLM provider service cache key: {cache_key}" + ) except ImportError: - logger.warning("Could not import llm_provider_service to invalidate cache") + logger.warning( + "Could not import llm_provider_service to invalidate cache" + ) except Exception as e: logger.error(f"Error invalidating LLM provider service cache: {e}") @@ -294,6 +306,7 @@ async def delete_credential(self, key: str) -> bool: # Also invalidate provider service cache to ensure immediate effect try: from .llm_provider_service import clear_provider_cache + clear_provider_cache() logger.debug("Also cleared LLM provider service cache") except Exception as e: @@ -302,14 +315,23 @@ async def delete_credential(self, key: str) -> bool: # Also invalidate LLM provider service cache for provider config try: from . import llm_provider_service + # Clear the provider config caches that depend on RAG settings - cache_keys_to_clear = ["provider_config_llm", "provider_config_embedding", "rag_strategy_settings"] + cache_keys_to_clear = [ + "provider_config_llm", + "provider_config_embedding", + "rag_strategy_settings", + ] for cache_key in cache_keys_to_clear: if cache_key in llm_provider_service._settings_cache: del llm_provider_service._settings_cache[cache_key] - logger.debug(f"Invalidated LLM provider service cache key: {cache_key}") + logger.debug( + f"Invalidated LLM provider service cache key: {cache_key}" + ) except ImportError: - logger.warning("Could not import llm_provider_service to invalidate cache") + logger.warning( + "Could not import llm_provider_service to invalidate cache" + ) except Exception as e: logger.error(f"Error invalidating LLM provider service cache: {e}") @@ -341,7 +363,10 @@ async def get_credentials_by_category(self, category: str) -> dict[str, Any]: try: supabase = self._get_supabase_client() result = ( - supabase.table("archon_settings").select("*").eq("category", category).execute() + supabase.table("archon_settings") + .select("*") + .eq("category", category) + .execute() ) credentials = {} @@ -443,24 +468,53 @@ async def get_active_provider(self, service_type: str = "llm") -> dict[str, Any] explicit_embedding_provider = rag_settings.get("EMBEDDING_PROVIDER") # Validate that embedding provider actually supports embeddings - embedding_capable_providers = {"openai", "google", "ollama"} + # Include cloud providers that support embeddings + embedding_capable_providers = { + "openai", + "azure-openai", # Azure OpenAI supports embeddings + "aws-bedrock", # AWS Bedrock supports embeddings (Titan, Cohere) + "google", + "ollama", + } - if (explicit_embedding_provider and - explicit_embedding_provider != "" and - explicit_embedding_provider in embedding_capable_providers): + if ( + explicit_embedding_provider + and explicit_embedding_provider != "" + and explicit_embedding_provider in embedding_capable_providers + ): # Use the explicitly set embedding provider provider = explicit_embedding_provider logger.debug(f"Using explicit embedding provider: '{provider}'") else: - # Fall back to OpenAI as default embedding provider for backward compatibility - if explicit_embedding_provider and explicit_embedding_provider not in embedding_capable_providers: - logger.warning(f"Invalid embedding provider '{explicit_embedding_provider}' doesn't support embeddings, defaulting to OpenAI") - provider = "openai" - logger.debug(f"No explicit embedding provider set, defaulting to OpenAI for backward compatibility") + # If no explicit embedding provider, check if LLM_PROVIDER supports embeddings + llm_provider = rag_settings.get("LLM_PROVIDER", "openai") + if llm_provider in embedding_capable_providers: + # Use LLM provider for embeddings if it supports them + provider = llm_provider + logger.debug(f"Using LLM provider for embeddings: '{provider}'") + else: + # Fall back to OpenAI as default embedding provider for backward compatibility + if ( + explicit_embedding_provider + and explicit_embedding_provider + not in embedding_capable_providers + ): + logger.warning( + f"Invalid embedding provider '{explicit_embedding_provider}' doesn't support embeddings, defaulting to OpenAI" + ) + provider = "openai" + logger.debug( + "No valid embedding provider found, defaulting to OpenAI for backward compatibility" + ) else: provider = rag_settings.get("LLM_PROVIDER", "openai") # Ensure provider is a valid string, not a boolean or other type - if not isinstance(provider, str) or provider.lower() in ("true", "false", "none", "null"): + if not isinstance(provider, str) or provider.lower() in ( + "true", + "false", + "none", + "null", + ): provider = "openai" # Get API key for this provider @@ -504,10 +558,12 @@ async def _get_provider_api_key(self, provider: str) -> str | None: """Get API key for a specific provider.""" key_mapping = { "openai": "OPENAI_API_KEY", + "azure-openai": "AZURE_OPENAI_API_KEY", "google": "GOOGLE_API_KEY", "openrouter": "OPENROUTER_API_KEY", "anthropic": "ANTHROPIC_API_KEY", "grok": "GROK_API_KEY", + "aws-bedrock": "AWS_ACCESS_KEY_ID", # AWS uses access key as primary credential "ollama": None, # No API key needed } @@ -519,7 +575,15 @@ async def _get_provider_api_key(self, provider: str) -> str | None: def _get_provider_base_url(self, provider: str, rag_settings: dict) -> str | None: """Get base URL for provider.""" if provider == "ollama": - return rag_settings.get("LLM_BASE_URL", "http://host.docker.internal:11434/v1") + return rag_settings.get( + "LLM_BASE_URL", "http://host.docker.internal:11434/v1" + ) + elif provider == "azure-openai": + # Azure OpenAI endpoint will be handled separately with AsyncAzureOpenAI + return None + elif provider == "aws-bedrock": + # AWS Bedrock region will be handled separately with boto3 + return None elif provider == "google": return "https://generativelanguage.googleapis.com/v1beta/openai/" elif provider == "openrouter": @@ -530,7 +594,9 @@ def _get_provider_base_url(self, provider: str, rag_settings: dict) -> str | Non return "https://api.x.ai/v1" return None # Use default for OpenAI - async def set_active_provider(self, provider: str, service_type: str = "llm") -> bool: + async def set_active_provider( + self, provider: str, service_type: str = "llm" + ) -> bool: """Set the active provider for a service type.""" try: # For now, we'll update the RAG strategy settings @@ -541,7 +607,9 @@ async def set_active_provider(self, provider: str, service_type: str = "llm") -> description=f"Active {service_type} provider", ) except Exception as e: - logger.error(f"Error setting active provider {provider} for {service_type}: {e}") + logger.error( + f"Error setting active provider {provider} for {service_type}: {e}" + ) return False @@ -555,10 +623,16 @@ async def get_credential(key: str, default: Any = None) -> Any: async def set_credential( - key: str, value: str, is_encrypted: bool = False, category: str = None, description: str = None + key: str, + value: str, + is_encrypted: bool = False, + category: str = None, + description: str = None, ) -> bool: """Convenience function to set a credential.""" - return await credential_service.set_credential(key, value, is_encrypted, category, description) + return await credential_service.set_credential( + key, value, is_encrypted, category, description + ) async def initialize_credentials() -> None: diff --git a/python/src/server/services/embeddings/embedding_service.py b/python/src/server/services/embeddings/embedding_service.py index 87ce390b67..f5bf904c0b 100644 --- a/python/src/server/services/embeddings/embedding_service.py +++ b/python/src/server/services/embeddings/embedding_service.py @@ -35,7 +35,9 @@ class EmbeddingBatchResult: failed_items: list[dict[str, Any]] = field(default_factory=list) success_count: int = 0 failure_count: int = 0 - texts_processed: list[str] = field(default_factory=list) # Successfully processed texts + texts_processed: list[str] = field( + default_factory=list + ) # Successfully processed texts def add_success(self, embedding: list[float], text: str): """Add a successful embedding.""" @@ -83,10 +85,10 @@ async def create_embeddings( class OpenAICompatibleEmbeddingAdapter(EmbeddingProviderAdapter): """Adapter for providers using the OpenAI embeddings API shape.""" - + def __init__(self, client: Any): self._client = client - + async def create_embeddings( self, texts: list[str], @@ -99,7 +101,37 @@ async def create_embeddings( } if dimensions is not None: request_args["dimensions"] = dimensions - + + response = await self._client.embeddings.create(**request_args) + return [item.embedding for item in response.data] + + +class AzureOpenAIEmbeddingAdapter(EmbeddingProviderAdapter): + """Adapter for Azure OpenAI embeddings API.""" + + def __init__(self, client: Any): + self._client = client + + async def create_embeddings( + self, + texts: list[str], + model: str, + dimensions: int | None = None, + ) -> list[list[float]]: + """ + Create embeddings using Azure OpenAI. + + For Azure OpenAI, the 'model' parameter should be the deployment name, + not the model name (e.g., "my-embedding-deployment" not "text-embedding-3-small"). + """ + request_args: dict[str, Any] = { + "model": model, # This is the deployment name for Azure + "input": texts, + } + # Azure OpenAI supports dimensions parameter for certain models + if dimensions is not None: + request_args["dimensions"] = dimensions + response = await self._client.embeddings.create(**request_args) return [item.embedding for item in response.data] @@ -121,7 +153,9 @@ async def create_embeddings( async with httpx.AsyncClient(timeout=30.0) as http_client: embeddings = await asyncio.gather( *( - self._fetch_single_embedding(http_client, google_api_key, model, text, dimensions) + self._fetch_single_embedding( + http_client, google_api_key, model, text, dimensions + ) for text in texts ) ) @@ -139,7 +173,9 @@ async def create_embeddings( original_error=error, ) from error except Exception as error: - search_logger.error(f"Error calling Google embedding API: {error}", exc_info=True) + search_logger.error( + f"Error calling Google embedding API: {error}", exc_info=True + ) raise EmbeddingAPIError( f"Google embedding error: {str(error)}", original_error=error ) from error @@ -209,7 +245,9 @@ def _normalize_embedding(self, embedding: list[float]) -> list[float]: normalized = embedding_array / norm return normalized.tolist() else: - search_logger.warning("Zero-norm embedding detected, returning unnormalized") + search_logger.warning( + "Zero-norm embedding detected, returning unnormalized" + ) return embedding except Exception as e: search_logger.error(f"Failed to normalize embedding: {e}") @@ -221,6 +259,8 @@ def _get_embedding_adapter(provider: str, client: Any) -> EmbeddingProviderAdapt provider_name = (provider or "").lower() if provider_name == "google": return GoogleEmbeddingAdapter() + elif provider_name == "azure-openai": + return AzureOpenAIEmbeddingAdapter(client) return OpenAICompatibleEmbeddingAdapter(client) @@ -229,6 +269,7 @@ async def _maybe_await(value: Any) -> Any: return await value if inspect.isawaitable(value) else value + # Provider-aware client factory get_openai_client = get_llm_client @@ -262,7 +303,9 @@ async def create_embedding(text: str, provider: str | None = None) -> list[float f"OpenAI quota exhausted: {error_msg}", text_preview=text ) elif "rate" in error_msg.lower(): - raise EmbeddingRateLimitError(f"Rate limit hit: {error_msg}", text_preview=text) + raise EmbeddingRateLimitError( + f"Rate limit hit: {error_msg}", text_preview=text + ) else: raise EmbeddingAPIError( f"Failed to create embedding: {error_msg}", text_preview=text @@ -286,7 +329,9 @@ async def create_embedding(text: str, provider: str | None = None) -> list[float f"OpenAI quota exhausted: {error_msg}", text_preview=text ) elif "rate_limit" in error_msg.lower(): - raise EmbeddingRateLimitError(f"Rate limit hit: {error_msg}", text_preview=text) + raise EmbeddingRateLimitError( + f"Rate limit hit: {error_msg}", text_preview=text + ) else: raise EmbeddingAPIError( f"Embedding error: {error_msg}", text_preview=text, original_error=e @@ -327,7 +372,8 @@ async def create_embeddings_batch( continue search_logger.error( - f"Invalid text type at index {i}: {type(text)}, value: {text}", exc_info=True + f"Invalid text type at index {i}: {type(text)}, value: {text}", + exc_info=True, ) try: converted = str(text) @@ -347,7 +393,9 @@ async def create_embeddings_batch( threading_service = get_threading_service() with safe_span( - "create_embeddings_batch", text_count=len(texts), total_chars=sum(len(t) for t in texts) + "create_embeddings_batch", + text_count=len(texts), + total_chars=sum(len(t) for t in texts), ) as span: try: embedding_config = await _maybe_await( @@ -356,30 +404,45 @@ async def create_embeddings_batch( embedding_provider = provider or embedding_config.get("provider") - if not isinstance(embedding_provider, str) or not embedding_provider.strip(): + if ( + not isinstance(embedding_provider, str) + or not embedding_provider.strip() + ): embedding_provider = "openai" if not embedding_provider: search_logger.error("No embedding provider configured") - raise ValueError("No embedding provider configured. Please set EMBEDDING_PROVIDER environment variable.") + raise ValueError( + "No embedding provider configured. Please set EMBEDDING_PROVIDER environment variable." + ) - search_logger.info(f"Using embedding provider: '{embedding_provider}' (from EMBEDDING_PROVIDER setting)") - async with get_llm_client(provider=embedding_provider, use_embedding_provider=True) as client: + search_logger.info( + f"Using embedding provider: '{embedding_provider}' (from EMBEDDING_PROVIDER setting)" + ) + async with get_llm_client( + provider=embedding_provider, use_embedding_provider=True + ) as client: # Load batch size and dimensions from settings try: rag_settings = await _maybe_await( credential_service.get_credentials_by_category("rag_strategy") ) batch_size = int(rag_settings.get("EMBEDDING_BATCH_SIZE", "100")) - embedding_dimensions = int(rag_settings.get("EMBEDDING_DIMENSIONS", "1536")) + embedding_dimensions = int( + rag_settings.get("EMBEDDING_DIMENSIONS", "1536") + ) except Exception as e: - search_logger.warning(f"Failed to load embedding settings: {e}, using defaults") + search_logger.warning( + f"Failed to load embedding settings: {e}, using defaults" + ) batch_size = 100 embedding_dimensions = 1536 total_tokens_used = 0 adapter = _get_embedding_adapter(embedding_provider, client) - dimensions_to_use = embedding_dimensions if embedding_dimensions > 0 else None + dimensions_to_use = ( + embedding_dimensions if embedding_dimensions > 0 else None + ) for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] @@ -393,28 +456,39 @@ async def create_embeddings_batch( # Create rate limit progress callback if we have a progress callback rate_limit_callback = None if progress_callback: + async def rate_limit_callback(data: dict): # Send heartbeat during rate limit wait processed = result.success_count + result.failure_count - message = f"Rate limited: {data.get('message', 'Waiting...')}" - await progress_callback(message, (processed / len(texts)) * 100) + message = ( + f"Rate limited: {data.get('message', 'Waiting...')}" + ) + await progress_callback( + message, (processed / len(texts)) * 100 + ) # Rate limit each batch - async with threading_service.rate_limited_operation(batch_tokens, rate_limit_callback): + async with threading_service.rate_limited_operation( + batch_tokens, rate_limit_callback + ): retry_count = 0 max_retries = 3 while retry_count < max_retries: try: # Create embeddings for this batch - embedding_model = await get_embedding_model(provider=embedding_provider) + embedding_model = await get_embedding_model( + provider=embedding_provider + ) embeddings = await adapter.create_embeddings( batch, embedding_model, dimensions=dimensions_to_use, ) - for text, vector in zip(batch, embeddings, strict=False): + for text, vector in zip( + batch, embeddings, strict=False + ): result.add_success(vector, text) break # Success, exit retry loop @@ -474,7 +548,9 @@ async def rate_limit_callback(data: dict): except Exception as e: # This batch failed - track failures but continue with next batch - search_logger.error(f"Batch {batch_index} failed: {e}", exc_info=True) + search_logger.error( + f"Batch {batch_index} failed: {e}", exc_info=True + ) for text in batch: if isinstance(e, EmbeddingError): @@ -483,7 +559,8 @@ async def rate_limit_callback(data: dict): result.add_failure( text, EmbeddingAPIError( - f"Failed to create embedding: {str(e)}", original_error=e + f"Failed to create embedding: {str(e)}", + original_error=e, ), batch_index, ) @@ -512,13 +589,18 @@ async def rate_limit_callback(data: dict): except Exception as e: # Catastrophic failure - return what we have span.set_attribute("catastrophic_failure", True) - search_logger.error(f"Catastrophic failure in batch embedding: {e}", exc_info=True) + search_logger.error( + f"Catastrophic failure in batch embedding: {e}", exc_info=True + ) # Mark remaining texts as failed processed_count = result.success_count + result.failure_count for text in texts[processed_count:]: result.add_failure( - text, EmbeddingAPIError(f"Catastrophic failure: {str(e)}", original_error=e) + text, + EmbeddingAPIError( + f"Catastrophic failure: {str(e)}", original_error=e + ), ) return result diff --git a/python/src/server/services/llm_provider_service.py b/python/src/server/services/llm_provider_service.py index 00197926fd..c383d0bc12 100644 --- a/python/src/server/services/llm_provider_service.py +++ b/python/src/server/services/llm_provider_service.py @@ -23,7 +23,16 @@ def _is_valid_provider(provider: str) -> bool: """Basic provider validation.""" if not provider or not isinstance(provider, str): return False - return provider.lower() in {"openai", "ollama", "google", "openrouter", "anthropic", "grok"} + return provider.lower() in { + "openai", + "azure-openai", + "aws-bedrock", + "ollama", + "google", + "openrouter", + "anthropic", + "grok", + } def _sanitize_for_log(text: str) -> str: @@ -31,6 +40,7 @@ def _sanitize_for_log(text: str) -> str: if not text: return "" import re + sanitized = re.sub(r"sk-[a-zA-Z0-9-_]{20,}", "[REDACTED]", text) sanitized = re.sub(r"xai-[a-zA-Z0-9-_]{20,}", "[REDACTED]", sanitized) return sanitized[:100] @@ -39,7 +49,9 @@ def _sanitize_for_log(text: str) -> str: # Secure settings cache with TTL and validation _settings_cache: dict[str, tuple[Any, float, str]] = {} # value, timestamp, checksum _CACHE_TTL_SECONDS = 300 # 5 minutes -_cache_access_log: list[dict] = [] # Track cache access patterns for security monitoring +_cache_access_log: list[dict] = ( + [] +) # Track cache access patterns for security monitoring def _calculate_cache_checksum(value: Any) -> str: @@ -50,13 +62,17 @@ def _calculate_cache_checksum(value: Any) -> str: # Convert value to JSON string for consistent hashing try: value_str = json.dumps(value, sort_keys=True, default=str) - return hashlib.sha256(value_str.encode()).hexdigest()[:16] # First 16 chars for efficiency + return hashlib.sha256(value_str.encode()).hexdigest()[ + :16 + ] # First 16 chars for efficiency except Exception: # Fallback for non-serializable objects return hashlib.sha256(str(value).encode()).hexdigest()[:16] -def _log_cache_access(key: str, action: str, hit: bool = None, security_event: str = None) -> None: +def _log_cache_access( + key: str, action: str, hit: bool = None, security_event: str = None +) -> None: """Log cache access for security monitoring.""" access_entry = { @@ -64,7 +80,7 @@ def _log_cache_access(key: str, action: str, hit: bool = None, security_event: s "key": _sanitize_for_log(key), "action": action, # "get", "set", "invalidate", "clear" "hit": hit, # For get operations - "security_event": security_event # "checksum_mismatch", "expired", etc. + "security_event": security_event, # "checksum_mismatch", "expired", etc. } # Keep only last 100 access entries to prevent memory growth @@ -98,17 +114,25 @@ def _get_cached_settings(key: str) -> Any | None: if current_checksum != stored_checksum: # Cache tampering detected, remove entry del _settings_cache[key] - _log_cache_access(key, "get", hit=False, security_event="checksum_mismatch") - logger.error(f"Cache integrity violation detected for key: {_sanitize_for_log(key)}") + _log_cache_access( + key, "get", hit=False, security_event="checksum_mismatch" + ) + logger.error( + f"Cache integrity violation detected for key: {_sanitize_for_log(key)}" + ) return None # Additional validation for provider configurations if "provider_config" in key and isinstance(value, dict): # Basic validation: check required fields - if not value.get("provider") or not _is_valid_provider(value.get("provider")): + if not value.get("provider") or not _is_valid_provider( + value.get("provider") + ): # Invalid configuration in cache, remove it del _settings_cache[key] - _log_cache_access(key, "get", hit=False, security_event="invalid_config") + _log_cache_access( + key, "get", hit=False, security_event="invalid_config" + ) return None _log_cache_access(key, "get", hit=True) @@ -119,7 +143,9 @@ def _get_cached_settings(key: str) -> Any | None: except Exception as e: # Cache access error, log and return None for safety - _log_cache_access(key, "get", hit=False, security_event=f"access_error: {str(e)}") + _log_cache_access( + key, "get", hit=False, security_event=f"access_error: {str(e)}" + ) return None @@ -130,9 +156,13 @@ def _set_cached_settings(key: str, value: Any) -> None: # Validate provider configurations before caching if "provider_config" in key and isinstance(value, dict): # Basic validation: check required fields - if not value.get("provider") or not _is_valid_provider(value.get("provider")): + if not value.get("provider") or not _is_valid_provider( + value.get("provider") + ): _log_cache_access(key, "set", security_event="invalid_config_rejected") - logger.warning(f"Rejected caching of invalid provider config for key: {_sanitize_for_log(key)}") + logger.warning( + f"Rejected caching of invalid provider config for key: {_sanitize_for_log(key)}" + ) return # Calculate integrity checksum @@ -154,7 +184,9 @@ def clear_provider_cache() -> None: cache_size_before = len(_settings_cache) _settings_cache.clear() _log_cache_access("*", "clear") - logger.debug(f"Provider configuration cache cleared ({cache_size_before} entries removed)") + logger.debug( + f"Provider configuration cache cleared ({cache_size_before} entries removed)" + ) def invalidate_provider_cache(provider: str = None) -> None: @@ -171,12 +203,18 @@ def invalidate_provider_cache(provider: str = None) -> None: cache_size_before = len(_settings_cache) _settings_cache.clear() _log_cache_access("*", "invalidate") - logger.debug(f"All provider cache entries invalidated ({cache_size_before} entries)") + logger.debug( + f"All provider cache entries invalidated ({cache_size_before} entries)" + ) else: # Validate provider name before processing if not _is_valid_provider(provider): - _log_cache_access(provider, "invalidate", security_event="invalid_provider_name") - logger.warning(f"Rejected cache invalidation for invalid provider: {_sanitize_for_log(provider)}") + _log_cache_access( + provider, "invalidate", security_event="invalid_provider_name" + ) + logger.warning( + f"Rejected cache invalidation for invalid provider: {_sanitize_for_log(provider)}" + ) return # Clear specific provider entries @@ -190,7 +228,9 @@ def invalidate_provider_cache(provider: str = None) -> None: _log_cache_access(key, "invalidate") safe_provider = _sanitize_for_log(provider) - logger.debug(f"Cache entries for provider '{safe_provider}' invalidated: {len(keys_to_remove)} entries removed") + logger.debug( + f"Cache entries for provider '{safe_provider}' invalidated: {len(keys_to_remove)} entries removed" + ) def get_cache_stats() -> dict[str, Any]: @@ -213,13 +253,13 @@ def get_cache_stats() -> dict[str, Any]: "expired_access_attempts": 0, "invalid_config_rejections": 0, "access_errors": 0, - "total_security_events": 0 + "total_security_events": 0, }, "access_patterns": { "recent_cache_hits": 0, "recent_cache_misses": 0, - "hit_rate": 0.0 - } + "hit_rate": 0.0, + }, } # Analyze cache entries @@ -284,13 +324,14 @@ def get_cache_security_report() -> dict[str, Any]: "timestamp": current_time, "analysis_period_hours": 1, "security_events": [], - "recommendations": [] + "recommendations": [], } # Extract security events from last hour recent_threshold = current_time - 3600 security_events = [ - access for access in _cache_access_log + access + for access in _cache_access_log if access["timestamp"] >= recent_threshold and access["security_event"] ] @@ -298,17 +339,33 @@ def get_cache_security_report() -> dict[str, Any]: # Generate recommendations based on security events if len(security_events) > 10: - report["recommendations"].append("High number of security events detected - investigate potential attacks") + report["recommendations"].append( + "High number of security events detected - investigate potential attacks" + ) - integrity_violations = sum(1 for event in security_events if "checksum_mismatch" in event.get("security_event", "")) + integrity_violations = sum( + 1 + for event in security_events + if "checksum_mismatch" in event.get("security_event", "") + ) if integrity_violations > 0: - report["recommendations"].append(f"Cache integrity violations detected ({integrity_violations}) - check for memory corruption or attacks") + report["recommendations"].append( + f"Cache integrity violations detected ({integrity_violations}) - check for memory corruption or attacks" + ) - invalid_configs = sum(1 for event in security_events if "invalid_config" in event.get("security_event", "")) + invalid_configs = sum( + 1 + for event in security_events + if "invalid_config" in event.get("security_event", "") + ) if invalid_configs > 3: - report["recommendations"].append(f"Multiple invalid configuration attempts ({invalid_configs}) - validate data sources") + report["recommendations"].append( + f"Multiple invalid configuration attempts ({invalid_configs}) - validate data sources" + ) return report + + @asynccontextmanager async def get_llm_client( provider: str | None = None, @@ -347,7 +404,9 @@ async def get_llm_client( cache_key = "rag_strategy_settings" rag_settings = _get_cached_settings(cache_key) if rag_settings is None: - rag_settings = await credential_service.get_credentials_by_category("rag_strategy") + rag_settings = await credential_service.get_credentials_by_category( + "rag_strategy" + ) _set_cached_settings(cache_key, rag_settings) logger.debug("Fetched and cached rag_strategy settings") else: @@ -367,7 +426,9 @@ async def get_llm_client( cache_key = f"provider_config_{service_type}" provider_config = _get_cached_settings(cache_key) if provider_config is None: - provider_config = await credential_service.get_active_provider(service_type) + provider_config = await credential_service.get_active_provider( + service_type + ) _set_cached_settings(cache_key, provider_config) logger.debug(f"Fetched and cached {service_type} provider config") else: @@ -376,7 +437,9 @@ async def get_llm_client( provider_name = provider_config["provider"] api_key = provider_config["api_key"] # For Ollama, don't use the base_url from config - let _get_optimal_ollama_instance decide - base_url = provider_config["base_url"] if provider_name != "ollama" else None + base_url = ( + provider_config["base_url"] if provider_name != "ollama" else None + ) # Comprehensive provider validation with security checks if not _is_valid_provider(provider_name): @@ -389,14 +452,98 @@ async def get_llm_client( elif len(api_key) > 500: # Reasonable API key length limit raise ValueError("API key length exceeds security limits") # Additional security: check for suspicious patterns - if any(char in api_key for char in ['\n', '\r', '\t', '\0']): + if any(char in api_key for char in ["\n", "\r", "\t", "\0"]): raise ValueError("API key contains invalid characters") # Sanitize provider name for logging - safe_provider_name = _sanitize_for_log(provider_name) if provider_name else "unknown" + safe_provider_name = ( + _sanitize_for_log(provider_name) if provider_name else "unknown" + ) logger.info(f"Creating LLM client for provider: {safe_provider_name}") - if provider_name == "openai": + if provider_name == "azure-openai": + # Azure OpenAI requires specific configuration + azure_endpoint = await credential_service.get_credential( + "AZURE_OPENAI_ENDPOINT" + ) + azure_api_version = await credential_service.get_credential( + "AZURE_OPENAI_API_VERSION" + ) + + if not azure_endpoint: + raise ValueError( + "Azure OpenAI endpoint not configured. Set AZURE_OPENAI_ENDPOINT." + ) + if not api_key: + raise ValueError( + "Azure OpenAI API key not found. Set AZURE_OPENAI_API_KEY." + ) + if not azure_api_version: + # Default to a stable API version if not specified + azure_api_version = "2024-02-15-preview" + logger.warning( + f"Azure OpenAI API version not set, using default: {azure_api_version}" + ) + + client = openai.AsyncAzureOpenAI( + api_key=api_key, + azure_endpoint=azure_endpoint, + api_version=azure_api_version, + ) + logger.info( + f"Azure OpenAI client created successfully with endpoint: {azure_endpoint}" + ) + + elif provider_name == "aws-bedrock": + # AWS Bedrock requires boto3 SDK for Converse API + try: + import boto3 + from botocore.config import Config as BotoConfig + except ImportError as import_error: + raise ValueError( + "AWS Bedrock support requires boto3. Install with: pip install boto3" + ) from import_error + + # Get AWS credentials + aws_secret_key = await credential_service.get_credential( + "AWS_SECRET_ACCESS_KEY" + ) + aws_region = await credential_service.get_credential("AWS_REGION") + + if not api_key: # api_key is AWS_ACCESS_KEY_ID + raise ValueError("AWS Access Key ID not found. Set AWS_ACCESS_KEY_ID.") + if not aws_secret_key: + raise ValueError( + "AWS Secret Access Key not found. Set AWS_SECRET_ACCESS_KEY." + ) + if not aws_region: + aws_region = "us-east-1" # Default region + logger.warning(f"AWS region not set, using default: {aws_region}") + + # Create boto3 Bedrock runtime client + bedrock_config = BotoConfig( + region_name=aws_region, + signature_version="v4", + retries={"max_attempts": 3, "mode": "adaptive"}, + ) + + bedrock_client = boto3.client( + "bedrock-runtime", + aws_access_key_id=api_key, + aws_secret_access_key=aws_secret_key, + config=bedrock_config, + ) + + # Wrap boto3 client in a compatible interface + # We'll create a custom wrapper that implements the OpenAI-like async interface + from ..adapters.aws_bedrock_adapter import AWSBedrockClientAdapter + + client = AWSBedrockClientAdapter(bedrock_client, aws_region) + logger.info( + f"AWS Bedrock client created successfully in region: {aws_region}" + ) + + elif provider_name == "openai": if api_key: client = openai.AsyncOpenAI(api_key=api_key) logger.info("OpenAI client created successfully") @@ -440,7 +587,9 @@ async def get_llm_client( api_key="ollama", # Required but unused by Ollama base_url=ollama_base_url, ) - logger.info(f"Ollama client created successfully with base URL: {ollama_base_url}") + logger.info( + f"Ollama client created successfully with base URL: {ollama_base_url}" + ) elif provider_name == "google": if not api_key: @@ -448,7 +597,8 @@ async def get_llm_client( client = openai.AsyncOpenAI( api_key=api_key, - base_url=base_url or "https://generativelanguage.googleapis.com/v1beta/openai/", + base_url=base_url + or "https://generativelanguage.googleapis.com/v1beta/openai/", ) logger.info("Google Gemini client created successfully") @@ -474,14 +624,18 @@ async def get_llm_client( elif provider_name == "grok": if not api_key: - raise ValueError("Grok API key not found - set GROK_API_KEY environment variable") + raise ValueError( + "Grok API key not found - set GROK_API_KEY environment variable" + ) # Enhanced Grok API key validation (secure - no key fragments logged) key_format_valid = api_key.startswith("xai-") key_length_valid = len(api_key) >= 20 if not key_format_valid: - logger.warning("Grok API key format validation failed - should start with 'xai-'") + logger.warning( + "Grok API key format validation failed - should start with 'xai-'" + ) if not key_length_valid: logger.warning("Grok API key validation failed - insufficient length") @@ -509,7 +663,9 @@ async def get_llm_client( yield client finally: if client is not None: - safe_provider = _sanitize_for_log(provider_name) if provider_name else "unknown" + safe_provider = ( + _sanitize_for_log(provider_name) if provider_name else "unknown" + ) try: close_method = getattr(client, "aclose", None) @@ -548,24 +704,29 @@ async def get_llm_client( ) - -async def _get_optimal_ollama_instance(instance_type: str | None = None, - use_embedding_provider: bool = False, - base_url_override: str | None = None) -> str: +async def _get_optimal_ollama_instance( + instance_type: str | None = None, + use_embedding_provider: bool = False, + base_url_override: str | None = None, +) -> str: """ Get the optimal Ollama instance URL based on configuration and health status. - + Args: instance_type: Preferred instance type ('chat', 'embedding', 'both', or None) use_embedding_provider: Whether this is for embedding operations base_url_override: Override URL if specified - + Returns: Best available Ollama instance URL """ # If override URL provided, use it directly if base_url_override: - return base_url_override if base_url_override.endswith('/v1') else f"{base_url_override}/v1" + return ( + base_url_override + if base_url_override.endswith("/v1") + else f"{base_url_override}/v1" + ) try: # For now, we don't have multi-instance support, so skip to single instance config @@ -573,25 +734,39 @@ async def _get_optimal_ollama_instance(instance_type: str | None = None, logger.info("Using single instance Ollama configuration") # Get single instance configuration from RAG settings - rag_settings = await credential_service.get_credentials_by_category("rag_strategy") + rag_settings = await credential_service.get_credentials_by_category( + "rag_strategy" + ) # Check if we need embedding provider and have separate embedding URL if use_embedding_provider or instance_type == "embedding": embedding_url = rag_settings.get("OLLAMA_EMBEDDING_URL") if embedding_url: - return embedding_url if embedding_url.endswith('/v1') else f"{embedding_url}/v1" + return ( + embedding_url + if embedding_url.endswith("/v1") + else f"{embedding_url}/v1" + ) # Default to LLM base URL for chat operations - fallback_url = rag_settings.get("LLM_BASE_URL", "http://host.docker.internal:11434") - return fallback_url if fallback_url.endswith('/v1') else f"{fallback_url}/v1" + fallback_url = rag_settings.get( + "LLM_BASE_URL", "http://host.docker.internal:11434" + ) + return fallback_url if fallback_url.endswith("/v1") else f"{fallback_url}/v1" except Exception as e: logger.error(f"Error getting Ollama configuration: {e}") # Final fallback to localhost only if we can't get RAG settings try: - rag_settings = await credential_service.get_credentials_by_category("rag_strategy") - fallback_url = rag_settings.get("LLM_BASE_URL", "http://host.docker.internal:11434") - return fallback_url if fallback_url.endswith('/v1') else f"{fallback_url}/v1" + rag_settings = await credential_service.get_credentials_by_category( + "rag_strategy" + ) + fallback_url = rag_settings.get( + "LLM_BASE_URL", "http://host.docker.internal:11434" + ) + return ( + fallback_url if fallback_url.endswith("/v1") else f"{fallback_url}/v1" + ) except Exception as fallback_error: logger.error(f"Could not retrieve fallback configuration: {fallback_error}") return "http://host.docker.internal:11434/v1" @@ -616,7 +791,9 @@ async def get_embedding_model(provider: str | None = None) -> str: cache_key = "rag_strategy_settings" rag_settings = _get_cached_settings(cache_key) if rag_settings is None: - rag_settings = await credential_service.get_credentials_by_category("rag_strategy") + rag_settings = await credential_service.get_credentials_by_category( + "rag_strategy" + ) _set_cached_settings(cache_key, rag_settings) custom_model = rag_settings.get("EMBEDDING_MODEL", "") else: @@ -624,7 +801,9 @@ async def get_embedding_model(provider: str | None = None) -> str: cache_key = "provider_config_embedding" provider_config = _get_cached_settings(cache_key) if provider_config is None: - provider_config = await credential_service.get_active_provider("embedding") + provider_config = await credential_service.get_active_provider( + "embedding" + ) _set_cached_settings(cache_key, provider_config) provider_name = provider_config["provider"] custom_model = provider_config["embedding_model"] @@ -632,21 +811,40 @@ async def get_embedding_model(provider: str | None = None) -> str: # Comprehensive provider validation for embeddings if not _is_valid_provider(provider_name): safe_provider = _sanitize_for_log(provider_name) - logger.warning(f"Invalid embedding provider: {safe_provider}, falling back to OpenAI") + logger.warning( + f"Invalid embedding provider: {safe_provider}, falling back to OpenAI" + ) provider_name = "openai" # Use custom model if specified (with validation) if custom_model and len(custom_model.strip()) > 0: custom_model = custom_model.strip() # Basic model name validation (check length and basic characters) - if len(custom_model) <= 100 and not any(char in custom_model for char in ['\n', '\r', '\t', '\0']): + if len(custom_model) <= 100 and not any( + char in custom_model for char in ["\n", "\r", "\t", "\0"] + ): return custom_model else: safe_model = _sanitize_for_log(custom_model) - logger.warning(f"Invalid custom embedding model '{safe_model}' for provider '{provider_name}', using default") + logger.warning( + f"Invalid custom embedding model '{safe_model}' for provider '{provider_name}', using default" + ) # Return provider-specific defaults if provider_name == "openai": return "text-embedding-3-small" + elif provider_name == "azure-openai": + # For Azure OpenAI, use the deployment name from settings + # This is required for Azure OpenAI embeddings + deployment = await credential_service.get_credential( + "AZURE_OPENAI_DEPLOYMENT" + ) + if deployment: + return deployment + else: + logger.warning( + "Azure OpenAI deployment not set, using default model name" + ) + return "text-embedding-3-small" elif provider_name == "ollama": # Ollama default embedding model return "nomic-embed-text" @@ -665,6 +863,16 @@ async def get_embedding_model(provider: str | None = None) -> str: # Grok supports OpenAI and Google embedding models through their API # Default to OpenAI's latest for compatibility return "text-embedding-3-small" + elif provider_name == "aws-bedrock": + # AWS Bedrock default embedding model (Amazon Titan Embeddings) + bedrock_model_id = await credential_service.get_credential( + "AWS_BEDROCK_MODEL_ID" + ) + if bedrock_model_id: + return bedrock_model_id + else: + # Default to Titan Embeddings V1 + return "amazon.titan-embed-text-v1" else: # Fallback to OpenAI's model return "text-embedding-3-small" @@ -714,7 +922,7 @@ def is_google_embedding_model(model: str) -> bool: "text-embedding-005", "text-multilingual-embedding-002", "gemini-embedding-001", - "multimodalembedding@001" + "multimodalembedding@001", ] return any(pattern in model_lower for pattern in google_patterns) @@ -738,6 +946,10 @@ def is_valid_embedding_model_for_provider(model: str, provider: str) -> bool: if provider_lower == "openai": return is_openai_embedding_model(model) + elif provider_lower == "azure-openai": + # Azure OpenAI uses deployment names, which can be any string + # so we accept any non-empty string as valid + return bool(model and model.strip()) elif provider_lower == "google": return is_google_embedding_model(model) elif provider_lower in ["openrouter", "anthropic", "grok"]: @@ -771,7 +983,7 @@ def get_supported_embedding_models(provider: str) -> list[str]: openai_models = [ "text-embedding-ada-002", "text-embedding-3-small", - "text-embedding-3-large" + "text-embedding-3-large", ] google_models = [ @@ -779,11 +991,15 @@ def get_supported_embedding_models(provider: str) -> list[str]: "text-embedding-005", "text-multilingual-embedding-002", "gemini-embedding-001", - "multimodalembedding@001" + "multimodalembedding@001", ] if provider_lower == "openai": return openai_models + elif provider_lower == "azure-openai": + # Azure OpenAI uses deployment names which are user-defined + # Return OpenAI models as reference, but actual deployment names may differ + return openai_models elif provider_lower == "google": return google_models elif provider_lower in ["openrouter", "anthropic", "grok"]: @@ -791,6 +1007,14 @@ def get_supported_embedding_models(provider: str) -> list[str]: return openai_models + google_models elif provider_lower == "ollama": return ["nomic-embed-text", "all-minilm", "mxbai-embed-large"] + elif provider_lower == "aws-bedrock": + # AWS Bedrock embedding models + return [ + "amazon.titan-embed-text-v1", + "amazon.titan-embed-text-v2:0", + "cohere.embed-english-v3", + "cohere.embed-multilingual-v3", + ] else: # For unknown providers, assume OpenAI compatibility return openai_models @@ -931,15 +1155,26 @@ def _is_reasoning_text(text: str) -> bool: # Common reasoning text patterns reasoning_indicators = [ - "okay, let's see", "let me think", "first, i need to", "looking at this", - "step by step", "analyzing", "breaking this down", "considering", - "let me work through", "i should", "thinking about", "examining" + "okay, let's see", + "let me think", + "first, i need to", + "looking at this", + "step by step", + "analyzing", + "breaking this down", + "considering", + "let me work through", + "i should", + "thinking about", + "examining", ] return any(indicator in text_lower for indicator in reasoning_indicators) -def extract_json_from_reasoning(reasoning_text: str, context_code: str = "", language: str = "") -> str: +def extract_json_from_reasoning( + reasoning_text: str, context_code: str = "", language: str = "" +) -> str: """Extract JSON content from reasoning text, with synthesis fallback.""" if not reasoning_text: return "" @@ -948,8 +1183,10 @@ def extract_json_from_reasoning(reasoning_text: str, context_code: str = "", lan import re # Try to find JSON blocks in markdown - json_block_pattern = r'```(?:json)?\s*(\{.*?\})\s*```' - json_matches = re.findall(json_block_pattern, reasoning_text, re.DOTALL | re.IGNORECASE) + json_block_pattern = r"```(?:json)?\s*(\{.*?\})\s*```" + json_matches = re.findall( + json_block_pattern, reasoning_text, re.DOTALL | re.IGNORECASE + ) for match in json_matches: try: @@ -960,14 +1197,16 @@ def extract_json_from_reasoning(reasoning_text: str, context_code: str = "", lan continue # Try to find standalone JSON objects - json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}' + json_pattern = r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}" json_matches = re.findall(json_pattern, reasoning_text, re.DOTALL) for match in json_matches: try: parsed = json.loads(match.strip()) # Ensure it has expected structure - if isinstance(parsed, dict) and any(key in parsed for key in ["example_name", "summary", "name", "title"]): + if isinstance(parsed, dict) and any( + key in parsed for key in ["example_name", "summary", "name", "title"] + ): return match.strip() except json.JSONDecodeError: continue @@ -976,7 +1215,9 @@ def extract_json_from_reasoning(reasoning_text: str, context_code: str = "", lan return synthesize_json_from_reasoning(reasoning_text, context_code, language) -def synthesize_json_from_reasoning(reasoning_text: str, context_code: str = "", language: str = "") -> str: +def synthesize_json_from_reasoning( + reasoning_text: str, context_code: str = "", language: str = "" +) -> str: """Generate JSON structure from reasoning text when no JSON is found.""" if not reasoning_text and not context_code: return "" @@ -991,44 +1232,44 @@ def synthesize_json_from_reasoning(reasoning_text: str, context_code: str = "", # Common action patterns in reasoning text and code action_patterns = [ - (r'\b(?:parse|parsing|parsed)\b', 'Parse'), - (r'\b(?:create|creating|created)\b', 'Create'), - (r'\b(?:analyze|analyzing|analyzed)\b', 'Analyze'), - (r'\b(?:extract|extracting|extracted)\b', 'Extract'), - (r'\b(?:generate|generating|generated)\b', 'Generate'), - (r'\b(?:process|processing|processed)\b', 'Process'), - (r'\b(?:load|loading|loaded)\b', 'Load'), - (r'\b(?:handle|handling|handled)\b', 'Handle'), - (r'\b(?:manage|managing|managed)\b', 'Manage'), - (r'\b(?:build|building|built)\b', 'Build'), - (r'\b(?:define|defining|defined)\b', 'Define'), - (r'\b(?:implement|implementing|implemented)\b', 'Implement'), - (r'\b(?:fetch|fetching|fetched)\b', 'Fetch'), - (r'\b(?:connect|connecting|connected)\b', 'Connect'), - (r'\b(?:validate|validating|validated)\b', 'Validate'), + (r"\b(?:parse|parsing|parsed)\b", "Parse"), + (r"\b(?:create|creating|created)\b", "Create"), + (r"\b(?:analyze|analyzing|analyzed)\b", "Analyze"), + (r"\b(?:extract|extracting|extracted)\b", "Extract"), + (r"\b(?:generate|generating|generated)\b", "Generate"), + (r"\b(?:process|processing|processed)\b", "Process"), + (r"\b(?:load|loading|loaded)\b", "Load"), + (r"\b(?:handle|handling|handled)\b", "Handle"), + (r"\b(?:manage|managing|managed)\b", "Manage"), + (r"\b(?:build|building|built)\b", "Build"), + (r"\b(?:define|defining|defined)\b", "Define"), + (r"\b(?:implement|implementing|implemented)\b", "Implement"), + (r"\b(?:fetch|fetching|fetched)\b", "Fetch"), + (r"\b(?:connect|connecting|connected)\b", "Connect"), + (r"\b(?:validate|validating|validated)\b", "Validate"), ] # Technology/concept patterns tech_patterns = [ - (r'\bjson\b', 'JSON'), - (r'\bapi\b', 'API'), - (r'\bfile\b', 'File'), - (r'\bdata\b', 'Data'), - (r'\bcode\b', 'Code'), - (r'\btext\b', 'Text'), - (r'\bcontent\b', 'Content'), - (r'\bresponse\b', 'Response'), - (r'\brequest\b', 'Request'), - (r'\bconfig\b', 'Config'), - (r'\bllm\b', 'LLM'), - (r'\bmodel\b', 'Model'), - (r'\bexample\b', 'Example'), - (r'\bcontext\b', 'Context'), - (r'\basync\b', 'Async'), - (r'\bfunction\b', 'Function'), - (r'\bclass\b', 'Class'), - (r'\bprint\b', 'Output'), - (r'\breturn\b', 'Return'), + (r"\bjson\b", "JSON"), + (r"\bapi\b", "API"), + (r"\bfile\b", "File"), + (r"\bdata\b", "Data"), + (r"\bcode\b", "Code"), + (r"\btext\b", "Text"), + (r"\bcontent\b", "Content"), + (r"\bresponse\b", "Response"), + (r"\brequest\b", "Request"), + (r"\bconfig\b", "Config"), + (r"\bllm\b", "LLM"), + (r"\bmodel\b", "Model"), + (r"\bexample\b", "Example"), + (r"\bcontext\b", "Context"), + (r"\basync\b", "Async"), + (r"\bfunction\b", "Function"), + (r"\bclass\b", "Class"), + (r"\bprint\b", "Output"), + (r"\breturn\b", "Return"), ] # Extract actions and technologies from combined text @@ -1061,8 +1302,12 @@ def synthesize_json_from_reasoning(reasoning_text: str, context_code: str = "", example_name = " ".join(example_name_words[:4]) # Generate summary from reasoning content - reasoning_lines = reasoning_text.split('\n') - meaningful_lines = [line.strip() for line in reasoning_lines if line.strip() and len(line.strip()) > 10] + reasoning_lines = reasoning_text.split("\n") + meaningful_lines = [ + line.strip() + for line in reasoning_lines + if line.strip() and len(line.strip()) > 10 + ] if meaningful_lines: # Take first meaningful sentence for summary base @@ -1084,10 +1329,7 @@ def synthesize_json_from_reasoning(reasoning_text: str, context_code: str = "", summary = summary[:297] + "..." # Create JSON structure - result = { - "example_name": example_name, - "summary": summary - } + result = {"example_name": example_name, "summary": summary} return json.dumps(result) @@ -1127,19 +1369,23 @@ def prepare_chat_completion_params(model: str, params: dict) -> dict: # Remove custom temperature for reasoning models (they only support default temperature=1.0) if reasoning_model and "temperature" in updated_params: original_temp = updated_params.pop("temperature") - logger.debug(f"Removed custom temperature {original_temp} for reasoning model {model} (only supports default temperature=1.0)") + logger.debug( + f"Removed custom temperature {original_temp} for reasoning model {model} (only supports default temperature=1.0)" + ) return updated_params -async def get_embedding_model_with_routing(provider: str | None = None, instance_url: str | None = None) -> tuple[str, str]: +async def get_embedding_model_with_routing( + provider: str | None = None, instance_url: str | None = None +) -> tuple[str, str]: """ Get the embedding model with intelligent routing for multi-instance setups. - + Args: provider: Override provider selection instance_url: Specific instance URL to use - + Returns: Tuple of (model_name, instance_url) for embedding operations """ @@ -1149,14 +1395,21 @@ async def get_embedding_model_with_routing(provider: str | None = None, instance # If specific instance URL provided, use it if instance_url: - final_url = instance_url if instance_url.endswith('/v1') else f"{instance_url}/v1" + final_url = ( + instance_url if instance_url.endswith("/v1") else f"{instance_url}/v1" + ) return model_name, final_url # For Ollama provider, use intelligent instance routing - if provider == "ollama" or (not provider and (await credential_service.get_credentials_by_category("rag_strategy")).get("LLM_PROVIDER") == "ollama"): + if provider == "ollama" or ( + not provider + and ( + await credential_service.get_credentials_by_category("rag_strategy") + ).get("LLM_PROVIDER") + == "ollama" + ): optimal_url = await _get_optimal_ollama_instance( - instance_type="embedding", - use_embedding_provider=True + instance_type="embedding", use_embedding_provider=True ) return model_name, optimal_url @@ -1168,14 +1421,16 @@ async def get_embedding_model_with_routing(provider: str | None = None, instance return "text-embedding-3-small", None -async def validate_provider_instance(provider: str, instance_url: str | None = None) -> dict[str, any]: +async def validate_provider_instance( + provider: str, instance_url: str | None = None +) -> dict[str, any]: """ Validate a provider instance and return health information. - + Args: provider: Provider name (openai, ollama, google, etc.) instance_url: Instance URL for providers that support multiple instances - + Returns: Dictionary with validation results and health status """ @@ -1188,10 +1443,12 @@ async def validate_provider_instance(provider: str, instance_url: str | None = N if not instance_url: instance_url = await _get_optimal_ollama_instance() # Remove /v1 suffix for health checking - if instance_url.endswith('/v1'): + if instance_url.endswith("/v1"): instance_url = instance_url[:-3] - health_status = await model_discovery_service.check_instance_health(instance_url) + health_status = await model_discovery_service.check_instance_health( + instance_url + ) return { "provider": provider, @@ -1200,7 +1457,7 @@ async def validate_provider_instance(provider: str, instance_url: str | None = N "response_time_ms": health_status.response_time_ms, "models_available": health_status.models_available, "error_message": health_status.error_message, - "validation_timestamp": time.time() + "validation_timestamp": time.time(), } else: @@ -1212,7 +1469,7 @@ async def validate_provider_instance(provider: str, instance_url: str | None = N if provider == "openai": # List models to validate API key models = await client.models.list() - model_count = len(models.data) if hasattr(models, 'data') else 0 + model_count = len(models.data) if hasattr(models, "data") else 0 elif provider == "google": # For Google, we can't easily list models, just validate client creation model_count = 1 # Assume available if client creation succeeded @@ -1228,7 +1485,7 @@ async def validate_provider_instance(provider: str, instance_url: str | None = N "response_time_ms": response_time, "models_available": model_count, "error_message": None, - "validation_timestamp": time.time() + "validation_timestamp": time.time(), } except Exception as e: @@ -1240,11 +1497,10 @@ async def validate_provider_instance(provider: str, instance_url: str | None = N "response_time_ms": None, "models_available": 0, "error_message": str(e), - "validation_timestamp": time.time() + "validation_timestamp": time.time(), } - def requires_max_completion_tokens(model_name: str) -> bool: """Backward compatible alias for previous API.""" return is_reasoning_model(model_name)