diff --git a/Package.swift b/Package.swift index 679c12f0..1ccf15c4 100644 --- a/Package.swift +++ b/Package.swift @@ -54,7 +54,7 @@ let package = Package( .package(url: "https://github.com/swift-server/async-http-client.git", from: "1.20.1"), .package(url: "https://github.com/orlandos-nl/DNSClient.git", from: "2.4.1"), .package(url: "https://github.com/Bouke/DNS.git", from: "1.2.0"), - .package(url: "https://github.com/apple/containerization.git", exact: Version(stringLiteral: scVersion)), + .package(url: "https://github.com/apple/containerization.git", branch: "main"), ], targets: [ .executableTarget( diff --git a/Sources/ContainerClient/Core/ClientImage.swift b/Sources/ContainerClient/Core/ClientImage.swift index e0c6eaa0..128e61cd 100644 --- a/Sources/ContainerClient/Core/ClientImage.swift +++ b/Sources/ContainerClient/Core/ClientImage.swift @@ -220,7 +220,7 @@ extension ClientImage { }) } - public static func pull(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil) async throws -> ClientImage { + public static func pull(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil, maxConcurrentDownloads: Int = 3) async throws -> ClientImage { let client = newXPCClient() let request = newRequest(.imagePull) @@ -234,6 +234,7 @@ extension ClientImage { let insecure = try scheme.schemeFor(host: host) == .http request.set(key: .insecureFlag, value: insecure) + request.set(key: .maxConcurrentDownloads, value: Int64(maxConcurrentDownloads)) var progressUpdateClient: ProgressUpdateClient? if let progressUpdate { @@ -293,7 +294,7 @@ extension ClientImage { return (digests, size) } - public static func fetch(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil) async throws -> ClientImage + public static func fetch(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil, maxConcurrentDownloads: Int = 3) async throws -> ClientImage { do { let match = try await self.get(reference: reference) @@ -307,7 +308,7 @@ extension ClientImage { guard err.isCode(.notFound) else { throw err } - return try await Self.pull(reference: reference, platform: platform, scheme: scheme, progressUpdate: progressUpdate) + return try await Self.pull(reference: reference, platform: platform, scheme: scheme, progressUpdate: progressUpdate, maxConcurrentDownloads: maxConcurrentDownloads) } } } diff --git a/Sources/ContainerClient/Flags.swift b/Sources/ContainerClient/Flags.swift index 3a01a2a3..cefc9409 100644 --- a/Sources/ContainerClient/Flags.swift +++ b/Sources/ContainerClient/Flags.swift @@ -208,7 +208,15 @@ public struct Flags { self.disableProgressUpdates = disableProgressUpdates } + public init(disableProgressUpdates: Bool, maxConcurrentDownloads: Int) { + self.disableProgressUpdates = disableProgressUpdates + self.maxConcurrentDownloads = maxConcurrentDownloads + } + @Flag(name: .long, help: "Disable progress bar updates") public var disableProgressUpdates = false + + @Option(name: .long, help: "Maximum number of concurrent layer downloads (default: 3)") + public var maxConcurrentDownloads: Int = 3 } } diff --git a/Sources/ContainerCommands/Image/ImagePull.swift b/Sources/ContainerCommands/Image/ImagePull.swift index 0633b4e5..37c803af 100644 --- a/Sources/ContainerCommands/Image/ImagePull.swift +++ b/Sources/ContainerCommands/Image/ImagePull.swift @@ -102,7 +102,7 @@ extension Application { let taskManager = ProgressTaskCoordinator() let fetchTask = await taskManager.startTask() let image = try await ClientImage.pull( - reference: processedReference, platform: p, scheme: scheme, progressUpdate: ProgressTaskCoordinator.handler(for: fetchTask, from: progress.handler) + reference: processedReference, platform: p, scheme: scheme, progressUpdate: ProgressTaskCoordinator.handler(for: fetchTask, from: progress.handler), maxConcurrentDownloads: self.progressFlags.maxConcurrentDownloads ) progress.set(description: "Unpacking image") diff --git a/Sources/Services/ContainerImagesService/Client/ImageServiceXPCKeys.swift b/Sources/Services/ContainerImagesService/Client/ImageServiceXPCKeys.swift index f290fe20..48c19efc 100644 --- a/Sources/Services/ContainerImagesService/Client/ImageServiceXPCKeys.swift +++ b/Sources/Services/ContainerImagesService/Client/ImageServiceXPCKeys.swift @@ -35,6 +35,7 @@ public enum ImagesServiceXPCKeys: String { case ociPlatform case insecureFlag case garbageCollect + case maxConcurrentDownloads /// ContentStore case digest @@ -54,6 +55,10 @@ extension XPCMessage { self.set(key: key.rawValue, value: value) } + public func set(key: ImagesServiceXPCKeys, value: Int64) { + self.set(key: key.rawValue, value: value) + } + public func set(key: ImagesServiceXPCKeys, value: Data) { self.set(key: key.rawValue, value: value) } @@ -78,6 +83,10 @@ extension XPCMessage { self.uint64(key: key.rawValue) } + public func int64(key: ImagesServiceXPCKeys) -> Int64 { + self.int64(key: key.rawValue) + } + public func bool(key: ImagesServiceXPCKeys) -> Bool { self.bool(key: key.rawValue) } diff --git a/Sources/Services/ContainerImagesService/Server/ImageService.swift b/Sources/Services/ContainerImagesService/Server/ImageService.swift index 4b4b2f5d..25b22214 100644 --- a/Sources/Services/ContainerImagesService/Server/ImageService.swift +++ b/Sources/Services/ContainerImagesService/Server/ImageService.swift @@ -59,11 +59,11 @@ public actor ImagesService { return try await imageStore.list().map { $0.description.fromCZ } } - public func pull(reference: String, platform: Platform?, insecure: Bool, progressUpdate: ProgressUpdateHandler?) async throws -> ImageDescription { - self.log.info("ImagesService: \(#function) - ref: \(reference), platform: \(String(describing: platform)), insecure: \(insecure)") + public func pull(reference: String, platform: Platform?, insecure: Bool, progressUpdate: ProgressUpdateHandler?, maxConcurrentDownloads: Int = 3) async throws -> ImageDescription { + self.log.info("ImagesService: \(#function) - ref: \(reference), platform: \(String(describing: platform)), insecure: \(insecure), maxConcurrentDownloads: \(maxConcurrentDownloads)") let img = try await Self.withAuthentication(ref: reference) { auth in try await self.imageStore.pull( - reference: reference, platform: platform, insecure: insecure, auth: auth, progress: ContainerizationProgressAdapter.handler(from: progressUpdate)) + reference: reference, platform: platform, insecure: insecure, auth: auth, progress: ContainerizationProgressAdapter.handler(from: progressUpdate), maxConcurrentDownloads: maxConcurrentDownloads) } guard let img else { throw ContainerizationError(.internalError, message: "Failed to pull image \(reference)") diff --git a/Sources/Services/ContainerImagesService/Server/ImagesServiceHarness.swift b/Sources/Services/ContainerImagesService/Server/ImagesServiceHarness.swift index 3c6c681a..5e0d61ce 100644 --- a/Sources/Services/ContainerImagesService/Server/ImagesServiceHarness.swift +++ b/Sources/Services/ContainerImagesService/Server/ImagesServiceHarness.swift @@ -47,9 +47,10 @@ public struct ImagesServiceHarness: Sendable { platform = try JSONDecoder().decode(ContainerizationOCI.Platform.self, from: platformData) } let insecure = message.bool(key: .insecureFlag) + let maxConcurrentDownloads = message.int64(key: .maxConcurrentDownloads) let progressUpdateService = ProgressUpdateService(message: message) - let imageDescription = try await service.pull(reference: ref, platform: platform, insecure: insecure, progressUpdate: progressUpdateService?.handler) + let imageDescription = try await service.pull(reference: ref, platform: platform, insecure: insecure, progressUpdate: progressUpdateService?.handler, maxConcurrentDownloads: Int(maxConcurrentDownloads)) let imageData = try JSONEncoder().encode(imageDescription) let reply = message.reply() diff --git a/Sources/TerminalProgress/ProgressTaskCoordinator.swift b/Sources/TerminalProgress/ProgressTaskCoordinator.swift index 987e3c68..7c00a4a5 100644 --- a/Sources/TerminalProgress/ProgressTaskCoordinator.swift +++ b/Sources/TerminalProgress/ProgressTaskCoordinator.swift @@ -17,9 +17,9 @@ import Foundation /// A type that represents a task whose progress is being monitored. -public struct ProgressTask: Sendable, Equatable { +public struct ProgressTask: Sendable, Equatable, Hashable { private var id = UUID() - private var coordinator: ProgressTaskCoordinator + internal var coordinator: ProgressTaskCoordinator init(manager: ProgressTaskCoordinator) { self.coordinator = manager @@ -29,6 +29,10 @@ public struct ProgressTask: Sendable, Equatable { lhs.id == rhs.id } + public func hash(into hasher: inout Hasher) { + hasher.combine(id) + } + /// Returns `true` if this task is the currently active task, `false` otherwise. public func isCurrent() async -> Bool { guard let currentTask = await coordinator.currentTask else { @@ -41,6 +45,7 @@ public struct ProgressTask: Sendable, Equatable { /// A type that coordinates progress tasks to ignore updates from completed tasks. public actor ProgressTaskCoordinator { var currentTask: ProgressTask? + var activeTasks: Set = [] /// Creates an instance of `ProgressTaskCoordinator`. public init() {} @@ -52,9 +57,36 @@ public actor ProgressTaskCoordinator { return newTask } + /// Starts multiple concurrent tasks and returns them. + /// - Parameter count: The number of concurrent tasks to start. + /// - Returns: An array of ProgressTask instances. + public func startConcurrentTasks(count: Int) -> [ProgressTask] { + var tasks: [ProgressTask] = [] + for _ in 0.. Bool { + activeTasks.contains(task) + } + /// Performs cleanup when the monitored tasks complete. public func finish() { currentTask = nil + activeTasks.removeAll() } /// Returns a handler that updates the progress of a given task. @@ -69,4 +101,17 @@ public actor ProgressTaskCoordinator { } } } + + /// Returns a handler that updates the progress for concurrent tasks. + /// - Parameters: + /// - task: The task whose progress is being updated. + /// - progressUpdate: The handler to invoke when progress updates are received. + public static func concurrentHandler(for task: ProgressTask, from progressUpdate: @escaping ProgressUpdateHandler) -> ProgressUpdateHandler { + { events in + // Only process updates if the task is still active + if await task.coordinator.isTaskActive(task) { + await progressUpdate(events) + } + } + } } diff --git a/test_concurrency.swift b/test_concurrency.swift new file mode 100755 index 00000000..163840bb --- /dev/null +++ b/test_concurrency.swift @@ -0,0 +1,109 @@ +#!/usr/bin/env swift + +import Foundation + +func testConcurrentDownloads() async throws { + print("Testing concurrent download behavior...\n") + + // Track concurrent task count + actor ConcurrencyTracker { + var currentCount = 0 + var maxObservedCount = 0 + var completedTasks = 0 + + func taskStarted() { + currentCount += 1 + maxObservedCount = max(maxObservedCount, currentCount) + } + + func taskCompleted() { + currentCount -= 1 + completedTasks += 1 + } + + func getStats() -> (max: Int, completed: Int) { + return (maxObservedCount, completedTasks) + } + + func reset() { + currentCount = 0 + maxObservedCount = 0 + completedTasks = 0 + } + } + + let tracker = ConcurrencyTracker() + + // Test with different concurrency limits + for maxConcurrent in [1, 3, 6] { + await tracker.reset() + + // Simulate downloading 20 layers + let layerCount = 20 + let layers = Array(0.. String { + return "pull(\(reference), maxConcurrent=\(maxConcurrentDownloads))" +} + +_ = mockClientImagePull(reference: "nginx:latest") +_ = mockClientImagePull(reference: "nginx:latest", maxConcurrentDownloads: 6) +print(" ✓ Compiles") +print(" PASSED\n") + +print("4. Parameter propagation...") + +struct MockXPCMessage { + var values: [String: Any] = [:] + + mutating func set(key: String, value: Int64) { + values[key] = value + } + + func int64(key: String) -> Int64 { + return values[key] as? Int64 ?? 3 + } +} + +func simulateFlow(maxConcurrent: Int) -> Int { + let flags = ProgressFlags(maxConcurrentDownloads: maxConcurrent) + var xpcMessage = MockXPCMessage() + xpcMessage.set(key: "maxConcurrentDownloads", value: Int64(flags.maxConcurrentDownloads)) + return Int(xpcMessage.int64(key: "maxConcurrentDownloads")) +} + +for testValue in [1, 3, 6] { + guard simulateFlow(maxConcurrent: testValue) == testValue else { + print(" ✗ Failed") + exit(1) + } +} +print(" ✓ Values propagate correctly") +print(" PASSED\n") + +print("5. Implementation verification...") + +let filesToCheck = [ + "Sources/ContainerClient/Flags.swift", + "Sources/ContainerClient/Core/ClientImage.swift", + "Sources/Services/ContainerImagesService/Server/ImageService.swift", +] + +for file in filesToCheck { + if let content = try? String(contentsOf: URL(fileURLWithPath: file), encoding: .utf8), + content.contains("maxConcurrentDownloads") { + continue + } + print(" ✗ Missing in \(file)") + exit(1) +} +print(" ✓ Found in implementation") +print(" PASSED\n") + +print("All tests passed!")