diff --git a/lwt/tls_lwt.ml b/lwt/tls_lwt.ml index 5de56ed9..645185da 100644 --- a/lwt/tls_lwt.ml +++ b/lwt/tls_lwt.ml @@ -22,18 +22,19 @@ let gettimeofday = Unix.gettimeofday module Lwt_cs = struct - let naked ~name f fd cs = + let naked ~name f fd state cs = Cstruct.(f fd cs.buffer cs.off cs.len) >>= fun res -> - match Lwt_unix.getsockopt_error fd with - | None -> return res - | Some err -> fail @@ Unix.Unix_error (err, name, "") + match (Lwt_unix.getsockopt_error fd, state) with + | (None, _) -> return res + | (Some Unix.EPIPE, `Eof) -> return res + | (Some err, _) -> fail @@ Unix.Unix_error (err, name, "") let write = naked ~name:"Tls_lwt.write" Lwt_bytes.write and read = naked ~name:"Tls_lwt.read" Lwt_bytes.read - let rec write_full fd = function + let rec write_full fd state = function | cs when Cstruct.len cs = 0 -> return_unit - | cs -> write fd cs >>= o (write_full fd) (Cstruct.shift cs) + | cs -> write fd state cs >>= o (write_full fd state) (Cstruct.shift cs) end module Unix = struct @@ -53,7 +54,7 @@ module Unix = struct let (read_t, write_t) = let recording_errors op t cs = Lwt.catch - (fun () -> op t.fd cs) + (fun () -> op t.fd t.state cs) (fun exn -> t.state <- `Error exn ; fail exn)