diff --git a/otherlibs/dune-rpc-lwt/src/dune_rpc_lwt.ml b/otherlibs/dune-rpc-lwt/src/dune_rpc_lwt.ml index dc0bdbb1514..e4d17424f32 100644 --- a/otherlibs/dune-rpc-lwt/src/dune_rpc_lwt.ml +++ b/otherlibs/dune-rpc-lwt/src/dune_rpc_lwt.ml @@ -85,10 +85,10 @@ module V1 = struct loop 0 Stack.Empty ;; - let write (_, o) = function - | None -> Lwt_io.close o - | Some csexps -> - Lwt_list.iter_s (fun sexp -> Lwt_io.write o (Csexp.to_string sexp)) csexps + let close (_, o) = Lwt_io.close o + + let write (_, o) csexps = + Lwt_list.iter_s (fun sexp -> Lwt_io.write o (Csexp.to_string sexp)) csexps ;; end) diff --git a/otherlibs/dune-rpc/private/dune_rpc_private.ml b/otherlibs/dune-rpc/private/dune_rpc_private.ml index b1b26dd50f0..82ca9005291 100644 --- a/otherlibs/dune-rpc/private/dune_rpc_private.ml +++ b/otherlibs/dune-rpc/private/dune_rpc_private.ml @@ -139,7 +139,8 @@ module Client = struct (Chan : sig type t - val write : t -> Sexp.t list option -> unit Fiber.t + val close : t -> unit Fiber.t + val write : t -> Sexp.t list -> unit Fiber.t val read : t -> Sexp.t option Fiber.t end) = struct @@ -149,7 +150,8 @@ module Client = struct module Chan = struct type t = { read : unit -> Sexp.t option Fiber.t - ; write : Sexp.t list option -> unit Fiber.t + ; write : Sexp.t list -> unit Fiber.t + ; close : unit -> unit Fiber.t ; closed_read : bool ; mutable closed_write : bool ; disconnected : unit Fiber.Ivar.t @@ -167,22 +169,25 @@ module Client = struct in { read ; write = (fun s -> Chan.write c s) + ; close = (fun () -> Chan.close c) ; closed_read = false ; closed_write = false ; disconnected } ;; + let close t = + let* () = Fiber.return () in + if t.closed_write + then Fiber.return () + else ( + t.closed_write <- true; + t.close ()) + ;; + let write t s = let* () = Fiber.return () in - match s with - | Some _ -> t.write s - | None -> - if t.closed_write - then Fiber.return () - else ( - t.closed_write <- true; - t.write None) + t.write s ;; let read t = @@ -253,7 +258,7 @@ module Client = struct Some x) in Fiber.fork_and_join_unit - (fun () -> Chan.write t.chan None) + (fun () -> Chan.close t.chan) (fun () -> Fiber.parallel_iter ivars ~f:(fun status -> match status with @@ -271,9 +276,8 @@ module Client = struct Code_error.raise message info) ;; - let send conn (packet : Packet.t list option) = - let sexps = Option.map packet ~f:(List.map ~f:(Conv.to_sexp Packet.sexp)) in - Chan.write conn.chan sexps + let send conn (packet : Packet.t list) = + List.map ~f:(Conv.to_sexp Packet.sexp) packet |> Chan.write conn.chan ;; let create ~chan ~initialize ~handler ~on_preemptive_abort = @@ -317,7 +321,7 @@ module Client = struct match prepare_request' conn (id, req) with | Error e -> Fiber.return (`Completed (Error e)) | Ok ivar -> - let* () = send conn (Some [ Request (id, req) ]) in + let* () = send conn [ Request (id, req) ] in Fiber.Ivar.read ivar ;; @@ -400,7 +404,7 @@ module Client = struct let notification (type a) t (stg : a Versioned.notification) (n : a) = let* () = Fiber.return () in - make_notification t stg n (fun call -> send t (Some [ Notification call ])) + make_notification t stg n (fun call -> send t [ Notification call ]) ;; let disconnected t = Fiber.Ivar.read t.chan.disconnected @@ -539,7 +543,7 @@ module Client = struct let* () = Fiber.return () in let pending = List.rev t.pending in t.pending <- []; - send t.client (Some pending) + send t.client pending ;; end @@ -573,7 +577,7 @@ module Client = struct | Request (id, req) -> let* handler = t.handler in let* result = V.Handler.handle_request handler () (id, req) in - send t (Some [ Response (id, result) ]) + send t [ Response (id, result) ] | Response (id, response) -> (match Table.find t.requests id with | Some status -> @@ -737,7 +741,7 @@ module Client = struct in client.handler_initialized <- true; let* () = Fiber.Ivar.fill handler_var handler in - Fiber.finalize (fun () -> f client) ~finally:(fun () -> Chan.write chan None) + Fiber.finalize (fun () -> f client) ~finally:(fun () -> Chan.close chan) in Fiber.fork_and_join_unit (fun () -> read_packets client packets) run ;; diff --git a/otherlibs/dune-rpc/private/dune_rpc_private.mli b/otherlibs/dune-rpc/private/dune_rpc_private.mli index f6eb99ce562..bab7da58d09 100644 --- a/otherlibs/dune-rpc/private/dune_rpc_private.mli +++ b/otherlibs/dune-rpc/private/dune_rpc_private.mli @@ -486,7 +486,8 @@ module Client : sig (Chan : sig type t - val write : t -> Csexp.t list option -> unit Fiber.t + val close : t -> unit Fiber.t + val write : t -> Csexp.t list -> unit Fiber.t val read : t -> Csexp.t option Fiber.t end) : S with type 'a fiber := 'a Fiber.t and type chan := Chan.t end diff --git a/otherlibs/dune-rpc/v1.mli b/otherlibs/dune-rpc/v1.mli index 5cf342176a1..e5aef2db2f1 100644 --- a/otherlibs/dune-rpc/v1.mli +++ b/otherlibs/dune-rpc/v1.mli @@ -477,9 +477,11 @@ module Client : sig (Chan : sig type t - (* [write t x] writes the s-expression when [x] is [Some _], and closes - the session if [x = None] *) - val write : t -> Csexp.t list option -> unit Fiber.t + (* [write t x] writes the s-expression*) + val write : t -> Csexp.t list -> unit Fiber.t + + (* closes the session *) + val close : t -> unit Fiber.t (* [read t] attempts to read from [t]. If an s-expression is read, it is returned as [Some sexp], otherwise [None] is returned and the session diff --git a/src/dune_rpc_client/client.ml b/src/dune_rpc_client/client.ml index 0530b81fc39..e64f5a1cf7e 100644 --- a/src/dune_rpc_client/client.ml +++ b/src/dune_rpc_client/client.ml @@ -7,13 +7,11 @@ include (struct include Csexp_rpc.Session - let write t = function - | None -> close t - | Some packets -> - write t packets - >>| (function - | Ok () -> () - | Error `Closed -> raise Dune_util.Report_error.Already_reported) + let write t packets = + write t packets + >>| function + | Ok () -> () + | Error `Closed -> raise Dune_util.Report_error.Already_reported ;; end) diff --git a/test/expect-tests/dune_rpc/dune_rpc_tests.ml b/test/expect-tests/dune_rpc/dune_rpc_tests.ml index 3fe3f074858..ef81c58b47c 100644 --- a/test/expect-tests/dune_rpc/dune_rpc_tests.ml +++ b/test/expect-tests/dune_rpc/dune_rpc_tests.ml @@ -48,10 +48,7 @@ module Drpc = struct (struct include Chan - let write t = function - | None -> close t - | Some packets -> write t packets >>| Result.ok_exn - ;; + let write t packets = write t packets >>| Result.ok_exn end) module Server = Dune_rpc_server.Make (Chan)