Skip to content

Commit

Permalink
Fix header .Add functions (#1036)
Browse files Browse the repository at this point in the history
These functions should take the headers that are handled differently
into account.
  • Loading branch information
erikdubbelboer authored Jun 1, 2021
1 parent 5bb5cfc commit 6233fbc
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 81 deletions.
6 changes: 2 additions & 4 deletions fasthttpadaptor/adaptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ func TestNewFastHTTPHandler(t *testing.T) {
expectedRequestURI := "/foo/bar?baz=123"
expectedBody := "body 123 foo bar baz"
expectedContentLength := len(expectedBody)
expectedTransferEncoding := "encoding"
expectedHost := "foobar.com"
expectedRemoteAddr := "1.2.3.4:6789"
expectedHeader := map[string]string{
Expand Down Expand Up @@ -56,8 +55,8 @@ func TestNewFastHTTPHandler(t *testing.T) {
if r.ContentLength != int64(expectedContentLength) {
t.Fatalf("unexpected contentLength %d. Expecting %d", r.ContentLength, expectedContentLength)
}
if len(r.TransferEncoding) != 1 || r.TransferEncoding[0] != expectedTransferEncoding {
t.Fatalf("unexpected transferEncoding %q. Expecting %q", r.TransferEncoding, expectedTransferEncoding)
if len(r.TransferEncoding) != 0 {
t.Fatalf("unexpected transferEncoding %q. Expecting []", r.TransferEncoding)
}
if r.Host != expectedHost {
t.Fatalf("unexpected host %q. Expecting %q", r.Host, expectedHost)
Expand Down Expand Up @@ -101,7 +100,6 @@ func TestNewFastHTTPHandler(t *testing.T) {
req.Header.SetMethod(expectedMethod)
req.SetRequestURI(expectedRequestURI)
req.Header.SetHost(expectedHost)
req.Header.Add(fasthttp.HeaderTransferEncoding, expectedTransferEncoding)
req.BodyWriter().Write([]byte(expectedBody)) // nolint:errcheck
for k, v := range expectedHeader {
req.Header.Set(k, v)
Expand Down
212 changes: 148 additions & 64 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -917,37 +917,159 @@ func (h *RequestHeader) del(key []byte) {
h.h = delAllArgsBytes(h.h, key)
}

// setSpecialHeader handles special headers and return true when a header is processed.
func (h *ResponseHeader) setSpecialHeader(key, value []byte) bool {
if len(key) == 0 {
return false
}

switch key[0] | 0x20 {
case 'c':
if caseInsensitiveCompare(strContentType, key) {
h.SetContentTypeBytes(value)
return true
} else if caseInsensitiveCompare(strContentLength, key) {
if contentLength, err := parseContentLength(value); err == nil {
h.contentLength = contentLength
h.contentLengthBytes = append(h.contentLengthBytes[:0], value...)
}
return true
} else if caseInsensitiveCompare(strConnection, key) {
if bytes.Equal(strClose, value) {
h.SetConnectionClose()
} else {
h.ResetConnectionClose()
h.h = setArgBytes(h.h, key, value, argsHasValue)
}
return true
}
case 's':
if caseInsensitiveCompare(strServer, key) {
h.SetServerBytes(value)
return true
} else if caseInsensitiveCompare(strSetCookie, key) {
var kv *argsKV
h.cookies, kv = allocArg(h.cookies)
kv.key = getCookieKey(kv.key, value)
kv.value = append(kv.value[:0], value...)
return true
}
case 't':
if caseInsensitiveCompare(strTransferEncoding, key) {
// Transfer-Encoding is managed automatically.
return true
}
case 'd':
if caseInsensitiveCompare(strDate, key) {
// Date is managed automatically.
return true
}
}

return false
}

// setSpecialHeader handles special headers and return true when a header is processed.
func (h *RequestHeader) setSpecialHeader(key, value []byte) bool {
if len(key) == 0 {
return false
}

switch key[0] | 0x20 {
case 'c':
if caseInsensitiveCompare(strContentType, key) {
h.SetContentTypeBytes(value)
return true
} else if caseInsensitiveCompare(strContentLength, key) {
if contentLength, err := parseContentLength(value); err == nil {
h.contentLength = contentLength
h.contentLengthBytes = append(h.contentLengthBytes[:0], value...)
}
return true
} else if caseInsensitiveCompare(strConnection, key) {
if bytes.Equal(strClose, value) {
h.SetConnectionClose()
} else {
h.ResetConnectionClose()
h.h = setArgBytes(h.h, key, value, argsHasValue)
}
return true
} else if caseInsensitiveCompare(strCookie, key) {
h.collectCookies()
h.cookies = parseRequestCookies(h.cookies, value)
return true
}
case 't':
if caseInsensitiveCompare(strTransferEncoding, key) {
// Transfer-Encoding is managed automatically.
return true
}
case 'h':
if caseInsensitiveCompare(strHost, key) {
h.SetHostBytes(value)
return true
}
case 'u':
if caseInsensitiveCompare(strUserAgent, key) {
h.SetUserAgentBytes(value)
return true
}
}

return false
}

// Add adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use Set for setting a single header for the given key.
//
// the Content-Type, Content-Length, Connection, Server, Set-Cookie,
// Transfer-Encoding and Date headers can only be set once and will
// overwrite the previous value.
func (h *ResponseHeader) Add(key, value string) {
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
h.h = appendArg(h.h, b2s(k), value, argsHasValue)
h.AddBytesKV(s2b(key), s2b(value))
}

// AddBytesK adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesK for setting a single header for the given key.
//
// the Content-Type, Content-Length, Connection, Server, Set-Cookie,
// Transfer-Encoding and Date headers can only be set once and will
// overwrite the previous value.
func (h *ResponseHeader) AddBytesK(key []byte, value string) {
h.Add(b2s(key), value)
h.AddBytesKV(key, s2b(value))
}

// AddBytesV adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesV for setting a single header for the given key.
//
// the Content-Type, Content-Length, Connection, Server, Set-Cookie,
// Transfer-Encoding and Date headers can only be set once and will
// overwrite the previous value.
func (h *ResponseHeader) AddBytesV(key string, value []byte) {
h.Add(key, b2s(value))
h.AddBytesKV(s2b(key), value)
}

// AddBytesKV adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesKV for setting a single header for the given key.
//
// the Content-Type, Content-Length, Connection, Server, Set-Cookie,
// Transfer-Encoding and Date headers can only be set once and will
// overwrite the previous value.
func (h *ResponseHeader) AddBytesKV(key, value []byte) {
h.Add(b2s(key), b2s(value))
if h.setSpecialHeader(key, value) {
return
}

k := getHeaderKeyBytes(&h.bufKV, b2s(key), h.disableNormalizing)
h.h = appendArgBytes(h.h, k, value, argsHasValue)
}

// Set sets the given 'key: value' header.
Expand Down Expand Up @@ -986,35 +1108,11 @@ func (h *ResponseHeader) SetBytesKV(key, value []byte) {
// SetCanonical sets the given 'key: value' header assuming that
// key is in canonical form.
func (h *ResponseHeader) SetCanonical(key, value []byte) {
switch string(key) {
case HeaderContentType:
h.SetContentTypeBytes(value)
case HeaderServer:
h.SetServerBytes(value)
case HeaderSetCookie:
var kv *argsKV
h.cookies, kv = allocArg(h.cookies)
kv.key = getCookieKey(kv.key, value)
kv.value = append(kv.value[:0], value...)
case HeaderContentLength:
if contentLength, err := parseContentLength(value); err == nil {
h.contentLength = contentLength
h.contentLengthBytes = append(h.contentLengthBytes[:0], value...)
}
case HeaderConnection:
if bytes.Equal(strClose, value) {
h.SetConnectionClose()
} else {
h.ResetConnectionClose()
h.h = setArgBytes(h.h, key, value, argsHasValue)
}
case HeaderTransferEncoding:
// Transfer-Encoding is managed automatically.
case HeaderDate:
// Date is managed automatically.
default:
h.h = setArgBytes(h.h, key, value, argsHasValue)
if h.setSpecialHeader(key, value) {
return
}

h.h = setArgBytes(h.h, key, value, argsHasValue)
}

// SetCookie sets the given response cookie.
Expand Down Expand Up @@ -1123,32 +1221,40 @@ func (h *RequestHeader) DelAllCookies() {
// Multiple headers with the same key may be added with this function.
// Use Set for setting a single header for the given key.
func (h *RequestHeader) Add(key, value string) {
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
h.h = appendArg(h.h, b2s(k), value, argsHasValue)
h.AddBytesKV(s2b(key), s2b(value))
}

// AddBytesK adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesK for setting a single header for the given key.
func (h *RequestHeader) AddBytesK(key []byte, value string) {
h.Add(b2s(key), value)
h.AddBytesKV(key, s2b(value))
}

// AddBytesV adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesV for setting a single header for the given key.
func (h *RequestHeader) AddBytesV(key string, value []byte) {
h.Add(key, b2s(value))
h.AddBytesKV(s2b(key), value)
}

// AddBytesKV adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesKV for setting a single header for the given key.
//
// the Content-Type, Content-Length, Connection, Cookie,
// Transfer-Encoding, Host and User-Agent headers can only be set once
// and will overwrite the previous value.
func (h *RequestHeader) AddBytesKV(key, value []byte) {
h.Add(b2s(key), b2s(value))
if h.setSpecialHeader(key, value) {
return
}

k := getHeaderKeyBytes(&h.bufKV, b2s(key), h.disableNormalizing)
h.h = appendArgBytes(h.h, k, value, argsHasValue)
}

// Set sets the given 'key: value' header.
Expand Down Expand Up @@ -1187,33 +1293,11 @@ func (h *RequestHeader) SetBytesKV(key, value []byte) {
// SetCanonical sets the given 'key: value' header assuming that
// key is in canonical form.
func (h *RequestHeader) SetCanonical(key, value []byte) {
switch string(key) {
case HeaderHost:
h.SetHostBytes(value)
case HeaderContentType:
h.SetContentTypeBytes(value)
case HeaderUserAgent:
h.SetUserAgentBytes(value)
case HeaderCookie:
h.collectCookies()
h.cookies = parseRequestCookies(h.cookies, value)
case HeaderContentLength:
if contentLength, err := parseContentLength(value); err == nil {
h.contentLength = contentLength
h.contentLengthBytes = append(h.contentLengthBytes[:0], value...)
}
case HeaderConnection:
if bytes.Equal(strClose, value) {
h.SetConnectionClose()
} else {
h.ResetConnectionClose()
h.h = setArgBytes(h.h, key, value, argsHasValue)
}
case HeaderTransferEncoding:
// Transfer-Encoding is managed automatically.
default:
h.h = setArgBytes(h.h, key, value, argsHasValue)
if h.setSpecialHeader(key, value) {
return
}

h.h = setArgBytes(h.h, key, value, argsHasValue)
}

// Peek returns header value for the given key.
Expand Down
Loading

0 comments on commit 6233fbc

Please sign in to comment.