diff --git a/go.mod b/go.mod index 26dbf5d706..c65c56d005 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/golang/protobuf go 1.9 -require google.golang.org/protobuf v0.0.0-20190620020611-d888139e7b59 +require google.golang.org/protobuf v0.0.0-20190717230113-f647c82cc3c7 diff --git a/go.sum b/go.sum index 46ffa01ea3..ae4387fee5 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,7 @@ github.com/golang/protobuf v1.2.1-0.20190516215712-ae2eaafab405/go.mod h1:UmP8hh github.com/golang/protobuf v1.2.1-0.20190523175523-a1331f0b4ab4/go.mod h1:G+fNMoyvKWZDB7PCDHF+dXbH9OeE3+JoozCd9V7i66U= github.com/golang/protobuf v1.2.1-0.20190605195750-76c9e09470ba/go.mod h1:S1YIJXvYHGRCG2UmZsOcElkAYfvZLg2sDRr9+Xu8JXU= github.com/golang/protobuf v1.2.1-0.20190617175902-f94016f5239f/go.mod h1:G+HpKX7pYZAVkElkAWZkr08MToW6pTp/vs+E9osFfbg= +github.com/golang/protobuf v1.2.1-0.20190620192300-1ee46dfd80dd/go.mod h1:+CMAsi9jpYf/wAltLUKlg++CWXqxCJyD8iLDbQONsJs= github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= google.golang.org/protobuf v0.0.0-20190514172829-e89e6244e0e8/go.mod h1:791zQGC15vDqjpmPRn1uGPu5oHy/Jzw/Q1n5JsgIIcY= @@ -14,5 +15,6 @@ google.golang.org/protobuf v0.0.0-20190516215540-a95b29fbf623/go.mod h1:cWWmz5ls google.golang.org/protobuf v0.0.0-20190522194032-21ade498bd69/go.mod h1:cJytyYi/6qdwy/+gD49hmgHcwD7zhWxE/1KPEslaZ3M= google.golang.org/protobuf v0.0.0-20190605195314-89d49632e5cf/go.mod h1:Btug4TBaP5wNYcb2zGKDTS7WMcaPPLuqEAKfEAZWYbo= google.golang.org/protobuf v0.0.0-20190617175724-bd7b7a9e0c26/go.mod h1:+FOB8T5/Yw4ywwdyeun9/KlDeuwFYBkNQ+kVuwj9C94= -google.golang.org/protobuf v0.0.0-20190620020611-d888139e7b59 h1:8413FO+8BbzBumkamWfo1VRHJyPBKBUeerQodlLbb0g= google.golang.org/protobuf v0.0.0-20190620020611-d888139e7b59/go.mod h1:of3pt14Y+dOxz2tBOHXEoapPpKFC15/0zWhPAddkfsU= +google.golang.org/protobuf v0.0.0-20190717230113-f647c82cc3c7 h1:U6U+Hb+UKNGJB0eMAjUGk0wTmy73kduTIvdsEgA4Gf8= +google.golang.org/protobuf v0.0.0-20190717230113-f647c82cc3c7/go.mod h1:yGm7aNHn9Bp1NIvj6+CVUkcJshu+Usshfd3A+YxEuI8= diff --git a/internal/proto/properties.go b/internal/proto/properties.go index a3d024b6c3..d129248e01 100644 --- a/internal/proto/properties.go +++ b/internal/proto/properties.go @@ -11,6 +11,7 @@ import ( "strings" "sync" + protoV2 "google.golang.org/protobuf/proto" "google.golang.org/protobuf/runtime/protoimpl" ) @@ -251,10 +252,10 @@ func newProperties(t reflect.Type) *StructProperties { p.OrigName = tagOneof } - // Rename unrelated struct fields with the "XXX_" prefix since so much - // user code simply checks for this to exclude special fields. - if tagField == "" && tagOneof == "" && !strings.HasPrefix(p.Name, "XXX_") { - p.Name = "XXX_invalid_" + p.Name + // Rename unexported struct fields with the "XXX_" prefix since so much + // user code simply checks for this to exclude unrelated fields. + if f.PkgPath != "" { + p.Name = "XXX_" + p.Name } prop.Prop = append(prop.Prop, p) @@ -268,6 +269,11 @@ func newProperties(t reflect.Type) *StructProperties { if fn, ok := reflect.PtrTo(t).MethodByName("XXX_OneofWrappers"); ok { oneofWrappers = fn.Func.Call([]reflect.Value{reflect.Zero(fn.Type.In(0))})[0].Interface().([]interface{}) } + if m, ok := reflect.Zero(reflect.PtrTo(t)).Interface().(protoV2.Message); ok { + if m, ok := m.ProtoReflect().(interface{ ProtoMessageInfo() *protoimpl.MessageInfo }); ok { + oneofWrappers = m.ProtoMessageInfo().OneofWrappers + } + } if len(oneofWrappers) > 0 { prop.OneofTypes = make(map[string]*OneofProperties) for _, wrapper := range oneofWrappers { diff --git a/jsonpb/jsonpb.go b/jsonpb/jsonpb.go index 6bcb71946e..0d817d5e17 100644 --- a/jsonpb/jsonpb.go +++ b/jsonpb/jsonpb.go @@ -23,6 +23,8 @@ import ( "time" "github.com/golang/protobuf/proto" + protoV2 "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" stpb "github.com/golang/protobuf/ptypes/struct" ) @@ -131,8 +133,31 @@ func (s int32Slice) Len() int { return len(s) } func (s int32Slice) Less(i, j int) bool { return s[i] < s[j] } func (s int32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -type wkt interface { - XXX_WellKnownType() string +func wellKnownType(v interface{}) string { + var s protoreflect.FullName + switch v := v.(type) { + case interface{ XXX_WellKnownType() string }: + return v.XXX_WellKnownType() + case protoreflect.Enum: + s = v.Descriptor().FullName() + case protoreflect.ProtoMessage: + s = v.ProtoReflect().Descriptor().FullName() + } + if s.Parent() == "google.protobuf" { + switch s.Name() { + case "Empty", + "Any", + "BoolValue", + "FloatValue", "DoubleValue", + "Int32Value", "Int64Value", + "UInt32Value", "UInt64Value", + "BytesValue", "StringValue", + "Duration", "Timestamp", + "NullValue", "Struct", "Value", "ListValue": + return string(s.Name()) + } + } + return "" } // marshalObject writes a struct to the Writer. @@ -165,71 +190,69 @@ func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent, typeU s := reflect.ValueOf(v).Elem() // Handle well-known types. - if wkt, ok := v.(wkt); ok { - switch wkt.XXX_WellKnownType() { - case "DoubleValue", "FloatValue", "Int64Value", "UInt64Value", - "Int32Value", "UInt32Value", "BoolValue", "StringValue", "BytesValue": - // "Wrappers use the same representation in JSON - // as the wrapped primitive type, ..." - sprop := proto.GetProperties(s.Type()) - return m.marshalValue(out, sprop.Prop[0], s.Field(0), indent) - case "Any": - // Any is a bit more involved. - return m.marshalAny(out, v, indent) - case "Duration": - // "Generated output always contains 0, 3, 6, or 9 fractional digits, - // depending on required precision." - s, ns := s.Field(0).Int(), s.Field(1).Int() - if ns <= -secondInNanos || ns >= secondInNanos { - return fmt.Errorf("ns out of range (%v, %v)", -secondInNanos, secondInNanos) - } - if (s > 0 && ns < 0) || (s < 0 && ns > 0) { - return errors.New("signs of seconds and nanos do not match") - } - if s < 0 { - ns = -ns - } - x := fmt.Sprintf("%d.%09d", s, ns) - x = strings.TrimSuffix(x, "000") - x = strings.TrimSuffix(x, "000") - x = strings.TrimSuffix(x, ".000") - out.write(`"`) - out.write(x) - out.write(`s"`) - return out.err - case "Struct", "ListValue": - // Let marshalValue handle the `Struct.fields` map or the `ListValue.values` slice. - // TODO: pass the correct Properties if needed. - return m.marshalValue(out, &proto.Properties{}, s.Field(0), indent) - case "Timestamp": - // "RFC 3339, where generated output will always be Z-normalized - // and uses 0, 3, 6 or 9 fractional digits." - s, ns := s.Field(0).Int(), s.Field(1).Int() - if ns < 0 || ns >= secondInNanos { - return fmt.Errorf("ns out of range [0, %v)", secondInNanos) - } - t := time.Unix(s, ns).UTC() - // time.RFC3339Nano isn't exactly right (we need to get 3/6/9 fractional digits). - x := t.Format("2006-01-02T15:04:05.000000000") - x = strings.TrimSuffix(x, "000") - x = strings.TrimSuffix(x, "000") - x = strings.TrimSuffix(x, ".000") - out.write(`"`) - out.write(x) - out.write(`Z"`) - return out.err - case "Value": - // Value has a single oneof. - kind := s.Field(0) - if kind.IsNil() { - // "absence of any variant indicates an error" - return errors.New("nil Value") - } - // oneof -> *T -> T -> T.F - x := kind.Elem().Elem().Field(0) - // TODO: pass the correct Properties if needed. - return m.marshalValue(out, &proto.Properties{}, x, indent) + switch wellKnownType(v) { + case "DoubleValue", "FloatValue", "Int64Value", "UInt64Value", + "Int32Value", "UInt32Value", "BoolValue", "StringValue", "BytesValue": + // "Wrappers use the same representation in JSON + // as the wrapped primitive type, ..." + sprop := proto.GetProperties(s.Type()) + return m.marshalValue(out, sprop.Prop[1], s.Field(1), indent) + case "Any": + // Any is a bit more involved. + return m.marshalAny(out, v, indent) + case "Duration": + // "Generated output always contains 0, 3, 6, or 9 fractional digits, + // depending on required precision." + s, ns := s.Field(1).Int(), s.Field(2).Int() + if ns <= -secondInNanos || ns >= secondInNanos { + return fmt.Errorf("ns out of range (%v, %v)", -secondInNanos, secondInNanos) + } + if (s > 0 && ns < 0) || (s < 0 && ns > 0) { + return errors.New("signs of seconds and nanos do not match") + } + if s < 0 { + ns = -ns + } + x := fmt.Sprintf("%d.%09d", s, ns) + x = strings.TrimSuffix(x, "000") + x = strings.TrimSuffix(x, "000") + x = strings.TrimSuffix(x, ".000") + out.write(`"`) + out.write(x) + out.write(`s"`) + return out.err + case "Struct", "ListValue": + // Let marshalValue handle the `Struct.fields` map or the `ListValue.values` slice. + // TODO: pass the correct Properties if needed. + return m.marshalValue(out, &proto.Properties{}, s.Field(1), indent) + case "Timestamp": + // "RFC 3339, where generated output will always be Z-normalized + // and uses 0, 3, 6 or 9 fractional digits." + s, ns := s.Field(1).Int(), s.Field(2).Int() + if ns < 0 || ns >= secondInNanos { + return fmt.Errorf("ns out of range [0, %v)", secondInNanos) + } + t := time.Unix(s, ns).UTC() + // time.RFC3339Nano isn't exactly right (we need to get 3/6/9 fractional digits). + x := t.Format("2006-01-02T15:04:05.000000000") + x = strings.TrimSuffix(x, "000") + x = strings.TrimSuffix(x, "000") + x = strings.TrimSuffix(x, ".000") + out.write(`"`) + out.write(x) + out.write(`Z"`) + return out.err + case "Value": + // Value has a single oneof. + kind := s.Field(1) + if kind.IsNil() { + // "absence of any variant indicates an error" + return errors.New("nil Value") } + // oneof -> *T -> T -> T.F + x := kind.Elem().Elem().Field(0) + // TODO: pass the correct Properties if needed. + return m.marshalValue(out, &proto.Properties{}, x, indent) } out.write("{") @@ -247,11 +270,11 @@ func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent, typeU } for i := 0; i < s.NumField(); i++ { - value := s.Field(i) - valueField := s.Type().Field(i) - if strings.HasPrefix(valueField.Name, "XXX_") { + if f := s.Type().Field(i); strings.HasPrefix(f.Name, "XXX_") || f.PkgPath != "" { continue } + value := s.Field(i) + valueField := s.Type().Field(i) // IsNil will panic on most value kinds. switch value.Kind() { @@ -332,7 +355,11 @@ func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent, typeU value := reflect.ValueOf(ext) var prop proto.Properties prop.Parse(desc.Tag) - prop.JSONName = fmt.Sprintf("[%s]", desc.Name) + name := desc.Name + if strings.HasSuffix(name, ".message_set_extension") && isMessageSet(s.Type()) { + name = strings.TrimSuffix(name, ".message_set_extension") + } + prop.JSONName = fmt.Sprintf("[%s]", name) if !firstField { m.writeSep(out) } @@ -366,8 +393,8 @@ func (m *Marshaler) marshalAny(out *errWriter, any proto.Message, indent string) // Otherwise, the value will be converted into a JSON object, // and the "@type" field will be inserted to indicate the actual data type." v := reflect.ValueOf(any).Elem() - turl := v.Field(0).String() - val := v.Field(1).Bytes() + turl := v.Field(1).String() + val := v.Field(2).Bytes() var msg proto.Message var err error @@ -384,7 +411,7 @@ func (m *Marshaler) marshalAny(out *errWriter, any proto.Message, indent string) return err } - if _, ok := msg.(wkt); ok { + if wellKnownType(msg) != "" { out.write("{") if m.Indent != "" { out.write("\n") @@ -489,12 +516,10 @@ func (m *Marshaler) marshalValue(out *errWriter, prop *proto.Properties, v refle // Handle well-known types. // Most are handled up in marshalObject (because 99% are messages). - if wkt, ok := v.Interface().(wkt); ok { - switch wkt.XXX_WellKnownType() { - case "NullValue": - out.write("null") - return out.err - } + switch wellKnownType(v.Interface()) { + case "NullValue": + out.write("null") + return out.err } // Handle enumerations. @@ -695,152 +720,150 @@ func (u *Unmarshaler) unmarshalValue(target reflect.Value, inputValue json.RawMe } // Handle well-known types that are not pointers. - if w, ok := target.Addr().Interface().(wkt); ok { - switch w.XXX_WellKnownType() { - case "DoubleValue", "FloatValue", "Int64Value", "UInt64Value", - "Int32Value", "UInt32Value", "BoolValue", "StringValue", "BytesValue": - return u.unmarshalValue(target.Field(0), inputValue, prop) - case "Any": - // Use json.RawMessage pointer type instead of value to support pre-1.8 version. - // 1.8 changed RawMessage.MarshalJSON from pointer type to value type, see - // https://github.com/golang/go/issues/14493 - var jsonFields map[string]*json.RawMessage - if err := json.Unmarshal(inputValue, &jsonFields); err != nil { - return err - } - - val, ok := jsonFields["@type"] - if !ok || val == nil { - return errors.New("Any JSON doesn't have '@type'") - } + switch wellKnownType(target.Addr().Interface()) { + case "DoubleValue", "FloatValue", "Int64Value", "UInt64Value", + "Int32Value", "UInt32Value", "BoolValue", "StringValue", "BytesValue": + return u.unmarshalValue(target.Field(1), inputValue, prop) + case "Any": + // Use json.RawMessage pointer type instead of value to support pre-1.8 version. + // 1.8 changed RawMessage.MarshalJSON from pointer type to value type, see + // https://github.com/golang/go/issues/14493 + var jsonFields map[string]*json.RawMessage + if err := json.Unmarshal(inputValue, &jsonFields); err != nil { + return err + } - var turl string - if err := json.Unmarshal([]byte(*val), &turl); err != nil { - return fmt.Errorf("can't unmarshal Any's '@type': %q", *val) - } - target.Field(0).SetString(turl) - - var m proto.Message - var err error - if u.AnyResolver != nil { - m, err = u.AnyResolver.Resolve(turl) - } else { - m, err = defaultResolveAny(turl) - } - if err != nil { - return err - } + val, ok := jsonFields["@type"] + if !ok || val == nil { + return errors.New("Any JSON doesn't have '@type'") + } - if _, ok := m.(wkt); ok { - val, ok := jsonFields["value"] - if !ok { - return errors.New("Any JSON doesn't have 'value'") - } + var turl string + if err := json.Unmarshal([]byte(*val), &turl); err != nil { + return fmt.Errorf("can't unmarshal Any's '@type': %q", *val) + } + target.Field(1).SetString(turl) - if err := u.unmarshalValue(reflect.ValueOf(m).Elem(), *val, nil); err != nil { - return fmt.Errorf("can't unmarshal Any nested proto %T: %v", m, err) - } - } else { - delete(jsonFields, "@type") - nestedProto, err := json.Marshal(jsonFields) - if err != nil { - return fmt.Errorf("can't generate JSON for Any's nested proto to be unmarshaled: %v", err) - } + var m proto.Message + var err error + if u.AnyResolver != nil { + m, err = u.AnyResolver.Resolve(turl) + } else { + m, err = defaultResolveAny(turl) + } + if err != nil { + return err + } - if err = u.unmarshalValue(reflect.ValueOf(m).Elem(), nestedProto, nil); err != nil { - return fmt.Errorf("can't unmarshal Any nested proto %T: %v", m, err) - } + if wellKnownType(m) != "" { + val, ok := jsonFields["value"] + if !ok { + return errors.New("Any JSON doesn't have 'value'") } - b, err := proto.Marshal(m) - if err != nil { - return fmt.Errorf("can't marshal proto %T into Any.Value: %v", m, err) + if err := u.unmarshalValue(reflect.ValueOf(m).Elem(), *val, nil); err != nil { + return fmt.Errorf("can't unmarshal Any nested proto %T: %v", m, err) } - target.Field(1).SetBytes(b) - - return nil - case "Duration": - unq, err := unquote(string(inputValue)) + } else { + delete(jsonFields, "@type") + nestedProto, err := json.Marshal(jsonFields) if err != nil { - return err + return fmt.Errorf("can't generate JSON for Any's nested proto to be unmarshaled: %v", err) } - d, err := time.ParseDuration(unq) - if err != nil { - return fmt.Errorf("bad Duration: %v", err) + if err = u.unmarshalValue(reflect.ValueOf(m).Elem(), nestedProto, nil); err != nil { + return fmt.Errorf("can't unmarshal Any nested proto %T: %v", m, err) } + } - ns := d.Nanoseconds() - s := ns / 1e9 - ns %= 1e9 - target.Field(0).SetInt(s) - target.Field(1).SetInt(ns) - return nil - case "Timestamp": - unq, err := unquote(string(inputValue)) - if err != nil { - return err - } + b, err := proto.Marshal(m) + if err != nil { + return fmt.Errorf("can't marshal proto %T into Any.Value: %v", m, err) + } + target.Field(2).SetBytes(b) - t, err := time.Parse(time.RFC3339Nano, unq) - if err != nil { - return fmt.Errorf("bad Timestamp: %v", err) - } + return nil + case "Duration": + unq, err := unquote(string(inputValue)) + if err != nil { + return err + } - target.Field(0).SetInt(t.Unix()) - target.Field(1).SetInt(int64(t.Nanosecond())) - return nil - case "Struct": - var m map[string]json.RawMessage - if err := json.Unmarshal(inputValue, &m); err != nil { - return fmt.Errorf("bad StructValue: %v", err) - } + d, err := time.ParseDuration(unq) + if err != nil { + return fmt.Errorf("bad Duration: %v", err) + } - target.Field(0).Set(reflect.ValueOf(map[string]*stpb.Value{})) - for k, jv := range m { - pv := &stpb.Value{} - if err := u.unmarshalValue(reflect.ValueOf(pv).Elem(), jv, prop); err != nil { - return fmt.Errorf("bad value in StructValue for key %q: %v", k, err) - } - target.Field(0).SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(pv)) - } - return nil - case "ListValue": - var s []json.RawMessage - if err := json.Unmarshal(inputValue, &s); err != nil { - return fmt.Errorf("bad ListValue: %v", err) - } + ns := d.Nanoseconds() + s := ns / 1e9 + ns %= 1e9 + target.Field(1).SetInt(s) + target.Field(2).SetInt(ns) + return nil + case "Timestamp": + unq, err := unquote(string(inputValue)) + if err != nil { + return err + } - target.Field(0).Set(reflect.ValueOf(make([]*stpb.Value, len(s)))) - for i, sv := range s { - if err := u.unmarshalValue(target.Field(0).Index(i), sv, prop); err != nil { - return err - } + t, err := time.Parse(time.RFC3339Nano, unq) + if err != nil { + return fmt.Errorf("bad Timestamp: %v", err) + } + + target.Field(1).SetInt(t.Unix()) + target.Field(2).SetInt(int64(t.Nanosecond())) + return nil + case "Struct": + var m map[string]json.RawMessage + if err := json.Unmarshal(inputValue, &m); err != nil { + return fmt.Errorf("bad StructValue: %v", err) + } + + target.Field(1).Set(reflect.ValueOf(map[string]*stpb.Value{})) + for k, jv := range m { + pv := &stpb.Value{} + if err := u.unmarshalValue(reflect.ValueOf(pv).Elem(), jv, prop); err != nil { + return fmt.Errorf("bad value in StructValue for key %q: %v", k, err) } - return nil - case "Value": - ivStr := string(inputValue) - if ivStr == "null" { - target.Field(0).Set(reflect.ValueOf(&stpb.Value_NullValue{})) - } else if v, err := strconv.ParseFloat(ivStr, 0); err == nil { - target.Field(0).Set(reflect.ValueOf(&stpb.Value_NumberValue{v})) - } else if v, err := unquote(ivStr); err == nil { - target.Field(0).Set(reflect.ValueOf(&stpb.Value_StringValue{v})) - } else if v, err := strconv.ParseBool(ivStr); err == nil { - target.Field(0).Set(reflect.ValueOf(&stpb.Value_BoolValue{v})) - } else if err := json.Unmarshal(inputValue, &[]json.RawMessage{}); err == nil { - lv := &stpb.ListValue{} - target.Field(0).Set(reflect.ValueOf(&stpb.Value_ListValue{lv})) - return u.unmarshalValue(reflect.ValueOf(lv).Elem(), inputValue, prop) - } else if err := json.Unmarshal(inputValue, &map[string]json.RawMessage{}); err == nil { - sv := &stpb.Struct{} - target.Field(0).Set(reflect.ValueOf(&stpb.Value_StructValue{sv})) - return u.unmarshalValue(reflect.ValueOf(sv).Elem(), inputValue, prop) - } else { - return fmt.Errorf("unrecognized type for Value %q", ivStr) + target.Field(1).SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(pv)) + } + return nil + case "ListValue": + var s []json.RawMessage + if err := json.Unmarshal(inputValue, &s); err != nil { + return fmt.Errorf("bad ListValue: %v", err) + } + + target.Field(1).Set(reflect.ValueOf(make([]*stpb.Value, len(s)))) + for i, sv := range s { + if err := u.unmarshalValue(target.Field(1).Index(i), sv, prop); err != nil { + return err } - return nil } + return nil + case "Value": + ivStr := string(inputValue) + if ivStr == "null" { + target.Field(1).Set(reflect.ValueOf(&stpb.Value_NullValue{})) + } else if v, err := strconv.ParseFloat(ivStr, 0); err == nil { + target.Field(1).Set(reflect.ValueOf(&stpb.Value_NumberValue{v})) + } else if v, err := unquote(ivStr); err == nil { + target.Field(1).Set(reflect.ValueOf(&stpb.Value_StringValue{v})) + } else if v, err := strconv.ParseBool(ivStr); err == nil { + target.Field(1).Set(reflect.ValueOf(&stpb.Value_BoolValue{v})) + } else if err := json.Unmarshal(inputValue, &[]json.RawMessage{}); err == nil { + lv := &stpb.ListValue{} + target.Field(1).Set(reflect.ValueOf(&stpb.Value_ListValue{lv})) + return u.unmarshalValue(reflect.ValueOf(lv).Elem(), inputValue, prop) + } else if err := json.Unmarshal(inputValue, &map[string]json.RawMessage{}); err == nil { + sv := &stpb.Struct{} + target.Field(1).Set(reflect.ValueOf(&stpb.Value_StructValue{sv})) + return u.unmarshalValue(reflect.ValueOf(sv).Elem(), inputValue, prop) + } else { + return fmt.Errorf("unrecognized type for Value %q", ivStr) + } + return nil } // Handle enums, which have an underlying type of int32, @@ -899,7 +922,7 @@ func (u *Unmarshaler) unmarshalValue(target reflect.Value, inputValue json.RawMe sprops := proto.GetProperties(targetType) for i := 0; i < target.NumField(); i++ { ft := target.Type().Field(i) - if strings.HasPrefix(ft.Name, "XXX_") { + if f := ft; strings.HasPrefix(f.Name, "XXX_") || f.PkgPath != "" { continue } @@ -944,6 +967,23 @@ func (u *Unmarshaler) unmarshalValue(target reflect.Value, inputValue json.RawMe return err } } + if isMessageSet(target.Type()) { + for _, ext := range proto.RegisteredExtensions(ep) { + name := fmt.Sprintf("[%s]", strings.TrimSuffix(ext.Name, ".message_set_extension")) + raw, ok := jsonFields[name] + if !ok { + continue + } + delete(jsonFields, name) + nv := reflect.New(reflect.TypeOf(ext.ExtensionType).Elem()) + if err := u.unmarshalValue(nv.Elem(), raw, nil); err != nil { + return err + } + if err := proto.SetExtension(ep, ext, nv.Interface()); err != nil { + return err + } + } + } } } if !u.AllowUnknownFields && len(jsonFields) > 0 { @@ -1118,7 +1158,7 @@ func checkRequiredFields(pb proto.Message) error { // When an Any message is being unmarshaled, the code will have invoked proto.Marshal on the // embedded message to store the serialized message in Any.Value field, and that should have // returned an error if a required field is not set. - if _, ok := pb.(wkt); ok { + if wellKnownType(pb) != "" { return nil } @@ -1133,17 +1173,11 @@ func checkRequiredFields(pb proto.Message) error { } for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - sfield := v.Type().Field(i) - - if sfield.PkgPath != "" { - // blank PkgPath means the field is exported; skip if not exported - continue - } - - if strings.HasPrefix(sfield.Name, "XXX_") { + if f := v.Type().Field(i); strings.HasPrefix(f.Name, "XXX_") || f.PkgPath != "" { continue } + field := v.Field(i) + sfield := v.Type().Field(i) // Oneof field is an interface implemented by wrapper structs containing the actual oneof // field, i.e. an interface containing &T{real_value}. @@ -1240,3 +1274,14 @@ func checkRequiredFieldsInValue(v reflect.Value) error { } return nil } + +// isMessageSet determines whether t is a MessageSet message, +// where t must be a named struct type. +func isMessageSet(t reflect.Type) bool { + if m, ok := reflect.Zero(reflect.PtrTo(t)).Interface().(protoV2.Message); ok { + md := m.ProtoReflect().Descriptor() + xmd, ok := md.(interface{ IsMessageSet() bool }) + return ok && xmd.IsMessageSet() + } + return false +} diff --git a/jsonpb/jsonpb_test.go b/jsonpb/jsonpb_test.go index 93ed06d3a1..d58a803749 100644 --- a/jsonpb/jsonpb_test.go +++ b/jsonpb/jsonpb_test.go @@ -506,7 +506,7 @@ func TestMarshaling(t *testing.T) { if err != nil { t.Errorf("%s: marshaling error: %v", tt.desc, err) } else if tt.json != json { - t.Errorf("%s: got [%v] want [%v]", tt.desc, json, tt.json) + t.Errorf("%s:\ngot: %v\nwant: %v", tt.desc, json, tt.json) } } } diff --git a/proto/all_test.go b/proto/all_test.go index ad586f6484..2c8e3aadc0 100644 --- a/proto/all_test.go +++ b/proto/all_test.go @@ -21,6 +21,7 @@ import ( . "github.com/golang/protobuf/proto" pb3 "github.com/golang/protobuf/proto/proto3_proto" . "github.com/golang/protobuf/proto/test_proto" + tpb "google.golang.org/protobuf/types/known/timestamppb" ) var globalO *Buffer @@ -2328,6 +2329,15 @@ func TestRequired(t *testing.T) { } } +func TestUnknownV2(t *testing.T) { + m := new(tpb.Timestamp) + m.ProtoReflect().SetUnknown([]byte("\x92\x4d\x12unknown field 1234")) + got := CompactTextString(m) + if !strings.Contains(got, "unknown field 1234") { + t.Errorf("got %q, want contains %q", got, "unknown field 1234") + } +} + // Benchmarks func testMsg() *GoTest { diff --git a/proto/clone.go b/proto/clone.go index cd1a2e20a4..c753078df4 100644 --- a/proto/clone.go +++ b/proto/clone.go @@ -77,7 +77,7 @@ func mergeStruct(out, in reflect.Value) { sprop := GetProperties(in.Type()) for i := 0; i < in.NumField(); i++ { f := in.Type().Field(i) - if strings.HasPrefix(f.Name, "XXX_") { + if strings.HasPrefix(f.Name, "XXX_") || f.PkgPath != "" { continue } mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i]) @@ -90,13 +90,13 @@ func mergeStruct(out, in reflect.Value) { } } - uf := in.FieldByName("XXX_unrecognized") + uf := unknownFieldsValue(in) if !uf.IsValid() { return } uin := uf.Bytes() if len(uin) > 0 { - out.FieldByName("XXX_unrecognized").SetBytes(append([]byte(nil), uin...)) + unknownFieldsValue(out).SetBytes(append([]byte(nil), uin...)) } } diff --git a/proto/discard.go b/proto/discard.go index c53d4b177c..8e61be3648 100644 --- a/proto/discard.go +++ b/proto/discard.go @@ -127,11 +127,11 @@ func (di *discardInfo) computeDiscardInfo() { for i := 0; i < n; i++ { f := t.Field(i) - if strings.HasPrefix(f.Name, "XXX_") { + if strings.HasPrefix(f.Name, "XXX_") || f.PkgPath != "" { continue } - dfi := discardFieldInfo{field: toField(&f)} + dfi := discardFieldInfo{field: toField(&f, nil)} tf := f.Type // Unwrap tf to get its most basic type. @@ -219,12 +219,19 @@ func (di *discardInfo) computeDiscardInfo() { di.fields = append(di.fields, dfi) } + expFunc := exporterFunc(t) di.unrecognized = invalidField if f, ok := t.FieldByName("XXX_unrecognized"); ok { if f.Type != reflect.TypeOf([]byte{}) { panic("expected XXX_unrecognized to be of type []byte") } - di.unrecognized = toField(&f) + di.unrecognized = toField(&f, nil) + } + if f, ok := t.FieldByName("unknownFields"); ok { + if f.Type != reflect.TypeOf([]byte{}) { + panic("expected unknownFields to be of type []byte") + } + di.unrecognized = toField(&f, expFunc) } atomic.StoreInt32(&di.initialized, 1) @@ -243,7 +250,7 @@ func discardLegacy(m Message) { for i := 0; i < v.NumField(); i++ { f := t.Field(i) - if strings.HasPrefix(f.Name, "XXX_") { + if strings.HasPrefix(f.Name, "XXX_") || f.PkgPath != "" { continue } vf := v.Field(i) @@ -308,9 +315,9 @@ func discardLegacy(m Message) { } } - if vf := v.FieldByName("XXX_unrecognized"); vf.IsValid() { + if vf := unknownFieldsValue(v); vf.IsValid() { if vf.Type() != reflect.TypeOf([]byte{}) { - panic("expected XXX_unrecognized to be of type []byte") + panic("expected unknown fields to be of type []byte") } vf.Set(reflect.ValueOf([]byte(nil))) } diff --git a/proto/equal.go b/proto/equal.go index ceebce79f0..b2a69e15ed 100644 --- a/proto/equal.go +++ b/proto/equal.go @@ -72,7 +72,7 @@ func equalStruct(v1, v2 reflect.Value) bool { sprop := GetProperties(v1.Type()) for i := 0; i < v1.NumField(); i++ { f := v1.Type().Field(i) - if strings.HasPrefix(f.Name, "XXX_") { + if strings.HasPrefix(f.Name, "XXX_") || f.PkgPath != "" { continue } f1, f2 := v1.Field(i), v2.Field(i) @@ -91,8 +91,8 @@ func equalStruct(v1, v2 reflect.Value) bool { } } - if em1 := v1.FieldByName("XXX_InternalExtensions"); em1.IsValid() { - em2 := v2.FieldByName("XXX_InternalExtensions") + if em1 := extensionFieldsValue(v1); em1.IsValid() { + em2 := extensionFieldsValue(v2) m1 := extensionFieldsOf(em1.Addr().Interface()) m2 := extensionFieldsOf(em2.Addr().Interface()) if !equalExtensions(v1.Type(), m1, m2) { @@ -100,23 +100,15 @@ func equalStruct(v1, v2 reflect.Value) bool { } } - if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() { - em2 := v2.FieldByName("XXX_extensions") - m1 := extensionFieldsOf(em1.Addr().Interface()) - m2 := extensionFieldsOf(em2.Addr().Interface()) - if !equalExtensions(v1.Type(), m1, m2) { + if uf := unknownFieldsValue(v1); uf.IsValid() { + u1 := uf.Bytes() + u2 := unknownFieldsValue(v2).Bytes() + if !bytes.Equal(u1, u2) { return false } } - uf := v1.FieldByName("XXX_unrecognized") - if !uf.IsValid() { - return true - } - - u1 := uf.Bytes() - u2 := v2.FieldByName("XXX_unrecognized").Bytes() - return bytes.Equal(u1, u2) + return true } // v1 and v2 are known to have the same type. diff --git a/proto/extensions.go b/proto/extensions.go index a85ee41a57..adf24b9ea0 100644 --- a/proto/extensions.go +++ b/proto/extensions.go @@ -69,11 +69,8 @@ func extendable(p interface{}) (*extensionMap, error) { v := reflect.ValueOf(p) if v.Kind() == reflect.Ptr && !v.IsNil() { v = v.Elem() - if v := v.FieldByName("XXX_InternalExtensions"); v.IsValid() { - return extensionFieldsOf(v.Addr().Interface()), nil - } - if v := v.FieldByName("XXX_extensions"); v.IsValid() { - return extensionFieldsOf(v.Addr().Interface()), nil + if vf := extensionFieldsValue(v); vf.IsValid() { + return extensionFieldsOf(vf.Addr().Interface()), nil } } } @@ -102,7 +99,7 @@ func SetRawExtension(base Message, id int32, b []byte) { if !v.IsValid() || v.Kind() != reflect.Ptr || v.IsNil() || v.Elem().Kind() != reflect.Struct { return } - v = v.Elem().FieldByName("XXX_unrecognized") + v = unknownFieldsValue(v.Elem()) if !v.IsValid() { return } @@ -208,7 +205,7 @@ func HasExtension(pb Message, extension *ExtensionDesc) bool { } // Check whether this field exists in raw form. - unrecognized := reflect.ValueOf(pb).Elem().FieldByName("XXX_unrecognized") + unrecognized := unknownFieldsValue(reflect.ValueOf(pb).Elem()) fnum := protoreflect.FieldNumber(extension.Field) for b := unrecognized.Bytes(); len(b) > 0; { got, _, n := wire.ConsumeField(b) @@ -250,7 +247,7 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) { return nil, err } - unrecognized := reflect.ValueOf(pb).Elem().FieldByName("XXX_unrecognized") + unrecognized := unknownFieldsValue(reflect.ValueOf(pb).Elem()) var out []byte fnum := protoreflect.FieldNumber(extension.Field) for b := unrecognized.Bytes(); len(b) > 0; { @@ -416,7 +413,7 @@ func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) { return true }) - unrecognized := reflect.ValueOf(pb).Elem().FieldByName("XXX_unrecognized") + unrecognized := unknownFieldsValue(reflect.ValueOf(pb).Elem()) if b := unrecognized.Bytes(); len(b) > 0 { fieldNums := make(map[int32]bool) for len(b) > 0 { diff --git a/proto/extensions_test.go b/proto/extensions_test.go index 044aaf4a64..e06812a0f1 100644 --- a/proto/extensions_test.go +++ b/proto/extensions_test.go @@ -268,47 +268,53 @@ func TestGetExtensionDefaults(t *testing.T) { {pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE}, } - checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error { - val, err := proto.GetExtension(msg, test.ext) - if err != nil { - if valWant != nil { - return fmt.Errorf("GetExtension(): %s", err) - } - if want := proto.ErrMissingExtension; err != want { - return fmt.Errorf("Unexpected error: got %v, want %v", err, want) + checkVal := func(t *testing.T, name string, test testcase, msg *pb.DefaultsMessage, valWant interface{}) { + t.Run(name, func(t *testing.T) { + val, err := proto.GetExtension(msg, test.ext) + if err != nil { + if valWant != nil { + t.Errorf("GetExtension(): %s", err) + return + } + if want := proto.ErrMissingExtension; err != want { + t.Errorf("Unexpected error: got %v, want %v", err, want) + return + } + return } - return nil - } - - // All proto2 extension values are either a pointer to a value or a slice of values. - ty := reflect.TypeOf(val) - tyWant := reflect.TypeOf(test.ext.ExtensionType) - if got, want := ty, tyWant; got != want { - return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want) - } - tye := ty.Elem() - tyeWant := tyWant.Elem() - if got, want := tye, tyeWant; got != want { - return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want) - } - // Check the name of the type of the value. - // If it is an enum it will be type int32 with the name of the enum. - if got, want := tye.Name(), tye.Name(); got != want { - return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want) - } + // All proto2 extension values are either a pointer to a value or a slice of values. + ty := reflect.TypeOf(val) + tyWant := reflect.TypeOf(test.ext.ExtensionType) + if got, want := ty, tyWant; got != want { + t.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want) + return + } + tye := ty.Elem() + tyeWant := tyWant.Elem() + if got, want := tye, tyeWant; got != want { + t.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want) + return + } - // Check that value is what we expect. - // If we have a pointer in val, get the value it points to. - valExp := val - if ty.Kind() == reflect.Ptr { - valExp = reflect.ValueOf(val).Elem().Interface() - } - if got, want := valExp, valWant; !reflect.DeepEqual(got, want) { - return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want) - } + // Check the name of the type of the value. + // If it is an enum it will be type int32 with the name of the enum. + if got, want := tye.Name(), tye.Name(); got != want { + t.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want) + return + } - return nil + // Check that value is what we expect. + // If we have a pointer in val, get the value it points to. + valExp := val + if ty.Kind() == reflect.Ptr { + valExp = reflect.ValueOf(val).Elem().Interface() + } + if got, want := valExp, valWant; !reflect.DeepEqual(got, want) { + t.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want) + return + } + }) } setTo := func(test testcase) interface{} { @@ -326,27 +332,18 @@ func TestGetExtensionDefaults(t *testing.T) { name := test.ext.Name // Check the initial value. - if err := checkVal(test, msg, test.def); err != nil { - t.Errorf("%s: %v", name, err) - } + checkVal(t, name+"/initial", test, msg, test.def) // Set the per-type value and check value. - name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want) if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil { t.Errorf("%s: SetExtension(): %v", name, err) continue } - if err := checkVal(test, msg, test.want); err != nil { - t.Errorf("%s: %v", name, err) - continue - } + checkVal(t, name+"/set", test, msg, test.want) // Set and check the value. - name += " (cleared)" proto.ClearExtension(msg, test.ext) - if err := checkVal(test, msg, test.def); err != nil { - t.Errorf("%s: %v", name, err) - } + checkVal(t, name+"/cleared", test, msg, test.def) } } diff --git a/proto/lib.go b/proto/lib.go index c8afe8a541..ada7b33fee 100644 --- a/proto/lib.go +++ b/proto/lib.go @@ -15,7 +15,9 @@ import ( "strconv" "sync" + protoV2 "google.golang.org/protobuf/proto" "google.golang.org/protobuf/runtime/protoiface" + "google.golang.org/protobuf/runtime/protoimpl" ) // requiredNotSetError is an error type returned by either Marshal or Unmarshal. @@ -95,6 +97,72 @@ type ( } ) +// oneofWrappers returns a list of oneof wrappers for t, +// which must be a named struct type. +func oneofWrappers(t reflect.Type) []interface{} { + var oos []interface{} + switch m := reflect.Zero(reflect.PtrTo(t)).Interface().(type) { + case oneofFuncsIface: + _, _, _, oos = m.XXX_OneofFuncs() + case oneofWrappersIface: + oos = m.XXX_OneofWrappers() + case protoV2.Message: + if m, ok := m.ProtoReflect().(interface{ ProtoMessageInfo() *protoimpl.MessageInfo }); ok { + oos = m.ProtoMessageInfo().OneofWrappers + } + } + return oos +} + +// unknownFieldsValue retrieves the value for unknown fields from v, +// which must be a name struct type. +func unknownFieldsValue(v reflect.Value) reflect.Value { + if vf := v.FieldByName("XXX_unrecognized"); vf.IsValid() { + return vf + } + if vf := fieldByName(v, "unknownFields"); vf.IsValid() { + return vf + } + return reflect.Value{} +} + +// extensionFieldsValue retrieves the value for extension fields from v, +// which must be a name struct type. +func extensionFieldsValue(v reflect.Value) reflect.Value { + if vf := v.FieldByName("XXX_InternalExtensions"); vf.IsValid() { + return vf + } + if vf := v.FieldByName("XXX_extensions"); vf.IsValid() { + return vf + } + if vf := fieldByName(v, "extensionFields"); vf.IsValid() { + return vf + } + return reflect.Value{} +} + +// exporterFunc retrieves the field exporter function for t, +// which must be a named struct type. +func exporterFunc(t reflect.Type) func(interface{}, int) interface{} { + if m, ok := reflect.Zero(reflect.PtrTo(t)).Interface().(protoV2.Message); ok { + if m, ok := m.ProtoReflect().(interface{ ProtoMessageInfo() *protoimpl.MessageInfo }); ok { + return m.ProtoMessageInfo().Exporter + } + } + return nil +} + +// isMessageSet determines whether t is a MessageSet message, +// where t must be a named struct type. +func isMessageSet(t reflect.Type) bool { + if m, ok := reflect.Zero(reflect.PtrTo(t)).Interface().(protoV2.Message); ok { + md := m.ProtoReflect().Descriptor() + xmd, ok := md.(interface{ IsMessageSet() bool }) + return ok && xmd.IsMessageSet() + } + return false +} + // A Buffer is a buffer manager for marshaling and unmarshaling // protocol buffers. It may be reused between invocations to // reduce memory usage. It is not necessary to use a Buffer; diff --git a/proto/message_set.go b/proto/message_set.go index 8626382a2b..201030d467 100644 --- a/proto/message_set.go +++ b/proto/message_set.go @@ -122,7 +122,7 @@ func unmarshalMessageSet(buf []byte, mi Message, exts interface{}) error { if err := Unmarshal(buf, ms); err != nil { return err } - unrecognized := reflect.ValueOf(mi).Elem().FieldByName("XXX_unrecognized").Addr().Interface().(*[]byte) + unrecognized := unknownFieldsValue(reflect.ValueOf(mi).Elem()).Addr().Interface().(*[]byte) for _, item := range ms.Item { id := protoreflect.FieldNumber(*item.TypeId) diff --git a/proto/pointer_reflect.go b/proto/pointer_reflect.go index 5ee1a0a91d..0b764bc682 100644 --- a/proto/pointer_reflect.go +++ b/proto/pointer_reflect.go @@ -13,6 +13,8 @@ package proto import ( "reflect" "sync" + "unicode" + "unicode/utf8" ) const unsafeAllowed = false @@ -20,21 +22,35 @@ const unsafeAllowed = false // A field identifies a field in a struct, accessible from a pointer. // In this implementation, a field is identified by the sequence of field indices // passed to reflect's FieldByIndex. -type field []int +type field struct { + index int + export exporter +} + +type exporter = func(interface{}, int) interface{} // toField returns a field equivalent to the given reflect field. -func toField(f *reflect.StructField) field { - return f.Index +func toField(f *reflect.StructField, x exporter) field { + if len(f.Index) != 1 { + panic("embedded structs are not supported") + } + if f.PkgPath == "" { + return field{index: f.Index[0]} // field is already exported + } + if x == nil { + panic("exporter must be provided for unexported field: " + f.Name) + } + return field{index: f.Index[0], export: x} } // invalidField is an invalid field identifier. -var invalidField = field(nil) +var invalidField = field{index: -1} // zeroField is a noop when calling pointer.offset. -var zeroField = field([]int{}) +var zeroField = field{index: 0} // IsValid reports whether the field identifier is valid. -func (f field) IsValid() bool { return f != nil } +func (f field) IsValid() bool { return f.index >= 0 } // The pointer type is for the table-driven decoder. // The implementation here uses a reflect.Value of pointer type to @@ -70,7 +86,12 @@ func valToPointer(v reflect.Value) pointer { // offset converts from a pointer to a structure to a pointer to // one of its fields. func (p pointer) offset(f field) pointer { - return pointer{v: p.v.Elem().FieldByIndex(f).Addr()} + if f.export != nil { + if v := reflect.ValueOf(f.export(p.v.Interface(), f.index)); v.IsValid() { + return pointer{v: v} + } + } + return pointer{v: p.v.Elem().Field(f.index).Addr()} } func (p pointer) isNil() bool { @@ -331,3 +352,20 @@ func atomicStoreDiscardInfo(p **discardInfo, v *discardInfo) { } var atomicLock sync.Mutex + +// fieldByName is equivalent to reflect.Value.FieldByName, but is able to +// descend into unexported fields for prop +func fieldByName(v reflect.Value, s string) reflect.Value { + if r, _ := utf8.DecodeRuneInString(s); unicode.IsUpper(r) { + return v.FieldByName(s) + } + t := v.Type() + if x := exporterFunc(t); x != nil { + sf, ok := t.FieldByName(s) + if ok { + vi := x(v.Addr().Interface(), sf.Index[0]) + return reflect.ValueOf(vi).Elem() + } + } + return v.FieldByName(s) +} diff --git a/proto/pointer_unsafe.go b/proto/pointer_unsafe.go index 52f1c92b55..20269f45a8 100644 --- a/proto/pointer_unsafe.go +++ b/proto/pointer_unsafe.go @@ -11,6 +11,8 @@ package proto import ( "reflect" "sync/atomic" + "unicode" + "unicode/utf8" "unsafe" ) @@ -20,8 +22,10 @@ const unsafeAllowed = true // In this implementation, a field is identified by its byte offset from the start of the struct. type field uintptr +type exporter = func(interface{}, int) interface{} + // toField returns a field equivalent to the given reflect field. -func toField(f *reflect.StructField) field { +func toField(f *reflect.StructField, x exporter) field { return field(f.Offset) } @@ -284,3 +288,16 @@ func atomicLoadDiscardInfo(p **discardInfo) *discardInfo { func atomicStoreDiscardInfo(p **discardInfo, v *discardInfo) { atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(p)), unsafe.Pointer(v)) } + +// fieldByName is equivalent to reflect.Value.FieldByName, but is able to +// descend into unexported fields for prop +func fieldByName(v reflect.Value, s string) reflect.Value { + if r, _ := utf8.DecodeRuneInString(s); unicode.IsUpper(r) { + return v.FieldByName(s) + } + sf, ok := v.Type().FieldByName(s) + if !ok { + return reflect.Value{} + } + return reflect.NewAt(sf.Type, unsafe.Pointer(v.UnsafeAddr()+sf.Offset)).Elem() +} diff --git a/proto/properties.go b/proto/properties.go index 6f15aeb839..88f6bb72f1 100644 --- a/proto/properties.go +++ b/proto/properties.go @@ -209,16 +209,9 @@ func newProperties(t reflect.Type) *StructProperties { } // Construct a mapping of oneof field names to properties. - var oneofWrappers []interface{} - if fn, ok := reflect.PtrTo(t).MethodByName("XXX_OneofFuncs"); ok { - oneofWrappers = fn.Func.Call([]reflect.Value{reflect.Zero(fn.Type.In(0))})[3].Interface().([]interface{}) - } - if fn, ok := reflect.PtrTo(t).MethodByName("XXX_OneofWrappers"); ok { - oneofWrappers = fn.Func.Call([]reflect.Value{reflect.Zero(fn.Type.In(0))})[0].Interface().([]interface{}) - } - if len(oneofWrappers) > 0 { + if oneofImplementors := oneofWrappers(t); len(oneofImplementors) > 0 { prop.OneofTypes = make(map[string]*OneofProperties) - for _, wrapper := range oneofWrappers { + for _, wrapper := range oneofImplementors { p := &OneofProperties{ Type: reflect.ValueOf(wrapper).Type(), // *T Prop: new(Properties), diff --git a/proto/table_marshal.go b/proto/table_marshal.go index c3b581de3b..f6ac7f5403 100644 --- a/proto/table_marshal.go +++ b/proto/table_marshal.go @@ -297,37 +297,37 @@ func (u *marshalInfo) computeMarshalInfo() { return } - // get oneof implementers - var oneofImplementers []interface{} - switch m := reflect.Zero(reflect.PtrTo(t)).Interface().(type) { - case oneofFuncsIface: - _, _, _, oneofImplementers = m.XXX_OneofFuncs() - case oneofWrappersIface: - oneofImplementers = m.XXX_OneofWrappers() - } + oneofImplementers := oneofWrappers(t) + u.messageset = isMessageSet(t) + expFunc := exporterFunc(t) n := t.NumField() - // deal with XXX fields first + // deal with XXX and unexported fields first. for i := 0; i < t.NumField(); i++ { f := t.Field(i) - if !strings.HasPrefix(f.Name, "XXX_") { + if !strings.HasPrefix(f.Name, "XXX_") && f.PkgPath == "" { continue } switch f.Name { case "XXX_sizecache": - u.sizecache = toField(&f) + u.sizecache = toField(&f, nil) case "XXX_unrecognized": - u.unrecognized = toField(&f) + u.unrecognized = toField(&f, nil) case "XXX_InternalExtensions": - u.extensions = toField(&f) - u.messageset = f.Tag.Get("protobuf_messageset") == "1" + u.extensions = toField(&f, nil) + if f.Tag.Get("protobuf_messageset") == "1" { + u.messageset = true + } case "XXX_extensions": - u.v1extensions = toField(&f) - case "XXX_NoUnkeyedLiteral": - // nothing to do - default: - panic("unknown XXX field: " + f.Name) + u.v1extensions = toField(&f, nil) + + case "sizeCache": + u.sizecache = toField(&f, expFunc) + case "unknownFields": + u.unrecognized = toField(&f, expFunc) + case "extensionFields": + u.extensions = toField(&f, expFunc) } n-- } @@ -338,7 +338,7 @@ func (u *marshalInfo) computeMarshalInfo() { for i, j := 0, 0; i < t.NumField(); i++ { f := t.Field(i) - if strings.HasPrefix(f.Name, "XXX_") { + if strings.HasPrefix(f.Name, "XXX_") || f.PkgPath != "" { continue } field := &fields[j] @@ -438,7 +438,7 @@ func (fi *marshalFieldInfo) computeMarshalFieldInfo(f *reflect.StructField) { } func (fi *marshalFieldInfo) computeOneofFieldInfo(f *reflect.StructField, oneofImplementers []interface{}) { - fi.field = toField(f) + fi.field = toField(f, nil) fi.wiretag = math.MaxInt32 // Use a large tag number, make oneofs sorted at the end. This tag will not appear on the wire. fi.isPointer = true fi.sizer, fi.marshaler = makeOneOfMarshaler(fi, f) @@ -486,7 +486,7 @@ func wiretype(encoding string) uint64 { // setTag fills up the tag (in wire format) and its size in the info of a field. func (fi *marshalFieldInfo) setTag(f *reflect.StructField, tag int, wt uint64) { - fi.field = toField(f) + fi.field = toField(f, nil) fi.wiretag = uint64(tag)<<3 | wt fi.tagsize = SizeVarint(uint64(tag) << 3) } diff --git a/proto/table_merge.go b/proto/table_merge.go index 3565efbda7..04f9a90db4 100644 --- a/proto/table_merge.go +++ b/proto/table_merge.go @@ -141,11 +141,11 @@ func (mi *mergeInfo) computeMergeInfo() { props := GetProperties(t) for i := 0; i < n; i++ { f := t.Field(i) - if strings.HasPrefix(f.Name, "XXX_") { + if strings.HasPrefix(f.Name, "XXX_") || f.PkgPath != "" { continue } - mfi := mergeFieldInfo{field: toField(&f)} + mfi := mergeFieldInfo{field: toField(&f, nil)} tf := f.Type // As an optimization, we can avoid the merge function call cost @@ -611,12 +611,19 @@ func (mi *mergeInfo) computeMergeInfo() { mi.fields = append(mi.fields, mfi) } + expFunc := exporterFunc(t) mi.unrecognized = invalidField if f, ok := t.FieldByName("XXX_unrecognized"); ok { if f.Type != reflect.TypeOf([]byte{}) { panic("expected XXX_unrecognized to be of type []byte") } - mi.unrecognized = toField(&f) + mi.unrecognized = toField(&f, nil) + } + if f, ok := t.FieldByName("unknownFields"); ok { + if f.Type != reflect.TypeOf([]byte{}) { + panic("expected unknownFields to be of type []byte") + } + mi.unrecognized = toField(&f, expFunc) } atomic.StoreInt32(&mi.initialized, 1) diff --git a/proto/table_unmarshal.go b/proto/table_unmarshal.go index 152ab79fde..64bf90928b 100644 --- a/proto/table_unmarshal.go +++ b/proto/table_unmarshal.go @@ -312,6 +312,10 @@ func (u *unmarshalInfo) computeUnmarshalInfo() { } var oneofFields []oneofField + oneofImplementers := oneofWrappers(t) + u.isMessageSet = isMessageSet(t) + expFunc := exporterFunc(t) + for i := 0; i < n; i++ { f := t.Field(i) if f.Name == "XXX_unrecognized" { @@ -319,7 +323,7 @@ func (u *unmarshalInfo) computeUnmarshalInfo() { if f.Type != reflect.TypeOf(([]byte)(nil)) { panic("bad type for XXX_unrecognized field: " + f.Type.Name()) } - u.unrecognized = toField(&f) + u.unrecognized = toField(&f, nil) continue } if f.Name == "XXX_InternalExtensions" { @@ -327,7 +331,7 @@ func (u *unmarshalInfo) computeUnmarshalInfo() { if f.Type != reflect.TypeOf(XXX_InternalExtensions{}) { panic("bad type for XXX_InternalExtensions field: " + f.Type.Name()) } - u.extensions = toField(&f) + u.extensions = toField(&f, nil) if f.Tag.Get("protobuf_messageset") == "1" { u.isMessageSet = true } @@ -338,16 +342,31 @@ func (u *unmarshalInfo) computeUnmarshalInfo() { if f.Type != reflect.TypeOf((map[int32]Extension)(nil)) { panic("bad type for XXX_extensions field: " + f.Type.Name()) } - u.oldExtensions = toField(&f) + u.oldExtensions = toField(&f, nil) continue } - if f.Name == "XXX_NoUnkeyedLiteral" || f.Name == "XXX_sizecache" { + if f.Name == "unknownFields" { + if f.Type != reflect.TypeOf(([]byte)(nil)) { + panic("bad type for unknownFields field: " + f.Type.Name()) + } + u.unrecognized = toField(&f, expFunc) + continue + } + if f.Name == "extensionFields" { + if f.Type != reflect.TypeOf(XXX_InternalExtensions{}) { + panic("bad type for extensionFields field: " + f.Type.Name()) + } + u.extensions = toField(&f, expFunc) + continue + } + + if strings.HasPrefix(f.Name, "XXX_") || f.PkgPath != "" { continue } oneof := f.Tag.Get("protobuf_oneof") if oneof != "" { - oneofFields = append(oneofFields, oneofField{f.Type, toField(&f)}) + oneofFields = append(oneofFields, oneofField{f.Type, toField(&f, nil)}) // The rest of oneof processing happens below. continue } @@ -384,17 +403,10 @@ func (u *unmarshalInfo) computeUnmarshalInfo() { } // Store the info in the correct slot in the message. - u.setTag(tag, toField(&f), unmarshal, reqMask, name) + u.setTag(tag, toField(&f, nil), unmarshal, reqMask, name) } // Find any types associated with oneof fields. - var oneofImplementers []interface{} - switch m := reflect.Zero(reflect.PtrTo(t)).Interface().(type) { - case oneofFuncsIface: - _, _, _, oneofImplementers = m.XXX_OneofFuncs() - case oneofWrappersIface: - oneofImplementers = m.XXX_OneofWrappers() - } for _, v := range oneofImplementers { tptr := reflect.TypeOf(v) // *Msg_X typ := tptr.Elem() // Msg_X @@ -1846,7 +1858,7 @@ func makeUnmarshalMap(f *reflect.StructField) unmarshaler { // Note that this function will be called once for each case in the oneof. func makeUnmarshalOneof(typ, ityp reflect.Type, unmarshal unmarshaler) unmarshaler { sf := typ.Field(0) - field0 := toField(&sf) + field0 := toField(&sf, nil) return func(b []byte, f pointer, w int) ([]byte, error) { // Allocate holder for value. v := reflect.New(typ) diff --git a/proto/text.go b/proto/text.go index ef8735f116..40f7912b0c 100644 --- a/proto/text.go +++ b/proto/text.go @@ -19,6 +19,7 @@ import ( "sort" "strings" + protoV2 "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" ) @@ -165,11 +166,14 @@ func requiresQuotes(u string) bool { // isAny reports whether sv is a google.protobuf.Any message func isAny(sv reflect.Value) bool { - type wkt interface { - XXX_WellKnownType() string + switch m := sv.Addr().Interface().(type) { + case interface{ XXX_WellKnownType() string }: + return m.XXX_WellKnownType() == "Any" + case protoV2.Message: + return m.ProtoReflect().Descriptor().FullName() == "google.protobuf.Any" + default: + return false } - t, ok := sv.Addr().Interface().(wkt) - return ok && t.XXX_WellKnownType() == "Any" } // writeProto3Any writes an expanded google.protobuf.Any message. @@ -236,23 +240,9 @@ func (tm *textMarshaler) writeStruct(w *textWriter, sv reflect.Value) error { for i := 0; i < sv.NumField(); i++ { fv := sv.Field(i) props := sprops.Prop[i] - name := st.Field(i).Name - - if name == "XXX_NoUnkeyedLiteral" { - continue - } - if strings.HasPrefix(name, "XXX_") { - // There are two XXX_ fields: - // XXX_unrecognized []byte - // XXX_extensions map[int32]proto.Extension - // The first is handled here; - // the second is handled at the bottom of this function. - if name == "XXX_unrecognized" && !fv.IsNil() { - if err := writeUnknownStruct(w, fv.Interface().([]byte)); err != nil { - return err - } - } + f := st.Field(i) + if strings.HasPrefix(f.Name, "XXX_") || f.PkgPath != "" { continue } if fv.Kind() == reflect.Ptr && fv.IsNil() { @@ -420,6 +410,12 @@ func (tm *textMarshaler) writeStruct(w *textWriter, sv reflect.Value) error { } } + if fv := unknownFieldsValue(sv); !fv.IsNil() { + if err := writeUnknownStruct(w, fv.Interface().([]byte)); err != nil { + return err + } + } + // Extensions (the XXX_extensions field). pv := sv.Addr() if _, err := extendable(pv.Interface()); err == nil { @@ -682,15 +678,20 @@ func (tm *textMarshaler) writeExtensions(w *textWriter, pv reflect.Value) error return fmt.Errorf("failed getting extension: %v", err) } + name := desc.Name + if strings.HasSuffix(name, ".message_set_extension") && isMessageSet(pv.Type().Elem()) { + name = strings.TrimSuffix(name, ".message_set_extension") + } + // Repeated extensions will appear as a slice. if !isRepeatedExtension(desc) { - if err := tm.writeExtension(w, desc.Name, pb); err != nil { + if err := tm.writeExtension(w, name, pb); err != nil { return err } } else { v := reflect.ValueOf(pb) for i := 0; i < v.Len(); i++ { - if err := tm.writeExtension(w, desc.Name, v.Index(i).Interface()); err != nil { + if err := tm.writeExtension(w, name, v.Index(i).Interface()); err != nil { return err } } diff --git a/proto/text_parser.go b/proto/text_parser.go index 443f62707e..8fb9463a4b 100644 --- a/proto/text_parser.go +++ b/proto/text_parser.go @@ -525,6 +525,10 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error { desc = d break } + if strings.TrimSuffix(d.Name, ".message_set_extension") == extName && isMessageSet(st) { + desc = d + break + } } if desc == nil { return p.errorf("unrecognized extension %q", extName) diff --git a/protoc-gen-go/descriptor/descriptor.pb.go b/protoc-gen-go/descriptor/descriptor.pb.go index 0e07d08624..c2fb216795 100644 --- a/protoc-gen-go/descriptor/descriptor.pb.go +++ b/protoc-gen-go/descriptor/descriptor.pb.go @@ -7,6 +7,7 @@ import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" descriptorpb "google.golang.org/protobuf/types/descriptorpb" + reflect "reflect" sync "sync" ) @@ -199,8 +200,10 @@ func file_github_com_golang_protobuf_protoc_gen_go_descriptor_descriptor_proto_i if File_github_com_golang_protobuf_protoc_gen_go_descriptor_descriptor_proto != nil { return } + type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_github_com_golang_protobuf_protoc_gen_go_descriptor_descriptor_proto_rawDesc, NumEnums: 0, NumMessages: 0, diff --git a/protoc-gen-go/plugin/plugin.pb.go b/protoc-gen-go/plugin/plugin.pb.go index 5d88470c36..28c50931f3 100644 --- a/protoc-gen-go/plugin/plugin.pb.go +++ b/protoc-gen-go/plugin/plugin.pb.go @@ -7,6 +7,7 @@ import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" pluginpb "google.golang.org/protobuf/types/pluginpb" + reflect "reflect" sync "sync" ) @@ -66,8 +67,10 @@ func file_github_com_golang_protobuf_protoc_gen_go_plugin_plugin_proto_init() { if File_github_com_golang_protobuf_protoc_gen_go_plugin_plugin_proto != nil { return } + type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_github_com_golang_protobuf_protoc_gen_go_plugin_plugin_proto_rawDesc, NumEnums: 0, NumMessages: 0, diff --git a/ptypes/any/any.pb.go b/ptypes/any/any.pb.go index d076aa2df9..96dc8e3f9f 100644 --- a/ptypes/any/any.pb.go +++ b/ptypes/any/any.pb.go @@ -7,6 +7,7 @@ import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" anypb "google.golang.org/protobuf/types/known/anypb" + reflect "reflect" sync "sync" ) @@ -61,8 +62,10 @@ func file_github_com_golang_protobuf_ptypes_any_any_proto_init() { if File_github_com_golang_protobuf_ptypes_any_any_proto != nil { return } + type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_github_com_golang_protobuf_ptypes_any_any_proto_rawDesc, NumEnums: 0, NumMessages: 0, diff --git a/ptypes/duration/duration.pb.go b/ptypes/duration/duration.pb.go index 8dc1778321..ea23997333 100644 --- a/ptypes/duration/duration.pb.go +++ b/ptypes/duration/duration.pb.go @@ -7,6 +7,7 @@ import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" durationpb "google.golang.org/protobuf/types/known/durationpb" + reflect "reflect" sync "sync" ) @@ -62,8 +63,10 @@ func file_github_com_golang_protobuf_ptypes_duration_duration_proto_init() { if File_github_com_golang_protobuf_ptypes_duration_duration_proto != nil { return } + type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_github_com_golang_protobuf_ptypes_duration_duration_proto_rawDesc, NumEnums: 0, NumMessages: 0, diff --git a/ptypes/empty/empty.pb.go b/ptypes/empty/empty.pb.go index 890ab4ceff..b8022dc19e 100644 --- a/ptypes/empty/empty.pb.go +++ b/ptypes/empty/empty.pb.go @@ -7,6 +7,7 @@ import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" emptypb "google.golang.org/protobuf/types/known/emptypb" + reflect "reflect" sync "sync" ) @@ -61,8 +62,10 @@ func file_github_com_golang_protobuf_ptypes_empty_empty_proto_init() { if File_github_com_golang_protobuf_ptypes_empty_empty_proto != nil { return } + type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_github_com_golang_protobuf_ptypes_empty_empty_proto_rawDesc, NumEnums: 0, NumMessages: 0, diff --git a/ptypes/struct/struct.pb.go b/ptypes/struct/struct.pb.go index 47e8fd776c..fae4ea1ed5 100644 --- a/ptypes/struct/struct.pb.go +++ b/ptypes/struct/struct.pb.go @@ -7,6 +7,7 @@ import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" structpb "google.golang.org/protobuf/types/known/structpb" + reflect "reflect" sync "sync" ) @@ -77,8 +78,10 @@ func file_github_com_golang_protobuf_ptypes_struct_struct_proto_init() { if File_github_com_golang_protobuf_ptypes_struct_struct_proto != nil { return } + type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_github_com_golang_protobuf_ptypes_struct_struct_proto_rawDesc, NumEnums: 0, NumMessages: 0, diff --git a/ptypes/timestamp/timestamp.pb.go b/ptypes/timestamp/timestamp.pb.go index 32fe5b318c..4100c26384 100644 --- a/ptypes/timestamp/timestamp.pb.go +++ b/ptypes/timestamp/timestamp.pb.go @@ -7,6 +7,7 @@ import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" timestamppb "google.golang.org/protobuf/types/known/timestamppb" + reflect "reflect" sync "sync" ) @@ -63,8 +64,10 @@ func file_github_com_golang_protobuf_ptypes_timestamp_timestamp_proto_init() { if File_github_com_golang_protobuf_ptypes_timestamp_timestamp_proto != nil { return } + type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_github_com_golang_protobuf_ptypes_timestamp_timestamp_proto_rawDesc, NumEnums: 0, NumMessages: 0, diff --git a/ptypes/wrappers/wrappers.pb.go b/ptypes/wrappers/wrappers.pb.go index dec0553be4..faeefe7923 100644 --- a/ptypes/wrappers/wrappers.pb.go +++ b/ptypes/wrappers/wrappers.pb.go @@ -7,6 +7,7 @@ import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" wrapperspb "google.golang.org/protobuf/types/known/wrapperspb" + reflect "reflect" sync "sync" ) @@ -70,8 +71,10 @@ func file_github_com_golang_protobuf_ptypes_wrappers_wrappers_proto_init() { if File_github_com_golang_protobuf_ptypes_wrappers_wrappers_proto != nil { return } + type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_github_com_golang_protobuf_ptypes_wrappers_wrappers_proto_rawDesc, NumEnums: 0, NumMessages: 0, diff --git a/test.bash b/test.bash index a7351bfc60..bab98bc155 100755 --- a/test.bash +++ b/test.bash @@ -10,9 +10,17 @@ PASS="\x1b[32mPASS" FAIL="\x1b[31mFAIL" RESET="\x1b[0m" -echo -e "${BOLD}go test${RESET}" -RET_TEST=$((go test ./... && go test -tags use_golang_protobuf_v1 ./...) | egrep -v "^(ok|[?])\s+") -if [[ ! -z "$RET_TEST" ]]; then echo "$RET_TEST"; echo; fi +echo -e "${BOLD}go test -tags proto1_legacy ./...${RESET}" +RET_TEST0=$(go test -tags proto1_legacy ./... | egrep -v "^(ok|[?])\s+") +if [[ ! -z "$RET_TEST0" ]]; then echo "$RET_TEST0"; echo; fi + +echo -e "${BOLD}go test -tags use_golang_protobuf_v1 ./...${RESET}" +RET_TEST1=$(go test -tags use_golang_protobuf_v1 ./... | egrep -v "^(ok|[?])\s+") +if [[ ! -z "$RET_TEST1" ]]; then echo "$RET_TEST1"; echo; fi + +echo -e "${BOLD}go test -tags "use_golang_protobuf_v1 purego" ./...${RESET}" +RET_TEST2=$(go test -tags "use_golang_protobuf_v1 purego" ./... | egrep -v "^(ok|[?])\s+") +if [[ ! -z "$RET_TEST2" ]]; then echo "$RET_TEST2"; echo; fi echo -e "${BOLD}go generate${RESET}" RET_GEN=$(go run ./internal/cmd/generate-alias 2>&1) @@ -30,7 +38,7 @@ echo -e "${BOLD}git ls-files${RESET}" RET_FILES=$(git ls-files --others --exclude-standard 2>&1) if [[ ! -z "$RET_FILES" ]]; then echo "$RET_FILES"; echo; fi -if [[ ! -z "$RET_TEST" ]] || [[ ! -z "$RET_GEN" ]] || [ ! -z "$RET_FMT" ] || [[ ! -z "$RET_DIFF" ]] || [[ ! -z "$RET_FILES" ]]; then +if [[ ! -z "$RET_TEST0" ]] || [[ ! -z "$RET_TEST1" ]] || [[ ! -z "$RET_TEST2" ]] || [[ ! -z "$RET_GEN" ]] || [ ! -z "$RET_FMT" ] || [[ ! -z "$RET_DIFF" ]] || [[ ! -z "$RET_FILES" ]]; then echo -e "${FAIL}${RESET}"; exit 1 else echo -e "${PASS}${RESET}"; exit 0