From 20bf28e03539c1c9d82a808305a224237a65ed54 Mon Sep 17 00:00:00 2001 From: Vinicius Stock Date: Thu, 10 Oct 2024 11:49:34 -0400 Subject: [PATCH] Use code units cache for locating targets --- lib/ruby_lsp/erb_document.rb | 25 +++++---- lib/ruby_lsp/requests/code_action_resolve.rb | 10 +++- lib/ruby_lsp/requests/completion.rb | 2 +- lib/ruby_lsp/requests/definition.rb | 2 +- lib/ruby_lsp/requests/document_highlight.rb | 6 ++- lib/ruby_lsp/requests/hover.rb | 2 +- lib/ruby_lsp/requests/references.rb | 2 +- lib/ruby_lsp/requests/rename.rb | 2 +- lib/ruby_lsp/requests/signature_help.rb | 2 +- lib/ruby_lsp/ruby_document.rb | 55 ++++++++++++-------- 10 files changed, 66 insertions(+), 42 deletions(-) diff --git a/lib/ruby_lsp/erb_document.rb b/lib/ruby_lsp/erb_document.rb index 310040ed0..468f0d5d6 100644 --- a/lib/ruby_lsp/erb_document.rb +++ b/lib/ruby_lsp/erb_document.rb @@ -11,12 +11,24 @@ class ERBDocument < Document sig { returns(String) } attr_reader :host_language_source + sig do + returns(T.any( + T.proc.params(arg0: Integer).returns(Integer), + Prism::CodeUnitsCache, + )) + end + attr_reader :code_units_cache + sig { params(source: String, version: Integer, uri: URI::Generic, encoding: Encoding).void } def initialize(source:, version:, uri:, encoding: Encoding::UTF_8) # This has to be initialized before calling super because we call `parse` in the parent constructor, which # overrides this with the proper virtual host language source @host_language_source = T.let("", String) super + @code_units_cache = T.let(@parse_result.code_units_cache(@encoding), T.any( + T.proc.params(arg0: Integer).returns(Integer), + Prism::CodeUnitsCache, + )) end sig { override.returns(T::Boolean) } @@ -30,6 +42,7 @@ def parse! # Use partial script to avoid syntax errors in ERB files where keywords may be used without the full context in # which they will be evaluated @parse_result = Prism.parse(scanner.ruby, partial_script: true) + @code_units_cache = @parse_result.code_units_cache(@encoding) true end @@ -53,8 +66,8 @@ def locate_node(position, node_types: []) RubyDocument.locate( @parse_result.value, create_scanner.find_char_position(position), + code_units_cache: @code_units_cache, node_types: node_types, - encoding: @encoding, ) end @@ -64,16 +77,6 @@ def inside_host_language?(char_position) char && char != " " end - sig do - returns(T.any( - T.proc.params(arg0: Integer).returns(Integer), - Prism::CodeUnitsCache, - )) - end - def code_units_cache - @parse_result.code_units_cache(@encoding) - end - class ERBScanner extend T::Sig diff --git a/lib/ruby_lsp/requests/code_action_resolve.rb b/lib/ruby_lsp/requests/code_action_resolve.rb index 9d7a342e2..e6bd7b38d 100644 --- a/lib/ruby_lsp/requests/code_action_resolve.rb +++ b/lib/ruby_lsp/requests/code_action_resolve.rb @@ -99,7 +99,13 @@ def refactor_variable # Find the closest statements node, so that we place the refactor in a valid position node_context = RubyDocument - .locate(@document.parse_result.value, start_index, node_types: [Prism::StatementsNode, Prism::BlockNode]) + .locate(@document.parse_result.value, + start_index, + node_types: [ + Prism::StatementsNode, + Prism::BlockNode, + ], + code_units_cache: @document.code_units_cache) closest_statements = node_context.node parent_statements = node_context.parent @@ -196,7 +202,7 @@ def refactor_method @document.parse_result.value, start_index, node_types: [Prism::DefNode], - encoding: @global_state.encoding, + code_units_cache: @document.code_units_cache, ) closest_node = node_context.node return Error::InvalidTargetRange unless closest_node diff --git a/lib/ruby_lsp/requests/completion.rb b/lib/ruby_lsp/requests/completion.rb index 632c7ac30..7dd27a851 100644 --- a/lib/ruby_lsp/requests/completion.rb +++ b/lib/ruby_lsp/requests/completion.rb @@ -57,7 +57,7 @@ def initialize(document, global_state, params, sorbet_level, dispatcher) Prism::InstanceVariableTargetNode, Prism::InstanceVariableWriteNode, ], - encoding: global_state.encoding, + code_units_cache: document.code_units_cache, ) @response_builder = T.let( ResponseBuilders::CollectionResponseBuilder[Interface::CompletionItem].new, diff --git a/lib/ruby_lsp/requests/definition.rb b/lib/ruby_lsp/requests/definition.rb index fe368e4b8..59f3592c3 100644 --- a/lib/ruby_lsp/requests/definition.rb +++ b/lib/ruby_lsp/requests/definition.rb @@ -58,7 +58,7 @@ def initialize(document, global_state, position, dispatcher, sorbet_level) Prism::SuperNode, Prism::ForwardingSuperNode, ], - encoding: global_state.encoding, + code_units_cache: document.code_units_cache, ) target = node_context.node diff --git a/lib/ruby_lsp/requests/document_highlight.rb b/lib/ruby_lsp/requests/document_highlight.rb index 0639054c6..b04575636 100644 --- a/lib/ruby_lsp/requests/document_highlight.rb +++ b/lib/ruby_lsp/requests/document_highlight.rb @@ -28,7 +28,11 @@ def initialize(global_state, document, position, dispatcher) char_position = document.create_scanner.find_char_position(position) delegate_request_if_needed!(global_state, document, char_position) - node_context = RubyDocument.locate(document.parse_result.value, char_position, encoding: global_state.encoding) + node_context = RubyDocument.locate( + document.parse_result.value, + char_position, + code_units_cache: document.code_units_cache, + ) @response_builder = T.let( ResponseBuilders::CollectionResponseBuilder[Interface::DocumentHighlight].new, diff --git a/lib/ruby_lsp/requests/hover.rb b/lib/ruby_lsp/requests/hover.rb index c298c9187..a675f91be 100644 --- a/lib/ruby_lsp/requests/hover.rb +++ b/lib/ruby_lsp/requests/hover.rb @@ -41,7 +41,7 @@ def initialize(document, global_state, position, dispatcher, sorbet_level) document.parse_result.value, char_position, node_types: Listeners::Hover::ALLOWED_TARGETS, - encoding: global_state.encoding, + code_units_cache: document.code_units_cache, ) target = node_context.node parent = node_context.parent diff --git a/lib/ruby_lsp/requests/references.rb b/lib/ruby_lsp/requests/references.rb index 8ace4d2cd..a3c626ad9 100644 --- a/lib/ruby_lsp/requests/references.rb +++ b/lib/ruby_lsp/requests/references.rb @@ -42,7 +42,7 @@ def perform Prism::CallNode, Prism::DefNode, ], - encoding: @global_state.encoding, + code_units_cache: @document.code_units_cache, ) target = node_context.node parent = node_context.parent diff --git a/lib/ruby_lsp/requests/rename.rb b/lib/ruby_lsp/requests/rename.rb index b150935a5..2b92e7402 100644 --- a/lib/ruby_lsp/requests/rename.rb +++ b/lib/ruby_lsp/requests/rename.rb @@ -37,7 +37,7 @@ def perform @document.parse_result.value, char_position, node_types: [Prism::ConstantReadNode, Prism::ConstantPathNode, Prism::ConstantPathTargetNode], - encoding: @global_state.encoding, + code_units_cache: @document.code_units_cache, ) target = node_context.node parent = node_context.parent diff --git a/lib/ruby_lsp/requests/signature_help.rb b/lib/ruby_lsp/requests/signature_help.rb index f5451b71c..ce8709d88 100644 --- a/lib/ruby_lsp/requests/signature_help.rb +++ b/lib/ruby_lsp/requests/signature_help.rb @@ -43,7 +43,7 @@ def initialize(document, global_state, position, context, dispatcher, sorbet_lev document.parse_result.value, char_position, node_types: [Prism::CallNode], - encoding: global_state.encoding, + code_units_cache: document.code_units_cache, ) target = adjust_for_nested_target(node_context.node, node_context.parent, position) diff --git a/lib/ruby_lsp/ruby_document.rb b/lib/ruby_lsp/ruby_document.rb index 03476e234..4cde7e656 100644 --- a/lib/ruby_lsp/ruby_document.rb +++ b/lib/ruby_lsp/ruby_document.rb @@ -25,11 +25,14 @@ class << self params( node: Prism::Node, char_position: Integer, + code_units_cache: T.any( + T.proc.params(arg0: Integer).returns(Integer), + Prism::CodeUnitsCache, + ), node_types: T::Array[T.class_of(Prism::Node)], - encoding: Encoding, ).returns(NodeContext) end - def locate(node, char_position, node_types: [], encoding: Encoding::UTF_8) + def locate(node, char_position, code_units_cache:, node_types: []) queue = T.let(node.child_nodes.compact, T::Array[T.nilable(Prism::Node)]) closest = node parent = T.let(nil, T.nilable(Prism::Node)) @@ -62,8 +65,8 @@ def locate(node, char_position, node_types: [], encoding: Encoding::UTF_8) # Skip if the current node doesn't cover the desired position loc = candidate.location - loc_start_offset = loc.start_code_units_offset(encoding) - loc_end_offset = loc.end_code_units_offset(encoding) + loc_start_offset = loc.cached_start_code_units_offset(code_units_cache) + loc_end_offset = loc.cached_end_code_units_offset(code_units_cache) next unless (loc_start_offset...loc_end_offset).cover?(char_position) # If the node's start character is already past the position, then we should've found the closest node @@ -74,7 +77,7 @@ def locate(node, char_position, node_types: [], encoding: Encoding::UTF_8) # and need to pop the stack previous_level = nesting_nodes.last if previous_level && - (loc_start_offset > previous_level.location.end_code_units_offset(encoding)) + (loc_start_offset > previous_level.location.cached_end_code_units_offset(code_units_cache)) nesting_nodes.pop end @@ -89,10 +92,10 @@ def locate(node, char_position, node_types: [], encoding: Encoding::UTF_8) if candidate.is_a?(Prism::CallNode) arg_loc = candidate.arguments&.location blk_loc = candidate.block&.location - if (arg_loc && (arg_loc.start_code_units_offset(encoding)... - arg_loc.end_code_units_offset(encoding)).cover?(char_position)) || - (blk_loc && (blk_loc.start_code_units_offset(encoding)... - blk_loc.end_code_units_offset(encoding)).cover?(char_position)) + if (arg_loc && (arg_loc.cached_start_code_units_offset(code_units_cache)... + arg_loc.cached_end_code_units_offset(code_units_cache)).cover?(char_position)) || + (blk_loc && (blk_loc.cached_start_code_units_offset(code_units_cache)... + blk_loc.cached_end_code_units_offset(code_units_cache)).cover?(char_position)) call_node = candidate end end @@ -102,8 +105,8 @@ def locate(node, char_position, node_types: [], encoding: Encoding::UTF_8) # If the current node is narrower than or equal to the previous closest node, then it is more precise closest_loc = closest.location - closest_node_start_offset = closest_loc.start_code_units_offset(encoding) - closest_node_end_offset = closest_loc.end_code_units_offset(encoding) + closest_node_start_offset = closest_loc.cached_start_code_units_offset(code_units_cache) + closest_node_end_offset = closest_loc.cached_end_code_units_offset(code_units_cache) if loc_end_offset - loc_start_offset <= closest_node_end_offset - closest_node_start_offset parent = closest closest = candidate @@ -131,12 +134,30 @@ def locate(node, char_position, node_types: [], encoding: Encoding::UTF_8) end end + sig do + returns(T.any( + T.proc.params(arg0: Integer).returns(Integer), + Prism::CodeUnitsCache, + )) + end + attr_reader :code_units_cache + + sig { params(source: String, version: Integer, uri: URI::Generic, encoding: Encoding).void } + def initialize(source:, version:, uri:, encoding: Encoding::UTF_8) + super + @code_units_cache = T.let(@parse_result.code_units_cache(@encoding), T.any( + T.proc.params(arg0: Integer).returns(Integer), + Prism::CodeUnitsCache, + )) + end + sig { override.returns(T::Boolean) } def parse! return false unless @needs_parsing @needs_parsing = false @parse_result = Prism.parse(@source) + @code_units_cache = @parse_result.code_units_cache(@encoding) true end @@ -214,19 +235,9 @@ def locate_node(position, node_types: []) RubyDocument.locate( @parse_result.value, create_scanner.find_char_position(position), + code_units_cache: @code_units_cache, node_types: node_types, - encoding: @encoding, ) end - - sig do - returns(T.any( - T.proc.params(arg0: Integer).returns(Integer), - Prism::CodeUnitsCache, - )) - end - def code_units_cache - @parse_result.code_units_cache(@encoding) - end end end