Skip to content

Commit

Permalink
Merge pull request #54 from SOF3/get-many-mut
Browse files Browse the repository at this point in the history
feat: add Access::try_get_many_mut
  • Loading branch information
SOF3 authored Dec 17, 2023
2 parents eceaed3 + 77f986e commit 1dac21b
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/entity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ pub trait Ref: sealed::Sealed {

/// The underlying entity ID referenced.
fn id(&self) -> <Self::Archetype as Archetype>::RawEntity;

/// Converts this entity reference to a homogeneous temporary.
fn as_ref(&self) -> TempRef<'_, Self::Archetype> { TempRef::new(self.id()) }
}

/// A temporary, non-`'static` reference to an entity.
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@
#![feature(never_type)]
#![feature(sync_unsafe_cell)]
#![feature(slice_take)]
#![feature(get_many_mut)]
#![feature(array_try_from_fn, array_try_map)]

/// Internal re-exports used in macros.
#[doc(hidden)]
Expand Down
15 changes: 15 additions & 0 deletions src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ pub trait Partition<'t>: Access + Send + Sync + Sized + 't {

/// Same as [`get_mut`](Access::get_mut), but returns a reference with lifetime `'t`.
fn into_mut(self, entity: Self::RawEntity) -> Option<&'t mut Self::Comp>;

/// Same as [`get_many_mut`](Access::get_many_mut), but returns a reference with lifetime `'t`.
fn into_many_mut<const N: usize>(
self,
entities: [Self::RawEntity; N],
) -> Option<[&'t mut Self::Comp; N]>;
}

/// Mutable access functions for a storage, generalizing [`Storage`] and [`Partition`].
Expand All @@ -118,6 +124,15 @@ pub trait Access {
/// Gets a mutable reference to the component for a specific entity if it is present.
fn get_mut(&mut self, entity: Self::RawEntity) -> Option<&mut Self::Comp>;

/// Gets mutable references to the components for specific entities if they are present.
///
/// Returns `None` if any entity is uninitialized
/// or if any entity appeared in `entities` more than once.
fn get_many_mut<const N: usize>(
&mut self,
entities: [Self::RawEntity; N],
) -> Option<[&mut Self::Comp; N]>;

/// Return value of [`iter_mut`](Self::iter_mut).
type IterMut<'u>: Iterator<Item = (Self::RawEntity, &'u mut Self::Comp)> + 'u
where
Expand Down
72 changes: 70 additions & 2 deletions src/storage/tree.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::cell::SyncUnsafeCell;
use std::collections::BTreeMap;
use std::slice;
use std::ptr::NonNull;
use std::{array, slice};

use super::{Access, ChunkMut, ChunkRef, Partition, Storage};
use crate::entity;
use crate::{entity, util};

/// A storage based on [`BTreeMap`].
pub struct Tree<RawT: entity::Raw, C> {
Expand All @@ -26,6 +27,36 @@ impl<RawT: entity::Raw, C: Send + Sync + 'static> Access for Tree<RawT, C> {
self.data.get_mut(&id).map(|cell| cell.get_mut())
}

fn get_many_mut<const N: usize>(
&mut self,
entities: [Self::RawEntity; N],
) -> Option<[&mut Self::Comp; N]> {
let ptrs = entities.map(|entity| {
let datum = self.data.get(&entity)?;
let ptr = datum.get();
NonNull::new(ptr)
});

if !util::is_all_distinct_quadtime(&ptrs) {
return None;
}

if ptrs.iter().any(|ptr| ptr.is_none()) {
return None;
}

Some(ptrs.map(|ptr| {
let mut ptr = ptr.expect("checked all are not none");

unsafe {
// All pointers originated from a `&mut self`, so all possible aliases are in locals.
// We have checked that all `ptrs` are distinct,
// and since they come from UnsafeCell, they cannot overlap.
ptr.as_mut()
}
}))
}

type IterMut<'t> = impl Iterator<Item = (Self::RawEntity, &'t mut Self::Comp)> + 't;
fn iter_mut(&mut self) -> Self::IterMut<'_> {
Box::new(self.data.iter_mut().map(|(&entity, cell)| (entity, cell.get_mut())))
Expand Down Expand Up @@ -111,6 +142,13 @@ impl<'t, RawT: entity::Raw, C: Send + Sync + 'static> Access for StoragePartitio
}
}

fn get_many_mut<const N: usize>(
&mut self,
entities: [RawT; N],
) -> Option<[&mut Self::Comp; N]> {
self.by_ref().into_many_mut(entities)
}

type IterMut<'u> = impl Iterator<Item = (Self::RawEntity, &'u mut Self::Comp)> + 'u where Self: 'u;
fn iter_mut(&mut self) -> Self::IterMut<'_> { self.by_ref().into_iter_mut() }
}
Expand Down Expand Up @@ -158,6 +196,36 @@ impl<'t, RawT: entity::Raw, C: Send + Sync + 'static> Partition<'t>
}
}

fn into_many_mut<const N: usize>(
self,
entities: [Self::RawEntity; N],
) -> Option<[&'t mut Self::Comp; N]> {
for entity in entities {
self.assert_bounds(entity);
}

let ptrs = entities.map(|entity| {
let datum = self.data.get(&entity)?;
let ptr = datum.get();
NonNull::new(ptr)
});

if !util::is_all_distinct_quadtime(&ptrs) {
return None;
}

array::try_from_fn(|i| {
let mut ptr = ptrs[i]?;

unsafe {
// All pointers originated from a `&mut self`, so all possible aliases are in locals.
// We have checked that all `ptrs` are distinct,
// and since they come from UnsafeCell, they cannot overlap.
Some(ptr.as_mut())
}
})
}

fn split_out(&mut self, entity: RawT) -> Self {
self.assert_bounds(entity);

Expand Down
45 changes: 45 additions & 0 deletions src/storage/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,24 @@ impl<RawT: entity::Raw, C: Send + Sync + 'static> Access for VecStorage<RawT, C>
}
}

fn get_many_mut<const N: usize>(
&mut self,
entities: [RawT; N],
) -> Option<[&mut Self::Comp; N]> {
let indices = entities.map(|id| id.to_primitive());

if !indices.iter().all(|&index| self.bit(index)) {
return None;
}

let values = self.data.get_many_mut(indices).ok()?;

Some(values.map(|value| {
// Safety: values correspond to indices checked above.
unsafe { value.assume_init_mut() }
}))
}

type IterMut<'t> = impl Iterator<Item = (RawT, &'t mut C)> + 't;
fn iter_mut(&mut self) -> Self::IterMut<'_> { iter_mut(0, &self.bits, &mut self.data) }
}
Expand Down Expand Up @@ -186,6 +204,13 @@ impl<'t, RawT: entity::Raw, C: Send + Sync + 'static> Access for StoragePartitio

fn get_mut(&mut self, entity: RawT) -> Option<&mut C> { self.by_ref().into_mut(entity) }

fn get_many_mut<const N: usize>(
&mut self,
entities: [RawT; N],
) -> Option<[&mut Self::Comp; N]> {
self.by_ref().into_many_mut(entities)
}

type IterMut<'u> = impl Iterator<Item = (RawT, &'u mut C)> + 'u where Self: 'u;
fn iter_mut(&mut self) -> Self::IterMut<'_> { self.by_ref().into_iter_mut() }
}
Expand Down Expand Up @@ -220,6 +245,26 @@ impl<'t, RawT: entity::Raw, C: Send + Sync + 'static> Partition<'t>
}
}

fn into_many_mut<const N: usize>(
self,
entities: [Self::RawEntity; N],
) -> Option<[&'t mut Self::Comp; N]> {
let indices: [usize; N] =
entities.try_map(|entity| match entity.to_primitive().checked_sub(self.offset) {
Some(index) => match self.bits.get(index) {
Some(bit) if *bit => Some(index),
_ => None,
},
None => panic!("Entity {entity:?} is not in the partition {:?}..", self.offset),
})?;
let values = self.data.get_many_mut(indices).ok()?;
Some(values.map(move |value| {
// Safety: all indices have been checked to be initialized
// before getting mapped into `indices`
unsafe { value.assume_init_mut() }
}))
}

fn split_out(&mut self, entity: RawT) -> Self {
let index =
entity.to_primitive().checked_sub(self.offset).expect("parameter out of bounds");
Expand Down
30 changes: 30 additions & 0 deletions src/system/access/single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,17 @@ where
self.storage.get_mut(entity.id())
}

/// Returns mutable references to the components for the specified entities.
///
/// Returns `None` if any component is not present in the entity
/// or if the same entity is passed multiple times.
pub fn try_get_many_mut<const N: usize>(
&mut self,
entities: [impl entity::Ref<Archetype = A>; N],
) -> Option<[&mut C; N]> {
self.storage.get_many_mut(entities.map(|entity| entity.id()))
}

/// Iterates over mutable references to all initialized components in this storage.
pub fn iter_mut<'t>(
&'t mut self,
Expand Down Expand Up @@ -223,6 +234,25 @@ where
),
}
}

/// Returns mutable references to the component for the specified entities.
///
/// # Panics
/// Panics if `entities` contains duplicate items.
pub fn get_many_mut<const N: usize>(
&mut self,
entities: [impl entity::Ref<Archetype = A>; N],
) -> [&mut C; N] {
match self.try_get_many_mut(entities) {
Some(comps) => comps,
None => panic!(
"Parameter contains duplicate entities, or component {}/{} implements comp::Must \
but is not present",
any::type_name::<A>(),
any::type_name::<C>(),
),
}
}
}

#[derive_trait(pub Set{
Expand Down
34 changes: 33 additions & 1 deletion src/system/access/single/tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Tests simple storage access.
use crate::entity::{generation, Ref as _};
use crate::test_util::*;
use crate::{system, system_test, tracer};

Expand All @@ -10,7 +11,7 @@ fn test_simple_fetch() {
mut comp5: system::WriteSimple<TestArch, Simple5RequiredNoInit>,
#[dynec(global)] initials: &InitialEntities,
) {
let ent = initials.strong.as_ref().expect("initials.strong is None");
let ent = initials.strong.as_ref().expect("initials.strong is assigned during init");

let comp = comp5.get_mut(ent);
assert_eq!(comp.0, 7);
Expand All @@ -28,3 +29,34 @@ fn test_simple_fetch() {
let comp = storage.try_get(ent);
assert_eq!(comp, Some(&Simple5RequiredNoInit(20)));
}

#[test]
fn test_get_many() {
#[system(dynec_as(crate))]
fn test_system(
mut comp5: system::WriteSimple<TestArch, Simple5RequiredNoInit>,
#[dynec(global)] initials: &InitialEntities,
) {
let strong = initials.strong.as_ref().expect("initials.stonrg is assigned during init");
let weak = initials.weak.as_ref().expect("initials.weak is assigned during init");

let [strong_comp, weak_comp] = comp5.get_many_mut([strong.as_ref(), weak.as_ref()]);
strong_comp.0 += 11;
weak_comp.0 += 13;
}

let mut world = system_test!(test_system.build(););

let strong = world.create(crate::comps![@(crate) TestArch => Simple5RequiredNoInit(7)]);
world.get_global::<InitialEntities>().strong = Some(strong.clone());

let weak = world.create(crate::comps![@(crate) TestArch => Simple5RequiredNoInit(3)]);
world.get_global::<InitialEntities>().weak =
Some(weak.weak(world.get_global::<generation::StoreMap>()));

world.execute(&tracer::Log(log::Level::Trace));

let storage = world.components.get_simple_storage::<TestArch, Simple5RequiredNoInit>();
assert_eq!(storage.try_get(&strong), Some(&Simple5RequiredNoInit(7 + 11)));
assert_eq!(storage.try_get(&weak), Some(&Simple5RequiredNoInit(3 + 13)));
}
15 changes: 14 additions & 1 deletion src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,22 @@ unsafe impl UnsafeEqOrd for usize {}
/// Transforms a value behind a mutable reference with a function that moves it.
///
/// The placeholder value will be left at the position of `ref_` if the transform function panics.
pub fn transform_mut<T, R>(ref_: &mut T, placeholder: T, transform: impl FnOnce(T) -> (T, R)) -> R {
pub(crate) fn transform_mut<T, R>(
ref_: &mut T,
placeholder: T,
transform: impl FnOnce(T) -> (T, R),
) -> R {
let old = mem::replace(ref_, placeholder);
let (new, ret) = transform(old);
*ref_ = new;
ret
}

pub(crate) fn is_all_distinct_quadtime<T: PartialEq>(slice: &[T]) -> bool {
for (i, item) in slice.iter().enumerate() {
if !slice[(i + 1)..].iter().all(|other| item == other) {
return false;
}
}
true
}

0 comments on commit 1dac21b

Please sign in to comment.