Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 67 additions & 16 deletions crates/ty_python_semantic/src/types/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ impl<'db> UnionBuilder<'db> {
}
}
for ty in replace_with {
self.add_in_place_impl(ty, seen_aliases);
self.add_in_place_impl(ty, seen_aliases, None);
}
}

Expand All @@ -312,10 +312,15 @@ impl<'db> UnionBuilder<'db> {

/// Adds a type to this union.
pub(crate) fn add_in_place(&mut self, ty: Type<'db>) {
self.add_in_place_impl(ty, &mut vec![]);
self.add_in_place_impl(ty, &mut vec![], None);
}

pub(crate) fn add_in_place_impl(&mut self, ty: Type<'db>, seen_aliases: &mut Vec<Type<'db>>) {
pub(crate) fn add_in_place_impl(
&mut self,
ty: Type<'db>,
seen_aliases: &mut Vec<Type<'db>>,
skip_checks_after: Option<usize>,
) {
let cycle_recovery = self.cycle_recovery;
let should_widen = |literals, recursively_defined: RecursivelyDefined| {
if recursively_defined.is_yes() && cycle_recovery {
Expand All @@ -329,8 +334,22 @@ impl<'db> UnionBuilder<'db> {
Type::Union(union) => {
let new_elements = union.elements(self.db);
self.elements.reserve(new_elements.len());
// Capture the current element count to avoid comparing union elements against each other.
// The union has already been simplified, so its elements don't need redundancy checks
// between themselves, only against pre-existing elements. However, we only apply this
// optimization when not in cycle recovery AND both the builder and union have
// recursively_defined=No (indicating neither is involved in recursive type definitions
// where simplification might be incomplete).
let batch_start = if !self.cycle_recovery
&& self.recursively_defined == RecursivelyDefined::No
&& union.recursively_defined(self.db) == RecursivelyDefined::No
{
Some(self.elements.len())
} else {
None
};
for element in new_elements {
self.add_in_place_impl(*element, seen_aliases);
self.add_in_place_impl(*element, seen_aliases, batch_start);
}
self.recursively_defined = self
.recursively_defined
Expand All @@ -355,7 +374,9 @@ impl<'db> UnionBuilder<'db> {
// leave out the recursive alias. TODO surface this error.
} else {
seen_aliases.push(ty);
self.add_in_place_impl(alias.value_type(self.db), seen_aliases);
// Don't pass through skip_checks_after when unpacking aliases, since the alias
// value is not part of the original union batch.
self.add_in_place_impl(alias.value_type(self.db), seen_aliases, None);
}
}
// If adding a string literal, look for an existing `UnionElement::StringLiterals` to
Expand All @@ -371,13 +392,16 @@ impl<'db> UnionBuilder<'db> {
UnionElement::StringLiterals(literals) => {
if should_widen(literals.len(), self.recursively_defined) {
let replace_with = KnownClass::Str.to_instance(self.db);
self.add_in_place_impl(replace_with, seen_aliases);
self.add_in_place_impl(replace_with, seen_aliases, None);
return;
}
found = Some(literals);
continue;
}
UnionElement::Type(existing) => {
// Skip redundancy checks against elements from the same union batch.
if skip_checks_after.is_some_and(|batch_start| index >= batch_start) {
continue;
}
// e.g. `existing` could be `Literal[""] & Any`,
// and `ty` could be `Literal[""]`
if ty.is_subtype_of(self.db, *existing) {
Expand Down Expand Up @@ -420,13 +444,16 @@ impl<'db> UnionBuilder<'db> {
UnionElement::BytesLiterals(literals) => {
if should_widen(literals.len(), self.recursively_defined) {
let replace_with = KnownClass::Bytes.to_instance(self.db);
self.add_in_place_impl(replace_with, seen_aliases);
self.add_in_place_impl(replace_with, seen_aliases, None);
return;
}
found = Some(literals);
continue;
}
UnionElement::Type(existing) => {
// Skip redundancy checks against elements from the same union batch.
if skip_checks_after.is_some_and(|batch_start| index >= batch_start) {
continue;
}
if ty.is_subtype_of(self.db, *existing) {
return;
}
Expand Down Expand Up @@ -468,13 +495,16 @@ impl<'db> UnionBuilder<'db> {
UnionElement::IntLiterals(literals) => {
if should_widen(literals.len(), self.recursively_defined) {
let replace_with = KnownClass::Int.to_instance(self.db);
self.add_in_place_impl(replace_with, seen_aliases);
self.add_in_place_impl(replace_with, seen_aliases, None);
return;
}
found = Some(literals);
continue;
}
UnionElement::Type(existing) => {
// Skip redundancy checks against elements from the same union batch.
if skip_checks_after.is_some_and(|batch_start| index >= batch_start) {
continue;
}
if ty.is_subtype_of(self.db, *existing) {
return;
}
Expand Down Expand Up @@ -524,30 +554,42 @@ impl<'db> UnionBuilder<'db> {
.all(|name| enum_members_in_union.contains(name));

if all_members_are_in_union {
// Don't pass through skip_checks_after when converting to enum class instance,
// since this is a transformation and not part of the original union batch.
self.add_in_place_impl(
enum_member_to_add.enum_class_instance(self.db),
seen_aliases,
None,
);
} else if !self
.elements
.iter()
.filter_map(UnionElement::to_type_element)
.any(|ty| Type::EnumLiteral(enum_member_to_add).is_subtype_of(self.db, ty))
{
self.push_type(Type::EnumLiteral(enum_member_to_add), seen_aliases);
self.push_type(
Type::EnumLiteral(enum_member_to_add),
seen_aliases,
skip_checks_after,
);
}
}
// Adding `object` to a union results in `object`.
ty if ty.is_object() => {
self.collapse_to_object();
}
_ => {
self.push_type(ty, seen_aliases);
self.push_type(ty, seen_aliases, skip_checks_after);
}
}
}

fn push_type(&mut self, ty: Type<'db>, seen_aliases: &mut Vec<Type<'db>>) {
fn push_type(
&mut self,
ty: Type<'db>,
seen_aliases: &mut Vec<Type<'db>>,
skip_checks_after: Option<usize>,
) {
let bool_pair = if let Type::BooleanLiteral(b) = ty {
Some(Type::BooleanLiteral(!b))
} else {
Expand Down Expand Up @@ -585,10 +627,16 @@ impl<'db> UnionBuilder<'db> {
}

if Some(element_type) == bool_pair {
self.add_in_place_impl(KnownClass::Bool.to_instance(self.db), seen_aliases);
self.add_in_place_impl(KnownClass::Bool.to_instance(self.db), seen_aliases, None);
return;
}

// Skip redundancy checks when comparing elements from the same union batch.
// The union has already been simplified, so its elements don't need to be checked
// against each other.
let skip_redundancy_checks =
skip_checks_after.is_some_and(|batch_start| i >= batch_start);

// Comparing `TypedDict`s for redundancy requires iterating over their fields, which is
// problematic if some of those fields point to recursive `Union`s. To avoid cycles,
// compare `TypedDict`s by name/identity instead of using the `has_relation_to`
Expand All @@ -597,7 +645,10 @@ impl<'db> UnionBuilder<'db> {
continue;
}

if should_simplify_full && !matches!(element_type, Type::TypeAlias(_)) {
if should_simplify_full
&& !skip_redundancy_checks
&& !matches!(element_type, Type::TypeAlias(_))
{
if ty.is_redundant_with(self.db, element_type) {
return;
}
Expand Down
Loading