Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 35 additions & 24 deletions src/stack-unix/tcpip_stack_socket.ml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ open Lwt.Infix
let src = Logs.Src.create "tcpip-stack-socket" ~doc:"Platform's native TCP/IP stack"
module Log = (val Logs.src_log src : Logs.LOG)

let ignore_canceled = function
| Lwt.Canceled -> Lwt.return_unit
| exn -> raise exn

module V4 = struct
module TCPV4 = Tcpv4_socket
module UDPV4 = Udpv4_socket
Expand All @@ -27,6 +31,8 @@ module V4 = struct
type t = {
udpv4 : UDPV4.t;
tcpv4 : TCPV4.t;
stop : unit Lwt.u;
switched_off : unit Lwt.t;
}

let udpv4 { udpv4; _ } = udpv4
Expand All @@ -44,7 +50,7 @@ module V4 = struct
UDPV4.get_udpv4_listening_fd t.udpv4 port >>= fun fd ->
let buf = Cstruct.create 4096 in
let rec loop () =
(* TODO cancellation *)
if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ;
Lwt.catch (fun () ->
Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) ->
let buf = Cstruct.sub buf 0 len in
Expand All @@ -59,7 +65,8 @@ module V4 = struct
Lwt.return_unit) >>= fun () ->
loop ()
in
loop ())
Lwt.catch loop ignore_canceled >>= fun () ->
Lwt_unix.close fd)

let listen_tcpv4 ?keepalive t ~port callback =
if port < 0 || port > 65535 then
Expand All @@ -73,6 +80,7 @@ module V4 = struct
Lwt.async (fun () ->
(* TODO cancellation *)
let rec loop () =
if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ;
Lwt.catch (fun () ->
Lwt_unix.accept fd >|= fun (afd, _) ->
(match keepalive with
Expand All @@ -91,17 +99,16 @@ module V4 = struct
Lwt.return_unit) >>= fun () ->
loop ()
in
loop ())
Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd)

let listen _t =
let t, _ = Lwt.task () in
t (* TODO cancellation *)
let listen t = t.switched_off

let connect udpv4 tcpv4 =
Log.info (fun f -> f "IPv4 socket stack: connect");
Lwt.return { tcpv4; udpv4 }
let switched_off, stop = Lwt.wait () in
Lwt.return { tcpv4; udpv4; stop; switched_off; }

let disconnect _ = Lwt.return_unit
let disconnect t = Lwt.wakeup_later t.stop () ; Lwt.return_unit
end

module V6 = struct
Expand All @@ -112,6 +119,8 @@ module V6 = struct
type t = {
udp : UDP.t;
tcp : TCP.t;
stop : unit Lwt.u;
switched_off : unit Lwt.t;
}

let udp { udp; _ } = udp
Expand All @@ -129,7 +138,7 @@ module V6 = struct
UDP.get_udpv6_listening_fd t.udp port >>= fun fd ->
let buf = Cstruct.create 4096 in
let rec loop () =
(* TODO cancellation *)
if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ;
Lwt.catch (fun () ->
Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) ->
let buf = Cstruct.sub buf 0 len in
Expand All @@ -144,7 +153,7 @@ module V6 = struct
Lwt.return_unit) >>= fun () ->
loop ()
in
loop ())
Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd)

let listen_tcp ?keepalive t ~port callback =
if port < 0 || port > 65535 then
Expand All @@ -159,6 +168,7 @@ module V6 = struct
Lwt.async (fun () ->
(* TODO cancellation *)
let rec loop () =
if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ;
Lwt.catch (fun () ->
Lwt_unix.accept fd >|= fun (afd, _) ->
(match keepalive with
Expand All @@ -177,17 +187,16 @@ module V6 = struct
Lwt.return_unit) >>= fun () ->
loop ()
in
loop ())
Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd)

let listen _t =
let t, _ = Lwt.task () in
t (* TODO cancellation *)
let listen t = t.switched_off

let connect udp tcp =
Log.info (fun f -> f "IPv6 socket stack: connect");
Lwt.return { tcp; udp }
let switched_off, stop = Lwt.wait () in
Lwt.return { tcp; udp; stop; switched_off; }

let disconnect _ = Lwt.return_unit
let disconnect t = Lwt.wakeup_later t.stop () ; Lwt.return_unit
end

module V4V6 = struct
Expand All @@ -198,6 +207,8 @@ module V4V6 = struct
type t = {
udp : UDP.t;
tcp : TCP.t;
stop : unit Lwt.u;
switched_off : unit Lwt.t;
}

let udp { udp; _ } = udp
Expand All @@ -217,7 +228,7 @@ module V4V6 = struct
Lwt.async (fun () ->
let buf = Cstruct.create 4096 in
let rec loop () =
(* TODO cancellation *)
if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ;
Lwt.catch (fun () ->
Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) ->
let buf = Cstruct.sub buf 0 len in
Expand All @@ -232,7 +243,7 @@ module V4V6 = struct
Lwt.return_unit) >>= fun () ->
loop ()
in
loop ())) fds)
Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd)) fds)

let listen_tcp ?keepalive t ~port callback =
if port < 0 || port > 65535 then
Expand Down Expand Up @@ -269,6 +280,7 @@ module V4V6 = struct
Lwt.async (fun () ->
(* TODO cancellation *)
let rec loop () =
if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ;
Lwt.catch (fun () ->
Lwt_unix.accept fd >|= fun (afd, _) ->
(match keepalive with
Expand All @@ -287,15 +299,14 @@ module V4V6 = struct
Lwt.return_unit) >>= fun () ->
loop ()
in
loop ())) fds
Lwt.catch loop ignore_canceled >>= fun () -> Lwt_unix.close fd)) fds

let listen _t =
let t, _ = Lwt.task () in
t (* TODO cancellation *)
let listen t = t.switched_off

let connect udp tcp =
Log.info (fun f -> f "Dual IPv4 and IPv6 socket stack: connect");
Lwt.return { tcp; udp }
let switched_off, stop = Lwt.wait () in
Lwt.return { tcp; udp; stop; switched_off; }

let disconnect _ = Lwt.return_unit
let disconnect t = Lwt.wakeup_later t.stop () ; Lwt.return_unit
end