Skip to content

Commit

Permalink
Auto merge of rust-lang#115515 - the8472:zip-for-arrays, r=scottmcm
Browse files Browse the repository at this point in the history
optimize zipping over array iterators

Fixes rust-lang#115339 (somewhat)

the new assembly:

```asm
zip_arrays:
        .cfi_startproc
        vmovups (%rdx), %ymm0
        leaq    32(%rsi), %rcx
        vxorps  %xmm1, %xmm1, %xmm1
        vmovups %xmm1, -24(%rsp)
        movq    $0, -8(%rsp)
        movq    %rsi, -88(%rsp)
        movq    %rdi, %rax
        movq    %rcx, -80(%rsp)
        vmovups %ymm0, -72(%rsp)
        movq    $0, -40(%rsp)
        movq    $32, -32(%rsp)
        movq    -24(%rsp), %rcx
        vmovups (%rsi,%rcx), %ymm0
        vorps   -72(%rsp,%rcx), %ymm0, %ymm0
        vmovups %ymm0, (%rsi,%rcx)
        vmovups (%rsi), %ymm0
        vmovups %ymm0, (%rdi)
        vzeroupper
        retq
```

This is still longer than the slice version given in the issue but at least it eliminates the terrible  `vpextrb`/`orb` chain. I guess this is due to excessive memcpys again (haven't looked at the llvmir)?

The `TrustedLen` specialization is a drive-by change since I had to do something for the default impl anyway to be able to specialize the `TrustedRandomAccessNoCoerce` impl.
  • Loading branch information
bors committed Oct 6, 2023
2 parents ff05789 + 0580b27 commit d603d0f
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 2 deletions.
27 changes: 26 additions & 1 deletion library/core/src/array/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::num::NonZeroUsize;
use crate::{
fmt,
intrinsics::transmute_unchecked,
iter::{self, ExactSizeIterator, FusedIterator, TrustedLen},
iter::{self, ExactSizeIterator, FusedIterator, TrustedLen, TrustedRandomAccessNoCoerce},
mem::MaybeUninit,
ops::{IndexRange, Range},
ptr,
Expand Down Expand Up @@ -293,6 +293,12 @@ impl<T, const N: usize> Iterator for IntoIter<T, N> {

NonZeroUsize::new(remaining).map_or(Ok(()), Err)
}

#[inline]
unsafe fn __iterator_get_unchecked(&mut self, idx: usize) -> Self::Item {
// SAFETY: The caller must provide an idx that is in bound of the remainder.
unsafe { self.data.as_ptr().add(self.alive.start()).add(idx).cast::<T>().read() }
}
}

#[stable(feature = "array_value_iter_impls", since = "1.40.0")]
Expand Down Expand Up @@ -374,6 +380,25 @@ impl<T, const N: usize> FusedIterator for IntoIter<T, N> {}
#[stable(feature = "array_value_iter_impls", since = "1.40.0")]
unsafe impl<T, const N: usize> TrustedLen for IntoIter<T, N> {}

#[doc(hidden)]
#[unstable(issue = "none", feature = "std_internals")]
#[rustc_unsafe_specialization_marker]
pub trait NonDrop {}

// T: Copy as approximation for !Drop since get_unchecked does not advance self.alive
// and thus we can't implement drop-handling
#[unstable(issue = "none", feature = "std_internals")]
impl<T: Copy> NonDrop for T {}

#[doc(hidden)]
#[unstable(issue = "none", feature = "std_internals")]
unsafe impl<T, const N: usize> TrustedRandomAccessNoCoerce for IntoIter<T, N>
where
T: NonDrop,
{
const MAY_HAVE_SIDE_EFFECT: bool = false;
}

#[stable(feature = "array_value_iter_impls", since = "1.40.0")]
impl<T: Clone, const N: usize> Clone for IntoIter<T, N> {
fn clone(&self) -> Self {
Expand Down
90 changes: 90 additions & 0 deletions library/core/src/iter/adapters/zip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ where
ZipImpl::nth(self, n)
}

#[inline]
fn fold<Acc, F>(self, init: Acc, f: F) -> Acc
where
F: FnMut(Acc, Self::Item) -> Acc,
{
ZipImpl::fold(self, init, f)
}

#[inline]
unsafe fn __iterator_get_unchecked(&mut self, idx: usize) -> Self::Item
where
Expand Down Expand Up @@ -129,6 +137,9 @@ trait ZipImpl<A, B> {
where
A: DoubleEndedIterator + ExactSizeIterator,
B: DoubleEndedIterator + ExactSizeIterator;
fn fold<Acc, F>(self, init: Acc, f: F) -> Acc
where
F: FnMut(Acc, Self::Item) -> Acc;
// This has the same safety requirements as `Iterator::__iterator_get_unchecked`
unsafe fn get_unchecked(&mut self, idx: usize) -> <Self as Iterator>::Item
where
Expand Down Expand Up @@ -228,6 +239,14 @@ where
{
unreachable!("Always specialized");
}

#[inline]
default fn fold<Acc, F>(self, init: Acc, f: F) -> Acc
where
F: FnMut(Acc, Self::Item) -> Acc,
{
SpecFold::spec_fold(self, init, f)
}
}

#[doc(hidden)]
Expand All @@ -251,6 +270,24 @@ where
// `Iterator::__iterator_get_unchecked`.
unsafe { (self.a.__iterator_get_unchecked(idx), self.b.__iterator_get_unchecked(idx)) }
}

#[inline]
fn fold<Acc, F>(mut self, init: Acc, mut f: F) -> Acc
where
F: FnMut(Acc, Self::Item) -> Acc,
{
let mut accum = init;
let len = ZipImpl::size_hint(&self).0;
for i in 0..len {
// SAFETY: since Self: TrustedRandomAccessNoCoerce we can trust the size-hint to
// calculate the length and then use that to do unchecked iteration.
// fold consumes the iterator so we don't need to fixup any state.
unsafe {
accum = f(accum, self.get_unchecked(i));
}
}
accum
}
}

#[doc(hidden)]
Expand Down Expand Up @@ -590,3 +627,56 @@ unsafe impl<I: Iterator + TrustedRandomAccessNoCoerce> SpecTrustedRandomAccess f
unsafe { self.__iterator_get_unchecked(index) }
}
}

trait SpecFold: Iterator {
fn spec_fold<B, F>(self, init: B, f: F) -> B
where
Self: Sized,
F: FnMut(B, Self::Item) -> B;
}

impl<A: Iterator, B: Iterator> SpecFold for Zip<A, B> {
// Adapted from default impl from the Iterator trait
#[inline]
default fn spec_fold<Acc, F>(mut self, init: Acc, mut f: F) -> Acc
where
F: FnMut(Acc, Self::Item) -> Acc,
{
let mut accum = init;
while let Some(x) = ZipImpl::next(&mut self) {
accum = f(accum, x);
}
accum
}
}

impl<A: TrustedLen, B: TrustedLen> SpecFold for Zip<A, B> {
#[inline]
fn spec_fold<Acc, F>(mut self, init: Acc, mut f: F) -> Acc
where
F: FnMut(Acc, Self::Item) -> Acc,
{
let mut accum = init;
loop {
let (upper, more) = if let Some(upper) = ZipImpl::size_hint(&self).1 {
(upper, false)
} else {
// Per TrustedLen contract a None upper bound means more than usize::MAX items
(usize::MAX, true)
};

for _ in 0..upper {
let pair =
// SAFETY: TrustedLen guarantees that at least `upper` many items are available
// therefore we know they can't be None
unsafe { (self.a.next().unwrap_unchecked(), self.b.next().unwrap_unchecked()) };
accum = f(accum, pair);
}

if !more {
break;
}
}
accum
}
}
6 changes: 5 additions & 1 deletion library/core/tests/iter/adapters/zip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,11 @@ fn test_zip_nested_sideffectful() {
let it = xs.iter_mut().map(|x| *x = 1).enumerate().zip(&ys);
it.count();
}
assert_eq!(&xs, &[1, 1, 1, 1, 1, 0]);
let length_aware = &xs == &[1, 1, 1, 1, 0, 0];
let probe_first = &xs == &[1, 1, 1, 1, 1, 0];

// either implementation is valid according to zip documentation
assert!(length_aware || probe_first);
}

#[test]
Expand Down
25 changes: 25 additions & 0 deletions tests/assembly/libs/issue-115339-zip-arrays.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// assembly-output: emit-asm
// # zen3 previously exhibited odd vectorization
// compile-flags: --crate-type=lib -Ctarget-cpu=znver3
// only-x86_64
// ignore-sgx

use std::iter;

// previously this produced a long chain of
// 56: vpextrb $6, %xmm0, %ecx
// 57: orb %cl, 22(%rsi)
// 58: vpextrb $7, %xmm0, %ecx
// 59: orb %cl, 23(%rsi)
// [...]

// CHECK-LABEL: zip_arrays:
#[no_mangle]
pub fn zip_arrays(mut a: [u8; 32], b: [u8; 32]) -> [u8; 32] {
// CHECK-NOT: vpextrb
// CHECK-NOT: orb %cl
// CHECK: vorps
iter::zip(&mut a, b).for_each(|(a, b)| *a |= b);
// CHECK: retq
a
}

0 comments on commit d603d0f

Please sign in to comment.