Skip to content

Commit

Permalink
Add model support config fetching from model repo (#216)
Browse files Browse the repository at this point in the history
* Add model support config fetching from model repo

* Fix audio start index error handling

Co-authored-by: 1amageek <[email protected]>

* Formatting

* Fix CI + watchOS build

- New github runner image does not include visionOS, so to prevent downloading for all platforms this will specify the platform from the test matrix

* Fix typo

* Use dispatch group for sync recommendedModels

* Remove sync remote model fetching

* Formatting and cleanup from review

---------

Co-authored-by: 1amageek <[email protected]>
  • Loading branch information
ZachNagengast and 1amageek authored Oct 8, 2024
1 parent c2f1b57 commit bfb1316
Show file tree
Hide file tree
Showing 14 changed files with 676 additions and 128 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ jobs:
run: make download-model MODEL=tiny
- name: Install and discover destinations
run: |
xcodebuild -downloadAllPlatforms
if [[ "${{ matrix.run-config['name'] }}" != "macOS" ]]; then
xcodebuild -downloadPlatform ${{ matrix.run-config['name'] }}
fi
echo "Destinations for testing:"
xcodebuild test-without-building -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -showdestinations
- name: Boot Simulator and Wait
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers.git",
"state" : {
"revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe",
"version" : "0.1.7"
"revision" : "fc6543263e4caed9bf6107466d625cfae9357f08",
"version" : "0.1.8"
}
}
],
Expand Down
25 changes: 13 additions & 12 deletions Examples/WhisperAX/WhisperAX/Views/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ struct ContentView: View {
MenuItem(name: "Stream", image: "waveform.badge.mic"),
]


private var isStreamMode: Bool {
self.selectedCategoryId == menu.first(where: { $0.name == "Stream" })?.id
}
Expand Down Expand Up @@ -202,7 +201,7 @@ struct ContentView: View {
.toolbar(content: {
ToolbarItem {
Button {
if (!enableEagerDecoding) {
if !enableEagerDecoding {
let fullTranscript = formatSegments(confirmedSegments + unconfirmedSegments, withTimestamps: enableTimestamps).joined(separator: "\n")
#if os(iOS)
UIPasteboard.general.string = fullTranscript
Expand Down Expand Up @@ -956,9 +955,7 @@ struct ContentView: View {

localModels = WhisperKit.formatModelFiles(localModels)
for model in localModels {
if !availableModels.contains(model),
!disabledModels.contains(model)
{
if !availableModels.contains(model) {
availableModels.append(model)
}
}
Expand All @@ -967,12 +964,17 @@ struct ContentView: View {
print("Previously selected model: \(selectedModel)")

Task {
let remoteModels = try await WhisperKit.fetchAvailableModels(from: repoName)
for model in remoteModels {
if !availableModels.contains(model),
!disabledModels.contains(model)
{
availableModels.append(model)
let remoteModelSupport = await WhisperKit.recommendedRemoteModels()
await MainActor.run {
for model in remoteModelSupport.supported {
if !availableModels.contains(model) {
availableModels.append(model)
}
}
for model in remoteModelSupport.disabled {
if !disabledModels.contains(model) {
disabledModels.append(model)
}
}
}
}
Expand Down Expand Up @@ -1644,7 +1646,6 @@ struct ContentView: View {
finalizeText()
}


let mergedResult = mergeTranscriptionResults(eagerResults, confirmedWords: confirmedWords)

return mergedResult
Expand Down
4 changes: 2 additions & 2 deletions Package.resolved
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers.git",
"state" : {
"revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe",
"version" : "0.1.7"
"revision" : "fc6543263e4caed9bf6107466d625cfae9357f08",
"version" : "0.1.8"
}
}
],
Expand Down
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ let package = Package(
),
],
dependencies: [
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.7"),
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.8"),
.package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"),
],
targets: [
Expand Down
4 changes: 2 additions & 2 deletions Sources/WhisperKit/Core/Audio/AudioChunker.swift
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ open class VADAudioChunker: AudioChunking {
// Typically this will be the full audio file, unless seek points are explicitly provided
var startIndex = seekClipStart
while startIndex < seekClipEnd - windowPadding {
let currentFrameLength = startIndex - seekClipStart
if startIndex >= currentFrameLength, startIndex < 0 {
let currentFrameLength = audioArray.count
guard startIndex >= 0 && startIndex < audioArray.count else {
throw WhisperError.audioProcessingFailed("startIndex is outside the buffer size")
}

Expand Down
2 changes: 1 addition & 1 deletion Sources/WhisperKit/Core/Audio/AudioProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public extension AudioProcessing {
static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? {
let currentFrameLength = audioArray.count

if startIndex >= currentFrameLength, startIndex < 0 {
guard startIndex >= 0 && startIndex < audioArray.count else {
Logging.error("startIndex is outside the buffer size")
return nil
}
Expand Down
5 changes: 2 additions & 3 deletions Sources/WhisperKit/Core/Configurations.swift
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ open class WhisperKitConfig {
prewarm: Bool? = nil,
load: Bool? = nil,
download: Bool = true,
useBackgroundDownloadSession: Bool = false
) {
useBackgroundDownloadSession: Bool = false)
{
self.model = model
self.downloadBase = downloadBase
self.modelRepo = modelRepo
Expand All @@ -83,7 +83,6 @@ open class WhisperKitConfig {
}
}


/// Options for how to transcribe an audio file using WhisperKit.
///
/// - Parameters:
Expand Down
245 changes: 245 additions & 0 deletions Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,114 @@ public struct ModelComputeOptions {
}
}

public struct ModelSupport: Codable, Equatable {
public let `default`: String
public let supported: [String]
/// Computed on init of ModelRepoConfig
public var disabled: [String] = []

private enum CodingKeys: String, CodingKey {
case `default`, supported
}
}

public struct DeviceSupport: Codable {
public let identifiers: [String]
public var models: ModelSupport
}

public struct ModelSupportConfig: Codable {
public let repoName: String
public let repoVersion: String
public var deviceSupports: [DeviceSupport]
/// Computed on init
public private(set) var knownModels: [String]
public private(set) var defaultSupport: DeviceSupport

enum CodingKeys: String, CodingKey {
case repoName = "name"
case repoVersion = "version"
case deviceSupports = "device_support"
}

public init(from decoder: Swift.Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let repoName = try container.decode(String.self, forKey: .repoName)
let repoVersion = try container.decode(String.self, forKey: .repoVersion)
let deviceSupports = try container.decode([DeviceSupport].self, forKey: .deviceSupports)

self.init(repoName: repoName, repoVersion: repoVersion, deviceSupports: deviceSupports)
}

public init(repoName: String, repoVersion: String, deviceSupports: [DeviceSupport], includeFallback: Bool = true) {
self.repoName = repoName
self.repoVersion = repoVersion

if includeFallback {
self.deviceSupports = Self.mergeDeviceSupport(remote: deviceSupports, fallback: Constants.fallbackModelSupportConfig.deviceSupports)
self.knownModels = self.deviceSupports.flatMap { $0.models.supported }.orderedSet
} else {
self.deviceSupports = deviceSupports
self.knownModels = deviceSupports.flatMap { $0.models.supported }.orderedSet
}

// Add default device support with all models supported for unknown devices
self.defaultSupport = DeviceSupport(
identifiers: [],
models: ModelSupport(
default: "openai_whisper-base",
supported: self.knownModels
)
)

computeDisabledModels()
}

@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
public func modelSupport(for deviceIdentifier: String = WhisperKit.deviceName()) -> ModelSupport {
for support in deviceSupports {
if support.identifiers.contains(where: { deviceIdentifier.hasPrefix($0) }) {
return support.models
}
}

Logging.info("No device support found for \(deviceIdentifier), using default")
return defaultSupport.models
}

private mutating func computeDisabledModels() {
for i in 0..<deviceSupports.count {
let disabledModels = Set(knownModels).subtracting(deviceSupports[i].models.supported)
self.deviceSupports[i].models.disabled = Array(disabledModels)
}
}

private static func mergeDeviceSupport(remote: [DeviceSupport], fallback: [DeviceSupport]) -> [DeviceSupport] {
var mergedSupports: [DeviceSupport] = []
let remoteIdentifiers = Set(remote.flatMap { $0.identifiers })

// Add remote device supports, merging with fallback if identifiers overlap
for remoteSupport in remote {
if let fallbackSupport = fallback.first(where: { $0.identifiers.contains(where: remoteSupport.identifiers.contains) }) {
let mergedModels = ModelSupport(
default: remoteSupport.models.default,
supported: (remoteSupport.models.supported + fallbackSupport.models.supported).orderedSet
)
mergedSupports.append(DeviceSupport(identifiers: remoteSupport.identifiers, models: mergedModels))
} else {
mergedSupports.append(remoteSupport)
}
}

// Add fallback device supports that don't overlap with remote
for fallbackSupport in fallback where !fallbackSupport.identifiers.contains(where: remoteIdentifiers.contains) {
mergedSupports.append(fallbackSupport)
}

return mergedSupports
}
}

// MARK: - Chunking

public struct AudioChunk {
Expand Down Expand Up @@ -1346,4 +1454,141 @@ public enum Constants {
public static let defaultLanguageCode: String = "en"

public static let defaultAudioReadFrameSize: AVAudioFrameCount = 1_323_000 // 30s of audio at commonly found 44.1khz sample rate

public static let fallbackModelSupportConfig: ModelSupportConfig = {
var config = ModelSupportConfig(
repoName: "whisperkit-coreml-fallback",
repoVersion: "0.2",
deviceSupports: [
DeviceSupport(
identifiers: ["iPhone11", "iPhone12", "Watch7", "Watch8"],
models: ModelSupport(
default: "openai_whisper-tiny",
supported: [
"openai_whisper-base",
"openai_whisper-base.en",
"openai_whisper-tiny",
"openai_whisper-tiny.en",
]
)
),
DeviceSupport(
identifiers: ["iPhone13", "iPad13,18", "iPad13,1"],
models: ModelSupport(
default: "openai_whisper-base",
supported: [
"openai_whisper-tiny",
"openai_whisper-tiny.en",
"openai_whisper-base",
"openai_whisper-base.en",
"openai_whisper-small",
"openai_whisper-small.en",
]
)
),
DeviceSupport(
identifiers: ["iPhone14", "iPhone15", "iPhone16", "iPhone17", "iPad14,1", "iPad14,2"],
models: ModelSupport(
default: "openai_whisper-base",
supported: [
"openai_whisper-tiny",
"openai_whisper-tiny.en",
"openai_whisper-base",
"openai_whisper-base.en",
"openai_whisper-small",
"openai_whisper-small.en",
"openai_whisper-large-v2_949MB",
"openai_whisper-large-v2_turbo_955MB",
"openai_whisper-large-v3_947MB",
"openai_whisper-large-v3_turbo_954MB",
"distil-whisper_distil-large-v3_594MB",
"distil-whisper_distil-large-v3_turbo_600MB",
"openai_whisper-large-v3-v20240930_626MB",
"openai_whisper-large-v3-v20240930_turbo_632MB",
]
)
),
DeviceSupport(
identifiers: [
"Mac13",
"iMac21",
"MacBookAir10,1",
"MacBookPro17",
"MacBookPro18",
"Macmini9",
"iPad13,16",
"iPad13,4",
"iPad13,8",
],
models: ModelSupport(
default: "openai_whisper-large-v3-v20240930",
supported: [
"openai_whisper-tiny",
"openai_whisper-tiny.en",
"openai_whisper-base",
"openai_whisper-base.en",
"openai_whisper-small",
"openai_whisper-small.en",
"openai_whisper-large-v2",
"openai_whisper-large-v2_949MB",
"openai_whisper-large-v3",
"openai_whisper-large-v3_947MB",
"distil-whisper_distil-large-v3",
"distil-whisper_distil-large-v3_594MB",
"openai_whisper-large-v3-v20240930",
"openai_whisper-large-v3-v20240930_626MB",
]
)
),
DeviceSupport(
identifiers: [
"Mac14",
"Mac15",
"Mac16",
"iPad14,3",
"iPad14,4",
"iPad14,5",
"iPad14,6",
"iPad14,8",
"iPad14,9",
"iPad14,10",
"iPad14,11",
"iPad16",
],
models: ModelSupport(
default: "openai_whisper-large-v3-v20240930",
supported: [
"openai_whisper-tiny",
"openai_whisper-tiny.en",
"openai_whisper-base",
"openai_whisper-base.en",
"openai_whisper-small",
"openai_whisper-small.en",
"openai_whisper-large-v2",
"openai_whisper-large-v2_949MB",
"openai_whisper-large-v2_turbo",
"openai_whisper-large-v2_turbo_955MB",
"openai_whisper-large-v3",
"openai_whisper-large-v3_947MB",
"openai_whisper-large-v3_turbo",
"openai_whisper-large-v3_turbo_954MB",
"distil-whisper_distil-large-v3",
"distil-whisper_distil-large-v3_594MB",
"distil-whisper_distil-large-v3_turbo",
"distil-whisper_distil-large-v3_turbo_600MB",
"openai_whisper-large-v3-v20240930",
"openai_whisper-large-v3-v20240930_turbo",
"openai_whisper-large-v3-v20240930_626MB",
"openai_whisper-large-v3-v20240930_turbo_632MB",
]
)
),
],
includeFallback: false
)

return config
}()

public static let knownModels: [String] = fallbackModelSupportConfig.deviceSupports.flatMap { $0.models.supported }.orderedSet
}
Loading

0 comments on commit bfb1316

Please sign in to comment.