Skip to content
164 changes: 155 additions & 9 deletions wincode/src/schema/containers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,17 @@
#[cfg(all(feature = "alloc", target_has_atomic = "ptr"))]
use alloc::sync::Arc as AllocArc;
use {
crate::{TypeMeta, config::ConfigCore, error::ReadResult, io::Reader, schema::SchemaRead},
crate::{
TypeMeta,
config::ConfigCore,
error::{ReadResult, WriteResult},
io::{Reader, Writer},
len::SeqLen,
schema::{SchemaRead, SchemaWrite, size_of_elem_iter, write_elem_iter},
},
core::{
borrow::Borrow,
marker::PhantomData,
mem::{self, MaybeUninit},
ptr,
},
Expand All @@ -69,16 +78,9 @@ use {
use {
crate::{
context,
error::WriteResult,
io::Writer,
len::SeqLen,
schema::{
SchemaReadContext, SchemaWrite, size_of_elem_iter, size_of_elem_slice, write_elem_iter,
write_elem_slice_prealloc_check,
},
schema::{SchemaReadContext, size_of_elem_slice, write_elem_slice_prealloc_check},
},
alloc::{boxed::Box as AllocBox, collections, rc::Rc as AllocRc, vec},
core::marker::PhantomData,
};

/// A [`Vec`](std::vec::Vec) with a customizable length encoding.
Expand Down Expand Up @@ -517,6 +519,150 @@ where
}
}

/// Newtype that collects a fallible iterator into `Result<C, E>` while preserving `size_hint`.
///
/// Unlike `collect::<Result<V, E>>()`, which loses the size hint on error, this type
/// drives `V::from_iter` through an adaptor that stops on the first error but keeps
/// `size_hint` accurate so that `V` can preallocate its full expected capacity.
struct ResultPrealloc<T, E>(Result<T, E>);

impl<A, E, V: FromIterator<A>> FromIterator<Result<A, E>> for ResultPrealloc<V, E> {
fn from_iter<I: IntoIterator<Item = Result<A, E>>>(iter: I) -> ResultPrealloc<V, E> {
struct Iter<I, E> {
inner: I,
error: Option<E>,
}

impl<I: Iterator<Item = Result<T, E>>, T, E> Iterator for Iter<I, E> {
type Item = T;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.inner.next()?.map_err(|e| self.error = Some(e)).ok()
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}

let mut iter = Iter {
inner: iter.into_iter(),
error: None,
};
let result = V::from_iter(&mut iter);
ResultPrealloc(iter.error.map_or(Ok(result), Err))
}
}

/// Extension trait that adds [`collect_result_prealloc`](CollectResultExt::collect_result_prealloc)
/// to any fallible iterator, collecting into `Result<B, E>` with preallocation-friendly size hints.
trait CollectResultExt<T, E>: Iterator<Item = Result<T, E>> {
#[inline]
fn collect_result_prealloc<B: FromIterator<T>>(self) -> Result<B, E>
where
Self: Sized,
{
self.collect::<ResultPrealloc<B, E>>().0
}
}
impl<T, E, I> CollectResultExt<T, E> for I where I: Iterator<Item = Result<T, E>> {}

/// A generic sequence schema for custom collections that implement
/// [`FromIterator`] (for reading) and whose references implement
/// [`IntoIterator`] with an [`ExactSizeIterator`] (for writing).
///
/// Works for both element sequences and key-value maps:
/// - For element collections (sets, ordered sets, etc.) whose reference
/// iterators yield `&T`, the schema for `T` is used directly.
/// - For map-like collections whose reference iterators yield `(&K, &V)` pairs,
/// the pair itself acts as the schema (automatically satisfied when `K` and
/// `V` implement `SchemaWrite<C>`).
///
/// Intended for external collection types that cannot have a dedicated
/// schema impl added directly. Unlike [`Vec`], [`VecDeque`], and [`BinaryHeap`], this
/// container relies on the collection's [`FromIterator`] impl rather than
/// writing directly into preallocated memory.
///
/// # Allocation efficiency
///
/// During deserialization, the iterator passed to [`FromIterator`] has a
/// precise [`size_hint`](Iterator::size_hint) matching the number of elements
/// produced, unless a read error is encountered. Collections whose
/// [`FromIterator`] implementation uses the size hint to preallocate capacity
/// will allocate optimally. Collections that do not use it will not benefit.
///
/// # Examples
///
/// ```ignore
/// use some_crate::{IndexSet, MyMap};
/// use wincode::{SchemaRead, SchemaWrite, containers::FromIntoIterator, len::BincodeLen};
///
/// #[derive(SchemaRead, SchemaWrite)]
/// struct MyData {
/// #[wincode(with = "FromIntoIterator<IndexSet<u32>, BincodeLen>")]
/// items: IndexSet<u32>,
/// #[wincode(with = "FromIntoIterator<MyMap<u32, u64>, BincodeLen>")]
/// map: MyMap<u32, u64>,
/// }
/// ```
pub struct FromIntoIterator<Coll, Len>(PhantomData<(Coll, Len)>);

unsafe impl<Coll, Len, C: ConfigCore> SchemaWrite<C> for FromIntoIterator<Coll, Len>
where
Len: SeqLen<C>,
Coll: IntoIterator,
for<'a> &'a Coll: IntoIterator<Item: SchemaWrite<C>, IntoIter: ExactSizeIterator>,
for<'a> <&'a Coll as IntoIterator>::Item:
Borrow<<<&'a Coll as IntoIterator>::Item as SchemaWrite<C>>::Src>,
{
type Src = Coll;

#[inline]
fn size_of(src: &Coll) -> WriteResult<usize> {
size_of_elem_iter::<<&Coll as IntoIterator>::Item, Len, C>(src.into_iter())
}

#[inline]
fn write(writer: impl Writer, src: &Coll) -> WriteResult<()> {
let iter = src.into_iter();
Len::prealloc_check::<Coll::Item>(iter.len())?;
write_elem_iter::<<&Coll as IntoIterator>::Item, Len, C>(writer, iter)
}
}

unsafe impl<'de, Coll, Len, C: ConfigCore> SchemaRead<'de, C> for FromIntoIterator<Coll, Len>
where
Len: SeqLen<C>,
Coll: IntoIterator<Item: SchemaRead<'de, C>>,
Coll: FromIterator<<Coll::Item as SchemaRead<'de, C>>::Dst>,
{
type Dst = Coll;

#[inline]
fn read(mut reader: impl Reader<'de>, dst: &mut MaybeUninit<Coll>) -> ReadResult<()> {
let len =
Len::read_prealloc_check::<<Coll::Item as SchemaRead<'de, C>>::Dst>(reader.by_ref())?;

let coll = if let TypeMeta::Static { size, .. } = Coll::Item::TYPE_META {
#[allow(clippy::arithmetic_side_effects)]
// SAFETY: `Item::TYPE_META` specifies a static size, so `len` reads of `Item::Dst`
// will consume `size * len` bytes, fully consuming the trusted window.
let mut reader = unsafe { reader.as_trusted_for(size * len) }?;
(0..len)
.map(|_| Coll::Item::get(reader.by_ref()))
.collect_result_prealloc()?
} else {
(0..len)
.map(|_| Coll::Item::get(reader.by_ref()))
.collect_result_prealloc()?
};
dst.write(coll);
Ok(())
}
}

/// Decode `slice.len()` items of `T` into contiguous, uninitialized memory.
///
/// Errors if fewer than `slice.len()` items are available in the [`Reader`]
Expand Down
33 changes: 22 additions & 11 deletions wincode/src/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ use {
io::*,
len::SeqLen,
},
core::mem::MaybeUninit,
core::{borrow::Borrow, mem::MaybeUninit},
};

pub mod containers;
Expand Down Expand Up @@ -422,21 +422,21 @@ impl<T, C: ConfigCore> SchemaReadOwned<C> for T where T: for<'de> SchemaRead<'de

#[inline(always)]
#[allow(clippy::arithmetic_side_effects)]
fn size_of_elem_iter<'a, T, Len, C>(
value: impl ExactSizeIterator<Item = &'a T::Src>,
fn size_of_elem_iter<T, Len, C>(
value: impl ExactSizeIterator<Item: Borrow<T::Src>>,
) -> WriteResult<usize>
where
C: ConfigCore,
Len: SeqLen<C>,
T: SchemaWrite<C> + 'a,
T: SchemaWrite<C>,
{
if let TypeMeta::Static { size, .. } = T::TYPE_META {
return Ok(Len::write_bytes_needed(value.len())? + size * value.len());
}
// Extremely unlikely a type-in-memory's size will overflow usize::MAX.
Ok(Len::write_bytes_needed(value.len())?
+ (value
.map(T::size_of)
.map(|x| T::size_of(x.borrow()))
.try_fold(0usize, |acc, x| x.map(|x| acc + x))?))
}

Expand All @@ -454,14 +454,14 @@ where
}

#[inline(always)]
fn write_elem_iter<'a, T, Len, C>(
fn write_elem_iter<T, Len, C>(
mut writer: impl Writer,
src: impl ExactSizeIterator<Item = &'a T::Src>,
src: impl ExactSizeIterator<Item: Borrow<T::Src>>,
) -> WriteResult<()>
where
C: ConfigCore,
Len: SeqLen<C>,
T: SchemaWrite<C> + 'a,
T: SchemaWrite<C>,
{
if let TypeMeta::Static { size, .. } = T::TYPE_META {
#[allow(clippy::arithmetic_side_effects)]
Expand All @@ -472,15 +472,15 @@ where
let mut writer = unsafe { writer.as_trusted_for(needed) }?;
Len::write(writer.by_ref(), src.len())?;
for item in src {
T::write(writer.by_ref(), item)?;
T::write(writer.by_ref(), item.borrow())?;
}
writer.finish()?;
return Ok(());
}

Len::write(writer.by_ref(), src.len())?;
for item in src {
T::write(writer.by_ref(), item)?;
T::write(writer.by_ref(), item.borrow())?;
}
Ok(())
}
Expand Down Expand Up @@ -2678,13 +2678,24 @@ mod tests {
prop_assert_eq!(&test_data, &wincode_deserialized);
prop_assert_eq!(wincode_deserialized, bincode_deserialized);

type TestMapSeq = containers::FromIntoIterator<TestMap, BincodeLen>;
let test_seq_serialized = TestMapSeq::serialize(&test_data).unwrap();
assert_eq!(test_seq_serialized, wincode_serialized);
let test_seq_deserialized = TestMapSeq::deserialize(&test_seq_serialized).unwrap();
prop_assert_eq!(&test_data, &test_seq_deserialized);

type RegularMap = HashMap<String, HashSet<u32>>;
let regular_deserialized: RegularMap = deserialize(&wincode_serialized).unwrap();
let regular_serialized = serialize(&regular_deserialized).unwrap();
let test_deserialized: TestMap = deserialize(&regular_serialized).unwrap();
prop_assert_eq!(test_data, test_deserialized);
}

type RegularMapSeq = containers::FromIntoIterator<RegularMap, BincodeLen>;
let regular_seq_serialized = RegularMapSeq::serialize(&regular_deserialized).unwrap();
assert_eq!(regular_serialized, regular_seq_serialized);
let regular_seq_deserialized = RegularMapSeq::deserialize(&regular_seq_serialized).unwrap();
prop_assert_eq!(&regular_deserialized, &regular_seq_deserialized);
}

#[test]
fn test_btree_map_zero_copy(map in proptest::collection::btree_map(any::<u8>(), any::<StructZeroCopy>(), 0..=100)) {
Expand Down