From f7bf13e4acaad1a5d08e14f14f7b403d0b257fd8 Mon Sep 17 00:00:00 2001 From: Kevin Hermawan <84965338+kevinhermawan@users.noreply.github.com> Date: Sun, 27 Oct 2024 18:14:56 +0700 Subject: [PATCH] improve: adds better error handling (#5) * improve: adds better error handling * simplify mock --- .../Playground/ViewModels/AppViewModel.swift | 1 + .../Playground/Views/ModelListView.swift | 71 +++++++++++------ README.md | 29 ++++++- .../AIModelRetriever/AIModelRetriever.swift | 76 +++++++++++++++---- .../AIModelRetrieverError.swift | 34 +++++++-- .../Documentation.docc/Documentation.md | 25 +++--- .../AIModelRetrieverTests.swift | 59 +++++++++----- .../URLProtocolMock.swift | 17 +++-- 8 files changed, 226 insertions(+), 86 deletions(-) diff --git a/Playground/Playground/ViewModels/AppViewModel.swift b/Playground/Playground/ViewModels/AppViewModel.swift index f73ca7d..79868bb 100644 --- a/Playground/Playground/ViewModels/AppViewModel.swift +++ b/Playground/Playground/ViewModels/AppViewModel.swift @@ -7,6 +7,7 @@ import Foundation +@MainActor @Observable final class AppViewModel { var cohereAPIKey: String diff --git a/Playground/Playground/Views/ModelListView.swift b/Playground/Playground/Views/ModelListView.swift index b4bb1a1..e9f2bfa 100644 --- a/Playground/Playground/Views/ModelListView.swift +++ b/Playground/Playground/Views/ModelListView.swift @@ -11,8 +11,12 @@ import AIModelRetriever struct ModelListView: View { private let title: String private let provider: AIProvider + private let retriever = AIModelRetriever() @Environment(AppViewModel.self) private var viewModel + + @State private var isFetching: Bool = false + @State private var fetchTask: Task? @State private var models: [AIModel] = [] init(title: String, provider: AIProvider) { @@ -21,36 +25,55 @@ struct ModelListView: View { } var body: some View { - List(models) { model in - VStack(alignment: .leading) { - Text(model.id) - .font(.footnote) - .foregroundStyle(.secondary) - - Text(model.name) + VStack { + if models.isEmpty, isFetching { + VStack(spacing: 16) { + ProgressView() + + Button("Cancel") { + fetchTask?.cancel() + } + } + } else { + List(models) { model in + VStack(alignment: .leading) { + Text(model.id) + .font(.footnote) + .foregroundStyle(.secondary) + + Text(model.name) + } + } } } .navigationTitle(title) .task { - let retriever = AIModelRetriever() + isFetching = true - do { - switch provider { - case .anthropic: - models = retriever.anthropic() - case .cohere: - models = try await retriever.cohere(apiKey: viewModel.cohereAPIKey) - case .google: - models = retriever.google() - case .ollama: - models = try await retriever.ollama() - case .openai: - models = try await retriever.openAI(apiKey: viewModel.openaiAPIKey) - case .groq: - models = try await retriever.openAI(apiKey: viewModel.groqAPIKey, endpoint: URL(string: "https://api.groq.com/openai/v1/models")) + fetchTask = Task { + do { + defer { + self.isFetching = false + self.fetchTask = nil + } + + switch provider { + case .anthropic: + models = retriever.anthropic() + case .cohere: + models = try await retriever.cohere(apiKey: viewModel.cohereAPIKey) + case .google: + models = retriever.google() + case .ollama: + models = try await retriever.ollama() + case .openai: + models = try await retriever.openAI(apiKey: viewModel.openaiAPIKey) + case .groq: + models = try await retriever.openAI(apiKey: viewModel.groqAPIKey, endpoint: URL(string: "https://api.groq.com/openai/v1/models")) + } + } catch { + print(String(describing: error)) } - } catch { - print(String(describing: error)) } } } diff --git a/README.md b/README.md index 2dd0536..e8c26c1 100644 --- a/README.md +++ b/README.md @@ -152,14 +152,39 @@ do { } ``` -## Donations +### Error Handling + +`AIModelRetrieverError` provides structured error handling through the `AIModelRetrieverError` enum. This enum contains three cases that represent different types of errors you might encounter: + +```swift +do { + let models = try await modelRetriever.openai(apiKey: "your-api-key") +} catch let error as LLMChatOpenAIError { + switch error { + case .serverError(let message): + // Handle server-side errors (e.g., invalid API key, rate limits) + print("Server Error: \(message)") + case .networkError(let error): + // Handle network-related errors (e.g., no internet connection) + print("Network Error: \(error.localizedDescription)") + case .badServerResponse: + // Handle invalid server responses + print("Invalid response received from server") + case .cancelled: + // Handle cancelled requests + print("Request cancelled") + } +} +``` + +## Support If you find `AIModelRetriever` helpful and would like to support its development, consider making a donation. Your contribution helps maintain the project and develop new features. - [GitHub Sponsors](https://github.com/sponsors/kevinhermawan) - [Buy Me a Coffee](https://buymeacoffee.com/kevinhermawan) -Your support is greatly appreciated! +Your support is greatly appreciated! ❤️ ## Contributing diff --git a/Sources/AIModelRetriever/AIModelRetriever.swift b/Sources/AIModelRetriever/AIModelRetriever.swift index adf7131..ab339ab 100644 --- a/Sources/AIModelRetriever/AIModelRetriever.swift +++ b/Sources/AIModelRetriever/AIModelRetriever.swift @@ -15,18 +15,33 @@ public struct AIModelRetriever: Sendable { /// Initializes a new instance of ``AIModelRetriever``. public init() {} - private func performRequest(_ request: URLRequest) async throws -> T { - let (data, response) = try await URLSession.shared.data(for: request) - - guard let httpResponse = response as? HTTPURLResponse else { - throw AIModelRetrieverError.badServerResponse - } - - guard 200...299 ~= httpResponse.statusCode else { - throw AIModelRetrieverError.serverError(statusCode: httpResponse.statusCode, error: String(data: data, encoding: .utf8)) + private func performRequest(_ request: URLRequest, errorType: E.Type) async throws -> T { + do { + let (data, response) = try await URLSession.shared.data(for: request) + + if let errorResponse = try? JSONDecoder().decode(E.self, from: data) { + throw AIModelRetrieverError.serverError(errorResponse.errorMessage) + } + + guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else { + throw AIModelRetrieverError.badServerResponse + } + + let models = try JSONDecoder().decode(T.self, from: data) + + return models + } catch let error as AIModelRetrieverError { + throw error + } catch let error as URLError { + switch error.code { + case .cancelled: + throw AIModelRetrieverError.cancelled + default: + throw AIModelRetrieverError.networkError(error) + } + } catch { + throw AIModelRetrieverError.networkError(error) } - - return try JSONDecoder().decode(T.self, from: data) } private func createRequest(for endpoint: URL, with headers: [String: String]? = nil) -> URLRequest { @@ -74,7 +89,7 @@ public extension AIModelRetriever { let allHeaders = ["Authorization": "Bearer \(apiKey)"] let request = createRequest(for: defaultEndpoint, with: allHeaders) - let response: CohereResponse = try await performRequest(request) + let response: CohereResponse = try await performRequest(request, errorType: CohereError.self) return response.models.map { AIModel(id: $0.name, name: $0.name) } } @@ -86,6 +101,12 @@ public extension AIModelRetriever { private struct CohereModel: Decodable { let name: String } + + private struct CohereError: ProviderError { + let message: String + + var errorMessage: String { message } + } } // MARK: - Google @@ -122,7 +143,7 @@ public extension AIModelRetriever { guard let defaultEndpoint = URL(string: "http://localhost:11434/api/tags") else { return [] } let request = createRequest(for: endpoint ?? defaultEndpoint, with: headers) - let response: OllamaResponse = try await performRequest(request) + let response: OllamaResponse = try await performRequest(request, errorType: OllamaError.self) return response.models.map { AIModel(id: $0.model, name: $0.name) } } @@ -135,6 +156,16 @@ public extension AIModelRetriever { let name: String let model: String } + + private struct OllamaError: ProviderError { + let error: Error + + struct Error: Decodable { + let message: String + } + + var errorMessage: String { error.message } + } } // MARK: - OpenAI @@ -156,7 +187,7 @@ public extension AIModelRetriever { allHeaders["Authorization"] = "Bearer \(apiKey)" let request = createRequest(for: endpoint ?? defaultEndpoint, with: allHeaders) - let response: OpenAIResponse = try await performRequest(request) + let response: OpenAIResponse = try await performRequest(request, errorType: OpenAIError.self) return response.data.map { AIModel(id: $0.id, name: $0.id) } } @@ -168,4 +199,21 @@ public extension AIModelRetriever { private struct OpenAIModel: Decodable { let id: String } + + private struct OpenAIError: ProviderError { + let error: Error + + struct Error: Decodable { + let message: String + } + + var errorMessage: String { error.message } + } +} + +// MARK: - Supporting Types +private extension AIModelRetriever { + protocol ProviderError: Decodable { + var errorMessage: String { get } + } } diff --git a/Sources/AIModelRetriever/AIModelRetrieverError.swift b/Sources/AIModelRetriever/AIModelRetrieverError.swift index 51664f1..823ff5d 100644 --- a/Sources/AIModelRetriever/AIModelRetrieverError.swift +++ b/Sources/AIModelRetriever/AIModelRetrieverError.swift @@ -9,13 +9,33 @@ import Foundation /// An enum that represents errors that can occur during AI model retrieval. public enum AIModelRetrieverError: Error, Sendable { - /// Indicates that the server response was not in the expected format. - case badServerResponse + /// A case that represents a server-side error response. + /// + /// - Parameter message: The error message from the server. + case serverError(String) - /// Indicates that the server returned an error. + /// A case that represents a network-related error. /// - /// - Parameters: - /// - statusCode: The HTTP status code returned by the server. - /// - error: An optional string that contains additional error information provided by the server. - case serverError(statusCode: Int, error: String?) + /// - Parameter error: The underlying network error. + case networkError(Error) + + /// A case that represents an invalid server response. + case badServerResponse + + /// A case that represents a request has been canceled. + case cancelled + + /// A localized message that describes the error. + public var errorDescription: String? { + switch self { + case .serverError(let error): + return error + case .networkError(let error): + return error.localizedDescription + case .badServerResponse: + return "Invalid response received from server" + case .cancelled: + return "Request was cancelled" + } + } } diff --git a/Sources/AIModelRetriever/Documentation.docc/Documentation.md b/Sources/AIModelRetriever/Documentation.docc/Documentation.md index f92e440..67cf38a 100644 --- a/Sources/AIModelRetriever/Documentation.docc/Documentation.md +++ b/Sources/AIModelRetriever/Documentation.docc/Documentation.md @@ -103,22 +103,25 @@ do { ### Error Handling -The package uses ``AIModelRetrieverError`` to represent specific errors that may occur. You can catch and handle these errors as follows: +``AIModelRetrieverError`` provides structured error handling through the ``AIModelRetrieverError`` enum. This enum contains three cases that represent different types of errors you might encounter: ```swift -let apiKey = "your-openai-api-key" - do { - let models = try await modelRetriever.openai(apiKey: apiKey) - // Process models -} catch let error as AIModelRetrieverError { + let models = try await modelRetriever.openai(apiKey: "your-api-key") +} catch let error as LLMChatOpenAIError { switch error { + case .serverError(let message): + // Handle server-side errors (e.g., invalid API key, rate limits) + print("Server Error: \(message)") + case .networkError(let error): + // Handle network-related errors (e.g., no internet connection) + print("Network Error: \(error.localizedDescription)") case .badServerResponse: - print("Received an invalid response from the server") - case .serverError(let statusCode, let errorMessage): - print("Server error (status \(statusCode)): \(errorMessage ?? "No error message provided")") + // Handle invalid server responses + print("Invalid response received from server") + case .cancelled: + // Handle cancelled requests + print("Request cancelled") } -} catch { - print("An unexpected error occurred: \(error)") } ``` diff --git a/Tests/AIModelRetrieverTests/AIModelRetrieverTests.swift b/Tests/AIModelRetrieverTests/AIModelRetrieverTests.swift index 4b7e358..0c77c1c 100644 --- a/Tests/AIModelRetrieverTests/AIModelRetrieverTests.swift +++ b/Tests/AIModelRetrieverTests/AIModelRetrieverTests.swift @@ -14,8 +14,8 @@ final class AIModelRetrieverTests: XCTestCase { override func setUp() { super.setUp() - retriever = AIModelRetriever() URLProtocol.registerClass(URLProtocolMock.self) + retriever = AIModelRetriever() } override func tearDown() { @@ -23,7 +23,6 @@ final class AIModelRetrieverTests: XCTestCase { URLProtocol.unregisterClass(URLProtocolMock.self) URLProtocolMock.mockData = nil URLProtocolMock.mockError = nil - URLProtocolMock.mockStatusCode = 200 super.tearDown() } @@ -120,35 +119,53 @@ final class AIModelRetrieverTests: XCTestCase { XCTAssertEqual(models[0].id, "custom-model-1") XCTAssertEqual(models[1].id, "custom-model-2") } - - func testServerError() async { - let errorResponse = """ +} + +// MARK: - Error Handling +extension AIModelRetrieverTests { + func testServerError() async throws { + let mockErrorResponse = """ { - "error": "Invalid API key" + "error": { + "message": "Invalid API key provided" + } } """ - URLProtocolMock.mockData = errorResponse.data(using: .utf8) - URLProtocolMock.mockStatusCode = 401 + URLProtocolMock.mockData = mockErrorResponse.data(using: .utf8) do { - _ = try await retriever.openAI(apiKey: "invalid-key") - XCTFail("Expected an error to be thrown") + let _ = try await retriever.openAI(apiKey: "test-key") + + XCTFail("Expected serverError to be thrown") + } catch let error as AIModelRetrieverError { + switch error { + case .serverError(let message): + XCTAssertEqual(message, "Invalid API key provided") + default: + XCTFail("Expected serverError but got \(error)") + } + } + } + + func testNetworkError() async throws { + URLProtocolMock.mockError = NSError( + domain: NSURLErrorDomain, + code: NSURLErrorNotConnectedToInternet, + userInfo: [NSLocalizedDescriptionKey: "The Internet connection appears to be offline."] + ) + + do { + let _ = try await retriever.openAI(apiKey: "test-key") + + XCTFail("Expected networkError to be thrown") } catch let error as AIModelRetrieverError { switch error { - case .serverError(let statusCode, let errorMessage): - XCTAssertEqual(statusCode, 401) - - if let errorData = errorMessage?.data(using: .utf8), - let jsonObject = try? JSONSerialization.jsonObject(with: errorData, options: []) as? [String: Any], - let errorString = jsonObject["error"] as? String { - XCTAssertEqual(errorString, "Invalid API key") - } + case .networkError(let underlyingError): + XCTAssertEqual((underlyingError as NSError).code, NSURLErrorNotConnectedToInternet) default: - XCTFail("Unexpected error type") + XCTFail("Expected networkError but got \(error)") } - } catch { - XCTFail("Unexpected error type") } } } diff --git a/Tests/AIModelRetrieverTests/URLProtocolMock.swift b/Tests/AIModelRetrieverTests/URLProtocolMock.swift index 54ea778..a37b18b 100644 --- a/Tests/AIModelRetrieverTests/URLProtocolMock.swift +++ b/Tests/AIModelRetrieverTests/URLProtocolMock.swift @@ -10,7 +10,6 @@ import Foundation final class URLProtocolMock: URLProtocol { static var mockData: Data? static var mockError: Error? - static var mockStatusCode: Int = 200 override class func canInit(with request: URLRequest) -> Bool { return true @@ -23,16 +22,20 @@ final class URLProtocolMock: URLProtocol { override func startLoading() { if let error = URLProtocolMock.mockError { client?.urlProtocol(self, didFailWithError: error) - } else { - if let data = URLProtocolMock.mockData { - client?.urlProtocol(self, didLoad: data) - } - - let response = HTTPURLResponse(url: request.url!, statusCode: URLProtocolMock.mockStatusCode, httpVersion: nil, headerFields: nil)! + client?.urlProtocolDidFinishLoading(self) + return + } + + if let data = URLProtocolMock.mockData, let url = request.url, let response = HTTPURLResponse(url: url, statusCode: 200, httpVersion: nil, headerFields: nil) { client?.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed) + client?.urlProtocol(self, didLoad: data) + client?.urlProtocolDidFinishLoading(self) + + return } + client?.urlProtocol(self, didFailWithError: NSError(domain: "No mock data", code: -1, userInfo: nil)) client?.urlProtocolDidFinishLoading(self) }