@@ -31,7 +31,7 @@ final class CollectingExecCommandHelper {
31
31
self . stderr = allocator. buffer ( capacity: 4096 )
32
32
}
33
33
34
- public func onOutput( _ output: ExecCommandHandler . Output ) {
34
+ public func onOutput( _ channel : Channel , _ output: ExecCommandHandler . Output ) {
35
35
switch output {
36
36
case . stderr( let byteBuffer) where mergeStreams:
37
37
fallthrough
@@ -65,6 +65,8 @@ final class CollectingExecCommandHelper {
65
65
case . eof( . none) :
66
66
stdoutPromise? . succeed ( stdout)
67
67
stderrPromise? . succeed ( stderr)
68
+ case . channelSuccess:
69
+ ( )
68
70
}
69
71
}
70
72
}
@@ -86,6 +88,8 @@ public struct ExecCommandStream {
86
88
case . eof( let error) :
87
89
stdout. finish ( throwing: error)
88
90
stderr. finish ( throwing: error)
91
+ case . channelSuccess:
92
+ ( )
89
93
}
90
94
}
91
95
}
@@ -96,8 +100,10 @@ public enum ExecCommandOutput {
96
100
case stderr( ByteBuffer )
97
101
}
98
102
103
+
99
104
final class ExecCommandHandler : ChannelDuplexHandler {
100
105
enum Output {
106
+ case channelSuccess
101
107
case stdout( ByteBuffer )
102
108
case stderr( ByteBuffer )
103
109
case eof( Error ? )
@@ -109,9 +115,9 @@ final class ExecCommandHandler: ChannelDuplexHandler {
109
115
typealias OutboundOut = SSHChannelData
110
116
111
117
let logger : Logger
112
- let onOutput : ( Output ) -> ( )
118
+ let onOutput : ( Channel , Output ) -> ( )
113
119
114
- init ( logger: Logger , onOutput: @escaping ( Output ) -> ( ) ) {
120
+ init ( logger: Logger , onOutput: @escaping ( Channel , Output ) -> ( ) ) {
115
121
self . logger = logger
116
122
self . onOutput = onOutput
117
123
}
@@ -124,6 +130,10 @@ final class ExecCommandHandler: ChannelDuplexHandler {
124
130
125
131
func userInboundEventTriggered( context: ChannelHandlerContext , event: Any ) {
126
132
switch event {
133
+ case is NIOSSH . ChannelSuccessEvent :
134
+ onOutput ( context. channel, . channelSuccess)
135
+ case is NIOSSH . ChannelFailureEvent :
136
+ onOutput ( context. channel, . eof( CitadelError . channelFailure) )
127
137
case is SSHChannelRequestEvent . ExitStatus :
128
138
( )
129
139
default :
@@ -132,35 +142,30 @@ final class ExecCommandHandler: ChannelDuplexHandler {
132
142
}
133
143
134
144
func handlerRemoved( context: ChannelHandlerContext ) {
135
- onOutput ( . eof( nil ) )
145
+ onOutput ( context . channel , . eof( nil ) )
136
146
}
137
147
138
148
func channelRead( context: ChannelHandlerContext , data: NIOAny ) {
139
149
let data = self . unwrapInboundIn ( data)
140
150
141
151
guard case . byteBuffer( let buffer) = data. data else {
142
152
logger. error ( " Unable to process channelData for executed command. Data was not a ByteBuffer " )
143
- return onOutput ( . eof( SSHExecError . invalidData) )
153
+ return onOutput ( context . channel , . eof( SSHExecError . invalidData) )
144
154
}
145
155
146
156
switch data. type {
147
157
case . channel:
148
- onOutput ( . stdout( buffer) )
158
+ onOutput ( context . channel , . stdout( buffer) )
149
159
case . stdErr:
150
- onOutput ( . stderr( buffer) )
160
+ onOutput ( context . channel , . stderr( buffer) )
151
161
default :
152
162
// We don't know this std channel
153
163
( )
154
164
}
155
165
}
156
166
157
167
func errorCaught( context: ChannelHandlerContext , error: Error ) {
158
- onOutput ( . eof( error) )
159
- }
160
-
161
- func write( context: ChannelHandlerContext , data: NIOAny , promise: EventLoopPromise < Void > ? ) {
162
- let data = self . unwrapOutboundIn ( data)
163
- context. write ( self . wrapOutboundOut ( SSHChannelData ( type: . channel, data: . byteBuffer( data) ) ) , promise: promise)
168
+ onOutput ( context. channel, . eof( error) )
164
169
}
165
170
}
166
171
@@ -222,20 +227,30 @@ extension SSHClient {
222
227
/// Executes a command on the remote server. This will return the output stream of the command. If the command fails, the error will be thrown.
223
228
/// - Parameters:
224
229
/// - command: The command to execute.
225
- public func executeCommandStream( _ command: String ) async throws -> AsyncThrowingStream < ExecCommandOutput , Error > {
230
+ /// - inShell: Whether to request the remote server to start a shell before executing the command.
231
+ public func executeCommandStream( _ command: String , inShell: Bool = false ) async throws -> AsyncThrowingStream < ExecCommandOutput , Error > {
226
232
var streamContinuation : AsyncThrowingStream < ExecCommandOutput , Error > . Continuation !
227
233
let stream = AsyncThrowingStream < ExecCommandOutput , Error > { continuation in
228
234
streamContinuation = continuation
229
235
}
230
236
231
- let handler = ExecCommandHandler ( logger: logger) { output in
237
+ var hasReceivedChannelSuccess = false
238
+
239
+ let handler = ExecCommandHandler ( logger: logger) { channel, output in
232
240
switch output {
233
241
case . stdout( let stdout) :
234
242
streamContinuation. yield ( . stdout( stdout) )
235
243
case . stderr( let stderr) :
236
244
streamContinuation. yield ( . stderr( stderr) )
237
245
case . eof( let error) :
238
246
streamContinuation. finish ( throwing: error)
247
+ case . channelSuccess:
248
+ if inShell, !hasReceivedChannelSuccess {
249
+ let commandData = SSHChannelData ( type: . channel,
250
+ data: . byteBuffer( ByteBuffer ( string: command + " ;exit \n " ) ) )
251
+ channel. writeAndFlush ( commandData, promise: nil )
252
+ hasReceivedChannelSuccess = true
253
+ }
239
254
}
240
255
}
241
256
@@ -252,21 +267,24 @@ extension SSHClient {
252
267
return createChannel. futureResult
253
268
} . get ( )
254
269
255
- // We need to exec a thing.
256
- let execRequest = SSHChannelRequestEvent . ExecRequest (
257
- command: command,
258
- wantReply: true
259
- )
260
-
261
- try await channel. triggerUserOutboundEvent ( execRequest)
270
+ if inShell {
271
+ try await channel. triggerUserOutboundEvent ( SSHChannelRequestEvent . ShellRequest (
272
+ wantReply: true
273
+ ) )
274
+ } else {
275
+ try await channel. triggerUserOutboundEvent ( SSHChannelRequestEvent . ExecRequest (
276
+ command: command,
277
+ wantReply: true
278
+ ) )
279
+ }
262
280
263
281
return stream
264
282
}
265
283
266
284
/// Executes a command on the remote server. This will return the pair of streams stdout and stderr of the command. If the command fails, the error will be thrown.
267
285
/// - Parameters:
268
286
/// - command: The command to execute.
269
- public func executeCommandPair( _ command: String ) async throws -> ExecCommandStream {
287
+ public func executeCommandPair( _ command: String , inShell : Bool = false ) async throws -> ExecCommandStream {
270
288
var stdoutContinuation : AsyncThrowingStream < ByteBuffer , Error > . Continuation !
271
289
var stderrContinuation : AsyncThrowingStream < ByteBuffer , Error > . Continuation !
272
290
let stdout = AsyncThrowingStream < ByteBuffer , Error > { continuation in
@@ -284,7 +302,7 @@ extension SSHClient {
284
302
285
303
Task {
286
304
do {
287
- let stream = try await self . executeCommandStream ( command)
305
+ let stream = try await executeCommandStream ( command, inShell : inShell )
288
306
for try await chunk in stream {
289
307
switch chunk {
290
308
case . stdout( let buffer) :
0 commit comments