From 93f2704bf6e50c6db38a4c49b3f985b2f59e18db Mon Sep 17 00:00:00 2001 From: Adrian Hesketh Date: Mon, 15 Mar 2021 20:38:56 +0000 Subject: [PATCH] feat: allow configuration of read and write timeout (#9) --- cmd/main.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/cmd/main.go b/cmd/main.go index cb0bcc6..8d72ddf 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -12,6 +12,7 @@ import ( "os/signal" "strings" "syscall" + "time" "github.com/a-h/gemini" ) @@ -73,6 +74,8 @@ func request(args []string) { verboseFlag := cmd.Bool("verbose", false, "Print both headers and body.") headersFlag := cmd.Bool("headers", false, "Print only the headers.") allowBinaryFlag := cmd.Bool("allowBinary", false, "Set to true to enable printing binary to the console.") + readTimeoutFlag := cmd.Duration("readTimeout", time.Second*5, "Set the duration, e.g. 1m or 5s (default is 5 seconds).") + writeTimeoutFlag := cmd.Duration("writeTimeout", time.Second*5, "Set the duration, e.g. 1m or 5s (default is 5 seconds).") helpFlag := cmd.Bool("help", false, "Print help and exit.") err := cmd.Parse(args) if err != nil || *helpFlag { @@ -91,6 +94,8 @@ func request(args []string) { } client := gemini.NewClient() + client.ReadTimeout = *readTimeoutFlag + client.WriteTimeout = *writeTimeoutFlag if *insecureFlag { client.Insecure = true } @@ -164,6 +169,8 @@ func serve(args []string) { domainFlag := cmd.String("domain", "localhost", "The domain to listen on.") pathFlag := cmd.String("path", ".", "Path containing content.") portFlag := cmd.Int("port", 1965, "Address to listen on.") + readTimeoutFlag := cmd.Duration("readTimeout", time.Second*5, "Set the duration, e.g. 1m or 5s (default is 5 seconds).") + writeTimeoutFlag := cmd.Duration("writeTimeout", time.Second*10, "Set the duration, e.g. 1m or 5s (default is 10 seconds).") helpFlag := cmd.Bool("help", false, "Print help and exit.") err := cmd.Parse(args) if err != nil || *helpFlag { @@ -189,6 +196,8 @@ func serve(args []string) { *domainFlag: dh, } server := gemini.NewServer(ctx, fmt.Sprintf(":%d", *portFlag), domainToHandler) + server.ReadTimeout = *readTimeoutFlag + server.WriteTimeout = *writeTimeoutFlag err = server.ListenAndServe() if err != nil { fmt.Printf("error: %v\n", err)