diff --git a/compress.go b/compress.go index e8345d7..e46a7bf 100644 --- a/compress.go +++ b/compress.go @@ -80,6 +80,7 @@ func CompressHandlerLevel(h http.Handler, level int) http.Handler { switch strings.TrimSpace(enc) { case "gzip": w.Header().Set("Content-Encoding", "gzip") + r.Header.Del("Accept-Encoding") w.Header().Add("Vary", "Accept-Encoding") gw, _ := gzip.NewWriterLevel(w, level) @@ -111,6 +112,7 @@ func CompressHandlerLevel(h http.Handler, level int) http.Handler { break L case "deflate": w.Header().Set("Content-Encoding", "deflate") + r.Header.Del("Accept-Encoding") w.Header().Add("Vary", "Accept-Encoding") fw, _ := flate.NewWriter(w, level) diff --git a/compress_test.go b/compress_test.go index 6f07f44..e42ff54 100644 --- a/compress_test.go +++ b/compress_test.go @@ -49,6 +49,71 @@ func TestCompressHandlerNoCompression(t *testing.T) { } } +func TestAcceptEncodingIsDropped(t *testing.T) { + tCases := []struct { + name, + compression, + expect string + isPresent bool + }{ + { + "accept-encoding-gzip", + "gzip", + "", + false, + }, + { + "accept-encoding-deflate", + "deflate", + "", + false, + }, + { + "accept-encoding-gzip,deflate", + "gzip,deflate", + "", + false, + }, + { + "accept-encoding-gzip,deflate,something", + "gzip,deflate,something", + "", + false, + }, + { + "accept-encoding-unknown", + "unknown", + "unknown", + true, + }, + } + + for _, tCase := range tCases { + ch := CompressHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + acceptEnc := r.Header.Get("Accept-Encoding") + if acceptEnc == "" && tCase.isPresent { + t.Fatalf("%s: expected 'Accept-Encoding' header to be present but was not", tCase.name) + } + if acceptEnc != "" { + if !tCase.isPresent { + t.Fatalf("%s: expected 'Accept-Encoding' header to be dropped but was still present having value %q", tCase.name, acceptEnc) + } + if acceptEnc != tCase.expect { + t.Fatalf("%s: expected 'Accept-Encoding' to be %q but was %q", tCase.name, tCase.expect, acceptEnc) + } + } + })) + + w := httptest.NewRecorder() + ch.ServeHTTP(w, &http.Request{ + Method: "GET", + Header: http.Header{ + "Accept-Encoding": []string{tCase.compression}, + }, + }) + } +} + func TestCompressHandlerGzip(t *testing.T) { w := httptest.NewRecorder() compressedRequest(w, "gzip")