diff --git a/.merlin b/.merlin index f68f8af4e4..fc5d741275 100644 --- a/.merlin +++ b/.merlin @@ -20,9 +20,7 @@ PKG re PKG re.emacs PKG stringext PKG fieldslib -PKG pa_fields_conv PKG sexplib -PKG pa_sexp_conv PKG ipaddr PKG ipaddr.unix PKG conduit diff --git a/async/cohttp_async.ml b/async/cohttp_async.ml index 94e5bdd1fd..4cb1d2ca3c 100644 --- a/async/cohttp_async.ml +++ b/async/cohttp_async.ml @@ -165,9 +165,13 @@ module Client = struct let pipe = pipe_of_body Response.read_body_chunk reader in (res, pipe) - let request ?interrupt ?ssl_config ?(body=`Empty) req = + let request ?interrupt ?ssl_config ?host ?(body=`Empty) req = (* Connect to the remote side *) - Net.connect_uri ?interrupt ?ssl_config req.Request.uri + let host = + match host with + | Some t -> t + | None -> Request.uri req in + Net.connect_uri ?interrupt host >>= fun (ic,oc) -> Request.write (fun writer -> Body.write Request.write_body body writer) req oc >>= fun () -> @@ -221,7 +225,7 @@ module Client = struct Request.make_for_client ?headers ~chunked:true meth uri end in - req >>= request ?interrupt ?ssl_config ~body + req >>= request ?interrupt ?ssl_config ~body ~host:uri let get ?interrupt ?ssl_config ?headers uri = call ?interrupt ?ssl_config ?headers ~chunked:false `GET uri diff --git a/async/cohttp_async.mli b/async/cohttp_async.mli index 60c469c6c5..fb226bf17e 100644 --- a/async/cohttp_async.mli +++ b/async/cohttp_async.mli @@ -43,6 +43,7 @@ module Client : sig val request : ?interrupt:unit Deferred.t -> ?ssl_config:Conduit_async.Ssl.config -> + ?host:Uri.t -> ?body:Body.t -> Request.t -> (Response.t * Body.t) Deferred.t diff --git a/lib/request.ml b/lib/request.ml index 704aa8f002..dbd8fca783 100644 --- a/lib/request.ml +++ b/lib/request.ml @@ -19,16 +19,29 @@ open Sexplib.Std type t = { headers: Header.t; meth: Code.meth; - uri: Uri.t; + path: string; version: Code.version; encoding: Transfer.encoding; } [@@deriving fields, sexp] +let fixed_zero = Transfer.Fixed Int64.zero + +let guess_encoding ?(encoding=fixed_zero) headers = + match Header.get_content_range headers with + | Some clen -> Transfer.Fixed clen + | None -> encoding + let make ?(meth=`GET) ?(version=`HTTP_1_1) ?encoding ?headers uri = let headers = match headers with | None -> Header.init () | Some h -> h in + let headers = + Header.add_unless_exists headers "host" + (Uri.host_with_default ~default:"localhost" uri ^ + match Uri.port uri with + | Some p -> ":" ^ string_of_int p + | None -> "") in let headers = (* Add user:password auth to headers from uri * if headers don't already have auth *) @@ -36,20 +49,9 @@ let make ?(meth=`GET) ?(version=`HTTP_1_1) ?encoding ?headers uri = | None, Some user, Some pass -> let auth = `Basic (user, pass) in Header.add_authorization headers auth - | _, _, _ -> headers - in - let encoding = - (* Check for a content-length in the supplied headers first *) - match Header.get_content_range headers with - | Some clen -> Transfer.Fixed clen - | None -> begin - (* Otherwise look for an API-level encoding specification *) - match encoding with - | None -> Transfer.Fixed Int64.zero - | Some e -> e - end - in - { meth; version; headers; uri; encoding } + | _, _, _ -> headers in + let encoding = guess_encoding ?encoding headers in + { meth; version; headers; path=(Uri.path_and_query uri); encoding } let is_keep_alive { version; headers; _ } = not (version = `HTTP_1_0 || @@ -72,6 +74,53 @@ let make_for_client ?headers ?(chunked=true) ?(body_length=Int64.zero) meth uri let pp_hum ppf r = Format.fprintf ppf "%s" (r |> sexp_of_t |> Sexplib.Sexp.to_string_hum) +(* Validate path when reading URI. Implemented for compatibility with old + implementation rather than efficiency *) +let is_valid_uri path meth = + path = "*" || meth = `CONNECT || + (match Uri.scheme (Uri.of_string path) with + | Some _ -> true + | None -> not (String.length path > 0 && path.[0] <> '/')) + +let uri { path ; headers ; meth ; _ } = + match path with + | "*" -> + begin match Header.get headers "host" with + | None -> Uri.of_string "" + | Some host -> + let host_uri = Uri.of_string ("//"^host) in + let uri = Uri.(with_host (of_string "") (host host_uri)) in + Uri.(with_port uri (port host_uri)) + end + | authority when meth = `CONNECT -> Uri.of_string ("//" ^ authority) + | path -> + let uri = Uri.of_string path in + begin match Uri.scheme uri with + | Some _ -> (* we have an absoluteURI *) + Uri.(match path uri with "" -> with_path uri "/" | _ -> uri) + | None -> + let empty = Uri.of_string "" in + let empty_base = Uri.of_string "///" in + let pqs = match Stringext.split ~max:2 path ~on:'?' with + | [] -> empty_base + | [path] -> + Uri.resolve "http" empty_base (Uri.with_path empty path) + | path::qs::_ -> + let path_base = + Uri.resolve "http" empty_base (Uri.with_path empty path) + in + Uri.with_query path_base (Uri.query_of_encoded qs) + in + let uri = match Header.get headers "host" with + | None -> Uri.(with_scheme (with_host pqs None) None) + | Some host -> + let host_uri = Uri.of_string ("//"^host) in + let uri = Uri.with_host pqs (Uri.host host_uri) in + Uri.with_port uri (Uri.port host_uri) + in + uri + end + type tt = t module Make(IO : S.IO) = struct type t = tt @@ -98,62 +147,17 @@ module Make(IO : S.IO) = struct end | None -> return `Eof - let return_request headers meth uri version = - let encoding = Header.get_transfer_encoding headers in - return (`Ok { headers; meth; uri; version; encoding }) - let read ic = parse_request_fst_line ic >>= function | `Eof -> return `Eof | `Invalid reason as r -> return r - | `Ok (meth, "*", version) -> - Header_IO.parse ic >>= fun headers -> - let uri = match Header.get headers "host" with - | None -> Uri.of_string "" - | Some host -> - let host_uri = Uri.of_string ("//"^host) in - let uri = Uri.(with_host (of_string "") (host host_uri)) in - Uri.(with_port uri (port host_uri)) - in - return_request headers meth uri version - | `Ok (`CONNECT as meth, authority, version) -> - Header_IO.parse ic >>= fun headers -> - let uri = Uri.of_string ("//"^authority) in - return_request headers meth uri version - | `Ok (meth, request_uri_s, version) -> - Header_IO.parse ic >>= fun headers -> - let uri = Uri.of_string request_uri_s in - match Uri.scheme uri with - | Some _ -> (* we have an absoluteURI *) - let uri = Uri.( - match path uri with "" -> with_path uri "/" | _ -> uri - ) in - return_request headers meth uri version - | None -> - let len = String.length request_uri_s in - if len > 0 && String.get request_uri_s 0 <> '/' - then return (`Invalid "bad request URI") - else - let empty = Uri.of_string "" in - let empty_base = Uri.of_string "///" in - let pqs = match Stringext.split ~max:2 request_uri_s ~on:'?' with - | [] -> empty_base - | [path] -> - Uri.resolve "http" empty_base (Uri.with_path empty path) - | path::qs::_ -> - let path_base = - Uri.resolve "http" empty_base (Uri.with_path empty path) - in - Uri.with_query path_base (Uri.query_of_encoded qs) - in - let uri = match Header.get headers "host" with - | None -> Uri.(with_scheme (with_host pqs None) None) - | Some host -> - let host_uri = Uri.of_string ("//"^host) in - let uri = Uri.with_host pqs (Uri.host host_uri) in - Uri.with_port uri (Uri.port host_uri) - in - return_request headers meth uri version + | `Ok (meth, path, version) -> + if is_valid_uri path meth then + Header_IO.parse ic >>= fun headers -> + let encoding = Header.get_transfer_encoding headers in + return (`Ok { headers; meth; path; version; encoding }) + else + return (`Invalid "bad request URI") (* Defined for method types in RFC7231 *) let has_body req = @@ -169,14 +173,9 @@ module Make(IO : S.IO) = struct let fst_line = Printf.sprintf "%s %s %s\r\n" (Code.string_of_method req.meth) - (Uri.path_and_query req.uri) + (if req.path = "" then "/" else req.path) (Code.string_of_version req.version) in - let headers = Header.add_unless_exists req.headers "host" - (Uri.host_with_default ~default:"localhost" req.uri ^ - match Uri.port req.uri with - | Some p -> ":" ^ string_of_int p - | None -> "" - ) in + let headers = req.headers in let headers = match has_body req with | `Yes | `Unknown -> Header.add_transfer_encoding headers req.encoding diff --git a/lib/s.mli b/lib/s.mli index f684e654a5..55175cbeeb 100644 --- a/lib/s.mli +++ b/lib/s.mli @@ -84,7 +84,7 @@ module type Request = sig type t = { headers: Header.t; (** HTTP request headers *) meth: Code.meth; (** HTTP request method *) - uri: Uri.t; (** Full HTTP request uri *) + path: string; (** Request path and query *) version: Code.version; (** HTTP version, usually 1.1 *) encoding: Transfer.encoding; (** transfer encoding of this HTTP request *) } [@@deriving fields, sexp] @@ -95,6 +95,8 @@ module type Request = sig (** Return true whether the connection should be reused *) val is_keep_alive : t -> bool + val uri : t -> Uri.t + val make_for_client: ?headers:Header.t -> ?chunked:bool ->