Skip to content

Commit

Permalink
fix: didn't consider json.Marshaler/Unmarshal when handling `json:"…
Browse files Browse the repository at this point in the history
…,string"` tag (#682)

Co-authored-by: liuqiang.06 <liuqiang.06@bytedance.com>
  • Loading branch information
AsterDY and liuq19 authored Aug 6, 2024
1 parent 1a0c001 commit bc420fc
Show file tree
Hide file tree
Showing 9 changed files with 326 additions and 33 deletions.
29 changes: 24 additions & 5 deletions internal/decoder/jitdec/assembler_regabi_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -972,11 +972,13 @@ var (

var (
_F_decodeJsonUnmarshaler obj.Addr
_F_decodeJsonUnmarshalerQuoted obj.Addr
_F_decodeTextUnmarshaler obj.Addr
)

func init() {
_F_decodeJsonUnmarshaler = jit.Func(decodeJsonUnmarshaler)
_F_decodeJsonUnmarshalerQuoted = jit.Func(decodeJsonUnmarshalerQuoted)
_F_decodeTextUnmarshaler = jit.Func(decodeTextUnmarshaler)
}

Expand Down Expand Up @@ -1061,14 +1063,15 @@ var (
_F_skip_number = jit.Imm(int64(native.S_skip_number))
)

func (self *_Assembler) unmarshal_json(t reflect.Type, deref bool) {
func (self *_Assembler) unmarshal_json(t reflect.Type, deref bool, f obj.Addr) {
self.call_sf(_F_skip_one) // CALL_SF skip_one
self.Emit("TESTQ", _AX, _AX) // TESTQ AX, AX
self.Sjmp("JS" , _LB_parsing_error_v) // JS _parse_error_v
self.Emit("MOVQ", _IC, _VAR_ic) // store for mismatche error skip
self.slice_from_r(_AX, 0) // SLICE_R AX, $0
self.Emit("MOVQ" , _DI, _ARG_sv_p) // MOVQ DI, sv.p
self.Emit("MOVQ" , _SI, _ARG_sv_n) // MOVQ SI, sv.n
self.unmarshal_func(t, _F_decodeJsonUnmarshaler, deref) // UNMARSHAL json, ${t}, ${deref}
self.unmarshal_func(t, f, deref) // UNMARSHAL json, ${t}, ${deref}
}

func (self *_Assembler) unmarshal_text(t reflect.Type, deref bool) {
Expand Down Expand Up @@ -1103,7 +1106,15 @@ func (self *_Assembler) unmarshal_func(t reflect.Type, fn obj.Addr, deref bool)
self.Emit("MOVQ" , _ARG_sv_n, _DI) // MOVQ sv.n, DI
self.call_go(fn) // CALL_GO ${fn}
self.Emit("TESTQ", _ET, _ET) // TESTQ ET, ET
self.Sjmp("JNZ" , _LB_error) // JNZ _error
self.Sjmp("JZ" , "_unmarshal_func_end_{n}") // JNZ _error
self.Emit("MOVQ", _I_json_MismatchTypeError, _CX) // MOVQ ET, VAR.et
self.Emit("CMPQ", _ET, _CX) // check if MismatchedError
self.Sjmp("JNE" , _LB_error)
self.Emit("MOVQ", jit.Type(t), _CX) // store current type
self.Emit("MOVQ", _CX, _VAR_et) // store current type
self.Emit("MOVQ", _VAR_ic, _IC) // recover the pos
self.Emit("XORL", _ET, _ET)
self.Link("_unmarshal_func_end_{n}")
}

/** Dynamic Decoding Routine **/
Expand Down Expand Up @@ -1774,11 +1785,19 @@ func (self *_Assembler) _asm_OP_struct_field(p *_Instr) {
}

func (self *_Assembler) _asm_OP_unmarshal(p *_Instr) {
self.unmarshal_json(p.vt(), true)
if iv := p.i64(); iv != 0 {
self.unmarshal_json(p.vt(), true, _F_decodeJsonUnmarshalerQuoted)
} else {
self.unmarshal_json(p.vt(), true, _F_decodeJsonUnmarshaler)
}
}

func (self *_Assembler) _asm_OP_unmarshal_p(p *_Instr) {
self.unmarshal_json(p.vt(), false)
if iv := p.i64(); iv != 0 {
self.unmarshal_json(p.vt(), false, _F_decodeJsonUnmarshalerQuoted)
} else {
self.unmarshal_json(p.vt(), false, _F_decodeJsonUnmarshaler)
}
}

func (self *_Assembler) _asm_OP_unmarshal_text(p *_Instr) {
Expand Down
77 changes: 61 additions & 16 deletions internal/decoder/jitdec/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,13 @@ func newInsVt(op _Op, vt reflect.Type) _Instr {
}
}

func newInsVtI(op _Op, vt reflect.Type, iv int) _Instr {
return _Instr {
u: packOp(op) | rt.PackInt(iv),
p: unsafe.Pointer(rt.UnpackType(vt)),
}
}

func newInsVf(op _Op, vf *caching.FieldMap) _Instr {
return _Instr {
u: packOp(op),
Expand Down Expand Up @@ -452,6 +459,10 @@ func (self *_Program) rtt(op _Op, vt reflect.Type) {
*self = append(*self, newInsVt(op, vt))
}

func (self *_Program) rtti(op _Op, vt reflect.Type, iv int) {
*self = append(*self, newInsVtI(op, vt, iv))
}

func (self *_Program) fmv(op _Op, vf *caching.FieldMap) {
*self = append(*self, newInsVf(op, vf))
}
Expand Down Expand Up @@ -527,35 +538,54 @@ func (self *_Compiler) compile(vt reflect.Type) (ret _Program, err error) {
return
}

func (self *_Compiler) checkMarshaler(p *_Program, vt reflect.Type) bool {
const (
checkMarshalerFlags_quoted = 1
)

func (self *_Compiler) checkMarshaler(p *_Program, vt reflect.Type, flags int, exec bool) bool {
pt := reflect.PtrTo(vt)

/* check for `json.Unmarshaler` with pointer receiver */
if pt.Implements(jsonUnmarshalerType) {
p.rtt(_OP_unmarshal_p, pt)
if exec {
p.add(_OP_lspace)
p.rtti(_OP_unmarshal_p, pt, flags)
}
return true
}

/* check for `json.Unmarshaler` */
if vt.Implements(jsonUnmarshalerType) {
p.add(_OP_lspace)
self.compileUnmarshalJson(p, vt)
if exec {
p.add(_OP_lspace)
self.compileUnmarshalJson(p, vt, flags)
}
return true
}

if flags == checkMarshalerFlags_quoted {
// text marshaler shouldn't be supported for quoted string
return false
}

/* check for `encoding.TextMarshaler` with pointer receiver */
if pt.Implements(encodingTextUnmarshalerType) {
p.add(_OP_lspace)
self.compileUnmarshalTextPtr(p, pt)
if exec {
p.add(_OP_lspace)
self.compileUnmarshalTextPtr(p, pt, flags)
}
return true
}

/* check for `encoding.TextUnmarshaler` */
if vt.Implements(encodingTextUnmarshalerType) {
p.add(_OP_lspace)
self.compileUnmarshalText(p, vt)
if exec {
p.add(_OP_lspace)
self.compileUnmarshalText(p, vt, flags)
}
return true
}

return false
}

Expand All @@ -567,7 +597,7 @@ func (self *_Compiler) compileOne(p *_Program, sp int, vt reflect.Type) {
return
}

if self.checkMarshaler(p, vt) {
if self.checkMarshaler(p, vt, 0, true) {
return
}

Expand Down Expand Up @@ -690,7 +720,7 @@ func (self *_Compiler) compilePtr(p *_Program, sp int, et reflect.Type) {

/* dereference all the way down */
for et.Kind() == reflect.Ptr {
if self.checkMarshaler(p, et) {
if self.checkMarshaler(p, et, 0, true) {
return
}
et = et.Elem()
Expand Down Expand Up @@ -938,7 +968,22 @@ end_of_object:
p.pin(skip)
}

func (self *_Compiler) compileStructFieldStrUnmarshal(p *_Program, vt reflect.Type) {
p.add(_OP_lspace)
n0 := p.pc()
p.add(_OP_is_null)
self.checkMarshaler(p, vt, checkMarshalerFlags_quoted, true)
p.pin(n0)
}

func (self *_Compiler) compileStructFieldStr(p *_Program, sp int, vt reflect.Type) {
// according to std, json.Unmarshaler should be called before stringize
// see https://github.com/bytedance/sonic/issues/670
if self.checkMarshaler(p, vt, checkMarshalerFlags_quoted, false) {
self.compileStructFieldStrUnmarshal(p, vt)
return
}

n1 := -1
ft := vt
sv := false
Expand Down Expand Up @@ -1106,7 +1151,7 @@ func (self *_Compiler) compileUnmarshalEnd(p *_Program, vt reflect.Type, i int)
p.pin(j)
}

func (self *_Compiler) compileUnmarshalJson(p *_Program, vt reflect.Type) {
func (self *_Compiler) compileUnmarshalJson(p *_Program, vt reflect.Type, flags int) {
i := p.pc()
v := _OP_unmarshal
p.add(_OP_is_null)
Expand All @@ -1117,11 +1162,11 @@ func (self *_Compiler) compileUnmarshalJson(p *_Program, vt reflect.Type) {
}

/* call the unmarshaler */
p.rtt(v, vt)
p.rtti(v, vt, flags)
self.compileUnmarshalEnd(p, vt, i)
}

func (self *_Compiler) compileUnmarshalText(p *_Program, vt reflect.Type) {
func (self *_Compiler) compileUnmarshalText(p *_Program, vt reflect.Type, iv int) {
i := p.pc()
v := _OP_unmarshal_text
p.add(_OP_is_null)
Expand All @@ -1134,15 +1179,15 @@ func (self *_Compiler) compileUnmarshalText(p *_Program, vt reflect.Type) {
}

/* call the unmarshaler */
p.rtt(v, vt)
p.rtti(v, vt, iv)
self.compileUnmarshalEnd(p, vt, i)
}

func (self *_Compiler) compileUnmarshalTextPtr(p *_Program, vt reflect.Type) {
func (self *_Compiler) compileUnmarshalTextPtr(p *_Program, vt reflect.Type, iv int) {
i := p.pc()
p.add(_OP_is_null)
p.chr(_OP_match_char, '"')
p.rtt(_OP_unmarshal_text_p, vt)
p.rtti(_OP_unmarshal_text_p, vt, iv)
p.pin(i)
}

Expand Down
1 change: 1 addition & 0 deletions internal/decoder/jitdec/generic_regabi_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ var (
_T_slice = jit.Type(reflect.TypeOf(([]interface{})(nil)))
_T_string = jit.Type(reflect.TypeOf(""))
_T_number = jit.Type(reflect.TypeOf(json.Number("")))
_T_miserr = jit.Type(reflect.TypeOf(MismatchTypeError{}))
_T_float64 = jit.Type(reflect.TypeOf(float64(0)))
)

Expand Down
7 changes: 7 additions & 0 deletions internal/decoder/jitdec/primitives.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ func decodeJsonUnmarshaler(vv interface{}, s string) error {
return vv.(json.Unmarshaler).UnmarshalJSON(rt.Str2Mem(s))
}

func decodeJsonUnmarshalerQuoted(vv interface{}, s string) error {
if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' {
return &MismatchTypeError{}
}
return vv.(json.Unmarshaler).UnmarshalJSON(rt.Str2Mem(s[1:len(s)-1]))
}

func decodeTextUnmarshaler(vv interface{}, s string) error {
return vv.(encoding.TextUnmarshaler).UnmarshalText(rt.Str2Mem(s))
}
39 changes: 38 additions & 1 deletion internal/decoder/optdec/compile_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,43 @@ func (c *compiler) compileIntStringOption(vt reflect.Type) decFunc {
panic("unreachable")
}

func isInteger(vt reflect.Type) bool {
switch vt.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, reflect.Uintptr, reflect.Int: return true
default: return false
}
}

func (c *compiler) assertStringOptTypes(vt reflect.Type) {
if c.depth > _CompileMaxDepth {
panic(*stackOverflow)
}

c.depth += 1
defer func () {
c.depth -= 1
}()

if isInteger(vt) {
return
}

switch vt.Kind() {
case reflect.String, reflect.Bool, reflect.Float32, reflect.Float64:
return
case reflect.Ptr: c.assertStringOptTypes(vt.Elem())
default:
panicForInvalidStrType(vt)
}
}

func (c *compiler) compileFieldStringOption(vt reflect.Type) decFunc {
c.assertStringOptTypes(vt)
unmDec := c.tryCompilePtrUnmarshaler(vt, true)
if unmDec != nil {
return unmDec
}

switch vt.Kind() {
case reflect.String:
if vt == jsonNumberType {
Expand Down Expand Up @@ -80,7 +116,8 @@ func (c *compiler) compileFieldStringOption(vt reflect.Type) decFunc {
deref: c.compileFieldStringOption(vt.Elem()),
}
default:
panic("string options should appliy only to fields of string, floating point, integer, or boolean types.")
panicForInvalidStrType(vt)
return nil
}
}

Expand Down
14 changes: 11 additions & 3 deletions internal/decoder/optdec/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ type compiler struct {
counts int
opts option.CompileOptions
namedPtr bool

}

func newCompiler() *compiler {
Expand Down Expand Up @@ -114,7 +113,7 @@ func (c *compiler) compile(vt reflect.Type) decFunc {
}
}

dec := c.tryCompilePtrUnmarshaler(vt)
dec := c.tryCompilePtrUnmarshaler(vt, false)
if dec != nil {
return dec
}
Expand Down Expand Up @@ -420,22 +419,31 @@ func (c *compiler) compileMapKey(vt reflect.Type) decKey {
}

// maybe vt is a named type, and not a pointer receiver, see issue 379
func (c *compiler) tryCompilePtrUnmarshaler(vt reflect.Type) decFunc {
func (c *compiler) tryCompilePtrUnmarshaler(vt reflect.Type, strOpt bool) decFunc {
pt := reflect.PtrTo(vt)

/* check for `json.Unmarshaler` with pointer receiver */
if pt.Implements(jsonUnmarshalerType) {
return &unmarshalJSONDecoder{
typ: rt.UnpackType(pt),
strOpt: strOpt,
}
}

/* check for `encoding.TextMarshaler` with pointer receiver */
if pt.Implements(encodingTextUnmarshalerType) {
/* TextUnmarshal not support ,strig tag */
if strOpt {
panicForInvalidStrType(vt)
}
return &unmarshalTextDecoder{
typ: rt.UnpackType(pt),
}
}

return nil
}

func panicForInvalidStrType(vt reflect.Type) {
panic(error_type(rt.UnpackType(vt)))
}
Loading

0 comments on commit bc420fc

Please sign in to comment.