Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ extern crate alloc;
extern crate proc_macro;

mod impl_wrapper;
mod trait_bounds;

use alloc::{
string::{
Expand All @@ -34,7 +35,6 @@ use syn::{
Error,
Result,
},
parse_quote,
punctuated::Punctuated,
token::Comma,
Data,
Expand Down Expand Up @@ -66,12 +66,9 @@ fn generate(input: TokenStream2) -> Result<TokenStream2> {
fn generate_type(input: TokenStream2) -> Result<TokenStream2> {
let mut ast: DeriveInput = syn::parse2(input.clone())?;

ast.generics.type_params_mut().for_each(|p| {
p.bounds.push(parse_quote!(::scale_info::TypeInfo));
p.bounds.push(parse_quote!('static));
});

let ident = &ast.ident;
trait_bounds::add(ident, &mut ast.generics, &ast.data)?;

let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
let generic_type_ids = ast.generics.type_params().map(|ty| {
let ty_ident = &ty.ident;
Expand Down
141 changes: 141 additions & 0 deletions derive/src/trait_bounds.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Copyright 2019-2020 Parity Technologies (UK) Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use alloc::vec::Vec;
use proc_macro2::Ident;
use syn::{
parse_quote,
punctuated::Punctuated,
spanned::Spanned,
visit::Visit,
Generics,
Result,
Type,
};

/// Adds a `TypeInfo + 'static` bound to all relevant generic types including
/// associated types (e.g. `T::A: TypeInfo`), correctly dealing with
/// self-referential types.
pub fn add(input_ident: &Ident, generics: &mut Generics, data: &syn::Data) -> Result<()> {
let ty_params = generics.type_params_mut().fold(Vec::new(), |mut acc, p| {
p.bounds.push(parse_quote!(::scale_info::TypeInfo));
p.bounds.push(parse_quote!('static));
acc.push(p.ident.clone());
acc
});

if ty_params.is_empty() {
return Ok(())
}

let types = collect_types_to_bind(input_ident, data, &ty_params)?;

if !types.is_empty() {
let where_clause = generics.make_where_clause();

types.into_iter().for_each(|ty| {
where_clause
.predicates
.push(parse_quote!(#ty : ::scale_info::TypeInfo + 'static))
});
}

Ok(())
}

/// Visits the ast and checks if the given type contains one of the given
/// idents.
fn type_contains_idents(ty: &Type, idents: &[Ident]) -> bool {
struct ContainIdents<'a> {
result: bool,
idents: &'a [Ident],
}

impl<'a, 'ast> Visit<'ast> for ContainIdents<'a> {
fn visit_ident(&mut self, i: &'ast Ident) {
if self.idents.iter().any(|id| id == i) {
self.result = true;
}
}
}

let mut visitor = ContainIdents {
result: false,
idents,
};
visitor.visit_type(ty);
visitor.result
}

/// Returns all types that must be added to the where clause with the respective
/// trait bound.
fn collect_types_to_bind(
input_ident: &Ident,
data: &syn::Data,
ty_params: &[Ident],
) -> Result<Vec<Type>> {
let types_from_fields = |fields: &Punctuated<syn::Field, _>| -> Vec<syn::Type> {
fields
.iter()
.filter(|field| {
// Only add a bound if the type uses a generic.
type_contains_idents(&field.ty, &ty_params)
&&
// Remove all remaining types that start/contain the input ident
// to not have them in the where clause.
!type_contains_idents(&field.ty, &[input_ident.clone()])
})
.map(|f| f.ty.clone())
.collect()
};

let types = match *data {
syn::Data::Struct(ref data) => {
match &data.fields {
syn::Fields::Named(syn::FieldsNamed { named: fields, .. })
| syn::Fields::Unnamed(syn::FieldsUnnamed {
unnamed: fields, ..
}) => types_from_fields(fields),
syn::Fields::Unit => Vec::new(),
}
}

syn::Data::Enum(ref data) => {
data.variants
.iter()
.flat_map(|variant| {
match &variant.fields {
syn::Fields::Named(syn::FieldsNamed {
named: fields, ..
})
| syn::Fields::Unnamed(syn::FieldsUnnamed {
unnamed: fields,
..
}) => types_from_fields(fields),
syn::Fields::Unit => Vec::new(),
}
})
.collect()
}

syn::Data::Union(ref data) => {
return Err(syn::Error::new(
data.union_token.span(),
"Union types are not supported.",
))
}
};

Ok(types)
}
2 changes: 2 additions & 0 deletions test_suite/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ scale = { package = "parity-scale-codec", version = "1.3", default-features = fa
serde = "1.0"
serde_json = "1.0"
pretty_assertions = "0.6.1"
trybuild = "1"
rustversion = "1"
37 changes: 37 additions & 0 deletions test_suite/tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,40 @@ fn fields_with_type_alias() {

assert_type!(S, ty);
}

#[test]
fn associated_types_derive_without_bounds() {
trait Types {
type A;
}
#[allow(unused)]
#[derive(TypeInfo)]
struct Assoc<T: Types> {
a: T::A,
}

#[derive(TypeInfo)]
enum ConcreteTypes {}
impl Types for ConcreteTypes {
type A = bool;
}

let struct_type = Type::builder()
.path(Path::new("Assoc", "derive"))
.type_params(tuple_meta_type!(ConcreteTypes))
.composite(Fields::named().field_of::<bool>("a", "T::A"));

assert_type!(Assoc<ConcreteTypes>, struct_type);
}

#[rustversion::nightly]
#[test]
fn ui_tests() {
let t = trybuild::TestCases::new();
t.compile_fail("tests/ui/fail_missing_derive.rs");
t.compile_fail("tests/ui/fail_non_static_lifetime.rs");
t.compile_fail("tests/ui/fail_unions.rs");
t.pass("tests/ui/pass_self_referential.rs");
t.pass("tests/ui/pass_basic_generic_type.rs");
t.pass("tests/ui/pass_complex_generic_self_referential_type.rs");
}
18 changes: 18 additions & 0 deletions test_suite/tests/ui/fail_missing_derive.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use scale_info::TypeInfo;

enum PawType<Paw> {
Big(Paw),
Small(Paw),
}
#[derive(TypeInfo)]
struct Cat<Tail, Ear, Paw> {
tail: Tail,
ears: [Ear; 3],
paws: PawType<Paw>,
}

fn assert_type_info<T: TypeInfo + 'static>() {}

fn main() {
assert_type_info::<Cat<bool, u8, u16>>();
}
10 changes: 10 additions & 0 deletions test_suite/tests/ui/fail_missing_derive.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
error[E0277]: the trait bound `PawType<u16>: TypeInfo` is not satisfied
--> $DIR/fail_missing_derive.rs:17:5
|
14 | fn assert_type_info<T: TypeInfo + 'static>() {}
| -------- required by this bound in `assert_type_info`
...
17 | assert_type_info::<Cat<bool, u8, u16>>();
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `TypeInfo` is not implemented for `PawType<u16>`
|
= note: required because of the requirements on the impl of `TypeInfo` for `Cat<bool, u8, u16>`
12 changes: 12 additions & 0 deletions test_suite/tests/ui/fail_non_static_lifetime.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use scale_info::TypeInfo;

#[derive(TypeInfo)]
struct Me<'a> {
me: &'a Me<'a>,
}

fn assert_type_info<T: TypeInfo + 'static>() {}

fn main() {
assert_type_info::<Me>();
}
25 changes: 25 additions & 0 deletions test_suite/tests/ui/fail_non_static_lifetime.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
error[E0477]: the type `Me<'a>` does not fulfill the required lifetime
--> $DIR/fail_non_static_lifetime.rs:3:10
|
3 | #[derive(TypeInfo)]
| ^^^^^^^^ in this macro invocation
|
::: $WORKSPACE/derive/src/lib.rs
|
| pub fn type_info(input: TokenStream) -> TokenStream {
| --------------------------------------------------- in this expansion of `#[derive(TypeInfo)]`
|
= note: type must satisfy the static lifetime

error[E0477]: the type `&'a Me<'a>` does not fulfill the required lifetime
--> $DIR/fail_non_static_lifetime.rs:3:10
|
3 | #[derive(TypeInfo)]
| ^^^^^^^^ in this macro invocation
|
::: $WORKSPACE/derive/src/lib.rs
|
| pub fn type_info(input: TokenStream) -> TokenStream {
| --------------------------------------------------- in this expansion of `#[derive(TypeInfo)]`
|
= note: type must satisfy the static lifetime
14 changes: 14 additions & 0 deletions test_suite/tests/ui/fail_unions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use scale_info::TypeInfo;

#[derive(TypeInfo)]
#[repr(C)]
union Commonwealth {
a: u8,
b: f32,
}

fn assert_type_info<T: TypeInfo + 'static>() {}

fn main() {
assert_type_info::<Commonwealth>();
}
18 changes: 18 additions & 0 deletions test_suite/tests/ui/fail_unions.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
error: Unions not supported
--> $DIR/fail_unions.rs:4:1
|
4 | / #[repr(C)]
5 | | union Commonwealth {
6 | | a: u8,
7 | | b: f32,
8 | | }
| |_^

error[E0277]: the trait bound `Commonwealth: TypeInfo` is not satisfied
--> $DIR/fail_unions.rs:13:24
|
10 | fn assert_type_info<T: TypeInfo + 'static>() {}
| -------- required by this bound in `assert_type_info`
...
13 | assert_type_info::<Commonwealth>();
| ^^^^^^^^^^^^ the trait `TypeInfo` is not implemented for `Commonwealth`
20 changes: 20 additions & 0 deletions test_suite/tests/ui/pass_basic_generic_type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use scale_info::TypeInfo;

#[allow(dead_code)]
#[derive(TypeInfo)]
enum PawType<Paw> {
Big(Paw),
Small(Paw),
}
#[derive(TypeInfo)]
struct Cat<Tail, Ear, Paw> {
_tail: Tail,
_ears: [Ear; 3],
_paws: PawType<Paw>,
}

fn assert_type_info<T: TypeInfo + 'static>() {}

fn main() {
assert_type_info::<Cat<bool, u8, u16>>();
}
41 changes: 41 additions & 0 deletions test_suite/tests/ui/pass_complex_generic_self_referential_type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use scale_info::TypeInfo;


#[derive(TypeInfo)]
struct Nested<P> {
_pos: P,
}

#[derive(TypeInfo)]
struct Is<N> {
_nested: N,
}

#[derive(TypeInfo)]
struct That<I, S> {
_is: I,
_selfie: S,
}

#[derive(TypeInfo)]
struct Thing<T> {
_that: T,
}

#[derive(TypeInfo)]
struct Other<T> {
_thing: T,
}

#[derive(TypeInfo)]
struct Selfie<Pos> {
_another: Box<Selfie<Pos>>,
_pos: Pos,
_nested: Box<Other<Thing<That<Is<Nested<Pos>>, Selfie<Pos>>>>>,
}

fn assert_type_info<T: TypeInfo + 'static>() {}

fn main() {
assert_type_info::<Selfie<bool>>();
}
12 changes: 12 additions & 0 deletions test_suite/tests/ui/pass_self_referential.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use scale_info::TypeInfo;

#[derive(TypeInfo)]
struct Me {
_me: Box<Me>,
}

fn assert_type_info<T: TypeInfo + 'static>() {}

fn main() {
assert_type_info::<Me>();
}