diff --git a/Playground/Playground/ViewModels/AppViewModel.swift b/Playground/Playground/ViewModels/AppViewModel.swift index 67a9b3b..f73ca7d 100644 --- a/Playground/Playground/ViewModels/AppViewModel.swift +++ b/Playground/Playground/ViewModels/AppViewModel.swift @@ -9,16 +9,19 @@ import Foundation @Observable final class AppViewModel { - var openaiAPIKey: String + var cohereAPIKey: String var groqAPIKey: String + var openaiAPIKey: String init() { - self.openaiAPIKey = UserDefaults.standard.string(forKey: "openaiAPIKey") ?? "" + self.cohereAPIKey = UserDefaults.standard.string(forKey: "cohereAPIKey") ?? "" self.groqAPIKey = UserDefaults.standard.string(forKey: "groqAPIKey") ?? "" + self.openaiAPIKey = UserDefaults.standard.string(forKey: "openaiAPIKey") ?? "" } func saveAPIKeys() { - UserDefaults.standard.set(openaiAPIKey, forKey: "openaiAPIKey") + UserDefaults.standard.set(cohereAPIKey, forKey: "cohereAPIKey") UserDefaults.standard.set(groqAPIKey, forKey: "groqAPIKey") + UserDefaults.standard.set(openaiAPIKey, forKey: "openaiAPIKey") } } diff --git a/Playground/Playground/Views/AppView.swift b/Playground/Playground/Views/AppView.swift index 0ee1b80..5c4ff1b 100644 --- a/Playground/Playground/Views/AppView.swift +++ b/Playground/Playground/Views/AppView.swift @@ -9,10 +9,11 @@ import SwiftUI enum AIProvider: String, CaseIterable { case anthropic = "Anthropic" + case cohere = "Cohere" case google = "Google" case ollama = "Ollama" case openai = "OpenAI" - case groq = "Groq (OpenAI-compatible)" + case groq = "OpenAI-Compatible (Groq)" } struct AppView: View { @@ -25,8 +26,9 @@ struct AppView: View { NavigationLink(provider.rawValue) { ModelListView(title: provider.rawValue, provider: provider) } - .disabled(provider == .openai && viewModel.openaiAPIKey.isEmpty) + .disabled(provider == .cohere && viewModel.cohereAPIKey.isEmpty) .disabled(provider == .groq && viewModel.groqAPIKey.isEmpty) + .disabled(provider == .openai && viewModel.openaiAPIKey.isEmpty) } .navigationTitle("Playground") .navigationBarTitleDisplayMode(.inline) diff --git a/Playground/Playground/Views/ModelListView.swift b/Playground/Playground/Views/ModelListView.swift index 9d03442..b4bb1a1 100644 --- a/Playground/Playground/Views/ModelListView.swift +++ b/Playground/Playground/Views/ModelListView.swift @@ -38,6 +38,8 @@ struct ModelListView: View { switch provider { case .anthropic: models = retriever.anthropic() + case .cohere: + models = try await retriever.cohere(apiKey: viewModel.cohereAPIKey) case .google: models = retriever.google() case .ollama: diff --git a/Playground/Playground/Views/SettingsView.swift b/Playground/Playground/Views/SettingsView.swift index 32dc4a6..5553d18 100644 --- a/Playground/Playground/Views/SettingsView.swift +++ b/Playground/Playground/Views/SettingsView.swift @@ -16,11 +16,15 @@ struct SettingsView: View { NavigationStack { Form { - Section("OpenAI API Key") { + Section("Cohere") { + TextField("API Key", text: $viewModelBindable.cohereAPIKey) + } + + Section("OpenAI") { TextField("API Key", text: $viewModelBindable.openaiAPIKey) } - Section("Groq API Key (OpenAI-compatible)") { + Section("OpenAI-Compatible (Groq)") { TextField("API Key", text: $viewModelBindable.groqAPIKey) } } diff --git a/Sources/AIModelRetriever/AIModelRetriever.swift b/Sources/AIModelRetriever/AIModelRetriever.swift index c5d0f53..dd5b798 100644 --- a/Sources/AIModelRetriever/AIModelRetriever.swift +++ b/Sources/AIModelRetriever/AIModelRetriever.swift @@ -58,6 +58,37 @@ public extension AIModelRetriever { } } +// MARK: - Cohere +public extension AIModelRetriever { + /// Retrieves a list of AI models from Cohere. + /// + /// - Parameters: + /// - apiKey: The API key for authenticating with the API. + /// + /// - Returns: An array of ``AIModel`` that represents Cohere's available models. + /// + /// - Throws: An error if the network request fails or if the response cannot be decoded. + func cohere(apiKey: String) async throws -> [AIModel] { + guard let defaultEndpoint = URL(string: "https://api.cohere.com/v1/models?page_size=1000") else { return [] } + + let allHeaders = ["Authorization": "Bearer \(apiKey)"] + + let request = createRequest(for: defaultEndpoint, with: allHeaders) + let response: CohereResponse = try await performRequest(request) + + return response.models.map { AIModel(id: $0.name, name: $0.name) } + } + + private struct CohereResponse: Decodable { + let models: [CohereModel] + } + + private struct CohereModel: Decodable { + let name: String + } +} + +// MARK: - Google public extension AIModelRetriever { /// Retrieves a list of AI models from Google. /// diff --git a/Tests/AIModelRetrieverTests/AIModelRetrieverTests.swift b/Tests/AIModelRetrieverTests/AIModelRetrieverTests.swift index 23e665f..4b7e358 100644 --- a/Tests/AIModelRetrieverTests/AIModelRetrieverTests.swift +++ b/Tests/AIModelRetrieverTests/AIModelRetrieverTests.swift @@ -36,6 +36,25 @@ final class AIModelRetrieverTests: XCTestCase { XCTAssertTrue(models.contains { $0.name == "Claude 3.5 Sonnet (Latest)" }) } + func testCohere() async throws { + let mockResponseString = """ + { + "models": [ + {"name": "test-model-1"}, + {"name": "test-model-2"} + ] + } + """ + + URLProtocolMock.mockData = mockResponseString.data(using: .utf8) + + let models = try await retriever.cohere(apiKey: "test-key") + + XCTAssertEqual(models.count, 2) + XCTAssertEqual(models[0].id, "test-model-1") + XCTAssertEqual(models[1].name, "test-model-2") + } + func testGoogle() { let models = retriever.google()