Skip to content

Commit

Permalink
Introduce exactness into Decimal validation logic (#1405)
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle authored Aug 15, 2024
1 parent fdd1e85 commit 08a99b5
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 51 deletions.
13 changes: 1 addition & 12 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,7 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {

fn validate_float(&self, strict: bool) -> ValMatch<EitherFloat<'_>>;

fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
if strict {
self.strict_decimal(py)
} else {
self.lax_decimal(py)
}
}
fn strict_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
self.strict_decimal(py)
}
fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>>;

type Dict<'a>: ValidatedDict<'py>
where
Expand Down
13 changes: 7 additions & 6 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,13 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
}
}

fn strict_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
fn validate_decimal(&self, _strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
match self {
JsonValue::Float(f) => create_decimal(&PyString::new_bound(py, &f.to_string()), self),

JsonValue::Float(f) => {
create_decimal(&PyString::new_bound(py, &f.to_string()), self).map(ValidationMatch::strict)
}
JsonValue::Str(..) | JsonValue::Int(..) | JsonValue::BigInt(..) => {
create_decimal(self.to_object(py).bind(py), self)
create_decimal(self.to_object(py).bind(py), self).map(ValidationMatch::strict)
}
_ => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)),
}
Expand Down Expand Up @@ -399,8 +400,8 @@ impl<'py> Input<'py> for str {
str_as_float(self, self).map(ValidationMatch::lax)
}

fn strict_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
create_decimal(self.to_object(py).bind(py), self)
fn validate_decimal(&self, _strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
create_decimal(self.to_object(py).bind(py), self).map(ValidationMatch::lax)
}

type Dict<'a> = Never;
Expand Down
54 changes: 25 additions & 29 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
str_as_int(self, s)
} else if self.is_exact_instance_of::<PyFloat>() {
float_as_int(self, self.extract::<f64>()?)
} else if let Ok(decimal) = self.strict_decimal(self.py()) {
decimal_as_int(self, &decimal)
} else if let Ok(decimal) = self.validate_decimal(true, self.py()) {
decimal_as_int(self, &decimal.into_inner())
} else if let Ok(float) = self.extract::<f64>() {
float_as_int(self, float)
} else if let Some(enum_val) = maybe_as_enum(self) {
Expand Down Expand Up @@ -310,48 +310,44 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
Err(ValError::new(ErrorTypeDefaults::FloatType, self))
}

fn strict_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
let decimal_type = get_decimal_type(py);

// Fast path for existing decimal objects
if self.is_exact_instance(decimal_type) {
return Ok(self.to_owned());
return Ok(ValidationMatch::exact(self.to_owned().clone()));
}

if !strict {
if self.is_instance_of::<PyString>() || (self.is_instance_of::<PyInt>() && !self.is_instance_of::<PyBool>())
{
// Checking isinstance for str / int / bool is fast compared to decimal / float
return create_decimal(self, self).map(ValidationMatch::lax);
}

if self.is_instance_of::<PyFloat>() {
return create_decimal(self.str()?.as_any(), self).map(ValidationMatch::lax);
}
}

// Try subclasses of decimals, they will be upcast to Decimal
if self.is_instance(decimal_type)? {
return create_decimal(self, self);
// Upcast subclasses to decimal
return create_decimal(self, self).map(ValidationMatch::strict);
}

Err(ValError::new(
let error_type = if strict {
ErrorType::IsInstanceOf {
class: decimal_type
.qualname()
.and_then(|name| name.extract())
.unwrap_or_else(|_| "Decimal".to_owned()),
context: None,
},
self,
))
}

fn lax_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
let decimal_type = get_decimal_type(py);
// Fast path for existing decimal objects
if self.is_exact_instance(decimal_type) {
return Ok(self.to_owned().clone());
}

if self.is_instance_of::<PyString>() || (self.is_instance_of::<PyInt>() && !self.is_instance_of::<PyBool>()) {
// checking isinstance for str / int / bool is fast compared to decimal / float
create_decimal(self, self)
} else if self.is_instance(decimal_type)? {
// upcast subclasses to decimal
return create_decimal(self, self);
} else if self.is_instance_of::<PyFloat>() {
create_decimal(self.str()?.as_any(), self)
}
} else {
Err(ValError::new(ErrorTypeDefaults::DecimalType, self))
}
ErrorTypeDefaults::DecimalType
};

Err(ValError::new(error_type, self))
}

type Dict<'a> = GenericPyMapping<'a, 'py> where Self: 'a;
Expand Down
4 changes: 2 additions & 2 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ impl<'py> Input<'py> for StringMapping<'py> {
}
}

fn strict_decimal(&self, _py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
fn validate_decimal(&self, _strict: bool, _py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
match self {
Self::String(s) => create_decimal(s, self),
Self::String(s) => create_decimal(s, self).map(ValidationMatch::strict),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)),
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/validators/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl Validator for DecimalValidator {
input: &(impl Input<'py> + ?Sized),
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
let decimal = input.validate_decimal(state.strict_or(self.strict), py)?;
let decimal = input.validate_decimal(state.strict_or(self.strict), py)?.unpack(state);

if !self.allow_inf_nan || self.check_digits {
if !decimal.call_method0(intern!(py, "is_finite"))?.extract()? {
Expand Down
22 changes: 21 additions & 1 deletion tests/validators/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
from dirty_equals import FunctionCheck, IsStr

from pydantic_core import SchemaValidator, ValidationError
from pydantic_core import SchemaValidator, ValidationError, core_schema

from ..conftest import Err, PyAndJson, plain_repr

Expand Down Expand Up @@ -467,3 +467,23 @@ def test_validate_max_digits_and_decimal_places_edge_case() -> None:
assert v.validate_python(Decimal('9999999999999999.999999999999999999')) == Decimal(
'9999999999999999.999999999999999999'
)


def test_str_validation_w_strict() -> None:
s = SchemaValidator(core_schema.decimal_schema(strict=True))

with pytest.raises(ValidationError):
assert s.validate_python('1.23')


def test_str_validation_w_lax() -> None:
s = SchemaValidator(core_schema.decimal_schema(strict=False))

assert s.validate_python('1.23') == Decimal('1.23')


def test_union_with_str_prefers_str() -> None:
s = SchemaValidator(core_schema.union_schema([core_schema.decimal_schema(), core_schema.str_schema()]))

assert s.validate_python('1.23') == '1.23'
assert s.validate_python(1.23) == Decimal('1.23')

0 comments on commit 08a99b5

Please sign in to comment.