diff --git a/clients/macos/vellum-assistant/Features/Settings/InferenceServiceCard.swift b/clients/macos/vellum-assistant/Features/Settings/InferenceServiceCard.swift index 338914202a3..7e69ddf4235 100644 --- a/clients/macos/vellum-assistant/Features/Settings/InferenceServiceCard.swift +++ b/clients/macos/vellum-assistant/Features/Settings/InferenceServiceCard.swift @@ -4,9 +4,9 @@ import VellumAssistantShared /// Card for the inference service with Managed/Your Own mode toggle. /// /// Shows different content based on mode and auth state: -/// - **Managed + logged in**: Model picker, Save button +/// - **Managed + logged in**: Provider picker (managed-capable only), model picker, Save button /// - **Managed + not logged in**: Empty state prompting login -/// - **Your Own**: Provider picker, API key field, model picker, Save + Reset buttons +/// - **Your Own**: Provider picker (all), API key field, model picker, Save + Reset buttons @MainActor struct InferenceServiceCard: View { @ObservedObject var store: SettingsStore @@ -82,26 +82,26 @@ struct InferenceServiceCard: View { let modeChanged = draftMode != store.inferenceMode let hasNewKey = draftMode == "your-own" && !apiKeyText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty let modelChanged = draftModel != initialModel - let effectiveDraftProvider = draftMode == "managed" ? "anthropic" : draftProvider - let providerChanged = effectiveDraftProvider != initialProvider + let providerChanged = draftProvider != initialProvider return modeChanged || hasNewKey || modelChanged || providerChanged } var body: some View { ServiceModeCard( title: "Inference", - subtitle: draftMode == "managed" - ? "Configure which model to use to power your assistant" - : "Configure which LLM provider and model to use to power your assistant", + subtitle: "Configure which LLM provider and model to use to power your assistant", draftMode: $draftMode, managedContent: { if isLoggedIn { - PickerWithInlineSave( - hasChanges: hasChanges, - isSaving: store.apiKeySaving, - onSave: { save() } - ) { - modelPicker + VStack(alignment: .leading, spacing: VSpacing.sm) { + managedProviderPicker + PickerWithInlineSave( + hasChanges: hasChanges, + isSaving: store.apiKeySaving, + onSave: { save() } + ) { + modelPicker + } } } else { managedLoginPrompt @@ -157,13 +157,15 @@ struct InferenceServiceCard: View { // Symmetric case: if the user is authenticated and the mode is // still the default "your-own", switch to "managed" so signed-in // users get managed inference out of the box — but only when the - // provider requires an API key and the user hasn't configured one. - // Providers like Ollama that don't use keys (apiKeyPlaceholder is - // nil) are left alone since the user intentionally set up a local - // provider. + // provider is managed-capable, requires an API key, and the user + // hasn't configured one. Providers like Ollama that don't use keys + // (apiKeyPlaceholder is nil) or non-managed providers (fireworks, + // openrouter) are left alone since the user intentionally set up + // that provider. let providerRequiresKey = store.dynamicProviderApiKeyPlaceholder(draftProvider) != nil let hasLocalKey = APIKeyManager.getKey(for: draftProvider) != nil - if isLoggedIn && draftMode == "your-own" && providerRequiresKey && !hasLocalKey { + let providerIsManagedCapable = store.isManagedCapable(draftProvider) + if isLoggedIn && draftMode == "your-own" && providerIsManagedCapable && providerRequiresKey && !hasLocalKey { draftMode = "managed" store.setInferenceMode("managed") } @@ -190,12 +192,13 @@ struct InferenceServiceCard: View { // mode that onAppear may have temporarily overridden. draftMode = "managed" } else if isAuthenticated && store.inferenceMode == "your-own" { - // When a user signs in and has no BYO key for a key-based - // provider, default to managed. Keyless providers (e.g. Ollama) - // are left in your-own mode. + // When a user signs in and has no BYO key for a managed-capable, + // key-based provider, default to managed. Keyless providers + // (e.g. Ollama) and non-managed providers are left in your-own mode. let requiresKey = store.dynamicProviderApiKeyPlaceholder(draftProvider) != nil let hasLocalKey = APIKeyManager.getKey(for: draftProvider) != nil - if requiresKey && !hasLocalKey { + let isManagedCapable = store.isManagedCapable(draftProvider) + if isManagedCapable && requiresKey && !hasLocalKey { draftMode = "managed" store.setInferenceMode("managed") } @@ -256,11 +259,19 @@ struct InferenceServiceCard: View { } .onChange(of: draftMode) { _, newMode in if newMode == "managed" { - let anthropicModels = store.dynamicProviderModels("anthropic") - let isCurrentModelAnthropic = anthropicModels.contains { $0.id == draftModel } - if !isCurrentModelAnthropic { - let defaultModel = store.dynamicProviderDefaultModel("anthropic") - draftModel = defaultModel.isEmpty ? "claude-opus-4-7" : defaultModel + // When switching to managed mode, fall back to a managed-capable + // provider if the current one does not support managed routing. + if !store.isManagedCapable(draftProvider) { + draftProvider = "anthropic" + } + // Validate the model against the selected managed provider's catalog. + let managedModels = store.dynamicProviderModels(draftProvider) + let isCurrentModelValid = managedModels.contains { $0.id == draftModel } + if !isCurrentModelValid { + let defaultModel = store.dynamicProviderDefaultModel(draftProvider) + draftModel = defaultModel.isEmpty + ? (managedModels.first?.id ?? "") + : defaultModel } } else if newMode == "your-own" { let providerModels = store.dynamicProviderModels(draftProvider) @@ -327,6 +338,22 @@ struct InferenceServiceCard: View { } } + /// Provider picker filtered to managed-capable providers, shown in managed mode. + private var managedProviderPicker: some View { + VStack(alignment: .leading, spacing: VSpacing.sm) { + Text("Provider") + .font(VFont.labelDefault) + .foregroundStyle(VColor.contentSecondary) + VDropdown( + placeholder: "Select a provider\u{2026}", + selection: $draftProvider, + options: store.managedCapableProviders.map { entry in + (label: entry.displayName, value: entry.id) + } + ) + } + } + // MARK: - API Key Field private var apiKeyField: some View { @@ -356,11 +383,10 @@ struct InferenceServiceCard: View { /// Per-provider catalog model dropdown. private var providerModelPicker: some View { - let provider = draftMode == "managed" ? "anthropic" : draftProvider - return VDropdown( + VDropdown( placeholder: "Select a model\u{2026}", selection: $draftModel, - options: store.dynamicProviderModels(provider).map { model in + options: store.dynamicProviderModels(draftProvider).map { model in (label: model.displayName, value: model.id) } ) @@ -390,16 +416,10 @@ struct InferenceServiceCard: View { // changed — switching between managed and your-own implies a // provider change even if the resolved provider ID happens to // match initialProvider (ensures config stays consistent). - let persistProvider = draftMode == "managed" ? "anthropic" : draftProvider - let providerChanged = persistProvider != initialProvider || modeChanged - let pendingProvider = providerChanged ? store.setInferenceProvider(persistProvider) : nil + let providerChanged = draftProvider != initialProvider || modeChanged + let pendingProvider = providerChanged ? store.setInferenceProvider(draftProvider) : nil if providerChanged { - initialProvider = persistProvider - } - // Normalize draftProvider to match what was persisted so hasChanges - // (which compares draftProvider against initialProvider) stays in sync. - if draftProvider != persistProvider { - draftProvider = persistProvider + initialProvider = draftProvider } // Persist API key if entered and in your-own mode. @@ -420,12 +440,12 @@ struct InferenceServiceCard: View { // daemon's read-modify-write cycle for the model doesn't overwrite them. store.selectedModel = draftModel let capturedModel = draftModel - let saveProvider = draftMode == "managed" ? "anthropic" : draftProvider + let capturedProvider = draftProvider let forceSend = modeChanged Task { if let pendingMode { _ = await pendingMode.value } if let pendingProvider { _ = await pendingProvider.value } - store.setModel(capturedModel, provider: saveProvider, force: forceSend) + store.setModel(capturedModel, provider: capturedProvider, force: forceSend) } initialModel = draftModel } diff --git a/clients/macos/vellum-assistant/Features/Settings/SettingsStore.swift b/clients/macos/vellum-assistant/Features/Settings/SettingsStore.swift index 584fdaf8b20..579ef94d84b 100644 --- a/clients/macos/vellum-assistant/Features/Settings/SettingsStore.swift +++ b/clients/macos/vellum-assistant/Features/Settings/SettingsStore.swift @@ -1063,6 +1063,36 @@ public final class SettingsStore: ObservableObject { providerCatalog.first { $0.id == provider }?.apiKeyPlaceholder } + // MARK: - Provider Capability Helpers + + /// Provider IDs that support managed proxy routing (i.e., can be used in managed mode). + /// Mirrors the `MANAGED_PROVIDER_META` table in the backend. + private static let managedCapableProviderIds: Set = ["anthropic", "openai", "gemini"] + + /// Provider IDs that support native web search (inference-provider-native). + /// Anthropic and OpenAI pass `useNativeWebSearch` to their providers; others do not. + private static let nativeWebSearchCapableProviderIds: Set = ["anthropic", "openai"] + + /// Returns the catalog entries for providers that support managed proxy routing. + var managedCapableProviders: [ProviderCatalogEntry] { + providerCatalog.filter { Self.managedCapableProviderIds.contains($0.id) } + } + + /// Returns the catalog entries for providers that support native web search. + var nativeWebSearchCapableProviders: [ProviderCatalogEntry] { + providerCatalog.filter { Self.nativeWebSearchCapableProviderIds.contains($0.id) } + } + + /// Whether a given provider supports managed proxy routing. + func isManagedCapable(_ provider: String) -> Bool { + Self.managedCapableProviderIds.contains(provider) + } + + /// Whether a given provider supports native web search. + func isNativeWebSearchCapable(_ provider: String) -> Bool { + Self.nativeWebSearchCapableProviderIds.contains(provider) + } + // MARK: - Embedding Config Actions func refreshEmbeddingConfig() { diff --git a/clients/macos/vellum-assistantTests/SettingsStoreManagedInferenceSelectionTests.swift b/clients/macos/vellum-assistantTests/SettingsStoreManagedInferenceSelectionTests.swift new file mode 100644 index 00000000000..675d35acee8 --- /dev/null +++ b/clients/macos/vellum-assistantTests/SettingsStoreManagedInferenceSelectionTests.swift @@ -0,0 +1,179 @@ +import XCTest +@testable import VellumAssistantLib +@testable import VellumAssistantShared + +/// Tests for SettingsStore provider capability helpers and managed-mode +/// provider selection behavior. +@MainActor +final class SettingsStoreManagedInferenceSelectionTests: XCTestCase { + + private var store: SettingsStore! + + override func setUp() { + super.setUp() + store = SettingsStore(settingsClient: MockSettingsClient()) + } + + override func tearDown() { + store = nil + super.tearDown() + } + + // MARK: - isManagedCapable + + func testAnthropicIsManagedCapable() { + XCTAssertTrue(store.isManagedCapable("anthropic")) + } + + func testOpenAIIsManagedCapable() { + XCTAssertTrue(store.isManagedCapable("openai")) + } + + func testGeminiIsManagedCapable() { + XCTAssertTrue(store.isManagedCapable("gemini")) + } + + func testOllamaIsNotManagedCapable() { + XCTAssertFalse(store.isManagedCapable("ollama")) + } + + func testFireworksIsNotManagedCapable() { + XCTAssertFalse(store.isManagedCapable("fireworks")) + } + + func testOpenRouterIsNotManagedCapable() { + XCTAssertFalse(store.isManagedCapable("openrouter")) + } + + func testUnknownProviderIsNotManagedCapable() { + XCTAssertFalse(store.isManagedCapable("unknown-provider")) + } + + // MARK: - isNativeWebSearchCapable + + func testAnthropicIsNativeWebSearchCapable() { + XCTAssertTrue(store.isNativeWebSearchCapable("anthropic")) + } + + func testOpenAIIsNativeWebSearchCapable() { + XCTAssertTrue(store.isNativeWebSearchCapable("openai")) + } + + func testGeminiIsNotNativeWebSearchCapable() { + XCTAssertFalse(store.isNativeWebSearchCapable("gemini")) + } + + func testOllamaIsNotNativeWebSearchCapable() { + XCTAssertFalse(store.isNativeWebSearchCapable("ollama")) + } + + // MARK: - managedCapableProviders + + func testManagedCapableProvidersContainsExpectedEntries() { + let ids = store.managedCapableProviders.map(\.id) + XCTAssertTrue(ids.contains("anthropic"), "expected anthropic in managed-capable providers") + XCTAssertTrue(ids.contains("openai"), "expected openai in managed-capable providers") + XCTAssertTrue(ids.contains("gemini"), "expected gemini in managed-capable providers") + } + + func testManagedCapableProvidersExcludesNonManagedEntries() { + let ids = store.managedCapableProviders.map(\.id) + XCTAssertFalse(ids.contains("ollama"), "ollama should not be in managed-capable providers") + XCTAssertFalse(ids.contains("fireworks"), "fireworks should not be in managed-capable providers") + XCTAssertFalse(ids.contains("openrouter"), "openrouter should not be in managed-capable providers") + } + + // MARK: - nativeWebSearchCapableProviders + + func testNativeWebSearchCapableProvidersContainsExpectedEntries() { + let ids = store.nativeWebSearchCapableProviders.map(\.id) + XCTAssertTrue(ids.contains("anthropic"), "expected anthropic in native-web-search-capable providers") + XCTAssertTrue(ids.contains("openai"), "expected openai in native-web-search-capable providers") + } + + func testNativeWebSearchCapableProvidersExcludesOthers() { + let ids = store.nativeWebSearchCapableProviders.map(\.id) + XCTAssertFalse(ids.contains("gemini"), "gemini should not be in native-web-search-capable providers") + XCTAssertFalse(ids.contains("ollama"), "ollama should not be in native-web-search-capable providers") + } + + // MARK: - Managed Provider Persistence + + func testManagedModeCanPersistOpenAIAsProvider() { + let mockClient = MockSettingsClient() + mockClient.patchConfigResponse = true + let testStore = SettingsStore(settingsClient: mockClient) + + // Simulate selecting OpenAI in managed mode + testStore.selectedInferenceProvider = "openai" + testStore.inferenceMode = "managed" + + // Persist the provider selection + _ = testStore.setInferenceProvider("openai") + + // Wait for the async patch to be captured + let predicate = NSPredicate { _, _ in + mockClient.patchConfigCalls.count >= 1 + } + let expectation = XCTNSPredicateExpectation(predicate: predicate, object: nil) + wait(for: [expectation], timeout: 2.0) + + // Verify the patched provider is "openai", not "anthropic" + let providerPatches = mockClient.patchConfigCalls.compactMap { call -> String? in + guard let services = call["services"] as? [String: Any], + let inference = services["inference"] as? [String: Any], + let provider = inference["provider"] as? String else { + return nil + } + return provider + } + XCTAssertTrue(providerPatches.contains("openai"), + "expected openai to be persisted as the inference provider, got: \(providerPatches)") + } + + func testManagedModeCanPersistGeminiAsProvider() { + let mockClient = MockSettingsClient() + mockClient.patchConfigResponse = true + let testStore = SettingsStore(settingsClient: mockClient) + + testStore.selectedInferenceProvider = "gemini" + testStore.inferenceMode = "managed" + _ = testStore.setInferenceProvider("gemini") + + let predicate = NSPredicate { _, _ in + mockClient.patchConfigCalls.count >= 1 + } + let expectation = XCTNSPredicateExpectation(predicate: predicate, object: nil) + wait(for: [expectation], timeout: 2.0) + + let providerPatches = mockClient.patchConfigCalls.compactMap { call -> String? in + guard let services = call["services"] as? [String: Any], + let inference = services["inference"] as? [String: Any], + let provider = inference["provider"] as? String else { + return nil + } + return provider + } + XCTAssertTrue(providerPatches.contains("gemini"), + "expected gemini to be persisted as the inference provider, got: \(providerPatches)") + } + + // MARK: - Model Validation Against Selected Provider + + func testOpenAIModelsAreAvailableForOpenAIProvider() { + let models = store.dynamicProviderModels("openai") + XCTAssertFalse(models.isEmpty, "expected OpenAI to have models in the default catalog") + // Verify these are OpenAI models (not Anthropic) + let modelIds = models.map(\.id) + XCTAssertTrue(modelIds.allSatisfy { !$0.hasPrefix("claude-") }, + "OpenAI models should not contain claude model IDs") + } + + func testAnthropicModelsAreAvailableForAnthropicProvider() { + let models = store.dynamicProviderModels("anthropic") + XCTAssertFalse(models.isEmpty, "expected Anthropic to have models in the default catalog") + let modelIds = models.map(\.id) + XCTAssertTrue(modelIds.allSatisfy { $0.hasPrefix("claude-") }, + "Anthropic models should all be claude models") + } +}