Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add streaming session and ability to use streaming #57

Merged
merged 19 commits into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 42 additions & 24 deletions Demo/DemoChat/Sources/ChatStore.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,53 +53,71 @@ public final class ChatStore: ObservableObject {
}

@MainActor
func sendMessage(_ message: Message, conversationId: Conversation.ID) async {
func sendMessage(
_ message: Message,
conversationId: Conversation.ID,
model: Model
) async {
guard let conversationIndex = conversations.firstIndex(where: { $0.id == conversationId }) else {
return
}
conversations[conversationIndex].messages.append(message)

await completeChat(conversationId: conversationId)
await completeChat(
conversationId: conversationId,
model: model
)
}

@MainActor
func completeChat(conversationId: Conversation.ID) async {
func completeChat(
conversationId: Conversation.ID,
model: Model
) async {
guard let conversation = conversations.first(where: { $0.id == conversationId }) else {
return
}

conversationErrors[conversationId] = nil

do {
let response = try await openAIClient.chats(
guard let conversationIndex = conversations.firstIndex(where: { $0.id == conversationId }) else {
return
}

let chatsStream: AsyncThrowingStream<ChatStreamResult, Error> = openAIClient.chatsStream(
query: ChatQuery(
model: .gpt3_5Turbo,
model: model,
messages: conversation.messages.map { message in
Chat(role: message.role, content: message.content)
}
)
)

guard let conversationIndex = conversations.firstIndex(where: { $0.id == conversationId }) else {
return
}

let existingMessages = conversations[conversationIndex].messages

for completionMessage in response.choices.map(\.message) {
let message = Message(
id: response.id,
role: completionMessage.role,
content: completionMessage.content,
createdAt: Date(timeIntervalSince1970: TimeInterval(response.created))
)

if existingMessages.contains(message) {
continue

for try await partialChatResult in chatsStream {
for choice in partialChatResult.choices {
let existingMessages = conversations[conversationIndex].messages
let message = Message(
id: partialChatResult.id,
role: choice.delta.role ?? .assistant,
content: choice.delta.content ?? "",
createdAt: Date(timeIntervalSince1970: TimeInterval(partialChatResult.created))
)
if let existingMessageIndex = existingMessages.firstIndex(where: { $0.id == partialChatResult.id }) {
// Meld into previous message
let previousMessage = existingMessages[existingMessageIndex]
let combinedMessage = Message(
id: message.id, // id stays the same for different deltas
role: message.role,
content: previousMessage.content + message.content,
createdAt: message.createdAt
)
conversations[conversationIndex].messages[existingMessageIndex] = combinedMessage
} else {
conversations[conversationIndex].messages.append(message)
}
}
conversations[conversationIndex].messages.append(message)
}

} catch {
conversationErrors[conversationId] = error
}
Expand Down
5 changes: 3 additions & 2 deletions Demo/DemoChat/Sources/UI/ChatView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public struct ChatView: View {
DetailView(
conversation: conversation,
error: store.conversationErrors[conversation.id],
sendMessage: { message in
sendMessage: { message, selectedModel in
Task {
await store.sendMessage(
Message(
Expand All @@ -55,7 +55,8 @@ public struct ChatView: View {
content: message,
createdAt: dateProvider()
),
conversationId: conversation.id
conversationId: conversation.id,
model: selectedModel
)
}
}
Expand Down
56 changes: 53 additions & 3 deletions Demo/DemoChat/Sources/UI/DetailView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@ import UIKit
#elseif os(macOS)
import AppKit
#endif
import OpenAI
import SwiftUI

struct DetailView: View {
@State var inputText: String = ""
@FocusState private var isFocused: Bool
@State private var showsModelSelectionSheet = false
@State private var selectedChatModel: Model = .gpt3_5Turbo

private let availableChatModels: [Model] = [.gpt3_5Turbo, .gpt4]

let conversation: Conversation
let error: Error?
let sendMessage: (String) -> Void
let sendMessage: (String, Model) -> Void

private var fillColor: Color {
#if os(iOS)
Expand Down Expand Up @@ -61,6 +66,51 @@ struct DetailView: View {
inputBar(scrollViewProxy: scrollViewProxy)
}
.navigationTitle("Chat")
.safeAreaInset(edge: .top) {
HStack {
Text(
"Model: \(selectedChatModel)"
)
.font(.caption)
.foregroundColor(.secondary)
Spacer()
}
.padding(.horizontal, 16)
.padding(.vertical, 8)
}
.toolbar {
ToolbarItem(placement: .navigationBarTrailing) {
Button(action: {
showsModelSelectionSheet.toggle()
}) {
Image(systemName: "cpu")
}
}
}
.confirmationDialog(
"Select model",
isPresented: $showsModelSelectionSheet,
titleVisibility: .visible,
actions: {
ForEach(availableChatModels, id: \.self) { model in
Button {
selectedChatModel = model
} label: {
Text(model)
}
}

Button("Cancel", role: .cancel) {
showsModelSelectionSheet = false
}
},
message: {
Text(
"View https://platform.openai.com/docs/models/overview for details"
)
.font(.caption)
}
)
}
}
}
Expand Down Expand Up @@ -133,7 +183,7 @@ struct DetailView: View {
private func tapSendMessage(
scrollViewProxy: ScrollViewProxy
) {
sendMessage(inputText)
sendMessage(inputText, selectedChatModel)
inputText = ""

// if let lastMessage = conversation.messages.last {
Expand Down Expand Up @@ -206,7 +256,7 @@ struct DetailView_Previews: PreviewProvider {
]
),
error: nil,
sendMessage: { _ in }
sendMessage: { _, _ in }
)
}
}
Expand Down
2 changes: 1 addition & 1 deletion Demo/DemoChat/Sources/UI/ModerationChatView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public struct ModerationChatView: View {
DetailView(
conversation: store.moderationConversation,
error: store.moderationConversationError,
sendMessage: { message in
sendMessage: { message, _ in
Task {
await store.sendModerationMessage(
Message(
Expand Down
78 changes: 76 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ This repository contains Swift community-maintained implementation over [OpenAI]
- [Usage](#usage)
- [Initialization](#initialization)
- [Completions](#completions)
- [Completions Streaming](#completions-streaming)
- [Chats](#chats)
- [Chats Streaming](#chats-streaming)
- [Images](#images)
- [Audio](#audio)
- [Audio Transcriptions](#audio-transcriptions)
Expand Down Expand Up @@ -146,6 +148,43 @@ let result = try await openAI.completions(query: query)
- index : 0
```

#### Completions Streaming

Completions streaming is available by using `completionsStream` function. Tokens will be sent one-by-one.

**Closures**
```swift
openAI.completionsStream(query: query) { partialResult in
switch partialResult {
case .success(let result):
print(result.choices)
case .failure(let error):
//Handle chunk error here
}
} completion: { error in
//Handle streaming error here
}
```

**Combine**

```swift
openAI
.completionsStream(query: query)
.sink { completion in
//Handle completion result here
} receiveValue: { result in
//Handle chunk here
}.store(in: &cancellables)
```

**Structured concurrency**
```swift
for try await result in openAI.completionsStream(query: query) {
//Handle result here
}
```

Review [Completions Documentation](https://platform.openai.com/docs/api-reference/completions) for more info.

### Chats
Expand Down Expand Up @@ -175,8 +214,6 @@ Using the OpenAI Chat API, you can build your own applications with `gpt-3.5-tur
public let topP: Double?
/// How many chat completion choices to generate for each input message.
public let n: Int?
/// If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only `server-sent events` as they become available, with the stream terminated by a data: [DONE] message.
public let stream: Bool?
/// Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.
public let stop: [String]?
/// The maximum number of tokens to generate in the completion.
Expand Down Expand Up @@ -244,6 +281,43 @@ let result = try await openAI.chats(query: query)
- total_tokens : 49
```

#### Chats Streaming

Chats streaming is available by using `chatStream` function. Tokens will be sent one-by-one.

**Closures**
```swift
openAI.chatsStream(query: query) { partialResult in
switch partialResult {
case .success(let result):
print(result.choices)
case .failure(let error):
//Handle chunk error here
}
} completion: { error in
//Handle streaming error here
}
```

**Combine**

```swift
openAI
.chatsStream(query: query)
.sink { completion in
//Handle completion result here
} receiveValue: { result in
//Handle chunk here
}.store(in: &cancellables)
```

**Structured concurrency**
```swift
for try await result in openAI.chatsStream(query: query) {
//Handle result here
}
```

Review [Chat Documentation](https://platform.openai.com/docs/guides/chat) for more info.

### Images
Expand Down
31 changes: 30 additions & 1 deletion Sources/OpenAI/OpenAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ final public class OpenAI: OpenAIProtocol {
}

private let session: URLSessionProtocol
private var streamingSessions: [NSObject] = []
Krivoblotsky marked this conversation as resolved.
Show resolved Hide resolved

public let configuration: Configuration

Expand All @@ -59,6 +60,10 @@ final public class OpenAI: OpenAIProtocol {
performRequest(request: JSONRequest<CompletionsResult>(body: query, url: buildURL(path: .completions)), completion: completion)
}

public func completionsStream(query: CompletionsQuery, onResult: @escaping (Result<CompletionsResult, Error>) -> Void, completion: ((Error?) -> Void)?) {
performSteamingRequest(request: JSONRequest<CompletionsResult>(body: query.makeStreamable(), url: buildURL(path: .completions)), onResult: onResult, completion: completion)
}

public func images(query: ImagesQuery, completion: @escaping (Result<ImagesResult, Error>) -> Void) {
performRequest(request: JSONRequest<ImagesResult>(body: query, url: buildURL(path: .images)), completion: completion)
}
Expand All @@ -71,6 +76,10 @@ final public class OpenAI: OpenAIProtocol {
performRequest(request: JSONRequest<ChatResult>(body: query, url: buildURL(path: .chats)), completion: completion)
}

public func chatsStream(query: ChatQuery, onResult: @escaping (Result<ChatStreamResult, Error>) -> Void, completion: ((Error?) -> Void)?) {
performSteamingRequest(request: JSONRequest<ChatResult>(body: query.makeStreamable(), url: buildURL(path: .chats)), onResult: onResult, completion: completion)
}

public func edits(query: EditsQuery, completion: @escaping (Result<EditsResult, Error>) -> Void) {
performRequest(request: JSONRequest<EditsResult>(body: query, url: buildURL(path: .edits)), completion: completion)
}
Expand Down Expand Up @@ -131,7 +140,27 @@ extension OpenAI {
task.resume()
} catch {
completion(.failure(error))
return
}
}

func performSteamingRequest<ResultType: Codable>(request: any URLRequestBuildable, onResult: @escaping (Result<ResultType, Error>) -> Void, completion: ((Error?) -> Void)?) {
do {
let request = try request.build(token: configuration.token, organizationIdentifier: configuration.organizationIdentifier, timeoutInterval: configuration.timeoutInterval)
let session = StreamingSession<ResultType>(urlRequest: request)
session.onReceiveContent = {_, object in
onResult(.success(object))
}
session.onProcessingError = {_, error in
onResult(.failure(error))
}
session.onComplete = { [weak self] object, error in
self?.streamingSessions.removeAll(where: { $0 == object })
Krivoblotsky marked this conversation as resolved.
Show resolved Hide resolved
completion?(error)
}
session.perform()
streamingSessions.append(session)
} catch {
completion?(error)
}
}
}
Expand Down
Loading