From 972718d4bed3264ec493f4c992a6b8257e4933a4 Mon Sep 17 00:00:00 2001 From: Chris Dillard Date: Sat, 16 Dec 2023 10:23:56 -0700 Subject: [PATCH 1/6] Feat/assistants openai api (#2) * feat: Assistants API * Rem accidentally committed Dev team * Demoapp syntax fix for ImageCreationView * assistant paging, modify, fix * demo: enhancement: Handle local message replacement, README update * clean, runRetrieveSteps implemented, SupportedFileTypes implemented * Handle run retrieve steps * Assistant README, add run retrieve steps * display run retrieve steps in updating fashion for code_interpreter --- Demo/App/APIProvidedView.swift | 10 + Demo/App/ContentView.swift | 29 ++- Demo/Demo.xcodeproj/project.pbxproj | 6 +- Demo/DemoChat/Sources/AssistantStore.swift | 125 +++++++++ Demo/DemoChat/Sources/ChatStore.swift | 245 ++++++++++++++++-- Demo/DemoChat/Sources/Models/Assistant.swift | 33 +++ .../Sources/Models/Conversation.swift | 11 +- Demo/DemoChat/Sources/Models/Message.swift | 3 + Demo/DemoChat/Sources/SupportedFileType.swift | 92 +++++++ .../UI/AssistantModalContentView.swift | 112 ++++++++ .../Sources/UI/AssistantsListView.swift | 42 +++ Demo/DemoChat/Sources/UI/AssistantsView.swift | 196 ++++++++++++++ Demo/DemoChat/Sources/UI/ChatView.swift | 92 ++++--- Demo/DemoChat/Sources/UI/DetailView.swift | 74 ++++-- Demo/DemoChat/Sources/UI/DocumentPicker.swift | 41 +++ Demo/DemoChat/Sources/UI/ListView.swift | 26 +- .../Sources/UI/ModerationChatView.swift | 4 +- README.md | 129 +++++++++ Sources/OpenAI/OpenAI.swift | 93 ++++++- Sources/OpenAI/Private/JSONRequest.swift | 3 + .../Public/Models/AssistantsQuery.swift | 56 ++++ .../Public/Models/AssistantsResult.swift | 41 +++ Sources/OpenAI/Public/Models/FilesQuery.swift | 42 +++ .../OpenAI/Public/Models/FilesResult.swift | 15 ++ .../Public/Models/RunRetrieveQuery.swift | 15 ++ .../Public/Models/RunRetrieveResult.swift | 13 + .../Models/RunRetrieveStepsResult.swift | 55 ++++ Sources/OpenAI/Public/Models/RunsQuery.swift | 22 ++ Sources/OpenAI/Public/Models/RunsResult.swift | 13 + .../Public/Models/ThreadAddMessageQuery.swift | 24 ++ .../Models/ThreadAddMessagesResult.swift | 13 + .../Public/Models/ThreadsMessagesResult.swift | 65 +++++ .../OpenAI/Public/Models/ThreadsQuery.swift | 20 ++ .../OpenAI/Public/Models/ThreadsResult.swift | 13 + .../Protocols/OpenAIProtocol+Async.swift | 145 +++++++++++ .../Protocols/OpenAIProtocol+Combine.swift | 38 +++ .../Public/Protocols/OpenAIProtocol.swift | 170 ++++++++++++ Tests/OpenAITests/OpenAITests.swift | 129 ++++++++- Tests/OpenAITests/OpenAITestsCombine.swift | 51 ++++ 39 files changed, 2200 insertions(+), 106 deletions(-) create mode 100644 Demo/DemoChat/Sources/AssistantStore.swift create mode 100644 Demo/DemoChat/Sources/Models/Assistant.swift create mode 100644 Demo/DemoChat/Sources/SupportedFileType.swift create mode 100644 Demo/DemoChat/Sources/UI/AssistantModalContentView.swift create mode 100644 Demo/DemoChat/Sources/UI/AssistantsListView.swift create mode 100644 Demo/DemoChat/Sources/UI/AssistantsView.swift create mode 100644 Demo/DemoChat/Sources/UI/DocumentPicker.swift create mode 100644 Sources/OpenAI/Public/Models/AssistantsQuery.swift create mode 100644 Sources/OpenAI/Public/Models/AssistantsResult.swift create mode 100644 Sources/OpenAI/Public/Models/FilesQuery.swift create mode 100644 Sources/OpenAI/Public/Models/FilesResult.swift create mode 100644 Sources/OpenAI/Public/Models/RunRetrieveQuery.swift create mode 100644 Sources/OpenAI/Public/Models/RunRetrieveResult.swift create mode 100644 Sources/OpenAI/Public/Models/RunRetrieveStepsResult.swift create mode 100644 Sources/OpenAI/Public/Models/RunsQuery.swift create mode 100644 Sources/OpenAI/Public/Models/RunsResult.swift create mode 100644 Sources/OpenAI/Public/Models/ThreadAddMessageQuery.swift create mode 100644 Sources/OpenAI/Public/Models/ThreadAddMessagesResult.swift create mode 100644 Sources/OpenAI/Public/Models/ThreadsMessagesResult.swift create mode 100644 Sources/OpenAI/Public/Models/ThreadsQuery.swift create mode 100644 Sources/OpenAI/Public/Models/ThreadsResult.swift diff --git a/Demo/App/APIProvidedView.swift b/Demo/App/APIProvidedView.swift index 9771e1fb..c9362a21 100644 --- a/Demo/App/APIProvidedView.swift +++ b/Demo/App/APIProvidedView.swift @@ -13,7 +13,9 @@ struct APIProvidedView: View { @Binding var apiKey: String @StateObject var chatStore: ChatStore @StateObject var imageStore: ImageStore + @StateObject var assistantStore: AssistantStore @StateObject var miscStore: MiscStore + @State var isShowingAPIConfigModal: Bool = true @Environment(\.idProviderValue) var idProvider @@ -35,6 +37,12 @@ struct APIProvidedView: View { openAIClient: OpenAI(apiToken: apiKey.wrappedValue) ) ) + self._assistantStore = StateObject( + wrappedValue: AssistantStore( + openAIClient: OpenAI(apiToken: apiKey.wrappedValue), + idProvider: idProvider + ) + ) self._miscStore = StateObject( wrappedValue: MiscStore( openAIClient: OpenAI(apiToken: apiKey.wrappedValue) @@ -46,12 +54,14 @@ struct APIProvidedView: View { ContentView( chatStore: chatStore, imageStore: imageStore, + assistantStore: assistantStore, miscStore: miscStore ) .onChange(of: apiKey) { newApiKey in let client = OpenAI(apiToken: newApiKey) chatStore.openAIClient = client imageStore.openAIClient = client + assistantStore.openAIClient = client miscStore.openAIClient = client } } diff --git a/Demo/App/ContentView.swift b/Demo/App/ContentView.swift index 2826e6bc..091951c7 100644 --- a/Demo/App/ContentView.swift +++ b/Demo/App/ContentView.swift @@ -12,26 +12,38 @@ import SwiftUI struct ContentView: View { @ObservedObject var chatStore: ChatStore @ObservedObject var imageStore: ImageStore + @ObservedObject var assistantStore: AssistantStore @ObservedObject var miscStore: MiscStore + @State private var selectedTab = 0 @Environment(\.idProviderValue) var idProvider var body: some View { TabView(selection: $selectedTab) { ChatView( - store: chatStore + store: chatStore, + assistantStore: assistantStore ) .tabItem { Label("Chats", systemImage: "message") } .tag(0) + AssistantsView( + store: chatStore, + assistantStore: assistantStore + ) + .tabItem { + Label("Assistants", systemImage: "eyeglasses") + } + .tag(1) + TranscribeView( ) .tabItem { Label("Transcribe", systemImage: "mic") } - .tag(1) + .tag(2) ImageView( store: imageStore @@ -39,26 +51,19 @@ struct ContentView: View { .tabItem { Label("Image", systemImage: "photo") } - .tag(2) - + .tag(3) + MiscView( store: miscStore ) .tabItem { Label("Misc", systemImage: "ellipsis") } - .tag(3) + .tag(4) } } } -struct ChatsView: View { - var body: some View { - Text("Chats") - .font(.largeTitle) - } -} - struct TranscribeView: View { var body: some View { Text("Transcribe: TBD") diff --git a/Demo/Demo.xcodeproj/project.pbxproj b/Demo/Demo.xcodeproj/project.pbxproj index edde7d8d..528156b6 100644 --- a/Demo/Demo.xcodeproj/project.pbxproj +++ b/Demo/Demo.xcodeproj/project.pbxproj @@ -234,6 +234,7 @@ GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 16.0; MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; MTL_FAST_MATH = YES; ONLY_ACTIVE_ARCH = YES; @@ -286,6 +287,7 @@ GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 16.0; MTL_ENABLE_DEBUG_INFO = NO; MTL_FAST_MATH = YES; SWIFT_COMPILATION_MODE = wholemodule; @@ -315,7 +317,7 @@ "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault; INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; - IPHONEOS_DEPLOYMENT_TARGET = 16.4; + IPHONEOS_DEPLOYMENT_TARGET = 16.0; LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks"; "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks"; MACOSX_DEPLOYMENT_TARGET = 13.3; @@ -354,7 +356,7 @@ "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault; INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; - IPHONEOS_DEPLOYMENT_TARGET = 16.4; + IPHONEOS_DEPLOYMENT_TARGET = 16.0; LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks"; "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks"; MACOSX_DEPLOYMENT_TARGET = 13.3; diff --git a/Demo/DemoChat/Sources/AssistantStore.swift b/Demo/DemoChat/Sources/AssistantStore.swift new file mode 100644 index 00000000..1393fefb --- /dev/null +++ b/Demo/DemoChat/Sources/AssistantStore.swift @@ -0,0 +1,125 @@ +// +// ChatStore.swift +// DemoChat +// +// Created by Sihao Lu on 3/25/23. +// + +import Foundation +import Combine +import OpenAI + +public final class AssistantStore: ObservableObject { + public var openAIClient: OpenAIProtocol + let idProvider: () -> String + @Published var selectedAssistantId: String? + + @Published var availableAssistants: [Assistant] = [] + + public init( + openAIClient: OpenAIProtocol, + idProvider: @escaping () -> String + ) { + self.openAIClient = openAIClient + self.idProvider = idProvider + } + + // MARK: Models + + @MainActor + func createAssistant(name: String, description: String, instructions: String, codeInterpreter: Bool, retrievel: Bool, fileIds: [String]? = nil) async -> String? { + do { + let tools = createToolsArray(codeInterpreter: codeInterpreter, retrieval: retrievel) + let query = AssistantsQuery(model: Model.gpt4_1106_preview, name: name, description: description, instructions: instructions, tools:tools, fileIds: fileIds) + let response = try await openAIClient.assistants(query: query, method: "POST", after: nil) + + // Refresh assistants with one just created (or modified) + let _ = await getAssistants() + + // Returns assistantId + return response.id + + } catch { + // TODO: Better error handling + print(error.localizedDescription) + } + return nil + } + + @MainActor + func modifyAssistant(asstId: String, name: String, description: String, instructions: String, codeInterpreter: Bool, retrievel: Bool, fileIds: [String]? = nil) async -> String? { + do { + let tools = createToolsArray(codeInterpreter: codeInterpreter, retrieval: retrievel) + let query = AssistantsQuery(model: Model.gpt4_1106_preview, name: name, description: description, instructions: instructions, tools:tools, fileIds: fileIds) + let response = try await openAIClient.assistantModify(query: query, asstId: asstId) + + // Returns assistantId + return response.id + + } catch { + // TODO: Better error handling + print(error.localizedDescription) + } + return nil + } + + @MainActor + func getAssistants(limit: Int = 20, after: String? = nil) async -> [Assistant] { + do { + let response = try await openAIClient.assistants(query: nil, method: "GET", after: after) + + var assistants = [Assistant]() + for result in response.data ?? [] { + let codeInterpreter = result.tools?.filter { $0.toolType == "code_interpreter" }.first != nil + let retrieval = result.tools?.filter { $0.toolType == "retrieval" }.first != nil + let fileIds = result.fileIds ?? [] + + assistants.append(Assistant(id: result.id, name: result.name, description: result.description, instructions: result.instructions, codeInterpreter: codeInterpreter, retrieval: retrieval, fileIds: fileIds)) + } + if after == nil { + availableAssistants = assistants + } + else { + availableAssistants = availableAssistants + assistants + } + return assistants + + } catch { + // TODO: Better error handling + print(error.localizedDescription) + } + return [] + } + + func selectAssistant(_ assistantId: String?) { + selectedAssistantId = assistantId + } + + @MainActor + func uploadFile(url: URL) async -> FilesResult? { + do { + + let mimeType = url.mimeType() + + let fileData = try Data(contentsOf: url) + + let result = try await openAIClient.files(query: FilesQuery(purpose: "assistants", file: fileData, fileName: url.lastPathComponent, contentType: mimeType)) + return result + } + catch { + print("error = \(error)") + return nil + } + } + + func createToolsArray(codeInterpreter: Bool, retrieval: Bool) -> [Tool] { + var tools = [Tool]() + if codeInterpreter { + tools.append(Tool(toolType: "code_interpreter")) + } + if retrieval { + tools.append(Tool(toolType: "retrieval")) + } + return tools + } +} diff --git a/Demo/DemoChat/Sources/ChatStore.swift b/Demo/DemoChat/Sources/ChatStore.swift index 51ee6b11..99f62696 100644 --- a/Demo/DemoChat/Sources/ChatStore.swift +++ b/Demo/DemoChat/Sources/ChatStore.swift @@ -8,6 +8,7 @@ import Foundation import Combine import OpenAI +import SwiftUI public final class ChatStore: ObservableObject { public var openAIClient: OpenAIProtocol @@ -17,6 +18,15 @@ public final class ChatStore: ObservableObject { @Published var conversationErrors: [Conversation.ID: Error] = [:] @Published var selectedConversationID: Conversation.ID? + // Used for assistants API state. + private var timer: Timer? + private var timeInterval: TimeInterval = 1.0 + private var currentRunId: String? + private var currentThreadId: String? + private var currentConversationId: String? + + @Published var isSendingMessage = false + var selectedConversation: Conversation? { selectedConversationID.flatMap { id in conversations.first { $0.id == id } @@ -39,19 +49,19 @@ public final class ChatStore: ObservableObject { } // MARK: - Events - func createConversation() { - let conversation = Conversation(id: idProvider(), messages: []) + func createConversation(type: ConversationType = .normal, assistantId: String? = nil) { + let conversation = Conversation(id: idProvider(), messages: [], type: type, assistantId: assistantId) conversations.append(conversation) } - + func selectConversation(_ conversationId: Conversation.ID?) { selectedConversationID = conversationId } - + func deleteConversation(_ conversationId: Conversation.ID) { conversations.removeAll(where: { $0.id == conversationId }) } - + @MainActor func sendMessage( _ message: Message, @@ -61,14 +71,69 @@ public final class ChatStore: ObservableObject { guard let conversationIndex = conversations.firstIndex(where: { $0.id == conversationId }) else { return } - conversations[conversationIndex].messages.append(message) - await completeChat( - conversationId: conversationId, - model: model - ) + switch conversations[conversationIndex].type { + case .normal: + conversations[conversationIndex].messages.append(message) + + await completeChat( + conversationId: conversationId, + model: model + ) + // For assistant case we send chats to thread and then poll, polling will receive sent chat + new assistant messages. + case .assistant: + + // First message in an assistant thread. + if conversations[conversationIndex].messages.count == 0 { + + var localMessage = message + localMessage.isLocal = true + conversations[conversationIndex].messages.append(localMessage) + + do { + let threadsQuery = ThreadsQuery(messages: [Chat(role: message.role, content: message.content)]) + let threadsResult = try await openAIClient.threads(query: threadsQuery) + + guard let currentAssistantId = conversations[conversationIndex].assistantId else { return print("No assistant selected.")} + + let runsQuery = RunsQuery(assistantId: currentAssistantId) + let runsResult = try await openAIClient.runs(threadId: threadsResult.id, query: runsQuery) + + // check in on the run every time the poller gets hit. + startPolling(conversationId: conversationId, runId: runsResult.id, threadId: threadsResult.id) + } + catch { + print("error: \(error) creating thread w/ message") + } + } + // Subsequent messages on the assistant thread. + else { + + var localMessage = message + localMessage.isLocal = true + conversations[conversationIndex].messages.append(localMessage) + + do { + guard let currentThreadId else { return print("No thread to add message to.")} + + let _ = try await openAIClient.threadsAddMessage(threadId: currentThreadId, + query: ThreadAddMessageQuery(role: message.role.rawValue, content: message.content)) + + guard let currentAssistantId = conversations[conversationIndex].assistantId else { return print("No assistant selected.")} + + let runsQuery = RunsQuery(assistantId: currentAssistantId) + let runsResult = try await openAIClient.runs(threadId: currentThreadId, query: runsQuery) + + // check in on the run every time the poller gets hit. + startPolling(conversationId: conversationId, runId: runsResult.id, threadId: currentThreadId) + } + catch { + print("error: \(error) adding to thread w/ message") + } + } + } } - + @MainActor func completeChat( conversationId: Conversation.ID, @@ -77,7 +142,7 @@ public final class ChatStore: ObservableObject { guard let conversation = conversations.first(where: { $0.id == conversationId }) else { return } - + conversationErrors[conversationId] = nil do { @@ -89,16 +154,16 @@ public final class ChatStore: ObservableObject { name: "getWeatherData", description: "Get the current weather in a given location", parameters: .init( - type: .object, - properties: [ - "location": .init(type: .string, description: "The city and state, e.g. San Francisco, CA") - ], - required: ["location"] + type: .object, + properties: [ + "location": .init(type: .string, description: "The city and state, e.g. San Francisco, CA") + ], + required: ["location"] ) ) let functions = [weatherFunction] - + let chatsStream: AsyncThrowingStream = openAIClient.chatsStream( query: ChatQuery( model: model, @@ -117,10 +182,10 @@ public final class ChatStore: ObservableObject { // Function calls are also streamed, so we need to accumulate. if let functionCallDelta = choice.delta.functionCall { if let nameDelta = functionCallDelta.name { - functionCallName += nameDelta + functionCallName += nameDelta } if let argumentsDelta = functionCallDelta.arguments { - functionCallArguments += argumentsDelta + functionCallArguments += argumentsDelta } } var messageText = choice.delta.content ?? "" @@ -153,4 +218,144 @@ public final class ChatStore: ObservableObject { conversationErrors[conversationId] = error } } + + // Start Polling section + func startPolling(conversationId: Conversation.ID, runId: String, threadId: String) { + currentRunId = runId + currentThreadId = threadId + currentConversationId = conversationId + isSendingMessage = true + timer = Timer.scheduledTimer(withTimeInterval: timeInterval, repeats: true) { [weak self] _ in + DispatchQueue.main.async { + self?.timerFired() + } + } + } + + func stopPolling() { + isSendingMessage = false + timer?.invalidate() + timer = nil + } + + private func timerFired() { + Task { + let result = try await openAIClient.runRetrieve(threadId: currentThreadId ?? "", runId: currentRunId ?? "") + + // TESTING RETRIEVAL OF RUN STEPS + handleRunRetrieveSteps() + + switch result.status { + // Get threadsMesages. + case "completed": + handleCompleted() + break + case "failed": + // Handle more gracefully with a popup dialog or failure indicator + await MainActor.run { + self.stopPolling() + } + break + default: + // Handle additional statuses "requires_action", "queued" ?, "expired", "cancelled" + // https://platform.openai.com/docs/assistants/how-it-works/runs-and-run-steps + break + } + } + } + // END Polling section + + // This function is called when a thread is marked "completed" by the run status API. + private func handleCompleted() { + guard let conversationIndex = conversations.firstIndex(where: { $0.id == currentConversationId }) else { + return + } + Task { + await MainActor.run { + self.stopPolling() + } + // Once a thread is marked "completed" by the status API, we can retrieve the threads messages, including a pagins cursor representing the last message we received. + var before: String? + if let lastNonLocalMessage = self.conversations[conversationIndex].messages.last(where: { $0.isLocal == false }) { + before = lastNonLocalMessage.id + } + + let result = try await openAIClient.threadsMessages(threadId: currentThreadId ?? "", before: before) + + for item in result.data.reversed() { + let role = item.role + for innerItem in item.content { + let message = Message( + id: item.id, + role: Chat.Role(rawValue: role) ?? .user, + content: innerItem.text?.value ?? "", + createdAt: Date(), + isLocal: false // Messages from the server are not local + ) + await MainActor.run { + // Check if this message from the API matches a local message + if let localMessageIndex = self.conversations[conversationIndex].messages.firstIndex(where: { $0.isLocal == true }) { + + // Replace the local message with the API message + self.conversations[conversationIndex].messages[localMessageIndex] = message + } else { + // This is a new message from the server, append it + self.conversations[conversationIndex].messages.append(message) + } + } + } + } + } + } + + // The run retrieval steps are fetched in a separate task. This request is fetched, checking for new run steps, each time the run is fetched. + private func handleRunRetrieveSteps() { + Task { + guard let conversationIndex = conversations.firstIndex(where: { $0.id == currentConversationId }) else { + return + } + var before: String? +// if let lastRunStepMessage = self.conversations[conversationIndex].messages.last(where: { $0.isRunStep == true }) { +// before = lastRunStepMessage.id +// } + + let stepsResult = try await openAIClient.runRetrieveSteps(threadId: currentThreadId ?? "", runId: currentRunId ?? "", before: before) + + for item in stepsResult.data.reversed() { + let toolCalls = item.stepDetails.toolCalls?.reversed() ?? [] + + for step in toolCalls { + // TODO: Depending on the type of tool tha is used we can add additional information here + // ie: if its a retrieval: add file information, code_interpreter: add inputs and outputs info, or function: add arguemts and additional info. + let msgContent: String + switch step.type { + case "retrieval": + msgContent = "RUN STEP: \(step.type)" + + case "code_interpreter": + msgContent = "code_interpreter\ninput:\n\(step.code?.input ?? "")\noutputs: \(step.code?.outputs?.first?.logs ?? "")" + + default: + msgContent = "RUN STEP: \(step.type)" + + } + let runStepMessage = Message( + id: step.id, + role: .assistant, + content: msgContent, + createdAt: Date(), + isRunStep: true + ) + await MainActor.run { + if let localMessageIndex = self.conversations[conversationIndex].messages.firstIndex(where: { $0.isRunStep == true && $0.id == step.id }) { + self.conversations[conversationIndex].messages[localMessageIndex] = runStepMessage + } + else { + self.conversations[conversationIndex].messages.append(runStepMessage) + } + } + } + } + } + } } diff --git a/Demo/DemoChat/Sources/Models/Assistant.swift b/Demo/DemoChat/Sources/Models/Assistant.swift new file mode 100644 index 00000000..f1f8dfee --- /dev/null +++ b/Demo/DemoChat/Sources/Models/Assistant.swift @@ -0,0 +1,33 @@ +// +// Conversation.swift +// DemoChat +// +// Created by Sihao Lu on 3/25/23. +// + +import Foundation + +struct Assistant: Hashable { + init(id: String, name: String, description: String? = nil, instructions: String? = nil, codeInterpreter: Bool, retrieval: Bool, fileIds: [String]? = nil) { + self.id = id + self.name = name + self.description = description + self.instructions = instructions + self.codeInterpreter = codeInterpreter + self.retrieval = retrieval + self.fileIds = fileIds + } + + typealias ID = String + + let id: String + let name: String + let description: String? + let instructions: String? + let fileIds: [String]? + var codeInterpreter: Bool + var retrieval: Bool +} + + +extension Assistant: Equatable, Identifiable {} diff --git a/Demo/DemoChat/Sources/Models/Conversation.swift b/Demo/DemoChat/Sources/Models/Conversation.swift index 7d6f82b8..b1c3ab71 100644 --- a/Demo/DemoChat/Sources/Models/Conversation.swift +++ b/Demo/DemoChat/Sources/Models/Conversation.swift @@ -8,15 +8,24 @@ import Foundation struct Conversation { - init(id: String, messages: [Message] = []) { + init(id: String, messages: [Message] = [], type: ConversationType = .normal, assistantId: String? = nil) { self.id = id self.messages = messages + self.type = type + self.assistantId = assistantId } typealias ID = String let id: String var messages: [Message] + var type: ConversationType + var assistantId: String? +} + +enum ConversationType { + case normal + case assistant } extension Conversation: Equatable, Identifiable {} diff --git a/Demo/DemoChat/Sources/Models/Message.swift b/Demo/DemoChat/Sources/Models/Message.swift index afea9099..bfbc7b9b 100644 --- a/Demo/DemoChat/Sources/Models/Message.swift +++ b/Demo/DemoChat/Sources/Models/Message.swift @@ -13,6 +13,9 @@ struct Message { var role: Chat.Role var content: String var createdAt: Date + + var isLocal: Bool? + var isRunStep: Bool? } extension Message: Equatable, Codable, Hashable, Identifiable {} diff --git a/Demo/DemoChat/Sources/SupportedFileType.swift b/Demo/DemoChat/Sources/SupportedFileType.swift new file mode 100644 index 00000000..dc604cc9 --- /dev/null +++ b/Demo/DemoChat/Sources/SupportedFileType.swift @@ -0,0 +1,92 @@ +// +// SupportedFileType.swift +// +// +// Created by Chris Dillard on 12/8/23. +// + +import Foundation +import UniformTypeIdentifiers + +struct SupportedFileType { + let fileFormat: String + let mimeType: String + let isCodeInterpreterSupported: Bool + let isRetrievalSupported: Bool +} + +let supportedFileTypes: [SupportedFileType] = [ + SupportedFileType(fileFormat: "c", mimeType: "text/x-c", + isCodeInterpreterSupported: true, isRetrievalSupported: true), + SupportedFileType(fileFormat: "cpp", mimeType: "text/x-c++", + isCodeInterpreterSupported: true, isRetrievalSupported: true), + SupportedFileType(fileFormat: "csv", mimeType: "application/csv", + isCodeInterpreterSupported: true, isRetrievalSupported: false), + SupportedFileType(fileFormat: "docx", mimeType: "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + isCodeInterpreterSupported: true, isRetrievalSupported: true), + SupportedFileType(fileFormat: "html", mimeType: "text/html", + isCodeInterpreterSupported: true, isRetrievalSupported: true), + SupportedFileType(fileFormat: "java", mimeType: "text/x-java", + isCodeInterpreterSupported: true, isRetrievalSupported: true), + SupportedFileType(fileFormat: "json", mimeType: "application/json", + isCodeInterpreterSupported: true, isRetrievalSupported: true), + SupportedFileType(fileFormat: "md", mimeType: "text/markdown", + isCodeInterpreterSupported: true, isRetrievalSupported: true), + SupportedFileType(fileFormat: "pdf", mimeType: "application/pdf", + isCodeInterpreterSupported: true, isRetrievalSupported: true), + SupportedFileType(fileFormat: "php", mimeType: "text/x-php", + isCodeInterpreterSupported: true, isRetrievalSupported: true), + SupportedFileType(fileFormat: "pptx", mimeType: "application/vnd.openxmlformats-officedocument.presentationml.presentation", + isCodeInterpreterSupported: true, isRetrievalSupported: true), + SupportedFileType(fileFormat: "py", mimeType: "text/x-python", + isCodeInterpreterSupported: true, isRetrievalSupported: true), + SupportedFileType(fileFormat: "rb", mimeType: "text/x-ruby", + isCodeInterpreterSupported: true, isRetrievalSupported: true), + SupportedFileType(fileFormat: "tex", mimeType: "text/x-tex", + isCodeInterpreterSupported: true, isRetrievalSupported: true), + SupportedFileType(fileFormat: "txt", mimeType: "text/plain", + isCodeInterpreterSupported: true, isRetrievalSupported: true), + SupportedFileType(fileFormat: "css", mimeType: "text/css", + isCodeInterpreterSupported: true, isRetrievalSupported: false), + SupportedFileType(fileFormat: "jpeg", mimeType: "image/jpeg", + isCodeInterpreterSupported: true, isRetrievalSupported: false), + SupportedFileType(fileFormat: "jpg", mimeType: "image/jpeg", + isCodeInterpreterSupported: true, isRetrievalSupported: false), + SupportedFileType(fileFormat: "js", mimeType: "text/javascript", + isCodeInterpreterSupported: true, isRetrievalSupported: false), + SupportedFileType(fileFormat: "gif", mimeType: "image/gif", + isCodeInterpreterSupported: true, isRetrievalSupported: false), + SupportedFileType(fileFormat: "png", mimeType: "image/png", + isCodeInterpreterSupported: true, isRetrievalSupported: false), + SupportedFileType(fileFormat: "tar", mimeType: "application/x-tar", + isCodeInterpreterSupported: true, isRetrievalSupported: false), + SupportedFileType(fileFormat: "ts", mimeType: "application/typescript", + isCodeInterpreterSupported: true, isRetrievalSupported: false), + SupportedFileType(fileFormat: "xlsx", mimeType: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + isCodeInterpreterSupported: true, isRetrievalSupported: false), + SupportedFileType(fileFormat: "xml", mimeType: "application/xml", // or \"text/xml\" + isCodeInterpreterSupported: true, isRetrievalSupported: false), + SupportedFileType(fileFormat: "zip", mimeType: "application/zip", + isCodeInterpreterSupported: true, isRetrievalSupported: false) +] + +func supportedUITypes() -> [UTType] { + var supportedTypes: [UTType] = [] + + for supportedFileType in supportedFileTypes { + if let newType = UTType(filenameExtension: supportedFileType.fileFormat) { + supportedTypes += [newType] + } + } + + return supportedTypes +} + +extension URL { + func mimeType() -> String { + guard let utType = UTType(filenameExtension: self.pathExtension) else { + return "application/octet-stream" // Default type if unknown + } + return utType.preferredMIMEType ?? "application/octet-stream" + } +} diff --git a/Demo/DemoChat/Sources/UI/AssistantModalContentView.swift b/Demo/DemoChat/Sources/UI/AssistantModalContentView.swift new file mode 100644 index 00000000..efde0536 --- /dev/null +++ b/Demo/DemoChat/Sources/UI/AssistantModalContentView.swift @@ -0,0 +1,112 @@ +// +// AssistantModalContentView.swift +// +// +// Created by Chris Dillard on 11/9/23. +// + +import SwiftUI + +struct AssistantModalContentView: View { + enum Mode { + case modify + case create + } + + @Binding var name: String + @Binding var description: String + @Binding var customInstructions: String + + @Binding var codeInterpreter: Bool + @Binding var retrieval: Bool + @Binding var fileIds: [String] + @Binding var isUploading: Bool + + var modify: Bool + + @Environment(\.dismiss) var dismiss + + @Binding var isPickerPresented: Bool + // If a file has been selected for uploading and is currently in progress, this is set. + @Binding var selectedFileURL: URL? + + var onCommit: () -> Void + var onFileUpload: () -> Void + + var body: some View { + NavigationView { + Form { + Section(header: Text("Name")) { + TextField("Name", text: $name) + } + Section(header: Text("Description")) { + TextEditor(text: $description) + .frame(minHeight: 50) + } + Section(header: Text("Custom Instructions")) { + TextEditor(text: $customInstructions) + .frame(minHeight: 100) + } + + Toggle(isOn: $codeInterpreter, label: { + Text("Code interpreter") + }) + + Toggle(isOn: $retrieval, label: { + Text("Retrieval") + }) + + if !fileIds.isEmpty { + ForEach(fileIds, id: \.self) { fileId in + HStack { + // File Id of each file added to the assistant. + Text("File: \(fileId)") + Spacer() + // Button to remove fileId from the list of fileIds to be used when create or modify assistant. + Button(action: { + // Add action to remove the file from the list + if let index = fileIds.firstIndex(of: fileId) { + fileIds.remove(at: index) + } + }) { + Image(systemName: "xmark.circle.fill") // X button + .foregroundColor(.red) + } + } + } + } + + if let selectedFileURL { + HStack { + Text("File: \(selectedFileURL.lastPathComponent)") + + Button("Remove") { + self.selectedFileURL = nil + } + } + } + else { + Button("Upload File") { + isPickerPresented = true + } + .sheet(isPresented: $isPickerPresented) { + DocumentPicker { url in + selectedFileURL = url + onFileUpload() + } + } + } + } + .navigationTitle("\(modify ? "Edit" : "Enter") Assistant Details") + .navigationBarItems( + leading: Button("Cancel") { + dismiss() + }, + trailing: Button("OK") { + onCommit() + dismiss() + } + ) + } + } +} diff --git a/Demo/DemoChat/Sources/UI/AssistantsListView.swift b/Demo/DemoChat/Sources/UI/AssistantsListView.swift new file mode 100644 index 00000000..16377649 --- /dev/null +++ b/Demo/DemoChat/Sources/UI/AssistantsListView.swift @@ -0,0 +1,42 @@ +// +// ListView.swift +// DemoChat +// +// Created by Sihao Lu on 3/25/23. +// + +import SwiftUI + +struct AssistantsListView: View { + @Binding var assistants: [Assistant] + @Binding var selectedAssistantId: String? + var onLoadMoreAssistants: () -> Void + @Binding var isLoadingMore: Bool + + var body: some View { + VStack { + List( + $assistants, + editActions: [.delete], + selection: $selectedAssistantId + ) { $assistant in + Text( + assistant.name + ) + .lineLimit(2) + .onAppear { + if assistant.id == assistants.last?.id { + onLoadMoreAssistants() + } + } + } + + + if isLoadingMore { + ProgressView() + .padding() + } + } + .navigationTitle("Assistants") + } +} diff --git a/Demo/DemoChat/Sources/UI/AssistantsView.swift b/Demo/DemoChat/Sources/UI/AssistantsView.swift new file mode 100644 index 00000000..d2668fdd --- /dev/null +++ b/Demo/DemoChat/Sources/UI/AssistantsView.swift @@ -0,0 +1,196 @@ +// +// ChatView.swift +// DemoChat +// +// Created by Sihao Lu on 3/25/23. +// + +import Combine +import SwiftUI + +public struct AssistantsView: View { + @ObservedObject var store: ChatStore + @ObservedObject var assistantStore: AssistantStore + + @Environment(\.dateProviderValue) var dateProvider + @Environment(\.idProviderValue) var idProvider + + // state to select file + @State private var isPickerPresented: Bool = false + @State private var fileURL: URL? + + // state to modify assistant + @State private var name: String = "" + @State private var description: String = "" + @State private var customInstructions: String = "" + @State private var fileIds: [String] = [] + + @State private var codeInterpreter: Bool = false + @State private var retrieval: Bool = false + @State var isLoadingMore = false + @State private var isModalPresented = false + @State private var isUploading = false + + //If a file is selected via the document picker, this is set. + @State var selectedFileURL: URL? + @State var uploadedFileId: String? + + @State var mode: AssistantModalContentView.Mode = .create + + public init(store: ChatStore, assistantStore: AssistantStore) { + self.store = store + self.assistantStore = assistantStore + } + + public var body: some View { + ZStack { + NavigationSplitView { + AssistantsListView( + assistants: $assistantStore.availableAssistants, selectedAssistantId: Binding( + get: { + assistantStore.selectedAssistantId + + }, set: { newId in + guard newId != nil else { return } + + selectAssistant(newId: newId) + }), onLoadMoreAssistants: { + loadMoreAssistants() + }, isLoadingMore: $isLoadingMore + ) + .toolbar { + ToolbarItem( + placement: .primaryAction + ) { + Menu { + Button("Get Assistants") { + Task { + let _ = await assistantStore.getAssistants() + } + } + Button("Create Assistant") { + mode = .create + isModalPresented = true + } + } label: { + Image(systemName: "plus") + } + + .buttonStyle(.borderedProminent) + } + } + } detail: { + + } + .sheet(isPresented: $isModalPresented, onDismiss: { + resetAssistantCreator() + }, content: { + AssistantModalContentView(name: $name, description: $description, customInstructions: $customInstructions, + codeInterpreter: $codeInterpreter, retrieval: $retrieval, fileIds: $fileIds, + isUploading: $isUploading, modify: mode == .modify, isPickerPresented: $isPickerPresented, selectedFileURL: $selectedFileURL) { + Task { + await handleOKTap() + } + } onFileUpload: { + Task { + guard let selectedFileURL else { return } + + isUploading = true + let file = await assistantStore.uploadFile(url: selectedFileURL) + uploadedFileId = file?.id + isUploading = false + + if uploadedFileId == nil { + print("Failed to upload") + self.selectedFileURL = nil + } + else { + // if successful upload , we can show it. + if let uploadedFileId = uploadedFileId { + self.selectedFileURL = nil + + fileIds += [uploadedFileId] + + print("Successful upload!") + } + } + } + } + }) + } + } + + private func handleOKTap() async { + + var mergedFileIds = [String]() + + mergedFileIds += fileIds + + let asstId: String? + + switch mode { + // Create new Assistant and start a new conversation with it. + case .create: + asstId = await assistantStore.createAssistant(name: name, description: description, instructions: customInstructions, codeInterpreter: codeInterpreter, retrievel: retrieval, fileIds: mergedFileIds.isEmpty ? nil : mergedFileIds) + // Modify existing Assistant and start new conversation with it. + case .modify: + guard let selectedAssistantId = assistantStore.selectedAssistantId else { return print("Cannot modify assistant, not selected.") } + + asstId = await assistantStore.modifyAssistant(asstId: selectedAssistantId, name: name, description: description, instructions: customInstructions, codeInterpreter: codeInterpreter, retrievel: retrieval, fileIds: mergedFileIds.isEmpty ? nil : mergedFileIds) + } + + // Reset Assistant Creator after attempted creation or modification. + resetAssistantCreator() + + guard let asstId else { + print("failed to create Assistant.") + return + } + + // Create new local conversation to represent new thread. + store.createConversation(type: .assistant, assistantId: asstId) + } + + private func loadMoreAssistants() { + guard !isLoadingMore else { return } + + isLoadingMore = true + let lastAssistantId = assistantStore.availableAssistants.last?.id ?? "" + + Task { + // Fetch more assistants and append to the list + let _ = await assistantStore.getAssistants(after: lastAssistantId) + isLoadingMore = false + } + } + + private func resetAssistantCreator() { + // Reset state for Assistant creator. + name = "" + description = "" + customInstructions = "" + + codeInterpreter = false + retrieval = false + selectedFileURL = nil + uploadedFileId = nil + fileIds = [] + } + + private func selectAssistant(newId: String?) { + assistantStore.selectAssistant(newId) + + let selectedAssistant = assistantStore.availableAssistants.filter { $0.id == assistantStore.selectedAssistantId }.first + + name = selectedAssistant?.name ?? "" + description = selectedAssistant?.description ?? "" + customInstructions = selectedAssistant?.instructions ?? "" + codeInterpreter = selectedAssistant?.codeInterpreter ?? false + retrieval = selectedAssistant?.retrieval ?? false + fileIds = selectedAssistant?.fileIds ?? [] + + mode = .modify + isModalPresented = true + + } +} diff --git a/Demo/DemoChat/Sources/UI/ChatView.swift b/Demo/DemoChat/Sources/UI/ChatView.swift index 1b872c21..1812ed26 100644 --- a/Demo/DemoChat/Sources/UI/ChatView.swift +++ b/Demo/DemoChat/Sources/UI/ChatView.swift @@ -10,57 +10,63 @@ import SwiftUI public struct ChatView: View { @ObservedObject var store: ChatStore - + @ObservedObject var assistantStore: AssistantStore + @Environment(\.dateProviderValue) var dateProvider @Environment(\.idProviderValue) var idProvider - public init(store: ChatStore) { + public init(store: ChatStore, assistantStore: AssistantStore) { self.store = store + self.assistantStore = assistantStore } - + public var body: some View { - NavigationSplitView { - ListView( - conversations: $store.conversations, - selectedConversationId: Binding( - get: { - store.selectedConversationID - }, set: { newId in - store.selectConversation(newId) - }) - ) - .toolbar { - ToolbarItem( - placement: .primaryAction - ) { - Button(action: { - store.createConversation() - }) { - Image(systemName: "plus") - } - .buttonStyle(.borderedProminent) - } - } - } detail: { - if let conversation = store.selectedConversation { - DetailView( - conversation: conversation, - error: store.conversationErrors[conversation.id], - sendMessage: { message, selectedModel in - Task { - await store.sendMessage( - Message( - id: idProvider(), - role: .user, - content: message, - createdAt: dateProvider() - ), - conversationId: conversation.id, - model: selectedModel - ) + ZStack { + NavigationSplitView { + ListView( + conversations: $store.conversations, + selectedConversationId: Binding( + get: { + store.selectedConversationID + }, set: { newId in + store.selectConversation(newId) + }) + ) + .toolbar { + ToolbarItem( + placement: .primaryAction + ) { + Menu { + Button("Create Chat") { + store.createConversation() + } + } label: { + Image(systemName: "plus") } + .buttonStyle(.borderedProminent) } - ) + } + } detail: { + if let conversation = store.selectedConversation { + DetailView( + availableAssistants: assistantStore.availableAssistants, conversation: conversation, + error: store.conversationErrors[conversation.id], + sendMessage: { message, selectedModel in + Task { + await store.sendMessage( + Message( + id: idProvider(), + role: .user, + content: message, + createdAt: dateProvider() + ), + conversationId: conversation.id, + model: selectedModel + ) + } + }, isSendingMessage: $store.isSendingMessage + ) + } } } } diff --git a/Demo/DemoChat/Sources/UI/DetailView.swift b/Demo/DemoChat/Sources/UI/DetailView.swift index 9e2a07e9..ee7b076f 100644 --- a/Demo/DemoChat/Sources/UI/DetailView.swift +++ b/Demo/DemoChat/Sources/UI/DetailView.swift @@ -18,6 +18,7 @@ struct DetailView: View { @FocusState private var isFocused: Bool @State private var showsModelSelectionSheet = false @State private var selectedChatModel: Model = .gpt4_0613 + var availableAssistants: [Assistant] private let availableChatModels: [Model] = [.gpt3_5Turbo0613, .gpt4_0613] @@ -25,6 +26,8 @@ struct DetailView: View { let error: Error? let sendMessage: (String, Model) -> Void + @Binding var isSendingMessage: Bool + private var fillColor: Color { #if os(iOS) return Color(uiColor: UIColor.systemBackground) @@ -51,6 +54,10 @@ struct DetailView: View { } .listRowSeparator(.hidden) } + // Tapping on the message bubble area should dismiss the keyboard. + .onTapGesture { + self.hideKeyboard() + } .listStyle(.plain) .animation(.default, value: conversation.messages) // .onChange(of: conversation) { newValue in @@ -65,11 +72,11 @@ struct DetailView: View { inputBar(scrollViewProxy: scrollViewProxy) } - .navigationTitle("Chat") + .navigationTitle(conversation.type == .assistant ? "Assistant: \(currentAssistantName())" : "Chat") .safeAreaInset(edge: .top) { HStack { Text( - "Model: \(selectedChatModel)" + "Model: \(conversation.type == .assistant ? Model.gpt4_1106_preview : selectedChatModel)" ) .font(.caption) .foregroundColor(.secondary) @@ -79,11 +86,28 @@ struct DetailView: View { .padding(.vertical, 8) } .toolbar { - ToolbarItem(placement: .navigationBarTrailing) { - Button(action: { - showsModelSelectionSheet.toggle() - }) { - Image(systemName: "cpu") + if conversation.type == .assistant { + ToolbarItem(placement: .navigationBarTrailing) { + + Menu { + ForEach(availableAssistants, id: \.self) { item in + Button(item.name) { + print("Select assistant") + //selectedItem = item + } + } + } label: { + Image(systemName: "eyeglasses") + } + } + } + if conversation.type == .normal { + ToolbarItem(placement: .navigationBarTrailing) { + Button(action: { + showsModelSelectionSheet.toggle() + }) { + Image(systemName: "cpu") + } } } } @@ -165,18 +189,24 @@ struct DetailView: View { } .padding(.leading) - Button(action: { - withAnimation { - tapSendMessage(scrollViewProxy: scrollViewProxy) + if isSendingMessage { + ProgressView() + .progressViewStyle(CircularProgressViewStyle()) + .padding(.trailing) + } else { + Button(action: { + withAnimation { + tapSendMessage(scrollViewProxy: scrollViewProxy) + } + }) { + Image(systemName: "paperplane") + .resizable() + .aspectRatio(contentMode: .fit) + .frame(width: 24, height: 24) + .padding(.trailing) } - }) { - Image(systemName: "paperplane") - .resizable() - .aspectRatio(contentMode: .fit) - .frame(width: 24, height: 24) - .padding(.trailing) + .disabled(inputText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty) } - .disabled(inputText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty) } .padding(.bottom) } @@ -196,6 +226,13 @@ struct DetailView: View { // scrollViewProxy.scrollTo(lastMessage.id, anchor: .bottom) // } } + + func currentAssistantName() -> String { + availableAssistants.filter { conversation.assistantId == $0.id }.first?.name ?? "" + } + func hideKeyboard() { + UIApplication.shared.sendAction(#selector(UIResponder.resignFirstResponder), to: nil, from: nil, for: nil) + } } struct ChatBubble: View { @@ -261,6 +298,7 @@ struct ChatBubble: View { struct DetailView_Previews: PreviewProvider { static var previews: some View { DetailView( + availableAssistants: [], conversation: Conversation( id: "1", messages: [ @@ -277,7 +315,7 @@ struct DetailView_Previews: PreviewProvider { ] ), error: nil, - sendMessage: { _, _ in } + sendMessage: { _, _ in }, isSendingMessage: Binding.constant(false) ) } } diff --git a/Demo/DemoChat/Sources/UI/DocumentPicker.swift b/Demo/DemoChat/Sources/UI/DocumentPicker.swift new file mode 100644 index 00000000..3c960235 --- /dev/null +++ b/Demo/DemoChat/Sources/UI/DocumentPicker.swift @@ -0,0 +1,41 @@ +// +// DocumentPicker.swift +// +// +// Created by Chris Dillard on 11/10/23. +// + +import SwiftUI +import UniformTypeIdentifiers + +struct DocumentPicker: UIViewControllerRepresentable { + var callback: (URL) -> Void + + func makeUIViewController(context: Context) -> UIDocumentPickerViewController { + let pickerViewController = UIDocumentPickerViewController(forOpeningContentTypes: supportedUITypes(), asCopy: true) + pickerViewController.allowsMultipleSelection = false + pickerViewController.shouldShowFileExtensions = true + + pickerViewController.delegate = context.coordinator + return pickerViewController + } + + func updateUIViewController(_ uiViewController: UIDocumentPickerViewController, context: Context) {} + + func makeCoordinator() -> Coordinator { + return Coordinator(self) + } + + class Coordinator: NSObject, UIDocumentPickerDelegate { + var parent: DocumentPicker + + init(_ parent: DocumentPicker) { + self.parent = parent + } + + func documentPicker(_ controller: UIDocumentPickerViewController, didPickDocumentsAt urls: [URL]) { + guard let url = urls.first else { return } + parent.callback(url) + } + } +} diff --git a/Demo/DemoChat/Sources/UI/ListView.swift b/Demo/DemoChat/Sources/UI/ListView.swift index bfbdfc56..d8be5585 100644 --- a/Demo/DemoChat/Sources/UI/ListView.swift +++ b/Demo/DemoChat/Sources/UI/ListView.swift @@ -17,10 +17,28 @@ struct ListView: View { editActions: [.delete], selection: $selectedConversationId ) { $conversation in - Text( - conversation.messages.last?.content ?? "New Conversation" - ) - .lineLimit(2) + if let convoContent = conversation.messages.last?.content { + Text( + convoContent + ) + .lineLimit(2) + } + else { + if conversation.type == .assistant { + Text( + "New Assistant" + ) + .lineLimit(2) + } + else { + Text( + "New Conversation" + ) + .lineLimit(2) + } + } + + } .navigationTitle("Conversations") } diff --git a/Demo/DemoChat/Sources/UI/ModerationChatView.swift b/Demo/DemoChat/Sources/UI/ModerationChatView.swift index 41658845..ec66425e 100644 --- a/Demo/DemoChat/Sources/UI/ModerationChatView.swift +++ b/Demo/DemoChat/Sources/UI/ModerationChatView.swift @@ -19,7 +19,7 @@ public struct ModerationChatView: View { public var body: some View { DetailView( - conversation: store.moderationConversation, + availableAssistants: [], conversation: store.moderationConversation, error: store.moderationConversationError, sendMessage: { message, _ in Task { @@ -32,7 +32,7 @@ public struct ModerationChatView: View { ) ) } - } + }, isSendingMessage: Binding.constant(false) ) } } diff --git a/README.md b/README.md index 7d9564d4..13545443 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,21 @@ This repository contains Swift community-maintained implementation over [OpenAI] - [Moderations](#moderations) - [Utilities](#utilities) - [Combine Extensions](#combine-extensions) + - [Assistants (Beta)](#assistants) + - [Create Assistant](#create-assistant) + - [Modify Assistant](#modify-assistant) + - [List Assistants](#list-assistants) + - [Threads](#threads) + - [Create Thread](#create-thread) + - [Get Threads Messages](#get-threads-messages) + - [Add Message to Thread](#add-message-to-thread) + - [Runs](#runs) + - [Create Run](#create-run) + - [Retrieve Run](#retrieve-run) + - [Retrieve Run Steps](#retrieve-run-steps) + + - [Files](#files) + - [Upload File](#upload-file) - [Example Project](#example-project) - [Contribution Guidelines](#contribution-guidelines) - [Links](#links) @@ -998,6 +1013,120 @@ func audioTranscriptions(query: AudioTranscriptionQuery) -> AnyPublisher AnyPublisher ``` +### Assistants + +Review [Assistants Documentation](https://platform.openai.com/docs/api-reference/assistants) for more info. + +#### Create Assistant + +Example: Create Assistant +``` +let query = AssistantsQuery(model: Model.gpt4_1106_preview, name: name, description: description, instructions: instructions, tools: tools, fileIds: fileIds) +openAI.assistants(query: query) { result in + //Handle response here +} +``` + +#### Modify Assistant + +Example: Modify Assistant +``` +let query = AssistantsQuery(model: Model.gpt4_1106_preview, name: name, description: description, instructions: instructions, tools: tools, fileIds: fileIds) +openAI.assistantModify(query: query, asstId: "asst_1234") { result in + //Handle response here +} +``` + +#### List Assistants + +Example: List Assistants +``` +openAI.assistants(query: nil, method: "GET") { result in + //Handle response here +} +``` + +#### Threads + +Review [Threads Documentation](https://platform.openai.com/docs/api-reference/threads) for more info. + +##### Create Thread + +Example: Create Thread +``` +let threadsQuery = ThreadsQuery(messages: [Chat(role: message.role, content: message.content)]) +openAI.threads(query: threadsQuery) { result in + //Handle response here +} +``` + +##### Get Threads Messages + +Review [Messages Documentation](https://platform.openai.com/docs/api-reference/messages) for more info. + +Example: Get Threads Messages +``` +openAI.threadsMessages(threadId: currentThreadId, before: nil) { result in + //Handle response here +} +``` + +##### Add Message to Thread + +Example: Add Message to Thread +``` +let query = ThreadAddMessageQuery(role: message.role.rawValue, content: message.content) +openAI.threadsAddMessage(threadId: currentThreadId, query: query) { result in + //Handle response here +} +``` + +#### Runs + +Review [Runs Documentation](https://platform.openai.com/docs/api-reference/runs) for more info. + +##### Create Run + +Example: Create Run +``` +let runsQuery = RunsQuery(assistantId: currentAssistantId) +openAI.runs(threadId: threadsResult.id, query: runsQuery) { result in + //Handle response here +} +``` + +##### Retrieve Run + +Example: Retrieve Run +``` +openAI.runRetrieve(threadId: currentThreadId, runId: currentRunId) { result in + //Handle response here +} +``` + +##### Retrieve Run Steps + +Example: Retrieve Run Steps +``` +openAI.runRetrieveSteps(threadId: currentThreadId, runId: currentRunId, before: nil) { result in + //Handle response here +} +``` + +#### Files + +Review [Files Documentation](https://platform.openai.com/docs/api-reference/files) for more info. + +##### Upload file + +Example: Upload file +``` +let query = FilesQuery(purpose: "assistants", file: fileData, fileName: url.lastPathComponent, contentType: "application/pdf") +openAI.files(query: query) { result in + //Handle response here +} +``` + ## Example Project You can find example iOS application in [Demo](/Demo) folder. diff --git a/Sources/OpenAI/OpenAI.swift b/Sources/OpenAI/OpenAI.swift index 3dcad3c9..f9eaf606 100644 --- a/Sources/OpenAI/OpenAI.swift +++ b/Sources/OpenAI/OpenAI.swift @@ -55,7 +55,45 @@ final public class OpenAI: OpenAIProtocol { public convenience init(configuration: Configuration, session: URLSession = URLSession.shared) { self.init(configuration: configuration, session: session as URLSessionProtocol) } - + + // UPDATES FROM 11-06-23 + public func threadsAddMessage(threadId: String, query: ThreadAddMessageQuery, completion: @escaping (Result) -> Void) { + performRequest(request: JSONRequest(body: query, url: buildRunsURL(path: .threadsMessages, threadId: threadId)), completion: completion) + } + + public func threadsMessages(threadId: String, before: String?, completion: @escaping (Result) -> Void) { + performRequest(request: JSONRequest(body: nil, url: buildRunsURL(path: .threadsMessages, threadId: threadId, before: before), method: "GET"), completion: completion) + } + + public func runRetrieve(threadId: String, runId: String, completion: @escaping (Result) -> Void) { + performRequest(request: JSONRequest(body: nil, url: buildRunRetrieveURL(path: .runRetrieve, threadId: threadId, runId: runId, before: nil), method: "GET"), completion: completion) + } + + public func runRetrieveSteps(threadId: String, runId: String, before: String?, completion: @escaping (Result) -> Void) { + performRequest(request: JSONRequest(body: nil, url: buildRunRetrieveURL(path: .runRetrieveSteps, threadId: threadId, runId: runId, before: before), method: "GET"), completion: completion) + } + + public func runs(threadId: String, query: RunsQuery, completion: @escaping (Result) -> Void) { + performRequest(request: JSONRequest(body: query, url: buildRunsURL(path: .runs, threadId: threadId)), completion: completion) + } + + public func threads(query: ThreadsQuery, completion: @escaping (Result) -> Void) { + performRequest(request: JSONRequest(body: query, url: buildURL(path: .threads)), completion: completion) + } + + public func assistants(query: AssistantsQuery?, method: String, after: String?, completion: @escaping (Result) -> Void) { + performRequest(request: JSONRequest(body: query, url: buildURL(path: .assistants, after: after), method: method), completion: completion) + } + + public func assistantModify(query: AssistantsQuery, asstId: String, completion: @escaping (Result) -> Void) { + performRequest(request: JSONRequest(body: query, url: buildAssistantURL(path: .assistantsModify, assistantId: asstId)), completion: completion) + } + + public func files(query: FilesQuery, completion: @escaping (Result) -> Void) { + performRequest(request: MultipartFormDataRequest(body: query, url: buildURL(path: .files)), completion: completion) + } + // END UPDATES FROM 11-06-23 + public func completions(query: CompletionsQuery, completion: @escaping (Result) -> Void) { performRequest(request: JSONRequest(body: query, url: buildURL(path: .completions)), completion: completion) } @@ -137,6 +175,9 @@ extension OpenAI { var apiError: Error? = nil do { + + let errorText = String(data: data, encoding: .utf8) + let decoded = try JSONDecoder().decode(ResultType.self, from: data) completion(.success(decoded)) } catch { @@ -168,6 +209,7 @@ extension OpenAI { onResult(.success(object)) } session.onProcessingError = {_, error in + print("OpenAI API error = \(error.localizedDescription)") onResult(.failure(error)) } session.onComplete = { [weak self] object, error in @@ -218,18 +260,63 @@ extension OpenAI { extension OpenAI { - func buildURL(path: String) -> URL { + func buildURL(path: String, after: String? = nil) -> URL { var components = URLComponents() components.scheme = "https" components.host = configuration.host components.path = path + if let after { + components.queryItems = [URLQueryItem(name: "after", value: after)] + } + return components.url! + } + + func buildRunsURL(path: String, threadId: String, before: String? = nil) -> URL { + var components = URLComponents() + components.scheme = "https" + components.host = configuration.host + components.path = path.replacingOccurrences(of: "THREAD_ID", with: threadId) + if let before { + components.queryItems = [URLQueryItem(name: "before", value: before)] + } + return components.url! + } + + func buildRunRetrieveURL(path: String, threadId: String, runId: String, before: String? = nil) -> URL { + var components = URLComponents() + components.scheme = "https" + components.host = configuration.host + components.path = path.replacingOccurrences(of: "THREAD_ID", with: threadId) + .replacingOccurrences(of: "RUN_ID", with: runId) + if let before { + components.queryItems = [URLQueryItem(name: "before", value: before)] + } + return components.url! + } + + func buildAssistantURL(path: String, assistantId: String) -> URL { + var components = URLComponents() + components.scheme = "https" + components.host = configuration.host + components.path = path.replacingOccurrences(of: "ASST_ID", with: assistantId) + return components.url! } } typealias APIPath = String extension APIPath { - + // 1106 + static let assistants = "/v1/assistants" + static let assistantsModify = "/v1/assistants/ASST_ID" + static let threads = "/v1/threads" + static let runs = "/v1/threads/THREAD_ID/runs" + static let runRetrieve = "/v1/threads/THREAD_ID/runs/RUN_ID" + static let runRetrieveSteps = "/v1/threads/THREAD_ID/runs/RUN_ID/steps" + static let threadsMessages = "/v1/threads/THREAD_ID/messages" + static let files = "/v1/files" + // 1106 end + static let completions = "/v1/completions" static let embeddings = "/v1/embeddings" static let chats = "/v1/chat/completions" diff --git a/Sources/OpenAI/Private/JSONRequest.swift b/Sources/OpenAI/Private/JSONRequest.swift index 526f95c9..afacfc0f 100644 --- a/Sources/OpenAI/Private/JSONRequest.swift +++ b/Sources/OpenAI/Private/JSONRequest.swift @@ -29,6 +29,9 @@ extension JSONRequest: URLRequestBuildable { var request = URLRequest(url: url, timeoutInterval: timeoutInterval) request.setValue("application/json", forHTTPHeaderField: "Content-Type") request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization") + // TODO: ONLY PASS IF ASSISTANTS API + request.setValue("assistants=v1", forHTTPHeaderField: "OpenAI-Beta") + if let organizationIdentifier { request.setValue(organizationIdentifier, forHTTPHeaderField: "OpenAI-Organization") } diff --git a/Sources/OpenAI/Public/Models/AssistantsQuery.swift b/Sources/OpenAI/Public/Models/AssistantsQuery.swift new file mode 100644 index 00000000..64c22cfe --- /dev/null +++ b/Sources/OpenAI/Public/Models/AssistantsQuery.swift @@ -0,0 +1,56 @@ +// +// AssistantsQuery.swift +// +// +// Created by Chris Dillard on 11/07/2023. +// + +import Foundation + +public struct AssistantsQuery: Codable { + + public let model: Model + + public let name: String + + public let description: String + + public let instructions: String + + public let tools: [Tool]? + + public let fileIds: [String]? + + enum CodingKeys: String, CodingKey { + case model + case name + case description + case instructions + case tools + case fileIds = "file_ids" + } + + public init(model: Model, name: String, description: String, instructions: String, tools: [Tool], fileIds: [String]? = nil) { + self.model = model + self.name = name + + self.description = description + self.instructions = instructions + + self.tools = tools + self.fileIds = fileIds + } +} + +public struct Tool: Codable, Equatable { + public let toolType: String + + enum CodingKeys: String, CodingKey { + case toolType = "type" + } + + public init(toolType: String) { + self.toolType = toolType + } + +} diff --git a/Sources/OpenAI/Public/Models/AssistantsResult.swift b/Sources/OpenAI/Public/Models/AssistantsResult.swift new file mode 100644 index 00000000..39ad2a5e --- /dev/null +++ b/Sources/OpenAI/Public/Models/AssistantsResult.swift @@ -0,0 +1,41 @@ +// +// AssistantsResult.swift +// +// +// Created by Chris Dillard on 11/07/2023. +// + +import Foundation + +public struct AssistantsResult: Codable, Equatable { + + public let id: String? + + public let data: [AssistantContent]? + public let tools: [Tool]? + + enum CodingKeys: String, CodingKey { + case data + case id + case tools + } + + public struct AssistantContent: Codable, Equatable { + + public let id: String + public let name: String + public let description: String? + public let instructions: String? + public let tools: [Tool]? + public let fileIds: [String]? + + enum CodingKeys: String, CodingKey { + case id + case name + case description + case instructions + case tools + case fileIds = "file_ids" + } + } +} diff --git a/Sources/OpenAI/Public/Models/FilesQuery.swift b/Sources/OpenAI/Public/Models/FilesQuery.swift new file mode 100644 index 00000000..09d883dc --- /dev/null +++ b/Sources/OpenAI/Public/Models/FilesQuery.swift @@ -0,0 +1,42 @@ +// +// FilesQuery.swift +// +// +// Created by Chris Dillard on 11/07/2023. +// + +import Foundation + +public struct FilesQuery: Codable { + + public let purpose: String + + public let file: Data + public let fileName: String + + public let contentType: String + + enum CodingKeys: String, CodingKey { + case purpose + case file + case fileName + case contentType + } + + public init(purpose: String, file: Data, fileName: String, contentType: String) { + self.purpose = purpose + self.file = file + self.fileName = fileName + self.contentType = contentType + } +} + +extension FilesQuery: MultipartFormDataBodyEncodable { + func encode(boundary: String) -> Data { + let bodyBuilder = MultipartFormDataBodyBuilder(boundary: boundary, entries: [ + .string(paramName: "purpose", value: purpose), + .file(paramName: "file", fileName: fileName, fileData: file, contentType: contentType), + ]) + return bodyBuilder.build() + } +} diff --git a/Sources/OpenAI/Public/Models/FilesResult.swift b/Sources/OpenAI/Public/Models/FilesResult.swift new file mode 100644 index 00000000..0799d8c4 --- /dev/null +++ b/Sources/OpenAI/Public/Models/FilesResult.swift @@ -0,0 +1,15 @@ +// +// FilesResult.swift +// +// +// Created by Chris Dillard on 11/07/2023. +// + +import Foundation + +public struct FilesResult: Codable, Equatable { + + public let id: String + public let name: String + +} diff --git a/Sources/OpenAI/Public/Models/RunRetrieveQuery.swift b/Sources/OpenAI/Public/Models/RunRetrieveQuery.swift new file mode 100644 index 00000000..eef7e5d0 --- /dev/null +++ b/Sources/OpenAI/Public/Models/RunRetrieveQuery.swift @@ -0,0 +1,15 @@ +// +// RunRetrieveQuery.swift +// +// +// Created by Chris Dillard on 11/07/2023. +// + +import Foundation + +public struct RunRetrieveQuery: Equatable, Codable { + + public init() { + + } +} diff --git a/Sources/OpenAI/Public/Models/RunRetrieveResult.swift b/Sources/OpenAI/Public/Models/RunRetrieveResult.swift new file mode 100644 index 00000000..5e377b9a --- /dev/null +++ b/Sources/OpenAI/Public/Models/RunRetrieveResult.swift @@ -0,0 +1,13 @@ +// +// RunsResult.swift +// +// +// Created by Chris Dillard on 11/07/2023. +// + +import Foundation + +public struct RunRetreiveResult: Codable, Equatable { + + public let status: String +} diff --git a/Sources/OpenAI/Public/Models/RunRetrieveStepsResult.swift b/Sources/OpenAI/Public/Models/RunRetrieveStepsResult.swift new file mode 100644 index 00000000..5bbd47bb --- /dev/null +++ b/Sources/OpenAI/Public/Models/RunRetrieveStepsResult.swift @@ -0,0 +1,55 @@ +// +// RunRetreiveStepsResult.swift +// +// +// Created by Chris Dillard on 11/07/2023. +// + +import Foundation + +public struct RunRetreiveStepsResult: Codable, Equatable { + + public struct StepDetailsTopLevel: Codable, Equatable { + public let id: String + public let stepDetails: StepDetailsSecondLevel + + enum CodingKeys: String, CodingKey { + case id + case stepDetails = "step_details" + } + + public struct StepDetailsSecondLevel: Codable, Equatable { + + public let toolCalls: [ToolCall]? + + enum CodingKeys: String, CodingKey { + case toolCalls = "tool_calls" + } + + public struct ToolCall: Codable, Equatable { + public let id: String + public let type: String + public let code: CodeToolCall? + + enum CodingKeys: String, CodingKey { + case id + case type + case code = "code_interpreter" + } + + public struct CodeToolCall: Codable, Equatable { + public let input: String + public let outputs: [CodeToolCallOutput]? + + public struct CodeToolCallOutput: Codable, Equatable { + public let type: String + public let logs: String? + + } + } + } + } + } + + public let data: [StepDetailsTopLevel] +} diff --git a/Sources/OpenAI/Public/Models/RunsQuery.swift b/Sources/OpenAI/Public/Models/RunsQuery.swift new file mode 100644 index 00000000..b9ba2ad6 --- /dev/null +++ b/Sources/OpenAI/Public/Models/RunsQuery.swift @@ -0,0 +1,22 @@ +// +// AssistantsQuery.swift +// +// +// Created by Chris Dillard on 11/07/2023. +// + +import Foundation + +public struct RunsQuery: Codable { + + public let assistantId: String + + enum CodingKeys: String, CodingKey { + case assistantId = "assistant_id" + } + + public init(assistantId: String) { + + self.assistantId = assistantId + } +} diff --git a/Sources/OpenAI/Public/Models/RunsResult.swift b/Sources/OpenAI/Public/Models/RunsResult.swift new file mode 100644 index 00000000..858f15f5 --- /dev/null +++ b/Sources/OpenAI/Public/Models/RunsResult.swift @@ -0,0 +1,13 @@ +// +// RunsResult.swift +// +// +// Created by Chris Dillard on 11/07/2023. +// + +import Foundation + +public struct RunsResult: Codable, Equatable { + + public let id: String +} diff --git a/Sources/OpenAI/Public/Models/ThreadAddMessageQuery.swift b/Sources/OpenAI/Public/Models/ThreadAddMessageQuery.swift new file mode 100644 index 00000000..153851e2 --- /dev/null +++ b/Sources/OpenAI/Public/Models/ThreadAddMessageQuery.swift @@ -0,0 +1,24 @@ +// +// ThreadAddMessageQuery.swift +// +// +// Created by Chris Dillard on 11/07/2023. +// + +import Foundation + +public struct ThreadAddMessageQuery: Equatable, Codable { + public let role: String + public let content: String + + enum CodingKeys: String, CodingKey { + case role + case content + + } + + public init(role: String, content: String) { + self.role = role + self.content = content + } +} diff --git a/Sources/OpenAI/Public/Models/ThreadAddMessagesResult.swift b/Sources/OpenAI/Public/Models/ThreadAddMessagesResult.swift new file mode 100644 index 00000000..f39736ef --- /dev/null +++ b/Sources/OpenAI/Public/Models/ThreadAddMessagesResult.swift @@ -0,0 +1,13 @@ +// +// ThreadsMessagesResult.swift +// +// +// Created by Chris Dillard on 11/07/2023. +// + +import Foundation + +public struct ThreadAddMessageResult: Codable, Equatable { + public let id: String + +} diff --git a/Sources/OpenAI/Public/Models/ThreadsMessagesResult.swift b/Sources/OpenAI/Public/Models/ThreadsMessagesResult.swift new file mode 100644 index 00000000..975f897f --- /dev/null +++ b/Sources/OpenAI/Public/Models/ThreadsMessagesResult.swift @@ -0,0 +1,65 @@ +// +// ThreadsMessagesResult.swift +// +// +// Created by Chris Dillard on 11/07/2023. +// + +import Foundation + +public struct ThreadsMessagesResult: Codable, Equatable { + + public struct ThreadsMessage: Codable, Equatable { + + public struct ThreadsMessageContent: Codable, Equatable { + + public struct ThreadsMessageContentText: Codable, Equatable { + + public let value: String? + + enum CodingKeys: String, CodingKey { + case value + } + } + + public struct ImageFileContentText: Codable, Equatable { + + public let fildId: String + + enum CodingKeys: String, CodingKey { + case fildId = "file_id" + } + } + + public let type: String + public let text: ThreadsMessageContentText? + public let imageFile: ThreadsMessageContentText? + + enum CodingKeys: String, CodingKey { + case type + case text + case imageFile = "image_file" + } + } + + public let id: String + + public let role: String + + public let content: [ThreadsMessageContent] + + enum CodingKeys: String, CodingKey { + case id + case content + case role + } + } + + + public let data: [ThreadsMessage] + + enum CodingKeys: String, CodingKey { + case data + } + +} diff --git a/Sources/OpenAI/Public/Models/ThreadsQuery.swift b/Sources/OpenAI/Public/Models/ThreadsQuery.swift new file mode 100644 index 00000000..1f27849f --- /dev/null +++ b/Sources/OpenAI/Public/Models/ThreadsQuery.swift @@ -0,0 +1,20 @@ +// +// ThreadsQuery.swift +// +// +// Created by Chris Dillard on 11/07/2023. +// + +import Foundation + +public struct ThreadsQuery: Equatable, Codable { + public let messages: [Chat] + + enum CodingKeys: String, CodingKey { + case messages + } + + public init(messages: [Chat]) { + self.messages = messages + } +} diff --git a/Sources/OpenAI/Public/Models/ThreadsResult.swift b/Sources/OpenAI/Public/Models/ThreadsResult.swift new file mode 100644 index 00000000..def6031c --- /dev/null +++ b/Sources/OpenAI/Public/Models/ThreadsResult.swift @@ -0,0 +1,13 @@ +// +// AssistantsResult.swift +// +// +// Created by Chris Dillard on 11/07/2023. +// + +import Foundation + +public struct ThreadsResult: Codable, Equatable { + + public let id: String +} diff --git a/Sources/OpenAI/Public/Protocols/OpenAIProtocol+Async.swift b/Sources/OpenAI/Public/Protocols/OpenAIProtocol+Async.swift index b515a234..7428e34f 100644 --- a/Sources/OpenAI/Public/Protocols/OpenAIProtocol+Async.swift +++ b/Sources/OpenAI/Public/Protocols/OpenAIProtocol+Async.swift @@ -228,4 +228,149 @@ public extension OpenAIProtocol { } } } + + // 1106 + func assistants( + query: AssistantsQuery?, + method: String, + after: String? + ) async throws -> AssistantsResult { + try await withCheckedThrowingContinuation { continuation in + assistants(query: query, method: method, after: after) { result in + switch result { + case let .success(success): + return continuation.resume(returning: success) + case let .failure(failure): + return continuation.resume(throwing: failure) + } + } + } + } + + func assistantModify( + query: AssistantsQuery, + asstId: String + ) async throws -> AssistantsResult { + try await withCheckedThrowingContinuation { continuation in + assistantModify(query: query, asstId: asstId) { result in + switch result { + case let .success(success): + return continuation.resume(returning: success) + case let .failure(failure): + return continuation.resume(throwing: failure) + } + } + } + } + + func threads( + query: ThreadsQuery + ) async throws -> ThreadsResult { + try await withCheckedThrowingContinuation { continuation in + threads(query: query) { result in + switch result { + case let .success(success): + return continuation.resume(returning: success) + case let .failure(failure): + return continuation.resume(throwing: failure) + } + } + } + } + + func runs( + threadId: String, + query: RunsQuery + ) async throws -> RunsResult { + try await withCheckedThrowingContinuation { continuation in + runs(threadId: threadId, query: query) { result in + switch result { + case let .success(success): + return continuation.resume(returning: success) + case let .failure(failure): + return continuation.resume(throwing: failure) + } + } + } + } + + func runRetrieve( + threadId: String, + runId: String + ) async throws -> RunRetreiveResult { + try await withCheckedThrowingContinuation { continuation in + runRetrieve(threadId: threadId, runId: runId) { result in + switch result { + case let .success(success): + return continuation.resume(returning: success) + case let .failure(failure): + return continuation.resume(throwing: failure) + } + } + } + } + + func runRetrieveSteps( + threadId: String, + runId: String, + before: String? + ) async throws -> RunRetreiveStepsResult { + try await withCheckedThrowingContinuation { continuation in + runRetrieveSteps(threadId: threadId, runId: runId, before: before) { result in + switch result { + case let .success(success): + return continuation.resume(returning: success) + case let .failure(failure): + return continuation.resume(throwing: failure) + } + } + } + } + + func threadsMessages( + threadId: String, + before: String? + ) async throws -> ThreadsMessagesResult { + try await withCheckedThrowingContinuation { continuation in + threadsMessages(threadId: threadId, before: before) { result in + switch result { + case let .success(success): + return continuation.resume(returning: success) + case let .failure(failure): + return continuation.resume(throwing: failure) + } + } + } + } + + func threadsAddMessage( + threadId: String, + query: ThreadAddMessageQuery + ) async throws -> ThreadAddMessageResult { + try await withCheckedThrowingContinuation { continuation in + threadsAddMessage(threadId: threadId, query: query) { result in + switch result { + case let .success(success): + return continuation.resume(returning: success) + case let .failure(failure): + return continuation.resume(throwing: failure) + } + } + } + } + func files( + query: FilesQuery + ) async throws -> FilesResult { + try await withCheckedThrowingContinuation { continuation in + files(query: query) { result in + switch result { + case let .success(success): + return continuation.resume(returning: success) + case let .failure(failure): + return continuation.resume(throwing: failure) + } + } + } + } + // 1106 end } diff --git a/Sources/OpenAI/Public/Protocols/OpenAIProtocol+Combine.swift b/Sources/OpenAI/Public/Protocols/OpenAIProtocol+Combine.swift index da8b7dfb..0853a654 100644 --- a/Sources/OpenAI/Public/Protocols/OpenAIProtocol+Combine.swift +++ b/Sources/OpenAI/Public/Protocols/OpenAIProtocol+Combine.swift @@ -126,6 +126,44 @@ public extension OpenAIProtocol { } .eraseToAnyPublisher() } + + // 1106 + func assistants(query: AssistantsQuery?, method: String, after: String?) -> AnyPublisher { + Future { + assistants(query: query, method: method, after: after, completion: $0) + } + .eraseToAnyPublisher() + } + + func threads(query: ThreadsQuery) -> AnyPublisher { + Future { + threads(query: query, completion: $0) + } + .eraseToAnyPublisher() + } + + func runs(threadId: String, query: RunsQuery) -> AnyPublisher { + Future { + runs(threadId: threadId, query: query, completion: $0) + } + .eraseToAnyPublisher() + } + + func runRetrieve(threadId: String, runId: String) -> AnyPublisher { + Future { + + runRetrieve(threadId: threadId, runId: runId, completion: $0) + } + .eraseToAnyPublisher() + } + + func threadsMessages(threadId: String, before: String?) -> AnyPublisher { + Future { + threadsMessages(threadId: threadId, before: before, completion: $0) + } + .eraseToAnyPublisher() + } + // 1106 end } #endif diff --git a/Sources/OpenAI/Public/Protocols/OpenAIProtocol.swift b/Sources/OpenAI/Public/Protocols/OpenAIProtocol.swift index caf97090..f178bd2b 100644 --- a/Sources/OpenAI/Public/Protocols/OpenAIProtocol.swift +++ b/Sources/OpenAI/Public/Protocols/OpenAIProtocol.swift @@ -247,4 +247,174 @@ public protocol OpenAIProtocol { Returns a `Result` of type `AudioTranslationResult` if successful, or an `Error` if an error occurs. **/ func audioTranslations(query: AudioTranslationQuery, completion: @escaping (Result) -> Void) + + /// + // The following functions represent new functionality added to OpenAI Beta on 11-06-23 + /// + /// + /** + This function sends a assistants query to the OpenAI API and creates an assistant. The Assistants API in this usage enables you to create an assistant. + + Example: Create Assistant + ``` + let query = AssistantsQuery(model: Model.gpt4_1106_preview, name: name, description: description, instructions: instructions, tools: tools, fileIds: fileIds) + openAI.assistants(query: query) { result in + //Handle response here + } + ``` + + Example: List Assistants + ``` + openAI.assistants(query: nil, method: "GET") { result in + //Handle response here + } + ``` + + - Parameter query: The `AssistantsQuery?` instance, containing the information required for the assistant request. Passing nil is used for GET form of request. + - Parameter method: The method to use with the HTTP request. Supports POST (default) and GET. + - Parameter completion: The completion handler to be executed upon completion of the assistant request. + Returns a `Result` of type `AssistantsResult` if successful, or an `Error` if an error occurs. + **/ + func assistants(query: AssistantsQuery?, method: String, after: String?, completion: @escaping (Result) -> Void) + + /** + This function sends a assistants query to the OpenAI API and modifies an assistant. The Assistants API in this usage enables you to modify an assistant. + + Example: Modify Assistant + ``` + let query = AssistantsQuery(model: Model.gpt4_1106_preview, name: name, description: description, instructions: instructions, tools: tools, fileIds: fileIds) + openAI.assistantModify(query: query, asstId: "asst_1234") { result in + //Handle response here + } + ``` + + - Parameter query: The `AssistantsQuery` instance, containing the information required for the assistant request. + - Parameter asstId: The assistant id for the assistant to modify. + - Parameter completion: The completion handler to be executed upon completion of the assistant request. + Returns a `Result` of type `AssistantsResult` if successful, or an `Error` if an error occurs. + **/ + func assistantModify(query: AssistantsQuery, asstId: String, completion: @escaping (Result) -> Void) + + /** + This function sends a threads query to the OpenAI API and creates a thread. The Threads API in this usage enables you to create a thread. + + Example: Create Thread + ``` + let threadsQuery = ThreadsQuery(messages: [Chat(role: message.role, content: message.content)]) + openAI.threads(query: threadsQuery) { result in + //Handle response here + } + + ``` + - Parameter query: The `ThreadsQuery` instance, containing the information required for the threads request. + - Parameter completion: The completion handler to be executed upon completion of the threads request. + Returns a `Result` of type `ThreadsResult` if successful, or an `Error` if an error occurs. + **/ + func threads(query: ThreadsQuery, completion: @escaping (Result) -> Void) + + /** + This function sends a runs query to the OpenAI API and creates a run. The Runs API in this usage enables you to create a run. + + Example: Create Run + ``` + let runsQuery = RunsQuery(assistantId: currentAssistantId) + openAI.runs(threadId: threadsResult.id, query: runsQuery) { result in + //Handle response here + } + ``` + + - Parameter threadId: The thread id for the thread to run. + - Parameter query: The `RunsQuery` instance, containing the information required for the runs request. + - Parameter completion: The completion handler to be executed upon completion of the runs request. + Returns a `Result` of type `RunsResult` if successful, or an `Error` if an error occurs. + **/ + func runs(threadId: String, query: RunsQuery, completion: @escaping (Result) -> Void) + + /** + This function sends a thread id and run id to the OpenAI API and retrieves a run. The Runs API in this usage enables you to retrieve a run. + + Example: Retrieve Run + ``` + openAI.runRetrieve(threadId: currentThreadId, runId: currentRunId) { result in + //Handle response here + } + ``` + - Parameter threadId: The thread id for the thread to run. + - Parameter runId: The run id for the run to retrieve. + - Parameter completion: The completion handler to be executed upon completion of the runRetrieve request. + Returns a `Result` of type `RunRetreiveResult` if successful, or an `Error` if an error occurs. + **/ + func runRetrieve(threadId: String, runId: String, completion: @escaping (Result) -> Void) + + /** + This function sends a thread id and run id to the OpenAI API and retrieves a list of run steps. The Runs API in this usage enables you to retrieve a runs run steps. + + Example: Retrieve Run Steps + ``` + openAI.runRetrieveSteps(threadId: currentThreadId, runId: currentRunId) { result in + //Handle response here + } + ``` + - Parameter threadId: The thread id for the thread to run. + - Parameter runId: The run id for the run to retrieve. + - Parameter before: String?: The message id for the run step that defines your place in the list of run steps. Pass nil to get all. + - Parameter completion: The completion handler to be executed upon completion of the runRetrieve request. + Returns a `Result` of type `RunRetreiveStepsResult` if successful, or an `Error` if an error occurs. + **/ + func runRetrieveSteps(threadId: String, runId: String, before: String?, completion: @escaping (Result) -> Void) + + + /** + This function sends a thread id and run id to the OpenAI API and retrieves a threads messages. + The Thread API in this usage enables you to retrieve a threads messages. + + Example: Get Threads Messages + ``` + openAI.threadsMessages(threadId: currentThreadId, before: nil) { result in + //Handle response here + } + ``` + + - Parameter threadId: The thread id for the thread to run. + - Parameter before: String?: The message id for the message that defines your place in the list of messages. Pass nil to get all. + - Parameter completion: The completion handler to be executed upon completion of the runRetrieve request. + Returns a `Result` of type `ThreadsMessagesResult` if successful, or an `Error` if an error occurs. + **/ + func threadsMessages(threadId: String, before: String?, completion: @escaping (Result) -> Void) + + /** + This function sends a thread id and message contents to the OpenAI API and returns a run. + + Example: Add Message to Thread + ``` + let query = ThreadAddMessageQuery(role: message.role.rawValue, content: message.content) + openAI.threadsAddMessage(threadId: currentThreadId, query: query) { result in + //Handle response here + } + ``` + + - Parameter threadId: The thread id for the thread to run. + - Parameter query: The `ThreadAddMessageQuery` instance, containing the information required for the threads request. + - Parameter completion: The completion handler to be executed upon completion of the runRetrieve request. + Returns a `Result` of type `ThreadAddMessageResult` if successful, or an `Error` if an error occurs. + **/ + func threadsAddMessage(threadId: String, query: ThreadAddMessageQuery, completion: @escaping (Result) -> Void) + + /** + This function sends a purpose string, file contents, and fileName contents to the OpenAI API and returns a file id result. + + Example: Upload file + ``` + let query = FilesQuery(purpose: "assistants", file: fileData, fileName: url.lastPathComponent, contentType: "application/pdf") + openAI.files(query: query) { result in + //Handle response here + } + ``` + - Parameter query: The `FilesQuery` instance, containing the information required for the files request. + - Parameter completion: The completion handler to be executed upon completion of the files request. + Returns a `Result` of type `FilesResult` if successful, or an `Error` if an error occurs. + **/ + func files(query: FilesQuery, completion: @escaping (Result) -> Void) + + // END new functionality added to OpenAI Beta on 11-06-23 end } diff --git a/Tests/OpenAITests/OpenAITests.swift b/Tests/OpenAITests/OpenAITests.swift index a66e9a45..a31373b8 100644 --- a/Tests/OpenAITests/OpenAITests.swift +++ b/Tests/OpenAITests/OpenAITests.swift @@ -33,7 +33,7 @@ class OpenAITests: XCTestCase { let result = try await openAI.completions(query: query) XCTAssertEqual(result, expectedResult) } - + func testCompletionsAPIError() async throws { let query = CompletionsQuery(model: .textDavinci_003, prompt: "What is 42?", temperature: 0, maxTokens: 100, topP: 1, frequencyPenalty: 0, presencePenalty: 0, stop: ["\\n"]) let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100") @@ -357,6 +357,133 @@ class OpenAITests: XCTestCase { let completionsURL = openAI.buildURL(path: .completions) XCTAssertEqual(completionsURL, URL(string: "https://my.host.com/v1/completions")) } + + // 1106 + func testAssistantQuery() async throws { + let query = AssistantsQuery(model: .gpt4_1106_preview, name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: []) + let expectedResult = AssistantsResult(id: "asst_1234", data: [AssistantsResult.AssistantContent(id: "asst_9876", name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: nil, fileIds: nil)], tools: []) + try self.stub(result: expectedResult) + + let result = try await openAI.assistants(query: query, method: "POST", after: nil) + XCTAssertEqual(result, expectedResult) + } + + func testAssistantQueryError() async throws { + let query = AssistantsQuery(model: .gpt4_1106_preview, name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: []) + + let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100") + self.stub(error: inError) + + let apiError: APIError = try await XCTExpectError { try await openAI.assistants(query: query, method: "POST", after: nil) } + XCTAssertEqual(inError, apiError) + } + + func testListAssistantQuery() async throws { + let expectedResult = AssistantsResult(id: nil, data: [AssistantsResult.AssistantContent(id: "asst_9876", name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: nil, fileIds: nil)], tools: nil) + try self.stub(result: expectedResult) + + let result = try await openAI.assistants(query: nil, method: "GET", after: nil) + XCTAssertEqual(result, expectedResult) + } + + func testListAssistantQueryError() async throws { + let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100") + self.stub(error: inError) + + let apiError: APIError = try await XCTExpectError { try await openAI.assistants(query: nil, method: "GET", after: nil) } + XCTAssertEqual(inError, apiError) + } + + func testThreadsQuery() async throws { + let query = ThreadsQuery(messages: [Chat(role: .user, content: "Hello, What is AI?")]) + let expectedResult = ThreadsResult(id: "thread_1234") + try self.stub(result: expectedResult) + + let result = try await openAI.threads(query: query) + XCTAssertEqual(result, expectedResult) + } + + func testThreadsQueryError() async throws { + let query = ThreadsQuery(messages: [Chat(role: .user, content: "Hello, What is AI?")]) + + let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100") + self.stub(error: inError) + + let apiError: APIError = try await XCTExpectError { try await openAI.threads(query: query) } + XCTAssertEqual(inError, apiError) + } + + func testRunsQuery() async throws { + let query = RunsQuery(assistantId: "asst_7654321") + let expectedResult = RunsResult(id: "run_1234") + try self.stub(result: expectedResult) + + let result = try await openAI.runs(threadId: "thread_1234", query: query) + XCTAssertEqual(result, expectedResult) + } + + func testRunsQueryError() async throws { + let query = RunsQuery(assistantId: "asst_7654321") + let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100") + self.stub(error: inError) + + let apiError: APIError = try await XCTExpectError { try await openAI.runs(threadId: "thread_1234", query: query) } + XCTAssertEqual(inError, apiError) + } + + func testRunRetrieveQuery() async throws { + let expectedResult = RunRetreiveResult(status: "in_progress") + try self.stub(result: expectedResult) + + let result = try await openAI.runRetrieve(threadId: "thread_1234", runId: "run_1234") + XCTAssertEqual(result, expectedResult) + } + + func testRunRetrieveQueryError() async throws { + let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100") + self.stub(error: inError) + + let apiError: APIError = try await XCTExpectError { try await openAI.runRetrieve(threadId: "thread_1234", runId: "run_1234") } + XCTAssertEqual(inError, apiError) + } + + func testThreadsMessageQuery() async throws { + let expectedResult = ThreadsMessagesResult(data: [ThreadsMessagesResult.ThreadsMessage(id: "thread_1234", role: Chat.Role.user.rawValue, content: [ThreadsMessagesResult.ThreadsMessage.ThreadsMessageContent(type: "text", text: ThreadsMessagesResult.ThreadsMessage.ThreadsMessageContent.ThreadsMessageContentText(value: "Hello, What is AI?"))])]) + try self.stub(result: expectedResult) + + let result = try await openAI.threadsMessages(threadId: "thread_1234", before: nil) + XCTAssertEqual(result, expectedResult) + } + + func testThreadsMessageQueryError() async throws { + let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100") + self.stub(error: inError) + + let apiError: APIError = try await XCTExpectError { try await openAI.threadsMessages(threadId: "thread_1234", before: nil) } + XCTAssertEqual(inError, apiError) + } + + func testCustomRunsURLBuilt() { + let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", host: "my.host.com", timeoutInterval: 14) + let openAI = OpenAI(configuration: configuration, session: self.urlSession) + let completionsURL = openAI.buildRunsURL(path: .runs, threadId: "thread_4321") + XCTAssertEqual(completionsURL, URL(string: "https://my.host.com/v1/threads/thread_4321/runs")) + } + + func testCustomRunsRetrieveURLBuilt() { + let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", host: "my.host.com", timeoutInterval: 14) + let openAI = OpenAI(configuration: configuration, session: self.urlSession) + let completionsURL = openAI.buildRunRetrieveURL(path: .runRetrieve, threadId: "thread_4321", runId: "run_1234") + XCTAssertEqual(completionsURL, URL(string: "https://my.host.com/v1/threads/thread_4321/runs/run_1234")) + } + + func testCustomRunRetrieveStepsURLBuilt() { + let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", host: "my.host.com", timeoutInterval: 14) + let openAI = OpenAI(configuration: configuration, session: self.urlSession) + let completionsURL = openAI.buildRunRetrieveURL(path: .runRetrieveSteps, threadId: "thread_4321", runId: "run_1234") + XCTAssertEqual(completionsURL, URL(string: "https://my.host.com/v1/threads/thread_4321/runs/run_1234/steps")) + } + // 1106 end } @available(tvOS 13.0, *) diff --git a/Tests/OpenAITests/OpenAITestsCombine.swift b/Tests/OpenAITests/OpenAITestsCombine.swift index e2b58458..421b1612 100644 --- a/Tests/OpenAITests/OpenAITestsCombine.swift +++ b/Tests/OpenAITests/OpenAITestsCombine.swift @@ -123,6 +123,57 @@ final class OpenAITestsCombine: XCTestCase { let result = try awaitPublisher(openAI.audioTranslations(query: query)) XCTAssertEqual(result, transcriptionResult) } + + // 1106 + func testAssistantQuery() throws { + let query = AssistantsQuery(model: .gpt4_1106_preview, name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: []) + let expectedResult = AssistantsResult(id: "asst_1234", data: [AssistantsResult.AssistantContent(id: "asst_9876", name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: nil, fileIds: nil)], tools: []) + try self.stub(result: expectedResult) + + let result = try awaitPublisher(openAI.assistants(query: query, method: "POST", after: nil)) + XCTAssertEqual(result, expectedResult) + } + + func testThreadsQuery() throws { + let query = ThreadsQuery(messages: [Chat(role: .user, content: "Hello, What is AI?")]) + let expectedResult = ThreadsResult(id: "thread_1234") + + try self.stub(result: expectedResult) + let result = try awaitPublisher(openAI.threads(query: query)) + + XCTAssertEqual(result, expectedResult) + } + + func testRunsQuery() throws { + let query = RunsQuery(assistantId: "asst_7654321") + let expectedResult = RunsResult(id: "run_1234") + + try self.stub(result: expectedResult) + let result = try awaitPublisher(openAI.runs(threadId: "thread_1234", query: query)) + + XCTAssertEqual(result, expectedResult) + } + + func testRunRetrieveQuery() throws { + let expectedResult = RunRetreiveResult(status: "in_progress") + try self.stub(result: expectedResult) + + let result = try awaitPublisher(openAI.runRetrieve(threadId: "thread_1234", runId: "run_1234")) + + XCTAssertEqual(result, expectedResult) + } + + func testThreadsMessageQuery() throws { + let expectedResult = ThreadsMessagesResult(data: [ThreadsMessagesResult.ThreadsMessage(id: "thread_1234", role: Chat.Role.user.rawValue, content: [ThreadsMessagesResult.ThreadsMessage.ThreadsMessageContent(type: "text", text: ThreadsMessagesResult.ThreadsMessage.ThreadsMessageContent.ThreadsMessageContentText(value: "Hello, What is AI?"))])]) + try self.stub(result: expectedResult) + + let result = try awaitPublisher(openAI.threadsMessages(threadId: "thread_1234", before: nil)) + + XCTAssertEqual(result, expectedResult) + } + // 1106 end + + } @available(tvOS 13.0, *) From 724b2a144747efe4ea166fc3239afc41a0d6e7fe Mon Sep 17 00:00:00 2001 From: Chris Dillard Date: Mon, 18 Dec 2023 12:52:29 -0700 Subject: [PATCH 2/6] Fix tests --- Tests/OpenAITests/OpenAITests.swift | 2 +- Tests/OpenAITests/OpenAITestsCombine.swift | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Tests/OpenAITests/OpenAITests.swift b/Tests/OpenAITests/OpenAITests.swift index a31373b8..11bb3f29 100644 --- a/Tests/OpenAITests/OpenAITests.swift +++ b/Tests/OpenAITests/OpenAITests.swift @@ -448,7 +448,7 @@ class OpenAITests: XCTestCase { } func testThreadsMessageQuery() async throws { - let expectedResult = ThreadsMessagesResult(data: [ThreadsMessagesResult.ThreadsMessage(id: "thread_1234", role: Chat.Role.user.rawValue, content: [ThreadsMessagesResult.ThreadsMessage.ThreadsMessageContent(type: "text", text: ThreadsMessagesResult.ThreadsMessage.ThreadsMessageContent.ThreadsMessageContentText(value: "Hello, What is AI?"))])]) + let expectedResult = ThreadsMessagesResult(data: [ThreadsMessagesResult.ThreadsMessage(id: "thread_1234", role: Chat.Role.user.rawValue, content: [ThreadsMessagesResult.ThreadsMessage.ThreadsMessageContent(type: "text", text: ThreadsMessagesResult.ThreadsMessage.ThreadsMessageContent.ThreadsMessageContentText(value: "Hello, What is AI?"), imageFile: nil)])]) try self.stub(result: expectedResult) let result = try await openAI.threadsMessages(threadId: "thread_1234", before: nil) diff --git a/Tests/OpenAITests/OpenAITestsCombine.swift b/Tests/OpenAITests/OpenAITestsCombine.swift index 421b1612..a306ea00 100644 --- a/Tests/OpenAITests/OpenAITestsCombine.swift +++ b/Tests/OpenAITests/OpenAITestsCombine.swift @@ -164,7 +164,7 @@ final class OpenAITestsCombine: XCTestCase { } func testThreadsMessageQuery() throws { - let expectedResult = ThreadsMessagesResult(data: [ThreadsMessagesResult.ThreadsMessage(id: "thread_1234", role: Chat.Role.user.rawValue, content: [ThreadsMessagesResult.ThreadsMessage.ThreadsMessageContent(type: "text", text: ThreadsMessagesResult.ThreadsMessage.ThreadsMessageContent.ThreadsMessageContentText(value: "Hello, What is AI?"))])]) + let expectedResult = ThreadsMessagesResult(data: [ThreadsMessagesResult.ThreadsMessage(id: "thread_1234", role: Chat.Role.user.rawValue, content: [ThreadsMessagesResult.ThreadsMessage.ThreadsMessageContent(type: "text", text: ThreadsMessagesResult.ThreadsMessage.ThreadsMessageContent.ThreadsMessageContentText(value: "Hello, What is AI?"), imageFile: nil)])]) try self.stub(result: expectedResult) let result = try awaitPublisher(openAI.threadsMessages(threadId: "thread_1234", before: nil)) From a3d06b1d733fe4e51bbf76783dd7da103b52cb25 Mon Sep 17 00:00:00 2001 From: Chris Dillard Date: Mon, 18 Dec 2023 14:03:24 -0700 Subject: [PATCH 3/6] Remove added debugging code --- Sources/OpenAI/OpenAI.swift | 4 ---- 1 file changed, 4 deletions(-) diff --git a/Sources/OpenAI/OpenAI.swift b/Sources/OpenAI/OpenAI.swift index f9eaf606..b1beaa7d 100644 --- a/Sources/OpenAI/OpenAI.swift +++ b/Sources/OpenAI/OpenAI.swift @@ -175,9 +175,6 @@ extension OpenAI { var apiError: Error? = nil do { - - let errorText = String(data: data, encoding: .utf8) - let decoded = try JSONDecoder().decode(ResultType.self, from: data) completion(.success(decoded)) } catch { @@ -209,7 +206,6 @@ extension OpenAI { onResult(.success(object)) } session.onProcessingError = {_, error in - print("OpenAI API error = \(error.localizedDescription)") onResult(.failure(error)) } session.onComplete = { [weak self] object, error in From b789eed49bd2b9aad7b33b6c0dfb2cf60ad869f7 Mon Sep 17 00:00:00 2001 From: Chris Dillard Date: Wed, 20 Dec 2023 16:59:57 -0700 Subject: [PATCH 4/6] Fix filesResult --- Sources/OpenAI/Public/Models/FilesResult.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/OpenAI/Public/Models/FilesResult.swift b/Sources/OpenAI/Public/Models/FilesResult.swift index 0799d8c4..d5012d46 100644 --- a/Sources/OpenAI/Public/Models/FilesResult.swift +++ b/Sources/OpenAI/Public/Models/FilesResult.swift @@ -10,6 +10,6 @@ import Foundation public struct FilesResult: Codable, Equatable { public let id: String - public let name: String + public let name: String? } From a0d1126869423d5dbc057449418ff406410b1517 Mon Sep 17 00:00:00 2001 From: Brent Whitman Date: Tue, 30 Jan 2024 09:39:31 -0800 Subject: [PATCH 5/6] Add support for Assistants API features: function calls; create & run thread in a single request; submit tool outputs. --- Demo/DemoChat/Sources/AssistantStore.swift | 37 +- Demo/DemoChat/Sources/ChatStore.swift | 125 ++++--- Demo/DemoChat/Sources/Models/Assistant.swift | 13 +- .../UI/AssistantModalContentView.swift | 130 +++++-- .../Sources/UI/AssistantsListView.swift | 13 +- Demo/DemoChat/Sources/UI/AssistantsView.swift | 137 ++++---- Demo/DemoChat/Sources/UI/FunctionView.swift | 72 ++++ README.md | 109 +++--- Sources/OpenAI/OpenAI.swift | 40 ++- .../Public/Models/AssistantResult.swift | 26 ++ .../Public/Models/AssistantsQuery.swift | 31 +- .../Public/Models/AssistantsResult.swift | 32 +- Sources/OpenAI/Public/Models/ChatQuery.swift | 198 ++++++++--- Sources/OpenAI/Public/Models/ChatResult.swift | 2 +- .../Public/Models/FunctionDeclaration.swift | 25 ++ .../OpenAI/Public/Models/ImagesQuery.swift | 2 +- .../OpenAI/Public/Models/MessageQuery.swift | 26 ++ Sources/OpenAI/Public/Models/RunResult.swift | 55 +++ .../Public/Models/RunRetrieveResult.swift | 13 - .../Models/RunRetrieveStepsResult.swift | 31 +- .../Public/Models/RunToolOutputsQuery.swift | 35 ++ Sources/OpenAI/Public/Models/RunsResult.swift | 13 - .../Public/Models/ThreadAddMessageQuery.swift | 24 -- .../OpenAI/Public/Models/ThreadRunQuery.swift | 39 +++ .../Public/Models/ThreadsMessagesResult.swift | 13 +- .../OpenAI/Public/Models/ThreadsQuery.swift | 4 +- Sources/OpenAI/Public/Models/Tool.swift | 64 ++++ .../Protocols/OpenAIProtocol+Async.swift | 71 +++- .../Protocols/OpenAIProtocol+Combine.swift | 65 +++- .../Public/Protocols/OpenAIProtocol.swift | 84 +++-- Tests/OpenAITests/OpenAITests.swift | 177 ++++++++-- Tests/OpenAITests/OpenAITestsCombine.swift | 90 ++++- Tests/OpenAITests/OpenAITestsDecoder.swift | 318 +++++++++++++++--- 33 files changed, 1608 insertions(+), 506 deletions(-) create mode 100644 Demo/DemoChat/Sources/UI/FunctionView.swift create mode 100644 Sources/OpenAI/Public/Models/AssistantResult.swift create mode 100644 Sources/OpenAI/Public/Models/FunctionDeclaration.swift create mode 100644 Sources/OpenAI/Public/Models/MessageQuery.swift create mode 100644 Sources/OpenAI/Public/Models/RunResult.swift delete mode 100644 Sources/OpenAI/Public/Models/RunRetrieveResult.swift create mode 100644 Sources/OpenAI/Public/Models/RunToolOutputsQuery.swift delete mode 100644 Sources/OpenAI/Public/Models/RunsResult.swift delete mode 100644 Sources/OpenAI/Public/Models/ThreadAddMessageQuery.swift create mode 100644 Sources/OpenAI/Public/Models/ThreadRunQuery.swift create mode 100644 Sources/OpenAI/Public/Models/Tool.swift diff --git a/Demo/DemoChat/Sources/AssistantStore.swift b/Demo/DemoChat/Sources/AssistantStore.swift index 1393fefb..34e597ac 100644 --- a/Demo/DemoChat/Sources/AssistantStore.swift +++ b/Demo/DemoChat/Sources/AssistantStore.swift @@ -27,11 +27,11 @@ public final class AssistantStore: ObservableObject { // MARK: Models @MainActor - func createAssistant(name: String, description: String, instructions: String, codeInterpreter: Bool, retrievel: Bool, fileIds: [String]? = nil) async -> String? { + func createAssistant(name: String, description: String, instructions: String, codeInterpreter: Bool, retrieval: Bool, functions: [FunctionDeclaration], fileIds: [String]? = nil) async -> String? { do { - let tools = createToolsArray(codeInterpreter: codeInterpreter, retrieval: retrievel) + let tools = createToolsArray(codeInterpreter: codeInterpreter, retrieval: retrieval, functions: functions) let query = AssistantsQuery(model: Model.gpt4_1106_preview, name: name, description: description, instructions: instructions, tools:tools, fileIds: fileIds) - let response = try await openAIClient.assistants(query: query, method: "POST", after: nil) + let response = try await openAIClient.assistantCreate(query: query) // Refresh assistants with one just created (or modified) let _ = await getAssistants() @@ -47,11 +47,11 @@ public final class AssistantStore: ObservableObject { } @MainActor - func modifyAssistant(asstId: String, name: String, description: String, instructions: String, codeInterpreter: Bool, retrievel: Bool, fileIds: [String]? = nil) async -> String? { + func modifyAssistant(asstId: String, name: String, description: String, instructions: String, codeInterpreter: Bool, retrieval: Bool, functions: [FunctionDeclaration], fileIds: [String]? = nil) async -> String? { do { - let tools = createToolsArray(codeInterpreter: codeInterpreter, retrieval: retrievel) + let tools = createToolsArray(codeInterpreter: codeInterpreter, retrieval: retrieval, functions: functions) let query = AssistantsQuery(model: Model.gpt4_1106_preview, name: name, description: description, instructions: instructions, tools:tools, fileIds: fileIds) - let response = try await openAIClient.assistantModify(query: query, asstId: asstId) + let response = try await openAIClient.assistantModify(query: query, assistantId: asstId) // Returns assistantId return response.id @@ -66,15 +66,24 @@ public final class AssistantStore: ObservableObject { @MainActor func getAssistants(limit: Int = 20, after: String? = nil) async -> [Assistant] { do { - let response = try await openAIClient.assistants(query: nil, method: "GET", after: after) + let response = try await openAIClient.assistants(after: after) var assistants = [Assistant]() for result in response.data ?? [] { - let codeInterpreter = result.tools?.filter { $0.toolType == "code_interpreter" }.first != nil - let retrieval = result.tools?.filter { $0.toolType == "retrieval" }.first != nil + let tools = result.tools ?? [] + let codeInterpreter = tools.contains { $0 == .codeInterpreter } + let retrieval = tools.contains { $0 == .retrieval } + let functions = tools.compactMap { + switch $0 { + case let .function(declaration): + return declaration + default: + return nil + } + } let fileIds = result.fileIds ?? [] - assistants.append(Assistant(id: result.id, name: result.name, description: result.description, instructions: result.instructions, codeInterpreter: codeInterpreter, retrieval: retrieval, fileIds: fileIds)) + assistants.append(Assistant(id: result.id, name: result.name ?? "", description: result.description, instructions: result.instructions, codeInterpreter: codeInterpreter, retrieval: retrieval, fileIds: fileIds, functions: functions)) } if after == nil { availableAssistants = assistants @@ -112,14 +121,14 @@ public final class AssistantStore: ObservableObject { } } - func createToolsArray(codeInterpreter: Bool, retrieval: Bool) -> [Tool] { + func createToolsArray(codeInterpreter: Bool, retrieval: Bool, functions: [FunctionDeclaration]) -> [Tool] { var tools = [Tool]() if codeInterpreter { - tools.append(Tool(toolType: "code_interpreter")) + tools.append(.codeInterpreter) } if retrieval { - tools.append(Tool(toolType: "retrieval")) + tools.append(.retrieval) } - return tools + return tools + functions.map { .function($0) } } } diff --git a/Demo/DemoChat/Sources/ChatStore.swift b/Demo/DemoChat/Sources/ChatStore.swift index 99f62696..113f9b8b 100644 --- a/Demo/DemoChat/Sources/ChatStore.swift +++ b/Demo/DemoChat/Sources/ChatStore.swift @@ -91,7 +91,7 @@ public final class ChatStore: ObservableObject { conversations[conversationIndex].messages.append(localMessage) do { - let threadsQuery = ThreadsQuery(messages: [Chat(role: message.role, content: message.content)]) + let threadsQuery = ThreadsQuery(messages: [MessageQuery(role: .user, content: message.content)]) let threadsResult = try await openAIClient.threads(query: threadsQuery) guard let currentAssistantId = conversations[conversationIndex].assistantId else { return print("No assistant selected.")} @@ -117,7 +117,7 @@ public final class ChatStore: ObservableObject { guard let currentThreadId else { return print("No thread to add message to.")} let _ = try await openAIClient.threadsAddMessage(threadId: currentThreadId, - query: ThreadAddMessageQuery(role: message.role.rawValue, content: message.content)) + query: MessageQuery(role: message.role, content: message.content)) guard let currentAssistantId = conversations[conversationIndex].assistantId else { return print("No assistant selected.")} @@ -150,7 +150,7 @@ public final class ChatStore: ObservableObject { return } - let weatherFunction = ChatFunctionDeclaration( + let weatherFunction = FunctionDeclaration( name: "getWeatherData", description: "Get the current weather in a given location", parameters: .init( @@ -243,19 +243,19 @@ public final class ChatStore: ObservableObject { let result = try await openAIClient.runRetrieve(threadId: currentThreadId ?? "", runId: currentRunId ?? "") // TESTING RETRIEVAL OF RUN STEPS - handleRunRetrieveSteps() + try await handleRunRetrieveSteps() switch result.status { // Get threadsMesages. - case "completed": + case .completed: handleCompleted() - break - case "failed": + case .failed: // Handle more gracefully with a popup dialog or failure indicator await MainActor.run { self.stopPolling() } - break + case .requiresAction: + try await handleRequiresAction(result) default: // Handle additional statuses "requires_action", "queued" ?, "expired", "cancelled" // https://platform.openai.com/docs/assistants/how-it-works/runs-and-run-steps @@ -287,7 +287,7 @@ public final class ChatStore: ObservableObject { for innerItem in item.content { let message = Message( id: item.id, - role: Chat.Role(rawValue: role) ?? .user, + role: role, content: innerItem.text?.value ?? "", createdAt: Date(), isLocal: false // Messages from the server are not local @@ -308,54 +308,89 @@ public final class ChatStore: ObservableObject { } } + // Store the function call as a message and submit tool outputs with a simple done message. + private func handleRequiresAction(_ result: RunResult) async throws { + guard let currentThreadId, let currentRunId else { + return + } + + guard let toolCalls = result.requiredAction?.submitToolOutputs.toolCalls else { + return + } + + var toolOutputs = [RunToolOutputsQuery.ToolOutput]() + + for toolCall in toolCalls { + let msgContent = "function\nname: \(toolCall.function.name ?? "")\nargs: \(toolCall.function.arguments ?? "{}")" + + let runStepMessage = Message( + id: toolCall.id, + role: .assistant, + content: msgContent, + createdAt: Date(), + isRunStep: true + ) + await addOrUpdateRunStepMessage(runStepMessage) + + // Just return a generic "Done" output for now + toolOutputs.append(.init(toolCallId: toolCall.id, output: "Done")) + } + + let query = RunToolOutputsQuery(toolOutputs: toolOutputs) + _ = try await openAIClient.runSubmitToolOutputs(threadId: currentThreadId, runId: currentRunId, query: query) + } + // The run retrieval steps are fetched in a separate task. This request is fetched, checking for new run steps, each time the run is fetched. - private func handleRunRetrieveSteps() { - Task { - guard let conversationIndex = conversations.firstIndex(where: { $0.id == currentConversationId }) else { - return - } - var before: String? + private func handleRunRetrieveSteps() async throws { + var before: String? // if let lastRunStepMessage = self.conversations[conversationIndex].messages.last(where: { $0.isRunStep == true }) { // before = lastRunStepMessage.id // } - let stepsResult = try await openAIClient.runRetrieveSteps(threadId: currentThreadId ?? "", runId: currentRunId ?? "", before: before) + let stepsResult = try await openAIClient.runRetrieveSteps(threadId: currentThreadId ?? "", runId: currentRunId ?? "", before: before) - for item in stepsResult.data.reversed() { - let toolCalls = item.stepDetails.toolCalls?.reversed() ?? [] + for item in stepsResult.data.reversed() { + let toolCalls = item.stepDetails.toolCalls?.reversed() ?? [] - for step in toolCalls { - // TODO: Depending on the type of tool tha is used we can add additional information here - // ie: if its a retrieval: add file information, code_interpreter: add inputs and outputs info, or function: add arguemts and additional info. - let msgContent: String - switch step.type { - case "retrieval": - msgContent = "RUN STEP: \(step.type)" + for step in toolCalls { + // TODO: Depending on the type of tool tha is used we can add additional information here + // ie: if its a retrieval: add file information, code_interpreter: add inputs and outputs info, or function: add arguemts and additional info. + let msgContent: String + switch step.type { + case .retrieval: + msgContent = "RUN STEP: \(step.type)" - case "code_interpreter": - msgContent = "code_interpreter\ninput:\n\(step.code?.input ?? "")\noutputs: \(step.code?.outputs?.first?.logs ?? "")" + case .codeInterpreter: + let code = step.codeInterpreter + msgContent = "code_interpreter\ninput:\n\(code?.input ?? "")\noutputs: \(code?.outputs?.first?.logs ?? "")" - default: - msgContent = "RUN STEP: \(step.type)" + case .function: + msgContent = "function\nname: \(step.function?.name ?? "")\nargs: \(step.function?.arguments ?? "{}")" - } - let runStepMessage = Message( - id: step.id, - role: .assistant, - content: msgContent, - createdAt: Date(), - isRunStep: true - ) - await MainActor.run { - if let localMessageIndex = self.conversations[conversationIndex].messages.firstIndex(where: { $0.isRunStep == true && $0.id == step.id }) { - self.conversations[conversationIndex].messages[localMessageIndex] = runStepMessage - } - else { - self.conversations[conversationIndex].messages.append(runStepMessage) - } - } } + let runStepMessage = Message( + id: step.id, + role: .assistant, + content: msgContent, + createdAt: Date(), + isRunStep: true + ) + await addOrUpdateRunStepMessage(runStepMessage) } } } + + @MainActor + private func addOrUpdateRunStepMessage(_ message: Message) async { + guard let conversationIndex = conversations.firstIndex(where: { $0.id == currentConversationId }) else { + return + } + + if let localMessageIndex = conversations[conversationIndex].messages.firstIndex(where: { $0.isRunStep == true && $0.id == message.id }) { + conversations[conversationIndex].messages[localMessageIndex] = message + } + else { + conversations[conversationIndex].messages.append(message) + } + } } diff --git a/Demo/DemoChat/Sources/Models/Assistant.swift b/Demo/DemoChat/Sources/Models/Assistant.swift index f1f8dfee..eb76ad74 100644 --- a/Demo/DemoChat/Sources/Models/Assistant.swift +++ b/Demo/DemoChat/Sources/Models/Assistant.swift @@ -6,9 +6,10 @@ // import Foundation +import OpenAI struct Assistant: Hashable { - init(id: String, name: String, description: String? = nil, instructions: String? = nil, codeInterpreter: Bool, retrieval: Bool, fileIds: [String]? = nil) { + init(id: String, name: String, description: String? = nil, instructions: String? = nil, codeInterpreter: Bool, retrieval: Bool, fileIds: [String]? = nil, functions: [FunctionDeclaration] = []) { self.id = id self.name = name self.description = description @@ -16,6 +17,7 @@ struct Assistant: Hashable { self.codeInterpreter = codeInterpreter self.retrieval = retrieval self.fileIds = fileIds + self.functions = functions } typealias ID = String @@ -27,7 +29,16 @@ struct Assistant: Hashable { let fileIds: [String]? var codeInterpreter: Bool var retrieval: Bool + var functions: [FunctionDeclaration] } extension Assistant: Equatable, Identifiable {} + +extension FunctionDeclaration: Hashable { + public func hash(into hasher: inout Hasher) { + hasher.combine(name) + hasher.combine(description) + hasher.combine(parameters) + } +} diff --git a/Demo/DemoChat/Sources/UI/AssistantModalContentView.swift b/Demo/DemoChat/Sources/UI/AssistantModalContentView.swift index efde0536..508be862 100644 --- a/Demo/DemoChat/Sources/UI/AssistantModalContentView.swift +++ b/Demo/DemoChat/Sources/UI/AssistantModalContentView.swift @@ -6,6 +6,7 @@ // import SwiftUI +import OpenAI struct AssistantModalContentView: View { enum Mode { @@ -19,8 +20,11 @@ struct AssistantModalContentView: View { @Binding var codeInterpreter: Bool @Binding var retrieval: Bool + @Binding var functions: [FunctionDeclaration] @Binding var fileIds: [String] @Binding var isUploading: Bool + @State var isFunctionModalPresented = false + @State var newFunction: FunctionDeclaration? var modify: Bool @@ -34,28 +38,73 @@ struct AssistantModalContentView: View { var onFileUpload: () -> Void var body: some View { - NavigationView { - Form { - Section(header: Text("Name")) { - TextField("Name", text: $name) - } - Section(header: Text("Description")) { - TextEditor(text: $description) - .frame(minHeight: 50) - } - Section(header: Text("Custom Instructions")) { - TextEditor(text: $customInstructions) - .frame(minHeight: 100) - } - + if modify { + form + } else { + NavigationStack { + form + } + } + } + + @ViewBuilder + private var form: some View { + Form { + Section("Name") { + TextField("Name", text: $name) + } + Section("Description") { + TextEditor(text: $description) + .frame(minHeight: 50) + } + Section("Custom Instructions") { + TextEditor(text: $customInstructions) + .frame(minHeight: 100) + } + + Section("Tools") { Toggle(isOn: $codeInterpreter, label: { Text("Code interpreter") }) - + Toggle(isOn: $retrieval, label: { Text("Retrieval") }) + } + + Section("Functions") { + if !functions.isEmpty { + ForEach(functions, id: \.name) { function in + HStack { + VStack(alignment: .leading) { + Text(function.name).fontWeight(.semibold) + if let description = function.description { + Text(description) + .font(.caption) + } + if let parameters = function.parameterJSON { + Text(parameters) + .font(.caption2) + } + } + Spacer() + Button { + if let index = functions.firstIndex(of: function) { + functions.remove(at: index) + } + } label: { + Image(systemName: "xmark.circle.fill") // X button + .foregroundColor(.red) + } + } + } + } + Button("Create Function") { + isFunctionModalPresented = true + } + } + Section("Files") { if !fileIds.isEmpty { ForEach(fileIds, id: \.self) { fileId in HStack { @@ -63,23 +112,23 @@ struct AssistantModalContentView: View { Text("File: \(fileId)") Spacer() // Button to remove fileId from the list of fileIds to be used when create or modify assistant. - Button(action: { + Button { // Add action to remove the file from the list if let index = fileIds.firstIndex(of: fileId) { fileIds.remove(at: index) } - }) { + } label: { Image(systemName: "xmark.circle.fill") // X button .foregroundColor(.red) } } } } - + if let selectedFileURL { HStack { Text("File: \(selectedFileURL.lastPathComponent)") - + Button("Remove") { self.selectedFileURL = nil } @@ -97,16 +146,45 @@ struct AssistantModalContentView: View { } } } - .navigationTitle("\(modify ? "Edit" : "Enter") Assistant Details") - .navigationBarItems( - leading: Button("Cancel") { - dismiss() - }, - trailing: Button("OK") { + } + .navigationTitle("\(modify ? "Edit" : "Enter") Assistant Details") + .toolbar { + if !modify { + ToolbarItem(placement: .cancellationAction) { + Button("Cancel") { + dismiss() + } + } + } + ToolbarItemGroup(placement: .primaryAction) { + Button("Save") { onCommit() dismiss() } - ) + } + } + .sheet(isPresented: $isFunctionModalPresented) { + if let newFunction { + functions.append(newFunction) + self.newFunction = nil + } + } content: { + FunctionView(name: "", description: "", parameters: "", function: $newFunction) + } + } +} + +extension FunctionDeclaration { + var parameterJSON: String? { + guard let parameters else { + return nil + } + + do { + let parameterData = try JSONEncoder().encode(parameters) + return String(data: parameterData, encoding: .utf8) + } catch { + return nil } } } diff --git a/Demo/DemoChat/Sources/UI/AssistantsListView.swift b/Demo/DemoChat/Sources/UI/AssistantsListView.swift index 16377649..3cd59f12 100644 --- a/Demo/DemoChat/Sources/UI/AssistantsListView.swift +++ b/Demo/DemoChat/Sources/UI/AssistantsListView.swift @@ -20,10 +20,15 @@ struct AssistantsListView: View { editActions: [.delete], selection: $selectedAssistantId ) { $assistant in - Text( - assistant.name - ) - .lineLimit(2) + HStack { + Text(assistant.name) + .lineLimit(2) + Spacer() + if assistant.id == selectedAssistantId { + Image(systemName: "checkmark.circle.fill") + .foregroundColor(.accentColor) + } + } .onAppear { if assistant.id == assistants.last?.id { onLoadMoreAssistants() diff --git a/Demo/DemoChat/Sources/UI/AssistantsView.swift b/Demo/DemoChat/Sources/UI/AssistantsView.swift index d2668fdd..23211628 100644 --- a/Demo/DemoChat/Sources/UI/AssistantsView.swift +++ b/Demo/DemoChat/Sources/UI/AssistantsView.swift @@ -7,6 +7,7 @@ import Combine import SwiftUI +import OpenAI public struct AssistantsView: View { @ObservedObject var store: ChatStore @@ -27,6 +28,7 @@ public struct AssistantsView: View { @State private var codeInterpreter: Bool = false @State private var retrieval: Bool = false + @State private var functions: [FunctionDeclaration] = [] @State var isLoadingMore = false @State private var isModalPresented = false @State private var isUploading = false @@ -59,67 +61,83 @@ public struct AssistantsView: View { }, isLoadingMore: $isLoadingMore ) .toolbar { - ToolbarItem( - placement: .primaryAction - ) { - Menu { - Button("Get Assistants") { - Task { - let _ = await assistantStore.getAssistants() - } + ToolbarItemGroup(placement: .primaryAction) { + Button { + mode = .create + isModalPresented = true + } label: { + Label("Create Assistant", systemImage: "plus") + } + Button { + guard let asstId = assistantStore.selectedAssistantId else { + return } - Button("Create Assistant") { - mode = .create - isModalPresented = true + + // Create new local conversation to represent new thread. + store.createConversation(type: .assistant, assistantId: asstId) + } label: { + Label("Start Chat", systemImage: "plus.message") + } + .disabled(assistantStore.selectedAssistantId == nil) + Button { + Task { + let _ = await assistantStore.getAssistants() } } label: { - Image(systemName: "plus") + Label("Get Assistants", systemImage: "arrow.triangle.2.circlepath") } - - .buttonStyle(.borderedProminent) } } } detail: { - + if assistantStore.selectedAssistantId != nil { + assistantContentView() + } else { + Text("Select an assistant") + } } - .sheet(isPresented: $isModalPresented, onDismiss: { + .sheet(isPresented: $isModalPresented) { resetAssistantCreator() - }, content: { - AssistantModalContentView(name: $name, description: $description, customInstructions: $customInstructions, - codeInterpreter: $codeInterpreter, retrieval: $retrieval, fileIds: $fileIds, - isUploading: $isUploading, modify: mode == .modify, isPickerPresented: $isPickerPresented, selectedFileURL: $selectedFileURL) { - Task { - await handleOKTap() - } - } onFileUpload: { - Task { - guard let selectedFileURL else { return } - - isUploading = true - let file = await assistantStore.uploadFile(url: selectedFileURL) - uploadedFileId = file?.id - isUploading = false - - if uploadedFileId == nil { - print("Failed to upload") - self.selectedFileURL = nil - } - else { - // if successful upload , we can show it. - if let uploadedFileId = uploadedFileId { - self.selectedFileURL = nil - - fileIds += [uploadedFileId] + } content: { + assistantContentView() + } + } + } - print("Successful upload!") - } - } + @ViewBuilder + private func assistantContentView() -> some View { + AssistantModalContentView(name: $name, description: $description, customInstructions: $customInstructions, + codeInterpreter: $codeInterpreter, retrieval: $retrieval, functions: $functions, fileIds: $fileIds, + isUploading: $isUploading, modify: mode == .modify, isPickerPresented: $isPickerPresented, selectedFileURL: $selectedFileURL) { + Task { + await handleOKTap() + } + } onFileUpload: { + Task { + guard let selectedFileURL else { return } + + isUploading = true + let file = await assistantStore.uploadFile(url: selectedFileURL) + uploadedFileId = file?.id + isUploading = false + + if uploadedFileId == nil { + print("Failed to upload") + self.selectedFileURL = nil + } + else { + // if successful upload , we can show it. + if let uploadedFileId = uploadedFileId { + self.selectedFileURL = nil + + fileIds += [uploadedFileId] + + print("Successful upload!") } } - }) + } } } - + private func handleOKTap() async { var mergedFileIds = [String]() @@ -129,26 +147,26 @@ public struct AssistantsView: View { let asstId: String? switch mode { - // Create new Assistant and start a new conversation with it. + // Create new Assistant and select it case .create: - asstId = await assistantStore.createAssistant(name: name, description: description, instructions: customInstructions, codeInterpreter: codeInterpreter, retrievel: retrieval, fileIds: mergedFileIds.isEmpty ? nil : mergedFileIds) - // Modify existing Assistant and start new conversation with it. + asstId = await assistantStore.createAssistant(name: name, description: description, instructions: customInstructions, codeInterpreter: codeInterpreter, retrieval: retrieval, functions: functions, fileIds: mergedFileIds.isEmpty ? nil : mergedFileIds) + assistantStore.selectedAssistantId = asstId + // Modify existing Assistant case .modify: - guard let selectedAssistantId = assistantStore.selectedAssistantId else { return print("Cannot modify assistant, not selected.") } + guard let selectedAssistantId = assistantStore.selectedAssistantId else { + print("Cannot modify assistant, not selected.") + return + } - asstId = await assistantStore.modifyAssistant(asstId: selectedAssistantId, name: name, description: description, instructions: customInstructions, codeInterpreter: codeInterpreter, retrievel: retrieval, fileIds: mergedFileIds.isEmpty ? nil : mergedFileIds) + asstId = await assistantStore.modifyAssistant(asstId: selectedAssistantId, name: name, description: description, instructions: customInstructions, codeInterpreter: codeInterpreter, retrieval: retrieval, functions: functions, fileIds: mergedFileIds.isEmpty ? nil : mergedFileIds) } // Reset Assistant Creator after attempted creation or modification. resetAssistantCreator() - guard let asstId else { - print("failed to create Assistant.") - return + if asstId == nil { + print("Failed to modify or create Assistant.") } - - // Create new local conversation to represent new thread. - store.createConversation(type: .assistant, assistantId: asstId) } private func loadMoreAssistants() { @@ -172,6 +190,7 @@ public struct AssistantsView: View { codeInterpreter = false retrieval = false + functions = [] selectedFileURL = nil uploadedFileId = nil fileIds = [] @@ -187,10 +206,10 @@ public struct AssistantsView: View { customInstructions = selectedAssistant?.instructions ?? "" codeInterpreter = selectedAssistant?.codeInterpreter ?? false retrieval = selectedAssistant?.retrieval ?? false + functions = selectedAssistant?.functions ?? [] fileIds = selectedAssistant?.fileIds ?? [] mode = .modify - isModalPresented = true } } diff --git a/Demo/DemoChat/Sources/UI/FunctionView.swift b/Demo/DemoChat/Sources/UI/FunctionView.swift new file mode 100644 index 00000000..97413bda --- /dev/null +++ b/Demo/DemoChat/Sources/UI/FunctionView.swift @@ -0,0 +1,72 @@ +// +// FunctionView.swift +// +// +// Created by Brent Whitman on 2024-01-31. +// + +import SwiftUI +import OpenAI + +struct FunctionView: View { + @Environment(\.dismiss) var dismiss + @State var name: String + @State var description: String + @State var parameters: String + @Binding var function: FunctionDeclaration? + @State var isShowingAlert = false + @State var alertMessage = "" + + var body: some View { + NavigationStack { + Form { + TextField("Name", text: $name) + TextField("Description", text: $description) + TextField("Parameters", text: $parameters) + } + .navigationTitle("Create Function") + .navigationBarTitleDisplayMode(.inline) + .toolbarBackground(.visible, for: .navigationBar) + .toolbar { + ToolbarItem(placement: .cancellationAction) { + Button("Cancel") { + dismiss() + } + } + ToolbarItem(placement: .confirmationAction) { + Button("Save") { + let parameters = validateParameters() + guard !isShowingAlert else { + return + } + + function = FunctionDeclaration(name: name, description: description, parameters: parameters) + dismiss() + } + } + } + .alert(isPresented: $isShowingAlert) { + Alert(title: Text("Parameters Error"), message: Text(alertMessage)) + } + } + } + + private func validateParameters() -> JSONSchema? { + guard !parameters.isEmpty, let parametersData = parameters.data(using: .utf8) else { + return nil + } + + do { + let parametersJSON = try JSONDecoder().decode(JSONSchema.self, from: parametersData) + return parametersJSON + } catch { + alertMessage = error.localizedDescription + isShowingAlert = true + return nil + } + } +} + +#Preview { + FunctionView(name: "print", description: "Prints text to the console", parameters: "{\"type\": \"string\"}", function: .constant(nil)) +} diff --git a/README.md b/README.md index 13545443..cbadd537 100644 --- a/README.md +++ b/README.md @@ -41,13 +41,14 @@ This repository contains Swift community-maintained implementation over [OpenAI] - [List Assistants](#list-assistants) - [Threads](#threads) - [Create Thread](#create-thread) + - [Create and Run Thread](#create-and-run-thread) - [Get Threads Messages](#get-threads-messages) - [Add Message to Thread](#add-message-to-thread) - [Runs](#runs) - [Create Run](#create-run) - [Retrieve Run](#retrieve-run) - [Retrieve Run Steps](#retrieve-run-steps) - + - [Submit Tool Outputs for Run](#submit-tool-outputs-for-run) - [Files](#files) - [Upload File](#upload-file) - [Example Project](#example-project) @@ -225,31 +226,35 @@ Using the OpenAI Chat API, you can build your own applications with `gpt-3.5-tur **Request** ```swift - struct ChatQuery: Codable { - /// ID of the model to use. Currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported. - public let model: Model - /// The messages to generate chat completions for - public let messages: [Chat] - /// A list of functions the model may generate JSON inputs for. - public let functions: [ChatFunctionDeclaration]? - /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and We generally recommend altering this or top_p but not both. - public let temperature: Double? - /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. - public let topP: Double? - /// How many chat completion choices to generate for each input message. - public let n: Int? - /// Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence. - public let stop: [String]? - /// The maximum number of tokens to generate in the completion. - public let maxTokens: Int? - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. - public let presencePenalty: Double? - /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. - public let frequencyPenalty: Double? - ///Modify the likelihood of specified tokens appearing in the completion. - public let logitBias: [String:Int]? - /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. - public let user: String? +struct ChatQuery: Codable { + /// ID of the model to use. + public let model: Model + /// An object specifying the format that the model must output. + public let responseFormat: ResponseFormat? + /// The messages to generate chat completions for + public let messages: [Message] + /// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. + public let tools: [Tool]? + /// Controls how the model responds to tool calls. "none" means the model does not call a function, and responds to the end-user. "auto" means the model can pick between and end-user or calling a function. Specifying a particular function via `{"name": "my_function"}` forces the model to call that function. "none" is the default when no functions are present. "auto" is the default if functions are present. + public let toolChoice: ToolChoice? + /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and We generally recommend altering this or top_p but not both. + public let temperature: Double? + /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + public let topP: Double? + /// How many chat completion choices to generate for each input message. + public let n: Int? + /// Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence. + public let stop: [String]? + /// The maximum number of tokens to generate in the completion. + public let maxTokens: Int? + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + public let presencePenalty: Double? + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + public let frequencyPenalty: Double? + /// Modify the likelihood of specified tokens appearing in the completion. + public let logitBias: [String:Int]? + /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + public let user: String? } ``` @@ -347,7 +352,7 @@ for try await result in openAI.chatsStream(query: query) { let openAI = OpenAI(apiToken: "...") // Declare functions which GPT-3 might decide to call. let functions = [ - ChatFunctionDeclaration( + FunctionDeclaration( name: "get_current_weather", description: "Get the current weather in a given location", parameters: @@ -366,7 +371,7 @@ let query = ChatQuery( messages: [ Chat(role: .user, content: "What's the weather like in Boston?") ], - functions: functions + tools: functions.map { Tool.function($0) } ) let result = try await openAI.chats(query: query) ``` @@ -383,10 +388,16 @@ Result will be (serialized as JSON here for readability): "index": 0, "message": { "role": "assistant", - "function_call": { - "name": "get_current_weather", - "arguments": "{\n \"location\": \"Boston, MA\"\n}" - } + "tool_calls": [ + { + "id": "call-0", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\n \"location\": \"Boston, MA\"\n}" + } + } + ] }, "finish_reason": "function_call" } @@ -1022,7 +1033,7 @@ Review [Assistants Documentation](https://platform.openai.com/docs/api-reference Example: Create Assistant ``` let query = AssistantsQuery(model: Model.gpt4_1106_preview, name: name, description: description, instructions: instructions, tools: tools, fileIds: fileIds) -openAI.assistants(query: query) { result in +openAI.assistantCreate(query: query) { result in //Handle response here } ``` @@ -1032,7 +1043,7 @@ openAI.assistants(query: query) { result in Example: Modify Assistant ``` let query = AssistantsQuery(model: Model.gpt4_1106_preview, name: name, description: description, instructions: instructions, tools: tools, fileIds: fileIds) -openAI.assistantModify(query: query, asstId: "asst_1234") { result in +openAI.assistantModify(query: query, assistantId: "asst_1234") { result in //Handle response here } ``` @@ -1041,7 +1052,7 @@ openAI.assistantModify(query: query, asstId: "asst_1234") { result in Example: List Assistants ``` -openAI.assistants(query: nil, method: "GET") { result in +openAI.assistants() { result in //Handle response here } ``` @@ -1060,13 +1071,24 @@ openAI.threads(query: threadsQuery) { result in } ``` +##### Create and Run Thread + +Example: Create and Run Thread +``` +let threadsQuery = ThreadQuery(messages: [Chat(role: message.role, content: message.content)]) +let threadRunQuery = ThreadRunQuery(assistantId: "asst_1234" thread: threadsQuery) +openAI.threadRun(query: threadRunQuery) { result in + //Handle response here +} +``` + ##### Get Threads Messages Review [Messages Documentation](https://platform.openai.com/docs/api-reference/messages) for more info. Example: Get Threads Messages ``` -openAI.threadsMessages(threadId: currentThreadId, before: nil) { result in +openAI.threadsMessages(threadId: currentThreadId) { result in //Handle response here } ``` @@ -1075,7 +1097,7 @@ openAI.threadsMessages(threadId: currentThreadId, before: nil) { result in Example: Add Message to Thread ``` -let query = ThreadAddMessageQuery(role: message.role.rawValue, content: message.content) +let query = MessageQuery(role: message.role.rawValue, content: message.content) openAI.threadsAddMessage(threadId: currentThreadId, query: query) { result in //Handle response here } @@ -1108,7 +1130,18 @@ openAI.runRetrieve(threadId: currentThreadId, runId: currentRunId) { result in Example: Retrieve Run Steps ``` -openAI.runRetrieveSteps(threadId: currentThreadId, runId: currentRunId, before: nil) { result in +openAI.runRetrieveSteps(threadId: currentThreadId, runId: currentRunId) { result in + //Handle response here +} +``` + +##### Submit Tool Outputs for Run + +Example: Submit Tool Outputs for Run +``` +let output = RunToolOutputsQuery.ToolOutput(toolCallId: "call123", output: "Success") +let query = RunToolOutputsQuery(toolOutputs: [output]) +openAI.runSubmitToolOutputs(threadId: currentThreadId, runId: currentRunId, query: query) { result in //Handle response here } ``` diff --git a/Sources/OpenAI/OpenAI.swift b/Sources/OpenAI/OpenAI.swift index b1beaa7d..ad0ec1c9 100644 --- a/Sources/OpenAI/OpenAI.swift +++ b/Sources/OpenAI/OpenAI.swift @@ -57,36 +57,48 @@ final public class OpenAI: OpenAIProtocol { } // UPDATES FROM 11-06-23 - public func threadsAddMessage(threadId: String, query: ThreadAddMessageQuery, completion: @escaping (Result) -> Void) { + public func threadsAddMessage(threadId: String, query: MessageQuery, completion: @escaping (Result) -> Void) { performRequest(request: JSONRequest(body: query, url: buildRunsURL(path: .threadsMessages, threadId: threadId)), completion: completion) } - public func threadsMessages(threadId: String, before: String?, completion: @escaping (Result) -> Void) { + public func threadsMessages(threadId: String, before: String? = nil, completion: @escaping (Result) -> Void) { performRequest(request: JSONRequest(body: nil, url: buildRunsURL(path: .threadsMessages, threadId: threadId, before: before), method: "GET"), completion: completion) } - public func runRetrieve(threadId: String, runId: String, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(body: nil, url: buildRunRetrieveURL(path: .runRetrieve, threadId: threadId, runId: runId, before: nil), method: "GET"), completion: completion) + public func runRetrieve(threadId: String, runId: String, completion: @escaping (Result) -> Void) { + performRequest(request: JSONRequest(body: nil, url: buildRunRetrieveURL(path: .runRetrieve, threadId: threadId, runId: runId), method: "GET"), completion: completion) } - public func runRetrieveSteps(threadId: String, runId: String, before: String?, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(body: nil, url: buildRunRetrieveURL(path: .runRetrieveSteps, threadId: threadId, runId: runId, before: before), method: "GET"), completion: completion) + public func runRetrieveSteps(threadId: String, runId: String, before: String? = nil, completion: @escaping (Result) -> Void) { + performRequest(request: JSONRequest(body: nil, url: buildRunRetrieveURL(path: .runRetrieveSteps, threadId: threadId, runId: runId, before: before), method: "GET"), completion: completion) + } + + public func runSubmitToolOutputs(threadId: String, runId: String, query: RunToolOutputsQuery, completion: @escaping (Result) -> Void) { + performRequest(request: JSONRequest(body: query, url: buildURL(path: .runSubmitToolOutputs(threadId: threadId, runId: runId)), method: "POST"), completion: completion) } - public func runs(threadId: String, query: RunsQuery, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(body: query, url: buildRunsURL(path: .runs, threadId: threadId)), completion: completion) + public func runs(threadId: String, query: RunsQuery, completion: @escaping (Result) -> Void) { + performRequest(request: JSONRequest(body: query, url: buildRunsURL(path: .runs, threadId: threadId)), completion: completion) } public func threads(query: ThreadsQuery, completion: @escaping (Result) -> Void) { performRequest(request: JSONRequest(body: query, url: buildURL(path: .threads)), completion: completion) } + + public func threadRun(query: ThreadRunQuery, completion: @escaping (Result) -> Void) { + performRequest(request: JSONRequest(body: query, url: buildURL(path: .threadRun)), completion: completion) + } - public func assistants(query: AssistantsQuery?, method: String, after: String?, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(body: query, url: buildURL(path: .assistants, after: after), method: method), completion: completion) + public func assistants(after: String? = nil, completion: @escaping (Result) -> Void) { + performRequest(request: JSONRequest(url: buildURL(path: .assistants, after: after), method: "GET"), completion: completion) } - public func assistantModify(query: AssistantsQuery, asstId: String, completion: @escaping (Result) -> Void) { - performRequest(request: JSONRequest(body: query, url: buildAssistantURL(path: .assistantsModify, assistantId: asstId)), completion: completion) + public func assistantCreate(query: AssistantsQuery, completion: @escaping (Result) -> Void) { + performRequest(request: JSONRequest(body: query, url: buildURL(path: .assistants), method: "POST"), completion: completion) + } + + public func assistantModify(query: AssistantsQuery, assistantId: String, completion: @escaping (Result) -> Void) { + performRequest(request: JSONRequest(body: query, url: buildAssistantURL(path: .assistantsModify, assistantId: assistantId), method: "POST"), completion: completion) } public func files(query: FilesQuery, completion: @escaping (Result) -> Void) { @@ -306,9 +318,13 @@ extension APIPath { static let assistants = "/v1/assistants" static let assistantsModify = "/v1/assistants/ASST_ID" static let threads = "/v1/threads" + static let threadRun = "/v1/threads/runs" static let runs = "/v1/threads/THREAD_ID/runs" static let runRetrieve = "/v1/threads/THREAD_ID/runs/RUN_ID" static let runRetrieveSteps = "/v1/threads/THREAD_ID/runs/RUN_ID/steps" + static func runSubmitToolOutputs(threadId: String, runId: String) -> String { + "/v1/threads/\(threadId)/runs/\(runId)/submit_tool_outputs" + } static let threadsMessages = "/v1/threads/THREAD_ID/messages" static let files = "/v1/files" // 1106 end diff --git a/Sources/OpenAI/Public/Models/AssistantResult.swift b/Sources/OpenAI/Public/Models/AssistantResult.swift new file mode 100644 index 00000000..0aa66cf2 --- /dev/null +++ b/Sources/OpenAI/Public/Models/AssistantResult.swift @@ -0,0 +1,26 @@ +// +// AssistantResult.swift +// +// +// Created by Brent Whitman on 2024-01-29. +// + +import Foundation + +public struct AssistantResult: Codable, Equatable { + public let id: String + public let name: String? + public let description: String? + public let instructions: String? + public let tools: [Tool]? + public let fileIds: [String]? + + enum CodingKeys: String, CodingKey { + case id + case name + case description + case instructions + case tools + case fileIds = "file_ids" + } +} diff --git a/Sources/OpenAI/Public/Models/AssistantsQuery.swift b/Sources/OpenAI/Public/Models/AssistantsQuery.swift index 64c22cfe..b45afe89 100644 --- a/Sources/OpenAI/Public/Models/AssistantsQuery.swift +++ b/Sources/OpenAI/Public/Models/AssistantsQuery.swift @@ -7,18 +7,12 @@ import Foundation -public struct AssistantsQuery: Codable { - +public struct AssistantsQuery: Codable, Equatable { public let model: Model - - public let name: String - - public let description: String - - public let instructions: String - + public let name: String? + public let description: String? + public let instructions: String? public let tools: [Tool]? - public let fileIds: [String]? enum CodingKeys: String, CodingKey { @@ -30,27 +24,12 @@ public struct AssistantsQuery: Codable { case fileIds = "file_ids" } - public init(model: Model, name: String, description: String, instructions: String, tools: [Tool], fileIds: [String]? = nil) { + public init(model: Model, name: String?, description: String?, instructions: String?, tools: [Tool]?, fileIds: [String]? = nil) { self.model = model self.name = name - self.description = description self.instructions = instructions - self.tools = tools self.fileIds = fileIds } } - -public struct Tool: Codable, Equatable { - public let toolType: String - - enum CodingKeys: String, CodingKey { - case toolType = "type" - } - - public init(toolType: String) { - self.toolType = toolType - } - -} diff --git a/Sources/OpenAI/Public/Models/AssistantsResult.swift b/Sources/OpenAI/Public/Models/AssistantsResult.swift index 39ad2a5e..cef22fd8 100644 --- a/Sources/OpenAI/Public/Models/AssistantsResult.swift +++ b/Sources/OpenAI/Public/Models/AssistantsResult.swift @@ -9,33 +9,15 @@ import Foundation public struct AssistantsResult: Codable, Equatable { - public let id: String? - - public let data: [AssistantContent]? - public let tools: [Tool]? + public let data: [AssistantResult]? + public let firstId: String? + public let lastId: String? + public let hasMore: Bool enum CodingKeys: String, CodingKey { case data - case id - case tools - } - - public struct AssistantContent: Codable, Equatable { - - public let id: String - public let name: String - public let description: String? - public let instructions: String? - public let tools: [Tool]? - public let fileIds: [String]? - - enum CodingKeys: String, CodingKey { - case id - case name - case description - case instructions - case tools - case fileIds = "file_ids" - } + case firstId = "first_id" + case lastId = "last_id" + case hasMore = "has_more" } } diff --git a/Sources/OpenAI/Public/Models/ChatQuery.swift b/Sources/OpenAI/Public/Models/ChatQuery.swift index 58be8f16..6c2b4fe7 100644 --- a/Sources/OpenAI/Public/Models/ChatQuery.swift +++ b/Sources/OpenAI/Public/Models/ChatQuery.swift @@ -20,13 +20,88 @@ public struct ResponseFormat: Codable, Equatable { } } +public enum Message: Codable, Equatable { + case system(content: String, name: String? = nil) + case user(content: String, name: String? = nil) + case assistant(content: String? = nil, name: String? = nil, toolCalls: [ChatToolCall]? = nil) + case tool(content: String, toolCallId: String) + + enum CodingKeys: String, CodingKey { + case role + case content + case name + case toolCalls = "tool_calls" + case toolCallId = "tool_call_id" + } + + var roleKey: String { + switch self { + case .system: + return "system" + case .user: + return "user" + case .assistant: + return "assistant" + case .tool: + return "tool" + } + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let roleString = try container.decode(String.self, forKey: .role) + + switch roleString { + case "system": + let content = try container.decode(String.self, forKey: .content) + let name = try container.decodeIfPresent(String.self, forKey: .name) + self = .system(content: content, name: name) + case "user": + let content = try container.decode(String.self, forKey: .content) + let name = try container.decodeIfPresent(String.self, forKey: .name) + self = .user(content: content, name: name) + case "assistant": + let content = try container.decodeIfPresent(String.self, forKey: .content) + let name = try container.decodeIfPresent(String.self, forKey: .name) + let toolCalls = try container.decodeIfPresent([ChatToolCall].self, forKey: .toolCalls) + self = .assistant(content: content, name: name, toolCalls: toolCalls) + case "tool": + let content = try container.decode(String.self, forKey: .content) + let toolCallId = try container.decode(String.self, forKey: .toolCallId) + self = .tool(content: content, toolCallId: toolCallId) + default: + throw DecodingError.dataCorruptedError(forKey: .role, in: container, debugDescription: "Invalid message role") + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(roleKey, forKey: .role) + switch self { + case let .system(content, name): + try container.encode(content, forKey: .content) + try container.encodeIfPresent(name, forKey: .name) + case let .user(content, name): + try container.encode(content, forKey: .content) + try container.encodeIfPresent(name, forKey: .name) + case let .assistant(content, name, toolCalls): + try container.encodeIfPresent(content, forKey: .content) + try container.encodeIfPresent(name, forKey: .name) + try container.encodeIfPresent(toolCalls, forKey: .toolCalls) + case let .tool(content, toolCallId): + try container.encode(content, forKey: .content) + try container.encode(toolCallId, forKey: .toolCallId) + } + } +} + public struct Chat: Codable, Equatable { public let role: Role /// The contents of the message. `content` is required for all messages except assistant messages with function calls. public let content: String? /// The name of the author of this message. `name` is required if role is `function`, and it should be the name of the function whose response is in the `content`. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters. public let name: String? - public let functionCall: ChatFunctionCall? + public let toolCalls: [ChatToolCall]? public enum Role: String, Codable, Equatable { case system @@ -39,14 +114,14 @@ public struct Chat: Codable, Equatable { case role case content case name - case functionCall = "function_call" + case toolCalls = "tool_calls" } - public init(role: Role, content: String? = nil, name: String? = nil, functionCall: ChatFunctionCall? = nil) { + public init(role: Role, content: String? = nil, name: String? = nil, toolCalls: [ChatToolCall]? = nil) { self.role = role self.content = content self.name = name - self.functionCall = functionCall + self.toolCalls = toolCalls } public func encode(to encoder: Encoder) throws { @@ -57,25 +132,41 @@ public struct Chat: Codable, Equatable { try container.encode(name, forKey: .name) } - if let functionCall = functionCall { - try container.encode(functionCall, forKey: .functionCall) + if let toolCalls = toolCalls { + try container.encode(toolCalls, forKey: .toolCalls) } // Should add 'nil' to 'content' property for function calling response // See https://openai.com/blog/function-calling-and-other-api-updates - if content != nil || (role == .assistant && functionCall != nil) { + if content != nil || (role == .assistant && toolCalls != nil) { try container.encode(content, forKey: .content) } } } +public struct ChatToolCall: Codable, Equatable { + public enum ToolType: String, Codable, Equatable { + case function + } + + public let id: String + public let type: ToolType + public let function: ChatFunctionCall + + public init(id: String, type: ToolType = .function, function: ChatFunctionCall) { + self.id = id + self.type = type + self.function = function + } +} + public struct ChatFunctionCall: Codable, Equatable { /// The name of the function to call. - public let name: String? + public let name: String /// The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. - public let arguments: String? + public let arguments: String - public init(name: String?, arguments: String?) { + public init(name: String, arguments: String) { self.name = name self.arguments = arguments } @@ -83,7 +174,7 @@ public struct ChatFunctionCall: Codable, Equatable { /// See the [guide](/docs/guides/gpt/function-calling) for examples, and the [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for documentation about the format. -public struct JSONSchema: Codable, Equatable { +public struct JSONSchema: Codable, Hashable { public let type: JSONType public let properties: [String: Property]? public let required: [String]? @@ -100,7 +191,7 @@ public struct JSONSchema: Codable, Equatable { case multipleOf, minimum, maximum } - public struct Property: Codable, Equatable { + public struct Property: Codable, Hashable { public let type: JSONType public let description: String? public let format: String? @@ -151,7 +242,7 @@ public struct JSONSchema: Codable, Equatable { case `null` = "null" } - public struct Items: Codable, Equatable { + public struct Items: Codable, Hashable { public let type: JSONType public let properties: [String: Property]? public let pattern: String? @@ -198,23 +289,6 @@ public struct JSONSchema: Codable, Equatable { } } -public struct ChatFunctionDeclaration: Codable, Equatable { - /// The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. - public let name: String - - /// The description of what the function does. - public let description: String - - /// The parameters the functions accepts, described as a JSON Schema object. - public let parameters: JSONSchema - - public init(name: String, description: String, parameters: JSONSchema) { - self.name = name - self.description = description - self.parameters = parameters - } -} - public struct ChatQueryFunctionCall: Codable, Equatable { /// The name of the function to call. public let name: String? @@ -223,16 +297,16 @@ public struct ChatQueryFunctionCall: Codable, Equatable { } public struct ChatQuery: Equatable, Codable, Streamable { - /// ID of the model to use. Currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported. + /// ID of the model to use. public let model: Model /// An object specifying the format that the model must output. public let responseFormat: ResponseFormat? /// The messages to generate chat completions for - public let messages: [Chat] - /// A list of functions the model may generate JSON inputs for. - public let functions: [ChatFunctionDeclaration]? - /// Controls how the model responds to function calls. "none" means the model does not call a function, and responds to the end-user. "auto" means the model can pick between and end-user or calling a function. Specifying a particular function via `{"name": "my_function"}` forces the model to call that function. "none" is the default when no functions are present. "auto" is the default if functions are present. - public let functionCall: FunctionCall? + public let messages: [Message] + /// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. + public let tools: [Tool]? + /// Controls how the model responds to tool calls. "none" means the model does not call a function, and responds to the end-user. "auto" means the model can pick between and end-user or calling a function. Specifying a particular function via `{"name": "my_function"}` forces the model to call that function. "none" is the default when no functions are present. "auto" is the default if functions are present. + public let toolChoice: ToolChoice? /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and We generally recommend altering this or top_p but not both. public let temperature: Double? /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. @@ -254,15 +328,46 @@ public struct ChatQuery: Equatable, Codable, Streamable { var stream: Bool = false - public enum FunctionCall: Codable, Equatable { + public enum ToolChoice: Codable, Equatable { + struct ToolFunction: Codable, Equatable { + let name: String + } + case none case auto case function(String) enum CodingKeys: String, CodingKey { - case none = "none" - case auto = "auto" - case function = "name" + case none + case auto + case type + case function + } + + public init(from decoder: Decoder) throws { + do { + let container = try decoder.singleValueContainer() + let type = try container.decode(String.self) + switch type { + case CodingKeys.none.rawValue: + self = .none + case CodingKeys.auto.rawValue: + self = .auto + default: + throw DecodingError.dataCorruptedError(in: container, debugDescription: "Invalid tool choice") + } + } catch { + let container = try decoder.container(keyedBy: CodingKeys.self) + let type = try container.decode(String.self, forKey: .type) + + switch type { + case CodingKeys.function.rawValue: + let function = try container.decode(ToolFunction.self, forKey: .function) + self = .function(function.name) + default: + throw DecodingError.dataCorruptedError(forKey: .type, in: container, debugDescription: "Invalid type") + } + } } public func encode(to encoder: Encoder) throws { @@ -275,7 +380,8 @@ public struct ChatQuery: Equatable, Codable, Streamable { try container.encode(CodingKeys.auto.rawValue) case .function(let name): var container = encoder.container(keyedBy: CodingKeys.self) - try container.encode(name, forKey: .function) + try container.encode(CodingKeys.function.rawValue, forKey: .type) + try container.encode(ToolFunction(name: name), forKey: .function) } } } @@ -283,8 +389,8 @@ public struct ChatQuery: Equatable, Codable, Streamable { enum CodingKeys: String, CodingKey { case model case messages - case functions - case functionCall = "function_call" + case tools + case toolChoice = "tool_choice" case temperature case topP = "top_p" case n @@ -298,11 +404,11 @@ public struct ChatQuery: Equatable, Codable, Streamable { case responseFormat = "response_format" } - public init(model: Model, messages: [Chat], responseFormat: ResponseFormat? = nil, functions: [ChatFunctionDeclaration]? = nil, functionCall: FunctionCall? = nil, temperature: Double? = nil, topP: Double? = nil, n: Int? = nil, stop: [String]? = nil, maxTokens: Int? = nil, presencePenalty: Double? = nil, frequencyPenalty: Double? = nil, logitBias: [String : Int]? = nil, user: String? = nil, stream: Bool = false) { + public init(model: Model, messages: [Message], responseFormat: ResponseFormat? = nil, tools: [Tool]? = nil, toolChoice: ToolChoice? = nil, temperature: Double? = nil, topP: Double? = nil, n: Int? = nil, stop: [String]? = nil, maxTokens: Int? = nil, presencePenalty: Double? = nil, frequencyPenalty: Double? = nil, logitBias: [String : Int]? = nil, user: String? = nil, stream: Bool = false) { self.model = model self.messages = messages - self.functions = functions - self.functionCall = functionCall + self.tools = tools + self.toolChoice = toolChoice self.temperature = temperature self.topP = topP self.n = n diff --git a/Sources/OpenAI/Public/Models/ChatResult.swift b/Sources/OpenAI/Public/Models/ChatResult.swift index f1a80a0c..ca78618c 100644 --- a/Sources/OpenAI/Public/Models/ChatResult.swift +++ b/Sources/OpenAI/Public/Models/ChatResult.swift @@ -13,7 +13,7 @@ public struct ChatResult: Codable, Equatable { public let index: Int /// Exists only if it is a complete message. - public let message: Chat + public let message: Message /// Exists only if it is a complete message. public let finishReason: String? diff --git a/Sources/OpenAI/Public/Models/FunctionDeclaration.swift b/Sources/OpenAI/Public/Models/FunctionDeclaration.swift new file mode 100644 index 00000000..66e70164 --- /dev/null +++ b/Sources/OpenAI/Public/Models/FunctionDeclaration.swift @@ -0,0 +1,25 @@ +// +// FunctionDeclaration.swift +// +// +// Created by Brent Whitman on 2024-01-29. +// + +import Foundation + +public struct FunctionDeclaration: Codable, Equatable { + /// The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + public let name: String + + /// The description of what the function does. + public let description: String? + + /// The parameters the functions accepts, described as a JSON Schema object. + public let parameters: JSONSchema? + + public init(name: String, description: String?, parameters: JSONSchema?) { + self.name = name + self.description = description + self.parameters = parameters + } +} diff --git a/Sources/OpenAI/Public/Models/ImagesQuery.swift b/Sources/OpenAI/Public/Models/ImagesQuery.swift index 6f9bd788..55b5821d 100644 --- a/Sources/OpenAI/Public/Models/ImagesQuery.swift +++ b/Sources/OpenAI/Public/Models/ImagesQuery.swift @@ -13,7 +13,7 @@ public enum ImageResponseFormat: String, Codable, Equatable { case b64_json } -public struct ImagesQuery: Codable { +public struct ImagesQuery: Codable, Equatable { public typealias ResponseFormat = ImageResponseFormat /// A text description of the desired image(s). The maximum length is 1000 characters. diff --git a/Sources/OpenAI/Public/Models/MessageQuery.swift b/Sources/OpenAI/Public/Models/MessageQuery.swift new file mode 100644 index 00000000..b3641dad --- /dev/null +++ b/Sources/OpenAI/Public/Models/MessageQuery.swift @@ -0,0 +1,26 @@ +// +// MessageQuery.swift +// +// +// Created by Chris Dillard on 11/07/2023. +// + +import Foundation + +public struct MessageQuery: Equatable, Codable { + public let role: Chat.Role + public let content: String + public let fileIds: [String]? + + enum CodingKeys: String, CodingKey { + case role + case content + case fileIds = "file_ids" + } + + public init(role: Chat.Role, content: String, fileIds: [String]? = nil) { + self.role = role + self.content = content + self.fileIds = fileIds + } +} diff --git a/Sources/OpenAI/Public/Models/RunResult.swift b/Sources/OpenAI/Public/Models/RunResult.swift new file mode 100644 index 00000000..05009beb --- /dev/null +++ b/Sources/OpenAI/Public/Models/RunResult.swift @@ -0,0 +1,55 @@ +// +// RunResult.swift +// +// +// Created by Chris Dillard on 11/07/2023. +// + +import Foundation + +public struct RunResult: Codable, Equatable { + public enum Status: String, Codable { + case queued + case inProgress = "in_progress" + case requiresAction = "requires_action" + case cancelling + case cancelled + case failed + case completed + case expired + } + + public struct RequiredAction: Codable, Equatable { + public let submitToolOutputs: SubmitToolOutputs + + enum CodingKeys: String, CodingKey { + case submitToolOutputs = "submit_tool_outputs" + } + } + + public struct SubmitToolOutputs: Codable, Equatable { + public let toolCalls: [ToolCall] + + enum CodingKeys: String, CodingKey { + case toolCalls = "tool_calls" + } + } + + public struct ToolCall: Codable, Equatable { + public let id: String + public let type: String + public let function: ChatFunctionCall + } + + enum CodingKeys: String, CodingKey { + case id + case threadId = "thread_id" + case status + case requiredAction = "required_action" + } + + public let id: String + public let threadId: String + public let status: Status + public let requiredAction: RequiredAction? +} diff --git a/Sources/OpenAI/Public/Models/RunRetrieveResult.swift b/Sources/OpenAI/Public/Models/RunRetrieveResult.swift deleted file mode 100644 index 5e377b9a..00000000 --- a/Sources/OpenAI/Public/Models/RunRetrieveResult.swift +++ /dev/null @@ -1,13 +0,0 @@ -// -// RunsResult.swift -// -// -// Created by Chris Dillard on 11/07/2023. -// - -import Foundation - -public struct RunRetreiveResult: Codable, Equatable { - - public let status: String -} diff --git a/Sources/OpenAI/Public/Models/RunRetrieveStepsResult.swift b/Sources/OpenAI/Public/Models/RunRetrieveStepsResult.swift index 5bbd47bb..6ce851ef 100644 --- a/Sources/OpenAI/Public/Models/RunRetrieveStepsResult.swift +++ b/Sources/OpenAI/Public/Models/RunRetrieveStepsResult.swift @@ -1,5 +1,5 @@ // -// RunRetreiveStepsResult.swift +// RunRetrieveStepsResult.swift // // // Created by Chris Dillard on 11/07/2023. @@ -7,7 +7,7 @@ import Foundation -public struct RunRetreiveStepsResult: Codable, Equatable { +public struct RunRetrieveStepsResult: Codable, Equatable { public struct StepDetailsTopLevel: Codable, Equatable { public let id: String @@ -27,26 +27,39 @@ public struct RunRetreiveStepsResult: Codable, Equatable { } public struct ToolCall: Codable, Equatable { + public enum ToolType: String, Codable { + case codeInterpreter = "code_interpreter" + case function + case retrieval + } + public let id: String - public let type: String - public let code: CodeToolCall? + public let type: ToolType + public let codeInterpreter: CodeInterpreterCall? + public let function: FunctionCall? enum CodingKeys: String, CodingKey { case id case type - case code = "code_interpreter" + case codeInterpreter = "code_interpreter" + case function } - public struct CodeToolCall: Codable, Equatable { + public struct CodeInterpreterCall: Codable, Equatable { public let input: String - public let outputs: [CodeToolCallOutput]? + public let outputs: [CodeInterpreterCallOutput]? - public struct CodeToolCallOutput: Codable, Equatable { + public struct CodeInterpreterCallOutput: Codable, Equatable { public let type: String public let logs: String? - } } + + public struct FunctionCall: Codable, Equatable { + public let name: String + public let arguments: String + public let output: String? + } } } } diff --git a/Sources/OpenAI/Public/Models/RunToolOutputsQuery.swift b/Sources/OpenAI/Public/Models/RunToolOutputsQuery.swift new file mode 100644 index 00000000..7fc42779 --- /dev/null +++ b/Sources/OpenAI/Public/Models/RunToolOutputsQuery.swift @@ -0,0 +1,35 @@ +// +// RunToolOutputsQuery.swift +// +// +// Created by Brent Whitman on 2024-01-29. +// + +import Foundation + +public struct RunToolOutputsQuery: Codable, Equatable { + public struct ToolOutput: Codable, Equatable { + public let toolCallId: String? + public let output: String? + + enum CodingKeys: String, CodingKey { + case toolCallId = "tool_call_id" + case output + } + + public init(toolCallId: String?, output: String?) { + self.toolCallId = toolCallId + self.output = output + } + } + + public let toolOutputs: [ToolOutput] + + enum CodingKeys: String, CodingKey { + case toolOutputs = "tool_outputs" + } + + public init(toolOutputs: [ToolOutput]) { + self.toolOutputs = toolOutputs + } +} diff --git a/Sources/OpenAI/Public/Models/RunsResult.swift b/Sources/OpenAI/Public/Models/RunsResult.swift deleted file mode 100644 index 858f15f5..00000000 --- a/Sources/OpenAI/Public/Models/RunsResult.swift +++ /dev/null @@ -1,13 +0,0 @@ -// -// RunsResult.swift -// -// -// Created by Chris Dillard on 11/07/2023. -// - -import Foundation - -public struct RunsResult: Codable, Equatable { - - public let id: String -} diff --git a/Sources/OpenAI/Public/Models/ThreadAddMessageQuery.swift b/Sources/OpenAI/Public/Models/ThreadAddMessageQuery.swift deleted file mode 100644 index 153851e2..00000000 --- a/Sources/OpenAI/Public/Models/ThreadAddMessageQuery.swift +++ /dev/null @@ -1,24 +0,0 @@ -// -// ThreadAddMessageQuery.swift -// -// -// Created by Chris Dillard on 11/07/2023. -// - -import Foundation - -public struct ThreadAddMessageQuery: Equatable, Codable { - public let role: String - public let content: String - - enum CodingKeys: String, CodingKey { - case role - case content - - } - - public init(role: String, content: String) { - self.role = role - self.content = content - } -} diff --git a/Sources/OpenAI/Public/Models/ThreadRunQuery.swift b/Sources/OpenAI/Public/Models/ThreadRunQuery.swift new file mode 100644 index 00000000..24eadaf8 --- /dev/null +++ b/Sources/OpenAI/Public/Models/ThreadRunQuery.swift @@ -0,0 +1,39 @@ +// +// ThreadsRunsQuery.swift +// +// +// Created by Brent Whitman on 2024-01-29. +// + +import Foundation + +public struct ThreadRunQuery: Equatable, Codable { + + public let assistantId: String + public let thread: ThreadsQuery + public let model: Model? + public let instructions: String? + public let tools: [Tool]? + + enum CodingKeys: String, CodingKey { + case assistantId = "assistant_id" + case thread + case model + case instructions + case tools + } + + public init( + assistantId: String, + thread: ThreadsQuery, + model: Model? = nil, + instructions: String? = nil, + tools: [Tool]? = nil + ) { + self.assistantId = assistantId + self.thread = thread + self.model = model + self.instructions = instructions + self.tools = tools + } +} diff --git a/Sources/OpenAI/Public/Models/ThreadsMessagesResult.swift b/Sources/OpenAI/Public/Models/ThreadsMessagesResult.swift index 975f897f..a08c79df 100644 --- a/Sources/OpenAI/Public/Models/ThreadsMessagesResult.swift +++ b/Sources/OpenAI/Public/Models/ThreadsMessagesResult.swift @@ -30,8 +30,13 @@ public struct ThreadsMessagesResult: Codable, Equatable { case fildId = "file_id" } } + + public enum ContentType: String, Codable { + case text + case imageFile = "image_file" + } - public let type: String + public let type: ContentType public let text: ThreadsMessageContentText? public let imageFile: ThreadsMessageContentText? @@ -43,9 +48,7 @@ public struct ThreadsMessagesResult: Codable, Equatable { } public let id: String - - public let role: String - + public let role: Chat.Role public let content: [ThreadsMessageContent] enum CodingKeys: String, CodingKey { @@ -55,11 +58,9 @@ public struct ThreadsMessagesResult: Codable, Equatable { } } - public let data: [ThreadsMessage] enum CodingKeys: String, CodingKey { case data } - } diff --git a/Sources/OpenAI/Public/Models/ThreadsQuery.swift b/Sources/OpenAI/Public/Models/ThreadsQuery.swift index 1f27849f..c9b2b446 100644 --- a/Sources/OpenAI/Public/Models/ThreadsQuery.swift +++ b/Sources/OpenAI/Public/Models/ThreadsQuery.swift @@ -8,13 +8,13 @@ import Foundation public struct ThreadsQuery: Equatable, Codable { - public let messages: [Chat] + public let messages: [MessageQuery] enum CodingKeys: String, CodingKey { case messages } - public init(messages: [Chat]) { + public init(messages: [MessageQuery]) { self.messages = messages } } diff --git a/Sources/OpenAI/Public/Models/Tool.swift b/Sources/OpenAI/Public/Models/Tool.swift new file mode 100644 index 00000000..e1843c63 --- /dev/null +++ b/Sources/OpenAI/Public/Models/Tool.swift @@ -0,0 +1,64 @@ +// +// Tool.swift +// +// +// Created by Brent Whitman on 2024-01-29. +// + +import Foundation + +/// The type of tool +/// +/// Refer to the [documentation](https://platform.openai.com/docs/assistants/tools/tools-beta) for more information on tools. +public enum Tool: Codable, Equatable { + /// Code Interpreter allows the Assistants API to write and run Python code in a sandboxed execution environment. + case codeInterpreter + /// Function calling allows you to describe functions to the Assistants and have it intelligently return the functions that need to be called along with their arguments. + case function(FunctionDeclaration) + /// Retrieval augments the Assistant with knowledge from outside its model, such as proprietary product information or documents provided by your users. + case retrieval + + enum CodingKeys: String, CodingKey { + case type + case function + } + + fileprivate var rawValue: String { + switch self { + case .codeInterpreter: + return "code_interpreter" + case .function: + return "function" + case .retrieval: + return "retrieval" + } + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let toolTypeString = try container.decode(String.self, forKey: .type) + + switch toolTypeString { + case "code_interpreter": + self = .codeInterpreter + case "function": + let functionDeclaration = try container.decode(FunctionDeclaration.self, forKey: .function) + self = .function(functionDeclaration) + case "retrieval": + self = .retrieval + default: + throw DecodingError.dataCorruptedError(forKey: .type, in: container, debugDescription: "Invalid tool type") + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(rawValue, forKey: .type) + switch self { + case let .function(declaration): + try container.encode(declaration, forKey: .function) + default: + break + } + } +} diff --git a/Sources/OpenAI/Public/Protocols/OpenAIProtocol+Async.swift b/Sources/OpenAI/Public/Protocols/OpenAIProtocol+Async.swift index 7428e34f..3e2b6078 100644 --- a/Sources/OpenAI/Public/Protocols/OpenAIProtocol+Async.swift +++ b/Sources/OpenAI/Public/Protocols/OpenAIProtocol+Async.swift @@ -231,12 +231,25 @@ public extension OpenAIProtocol { // 1106 func assistants( - query: AssistantsQuery?, - method: String, - after: String? + after: String? = nil ) async throws -> AssistantsResult { try await withCheckedThrowingContinuation { continuation in - assistants(query: query, method: method, after: after) { result in + assistants(after: after) { result in + switch result { + case let .success(success): + return continuation.resume(returning: success) + case let .failure(failure): + return continuation.resume(throwing: failure) + } + } + } + } + + func assistantCreate( + query: AssistantsQuery + ) async throws -> AssistantResult { + try await withCheckedThrowingContinuation { continuation in + assistantCreate(query: query) { result in switch result { case let .success(success): return continuation.resume(returning: success) @@ -249,10 +262,10 @@ public extension OpenAIProtocol { func assistantModify( query: AssistantsQuery, - asstId: String - ) async throws -> AssistantsResult { + assistantId: String + ) async throws -> AssistantResult { try await withCheckedThrowingContinuation { continuation in - assistantModify(query: query, asstId: asstId) { result in + assistantModify(query: query, assistantId: assistantId) { result in switch result { case let .success(success): return continuation.resume(returning: success) @@ -278,10 +291,25 @@ public extension OpenAIProtocol { } } + func threadRun( + query: ThreadRunQuery + ) async throws -> RunResult { + try await withCheckedThrowingContinuation { continuation in + threadRun(query: query) { result in + switch result { + case let .success(success): + return continuation.resume(returning: success) + case let .failure(failure): + return continuation.resume(throwing: failure) + } + } + } + } + func runs( threadId: String, query: RunsQuery - ) async throws -> RunsResult { + ) async throws -> RunResult { try await withCheckedThrowingContinuation { continuation in runs(threadId: threadId, query: query) { result in switch result { @@ -297,7 +325,7 @@ public extension OpenAIProtocol { func runRetrieve( threadId: String, runId: String - ) async throws -> RunRetreiveResult { + ) async throws -> RunResult { try await withCheckedThrowingContinuation { continuation in runRetrieve(threadId: threadId, runId: runId) { result in switch result { @@ -313,8 +341,8 @@ public extension OpenAIProtocol { func runRetrieveSteps( threadId: String, runId: String, - before: String? - ) async throws -> RunRetreiveStepsResult { + before: String? = nil + ) async throws -> RunRetrieveStepsResult { try await withCheckedThrowingContinuation { continuation in runRetrieveSteps(threadId: threadId, runId: runId, before: before) { result in switch result { @@ -327,9 +355,26 @@ public extension OpenAIProtocol { } } + func runSubmitToolOutputs( + threadId: String, + runId: String, + query: RunToolOutputsQuery + ) async throws -> RunResult { + try await withCheckedThrowingContinuation { continuation in + runSubmitToolOutputs(threadId: threadId, runId: runId, query: query) { result in + switch result { + case let .success(success): + return continuation.resume(returning: success) + case let .failure(failure): + return continuation.resume(throwing: failure) + } + } + } + } + func threadsMessages( threadId: String, - before: String? + before: String? = nil ) async throws -> ThreadsMessagesResult { try await withCheckedThrowingContinuation { continuation in threadsMessages(threadId: threadId, before: before) { result in @@ -345,7 +390,7 @@ public extension OpenAIProtocol { func threadsAddMessage( threadId: String, - query: ThreadAddMessageQuery + query: MessageQuery ) async throws -> ThreadAddMessageResult { try await withCheckedThrowingContinuation { continuation in threadsAddMessage(threadId: threadId, query: query) { result in diff --git a/Sources/OpenAI/Public/Protocols/OpenAIProtocol+Combine.swift b/Sources/OpenAI/Public/Protocols/OpenAIProtocol+Combine.swift index 0853a654..9300fe90 100644 --- a/Sources/OpenAI/Public/Protocols/OpenAIProtocol+Combine.swift +++ b/Sources/OpenAI/Public/Protocols/OpenAIProtocol+Combine.swift @@ -128,9 +128,23 @@ public extension OpenAIProtocol { } // 1106 - func assistants(query: AssistantsQuery?, method: String, after: String?) -> AnyPublisher { + func assistants(after: String? = nil) -> AnyPublisher { Future { - assistants(query: query, method: method, after: after, completion: $0) + assistants(after: after, completion: $0) + } + .eraseToAnyPublisher() + } + + func assistantCreate(query: AssistantsQuery) -> AnyPublisher { + Future { + assistantCreate(query: query, completion: $0) + } + .eraseToAnyPublisher() + } + + func assistantModify(query: AssistantsQuery, assistantId: String) -> AnyPublisher { + Future { + assistantModify(query: query, assistantId: assistantId, completion: $0) } .eraseToAnyPublisher() } @@ -142,27 +156,62 @@ public extension OpenAIProtocol { .eraseToAnyPublisher() } - func runs(threadId: String, query: RunsQuery) -> AnyPublisher { - Future { - runs(threadId: threadId, query: query, completion: $0) + func threadRun(query: ThreadRunQuery) -> AnyPublisher { + Future { + threadRun(query: query, completion: $0) } .eraseToAnyPublisher() } - func runRetrieve(threadId: String, runId: String) -> AnyPublisher { - Future { + func runs(threadId: String, query: RunsQuery) -> AnyPublisher { + Future { + runs(threadId: threadId, query: query, completion: $0) + } + .eraseToAnyPublisher() + } + func runRetrieve(threadId: String, runId: String) -> AnyPublisher { + Future { runRetrieve(threadId: threadId, runId: runId, completion: $0) } .eraseToAnyPublisher() } + + func runRetrieveSteps(threadId: String, runId: String, before: String? = nil) -> AnyPublisher { + Future { + runRetrieveSteps(threadId: threadId, runId: runId, before: before, completion: $0) + } + .eraseToAnyPublisher() + } + + func runSubmitToolOutputs(threadId: String, runId: String, query: RunToolOutputsQuery) -> AnyPublisher { + Future { + runSubmitToolOutputs(threadId: threadId, runId: runId, query: query, completion: $0) + } + .eraseToAnyPublisher() + } - func threadsMessages(threadId: String, before: String?) -> AnyPublisher { + func threadsMessages(threadId: String, before: String? = nil) -> AnyPublisher { Future { threadsMessages(threadId: threadId, before: before, completion: $0) } .eraseToAnyPublisher() } + + func threadsAddMessage(threadId: String, query: MessageQuery) -> AnyPublisher { + Future { + threadsAddMessage(threadId: threadId, query: query, completion: $0) + } + .eraseToAnyPublisher() + } + + func files(query: FilesQuery) -> AnyPublisher { + Future { + files(query: query, completion: $0) + } + .eraseToAnyPublisher() + } + // 1106 end } diff --git a/Sources/OpenAI/Public/Protocols/OpenAIProtocol.swift b/Sources/OpenAI/Public/Protocols/OpenAIProtocol.swift index f178bd2b..9ef652cd 100644 --- a/Sources/OpenAI/Public/Protocols/OpenAIProtocol.swift +++ b/Sources/OpenAI/Public/Protocols/OpenAIProtocol.swift @@ -251,31 +251,38 @@ public protocol OpenAIProtocol { /// // The following functions represent new functionality added to OpenAI Beta on 11-06-23 /// - /// + /** - This function sends a assistants query to the OpenAI API and creates an assistant. The Assistants API in this usage enables you to create an assistant. + This function sends a assistants query to the OpenAI API to list assistants that have been created. - Example: Create Assistant + Example: List Assistants ``` - let query = AssistantsQuery(model: Model.gpt4_1106_preview, name: name, description: description, instructions: instructions, tools: tools, fileIds: fileIds) - openAI.assistants(query: query) { result in + openAI.assistants() { result in //Handle response here } ``` - Example: List Assistants + - Parameter after: A cursor for use in pagination. after is an object ID that defines your place in the list. + - Parameter completion: The completion handler to be executed upon completion of the assistant request. + Returns a `Result` of type `AssistantsResult` if successful, or an `Error` if an error occurs. + **/ + func assistants(after: String?, completion: @escaping (Result) -> Void) + + /** + This function sends an assistants query to the OpenAI API and creates an assistant. + ``` - openAI.assistants(query: nil, method: "GET") { result in + let query = AssistantsQuery(model: Model.gpt4_1106_preview, name: name, description: description, instructions: instructions, tools: tools, fileIds: fileIds) + openAI.createAssistant(query: query) { result in //Handle response here } ``` - - Parameter query: The `AssistantsQuery?` instance, containing the information required for the assistant request. Passing nil is used for GET form of request. - - Parameter method: The method to use with the HTTP request. Supports POST (default) and GET. + - Parameter query: The `AssistantsQuery` instance, containing the information required for the assistant request. - Parameter completion: The completion handler to be executed upon completion of the assistant request. - Returns a `Result` of type `AssistantsResult` if successful, or an `Error` if an error occurs. + Returns a `Result` of type `AssistantResult` if successful, or an `Error` if an error occurs. **/ - func assistants(query: AssistantsQuery?, method: String, after: String?, completion: @escaping (Result) -> Void) + func assistantCreate(query: AssistantsQuery, completion: @escaping (Result) -> Void) /** This function sends a assistants query to the OpenAI API and modifies an assistant. The Assistants API in this usage enables you to modify an assistant. @@ -283,17 +290,17 @@ public protocol OpenAIProtocol { Example: Modify Assistant ``` let query = AssistantsQuery(model: Model.gpt4_1106_preview, name: name, description: description, instructions: instructions, tools: tools, fileIds: fileIds) - openAI.assistantModify(query: query, asstId: "asst_1234") { result in + openAI.assistantModify(query: query, assistantId: "asst_1234") { result in //Handle response here } ``` - Parameter query: The `AssistantsQuery` instance, containing the information required for the assistant request. - - Parameter asstId: The assistant id for the assistant to modify. + - Parameter assistantId: The assistant id for the assistant to modify. - Parameter completion: The completion handler to be executed upon completion of the assistant request. - Returns a `Result` of type `AssistantsResult` if successful, or an `Error` if an error occurs. + Returns a `Result` of type `AssistantResult` if successful, or an `Error` if an error occurs. **/ - func assistantModify(query: AssistantsQuery, asstId: String, completion: @escaping (Result) -> Void) + func assistantModify(query: AssistantsQuery, assistantId: String, completion: @escaping (Result) -> Void) /** This function sends a threads query to the OpenAI API and creates a thread. The Threads API in this usage enables you to create a thread. @@ -312,6 +319,23 @@ public protocol OpenAIProtocol { **/ func threads(query: ThreadsQuery, completion: @escaping (Result) -> Void) + /** + This function sends a threads query to the OpenAI API that creates and runs a thread in a single request. + + Example: Create and Run Thread + ``` + let threadsQuery = ThreadQuery(messages: [Chat(role: message.role, content: message.content)]) + let threadRunQuery = ThreadRunQuery(assistantId: "asst_1234" thread: threadsQuery) + openAI.threadRun(query: threadRunQuery) { result in + //Handle response here + } + ``` + - Parameter query: The `ThreadRunQuery` instance, containing the information required for the request. + - Parameter completion: The completion handler to be executed upon completion of the threads request. + Returns a `Result` of type `RunResult` if successful, or an `Error` if an error occurs. + **/ + func threadRun(query: ThreadRunQuery, completion: @escaping (Result) -> Void) + /** This function sends a runs query to the OpenAI API and creates a run. The Runs API in this usage enables you to create a run. @@ -326,9 +350,9 @@ public protocol OpenAIProtocol { - Parameter threadId: The thread id for the thread to run. - Parameter query: The `RunsQuery` instance, containing the information required for the runs request. - Parameter completion: The completion handler to be executed upon completion of the runs request. - Returns a `Result` of type `RunsResult` if successful, or an `Error` if an error occurs. + Returns a `Result` of type `RunResult` if successful, or an `Error` if an error occurs. **/ - func runs(threadId: String, query: RunsQuery, completion: @escaping (Result) -> Void) + func runs(threadId: String, query: RunsQuery, completion: @escaping (Result) -> Void) /** This function sends a thread id and run id to the OpenAI API and retrieves a run. The Runs API in this usage enables you to retrieve a run. @@ -342,9 +366,9 @@ public protocol OpenAIProtocol { - Parameter threadId: The thread id for the thread to run. - Parameter runId: The run id for the run to retrieve. - Parameter completion: The completion handler to be executed upon completion of the runRetrieve request. - Returns a `Result` of type `RunRetreiveResult` if successful, or an `Error` if an error occurs. + Returns a `Result` of type `RunRetrieveResult` if successful, or an `Error` if an error occurs. **/ - func runRetrieve(threadId: String, runId: String, completion: @escaping (Result) -> Void) + func runRetrieve(threadId: String, runId: String, completion: @escaping (Result) -> Void) /** This function sends a thread id and run id to the OpenAI API and retrieves a list of run steps. The Runs API in this usage enables you to retrieve a runs run steps. @@ -359,10 +383,20 @@ public protocol OpenAIProtocol { - Parameter runId: The run id for the run to retrieve. - Parameter before: String?: The message id for the run step that defines your place in the list of run steps. Pass nil to get all. - Parameter completion: The completion handler to be executed upon completion of the runRetrieve request. - Returns a `Result` of type `RunRetreiveStepsResult` if successful, or an `Error` if an error occurs. + Returns a `Result` of type `RunRetrieveStepsResult` if successful, or an `Error` if an error occurs. **/ - func runRetrieveSteps(threadId: String, runId: String, before: String?, completion: @escaping (Result) -> Void) + func runRetrieveSteps(threadId: String, runId: String, before: String?, completion: @escaping (Result) -> Void) + /** + This function submits tool outputs for a run to the OpenAI API. It should be submitted when a run is in status `required_action` and `required_action.type` is `submit_tool_outputs` + + - Parameter threadId: The thread id for the thread which needs tool outputs. + - Parameter runId: The run id for the run which needs tool outputs. + - Parameter query: An object containing the tool outputs, populated based on the results of the requested function call + - Parameter completion: The completion handler to be executed upon completion of the runSubmitToolOutputs request. + Returns a `Result` of type `RunResult` if successful, or an `Error` if an error occurs. + */ + func runSubmitToolOutputs(threadId: String, runId: String, query: RunToolOutputsQuery, completion: @escaping (Result) -> Void) /** This function sends a thread id and run id to the OpenAI API and retrieves a threads messages. @@ -370,7 +404,7 @@ public protocol OpenAIProtocol { Example: Get Threads Messages ``` - openAI.threadsMessages(threadId: currentThreadId, before: nil) { result in + openAI.threadsMessages(threadId: currentThreadId) { result in //Handle response here } ``` @@ -387,18 +421,18 @@ public protocol OpenAIProtocol { Example: Add Message to Thread ``` - let query = ThreadAddMessageQuery(role: message.role.rawValue, content: message.content) + let query = MessageQuery(role: message.role.rawValue, content: message.content) openAI.threadsAddMessage(threadId: currentThreadId, query: query) { result in //Handle response here } ``` - Parameter threadId: The thread id for the thread to run. - - Parameter query: The `ThreadAddMessageQuery` instance, containing the information required for the threads request. + - Parameter query: The `MessageQuery` instance, containing the information required for the threads request. - Parameter completion: The completion handler to be executed upon completion of the runRetrieve request. Returns a `Result` of type `ThreadAddMessageResult` if successful, or an `Error` if an error occurs. **/ - func threadsAddMessage(threadId: String, query: ThreadAddMessageQuery, completion: @escaping (Result) -> Void) + func threadsAddMessage(threadId: String, query: MessageQuery, completion: @escaping (Result) -> Void) /** This function sends a purpose string, file contents, and fileName contents to the OpenAI API and returns a file id result. diff --git a/Tests/OpenAITests/OpenAITests.swift b/Tests/OpenAITests/OpenAITests.swift index 11bb3f29..b9697b13 100644 --- a/Tests/OpenAITests/OpenAITests.swift +++ b/Tests/OpenAITests/OpenAITests.swift @@ -102,13 +102,13 @@ class OpenAITests: XCTestCase { func testChats() async throws { let query = ChatQuery(model: .gpt4, messages: [ - .init(role: .system, content: "You are Librarian-GPT. You know everything about the books."), - .init(role: .user, content: "Who wrote Harry Potter?") + .system(content: "You are Librarian-GPT. You know everything about the books."), + .user(content: "Who wrote Harry Potter?") ]) let chatResult = ChatResult(id: "id-12312", object: "foo", created: 100, model: .gpt3_5Turbo, choices: [ - .init(index: 0, message: .init(role: .system, content: "bar"), finishReason: "baz"), - .init(index: 0, message: .init(role: .user, content: "bar1"), finishReason: "baz1"), - .init(index: 0, message: .init(role: .assistant, content: "bar2"), finishReason: "baz2") + .init(index: 0, message: .system(content: "bar"), finishReason: "baz"), + .init(index: 0, message: .user(content: "bar1"), finishReason: "baz1"), + .init(index: 0, message: .assistant(content: "bar2"), finishReason: "baz2") ], usage: .init(promptTokens: 100, completionTokens: 200, totalTokens: 300)) try self.stub(result: chatResult) @@ -118,19 +118,21 @@ class OpenAITests: XCTestCase { func testChatsFunction() async throws { let query = ChatQuery(model: .gpt3_5Turbo_1106, messages: [ - .init(role: .system, content: "You are Weather-GPT. You know everything about the weather."), - .init(role: .user, content: "What's the weather like in Boston?"), - ], functions: [ - .init(name: "get_current_weather", description: "Get the current weather in a given location", parameters: .init(type: .object, properties: [ - "location": .init(type: .string, description: "The city and state, e.g. San Francisco, CA"), - "unit": .init(type: .string, enumValues: ["celsius", "fahrenheit"]) - ], required: ["location"])) - ], functionCall: .auto) + .system(content: "You are Weather-GPT. You know everything about the weather."), + .user(content: "What's the weather like in Boston?"), + ], tools: [ + .function( + .init(name: "get_current_weather", description: "Get the current weather in a given location", parameters: .init(type: .object, properties: [ + "location": .init(type: .string, description: "The city and state, e.g. San Francisco, CA"), + "unit": .init(type: .string, enumValues: ["celsius", "fahrenheit"]) + ], required: ["location"])) + ) + ], toolChoice: .auto) let chatResult = ChatResult(id: "id-12312", object: "foo", created: 100, model: .gpt3_5Turbo, choices: [ - .init(index: 0, message: .init(role: .system, content: "bar"), finishReason: "baz"), - .init(index: 0, message: .init(role: .user, content: "bar1"), finishReason: "baz1"), - .init(index: 0, message: .init(role: .assistant, content: "bar2"), finishReason: "baz2") + .init(index: 0, message: .system(content: "bar"), finishReason: "baz"), + .init(index: 0, message: .user(content: "bar1"), finishReason: "baz1"), + .init(index: 0, message: .assistant(content: "bar2"), finishReason: "baz2") ], usage: .init(promptTokens: 100, completionTokens: 200, totalTokens: 300)) try self.stub(result: chatResult) @@ -140,8 +142,8 @@ class OpenAITests: XCTestCase { func testChatsError() async throws { let query = ChatQuery(model: .gpt4, messages: [ - .init(role: .system, content: "You are Librarian-GPT. You know everything about the books."), - .init(role: .user, content: "Who wrote Harry Potter?") + .system(content: "You are Librarian-GPT. You know everything about the books."), + .user(content: "Who wrote Harry Potter?") ]) let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100") self.stub(error: inError) @@ -359,30 +361,31 @@ class OpenAITests: XCTestCase { } // 1106 - func testAssistantQuery() async throws { + func testAssistantCreateQuery() async throws { let query = AssistantsQuery(model: .gpt4_1106_preview, name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: []) - let expectedResult = AssistantsResult(id: "asst_1234", data: [AssistantsResult.AssistantContent(id: "asst_9876", name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: nil, fileIds: nil)], tools: []) + let expectedResult = AssistantResult(id: "asst_9876", name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: nil, fileIds: nil) try self.stub(result: expectedResult) - let result = try await openAI.assistants(query: query, method: "POST", after: nil) + let result = try await openAI.assistantCreate(query: query) XCTAssertEqual(result, expectedResult) } - func testAssistantQueryError() async throws { + func testAssistantCreateQueryError() async throws { let query = AssistantsQuery(model: .gpt4_1106_preview, name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: []) let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100") self.stub(error: inError) - let apiError: APIError = try await XCTExpectError { try await openAI.assistants(query: query, method: "POST", after: nil) } + let apiError: APIError = try await XCTExpectError { try await openAI.assistantCreate(query: query) } XCTAssertEqual(inError, apiError) } func testListAssistantQuery() async throws { - let expectedResult = AssistantsResult(id: nil, data: [AssistantsResult.AssistantContent(id: "asst_9876", name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: nil, fileIds: nil)], tools: nil) + let expectedAssistant = AssistantResult(id: "asst_9876", name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: nil, fileIds: nil) + let expectedResult = AssistantsResult(data: [expectedAssistant], firstId: expectedAssistant.id, lastId: expectedAssistant.id, hasMore: false) try self.stub(result: expectedResult) - let result = try await openAI.assistants(query: nil, method: "GET", after: nil) + let result = try await openAI.assistants() XCTAssertEqual(result, expectedResult) } @@ -390,12 +393,30 @@ class OpenAITests: XCTestCase { let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100") self.stub(error: inError) - let apiError: APIError = try await XCTExpectError { try await openAI.assistants(query: nil, method: "GET", after: nil) } + let apiError: APIError = try await XCTExpectError { try await openAI.assistants() } + XCTAssertEqual(inError, apiError) + } + + func testAssistantModifyQuery() async throws { + let query = AssistantsQuery(model: .gpt4_1106_preview, name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: []) + let expectedResult = AssistantResult(id: "asst_9876", name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: nil, fileIds: nil) + try self.stub(result: expectedResult) + + let result = try await openAI.assistantModify(query: query, assistantId: "asst_9876") + XCTAssertEqual(result, expectedResult) + } + + func testAssistantModifyQueryError() async throws { + let query = AssistantsQuery(model: .gpt4_1106_preview, name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: []) + let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100") + self.stub(error: inError) + + let apiError: APIError = try await XCTExpectError { try await openAI.assistantModify(query: query, assistantId: "asst_9876") } XCTAssertEqual(inError, apiError) } func testThreadsQuery() async throws { - let query = ThreadsQuery(messages: [Chat(role: .user, content: "Hello, What is AI?")]) + let query = ThreadsQuery(messages: [MessageQuery(role: .user, content: "Hello, What is AI?")]) let expectedResult = ThreadsResult(id: "thread_1234") try self.stub(result: expectedResult) @@ -404,7 +425,7 @@ class OpenAITests: XCTestCase { } func testThreadsQueryError() async throws { - let query = ThreadsQuery(messages: [Chat(role: .user, content: "Hello, What is AI?")]) + let query = ThreadsQuery(messages: [MessageQuery(role: .user, content: "Hello, What is AI?")]) let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100") self.stub(error: inError) @@ -412,10 +433,28 @@ class OpenAITests: XCTestCase { let apiError: APIError = try await XCTExpectError { try await openAI.threads(query: query) } XCTAssertEqual(inError, apiError) } + + func testThreadRunQuery() async throws { + let query = ThreadRunQuery(assistantId: "asst_7654321", thread: .init(messages: [.init(role: .user, content: "Hello, What is AI?")])) + let expectedResult = RunResult(id: "run_1234", threadId: "thread_1234", status: .completed, requiredAction: nil) + try self.stub(result: expectedResult) + + let result = try await openAI.threadRun(query: query) + XCTAssertEqual(result, expectedResult) + } + func testThreadRunQueryError() async throws { + let query = ThreadRunQuery(assistantId: "asst_7654321", thread: .init(messages: [.init(role: .user, content: "Hello, What is AI?")])) + let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100") + self.stub(error: inError) + + let apiError: APIError = try await XCTExpectError { try await openAI.threadRun(query: query) } + XCTAssertEqual(inError, apiError) + } + func testRunsQuery() async throws { let query = RunsQuery(assistantId: "asst_7654321") - let expectedResult = RunsResult(id: "run_1234") + let expectedResult = RunResult(id: "run_1234", threadId: "thread_1234", status: .completed, requiredAction: nil) try self.stub(result: expectedResult) let result = try await openAI.runs(threadId: "thread_1234", query: query) @@ -432,7 +471,7 @@ class OpenAITests: XCTestCase { } func testRunRetrieveQuery() async throws { - let expectedResult = RunRetreiveResult(status: "in_progress") + let expectedResult = RunResult(id: "run_1234", threadId: "thread_1234", status: .inProgress, requiredAction: nil) try self.stub(result: expectedResult) let result = try await openAI.runRetrieve(threadId: "thread_1234", runId: "run_1234") @@ -446,12 +485,64 @@ class OpenAITests: XCTestCase { let apiError: APIError = try await XCTExpectError { try await openAI.runRetrieve(threadId: "thread_1234", runId: "run_1234") } XCTAssertEqual(inError, apiError) } + + func testRunRetrieveStepsQuery() async throws { + let expectedResult = RunRetrieveStepsResult(data: [.init(id: "step_1234", stepDetails: .init(toolCalls: [.init(id: "tool_456", type: .retrieval, codeInterpreter: nil, function: nil)]))]) + try self.stub(result: expectedResult) + + let result = try await openAI.runRetrieveSteps(threadId: "thread_1234", runId: "run_1234") + XCTAssertEqual(result, expectedResult) + } + + func testRunRetreiveStepsQueryError() async throws { + let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100") + self.stub(error: inError) + + let apiError: APIError = try await XCTExpectError { try await openAI.runRetrieveSteps(threadId: "thread_1234", runId: "run_1234") } + XCTAssertEqual(inError, apiError) + } + + func testRunSubmitToolOutputsQuery() async throws { + let query = RunToolOutputsQuery(toolOutputs: [.init(toolCallId: "call_123", output: "Success")]) + let expectedResult = RunResult(id: "run_123", threadId: "thread_456", status: .inProgress, requiredAction: nil) + try self.stub(result: expectedResult) + + let result = try await openAI.runSubmitToolOutputs(threadId: "thread_456", runId: "run_123", query: query) + XCTAssertEqual(result, expectedResult) + } + + func testRunSubmitToolOutputsQueryError() async throws { + let query = RunToolOutputsQuery(toolOutputs: [.init(toolCallId: "call_123", output: "Success")]) + let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100") + self.stub(error: inError) + + let apiError: APIError = try await XCTExpectError { try await openAI.runSubmitToolOutputs(threadId: "thread_456", runId: "run_123", query: query) } + XCTAssertEqual(inError, apiError) + } + + func testThreadAddMessageQuery() async throws { + let query = MessageQuery(role: .user, content: "Hello, What is AI?", fileIds: ["file_123"]) + let expectedResult = ThreadAddMessageResult(id: "message_1234") + try self.stub(result: expectedResult) + + let result = try await openAI.threadsAddMessage(threadId: "thread_1234", query: query) + XCTAssertEqual(result, expectedResult) + } + + func testThreadAddMessageQueryError() async throws { + let query = MessageQuery(role: .user, content: "Hello, What is AI?", fileIds: ["file_123"]) + let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100") + self.stub(error: inError) + + let apiError: APIError = try await XCTExpectError { try await openAI.threadsAddMessage(threadId: "thread_1234", query: query) } + XCTAssertEqual(inError, apiError) + } func testThreadsMessageQuery() async throws { - let expectedResult = ThreadsMessagesResult(data: [ThreadsMessagesResult.ThreadsMessage(id: "thread_1234", role: Chat.Role.user.rawValue, content: [ThreadsMessagesResult.ThreadsMessage.ThreadsMessageContent(type: "text", text: ThreadsMessagesResult.ThreadsMessage.ThreadsMessageContent.ThreadsMessageContentText(value: "Hello, What is AI?"), imageFile: nil)])]) + let expectedResult = ThreadsMessagesResult(data: [ThreadsMessagesResult.ThreadsMessage(id: "thread_1234", role: .user, content: [.init(type: .text, text: .init(value: "Hello, What is AI?"), imageFile: nil)])]) try self.stub(result: expectedResult) - let result = try await openAI.threadsMessages(threadId: "thread_1234", before: nil) + let result = try await openAI.threadsMessages(threadId: "thread_1234") XCTAssertEqual(result, expectedResult) } @@ -459,10 +550,30 @@ class OpenAITests: XCTestCase { let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100") self.stub(error: inError) - let apiError: APIError = try await XCTExpectError { try await openAI.threadsMessages(threadId: "thread_1234", before: nil) } + let apiError: APIError = try await XCTExpectError { try await openAI.threadsMessages(threadId: "thread_1234") } XCTAssertEqual(inError, apiError) } + func testFilesQuery() async throws { + let data = try XCTUnwrap("{\"test\":\"data\"}".data(using: .utf8)) + let query = FilesQuery(purpose: "assistant", file: data, fileName: "test.json", contentType: "application/json") + let expectedResult = FilesResult(id: "file_1234", name: "test.json") + try self.stub(result: expectedResult) + + let result = try await openAI.files(query: query) + XCTAssertEqual(result, expectedResult) + } + + func testFilesQueryError() async throws { + let data = try XCTUnwrap("{\"test\":\"data\"}".data(using: .utf8)) + let query = FilesQuery(purpose: "assistant", file: data, fileName: "test.json", contentType: "application/json") + let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100") + self.stub(error: inError) + + let apiError: APIError = try await XCTExpectError { try await openAI.files(query: query) } + XCTAssertEqual(inError, apiError) + } + func testCustomRunsURLBuilt() { let configuration = OpenAI.Configuration(token: "foo", organizationIdentifier: "bar", host: "my.host.com", timeoutInterval: 14) let openAI = OpenAI(configuration: configuration, session: self.urlSession) diff --git a/Tests/OpenAITests/OpenAITestsCombine.swift b/Tests/OpenAITests/OpenAITestsCombine.swift index a306ea00..1f438f5f 100644 --- a/Tests/OpenAITests/OpenAITestsCombine.swift +++ b/Tests/OpenAITests/OpenAITestsCombine.swift @@ -38,13 +38,13 @@ final class OpenAITestsCombine: XCTestCase { func testChats() throws { let query = ChatQuery(model: .gpt4, messages: [ - .init(role: .system, content: "You are Librarian-GPT. You know everything about the books."), - .init(role: .user, content: "Who wrote Harry Potter?") + .system(content: "You are Librarian-GPT. You know everything about the books."), + .user(content: "Who wrote Harry Potter?") ]) let chatResult = ChatResult(id: "id-12312", object: "foo", created: 100, model: .gpt3_5Turbo, choices: [ - .init(index: 0, message: .init(role: .system, content: "bar"), finishReason: "baz"), - .init(index: 0, message: .init(role: .user, content: "bar1"), finishReason: "baz1"), - .init(index: 0, message: .init(role: .assistant, content: "bar2"), finishReason: "baz2") + .init(index: 0, message: .system(content: "bar"), finishReason: "baz"), + .init(index: 0, message: .user(content: "bar1"), finishReason: "baz1"), + .init(index: 0, message: .assistant(content: "bar2"), finishReason: "baz2") ], usage: .init(promptTokens: 100, completionTokens: 200, totalTokens: 300)) try self.stub(result: chatResult) let result = try awaitPublisher(openAI.chats(query: query)) @@ -125,17 +125,35 @@ final class OpenAITestsCombine: XCTestCase { } // 1106 - func testAssistantQuery() throws { + func testAssistantsQuery() throws { + let expectedAssistant = AssistantResult(id: "asst_9876", name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: nil, fileIds: nil) + let expectedResult = AssistantsResult(data: [expectedAssistant], firstId: expectedAssistant.id, lastId: expectedAssistant.id, hasMore: false) + try self.stub(result: expectedResult) + + let result = try awaitPublisher(openAI.assistants()) + XCTAssertEqual(result, expectedResult) + } + + func testAssistantCreateQuery() throws { let query = AssistantsQuery(model: .gpt4_1106_preview, name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: []) - let expectedResult = AssistantsResult(id: "asst_1234", data: [AssistantsResult.AssistantContent(id: "asst_9876", name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: nil, fileIds: nil)], tools: []) + let expectedResult = AssistantResult(id: "asst_9876", name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: nil, fileIds: nil) try self.stub(result: expectedResult) - let result = try awaitPublisher(openAI.assistants(query: query, method: "POST", after: nil)) + let result = try awaitPublisher(openAI.assistantCreate(query: query)) + XCTAssertEqual(result, expectedResult) + } + + func testAssistantModifyQuery() throws { + let query = AssistantsQuery(model: .gpt4_1106_preview, name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: []) + let expectedResult = AssistantResult(id: "asst_9876", name: "My New Assistant", description: "Assistant Description", instructions: "You are a helpful assistant.", tools: nil, fileIds: nil) + try self.stub(result: expectedResult) + + let result = try awaitPublisher(openAI.assistantModify(query: query, assistantId: "asst_9876")) XCTAssertEqual(result, expectedResult) } func testThreadsQuery() throws { - let query = ThreadsQuery(messages: [Chat(role: .user, content: "Hello, What is AI?")]) + let query = ThreadsQuery(messages: [MessageQuery(role: .user, content: "Hello, What is AI?")]) let expectedResult = ThreadsResult(id: "thread_1234") try self.stub(result: expectedResult) @@ -143,10 +161,19 @@ final class OpenAITestsCombine: XCTestCase { XCTAssertEqual(result, expectedResult) } + + func testThreadRunQuery() throws { + let query = ThreadRunQuery(assistantId: "asst_7654321", thread: .init(messages: [.init(role: .user, content: "Hello, What is AI?")])) + let expectedResult = RunResult(id: "run_1234", threadId: "thread_1234", status: .completed, requiredAction: nil) + try self.stub(result: expectedResult) + + let result = try awaitPublisher(openAI.threadRun(query: query)) + XCTAssertEqual(result, expectedResult) + } func testRunsQuery() throws { let query = RunsQuery(assistantId: "asst_7654321") - let expectedResult = RunsResult(id: "run_1234") + let expectedResult = RunResult(id: "run_1234", threadId: "thread_1234", status: .inProgress, requiredAction: nil) try self.stub(result: expectedResult) let result = try awaitPublisher(openAI.runs(threadId: "thread_1234", query: query)) @@ -155,25 +182,60 @@ final class OpenAITestsCombine: XCTestCase { } func testRunRetrieveQuery() throws { - let expectedResult = RunRetreiveResult(status: "in_progress") + let expectedResult = RunResult(id: "run_1234", threadId: "thread_1234", status: .inProgress, requiredAction: nil) try self.stub(result: expectedResult) let result = try awaitPublisher(openAI.runRetrieve(threadId: "thread_1234", runId: "run_1234")) XCTAssertEqual(result, expectedResult) } + + func testRunRetrieveStepsQuery() throws { + let expectedResult = RunRetrieveStepsResult(data: [.init(id: "step_1234", stepDetails: .init(toolCalls: [.init(id: "tool_456", type: .retrieval, codeInterpreter: nil, function: nil)]))]) + try self.stub(result: expectedResult) + + let result = try awaitPublisher(openAI.runRetrieveSteps(threadId: "thread_1234", runId: "run_1234")) + XCTAssertEqual(result, expectedResult) + } + func testRunSubmitToolOutputsQuery() throws { + let query = RunToolOutputsQuery(toolOutputs: [.init(toolCallId: "call_123", output: "Success")]) + let expectedResult = RunResult(id: "run_123", threadId: "thread_456", status: .inProgress, requiredAction: nil) + try self.stub(result: expectedResult) + + let result = try awaitPublisher(openAI.runSubmitToolOutputs(threadId: "thread_456", runId: "run_123", query: query)) + XCTAssertEqual(result, expectedResult) + } + + func testThreadAddMessageQuery() throws { + let query = MessageQuery(role: .user, content: "Hello, What is AI?", fileIds: ["file_123"]) + let expectedResult = ThreadAddMessageResult(id: "message_1234") + try self.stub(result: expectedResult) + + let result = try awaitPublisher(openAI.threadsAddMessage(threadId: "thread_1234", query: query)) + XCTAssertEqual(result, expectedResult) + } + func testThreadsMessageQuery() throws { - let expectedResult = ThreadsMessagesResult(data: [ThreadsMessagesResult.ThreadsMessage(id: "thread_1234", role: Chat.Role.user.rawValue, content: [ThreadsMessagesResult.ThreadsMessage.ThreadsMessageContent(type: "text", text: ThreadsMessagesResult.ThreadsMessage.ThreadsMessageContent.ThreadsMessageContentText(value: "Hello, What is AI?"), imageFile: nil)])]) + let expectedResult = ThreadsMessagesResult(data: [.init(id: "thread_1234", role: .user, content: [.init(type: .text, text: .init(value: "Hello, What is AI?"), imageFile: nil)])]) try self.stub(result: expectedResult) - let result = try awaitPublisher(openAI.threadsMessages(threadId: "thread_1234", before: nil)) + let result = try awaitPublisher(openAI.threadsMessages(threadId: "thread_1234")) XCTAssertEqual(result, expectedResult) } + + func testFilesQuery() throws { + let data = try XCTUnwrap("{\"test\":\"data\"}".data(using: .utf8)) + let query = FilesQuery(purpose: "assistant", file: data, fileName: "test.json", contentType: "application/json") + let expectedResult = FilesResult(id: "file_1234", name: "test.json") + try self.stub(result: expectedResult) + + let result = try awaitPublisher(openAI.files(query: query)) + XCTAssertEqual(result, expectedResult) + } // 1106 end - } @available(tvOS 13.0, *) diff --git a/Tests/OpenAITests/OpenAITestsDecoder.swift b/Tests/OpenAITests/OpenAITestsDecoder.swift index 70b611cf..670460c1 100644 --- a/Tests/OpenAITests/OpenAITestsDecoder.swift +++ b/Tests/OpenAITests/OpenAITestsDecoder.swift @@ -17,10 +17,17 @@ class OpenAITestsDecoder: XCTestCase { super.setUp() } - private func decode(_ jsonString: String, _ expectedValue: T) throws { + private func decode(_ jsonString: String, _ expectedValue: T, file: StaticString = #filePath, line: UInt = #line) throws { let data = jsonString.data(using: .utf8)! let decoded = try JSONDecoder().decode(T.self, from: data) - XCTAssertEqual(decoded, expectedValue) + XCTAssertEqual(decoded, expectedValue, file: file, line: line) + } + + private func encode(_ expectedValue: T, _ jsonString: String, file: StaticString = #filePath, line: UInt = #line) throws { + // To compare serialized JSONs we first convert them both into NSDictionary which are comparable (unlike native swift dictionaries) + let expectedValueAsDict = try jsonDataAsNSDictionary(JSONEncoder().encode(expectedValue)) + let jsonStringAsDict = try jsonDataAsNSDictionary(jsonString.data(using: .utf8)!) + XCTAssertEqual(jsonStringAsDict, expectedValueAsDict, file: file, line: line) } func jsonDataAsNSDictionary(_ data: Data) throws -> NSDictionary { @@ -106,7 +113,7 @@ class OpenAITestsDecoder: XCTestCase { """ let expectedValue = ChatResult(id: "chatcmpl-123", object: "chat.completion", created: 1677652288, model: .gpt4, choices: [ - .init(index: 0, message: Chat(role: .assistant, content: "Hello, world!"), finishReason: "stop") + .init(index: 0, message: .assistant(content: "Hello, world!"), finishReason: "stop") ], usage: .init(promptTokens: 9, completionTokens: 12, totalTokens: 21)) try decode(data, expectedValue) } @@ -134,71 +141,86 @@ class OpenAITestsDecoder: XCTestCase { } """ - // To compare serialized JSONs we first convert them both into NSDictionary which are comparable (unline native swift dictionaries) - let imageQueryAsDict = try jsonDataAsNSDictionary(JSONEncoder().encode(imageQuery)) - let expectedValueAsDict = try jsonDataAsNSDictionary(expectedValue.data(using: .utf8)!) - - XCTAssertEqual(imageQueryAsDict, expectedValueAsDict) + try encode(imageQuery, expectedValue) } func testChatQueryWithFunctionCall() async throws { let chatQuery = ChatQuery( model: .gpt3_5Turbo, messages: [ - Chat(role: .user, content: "What's the weather like in Boston?") + .user(content: "What's the weather like in Boston?") ], responseFormat: .init(type: .jsonObject), - functions: [ - ChatFunctionDeclaration( - name: "get_current_weather", - description: "Get the current weather in a given location", - parameters: - JSONSchema( - type: .object, - properties: [ - "location": .init(type: .string, description: "The city and state, e.g. San Francisco, CA"), - "unit": .init(type: .string, enumValues: ["celsius", "fahrenheit"]) - ], - required: ["location"] - ) + tools: [ + .function( + FunctionDeclaration( + name: "get_current_weather", + description: "Get the current weather in a given location", + parameters: + JSONSchema( + type: .object, + properties: [ + "location": .init(type: .string, description: "The city and state, e.g. San Francisco, CA"), + "unit": .init(type: .string, enumValues: ["celsius", "fahrenheit"]) + ], + required: ["location"] + ) + ) ) - ] + ], + toolChoice: .function("get_current_weather") ) let expectedValue = """ { "model": "gpt-3.5-turbo", "messages": [ - { "role": "user", "content": "What's the weather like in Boston?" } + { + "role": "user", + "content": "What's the weather like in Boston?" + } ], "response_format": { "type": "json_object" - }, - "functions": [ + }, + "tools": [ { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit" + ] + } }, - "unit": { "type": "string", "enum": ["celsius", "fahrenheit"] } - }, - "required": ["location"] + "required": [ + "location" + ] + } } } ], + "tool_choice": { + "type": "function", + "function": { + "name": "get_current_weather" + } + }, "stream": false } """ - // To compare serialized JSONs we first convert them both into NSDictionary which are comparable (unline native swift dictionaries) - let chatQueryAsDict = try jsonDataAsNSDictionary(JSONEncoder().encode(chatQuery)) - let expectedValueAsDict = try jsonDataAsNSDictionary(expectedValue.data(using: .utf8)!) - - XCTAssertEqual(chatQueryAsDict, expectedValueAsDict) + try encode(chatQuery, expectedValue) } func testChatCompletionWithFunctionCall() async throws { @@ -214,9 +236,16 @@ class OpenAITestsDecoder: XCTestCase { "message": { "role": "assistant", "content": null, - "function_call": { - "name": "get_current_weather" - } + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{}" + } + } + ] }, "finish_reason": "function_call" } @@ -235,9 +264,8 @@ class OpenAITestsDecoder: XCTestCase { created: 1677652288, model: .gpt3_5Turbo, choices: [ - .init(index: 0, message: - Chat(role: .assistant, - functionCall: ChatFunctionCall(name: "get_current_weather", arguments: nil)), + .init(index: 0, + message: .assistant(content: nil, toolCalls: [ChatToolCall(id: "call_abc123", function: ChatFunctionCall(name: "get_current_weather", arguments: "{}"))]), finishReason: "function_call") ], usage: .init(promptTokens: 82, completionTokens: 18, totalTokens: 100)) @@ -403,4 +431,198 @@ class OpenAITestsDecoder: XCTestCase { let expectedValue = AudioTranslationResult(text: "Hello, world!") try decode(data, expectedValue) } + + func testAssistantResult() async throws { + let data = """ + { + "id": "asst_abc123", + "object": "assistant", + "created_at": 1698984975, + "name": "Math Tutor", + "description": null, + "model": "gpt-4", + "instructions": "You are a personal math tutor. When asked a question, write and run Python code to answer the question.", + "tools": [ + { + "type": "code_interpreter" + } + ], + "file_ids": [], + "metadata": {} + } + """ + + let expectedValue = AssistantResult(id: "asst_abc123", name: "Math Tutor", description: nil, instructions: "You are a personal math tutor. When asked a question, write and run Python code to answer the question.", tools: [.codeInterpreter], fileIds: []) + try decode(data, expectedValue) + } + + func testAssistantsQuery() async throws { + let assistantsQuery = AssistantsQuery( + model: .gpt4, + name: "Math Tutor", + description: nil, + instructions: "You are a personal math tutor. When asked a question, write and run Python code to answer the question.", + tools: [.codeInterpreter], + fileIds: nil + ) + + let expectedValue = """ + { + "instructions": "You are a personal math tutor. When asked a question, write and run Python code to answer the question.", + "name": "Math Tutor", + "tools": [ + {"type": "code_interpreter"} + ], + "model": "gpt-4" + } + """ + + try encode(assistantsQuery, expectedValue) + } + + func testAssistantsResult() async throws { + let data = """ + { + "object": "list", + "data": [ + { + "id": "asst_abc123", + "object": "assistant", + "created_at": 1698982736, + "name": "Coding Tutor", + "description": null, + "model": "gpt-4", + "instructions": "You are a helpful assistant designed to make me better at coding!", + "tools": [], + "file_ids": [], + "metadata": {} + }, + { + "id": "asst_abc456", + "object": "assistant", + "created_at": 1698982718, + "name": "My Assistant", + "description": null, + "model": "gpt-4", + "instructions": "You are a helpful assistant designed to teach me about AI!", + "tools": [], + "file_ids": [], + "metadata": {} + } + ], + "first_id": "asst_abc123", + "last_id": "asst_abc789", + "has_more": false + } + """ + + let expectedValue = AssistantsResult( + data: [ + .init(id: "asst_abc123", name: "Coding Tutor", description: nil, instructions: "You are a helpful assistant designed to make me better at coding!", tools: [], fileIds: []), + .init(id: "asst_abc456", name: "My Assistant", description: nil, instructions: "You are a helpful assistant designed to teach me about AI!", tools: [], fileIds: []), + ], + firstId: "asst_abc123", + lastId: "asst_abc789", + hasMore: false + ) + + try decode(data, expectedValue) + } + + func testMessageQuery() async throws { + let messageQuery = MessageQuery( + role: .user, + content: "How does AI work? Explain it in simple terms.", + fileIds: ["file_abc123"] + ) + + let expectedValue = """ + { + "role": "user", + "content": "How does AI work? Explain it in simple terms.", + "file_ids": ["file_abc123"] + } + """ + + try encode(messageQuery, expectedValue) + } + + func testRunResult() async throws { + let data = """ + { + "id": "run_1a", + "thread_id": "thread_2b", + "status": "requires_action", + "required_action": { + "type": "submit_tool_outputs", + "submit_tool_outputs": { + "tool_calls": [ + { + "id": "tool_abc890", + "type": "function", + "function": { + "name": "print", + "arguments": "{\\"text\\": \\"hello\\"}" + } + } + ] + } + } + } + """ + + let expectedValue = RunResult( + id: "run_1a", + threadId: "thread_2b", + status: .requiresAction, + requiredAction: .init( + submitToolOutputs: .init(toolCalls: [.init(id: "tool_abc890", type: "function", function: .init(name: "print", arguments: "{\"text\": \"hello\"}"))]) + ) + ) + + try decode(data, expectedValue) + } + + func testRunToolOutputsQuery() async throws { + let runToolOutputsQuery = RunToolOutputsQuery( + toolOutputs: [ + .init(toolCallId: "call_abc0", output: "success") + ] + ) + + let expectedValue = """ + { + "tool_outputs": [ + { + "tool_call_id": "call_abc0", + "output": "success" + } + ] + } + """ + + try encode(runToolOutputsQuery, expectedValue) + } + + func testThreadRunQuery() async throws { + let threadRunQuery = ThreadRunQuery( + assistantId: "asst_abc123", + thread: .init( + messages: [.init(role: .user, content: "Explain deep learning to a 5 year old.")] + ) + ) + + let expectedValue = """ + { + "assistant_id": "asst_abc123", + "thread": { + "messages": [ + {"role": "user", "content": "Explain deep learning to a 5 year old."} + ] + } + } + """ + + try encode(threadRunQuery, expectedValue) + } } From a67a3ff189e7e8938f003b9591e5cd6a5c4595c7 Mon Sep 17 00:00:00 2001 From: Chris Dillard Date: Wed, 21 Feb 2024 17:54:45 -0700 Subject: [PATCH 6/6] Only run testModerationsIterable on iOS 16+ --- Tests/OpenAITests/OpenAITests.swift | 1 + 1 file changed, 1 insertion(+) diff --git a/Tests/OpenAITests/OpenAITests.swift b/Tests/OpenAITests/OpenAITests.swift index 99080984..36b16706 100644 --- a/Tests/OpenAITests/OpenAITests.swift +++ b/Tests/OpenAITests/OpenAITests.swift @@ -255,6 +255,7 @@ class OpenAITests: XCTestCase { XCTAssertEqual(result, moderationsResult) } + @available(iOS 16.0, *) func testModerationsIterable() { let categories = ModerationsResult.Moderation.Categories(harassment: false, harassmentThreatening: false, hate: false, hateThreatening: false, selfHarm: false, selfHarmIntent: false, selfHarmInstructions: false, sexual: false, sexualMinors: false, violence: false, violenceGraphic: false) Mirror(reflecting: categories).children.enumerated().forEach { index, element in