Skip to content

Commit d731663

Browse files
authored
Merge pull request #36 from Marcocanc/shell-exec
Add `inShell` parameter to for SSH command execution
2 parents bbe2e65 + a461b4e commit d731663

File tree

2 files changed

+43
-24
lines changed

2 files changed

+43
-24
lines changed

Sources/Citadel/Algorithms/AES.swift

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ enum CitadelError: Error {
1717
case unauthorized
1818
case commandOutputTooLarge
1919
case channelCreationFailed
20+
case channelFailure
2021
}
2122

2223
public final class AES128CTR: NIOSSHTransportProtection {

Sources/Citadel/TTY/Client/TTY.swift

+42-24
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ final class CollectingExecCommandHelper {
3131
self.stderr = allocator.buffer(capacity: 4096)
3232
}
3333

34-
public func onOutput(_ output: ExecCommandHandler.Output) {
34+
public func onOutput(_ channel: Channel, _ output: ExecCommandHandler.Output) {
3535
switch output {
3636
case .stderr(let byteBuffer) where mergeStreams:
3737
fallthrough
@@ -65,6 +65,8 @@ final class CollectingExecCommandHelper {
6565
case .eof(.none):
6666
stdoutPromise?.succeed(stdout)
6767
stderrPromise?.succeed(stderr)
68+
case .channelSuccess:
69+
()
6870
}
6971
}
7072
}
@@ -86,6 +88,8 @@ public struct ExecCommandStream {
8688
case .eof(let error):
8789
stdout.finish(throwing: error)
8890
stderr.finish(throwing: error)
91+
case .channelSuccess:
92+
()
8993
}
9094
}
9195
}
@@ -96,8 +100,10 @@ public enum ExecCommandOutput {
96100
case stderr(ByteBuffer)
97101
}
98102

103+
99104
final class ExecCommandHandler: ChannelDuplexHandler {
100105
enum Output {
106+
case channelSuccess
101107
case stdout(ByteBuffer)
102108
case stderr(ByteBuffer)
103109
case eof(Error?)
@@ -109,9 +115,9 @@ final class ExecCommandHandler: ChannelDuplexHandler {
109115
typealias OutboundOut = SSHChannelData
110116

111117
let logger: Logger
112-
let onOutput: (Output) -> ()
118+
let onOutput: (Channel, Output) -> ()
113119

114-
init(logger: Logger, onOutput: @escaping (Output) -> ()) {
120+
init(logger: Logger, onOutput: @escaping (Channel, Output) -> ()) {
115121
self.logger = logger
116122
self.onOutput = onOutput
117123
}
@@ -124,6 +130,10 @@ final class ExecCommandHandler: ChannelDuplexHandler {
124130

125131
func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
126132
switch event {
133+
case is NIOSSH.ChannelSuccessEvent:
134+
onOutput(context.channel, .channelSuccess)
135+
case is NIOSSH.ChannelFailureEvent:
136+
onOutput(context.channel, .eof(CitadelError.channelFailure))
127137
case is SSHChannelRequestEvent.ExitStatus:
128138
()
129139
default:
@@ -132,35 +142,30 @@ final class ExecCommandHandler: ChannelDuplexHandler {
132142
}
133143

134144
func handlerRemoved(context: ChannelHandlerContext) {
135-
onOutput(.eof(nil))
145+
onOutput(context.channel, .eof(nil))
136146
}
137147

138148
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
139149
let data = self.unwrapInboundIn(data)
140150

141151
guard case .byteBuffer(let buffer) = data.data else {
142152
logger.error("Unable to process channelData for executed command. Data was not a ByteBuffer")
143-
return onOutput(.eof(SSHExecError.invalidData))
153+
return onOutput(context.channel, .eof(SSHExecError.invalidData))
144154
}
145155

146156
switch data.type {
147157
case .channel:
148-
onOutput(.stdout(buffer))
158+
onOutput(context.channel, .stdout(buffer))
149159
case .stdErr:
150-
onOutput(.stderr(buffer))
160+
onOutput(context.channel, .stderr(buffer))
151161
default:
152162
// We don't know this std channel
153163
()
154164
}
155165
}
156166

157167
func errorCaught(context: ChannelHandlerContext, error: Error) {
158-
onOutput(.eof(error))
159-
}
160-
161-
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
162-
let data = self.unwrapOutboundIn(data)
163-
context.write(self.wrapOutboundOut(SSHChannelData(type: .channel, data: .byteBuffer(data))), promise: promise)
168+
onOutput(context.channel, .eof(error))
164169
}
165170
}
166171

@@ -222,20 +227,30 @@ extension SSHClient {
222227
/// Executes a command on the remote server. This will return the output stream of the command. If the command fails, the error will be thrown.
223228
/// - Parameters:
224229
/// - command: The command to execute.
225-
public func executeCommandStream(_ command: String) async throws -> AsyncThrowingStream<ExecCommandOutput, Error> {
230+
/// - inShell: Whether to request the remote server to start a shell before executing the command.
231+
public func executeCommandStream(_ command: String, inShell: Bool = false) async throws -> AsyncThrowingStream<ExecCommandOutput, Error> {
226232
var streamContinuation: AsyncThrowingStream<ExecCommandOutput, Error>.Continuation!
227233
let stream = AsyncThrowingStream<ExecCommandOutput, Error> { continuation in
228234
streamContinuation = continuation
229235
}
230236

231-
let handler = ExecCommandHandler(logger: logger) { output in
237+
var hasReceivedChannelSuccess = false
238+
239+
let handler = ExecCommandHandler(logger: logger) { channel, output in
232240
switch output {
233241
case .stdout(let stdout):
234242
streamContinuation.yield(.stdout(stdout))
235243
case .stderr(let stderr):
236244
streamContinuation.yield(.stderr(stderr))
237245
case .eof(let error):
238246
streamContinuation.finish(throwing: error)
247+
case .channelSuccess:
248+
if inShell, !hasReceivedChannelSuccess {
249+
let commandData = SSHChannelData(type: .channel,
250+
data: .byteBuffer(ByteBuffer(string: command + ";exit\n")))
251+
channel.writeAndFlush(commandData, promise: nil)
252+
hasReceivedChannelSuccess = true
253+
}
239254
}
240255
}
241256

@@ -252,21 +267,24 @@ extension SSHClient {
252267
return createChannel.futureResult
253268
}.get()
254269

255-
// We need to exec a thing.
256-
let execRequest = SSHChannelRequestEvent.ExecRequest(
257-
command: command,
258-
wantReply: true
259-
)
260-
261-
try await channel.triggerUserOutboundEvent(execRequest)
270+
if inShell {
271+
try await channel.triggerUserOutboundEvent(SSHChannelRequestEvent.ShellRequest(
272+
wantReply: true
273+
))
274+
} else {
275+
try await channel.triggerUserOutboundEvent(SSHChannelRequestEvent.ExecRequest(
276+
command: command,
277+
wantReply: true
278+
))
279+
}
262280

263281
return stream
264282
}
265283

266284
/// Executes a command on the remote server. This will return the pair of streams stdout and stderr of the command. If the command fails, the error will be thrown.
267285
/// - Parameters:
268286
/// - command: The command to execute.
269-
public func executeCommandPair(_ command: String) async throws -> ExecCommandStream {
287+
public func executeCommandPair(_ command: String, inShell: Bool = false) async throws -> ExecCommandStream {
270288
var stdoutContinuation: AsyncThrowingStream<ByteBuffer, Error>.Continuation!
271289
var stderrContinuation: AsyncThrowingStream<ByteBuffer, Error>.Continuation!
272290
let stdout = AsyncThrowingStream<ByteBuffer, Error> { continuation in
@@ -284,7 +302,7 @@ extension SSHClient {
284302

285303
Task {
286304
do {
287-
let stream = try await self.executeCommandStream(command)
305+
let stream = try await executeCommandStream(command, inShell: inShell)
288306
for try await chunk in stream {
289307
switch chunk {
290308
case .stdout(let buffer):

0 commit comments

Comments
 (0)