From 4d6c1c14768047f23a0098ba6d4c64fde2d79664 Mon Sep 17 00:00:00 2001 From: UnboundVariable Date: Sat, 19 Jul 2025 16:20:34 -0700 Subject: [PATCH 01/10] [ty] Implemented support for "find references" language server feature. --- crates/ty_ide/src/goto.rs | 107 +++ crates/ty_ide/src/lib.rs | 2 + crates/ty_ide/src/references.rs | 883 ++++++++++++++++++ crates/ty_server/src/server.rs | 1 + crates/ty_server/src/server/api.rs | 3 + crates/ty_server/src/server/api/requests.rs | 2 + .../src/server/api/requests/references.rs | 65 ++ 7 files changed, 1063 insertions(+) create mode 100644 crates/ty_ide/src/references.rs create mode 100644 crates/ty_server/src/server/api/requests/references.rs diff --git a/crates/ty_ide/src/goto.rs b/crates/ty_ide/src/goto.rs index 134e7dd140d6a..bb4f0d5254418 100644 --- a/crates/ty_ide/src/goto.rs +++ b/crates/ty_ide/src/goto.rs @@ -270,10 +270,117 @@ impl GotoTarget<'_> { definitions_to_navigation_targets(db, stub_mapper, definitions) } + // For exception variables, they are their own definitions (like parameters) + GotoTarget::ExceptVariable(except_handler) => { + if let Some(name) = &except_handler.name { + let range = name.range; + Some(crate::NavigationTargets::single(NavigationTarget { + file, + focus_range: range, + full_range: range, + })) + } else { + None + } + } + + // For pattern match rest variables, they are their own definitions + GotoTarget::PatternMatchRest(pattern_mapping) => { + if let Some(rest_name) = &pattern_mapping.rest { + let range = rest_name.range; + Some(crate::NavigationTargets::single(NavigationTarget { + file, + focus_range: range, + full_range: range, + })) + } else { + None + } + } + + // For pattern match as names, they are their own definitions + GotoTarget::PatternMatchAsName(pattern_as) => { + if let Some(name) = &pattern_as.name { + let range = name.range; + Some(crate::NavigationTargets::single(NavigationTarget { + file, + focus_range: range, + full_range: range, + })) + } else { + None + } + } + // TODO: Handle string literals that map to TypedDict fields _ => None, } } + + /// Returns the text representation of this goto target. + /// Returns `None` if no meaningful string representation can be provided. + /// This is used by the "references" feature, which looks for references + /// to this goto target. + pub(crate) fn as_str(&self) -> Option { + match self { + GotoTarget::Expression(expression) => match expression { + ast::ExprRef::Name(name) => Some(name.id.as_str().to_string()), + ast::ExprRef::Attribute(attr) => Some(attr.attr.as_str().to_string()), + _ => None, + }, + GotoTarget::FunctionDef(function) => Some(function.name.as_str().to_string()), + GotoTarget::ClassDef(class) => Some(class.name.as_str().to_string()), + GotoTarget::Parameter(parameter) => Some(parameter.name.as_str().to_string()), + GotoTarget::ImportSymbolAlias { alias, .. } => { + if let Some(asname) = &alias.asname { + Some(asname.as_str().to_string()) + } else { + Some(alias.name.as_str().to_string()) + } + } + GotoTarget::ImportModuleComponent { + module_name, + component_index, + .. + } => { + let components: Vec<&str> = module_name.split('.').collect(); + if let Some(component) = components.get(*component_index) { + Some((*component).to_string()) + } else { + Some(module_name.clone()) + } + } + GotoTarget::ImportModuleAlias { alias } => { + if let Some(asname) = &alias.asname { + Some(asname.as_str().to_string()) + } else { + Some(alias.name.as_str().to_string()) + } + } + GotoTarget::ExceptVariable(except) => { + except.name.as_ref().map(|name| name.as_str().to_string()) + } + GotoTarget::KeywordArgument { keyword, .. } => { + keyword.arg.as_ref().map(|arg| arg.as_str().to_string()) + } + GotoTarget::PatternMatchRest(rest) => rest + .rest + .as_ref() + .map(|rest_name| rest_name.as_str().to_string()), + GotoTarget::PatternKeywordArgument(keyword) => Some(keyword.attr.as_str().to_string()), + GotoTarget::PatternMatchStarName(star) => { + star.name.as_ref().map(|name| name.as_str().to_string()) + } + GotoTarget::PatternMatchAsName(as_name) => { + as_name.name.as_ref().map(|name| name.as_str().to_string()) + } + GotoTarget::TypeParamTypeVarName(type_var) => Some(type_var.name.as_str().to_string()), + GotoTarget::TypeParamParamSpecName(spec) => Some(spec.name.as_str().to_string()), + GotoTarget::TypeParamTypeVarTupleName(tuple) => Some(tuple.name.as_str().to_string()), + GotoTarget::NonLocal { identifier, .. } => Some(identifier.as_str().to_string()), + GotoTarget::Globals { identifier, .. } => Some(identifier.as_str().to_string()), + } + } } impl Ranged for GotoTarget<'_> { diff --git a/crates/ty_ide/src/lib.rs b/crates/ty_ide/src/lib.rs index ebb7c83e4c2da..8c63eff15d4b0 100644 --- a/crates/ty_ide/src/lib.rs +++ b/crates/ty_ide/src/lib.rs @@ -9,6 +9,7 @@ mod goto_type_definition; mod hover; mod inlay_hints; mod markup; +mod references; mod semantic_tokens; mod signature_help; mod stub_mapping; @@ -20,6 +21,7 @@ pub use goto::{goto_declaration, goto_definition, goto_type_definition}; pub use hover::hover; pub use inlay_hints::inlay_hints; pub use markup::MarkupKind; +pub use references::references; pub use semantic_tokens::{ SemanticToken, SemanticTokenModifier, SemanticTokenType, SemanticTokens, semantic_tokens, }; diff --git a/crates/ty_ide/src/references.rs b/crates/ty_ide/src/references.rs new file mode 100644 index 0000000000000..ec3b4e574c2fa --- /dev/null +++ b/crates/ty_ide/src/references.rs @@ -0,0 +1,883 @@ +use crate::goto::find_goto_target; +use crate::{Db, NavigationTarget, NavigationTargets, RangedValue}; +use ruff_db::files::{File, FileRange}; +use ruff_python_ast::{self as ast, visitor::Visitor}; +use ruff_text_size::{Ranged, TextSize}; + +// This module implements the core functionality of the "references" and +// "rename" language server features. It locates all references to a named +// symbol. Unlike a simple text search for the symbol's name, this is +// a "semantic search" where the text and the semantic meaning must match. +// +// Some symbols (such as parameters and local variables) are visible only +// within their scope. All other symbols, such as those defined at the global +// scope or within classes, are visible outside of the module. Finding +// all references to these externally-visible symbols therefore requires +// an expensive search of all source files in the workspace. + +/// Find all references to a symbol at the given position. +pub fn references( + db: &dyn Db, + file: File, + offset: TextSize, + include_declaration: bool, +) -> Option>> { + let parsed = ruff_db::parsed::parsed_module(db, file); + let module = parsed.load(db); + + // Get the definitions for the symbol at the cursor position + let goto_target = find_goto_target(&module, offset)?; + let target_definitions = goto_target.get_definition_targets(file, db, None)?; + + // Extract the target text from the goto target for fast comparison + let target_text = goto_target.as_str()?; + + // Find all of the references to the symbol within this file + find_local_references( + db, + file, + &target_definitions, + include_declaration, + &target_text, + ) + + // TODO: Determine whether the symbol is potentially visible outside + // of this module. + + // TODO: For symbols that are potentially visible outside of this + // module, we need to look for references in all other files within + // the workspace. + + // TODO: Eliminate the need to parse every file by first doing a simple + // text context search to see if there is a potential match in the file. +} + +/// Find all references to a local symbol within the current file. If +/// `include_declaration` is true, return the original declaration for symbols +/// such as functions or classes that have a single declaration location. +fn find_local_references( + db: &dyn Db, + file: File, + target_definitions: &NavigationTargets, + include_declaration: bool, + target_text: &str, +) -> Option>> { + let parsed = ruff_db::parsed::parsed_module(db, file); + let module = parsed.load(db); + + let mut finder = LocalReferencesFinder { + db, + file, + target_definitions, + references: Vec::new(), + include_declaration, + module: &module, + target_text, + }; + + finder.visit_body(&module.syntax().body); + + if finder.references.is_empty() { + None + } else { + Some(finder.references) + } +} + +/// AST visitor to find all references to a specific symbol by comparing semantic definitions +struct LocalReferencesFinder<'a> { + db: &'a dyn Db, + file: File, + target_definitions: &'a NavigationTargets, + references: Vec>, + include_declaration: bool, + module: &'a ruff_db::parsed::ParsedModuleRef, + target_text: &'a str, +} + +impl<'a> Visitor<'a> for LocalReferencesFinder<'a> { + fn visit_expr(&mut self, expr: &'a ast::Expr) { + match expr { + ast::Expr::Name(name_expr) => { + self.check_reference_at_offset(name_expr.range.start(), &name_expr.id); + } + ast::Expr::Call(call_expr) => { + // Handle keyword arguments in call expressions + for keyword in &call_expr.arguments.keywords { + if let Some(arg) = &keyword.arg { + self.check_reference_at_offset(arg.range.start(), &arg.id); + } + } + } + _ => {} + } + ast::visitor::walk_expr(self, expr); + } + + fn visit_stmt(&mut self, stmt: &'a ast::Stmt) { + match stmt { + ast::Stmt::FunctionDef(func) => { + if self.include_declaration { + self.check_reference_at_offset(func.name.range.start(), &func.name.id); + } + } + ast::Stmt::ClassDef(class) => { + if self.include_declaration { + self.check_reference_at_offset(class.name.range.start(), &class.name.id); + } + } + ast::Stmt::Nonlocal(nonlocal_stmt) => { + for name in &nonlocal_stmt.names { + self.check_reference_at_offset(name.range.start(), &name.id); + } + } + ast::Stmt::Global(global_stmt) => { + for name in &global_stmt.names { + self.check_reference_at_offset(name.range.start(), &name.id); + } + } + _ => {} + } + ast::visitor::walk_stmt(self, stmt); + } + + fn visit_parameter(&mut self, parameter: &'a ast::Parameter) { + if self.include_declaration { + self.check_reference_at_offset(parameter.name.range.start(), ¶meter.name.id); + } + ast::visitor::walk_parameter(self, parameter); + } + + fn visit_except_handler(&mut self, except_handler: &'a ast::ExceptHandler) { + match except_handler { + ast::ExceptHandler::ExceptHandler(handler) => { + if let Some(name) = &handler.name { + self.check_reference_at_offset(name.range.start(), &name.id); + } + } + } + ast::visitor::walk_except_handler(self, except_handler); + } + + fn visit_pattern(&mut self, pattern: &'a ast::Pattern) { + match pattern { + ast::Pattern::MatchAs(pattern_as) => { + if let Some(name) = &pattern_as.name { + self.check_reference_at_offset(name.range.start(), &name.id); + } + } + ast::Pattern::MatchMapping(pattern_mapping) => { + if let Some(rest_name) = &pattern_mapping.rest { + self.check_reference_at_offset(rest_name.range.start(), &rest_name.id); + } + } + _ => {} + } + ast::visitor::walk_pattern(self, pattern); + } +} + +impl LocalReferencesFinder<'_> { + /// Determines whether the node at the specified offset is a reference to + /// the symbol we are searching for + fn check_reference_at_offset(&mut self, offset: TextSize, text: &str) { + // Quick text-based check first - compare with our target text + if text != self.target_text { + return; + } + + // Use find_goto_target to get the GotoTarget for this identifier + if let Some(goto_target) = find_goto_target(self.module, offset) { + let range = goto_target.range(); + + // Get the definitions for this goto target + if let Some(current_definitions) = + goto_target.get_definition_targets(self.file, self.db, None) + { + // Check if any of the current definitions match our target definitions + if self.navigation_targets_match(¤t_definitions) { + let target = NavigationTarget { + file: self.file, + focus_range: range, + full_range: range, + }; + self.references.push(RangedValue { + value: NavigationTargets::single(target), + range: FileRange::new(self.file, range), + }); + } + } + } + } +} + +impl LocalReferencesFinder<'_> { + /// Check if `NavigationTargets` match our target definitions + fn navigation_targets_match(&self, current_targets: &NavigationTargets) -> bool { + // Since we're comparing the same symbol, all definitions should be equivalent + // We only need to check against the first target definition + if let Some(first_target) = self.target_definitions.iter().next() { + for current_target in current_targets { + if current_target.file == first_target.file + && current_target.focus_range == first_target.focus_range + { + return true; + } + } + } + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::{CursorTest, IntoDiagnostic, cursor_test}; + use insta::assert_snapshot; + use ruff_db::diagnostic::{Annotation, Diagnostic, DiagnosticId, LintName, Severity, Span}; + use ruff_text_size::Ranged; + + impl CursorTest { + fn references(&self) -> String { + let Some(reference_results) = + references(&self.db, self.cursor.file, self.cursor.offset, true) + else { + return "No references found".to_string(); + }; + + if reference_results.is_empty() { + return "No references found".to_string(); + } + + self.render_diagnostics(reference_results.into_iter().enumerate().map( + |(i, ref_item)| -> ReferenceResult { + ReferenceResult { + index: i, + file_range: ref_item.range, + } + }, + )) + } + } + + struct ReferenceResult { + index: usize, + file_range: FileRange, + } + + impl IntoDiagnostic for ReferenceResult { + fn into_diagnostic(self) -> Diagnostic { + let mut main = Diagnostic::new( + DiagnosticId::Lint(LintName::of("references")), + Severity::Info, + format!("Reference {}", self.index + 1), + ); + main.annotate(Annotation::primary( + Span::from(self.file_range.file()).with_range(self.file_range.range()), + )); + + main + } + } + + #[test] + fn test_parameter_references_in_function() { + let test = cursor_test( + " +def calculate_sum(value: int) -> int: + doubled = value * 2 + result = value + doubled + return value + +# Call with keyword argument +result = calculate_sum(value=42) +", + ); + + assert_snapshot!(test.references(), @r###" + info[references]: Reference 1 + --> main.py:2:19 + | + 2 | def calculate_sum(value: int) -> int: + | ^^^^^ + 3 | doubled = value * 2 + 4 | result = value + doubled + | + + info[references]: Reference 2 + --> main.py:3:15 + | + 2 | def calculate_sum(value: int) -> int: + 3 | doubled = value * 2 + | ^^^^^ + 4 | result = value + doubled + 5 | return value + | + + info[references]: Reference 3 + --> main.py:4:14 + | + 2 | def calculate_sum(value: int) -> int: + 3 | doubled = value * 2 + 4 | result = value + doubled + | ^^^^^ + 5 | return value + | + + info[references]: Reference 4 + --> main.py:5:12 + | + 3 | doubled = value * 2 + 4 | result = value + doubled + 5 | return value + | ^^^^^ + 6 | + 7 | # Call with keyword argument + | + + info[references]: Reference 5 + --> main.py:8:24 + | + 7 | # Call with keyword argument + 8 | result = calculate_sum(value=42) + | ^^^^^ + | + "###); + } + + #[test] + #[ignore] // TODO: Enable when nonlocal support is fully implemented in goto.rs + fn test_nonlocal_variable_references() { + let test = cursor_test( + " +def outer_function(): + counter = 0 + + def increment(): + nonlocal counter + counter += 1 + return counter + + def decrement(): + nonlocal counter + counter -= 1 + return counter + + # Use counter in outer scope + initial = counter + increment() + decrement() + final = counter + + return increment, decrement +", + ); + + assert_snapshot!(test.references(), @r" + info[references]: Reference 1 + --> main.py:3:5 + | + 2 | def outer_function(): + 3 | counter = 0 + | ^^^^^^^ + 4 | + 5 | def increment(): + | + + info[references]: Reference 2 + --> main.py:6:18 + | + 5 | def increment(): + 6 | nonlocal counter + | ^^^^^^^ + 7 | counter += 1 + 8 | return counter + | + + info[references]: Reference 3 + --> main.py:7:9 + | + 5 | def increment(): + 6 | nonlocal counter + 7 | counter += 1 + | ^^^^^^^ + 8 | return counter + | + + info[references]: Reference 4 + --> main.py:8:16 + | + 6 | nonlocal counter + 7 | counter += 1 + 8 | return counter + | ^^^^^^^ + 9 | + 10 | def decrement(): + | + + info[references]: Reference 5 + --> main.py:11:18 + | + 10 | def decrement(): + 11 | nonlocal counter + | ^^^^^^^ + 12 | counter -= 1 + 13 | return counter + | + + info[references]: Reference 6 + --> main.py:12:9 + | + 10 | def decrement(): + 11 | nonlocal counter + 12 | counter -= 1 + | ^^^^^^^ + 13 | return counter + | + + info[references]: Reference 7 + --> main.py:13:16 + | + 11 | nonlocal counter + 12 | counter -= 1 + 13 | return counter + | ^^^^^^^ + 14 | + 15 | # Use counter in outer scope + | + + info[references]: Reference 8 + --> main.py:16:15 + | + 15 | # Use counter in outer scope + 16 | initial = counter + | ^^^^^^^ + 17 | increment() + 18 | decrement() + | + + info[references]: Reference 9 + --> main.py:19:13 + | + 17 | increment() + 18 | decrement() + 19 | final = counter + | ^^^^^^^ + 20 | + 21 | return increment, decrement + | + "); + } + + #[test] + #[ignore] // TODO: Enable when global support is fully implemented in goto.rs + fn test_global_variable_references() { + let test = cursor_test( + " +global_counter = 0 + +def increment_global(): + global global_counter + global_counter += 1 + return global_counter + +def decrement_global(): + global global_counter + global_counter -= 1 + return global_counter + +# Use global_counter at module level +initial_value = global_counter +increment_global() +decrement_global() +final_value = global_counter +", + ); + + assert_snapshot!(test.references(), @r" + info[references]: Reference 1 + --> main.py:2:1 + | + 2 | global_counter = 0 + | ^^^^^^^^^^^^^^ + 3 | + 4 | def increment_global(): + | + + info[references]: Reference 2 + --> main.py:5:12 + | + 4 | def increment_global(): + 5 | global global_counter + | ^^^^^^^^^^^^^^ + 6 | global_counter += 1 + 7 | return global_counter + | + + info[references]: Reference 3 + --> main.py:6:5 + | + 4 | def increment_global(): + 5 | global global_counter + 6 | global_counter += 1 + | ^^^^^^^^^^^^^^ + 7 | return global_counter + | + + info[references]: Reference 4 + --> main.py:7:12 + | + 5 | global global_counter + 6 | global_counter += 1 + 7 | return global_counter + | ^^^^^^^^^^^^^^ + 8 | + 9 | def decrement_global(): + | + + info[references]: Reference 5 + --> main.py:10:12 + | + 9 | def decrement_global(): + 10 | global global_counter + | ^^^^^^^^^^^^^^ + 11 | global_counter -= 1 + 12 | return global_counter + | + + info[references]: Reference 6 + --> main.py:11:5 + | + 9 | def decrement_global(): + 10 | global global_counter + 11 | global_counter -= 1 + | ^^^^^^^^^^^^^^ + 12 | return global_counter + | + + info[references]: Reference 7 + --> main.py:12:12 + | + 10 | global global_counter + 11 | global_counter -= 1 + 12 | return global_counter + | ^^^^^^^^^^^^^^ + 13 | + 14 | # Use global_counter at module level + | + + info[references]: Reference 8 + --> main.py:15:17 + | + 14 | # Use global_counter at module level + 15 | initial_value = global_counter + | ^^^^^^^^^^^^^^ + 16 | increment_global() + 17 | decrement_global() + | + + info[references]: Reference 9 + --> main.py:18:15 + | + 16 | increment_global() + 17 | decrement_global() + 18 | final_value = global_counter + | ^^^^^^^^^^^^^^ + | + "); + } + + #[test] + fn test_except_handler_variable_references() { + let test = cursor_test( + " +try: + x = 1 / 0 +except ZeroDivisionError as err: + print(f'Error: {err}') + return err + +try: + y = 2 / 0 +except ValueError as err: + print(f'Different error: {err}') +", + ); + + // Note: Currently only finds the declaration, not the usages + // This is because semantic analysis for except handler variables isn't fully implemented + assert_snapshot!(test.references(), @r###" + info[references]: Reference 1 + --> main.py:4:29 + | + 2 | try: + 3 | x = 1 / 0 + 4 | except ZeroDivisionError as err: + | ^^^ + 5 | print(f'Error: {err}') + 6 | return err + | + "###); + } + + #[test] + fn test_pattern_match_as_references() { + let test = cursor_test( + " +match x: + case [a, b] as pattern: + print(f'Matched: {pattern}') + return pattern + case _: + pass +", + ); + + assert_snapshot!(test.references(), @r###" + info[references]: Reference 1 + --> main.py:3:20 + | + 2 | match x: + 3 | case [a, b] as pattern: + | ^^^^^^^ + 4 | print(f'Matched: {pattern}') + 5 | return pattern + | + + info[references]: Reference 2 + --> main.py:4:27 + | + 2 | match x: + 3 | case [a, b] as pattern: + 4 | print(f'Matched: {pattern}') + | ^^^^^^^ + 5 | return pattern + 6 | case _: + | + + info[references]: Reference 3 + --> main.py:5:16 + | + 3 | case [a, b] as pattern: + 4 | print(f'Matched: {pattern}') + 5 | return pattern + | ^^^^^^^ + 6 | case _: + 7 | pass + | + "###); + } + + #[test] + fn test_pattern_match_mapping_rest_references() { + let test = cursor_test( + " +match data: + case {'a': a, 'b': b, **rest}: + print(f'Rest data: {rest}') + process(rest) + return rest +", + ); + + assert_snapshot!(test.references(), @r###" + info[references]: Reference 1 + --> main.py:3:29 + | + 2 | match data: + 3 | case {'a': a, 'b': b, **rest}: + | ^^^^ + 4 | print(f'Rest data: {rest}') + 5 | process(rest) + | + + info[references]: Reference 2 + --> main.py:4:29 + | + 2 | match data: + 3 | case {'a': a, 'b': b, **rest}: + 4 | print(f'Rest data: {rest}') + | ^^^^ + 5 | process(rest) + 6 | return rest + | + + info[references]: Reference 3 + --> main.py:5:17 + | + 3 | case {'a': a, 'b': b, **rest}: + 4 | print(f'Rest data: {rest}') + 5 | process(rest) + | ^^^^ + 6 | return rest + | + + info[references]: Reference 4 + --> main.py:6:16 + | + 4 | print(f'Rest data: {rest}') + 5 | process(rest) + 6 | return rest + | ^^^^ + | + "###); + } + + #[test] + fn test_function_definition_references() { + let test = cursor_test( + " +def my_function(): + return 42 + +# Call the function multiple times +result1 = my_function() +result2 = my_function() + +# Function passed as an argument +callback = my_function + +# Function used in different contexts +print(my_function()) +value = my_function +", + ); + + assert_snapshot!(test.references(), @r" + info[references]: Reference 1 + --> main.py:2:5 + | + 2 | def my_function(): + | ^^^^^^^^^^^ + 3 | return 42 + | + + info[references]: Reference 2 + --> main.py:6:11 + | + 5 | # Call the function multiple times + 6 | result1 = my_function() + | ^^^^^^^^^^^ + 7 | result2 = my_function() + | + + info[references]: Reference 3 + --> main.py:7:11 + | + 5 | # Call the function multiple times + 6 | result1 = my_function() + 7 | result2 = my_function() + | ^^^^^^^^^^^ + 8 | + 9 | # Function passed as an argument + | + + info[references]: Reference 4 + --> main.py:10:12 + | + 9 | # Function passed as an argument + 10 | callback = my_function + | ^^^^^^^^^^^ + 11 | + 12 | # Function used in different contexts + | + + info[references]: Reference 5 + --> main.py:13:7 + | + 12 | # Function used in different contexts + 13 | print(my_function()) + | ^^^^^^^^^^^ + 14 | value = my_function + | + + info[references]: Reference 6 + --> main.py:14:9 + | + 12 | # Function used in different contexts + 13 | print(my_function()) + 14 | value = my_function + | ^^^^^^^^^^^ + | + "); + } + + #[test] + fn test_class_definition_references() { + let test = cursor_test( + " +class MyClass: + def __init__(self): + pass + +# Create instances +obj1 = MyClass() +obj2 = MyClass() + +# Use in type annotations +def process(instance: MyClass) -> MyClass: + return instance + +# Reference the class itself +cls = MyClass +", + ); + + assert_snapshot!(test.references(), @r" + info[references]: Reference 1 + --> main.py:2:7 + | + 2 | class MyClass: + | ^^^^^^^ + 3 | def __init__(self): + 4 | pass + | + + info[references]: Reference 2 + --> main.py:7:8 + | + 6 | # Create instances + 7 | obj1 = MyClass() + | ^^^^^^^ + 8 | obj2 = MyClass() + | + + info[references]: Reference 3 + --> main.py:8:8 + | + 6 | # Create instances + 7 | obj1 = MyClass() + 8 | obj2 = MyClass() + | ^^^^^^^ + 9 | + 10 | # Use in type annotations + | + + info[references]: Reference 4 + --> main.py:11:23 + | + 10 | # Use in type annotations + 11 | def process(instance: MyClass) -> MyClass: + | ^^^^^^^ + 12 | return instance + | + + info[references]: Reference 5 + --> main.py:11:35 + | + 10 | # Use in type annotations + 11 | def process(instance: MyClass) -> MyClass: + | ^^^^^^^ + 12 | return instance + | + + info[references]: Reference 6 + --> main.py:15:7 + | + 14 | # Reference the class itself + 15 | cls = MyClass + | ^^^^^^^ + | + "); + } +} diff --git a/crates/ty_server/src/server.rs b/crates/ty_server/src/server.rs index 03b7b8326556f..fd02b0d22577d 100644 --- a/crates/ty_server/src/server.rs +++ b/crates/ty_server/src/server.rs @@ -196,6 +196,7 @@ impl Server { type_definition_provider: Some(TypeDefinitionProviderCapability::Simple(true)), definition_provider: Some(lsp_types::OneOf::Left(true)), declaration_provider: Some(DeclarationCapability::Simple(true)), + references_provider: Some(lsp_types::OneOf::Left(true)), hover_provider: Some(HoverProviderCapability::Simple(true)), signature_help_provider: Some(SignatureHelpOptions { trigger_characters: Some(vec!["(".to_string(), ",".to_string()]), diff --git a/crates/ty_server/src/server/api.rs b/crates/ty_server/src/server/api.rs index 41d14cbf15523..062fc0e10373e 100644 --- a/crates/ty_server/src/server/api.rs +++ b/crates/ty_server/src/server/api.rs @@ -56,6 +56,9 @@ pub(super) fn request(req: server::Request) -> Task { requests::HoverRequestHandler::METHOD => background_document_request_task::< requests::HoverRequestHandler, >(req, BackgroundSchedule::Worker), + requests::ReferencesRequestHandler::METHOD => background_document_request_task::< + requests::ReferencesRequestHandler, + >(req, BackgroundSchedule::Worker), requests::InlayHintRequestHandler::METHOD => background_document_request_task::< requests::InlayHintRequestHandler, >(req, BackgroundSchedule::Worker), diff --git a/crates/ty_server/src/server/api/requests.rs b/crates/ty_server/src/server/api/requests.rs index 89a3eff61524e..ff7772c35daaf 100644 --- a/crates/ty_server/src/server/api/requests.rs +++ b/crates/ty_server/src/server/api/requests.rs @@ -5,6 +5,7 @@ mod goto_definition; mod goto_type_definition; mod hover; mod inlay_hints; +mod references; mod semantic_tokens; mod semantic_tokens_range; mod shutdown; @@ -18,6 +19,7 @@ pub(super) use goto_definition::GotoDefinitionRequestHandler; pub(super) use goto_type_definition::GotoTypeDefinitionRequestHandler; pub(super) use hover::HoverRequestHandler; pub(super) use inlay_hints::InlayHintRequestHandler; +pub(super) use references::ReferencesRequestHandler; pub(super) use semantic_tokens::SemanticTokensRequestHandler; pub(super) use semantic_tokens_range::SemanticTokensRangeRequestHandler; pub(super) use shutdown::ShutdownHandler; diff --git a/crates/ty_server/src/server/api/requests/references.rs b/crates/ty_server/src/server/api/requests/references.rs new file mode 100644 index 0000000000000..63c713ac3ed1b --- /dev/null +++ b/crates/ty_server/src/server/api/requests/references.rs @@ -0,0 +1,65 @@ +use std::borrow::Cow; + +use lsp_types::request::References; +use lsp_types::{Location, ReferenceParams, Url}; +use ruff_db::source::{line_index, source_text}; +use ty_ide::references; +use ty_project::ProjectDatabase; + +use crate::document::{PositionExt, ToLink}; +use crate::server::api::traits::{ + BackgroundDocumentRequestHandler, RequestHandler, RetriableRequestHandler, +}; +use crate::session::DocumentSnapshot; +use crate::session::client::Client; + +pub(crate) struct ReferencesRequestHandler; + +impl RequestHandler for ReferencesRequestHandler { + type RequestType = References; +} + +impl BackgroundDocumentRequestHandler for ReferencesRequestHandler { + fn document_url(params: &ReferenceParams) -> Cow { + Cow::Borrowed(¶ms.text_document_position.text_document.uri) + } + + fn run_with_snapshot( + db: &ProjectDatabase, + snapshot: DocumentSnapshot, + _client: &Client, + params: ReferenceParams, + ) -> crate::server::Result>> { + if snapshot.client_settings().is_language_services_disabled() { + return Ok(None); + } + + let Some(file) = snapshot.file(db) else { + return Ok(None); + }; + + let source = source_text(db, file); + let line_index = line_index(db, file); + let offset = params.text_document_position.position.to_text_size( + &source, + &line_index, + snapshot.encoding(), + ); + + let include_declaration = params.context.include_declaration; + + let Some(references_result) = references(db, file, offset, include_declaration) else { + return Ok(None); + }; + + let locations: Vec<_> = references_result + .into_iter() + .flat_map(|ranged| ranged.value.into_iter()) + .filter_map(|target| target.to_location(db, snapshot.encoding())) + .collect(); + + Ok(Some(locations)) + } +} + +impl RetriableRequestHandler for ReferencesRequestHandler {} From ea5d0ce361bbb9e496d7d72e5c440f80c9a6b9b2 Mon Sep 17 00:00:00 2001 From: UnboundVariable Date: Tue, 22 Jul 2025 14:11:04 -0700 Subject: [PATCH 02/10] Code review feedback. --- crates/ty_ide/src/goto.rs | 95 +++++++++++++++------------------ crates/ty_ide/src/lib.rs | 9 ++++ crates/ty_ide/src/references.rs | 32 +++++------ 3 files changed, 66 insertions(+), 70 deletions(-) diff --git a/crates/ty_ide/src/goto.rs b/crates/ty_ide/src/goto.rs index bb4f0d5254418..feb9403d309df 100644 --- a/crates/ty_ide/src/goto.rs +++ b/crates/ty_ide/src/goto.rs @@ -2,6 +2,8 @@ pub use crate::goto_declaration::goto_declaration; pub use crate::goto_definition::goto_definition; pub use crate::goto_type_definition::goto_type_definition; +use std::borrow::Cow; + use crate::find_node::covering_node; use crate::stub_mapping::StubMapper; use ruff_db::parsed::ParsedModuleRef; @@ -274,11 +276,9 @@ impl GotoTarget<'_> { GotoTarget::ExceptVariable(except_handler) => { if let Some(name) = &except_handler.name { let range = name.range; - Some(crate::NavigationTargets::single(NavigationTarget { - file, - focus_range: range, - full_range: range, - })) + Some(crate::NavigationTargets::single(NavigationTarget::new( + file, range, + ))) } else { None } @@ -288,11 +288,9 @@ impl GotoTarget<'_> { GotoTarget::PatternMatchRest(pattern_mapping) => { if let Some(rest_name) = &pattern_mapping.rest { let range = rest_name.range; - Some(crate::NavigationTargets::single(NavigationTarget { - file, - focus_range: range, - full_range: range, - })) + Some(crate::NavigationTargets::single(NavigationTarget::new( + file, range, + ))) } else { None } @@ -302,11 +300,9 @@ impl GotoTarget<'_> { GotoTarget::PatternMatchAsName(pattern_as) => { if let Some(name) = &pattern_as.name { let range = name.range; - Some(crate::NavigationTargets::single(NavigationTarget { - file, - focus_range: range, - full_range: range, - })) + Some(crate::NavigationTargets::single(NavigationTarget::new( + file, range, + ))) } else { None } @@ -321,21 +317,21 @@ impl GotoTarget<'_> { /// Returns `None` if no meaningful string representation can be provided. /// This is used by the "references" feature, which looks for references /// to this goto target. - pub(crate) fn as_str(&self) -> Option { + pub(crate) fn to_string(&self) -> Option> { match self { GotoTarget::Expression(expression) => match expression { - ast::ExprRef::Name(name) => Some(name.id.as_str().to_string()), - ast::ExprRef::Attribute(attr) => Some(attr.attr.as_str().to_string()), + ast::ExprRef::Name(name) => Some(Cow::Borrowed(name.id.as_str())), + ast::ExprRef::Attribute(attr) => Some(Cow::Borrowed(attr.attr.as_str())), _ => None, }, - GotoTarget::FunctionDef(function) => Some(function.name.as_str().to_string()), - GotoTarget::ClassDef(class) => Some(class.name.as_str().to_string()), - GotoTarget::Parameter(parameter) => Some(parameter.name.as_str().to_string()), + GotoTarget::FunctionDef(function) => Some(Cow::Borrowed(function.name.as_str())), + GotoTarget::ClassDef(class) => Some(Cow::Borrowed(class.name.as_str())), + GotoTarget::Parameter(parameter) => Some(Cow::Borrowed(parameter.name.as_str())), GotoTarget::ImportSymbolAlias { alias, .. } => { if let Some(asname) = &alias.asname { - Some(asname.as_str().to_string()) + Some(Cow::Borrowed(asname.as_str())) } else { - Some(alias.name.as_str().to_string()) + Some(Cow::Borrowed(alias.name.as_str())) } } GotoTarget::ImportModuleComponent { @@ -345,40 +341,43 @@ impl GotoTarget<'_> { } => { let components: Vec<&str> = module_name.split('.').collect(); if let Some(component) = components.get(*component_index) { - Some((*component).to_string()) + Some(Cow::Borrowed(*component)) } else { - Some(module_name.clone()) + Some(Cow::Borrowed(module_name)) } } GotoTarget::ImportModuleAlias { alias } => { if let Some(asname) = &alias.asname { - Some(asname.as_str().to_string()) + Some(Cow::Borrowed(asname.as_str())) } else { - Some(alias.name.as_str().to_string()) + Some(Cow::Borrowed(alias.name.as_str())) } } GotoTarget::ExceptVariable(except) => { - except.name.as_ref().map(|name| name.as_str().to_string()) + Some(Cow::Borrowed(except.name.as_ref()?.as_str())) } GotoTarget::KeywordArgument { keyword, .. } => { - keyword.arg.as_ref().map(|arg| arg.as_str().to_string()) + Some(Cow::Borrowed(keyword.arg.as_ref()?.as_str())) + } + GotoTarget::PatternMatchRest(rest) => Some(Cow::Borrowed(rest.rest.as_ref()?.as_str())), + GotoTarget::PatternKeywordArgument(keyword) => { + Some(Cow::Borrowed(keyword.attr.as_str())) } - GotoTarget::PatternMatchRest(rest) => rest - .rest - .as_ref() - .map(|rest_name| rest_name.as_str().to_string()), - GotoTarget::PatternKeywordArgument(keyword) => Some(keyword.attr.as_str().to_string()), GotoTarget::PatternMatchStarName(star) => { - star.name.as_ref().map(|name| name.as_str().to_string()) + Some(Cow::Borrowed(star.name.as_ref()?.as_str())) } GotoTarget::PatternMatchAsName(as_name) => { - as_name.name.as_ref().map(|name| name.as_str().to_string()) + Some(Cow::Borrowed(as_name.name.as_ref()?.as_str())) + } + GotoTarget::TypeParamTypeVarName(type_var) => { + Some(Cow::Borrowed(type_var.name.as_str())) } - GotoTarget::TypeParamTypeVarName(type_var) => Some(type_var.name.as_str().to_string()), - GotoTarget::TypeParamParamSpecName(spec) => Some(spec.name.as_str().to_string()), - GotoTarget::TypeParamTypeVarTupleName(tuple) => Some(tuple.name.as_str().to_string()), - GotoTarget::NonLocal { identifier, .. } => Some(identifier.as_str().to_string()), - GotoTarget::Globals { identifier, .. } => Some(identifier.as_str().to_string()), + GotoTarget::TypeParamParamSpecName(spec) => Some(Cow::Borrowed(spec.name.as_str())), + GotoTarget::TypeParamTypeVarTupleName(tuple) => { + Some(Cow::Borrowed(tuple.name.as_str())) + } + GotoTarget::NonLocal { identifier, .. } => Some(Cow::Borrowed(identifier.as_str())), + GotoTarget::Globals { identifier, .. } => Some(Cow::Borrowed(identifier.as_str())), } } } @@ -435,11 +434,7 @@ fn convert_resolved_definitions_to_targets( } ty_python_semantic::ResolvedDefinition::FileWithRange(file_range) => { // For file ranges, navigate to the specific range within the file - crate::NavigationTarget { - file: file_range.file(), - focus_range: file_range.range(), - full_range: file_range.range(), - } + crate::NavigationTarget::new(file_range.file(), file_range.range()) } }) .collect() @@ -633,11 +628,9 @@ fn resolve_module_to_navigation_target( if let Some(module_name) = ModuleName::new(module_name_str) { if let Some(resolved_module) = resolve_module(db, &module_name) { if let Some(module_file) = resolved_module.file() { - return Some(crate::NavigationTargets::single(crate::NavigationTarget { - file: module_file, - focus_range: TextRange::default(), - full_range: TextRange::default(), - })); + return Some(crate::NavigationTargets::single( + crate::NavigationTarget::new(module_file, TextRange::default()), + )); } } } diff --git a/crates/ty_ide/src/lib.rs b/crates/ty_ide/src/lib.rs index 8c63eff15d4b0..455724589aa86 100644 --- a/crates/ty_ide/src/lib.rs +++ b/crates/ty_ide/src/lib.rs @@ -89,6 +89,15 @@ pub struct NavigationTarget { } impl NavigationTarget { + /// Creates a new `NavigationTarget` where the focus and full range are identical. + pub fn new(file: File, range: TextRange) -> Self { + Self { + file, + focus_range: range, + full_range: range, + } + } + pub fn file(&self) -> File { self.file } diff --git a/crates/ty_ide/src/references.rs b/crates/ty_ide/src/references.rs index ec3b4e574c2fa..c5851e6a8711a 100644 --- a/crates/ty_ide/src/references.rs +++ b/crates/ty_ide/src/references.rs @@ -1,20 +1,20 @@ +//! This module implements the core functionality of the "references" and +//! "rename" language server features. It locates all references to a named +//! symbol. Unlike a simple text search for the symbol's name, this is +//! a "semantic search" where the text and the semantic meaning must match. +//! +//! Some symbols (such as parameters and local variables) are visible only +//! within their scope. All other symbols, such as those defined at the global +//! scope or within classes, are visible outside of the module. Finding +//! all references to these externally-visible symbols therefore requires +//! an expensive search of all source files in the workspace. + use crate::goto::find_goto_target; use crate::{Db, NavigationTarget, NavigationTargets, RangedValue}; use ruff_db::files::{File, FileRange}; use ruff_python_ast::{self as ast, visitor::Visitor}; use ruff_text_size::{Ranged, TextSize}; -// This module implements the core functionality of the "references" and -// "rename" language server features. It locates all references to a named -// symbol. Unlike a simple text search for the symbol's name, this is -// a "semantic search" where the text and the semantic meaning must match. -// -// Some symbols (such as parameters and local variables) are visible only -// within their scope. All other symbols, such as those defined at the global -// scope or within classes, are visible outside of the module. Finding -// all references to these externally-visible symbols therefore requires -// an expensive search of all source files in the workspace. - /// Find all references to a symbol at the given position. pub fn references( db: &dyn Db, @@ -30,7 +30,7 @@ pub fn references( let target_definitions = goto_target.get_definition_targets(file, db, None)?; // Extract the target text from the goto target for fast comparison - let target_text = goto_target.as_str()?; + let target_text = goto_target.to_string()?; // Find all of the references to the symbol within this file find_local_references( @@ -196,11 +196,7 @@ impl LocalReferencesFinder<'_> { { // Check if any of the current definitions match our target definitions if self.navigation_targets_match(¤t_definitions) { - let target = NavigationTarget { - file: self.file, - focus_range: range, - full_range: range, - }; + let target = NavigationTarget::new(self.file, range); self.references.push(RangedValue { value: NavigationTargets::single(target), range: FileRange::new(self.file, range), @@ -209,9 +205,7 @@ impl LocalReferencesFinder<'_> { } } } -} -impl LocalReferencesFinder<'_> { /// Check if `NavigationTargets` match our target definitions fn navigation_targets_match(&self, current_targets: &NavigationTargets) -> bool { // Since we're comparing the same symbol, all definitions should be equivalent From 155505ac2235d709012ec86f6b9e46aebf380a9f Mon Sep 17 00:00:00 2001 From: UnboundVariable Date: Tue, 22 Jul 2025 17:19:03 -0700 Subject: [PATCH 03/10] Code review feedback. --- .../src/visitor/source_order.rs | 7 +- crates/ty_ide/src/find_node.rs | 10 +- crates/ty_ide/src/goto.rs | 296 ++++++++++-------- crates/ty_ide/src/references.rs | 144 ++++----- 4 files changed, 238 insertions(+), 219 deletions(-) diff --git a/crates/ruff_python_ast/src/visitor/source_order.rs b/crates/ruff_python_ast/src/visitor/source_order.rs index 9e42766f0d4ac..2caaab9f9a3fe 100644 --- a/crates/ruff_python_ast/src/visitor/source_order.rs +++ b/crates/ruff_python_ast/src/visitor/source_order.rs @@ -235,12 +235,7 @@ impl TraversalSignal { } pub fn walk_annotation<'a, V: SourceOrderVisitor<'a> + ?Sized>(visitor: &mut V, expr: &'a Expr) { - let node = AnyNodeRef::from(expr); - if visitor.enter_node(node).is_traverse() { - visitor.visit_expr(expr); - } - - visitor.leave_node(node); + visitor.visit_expr(expr); } pub fn walk_decorator<'a, V>(visitor: &mut V, decorator: &'a Decorator) diff --git a/crates/ty_ide/src/find_node.rs b/crates/ty_ide/src/find_node.rs index 26c4b489d1596..a56be5f7b957e 100644 --- a/crates/ty_ide/src/find_node.rs +++ b/crates/ty_ide/src/find_node.rs @@ -52,9 +52,7 @@ pub(crate) fn covering_node(root: AnyNodeRef, range: TextRange) -> CoveringNode if visitor.ancestors.is_empty() { visitor.ancestors.push(root); } - CoveringNode { - nodes: visitor.ancestors, - } + CoveringNode::from_ancestors(visitor.ancestors) } /// The node with a minimal range that fully contains the search range. @@ -67,6 +65,12 @@ pub(crate) struct CoveringNode<'a> { } impl<'a> CoveringNode<'a> { + /// Creates a new `CoveringNode` from a list of ancestor nodes. + /// The ancestors should be ordered from root to the covering node. + pub(crate) fn from_ancestors(ancestors: Vec>) -> Self { + Self { nodes: ancestors } + } + /// Returns the covering node found. pub(crate) fn node(&self) -> AnyNodeRef<'a> { *self diff --git a/crates/ty_ide/src/goto.rs b/crates/ty_ide/src/goto.rs index feb9403d309df..3299b8929d176 100644 --- a/crates/ty_ide/src/goto.rs +++ b/crates/ty_ide/src/goto.rs @@ -380,6 +380,162 @@ impl GotoTarget<'_> { GotoTarget::Globals { identifier, .. } => Some(Cow::Borrowed(identifier.as_str())), } } + + /// Creates a `GotoTarget` from a `CoveringNode` and an offset within the node + pub(crate) fn from_covering_node<'a>( + covering_node: &crate::find_node::CoveringNode<'a>, + offset: TextSize, + ) -> Option> { + tracing::trace!("Covering node is of kind {:?}", covering_node.node().kind()); + + match covering_node.node() { + AnyNodeRef::Identifier(identifier) => match covering_node.parent() { + Some(AnyNodeRef::StmtFunctionDef(function)) => { + Some(GotoTarget::FunctionDef(function)) + } + Some(AnyNodeRef::StmtClassDef(class)) => Some(GotoTarget::ClassDef(class)), + Some(AnyNodeRef::Parameter(parameter)) => Some(GotoTarget::Parameter(parameter)), + Some(AnyNodeRef::Alias(alias)) => { + // Find the containing import statement to determine the type + let import_stmt = covering_node.ancestors().find(|node| { + matches!( + node, + AnyNodeRef::StmtImport(_) | AnyNodeRef::StmtImportFrom(_) + ) + }); + + match import_stmt { + Some(AnyNodeRef::StmtImport(_)) => { + // Regular import statement like "import x.y as z" + + // Is the offset within the alias name (asname) part? + if let Some(asname) = &alias.asname { + if asname.range.contains_inclusive(offset) { + return Some(GotoTarget::ImportModuleAlias { alias }); + } + } + + // Is the offset in the module name part? + if alias.name.range.contains_inclusive(offset) { + let full_name = alias.name.as_str(); + + if let Some((component_index, component_range)) = + find_module_component( + full_name, + alias.name.range.start(), + offset, + ) + { + return Some(GotoTarget::ImportModuleComponent { + module_name: full_name.to_string(), + component_index, + component_range, + }); + } + } + + None + } + Some(AnyNodeRef::StmtImportFrom(import_from)) => { + // From import statement like "from x import y as z" + + // Is the offset within the alias name (asname) part? + if let Some(asname) = &alias.asname { + if asname.range.contains_inclusive(offset) { + return Some(GotoTarget::ImportSymbolAlias { + alias, + range: asname.range, + import_from, + }); + } + } + + // Is the offset in the original name part? + if alias.name.range.contains_inclusive(offset) { + return Some(GotoTarget::ImportSymbolAlias { + alias, + range: alias.name.range, + import_from, + }); + } + + None + } + _ => None, + } + } + Some(AnyNodeRef::StmtImportFrom(from)) => { + // Handle offset within module name in from import statements + if let Some(module_expr) = &from.module { + let full_module_name = module_expr.to_string(); + + if let Some((component_index, component_range)) = find_module_component( + &full_module_name, + module_expr.range.start(), + offset, + ) { + return Some(GotoTarget::ImportModuleComponent { + module_name: full_module_name, + component_index, + component_range, + }); + } + } + + None + } + Some(AnyNodeRef::ExceptHandlerExceptHandler(handler)) => { + Some(GotoTarget::ExceptVariable(handler)) + } + Some(AnyNodeRef::Keyword(keyword)) => { + // Find the containing call expression from the ancestor chain + let call_expression = covering_node + .ancestors() + .find_map(ruff_python_ast::AnyNodeRef::expr_call)?; + Some(GotoTarget::KeywordArgument { + keyword, + call_expression, + }) + } + Some(AnyNodeRef::PatternMatchMapping(mapping)) => { + Some(GotoTarget::PatternMatchRest(mapping)) + } + Some(AnyNodeRef::PatternKeyword(keyword)) => { + Some(GotoTarget::PatternKeywordArgument(keyword)) + } + Some(AnyNodeRef::PatternMatchStar(star)) => { + Some(GotoTarget::PatternMatchStarName(star)) + } + Some(AnyNodeRef::PatternMatchAs(as_pattern)) => { + Some(GotoTarget::PatternMatchAsName(as_pattern)) + } + Some(AnyNodeRef::TypeParamTypeVar(var)) => { + Some(GotoTarget::TypeParamTypeVarName(var)) + } + Some(AnyNodeRef::TypeParamParamSpec(bound)) => { + Some(GotoTarget::TypeParamParamSpecName(bound)) + } + Some(AnyNodeRef::TypeParamTypeVarTuple(var_tuple)) => { + Some(GotoTarget::TypeParamTypeVarTupleName(var_tuple)) + } + Some(AnyNodeRef::ExprAttribute(attribute)) => { + Some(GotoTarget::Expression(attribute.into())) + } + Some(AnyNodeRef::StmtNonlocal(_)) => Some(GotoTarget::NonLocal { identifier }), + Some(AnyNodeRef::StmtGlobal(_)) => Some(GotoTarget::Globals { identifier }), + None => None, + Some(parent) => { + tracing::debug!( + "Missing `GoToTarget` for identifier with parent {:?}", + parent.kind() + ); + None + } + }, + + node => node.as_expr_ref().map(GotoTarget::Expression), + } + } } impl Ranged for GotoTarget<'_> { @@ -477,145 +633,7 @@ pub(crate) fn find_goto_target( .find_first(|node| node.is_identifier() || node.is_expression()) .ok()?; - tracing::trace!("Covering node is of kind {:?}", covering_node.node().kind()); - - match covering_node.node() { - AnyNodeRef::Identifier(identifier) => match covering_node.parent() { - Some(AnyNodeRef::StmtFunctionDef(function)) => Some(GotoTarget::FunctionDef(function)), - Some(AnyNodeRef::StmtClassDef(class)) => Some(GotoTarget::ClassDef(class)), - Some(AnyNodeRef::Parameter(parameter)) => Some(GotoTarget::Parameter(parameter)), - Some(AnyNodeRef::Alias(alias)) => { - // Find the containing import statement to determine the type - let import_stmt = covering_node.ancestors().find(|node| { - matches!( - node, - AnyNodeRef::StmtImport(_) | AnyNodeRef::StmtImportFrom(_) - ) - }); - - match import_stmt { - Some(AnyNodeRef::StmtImport(_)) => { - // Regular import statement like "import x.y as z" - - // Is the offset within the alias name (asname) part? - if let Some(asname) = &alias.asname { - if asname.range.contains_inclusive(offset) { - return Some(GotoTarget::ImportModuleAlias { alias }); - } - } - - // Is the offset in the module name part? - if alias.name.range.contains_inclusive(offset) { - let full_name = alias.name.as_str(); - - if let Some((component_index, component_range)) = - find_module_component(full_name, alias.name.range.start(), offset) - { - return Some(GotoTarget::ImportModuleComponent { - module_name: full_name.to_string(), - component_index, - component_range, - }); - } - } - - None - } - Some(AnyNodeRef::StmtImportFrom(import_from)) => { - // From import statement like "from x import y as z" - - // Is the offset within the alias name (asname) part? - if let Some(asname) = &alias.asname { - if asname.range.contains_inclusive(offset) { - return Some(GotoTarget::ImportSymbolAlias { - alias, - range: asname.range, - import_from, - }); - } - } - - // Is the offset in the original name part? - if alias.name.range.contains_inclusive(offset) { - return Some(GotoTarget::ImportSymbolAlias { - alias, - range: alias.name.range, - import_from, - }); - } - - None - } - _ => None, - } - } - Some(AnyNodeRef::StmtImportFrom(from)) => { - // Handle offset within module name in from import statements - if let Some(module_expr) = &from.module { - let full_module_name = module_expr.to_string(); - - if let Some((component_index, component_range)) = - find_module_component(&full_module_name, module_expr.range.start(), offset) - { - return Some(GotoTarget::ImportModuleComponent { - module_name: full_module_name, - component_index, - component_range, - }); - } - } - - None - } - Some(AnyNodeRef::ExceptHandlerExceptHandler(handler)) => { - Some(GotoTarget::ExceptVariable(handler)) - } - Some(AnyNodeRef::Keyword(keyword)) => { - // Find the containing call expression from the ancestor chain - let call_expression = covering_node - .ancestors() - .find_map(ruff_python_ast::AnyNodeRef::expr_call)?; - Some(GotoTarget::KeywordArgument { - keyword, - call_expression, - }) - } - Some(AnyNodeRef::PatternMatchMapping(mapping)) => { - Some(GotoTarget::PatternMatchRest(mapping)) - } - Some(AnyNodeRef::PatternKeyword(keyword)) => { - Some(GotoTarget::PatternKeywordArgument(keyword)) - } - Some(AnyNodeRef::PatternMatchStar(star)) => { - Some(GotoTarget::PatternMatchStarName(star)) - } - Some(AnyNodeRef::PatternMatchAs(as_pattern)) => { - Some(GotoTarget::PatternMatchAsName(as_pattern)) - } - Some(AnyNodeRef::TypeParamTypeVar(var)) => Some(GotoTarget::TypeParamTypeVarName(var)), - Some(AnyNodeRef::TypeParamParamSpec(bound)) => { - Some(GotoTarget::TypeParamParamSpecName(bound)) - } - Some(AnyNodeRef::TypeParamTypeVarTuple(var_tuple)) => { - Some(GotoTarget::TypeParamTypeVarTupleName(var_tuple)) - } - Some(AnyNodeRef::ExprAttribute(attribute)) => { - Some(GotoTarget::Expression(attribute.into())) - } - Some(AnyNodeRef::StmtNonlocal(_)) => Some(GotoTarget::NonLocal { identifier }), - Some(AnyNodeRef::StmtGlobal(_)) => Some(GotoTarget::Globals { identifier }), - None => None, - Some(parent) => { - tracing::debug!( - "Missing `GoToTarget` for identifier with parent {:?}", - parent.kind() - ); - None - } - }, - - node => node.as_expr_ref().map(GotoTarget::Expression), - } + GotoTarget::from_covering_node(&covering_node, offset) } /// Helper function to resolve a module name and create a navigation target. diff --git a/crates/ty_ide/src/references.rs b/crates/ty_ide/src/references.rs index c5851e6a8711a..95d8862c18c3f 100644 --- a/crates/ty_ide/src/references.rs +++ b/crates/ty_ide/src/references.rs @@ -9,10 +9,14 @@ //! all references to these externally-visible symbols therefore requires //! an expensive search of all source files in the workspace. -use crate::goto::find_goto_target; +use crate::find_node::CoveringNode; +use crate::goto::{GotoTarget, find_goto_target}; use crate::{Db, NavigationTarget, NavigationTargets, RangedValue}; use ruff_db::files::{File, FileRange}; -use ruff_python_ast::{self as ast, visitor::Visitor}; +use ruff_python_ast::{ + self as ast, AnyNodeRef, + visitor::source_order::{SourceOrderVisitor, TraversalSignal}, +}; use ruff_text_size::{Ranged, TextSize}; /// Find all references to a symbol at the given position. @@ -71,11 +75,11 @@ fn find_local_references( target_definitions, references: Vec::new(), include_declaration, - module: &module, target_text, + ancestors: Vec::new(), }; - finder.visit_body(&module.syntax().body); + AnyNodeRef::from(module.syntax()).visit_source_order(&mut finder); if finder.references.is_empty() { None @@ -91,103 +95,101 @@ struct LocalReferencesFinder<'a> { target_definitions: &'a NavigationTargets, references: Vec>, include_declaration: bool, - module: &'a ruff_db::parsed::ParsedModuleRef, target_text: &'a str, + ancestors: Vec>, } -impl<'a> Visitor<'a> for LocalReferencesFinder<'a> { - fn visit_expr(&mut self, expr: &'a ast::Expr) { - match expr { - ast::Expr::Name(name_expr) => { - self.check_reference_at_offset(name_expr.range.start(), &name_expr.id); - } - ast::Expr::Call(call_expr) => { - // Handle keyword arguments in call expressions - for keyword in &call_expr.arguments.keywords { - if let Some(arg) = &keyword.arg { - self.check_reference_at_offset(arg.range.start(), &arg.id); - } - } - } - _ => {} - } - ast::visitor::walk_expr(self, expr); - } +impl<'a> SourceOrderVisitor<'a> for LocalReferencesFinder<'a> { + fn enter_node(&mut self, node: AnyNodeRef<'a>) -> TraversalSignal { + self.ancestors.push(node); - fn visit_stmt(&mut self, stmt: &'a ast::Stmt) { - match stmt { - ast::Stmt::FunctionDef(func) => { - if self.include_declaration { - self.check_reference_at_offset(func.name.range.start(), &func.name.id); + match node { + AnyNodeRef::ExprName(name_expr) => { + // If the name doesn't match our target text, this isn't a match + if name_expr.id.as_str() != self.target_text { + return TraversalSignal::Traverse; } + + let covering_node = CoveringNode::from_ancestors(self.ancestors.clone()); + self.check_reference_from_covering_node(&covering_node); } - ast::Stmt::ClassDef(class) => { - if self.include_declaration { - self.check_reference_at_offset(class.name.range.start(), &class.name.id); - } + AnyNodeRef::StmtFunctionDef(func) if self.include_declaration => { + self.check_identifier_reference(&func.name); } - ast::Stmt::Nonlocal(nonlocal_stmt) => { - for name in &nonlocal_stmt.names { - self.check_reference_at_offset(name.range.start(), &name.id); + AnyNodeRef::StmtClassDef(class) if self.include_declaration => { + self.check_identifier_reference(&class.name); + } + AnyNodeRef::Parameter(parameter) if self.include_declaration => { + self.check_identifier_reference(¶meter.name); + } + AnyNodeRef::Keyword(keyword) => { + if let Some(arg) = &keyword.arg { + self.check_identifier_reference(arg); } } - ast::Stmt::Global(global_stmt) => { + AnyNodeRef::StmtGlobal(global_stmt) if self.include_declaration => { for name in &global_stmt.names { - self.check_reference_at_offset(name.range.start(), &name.id); + self.check_identifier_reference(name); } } - _ => {} - } - ast::visitor::walk_stmt(self, stmt); - } - - fn visit_parameter(&mut self, parameter: &'a ast::Parameter) { - if self.include_declaration { - self.check_reference_at_offset(parameter.name.range.start(), ¶meter.name.id); - } - ast::visitor::walk_parameter(self, parameter); - } - - fn visit_except_handler(&mut self, except_handler: &'a ast::ExceptHandler) { - match except_handler { - ast::ExceptHandler::ExceptHandler(handler) => { + AnyNodeRef::StmtNonlocal(nonlocal_stmt) if self.include_declaration => { + for name in &nonlocal_stmt.names { + self.check_identifier_reference(name); + } + } + AnyNodeRef::ExceptHandlerExceptHandler(handler) if self.include_declaration => { if let Some(name) = &handler.name { - self.check_reference_at_offset(name.range.start(), &name.id); + self.check_identifier_reference(name); } } - } - ast::visitor::walk_except_handler(self, except_handler); - } - - fn visit_pattern(&mut self, pattern: &'a ast::Pattern) { - match pattern { - ast::Pattern::MatchAs(pattern_as) => { + AnyNodeRef::PatternMatchAs(pattern_as) if self.include_declaration => { if let Some(name) = &pattern_as.name { - self.check_reference_at_offset(name.range.start(), &name.id); + self.check_identifier_reference(name); } } - ast::Pattern::MatchMapping(pattern_mapping) => { + AnyNodeRef::PatternMatchMapping(pattern_mapping) if self.include_declaration => { if let Some(rest_name) = &pattern_mapping.rest { - self.check_reference_at_offset(rest_name.range.start(), &rest_name.id); + self.check_identifier_reference(rest_name); } } _ => {} } - ast::visitor::walk_pattern(self, pattern); + + TraversalSignal::Traverse + } + + fn leave_node(&mut self, node: AnyNodeRef<'a>) { + debug_assert_eq!(self.ancestors.last(), Some(&node)); + self.ancestors.pop(); } } impl LocalReferencesFinder<'_> { - /// Determines whether the node at the specified offset is a reference to - /// the symbol we are searching for - fn check_reference_at_offset(&mut self, offset: TextSize, text: &str) { - // Quick text-based check first - compare with our target text - if text != self.target_text { + /// Helper method to check identifier references for declarations + fn check_identifier_reference(&mut self, identifier: &ast::Identifier) { + // Quick text-based check first + if identifier.id != self.target_text { return; } - // Use find_goto_target to get the GotoTarget for this identifier - if let Some(goto_target) = find_goto_target(self.module, offset) { + let mut ancestors_with_identifier = self.ancestors.clone(); + ancestors_with_identifier.push(AnyNodeRef::from(identifier)); + let covering_node = CoveringNode::from_ancestors(ancestors_with_identifier); + self.check_reference_from_covering_node(&covering_node); + } + + /// Determines whether the given covering node is a reference to + /// the symbol we are searching for + fn check_reference_from_covering_node( + &mut self, + covering_node: &crate::find_node::CoveringNode<'_>, + ) { + // Use the start of the covering node as the offset. Any offset within + // the node is fine here. Offsets matter only for import statements + // where the identifier might be a multi-part module name. + let offset = covering_node.node().range().start(); + + if let Some(goto_target) = GotoTarget::from_covering_node(covering_node, offset) { let range = goto_target.range(); // Get the definitions for this goto target From dde2831a19486eebc8498ebdc5fb68a3d8e11d6d Mon Sep 17 00:00:00 2001 From: UnboundVariable Date: Tue, 22 Jul 2025 18:19:56 -0700 Subject: [PATCH 04/10] Implemented multi-file reference searches. --- crates/ty_ide/src/lib.rs | 27 +- crates/ty_ide/src/references.rs | 302 +++++++++++++++++- .../src/server/api/requests/references.rs | 10 +- 3 files changed, 314 insertions(+), 25 deletions(-) diff --git a/crates/ty_ide/src/lib.rs b/crates/ty_ide/src/lib.rs index 455724589aa86..6c93f1570f4c8 100644 --- a/crates/ty_ide/src/lib.rs +++ b/crates/ty_ide/src/lib.rs @@ -244,6 +244,7 @@ mod tests { pub(super) struct CursorTest { pub(super) db: TestDb, pub(super) cursor: Cursor, + pub(super) files: Vec, _insta_settings_guard: SettingsBindDropGuard, } @@ -299,6 +300,8 @@ mod tests { pub(super) fn build(&self) -> CursorTest { let mut db = TestDb::new(); let mut cursor: Option = None; + let mut files = Vec::new(); + for &Source { ref path, ref contents, @@ -307,19 +310,20 @@ mod tests { { db.write_file(path, contents) .expect("write to memory file system to be successful"); - let Some(offset) = cursor_offset else { - continue; - }; let file = system_path_to_file(&db, path).expect("newly written file to existing"); - // This assert should generally never trip, since - // we have an assert on `CursorTestBuilder::source` - // to ensure we never have more than one marker. - assert!( - cursor.is_none(), - "found more than one source that contains ``" - ); - cursor = Some(Cursor { file, offset }); + files.push(file); + + if let Some(offset) = cursor_offset { + // This assert should generally never trip, since + // we have an assert on `CursorTestBuilder::source` + // to ensure we never have more than one marker. + assert!( + cursor.is_none(), + "found more than one source that contains ``" + ); + cursor = Some(Cursor { file, offset }); + } } let search_paths = SearchPathSettings::new(vec![SystemPathBuf::from("/")]) @@ -345,6 +349,7 @@ mod tests { CursorTest { db, cursor: cursor.expect("at least one source to contain ``"), + files, _insta_settings_guard: insta_settings_guard, } } diff --git a/crates/ty_ide/src/references.rs b/crates/ty_ide/src/references.rs index 95d8862c18c3f..05cf06fc0f61e 100644 --- a/crates/ty_ide/src/references.rs +++ b/crates/ty_ide/src/references.rs @@ -20,11 +20,14 @@ use ruff_python_ast::{ use ruff_text_size::{Ranged, TextSize}; /// Find all references to a symbol at the given position. +/// Search for references across all files in `project_files`. +/// To search only within the current file, pass an empty iterator. pub fn references( db: &dyn Db, file: File, offset: TextSize, include_declaration: bool, + project_files: impl IntoIterator, ) -> Option>> { let parsed = ruff_db::parsed::parsed_module(db, file); let module = parsed.load(db); @@ -37,23 +40,43 @@ pub fn references( let target_text = goto_target.to_string()?; // Find all of the references to the symbol within this file - find_local_references( + let mut references = find_local_references( db, file, &target_definitions, include_declaration, &target_text, - ) + )?; + + // Check if the symbol is potentially visible outside of this module + if is_symbol_externally_visible(&goto_target) { + // Look for references in all other files within the workspace + for other_file in project_files { + // Skip the current file as we already processed it + if other_file == file { + continue; + } - // TODO: Determine whether the symbol is potentially visible outside - // of this module. + // First do a simple text search to see if there is a potential match in the file + let source = ruff_db::source::source_text(db, other_file); + if !source.as_str().contains(target_text.as_ref()) { + continue; + } - // TODO: For symbols that are potentially visible outside of this - // module, we need to look for references in all other files within - // the workspace. + // If the target text is found, do the more expensive semantic analysis + if let Some(other_references) = find_local_references( + db, + other_file, + &target_definitions, + false, // Don't include declarations from other files + &target_text, + ) { + references.extend(other_references); + } + } + } - // TODO: Eliminate the need to parse every file by first doing a simple - // text context search to see if there is a potential match in the file. + Some(references) } /// Find all references to a local symbol within the current file. If @@ -88,6 +111,25 @@ fn find_local_references( } } +/// Determines whether a symbol is potentially visible outside of the current module. +fn is_symbol_externally_visible(goto_target: &GotoTarget<'_>) -> bool { + match goto_target { + GotoTarget::Parameter(_) + | GotoTarget::KeywordArgument { .. } + | GotoTarget::ExceptVariable(_) + | GotoTarget::TypeParamTypeVarName(_) + | GotoTarget::TypeParamParamSpecName(_) + | GotoTarget::TypeParamTypeVarTupleName(_) => false, + + // Assume all other goto target types are potentially visible. + + // TODO: For local variables, we should be able to return false + // except in cases where the variable is in the global scope + // or uses a "global" binding. + _ => true, + } +} + /// AST visitor to find all references to a specific symbol by comparing semantic definitions struct LocalReferencesFinder<'a> { db: &'a dyn Db, @@ -235,9 +277,38 @@ mod tests { impl CursorTest { fn references(&self) -> String { - let Some(reference_results) = - references(&self.db, self.cursor.file, self.cursor.offset, true) - else { + let Some(reference_results) = references( + &self.db, + self.cursor.file, + self.cursor.offset, + true, + std::iter::empty(), + ) else { + return "No references found".to_string(); + }; + + if reference_results.is_empty() { + return "No references found".to_string(); + } + + self.render_diagnostics(reference_results.into_iter().enumerate().map( + |(i, ref_item)| -> ReferenceResult { + ReferenceResult { + index: i, + file_range: ref_item.range, + } + }, + )) + } + + fn references_with_project_files(&self, project_files: Vec) -> String { + let Some(reference_results) = references( + &self.db, + self.cursor.file, + self.cursor.offset, + true, + project_files, + ) else { return "No references found".to_string(); }; @@ -876,4 +947,211 @@ cls = MyClass | "); } + + #[test] + fn test_multi_file_function_references() { + let test = CursorTest::builder() + .source( + "utils.py", + " +def helper_function(x): + return x * 2 +", + ) + .source( + "module.py", + " +from utils import helper_function + +def process_data(data): + return helper_function(data) + +def double_process(data): + result = helper_function(data) + return helper_function(result) +", + ) + .source( + "app.py", + " +from utils import helper_function + +class DataProcessor: + def __init__(self): + self.multiplier = helper_function + + def process(self, value): + return helper_function(value) +", + ) + .build(); + + assert_snapshot!(test.references_with_project_files(test.files.clone()), @r" + info[references]: Reference 1 + --> utils.py:2:5 + | + 2 | def helper_function(x): + | ^^^^^^^^^^^^^^^ + 3 | return x * 2 + | + + info[references]: Reference 2 + --> module.py:5:12 + | + 4 | def process_data(data): + 5 | return helper_function(data) + | ^^^^^^^^^^^^^^^ + 6 | + 7 | def double_process(data): + | + + info[references]: Reference 3 + --> module.py:8:14 + | + 7 | def double_process(data): + 8 | result = helper_function(data) + | ^^^^^^^^^^^^^^^ + 9 | return helper_function(result) + | + + info[references]: Reference 4 + --> module.py:9:12 + | + 7 | def double_process(data): + 8 | result = helper_function(data) + 9 | return helper_function(result) + | ^^^^^^^^^^^^^^^ + | + + info[references]: Reference 5 + --> app.py:6:27 + | + 4 | class DataProcessor: + 5 | def __init__(self): + 6 | self.multiplier = helper_function + | ^^^^^^^^^^^^^^^ + 7 | + 8 | def process(self, value): + | + + info[references]: Reference 6 + --> app.py:9:16 + | + 8 | def process(self, value): + 9 | return helper_function(value) + | ^^^^^^^^^^^^^^^ + | + "); + } + + #[test] + #[ignore] // TODO: Enable when attribute references are fully implemented + fn test_multi_file_class_attribute_references() { + let test = CursorTest::builder() + .source( + "models.py", + " +class MyModel: + def __init__(self): + self.special_attribute = 42 + + def get_attribute(self): + return self.attr +", + ) + .source( + "main.py", + " +from models import MyModel + +def process_model(): + model = MyModel() + value = model.attr + model.attr = 100 + return model.attr + +class ModelWrapper: + def __init__(self, model): + self.inner = model + + def get_special(self): + return self.inner.attr + + def set_special(self, value): + self.inner.attr = value +", + ) + .build(); + + assert_snapshot!(test.references_with_project_files(test.files.clone()), @r" + info[references]: Reference 1 + --> models.py:4:14 + | + 3 | def __init__(self): + 4 | self.attr = 42 + | ^^^^ + 5 | + 6 | def get_attribute(self): + | + + info[references]: Reference 2 + --> models.py:7:21 + | + 5 | + 6 | def get_attribute(self): + 7 | return self.attr + | ^^^^ + | + + info[references]: Reference 3 + --> main.py:6:19 + | + 5 | model = MyModel() + 6 | value = model.attr + | ^^^^ + 7 | model.attr = 100 + 8 | return model.attr + | + + info[references]: Reference 4 + --> main.py:7:11 + | + 5 | model = MyModel() + 6 | value = model.attr + 7 | model.attr = 100 + | ^^^^ + 8 | return model.attr + | + + info[references]: Reference 5 + --> main.py:8:18 + | + 6 | value = model.attr + 7 | model.attr = 100 + 8 | return model.attr + | ^^^^ + 9 | + 10 | class ModelWrapper: + | + + info[references]: Reference 6 + --> main.py:15:26 + | + 14 | def get_special(self): + 15 | return self.inner.attr + | ^^^^^ + 16 | + 17 | def set_special(self, value): + | + + info[references]: Reference 7 + --> main.py:18:20 + | + 16 | + 17 | def set_special(self, value): + 18 | self.inner.attr = value + | ^^^^ + | + "); + } } diff --git a/crates/ty_server/src/server/api/requests/references.rs b/crates/ty_server/src/server/api/requests/references.rs index 63c713ac3ed1b..4f6a854d7da6e 100644 --- a/crates/ty_server/src/server/api/requests/references.rs +++ b/crates/ty_server/src/server/api/requests/references.rs @@ -4,7 +4,7 @@ use lsp_types::request::References; use lsp_types::{Location, ReferenceParams, Url}; use ruff_db::source::{line_index, source_text}; use ty_ide::references; -use ty_project::ProjectDatabase; +use ty_project::{Db, ProjectDatabase}; use crate::document::{PositionExt, ToLink}; use crate::server::api::traits::{ @@ -48,7 +48,13 @@ impl BackgroundDocumentRequestHandler for ReferencesRequestHandler { let include_declaration = params.context.include_declaration; - let Some(references_result) = references(db, file, offset, include_declaration) else { + let Some(references_result) = references( + db, + file, + offset, + include_declaration, + &db.project().files(db), + ) else { return Ok(None); }; From 96296e63896394bc8f64b388ff671a51e0ffbff7 Mon Sep 17 00:00:00 2001 From: UnboundVariable Date: Tue, 22 Jul 2025 19:00:25 -0700 Subject: [PATCH 05/10] Fixed bugs in attribute expression handling. --- crates/ty_ide/src/references.rs | 100 ++++++++++++-------------------- 1 file changed, 36 insertions(+), 64 deletions(-) diff --git a/crates/ty_ide/src/references.rs b/crates/ty_ide/src/references.rs index 05cf06fc0f61e..a724eafef1c34 100644 --- a/crates/ty_ide/src/references.rs +++ b/crates/ty_ide/src/references.rs @@ -40,13 +40,15 @@ pub fn references( let target_text = goto_target.to_string()?; // Find all of the references to the symbol within this file - let mut references = find_local_references( + let mut references = Vec::new(); + references_for_file( db, file, &target_definitions, include_declaration, &target_text, - )?; + &mut references, + ); // Check if the symbol is potentially visible outside of this module if is_symbol_externally_visible(&goto_target) { @@ -64,31 +66,35 @@ pub fn references( } // If the target text is found, do the more expensive semantic analysis - if let Some(other_references) = find_local_references( + references_for_file( db, other_file, &target_definitions, false, // Don't include declarations from other files &target_text, - ) { - references.extend(other_references); - } + &mut references, + ); } } - Some(references) + if references.is_empty() { + None + } else { + Some(references) + } } /// Find all references to a local symbol within the current file. If /// `include_declaration` is true, return the original declaration for symbols /// such as functions or classes that have a single declaration location. -fn find_local_references( +fn references_for_file( db: &dyn Db, file: File, target_definitions: &NavigationTargets, include_declaration: bool, target_text: &str, -) -> Option>> { + references: &mut Vec>, +) { let parsed = ruff_db::parsed::parsed_module(db, file); let module = parsed.load(db); @@ -96,19 +102,13 @@ fn find_local_references( db, file, target_definitions, - references: Vec::new(), + references, include_declaration, target_text, ancestors: Vec::new(), }; AnyNodeRef::from(module.syntax()).visit_source_order(&mut finder); - - if finder.references.is_empty() { - None - } else { - Some(finder.references) - } } /// Determines whether a symbol is potentially visible outside of the current module. @@ -135,7 +135,7 @@ struct LocalReferencesFinder<'a> { db: &'a dyn Db, file: File, target_definitions: &'a NavigationTargets, - references: Vec>, + references: &'a mut Vec>, include_declaration: bool, target_text: &'a str, ancestors: Vec>, @@ -155,6 +155,9 @@ impl<'a> SourceOrderVisitor<'a> for LocalReferencesFinder<'a> { let covering_node = CoveringNode::from_ancestors(self.ancestors.clone()); self.check_reference_from_covering_node(&covering_node); } + AnyNodeRef::ExprAttribute(attr_expr) => { + self.check_identifier_reference(&attr_expr.attr); + } AnyNodeRef::StmtFunctionDef(func) if self.include_declaration => { self.check_identifier_reference(&func.name); } @@ -232,7 +235,9 @@ impl LocalReferencesFinder<'_> { let offset = covering_node.node().range().start(); if let Some(goto_target) = GotoTarget::from_covering_node(covering_node, offset) { - let range = goto_target.range(); + // Use the range of the covering node (the identifier) rather than the goto target + // This ensures we highlight just the identifier, not the entire expression + let range = covering_node.node().range(); // Get the definitions for this goto target if let Some(current_definitions) = @@ -1045,18 +1050,16 @@ class DataProcessor: } #[test] - #[ignore] // TODO: Enable when attribute references are fully implemented fn test_multi_file_class_attribute_references() { let test = CursorTest::builder() .source( "models.py", " class MyModel: - def __init__(self): - self.special_attribute = 42 + attr = 42 def get_attribute(self): - return self.attr + return MyModel.attr ", ) .source( @@ -1069,43 +1072,33 @@ def process_model(): value = model.attr model.attr = 100 return model.attr - -class ModelWrapper: - def __init__(self, model): - self.inner = model - - def get_special(self): - return self.inner.attr - - def set_special(self, value): - self.inner.attr = value ", ) .build(); assert_snapshot!(test.references_with_project_files(test.files.clone()), @r" info[references]: Reference 1 - --> models.py:4:14 + --> models.py:3:5 | - 3 | def __init__(self): - 4 | self.attr = 42 - | ^^^^ - 5 | - 6 | def get_attribute(self): + 2 | class MyModel: + 3 | attr = 42 + | ^^^^ + 4 | + 5 | def get_attribute(self): | info[references]: Reference 2 - --> models.py:7:21 + --> models.py:6:24 | - 5 | - 6 | def get_attribute(self): - 7 | return self.attr - | ^^^^ + 5 | def get_attribute(self): + 6 | return MyModel.attr + | ^^^^ | info[references]: Reference 3 --> main.py:6:19 | + 4 | def process_model(): 5 | model = MyModel() 6 | value = model.attr | ^^^^ @@ -1130,28 +1123,7 @@ class ModelWrapper: 7 | model.attr = 100 8 | return model.attr | ^^^^ - 9 | - 10 | class ModelWrapper: | - - info[references]: Reference 6 - --> main.py:15:26 - | - 14 | def get_special(self): - 15 | return self.inner.attr - | ^^^^^ - 16 | - 17 | def set_special(self, value): - | - - info[references]: Reference 7 - --> main.py:18:20 - | - 16 | - 17 | def set_special(self, value): - 18 | self.inner.attr = value - | ^^^^ - | "); } } From ed72129eb250f4906250cb27aebb9f324b422efa Mon Sep 17 00:00:00 2001 From: UnboundVariable Date: Tue, 22 Jul 2025 19:03:10 -0700 Subject: [PATCH 06/10] Fixed bug in handling of keyword arguments. --- crates/ty_ide/src/references.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/ty_ide/src/references.rs b/crates/ty_ide/src/references.rs index a724eafef1c34..8ce8719687f20 100644 --- a/crates/ty_ide/src/references.rs +++ b/crates/ty_ide/src/references.rs @@ -115,7 +115,6 @@ fn references_for_file( fn is_symbol_externally_visible(goto_target: &GotoTarget<'_>) -> bool { match goto_target { GotoTarget::Parameter(_) - | GotoTarget::KeywordArgument { .. } | GotoTarget::ExceptVariable(_) | GotoTarget::TypeParamTypeVarName(_) | GotoTarget::TypeParamParamSpecName(_) From a6c9665fe8b026aefa53227d31bc17177b626e25 Mon Sep 17 00:00:00 2001 From: UnboundVariable Date: Wed, 23 Jul 2025 08:42:59 -0700 Subject: [PATCH 07/10] Fixed compiler error introduced during merge. --- crates/ty_ide/src/goto.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/ty_ide/src/goto.rs b/crates/ty_ide/src/goto.rs index 3299b8929d176..0527e7e1cf511 100644 --- a/crates/ty_ide/src/goto.rs +++ b/crates/ty_ide/src/goto.rs @@ -645,7 +645,7 @@ fn resolve_module_to_navigation_target( if let Some(module_name) = ModuleName::new(module_name_str) { if let Some(resolved_module) = resolve_module(db, &module_name) { - if let Some(module_file) = resolved_module.file() { + if let Some(module_file) = resolved_module.file(db) { return Some(crate::NavigationTargets::single( crate::NavigationTarget::new(module_file, TextRange::default()), )); From 1b550aae89ba3ac3986a14905a865840df97c9a4 Mon Sep 17 00:00:00 2001 From: UnboundVariable Date: Wed, 23 Jul 2025 08:43:46 -0700 Subject: [PATCH 08/10] Incorporated code review feedback. --- crates/ty_ide/src/lib.rs | 4 -- crates/ty_ide/src/references.rs | 45 +++---------------- .../src/server/api/requests/references.rs | 10 +---- 3 files changed, 9 insertions(+), 50 deletions(-) diff --git a/crates/ty_ide/src/lib.rs b/crates/ty_ide/src/lib.rs index a278b30d146ce..b171a3acecaec 100644 --- a/crates/ty_ide/src/lib.rs +++ b/crates/ty_ide/src/lib.rs @@ -243,7 +243,6 @@ mod tests { pub(super) struct CursorTest { pub(super) db: ty_project::TestDb, pub(super) cursor: Cursor, - pub(super) files: Vec, _insta_settings_guard: SettingsBindDropGuard, } @@ -303,7 +302,6 @@ mod tests { )); let mut cursor: Option = None; - let mut files = Vec::new(); for &Source { ref path, @@ -315,7 +313,6 @@ mod tests { .expect("write to memory file system to be successful"); let file = system_path_to_file(&db, path).expect("newly written file to existing"); - files.push(file); if let Some(offset) = cursor_offset { // This assert should generally never trip, since @@ -352,7 +349,6 @@ mod tests { CursorTest { db, cursor: cursor.expect("at least one source to contain ``"), - files, _insta_settings_guard: insta_settings_guard, } } diff --git a/crates/ty_ide/src/references.rs b/crates/ty_ide/src/references.rs index 8ce8719687f20..910cb8d248efb 100644 --- a/crates/ty_ide/src/references.rs +++ b/crates/ty_ide/src/references.rs @@ -20,14 +20,12 @@ use ruff_python_ast::{ use ruff_text_size::{Ranged, TextSize}; /// Find all references to a symbol at the given position. -/// Search for references across all files in `project_files`. -/// To search only within the current file, pass an empty iterator. +/// Search for references across all files in the project. pub fn references( db: &dyn Db, file: File, offset: TextSize, include_declaration: bool, - project_files: impl IntoIterator, ) -> Option>> { let parsed = ruff_db::parsed::parsed_module(db, file); let module = parsed.load(db); @@ -53,7 +51,7 @@ pub fn references( // Check if the symbol is potentially visible outside of this module if is_symbol_externally_visible(&goto_target) { // Look for references in all other files within the workspace - for other_file in project_files { + for other_file in &db.project().files(db) { // Skip the current file as we already processed it if other_file == file { continue; @@ -281,38 +279,9 @@ mod tests { impl CursorTest { fn references(&self) -> String { - let Some(reference_results) = references( - &self.db, - self.cursor.file, - self.cursor.offset, - true, - std::iter::empty(), - ) else { - return "No references found".to_string(); - }; - - if reference_results.is_empty() { - return "No references found".to_string(); - } - - self.render_diagnostics(reference_results.into_iter().enumerate().map( - |(i, ref_item)| -> ReferenceResult { - ReferenceResult { - index: i, - file_range: ref_item.range, - } - }, - )) - } - - fn references_with_project_files(&self, project_files: Vec) -> String { - let Some(reference_results) = references( - &self.db, - self.cursor.file, - self.cursor.offset, - true, - project_files, - ) else { + let Some(reference_results) = + references(&self.db, self.cursor.file, self.cursor.offset, true) + else { return "No references found".to_string(); }; @@ -990,7 +959,7 @@ class DataProcessor: ) .build(); - assert_snapshot!(test.references_with_project_files(test.files.clone()), @r" + assert_snapshot!(test.references(), @r" info[references]: Reference 1 --> utils.py:2:5 | @@ -1075,7 +1044,7 @@ def process_model(): ) .build(); - assert_snapshot!(test.references_with_project_files(test.files.clone()), @r" + assert_snapshot!(test.references(), @r" info[references]: Reference 1 --> models.py:3:5 | diff --git a/crates/ty_server/src/server/api/requests/references.rs b/crates/ty_server/src/server/api/requests/references.rs index 4f6a854d7da6e..63c713ac3ed1b 100644 --- a/crates/ty_server/src/server/api/requests/references.rs +++ b/crates/ty_server/src/server/api/requests/references.rs @@ -4,7 +4,7 @@ use lsp_types::request::References; use lsp_types::{Location, ReferenceParams, Url}; use ruff_db::source::{line_index, source_text}; use ty_ide::references; -use ty_project::{Db, ProjectDatabase}; +use ty_project::ProjectDatabase; use crate::document::{PositionExt, ToLink}; use crate::server::api::traits::{ @@ -48,13 +48,7 @@ impl BackgroundDocumentRequestHandler for ReferencesRequestHandler { let include_declaration = params.context.include_declaration; - let Some(references_result) = references( - db, - file, - offset, - include_declaration, - &db.project().files(db), - ) else { + let Some(references_result) = references(db, file, offset, include_declaration) else { return Ok(None); }; From 859ccb2e5d3e6c86807b4cad1de96b05fffbf502 Mon Sep 17 00:00:00 2001 From: UnboundVariable Date: Wed, 23 Jul 2025 08:48:01 -0700 Subject: [PATCH 09/10] Code review feedback. --- crates/ty_ide/src/references.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/ty_ide/src/references.rs b/crates/ty_ide/src/references.rs index 910cb8d248efb..b1a713963fa16 100644 --- a/crates/ty_ide/src/references.rs +++ b/crates/ty_ide/src/references.rs @@ -229,7 +229,7 @@ impl LocalReferencesFinder<'_> { // Use the start of the covering node as the offset. Any offset within // the node is fine here. Offsets matter only for import statements // where the identifier might be a multi-part module name. - let offset = covering_node.node().range().start(); + let offset = covering_node.node().start(); if let Some(goto_target) = GotoTarget::from_covering_node(covering_node, offset) { // Use the range of the covering node (the identifier) rather than the goto target From af35936b25fc4632a96bd0a0f13919d1b16e415e Mon Sep 17 00:00:00 2001 From: UnboundVariable Date: Wed, 23 Jul 2025 09:03:23 -0700 Subject: [PATCH 10/10] Fixed server capabilities snapshots. --- .../src/snapshots/ty_server__server__tests__initialization.snap | 1 + .../ty_server__server__tests__initialization_with_workspace.snap | 1 + 2 files changed, 2 insertions(+) diff --git a/crates/ty_server/src/snapshots/ty_server__server__tests__initialization.snap b/crates/ty_server/src/snapshots/ty_server__server__tests__initialization.snap index b7f89f5e2abe3..9c76c557427fe 100644 --- a/crates/ty_server/src/snapshots/ty_server__server__tests__initialization.snap +++ b/crates/ty_server/src/snapshots/ty_server__server__tests__initialization.snap @@ -26,6 +26,7 @@ expression: initialization_result }, "definitionProvider": true, "typeDefinitionProvider": true, + "referencesProvider": true, "declarationProvider": true, "semanticTokensProvider": { "legend": { diff --git a/crates/ty_server/src/snapshots/ty_server__server__tests__initialization_with_workspace.snap b/crates/ty_server/src/snapshots/ty_server__server__tests__initialization_with_workspace.snap index b7f89f5e2abe3..9c76c557427fe 100644 --- a/crates/ty_server/src/snapshots/ty_server__server__tests__initialization_with_workspace.snap +++ b/crates/ty_server/src/snapshots/ty_server__server__tests__initialization_with_workspace.snap @@ -26,6 +26,7 @@ expression: initialization_result }, "definitionProvider": true, "typeDefinitionProvider": true, + "referencesProvider": true, "declarationProvider": true, "semanticTokensProvider": { "legend": {