Skip to content

Commit

Permalink
fix issue #875 (#909)
Browse files Browse the repository at this point in the history
* 🐞 panic in fs.go #824

* fix issue #875

Signed-off-by: Kirill Danshin <[email protected]>

* improve issue 875

Co-authored-by: Fenny <[email protected]>

* Update header.go

* Update header.go

Co-authored-by: Kirill Danshin <[email protected]>

* remove foldReplacer

* Improve removeNewLines

Start replacing at the first character found, use bytes.Indexbyte to
make the function signature more logical. Both bytes.indexByte and
strings.IndexByte use exactly the same code:
https://github.com/golang/go/blob/0c703b37dffe74d3fffc04347884bb0ee2fba5b3/src/internal/bytealg/indexbyte_amd64.s#L8-L20

Co-authored-by: wernerr <[email protected]>
Co-authored-by: wernerr <[email protected]>
Co-authored-by: Fenny <[email protected]>
Co-authored-by: Erik Dubbelboer <[email protected]>
  • Loading branch information
5 people authored Dec 9, 2020
1 parent ec4aa43 commit d0dfbd4
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 22 deletions.
82 changes: 60 additions & 22 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ import (
"time"
)

const (
rChar = byte('\r')
nChar = byte('\n')
)

// ResponseHeader represents HTTP response header.
//
// It is forbidden copying ResponseHeader instances.
Expand Down Expand Up @@ -1419,7 +1424,7 @@ func bufferSnippet(b []byte) string {

func isOnlyCRLF(b []byte) bool {
for _, ch := range b {
if ch != '\r' && ch != '\n' {
if ch != rChar && ch != nChar {
return false
}
}
Expand Down Expand Up @@ -1731,7 +1736,7 @@ func peekRawHeader(buf, key []byte) []byte {
if n < 0 {
return nil
}
if n > 0 && buf[n-1] != '\n' {
if n > 0 && buf[n-1] != nChar {
return nil
}
n += len(key)
Expand All @@ -1747,22 +1752,22 @@ func peekRawHeader(buf, key []byte) []byte {
}
n++
buf = buf[n:]
n = bytes.IndexByte(buf, '\n')
n = bytes.IndexByte(buf, nChar)
if n < 0 {
return nil
}
if n > 0 && buf[n-1] == '\r' {
if n > 0 && buf[n-1] == rChar {
n--
}
return buf[:n]
}

func readRawHeaders(dst, buf []byte) ([]byte, int, error) {
n := bytes.IndexByte(buf, '\n')
n := bytes.IndexByte(buf, nChar)
if n < 0 {
return dst[:0], 0, errNeedMore
}
if (n == 1 && buf[0] == '\r') || n == 0 {
if (n == 1 && buf[0] == rChar) || n == 0 {
// empty headers
return dst, n + 1, nil
}
Expand All @@ -1772,13 +1777,13 @@ func readRawHeaders(dst, buf []byte) ([]byte, int, error) {
m := n
for {
b = b[m:]
m = bytes.IndexByte(b, '\n')
m = bytes.IndexByte(b, nChar)
if m < 0 {
return dst, 0, errNeedMore
}
m++
n += m
if (m == 2 && b[0] == '\r') || m == 1 {
if (m == 2 && b[0] == rChar) || m == 1 {
dst = append(dst, buf[:n]...)
return dst, n, nil
}
Expand Down Expand Up @@ -2011,12 +2016,12 @@ func (s *headerScanner) next() bool {
s.initialized = true
}
bLen := len(s.b)
if bLen >= 2 && s.b[0] == '\r' && s.b[1] == '\n' {
if bLen >= 2 && s.b[0] == rChar && s.b[1] == nChar {
s.b = s.b[2:]
s.hLen += 2
return false
}
if bLen >= 1 && s.b[0] == '\n' {
if bLen >= 1 && s.b[0] == nChar {
s.b = s.b[1:]
s.hLen++
return false
Expand All @@ -2029,7 +2034,7 @@ func (s *headerScanner) next() bool {
n = bytes.IndexByte(s.b, ':')

// There can't be a \n inside the header name, check for this.
x := bytes.IndexByte(s.b, '\n')
x := bytes.IndexByte(s.b, nChar)
if x < 0 {
// A header name should always at some point be followed by a \n
// even if it's the one that terminates the header block.
Expand Down Expand Up @@ -2062,7 +2067,7 @@ func (s *headerScanner) next() bool {
n = s.nextNewLine
s.nextNewLine = -1
} else {
n = bytes.IndexByte(s.b, '\n')
n = bytes.IndexByte(s.b, nChar)
}
if n < 0 {
s.err = errNeedMore
Expand All @@ -2076,10 +2081,10 @@ func (s *headerScanner) next() bool {
if s.b[n+1] != ' ' && s.b[n+1] != '\t' {
break
}
d := bytes.IndexByte(s.b[n+1:], '\n')
d := bytes.IndexByte(s.b[n+1:], nChar)
if d <= 0 {
break
} else if d == 1 && s.b[n+1] == '\r' {
} else if d == 1 && s.b[n+1] == rChar {
break
}
e := n + d + 1
Expand All @@ -2100,7 +2105,7 @@ func (s *headerScanner) next() bool {
s.hLen += n + 1
s.b = s.b[n+1:]

if n > 0 && s.value[n-1] == '\r' {
if n > 0 && s.value[n-1] == rChar {
n--
}
for n > 0 && s.value[n-1] == ' ' {
Expand Down Expand Up @@ -2156,20 +2161,22 @@ func hasHeaderValue(s, value []byte) bool {
}

func nextLine(b []byte) ([]byte, []byte, error) {
nNext := bytes.IndexByte(b, '\n')
nNext := bytes.IndexByte(b, nChar)
if nNext < 0 {
return nil, nil, errNeedMore
}
n := nNext
if n > 0 && b[n-1] == '\r' {
if n > 0 && b[n-1] == rChar {
n--
}
return b[:n], b[nNext+1:], nil
}

func initHeaderKV(kv *argsKV, key, value string, disableNormalizing bool) {
kv.key = getHeaderKeyBytes(kv, key, disableNormalizing)
// https://tools.ietf.org/html/rfc7230#section-3.2.4
kv.value = append(kv.value[:0], value...)
kv.value = removeNewLines(kv.value)
}

func getHeaderKeyBytes(kv *argsKV, key string, disableNormalizing bool) []byte {
Expand All @@ -2189,9 +2196,9 @@ func normalizeHeaderValue(ov, ob []byte, headerLength int) (nv, nb []byte, nhl i
lineStart := false
for read := 0; read < length; read++ {
c := ov[read]
if c == '\r' || c == '\n' {
if c == rChar || c == nChar {
shrunk++
if c == '\n' {
if c == nChar {
lineStart = true
}
continue
Expand All @@ -2209,13 +2216,13 @@ func normalizeHeaderValue(ov, ob []byte, headerLength int) (nv, nb []byte, nhl i

// Check if we need to skip \r\n or just \n
skip := 0
if ob[write] == '\r' {
if ob[write+1] == '\n' {
if ob[write] == rChar {
if ob[write+1] == nChar {
skip += 2
} else {
skip++
}
} else if ob[write] == '\n' {
} else if ob[write] == nChar {
skip++
}

Expand Down Expand Up @@ -2248,6 +2255,37 @@ func normalizeHeaderKey(b []byte, disableNormalizing bool) {
}
}

// removeNewLines will replace `\r` and `\n` with an empty space
func removeNewLines(raw []byte) []byte {
// check if a `\r` is present and save the position.
// if no `\r` is found, check if a `\n` is present.
foundR := bytes.IndexByte(raw, rChar)
foundN := bytes.IndexByte(raw, nChar)
start := 0

if foundN != -1 {
if foundR > foundN {
start = foundN
} else if foundR != -1 {
start = foundR
}
} else if foundR != -1 {
start = foundR
} else {
return raw
}

for i := start; i < len(raw); i++ {
switch raw[i] {
case rChar, nChar:
raw[i] = ' '
default:
continue
}
}
return raw
}

// AppendNormalizedHeaderKey appends normalized header key (name) to dst
// and returns the resulting dst.
//
Expand Down
32 changes: 32 additions & 0 deletions header_timing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"io"
"strconv"
"testing"

"github.com/valyala/bytebufferpool"
Expand Down Expand Up @@ -146,3 +147,34 @@ func benchmarkNormalizeHeaderKey(b *testing.B, src []byte) {
}
})
}

func BenchmarkRemoveNewLines(b *testing.B) {
type testcase struct {
value string
expectedValue string
}

var testcases = []testcase{
{value: "MaliciousValue", expectedValue: "MaliciousValue"},
{value: "MaliciousValue\r\n", expectedValue: "MaliciousValue "},
{value: "Malicious\nValue", expectedValue: "Malicious Value"},
{value: "Malicious\rValue", expectedValue: "Malicious Value"},
}

for i, tcase := range testcases {
caseName := strconv.FormatInt(int64(i), 10)
b.Run(caseName, func(subB *testing.B) {
subB.ReportAllocs()
var h RequestHeader
for i := 0; i < subB.N; i++ {
h.Set("Test", tcase.value)
}
subB.StopTimer()
actualValue := string(h.Peek("Test"))

if actualValue != tcase.expectedValue {
subB.Errorf("unexpected value, got: %+v", actualValue)
}
})
}
}
48 changes: 48 additions & 0 deletions http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io/ioutil"
"mime/multipart"
"reflect"
"strconv"
"strings"
"testing"
"time"
Expand All @@ -30,6 +31,53 @@ func TestFragmentInURIRequest(t *testing.T) {
}
}

func TestIssue875(t *testing.T) {
type testcase struct {
uri string
expectedRedirect string
expectedLocation string
}

var testcases = []testcase{
{
uri: `http://localhost:3000/?redirect=foo%0d%0aSet-Cookie:%20SESSIONID=MaliciousValue%0d%0a`,
expectedRedirect: "foo\r\nSet-Cookie: SESSIONID=MaliciousValue\r\n",
expectedLocation: "Location: foo Set-Cookie: SESSIONID=MaliciousValue",
},
{
uri: `http://localhost:3000/?redirect=foo%0dSet-Cookie:%20SESSIONID=MaliciousValue%0d%0a`,
expectedRedirect: "foo\rSet-Cookie: SESSIONID=MaliciousValue\r\n",
expectedLocation: "Location: foo Set-Cookie: SESSIONID=MaliciousValue",
},
{
uri: `http://localhost:3000/?redirect=foo%0aSet-Cookie:%20SESSIONID=MaliciousValue%0d%0a`,
expectedRedirect: "foo\nSet-Cookie: SESSIONID=MaliciousValue\r\n",
expectedLocation: "Location: foo Set-Cookie: SESSIONID=MaliciousValue",
},
}

for i, tcase := range testcases {
caseName := strconv.FormatInt(int64(i), 10)
t.Run(caseName, func(subT *testing.T) {
ctx := &RequestCtx{
Request: Request{},
Response: Response{},
}
ctx.Request.SetRequestURI(tcase.uri)

q := string(ctx.QueryArgs().Peek("redirect"))
if q != tcase.expectedRedirect {
subT.Errorf("unexpected redirect query value, got: %+v", q)
}
ctx.Response.Header.Set("Location", q)

if !strings.Contains(ctx.Response.String(), tcase.expectedLocation) {
subT.Errorf("invalid escaping, got\n%s", ctx.Response.String())
}
})
}
}

func TestRequestCopyTo(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit d0dfbd4

Please sign in to comment.