From 03a166cc68b42d39839b2707c95709fb9d546de3 Mon Sep 17 00:00:00 2001 From: James J Kalafus Date: Fri, 16 Feb 2024 15:35:10 -0500 Subject: [PATCH] gpt-4-vision-preview support fix and test https://github.com/MacPaw/OpenAI/pull/169 https://github.com/MacPaw/OpenAI/issues/174 --- Sources/OpenAI/Public/Models/ChatQuery.swift | 67 ++++++++++++++----- Sources/OpenAI/Public/Models/ChatResult.swift | 11 +-- Tests/OpenAITests/OpenAITestsDecoder.swift | 46 ++++++++++++- 3 files changed, 100 insertions(+), 24 deletions(-) diff --git a/Sources/OpenAI/Public/Models/ChatQuery.swift b/Sources/OpenAI/Public/Models/ChatQuery.swift index f2f1a98c..c7a88649 100644 --- a/Sources/OpenAI/Public/Models/ChatQuery.swift +++ b/Sources/OpenAI/Public/Models/ChatQuery.swift @@ -115,12 +115,12 @@ public struct ChatQuery: Equatable, Codable, Streamable { case assistant(Self.ChatCompletionAssistantMessageParam) case tool(Self.ChatCompletionToolMessageParam) - public var content: Self.ChatCompletionUserMessageParam.Content? { get { // TODO: String type except for .user + public var content: Self.ChatCompletionUserMessageParam.Content? { get { switch self { case .system(let systemMessage): return Self.ChatCompletionUserMessageParam.Content.string(systemMessage.content) case .user(let userMessage): - return userMessage.content // TODO: Content type + return userMessage.content case .assistant(let assistantMessage): if let content = assistantMessage.content { return Self.ChatCompletionUserMessageParam.Content.string(content) @@ -178,7 +178,6 @@ public struct ChatQuery: Equatable, Codable, Streamable { public init?( role: Role, content: String? = nil, - imageUrl: URL? = nil, name: String? = nil, toolCalls: [Self.ChatCompletionAssistantMessageParam.ChatCompletionMessageToolCallParam]? = nil, toolCallId: String? = nil @@ -193,8 +192,6 @@ public struct ChatQuery: Equatable, Codable, Streamable { case .user: if let content { self = .user(.init(content: .init(string: content), name: name)) - } else if let imageUrl { - self = .user(.init(content: .init(chatCompletionContentPartImageParam: .init(imageUrl: .init(url: imageUrl.absoluteString, detail: .auto))), name: name)) } else { return nil } @@ -209,6 +206,20 @@ public struct ChatQuery: Equatable, Codable, Streamable { } } + public init?( + role: Role, + content: [ChatCompletionUserMessageParam.Content.VisionContent], + name: String? = nil + ) { + switch role { + case .user: + self = .user(.init(content: .vision(content), name: name)) + default: + return nil + } + + } + private init?( content: String, role: Role, @@ -330,8 +341,7 @@ public struct ChatQuery: Equatable, Codable, Streamable { public enum Content: Codable, Equatable { case string(String) - case chatCompletionContentPartTextParam(ChatCompletionContentPartTextParam) - case chatCompletionContentPartImageParam(ChatCompletionContentPartImageParam) + case vision([VisionContent]) public var string: String? { get { switch self { @@ -342,6 +352,33 @@ public struct ChatQuery: Equatable, Codable, Streamable { } }} + public init(string: String) { + self = .string(string) + } + + public init(vision: [VisionContent]) { + self = .vision(vision) + } + + public enum CodingKeys: CodingKey { + case string + case vision + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .string(let a0): + try container.encode(a0) + case .vision(let a0): + try container.encode(a0) + } + } + + public enum VisionContent: Codable, Equatable { + case chatCompletionContentPartTextParam(ChatCompletionContentPartTextParam) + case chatCompletionContentPartImageParam(ChatCompletionContentPartImageParam) + public var text: String? { get { switch self { case .chatCompletionContentPartTextParam(let text): @@ -360,10 +397,6 @@ public struct ChatQuery: Equatable, Codable, Streamable { } }} - public init(string: String) { - self = .string(string) - } - public init(chatCompletionContentPartTextParam: ChatCompletionContentPartTextParam) { self = .chatCompletionContentPartTextParam(chatCompletionContentPartTextParam) } @@ -375,8 +408,6 @@ public struct ChatQuery: Equatable, Codable, Streamable { public func encode(to encoder: Encoder) throws { var container = encoder.singleValueContainer() switch self { - case .string(let a0): - try container.encode(a0) case .chatCompletionContentPartTextParam(let a0): try container.encode(a0) case .chatCompletionContentPartImageParam(let a0): @@ -385,7 +416,6 @@ public struct ChatQuery: Equatable, Codable, Streamable { } enum CodingKeys: CodingKey { - case string case chatCompletionContentPartTextParam case chatCompletionContentPartImageParam } @@ -409,7 +439,7 @@ public struct ChatQuery: Equatable, Codable, Streamable { public init(imageUrl: ImageURL) { self.imageUrl = imageUrl - self.type = "imageUrl" + self.type = "image_url" } public struct ImageURL: Codable, Equatable { @@ -424,6 +454,12 @@ public struct ChatQuery: Equatable, Codable, Streamable { self.detail = detail } + public init(url: Data, detail: Detail) { + self.init( + url: "data:image/jpeg;base64,\(url.base64EncodedString())", + detail: detail) + } + public enum Detail: String, Codable, Equatable, CaseIterable { case auto case low @@ -438,6 +474,7 @@ public struct ChatQuery: Equatable, Codable, Streamable { } } } + } internal struct ChatCompletionMessageParam: Codable, Equatable { typealias Role = ChatQuery.ChatCompletionMessageParam.Role diff --git a/Sources/OpenAI/Public/Models/ChatResult.swift b/Sources/OpenAI/Public/Models/ChatResult.swift index 5e42c37c..c2f7c12d 100644 --- a/Sources/OpenAI/Public/Models/ChatResult.swift +++ b/Sources/OpenAI/Public/Models/ChatResult.swift @@ -145,15 +145,10 @@ extension ChatQuery.ChatCompletionMessageParam.ChatCompletionUserMessageParam.Co return } catch {} do { - let text = try container.decode(ChatCompletionContentPartTextParam.self) - self = .chatCompletionContentPartTextParam(text) + let vision = try container.decode([VisionContent].self) + self = .vision(vision) return } catch {} - do { - let image = try container.decode(ChatCompletionContentPartImageParam.self) - self = .chatCompletionContentPartImageParam(image) - return - } catch {} - throw DecodingError.typeMismatch(Self.self, .init(codingPath: [Self.CodingKeys.string, CodingKeys.chatCompletionContentPartTextParam, CodingKeys.chatCompletionContentPartImageParam], debugDescription: "Content: expected String, ChatCompletionContentPartTextParam, ChatCompletionContentPartImageParam")) + throw DecodingError.typeMismatch(Self.self, .init(codingPath: [Self.CodingKeys.string, Self.CodingKeys.vision], debugDescription: "Content: expected String || Vision")) } } diff --git a/Tests/OpenAITests/OpenAITestsDecoder.swift b/Tests/OpenAITests/OpenAITestsDecoder.swift index d9672c04..740135fd 100644 --- a/Tests/OpenAITests/OpenAITestsDecoder.swift +++ b/Tests/OpenAITests/OpenAITestsDecoder.swift @@ -140,7 +140,51 @@ class OpenAITestsDecoder: XCTestCase { XCTAssertEqual(imageQueryAsDict, expectedValueAsDict) } - + + func testChatQueryWithVision() async throws { + let chatQuery = ChatQuery(messages: [ +// .init(role: .user, content: [ +// .chatCompletionContentPartTextParam(.init(text: "What's in this image?")), +// .chatCompletionContentPartImageParam(.init(imageUrl: .init(url: "https://some.url/image.jpeg", detail: .auto))) +// ])! + .user(.init(content: .vision([ + .chatCompletionContentPartTextParam(.init(text: "What's in this image?")), + .chatCompletionContentPartImageParam(.init(imageUrl: .init(url: "https://some.url/image.jpeg", detail: .auto))) + ]))) + ], model: Model.gpt4_vision_preview, maxTokens: 300) + let expectedValue = """ + { + "model": "gpt-4-vision-preview", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What's in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": "https://some.url/image.jpeg", + "detail": "auto" + } + } + ] + } + ], + "max_tokens": 300, + "stream": false + } + """ + + // To compare serialized JSONs we first convert them both into NSDictionary which are comparable (unline native swift dictionaries) + let chatQueryAsDict = try jsonDataAsNSDictionary(JSONEncoder().encode(chatQuery)) + let expectedValueAsDict = try jsonDataAsNSDictionary(expectedValue.data(using: .utf8)!) + + XCTAssertEqual(chatQueryAsDict, expectedValueAsDict) + } + func testChatQueryWithFunctionCall() async throws { let chatQuery = ChatQuery( messages: [