Skip to content
4 changes: 2 additions & 2 deletions benches/benches/bevy_ecs/entity_cloning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ fn bench_clone_hierarchy<B: Bundle + Default + GetTypeRegistration>(

hierarchy_level.clear();

for parent_id in current_hierarchy_level {
for parent in current_hierarchy_level {
for _ in 0..children {
let child_id = world.spawn((B::default(), ChildOf(parent_id))).id();
let child_id = world.spawn((B::default(), ChildOf { parent })).id();
hierarchy_level.push(child_id);
}
}
Expand Down
133 changes: 90 additions & 43 deletions crates/bevy_ecs/macros/src/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use syn::{
punctuated::Punctuated,
spanned::Spanned,
token::{Comma, Paren},
Data, DataStruct, DeriveInput, ExprClosure, ExprPath, Fields, Ident, Index, LitStr, Member,
Path, Result, Token, Visibility,
Data, DataStruct, DeriveInput, ExprClosure, ExprPath, Field, Fields, Ident, Index, LitStr,
Member, Path, Result, Token, Visibility,
};

pub fn derive_event(input: TokenStream) -> TokenStream {
Expand Down Expand Up @@ -260,6 +260,15 @@ fn visit_entities(data: &Data, bevy_ecs_path: &Path, is_relationship: bool) -> T
Data::Struct(DataStruct { ref fields, .. }) => {
let mut visited_fields = Vec::new();
let mut visited_indices = Vec::new();

if is_relationship {
if let Some(field) = relationship_field(fields, fields.span()).ok().flatten() {
match field.ident {
Some(ref ident) => visited_fields.push(ident.clone()),
None => visited_indices.push(Index::from(0)),
}
}
}
match fields {
Fields::Named(fields) => {
for field in &fields.named {
Expand All @@ -276,9 +285,7 @@ fn visit_entities(data: &Data, bevy_ecs_path: &Path, is_relationship: bool) -> T
}
Fields::Unnamed(fields) => {
for (index, field) in fields.unnamed.iter().enumerate() {
if index == 0 && is_relationship {
visited_indices.push(Index::from(0));
} else if field
if field
.attrs
.iter()
.any(|a| a.meta.path().is_ident(ENTITIES_ATTR))
Expand All @@ -289,7 +296,6 @@ fn visit_entities(data: &Data, bevy_ecs_path: &Path, is_relationship: bool) -> T
}
Fields::Unit => {}
}

if visited_fields.is_empty() && visited_indices.is_empty() {
TokenStream2::new()
} else {
Expand Down Expand Up @@ -651,25 +657,29 @@ fn derive_relationship(
let Some(relationship) = &attrs.relationship else {
return Ok(None);
};
const RELATIONSHIP_FORMAT_MESSAGE: &str = "Relationship derives must be a tuple struct with the only element being an EntityTargets type (ex: ChildOf(Entity))";
if let Data::Struct(DataStruct {
fields: Fields::Unnamed(unnamed_fields),
let Data::Struct(DataStruct {
fields,
struct_token,
..
}) = &ast.data
{
if unnamed_fields.unnamed.len() != 1 {
return Err(syn::Error::new(ast.span(), RELATIONSHIP_FORMAT_MESSAGE));
}
if unnamed_fields.unnamed.first().is_none() {
return Err(syn::Error::new(
struct_token.span(),
RELATIONSHIP_FORMAT_MESSAGE,
));
}
} else {
return Err(syn::Error::new(ast.span(), RELATIONSHIP_FORMAT_MESSAGE));
else {
return Err(syn::Error::new(
ast.span(),
"Relationship can only be derived for structs.",
));
};
let field = relationship_field(fields, struct_token.span())?;

let relationship_member: Member = match field {
Some(field) => field.ident.clone().map_or(Member::from(0), Member::Named),
None => return Err(syn::Error::new(
fields.span(),
"Relationship can only be derived for structs with a single unnamed field or for structs where one field is annotated with #[relationship].",
)),
};
let members = fields
.members()
.filter(|member| member != &relationship_member);

let struct_name = &ast.ident;
let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl();
Expand All @@ -682,12 +692,15 @@ fn derive_relationship(

#[inline(always)]
fn get(&self) -> #bevy_ecs_path::entity::Entity {
self.0
self.#relationship_member
}

#[inline]
fn from(entity: #bevy_ecs_path::entity::Entity) -> Self {
Self(entity)
Self {
#(#members: core::default::Default::default(),),*
#relationship_member: entity
}
}
}
}))
Expand All @@ -702,31 +715,37 @@ fn derive_relationship_target(
return Ok(None);
};

const RELATIONSHIP_TARGET_FORMAT_MESSAGE: &str = "RelationshipTarget derives must be a tuple struct with the first element being a private RelationshipSourceCollection (ex: Children(Vec<Entity>))";
let collection = if let Data::Struct(DataStruct {
fields: Fields::Unnamed(unnamed_fields),
let Data::Struct(DataStruct {
fields,
struct_token,
..
}) = &ast.data
{
if let Some(first) = unnamed_fields.unnamed.first() {
if first.vis != Visibility::Inherited {
return Err(syn::Error::new(first.span(), "The collection in RelationshipTarget must be private to prevent users from directly mutating it, which could invalidate the correctness of relationships."));
}
first.ty.clone()
} else {
return Err(syn::Error::new(
struct_token.span(),
RELATIONSHIP_TARGET_FORMAT_MESSAGE,
));
}
} else {
else {
return Err(syn::Error::new(
ast.span(),
RELATIONSHIP_TARGET_FORMAT_MESSAGE,
"RelationshipTarget can only be derived for structs.",
));
};
let field = relationship_field(fields, struct_token.span())?;

let Some(field) = field else {
return Err(syn::Error::new(
fields.span(),
"RelationshipTarget can only be derived for structs with a single private unnamed field or for structs where one field is annotated with #[relationship] and is private.",
));
};

if field.vis != Visibility::Inherited {
return Err(syn::Error::new(field.span(), "The collection in RelationshipTarget must be private to prevent users from directly mutating it, which could invalidate the correctness of relationships."));
}
let collection = &field.ty;

let relationship_member = field.ident.clone().map_or(Member::from(0), Member::Named);

let members = fields
.members()
.filter(|member| member != &relationship_member);

let relationship = &relationship_target.relationship;
let struct_name = &ast.ident;
let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl();
Expand All @@ -739,18 +758,46 @@ fn derive_relationship_target(

#[inline]
fn collection(&self) -> &Self::Collection {
&self.0
&self.#relationship_member
}

#[inline]
fn collection_mut_risky(&mut self) -> &mut Self::Collection {
&mut self.0
&mut self.#relationship_member
}

#[inline]
fn from_collection_risky(collection: Self::Collection) -> Self {
Self(collection)
Self {
#(#members: core::default::Default::default(),),*
#relationship_member: collection
}
}
}
}))
}

/// Returns the field with the `#[relationship]` attribute, the only field if unnamed,
/// or the only field in a [`Fields::Named`] with one field, otherwise None.
fn relationship_field(fields: &Fields, span: Span) -> Result<Option<&Field>> {
let field = match fields {
Fields::Named(fields) if fields.named.len() == 1 => fields.named.first(),
Fields::Named(fields) => fields.named.iter().find(|field| {
field
.attrs
.iter()
.any(|attr| attr.path().is_ident("relationship"))
}),
Fields::Unnamed(fields) => fields
.unnamed
.len()
.eq(&1)
.then(|| fields.unnamed.first())
.flatten(),
Fields::Unit => return Err(syn::Error::new(
span,
"Relationship and RelationshipTarget can only be derived for named or unnamed structs, not unit structs.",
)),
};
Ok(field)
}
6 changes: 3 additions & 3 deletions crates/bevy_ecs/src/entity/clone_entities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1263,9 +1263,9 @@ mod tests {
fn recursive_clone() {
let mut world = World::new();
let root = world.spawn_empty().id();
let child1 = world.spawn(ChildOf(root)).id();
let grandchild = world.spawn(ChildOf(child1)).id();
let child2 = world.spawn(ChildOf(root)).id();
let child1 = world.spawn(ChildOf { parent: root }).id();
let grandchild = world.spawn(ChildOf { parent: child1 }).id();
let child2 = world.spawn(ChildOf { parent: root }).id();

let clone_root = world.spawn_empty().id();
EntityCloner::build(&mut world)
Expand Down
45 changes: 25 additions & 20 deletions crates/bevy_ecs/src/hierarchy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ use log::warn;
/// # use bevy_ecs::prelude::*;
/// # let mut world = World::new();
/// let root = world.spawn_empty().id();
/// let child1 = world.spawn(ChildOf(root)).id();
/// let child2 = world.spawn(ChildOf(root)).id();
/// let grandchild = world.spawn(ChildOf(child1)).id();
/// let child1 = world.spawn(ChildOf {parent: root}).id();
/// let child2 = world.spawn(ChildOf {parent: root}).id();
/// let grandchild = world.spawn(ChildOf {parent: child1}).id();
///
/// assert_eq!(&**world.entity(root).get::<Children>().unwrap(), &[child1, child2]);
/// assert_eq!(&**world.entity(child1).get::<Children>().unwrap(), &[grandchild]);
Expand Down Expand Up @@ -94,12 +94,15 @@ use log::warn;
)]
#[relationship(relationship_target = Children)]
#[doc(alias = "IsChild", alias = "Parent")]
pub struct ChildOf(pub Entity);
pub struct ChildOf {
/// The parent entity of this child entity.
pub parent: Entity,
}

impl ChildOf {
/// Returns the parent entity, which is the "target" of this relationship.
pub fn get(&self) -> Entity {
self.0
self.parent
}
}

Expand All @@ -108,7 +111,7 @@ impl Deref for ChildOf {

#[inline]
fn deref(&self) -> &Self::Target {
&self.0
&self.parent
}
}

Expand All @@ -119,7 +122,9 @@ impl Deref for ChildOf {
impl FromWorld for ChildOf {
#[inline(always)]
fn from_world(_world: &mut World) -> Self {
ChildOf(Entity::PLACEHOLDER)
ChildOf {
parent: Entity::PLACEHOLDER,
}
}
}

Expand Down Expand Up @@ -198,7 +203,7 @@ impl<'w> EntityWorldMut<'w> {
pub fn with_child(&mut self, bundle: impl Bundle) -> &mut Self {
let id = self.id();
self.world_scope(|world| {
world.spawn((bundle, ChildOf(id)));
world.spawn((bundle, ChildOf { parent: id }));
});
self
}
Expand All @@ -213,7 +218,7 @@ impl<'w> EntityWorldMut<'w> {
/// Inserts the [`ChildOf`] component with the given `parent` entity, if it exists.
#[deprecated(since = "0.16.0", note = "Use entity_mut.insert(ChildOf(entity))")]
pub fn set_parent(&mut self, parent: Entity) -> &mut Self {
self.insert(ChildOf(parent));
self.insert(ChildOf { parent });
self
}
}
Expand Down Expand Up @@ -245,7 +250,7 @@ impl<'a> EntityCommands<'a> {
/// [`with_children`]: EntityCommands::with_children
pub fn with_child(&mut self, bundle: impl Bundle) -> &mut Self {
let id = self.id();
self.commands.spawn((bundle, ChildOf(id)));
self.commands.spawn((bundle, ChildOf { parent: id }));
self
}

Expand All @@ -259,7 +264,7 @@ impl<'a> EntityCommands<'a> {
/// Inserts the [`ChildOf`] component with the given `parent` entity, if it exists.
#[deprecated(since = "0.16.0", note = "Use entity_commands.insert(ChildOf(entity))")]
pub fn set_parent(&mut self, parent: Entity) -> &mut Self {
self.insert(ChildOf(parent));
self.insert(ChildOf { parent });
self
}
}
Expand Down Expand Up @@ -375,9 +380,9 @@ mod tests {
fn hierarchy() {
let mut world = World::new();
let root = world.spawn_empty().id();
let child1 = world.spawn(ChildOf(root)).id();
let grandchild = world.spawn(ChildOf(child1)).id();
let child2 = world.spawn(ChildOf(root)).id();
let child1 = world.spawn(ChildOf { parent: root }).id();
let grandchild = world.spawn(ChildOf { parent: child1 }).id();
let child2 = world.spawn(ChildOf { parent: root }).id();

// Spawn
let hierarchy = get_hierarchy(&world, root);
Expand All @@ -398,7 +403,7 @@ mod tests {
assert_eq!(hierarchy, Node::new_with(root, vec![Node::new(child2)]));

// Insert
world.entity_mut(child1).insert(ChildOf(root));
world.entity_mut(child1).insert(ChildOf { parent: root });
let hierarchy = get_hierarchy(&world, root);
assert_eq!(
hierarchy,
Expand Down Expand Up @@ -457,7 +462,7 @@ mod tests {
fn self_parenting_invalid() {
let mut world = World::new();
let id = world.spawn_empty().id();
world.entity_mut(id).insert(ChildOf(id));
world.entity_mut(id).insert(ChildOf { parent: id });
assert!(
world.entity(id).get::<ChildOf>().is_none(),
"invalid ChildOf relationships should self-remove"
Expand All @@ -469,7 +474,7 @@ mod tests {
let mut world = World::new();
let parent = world.spawn_empty().id();
world.entity_mut(parent).despawn();
let id = world.spawn(ChildOf(parent)).id();
let id = world.spawn(ChildOf { parent }).id();
assert!(
world.entity(id).get::<ChildOf>().is_none(),
"invalid ChildOf relationships should self-remove"
Expand All @@ -480,10 +485,10 @@ mod tests {
fn reinsert_same_parent() {
let mut world = World::new();
let parent = world.spawn_empty().id();
let id = world.spawn(ChildOf(parent)).id();
world.entity_mut(id).insert(ChildOf(parent));
let id = world.spawn(ChildOf { parent }).id();
world.entity_mut(id).insert(ChildOf { parent });
assert_eq!(
Some(&ChildOf(parent)),
Some(&ChildOf { parent }),
world.entity(id).get::<ChildOf>(),
"ChildOf should still be there"
);
Expand Down
Loading