diff --git a/config.ml b/config.ml index b4f9bf2..22c3aa1 100644 --- a/config.ml +++ b/config.ml @@ -3,27 +3,39 @@ open Mirage let packages = [ - Functoria.package "letsencrypt"; - Functoria.package "uri"; - Functoria.package ~sublibs:["kv"] "chamelon"; - Functoria.package ~sublibs:["ocaml"] "digestif"; - Functoria.package ~sublibs:["ocaml"] "checkseum"; + package "uri"; + package ~sublibs:["ocaml"] "digestif"; + package ~sublibs:["ocaml"] "checkseum"; + package "paf"; + package "paf-le"; + package "paf" ~sublibs:[ "mirage" ]; + package "multipart_form-lwt"; ] -let stack = generic_stackv4v6 default_network -let conduit = conduit_direct ~tls:true stack -let http_srv = cohttp_server conduit -let http_client_imp = cohttp_client (resolver_dns stack) conduit -let block_imp = block_of_file "url-shortener-db.img" - let host = let doc = Key.Arg.info ~doc:"Fully-qualified domain name for the server. Certificates will be requested from Let's Encrypt for this name." ["host"] in Key.(create "host" Arg.(required string doc)) -let keys = List.map Key.abstract [ host ] +let tls = + let doc = Key.Arg.info ~doc:"Bootstrap with a Let's encrypt certificate and an HTTPS server." ["tls"] in + Key.(create "tls" Arg.(opt bool false doc)) + +let port = + let doc = Key.Arg.info ~doc:"Port where the HTTP(S) must listen." ["port"] in + Key.(create "port" Arg.(opt (some int) None doc)) + +let program_block_size = + let doc = Key.Arg.info ~doc:"Program block size." [ "program-block-size" ] in + Key.(create "program_block_size" Arg.(opt int 16 doc)) + +let keys = [ Key.v host; Key.v tls; Key.v port ] let main = - foreign ~packages ~keys "Shortener.Main" (block @-> pclock @-> time @-> http @-> http_client @-> job) + foreign ~packages ~keys "Shortener.Main" (kv_rw @-> pclock @-> time @-> stackv4v6 @-> dns_client @-> job) + +let stack = generic_stackv4v6 default_network +let block = chamelon ~program_block_size (block_of_file "db") +let dns = generic_dns_client stack let () = - register "shortener" [ main $ block_imp $ default_posix_clock $ default_time $ http_srv $ http_client_imp ] + register "shortener" [ main $ block $ default_posix_clock $ default_time $ stack $ dns ] diff --git a/le.ml b/le.ml deleted file mode 100644 index 3be30ee..0000000 --- a/le.ml +++ /dev/null @@ -1,116 +0,0 @@ -module Shim(Cohttp_client : Cohttp_lwt.S.Client) = struct - module Headers = Cohttp.Header - module Response = struct - include Cohttp.Response - let status t = Cohttp.Response.status t |> Cohttp.Code.code_of_status - end - module Body = Cohttp_lwt__.Body - include Cohttp_client - - -end - -module Make - (Time : Mirage_time.S) - (Http_server : Cohttp_mirage.Server.S) - (Http_client : Cohttp_lwt.S.Client) -= struct - module Http_client_shim = Shim(Http_client) - module Acme = Letsencrypt.Client.Make(Http_client_shim) - - let http_port = 80 - let https_port = 443 - - let cn host = X509.[Distinguished_name.(Relative_distinguished_name.singleton (CN host))] - - let csr host key = - X509.Signing_request.create (cn host) key - - let prefix = ".well-known", "acme-challenge" - let tokens = Hashtbl.create 1 - - let solver _host ~prefix:_ ~token ~content = - Hashtbl.replace tokens token content; - Lwt.return (Ok ()) - - (* It's important (more so than normal) that this function terminate, - * because we call it with Lwt.async later *) - let letsencrypt_dispatch request _body = - let path = Uri.path (Cohttp.Request.uri request) in - Logs.debug (fun m -> m "let's encrypt dispatcher %s" path); - (* we expect very particular incoming requests from the LE web client. - * Only if the incoming URI matches the right form should we - * even check to see whether the token's in the store. *) - match Astring.String.cuts ~sep:"/" ~empty:false path with - | [p1; p2; token] when - String.equal p1 (fst prefix) && String.equal p2 (snd prefix) -> begin - (* anyone trying .well-known/acme-challenge/not-the-token gets a 404 *) - match Hashtbl.find_opt tokens token with - | None -> Http_server.respond ~status:`Not_found ~body:`Empty () - | Some data -> - let headers = - Cohttp.Header.init_with "content-type" "application/octet-stream" - in - (* respond to the challenge with the data we have available *) - Http_server.respond ~headers ~status:`OK ~body:(`String data) () - - end - | _ -> - (* TODO: we could refer this to another dispatcher, - * which might know what to do *) - Http_server.respond ~status:`Not_found ~body:`Empty () - - let provision_certificate host ctx = - let open Lwt_result.Infix in - let endpoint = - (* the example code contains a switch here for a production key, - * so we can use Letsencrypt.letsencrypt_production_url - * or the staging one as appropriate. - * We test in prod ;) - *) - Letsencrypt.letsencrypt_production_url - in - - (* email and seed are provided arguments in the example code; - * let's see if we can get by without them *) - - (* the example code does some contortions to inject the seed - * here if it's been provided. We DGAF so just let generate - * handle it. *) - let priv = `RSA (Mirage_crypto_pk.Rsa.generate ~bits:4096 ()) in - match csr host priv with - | Error (`Msg err) -> - Logs.err (fun m -> m "couldn't create signing request for our key: %s" err); - (* The choice to `exit` here is debatable - we could return and serve on HTTP only *) - exit 1 - | Ok csr -> - let http_connection_pk = Mirage_crypto_pk.Rsa.generate ~bits:4096 () in - Acme.initialise ~ctx ~endpoint (`RSA http_connection_pk) >>= fun lets_encrypt -> - let sleep sec = Time.sleep_ns (Duration.of_sec sec) in - let solver = Letsencrypt.Client.http_solver solver in - Acme.sign_certificate ~ctx solver lets_encrypt sleep csr >|= - fun certs -> - `Single (certs, priv) - - let serve cb = - let callback _ = cb - and conn_closed _ = () - in - Http_server.make ~conn_closed ~callback () - - let rec provision host http_server_impl http_client = - let open Lwt.Infix in - Logs.info (fun m -> m "listening on tcp/%d for Let's Encrypt provisioning" http_port); - (* "this should be cancelled once certificates are retrieved", - * says the source material *) - let letsencrypt_http_server = http_server_impl (`TCP http_port) @@ serve letsencrypt_dispatch in - Lwt.dont_wait (fun () -> letsencrypt_http_server) (fun _ex -> ()); - provision_certificate host http_client >>= function - | Error (`Msg s) -> Logs.err (fun f -> f "error provisioning TLS certificate: %s" s); - (* Since the error may be transient, wait a bit and try again *) - Time.sleep_ns (Duration.of_min 15) >>= fun () -> - provision host http_server_impl http_client - | Ok certificates -> - Lwt.return certificates - -end diff --git a/shortener.ml b/shortener.ml index ff587e4..18f649a 100644 --- a/shortener.ml +++ b/shortener.ml @@ -1,66 +1,74 @@ open Lwt.Infix +open Lwt.Syntax module Webapp (Clock : Mirage_clock.PCLOCK) - (KV : Mirage_kv.RW) - (H : Cohttp_mirage.Server.S) = struct + (KV : Mirage_kv.RW) = struct + open Httpaf let reserved = [ "/uptime"; "/new"; "/status"; "/"; "/favicon.ico" ] - let not_found = ( - Cohttp.Response.make ~status:Cohttp.Code.(`Not_found) (), - Cohttp_lwt__.Body.of_string "Not found") - - let ise = ( - Cohttp.Response.make ~status:Cohttp.Code.(`Internal_server_error) (), - Cohttp_lwt__.Body.of_string "Internal server error") - - let bad_request = ( - Cohttp.Response.make ~status:Cohttp.Code.(`Bad_request) (), - Cohttp_lwt__.Body.of_string "Bad request") - - let slash = - let form = "
" + let not_found = + let headers = Headers.of_list [ "connection", "close" ] in + Response.create ~headers `Not_found, None + + let ise = + let headers = Headers.of_list [ "connection", "close" ] in + Response.create ~headers `Internal_server_error, None + + let bad_request = + let headers = Headers.of_list [ "connection", "close" ] in + Response.create ~headers `Bad_request, None + + let slash reqd = + let form = +{html| + + + + + +|html} in - (Cohttp.Response.make ~status:Cohttp.Code.(`OK) (), - Cohttp_lwt__.Body.of_string form) + let headers = Headers.of_list + [ "content-length", string_of_int (String.length form) + ; "content-type", "text/html; charset=utf-8" ] in + Response.create ~headers `OK, Some form let uptime start_time = - let response = Cohttp.Response.make ~status:Cohttp.Code.(`OK) () in let span = Ptime.Span.sub (Ptime.Span.v @@ Pclock.now_d_ps ()) (Ptime.to_span start_time) in - let s = Format.asprintf "%a (since %a)" Ptime.Span.pp span (Ptime.pp_human ()) start_time in - (response, Cohttp_lwt__.Body.of_string s) + let contents = Fmt.str "%a (since %a)" Ptime.Span.pp span (Ptime.pp_human ()) start_time in + let headers = Headers.of_list + [ "content-type", "text/plain" + ; "content-length", string_of_int (String.length contents) ] in + Response.create ~headers `OK, Some contents - let get_reserved start_time path = + let get_reserved start_time path reqd = if String.equal path "/uptime" then uptime start_time else if String.equal path "/favicon.ico" then not_found - else if String.equal path "/" then slash + else if String.equal path "/" then slash reqd else not_found let get_from_database kv path = KV.get kv @@ Mirage_kv.Key.v path >>= function | Error (`Not_found k) -> - let response = Cohttp.Response.make ~status:Cohttp.Code.(`Not_found) () in - let body = Cohttp_lwt__.Body.of_string "Not found" in - Lwt.return (response, body) + Lwt.return not_found | Error e -> Logs.err (fun f -> f "error %a fetching a key from the database" KV.pp_error e); - let response = Cohttp.Response.make ~status:Cohttp.Code.(`Internal_server_error) () in - let body = Cohttp_lwt__.Body.of_string "Internal Server Error" in - Lwt.return (response, body) + Lwt.return ise | Ok data -> - let loc = Cohttp.Header.init_with "Location" data in - let response = Cohttp.Response.make ~status:Cohttp.Code.(`Temporary_redirect) ~headers:loc () in - let body = Cohttp_lwt__.Body.empty in - Lwt.return (response, body) + let headers = Headers.of_list + [ "location", data + ; "content-length", "0" ] in + Lwt.return (Response.create ~headers `Temporary_redirect, None) let validate_uri ~hostname s = let uri = Uri.of_string s |> Uri.canonicalize in @@ -98,9 +106,11 @@ module Webapp Logs.err (fun f -> f "error %a trying to post a url" KV.pp_error e); Lwt.return ise | Ok (Some _) -> - let response = Cohttp.Response.make ~status:Cohttp.Code.(`Conflict) () in - let body = Cohttp_lwt__.Body.of_string "there's already a URL set there. Try choosing another" in - Lwt.return (response, body) + let contents = "There's already a URL set there. Try choosing another." in + let headers = Headers.of_list + [ "content-type", "text/plain" + ; "content-length", string_of_int (String.length contents) ] in + Lwt.return (Response.create ~headers `Conflict, Some contents) | Ok None -> match validate_uri ~hostname url with | None -> @@ -113,76 +123,172 @@ module Webapp Lwt.return ise | Ok () -> Logs.debug (fun f -> f "set a new key"); - let response = Cohttp.Response.make ~status:Cohttp.Code.(`Created) () in - let response_body = Cohttp_lwt__.Body.of_string "Success! Your shortcut has been created." in - Lwt.return (response, response_body) + let contents = "Success! Your shortcut has been created." in + let headers = Headers.of_list + [ "content-type", "text/plain" + ; "content-length", string_of_int (String.length contents) ] in + Lwt.return (Response.create ~headers `Created, Some contents) end - let reply (kv : KV.t) hostname start_time = - let callback _connection request body = - match Cohttp.Request.meth request with - | `GET -> begin - let path = Uri.path @@ Uri.canonicalize @@ Cohttp.Request.uri request in - if List.mem path reserved then Lwt.return @@ get_reserved start_time path - else get_from_database kv path - end - | `POST -> begin - Cohttp_lwt__.Body.to_form body >>= fun form_entries -> - match List.assoc_opt "short_name" form_entries, List.assoc_opt "url" form_entries with - | Some (path::[]), Some (url::[]) when not @@ List.mem path reserved -> maybe_set kv hostname path url - | _, _ -> Lwt.return bad_request - end - | _ -> - let response = Cohttp.Response.make ~status:Cohttp.Code.(`Method_not_allowed) () in - let body = Cohttp_lwt__.Body.empty in - Lwt.return (response, body) - in - H.make ~conn_closed:(fun _ -> ()) ~callback () + let stream_of_body body = + let stream, push = Lwt_stream.create () in + let rec on_eof () = push None + and on_read buf ~off ~len = + push (Some (Bigstringaf.substring buf ~off ~len)) ; + Body.schedule_read body ~on_eof ~on_read in + Body.schedule_read body ~on_eof ~on_read ; + stream + + let identify header = + let open Multipart_form in + let ( >>= ) = Option.bind in + let ( >>| ) x f = Option.map f x in + Header.content_disposition header + >>= Content_disposition.name + >>| String.lowercase_ascii + >>= function + | "url" -> Some `Url + | "short_name" -> Some `Short_name + | _ -> None + + let post kv hostname reqd request = + let headers = request.Request.headers in + match Headers.get headers "content-type" with + | None -> Lwt.return bad_request + | Some str -> + match Multipart_form.Content_type.of_string (str ^ "\r\n") with + | Error (`Msg _err) -> Lwt.return bad_request + | Ok content_type -> + let body = Reqd.request_body reqd in + let stream = stream_of_body body in + let `Parse th, stream = Multipart_form_lwt.stream ~identify + stream content_type in + th >>= fun result -> + Body.close_reader body ; + ( match result with + | Error _ -> Lwt.return ise + | Ok _tree -> + Lwt_stream.to_list stream + >>= Lwt_list.filter_map_p (fun (id, headers, stream) -> + Lwt_stream.to_list stream >|= String.concat "" >>= fun contents -> + match id with + | None -> Lwt.return_none + | Some `Short_name -> Lwt.return_some (`Short_name, contents) + | Some `Url -> Lwt.return_some (`Url, contents)) >>= fun bindings -> + match List.assoc_opt `Short_name bindings, + List.assoc_opt `Url bindings with + | Some path, Some url when not (List.mem path reserved) -> + maybe_set kv hostname path url + | _ -> Lwt.return bad_request ) + + let reply (kv : KV.t) hostname start_time (_ipaddr, _port) reqd = + let res () = + Lwt.catch begin fun () -> + let request = Reqd.request reqd in + let path = + Uri.of_string request.Request.target + |> Uri.canonicalize + |> Uri.path in + let* response, body = match request.Request.meth with + | `GET when List.mem path reserved -> + Lwt.return (get_reserved start_time path reqd) + | `GET -> get_from_database kv path + | `POST -> post kv hostname reqd request + | _ -> + let headers = Headers.of_list [ "connection", "close" ] in + Lwt.return (Response.create ~headers `Method_not_allowed, None) in + let body = Option.value ~default:"" body in + Reqd.respond_with_string reqd response body ; + Lwt.return_unit + end @@ fun exn -> + let res = Printexc.to_string exn in + let headers = Headers.of_list + [ "content-length", string_of_int (String.length res) ] in + let response = Response.create ~headers `Internal_server_error in + Reqd.respond_with_string reqd response res ; + Lwt.return_unit in + Lwt.async res end module Main - (Block : Mirage_block.S) + (Database : Mirage_kv.RW) (Clock : Mirage_clock.PCLOCK) (Time : Mirage_time.S) - (Http : Cohttp_mirage.Server.S) - (Client : Cohttp_lwt.S.Client) + (Stack : Tcpip.Stack.V4V6) + (DNS : Dns_client_mirage.S with type Transport.stack = Stack.t) = struct module Logs_reporter = Mirage_logs.Make(Clock) - module LE = Le.Make(Time)(Http)(Client) - module Database = Kv.Make(Block)(Clock) - module Shortener = Webapp(Clock)(Database)(Http) + module Paf = Paf_mirage.Make(Time)(Stack.TCP) + module LE = LE.Make(Time)(Stack) + module Shortener = Webapp(Clock)(Database) + module Nss = Ca_certs_nss.Make(Pclock) + + let ignore_error_handler _ ?request:_ _ _ = () - let start block pclock _time http_server http_client = + let get_certificates cfg stack dns = + Paf.init ~port:80 (Stack.tcp stack) >>= fun t -> + let service = Paf.http_service ~error_handler:ignore_error_handler + (fun _flow -> LE.request_handler) in + let stop = Lwt_switch.create () in + let `Initialized th0 = Paf.serve ~stop service t in + let th1 = + let gethostbyname dns domain_name = + DNS.gethostbyname dns domain_name >>= function + | Ok ipv4 -> Lwt.return_ok (Ipaddr.V4 ipv4) + | Error _ as err -> Lwt.return err in + let authenticator = Result.get_ok (Nss.authenticator ()) in + LE.provision_certificate + ~production:true cfg + (LE.ctx ~gethostbyname ~authenticator dns stack) >>= fun res -> + Lwt_switch.turn_off stop >>= fun () -> Lwt.return res in + Lwt.both th0 th1 >>= function + | ((), Error (`Msg err)) -> failwith err + | ((), Ok certificates) -> Lwt.return certificates + + let start kv pclock _time stack dns = let open Lwt.Infix in let start_time = Ptime.v @@ Pclock.now_d_ps () in let host = Key_gen.host () in Logs_reporter.(create pclock |> run) @@ fun () -> (* solo5 requires us to use a block size of, at maximum, 512 *) - Database.connect ~program_block_size:16 ~block_size:512 block >>= function + (* Database.connect ~program_block_size:16 ~block_size:512 block >>= function | Error e -> Logs.err (fun f -> f "failed to initialize block-backed key-value store: %a" Database.pp_error e); Lwt.return_unit - | Ok kv -> + | Ok kv -> *) Logs.info (fun f -> f "block-backed key-value store up and running"); - let rec provision () = - LE.provision host http_server http_client >>= fun certificates -> - Logs.info (fun f -> f "got certificates from let's encrypt via acme"); - let tls_cfg = Tls.Config.server ~certificates () in - let tls = `TLS (tls_cfg, `TCP 443) in - let tcp = `TCP 80 in - let https = - Logs.info (fun f -> f "(re-)initialized https listener"); - http_server tls @@ Shortener.reply kv host start_time - in - let http = - Logs.info (fun f -> f "overwriting Let's Encrypt http listener with ours"); - http_server tcp @@ Shortener.reply kv host start_time - in - let expire = Time.sleep_ns @@ Duration.of_day 80 in - Lwt.pick [ - https - ; http - ; expire] >>= fun () -> + match Key_gen.tls () with + | false -> + let request_handler _flow = Shortener.reply kv host start_time in + let port = Option.value ~default:80 (Key_gen.port ()) in + Paf.init ~port (Stack.tcp stack) >>= fun t -> + let service = Paf.http_service ~error_handler:ignore_error_handler + request_handler in + let `Initialized th = Paf.serve service t in + th + | true -> + let cfg = + { LE.certificate_seed= None + ; LE.certificate_key_type= `RSA + ; LE.certificate_key_bits= None + ; LE.email= None + ; LE.account_seed= None + ; LE.account_key_type= `RSA + ; LE.account_key_bits= None + ; LE.hostname= Key_gen.host () + |> Domain_name.of_string_exn + |> Domain_name.host_exn } in + let rec provision () = + get_certificates cfg stack dns >>= fun certificates -> + let tls = Tls.Config.server ~certificates () in + let request_handler _flow = Shortener.reply kv host start_time in + let port = Option.value ~default:443 (Key_gen.port ()) in + Paf.init ~port (Stack.tcp stack) >>= fun service -> + let https = Paf.https_service ~tls ~error_handler:ignore_error_handler + request_handler in + let stop = Lwt_switch.create () in + let `Initialized th0 = Paf.serve ~stop https service in + let expire () = Time.sleep_ns (Duration.of_day 80) >>= fun () -> + Lwt_switch.turn_off stop in + Lwt.pick [ th0; expire () ] >>= provision in provision () - in - provision () end