From bf00034493381f6ae2b49f04e101b3251bb099c9 Mon Sep 17 00:00:00 2001 From: Joshua Haberman Date: Tue, 23 Jan 2024 20:29:09 -0800 Subject: [PATCH] Breaking Change: Made text_format output default to UTF-8. Also hardened the text format printer against invalid UTF-8 in string fields. The output string will always be valid UTF-8, even if string fields contain invalid UTF-8. PiperOrigin-RevId: 600990001 --- .../protobuf/internal/text_encoding_test.py | 10 +-- .../protobuf/internal/text_format_test.py | 52 +++++++++++--- python/google/protobuf/text_encoding.py | 70 ++++++++++++------- python/google/protobuf/text_format.py | 14 ++-- 4 files changed, 100 insertions(+), 46 deletions(-) diff --git a/python/google/protobuf/internal/text_encoding_test.py b/python/google/protobuf/internal/text_encoding_test.py index a0253b2ce0c4..822ae4d3be5d 100755 --- a/python/google/protobuf/internal/text_encoding_test.py +++ b/python/google/protobuf/internal/text_encoding_test.py @@ -22,17 +22,17 @@ "signi\\\\fying\\\\ nothing\\\\", b"signi\\fying\\ nothing\\"), ("\\010\\t\\n\\013\\014\\r", - "\x08\\t\\n\x0b\x0c\\r", + "\\010\\t\\n\\013\\014\\r", b"\010\011\012\013\014\015")] class TextEncodingTestCase(unittest.TestCase): def testCEscape(self): for escaped, escaped_utf8, unescaped in TEST_VALUES: - self.assertEqual(escaped, - text_encoding.CEscape(unescaped, as_utf8=False)) - self.assertEqual(escaped_utf8, - text_encoding.CEscape(unescaped, as_utf8=True)) + self.assertEqual(escaped, text_encoding.CEscape(unescaped, as_utf8=False)) + self.assertEqual( + escaped_utf8, text_encoding.CEscape(unescaped, as_utf8=True) + ) def testCUnescape(self): for escaped, escaped_utf8, unescaped in TEST_VALUES: diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index 75f3b5b35ac8..d5772d3d3e30 100644 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -86,7 +86,9 @@ def testPrintExotic(self, message_module): message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"') message.repeated_string.append(u'\u00fc\ua71f') self.CompareToGoldenText( - self.RemoveRedundantZeros(text_format.MessageToString(message)), + self.RemoveRedundantZeros( + text_format.MessageToString(message, as_utf8=True) + ), 'repeated_int64: -9223372036854775808\n' 'repeated_uint64: 18446744073709551615\n' 'repeated_double: 123.456\n' @@ -94,7 +96,8 @@ def testPrintExotic(self, message_module): 'repeated_double: 1.23e-18\n' 'repeated_string:' ' "\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n' - 'repeated_string: "\\303\\274\\352\\234\\237"\n') + 'repeated_string: "üꜟ"\n', + ) def testPrintFloatPrecision(self, message_module): message = message_module.TestAllTypes() @@ -204,8 +207,8 @@ class UnicodeSub(str): message = message_module.TestAllTypes() message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f')) self.CompareToGoldenText( - text_format.MessageToString(message), - 'repeated_string: "\\303\\274\\352\\234\\237"\n') + text_format.MessageToString(message, as_utf8=True), + 'repeated_string: "üꜟ"\n') def testPrintNestedMessageAsOneLine(self, message_module): message = message_module.TestAllTypes() @@ -282,7 +285,7 @@ def testPrintExoticAsOneLine(self, message_module): message.repeated_string.append(u'\u00fc\ua71f') self.CompareToGoldenText( self.RemoveRedundantZeros(text_format.MessageToString( - message, as_one_line=True)), + message, as_one_line=True, as_utf8=True)), 'repeated_int64: -9223372036854775808' ' repeated_uint64: 18446744073709551615' ' repeated_double: 123.456' @@ -290,7 +293,7 @@ def testPrintExoticAsOneLine(self, message_module): ' repeated_double: 1.23e-18' ' repeated_string: ' '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""' - ' repeated_string: "\\303\\274\\352\\234\\237"') + ' repeated_string: "üꜟ"') def testRoundTripExoticAsOneLine(self, message_module): message = message_module.TestAllTypes() @@ -616,8 +619,8 @@ def testMessageToBytes(self, message_module): def testRawUtf8RoundTrip(self, message_module): message = message_module.TestAllTypes() message.repeated_string.append(u'\u00fc\t\ua71f') - utf8_text = text_format.MessageToBytes(message, as_utf8=True) - golden_bytes = b'repeated_string: "\xc3\xbc\\t\xea\x9c\x9f"\n' + utf8_text = text_format.MessageToBytes(message, as_utf8=False) + golden_bytes = b'repeated_string: "\\303\\274\\t\\352\\234\\237"\n' self.CompareToGoldenText(utf8_text, golden_bytes) parsed_message = message_module.TestAllTypes() text_format.Parse(utf8_text, parsed_message) @@ -626,10 +629,41 @@ def testRawUtf8RoundTrip(self, message_module): (message, parsed_message, message.repeated_string[0], parsed_message.repeated_string[0])) + def testRawUtf8RoundTripAsUtf8(self, message_module): + message = message_module.TestAllTypes() + message.repeated_string.append(u'\u00fc\t\ua71f') + utf8_text = text_format.MessageToString(message, as_utf8=True) + parsed_message = message_module.TestAllTypes() + text_format.Parse(utf8_text, parsed_message) + self.assertEqual( + message, parsed_message, '\n%s != %s (%s != %s)' % + (message, parsed_message, message.repeated_string[0], + parsed_message.repeated_string[0])) + + # We can only test this case under proto2, because proto3 will reject invalid + # UTF-8 in the parser, so there should be no way of creating a string field + # that contains invalid UTF-8. + # + # We also can't test it in pure-Python, which validates all string fields for + # UTF-8 even when the spec says it shouldn't. + @unittest.skipIf(api_implementation.Type() == 'python', + 'Python can\'t create invalid UTF-8 strings') + def testInvalidUtf8RoundTrip(self, message_module): + if message_module is not unittest_pb2: + return + one_bytes = unittest_pb2.OneBytes() + one_bytes.data = b'ABC\xff123' + one_string = unittest_pb2.OneString() + one_string.ParseFromString(one_bytes.SerializeToString()) + self.assertIn( + 'data: "ABC\\377123"', + text_format.MessageToString(one_string, as_utf8=True), + ) + def testEscapedUtf8ASCIIRoundTrip(self, message_module): message = message_module.TestAllTypes() message.repeated_string.append(u'\u00fc\t\ua71f') - ascii_text = text_format.MessageToBytes(message) # as_utf8=False default + ascii_text = text_format.MessageToBytes(message, as_utf8=False) golden_bytes = b'repeated_string: "\\303\\274\\t\\352\\234\\237"\n' self.CompareToGoldenText(ascii_text, golden_bytes) parsed_message = message_module.TestAllTypes() diff --git a/python/google/protobuf/text_encoding.py b/python/google/protobuf/text_encoding.py index 112e4ab105a6..03c27dc104b2 100644 --- a/python/google/protobuf/text_encoding.py +++ b/python/google/protobuf/text_encoding.py @@ -8,26 +8,42 @@ """Encoding related utilities.""" import re -_cescape_chr_to_symbol_map = {} -_cescape_chr_to_symbol_map[9] = r'\t' # optional escape -_cescape_chr_to_symbol_map[10] = r'\n' # optional escape -_cescape_chr_to_symbol_map[13] = r'\r' # optional escape -_cescape_chr_to_symbol_map[34] = r'\"' # necessary escape -_cescape_chr_to_symbol_map[39] = r"\'" # optional escape -_cescape_chr_to_symbol_map[92] = r'\\' # necessary escape - -# Lookup table for unicode -_cescape_unicode_to_str = [chr(i) for i in range(0, 256)] -for byte, string in _cescape_chr_to_symbol_map.items(): - _cescape_unicode_to_str[byte] = string - -# Lookup table for non-utf8, with necessary escapes at (o >= 127 or o < 32) -_cescape_byte_to_str = ([r'\%03o' % i for i in range(0, 32)] + - [chr(i) for i in range(32, 127)] + - [r'\%03o' % i for i in range(127, 256)]) -for byte, string in _cescape_chr_to_symbol_map.items(): - _cescape_byte_to_str[byte] = string -del byte, string +def _AsciiIsPrint(i): + return i >= 32 and i < 127 + +def _MakeStrEscapes(): + ret = {} + for i in range(0, 128): + if not _AsciiIsPrint(i): + ret[i] = r'\%03o' % i + ret[ord('\t')] = r'\t' # optional escape + ret[ord('\n')] = r'\n' # optional escape + ret[ord('\r')] = r'\r' # optional escape + ret[ord('"')] = r'\"' # necessary escape + ret[ord('\'')] = r"\'" # optional escape + ret[ord('\\')] = r'\\' # necessary escape + return ret + +# Maps int -> char, performing string escapes. +_str_escapes = _MakeStrEscapes() + +# Maps int -> char, performing byte escaping and string escapes +_byte_escapes = {i: chr(i) for i in range(0, 256)} +_byte_escapes.update(_str_escapes) +_byte_escapes.update({i: r'\%03o' % i for i in range(128, 256)}) + + +def _DecodeUtf8EscapeErrors(text_bytes): + ret = '' + while text_bytes: + try: + ret += text_bytes.decode('utf-8').translate(_str_escapes) + text_bytes = '' + except UnicodeDecodeError as e: + ret += text_bytes[:e.start].decode('utf-8').translate(_str_escapes) + ret += _byte_escapes[text_bytes[e.start]] + text_bytes = text_bytes[e.start+1:] + return ret def CEscape(text, as_utf8) -> str: @@ -47,13 +63,15 @@ def CEscape(text, as_utf8) -> str: # length. So, "\0011".encode('string_escape') ends up being "\\x011", which # will be decoded in C++ as a single-character string with char code 0x11. text_is_unicode = isinstance(text, str) - if as_utf8 and text_is_unicode: - # We're already unicode, no processing beyond control char escapes. - return text.translate(_cescape_chr_to_symbol_map) - ord_ = ord if text_is_unicode else lambda x: x # bytes iterate as ints. if as_utf8: - return ''.join(_cescape_unicode_to_str[ord_(c)] for c in text) - return ''.join(_cescape_byte_to_str[ord_(c)] for c in text) + if text_is_unicode: + return text.translate(_str_escapes) + else: + return _DecodeUtf8EscapeErrors(text) + else: + if text_is_unicode: + text = text.encode('utf-8') + return ''.join([_byte_escapes[c] for c in text]) _CUNESCAPE_HEX = re.compile(r'(\\+)x([0-9a-fA-F])(?![0-9a-fA-F])') diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index b448f660f512..4f6b94a9510c 100644 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -46,6 +46,8 @@ _ANY_FULL_TYPE_NAME = 'google.protobuf.Any' _DEBUG_STRING_SILENT_MARKER = '\t ' +_as_utf8_default = True + class Error(Exception): """Top-level module error for text_format.""" @@ -91,7 +93,7 @@ def getvalue(self): def MessageToString( message, - as_utf8=False, + as_utf8=_as_utf8_default, as_one_line=False, use_short_repeated_primitives=False, pointy_brackets=False, @@ -186,7 +188,7 @@ def _IsMapEntry(field): def PrintMessage(message, out, indent=0, - as_utf8=False, + as_utf8=_as_utf8_default, as_one_line=False, use_short_repeated_primitives=False, pointy_brackets=False, @@ -229,7 +231,7 @@ def PrintMessage(message, the field is a proto message. """ printer = _Printer( - out=out, indent=indent, as_utf8=as_utf8, + out=out, indent=indent, as_utf8=_as_utf8_default, as_one_line=as_one_line, use_short_repeated_primitives=use_short_repeated_primitives, pointy_brackets=pointy_brackets, @@ -248,7 +250,7 @@ def PrintField(field, value, out, indent=0, - as_utf8=False, + as_utf8=_as_utf8_default, as_one_line=False, use_short_repeated_primitives=False, pointy_brackets=False, @@ -272,7 +274,7 @@ def PrintFieldValue(field, value, out, indent=0, - as_utf8=False, + as_utf8=_as_utf8_default, as_one_line=False, use_short_repeated_primitives=False, pointy_brackets=False, @@ -328,7 +330,7 @@ def __init__( self, out, indent=0, - as_utf8=False, + as_utf8=_as_utf8_default, as_one_line=False, use_short_repeated_primitives=False, pointy_brackets=False,