Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add public callbacks to help expose internal state a little more #240

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,47 @@ public struct TranscriptionProgress {
}
}

// Callbacks to receive state updates during transcription.

/// A callback that provides transcription segments as they are discovered.
/// - Parameters:
/// - segments: An array of `TranscriptionSegment` objects representing the transcribed segments
public typealias SegmentDiscoveryCallback = (_ segments: [TranscriptionSegment]) -> Void

/// A callback that reports changes in the model's state.
/// - Parameters:
/// - oldState: The previous state of the model, if any
/// - newState: The current state of the model
public typealias ModelStateCallback = (_ oldState: ModelState?, _ newState: ModelState) -> Void

/// A callback that reports changes in the transcription process.
/// - Parameter state: The current `TranscriptionState` of the transcription process
public typealias TranscriptionStateCallback = (_ state: TranscriptionState) -> Void

/// Represents the different states of the transcription process.
public enum TranscriptionState: CustomStringConvertible {
/// The audio is being converted to the required format for transcription
case convertingAudio

/// The audio is actively being transcribed to text
case transcribing

/// The transcription process has completed
case finished

/// A human-readable description of the transcription state
public var description: String {
switch self {
case .convertingAudio:
return "Converting Audio"
case .transcribing:
return "Transcribing"
case .finished:
return "Finished"
}
}
}

/// Callback to receive progress updates during transcription.
///
/// - Parameters:
Expand Down
4 changes: 4 additions & 0 deletions Sources/WhisperKit/Core/TranscribeTask.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ final class TranscribeTask {
private let textDecoder: any TextDecoding
private let tokenizer: any WhisperTokenizer

public var segmentDiscoveryCallback: SegmentDiscoveryCallback?

init(
currentTimings: TranscriptionTimings,
progress: Progress?,
Expand Down Expand Up @@ -230,6 +232,8 @@ final class TranscribeTask {
}
}

segmentDiscoveryCallback?(currentSegments)

// add them to the `allSegments` list
allSegments.append(contentsOf: currentSegments)
let allCurrentTokens = currentSegments.flatMap { $0.tokens }
Expand Down
35 changes: 31 additions & 4 deletions Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ import Tokenizers
open class WhisperKit {
/// Models
public private(set) var modelVariant: ModelVariant = .tiny
public private(set) var modelState: ModelState = .unloaded
public private(set) var modelState: ModelState = .unloaded {
didSet {
modelStateCallback?(oldValue, modelState)
}
}

public var modelCompute: ModelComputeOptions
public var tokenizer: WhisperTokenizer?

Expand Down Expand Up @@ -42,6 +47,11 @@ open class WhisperKit {
public var tokenizerFolder: URL?
public private(set) var useBackgroundDownloadSession: Bool

/// Callbacks
public var segmentDiscoveryCallback: SegmentDiscoveryCallback?
public var modelStateCallback: ModelStateCallback?
public var transcriptionStateCallback: TranscriptionStateCallback?

public init(_ config: WhisperKitConfig = WhisperKitConfig()) async throws {
modelCompute = config.computeOptions ?? ModelComputeOptions()
audioProcessor = config.audioProcessor ?? AudioProcessor()
Expand Down Expand Up @@ -365,7 +375,7 @@ open class WhisperKit {
} else {
currentTimings.decoderLoadTime = CFAbsoluteTimeGetCurrent() - decoderLoadStart
}

Logging.debug("Loaded text decoder in \(String(format: "%.2f", currentTimings.decoderLoadTime))s")
}

Expand All @@ -378,13 +388,13 @@ open class WhisperKit {
computeUnits: modelCompute.audioEncoderCompute,
prewarmMode: prewarmMode
)

if prewarmMode {
currentTimings.encoderSpecializationTime = CFAbsoluteTimeGetCurrent() - encoderLoadStart
} else {
currentTimings.encoderLoadTime = CFAbsoluteTimeGetCurrent() - encoderLoadStart
}

Logging.debug("Loaded audio encoder in \(String(format: "%.2f", currentTimings.encoderLoadTime))s")
}

Expand Down Expand Up @@ -549,6 +559,8 @@ open class WhisperKit {
decodeOptions: DecodingOptions? = nil,
callback: TranscriptionCallback = nil
) async -> [Result<[TranscriptionResult], Swift.Error>] {
transcriptionStateCallback?(.convertingAudio)

// Start timing the audio loading and conversion process
let loadAudioStart = Date()

Expand All @@ -561,6 +573,11 @@ open class WhisperKit {
currentTimings.audioLoading = loadAndConvertTime
Logging.debug("Total Audio Loading and Converting Time: \(loadAndConvertTime)")

transcriptionStateCallback?(.transcribing)
defer {
transcriptionStateCallback?(.finished)
}

// Transcribe the loaded audio arrays
let transcribeResults = await transcribeWithResults(
audioArrays: audioArrays,
Expand Down Expand Up @@ -733,6 +750,8 @@ open class WhisperKit {
decodeOptions: DecodingOptions? = nil,
callback: TranscriptionCallback = nil
) async throws -> [TranscriptionResult] {
transcriptionStateCallback?(.convertingAudio)

// Process input audio file into audio samples
let audioArray = try await withThrowingTaskGroup(of: [Float].self) { group -> [Float] in
let convertAudioStart = Date()
Expand All @@ -746,6 +765,12 @@ open class WhisperKit {
return try AudioProcessor.loadAudioAsFloatArray(fromPath: audioPath)
}

transcriptionStateCallback?(.transcribing)
defer {
transcriptionStateCallback?(.finished)
}

// Send converted samples to be transcribed
let transcribeResults: [TranscriptionResult] = try await transcribe(
audioArray: audioArray,
decodeOptions: decodeOptions,
Expand Down Expand Up @@ -872,6 +897,8 @@ open class WhisperKit {
tokenizer: tokenizer
)

transcribeTask.segmentDiscoveryCallback = self.segmentDiscoveryCallback

let transcribeTaskResult = try await transcribeTask.run(
audioArray: audioArray,
decodeOptions: decodeOptions,
Expand Down
37 changes: 37 additions & 0 deletions Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,43 @@ final class UnitTests: XCTestCase {
XCTAssertEqual(result.segments.first?.text, " and so my fellow americans ask not what your country can do for you ask what you can do for your country.")
}

func testCallbacks() async throws {
let config = try WhisperKitConfig(
modelFolder: tinyModelPath(),
verbose: true,
logLevel: .debug,
load: false
)
let whisperKit = try await WhisperKit(config)
let modelStateExpectation = XCTestExpectation(description: "Model state callback expectation")
whisperKit.modelStateCallback = { (oldState: ModelState?, newState: ModelState) in
Logging.debug("Model state: \(newState)")
modelStateExpectation.fulfill()
}

let segmentDiscoveryExpectation = XCTestExpectation(description: "Segment discovery callback expectation")
whisperKit.segmentDiscoveryCallback = { (segments: [TranscriptionSegment]) in
Logging.debug("Segments discovered: \(segments)")
segmentDiscoveryExpectation.fulfill()
}

let transcriptionStateExpectation = XCTestExpectation(description: "Transcription state callback expectation")
whisperKit.transcriptionStateCallback = { (state: TranscriptionState) in
Logging.debug("Transcription state: \(state)")
transcriptionStateExpectation.fulfill()
}

// Run the full pipeline
try await whisperKit.loadModels()
let audioFilePath = try XCTUnwrap(
Bundle.current.path(forResource: "jfk", ofType: "wav"),
"Audio file not found"
)
let _ = try await whisperKit.transcribe(audioPath: audioFilePath)

await fulfillment(of: [modelStateExpectation, segmentDiscoveryExpectation, transcriptionStateExpectation], timeout: 1)
}

// MARK: - Utils Tests

func testFillIndexesWithValue() throws {
Expand Down