Skip to content

Commit cc15745

Browse files
committed
Add parallel layer downloads support for image pull operations
Implements concurrent layer downloads to improve image pull performance by fetching multiple layers simultaneously instead of sequentially.
1 parent bfc5ca9 commit cc15745

File tree

11 files changed

+279
-15
lines changed

11 files changed

+279
-15
lines changed

Package.resolved

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ let package = Package(
5454
.package(url: "https://github.com/swift-server/async-http-client.git", from: "1.20.1"),
5555
.package(url: "https://github.com/orlandos-nl/DNSClient.git", from: "2.4.1"),
5656
.package(url: "https://github.com/Bouke/DNS.git", from: "1.2.0"),
57-
.package(url: "https://github.com/apple/containerization.git", exact: Version(stringLiteral: scVersion)),
57+
.package(url: "https://github.com/apple/containerization.git", branch: "main"),
5858
],
5959
targets: [
6060
.executableTarget(

Sources/ContainerClient/Core/ClientImage.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ extension ClientImage {
220220
})
221221
}
222222

223-
public static func pull(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil) async throws -> ClientImage {
223+
public static func pull(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil, maxConcurrentDownloads: Int = 3) async throws -> ClientImage {
224224
let client = newXPCClient()
225225
let request = newRequest(.imagePull)
226226

@@ -234,6 +234,7 @@ extension ClientImage {
234234

235235
let insecure = try scheme.schemeFor(host: host) == .http
236236
request.set(key: .insecureFlag, value: insecure)
237+
request.set(key: .maxConcurrentDownloads, value: Int64(maxConcurrentDownloads))
237238

238239
var progressUpdateClient: ProgressUpdateClient?
239240
if let progressUpdate {
@@ -293,7 +294,7 @@ extension ClientImage {
293294
return (digests, size)
294295
}
295296

296-
public static func fetch(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil) async throws -> ClientImage
297+
public static func fetch(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil, maxConcurrentDownloads: Int = 3) async throws -> ClientImage
297298
{
298299
do {
299300
let match = try await self.get(reference: reference)
@@ -307,7 +308,7 @@ extension ClientImage {
307308
guard err.isCode(.notFound) else {
308309
throw err
309310
}
310-
return try await Self.pull(reference: reference, platform: platform, scheme: scheme, progressUpdate: progressUpdate)
311+
return try await Self.pull(reference: reference, platform: platform, scheme: scheme, progressUpdate: progressUpdate, maxConcurrentDownloads: maxConcurrentDownloads)
311312
}
312313
}
313314
}

Sources/ContainerClient/Flags.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,15 @@ public struct Flags {
208208
self.disableProgressUpdates = disableProgressUpdates
209209
}
210210

211+
public init(disableProgressUpdates: Bool, maxConcurrentDownloads: Int) {
212+
self.disableProgressUpdates = disableProgressUpdates
213+
self.maxConcurrentDownloads = maxConcurrentDownloads
214+
}
215+
211216
@Flag(name: .long, help: "Disable progress bar updates")
212217
public var disableProgressUpdates = false
218+
219+
@Option(name: .long, help: "Maximum number of concurrent layer downloads (default: 3)")
220+
public var maxConcurrentDownloads: Int = 3
213221
}
214222
}

Sources/ContainerCommands/Image/ImagePull.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ extension Application {
102102
let taskManager = ProgressTaskCoordinator()
103103
let fetchTask = await taskManager.startTask()
104104
let image = try await ClientImage.pull(
105-
reference: processedReference, platform: p, scheme: scheme, progressUpdate: ProgressTaskCoordinator.handler(for: fetchTask, from: progress.handler)
105+
reference: processedReference, platform: p, scheme: scheme, progressUpdate: ProgressTaskCoordinator.handler(for: fetchTask, from: progress.handler), maxConcurrentDownloads: self.progressFlags.maxConcurrentDownloads
106106
)
107107

108108
progress.set(description: "Unpacking image")

Sources/Services/ContainerImagesService/Client/ImageServiceXPCKeys.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ public enum ImagesServiceXPCKeys: String {
3535
case ociPlatform
3636
case insecureFlag
3737
case garbageCollect
38+
case maxConcurrentDownloads
3839

3940
/// ContentStore
4041
case digest
@@ -54,6 +55,10 @@ extension XPCMessage {
5455
self.set(key: key.rawValue, value: value)
5556
}
5657

58+
public func set(key: ImagesServiceXPCKeys, value: Int64) {
59+
self.set(key: key.rawValue, value: value)
60+
}
61+
5762
public func set(key: ImagesServiceXPCKeys, value: Data) {
5863
self.set(key: key.rawValue, value: value)
5964
}
@@ -78,6 +83,10 @@ extension XPCMessage {
7883
self.uint64(key: key.rawValue)
7984
}
8085

86+
public func int64(key: ImagesServiceXPCKeys) -> Int64 {
87+
self.int64(key: key.rawValue)
88+
}
89+
8190
public func bool(key: ImagesServiceXPCKeys) -> Bool {
8291
self.bool(key: key.rawValue)
8392
}

Sources/Services/ContainerImagesService/Server/ImageService.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ public actor ImagesService {
5959
return try await imageStore.list().map { $0.description.fromCZ }
6060
}
6161

62-
public func pull(reference: String, platform: Platform?, insecure: Bool, progressUpdate: ProgressUpdateHandler?) async throws -> ImageDescription {
63-
self.log.info("ImagesService: \(#function) - ref: \(reference), platform: \(String(describing: platform)), insecure: \(insecure)")
62+
public func pull(reference: String, platform: Platform?, insecure: Bool, progressUpdate: ProgressUpdateHandler?, maxConcurrentDownloads: Int = 3) async throws -> ImageDescription {
63+
self.log.info("ImagesService: \(#function) - ref: \(reference), platform: \(String(describing: platform)), insecure: \(insecure), maxConcurrentDownloads: \(maxConcurrentDownloads)")
6464
let img = try await Self.withAuthentication(ref: reference) { auth in
6565
try await self.imageStore.pull(
66-
reference: reference, platform: platform, insecure: insecure, auth: auth, progress: ContainerizationProgressAdapter.handler(from: progressUpdate))
66+
reference: reference, platform: platform, insecure: insecure, auth: auth, progress: ContainerizationProgressAdapter.handler(from: progressUpdate), maxConcurrentDownloads: maxConcurrentDownloads)
6767
}
6868
guard let img else {
6969
throw ContainerizationError(.internalError, message: "Failed to pull image \(reference)")

Sources/Services/ContainerImagesService/Server/ImagesServiceHarness.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ public struct ImagesServiceHarness: Sendable {
4747
platform = try JSONDecoder().decode(ContainerizationOCI.Platform.self, from: platformData)
4848
}
4949
let insecure = message.bool(key: .insecureFlag)
50+
let maxConcurrentDownloads = message.int64(key: .maxConcurrentDownloads)
5051

5152
let progressUpdateService = ProgressUpdateService(message: message)
52-
let imageDescription = try await service.pull(reference: ref, platform: platform, insecure: insecure, progressUpdate: progressUpdateService?.handler)
53+
let imageDescription = try await service.pull(reference: ref, platform: platform, insecure: insecure, progressUpdate: progressUpdateService?.handler, maxConcurrentDownloads: Int(maxConcurrentDownloads))
5354

5455
let imageData = try JSONEncoder().encode(imageDescription)
5556
let reply = message.reply()

Sources/TerminalProgress/ProgressTaskCoordinator.swift

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
import Foundation
1818

1919
/// A type that represents a task whose progress is being monitored.
20-
public struct ProgressTask: Sendable, Equatable {
20+
public struct ProgressTask: Sendable, Equatable, Hashable {
2121
private var id = UUID()
22-
private var coordinator: ProgressTaskCoordinator
22+
internal var coordinator: ProgressTaskCoordinator
2323

2424
init(manager: ProgressTaskCoordinator) {
2525
self.coordinator = manager
@@ -29,6 +29,10 @@ public struct ProgressTask: Sendable, Equatable {
2929
lhs.id == rhs.id
3030
}
3131

32+
public func hash(into hasher: inout Hasher) {
33+
hasher.combine(id)
34+
}
35+
3236
/// Returns `true` if this task is the currently active task, `false` otherwise.
3337
public func isCurrent() async -> Bool {
3438
guard let currentTask = await coordinator.currentTask else {
@@ -41,6 +45,7 @@ public struct ProgressTask: Sendable, Equatable {
4145
/// A type that coordinates progress tasks to ignore updates from completed tasks.
4246
public actor ProgressTaskCoordinator {
4347
var currentTask: ProgressTask?
48+
var activeTasks: Set<ProgressTask> = []
4449

4550
/// Creates an instance of `ProgressTaskCoordinator`.
4651
public init() {}
@@ -52,9 +57,36 @@ public actor ProgressTaskCoordinator {
5257
return newTask
5358
}
5459

60+
/// Starts multiple concurrent tasks and returns them.
61+
/// - Parameter count: The number of concurrent tasks to start.
62+
/// - Returns: An array of ProgressTask instances.
63+
public func startConcurrentTasks(count: Int) -> [ProgressTask] {
64+
var tasks: [ProgressTask] = []
65+
for _ in 0..<count {
66+
let task = ProgressTask(manager: self)
67+
tasks.append(task)
68+
activeTasks.insert(task)
69+
}
70+
return tasks
71+
}
72+
73+
/// Marks a specific task as completed and removes it from active tasks.
74+
/// - Parameter task: The task to mark as completed.
75+
public func completeTask(_ task: ProgressTask) {
76+
activeTasks.remove(task)
77+
}
78+
79+
/// Checks if a task is currently active.
80+
/// - Parameter task: The task to check.
81+
/// - Returns: `true` if the task is active, `false` otherwise.
82+
public func isTaskActive(_ task: ProgressTask) -> Bool {
83+
activeTasks.contains(task)
84+
}
85+
5586
/// Performs cleanup when the monitored tasks complete.
5687
public func finish() {
5788
currentTask = nil
89+
activeTasks.removeAll()
5890
}
5991

6092
/// Returns a handler that updates the progress of a given task.
@@ -69,4 +101,17 @@ public actor ProgressTaskCoordinator {
69101
}
70102
}
71103
}
104+
105+
/// Returns a handler that updates the progress for concurrent tasks.
106+
/// - Parameters:
107+
/// - task: The task whose progress is being updated.
108+
/// - progressUpdate: The handler to invoke when progress updates are received.
109+
public static func concurrentHandler(for task: ProgressTask, from progressUpdate: @escaping ProgressUpdateHandler) -> ProgressUpdateHandler {
110+
{ events in
111+
// Only process updates if the task is still active
112+
if await task.coordinator.isTaskActive(task) {
113+
await progressUpdate(events)
114+
}
115+
}
116+
}
72117
}

test_concurrency.swift

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
#!/usr/bin/env swift
2+
3+
import Foundation
4+
5+
func testConcurrentDownloads() async throws {
6+
print("Testing concurrent download behavior...\n")
7+
8+
// Track concurrent task count
9+
actor ConcurrencyTracker {
10+
var currentCount = 0
11+
var maxObservedCount = 0
12+
var completedTasks = 0
13+
14+
func taskStarted() {
15+
currentCount += 1
16+
maxObservedCount = max(maxObservedCount, currentCount)
17+
}
18+
19+
func taskCompleted() {
20+
currentCount -= 1
21+
completedTasks += 1
22+
}
23+
24+
func getStats() -> (max: Int, completed: Int) {
25+
return (maxObservedCount, completedTasks)
26+
}
27+
28+
func reset() {
29+
currentCount = 0
30+
maxObservedCount = 0
31+
completedTasks = 0
32+
}
33+
}
34+
35+
let tracker = ConcurrencyTracker()
36+
37+
// Test with different concurrency limits
38+
for maxConcurrent in [1, 3, 6] {
39+
await tracker.reset()
40+
41+
// Simulate downloading 20 layers
42+
let layerCount = 20
43+
let layers = Array(0..<layerCount)
44+
45+
print("Testing maxConcurrent=\(maxConcurrent) with \(layerCount) layers...")
46+
47+
let startTime = Date()
48+
49+
try await withThrowingTaskGroup(of: Void.self) { group in
50+
var iterator = layers.makeIterator()
51+
52+
// Start initial batch based on maxConcurrent
53+
for _ in 0..<maxConcurrent {
54+
if iterator.next() != nil {
55+
group.addTask {
56+
await tracker.taskStarted()
57+
try await Task.sleep(nanoseconds: 10_000_000)
58+
await tracker.taskCompleted()
59+
}
60+
}
61+
}
62+
for try await _ in group {
63+
if iterator.next() != nil {
64+
group.addTask {
65+
await tracker.taskStarted()
66+
try await Task.sleep(nanoseconds: 10_000_000)
67+
await tracker.taskCompleted()
68+
}
69+
}
70+
}
71+
}
72+
73+
let duration = Date().timeIntervalSince(startTime)
74+
let stats = await tracker.getStats()
75+
76+
print(" ✓ Completed: \(stats.completed)/\(layerCount)")
77+
print(" ✓ Max concurrent: \(stats.max)")
78+
print(" ✓ Duration: \(String(format: "%.3f", duration))s")
79+
80+
guard stats.max <= maxConcurrent + 1 else {
81+
throw TestError.concurrencyLimitExceeded
82+
}
83+
84+
guard stats.completed == layerCount else {
85+
throw TestError.incompleteTasks
86+
}
87+
88+
print(" ✅ PASSED\n")
89+
}
90+
91+
print("All tests passed!")
92+
}
93+
94+
enum TestError: Error {
95+
case concurrencyLimitExceeded
96+
case incompleteTasks
97+
}
98+
99+
Task {
100+
do {
101+
try await testConcurrentDownloads()
102+
exit(0)
103+
} catch {
104+
print("Test failed: \(error)")
105+
exit(1)
106+
}
107+
}
108+
109+
RunLoop.main.run()

0 commit comments

Comments
 (0)