diff --git a/library/core/src/iter/sources/repeat_n.rs b/library/core/src/iter/sources/repeat_n.rs index 9c0621933638e..7e162ff387baf 100644 --- a/library/core/src/iter/sources/repeat_n.rs +++ b/library/core/src/iter/sources/repeat_n.rs @@ -1,5 +1,6 @@ +use crate::fmt; use crate::iter::{FusedIterator, TrustedLen, UncheckedIterator}; -use crate::mem::ManuallyDrop; +use crate::mem::{self, MaybeUninit}; use crate::num::NonZero; /// Creates a new iterator that repeats a single element a given number of times. @@ -58,14 +59,12 @@ use crate::num::NonZero; #[inline] #[stable(feature = "iter_repeat_n", since = "1.82.0")] pub fn repeat_n(element: T, count: usize) -> RepeatN { - let mut element = ManuallyDrop::new(element); - - if count == 0 { - // SAFETY: we definitely haven't dropped it yet, since we only just got - // passed it in, and because the count is zero the instance we're about - // to create won't drop it, so to avoid leaking we need to now. - unsafe { ManuallyDrop::drop(&mut element) }; - } + let element = if count == 0 { + // `element` gets dropped eagerly. + MaybeUninit::uninit() + } else { + MaybeUninit::new(element) + }; RepeatN { element, count } } @@ -74,15 +73,23 @@ pub fn repeat_n(element: T, count: usize) -> RepeatN { /// /// This `struct` is created by the [`repeat_n()`] function. /// See its documentation for more. -#[derive(Clone, Debug)] #[stable(feature = "iter_repeat_n", since = "1.82.0")] pub struct RepeatN { count: usize, - // Invariant: has been dropped iff count == 0. - element: ManuallyDrop, + // Invariant: uninit iff count == 0. + element: MaybeUninit, } impl RepeatN { + /// Returns the element if it hasn't been dropped already. + fn element_ref(&self) -> Option<&A> { + if self.count > 0 { + // SAFETY: The count is non-zero, so it must be initialized. + Some(unsafe { self.element.assume_init_ref() }) + } else { + None + } + } /// If we haven't already dropped the element, return it in an option. /// /// Clears the count so it won't be dropped again later. @@ -90,15 +97,36 @@ impl RepeatN { fn take_element(&mut self) -> Option { if self.count > 0 { self.count = 0; + let element = mem::replace(&mut self.element, MaybeUninit::uninit()); // SAFETY: We just set count to zero so it won't be dropped again, // and it used to be non-zero so it hasn't already been dropped. - unsafe { Some(ManuallyDrop::take(&mut self.element)) } + unsafe { Some(element.assume_init()) } } else { None } } } +#[stable(feature = "iter_repeat_n", since = "1.82.0")] +impl Clone for RepeatN { + fn clone(&self) -> RepeatN { + RepeatN { + count: self.count, + element: self.element_ref().cloned().map_or_else(MaybeUninit::uninit, MaybeUninit::new), + } + } +} + +#[stable(feature = "iter_repeat_n", since = "1.82.0")] +impl fmt::Debug for RepeatN { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RepeatN") + .field("count", &self.count) + .field("element", &self.element_ref()) + .finish() + } +} + #[stable(feature = "iter_repeat_n", since = "1.82.0")] impl Drop for RepeatN { fn drop(&mut self) { @@ -194,9 +222,11 @@ impl UncheckedIterator for RepeatN { // SAFETY: the check above ensured that the count used to be non-zero, // so element hasn't been dropped yet, and we just lowered the count to // zero so it won't be dropped later, and thus it's okay to take it here. - unsafe { ManuallyDrop::take(&mut self.element) } + unsafe { mem::replace(&mut self.element, MaybeUninit::uninit()).assume_init() } } else { - A::clone(&self.element) + // SAFETY: the count is non-zero, so it must have not been dropped yet. + let element = unsafe { self.element.assume_init_ref() }; + A::clone(element) } } } diff --git a/library/core/tests/iter/sources.rs b/library/core/tests/iter/sources.rs index eb8c80dd08724..506febaa056a8 100644 --- a/library/core/tests/iter/sources.rs +++ b/library/core/tests/iter/sources.rs @@ -156,3 +156,27 @@ fn test_repeat_n_drop() { drop((x0, x1, x2)); assert_eq!(count.get(), 3); } + +#[test] +fn test_repeat_n_soundness() { + let x = std::iter::repeat_n(String::from("use after free"), 0); + println!("{x:?}"); + + pub struct PanicOnClone; + + impl Clone for PanicOnClone { + fn clone(&self) -> Self { + unreachable!() + } + } + + // `repeat_n` should drop the element immediately if `count` is zero. + // `Clone` should then not try to clone the element. + let x = std::iter::repeat_n(PanicOnClone, 0); + let _ = x.clone(); + + let mut y = std::iter::repeat_n(Box::new(0), 1); + let x = y.next().unwrap(); + let _z = y; + assert_eq!(0, *x); +}