Skip to content

Commit

Permalink
Merge branch 'main' into MacPaw.views_abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
ingvarus-bc committed Feb 7, 2024
2 parents bc5e7e1 + a2c8d17 commit af81a41
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 14 deletions.
4 changes: 2 additions & 2 deletions Demo/DemoChat/Sources/UI/TextToSpeechView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public struct TextToSpeechView: View {

@State private var prompt: String = ""
@State private var voice: AudioSpeechQuery.AudioSpeechVoice = .alloy
@State private var speed: Double = 1
@State private var speed: Double = AudioSpeechQuery.Speed.normal.rawValue
@State private var responseFormat: AudioSpeechQuery.AudioSpeechResponseFormat = .mp3
@State private var showsModelSelectionSheet = false
@State private var selectedSpeechModel: String = Model.tts_1
Expand Down Expand Up @@ -60,7 +60,7 @@ public struct TextToSpeechView: View {
HStack {
Text("Speed: ")
Spacer()
Stepper(value: $speed, in: 0.25...4, step: 0.25) {
Stepper(value: $speed, in: AudioSpeechQuery.Speed.min.rawValue...AudioSpeechQuery.Speed.max.rawValue, step: 0.25) {
HStack {
Spacer()
Text("**\(String(format: "%.2f", speed))**")
Expand Down
2 changes: 1 addition & 1 deletion Sources/OpenAI/OpenAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ final public class OpenAI: OpenAIProtocol {
}

private let session: URLSessionProtocol
private var streamingSessions: [NSObject] = []
private var streamingSessions = ArrayWithThreadSafety<NSObject>()

public let configuration: Configuration

Expand Down
25 changes: 14 additions & 11 deletions Sources/OpenAI/Public/Models/AudioSpeechQuery.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,7 @@ public struct AudioSpeechQuery: Codable, Equatable {
case responseFormat = "response_format"
case speed
}

private enum Constants {
static let normalSpeed = 1.0
static let maxSpeed = 4.0
static let minSpeed = 0.25
}


public init(model: Model, input: String, voice: AudioSpeechVoice, responseFormat: AudioSpeechResponseFormat = .mp3, speed: Double?) {
self.model = AudioSpeechQuery.validateSpeechModel(model)
self.speed = AudioSpeechQuery.normalizeSpeechSpeed(speed)
Expand All @@ -80,13 +74,22 @@ private extension AudioSpeechQuery {
}
return inputModel
}

}

public extension AudioSpeechQuery {

enum Speed: Double {
case normal = 1.0
case max = 4.0
case min = 0.25
}

static func normalizeSpeechSpeed(_ inputSpeed: Double?) -> String {
guard let inputSpeed else { return "\(Constants.normalSpeed)" }
let isSpeedOutOfBounds = inputSpeed >= Constants.maxSpeed && inputSpeed <= Constants.minSpeed
guard let inputSpeed else { return "\(Self.Speed.normal.rawValue)" }
let isSpeedOutOfBounds = inputSpeed <= Self.Speed.min.rawValue || Self.Speed.max.rawValue <= inputSpeed
guard !isSpeedOutOfBounds else {
print("[AudioSpeech] Speed value must be between 0.25 and 4.0. Setting value to closest valid.")
return inputSpeed < Constants.minSpeed ? "\(Constants.minSpeed)" : "\(Constants.maxSpeed)"
return inputSpeed < Self.Speed.min.rawValue ? "\(Self.Speed.min.rawValue)" : "\(Self.Speed.max.rawValue)"
}
return "\(inputSpeed)"
}
Expand Down
26 changes: 26 additions & 0 deletions Sources/OpenAI/Public/Utilities/ArrayWithThreadSafety.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//
// ArrayWithThreadSafety.swift
//
//
// Created by James J Kalafus on 2024-02-01.
//

import Foundation

internal class ArrayWithThreadSafety<Element> {
private var array = [Element]()
private let queue = DispatchQueue(label: "us.kalaf.OpenAI.threadSafeArray", attributes: .concurrent)

@inlinable public func append(_ element: Element) {
queue.async(flags: .barrier) {
self.array.append(element)
}
}

@inlinable public func removeAll(where shouldBeRemoved: @escaping (Element) throws -> Bool) rethrows {
try queue.sync(flags: .barrier) {
try self.array.removeAll(where: shouldBeRemoved)
}
}
}

24 changes: 24 additions & 0 deletions Tests/OpenAITests/OpenAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,30 @@ class OpenAITests: XCTestCase {
XCTAssertEqual(inError, apiError)
}

func testAudioSpeechDoesNotNormalize() async throws {
let query = AudioSpeechQuery(model: .tts_1, input: "Hello, world!", voice: .alloy, responseFormat: .mp3, speed: 2.0)

XCTAssertEqual(query.speed, "\(2.0)")
}

func testAudioSpeechNormalizeNil() async throws {
let query = AudioSpeechQuery(model: .tts_1, input: "Hello, world!", voice: .alloy, responseFormat: .mp3, speed: nil)

XCTAssertEqual(query.speed, "\(1.0)")
}

func testAudioSpeechNormalizeLow() async throws {
let query = AudioSpeechQuery(model: .tts_1, input: "Hello, world!", voice: .alloy, responseFormat: .mp3, speed: 0.0)

XCTAssertEqual(query.speed, "\(0.25)")
}

func testAudioSpeechNormalizeHigh() async throws {
let query = AudioSpeechQuery(model: .tts_1, input: "Hello, world!", voice: .alloy, responseFormat: .mp3, speed: 10.0)

XCTAssertEqual(query.speed, "\(4.0)")
}

func testAudioSpeechError() async throws {
let query = AudioSpeechQuery(model: .tts_1, input: "Hello, world!", voice: .alloy, responseFormat: .mp3, speed: 1.0)
let inError = APIError(message: "foo", type: "bar", param: "baz", code: "100")
Expand Down

0 comments on commit af81a41

Please sign in to comment.