Skip to content

Commit

Permalink
improve: adds better error handling (#5)
Browse files Browse the repository at this point in the history
* improve: adds better error handling

* simplify mock
  • Loading branch information
kevinhermawan authored Oct 27, 2024
1 parent 5b42916 commit f7bf13e
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 86 deletions.
1 change: 1 addition & 0 deletions Playground/Playground/ViewModels/AppViewModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import Foundation

@MainActor
@Observable
final class AppViewModel {
var cohereAPIKey: String
Expand Down
71 changes: 47 additions & 24 deletions Playground/Playground/Views/ModelListView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ import AIModelRetriever
struct ModelListView: View {
private let title: String
private let provider: AIProvider
private let retriever = AIModelRetriever()

@Environment(AppViewModel.self) private var viewModel

@State private var isFetching: Bool = false
@State private var fetchTask: Task<Void, Never>?
@State private var models: [AIModel] = []

init(title: String, provider: AIProvider) {
Expand All @@ -21,36 +25,55 @@ struct ModelListView: View {
}

var body: some View {
List(models) { model in
VStack(alignment: .leading) {
Text(model.id)
.font(.footnote)
.foregroundStyle(.secondary)

Text(model.name)
VStack {
if models.isEmpty, isFetching {
VStack(spacing: 16) {
ProgressView()

Button("Cancel") {
fetchTask?.cancel()
}
}
} else {
List(models) { model in
VStack(alignment: .leading) {
Text(model.id)
.font(.footnote)
.foregroundStyle(.secondary)

Text(model.name)
}
}
}
}
.navigationTitle(title)
.task {
let retriever = AIModelRetriever()
isFetching = true

do {
switch provider {
case .anthropic:
models = retriever.anthropic()
case .cohere:
models = try await retriever.cohere(apiKey: viewModel.cohereAPIKey)
case .google:
models = retriever.google()
case .ollama:
models = try await retriever.ollama()
case .openai:
models = try await retriever.openAI(apiKey: viewModel.openaiAPIKey)
case .groq:
models = try await retriever.openAI(apiKey: viewModel.groqAPIKey, endpoint: URL(string: "https://api.groq.com/openai/v1/models"))
fetchTask = Task {
do {
defer {
self.isFetching = false
self.fetchTask = nil
}

switch provider {
case .anthropic:
models = retriever.anthropic()
case .cohere:
models = try await retriever.cohere(apiKey: viewModel.cohereAPIKey)
case .google:
models = retriever.google()
case .ollama:
models = try await retriever.ollama()
case .openai:
models = try await retriever.openAI(apiKey: viewModel.openaiAPIKey)
case .groq:
models = try await retriever.openAI(apiKey: viewModel.groqAPIKey, endpoint: URL(string: "https://api.groq.com/openai/v1/models"))
}
} catch {
print(String(describing: error))
}
} catch {
print(String(describing: error))
}
}
}
Expand Down
29 changes: 27 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,39 @@ do {
}
```

## Donations
### Error Handling

`AIModelRetrieverError` provides structured error handling through the `AIModelRetrieverError` enum. This enum contains three cases that represent different types of errors you might encounter:

```swift
do {
let models = try await modelRetriever.openai(apiKey: "your-api-key")
} catch let error as LLMChatOpenAIError {
switch error {
case .serverError(let message):
// Handle server-side errors (e.g., invalid API key, rate limits)
print("Server Error: \(message)")
case .networkError(let error):
// Handle network-related errors (e.g., no internet connection)
print("Network Error: \(error.localizedDescription)")
case .badServerResponse:
// Handle invalid server responses
print("Invalid response received from server")
case .cancelled:
// Handle cancelled requests
print("Request cancelled")
}
}
```

## Support

If you find `AIModelRetriever` helpful and would like to support its development, consider making a donation. Your contribution helps maintain the project and develop new features.

- [GitHub Sponsors](https://github.com/sponsors/kevinhermawan)
- [Buy Me a Coffee](https://buymeacoffee.com/kevinhermawan)

Your support is greatly appreciated!
Your support is greatly appreciated! ❤️

## Contributing

Expand Down
76 changes: 62 additions & 14 deletions Sources/AIModelRetriever/AIModelRetriever.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,33 @@ public struct AIModelRetriever: Sendable {
/// Initializes a new instance of ``AIModelRetriever``.
public init() {}

private func performRequest<T: Decodable>(_ request: URLRequest) async throws -> T {
let (data, response) = try await URLSession.shared.data(for: request)

guard let httpResponse = response as? HTTPURLResponse else {
throw AIModelRetrieverError.badServerResponse
}

guard 200...299 ~= httpResponse.statusCode else {
throw AIModelRetrieverError.serverError(statusCode: httpResponse.statusCode, error: String(data: data, encoding: .utf8))
private func performRequest<T: Decodable, E: ProviderError>(_ request: URLRequest, errorType: E.Type) async throws -> T {
do {
let (data, response) = try await URLSession.shared.data(for: request)

if let errorResponse = try? JSONDecoder().decode(E.self, from: data) {
throw AIModelRetrieverError.serverError(errorResponse.errorMessage)
}

guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else {
throw AIModelRetrieverError.badServerResponse
}

let models = try JSONDecoder().decode(T.self, from: data)

return models
} catch let error as AIModelRetrieverError {
throw error
} catch let error as URLError {
switch error.code {
case .cancelled:
throw AIModelRetrieverError.cancelled
default:
throw AIModelRetrieverError.networkError(error)
}
} catch {
throw AIModelRetrieverError.networkError(error)
}

return try JSONDecoder().decode(T.self, from: data)
}

private func createRequest(for endpoint: URL, with headers: [String: String]? = nil) -> URLRequest {
Expand Down Expand Up @@ -74,7 +89,7 @@ public extension AIModelRetriever {
let allHeaders = ["Authorization": "Bearer \(apiKey)"]

let request = createRequest(for: defaultEndpoint, with: allHeaders)
let response: CohereResponse = try await performRequest(request)
let response: CohereResponse = try await performRequest(request, errorType: CohereError.self)

return response.models.map { AIModel(id: $0.name, name: $0.name) }
}
Expand All @@ -86,6 +101,12 @@ public extension AIModelRetriever {
private struct CohereModel: Decodable {
let name: String
}

private struct CohereError: ProviderError {
let message: String

var errorMessage: String { message }
}
}

// MARK: - Google
Expand Down Expand Up @@ -122,7 +143,7 @@ public extension AIModelRetriever {
guard let defaultEndpoint = URL(string: "http://localhost:11434/api/tags") else { return [] }

let request = createRequest(for: endpoint ?? defaultEndpoint, with: headers)
let response: OllamaResponse = try await performRequest(request)
let response: OllamaResponse = try await performRequest(request, errorType: OllamaError.self)

return response.models.map { AIModel(id: $0.model, name: $0.name) }
}
Expand All @@ -135,6 +156,16 @@ public extension AIModelRetriever {
let name: String
let model: String
}

private struct OllamaError: ProviderError {
let error: Error

struct Error: Decodable {
let message: String
}

var errorMessage: String { error.message }
}
}

// MARK: - OpenAI
Expand All @@ -156,7 +187,7 @@ public extension AIModelRetriever {
allHeaders["Authorization"] = "Bearer \(apiKey)"

let request = createRequest(for: endpoint ?? defaultEndpoint, with: allHeaders)
let response: OpenAIResponse = try await performRequest(request)
let response: OpenAIResponse = try await performRequest(request, errorType: OpenAIError.self)

return response.data.map { AIModel(id: $0.id, name: $0.id) }
}
Expand All @@ -168,4 +199,21 @@ public extension AIModelRetriever {
private struct OpenAIModel: Decodable {
let id: String
}

private struct OpenAIError: ProviderError {
let error: Error

struct Error: Decodable {
let message: String
}

var errorMessage: String { error.message }
}
}

// MARK: - Supporting Types
private extension AIModelRetriever {
protocol ProviderError: Decodable {
var errorMessage: String { get }
}
}
34 changes: 27 additions & 7 deletions Sources/AIModelRetriever/AIModelRetrieverError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,33 @@ import Foundation

/// An enum that represents errors that can occur during AI model retrieval.
public enum AIModelRetrieverError: Error, Sendable {
/// Indicates that the server response was not in the expected format.
case badServerResponse
/// A case that represents a server-side error response.
///
/// - Parameter message: The error message from the server.
case serverError(String)

/// Indicates that the server returned an error.
/// A case that represents a network-related error.
///
/// - Parameters:
/// - statusCode: The HTTP status code returned by the server.
/// - error: An optional string that contains additional error information provided by the server.
case serverError(statusCode: Int, error: String?)
/// - Parameter error: The underlying network error.
case networkError(Error)

/// A case that represents an invalid server response.
case badServerResponse

/// A case that represents a request has been canceled.
case cancelled

/// A localized message that describes the error.
public var errorDescription: String? {
switch self {
case .serverError(let error):
return error
case .networkError(let error):
return error.localizedDescription
case .badServerResponse:
return "Invalid response received from server"
case .cancelled:
return "Request was cancelled"
}
}
}
25 changes: 14 additions & 11 deletions Sources/AIModelRetriever/Documentation.docc/Documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,25 @@ do {

### Error Handling

The package uses ``AIModelRetrieverError`` to represent specific errors that may occur. You can catch and handle these errors as follows:
``AIModelRetrieverError`` provides structured error handling through the ``AIModelRetrieverError`` enum. This enum contains three cases that represent different types of errors you might encounter:

```swift
let apiKey = "your-openai-api-key"

do {
let models = try await modelRetriever.openai(apiKey: apiKey)
// Process models
} catch let error as AIModelRetrieverError {
let models = try await modelRetriever.openai(apiKey: "your-api-key")
} catch let error as LLMChatOpenAIError {
switch error {
case .serverError(let message):
// Handle server-side errors (e.g., invalid API key, rate limits)
print("Server Error: \(message)")
case .networkError(let error):
// Handle network-related errors (e.g., no internet connection)
print("Network Error: \(error.localizedDescription)")
case .badServerResponse:
print("Received an invalid response from the server")
case .serverError(let statusCode, let errorMessage):
print("Server error (status \(statusCode)): \(errorMessage ?? "No error message provided")")
// Handle invalid server responses
print("Invalid response received from server")
case .cancelled:
// Handle cancelled requests
print("Request cancelled")
}
} catch {
print("An unexpected error occurred: \(error)")
}
```
Loading

0 comments on commit f7bf13e

Please sign in to comment.