Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions go/mysql/json/marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"testing"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/sqltypes"
)

func TestMarshalSQLTo(t *testing.T) {
Expand Down Expand Up @@ -55,3 +57,31 @@ func TestMarshalSQLTo(t *testing.T) {
})
}
}

// TestMarshalSQLValuePreservesControlByte verifies that JSON control bytes
// are preserved when marshaled into SQL.
func TestMarshalSQLValuePreservesControlByte(t *testing.T) {
raw := "Foo Bar" + string([]byte{26}) + "a"

val := NewString(raw)

got, err := MarshalSQLValue(val.MarshalTo(nil))
require.NoError(t, err)

expected := "CAST(JSON_QUOTE(_utf8mb4" + sqltypes.EncodeStringSQL(raw) + ") as JSON)"
require.Equal(t, expected, string(got.Raw()))
}

// TestMarshalSQLValueNormalizesInvalidUTF8 verifies that JSON string
// marshaling normalizes invalid UTF-8 before producing SQL.
func TestMarshalSQLValueNormalizesInvalidUTF8(t *testing.T) {
raw := string([]byte{0xff, 0xfe, 'A'})
val := NewString(raw)

got, err := MarshalSQLValue(val.MarshalTo(nil))
require.NoError(t, err)

normalized := string([]rune(raw))
expected := "CAST(JSON_QUOTE(_utf8mb4" + sqltypes.EncodeStringSQL(normalized) + ") as JSON)"
require.Equal(t, expected, string(got.Raw()))
}
51 changes: 49 additions & 2 deletions go/mysql/json/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"strings"
"time"
"unicode/utf16"
"unicode/utf8"

"vitess.io/vitess/go/mysql/fastparse"

Expand Down Expand Up @@ -311,7 +312,18 @@ func parseObject(s string, c *cache, depth int) (*Value, string, error) {
}
}

const hexDigits = "0123456789abcdef"

// escapeString appends s as a JSON string to dst and returns the result.
//
// The output uses JSON compliant escapes for control bytes and only uses the
// short escape sequences for \b, \f, \n, \r, and \t.
func escapeString(dst []byte, s string) []byte {
// If we have invalid UTF-8, normalize it so JSON output stays valid.
if !utf8.ValidString(s) {
s = string([]rune(s))
}

if !hasSpecialChars(s) {
// Fast path - nothing to escape.
dst = append(dst, '"')
Expand All @@ -320,8 +332,43 @@ func escapeString(dst []byte, s string) []byte {
return dst
}

// Slow path.
return strconv.AppendQuote(dst, s)
dst = append(dst, '"')

// Escape control bytes, quotes, and backslashes.
for i := 0; i < len(s); i++ {
ch := s[i]

switch ch {
case '"', '\\':
dst = append(dst, '\\', ch)
case '\b': // 0x08, backspace
dst = append(dst, '\\', 'b')
case '\f': // 0x0C, form feed
dst = append(dst, '\\', 'f')
case '\n': // 0x0A, line feed
dst = append(dst, '\\', 'n')
case '\r': // 0x0D, carriage return
dst = append(dst, '\\', 'r')
case '\t': // 0x09, horizontal tab
dst = append(dst, '\\', 't')
default:
// Other control characters (0x00-0x1F) use \u00XX escapes.
// We hardcode 00 since we only handle single byte control
// characters, then split the byte into two 4-bit halves;
// ch>>4 extracts the upper bits, and ch&0x0f extracts the lower
// bits. Each then indexes into hexDigits to produce the final
// two hex characters.
if ch < 0x20 {
dst = append(dst, '\\', 'u', '0', '0', hexDigits[ch>>4], hexDigits[ch&0x0f])
continue
}

dst = append(dst, ch)
}
}

dst = append(dst, '"')
return dst
}

func hasSpecialChars(s string) bool {
Expand Down
Loading