Skip to content

Commit

Permalink
Parsing string annotation is not a salsa query
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruvmanila committed Nov 8, 2024
1 parent 69c2de3 commit 666e12e
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 98 deletions.
16 changes: 1 addition & 15 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ pub(super) struct SemanticIndexBuilder<'db> {
current_assignments: Vec<CurrentAssignment<'db>>,
/// The match case we're currently visiting.
current_match_case: Option<CurrentMatchCase<'db>>,
/// The visitor is currently visiting a type annotation.
in_type_annotation: bool,

/// Flow states at each `break` in the current loop.
loop_break_states: Vec<FlowSnapshot>,
Expand Down Expand Up @@ -81,7 +79,6 @@ impl<'db> SemanticIndexBuilder<'db> {
current_match_case: None,
loop_break_states: vec![],
try_node_context_stack_manager: TryNodeContextStackManager::default(),
in_type_annotation: false,

has_future_annotations: false,

Expand Down Expand Up @@ -448,13 +445,6 @@ impl<'db> SemanticIndexBuilder<'db> {
self.pop_scope();
}

fn visit_type_annotation(&mut self, expr: &'db ast::Expr) {
let was_in_type_annotation = self.in_type_annotation;
self.in_type_annotation = true;
self.visit_expr(expr);
self.in_type_annotation = was_in_type_annotation;
}

fn declare_parameter(&mut self, parameter: AnyParameterRef<'db>) {
let symbol = self.add_symbol(parameter.name().id().clone());

Expand Down Expand Up @@ -677,7 +667,7 @@ where
}
ast::Stmt::AnnAssign(node) => {
debug_assert_eq!(&self.current_assignments, &[]);
self.visit_type_annotation(&node.annotation);
self.visit_expr(&node.annotation);
if let Some(value) = &node.value {
self.visit_expr(value);
}
Expand Down Expand Up @@ -1200,10 +1190,6 @@ where
}
}
_ => {
if matches!(expr, ast::Expr::StringLiteral(_) if self.in_type_annotation) {
self.add_standalone_expression(expr);
}

walk_expr(self, expr);
}
}
Expand Down
15 changes: 15 additions & 0 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1801,6 +1801,21 @@ impl<'db> IterationOutcome<'db> {
}
}

#[derive(Debug)]
enum AnnotationOutcome<'db> {
Type(Type<'db>),
Deferred,
}

impl<'db> AnnotationOutcome<'db> {
fn expect_type(self) -> Type<'db> {
match self {
AnnotationOutcome::Type(ty) => ty,
AnnotationOutcome::Deferred => panic!("expected a type, but got a deferred annotation"),
}
}
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum Truthiness {
/// For an object `x`, `bool(x)` will always return `True`
Expand Down
127 changes: 62 additions & 65 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ use crate::types::mro::MroErrorKind;
use crate::types::unpacker::{UnpackResult, Unpacker};
use crate::types::{
bindings_ty, builtins_symbol, declarations_ty, global_symbol, symbol, typing_extensions_symbol,
Boundness, BytesLiteralType, Class, ClassLiteralType, FunctionType, InstanceType,
IterationOutcome, KnownClass, KnownFunction, KnownInstance, MetaclassErrorKind,
AnnotationOutcome, Boundness, BytesLiteralType, Class, ClassLiteralType, FunctionType,
InstanceType, IterationOutcome, KnownClass, KnownFunction, KnownInstance, MetaclassErrorKind,
SliceLiteralType, StringLiteralType, Symbol, Truthiness, TupleType, Type, TypeArrayDisplay,
UnionBuilder, UnionType,
};
Expand Down Expand Up @@ -610,7 +610,6 @@ impl<'db> TypeInferenceBuilder<'db> {

fn infer_region_expression(&mut self, expression: Expression<'db>) {
self.infer_expression_impl(expression.node_ref(self.db));
self.check_deferred();
}

fn check_deferred(&mut self) {
Expand Down Expand Up @@ -861,11 +860,12 @@ impl<'db> TypeInferenceBuilder<'db> {
if self.are_all_types_deferred() {
self.deferred.push(definition);
} else {
// TODO(sa): Check if any of the parameter annotation or the return type is
// stringified or contains forward references. They need to be deferred as well.
// This could be done using an AST traversal and will also require resolving
// imports to special case the logic for `typing.Literal` and `typing.Annotation`.
self.infer_optional_annotation_expression(returns.as_deref());
if matches!(
self.infer_optional_annotation_expression(returns.as_deref()),
Some(AnnotationOutcome::Deferred)
) {
self.deferred.push(definition);
}
}
}

Expand Down Expand Up @@ -1494,6 +1494,12 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}

/// Infer the types in an annotated assignment definition.
///
/// This will defer the inference of the annotation expression if:
/// - It is a string literal expression
/// - In a stub file
/// - The `__future__.annotations` feature is enabled
fn infer_annotated_assignment_definition(
&mut self,
assignment: &ast::StmtAnnAssign,
Expand All @@ -1507,16 +1513,21 @@ impl<'db> TypeInferenceBuilder<'db> {
simple: _,
} = assignment;

let annotation_ty = self.infer_annotation_expression(annotation);

match annotation_ty {
AnnotationOutcome::Deferred => self.deferred.push(definition),
AnnotationOutcome::Type(annotation_ty) => {
self.add_annotated_assignment_declaration(assignment, annotation_ty, definition);
if self.are_all_types_deferred() {
self.deferred.push(definition);
} else {
match self.infer_annotation_expression(annotation) {
AnnotationOutcome::Deferred => self.deferred.push(definition),
AnnotationOutcome::Type(annotation_ty) => {
self.add_annotated_assignment_declaration(
assignment,
annotation_ty,
definition,
);
self.infer_expression(target);
}
}
}

self.infer_expression(target);
}

fn infer_annotated_assignment_deferred(
Expand All @@ -1526,35 +1537,50 @@ impl<'db> TypeInferenceBuilder<'db> {
) {
let ast::StmtAnnAssign {
range: _,
target: _,
target,
annotation,
value: _,
simple: _,
} = assignment;

let annotation_expression = self.index.expression(&**annotation);

let annotation_ty = match parse_string_annotation(self.db, annotation_expression) {
Ok(parsed) => match self.infer_annotation_expression_no_store(parsed.expr()) {
AnnotationOutcome::Type(annotation_ty) => annotation_ty,
AnnotationOutcome::Deferred => {
self.deferred.push(definition);
return;
// "Foo" -> Foo
// "SomeGeneric[Foo]" -> SomeGeneric[Foo]
// "'Foo'" -> 'Foo'
// "SomeGeneric['Foo']" -> SomeGeneric['Foo']

let annotation_ty = if let Some(string) = annotation.as_string_literal_expr() {
let annotation_ty = match parse_string_annotation(self.db, self.file, string) {
Ok(parsed) => match self.infer_annotation_expression_no_store(parsed.expr()) {
AnnotationOutcome::Type(ty) => ty,
AnnotationOutcome::Deferred => {
debug_assert!(parsed.expr().is_string_literal_expr());
self.diagnostics.add(
parsed.expr().into(),
"nested-string-annotation",
format_args!("Nested string annotations are not supported"),
);
Type::Unknown
}
},
Err(diagnostics) => {
self.diagnostics.extend(&diagnostics);
Type::Unknown
}
},
Err(diagnostics) => {
self.diagnostics.extend(diagnostics);
Type::Unknown
}
};
};

// We don't store the expression type when `infer_annotation_expression` is called for the
// first time because it would return `Deferred`. So, now as we're in the deferred region,
// we need to store it and we will do so on the original string expression instead of the
// parsed expression because the latter doesn't exists in the semantic index.
self.store_expression_type(annotation, annotation_ty);
// We don't store the expression type when `infer_annotation_expression` is called for the
// first time because it would return `Deferred`. So, now as we're in the deferred region,
// we need to store it and we will do so on the original string expression instead of the
// parsed expression because the latter doesn't exists in the semantic index.
self.store_expression_type(annotation, annotation_ty);

annotation_ty
} else {
self.infer_annotation_expression(annotation).expect_type()
};

self.add_annotated_assignment_declaration(assignment, annotation_ty, definition);
self.infer_expression(target);
}

fn add_annotated_assignment_declaration(
Expand Down Expand Up @@ -3875,12 +3901,6 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}

#[derive(Debug)]
enum AnnotationOutcome<'db> {
Type(Type<'db>),
Deferred,
}

/// Annotation expressions.
impl<'db> TypeInferenceBuilder<'db> {
fn infer_annotation_expression(&mut self, expression: &ast::Expr) -> AnnotationOutcome<'db> {
Expand Down Expand Up @@ -5618,27 +5638,4 @@ mod tests {
);
Ok(())
}

#[test]
fn debug() {
let mut db = setup_db();

db.write_file(
"/src/a.py",
"
from typing import reveal_type
x: 'Foo'
class Foo:
pass
",
)
.unwrap();

let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = global_symbol(&db, a, "x").expect_type();
println!("x_ty: {:?}", x_ty.display(&db));
}
}
26 changes: 8 additions & 18 deletions crates/red_knot_python_semantic/src/types/string_annotation.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use ruff_db::files::File;
use ruff_db::source::source_text;
use ruff_python_ast::str::raw_contents;
use ruff_python_ast::{ModExpression, StringFlags};
use ruff_python_ast::{self as ast, ModExpression, StringFlags};
use ruff_python_parser::{parse_expression_range, Parsed};
use ruff_text_size::Ranged;

use salsa::plumbing::AsId;

use crate::semantic_index::expression::Expression;
use crate::types::diagnostic::{TypeCheckDiagnostics, TypeCheckDiagnosticsBuilder};
use crate::Db;

Expand All @@ -17,25 +15,17 @@ type AnnotationParseResult = Result<Parsed<ModExpression>, TypeCheckDiagnostics>
/// # Panics
///
/// Panics if the expression is not a string literal.
#[salsa::tracked(return_ref, no_eq)]
pub(crate) fn parse_string_annotation<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
pub(crate) fn parse_string_annotation(
db: &dyn Db,
file: File,
string_expr: &ast::ExprStringLiteral,
) -> AnnotationParseResult {
let file = expression.file(db);
let _span =
tracing::trace_span!("parse_string_annotation", expression=?expression.as_id(), file=%file.path(db))
.entered();
let _span = tracing::trace_span!("parse_string_annotation", string=?string_expr.range(), file=%file.path(db)).entered();

let source = source_text(db.upcast(), file);
let node = expression.node_ref(db).node();
let Some(string_expr) = node.as_string_literal_expr() else {
panic!("Expected a string literal expression");
};

let node_text = &source[string_expr.range()];
let mut diagnostics = TypeCheckDiagnosticsBuilder::new(db, file);

let node_text = &source[string_expr.range()];
if let [string_literal] = string_expr.value.as_slice() {
let prefix = string_literal.flags.prefix();
if prefix.is_raw() {
Expand Down

0 comments on commit 666e12e

Please sign in to comment.