Skip to content

Commit 3a27298

Browse files
committed
Merge branch 'dh/input-assocs' into dh/json-cow
2 parents 2a3f32c + 230309d commit 3a27298

File tree

7 files changed

+259
-73
lines changed

7 files changed

+259
-73
lines changed

src/input/input_abstract.rs

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::fmt;
22

33
use pyo3::exceptions::PyValueError;
4-
use pyo3::types::{PyDict, PyType};
4+
use pyo3::types::{PyDict, PyList, PyType};
55
use pyo3::{intern, prelude::*};
66

77
use crate::errors::{ErrorTypeDefaults, InputValue, ValError, ValResult};
@@ -42,6 +42,8 @@ impl TryFrom<&str> for InputType {
4242
}
4343
}
4444

45+
pub type ValMatch<T> = ValResult<ValidationMatch<T>>;
46+
4547
/// all types have three methods: `validate_*`, `strict_*`, `lax_*`
4648
/// the convention is to either implement:
4749
/// * `strict_*` & `lax_*` if they have different behavior
@@ -87,13 +89,13 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {
8789

8890
fn validate_dataclass_args<'a>(&'a self, dataclass_name: &str) -> ValResult<GenericArguments<'a, 'py>>;
8991

90-
fn validate_str(&self, strict: bool, coerce_numbers_to_str: bool) -> ValResult<ValidationMatch<EitherString<'_>>>;
92+
fn validate_str(&self, strict: bool, coerce_numbers_to_str: bool) -> ValMatch<EitherString<'_>>;
9193

92-
fn validate_bytes<'a>(&'a self, strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>>;
94+
fn validate_bytes<'a>(&'a self, strict: bool) -> ValMatch<EitherBytes<'a, 'py>>;
9395

94-
fn validate_bool(&self, strict: bool) -> ValResult<ValidationMatch<bool>>;
96+
fn validate_bool(&self, strict: bool) -> ValMatch<bool>;
9597

96-
fn validate_int(&self, strict: bool) -> ValResult<ValidationMatch<EitherInt<'_>>>;
98+
fn validate_int(&self, strict: bool) -> ValMatch<EitherInt<'_>>;
9799

98100
fn exact_int(&self) -> ValResult<EitherInt<'_>> {
99101
self.validate_int(true).and_then(|val_match| {
@@ -113,7 +115,7 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {
113115
})
114116
}
115117

116-
fn validate_float(&self, strict: bool) -> ValResult<ValidationMatch<EitherFloat<'_>>>;
118+
fn validate_float(&self, strict: bool) -> ValMatch<EitherFloat<'_>>;
117119

118120
fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
119121
if strict {
@@ -145,18 +147,11 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {
145147
self.validate_dict(strict)
146148
}
147149

148-
fn validate_list<'a>(&'a self, strict: bool) -> ValResult<GenericIterable<'a, 'py>> {
149-
if strict {
150-
self.strict_list()
151-
} else {
152-
self.lax_list()
153-
}
154-
}
155-
fn strict_list<'a>(&'a self) -> ValResult<GenericIterable<'a, 'py>>;
156-
#[cfg_attr(has_coverage_attribute, coverage(off))]
157-
fn lax_list<'a>(&'a self) -> ValResult<GenericIterable<'a, 'py>> {
158-
self.strict_list()
159-
}
150+
type List<'a>: Iterable<'py> + AsPyList<'py>
151+
where
152+
Self: 'a;
153+
154+
fn validate_list(&self, strict: bool) -> ValMatch<Self::List<'_>>;
160155

161156
fn validate_tuple<'a>(&'a self, strict: bool) -> ValResult<GenericIterable<'a, 'py>> {
162157
if strict {
@@ -201,25 +196,25 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {
201196

202197
fn validate_iter(&self) -> ValResult<GenericIterator>;
203198

204-
fn validate_date(&self, strict: bool) -> ValResult<ValidationMatch<EitherDate<'py>>>;
199+
fn validate_date(&self, strict: bool) -> ValMatch<EitherDate<'py>>;
205200

206201
fn validate_time(
207202
&self,
208203
strict: bool,
209204
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
210-
) -> ValResult<ValidationMatch<EitherTime<'py>>>;
205+
) -> ValMatch<EitherTime<'py>>;
211206

212207
fn validate_datetime(
213208
&self,
214209
strict: bool,
215210
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
216-
) -> ValResult<ValidationMatch<EitherDateTime<'py>>>;
211+
) -> ValMatch<EitherDateTime<'py>>;
217212

218213
fn validate_timedelta(
219214
&self,
220215
strict: bool,
221216
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
222-
) -> ValResult<ValidationMatch<EitherTimedelta<'py>>>;
217+
) -> ValMatch<EitherTimedelta<'py>>;
223218
}
224219

225220
/// The problem to solve here is that iterating collections often returns owned
@@ -238,3 +233,42 @@ impl<'py, T: Input<'py> + ?Sized> BorrowInput<'py> for &'_ T {
238233
self
239234
}
240235
}
236+
237+
pub enum Never {}
238+
239+
// Pairs with Iterable below
240+
pub trait ConsumeIterator<T> {
241+
type Output;
242+
fn consume_iterator(self, iterator: impl Iterator<Item = T>) -> Self::Output;
243+
}
244+
245+
// This slightly awkward trait is used to define types which can be iterable. This formulation
246+
// arises because the Python enums have several different underlying iterator types, and we want to
247+
// be able to dispatch over each of them without overhead.
248+
pub trait Iterable<'py> {
249+
type Input: BorrowInput<'py>;
250+
fn len(&self) -> Option<usize>;
251+
fn iterate<R>(self, consumer: impl ConsumeIterator<PyResult<Self::Input>, Output = R>) -> ValResult<R>;
252+
}
253+
254+
// Necessary for inputs which don't support certain types, e.g. String -> list
255+
impl<'py> Iterable<'py> for Never {
256+
type Input = Bound<'py, PyAny>; // Doesn't really matter what this is
257+
fn len(&self) -> Option<usize> {
258+
unreachable!()
259+
}
260+
fn iterate<R>(self, _consumer: impl ConsumeIterator<PyResult<Self::Input>, Output = R>) -> ValResult<R> {
261+
unreachable!()
262+
}
263+
}
264+
265+
// Optimization pathway for inputs which are already python lists
266+
pub trait AsPyList<'py>: Iterable<'py> {
267+
fn as_py_list(&self) -> Option<&Bound<'py, PyList>>;
268+
}
269+
270+
impl<'py> AsPyList<'py> for Never {
271+
fn as_py_list(&self) -> Option<&Bound<'py, PyList>> {
272+
unreachable!()
273+
}
274+
}

src/input/input_json.rs

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ use std::borrow::Cow;
22

33
use jiter::{JsonArray, JsonValue};
44
use pyo3::prelude::*;
5-
use pyo3::types::{PyDict, PyString};
5+
use pyo3::types::{PyDict, PyList, PyString};
6+
use smallvec::SmallVec;
67
use speedate::MicrosecondsPrecisionOverflowBehavior;
78
use strum::EnumMessage;
89

@@ -13,6 +14,7 @@ use super::datetime::{
1314
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration,
1415
float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime,
1516
};
17+
use super::input_abstract::{AsPyList, ConsumeIterator, Iterable, Never, ValMatch};
1618
use super::return_enums::ValidationMatch;
1719
use super::shared::{float_as_int, int_as_bool, str_as_bool, str_as_float, str_as_int};
1820
use super::{
@@ -37,7 +39,7 @@ impl From<JsonValue<'_>> for LocItem {
3739
}
3840
}
3941

40-
impl<'py> Input<'py> for JsonValue<'_> {
42+
impl<'py, 'data> Input<'py> for JsonValue<'data> {
4143
fn as_error_value(&self) -> InputValue {
4244
// cloning JsonValue is cheap due to use of Arc
4345
InputValue::Json(self.clone().into_static())
@@ -172,16 +174,14 @@ impl<'py> Input<'py> for JsonValue<'_> {
172174
self.validate_dict(false)
173175
}
174176

175-
fn validate_list<'a>(&'a self, _strict: bool) -> ValResult<GenericIterable<'a, 'py>> {
177+
type List<'a> = &'a JsonArray<'data> where Self: 'a;
178+
179+
fn validate_list(&self, _strict: bool) -> ValMatch<&JsonArray<'data>> {
176180
match self {
177-
JsonValue::Array(a) => Ok(GenericIterable::JsonArray(a)),
181+
JsonValue::Array(a) => Ok(ValidationMatch::strict(a)),
178182
_ => Err(ValError::new(ErrorTypeDefaults::ListType, self)),
179183
}
180184
}
181-
#[cfg_attr(has_coverage_attribute, coverage(off))]
182-
fn strict_list<'a>(&'a self) -> ValResult<GenericIterable<'a, 'py>> {
183-
self.validate_list(false)
184-
}
185185

186186
fn validate_tuple<'a>(&'a self, _strict: bool) -> ValResult<GenericIterable<'a, 'py>> {
187187
// just as in set's case, List has to be allowed
@@ -375,8 +375,9 @@ impl<'py> Input<'py> for str {
375375
Err(ValError::new(ErrorTypeDefaults::DictType, self))
376376
}
377377

378-
#[cfg_attr(has_coverage_attribute, coverage(off))]
379-
fn strict_list<'a>(&'a self) -> ValResult<GenericIterable<'a, 'py>> {
378+
type List<'a> = Never;
379+
380+
fn validate_list(&self, _strict: bool) -> ValMatch<Never> {
380381
Err(ValError::new(ErrorTypeDefaults::ListType, self))
381382
}
382383

@@ -449,3 +450,20 @@ impl BorrowInput<'_> for String {
449450
fn string_to_vec(s: &str) -> JsonArray {
450451
JsonArray::new(s.chars().map(|c| JsonValue::Str(c.to_string().into())).collect())
451452
}
453+
454+
impl<'a, 'data> Iterable<'_> for &'a JsonArray<'data> {
455+
type Input = &'a JsonValue<'data>;
456+
457+
fn len(&self) -> Option<usize> {
458+
Some(SmallVec::len(self))
459+
}
460+
fn iterate<R>(self, consumer: impl ConsumeIterator<PyResult<Self::Input>, Output = R>) -> ValResult<R> {
461+
Ok(consumer.consume_iterator(self.iter().map(Ok)))
462+
}
463+
}
464+
465+
impl<'py> AsPyList<'py> for &'_ JsonArray<'_> {
466+
fn as_py_list(&self) -> Option<&Bound<'py, PyList>> {
467+
None
468+
}
469+
}

src/input/input_python.rs

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use super::datetime::{
2323
float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime,
2424
EitherTime,
2525
};
26+
use super::input_abstract::ValMatch;
2627
use super::return_enums::ValidationMatch;
2728
use super::shared::{
2829
decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, str_as_bool, str_as_float, str_as_int,
@@ -461,24 +462,25 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
461462
}
462463
}
463464

464-
fn strict_list<'a>(&'a self) -> ValResult<GenericIterable<'a, 'py>> {
465-
match self.lax_list()? {
466-
GenericIterable::List(iter) => Ok(GenericIterable::List(iter)),
467-
_ => Err(ValError::new(ErrorTypeDefaults::ListType, self)),
468-
}
469-
}
465+
type List<'a> = GenericIterable<'a, 'py> where Self: 'a;
470466

471-
fn lax_list<'a>(&'a self) -> ValResult<GenericIterable<'a, 'py>> {
472-
match self
473-
.extract_generic_iterable()
474-
.map_err(|_| ValError::new(ErrorTypeDefaults::ListType, self))?
475-
{
476-
GenericIterable::PyString(_)
477-
| GenericIterable::Bytes(_)
478-
| GenericIterable::Dict(_)
479-
| GenericIterable::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::ListType, self)),
480-
other => Ok(other),
467+
fn validate_list<'a>(&'a self, strict: bool) -> ValMatch<GenericIterable<'a, 'py>> {
468+
if let Ok(list) = self.downcast::<PyList>() {
469+
return Ok(ValidationMatch::exact(GenericIterable::List(list)));
470+
} else if !strict {
471+
match self.extract_generic_iterable() {
472+
Ok(
473+
GenericIterable::PyString(_)
474+
| GenericIterable::Bytes(_)
475+
| GenericIterable::Dict(_)
476+
| GenericIterable::Mapping(_),
477+
)
478+
| Err(_) => {}
479+
Ok(other) => return Ok(ValidationMatch::lax(other)),
480+
}
481481
}
482+
483+
Err(ValError::new(ErrorTypeDefaults::ListType, self))
482484
}
483485

484486
fn strict_tuple<'a>(&'a self) -> ValResult<GenericIterable<'a, 'py>> {

src/input/input_string.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use crate::validators::decimal::create_decimal;
1111
use super::datetime::{
1212
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime,
1313
};
14+
use super::input_abstract::{Never, ValMatch};
1415
use super::shared::{str_as_bool, str_as_float, str_as_int};
1516
use super::{
1617
BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable,
@@ -138,7 +139,9 @@ impl<'py> Input<'py> for StringMapping<'py> {
138139
}
139140
}
140141

141-
fn strict_list<'a>(&'a self) -> ValResult<GenericIterable<'a, 'py>> {
142+
type List<'a> = Never where Self: 'a;
143+
144+
fn validate_list(&self, _strict: bool) -> ValMatch<Never> {
142145
Err(ValError::new(ErrorTypeDefaults::ListType, self))
143146
}
144147

src/input/mod.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@ pub(crate) use datetime::{
1515
duration_as_pytimedelta, pydate_as_date, pydatetime_as_datetime, pytime_as_time, EitherDate, EitherDateTime,
1616
EitherTime, EitherTimedelta,
1717
};
18-
pub(crate) use input_abstract::{BorrowInput, Input, InputType};
18+
pub(crate) use input_abstract::{AsPyList, BorrowInput, ConsumeIterator, Input, InputType, Iterable};
1919
pub(crate) use input_string::StringMapping;
2020
pub(crate) use return_enums::{
21-
py_string_str, AttributesGenericIterator, DictGenericIterator, EitherBytes, EitherFloat, EitherInt, EitherString,
22-
GenericArguments, GenericIterable, GenericIterator, GenericMapping, Int, JsonArgs, JsonObjectGenericIterator,
23-
MappingGenericIterator, PyArgs, StringMappingGenericIterator, ValidationMatch,
21+
no_validator_iter_to_vec, py_string_str, validate_iter_to_vec, AttributesGenericIterator, DictGenericIterator,
22+
EitherBytes, EitherFloat, EitherInt, EitherString, GenericArguments, GenericIterable, GenericIterator,
23+
GenericMapping, Int, JsonArgs, JsonObjectGenericIterator, MappingGenericIterator, MaxLengthCheck, PyArgs,
24+
StringMappingGenericIterator, ValidationMatch,
2425
};
2526

2627
// Defined here as it's not exported by pyo3

0 commit comments

Comments
 (0)