Skip to content

Commit

Permalink
Make additional initializers, functions, members public for extensibi…
Browse files Browse the repository at this point in the history
…lity (#192)

* Make additional initializers, functions, members public, for WKPro

* Allows use of default internal functions & member accesses which have
  increased protections when imported

* Initializers were Xcode generated: right click class name -> refactor
  -> generate memberwise initializers
   * memberwise initializer defaults to internal, mark as public.

* Formatting

---------

Co-authored-by: ZachNagengast <[email protected]>
  • Loading branch information
bpkeene and ZachNagengast authored Aug 7, 2024
1 parent c93d613 commit 37007ef
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 23 deletions.
5 changes: 3 additions & 2 deletions Sources/WhisperKit/Core/AudioProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public extension AudioProcessing {
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
static func loadAudioAsync(fromPath audioFilePath: String) async throws -> AVAudioPCMBuffer {
return try await Task {
return try AudioProcessor.loadAudio(fromPath: audioFilePath)
try AudioProcessor.loadAudio(fromPath: audioFilePath)
}.value
}

Expand Down Expand Up @@ -305,7 +305,8 @@ public class AudioProcessor: NSObject, AudioProcessing {
try audioFile.read(into: inputBuffer, frameCount: framesToRead)
guard let resampledChunk = resampleAudio(fromBuffer: inputBuffer,
toSampleRate: outputFormat.sampleRate,
channelCount: outputFormat.channelCount) else {
channelCount: outputFormat.channelCount)
else {
Logging.error("Failed to resample audio chunk")
return nil
}
Expand Down
69 changes: 56 additions & 13 deletions Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -188,16 +188,29 @@ public enum DecodingTask: CustomStringConvertible, CaseIterable {
}

public struct DecodingInputs {
var initialPrompt: [Int]
var inputIds: MLMultiArray
var cacheLength: MLMultiArray
var keyCache: MLMultiArray
var valueCache: MLMultiArray
var alignmentWeights: MLMultiArray
var kvCacheUpdateMask: MLMultiArray
var decoderKeyPaddingMask: MLMultiArray
var prefillKeyCache: MLMultiArray
var prefillValueCache: MLMultiArray
public var initialPrompt: [Int]
public var inputIds: MLMultiArray
public var cacheLength: MLMultiArray
public var keyCache: MLMultiArray
public var valueCache: MLMultiArray
public var alignmentWeights: MLMultiArray
public var kvCacheUpdateMask: MLMultiArray
public var decoderKeyPaddingMask: MLMultiArray
public var prefillKeyCache: MLMultiArray
public var prefillValueCache: MLMultiArray

public init(initialPrompt: [Int], inputIds: MLMultiArray, cacheLength: MLMultiArray, keyCache: MLMultiArray, valueCache: MLMultiArray, alignmentWeights: MLMultiArray, kvCacheUpdateMask: MLMultiArray, decoderKeyPaddingMask: MLMultiArray, prefillKeyCache: MLMultiArray, prefillValueCache: MLMultiArray) {
self.initialPrompt = initialPrompt
self.inputIds = inputIds
self.cacheLength = cacheLength
self.keyCache = keyCache
self.valueCache = valueCache
self.alignmentWeights = alignmentWeights
self.kvCacheUpdateMask = kvCacheUpdateMask
self.decoderKeyPaddingMask = decoderKeyPaddingMask
self.prefillKeyCache = prefillKeyCache
self.prefillValueCache = prefillValueCache
}

func reset(prefilledCacheSize: Int, maxTokenContext: Int) {
// NOTE: Because we have a mask on the kvcache,
Expand All @@ -223,9 +236,14 @@ public struct DecodingInputs {
}

public struct DecodingCache {
var keyCache: MLMultiArray?
var valueCache: MLMultiArray?
var alignmentWeights: MLMultiArray?
public var keyCache: MLMultiArray?
public var valueCache: MLMultiArray?
public var alignmentWeights: MLMultiArray?
public init(keyCache: MLMultiArray? = nil, valueCache: MLMultiArray? = nil, alignmentWeights: MLMultiArray? = nil) {
self.keyCache = keyCache
self.valueCache = valueCache
self.alignmentWeights = alignmentWeights
}
}

public enum ChunkingStrategy: String, CaseIterable {
Expand Down Expand Up @@ -417,6 +435,21 @@ public struct DecodingResult {
timings: nil,
fallback: nil)
}

public init(language: String, languageProbs: [String: Float], tokens: [Int], tokenLogProbs: [[Int: Float]], text: String, avgLogProb: Float, noSpeechProb: Float, temperature: Float, compressionRatio: Float, cache: DecodingCache? = nil, timings: TranscriptionTimings? = nil, fallback: DecodingFallback? = nil) {
self.language = language
self.languageProbs = languageProbs
self.tokens = tokens
self.tokenLogProbs = tokenLogProbs
self.text = text
self.avgLogProb = avgLogProb
self.noSpeechProb = noSpeechProb
self.temperature = temperature
self.compressionRatio = compressionRatio
self.cache = cache
self.timings = timings
self.fallback = fallback
}
}

public enum WhisperError: Error, LocalizedError, Equatable {
Expand Down Expand Up @@ -588,6 +621,16 @@ public struct TranscriptionProgress {
public var avgLogprob: Float?
public var compressionRatio: Float?
public var windowId: Int = 0

public init(timings: TranscriptionTimings, text: String, tokens: [Int], temperature: Float? = nil, avgLogprob: Float? = nil, compressionRatio: Float? = nil, windowId: Int = 0) {
self.timings = timings
self.text = text
self.tokens = tokens
self.temperature = temperature
self.avgLogprob = avgLogprob
self.compressionRatio = compressionRatio
self.windowId = windowId
}
}

/// Callback to receive progress updates during transcription.
Expand Down
2 changes: 1 addition & 1 deletion Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
break
}
}

// Cleanup the early stop flag after loop completion
if shouldEarlyStop.removeValue(forKey: windowUUID) == nil {
Logging.error("Early stop flag not found for window: \(windowUUID)")
Expand Down
12 changes: 6 additions & 6 deletions Sources/WhisperKit/Core/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,14 @@ public extension WhisperKit {
}
}

extension Float {
public extension Float {
func rounded(_ decimalPlaces: Int) -> Float {
let divisor = pow(10.0, Float(decimalPlaces))
return (self * divisor).rounded() / divisor
}
}

extension String {
public extension String {
var normalized: String {
// Trim whitespace and newlines
let trimmedString = self.trimmingCharacters(in: .whitespacesAndNewlines)
Expand All @@ -188,12 +188,12 @@ extension String {
}

extension AVAudioPCMBuffer {
// Appends the contents of another buffer to the current buffer
/// Appends the contents of another buffer to the current buffer
func appendContents(of buffer: AVAudioPCMBuffer) -> Bool {
return appendContents(of: buffer, startingFrame: 0, frameCount: buffer.frameLength)
}

// Appends a specific range of frames from another buffer to the current buffer
/// Appends a specific range of frames from another buffer to the current buffer
func appendContents(of buffer: AVAudioPCMBuffer, startingFrame: AVAudioFramePosition, frameCount: AVAudioFrameCount) -> Bool {
guard format == buffer.format else {
Logging.debug("Format mismatch")
Expand Down Expand Up @@ -225,7 +225,7 @@ extension AVAudioPCMBuffer {
return true
}

// Convenience initializer to concatenate multiple buffers into one
/// Convenience initializer to concatenate multiple buffers into one
convenience init?(concatenating buffers: [AVAudioPCMBuffer]) {
guard !buffers.isEmpty else {
Logging.debug("Buffers array should not be empty")
Expand All @@ -249,7 +249,7 @@ extension AVAudioPCMBuffer {
}
}

// Computed property to determine the stride for float channel data
/// Computed property to determine the stride for float channel data
private var stride: Int {
return Int(format.streamDescription.pointee.mBytesPerFrame) / MemoryLayout<Float>.size
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ open class WhisperKit {
guard textDecoder.isModelMultilingual else {
throw WhisperError.decodingFailed("Language detection not supported for this model")
}

// Tokenizer required for decoding
guard let tokenizer else {
throw WhisperError.tokenizerUnavailable()
Expand Down

0 comments on commit 37007ef

Please sign in to comment.