-
Notifications
You must be signed in to change notification settings - Fork 357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model support config fetching from model repo #216
Conversation
- New github runner image does not include visionOS, so to prevent downloading for all platforms this will specify the platform from the test matrix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! 🚀
@@ -956,8 +956,7 @@ struct ContentView: View { | |||
|
|||
localModels = WhisperKit.formatModelFiles(localModels) | |||
for model in localModels { | |||
if !availableModels.contains(model), | |||
!disabledModels.contains(model) | |||
if !availableModels.contains(model) | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cuddle curly braces here and LOC 972-981
or consider replacing the for loops with filter:
availableModels.append(contentsOf: localModels.filter { !availableModels.contains($0) })
let currentFrameLength = startIndex - seekClipStart | ||
if startIndex >= currentFrameLength, startIndex < 0 { | ||
let currentFrameLength = audioArray.count | ||
if startIndex < 0 || startIndex >= currentFrameLength { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: don't need currentFrameLength
and use guard startIndex >= 0 && startIndex < audioArray.count else {
@@ -95,7 +95,7 @@ public extension AudioProcessing { | |||
static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? { | |||
let currentFrameLength = audioArray.count | |||
|
|||
if startIndex >= currentFrameLength, startIndex < 0 { | |||
if startIndex < 0 || startIndex >= currentFrameLength { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
return support.models | ||
} | ||
} | ||
return defaultSupport.models |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add log that defaultSupport was used
Sources/WhisperKit/Core/Models.swift
Outdated
public struct ModelSupportConfig: Codable { | ||
public let repoName: String | ||
public let repoVersion: String | ||
public var deviceSupport: [DeviceSupport] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: plural deviceSupports
since it stores an array of DeviceSupport
Sources/WhisperKit/Core/Models.swift
Outdated
} | ||
|
||
// Add fallback device support that don't overlap with remote | ||
for fallbackSupport in fallback { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use filter
mergedSupports.append(contentsOf: fallback.filter { !$0.identifiers.contains(where: remoteIdentifiers.contains) })
public static let fallbackModelSupportConfig: ModelSupportConfig = { | ||
var config = ModelSupportConfig( | ||
repoName: "whisperkit-coreml", | ||
repoVersion: "0.2", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would we need to identify that it's a fallback config? maybe append -local to version here?
modelSupportConfig = try decoder.decode(ModelSupportConfig.self, from: jsonData) | ||
} catch { | ||
// Allow this to fail gracefully as it uses fallback config by default | ||
Logging.debug(error) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log as error and add some text that this error happened when fetching model support config
This PR adds capability to fetch supported models for a specific device based on the config.json file in the model repo on HF. Config will update over time as new models and devices come out. Need for this arose with the latest whisper turbo model release.
Changes:
WhisperKit.recommendedRemoteModels()
to fetch the config.json file and return aModelSupport
objectNew optional parameters for synchronousThis proved difficult to conform to swift concurrency as a static function for macOS 14.6 and below, so removed for now.WhisperKit.recommendedModels()
function to try to fetch the remote config, and a timeout for the requestConstants.fallbackModelSupportConfig
for when remote config is unavailable in offline environmentsWhisperKit.deviceName()
to use a similar device identifier for mac as it does for iOS etc