diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index b7cc49497a177d..12374faa87e2f6 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -300,7 +300,7 @@ impl<'db> UnionBuilder<'db> { } } for ty in replace_with { - self.add_in_place_impl(ty, seen_aliases); + self.add_in_place_impl(ty, seen_aliases, None); } } @@ -312,10 +312,15 @@ impl<'db> UnionBuilder<'db> { /// Adds a type to this union. pub(crate) fn add_in_place(&mut self, ty: Type<'db>) { - self.add_in_place_impl(ty, &mut vec![]); + self.add_in_place_impl(ty, &mut vec![], None); } - pub(crate) fn add_in_place_impl(&mut self, ty: Type<'db>, seen_aliases: &mut Vec>) { + pub(crate) fn add_in_place_impl( + &mut self, + ty: Type<'db>, + seen_aliases: &mut Vec>, + skip_checks_after: Option, + ) { let cycle_recovery = self.cycle_recovery; let should_widen = |literals, recursively_defined: RecursivelyDefined| { if recursively_defined.is_yes() && cycle_recovery { @@ -329,8 +334,22 @@ impl<'db> UnionBuilder<'db> { Type::Union(union) => { let new_elements = union.elements(self.db); self.elements.reserve(new_elements.len()); + // Capture the current element count to avoid comparing union elements against each other. + // The union has already been simplified, so its elements don't need redundancy checks + // between themselves, only against pre-existing elements. However, we only apply this + // optimization when not in cycle recovery AND both the builder and union have + // recursively_defined=No (indicating neither is involved in recursive type definitions + // where simplification might be incomplete). + let batch_start = if !self.cycle_recovery + && self.recursively_defined == RecursivelyDefined::No + && union.recursively_defined(self.db) == RecursivelyDefined::No + { + Some(self.elements.len()) + } else { + None + }; for element in new_elements { - self.add_in_place_impl(*element, seen_aliases); + self.add_in_place_impl(*element, seen_aliases, batch_start); } self.recursively_defined = self .recursively_defined @@ -355,7 +374,9 @@ impl<'db> UnionBuilder<'db> { // leave out the recursive alias. TODO surface this error. } else { seen_aliases.push(ty); - self.add_in_place_impl(alias.value_type(self.db), seen_aliases); + // Don't pass through skip_checks_after when unpacking aliases, since the alias + // value is not part of the original union batch. + self.add_in_place_impl(alias.value_type(self.db), seen_aliases, None); } } // If adding a string literal, look for an existing `UnionElement::StringLiterals` to @@ -371,13 +392,16 @@ impl<'db> UnionBuilder<'db> { UnionElement::StringLiterals(literals) => { if should_widen(literals.len(), self.recursively_defined) { let replace_with = KnownClass::Str.to_instance(self.db); - self.add_in_place_impl(replace_with, seen_aliases); + self.add_in_place_impl(replace_with, seen_aliases, None); return; } found = Some(literals); - continue; } UnionElement::Type(existing) => { + // Skip redundancy checks against elements from the same union batch. + if skip_checks_after.is_some_and(|batch_start| index >= batch_start) { + continue; + } // e.g. `existing` could be `Literal[""] & Any`, // and `ty` could be `Literal[""]` if ty.is_subtype_of(self.db, *existing) { @@ -420,13 +444,16 @@ impl<'db> UnionBuilder<'db> { UnionElement::BytesLiterals(literals) => { if should_widen(literals.len(), self.recursively_defined) { let replace_with = KnownClass::Bytes.to_instance(self.db); - self.add_in_place_impl(replace_with, seen_aliases); + self.add_in_place_impl(replace_with, seen_aliases, None); return; } found = Some(literals); - continue; } UnionElement::Type(existing) => { + // Skip redundancy checks against elements from the same union batch. + if skip_checks_after.is_some_and(|batch_start| index >= batch_start) { + continue; + } if ty.is_subtype_of(self.db, *existing) { return; } @@ -468,13 +495,16 @@ impl<'db> UnionBuilder<'db> { UnionElement::IntLiterals(literals) => { if should_widen(literals.len(), self.recursively_defined) { let replace_with = KnownClass::Int.to_instance(self.db); - self.add_in_place_impl(replace_with, seen_aliases); + self.add_in_place_impl(replace_with, seen_aliases, None); return; } found = Some(literals); - continue; } UnionElement::Type(existing) => { + // Skip redundancy checks against elements from the same union batch. + if skip_checks_after.is_some_and(|batch_start| index >= batch_start) { + continue; + } if ty.is_subtype_of(self.db, *existing) { return; } @@ -524,9 +554,12 @@ impl<'db> UnionBuilder<'db> { .all(|name| enum_members_in_union.contains(name)); if all_members_are_in_union { + // Don't pass through skip_checks_after when converting to enum class instance, + // since this is a transformation and not part of the original union batch. self.add_in_place_impl( enum_member_to_add.enum_class_instance(self.db), seen_aliases, + None, ); } else if !self .elements @@ -534,7 +567,11 @@ impl<'db> UnionBuilder<'db> { .filter_map(UnionElement::to_type_element) .any(|ty| Type::EnumLiteral(enum_member_to_add).is_subtype_of(self.db, ty)) { - self.push_type(Type::EnumLiteral(enum_member_to_add), seen_aliases); + self.push_type( + Type::EnumLiteral(enum_member_to_add), + seen_aliases, + skip_checks_after, + ); } } // Adding `object` to a union results in `object`. @@ -542,12 +579,17 @@ impl<'db> UnionBuilder<'db> { self.collapse_to_object(); } _ => { - self.push_type(ty, seen_aliases); + self.push_type(ty, seen_aliases, skip_checks_after); } } } - fn push_type(&mut self, ty: Type<'db>, seen_aliases: &mut Vec>) { + fn push_type( + &mut self, + ty: Type<'db>, + seen_aliases: &mut Vec>, + skip_checks_after: Option, + ) { let bool_pair = if let Type::BooleanLiteral(b) = ty { Some(Type::BooleanLiteral(!b)) } else { @@ -585,10 +627,16 @@ impl<'db> UnionBuilder<'db> { } if Some(element_type) == bool_pair { - self.add_in_place_impl(KnownClass::Bool.to_instance(self.db), seen_aliases); + self.add_in_place_impl(KnownClass::Bool.to_instance(self.db), seen_aliases, None); return; } + // Skip redundancy checks when comparing elements from the same union batch. + // The union has already been simplified, so its elements don't need to be checked + // against each other. + let skip_redundancy_checks = + skip_checks_after.is_some_and(|batch_start| i >= batch_start); + // Comparing `TypedDict`s for redundancy requires iterating over their fields, which is // problematic if some of those fields point to recursive `Union`s. To avoid cycles, // compare `TypedDict`s by name/identity instead of using the `has_relation_to` @@ -597,7 +645,10 @@ impl<'db> UnionBuilder<'db> { continue; } - if should_simplify_full && !matches!(element_type, Type::TypeAlias(_)) { + if should_simplify_full + && !skip_redundancy_checks + && !matches!(element_type, Type::TypeAlias(_)) + { if ty.is_redundant_with(self.db, element_type) { return; }