diff --git a/crates/red_knot_python_semantic/resources/mdtest/protocols.md b/crates/red_knot_python_semantic/resources/mdtest/protocols.md index abba908825da8..7f19249d77bdd 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/protocols.md +++ b/crates/red_knot_python_semantic/resources/mdtest/protocols.md @@ -304,10 +304,12 @@ reveal_type(typing.Protocol is not typing_extensions.Protocol) # revealed: bool ## Calls to protocol classes + + Neither `Protocol`, nor any protocol class, can be directly instantiated: ```py -from typing import Protocol +from typing_extensions import Protocol, reveal_type # error: [call-non-callable] reveal_type(Protocol()) # revealed: Unknown @@ -315,7 +317,7 @@ reveal_type(Protocol()) # revealed: Unknown class MyProtocol(Protocol): x: int -# TODO: should emit error +# error: [call-non-callable] "Cannot instantiate class `MyProtocol`" reveal_type(MyProtocol()) # revealed: MyProtocol ``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/snapshots/protocols.md_-_Protocols_-_Calls_to_protocol_classes.snap b/crates/red_knot_python_semantic/resources/mdtest/snapshots/protocols.md_-_Protocols_-_Calls_to_protocol_classes.snap new file mode 100644 index 0000000000000..92abc59799e3e --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/snapshots/protocols.md_-_Protocols_-_Calls_to_protocol_classes.snap @@ -0,0 +1,117 @@ +--- +source: crates/red_knot_test/src/lib.rs +expression: snapshot +--- +--- +mdtest name: protocols.md - Protocols - Calls to protocol classes +mdtest path: crates/red_knot_python_semantic/resources/mdtest/protocols.md +--- + +# Python source files + +## mdtest_snippet.py + +``` + 1 | from typing_extensions import Protocol, reveal_type + 2 | + 3 | # error: [call-non-callable] + 4 | reveal_type(Protocol()) # revealed: Unknown + 5 | + 6 | class MyProtocol(Protocol): + 7 | x: int + 8 | + 9 | # error: [call-non-callable] "Cannot instantiate class `MyProtocol`" +10 | reveal_type(MyProtocol()) # revealed: MyProtocol +11 | class SubclassOfMyProtocol(MyProtocol): ... +12 | +13 | reveal_type(SubclassOfMyProtocol()) # revealed: SubclassOfMyProtocol +14 | def f(x: type[MyProtocol]): +15 | reveal_type(x()) # revealed: MyProtocol +``` + +# Diagnostics + +``` +error: lint:call-non-callable: Object of type `typing.Protocol` is not callable + --> /src/mdtest_snippet.py:4:13 + | +3 | # error: [call-non-callable] +4 | reveal_type(Protocol()) # revealed: Unknown + | ^^^^^^^^^^ +5 | +6 | class MyProtocol(Protocol): + | + +``` + +``` +info: revealed-type: Revealed type + --> /src/mdtest_snippet.py:4:1 + | +3 | # error: [call-non-callable] +4 | reveal_type(Protocol()) # revealed: Unknown + | ^^^^^^^^^^^^^^^^^^^^^^^ `Unknown` +5 | +6 | class MyProtocol(Protocol): + | + +``` + +``` +error: lint:call-non-callable: Cannot instantiate class `MyProtocol` + --> /src/mdtest_snippet.py:10:13 + | + 9 | # error: [call-non-callable] "Cannot instantiate class `MyProtocol`" +10 | reveal_type(MyProtocol()) # revealed: MyProtocol + | ^^^^^^^^^^^^ This call will raise `TypeError` at runtime +11 | class SubclassOfMyProtocol(MyProtocol): ... + | +info: Protocol classes cannot be instantiated + --> /src/mdtest_snippet.py:6:7 + | +4 | reveal_type(Protocol()) # revealed: Unknown +5 | +6 | class MyProtocol(Protocol): + | ^^^^^^^^^^^^^^^^^^^^ `MyProtocol` declared as a protocol here +7 | x: int + | + +``` + +``` +info: revealed-type: Revealed type + --> /src/mdtest_snippet.py:10:1 + | + 9 | # error: [call-non-callable] "Cannot instantiate class `MyProtocol`" +10 | reveal_type(MyProtocol()) # revealed: MyProtocol + | ^^^^^^^^^^^^^^^^^^^^^^^^^ `MyProtocol` +11 | class SubclassOfMyProtocol(MyProtocol): ... + | + +``` + +``` +info: revealed-type: Revealed type + --> /src/mdtest_snippet.py:13:1 + | +11 | class SubclassOfMyProtocol(MyProtocol): ... +12 | +13 | reveal_type(SubclassOfMyProtocol()) # revealed: SubclassOfMyProtocol + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `SubclassOfMyProtocol` +14 | def f(x: type[MyProtocol]): +15 | reveal_type(x()) # revealed: MyProtocol + | + +``` + +``` +info: revealed-type: Revealed type + --> /src/mdtest_snippet.py:15:5 + | +13 | reveal_type(SubclassOfMyProtocol()) # revealed: SubclassOfMyProtocol +14 | def f(x: type[MyProtocol]): +15 | reveal_type(x()) # revealed: MyProtocol + | ^^^^^^^^^^^^^^^^ `MyProtocol` + | + +``` diff --git a/crates/red_knot_python_semantic/src/types/class.rs b/crates/red_knot_python_semantic/src/types/class.rs index 20e0624dd35ac..9f7e2f37388d7 100644 --- a/crates/red_knot_python_semantic/src/types/class.rs +++ b/crates/red_knot_python_semantic/src/types/class.rs @@ -37,9 +37,11 @@ use crate::{ }; use indexmap::IndexSet; use itertools::Itertools as _; +use ruff_db::diagnostic::Span; use ruff_db::files::File; use ruff_python_ast::name::Name; use ruff_python_ast::{self as ast, PythonVersion}; +use ruff_text_size::{Ranged, TextRange}; use rustc_hash::{FxHashSet, FxHasher}; type FxOrderMap = ordermap::map::OrderMap>; @@ -1725,6 +1727,28 @@ impl<'db> ClassLiteralType<'db> { pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option> { self.is_protocol(db).then_some(ProtocolClassLiteral(self)) } + + /// Returns the [`Span`] of the class's "header": the class name + /// and any arguments passed to the `class` statement. E.g. + /// + /// ```ignore + /// class Foo(Bar, metaclass=Baz): ... + /// ^^^^^^^^^^^^^^^^^^^^^^^ + /// ``` + pub(super) fn header_span(self, db: &'db dyn Db) -> Span { + let class_scope = self.body_scope(db); + let class_node = class_scope.node(db).expect_class(); + let class_name = &class_node.name; + let header_range = TextRange::new( + class_name.start(), + class_node + .arguments + .as_deref() + .map(Ranged::end) + .unwrap_or_else(|| class_name.end()), + ); + Span::from(class_scope.file(db)).with_range(header_range) + } } impl<'db> From> for Type<'db> { diff --git a/crates/red_knot_python_semantic/src/types/diagnostic.rs b/crates/red_knot_python_semantic/src/types/diagnostic.rs index c4b7d4f87b31a..6cd6b3da483a1 100644 --- a/crates/red_knot_python_semantic/src/types/diagnostic.rs +++ b/crates/red_knot_python_semantic/src/types/diagnostic.rs @@ -11,7 +11,7 @@ use crate::types::string_annotation::{ use crate::types::{class::ProtocolClassLiteral, KnownFunction, KnownInstanceType, Type}; use ruff_db::diagnostic::{Annotation, Diagnostic, Severity, Span, SubDiagnostic}; use ruff_python_ast::{self as ast, AnyNodeRef}; -use ruff_text_size::{Ranged, TextRange}; +use ruff_text_size::Ranged; use rustc_hash::FxHashSet; use std::fmt::Formatter; @@ -1331,24 +1331,14 @@ pub(crate) fn report_bad_argument_to_get_protocol_members( diagnostic.set_primary_message("This call will raise `TypeError` at runtime"); diagnostic.info("Only protocol classes can be passed to `get_protocol_members`"); - let class_scope = class.body_scope(db); - let class_node = class_scope.node(db).expect_class(); - let class_name = &class_node.name; - let class_def_diagnostic_range = TextRange::new( - class_name.start(), - class_node - .arguments - .as_deref() - .map(Ranged::end) - .unwrap_or_else(|| class_name.end()), - ); let mut class_def_diagnostic = SubDiagnostic::new( Severity::Info, - format_args!("`{class_name}` is declared here, but it is not a protocol class:"), + format_args!( + "`{}` is declared here, but it is not a protocol class:", + class.name(db) + ), ); - class_def_diagnostic.annotate(Annotation::primary( - Span::from(class_scope.file(db)).with_range(class_def_diagnostic_range), - )); + class_def_diagnostic.annotate(Annotation::primary(class.header_span(db))); diagnostic.sub(class_def_diagnostic); diagnostic.info( @@ -1393,12 +1383,6 @@ pub(crate) fn report_runtime_check_against_non_runtime_checkable_protocol( )); diagnostic.set_primary_message("This call will raise `TypeError` at runtime"); - let class_scope = protocol.body_scope(db); - let class_node = class_scope.node(db).expect_class(); - let class_def_arguments = class_node - .arguments - .as_ref() - .expect("A `Protocol` class should always have at least one explicit base"); let mut class_def_diagnostic = SubDiagnostic::new( Severity::Info, format_args!( @@ -1407,11 +1391,8 @@ pub(crate) fn report_runtime_check_against_non_runtime_checkable_protocol( ), ); class_def_diagnostic.annotate( - Annotation::primary(Span::from(class_scope.file(db)).with_range(TextRange::new( - class_node.name.start(), - class_def_arguments.end(), - ))) - .message(format_args!("`{class_name}` declared here")), + Annotation::primary(protocol.header_span(db)) + .message(format_args!("`{class_name}` declared here")), ); diagnostic.sub(class_def_diagnostic); @@ -1421,3 +1402,28 @@ pub(crate) fn report_runtime_check_against_non_runtime_checkable_protocol( )); diagnostic.info("See https://docs.python.org/3/library/typing.html#typing.runtime_checkable"); } + +pub(crate) fn report_attempted_protocol_instantiation( + context: &InferContext, + call: &ast::ExprCall, + protocol: ProtocolClassLiteral, +) { + let Some(builder) = context.report_lint(&CALL_NON_CALLABLE, call) else { + return; + }; + let db = context.db(); + let class_name = protocol.name(db); + let mut diagnostic = + builder.into_diagnostic(format_args!("Cannot instantiate class `{class_name}`",)); + diagnostic.set_primary_message("This call will raise `TypeError` at runtime"); + + let mut class_def_diagnostic = SubDiagnostic::new( + Severity::Info, + format_args!("Protocol classes cannot be instantiated"), + ); + class_def_diagnostic.annotate( + Annotation::primary(protocol.header_span(db)) + .message(format_args!("`{class_name}` declared as a protocol here")), + ); + diagnostic.sub(class_def_diagnostic); +} diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 9ca92e493eaad..7cd0d4c4f95e7 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -96,8 +96,8 @@ use crate::Db; use super::context::{InNoTypeCheck, InferContext}; use super::diagnostic::{ - report_bad_argument_to_get_protocol_members, report_index_out_of_bounds, - report_invalid_exception_caught, report_invalid_exception_cause, + report_attempted_protocol_instantiation, report_bad_argument_to_get_protocol_members, + report_index_out_of_bounds, report_invalid_exception_caught, report_invalid_exception_cause, report_invalid_exception_raised, report_invalid_type_checking_constant, report_non_subscriptable, report_possibly_unresolved_reference, report_runtime_check_against_non_runtime_checkable_protocol, report_slice_step_size_zero, @@ -4280,6 +4280,20 @@ impl<'db> TypeInferenceBuilder<'db> { let mut call_arguments = Self::parse_arguments(arguments); let callable_type = self.infer_expression(func); + // It might look odd here that we emit an error for class-literals but not `type[]` types. + // But it's deliberate! The typing spec explicitly mandates that `type[]` types can be called + // even though class-literals cannot. This is because even though a protocol class `SomeProtocol` + // is always an abstract class, `type[SomeProtocol]` can be a concrete subclass of that protocol + // -- and indeed, according to the spec, type checkers must disallow abstract subclasses of the + // protocol to be passed to parameters that accept `type[SomeProtocol]`. + // . + if let Some(protocol_class) = callable_type + .into_class_literal() + .and_then(|class| class.into_protocol_class(self.db())) + { + report_attempted_protocol_instantiation(&self.context, call_expression, protocol_class); + } + // For class literals we model the entire class instantiation logic, so it is handled // in a separate function. For some known classes we have manual signatures defined and use // the `try_call` path below.