diff --git a/spec/std/string/utf16_spec.cr b/spec/std/string/utf16_spec.cr index 49141c79ddd2..a092bd2a48a4 100644 --- a/spec/std/string/utf16_spec.cr +++ b/spec/std/string/utf16_spec.cr @@ -78,5 +78,11 @@ describe "String UTF16" do string, pointer = String.from_utf16(pointer) string.should eq("hello\u{d7ff}") end + + it "allows creating from a null-terminated slice" do + String.from_utf16(Slice(UInt16).empty, truncate_at_null: true).should eq("") + String.from_utf16(UInt16.slice(102, 111, 111, 98, 97, 114), truncate_at_null: true).should eq("foobar") + String.from_utf16(UInt16.slice(102, 111, 111, 0, 98, 97, 114), truncate_at_null: true).should eq("foo") + end end end diff --git a/spec/std/string_spec.cr b/spec/std/string_spec.cr index fd85a196aaaa..6a42ec44ded1 100644 --- a/spec/std/string_spec.cr +++ b/spec/std/string_spec.cr @@ -2234,6 +2234,16 @@ describe "String" do it "allows creating from an empty slice" do String.new(Bytes.empty).should eq("") end + + it "allows creating from a non-empty slice" do + String.new(UInt8.slice(102, 111, 111, 0, 98, 97, 114)).should eq("foo\0bar") + end + + it "allows creating from a null-terminated slice" do + String.new(Bytes.empty, truncate_at_null: true).should eq("") + String.new(UInt8.slice(102, 111, 111, 98, 97, 114), truncate_at_null: true).should eq("foobar") + String.new(UInt8.slice(102, 111, 111, 0, 98, 97, 114), truncate_at_null: true).should eq("foo") + end end describe "tr" do diff --git a/src/crystal/pe.cr b/src/crystal/pe.cr index d1b19401ad19..a9f86f56ac4e 100644 --- a/src/crystal/pe.cr +++ b/src/crystal/pe.cr @@ -52,11 +52,11 @@ module Crystal if nt_header.name[0] === '/' # section name is longer than 8 bytes; look up the COFF string table name_buf = nt_header.name.to_slice + 1 - string_offset = String.new(name_buf.to_unsafe, name_buf.index(0) || name_buf.size).to_i + string_offset = String.new(name_buf, truncate_at_null: true).to_i io.seek(@string_table_base + string_offset) name = io.gets('\0', chomp: true).not_nil! else - name = String.new(nt_header.name.to_unsafe, nt_header.name.index(0) || nt_header.name.size) + name = String.new(nt_header.name.to_slice, truncate_at_null: true) end SectionHeader.new(name: name, virtual_offset: nt_header.virtualAddress, offset: nt_header.pointerToRawData, size: nt_header.virtualSize) @@ -84,7 +84,7 @@ module Crystal io.seek(@string_table_base + sym.n.name.long) name = io.gets('\0', chomp: true).not_nil! else - name = String.new(sym.n.shortName.to_slice).rstrip('\0') + name = String.new(sym.n.shortName.to_slice, truncate_at_null: true) end # `@coff_symbols` uses zero-based indices diff --git a/src/crystal/system/win32/library_archive.cr b/src/crystal/system/win32/library_archive.cr index 24c50f3405fa..25b848bf1118 100644 --- a/src/crystal/system/win32/library_archive.cr +++ b/src/crystal/system/win32/library_archive.cr @@ -112,7 +112,7 @@ module Crystal::System::LibraryArchive section_header = uninitialized LibC::IMAGE_SECTION_HEADER return unless io.read_fully?(pointerof(section_header).to_slice(1).to_unsafe_bytes) - name = String.new(section_header.name.to_unsafe, section_header.name.index(0) || section_header.name.size) + name = String.new(section_header.name.to_slice, truncate_at_null: true) next unless name == (msvc? ? ".idata$6" : ".idata$7") if msvc? ? section_header.characteristics.bits_set?(LibC::IMAGE_SCN_CNT_INITIALIZED_DATA) : section_header.pointerToRelocations == 0 diff --git a/src/socket/address.cr b/src/socket/address.cr index d78ba35621fc..ab4e5a00d7d4 100644 --- a/src/socket/address.cr +++ b/src/socket/address.cr @@ -916,7 +916,7 @@ class Socket {% unless flag?(:wasm32) %} protected def initialize(sockaddr : LibC::SockaddrUn*, size) @family = Family::UNIX - @path = String.new(sockaddr.value.sun_path.to_unsafe) + @path = String.new(sockaddr.value.sun_path.to_slice, truncate_at_null: true) @size = size || sizeof(LibC::SockaddrUn) end {% end %} @@ -933,7 +933,7 @@ class Socket {% else %} sockaddr = Pointer(LibC::SockaddrUn).malloc sockaddr.value.sun_family = family - sockaddr.value.sun_path.to_unsafe.copy_from(@path.to_unsafe, @path.bytesize + 1) + sockaddr.value.sun_path.to_unsafe.copy_from(@path.to_unsafe, {@path.bytesize + 1, sockaddr.value.sun_path.size}.min) sockaddr.as(LibC::Sockaddr*) {% end %} end diff --git a/src/string.cr b/src/string.cr index 19337fc7efc9..e642b2d21a94 100644 --- a/src/string.cr +++ b/src/string.cr @@ -159,12 +159,24 @@ class String # This method is always safe to call, and the resulting string will have # the contents and size of the slice. # + # If *truncate_at_null* is true, only the characters up to and not including + # the first null character are copied. + # # ``` # slice = Slice.new(4) { |i| ('a'.ord + i).to_u8 } # String.new(slice) # => "abcd" + # + # slice = UInt8.slice(102, 111, 111, 0, 98, 97, 114) + # String.new(slice, truncate_at_null: true) # => "foo" # ``` - def self.new(slice : Bytes) - new(slice.to_unsafe, slice.size) + def self.new(slice : Bytes, *, truncate_at_null : Bool = false) + bytesize = slice.size + if truncate_at_null + if index = slice.index(0) + bytesize = index + end + end + new(slice.to_unsafe, bytesize) end # Creates a new `String` from the given *bytes*, which are encoded in the given *encoding*. diff --git a/src/string/utf16.cr b/src/string/utf16.cr index 697c1b585a37..6f85a5c916d9 100644 --- a/src/string/utf16.cr +++ b/src/string/utf16.cr @@ -48,21 +48,27 @@ class String # Invalid values are encoded using the unicode replacement char with # codepoint `0xfffd`. # + # If *truncate_at_null* is true, only the characters up to and not including + # the first null character are copied. + # # ``` # slice = Slice[104_u16, 105_u16, 32_u16, 55296_u16, 56485_u16] # String.from_utf16(slice) # => "hi 𐂥" + # + # slice = UInt16.slice(102, 111, 111, 0, 98, 97, 114) + # String.from_utf16(slice, truncate_at_null: true) # => "foo" # ``` - def self.from_utf16(slice : Slice(UInt16)) : String + def self.from_utf16(slice : Slice(UInt16), *, truncate_at_null : Bool = false) : String bytesize = 0 size = 0 - each_utf16_char(slice) do |char| + each_utf16_char(slice, truncate_at_null: truncate_at_null) do |char| bytesize += char.bytesize size += 1 end String.new(bytesize) do |buffer| - each_utf16_char(slice) do |char| + each_utf16_char(slice, truncate_at_null: truncate_at_null) do |char| char.each_byte do |byte| buffer.value = byte buffer += 1 @@ -112,10 +118,11 @@ class String # :nodoc: # # Yields each decoded char in the given slice. - def self.each_utf16_char(slice : Slice(UInt16), &) + def self.each_utf16_char(slice : Slice(UInt16), *, truncate_at_null : Bool = false, &) i = 0 while i < slice.size byte = slice[i].to_i + break if truncate_at_null && byte == 0 if byte < 0xd800 || byte >= 0xe000 # One byte codepoint = byte