Skip to content

Commit

Permalink
Complete the shell support on the LWT layer of awa-ssh
Browse files Browse the repository at this point in the history
  • Loading branch information
dinosaure committed Mar 13, 2023
1 parent 1196665 commit 8aa0773
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 13 deletions.
39 changes: 33 additions & 6 deletions lwt/awa_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,28 @@ type sshin_msg = [
]

type channel = {
cmd : string;
cmd : string option;
id : int32;
sshin_mbox : sshin_msg Lwt_mvar.t;
exec_thread : unit Lwt.t;
}

type exec_callback =
string -> (* cmd *)
?cmd:string -> (* cmd *)
(unit -> sshin_msg Lwt.t) -> (* sshin *)
(Cstruct.t -> unit Lwt.t) -> (* sshout *)
(Cstruct.t -> unit Lwt.t) -> (* ssherr *)
unit Lwt.t

type set_env = string -> string -> unit Lwt.t
type set_window = term:string -> w:int32 -> h:int32 -> maxw:int32 -> maxh:int32 -> unit Lwt.t

type t = {
exec_callback : exec_callback; (* callback to run on exec *)
channels : channel list; (* Opened channels *)
nexus_mbox : nexus_msg Lwt_mvar.t;(* Nexus mailbox *)
env : set_env option; (* Environment *)
window : set_window option; (* Window *)
}

let wrapr = function
Expand Down Expand Up @@ -144,6 +149,16 @@ let rec nexus t fd server input_buffer =
| None -> Lwt.return_unit)
>>= fun () ->
nexus t fd server input_buffer
| Some Awa.Server.Set_env (k, v) ->
( match t.env with
| Some set_env -> set_env k v
| None -> Lwt.return_unit ) >>= fun () ->
nexus t fd server input_buffer
| Some Awa.Server.Pty (term, w, h, maxw, maxh, _modes) ->
( match t.window with
| Some set_window -> set_window ~term ~w ~h ~maxw ~maxh
| None -> Lwt.return_unit ) >>= fun () ->
nexus t fd server input_buffer
| Some Awa.Server.Channel_subsystem (id, cmd) (* same as exec *)
| Some Awa.Server.Channel_exec (id, cmd) ->
(* Create an input box *)
Expand All @@ -153,15 +168,27 @@ let rec nexus t fd server input_buffer =
let sshout id buf = Lwt_mvar.put t.nexus_mbox (Sshout (id, buf)) in
let ssherr id buf = Lwt_mvar.put t.nexus_mbox (Ssherr (id, buf)) in
(* Create the execution thread *)
let exec_thread = t.exec_callback cmd sshin (sshout id) (ssherr id) in
let c = { cmd; id; sshin_mbox; exec_thread } in
let exec_thread = t.exec_callback ~cmd sshin (sshout id) (ssherr id) in
let c = { cmd= Some cmd; id; sshin_mbox; exec_thread } in
let t = { t with channels = c :: t.channels } in
nexus t fd server input_buffer
| Some (Awa.Server.Start_shell id) ->
(* Create an input box *)
let sshin_mbox = Lwt_mvar.create_empty () in
(* Create a callback for each mbox *)
let sshin () = Lwt_mvar.take sshin_mbox in
let sshout id buf = Lwt_mvar.put t.nexus_mbox (Sshout (id, buf)) in
let ssherr id buf = Lwt_mvar.put t.nexus_mbox (Ssherr (id, buf)) in
(* Create the execution thread *)
let exec_thread = t.exec_callback sshin (sshout id) (ssherr id) in
let c = { cmd= None; id; sshin_mbox; exec_thread } in
let t = { t with channels = c :: t.channels } in
nexus t fd server input_buffer
| _ -> nexus t fd server input_buffer

let spawn_server server msgs fd exec_callback =
let spawn_server ?env ?window server msgs fd exec_callback =
let t = { exec_callback;
channels = [];
env; window;
nexus_mbox = Lwt_mvar.create_empty () }
in
send_msgs fd server msgs >>= fun server ->
Expand Down
19 changes: 12 additions & 7 deletions test/awa_lwt_server.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,26 @@ let user_db =
let awa = Awa.Auth.make_user "awa" [ key ] in
[ foo; awa ]

let exec addr cmd sshin sshout _ssherror =
let exec addr ?cmd sshin sshout _ssherror =
let rec echo () =
sshin () >>= function
| `Eof -> Lwt.return_unit
| `Data input -> sshout input >>= fun () -> echo ()
in
let ping () = sshout (Cstruct.of_string "pong\n") in
let badcmd () =
let badcmd cmd =
sshout (Cstruct.of_string (Printf.sprintf "Bad command `%s`\n" cmd))
in
Lwt_io.printf "[%s] executing `%s`\n%!" addr cmd >>= fun () ->
(match cmd with "echo" -> echo () | "ping" -> ping () | _ -> badcmd ())
>>= fun () ->
Lwt_io.printf "[%s] execution of `%s` finished\n%!" addr cmd
(* XXX Awa_lwt must close the channel when exec returns ! *)
match cmd with
| None ->
Lwt_io.printf "[%s] impossible to execute a shell\n%!" addr >>= fun () ->
sshout (Cstruct.of_string (Printf.sprintf "No shell available"))
| Some cmd ->
Lwt_io.printf "[%s] executing `%s`\n%!" addr cmd >>= fun () ->
(match cmd with "echo" -> echo () | "ping" -> ping () | _ -> badcmd cmd)
>>= fun () ->
Lwt_io.printf "[%s] execution of `%s` finished\n%!" addr cmd
(* XXX Awa_lwt must close the channel when exec returns ! *)

let serve rsa fd addr =
Lwt_io.printf "[%s] connected\n%!" addr >>= fun () ->
Expand Down

0 comments on commit 8aa0773

Please sign in to comment.