From b2d0aee342d9ac1438f14ac37e99b8a01ed027b2 Mon Sep 17 00:00:00 2001 From: SOFe Date: Mon, 11 Dec 2023 00:02:23 +0800 Subject: [PATCH 1/2] feat: add Access::try_get_many_mut --- src/lib.rs | 2 ++ src/storage.rs | 15 ++++++++ src/storage/tree.rs | 72 +++++++++++++++++++++++++++++++++++-- src/storage/vec.rs | 45 +++++++++++++++++++++++ src/system/access/single.rs | 11 ++++++ src/util.rs | 15 +++++++- 6 files changed, 157 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index c93d08caaa..334196e9b7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -168,6 +168,8 @@ #![feature(never_type)] #![feature(sync_unsafe_cell)] #![feature(slice_take)] +#![feature(get_many_mut)] +#![feature(array_try_from_fn, array_try_map)] /// Internal re-exports used in macros. #[doc(hidden)] diff --git a/src/storage.rs b/src/storage.rs index cf0d5056b3..e2ff28466f 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -106,6 +106,12 @@ pub trait Partition<'t>: Access + Send + Sync + Sized + 't { /// Same as [`get_mut`](Access::get_mut), but returns a reference with lifetime `'t`. fn into_mut(self, entity: Self::RawEntity) -> Option<&'t mut Self::Comp>; + + /// Same as [`get_many_mut`](Access::get_many_mut), but returns a reference with lifetime `'t`. + fn into_many_mut( + self, + entities: [Self::RawEntity; N], + ) -> Option<[&'t mut Self::Comp; N]>; } /// Mutable access functions for a storage, generalizing [`Storage`] and [`Partition`]. @@ -118,6 +124,15 @@ pub trait Access { /// Gets a mutable reference to the component for a specific entity if it is present. fn get_mut(&mut self, entity: Self::RawEntity) -> Option<&mut Self::Comp>; + /// Gets mutable references to the components for specific entities if they are present. + /// + /// Returns `None` if any entity is uninitialized + /// or if any entity appeared in `entities` more than once. + fn get_many_mut( + &mut self, + entities: [Self::RawEntity; N], + ) -> Option<[&mut Self::Comp; N]>; + /// Return value of [`iter_mut`](Self::iter_mut). type IterMut<'u>: Iterator + 'u where diff --git a/src/storage/tree.rs b/src/storage/tree.rs index c7c677059e..279b3605d1 100644 --- a/src/storage/tree.rs +++ b/src/storage/tree.rs @@ -1,9 +1,10 @@ use std::cell::SyncUnsafeCell; use std::collections::BTreeMap; -use std::slice; +use std::ptr::NonNull; +use std::{array, slice}; use super::{Access, ChunkMut, ChunkRef, Partition, Storage}; -use crate::entity; +use crate::{entity, util}; /// A storage based on [`BTreeMap`]. pub struct Tree { @@ -26,6 +27,36 @@ impl Access for Tree { self.data.get_mut(&id).map(|cell| cell.get_mut()) } + fn get_many_mut( + &mut self, + entities: [Self::RawEntity; N], + ) -> Option<[&mut Self::Comp; N]> { + let ptrs = entities.map(|entity| { + let datum = self.data.get(&entity)?; + let ptr = datum.get(); + NonNull::new(ptr) + }); + + if !util::is_all_distinct_quadtime(&ptrs) { + return None; + } + + if ptrs.iter().any(|ptr| ptr.is_none()) { + return None; + } + + Some(ptrs.map(|ptr| { + let mut ptr = ptr.expect("checked all are not none"); + + unsafe { + // All pointers originated from a `&mut self`, so all possible aliases are in locals. + // We have checked that all `ptrs` are distinct, + // and since they come from UnsafeCell, they cannot overlap. + ptr.as_mut() + } + })) + } + type IterMut<'t> = impl Iterator + 't; fn iter_mut(&mut self) -> Self::IterMut<'_> { Box::new(self.data.iter_mut().map(|(&entity, cell)| (entity, cell.get_mut()))) @@ -111,6 +142,13 @@ impl<'t, RawT: entity::Raw, C: Send + Sync + 'static> Access for StoragePartitio } } + fn get_many_mut( + &mut self, + entities: [RawT; N], + ) -> Option<[&mut Self::Comp; N]> { + self.by_ref().into_many_mut(entities) + } + type IterMut<'u> = impl Iterator + 'u where Self: 'u; fn iter_mut(&mut self) -> Self::IterMut<'_> { self.by_ref().into_iter_mut() } } @@ -158,6 +196,36 @@ impl<'t, RawT: entity::Raw, C: Send + Sync + 'static> Partition<'t> } } + fn into_many_mut( + self, + entities: [Self::RawEntity; N], + ) -> Option<[&'t mut Self::Comp; N]> { + for entity in entities { + self.assert_bounds(entity); + } + + let ptrs = entities.map(|entity| { + let datum = self.data.get(&entity)?; + let ptr = datum.get(); + NonNull::new(ptr) + }); + + if !util::is_all_distinct_quadtime(&ptrs) { + return None; + } + + array::try_from_fn(|i| { + let mut ptr = ptrs[i]?; + + unsafe { + // All pointers originated from a `&mut self`, so all possible aliases are in locals. + // We have checked that all `ptrs` are distinct, + // and since they come from UnsafeCell, they cannot overlap. + Some(ptr.as_mut()) + } + }) + } + fn split_out(&mut self, entity: RawT) -> Self { self.assert_bounds(entity); diff --git a/src/storage/vec.rs b/src/storage/vec.rs index 69cb06929a..ec712c6b3c 100644 --- a/src/storage/vec.rs +++ b/src/storage/vec.rs @@ -74,6 +74,24 @@ impl Access for VecStorage } } + fn get_many_mut( + &mut self, + entities: [RawT; N], + ) -> Option<[&mut Self::Comp; N]> { + let indices = entities.map(|id| id.to_primitive()); + + if !indices.iter().all(|&index| self.bit(index)) { + return None; + } + + let values = self.data.get_many_mut(indices).ok()?; + + Some(values.map(|value| { + // Safety: values correspond to indices checked above. + unsafe { value.assume_init_mut() } + })) + } + type IterMut<'t> = impl Iterator + 't; fn iter_mut(&mut self) -> Self::IterMut<'_> { iter_mut(0, &self.bits, &mut self.data) } } @@ -186,6 +204,13 @@ impl<'t, RawT: entity::Raw, C: Send + Sync + 'static> Access for StoragePartitio fn get_mut(&mut self, entity: RawT) -> Option<&mut C> { self.by_ref().into_mut(entity) } + fn get_many_mut( + &mut self, + entities: [RawT; N], + ) -> Option<[&mut Self::Comp; N]> { + self.by_ref().into_many_mut(entities) + } + type IterMut<'u> = impl Iterator + 'u where Self: 'u; fn iter_mut(&mut self) -> Self::IterMut<'_> { self.by_ref().into_iter_mut() } } @@ -220,6 +245,26 @@ impl<'t, RawT: entity::Raw, C: Send + Sync + 'static> Partition<'t> } } + fn into_many_mut( + self, + entities: [Self::RawEntity; N], + ) -> Option<[&'t mut Self::Comp; N]> { + let indices: [usize; N] = + entities.try_map(|entity| match entity.to_primitive().checked_sub(self.offset) { + Some(index) => match self.bits.get(index) { + Some(bit) if *bit => Some(index), + _ => None, + }, + None => panic!("Entity {entity:?} is not in the partition {:?}..", self.offset), + })?; + let values = self.data.get_many_mut(indices).ok()?; + Some(values.map(move |value| { + // Safety: all indices have been checked to be initialized + // before getting mapped into `indices` + unsafe { value.assume_init_mut() } + })) + } + fn split_out(&mut self, entity: RawT) -> Self { let index = entity.to_primitive().checked_sub(self.offset).expect("parameter out of bounds"); diff --git a/src/system/access/single.rs b/src/system/access/single.rs index 9cc4bb1d5b..1e5aedde82 100644 --- a/src/system/access/single.rs +++ b/src/system/access/single.rs @@ -184,6 +184,17 @@ where self.storage.get_mut(entity.id()) } + /// Returns mutable references to the components for the specified entities. + /// + /// Returns `None` if any component is not present in the entity + /// or if the same entity is passed multiple times. + pub fn try_get_many_mut( + &mut self, + entities: [impl entity::Ref; N], + ) -> Option<[&mut C; N]> { + self.storage.get_many_mut(entities.map(|entity| entity.id())) + } + /// Iterates over mutable references to all initialized components in this storage. pub fn iter_mut<'t>( &'t mut self, diff --git a/src/util.rs b/src/util.rs index 8927be1ee2..f0e85ecb3b 100644 --- a/src/util.rs +++ b/src/util.rs @@ -145,9 +145,22 @@ unsafe impl UnsafeEqOrd for usize {} /// Transforms a value behind a mutable reference with a function that moves it. /// /// The placeholder value will be left at the position of `ref_` if the transform function panics. -pub fn transform_mut(ref_: &mut T, placeholder: T, transform: impl FnOnce(T) -> (T, R)) -> R { +pub(crate) fn transform_mut( + ref_: &mut T, + placeholder: T, + transform: impl FnOnce(T) -> (T, R), +) -> R { let old = mem::replace(ref_, placeholder); let (new, ret) = transform(old); *ref_ = new; ret } + +pub(crate) fn is_all_distinct_quadtime(slice: &[T]) -> bool { + for (i, item) in slice.iter().enumerate() { + if !slice[(i + 1)..].iter().all(|other| item == other) { + return false; + } + } + true +} From 77f986ee3a97e233f66c98530211dfc5b83efa19 Mon Sep 17 00:00:00 2001 From: SOFe Date: Sun, 17 Dec 2023 12:16:30 +0800 Subject: [PATCH 2/2] test: add unit tests for get_many_mut --- src/entity.rs | 3 +++ src/system/access/single.rs | 19 +++++++++++++++++ src/system/access/single/tests.rs | 34 ++++++++++++++++++++++++++++++- 3 files changed, 55 insertions(+), 1 deletion(-) diff --git a/src/entity.rs b/src/entity.rs index 8ce71c223a..92061dd593 100644 --- a/src/entity.rs +++ b/src/entity.rs @@ -43,6 +43,9 @@ pub trait Ref: sealed::Sealed { /// The underlying entity ID referenced. fn id(&self) -> ::RawEntity; + + /// Converts this entity reference to a homogeneous temporary. + fn as_ref(&self) -> TempRef<'_, Self::Archetype> { TempRef::new(self.id()) } } /// A temporary, non-`'static` reference to an entity. diff --git a/src/system/access/single.rs b/src/system/access/single.rs index 1e5aedde82..2c96e69f58 100644 --- a/src/system/access/single.rs +++ b/src/system/access/single.rs @@ -234,6 +234,25 @@ where ), } } + + /// Returns mutable references to the component for the specified entities. + /// + /// # Panics + /// Panics if `entities` contains duplicate items. + pub fn get_many_mut( + &mut self, + entities: [impl entity::Ref; N], + ) -> [&mut C; N] { + match self.try_get_many_mut(entities) { + Some(comps) => comps, + None => panic!( + "Parameter contains duplicate entities, or component {}/{} implements comp::Must \ + but is not present", + any::type_name::(), + any::type_name::(), + ), + } + } } #[derive_trait(pub Set{ diff --git a/src/system/access/single/tests.rs b/src/system/access/single/tests.rs index d7511d8981..e99912e050 100644 --- a/src/system/access/single/tests.rs +++ b/src/system/access/single/tests.rs @@ -1,5 +1,6 @@ //! Tests simple storage access. +use crate::entity::{generation, Ref as _}; use crate::test_util::*; use crate::{system, system_test, tracer}; @@ -10,7 +11,7 @@ fn test_simple_fetch() { mut comp5: system::WriteSimple, #[dynec(global)] initials: &InitialEntities, ) { - let ent = initials.strong.as_ref().expect("initials.strong is None"); + let ent = initials.strong.as_ref().expect("initials.strong is assigned during init"); let comp = comp5.get_mut(ent); assert_eq!(comp.0, 7); @@ -28,3 +29,34 @@ fn test_simple_fetch() { let comp = storage.try_get(ent); assert_eq!(comp, Some(&Simple5RequiredNoInit(20))); } + +#[test] +fn test_get_many() { + #[system(dynec_as(crate))] + fn test_system( + mut comp5: system::WriteSimple, + #[dynec(global)] initials: &InitialEntities, + ) { + let strong = initials.strong.as_ref().expect("initials.stonrg is assigned during init"); + let weak = initials.weak.as_ref().expect("initials.weak is assigned during init"); + + let [strong_comp, weak_comp] = comp5.get_many_mut([strong.as_ref(), weak.as_ref()]); + strong_comp.0 += 11; + weak_comp.0 += 13; + } + + let mut world = system_test!(test_system.build();); + + let strong = world.create(crate::comps![@(crate) TestArch => Simple5RequiredNoInit(7)]); + world.get_global::().strong = Some(strong.clone()); + + let weak = world.create(crate::comps![@(crate) TestArch => Simple5RequiredNoInit(3)]); + world.get_global::().weak = + Some(weak.weak(world.get_global::())); + + world.execute(&tracer::Log(log::Level::Trace)); + + let storage = world.components.get_simple_storage::(); + assert_eq!(storage.try_get(&strong), Some(&Simple5RequiredNoInit(7 + 11))); + assert_eq!(storage.try_get(&weak), Some(&Simple5RequiredNoInit(3 + 13))); +}