Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,26 @@ static_assert(
)
```

## Unions containing tuples containing tuples containing unions (etc.)

```py
from knot_extensions import is_equivalent_to, static_assert, Intersection

class P: ...
class Q: ...

static_assert(
is_equivalent_to(
tuple[tuple[tuple[P | Q]]] | P,
tuple[tuple[tuple[Q | P]]] | P,
)
)
static_assert(
is_equivalent_to(
tuple[tuple[tuple[tuple[tuple[Intersection[P, Q]]]]]],
tuple[tuple[tuple[tuple[tuple[Intersection[Q, P]]]]]],
)
)
```

[the equivalence relation]: https://typing.readthedocs.io/en/latest/spec/glossary.html#term-equivalent
81 changes: 63 additions & 18 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,35 @@ impl<'db> Type<'db> {
}
}

/// Return a normalized version of `self` in which all unions and intersections are sorted
/// according to a canonical order, no matter how "deeply" a union/intersection may be nested.
#[must_use]
pub fn with_sorted_unions(self, db: &'db dyn Db) -> Self {
match self {
Type::Union(union) => Type::Union(union.to_sorted_union(db)),
Type::Intersection(intersection) => {
Type::Intersection(intersection.to_sorted_intersection(db))
}
Type::Tuple(tuple) => Type::Tuple(tuple.with_sorted_unions(db)),
Type::LiteralString
| Type::Instance(_)
| Type::AlwaysFalsy
| Type::AlwaysTruthy
| Type::BooleanLiteral(_)
| Type::SliceLiteral(_)
| Type::BytesLiteral(_)
| Type::StringLiteral(_)
| Type::Dynamic(_)
| Type::Never
| Type::FunctionLiteral(_)
| Type::ModuleLiteral(_)
| Type::ClassLiteral(_)
| Type::KnownInstance(_)
| Type::IntLiteral(_)
| Type::SubclassOf(_) => self,
}
}

/// Return true if this type is a [subtype of] type `target`.
///
/// This method returns `false` if either `self` or `other` is not fully static.
Expand Down Expand Up @@ -1154,7 +1183,7 @@ impl<'db> Type<'db> {
left.is_equivalent_to(db, right)
}
(Type::Tuple(left), Type::Tuple(right)) => left.is_equivalent_to(db, right),
_ => self.is_fully_static(db) && other.is_fully_static(db) && self == other,
_ => self == other && self.is_fully_static(db) && other.is_fully_static(db),
Comment on lines -1157 to +1186
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have a teeny tiny speedup on Codspeed, which I suspect is just due to the change on this line

}
}

Expand Down Expand Up @@ -4352,12 +4381,11 @@ impl<'db> UnionType<'db> {
/// Create a new union type with the elements sorted according to a canonical ordering.
#[must_use]
pub fn to_sorted_union(self, db: &'db dyn Db) -> Self {
let mut new_elements = self.elements(db).to_vec();
for element in &mut new_elements {
if let Type::Intersection(intersection) = element {
intersection.sort(db);
}
}
let mut new_elements: Vec<Type<'db>> = self
.elements(db)
.iter()
.map(|element| element.with_sorted_unions(db))
.collect();
new_elements.sort_unstable_by(union_elements_ordering);
UnionType::new(db, new_elements.into_boxed_slice())
}
Expand Down Expand Up @@ -4453,19 +4481,24 @@ impl<'db> IntersectionType<'db> {
/// according to a canonical ordering.
#[must_use]
pub fn to_sorted_intersection(self, db: &'db dyn Db) -> Self {
let mut positive = self.positive(db).clone();
positive.sort_unstable_by(union_elements_ordering);

let mut negative = self.negative(db).clone();
negative.sort_unstable_by(union_elements_ordering);
fn normalized_set<'db>(
db: &'db dyn Db,
elements: &FxOrderSet<Type<'db>>,
) -> FxOrderSet<Type<'db>> {
let mut elements: FxOrderSet<Type<'db>> = elements
.iter()
.map(|ty| ty.with_sorted_unions(db))
.collect();

IntersectionType::new(db, positive, negative)
}
elements.sort_unstable_by(union_elements_ordering);
elements
}

/// Perform an in-place sort of this [`IntersectionType`] instance
/// according to a canonical ordering.
fn sort(&mut self, db: &'db dyn Db) {
*self = self.to_sorted_intersection(db);
IntersectionType::new(
db,
normalized_set(db, self.positive(db)),
normalized_set(db, self.negative(db)),
)
}

pub fn is_fully_static(self, db: &'db dyn Db) -> bool {
Expand Down Expand Up @@ -4608,6 +4641,18 @@ impl<'db> TupleType<'db> {
Type::Tuple(Self::new(db, elements.into_boxed_slice()))
}

/// Return a normalized version of `self` in which all unions and intersections are sorted
/// according to a canonical order, no matter how "deeply" a union/intersection may be nested.
#[must_use]
pub fn with_sorted_unions(self, db: &'db dyn Db) -> Self {
let elements: Box<[Type<'db>]> = self
.elements(db)
.iter()
.map(|ty| ty.with_sorted_unions(db))
.collect();
TupleType::new(db, elements)
}

pub fn is_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
let self_elements = self.elements(db);
let other_elements = other.elements(db);
Expand Down
Loading