diff --git a/server/client.go b/server/client.go index b7ef2ba60b5..af670bd33b1 100644 --- a/server/client.go +++ b/server/client.go @@ -4345,7 +4345,7 @@ func (c *client) setupResponseServiceImport(acc *Account, si *serviceImport, tra // Will remove a header if present. func removeHeaderIfPresent(hdr []byte, key string) []byte { - start := bytes.Index(hdr, []byte(key+":")) + start := getHeaderKeyIndex(key, hdr) // key can't be first and we want to check that it is preceded by a '\n' if start < 1 || hdr[start-1] != '\n' { return hdr @@ -4463,22 +4463,13 @@ func sliceHeader(key string, hdr []byte) []byte { if len(hdr) == 0 { return nil } - index := bytes.Index(hdr, stringToBytes(key+":")) - hdrLen := len(hdr) - // Check that we have enough characters, this will handle the -1 case of the key not - // being found and will also handle not having enough characters for trailing CRLF. - if index < 2 { - return nil - } - // There should be a terminating CRLF. - if index >= hdrLen-1 || hdr[index-1] != '\n' || hdr[index-2] != '\r' { + index := getHeaderKeyIndex(key, hdr) + if index == -1 { return nil } - // The key should be immediately followed by a : separator. + // Skip over the key and the : separator. index += len(key) + 1 - if index >= hdrLen || hdr[index-1] != ':' { - return nil - } + hdrLen := len(hdr) // Skip over whitespace before the value. for index < hdrLen && hdr[index] == ' ' { index++ @@ -4494,11 +4485,49 @@ func sliceHeader(key string, hdr []byte) []byte { return hdr[start:index:index] } +// getHeaderKeyIndex returns an index into the header slice for the given key. +// Returns -1 if not found. +func getHeaderKeyIndex(key string, hdr []byte) int { + if len(hdr) == 0 { + return -1 + } + bkey := stringToBytes(key) + keyLen, hdrLen := len(key), len(hdr) + var offset int + for { + index := bytes.Index(hdr[offset:], bkey) + // Check that we have enough characters, this will handle the -1 case of the key not + // being found and will also handle not having enough characters for trailing CRLF. + if index < 2 { + return -1 + } + index += offset + // There should be a terminating CRLF. + if index >= hdrLen-1 || hdr[index-1] != '\n' || hdr[index-2] != '\r' { + offset = index + keyLen + continue + } + // The key should be immediately followed by a : separator. + if index+keyLen >= hdrLen { + return -1 + } + if hdr[index+keyLen] != ':' { + offset = index + keyLen + continue + } + return index + } +} + func setHeader(key, val string, hdr []byte) []byte { - prefix := []byte(key + ": ") - start := bytes.Index(hdr, prefix) + start := getHeaderKeyIndex(key, hdr) if start >= 0 { - valStart := start + len(prefix) + valStart := start + len(key) + 1 + // Preserve single whitespace if used. + hdrLen := len(hdr) + if valStart < hdrLen && hdr[valStart] == ' ' { + valStart++ + } valEnd := bytes.Index(hdr[valStart:], []byte("\r")) if valEnd < 0 { return hdr // malformed headers diff --git a/server/client_test.go b/server/client_test.go index 88b6baf0063..03d0fca2e72 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -3195,7 +3195,7 @@ func TestSliceHeader(t *testing.T) { require_True(t, bytes.Equal(sliced, copied)) } -func TestSliceHeaderOrdering(t *testing.T) { +func TestSliceHeaderOrderingPrefix(t *testing.T) { hdr := []byte("NATS/1.0\r\n\r\n") // These headers share the same prefix, the longer subject @@ -3215,6 +3215,105 @@ func TestSliceHeaderOrdering(t *testing.T) { require_True(t, bytes.Equal(sliced, copied)) } +func TestSliceHeaderOrderingSuffix(t *testing.T) { + hdr := []byte("NATS/1.0\r\n\r\n") + + // These headers share the same suffix, the longer subject + // must not invalidate the existence of the shorter one. + hdr = genHeader(hdr, "Previous-Nats-Msg-Id", "user") + hdr = genHeader(hdr, "Nats-Msg-Id", "control") + + sliced := sliceHeader("Nats-Msg-Id", hdr) + copied := getHeader("Nats-Msg-Id", hdr) + + require_NotNil(t, sliced) + require_NotNil(t, copied) + require_True(t, bytes.Equal(sliced, copied)) + require_Equal(t, string(copied), "control") +} + +func TestRemoveHeaderIfPresentOrderingPrefix(t *testing.T) { + hdr := []byte("NATS/1.0\r\n\r\n") + + // These headers share the same prefix, the longer subject + // must not invalidate the existence of the shorter one. + hdr = genHeader(hdr, JSExpectedLastSubjSeqSubj, "foo") + hdr = genHeader(hdr, JSExpectedLastSubjSeq, "24") + + hdr = removeHeaderIfPresent(hdr, JSExpectedLastSubjSeq) + ehdr := genHeader(nil, JSExpectedLastSubjSeqSubj, "foo") + require_True(t, bytes.Equal(hdr, ehdr)) +} + +func TestRemoveHeaderIfPresentOrderingSuffix(t *testing.T) { + hdr := []byte("NATS/1.0\r\n\r\n") + + // These headers share the same suffix, the longer subject + // must not invalidate the existence of the shorter one. + hdr = genHeader(hdr, "Previous-Nats-Msg-Id", "user") + hdr = genHeader(hdr, "Nats-Msg-Id", "control") + + hdr = removeHeaderIfPresent(hdr, "Nats-Msg-Id") + ehdr := genHeader(nil, "Previous-Nats-Msg-Id", "user") + require_True(t, bytes.Equal(hdr, ehdr)) +} + +func TestSetHeaderOrderingPrefix(t *testing.T) { + for _, space := range []bool{true, false} { + title := "Normal" + if !space { + title = "Trimmed" + } + t.Run(title, func(t *testing.T) { + hdr := []byte("NATS/1.0\r\n\r\n") + + // These headers share the same prefix, the longer subject + // must not invalidate the existence of the shorter one. + hdr = genHeader(hdr, JSExpectedLastSubjSeqSubj, "foo") + hdr = genHeader(hdr, JSExpectedLastSubjSeq, "24") + if !space { + hdr = bytes.ReplaceAll(hdr, []byte(" "), nil) + } + + hdr = setHeader(JSExpectedLastSubjSeq, "12", hdr) + ehdr := genHeader(nil, JSExpectedLastSubjSeqSubj, "foo") + ehdr = genHeader(ehdr, JSExpectedLastSubjSeq, "12") + if !space { + ehdr = bytes.ReplaceAll(ehdr, []byte(" "), nil) + } + require_True(t, bytes.Equal(hdr, ehdr)) + }) + } +} + +func TestSetHeaderOrderingSuffix(t *testing.T) { + for _, space := range []bool{true, false} { + title := "Normal" + if !space { + title = "Trimmed" + } + t.Run(title, func(t *testing.T) { + hdr := []byte("NATS/1.0\r\n\r\n") + + // These headers share the same suffix, the longer subject + // must not invalidate the existence of the shorter one. + hdr = genHeader(hdr, "Previous-Nats-Msg-Id", "user") + hdr = genHeader(hdr, "Nats-Msg-Id", "control") + if !space { + hdr = bytes.ReplaceAll(hdr, []byte(" "), nil) + } + + hdr = setHeader("Nats-Msg-Id", "other", hdr) + ehdr := genHeader(nil, "Previous-Nats-Msg-Id", "user") + ehdr = genHeader(ehdr, "Nats-Msg-Id", "other") + if !space { + ehdr = bytes.ReplaceAll(ehdr, []byte(" "), nil) + } + require_True(t, bytes.Equal(hdr, ehdr)) + }) + } +} + func TestInProcessAllowedConnectionType(t *testing.T) { tmpl := ` listen: "127.0.0.1:-1"