From 680a2e19c77a1ca7559373f383e5cb3422174ab3 Mon Sep 17 00:00:00 2001 From: Mikhail Zabaluev Date: Mon, 29 Apr 2024 20:28:40 +0300 Subject: [PATCH] style(prost-build): Consolidate field data into struct (#1017) * prost-build: consolidate message field data When massaging field data in CodeGenerator::append_message, move it into lists of Field and OneofField structs so that later generation passes can operate on the data with less code duplication. Subsidiary append_* methods are changed to take references to these structs rather than moved data, as generation of lexical tokens does not actually consume any owned data, and we will need more passes over the same field lists for the upcoming builder code. * prost-build: compute field tags in place * prost-build: address comments on reuse of Field Make rust_field into a method computing the name on the fly. In OneofField, make the vector of fields to have Field members. Don't play reference renaming tricks with field.descriptor. --- prost-build/src/code_generator.rs | 214 ++++++++++++++++++------------ 1 file changed, 127 insertions(+), 87 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 89c48f30b..c303a6952 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -49,6 +49,44 @@ fn prost_path(config: &Config) -> &str { config.prost_path.as_deref().unwrap_or("::prost") } +struct Field { + descriptor: FieldDescriptorProto, + path_index: i32, +} + +impl Field { + fn new(descriptor: FieldDescriptorProto, path_index: i32) -> Self { + Self { + descriptor, + path_index, + } + } + + fn rust_name(&self) -> String { + to_snake(self.descriptor.name()) + } +} + +struct OneofField { + descriptor: OneofDescriptorProto, + fields: Vec, + path_index: i32, +} + +impl OneofField { + fn new(descriptor: OneofDescriptorProto, fields: Vec, path_index: i32) -> Self { + Self { + descriptor, + fields, + path_index, + } + } + + fn rust_name(&self) -> String { + to_snake(self.descriptor.name()) + } +} + impl<'a> CodeGenerator<'a> { pub fn generate( config: &mut Config, @@ -158,21 +196,33 @@ impl<'a> CodeGenerator<'a> { // Split the fields into a vector of the normal fields, and oneof fields. // Path indexes are preserved so that comments can be retrieved. - type Fields = Vec<(FieldDescriptorProto, usize)>; - type OneofFields = MultiMap; - let (fields, mut oneof_fields): (Fields, OneofFields) = message + type OneofFieldsByIndex = MultiMap; + let (fields, mut oneof_map): (Vec, OneofFieldsByIndex) = message .field .into_iter() .enumerate() - .partition_map(|(idx, field)| { - if field.proto3_optional.unwrap_or(false) { - Either::Left((field, idx)) - } else if let Some(oneof_index) = field.oneof_index { - Either::Right((oneof_index, (field, idx))) + .partition_map(|(idx, proto)| { + let idx = idx as i32; + if proto.proto3_optional.unwrap_or(false) { + Either::Left(Field::new(proto, idx)) + } else if let Some(oneof_index) = proto.oneof_index { + Either::Right((oneof_index, Field::new(proto, idx))) } else { - Either::Left((field, idx)) + Either::Left(Field::new(proto, idx)) } }); + // Optional fields create a synthetic oneof that we want to skip + let oneof_fields: Vec = message + .oneof_decl + .into_iter() + .enumerate() + .filter_map(move |(idx, proto)| { + let idx = idx as i32; + oneof_map + .remove(&idx) + .map(|fields| OneofField::new(proto, fields, idx)) + }) + .collect(); self.append_doc(&fq_message_name, None); self.append_type_attributes(&fq_message_name); @@ -192,9 +242,10 @@ impl<'a> CodeGenerator<'a> { self.depth += 1; self.path.push(2); - for (field, idx) in fields { - self.path.push(idx as i32); + for field in &fields { + self.path.push(field.path_index); match field + .descriptor .type_name .as_ref() .and_then(|type_name| map_types.get(type_name)) @@ -207,16 +258,9 @@ impl<'a> CodeGenerator<'a> { self.path.pop(); self.path.push(8); - for (idx, oneof) in message.oneof_decl.iter().enumerate() { - let idx = idx as i32; - - let fields = match oneof_fields.get_vec(&idx) { - Some(fields) => fields, - None => continue, - }; - - self.path.push(idx); - self.append_oneof_field(&message_name, &fq_message_name, oneof, fields); + for oneof in &oneof_fields { + self.path.push(oneof.path_index); + self.append_oneof_field(&message_name, &fq_message_name, oneof); self.path.pop(); } self.path.pop(); @@ -243,14 +287,8 @@ impl<'a> CodeGenerator<'a> { } self.path.pop(); - for (idx, oneof) in message.oneof_decl.into_iter().enumerate() { - let idx = idx as i32; - // optional fields create a synthetic oneof that we want to skip - let fields = match oneof_fields.remove(&idx) { - Some(fields) => fields, - None => continue, - }; - self.append_oneof(&fq_message_name, oneof, idx, fields); + for oneof in &oneof_fields { + self.append_oneof(&fq_message_name, oneof); } self.pop_mod(); @@ -359,32 +397,32 @@ impl<'a> CodeGenerator<'a> { } } - fn append_field(&mut self, fq_message_name: &str, field: FieldDescriptorProto) { - let type_ = field.r#type(); - let repeated = field.label == Some(Label::Repeated as i32); - let deprecated = self.deprecated(&field); - let optional = self.optional(&field); - let ty = self.resolve_type(&field, fq_message_name); + fn append_field(&mut self, fq_message_name: &str, field: &Field) { + let type_ = field.descriptor.r#type(); + let repeated = field.descriptor.label == Some(Label::Repeated as i32); + let deprecated = self.deprecated(&field.descriptor); + let optional = self.optional(&field.descriptor); + let ty = self.resolve_type(&field.descriptor, fq_message_name); let boxed = !repeated && ((type_ == Type::Message || type_ == Type::Group) && self .message_graph - .is_nested(field.type_name(), fq_message_name)) + .is_nested(field.descriptor.type_name(), fq_message_name)) || (self .config .boxed - .get_first_field(fq_message_name, field.name()) + .get_first_field(fq_message_name, field.descriptor.name()) .is_some()); debug!( " field: {:?}, type: {:?}, boxed: {}", - field.name(), + field.descriptor.name(), ty, boxed ); - self.append_doc(fq_message_name, Some(field.name())); + self.append_doc(fq_message_name, Some(field.descriptor.name())); if deprecated { self.push_indent(); @@ -393,21 +431,21 @@ impl<'a> CodeGenerator<'a> { self.push_indent(); self.buf.push_str("#[prost("); - let type_tag = self.field_type_tag(&field); + let type_tag = self.field_type_tag(&field.descriptor); self.buf.push_str(&type_tag); if type_ == Type::Bytes { let bytes_type = self .config .bytes_type - .get_first_field(fq_message_name, field.name()) + .get_first_field(fq_message_name, field.descriptor.name()) .copied() .unwrap_or_default(); self.buf .push_str(&format!("={:?}", bytes_type.annotation())); } - match field.label() { + match field.descriptor.label() { Label::Optional => { if optional { self.buf.push_str(", optional"); @@ -416,8 +454,9 @@ impl<'a> CodeGenerator<'a> { Label::Required => self.buf.push_str(", required"), Label::Repeated => { self.buf.push_str(", repeated"); - if can_pack(&field) + if can_pack(&field.descriptor) && !field + .descriptor .options .as_ref() .map_or(self.syntax == Syntax::Proto3, |options| options.packed()) @@ -431,9 +470,9 @@ impl<'a> CodeGenerator<'a> { self.buf.push_str(", boxed"); } self.buf.push_str(", tag=\""); - self.buf.push_str(&field.number().to_string()); + self.buf.push_str(&field.descriptor.number().to_string()); - if let Some(ref default) = field.default_value { + if let Some(ref default) = field.descriptor.default_value { self.buf.push_str("\", default=\""); if type_ == Type::Bytes { self.buf.push_str("b\\\""); @@ -450,6 +489,7 @@ impl<'a> CodeGenerator<'a> { // the last segment and strip it from the left // side of the default value. let enum_type = field + .descriptor .type_name .as_ref() .and_then(|ty| ty.split('.').last()) @@ -464,10 +504,10 @@ impl<'a> CodeGenerator<'a> { } self.buf.push_str("\")]\n"); - self.append_field_attributes(fq_message_name, field.name()); + self.append_field_attributes(fq_message_name, field.descriptor.name()); self.push_indent(); self.buf.push_str("pub "); - self.buf.push_str(&to_snake(field.name())); + self.buf.push_str(&field.rust_name()); self.buf.push_str(": "); let prost_path = prost_path(self.config); @@ -495,7 +535,7 @@ impl<'a> CodeGenerator<'a> { fn append_map_field( &mut self, fq_message_name: &str, - field: FieldDescriptorProto, + field: &Field, key: &FieldDescriptorProto, value: &FieldDescriptorProto, ) { @@ -504,18 +544,18 @@ impl<'a> CodeGenerator<'a> { debug!( " map field: {:?}, key type: {:?}, value type: {:?}", - field.name(), + field.descriptor.name(), key_ty, value_ty ); - self.append_doc(fq_message_name, Some(field.name())); + self.append_doc(fq_message_name, Some(field.descriptor.name())); self.push_indent(); let map_type = self .config .map_type - .get_first_field(fq_message_name, field.name()) + .get_first_field(fq_message_name, field.descriptor.name()) .copied() .unwrap_or_default(); let key_tag = self.field_type_tag(key); @@ -526,13 +566,13 @@ impl<'a> CodeGenerator<'a> { map_type.annotation(), key_tag, value_tag, - field.number() + field.descriptor.number() )); - self.append_field_attributes(fq_message_name, field.name()); + self.append_field_attributes(fq_message_name, field.descriptor.name()); self.push_indent(); self.buf.push_str(&format!( "pub {}: {}<{}, {}>,\n", - to_snake(field.name()), + field.rust_name(), map_type.rust_type(), key_ty, value_ty @@ -543,44 +583,41 @@ impl<'a> CodeGenerator<'a> { &mut self, message_name: &str, fq_message_name: &str, - oneof: &OneofDescriptorProto, - fields: &[(FieldDescriptorProto, usize)], + oneof: &OneofField, ) { - let name = format!( + let type_name = format!( "{}::{}", to_snake(message_name), - to_upper_camel(oneof.name()) + to_upper_camel(oneof.descriptor.name()) ); self.append_doc(fq_message_name, None); self.push_indent(); self.buf.push_str(&format!( "#[prost(oneof=\"{}\", tags=\"{}\")]\n", - name, - fields.iter().map(|(field, _)| field.number()).join(", ") + type_name, + oneof + .fields + .iter() + .map(|field| field.descriptor.number()) + .join(", "), )); - self.append_field_attributes(fq_message_name, oneof.name()); + self.append_field_attributes(fq_message_name, oneof.descriptor.name()); self.push_indent(); self.buf.push_str(&format!( "pub {}: ::core::option::Option<{}>,\n", - to_snake(oneof.name()), - name + oneof.rust_name(), + type_name )); } - fn append_oneof( - &mut self, - fq_message_name: &str, - oneof: OneofDescriptorProto, - idx: i32, - fields: Vec<(FieldDescriptorProto, usize)>, - ) { + fn append_oneof(&mut self, fq_message_name: &str, oneof: &OneofField) { self.path.push(8); - self.path.push(idx); + self.path.push(oneof.path_index); self.append_doc(fq_message_name, None); self.path.pop(); self.path.pop(); - let oneof_name = format!("{}.{}", fq_message_name, oneof.name()); + let oneof_name = format!("{}.{}", fq_message_name, oneof.descriptor.name()); self.append_type_attributes(&oneof_name); self.append_enum_attributes(&oneof_name); self.push_indent(); @@ -593,43 +630,43 @@ impl<'a> CodeGenerator<'a> { self.append_skip_debug(fq_message_name); self.push_indent(); self.buf.push_str("pub enum "); - self.buf.push_str(&to_upper_camel(oneof.name())); + self.buf.push_str(&to_upper_camel(oneof.descriptor.name())); self.buf.push_str(" {\n"); self.path.push(2); self.depth += 1; - for (field, idx) in fields { - let type_ = field.r#type(); + for field in &oneof.fields { + let type_ = field.descriptor.r#type(); - self.path.push(idx as i32); - self.append_doc(fq_message_name, Some(field.name())); + self.path.push(field.path_index); + self.append_doc(fq_message_name, Some(field.descriptor.name())); self.path.pop(); self.push_indent(); - let ty_tag = self.field_type_tag(&field); + let ty_tag = self.field_type_tag(&field.descriptor); self.buf.push_str(&format!( "#[prost({}, tag=\"{}\")]\n", ty_tag, - field.number() + field.descriptor.number() )); - self.append_field_attributes(&oneof_name, field.name()); + self.append_field_attributes(&oneof_name, field.descriptor.name()); self.push_indent(); - let ty = self.resolve_type(&field, fq_message_name); + let ty = self.resolve_type(&field.descriptor, fq_message_name); let boxed = ((type_ == Type::Message || type_ == Type::Group) && self .message_graph - .is_nested(field.type_name(), fq_message_name)) + .is_nested(field.descriptor.type_name(), fq_message_name)) || (self .config .boxed - .get_first_field(&oneof_name, field.name()) + .get_first_field(&oneof_name, field.descriptor.name()) .is_some()); debug!( " oneof: {:?}, type: {:?}, boxed: {}", - field.name(), + field.descriptor.name(), ty, boxed ); @@ -637,12 +674,15 @@ impl<'a> CodeGenerator<'a> { if boxed { self.buf.push_str(&format!( "{}(::prost::alloc::boxed::Box<{}>),\n", - to_upper_camel(field.name()), + to_upper_camel(field.descriptor.name()), ty )); } else { - self.buf - .push_str(&format!("{}({}),\n", to_upper_camel(field.name()), ty)); + self.buf.push_str(&format!( + "{}({}),\n", + to_upper_camel(field.descriptor.name()), + ty + )); } } self.depth -= 1;