Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/red_knot_python_semantic/resources/mdtest/protocols.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,13 @@ If `Protocol` is present in the bases tuple, all other bases in the tuple must b
or `TypeError` is raised at runtime when the class is created.

```py
# TODO: should emit `[invalid-protocol]`
# error: [invalid-protocol] "Protocol class `Invalid` cannot inherit from non-protocol class `NotAProtocol`"
class Invalid(NotAProtocol, Protocol): ...

# revealed: tuple[Literal[Invalid], Literal[NotAProtocol], typing.Protocol, typing.Generic, Literal[object]]
reveal_type(Invalid.__mro__)

# TODO: should emit an `[invalid-protocol`] error
# error: [invalid-protocol] "Protocol class `AlsoInvalid` cannot inherit from non-protocol class `NotAProtocol`"
class AlsoInvalid(MyProtocol, OtherProtocol, NotAProtocol, Protocol): ...

# revealed: tuple[Literal[AlsoInvalid], Literal[MyProtocol], Literal[OtherProtocol], Literal[NotAProtocol], typing.Protocol, typing.Generic, Literal[object]]
Expand Down
29 changes: 29 additions & 0 deletions crates/red_knot_python_semantic/src/types/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) {
registry.register_lint(&INVALID_EXCEPTION_CAUGHT);
registry.register_lint(&INVALID_METACLASS);
registry.register_lint(&INVALID_PARAMETER_DEFAULT);
registry.register_lint(&INVALID_PROTOCOL);
registry.register_lint(&INVALID_RAISE);
registry.register_lint(&INVALID_SUPER_ARGUMENT);
registry.register_lint(&INVALID_TYPE_CHECKING_CONSTANT);
Expand Down Expand Up @@ -230,6 +231,34 @@ declare_lint! {
}
}

declare_lint! {
/// ## What it does
/// Checks for invalidly defined protocol classes.
///
/// ## Why is this bad?
/// An invalidly defined protocol class may lead to the type checker inferring
/// unexpected things. It may also lead to `TypeError`s at runtime.
///
/// ## Examples
/// A `Protocol` class cannot inherit from a non-`Protocol` class;
/// this raises a `TypeError` at runtime:
///
/// ```pycon
/// >>> from typing import Protocol
/// >>> class Foo(int, Protocol): ...
/// ...
/// Traceback (most recent call last):
/// File "<python-input-1>", line 1, in <module>
/// class Foo(int, Protocol): ...
/// TypeError: Protocols can only inherit from other protocols, got <class 'int'>
/// ```
pub(crate) static INVALID_PROTOCOL = {
summary: "detects invalid protocol class definitions",
status: LintStatus::preview("1.0.0"),
default_level: Level::Error,
}
}

declare_lint! {
/// TODO #14889
pub(crate) static INCONSISTENT_MRO = {
Expand Down
52 changes: 35 additions & 17 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ use super::diagnostic::{
report_index_out_of_bounds, report_invalid_exception_caught, report_invalid_exception_cause,
report_invalid_exception_raised, report_invalid_type_checking_constant,
report_non_subscriptable, report_possibly_unresolved_reference, report_slice_step_size_zero,
report_unresolved_reference, INVALID_METACLASS, REDUNDANT_CAST, STATIC_ASSERT_ERROR,
SUBCLASS_OF_FINAL_CLASS, TYPE_ASSERTION_FAILURE,
report_unresolved_reference, INVALID_METACLASS, INVALID_PROTOCOL, REDUNDANT_CAST,
STATIC_ASSERT_ERROR, SUBCLASS_OF_FINAL_CLASS, TYPE_ASSERTION_FAILURE,
};
use super::slots::check_class_slots;
use super::string_annotation::{
Expand Down Expand Up @@ -763,17 +763,21 @@ impl<'db> TypeInferenceBuilder<'db> {
continue;
}

// (2) Check for inheritance from plain `Generic`,
// and from classes that inherit from `@final` classes
let is_protocol = class.is_protocol(self.db());

// (2) Iterate through the class's explicit bases to check for various possible errors:
// - Check for inheritance from plain `Generic`,
// - Check for inheritance from a `@final` classes
// - If the class is a protocol class: check for inheritance from a non-protocol class
for (i, base_class) in class.explicit_bases(self.db()).iter().enumerate() {
let base_class = match base_class {
Type::KnownInstance(KnownInstanceType::Generic) => {
// `Generic` can appear in the MRO of many classes,
// Unsubscripted `Generic` can appear in the MRO of many classes,
// but it is never valid as an explicit base class in user code.
self.context.report_lint_old(
&INVALID_BASE,
&class_node.bases()[i],
format_args!("Cannot inherit from plain `Generic`",),
format_args!("Cannot inherit from plain `Generic`"),
);
continue;
}
Expand All @@ -782,18 +786,32 @@ impl<'db> TypeInferenceBuilder<'db> {
_ => continue,
};

if !base_class.is_final(self.db()) {
continue;
if is_protocol
&& !(base_class.is_protocol(self.db())
|| base_class.is_known(self.db(), KnownClass::Object))
{
self.context.report_lint_old(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it makes sense to add new uses of report_lint_old?

But I guess there is some argument for keeping consistency with nearby code, and changing them all together?

cc @BurntSushi

Copy link
Member Author

@AlexWaygood AlexWaygood Apr 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heh, yeah, I also wondered about this... I think I'll keep it as-is for now, purely for local consistency if nothing else 😄 apologies to @BurntSushi if this causes annoyance for him!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's okay, I'll fix it up. Thanks for highlighting this!

&INVALID_PROTOCOL,
&class_node.bases()[i],
format_args!(
"Protocol class `{}` cannot inherit from non-protocol class `{}`",
class.name(self.db()),
base_class.name(self.db()),
),
);
}

if base_class.is_final(self.db()) {
self.context.report_lint_old(
&SUBCLASS_OF_FINAL_CLASS,
&class_node.bases()[i],
format_args!(
"Class `{}` cannot inherit from final class `{}`",
class.name(self.db()),
base_class.name(self.db()),
),
);
}
self.context.report_lint_old(
&SUBCLASS_OF_FINAL_CLASS,
&class_node.bases()[i],
format_args!(
"Class `{}` cannot inherit from final class `{}`",
class.name(self.db()),
base_class.name(self.db()),
),
);
}

// (3) Check that the class's MRO is resolvable
Expand Down
10 changes: 10 additions & 0 deletions knot.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,16 @@
}
]
},
"invalid-protocol": {
"title": "detects invalid protocol class definitions",
"description": "## What it does\nChecks for invalidly defined protocol classes.\n\n## Why is this bad?\nAn invalidly defined protocol class may lead to the type checker inferring\nunexpected things. It may also lead to `TypeError`s at runtime.\n\n## Examples\nA `Protocol` class cannot inherit from a non-`Protocol` class;\nthis raises a `TypeError` at runtime:\n\n```pycon\n>>> from typing import Protocol\n>>> class Foo(int, Protocol): ...\n...\nTraceback (most recent call last):\n File \"<python-input-1>\", line 1, in <module>\n class Foo(int, Protocol): ...\nTypeError: Protocols can only inherit from other protocols, got <class 'int'>\n```",
"default": "error",
"oneOf": [
{
"$ref": "#/definitions/Level"
}
]
},
"invalid-raise": {
"title": "detects `raise` statements that raise invalid exceptions or use invalid causes",
"description": "Checks for `raise` statements that raise non-exceptions or use invalid\ncauses for their raised exceptions.\n\n## Why is this bad?\nOnly subclasses or instances of `BaseException` can be raised.\nFor an exception's cause, the same rules apply, except that `None` is also\npermitted. Violating these rules results in a `TypeError` at runtime.\n\n## Examples\n```python\ndef f():\n try:\n something()\n except NameError:\n raise \"oops!\" from f\n\ndef g():\n raise NotImplemented from 42\n```\n\nUse instead:\n```python\ndef f():\n try:\n something()\n except NameError as e:\n raise RuntimeError(\"oops!\") from e\n\ndef g():\n raise NotImplementedError from None\n```\n\n## References\n- [Python documentation: The `raise` statement](https://docs.python.org/3/reference/simple_stmts.html#raise)\n- [Python documentation: Built-in Exceptions](https://docs.python.org/3/library/exceptions.html#built-in-exceptions)",
Expand Down
Loading