diff --git a/src/stack-unix/tcpip_stack_socket.ml b/src/stack-unix/tcpip_stack_socket.ml index e0c06ff01..cc001ea06 100644 --- a/src/stack-unix/tcpip_stack_socket.ml +++ b/src/stack-unix/tcpip_stack_socket.ml @@ -19,6 +19,10 @@ open Lwt.Infix let src = Logs.Src.create "tcpip-stack-socket" ~doc:"Platform's native TCP/IP stack" module Log = (val Logs.src_log src : Logs.LOG) +let ignore_canceled = function + | Lwt.Canceled -> Lwt.return_unit + | exn -> raise exn + module V4 = struct module TCPV4 = Tcpv4_socket module UDPV4 = Udpv4_socket @@ -27,6 +31,8 @@ module V4 = struct type t = { udpv4 : UDPV4.t; tcpv4 : TCPV4.t; + stop : unit Lwt.u; + switched_off : unit Lwt.t; } let udpv4 { udpv4; _ } = udpv4 @@ -44,7 +50,7 @@ module V4 = struct UDPV4.get_udpv4_listening_fd t.udpv4 port >>= fun fd -> let buf = Cstruct.create 4096 in let rec loop () = - (* TODO cancellation *) + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; Lwt.catch (fun () -> Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) -> let buf = Cstruct.sub buf 0 len in @@ -59,7 +65,8 @@ module V4 = struct Lwt.return_unit) >>= fun () -> loop () in - loop ()) + Lwt.catch loop ignore_canceled >>= fun () -> + Lwt_unix.close fd) let listen_tcpv4 ?keepalive t ~port callback = if port < 0 || port > 65535 then @@ -73,6 +80,7 @@ module V4 = struct Lwt.async (fun () -> (* TODO cancellation *) let rec loop () = + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; Lwt.catch (fun () -> Lwt_unix.accept fd >|= fun (afd, _) -> (match keepalive with @@ -91,17 +99,16 @@ module V4 = struct Lwt.return_unit) >>= fun () -> loop () in - loop ()) + Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd) - let listen _t = - let t, _ = Lwt.task () in - t (* TODO cancellation *) + let listen t = t.switched_off let connect udpv4 tcpv4 = Log.info (fun f -> f "IPv4 socket stack: connect"); - Lwt.return { tcpv4; udpv4 } + let switched_off, stop = Lwt.wait () in + Lwt.return { tcpv4; udpv4; stop; switched_off; } - let disconnect _ = Lwt.return_unit + let disconnect t = Lwt.wakeup_later t.stop () ; Lwt.return_unit end module V6 = struct @@ -112,6 +119,8 @@ module V6 = struct type t = { udp : UDP.t; tcp : TCP.t; + stop : unit Lwt.u; + switched_off : unit Lwt.t; } let udp { udp; _ } = udp @@ -129,7 +138,7 @@ module V6 = struct UDP.get_udpv6_listening_fd t.udp port >>= fun fd -> let buf = Cstruct.create 4096 in let rec loop () = - (* TODO cancellation *) + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; Lwt.catch (fun () -> Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) -> let buf = Cstruct.sub buf 0 len in @@ -144,7 +153,7 @@ module V6 = struct Lwt.return_unit) >>= fun () -> loop () in - loop ()) + Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd) let listen_tcp ?keepalive t ~port callback = if port < 0 || port > 65535 then @@ -159,6 +168,7 @@ module V6 = struct Lwt.async (fun () -> (* TODO cancellation *) let rec loop () = + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; Lwt.catch (fun () -> Lwt_unix.accept fd >|= fun (afd, _) -> (match keepalive with @@ -177,17 +187,16 @@ module V6 = struct Lwt.return_unit) >>= fun () -> loop () in - loop ()) + Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd) - let listen _t = - let t, _ = Lwt.task () in - t (* TODO cancellation *) + let listen t = t.switched_off let connect udp tcp = Log.info (fun f -> f "IPv6 socket stack: connect"); - Lwt.return { tcp; udp } + let switched_off, stop = Lwt.wait () in + Lwt.return { tcp; udp; stop; switched_off; } - let disconnect _ = Lwt.return_unit + let disconnect t = Lwt.wakeup_later t.stop () ; Lwt.return_unit end module V4V6 = struct @@ -198,6 +207,8 @@ module V4V6 = struct type t = { udp : UDP.t; tcp : TCP.t; + stop : unit Lwt.u; + switched_off : unit Lwt.t; } let udp { udp; _ } = udp @@ -217,7 +228,7 @@ module V4V6 = struct Lwt.async (fun () -> let buf = Cstruct.create 4096 in let rec loop () = - (* TODO cancellation *) + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; Lwt.catch (fun () -> Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) -> let buf = Cstruct.sub buf 0 len in @@ -232,7 +243,7 @@ module V4V6 = struct Lwt.return_unit) >>= fun () -> loop () in - loop ())) fds) + Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd)) fds) let listen_tcp ?keepalive t ~port callback = if port < 0 || port > 65535 then @@ -269,6 +280,7 @@ module V4V6 = struct Lwt.async (fun () -> (* TODO cancellation *) let rec loop () = + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; Lwt.catch (fun () -> Lwt_unix.accept fd >|= fun (afd, _) -> (match keepalive with @@ -287,15 +299,14 @@ module V4V6 = struct Lwt.return_unit) >>= fun () -> loop () in - loop ())) fds + Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd)) fds - let listen _t = - let t, _ = Lwt.task () in - t (* TODO cancellation *) + let listen t = t.switched_off let connect udp tcp = Log.info (fun f -> f "Dual IPv4 and IPv6 socket stack: connect"); - Lwt.return { tcp; udp } + let switched_off, stop = Lwt.wait () in + Lwt.return { tcp; udp; stop; switched_off; } - let disconnect _ = Lwt.return_unit + let disconnect t = Lwt.wakeup_later t.stop () ; Lwt.return_unit end