diff --git a/src/stack-direct/tcpip_stack_direct.ml b/src/stack-direct/tcpip_stack_direct.ml index cce23aade..dd8d63ba3 100644 --- a/src/stack-direct/tcpip_stack_direct.ml +++ b/src/stack-direct/tcpip_stack_direct.ml @@ -53,6 +53,7 @@ module Make tcpv4 : Tcpv4.t; udpv4_listeners: (int, Udpv4.callback) Hashtbl.t; tcpv4_listeners: (int, Tcpv4.listener) Hashtbl.t; + mutable task : unit Lwt.t option; } let pp fmt t = @@ -85,51 +86,57 @@ module Make with Not_found -> None let listen t = - Log.debug (fun f -> f "Establishing or updating listener for stack %a" pp t); - let ethif_listener = Ethernet.input - ~arpv4:(Arpv4.input t.arpv4) - ~ipv4:( - Ipv4.input - ~tcp:(Tcpv4.input t.tcpv4 - ~listeners:(tcpv4_listeners t)) - ~udp:(Udpv4.input t.udpv4 - ~listeners:(udpv4_listeners t)) - ~default:(fun ~proto ~src ~dst buf -> - match proto with - | 1 -> Icmpv4.input t.icmpv4 ~src ~dst buf - | _ -> Lwt.return_unit) - t.ipv4) - ~ipv6:(fun _ -> Lwt.return_unit) - t.ethif - in - Netif.listen t.netif ~header_size:Ethernet_wire.sizeof_ethernet ethif_listener - >>= function - | Error e -> - Log.warn (fun p -> p "%a" Netif.pp_error e) ; - (* XXX: error should be passed to the caller *) - Lwt.return_unit - | Ok _res -> - let nstat = Netif.get_stats_counters t.netif in - let open Mirage_net in - Log.info (fun f -> - f "listening loop of interface %s terminated regularly:@ %Lu bytes \ - (%lu packets) received, %Lu bytes (%lu packets) sent@ " - (Macaddr.to_string (Netif.mac t.netif)) - nstat.rx_bytes nstat.rx_pkts - nstat.tx_bytes nstat.tx_pkts) ; - Lwt.return_unit + Lwt.catch (fun () -> + Log.debug (fun f -> f "Establishing or updating listener for stack %a" pp t); + let ethif_listener = Ethernet.input + ~arpv4:(Arpv4.input t.arpv4) + ~ipv4:( + Ipv4.input + ~tcp:(Tcpv4.input t.tcpv4 + ~listeners:(tcpv4_listeners t)) + ~udp:(Udpv4.input t.udpv4 + ~listeners:(udpv4_listeners t)) + ~default:(fun ~proto ~src ~dst buf -> + match proto with + | 1 -> Icmpv4.input t.icmpv4 ~src ~dst buf + | _ -> Lwt.return_unit) + t.ipv4) + ~ipv6:(fun _ -> Lwt.return_unit) + t.ethif + in + Netif.listen t.netif ~header_size:Ethernet_wire.sizeof_ethernet ethif_listener + >>= function + | Error e -> + Log.warn (fun p -> p "%a" Netif.pp_error e) ; + (* XXX: error should be passed to the caller *) + Lwt.return_unit + | Ok _res -> + let nstat = Netif.get_stats_counters t.netif in + let open Mirage_net in + Log.info (fun f -> + f "listening loop of interface %s terminated regularly:@ %Lu bytes \ + (%lu packets) received, %Lu bytes (%lu packets) sent@ " + (Macaddr.to_string (Netif.mac t.netif)) + nstat.rx_bytes nstat.rx_pkts + nstat.tx_bytes nstat.tx_pkts) ; + Lwt.return_unit) + (function + | Lwt.Canceled -> + Log.info (fun f -> f "listen of %a cancelled" pp t); + Lwt.return_unit + | e -> Lwt.fail e) let connect netif ethif arpv4 ipv4 icmpv4 udpv4 tcpv4 = let udpv4_listeners = Hashtbl.create 7 in let tcpv4_listeners = Hashtbl.create 7 in let t = { netif; ethif; arpv4; ipv4; icmpv4; tcpv4; udpv4; - udpv4_listeners; tcpv4_listeners } in + udpv4_listeners; tcpv4_listeners; task = None } in Log.info (fun f -> f "stack assembled: %a" pp t); - Lwt.async (fun () -> listen t); + Lwt.async (fun () -> let task = listen t in t.task <- Some task; task); Lwt.return t let disconnect t = - (* TODO: kill the listening thread *) - Log.info (fun f -> f "disconnect called (currently a noop): %a" pp t); + Log.info (fun f -> f "disconnect called: %a" pp t); + (match t.task with None -> () | Some task -> Lwt.cancel task); Lwt.return_unit end diff --git a/src/tcp/flow.ml b/src/tcp/flow.ml index 2e79b3420..f74fbe21d 100644 --- a/src/tcp/flow.ml +++ b/src/tcp/flow.ml @@ -71,6 +71,7 @@ struct type t = { ip : Ip.t; + mutable active : bool ; mutable localport : int; channels: (WIRE.t, connection) Hashtbl.t; (* server connections the process of connecting - SYN-ACK sent @@ -537,19 +538,23 @@ struct >>= fun _ -> Lwt.return_unit (* if send fails, who cares *) let input_no_pcb t listeners (parsed, payload) id = - let { sequence; Tcp_packet.ack_number; window; options; syn; fin; rst; ack; _ } = parsed in - match rst, syn, ack with - | true, _, _ -> process_reset t id ~ack ~ack_number - | false, true, true -> - process_synack t id ~ack_number ~sequence ~tx_wnd:window ~options ~syn ~fin - | false, true , false -> process_syn t id ~listeners ~tx_wnd:window - ~ack_number ~sequence ~options ~syn ~fin - | false, false, true -> - let open RXS in - process_ack t id ~pkt:{ header = parsed; payload} - | false, false, false -> - Log.debug (fun f -> f "incoming packet matches no connection table entry and has no useful flags set; dropping it"); + if not t.active then + (* TODO: eventually send an RST? *) Lwt.return_unit + else + let { sequence; Tcp_packet.ack_number; window; options; syn; fin; rst; ack; _ } = parsed in + match rst, syn, ack with + | true, _, _ -> process_reset t id ~ack ~ack_number + | false, true, true -> + process_synack t id ~ack_number ~sequence ~tx_wnd:window ~options ~syn ~fin + | false, true , false -> process_syn t id ~listeners ~tx_wnd:window + ~ack_number ~sequence ~options ~syn ~fin + | false, false, true -> + let open RXS in + process_ack t id ~pkt:{ header = parsed; payload} + | false, false, false -> + Log.debug (fun f -> f "incoming packet matches no connection table entry and has no useful flags set; dropping it"); + Lwt.return_unit (* Main input function for TCP packets *) let input t ~listeners ~src ~dst data = @@ -714,9 +719,12 @@ struct pp_error e Ip.pp_ipaddr daddr dport) let create_connection ?keepalive tcp (daddr, dport) = - connect ?keepalive tcp ~dst:daddr ~dst_port:dport >>= function - | Error e -> log_failure daddr dport e; Lwt.return @@ Error e - | Ok (fl, _) -> Lwt.return (Ok fl) + if not tcp.active then + Lwt.return (Error `Timeout) (* TODO: custom error variant *) + else + connect ?keepalive tcp ~dst:daddr ~dst_port:dport >>= function + | Error e -> log_failure daddr dport e; Lwt.return @@ Error e + | Ok (fl, _) -> Lwt.return (Ok fl) (* Construct the main TCP thread *) let connect ip = @@ -726,7 +734,13 @@ struct let listens = Hashtbl.create 1 in let connects = Hashtbl.create 1 in let channels = Hashtbl.create 7 in - Lwt.return { ip; localport; channels; listens; connects } - - let disconnect _ = Lwt.return_unit + Lwt.return { ip; active = true; localport; channels; listens; connects } + + let disconnect t = + t.active <- false; + let conns = Hashtbl.fold (fun _ (pcb, _) acc -> pcb :: acc) t.channels [] in + Lwt_list.iter_p close conns >|= fun () -> + Hashtbl.reset t.listens; + Hashtbl.reset t.connects + (* TODO: should there be Lwt tasks being cancelled? *) end