diff --git a/protobuf-codegen/src/field.rs b/protobuf-codegen/src/field.rs index 950b62938..bf1fe5599 100644 --- a/protobuf-codegen/src/field.rs +++ b/protobuf-codegen/src/field.rs @@ -455,7 +455,7 @@ impl<'a> FieldGen<'a> { }), } } else if let Some(oneof) = field.oneof() { - FieldKind::Oneof(OneofField::parse(&oneof, field.field, elem)) + FieldKind::Oneof(OneofField::parse(&oneof, &field, elem, root_scope)) } else { let flag = if field.message.scope.file_scope.syntax() == Syntax::PROTO3 && field.field.get_field_type() != FieldDescriptorProto_Type::TYPE_MESSAGE diff --git a/protobuf-codegen/src/lib.rs b/protobuf-codegen/src/lib.rs index 89387d89f..488ee0fce 100644 --- a/protobuf-codegen/src/lib.rs +++ b/protobuf-codegen/src/lib.rs @@ -154,8 +154,11 @@ fn gen_file( w.write_line(""); w.write_line("/// Generated files are compatible only with the same version"); w.write_line("/// of protobuf runtime."); - w.write_line(&format!("const _PROTOBUF_VERSION_CHECK: () = {}::{};", - protobuf_crate_path(&customize), protobuf::VERSION_IDENT)); + w.write_line(&format!( + "const _PROTOBUF_VERSION_CHECK: () = {}::{};", + protobuf_crate_path(&customize), + protobuf::VERSION_IDENT + )); } for message in &scope.get_messages() { diff --git a/protobuf-codegen/src/message.rs b/protobuf-codegen/src/message.rs index 7ef0679f6..ff3f7e1c4 100644 --- a/protobuf-codegen/src/message.rs +++ b/protobuf-codegen/src/message.rs @@ -14,8 +14,8 @@ use serde; /// Message info for codegen pub struct MessageGen<'a> { - message: &'a MessageWithScope<'a>, - root_scope: &'a RootScope<'a>, + pub(crate) message: &'a MessageWithScope<'a>, + pub(crate) root_scope: &'a RootScope<'a>, type_name: String, pub fields: Vec>, pub lite_runtime: bool, diff --git a/protobuf-codegen/src/oneof.rs b/protobuf-codegen/src/oneof.rs index 86c24d8a7..52c0aba6a 100644 --- a/protobuf-codegen/src/oneof.rs +++ b/protobuf-codegen/src/oneof.rs @@ -6,11 +6,12 @@ use field::FieldGen; use message::MessageGen; use protobuf::descriptor::FieldDescriptorProto; use protobuf::descriptor::FieldDescriptorProto_Type; -use protobuf::descriptorx::OneofVariantWithContext; use protobuf::descriptorx::OneofWithContext; use protobuf::descriptorx::WithScope; +use protobuf::descriptorx::{FieldWithContext, OneofVariantWithContext, RootScope}; use rust_types_values::RustType; use serde; +use std::collections::HashSet; use Customize; // oneof one { ... } @@ -23,21 +24,35 @@ pub struct OneofField { } impl OneofField { + // Detecting recursion: if oneof fields contains a self-reference + // or another message which has a reference to self, + // put oneof variant into a box. + fn need_boxed(field: &FieldWithContext, root_scope: &RootScope, owner_name: &str) -> bool { + let mut visited_messages = HashSet::new(); + let mut fields = vec![field.clone()]; + while let Some(field) = fields.pop() { + if field.field.get_field_type() == FieldDescriptorProto_Type::TYPE_MESSAGE { + let message_name = field.field.get_type_name().to_owned(); + if !visited_messages.insert(message_name.clone()) { + continue; + } + if message_name == *owner_name { + return true; + } + let message = root_scope.find_message(&message_name); + fields.extend(message.fields().into_iter().filter(|f| f.is_oneof())); + } + } + false + } + pub fn parse( oneof: &OneofWithContext, - _field: &FieldDescriptorProto, + field: &FieldWithContext, elem: FieldElem, + root_scope: &RootScope, ) -> OneofField { - // detecting recursion - let boxed = if let &FieldElem::Message(ref name, ..) = &elem { - if *name == oneof.message.rust_name() { - true - } else { - false - } - } else { - false - }; + let boxed = OneofField::need_boxed(field, root_scope, &oneof.message.name_absolute()); OneofField { elem: elem, @@ -73,6 +88,8 @@ impl<'a> OneofVariantGen<'a> { oneof: &'a OneofGen<'a>, variant: OneofVariantWithContext<'a>, field: &'a FieldGen, + root_scope: &RootScope, + customize: Customize, ) -> OneofVariantGen<'a> { OneofVariantGen { oneof: oneof, @@ -85,10 +102,11 @@ impl<'a> OneofVariantGen<'a> { ), oneof_field: OneofField::parse( variant.oneof, - variant.field, + &field.proto_field, field.oneof().elem.clone(), + oneof.message.root_scope, ), - customize: field.customize.clone(), + customize, } } @@ -149,7 +167,13 @@ impl<'a> OneofGen<'a> { .expect(&format!("field not found by name: {}", v.field.get_name())); match field.proto_type { FieldDescriptorProto_Type::TYPE_GROUP => None, - _ => Some(OneofVariantGen::parse(self, v, field)), + _ => Some(OneofVariantGen::parse( + self, + v, + field, + self.message.root_scope, + self.customize.clone(), + )), } }) .collect() diff --git a/protobuf-codegen/src/scope.rs b/protobuf-codegen/src/scope.rs index 3dd04d2af..eb9271ea7 100644 --- a/protobuf-codegen/src/scope.rs +++ b/protobuf-codegen/src/scope.rs @@ -42,7 +42,7 @@ impl<'a> RootScope<'a> { } // find message by fully qualified name - pub fn _find_message(&'a self, fqn: &ProtobufAbsolutePath) -> MessageWithScope<'a> { + pub fn find_message(&'a self, fqn: &ProtobufAbsolutePath) -> MessageWithScope<'a> { match self.find_message_or_enum(fqn) { MessageOrEnumWithScope::Message(m) => m, _ => panic!("not a message: {}", fqn), diff --git a/protobuf-test/src/common/v2/test_oneof_recursive_pb.proto b/protobuf-test/src/common/v2/test_oneof_recursive_pb.proto index 32c82eb41..1f6d8c6f0 100644 --- a/protobuf-test/src/common/v2/test_oneof_recursive_pb.proto +++ b/protobuf-test/src/common/v2/test_oneof_recursive_pb.proto @@ -8,3 +8,21 @@ message LinkedList { LinkedList node = 2; } } + +message RecursiveA { + oneof x { + RecursiveB b = 1; + } +} + +message RecursiveB { + oneof x { + RecursiveC c = 1; + } +} + +message RecursiveC { + oneof x { + RecursiveA a = 1; + } +} diff --git a/protobuf/src/descriptorx.rs b/protobuf/src/descriptorx.rs index 08591e19f..063b52cb8 100644 --- a/protobuf/src/descriptorx.rs +++ b/protobuf/src/descriptorx.rs @@ -521,7 +521,8 @@ pub struct FieldWithContext<'a> { } impl<'a> FieldWithContext<'a> { - fn is_oneof(&self) -> bool { + #[doc(hidden)] + pub fn is_oneof(&self) -> bool { self.field.has_oneof_index() }