diff --git a/README.md b/README.md index 773b4a1..daa9ab0 100644 --- a/README.md +++ b/README.md @@ -71,9 +71,9 @@ The µDNS library is published under the 2 clause BSD license. ## Installation You first need to install [OCaml](https://ocaml.org) (at least 4.04.0) and -[opam](https://opam.ocaml.org), the OCaml package manager (at least 1.2.2) on +[opam](https://opam.ocaml.org), the OCaml package manager (at least 2.0.0) on your machine (you can use opam to install an up-to-date OCaml (`opam switch -4.06.0`)). You may want to follow the [mirage installation +4.07.1`)). You may want to follow the [mirage installation instructions](https://mirage.io/wiki/install) to get `mirage` installed on your computer. @@ -85,4 +85,4 @@ examples at the [unikernel repository](https://github.com/roburio/unikernels). ## Documentation -Is unfortunately only in the code at the moment. +API documentation [is available online](https://roburio.github.io/udns/doc/). diff --git a/app/odns.ml b/app/odns.ml index d539753..1e43316 100644 --- a/app/odns.ml +++ b/app/odns.ml @@ -27,7 +27,9 @@ let pp_zone_tlsa ppf (domain,ttl,(tlsa:Udns_packet.tlsa)) = | n -> loop ((String.sub hex n 56)::acc) (n+56) in loop [] 0) -let do_a ((_,(ns_ip,_)) as nameserver) domains _ = +let do_a nameserver domains _ = + let t = Udns_client_lwt.create ?nameserver () in + let (_, (ns_ip, _)) = Udns_client_lwt.nameserver t in Logs.info (fun m -> m "querying NS %s for A records of %a" (Unix.string_of_inet_addr ns_ip) Fmt.(list ~sep:(unit", ") Domain_name.pp) domains); @@ -35,7 +37,7 @@ let do_a ((_,(ns_ip,_)) as nameserver) domains _ = Lwt_list.iter_p (fun domain -> let open Lwt in Logs.debug (fun m -> m "looking up %a" Domain_name.pp domain); - Udns_client_lwt.(getaddrinfo () ~nameserver Udns_map.A domain) + Udns_client_lwt.(getaddrinfo t Udns_map.A domain) >|= function | Ok (_ttl, addrs) when Udns_map.Ipv4Set.is_empty addrs -> (* handle empty response? *) @@ -53,16 +55,18 @@ let do_a ((_,(ns_ip,_)) as nameserver) domains _ = match Lwt_main.run job with | () -> Ok () (* TODO handle errors *) -let for_all_domains ((_,(ns_ip,_)) as nameserver) ~domains typ f = +let for_all_domains nameserver ~domains typ f = (* [for_all_domains] is a utility function that lets us avoid duplicating this block of code in all the subcommands. We leave {!do_a} simple to provide a more readable example. *) + let t = Udns_client_lwt.create ?nameserver () in + let _, (ns_ip, _) = Udns_client_lwt.nameserver t in Logs.info (fun m -> m "NS: %s" @@ Unix.string_of_inet_addr ns_ip); let open Lwt in match Lwt_main.run (Lwt_list.iter_p (fun domain -> - Udns_client_lwt.getaddrinfo () ~nameserver typ domain + Udns_client_lwt.getaddrinfo t typ domain >|= f domain) domains) with | () -> Ok () (* TODO catch failed jobs *) @@ -95,7 +99,7 @@ let do_txt nameserver domains _ = let do_any nameserver domains _ = for_all_domains nameserver ~domains Udns_map.Any (fun domain -> function - | Ok (rr_list, _domain_names) -> + | Ok rr_list -> List.iter (fun rr -> Logs.app (fun m -> m "%a" Udns_packet.pp_rr rr)) rr_list | Error (`Msg msg) -> @@ -143,8 +147,7 @@ let parse_ns : ('a * (Lwt_unix.inet_addr * int)) Arg.conv = let arg_ns : 'a Term.t = let doc = "IP of nameserver to use" in - Arg.(value & opt parse_ns Udns_client_lwt.default_ns - & info ~docv:"NS-IP" ~doc ["ns"]) + Arg.(value & opt (some parse_ns) None & info ~docv:"NS-IP" ~doc ["ns"]) let parse_domain : Domain_name.t Arg.conv = ( fun name -> diff --git a/client/udns_client.ml b/client/udns_client.ml index e502d15..7250e2a 100644 --- a/client/udns_client.ml +++ b/client/udns_client.ml @@ -69,12 +69,7 @@ let parse_response (type requested) ) >>= fun relevant_map -> begin match (state.key : requested Udns_map.k) with | (Udns_map.Any : requested Udns_map.k) -> - Ok (((resp.answer:Udns_packet.rr list) , - (((Udns_map.of_rrs resp.answer - |> Domain_name.Map.bindings - |> List.map fst - |> Domain_name.Set.of_list) - ) : Domain_name.Set.t)):requested) + Ok (resp.answer:requested) | _ -> begin match Udns_map.find state.key relevant_map with | Some response -> Ok response diff --git a/client/udns_client_flow.ml b/client/udns_client_flow.ml index b6116c6..bb6e948 100644 --- a/client/udns_client_flow.ml +++ b/client/udns_client_flow.ml @@ -4,10 +4,13 @@ module type S = sig type io_addr type ns_addr = ([`TCP | `UDP]) * io_addr type stack + type t - val default_ns : ns_addr + val create : ?nameserver:ns_addr -> stack -> t - val connect : stack -> ns_addr -> (flow,'err) io + val nameserver : t -> ns_addr + + val connect : ?nameserver:ns_addr -> t -> (flow,'err) io val send : flow -> Cstruct.t -> (unit,'b) io val recv : flow -> (Cstruct.t, 'b) io @@ -18,19 +21,19 @@ end module Make = functor (Uflow:S) -> struct - let default_ns = Uflow.default_ns + let create ?nameserver stack = Uflow.create ?nameserver stack + + let nameserver t = Uflow.nameserver t - let getaddrinfo (type requested) stack ?nameserver (query_type:requested Udns_map.k) name + let getaddrinfo (type requested) t ?nameserver (query_type:requested Udns_map.k) name : (requested, [> `Msg of string]) Uflow.io = - let (proto, _) as ns_addr = match nameserver with None -> Uflow.default_ns | Some x -> x in + let proto, _ = match nameserver with None -> Uflow.nameserver t | Some x -> x in let tx, state = - let cs, state = Udns_client.make_query - (match proto with `UDP -> `Udp - | `TCP -> `Tcp) name query_type in - cs, state + Udns_client.make_query + (match proto with `UDP -> `Udp | `TCP -> `Tcp) name query_type in let (>>=), (>>|) = Uflow.(resolve, map) in - Uflow.connect stack ns_addr >>| fun socket -> + Uflow.connect ?nameserver t >>| fun socket -> Logs.debug (fun m -> m "Connected to NS."); Uflow.send socket tx >>| fun () -> (* TODO steal loop logic from lwt *) diff --git a/client/udns_client_flow.mli b/client/udns_client_flow.mli index eee3007..3e0b4ad 100644 --- a/client/udns_client_flow.mli +++ b/client/udns_client_flow.mli @@ -29,12 +29,18 @@ module type S = sig type stack (** A stack with which to connect, e.g. {IPv4.tcpv4}*) - val default_ns : ns_addr + type t + (** The abstract state of a DNS client. *) + + val create : ?nameserver:ns_addr -> stack -> t + (** [create ~nameserver stack] creates the state record of the DNS client. *) + + val nameserver : t -> ns_addr (** The address of a nameserver that is supposed to work with the underlying flow, can be used if the user does not want to bother with configuring their own.*) - val connect : stack -> ns_addr -> (flow,'err) io + val connect : ?nameserver:ns_addr -> t -> (flow,'err) io (** [connect addr] is a new connection ([flow]) to [addr], or an error. *) val send : flow -> Cstruct.t -> (unit,'err) io @@ -53,12 +59,13 @@ end module Make : functor (U : S) -> sig - val default_ns : U.ns_addr - (** The address of a nameserver that is supposed to work with - the underlying flow, can be used if the user does not want to - bother with configuring their own.*) + val create : ?nameserver:U.ns_addr -> U.stack -> U.t + (** [create ~nameserver stack] creates the state of the DNS client. *) + + val nameserver : U.t -> U.ns_addr + (** [nameserver t] returns the default nameserver to be used. *) - val getaddrinfo : U.stack -> ?nameserver:U.ns_addr -> 'response Udns_map.k -> + val getaddrinfo : U.t -> ?nameserver:U.ns_addr -> 'response Udns_map.k -> Domain_name.t -> ('response, 'err) U.io (** [getaddrinfo nameserver query_type name] is the [query_type]-dependent response from [nameserver] regarding [name], or an [Error _] message. @@ -66,7 +73,7 @@ sig result types. *) - val gethostbyname : U.stack -> ?nameserver:U.ns_addr -> Domain_name.t -> + val gethostbyname : U.t -> ?nameserver:U.ns_addr -> Domain_name.t -> (Ipaddr.V4.t, 'err) U.io (** [gethostbyname ~nameserver name] is the IPv4 address of [name] resolved via the [nameserver] specified. diff --git a/lwt/client/udns_client_lwt.ml b/lwt/client/udns_client_lwt.ml index b48b0b4..7e13601 100644 --- a/lwt/client/udns_client_lwt.ml +++ b/lwt/client/udns_client_lwt.ml @@ -3,6 +3,8 @@ Lwt convenience module *) +open Lwt.Infix + module Uflow : Udns_client_flow.S with type flow = Lwt_unix.file_descr and type io_addr = Lwt_unix.inet_addr * int @@ -15,8 +17,12 @@ module Uflow : Udns_client_flow.S type (+'a,+'b) io = ('a,'b) Lwt_result.t constraint 'b = [> `Msg of string] type stack = unit + type t = { nameserver : ns_addr } + + let create ?(nameserver = `TCP, (Unix.inet_addr_of_string "91.239.100.100", 53)) () = + { nameserver } - let default_ns = `TCP, (Unix.inet_addr_of_string "91.239.100.100", 53) + let nameserver { nameserver } = nameserver let send socket tx = let open Lwt in @@ -40,8 +46,8 @@ module Uflow : Udns_client_flow.S let map = Lwt_result.bind let resolve = Lwt_result.bind_result - let connect () (proto, (server,port)) = - let open Lwt in + let connect ?nameserver:ns t = + let (proto, (server, port)) = match ns with None -> nameserver t | Some x -> x in begin match proto with | `UDP -> Lwt_unix.((getprotobyname "udp") >|= fun x -> x.p_proto, diff --git a/mirage/client/udns_mirage_client.ml b/mirage/client/udns_mirage_client.ml index 2ac6569..d872f67 100644 --- a/mirage/client/udns_mirage_client.ml +++ b/mirage/client/udns_mirage_client.ml @@ -7,24 +7,32 @@ module Make (S : Mirage_stack_lwt.V4) = struct module Uflow : Udns_client_flow.S with type flow = S.TCPV4.flow - and type stack = S.tcpv4 + and type stack = S.t and type (+'a,+'b) io = ('a, 'b) Lwt_result.t constraint 'b = [> `Msg of string] and type io_addr = Ipaddr.V4.t * int = struct type flow = S.TCPV4.flow - type stack = S.tcpv4 + type stack = S.t type io_addr = Ipaddr.V4.t * int type ns_addr = [`TCP | `UDP] * io_addr type (+'a,+'b) io = ('a, 'b) Lwt_result.t constraint 'b = [> `Msg of string] + type t = { + nameserver : ns_addr ; + stack : stack ; + } - let default_ns = `TCP, (Ipaddr.V4.of_string_exn "91.239.100.100", 53) + let create ?(nameserver = `TCP, (Ipaddr.V4.of_string_exn "91.239.100.100", 53)) stack = + { nameserver ; stack } + + let nameserver { nameserver ; _ } = nameserver let map = Lwt_result.bind let resolve = Lwt_result.bind_result - let connect stack ((_proto, (ip, port)):ns_addr) = - S.TCPV4.create_connection stack (ip, port) >|= function + let connect ?nameserver:ns t = + let _proto, addr = match ns with None -> nameserver t | Some x -> x in + S.TCPV4.create_connection (S.tcpv4 t.stack) addr >|= function | Error e -> Log.err (fun m -> m "error connecting to nameserver %a" S.TCPV4.pp_error e) ; diff --git a/mirage/client/udns_mirage_client.mli b/mirage/client/udns_mirage_client.mli index 2f1f31e..3cf4900 100644 --- a/mirage/client/udns_mirage_client.mli +++ b/mirage/client/udns_mirage_client.mli @@ -4,7 +4,7 @@ module Make (S : Mirage_stack_lwt.V4) : sig with type flow = S.TCPV4.flow and type io_addr = Ipaddr.V4.t * int and type (+'a, +'b) io = ('a, 'b) Lwt_result.t - and type stack = S.tcpv4 + and type stack = S.t include module type of Udns_client_flow.Make(Uflow) end diff --git a/resolver/udns_resolver_utils.mli b/resolver/udns_resolver_utils.mli index 2d94e9a..bc39f03 100644 --- a/resolver/udns_resolver_utils.mli +++ b/resolver/udns_resolver_utils.mli @@ -1,6 +1,7 @@ (* (c) 2017, 2018 Hannes Mehnert, all rights reserved *) -val scrub : ?mode:[ `Recursive | `Stub ] -> Domain_name.t -> Udns_packet.question -> Udns_packet.header -> Udns_packet.query -> +val scrub : ?mode:[ `Recursive | `Stub ] -> Domain_name.t -> + Udns_packet.question -> Udns_packet.header -> Udns_packet.query -> ((Udns_enum.rr_typ * Domain_name.t * Udns_resolver_entry.rank * Udns_resolver_entry.res) list, Udns_enum.rcode) result diff --git a/src/udns_map.ml b/src/udns_map.ml index 4475e91..fd3731a 100644 --- a/src/udns_map.ml +++ b/src/udns_map.ml @@ -51,7 +51,7 @@ module SshfpSet = Set.Make (struct end) type _ k = - | Any : (Udns_packet.rr list * Domain_name.Set.t) k + | Any : Udns_packet.rr list k | Cname : (int32 * Domain_name.t) k | Mx : (int32 * MxSet.t) k | Ns : (int32 * Domain_name.Set.t) k @@ -89,9 +89,8 @@ module K = struct let pp : type a. Format.formatter -> a t -> a -> unit = fun ppf t v -> match t, v with - | Any, (entries, names) -> - Fmt.pf ppf "any %a %a" Udns_packet.pp_rrs entries - Fmt.(list ~sep:(unit ";@,") Domain_name.pp) (Domain_name.Set.elements names) + | Any, entries -> + Fmt.pf ppf "any %a" Udns_packet.pp_rrs entries | Cname, (ttl, alias) -> Fmt.pf ppf "cname ttl %lu %a" ttl Domain_name.pp alias | Mx, (ttl, mxs) -> Fmt.pf ppf "mx ttl %lu %a" ttl @@ -222,12 +221,11 @@ include Gmap.Make(K) let pp_b ppf (B (k, v)) = K.pp ppf k v let equal_b b b' = match b, b' with - | B (Any, (entries, names)), B (Any, (entries', names')) -> + | B (Any, entries), B (Any, entries') -> List.length entries = List.length entries' && List.for_all (fun e -> List.exists (fun e' -> Udns_packet.rr_equal e e') entries') - entries && - Domain_name.Set.equal names names' + entries | B (Cname, (_, alias)), B (Cname, (_, alias')) -> Domain_name.equal alias alias' | B (Mx, (_, mxs)), B (Mx, (_, mxs')) -> @@ -316,13 +314,13 @@ let to_rdata : b -> int32 * Udns_packet.rdata list = fun (B (k, v)) -> let to_rr : Domain_name.t -> b -> Udns_packet.rr list = fun name b -> match b with - | B (Any, (entries, _)) -> entries + | B (Any, entries) -> entries | _ -> let ttl, rdatas = to_rdata b in List.map (fun rdata -> { Udns_packet.name ; ttl ; rdata }) rdatas let names = function - | B (Any, (_, names)) -> names + | B (Any, rrs) -> Udns_packet.rr_names rrs | B (Mx, (_, mxs)) -> MxSet.fold (fun (_, name) acc -> Domain_name.Set.add name acc) mxs Domain_name.Set.empty diff --git a/src/udns_map.mli b/src/udns_map.mli index 2dfcb31..5a6f5d7 100644 --- a/src/udns_map.mli +++ b/src/udns_map.mli @@ -32,7 +32,7 @@ module SshfpSet : Set.S with type elt = Udns_packet.sshfp (** A set of SSH FP records. *) type _ k = - | Any : (Udns_packet.rr list * Domain_name.Set.t) k + | Any : Udns_packet.rr list k | Cname : (int32 * Domain_name.t) k | Mx : (int32 * MxSet.t) k | Ns : (int32 * Domain_name.Set.t) k diff --git a/src/udns_trie.ml b/src/udns_trie.ml index 7177450..05941cd 100644 --- a/src/udns_trie.ml +++ b/src/udns_trie.ml @@ -54,13 +54,8 @@ let lookup_res name zone ty m = match ty with | Udns_enum.ANY -> let bindings = Udns_map.bindings m in - let rrs = List.(flatten (map (Udns_map.to_rr name) bindings)) - and names = - List.fold_left - (fun acc v -> Domain_name.Set.union acc (Udns_map.names v)) - Domain_name.Set.empty bindings - in - Ok (Udns_map.B (Udns_map.Any, (rrs, names)), to_ns z zmap) + let rrs = List.(flatten (map (Udns_map.to_rr name) bindings)) in + Ok (Udns_map.B (Udns_map.Any, rrs), to_ns z zmap) | _ -> match Udns_map.lookup_rr ty m with | Some v -> Ok (v, to_ns z zmap) | None -> match Udns_map.findb Udns_map.Cname m with diff --git a/unix/client/ohostname.ml b/unix/client/ohostname.ml index 839452f..d63b87a 100644 --- a/unix/client/ohostname.ml +++ b/unix/client/ohostname.ml @@ -1,6 +1,7 @@ let () = + let t = Udns_client_unix.create () in let res = - Udns_client_unix.gethostbyname () (Domain_name.of_string_exn Sys.argv.(1)) in + Udns_client_unix.gethostbyname t (Domain_name.of_string_exn Sys.argv.(1)) in match res with | Ok addr -> Fmt.pr "%a\n" Ipaddr.V4.pp addr | Error (`Msg x) -> Fmt.epr "Failed to resolve: %s\n" x; exit 1 diff --git a/unix/client/udns_client_unix.ml b/unix/client/udns_client_unix.ml index dd8deb0..6da3f0a 100644 --- a/unix/client/udns_client_unix.ml +++ b/unix/client/udns_client_unix.ml @@ -5,30 +5,34 @@ module Uflow : Udns_client_flow.S with type flow = Unix.file_descr - and type io_addr = string * int + and type io_addr = Unix.inet_addr * int and type stack = unit and type (+'a,+'b) io = ('a,[> `Msg of string]as 'b) result = struct - type io_addr = string * int + type io_addr = Unix.inet_addr * int type ns_addr = [`TCP | `UDP] * io_addr type stack = unit type flow = Unix.file_descr + type t = { nameserver : ns_addr } type (+'a,+'b) io = ('a,'b) result constraint 'b = [> `Msg of string] - let default_ns : ns_addr = `TCP, ("91.239.100.100", 53) + let create ?(nameserver = `TCP, (Unix.inet_addr_of_string "91.239.100.100", 53)) () = + { nameserver } + + let nameserver { nameserver } = nameserver let map = Rresult.R.((>>=)) let resolve = (Rresult.R.(>>=)) open Rresult - let connect () ((proto,(server,port)):ns_addr) = + let connect ?nameserver:ns t = + let proto, (server, port) = match ns with None -> nameserver t | Some x -> x in begin match proto with | `UDP -> Ok Unix.((getprotobyname "udp").p_proto) | `TCP -> Ok Unix.((getprotobyname "tcp").p_proto) end >>= fun proto_number -> let socket = Unix.socket PF_INET SOCK_STREAM proto_number in - let server = Unix.inet_addr_of_string server in let addr = Unix.ADDR_INET (server, port) in Unix.connect socket addr ; Ok socket diff --git a/unix/client/udns_client_unix.mli b/unix/client/udns_client_unix.mli index b3341f9..da33d69 100644 --- a/unix/client/udns_client_unix.mli +++ b/unix/client/udns_client_unix.mli @@ -6,7 +6,7 @@ (** A flow module based on blocking I/O on top of the Unix socket API. *) module Uflow : Udns_client_flow.S with type flow = Unix.file_descr - and type io_addr = string * int + and type io_addr = Unix.inet_addr * int and type stack = unit and type (+'a,+'b) io = ('a,'b) result