Skip to content

Commit

Permalink
Allow protocol defined types for model inputs and outputs (#281)
Browse files Browse the repository at this point in the history
* Freeze more enums

* Audio input length from CoreML metadata

Add arbitrary length audio

* Backwards compatible generic model io

* Support generic io for model inputs and outputs

* Add speed factor to timing report

* Use actor for early stop checks for better concurrency safety

* Add io type protocol handling and tests

* Formatting

* Fix timestamp token filter logic and tests

* Run unit tests on any branch in PR

* Upload test failure results

---------

Co-authored-by: Andrey Leonov <[email protected]>
Co-authored-by: Eduardo Pacheco <[email protected]>
  • Loading branch information
3 people authored Dec 19, 2024
1 parent 3bc936a commit d191654
Show file tree
Hide file tree
Showing 15 changed files with 746 additions and 253 deletions.
1 change: 0 additions & 1 deletion .github/workflows/development-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ name: Development Tests

on:
pull_request:
branches: ["main"]
pull_request_review:
types: [submitted]
workflow_dispatch:
Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,19 @@ jobs:
sleep 15
xcrun simctl list devices
- name: Build and Test - ${{ matrix.run-config['name'] }}
id: test-step
if: ${{ matrix.run-config['condition'] == true }}
continue-on-error: true
run: |
set -o pipefail
xcodebuild clean build-for-testing -scheme whisperkit-Package -destination '${{ matrix.run-config['clean-destination'] }}' | xcpretty
xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -destination '${{ matrix.run-config['test-destination'] }}'
- name: Upload Test Results
if: failure() && steps.test-step.outcome == 'failure'
uses: actions/upload-artifact@v4
with:
name: test-results-${{ matrix.run-config['name'] }}
path: |
~/Library/Developer/Xcode/DerivedData/**/Logs/Test/*.xcresult
retention-days: 5
9 changes: 6 additions & 3 deletions Sources/WhisperKit/Core/Audio/AudioProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public extension AudioProcessing {
}

static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? {
guard startIndex >= 0 && startIndex < audioArray.count else {
guard startIndex >= 0, startIndex < audioArray.count else {
Logging.error("startIndex is outside the buffer size")
return nil
}
Expand Down Expand Up @@ -178,7 +178,6 @@ public class AudioProcessor: NSObject, AudioProcessing {
}

public var audioBufferCallback: (([Float]) -> Void)?
public var maxBufferLength = WhisperKit.sampleRate * WhisperKit.chunkLength // 30 seconds of audio at 16,000 Hz
public var minBufferLength = Int(Double(WhisperKit.sampleRate) * 0.1) // 0.1 second of audio at 16,000 Hz

// MARK: - Loading and conversion
Expand Down Expand Up @@ -229,7 +228,11 @@ public class AudioProcessor: NSObject, AudioProcessing {
guard let buffer = AVAudioPCMBuffer(pcmFormat: audioFile.processingFormat, frameCapacity: frameCount) else {
throw WhisperError.loadAudioFailed("Unable to create audio buffer")
}
try audioFile.read(into: buffer, frameCount: frameCount)
do {
try audioFile.read(into: buffer, frameCount: frameCount)
} catch {
throw WhisperError.loadAudioFailed("Failed to read audio file: \(error)")
}
outputBuffer = buffer
} else {
// Audio needs resampling to 16khz
Expand Down
2 changes: 1 addition & 1 deletion Sources/WhisperKit/Core/Audio/VoiceActivityDetector.swift
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ open class VoiceActivityDetector {
}
}

// MARK - Utility
// MARK: - Utility

func voiceActivityClipTimestamps(in waveform: [Float]) -> [Float] {
let nonSilentChunks = calculateActiveChunks(in: waveform)
Expand Down
18 changes: 15 additions & 3 deletions Sources/WhisperKit/Core/AudioEncoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@

import CoreML

public protocol AudioEncoderOutputType {}
extension MLMultiArray: AudioEncoderOutputType {}

/// AudioEncoding protocol defines the requirements for an audio encoding implementation.
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
public protocol AudioEncoding {
/// The size of the embedding produced by the encoder.
var embedSize: Int? { get }

/// Encodes the given audio features asynchronously.
/// - Parameter features: The audio features to be encoded.
/// - Returns: An optional `MLMultiArray` containing the encoded features.
func encodeFeatures(_ features: MLMultiArray) async throws -> MLMultiArray?
/// - Returns: An optional tensor containing the encoded features.
func encodeFeatures(_ features: any FeatureExtractorOutputType) async throws -> (any AudioEncoderOutputType)?
}

/// Backwards-compatible AudioEncoder implementation
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
public class AudioEncoder: AudioEncoding, WhisperMLModel {
public var model: MLModel?
Expand All @@ -36,8 +41,15 @@ public class AudioEncoder: AudioEncoding, WhisperMLModel {

public init() {}

public func encodeFeatures(_ features: any FeatureExtractorOutputType) async throws -> (any AudioEncoderOutputType)? {
guard let features = features as? MLMultiArray else {
throw WhisperError.audioProcessingFailed("AudioEncoder input must be MLMultiArray")
}

return try await encodeFeatures(features)
}

public func encodeFeatures(_ features: MLMultiArray) async throws -> MLMultiArray? {
// Make sure features is shape MultiArray (Float32 1 × {80,128} × 3000)
guard let model else {
throw WhisperError.modelsUnavailable()
}
Expand Down
17 changes: 16 additions & 1 deletion Sources/WhisperKit/Core/FeatureExtractor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@ import CoreGraphics
import CoreML
import Foundation

public protocol FeatureExtractorOutputType {}
extension MLMultiArray: FeatureExtractorOutputType {}

public protocol FeatureExtracting {
associatedtype OutputType: FeatureExtractorOutputType

var melCount: Int? { get }
func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> MLMultiArray?
var windowSamples: Int? { get }
func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> OutputType?
}

@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
Expand All @@ -26,6 +32,14 @@ open class FeatureExtractor: FeatureExtracting, WhisperMLModel {
return shape[1]
}

public var windowSamples: Int? {
guard let inputDescription = model?.modelDescription.inputDescriptionsByName["audio"] else { return nil }
guard inputDescription.type == .multiArray else { return nil }
guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil }
let shape = shapeConstraint.shape.map { $0.intValue }
return shape[0] // The audio input is a 1D array
}

public func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> MLMultiArray? {
guard let model else {
throw WhisperError.modelsUnavailable()
Expand All @@ -40,4 +54,5 @@ open class FeatureExtractor: FeatureExtracting, WhisperMLModel {
let output = MelSpectrogramOutput(features: outputFeatures)
return output.melspectrogramFeatures
}

}
12 changes: 11 additions & 1 deletion Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public extension WhisperMLModel {

// MARK: - Whisper Models

@frozen
public enum ModelVariant: CustomStringConvertible, CaseIterable {
case tiny
case tinyEn
Expand Down Expand Up @@ -100,6 +101,7 @@ public enum ModelVariant: CustomStringConvertible, CaseIterable {
}
}

@frozen
public enum ModelState: CustomStringConvertible {
case unloading
case unloaded
Expand Down Expand Up @@ -282,6 +284,7 @@ public struct AudioChunk {

// MARK: - Decoding

@frozen
public enum DecodingTask: Codable, CustomStringConvertible, CaseIterable {
case transcribe
case translate
Expand All @@ -296,7 +299,7 @@ public enum DecodingTask: Codable, CustomStringConvertible, CaseIterable {
}
}

public struct DecodingInputs {
open class DecodingInputs {
public var initialPrompt: [Int]
public var inputIds: MLMultiArray
public var cacheLength: MLMultiArray
Expand Down Expand Up @@ -355,6 +358,7 @@ public struct DecodingCache {
}
}

@frozen
public enum ChunkingStrategy: String, Codable, CaseIterable {
case none
case vad
Expand Down Expand Up @@ -444,6 +448,7 @@ public struct DecodingResult {
}
}

@frozen
public enum WhisperError: Error, LocalizedError, Equatable {
case tokenizerUnavailable(String = "Tokenizer is unavailable")
case modelsUnavailable(String = "Models are unavailable")
Expand Down Expand Up @@ -575,6 +580,7 @@ public struct TranscriptionResult: Codable {
Total Tokens: \(totalTokens)
Tokens per Second: \(String(format: "%.2f", tokensPerSecond)) tok/s
Real Time Factor: \(String(format: "%.3f", rtf))
Speed Factor: \(String(format: "%.3f", 1.0 / rtf))
Fallbacks: \(timings.totalDecodingFallbacks)
""")
}
Expand Down Expand Up @@ -647,6 +653,7 @@ public typealias ModelStateCallback = (_ oldState: ModelState?, _ newState: Mode
public typealias TranscriptionStateCallback = (_ state: TranscriptionState) -> Void

/// Represents the different states of the transcription process.
@frozen
public enum TranscriptionState: CustomStringConvertible {
/// The audio is being converted to the required format for transcription
case convertingAudio
Expand Down Expand Up @@ -1372,6 +1379,7 @@ extension WhisperTokenizerWrapper {

// MARK: Constants

@frozen
public enum Constants {
enum Logging {
static let subsystem = "com.argmax.whisperkit"
Expand Down Expand Up @@ -1502,6 +1510,8 @@ public enum Constants {

public static let defaultAudioReadFrameSize: AVAudioFrameCount = 1_323_000 // 30s of audio at commonly found 44.1khz sample rate

public static let defaultWindowSamples: Int = 480_000 // 30s of audio at 16khz sample rate default for Whisper models

public static let fallbackModelSupportConfig: ModelSupportConfig = {
var config = ModelSupportConfig(
repoName: "whisperkit-coreml-fallback",
Expand Down
3 changes: 2 additions & 1 deletion Sources/WhisperKit/Core/Text/LogitsFilter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ open class TimestampRulesFilter: LogitsFiltering {

public func filterLogits(_ logits: MLMultiArray, withTokens tokens: [Int]) -> MLMultiArray {
guard let sampleBegin = sampleBegin(for: tokens),
sampleBegin > tokens.count
sampleBegin <= tokens.count
else {
// Early return if we are still prefilling the prompt
return logits
}

Expand Down
Loading

0 comments on commit d191654

Please sign in to comment.