Skip to content

Commit

Permalink
Add two events Pty/Set_env/Start_shell into the server and export the…
Browse files Browse the repository at this point in the history
…m into the mirage layer
  • Loading branch information
dinosaure committed Mar 4, 2023
1 parent 7608e31 commit 92d7ab5
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 55 deletions.
15 changes: 6 additions & 9 deletions lib/server.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ type event =
| Channel_data of (int32 * Cstruct.t)
| Channel_eof of int32
| Disconnected of string
| Pty of (string * int32 * int32 * int32 * int32 * string)
| Set_env of (string * string)
| Start_shell of int32

type t = {
client_version : string option; (* Without crlf *)
Expand Down Expand Up @@ -253,23 +256,17 @@ let input_channel_request t recp_channel want_reply data =
else
make_noreply t
in
let success t =
if want_reply then
make_reply t (Msg_channel_success recp_channel)
else
make_noreply t
in
let event t event =
if want_reply then
make_reply_with_event t (Msg_channel_success recp_channel) event
else
make_event t event
in
let handle t c = function
| Pty_req _ -> success t
| Pty_req v -> event t (Pty v)
| X11_req _ -> fail t
| Env (_key, _value) -> success t (* TODO implement me *)
| Shell -> fail t
| Env v -> event t (Set_env v)
| Shell -> event t (Start_shell c)
| Exec cmd -> event t (Channel_exec (c, cmd))
| Subsystem cmd -> event t (Channel_subsystem (c, cmd))
| Window_change _ -> fail t
Expand Down
1 change: 1 addition & 0 deletions lwt/awa_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ let rec nexus t fd server input_buffer =
let c = { cmd; id; sshin_mbox; exec_thread } in
let t = { t with channels = c :: t.channels } in
nexus t fd server input_buffer
| _ -> nexus t fd server input_buffer

let spawn_server server msgs fd exec_callback =
let t = { exec_callback;
Expand Down
76 changes: 47 additions & 29 deletions mirage/awa_mirage.ml
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
| `Nothing, `Channel_eof _ -> `Eof
| `Nothing, `Disconnected -> `Eof
| a, `Channel_stderr (id, data) ->
Log.warn (fun m -> m "%ld stderr %s" id (Cstruct.to_string data));
a
Log.warn (fun m -> m "%ld stderr %s" id (Cstruct.to_string data)); a
| a, _ -> a)
`Nothing events
in
Expand Down Expand Up @@ -161,24 +160,26 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
| Sshout of (int32 * Cstruct.t)
| Ssherr of (int32 * Cstruct.t)

type sshin_msg = [
| `Data of Cstruct.t
| `Eof
]

type channel = {
cmd : string;
cmd : string option;
id : int32;
sshin_mbox : sshin_msg Lwt_mvar.t;
sshin_mbox : Cstruct.t Mirage_flow.or_eof Lwt_mvar.t;
exec_thread : unit Lwt.t;
}

type exec_callback =
string -> (* cmd *)
(unit -> sshin_msg Lwt.t) -> (* sshin *)
(Cstruct.t -> unit Lwt.t) -> (* sshout *)
(Cstruct.t -> unit Lwt.t) -> (* ssherr *)
unit Lwt.t
type request =
| Pty_req of { width : int32; height : int32; max_width : int32; max_height : int32; term : string }
| Pty_set of { width : int32; height : int32; max_width : int32; max_height : int32 }
| Set_env of { key : string; value : string }
| Channel of { command : string
; ic : unit -> Cstruct.t Mirage_flow.or_eof Lwt.t
; oc : Cstruct.t -> unit Lwt.t
; ec : Cstruct.t -> unit Lwt.t }
| Shell of { ic : unit -> Cstruct.t Mirage_flow.or_eof Lwt.t
; oc : Cstruct.t -> unit Lwt.t
; ec : Cstruct.t -> unit Lwt.t }

type exec_callback = request -> unit Lwt.t

type t = {
exec_callback : exec_callback; (* callback to run on exec *)
Expand Down Expand Up @@ -285,6 +286,12 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
>>= fun server ->
match event with
| None -> nexus t fd server input_buffer (List.append pending_promises [ Lwt_mvar.take t.nexus_mbox ])
| Some Awa.Server.Pty (term, width, height, max_width, max_height, _modes) ->
t.exec_callback (Pty_req { width; height; max_width; max_height; term; }) >>= fun () ->
nexus t fd server input_buffer pending_promises
| Some Awa.Server.Set_env (key, value) ->
t.exec_callback (Set_env { key; value; }) >>= fun () ->
nexus t fd server input_buffer pending_promises
| Some Awa.Server.Disconnected _ ->
Lwt_list.iter_p sshin_eof t.channels
>>= fun () -> Lwt.return t
Expand All @@ -298,32 +305,43 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
| None -> Lwt.return_unit)
>>= fun () ->
nexus t fd server input_buffer (List.append pending_promises [ Lwt_mvar.take t.nexus_mbox ])
| Some Awa.Server.Channel_subsystem (id, cmd) (* same as exec *)
| Some Awa.Server.Channel_exec (id, cmd) ->
| Some Awa.Server.Channel_subsystem (id, command) (* same as exec *)
| Some Awa.Server.Channel_exec (id, command) ->
(* Create an input box *)
let sshin_mbox = Lwt_mvar.create_empty () in
(* Create a callback for each mbox *)
let sshin () = Lwt_mvar.take sshin_mbox in
let sshout id buf = Lwt_mvar.put t.nexus_mbox (Sshout (id, buf)) in
let ssherr id buf = Lwt_mvar.put t.nexus_mbox (Ssherr (id, buf)) in
let ic () = Lwt_mvar.take sshin_mbox in
let oc id buf = Lwt_mvar.put t.nexus_mbox (Sshout (id, buf)) in
let ec id buf = Lwt_mvar.put t.nexus_mbox (Ssherr (id, buf)) in
(* Create the execution thread *)
let exec_thread = t.exec_callback cmd sshin (sshout id) (ssherr id) in
let c = { cmd; id; sshin_mbox; exec_thread } in
let exec_thread = t.exec_callback (Channel { command; ic; oc= oc id; ec= ec id; }) in
let c = { cmd= Some command; id; sshin_mbox; exec_thread } in
let t = { t with channels = c :: t.channels } in
nexus t fd server input_buffer (List.append pending_promises [ Lwt_mvar.take t.nexus_mbox ])
| Some (Awa.Server.Start_shell id) ->
let sshin_mbox = Lwt_mvar.create_empty () in
(* Create a callback for each mbox *)
let ic () = Lwt_mvar.take sshin_mbox in
let oc id buf = Lwt_mvar.put t.nexus_mbox (Sshout (id, buf)) in
let ec id buf = Lwt_mvar.put t.nexus_mbox (Ssherr (id, buf)) in
(* Create the execution thread *)
let exec_thread = t.exec_callback (Shell { ic; oc= oc id; ec= ec id; }) in
let c = { cmd= None; id; sshin_mbox; exec_thread } in
let t = { t with channels = c :: t.channels } in
nexus t fd server input_buffer (List.append pending_promises [ Lwt_mvar.take t.nexus_mbox ])

let spawn_server ?stop server msgs fd exec_callback =
let switched_off =
let t, u = Lwt.wait () in
Lwt_switch.add_hook stop (fun () ->
Lwt.wakeup_later u Net_eof;
Log.debug (fun m -> m "Turn off the server.") ;
Lwt.return_unit); t in
let t = { exec_callback;
channels = [];
nexus_mbox = Lwt_mvar.create_empty ()
nexus_mbox = Lwt_mvar.create_empty ();
}
in
let open Lwt.Syntax in
let* switched_off =
let thread, u = Lwt.wait () in
Lwt_switch.add_hook_or_exec stop (fun () ->
Lwt.wakeup_later u Net_eof;
Lwt_list.iter_p sshin_eof t.channels) >|= fun () -> thread in
send_msgs fd server msgs >>= fun server ->
(* the ssh communication will start with 'net_read' and can only add a 'Lwt.take' promise when
* one Awa.Server.Channel_{exec,subsystem} is received
Expand Down
28 changes: 14 additions & 14 deletions mirage/awa_mirage.mli
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,21 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) :
Awa.Hostkey.priv -> Awa.Ssh.channel_request -> FLOW.flow ->
(flow, error) result Lwt.t

type t
type t

type sshin_msg = [
| `Data of Cstruct.t
| `Eof
]
type request =
| Pty_req of { width : int32; height : int32; max_width : int32; max_height : int32; term : string }
| Pty_set of { width : int32; height : int32; max_width : int32; max_height : int32 }
| Set_env of { key : string; value : string }
| Channel of { command : string
; ic : unit -> Cstruct.t Mirage_flow.or_eof Lwt.t
; oc : Cstruct.t -> unit Lwt.t
; ec : Cstruct.t -> unit Lwt.t }
| Shell of { ic : unit -> Cstruct.t Mirage_flow.or_eof Lwt.t
; oc : Cstruct.t -> unit Lwt.t
; ec : Cstruct.t -> unit Lwt.t }

type exec_callback =
string -> (* cmd *)
(unit -> sshin_msg Lwt.t) -> (* sshin *)
(Cstruct.t -> unit Lwt.t) -> (* sshout *)
(Cstruct.t -> unit Lwt.t) -> (* ssherr *)
unit Lwt.t
type exec_callback = request -> unit Lwt.t

val spawn_server : ?stop:Lwt_switch.t -> Awa.Server.t -> Awa.Ssh.message list -> F.flow ->
exec_callback -> t Lwt.t
Expand All @@ -62,6 +64,4 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) :
{b NOTE}: Even if the [ssh_channel_handler] is fulfilled, [spawn_server]
continues to handle SSH channels. Only [stop] can really stop the internal
SSH channels handler. *)

end
with module FLOW = F
end with module FLOW = F
2 changes: 1 addition & 1 deletion mirage/dune
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
(name awa_mirage)
(public_name awa-mirage)
(wrapped false)
(libraries awa mirage-flow mirage-clock mirage-time duration lwt mtime logs))
(libraries hxd.core hxd.string awa mirage-flow mirage-clock mirage-time duration lwt mtime))
5 changes: 3 additions & 2 deletions test/awa_test_server.ml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ let rec serve t cmd =
| Channel_subsystem (id, exec) (* same as exec *)
| Channel_exec (id, exec) ->
printf "channel exec %s\n%!" exec;
match exec with
begin match exec with
| "suicide" ->
let* _ = Driver.disconnect t in
Ok ()
Expand All @@ -104,7 +104,8 @@ let rec serve t cmd =
let* t = Driver.send_channel_data t id (Cstruct.of_string m) in
printf "%s\n%!" m;
let* t = Driver.disconnect t in
serve t cmd
serve t cmd end
| _ -> failwith "Invalid SSH event"

let user_db =
(* User foo auths by passoword *)
Expand Down

0 comments on commit 92d7ab5

Please sign in to comment.