diff --git a/src/lwt/conduit_lwt.ml b/src/lwt/conduit_lwt.ml index d80262ef..a3168db7 100644 --- a/src/lwt/conduit_lwt.ml +++ b/src/lwt/conduit_lwt.ml @@ -15,11 +15,12 @@ let failwith fmt = Format.kasprintf (fun err -> Lwt.fail (Failure err)) fmt let io_of_flow flow = let open Lwt.Infix in + let mutex = Lwt_mutex.create () in let ic_closed = ref false and oc_closed = ref false in let close () = if !ic_closed && !oc_closed then - close flow >>= function + Lwt_mutex.with_lock mutex (fun () -> close flow) >>= function | Ok () -> Lwt.return_unit | Error err -> failwith "%a" pp_error err else Lwt.return_unit in @@ -29,19 +30,21 @@ let io_of_flow flow = let oc_close () = oc_closed := true ; close () in - let recv buf off len = + let rec rrecv buf off len = let raw = Cstruct.of_bigarray buf ~off ~len in - recv flow raw >>= function + Lwt_mutex.with_lock mutex (fun () -> recv flow raw) >>= function + | Ok (`Input 0) -> Lwt_unix.yield () >>= fun () -> rrecv buf off len | Ok (`Input len) -> Lwt.return len | Ok `End_of_flow -> Lwt.return 0 | Error err -> failwith "%a" pp_error err in - let ic = Lwt_io.make ~close:ic_close ~mode:Lwt_io.input recv in - let send buf off len = + let ic = Lwt_io.make ~close:ic_close ~mode:Lwt_io.input rrecv in + let rec ssend buf off len = let raw = Cstruct.of_bigarray buf ~off ~len in - send flow raw >>= function + Lwt_mutex.with_lock mutex (fun () -> send flow raw) >>= function + | Ok 0 -> Lwt_unix.yield () >>= fun () -> ssend buf off len | Ok len -> Lwt.return len | Error err -> failwith "%a" pp_error err in - let oc = Lwt_io.make ~close:oc_close ~mode:Lwt_io.output send in + let oc = Lwt_io.make ~close:oc_close ~mode:Lwt_io.output ssend in (ic, oc) let ( >>? ) = Lwt_result.bind @@ -111,9 +114,29 @@ module TCP = struct socket : Lwt_unix.file_descr; sockaddr : Lwt_unix.sockaddr; linger : Bytes.t; + recv_first : bool; mutable closed : bool; } + (* XXX(dinosaure): [recv_first] is here to fit into [Lwt_io], from what we know, + * a tuple of [Lwt_io] [in_channel/out_channel] tries to receive first. However, + * such behavior is problematic for HTTP: + * - as a HTTP client, we should send first + * - as a HTTP server, we should recv first + * - with TLS layer [conduit-tls], both work - where + * the handshake can be done by send or recv + * + * For my perspective, [Lwt_io] is not the right way to abstract a [Conduit.flow] + * and we should directly use [Conduit.send]/[Conduit.recv] when we need to use + * them. Because [Lwt_io] tries to receive in any case, we must check (with [Lwt_unix.readable]) + * if the socket can be read. In that case and if we want to [recv_first], we start + * to waiting something from our peer. In the other case, we returns [`Input 0] + * which gives an opportunity for the scheduler to send something (so, [send_first]). + * + * Such patch is really close to what LWT/[Lwt_io] does. A problem should be a diff + * on behaviors between [Conduit_lwt] and [mirage-tcpip] + [Conduit_mirage]. The best + * way to delete it is to deprecate [io_of_flow]. *) + let peer { sockaddr; _ } = sockaddr let sock { socket; _ } = Lwt_unix.getsockname socket @@ -161,7 +184,14 @@ module TCP = struct let rec go () = let process () = Lwt_unix.connect socket sockaddr >>= fun () -> - Lwt.return_ok { socket; sockaddr; linger; closed = false } in + Lwt.return_ok + { + socket; + sockaddr; + linger; + closed = false; + recv_first = Lwt_unix.readable socket; + } in Lwt.catch process @@ function | Unix.(Unix_error ((EACCES | EPERM), _, _)) -> Lwt.return_error `Operation_not_permitted @@ -220,7 +250,11 @@ module TCP = struct (if filled + len = 0 then `End_of_flow else `Input (filled + len))) in - Lwt.catch (fun () -> process 0 raw) @@ function + Lwt.catch (fun () -> + if (not (Lwt_unix.readable t.socket)) && not t.recv_first + then Lwt.return_ok (`Input 0) + else process 0 raw) + @@ function | Unix.(Unix_error ((EAGAIN | EWOULDBLOCK), _, _)) -> recv t raw | Unix.(Unix_error (EINTR, _, _)) -> recv t raw | Unix.(Unix_error (EFAULT, _, _)) -> Lwt.return_error `Bad_address @@ -379,8 +413,14 @@ module TCP = struct let process () = Lwt_unix.accept service >>= fun (socket, sockaddr) -> let linger = Bytes.create 0x1000 in - Lwt.return_ok { Protocol.socket; sockaddr; linger; closed = false } - in + Lwt.return_ok + { + Protocol.socket; + sockaddr; + linger; + closed = false; + recv_first = Lwt_unix.readable socket; + } in Lwt.catch process @@ function | Unix.(Unix_error ((EAGAIN | EWOULDBLOCK), _, _)) -> accept service | Unix.(Unix_error (EINTR, _, _)) -> accept service diff --git a/src/tls/conduit_tls.ml b/src/tls/conduit_tls.ml index 59c4c14a..43529f59 100644 --- a/src/tls/conduit_tls.ml +++ b/src/tls/conduit_tls.ml @@ -119,7 +119,7 @@ struct (* XXX(dinosaure): it seems that decoding TLS inputs can produce something bigger than expected. For example, decoding 4096 bytes can produce 4119 byte(s). *) - Log.debug (fun m -> m "|- TLS state: Ok") ; + Log.debug (fun m -> m "|- TLS state: Ok.") ; queue_wr_opt queue data ; flow_wr_opt flow resp >>? fun () -> return (Ok (Some tls)) @@ -151,7 +151,8 @@ struct flow_wr_opt flow resp >>? fun () -> if Tls.Engine.handshake_in_progress tls then ( - Log.debug (fun m -> m "<- Read the TLS flow") ; + Log.debug (fun m -> + m "<- Read the TLS flow (while handshake).") ; Flow.recv flow raw0 >>| reword_error flow_error >>? function | `End_of_flow -> Log.warn (fun m -> @@ -159,6 +160,7 @@ struct "Got EOF from underlying connection while \ handshake.") ; return (Ok None) + | `Input 0 -> return (Ok (Some tls)) | `Input len -> let uid = Hashtbl.hash @@ -222,21 +224,36 @@ struct m "<- Connection closed by underlying protocol.") ; t.tls <- None ; return (Ok `End_of_flow) - | `Input len -> - let handle = + | `Input 0 -> + t.tls <- Some tls ; + return (Ok (`Input 0)) + | `Input len -> ( + Log.debug (fun m -> m "<- Got %d byte(s)." len) ; + let handle raw = if Tls.Engine.handshake_in_progress tls - then handle_handshake tls t.queue t.flow - else handle_tls tls t.queue t.flow in - let uid = - Hashtbl.hash (Cstruct.to_string (Cstruct.sub t.raw 0 len)) - in + then handle_handshake tls t.queue t.flow raw + else handle_tls tls t.queue t.flow raw in + let before = Tls.Engine.handshake_in_progress tls in Log.debug (fun m -> + let uid = + Hashtbl.hash + (Cstruct.to_string (Cstruct.sub t.raw 0 len)) in m "<~ [%04x] Got %d bytes (handshake in progress: %b)." uid len (Tls.Engine.handshake_in_progress tls)) ; handle (Cstruct.sub t.raw 0 len) >>? fun tls -> + let after = + Option.fold ~none:false + ~some:Tls.Engine.handshake_in_progress tls in t.tls <- tls ; - recv t raw)) + match (tls, before, after) with + | Some _, false, false | Some _, true, false -> + return (Ok (`Input 0)) + | Some _, false, true (* renegociate *) + | Some _, true, true (* continue handshake *) + | None, _, _ -> + Log.debug (fun m -> m "Retry to receive something.") ; + recv t raw))) | _ -> let max = Cstruct.len raw in let len = min (Ke.length t.queue) max in @@ -262,6 +279,9 @@ struct Log.warn (fun m -> m "[-] Underlying flow already closed.") ; t.tls <- None ; return (Error `Closed_by_peer) + | `Input 0 -> + t.tls <- Some tls ; + return (Ok 0) | `Input len -> ( let res = handle_handshake tls t.queue t.flow (Cstruct.sub t.raw 0 len) diff --git a/tests/ping-pong/common.ml b/tests/ping-pong/common.ml index 7f6136e8..d3c8b058 100644 --- a/tests/ping-pong/common.ml +++ b/tests/ping-pong/common.ml @@ -15,10 +15,16 @@ module type CONDITION = sig type 'a t end +module type IO = sig + include Conduit.IO + + val yield : unit -> unit t +end + let ( <.> ) f g x = f (g x) module Make - (IO : Conduit.IO) + (IO : IO) (Condition : CONDITION) (Conduit : S with type +'a io = 'a IO.t @@ -67,6 +73,7 @@ struct | None -> ( Conduit.recv flow tmp >>? function | `End_of_flow -> IO.return (Ok `Close) + | `Input 0 -> IO.yield () >>= go | `Input len -> Ke.Rke.N.push queue ~blit ~length:Cstruct.len ~off:0 ~len tmp ; go ()) in @@ -76,6 +83,15 @@ struct let ping = Cstruct.of_string "ping\n" + let send flow raw = + let rec go flow raw = + Conduit.send flow raw >>? function + | 0 -> IO.yield () >>= fun () -> go flow raw + | len -> + let raw = Cstruct.shift raw len in + if Cstruct.len raw = 0 then return (Ok ()) else go flow raw in + go flow raw + let transmission flow = let queue = Ke.Rke.create ~capacity:0x1000 Bigarray.char in let rec go () = @@ -83,13 +99,13 @@ struct | Ok `Close | Error _ -> Conduit.close flow | Ok (`Line "ping") -> Fmt.epr "[!] received ping.\n%!" ; - Conduit.send flow pong >>? fun _ -> go () + send flow pong >>? go | Ok (`Line "pong") -> Fmt.epr "[!] received pong.\n%!" ; - Conduit.send flow ping >>? fun _ -> go () + send flow ping >>? go | Ok (`Line line) -> Fmt.epr "[!] received %S.\n%!" line ; - Conduit.send flow (Cstruct.of_string (line ^ "\n")) >>? fun _ -> + send flow (Cstruct.of_string (line ^ "\n")) >>? fun () -> Conduit.close flow in go () >>= function | Error err -> Fmt.failwith "%a" Conduit.pp_error err @@ -114,7 +130,7 @@ struct let rec go = function | [] -> Conduit.close flow | line :: rest -> ( - Conduit.send flow (Cstruct.of_string (line ^ "\n")) >>? fun _ -> + send flow (Cstruct.of_string (line ^ "\n")) >>? fun () -> getline queue flow >>? function | `Close -> Conduit.close flow | `Line "pong" -> go rest diff --git a/tests/ping-pong/with_async.ml b/tests/ping-pong/with_async.ml index 323f24e1..a7d400aa 100644 --- a/tests/ping-pong/with_async.ml +++ b/tests/ping-pong/with_async.ml @@ -11,6 +11,8 @@ include Common.Make let bind x f = Async.Deferred.bind x ~f let return = Async.Deferred.return + + let yield () = Async.Deferred.return () end) (Async.Condition) (struct diff --git a/tests/ping-pong/with_lwt.ml b/tests/ping-pong/with_lwt.ml index 1425e723..307ea62e 100644 --- a/tests/ping-pong/with_lwt.ml +++ b/tests/ping-pong/with_lwt.ml @@ -6,8 +6,34 @@ let () = Printexc.record_backtrace true let () = Ssl.init () +let reporter ppf = + let report src level ~over k msgf = + let k _ = + over () ; + k () in + let with_metadata header _tags k ppf fmt = + Format.kfprintf k ppf + ("%a[%a]: " ^^ fmt ^^ "\n%!") + Logs_fmt.pp_header (level, header) + Fmt.(styled `Magenta string) + (Logs.Src.name src) in + msgf @@ fun ?header ?tags fmt -> with_metadata header tags k ppf fmt in + { Logs.report } + +let () = Fmt_tty.setup_std_outputs ~style_renderer:`Ansi_tty ~utf_8:true () + +let () = Logs.set_reporter (reporter Fmt.stderr) + +let () = Logs.set_level ~all:true (Some Logs.Debug) + let failwith fmt = Fmt.kstrf (fun err -> Lwt.fail (Failure err)) fmt +module Lwt = struct + include Lwt + + let yield = Lwt_unix.yield +end + include Common.Make (Lwt) (Lwt_condition) (struct type 'a condition = 'a Lwt_condition.t