diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index b0e058d9b..918de4054 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -98,18 +98,7 @@ pub trait Input<'py>: fmt::Debug + ToPyObject { fn validate_float(&self, strict: bool) -> ValMatch>; - fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValResult> { - if strict { - self.strict_decimal(py) - } else { - self.lax_decimal(py) - } - } - fn strict_decimal(&self, py: Python<'py>) -> ValResult>; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_decimal(&self, py: Python<'py>) -> ValResult> { - self.strict_decimal(py) - } + fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValMatch>; type Dict<'a>: ValidatedDict<'py> where diff --git a/src/input/input_json.rs b/src/input/input_json.rs index 3adc36ba6..897c2c1de 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -165,12 +165,13 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> { } } - fn strict_decimal(&self, py: Python<'py>) -> ValResult> { + fn validate_decimal(&self, _strict: bool, py: Python<'py>) -> ValMatch> { match self { - JsonValue::Float(f) => create_decimal(&PyString::new_bound(py, &f.to_string()), self), - + JsonValue::Float(f) => { + create_decimal(&PyString::new_bound(py, &f.to_string()), self).map(ValidationMatch::strict) + } JsonValue::Str(..) | JsonValue::Int(..) | JsonValue::BigInt(..) => { - create_decimal(self.to_object(py).bind(py), self) + create_decimal(self.to_object(py).bind(py), self).map(ValidationMatch::strict) } _ => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)), } @@ -373,8 +374,8 @@ impl<'py> Input<'py> for str { str_as_float(self, self).map(ValidationMatch::lax) } - fn strict_decimal(&self, py: Python<'py>) -> ValResult> { - create_decimal(self.to_object(py).bind(py), self) + fn validate_decimal(&self, _strict: bool, py: Python<'py>) -> ValMatch> { + create_decimal(self.to_object(py).bind(py), self).map(ValidationMatch::lax) } type Dict<'a> = Never; diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 7840a825a..84f2dccb7 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -248,8 +248,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> { str_as_int(self, s) } else if self.is_exact_instance_of::() { float_as_int(self, self.extract::()?) - } else if let Ok(decimal) = self.strict_decimal(self.py()) { - decimal_as_int(self, &decimal) + } else if let Ok(decimal) = self.validate_decimal(true, self.py()) { + decimal_as_int(self, &decimal.into_inner()) } else if let Ok(float) = self.extract::() { float_as_int(self, float) } else if let Some(enum_val) = maybe_as_enum(self) { @@ -307,48 +307,44 @@ impl<'py> Input<'py> for Bound<'py, PyAny> { Err(ValError::new(ErrorTypeDefaults::FloatType, self)) } - fn strict_decimal(&self, py: Python<'py>) -> ValResult> { + fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValMatch> { let decimal_type = get_decimal_type(py); + // Fast path for existing decimal objects if self.is_exact_instance(decimal_type) { - return Ok(self.to_owned()); + return Ok(ValidationMatch::exact(self.to_owned().clone())); + } + + if !strict { + if self.is_instance_of::() || (self.is_instance_of::() && !self.is_instance_of::()) + { + // Checking isinstance for str / int / bool is fast compared to decimal / float + return create_decimal(self, self).map(ValidationMatch::lax); + } + + if self.is_instance_of::() { + return create_decimal(self.str()?.as_any(), self).map(ValidationMatch::lax); + } } - // Try subclasses of decimals, they will be upcast to Decimal if self.is_instance(decimal_type)? { - return create_decimal(self, self); + // Upcast subclasses to decimal + return create_decimal(self, self).map(ValidationMatch::strict); } - Err(ValError::new( + let error_type = if strict { ErrorType::IsInstanceOf { class: decimal_type .qualname() .and_then(|name| name.extract()) .unwrap_or_else(|_| "Decimal".to_owned()), context: None, - }, - self, - )) - } - - fn lax_decimal(&self, py: Python<'py>) -> ValResult> { - let decimal_type = get_decimal_type(py); - // Fast path for existing decimal objects - if self.is_exact_instance(decimal_type) { - return Ok(self.to_owned().clone()); - } - - if self.is_instance_of::() || (self.is_instance_of::() && !self.is_instance_of::()) { - // checking isinstance for str / int / bool is fast compared to decimal / float - create_decimal(self, self) - } else if self.is_instance(decimal_type)? { - // upcast subclasses to decimal - return create_decimal(self, self); - } else if self.is_instance_of::() { - create_decimal(self.str()?.as_any(), self) + } } else { - Err(ValError::new(ErrorTypeDefaults::DecimalType, self)) - } + ErrorTypeDefaults::DecimalType + }; + + Err(ValError::new(error_type, self)) } type Dict<'a> = GenericPyMapping<'a, 'py> where Self: 'a; diff --git a/src/input/input_string.rs b/src/input/input_string.rs index 3ef1b58ce..73ee6fd18 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -141,9 +141,9 @@ impl<'py> Input<'py> for StringMapping<'py> { } } - fn strict_decimal(&self, _py: Python<'py>) -> ValResult> { + fn validate_decimal(&self, _strict: bool, _py: Python<'py>) -> ValMatch> { match self { - Self::String(s) => create_decimal(s, self), + Self::String(s) => create_decimal(s, self).map(ValidationMatch::strict), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)), } } diff --git a/src/validators/decimal.rs b/src/validators/decimal.rs index d1c15c299..8008d4260 100644 --- a/src/validators/decimal.rs +++ b/src/validators/decimal.rs @@ -122,7 +122,7 @@ impl Validator for DecimalValidator { input: &(impl Input<'py> + ?Sized), state: &mut ValidationState<'_, 'py>, ) -> ValResult { - let decimal = input.validate_decimal(state.strict_or(self.strict), py)?; + let decimal = input.validate_decimal(state.strict_or(self.strict), py)?.unpack(state); if !self.allow_inf_nan || self.check_digits { if !decimal.call_method0(intern!(py, "is_finite"))?.extract()? { diff --git a/tests/validators/test_decimal.py b/tests/validators/test_decimal.py index 931fe6377..fa1c0270d 100644 --- a/tests/validators/test_decimal.py +++ b/tests/validators/test_decimal.py @@ -9,7 +9,7 @@ import pytest from dirty_equals import FunctionCheck, IsStr -from pydantic_core import SchemaValidator, ValidationError +from pydantic_core import SchemaValidator, ValidationError, core_schema from ..conftest import Err, PyAndJson, plain_repr @@ -467,3 +467,23 @@ def test_validate_max_digits_and_decimal_places_edge_case() -> None: assert v.validate_python(Decimal('9999999999999999.999999999999999999')) == Decimal( '9999999999999999.999999999999999999' ) + + +def test_str_validation_w_strict() -> None: + s = SchemaValidator(core_schema.decimal_schema(strict=True)) + + with pytest.raises(ValidationError): + assert s.validate_python('1.23') + + +def test_str_validation_w_lax() -> None: + s = SchemaValidator(core_schema.decimal_schema(strict=False)) + + assert s.validate_python('1.23') == Decimal('1.23') + + +def test_union_with_str_prefers_str() -> None: + s = SchemaValidator(core_schema.union_schema([core_schema.decimal_schema(), core_schema.str_schema()])) + + assert s.validate_python('1.23') == '1.23' + assert s.validate_python(1.23) == Decimal('1.23')