diff --git a/lwt/awa_lwt.ml b/lwt/awa_lwt.ml index 33035b6..a23665f 100644 --- a/lwt/awa_lwt.ml +++ b/lwt/awa_lwt.ml @@ -29,23 +29,28 @@ type sshin_msg = [ ] type channel = { - cmd : string; + cmd : string option; id : int32; sshin_mbox : sshin_msg Lwt_mvar.t; exec_thread : unit Lwt.t; } type exec_callback = - string -> (* cmd *) + ?cmd:string -> (* cmd *) (unit -> sshin_msg Lwt.t) -> (* sshin *) (Cstruct.t -> unit Lwt.t) -> (* sshout *) (Cstruct.t -> unit Lwt.t) -> (* ssherr *) unit Lwt.t +type set_env = string -> string -> unit Lwt.t +type set_window = term:string -> w:int32 -> h:int32 -> maxw:int32 -> maxh:int32 -> unit Lwt.t + type t = { exec_callback : exec_callback; (* callback to run on exec *) channels : channel list; (* Opened channels *) nexus_mbox : nexus_msg Lwt_mvar.t;(* Nexus mailbox *) + env : set_env option; (* Environment *) + window : set_window option; (* Window *) } let wrapr = function @@ -144,6 +149,16 @@ let rec nexus t fd server input_buffer = | None -> Lwt.return_unit) >>= fun () -> nexus t fd server input_buffer + | Some Awa.Server.Set_env (k, v) -> + ( match t.env with + | Some set_env -> set_env k v + | None -> Lwt.return_unit ) >>= fun () -> + nexus t fd server input_buffer + | Some Awa.Server.Pty (term, w, h, maxw, maxh, _modes) -> + ( match t.window with + | Some set_window -> set_window ~term ~w ~h ~maxw ~maxh + | None -> Lwt.return_unit ) >>= fun () -> + nexus t fd server input_buffer | Some Awa.Server.Channel_subsystem (id, cmd) (* same as exec *) | Some Awa.Server.Channel_exec (id, cmd) -> (* Create an input box *) @@ -153,15 +168,27 @@ let rec nexus t fd server input_buffer = let sshout id buf = Lwt_mvar.put t.nexus_mbox (Sshout (id, buf)) in let ssherr id buf = Lwt_mvar.put t.nexus_mbox (Ssherr (id, buf)) in (* Create the execution thread *) - let exec_thread = t.exec_callback cmd sshin (sshout id) (ssherr id) in - let c = { cmd; id; sshin_mbox; exec_thread } in + let exec_thread = t.exec_callback ~cmd sshin (sshout id) (ssherr id) in + let c = { cmd= Some cmd; id; sshin_mbox; exec_thread } in + let t = { t with channels = c :: t.channels } in + nexus t fd server input_buffer + | Some (Awa.Server.Start_shell id) -> + (* Create an input box *) + let sshin_mbox = Lwt_mvar.create_empty () in + (* Create a callback for each mbox *) + let sshin () = Lwt_mvar.take sshin_mbox in + let sshout id buf = Lwt_mvar.put t.nexus_mbox (Sshout (id, buf)) in + let ssherr id buf = Lwt_mvar.put t.nexus_mbox (Ssherr (id, buf)) in + (* Create the execution thread *) + let exec_thread = t.exec_callback sshin (sshout id) (ssherr id) in + let c = { cmd= None; id; sshin_mbox; exec_thread } in let t = { t with channels = c :: t.channels } in nexus t fd server input_buffer - | _ -> nexus t fd server input_buffer -let spawn_server server msgs fd exec_callback = +let spawn_server ?env ?window server msgs fd exec_callback = let t = { exec_callback; channels = []; + env; window; nexus_mbox = Lwt_mvar.create_empty () } in send_msgs fd server msgs >>= fun server -> diff --git a/test/awa_lwt_server.ml b/test/awa_lwt_server.ml index 71da2f9..6a99aac 100644 --- a/test/awa_lwt_server.ml +++ b/test/awa_lwt_server.ml @@ -27,21 +27,26 @@ let user_db = let awa = Awa.Auth.make_user "awa" [ key ] in [ foo; awa ] -let exec addr cmd sshin sshout _ssherror = +let exec addr ?cmd sshin sshout _ssherror = let rec echo () = sshin () >>= function | `Eof -> Lwt.return_unit | `Data input -> sshout input >>= fun () -> echo () in let ping () = sshout (Cstruct.of_string "pong\n") in - let badcmd () = + let badcmd cmd = sshout (Cstruct.of_string (Printf.sprintf "Bad command `%s`\n" cmd)) in - Lwt_io.printf "[%s] executing `%s`\n%!" addr cmd >>= fun () -> - (match cmd with "echo" -> echo () | "ping" -> ping () | _ -> badcmd ()) - >>= fun () -> - Lwt_io.printf "[%s] execution of `%s` finished\n%!" addr cmd - (* XXX Awa_lwt must close the channel when exec returns ! *) + match cmd with + | None -> + Lwt_io.printf "[%s] impossible to execute a shell\n%!" addr >>= fun () -> + sshout (Cstruct.of_string (Printf.sprintf "No shell available")) + | Some cmd -> + Lwt_io.printf "[%s] executing `%s`\n%!" addr cmd >>= fun () -> + (match cmd with "echo" -> echo () | "ping" -> ping () | _ -> badcmd cmd) + >>= fun () -> + Lwt_io.printf "[%s] execution of `%s` finished\n%!" addr cmd + (* XXX Awa_lwt must close the channel when exec returns ! *) let serve rsa fd addr = Lwt_io.printf "[%s] connected\n%!" addr >>= fun () ->