+
+ {/* Credential rows */}
+ {customCredentials.map((cred, index) => (
+
+ {/* Key name column */}
+
+ updateCredential(index, 'key', e.target.value)}
+ placeholder="Enter key name"
+ className="w-full px-3 py-2 rounded-md bg-white dark:bg-gray-900 border border-gray-300 dark:border-gray-700 text-sm font-mono"
+ />
+
- {/* Credential rows */}
- {customCredentials.map((cred, index) => (
-
- {/* Key name column */}
-
+ {/* Value column with encryption toggle */}
+
+
updateCredential(index, 'key', e.target.value)}
- placeholder="Enter key name"
- className="w-full px-3 py-2 rounded-md bg-white dark:bg-gray-900 border border-gray-300 dark:border-gray-700 text-sm font-mono"
+ type={cred.showValue ? 'text' : 'password'}
+ value={cred.value}
+ onChange={(e) => updateCredential(index, 'value', e.target.value)}
+ placeholder={cred.is_encrypted && !cred.value ? 'Enter new value (encrypted)' : 'Enter value'}
+ className={`w-full px-3 py-2 pr-20 rounded-md border text-sm ${cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]'
+ ? 'bg-gray-100 dark:bg-gray-800 border-gray-200 dark:border-gray-600 text-gray-500 dark:text-gray-400'
+ : 'bg-white dark:bg-gray-900 border-gray-300 dark:border-gray-700'
+ }`}
+ title={cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]'
+ ? 'Click to edit this encrypted credential'
+ : undefined}
/>
-
- {/* Value column with encryption toggle */}
-
-
- updateCredential(index, 'value', e.target.value)}
- placeholder={cred.is_encrypted && !cred.value ? 'Enter new value (encrypted)' : 'Enter value'}
- className={`w-full px-3 py-2 pr-20 rounded-md border text-sm ${
- cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]'
- ? 'bg-gray-100 dark:bg-gray-800 border-gray-200 dark:border-gray-600 text-gray-500 dark:text-gray-400'
- : 'bg-white dark:bg-gray-900 border-gray-300 dark:border-gray-700'
- }`}
- title={cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]'
- ? 'Click to edit this encrypted credential'
- : undefined}
- />
-
- {/* Show/Hide value button */}
- toggleValueVisibility(index)}
- disabled={cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]'}
- className={`absolute right-10 top-1/2 -translate-y-1/2 p-1.5 rounded transition-colors ${
- cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]'
- ? 'cursor-not-allowed opacity-50'
- : 'hover:bg-gray-200 dark:hover:bg-gray-700'
+ {/* Show/Hide value button */}
+ toggleValueVisibility(index)}
+ disabled={cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]'}
+ className={`absolute right-10 top-1/2 -translate-y-1/2 p-1.5 rounded transition-colors ${cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]'
+ ? 'cursor-not-allowed opacity-50'
+ : 'hover:bg-gray-200 dark:hover:bg-gray-700'
}`}
- title={
- cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]'
- ? 'Edit credential to view and modify'
- : cred.showValue ? 'Hide value' : 'Show value'
- }
- >
- {cred.showValue ? (
-
- ) : (
-
- )}
-
-
- {/* Encryption toggle */}
- toggleEncryption(index)}
- disabled={cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]'}
- className={`
+ title={
+ cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]'
+ ? 'Edit credential to view and modify'
+ : cred.showValue ? 'Hide value' : 'Show value'
+ }
+ >
+ {cred.showValue ? (
+
+ ) : (
+
+ )}
+
+
+ {/* Encryption toggle */}
+ toggleEncryption(index)}
+ disabled={cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]'}
+ className={`
absolute right-2 top-1/2 -translate-y-1/2 p-1.5 rounded transition-colors
${cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]'
- ? 'cursor-not-allowed opacity-50 text-pink-400'
- : cred.is_encrypted
- ? 'text-pink-600 dark:text-pink-400 hover:bg-pink-100 dark:hover:bg-pink-900/20'
- : 'text-gray-400 hover:bg-gray-200 dark:hover:bg-gray-700'
- }
- `}
- title={
- cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]'
- ? 'Edit credential to modify encryption'
- : cred.is_encrypted ? 'Encrypted - click to decrypt' : 'Not encrypted - click to encrypt'
+ ? 'cursor-not-allowed opacity-50 text-pink-400'
+ : cred.is_encrypted
+ ? 'text-pink-600 dark:text-pink-400 hover:bg-pink-100 dark:hover:bg-pink-900/20'
+ : 'text-gray-400 hover:bg-gray-200 dark:hover:bg-gray-700'
}
- >
- {cred.is_encrypted ? (
-
- ) : (
-
- )}
-
-
-
-
- {/* Actions column */}
-
- deleteCredential(index)}
- className="p-1 rounded text-gray-400 hover:text-red-600 transition-colors"
- title="Delete credential"
+ `}
+ title={
+ cred.isFromBackend && cred.is_encrypted && cred.value === '[ENCRYPTED]'
+ ? 'Edit credential to modify encryption'
+ : cred.is_encrypted ? 'Encrypted - click to decrypt' : 'Not encrypted - click to encrypt'
+ }
>
-
+ {cred.is_encrypted ? (
+
+ ) : (
+
+ )}
- ))}
-
- {/* Add credential button */}
-
+ {/* Actions column */}
+
+ deleteCredential(index)}
+ className="p-1 rounded text-gray-400 hover:text-red-600 transition-colors"
+ title="Delete credential"
+ >
+
+
+
+
+ ))}
+
+
+ {/* Add credential button */}
+
+
+ {/* Save all changes button */}
+ {hasUnsavedChanges && (
+
+
+ Cancel
+
-
- Add Credential
+ {saving ? (
+ <>
+
+ Saving...
+ >
+ ) : (
+ <>
+
+ Save All Changes
+ >
+ )}
+ )}
- {/* Save all changes button */}
- {hasUnsavedChanges && (
-
-
- Cancel
-
-
- {saving ? (
- <>
-
- Saving...
- >
- ) : (
- <>
-
- Save All Changes
- >
- )}
-
-
- )}
-
- {/* Security Notice */}
-
-
-
-
-
-
- Encrypted credentials are masked after saving. Click on a masked credential to edit it - this allows you to change the value and encryption settings.
-
-
+ {/* Security Notice */}
+
+
+
+
+
+
+ Encrypted credentials are masked after saving. Click on a masked credential to edit it - this allows you to change the value and encryption settings.
+
-
+
+
);
};
\ No newline at end of file
diff --git a/archon-ui-main/src/components/settings/CloudProvidersSection.tsx b/archon-ui-main/src/components/settings/CloudProvidersSection.tsx
new file mode 100644
index 0000000000..63f98ad1c1
--- /dev/null
+++ b/archon-ui-main/src/components/settings/CloudProvidersSection.tsx
@@ -0,0 +1,463 @@
+import { useState, useEffect } from 'react';
+import { Cloud, Lock, Unlock, Eye, EyeOff, Save, Loader, AlertCircle } from 'lucide-react';
+import { Input } from '../ui/Input';
+import { Button } from '../ui/Button';
+import { Card } from '../ui/Card';
+import { credentialsService } from '../../services/credentialsService';
+import { useToast } from '../../features/shared/hooks/useToast';
+
+// Cloud provider configurations
+const CLOUD_PROVIDERS = {
+ azure: {
+ name: 'Azure OpenAI',
+ icon: (
+
+
+
+ ),
+ color: 'blue',
+ credentials: [
+ {
+ key: 'AZURE_OPENAI_API_KEY',
+ label: 'API Key',
+ placeholder: 'Enter your Azure OpenAI API key',
+ type: 'password',
+ encrypted: true,
+ required: true,
+ description: 'Your Azure OpenAI resource API key from Azure Portal'
+ },
+ {
+ key: 'AZURE_OPENAI_ENDPOINT',
+ label: 'Endpoint',
+ placeholder: 'https://your-resource.openai.azure.com/',
+ type: 'text',
+ encrypted: false,
+ required: true,
+ description: 'Your Azure OpenAI resource endpoint URL'
+ },
+ {
+ key: 'AZURE_OPENAI_API_VERSION',
+ label: 'API Version',
+ placeholder: '2024-02-15-preview',
+ type: 'text',
+ encrypted: false,
+ required: true,
+ description: 'API version (YYYY-MM-DD format)'
+ },
+ {
+ key: 'AZURE_OPENAI_DEPLOYMENT',
+ label: 'Deployment Name',
+ placeholder: 'my-gpt4-deployment',
+ type: 'text',
+ encrypted: false,
+ required: false,
+ description: 'Default deployment name (optional, can be set per operation)'
+ }
+ ]
+ },
+ aws: {
+ name: 'AWS Bedrock',
+ icon: (
+
+
+
+ ),
+ color: 'orange',
+ credentials: [
+ {
+ key: 'AWS_ACCESS_KEY_ID',
+ label: 'Access Key ID',
+ placeholder: 'Enter your AWS Access Key ID',
+ type: 'password',
+ encrypted: true,
+ required: true,
+ description: 'AWS IAM Access Key ID'
+ },
+ {
+ key: 'AWS_SECRET_ACCESS_KEY',
+ label: 'Secret Access Key',
+ placeholder: 'Enter your AWS Secret Access Key',
+ type: 'password',
+ encrypted: true,
+ required: true,
+ description: 'AWS IAM Secret Access Key'
+ },
+ {
+ key: 'AWS_REGION',
+ label: 'Region',
+ placeholder: 'us-east-1',
+ type: 'text',
+ encrypted: false,
+ required: true,
+ description: 'AWS region for Bedrock (e.g., us-east-1, us-west-2)'
+ },
+ {
+ key: 'AWS_BEDROCK_MODEL_ID',
+ label: 'Model ID',
+ placeholder: 'anthropic.claude-3-sonnet-20240229-v1:0',
+ type: 'text',
+ encrypted: false,
+ required: false,
+ description: 'Default Bedrock model ID (optional)'
+ }
+ ]
+ }
+};
+
+interface CredentialValue {
+ value: string;
+ showValue: boolean;
+ hasChanges: boolean;
+ originalValue: string;
+ isFromBackend: boolean;
+}
+
+type ProviderKey = keyof typeof CLOUD_PROVIDERS;
+
+export const CloudProvidersSection = () => {
+ const [selectedProvider, setSelectedProvider] = useState
('azure');
+ const [credentialValues, setCredentialValues] = useState>({});
+ const [loading, setLoading] = useState(true);
+ const [saving, setSaving] = useState(false);
+ const [hasUnsavedChanges, setHasUnsavedChanges] = useState(false);
+
+ const { showToast } = useToast();
+
+ useEffect(() => {
+ loadCredentials();
+ }, []);
+
+ useEffect(() => {
+ // Check if there are unsaved changes
+ const hasChanges = Object.values(credentialValues).some(cred => cred.hasChanges);
+ setHasUnsavedChanges(hasChanges);
+ }, [credentialValues]);
+
+ const loadCredentials = async () => {
+ try {
+ setLoading(true);
+
+ // Load all cloud provider credentials
+ const allCredentials = await credentialsService.getAllCredentials();
+
+ const values: Record = {};
+
+ // Initialize all credential values from backend
+ Object.values(CLOUD_PROVIDERS).forEach(provider => {
+ provider.credentials.forEach(cred => {
+ const backendCred = allCredentials.find(c => c.key === cred.key);
+ const isEncryptedFromBackend = backendCred?.is_encrypted && backendCred.value === '[ENCRYPTED]';
+
+ values[cred.key] = {
+ value: backendCred?.value || '',
+ showValue: false,
+ hasChanges: false,
+ originalValue: backendCred?.value || '',
+ isFromBackend: !!backendCred && !backendCred.isNew
+ };
+ });
+ });
+
+ setCredentialValues(values);
+ } catch (err) {
+ console.error('Failed to load cloud credentials:', err);
+ showToast('Failed to load cloud provider credentials', 'error');
+ } finally {
+ setLoading(false);
+ }
+ };
+
+ const updateCredentialValue = (key: string, value: string) => {
+ setCredentialValues(prev => {
+ const current = prev[key];
+ const updated = {
+ ...current,
+ value,
+ hasChanges: true
+ };
+
+ // If editing an encrypted credential from backend, make it editable
+ if (current.isFromBackend && current.value === '[ENCRYPTED]' && value !== '[ENCRYPTED]') {
+ updated.isFromBackend = false;
+ updated.showValue = false;
+ if (value === '') {
+ // If they click to edit but haven't entered anything, clear the placeholder
+ updated.value = '';
+ }
+ }
+
+ return {
+ ...prev,
+ [key]: updated
+ };
+ });
+ };
+
+ const toggleValueVisibility = (key: string) => {
+ const cred = credentialValues[key];
+ if (cred.isFromBackend && cred.value === '[ENCRYPTED]') {
+ showToast('Encrypted credentials cannot be viewed. Edit to make changes.', 'warning');
+ return;
+ }
+
+ setCredentialValues(prev => ({
+ ...prev,
+ [key]: {
+ ...prev[key],
+ showValue: !prev[key].showValue
+ }
+ }));
+ };
+
+ const saveChanges = async () => {
+ setSaving(true);
+ let hasErrors = false;
+ const provider = CLOUD_PROVIDERS[selectedProvider];
+
+ try {
+ for (const credConfig of provider.credentials) {
+ const credValue = credentialValues[credConfig.key];
+
+ if (!credValue.hasChanges) continue;
+
+ // Validate required fields
+ if (credConfig.required && !credValue.value) {
+ showToast(`${credConfig.label} is required`, 'error');
+ hasErrors = true;
+ continue;
+ }
+
+ // Skip if value is still [ENCRYPTED] (unchanged)
+ if (credValue.value === '[ENCRYPTED]') {
+ continue;
+ }
+
+ try {
+ // Check if credential exists
+ const existing = await credentialsService.getCredential(credConfig.key).catch(() => null);
+
+ if (existing) {
+ // Update existing
+ await credentialsService.updateCredential({
+ key: credConfig.key,
+ value: credValue.value,
+ is_encrypted: credConfig.encrypted,
+ category: 'cloud_providers',
+ description: credConfig.description
+ });
+ } else {
+ // Create new
+ await credentialsService.createCredential({
+ key: credConfig.key,
+ value: credValue.value,
+ is_encrypted: credConfig.encrypted,
+ category: 'cloud_providers',
+ description: credConfig.description
+ });
+ }
+ } catch (err) {
+ console.error(`Failed to save ${credConfig.key}:`, err);
+ showToast(`Failed to save ${credConfig.label}`, 'error');
+ hasErrors = true;
+ }
+ }
+
+ if (!hasErrors) {
+ showToast(`${provider.name} credentials saved successfully!`, 'success');
+ await loadCredentials(); // Reload to get fresh data
+ }
+ } finally {
+ setSaving(false);
+ }
+ };
+
+ const discardChanges = () => {
+ loadCredentials();
+ };
+
+ if (loading) {
+ return (
+
+
+
+
+
+ );
+ }
+
+ const currentProvider = CLOUD_PROVIDERS[selectedProvider];
+
+ return (
+
+
+ {/* Description */}
+
+
+
+
+ Configure cloud-based AI services. These providers offer enterprise-grade security, compliance, and regional availability.
+
+
+
+
+ {/* Provider Selector */}
+
+ {(Object.keys(CLOUD_PROVIDERS) as ProviderKey[]).map(providerKey => {
+ const provider = CLOUD_PROVIDERS[providerKey];
+ const isSelected = selectedProvider === providerKey;
+
+ return (
+ setSelectedProvider(providerKey)}
+ className={`
+ flex-1 flex items-center justify-center gap-2 px-4 py-2.5 rounded-md
+ font-medium text-sm transition-all duration-200
+ ${isSelected
+ ? `bg-white dark:bg-gray-900 text-${provider.color}-600 shadow-sm`
+ : 'text-gray-600 dark:text-gray-400 hover:text-gray-900 dark:hover:text-gray-200'
+ }
+ `}
+ >
+
+ {provider.icon}
+
+ {provider.name}
+
+ );
+ })}
+
+
+ {/* Credentials Form */}
+
+ {currentProvider.credentials.map(credConfig => {
+ const credValue = credentialValues[credConfig.key] || {
+ value: '',
+ showValue: false,
+ hasChanges: false,
+ originalValue: '',
+ isFromBackend: false
+ };
+
+ const isEncrypted = credConfig.encrypted;
+ const isFromBackend = credValue.isFromBackend && credValue.value === '[ENCRYPTED]';
+
+ return (
+
+
+ {credConfig.label}
+ {credConfig.required && * }
+
+
+
+
updateCredentialValue(credConfig.key, e.target.value)}
+ placeholder={isFromBackend ? 'Click to edit encrypted value' : credConfig.placeholder}
+ className={`
+ w-full px-4 py-2 pr-24 rounded-md border text-sm
+ ${isFromBackend
+ ? 'bg-gray-100 dark:bg-gray-800 border-gray-200 dark:border-gray-600 text-gray-500 dark:text-gray-400'
+ : 'bg-white dark:bg-gray-900 border-gray-300 dark:border-gray-700'
+ }
+ focus:outline-none focus:ring-2 focus:ring-${currentProvider.color}-500 focus:border-transparent
+ `}
+ title={isFromBackend ? 'Click to edit this encrypted credential' : undefined}
+ />
+
+ {/* Show/Hide button for password fields */}
+ {credConfig.type === 'password' && (
+
toggleValueVisibility(credConfig.key)}
+ disabled={isFromBackend}
+ className={`
+ absolute right-14 top-1/2 -translate-y-1/2 p-1.5 rounded transition-colors
+ ${isFromBackend
+ ? 'cursor-not-allowed opacity-50'
+ : 'hover:bg-gray-200 dark:hover:bg-gray-700'
+ }
+ `}
+ title={isFromBackend ? 'Edit to view' : credValue.showValue ? 'Hide value' : 'Show value'}
+ >
+ {credValue.showValue ? (
+
+ ) : (
+
+ )}
+
+ )}
+
+ {/* Encryption indicator */}
+ {isEncrypted && (
+
+
+
+ )}
+
+
+ {credConfig.description && (
+
+ {credConfig.description}
+
+ )}
+
+ );
+ })}
+
+
+ {/* Save Button */}
+ {hasUnsavedChanges && (
+
+
+ Cancel
+
+
+ {saving ? (
+ <>
+
+ Saving...
+ >
+ ) : (
+ <>
+
+ Save Changes
+ >
+ )}
+
+
+ )}
+
+ {/* Info Box */}
+
+
+
+
Cloud Provider Setup:
+
+ Encrypted credentials are masked after saving
+ Set LLM_PROVIDER to 'azure-openai ' or 'aws-bedrock ' in RAG Settings to use
+ Regional availability and pricing may vary by provider
+
+
+
+
+
+ );
+};
diff --git a/archon-ui-main/src/pages/SettingsPage.tsx b/archon-ui-main/src/pages/SettingsPage.tsx
index 351366161d..f93b83579e 100644
--- a/archon-ui-main/src/pages/SettingsPage.tsx
+++ b/archon-ui-main/src/pages/SettingsPage.tsx
@@ -12,6 +12,7 @@ import {
Bug,
Info,
Database,
+ Cloud,
} from "lucide-react";
import { motion, AnimatePresence } from "framer-motion";
import { useToast } from "../features/shared/hooks/useToast";
@@ -19,6 +20,7 @@ import { useSettings } from "../contexts/SettingsContext";
import { useStaggeredEntrance } from "../hooks/useStaggeredEntrance";
import { FeaturesSection } from "../components/settings/FeaturesSection";
import { APIKeysSection } from "../components/settings/APIKeysSection";
+import { CloudProvidersSection } from "../components/settings/CloudProvidersSection";
import { RAGSettings } from "../components/settings/RAGSettings";
import { CodeExtractionSettings } from "../components/settings/CodeExtractionSettings";
import { IDEGlobalRules } from "../components/settings/IDEGlobalRules";
@@ -199,6 +201,17 @@ export const SettingsPage = () => {
+
+
+
+
+
Cloud Providers in the UI
+INSERT INTO archon_settings (key, encrypted_value, is_encrypted, category, description) VALUES
+('AZURE_OPENAI_API_KEY', NULL, true, 'cloud_providers', 'Azure OpenAI API key from Azure Portal. Configure via Settings > Cloud Providers.'),
+('AZURE_OPENAI_ENDPOINT', NULL, false, 'cloud_providers', 'Azure OpenAI resource endpoint URL (e.g., https://your-resource.openai.azure.com/)'),
+('AZURE_OPENAI_API_VERSION', NULL, false, 'cloud_providers', 'Azure OpenAI API version in YYYY-MM-DD format (e.g., 2024-02-15-preview)'),
+('AZURE_OPENAI_DEPLOYMENT', NULL, false, 'cloud_providers', 'Azure OpenAI deployment name (optional, used for default model)'),
+('AWS_ACCESS_KEY_ID', NULL, true, 'cloud_providers', 'AWS IAM Access Key ID for Bedrock. Configure via Settings > Cloud Providers.'),
+('AWS_SECRET_ACCESS_KEY', NULL, true, 'cloud_providers', 'AWS IAM Secret Access Key for Bedrock. Configure via Settings > Cloud Providers.'),
+('AWS_REGION', NULL, false, 'cloud_providers', 'AWS region for Bedrock service (e.g., us-east-1, us-west-2)'),
+('AWS_BEDROCK_MODEL_ID', NULL, false, 'cloud_providers', 'Default AWS Bedrock model ID (e.g., anthropic.claude-3-sonnet-20240229-v1:0)')
+ON CONFLICT (key) DO NOTHING;
+
-- Code Extraction Settings Migration
-- Adds configurable settings for the code extraction service
diff --git a/python/pyproject.toml b/python/pyproject.toml
index 128e433290..694d0cbf93 100644
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -45,6 +45,7 @@ server = [
"asyncpg>=0.29.0",
# AI/ML libraries
"openai==1.71.0",
+ "boto3>=1.34.0", # AWS SDK for Bedrock
# Document processing
"pypdf2>=3.0.1",
"pdfplumber>=0.11.6",
@@ -123,6 +124,7 @@ all = [
"supabase==2.15.1",
"asyncpg>=0.29.0",
"openai==1.71.0",
+ "boto3>=1.34.0",
"pypdf2>=3.0.1",
"pdfplumber>=0.11.6",
"python-docx>=1.1.2",
diff --git a/python/src/server/adapters/__init__.py b/python/src/server/adapters/__init__.py
new file mode 100644
index 0000000000..247e00de25
--- /dev/null
+++ b/python/src/server/adapters/__init__.py
@@ -0,0 +1,7 @@
+"""
+Adapters for external LLM providers that don't have OpenAI-compatible APIs.
+"""
+
+from .aws_bedrock_adapter import AWSBedrockClientAdapter
+
+__all__ = ["AWSBedrockClientAdapter"]
diff --git a/python/src/server/adapters/aws_bedrock_adapter.py b/python/src/server/adapters/aws_bedrock_adapter.py
new file mode 100644
index 0000000000..a30411da60
--- /dev/null
+++ b/python/src/server/adapters/aws_bedrock_adapter.py
@@ -0,0 +1,374 @@
+"""
+AWS Bedrock Client Adapter
+
+Provides an OpenAI-compatible interface for AWS Bedrock's Converse API.
+This adapter wraps the boto3 bedrock-runtime client to work with our existing
+LLM provider infrastructure.
+"""
+
+import asyncio
+import json
+from typing import Any
+
+from ..config.logfire_config import get_logger
+
+logger = get_logger(__name__)
+
+
+class AWSBedrockClientAdapter:
+ """
+ Adapter to make AWS Bedrock Converse API compatible with OpenAI-style async clients.
+
+ This adapter implements the minimum interface needed for our LLM operations,
+ translating between OpenAI's chat completion format and AWS Bedrock's Converse API.
+ """
+
+ def __init__(self, bedrock_client: Any, region: str):
+ """
+ Initialize the AWS Bedrock adapter.
+
+ Args:
+ bedrock_client: boto3 bedrock-runtime client instance
+ region: AWS region for the Bedrock service
+ """
+ self.bedrock_client = bedrock_client
+ self.region = region
+ self._executor = None # Will be created on first use
+
+ def _get_executor(self):
+ """Get or create thread pool executor for async operations."""
+ if self._executor is None:
+ import concurrent.futures
+
+ self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
+ return self._executor
+
+ async def aclose(self):
+ """Close the adapter and cleanup resources."""
+ if self._executor is not None:
+ self._executor.shutdown(wait=True)
+ self._executor = None
+ logger.debug("AWS Bedrock adapter closed")
+
+ async def close(self):
+ """Alias for aclose() for compatibility."""
+ await self.aclose()
+
+ @property
+ def chat(self):
+ """Return chat completions interface."""
+ return self
+
+ @property
+ def completions(self):
+ """Return completions interface."""
+ return ChatCompletions(self)
+
+ def _openai_to_bedrock_messages(self, openai_messages: list[dict]) -> list[dict]:
+ """
+ Convert OpenAI message format to Bedrock Converse API format.
+
+ OpenAI format:
+ [{"role": "system", "content": "..."},
+ {"role": "user", "content": "..."},
+ {"role": "assistant", "content": "..."}]
+
+ Bedrock format:
+ [{"role": "user", "content": [{"text": "..."}]},
+ {"role": "assistant", "content": [{"text": "..."}]}]
+
+ Note: Bedrock handles system prompts separately, not in messages.
+ """
+ bedrock_messages = []
+ system_prompt = None
+
+ for msg in openai_messages:
+ role = msg.get("role", "user")
+ content = msg.get("content", "")
+
+ if role == "system":
+ # Bedrock handles system prompts separately
+ system_prompt = content
+ continue
+
+ # Convert role (Bedrock uses "user" and "assistant")
+ bedrock_role = "assistant" if role == "assistant" else "user"
+
+ # Convert content to Bedrock format
+ if isinstance(content, str):
+ bedrock_content = [{"text": content}]
+ elif isinstance(content, list):
+ # Handle multimodal content if needed
+ bedrock_content = []
+ for item in content:
+ if isinstance(item, dict):
+ if item.get("type") == "text":
+ bedrock_content.append({"text": item.get("text", "")})
+ # Add support for images if needed in future
+ elif isinstance(item, str):
+ bedrock_content.append({"text": item})
+ else:
+ bedrock_content = [{"text": str(content)}]
+
+ bedrock_messages.append({"role": bedrock_role, "content": bedrock_content})
+
+ return bedrock_messages, system_prompt
+
+ def _bedrock_to_openai_response(self, bedrock_response: dict, model: str) -> dict:
+ """
+ Convert Bedrock Converse API response to OpenAI format.
+
+ Bedrock response:
+ {
+ "output": {
+ "message": {
+ "role": "assistant",
+ "content": [{"text": "..."}]
+ }
+ },
+ "stopReason": "end_turn",
+ "usage": {
+ "inputTokens": 10,
+ "outputTokens": 20,
+ "totalTokens": 30
+ }
+ }
+
+ OpenAI format:
+ {
+ "id": "chatcmpl-xxx",
+ "object": "chat.completion",
+ "created": 1234567890,
+ "model": "model-name",
+ "choices": [{
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": "..."
+ },
+ "finish_reason": "stop"
+ }],
+ "usage": {
+ "prompt_tokens": 10,
+ "completion_tokens": 20,
+ "total_tokens": 30
+ }
+ }
+ """
+ import time
+ import uuid
+
+ # Extract message content
+ output_message = bedrock_response.get("output", {}).get("message", {})
+ content_blocks = output_message.get("content", [])
+
+ # Combine text blocks
+ content = " ".join(
+ block.get("text", "") for block in content_blocks if "text" in block
+ )
+
+ # Map stop reason
+ stop_reason_map = {
+ "end_turn": "stop",
+ "max_tokens": "length",
+ "stop_sequence": "stop",
+ "content_filtered": "content_filter",
+ }
+ finish_reason = stop_reason_map.get(
+ bedrock_response.get("stopReason", "end_turn"), "stop"
+ )
+
+ # Extract usage
+ bedrock_usage = bedrock_response.get("usage", {})
+
+ return {
+ "id": f"chatcmpl-{uuid.uuid4().hex[:24]}",
+ "object": "chat.completion",
+ "created": int(time.time()),
+ "model": model,
+ "choices": [
+ {
+ "index": 0,
+ "message": {"role": "assistant", "content": content},
+ "finish_reason": finish_reason,
+ }
+ ],
+ "usage": {
+ "prompt_tokens": bedrock_usage.get("inputTokens", 0),
+ "completion_tokens": bedrock_usage.get("outputTokens", 0),
+ "total_tokens": bedrock_usage.get("totalTokens", 0),
+ },
+ }
+
+ async def create(
+ self,
+ model: str,
+ messages: list[dict],
+ temperature: float = 1.0,
+ max_tokens: int | None = None,
+ max_completion_tokens: int | None = None,
+ stream: bool = False,
+ **kwargs: Any,
+ ) -> dict:
+ """
+ Create a chat completion using AWS Bedrock Converse API.
+
+ Args:
+ model: Bedrock model ID (e.g., "anthropic.claude-3-sonnet-20240229-v1:0")
+ messages: List of messages in OpenAI format
+ temperature: Sampling temperature (0.0 to 1.0)
+ max_tokens: Maximum tokens to generate (OpenAI style)
+ max_completion_tokens: Maximum tokens to generate (new OpenAI style)
+ stream: Whether to stream the response (not supported yet)
+ **kwargs: Additional parameters
+
+ Returns:
+ Chat completion response in OpenAI format
+ """
+ if stream:
+ raise NotImplementedError(
+ "Streaming is not yet supported for AWS Bedrock adapter"
+ )
+
+ # Convert messages to Bedrock format
+ bedrock_messages, system_prompt = self._openai_to_bedrock_messages(messages)
+
+ # Determine max tokens (prefer max_completion_tokens if provided)
+ max_tokens_value = max_completion_tokens or max_tokens or 2048
+
+ # Build inference configuration
+ inference_config = {
+ "temperature": temperature,
+ "maxTokens": max_tokens_value,
+ }
+
+ # Add top_p if provided
+ if "top_p" in kwargs:
+ inference_config["topP"] = kwargs["top_p"]
+
+ # Build converse request
+ converse_params = {
+ "modelId": model,
+ "messages": bedrock_messages,
+ "inferenceConfig": inference_config,
+ }
+
+ # Add system prompt if present
+ if system_prompt:
+ converse_params["system"] = [{"text": system_prompt}]
+
+ try:
+ # Call Bedrock Converse API asynchronously
+ loop = asyncio.get_event_loop()
+ executor = self._get_executor()
+
+ bedrock_response = await loop.run_in_executor(
+ executor, lambda: self.bedrock_client.converse(**converse_params)
+ )
+
+ # Convert response to OpenAI format
+ openai_response = self._bedrock_to_openai_response(bedrock_response, model)
+
+ logger.debug(
+ f"AWS Bedrock completion successful. Tokens used: {openai_response['usage']['total_tokens']}"
+ )
+
+ return openai_response
+
+ except Exception as e:
+ logger.error(f"Error calling AWS Bedrock Converse API: {e}")
+ raise
+
+
+class ChatCompletions:
+ """Chat completions interface wrapper."""
+
+ def __init__(self, adapter: AWSBedrockClientAdapter):
+ self.adapter = adapter
+
+ async def create(self, *args, **kwargs):
+ """Create a chat completion."""
+ return await self.adapter.create(*args, **kwargs)
+
+
+# Bedrock embedding adapter for future use
+class AWSBedrockEmbeddingAdapter:
+ """
+ Adapter for AWS Bedrock embeddings.
+
+ AWS Bedrock supports embedding models like Amazon Titan Embeddings.
+ This adapter will be used by the embedding service.
+ """
+
+ def __init__(self, bedrock_client: Any, region: str):
+ """
+ Initialize the AWS Bedrock embedding adapter.
+
+ Args:
+ bedrock_client: boto3 bedrock-runtime client instance
+ region: AWS region for the Bedrock service
+ """
+ self.bedrock_client = bedrock_client
+ self.region = region
+ self._executor = None
+
+ def _get_executor(self):
+ """Get or create thread pool executor for async operations."""
+ if self._executor is None:
+ import concurrent.futures
+
+ self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
+ return self._executor
+
+ async def aclose(self):
+ """Close the adapter and cleanup resources."""
+ if self._executor is not None:
+ self._executor.shutdown(wait=True)
+ self._executor = None
+ logger.debug("AWS Bedrock embedding adapter closed")
+
+ async def generate_embeddings(
+ self, texts: list[str], model: str = "amazon.titan-embed-text-v1"
+ ) -> list[list[float]]:
+ """
+ Generate embeddings using AWS Bedrock.
+
+ Args:
+ texts: List of text strings to embed
+ model: Bedrock embedding model ID
+
+ Returns:
+ List of embedding vectors
+ """
+ embeddings = []
+
+ try:
+ loop = asyncio.get_event_loop()
+ executor = self._get_executor()
+
+ for text in texts:
+ # Prepare request body for Titan Embeddings
+ request_body = json.dumps({"inputText": text})
+
+ # Call invoke_model asynchronously (bind request_body in lambda)
+ response = await loop.run_in_executor(
+ executor,
+ lambda body=request_body: self.bedrock_client.invoke_model(
+ modelId=model,
+ contentType="application/json",
+ accept="application/json",
+ body=body,
+ ),
+ )
+
+ # Parse response
+ response_body = json.loads(response["body"].read())
+ embedding = response_body.get("embedding", [])
+ embeddings.append(embedding)
+
+ logger.debug(f"Generated {len(embeddings)} embeddings using AWS Bedrock")
+ return embeddings
+
+ except Exception as e:
+ logger.error(f"Error generating embeddings with AWS Bedrock: {e}")
+ raise
diff --git a/python/src/server/api_routes/knowledge_api.py b/python/src/server/api_routes/knowledge_api.py
index 052f75216e..620b7c2350 100644
--- a/python/src/server/api_routes/knowledge_api.py
+++ b/python/src/server/api_routes/knowledge_api.py
@@ -26,7 +26,11 @@
from ..services.crawling import CrawlingService
from ..services.credential_service import credential_service
from ..services.embeddings.provider_error_adapters import ProviderErrorFactory
-from ..services.knowledge import DatabaseMetricsService, KnowledgeItemService, KnowledgeSummaryService
+from ..services.knowledge import (
+ DatabaseMetricsService,
+ KnowledgeItemService,
+ KnowledgeSummaryService,
+)
from ..services.search.rag_service import RAGService
from ..services.storage import DocumentStorageService
from ..utils import get_supabase_client
@@ -50,39 +54,50 @@
#
# The hardcoded limit of 3 protects the server from being overwhelmed by multiple users
# starting crawls at the same time. Each crawl can still process many pages in parallel.
-CONCURRENT_CRAWL_LIMIT = 3 # Max simultaneous crawl operations (protects server resources)
+CONCURRENT_CRAWL_LIMIT = (
+ 3 # Max simultaneous crawl operations (protects server resources)
+)
crawl_semaphore = asyncio.Semaphore(CONCURRENT_CRAWL_LIMIT)
# Track active async crawl tasks for cancellation support
active_crawl_tasks: dict[str, asyncio.Task] = {}
-
-
async def _validate_provider_api_key(provider: str = None) -> None:
"""Validate LLM provider API key before starting operations."""
logger.info("🔑 Starting API key validation...")
-
+
try:
# Basic provider validation
if not provider:
provider = "openai"
else:
- # Simple provider validation
- allowed_providers = {"openai", "ollama", "google", "openrouter", "anthropic", "grok"}
+ # Simple provider validation - include cloud providers
+ allowed_providers = {
+ "openai",
+ "azure-openai", # Azure OpenAI cloud provider
+ "aws-bedrock", # AWS Bedrock cloud provider
+ "ollama",
+ "google",
+ "openrouter",
+ "anthropic",
+ "grok",
+ }
if provider not in allowed_providers:
raise HTTPException(
status_code=400,
detail={
"error": "Invalid provider name",
"message": f"Provider '{provider}' not supported",
- "error_type": "validation_error"
- }
+ "error_type": "validation_error",
+ },
)
# Basic sanitization for logging
safe_provider = provider[:20] # Limit length
- logger.info(f"🔑 Testing {safe_provider.title()} API key with minimal embedding request...")
+ logger.info(
+ f"🔑 Testing {safe_provider.title()} API key with minimal embedding request..."
+ )
try:
# Test API key with minimal embedding request using provider-scoped configuration
@@ -117,7 +132,7 @@ async def _validate_provider_api_key(provider: str = None) -> None:
"provider": provider,
},
)
-
+
logger.info(f"✅ {provider.title()} API key validation successful")
except HTTPException:
@@ -127,9 +142,13 @@ async def _validate_provider_api_key(provider: str = None) -> None:
except Exception as e:
# Sanitize error before logging to prevent sensitive data exposure
error_str = str(e)
- sanitized_error = ProviderErrorFactory.sanitize_provider_error(error_str, provider or "openai")
- logger.error(f"❌ Caught exception during API key validation: {sanitized_error}")
-
+ sanitized_error = ProviderErrorFactory.sanitize_provider_error(
+ error_str, provider or "openai"
+ )
+ logger.error(
+ f"❌ Caught exception during API key validation: {sanitized_error}"
+ )
+
# Always fail for any exception during validation - better safe than sorry
logger.error("🚨 API key validation failed - blocking crawl operation")
raise HTTPException(
@@ -138,8 +157,8 @@ async def _validate_provider_api_key(provider: str = None) -> None:
"error": "Invalid API key",
"message": f"Please verify your {(provider or 'openai').title()} API key in Settings before starting a crawl.",
"error_type": "authentication_failed",
- "provider": provider or "openai"
- }
+ "provider": provider or "openai",
+ },
) from None
@@ -183,7 +202,7 @@ class RagQueryRequest(BaseModel):
@router.get("/crawl-progress/{progress_id}")
async def get_crawl_progress(progress_id: str):
"""Get crawl progress for polling.
-
+
Returns the current state of a crawl operation.
Frontend should poll this endpoint to track crawl progress.
"""
@@ -193,11 +212,16 @@ async def get_crawl_progress(progress_id: str):
# Get progress from the tracker's in-memory storage
progress_data = ProgressTracker.get_progress(progress_id)
- safe_logfire_info(f"Crawl progress requested | progress_id={progress_id} | found={progress_data is not None}")
+ safe_logfire_info(
+ f"Crawl progress requested | progress_id={progress_id} | found={progress_data is not None}"
+ )
if not progress_data:
# Return 404 if no progress exists - this is correct behavior
- raise HTTPException(status_code=404, detail={"error": f"No progress found for ID: {progress_id}"})
+ raise HTTPException(
+ status_code=404,
+ detail={"error": f"No progress found for ID: {progress_id}"},
+ )
# Ensure we have the progress_id in the data
progress_data["progress_id"] = progress_id
@@ -219,7 +243,9 @@ async def get_crawl_progress(progress_id: str):
return response_data
except Exception as e:
- safe_logfire_error(f"Failed to get crawl progress | error={str(e)} | progress_id={progress_id}")
+ safe_logfire_error(
+ f"Failed to get crawl progress | error={str(e)} | progress_id={progress_id}"
+ )
raise HTTPException(status_code=500, detail={"error": str(e)})
@@ -237,7 +263,10 @@ async def get_knowledge_sources():
@router.get("/knowledge-items")
async def get_knowledge_items(
- page: int = 1, per_page: int = 20, knowledge_type: str | None = None, search: str | None = None
+ page: int = 1,
+ per_page: int = 20,
+ knowledge_type: str | None = None,
+ search: str | None = None,
):
"""Get knowledge items with pagination and filtering."""
try:
@@ -257,16 +286,19 @@ async def get_knowledge_items(
@router.get("/knowledge-items/summary")
async def get_knowledge_items_summary(
- page: int = 1, per_page: int = 20, knowledge_type: str | None = None, search: str | None = None
+ page: int = 1,
+ per_page: int = 20,
+ knowledge_type: str | None = None,
+ search: str | None = None,
):
"""
Get lightweight summaries of knowledge items.
-
+
Returns minimal data optimized for frequent polling:
- Only counts, no actual document/code content
- Basic metadata for display
- Efficient batch queries
-
+
Use this endpoint for card displays and frequent polling.
"""
try:
@@ -298,9 +330,13 @@ async def update_knowledge_item(source_id: str, updates: dict):
return result
else:
if "not found" in result.get("error", "").lower():
- raise HTTPException(status_code=404, detail={"error": result.get("error")})
+ raise HTTPException(
+ status_code=404, detail={"error": result.get("error")}
+ )
else:
- raise HTTPException(status_code=500, detail={"error": result.get("error")})
+ raise HTTPException(
+ status_code=500, detail={"error": result.get("error")}
+ )
except HTTPException:
raise
@@ -337,15 +373,21 @@ async def delete_knowledge_item(source_id: str):
}
if result.get("success"):
- safe_logfire_info(f"Knowledge item deleted successfully | source_id={source_id}")
+ safe_logfire_info(
+ f"Knowledge item deleted successfully | source_id={source_id}"
+ )
- return {"success": True, "message": f"Successfully deleted knowledge item {source_id}"}
+ return {
+ "success": True,
+ "message": f"Successfully deleted knowledge item {source_id}",
+ }
else:
safe_logfire_error(
f"Knowledge item deletion failed | source_id={source_id} | error={result.get('error')}"
)
raise HTTPException(
- status_code=500, detail={"error": result.get("error", "Deletion failed")}
+ status_code=500,
+ detail={"error": result.get("error", "Deletion failed")},
)
except Exception as e:
@@ -362,28 +404,25 @@ async def delete_knowledge_item(source_id: str):
@router.get("/knowledge-items/{source_id}/chunks")
async def get_knowledge_item_chunks(
- source_id: str,
- domain_filter: str | None = None,
- limit: int = 20,
- offset: int = 0
+ source_id: str, domain_filter: str | None = None, limit: int = 20, offset: int = 0
):
"""
Get document chunks for a specific knowledge item with pagination.
-
+
Args:
source_id: The source ID
domain_filter: Optional domain filter for URLs
limit: Maximum number of chunks to return (default 20, max 100)
offset: Number of chunks to skip (for pagination)
-
+
Returns:
Paginated chunks with metadata
"""
try:
# Validate pagination parameters
limit = min(limit, 100) # Cap at 100 to prevent excessive data transfer
- limit = max(limit, 1) # At least 1
- offset = max(offset, 0) # Can't be negative
+ limit = max(limit, 1) # At least 1
+ offset = max(offset, 0) # Can't be negative
safe_logfire_info(
f"Fetching chunks | source_id={source_id} | domain_filter={domain_filter} | "
@@ -468,10 +507,17 @@ async def get_knowledge_item_chunks(
for line in lines:
line = line.strip()
# Skip code blocks, empty lines, and very short lines
- if (line and not line.startswith("```") and not line.startswith("Source:")
- and len(line) > 15 and len(line) < 80
- and not line.startswith("from ") and not line.startswith("import ")
- and "=" not in line and "{" not in line):
+ if (
+ line
+ and not line.startswith("```")
+ and not line.startswith("Source:")
+ and len(line) > 15
+ and len(line) < 80
+ and not line.startswith("from ")
+ and not line.startswith("import ")
+ and "=" not in line
+ and "{" not in line
+ ):
title = line
break
@@ -481,17 +527,31 @@ async def get_knowledge_item_chunks(
if url:
# Extract meaningful part from URL
if url.endswith(".txt"):
- title = url.split("/")[-1].replace(".txt", "").replace("-", " ").title()
+ title = (
+ url.split("/")[-1]
+ .replace(".txt", "")
+ .replace("-", " ")
+ .title()
+ )
else:
# Get domain and path info
parsed = urlparse(url)
if parsed.path and parsed.path != "/":
- title = parsed.path.strip("/").replace("-", " ").replace("_", " ").title()
+ title = (
+ parsed.path.strip("/")
+ .replace("-", " ")
+ .replace("_", " ")
+ .title()
+ )
else:
title = parsed.netloc.replace("www.", "").title()
chunk["title"] = title or ""
- chunk["section"] = metadata.get("headers", "").replace(";", " > ") if metadata.get("headers") else None
+ chunk["section"] = (
+ metadata.get("headers", "").replace(";", " > ")
+ if metadata.get("headers")
+ else None
+ )
chunk["source_type"] = metadata.get("source_type")
chunk["knowledge_type"] = metadata.get("knowledge_type")
@@ -521,26 +581,24 @@ async def get_knowledge_item_chunks(
@router.get("/knowledge-items/{source_id}/code-examples")
async def get_knowledge_item_code_examples(
- source_id: str,
- limit: int = 20,
- offset: int = 0
+ source_id: str, limit: int = 20, offset: int = 0
):
"""
Get code examples for a specific knowledge item with pagination.
-
+
Args:
source_id: The source ID
limit: Maximum number of examples to return (default 20, max 100)
offset: Number of examples to skip (for pagination)
-
+
Returns:
Paginated code examples with metadata
"""
try:
# Validate pagination parameters
limit = min(limit, 100) # Cap at 100 to prevent excessive data transfer
- limit = max(limit, 1) # At least 1
- offset = max(offset, 0) # Can't be negative
+ limit = max(limit, 1) # At least 1
+ offset = max(offset, 0) # Can't be negative
safe_logfire_info(
f"Fetching code examples | source_id={source_id} | limit={limit} | offset={offset}"
@@ -582,9 +640,13 @@ async def get_knowledge_item_code_examples(
metadata = example.get("metadata", {}) or {}
# Extract fields to match frontend TypeScript types
example["title"] = metadata.get("title") # AI-generated title
- example["example_name"] = metadata.get("example_name") # Same as title for compatibility
+ example["example_name"] = metadata.get(
+ "example_name"
+ ) # Same as title for compatibility
example["language"] = metadata.get("language") # Programming language
- example["file_path"] = metadata.get("file_path") # Original file path if available
+ example["file_path"] = metadata.get(
+ "file_path"
+ ) # Original file path if available
# Note: content field is already at top level from database
# Note: summary field is already at top level from database
@@ -612,14 +674,14 @@ async def get_knowledge_item_code_examples(
@router.post("/knowledge-items/{source_id}/refresh")
async def refresh_knowledge_item(source_id: str):
"""Refresh a knowledge item by re-crawling its URL with the same metadata."""
-
+
# Validate API key before starting expensive refresh operation
logger.info("🔍 About to validate API key for refresh...")
provider_config = await credential_service.get_active_provider("embedding")
provider = provider_config.get("provider", "openai")
await _validate_provider_api_key(provider)
logger.info("✅ API key validation completed successfully for refresh")
-
+
try:
safe_logfire_info(f"Starting knowledge item refresh | source_id={source_id}")
@@ -629,7 +691,8 @@ async def refresh_knowledge_item(source_id: str):
if not existing_item:
raise HTTPException(
- status_code=404, detail={"error": f"Knowledge item {source_id} not found"}
+ status_code=404,
+ detail={"error": f"Knowledge item {source_id} not found"},
)
# Extract metadata
@@ -640,7 +703,8 @@ async def refresh_knowledge_item(source_id: str):
url = metadata.get("original_url") or existing_item.get("url")
if not url:
raise HTTPException(
- status_code=400, detail={"error": "Knowledge item does not have a URL to refresh"}
+ status_code=400,
+ detail={"error": "Knowledge item does not have a URL to refresh"},
)
knowledge_type = metadata.get("knowledge_type", "technical")
tags = metadata.get("tags", [])
@@ -651,26 +715,32 @@ async def refresh_knowledge_item(source_id: str):
# Initialize progress tracker IMMEDIATELY so it's available for polling
from ..utils.progress.progress_tracker import ProgressTracker
+
tracker = ProgressTracker(progress_id, operation_type="crawl")
- await tracker.start({
- "url": url,
- "status": "initializing",
- "progress": 0,
- "log": f"Starting refresh for {url}",
- "source_id": source_id,
- "operation": "refresh",
- "crawl_type": "refresh"
- })
+ await tracker.start(
+ {
+ "url": url,
+ "status": "initializing",
+ "progress": 0,
+ "log": f"Starting refresh for {url}",
+ "source_id": source_id,
+ "operation": "refresh",
+ "crawl_type": "refresh",
+ }
+ )
# Get crawler from CrawlerManager - same pattern as _perform_crawl_with_progress
try:
crawler = await get_crawler()
if crawler is None:
- raise Exception("Crawler not available - initialization may have failed")
+ raise Exception(
+ "Crawler not available - initialization may have failed"
+ )
except Exception as e:
safe_logfire_error(f"Failed to get crawler | error={str(e)}")
raise HTTPException(
- status_code=500, detail={"error": f"Failed to initialize crawler: {str(e)}"}
+ status_code=500,
+ detail={"error": f"Failed to initialize crawler: {str(e)}"},
)
# Use the same crawl orchestration as regular crawl
@@ -736,7 +806,9 @@ async def crawl_knowledge_item(request: KnowledgeItemRequest):
# Basic URL validation
if not request.url.startswith(("http://", "https://")):
- raise HTTPException(status_code=422, detail="URL must start with http:// or https://")
+ raise HTTPException(
+ status_code=422, detail="URL must start with http:// or https://"
+ )
# Validate API key before starting expensive operation
logger.info("🔍 About to validate API key...")
@@ -754,6 +826,7 @@ async def crawl_knowledge_item(request: KnowledgeItemRequest):
# Initialize progress tracker IMMEDIATELY so it's available for polling
from ..utils.progress.progress_tracker import ProgressTracker
+
tracker = ProgressTracker(progress_id, operation_type="crawl")
# Detect crawl type from URL
@@ -764,14 +837,16 @@ async def crawl_knowledge_item(request: KnowledgeItemRequest):
elif url_str.endswith(".txt"):
crawl_type = "llms-txt" if "llms" in url_str.lower() else "text_file"
- await tracker.start({
- "url": url_str,
- "current_url": url_str,
- "crawl_type": crawl_type,
- # Don't override status - let tracker.start() set it to "starting"
- "progress": 0,
- "log": f"Starting crawl for {request.url}"
- })
+ await tracker.start(
+ {
+ "url": url_str,
+ "current_url": url_str,
+ "crawl_type": crawl_type,
+ # Don't override status - let tracker.start() set it to "starting"
+ "progress": 0,
+ "log": f"Starting crawl for {request.url}",
+ }
+ )
# Start background task - no need to track this wrapper task
# The actual crawl task will be stored inside _perform_crawl_with_progress
@@ -795,12 +870,14 @@ class Config:
success=True,
progress_id=progress_id,
message="Crawling started",
- estimated_duration="3-5 minutes"
+ estimated_duration="3-5 minutes",
)
return response.model_dump(by_alias=True)
except Exception as e:
- safe_logfire_error(f"Failed to start crawl | error={str(e)} | url={str(request.url)}")
+ safe_logfire_error(
+ f"Failed to start crawl | error={str(e)} | url={str(request.url)}"
+ )
raise HTTPException(status_code=500, detail=str(e))
@@ -822,7 +899,9 @@ async def _perform_crawl_with_progress(
try:
crawler = await get_crawler()
if crawler is None:
- raise Exception("Crawler not available - initialization may have failed")
+ raise Exception(
+ "Crawler not available - initialization may have failed"
+ )
except Exception as e:
safe_logfire_error(f"Failed to get crawler | error={str(e)}")
await tracker.error(f"Failed to initialize crawler: {str(e)}")
@@ -853,7 +932,9 @@ async def _perform_crawl_with_progress(
f"Stored actual crawl task in active_crawl_tasks | progress_id={progress_id} | task_name={crawl_task.get_name()}"
)
else:
- safe_logfire_error(f"No task returned from orchestrate_crawl | progress_id={progress_id}")
+ safe_logfire_error(
+ f"No task returned from orchestrate_crawl | progress_id={progress_id}"
+ )
# The orchestration service now runs in background and handles all progress updates
safe_logfire_info(
@@ -899,14 +980,14 @@ async def upload_document(
extract_code_examples: bool = Form(True),
):
"""Upload and process a document with progress tracking."""
-
- # Validate API key before starting expensive upload operation
+
+ # Validate API key before starting expensive upload operation
logger.info("🔍 About to validate API key for upload...")
provider_config = await credential_service.get_active_provider("embedding")
provider = provider_config.get("provider", "openai")
await _validate_provider_api_key(provider)
logger.info("✅ API key validation completed successfully for upload")
-
+
try:
# DETAILED LOGGING: Track knowledge_type parameter flow
safe_logfire_info(
@@ -923,11 +1004,19 @@ async def upload_document(
tag_list = []
# Validate tags is a list of strings
if not isinstance(tag_list, list):
- raise HTTPException(status_code=422, detail={"error": "tags must be a JSON array of strings"})
+ raise HTTPException(
+ status_code=422,
+ detail={"error": "tags must be a JSON array of strings"},
+ )
if not all(isinstance(tag, str) for tag in tag_list):
- raise HTTPException(status_code=422, detail={"error": "tags must be a JSON array of strings"})
+ raise HTTPException(
+ status_code=422,
+ detail={"error": "tags must be a JSON array of strings"},
+ )
except json.JSONDecodeError as ex:
- raise HTTPException(status_code=422, detail={"error": f"Invalid tags JSON: {str(ex)}"})
+ raise HTTPException(
+ status_code=422, detail={"error": f"Invalid tags JSON: {str(ex)}"}
+ )
# Read file content immediately to avoid closed file issues
file_content = await file.read()
@@ -939,18 +1028,27 @@ async def upload_document(
# Initialize progress tracker IMMEDIATELY so it's available for polling
from ..utils.progress.progress_tracker import ProgressTracker
+
tracker = ProgressTracker(progress_id, operation_type="upload")
- await tracker.start({
- "filename": file.filename,
- "status": "initializing",
- "progress": 0,
- "log": f"Starting upload for {file.filename}"
- })
+ await tracker.start(
+ {
+ "filename": file.filename,
+ "status": "initializing",
+ "progress": 0,
+ "log": f"Starting upload for {file.filename}",
+ }
+ )
# Start background task for processing with file content and metadata
# Upload tasks can be tracked directly since they don't spawn sub-tasks
upload_task = asyncio.create_task(
_perform_upload_with_progress(
- progress_id, file_content, file_metadata, tag_list, knowledge_type, extract_code_examples, tracker
+ progress_id,
+ file_content,
+ file_metadata,
+ tag_list,
+ knowledge_type,
+ extract_code_examples,
+ tracker,
)
)
# Track the task for cancellation support
@@ -982,6 +1080,7 @@ async def _perform_upload_with_progress(
tracker: "ProgressTracker",
):
"""Perform document upload with progress tracking using service layer."""
+
# Create cancellation check function for document uploads
def check_upload_cancellation():
"""Check if upload task has been cancelled."""
@@ -991,6 +1090,7 @@ def check_upload_cancellation():
# Import ProgressMapper to prevent progress from going backwards
from ..services.crawling.progress_mapper import ProgressMapper
+
progress_mapper = ProgressMapper()
try:
@@ -1002,17 +1102,18 @@ def check_upload_cancellation():
f"Starting document upload with progress tracking | progress_id={progress_id} | filename={filename} | content_type={content_type}"
)
-
# Extract text from document with progress - use mapper for consistent progress
mapped_progress = progress_mapper.map_progress("processing", 50)
await tracker.update(
status="processing",
progress=mapped_progress,
- log=f"Extracting text from {filename}"
+ log=f"Extracting text from {filename}",
)
try:
- extracted_text = extract_text_from_document(file_content, filename, content_type)
+ extracted_text = extract_text_from_document(
+ file_content, filename, content_type
+ )
safe_logfire_info(
f"Document text extracted | filename={filename} | extracted_length={len(extracted_text)} | content_type={content_type}"
)
@@ -1023,7 +1124,9 @@ def check_upload_cancellation():
return
except Exception as ex:
# Other exceptions are system errors - log with full traceback
- logger.error(f"Failed to extract text from document: {filename}", exc_info=True)
+ logger.error(
+ f"Failed to extract text from document: {filename}", exc_info=True
+ )
await tracker.error(f"Failed to extract text from document: {str(ex)}")
return
@@ -1047,10 +1150,9 @@ async def document_progress_callback(
progress=mapped_percentage,
log=message,
currentUrl=f"file://{filename}",
- **(batch_info or {})
+ **(batch_info or {}),
)
-
# Call the service's upload_document method
success, result = await doc_storage_service.upload_document(
file_content=extracted_text,
@@ -1065,12 +1167,14 @@ async def document_progress_callback(
if success:
# Complete the upload with 100% progress
- await tracker.complete({
- "log": "Document uploaded successfully!",
- "chunks_stored": result.get("chunks_stored"),
- "code_examples_stored": result.get("code_examples_stored", 0),
- "sourceId": result.get("source_id"),
- })
+ await tracker.complete(
+ {
+ "log": "Document uploaded successfully!",
+ "chunks_stored": result.get("chunks_stored"),
+ "code_examples_stored": result.get("code_examples_stored", 0),
+ "sourceId": result.get("source_id"),
+ }
+ )
safe_logfire_info(
f"Document uploaded successfully | progress_id={progress_id} | source_id={result.get('source_id')} | chunks_stored={result.get('chunks_stored')} | code_examples_stored={result.get('code_examples_stored', 0)}"
)
@@ -1089,7 +1193,9 @@ async def document_progress_callback(
# Clean up task from registry when done (success or failure)
if progress_id in active_crawl_tasks:
del active_crawl_tasks[progress_id]
- safe_logfire_info(f"Cleaned up upload task from registry | progress_id={progress_id}")
+ safe_logfire_info(
+ f"Cleaned up upload task from registry | progress_id={progress_id}"
+ )
@router.post("/knowledge-items/search")
@@ -1123,7 +1229,7 @@ async def perform_rag_query(request: RagQueryRequest):
query=request.query,
source=request.source,
match_count=request.match_count,
- return_mode=request.return_mode
+ return_mode=request.return_mode,
)
if success:
@@ -1132,7 +1238,8 @@ async def perform_rag_query(request: RagQueryRequest):
return result
else:
raise HTTPException(
- status_code=500, detail={"error": result.get("error", "RAG query failed")}
+ status_code=500,
+ detail={"error": result.get("error", "RAG query failed")},
)
except HTTPException:
raise
@@ -1140,7 +1247,9 @@ async def perform_rag_query(request: RagQueryRequest):
safe_logfire_error(
f"RAG query failed | error={str(e)} | query={request.query[:50]} | source={request.source}"
)
- raise HTTPException(status_code=500, detail={"error": f"RAG query failed: {str(e)}"})
+ raise HTTPException(
+ status_code=500, detail={"error": f"RAG query failed: {str(e)}"}
+ )
@router.post("/rag/code-examples")
@@ -1230,12 +1339,15 @@ async def delete_source(source_id: str):
f"Source deletion failed | source_id={source_id} | error={result_data.get('error')}"
)
raise HTTPException(
- status_code=500, detail={"error": result_data.get("error", "Deletion failed")}
+ status_code=500,
+ detail={"error": result_data.get("error", "Deletion failed")},
)
except HTTPException:
raise
except Exception as e:
- safe_logfire_error(f"Failed to delete source | error={str(e)} | source_id={source_id}")
+ safe_logfire_error(
+ f"Failed to delete source | error={str(e)} | source_id={source_id}"
+ )
raise HTTPException(status_code=500, detail={"error": str(e)})
@@ -1267,7 +1379,7 @@ async def knowledge_health():
"ready": False,
"migration_required": True,
"message": schema_status["message"],
- "migration_instructions": "Open Supabase Dashboard → SQL Editor → Run: migration/add_source_url_display_name.sql"
+ "migration_instructions": "Open Supabase Dashboard → SQL Editor → Run: migration/add_source_url_display_name.sql",
}
# Removed health check logging to reduce console noise
@@ -1280,13 +1392,14 @@ async def knowledge_health():
return result
-
@router.post("/knowledge-items/stop/{progress_id}")
async def stop_crawl_task(progress_id: str):
"""Stop a running crawl task."""
try:
- from ..services.crawling import get_active_orchestration, unregister_orchestration
-
+ from ..services.crawling import (
+ get_active_orchestration,
+ unregister_orchestration,
+ )
safe_logfire_info(f"Stop crawl requested | progress_id={progress_id}")
@@ -1316,24 +1429,32 @@ async def stop_crawl_task(progress_id: str):
if found:
try:
from ..utils.progress.progress_tracker import ProgressTracker
+
# Get current progress from existing tracker, default to 0 if not found
current_state = ProgressTracker.get_progress(progress_id)
- current_progress = current_state.get("progress", 0) if current_state else 0
+ current_progress = (
+ current_state.get("progress", 0) if current_state else 0
+ )
tracker = ProgressTracker(progress_id, operation_type="crawl")
await tracker.update(
status="cancelled",
progress=current_progress,
- log="Crawl cancelled by user"
+ log="Crawl cancelled by user",
)
except Exception:
# Best effort - don't fail the cancellation if tracker update fails
pass
if not found:
- raise HTTPException(status_code=404, detail={"error": "No active task for given progress_id"})
+ raise HTTPException(
+ status_code=404,
+ detail={"error": "No active task for given progress_id"},
+ )
- safe_logfire_info(f"Successfully stopped crawl task | progress_id={progress_id}")
+ safe_logfire_info(
+ f"Successfully stopped crawl task | progress_id={progress_id}"
+ )
return {
"success": True,
"message": "Crawl task stopped successfully",
diff --git a/python/src/server/api_routes/settings_api.py b/python/src/server/api_routes/settings_api.py
index 30de2b9813..ad89124999 100644
--- a/python/src/server/api_routes/settings_api.py
+++ b/python/src/server/api_routes/settings_api.py
@@ -70,7 +70,9 @@ async def list_credentials(category: str | None = None):
for cred in credentials
]
except Exception as e:
- logfire.error(f"Error listing credentials | category={category} | error={str(e)}")
+ logfire.error(
+ f"Error listing credentials | category={category} | error={str(e)}"
+ )
raise HTTPException(status_code=500, detail={"error": str(e)})
@@ -120,7 +122,9 @@ async def create_credential(request: CredentialRequest):
}
else:
logfire.error(f"Failed to save credential | key={request.key}")
- raise HTTPException(status_code=500, detail={"error": "Failed to save credential"})
+ raise HTTPException(
+ status_code=500, detail={"error": "Failed to save credential"}
+ )
except Exception as e:
logfire.error(f"Error creating credential | key={request.key} | error={str(e)}")
@@ -149,7 +153,9 @@ async def get_credential(key: str):
if value is None:
# Check if this is an optional setting with a default value
if key in OPTIONAL_SETTINGS_WITH_DEFAULTS:
- logfire.info(f"Returning default value for optional setting | key={key}")
+ logfire.info(
+ f"Returning default value for optional setting | key={key}"
+ )
return {
"key": key,
"value": OPTIONAL_SETTINGS_WITH_DEFAULTS[key],
@@ -159,7 +165,9 @@ async def get_credential(key: str):
}
logfire.warning(f"Credential not found | key={key}")
- raise HTTPException(status_code=404, detail={"error": f"Credential {key} not found"})
+ raise HTTPException(
+ status_code=404, detail={"error": f"Credential {key} not found"}
+ )
logfire.info(f"Credential retrieved successfully | key={key}")
@@ -218,7 +226,9 @@ async def update_credential(key: str, request: dict[str, Any]):
category = existing.category
if description is None:
description = existing.description
- logfire.info(f"Updating existing credential | key={key} | category={category}")
+ logfire.info(
+ f"Updating existing credential | key={key} | category={category}"
+ )
success = await credential_service.set_credential(
key=key,
@@ -233,10 +243,95 @@ async def update_credential(key: str, request: dict[str, Any]):
f"Credential updated successfully | key={key} | is_encrypted={is_encrypted}"
)
- return {"success": True, "message": f"Credential {key} updated successfully"}
+ # Auto-switch provider when cloud provider credentials are saved
+ if category == "cloud_providers" and value and value.strip():
+ # Determine which provider to switch to based on the credential key
+ provider_switch_map = {
+ "AZURE_OPENAI_API_KEY": "azure-openai",
+ "AZURE_OPENAI_ENDPOINT": "azure-openai",
+ "AWS_ACCESS_KEY_ID": "aws-bedrock",
+ "AWS_SECRET_ACCESS_KEY": "aws-bedrock",
+ }
+
+ target_provider = provider_switch_map.get(key)
+
+ if target_provider:
+ # Check if we have all required credentials for this provider
+ if target_provider == "azure-openai":
+ # Check if we have all Azure OpenAI credentials
+ azure_api_key = await credential_service.get_credential(
+ "AZURE_OPENAI_API_KEY"
+ )
+ azure_endpoint = await credential_service.get_credential(
+ "AZURE_OPENAI_ENDPOINT"
+ )
+ azure_version = await credential_service.get_credential(
+ "AZURE_OPENAI_API_VERSION"
+ )
+
+ if azure_api_key and azure_endpoint and azure_version:
+ # Switch to Azure OpenAI
+ await credential_service.set_credential(
+ key="LLM_PROVIDER",
+ value="azure-openai",
+ is_encrypted=False,
+ category="rag_strategy",
+ description="LLM provider to use: openai, ollama, or google",
+ )
+ # Also update EMBEDDING_PROVIDER to match
+ await credential_service.set_credential(
+ key="EMBEDDING_PROVIDER",
+ value="azure-openai",
+ is_encrypted=False,
+ category="rag_strategy",
+ description="Embedding provider to use (if different from LLM_PROVIDER)",
+ )
+ logfire.info(
+ "Auto-switched LLM_PROVIDER and EMBEDDING_PROVIDER to azure-openai after saving Azure credentials"
+ )
+
+ elif target_provider == "aws-bedrock":
+ # Check if we have all AWS Bedrock credentials
+ aws_access_key = await credential_service.get_credential(
+ "AWS_ACCESS_KEY_ID"
+ )
+ aws_secret_key = await credential_service.get_credential(
+ "AWS_SECRET_ACCESS_KEY"
+ )
+ aws_region = await credential_service.get_credential(
+ "AWS_REGION"
+ )
+
+ if aws_access_key and aws_secret_key and aws_region:
+ # Switch to AWS Bedrock
+ await credential_service.set_credential(
+ key="LLM_PROVIDER",
+ value="aws-bedrock",
+ is_encrypted=False,
+ category="rag_strategy",
+ description="LLM provider to use: openai, ollama, or google",
+ )
+ # Also update EMBEDDING_PROVIDER to match
+ await credential_service.set_credential(
+ key="EMBEDDING_PROVIDER",
+ value="aws-bedrock",
+ is_encrypted=False,
+ category="rag_strategy",
+ description="Embedding provider to use (if different from LLM_PROVIDER)",
+ )
+ logfire.info(
+ "Auto-switched LLM_PROVIDER and EMBEDDING_PROVIDER to aws-bedrock after saving AWS credentials"
+ )
+
+ return {
+ "success": True,
+ "message": f"Credential {key} updated successfully",
+ }
else:
logfire.error(f"Failed to update credential | key={key}")
- raise HTTPException(status_code=500, detail={"error": "Failed to update credential"})
+ raise HTTPException(
+ status_code=500, detail={"error": "Failed to update credential"}
+ )
except Exception as e:
logfire.error(f"Error updating credential | key={key} | error={str(e)}")
@@ -253,10 +348,15 @@ async def delete_credential(key: str):
if success:
logfire.info(f"Credential deleted successfully | key={key}")
- return {"success": True, "message": f"Credential {key} deleted successfully"}
+ return {
+ "success": True,
+ "message": f"Credential {key} deleted successfully",
+ }
else:
logfire.error(f"Failed to delete credential | key={key}")
- raise HTTPException(status_code=500, detail={"error": "Failed to delete credential"})
+ raise HTTPException(
+ status_code=500, detail={"error": "Failed to delete credential"}
+ )
except Exception as e:
logfire.error(f"Error deleting credential | key={key} | error={str(e)}")
@@ -290,19 +390,27 @@ async def database_metrics():
# Get projects count
projects_response = (
- supabase_client.table("archon_projects").select("id", count="exact").execute()
+ supabase_client.table("archon_projects")
+ .select("id", count="exact")
+ .execute()
)
tables_info["projects"] = (
projects_response.count if projects_response.count is not None else 0
)
# Get tasks count
- tasks_response = supabase_client.table("archon_tasks").select("id", count="exact").execute()
- tables_info["tasks"] = tasks_response.count if tasks_response.count is not None else 0
+ tasks_response = (
+ supabase_client.table("archon_tasks").select("id", count="exact").execute()
+ )
+ tables_info["tasks"] = (
+ tasks_response.count if tasks_response.count is not None else 0
+ )
# Get crawled pages count
pages_response = (
- supabase_client.table("archon_crawled_pages").select("id", count="exact").execute()
+ supabase_client.table("archon_crawled_pages")
+ .select("id", count="exact")
+ .execute()
)
tables_info["crawled_pages"] = (
pages_response.count if pages_response.count is not None else 0
@@ -310,7 +418,9 @@ async def database_metrics():
# Get settings count
settings_response = (
- supabase_client.table("archon_settings").select("id", count="exact").execute()
+ supabase_client.table("archon_settings")
+ .select("id", count="exact")
+ .execute()
)
tables_info["settings"] = (
settings_response.count if settings_response.count is not None else 0
@@ -346,46 +456,52 @@ async def settings_health():
@router.post("/credentials/status-check")
async def check_credential_status(request: dict[str, list[str]]):
"""Check status of API credentials by actually decrypting and validating them.
-
+
This endpoint is specifically for frontend status indicators and returns
decrypted credential values for connectivity testing.
"""
try:
credential_keys = request.get("keys", [])
logfire.info(f"Checking status for credentials: {credential_keys}")
-
+
result = {}
-
+
for key in credential_keys:
try:
# Get decrypted value for status checking
- decrypted_value = await credential_service.get_credential(key, decrypt=True)
-
- if decrypted_value and isinstance(decrypted_value, str) and decrypted_value.strip():
+ decrypted_value = await credential_service.get_credential(
+ key, decrypt=True
+ )
+
+ if (
+ decrypted_value
+ and isinstance(decrypted_value, str)
+ and decrypted_value.strip()
+ ):
result[key] = {
"key": key,
"value": decrypted_value,
- "has_value": True
+ "has_value": True,
}
else:
- result[key] = {
- "key": key,
- "value": None,
- "has_value": False
- }
-
+ result[key] = {"key": key, "value": None, "has_value": False}
+
except Exception as e:
- logfire.warning(f"Failed to get credential for status check: {key} | error={str(e)}")
+ logfire.warning(
+ f"Failed to get credential for status check: {key} | error={str(e)}"
+ )
result[key] = {
"key": key,
"value": None,
"has_value": False,
- "error": str(e)
+ "error": str(e),
}
-
- logfire.info(f"Credential status check completed | checked={len(credential_keys)} | found={len([k for k, v in result.items() if v.get('has_value')])}")
+
+ logfire.info(
+ f"Credential status check completed | checked={len(credential_keys)} | found={len([k for k, v in result.items() if v.get('has_value')])}"
+ )
return result
-
+
except Exception as e:
logfire.error(f"Error in credential status check | error={str(e)}")
raise HTTPException(status_code=500, detail={"error": str(e)})
diff --git a/python/src/server/config/config.py b/python/src/server/config/config.py
index d8104bb0ea..709fff5818 100644
--- a/python/src/server/config/config.py
+++ b/python/src/server/config/config.py
@@ -26,6 +26,16 @@ class EnvironmentConfig:
openai_api_key: str | None = None
host: str = "0.0.0.0"
transport: str = "sse"
+ # Azure OpenAI configuration
+ azure_openai_endpoint: str | None = None
+ azure_openai_api_key: str | None = None
+ azure_openai_api_version: str | None = None
+ azure_openai_deployment: str | None = None
+ # AWS Bedrock configuration
+ aws_access_key_id: str | None = None
+ aws_secret_access_key: str | None = None
+ aws_region: str | None = None
+ aws_bedrock_model_id: str | None = None
@dataclass
@@ -66,6 +76,93 @@ def validate_openai_api_key(api_key: str) -> bool:
return True
+def validate_azure_openai_endpoint(endpoint: str) -> bool:
+ """Validate Azure OpenAI endpoint format."""
+ if not endpoint:
+ raise ConfigurationError("Azure OpenAI endpoint cannot be empty")
+
+ parsed = urlparse(endpoint)
+ if parsed.scheme not in ("http", "https"):
+ raise ConfigurationError("Azure OpenAI endpoint must use HTTP or HTTPS")
+
+ if not parsed.netloc:
+ raise ConfigurationError("Invalid Azure OpenAI endpoint format")
+
+ # Azure OpenAI endpoints typically contain '.openai.azure.com'
+ if ".openai.azure.com" not in parsed.netloc:
+ # Warning but allow non-standard endpoints
+ pass
+
+ return True
+
+
+def validate_azure_openai_api_version(api_version: str) -> bool:
+ """Validate Azure OpenAI API version format."""
+ if not api_version:
+ raise ConfigurationError("Azure OpenAI API version cannot be empty")
+
+ # Azure API versions follow YYYY-MM-DD format or YYYY-MM-DD-preview
+ import re
+
+ pattern = r"^\d{4}-\d{2}-\d{2}(-preview)?$"
+ if not re.match(pattern, api_version):
+ raise ConfigurationError(
+ f"Azure OpenAI API version must follow YYYY-MM-DD or YYYY-MM-DD-preview format, got: {api_version}"
+ )
+
+ return True
+
+
+def validate_aws_access_key_id(access_key_id: str) -> bool:
+ """Validate AWS Access Key ID format."""
+ if not access_key_id:
+ raise ConfigurationError("AWS Access Key ID cannot be empty")
+
+ # AWS Access Key IDs are 20 characters long and start with AKIA or ASIA
+ if not (access_key_id.startswith("AKIA") or access_key_id.startswith("ASIA")):
+ raise ConfigurationError(
+ "AWS Access Key ID must start with 'AKIA' (long-term) or 'ASIA' (temporary)"
+ )
+
+ if len(access_key_id) != 20:
+ raise ConfigurationError(
+ f"AWS Access Key ID must be exactly 20 characters, got {len(access_key_id)}"
+ )
+
+ return True
+
+
+def validate_aws_region(region: str) -> bool:
+ """Validate AWS region format."""
+ if not region:
+ raise ConfigurationError("AWS region cannot be empty")
+
+ # Common AWS regions for Bedrock
+ valid_regions = {
+ "us-east-1",
+ "us-east-2",
+ "us-west-1",
+ "us-west-2",
+ "ap-south-1",
+ "ap-northeast-1",
+ "ap-northeast-2",
+ "ap-southeast-1",
+ "ap-southeast-2",
+ "ca-central-1",
+ "eu-central-1",
+ "eu-west-1",
+ "eu-west-2",
+ "eu-west-3",
+ "sa-east-1",
+ }
+
+ if region not in valid_regions:
+ # Warning but allow non-standard regions
+ pass
+
+ return True
+
+
def validate_supabase_key(supabase_key: str) -> tuple[bool, str]:
"""Validate Supabase key type and return validation result.
@@ -137,14 +234,18 @@ def validate_supabase_url(url: str) -> bool:
# Class C: 192.168.0.0/16
# Also includes link-local (169.254.0.0/16) and loopback
# Exclude unspecified address (0.0.0.0) for security
- if (ip.is_private or ip.is_loopback or ip.is_link_local) and not ip.is_unspecified:
+ if (
+ ip.is_private or ip.is_loopback or ip.is_link_local
+ ) and not ip.is_unspecified:
return True
except ValueError:
# hostname is not a valid IP address, could be a domain name
pass
# If not a local host or private IP, require HTTPS
- raise ConfigurationError(f"Supabase URL must use HTTPS for non-local environments (hostname: {hostname})")
+ raise ConfigurationError(
+ f"Supabase URL must use HTTPS for non-local environments (hostname: {hostname})"
+ )
if not parsed.netloc:
raise ConfigurationError("Invalid Supabase URL format")
@@ -157,6 +258,12 @@ def load_environment_config() -> EnvironmentConfig:
# OpenAI API key is optional at startup - can be set via API
openai_api_key = os.getenv("OPENAI_API_KEY")
+ # Azure OpenAI configuration (optional)
+ azure_openai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
+ azure_openai_api_key = os.getenv("AZURE_OPENAI_API_KEY")
+ azure_openai_api_version = os.getenv("AZURE_OPENAI_API_VERSION")
+ azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
+
# Required environment variables for database access
supabase_url = os.getenv("SUPABASE_URL")
if not supabase_url:
@@ -164,11 +271,20 @@ def load_environment_config() -> EnvironmentConfig:
supabase_service_key = os.getenv("SUPABASE_SERVICE_KEY")
if not supabase_service_key:
- raise ConfigurationError("SUPABASE_SERVICE_KEY environment variable is required")
+ raise ConfigurationError(
+ "SUPABASE_SERVICE_KEY environment variable is required"
+ )
# Validate required fields
if openai_api_key:
validate_openai_api_key(openai_api_key)
+
+ # Validate Azure OpenAI configuration if provided
+ if azure_openai_endpoint:
+ validate_azure_openai_endpoint(azure_openai_endpoint)
+ if azure_openai_api_version:
+ validate_azure_openai_api_version(azure_openai_api_version)
+
validate_supabase_url(supabase_url)
# Validate Supabase key type
@@ -217,7 +333,9 @@ def load_environment_config() -> EnvironmentConfig:
try:
port = int(port_str)
except ValueError as e:
- raise ConfigurationError(f"PORT must be a valid integer, got: {port_str}") from e
+ raise ConfigurationError(
+ f"PORT must be a valid integer, got: {port_str}"
+ ) from e
return EnvironmentConfig(
openai_api_key=openai_api_key,
@@ -226,6 +344,10 @@ def load_environment_config() -> EnvironmentConfig:
host=host,
port=port,
transport=transport,
+ azure_openai_endpoint=azure_openai_endpoint,
+ azure_openai_api_key=azure_openai_api_key,
+ azure_openai_api_version=azure_openai_api_version,
+ azure_openai_deployment=azure_openai_deployment,
)
diff --git a/python/src/server/services/credential_service.py b/python/src/server/services/credential_service.py
index a8aee8491d..faf6e72789 100644
--- a/python/src/server/services/credential_service.py
+++ b/python/src/server/services/credential_service.py
@@ -36,8 +36,6 @@ class CredentialItem:
description: str | None = None
-
-
class CredentialService:
"""Service for managing application credentials and configuration."""
@@ -71,7 +69,9 @@ def _get_supabase_client(self) -> Client:
match = re.match(r"https://([^.]+)\.supabase\.co", url)
if match:
project_id = match.group(1)
- logger.debug(f"Supabase client initialized for project: {project_id}")
+ logger.debug(
+ f"Supabase client initialized for project: {project_id}"
+ )
else:
logger.debug("Supabase client initialized successfully")
@@ -157,7 +157,9 @@ async def load_all_credentials(self) -> dict[str, Any]:
logger.error(f"Error loading credentials: {e}")
raise
- async def get_credential(self, key: str, default: Any = None, decrypt: bool = True) -> Any:
+ async def get_credential(
+ self, key: str, default: Any = None, decrypt: bool = True
+ ) -> Any:
"""Get a credential value by key."""
if not self._cache_initialized:
await self.load_all_credentials()
@@ -244,6 +246,7 @@ async def set_credential(
# Also invalidate provider service cache to ensure immediate effect
try:
from .llm_provider_service import clear_provider_cache
+
clear_provider_cache()
logger.debug("Also cleared LLM provider service cache")
except Exception as e:
@@ -252,14 +255,23 @@ async def set_credential(
# Also invalidate LLM provider service cache for provider config
try:
from . import llm_provider_service
+
# Clear the provider config caches that depend on RAG settings
- cache_keys_to_clear = ["provider_config_llm", "provider_config_embedding", "rag_strategy_settings"]
+ cache_keys_to_clear = [
+ "provider_config_llm",
+ "provider_config_embedding",
+ "rag_strategy_settings",
+ ]
for cache_key in cache_keys_to_clear:
if cache_key in llm_provider_service._settings_cache:
del llm_provider_service._settings_cache[cache_key]
- logger.debug(f"Invalidated LLM provider service cache key: {cache_key}")
+ logger.debug(
+ f"Invalidated LLM provider service cache key: {cache_key}"
+ )
except ImportError:
- logger.warning("Could not import llm_provider_service to invalidate cache")
+ logger.warning(
+ "Could not import llm_provider_service to invalidate cache"
+ )
except Exception as e:
logger.error(f"Error invalidating LLM provider service cache: {e}")
@@ -294,6 +306,7 @@ async def delete_credential(self, key: str) -> bool:
# Also invalidate provider service cache to ensure immediate effect
try:
from .llm_provider_service import clear_provider_cache
+
clear_provider_cache()
logger.debug("Also cleared LLM provider service cache")
except Exception as e:
@@ -302,14 +315,23 @@ async def delete_credential(self, key: str) -> bool:
# Also invalidate LLM provider service cache for provider config
try:
from . import llm_provider_service
+
# Clear the provider config caches that depend on RAG settings
- cache_keys_to_clear = ["provider_config_llm", "provider_config_embedding", "rag_strategy_settings"]
+ cache_keys_to_clear = [
+ "provider_config_llm",
+ "provider_config_embedding",
+ "rag_strategy_settings",
+ ]
for cache_key in cache_keys_to_clear:
if cache_key in llm_provider_service._settings_cache:
del llm_provider_service._settings_cache[cache_key]
- logger.debug(f"Invalidated LLM provider service cache key: {cache_key}")
+ logger.debug(
+ f"Invalidated LLM provider service cache key: {cache_key}"
+ )
except ImportError:
- logger.warning("Could not import llm_provider_service to invalidate cache")
+ logger.warning(
+ "Could not import llm_provider_service to invalidate cache"
+ )
except Exception as e:
logger.error(f"Error invalidating LLM provider service cache: {e}")
@@ -341,7 +363,10 @@ async def get_credentials_by_category(self, category: str) -> dict[str, Any]:
try:
supabase = self._get_supabase_client()
result = (
- supabase.table("archon_settings").select("*").eq("category", category).execute()
+ supabase.table("archon_settings")
+ .select("*")
+ .eq("category", category)
+ .execute()
)
credentials = {}
@@ -443,24 +468,53 @@ async def get_active_provider(self, service_type: str = "llm") -> dict[str, Any]
explicit_embedding_provider = rag_settings.get("EMBEDDING_PROVIDER")
# Validate that embedding provider actually supports embeddings
- embedding_capable_providers = {"openai", "google", "ollama"}
+ # Include cloud providers that support embeddings
+ embedding_capable_providers = {
+ "openai",
+ "azure-openai", # Azure OpenAI supports embeddings
+ "aws-bedrock", # AWS Bedrock supports embeddings (Titan, Cohere)
+ "google",
+ "ollama",
+ }
- if (explicit_embedding_provider and
- explicit_embedding_provider != "" and
- explicit_embedding_provider in embedding_capable_providers):
+ if (
+ explicit_embedding_provider
+ and explicit_embedding_provider != ""
+ and explicit_embedding_provider in embedding_capable_providers
+ ):
# Use the explicitly set embedding provider
provider = explicit_embedding_provider
logger.debug(f"Using explicit embedding provider: '{provider}'")
else:
- # Fall back to OpenAI as default embedding provider for backward compatibility
- if explicit_embedding_provider and explicit_embedding_provider not in embedding_capable_providers:
- logger.warning(f"Invalid embedding provider '{explicit_embedding_provider}' doesn't support embeddings, defaulting to OpenAI")
- provider = "openai"
- logger.debug(f"No explicit embedding provider set, defaulting to OpenAI for backward compatibility")
+ # If no explicit embedding provider, check if LLM_PROVIDER supports embeddings
+ llm_provider = rag_settings.get("LLM_PROVIDER", "openai")
+ if llm_provider in embedding_capable_providers:
+ # Use LLM provider for embeddings if it supports them
+ provider = llm_provider
+ logger.debug(f"Using LLM provider for embeddings: '{provider}'")
+ else:
+ # Fall back to OpenAI as default embedding provider for backward compatibility
+ if (
+ explicit_embedding_provider
+ and explicit_embedding_provider
+ not in embedding_capable_providers
+ ):
+ logger.warning(
+ f"Invalid embedding provider '{explicit_embedding_provider}' doesn't support embeddings, defaulting to OpenAI"
+ )
+ provider = "openai"
+ logger.debug(
+ "No valid embedding provider found, defaulting to OpenAI for backward compatibility"
+ )
else:
provider = rag_settings.get("LLM_PROVIDER", "openai")
# Ensure provider is a valid string, not a boolean or other type
- if not isinstance(provider, str) or provider.lower() in ("true", "false", "none", "null"):
+ if not isinstance(provider, str) or provider.lower() in (
+ "true",
+ "false",
+ "none",
+ "null",
+ ):
provider = "openai"
# Get API key for this provider
@@ -504,10 +558,12 @@ async def _get_provider_api_key(self, provider: str) -> str | None:
"""Get API key for a specific provider."""
key_mapping = {
"openai": "OPENAI_API_KEY",
+ "azure-openai": "AZURE_OPENAI_API_KEY",
"google": "GOOGLE_API_KEY",
"openrouter": "OPENROUTER_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"grok": "GROK_API_KEY",
+ "aws-bedrock": "AWS_ACCESS_KEY_ID", # AWS uses access key as primary credential
"ollama": None, # No API key needed
}
@@ -519,7 +575,15 @@ async def _get_provider_api_key(self, provider: str) -> str | None:
def _get_provider_base_url(self, provider: str, rag_settings: dict) -> str | None:
"""Get base URL for provider."""
if provider == "ollama":
- return rag_settings.get("LLM_BASE_URL", "http://host.docker.internal:11434/v1")
+ return rag_settings.get(
+ "LLM_BASE_URL", "http://host.docker.internal:11434/v1"
+ )
+ elif provider == "azure-openai":
+ # Azure OpenAI endpoint will be handled separately with AsyncAzureOpenAI
+ return None
+ elif provider == "aws-bedrock":
+ # AWS Bedrock region will be handled separately with boto3
+ return None
elif provider == "google":
return "https://generativelanguage.googleapis.com/v1beta/openai/"
elif provider == "openrouter":
@@ -530,7 +594,9 @@ def _get_provider_base_url(self, provider: str, rag_settings: dict) -> str | Non
return "https://api.x.ai/v1"
return None # Use default for OpenAI
- async def set_active_provider(self, provider: str, service_type: str = "llm") -> bool:
+ async def set_active_provider(
+ self, provider: str, service_type: str = "llm"
+ ) -> bool:
"""Set the active provider for a service type."""
try:
# For now, we'll update the RAG strategy settings
@@ -541,7 +607,9 @@ async def set_active_provider(self, provider: str, service_type: str = "llm") ->
description=f"Active {service_type} provider",
)
except Exception as e:
- logger.error(f"Error setting active provider {provider} for {service_type}: {e}")
+ logger.error(
+ f"Error setting active provider {provider} for {service_type}: {e}"
+ )
return False
@@ -555,10 +623,16 @@ async def get_credential(key: str, default: Any = None) -> Any:
async def set_credential(
- key: str, value: str, is_encrypted: bool = False, category: str = None, description: str = None
+ key: str,
+ value: str,
+ is_encrypted: bool = False,
+ category: str = None,
+ description: str = None,
) -> bool:
"""Convenience function to set a credential."""
- return await credential_service.set_credential(key, value, is_encrypted, category, description)
+ return await credential_service.set_credential(
+ key, value, is_encrypted, category, description
+ )
async def initialize_credentials() -> None:
diff --git a/python/src/server/services/embeddings/embedding_service.py b/python/src/server/services/embeddings/embedding_service.py
index 87ce390b67..f5bf904c0b 100644
--- a/python/src/server/services/embeddings/embedding_service.py
+++ b/python/src/server/services/embeddings/embedding_service.py
@@ -35,7 +35,9 @@ class EmbeddingBatchResult:
failed_items: list[dict[str, Any]] = field(default_factory=list)
success_count: int = 0
failure_count: int = 0
- texts_processed: list[str] = field(default_factory=list) # Successfully processed texts
+ texts_processed: list[str] = field(
+ default_factory=list
+ ) # Successfully processed texts
def add_success(self, embedding: list[float], text: str):
"""Add a successful embedding."""
@@ -83,10 +85,10 @@ async def create_embeddings(
class OpenAICompatibleEmbeddingAdapter(EmbeddingProviderAdapter):
"""Adapter for providers using the OpenAI embeddings API shape."""
-
+
def __init__(self, client: Any):
self._client = client
-
+
async def create_embeddings(
self,
texts: list[str],
@@ -99,7 +101,37 @@ async def create_embeddings(
}
if dimensions is not None:
request_args["dimensions"] = dimensions
-
+
+ response = await self._client.embeddings.create(**request_args)
+ return [item.embedding for item in response.data]
+
+
+class AzureOpenAIEmbeddingAdapter(EmbeddingProviderAdapter):
+ """Adapter for Azure OpenAI embeddings API."""
+
+ def __init__(self, client: Any):
+ self._client = client
+
+ async def create_embeddings(
+ self,
+ texts: list[str],
+ model: str,
+ dimensions: int | None = None,
+ ) -> list[list[float]]:
+ """
+ Create embeddings using Azure OpenAI.
+
+ For Azure OpenAI, the 'model' parameter should be the deployment name,
+ not the model name (e.g., "my-embedding-deployment" not "text-embedding-3-small").
+ """
+ request_args: dict[str, Any] = {
+ "model": model, # This is the deployment name for Azure
+ "input": texts,
+ }
+ # Azure OpenAI supports dimensions parameter for certain models
+ if dimensions is not None:
+ request_args["dimensions"] = dimensions
+
response = await self._client.embeddings.create(**request_args)
return [item.embedding for item in response.data]
@@ -121,7 +153,9 @@ async def create_embeddings(
async with httpx.AsyncClient(timeout=30.0) as http_client:
embeddings = await asyncio.gather(
*(
- self._fetch_single_embedding(http_client, google_api_key, model, text, dimensions)
+ self._fetch_single_embedding(
+ http_client, google_api_key, model, text, dimensions
+ )
for text in texts
)
)
@@ -139,7 +173,9 @@ async def create_embeddings(
original_error=error,
) from error
except Exception as error:
- search_logger.error(f"Error calling Google embedding API: {error}", exc_info=True)
+ search_logger.error(
+ f"Error calling Google embedding API: {error}", exc_info=True
+ )
raise EmbeddingAPIError(
f"Google embedding error: {str(error)}", original_error=error
) from error
@@ -209,7 +245,9 @@ def _normalize_embedding(self, embedding: list[float]) -> list[float]:
normalized = embedding_array / norm
return normalized.tolist()
else:
- search_logger.warning("Zero-norm embedding detected, returning unnormalized")
+ search_logger.warning(
+ "Zero-norm embedding detected, returning unnormalized"
+ )
return embedding
except Exception as e:
search_logger.error(f"Failed to normalize embedding: {e}")
@@ -221,6 +259,8 @@ def _get_embedding_adapter(provider: str, client: Any) -> EmbeddingProviderAdapt
provider_name = (provider or "").lower()
if provider_name == "google":
return GoogleEmbeddingAdapter()
+ elif provider_name == "azure-openai":
+ return AzureOpenAIEmbeddingAdapter(client)
return OpenAICompatibleEmbeddingAdapter(client)
@@ -229,6 +269,7 @@ async def _maybe_await(value: Any) -> Any:
return await value if inspect.isawaitable(value) else value
+
# Provider-aware client factory
get_openai_client = get_llm_client
@@ -262,7 +303,9 @@ async def create_embedding(text: str, provider: str | None = None) -> list[float
f"OpenAI quota exhausted: {error_msg}", text_preview=text
)
elif "rate" in error_msg.lower():
- raise EmbeddingRateLimitError(f"Rate limit hit: {error_msg}", text_preview=text)
+ raise EmbeddingRateLimitError(
+ f"Rate limit hit: {error_msg}", text_preview=text
+ )
else:
raise EmbeddingAPIError(
f"Failed to create embedding: {error_msg}", text_preview=text
@@ -286,7 +329,9 @@ async def create_embedding(text: str, provider: str | None = None) -> list[float
f"OpenAI quota exhausted: {error_msg}", text_preview=text
)
elif "rate_limit" in error_msg.lower():
- raise EmbeddingRateLimitError(f"Rate limit hit: {error_msg}", text_preview=text)
+ raise EmbeddingRateLimitError(
+ f"Rate limit hit: {error_msg}", text_preview=text
+ )
else:
raise EmbeddingAPIError(
f"Embedding error: {error_msg}", text_preview=text, original_error=e
@@ -327,7 +372,8 @@ async def create_embeddings_batch(
continue
search_logger.error(
- f"Invalid text type at index {i}: {type(text)}, value: {text}", exc_info=True
+ f"Invalid text type at index {i}: {type(text)}, value: {text}",
+ exc_info=True,
)
try:
converted = str(text)
@@ -347,7 +393,9 @@ async def create_embeddings_batch(
threading_service = get_threading_service()
with safe_span(
- "create_embeddings_batch", text_count=len(texts), total_chars=sum(len(t) for t in texts)
+ "create_embeddings_batch",
+ text_count=len(texts),
+ total_chars=sum(len(t) for t in texts),
) as span:
try:
embedding_config = await _maybe_await(
@@ -356,30 +404,45 @@ async def create_embeddings_batch(
embedding_provider = provider or embedding_config.get("provider")
- if not isinstance(embedding_provider, str) or not embedding_provider.strip():
+ if (
+ not isinstance(embedding_provider, str)
+ or not embedding_provider.strip()
+ ):
embedding_provider = "openai"
if not embedding_provider:
search_logger.error("No embedding provider configured")
- raise ValueError("No embedding provider configured. Please set EMBEDDING_PROVIDER environment variable.")
+ raise ValueError(
+ "No embedding provider configured. Please set EMBEDDING_PROVIDER environment variable."
+ )
- search_logger.info(f"Using embedding provider: '{embedding_provider}' (from EMBEDDING_PROVIDER setting)")
- async with get_llm_client(provider=embedding_provider, use_embedding_provider=True) as client:
+ search_logger.info(
+ f"Using embedding provider: '{embedding_provider}' (from EMBEDDING_PROVIDER setting)"
+ )
+ async with get_llm_client(
+ provider=embedding_provider, use_embedding_provider=True
+ ) as client:
# Load batch size and dimensions from settings
try:
rag_settings = await _maybe_await(
credential_service.get_credentials_by_category("rag_strategy")
)
batch_size = int(rag_settings.get("EMBEDDING_BATCH_SIZE", "100"))
- embedding_dimensions = int(rag_settings.get("EMBEDDING_DIMENSIONS", "1536"))
+ embedding_dimensions = int(
+ rag_settings.get("EMBEDDING_DIMENSIONS", "1536")
+ )
except Exception as e:
- search_logger.warning(f"Failed to load embedding settings: {e}, using defaults")
+ search_logger.warning(
+ f"Failed to load embedding settings: {e}, using defaults"
+ )
batch_size = 100
embedding_dimensions = 1536
total_tokens_used = 0
adapter = _get_embedding_adapter(embedding_provider, client)
- dimensions_to_use = embedding_dimensions if embedding_dimensions > 0 else None
+ dimensions_to_use = (
+ embedding_dimensions if embedding_dimensions > 0 else None
+ )
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
@@ -393,28 +456,39 @@ async def create_embeddings_batch(
# Create rate limit progress callback if we have a progress callback
rate_limit_callback = None
if progress_callback:
+
async def rate_limit_callback(data: dict):
# Send heartbeat during rate limit wait
processed = result.success_count + result.failure_count
- message = f"Rate limited: {data.get('message', 'Waiting...')}"
- await progress_callback(message, (processed / len(texts)) * 100)
+ message = (
+ f"Rate limited: {data.get('message', 'Waiting...')}"
+ )
+ await progress_callback(
+ message, (processed / len(texts)) * 100
+ )
# Rate limit each batch
- async with threading_service.rate_limited_operation(batch_tokens, rate_limit_callback):
+ async with threading_service.rate_limited_operation(
+ batch_tokens, rate_limit_callback
+ ):
retry_count = 0
max_retries = 3
while retry_count < max_retries:
try:
# Create embeddings for this batch
- embedding_model = await get_embedding_model(provider=embedding_provider)
+ embedding_model = await get_embedding_model(
+ provider=embedding_provider
+ )
embeddings = await adapter.create_embeddings(
batch,
embedding_model,
dimensions=dimensions_to_use,
)
- for text, vector in zip(batch, embeddings, strict=False):
+ for text, vector in zip(
+ batch, embeddings, strict=False
+ ):
result.add_success(vector, text)
break # Success, exit retry loop
@@ -474,7 +548,9 @@ async def rate_limit_callback(data: dict):
except Exception as e:
# This batch failed - track failures but continue with next batch
- search_logger.error(f"Batch {batch_index} failed: {e}", exc_info=True)
+ search_logger.error(
+ f"Batch {batch_index} failed: {e}", exc_info=True
+ )
for text in batch:
if isinstance(e, EmbeddingError):
@@ -483,7 +559,8 @@ async def rate_limit_callback(data: dict):
result.add_failure(
text,
EmbeddingAPIError(
- f"Failed to create embedding: {str(e)}", original_error=e
+ f"Failed to create embedding: {str(e)}",
+ original_error=e,
),
batch_index,
)
@@ -512,13 +589,18 @@ async def rate_limit_callback(data: dict):
except Exception as e:
# Catastrophic failure - return what we have
span.set_attribute("catastrophic_failure", True)
- search_logger.error(f"Catastrophic failure in batch embedding: {e}", exc_info=True)
+ search_logger.error(
+ f"Catastrophic failure in batch embedding: {e}", exc_info=True
+ )
# Mark remaining texts as failed
processed_count = result.success_count + result.failure_count
for text in texts[processed_count:]:
result.add_failure(
- text, EmbeddingAPIError(f"Catastrophic failure: {str(e)}", original_error=e)
+ text,
+ EmbeddingAPIError(
+ f"Catastrophic failure: {str(e)}", original_error=e
+ ),
)
return result
diff --git a/python/src/server/services/llm_provider_service.py b/python/src/server/services/llm_provider_service.py
index 00197926fd..c383d0bc12 100644
--- a/python/src/server/services/llm_provider_service.py
+++ b/python/src/server/services/llm_provider_service.py
@@ -23,7 +23,16 @@ def _is_valid_provider(provider: str) -> bool:
"""Basic provider validation."""
if not provider or not isinstance(provider, str):
return False
- return provider.lower() in {"openai", "ollama", "google", "openrouter", "anthropic", "grok"}
+ return provider.lower() in {
+ "openai",
+ "azure-openai",
+ "aws-bedrock",
+ "ollama",
+ "google",
+ "openrouter",
+ "anthropic",
+ "grok",
+ }
def _sanitize_for_log(text: str) -> str:
@@ -31,6 +40,7 @@ def _sanitize_for_log(text: str) -> str:
if not text:
return ""
import re
+
sanitized = re.sub(r"sk-[a-zA-Z0-9-_]{20,}", "[REDACTED]", text)
sanitized = re.sub(r"xai-[a-zA-Z0-9-_]{20,}", "[REDACTED]", sanitized)
return sanitized[:100]
@@ -39,7 +49,9 @@ def _sanitize_for_log(text: str) -> str:
# Secure settings cache with TTL and validation
_settings_cache: dict[str, tuple[Any, float, str]] = {} # value, timestamp, checksum
_CACHE_TTL_SECONDS = 300 # 5 minutes
-_cache_access_log: list[dict] = [] # Track cache access patterns for security monitoring
+_cache_access_log: list[dict] = (
+ []
+) # Track cache access patterns for security monitoring
def _calculate_cache_checksum(value: Any) -> str:
@@ -50,13 +62,17 @@ def _calculate_cache_checksum(value: Any) -> str:
# Convert value to JSON string for consistent hashing
try:
value_str = json.dumps(value, sort_keys=True, default=str)
- return hashlib.sha256(value_str.encode()).hexdigest()[:16] # First 16 chars for efficiency
+ return hashlib.sha256(value_str.encode()).hexdigest()[
+ :16
+ ] # First 16 chars for efficiency
except Exception:
# Fallback for non-serializable objects
return hashlib.sha256(str(value).encode()).hexdigest()[:16]
-def _log_cache_access(key: str, action: str, hit: bool = None, security_event: str = None) -> None:
+def _log_cache_access(
+ key: str, action: str, hit: bool = None, security_event: str = None
+) -> None:
"""Log cache access for security monitoring."""
access_entry = {
@@ -64,7 +80,7 @@ def _log_cache_access(key: str, action: str, hit: bool = None, security_event: s
"key": _sanitize_for_log(key),
"action": action, # "get", "set", "invalidate", "clear"
"hit": hit, # For get operations
- "security_event": security_event # "checksum_mismatch", "expired", etc.
+ "security_event": security_event, # "checksum_mismatch", "expired", etc.
}
# Keep only last 100 access entries to prevent memory growth
@@ -98,17 +114,25 @@ def _get_cached_settings(key: str) -> Any | None:
if current_checksum != stored_checksum:
# Cache tampering detected, remove entry
del _settings_cache[key]
- _log_cache_access(key, "get", hit=False, security_event="checksum_mismatch")
- logger.error(f"Cache integrity violation detected for key: {_sanitize_for_log(key)}")
+ _log_cache_access(
+ key, "get", hit=False, security_event="checksum_mismatch"
+ )
+ logger.error(
+ f"Cache integrity violation detected for key: {_sanitize_for_log(key)}"
+ )
return None
# Additional validation for provider configurations
if "provider_config" in key and isinstance(value, dict):
# Basic validation: check required fields
- if not value.get("provider") or not _is_valid_provider(value.get("provider")):
+ if not value.get("provider") or not _is_valid_provider(
+ value.get("provider")
+ ):
# Invalid configuration in cache, remove it
del _settings_cache[key]
- _log_cache_access(key, "get", hit=False, security_event="invalid_config")
+ _log_cache_access(
+ key, "get", hit=False, security_event="invalid_config"
+ )
return None
_log_cache_access(key, "get", hit=True)
@@ -119,7 +143,9 @@ def _get_cached_settings(key: str) -> Any | None:
except Exception as e:
# Cache access error, log and return None for safety
- _log_cache_access(key, "get", hit=False, security_event=f"access_error: {str(e)}")
+ _log_cache_access(
+ key, "get", hit=False, security_event=f"access_error: {str(e)}"
+ )
return None
@@ -130,9 +156,13 @@ def _set_cached_settings(key: str, value: Any) -> None:
# Validate provider configurations before caching
if "provider_config" in key and isinstance(value, dict):
# Basic validation: check required fields
- if not value.get("provider") or not _is_valid_provider(value.get("provider")):
+ if not value.get("provider") or not _is_valid_provider(
+ value.get("provider")
+ ):
_log_cache_access(key, "set", security_event="invalid_config_rejected")
- logger.warning(f"Rejected caching of invalid provider config for key: {_sanitize_for_log(key)}")
+ logger.warning(
+ f"Rejected caching of invalid provider config for key: {_sanitize_for_log(key)}"
+ )
return
# Calculate integrity checksum
@@ -154,7 +184,9 @@ def clear_provider_cache() -> None:
cache_size_before = len(_settings_cache)
_settings_cache.clear()
_log_cache_access("*", "clear")
- logger.debug(f"Provider configuration cache cleared ({cache_size_before} entries removed)")
+ logger.debug(
+ f"Provider configuration cache cleared ({cache_size_before} entries removed)"
+ )
def invalidate_provider_cache(provider: str = None) -> None:
@@ -171,12 +203,18 @@ def invalidate_provider_cache(provider: str = None) -> None:
cache_size_before = len(_settings_cache)
_settings_cache.clear()
_log_cache_access("*", "invalidate")
- logger.debug(f"All provider cache entries invalidated ({cache_size_before} entries)")
+ logger.debug(
+ f"All provider cache entries invalidated ({cache_size_before} entries)"
+ )
else:
# Validate provider name before processing
if not _is_valid_provider(provider):
- _log_cache_access(provider, "invalidate", security_event="invalid_provider_name")
- logger.warning(f"Rejected cache invalidation for invalid provider: {_sanitize_for_log(provider)}")
+ _log_cache_access(
+ provider, "invalidate", security_event="invalid_provider_name"
+ )
+ logger.warning(
+ f"Rejected cache invalidation for invalid provider: {_sanitize_for_log(provider)}"
+ )
return
# Clear specific provider entries
@@ -190,7 +228,9 @@ def invalidate_provider_cache(provider: str = None) -> None:
_log_cache_access(key, "invalidate")
safe_provider = _sanitize_for_log(provider)
- logger.debug(f"Cache entries for provider '{safe_provider}' invalidated: {len(keys_to_remove)} entries removed")
+ logger.debug(
+ f"Cache entries for provider '{safe_provider}' invalidated: {len(keys_to_remove)} entries removed"
+ )
def get_cache_stats() -> dict[str, Any]:
@@ -213,13 +253,13 @@ def get_cache_stats() -> dict[str, Any]:
"expired_access_attempts": 0,
"invalid_config_rejections": 0,
"access_errors": 0,
- "total_security_events": 0
+ "total_security_events": 0,
},
"access_patterns": {
"recent_cache_hits": 0,
"recent_cache_misses": 0,
- "hit_rate": 0.0
- }
+ "hit_rate": 0.0,
+ },
}
# Analyze cache entries
@@ -284,13 +324,14 @@ def get_cache_security_report() -> dict[str, Any]:
"timestamp": current_time,
"analysis_period_hours": 1,
"security_events": [],
- "recommendations": []
+ "recommendations": [],
}
# Extract security events from last hour
recent_threshold = current_time - 3600
security_events = [
- access for access in _cache_access_log
+ access
+ for access in _cache_access_log
if access["timestamp"] >= recent_threshold and access["security_event"]
]
@@ -298,17 +339,33 @@ def get_cache_security_report() -> dict[str, Any]:
# Generate recommendations based on security events
if len(security_events) > 10:
- report["recommendations"].append("High number of security events detected - investigate potential attacks")
+ report["recommendations"].append(
+ "High number of security events detected - investigate potential attacks"
+ )
- integrity_violations = sum(1 for event in security_events if "checksum_mismatch" in event.get("security_event", ""))
+ integrity_violations = sum(
+ 1
+ for event in security_events
+ if "checksum_mismatch" in event.get("security_event", "")
+ )
if integrity_violations > 0:
- report["recommendations"].append(f"Cache integrity violations detected ({integrity_violations}) - check for memory corruption or attacks")
+ report["recommendations"].append(
+ f"Cache integrity violations detected ({integrity_violations}) - check for memory corruption or attacks"
+ )
- invalid_configs = sum(1 for event in security_events if "invalid_config" in event.get("security_event", ""))
+ invalid_configs = sum(
+ 1
+ for event in security_events
+ if "invalid_config" in event.get("security_event", "")
+ )
if invalid_configs > 3:
- report["recommendations"].append(f"Multiple invalid configuration attempts ({invalid_configs}) - validate data sources")
+ report["recommendations"].append(
+ f"Multiple invalid configuration attempts ({invalid_configs}) - validate data sources"
+ )
return report
+
+
@asynccontextmanager
async def get_llm_client(
provider: str | None = None,
@@ -347,7 +404,9 @@ async def get_llm_client(
cache_key = "rag_strategy_settings"
rag_settings = _get_cached_settings(cache_key)
if rag_settings is None:
- rag_settings = await credential_service.get_credentials_by_category("rag_strategy")
+ rag_settings = await credential_service.get_credentials_by_category(
+ "rag_strategy"
+ )
_set_cached_settings(cache_key, rag_settings)
logger.debug("Fetched and cached rag_strategy settings")
else:
@@ -367,7 +426,9 @@ async def get_llm_client(
cache_key = f"provider_config_{service_type}"
provider_config = _get_cached_settings(cache_key)
if provider_config is None:
- provider_config = await credential_service.get_active_provider(service_type)
+ provider_config = await credential_service.get_active_provider(
+ service_type
+ )
_set_cached_settings(cache_key, provider_config)
logger.debug(f"Fetched and cached {service_type} provider config")
else:
@@ -376,7 +437,9 @@ async def get_llm_client(
provider_name = provider_config["provider"]
api_key = provider_config["api_key"]
# For Ollama, don't use the base_url from config - let _get_optimal_ollama_instance decide
- base_url = provider_config["base_url"] if provider_name != "ollama" else None
+ base_url = (
+ provider_config["base_url"] if provider_name != "ollama" else None
+ )
# Comprehensive provider validation with security checks
if not _is_valid_provider(provider_name):
@@ -389,14 +452,98 @@ async def get_llm_client(
elif len(api_key) > 500: # Reasonable API key length limit
raise ValueError("API key length exceeds security limits")
# Additional security: check for suspicious patterns
- if any(char in api_key for char in ['\n', '\r', '\t', '\0']):
+ if any(char in api_key for char in ["\n", "\r", "\t", "\0"]):
raise ValueError("API key contains invalid characters")
# Sanitize provider name for logging
- safe_provider_name = _sanitize_for_log(provider_name) if provider_name else "unknown"
+ safe_provider_name = (
+ _sanitize_for_log(provider_name) if provider_name else "unknown"
+ )
logger.info(f"Creating LLM client for provider: {safe_provider_name}")
- if provider_name == "openai":
+ if provider_name == "azure-openai":
+ # Azure OpenAI requires specific configuration
+ azure_endpoint = await credential_service.get_credential(
+ "AZURE_OPENAI_ENDPOINT"
+ )
+ azure_api_version = await credential_service.get_credential(
+ "AZURE_OPENAI_API_VERSION"
+ )
+
+ if not azure_endpoint:
+ raise ValueError(
+ "Azure OpenAI endpoint not configured. Set AZURE_OPENAI_ENDPOINT."
+ )
+ if not api_key:
+ raise ValueError(
+ "Azure OpenAI API key not found. Set AZURE_OPENAI_API_KEY."
+ )
+ if not azure_api_version:
+ # Default to a stable API version if not specified
+ azure_api_version = "2024-02-15-preview"
+ logger.warning(
+ f"Azure OpenAI API version not set, using default: {azure_api_version}"
+ )
+
+ client = openai.AsyncAzureOpenAI(
+ api_key=api_key,
+ azure_endpoint=azure_endpoint,
+ api_version=azure_api_version,
+ )
+ logger.info(
+ f"Azure OpenAI client created successfully with endpoint: {azure_endpoint}"
+ )
+
+ elif provider_name == "aws-bedrock":
+ # AWS Bedrock requires boto3 SDK for Converse API
+ try:
+ import boto3
+ from botocore.config import Config as BotoConfig
+ except ImportError as import_error:
+ raise ValueError(
+ "AWS Bedrock support requires boto3. Install with: pip install boto3"
+ ) from import_error
+
+ # Get AWS credentials
+ aws_secret_key = await credential_service.get_credential(
+ "AWS_SECRET_ACCESS_KEY"
+ )
+ aws_region = await credential_service.get_credential("AWS_REGION")
+
+ if not api_key: # api_key is AWS_ACCESS_KEY_ID
+ raise ValueError("AWS Access Key ID not found. Set AWS_ACCESS_KEY_ID.")
+ if not aws_secret_key:
+ raise ValueError(
+ "AWS Secret Access Key not found. Set AWS_SECRET_ACCESS_KEY."
+ )
+ if not aws_region:
+ aws_region = "us-east-1" # Default region
+ logger.warning(f"AWS region not set, using default: {aws_region}")
+
+ # Create boto3 Bedrock runtime client
+ bedrock_config = BotoConfig(
+ region_name=aws_region,
+ signature_version="v4",
+ retries={"max_attempts": 3, "mode": "adaptive"},
+ )
+
+ bedrock_client = boto3.client(
+ "bedrock-runtime",
+ aws_access_key_id=api_key,
+ aws_secret_access_key=aws_secret_key,
+ config=bedrock_config,
+ )
+
+ # Wrap boto3 client in a compatible interface
+ # We'll create a custom wrapper that implements the OpenAI-like async interface
+ from ..adapters.aws_bedrock_adapter import AWSBedrockClientAdapter
+
+ client = AWSBedrockClientAdapter(bedrock_client, aws_region)
+ logger.info(
+ f"AWS Bedrock client created successfully in region: {aws_region}"
+ )
+
+ elif provider_name == "openai":
if api_key:
client = openai.AsyncOpenAI(api_key=api_key)
logger.info("OpenAI client created successfully")
@@ -440,7 +587,9 @@ async def get_llm_client(
api_key="ollama", # Required but unused by Ollama
base_url=ollama_base_url,
)
- logger.info(f"Ollama client created successfully with base URL: {ollama_base_url}")
+ logger.info(
+ f"Ollama client created successfully with base URL: {ollama_base_url}"
+ )
elif provider_name == "google":
if not api_key:
@@ -448,7 +597,8 @@ async def get_llm_client(
client = openai.AsyncOpenAI(
api_key=api_key,
- base_url=base_url or "https://generativelanguage.googleapis.com/v1beta/openai/",
+ base_url=base_url
+ or "https://generativelanguage.googleapis.com/v1beta/openai/",
)
logger.info("Google Gemini client created successfully")
@@ -474,14 +624,18 @@ async def get_llm_client(
elif provider_name == "grok":
if not api_key:
- raise ValueError("Grok API key not found - set GROK_API_KEY environment variable")
+ raise ValueError(
+ "Grok API key not found - set GROK_API_KEY environment variable"
+ )
# Enhanced Grok API key validation (secure - no key fragments logged)
key_format_valid = api_key.startswith("xai-")
key_length_valid = len(api_key) >= 20
if not key_format_valid:
- logger.warning("Grok API key format validation failed - should start with 'xai-'")
+ logger.warning(
+ "Grok API key format validation failed - should start with 'xai-'"
+ )
if not key_length_valid:
logger.warning("Grok API key validation failed - insufficient length")
@@ -509,7 +663,9 @@ async def get_llm_client(
yield client
finally:
if client is not None:
- safe_provider = _sanitize_for_log(provider_name) if provider_name else "unknown"
+ safe_provider = (
+ _sanitize_for_log(provider_name) if provider_name else "unknown"
+ )
try:
close_method = getattr(client, "aclose", None)
@@ -548,24 +704,29 @@ async def get_llm_client(
)
-
-async def _get_optimal_ollama_instance(instance_type: str | None = None,
- use_embedding_provider: bool = False,
- base_url_override: str | None = None) -> str:
+async def _get_optimal_ollama_instance(
+ instance_type: str | None = None,
+ use_embedding_provider: bool = False,
+ base_url_override: str | None = None,
+) -> str:
"""
Get the optimal Ollama instance URL based on configuration and health status.
-
+
Args:
instance_type: Preferred instance type ('chat', 'embedding', 'both', or None)
use_embedding_provider: Whether this is for embedding operations
base_url_override: Override URL if specified
-
+
Returns:
Best available Ollama instance URL
"""
# If override URL provided, use it directly
if base_url_override:
- return base_url_override if base_url_override.endswith('/v1') else f"{base_url_override}/v1"
+ return (
+ base_url_override
+ if base_url_override.endswith("/v1")
+ else f"{base_url_override}/v1"
+ )
try:
# For now, we don't have multi-instance support, so skip to single instance config
@@ -573,25 +734,39 @@ async def _get_optimal_ollama_instance(instance_type: str | None = None,
logger.info("Using single instance Ollama configuration")
# Get single instance configuration from RAG settings
- rag_settings = await credential_service.get_credentials_by_category("rag_strategy")
+ rag_settings = await credential_service.get_credentials_by_category(
+ "rag_strategy"
+ )
# Check if we need embedding provider and have separate embedding URL
if use_embedding_provider or instance_type == "embedding":
embedding_url = rag_settings.get("OLLAMA_EMBEDDING_URL")
if embedding_url:
- return embedding_url if embedding_url.endswith('/v1') else f"{embedding_url}/v1"
+ return (
+ embedding_url
+ if embedding_url.endswith("/v1")
+ else f"{embedding_url}/v1"
+ )
# Default to LLM base URL for chat operations
- fallback_url = rag_settings.get("LLM_BASE_URL", "http://host.docker.internal:11434")
- return fallback_url if fallback_url.endswith('/v1') else f"{fallback_url}/v1"
+ fallback_url = rag_settings.get(
+ "LLM_BASE_URL", "http://host.docker.internal:11434"
+ )
+ return fallback_url if fallback_url.endswith("/v1") else f"{fallback_url}/v1"
except Exception as e:
logger.error(f"Error getting Ollama configuration: {e}")
# Final fallback to localhost only if we can't get RAG settings
try:
- rag_settings = await credential_service.get_credentials_by_category("rag_strategy")
- fallback_url = rag_settings.get("LLM_BASE_URL", "http://host.docker.internal:11434")
- return fallback_url if fallback_url.endswith('/v1') else f"{fallback_url}/v1"
+ rag_settings = await credential_service.get_credentials_by_category(
+ "rag_strategy"
+ )
+ fallback_url = rag_settings.get(
+ "LLM_BASE_URL", "http://host.docker.internal:11434"
+ )
+ return (
+ fallback_url if fallback_url.endswith("/v1") else f"{fallback_url}/v1"
+ )
except Exception as fallback_error:
logger.error(f"Could not retrieve fallback configuration: {fallback_error}")
return "http://host.docker.internal:11434/v1"
@@ -616,7 +791,9 @@ async def get_embedding_model(provider: str | None = None) -> str:
cache_key = "rag_strategy_settings"
rag_settings = _get_cached_settings(cache_key)
if rag_settings is None:
- rag_settings = await credential_service.get_credentials_by_category("rag_strategy")
+ rag_settings = await credential_service.get_credentials_by_category(
+ "rag_strategy"
+ )
_set_cached_settings(cache_key, rag_settings)
custom_model = rag_settings.get("EMBEDDING_MODEL", "")
else:
@@ -624,7 +801,9 @@ async def get_embedding_model(provider: str | None = None) -> str:
cache_key = "provider_config_embedding"
provider_config = _get_cached_settings(cache_key)
if provider_config is None:
- provider_config = await credential_service.get_active_provider("embedding")
+ provider_config = await credential_service.get_active_provider(
+ "embedding"
+ )
_set_cached_settings(cache_key, provider_config)
provider_name = provider_config["provider"]
custom_model = provider_config["embedding_model"]
@@ -632,21 +811,40 @@ async def get_embedding_model(provider: str | None = None) -> str:
# Comprehensive provider validation for embeddings
if not _is_valid_provider(provider_name):
safe_provider = _sanitize_for_log(provider_name)
- logger.warning(f"Invalid embedding provider: {safe_provider}, falling back to OpenAI")
+ logger.warning(
+ f"Invalid embedding provider: {safe_provider}, falling back to OpenAI"
+ )
provider_name = "openai"
# Use custom model if specified (with validation)
if custom_model and len(custom_model.strip()) > 0:
custom_model = custom_model.strip()
# Basic model name validation (check length and basic characters)
- if len(custom_model) <= 100 and not any(char in custom_model for char in ['\n', '\r', '\t', '\0']):
+ if len(custom_model) <= 100 and not any(
+ char in custom_model for char in ["\n", "\r", "\t", "\0"]
+ ):
return custom_model
else:
safe_model = _sanitize_for_log(custom_model)
- logger.warning(f"Invalid custom embedding model '{safe_model}' for provider '{provider_name}', using default")
+ logger.warning(
+ f"Invalid custom embedding model '{safe_model}' for provider '{provider_name}', using default"
+ )
# Return provider-specific defaults
if provider_name == "openai":
return "text-embedding-3-small"
+ elif provider_name == "azure-openai":
+ # For Azure OpenAI, use the deployment name from settings
+ # This is required for Azure OpenAI embeddings
+ deployment = await credential_service.get_credential(
+ "AZURE_OPENAI_DEPLOYMENT"
+ )
+ if deployment:
+ return deployment
+ else:
+ logger.warning(
+ "Azure OpenAI deployment not set, using default model name"
+ )
+ return "text-embedding-3-small"
elif provider_name == "ollama":
# Ollama default embedding model
return "nomic-embed-text"
@@ -665,6 +863,16 @@ async def get_embedding_model(provider: str | None = None) -> str:
# Grok supports OpenAI and Google embedding models through their API
# Default to OpenAI's latest for compatibility
return "text-embedding-3-small"
+ elif provider_name == "aws-bedrock":
+ # AWS Bedrock default embedding model (Amazon Titan Embeddings)
+ bedrock_model_id = await credential_service.get_credential(
+ "AWS_BEDROCK_MODEL_ID"
+ )
+ if bedrock_model_id:
+ return bedrock_model_id
+ else:
+ # Default to Titan Embeddings V1
+ return "amazon.titan-embed-text-v1"
else:
# Fallback to OpenAI's model
return "text-embedding-3-small"
@@ -714,7 +922,7 @@ def is_google_embedding_model(model: str) -> bool:
"text-embedding-005",
"text-multilingual-embedding-002",
"gemini-embedding-001",
- "multimodalembedding@001"
+ "multimodalembedding@001",
]
return any(pattern in model_lower for pattern in google_patterns)
@@ -738,6 +946,10 @@ def is_valid_embedding_model_for_provider(model: str, provider: str) -> bool:
if provider_lower == "openai":
return is_openai_embedding_model(model)
+ elif provider_lower == "azure-openai":
+ # Azure OpenAI uses deployment names, which can be any string
+ # so we accept any non-empty string as valid
+ return bool(model and model.strip())
elif provider_lower == "google":
return is_google_embedding_model(model)
elif provider_lower in ["openrouter", "anthropic", "grok"]:
@@ -771,7 +983,7 @@ def get_supported_embedding_models(provider: str) -> list[str]:
openai_models = [
"text-embedding-ada-002",
"text-embedding-3-small",
- "text-embedding-3-large"
+ "text-embedding-3-large",
]
google_models = [
@@ -779,11 +991,15 @@ def get_supported_embedding_models(provider: str) -> list[str]:
"text-embedding-005",
"text-multilingual-embedding-002",
"gemini-embedding-001",
- "multimodalembedding@001"
+ "multimodalembedding@001",
]
if provider_lower == "openai":
return openai_models
+ elif provider_lower == "azure-openai":
+ # Azure OpenAI uses deployment names which are user-defined
+ # Return OpenAI models as reference, but actual deployment names may differ
+ return openai_models
elif provider_lower == "google":
return google_models
elif provider_lower in ["openrouter", "anthropic", "grok"]:
@@ -791,6 +1007,14 @@ def get_supported_embedding_models(provider: str) -> list[str]:
return openai_models + google_models
elif provider_lower == "ollama":
return ["nomic-embed-text", "all-minilm", "mxbai-embed-large"]
+ elif provider_lower == "aws-bedrock":
+ # AWS Bedrock embedding models
+ return [
+ "amazon.titan-embed-text-v1",
+ "amazon.titan-embed-text-v2:0",
+ "cohere.embed-english-v3",
+ "cohere.embed-multilingual-v3",
+ ]
else:
# For unknown providers, assume OpenAI compatibility
return openai_models
@@ -931,15 +1155,26 @@ def _is_reasoning_text(text: str) -> bool:
# Common reasoning text patterns
reasoning_indicators = [
- "okay, let's see", "let me think", "first, i need to", "looking at this",
- "step by step", "analyzing", "breaking this down", "considering",
- "let me work through", "i should", "thinking about", "examining"
+ "okay, let's see",
+ "let me think",
+ "first, i need to",
+ "looking at this",
+ "step by step",
+ "analyzing",
+ "breaking this down",
+ "considering",
+ "let me work through",
+ "i should",
+ "thinking about",
+ "examining",
]
return any(indicator in text_lower for indicator in reasoning_indicators)
-def extract_json_from_reasoning(reasoning_text: str, context_code: str = "", language: str = "") -> str:
+def extract_json_from_reasoning(
+ reasoning_text: str, context_code: str = "", language: str = ""
+) -> str:
"""Extract JSON content from reasoning text, with synthesis fallback."""
if not reasoning_text:
return ""
@@ -948,8 +1183,10 @@ def extract_json_from_reasoning(reasoning_text: str, context_code: str = "", lan
import re
# Try to find JSON blocks in markdown
- json_block_pattern = r'```(?:json)?\s*(\{.*?\})\s*```'
- json_matches = re.findall(json_block_pattern, reasoning_text, re.DOTALL | re.IGNORECASE)
+ json_block_pattern = r"```(?:json)?\s*(\{.*?\})\s*```"
+ json_matches = re.findall(
+ json_block_pattern, reasoning_text, re.DOTALL | re.IGNORECASE
+ )
for match in json_matches:
try:
@@ -960,14 +1197,16 @@ def extract_json_from_reasoning(reasoning_text: str, context_code: str = "", lan
continue
# Try to find standalone JSON objects
- json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
+ json_pattern = r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}"
json_matches = re.findall(json_pattern, reasoning_text, re.DOTALL)
for match in json_matches:
try:
parsed = json.loads(match.strip())
# Ensure it has expected structure
- if isinstance(parsed, dict) and any(key in parsed for key in ["example_name", "summary", "name", "title"]):
+ if isinstance(parsed, dict) and any(
+ key in parsed for key in ["example_name", "summary", "name", "title"]
+ ):
return match.strip()
except json.JSONDecodeError:
continue
@@ -976,7 +1215,9 @@ def extract_json_from_reasoning(reasoning_text: str, context_code: str = "", lan
return synthesize_json_from_reasoning(reasoning_text, context_code, language)
-def synthesize_json_from_reasoning(reasoning_text: str, context_code: str = "", language: str = "") -> str:
+def synthesize_json_from_reasoning(
+ reasoning_text: str, context_code: str = "", language: str = ""
+) -> str:
"""Generate JSON structure from reasoning text when no JSON is found."""
if not reasoning_text and not context_code:
return ""
@@ -991,44 +1232,44 @@ def synthesize_json_from_reasoning(reasoning_text: str, context_code: str = "",
# Common action patterns in reasoning text and code
action_patterns = [
- (r'\b(?:parse|parsing|parsed)\b', 'Parse'),
- (r'\b(?:create|creating|created)\b', 'Create'),
- (r'\b(?:analyze|analyzing|analyzed)\b', 'Analyze'),
- (r'\b(?:extract|extracting|extracted)\b', 'Extract'),
- (r'\b(?:generate|generating|generated)\b', 'Generate'),
- (r'\b(?:process|processing|processed)\b', 'Process'),
- (r'\b(?:load|loading|loaded)\b', 'Load'),
- (r'\b(?:handle|handling|handled)\b', 'Handle'),
- (r'\b(?:manage|managing|managed)\b', 'Manage'),
- (r'\b(?:build|building|built)\b', 'Build'),
- (r'\b(?:define|defining|defined)\b', 'Define'),
- (r'\b(?:implement|implementing|implemented)\b', 'Implement'),
- (r'\b(?:fetch|fetching|fetched)\b', 'Fetch'),
- (r'\b(?:connect|connecting|connected)\b', 'Connect'),
- (r'\b(?:validate|validating|validated)\b', 'Validate'),
+ (r"\b(?:parse|parsing|parsed)\b", "Parse"),
+ (r"\b(?:create|creating|created)\b", "Create"),
+ (r"\b(?:analyze|analyzing|analyzed)\b", "Analyze"),
+ (r"\b(?:extract|extracting|extracted)\b", "Extract"),
+ (r"\b(?:generate|generating|generated)\b", "Generate"),
+ (r"\b(?:process|processing|processed)\b", "Process"),
+ (r"\b(?:load|loading|loaded)\b", "Load"),
+ (r"\b(?:handle|handling|handled)\b", "Handle"),
+ (r"\b(?:manage|managing|managed)\b", "Manage"),
+ (r"\b(?:build|building|built)\b", "Build"),
+ (r"\b(?:define|defining|defined)\b", "Define"),
+ (r"\b(?:implement|implementing|implemented)\b", "Implement"),
+ (r"\b(?:fetch|fetching|fetched)\b", "Fetch"),
+ (r"\b(?:connect|connecting|connected)\b", "Connect"),
+ (r"\b(?:validate|validating|validated)\b", "Validate"),
]
# Technology/concept patterns
tech_patterns = [
- (r'\bjson\b', 'JSON'),
- (r'\bapi\b', 'API'),
- (r'\bfile\b', 'File'),
- (r'\bdata\b', 'Data'),
- (r'\bcode\b', 'Code'),
- (r'\btext\b', 'Text'),
- (r'\bcontent\b', 'Content'),
- (r'\bresponse\b', 'Response'),
- (r'\brequest\b', 'Request'),
- (r'\bconfig\b', 'Config'),
- (r'\bllm\b', 'LLM'),
- (r'\bmodel\b', 'Model'),
- (r'\bexample\b', 'Example'),
- (r'\bcontext\b', 'Context'),
- (r'\basync\b', 'Async'),
- (r'\bfunction\b', 'Function'),
- (r'\bclass\b', 'Class'),
- (r'\bprint\b', 'Output'),
- (r'\breturn\b', 'Return'),
+ (r"\bjson\b", "JSON"),
+ (r"\bapi\b", "API"),
+ (r"\bfile\b", "File"),
+ (r"\bdata\b", "Data"),
+ (r"\bcode\b", "Code"),
+ (r"\btext\b", "Text"),
+ (r"\bcontent\b", "Content"),
+ (r"\bresponse\b", "Response"),
+ (r"\brequest\b", "Request"),
+ (r"\bconfig\b", "Config"),
+ (r"\bllm\b", "LLM"),
+ (r"\bmodel\b", "Model"),
+ (r"\bexample\b", "Example"),
+ (r"\bcontext\b", "Context"),
+ (r"\basync\b", "Async"),
+ (r"\bfunction\b", "Function"),
+ (r"\bclass\b", "Class"),
+ (r"\bprint\b", "Output"),
+ (r"\breturn\b", "Return"),
]
# Extract actions and technologies from combined text
@@ -1061,8 +1302,12 @@ def synthesize_json_from_reasoning(reasoning_text: str, context_code: str = "",
example_name = " ".join(example_name_words[:4])
# Generate summary from reasoning content
- reasoning_lines = reasoning_text.split('\n')
- meaningful_lines = [line.strip() for line in reasoning_lines if line.strip() and len(line.strip()) > 10]
+ reasoning_lines = reasoning_text.split("\n")
+ meaningful_lines = [
+ line.strip()
+ for line in reasoning_lines
+ if line.strip() and len(line.strip()) > 10
+ ]
if meaningful_lines:
# Take first meaningful sentence for summary base
@@ -1084,10 +1329,7 @@ def synthesize_json_from_reasoning(reasoning_text: str, context_code: str = "",
summary = summary[:297] + "..."
# Create JSON structure
- result = {
- "example_name": example_name,
- "summary": summary
- }
+ result = {"example_name": example_name, "summary": summary}
return json.dumps(result)
@@ -1127,19 +1369,23 @@ def prepare_chat_completion_params(model: str, params: dict) -> dict:
# Remove custom temperature for reasoning models (they only support default temperature=1.0)
if reasoning_model and "temperature" in updated_params:
original_temp = updated_params.pop("temperature")
- logger.debug(f"Removed custom temperature {original_temp} for reasoning model {model} (only supports default temperature=1.0)")
+ logger.debug(
+ f"Removed custom temperature {original_temp} for reasoning model {model} (only supports default temperature=1.0)"
+ )
return updated_params
-async def get_embedding_model_with_routing(provider: str | None = None, instance_url: str | None = None) -> tuple[str, str]:
+async def get_embedding_model_with_routing(
+ provider: str | None = None, instance_url: str | None = None
+) -> tuple[str, str]:
"""
Get the embedding model with intelligent routing for multi-instance setups.
-
+
Args:
provider: Override provider selection
instance_url: Specific instance URL to use
-
+
Returns:
Tuple of (model_name, instance_url) for embedding operations
"""
@@ -1149,14 +1395,21 @@ async def get_embedding_model_with_routing(provider: str | None = None, instance
# If specific instance URL provided, use it
if instance_url:
- final_url = instance_url if instance_url.endswith('/v1') else f"{instance_url}/v1"
+ final_url = (
+ instance_url if instance_url.endswith("/v1") else f"{instance_url}/v1"
+ )
return model_name, final_url
# For Ollama provider, use intelligent instance routing
- if provider == "ollama" or (not provider and (await credential_service.get_credentials_by_category("rag_strategy")).get("LLM_PROVIDER") == "ollama"):
+ if provider == "ollama" or (
+ not provider
+ and (
+ await credential_service.get_credentials_by_category("rag_strategy")
+ ).get("LLM_PROVIDER")
+ == "ollama"
+ ):
optimal_url = await _get_optimal_ollama_instance(
- instance_type="embedding",
- use_embedding_provider=True
+ instance_type="embedding", use_embedding_provider=True
)
return model_name, optimal_url
@@ -1168,14 +1421,16 @@ async def get_embedding_model_with_routing(provider: str | None = None, instance
return "text-embedding-3-small", None
-async def validate_provider_instance(provider: str, instance_url: str | None = None) -> dict[str, any]:
+async def validate_provider_instance(
+ provider: str, instance_url: str | None = None
+) -> dict[str, any]:
"""
Validate a provider instance and return health information.
-
+
Args:
provider: Provider name (openai, ollama, google, etc.)
instance_url: Instance URL for providers that support multiple instances
-
+
Returns:
Dictionary with validation results and health status
"""
@@ -1188,10 +1443,12 @@ async def validate_provider_instance(provider: str, instance_url: str | None = N
if not instance_url:
instance_url = await _get_optimal_ollama_instance()
# Remove /v1 suffix for health checking
- if instance_url.endswith('/v1'):
+ if instance_url.endswith("/v1"):
instance_url = instance_url[:-3]
- health_status = await model_discovery_service.check_instance_health(instance_url)
+ health_status = await model_discovery_service.check_instance_health(
+ instance_url
+ )
return {
"provider": provider,
@@ -1200,7 +1457,7 @@ async def validate_provider_instance(provider: str, instance_url: str | None = N
"response_time_ms": health_status.response_time_ms,
"models_available": health_status.models_available,
"error_message": health_status.error_message,
- "validation_timestamp": time.time()
+ "validation_timestamp": time.time(),
}
else:
@@ -1212,7 +1469,7 @@ async def validate_provider_instance(provider: str, instance_url: str | None = N
if provider == "openai":
# List models to validate API key
models = await client.models.list()
- model_count = len(models.data) if hasattr(models, 'data') else 0
+ model_count = len(models.data) if hasattr(models, "data") else 0
elif provider == "google":
# For Google, we can't easily list models, just validate client creation
model_count = 1 # Assume available if client creation succeeded
@@ -1228,7 +1485,7 @@ async def validate_provider_instance(provider: str, instance_url: str | None = N
"response_time_ms": response_time,
"models_available": model_count,
"error_message": None,
- "validation_timestamp": time.time()
+ "validation_timestamp": time.time(),
}
except Exception as e:
@@ -1240,11 +1497,10 @@ async def validate_provider_instance(provider: str, instance_url: str | None = N
"response_time_ms": None,
"models_available": 0,
"error_message": str(e),
- "validation_timestamp": time.time()
+ "validation_timestamp": time.time(),
}
-
def requires_max_completion_tokens(model_name: str) -> bool:
"""Backward compatible alias for previous API."""
return is_reasoning_model(model_name)