Skip to content

Commit cf5c42e

Browse files
committed
SSHAlgorithms.all support, Sendability fixes
1 parent 9ebd290 commit cf5c42e

File tree

4 files changed

+107
-52
lines changed

4 files changed

+107
-52
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,8 @@ let client = try await SSHClient.connect(
325325
)
326326
```
327327

328+
You can also use `SSHAlgorithms.all` to enable all supported algorithms.
329+
328330
## TODO
329331

330332
A couple of code is held back until further work in SwiftNIO SSH is completed. We're currently working with Apple to resolve these.

Sources/Citadel/Client.swift

+81-34
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,61 @@
11
import NIO
2+
import CryptoKit
23
import Logging
34
import NIOSSH
45

6+
extension SSHAlgorithms.Modification<NIOSSHTransportProtection.Type> {
7+
func apply(to configuration: inout [any NIOSSHTransportProtection.Type]) {
8+
switch self {
9+
case .add(let algorithms):
10+
configuration.append(contentsOf: algorithms)
11+
12+
for algorithm: any NIOSSHTransportProtection.Type in algorithms {
13+
NIOSSHAlgorithms.register(transportProtectionScheme: algorithm)
14+
}
15+
case .replace(with: let algorithms):
16+
configuration = algorithms
17+
18+
for algorithm in algorithms {
19+
NIOSSHAlgorithms.register(transportProtectionScheme: algorithm)
20+
}
21+
}
22+
}
23+
}
24+
25+
extension SSHAlgorithms.Modification<NIOSSHKeyExchangeAlgorithmProtocol.Type> {
26+
func apply(to configuration: inout [any NIOSSHKeyExchangeAlgorithmProtocol.Type]) {
27+
switch self {
28+
case .add(let algorithms):
29+
configuration.append(contentsOf: algorithms)
30+
31+
for algorithm in algorithms {
32+
NIOSSHAlgorithms.register(keyExchangeAlgorithm: algorithm)
33+
}
34+
case .replace(with: let algorithms):
35+
configuration = algorithms
36+
37+
for algorithm in algorithms {
38+
NIOSSHAlgorithms.register(keyExchangeAlgorithm: algorithm)
39+
}
40+
}
41+
}
42+
}
43+
44+
extension SSHAlgorithms.Modification<(NIOSSHPublicKeyProtocol.Type, NIOSSHSignatureProtocol.Type)>{
45+
func register() {
46+
switch self {
47+
case .add(let algorithms):
48+
for (publicKey, signature) in algorithms {
49+
NIOSSHAlgorithms.register(publicKey: publicKey, signature: signature)
50+
}
51+
case .replace(with: let algorithms):
52+
for (publicKey, signature) in algorithms {
53+
NIOSSHAlgorithms.register(publicKey: publicKey, signature: signature)
54+
}
55+
}
56+
}
57+
}
58+
559
public struct SSHAlgorithms {
660
/// Represents a modification to a list of items.
761
///
@@ -18,47 +72,40 @@ public struct SSHAlgorithms {
1872
/// The enabled KeyExchangeAlgorithms
1973
public var keyExchangeAlgorithms: Modification<NIOSSHKeyExchangeAlgorithmProtocol.Type>?
2074

75+
public var publicKeyAlgorihtms: Modification<(NIOSSHPublicKeyProtocol.Type, NIOSSHSignatureProtocol.Type)>?
76+
2177
func apply(to clientConfiguration: inout SSHClientConfiguration) {
22-
switch transportProtectionSchemes {
23-
case .add(let algorithms):
24-
clientConfiguration.transportProtectionSchemes.append(contentsOf: algorithms)
25-
case .replace(with: let algorithms):
26-
clientConfiguration.transportProtectionSchemes = algorithms
27-
case .none:
28-
()
29-
}
30-
31-
switch keyExchangeAlgorithms {
32-
case .add(let algorithms):
33-
clientConfiguration.keyExchangeAlgorithms.append(contentsOf: algorithms)
34-
case .replace(with: let algorithms):
35-
clientConfiguration.keyExchangeAlgorithms = algorithms
36-
case .none:
37-
()
38-
}
78+
transportProtectionSchemes?.apply(to: &clientConfiguration.transportProtectionSchemes)
79+
keyExchangeAlgorithms?.apply(to: &clientConfiguration.keyExchangeAlgorithms)
80+
publicKeyAlgorihtms?.register()
3981
}
4082

4183
func apply(to serverConfiguration: inout SSHServerConfiguration) {
42-
switch transportProtectionSchemes {
43-
case .add(let algorithms):
44-
serverConfiguration.transportProtectionSchemes.append(contentsOf: algorithms)
45-
case .replace(with: let algorithms):
46-
serverConfiguration.transportProtectionSchemes = algorithms
47-
case .none:
48-
()
49-
}
50-
51-
switch keyExchangeAlgorithms {
52-
case .add(let algorithms):
53-
serverConfiguration.keyExchangeAlgorithms.append(contentsOf: algorithms)
54-
case .replace(with: let algorithms):
55-
serverConfiguration.keyExchangeAlgorithms = algorithms
56-
case .none:
57-
()
58-
}
84+
transportProtectionSchemes?.apply(to: &serverConfiguration.transportProtectionSchemes)
85+
keyExchangeAlgorithms?.apply(to: &serverConfiguration.keyExchangeAlgorithms)
86+
publicKeyAlgorihtms?.register()
5987
}
6088

6189
public init() {}
90+
91+
public static let all: SSHAlgorithms = {
92+
var algorithms = SSHAlgorithms()
93+
94+
algorithms.transportProtectionSchemes = .add([
95+
AES128CTR.self
96+
])
97+
98+
algorithms.keyExchangeAlgorithms = .add([
99+
DiffieHellmanGroup14Sha1.self,
100+
DiffieHellmanGroup14Sha256.self
101+
])
102+
103+
algorithms.publicKeyAlgorihtms = .add([
104+
(Insecure.RSA.PublicKey.self, Insecure.RSA.Signature.self),
105+
])
106+
107+
return algorithms
108+
}()
62109
}
63110

64111
/// Represents an SSH connection.

Sources/Citadel/SFTP/Client/SFTPClient.swift

+23-17
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ import Logging
77
/// The SFTP client does not concern itself with the created SSH subsystem channel.
88
///
99
/// Per specification, SFTP could be used over other transport layers, too.
10-
public final class SFTPClient {
10+
public final class SFTPClient: Sendable {
1111
/// The SSH child channel created for this connection.
1212
fileprivate let channel: Channel
1313

1414
/// A monotonically increasing counter for gneerating request IDs.
15-
private var _nextRequestId = NIOLockedValueBox<UInt32>(0)
15+
private let _nextRequestId = NIOLockedValueBox<UInt32>(0)
1616

1717
private func incrementAndGetNextRequestId() -> UInt32 {
1818
_nextRequestId.withLockedValue { value in
@@ -34,7 +34,7 @@ public final class SFTPClient {
3434
}
3535

3636
fileprivate static func setupChannelHanders(channel: Channel, logger: Logger) -> EventLoopFuture<SFTPClient> {
37-
let responses = SFTPResponses(initialized: channel.eventLoop.makePromise())
37+
let responses = SFTPResponses(sftpVersion: channel.eventLoop.makePromise())
3838

3939
let deserializeHandler = ByteToMessageHandler(SFTPMessageParser())
4040
let serializeHandler = MessageToByteHandler(SFTPMessageSerializer())
@@ -315,7 +315,9 @@ extension SSHClient {
315315

316316
self.session.sshHandler.createChannel(createChannel) { channel, _ in
317317
SFTPClient.setupChannelHanders(channel: channel, logger: logger)
318-
.map(createClient.succeed)
318+
.map { client in
319+
createClient.succeed(client)
320+
}
319321
}
320322

321323
timeoutCheck.futureResult.whenFailure { _ in
@@ -354,7 +356,7 @@ extension SSHClient {
354356
//logger.trace("SFTP OUT: \(initializeMessage.debugRawBytesRepresentation)")
355357

356358
return client.channel.writeAndFlush(initializeMessage).flatMap {
357-
return client.responses.initialized.futureResult
359+
return client.responses.sftpVersion.futureResult
358360
}.flatMapThrowing { serverVersion in
359361
guard serverVersion.version >= .v3 else {
360362
logger.warning("SFTP ERROR: Server version is unrecognized: \(serverVersion.version.rawValue)")
@@ -370,29 +372,33 @@ extension SSHClient {
370372
}
371373

372374
/// A tracker for in-flight SFTP requests. Request IDs are allocated by `SFTPClient`.
373-
final class SFTPResponses {
374-
var isInitialized: Bool = false
375-
let initialized: EventLoopPromise<SFTPMessage.Version>
375+
final class SFTPResponses: @unchecked Sendable {
376+
let _initialized: NIOLockedValueBox<Bool> = NIOLockedValueBox<Bool>(false)
377+
var isInitialized: Bool {
378+
get { _initialized.withLockedValue { $0 } }
379+
set { _initialized.withLockedValue { $0 = newValue } }
380+
}
381+
let sftpVersion: EventLoopPromise<SFTPMessage.Version>
376382
var responses = [UInt32: EventLoopPromise<SFTPResponse>]()
377383

378-
init(initialized: EventLoopPromise<SFTPMessage.Version>) {
379-
self.initialized = initialized
384+
init(sftpVersion: EventLoopPromise<SFTPMessage.Version>) {
385+
self.sftpVersion = sftpVersion
380386

381-
initialized.futureResult.whenSuccess { [unowned self] _ in
382-
self.isInitialized = true
387+
sftpVersion.futureResult.whenSuccess { [weak self] _ in
388+
self?.isInitialized = true
383389
}
384390
}
385391

386-
deinit {
387-
self.close()
388-
}
389-
390392
func close() {
391393
self.isInitialized = false
392-
self.initialized.fail(SFTPError.connectionClosed)
394+
self.sftpVersion.fail(SFTPError.connectionClosed)
393395

394396
for promise in self.responses.values {
395397
promise.fail(SFTPError.connectionClosed)
396398
}
397399
}
400+
401+
deinit {
402+
close()
403+
}
398404
}

Sources/Citadel/SFTP/Client/SFTPClientInboundHandler.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ final class SFTPClientInboundHandler: ChannelInboundHandler {
2222
logger.warning("SFTP ERROR: Server version is unrecognized or incompatible: \(version.version.rawValue)")
2323
context.fireErrorCaught(SFTPError.unsupportedVersion(version.version))
2424
} else {
25-
responses.initialized.succeed(version)
25+
responses.sftpVersion.succeed(version)
2626
}
2727
} else if let response = SFTPResponse(message: message) {
2828
if let promise = responses.responses.removeValue(forKey: response.requestId) {

0 commit comments

Comments
 (0)