Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
9117c6b
first pass at union match improvement
sydney-runkle Jun 16, 2024
fcc84ec
don't need to add explicit numbers
sydney-runkle Jun 16, 2024
0255211
figuring out fields set
sydney-runkle Jun 16, 2024
ceb5e1b
behavior for checking model fields set
sydney-runkle Jun 17, 2024
1a6b00c
working on my rust best practices
sydney-runkle Jun 17, 2024
1718e8b
using method on Validator trait to streamline things
sydney-runkle Jun 17, 2024
500a812
comment udpate
sydney-runkle Jun 17, 2024
6645ac2
new round of tests
sydney-runkle Jun 17, 2024
d479928
get all tests passing
sydney-runkle Jun 17, 2024
e16103a
use Union not pipe in tests
sydney-runkle Jun 17, 2024
9208282
abandon num_fields for a state based approach
sydney-runkle Jun 18, 2024
69f6ec0
get typed dicts working, add tests for other model like cases
sydney-runkle Jun 18, 2024
18b70d4
dataclass support + more efficient exact return
sydney-runkle Jun 18, 2024
e4f1e6b
all dataclass tests passing :)
sydney-runkle Jun 18, 2024
00ddca2
bubble up fields set in nested models
sydney-runkle Jun 18, 2024
fc21636
add nested counting for dataclasses, typed dicts
sydney-runkle Jun 18, 2024
72aa7bd
corresponding tests
sydney-runkle Jun 19, 2024
d2ef400
updating fields set at the end)
sydney-runkle Jun 19, 2024
1585e96
comments and best practice with state updates
sydney-runkle Jun 19, 2024
d3f88b7
consistency w var names
sydney-runkle Jun 19, 2024
38f911f
abbreviated syntax
sydney-runkle Jun 19, 2024
c3b43e9
adding a bubble up test + doing some minor test refactoring
sydney-runkle Jun 19, 2024
31a439a
3.8 fixes
sydney-runkle Jun 19, 2024
ed920ad
oops, another list
sydney-runkle Jun 19, 2024
645f917
ugh, last 3.8 fix hopefully
sydney-runkle Jun 19, 2024
ed8ddb9
name change success -> best_match
sydney-runkle Jun 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/validators/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ impl Validator for DataclassArgsValidator {
let mut used_keys: AHashSet<&str> = AHashSet::with_capacity(self.fields.len());

let state = &mut state.rebind_extra(|extra| extra.data = Some(output_dict.clone()));
let mut fields_set_count: usize = 0;

macro_rules! set_item {
($field:ident, $value:expr) => {{
Expand All @@ -175,6 +176,7 @@ impl Validator for DataclassArgsValidator {
Ok(Some(value)) => {
// Default value exists, and passed validation if required
set_item!(field, value);
fields_set_count += 1;
}
Ok(None) | Err(ValError::Omit) => continue,
// Note: this will always use the field name even if there is an alias
Expand Down Expand Up @@ -214,15 +216,21 @@ impl Validator for DataclassArgsValidator {
}
// found a positional argument, validate it
(Some(pos_value), None) => match field.validator.validate(py, pos_value.borrow_input(), state) {
Ok(value) => set_item!(field, value),
Ok(value) => {
set_item!(field, value);
fields_set_count += 1;
}
Err(ValError::LineErrors(line_errors)) => {
errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index)));
}
Err(err) => return Err(err),
},
// found a keyword argument, validate it
(None, Some((lookup_path, kw_value))) => match field.validator.validate(py, kw_value, state) {
Ok(value) => set_item!(field, value),
Ok(value) => {
set_item!(field, value);
fields_set_count += 1;
}
Err(ValError::LineErrors(line_errors)) => {
errors.extend(
line_errors
Expand Down Expand Up @@ -336,6 +344,8 @@ impl Validator for DataclassArgsValidator {
}
}

state.add_fields_set(fields_set_count);

if errors.is_empty() {
if let Some(init_only_args) = init_only_args {
Ok((output_dict, PyTuple::new_bound(py, init_only_args)).to_object(py))
Expand Down
12 changes: 9 additions & 3 deletions src/validators/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ impl Validator for ModelValidator {
for field_name in validated_fields_set {
fields_set.add(field_name)?;
}
state.add_fields_set(fields_set.len());
}

force_setattr(py, model, intern!(py, DUNDER_DICT), validated_dict.to_object(py))?;
Expand Down Expand Up @@ -241,10 +242,13 @@ impl ModelValidator {
} else {
PySet::new_bound(py, [&String::from(ROOT_FIELD)])?
};
force_setattr(py, self_instance, intern!(py, DUNDER_FIELDS_SET_KEY), fields_set)?;
force_setattr(py, self_instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?;
force_setattr(py, self_instance, intern!(py, ROOT_FIELD), &output)?;
state.add_fields_set(fields_set.len());
} else {
let (model_dict, model_extra, fields_set) = output.extract(py)?;
let (model_dict, model_extra, fields_set): (Bound<PyAny>, Bound<PyAny>, Bound<PyAny>) =
output.extract(py)?;
state.add_fields_set(fields_set.len().unwrap_or(0));
set_model_attrs(self_instance, &model_dict, &model_extra, &fields_set)?;
}
self.call_post_init(py, self_instance.clone(), input, state.extra())
Expand Down Expand Up @@ -281,11 +285,13 @@ impl ModelValidator {
} else {
PySet::new_bound(py, [&String::from(ROOT_FIELD)])?
};
force_setattr(py, &instance, intern!(py, DUNDER_FIELDS_SET_KEY), fields_set)?;
force_setattr(py, &instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?;
force_setattr(py, &instance, intern!(py, ROOT_FIELD), output)?;
state.add_fields_set(fields_set.len());
} else {
let (model_dict, model_extra, val_fields_set) = output.extract(py)?;
let fields_set = existing_fields_set.unwrap_or(&val_fields_set);
state.add_fields_set(fields_set.len().unwrap_or(0));
set_model_attrs(&instance, &model_dict, &model_extra, fields_set)?;
}
self.call_post_init(py, instance, input, state.extra())
Expand Down
4 changes: 4 additions & 0 deletions src/validators/typed_dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ impl Validator for TypedDictValidator {

{
let state = &mut state.rebind_extra(|extra| extra.data = Some(output_dict.clone()));
let mut fields_set_count: usize = 0;

for field in &self.fields {
let op_key_value = match dict.get_item(&field.lookup_key) {
Expand All @@ -186,6 +187,7 @@ impl Validator for TypedDictValidator {
match field.validator.validate(py, value.borrow_input(), state) {
Ok(value) => {
output_dict.set_item(&field.name_py, value)?;
fields_set_count += 1;
}
Err(ValError::Omit) => continue,
Err(ValError::LineErrors(line_errors)) => {
Expand Down Expand Up @@ -227,6 +229,8 @@ impl Validator for TypedDictValidator {
Err(err) => return Err(err),
}
}

state.add_fields_set(fields_set_count);
}

if let Some(used_keys) = used_keys {
Expand Down
54 changes: 38 additions & 16 deletions src/validators/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,12 @@ impl UnionValidator {
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
let old_exactness = state.exactness;
let old_fields_set_count = state.fields_set_count;

let strict = state.strict_or(self.strict);
let mut errors = MaybeErrors::new(self.custom_error.as_ref());

let mut success = None;
let mut best_match: Option<(Py<PyAny>, Exactness, Option<usize>)> = None;

for (choice, label) in &self.choices {
let state = &mut state.rebind_extra(|extra| {
Expand All @@ -120,47 +122,67 @@ impl UnionValidator {
}
});
state.exactness = Some(Exactness::Exact);
state.fields_set_count = None;
let result = choice.validate(py, input, state);
match result {
Ok(new_success) => match state.exactness {
// exact match, return
Some(Exactness::Exact) => {
Ok(new_success) => match (state.exactness, state.fields_set_count) {
(Some(Exactness::Exact), None) => {
// exact match with no fields set data, return immediately
return {
// exact match, return, restore any previous exactness
state.exactness = old_exactness;
state.fields_set_count = old_fields_set_count;
Ok(new_success)
};
}
_ => {
// success should always have an exactness
debug_assert_ne!(state.exactness, None);

let new_exactness = state.exactness.unwrap_or(Exactness::Lax);
// if the new result has higher exactness than the current success, replace it
if success
.as_ref()
.map_or(true, |(_, current_exactness)| *current_exactness < new_exactness)
{
// TODO: is there a possible optimization here, where once there has
// been one success, we turn on strict mode, to avoid unnecessary
// coercions for further validation?
success = Some((new_success, new_exactness));
let new_fields_set_count = state.fields_set_count;

// we use both the exactness and the fields_set_count to determine the best union member match
// if fields_set_count is available for the current best match and the new candidate, we use this
// as the primary metric. If the new fields_set_count is greater, the new candidate is better.
// if the fields_set_count is the same, we use the exactness as a tie breaker to determine the best match.
// if the fields_set_count is not available for either the current best match or the new candidate,
// we use the exactness to determine the best match.
let new_success_is_best_match: bool =
best_match
.as_ref()
.map_or(true, |(_, cur_exactness, cur_fields_set_count)| {
match (*cur_fields_set_count, new_fields_set_count) {
(Some(cur), Some(new)) if cur != new => cur < new,
_ => *cur_exactness < new_exactness,
}
});

if new_success_is_best_match {
best_match = Some((new_success, new_exactness, new_fields_set_count));
}
}
},
Err(ValError::LineErrors(lines)) => {
// if we don't yet know this validation will succeed, record the error
if success.is_none() {
if best_match.is_none() {
errors.push(choice, label.as_deref(), lines);
}
}
otherwise => return otherwise,
}
}

// restore previous validation state to prepare for any future validations
state.exactness = old_exactness;
state.fields_set_count = old_fields_set_count;

if let Some((success, exactness)) = success {
if let Some((best_match, exactness, fields_set_count)) = best_match {
state.floor_exactness(exactness);
return Ok(success);
if let Some(count) = fields_set_count {
state.add_fields_set(count);
}
return Ok(best_match);
}

// no matches, build errors
Expand Down
6 changes: 6 additions & 0 deletions src/validators/validation_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub enum Exactness {
pub struct ValidationState<'a, 'py> {
pub recursion_guard: &'a mut RecursionState,
pub exactness: Option<Exactness>,
pub fields_set_count: Option<usize>,
// deliberately make Extra readonly
extra: Extra<'a, 'py>,
}
Expand All @@ -27,6 +28,7 @@ impl<'a, 'py> ValidationState<'a, 'py> {
Self {
recursion_guard, // Don't care about exactness unless doing union validation
exactness: None,
fields_set_count: None,
extra,
}
}
Expand Down Expand Up @@ -68,6 +70,10 @@ impl<'a, 'py> ValidationState<'a, 'py> {
}
}

pub fn add_fields_set(&mut self, fields_set_count: usize) {
*self.fields_set_count.get_or_insert(0) += fields_set_count;
}

pub fn cache_str(&self) -> StringCacheMode {
self.extra.cache_str
}
Expand Down
Loading