From 9e8d1629e854efc653c204e184ea0201c3696160 Mon Sep 17 00:00:00 2001 From: dinosaure Date: Mon, 22 Mar 2021 17:20:09 +0100 Subject: [PATCH 1/4] Add cancelation on tcpip.stack-socket --- src/stack-unix/tcpip_stack_socket.ml | 45 +++++++++++++++------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/src/stack-unix/tcpip_stack_socket.ml b/src/stack-unix/tcpip_stack_socket.ml index e0c06ff01..e287f3831 100644 --- a/src/stack-unix/tcpip_stack_socket.ml +++ b/src/stack-unix/tcpip_stack_socket.ml @@ -27,6 +27,8 @@ module V4 = struct type t = { udpv4 : UDPV4.t; tcpv4 : TCPV4.t; + stop : [ `Stopped ] Lwt.u; + switched_off : [ `Stopped ] Lwt.t; } let udpv4 { udpv4; _ } = udpv4 @@ -59,7 +61,7 @@ module V4 = struct Lwt.return_unit) >>= fun () -> loop () in - loop ()) + Lwt.pick [ t.switched_off; loop () ] >>= fun `Stopped -> Lwt_unix.close fd) let listen_tcpv4 ?keepalive t ~port callback = if port < 0 || port > 65535 then @@ -91,17 +93,16 @@ module V4 = struct Lwt.return_unit) >>= fun () -> loop () in - loop ()) + Lwt.pick [ t.switched_off; loop () ] >>= fun `Stopped -> Lwt_unix.close fd) - let listen _t = - let t, _ = Lwt.task () in - t (* TODO cancellation *) + let listen t = t.switched_off >>= fun `Stopped -> Lwt.return_unit let connect udpv4 tcpv4 = Log.info (fun f -> f "IPv4 socket stack: connect"); - Lwt.return { tcpv4; udpv4 } + let switched_off, stop = Lwt.wait () in + Lwt.return { tcpv4; udpv4; stop; switched_off; } - let disconnect _ = Lwt.return_unit + let disconnect t = Lwt.wakeup_later t.stop `Stopped ; Lwt.return_unit end module V6 = struct @@ -112,6 +113,8 @@ module V6 = struct type t = { udp : UDP.t; tcp : TCP.t; + stop : [ `Stopped ] Lwt.u; + switched_off : [ `Stopped ] Lwt.t; } let udp { udp; _ } = udp @@ -144,7 +147,7 @@ module V6 = struct Lwt.return_unit) >>= fun () -> loop () in - loop ()) + Lwt.pick [ t.switched_off; loop () ] >>= fun `Stopped -> Lwt_unix.close fd) let listen_tcp ?keepalive t ~port callback = if port < 0 || port > 65535 then @@ -177,17 +180,16 @@ module V6 = struct Lwt.return_unit) >>= fun () -> loop () in - loop ()) + Lwt.pick [ t.switched_off; loop () ] >>= fun `Stopped -> Lwt_unix.close fd) - let listen _t = - let t, _ = Lwt.task () in - t (* TODO cancellation *) + let listen t = t.switched_off >>= fun `Stopped -> Lwt.return_unit let connect udp tcp = Log.info (fun f -> f "IPv6 socket stack: connect"); - Lwt.return { tcp; udp } + let switched_off, stop = Lwt.wait () in + Lwt.return { tcp; udp; stop; switched_off; } - let disconnect _ = Lwt.return_unit + let disconnect t = Lwt.wakeup_later t.stop `Stopped ; Lwt.return_unit end module V4V6 = struct @@ -198,6 +200,8 @@ module V4V6 = struct type t = { udp : UDP.t; tcp : TCP.t; + stop : [ `Stopped ] Lwt.u; + switched_off : [ `Stopped ] Lwt.t; } let udp { udp; _ } = udp @@ -232,7 +236,7 @@ module V4V6 = struct Lwt.return_unit) >>= fun () -> loop () in - loop ())) fds) + Lwt.pick [ t.switched_off; loop () ] >>= fun `Stopped -> Lwt_unix.close fd)) fds) let listen_tcp ?keepalive t ~port callback = if port < 0 || port > 65535 then @@ -287,15 +291,14 @@ module V4V6 = struct Lwt.return_unit) >>= fun () -> loop () in - loop ())) fds + Lwt.pick [ t.switched_off; loop () ] >>= fun `Stopped -> Lwt_unix.close fd)) fds - let listen _t = - let t, _ = Lwt.task () in - t (* TODO cancellation *) + let listen t = t.switched_off >>= fun `Stopped -> Lwt.return_unit let connect udp tcp = Log.info (fun f -> f "Dual IPv4 and IPv6 socket stack: connect"); - Lwt.return { tcp; udp } + let switched_off, stop = Lwt.wait () in + Lwt.return { tcp; udp; stop; switched_off; } - let disconnect _ = Lwt.return_unit + let disconnect t = Lwt.wakeup_later t.stop `Stopped ; Lwt.return_unit end From 3d0b14fb26e5a6d5cf5a2e4e645c5e949851c4f6 Mon Sep 17 00:00:00 2001 From: dinosaure Date: Wed, 24 Mar 2021 05:27:19 +0100 Subject: [PATCH 2/4] Apply @talex5's comments on a better cancelation --- src/stack-unix/tcpip_stack_socket.ml | 46 +++++++++++++++------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/src/stack-unix/tcpip_stack_socket.ml b/src/stack-unix/tcpip_stack_socket.ml index e287f3831..56e9c1f60 100644 --- a/src/stack-unix/tcpip_stack_socket.ml +++ b/src/stack-unix/tcpip_stack_socket.ml @@ -27,8 +27,8 @@ module V4 = struct type t = { udpv4 : UDPV4.t; tcpv4 : TCPV4.t; - stop : [ `Stopped ] Lwt.u; - switched_off : [ `Stopped ] Lwt.t; + stop : unit Lwt.u; + switched_off : unit Lwt.t; } let udpv4 { udpv4; _ } = udpv4 @@ -46,7 +46,7 @@ module V4 = struct UDPV4.get_udpv4_listening_fd t.udpv4 port >>= fun fd -> let buf = Cstruct.create 4096 in let rec loop () = - (* TODO cancellation *) + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; Lwt.catch (fun () -> Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) -> let buf = Cstruct.sub buf 0 len in @@ -61,7 +61,8 @@ module V4 = struct Lwt.return_unit) >>= fun () -> loop () in - Lwt.pick [ t.switched_off; loop () ] >>= fun `Stopped -> Lwt_unix.close fd) + Lwt.catch loop (fun _ -> Lwt.return_unit) >>= fun () -> + Lwt_unix.close fd) let listen_tcpv4 ?keepalive t ~port callback = if port < 0 || port > 65535 then @@ -76,6 +77,7 @@ module V4 = struct (* TODO cancellation *) let rec loop () = Lwt.catch (fun () -> + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; Lwt_unix.accept fd >|= fun (afd, _) -> (match keepalive with | None -> () @@ -93,16 +95,16 @@ module V4 = struct Lwt.return_unit) >>= fun () -> loop () in - Lwt.pick [ t.switched_off; loop () ] >>= fun `Stopped -> Lwt_unix.close fd) + Lwt.catch loop (fun _-> Lwt.return_unit) >>= fun () -> Lwt_unix.close fd) - let listen t = t.switched_off >>= fun `Stopped -> Lwt.return_unit + let listen t = t.switched_off >>= fun () -> Lwt.return_unit let connect udpv4 tcpv4 = Log.info (fun f -> f "IPv4 socket stack: connect"); let switched_off, stop = Lwt.wait () in Lwt.return { tcpv4; udpv4; stop; switched_off; } - let disconnect t = Lwt.wakeup_later t.stop `Stopped ; Lwt.return_unit + let disconnect t = Lwt.wakeup_later t.stop () ; Lwt.return_unit end module V6 = struct @@ -113,8 +115,8 @@ module V6 = struct type t = { udp : UDP.t; tcp : TCP.t; - stop : [ `Stopped ] Lwt.u; - switched_off : [ `Stopped ] Lwt.t; + stop : unit Lwt.u; + switched_off : unit Lwt.t; } let udp { udp; _ } = udp @@ -132,7 +134,7 @@ module V6 = struct UDP.get_udpv6_listening_fd t.udp port >>= fun fd -> let buf = Cstruct.create 4096 in let rec loop () = - (* TODO cancellation *) + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; Lwt.catch (fun () -> Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) -> let buf = Cstruct.sub buf 0 len in @@ -147,7 +149,7 @@ module V6 = struct Lwt.return_unit) >>= fun () -> loop () in - Lwt.pick [ t.switched_off; loop () ] >>= fun `Stopped -> Lwt_unix.close fd) + Lwt.catch loop (fun _ -> Lwt.return_unit) >>= fun () -> Lwt_unix.close fd) let listen_tcp ?keepalive t ~port callback = if port < 0 || port > 65535 then @@ -162,6 +164,7 @@ module V6 = struct Lwt.async (fun () -> (* TODO cancellation *) let rec loop () = + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; Lwt.catch (fun () -> Lwt_unix.accept fd >|= fun (afd, _) -> (match keepalive with @@ -180,16 +183,16 @@ module V6 = struct Lwt.return_unit) >>= fun () -> loop () in - Lwt.pick [ t.switched_off; loop () ] >>= fun `Stopped -> Lwt_unix.close fd) + Lwt.catch loop (fun _ -> Lwt.return_unit) >>= fun () -> Lwt_unix.close fd) - let listen t = t.switched_off >>= fun `Stopped -> Lwt.return_unit + let listen t = t.switched_off >>= fun () -> Lwt.return_unit let connect udp tcp = Log.info (fun f -> f "IPv6 socket stack: connect"); let switched_off, stop = Lwt.wait () in Lwt.return { tcp; udp; stop; switched_off; } - let disconnect t = Lwt.wakeup_later t.stop `Stopped ; Lwt.return_unit + let disconnect t = Lwt.wakeup_later t.stop () ; Lwt.return_unit end module V4V6 = struct @@ -200,8 +203,8 @@ module V4V6 = struct type t = { udp : UDP.t; tcp : TCP.t; - stop : [ `Stopped ] Lwt.u; - switched_off : [ `Stopped ] Lwt.t; + stop : unit Lwt.u; + switched_off : unit Lwt.t; } let udp { udp; _ } = udp @@ -221,7 +224,7 @@ module V4V6 = struct Lwt.async (fun () -> let buf = Cstruct.create 4096 in let rec loop () = - (* TODO cancellation *) + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; Lwt.catch (fun () -> Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) -> let buf = Cstruct.sub buf 0 len in @@ -236,7 +239,7 @@ module V4V6 = struct Lwt.return_unit) >>= fun () -> loop () in - Lwt.pick [ t.switched_off; loop () ] >>= fun `Stopped -> Lwt_unix.close fd)) fds) + Lwt.catch loop (fun _ -> Lwt.return_unit) >>= fun () -> Lwt_unix.close fd)) fds) let listen_tcp ?keepalive t ~port callback = if port < 0 || port > 65535 then @@ -273,6 +276,7 @@ module V4V6 = struct Lwt.async (fun () -> (* TODO cancellation *) let rec loop () = + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; Lwt.catch (fun () -> Lwt_unix.accept fd >|= fun (afd, _) -> (match keepalive with @@ -291,14 +295,14 @@ module V4V6 = struct Lwt.return_unit) >>= fun () -> loop () in - Lwt.pick [ t.switched_off; loop () ] >>= fun `Stopped -> Lwt_unix.close fd)) fds + Lwt.catch loop (fun _ -> Lwt.return_unit) >>= fun () -> Lwt_unix.close fd)) fds - let listen t = t.switched_off >>= fun `Stopped -> Lwt.return_unit + let listen t = t.switched_off >>= fun () -> Lwt.return_unit let connect udp tcp = Log.info (fun f -> f "Dual IPv4 and IPv6 socket stack: connect"); let switched_off, stop = Lwt.wait () in Lwt.return { tcp; udp; stop; switched_off; } - let disconnect t = Lwt.wakeup_later t.stop `Stopped ; Lwt.return_unit + let disconnect t = Lwt.wakeup_later t.stop () ; Lwt.return_unit end From 4fcb3a66004268d1298b5ef0102cbb28ab5a64fc Mon Sep 17 00:00:00 2001 From: Calascibetta Romain Date: Thu, 25 Mar 2021 16:27:47 +0100 Subject: [PATCH 3/4] Update src/stack-unix/tcpip_stack_socket.ml Co-authored-by: Thomas Leonard --- src/stack-unix/tcpip_stack_socket.ml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stack-unix/tcpip_stack_socket.ml b/src/stack-unix/tcpip_stack_socket.ml index 56e9c1f60..cb2cf380e 100644 --- a/src/stack-unix/tcpip_stack_socket.ml +++ b/src/stack-unix/tcpip_stack_socket.ml @@ -97,7 +97,7 @@ module V4 = struct in Lwt.catch loop (fun _-> Lwt.return_unit) >>= fun () -> Lwt_unix.close fd) - let listen t = t.switched_off >>= fun () -> Lwt.return_unit + let listen t = t.switched_off let connect udpv4 tcpv4 = Log.info (fun f -> f "IPv4 socket stack: connect"); From 509cb2c71b4053f5700e30932650043cfd94d4e9 Mon Sep 17 00:00:00 2001 From: dinosaure Date: Thu, 25 Mar 2021 02:28:07 +0100 Subject: [PATCH 4/4] Apply @talex5's review --- src/stack-unix/tcpip_stack_socket.ml | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/stack-unix/tcpip_stack_socket.ml b/src/stack-unix/tcpip_stack_socket.ml index cb2cf380e..cc001ea06 100644 --- a/src/stack-unix/tcpip_stack_socket.ml +++ b/src/stack-unix/tcpip_stack_socket.ml @@ -19,6 +19,10 @@ open Lwt.Infix let src = Logs.Src.create "tcpip-stack-socket" ~doc:"Platform's native TCP/IP stack" module Log = (val Logs.src_log src : Logs.LOG) +let ignore_canceled = function + | Lwt.Canceled -> Lwt.return_unit + | exn -> raise exn + module V4 = struct module TCPV4 = Tcpv4_socket module UDPV4 = Udpv4_socket @@ -61,7 +65,7 @@ module V4 = struct Lwt.return_unit) >>= fun () -> loop () in - Lwt.catch loop (fun _ -> Lwt.return_unit) >>= fun () -> + Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd) let listen_tcpv4 ?keepalive t ~port callback = @@ -76,8 +80,8 @@ module V4 = struct Lwt.async (fun () -> (* TODO cancellation *) let rec loop () = + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; Lwt.catch (fun () -> - if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; Lwt_unix.accept fd >|= fun (afd, _) -> (match keepalive with | None -> () @@ -95,7 +99,7 @@ module V4 = struct Lwt.return_unit) >>= fun () -> loop () in - Lwt.catch loop (fun _-> Lwt.return_unit) >>= fun () -> Lwt_unix.close fd) + Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd) let listen t = t.switched_off @@ -149,7 +153,7 @@ module V6 = struct Lwt.return_unit) >>= fun () -> loop () in - Lwt.catch loop (fun _ -> Lwt.return_unit) >>= fun () -> Lwt_unix.close fd) + Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd) let listen_tcp ?keepalive t ~port callback = if port < 0 || port > 65535 then @@ -183,9 +187,9 @@ module V6 = struct Lwt.return_unit) >>= fun () -> loop () in - Lwt.catch loop (fun _ -> Lwt.return_unit) >>= fun () -> Lwt_unix.close fd) + Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd) - let listen t = t.switched_off >>= fun () -> Lwt.return_unit + let listen t = t.switched_off let connect udp tcp = Log.info (fun f -> f "IPv6 socket stack: connect"); @@ -239,7 +243,7 @@ module V4V6 = struct Lwt.return_unit) >>= fun () -> loop () in - Lwt.catch loop (fun _ -> Lwt.return_unit) >>= fun () -> Lwt_unix.close fd)) fds) + Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd)) fds) let listen_tcp ?keepalive t ~port callback = if port < 0 || port > 65535 then @@ -295,9 +299,9 @@ module V4V6 = struct Lwt.return_unit) >>= fun () -> loop () in - Lwt.catch loop (fun _ -> Lwt.return_unit) >>= fun () -> Lwt_unix.close fd)) fds + Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd)) fds - let listen t = t.switched_off >>= fun () -> Lwt.return_unit + let listen t = t.switched_off let connect udp tcp = Log.info (fun f -> f "Dual IPv4 and IPv6 socket stack: connect");