Skip to content

Commit

Permalink
Limit MultiSlice impls to flat tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
jturner314 committed Sep 19, 2019
1 parent 8569279 commit 601c1bf
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 134 deletions.
2 changes: 1 addition & 1 deletion src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ where
M: MultiSlice<'a, A, D>,
S: DataMut,
{
unsafe { info.slice_and_deref(self.raw_view_mut()) }
info.multi_slice_move(self.view_mut())
}

/// Slice the array, possibly changing the number of dimensions.
Expand Down
4 changes: 2 additions & 2 deletions src/impl_views/splitting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ where
/// * if any of the views would intersect (i.e. if any element would appear in multiple slices)
/// * if an index is out of bounds or step size is zero
/// * if `D` is `IxDyn` and `info` does not match the number of array axes
pub fn multi_slice_move<M>(mut self, info: M) -> M::Output
pub fn multi_slice_move<M>(self, info: M) -> M::Output
where
M: MultiSlice<'a, A, D>,
{
unsafe { info.slice_and_deref(self.raw_view_mut()) }
info.multi_slice_move(self)
}
}
206 changes: 76 additions & 130 deletions src/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// except according to those terms.
use crate::dimension::slices_intersect;
use crate::error::{ErrorKind, ShapeError};
use crate::{ArrayViewMut, Dimension, RawArrayViewMut};
use crate::{ArrayViewMut, Dimension};
use std::fmt;
use std::marker::PhantomData;
use std::ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
Expand Down Expand Up @@ -633,186 +633,132 @@ macro_rules! s(

/// Slicing information describing multiple mutable, disjoint slices.
///
/// It's unfortunate that we need `'out` and `A` to be parameters of the trait,
/// It's unfortunate that we need `'a` and `A` to be parameters of the trait,
/// but they're necessary until Rust supports generic associated types.
///
/// # Safety
///
/// Implementers of this trait must ensure that:
///
/// * `.slice_and_deref()` panics or aborts if the slices would intersect, and
///
/// * the `.intersects_self()`, `.intersects_indices()`, and
/// `.intersects_other()` implementations are correct.
pub unsafe trait MultiSlice<'out, A, D>
pub trait MultiSlice<'a, A, D>
where
A: 'out,
A: 'a,
D: Dimension,
{
/// The type of the slices created by `.slice_and_deref()`.
/// The type of the slices created by `.multi_slice_move()`.
type Output;

/// Slice the raw view into multiple raw views, and dereference them.
///
/// **Panics** if performing any individual slice panics or if the slices
/// are not disjoint (i.e. if they intersect).
///
/// # Safety
///
/// The caller must ensure that it is safe to mutably dereference the view
/// using the lifetime `'out`.
unsafe fn slice_and_deref(&self, view: RawArrayViewMut<A, D>) -> Self::Output;

/// Returns `true` if slicing an array of the specified `shape` with `self`
/// would result in intersecting slices.
///
/// If `self.intersects_self(&view.raw_dim())` is `true`, then
/// `self.slice_and_deref(view)` must panic.
fn intersects_self(&self, shape: &D) -> bool;

/// Returns `true` if any slices created by slicing an array of the
/// specified `shape` with `self` would intersect with the specified
/// indices.
///
/// Note that even if this returns `false`, `self.intersects_self(shape)`
/// may still return `true`. (`.intersects_indices()` doesn't check for
/// intersections within `self`; it only checks for intersections between
/// `self` and `indices`.)
fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool;

/// Returns `true` if any slices created by slicing an array of the
/// specified `shape` with `self` would intersect any slices created by
/// slicing the array with `other`.
///
/// Note that even if this returns `false`, `self.intersects_self(shape)`
/// or `other.intersects_self(shape)` may still return `true`.
/// (`.intersects_other()` doesn't check for intersections within `self` or
/// within `other`; it only checks for intersections between `self` and
/// `other`.)
fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool;
fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output;
}

unsafe impl<'out, A, D, Do> MultiSlice<'out, A, D> for SliceInfo<D::SliceArg, Do>
impl<'a, A, D> MultiSlice<'a, A, D> for ()
where
A: 'out,
A: 'a,
D: Dimension,
Do: Dimension,
{
type Output = ArrayViewMut<'out, A, Do>;

unsafe fn slice_and_deref(&self, view: RawArrayViewMut<A, D>) -> Self::Output {
view.slice_move(self).deref_into_view_mut()
}

fn intersects_self(&self, _shape: &D) -> bool {
false
}

fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool {
slices_intersect(shape, &*self, indices)
}
type Output = ();

fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool {
other.intersects_indices(shape, &*self)
}
fn multi_slice_move(&self, _view: ArrayViewMut<'a, A, D>) -> Self::Output {}
}

unsafe impl<'out, A, D> MultiSlice<'out, A, D> for ()
impl<'a, A, D, Do0> MultiSlice<'a, A, D> for (SliceInfo<D::SliceArg, Do0>,)
where
A: 'out,
A: 'a,
D: Dimension,
D::SliceArg: Sized,
Do0: Dimension,
{
type Output = ();
type Output = (ArrayViewMut<'a, A, Do0>,);

unsafe fn slice_and_deref(&self, _view: RawArrayViewMut<A, D>) -> Self::Output {}

fn intersects_self(&self, _shape: &D) -> bool {
false
fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output {
(view.slice_move(&self.0),)
}
}

fn intersects_indices(&self, _shape: &D, _indices: &D::SliceArg) -> bool {
false
}
impl<'a, A, D, Do0> MultiSlice<'a, A, D> for (&SliceInfo<D::SliceArg, Do0>,)
where
A: 'a,
D: Dimension,
Do0: Dimension,
{
type Output = (ArrayViewMut<'a, A, Do0>,);

fn intersects_other(&self, _shape: &D, _other: impl MultiSlice<'out, A, D>) -> bool {
false
fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output {
(view.slice_move(self.0),)
}
}

macro_rules! impl_multislice_tuple {
($($T:ident,)*) => {
unsafe impl<'out, A, D, $($T,)*> MultiSlice<'out, A, D> for ($($T,)*)
($($Do:ident,)*) => {
impl<'a, A, D, $($Do,)*> MultiSlice<'a, A, D> for ($(SliceInfo<D::SliceArg, $Do>,)*)
where
A: 'out,
A: 'a,
D: Dimension,
$($T: MultiSlice<'out, A, D>,)*
D::SliceArg: Sized,
$($Do: Dimension,)*
{
type Output = ($($T::Output,)*);

unsafe fn slice_and_deref(&self, view: RawArrayViewMut<A, D>) -> Self::Output {
assert!(!self.intersects_self(&view.raw_dim()));
type Output = ($(ArrayViewMut<'a, A, $Do>,)*);

fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output {
#[allow(non_snake_case)]
let ($($T,)*) = self;
($($T.slice_and_deref(view.clone()),)*)
}
let ($($Do,)*) = self;

fn intersects_self(&self, shape: &D) -> bool {
#[allow(non_snake_case)]
let ($($T,)*) = self;
impl_multislice_tuple!(@intersects_self shape, ($($T,)*))
}
let shape = view.raw_dim();
assert!(!impl_multislice_tuple!(@intersects_self &shape, ($(&$Do,)*)));

fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool {
#[allow(non_snake_case)]
let ($($T,)*) = self;
$($T.intersects_indices(shape, indices)) ||*
let raw_view = view.into_raw_view_mut();
unsafe {
($(raw_view.clone().slice_move(&$Do).deref_into_view_mut(),)*)
}
}
}

impl<'a, A, D, $($Do,)*> MultiSlice<'a, A, D> for ($(&SliceInfo<D::SliceArg, $Do>,)*)
where
A: 'a,
D: Dimension,
$($Do: Dimension,)*
{
type Output = ($(ArrayViewMut<'a, A, $Do>,)*);

fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool {
fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output {
#[allow(non_snake_case)]
let ($($T,)*) = self;
$($T.intersects_other(shape, &other)) ||*
let ($($Do,)*) = self;

let shape = view.raw_dim();
assert!(!impl_multislice_tuple!(@intersects_self &shape, ($($Do,)*)));

let raw_view = view.into_raw_view_mut();
unsafe {
($(raw_view.clone().slice_move($Do).deref_into_view_mut(),)*)
}
}
}
};

(@intersects_self $shape:expr, ($head:expr,)) => {
$head.intersects_self($shape)
false
};
(@intersects_self $shape:expr, ($head:expr, $($tail:expr,)*)) => {
$head.intersects_self($shape) ||
$($head.intersects_other($shape, &$tail)) ||* ||
impl_multislice_tuple!(@intersects_self $shape, ($($tail,)*))
$(slices_intersect($shape, $head, $tail)) ||*
|| impl_multislice_tuple!(@intersects_self $shape, ($($tail,)*))
};
}
impl_multislice_tuple!(T0,);
impl_multislice_tuple!(T0, T1,);
impl_multislice_tuple!(T0, T1, T2,);
impl_multislice_tuple!(T0, T1, T2, T3,);
impl_multislice_tuple!(T0, T1, T2, T3, T4,);
impl_multislice_tuple!(T0, T1, T2, T3, T4, T5,);

unsafe impl<'out, A, D, T> MultiSlice<'out, A, D> for &'_ T

impl_multislice_tuple!(Do0, Do1,);
impl_multislice_tuple!(Do0, Do1, Do2,);
impl_multislice_tuple!(Do0, Do1, Do2, Do3,);
impl_multislice_tuple!(Do0, Do1, Do2, Do3, Do4,);
impl_multislice_tuple!(Do0, Do1, Do2, Do3, Do4, Do5,);

impl<'a, A, D, T> MultiSlice<'a, A, D> for &T
where
A: 'out,
A: 'a,
D: Dimension,
T: MultiSlice<'out, A, D>,
T: MultiSlice<'a, A, D>,
{
type Output = T::Output;

unsafe fn slice_and_deref(&self, view: RawArrayViewMut<A, D>) -> Self::Output {
T::slice_and_deref(self, view)
}

fn intersects_self(&self, shape: &D) -> bool {
T::intersects_self(self, shape)
}

fn intersects_indices(&self, shape: &D, indices: &D::SliceArg) -> bool {
T::intersects_indices(self, shape, indices)
}

fn intersects_other(&self, shape: &D, other: impl MultiSlice<'out, A, D>) -> bool {
T::intersects_other(self, shape, other)
fn multi_slice_move(&self, view: ArrayViewMut<'a, A, D>) -> Self::Output {
T::multi_slice_move(self, view)
}
}
2 changes: 1 addition & 1 deletion tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ fn test_multislice() {
});
let mut arr = Array1::from_iter(0..48).into_shape((8, 6)).unwrap();

assert_eq!(arr.clone().view(), arr.multi_slice_mut(s![.., ..]));
assert_eq!((arr.clone().view_mut(),), arr.multi_slice_mut((s![.., ..],)));
test_multislice!(&mut arr, s![0, ..], s![1, ..]);
test_multislice!(&mut arr, s![0, ..], s![-1, ..]);
test_multislice!(&mut arr, s![0, ..], s![1.., ..]);
Expand Down

0 comments on commit 601c1bf

Please sign in to comment.