diff --git a/derive/src/trait_bounds.rs b/derive/src/trait_bounds.rs index e988c450..1e79e7d0 100644 --- a/derive/src/trait_bounds.rs +++ b/derive/src/trait_bounds.rs @@ -18,10 +18,14 @@ use syn::{ parse_quote, punctuated::Punctuated, spanned::Spanned, - visit::Visit, + visit::{ + self, + Visit, + }, Generics, Result, Type, + TypePath, WhereClause, }; @@ -112,6 +116,36 @@ fn type_contains_idents(ty: &Type, idents: &[Ident]) -> bool { visitor.result } +/// Checks if the given type or any containing type path starts with the given ident. +fn type_or_sub_type_path_starts_with_ident(ty: &Type, ident: &Ident) -> bool { + // Visits the ast and checks if the a type path starts with the given ident. + struct TypePathStartsWithIdent<'a> { + result: bool, + ident: &'a Ident, + } + + impl<'a, 'ast> Visit<'ast> for TypePathStartsWithIdent<'a> { + fn visit_type_path(&mut self, i: &'ast TypePath) { + if i.qself.is_none() { + if let Some(segment) = i.path.segments.first() { + if &segment.ident == self.ident { + self.result = true; + return + } + } + } + visit::visit_type_path(self, i); + } + } + + let mut visitor = TypePathStartsWithIdent { + result: false, + ident, + }; + visitor.visit_type(ty); + visitor.result +} + /// Returns all types that must be added to the where clause with a boolean /// indicating if the field is [`scale::Compact`] or not. fn collect_types_to_bind( @@ -128,7 +162,7 @@ fn collect_types_to_bind( && // Remove all remaining types that start/contain the input ident // to not have them in the where clause. - !type_contains_idents(&field.ty, &[input_ident.clone()]) + !type_or_sub_type_path_starts_with_ident(&field.ty, &input_ident) }) .map(|f| (f.ty.clone(), super::is_compact(f))) .collect() diff --git a/test_suite/tests/derive.rs b/test_suite/tests/derive.rs index b6a1bdb0..02a90a35 100644 --- a/test_suite/tests/derive.rs +++ b/test_suite/tests/derive.rs @@ -21,6 +21,7 @@ use scale_info::{ prelude::{ boxed::Box, marker::PhantomData, + vec::Vec, }, tuple_meta_type, Path, @@ -234,6 +235,40 @@ fn associated_types_derive_without_bounds() { assert_type!(Assoc, struct_type); } +#[test] +fn associated_types_named_like_the_derived_type_works() { + trait Types { + type Assoc; + } + #[allow(unused)] + #[derive(TypeInfo)] + struct Assoc { + a: Vec, + b: Vec<::Assoc>, + c: T::Assoc, + d: ::Assoc, + } + + #[derive(TypeInfo)] + enum ConcreteTypes {} + impl Types for ConcreteTypes { + type Assoc = bool; + } + + let struct_type = Type::builder() + .path(Path::new("Assoc", "derive")) + .type_params(tuple_meta_type!(ConcreteTypes)) + .composite( + Fields::named() + .field_of::>("a", "Vec") + .field_of::>("b", "Vec<::Assoc>") + .field_of::("c", "T::Assoc") + .field_of::("d", "::Assoc"), + ); + + assert_type!(Assoc, struct_type); +} + #[test] fn scale_compact_types_work_in_structs() { #[allow(unused)]