Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: max_size: various small refactoring #223

Merged
merged 3 commits into from
Sep 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 61 additions & 40 deletions borsh/src/schema/container_ext/max_size.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
use super::{BorshSchemaContainer, Declaration, Definition, Fields};
use crate::__private::maybestd::{string::ToString, vec::Vec};

use core::num::NonZeroUsize;

/// NonZeroUsize of value one.
// TODO: Replace usage by NonZeroUsize::MIN once MSRV is 1.70+.
const ONE: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(1) };

impl BorshSchemaContainer {
/// Returns the largest possible size of a serialised object based solely on its type.
///
Expand Down Expand Up @@ -37,9 +43,9 @@ impl BorshSchemaContainer {
/// assert_eq!(Err(borsh::schema::SchemaMaxSerializedSizeError::Overflow),
/// schema.max_serialized_size());
/// ```
pub fn max_serialized_size(&self) -> core::result::Result<usize, SchemaMaxSerializedSizeError> {
pub fn max_serialized_size(&self) -> Result<usize, SchemaMaxSerializedSizeError> {
let mut stack = Vec::new();
max_serialized_size_impl(1, self.declaration(), self, &mut stack)
max_serialized_size_impl(ONE, self.declaration(), self, &mut stack)
}
}

Expand All @@ -64,37 +70,38 @@ pub enum SchemaMaxSerializedSizeError {

/// Implementation of [`BorshSchema::max_serialized_size`].
fn max_serialized_size_impl<'a>(
count: usize,
count: NonZeroUsize,
declaration: &'a str,
schema: &'a BorshSchemaContainer,
stack: &mut Vec<&'a str>,
) -> core::result::Result<usize, SchemaMaxSerializedSizeError> {
) -> Result<usize, SchemaMaxSerializedSizeError> {
use core::convert::TryFrom;

/// Maximum number of elements in a vector or length of a string which can
/// be serialised.
const MAX_LEN: usize = u32::MAX as usize;
const MAX_LEN: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(u32::MAX as usize) };

fn add(x: usize, y: usize) -> core::result::Result<usize, SchemaMaxSerializedSizeError> {
fn add(x: usize, y: usize) -> Result<usize, SchemaMaxSerializedSizeError> {
x.checked_add(y)
.ok_or(SchemaMaxSerializedSizeError::Overflow)
}

fn mul(x: usize, y: usize) -> core::result::Result<usize, SchemaMaxSerializedSizeError> {
x.checked_mul(y)
fn mul(x: NonZeroUsize, y: usize) -> Result<usize, SchemaMaxSerializedSizeError> {
x.get()
.checked_mul(y)
.ok_or(SchemaMaxSerializedSizeError::Overflow)
}

/// Calculates max serialised size of a tuple with given members.
fn tuple<'a>(
count: usize,
count: NonZeroUsize,
elements: impl core::iter::IntoIterator<Item = &'a Declaration>,
schema: &'a BorshSchemaContainer,
stack: &mut Vec<&'a str>,
) -> ::core::result::Result<usize, SchemaMaxSerializedSizeError> {
) -> Result<usize, SchemaMaxSerializedSizeError> {
let mut sum: usize = 0;
for el in elements {
sum = add(sum, max_serialized_size_impl(1, el, schema, stack)?)?;
sum = add(sum, max_serialized_size_impl(ONE, el, schema, stack)?)?;
}
mul(count, sum)
}
Expand All @@ -108,18 +115,25 @@ fn max_serialized_size_impl<'a>(
Ok(Definition::Array { length, elements }) => {
// Aggregate `count` and `length` to a single number. If this
// overflows, check if array’s element is zero-sized.
let count = usize::try_from(*length)
let count_lengths = usize::try_from(*length)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using the same count name everywhere was confusing to me

.ok()
.and_then(|len| len.checked_mul(count));
match count {
Some(0) => Ok(0),
Some(count) => max_serialized_size_impl(count, elements, schema, stack),
None if is_zero_size(elements, schema)
.map_err(|_err| SchemaMaxSerializedSizeError::Recursive)? =>
{
Ok(0)
.and_then(|len| len.checked_mul(count.get()));
let count_lengths = match count_lengths {
Some(count_lengths) => count_lengths,
None if is_zero_size_impl(elements.as_str(), schema, stack)? => {
return Ok(0);
}
None => {
return Err(SchemaMaxSerializedSizeError::Overflow);
}
};
let count_lengths = NonZeroUsize::new(count_lengths);

match count_lengths {
None => Ok(0),
Some(count_lengths) => {
max_serialized_size_impl(count_lengths, elements, schema, stack)
}
None => Err(SchemaMaxSerializedSizeError::Overflow),
}
}
Ok(Definition::Sequence { elements }) => {
Expand All @@ -135,7 +149,7 @@ fn max_serialized_size_impl<'a>(
}) => {
let mut max = 0;
for (_, variant) in variants {
let sz = max_serialized_size_impl(1, variant, schema, stack)?;
let sz = max_serialized_size_impl(ONE, variant, schema, stack)?;
max = max.max(sz);
}
max.checked_add(usize::from(*tag_width))
Expand All @@ -154,14 +168,14 @@ fn max_serialized_size_impl<'a>(

// Primitive types.
Err("nil") => Ok(0),
Err("bool" | "i8" | "u8" | "nonzero_i8" | "nonzero_u8") => Ok(count),
Err("bool" | "i8" | "u8" | "nonzero_i8" | "nonzero_u8") => Ok(count.get()),
Err("i16" | "u16" | "nonzero_i16" | "nonzero_u16") => mul(count, 2),
Err("i32" | "u32" | "f32" | "nonzero_i32" | "nonzero_u32") => mul(count, 4),
Err("i64" | "u64" | "f64" | "nonzero_i64" | "nonzero_u64") => mul(count, 8),
Err("i128" | "u128" | "nonzero_i128" | "nonzero_u128") => mul(count, 16),

// string is just Vec<u8>
Err("string") => mul(count, add(MAX_LEN, 4)?),
Err("string") => mul(count, add(MAX_LEN.get(), 4)?),

Err(declaration) => Err(SchemaMaxSerializedSizeError::MissingDefinition(
declaration.to_string(),
Expand All @@ -187,37 +201,44 @@ fn max_serialized_size_impl<'a>(
pub(super) fn is_zero_size(
declaration: &Declaration,
schema: &BorshSchemaContainer,
) -> Result<bool, ()> {
) -> Result<bool, Recursive> {
let mut stack = Vec::new();
is_zero_size_impl(declaration, schema, &mut stack)
}
const RECURSIVE: () = ();

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(super) struct Recursive;

impl From<Recursive> for SchemaMaxSerializedSizeError {
fn from(_: Recursive) -> Self {
Self::Recursive
}
}

fn is_zero_size_impl<'a>(
declaration: &'a str,
schema: &'a BorshSchemaContainer,
stack: &mut Vec<Declaration>,
) -> Result<bool, ()> {
fn all<T>(
stack: &mut Vec<&'a str>,
) -> Result<bool, Recursive> {
fn all<'a, T: 'a>(
iter: impl Iterator<Item = T>,
f_key: impl Fn(&T) -> &Declaration,
schema: &BorshSchemaContainer,
stack: &mut Vec<Declaration>,
) -> Result<bool, ()> {
let mut all = true;
f_key: impl Fn(&T) -> &'a Declaration,
schema: &'a BorshSchemaContainer,
stack: &mut Vec<&'a str>,
) -> Result<bool, Recursive> {
for element in iter {
let declaration = f_key(&element);
if !is_zero_size_impl(declaration.as_str(), schema, stack)? {
all = false;
return Ok(false);
}
}
Ok(all)
Ok(true)
}

if stack.iter().any(|dec| *dec == declaration) {
return Err(RECURSIVE);
return Err(Recursive);
}
stack.push(declaration.to_string());
stack.push(declaration);

let res = match schema.get_definition(declaration).ok_or(declaration) {
Ok(Definition::Array { length, elements }) => {
Expand Down Expand Up @@ -308,7 +329,7 @@ mod tests {
struct RecursiveNoExitStructUnnamed(Box<RecursiveNoExitStructUnnamed>);

let schema = BorshSchemaContainer::for_type::<RecursiveNoExitStructUnnamed>();
assert_eq!(Err(()), is_zero_size(schema.declaration(), &schema));
assert_eq!(Err(Recursive), is_zero_size(schema.declaration(), &schema));
}

#[test]
Expand All @@ -325,7 +346,7 @@ mod tests {

test_ok::<Option<()>>(1);
test_ok::<Option<u8>>(2);
test_ok::<core::result::Result<u8, usize>>(9);
test_ok::<Result<u8, usize>>(9);

test_ok::<()>(0);
test_ok::<(u8,)>(1);
Expand Down
Loading