Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add streaming session and ability to use streaming #57

Merged
merged 19 commits into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions Sources/OpenAI/OpenAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ final public class OpenAI: OpenAIProtocol {
}

private let session: URLSessionProtocol
private var streamingSessions: [NSObject] = []
Krivoblotsky marked this conversation as resolved.
Show resolved Hide resolved

public let configuration: Configuration

Expand Down Expand Up @@ -64,7 +65,11 @@ final public class OpenAI: OpenAIProtocol {
}

public func chats(query: ChatQuery, completion: @escaping (Result<ChatResult, Error>) -> Void) {
performRequest(request: JSONRequest<ChatResult>(body: query, url: buildURL(path: .chats)), completion: completion)
if query.stream == true {
Krivoblotsky marked this conversation as resolved.
Show resolved Hide resolved
performSteamingRequest(request: JSONRequest<ChatResult>(body: query, url: buildURL(path: .chats)), completion: completion)
} else {
performRequest(request: JSONRequest<ChatResult>(body: query, url: buildURL(path: .chats)), completion: completion)
}
}

public func edits(query: EditsQuery, completion: @escaping (Result<EditsResult, Error>) -> Void) {
Expand Down Expand Up @@ -127,7 +132,26 @@ extension OpenAI {
task.resume()
} catch {
completion(.failure(error))
return
}
}

func performSteamingRequest<ResultType: Codable>(request: any URLRequestBuildable, completion: @escaping (Result<ResultType, Error>) -> Void) {
do {
let request = try request.build(token: configuration.token, organizationIdentifier: configuration.organizationIdentifier, timeoutInterval: configuration.timeoutInterval)
let session = StreamingSession<ResultType>(urlRequest: request)
session.onReceiveContent = {_, object in
completion(.success(object))
}
session.onProcessingError = {_, error in
completion(.failure(error))
}
session.onComplete = { [weak self] object, error in
self?.streamingSessions.removeAll(where: { $0 == object })
Krivoblotsky marked this conversation as resolved.
Show resolved Hide resolved
}
session.perform()
streamingSessions.append(session)
} catch {
completion(.failure(error))
}
}
}
Expand Down
76 changes: 76 additions & 0 deletions Sources/OpenAI/Private/StreamingSession.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
//
// StreamingSession.swift
//
//
// Created by Sergii Kryvoblotskyi on 18/04/2023.
//

import Foundation
#if canImport(FoundationNetworking)
import FoundationNetworking
#endif

final class StreamingSession<ResultType: Codable>: NSObject, Identifiable, URLSessionDelegate, URLSessionDataDelegate {

enum StreamingError: Error {
case unknownContent
case emptyContent
}

var onReceiveContent: ((StreamingSession, ResultType) -> Void)?
var onProcessingError: ((StreamingSession, Error) -> Void)?
var onComplete: ((StreamingSession, Error?) -> Void)?

private let streamingCompletionMarker = "[DONE]"
private let urlRequest: URLRequest
private lazy var urlSession: URLSession = {
let session = URLSession(configuration: .default, delegate: self, delegateQueue: nil)
return session
}()

init(urlRequest: URLRequest) {
self.urlRequest = urlRequest
}

func perform() {
self.urlSession
.dataTask(with: self.urlRequest)
.resume()
}

func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
onComplete?(self, error)
}

func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) {
guard let stringContent = String(data: data, encoding: .utf8) else {
onProcessingError?(self, StreamingError.unknownContent)
return
}
let jsonObjects = stringContent
.components(separatedBy: "data:")
.filter { $0.isEmpty == false }
.map { $0.trimmingCharacters(in: .whitespacesAndNewlines) }

guard jsonObjects.isEmpty == false, jsonObjects.first != streamingCompletionMarker else {
onProcessingError?(self, StreamingError.emptyContent)
return
}
jsonObjects.forEach { jsonContent in
guard jsonContent != streamingCompletionMarker else {
return
}
guard let jsonData = jsonContent.data(using: .utf8) else {
onProcessingError?(self, StreamingError.unknownContent)
return
}
do {
let decoder = JSONDecoder()
let object = try decoder.decode(ResultType.self, from: jsonData)
onReceiveContent?(self, object)
} catch {
onProcessingError?(self, error)
}
}
}
}
5 changes: 5 additions & 0 deletions Sources/OpenAI/Private/URLSessionProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@ import FoundationNetworking
protocol URLSessionProtocol {

func dataTask(with request: URLRequest, completionHandler: @escaping @Sendable (Data?, URLResponse?, Error?) -> Void) -> URLSessionDataTaskProtocol
func dataTask(with request: URLRequest) -> URLSessionDataTaskProtocol
}

extension URLSession: URLSessionProtocol {

func dataTask(with request: URLRequest) -> URLSessionDataTaskProtocol {
dataTask(with: request) as URLSessionDataTask
}

func dataTask(with request: URLRequest, completionHandler: @escaping @Sendable (Data?, URLResponse?, Error?) -> Void) -> URLSessionDataTaskProtocol {
dataTask(with: request, completionHandler: completionHandler) as URLSessionDataTask
}
Expand Down
13 changes: 10 additions & 3 deletions Sources/OpenAI/Public/Models/ChatResult.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,20 @@ import Foundation
public struct ChatResult: Codable, Equatable {

public struct Choice: Codable, Equatable {
public struct Delta: Codable, Equatable {
public let content: String?
public let role: Chat.Role?
}

public let index: Int
public let message: Chat
public let finishReason: String
public let message: Chat?
Copy link
Contributor

@DJBen DJBen Apr 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to mix the complete and partial results, can you comment in what condition these properties are nil?

I have moderate preference in separating a complete Choice from a PartialChoice, provided that they have major differences, despite the official OpenAI API mix them together.

// Choice stays the same

public struct PartialChoice: Codable, Equatable {
  struct Delta: Codable, Equatable {
    public let content: String?
    public let role: Chat.Role?
  }
  public let index: Int
  public let delta: Delta
}

As a result, the result can contain enum of either Choice or PartialChoice
What do you think? Appreciated!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've decided to keep ChatResultto not break backward compatibility and created "ChatStreamingResult" along with a separate function for streaming. I found this solution more straightforward and clean. What do you think about this?

public let delta: Delta?
public let finishReason: String?

enum CodingKeys: String, CodingKey {
case index
case message
case delta
case finishReason = "finish_reason"
}
}
Expand All @@ -38,7 +45,7 @@ public struct ChatResult: Codable, Equatable {
public let created: TimeInterval
public let model: Model
public let choices: [Choice]
public let usage: Usage
public let usage: Usage?

enum CodingKeys: String, CodingKey {
case id
Expand Down
8 changes: 8 additions & 0 deletions Sources/OpenAI/Public/Protocols/OpenAIProtocol+Async.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ public extension OpenAIProtocol {
}
}

func chatsStream(
query: ChatQuery
) -> AsyncThrowingStream<ChatResult, Error> {
return AsyncThrowingStream { continuation in
return chats(query: query) { continuation.yield(with: $0) }
}
}

func edits(
query: EditsQuery
) async throws -> EditsResult {
Expand Down
4 changes: 4 additions & 0 deletions Tests/OpenAITests/Mocks/URLSessionMock.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,8 @@ class URLSessionMock: URLSessionProtocol {
dataTask.completion = completionHandler
return dataTask
}

func dataTask(with request: URLRequest) -> URLSessionDataTaskProtocol {
dataTask
}
}
6 changes: 3 additions & 3 deletions Tests/OpenAITests/OpenAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ class OpenAITests: XCTestCase {
.init(role: .user, content: "Who wrote Harry Potter?")
])
let chatResult = ChatResult(id: "id-12312", object: "foo", created: 100, model: .gpt3_5Turbo, choices: [
.init(index: 0, message: .init(role: .system, content: "bar"), finishReason: "baz"),
.init(index: 0, message: .init(role: .user, content: "bar1"), finishReason: "baz1"),
.init(index: 0, message: .init(role: .assistant, content: "bar2"), finishReason: "baz2")
.init(index: 0, message: .init(role: .system, content: "bar"), delta: nil, finishReason: "baz"),
.init(index: 0, message: .init(role: .user, content: "bar1"), delta: nil, finishReason: "baz1"),
.init(index: 0, message: .init(role: .assistant, content: "bar2"), delta: nil, finishReason: "baz2")
], usage: .init(promptTokens: 100, completionTokens: 200, totalTokens: 300))
try self.stub(result: chatResult)

Expand Down
6 changes: 3 additions & 3 deletions Tests/OpenAITests/OpenAITestsCombine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ final class OpenAITestsCombine: XCTestCase {
.init(role: .user, content: "Who wrote Harry Potter?")
])
let chatResult = ChatResult(id: "id-12312", object: "foo", created: 100, model: .gpt3_5Turbo, choices: [
.init(index: 0, message: .init(role: .system, content: "bar"), finishReason: "baz"),
.init(index: 0, message: .init(role: .user, content: "bar1"), finishReason: "baz1"),
.init(index: 0, message: .init(role: .assistant, content: "bar2"), finishReason: "baz2")
.init(index: 0, message: .init(role: .system, content: "bar"), delta: nil, finishReason: "baz"),
.init(index: 0, message: .init(role: .user, content: "bar1"), delta: nil, finishReason: "baz1"),
.init(index: 0, message: .init(role: .assistant, content: "bar2"), delta: nil, finishReason: "baz2")
], usage: .init(promptTokens: 100, completionTokens: 200, totalTokens: 300))
try self.stub(result: chatResult)
let result = try awaitPublisher(openAI.chats(query: query))
Expand Down
2 changes: 1 addition & 1 deletion Tests/OpenAITests/OpenAITestsDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class OpenAITestsDecoder: XCTestCase {
"""

let expectedValue = ChatResult(id: "chatcmpl-123", object: "chat.completion", created: 1677652288, model: .gpt4, choices: [
.init(index: 0, message: Chat(role: .assistant, content: "Hello, world!"), finishReason: "stop")
.init(index: 0, message: Chat(role: .assistant, content: "Hello, world!"), delta: nil, finishReason: "stop")
], usage: .init(promptTokens: 9, completionTokens: 12, totalTokens: 21))
try decode(data, expectedValue)
}
Expand Down