Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add two events Pty/Set_env/Start_shell into the server #53

Merged
merged 3 commits into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions lib/server.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ type event =
| Channel_data of (int32 * Cstruct.t)
| Channel_eof of int32
| Disconnected of string
| Pty of (string * int32 * int32 * int32 * int32 * string)
| Set_env of (string * string)
| Start_shell of int32

type t = {
client_version : string option; (* Without crlf *)
Expand Down Expand Up @@ -253,23 +256,17 @@ let input_channel_request t recp_channel want_reply data =
else
make_noreply t
in
let success t =
if want_reply then
make_reply t (Msg_channel_success recp_channel)
else
make_noreply t
in
let event t event =
if want_reply then
make_reply_with_event t (Msg_channel_success recp_channel) event
else
make_event t event
in
let handle t c = function
| Pty_req _ -> success t
| Pty_req v -> event t (Pty v)
| X11_req _ -> fail t
| Env (_key, _value) -> success t (* TODO implement me *)
| Shell -> fail t
| Env v -> event t (Set_env v)
| Shell -> event t (Start_shell c)
| Exec cmd -> event t (Channel_exec (c, cmd))
| Subsystem cmd -> event t (Channel_subsystem (c, cmd))
| Window_change _ -> fail t
Expand Down
2 changes: 1 addition & 1 deletion lib/wire.ml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ let get_string buf =
trap_error (fun () ->
let len = Cstruct.BE.get_uint32 buf 0 |> Int32.to_int in
Ssh.guard_sshlen_exn len;
(Cstruct.copy buf 4 len), Cstruct.shift buf (len + 4))
(Cstruct.to_string buf ~off:4 ~len), Cstruct.shift buf (len + 4))

let put_string s t =
let len = String.length s in
Expand Down
38 changes: 33 additions & 5 deletions lwt/awa_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,28 @@ type sshin_msg = [
]

type channel = {
cmd : string;
cmd : string option;
id : int32;
sshin_mbox : sshin_msg Lwt_mvar.t;
exec_thread : unit Lwt.t;
}

type exec_callback =
string -> (* cmd *)
?cmd:string -> (* cmd *)
(unit -> sshin_msg Lwt.t) -> (* sshin *)
(Cstruct.t -> unit Lwt.t) -> (* sshout *)
(Cstruct.t -> unit Lwt.t) -> (* ssherr *)
unit Lwt.t

type set_env = string -> string -> unit Lwt.t
type set_window = term:string -> w:int32 -> h:int32 -> maxw:int32 -> maxh:int32 -> unit Lwt.t

type t = {
exec_callback : exec_callback; (* callback to run on exec *)
channels : channel list; (* Opened channels *)
nexus_mbox : nexus_msg Lwt_mvar.t;(* Nexus mailbox *)
env : set_env option; (* Environment *)
window : set_window option; (* Window *)
}

let wrapr = function
Expand Down Expand Up @@ -144,6 +149,16 @@ let rec nexus t fd server input_buffer =
| None -> Lwt.return_unit)
>>= fun () ->
nexus t fd server input_buffer
| Some Awa.Server.Set_env (k, v) ->
( match t.env with
| Some set_env -> set_env k v
| None -> Lwt.return_unit ) >>= fun () ->
nexus t fd server input_buffer
| Some Awa.Server.Pty (term, w, h, maxw, maxh, _modes) ->
( match t.window with
| Some set_window -> set_window ~term ~w ~h ~maxw ~maxh
| None -> Lwt.return_unit ) >>= fun () ->
nexus t fd server input_buffer
| Some Awa.Server.Channel_subsystem (id, cmd) (* same as exec *)
| Some Awa.Server.Channel_exec (id, cmd) ->
(* Create an input box *)
Expand All @@ -153,14 +168,27 @@ let rec nexus t fd server input_buffer =
let sshout id buf = Lwt_mvar.put t.nexus_mbox (Sshout (id, buf)) in
let ssherr id buf = Lwt_mvar.put t.nexus_mbox (Ssherr (id, buf)) in
(* Create the execution thread *)
let exec_thread = t.exec_callback cmd sshin (sshout id) (ssherr id) in
let c = { cmd; id; sshin_mbox; exec_thread } in
let exec_thread = t.exec_callback ~cmd sshin (sshout id) (ssherr id) in
let c = { cmd= Some cmd; id; sshin_mbox; exec_thread } in
let t = { t with channels = c :: t.channels } in
nexus t fd server input_buffer
| Some (Awa.Server.Start_shell id) ->
(* Create an input box *)
let sshin_mbox = Lwt_mvar.create_empty () in
(* Create a callback for each mbox *)
let sshin () = Lwt_mvar.take sshin_mbox in
let sshout id buf = Lwt_mvar.put t.nexus_mbox (Sshout (id, buf)) in
let ssherr id buf = Lwt_mvar.put t.nexus_mbox (Ssherr (id, buf)) in
(* Create the execution thread *)
let exec_thread = t.exec_callback sshin (sshout id) (ssherr id) in
let c = { cmd= None; id; sshin_mbox; exec_thread } in
let t = { t with channels = c :: t.channels } in
nexus t fd server input_buffer

let spawn_server server msgs fd exec_callback =
let spawn_server ?env ?window server msgs fd exec_callback =
let t = { exec_callback;
channels = [];
env; window;
nexus_mbox = Lwt_mvar.create_empty () }
in
send_msgs fd server msgs >>= fun server ->
Expand Down
55 changes: 37 additions & 18 deletions mirage/awa_mirage.ml
Original file line number Diff line number Diff line change
Expand Up @@ -161,24 +161,26 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
| Sshout of (int32 * Cstruct.t)
| Ssherr of (int32 * Cstruct.t)

type sshin_msg = [
| `Data of Cstruct.t
| `Eof
]

type channel = {
cmd : string;
cmd : string option;
id : int32;
sshin_mbox : sshin_msg Lwt_mvar.t;
sshin_mbox : Cstruct.t Mirage_flow.or_eof Lwt_mvar.t;
exec_thread : unit Lwt.t;
}

type exec_callback =
string -> (* cmd *)
(unit -> sshin_msg Lwt.t) -> (* sshin *)
(Cstruct.t -> unit Lwt.t) -> (* sshout *)
(Cstruct.t -> unit Lwt.t) -> (* ssherr *)
unit Lwt.t
type request =
| Pty_req of { width : int32; height : int32; max_width : int32; max_height : int32; term : string }
| Pty_set of { width : int32; height : int32; max_width : int32; max_height : int32 }
| Set_env of { key : string; value : string }
| Channel of { cmd : string
; ic : unit -> Cstruct.t Mirage_flow.or_eof Lwt.t
; oc : Cstruct.t -> unit Lwt.t
; ec : Cstruct.t -> unit Lwt.t }
| Shell of { ic : unit -> Cstruct.t Mirage_flow.or_eof Lwt.t
; oc : Cstruct.t -> unit Lwt.t
; ec : Cstruct.t -> unit Lwt.t }

type exec_callback = request -> unit Lwt.t

type t = {
exec_callback : exec_callback; (* callback to run on exec *)
Expand Down Expand Up @@ -285,6 +287,12 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
>>= fun server ->
match event with
| None -> nexus t fd server input_buffer (List.append pending_promises [ Lwt_mvar.take t.nexus_mbox ])
| Some Awa.Server.Pty (term, width, height, max_width, max_height, _modes) ->
t.exec_callback (Pty_req { width; height; max_width; max_height; term; }) >>= fun () ->
nexus t fd server input_buffer pending_promises
| Some Awa.Server.Set_env (key, value) ->
t.exec_callback (Set_env { key; value; }) >>= fun () ->
nexus t fd server input_buffer pending_promises
| Some Awa.Server.Disconnected _ ->
Lwt_list.iter_p sshin_eof t.channels
>>= fun () -> Lwt.return t
Expand All @@ -303,12 +311,23 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
(* Create an input box *)
let sshin_mbox = Lwt_mvar.create_empty () in
(* Create a callback for each mbox *)
let sshin () = Lwt_mvar.take sshin_mbox in
let sshout id buf = Lwt_mvar.put t.nexus_mbox (Sshout (id, buf)) in
let ssherr id buf = Lwt_mvar.put t.nexus_mbox (Ssherr (id, buf)) in
let ic () = Lwt_mvar.take sshin_mbox in
let oc id buf = Lwt_mvar.put t.nexus_mbox (Sshout (id, buf)) in
let ec id buf = Lwt_mvar.put t.nexus_mbox (Ssherr (id, buf)) in
(* Create the execution thread *)
let exec_thread = t.exec_callback (Channel { cmd; ic; oc= oc id; ec= ec id; }) in
let c = { cmd= Some cmd; id; sshin_mbox; exec_thread } in
let t = { t with channels = c :: t.channels } in
nexus t fd server input_buffer (List.append pending_promises [ Lwt_mvar.take t.nexus_mbox ])
| Some (Awa.Server.Start_shell id) ->
let sshin_mbox = Lwt_mvar.create_empty () in
(* Create a callback for each mbox *)
let ic () = Lwt_mvar.take sshin_mbox in
let oc id buf = Lwt_mvar.put t.nexus_mbox (Sshout (id, buf)) in
let ec id buf = Lwt_mvar.put t.nexus_mbox (Ssherr (id, buf)) in
(* Create the execution thread *)
let exec_thread = t.exec_callback cmd sshin (sshout id) (ssherr id) in
let c = { cmd; id; sshin_mbox; exec_thread } in
let exec_thread = t.exec_callback (Shell { ic; oc= oc id; ec= ec id; }) in
let c = { cmd= None; id; sshin_mbox; exec_thread } in
let t = { t with channels = c :: t.channels } in
nexus t fd server input_buffer (List.append pending_promises [ Lwt_mvar.take t.nexus_mbox ])

Expand Down
28 changes: 14 additions & 14 deletions mirage/awa_mirage.mli
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,21 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) :
Awa.Hostkey.priv -> Awa.Ssh.channel_request -> FLOW.flow ->
(flow, error) result Lwt.t

type t
type t

type sshin_msg = [
| `Data of Cstruct.t
| `Eof
]
type request =
| Pty_req of { width : int32; height : int32; max_width : int32; max_height : int32; term : string }
| Pty_set of { width : int32; height : int32; max_width : int32; max_height : int32 }
| Set_env of { key : string; value : string }
| Channel of { cmd : string
; ic : unit -> Cstruct.t Mirage_flow.or_eof Lwt.t
; oc : Cstruct.t -> unit Lwt.t
; ec : Cstruct.t -> unit Lwt.t }
| Shell of { ic : unit -> Cstruct.t Mirage_flow.or_eof Lwt.t
; oc : Cstruct.t -> unit Lwt.t
; ec : Cstruct.t -> unit Lwt.t }

type exec_callback =
string -> (* cmd *)
(unit -> sshin_msg Lwt.t) -> (* sshin *)
(Cstruct.t -> unit Lwt.t) -> (* sshout *)
(Cstruct.t -> unit Lwt.t) -> (* ssherr *)
unit Lwt.t
type exec_callback = request -> unit Lwt.t

val spawn_server : ?stop:Lwt_switch.t -> Awa.Server.t -> Awa.Ssh.message list -> F.flow ->
exec_callback -> t Lwt.t
Expand All @@ -62,6 +64,4 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) :
{b NOTE}: Even if the [ssh_channel_handler] is fulfilled, [spawn_server]
continues to handle SSH channels. Only [stop] can really stop the internal
SSH channels handler. *)

end
with module FLOW = F
end with module FLOW = F
19 changes: 12 additions & 7 deletions test/awa_lwt_server.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,26 @@ let user_db =
let awa = Awa.Auth.make_user "awa" [ key ] in
[ foo; awa ]

let exec addr cmd sshin sshout _ssherror =
let exec addr ?cmd sshin sshout _ssherror =
let rec echo () =
sshin () >>= function
| `Eof -> Lwt.return_unit
| `Data input -> sshout input >>= fun () -> echo ()
in
let ping () = sshout (Cstruct.of_string "pong\n") in
let badcmd () =
let badcmd cmd =
sshout (Cstruct.of_string (Printf.sprintf "Bad command `%s`\n" cmd))
in
Lwt_io.printf "[%s] executing `%s`\n%!" addr cmd >>= fun () ->
(match cmd with "echo" -> echo () | "ping" -> ping () | _ -> badcmd ())
>>= fun () ->
Lwt_io.printf "[%s] execution of `%s` finished\n%!" addr cmd
(* XXX Awa_lwt must close the channel when exec returns ! *)
match cmd with
| None ->
Lwt_io.printf "[%s] impossible to execute a shell\n%!" addr >>= fun () ->
sshout (Cstruct.of_string (Printf.sprintf "No shell available"))
| Some cmd ->
Lwt_io.printf "[%s] executing `%s`\n%!" addr cmd >>= fun () ->
(match cmd with "echo" -> echo () | "ping" -> ping () | _ -> badcmd cmd)
>>= fun () ->
Lwt_io.printf "[%s] execution of `%s` finished\n%!" addr cmd
(* XXX Awa_lwt must close the channel when exec returns ! *)

let serve rsa fd addr =
Lwt_io.printf "[%s] connected\n%!" addr >>= fun () ->
Expand Down
5 changes: 3 additions & 2 deletions test/awa_test_server.ml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ let rec serve t cmd =
| Channel_subsystem (id, exec) (* same as exec *)
| Channel_exec (id, exec) ->
printf "channel exec %s\n%!" exec;
match exec with
begin match exec with
| "suicide" ->
let* _ = Driver.disconnect t in
Ok ()
Expand All @@ -104,7 +104,8 @@ let rec serve t cmd =
let* t = Driver.send_channel_data t id (Cstruct.of_string m) in
printf "%s\n%!" m;
let* t = Driver.disconnect t in
serve t cmd
serve t cmd end
| _ -> failwith "Invalid SSH event"

let user_db =
(* User foo auths by passoword *)
Expand Down