Skip to content

Commit

Permalink
implement Input for str (pydantic#1229)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt authored Mar 18, 2024
1 parent b92a69b commit 352d40f
Show file tree
Hide file tree
Showing 53 changed files with 461 additions and 443 deletions.
33 changes: 20 additions & 13 deletions src/errors/line_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,27 @@ use pyo3::DowncastIntoError;

use jiter::JsonValue;

use crate::input::BorrowInput;
use crate::input::Input;

use super::location::{LocItem, Location};
use super::types::ErrorType;

pub type ValResult<T> = Result<T, ValError>;

pub trait AsErrorValue {
fn as_error_value(&self) -> InputValue;
pub trait ToErrorValue {
fn to_error_value(&self) -> InputValue;
}

impl<'a, T: Input<'a>> AsErrorValue for T {
fn as_error_value(&self) -> InputValue {
Input::as_error_value(self)
impl<'a, T: BorrowInput<'a>> ToErrorValue for T {
fn to_error_value(&self) -> InputValue {
Input::as_error_value(self.borrow_input())
}
}

impl ToErrorValue for &'_ dyn ToErrorValue {
fn to_error_value(&self) -> InputValue {
(**self).to_error_value()
}
}

Expand Down Expand Up @@ -55,11 +62,11 @@ impl From<Vec<ValLineError>> for ValError {
}

impl ValError {
pub fn new(error_type: ErrorType, input: &impl AsErrorValue) -> ValError {
pub fn new(error_type: ErrorType, input: impl ToErrorValue) -> ValError {
Self::LineErrors(vec![ValLineError::new(error_type, input)])
}

pub fn new_with_loc(error_type: ErrorType, input: &impl AsErrorValue, loc: impl Into<LocItem>) -> ValError {
pub fn new_with_loc(error_type: ErrorType, input: impl ToErrorValue, loc: impl Into<LocItem>) -> ValError {
Self::LineErrors(vec![ValLineError::new_with_loc(error_type, input, loc)])
}

Expand Down Expand Up @@ -94,26 +101,26 @@ pub struct ValLineError {
}

impl ValLineError {
pub fn new(error_type: ErrorType, input: &impl AsErrorValue) -> ValLineError {
pub fn new(error_type: ErrorType, input: impl ToErrorValue) -> ValLineError {
Self {
error_type,
input_value: input.as_error_value(),
input_value: input.to_error_value(),
location: Location::default(),
}
}

pub fn new_with_loc(error_type: ErrorType, input: &impl AsErrorValue, loc: impl Into<LocItem>) -> ValLineError {
pub fn new_with_loc(error_type: ErrorType, input: impl ToErrorValue, loc: impl Into<LocItem>) -> ValLineError {
Self {
error_type,
input_value: input.as_error_value(),
input_value: input.to_error_value(),
location: Location::new_some(loc.into()),
}
}

pub fn new_with_full_loc(error_type: ErrorType, input: &impl AsErrorValue, location: Location) -> ValLineError {
pub fn new_with_full_loc(error_type: ErrorType, input: impl ToErrorValue, location: Location) -> ValLineError {
Self {
error_type,
input_value: input.as_error_value(),
input_value: input.to_error_value(),
location,
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/errors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod types;
mod validation_exception;
mod value_exception;

pub use self::line_error::{AsErrorValue, InputValue, ValError, ValLineError, ValResult};
pub use self::line_error::{InputValue, ToErrorValue, ValError, ValLineError, ValResult};
pub use self::location::LocItem;
pub use self::types::{list_all_errors, ErrorType, ErrorTypeDefaults, Number};
pub use self::validation_exception::ValidationError;
Expand Down
6 changes: 3 additions & 3 deletions src/errors/value_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use pyo3::types::{PyDict, PyString};
use crate::input::InputType;
use crate::tools::extract_i64;

use super::line_error::AsErrorValue;
use super::line_error::ToErrorValue;
use super::{ErrorType, ValError};

#[pyclass(extends=PyException, module="pydantic_core._pydantic_core")]
Expand Down Expand Up @@ -106,7 +106,7 @@ impl PydanticCustomError {
}

impl PydanticCustomError {
pub fn into_val_error(self, input: &impl AsErrorValue) -> ValError {
pub fn into_val_error(self, input: impl ToErrorValue) -> ValError {
let error_type = ErrorType::CustomError {
error_type: self.error_type,
message_template: self.message_template,
Expand Down Expand Up @@ -181,7 +181,7 @@ impl PydanticKnownError {
}

impl PydanticKnownError {
pub fn into_val_error(self, input: &impl AsErrorValue) -> ValError {
pub fn into_val_error(self, input: impl ToErrorValue) -> ValError {
ValError::new(self.error_type, input)
}
}
47 changes: 24 additions & 23 deletions src/input/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use std::hash::Hasher;
use strum::EnumMessage;

use super::Input;
use crate::errors::ToErrorValue;
use crate::errors::{ErrorType, ValError, ValResult};
use crate::tools::py_err;

Expand Down Expand Up @@ -285,7 +286,7 @@ impl<'a> EitherDateTime<'a> {
}
}

pub fn bytes_as_date<'a>(input: &'a impl Input<'a>, bytes: &[u8]) -> ValResult<EitherDate<'a>> {
pub fn bytes_as_date<'py>(input: &(impl Input<'py> + ?Sized), bytes: &[u8]) -> ValResult<EitherDate<'py>> {
match Date::parse_bytes(bytes) {
Ok(date) => Ok(date.into()),
Err(err) => Err(ValError::new(
Expand All @@ -298,11 +299,11 @@ pub fn bytes_as_date<'a>(input: &'a impl Input<'a>, bytes: &[u8]) -> ValResult<E
}
}

pub fn bytes_as_time<'a>(
input: &'a impl Input<'a>,
pub fn bytes_as_time<'py>(
input: &(impl Input<'py> + ?Sized),
bytes: &[u8],
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherTime<'a>> {
) -> ValResult<EitherTime<'py>> {
match Time::parse_bytes_with_config(
bytes,
&TimeConfig {
Expand All @@ -321,11 +322,11 @@ pub fn bytes_as_time<'a>(
}
}

pub fn bytes_as_datetime<'a, 'b>(
input: &'a impl Input<'a>,
bytes: &'b [u8],
pub fn bytes_as_datetime<'py>(
input: &(impl Input<'py> + ?Sized),
bytes: &[u8],
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherDateTime<'a>> {
) -> ValResult<EitherDateTime<'py>> {
match DateTime::parse_bytes_with_config(
bytes,
&TimeConfig {
Expand All @@ -344,11 +345,11 @@ pub fn bytes_as_datetime<'a, 'b>(
}
}

pub fn int_as_datetime<'a>(
input: &'a impl Input<'a>,
pub fn int_as_datetime<'py>(
input: &(impl Input<'py> + ?Sized),
timestamp: i64,
timestamp_microseconds: u32,
) -> ValResult<EitherDateTime> {
) -> ValResult<EitherDateTime<'py>> {
match DateTime::from_timestamp_with_config(
timestamp,
timestamp_microseconds,
Expand Down Expand Up @@ -382,7 +383,7 @@ macro_rules! nan_check {
};
}

pub fn float_as_datetime<'a>(input: &'a impl Input<'a>, timestamp: f64) -> ValResult<EitherDateTime> {
pub fn float_as_datetime<'py>(input: &(impl Input<'py> + ?Sized), timestamp: f64) -> ValResult<EitherDateTime<'py>> {
nan_check!(input, timestamp, DatetimeParsing);
let microseconds = timestamp.fract().abs() * 1_000_000.0;
// checking for extra digits in microseconds is unreliable with large floats,
Expand All @@ -408,11 +409,11 @@ pub fn date_as_datetime<'py>(date: &Bound<'py, PyDate>) -> PyResult<EitherDateTi

const MAX_U32: i64 = u32::MAX as i64;

pub fn int_as_time<'a>(
input: &'a impl Input<'a>,
pub fn int_as_time<'py>(
input: &(impl Input<'py> + ?Sized),
timestamp: i64,
timestamp_microseconds: u32,
) -> ValResult<EitherTime> {
) -> ValResult<EitherTime<'py>> {
let time_timestamp: u32 = match timestamp {
t if t < 0_i64 => {
return Err(ValError::new(
Expand Down Expand Up @@ -447,14 +448,14 @@ pub fn int_as_time<'a>(
}
}

pub fn float_as_time<'a>(input: &'a impl Input<'a>, timestamp: f64) -> ValResult<EitherTime> {
pub fn float_as_time<'py>(input: &(impl Input<'py> + ?Sized), timestamp: f64) -> ValResult<EitherTime<'py>> {
nan_check!(input, timestamp, TimeParsing);
let microseconds = timestamp.fract().abs() * 1_000_000.0;
// round for same reason as above
int_as_time(input, timestamp.floor() as i64, microseconds.round() as u32)
}

fn map_timedelta_err<'a>(input: &'a impl Input<'a>, err: ParseError) -> ValError {
fn map_timedelta_err(input: impl ToErrorValue, err: ParseError) -> ValError {
ValError::new(
ErrorType::TimeDeltaParsing {
error: Cow::Borrowed(err.get_documentation().unwrap_or_default()),
Expand All @@ -464,11 +465,11 @@ fn map_timedelta_err<'a>(input: &'a impl Input<'a>, err: ParseError) -> ValError
)
}

pub fn bytes_as_timedelta<'a, 'b>(
input: &'a impl Input<'a>,
bytes: &'b [u8],
pub fn bytes_as_timedelta<'py>(
input: &(impl Input<'py> + ?Sized),
bytes: &[u8],
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherTimedelta<'a>> {
) -> ValResult<EitherTimedelta<'py>> {
match Duration::parse_bytes_with_config(
bytes,
&TimeConfig {
Expand All @@ -481,7 +482,7 @@ pub fn bytes_as_timedelta<'a, 'b>(
}
}

pub fn int_as_duration<'a>(input: &'a impl Input<'a>, total_seconds: i64) -> ValResult<Duration> {
pub fn int_as_duration(input: impl ToErrorValue, total_seconds: i64) -> ValResult<Duration> {
let positive = total_seconds >= 0;
let total_seconds = total_seconds.unsigned_abs();
// we can safely unwrap here since we've guaranteed seconds and microseconds can't cause overflow
Expand All @@ -490,7 +491,7 @@ pub fn int_as_duration<'a>(input: &'a impl Input<'a>, total_seconds: i64) -> Val
Duration::new(positive, days, seconds, 0).map_err(|err| map_timedelta_err(input, err))
}

pub fn float_as_duration<'a>(input: &'a impl Input<'a>, total_seconds: f64) -> ValResult<Duration> {
pub fn float_as_duration(input: impl ToErrorValue, total_seconds: f64) -> ValResult<Duration> {
nan_check!(input, total_seconds, TimeDeltaParsing);
let positive = total_seconds >= 0_f64;
let total_seconds = total_seconds.abs();
Expand Down
25 changes: 16 additions & 9 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use pyo3::exceptions::PyValueError;
use pyo3::types::{PyDict, PyType};
use pyo3::{intern, prelude::*};

use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::errors::{ErrorTypeDefaults, InputValue, ValError, ValResult};
use crate::tools::py_err;
use crate::{PyMultiHostUrl, PyUrl};

Expand Down Expand Up @@ -46,7 +46,7 @@ impl TryFrom<&str> for InputType {
/// the convention is to either implement:
/// * `strict_*` & `lax_*` if they have different behavior
/// * or, `validate_*` and `strict_*` to just call `validate_*` if the behavior for strict and lax is the same
pub trait Input<'py>: fmt::Debug + ToPyObject + Into<LocItem> + Sized {
pub trait Input<'py>: fmt::Debug + ToPyObject {
fn as_error_value(&self) -> InputValue;

fn identity(&self) -> Option<usize> {
Expand Down Expand Up @@ -83,9 +83,9 @@ pub trait Input<'py>: fmt::Debug + ToPyObject + Into<LocItem> + Sized {
false
}

fn validate_args(&self) -> ValResult<GenericArguments<'_>>;
fn validate_args(&self) -> ValResult<GenericArguments<'_, 'py>>;

fn validate_dataclass_args<'a>(&'a self, dataclass_name: &str) -> ValResult<GenericArguments<'a>>;
fn validate_dataclass_args<'a>(&'a self, dataclass_name: &str) -> ValResult<GenericArguments<'a, 'py>>;

fn validate_str(&self, strict: bool, coerce_numbers_to_str: bool) -> ValResult<ValidationMatch<EitherString<'_>>>;

Expand Down Expand Up @@ -201,25 +201,25 @@ pub trait Input<'py>: fmt::Debug + ToPyObject + Into<LocItem> + Sized {

fn validate_iter(&self) -> ValResult<GenericIterator>;

fn validate_date(&self, strict: bool) -> ValResult<ValidationMatch<EitherDate>>;
fn validate_date(&self, strict: bool) -> ValResult<ValidationMatch<EitherDate<'py>>>;

fn validate_time(
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<ValidationMatch<EitherTime>>;
) -> ValResult<ValidationMatch<EitherTime<'py>>>;

fn validate_datetime(
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<ValidationMatch<EitherDateTime>>;
) -> ValResult<ValidationMatch<EitherDateTime<'py>>>;

fn validate_timedelta(
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<ValidationMatch<EitherTimedelta>>;
) -> ValResult<ValidationMatch<EitherTimedelta<'py>>>;
}

/// The problem to solve here is that iterating collections often returns owned
Expand All @@ -228,6 +228,13 @@ pub trait Input<'py>: fmt::Debug + ToPyObject + Into<LocItem> + Sized {
/// or borrowed; all we care about is that we can borrow it again with `borrow_input`
/// for some lifetime 'a.
pub trait BorrowInput<'py> {
type Input: Input<'py>;
type Input: Input<'py> + ?Sized;
fn borrow_input(&self) -> &Self::Input;
}

impl<'py, T: Input<'py> + ?Sized> BorrowInput<'py> for &'_ T {
type Input = T;
fn borrow_input(&self) -> &Self::Input {
self
}
}
Loading

0 comments on commit 352d40f

Please sign in to comment.