-
Notifications
You must be signed in to change notification settings - Fork 357
/
Copy pathRegressionTests.swift
721 lines (622 loc) · 29.2 KB
/
RegressionTests.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
// For licensing see accompanying LICENSE.md file.
// Copyright © 2024 Argmax, Inc. All rights reserved.
import CoreML
import Foundation
import Hub
import UniformTypeIdentifiers
import WhisperKit
import XCTest
#if os(watchOS)
import WatchKit
#endif
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
class RegressionTests: XCTestCase {
var audioFileURLs: [URL]?
var remoteFileURLs: [URL]?
var metadataURL: URL?
var testWERURLs: [URL]?
var modelsToTest: [String] = []
var modelReposToTest: [String] = []
var modelsTested: [String] = []
var modelReposTested: [String] = []
var optionsToTest: [DecodingOptions] = [DecodingOptions()]
struct TestConfig {
let dataset: String
let modelComputeOptions: ModelComputeOptions
var model: String
var modelRepo: String
let decodingOptions: DecodingOptions
}
/// Located on HF https://huggingface.co/datasets/argmaxinc/whisperkit-test-data/tree/main
let datasetRepo = "argmaxinc/whisperkit-test-data"
var datasets = ["librispeech-10mins", "earnings22-10mins"]
let debugDataset = ["earnings22-10mins"]
let debugModels = ["tiny"]
let debugRepos = ["argmaxinc/whisperkit-coreml"]
var computeOptions: [ModelComputeOptions] = [
ModelComputeOptions(audioEncoderCompute: .cpuAndNeuralEngine, textDecoderCompute: .cpuAndNeuralEngine),
]
let defaultDecodingOptions = DecodingOptions(
verbose: true,
task: .transcribe
)
let vadDecodingOptions = DecodingOptions(
verbose: true,
task: .transcribe,
concurrentWorkerCount: 16,
chunkingStrategy: .vad
)
override func setUp() {
super.setUp()
#if canImport(UIApplication)
NotificationCenter.default.addObserver(
self,
selector: #selector(didReceiveMemoryWarning),
name: UIApplication.didReceiveMemoryWarningNotification,
object: nil
)
#endif
}
@objc func didReceiveMemoryWarning() {
Logging.debug("Received memory warning")
// TODO: Record this data in the test results
let maxMemory = SystemMemoryCheckerAdvanced.getMemoryUsage()
Logging.debug("Max memory before warning: \(maxMemory)")
}
class func getModelToken() -> String? {
// Add token here or override
return nil
}
func testEnvConfigurations(defaultModels: [String]? = nil, defaultRepos: [String]? = nil) {
if let modelSizeEnv = ProcessInfo.processInfo.environment["MODEL_NAME"], !modelSizeEnv.isEmpty {
modelsToTest = [modelSizeEnv]
Logging.debug("Model size: \(modelSizeEnv)")
if let repoEnv = ProcessInfo.processInfo.environment["MODEL_REPO"] {
modelReposToTest = [repoEnv]
Logging.debug("Using repo: \(repoEnv)")
}
XCTAssertTrue(modelsToTest.count > 0, "Invalid model size: \(modelSizeEnv)")
if modelSizeEnv == "crash_test" {
fatalError("Crash test triggered")
}
} else {
modelsToTest = defaultModels ?? debugModels
modelReposToTest = defaultRepos ?? debugRepos
Logging.debug("Model size not set by env")
}
}
func testModelPerformanceWithDebugConfig() async throws {
testEnvConfigurations()
// Debug test matrix
datasets = debugDataset
optionsToTest = [vadDecodingOptions]
computeOptions = [computeOptions.first!]
let debugTestMatrix: [TestConfig] = getTestMatrix()
Logging.debug("Running \(debugTestMatrix.count) regression tests for models: \(modelsToTest)")
// Run the tests
try await runRegressionTests(with: debugTestMatrix)
}
func testModelPerformance() async throws {
testEnvConfigurations(defaultModels: WhisperKit.recommendedModels().supported)
// Setup test matrix
optionsToTest = [vadDecodingOptions]
computeOptions = [computeOptions.first!]
let testMatrix: [TestConfig] = getTestMatrix()
Logging.debug("Running \(testMatrix.count) regression tests for models: \(modelsToTest)")
// Run the tests
try await runRegressionTests(with: testMatrix)
}
// MARK: - Test Pipeline
public func runRegressionTests(with testMatrix: [TestConfig]) async throws {
var failureInfo: [String: String] = [:]
var attachments: [String: String] = [:]
let device = getCurrentDevice()
for (i, config) in testMatrix.enumerated() {
do {
Logging.debug("Running test \(i + 1)/\(testMatrix.count) for \(config.model) with \(config.dataset) on \(device) using encoder compute: \(config.modelComputeOptions.audioEncoderCompute.description) and decoder compute: \(config.modelComputeOptions.textDecoderCompute.description)")
let expectation = XCTestExpectation(description: "Download test audio files for \(config.dataset) dataset")
downloadTestData(forDataset: config.dataset) { success in
if success {
expectation.fulfill()
} else {
XCTFail("Downloading audio file for testing failed")
}
}
await fulfillment(of: [expectation], timeout: 300)
let attachmentName = try await testAndMeasureModelPerformance(config: config, device: device)
attachments[config.dataset] = attachmentName
try await Task.sleep(nanoseconds: 1_000_000_000)
} catch {
Logging.debug("Failed to test \(config.model): \(error)")
failureInfo[config.model] = error.localizedDescription
}
}
// Save summary
saveSummary(failureInfo: failureInfo, attachments: attachments)
}
func testAndMeasureModelPerformance(config: TestConfig, device: String) async throws -> String? {
var config = config
var resultJSON: [RegressionStats] = []
let audioFilePaths = try XCTUnwrap(
self.audioFileURLs,
"Audio files not found"
).map { $0.path() }
if WhisperKit.recommendedModels().disabled.contains(where: { $0.range(of: config.model) != nil }) {
throw WhisperError.modelsUnavailable("Skipping model \(config.model), disabled for \(device).")
}
// Create WhisperKit instance with checks for memory usage
let whisperKit = try await createWithMemoryCheck(
testConfig: config,
verbose: true,
logLevel: .debug
)
if let modelFile = whisperKit.modelFolder?.lastPathComponent {
config.model = modelFile
modelsTested.append(modelFile)
modelsTested = Array(Set(modelsTested))
modelReposTested.append(config.modelRepo)
modelReposTested = Array(Set(modelReposTested))
}
for audioFilePath in audioFilePaths {
// Process each audio file
try await processAudioFile(
audioFilePath: audioFilePath,
whisperKit: whisperKit,
config: config,
device: device,
resultJSON: &resultJSON
)
}
do {
let jsonData = try JSONEncoder().encode(resultJSON)
let attachment = XCTAttachment(data: jsonData, uniformTypeIdentifier: UTType.json.identifier)
let attachmentName = "\(device)_\(config.model)_\(Date().formatted(Date.ISO8601FormatStyle().dateSeparator(.dash).timeSeparator(.omitted)))_\(config.dataset)".replacingOccurrences(of: ".", with: "_")
attachment.name = attachmentName + ".json"
attachment.lifetime = .keepAlways
add(attachment)
return attachmentName
} catch {
XCTFail("Failed with error: \(error)")
return nil
}
}
func processAudioFile(
audioFilePath: String,
whisperKit: WhisperKit,
config: TestConfig,
device: String,
resultJSON: inout [RegressionStats]
) async throws {
let startTime = Date()
// Initialize test state
var memoryStats = MemoryStats(
measurements: [], units: "MB",
totalNumberOfMeasurements: 0,
preTranscribeMemory: -1,
postTranscribeMemory: -1
)
var latencyStats = LatencyStats(
measurements: [], units: "Tokens/Sec",
totalNumberOfMeasurements: 0
)
let startTimeStamp = CFAbsoluteTimeGetCurrent()
let testState = TranscriptionTestState(
startTime: startTime,
startTimeStamp: startTimeStamp,
memoryStats: memoryStats,
latencyStats: latencyStats
)
let callback = { (result: TranscriptionProgress) -> Bool in
Task {
await testState.update(with: result)
}
return true
}
memoryStats.preTranscribeMemory = Float(AppMemoryChecker.getMemoryUsed())
var systemMemory: [SystemMemoryUsage] = []
var diskSpace: [DiskSpace] = []
var batteryLevel: [Float] = []
var thermalState: [Int] = []
var timerTimeElapsed: [TimeInterval] = []
// Start your timer
let timerQueue = DispatchQueue(label: "RegressionTimerQueue")
let timer = DispatchSource.makeTimerSource(queue: timerQueue)
timer.schedule(deadline: .now(), repeating: 1.0)
timer.setEventHandler {
systemMemory.append(SystemMemoryCheckerAdvanced.getMemoryUsage())
diskSpace.append(DiskSpaceChecker.getDiskSpace())
batteryLevel.append(BatteryLevelChecker.getBatteryLevel() ?? -1)
thermalState.append(ThermalStateChecker.getThermalState())
timerTimeElapsed.append(Date().timeIntervalSince(startTime))
}
timer.resume()
// Perform transcription
let transcriptionResults = try await whisperKit.transcribe(
audioPath: audioFilePath,
decodeOptions: config.decodingOptions,
callback: callback
)
let tpsThreshold = 4.0
let currentTPS = await testState.getCurrentTPS()
if !(currentTPS != 0 && currentTPS > tpsThreshold) {
XCTFail("Tokens per second below expected for compute unit \(currentTPS), potential CPU fallback")
}
let transcriptionResult = mergeTranscriptionResults(transcriptionResults)
// Store final measurements
let (finalMemoryStats, finalLatencyStats) = await testState.processFinalMeasurements()
memoryStats = finalMemoryStats
latencyStats = finalLatencyStats
memoryStats.postTranscribeMemory = Float(AppMemoryChecker.getMemoryUsed())
let filename = String(audioFilePath.split(separator: "/").last!)
guard let reference = getTranscript(filename: filename) else {
Logging.debug("Reference transcript not found for \(filename)")
return
}
let (wer, diff) = WERUtils.evaluate(
originalTranscript: reference,
generatedTranscript: transcriptionResult.text
)
let modelSizeMB = try? getFolderSize(atUrl: whisperKit.modelFolder)
let testInfo = TestInfo(
device: device,
audioFile: URL(fileURLWithPath: audioFilePath).lastPathComponent,
datasetDir: config.dataset,
datasetRepo: datasetRepo,
model: config.model,
modelRepo: config.modelRepo,
modelSizeMB: modelSizeMB ?? -1,
date: startTime.formatted(Date.ISO8601FormatStyle().dateSeparator(.dash)),
timeElapsedInSeconds: Date().timeIntervalSince(startTime),
timings: transcriptionResult.timings,
prediction: transcriptionResult.text,
reference: reference,
wer: wer,
diff: diff
)
let staticAttributes = StaticAttributes(
encoderCompute: whisperKit.modelCompute.audioEncoderCompute,
decoderCompute: whisperKit.modelCompute.textDecoderCompute,
decodingOptions: config.decodingOptions
)
let systemMeasurements = SystemMeasurements(
systemMemory: systemMemory,
diskSpace: diskSpace,
batteryLevel: batteryLevel,
thermalState: thermalState,
timeElapsed: timerTimeElapsed
)
let json = RegressionStats(
testInfo: testInfo,
memoryStats: memoryStats,
latencyStats: latencyStats,
staticAttributes: staticAttributes,
systemMeasurements: systemMeasurements
)
resultJSON.append(json)
}
// MARK: - Pipeline Tests
func testLargeWER() async {
let texts = await getWERTestData()
let (simpleWER, simpleDiff) = WERUtils.evaluate(originalTranscript: "This is some basic text", generatedTranscript: "This is edited text with some words added replaced and deleted")
XCTAssertEqual(simpleWER, 1.7, accuracy: 0.1, "Expected wer: 1.7 but computed \(simpleWER)")
XCTAssertEqual(simpleDiff.count, 23)
if let originalText = texts.0, let generatedText = texts.1 {
let (werNormalized, _) = WERUtils.evaluate(originalTranscript: originalText, generatedTranscript: generatedText)
XCTAssertEqual(werNormalized, 0.1852080123266564, accuracy: 0.001, "Expected wer: 0.1852080123266564 but computed \(werNormalized)")
let (wer, _) = WERUtils.evaluate(originalTranscript: originalText, generatedTranscript: generatedText, normalizeOriginal: false)
XCTAssertEqual(wer, 0.42448103078024335, accuracy: 0.001, "Expected wer: 0.42448103078024335 but computed \(wer)")
} else {
XCTFail("Fetching WER test data failed.")
}
}
func testNormalizer() {
let normalizer = EnglishTextNormalizer()
let jsonText = "hello\\u2026 this is a test over GH\\u20b5 94 million in fees in H\\u00f8rsholm and Basel grew 10% to one billions, 370 millions"
let swiftText = "hello\u{2026} this is a test over GH\u{20b5} 94 million in fees in H\u{00f8}rsholm and Basel grew 10% to one billions, 370 millions"
let resultJson = normalizer.normalize(text: jsonText)
let resultSwift = normalizer.normalize(text: swiftText)
XCTAssertEqual(resultSwift, "hello . this is a test over gh 94000000 in fees in horsholm and basel grew 10% to 1000000000s 370000000s")
XCTAssertEqual(resultJson, resultSwift)
}
func testHirschberg() {
let s1 = "With a rumble that echoed through the night, thunder crashed overhead, its raw power shaking the earth beneath it, leaving in its wake an exhilarating sense of awe. As rain poured down in torrents, the thunder boomed with a rhythm that seemed to speak a secret language, intertwining nature's symphony with an innovative melody that captivated all who listened."
let s2 = "In the midst of a summer storm, thunder erupted with a booming chorus, shaking the earth beneath our feet and electrifying the air with its powerful presence. The crackling symphony of thunderbolts danced across the darkened sky, illuminating the clouds with an innovative display of nature's raw energy."
let ops = hirschberg(Array(s1.unicodeScalars), Array(s2.unicodeScalars))
XCTAssertEqual(ops.count, 228)
}
func testLevenshtein() {
let s1 = "With a rumble that echoed through the night, thunder crashed overhead, its raw power shaking the earth beneath it, leaving in its wake an exhilarating sense of awe. As rain poured down in torrents, the thunder boomed with a rhythm that seemed to speak a secret language, intertwining nature's symphony with an innovative melody that captivated all who listened."
let s2 = "In the midst of a summer storm, thunder erupted with a booming chorus, shaking the earth beneath our feet and electrifying the air with its powerful presence. The crackling symphony of thunderbolts danced across the darkened sky, illuminating the clouds with an innovative display of nature's raw energy."
let ops = levenshtein(Array(s1.unicodeScalars), Array(s2.unicodeScalars))
XCTAssertEqual(ops.count, 495)
}
func testInMemoryAndDiskUsage() async throws {
// Choose a model to test
let modelToTest = "openai_whisper-tiny"
// Get initial measurements
let initialMemory = AppMemoryChecker.getMemoryUsed()
let initialDiskSpace = DiskSpaceChecker.getDiskSpace()
let initialCacheSize = try getCacheSize()
// Create WhisperKit instance
let whisperKit = try await WhisperKit(WhisperKitConfig(
model: modelToTest,
computeOptions: ModelComputeOptions(audioEncoderCompute: .cpuAndNeuralEngine, textDecoderCompute: .cpuAndNeuralEngine),
verbose: true,
logLevel: .debug,
load: true
))
// Get final measurements
let finalMemory = AppMemoryChecker.getMemoryUsed()
let finalDiskSpace = DiskSpaceChecker.getDiskSpace()
let finalCacheSize = try getCacheSize()
// Calculate differences
let memoryUsed = finalMemory - initialMemory
let diskSpaceUsed = initialDiskSpace.freeSpaceGB! - finalDiskSpace.freeSpaceGB!
let cacheSpaceUsed = finalCacheSize - initialCacheSize
// Log results
Logging.debug("Memory used by \(modelToTest): \(memoryUsed) MB")
Logging.debug("Disk space used by \(modelToTest): \(diskSpaceUsed) MB")
Logging.debug("Cache space used by \(modelToTest): \(cacheSpaceUsed) MB")
// Assert that the measurements are within expected ranges
XCTAssertGreaterThan(memoryUsed, 0, "Model should use some memory")
XCTAssertLessThan(memoryUsed, 1000, "Model should use less than 1GB of memory")
XCTAssertGreaterThanOrEqual(diskSpaceUsed, 0, "Model should use some disk space unless already downloaded")
XCTAssertLessThan(diskSpaceUsed, 5000, "Model should use less than 5GB of disk space")
XCTAssertGreaterThanOrEqual(cacheSpaceUsed, 0, "Cache usage should not be negative")
// Clean up
await whisperKit.unloadModels()
}
// MARK: - Helper Methods
private func downloadTestDataIfNeeded() {
guard audioFileURLs == nil || metadataURL == nil || testWERURLs == nil else { return }
for dataset in datasets {
let expectation = XCTestExpectation(description: "Download test audio files for \(dataset) dataset")
downloadTestData(forDataset: dataset) { success in
if success {
expectation.fulfill()
} else {
XCTFail("Downloading audio file for testing failed")
}
}
wait(for: [expectation], timeout: 300)
}
}
public func getTestMatrix() -> [TestConfig] {
var regressionTestConfigMatrix: [TestConfig] = []
for dataset in datasets {
for computeOption in computeOptions {
for options in optionsToTest {
for repo in modelReposToTest {
for model in modelsToTest {
regressionTestConfigMatrix.append(
TestConfig(
dataset: dataset,
modelComputeOptions: computeOption,
model: model,
modelRepo: repo,
decodingOptions: options
)
)
}
}
}
}
}
return regressionTestConfigMatrix
}
private func downloadTestData(forDataset dataset: String, completion: @escaping (Bool) -> Void) {
Task {
do {
Logging.debug("Available models: \(modelsToTest)")
let testDatasetRepo = Hub.Repo(id: datasetRepo, type: .datasets)
let tempPath = FileManager.default.temporaryDirectory
let downloadBase = tempPath.appending(component: "huggingface")
let hubApi = HubApi(downloadBase: downloadBase)
let repoURL = try await hubApi.snapshot(from: testDatasetRepo, matching: ["\(dataset)/*"]) { progress in
Logging.debug("Downloading \(dataset) dataset: \(progress)")
}.appending(path: dataset)
let downloadedFiles = try FileManager.default.contentsOfDirectory(atPath: repoURL.path())
var audioFileURLs: [URL] = []
for file in downloadedFiles {
if file.hasSuffix(".mp3") {
audioFileURLs.append(repoURL.appending(component: file))
} else if file.hasSuffix(".json") {
self.metadataURL = repoURL.appending(component: file)
}
}
self.audioFileURLs = audioFileURLs
Logging.debug("Downloaded \(audioFileURLs.count) audio files.")
completion(true)
} catch {
XCTFail("Async setup failed with error: \(error)")
completion(false)
}
}
}
private func getTranscript(filename: String) -> String? {
// Ensure we can access and parse the metadata
guard let data = try? Data(contentsOf: self.metadataURL!),
let json = try? JSONSerialization.jsonObject(with: data) as? [[String: Any]]
else {
return nil
}
// Search for the matching audio item
for item in json {
// Check if the current item's audio matches the filename
let audioFileName = filename.split(separator: ".").first!
if let referenceFilename = item["audio"] as? String,
referenceFilename.contains(audioFileName)
{
// If found, return the reference text
return item["text"] as? String
}
}
// If no matching item was found, return nil
return nil
}
private func getWERTestData() async -> (String?, String?) {
do {
let testDataset = Hub.Repo(id: datasetRepo, type: .datasets)
let tempPath = FileManager.default.temporaryDirectory
let downloadBase = tempPath.appending(component: "huggingface")
let hubApi = HubApi(downloadBase: downloadBase)
let testWERRepoURL = try await hubApi.snapshot(from: testDataset, matching: ["*.txt"])
let testWERTextURLs = try FileManager.default.contentsOfDirectory(atPath: testWERRepoURL.path()).filter { $0.hasSuffix(".txt") }
self.testWERURLs = testWERTextURLs.map { testWERRepoURL.appending(component: $0) }
Logging.debug("Downloaded \(testWERTextURLs.count) test WER files.")
let testFileURLs = try XCTUnwrap(
self.testWERURLs,
"Test files for WER verification not found"
)
var generatedText: String?
var originalText: String?
for file in testFileURLs {
switch file.lastPathComponent {
case "test_generated_transcript.txt":
generatedText = try? String(contentsOf: file)
case "test_original_transcript.txt":
originalText = try? String(contentsOf: file)
default:
continue
}
}
return (originalText, generatedText)
} catch {
XCTFail("Fetching test data for WER verification failed: \(error)")
}
return (nil, nil)
}
private func saveSummary(failureInfo: [String: String], attachments: [String: String]) {
let currentDevice = getCurrentDevice()
let osDetails = getOSDetails()
let testReport = TestReport(
deviceModel: currentDevice,
osType: osDetails.osType,
osVersion: osDetails.osVersion,
modelsTested: modelsTested,
modelReposTested: modelReposTested,
failureInfo: failureInfo,
attachments: attachments
)
do {
let iso8601DateTimeString = ISO8601DateFormatter().string(from: Date())
let jsonData = try testReport.jsonData()
let attachment = XCTAttachment(data: jsonData, uniformTypeIdentifier: UTType.json.identifier)
attachment.lifetime = .keepAlways
attachment.name = "\(currentDevice)_summary_\(iso8601DateTimeString).json"
add(attachment)
} catch {
XCTFail("Failed with error: \(error)")
}
}
private func getCurrentDevice() -> String {
var currentDevice = WhisperKit.deviceName()
currentDevice = currentDevice.trimmingCharacters(in: .whitespacesAndNewlines)
currentDevice = currentDevice.replacingOccurrences(of: " ", with: "_")
return currentDevice
}
private func getOSDetails() -> (osType: String, osVersion: String) {
#if os(iOS)
return (UIDevice.current.systemName, UIDevice.current.systemVersion)
#elseif os(macOS)
let version = ProcessInfo.processInfo.operatingSystemVersion
return ("macOS", "\(version.majorVersion).\(version.minorVersion).\(version.patchVersion)")
#elseif os(watchOS)
return ("watchOS", WKInterfaceDevice.current().systemVersion)
#else
return ("Unknown", "Unknown")
#endif
}
/// Helper function to get cache size
private func getCacheSize() throws -> Int64 {
let fileManager = FileManager.default
let cacheURL = fileManager.urls(for: .cachesDirectory, in: .userDomainMask)
let cacheSize = try fileManager.allocatedSizeOfDirectory(at: cacheURL.first!)
return cacheSize / (1024 * 1024) // Convert to MB
}
private func getFolderSize(atUrl folder: URL?) throws -> Double {
guard let folder = folder else {
return -1
}
let fileManager = FileManager.default
let modelSize = try fileManager.allocatedSizeOfDirectory(at: folder)
return Double(modelSize / (1024 * 1024)) // Convert to MB
}
public func initWhisperKitTask(testConfig config: TestConfig, verbose: Bool, logLevel: Logging.LogLevel) -> Task<WhisperKit, Error> {
// Create the initialization task
let initializationTask = Task { () -> WhisperKit in
let whisperKit = try await WhisperKit(WhisperKitConfig(
model: config.model,
modelRepo: config.modelRepo,
modelToken: Self.getModelToken(),
computeOptions: config.modelComputeOptions,
verbose: verbose,
logLevel: logLevel,
prewarm: true,
load: true
))
try Task.checkCancellation()
return whisperKit
}
return initializationTask
}
func createWithMemoryCheck(
testConfig: TestConfig,
verbose: Bool,
logLevel: Logging.LogLevel
) async throws -> WhisperKit {
// Create the initialization task
let initializationTask = initWhisperKitTask(
testConfig: testConfig,
verbose: verbose,
logLevel: logLevel
)
// Start the memory monitoring task
let monitorTask = Task {
while true {
let remainingMemory = SystemMemoryCheckerAdvanced.getMemoryUsage().totalAvailableGB
Logging.debug(remainingMemory, "GB of memory left")
if remainingMemory <= 0.1 { // Cancel with 100MB remaining
Logging.debug("Cancelling due to oom")
// Cancel the initialization task
initializationTask.cancel()
// Throw an error to stop the monitor task
throw WhisperError.modelsUnavailable("Memory limit exceeded during initialization")
}
try await Task.sleep(nanoseconds: 1_000_000_000) // 1 second
}
}
// Create a timeout task
let timeoutTask = Task {
try await Task.sleep(nanoseconds: 300_000_000_000) // 5 minutes
initializationTask.cancel()
monitorTask.cancel()
Logging.debug("Cancelling due to timeout")
throw WhisperError.modelsUnavailable("Initialization timed out")
}
do {
// Use withTaskCancellationHandler to ensure proper cleanup
return try await withTaskCancellationHandler(
operation: {
// Await the initialization task
let whisperKit = try await initializationTask.value
// Cancel the monitor tasks after successful initialization
monitorTask.cancel()
timeoutTask.cancel()
return whisperKit
},
onCancel: {
initializationTask.cancel()
monitorTask.cancel()
timeoutTask.cancel()
}
)
} catch {
initializationTask.cancel()
monitorTask.cancel()
timeoutTask.cancel()
Logging.debug(error)
throw error
}
}
}