Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Decompress request body when multi Content-Encoding sent on request headers #2555

Merged
merged 11 commits into from
Aug 6, 2023
Merged
67 changes: 47 additions & 20 deletions ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,33 +260,60 @@ func (c *Ctx) BaseURL() string {
return c.baseURI
}

// Body contains the raw body submitted in a POST request.
// BodyRaw contains the raw body submitted in a POST request.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting instead.
func (c *Ctx) BodyRaw() []byte {
return c.fasthttp.Request.Body()
}

// Body contains the raw body submitted in a POST request.
// This method will decompress the body if the 'Content-Encoding' header is provided.
// It returns the original (or decompressed) body data which is valid only within the handler.
// Don't store direct references to the returned data.
// If you need to keep the body's data later, make a copy or use the Immutable option.
func (c *Ctx) Body() []byte {
var err error
var encoding string
var body []byte
// faster than peek
c.Request().Header.VisitAll(func(key, value []byte) {
if c.app.getString(key) == HeaderContentEncoding {
encoding = c.app.getString(value)
var (
err error
body, originalBody []byte
encodingOrder = []string{"", "", ""}
)

// Split and get the encodings list, in order to attend the
// rule defined at: https://www.rfc-editor.org/rfc/rfc9110#section-8.4-5
encodingOrder = getSplicedStrList(c.Get(HeaderContentEncoding), encodingOrder)
if len(encodingOrder) == 0 {
return c.fasthttp.Request.Body()
}

for index, encoding := range encodingOrder {
switch encoding {
case StrGzip:
body, err = c.fasthttp.Request.BodyGunzip()
case StrBr, StrBrotli:
body, err = c.fasthttp.Request.BodyUnbrotli()
case StrDeflate:
body, err = c.fasthttp.Request.BodyInflate()
default:
return body
}
})

switch encoding {
case StrGzip:
body, err = c.fasthttp.Request.BodyGunzip()
case StrBr, StrBrotli:
body, err = c.fasthttp.Request.BodyUnbrotli()
case StrDeflate:
body, err = c.fasthttp.Request.BodyInflate()
default:
body = c.fasthttp.Request.Body()
if err != nil {
return []byte(err.Error())
}

if index < len(encodingOrder)-1 {
gaby marked this conversation as resolved.
Show resolved Hide resolved
if originalBody == nil {
tempBody := c.fasthttp.Request.Body()
originalBody = make([]byte, len(tempBody))
copy(originalBody, tempBody)
}
c.fasthttp.Request.SetBodyRaw(body)
}
}

if err != nil {
return []byte(err.Error())
if originalBody != nil {
c.fasthttp.Request.SetBodyRaw(originalBody)
}

return body
Expand Down
137 changes: 121 additions & 16 deletions ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"bufio"
"bytes"
"compress/gzip"
"compress/zlib"
"context"
"crypto/tls"
"encoding/xml"
Expand Down Expand Up @@ -323,6 +324,21 @@ func Test_Ctx_Body(t *testing.T) {
utils.AssertEqual(t, []byte("john=doe"), c.Body())
}

func Benchmark_Ctx_Body(b *testing.B) {
const input = "john=doe"

app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)

c.Request().SetBody([]byte(input))
for i := 0; i < b.N; i++ {
_ = c.Body()
}

utils.AssertEqual(b, []byte(input), c.Body())
}

// go test -run Test_Ctx_Body_With_Compression
func Test_Ctx_Body_With_Compression(t *testing.T) {
t.Parallel()
Expand All @@ -344,26 +360,115 @@ func Test_Ctx_Body_With_Compression(t *testing.T) {

// go test -v -run=^$ -bench=Benchmark_Ctx_Body_With_Compression -benchmem -count=4
func Benchmark_Ctx_Body_With_Compression(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.Set("Content-Encoding", "gzip")
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, err := gz.Write([]byte("john=doe"))
utils.AssertEqual(b, nil, err)
err = gz.Flush()
utils.AssertEqual(b, nil, err)
err = gz.Close()
utils.AssertEqual(b, nil, err)
type compressionTest struct {
contentEncoding string
compressWriter func([]byte) ([]byte, error)
}

encodingErr := errors.New("failed to encoding data")
compressionTests := []compressionTest{
{
contentEncoding: "gzip",
compressWriter: func(data []byte) ([]byte, error) {
var buf bytes.Buffer
writer := gzip.NewWriter(&buf)
if _, err := writer.Write(data); err != nil {
return nil, encodingErr
}
if err := writer.Flush(); err != nil {
return nil, encodingErr
}
if err := writer.Close(); err != nil {
return nil, encodingErr
}
return buf.Bytes(), nil
},
},
{
contentEncoding: "deflate",
compressWriter: func(data []byte) ([]byte, error) {
var buf bytes.Buffer
writer := zlib.NewWriter(&buf)
if _, err := writer.Write(data); err != nil {
return nil, encodingErr
}
if err := writer.Flush(); err != nil {
return nil, encodingErr
}
if err := writer.Close(); err != nil {
return nil, encodingErr
}
return buf.Bytes(), nil
},
},
{
contentEncoding: "gzip,deflate",
compressWriter: func(data []byte) ([]byte, error) {
var (
buf bytes.Buffer
writer interface {
io.WriteCloser
Flush() error
}
err error
)

// deflate
{
writer = zlib.NewWriter(&buf)
if _, err = writer.Write(data); err != nil {
return nil, encodingErr
}
if err = writer.Flush(); err != nil {
return nil, encodingErr
}
if err = writer.Close(); err != nil {
return nil, encodingErr
}
}

c.Request().SetBody(buf.Bytes())
data = make([]byte, buf.Len())
copy(data, buf.Bytes())
buf.Reset()

// gzip
{
writer = gzip.NewWriter(&buf)
if _, err = writer.Write(data); err != nil {
return nil, encodingErr
}
if err = writer.Flush(); err != nil {
return nil, encodingErr
}
if err = writer.Close(); err != nil {
return nil, encodingErr
}
}

for i := 0; i < b.N; i++ {
_ = c.Body()
return buf.Bytes(), nil
},
},
}

utils.AssertEqual(b, []byte("john=doe"), c.Body())
for _, ct := range compressionTests {
b.Run(ct.contentEncoding, func(b *testing.B) {
app := New()
const input = "john=doe"
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)

c.Request().Header.Set("Content-Encoding", ct.contentEncoding)
compressedBody, err := ct.compressWriter([]byte(input))
utils.AssertEqual(b, nil, err)

c.Request().SetBody(compressedBody)
for i := 0; i < b.N; i++ {
_ = c.Body()
}

utils.AssertEqual(b, []byte(input), c.Body())
})
}
}

// go test -run Test_Ctx_BodyParser
Expand Down
35 changes: 35 additions & 0 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,41 @@ func acceptsOfferType(spec, offerType string) bool {
return false
}

// getSplicedStrList function takes a string and a string slice as an argument, divides the string into different
// elements divided by ',' and stores these elements in the string slice.
// It returns the populated string slice as an output.
//
// If the given slice hasn't enough space, it will allocate more and return.
func getSplicedStrList(headerValue string, dst []string) []string {
if headerValue == "" {
return nil
}

var (
index int
character rune
lastElementEndsAt uint8
insertIndex int
)
for index, character = range headerValue + "$" {
if character == ',' || index == len(headerValue) {
if insertIndex >= len(dst) {
oldSlice := dst
dst = make([]string, len(dst)+(len(dst)>>1)+2)
copy(dst, oldSlice)
}
dst[insertIndex] = utils.TrimLeft(headerValue[lastElementEndsAt:index], ' ')
lastElementEndsAt = uint8(index + 1)
insertIndex++
}
}

if len(dst) > insertIndex {
dst = dst[:insertIndex]
}
return dst
}

// getOffer return valid offer for header negotiation
func getOffer(header string, isAccepted func(spec, offer string) bool, offers ...string) string {
if len(offers) == 0 {
Expand Down
42 changes: 42 additions & 0 deletions helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,48 @@ func Benchmark_Utils_GetOffer(b *testing.B) {
}
}

func Test_Utils_GetSplicedStrList(t *testing.T) {
testCases := []struct {
description string
headerValue string
expectedList []string
}{
{
description: "normal case",
headerValue: "gzip, deflate,br",
expectedList: []string{"gzip", "deflate", "br"},
},
{
description: "no matter the value",
headerValue: " gzip,deflate, br, zip",
expectedList: []string{"gzip", "deflate", "br", "zip"},
},
{
description: "headerValue is empty",
headerValue: "",
expectedList: nil,
},
}

for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
dst := make([]string, 10)
result := getSplicedStrList(tc.headerValue, dst)
utils.AssertEqual(t, tc.expectedList, result)
})
}
}

func Benchmark_Utils_GetSplicedStrList(b *testing.B) {
destination := make([]string, 5)
result := destination
const input = "deflate, gzip,br,brotli"
for n := 0; n < b.N; n++ {
result = getSplicedStrList(input, destination)
}
utils.AssertEqual(b, []string{"deflate", "gzip", "br", "brotli"}, result)
}

func Test_Utils_SortAcceptedTypes(t *testing.T) {
t.Parallel()
acceptedTypes := []acceptedType{
Expand Down
5 changes: 3 additions & 2 deletions middleware/session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ func Test_Session_Save_Expiration(t *testing.T) {
t.Parallel()

t.Run("save to cookie", func(t *testing.T) {
const sessionDuration = 5 * time.Second
t.Parallel()
// session store
store := New()
Expand All @@ -302,7 +303,7 @@ func Test_Session_Save_Expiration(t *testing.T) {
sess.Set("name", "john")

// expire this session in 5 seconds
sess.SetExpiry(time.Second * 5)
sess.SetExpiry(sessionDuration)

// save session
err = sess.Save()
Expand All @@ -314,7 +315,7 @@ func Test_Session_Save_Expiration(t *testing.T) {
utils.AssertEqual(t, "john", sess.Get("name"))

// just to make sure the session has been expired
time.Sleep(time.Second * 5)
time.Sleep(sessionDuration + (10 * time.Millisecond))

// here you should get a new session
sess, err = store.Get(ctx)
Expand Down
Loading