Skip to content

Commit

Permalink
Prism::CodeUnitsCache
Browse files Browse the repository at this point in the history
Calculating code unit offsets for a source can be very expensive,
especially when the source is large. This commit introduces a new
class that wraps the source and desired encoding into a cache that
reuses pre-computed offsets. It performs quite a bit better.

There are still some problems with this approach, namely character
boundaries and the fact that the cache is unbounded, but both of
these may be addressed in subsequent commits.
  • Loading branch information
kddnewton committed Oct 10, 2024
1 parent d6e9b8d commit 0056890
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 0 deletions.
112 changes: 112 additions & 0 deletions lib/prism/parse_result.rb
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def code_units_offset(byte_offset, encoding)
end
end

# Generate a cache that targets a specific encoding for calculating code
# unit offsets.
def code_units_cache(encoding)
CodeUnitsCache.new(source, encoding)
end

# Returns the column number in code units for the given encoding for the
# given byte offset.
def code_units_column(byte_offset, encoding)
Expand Down Expand Up @@ -149,6 +155,76 @@ def find_line(byte_offset)
end
end

# A cache that can be used to quickly compute code unit offsets from byte
# offsets. It purposefully provides only a single #[] method to access the
# cache in order to minimize surface area.
#
# Note that there are some known issues here that may or may not be addressed
# in the future:
#
# * The first is that there are issues when the cache computes values that are
# not on character boundaries. This can result in subsequent computations
# being off by one or more code units.
# * The second is that this cache is currently unbounded. In theory we could
# introduce some kind of LRU cache to limit the number of entries, but this
# has not yet been implemented.
#
class CodeUnitsCache
class UTF16Counter # :nodoc:
def initialize(source, encoding)
@source = source
@encoding = encoding
end

def count(byte_offset, byte_length)
@source.byteslice(byte_offset, byte_length).encode(@encoding, invalid: :replace, undef: :replace).bytesize / 2
end
end

class LengthCounter # :nodoc:
def initialize(source, encoding)
@source = source
@encoding = encoding
end

def count(byte_offset, byte_length)
@source.byteslice(byte_offset, byte_length).encode(@encoding, invalid: :replace, undef: :replace).length
end
end

private_constant :UTF16Counter, :LengthCounter

# Initialize a new cache with the given source and encoding.
def initialize(source, encoding)
@source = source
@counter =
if encoding == Encoding::UTF_16LE || encoding == Encoding::UTF_16BE
UTF16Counter.new(source, encoding)
else
LengthCounter.new(source, encoding)
end

@cache = {}
@offsets = []
end

# Retrieve the code units offset from the given byte offset.
def [](byte_offset)
@cache[byte_offset] ||=
if (index = @offsets.bsearch_index { |offset| offset > byte_offset }).nil?
@offsets << byte_offset
@counter.count(0, byte_offset)
elsif index == 0
@offsets.unshift(byte_offset)
@counter.count(0, byte_offset)
else
@offsets.insert(index, byte_offset)
offset = @offsets[index - 1]
@cache[offset] + @counter.count(offset, byte_offset - offset)
end
end
end

# Specialized version of Prism::Source for source code that includes ASCII
# characters only. This class is used to apply performance optimizations that
# cannot be applied to sources that include multibyte characters.
Expand Down Expand Up @@ -178,6 +254,13 @@ def code_units_offset(byte_offset, encoding)
byte_offset
end

# Returns a cache that is the identity function in order to maintain the
# same interface. We can do this because code units are always equivalent to
# byte offsets for ASCII-only sources.
def code_units_cache(encoding)
->(byte_offset) { byte_offset }
end

# Specialized version of `code_units_column` that does not depend on
# `code_units_offset`, which is a more expensive operation. This is
# essentially the same as `Prism::Source#column`.
Expand Down Expand Up @@ -287,6 +370,12 @@ def start_code_units_offset(encoding = Encoding::UTF_16LE)
source.code_units_offset(start_offset, encoding)
end

# The start offset from the start of the file in code units using the given
# cache to fetch or calculate the value.
def cached_start_code_units_offset(cache)
cache[start_offset]
end

# The byte offset from the beginning of the source where this location ends.
def end_offset
start_offset + length
Expand All @@ -303,6 +392,12 @@ def end_code_units_offset(encoding = Encoding::UTF_16LE)
source.code_units_offset(end_offset, encoding)
end

# The end offset from the start of the file in code units using the given
# cache to fetch or calculate the value.
def cached_end_code_units_offset(cache)
cache[end_offset]
end

# The line number where this location starts.
def start_line
source.line(start_offset)
Expand Down Expand Up @@ -337,6 +432,12 @@ def start_code_units_column(encoding = Encoding::UTF_16LE)
source.code_units_column(start_offset, encoding)
end

# The start column in code units using the given cache to fetch or calculate
# the value.
def cached_start_code_units_column(cache)
cache[start_offset] - cache[source.line_start(start_offset)]
end

# The column number in bytes where this location ends from the start of the
# line.
def end_column
Expand All @@ -355,6 +456,12 @@ def end_code_units_column(encoding = Encoding::UTF_16LE)
source.code_units_column(end_offset, encoding)
end

# The end column in code units using the given cache to fetch or calculate
# the value.
def cached_end_code_units_column(cache)
cache[end_offset] - cache[source.line_start(end_offset)]
end

# Implement the hash pattern matching interface for Location.
def deconstruct_keys(keys)
{ start_offset: start_offset, end_offset: end_offset }
Expand Down Expand Up @@ -604,6 +711,11 @@ def success?
def failure?
!success?
end

# Create a code units cache for the given encoding.
def code_units_cache(encoding)
source.code_units_cache(encoding)
end
end

# This is a result specific to the `parse` and `parse_file` methods.
Expand Down
29 changes: 29 additions & 0 deletions rbi/prism/parse_result.rbi
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,21 @@ class Prism::Source
sig { params(byte_offset: Integer, encoding: Encoding).returns(Integer) }
def code_units_offset(byte_offset, encoding); end

sig { params(encoding: Encoding).returns(T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))) }
def code_units_cache(encoding); end

sig { params(byte_offset: Integer, encoding: Encoding).returns(Integer) }
def code_units_column(byte_offset, encoding); end
end

class Prism::CodeUnitsCache
sig { params(source: Source, encoding: Encoding).void }
def initialize(source, encoding); end

sig { params(byte_offset: Integer).returns(Integer) }
def [](byte_offset); end
end

class Prism::ASCIISource < Prism::Source
sig { params(byte_offset: Integer).returns(Integer) }
def character_offset(byte_offset); end
Expand All @@ -54,6 +65,9 @@ class Prism::ASCIISource < Prism::Source
sig { params(byte_offset: Integer, encoding: Encoding).returns(Integer) }
def code_units_offset(byte_offset, encoding); end

sig { params(encoding: Encoding).returns(T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))) }
def code_units_cache(encoding); end

sig { params(byte_offset: Integer, encoding: Encoding).returns(Integer) }
def code_units_column(byte_offset, encoding); end
end
Expand Down Expand Up @@ -107,6 +121,9 @@ class Prism::Location
sig { params(encoding: Encoding).returns(Integer) }
def start_code_units_offset(encoding = Encoding::UTF_16LE); end

sig { params(cache: T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))).returns(Integer) }
def cached_start_code_units_offset(cache); end

sig { returns(Integer) }
def end_offset; end

Expand All @@ -116,6 +133,9 @@ class Prism::Location
sig { params(encoding: Encoding).returns(Integer) }
def end_code_units_offset(encoding = Encoding::UTF_16LE); end

sig { params(cache: T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))).returns(Integer) }
def cached_end_code_units_offset(cache); end

sig { returns(Integer) }
def start_line; end

Expand All @@ -134,6 +154,9 @@ class Prism::Location
sig { params(encoding: Encoding).returns(Integer) }
def start_code_units_column(encoding = Encoding::UTF_16LE); end

sig { params(cache: T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))).returns(Integer) }
def cached_start_code_units_column(cache); end

sig { returns(Integer) }
def end_column; end

Expand All @@ -143,6 +166,9 @@ class Prism::Location
sig { params(encoding: Encoding).returns(Integer) }
def end_code_units_column(encoding = Encoding::UTF_16LE); end

sig { params(cache: T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))).returns(Integer) }
def cached_end_code_units_column(cache); end

sig { params(keys: T.nilable(T::Array[Symbol])).returns(T::Hash[Symbol, T.untyped]) }
def deconstruct_keys(keys); end

Expand Down Expand Up @@ -296,6 +322,9 @@ class Prism::Result

sig { returns(T::Boolean) }
def failure?; end

sig { params(encoding: Encoding).returns(T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))) }
def code_units_cache(encoding); end
end

class Prism::ParseResult < Prism::Result
Expand Down
12 changes: 12 additions & 0 deletions sig/prism/_private/parse_result.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@ module Prism
def find_line: (Integer) -> Integer
end

class CodeUnitsCache
class UTF16Counter
def initialize: (String source, Encoding encoding) -> void
def count: (Integer byte_offset, Integer byte_length) -> Integer
end

class LengthCounter
def initialize: (String source, Encoding encoding) -> void
def count: (Integer byte_offset, Integer byte_length) -> Integer
end
end

class Location
private

Expand Down
20 changes: 20 additions & 0 deletions sig/prism/parse_result.rbs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
module Prism
interface _CodeUnitsCache
def []: (Integer byte_offset) -> Integer
end

class Source
attr_reader source: String
attr_reader start_line: Integer
Expand All @@ -16,15 +20,22 @@ module Prism
def character_offset: (Integer byte_offset) -> Integer
def character_column: (Integer byte_offset) -> Integer
def code_units_offset: (Integer byte_offset, Encoding encoding) -> Integer
def code_units_cache: (Encoding encoding) -> _CodeUnitsCache
def code_units_column: (Integer byte_offset, Encoding encoding) -> Integer

def self.for: (String source) -> Source
end

class CodeUnitsCache
def initialize: (String source, Encoding encoding) -> void
def []: (Integer byte_offset) -> Integer
end

class ASCIISource < Source
def character_offset: (Integer byte_offset) -> Integer
def character_column: (Integer byte_offset) -> Integer
def code_units_offset: (Integer byte_offset, Encoding encoding) -> Integer
def code_units_cache: (Encoding encoding) -> _CodeUnitsCache
def code_units_column: (Integer byte_offset, Encoding encoding) -> Integer
end

Expand All @@ -45,15 +56,23 @@ module Prism
def slice: () -> String
def slice_lines: () -> String
def start_character_offset: () -> Integer
def start_code_units_offset: (Encoding encoding) -> Integer
def cached_start_code_units_offset: (_CodeUnitsCache cache) -> Integer
def end_offset: () -> Integer
def end_character_offset: () -> Integer
def end_code_units_offset: (Encoding encoding) -> Integer
def cached_end_code_units_offset: (_CodeUnitsCache cache) -> Integer
def start_line: () -> Integer
def start_line_slice: () -> String
def end_line: () -> Integer
def start_column: () -> Integer
def start_character_column: () -> Integer
def start_code_units_column: (Encoding encoding) -> Integer
def cached_start_code_units_column: (_CodeUnitsCache cache) -> Integer
def end_column: () -> Integer
def end_character_column: () -> Integer
def end_code_units_column: (Encoding encoding) -> Integer
def cached_end_code_units_column: (_CodeUnitsCache cache) -> Integer
def deconstruct_keys: (Array[Symbol]? keys) -> Hash[Symbol, untyped]
def pretty_print: (untyped q) -> untyped
def join: (Location other) -> Location
Expand Down Expand Up @@ -125,6 +144,7 @@ module Prism
def deconstruct_keys: (Array[Symbol]? keys) -> Hash[Symbol, untyped]
def success?: () -> bool
def failure?: () -> bool
def code_units_cache: (Encoding encoding) -> _CodeUnitsCache
end

class ParseResult < Result
Expand Down
46 changes: 46 additions & 0 deletions test/prism/ruby/location_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,52 @@ def test_code_units
assert_equal 7, location.end_code_units_column(Encoding::UTF_32LE)
end

def test_cached_code_units
result = Prism.parse("πŸ˜€ + πŸ˜€\n😍 ||= 😍")

utf8_cache = result.code_units_cache(Encoding::UTF_8)
utf16_cache = result.code_units_cache(Encoding::UTF_16LE)
utf32_cache = result.code_units_cache(Encoding::UTF_32LE)

# first πŸ˜€
location = result.value.statements.body.first.receiver.location

assert_equal 0, location.cached_start_code_units_offset(utf8_cache)
assert_equal 0, location.cached_start_code_units_offset(utf16_cache)
assert_equal 0, location.cached_start_code_units_offset(utf32_cache)

assert_equal 1, location.cached_end_code_units_offset(utf8_cache)
assert_equal 2, location.cached_end_code_units_offset(utf16_cache)
assert_equal 1, location.cached_end_code_units_offset(utf32_cache)

assert_equal 0, location.cached_start_code_units_column(utf8_cache)
assert_equal 0, location.cached_start_code_units_column(utf16_cache)
assert_equal 0, location.cached_start_code_units_column(utf32_cache)

assert_equal 1, location.cached_end_code_units_column(utf8_cache)
assert_equal 2, location.cached_end_code_units_column(utf16_cache)
assert_equal 1, location.cached_end_code_units_column(utf32_cache)

# second πŸ˜€
location = result.value.statements.body.first.arguments.arguments.first.location

assert_equal 4, location.cached_start_code_units_offset(utf8_cache)
assert_equal 5, location.cached_start_code_units_offset(utf16_cache)
assert_equal 4, location.cached_start_code_units_offset(utf32_cache)

assert_equal 5, location.cached_end_code_units_offset(utf8_cache)
assert_equal 7, location.cached_end_code_units_offset(utf16_cache)
assert_equal 5, location.cached_end_code_units_offset(utf32_cache)

assert_equal 4, location.cached_start_code_units_column(utf8_cache)
assert_equal 5, location.cached_start_code_units_column(utf16_cache)
assert_equal 4, location.cached_start_code_units_column(utf32_cache)

assert_equal 5, location.cached_end_code_units_column(utf8_cache)
assert_equal 7, location.cached_end_code_units_column(utf16_cache)
assert_equal 5, location.cached_end_code_units_column(utf32_cache)
end

def test_code_units_binary_valid_utf8
program = Prism.parse(<<~RUBY).value
# -*- encoding: binary -*-
Expand Down

0 comments on commit 0056890

Please sign in to comment.