Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def f1(a: int = 1) -> None: ...
def f2(a: int = 2) -> None: ...

static_assert(is_equivalent_to(CallableTypeOf[f1], CallableTypeOf[f2]))
static_assert(is_equivalent_to(CallableTypeOf[f1] | bool | CallableTypeOf[f2], CallableTypeOf[f2] | bool | CallableTypeOf[f1]))
```

The names of the positional-only, variadic and keyword-variadic parameters does not need to be the
Expand All @@ -144,6 +145,7 @@ def f3(a1: int, /, *args1: int, **kwargs2: int) -> None: ...
def f4(a2: int, /, *args2: int, **kwargs1: int) -> None: ...

static_assert(is_equivalent_to(CallableTypeOf[f3], CallableTypeOf[f4]))
static_assert(is_equivalent_to(CallableTypeOf[f3] | bool | CallableTypeOf[f4], CallableTypeOf[f4] | bool | CallableTypeOf[f3]))
```

Putting it all together, the following two callables are equivalent:
Expand All @@ -153,6 +155,7 @@ def f5(a1: int, /, b: float, c: bool = False, *args1: int, d: int = 1, e: str, *
def f6(a2: int, /, b: float, c: bool = True, *args2: int, d: int = 2, e: str, **kwargs2: float) -> None: ...

static_assert(is_equivalent_to(CallableTypeOf[f5], CallableTypeOf[f6]))
static_assert(is_equivalent_to(CallableTypeOf[f5] | bool | CallableTypeOf[f6], CallableTypeOf[f6] | bool | CallableTypeOf[f5]))
```

### Not equivalent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def f4(a=2): ...
def f5(a): ...

static_assert(is_gradual_equivalent_to(CallableTypeOf[f3], CallableTypeOf[f4]))
static_assert(
is_gradual_equivalent_to(CallableTypeOf[f3] | bool | CallableTypeOf[f4], CallableTypeOf[f4] | bool | CallableTypeOf[f3])
)
static_assert(not is_gradual_equivalent_to(CallableTypeOf[f3], CallableTypeOf[f5]))

def f6(a, /): ...
Expand Down
79 changes: 44 additions & 35 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,19 +595,22 @@ impl<'db> Type<'db> {
}
}

/// Return a normalized version of `self` in which all unions and intersections are sorted
/// according to a canonical order, no matter how "deeply" a union/intersection may be nested.
/// Return a "normalized" version of `self` that ensures that equivalent types have the same Salsa ID.
///
/// A normalized type:
/// - Has all unions and intersections sorted according to a canonical order,
/// no matter how "deeply" a union/intersection may be nested.
/// - Strips the names of positional-only parameters and variadic parameters from `Callable` types,
/// as these are irrelevant to whether a callable type `X` is equivalent to a callable type `Y`.
/// - Strips the types of default values from parameters in `Callable` types: only whether a parameter
/// *has* or *does not have* a default value is relevant to whether two `Callable` types are equivalent.
#[must_use]
pub fn with_sorted_unions_and_intersections(self, db: &'db dyn Db) -> Self {
pub fn normalized(self, db: &'db dyn Db) -> Self {
match self {
Type::Union(union) => Type::Union(union.to_sorted_union(db)),
Type::Intersection(intersection) => {
Type::Intersection(intersection.to_sorted_intersection(db))
}
Type::Tuple(tuple) => Type::Tuple(tuple.with_sorted_unions_and_intersections(db)),
Type::Callable(callable) => {
Type::Callable(callable.with_sorted_unions_and_intersections(db))
}
Type::Union(union) => Type::Union(union.normalized(db)),
Type::Intersection(intersection) => Type::Intersection(intersection.normalized(db)),
Type::Tuple(tuple) => Type::Tuple(tuple.normalized(db)),
Type::Callable(callable) => Type::Callable(callable.normalized(db)),
Type::LiteralString
| Type::Instance(_)
| Type::PropertyInstance(_)
Expand Down Expand Up @@ -4656,16 +4659,19 @@ impl<'db> CallableType<'db> {
)
}

fn with_sorted_unions_and_intersections(self, db: &'db dyn Db) -> Self {
/// Return a "normalized" version of this `Callable` type.
///
/// See [`Type::normalized`] for more details.
fn normalized(self, db: &'db dyn Db) -> Self {
let signature = self.signature(db);
let parameters = signature
.parameters()
.iter()
.map(|param| param.clone().with_sorted_unions_and_intersections(db))
.map(|param| param.normalized(db))
.collect();
let return_ty = signature
.return_ty
.map(|return_ty| return_ty.with_sorted_unions_and_intersections(db));
.map(|return_ty| return_ty.normalized(db));
CallableType::new(db, Signature::new(parameters, return_ty))
}

Expand Down Expand Up @@ -5423,13 +5429,15 @@ impl<'db> UnionType<'db> {
self.elements(db).iter().all(|ty| ty.is_fully_static(db))
}

/// Create a new union type with the elements sorted according to a canonical ordering.
/// Create a new union type with the elements normalized.
///
/// See [`Type::normalized`] for more details.
#[must_use]
pub fn to_sorted_union(self, db: &'db dyn Db) -> Self {
pub fn normalized(self, db: &'db dyn Db) -> Self {
let mut new_elements: Vec<Type<'db>> = self
.elements(db)
.iter()
.map(|element| element.with_sorted_unions_and_intersections(db))
.map(|element| element.normalized(db))
.collect();
new_elements.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r));
UnionType::new(db, new_elements.into_boxed_slice())
Expand Down Expand Up @@ -5463,13 +5471,13 @@ impl<'db> UnionType<'db> {
return true;
}

let sorted_self = self.to_sorted_union(db);
let sorted_self = self.normalized(db);

if sorted_self == other {
return true;
}

sorted_self == other.to_sorted_union(db)
sorted_self == other.normalized(db)
}

/// Return `true` if `self` has exactly the same set of possible static materializations as `other`
Expand All @@ -5486,13 +5494,13 @@ impl<'db> UnionType<'db> {
return false;
}

let sorted_self = self.to_sorted_union(db);
let sorted_self = self.normalized(db);

if sorted_self == other {
return true;
}

let sorted_other = other.to_sorted_union(db);
let sorted_other = other.normalized(db);

if sorted_self == sorted_other {
return true;
Expand Down Expand Up @@ -5523,17 +5531,17 @@ pub struct IntersectionType<'db> {

impl<'db> IntersectionType<'db> {
/// Return a new `IntersectionType` instance with the positive and negative types sorted
/// according to a canonical ordering.
/// according to a canonical ordering, and other normalizations applied to each element as applicable.
///
/// See [`Type::normalized`] for more details.
#[must_use]
pub fn to_sorted_intersection(self, db: &'db dyn Db) -> Self {
pub fn normalized(self, db: &'db dyn Db) -> Self {
fn normalized_set<'db>(
db: &'db dyn Db,
elements: &FxOrderSet<Type<'db>>,
) -> FxOrderSet<Type<'db>> {
let mut elements: FxOrderSet<Type<'db>> = elements
.iter()
.map(|ty| ty.with_sorted_unions_and_intersections(db))
.collect();
let mut elements: FxOrderSet<Type<'db>> =
elements.iter().map(|ty| ty.normalized(db)).collect();

elements.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r));
elements
Expand Down Expand Up @@ -5596,13 +5604,13 @@ impl<'db> IntersectionType<'db> {
return true;
}

let sorted_self = self.to_sorted_intersection(db);
let sorted_self = self.normalized(db);

if sorted_self == other {
return true;
}

sorted_self == other.to_sorted_intersection(db)
sorted_self == other.normalized(db)
}

/// Return `true` if `self` has exactly the same set of possible static materializations as `other`
Expand All @@ -5618,13 +5626,13 @@ impl<'db> IntersectionType<'db> {
return false;
}

let sorted_self = self.to_sorted_intersection(db);
let sorted_self = self.normalized(db);

if sorted_self == other {
return true;
}

let sorted_other = other.to_sorted_intersection(db);
let sorted_other = other.normalized(db);

if sorted_self == sorted_other {
return true;
Expand Down Expand Up @@ -5806,14 +5814,15 @@ impl<'db> TupleType<'db> {
Type::Tuple(Self::new(db, elements.into_boxed_slice()))
}

/// Return a normalized version of `self` in which all unions and intersections are sorted
/// according to a canonical order, no matter how "deeply" a union/intersection may be nested.
/// Return a normalized version of `self`.
///
/// See [`Type::normalized`] for more details.
#[must_use]
pub fn with_sorted_unions_and_intersections(self, db: &'db dyn Db) -> Self {
pub fn normalized(self, db: &'db dyn Db) -> Self {
let elements: Box<[Type<'db>]> = self
.elements(db)
.iter()
.map(|ty| ty.with_sorted_unions_and_intersections(db))
.map(|ty| ty.normalized(db))
.collect();
TupleType::new(db, elements)
}
Expand Down
55 changes: 39 additions & 16 deletions crates/red_knot_python_semantic/src/types/signatures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -606,31 +606,54 @@ impl<'db> Parameter<'db> {
self
}

pub(crate) fn with_sorted_unions_and_intersections(mut self, db: &'db dyn Db) -> Self {
self.annotated_type = self
.annotated_type
.map(|ty| ty.with_sorted_unions_and_intersections(db));

self.kind = match self.kind {
ParameterKind::PositionalOnly { name, default_type } => ParameterKind::PositionalOnly {
name,
default_type: default_type.map(|ty| ty.with_sorted_unions_and_intersections(db)),
/// Strip information from the parameter so that two equivalent parameters compare equal.
/// Normalize nested unions and intersections in the annotated type, if any.
///
/// See [`Type::normalized`] for more details.
pub(crate) fn normalized(&self, db: &'db dyn Db) -> Self {
let Parameter {
annotated_type,
kind,
form,
} = self;

// Ensure unions and intersections are ordered in the annotated type (if there is one)
let annotated_type = annotated_type.map(|ty| ty.normalized(db));

// Ensure that parameter names are stripped from positional-only, variadic and keyword-variadic parameters.
// Ensure that we only record whether a parameter *has* a default
// (strip the precise *type* of the default from the parameter, replacing it with `Never`).
let kind = match kind {
ParameterKind::PositionalOnly {
name: _,
default_type,
} => ParameterKind::PositionalOnly {
name: None,
default_type: default_type.map(|_| Type::Never),
},
ParameterKind::PositionalOrKeyword { name, default_type } => {
ParameterKind::PositionalOrKeyword {
name,
default_type: default_type
.map(|ty| ty.with_sorted_unions_and_intersections(db)),
name: name.clone(),
default_type: default_type.map(|_| Type::Never),
}
}
ParameterKind::KeywordOnly { name, default_type } => ParameterKind::KeywordOnly {
name,
default_type: default_type.map(|ty| ty.with_sorted_unions_and_intersections(db)),
name: name.clone(),
default_type: default_type.map(|_| Type::Never),
},
ParameterKind::Variadic { name: _ } => ParameterKind::Variadic {
name: Name::new_static("args"),
},
ParameterKind::KeywordVariadic { name: _ } => ParameterKind::KeywordVariadic {
name: Name::new_static("kwargs"),
},
ParameterKind::Variadic { .. } | ParameterKind::KeywordVariadic { .. } => self.kind,
};

self
Self {
annotated_type,
kind,
form: *form,
}
}

fn from_node_and_kind(
Expand Down
14 changes: 9 additions & 5 deletions crates/red_knot_python_semantic/src/types/type_ordering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,17 @@ pub(super) fn union_or_intersection_elements_ordering<'db>(
(Type::WrapperDescriptor(_), _) => Ordering::Less,
(_, Type::WrapperDescriptor(_)) => Ordering::Greater,

(Type::Callable(left), Type::Callable(right)) => left.cmp(right),
(Type::Callable(left), Type::Callable(right)) => {
debug_assert_eq!(*left, left.normalized(db));
debug_assert_eq!(*right, right.normalized(db));
left.cmp(right)
}
(Type::Callable(_), _) => Ordering::Less,
(_, Type::Callable(_)) => Ordering::Greater,

(Type::Tuple(left), Type::Tuple(right)) => {
debug_assert_eq!(*left, left.with_sorted_unions_and_intersections(db));
debug_assert_eq!(*right, right.with_sorted_unions_and_intersections(db));
debug_assert_eq!(*left, left.normalized(db));
debug_assert_eq!(*right, right.normalized(db));
left.cmp(right)
}
(Type::Tuple(_), _) => Ordering::Less,
Expand Down Expand Up @@ -271,8 +275,8 @@ pub(super) fn union_or_intersection_elements_ordering<'db>(
}

(Type::Intersection(left), Type::Intersection(right)) => {
debug_assert_eq!(*left, left.to_sorted_intersection(db));
debug_assert_eq!(*right, right.to_sorted_intersection(db));
debug_assert_eq!(*left, left.normalized(db));
debug_assert_eq!(*right, right.normalized(db));

if left == right {
return Ordering::Equal;
Expand Down
Loading