Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 51 additions & 11 deletions src/lwt/conduit_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ let failwith fmt = Format.kasprintf (fun err -> Lwt.fail (Failure err)) fmt

let io_of_flow flow =
let open Lwt.Infix in
let mutex = Lwt_mutex.create () in
let ic_closed = ref false and oc_closed = ref false in
let close () =
if !ic_closed && !oc_closed
then
close flow >>= function
Lwt_mutex.with_lock mutex (fun () -> close flow) >>= function
| Ok () -> Lwt.return_unit
| Error err -> failwith "%a" pp_error err
else Lwt.return_unit in
Expand All @@ -29,19 +30,21 @@ let io_of_flow flow =
let oc_close () =
oc_closed := true ;
close () in
let recv buf off len =
let rec rrecv buf off len =
let raw = Cstruct.of_bigarray buf ~off ~len in
recv flow raw >>= function
Lwt_mutex.with_lock mutex (fun () -> recv flow raw) >>= function
| Ok (`Input 0) -> Lwt_unix.yield () >>= fun () -> rrecv buf off len
| Ok (`Input len) -> Lwt.return len
| Ok `End_of_flow -> Lwt.return 0
| Error err -> failwith "%a" pp_error err in
let ic = Lwt_io.make ~close:ic_close ~mode:Lwt_io.input recv in
let send buf off len =
let ic = Lwt_io.make ~close:ic_close ~mode:Lwt_io.input rrecv in
let rec ssend buf off len =
let raw = Cstruct.of_bigarray buf ~off ~len in
send flow raw >>= function
Lwt_mutex.with_lock mutex (fun () -> send flow raw) >>= function
| Ok 0 -> Lwt_unix.yield () >>= fun () -> ssend buf off len
| Ok len -> Lwt.return len
| Error err -> failwith "%a" pp_error err in
let oc = Lwt_io.make ~close:oc_close ~mode:Lwt_io.output send in
let oc = Lwt_io.make ~close:oc_close ~mode:Lwt_io.output ssend in
(ic, oc)

let ( >>? ) = Lwt_result.bind
Expand Down Expand Up @@ -111,9 +114,29 @@ module TCP = struct
socket : Lwt_unix.file_descr;
sockaddr : Lwt_unix.sockaddr;
linger : Bytes.t;
recv_first : bool;
mutable closed : bool;
}

(* XXX(dinosaure): [recv_first] is here to fit into [Lwt_io], from what we know,
* a tuple of [Lwt_io] [in_channel/out_channel] tries to receive first. However,
* such behavior is problematic for HTTP:
* - as a HTTP client, we should send first
* - as a HTTP server, we should recv first
* - with TLS layer [conduit-tls], both work - where
* the handshake can be done by send or recv
*
* For my perspective, [Lwt_io] is not the right way to abstract a [Conduit.flow]
* and we should directly use [Conduit.send]/[Conduit.recv] when we need to use
* them. Because [Lwt_io] tries to receive in any case, we must check (with [Lwt_unix.readable])
* if the socket can be read. In that case and if we want to [recv_first], we start
* to waiting something from our peer. In the other case, we returns [`Input 0]
* which gives an opportunity for the scheduler to send something (so, [send_first]).
*
* Such patch is really close to what LWT/[Lwt_io] does. A problem should be a diff
* on behaviors between [Conduit_lwt] and [mirage-tcpip] + [Conduit_mirage]. The best
* way to delete it is to deprecate [io_of_flow]. *)

let peer { sockaddr; _ } = sockaddr

let sock { socket; _ } = Lwt_unix.getsockname socket
Expand Down Expand Up @@ -161,7 +184,14 @@ module TCP = struct
let rec go () =
let process () =
Lwt_unix.connect socket sockaddr >>= fun () ->
Lwt.return_ok { socket; sockaddr; linger; closed = false } in
Lwt.return_ok
{
socket;
sockaddr;
linger;
closed = false;
recv_first = Lwt_unix.readable socket;
} in
Lwt.catch process @@ function
| Unix.(Unix_error ((EACCES | EPERM), _, _)) ->
Lwt.return_error `Operation_not_permitted
Expand Down Expand Up @@ -220,7 +250,11 @@ module TCP = struct
(if filled + len = 0
then `End_of_flow
else `Input (filled + len))) in
Lwt.catch (fun () -> process 0 raw) @@ function
Lwt.catch (fun () ->
if (not (Lwt_unix.readable t.socket)) && not t.recv_first
then Lwt.return_ok (`Input 0)
else process 0 raw)
@@ function
| Unix.(Unix_error ((EAGAIN | EWOULDBLOCK), _, _)) -> recv t raw
| Unix.(Unix_error (EINTR, _, _)) -> recv t raw
| Unix.(Unix_error (EFAULT, _, _)) -> Lwt.return_error `Bad_address
Expand Down Expand Up @@ -379,8 +413,14 @@ module TCP = struct
let process () =
Lwt_unix.accept service >>= fun (socket, sockaddr) ->
let linger = Bytes.create 0x1000 in
Lwt.return_ok { Protocol.socket; sockaddr; linger; closed = false }
in
Lwt.return_ok
{
Protocol.socket;
sockaddr;
linger;
closed = false;
recv_first = Lwt_unix.readable socket;
} in
Lwt.catch process @@ function
| Unix.(Unix_error ((EAGAIN | EWOULDBLOCK), _, _)) -> accept service
| Unix.(Unix_error (EINTR, _, _)) -> accept service
Expand Down
40 changes: 30 additions & 10 deletions src/tls/conduit_tls.ml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ struct
(* XXX(dinosaure): it seems that decoding TLS inputs can produce
something bigger than expected. For example, decoding 4096 bytes
can produce 4119 byte(s). *)
Log.debug (fun m -> m "|- TLS state: Ok") ;
Log.debug (fun m -> m "|- TLS state: Ok.") ;
queue_wr_opt queue data ;
flow_wr_opt flow resp >>? fun () -> return (Ok (Some tls))

Expand Down Expand Up @@ -151,14 +151,16 @@ struct
flow_wr_opt flow resp >>? fun () ->
if Tls.Engine.handshake_in_progress tls
then (
Log.debug (fun m -> m "<- Read the TLS flow") ;
Log.debug (fun m ->
m "<- Read the TLS flow (while handshake).") ;
Flow.recv flow raw0 >>| reword_error flow_error >>? function
| `End_of_flow ->
Log.warn (fun m ->
m
"Got EOF from underlying connection while \
handshake.") ;
return (Ok None)
| `Input 0 -> return (Ok (Some tls))
| `Input len ->
let uid =
Hashtbl.hash
Expand Down Expand Up @@ -222,21 +224,36 @@ struct
m "<- Connection closed by underlying protocol.") ;
t.tls <- None ;
return (Ok `End_of_flow)
| `Input len ->
let handle =
| `Input 0 ->
t.tls <- Some tls ;
return (Ok (`Input 0))
| `Input len -> (
Log.debug (fun m -> m "<- Got %d byte(s)." len) ;
let handle raw =
if Tls.Engine.handshake_in_progress tls
then handle_handshake tls t.queue t.flow
else handle_tls tls t.queue t.flow in
let uid =
Hashtbl.hash (Cstruct.to_string (Cstruct.sub t.raw 0 len))
in
then handle_handshake tls t.queue t.flow raw
else handle_tls tls t.queue t.flow raw in
let before = Tls.Engine.handshake_in_progress tls in
Log.debug (fun m ->
let uid =
Hashtbl.hash
(Cstruct.to_string (Cstruct.sub t.raw 0 len)) in
m "<~ [%04x] Got %d bytes (handshake in progress: %b)."
uid len
(Tls.Engine.handshake_in_progress tls)) ;
handle (Cstruct.sub t.raw 0 len) >>? fun tls ->
let after =
Option.fold ~none:false
~some:Tls.Engine.handshake_in_progress tls in
t.tls <- tls ;
recv t raw))
match (tls, before, after) with
| Some _, false, false | Some _, true, false ->
return (Ok (`Input 0))
| Some _, false, true (* renegociate *)
| Some _, true, true (* continue handshake *)
| None, _, _ ->
Log.debug (fun m -> m "Retry to receive something.") ;
recv t raw)))
| _ ->
let max = Cstruct.len raw in
let len = min (Ke.length t.queue) max in
Expand All @@ -262,6 +279,9 @@ struct
Log.warn (fun m -> m "[-] Underlying flow already closed.") ;
t.tls <- None ;
return (Error `Closed_by_peer)
| `Input 0 ->
t.tls <- Some tls ;
return (Ok 0)
| `Input len -> (
let res =
handle_handshake tls t.queue t.flow (Cstruct.sub t.raw 0 len)
Expand Down
26 changes: 21 additions & 5 deletions tests/ping-pong/common.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@ module type CONDITION = sig
type 'a t
end

module type IO = sig
include Conduit.IO

val yield : unit -> unit t
end

let ( <.> ) f g x = f (g x)

module Make
(IO : Conduit.IO)
(IO : IO)
(Condition : CONDITION)
(Conduit : S
with type +'a io = 'a IO.t
Expand Down Expand Up @@ -67,6 +73,7 @@ struct
| None -> (
Conduit.recv flow tmp >>? function
| `End_of_flow -> IO.return (Ok `Close)
| `Input 0 -> IO.yield () >>= go
| `Input len ->
Ke.Rke.N.push queue ~blit ~length:Cstruct.len ~off:0 ~len tmp ;
go ()) in
Expand All @@ -76,20 +83,29 @@ struct

let ping = Cstruct.of_string "ping\n"

let send flow raw =
let rec go flow raw =
Conduit.send flow raw >>? function
| 0 -> IO.yield () >>= fun () -> go flow raw
| len ->
let raw = Cstruct.shift raw len in
if Cstruct.len raw = 0 then return (Ok ()) else go flow raw in
go flow raw

let transmission flow =
let queue = Ke.Rke.create ~capacity:0x1000 Bigarray.char in
let rec go () =
getline queue flow >>= function
| Ok `Close | Error _ -> Conduit.close flow
| Ok (`Line "ping") ->
Fmt.epr "[!] received ping.\n%!" ;
Conduit.send flow pong >>? fun _ -> go ()
send flow pong >>? go
| Ok (`Line "pong") ->
Fmt.epr "[!] received pong.\n%!" ;
Conduit.send flow ping >>? fun _ -> go ()
send flow ping >>? go
| Ok (`Line line) ->
Fmt.epr "[!] received %S.\n%!" line ;
Conduit.send flow (Cstruct.of_string (line ^ "\n")) >>? fun _ ->
send flow (Cstruct.of_string (line ^ "\n")) >>? fun () ->
Conduit.close flow in
go () >>= function
| Error err -> Fmt.failwith "%a" Conduit.pp_error err
Expand All @@ -114,7 +130,7 @@ struct
let rec go = function
| [] -> Conduit.close flow
| line :: rest -> (
Conduit.send flow (Cstruct.of_string (line ^ "\n")) >>? fun _ ->
send flow (Cstruct.of_string (line ^ "\n")) >>? fun () ->
getline queue flow >>? function
| `Close -> Conduit.close flow
| `Line "pong" -> go rest
Expand Down
2 changes: 2 additions & 0 deletions tests/ping-pong/with_async.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ include Common.Make
let bind x f = Async.Deferred.bind x ~f

let return = Async.Deferred.return

let yield () = Async.Deferred.return ()
end)
(Async.Condition)
(struct
Expand Down
26 changes: 26 additions & 0 deletions tests/ping-pong/with_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,34 @@ let () = Printexc.record_backtrace true

let () = Ssl.init ()

let reporter ppf =
let report src level ~over k msgf =
let k _ =
over () ;
k () in
let with_metadata header _tags k ppf fmt =
Format.kfprintf k ppf
("%a[%a]: " ^^ fmt ^^ "\n%!")
Logs_fmt.pp_header (level, header)
Fmt.(styled `Magenta string)
(Logs.Src.name src) in
msgf @@ fun ?header ?tags fmt -> with_metadata header tags k ppf fmt in
{ Logs.report }

let () = Fmt_tty.setup_std_outputs ~style_renderer:`Ansi_tty ~utf_8:true ()

let () = Logs.set_reporter (reporter Fmt.stderr)

let () = Logs.set_level ~all:true (Some Logs.Debug)

let failwith fmt = Fmt.kstrf (fun err -> Lwt.fail (Failure err)) fmt

module Lwt = struct
include Lwt

let yield = Lwt_unix.yield
end

include Common.Make (Lwt) (Lwt_condition)
(struct
type 'a condition = 'a Lwt_condition.t
Expand Down