From 6927801ceae6aae7ce6de5a8415553a3e69ad182 Mon Sep 17 00:00:00 2001 From: kbullinger Date: Tue, 1 Oct 2024 10:32:29 -0400 Subject: [PATCH 1/4] add support for proto2 groups --- protobuf-codegen/src/gen/field/accessor.rs | 8 +- protobuf-codegen/src/gen/field/elem.rs | 30 +++++-- protobuf-codegen/src/gen/field/mod.rs | 83 ++++++++++++++----- protobuf-codegen/src/gen/field/type_ext.rs | 2 +- protobuf-codegen/src/gen/message.rs | 38 +++++++-- protobuf-codegen/src/gen/rust_types_values.rs | 3 - protobuf/src/coded_input_stream/mod.rs | 76 +++++++++++++++++ protobuf/src/coded_output_stream/mod.rs | 1 + protobuf/src/rt/mod.rs | 6 ++ protobuf/src/rt/unknown_or_group.rs | 20 ++++- protobuf/src/unknown.rs | 16 ++++ protobuf/src/wire_format.rs | 13 ++- 12 files changed, 252 insertions(+), 44 deletions(-) diff --git a/protobuf-codegen/src/gen/field/accessor.rs b/protobuf-codegen/src/gen/field/accessor.rs index 0fa7ec3bf..7baaa8771 100644 --- a/protobuf-codegen/src/gen/field/accessor.rs +++ b/protobuf-codegen/src/gen/field/accessor.rs @@ -155,9 +155,11 @@ impl FieldGen<'_> { type_params: vec!["_".to_owned()], callback_params: self.make_accessor_fns_lambda(), }, - FieldElem::Group => { - unreachable!("no accessor for group field"); - } + FieldElem::Group(m) => AccessorFn { + name: "make_message_field_accessor".to_owned(), + type_params: vec![format!("{}", m.rust_name_relative(&self.file_and_mod()))], + callback_params: self.make_accessor_fns_lambda(), + }, } } diff --git a/protobuf-codegen/src/gen/field/elem.rs b/protobuf-codegen/src/gen/field/elem.rs index fd06b4ced..8609cca65 100644 --- a/protobuf-codegen/src/gen/field/elem.rs +++ b/protobuf-codegen/src/gen/field/elem.rs @@ -74,7 +74,7 @@ pub(crate) enum FieldElem<'a> { Primitive(Type, PrimitiveTypeVariant), Message(FieldElemMessage<'a>), Enum(FieldElemEnum<'a>), - Group, + Group(FieldElemMessage<'a>), } pub(crate) enum HowToGetMessageSize { @@ -86,7 +86,7 @@ impl<'a> FieldElem<'a> { pub(crate) fn proto_type(&self) -> Type { match *self { FieldElem::Primitive(t, ..) => t, - FieldElem::Group => Type::TYPE_GROUP, + FieldElem::Group(_) => Type::TYPE_GROUP, FieldElem::Message(..) => Type::TYPE_MESSAGE, FieldElem::Enum(..) => Type::TYPE_ENUM, } @@ -106,7 +106,7 @@ impl<'a> FieldElem<'a> { RustType::Bytes } FieldElem::Primitive(.., PrimitiveTypeVariant::TokioBytes) => unreachable!(), - FieldElem::Group => RustType::Group, + FieldElem::Group(ref m) => m.rust_type(reference), FieldElem::Message(ref m) => m.rust_type(reference), FieldElem::Enum(ref en) => en.enum_or_unknown_rust_type(reference), } @@ -213,6 +213,18 @@ impl<'a> FieldElem<'a> { protobuf_crate_path(customize), )); } + Type::TYPE_GROUP => { + match how_to_get_message_size { + HowToGetMessageSize::Compute => { + w.write_line(&format!("let len = {}.compute_size();", item_var.value)) + } + HowToGetMessageSize::GetCached => w.write_line(&format!( + "let len = {}.cached_size() as u64;", + item_var.value + )), + } + w.write_line(&format!("{sum_var} += {tag_size} * 2 + len;",)); + } _ => { w.write_line(&format!( "{sum_var} += {};", @@ -232,7 +244,7 @@ impl<'a> FieldElem<'a> { w: &mut CodeWriter, ) { match self.proto_type() { - Type::TYPE_MESSAGE => { + Type::TYPE_MESSAGE | Type::TYPE_GROUP => { let param_type = RustType::Ref(Box::new(self.rust_storage_elem_type(file_and_mod))); w.write_line(&format!( @@ -273,10 +285,7 @@ pub(crate) fn field_elem<'a>( if let RuntimeFieldType::Map(..) = field.field.runtime_field_type() { unreachable!(); } - - if field.field.proto().type_() == Type::TYPE_GROUP { - FieldElem::Group - } else if field.field.proto().has_type_name() { + if field.field.proto().has_type_name() { let message_or_enum = root_scope .find_message_or_enum(&ProtobufAbsPath::from(field.field.proto().type_name())); match (field.field.proto().type_(), message_or_enum) { @@ -285,6 +294,11 @@ pub(crate) fn field_elem<'a>( message: message.clone(), }) } + (Type::TYPE_GROUP, MessageOrEnumWithScope::Message(message)) => { + FieldElem::Group(FieldElemMessage { + message: message.clone(), + }) + } (Type::TYPE_ENUM, MessageOrEnumWithScope::Enum(enum_with_scope)) => { let default_value = if field.field.proto().has_default_value() { enum_with_scope.value_by_name(field.field.proto().default_value()) diff --git a/protobuf-codegen/src/gen/field/mod.rs b/protobuf-codegen/src/gen/field/mod.rs index 344aca2e2..9ab9270b4 100644 --- a/protobuf-codegen/src/gen/field/mod.rs +++ b/protobuf-codegen/src/gen/field/mod.rs @@ -603,26 +603,22 @@ impl<'a> FieldGen<'a> { } pub(crate) fn write_struct_field(&self, w: &mut CodeWriter) { - if self.proto_type == field_descriptor_proto::Type::TYPE_GROUP { - w.comment(&format!("{}: ", &self.rust_name)); - } else { - w.all_documentation(self.info, &self.path); + w.all_documentation(self.info, &self.path); - write_protoc_insertion_point_for_field(w, &self.customize, &self.proto_field.field); - w.field_decl_vis( - Visibility::Public, - &self.rust_name.to_string(), - &self - .full_storage_type( - &self - .proto_field - .message - .scope - .file_and_mod(self.customize.clone()), - ) - .to_code(&self.customize), - ); - } + write_protoc_insertion_point_for_field(w, &self.customize, &self.proto_field.field); + w.field_decl_vis( + Visibility::Public, + &self.rust_name.to_string(), + &self + .full_storage_type( + &self + .proto_field + .message + .scope + .file_and_mod(self.customize.clone()), + ) + .to_code(&self.customize), + ); } fn write_if_let_self_field_is_some(&self, s: &SingularField, w: &mut CodeWriter, cb: F) @@ -1035,6 +1031,7 @@ impl<'a> FieldGen<'a> { self.rust_name, )); } + FieldElem::Group(..) => self.write_merge_delimited_group_case_block(w), _ => { let read_proc = s.elem.read_one_liner(); self.write_self_field_assign_some(w, s, &read_proc); @@ -1042,6 +1039,50 @@ impl<'a> FieldGen<'a> { }) } + fn write_merge_delimited_repeated_group_case_block(&self, w: &mut CodeWriter<'_>) { + let name = &self.rust_name; + let proto_path = protobuf_crate_path(&self.customize); + w.write_line(&format!( + "let end_tag = {proto_path}::rt::set_wire_type_to_end_group(tag);" + )); + let value_type = &self + .elem() + .rust_storage_elem_type( + &self + .proto_field + .message + .scope + .file_and_mod(self.customize.clone()), + ) + .to_code(&self.customize); + + w.write_line(&format!("let mut {name} = {value_type}::default();")); + w.write_line(&format!("{name}.merge_delimited(is, end_tag)?;")); + w.write_line(&format!("self.{name}.push({name});",)); + } + + fn write_merge_delimited_group_case_block(&self, w: &mut CodeWriter<'_>) { + let name = &self.rust_name; + let proto_path = protobuf_crate_path(&self.customize); + w.write_line(&format!( + "let end_tag = {proto_path}::rt::set_wire_type_to_end_group(tag);" + )); + let value_type = &self + .elem() + .rust_storage_elem_type( + &self + .proto_field + .message + .scope + .file_and_mod(self.customize.clone()), + ) + .to_code(&self.customize); + + w.write_line(&format!("let mut {name} = {value_type}::default();")); + w.write_line(&format!("{name}.merge_delimited(is, end_tag)?;")); + w.write_line(&format!("self.{name} = Some({name});",)); + } + // Write `merge_from` part for this repeated field fn write_merge_from_repeated_case_block(&self, w: &mut CodeWriter) { let field = match self.kind { @@ -1057,6 +1098,9 @@ impl<'a> FieldGen<'a> { self.write_merge_from_field_message_string_bytes_repeated(field, w); }) } + FieldElem::Group(_) => w.case_block(&format!("{}", self.tag()), |w| { + self.write_merge_delimited_repeated_group_case_block(w); + }), FieldElem::Enum(..) => { w.case_block( &format!("{}", self.tag_with_wire_type(WireType::Varint)), @@ -1338,6 +1382,7 @@ impl<'a> FieldGen<'a> { match singular.elem { FieldElem::Message(..) => self.write_message_field_get_singular_message(singular, w), + FieldElem::Group(..) => self.write_message_field_get_singular_message(singular, w), FieldElem::Enum(ref en) => { self.write_message_field_get_singular_enum(singular.flag, en, w) } diff --git a/protobuf-codegen/src/gen/field/type_ext.rs b/protobuf-codegen/src/gen/field/type_ext.rs index 464a5024a..ccc4ed2c6 100644 --- a/protobuf-codegen/src/gen/field/type_ext.rs +++ b/protobuf-codegen/src/gen/field/type_ext.rs @@ -37,7 +37,7 @@ impl TypeExt for Type { fn is_copy(&self) -> bool { match self { - Type::TYPE_MESSAGE | Type::TYPE_STRING | Type::TYPE_BYTES => false, + Type::TYPE_GROUP | Type::TYPE_MESSAGE | Type::TYPE_STRING | Type::TYPE_BYTES => false, _ => true, } } diff --git a/protobuf-codegen/src/gen/message.rs b/protobuf-codegen/src/gen/message.rs index 9221ab79c..e2c4a633b 100644 --- a/protobuf-codegen/src/gen/message.rs +++ b/protobuf-codegen/src/gen/message.rs @@ -271,7 +271,7 @@ impl<'a> MessageGen<'a> { self.rust_name() ), |w| { - for f in &self.fields_except_oneof_and_group() { + for f in self.fields_except_oneof() { w.field_entry( &f.rust_name.to_string(), &f.kind @@ -310,6 +310,31 @@ impl<'a> MessageGen<'a> { ); } + fn write_merge_delimited(&self, w: &mut CodeWriter) { + let sig = format!( + "merge_delimited(&mut self, is: &mut {}::CodedInputStream<'_>, end_tag: u32) -> {}::Result<()>", + protobuf_crate_path(&self.customize.for_elem), + protobuf_crate_path(&self.customize.for_elem), + ); + w.comment("read and merge values from the stream, stopping when `end_tag` tag is reached"); + w.pub_fn(&sig, |w| { + w.while_block("let Some(tag) = is.read_raw_tag_or_eof()?", |w| { + w.if_stmt("tag == end_tag", |w| { + w.write_line("return ::std::result::Result::Ok(());"); + }); + w.match_block("tag", |w| { + for f in &self.fields { + f.write_merge_from_field_case_block(w); + } + w.case_block("tag", |w| { + w.write_line(&format!("{}::rt::read_unknown(tag, is, self.special_fields.mut_unknown_fields())?;", protobuf_crate_path(&self.customize.for_elem))); + }); + }); + }); + w.write_line("::std::result::Result::Ok(())"); + }); + } + fn write_compute_size(&self, w: &mut CodeWriter) { // Append sizes of messages in the tree to the specified vector. // First appended element is size of self, and then nested message sizes. @@ -320,7 +345,7 @@ impl<'a> MessageGen<'a> { w.def_fn("compute_size(&self) -> u64", |w| { // To have access to its methods but not polute the name space. w.write_line("let mut my_size = 0;"); - for field in self.fields_except_oneof_and_group() { + for field in &self.fields_except_oneof() { field.write_message_compute_field_size("my_size", w); } self.write_match_each_oneof_variant(w, |w, variant, v| { @@ -338,7 +363,7 @@ impl<'a> MessageGen<'a> { } fn write_field_accessors(&self, w: &mut CodeWriter) { - for f in self.fields_except_group() { + for f in &self.fields { f.write_message_single_field_accessors(w); } } @@ -351,6 +376,9 @@ impl<'a> MessageGen<'a> { self.write_field_accessors(w); + w.write_line(""); + self.write_merge_delimited(w); + if !self.lite_runtime { w.write_line(""); self.write_generated_message_descriptor_data(w); @@ -385,11 +413,11 @@ impl<'a> MessageGen<'a> { w.def_fn(&sig, |w| { w.while_block("let Some(tag) = is.read_raw_tag_or_eof()?", |w| { w.match_block("tag", |w| { - for f in &self.fields_except_group() { + for f in &self.fields { f.write_merge_from_field_case_block(w); } w.case_block("tag", |w| { - w.write_line(&format!("{}::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?;", protobuf_crate_path(&self.customize.for_elem))); + w.write_line(&format!("{}::rt::read_unknown(tag, is, self.special_fields.mut_unknown_fields())?;", protobuf_crate_path(&self.customize.for_elem))); }); }); }); diff --git a/protobuf-codegen/src/gen/rust_types_values.rs b/protobuf-codegen/src/gen/rust_types_values.rs index ee2f730cd..8e2ce87d4 100644 --- a/protobuf-codegen/src/gen/rust_types_values.rs +++ b/protobuf-codegen/src/gen/rust_types_values.rs @@ -58,8 +58,6 @@ pub(crate) enum RustType { Bytes, // chars::Chars Chars, - // group - Group, } impl RustType { @@ -101,7 +99,6 @@ impl RustType { protobuf_crate_path(customize), name ), - RustType::Group => format!(""), RustType::Bytes => format!("::bytes::Bytes"), RustType::Chars => format!("{}::Chars", protobuf_crate_path(customize)), } diff --git a/protobuf/src/coded_input_stream/mod.rs b/protobuf/src/coded_input_stream/mod.rs index a979df19c..4994286ef 100644 --- a/protobuf/src/coded_input_stream/mod.rs +++ b/protobuf/src/coded_input_stream/mod.rs @@ -43,6 +43,7 @@ use crate::wire_format; use crate::wire_format::WireType; use crate::zigzag::decode_zig_zag_32; use crate::zigzag::decode_zig_zag_64; +use crate::CodedOutputStream; use crate::EnumOrUnknown; use crate::Message; use crate::MessageDyn; @@ -521,6 +522,81 @@ impl<'a> CodedInputStream<'a> { Ok(()) } + /// Given the field number of a Group field, read raw encoded bytes until corresponding + /// end tag is found. + pub fn read_group_to_bytes(&mut self, field_number: u32) -> crate::Result> { + let mut buf = vec![]; + let mut os = CodedOutputStream::new(&mut buf); + let end_tag = field_number << 3 | (WireType::EndGroup as u32); + while !self.eof()? { + let tag = self.read_tag()?; + let (field_number, wire_type) = tag.unpack(); + if tag.value() == end_tag { + drop(os); + return Ok(buf); + } + match wire_type { + WireType::Varint => { + let n = self.read_raw_varint64()?; + os.write_raw_varint32(tag.value())?; + os.write_raw_varint64(n)?; + } + WireType::Fixed32 => { + let n = self.read_fixed32()?; + os.write_raw_varint32(tag.value())?; + os.write_fixed32_no_tag(n)?; + } + WireType::Fixed64 => { + let n = self.read_fixed64()?; + os.write_raw_varint32(tag.value())?; + os.write_fixed64_no_tag(n)?; + } + WireType::LengthDelimited => { + let len = self.read_raw_varint32()?; + let bytes = self.read_raw_bytes(len)?; + os.write_raw_varint32(tag.value())?; + os.write_raw_bytes(bytes.as_ref())?; + } + WireType::StartGroup => { + self.incr_recursion()?; + let group_bytes = self.read_group_to_bytes(field_number)?; + self.decr_recursion(); + os.write_raw_varint32(tag.value())?; + os.write_raw_bytes(group_bytes.as_ref())?; + let end_tag = crate::rt::set_wire_type_to_end_group(tag.value()); + os.write_raw_varint32(end_tag)?; + } + WireType::EndGroup => { + return Err(WireError::UnexpectedWireType(wire_type).into()); + } + } + } + Err(WireError::UnexpectedEof.into()) + } + + /// Read `UnknownValue` given field number and wire type + pub fn read_unknown_with_tag_unpacked( + &mut self, + field_number: u32, + wire_type: WireType, + ) -> crate::Result { + match wire_type { + WireType::Varint => self.read_raw_varint64().map(UnknownValue::Varint), + WireType::Fixed64 => self.read_fixed64().map(UnknownValue::Fixed64), + WireType::Fixed32 => self.read_fixed32().map(UnknownValue::Fixed32), + WireType::LengthDelimited => { + let len = self.read_raw_varint32()?; + self.read_raw_bytes(len).map(UnknownValue::LengthDelimited) + } + WireType::StartGroup => self + .read_group_to_bytes(field_number) + .map(UnknownValue::Group), + WireType::EndGroup => { + Err(ProtobufError::WireError(WireError::UnexpectedWireType(wire_type)).into()) + } + } + } + /// Read `UnknownValue` pub fn read_unknown(&mut self, wire_type: WireType) -> crate::Result { match wire_type { diff --git a/protobuf/src/coded_output_stream/mod.rs b/protobuf/src/coded_output_stream/mod.rs index 574d6af1c..dd330546f 100644 --- a/protobuf/src/coded_output_stream/mod.rs +++ b/protobuf/src/coded_output_stream/mod.rs @@ -416,6 +416,7 @@ impl<'a> CodedOutputStream<'a> { UnknownValueRef::Fixed32(fixed32) => self.write_raw_little_endian32(fixed32), UnknownValueRef::Varint(varint) => self.write_raw_varint64(varint), UnknownValueRef::LengthDelimited(bytes) => self.write_bytes_no_tag(bytes), + UnknownValueRef::Group(bytes) => self.write_bytes_no_tag(bytes), } } diff --git a/protobuf/src/rt/mod.rs b/protobuf/src/rt/mod.rs index 90db0eb29..9f75d6275 100644 --- a/protobuf/src/rt/mod.rs +++ b/protobuf/src/rt/mod.rs @@ -36,6 +36,7 @@ pub use singular::sint64_size; pub use singular::string_size; pub use singular::uint32_size; pub use singular::uint64_size; +pub use unknown_or_group::read_unknown; pub use unknown_or_group::read_unknown_or_skip_group; pub use unknown_or_group::skip_field_for_tag; pub use unknown_or_group::unknown_fields_size; @@ -60,3 +61,8 @@ pub(crate) fn compute_raw_varint32_size(value: u32) -> u64 { pub fn tag_size(field_number: u32) -> u64 { encoded_varint64_len((field_number as u64) << 3) as u64 } + +/// Sets the `WireType` of a raw tag to `WireType::EndGroup`. +pub fn set_wire_type_to_end_group(tag: u32) -> u32 { + tag & !crate::wire_format::TAG_TYPE_MASK | (WireType::EndGroup as u32) +} diff --git a/protobuf/src/rt/unknown_or_group.rs b/protobuf/src/rt/unknown_or_group.rs index 33e688247..cf9d35f00 100644 --- a/protobuf/src/rt/unknown_or_group.rs +++ b/protobuf/src/rt/unknown_or_group.rs @@ -21,12 +21,17 @@ fn skip_group(is: &mut CodedInputStream) -> crate::Result<()> { pub fn unknown_fields_size(unknown_fields: &UnknownFields) -> u64 { let mut r = 0; for (number, value) in unknown_fields { - r += tag_size(number); + let tag_size = tag_size(number); + r += tag_size; r += match value { UnknownValueRef::Fixed32(_) => 4, UnknownValueRef::Fixed64(_) => 8, UnknownValueRef::Varint(v) => compute_raw_varint64_size(v), UnknownValueRef::LengthDelimited(v) => bytes_size_no_tag(v), + UnknownValueRef::Group(v) => { + // length of encoded data plus end tag + v.len() as u64 + tag_size + } }; } r @@ -62,6 +67,19 @@ pub fn read_unknown_or_skip_group( read_unknown_or_skip_group_with_tag_unpacked(field_humber, wire_type, is, unknown_fields) } +/// Store a value in unknown. +/// Return error if tag is incorrect. +pub fn read_unknown( + tag: u32, + is: &mut CodedInputStream<'_>, + unknown_fields: &mut UnknownFields, +) -> crate::Result<()> { + let (field_number, wire_type) = Tag::new(tag)?.unpack(); + let unknown = is.read_unknown_with_tag_unpacked(field_number, wire_type)?; + unknown_fields.add_value(tag, unknown); + Ok(()) +} + /// Skip field. pub fn skip_field_for_tag(tag: u32, is: &mut CodedInputStream) -> crate::Result<()> { let (_field_humber, wire_type) = Tag::new(tag)?.unpack(); diff --git a/protobuf/src/unknown.rs b/protobuf/src/unknown.rs index c54b20c4c..486fe7dd3 100644 --- a/protobuf/src/unknown.rs +++ b/protobuf/src/unknown.rs @@ -27,6 +27,8 @@ pub enum UnknownValue { Varint(u64), /// Length-delimited unknown (e. g. `message` or `string`) LengthDelimited(Vec), + /// Tag-delimited unknown + Group(Vec), } impl UnknownValue { @@ -42,6 +44,7 @@ impl UnknownValue { UnknownValue::Fixed64(fixed64) => UnknownValueRef::Fixed64(fixed64), UnknownValue::Varint(varint) => UnknownValueRef::Varint(varint), UnknownValue::LengthDelimited(ref bytes) => UnknownValueRef::LengthDelimited(&bytes), + UnknownValue::Group(ref bytes) => UnknownValueRef::Group(&bytes), } } @@ -99,6 +102,8 @@ pub enum UnknownValueRef<'o> { Varint(u64), /// Length-delimited unknown LengthDelimited(&'o [u8]), + /// Tag-delimited unknown + Group(&'o [u8]), } impl<'o> UnknownValueRef<'o> { @@ -109,6 +114,7 @@ impl<'o> UnknownValueRef<'o> { UnknownValueRef::Fixed64(_) => WireType::Fixed64, UnknownValueRef::Varint(_) => WireType::Varint, UnknownValueRef::LengthDelimited(_) => WireType::LengthDelimited, + UnknownValueRef::Group(_) => WireType::StartGroup, } } @@ -118,6 +124,7 @@ impl<'o> UnknownValueRef<'o> { UnknownValueRef::Fixed64(v) => ReflectValueRef::U64(*v), UnknownValueRef::Varint(v) => ReflectValueRef::U64(*v), UnknownValueRef::LengthDelimited(v) => ReflectValueRef::Bytes(v), + UnknownValueRef::Group(v) => ReflectValueRef::Bytes(v), } } } @@ -135,6 +142,7 @@ pub(crate) struct UnknownValues { pub(crate) varint: Vec, /// Length-delimited unknowns pub(crate) length_delimited: Vec>, + pub(crate) groups: Vec>, } impl UnknownValues { @@ -147,6 +155,7 @@ impl UnknownValues { UnknownValue::LengthDelimited(length_delimited) => { self.length_delimited.push(length_delimited) } + UnknownValue::Group(g) => self.groups.push(g), }; } @@ -157,6 +166,7 @@ impl UnknownValues { fixed64: self.fixed64.iter(), varint: self.varint.iter(), length_delimited: self.length_delimited.iter(), + groups: self.groups.iter(), } } @@ -169,6 +179,8 @@ impl UnknownValues { Some(UnknownValueRef::Varint(*last)) } else if let Some(last) = self.length_delimited.last() { Some(UnknownValueRef::LengthDelimited(last)) + } else if let Some(last) = self.groups.last() { + Some(UnknownValueRef::Group(last)) } else { None } @@ -190,6 +202,7 @@ pub(crate) struct UnknownValuesIter<'o> { fixed64: slice::Iter<'o, u64>, varint: slice::Iter<'o, u64>, length_delimited: slice::Iter<'o, Vec>, + groups: slice::Iter<'o, Vec>, } impl<'o> Iterator for UnknownValuesIter<'o> { @@ -208,6 +221,9 @@ impl<'o> Iterator for UnknownValuesIter<'o> { if let Some(length_delimited) = self.length_delimited.next() { return Some(UnknownValueRef::LengthDelimited(&length_delimited)); } + if let Some(g) = self.groups.next() { + return Some(UnknownValueRef::Group(&g)); + } None } } diff --git a/protobuf/src/wire_format.rs b/protobuf/src/wire_format.rs index 9cbbf5549..2f2fb2262 100644 --- a/protobuf/src/wire_format.rs +++ b/protobuf/src/wire_format.rs @@ -35,9 +35,9 @@ pub enum WireType { Fixed64 = 1, /// Length-delimited field LengthDelimited = 2, - /// Groups are not supported in rust-protobuf + /// Denotes the beginning of a group StartGroup = 3, - /// Groups are not supported in rust-protobuf + /// Denotes the end of a group EndGroup = 4, /// 32-bit field (e. g. `fixed32` or `float`) Fixed32 = 5, @@ -57,6 +57,11 @@ impl WireType { } } + /// Indicates whether the `WireType` is a `StartGroup` or `EndGroup`. + pub fn is_group(&self) -> bool { + matches!(self, WireType::StartGroup | WireType::EndGroup) + } + #[doc(hidden)] pub fn for_type(field_type: field_descriptor_proto::Type) -> WireType { use field_descriptor_proto::Type; @@ -78,7 +83,7 @@ impl WireType { Type::TYPE_STRING => WireType::LengthDelimited, Type::TYPE_BYTES => WireType::LengthDelimited, Type::TYPE_MESSAGE => WireType::LengthDelimited, - Type::TYPE_GROUP => WireType::LengthDelimited, // not true + Type::TYPE_GROUP => WireType::StartGroup, } } } @@ -131,7 +136,7 @@ impl Tag { } /// Get wire type - fn wire_type(self) -> WireType { + pub(crate) fn wire_type(self) -> WireType { self.wire_type } From c44f05312a8956426ae527df28c47cb4f6bc5cd7 Mon Sep 17 00:00:00 2001 From: kbullinger Date: Wed, 2 Oct 2024 17:30:09 -0400 Subject: [PATCH 2/4] add support for group in oneof; add field_number to unknown_fields instead of tag --- protobuf-codegen/src/gen/field/mod.rs | 41 +++++++++++++++++++++++---- protobuf-codegen/src/gen/oneof.rs | 15 +++++++++- protobuf/src/rt/unknown_or_group.rs | 2 +- 3 files changed, 50 insertions(+), 8 deletions(-) diff --git a/protobuf-codegen/src/gen/field/mod.rs b/protobuf-codegen/src/gen/field/mod.rs index 9ab9270b4..e657ddd58 100644 --- a/protobuf-codegen/src/gen/field/mod.rs +++ b/protobuf-codegen/src/gen/field/mod.rs @@ -970,13 +970,42 @@ impl<'a> FieldGen<'a> { } else { typed }; + let variant_path = o.variant_path(&self.proto_field.message.scope.rust_path_to_file()); + if o.elem.proto_type() == Type::TYPE_GROUP { + let proto_path = protobuf_crate_path(&self.customize); + w.write_line(&format!( + "let end_tag = {proto_path}::rt::set_wire_type_to_end_group(tag);" + )); + let value_type = &self + .elem() + .rust_storage_elem_type( + &self + .proto_field + .message + .scope + .file_and_mod(self.customize.clone()), + ) + .to_code(&self.customize); - w.write_line(&format!( - "self.{} = ::std::option::Option::Some({}({}));", - o.oneof_field_name, - o.variant_path(&self.proto_field.message.scope.rust_path_to_file()), - maybe_boxed.value - )); + w.write_line(&format!( + "let mut {} = {}::default();", + o.oneof_field_name, value_type, + )); + w.write_line(&format!( + "{}.merge_delimited(is, end_tag)?;", + o.oneof_field_name, + )); + + w.write_line(&format!( + "self.{} = ::std::option::Option::Some({}({}));", + o.oneof_field_name, variant_path, o.oneof_field_name, + )); + } else { + w.write_line(&format!( + "self.{} = ::std::option::Option::Some({}({}));", + o.oneof_field_name, variant_path, maybe_boxed.value + )); + } }) } diff --git a/protobuf-codegen/src/gen/oneof.rs b/protobuf-codegen/src/gen/oneof.rs index e1bd33901..73b217627 100644 --- a/protobuf-codegen/src/gen/oneof.rs +++ b/protobuf-codegen/src/gen/oneof.rs @@ -223,6 +223,19 @@ impl<'a> OneofGen<'a> { .collect() } + pub fn variants(&'a self) -> impl Iterator> { + self.oneof.variants().into_iter().map(|v| { + let field = self + .message + .fields + .iter() + .filter(|f| f.proto_field.name() == v.field.name()) + .next() + .expect(&format!("field not found by name: {}", v.field.name())); + OneofVariantGen::parse(self, v, field, self.message.root_scope) + }) + } + pub fn full_storage_type(&self) -> RustType { RustType::Option(Box::new(RustType::Oneof( self.type_name_relative( @@ -262,7 +275,7 @@ impl<'a> OneofGen<'a> { } write_protoc_insertion_point_for_oneof(w, &self.customize.for_elem, &self.oneof.oneof); w.pub_enum(&self.oneof.rust_name().ident.to_string(), |w| { - for variant in self.variants_except_group() { + for variant in self.variants() { write_protoc_insertion_point_for_oneof_field( w, &self.customize.for_children, diff --git a/protobuf/src/rt/unknown_or_group.rs b/protobuf/src/rt/unknown_or_group.rs index cf9d35f00..e43ae9721 100644 --- a/protobuf/src/rt/unknown_or_group.rs +++ b/protobuf/src/rt/unknown_or_group.rs @@ -76,7 +76,7 @@ pub fn read_unknown( ) -> crate::Result<()> { let (field_number, wire_type) = Tag::new(tag)?.unpack(); let unknown = is.read_unknown_with_tag_unpacked(field_number, wire_type)?; - unknown_fields.add_value(tag, unknown); + unknown_fields.add_value(field_number, unknown); Ok(()) } From 7919ee023a378a628a8d8c60ed39a817b5106474 Mon Sep 17 00:00:00 2001 From: kbullinger Date: Thu, 3 Oct 2024 12:41:26 -0400 Subject: [PATCH 3/4] include group oneof variants in generated output --- protobuf-codegen/src/gen/message.rs | 2 +- protobuf-codegen/src/gen/oneof.rs | 26 +++++++++++++++----------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/protobuf-codegen/src/gen/message.rs b/protobuf-codegen/src/gen/message.rs index e2c4a633b..8ef950516 100644 --- a/protobuf-codegen/src/gen/message.rs +++ b/protobuf-codegen/src/gen/message.rs @@ -195,7 +195,7 @@ impl<'a> MessageGen<'a> { F: Fn(&mut CodeWriter, &OneofVariantGen, &RustValueTyped), { for oneof in self.oneofs() { - let variants = oneof.variants_except_group(); + let variants = oneof.variants(); if variants.is_empty() { // Special case because // https://github.com/rust-lang/rust/issues/50642 diff --git a/protobuf-codegen/src/gen/oneof.rs b/protobuf-codegen/src/gen/oneof.rs index 73b217627..a6953936e 100644 --- a/protobuf-codegen/src/gen/oneof.rs +++ b/protobuf-codegen/src/gen/oneof.rs @@ -223,17 +223,21 @@ impl<'a> OneofGen<'a> { .collect() } - pub fn variants(&'a self) -> impl Iterator> { - self.oneof.variants().into_iter().map(|v| { - let field = self - .message - .fields - .iter() - .filter(|f| f.proto_field.name() == v.field.name()) - .next() - .expect(&format!("field not found by name: {}", v.field.name())); - OneofVariantGen::parse(self, v, field, self.message.root_scope) - }) + pub fn variants(&'a self) -> Vec> { + self.oneof + .variants() + .into_iter() + .map(|v| { + let field = self + .message + .fields + .iter() + .filter(|f| f.proto_field.name() == v.field.name()) + .next() + .expect(&format!("field not found by name: {}", v.field.name())); + OneofVariantGen::parse(self, v, field, self.message.root_scope) + }) + .collect() } pub fn full_storage_type(&self) -> RustType { From 1bb993ee31a4be56e07d67c2800ed645b1c65bc4 Mon Sep 17 00:00:00 2001 From: kbullinger Date: Thu, 3 Oct 2024 13:15:55 -0400 Subject: [PATCH 4/4] add rust_opt items to protoc invocation --- protobuf/regenerate.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/protobuf/regenerate.sh b/protobuf/regenerate.sh index 4987e8eee..4da70c1b0 100755 --- a/protobuf/regenerate.sh +++ b/protobuf/regenerate.sh @@ -34,7 +34,7 @@ esac "$PROTOC" \ --plugin=protoc-gen-rust="$where_am_i/target/debug/protoc-gen-rust$exe_suffix" \ --rust_out tmp-generated \ - --rust_opt 'inside_protobuf=true gen_mod_rs=false' \ + --rust_opt 'inside_protobuf=true,gen_mod_rs=false,experimental-codegen=enabled,kernel=cpp' \ -I../proto \ ../proto/google/protobuf/*.proto \ ../proto/google/protobuf/compiler/*.proto \