Skip to content

Commit

Permalink
Improving modularity and code structure (#212)
Browse files Browse the repository at this point in the history
* CI fetch depth 0

* VAD refactoring

* Update logo

* Add WhisperKitConfig

* Open whisperkit methods

* add missing @available

---------

Co-authored-by: BlaiseMuhirwa <[email protected]>
Co-authored-by: ZachNagengast <[email protected]>
  • Loading branch information
3 people authored Oct 2, 2024
1 parent 3cd3ef1 commit c2f1b57
Show file tree
Hide file tree
Showing 22 changed files with 371 additions and 263 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/development-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ jobs:
reviews: ${{ steps.reviews.outputs.state }}
permissions:
pull-requests: read
contents: read
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Check Approvals
id: reviews
env:
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/expo-update.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ jobs:
uses: actions/checkout@v4
with:
repository: seb-sep/whisper-kit-expo
fetch-depth: 0
token: ${{ secrets.COMMITTER_TOKEN }}
ref: main

Expand Down
15 changes: 7 additions & 8 deletions Examples/WhisperAX/WhisperAX/Views/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -990,14 +990,13 @@ struct ContentView: View {

whisperKit = nil
Task {
whisperKit = try await WhisperKit(
computeOptions: getComputeOptions(),
verbose: true,
logLevel: .debug,
prewarm: false,
load: false,
download: false
)
let config = WhisperKitConfig(computeOptions: getComputeOptions(),
verbose: true,
logLevel: .debug,
prewarm: false,
load: false,
download: false)
whisperKit = try await WhisperKit(config)
guard let whisperKit = whisperKit else {
return
}
Expand Down
14 changes: 7 additions & 7 deletions Examples/WhisperAX/WhisperAXWatchApp/WhisperAXExampleView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -340,13 +340,13 @@ struct WhisperAXWatchView: View {

whisperKit = nil
Task {
whisperKit = try await WhisperKit(
verbose: true,
logLevel: .debug,
prewarm: false,
load: false,
download: false
)
let config = WhisperKitConfig(verbose: true,
logLevel: .debug,
prewarm: false,
load: false,
download: false)

whisperKit = try await WhisperKit(config)
guard let whisperKit = whisperKit else {
return
}
Expand Down
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
<div align="center">

<a href="https://github.com/argmaxinc/WhisperKit#gh-light-mode-only">
<img src="https://github.com/argmaxinc/WhisperKit/assets/1981179/6ac3360b-2f5c-4392-a71a-05c5dda71093" alt="WhisperKit" width="20%" />
<img src="https://github.com/user-attachments/assets/f0699c07-c29f-45b6-a9c6-f6d491b8f791" alt="WhisperKit" width="20%" />
</a>

<a href="https://github.com/argmaxinc/WhisperKit#gh-dark-mode-only">
<img src="https://github.com/argmaxinc/WhisperKit/assets/1981179/a682ce21-80e0-4a98-a99f-836663538a4f" alt="WhisperKit" width="20%" />
<img src="https://github.com/user-attachments/assets/1be5e31c-de42-40ab-9b85-790cb911ed47" alt="WhisperKit" width="20%" />
</a>

# WhisperKit
Expand Down Expand Up @@ -92,13 +92,13 @@ Task {
WhisperKit automatically downloads the recommended model for the device if not specified. You can also select a specific model by passing in the model name:

```swift
let pipe = try? await WhisperKit(model: "large-v3")
let pipe = try? await WhisperKit(WhisperKitConfig(model: "large-v3"))
```

This method also supports glob search, so you can use wildcards to select a model:

```swift
let pipe = try? await WhisperKit(model: "distil*large-v3")
let pipe = try? await WhisperKit(WhisperKitConfig(model: "distil*large-v3"))
```

Note that the model search must return a single model from the source repo, otherwise an error will be thrown.
Expand All @@ -110,7 +110,8 @@ For a list of available models, see our [HuggingFace repo](https://huggingface.c
WhisperKit also comes with the supporting repo [`whisperkittools`](https://github.com/argmaxinc/whisperkittools) which lets you create and deploy your own fine tuned versions of Whisper in CoreML format to HuggingFace. Once generated, they can be loaded by simply changing the repo name to the one used to upload the model:

```swift
let pipe = try? await WhisperKit(model: "large-v3", modelRepo: "username/your-model-repo")
let config = WhisperKitConfig(model: "large-v3", modelRepo: "username/your-model-repo")
let pipe = try? await WhisperKit(config)
```

### Swift CLI
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ open class VADAudioChunker: AudioChunking {
private let windowPadding: Int
private let vad: VoiceActivityDetector

init(windowPadding: Int = 16000, vad: VoiceActivityDetector = EnergyVAD()) {
public init(windowPadding: Int = 16000, vad: VoiceActivityDetector? = nil) {
self.windowPadding = windowPadding
self.vad = vad
self.vad = vad ?? EnergyVAD()
}

private func splitOnMiddleOfLongestSilence(audioArray: [Float], startIndex: Int, endIndex: Int) -> Int {
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@ import Foundation
/// A base class for Voice Activity Detection (VAD), used to identify and separate segments of audio that contain human speech from those that do not.
/// Subclasses must implement the `voiceActivity(in:)` method to provide specific voice activity detection functionality.
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
class VoiceActivityDetector {
open class VoiceActivityDetector {
/// The sample rate of the audio signal, in samples per second.
var sampleRate: Int
public let sampleRate: Int

/// The length of each frame in samples.
var frameLengthSamples: Int
public let frameLengthSamples: Int

// The number of samples overlapping between consecutive frames.
var frameOverlapSamples: Int
/// The number of samples overlapping between consecutive frames.
public let frameOverlapSamples: Int

/// Initializes a new `VoiceActivityDetector` instance with the specified parameters.
/// - Parameters:
/// - sampleRate: The sample rate of the audio signal in samples per second. Defaults to 16000.
/// - frameLengthSamples: The length of each frame in samples.
/// - frameOverlapSamples: The number of samples overlapping between consecutive frames. Defaults to 0.
/// - Note: Subclasses should override the `voiceActivity(in:)` method to provide specific VAD functionality.
init(
public init(
sampleRate: Int = 16000,
frameLengthSamples: Int,
frameOverlapSamples: Int = 0
Expand All @@ -35,14 +35,14 @@ class VoiceActivityDetector {
/// Analyzes the provided audio waveform to determine which segments contain voice activity.
/// - Parameter waveform: An array of `Float` values representing the audio waveform.
/// - Returns: An array of `Bool` values where `true` indicates the presence of voice activity and `false` indicates silence.
func voiceActivity(in waveform: [Float]) -> [Bool] {
open func voiceActivity(in waveform: [Float]) -> [Bool] {
fatalError("`voiceActivity` must be implemented by subclass")
}

/// Calculates and returns a list of active audio chunks, each represented by a start and end index.
/// - Parameter waveform: An array of `Float` values representing the audio waveform.
/// - Returns: An array of tuples where each tuple contains the start and end indices of an active audio chunk.
func calculateActiveChunks(in waveform: [Float]) -> [(startIndex: Int, endIndex: Int)] {
public func calculateActiveChunks(in waveform: [Float]) -> [(startIndex: Int, endIndex: Int)] {
let vad: [Bool] = voiceActivity(in: waveform)
var result = [(startIndex: Int, endIndex: Int)]()

Expand Down Expand Up @@ -74,18 +74,18 @@ class VoiceActivityDetector {
/// Converts a voice activity index to the corresponding audio sample index.
/// - Parameter index: The voice activity index to convert.
/// - Returns: The corresponding audio sample index.
func voiceActivityIndexToAudioSampleIndex(_ index: Int) -> Int {
public func voiceActivityIndexToAudioSampleIndex(_ index: Int) -> Int {
return index * frameLengthSamples
}

func voiceActivityIndexToSeconds(_ index: Int) -> Float {
public func voiceActivityIndexToSeconds(_ index: Int) -> Float {
return Float(voiceActivityIndexToAudioSampleIndex(index)) / Float(sampleRate)
}

/// Identifies the longest continuous period of silence within the provided voice activity detection results.
/// - Parameter vadResult: An array of `Bool` values representing voice activity detection results.
/// - Returns: A tuple containing the start and end indices of the longest silence period, or `nil` if no silence is found.
func findLongestSilence(in vadResult: [Bool]) -> (startIndex: Int, endIndex: Int)? {
public func findLongestSilence(in vadResult: [Bool]) -> (startIndex: Int, endIndex: Int)? {
var longestStartIndex: Int?
var longestEndIndex: Int?
var longestCount = 0
Expand Down
205 changes: 205 additions & 0 deletions Sources/WhisperKit/Core/Configurations.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2024 Argmax, Inc. All rights reserved.

import Foundation

/// Configuration to initialize WhisperKit
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
open class WhisperKitConfig {
/// Name for whisper model to use
public var model: String?
/// Base URL for downloading models
public var downloadBase: URL?
/// Repository for downloading models
public var modelRepo: String?

/// Folder to store models
public var modelFolder: String?
/// Folder to store tokenizers
public var tokenizerFolder: URL?

/// Model compute options, see `ModelComputeOptions`
public var computeOptions: ModelComputeOptions?
/// Audio processor for the model
public var audioProcessor: (any AudioProcessing)?
/// Audio processor for the model
public var featureExtractor: (any FeatureExtracting)?
public var audioEncoder: (any AudioEncoding)?
public var textDecoder: (any TextDecoding)?
public var logitsFilters: [any LogitsFiltering]?
public var segmentSeeker: (any SegmentSeeking)?

/// Enable extra verbosity for logging
public var verbose: Bool
/// Maximum log level
public var logLevel: Logging.LogLevel

/// Enable model prewarming
public var prewarm: Bool?
/// Load models if available
public var load: Bool?
/// Download models if not available
public var download: Bool
/// Use background download session
public var useBackgroundDownloadSession: Bool

public init(model: String? = nil,
downloadBase: URL? = nil,
modelRepo: String? = nil,
modelFolder: String? = nil,
tokenizerFolder: URL? = nil,
computeOptions: ModelComputeOptions? = nil,
audioProcessor: (any AudioProcessing)? = nil,
featureExtractor: (any FeatureExtracting)? = nil,
audioEncoder: (any AudioEncoding)? = nil,
textDecoder: (any TextDecoding)? = nil,
logitsFilters: [any LogitsFiltering]? = nil,
segmentSeeker: (any SegmentSeeking)? = nil,
verbose: Bool = true,
logLevel: Logging.LogLevel = .info,
prewarm: Bool? = nil,
load: Bool? = nil,
download: Bool = true,
useBackgroundDownloadSession: Bool = false
) {
self.model = model
self.downloadBase = downloadBase
self.modelRepo = modelRepo
self.modelFolder = modelFolder
self.tokenizerFolder = tokenizerFolder
self.computeOptions = computeOptions
self.audioProcessor = audioProcessor
self.featureExtractor = featureExtractor
self.audioEncoder = audioEncoder
self.textDecoder = textDecoder
self.logitsFilters = logitsFilters
self.segmentSeeker = segmentSeeker
self.verbose = verbose
self.logLevel = logLevel
self.prewarm = prewarm
self.load = load
self.download = download
self.useBackgroundDownloadSession = useBackgroundDownloadSession
}
}


/// Options for how to transcribe an audio file using WhisperKit.
///
/// - Parameters:
/// - verbose: Whether to display the text being decoded to the console.
/// If true, displays all details; if false, displays minimal details;
/// - task: Whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')
/// - language: Language spoken in the audio
/// - temperature: Temperature to use for sampling.
/// - temperatureIncrementOnFallback: Increment which will be
/// successively added to temperature upon failures according to either `compressionRatioThreshold`
/// or `logProbThreshold`.
/// - temperatureFallbackCount: Number of times to increment temperature on fallback.
/// - sampleLength: The maximum number of tokens to sample.
/// - topK: Number of candidates when sampling with non-zero temperature.
/// - usePrefillPrompt: If true, the prefill tokens will be forced according to task and language settings.
/// - usePrefillCache: If true, the kv cache will be prefilled based on the prefill data mlmodel.
/// - detectLanguage: Use this in conjuntion with `usePrefillPrompt: true` to detect the language of the input audio.
/// - skipSpecialTokens: Whether to skip special tokens in the output.
/// - withoutTimestamps: Whether to include timestamps in the transcription result.
/// - wordTimestamps: Whether to include word-level timestamps in the transcription result.
/// - maxInitialTimestamp: Maximal initial timestamp.
/// - clipTimestamps: Array of timestamps (in seconds) to split the audio into segments for transcription.
/// - promptTokens: Array of token IDs to use as the conditioning prompt for the decoder. These are prepended to the prefill tokens.
/// - prefixTokens: Array of token IDs to use as the initial prefix for the decoder. These are appended to the prefill tokens.
/// - suppressBlank: If true, blank tokens will be suppressed during decoding.
/// - supressTokens: List of token IDs to suppress during decoding.
/// - compressionRatioThreshold: If the compression ratio of the transcription text is above this value, it is too repetitive and treated as failed.
/// - logProbThreshold: If the average log probability over sampled tokens is below this value, treat as failed.
/// - firstTokenLogProbThreshold: If the log probability over the first sampled token is below this value, treat as failed.
/// - noSpeechThreshold: If the no speech probability is higher than this value AND the average log
/// probability over sampled tokens is below `logProbThreshold`, consider the segment as silent.
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
public struct DecodingOptions {
public var verbose: Bool
public var task: DecodingTask
public var language: String?
public var temperature: Float
public var temperatureIncrementOnFallback: Float
public var temperatureFallbackCount: Int
public var sampleLength: Int
public var topK: Int
public var usePrefillPrompt: Bool
public var usePrefillCache: Bool
public var detectLanguage: Bool
public var skipSpecialTokens: Bool
public var withoutTimestamps: Bool
public var wordTimestamps: Bool
public var maxInitialTimestamp: Float?
public var clipTimestamps: [Float]
public var promptTokens: [Int]?
public var prefixTokens: [Int]?
public var suppressBlank: Bool
public var supressTokens: [Int]
public var compressionRatioThreshold: Float?
public var logProbThreshold: Float?
public var firstTokenLogProbThreshold: Float?
public var noSpeechThreshold: Float?
public var concurrentWorkerCount: Int
public var chunkingStrategy: ChunkingStrategy?
public var voiceActivityDetector: VoiceActivityDetector?

public init(
verbose: Bool = false,
task: DecodingTask = .transcribe,
language: String? = nil,
temperature: Float = 0.0,
temperatureIncrementOnFallback: Float = 0.2,
temperatureFallbackCount: Int = 5,
sampleLength: Int = Constants.maxTokenContext,
topK: Int = 5,
usePrefillPrompt: Bool = true,
usePrefillCache: Bool = true,
detectLanguage: Bool? = nil,
skipSpecialTokens: Bool = false,
withoutTimestamps: Bool = false,
wordTimestamps: Bool = false,
maxInitialTimestamp: Float? = nil,
clipTimestamps: [Float] = [],
promptTokens: [Int]? = nil,
prefixTokens: [Int]? = nil,
suppressBlank: Bool = false,
supressTokens: [Int]? = nil,
compressionRatioThreshold: Float? = 2.4,
logProbThreshold: Float? = -1.0,
firstTokenLogProbThreshold: Float? = -1.5,
noSpeechThreshold: Float? = 0.6,
concurrentWorkerCount: Int = 16,
chunkingStrategy: ChunkingStrategy? = nil,
voiceActivityDetector: VoiceActivityDetector? = nil
) {
self.verbose = verbose
self.task = task
self.language = language
self.temperature = temperature
self.temperatureIncrementOnFallback = temperatureIncrementOnFallback
self.temperatureFallbackCount = temperatureFallbackCount
self.sampleLength = sampleLength
self.topK = topK
self.usePrefillPrompt = usePrefillPrompt
self.usePrefillCache = usePrefillCache
self.detectLanguage = detectLanguage ?? !usePrefillPrompt // If prefill is false, detect language by default
self.skipSpecialTokens = skipSpecialTokens
self.withoutTimestamps = withoutTimestamps
self.wordTimestamps = wordTimestamps
self.maxInitialTimestamp = maxInitialTimestamp
self.clipTimestamps = clipTimestamps
self.promptTokens = promptTokens
self.prefixTokens = prefixTokens
self.suppressBlank = suppressBlank
self.supressTokens = supressTokens ?? [] // nonSpeechTokens() // TODO: implement these as default
self.compressionRatioThreshold = compressionRatioThreshold
self.logProbThreshold = logProbThreshold
self.firstTokenLogProbThreshold = firstTokenLogProbThreshold
self.noSpeechThreshold = noSpeechThreshold
self.concurrentWorkerCount = concurrentWorkerCount
self.chunkingStrategy = chunkingStrategy
self.voiceActivityDetector = voiceActivityDetector
}
}
Loading

0 comments on commit c2f1b57

Please sign in to comment.