Skip to content

Commit

Permalink
Add VoiceActivityDetector base class (#199)
Browse files Browse the repository at this point in the history
* Add VoiceActivityDetector base class

Add base class to allow different VAD implementations

* fix spaces
  • Loading branch information
Andrey Leonov authored Sep 5, 2024
1 parent 59aaa4e commit c03017f
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 79 deletions.
5 changes: 3 additions & 2 deletions Sources/WhisperKit/Core/AudioChunker.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ public extension AudioChunking {
open class VADAudioChunker: AudioChunking {
/// prevent hallucinations at the end of the clip by stopping up to 1.0s early
private let windowPadding: Int
private let vad = EnergyVAD()
private let vad: VoiceActivityDetector

init(windowPadding: Int = 16000) {
init(windowPadding: Int = 16000, vad: VoiceActivityDetector = EnergyVAD()) {
self.windowPadding = windowPadding
self.vad = vad
}

private func splitOnMiddleOfLongestSilence(audioArray: [Float], startIndex: Int, endIndex: Int) -> Int {
Expand Down
58 changes: 58 additions & 0 deletions Sources/WhisperKit/Core/VAD/EnergyVAD.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2024 Argmax, Inc. All rights reserved.

import Foundation

/// Voice activity detection based on energy threshold
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
final class EnergyVAD: VoiceActivityDetector {
var energyThreshold: Float

/// Initialize a new EnergyVAD instance
/// - Parameters:
/// - sampleRate: Audio sample rate
/// - frameLength: Frame length in seconds
/// - frameOverlap: frame overlap in seconds, this will include `frameOverlap` length audio into the `frameLength` and is helpful to catch audio that starts exactly at chunk boundaries
/// - energyThreshold: minimal energy threshold
convenience init(
sampleRate: Int = WhisperKit.sampleRate,
frameLength: Float = 0.1,
frameOverlap: Float = 0.0,
energyThreshold: Float = 0.02
) {
self.init(
sampleRate: sampleRate,
// Compute frame length and overlap in number of samples
frameLengthSamples: Int(frameLength * Float(sampleRate)),
frameOverlapSamples: Int(frameOverlap * Float(sampleRate)),
energyThreshold: energyThreshold
)
}

required init(
sampleRate: Int = 16000,
frameLengthSamples: Int,
frameOverlapSamples: Int = 0,
energyThreshold: Float = 0.02
) {
self.energyThreshold = energyThreshold
super.init(sampleRate: sampleRate, frameLengthSamples: frameLengthSamples, frameOverlapSamples: frameOverlapSamples)
}

override func voiceActivity(in waveform: [Float]) -> [Bool] {
let chunkRatio = Double(waveform.count) / Double(frameLengthSamples)

// Round up if uneven, the final chunk will not be a full `frameLengthSamples` long
let count = Int(chunkRatio.rounded(.up))

let chunkedVoiceActivity = AudioProcessor.calculateVoiceActivityInChunks(
of: waveform,
chunkCount: count,
frameLengthSamples: frameLengthSamples,
frameOverlapSamples: frameOverlapSamples,
energyThreshold: energyThreshold
)

return chunkedVoiceActivity
}
}
Original file line number Diff line number Diff line change
@@ -1,67 +1,47 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2024 Argmax, Inc. All rights reserved.

import Accelerate
import Foundation

/// Voice activity detection based on energy threshold
/// A base class for Voice Activity Detection (VAD), used to identify and separate segments of audio that contain human speech from those that do not.
/// Subclasses must implement the `voiceActivity(in:)` method to provide specific voice activity detection functionality.
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
final class EnergyVAD {
class VoiceActivityDetector {
/// The sample rate of the audio signal, in samples per second.
var sampleRate: Int

/// The length of each frame in samples.
var frameLengthSamples: Int

// The number of samples overlapping between consecutive frames.
var frameOverlapSamples: Int
var energyThreshold: Float

/// Initialize a new EnergyVAD instance
/// Initializes a new `VoiceActivityDetector` instance with the specified parameters.
/// - Parameters:
/// - sampleRate: Audio sample rate
/// - frameLength: Frame length in seconds
/// - frameOverlap: frame overlap in seconds, this will include `frameOverlap` length audio into the `frameLength` and is helpful to catch audio that starts exactly at chunk boundaries
/// - energyThreshold: minimal energy threshold
convenience init(
sampleRate: Int = WhisperKit.sampleRate,
frameLength: Float = 0.1,
frameOverlap: Float = 0.0,
energyThreshold: Float = 0.02
) {
self.init(
sampleRate: sampleRate,
// Compute frame length and overlap in number of samples
frameLengthSamples: Int(frameLength * Float(sampleRate)),
frameOverlapSamples: Int(frameOverlap * Float(sampleRate)),
energyThreshold: energyThreshold
)
}

required init(
/// - sampleRate: The sample rate of the audio signal in samples per second. Defaults to 16000.
/// - frameLengthSamples: The length of each frame in samples.
/// - frameOverlapSamples: The number of samples overlapping between consecutive frames. Defaults to 0.
/// - Note: Subclasses should override the `voiceActivity(in:)` method to provide specific VAD functionality.
init(
sampleRate: Int = 16000,
frameLengthSamples: Int,
frameOverlapSamples: Int = 0,
energyThreshold: Float = 0.02
frameOverlapSamples: Int = 0
) {
self.sampleRate = sampleRate
self.frameLengthSamples = frameLengthSamples
self.frameOverlapSamples = frameOverlapSamples
self.energyThreshold = energyThreshold
}

/// Analyzes the provided audio waveform to determine which segments contain voice activity.
/// - Parameter waveform: An array of `Float` values representing the audio waveform.
/// - Returns: An array of `Bool` values where `true` indicates the presence of voice activity and `false` indicates silence.
func voiceActivity(in waveform: [Float]) -> [Bool] {
let chunkRatio = Double(waveform.count) / Double(frameLengthSamples)

// Round up if uneven, the final chunk will not be a full `frameLengthSamples` long
let count = Int(chunkRatio.rounded(.up))

let chunkedVoiceActivity = AudioProcessor.calculateVoiceActivityInChunks(
of: waveform,
chunkCount: count,
frameLengthSamples: frameLengthSamples,
frameOverlapSamples: frameOverlapSamples,
energyThreshold: energyThreshold
)

return chunkedVoiceActivity
fatalError("`voiceActivity` must be implemented by subclass")
}

/// Calculates and returns a list of active audio chunks, each represented by a start and end index.
/// - Parameter waveform: An array of `Float` values representing the audio waveform.
/// - Returns: An array of tuples where each tuple contains the start and end indices of an active audio chunk.
func calculateActiveChunks(in waveform: [Float]) -> [(startIndex: Int, endIndex: Int)] {
let vad: [Bool] = voiceActivity(in: waveform)
var result = [(startIndex: Int, endIndex: Int)]()
Expand Down Expand Up @@ -91,41 +71,9 @@ final class EnergyVAD {
return result
}

func voiceActivityClipTimestamps(in waveform: [Float]) -> [Float] {
let nonSilentChunks = calculateActiveChunks(in: waveform)
var clipTimestamps = [Float]()

for chunk in nonSilentChunks {
let startTimestamp = Float(chunk.startIndex) / Float(sampleRate)
let endTimestamp = Float(chunk.endIndex) / Float(sampleRate)

clipTimestamps.append(contentsOf: [startTimestamp, endTimestamp])
}

return clipTimestamps
}

func calculateNonSilentSeekClips(in waveform: [Float]) -> [(start: Int, end: Int)] {
let clipTimestamps = voiceActivityClipTimestamps(in: waveform)
let options = DecodingOptions(clipTimestamps: clipTimestamps)
let seekClips = prepareSeekClips(contentFrames: waveform.count, decodeOptions: options)
return seekClips
}

func calculateSeekTimestamps(in waveform: [Float]) -> [(startTime: Float, endTime: Float)] {
let nonSilentChunks = calculateActiveChunks(in: waveform)
var seekTimestamps = [(startTime: Float, endTime: Float)]()

for chunk in nonSilentChunks {
let startTimestamp = Float(chunk.startIndex) / Float(sampleRate)
let endTimestamp = Float(chunk.endIndex) / Float(sampleRate)

seekTimestamps.append(contentsOf: [(startTime: startTimestamp, endTime: endTimestamp)])
}

return seekTimestamps
}

/// Converts a voice activity index to the corresponding audio sample index.
/// - Parameter index: The voice activity index to convert.
/// - Returns: The corresponding audio sample index.
func voiceActivityIndexToAudioSampleIndex(_ index: Int) -> Int {
return index * frameLengthSamples
}
Expand All @@ -134,6 +82,9 @@ final class EnergyVAD {
return Float(voiceActivityIndexToAudioSampleIndex(index)) / Float(sampleRate)
}

/// Identifies the longest continuous period of silence within the provided voice activity detection results.
/// - Parameter vadResult: An array of `Bool` values representing voice activity detection results.
/// - Returns: A tuple containing the start and end indices of the longest silence period, or `nil` if no silence is found.
func findLongestSilence(in vadResult: [Bool]) -> (startIndex: Int, endIndex: Int)? {
var longestStartIndex: Int?
var longestEndIndex: Int?
Expand Down Expand Up @@ -165,4 +116,41 @@ final class EnergyVAD {
return nil
}
}

// MARK - Utility

func voiceActivityClipTimestamps(in waveform: [Float]) -> [Float] {
let nonSilentChunks = calculateActiveChunks(in: waveform)
var clipTimestamps = [Float]()

for chunk in nonSilentChunks {
let startTimestamp = Float(chunk.startIndex) / Float(sampleRate)
let endTimestamp = Float(chunk.endIndex) / Float(sampleRate)

clipTimestamps.append(contentsOf: [startTimestamp, endTimestamp])
}

return clipTimestamps
}

func calculateNonSilentSeekClips(in waveform: [Float]) -> [(start: Int, end: Int)] {
let clipTimestamps = voiceActivityClipTimestamps(in: waveform)
let options = DecodingOptions(clipTimestamps: clipTimestamps)
let seekClips = prepareSeekClips(contentFrames: waveform.count, decodeOptions: options)
return seekClips
}

func calculateSeekTimestamps(in waveform: [Float]) -> [(startTime: Float, endTime: Float)] {
let nonSilentChunks = calculateActiveChunks(in: waveform)
var seekTimestamps = [(startTime: Float, endTime: Float)]()

for chunk in nonSilentChunks {
let startTimestamp = Float(chunk.startIndex) / Float(sampleRate)
let endTimestamp = Float(chunk.endIndex) / Float(sampleRate)

seekTimestamps.append(contentsOf: [(startTime: startTimestamp, endTime: endTimestamp)])
}

return seekTimestamps
}
}

0 comments on commit c03017f

Please sign in to comment.