diff --git a/benches/rust-protobuf/perftest_data_quick.rs b/benches/rust-protobuf/perftest_data_quick.rs index 042d1454..bdd8c948 100644 --- a/benches/rust-protobuf/perftest_data_quick.rs +++ b/benches/rust-protobuf/perftest_data_quick.rs @@ -30,7 +30,7 @@ impl Test1 { impl MessageWrite for Test1 { fn get_size(&self) -> usize { - self.value.as_ref().map_or(0, |m| 1 + sizeof_varint(*m as u64)) + self.value.as_ref().map_or(0, |m| 1 + sizeof_varint(*(m) as u64)) } fn write_message(&self, w: &mut Writer) -> Result<()> { @@ -60,7 +60,7 @@ impl TestRepeatedBool { impl MessageWrite for TestRepeatedBool { fn get_size(&self) -> usize { - self.values.iter().map(|s| 1 + sizeof_varint(*s as u64)).sum::() + self.values.iter().map(|s| 1 + sizeof_varint(*(s) as u64)).sum::() } fn write_message(&self, w: &mut Writer) -> Result<()> { @@ -90,11 +90,11 @@ impl TestRepeatedPackedInt32 { impl MessageWrite for TestRepeatedPackedInt32 { fn get_size(&self) -> usize { - if self.values.is_empty() { 0 } else { 1 + sizeof_len(self.values.iter().map(|s| sizeof_varint(*s as u64)).sum::()) } + if self.values.is_empty() { 0 } else { 1 + sizeof_len(self.values.iter().map(|s| sizeof_varint(*(s) as u64)).sum::()) } } fn write_message(&self, w: &mut Writer) -> Result<()> { - w.write_packed_with_tag(10, &self.values, |w, m| w.write_int32(*m), &|m| sizeof_varint(*m as u64))?; + w.write_packed_with_tag(10, &self.values, |w, m| w.write_int32(*m), &|m| sizeof_varint(*(m) as u64))?; Ok(()) } } @@ -124,9 +124,9 @@ impl TestRepeatedMessages { impl MessageWrite for TestRepeatedMessages { fn get_size(&self) -> usize { - self.messages1.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() - + self.messages2.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() - + self.messages3.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + self.messages1.iter().map(|s| 1 + sizeof_len((s).get_size())).sum::() + + self.messages2.iter().map(|s| 1 + sizeof_len((s).get_size())).sum::() + + self.messages3.iter().map(|s| 1 + sizeof_len((s).get_size())).sum::() } fn write_message(&self, w: &mut Writer) -> Result<()> { @@ -162,9 +162,9 @@ impl TestOptionalMessages { impl MessageWrite for TestOptionalMessages { fn get_size(&self) -> usize { - self.message1.as_ref().map_or(0, |m| 1 + sizeof_len(m.get_size())) - + self.message2.as_ref().map_or(0, |m| 1 + sizeof_len(m.get_size())) - + self.message3.as_ref().map_or(0, |m| 1 + sizeof_len(m.get_size())) + self.message1.as_ref().map_or(0, |m| 1 + sizeof_len((m).get_size())) + + self.message2.as_ref().map_or(0, |m| 1 + sizeof_len((m).get_size())) + + self.message3.as_ref().map_or(0, |m| 1 + sizeof_len((m).get_size())) } fn write_message(&self, w: &mut Writer) -> Result<()> { @@ -200,9 +200,9 @@ impl<'a> TestStrings<'a> { impl<'a> MessageWrite for TestStrings<'a> { fn get_size(&self) -> usize { - self.s1.as_ref().map_or(0, |m| 1 + sizeof_len(m.len())) - + self.s2.as_ref().map_or(0, |m| 1 + sizeof_len(m.len())) - + self.s3.as_ref().map_or(0, |m| 1 + sizeof_len(m.len())) + self.s1.as_ref().map_or(0, |m| 1 + sizeof_len((m).len())) + + self.s2.as_ref().map_or(0, |m| 1 + sizeof_len((m).len())) + + self.s3.as_ref().map_or(0, |m| 1 + sizeof_len((m).len())) } fn write_message(&self, w: &mut Writer) -> Result<()> { @@ -234,7 +234,7 @@ impl<'a> TestBytes<'a> { impl<'a> MessageWrite for TestBytes<'a> { fn get_size(&self) -> usize { - self.b1.as_ref().map_or(0, |m| 1 + sizeof_len(m.len())) + self.b1.as_ref().map_or(0, |m| 1 + sizeof_len((m).len())) } fn write_message(&self, w: &mut Writer) -> Result<()> { @@ -278,14 +278,14 @@ impl<'a> PerftestData<'a> { impl<'a> MessageWrite for PerftestData<'a> { fn get_size(&self) -> usize { - self.test1.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() - + self.test_repeated_bool.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() - + self.test_repeated_messages.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() - + self.test_optional_messages.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() - + self.test_strings.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() - + self.test_repeated_packed_int32.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() - + self.test_small_bytearrays.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() - + self.test_large_bytearrays.iter().map(|s| 1 + sizeof_len(s.get_size())).sum::() + self.test1.iter().map(|s| 1 + sizeof_len((s).get_size())).sum::() + + self.test_repeated_bool.iter().map(|s| 1 + sizeof_len((s).get_size())).sum::() + + self.test_repeated_messages.iter().map(|s| 1 + sizeof_len((s).get_size())).sum::() + + self.test_optional_messages.iter().map(|s| 1 + sizeof_len((s).get_size())).sum::() + + self.test_strings.iter().map(|s| 1 + sizeof_len((s).get_size())).sum::() + + self.test_repeated_packed_int32.iter().map(|s| 1 + sizeof_len((s).get_size())).sum::() + + self.test_small_bytearrays.iter().map(|s| 1 + sizeof_len((s).get_size())).sum::() + + self.test_large_bytearrays.iter().map(|s| 1 + sizeof_len((s).get_size())).sum::() } fn write_message(&self, w: &mut Writer) -> Result<()> { diff --git a/codegen/src/types.rs b/codegen/src/types.rs index 6c8cf76f..dab84f43 100644 --- a/codegen/src/types.rs +++ b/codegen/src/types.rs @@ -169,6 +169,33 @@ impl FieldType { } } + /// Searches for enum corresponding to the current type + fn find_enum<'a, 'b>(&'a self, msgs: &'b [Message], enums: &'b [Enumerator]) -> Option<&'b Enumerator> { + match *self { + FieldType::Enum(ref m) => { + let mut found = match m.rfind('.') { + Some(p) => { + let package = &m[..p]; + let name = &m[(p + 1)..]; + enums.iter().find(|m| m.package == package && m.name == name) + }, + None => enums.iter().find(|m2| m2.name == &m[..]), + }; + + if found.is_none() { + // recursively search into nested messages + for m in msgs { + found = self.find_enum(&m.messages, &m.enums); + if found.is_some() { break; } + } + } + + found + } + _ => None, + } + } + fn has_lifetime(&self, msgs: &[Message]) -> bool { match *self { FieldType::String_ | FieldType::Bytes => true, // Cow @@ -181,7 +208,7 @@ impl FieldType { } } - fn rust_type(&self, msgs: &[Message]) -> String { + fn rust_type(&self, msgs: &[Message], enums: &[Enumerator]) -> String { match *self { FieldType::Int32 | FieldType::Sint32 | FieldType::Sfixed32 => "i32".to_string(), FieldType::Int64 | FieldType::Sint64 | FieldType::Sfixed64 => "i64".to_string(), @@ -192,17 +219,20 @@ impl FieldType { FieldType::String_ => "Cow<'a, str>".to_string(), FieldType::Bytes => "Cow<'a, [u8]>".to_string(), FieldType::Bool => "bool".to_string(), - FieldType::Enum(ref e) => e.replace(".", "::"), + FieldType::Enum(ref e) => match self.find_enum(msgs, enums) { + Some(e) => format!("{}{}", e.get_modules(), e.name), + None => unreachable!(format!("Could not find enum {} in {:?}", e, enums)), + }, FieldType::Message(ref msg) => match self.find_message(msgs) { Some(m) => { let lifetime = if m.has_lifetime(msgs) { "<'a>" } else { "" }; format!("{}{}{}", m.get_modules(), m.name, lifetime) }, - None => unreachable!(format!("Could not find message {}", msg)), + None => unreachable!(format!("Could not find message {} in {:?}", msg, msgs)), }, FieldType::Map(ref t) => { let &(ref key, ref value) = &**t; - format!("HashMap<{}, {}>", key.rust_type(msgs), value.rust_type(msgs)) + format!("HashMap<{}, {}>", key.rust_type(msgs, enums), value.rust_type(msgs, enums)) } } } @@ -319,14 +349,15 @@ impl Field { tag(self.number as u32, &self.typ, self.packed()) } - fn write_definition(&self, w: &mut W, msgs: &[Message]) -> Result<()> { + fn write_definition(&self, w: &mut W, msgs: &[Message], enums: &[Enumerator]) -> Result<()> { write!(w, " pub {}: ", self.name)?; + let rust_type = self.typ.rust_type(msgs, enums); match self.frequency { - Frequency::Optional if self.boxed => writeln!(w, "Option>,", self.typ.rust_type(msgs))?, - Frequency::Optional if self.default.is_some() => writeln!(w, "{},", self.typ.rust_type(msgs))?, - Frequency::Optional => writeln!(w, "Option<{}>,", self.typ.rust_type(msgs))?, - Frequency::Repeated => writeln!(w, "Vec<{}>,", self.typ.rust_type(msgs))?, - Frequency::Required => writeln!(w, "{},", self.typ.rust_type(msgs))?, + Frequency::Optional if self.boxed => writeln!(w, "Option>,", rust_type)?, + Frequency::Optional if self.default.is_some() => writeln!(w, "{},", rust_type)?, + Frequency::Optional => writeln!(w, "Option<{}>,", rust_type)?, + Frequency::Repeated => writeln!(w, "Vec<{}>,", rust_type)?, + Frequency::Required => writeln!(w, "{},", rust_type)?, } Ok(()) } @@ -476,10 +507,10 @@ impl Message { .collect() } - fn write(&self, w: &mut W, msgs: &[Message]) -> Result<()> { + fn write(&self, w: &mut W, msgs: &[Message], enums: &[Enumerator]) -> Result<()> { println!("Writing message {}{}", self.get_modules(), self.name); writeln!(w, "")?; - self.write_definition(w, msgs)?; + self.write_definition(w, msgs, enums)?; writeln!(w, "")?; self.write_impl_message_read(w, msgs)?; writeln!(w, "")?; @@ -497,7 +528,10 @@ impl Message { } writeln!(w, "use super::*;")?; for m in &self.messages { - m.write(w, msgs)?; + m.write(w, msgs, enums)?; + } + for e in &self.enums { + e.write(w)?; } writeln!(w, "")?; writeln!(w, "}}")?; @@ -505,7 +539,7 @@ impl Message { Ok(()) } - fn write_definition(&self, w: &mut W, msgs: &[Message]) -> Result<()> { + fn write_definition(&self, w: &mut W, msgs: &[Message], enums: &[Enumerator]) -> Result<()> { writeln!(w, "#[derive(Debug, Default, PartialEq, Clone)]")?; if self.has_lifetime(msgs) { writeln!(w, "pub struct {}<'a> {{", self.name)?; @@ -513,7 +547,7 @@ impl Message { writeln!(w, "pub struct {} {{", self.name)?; } for f in self.fields.iter().filter(|f| !f.deprecated) { - f.write_definition(w, msgs)?; + f.write_definition(w, msgs, enums)?; } writeln!(w, "}}")?; Ok(()) @@ -600,6 +634,7 @@ impl Message { self.package = package.to_string(); let child_package = format!("{}.{}", package, self.name); for m in &mut self.messages { m.set_package(&child_package); } + for m in &mut self.enums { m.set_package(&child_package); } } } @@ -630,6 +665,28 @@ pub struct Enumerator { } impl Enumerator { + + fn set_package(&mut self, package: &str) { + self.package = package.to_string(); + } + + fn get_modules(&self) -> String { + self.package + .split('.').filter(|p| !p.is_empty()) + .map(|p| format!("mod_{}::", p)) + .collect() + } + + fn write(&self, w: &mut W) -> Result<()> { + println!("Writing enum {}", self.name); + writeln!(w, "")?; + self.write_definition(w)?; + writeln!(w, "")?; + self.write_impl_default(w)?; + writeln!(w, "")?; + self.write_from_i32(w) + } + fn write_definition(&self, w: &mut W) -> Result<()> { writeln!(w, "#[derive(Debug, PartialEq, Eq, Clone, Copy)]")?; writeln!(w, "pub enum {} {{", self.name)?; @@ -689,8 +746,7 @@ impl FileDescriptor { let name = in_file.as_ref().file_name().and_then(|e| e.to_str()).unwrap(); let mut w = BufWriter::new(File::create(out_file)?); - desc.write(&mut w, name)?; - Ok(()) + desc.write(&mut w, name) } /// Opens a proto file, reads it and returns raw parsed data @@ -726,6 +782,9 @@ impl FileDescriptor { for m in &mut self.messages { m.set_package(""); } + for m in &mut self.enums { + m.set_package(""); + } for p in &self.import_paths { let import_path = get_imported_path(&in_file, p); let mut f = FileDescriptor::read_proto(&import_path)?; @@ -739,7 +798,7 @@ impl FileDescriptor { m })); self.enums.extend(f.enums.drain(..).map(|mut e| { - e.package = package.clone(); + e.set_package(&package); e.imported = true; e })); @@ -873,7 +932,7 @@ impl FileDescriptor { fn write_messages(&self, w: &mut W) -> Result<()> { for m in self.messages.iter().filter(|m| !m.imported) { - m.write(w, &self.messages)?; + m.write(w, &self.messages, &self.enums)?; } Ok(()) } diff --git a/examples/codegen/data_types.proto b/examples/codegen/data_types.proto index 9f08a386..3d323550 100644 --- a/examples/codegen/data_types.proto +++ b/examples/codegen/data_types.proto @@ -33,15 +33,21 @@ message FooMessage { optional a.b.ImportedMessage f_imported = 21; optional BazMessage f_baz = 22; optional BazMessage.Nested f_nested = 23; - map f_map = 24; + optional BazMessage.Nested.NestedEnum f_nested_enum = 24; + map f_map = 25; } message BazMessage { message Nested { - message Nested2 { + message NestedMessage { required int32 f_nested = 1; } - required Nested2 f_nested = 1; + enum NestedEnum { + Foo = 0; + Bar = 1; + Baz = 2; + } + required NestedMessage f_nested = 1; } optional Nested nested = 1; } diff --git a/examples/codegen/data_types.rs b/examples/codegen/data_types.rs index 7b2591a7..ed3a5454 100644 --- a/examples/codegen/data_types.rs +++ b/examples/codegen/data_types.rs @@ -89,6 +89,7 @@ pub struct FooMessage<'a> { pub f_imported: Option, pub f_baz: Option, pub f_nested: Option, + pub f_nested_enum: Option, pub f_map: HashMap, i32>, } @@ -120,7 +121,8 @@ impl<'a> FooMessage<'a> { Ok(170) => msg.f_imported = Some(r.read_message(bytes, mod_a::mod_b::ImportedMessage::from_reader)?), Ok(178) => msg.f_baz = Some(r.read_message(bytes, BazMessage::from_reader)?), Ok(186) => msg.f_nested = Some(r.read_message(bytes, mod_BazMessage::Nested::from_reader)?), - Ok(194) => { + Ok(192) => msg.f_nested_enum = Some(r.read_enum(bytes)?), + Ok(202) => { let (key, value) = r.read_map(bytes, |r, bytes| r.read_string(bytes).map(Cow::Borrowed), |r, bytes| r.read_int32(bytes))?; msg.f_map.insert(key, value); } @@ -157,6 +159,7 @@ impl<'a> MessageWrite for FooMessage<'a> { + self.f_imported.as_ref().map_or(0, |m| 2 + sizeof_len((m).get_size())) + self.f_baz.as_ref().map_or(0, |m| 2 + sizeof_len((m).get_size())) + self.f_nested.as_ref().map_or(0, |m| 2 + sizeof_len((m).get_size())) + + self.f_nested_enum.as_ref().map_or(0, |m| 2 + sizeof_varint(*(m) as u64)) + self.f_map.iter().map(|(k, v)| 2 + sizeof_len(2 + sizeof_len((k).len()) + sizeof_varint(*(v) as u64))).sum::() } @@ -184,7 +187,8 @@ impl<'a> MessageWrite for FooMessage<'a> { if let Some(ref s) = self.f_imported { w.write_with_tag(170, |w| w.write_message(s))?; } if let Some(ref s) = self.f_baz { w.write_with_tag(178, |w| w.write_message(s))?; } if let Some(ref s) = self.f_nested { w.write_with_tag(186, |w| w.write_message(s))?; } - for (k, v) in self.f_map.iter() { w.write_with_tag(194, |w| w.write_map(2 + sizeof_len((k).len()) + sizeof_varint(*(v) as u64), 10, |w| w.write_string(&**k), 16, |w| w.write_int32(*v)))?; } + if let Some(ref s) = self.f_nested_enum { w.write_with_tag(192, |w| w.write_enum(*s as i32))?; } + for (k, v) in self.f_map.iter() { w.write_with_tag(202, |w| w.write_map(2 + sizeof_len((k).len()) + sizeof_varint(*(v) as u64), 10, |w| w.write_string(&**k), 16, |w| w.write_int32(*v)))?; } Ok(()) } } @@ -225,7 +229,7 @@ use super::*; #[derive(Debug, Default, PartialEq, Clone)] pub struct Nested { - pub f_nested: mod_BazMessage::mod_Nested::Nested2, + pub f_nested: mod_BazMessage::mod_Nested::NestedMessage, } impl Nested { @@ -233,7 +237,7 @@ impl Nested { let mut msg = Self::default(); while !r.is_eof() { match r.next_tag(bytes) { - Ok(10) => msg.f_nested = r.read_message(bytes, mod_BazMessage::mod_Nested::Nested2::from_reader)?, + Ok(10) => msg.f_nested = r.read_message(bytes, mod_BazMessage::mod_Nested::NestedMessage::from_reader)?, Ok(t) => { r.read_unknown(bytes, t)?; } Err(e) => return Err(e), } @@ -258,11 +262,11 @@ pub mod mod_Nested { use super::*; #[derive(Debug, Default, PartialEq, Clone)] -pub struct Nested2 { +pub struct NestedMessage { pub f_nested: i32, } -impl Nested2 { +impl NestedMessage { pub fn from_reader(r: &mut BytesReader, bytes: &[u8]) -> Result { let mut msg = Self::default(); while !r.is_eof() { @@ -276,7 +280,7 @@ impl Nested2 { } } -impl MessageWrite for Nested2 { +impl MessageWrite for NestedMessage { fn get_size(&self) -> usize { 1 + sizeof_varint(*(&self.f_nested) as u64) } @@ -287,6 +291,30 @@ impl MessageWrite for Nested2 { } } +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum NestedEnum { + Foo = 0, + Bar = 1, + Baz = 2, +} + +impl Default for NestedEnum { + fn default() -> Self { + NestedEnum::Foo + } +} + +impl From for NestedEnum { + fn from(i: i32) -> Self { + match i { + 0 => NestedEnum::Foo, + 1 => NestedEnum::Bar, + 2 => NestedEnum::Baz, + _ => Self::default(), + } + } +} + } } diff --git a/examples/codegen_example.rs b/examples/codegen_example.rs index 811723f8..d73279ba 100644 --- a/examples/codegen_example.rs +++ b/examples/codegen_example.rs @@ -33,9 +33,12 @@ fn main() { // nested messages are encapsulated into a rust module mod_Message f_nested: Some(data_types::mod_BazMessage::Nested { - f_nested: data_types::mod_BazMessage::mod_Nested::Nested2 { f_nested: 2 } + f_nested: data_types::mod_BazMessage::mod_Nested::NestedMessage { f_nested: 2 } }), + // nested enums too + f_nested_enum: Some(data_types::mod_BazMessage::mod_Nested::NestedEnum::Baz), + // a map! f_map: vec![(Cow::Borrowed("foo"), 1), (Cow::Borrowed("bar"), 2)].into_iter().collect(),