From 97f5c60a360039f7a72c716c3e37ea3d95170682 Mon Sep 17 00:00:00 2001 From: Adrian Hesketh Date: Mon, 15 Mar 2021 21:26:16 +0000 Subject: [PATCH] feat: add config for read and write timeouts (#9) --- server.go | 23 ++++++++++++++++------- server_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 7 deletions(-) diff --git a/server.go b/server.go index 61c09bf..12cbba9 100644 --- a/server.go +++ b/server.go @@ -279,6 +279,9 @@ func (srv *Server) handle(dh *DomainHandler, certificate Certificate, conn net.C log.String("handlerType", reflect.TypeOf(dh.Handler).PkgPath()), log.Int64("ms", duration.Milliseconds()), log.Int64("us", int64(duration.Microseconds())), + log.Int64("lenBody", w.WrittenBody), + log.Int("lenHeader", w.WrittenHeader), + log.Int64("len", int64(w.WrittenHeader)+w.WrittenBody), ) } @@ -309,8 +312,10 @@ func (srv *Server) parseRequest(rw io.ReadWriter) (r *Request, ok bool, err erro // Writer passed to Gemini handlers. type Writer struct { - Code string - Writer io.Writer + Code string + Writer io.Writer + WrittenHeader int + WrittenBody int64 } // NewWriter creates a new Gemini writer. @@ -332,7 +337,9 @@ func (gw *Writer) Write(p []byte) (n int, err error) { err = ErrCannotWriteBodyWithoutSuccessCode return } - return gw.Writer.Write(p) + n, err = gw.Writer.Write(p) + gw.WrittenBody += int64(n) + return } func isSuccessCode(code Code) bool { @@ -347,10 +354,13 @@ func (gw *Writer) SetHeader(code Code, meta string) (err error) { return ErrHeaderAlreadyWritten } gw.Code = string(code) - return writeHeaderToWriter(code, meta, gw.Writer) + var n int + n, err = writeHeaderToWriter(code, meta, gw.Writer) + gw.WrittenHeader += n + return } -func writeHeaderToWriter(code Code, meta string, w io.Writer) error { +func writeHeaderToWriter(code Code, meta string, w io.Writer) (n int, err error) { // // Set default meta if required. if meta == "" && isSuccessCode(code) { @@ -359,8 +369,7 @@ func writeHeaderToWriter(code Code, meta string, w io.Writer) error { if len(meta) > 1024 { meta = meta[:1024] } - _, err := w.Write([]byte(string(code) + " " + meta + "\r\n")) - return err + return w.Write([]byte(string(code) + " " + meta + "\r\n")) } // DomainHandler handles incoming requests for the ServerName using the provided KeyPair certificate diff --git a/server_test.go b/server_test.go index cf3a2ec..4d115d7 100644 --- a/server_test.go +++ b/server_test.go @@ -269,3 +269,51 @@ func (rec *Recorder) SetReadDeadline(t time.Time) error { func (rec *Recorder) SetWriteDeadline(t time.Time) error { return nil } + +func TestWriter(t *testing.T) { + var tests = []struct { + name string + write [][]byte + }{ + { + name: "single write", + write: [][]byte{ + {0, 0}, + }, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + buf := new(bytes.Buffer) + w := NewWriter(buf) + for i := 0; i < len(tt.write); i++ { + n, err := w.Write(tt.write[i]) + if err != nil { + t.Errorf("[%d] unexpected error writing: %v", i, err) + } + if n != len(tt.write[i]) { + t.Errorf("[%d] expected to write %d bytes, wrote %d", i, len(tt.write[i]), n) + } + } + headerAndBody := bytes.SplitN(buf.Bytes(), []byte("\r\n"), 2) + header := headerAndBody[0] + body := headerAndBody[1] + expected := bytes.Join(tt.write, nil) + if !reflect.DeepEqual(body, expected) { + t.Errorf("mismatched body, expected %x, got %x", expected, body) + } + if w.WrittenBody != int64(len(expected)) { + t.Errorf("expected the 'Written' field to be the %d bytes written to the body, but got %d", len(expected), w.WrittenBody) + } + expectedHeader := "20 " + DefaultMIMEType + "\r\n" + if w.WrittenHeader != len(expectedHeader) { + t.Errorf("expected to write header length %d, got %d", len(expectedHeader), w.WrittenHeader) + } + if string(header)+"\r\n" != expectedHeader { + t.Errorf("expected header %q, got %q", expectedHeader, header) + } + }) + } + +}