Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add streaming session and ability to use streaming #57

Merged
merged 19 commits into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 62 additions & 25 deletions Demo/DemoChat/Sources/ChatStore.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,53 +53,90 @@ public final class ChatStore: ObservableObject {
}

@MainActor
func sendMessage(_ message: Message, conversationId: Conversation.ID) async {
func sendMessage(
_ message: Message,
conversationId: Conversation.ID,
model: Model
) async {
guard let conversationIndex = conversations.firstIndex(where: { $0.id == conversationId }) else {
return
}
conversations[conversationIndex].messages.append(message)

await completeChat(conversationId: conversationId)
await completeChat(
conversationId: conversationId,
model: model
)
}

@MainActor
func completeChat(conversationId: Conversation.ID) async {
func completeChat(
conversationId: Conversation.ID,
model: Model
) async {
guard let conversation = conversations.first(where: { $0.id == conversationId }) else {
return
}

conversationErrors[conversationId] = nil

do {
let response = try await openAIClient.chats(
guard let conversationIndex = conversations.firstIndex(where: { $0.id == conversationId }) else {
return
}

let chatsStream = openAIClient.chatsStream(
query: ChatQuery(
model: .gpt3_5Turbo,
model: model,
messages: conversation.messages.map { message in
Chat(role: message.role, content: message.content)
}
},
stream: true
)
)

guard let conversationIndex = conversations.firstIndex(where: { $0.id == conversationId }) else {
return
}

let existingMessages = conversations[conversationIndex].messages

for completionMessage in response.choices.map(\.message) {
let message = Message(
id: response.id,
role: completionMessage.role,
content: completionMessage.content,
createdAt: Date(timeIntervalSince1970: TimeInterval(response.created))
)

if existingMessages.contains(message) {
continue

for try await partialChatResult in chatsStream {
for choice in partialChatResult.choices {
let existingMessages = conversations[conversationIndex].messages

let message: Message
if let delta = choice.delta {
message = Message(
id: partialChatResult.id,
role: delta.role ?? .assistant,
content: delta.content ?? "",
createdAt: Date(timeIntervalSince1970: TimeInterval(partialChatResult.created))
)
if let existingMessageIndex = existingMessages.firstIndex(where: { $0.id == partialChatResult.id }) {
// Meld into previous message
let previousMessage = existingMessages[existingMessageIndex]
let combinedMessage = Message(
id: message.id, // id stays the same for different deltas
role: message.role,
content: previousMessage.content + message.content,
createdAt: message.createdAt
)
conversations[conversationIndex].messages[existingMessageIndex] = combinedMessage
} else {
conversations[conversationIndex].messages.append(message)
}
} else {
let choiceMessage = choice.message!

message = Message(
id: partialChatResult.id,
role: choiceMessage.role,
content: choiceMessage.content,
createdAt: Date(timeIntervalSince1970: TimeInterval(partialChatResult.created))
)

if existingMessages.contains(message) {
continue
}
conversations[conversationIndex].messages.append(message)
}
}
conversations[conversationIndex].messages.append(message)
}

} catch {
conversationErrors[conversationId] = error
}
Expand Down
5 changes: 3 additions & 2 deletions Demo/DemoChat/Sources/UI/ConversationView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public struct ChatView: View {
DetailView(
conversation: conversation,
error: store.conversationErrors[conversation.id],
sendMessage: { message in
sendMessage: { message, selectedModel in
Task {
await store.sendMessage(
Message(
Expand All @@ -55,7 +55,8 @@ public struct ChatView: View {
content: message,
createdAt: dateProvider()
),
conversationId: conversation.id
conversationId: conversation.id,
model: selectedModel
)
}
}
Expand Down
56 changes: 53 additions & 3 deletions Demo/DemoChat/Sources/UI/DetailView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@ import UIKit
#elseif os(macOS)
import AppKit
#endif
import OpenAI
import SwiftUI

struct DetailView: View {
@State var inputText: String = ""
@FocusState private var isFocused: Bool
@State private var showsModelSelectionSheet = false
@State private var selectedChatModel: Model = .gpt3_5Turbo

private let availableChatModels: [Model] = [.gpt3_5Turbo, .gpt4]

let conversation: Conversation
let error: Error?
let sendMessage: (String) -> Void
let sendMessage: (String, Model) -> Void

private var fillColor: Color {
#if os(iOS)
Expand Down Expand Up @@ -61,6 +66,51 @@ struct DetailView: View {
inputBar(scrollViewProxy: scrollViewProxy)
}
.navigationTitle("Chat")
.safeAreaInset(edge: .top) {
HStack {
Text(
"Model: \(selectedChatModel)"
)
.font(.caption)
.foregroundColor(.secondary)
Spacer()
}
.padding(.horizontal, 16)
.padding(.vertical, 8)
}
.toolbar {
ToolbarItem(placement: .navigationBarTrailing) {
Button(action: {
showsModelSelectionSheet.toggle()
}) {
Image(systemName: "cpu")
}
}
}
.confirmationDialog(
"Select model",
isPresented: $showsModelSelectionSheet,
titleVisibility: .visible,
actions: {
ForEach(availableChatModels, id: \.self) { model in
Button {
selectedChatModel = model
} label: {
Text(model)
}
}

Button("Cancel", role: .cancel) {
showsModelSelectionSheet = false
}
},
message: {
Text(
"View https://platform.openai.com/docs/models/overview for details"
)
.font(.caption)
}
)
}
}
}
Expand Down Expand Up @@ -133,7 +183,7 @@ struct DetailView: View {
private func tapSendMessage(
scrollViewProxy: ScrollViewProxy
) {
sendMessage(inputText)
sendMessage(inputText, selectedChatModel)
inputText = ""

// if let lastMessage = conversation.messages.last {
Expand Down Expand Up @@ -206,7 +256,7 @@ struct DetailView_Previews: PreviewProvider {
]
),
error: nil,
sendMessage: { _ in }
sendMessage: { _, _ in }
)
}
}
Expand Down
28 changes: 26 additions & 2 deletions Sources/OpenAI/OpenAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ final public class OpenAI: OpenAIProtocol {
}

private let session: URLSessionProtocol
private var streamingSessions: [NSObject] = []
Krivoblotsky marked this conversation as resolved.
Show resolved Hide resolved

public let configuration: Configuration

Expand Down Expand Up @@ -64,7 +65,11 @@ final public class OpenAI: OpenAIProtocol {
}

public func chats(query: ChatQuery, completion: @escaping (Result<ChatResult, Error>) -> Void) {
performRequest(request: JSONRequest<ChatResult>(body: query, url: buildURL(path: .chats)), completion: completion)
if query.stream == true {
Krivoblotsky marked this conversation as resolved.
Show resolved Hide resolved
performSteamingRequest(request: JSONRequest<ChatResult>(body: query, url: buildURL(path: .chats)), completion: completion)
} else {
performRequest(request: JSONRequest<ChatResult>(body: query, url: buildURL(path: .chats)), completion: completion)
}
}

public func edits(query: EditsQuery, completion: @escaping (Result<EditsResult, Error>) -> Void) {
Expand Down Expand Up @@ -127,7 +132,26 @@ extension OpenAI {
task.resume()
} catch {
completion(.failure(error))
return
}
}

func performSteamingRequest<ResultType: Codable>(request: any URLRequestBuildable, completion: @escaping (Result<ResultType, Error>) -> Void) {
do {
let request = try request.build(token: configuration.token, organizationIdentifier: configuration.organizationIdentifier, timeoutInterval: configuration.timeoutInterval)
let session = StreamingSession<ResultType>(urlRequest: request)
session.onReceiveContent = {_, object in
completion(.success(object))
}
session.onProcessingError = {_, error in
completion(.failure(error))
}
session.onComplete = { [weak self] object, error in
self?.streamingSessions.removeAll(where: { $0 == object })
Krivoblotsky marked this conversation as resolved.
Show resolved Hide resolved
}
session.perform()
streamingSessions.append(session)
} catch {
completion(.failure(error))
}
}
}
Expand Down
76 changes: 76 additions & 0 deletions Sources/OpenAI/Private/StreamingSession.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
//
// StreamingSession.swift
//
//
// Created by Sergii Kryvoblotskyi on 18/04/2023.
//

import Foundation
#if canImport(FoundationNetworking)
import FoundationNetworking
#endif

final class StreamingSession<ResultType: Codable>: NSObject, Identifiable, URLSessionDelegate, URLSessionDataDelegate {

enum StreamingError: Error {
case unknownContent
case emptyContent
}

var onReceiveContent: ((StreamingSession, ResultType) -> Void)?
var onProcessingError: ((StreamingSession, Error) -> Void)?
var onComplete: ((StreamingSession, Error?) -> Void)?

private let streamingCompletionMarker = "[DONE]"
private let urlRequest: URLRequest
private lazy var urlSession: URLSession = {
let session = URLSession(configuration: .default, delegate: self, delegateQueue: nil)
return session
}()

init(urlRequest: URLRequest) {
self.urlRequest = urlRequest
}

func perform() {
self.urlSession
.dataTask(with: self.urlRequest)
.resume()
}

func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
onComplete?(self, error)
}

func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) {
guard let stringContent = String(data: data, encoding: .utf8) else {
onProcessingError?(self, StreamingError.unknownContent)
return
}
let jsonObjects = stringContent
.components(separatedBy: "data:")
.filter { $0.isEmpty == false }
.map { $0.trimmingCharacters(in: .whitespacesAndNewlines) }

guard jsonObjects.isEmpty == false, jsonObjects.first != streamingCompletionMarker else {
onProcessingError?(self, StreamingError.emptyContent)
return
}
jsonObjects.forEach { jsonContent in
guard jsonContent != streamingCompletionMarker else {
return
}
guard let jsonData = jsonContent.data(using: .utf8) else {
onProcessingError?(self, StreamingError.unknownContent)
return
}
do {
let decoder = JSONDecoder()
let object = try decoder.decode(ResultType.self, from: jsonData)
onReceiveContent?(self, object)
} catch {
onProcessingError?(self, error)
}
}
}
}
5 changes: 5 additions & 0 deletions Sources/OpenAI/Private/URLSessionProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@ import FoundationNetworking
protocol URLSessionProtocol {

func dataTask(with request: URLRequest, completionHandler: @escaping @Sendable (Data?, URLResponse?, Error?) -> Void) -> URLSessionDataTaskProtocol
func dataTask(with request: URLRequest) -> URLSessionDataTaskProtocol
}

extension URLSession: URLSessionProtocol {

func dataTask(with request: URLRequest) -> URLSessionDataTaskProtocol {
dataTask(with: request) as URLSessionDataTask
}

func dataTask(with request: URLRequest, completionHandler: @escaping @Sendable (Data?, URLResponse?, Error?) -> Void) -> URLSessionDataTaskProtocol {
dataTask(with: request, completionHandler: completionHandler) as URLSessionDataTask
}
Expand Down
1 change: 1 addition & 0 deletions Sources/OpenAI/Public/Models/ChatQuery.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public struct ChatQuery: Codable {
/// How many chat completion choices to generate for each input message.
public let n: Int?
/// If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only `server-sent events` as they become available, with the stream terminated by a data: [DONE] message.
/// If you want to perform the query in a streaming fashion, set this to `true` and use `OpenAI.chatsStream(query:)` method.
public let stream: Bool?
/// Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.
public let stop: [String]?
Expand Down
Loading