From 346ef4975d844ad95f50214d74c6cd85fdc2e5f8 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Fri, 18 Jun 2021 11:57:42 +0200 Subject: [PATCH 1/5] gzhttp: Add zstd to transport --- gzhttp/transport.go | 54 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/gzhttp/transport.go b/gzhttp/transport.go index 85abe9d0bd..79686b78d3 100644 --- a/gzhttp/transport.go +++ b/gzhttp/transport.go @@ -10,17 +10,19 @@ import ( "sync" "github.com/klauspost/compress/gzip" + "github.com/klauspost/compress/zstd" ) -// Transport will wrap a transport with a custom gzip handler +// Transport will wrap a transport with a custom handler // that will request gzip and automatically decompress it. // Using this is significantly faster than using the default transport. func Transport(parent http.RoundTripper) http.RoundTripper { - return gzRoundtripper{parent: parent} + return gzRoundtripper{parent: parent, withZstd: true} } type gzRoundtripper struct { - parent http.RoundTripper + parent http.RoundTripper + withZstd bool } func (g gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) { @@ -41,7 +43,12 @@ func (g gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) { // auto-decoding a portion of a gzipped document will just fail // anyway. See https://golang.org/issue/8923 requestedGzip = true - req.Header.Set("Accept-Encoding", "gzip") + if g.withZstd { + // Swap when we want zstd to default. + req.Header.Set("Accept-Encoding", "gzip,zstd") + } else { + req.Header.Set("Accept-Encoding", "gzip") + } } resp, err := g.parent.RoundTrip(req) if err != nil || !requestedGzip { @@ -54,6 +61,16 @@ func (g gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) { resp.ContentLength = -1 resp.Uncompressed = true } + if g.withZstd { + if asciiEqualFold(resp.Header.Get("Content-Encoding"), "zstd") { + resp.Body = &zstdReader{body: resp.Body} + resp.Header.Del("Content-Encoding") + resp.Header.Del("Content-Length") + resp.ContentLength = -1 + resp.Uncompressed = true + } + } + return resp, nil } @@ -114,3 +131,32 @@ func lower(b byte) byte { } return b } + +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +type zstdReader struct { + body io.ReadCloser // underlying HTTP/1 response body framing + zr *zstd.Decoder // lazily-initialized gzip reader + zerr error // any error from zstd.NewReader; sticky +} + +func (zr *zstdReader) Read(p []byte) (n int, err error) { + if zr.zr == nil { + if zr.zerr == nil { + zr.zr, zr.zerr = zstd.NewReader(zr.body, zstd.WithDecoderLowmem(true), zstd.WithDecoderMaxWindow(32<<20)) + } + if zr.zerr != nil { + return 0, zr.zerr + } + } + + return zr.zr.Read(p) +} + +func (zr *zstdReader) Close() error { + if zr.zr != nil { + zr.zr.Close() + zr.zr = nil + } + return zr.body.Close() +} From 08f591d5dda8ec295fb6bc157806575f59a0713f Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Mon, 28 Feb 2022 10:37:37 +0100 Subject: [PATCH 2/5] Add zstd decoder pool. --- gzhttp/transport.go | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/gzhttp/transport.go b/gzhttp/transport.go index 79686b78d3..18cbcad4fa 100644 --- a/gzhttp/transport.go +++ b/gzhttp/transport.go @@ -45,7 +45,7 @@ func (g gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) { requestedGzip = true if g.withZstd { // Swap when we want zstd to default. - req.Header.Set("Accept-Encoding", "gzip,zstd") + req.Header.Set("Accept-Encoding", "zstd,gzip") } else { req.Header.Set("Accept-Encoding", "gzip") } @@ -132,7 +132,10 @@ func lower(b byte) byte { return b } -// gzipReader wraps a response body so it can lazily +// zstdReaderPool pools zstd decoders. +var zstdReaderPool sync.Pool + +// zstdReader wraps a response body so it can lazily // call gzip.NewReader on the first call to Read type zstdReader struct { body io.ReadCloser // underlying HTTP/1 response body framing @@ -143,7 +146,13 @@ type zstdReader struct { func (zr *zstdReader) Read(p []byte) (n int, err error) { if zr.zr == nil { if zr.zerr == nil { - zr.zr, zr.zerr = zstd.NewReader(zr.body, zstd.WithDecoderLowmem(true), zstd.WithDecoderMaxWindow(32<<20)) + reader, ok := zstdReaderPool.Get().(*zstd.Decoder) + if ok { + zr.zerr = reader.Reset(zr.body) + zr.zr = reader + } else { + zr.zr, zr.zerr = zstd.NewReader(zr.body, zstd.WithDecoderLowmem(true), zstd.WithDecoderMaxWindow(32<<20), zstd.WithDecoderConcurrency(1)) + } } if zr.zerr != nil { return 0, zr.zerr @@ -155,7 +164,7 @@ func (zr *zstdReader) Read(p []byte) (n int, err error) { func (zr *zstdReader) Close() error { if zr.zr != nil { - zr.zr.Close() + zstdReaderPool.Put(zr.zr) zr.zr = nil } return zr.body.Close() From 872e8bf04100f85533635737e4ae3416a845c5f2 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Mon, 28 Feb 2022 10:49:11 +0100 Subject: [PATCH 3/5] Clean up early. --- gzhttp/transport.go | 16 ++++++++++++++-- zstd/decoder.go | 1 + 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/gzhttp/transport.go b/gzhttp/transport.go index 18cbcad4fa..1d2570b488 100644 --- a/gzhttp/transport.go +++ b/gzhttp/transport.go @@ -144,6 +144,9 @@ type zstdReader struct { } func (zr *zstdReader) Read(p []byte) (n int, err error) { + if zr.zerr != nil { + return 0, zr.zerr + } if zr.zr == nil { if zr.zerr == nil { reader, ok := zstdReaderPool.Get().(*zstd.Decoder) @@ -158,12 +161,21 @@ func (zr *zstdReader) Read(p []byte) (n int, err error) { return 0, zr.zerr } } - - return zr.zr.Read(p) + n, err = zr.zr.Read(p) + if err != nil { + // Usually this will be io.EOF, + // stash the decoder and keep the error. + zr.zr.Reset(nil) + zstdReaderPool.Put(zr.zr) + zr.zr = nil + zr.zerr = err + } + return } func (zr *zstdReader) Close() error { if zr.zr != nil { + zr.zr.Reset(nil) zstdReaderPool.Put(zr.zr) zr.zr = nil } diff --git a/zstd/decoder.go b/zstd/decoder.go index b6f29a5335..a93dfaf100 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -176,6 +176,7 @@ func (d *Decoder) Reset(r io.Reader) error { d.drainOutput() + d.syncStream.br.r = nil if r == nil { d.current.err = ErrDecoderNilInput if len(d.current.b) > 0 { From 7d7d99fc8c183cb0c70a6fa20a56d85a1f799ecc Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Mon, 28 Feb 2022 11:33:05 +0100 Subject: [PATCH 4/5] Add tests. --- gzhttp/gzip_test.go | 3 + gzhttp/transport_test.go | 137 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 132 insertions(+), 8 deletions(-) diff --git a/gzhttp/gzip_test.go b/gzhttp/gzip_test.go index c879d91866..e2556c2e2f 100644 --- a/gzhttp/gzip_test.go +++ b/gzhttp/gzip_test.go @@ -1133,6 +1133,9 @@ func newTestHandler(body []byte) http.Handler { case "/gzipped": w.Header().Set("Content-Encoding", "gzip") w.Write(body) + case "/zstd": + w.Header().Set("Content-Encoding", "zstd") + w.Write(body) default: w.Write(body) } diff --git a/gzhttp/transport_test.go b/gzhttp/transport_test.go index 4a8271e1a7..0d616e0562 100644 --- a/gzhttp/transport_test.go +++ b/gzhttp/transport_test.go @@ -14,6 +14,7 @@ import ( "testing" "github.com/klauspost/compress/gzip" + "github.com/klauspost/compress/zstd" ) func TestTransport(t *testing.T) { @@ -38,6 +39,30 @@ func TestTransport(t *testing.T) { } } +func TestTransportZstd(t *testing.T) { + bin, err := ioutil.ReadFile("testdata/benchmark.json") + if err != nil { + t.Fatal(err) + } + enc, _ := zstd.NewWriter(nil) + defer enc.Close() + zsBin := enc.EncodeAll(bin, nil) + server := httptest.NewServer(newTestHandler(zsBin)) + + c := http.Client{Transport: Transport(http.DefaultTransport)} + resp, err := c.Get(server.URL + "/zstd") + if err != nil { + t.Fatal(err) + } + got, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, bin) { + t.Errorf("data mismatch") + } +} + func TestTransportInvalid(t *testing.T) { bin, err := ioutil.ReadFile("testdata/benchmark.json") if err != nil { @@ -58,6 +83,26 @@ func TestTransportInvalid(t *testing.T) { } } +func TestTransportZstdInvalid(t *testing.T) { + bin, err := ioutil.ReadFile("testdata/benchmark.json") + if err != nil { + t.Fatal(err) + } + // Do not encode... + server := httptest.NewServer(newTestHandler(bin)) + + c := http.Client{Transport: Transport(http.DefaultTransport)} + resp, err := c.Get(server.URL + "/zstd") + if err != nil { + t.Fatal(err) + } + _, err = ioutil.ReadAll(resp.Body) + if err == nil { + t.Fatal("expected error, got nil") + } + t.Log("expected error:", err) +} + func TestDefaultTransport(t *testing.T) { bin, err := ioutil.ReadFile("testdata/benchmark.json") if err != nil { @@ -82,22 +127,31 @@ func TestDefaultTransport(t *testing.T) { } func BenchmarkTransport(b *testing.B) { - bin, err := ioutil.ReadFile("testdata/benchmark.json") + raw, err := ioutil.ReadFile("testdata/benchmark.json") if err != nil { b.Fatal(err) } - sz := len(bin) + sz := int64(len(raw)) var buf bytes.Buffer zw := gzip.NewWriter(&buf) - zw.Write(bin) + zw.Write(raw) zw.Close() - bin = buf.Bytes() + bin := buf.Bytes() server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.Body.Close() w.Header().Set("Content-Encoding", "gzip") w.WriteHeader(http.StatusOK) w.Write(bin) })) + enc, _ := zstd.NewWriter(nil, zstd.WithWindowSize(128<<10), zstd.WithEncoderLevel(zstd.SpeedBestCompression)) + defer enc.Close() + zsBin := enc.EncodeAll(raw, nil) + serverZstd := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Body.Close() + w.Header().Set("Content-Encoding", "zstd") + w.WriteHeader(http.StatusOK) + w.Write(zsBin) + })) b.Run("gzhttp", func(b *testing.B) { c := http.Client{Transport: Transport(http.DefaultTransport)} @@ -109,12 +163,16 @@ func BenchmarkTransport(b *testing.B) { if err != nil { b.Fatal(err) } - _, err = io.Copy(ioutil.Discard, resp.Body) + n, err := io.Copy(ioutil.Discard, resp.Body) if err != nil { b.Fatal(err) } + if n != sz { + b.Fatalf("size mismatch: want %d, got %d", sz, n) + } resp.Body.Close() } + b.ReportMetric(100*float64(len(bin))/float64(len(raw)), "pct") }) b.Run("stdlib", func(b *testing.B) { c := http.Client{Transport: http.DefaultTransport} @@ -126,12 +184,38 @@ func BenchmarkTransport(b *testing.B) { if err != nil { b.Fatal(err) } - _, err = io.Copy(ioutil.Discard, resp.Body) + n, err := io.Copy(ioutil.Discard, resp.Body) if err != nil { b.Fatal(err) } + if n != sz { + b.Fatalf("size mismatch: want %d, got %d", sz, n) + } resp.Body.Close() } + b.ReportMetric(100*float64(len(bin))/float64(len(raw)), "pct") + }) + b.Run("zstd", func(b *testing.B) { + c := http.Client{Transport: Transport(http.DefaultTransport)} + + b.SetBytes(int64(sz)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + resp, err := c.Get(serverZstd.URL + "/zstd") + if err != nil { + b.Fatal(err) + } + n, err := io.Copy(ioutil.Discard, resp.Body) + if err != nil { + b.Fatal(err) + } + if n != sz { + b.Fatalf("size mismatch: want %d, got %d", sz, n) + } + resp.Body.Close() + } + b.ReportMetric(100*float64(len(zsBin))/float64(len(raw)), "pct") }) b.Run("gzhttp-par", func(b *testing.B) { c := http.Client{ @@ -150,13 +234,17 @@ func BenchmarkTransport(b *testing.B) { if err != nil { b.Fatal(err) } - _, err = io.Copy(ioutil.Discard, resp.Body) + n, err := io.Copy(ioutil.Discard, resp.Body) if err != nil { b.Fatal(err) } + if n != sz { + b.Fatalf("size mismatch: want %d, got %d", sz, n) + } resp.Body.Close() } }) + b.ReportMetric(100*float64(len(bin))/float64(len(raw)), "pct") }) b.Run("stdlib-par", func(b *testing.B) { c := http.Client{Transport: &http.Transport{ @@ -172,12 +260,45 @@ func BenchmarkTransport(b *testing.B) { if err != nil { b.Fatal(err) } - _, err = io.Copy(ioutil.Discard, resp.Body) + n, err := io.Copy(ioutil.Discard, resp.Body) + if err != nil { + b.Fatal(err) + } + if n != sz { + b.Fatalf("size mismatch: want %d, got %d", sz, n) + } + resp.Body.Close() + } + }) + b.ReportMetric(100*float64(len(bin))/float64(len(raw)), "pct") + }) + b.Run("zstd-par", func(b *testing.B) { + c := http.Client{ + Transport: Transport(&http.Transport{ + MaxConnsPerHost: runtime.GOMAXPROCS(0), + MaxIdleConnsPerHost: runtime.GOMAXPROCS(0), + }), + } + + b.SetBytes(int64(sz)) + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + resp, err := c.Get(serverZstd.URL + "/zstd") + if err != nil { + b.Fatal(err) + } + n, err := io.Copy(ioutil.Discard, resp.Body) if err != nil { b.Fatal(err) } + if n != sz { + b.Fatalf("size mismatch: want %d, got %d", sz, n) + } resp.Body.Close() } }) + b.ReportMetric(100*float64(len(zsBin))/float64(len(raw)), "pct") }) } From 436be7c65cd845443b4fef2be0cab45eef54e4e1 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Mon, 28 Feb 2022 15:16:47 +0100 Subject: [PATCH 5/5] Add option to enable/disable individual methods. --- gzhttp/transport.go | 74 +++++++++++++++++++++++++------------ gzhttp/transport_test.go | 80 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 23 deletions(-) diff --git a/gzhttp/transport.go b/gzhttp/transport.go index 1d2570b488..a199fbc6e8 100644 --- a/gzhttp/transport.go +++ b/gzhttp/transport.go @@ -7,6 +7,7 @@ package gzhttp import ( "io" "net/http" + "strings" "sync" "github.com/klauspost/compress/gzip" @@ -16,17 +17,48 @@ import ( // Transport will wrap a transport with a custom handler // that will request gzip and automatically decompress it. // Using this is significantly faster than using the default transport. -func Transport(parent http.RoundTripper) http.RoundTripper { - return gzRoundtripper{parent: parent, withZstd: true} +func Transport(parent http.RoundTripper, opts ...transportOption) http.RoundTripper { + g := gzRoundtripper{parent: parent, withZstd: true, withGzip: true} + for _, o := range opts { + o(&g) + } + var ae []string + if g.withZstd { + ae = append(ae, "zstd") + } + if g.withGzip { + ae = append(ae, "gzip") + } + g.acceptEncoding = strings.Join(ae, ",") + return &g +} + +type transportOption func(c *gzRoundtripper) + +// TransportEnableZstd will send Zstandard as a compression option to the server. +// Enabled by default, but may be disabled if future problems arise. +func TransportEnableZstd(b bool) transportOption { + return func(c *gzRoundtripper) { + c.withZstd = b + } +} + +// TransportEnableGzip will send Gzip as a compression option to the server. +// Enabled by default. +func TransportEnableGzip(b bool) transportOption { + return func(c *gzRoundtripper) { + c.withGzip = b + } } type gzRoundtripper struct { - parent http.RoundTripper - withZstd bool + parent http.RoundTripper + acceptEncoding string + withZstd, withGzip bool } -func (g gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) { - var requestedGzip bool +func (g *gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) { + var requestedComp bool if req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { @@ -42,33 +74,29 @@ func (g gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) { // We don't request gzip if the request is for a range, since // auto-decoding a portion of a gzipped document will just fail // anyway. See https://golang.org/issue/8923 - requestedGzip = true - if g.withZstd { - // Swap when we want zstd to default. - req.Header.Set("Accept-Encoding", "zstd,gzip") - } else { - req.Header.Set("Accept-Encoding", "gzip") - } + requestedComp = len(g.acceptEncoding) > 0 + req.Header.Set("Accept-Encoding", g.acceptEncoding) } + resp, err := g.parent.RoundTrip(req) - if err != nil || !requestedGzip { + if err != nil || !requestedComp { return resp, err } - if asciiEqualFold(resp.Header.Get("Content-Encoding"), "gzip") { + + // Decompress + if g.withGzip && asciiEqualFold(resp.Header.Get("Content-Encoding"), "gzip") { resp.Body = &gzipReader{body: resp.Body} resp.Header.Del("Content-Encoding") resp.Header.Del("Content-Length") resp.ContentLength = -1 resp.Uncompressed = true } - if g.withZstd { - if asciiEqualFold(resp.Header.Get("Content-Encoding"), "zstd") { - resp.Body = &zstdReader{body: resp.Body} - resp.Header.Del("Content-Encoding") - resp.Header.Del("Content-Length") - resp.ContentLength = -1 - resp.Uncompressed = true - } + if g.withZstd && asciiEqualFold(resp.Header.Get("Content-Encoding"), "zstd") { + resp.Body = &zstdReader{body: resp.Body} + resp.Header.Del("Content-Encoding") + resp.Header.Del("Content-Length") + resp.ContentLength = -1 + resp.Uncompressed = true } return resp, nil diff --git a/gzhttp/transport_test.go b/gzhttp/transport_test.go index 0d616e0562..884c887ae7 100644 --- a/gzhttp/transport_test.go +++ b/gzhttp/transport_test.go @@ -39,6 +39,61 @@ func TestTransport(t *testing.T) { } } +func TestTransportForced(t *testing.T) { + raw, err := ioutil.ReadFile("testdata/benchmark.json") + if err != nil { + t.Fatal(err) + } + + var buf bytes.Buffer + zw := gzip.NewWriter(&buf) + zw.Write(raw) + zw.Close() + bin := buf.Bytes() + + server := httptest.NewServer(newTestHandler(bin)) + + c := http.Client{Transport: Transport(http.DefaultTransport)} + resp, err := c.Get(server.URL + "/gzipped") + if err != nil { + t.Fatal(err) + } + got, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, raw) { + t.Errorf("data mismatch") + } +} + +func TestTransportForcedDisabled(t *testing.T) { + raw, err := ioutil.ReadFile("testdata/benchmark.json") + if err != nil { + t.Fatal(err) + } + + var buf bytes.Buffer + zw := gzip.NewWriter(&buf) + zw.Write(raw) + zw.Close() + bin := buf.Bytes() + + server := httptest.NewServer(newTestHandler(bin)) + c := http.Client{Transport: Transport(http.DefaultTransport, TransportEnableGzip(false))} + resp, err := c.Get(server.URL + "/gzipped") + if err != nil { + t.Fatal(err) + } + got, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(bin, got) { + t.Errorf("data mismatch") + } +} + func TestTransportZstd(t *testing.T) { bin, err := ioutil.ReadFile("testdata/benchmark.json") if err != nil { @@ -83,6 +138,31 @@ func TestTransportInvalid(t *testing.T) { } } +func TestTransportZstdDisabled(t *testing.T) { + raw, err := ioutil.ReadFile("testdata/benchmark.json") + if err != nil { + t.Fatal(err) + } + + enc, _ := zstd.NewWriter(nil) + defer enc.Close() + zsBin := enc.EncodeAll(raw, nil) + + server := httptest.NewServer(newTestHandler(zsBin)) + c := http.Client{Transport: Transport(http.DefaultTransport, TransportEnableZstd(false))} + resp, err := c.Get(server.URL + "/zstd") + if err != nil { + t.Fatal(err) + } + got, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(zsBin, got) { + t.Errorf("data mismatch") + } +} + func TestTransportZstdInvalid(t *testing.T) { bin, err := ioutil.ReadFile("testdata/benchmark.json") if err != nil {