Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions go/mysql/json/marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"testing"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/sqltypes"
)

func TestMarshalSQLTo(t *testing.T) {
Expand Down Expand Up @@ -55,3 +57,31 @@ func TestMarshalSQLTo(t *testing.T) {
})
}
}

// TestMarshalSQLValuePreservesControlByte verifies that JSON control bytes
// are preserved when marshaled into SQL.
func TestMarshalSQLValuePreservesControlByte(t *testing.T) {
raw := "Foo Bar" + string([]byte{26}) + "a"

val := NewString(raw)

got, err := MarshalSQLValue(val.MarshalTo(nil))
require.NoError(t, err)

expected := "CAST(JSON_QUOTE(_utf8mb4" + sqltypes.EncodeStringSQL(raw) + ") as JSON)"
require.Equal(t, expected, string(got.Raw()))
}

// TestMarshalSQLValueNormalizesInvalidUTF8 verifies that JSON string
// marshaling normalizes invalid UTF-8 before producing SQL.
func TestMarshalSQLValueNormalizesInvalidUTF8(t *testing.T) {
raw := string([]byte{0xff, 0xfe, 'A'})
val := NewString(raw)

got, err := MarshalSQLValue(val.MarshalTo(nil))
require.NoError(t, err)

normalized := string([]rune(raw))
expected := "CAST(JSON_QUOTE(_utf8mb4" + sqltypes.EncodeStringSQL(normalized) + ") as JSON)"
require.Equal(t, expected, string(got.Raw()))
}
51 changes: 49 additions & 2 deletions go/mysql/json/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"strings"
"time"
"unicode/utf16"
"unicode/utf8"

"vitess.io/vitess/go/mysql/fastparse"

Expand Down Expand Up @@ -312,7 +313,18 @@ func parseObject(s string, c *cache, depth int) (*Value, string, error) {
}
}

const hexDigits = "0123456789abcdef"

// escapeString appends s as a JSON string to dst and returns the result.
//
// The output uses JSON compliant escapes for control bytes and only uses the
// short escape sequences for \b, \f, \n, \r, and \t.
func escapeString(dst []byte, s string) []byte {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it's warranted here, but we may want to check if s is valid UTF8 with utf8.ValidString, and if it isn't, we should escape invalid UTF8.

It should always be utf8mb4 so this isn't necessary, but could be good defensively.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep I thought about the same thing. I was going to pull in (copy) how the Go stdlib handles it: https://github.com/golang/go/blob/go1.22.0/src/encoding/json/encode.go#L956-L1025, but chose to keep it simpler instead. I'm inclined to agree to be as defensive as possible though.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 7227058

// If we have invalid UTF-8, normalize it so JSON output stays valid.
if !utf8.ValidString(s) {
s = string([]rune(s))
}

if !hasSpecialChars(s) {
// Fast path - nothing to escape.
dst = append(dst, '"')
Expand All @@ -321,8 +333,43 @@ func escapeString(dst []byte, s string) []byte {
return dst
}

// Slow path.
return strconv.AppendQuote(dst, s)
dst = append(dst, '"')

// Escape control bytes, quotes, and backslashes.
for i := 0; i < len(s); i++ {
ch := s[i]

switch ch {
case '"', '\\':
dst = append(dst, '\\', ch)
case '\b': // 0x08, backspace
dst = append(dst, '\\', 'b')
case '\f': // 0x0C, form feed
dst = append(dst, '\\', 'f')
case '\n': // 0x0A, line feed
dst = append(dst, '\\', 'n')
case '\r': // 0x0D, carriage return
dst = append(dst, '\\', 'r')
case '\t': // 0x09, horizontal tab
dst = append(dst, '\\', 't')
default:
// Other control characters (0x00-0x1F) use \u00XX escapes.
// We hardcode 00 since we only handle single byte control
// characters, then split the byte into two 4-bit halves;
// ch>>4 extracts the upper bits, and ch&0x0f extracts the lower
// bits. Each then indexes into hexDigits to produce the final
// two hex characters.
if ch < 0x20 {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meaning it's a control character. Worth a comment.

dst = append(dst, '\\', 'u', '0', '0', hexDigits[ch>>4], hexDigits[ch&0x0f])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here something like:

We then convert the escape code character to a 6 byte unicode escape string \u00XX. \\ and
u write the literal backslash and u, starting the unicode escape string. The two '0's hardcode
the high byte to 00 because we are only dealing with a single control byte. ch>>4 takes the
top 4 bits of the byte, so hexDigits[ch>>4] is the first hex digit, and ch&0x0f takes the low 4
bits, so hexDigits[ch&0x0f] is the second hex digit.

continue
}

dst = append(dst, ch)
}
}

dst = append(dst, '"')
return dst
}

func hasSpecialChars(s string) bool {
Expand Down
Loading