Skip to content

Commit d603d0f

Browse files
committed
Auto merge of #115515 - the8472:zip-for-arrays, r=scottmcm
optimize zipping over array iterators Fixes #115339 (somewhat) the new assembly: ```asm zip_arrays: .cfi_startproc vmovups (%rdx), %ymm0 leaq 32(%rsi), %rcx vxorps %xmm1, %xmm1, %xmm1 vmovups %xmm1, -24(%rsp) movq $0, -8(%rsp) movq %rsi, -88(%rsp) movq %rdi, %rax movq %rcx, -80(%rsp) vmovups %ymm0, -72(%rsp) movq $0, -40(%rsp) movq $32, -32(%rsp) movq -24(%rsp), %rcx vmovups (%rsi,%rcx), %ymm0 vorps -72(%rsp,%rcx), %ymm0, %ymm0 vmovups %ymm0, (%rsi,%rcx) vmovups (%rsi), %ymm0 vmovups %ymm0, (%rdi) vzeroupper retq ``` This is still longer than the slice version given in the issue but at least it eliminates the terrible `vpextrb`/`orb` chain. I guess this is due to excessive memcpys again (haven't looked at the llvmir)? The `TrustedLen` specialization is a drive-by change since I had to do something for the default impl anyway to be able to specialize the `TrustedRandomAccessNoCoerce` impl.
2 parents ff05789 + 0580b27 commit d603d0f

File tree

4 files changed

+146
-2
lines changed

4 files changed

+146
-2
lines changed

library/core/src/array/iter.rs

+26-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::num::NonZeroUsize;
44
use crate::{
55
fmt,
66
intrinsics::transmute_unchecked,
7-
iter::{self, ExactSizeIterator, FusedIterator, TrustedLen},
7+
iter::{self, ExactSizeIterator, FusedIterator, TrustedLen, TrustedRandomAccessNoCoerce},
88
mem::MaybeUninit,
99
ops::{IndexRange, Range},
1010
ptr,
@@ -293,6 +293,12 @@ impl<T, const N: usize> Iterator for IntoIter<T, N> {
293293

294294
NonZeroUsize::new(remaining).map_or(Ok(()), Err)
295295
}
296+
297+
#[inline]
298+
unsafe fn __iterator_get_unchecked(&mut self, idx: usize) -> Self::Item {
299+
// SAFETY: The caller must provide an idx that is in bound of the remainder.
300+
unsafe { self.data.as_ptr().add(self.alive.start()).add(idx).cast::<T>().read() }
301+
}
296302
}
297303

298304
#[stable(feature = "array_value_iter_impls", since = "1.40.0")]
@@ -374,6 +380,25 @@ impl<T, const N: usize> FusedIterator for IntoIter<T, N> {}
374380
#[stable(feature = "array_value_iter_impls", since = "1.40.0")]
375381
unsafe impl<T, const N: usize> TrustedLen for IntoIter<T, N> {}
376382

383+
#[doc(hidden)]
384+
#[unstable(issue = "none", feature = "std_internals")]
385+
#[rustc_unsafe_specialization_marker]
386+
pub trait NonDrop {}
387+
388+
// T: Copy as approximation for !Drop since get_unchecked does not advance self.alive
389+
// and thus we can't implement drop-handling
390+
#[unstable(issue = "none", feature = "std_internals")]
391+
impl<T: Copy> NonDrop for T {}
392+
393+
#[doc(hidden)]
394+
#[unstable(issue = "none", feature = "std_internals")]
395+
unsafe impl<T, const N: usize> TrustedRandomAccessNoCoerce for IntoIter<T, N>
396+
where
397+
T: NonDrop,
398+
{
399+
const MAY_HAVE_SIDE_EFFECT: bool = false;
400+
}
401+
377402
#[stable(feature = "array_value_iter_impls", since = "1.40.0")]
378403
impl<T: Clone, const N: usize> Clone for IntoIter<T, N> {
379404
fn clone(&self) -> Self {

library/core/src/iter/adapters/zip.rs

+90
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,14 @@ where
9494
ZipImpl::nth(self, n)
9595
}
9696

97+
#[inline]
98+
fn fold<Acc, F>(self, init: Acc, f: F) -> Acc
99+
where
100+
F: FnMut(Acc, Self::Item) -> Acc,
101+
{
102+
ZipImpl::fold(self, init, f)
103+
}
104+
97105
#[inline]
98106
unsafe fn __iterator_get_unchecked(&mut self, idx: usize) -> Self::Item
99107
where
@@ -129,6 +137,9 @@ trait ZipImpl<A, B> {
129137
where
130138
A: DoubleEndedIterator + ExactSizeIterator,
131139
B: DoubleEndedIterator + ExactSizeIterator;
140+
fn fold<Acc, F>(self, init: Acc, f: F) -> Acc
141+
where
142+
F: FnMut(Acc, Self::Item) -> Acc;
132143
// This has the same safety requirements as `Iterator::__iterator_get_unchecked`
133144
unsafe fn get_unchecked(&mut self, idx: usize) -> <Self as Iterator>::Item
134145
where
@@ -228,6 +239,14 @@ where
228239
{
229240
unreachable!("Always specialized");
230241
}
242+
243+
#[inline]
244+
default fn fold<Acc, F>(self, init: Acc, f: F) -> Acc
245+
where
246+
F: FnMut(Acc, Self::Item) -> Acc,
247+
{
248+
SpecFold::spec_fold(self, init, f)
249+
}
231250
}
232251

233252
#[doc(hidden)]
@@ -251,6 +270,24 @@ where
251270
// `Iterator::__iterator_get_unchecked`.
252271
unsafe { (self.a.__iterator_get_unchecked(idx), self.b.__iterator_get_unchecked(idx)) }
253272
}
273+
274+
#[inline]
275+
fn fold<Acc, F>(mut self, init: Acc, mut f: F) -> Acc
276+
where
277+
F: FnMut(Acc, Self::Item) -> Acc,
278+
{
279+
let mut accum = init;
280+
let len = ZipImpl::size_hint(&self).0;
281+
for i in 0..len {
282+
// SAFETY: since Self: TrustedRandomAccessNoCoerce we can trust the size-hint to
283+
// calculate the length and then use that to do unchecked iteration.
284+
// fold consumes the iterator so we don't need to fixup any state.
285+
unsafe {
286+
accum = f(accum, self.get_unchecked(i));
287+
}
288+
}
289+
accum
290+
}
254291
}
255292

256293
#[doc(hidden)]
@@ -590,3 +627,56 @@ unsafe impl<I: Iterator + TrustedRandomAccessNoCoerce> SpecTrustedRandomAccess f
590627
unsafe { self.__iterator_get_unchecked(index) }
591628
}
592629
}
630+
631+
trait SpecFold: Iterator {
632+
fn spec_fold<B, F>(self, init: B, f: F) -> B
633+
where
634+
Self: Sized,
635+
F: FnMut(B, Self::Item) -> B;
636+
}
637+
638+
impl<A: Iterator, B: Iterator> SpecFold for Zip<A, B> {
639+
// Adapted from default impl from the Iterator trait
640+
#[inline]
641+
default fn spec_fold<Acc, F>(mut self, init: Acc, mut f: F) -> Acc
642+
where
643+
F: FnMut(Acc, Self::Item) -> Acc,
644+
{
645+
let mut accum = init;
646+
while let Some(x) = ZipImpl::next(&mut self) {
647+
accum = f(accum, x);
648+
}
649+
accum
650+
}
651+
}
652+
653+
impl<A: TrustedLen, B: TrustedLen> SpecFold for Zip<A, B> {
654+
#[inline]
655+
fn spec_fold<Acc, F>(mut self, init: Acc, mut f: F) -> Acc
656+
where
657+
F: FnMut(Acc, Self::Item) -> Acc,
658+
{
659+
let mut accum = init;
660+
loop {
661+
let (upper, more) = if let Some(upper) = ZipImpl::size_hint(&self).1 {
662+
(upper, false)
663+
} else {
664+
// Per TrustedLen contract a None upper bound means more than usize::MAX items
665+
(usize::MAX, true)
666+
};
667+
668+
for _ in 0..upper {
669+
let pair =
670+
// SAFETY: TrustedLen guarantees that at least `upper` many items are available
671+
// therefore we know they can't be None
672+
unsafe { (self.a.next().unwrap_unchecked(), self.b.next().unwrap_unchecked()) };
673+
accum = f(accum, pair);
674+
}
675+
676+
if !more {
677+
break;
678+
}
679+
}
680+
accum
681+
}
682+
}

library/core/tests/iter/adapters/zip.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,11 @@ fn test_zip_nested_sideffectful() {
184184
let it = xs.iter_mut().map(|x| *x = 1).enumerate().zip(&ys);
185185
it.count();
186186
}
187-
assert_eq!(&xs, &[1, 1, 1, 1, 1, 0]);
187+
let length_aware = &xs == &[1, 1, 1, 1, 0, 0];
188+
let probe_first = &xs == &[1, 1, 1, 1, 1, 0];
189+
190+
// either implementation is valid according to zip documentation
191+
assert!(length_aware || probe_first);
188192
}
189193

190194
#[test]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// assembly-output: emit-asm
2+
// # zen3 previously exhibited odd vectorization
3+
// compile-flags: --crate-type=lib -Ctarget-cpu=znver3
4+
// only-x86_64
5+
// ignore-sgx
6+
7+
use std::iter;
8+
9+
// previously this produced a long chain of
10+
// 56: vpextrb $6, %xmm0, %ecx
11+
// 57: orb %cl, 22(%rsi)
12+
// 58: vpextrb $7, %xmm0, %ecx
13+
// 59: orb %cl, 23(%rsi)
14+
// [...]
15+
16+
// CHECK-LABEL: zip_arrays:
17+
#[no_mangle]
18+
pub fn zip_arrays(mut a: [u8; 32], b: [u8; 32]) -> [u8; 32] {
19+
// CHECK-NOT: vpextrb
20+
// CHECK-NOT: orb %cl
21+
// CHECK: vorps
22+
iter::zip(&mut a, b).for_each(|(a, b)| *a |= b);
23+
// CHECK: retq
24+
a
25+
}

0 commit comments

Comments
 (0)