Skip to content

Commit

Permalink
feat: adds new provider (Cohere) (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinhermawan authored Oct 23, 2024
1 parent dc1a00d commit a5237c7
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 7 deletions.
9 changes: 6 additions & 3 deletions Playground/Playground/ViewModels/AppViewModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@ import Foundation

@Observable
final class AppViewModel {
var openaiAPIKey: String
var cohereAPIKey: String
var groqAPIKey: String
var openaiAPIKey: String

init() {
self.openaiAPIKey = UserDefaults.standard.string(forKey: "openaiAPIKey") ?? ""
self.cohereAPIKey = UserDefaults.standard.string(forKey: "cohereAPIKey") ?? ""
self.groqAPIKey = UserDefaults.standard.string(forKey: "groqAPIKey") ?? ""
self.openaiAPIKey = UserDefaults.standard.string(forKey: "openaiAPIKey") ?? ""
}

func saveAPIKeys() {
UserDefaults.standard.set(openaiAPIKey, forKey: "openaiAPIKey")
UserDefaults.standard.set(cohereAPIKey, forKey: "cohereAPIKey")
UserDefaults.standard.set(groqAPIKey, forKey: "groqAPIKey")
UserDefaults.standard.set(openaiAPIKey, forKey: "openaiAPIKey")
}
}
6 changes: 4 additions & 2 deletions Playground/Playground/Views/AppView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import SwiftUI

enum AIProvider: String, CaseIterable {
case anthropic = "Anthropic"
case cohere = "Cohere"
case google = "Google"
case ollama = "Ollama"
case openai = "OpenAI"
case groq = "Groq (OpenAI-compatible)"
case groq = "OpenAI-Compatible (Groq)"
}

struct AppView: View {
Expand All @@ -25,8 +26,9 @@ struct AppView: View {
NavigationLink(provider.rawValue) {
ModelListView(title: provider.rawValue, provider: provider)
}
.disabled(provider == .openai && viewModel.openaiAPIKey.isEmpty)
.disabled(provider == .cohere && viewModel.cohereAPIKey.isEmpty)
.disabled(provider == .groq && viewModel.groqAPIKey.isEmpty)
.disabled(provider == .openai && viewModel.openaiAPIKey.isEmpty)
}
.navigationTitle("Playground")
.navigationBarTitleDisplayMode(.inline)
Expand Down
2 changes: 2 additions & 0 deletions Playground/Playground/Views/ModelListView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ struct ModelListView: View {
switch provider {
case .anthropic:
models = retriever.anthropic()
case .cohere:
models = try await retriever.cohere(apiKey: viewModel.cohereAPIKey)
case .google:
models = retriever.google()
case .ollama:
Expand Down
8 changes: 6 additions & 2 deletions Playground/Playground/Views/SettingsView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ struct SettingsView: View {

NavigationStack {
Form {
Section("OpenAI API Key") {
Section("Cohere") {
TextField("API Key", text: $viewModelBindable.cohereAPIKey)
}

Section("OpenAI") {
TextField("API Key", text: $viewModelBindable.openaiAPIKey)
}

Section("Groq API Key (OpenAI-compatible)") {
Section("OpenAI-Compatible (Groq)") {
TextField("API Key", text: $viewModelBindable.groqAPIKey)
}
}
Expand Down
31 changes: 31 additions & 0 deletions Sources/AIModelRetriever/AIModelRetriever.swift
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,37 @@ public extension AIModelRetriever {
}
}

// MARK: - Cohere
public extension AIModelRetriever {
/// Retrieves a list of AI models from Cohere.
///
/// - Parameters:
/// - apiKey: The API key for authenticating with the API.
///
/// - Returns: An array of ``AIModel`` that represents Cohere's available models.
///
/// - Throws: An error if the network request fails or if the response cannot be decoded.
func cohere(apiKey: String) async throws -> [AIModel] {
guard let defaultEndpoint = URL(string: "https://api.cohere.com/v1/models?page_size=1000") else { return [] }

let allHeaders = ["Authorization": "Bearer \(apiKey)"]

let request = createRequest(for: defaultEndpoint, with: allHeaders)
let response: CohereResponse = try await performRequest(request)

return response.models.map { AIModel(id: $0.name, name: $0.name) }
}

private struct CohereResponse: Decodable {
let models: [CohereModel]
}

private struct CohereModel: Decodable {
let name: String
}
}

// MARK: - Google
public extension AIModelRetriever {
/// Retrieves a list of AI models from Google.
///
Expand Down
19 changes: 19 additions & 0 deletions Tests/AIModelRetrieverTests/AIModelRetrieverTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,25 @@ final class AIModelRetrieverTests: XCTestCase {
XCTAssertTrue(models.contains { $0.name == "Claude 3.5 Sonnet (Latest)" })
}

func testCohere() async throws {
let mockResponseString = """
{
"models": [
{"name": "test-model-1"},
{"name": "test-model-2"}
]
}
"""

URLProtocolMock.mockData = mockResponseString.data(using: .utf8)

let models = try await retriever.cohere(apiKey: "test-key")

XCTAssertEqual(models.count, 2)
XCTAssertEqual(models[0].id, "test-model-1")
XCTAssertEqual(models[1].name, "test-model-2")
}

func testGoogle() {
let models = retriever.google()

Expand Down

0 comments on commit a5237c7

Please sign in to comment.